Compare commits
36 Commits
v3.4.7
...
feat-platf
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6c18971f00 | ||
|
|
d488c88e78 | ||
|
|
baae842210 | ||
|
|
ec1fb838b6 | ||
|
|
13281179df | ||
|
|
276a42c9a1 | ||
|
|
7a70a730ba | ||
|
|
d0fe59631c | ||
|
|
106892e933 | ||
|
|
19543a41b3 | ||
|
|
b172b760ab | ||
|
|
4b5d49cb41 | ||
|
|
3fd35b6058 | ||
|
|
5f86c4ab99 | ||
|
|
c94a7f6629 | ||
|
|
7d6beb4141 | ||
|
|
e2117e690a | ||
|
|
fb791290e2 | ||
|
|
5dd1488b5d | ||
|
|
529cd64d82 | ||
|
|
d2bd3e8da8 | ||
|
|
e42ce7dd86 | ||
|
|
40709462ee | ||
|
|
2ad6c01a4d | ||
|
|
70c12e788e | ||
|
|
1713791c90 | ||
|
|
9aa23fd412 | ||
|
|
e4ba09cd93 | ||
|
|
171fdf1fbc | ||
|
|
01f4e0b961 | ||
|
|
be2d5a91c7 | ||
|
|
a1d89d9478 | ||
|
|
98d1dc3b65 | ||
|
|
b80eb3acc0 | ||
|
|
05ccc1995b | ||
|
|
0de244889e |
101
README.md
101
README.md
@@ -17,80 +17,92 @@ _✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
|
||||
<img alt="Static Badge" src="https://img.shields.io/badge/QQ群-322154837-purple">
|
||||
[](https://wakatime.com/badge/user/915e5316-99c6-4563-a483-ef186cf000c9/project/018e705a-a1a7-409a-a849-3013485e6c8e)
|
||||
[](https://codecov.io/gh/Soulter/AstrBot)
|
||||
[<img src="https://api.gitsponsors.com/api/badge/img?id=575865240" height="20">](https://api.gitsponsors.com/api/badge/link?p=XEpbdGxlitw/RbcwiTX93UMzNK/jgDYC8NiSzamIPMoKvG2lBFmyXhSS/b0hFoWlBBMX2L5X5CxTDsUdyvcIEHTOfnkXz47UNOZvMwyt5CzbYpq0SEzsSV1OJF1cCo90qC/ZyYKYOWedal3MhZ3ikw==)
|
||||
</a>
|
||||
|
||||
<a href="https://astrbot.lwl.lol/">查看文档</a> |
|
||||
<a href="https://astrbot.app/">查看文档</a> |
|
||||
<a href="https://github.com/Soulter/AstrBot/issues">问题提交</a>
|
||||
</div>
|
||||
|
||||
AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用的插件系统和完善的大语言模型(LLM)接入功能的聊天机器人及开发框架。
|
||||
|
||||
## ✨ 多消息平台部署
|
||||
## ✨ 主要功能
|
||||
|
||||
1. QQ 群、QQ 频道、微信个人号、Telegram。
|
||||
2. 内置 Web Chat,即使不部署到消息平台也能聊天。
|
||||
3. 支持文本转图片,Markdown 渲染。
|
||||
|
||||
## ✨ 多 LLM 配置
|
||||
1. **大语言模型对话**。支持各种大语言模型,包括 OpenAI API、Google Gemini、Llama、Deepseek、ChatGLM 等,支持接入本地部署的大模型,通过 Ollama、LLMTuner。具有多轮对话、人格情境、多模态能力,支持图片理解、语音转文字(Whisper)。
|
||||
2. **多消息平台接入**。支持接入 QQ(OneBot)、QQ 频道、微信(Gewechat、VChat)、Telegram。后续将支持钉钉、飞书、Discord、WhatsApp、小爱音响。支持速率限制、白名单、关键词过滤、百度内容审核。
|
||||
3. **Agent**。原生支持部分 Agent 能力,如代码执行器、自然语言待办、网页搜索。对接 [Dify 平台](https://astrbot.app/others/dify.html),便捷接入 Dify 智能助手、知识库和 Dify 工作流。
|
||||
4. **插件扩展**。深度优化的插件机制,支持[开发插件](https://astrbot.app/dev/plugin.html)扩展功能,极简开发。已支持安装多个插件。
|
||||
5. **可视化管理面板**。支持可视化修改配置、插件管理、日志查看等功能,降低配置难度。集成 WebChat,可在面板上与大模型对话。
|
||||
6. **高稳定性、高模块化**。基于事件总线和流水线的架构设计,高度模块化,低耦合。
|
||||
|
||||
1. 适配 OpenAI API,支持接入 Gemini、GPT、Llama、Claude、DeepSeek、GLM 等各种大语言模型。
|
||||
2. 支持 OneAPI 等分发平台。
|
||||
3. 支持 LLMTuner 载入微调模型。
|
||||
4. 支持 Ollama 载入自部署模型。
|
||||
4. 支持网页搜索(Web Search)、自然语言待办提醒。
|
||||
5. 支持 Whisper 语音转文字
|
||||
> [!TIP]
|
||||
> 管理面板在线体验 Demo: [https://demo.astrbot.app/](https://demo.astrbot.app/)
|
||||
>
|
||||
> 用户名: `astrbot`, 密码: `astrbot`。此 Demo 未配置 LLM,因此无法在聊天页使用大模型。
|
||||
|
||||
## ✨ 管理面板
|
||||
## ✨ 使用方式
|
||||
|
||||
1. 支持可视化修改配置
|
||||
2. 日志实时查看
|
||||
3. 简单的信息统计
|
||||
4. 插件管理
|
||||
#### Docker 部署
|
||||
|
||||
## ✨ 支持 Dify
|
||||
请参阅官方文档 [使用 Docker 部署 AstrBot](https://astrbot.app/deploy/astrbot/docker.html#%E4%BD%BF%E7%94%A8-docker-%E9%83%A8%E7%BD%B2-astrbot) 。
|
||||
|
||||
1. 对接了 LLMOps 平台 Dify,便捷接入 Dify 智能助手、知识库和 Dify 工作流
|
||||
#### Windows 一键安装器部署
|
||||
|
||||
## ✨ 代码执行器(Beta)
|
||||
需要电脑上安装有 Python(>3.10)。请参阅官方文档 [使用 Windows 一键安装器部署 AstrBot](https://astrbot.app/deploy/astrbot/windows.html) 。
|
||||
|
||||
基于 Docker 的沙箱化代码执行器(Beta 测试中)
|
||||
|
||||
> [!NOTE]
|
||||
> 文件输入/输出目前仅测试了 Napcat(QQ), Lagrange(QQ)
|
||||
|
||||
<div align='center'>
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||
|
||||
</div>
|
||||
|
||||
## ✨ 云部署
|
||||
#### Replit 部署
|
||||
|
||||
[](https://repl.it/github/Soulter/AstrBot)
|
||||
|
||||
#### CasaOS 部署
|
||||
|
||||
社区贡献的部署方式。
|
||||
|
||||
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/casaos.html) 。
|
||||
|
||||
#### 手动部署
|
||||
|
||||
请参阅官方文档 [通过源码部署 AstrBot](https://astrbot.app/deploy/astrbot/cli.html) 。
|
||||
|
||||
|
||||
## ⚡ 消息平台支持情况
|
||||
|
||||
|
||||
| 平台 | 支持性 | 详情 | 消息类型 |
|
||||
| -------- | ------- | ------- | ------ |
|
||||
| QQ | ✔ | 私聊、群聊 | 文字、图片、语音 |
|
||||
| QQ 官方API | ✔ | 私聊、群聊,QQ 频道私聊、群聊 | 文字、图片 |
|
||||
| 微信 | ✔ | [Gewechat](https://github.com/Devo919/Gewechat)。微信个人号私聊、群聊 | 文字 |
|
||||
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
|
||||
| 微信对话开放平台 | 🚧 | 计划内 | - |
|
||||
| 飞书 | 🚧 | 计划内 | - |
|
||||
| Discord | 🚧 | 计划内 | - |
|
||||
| WhatsApp | 🚧 | 计划内 | - |
|
||||
| 小爱音响 | 🚧 | 计划内 | - |
|
||||
|
||||
## ❤️ 贡献
|
||||
|
||||
欢迎任何 Issues/Pull Requests!只需要将你的更改提交到此项目 :)
|
||||
|
||||
对于新功能的添加,请先通过 Issue 讨论。
|
||||
|
||||
## 🔭 展望
|
||||
|
||||
1. 更强大的 Agent 系统。
|
||||
2. 打造插件工作流平台。
|
||||
|
||||
## ✨ Support
|
||||
## 🌟 支持
|
||||
|
||||
- Star 这个项目!
|
||||
- 在[爱发电](https://afdian.com/a/soulter)支持我!
|
||||
- 在[微信](https://drive.soulter.top/f/pYfA/d903f4fa49a496fda3f16d2be9e023b5.png)支持我~
|
||||
|
||||
|
||||
|
||||
## ✨ Demo
|
||||
|
||||
> [!NOTE]
|
||||
> 代码执行器的文件输入/输出目前仅测试了 Napcat(QQ), Lagrange(QQ)
|
||||
|
||||
<div align='center'>
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/4ee688d9-467d-45c8-99d6-368f9a8a92d8" width="600">
|
||||
|
||||
_✨基于 Docker 的沙箱化代码执行器(Beta 测试中)✨_
|
||||
|
||||
<img src="https://github.com/user-attachments/assets/0378f407-6079-4f64-ae4c-e97ab20611d2" height=500>
|
||||
|
||||
_✨ 多模态、网页搜索、长文本转图片(可配置) ✨_
|
||||
@@ -114,6 +126,13 @@ _✨ 内置 Web Chat,在线与机器人交互 ✨_
|
||||
|
||||
</div>
|
||||
|
||||
## ⭐ Star History
|
||||
|
||||
> [!TIP]
|
||||
> 如果本项目对您的生活 / 工作产生了帮助,或者您关注本项目的未来发展,请给项目 Star,这是我维护这个开源项目的动力 <3
|
||||
|
||||
[](https://star-history.com/#soulter/astrbot&Date)
|
||||
|
||||
|
||||
<!-- ## ✨ ATRI [Beta 测试]
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||
__all__ = [
|
||||
"AstrBotConfig",
|
||||
"logger",
|
||||
"personalities",
|
||||
"html_renderer",
|
||||
"llm_tool",
|
||||
"sp"
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot import logger
|
||||
from astrbot.core.utils.personality import personalities
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.star.register import register_llm_tool as llm_tool
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
|
||||
"""
|
||||
|
||||
VERSION = "3.4.7"
|
||||
VERSION = "3.4.11"
|
||||
DB_PATH = "data/data_v3.db"
|
||||
|
||||
# 默认配置
|
||||
@@ -22,6 +22,9 @@ DEFAULT_CONFIG = {
|
||||
"id_whitelist_log": True,
|
||||
"wl_ignore_admin_on_group": True,
|
||||
"wl_ignore_admin_on_friend": True,
|
||||
"reply_with_mention": False,
|
||||
"reply_with_quote": False,
|
||||
"path_mapping": []
|
||||
},
|
||||
"provider": [],
|
||||
"provider_settings": {
|
||||
@@ -34,8 +37,8 @@ DEFAULT_CONFIG = {
|
||||
"prompt_prefix": "",
|
||||
},
|
||||
"provider_stt_settings": {
|
||||
"enable": False,
|
||||
"provider_id": "",
|
||||
"enable": False,
|
||||
"provider_id": "",
|
||||
},
|
||||
"content_safety": {
|
||||
"internal_keywords": {"enable": True, "extra_keywords": []},
|
||||
@@ -61,9 +64,9 @@ DEFAULT_CONFIG = {
|
||||
"name": "default",
|
||||
"prompt": "如果用户寻求帮助或者打招呼,请告诉他可以用 /help 查看 AstrBot 帮助。",
|
||||
"begin_dialogs": [],
|
||||
"mood_imitation_dialogs": []
|
||||
"mood_imitation_dialogs": [],
|
||||
}
|
||||
]
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@@ -93,6 +96,26 @@ CONFIG_METADATA_2 = {
|
||||
"ws_reverse_port": 6199,
|
||||
},
|
||||
"vchat(微信)": {"id": "default", "type": "vchat", "enable": False},
|
||||
"gewechat(微信)": {
|
||||
"id": "gwchat",
|
||||
"type": "gewechat",
|
||||
"enable": False,
|
||||
"base_url": "http://localhost:2531",
|
||||
"nickname": "soulter",
|
||||
"host": "localhost",
|
||||
"port": 11451,
|
||||
},
|
||||
"mispeaker(小爱音箱)": {
|
||||
"id": "mispeaker",
|
||||
"type": "mispeaker",
|
||||
"enable": False,
|
||||
"username": "",
|
||||
"password": "",
|
||||
"did": "",
|
||||
"activate_word": "测试",
|
||||
"deactivate_word": "停止",
|
||||
"interval": 1,
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"id": {
|
||||
@@ -178,7 +201,7 @@ CONFIG_METADATA_2 = {
|
||||
},
|
||||
"enable_id_white_list": {
|
||||
"description": "启用 ID 白名单",
|
||||
"type": "bool"
|
||||
"type": "bool",
|
||||
},
|
||||
"id_whitelist": {
|
||||
"description": "ID 白名单",
|
||||
@@ -199,6 +222,22 @@ CONFIG_METADATA_2 = {
|
||||
"description": "管理员私聊消息无视 ID 白名单",
|
||||
"type": "bool",
|
||||
},
|
||||
"reply_with_mention": {
|
||||
"description": "回复时 @ 发送者",
|
||||
"type": "bool",
|
||||
"hint": "启用后,机器人回复消息时会 @ 发送者。实际效果以具体的平台适配器为准。",
|
||||
},
|
||||
"reply_with_quote": {
|
||||
"description": "回复时引用消息",
|
||||
"type": "bool",
|
||||
"hint": "启用后,机器人回复消息时会引用原消息。实际效果以具体的平台适配器为准。",
|
||||
},
|
||||
"path_mapping": {
|
||||
"description": "路径映射",
|
||||
"type": "list",
|
||||
"obvious_hint": True,
|
||||
"hint": "此功能解决由于文件系统不一致导致路径不存在的问题。格式为 <原路径>:<映射路径>。如 `/app/.config/QQ:/var/lib/docker/volumes/xxxx/_data`。这样,当消息平台下发的事件中图片和语音路径以 `/app/.config/QQ` 开头时,开头被替换为 `/var/lib/docker/volumes/xxxx/_data`。这在 AstrBot 或者平台协议端使用 Docker 部署时特别有用。",
|
||||
}
|
||||
},
|
||||
},
|
||||
"content_safety": {
|
||||
@@ -264,7 +303,7 @@ CONFIG_METADATA_2 = {
|
||||
"type": "openai_chat_completion",
|
||||
"enable": True,
|
||||
"key": ["ollama"], # ollama 的 key 默认是 ollama
|
||||
"api_base": "http://localhost:11434",
|
||||
"api_base": "http://localhost:11434/v1",
|
||||
"model_config": {
|
||||
"model": "llama3.1-8b",
|
||||
},
|
||||
@@ -342,14 +381,14 @@ CONFIG_METADATA_2 = {
|
||||
"id": "whisper",
|
||||
"type": "openai_whisper_selfhost",
|
||||
"model": "tiny",
|
||||
}
|
||||
},
|
||||
},
|
||||
"items": {
|
||||
"whisper_hint": {
|
||||
"description": "本地部署 Whisper 模型须知",
|
||||
"type": "string",
|
||||
"hint": "启用前请 pip 安装 openai-whisper 库(N卡用户大约下载 2GB,主要是 torch 和 cuda,CPU 用户大约下载 1 GB),并且安装 ffmpeg。否则将无法正常转文字。",
|
||||
"obvious_hint": True
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"id": {
|
||||
"description": "ID",
|
||||
@@ -375,7 +414,8 @@ CONFIG_METADATA_2 = {
|
||||
"api_base": {
|
||||
"description": "API Base URL",
|
||||
"type": "string",
|
||||
"hint": "API Base URL 请在在模型提供商处获得。支持 Ollama 开放的 API 地址。如果您确认填写正确但是使用时出现了 404 异常,可以尝试在地址末尾加上 `/v1`。",
|
||||
"hint": "API Base URL 请在在模型提供商处获得。如使用时出现了 404 报错,可以尝试在地址末尾加上 `/v1`。",
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"base_model_path": {
|
||||
"description": "基座模型路径",
|
||||
@@ -439,7 +479,7 @@ CONFIG_METADATA_2 = {
|
||||
"description": "Dify Workflow 输出变量名",
|
||||
"type": "string",
|
||||
"hint": "Dify Workflow 输出变量名。当应用类型为 workflow 时才使用。默认为 astrbot_wf_output。",
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
"provider_settings": {
|
||||
@@ -450,7 +490,7 @@ CONFIG_METADATA_2 = {
|
||||
"description": "启用大语言模型聊天",
|
||||
"type": "bool",
|
||||
"hint": "如需切换大语言模型提供商,请使用 `/provider` 命令。",
|
||||
"obvious_hint": True
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"wake_prefix": {
|
||||
"description": "LLM 聊天额外唤醒前缀",
|
||||
@@ -492,7 +532,7 @@ CONFIG_METADATA_2 = {
|
||||
"name": "",
|
||||
"prompt": "",
|
||||
"begin_dialogs": [],
|
||||
"mood_imitation_dialogs": []
|
||||
"mood_imitation_dialogs": [],
|
||||
}
|
||||
},
|
||||
"tmpl_display_title": "name",
|
||||
@@ -501,7 +541,7 @@ CONFIG_METADATA_2 = {
|
||||
"description": "人格名称",
|
||||
"type": "string",
|
||||
"hint": "人格名称,用于在多个人格中区分。使用 /persona 指令可切换人格。在 大语言模型设置 处可以设置默认人格。",
|
||||
"obvious_hint": True
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"prompt": {
|
||||
"description": "设定(系统提示词)",
|
||||
@@ -513,18 +553,17 @@ CONFIG_METADATA_2 = {
|
||||
"type": "list",
|
||||
"items": {},
|
||||
"hint": "可选。在每个对话前会插入这些预设对话。格式要求:第一句为用户,第二句为助手,以此类推。",
|
||||
"obvious_hint": True
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"mood_imitation_dialogs": {
|
||||
"description": "对话风格模仿",
|
||||
"type": "list",
|
||||
"items": {},
|
||||
"hint": "旨在让模型尽可能模仿学习到所填写的对话的语气风格。格式和 `预设对话` 一样。",
|
||||
"obvious_hint": True
|
||||
"obvious_hint": True,
|
||||
},
|
||||
}
|
||||
},
|
||||
},
|
||||
|
||||
"provider_stt_settings": {
|
||||
"description": "语音转文本(STT)",
|
||||
"type": "object",
|
||||
@@ -533,7 +572,7 @@ CONFIG_METADATA_2 = {
|
||||
"description": "启用语音转文本(STT)",
|
||||
"type": "bool",
|
||||
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 whisper。",
|
||||
"obvious_hint": True
|
||||
"obvious_hint": True,
|
||||
},
|
||||
"provider_id": {
|
||||
"description": "提供商 ID,不填则默认第一个STT提供商",
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import traceback
|
||||
import asyncio
|
||||
import time
|
||||
import threading
|
||||
@@ -81,12 +82,30 @@ class AstrBotCoreLifecycle:
|
||||
for task in self.star_context._register_tasks:
|
||||
extra_tasks.append(asyncio.create_task(task, name=task.__name__))
|
||||
|
||||
self.curr_tasks = [event_bus_task, *platform_tasks, *extra_tasks]
|
||||
# self.curr_tasks = [event_bus_task, *platform_tasks, *extra_tasks]
|
||||
|
||||
tasks_ = [event_bus_task, *platform_tasks, *extra_tasks]
|
||||
for task in tasks_:
|
||||
self.curr_tasks.append(asyncio.create_task(self._task_wrapper(task), name=task.get_name()))
|
||||
|
||||
self.start_time = int(time.time())
|
||||
|
||||
async def _task_wrapper(self, task: asyncio.Task):
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception as e:
|
||||
|
||||
logger.error(f"------- 任务 {task.get_name()} 发生错误: {e}")
|
||||
for line in traceback.format_exc().split("\n"):
|
||||
logger.error(f"| {line}")
|
||||
logger.error("-------")
|
||||
|
||||
async def start(self):
|
||||
self._load()
|
||||
logger.info("AstrBot 启动完成。")
|
||||
|
||||
await asyncio.gather(*self.curr_tasks, return_exceptions=True)
|
||||
|
||||
async def stop(self):
|
||||
|
||||
@@ -5,7 +5,7 @@ from ..stage import Stage, register_stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Plain, Record
|
||||
from astrbot.core.message.components import Plain, Record, Image
|
||||
|
||||
@register_stage
|
||||
class PreProcessStage(Stage):
|
||||
@@ -16,26 +16,39 @@ class PreProcessStage(Stage):
|
||||
self.plugin_manager = ctx.plugin_manager
|
||||
|
||||
self.stt_settings: dict = self.config.get('provider_stt_settings', {})
|
||||
self.platform_settings: dict = self.config.get('platform_settings', {})
|
||||
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
'''在处理事件之前的预处理'''
|
||||
|
||||
# 路径映射
|
||||
if mappings := self.platform_settings.get('path_mapping', []):
|
||||
# 支持 Record,Image 消息段的路径映射。
|
||||
message_chain = event.get_messages()
|
||||
|
||||
for idx, component in enumerate(message_chain):
|
||||
if isinstance(component, (Record, Image)) and component.url:
|
||||
for mapping in mappings:
|
||||
from_, to_ = mapping.split(":")
|
||||
from_ = from_.removesuffix("/")
|
||||
to_ = to_.removesuffix("/")
|
||||
|
||||
url = component.url.removeprefix("file://")
|
||||
if url.startswith(from_):
|
||||
component.url = url.replace(from_, to_, 1)
|
||||
logger.debug(f"路径映射: {url} -> {component.url}")
|
||||
message_chain[idx] = component
|
||||
|
||||
# STT
|
||||
if self.stt_settings.get('enable', False):
|
||||
# STT 处理
|
||||
# TODO: 独立
|
||||
stt_provider = self.plugin_manager.context.provider_manager.curr_stt_provider_inst
|
||||
if stt_provider:
|
||||
message_chain = event.get_messages()
|
||||
for idx, component in enumerate(message_chain):
|
||||
if isinstance(component, Record) and component.url:
|
||||
|
||||
path = component.url
|
||||
|
||||
path.removeprefix("file:///")
|
||||
|
||||
path = component.url.removeprefix("file://")
|
||||
retry = 5
|
||||
|
||||
for i in range(retry):
|
||||
try:
|
||||
result = await stt_provider.get_text(audio_url=path)
|
||||
@@ -48,7 +61,7 @@ class PreProcessStage(Stage):
|
||||
except FileNotFoundError as e:
|
||||
# napcat workaround
|
||||
logger.warning(e)
|
||||
logger.warning(f"语音文件不存在: {path}, 重试中: {i + 1}/{retry}")
|
||||
logger.warning(f"重试中: {i + 1}/{retry}")
|
||||
await asyncio.sleep(0.5)
|
||||
continue
|
||||
except BaseException as e:
|
||||
|
||||
@@ -17,6 +17,13 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
async def initialize(self, ctx: PipelineContext) -> None:
|
||||
self.ctx = ctx
|
||||
self.bot_wake_prefixs = ctx.astrbot_config['wake_prefix'] # list
|
||||
self.provider_wake_prefix = ctx.astrbot_config['provider_settings']['wake_prefix'] # str
|
||||
|
||||
for bwp in self.bot_wake_prefixs:
|
||||
if self.provider_wake_prefix.startswith(bwp):
|
||||
logger.info(f"识别 LLM 聊天额外唤醒前缀 {self.provider_wake_prefix} 以机器人唤醒前缀 {bwp} 开头,已自动去除。")
|
||||
self.provider_wake_prefix = self.provider_wake_prefix[len(bwp):]
|
||||
|
||||
async def process(self, event: AstrMessageEvent, _nested: bool = False) -> Union[None, AsyncGenerator[None, None]]:
|
||||
req: ProviderRequest = None
|
||||
@@ -30,10 +37,10 @@ class LLMRequestSubStage(Stage):
|
||||
assert isinstance(req, ProviderRequest), "provider_request 必须是 ProviderRequest 类型。"
|
||||
else:
|
||||
req = ProviderRequest(prompt="", image_urls=[])
|
||||
if self.ctx.astrbot_config['provider_settings']['wake_prefix']:
|
||||
if not event.message_str.startswith(self.ctx.astrbot_config['provider_settings']['wake_prefix']):
|
||||
if self.provider_wake_prefix:
|
||||
if not event.message_str.startswith(self.provider_wake_prefix):
|
||||
return
|
||||
req.prompt = event.message_str[len(self.ctx.astrbot_config['provider_settings']['wake_prefix']):]
|
||||
req.prompt = event.message_str[len(self.provider_wake_prefix):]
|
||||
req.func_tool = self.ctx.plugin_manager.context.get_llm_tool_manager()
|
||||
for comp in event.message_obj.message:
|
||||
if isinstance(comp, Image):
|
||||
@@ -98,5 +105,5 @@ class LLMRequestSubStage(Stage):
|
||||
|
||||
except BaseException as e:
|
||||
logger.error(traceback.format_exc())
|
||||
event.set_result(MessageEventResult().message("AstrBot 请求 LLM 资源失败:" + str(e)))
|
||||
event.set_result(MessageEventResult().message(f"AstrBot 请求失败。\n错误类型: {type(e).__name__}\n错误信息: {str(e)}"))
|
||||
return
|
||||
@@ -3,8 +3,9 @@ from typing import Union, AsyncGenerator
|
||||
from ..stage import register_stage
|
||||
from ..context import PipelineContext
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
from astrbot.core.platform.message_type import MessageType
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.message.components import Plain, Image
|
||||
from astrbot.core.message.components import Plain, Image, At, Reply
|
||||
from astrbot.core import html_renderer
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
|
||||
@@ -13,6 +14,8 @@ class ResultDecorateStage:
|
||||
async def initialize(self, ctx: PipelineContext):
|
||||
self.ctx = ctx
|
||||
self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix']
|
||||
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
|
||||
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
|
||||
self.t2i = ctx.astrbot_config['t2i']
|
||||
|
||||
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
|
||||
@@ -48,4 +51,10 @@ class ResultDecorateStage:
|
||||
if time.time() - render_start > 3:
|
||||
logger.warning("文本转图片耗时超过了 3 秒,如果觉得很慢可以使用 /t2i 关闭文本转图片模式。")
|
||||
if url:
|
||||
result.chain = [Image.fromURL(url)]
|
||||
result.chain = [Image.fromURL(url)]
|
||||
|
||||
if self.reply_with_mention and event.get_message_type() != MessageType.FRIEND_MESSAGE:
|
||||
result.chain.insert(0, At(qq=event.get_sender_id()))
|
||||
|
||||
if self.reply_with_quote:
|
||||
result.chain.insert(0, Reply(id=event.message_obj.message_id))
|
||||
@@ -25,6 +25,10 @@ class PlatformManager():
|
||||
from .sources.qqofficial.qqofficial_platform_adapter import QQOfficialPlatformAdapter # noqa: F401
|
||||
case "vchat":
|
||||
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
|
||||
case "gewechat":
|
||||
from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
|
||||
case "mispeaker":
|
||||
from .sources.mispeaker.mispeaker_adapter import MiSpeakerPlatformAdapter # noqa: F401
|
||||
|
||||
|
||||
async def initialize(self):
|
||||
|
||||
@@ -27,6 +27,8 @@ def register_platform_adapter(
|
||||
default_config_tmpl['type'] = adapter_name
|
||||
if 'enable' not in default_config_tmpl:
|
||||
default_config_tmpl['enable'] = False
|
||||
if 'id' not in default_config_tmpl:
|
||||
default_config_tmpl['id'] = adapter_name
|
||||
|
||||
pm = PlatformMetadata(
|
||||
name=adapter_name,
|
||||
|
||||
@@ -28,7 +28,9 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
|
||||
elif segment.file and segment.file.startswith("http"):
|
||||
image_file_path = await download_image_by_url(segment.file)
|
||||
image_base64 = file_to_base64(image_file_path)
|
||||
d['data']['file'] = image_base64
|
||||
d['data'] = {
|
||||
'file': image_base64,
|
||||
}
|
||||
ret.append(d)
|
||||
return ret
|
||||
|
||||
|
||||
266
astrbot/core/platform/sources/gewechat/client.py
Normal file
266
astrbot/core/platform/sources/gewechat/client.py
Normal file
@@ -0,0 +1,266 @@
|
||||
import threading
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import quart
|
||||
|
||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.api.message_components import Plain, Image, At
|
||||
from astrbot.api import logger, sp
|
||||
|
||||
class SimpleGewechatClient():
|
||||
'''针对 Gewechat 的简单实现。
|
||||
|
||||
@author: Soulter
|
||||
@website: https://github.com/Soulter
|
||||
'''
|
||||
def __init__(self, base_url: str, nickname: str, host: str, port: int, event_queue: asyncio.Queue):
|
||||
self.base_url = base_url
|
||||
if self.base_url.endswith('/'):
|
||||
self.base_url = self.base_url[:-1]
|
||||
|
||||
self.base_url += "/v2/api"
|
||||
|
||||
if isinstance(port, str):
|
||||
port = int(port)
|
||||
|
||||
self.token = None
|
||||
self.headers = {}
|
||||
self.nickname = nickname
|
||||
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
|
||||
self.callback_url = None
|
||||
|
||||
self.server = quart.Quart(__name__)
|
||||
self.server.add_url_rule('/astrbot-gewechat/callback', view_func=self.callback, methods=['POST'])
|
||||
|
||||
self.host = host
|
||||
self.port = port
|
||||
|
||||
self.event_queue = event_queue
|
||||
|
||||
async def get_token_id(self):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(f"{self.base_url}/tools/getTokenId") as resp:
|
||||
json_blob = await resp.json()
|
||||
self.token = json_blob['data']
|
||||
logger.info(f"获取到 Gewechat Token: {self.token}")
|
||||
self.headers = {
|
||||
"X-GEWE-TOKEN": self.token
|
||||
}
|
||||
|
||||
async def _convert(self, data: dict) -> AstrBotMessage:
|
||||
type_name = data['TypeName']
|
||||
if type_name == "Offline":
|
||||
logger.critical("收到 gewechat 下线通知。")
|
||||
return
|
||||
abm = AstrBotMessage()
|
||||
d = data['Data']
|
||||
msg_type = d['MsgType']
|
||||
|
||||
match msg_type:
|
||||
case 1:
|
||||
from_user_name = d['FromUserName']['string'] # 消息来源
|
||||
d['to_wxid'] = from_user_name # 用于发信息
|
||||
|
||||
user_id = "" # 发送人 wxid
|
||||
content = d['Content']['string'] # 消息内容
|
||||
user_real_name = d['PushContent'].split(' : ')[0] # 真实昵称
|
||||
user_real_name = user_real_name.replace('在群聊中@了你', '') # trick
|
||||
abm.self_id = data['Wxid'] # 机器人的 wxid
|
||||
at_me = False
|
||||
if "@chatroom" in from_user_name:
|
||||
abm.type = MessageType.GROUP_MESSAGE
|
||||
_t = content.split(':\n')
|
||||
user_id = _t[0]
|
||||
content = _t[1]
|
||||
if '\u2005' in content:
|
||||
# at
|
||||
content = content.split('\u2005')[1]
|
||||
|
||||
abm.group_id = from_user_name
|
||||
|
||||
# at
|
||||
msg_source = d['MsgSource']
|
||||
if f'<atuserlist><![CDATA[,{abm.self_id}]]>' in msg_source \
|
||||
or f'<atuserlist><![CDATA[{abm.self_id}]]>' in msg_source:
|
||||
at_me = True
|
||||
|
||||
else:
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
user_id = from_user_name
|
||||
abm.session_id = from_user_name
|
||||
abm.sender = MessageMember(user_id, user_real_name)
|
||||
abm.message = [Plain(content)]
|
||||
|
||||
if at_me:
|
||||
abm.message.insert(0, At(qq=abm.self_id))
|
||||
|
||||
abm.message_id = str(d['MsgId'])
|
||||
abm.raw_message = d
|
||||
abm.message_str = content
|
||||
|
||||
logger.info(f"abm: {abm}")
|
||||
return abm
|
||||
case _:
|
||||
logger.error(f"未实现的消息类型: {msg_type}")
|
||||
|
||||
async def callback(self):
|
||||
data = await quart.request.json
|
||||
logger.debug(f"收到 gewechat 回调: {data}")
|
||||
|
||||
if data.get('testMsg', None):
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
abm = await self._convert(data)
|
||||
|
||||
if abm:
|
||||
coro = getattr(self, "on_event_received")
|
||||
if coro:
|
||||
await coro(abm)
|
||||
|
||||
return quart.jsonify({"r": "AstrBot ACK"})
|
||||
|
||||
async def _set_callback_url(self):
|
||||
logger.info("设置回调,请等待...")
|
||||
await asyncio.sleep(3)
|
||||
callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/tools/setCallback",
|
||||
headers=self.headers,
|
||||
json={
|
||||
"token": self.token,
|
||||
"callbackUrl": callback_url
|
||||
}
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"设置回调结果: {json_blob}")
|
||||
if json_blob['ret'] != 200:
|
||||
raise Exception(f"设置回调失败: {json_blob}")
|
||||
logger.info(f"将在 {callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。")
|
||||
|
||||
async def start_polling(self):
|
||||
|
||||
# 设置回调
|
||||
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
|
||||
|
||||
|
||||
await self.server.run_task(
|
||||
host=self.host,
|
||||
port=self.port,
|
||||
shutdown_trigger=self.shutdown_trigger_placeholder
|
||||
)
|
||||
|
||||
async def shutdown_trigger_placeholder(self):
|
||||
while not self.event_queue.closed:
|
||||
await asyncio.sleep(1)
|
||||
logger.info("gewechat 适配器已关闭。")
|
||||
|
||||
async def check_online(self, appid: str):
|
||||
# /login/checkOnline
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/checkOnline",
|
||||
headers=self.headers,
|
||||
json={
|
||||
"appId": appid
|
||||
}
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
return json_blob['data']
|
||||
|
||||
async def logout(self):
|
||||
if self.appid:
|
||||
online = await self.check_online(self.appid)
|
||||
if online:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/logout",
|
||||
headers=self.headers,
|
||||
json={
|
||||
"appId": self.appid
|
||||
}
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"登出结果: {json_blob}")
|
||||
|
||||
async def login(self):
|
||||
if self.token is None:
|
||||
await self.get_token_id()
|
||||
|
||||
if self.appid:
|
||||
online = await self.check_online(self.appid)
|
||||
if online:
|
||||
logger.info(f"APPID: {self.appid} 已在线")
|
||||
return
|
||||
|
||||
payload = {
|
||||
"appId": self.appid
|
||||
}
|
||||
|
||||
if self.appid:
|
||||
logger.info(f"使用 APPID: {self.appid}, {self.nickname}")
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/getLoginQrCode",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
if json_blob['ret'] != 200:
|
||||
raise Exception(f"获取二维码失败: {json_blob}")
|
||||
qr_data = json_blob['data']['qrData']
|
||||
qr_uuid = json_blob['data']['uuid']
|
||||
appid = json_blob['data']['appId']
|
||||
logger.info(f"APPID: {appid}")
|
||||
logger.warning(f"请打开该网址,然后使用微信扫描二维码登录: https://api.cl2wm.cn/api/qrcode/code?text={qr_data}")
|
||||
|
||||
# 执行登录
|
||||
retry_cnt = 64
|
||||
payload.update({
|
||||
"uuid": qr_uuid,
|
||||
"appId": appid
|
||||
})
|
||||
while retry_cnt > 0:
|
||||
retry_cnt -= 1
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/login/checkLogin",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"检查登录状态: {json_blob}")
|
||||
status = json_blob['data']['status']
|
||||
nickname = json_blob['data'].get('nickName', '')
|
||||
if status == 1:
|
||||
logger.info(f"等待确认...{nickname}")
|
||||
elif status == 2:
|
||||
logger.info(f"绿泡泡平台登录成功: {nickname}")
|
||||
break
|
||||
elif status == 0:
|
||||
logger.info("等待扫码...")
|
||||
else:
|
||||
logger.warning(f"未知状态: {status}")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
if appid:
|
||||
sp.put(f"gewechat-appid-{nickname}", appid)
|
||||
self.appid = appid
|
||||
logger.info(f"已保存 APPID: {appid}")
|
||||
|
||||
async def post_text(self, to_wxid, content: str):
|
||||
payload = {
|
||||
"appId": self.appid,
|
||||
"toWxid": to_wxid,
|
||||
"content": content,
|
||||
}
|
||||
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.post(
|
||||
f"{self.base_url}/message/postText",
|
||||
headers=self.headers,
|
||||
json=payload
|
||||
) as resp:
|
||||
json_blob = await resp.json()
|
||||
logger.info(f"发送消息结果: {json_blob}")
|
||||
38
astrbot/core/platform/sources/gewechat/gewechat_event.py
Normal file
38
astrbot/core/platform/sources/gewechat/gewechat_event.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import random
|
||||
import asyncio
|
||||
from astrbot.core.utils.io import download_image_by_url
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from .client import SimpleGewechatClient
|
||||
|
||||
class GewechatPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
client: SimpleGewechatClient
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
@staticmethod
|
||||
async def send_with_client(message: MessageChain, user_name: str):
|
||||
pass
|
||||
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
to_wxid = self.message_obj.raw_message.get('to_wxid', None)
|
||||
|
||||
if not to_wxid:
|
||||
logger.error("无法获取到 to_wxid。")
|
||||
return
|
||||
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
await self.client.post_text(to_wxid, comp.text)
|
||||
|
||||
await super().send(message)
|
||||
@@ -0,0 +1,93 @@
|
||||
import sys
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from astrbot.api import logger
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from .gewechat_event import GewechatPlatformEvent
|
||||
from .client import SimpleGewechatClient
|
||||
from astrbot.core.message.components import Plain
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
else:
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
@register_platform_adapter("gewechat", "基于 gewechat 的 Wechat 适配器")
|
||||
class GewechatPlatformAdapter(Platform):
|
||||
|
||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
||||
super().__init__(event_queue)
|
||||
self.config = platform_config
|
||||
self.settingss = platform_settings
|
||||
self.test_mode = os.environ.get('TEST_MODE', 'off') == 'on'
|
||||
self.client = None
|
||||
|
||||
@override
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
to_wxid = session.session_id
|
||||
if "_" in to_wxid:
|
||||
# 群聊,开启了独立会话
|
||||
_, to_wxid = to_wxid.split("_")
|
||||
|
||||
if not to_wxid:
|
||||
logger.error("无法获取到 to_wxid。")
|
||||
return
|
||||
|
||||
for comp in message_chain.chain:
|
||||
if isinstance(comp, Plain):
|
||||
await self.client.post_text(to_wxid, comp.text)
|
||||
|
||||
await super().send_by_session(session, message_chain)
|
||||
|
||||
@override
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"gewechat",
|
||||
"基于 gewechat 的 Wechat 适配器",
|
||||
)
|
||||
|
||||
@override
|
||||
def run(self):
|
||||
self.client = SimpleGewechatClient(
|
||||
self.config['base_url'],
|
||||
self.config['nickname'],
|
||||
self.config['host'],
|
||||
self.config['port'],
|
||||
self._event_queue,
|
||||
)
|
||||
|
||||
async def on_event_received(abm: AstrBotMessage):
|
||||
await self.handle_msg(abm)
|
||||
|
||||
self.client.on_event_received = on_event_received
|
||||
|
||||
return self._run()
|
||||
|
||||
async def logout(self):
|
||||
await self.client.logout()
|
||||
|
||||
async def _run(self):
|
||||
await self.client.login()
|
||||
|
||||
await self.client.start_polling()
|
||||
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
if message.type == MessageType.GROUP_MESSAGE:
|
||||
if self.settingss['unique_session']:
|
||||
message.session_id = message.sender.user_id + "_" + message.group_id
|
||||
|
||||
message_event = GewechatPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.client
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
137
astrbot/core/platform/sources/mispeaker/client.py
Normal file
137
astrbot/core/platform/sources/mispeaker/client.py
Normal file
@@ -0,0 +1,137 @@
|
||||
import os
|
||||
import json
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import time
|
||||
import traceback
|
||||
from .miservice import MiAccount, MiNAService, MiIOService, miio_command, miio_command_help
|
||||
from astrbot.core import logger
|
||||
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
|
||||
from astrbot.api.message_components import Plain, Image, At
|
||||
|
||||
class SimpleMiSpeakerClient():
|
||||
'''
|
||||
@author: Soulter
|
||||
@references: https://github.com/yihong0618/xiaogpt/blob/main/xiaogpt/xiaogpt.py
|
||||
'''
|
||||
def __init__(self, config: dict):
|
||||
self.username = config['username']
|
||||
self.password = config['password']
|
||||
self.did = config['did']
|
||||
self.store = os.path.join("data", '.mi.token')
|
||||
self.interval = float(config.get('interval', 1))
|
||||
|
||||
self.conv_query_cookies = {
|
||||
'userId': '',
|
||||
'deviceId': '',
|
||||
'serviceToken': ''
|
||||
}
|
||||
|
||||
self.MI_CONVERSATION_URL = "https://userprofile.mina.mi.com/device_profile/v2/conversation?source=dialogu&hardware={hardware}×tamp={timestamp}&limit=1"
|
||||
|
||||
self.session = aiohttp.ClientSession()
|
||||
|
||||
self.activate_word = config.get('activate_word', '测试')
|
||||
self.deactivate_word = config.get('deactivate_word', '停止')
|
||||
|
||||
self.entered = False
|
||||
|
||||
async def initialize(self):
|
||||
account = MiAccount(self.session, self.username, self.password, self.store)
|
||||
self.miio_service = MiIOService(account) # 小米设备服务
|
||||
self.mina_service = MiNAService(account) # 小爱音箱服务
|
||||
|
||||
device = await self.get_mina_device()
|
||||
|
||||
self.deviceID = device['deviceID']
|
||||
self.hardware = device['hardware']
|
||||
|
||||
with open(self.store, 'r') as f:
|
||||
data = json.load(f)
|
||||
self.userId = data['userId']
|
||||
self.serviceToken = data['micoapi'][1]
|
||||
self.conv_query_cookies['userId'] = self.userId
|
||||
self.conv_query_cookies['deviceId'] = self.deviceID
|
||||
self.conv_query_cookies['serviceToken'] = self.serviceToken
|
||||
|
||||
logger.info(f"MiSpeakerClient initialized. Conv cookies: {self.conv_query_cookies}. Hardware: {self.hardware}")
|
||||
|
||||
async def get_mina_device(self) -> dict:
|
||||
devices = await self.mina_service.device_list()
|
||||
for device in devices:
|
||||
if device['miotDID'] == self.did:
|
||||
logger.info(f"找到设备 {device['alias']}({device['name']}) 了!")
|
||||
return device
|
||||
|
||||
async def get_conv(self) -> str:
|
||||
# 时区请确保为北京时间
|
||||
async with aiohttp.ClientSession() as session:
|
||||
session.cookie_jar.update_cookies(self.conv_query_cookies)
|
||||
query_ts = int(time.time())*1000
|
||||
logger.debug(f"Querying conversation at {query_ts}")
|
||||
async with session.get(self.MI_CONVERSATION_URL.format(hardware=self.hardware, timestamp=str(query_ts))) as resp:
|
||||
json_blob = await resp.json()
|
||||
if json_blob['code'] == 0:
|
||||
data = json.loads(json_blob['data'])
|
||||
records = data.get('records', None)
|
||||
for record in records:
|
||||
if record['time'] >= query_ts - self.interval*1000:
|
||||
return record['query']
|
||||
else:
|
||||
logger.error(f"Failed to get conversation: {json_blob}")
|
||||
|
||||
return None
|
||||
|
||||
async def start_pooling(self):
|
||||
while True:
|
||||
await asyncio.sleep(self.interval)
|
||||
try:
|
||||
query = await self.get_conv()
|
||||
if not query:
|
||||
continue
|
||||
|
||||
# is wake
|
||||
if query == self.activate_word:
|
||||
self.entered = True
|
||||
await self.stop_playing()
|
||||
await self.send("我来啦!")
|
||||
continue
|
||||
elif query == self.deactivate_word:
|
||||
self.entered = False
|
||||
await self.stop_playing()
|
||||
await self.send("再见,欢迎给个 Star。")
|
||||
continue
|
||||
if not self.entered:
|
||||
continue
|
||||
|
||||
await self.send("")
|
||||
abm = await self._convert(query)
|
||||
|
||||
if abm:
|
||||
coro = getattr(self, "on_event_received")
|
||||
if coro:
|
||||
await coro(abm)
|
||||
|
||||
except BaseException as e:
|
||||
traceback.print_exc()
|
||||
logger.error(e)
|
||||
|
||||
async def _convert(self, query: str):
|
||||
abm = AstrBotMessage()
|
||||
abm.message = [Plain(query)]
|
||||
abm.message_id = str(int(time.time()))
|
||||
abm.message_str = query
|
||||
abm.raw_message = query
|
||||
abm.session_id = f"{self.hardware}_{self.did}_{self.username}"
|
||||
abm.sender = MessageMember(self.username, "主人")
|
||||
abm.self_id = f"{self.hardware}_{self.did}"
|
||||
abm.type = MessageType.FRIEND_MESSAGE
|
||||
return abm
|
||||
|
||||
async def send(self, message: str):
|
||||
text = f'5 {message}'
|
||||
await miio_command(self.miio_service, self.did, text, 'astrbot')
|
||||
|
||||
async def stop_playing(self):
|
||||
text = f'3-2'
|
||||
await miio_command(self.miio_service, self.did, text, 'astrbot')
|
||||
21
astrbot/core/platform/sources/mispeaker/miservice/LICENSE
Normal file
21
astrbot/core/platform/sources/mispeaker/miservice/LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2021-2022 Yonsm
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
5
astrbot/core/platform/sources/mispeaker/miservice/__init__.py
Executable file
5
astrbot/core/platform/sources/mispeaker/miservice/__init__.py
Executable file
@@ -0,0 +1,5 @@
|
||||
from .miaccount import MiAccount, MiTokenStore
|
||||
from .minaservice import MiNAService
|
||||
from .miioservice import MiIOService
|
||||
from .miiocommand import miio_command, miio_command_help
|
||||
|
||||
135
astrbot/core/platform/sources/mispeaker/miservice/miaccount.py
Normal file
135
astrbot/core/platform/sources/mispeaker/miservice/miaccount.py
Normal file
@@ -0,0 +1,135 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import string
|
||||
from urllib import parse
|
||||
from aiohttp import ClientSession
|
||||
from aiofiles import open as async_open
|
||||
|
||||
_LOGGER = logging.getLogger(__package__)
|
||||
|
||||
|
||||
def get_random(length):
|
||||
return ''.join(random.sample(string.ascii_letters + string.digits, length))
|
||||
|
||||
|
||||
class MiTokenStore:
|
||||
|
||||
def __init__(self, token_path):
|
||||
self.token_path = token_path
|
||||
|
||||
async def load_token(self):
|
||||
if os.path.isfile(self.token_path):
|
||||
try:
|
||||
async with async_open(self.token_path) as f:
|
||||
return json.loads(await f.read())
|
||||
except Exception as e:
|
||||
_LOGGER.exception("Exception on load token from %s: %s", self.token_path, e)
|
||||
return None
|
||||
|
||||
async def save_token(self, token=None):
|
||||
if token:
|
||||
try:
|
||||
async with async_open(self.token_path, 'w') as f:
|
||||
await f.write(json.dumps(token, indent=2))
|
||||
except Exception as e:
|
||||
_LOGGER.exception("Exception on save token to %s: %s", self.token_path, e)
|
||||
elif os.path.isfile(self.token_path):
|
||||
os.remove(self.token_path)
|
||||
|
||||
|
||||
class MiAccount:
|
||||
|
||||
def __init__(self, session: ClientSession, username, password, token_store='.mi.token'):
|
||||
self.session = session
|
||||
self.username = username
|
||||
self.password = password
|
||||
self.token_store = MiTokenStore(token_store) if isinstance(token_store, str) else token_store
|
||||
self.token = None
|
||||
|
||||
async def login(self, sid):
|
||||
if not self.token:
|
||||
self.token = {'deviceId': get_random(16).upper()}
|
||||
try:
|
||||
resp = await self._serviceLogin(f'serviceLogin?sid={sid}&_json=true')
|
||||
if resp['code'] != 0:
|
||||
data = {
|
||||
'_json': 'true',
|
||||
'qs': resp['qs'],
|
||||
'sid': resp['sid'],
|
||||
'_sign': resp['_sign'],
|
||||
'callback': resp['callback'],
|
||||
'user': self.username,
|
||||
'hash': hashlib.md5(self.password.encode()).hexdigest().upper()
|
||||
}
|
||||
resp = await self._serviceLogin('serviceLoginAuth2', data)
|
||||
if resp['code'] != 0:
|
||||
raise Exception(resp)
|
||||
|
||||
self.token['userId'] = resp['userId']
|
||||
self.token['passToken'] = resp['passToken']
|
||||
|
||||
serviceToken = await self._securityTokenService(resp['location'], resp['nonce'], resp['ssecurity'])
|
||||
self.token[sid] = (resp['ssecurity'], serviceToken)
|
||||
if self.token_store:
|
||||
await self.token_store.save_token(self.token)
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
self.token = None
|
||||
if self.token_store:
|
||||
await self.token_store.save_token()
|
||||
_LOGGER.exception("Exception on login %s: %s", self.username, e)
|
||||
return False
|
||||
|
||||
async def _serviceLogin(self, uri, data=None):
|
||||
headers = {'User-Agent': 'APP/com.xiaomi.mihome APPV/6.0.103 iosPassportSDK/3.9.0 iOS/14.4 miHSTS'}
|
||||
cookies = {'sdkVersion': '3.9', 'deviceId': self.token['deviceId']}
|
||||
if 'passToken' in self.token:
|
||||
cookies['userId'] = self.token['userId']
|
||||
cookies['passToken'] = self.token['passToken']
|
||||
url = 'https://account.xiaomi.com/pass/' + uri
|
||||
async with self.session.request('GET' if data is None else 'POST', url, data=data, cookies=cookies, headers=headers) as r:
|
||||
raw = await r.read()
|
||||
resp = json.loads(raw[11:])
|
||||
_LOGGER.debug("%s: %s", uri, resp)
|
||||
return resp
|
||||
|
||||
async def _securityTokenService(self, location, nonce, ssecurity):
|
||||
nsec = 'nonce=' + str(nonce) + '&' + ssecurity
|
||||
clientSign = base64.b64encode(hashlib.sha1(nsec.encode()).digest()).decode()
|
||||
async with self.session.get(location + '&clientSign=' + parse.quote(clientSign)) as r:
|
||||
serviceToken = r.cookies['serviceToken'].value
|
||||
if not serviceToken:
|
||||
raise Exception(await r.text())
|
||||
return serviceToken
|
||||
|
||||
async def mi_request(self, sid, url, data, headers, relogin=True):
|
||||
if self.token is None and self.token_store is not None:
|
||||
self.token = await self.token_store.load_token()
|
||||
if (self.token and sid in self.token) or await self.login(sid): # Ensure login
|
||||
cookies = {'userId': self.token['userId'], 'serviceToken': self.token[sid][1]}
|
||||
content = data(self.token, cookies) if callable(data) else data
|
||||
method = 'GET' if data is None else 'POST'
|
||||
_LOGGER.debug("%s %s", url, content)
|
||||
async with self.session.request(method, url, data=content, cookies=cookies, headers=headers) as r:
|
||||
status = r.status
|
||||
if status == 200:
|
||||
resp = await r.json(content_type=None)
|
||||
code = resp['code']
|
||||
if code == 0:
|
||||
return resp
|
||||
if 'auth' in resp.get('message', '').lower():
|
||||
status = 401
|
||||
else:
|
||||
resp = await r.text()
|
||||
if status == 401 and relogin:
|
||||
_LOGGER.warn("Auth error on request %s %s, relogin...", url, resp)
|
||||
self.token = None # Auth error, reset login
|
||||
return await self.mi_request(sid, url, data, headers, False)
|
||||
else:
|
||||
resp = "Login failed"
|
||||
raise Exception(f"Error {url}: {resp}")
|
||||
104
astrbot/core/platform/sources/mispeaker/miservice/miiocommand.py
Executable file
104
astrbot/core/platform/sources/mispeaker/miservice/miiocommand.py
Executable file
@@ -0,0 +1,104 @@
|
||||
|
||||
import json
|
||||
from .miioservice import MiIOService
|
||||
|
||||
|
||||
def twins_split(string, sep, default=None):
|
||||
pos = string.find(sep)
|
||||
return (string, default) if pos == -1 else (string[0:pos], string[pos+1:])
|
||||
|
||||
|
||||
def string_to_value(string):
|
||||
if string[0] in '"\'#':
|
||||
return string[1:-1] if string[-1] in '"\'#' else string[1:]
|
||||
elif string == 'null':
|
||||
return None
|
||||
elif string == 'false':
|
||||
return False
|
||||
elif string == 'true':
|
||||
return True
|
||||
elif string.isdigit():
|
||||
return int(string)
|
||||
try:
|
||||
return float(string)
|
||||
except:
|
||||
return string
|
||||
|
||||
def miio_command_help(did=None, prefix='?'):
|
||||
quote = '' if prefix == '?' else "'"
|
||||
return f'\
|
||||
Get Props: {prefix}<siid[-piid]>[,...]\n\
|
||||
{prefix}1,1-2,1-3,1-4,2-1,2-2,3\n\
|
||||
Set Props: {prefix}<siid[-piid]=[#]value>[,...]\n\
|
||||
{prefix}2=60,2-1=#60,2-2=false,2-3="null",3=test\n\
|
||||
Do Action: {prefix}<siid[-piid]> <arg1|[]> [...] \n\
|
||||
{prefix}2 []\n\
|
||||
{prefix}5 Hello\n\
|
||||
{prefix}5-4 Hello 1\n\n\
|
||||
Call MIoT: {prefix}<cmd=prop/get|/prop/set|action> <params>\n\
|
||||
{prefix}action {quote}{{"did":"{did or "267090026"}","siid":5,"aiid":1,"in":["Hello"]}}{quote}\n\n\
|
||||
Call MiIO: {prefix}/<uri> <data>\n\
|
||||
{prefix}/home/device_list {quote}{{"getVirtualModel":false,"getHuamiDevices":1}}{quote}\n\n\
|
||||
Devs List: {prefix}list [name=full|name_keyword] [getVirtualModel=false|true] [getHuamiDevices=0|1]\n\
|
||||
{prefix}list Light true 0\n\n\
|
||||
MIoT Spec: {prefix}spec [model_keyword|type_urn] [format=text|python|json]\n\
|
||||
{prefix}spec\n\
|
||||
{prefix}spec speaker\n\
|
||||
{prefix}spec xiaomi.wifispeaker.lx04\n\
|
||||
{prefix}spec urn:miot-spec-v2:device:speaker:0000A015:xiaomi-lx04:1\n\n\
|
||||
MIoT Decode: {prefix}decode <ssecurity> <nonce> <data> [gzip]\n\
|
||||
'
|
||||
|
||||
|
||||
async def miio_command(service: MiIOService, did, text, prefix='?'):
|
||||
cmd, arg = twins_split(text, ' ')
|
||||
|
||||
if cmd.startswith('/'):
|
||||
return await service.miio_request(cmd, arg)
|
||||
|
||||
if cmd.startswith('prop') or cmd == 'action':
|
||||
return await service.miot_request(cmd, json.loads(arg) if arg else None)
|
||||
|
||||
argv = arg.split(' ') if arg else []
|
||||
argc = len(argv)
|
||||
if cmd == 'list':
|
||||
return await service.device_list(argc > 0 and argv[0], argc > 1 and string_to_value(argv[1]), argc > 2 and argv[2])
|
||||
|
||||
if cmd == 'spec':
|
||||
return await service.miot_spec(argc > 0 and argv[0], argc > 1 and argv[1])
|
||||
|
||||
if cmd == 'decode':
|
||||
return MiIOService.miot_decode(argv[0], argv[1], argv[2], argc > 3 and argv[3] == 'gzip')
|
||||
|
||||
if not did or not cmd or cmd == '?' or cmd == '?' or cmd == 'help' or cmd == '-h' or cmd == '--help':
|
||||
return miio_command_help(did, prefix)
|
||||
|
||||
if not did.isdigit():
|
||||
devices = await service.device_list(did)
|
||||
if not devices:
|
||||
return "Device not found: " + did
|
||||
did = devices[0]['did']
|
||||
|
||||
props = []
|
||||
setp = True
|
||||
miot = True
|
||||
for item in cmd.split(','):
|
||||
key, value = twins_split(item, '=')
|
||||
siid, iid = twins_split(key, '-', '1')
|
||||
if siid.isdigit() and iid.isdigit():
|
||||
prop = [int(siid), int(iid)]
|
||||
else:
|
||||
prop = [key]
|
||||
miot = False
|
||||
if value is None:
|
||||
setp = False
|
||||
elif setp:
|
||||
prop.append(string_to_value(value))
|
||||
props.append(prop)
|
||||
|
||||
if miot and argc > 0:
|
||||
args = [] if arg == '[]' else [string_to_value(a) for a in argv]
|
||||
return await service.miot_action(did, props[0], args)
|
||||
|
||||
do_props = ((service.home_get_props, service.miot_get_props), (service.home_set_props, service.miot_set_props))[setp][miot]
|
||||
return await do_props(did, props if miot or setp else [p[0] for p in props])
|
||||
197
astrbot/core/platform/sources/mispeaker/miservice/miioservice.py
Executable file
197
astrbot/core/platform/sources/mispeaker/miservice/miioservice.py
Executable file
@@ -0,0 +1,197 @@
|
||||
import os
|
||||
import time
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
|
||||
# REGIONS = ['cn', 'de', 'i2', 'ru', 'sg', 'us']
|
||||
|
||||
|
||||
class MiIOService:
|
||||
|
||||
def __init__(self, account=None, region=None):
|
||||
self.account = account
|
||||
self.server = 'https://' + ('' if region is None or region == 'cn' else region + '.') + 'api.io.mi.com/app'
|
||||
|
||||
async def miio_request(self, uri, data):
|
||||
def prepare_data(token, cookies):
|
||||
cookies['PassportDeviceId'] = token['deviceId']
|
||||
return MiIOService.sign_data(uri, data, token['xiaomiio'][0])
|
||||
headers = {'User-Agent': 'iOS-14.4-6.0.103-iPhone12,3--D7744744F7AF32F0544445285880DD63E47D9BE9-8816080-84A3F44E137B71AE-iPhone', 'x-xiaomi-protocal-flag-cli': 'PROTOCAL-HTTP2'}
|
||||
resp = await self.account.mi_request('xiaomiio', self.server + uri, prepare_data, headers)
|
||||
if 'result' not in resp:
|
||||
raise Exception(f"Error {uri}: {resp}")
|
||||
return resp['result']
|
||||
|
||||
async def home_request(self, did, method, params):
|
||||
return await self.miio_request('/home/rpc/' + did, {'id': 1, 'method': method, "accessKey": "IOS00026747c5acafc2", 'params': params})
|
||||
|
||||
async def home_get_props(self, did, props):
|
||||
return await self.home_request(did, 'get_prop', props)
|
||||
|
||||
async def home_set_props(self, did, props):
|
||||
return [await self.home_set_prop(did, i[0], i[1]) for i in props]
|
||||
|
||||
async def home_get_prop(self, did, prop):
|
||||
return (await self.home_get_props(did, [prop]))[0]
|
||||
|
||||
async def home_set_prop(self, did, prop, value):
|
||||
result = (await self.home_request(did, 'set_' + prop, value if isinstance(value, list) else [value]))[0]
|
||||
return 0 if result == 'ok' else result
|
||||
|
||||
async def miot_request(self, cmd, params):
|
||||
return await self.miio_request('/miotspec/' + cmd, {'params': params})
|
||||
|
||||
async def miot_get_props(self, did, iids):
|
||||
params = [{'did': did, 'siid': i[0], 'piid': i[1]} for i in iids]
|
||||
result = await self.miot_request('prop/get', params)
|
||||
return [it.get('value') if it.get('code') == 0 else None for it in result]
|
||||
|
||||
async def miot_set_props(self, did, props):
|
||||
params = [{'did': did, 'siid': i[0], 'piid': i[1], 'value': i[2]} for i in props]
|
||||
result = await self.miot_request('prop/set', params)
|
||||
return [it.get('code', -1) for it in result]
|
||||
|
||||
async def miot_get_prop(self, did, iid):
|
||||
return (await self.miot_get_props(did, [iid]))[0]
|
||||
|
||||
async def miot_set_prop(self, did, iid, value):
|
||||
return (await self.miot_set_props(did, [(iid[0], iid[1], value)]))[0]
|
||||
|
||||
async def miot_action(self, did, iid, args=[]):
|
||||
result = await self.miot_request('action', {'did': did, 'siid': iid[0], 'aiid': iid[1], 'in': args})
|
||||
return result.get('code', -1)
|
||||
|
||||
async def device_list(self, name=None, getVirtualModel=False, getHuamiDevices=0):
|
||||
result = await self.miio_request('/home/device_list', {'getVirtualModel': bool(getVirtualModel), 'getHuamiDevices': int(getHuamiDevices)})
|
||||
result = result['list']
|
||||
return result if name == 'full' else [{'name': i['name'], 'model': i['model'], 'did': i['did'], 'token': i['token']} for i in result if not name or name in i['name']]
|
||||
|
||||
async def miot_spec(self, type=None, format=None):
|
||||
if not type or not type.startswith('urn'):
|
||||
def get_spec(all):
|
||||
if not type:
|
||||
return all
|
||||
ret = {}
|
||||
for m, t in all.items():
|
||||
if type == m:
|
||||
return {m: t}
|
||||
elif type in m:
|
||||
ret[m] = t
|
||||
return ret
|
||||
import tempfile
|
||||
path = os.path.join(tempfile.gettempdir(), 'miservice_miot_specs.json')
|
||||
try:
|
||||
with open(path) as f:
|
||||
result = get_spec(json.load(f))
|
||||
except:
|
||||
result = None
|
||||
if not result:
|
||||
async with self.account.session.get('http://miot-spec.org/miot-spec-v2/instances?status=all') as r:
|
||||
all = {i['model']: i['type'] for i in (await r.json())['instances']}
|
||||
with open(path, 'w') as f:
|
||||
json.dump(all, f)
|
||||
result = get_spec(all)
|
||||
if len(result) != 1:
|
||||
return result
|
||||
type = list(result.values())[0]
|
||||
|
||||
url = 'http://miot-spec.org/miot-spec-v2/instance?type=' + type
|
||||
async with self.account.session.get(url) as r:
|
||||
result = await r.json()
|
||||
|
||||
def parse_desc(node):
|
||||
desc = node['description']
|
||||
# pos = desc.find(' ')
|
||||
# if pos != -1:
|
||||
# return (desc[:pos], ' # ' + desc[pos + 2:])
|
||||
name = ''
|
||||
for i in range(len(desc)):
|
||||
d = desc[i]
|
||||
if d in '-—{「[【((<《':
|
||||
return (name, ' # ' + desc[i:])
|
||||
name += '_' if d == ' ' else d
|
||||
return (name, '')
|
||||
|
||||
def make_line(siid, iid, desc, comment, readable=False):
|
||||
value = f"({siid}, {iid})" if format == 'python' else iid
|
||||
return f" {'' if readable else '_'}{desc} = {value}{comment}\n"
|
||||
|
||||
if format != 'json':
|
||||
STR_HEAD, STR_SRV, STR_VALUE = ('from enum import Enum\n\n', '\nclass {}(tuple, Enum):\n', '\nclass {}(int, Enum):\n') if format == 'python' else ('', '{} = {}\n', '{}\n')
|
||||
text = '# Generated by https://github.com/Yonsm/MiService\n# ' + url + '\n\n' + STR_HEAD
|
||||
svcs = []
|
||||
vals = []
|
||||
|
||||
for s in result['services']:
|
||||
siid = s['iid']
|
||||
svc = s['description'].replace(' ', '_')
|
||||
svcs.append(svc)
|
||||
text += STR_SRV.format(svc, siid)
|
||||
for p in s.get('properties', []):
|
||||
name, comment = parse_desc(p)
|
||||
access = p['access']
|
||||
|
||||
comment += ''.join([' # ' + k for k, v in [(p['format'], 'string'), (''.join([a[0] for a in access]), 'r')] if k and k != v])
|
||||
text += make_line(siid, p['iid'], name, comment, 'read' in access)
|
||||
if 'value-range' in p:
|
||||
valuer = p['value-range']
|
||||
length = min(3, len(valuer))
|
||||
values = {['MIN', 'MAX', 'STEP'][i]: valuer[i] for i in range(length) if i != 2 or valuer[i] != 1}
|
||||
elif 'value-list' in p:
|
||||
values = {i['description'].replace(' ', '_') if i['description'] else str(i['value']): i['value'] for i in p['value-list']}
|
||||
else:
|
||||
continue
|
||||
vals.append((svc + '_' + name, values))
|
||||
if 'actions' in s:
|
||||
text += '\n'
|
||||
for a in s['actions']:
|
||||
name, comment = parse_desc(a)
|
||||
comment += ''.join([f" # {io}={a[io]}" for io in ['in', 'out'] if a[io]])
|
||||
text += make_line(siid, a['iid'], name, comment)
|
||||
text += '\n'
|
||||
for name, values in vals:
|
||||
text += STR_VALUE.format(name)
|
||||
for k, v in values.items():
|
||||
text += f" {'_' + k if k.isdigit() else k} = {v}\n"
|
||||
text += '\n'
|
||||
if format == 'python':
|
||||
text += '\nALL_SVCS = (' + ', '.join(svcs) + ')\n'
|
||||
result = text
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def miot_decode(ssecurity, nonce, data, gzip=False):
|
||||
from Crypto.Cipher import ARC4
|
||||
r = ARC4.new(base64.b64decode(MiIOService.sign_nonce(ssecurity, nonce)))
|
||||
r.encrypt(bytes(1024))
|
||||
decrypted = r.encrypt(base64.b64decode(data))
|
||||
if gzip:
|
||||
try:
|
||||
from io import BytesIO
|
||||
from gzip import GzipFile
|
||||
compressed = BytesIO()
|
||||
compressed.write(decrypted)
|
||||
compressed.seek(0)
|
||||
decrypted = GzipFile(fileobj=compressed, mode='rb').read()
|
||||
except:
|
||||
pass
|
||||
return json.loads(decrypted.decode())
|
||||
|
||||
@staticmethod
|
||||
def sign_nonce(ssecurity, nonce):
|
||||
m = hashlib.sha256()
|
||||
m.update(base64.b64decode(ssecurity))
|
||||
m.update(base64.b64decode(nonce))
|
||||
return base64.b64encode(m.digest()).decode()
|
||||
|
||||
@staticmethod
|
||||
def sign_data(uri, data, ssecurity):
|
||||
if not isinstance(data, str):
|
||||
data = json.dumps(data)
|
||||
nonce = base64.b64encode(os.urandom(8) + int(time.time() / 60).to_bytes(4, 'big')).decode()
|
||||
snonce = MiIOService.sign_nonce(ssecurity, nonce)
|
||||
msg = '&'.join([uri, snonce, nonce, 'data=' + data])
|
||||
sign = hmac.new(key=base64.b64decode(snonce), msg=msg.encode(), digestmod=hashlib.sha256).digest()
|
||||
return {'_nonce': nonce, 'data': data, 'signature': base64.b64encode(sign).decode()}
|
||||
@@ -0,0 +1,50 @@
|
||||
import json
|
||||
from .miaccount import MiAccount, get_random
|
||||
|
||||
import logging
|
||||
_LOGGER = logging.getLogger(__package__)
|
||||
|
||||
|
||||
class MiNAService:
|
||||
|
||||
def __init__(self, account: MiAccount):
|
||||
self.account = account
|
||||
|
||||
async def mina_request(self, uri, data=None):
|
||||
requestId = 'app_ios_' + get_random(30)
|
||||
if data is not None:
|
||||
data['requestId'] = requestId
|
||||
else:
|
||||
uri += '&requestId=' + requestId
|
||||
headers = {'User-Agent': 'MiHome/6.0.103 (com.xiaomi.mihome; build:6.0.103.1; iOS 14.4.0) Alamofire/6.0.103 MICO/iOSApp/appStore/6.0.103'}
|
||||
return await self.account.mi_request('micoapi', 'https://api2.mina.mi.com' + uri, data, headers)
|
||||
|
||||
async def device_list(self, master=0):
|
||||
result = await self.mina_request('/admin/v2/device_list?master=' + str(master))
|
||||
return result.get('data') if result else None
|
||||
|
||||
async def ubus_request(self, deviceId, method, path, message):
|
||||
message = json.dumps(message)
|
||||
result = await self.mina_request('/remote/ubus', {'deviceId': deviceId, 'message': message, 'method': method, 'path': path})
|
||||
return result and result.get('code') == 0
|
||||
|
||||
async def text_to_speech(self, deviceId, text):
|
||||
return await self.ubus_request(deviceId, 'text_to_speech', 'mibrain', {'text': text})
|
||||
|
||||
async def player_set_volume(self, deviceId, volume):
|
||||
return await self.ubus_request(deviceId, 'player_set_volume', 'mediaplayer', {'volume': volume, 'media': 'app_ios'})
|
||||
|
||||
async def send_message(self, devices, devno, message, volume=None): # -1/0/1...
|
||||
result = False
|
||||
for i in range(0, len(devices)):
|
||||
if devno == -1 or devno != i + 1 or devices[i]['capabilities'].get('yunduantts'):
|
||||
_LOGGER.debug("Send to devno=%d index=%d: %s", devno, i, message or volume)
|
||||
deviceId = devices[i]['deviceID']
|
||||
result = True if volume is None else await self.player_set_volume(deviceId, volume)
|
||||
if result and message:
|
||||
result = await self.text_to_speech(deviceId, message)
|
||||
if not result:
|
||||
_LOGGER.error("Send failed: %s", message or volume)
|
||||
if devno != -1 or not result:
|
||||
break
|
||||
return result
|
||||
63
astrbot/core/platform/sources/mispeaker/mispeaker_adapter.py
Normal file
63
astrbot/core/platform/sources/mispeaker/mispeaker_adapter.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import logging
|
||||
import time
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
|
||||
from astrbot.api.event import MessageChain
|
||||
from typing import Union, List
|
||||
from astrbot.api.message_components import Image, Plain, At
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from ...register import register_platform_adapter
|
||||
from astrbot.core.message.components import BaseMessageComponent
|
||||
from .client import SimpleMiSpeakerClient
|
||||
from .mispeaker_event import MiSpeakerPlatformEvent
|
||||
from astrbot.core import logger
|
||||
|
||||
|
||||
@register_platform_adapter("mispeaker", "小爱音箱")
|
||||
class MiSpeakerPlatformAdapter(Platform):
|
||||
|
||||
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
|
||||
super().__init__(event_queue)
|
||||
|
||||
self.config = platform_config
|
||||
|
||||
|
||||
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
|
||||
pass
|
||||
|
||||
def meta(self) -> PlatformMetadata:
|
||||
return PlatformMetadata(
|
||||
"mispeaker",
|
||||
"小爱音箱",
|
||||
)
|
||||
|
||||
async def handle_msg(self, message: AstrBotMessage):
|
||||
message_event = MiSpeakerPlatformEvent(
|
||||
message_str=message.message_str,
|
||||
message_obj=message,
|
||||
platform_meta=self.meta(),
|
||||
session_id=message.session_id,
|
||||
client=self.client
|
||||
)
|
||||
|
||||
self.commit_event(message_event)
|
||||
|
||||
def run(self):
|
||||
self.client = SimpleMiSpeakerClient(
|
||||
self.config
|
||||
)
|
||||
|
||||
async def on_event_received(abm: AstrBotMessage):
|
||||
logger.info(f"on_event_received: {abm}")
|
||||
|
||||
await self.handle_msg(abm)
|
||||
|
||||
self.client.on_event_received = on_event_received
|
||||
|
||||
return self._run()
|
||||
|
||||
async def _run(self):
|
||||
await self.client.initialize()
|
||||
await self.client.start_pooling()
|
||||
30
astrbot/core/platform/sources/mispeaker/mispeaker_event.py
Normal file
30
astrbot/core/platform/sources/mispeaker/mispeaker_event.py
Normal file
@@ -0,0 +1,30 @@
|
||||
import random
|
||||
import asyncio
|
||||
from astrbot.api import logger
|
||||
from astrbot.api.event import AstrMessageEvent, MessageChain
|
||||
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
|
||||
from astrbot.api.message_components import Plain, Image
|
||||
from .client import SimpleMiSpeakerClient
|
||||
|
||||
class MiSpeakerPlatformEvent(AstrMessageEvent):
|
||||
def __init__(
|
||||
self,
|
||||
message_str: str,
|
||||
message_obj: AstrBotMessage,
|
||||
platform_meta: PlatformMetadata,
|
||||
session_id: str,
|
||||
client: SimpleMiSpeakerClient
|
||||
):
|
||||
super().__init__(message_str, message_obj, platform_meta, session_id)
|
||||
self.client = client
|
||||
|
||||
@staticmethod
|
||||
async def send_with_client(message: MessageChain, user_name: str):
|
||||
pass
|
||||
|
||||
async def send(self, message: MessageChain):
|
||||
for comp in message.chain:
|
||||
if isinstance(comp, Plain):
|
||||
await self.client.send(comp.text)
|
||||
|
||||
await super().send(message)
|
||||
@@ -28,6 +28,8 @@ def register_provider_adapter(
|
||||
default_config_tmpl['type'] = provider_type_name
|
||||
if 'enable' not in default_config_tmpl:
|
||||
default_config_tmpl['enable'] = False
|
||||
if 'id' not in default_config_tmpl:
|
||||
default_config_tmpl['id'] = provider_type_name
|
||||
|
||||
pm = ProviderMetaData(
|
||||
type=provider_type_name,
|
||||
|
||||
@@ -18,7 +18,7 @@ class SimpleGoogleGenAIClient():
|
||||
self.api_base = api_base[:-1]
|
||||
else:
|
||||
self.api_base = api_base
|
||||
self.client = aiohttp.ClientSession()
|
||||
self.client = aiohttp.ClientSession(trust_env=True)
|
||||
|
||||
async def models_list(self) -> List[str]:
|
||||
request_url = f"{self.api_base}/v1beta/models?key={self.api_key}"
|
||||
@@ -224,15 +224,24 @@ class ProviderGoogleGenAI(Provider):
|
||||
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
self.pop_record(session_id)
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
|
||||
return llm_response
|
||||
retry_cnt = 10
|
||||
while retry_cnt > 0:
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
try:
|
||||
self.pop_record(session_id)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
raise e
|
||||
|
||||
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant" and session_id:
|
||||
|
||||
@@ -164,15 +164,25 @@ class ProviderOpenAIOfficial(Provider):
|
||||
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
self.pop_record(session_id)
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
retry_cnt = 10
|
||||
while retry_cnt > 0:
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
try:
|
||||
self.pop_record(session_id)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
raise e
|
||||
|
||||
return llm_response
|
||||
|
||||
async def save_history(self, contexts: List, new_record: dict, session_id: str, llm_response: LLMResponse):
|
||||
if llm_response.role == "assistant" and session_id:
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
import uuid
|
||||
import os
|
||||
import io
|
||||
from openai import AsyncOpenAI, NOT_GIVEN
|
||||
from ..provider import STTProvider
|
||||
from ..entites import ProviderType
|
||||
from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
|
||||
@register_provider_adapter("openai_whisper_api", "OpenAI Whisper API", provider_type=ProviderType.SPEECH_TO_TEXT)
|
||||
class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
@@ -33,34 +33,6 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
output_path = ff.convert(path, os.path.join('data/temp', filename))
|
||||
return output_path
|
||||
|
||||
async def _pcm_to_wav(self, input_io: io.BytesIO, output_path: str) -> str:
|
||||
import wave
|
||||
|
||||
with wave.open(output_path, 'wb') as wav:
|
||||
wav.setnchannels(1)
|
||||
wav.setsampwidth(2)
|
||||
wav.setframerate(24000)
|
||||
wav.writeframes(input_io.read())
|
||||
|
||||
return output_path
|
||||
|
||||
async def _convert_silk(self, path: str) -> str:
|
||||
import pysilk
|
||||
filename = str(uuid.uuid4()) + '.wav'
|
||||
output_path = os.path.join('data/temp', filename)
|
||||
with open(path, "rb") as f:
|
||||
input_data = f.read()
|
||||
if input_data.startswith(b'\x02'):
|
||||
# tencent 我爱你
|
||||
input_data = input_data[1:]
|
||||
input_io = io.BytesIO(input_data)
|
||||
output_io = io.BytesIO()
|
||||
pysilk.decode(input_io, output_io, 24000)
|
||||
output_io.seek(0)
|
||||
await self._pcm_to_wav(output_io, output_path)
|
||||
|
||||
return output_path
|
||||
|
||||
async def _is_silk_file(self, file_path):
|
||||
silk_header = b"SILK"
|
||||
with open(file_path, "rb") as f:
|
||||
@@ -91,8 +63,9 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
audio_url = await self._convert_silk(audio_url)
|
||||
|
||||
output_path = os.path.join('data/temp', str(uuid.uuid4()) + '.wav')
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
audio_url = output_path
|
||||
|
||||
result = await self.client.audio.transcriptions.create(
|
||||
model=self.model_name,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import uuid
|
||||
import os
|
||||
import io
|
||||
import asyncio
|
||||
import whisper
|
||||
from ..provider import STTProvider
|
||||
@@ -8,7 +7,7 @@ from ..entites import ProviderType
|
||||
from astrbot.core.utils.io import download_file
|
||||
from ..register import register_provider_adapter
|
||||
from astrbot.core import logger
|
||||
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
|
||||
@register_provider_adapter("openai_whisper_selfhost", "OpenAI Whisper 模型部署", provider_type=ProviderType.SPEECH_TO_TEXT)
|
||||
class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
@@ -34,34 +33,6 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
output_path = ff.convert(path, os.path.join('data/temp', filename))
|
||||
return output_path
|
||||
|
||||
async def _pcm_to_wav(self, input_io: io.BytesIO, output_path: str) -> str:
|
||||
import wave
|
||||
|
||||
with wave.open(output_path, 'wb') as wav:
|
||||
wav.setnchannels(1)
|
||||
wav.setsampwidth(2)
|
||||
wav.setframerate(24000)
|
||||
wav.writeframes(input_io.read())
|
||||
|
||||
return output_path
|
||||
|
||||
async def _convert_silk(self, path: str) -> str:
|
||||
import pysilk
|
||||
filename = str(uuid.uuid4()) + '.wav'
|
||||
output_path = os.path.join('data/temp', filename)
|
||||
with open(path, "rb") as f:
|
||||
input_data = f.read()
|
||||
if input_data.startswith(b'\x02'):
|
||||
# tencent 我爱你
|
||||
input_data = input_data[1:]
|
||||
input_io = io.BytesIO(input_data)
|
||||
output_io = io.BytesIO()
|
||||
pysilk.decode(input_io, output_io, 24000)
|
||||
output_io.seek(0)
|
||||
await self._pcm_to_wav(output_io, output_path)
|
||||
|
||||
return output_path
|
||||
|
||||
async def _is_silk_file(self, file_path):
|
||||
silk_header = b"SILK"
|
||||
with open(file_path, "rb") as f:
|
||||
@@ -93,7 +64,9 @@ class ProviderOpenAIWhisperSelfHost(STTProvider):
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
audio_url = await self._convert_silk(audio_url)
|
||||
|
||||
output_path = os.path.join('data/temp', str(uuid.uuid4()) + '.wav')
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
audio_url = output_path
|
||||
|
||||
result = await loop.run_in_executor(None, self.model.transcribe, audio_url)
|
||||
return result['text']
|
||||
@@ -60,15 +60,23 @@ class ProviderZhipu(ProviderOpenAIOfficial):
|
||||
"messages": context_query,
|
||||
**model_cfgs
|
||||
}
|
||||
llm_response = None
|
||||
try:
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
return llm_response
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
self.pop_record(session_id)
|
||||
logger.warning(traceback.format_exc())
|
||||
|
||||
await self.save_history(contexts, new_record, session_id, llm_response)
|
||||
|
||||
return llm_response
|
||||
retry_cnt = 10
|
||||
while retry_cnt > 0:
|
||||
logger.warning(f"请求失败:{e}。上下文长度超过限制。尝试弹出最早的记录然后重试。")
|
||||
try:
|
||||
self.pop_record(session_id)
|
||||
llm_response = await self._query(payloads, func_tool)
|
||||
break
|
||||
except Exception as e:
|
||||
if "maximum context length" in str(e):
|
||||
retry_cnt -= 1
|
||||
else:
|
||||
raise e
|
||||
else:
|
||||
raise e
|
||||
@@ -20,6 +20,6 @@ class PermissionTypeFilter(HandlerFilter):
|
||||
if self.permission_type == PermissionType.ADMIN:
|
||||
if not event.is_admin():
|
||||
event.stop_event()
|
||||
raise ValueError("您没有权限执行此操作。")
|
||||
raise ValueError(f"您 (ID: {event.get_sender_id()}) 没有权限执行此操作。")
|
||||
|
||||
return True
|
||||
|
||||
@@ -8,12 +8,14 @@ class PlatformAdapterType(enum.Flag):
|
||||
AIOCQHTTP = enum.auto()
|
||||
QQOFFICIAL = enum.auto()
|
||||
VCHAT = enum.auto()
|
||||
ALL = AIOCQHTTP | QQOFFICIAL | VCHAT
|
||||
GEWECHAT = enum.auto()
|
||||
ALL = AIOCQHTTP | QQOFFICIAL | VCHAT | GEWECHAT
|
||||
|
||||
ADAPTER_NAME_2_TYPE = {
|
||||
"aiocqhttp": PlatformAdapterType.AIOCQHTTP,
|
||||
"qq_official": PlatformAdapterType.QQOFFICIAL,
|
||||
"vchat": PlatformAdapterType.VCHAT
|
||||
"vchat": PlatformAdapterType.VCHAT,
|
||||
"gewechat": PlatformAdapterType.GEWECHAT
|
||||
}
|
||||
|
||||
class PlatformAdapterTypeFilter(HandlerFilter):
|
||||
|
||||
@@ -7,7 +7,6 @@ import yaml
|
||||
import logging
|
||||
from types import ModuleType
|
||||
from typing import List
|
||||
from pip import main as pip_main
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
from astrbot.core import logger, sp, pip_installer
|
||||
from .context import Context
|
||||
|
||||
@@ -11,7 +11,7 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
def __init__(self, repo_mirror: str = "") -> None:
|
||||
super().__init__(repo_mirror)
|
||||
self.MAIN_PATH = os.path.abspath(os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../"))
|
||||
self.ASTRBOT_RELEASE_API = "https://api.github.com/repos/Soulter/AstrBot/releases"
|
||||
self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases"
|
||||
|
||||
def terminate_child_processes(self):
|
||||
try:
|
||||
|
||||
@@ -70,7 +70,7 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
|
||||
下载图片, 返回 path
|
||||
'''
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
if post:
|
||||
async with session.post(url, json=post_data) as resp:
|
||||
if not path:
|
||||
@@ -91,7 +91,7 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
|
||||
# 关闭SSL验证
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.set_ciphers('DEFAULT')
|
||||
async with aiohttp.ClientSession(trust_env=False) as session:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
if post:
|
||||
async with session.get(url, ssl=ssl_context) as resp:
|
||||
return save_temp_img(await resp.read())
|
||||
@@ -101,34 +101,57 @@ async def download_image_by_url(url: str, post: bool = False, post_data: dict =
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def download_file(url: str, path: str):
|
||||
async def download_file(url: str, path: str, show_progress: bool = False):
|
||||
'''
|
||||
从指定 url 下载文件到指定路径 path
|
||||
'''
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, timeout=20) as resp:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(url, timeout=120) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"下载文件失败: {resp.status}")
|
||||
total_size = int(resp.headers.get('content-length', 0))
|
||||
downloaded_size = 0
|
||||
start_time = time.time()
|
||||
if show_progress:
|
||||
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
|
||||
with open(path, 'wb') as f:
|
||||
while True:
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
downloaded_size += len(chunk)
|
||||
if show_progress:
|
||||
elapsed_time = time.time() - start_time
|
||||
speed = downloaded_size / 1024 / elapsed_time # KB/s
|
||||
print(f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", end='')
|
||||
except aiohttp.client.ClientConnectorSSLError:
|
||||
# 关闭SSL验证
|
||||
ssl_context = ssl.create_default_context()
|
||||
ssl_context.set_ciphers('DEFAULT')
|
||||
async with aiohttp.ClientSession(trust_env=False) as session:
|
||||
async with session.get(url, ssl=ssl_context, timeout=20) as resp:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get(url, ssl=ssl_context, timeout=120) as resp:
|
||||
total_size = int(resp.headers.get('content-length', 0))
|
||||
downloaded_size = 0
|
||||
start_time = time.time()
|
||||
if show_progress:
|
||||
print(f"文件大小: {total_size / 1024:.2f} KB | 文件地址: {url}")
|
||||
with open(path, 'wb') as f:
|
||||
while True:
|
||||
chunk = await resp.content.read(8192)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
|
||||
downloaded_size += len(chunk)
|
||||
if show_progress:
|
||||
elapsed_time = time.time() - start_time
|
||||
speed = downloaded_size / 1024 / elapsed_time # KB/s
|
||||
print(f"\r下载进度: {downloaded_size / total_size:.2%} 速度: {speed:.2f} KB/s", end='')
|
||||
if show_progress:
|
||||
print()
|
||||
|
||||
|
||||
def file_to_base64(file_path: str) -> str:
|
||||
with open(file_path, "rb") as f:
|
||||
data_bytes = f.read()
|
||||
@@ -147,9 +170,22 @@ def get_local_ip_addresses():
|
||||
s.close()
|
||||
return ip
|
||||
|
||||
async def get_dashboard_version():
|
||||
if os.path.exists("data/dist"):
|
||||
if os.path.exists("data/dist/assets/version"):
|
||||
with open("data/dist/assets/version", "r") as f:
|
||||
v = f.read().strip()
|
||||
return v
|
||||
return None
|
||||
|
||||
async def download_dashboard():
|
||||
'''下载管理面板文件'''
|
||||
dashboard_release_url = "https://astrbot-registry.lwl.lol/download/astrbot-dashboard/latest/dist.zip"
|
||||
await download_file(dashboard_release_url, "data/dashboard.zip")
|
||||
dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip"
|
||||
try:
|
||||
await download_file(dashboard_release_url, "data/dashboard.zip", show_progress=True)
|
||||
except BaseException as _:
|
||||
dashboard_release_url = "https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip"
|
||||
await download_file(dashboard_release_url, "data/dashboard.zip", show_progress=True)
|
||||
print("解压管理面板文件中...")
|
||||
with zipfile.ZipFile("data/dashboard.zip", "r") as z:
|
||||
z.extractall("data")
|
||||
@@ -30,7 +30,7 @@ class Metric():
|
||||
pass
|
||||
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.post(base_url, json=payload, timeout=3) as response:
|
||||
if response.status != 200:
|
||||
pass
|
||||
|
||||
@@ -83,7 +83,7 @@ class LocalRenderStrategy(RenderStrategy):
|
||||
try:
|
||||
image_url = re.findall(IMAGE_REGEX, line)[0]
|
||||
print(image_url)
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(image_url) as resp:
|
||||
image_res = Image.open(BytesIO(await resp.read()))
|
||||
images[i] = image_res
|
||||
|
||||
@@ -33,7 +33,7 @@ class NetworkRenderStrategy(RenderStrategy):
|
||||
}
|
||||
}
|
||||
if return_url:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.post(f"{self.BASE_RENDER_URL}/generate", json=post_data) as resp:
|
||||
ret = await resp.json()
|
||||
return f"{self.BASE_RENDER_URL}/{ret['data']['id']}"
|
||||
|
||||
37
astrbot/core/utils/tencent_record_helper.py
Normal file
37
astrbot/core/utils/tencent_record_helper.py
Normal file
@@ -0,0 +1,37 @@
|
||||
import wave
|
||||
from io import BytesIO
|
||||
|
||||
async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str:
|
||||
import pysilk
|
||||
|
||||
with open(silk_path, "rb") as f:
|
||||
input_data = f.read()
|
||||
if input_data.startswith(b'\x02'):
|
||||
input_data = input_data[1:]
|
||||
input_io = BytesIO(input_data)
|
||||
output_io = BytesIO()
|
||||
pysilk.decode(input_io, output_io, 24000)
|
||||
output_io.seek(0)
|
||||
with wave.open(output_path, 'wb') as wav:
|
||||
wav.setnchannels(1)
|
||||
wav.setsampwidth(2)
|
||||
wav.setframerate(24000)
|
||||
wav.writeframes(output_io.read())
|
||||
|
||||
return output_path
|
||||
|
||||
async def wav_to_tencent_silk(wav_path: str) -> BytesIO:
|
||||
import pysilk
|
||||
|
||||
with wave.open(wav_path, 'rb') as wav:
|
||||
wav_data = wav.readframes(wav.getnframes())
|
||||
wav_data = BytesIO(wav_data)
|
||||
output_io = BytesIO()
|
||||
pysilk.encode(wav_data, output_io, 24000)
|
||||
output_io.seek(0)
|
||||
|
||||
# 在首字节添加 \x02
|
||||
silk_data = output_io.read()
|
||||
silk_data_with_prefix = b'\x02' + silk_data
|
||||
|
||||
return BytesIO(silk_data_with_prefix)
|
||||
@@ -29,7 +29,7 @@ class RepoZipUpdator():
|
||||
返回一个列表,每个元素是一个字典,包含版本号、发布时间、更新内容、commit hash等信息。
|
||||
'''
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(url) as response:
|
||||
result = await response.json()
|
||||
if not result:
|
||||
@@ -111,7 +111,7 @@ class RepoZipUpdator():
|
||||
releases = await self.fetch_release_info(url=release_url)
|
||||
if not releases:
|
||||
# download from the default branch directly.
|
||||
logger.warning(f"未在仓库 {author}/{repo} 中找到任何发布版本,正在从默认分支下载。")
|
||||
logger.info(f"未在仓库 {author}/{repo} 中找到任何发布版本,正在从默认分支下载。")
|
||||
release_url = f"https://github.com/{author}/{repo}/archive/refs/heads/master.zip"
|
||||
else:
|
||||
release_url = releases[0]['zipball_url']
|
||||
|
||||
@@ -182,8 +182,7 @@ class ChatRoute(Route):
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
except BaseException as e:
|
||||
logger.error(e)
|
||||
logger.error(f"与用户 {username} 断开聊天长连接。")
|
||||
logger.debug(f"用户 {username} 断开聊天长连接: {str(e)}。")
|
||||
self.curr_chat_sse.pop(username)
|
||||
return
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ class PluginRoute(Route):
|
||||
async def get_online_plugins(self):
|
||||
url = "https://soulter.github.io/AstrBot_Plugins_Collection/plugins.json"
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(url) as response:
|
||||
result = await response.json()
|
||||
return Response().ok(result).__dict__
|
||||
|
||||
@@ -15,7 +15,6 @@ class StatRoute(Route):
|
||||
self.routes = {
|
||||
'/stat/get': ('GET', self.get_stat),
|
||||
'/stat/version': ('GET', self.get_version),
|
||||
'/stat/dashboard-version': ('GET', self.get_dashboard_version),
|
||||
'/stat/start-time': ('GET', self.get_start_time),
|
||||
'/stat/restart-core': ('GET', self.restart_core)
|
||||
}
|
||||
@@ -37,16 +36,6 @@ class StatRoute(Route):
|
||||
"version": VERSION
|
||||
}).__dict__
|
||||
|
||||
async def get_dashboard_version(self):
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with session.get('https://api.github.com/repos/Soulter/Astrbot-dashboard/actions/artifacts') as resp:
|
||||
data = await resp.json()
|
||||
return Response().ok({
|
||||
"data": data,
|
||||
"mark": "unimplemented feature"
|
||||
}).__dict__
|
||||
|
||||
|
||||
async def get_start_time(self):
|
||||
return Response().ok({
|
||||
"start_time": self.core_lifecycle.start_time
|
||||
|
||||
@@ -1,31 +1,47 @@
|
||||
import threading
|
||||
import traceback
|
||||
from .route import Route, Response, RouteContext
|
||||
from quart import request
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.updator import AstrBotUpdator
|
||||
from astrbot.core import logger, pip_installer
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
from astrbot.core.config.default import VERSION
|
||||
|
||||
class UpdateRoute(Route):
|
||||
def __init__(self, context: RouteContext, astrbot_updator: AstrBotUpdator) -> None:
|
||||
def __init__(self, context: RouteContext, astrbot_updator: AstrBotUpdator, core_lifecycle: AstrBotCoreLifecycle) -> None:
|
||||
super().__init__(context)
|
||||
self.routes = {
|
||||
'/update/check': ('GET', self.check_update),
|
||||
'/update/do': ('POST', self.update_project),
|
||||
'/update/dashboard': ('POST', self.update_dashboard),
|
||||
'/update/pip-install': ('POST', self.install_pip_package)
|
||||
}
|
||||
self.astrbot_updator = astrbot_updator
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.register_routes()
|
||||
|
||||
async def check_update(self):
|
||||
type_ = request.args.get('type', None)
|
||||
|
||||
try:
|
||||
ret = await self.astrbot_updator.check_update(None, None)
|
||||
return Response(
|
||||
status="success",
|
||||
message=str(ret) if ret is not None else "已经是最新版本了。",
|
||||
data={
|
||||
"has_new_version": ret is not None
|
||||
}
|
||||
).__dict__
|
||||
dv = await get_dashboard_version()
|
||||
if type_ == 'dashboard':
|
||||
return Response().ok({
|
||||
"has_new_version": dv != f"v{VERSION}",
|
||||
"current_version": dv
|
||||
}).__dict__
|
||||
else:
|
||||
ret = await self.astrbot_updator.check_update(None, None)
|
||||
return Response(
|
||||
status="success",
|
||||
message=str(ret) if ret is not None else "已经是最新版本了。",
|
||||
data={
|
||||
"version": f"v{VERSION}",
|
||||
"has_new_version": ret is not None,
|
||||
"dashboard_version": dv,
|
||||
"dashboard_has_new_version": dv != f"v{VERSION}"
|
||||
}
|
||||
).__dict__
|
||||
except Exception as e:
|
||||
logger.error(traceback.format_exc())
|
||||
return Response().error(e.__str__()).__dict__
|
||||
@@ -41,8 +57,16 @@ class UpdateRoute(Route):
|
||||
latest = False
|
||||
try:
|
||||
await self.astrbot_updator.update(latest=latest, version=version)
|
||||
|
||||
if latest:
|
||||
try:
|
||||
await download_dashboard()
|
||||
except Exception as e:
|
||||
logger.error(f"下载管理面板文件失败: {e}。")
|
||||
|
||||
if reboot:
|
||||
threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
|
||||
# threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()
|
||||
self.core_lifecycle.restart()
|
||||
return Response().ok(None, "更新成功,AstrBot 将在 2 秒内全量重启以应用新的代码。").__dict__
|
||||
else:
|
||||
return Response().ok(None, "更新成功,AstrBot 将在下次启动时应用新的代码。").__dict__
|
||||
@@ -50,6 +74,18 @@ class UpdateRoute(Route):
|
||||
logger.error(f"/api/update_project: {traceback.format_exc()}")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
|
||||
async def update_dashboard(self):
|
||||
try:
|
||||
try:
|
||||
await download_dashboard()
|
||||
except Exception as e:
|
||||
logger.error(f"下载管理面板文件失败: {e}。")
|
||||
return Response().error(f"下载管理面板文件失败: {e}").__dict__
|
||||
return Response().ok(None, "更新成功。刷新页面即可应用新版本面板。").__dict__
|
||||
except Exception as e:
|
||||
logger.error(f"/api/update_dashboard: {traceback.format_exc()}")
|
||||
return Response().error(e.__str__()).__dict__
|
||||
|
||||
async def install_pip_package(self):
|
||||
data = await request.json
|
||||
package = data.get('package', '')
|
||||
|
||||
@@ -24,7 +24,7 @@ class AstrBotDashboard():
|
||||
# token 用于验证请求
|
||||
logging.getLogger(self.app.name).removeHandler(default_handler)
|
||||
self.context = RouteContext(self.config, self.app)
|
||||
self.ur = UpdateRoute(self.context, core_lifecycle.astrbot_updator)
|
||||
self.ur = UpdateRoute(self.context, core_lifecycle.astrbot_updator, core_lifecycle)
|
||||
self.sr = StatRoute(self.context, db, core_lifecycle)
|
||||
self.pr = PluginRoute(self.context, core_lifecycle, core_lifecycle.plugin_manager)
|
||||
self.cr = ConfigRoute(self.context, core_lifecycle)
|
||||
|
||||
12
changelogs/v3.4.10.md
Normal file
12
changelogs/v3.4.10.md
Normal file
@@ -0,0 +1,12 @@
|
||||
# What's Changed
|
||||
|
||||
- 修复 LLM 请求报错信息被覆盖的问题,增强 LLM 请求错误处理 #243
|
||||
- 修复 Napcat 接口更新导致 QQ 图片发送失败的问题 #246
|
||||
- 修复某些请求不能正确应用代理的问题
|
||||
- 针对 api_base 的明显提示,修改 ollama 模板的 api_base #247
|
||||
- 支持登出 gewechat,在webchat等地方使用 `/gewe_logout` 指令,这在微信上显示账号下线但是 gewe 仍显示设备在线时很好用
|
||||
- 添加gewechat适配器过滤器
|
||||
- help显示AstrBot和webui版本
|
||||
- 优化webui和主程序更新的协调
|
||||
- 下载管理面板时显示提示、下载进度和下载速度
|
||||
- 管理面板前端更新功能入口移入右上角更新按钮,以便统一管理 #245
|
||||
6
changelogs/v3.4.11.md
Normal file
6
changelogs/v3.4.11.md
Normal file
@@ -0,0 +1,6 @@
|
||||
# What's Changed
|
||||
|
||||
- 为平台和提供商适配器添加默认 ID 配置 #248
|
||||
- 修复appid保存的问题和部分群聊at失效的问题和群聊@的sender username显示异常的问题
|
||||
- 优化更新项目时重启可能会导致Address already in use的问题
|
||||
- 各类异步任务报错后的优雅报错输出,而不是只有在退出程序的时候才输出异常日志。
|
||||
5
changelogs/v3.4.8.md
Normal file
5
changelogs/v3.4.8.md
Normal file
@@ -0,0 +1,5 @@
|
||||
# What's Changed
|
||||
|
||||
- 支持 Gewechat 接入微信个人号(文字交互)
|
||||
- 支持回复时 At 和引用发送者 #241
|
||||
- 清除残留的 personalities
|
||||
6
changelogs/v3.4.9.md
Normal file
6
changelogs/v3.4.9.md
Normal file
@@ -0,0 +1,6 @@
|
||||
# What's Changed
|
||||
|
||||
- AstrBot 新域名:astrbot.app
|
||||
- LLM额外唤醒词与机器人唤醒词冲突时的处理
|
||||
- 调整部分日志的严重级别
|
||||
- 下载管理面板时显示提示、下载进度和下载速度
|
||||
9998
dashboard/package-lock.json
generated
9998
dashboard/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -15,6 +15,9 @@ let newUsername = ref('');
|
||||
let status = ref('');
|
||||
let updateStatus = ref('')
|
||||
let hasNewVersion = ref(false);
|
||||
let botCurrVersion = ref('');
|
||||
let dashboardHasNewVersion = ref(false);
|
||||
let dashboardCurrentVersion = ref('');
|
||||
let version = ref('');
|
||||
|
||||
const open = (link: string) => {
|
||||
@@ -64,6 +67,9 @@ function checkUpdate() {
|
||||
.then((res) => {
|
||||
hasNewVersion.value = res.data.data.has_new_version;
|
||||
updateStatus.value = res.data.message;
|
||||
botCurrVersion.value = res.data.data.version;
|
||||
dashboardCurrentVersion.value = res.data.data.dashboard_version;
|
||||
dashboardHasNewVersion.value = res.data.data.dashboard_has_new_version;
|
||||
})
|
||||
.catch((err) => {
|
||||
if (err.response.status == 401) {
|
||||
@@ -84,7 +90,24 @@ function switchVersion(version: string) {
|
||||
})
|
||||
.then((res) => {
|
||||
updateStatus.value = res.data.message;
|
||||
if (res.data.status == 'success') {
|
||||
if (res.data.status == 'ok') {
|
||||
setTimeout(() => {
|
||||
window.location.reload();
|
||||
}, 1000);
|
||||
}
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
updateStatus.value = err
|
||||
});
|
||||
}
|
||||
|
||||
function updateDashboard() {
|
||||
updateStatus.value = '正在更新...';
|
||||
axios.post('/api/update/dashboard')
|
||||
.then((res) => {
|
||||
updateStatus.value = res.data.message;
|
||||
if (res.data.status == 'ok') {
|
||||
setTimeout(() => {
|
||||
window.location.reload();
|
||||
}, 1000);
|
||||
@@ -106,8 +129,8 @@ commonStore.getStartTime();
|
||||
<template>
|
||||
<v-app-bar elevation="0" height="70">
|
||||
|
||||
<v-btn style="margin-left: 22px;" class="hidden-md-and-down text-secondary" color="lightsecondary" icon rounded="sm" variant="flat"
|
||||
@click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
|
||||
<v-btn style="margin-left: 22px;" class="hidden-md-and-down text-secondary" color="lightsecondary" icon rounded="sm"
|
||||
variant="flat" @click.stop="customizer.SET_MINI_SIDEBAR(!customizer.mini_sidebar)" size="small">
|
||||
<v-icon>mdi-menu</v-icon>
|
||||
</v-btn>
|
||||
<v-btn class="hidden-lg-and-up text-secondary ms-3" color="lightsecondary" icon rounded="sm" variant="flat"
|
||||
@@ -136,11 +159,16 @@ commonStore.getStartTime();
|
||||
</template>
|
||||
<v-card>
|
||||
<v-card-title>
|
||||
<span class="text-h5">更新项目</span>
|
||||
<span class="text-h5">更新 AstrBot</span>
|
||||
</v-card-title>
|
||||
<v-card-text>
|
||||
<v-container>
|
||||
<h3 class="mb-4">升级到最新版本</h3>
|
||||
<h3 class="mb-4">升级到项目最新版本</h3>
|
||||
<small>当前版本 {{ botCurrVersion }}</small>
|
||||
<div class="mb-4">
|
||||
<small>会同时尝试更新机器人主程序和管理面板。如果您正在使用 Docker 部署,也可以重新拉取镜像或者使用 <a
|
||||
href="https://containrrr.dev/watchtower/usage-overview/">watchtower</a> 来自动监控拉取。</small>
|
||||
</div>
|
||||
<p>{{ updateStatus }}</p>
|
||||
<v-btn class="mt-4 mb-4" @click="switchVersion('latest')" color="primary" style="border-radius: 10px;"
|
||||
:disabled="!hasNewVersion">
|
||||
@@ -148,7 +176,11 @@ commonStore.getStartTime();
|
||||
</v-btn>
|
||||
<v-divider></v-divider>
|
||||
<div style="margin-top: 16px;">
|
||||
<h3 class="mb-4">切换到指定版本或指定提交</h3>
|
||||
<h3 class="mb-4">切换到项目指定版本或指定提交</h3>
|
||||
<div class="mb-4">
|
||||
<small>跳到旧版本不会重新下载管理面板文件,这可能会造成部分数据显示错误。您可在 <a href="https://github.com/Soulter/AstrBot/releases">此处</a>
|
||||
找到对应的面板文件 dist.zip,解压后替换 data/dist 文件夹即可。</small>
|
||||
</div>
|
||||
<v-text-field label="输入版本号或 master 分支下的 commit hash。" v-model="version" required
|
||||
variant="outlined"></v-text-field>
|
||||
<div class="mb-4">
|
||||
@@ -160,7 +192,29 @@ commonStore.getStartTime();
|
||||
<v-btn color="error" style="border-radius: 10px;" @click="switchVersion(version)">
|
||||
确定切换
|
||||
</v-btn>
|
||||
</div>
|
||||
|
||||
<v-divider></v-divider>
|
||||
<div style="margin-top: 16px;">
|
||||
<h3 class="mb-4">更新管理面板到最新版本</h3>
|
||||
<div class="mb-4">
|
||||
<small>当前版本 {{ dashboardCurrentVersion }}</small>
|
||||
<br>
|
||||
|
||||
</div>
|
||||
|
||||
<div class="mb-4">
|
||||
<p v-if="dashboardHasNewVersion">
|
||||
有新版本!
|
||||
</p>
|
||||
<p v-else="dashboardHasNewVersion">
|
||||
已经是最新版本了。
|
||||
</p>
|
||||
</div>
|
||||
|
||||
<v-btn color="primary" style="border-radius: 10px;" @click="updateDashboard()">
|
||||
下载并更新
|
||||
</v-btn>
|
||||
</div>
|
||||
</v-container>
|
||||
</v-card-text>
|
||||
@@ -190,8 +244,7 @@ commonStore.getStartTime();
|
||||
<v-text-field label="原密码*" type="password" v-model="password" required
|
||||
variant="outlined"></v-text-field>
|
||||
|
||||
<v-text-field label="新用户名" v-model="newUsername" required
|
||||
variant="outlined"></v-text-field>
|
||||
<v-text-field label="新用户名" v-model="newUsername" required variant="outlined"></v-text-field>
|
||||
|
||||
<v-text-field label="新密码" type="password" v-model="newPassword" required
|
||||
variant="outlined"></v-text-field>
|
||||
|
||||
@@ -27,10 +27,10 @@ const sidebarMenu = shallowRef(sidebarItems);
|
||||
</v-btn>
|
||||
</v-list-item>
|
||||
<small style="display: block;" v-if="buildVer">构建: {{ buildVer }}</small>
|
||||
<small style="display: block;" v-else="buildVer">构建: embedded</small>
|
||||
<small style="display: block;" v-else>构建: embedded</small>
|
||||
<v-tooltip text="使用 /dashbord_update 指令更新管理面板">
|
||||
<template v-slot:activator="{ props }">
|
||||
<small v-bind="props" v-if="buildVer != version" style="display: block; margin-top: 4px;">面板有更新</small>
|
||||
<small v-bind="props" v-if="hasWebUIUpdate" style="display: block; margin-top: 4px;">面板有更新</small>
|
||||
</template>
|
||||
</v-tooltip>
|
||||
|
||||
@@ -50,19 +50,12 @@ export default {
|
||||
},
|
||||
data: () => ({
|
||||
version: "",
|
||||
buildVer: ""
|
||||
buildVer: "",
|
||||
hasWebUIUpdate: false,
|
||||
}),
|
||||
mounted() {
|
||||
this.get_version()
|
||||
fetch('/assets/version').then((res) => {
|
||||
return res.text()
|
||||
}).then((res) => {
|
||||
if (res.length > 10) {
|
||||
// 不是版本,不显示 😎
|
||||
return
|
||||
}
|
||||
this.buildVer = res.replace(/\s+/g, '')
|
||||
})
|
||||
this.check_webui_update()
|
||||
},
|
||||
methods: {
|
||||
get_version() {
|
||||
@@ -73,6 +66,16 @@ export default {
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
});
|
||||
},
|
||||
check_webui_update() {
|
||||
axios.get('/api/update/check?type=dashboard')
|
||||
.then((res) => {
|
||||
this.hasWebUIUpdate = res.data.data.has_new_version;
|
||||
this.buildVer = res.data.data.current_version;
|
||||
})
|
||||
.catch((err) => {
|
||||
console.log(err);
|
||||
});
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
@@ -1,4 +1,13 @@
|
||||
<template>
|
||||
<v-row style="margin: 2px;">
|
||||
<v-alert
|
||||
:type="noticeType"
|
||||
:text="noticeContent"
|
||||
:title="noticeTitle"
|
||||
v-if="noticeTitle && noticeContent"
|
||||
closable
|
||||
></v-alert>
|
||||
</v-row>
|
||||
<v-row>
|
||||
<v-col cols="12" md="4">
|
||||
<TotalMessage :stat="stat" />
|
||||
@@ -38,13 +47,26 @@ export default {
|
||||
},
|
||||
data: () => ({
|
||||
stat: {},
|
||||
noticeTitle: '',
|
||||
noticeContent: '',
|
||||
noticeType: '',
|
||||
}),
|
||||
|
||||
mounted() {
|
||||
axios.get('/api/stat/get').then((res) => {
|
||||
this.stat = res.data.data;
|
||||
});
|
||||
}
|
||||
|
||||
axios.get('https://api.soulter.top/astrbot-announcement').then((res) => {
|
||||
let data = res.data.data;
|
||||
// 如果 dashboard-notice 在其中
|
||||
if (data['dashboard-notice']) {
|
||||
this.noticeTitle = data['dashboard-notice'].title;
|
||||
this.noticeContent = data['dashboard-notice'].content;
|
||||
this.noticeType = data['dashboard-notice'].type;
|
||||
}
|
||||
});
|
||||
},
|
||||
};
|
||||
|
||||
</script>
|
||||
|
||||
22
main.py
22
main.py
@@ -6,7 +6,7 @@ from astrbot.dashboard import AstrBotDashBoardLifecycle
|
||||
from astrbot.core import db_helper
|
||||
from astrbot.core import logger, LogManager, LogBroker
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.utils.io import download_dashboard
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
|
||||
# add parent path to sys.path
|
||||
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
||||
@@ -37,22 +37,22 @@ def check_env():
|
||||
|
||||
async def check_dashboard_files():
|
||||
'''下载管理面板文件'''
|
||||
if os.path.exists("data/dist"):
|
||||
if os.path.exists("data/dist/assets/version"):
|
||||
with open("data/dist/assets/version", "r") as f:
|
||||
v = f.read().strip()
|
||||
if v != f"v{VERSION}":
|
||||
logger.warning("检测到管理面板有更新。可以使用 /dashboard update 命令更新。")
|
||||
else:
|
||||
logger.info("管理面板文件已是最新。")
|
||||
|
||||
v = await get_dashboard_version()
|
||||
if v is not None:
|
||||
# has file
|
||||
if v == f"v{VERSION}":
|
||||
logger.info("管理面板文件已是最新。")
|
||||
else:
|
||||
logger.warning("检测到管理面板有更新。可以使用 /dashboard_update 命令更新。")
|
||||
return
|
||||
|
||||
logger.info("开始下载管理面板文件...")
|
||||
logger.info("开始下载管理面板文件...高峰期(晚上)可能导致较慢的速度。如多次下载失败,请前往 https://github.com/Soulter/AstrBot/releases/latest 下载 dist.zip,并将其中的 dist 文件夹解压至 data 目录下。")
|
||||
|
||||
try:
|
||||
await download_dashboard()
|
||||
except Exception as e:
|
||||
logger.critical(f"下载管理面板文件失败: {e}")
|
||||
logger.critical(f"下载管理面板文件失败: {e}。")
|
||||
return
|
||||
|
||||
logger.info("管理面板下载完成。")
|
||||
|
||||
@@ -6,7 +6,8 @@ import astrbot.api.event.filter as filter
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult
|
||||
from astrbot.api import sp
|
||||
from astrbot.api.provider import Personality, ProviderRequest
|
||||
from astrbot.core.utils.io import download_dashboard
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
from astrbot.core.config.default import VERSION
|
||||
|
||||
from typing import Union
|
||||
|
||||
@@ -23,7 +24,7 @@ class Main(star.Star):
|
||||
|
||||
async def _query_astrbot_notice(self):
|
||||
try:
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get("https://astrbot.soulter.top/notice.json", timeout=2) as resp:
|
||||
return (await resp.json())["notice"]
|
||||
except BaseException:
|
||||
@@ -36,9 +37,12 @@ class Main(star.Star):
|
||||
notice = await self._query_astrbot_notice()
|
||||
except BaseException:
|
||||
pass
|
||||
|
||||
dashboard_version = await get_dashboard_version()
|
||||
|
||||
msg = "已注册的 AstrBot 内置指令:\n"
|
||||
msg += f"""[System]
|
||||
msg = f"""AstrBot v{VERSION}(WebUI: {dashboard_version})
|
||||
已注册的 AstrBot 内置指令:
|
||||
[System]
|
||||
/plugin: 查看注册的插件、插件帮助
|
||||
/t2i: 开启/关闭文本转图片模式
|
||||
/sid: 获取当前会话的 ID
|
||||
@@ -414,6 +418,16 @@ UID: {user_id} 此 ID 可用于设置管理员。/op <UID> 授权管理员, /deo
|
||||
del session_var[key]
|
||||
sp.put("session_variables", session_vars)
|
||||
yield event.plain_result(f"会话 {session_id} 变量 {key} 移除成功。")
|
||||
|
||||
@filter.command("gewe_logout")
|
||||
async def gewe_logout(self, event: AstrMessageEvent):
|
||||
platforms = self.context.platform_manager.platform_insts
|
||||
for platform in platforms:
|
||||
if platform.meta().name == "gewechat":
|
||||
yield event.plain_result("正在登出 gewechat")
|
||||
await platform.logout()
|
||||
yield event.plain_result("已登出 gewechat")
|
||||
return
|
||||
|
||||
@filter.command_group("kdb")
|
||||
def kdb(self):
|
||||
|
||||
@@ -127,7 +127,7 @@ class Main(star.Star):
|
||||
|
||||
s3_file_url = f"{S3_URL}/{uuid.uuid4().hex}{ext}"
|
||||
|
||||
async with aiohttp.ClientSession(headers = {"Accept": "application/json"}) as session:
|
||||
async with aiohttp.ClientSession(headers = {"Accept": "application/json"}, trust_env=True) as session:
|
||||
async with session.put(s3_file_url, data=file) as resp:
|
||||
if resp.status != 200:
|
||||
raise Exception(f"Failed to upload image: {resp.status}")
|
||||
@@ -140,7 +140,7 @@ class Main(star.Star):
|
||||
docker = aiodocker.Docker()
|
||||
await docker.version()
|
||||
return True
|
||||
except aiodocker.exceptions.DockerError as e:
|
||||
except BaseException as e:
|
||||
logger.info(f"检查 Docker 可用性: {e}")
|
||||
return False
|
||||
|
||||
@@ -159,7 +159,7 @@ class Main(star.Star):
|
||||
|
||||
async def download_image(self, image_url: str, workplace_path: str, filename: str) -> str:
|
||||
'''Download image from url to workplace_path'''
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(image_url) as resp:
|
||||
if resp.status != 200:
|
||||
return ""
|
||||
|
||||
@@ -39,7 +39,7 @@ class Main(star.Star):
|
||||
'''获取网页内容'''
|
||||
header = HEADERS
|
||||
header.update({'User-Agent': random.choice(USER_AGENTS)})
|
||||
async with aiohttp.ClientSession() as session:
|
||||
async with aiohttp.ClientSession(trust_env=True) as session:
|
||||
async with session.get(url, headers=header, timeout=6) as response:
|
||||
html = await response.text(encoding="utf-8")
|
||||
doc = Document(html)
|
||||
|
||||
Reference in New Issue
Block a user