Compare commits

..

1 Commits

Author SHA1 Message Date
Soulter
6c18971f00 feat: 初步接入米家小爱音箱 2025-01-25 02:30:00 +08:00
34 changed files with 914 additions and 591 deletions

View File

@@ -1,12 +1,14 @@
<p align="center">
![](https://github.com/user-attachments/assets/04f22f63-4dcf-4a6c-981e-b33f45c23f3e)
<p align="center">
<img src="https://github.com/user-attachments/assets/de10f24d-cd64-433a-90b8-16c0a60de24a" width=500>
</p>
<div align="center">
<h1>AstrBot</h1>
_✨ 易上手的多平台 LLM 聊天机器人及开发框架 ✨_
[![GitHub release (latest by date)](https://img.shields.io/github/v/release/Soulter/AstrBot)](https://github.com/Soulter/AstrBot/releases/latest)
@@ -70,7 +72,7 @@ AstrBot 是一个松耦合、异步、支持多消息平台部署、具有易用
| -------- | ------- | ------- | ------ |
| QQ | ✔ | 私聊、群聊 | 文字、图片、语音 |
| QQ 官方API | ✔ | 私聊、群聊QQ 频道私聊、群聊 | 文字、图片 |
| 微信 | ✔ | [Gewechat](https://github.com/Devo919/Gewechat)。微信个人号私聊、群聊 | 文字、图片、语音 |
| 微信 | ✔ | [Gewechat](https://github.com/Devo919/Gewechat)。微信个人号私聊、群聊 | 文字 |
| [Telegram](https://github.com/Soulter/astrbot_plugin_telegram) | ✔ | 私聊、群聊 | 文字、图片 |
| 微信对话开放平台 | 🚧 | 计划内 | - |
| 飞书 | 🚧 | 计划内 | - |

View File

@@ -2,7 +2,7 @@
如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。
"""
VERSION = "3.4.13"
VERSION = "3.4.11"
DB_PATH = "data/data_v3.db"
# 默认配置
@@ -24,13 +24,7 @@ DEFAULT_CONFIG = {
"wl_ignore_admin_on_friend": True,
"reply_with_mention": False,
"reply_with_quote": False,
"path_mapping": [],
"segmented_reply": {
"enable": False,
"only_llm_result": True,
"interval": "1.5,3.5",
"regex": ".*?[。?!~…]+|.+$"
}
"path_mapping": []
},
"provider": [],
"provider_settings": {
@@ -46,10 +40,6 @@ DEFAULT_CONFIG = {
"enable": False,
"provider_id": "",
},
"provider_tts_settings": {
"enable": False,
"provider_id": "",
},
"content_safety": {
"internal_keywords": {"enable": True, "extra_keywords": []},
"baidu_aip": {"enable": False, "app_id": "", "api_key": "", "secret_key": ""},
@@ -115,6 +105,17 @@ CONFIG_METADATA_2 = {
"host": "localhost",
"port": 11451,
},
"mispeaker(小爱音箱)": {
"id": "mispeaker",
"type": "mispeaker",
"enable": False,
"username": "",
"password": "",
"did": "",
"activate_word": "测试",
"deactivate_word": "停止",
"interval": 1,
},
},
"items": {
"id": {
@@ -188,31 +189,6 @@ CONFIG_METADATA_2 = {
},
},
},
"segmented_reply": {
"description": "分段回复",
"type": "object",
"items": {
"enable": {
"description": "启用分段回复",
"type": "bool",
},
"only_llm_result": {
"description": "仅对 LLM 结果分段",
"type": "bool",
},
"interval": {
"description": "随机间隔时间(秒)",
"type": "string",
"hint": "每一段回复的间隔时间,格式为 `最小时间,最大时间`。如 `0.75,2.5`",
},
"regex": {
"description": "正则表达式",
"type": "string",
"obvious_hint": True,
"hint": "用于分隔一段消息。默认情况下会根据句号、问号等标点符号分隔。re.findall(r'<regex>', text)",
},
},
},
"reply_prefix": {
"description": "回复前缀",
"type": "string",
@@ -406,14 +382,6 @@ CONFIG_METADATA_2 = {
"type": "openai_whisper_selfhost",
"model": "tiny",
},
"openai_tts(API)": {
"id": "openai_tts",
"type": "openai_tts_api",
"enable": False,
"api_key": "",
"api_base": "",
"model": "tts-1",
},
},
"items": {
"whisper_hint": {
@@ -613,23 +581,6 @@ CONFIG_METADATA_2 = {
},
},
},
"provider_tts_settings": {
"description": "文本转语音(TTS)",
"type": "object",
"items": {
"enable": {
"description": "启用文本转语音(TTS)",
"type": "bool",
"hint": "启用前请在 服务提供商配置 处创建支持 语音转文本任务 的提供商。如 openai_tts。",
"obvious_hint": True,
},
"provider_id": {
"description": "提供商 ID不填则默认第一个TTS提供商",
"type": "string",
"hint": "文本转语音提供商 ID。如果不填写将使用载入的第一个提供商。",
},
},
},
},
},
"misc_config_group": {

View File

@@ -7,6 +7,7 @@ from .event_bus import EventBus
from . import astrbot_config
from asyncio import Queue
from typing import List
from astrbot.core.config.astrbot_config import AstrBotConfig
from astrbot.core.pipeline.scheduler import PipelineScheduler, PipelineContext
from astrbot.core.star import PluginManager
from astrbot.core.platform.manager import PlatformManager

View File

@@ -123,7 +123,7 @@ class Record(BaseMessageComponent):
proxy: T.Optional[bool] = True
timeout: T.Optional[int] = 0
# 额外
path: T.Optional[str]
path: T.Optional[str] # 用这个
def __init__(self, file: T.Optional[str], **_):
for k in _.keys():

View File

@@ -13,10 +13,12 @@ class MessageChain():
Attributes:
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后将会依次发送 chain 中的每个 component。
'''
chain: List[BaseMessageComponent] = field(default_factory=list)
use_t2i_: Optional[bool] = None # None 为跟随用户设置
is_split_: Optional[bool] = False # 是否将消息分条发送。默认为 False。启用后将会依次发送 chain 中的每个 component。
def message(self, message: str):
'''添加一条文本消息到消息链 `chain` 中。
@@ -75,6 +77,16 @@ class MessageChain():
'''
self.use_t2i_ = use_t2i
return self
def is_split(self, is_split: bool):
'''设置是否分条发送消息。默认为 False。启用后将会依次发送 chain 中的每个 component。
Note:
具体的效果以各适配器实现为准。
'''
self.is_split_ = is_split
return self
class EventResultType(enum.Enum):
'''用于描述事件处理的结果类型。
@@ -101,6 +113,7 @@ class MessageEventResult(MessageChain):
Attributes:
`chain` (list): 用于顺序存储各个组件。
`use_t2i_` (bool): 用于标记是否使用文本转图片服务。默认为 None即跟随用户的设置。当设置为 True 时,将会使用文本转图片服务。
`is_split_` (bool): 用于标记是否分条发送消息。默认为 False。启用后将会依次发送 chain 中的每个 component。
`result_type` (EventResultType): 事件处理的结果类型。
'''
@@ -126,7 +139,7 @@ class MessageEventResult(MessageChain):
'''
return self.result_type == EventResultType.STOP
def set_result_content_type(self, typ: ResultContentType) -> 'MessageEventResult':
def set_result_content_type(self, typ: EventResultType) -> 'MessageEventResult':
'''设置事件处理的结果类型。
Args:
@@ -135,10 +148,5 @@ class MessageEventResult(MessageChain):
self.result_content_type = typ
return self
def is_llm_result(self) -> bool:
'''是否为 LLM 结果。
'''
return self.result_content_type == ResultContentType.LLM_RESULT
CommandResult = MessageEventResult

View File

@@ -51,7 +51,7 @@ class LLMRequestSubStage(Stage):
session_provider_context = provider.session_memory.get(event.session_id)
req.contexts = session_provider_context if session_provider_context else []
if not req.prompt and not req.image_urls:
if not req.prompt:
return
# 执行请求 LLM 前事件。

View File

@@ -1,10 +1,7 @@
import random
import asyncio
from typing import Union, AsyncGenerator
from ..stage import register_stage, Stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core import logger
from astrbot.core.star.star_handler import star_handlers_registry, EventType
@@ -12,17 +9,6 @@ from astrbot.core.star.star_handler import star_handlers_registry, EventType
class RespondStage(Stage):
async def initialize(self, ctx: PipelineContext):
self.ctx = ctx
# 分段回复
self.enable_seg: bool = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
interval_str: str = ctx.astrbot_config['platform_settings']['segmented_reply']['interval']
interval_str_ls = interval_str.replace(" ", "").split(",")
try:
self.interval = [float(t) for t in interval_str_ls]
except BaseException as e:
logger.error(f'解析分段回复的间隔时间失败。{e}')
self.interval = [1.5, 3.5]
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
@@ -30,15 +16,7 @@ class RespondStage(Stage):
return
if len(result.chain) > 0:
await event._pre_send()
if self.enable_seg:
# 分段回复
for comp in result.chain:
await event.send(MessageChain([comp]))
await asyncio.sleep(random.uniform(self.interval[0], self.interval[1]))
else:
await event.send(result)
await event._post_send()
await event.send(result)
logger.info(f"AstrBot -> {event.get_sender_name()}/{event.get_sender_id()}: {event._outline_chain(result.chain)}")
handlers = star_handlers_registry.get_handlers_by_event_type(EventType.OnAfterMessageSentEvent)

View File

@@ -1,13 +1,11 @@
import time
import re
import traceback
from typing import Union, AsyncGenerator
from ..stage import register_stage
from ..context import PipelineContext
from astrbot.core.platform.astr_message_event import AstrMessageEvent
from astrbot.core.platform.message_type import MessageType
from astrbot.core import logger
from astrbot.core.message.components import Plain, Image, At, Reply, Record
from astrbot.core.message.components import Plain, Image, At, Reply
from astrbot.core import html_renderer
from astrbot.core.star.star_handler import star_handlers_registry, EventType
@@ -18,13 +16,7 @@ class ResultDecorateStage:
self.reply_prefix = ctx.astrbot_config['platform_settings']['reply_prefix']
self.reply_with_mention = ctx.astrbot_config['platform_settings']['reply_with_mention']
self.reply_with_quote = ctx.astrbot_config['platform_settings']['reply_with_quote']
self.use_tts = ctx.astrbot_config['provider_tts_settings']['enable']
self.t2i = ctx.astrbot_config['t2i']
# 分段回复
self.enable_segmented_reply = ctx.astrbot_config['platform_settings']['segmented_reply']['enable']
self.only_llm_result = ctx.astrbot_config['platform_settings']['segmented_reply']['only_llm_result']
self.regex = ctx.astrbot_config['platform_settings']['segmented_reply']['regex']
async def process(self, event: AstrMessageEvent) -> Union[None, AsyncGenerator[None, None]]:
result = event.get_result()
@@ -39,53 +31,10 @@ class ResultDecorateStage:
if len(result.chain) > 0:
# 回复前缀
if self.reply_prefix:
for comp in result.chain:
if isinstance(comp, Plain):
comp.text = self.reply_prefix + comp.text
break
# 分段回复
if self.enable_segmented_reply:
if (self.only_llm_result and result.is_llm_result()) or not self.only_llm_result:
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain):
split_response = re.findall(r".*?[。?!~…]+|.+$", comp.text)
if not split_response:
new_chain.append(comp)
continue
for seg in split_response:
new_chain.append(Plain(seg))
else:
# 非 Plain 类型的消息段不分段
new_chain.append(comp)
result.chain = new_chain
# TTS
if self.use_tts and result.is_llm_result():
tts_provider = self.ctx.plugin_manager.context.provider_manager.curr_tts_provider_inst
new_chain = []
for comp in result.chain:
if isinstance(comp, Plain) and len(comp.text) > 1:
try:
logger.info("TTS 请求: " + plain_str)
audio_path = await tts_provider.get_audio(plain_str)
logger.info("TTS 结果: " + audio_path)
if audio_path:
new_chain.append(Record(file=audio_path, url=audio_path))
else:
logger.error(f"由于 TTS 音频文件没找到,消息段转语音失败: {comp.text}")
new_chain.append(comp)
except BaseException:
traceback.print_exc()
logger.error("TTS 失败,使用文本发送。")
new_chain.append(comp)
else:
new_chain.append(comp)
result.chain = new_chain
result.chain.insert(0, Plain(self.reply_prefix))
# 文本转图片
elif (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
if (result.use_t2i_ is None and self.t2i) or result.use_t2i_:
plain_str = ""
for comp in result.chain:
if isinstance(comp, Plain):

View File

@@ -179,15 +179,6 @@ class AstrMessageEvent(abc.ABC):
await Metric.upload(msg_event_tick = 1, adapter_name = self.platform_meta.name)
self._has_send_oper = True
async def _pre_send(self):
'''调度器会在执行 send() 前调用该方法'''
pass
async def _post_send(self):
'''调度器会在执行 send() 后调用该方法'''
pass
def set_result(self, result: Union[MessageEventResult, str]):
'''设置消息事件的结果。

View File

@@ -27,6 +27,8 @@ class PlatformManager():
from .sources.vchat.vchat_platform_adapter import VChatPlatformAdapter # noqa: F401
case "gewechat":
from .sources.gewechat.gewechat_platform_adapter import GewechatPlatformAdapter # noqa: F401
case "mispeaker":
from .sources.mispeaker.mispeaker_adapter import MiSpeakerPlatformAdapter # noqa: F401
async def initialize(self):

View File

@@ -3,7 +3,7 @@ import random
import asyncio
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.message_components import Plain, Image, Record
from astrbot.api.message_components import Plain, Image
from aiocqhttp import CQHttp
from astrbot.core.utils.io import file_to_base64, download_image_by_url
@@ -20,18 +20,16 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
d = segment.toDict()
if isinstance(segment, Plain):
d['type'] = 'text'
if isinstance(segment, (Image, Record)):
if isinstance(segment, Image):
# convert to base64
if segment.file and segment.file.startswith("file:///"):
bs64_data = file_to_base64(segment.file[8:])
image_base64 = file_to_base64(segment.file[8:])
image_file_path = segment.file[8:]
elif segment.file and segment.file.startswith("http"):
image_file_path = await download_image_by_url(segment.file)
bs64_data = file_to_base64(image_file_path)
else:
bs64_data = file_to_base64(segment.file)
image_base64 = file_to_base64(image_file_path)
d['data'] = {
'file': bs64_data,
'file': image_base64,
}
ret.append(d)
return ret
@@ -40,5 +38,11 @@ class AiocqhttpMessageEvent(AstrMessageEvent):
ret = await AiocqhttpMessageEvent._parse_onebot_json(message)
if os.environ.get('TEST_MODE', 'off') == 'on':
return
await self.bot.send(self.message_obj.raw_message, ret)
if message.is_split_: # 分条发送
for m in ret:
await self.bot.send(self.message_obj.raw_message, [m])
await asyncio.sleep(random.uniform(0.75, 2.5))
else:
await self.bot.send(self.message_obj.raw_message, ret)
await super().send(message)

View File

@@ -2,14 +2,10 @@ import threading
import asyncio
import aiohttp
import quart
import base64
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
from astrbot.api.message_components import Plain, Image, At, Record
from astrbot.api.message_components import Plain, Image, At
from astrbot.api import logger, sp
from .downloader import GeweDownloader
from astrbot.core.utils.io import download_image_by_url
class SimpleGewechatClient():
'''针对 Gewechat 的简单实现。
@@ -21,15 +17,9 @@ class SimpleGewechatClient():
self.base_url = base_url
if self.base_url.endswith('/'):
self.base_url = self.base_url[:-1]
self.download_base_url = self.base_url.split(':')[:-1] # 去掉端口
self.download_base_url = ':'.join(self.download_base_url) + ":2532/download/"
self.base_url += "/v2/api"
logger.info(f"Gewechat API: {self.base_url}")
logger.info(f"Gewechat 下载 API: {self.download_base_url}")
if isinstance(port, str):
port = int(port)
@@ -37,19 +27,15 @@ class SimpleGewechatClient():
self.headers = {}
self.nickname = nickname
self.appid = sp.get(f"gewechat-appid-{nickname}", "")
self.callback_url = None
self.server = quart.Quart(__name__)
self.server.add_url_rule('/astrbot-gewechat/callback', view_func=self.callback, methods=['POST'])
self.server.add_url_rule('/astrbot-gewechat/file/<file_id>', view_func=self.handle_file, methods=['GET'])
self.server.add_url_rule('/astrbot-gewechat/callback', view_func=self.callback, methods=['POST'])
self.host = host
self.port = port
self.callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
self.file_server_url = f"http://{self.host}:{self.port}/astrbot-gewechat/file"
self.event_queue = event_queue
self.multimedia_downloader = None
async def get_token_id(self):
async with aiohttp.ClientSession() as session:
@@ -68,84 +54,55 @@ class SimpleGewechatClient():
return
abm = AstrBotMessage()
d = data['Data']
from_user_name = d['FromUserName']['string'] # 消息来源
d['to_wxid'] = from_user_name # 用于发信息
msg_type = d['MsgType']
abm.message_id = str(d.get('MsgId'))
abm.session_id = from_user_name
abm.self_id = data['Wxid'] # 机器人的 wxid
user_id = "" # 发送人 wxid
content = d['Content']['string'] # 消息内容
at_me = False
if "@chatroom" in from_user_name:
abm.type = MessageType.GROUP_MESSAGE
_t = content.split(':\n')
user_id = _t[0]
content = _t[1]
if '\u2005' in content:
# at
content = content.split('\u2005')[1]
abm.group_id = from_user_name
# at
msg_source = d['MsgSource']
if f'<atuserlist><![CDATA[,{abm.self_id}]]>' in msg_source \
or f'<atuserlist><![CDATA[{abm.self_id}]]>' in msg_source:
at_me = True
else:
abm.type = MessageType.FRIEND_MESSAGE
user_id = from_user_name
abm.message = []
if at_me:
abm.message.insert(0, At(qq=abm.self_id))
user_real_name = d['PushContent'].split(' : ')[0] \
.replace('在群聊中@了你', '') \
.replace('在群聊中发了一段语音', '') # 真实昵称
abm.sender = MessageMember(user_id, user_real_name)
abm.raw_message = d
abm.message_str = ""
# 不同消息类型
match d['MsgType']:
match msg_type:
case 1:
# 文本消息
abm.message.append(Plain(content))
from_user_name = d['FromUserName']['string'] # 消息来源
d['to_wxid'] = from_user_name # 用于发信息
user_id = "" # 发送人 wxid
content = d['Content']['string'] # 消息内容
user_real_name = d['PushContent'].split(' : ')[0] # 真实昵称
user_real_name = user_real_name.replace('在群聊中@了你', '') # trick
abm.self_id = data['Wxid'] # 机器人的 wxid
at_me = False
if "@chatroom" in from_user_name:
abm.type = MessageType.GROUP_MESSAGE
_t = content.split(':\n')
user_id = _t[0]
content = _t[1]
if '\u2005' in content:
# at
content = content.split('\u2005')[1]
abm.group_id = from_user_name
# at
msg_source = d['MsgSource']
if f'<atuserlist><![CDATA[,{abm.self_id}]]>' in msg_source \
or f'<atuserlist><![CDATA[{abm.self_id}]]>' in msg_source:
at_me = True
else:
abm.type = MessageType.FRIEND_MESSAGE
user_id = from_user_name
abm.session_id = from_user_name
abm.sender = MessageMember(user_id, user_real_name)
abm.message = [Plain(content)]
if at_me:
abm.message.insert(0, At(qq=abm.self_id))
abm.message_id = str(d['MsgId'])
abm.raw_message = d
abm.message_str = content
case 3:
# 图片消息
file_url = await self.multimedia_downloader.download_image(
self.appid,
content
)
logger.debug(f"下载图片: {file_url}")
file_path = await download_image_by_url(file_url)
abm.message.append(Image(file=file_path, url=file_path))
case 34:
# 语音消息
# data = await self.multimedia_downloader.download_voice(
# self.appid,
# content,
# abm.message_id
# )
# print(data)
if 'ImgBuf' in d and 'buffer' in d['ImgBuf']:
voice_data = base64.b64decode(d['ImgBuf']['buffer'])
file_path = f"data/temp/gewe_voice_{abm.message_id}.silk"
with open(file_path, "wb") as f:
f.write(voice_data)
abm.message.append(Record(file=file_path, url=file_path))
logger.info(f"abm: {abm}")
return abm
case _:
logger.error(f"未实现的消息类型: {d['MsgType']}")
return
logger.info(f"abm: {abm}")
return abm
logger.error(f"未实现的消息类型: {msg_type}")
async def callback(self):
data = await quart.request.json
logger.debug(f"收到 gewechat 回调: {data}")
@@ -161,31 +118,32 @@ class SimpleGewechatClient():
await coro(abm)
return quart.jsonify({"r": "AstrBot ACK"})
async def handle_file(self, file_id):
file_path = f"data/temp/{file_id}"
return await quart.send_file(file_path)
async def _set_callback_url(self):
logger.info("设置回调,请等待...")
await asyncio.sleep(3)
callback_url = f"http://{self.host}:{self.port}/astrbot-gewechat/callback"
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/tools/setCallback",
headers=self.headers,
json={
"token": self.token,
"callbackUrl": self.callback_url
"callbackUrl": callback_url
}
) as resp:
json_blob = await resp.json()
logger.info(f"设置回调结果: {json_blob}")
if json_blob['ret'] != 200:
raise Exception(f"设置回调失败: {json_blob}")
logger.info(f"将在 {self.callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。")
logger.info(f"将在 {callback_url} 上接收 gewechat 下发的消息。如果一直没收到消息请先尝试重启 AstrBot。")
async def start_polling(self):
# 设置回调
threading.Thread(target=asyncio.run, args=(self._set_callback_url(),)).start()
await self.server.run_task(
host=self.host,
port=self.port,
@@ -228,8 +186,6 @@ class SimpleGewechatClient():
async def login(self):
if self.token is None:
await self.get_token_id()
self.multimedia_downloader = GeweDownloader(self.base_url, self.download_base_url, self.token)
if self.appid:
online = await self.check_online(self.appid)
@@ -307,39 +263,4 @@ class SimpleGewechatClient():
json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送消息结果: {json_blob}")
async def post_image(self, to_wxid, image_url: str):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"imgUrl": image_url,
}
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postImage",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送图片结果: {json_blob}")
async def post_voice(self, to_wxid, voice_url: str, voice_duration: int):
payload = {
"appId": self.appid,
"toWxid": to_wxid,
"voiceUrl": voice_url,
"voiceDuration": voice_duration
}
logger.debug(f"发送语音: {payload}")
async with aiohttp.ClientSession() as session:
async with session.post(
f"{self.base_url}/message/postVoice",
headers=self.headers,
json=payload
) as resp:
json_blob = await resp.json()
logger.debug(f"发送语音结果: {json_blob}")
logger.info(f"发送消息结果: {json_blob}")

View File

@@ -1,51 +0,0 @@
from astrbot import logger
import aiohttp
import json
class GeweDownloader():
def __init__(self, base_url: str, download_base_url: str, token: str):
self.base_url = base_url
self.download_base_url = download_base_url
self.headers = {
"Content-Type": "application/json",
"X-GEWE-TOKEN": token
}
async def _post_json(self, baseurl: str, route: str, payload: dict):
async with aiohttp.ClientSession() as session:
async with session.post(
f"{baseurl}{route}",
headers=self.headers,
json=payload
) as resp:
return await resp.read()
async def download_voice(self, appid: str, xml: str, msg_id: str):
payload = {
"appId": appid,
"xml": xml,
"msgId": msg_id
}
return await self._post_json(self.base_url, "/message/downloadVoice", payload)
async def download_image(self, appid: str, xml: str) -> str:
'''返回一个可下载的 URL'''
choices = [2, 3] # 2:常规图片 3:缩略图
for choice in choices:
try:
payload = {
"appId": appid,
"xml": xml,
"type": choice
}
data = await self._post_json(self.base_url, "/message/downloadImage", payload)
json_blob = json.loads(data)
if 'fileUrl' in json_blob['data']:
return self.download_base_url + json_blob['data']['fileUrl']
except BaseException as e:
logger.error(f"gewe download image: {e}")
continue
raise Exception("无法下载图片")

View File

@@ -1,24 +1,12 @@
import wave
import uuid
import os
from astrbot.core.utils.io import save_temp_img, download_image_by_url, download_file
from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk
import random
import asyncio
from astrbot.core.utils.io import download_image_by_url
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image, Record
from astrbot.api.message_components import Plain, Image
from .client import SimpleGewechatClient
def get_wav_duration(file_path):
with wave.open(file_path, 'rb') as wav_file:
file_size = os.path.getsize(file_path)
n_channels, sampwidth, framerate, n_frames = wav_file.getparams()[:4]
if n_frames == 2147483647:
duration = (file_size - 44) / (n_channels * sampwidth * framerate)
else:
duration = n_frames / float(framerate)
return duration
class GewechatPlatformEvent(AstrMessageEvent):
def __init__(
self,
@@ -46,57 +34,5 @@ class GewechatPlatformEvent(AstrMessageEvent):
for comp in message.chain:
if isinstance(comp, Plain):
await self.client.post_text(to_wxid, comp.text)
elif isinstance(comp, Image):
img_url = comp.file
img_path = ""
if img_url.startswith("file:///"):
img_path = img_url[8:]
elif comp.file and comp.file.startswith("http"):
img_path = await download_image_by_url(comp.file)
else:
img_path = img_url
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
temp_directory = os.path.abspath('data/temp')
img_path = os.path.abspath(img_path)
if os.path.commonpath([temp_directory, img_path]) != temp_directory:
with open(img_path, "rb") as f:
img_path = save_temp_img(f.read())
file_id = os.path.basename(img_path)
img_url = f"{self.client.file_server_url}/{file_id}"
logger.debug(f"gewe callback img url: {img_url}")
await self.client.post_image(to_wxid, img_url)
elif isinstance(comp, Record):
# 默认已经存在 data/temp 中
record_url = comp.file
record_path = ""
if record_url.startswith("file:///"):
record_path = record_url[8:]
elif record_url.startswith("http"):
await download_file(record_url, f"data/temp/{uuid.uuid4()}.wav")
else:
record_path = record_url
silk_path = f"data/temp/{uuid.uuid4()}.silk"
duration = await wav_to_tencent_silk(record_path, silk_path)
print(f"duration: {duration}, {silk_path}")
# 检查 record_path 是否在 data/temp 目录中, record_path 可能是绝对路径
# temp_directory = os.path.abspath('data/temp')
# record_path = os.path.abspath(record_path)
# if os.path.commonpath([temp_directory, record_path]) != temp_directory:
# with open(record_path, "rb") as f:
# record_path = f"data/temp/{uuid.uuid4()}.wav"
# with open(record_path, "wb") as f2:
# f2.write(f.read())
if duration == 0:
duration = get_wav_duration(record_path)
file_id = os.path.basename(silk_path)
record_url = f"{self.client.file_server_url}/{file_id}"
await self.client.post_voice(to_wxid, record_url, duration*1000)
await super().send(message)

View File

@@ -0,0 +1,137 @@
import os
import json
import asyncio
import aiohttp
import time
import traceback
from .miservice import MiAccount, MiNAService, MiIOService, miio_command, miio_command_help
from astrbot.core import logger
from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType
from astrbot.api.message_components import Plain, Image, At
class SimpleMiSpeakerClient():
'''
@author: Soulter
@references: https://github.com/yihong0618/xiaogpt/blob/main/xiaogpt/xiaogpt.py
'''
def __init__(self, config: dict):
self.username = config['username']
self.password = config['password']
self.did = config['did']
self.store = os.path.join("data", '.mi.token')
self.interval = float(config.get('interval', 1))
self.conv_query_cookies = {
'userId': '',
'deviceId': '',
'serviceToken': ''
}
self.MI_CONVERSATION_URL = "https://userprofile.mina.mi.com/device_profile/v2/conversation?source=dialogu&hardware={hardware}&timestamp={timestamp}&limit=1"
self.session = aiohttp.ClientSession()
self.activate_word = config.get('activate_word', '测试')
self.deactivate_word = config.get('deactivate_word', '停止')
self.entered = False
async def initialize(self):
account = MiAccount(self.session, self.username, self.password, self.store)
self.miio_service = MiIOService(account) # 小米设备服务
self.mina_service = MiNAService(account) # 小爱音箱服务
device = await self.get_mina_device()
self.deviceID = device['deviceID']
self.hardware = device['hardware']
with open(self.store, 'r') as f:
data = json.load(f)
self.userId = data['userId']
self.serviceToken = data['micoapi'][1]
self.conv_query_cookies['userId'] = self.userId
self.conv_query_cookies['deviceId'] = self.deviceID
self.conv_query_cookies['serviceToken'] = self.serviceToken
logger.info(f"MiSpeakerClient initialized. Conv cookies: {self.conv_query_cookies}. Hardware: {self.hardware}")
async def get_mina_device(self) -> dict:
devices = await self.mina_service.device_list()
for device in devices:
if device['miotDID'] == self.did:
logger.info(f"找到设备 {device['alias']}({device['name']}) 了!")
return device
async def get_conv(self) -> str:
# 时区请确保为北京时间
async with aiohttp.ClientSession() as session:
session.cookie_jar.update_cookies(self.conv_query_cookies)
query_ts = int(time.time())*1000
logger.debug(f"Querying conversation at {query_ts}")
async with session.get(self.MI_CONVERSATION_URL.format(hardware=self.hardware, timestamp=str(query_ts))) as resp:
json_blob = await resp.json()
if json_blob['code'] == 0:
data = json.loads(json_blob['data'])
records = data.get('records', None)
for record in records:
if record['time'] >= query_ts - self.interval*1000:
return record['query']
else:
logger.error(f"Failed to get conversation: {json_blob}")
return None
async def start_pooling(self):
while True:
await asyncio.sleep(self.interval)
try:
query = await self.get_conv()
if not query:
continue
# is wake
if query == self.activate_word:
self.entered = True
await self.stop_playing()
await self.send("我来啦!")
continue
elif query == self.deactivate_word:
self.entered = False
await self.stop_playing()
await self.send("再见,欢迎给个 Star。")
continue
if not self.entered:
continue
await self.send("")
abm = await self._convert(query)
if abm:
coro = getattr(self, "on_event_received")
if coro:
await coro(abm)
except BaseException as e:
traceback.print_exc()
logger.error(e)
async def _convert(self, query: str):
abm = AstrBotMessage()
abm.message = [Plain(query)]
abm.message_id = str(int(time.time()))
abm.message_str = query
abm.raw_message = query
abm.session_id = f"{self.hardware}_{self.did}_{self.username}"
abm.sender = MessageMember(self.username, "主人")
abm.self_id = f"{self.hardware}_{self.did}"
abm.type = MessageType.FRIEND_MESSAGE
return abm
async def send(self, message: str):
text = f'5 {message}'
await miio_command(self.miio_service, self.did, text, 'astrbot')
async def stop_playing(self):
text = f'3-2'
await miio_command(self.miio_service, self.did, text, 'astrbot')

View File

@@ -0,0 +1,21 @@
MIT License
Copyright (c) 2021-2022 Yonsm
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.

View File

@@ -0,0 +1,5 @@
from .miaccount import MiAccount, MiTokenStore
from .minaservice import MiNAService
from .miioservice import MiIOService
from .miiocommand import miio_command, miio_command_help

View File

@@ -0,0 +1,135 @@
import base64
import hashlib
import json
import logging
import os
import random
import string
from urllib import parse
from aiohttp import ClientSession
from aiofiles import open as async_open
_LOGGER = logging.getLogger(__package__)
def get_random(length):
return ''.join(random.sample(string.ascii_letters + string.digits, length))
class MiTokenStore:
def __init__(self, token_path):
self.token_path = token_path
async def load_token(self):
if os.path.isfile(self.token_path):
try:
async with async_open(self.token_path) as f:
return json.loads(await f.read())
except Exception as e:
_LOGGER.exception("Exception on load token from %s: %s", self.token_path, e)
return None
async def save_token(self, token=None):
if token:
try:
async with async_open(self.token_path, 'w') as f:
await f.write(json.dumps(token, indent=2))
except Exception as e:
_LOGGER.exception("Exception on save token to %s: %s", self.token_path, e)
elif os.path.isfile(self.token_path):
os.remove(self.token_path)
class MiAccount:
def __init__(self, session: ClientSession, username, password, token_store='.mi.token'):
self.session = session
self.username = username
self.password = password
self.token_store = MiTokenStore(token_store) if isinstance(token_store, str) else token_store
self.token = None
async def login(self, sid):
if not self.token:
self.token = {'deviceId': get_random(16).upper()}
try:
resp = await self._serviceLogin(f'serviceLogin?sid={sid}&_json=true')
if resp['code'] != 0:
data = {
'_json': 'true',
'qs': resp['qs'],
'sid': resp['sid'],
'_sign': resp['_sign'],
'callback': resp['callback'],
'user': self.username,
'hash': hashlib.md5(self.password.encode()).hexdigest().upper()
}
resp = await self._serviceLogin('serviceLoginAuth2', data)
if resp['code'] != 0:
raise Exception(resp)
self.token['userId'] = resp['userId']
self.token['passToken'] = resp['passToken']
serviceToken = await self._securityTokenService(resp['location'], resp['nonce'], resp['ssecurity'])
self.token[sid] = (resp['ssecurity'], serviceToken)
if self.token_store:
await self.token_store.save_token(self.token)
return True
except Exception as e:
self.token = None
if self.token_store:
await self.token_store.save_token()
_LOGGER.exception("Exception on login %s: %s", self.username, e)
return False
async def _serviceLogin(self, uri, data=None):
headers = {'User-Agent': 'APP/com.xiaomi.mihome APPV/6.0.103 iosPassportSDK/3.9.0 iOS/14.4 miHSTS'}
cookies = {'sdkVersion': '3.9', 'deviceId': self.token['deviceId']}
if 'passToken' in self.token:
cookies['userId'] = self.token['userId']
cookies['passToken'] = self.token['passToken']
url = 'https://account.xiaomi.com/pass/' + uri
async with self.session.request('GET' if data is None else 'POST', url, data=data, cookies=cookies, headers=headers) as r:
raw = await r.read()
resp = json.loads(raw[11:])
_LOGGER.debug("%s: %s", uri, resp)
return resp
async def _securityTokenService(self, location, nonce, ssecurity):
nsec = 'nonce=' + str(nonce) + '&' + ssecurity
clientSign = base64.b64encode(hashlib.sha1(nsec.encode()).digest()).decode()
async with self.session.get(location + '&clientSign=' + parse.quote(clientSign)) as r:
serviceToken = r.cookies['serviceToken'].value
if not serviceToken:
raise Exception(await r.text())
return serviceToken
async def mi_request(self, sid, url, data, headers, relogin=True):
if self.token is None and self.token_store is not None:
self.token = await self.token_store.load_token()
if (self.token and sid in self.token) or await self.login(sid): # Ensure login
cookies = {'userId': self.token['userId'], 'serviceToken': self.token[sid][1]}
content = data(self.token, cookies) if callable(data) else data
method = 'GET' if data is None else 'POST'
_LOGGER.debug("%s %s", url, content)
async with self.session.request(method, url, data=content, cookies=cookies, headers=headers) as r:
status = r.status
if status == 200:
resp = await r.json(content_type=None)
code = resp['code']
if code == 0:
return resp
if 'auth' in resp.get('message', '').lower():
status = 401
else:
resp = await r.text()
if status == 401 and relogin:
_LOGGER.warn("Auth error on request %s %s, relogin...", url, resp)
self.token = None # Auth error, reset login
return await self.mi_request(sid, url, data, headers, False)
else:
resp = "Login failed"
raise Exception(f"Error {url}: {resp}")

View File

@@ -0,0 +1,104 @@
import json
from .miioservice import MiIOService
def twins_split(string, sep, default=None):
pos = string.find(sep)
return (string, default) if pos == -1 else (string[0:pos], string[pos+1:])
def string_to_value(string):
if string[0] in '"\'#':
return string[1:-1] if string[-1] in '"\'#' else string[1:]
elif string == 'null':
return None
elif string == 'false':
return False
elif string == 'true':
return True
elif string.isdigit():
return int(string)
try:
return float(string)
except:
return string
def miio_command_help(did=None, prefix='?'):
quote = '' if prefix == '?' else "'"
return f'\
Get Props: {prefix}<siid[-piid]>[,...]\n\
{prefix}1,1-2,1-3,1-4,2-1,2-2,3\n\
Set Props: {prefix}<siid[-piid]=[#]value>[,...]\n\
{prefix}2=60,2-1=#60,2-2=false,2-3="null",3=test\n\
Do Action: {prefix}<siid[-piid]> <arg1|[]> [...] \n\
{prefix}2 []\n\
{prefix}5 Hello\n\
{prefix}5-4 Hello 1\n\n\
Call MIoT: {prefix}<cmd=prop/get|/prop/set|action> <params>\n\
{prefix}action {quote}{{"did":"{did or "267090026"}","siid":5,"aiid":1,"in":["Hello"]}}{quote}\n\n\
Call MiIO: {prefix}/<uri> <data>\n\
{prefix}/home/device_list {quote}{{"getVirtualModel":false,"getHuamiDevices":1}}{quote}\n\n\
Devs List: {prefix}list [name=full|name_keyword] [getVirtualModel=false|true] [getHuamiDevices=0|1]\n\
{prefix}list Light true 0\n\n\
MIoT Spec: {prefix}spec [model_keyword|type_urn] [format=text|python|json]\n\
{prefix}spec\n\
{prefix}spec speaker\n\
{prefix}spec xiaomi.wifispeaker.lx04\n\
{prefix}spec urn:miot-spec-v2:device:speaker:0000A015:xiaomi-lx04:1\n\n\
MIoT Decode: {prefix}decode <ssecurity> <nonce> <data> [gzip]\n\
'
async def miio_command(service: MiIOService, did, text, prefix='?'):
cmd, arg = twins_split(text, ' ')
if cmd.startswith('/'):
return await service.miio_request(cmd, arg)
if cmd.startswith('prop') or cmd == 'action':
return await service.miot_request(cmd, json.loads(arg) if arg else None)
argv = arg.split(' ') if arg else []
argc = len(argv)
if cmd == 'list':
return await service.device_list(argc > 0 and argv[0], argc > 1 and string_to_value(argv[1]), argc > 2 and argv[2])
if cmd == 'spec':
return await service.miot_spec(argc > 0 and argv[0], argc > 1 and argv[1])
if cmd == 'decode':
return MiIOService.miot_decode(argv[0], argv[1], argv[2], argc > 3 and argv[3] == 'gzip')
if not did or not cmd or cmd == '?' or cmd == '' or cmd == 'help' or cmd == '-h' or cmd == '--help':
return miio_command_help(did, prefix)
if not did.isdigit():
devices = await service.device_list(did)
if not devices:
return "Device not found: " + did
did = devices[0]['did']
props = []
setp = True
miot = True
for item in cmd.split(','):
key, value = twins_split(item, '=')
siid, iid = twins_split(key, '-', '1')
if siid.isdigit() and iid.isdigit():
prop = [int(siid), int(iid)]
else:
prop = [key]
miot = False
if value is None:
setp = False
elif setp:
prop.append(string_to_value(value))
props.append(prop)
if miot and argc > 0:
args = [] if arg == '[]' else [string_to_value(a) for a in argv]
return await service.miot_action(did, props[0], args)
do_props = ((service.home_get_props, service.miot_get_props), (service.home_set_props, service.miot_set_props))[setp][miot]
return await do_props(did, props if miot or setp else [p[0] for p in props])

View File

@@ -0,0 +1,197 @@
import os
import time
import base64
import hashlib
import hmac
import json
# REGIONS = ['cn', 'de', 'i2', 'ru', 'sg', 'us']
class MiIOService:
def __init__(self, account=None, region=None):
self.account = account
self.server = 'https://' + ('' if region is None or region == 'cn' else region + '.') + 'api.io.mi.com/app'
async def miio_request(self, uri, data):
def prepare_data(token, cookies):
cookies['PassportDeviceId'] = token['deviceId']
return MiIOService.sign_data(uri, data, token['xiaomiio'][0])
headers = {'User-Agent': 'iOS-14.4-6.0.103-iPhone12,3--D7744744F7AF32F0544445285880DD63E47D9BE9-8816080-84A3F44E137B71AE-iPhone', 'x-xiaomi-protocal-flag-cli': 'PROTOCAL-HTTP2'}
resp = await self.account.mi_request('xiaomiio', self.server + uri, prepare_data, headers)
if 'result' not in resp:
raise Exception(f"Error {uri}: {resp}")
return resp['result']
async def home_request(self, did, method, params):
return await self.miio_request('/home/rpc/' + did, {'id': 1, 'method': method, "accessKey": "IOS00026747c5acafc2", 'params': params})
async def home_get_props(self, did, props):
return await self.home_request(did, 'get_prop', props)
async def home_set_props(self, did, props):
return [await self.home_set_prop(did, i[0], i[1]) for i in props]
async def home_get_prop(self, did, prop):
return (await self.home_get_props(did, [prop]))[0]
async def home_set_prop(self, did, prop, value):
result = (await self.home_request(did, 'set_' + prop, value if isinstance(value, list) else [value]))[0]
return 0 if result == 'ok' else result
async def miot_request(self, cmd, params):
return await self.miio_request('/miotspec/' + cmd, {'params': params})
async def miot_get_props(self, did, iids):
params = [{'did': did, 'siid': i[0], 'piid': i[1]} for i in iids]
result = await self.miot_request('prop/get', params)
return [it.get('value') if it.get('code') == 0 else None for it in result]
async def miot_set_props(self, did, props):
params = [{'did': did, 'siid': i[0], 'piid': i[1], 'value': i[2]} for i in props]
result = await self.miot_request('prop/set', params)
return [it.get('code', -1) for it in result]
async def miot_get_prop(self, did, iid):
return (await self.miot_get_props(did, [iid]))[0]
async def miot_set_prop(self, did, iid, value):
return (await self.miot_set_props(did, [(iid[0], iid[1], value)]))[0]
async def miot_action(self, did, iid, args=[]):
result = await self.miot_request('action', {'did': did, 'siid': iid[0], 'aiid': iid[1], 'in': args})
return result.get('code', -1)
async def device_list(self, name=None, getVirtualModel=False, getHuamiDevices=0):
result = await self.miio_request('/home/device_list', {'getVirtualModel': bool(getVirtualModel), 'getHuamiDevices': int(getHuamiDevices)})
result = result['list']
return result if name == 'full' else [{'name': i['name'], 'model': i['model'], 'did': i['did'], 'token': i['token']} for i in result if not name or name in i['name']]
async def miot_spec(self, type=None, format=None):
if not type or not type.startswith('urn'):
def get_spec(all):
if not type:
return all
ret = {}
for m, t in all.items():
if type == m:
return {m: t}
elif type in m:
ret[m] = t
return ret
import tempfile
path = os.path.join(tempfile.gettempdir(), 'miservice_miot_specs.json')
try:
with open(path) as f:
result = get_spec(json.load(f))
except:
result = None
if not result:
async with self.account.session.get('http://miot-spec.org/miot-spec-v2/instances?status=all') as r:
all = {i['model']: i['type'] for i in (await r.json())['instances']}
with open(path, 'w') as f:
json.dump(all, f)
result = get_spec(all)
if len(result) != 1:
return result
type = list(result.values())[0]
url = 'http://miot-spec.org/miot-spec-v2/instance?type=' + type
async with self.account.session.get(url) as r:
result = await r.json()
def parse_desc(node):
desc = node['description']
# pos = desc.find(' ')
# if pos != -1:
# return (desc[:pos], ' # ' + desc[pos + 2:])
name = ''
for i in range(len(desc)):
d = desc[i]
if d in '-—{「[【(<《':
return (name, ' # ' + desc[i:])
name += '_' if d == ' ' else d
return (name, '')
def make_line(siid, iid, desc, comment, readable=False):
value = f"({siid}, {iid})" if format == 'python' else iid
return f" {'' if readable else '_'}{desc} = {value}{comment}\n"
if format != 'json':
STR_HEAD, STR_SRV, STR_VALUE = ('from enum import Enum\n\n', '\nclass {}(tuple, Enum):\n', '\nclass {}(int, Enum):\n') if format == 'python' else ('', '{} = {}\n', '{}\n')
text = '# Generated by https://github.com/Yonsm/MiService\n# ' + url + '\n\n' + STR_HEAD
svcs = []
vals = []
for s in result['services']:
siid = s['iid']
svc = s['description'].replace(' ', '_')
svcs.append(svc)
text += STR_SRV.format(svc, siid)
for p in s.get('properties', []):
name, comment = parse_desc(p)
access = p['access']
comment += ''.join([' # ' + k for k, v in [(p['format'], 'string'), (''.join([a[0] for a in access]), 'r')] if k and k != v])
text += make_line(siid, p['iid'], name, comment, 'read' in access)
if 'value-range' in p:
valuer = p['value-range']
length = min(3, len(valuer))
values = {['MIN', 'MAX', 'STEP'][i]: valuer[i] for i in range(length) if i != 2 or valuer[i] != 1}
elif 'value-list' in p:
values = {i['description'].replace(' ', '_') if i['description'] else str(i['value']): i['value'] for i in p['value-list']}
else:
continue
vals.append((svc + '_' + name, values))
if 'actions' in s:
text += '\n'
for a in s['actions']:
name, comment = parse_desc(a)
comment += ''.join([f" # {io}={a[io]}" for io in ['in', 'out'] if a[io]])
text += make_line(siid, a['iid'], name, comment)
text += '\n'
for name, values in vals:
text += STR_VALUE.format(name)
for k, v in values.items():
text += f" {'_' + k if k.isdigit() else k} = {v}\n"
text += '\n'
if format == 'python':
text += '\nALL_SVCS = (' + ', '.join(svcs) + ')\n'
result = text
return result
@staticmethod
def miot_decode(ssecurity, nonce, data, gzip=False):
from Crypto.Cipher import ARC4
r = ARC4.new(base64.b64decode(MiIOService.sign_nonce(ssecurity, nonce)))
r.encrypt(bytes(1024))
decrypted = r.encrypt(base64.b64decode(data))
if gzip:
try:
from io import BytesIO
from gzip import GzipFile
compressed = BytesIO()
compressed.write(decrypted)
compressed.seek(0)
decrypted = GzipFile(fileobj=compressed, mode='rb').read()
except:
pass
return json.loads(decrypted.decode())
@staticmethod
def sign_nonce(ssecurity, nonce):
m = hashlib.sha256()
m.update(base64.b64decode(ssecurity))
m.update(base64.b64decode(nonce))
return base64.b64encode(m.digest()).decode()
@staticmethod
def sign_data(uri, data, ssecurity):
if not isinstance(data, str):
data = json.dumps(data)
nonce = base64.b64encode(os.urandom(8) + int(time.time() / 60).to_bytes(4, 'big')).decode()
snonce = MiIOService.sign_nonce(ssecurity, nonce)
msg = '&'.join([uri, snonce, nonce, 'data=' + data])
sign = hmac.new(key=base64.b64decode(snonce), msg=msg.encode(), digestmod=hashlib.sha256).digest()
return {'_nonce': nonce, 'data': data, 'signature': base64.b64encode(sign).decode()}

View File

@@ -0,0 +1,50 @@
import json
from .miaccount import MiAccount, get_random
import logging
_LOGGER = logging.getLogger(__package__)
class MiNAService:
def __init__(self, account: MiAccount):
self.account = account
async def mina_request(self, uri, data=None):
requestId = 'app_ios_' + get_random(30)
if data is not None:
data['requestId'] = requestId
else:
uri += '&requestId=' + requestId
headers = {'User-Agent': 'MiHome/6.0.103 (com.xiaomi.mihome; build:6.0.103.1; iOS 14.4.0) Alamofire/6.0.103 MICO/iOSApp/appStore/6.0.103'}
return await self.account.mi_request('micoapi', 'https://api2.mina.mi.com' + uri, data, headers)
async def device_list(self, master=0):
result = await self.mina_request('/admin/v2/device_list?master=' + str(master))
return result.get('data') if result else None
async def ubus_request(self, deviceId, method, path, message):
message = json.dumps(message)
result = await self.mina_request('/remote/ubus', {'deviceId': deviceId, 'message': message, 'method': method, 'path': path})
return result and result.get('code') == 0
async def text_to_speech(self, deviceId, text):
return await self.ubus_request(deviceId, 'text_to_speech', 'mibrain', {'text': text})
async def player_set_volume(self, deviceId, volume):
return await self.ubus_request(deviceId, 'player_set_volume', 'mediaplayer', {'volume': volume, 'media': 'app_ios'})
async def send_message(self, devices, devno, message, volume=None): # -1/0/1...
result = False
for i in range(0, len(devices)):
if devno == -1 or devno != i + 1 or devices[i]['capabilities'].get('yunduantts'):
_LOGGER.debug("Send to devno=%d index=%d: %s", devno, i, message or volume)
deviceId = devices[i]['deviceID']
result = True if volume is None else await self.player_set_volume(deviceId, volume)
if result and message:
result = await self.text_to_speech(deviceId, message)
if not result:
_LOGGER.error("Send failed: %s", message or volume)
if devno != -1 or not result:
break
return result

View File

@@ -0,0 +1,63 @@
import logging
import time
import asyncio
import os
from astrbot.api.platform import Platform, AstrBotMessage, MessageMember, MessageType, PlatformMetadata
from astrbot.api.event import MessageChain
from typing import Union, List
from astrbot.api.message_components import Image, Plain, At
from astrbot.core.platform.astr_message_event import MessageSesion
from ...register import register_platform_adapter
from astrbot.core.message.components import BaseMessageComponent
from .client import SimpleMiSpeakerClient
from .mispeaker_event import MiSpeakerPlatformEvent
from astrbot.core import logger
@register_platform_adapter("mispeaker", "小爱音箱")
class MiSpeakerPlatformAdapter(Platform):
def __init__(self, platform_config: dict, platform_settings: dict, event_queue: asyncio.Queue) -> None:
super().__init__(event_queue)
self.config = platform_config
async def send_by_session(self, session: MessageSesion, message_chain: MessageChain):
pass
def meta(self) -> PlatformMetadata:
return PlatformMetadata(
"mispeaker",
"小爱音箱",
)
async def handle_msg(self, message: AstrBotMessage):
message_event = MiSpeakerPlatformEvent(
message_str=message.message_str,
message_obj=message,
platform_meta=self.meta(),
session_id=message.session_id,
client=self.client
)
self.commit_event(message_event)
def run(self):
self.client = SimpleMiSpeakerClient(
self.config
)
async def on_event_received(abm: AstrBotMessage):
logger.info(f"on_event_received: {abm}")
await self.handle_msg(abm)
self.client.on_event_received = on_event_received
return self._run()
async def _run(self):
await self.client.initialize()
await self.client.start_pooling()

View File

@@ -0,0 +1,30 @@
import random
import asyncio
from astrbot.api import logger
from astrbot.api.event import AstrMessageEvent, MessageChain
from astrbot.api.platform import AstrBotMessage, PlatformMetadata
from astrbot.api.message_components import Plain, Image
from .client import SimpleMiSpeakerClient
class MiSpeakerPlatformEvent(AstrMessageEvent):
def __init__(
self,
message_str: str,
message_obj: AstrBotMessage,
platform_meta: PlatformMetadata,
session_id: str,
client: SimpleMiSpeakerClient
):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.client = client
@staticmethod
async def send_with_client(message: MessageChain, user_name: str):
pass
async def send(self, message: MessageChain):
for comp in message.chain:
if isinstance(comp, Plain):
await self.client.send(comp.text)
await super().send(message)

View File

@@ -14,20 +14,12 @@ class QQOfficialMessageEvent(AstrMessageEvent):
def __init__(self, message_str: str, message_obj: AstrBotMessage, platform_meta: PlatformMetadata, session_id: str, bot: Client):
super().__init__(message_str, message_obj, platform_meta, session_id)
self.bot = bot
self.send_buffer = None
async def send(self, message: MessageChain):
if not self.send_buffer:
self.send_buffer = message
else:
self.send_buffer.chain.extend(message.chain)
async def _post_send(self):
'''QQ 官方 API 仅支持回复一次'''
source = self.message_obj.raw_message
assert isinstance(source, (botpy.message.Message, botpy.message.GroupMessage, botpy.message.DirectMessage, botpy.message.C2CMessage))
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(self.send_buffer)
plain_text, image_base64, image_path = await QQOfficialMessageEvent._parse_to_qqofficial(message)
payload = {
'content': plain_text,
@@ -56,9 +48,7 @@ class QQOfficialMessageEvent(AstrMessageEvent):
payload['file_image'] = image_path
await self.bot.api.post_dms(guild_id=source.guild_id, **payload)
await super().send(self.send_buffer)
self.send_buffer = None
await super().send(message)
async def upload_group_and_c2c_image(self, image_base64: str, file_type: int, **kwargs) -> botpy.types.message.Media:
payload = {
@@ -90,7 +80,4 @@ class QQOfficialMessageEvent(AstrMessageEvent):
elif i.file and i.file.startswith("http"):
image_file_path = await download_image_by_url(i.file)
image_base64 = file_to_base64(image_file_path).replace("base64://", "")
else:
image_base64 = file_to_base64(i.file).replace("base64://", "")
image_file_path = i.file
return plain_text, image_base64, image_file_path

View File

@@ -32,10 +32,6 @@ class WebChatMessageEvent(AstrMessageEvent):
f.write(f2.read())
elif comp.file and comp.file.startswith("http"):
await download_image_by_url(comp.file, path=path)
else:
with open(path, "wb") as f:
with open(comp.file, "rb") as f2:
f.write(f2.read())
web_chat_back_queue.put_nowait((f"[IMAGE]{filename}", cid))
web_chat_back_queue.put_nowait(None)
await super().send(message)

View File

@@ -1,6 +1,6 @@
import traceback
from astrbot.core.config.astrbot_config import AstrBotConfig
from .provider import Provider, STTProvider, TTSProvider, Personality
from .provider import Provider, STTProvider, Personality
from .entites import ProviderType
from typing import List
from astrbot.core.db import BaseDatabase
@@ -64,15 +64,11 @@ class ProviderManager():
'''加载的 Provider 的实例'''
self.stt_provider_insts: List[STTProvider] = []
'''加载的 Speech To Text Provider 的实例'''
self.tts_provider_insts: List[TTSProvider] = []
'''加载的 Text To Speech Provider 的实例'''
self.llm_tools = llm_tools
self.curr_provider_inst: Provider = None
'''当前使用的 Provider 实例'''
self.curr_stt_provider_inst: STTProvider = None
'''当前使用的 Speech To Text Provider 实例'''
self.curr_tts_provider_inst: TTSProvider = None
'''当前使用的 Text To Speech Provider 实例'''
self.loaded_ids = defaultdict(bool)
self.db_helper = db_helper
@@ -107,8 +103,6 @@ class ProviderManager():
from .sources.whisper_api_source import ProviderOpenAIWhisperAPI # noqa: F401
case "openai_whisper_selfhost":
from .sources.whisper_selfhosted_source import ProviderOpenAIWhisperSelfHost # noqa: F401
case "openai_tts_api":
from .sources.openai_tts_api_source import ProviderOpenAITTSAPI # noqa: F401
except (ImportError, ModuleNotFoundError) as e:
logger.critical(f"加载 {provider_cfg['type']}({provider_cfg['id']}) 提供商适配器失败:{e}。可能是因为有未安装的依赖。")
continue
@@ -117,21 +111,17 @@ class ProviderManager():
continue
async def initialize(self):
selected_provider_id = sp.get("curr_provider")
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
selected_tts_provider_id = self.provider_settings.get("provider_id")
provider_enabled = self.provider_settings.get("enable", False)
stt_enabled = self.provider_stt_settings.get("enable", False)
tts_enabled = self.provider_settings.get("enable", False)
for provider_config in self.providers_config:
if not provider_config['enable']:
continue
if provider_config['type'] not in provider_cls_map:
logger.error(f"未找到适用于 {provider_config['type']}({provider_config['id']}) 的提供商适配器,请检查是否已经安装或者名称填写错误。已跳过。")
continue
selected_provider_id = sp.get("curr_provider")
selected_stt_provider_id = self.provider_stt_settings.get("provider_id")
provider_enabled = self.provider_settings.get("enable", False)
stt_enabled = self.provider_stt_settings.get("enable", False)
provider_metadata = provider_cls_map[provider_config['type']]
logger.info(f"尝试实例化 {provider_config['type']}({provider_config['id']}) 提供商适配器 ...")
try:
@@ -148,18 +138,6 @@ class ProviderManager():
if selected_stt_provider_id == provider_config['id'] and stt_enabled:
self.curr_stt_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前语音转文本提供商适配器。")
elif provider_metadata.provider_type == ProviderType.TEXT_TO_SPEECH:
# TTS 任务
inst = provider_metadata.cls_type(provider_config, self.provider_settings)
if getattr(inst, "initialize", None):
await inst.initialize()
self.tts_provider_insts.append(inst)
if selected_tts_provider_id == provider_config['id'] and tts_enabled:
self.curr_tts_provider_inst = inst
logger.info(f"已选择 {provider_config['type']}({provider_config['id']}) 作为当前文本转语音提供商适配器。")
elif provider_metadata.provider_type == ProviderType.CHAT_COMPLETION:
# 文本生成任务
@@ -189,18 +167,11 @@ class ProviderManager():
if len(self.stt_provider_insts) > 0 and not self.curr_stt_provider_inst and stt_enabled:
self.curr_stt_provider_inst = self.stt_provider_insts[0]
if len(self.tts_provider_insts) > 0 and not self.curr_tts_provider_inst and tts_enabled:
self.curr_tts_provider_inst = self.tts_provider_insts[0]
if not self.curr_provider_inst:
logger.warning("未启用任何用于 文本生成 的提供商适配器。")
if stt_enabled and not self.curr_stt_provider_inst:
if self.provider_stt_settings.get("enable"):
if not self.curr_stt_provider_inst:
logger.warning("未启用任何用于 语音转文本 的提供商适配器。")
if tts_enabled and not self.curr_tts_provider_inst:
logger.warning("未启用任何用于 文本转语音 的提供商适配器。")
def get_insts(self):
return self.provider_insts

View File

@@ -24,32 +24,9 @@ class ProviderMeta():
id: str
model: str
type: str
class AbstractProvider(abc.ABC):
def __init__(self, provider_config: dict) -> None:
super().__init__()
self.model_name = ""
self.provider_config = provider_config
def set_model(self, model_name: str):
'''设置当前使用的模型名称'''
self.model_name = model_name
def get_model(self) -> str:
'''获得当前使用的模型名称'''
return self.model_name
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
id=self.provider_config['id'],
model=self.get_model(),
type=self.provider_config['type']
)
class Provider(AbstractProvider):
class Provider(abc.ABC):
def __init__(
self,
provider_config: dict,
@@ -58,11 +35,14 @@ class Provider(AbstractProvider):
db_helper: BaseDatabase = None,
default_persona: Personality = None
) -> None:
super().__init__(provider_config)
self.model_name = ""
'''当前使用的模型名称'''
self.session_memory = defaultdict(list)
'''维护了 session_id 的上下文,**不包含 system 指令**。'''
self.provider_config = provider_config
self.provider_settings = provider_settings
self.curr_personality: Personality = default_persona
@@ -78,6 +58,14 @@ class Provider(AbstractProvider):
self.session_memory[history.session_id] = json.loads(history.content)
except BaseException as e:
logger.warning(f"读取 LLM 对话历史记录 失败:{e}。仍可正常使用。")
def set_model(self, model_name: str):
'''设置当前使用的模型名称'''
self.model_name = model_name
def get_model(self) -> str:
'''获得当前使用的模型名称'''
return self.model_name
@abc.abstractmethod
def get_current_key(self) -> str:
@@ -145,11 +133,17 @@ class Provider(AbstractProvider):
'''重置某一个 session_id 的上下文'''
raise NotImplementedError()
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
id=self.provider_config['id'],
model=self.get_model(),
type=self.provider_config['type']
)
class STTProvider(AbstractProvider):
class STTProvider():
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@@ -157,15 +151,19 @@ class STTProvider(AbstractProvider):
async def get_text(self, audio_url: str) -> str:
'''获取音频的文本'''
raise NotImplementedError()
class TTSProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def get_audio(self, text: str) -> str:
'''获取文本的音频,返回音频文件路径'''
raise NotImplementedError()
def set_model(self, model_name: str):
'''设置当前使用的模型名称'''
self.model_name = model_name
def get_model(self) -> str:
'''获取当前使用的模型'''
return self.provider_config.get("model", "")
def meta(self) -> ProviderMeta:
'''获取 Provider 的元数据'''
return ProviderMeta(
id=self.provider_config['id'],
model=self.get_model(),
type=self.provider_config['type']
)

View File

@@ -1,3 +1,4 @@
import traceback
import base64
import json
@@ -109,7 +110,7 @@ class ProviderOpenAIOfficial(Provider):
)
assert isinstance(completion, ChatCompletion)
logger.debug(f"completion: {completion}")
logger.debug(f"completion: {completion.usage}")
if len(completion.choices) == 0:
raise Exception("API 返回的 completion 为空。")

View File

@@ -1,40 +0,0 @@
import uuid
import os
from openai import AsyncOpenAI, NOT_GIVEN
from ..provider import TTSProvider
from ..entites import ProviderType
from ..register import register_provider_adapter
@register_provider_adapter("openai_tts_api", "OpenAI TTS API", provider_type=ProviderType.TEXT_TO_SPEECH)
class ProviderOpenAITTSAPI(TTSProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
) -> None:
super().__init__(provider_config, provider_settings)
self.chosen_api_key = provider_config.get("api_key", "")
self.voice = provider_config.get("voice", "alloy")
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
timeout=provider_config.get("timeout", NOT_GIVEN),
)
self.set_model(provider_config.get("model", None))
async def get_audio(self, text: str) -> str:
path = f'data/temp/openai_tts_api_{uuid.uuid4()}.wav'
async with self.client.audio.speech.with_streaming_response.create(
model=self.model_name,
voice=self.voice,
response_format='wav',
input=text
) as response:
with open(path, 'wb') as f:
async for chunk in response.iter_bytes(chunk_size=1024):
f.write(chunk)
return path

View File

@@ -6,8 +6,6 @@ import time
import aiohttp
import base64
import zipfile
import uuid
from typing import Union
from PIL import Image
@@ -43,21 +41,21 @@ def port_checker(port: int, host: str = "localhost"):
return False
def save_temp_img(img: Union[Image.Image, str]) -> str:
def save_temp_img(img: Image) -> str:
os.makedirs("data/temp", exist_ok=True)
# 获得文件创建时间,清除超过 12 小时的
# 获得文件创建时间,清除超过1小时的
try:
for f in os.listdir("data/temp"):
path = os.path.join("data/temp", f)
if os.path.isfile(path):
ctime = os.path.getctime(path)
if time.time() - ctime > 3600*12:
if time.time() - ctime > 3600:
os.remove(path)
except Exception as e:
print(f"清除临时文件失败: {e}")
# 获得时间戳
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
timestamp = int(time.time())
p = f"data/temp/{timestamp}.jpg"
if isinstance(img, Image.Image):

View File

@@ -20,23 +20,18 @@ async def tencent_silk_to_wav(silk_path: str, output_path: str) -> str:
return output_path
async def wav_to_tencent_silk(wav_path: str, output_path: str) -> int:
'''返回 duration'''
async def wav_to_tencent_silk(wav_path: str) -> BytesIO:
import pysilk
with wave.open(wav_path, 'rb') as wav:
wav_data = wav.readframes(wav.getnframes())
wav_data = BytesIO(wav_data)
output_io = BytesIO()
pysilk.encode(wav_data, output_io, 24000, 24000)
pysilk.encode(wav_data, output_io, 24000)
output_io.seek(0)
# 在首字节添加 \x02,去除结尾的\xff\xff
# 在首字节添加 \x02
silk_data = output_io.read()
silk_data_with_prefix = b'\x02' + silk_data[:-2]
silk_data_with_prefix = b'\x02' + silk_data
# return BytesIO(silk_data_with_prefix)
with open(output_path, "wb") as f:
f.write(silk_data_with_prefix)
return 0
return BytesIO(silk_data_with_prefix)

View File

@@ -63,13 +63,6 @@ class UpdateRoute(Route):
await download_dashboard()
except Exception as e:
logger.error(f"下载管理面板文件失败: {e}")
# pip 更新依赖
logger.info("更新依赖中...")
try:
pip_installer.install(requirements_path="requirements.txt")
except Exception as e:
logger.error(f"更新依赖失败: {e}")
if reboot:
# threading.Thread(target=self.astrbot_updator._reboot, args=(2, )).start()

View File

@@ -1,6 +0,0 @@
# What's Changed
- Gewechat 微信支持图片、语音的收和发
- 支持 OpenAI TTS文字转语音
- 支持路径映射,解决 docker 部署时两端文件系统不一致导致的富媒体文件路径不存在问题
- Napcat 下语音消息可能接收异常

View File

@@ -1,4 +0,0 @@
# What's Changed
- 修复 astrbot_updator 属性缺失与stt_enabled 未初始化 #252
- 支持消息分段回复