"""MIT License Copyright (c) 2021 Lxns-Network Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import asyncio import base64 import json import os import uuid from enum import Enum from pydantic.v1 import BaseModel from astrbot.core import astrbot_config, file_token_service, logger from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64 class ComponentType(str, Enum): # Basic Segment Types Plain = "Plain" # plain text message Image = "Image" # image Record = "Record" # audio Video = "Video" # video File = "File" # file attachment # IM-specific Segment Types Face = "Face" # Emoji segment for Tencent QQ platform At = "At" # mention a user in IM apps Node = "Node" # a node in a forwarded message Nodes = "Nodes" # a forwarded message consisting of multiple nodes Poke = "Poke" # a poke message for Tencent QQ platform Reply = "Reply" # a reply message segment Forward = "Forward" # a forwarded message segment RPS = "RPS" # TODO Dice = "Dice" # TODO Shake = "Shake" # TODO Share = "Share" Contact = "Contact" # TODO Location = "Location" # TODO Music = "Music" Json = "Json" Unknown = "Unknown" WechatEmoji = "WechatEmoji" # Wechat 下的 emoji 表情包 class BaseMessageComponent(BaseModel): type: ComponentType def toDict(self): data = {} for k, v in self.__dict__.items(): if k == "type" or v is None: continue if k == "_type": k = "type" data[k] = v return {"type": self.type.lower(), "data": data} async def to_dict(self) -> dict: # 默认情况下,回退到旧的同步 toDict() return self.toDict() class Plain(BaseMessageComponent): type = ComponentType.Plain text: str convert: bool | None = True def __init__(self, text: str, convert: bool = True, **_): super().__init__(text=text, convert=convert, **_) def toDict(self): return {"type": "text", "data": {"text": self.text.strip()}} async def to_dict(self): return {"type": "text", "data": {"text": self.text}} class Face(BaseMessageComponent): type = ComponentType.Face id: int def __init__(self, **_): super().__init__(**_) class Record(BaseMessageComponent): type = ComponentType.Record file: str | None = "" magic: bool | None = False url: str | None = "" cache: bool | None = True proxy: bool | None = True timeout: int | None = 0 # 额外 path: str | None def __init__(self, file: str | None, **_): for k in _: if k == "url": pass # Protocol.warn(f"go-cqhttp doesn't support send {self.type} by {k}") super().__init__(file=file, **_) @staticmethod def fromFileSystem(path, **_): return Record(file=f"file:///{os.path.abspath(path)}", path=path, **_) @staticmethod def fromURL(url: str, **_): if url.startswith("http://") or url.startswith("https://"): return Record(file=url, **_) raise Exception("not a valid url") @staticmethod def fromBase64(bs64_data: str, **_): return Record(file=f"base64://{bs64_data}", **_) async def convert_to_file_path(self) -> str: """将这个语音统一转换为本地文件路径。这个方法避免了手动判断语音数据类型,直接返回语音数据的本地路径(如果是网络 URL, 则会自动进行下载)。 Returns: str: 语音的本地路径,以绝对路径表示。 """ if not self.file: raise Exception(f"not a valid file: {self.file}") if self.file.startswith("file:///"): return self.file[8:] if self.file.startswith("http"): file_path = await download_image_by_url(self.file) return os.path.abspath(file_path) if self.file.startswith("base64://"): bs64_data = self.file.removeprefix("base64://") image_bytes = base64.b64decode(bs64_data) temp_dir = os.path.join(get_astrbot_data_path(), "temp") file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg") with open(file_path, "wb") as f: f.write(image_bytes) return os.path.abspath(file_path) if os.path.exists(self.file): return os.path.abspath(self.file) raise Exception(f"not a valid file: {self.file}") async def convert_to_base64(self) -> str: """将语音统一转换为 base64 编码。这个方法避免了手动判断语音数据类型,直接返回语音数据的 base64 编码。 Returns: str: 语音的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。 """ # convert to base64 if not self.file: raise Exception(f"not a valid file: {self.file}") if self.file.startswith("file:///"): bs64_data = file_to_base64(self.file[8:]) elif self.file.startswith("http"): file_path = await download_image_by_url(self.file) bs64_data = file_to_base64(file_path) elif self.file.startswith("base64://"): bs64_data = self.file elif os.path.exists(self.file): bs64_data = file_to_base64(self.file) else: raise Exception(f"not a valid file: {self.file}") bs64_data = bs64_data.removeprefix("base64://") return bs64_data async def register_to_file_service(self) -> str: """将语音注册到文件服务。 Returns: str: 注册后的URL Raises: Exception: 如果未配置 callback_api_base """ callback_host = astrbot_config.get("callback_api_base") if not callback_host: raise Exception("未配置 callback_api_base,文件服务不可用") file_path = await self.convert_to_file_path() token = await file_token_service.register_file(file_path) logger.debug(f"已注册:{callback_host}/api/file/{token}") return f"{callback_host}/api/file/{token}" class Video(BaseMessageComponent): type = ComponentType.Video file: str cover: str | None = "" c: int | None = 2 # 额外 path: str | None = "" def __init__(self, file: str, **_): super().__init__(file=file, **_) @staticmethod def fromFileSystem(path, **_): return Video(file=f"file:///{os.path.abspath(path)}", path=path, **_) @staticmethod def fromURL(url: str, **_): if url.startswith("http://") or url.startswith("https://"): return Video(file=url, **_) raise Exception("not a valid url") async def convert_to_file_path(self) -> str: """将这个视频统一转换为本地文件路径。这个方法避免了手动判断视频数据类型,直接返回视频数据的本地路径(如果是网络 URL,则会自动进行下载)。 Returns: str: 视频的本地路径,以绝对路径表示。 """ url = self.file if url and url.startswith("file:///"): return url[8:] if url and url.startswith("http"): download_dir = os.path.join(get_astrbot_data_path(), "temp") video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}") await download_file(url, video_file_path) if os.path.exists(video_file_path): return os.path.abspath(video_file_path) raise Exception(f"download failed: {url}") if os.path.exists(url): return os.path.abspath(url) raise Exception(f"not a valid file: {url}") async def register_to_file_service(self): """将视频注册到文件服务。 Returns: str: 注册后的URL Raises: Exception: 如果未配置 callback_api_base """ callback_host = astrbot_config.get("callback_api_base") if not callback_host: raise Exception("未配置 callback_api_base,文件服务不可用") file_path = await self.convert_to_file_path() token = await file_token_service.register_file(file_path) logger.debug(f"已注册:{callback_host}/api/file/{token}") return f"{callback_host}/api/file/{token}" async def to_dict(self): """需要和 toDict 区分开,toDict 是同步方法""" url_or_path = self.file if url_or_path.startswith("http"): payload_file = url_or_path elif callback_host := astrbot_config.get("callback_api_base"): callback_host = str(callback_host).removesuffix("/") token = await file_token_service.register_file(url_or_path) payload_file = f"{callback_host}/api/file/{token}" logger.debug(f"Generated video file callback link: {payload_file}") else: payload_file = url_or_path return { "type": "video", "data": { "file": payload_file, }, } class At(BaseMessageComponent): type = ComponentType.At qq: int | str # 此处str为all时代表所有人 name: str | None = "" def __init__(self, **_): super().__init__(**_) def toDict(self): return { "type": "at", "data": {"qq": str(self.qq)}, } class AtAll(At): qq: str = "all" def __init__(self, **_): super().__init__(**_) class RPS(BaseMessageComponent): # TODO type = ComponentType.RPS def __init__(self, **_): super().__init__(**_) class Dice(BaseMessageComponent): # TODO type = ComponentType.Dice def __init__(self, **_): super().__init__(**_) class Shake(BaseMessageComponent): # TODO type = ComponentType.Shake def __init__(self, **_): super().__init__(**_) class Share(BaseMessageComponent): type = ComponentType.Share url: str title: str content: str | None = "" image: str | None = "" def __init__(self, **_): super().__init__(**_) class Contact(BaseMessageComponent): # TODO type = ComponentType.Contact _type: str # type 字段冲突 id: int | None = 0 def __init__(self, **_): super().__init__(**_) class Location(BaseMessageComponent): # TODO type = ComponentType.Location lat: float lon: float title: str | None = "" content: str | None = "" def __init__(self, **_): super().__init__(**_) class Music(BaseMessageComponent): type = ComponentType.Music _type: str id: int | None = 0 url: str | None = "" audio: str | None = "" title: str | None = "" content: str | None = "" image: str | None = "" def __init__(self, **_): # for k in _.keys(): # if k == "_type" and _[k] not in ["qq", "163", "xm", "custom"]: # logger.warn(f"Protocol: {k}={_[k]} doesn't match values") super().__init__(**_) class Image(BaseMessageComponent): type = ComponentType.Image file: str | None = "" _type: str | None = "" subType: int | None = 0 url: str | None = "" cache: bool | None = True id: int | None = 40000 c: int | None = 2 # 额外 path: str | None = "" file_unique: str | None = "" # 某些平台可能有图片缓存的唯一标识 def __init__(self, file: str | None, **_): super().__init__(file=file, **_) @staticmethod def fromURL(url: str, **_): if url.startswith("http://") or url.startswith("https://"): return Image(file=url, **_) raise Exception("not a valid url") @staticmethod def fromFileSystem(path, **_): return Image(file=f"file:///{os.path.abspath(path)}", path=path, **_) @staticmethod def fromBase64(base64: str, **_): return Image(f"base64://{base64}", **_) @staticmethod def fromBytes(byte: bytes): return Image.fromBase64(base64.b64encode(byte).decode()) @staticmethod def fromIO(IO): return Image.fromBytes(IO.read()) async def convert_to_file_path(self) -> str: """将这个图片统一转换为本地文件路径。这个方法避免了手动判断图片数据类型,直接返回图片数据的本地路径(如果是网络 URL, 则会自动进行下载)。 Returns: str: 图片的本地路径,以绝对路径表示。 """ url = self.url or self.file if not url: raise ValueError("No valid file or URL provided") if url.startswith("file:///"): return url[8:] if url.startswith("http"): image_file_path = await download_image_by_url(url) return os.path.abspath(image_file_path) if url.startswith("base64://"): bs64_data = url.removeprefix("base64://") image_bytes = base64.b64decode(bs64_data) temp_dir = os.path.join(get_astrbot_data_path(), "temp") image_file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg") with open(image_file_path, "wb") as f: f.write(image_bytes) return os.path.abspath(image_file_path) if os.path.exists(url): return os.path.abspath(url) raise Exception(f"not a valid file: {url}") async def convert_to_base64(self) -> str: """将这个图片统一转换为 base64 编码。这个方法避免了手动判断图片数据类型,直接返回图片数据的 base64 编码。 Returns: str: 图片的 base64 编码,不以 base64:// 或者 data:image/jpeg;base64, 开头。 """ # convert to base64 url = self.url or self.file if not url: raise ValueError("No valid file or URL provided") if url.startswith("file:///"): bs64_data = file_to_base64(url[8:]) elif url.startswith("http"): image_file_path = await download_image_by_url(url) bs64_data = file_to_base64(image_file_path) elif url.startswith("base64://"): bs64_data = url elif os.path.exists(url): bs64_data = file_to_base64(url) else: raise Exception(f"not a valid file: {url}") bs64_data = bs64_data.removeprefix("base64://") return bs64_data async def register_to_file_service(self) -> str: """将图片注册到文件服务。 Returns: str: 注册后的URL Raises: Exception: 如果未配置 callback_api_base """ callback_host = astrbot_config.get("callback_api_base") if not callback_host: raise Exception("未配置 callback_api_base,文件服务不可用") file_path = await self.convert_to_file_path() token = await file_token_service.register_file(file_path) logger.debug(f"已注册:{callback_host}/api/file/{token}") return f"{callback_host}/api/file/{token}" class Reply(BaseMessageComponent): type = ComponentType.Reply id: str | int """所引用的消息 ID""" chain: list["BaseMessageComponent"] | None = [] """被引用的消息段列表""" sender_id: int | None | str = 0 """被引用的消息对应的发送者的 ID""" sender_nickname: str | None = "" """被引用的消息对应的发送者的昵称""" time: int | None = 0 """被引用的消息发送时间""" message_str: str | None = "" """被引用的消息解析后的纯文本消息字符串""" text: str | None = "" """deprecated""" qq: int | None = 0 """deprecated""" seq: int | None = 0 """deprecated""" def __init__(self, **_): super().__init__(**_) class Poke(BaseMessageComponent): type: str = ComponentType.Poke id: int | None = 0 qq: int | None = 0 def __init__(self, type: str, **_): type = f"Poke:{type}" super().__init__(type=type, **_) class Forward(BaseMessageComponent): type = ComponentType.Forward id: str def __init__(self, **_): super().__init__(**_) class Node(BaseMessageComponent): """群合并转发消息""" type = ComponentType.Node id: int | None = 0 # 忽略 name: str | None = "" # qq昵称 uin: str | None = "0" # qq号 content: list[BaseMessageComponent] | None = [] seq: str | list | None = "" # 忽略 time: int | None = 0 # 忽略 def __init__(self, content: list[BaseMessageComponent], **_): if isinstance(content, Node): # back content = [content] super().__init__(content=content, **_) async def to_dict(self): data_content = [] for comp in self.content: if isinstance(comp, (Image, Record)): # For Image and Record segments, we convert them to base64 bs64 = await comp.convert_to_base64() data_content.append( { "type": comp.type.lower(), "data": {"file": f"base64://{bs64}"}, }, ) elif isinstance(comp, Plain): # For Plain segments, we need to handle the plain differently d = await comp.to_dict() data_content.append(d) elif isinstance(comp, File): # For File segments, we need to handle the file differently d = await comp.to_dict() data_content.append(d) elif isinstance(comp, (Node, Nodes)): # For Node segments, we recursively convert them to dict d = await comp.to_dict() data_content.append(d) else: d = comp.toDict() data_content.append(d) return { "type": "node", "data": { "user_id": str(self.uin), "nickname": self.name, "content": data_content, }, } class Nodes(BaseMessageComponent): type = ComponentType.Nodes nodes: list[Node] def __init__(self, nodes: list[Node], **_): super().__init__(nodes=nodes, **_) def toDict(self): """Deprecated. Use to_dict instead""" ret = { "messages": [], } for node in self.nodes: d = node.toDict() ret["messages"].append(d) return ret async def to_dict(self): """将 Nodes 转换为字典格式,适用于 OneBot JSON 格式""" ret = {"messages": []} for node in self.nodes: d = await node.to_dict() ret["messages"].append(d) return ret class Json(BaseMessageComponent): type = ComponentType.Json data: str | dict resid: int | None = 0 def __init__(self, data, **_): if isinstance(data, dict): data = json.dumps(data) super().__init__(data=data, **_) class Unknown(BaseMessageComponent): type = ComponentType.Unknown text: str class File(BaseMessageComponent): """文件消息段""" type = ComponentType.File name: str | None = "" # 名字 file_: str | None = "" # 本地路径 url: str | None = "" # url def __init__(self, name: str, file: str = "", url: str = ""): """文件消息段。""" super().__init__(name=name, file_=file, url=url) @property def file(self) -> str: """获取文件路径,如果文件不存在但有URL,则同步下载文件 Returns: str: 文件路径 """ if self.file_ and os.path.exists(self.file_): return os.path.abspath(self.file_) if self.url: try: loop = asyncio.get_event_loop() if loop.is_running(): logger.warning( "不可以在异步上下文中同步等待下载! " "这个警告通常发生于某些逻辑试图通过 .file 获取文件消息段的文件内容。" "请使用 await get_file() 代替直接获取 .file 字段", ) return "" # 等待下载完成 loop.run_until_complete(self._download_file()) if self.file_ and os.path.exists(self.file_): return os.path.abspath(self.file_) except Exception as e: logger.error(f"文件下载失败: {e}") return "" @file.setter def file(self, value: str): """向前兼容, 设置file属性, 传入的参数可能是文件路径或URL Args: value (str): 文件路径或URL """ if value.startswith("http://") or value.startswith("https://"): self.url = value else: self.file_ = value async def get_file(self, allow_return_url: bool = False) -> str: """异步获取文件。请注意在使用后清理下载的文件, 以免占用过多空间 Args: allow_return_url: 是否允许以文件 http 下载链接的形式返回,这允许您自行控制是否需要下载文件。 注意,如果为 True,也可能返回文件路径。 Returns: str: 文件路径或者 http 下载链接 """ if allow_return_url and self.url: return self.url if self.file_ and os.path.exists(self.file_): return os.path.abspath(self.file_) if self.url: await self._download_file() return os.path.abspath(self.file_) return "" async def _download_file(self): """下载文件""" download_dir = os.path.join(get_astrbot_data_path(), "temp") os.makedirs(download_dir, exist_ok=True) file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}") await download_file(self.url, file_path) self.file_ = os.path.abspath(file_path) async def register_to_file_service(self): """将文件注册到文件服务。 Returns: str: 注册后的URL Raises: Exception: 如果未配置 callback_api_base """ callback_host = astrbot_config.get("callback_api_base") if not callback_host: raise Exception("未配置 callback_api_base,文件服务不可用") file_path = await self.get_file() token = await file_token_service.register_file(file_path) logger.debug(f"已注册:{callback_host}/api/file/{token}") return f"{callback_host}/api/file/{token}" async def to_dict(self): """需要和 toDict 区分开,toDict 是同步方法""" url_or_path = await self.get_file(allow_return_url=True) if url_or_path.startswith("http"): payload_file = url_or_path elif callback_host := astrbot_config.get("callback_api_base"): callback_host = str(callback_host).removesuffix("/") token = await file_token_service.register_file(url_or_path) payload_file = f"{callback_host}/api/file/{token}" logger.debug(f"Generated file callback link: {payload_file}") else: payload_file = url_or_path return { "type": "file", "data": { "name": self.name, "file": payload_file, }, } class WechatEmoji(BaseMessageComponent): type = ComponentType.WechatEmoji md5: str | None = "" md5_len: int | None = 0 cdnurl: str | None = "" def __init__(self, **_): super().__init__(**_) ComponentTypes = { # Basic Message Segments "plain": Plain, "text": Plain, "image": Image, "record": Record, "video": Video, "file": File, # IM-specific Message Segments "face": Face, "at": At, "rps": RPS, "dice": Dice, "shake": Shake, "share": Share, "contact": Contact, "location": Location, "music": Music, "reply": Reply, "poke": Poke, "forward": Forward, "node": Node, "nodes": Nodes, "json": Json, "unknown": Unknown, "WechatEmoji": WechatEmoji, }