From c111da468100330c3d1a986961ea18d56fbda13a Mon Sep 17 00:00:00 2001 From: Raven95676 Date: Tue, 6 May 2025 11:54:14 +0800 Subject: [PATCH] =?UTF-8?q?refactor:=20=E4=BF=AE=E6=94=B9=E6=A1=86?= =?UTF-8?q?=E6=9E=B6=E8=B7=AF=E5=BE=84=E8=8E=B7=E5=8F=96=E6=96=B9=E5=BC=8F?= =?UTF-8?q?=EF=BC=8C=E8=A7=84=E8=8C=83=E5=8C=96=E8=B7=AF=E5=BE=84=E6=8B=BC?= =?UTF-8?q?=E6=8E=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- astrbot/core/__init__.py | 3 +- astrbot/core/config/astrbot_config.py | 3 +- astrbot/core/config/default.py | 5 +- astrbot/core/db/plugin/sqlite_impl.py | 3 +- astrbot/core/message/components.py | 14 +- .../sources/dingtalk/dingtalk_adapter.py | 5 +- .../core/platform/sources/gewechat/client.py | 16 +- .../sources/gewechat/gewechat_event.py | 19 +- .../core/platform/sources/lark/lark_event.py | 5 +- .../platform/sources/telegram/tg_event.py | 8 +- .../sources/webchat/webchat_adapter.py | 4 +- .../platform/sources/webchat/webchat_event.py | 3 +- .../platform/sources/wecom/wecom_adapter.py | 7 +- .../platform/sources/wecom/wecom_event.py | 23 +- astrbot/core/provider/func_tool_manager.py | 5 +- .../core/provider/sources/dashscope_tts.py | 5 +- astrbot/core/provider/sources/dify_source.py | 6 +- .../core/provider/sources/edge_tts_source.py | 7 +- .../sources/fishaudio_tts_api_source.py | 5 +- .../core/provider/sources/gsvi_tts_source.py | 5 +- .../provider/sources/openai_tts_api_source.py | 5 +- .../provider/sources/whisper_api_source.py | 7 +- .../sources/whisper_selfhosted_source.py | 7 +- astrbot/core/rag/knowledge_db_mgr.py | 3 +- astrbot/core/rag/store/chroma_db.py | 4 +- astrbot/core/star/config.py | 10 +- astrbot/core/star/star_manager.py | 16 +- astrbot/core/star/star_tools.py | 6 +- astrbot/core/star/updator.py | 7 +- astrbot/core/updator.py | 16 +- astrbot/core/utils/astrbot_path.py | 41 ++ astrbot/core/utils/io.py | 30 +- astrbot/core/utils/shared_preferences.py | 5 +- astrbot/core/utils/t2i/local_strategy.py | 624 +++++++++++------- astrbot/dashboard/routes/chat.py | 4 +- packages/astrbot/main.py | 5 +- packages/python_interpreter/main.py | 6 +- packages/reminder/main.py | 11 +- 38 files changed, 624 insertions(+), 334 deletions(-) create mode 100644 astrbot/core/utils/astrbot_path.py diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index 59e61d73..b87978fe 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -7,9 +7,10 @@ from astrbot.core.utils.pip_installer import PipInstaller from astrbot.core.db.sqlite import SQLiteDatabase from astrbot.core.config.default import DB_PATH from astrbot.core.config import AstrBotConfig +from .utils.astrbot_path import get_astrbot_data_path # 初始化数据存储文件夹 -os.makedirs("data", exist_ok=True) +os.makedirs(get_astrbot_data_path(), exist_ok=True) astrbot_config = AstrBotConfig() t2i_base_url = astrbot_config.get("t2i_endpoint", "https://t2i.soulter.top/text2img") diff --git a/astrbot/core/config/astrbot_config.py b/astrbot/core/config/astrbot_config.py index 09e66ce1..c43536ea 100644 --- a/astrbot/core/config/astrbot_config.py +++ b/astrbot/core/config/astrbot_config.py @@ -4,8 +4,9 @@ import logging import enum from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP from typing import Dict +from astrbot.core.utils.astrbot_path import get_astrbot_data_path -ASTRBOT_CONFIG_PATH = "data/cmd_config.json" +ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json") logger = logging.getLogger("astrbot") diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index d4f92438..41677cc5 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -2,8 +2,11 @@ 如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。 """ +import os +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + VERSION = "3.5.6" -DB_PATH = "data/data_v3.db" +DB_PATH = os.path.join(get_astrbot_data_path(), "data_v3.db") # 默认配置 DEFAULT_CONFIG = { diff --git a/astrbot/core/db/plugin/sqlite_impl.py b/astrbot/core/db/plugin/sqlite_impl.py index 5440362a..53cfb828 100644 --- a/astrbot/core/db/plugin/sqlite_impl.py +++ b/astrbot/core/db/plugin/sqlite_impl.py @@ -3,8 +3,9 @@ import aiosqlite import os from typing import Any from .plugin_storage import PluginStorage +from astrbot.core.utils.astrbot_path import get_astrbot_data_path -DBPATH = "data/plugin_data/sqlite/plugin_data.db" +DBPATH = os.path.join(get_astrbot_data_path(), "plugin_data", "sqlite", "plugin_data.db") class SQLitePluginStorage(PluginStorage): diff --git a/astrbot/core/message/components.py b/astrbot/core/message/components.py index 74538d09..08d7b869 100644 --- a/astrbot/core/message/components.py +++ b/astrbot/core/message/components.py @@ -32,6 +32,7 @@ from enum import Enum from pydantic.v1 import BaseModel from astrbot.core import logger from astrbot.core.utils.io import download_image_by_url, file_to_base64, download_file +from astrbot.core.utils.astrbot_path import get_astrbot_data_path class ComponentType(Enum): @@ -167,7 +168,8 @@ class Record(BaseMessageComponent): elif self.file and self.file.startswith("base64://"): bs64_data = self.file.removeprefix("base64://") image_bytes = base64.b64decode(bs64_data) - file_path = f"data/temp/{uuid.uuid4()}.jpg" + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg") with open(file_path, "wb") as f: f.write(image_bytes) return os.path.abspath(file_path) @@ -371,7 +373,9 @@ class Image(BaseMessageComponent): elif url and url.startswith("base64://"): bs64_data = url.removeprefix("base64://") image_bytes = base64.b64decode(bs64_data) - image_file_path = f"data/temp/{uuid.uuid4()}.jpg" + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + os.makedirs(temp_dir, exist_ok=True) + image_file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg") with open(image_file_path, "wb") as f: f.write(image_bytes) return os.path.abspath(image_file_path) @@ -631,9 +635,9 @@ class File(BaseMessageComponent): if self._downloaded: return - os.makedirs("data/download", exist_ok=True) - filename = self.name or f"{uuid.uuid4().hex}" - file_path = f"data/download/{filename}" + download_dir = os.path.join(get_astrbot_data_path(), "download") + os.makedirs(download_dir, exist_ok=True) + file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}") await download_file(self.url, file_path) diff --git a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py index 7a83a8ab..e61e2385 100644 --- a/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py +++ b/astrbot/core/platform/sources/dingtalk/dingtalk_adapter.py @@ -1,4 +1,5 @@ import asyncio +import os import uuid import aiohttp import dingtalk_stream @@ -19,6 +20,7 @@ from ...register import register_platform_adapter from astrbot import logger from dingtalk_stream import AckMessage from astrbot.core.utils.io import download_file +from astrbot.core.utils.astrbot_path import get_astrbot_data_path class MyEventHandler(dingtalk_stream.EventHandler): @@ -152,7 +154,8 @@ class DingtalkPlatformAdapter(Platform): "downloadCode": download_code, "robotCode": robot_code, } - f_path = f"data/dingtalk_file_{uuid.uuid4()}.{ext}" + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + f_path = os.path.join(temp_dir, f"dingtalk_file_{uuid.uuid4()}.{ext}") async with aiohttp.ClientSession() as session: async with session.post( "https://api.dingtalk.com/v1.0/robot/messageFiles/download", diff --git a/astrbot/core/platform/sources/gewechat/client.py b/astrbot/core/platform/sources/gewechat/client.py index 36a18ec6..5f97a677 100644 --- a/astrbot/core/platform/sources/gewechat/client.py +++ b/astrbot/core/platform/sources/gewechat/client.py @@ -15,6 +15,7 @@ from astrbot.api.message_components import Plain, Image, At, Record, Video from astrbot.api.platform import AstrBotMessage, MessageMember, MessageType from astrbot.core.utils.io import download_image_by_url from .downloader import GeweDownloader +from astrbot.core.utils.astrbot_path import get_astrbot_data_path try: from .xml_data_parser import GeweDataParser @@ -250,7 +251,10 @@ class SimpleGewechatClient: # 语音消息 if "ImgBuf" in d and "buffer" in d["ImgBuf"]: voice_data = base64.b64decode(d["ImgBuf"]["buffer"]) - file_path = f"data/temp/gewe_voice_{abm.message_id}.silk" + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + file_path = os.path.join( + temp_dir, f"gewe_voice_{abm.message_id}.silk" + ) async with await anyio.open_file(file_path, "wb") as f: await f.write(voice_data) @@ -458,8 +462,10 @@ class SimpleGewechatClient: retry_cnt -= 1 # 需要验证码 - if os.path.exists("data/temp/gewe_code"): - with open("data/temp/gewe_code", "r") as f: + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + code_file_path = os.path.join(temp_dir, "gewe_code") + if os.path.exists(code_file_path): + with open(code_file_path, "r") as f: code = f.read().strip() if not code: logger.warning( @@ -470,9 +476,9 @@ class SimpleGewechatClient: payload["captchCode"] = code logger.info(f"使用验证码: {code}") try: - os.remove("data/temp/gewe_code") + os.remove(code_file_path) except Exception: - logger.warning("删除验证码文件 data/temp/gewe_code 失败。") + logger.warning(f"删除验证码文件 {code_file_path} 失败。") async with aiohttp.ClientSession() as session: async with session.post( diff --git a/astrbot/core/platform/sources/gewechat/gewechat_event.py b/astrbot/core/platform/sources/gewechat/gewechat_event.py index 3a62b7a8..f549d9ec 100644 --- a/astrbot/core/platform/sources/gewechat/gewechat_event.py +++ b/astrbot/core/platform/sources/gewechat/gewechat_event.py @@ -6,7 +6,7 @@ import traceback import os from typing import AsyncGenerator -from astrbot.core.utils.io import save_temp_img, download_file +from astrbot.core.utils.io import download_file from astrbot.core.utils.tencent_record_helper import wav_to_tencent_silk from astrbot.api import logger from astrbot.api.event import AstrMessageEvent, MessageChain @@ -21,6 +21,7 @@ from astrbot.api.message_components import ( WechatEmoji as Emoji, ) from .client import SimpleGewechatClient +from astrbot.core.utils.astrbot_path import get_astrbot_data_path def get_wav_duration(file_path): @@ -106,7 +107,8 @@ class GewechatPlatformEvent(AstrMessageEvent): # 根据 url 下载视频 if video_url.startswith("http"): video_filename = f"{uuid.uuid4()}.mp4" - video_path = f"data/temp/{video_filename}" + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + video_path = os.path.join(temp_dir, video_filename) await download_file(video_url, video_path) else: video_path = video_url @@ -115,7 +117,10 @@ class GewechatPlatformEvent(AstrMessageEvent): video_callback_url = f"{client.file_server_url}/{video_token}" # 获取视频第一帧 - thumb_path = f"data/temp/gewechat_video_thumb_{uuid.uuid4()}.jpg" + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + thumb_path = os.path.join( + temp_dir, f"gewechat_video_thumb_{uuid.uuid4()}.jpg" + ) video_path = video_path.replace(" ", "\\ ") try: @@ -154,7 +159,8 @@ class GewechatPlatformEvent(AstrMessageEvent): record_url = comp.file record_path = await comp.convert_to_file_path() - silk_path = f"data/temp/{uuid.uuid4()}.silk" + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + silk_path = os.path.join(temp_dir, f"{uuid.uuid4()}.silk") try: duration = await wav_to_tencent_silk(record_path, silk_path) except Exception as e: @@ -173,7 +179,10 @@ class GewechatPlatformEvent(AstrMessageEvent): if file_path.startswith("file:///"): file_path = file_path[8:] elif file_path.startswith("http"): - await download_file(file_path, f"data/temp/{file_name}") + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + temp_file_path = os.path.join(temp_dir, file_name) + await download_file(file_path, temp_file_path) + file_path = temp_file_path else: file_path = file_path diff --git a/astrbot/core/platform/sources/lark/lark_event.py b/astrbot/core/platform/sources/lark/lark_event.py index 600c13e5..994d1495 100644 --- a/astrbot/core/platform/sources/lark/lark_event.py +++ b/astrbot/core/platform/sources/lark/lark_event.py @@ -1,4 +1,5 @@ import json +import os import uuid import base64 import lark_oapi as lark @@ -9,6 +10,7 @@ from astrbot.api.message_components import Plain, Image as AstrBotImage, At from astrbot.core.utils.io import download_image_by_url from lark_oapi.api.im.v1 import * from astrbot import logger +from astrbot.core.utils.astrbot_path import get_astrbot_data_path class LarkMessageEvent(AstrMessageEvent): @@ -40,7 +42,8 @@ class LarkMessageEvent(AstrMessageEvent): base64_str = comp.file.removeprefix("base64://") image_data = base64.b64decode(base64_str) # save as temp file - file_path = f"data/temp/{uuid.uuid4()}_test.jpg" + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + file_path = os.path.join(temp_dir, f"{uuid.uuid4()}_test.jpg") with open(file_path, "wb") as f: f.write(BytesIO(image_data).getvalue()) else: diff --git a/astrbot/core/platform/sources/telegram/tg_event.py b/astrbot/core/platform/sources/telegram/tg_event.py index 8e26b896..4b9fd0ad 100644 --- a/astrbot/core/platform/sources/telegram/tg_event.py +++ b/astrbot/core/platform/sources/telegram/tg_event.py @@ -1,3 +1,4 @@ +import os import asyncio import telegramify_markdown from astrbot.api.event import AstrMessageEvent, MessageChain @@ -13,6 +14,7 @@ from astrbot.api.message_components import ( from telegram.ext import ExtBot from astrbot.core.utils.io import download_file from astrbot import logger +from astrbot.core.utils.astrbot_path import get_astrbot_data_path class TelegramPlatformEvent(AstrMessageEvent): @@ -75,7 +77,8 @@ class TelegramPlatformEvent(AstrMessageEvent): await client.send_photo(photo=image_path, **payload) elif isinstance(i, File): if i.file.startswith("https://"): - path = "data/temp/" + i.name + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + path = os.path.join(temp_dir, i.name) await download_file(i.file, path) i.file = path @@ -126,7 +129,8 @@ class TelegramPlatformEvent(AstrMessageEvent): continue elif isinstance(i, File): if i.file.startswith("https://"): - path = "data/temp/" + i.name + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + path = os.path.join(temp_dir, i.name) await download_file(i.file, path) i.file = path diff --git a/astrbot/core/platform/sources/webchat/webchat_adapter.py b/astrbot/core/platform/sources/webchat/webchat_adapter.py index 01a042fb..fa384ed9 100644 --- a/astrbot/core/platform/sources/webchat/webchat_adapter.py +++ b/astrbot/core/platform/sources/webchat/webchat_adapter.py @@ -17,6 +17,7 @@ from astrbot.core import web_chat_queue from .webchat_event import WebChatMessageEvent from astrbot.core.platform.astr_message_event import MessageSesion from ...register import register_platform_adapter +from astrbot.core.utils.astrbot_path import get_astrbot_data_path class QueueListener: @@ -40,7 +41,8 @@ class WebChatAdapter(Platform): self.config = platform_config self.settings = platform_settings self.unique_session = platform_settings["unique_session"] - self.imgs_dir = "data/webchat/imgs" + self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") + os.makedirs(self.imgs_dir, exist_ok=True) self.metadata = PlatformMetadata( name="webchat", description="webchat", id=self.config.get("id") diff --git a/astrbot/core/platform/sources/webchat/webchat_event.py b/astrbot/core/platform/sources/webchat/webchat_event.py index e60d6d14..76b5dc85 100644 --- a/astrbot/core/platform/sources/webchat/webchat_event.py +++ b/astrbot/core/platform/sources/webchat/webchat_event.py @@ -6,8 +6,9 @@ from astrbot.api.event import AstrMessageEvent, MessageChain from astrbot.api.message_components import Plain, Image, Record from astrbot.core.utils.io import download_image_by_url from astrbot.core import web_chat_back_queue +from astrbot.core.utils.astrbot_path import get_astrbot_data_path -imgs_dir = "data/webchat/imgs" +imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") class WebChatMessageEvent(AstrMessageEvent): diff --git a/astrbot/core/platform/sources/wecom/wecom_adapter.py b/astrbot/core/platform/sources/wecom/wecom_adapter.py index c6b7c096..d04a7b74 100644 --- a/astrbot/core/platform/sources/wecom/wecom_adapter.py +++ b/astrbot/core/platform/sources/wecom/wecom_adapter.py @@ -1,3 +1,4 @@ +import os import sys import uuid import asyncio @@ -23,6 +24,7 @@ from wechatpy.enterprise.messages import TextMessage, ImageMessage, VoiceMessage from wechatpy.exceptions import InvalidSignatureException from wechatpy.enterprise import parse_message from .wecom_event import WecomPlatformEvent +from astrbot.core.utils.astrbot_path import get_astrbot_data_path if sys.version_info >= (3, 12): from typing import override @@ -191,14 +193,15 @@ class WecomPlatformAdapter(Platform): resp: Response = await asyncio.get_event_loop().run_in_executor( None, self.client.media.download, msg.media_id ) - path = f"data/temp/wecom_{msg.media_id}.amr" + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + path = os.path.join(temp_dir, f"wecom_{msg.media_id}.amr") with open(path, "wb") as f: f.write(resp.content) try: from pydub import AudioSegment - path_wav = f"data/temp/wecom_{msg.media_id}.wav" + path_wav = os.path.join(temp_dir, f"wecom_{msg.media_id}.wav") audio = AudioSegment.from_file(path) audio.export(path_wav, format="wav") except Exception as e: diff --git a/astrbot/core/platform/sources/wecom/wecom_event.py b/astrbot/core/platform/sources/wecom/wecom_event.py index 05fc33da..fb820d6e 100644 --- a/astrbot/core/platform/sources/wecom/wecom_event.py +++ b/astrbot/core/platform/sources/wecom/wecom_event.py @@ -1,3 +1,4 @@ +import os import uuid import asyncio from astrbot.api.event import AstrMessageEvent, MessageChain @@ -6,6 +7,7 @@ from astrbot.api.message_components import Plain, Image, Record from wechatpy.enterprise import WeChatClient from astrbot.api import logger +from astrbot.core.utils.astrbot_path import get_astrbot_data_path try: import pydub @@ -52,19 +54,29 @@ class WecomPlatformEvent(AstrMessageEvent): if start + 2048 >= len(plain): result.append(plain[start:]) break - + # 向前搜索分割标点符号 end = min(start + 2048, len(plain)) cut_position = end for i in range(end, start, -1): - if i < len(plain) and plain[i-1] in ["。", "!", "?", ".", "!", "?", "\n", ";", ";"]: + if i < len(plain) and plain[i - 1] in [ + "。", + "!", + "?", + ".", + "!", + "?", + "\n", + ";", + ";", + ]: cut_position = i break - + # 没找到合适的位置分割, 直接切分 if cut_position == end and end < len(plain): cut_position = end - + result.append(plain[start:cut_position]) start = cut_position @@ -103,7 +115,8 @@ class WecomPlatformEvent(AstrMessageEvent): elif isinstance(comp, Record): record_path = await comp.convert_to_file_path() # 转成amr - record_path_amr = f"data/temp/{uuid.uuid4()}.amr" + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + record_path_amr = os.path.join(temp_dir, f"{uuid.uuid4()}.amr") pydub.AudioSegment.from_wav(record_path).export( record_path_amr, format="amr" ) diff --git a/astrbot/core/provider/func_tool_manager.py b/astrbot/core/provider/func_tool_manager.py index 1ce1efba..7059a00f 100644 --- a/astrbot/core/provider/func_tool_manager.py +++ b/astrbot/core/provider/func_tool_manager.py @@ -12,6 +12,8 @@ from contextlib import AsyncExitStack from astrbot import logger from astrbot.core.utils.log_pipe import LogPipe +from astrbot.core.utils.astrbot_path import get_astrbot_data_path + try: import mcp from mcp.client.sse import sse_client @@ -238,8 +240,7 @@ class FuncCall: } ``` """ - current_dir = os.path.dirname(os.path.abspath(__file__)) - data_dir = os.path.abspath(os.path.join(current_dir, "../../../data")) + data_dir = get_astrbot_data_path() mcp_json_file = os.path.join(data_dir, "mcp_server.json") if not os.path.exists(mcp_json_file): diff --git a/astrbot/core/provider/sources/dashscope_tts.py b/astrbot/core/provider/sources/dashscope_tts.py index f135a35d..29c988d7 100644 --- a/astrbot/core/provider/sources/dashscope_tts.py +++ b/astrbot/core/provider/sources/dashscope_tts.py @@ -1,3 +1,4 @@ +import os import dashscope import uuid import asyncio @@ -5,6 +6,7 @@ from dashscope.audio.tts_v2 import * from ..provider import TTSProvider from ..entities import ProviderType from ..register import register_provider_adapter +from astrbot.core.utils.astrbot_path import get_astrbot_data_path @register_provider_adapter( @@ -24,7 +26,8 @@ class ProviderDashscopeTTSAPI(TTSProvider): dashscope.api_key = self.chosen_api_key async def get_audio(self, text: str) -> str: - path = f"data/temp/dashscope_tts_{uuid.uuid4()}.wav" + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + path = os.path.join(temp_dir, f"dashscope_tts_{uuid.uuid4()}.wav") self.synthesizer = SpeechSynthesizer( model=self.get_model(), voice=self.voice, diff --git a/astrbot/core/provider/sources/dify_source.py b/astrbot/core/provider/sources/dify_source.py index 78e3760c..ad0605f1 100644 --- a/astrbot/core/provider/sources/dify_source.py +++ b/astrbot/core/provider/sources/dify_source.py @@ -1,5 +1,5 @@ import astrbot.core.message.components as Comp - +import os from typing import List from .. import Provider, Personality from ..entities import LLMResponse @@ -10,6 +10,7 @@ from astrbot.core.utils.dify_api_client import DifyAPIClient from astrbot.core.utils.io import download_image_by_url, download_file from astrbot.core import logger, sp from astrbot.core.message.message_event_result import MessageChain +from astrbot.core.utils.astrbot_path import get_astrbot_data_path @register_provider_adapter("dify", "Dify APP 适配器。") @@ -227,7 +228,8 @@ class ProviderDify(Provider): return Comp.Image(file=item["url"], url=item["url"]) case "audio": # 仅支持 wav - path = f"data/temp/{item['filename']}.wav" + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + path = os.path.join(temp_dir, f"{item['filename']}.wav") await download_file(item["url"], path) return Comp.Image(file=item["url"], url=item["url"]) case "video": diff --git a/astrbot/core/provider/sources/edge_tts_source.py b/astrbot/core/provider/sources/edge_tts_source.py index 338abe26..44c2d175 100644 --- a/astrbot/core/provider/sources/edge_tts_source.py +++ b/astrbot/core/provider/sources/edge_tts_source.py @@ -7,6 +7,7 @@ from ..provider import TTSProvider from ..entities import ProviderType from ..register import register_provider_adapter from astrbot.core import logger +from astrbot.core.utils.astrbot_path import get_astrbot_data_path """ edge_tts 方式,能够免费、快速生成语音,使用需要先安装edge-tts库 @@ -40,9 +41,9 @@ class ProviderEdgeTTS(TTSProvider): self.set_model("edge_tts") async def get_audio(self, text: str) -> str: - os.makedirs("data/temp", exist_ok=True) - mp3_path = f"data/temp/edge_tts_temp_{uuid.uuid4()}.mp3" - wav_path = f"data/temp/edge_tts_{uuid.uuid4()}.wav" + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3") + wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav") # 构建 Edge TTS 参数 kwargs = {"text": text, "voice": self.voice} diff --git a/astrbot/core/provider/sources/fishaudio_tts_api_source.py b/astrbot/core/provider/sources/fishaudio_tts_api_source.py index 07d0c32a..c0cf044b 100644 --- a/astrbot/core/provider/sources/fishaudio_tts_api_source.py +++ b/astrbot/core/provider/sources/fishaudio_tts_api_source.py @@ -1,3 +1,4 @@ +import os import uuid import ormsgpack from pydantic import BaseModel, conint @@ -6,6 +7,7 @@ from typing import Annotated, Literal from ..provider import TTSProvider from ..entities import ProviderType from ..register import register_provider_adapter +from astrbot.core.utils.astrbot_path import get_astrbot_data_path class ServeReferenceAudio(BaseModel): @@ -87,7 +89,8 @@ class ProviderFishAudioTTSAPI(TTSProvider): ) async def get_audio(self, text: str) -> str: - path = f"data/temp/fishaudio_tts_api_{uuid.uuid4()}.wav" + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + path = os.path.join(temp_dir, f"fishaudio_tts_api_{uuid.uuid4()}.wav") self.headers["content-type"] = "application/msgpack" request = await self._generate_request(text) async with AsyncClient(base_url=self.api_base).stream( diff --git a/astrbot/core/provider/sources/gsvi_tts_source.py b/astrbot/core/provider/sources/gsvi_tts_source.py index 581eef4d..c2444819 100644 --- a/astrbot/core/provider/sources/gsvi_tts_source.py +++ b/astrbot/core/provider/sources/gsvi_tts_source.py @@ -1,9 +1,11 @@ +import os import uuid import aiohttp import urllib.parse from ..provider import TTSProvider from ..entities import ProviderType from ..register import register_provider_adapter +from astrbot.core.utils.astrbot_path import get_astrbot_data_path @register_provider_adapter( @@ -23,7 +25,8 @@ class ProviderGSVITTS(TTSProvider): self.emotion = provider_config.get("emotion") async def get_audio(self, text: str) -> str: - path = f"data/temp/gsvi_tts_{uuid.uuid4()}.wav" + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + path = os.path.join(temp_dir, f"gsvi_tts_{uuid.uuid4()}.wav") params = {"text": text} if self.character: diff --git a/astrbot/core/provider/sources/openai_tts_api_source.py b/astrbot/core/provider/sources/openai_tts_api_source.py index 20b00f94..c188a9fa 100644 --- a/astrbot/core/provider/sources/openai_tts_api_source.py +++ b/astrbot/core/provider/sources/openai_tts_api_source.py @@ -1,8 +1,10 @@ +import os import uuid from openai import AsyncOpenAI, NOT_GIVEN from ..provider import TTSProvider from ..entities import ProviderType from ..register import register_provider_adapter +from astrbot.core.utils.astrbot_path import get_astrbot_data_path @register_provider_adapter( @@ -31,7 +33,8 @@ class ProviderOpenAITTSAPI(TTSProvider): self.set_model(provider_config.get("model", None)) async def get_audio(self, text: str) -> str: - path = f"data/temp/openai_tts_api_{uuid.uuid4()}.wav" + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + path = os.path.join(temp_dir, f"openai_tts_api_{uuid.uuid4()}.wav") async with self.client.audio.speech.with_streaming_response.create( model=self.model_name, voice=self.voice, response_format="wav", input=text ) as response: diff --git a/astrbot/core/provider/sources/whisper_api_source.py b/astrbot/core/provider/sources/whisper_api_source.py index 0009af90..dfe28697 100644 --- a/astrbot/core/provider/sources/whisper_api_source.py +++ b/astrbot/core/provider/sources/whisper_api_source.py @@ -7,6 +7,7 @@ from astrbot.core.utils.io import download_file from ..register import register_provider_adapter from astrbot.core import logger from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav +from astrbot.core.utils.astrbot_path import get_astrbot_data_path @register_provider_adapter( @@ -50,7 +51,8 @@ class ProviderOpenAIWhisperAPI(STTProvider): is_tencent = True name = str(uuid.uuid4()) - path = os.path.join("data/temp", name) + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + path = os.path.join(temp_dir, name) await download_file(audio_url, path) audio_url = path @@ -61,7 +63,8 @@ class ProviderOpenAIWhisperAPI(STTProvider): is_silk = await self._is_silk_file(audio_url) if is_silk: logger.info("Converting silk file to wav ...") - output_path = os.path.join("data/temp", str(uuid.uuid4()) + ".wav") + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav") await tencent_silk_to_wav(audio_url, output_path) audio_url = output_path diff --git a/astrbot/core/provider/sources/whisper_selfhosted_source.py b/astrbot/core/provider/sources/whisper_selfhosted_source.py index 96f0b6f6..7cb76cc4 100644 --- a/astrbot/core/provider/sources/whisper_selfhosted_source.py +++ b/astrbot/core/provider/sources/whisper_selfhosted_source.py @@ -8,6 +8,7 @@ from astrbot.core.utils.io import download_file from ..register import register_provider_adapter from astrbot.core import logger from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav +from astrbot.core.utils.astrbot_path import get_astrbot_data_path @register_provider_adapter( @@ -53,7 +54,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider): is_tencent = True name = str(uuid.uuid4()) - path = os.path.join("data/temp", name) + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + path = os.path.join(temp_dir, name) await download_file(audio_url, path) audio_url = path @@ -64,7 +66,8 @@ class ProviderOpenAIWhisperSelfHost(STTProvider): is_silk = await self._is_silk_file(audio_url) if is_silk: logger.info("Converting silk file to wav ...") - output_path = os.path.join("data/temp", str(uuid.uuid4()) + ".wav") + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav") await tencent_silk_to_wav(audio_url, output_path) audio_url = output_path diff --git a/astrbot/core/rag/knowledge_db_mgr.py b/astrbot/core/rag/knowledge_db_mgr.py index 2aed0e44..f1c1f386 100644 --- a/astrbot/core/rag/knowledge_db_mgr.py +++ b/astrbot/core/rag/knowledge_db_mgr.py @@ -3,11 +3,12 @@ from typing import List, Dict from astrbot.core import logger from .store import Store from astrbot.core.config import AstrBotConfig +from astrbot.core.utils.astrbot_path import get_astrbot_data_path class KnowledgeDBManager: def __init__(self, astrbot_config: AstrBotConfig) -> None: - self.db_path = "data/knowledge_db/" + self.db_path = os.path.join(get_astrbot_data_path(), "knowledge_db") self.config = astrbot_config.get("knowledge_db", {}) self.astrbot_config = astrbot_config if not os.path.exists(self.db_path): diff --git a/astrbot/core/rag/store/chroma_db.py b/astrbot/core/rag/store/chroma_db.py index 30befb97..d4cfae94 100644 --- a/astrbot/core/rag/store/chroma_db.py +++ b/astrbot/core/rag/store/chroma_db.py @@ -4,12 +4,14 @@ from typing import List, Dict from astrbot.api import logger from ..embedding.openai_source import SimpleOpenAIEmbedding from . import Store +from astrbot.core.utils.astrbot_path import get_astrbot_data_path class ChromaVectorStore(Store): def __init__(self, name: str, embedding_cfg: Dict) -> None: + import os self.chroma_client = chromadb.PersistentClient( - path="data/long_term_memory_chroma.db" + path=os.path.join(get_astrbot_data_path(), "long_term_memory_chroma.db") ) self.collection = self.chroma_client.get_or_create_collection(name=name) self.embedding = None diff --git a/astrbot/core/star/config.py b/astrbot/core/star/config.py index dc07fe6f..23a522dc 100644 --- a/astrbot/core/star/config.py +++ b/astrbot/core/star/config.py @@ -5,6 +5,7 @@ from typing import Union import os import json +from astrbot.core.utils.astrbot_path import get_astrbot_data_path def load_config(namespace: str) -> Union[dict, bool]: @@ -13,7 +14,7 @@ def load_config(namespace: str) -> Union[dict, bool]: namespace: str, 配置的唯一识别符,也就是配置文件的名字。 返回值: 当配置文件存在时,返回 namespace 对应配置文件的内容dict,否则返回 False。 """ - path = f"data/config/{namespace}.json" + path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json") if not os.path.exists(path): return False with open(path, "r", encoding="utf-8-sig") as f: @@ -43,7 +44,10 @@ def put_config(namespace: str, name: str, key: str, value, description: str): raise ValueError("key 只支持 str 类型。") if not isinstance(value, (str, int, float, bool, list)): raise ValueError("value 只支持 str, int, float, bool, list 类型。") - path = f"data/config/{namespace}.json" + + config_dir = os.path.join(get_astrbot_data_path(), "config") + path = os.path.join(config_dir, f"{namespace}.json") + if not os.path.exists(path): with open(path, "w", encoding="utf-8-sig") as f: f.write("{}") @@ -71,7 +75,7 @@ def update_config(namespace: str, key: str, value): key: str, 配置项的键。 value: str, int, float, bool, list, 配置项的值。 """ - path = f"data/config/{namespace}.json" + path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json") if not os.path.exists(path): raise FileNotFoundError(f"配置文件 {namespace}.json 不存在。") with open(path, "r", encoding="utf-8-sig") as f: diff --git a/astrbot/core/star/star_manager.py b/astrbot/core/star/star_manager.py index 60b0e0c6..5aa70b23 100644 --- a/astrbot/core/star/star_manager.py +++ b/astrbot/core/star/star_manager.py @@ -22,6 +22,10 @@ from astrbot.core.utils.io import remove_dir from .star import star_registry, star_map from .star_handler import star_handlers_registry from astrbot.core.provider.register import llm_tools +from astrbot.core.utils.astrbot_path import ( + get_astrbot_plugin_path, + get_astrbot_config_path, +) from .filter.permission import PermissionTypeFilter, PermissionType @@ -34,17 +38,9 @@ class PluginManager: self.context._star_manager = self self.config = config - self.plugin_store_path = os.path.abspath( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins" - ) - ) + self.plugin_store_path = get_astrbot_plugin_path() """存储插件的路径。即 data/plugins""" - self.plugin_config_path = os.path.abspath( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "../../../data/config" - ) - ) + self.plugin_config_path = get_astrbot_config_path() """存储插件配置的路径。data/config""" self.reserved_plugin_path = os.path.abspath( os.path.join( diff --git a/astrbot/core/star/star_tools.py b/astrbot/core/star/star_tools.py index 405ccc63..40fa3e51 100644 --- a/astrbot/core/star/star_tools.py +++ b/astrbot/core/star/star_tools.py @@ -1,4 +1,6 @@ import inspect +import os +from pathlib import Path from typing import Union, Awaitable, List, Optional, ClassVar from astrbot.core.message.components import BaseMessageComponent from astrbot.core.message.message_event_result import MessageChain @@ -6,7 +8,7 @@ from astrbot.api.platform import MessageMember, AstrBotMessage from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.star.context import Context from astrbot.core.star.star import star_map -from pathlib import Path +from astrbot.core.utils.astrbot_path import get_astrbot_data_path class StarTools: @@ -180,7 +182,7 @@ class StarTools: plugin_name = metadata.name - data_dir = Path("data/plugin_data") / plugin_name + data_dir = Path(os.path.join(get_astrbot_data_path(), "plugin_data", plugin_name)) try: data_dir.mkdir(parents=True, exist_ok=True) diff --git a/astrbot/core/star/updator.py b/astrbot/core/star/updator.py index d439e98c..45f8b8a2 100644 --- a/astrbot/core/star/updator.py +++ b/astrbot/core/star/updator.py @@ -6,16 +6,13 @@ from ..updator import RepoZipUpdator from astrbot.core.utils.io import remove_dir, on_error from ..star.star import StarMetadata from astrbot.core import logger +from astrbot.core.utils.astrbot_path import get_astrbot_plugin_path class PluginUpdator(RepoZipUpdator): def __init__(self, repo_mirror: str = "") -> None: super().__init__(repo_mirror) - self.plugin_store_path = os.path.abspath( - os.path.join( - os.path.dirname(os.path.abspath(__file__)), "../../../data/plugins" - ) - ) + self.plugin_store_path = get_astrbot_plugin_path() def get_plugin_store_path(self) -> str: return self.plugin_store_path diff --git a/astrbot/core/updator.py b/astrbot/core/updator.py index 1e7279a8..60ed6860 100644 --- a/astrbot/core/updator.py +++ b/astrbot/core/updator.py @@ -6,6 +6,7 @@ from .zip_updator import ReleaseInfo, RepoZipUpdator from astrbot.core import logger from astrbot.core.config.default import VERSION from astrbot.core.utils.io import download_file +from astrbot.core.utils.astrbot_path import get_astrbot_path class AstrBotUpdator(RepoZipUpdator): @@ -16,9 +17,7 @@ class AstrBotUpdator(RepoZipUpdator): def __init__(self, repo_mirror: str = "") -> None: super().__init__(repo_mirror) - self.MAIN_PATH = os.path.abspath( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../") - ) + self.MAIN_PATH = get_astrbot_path() self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases" def terminate_child_processes(self): @@ -51,7 +50,13 @@ class AstrBotUpdator(RepoZipUpdator): self.terminate_child_processes() py = py.replace(" ", "\\ ") try: - os.execl(py, py, *sys.argv) + if "astrbot" in os.path.basename(sys.argv[0]): # 兼容cli + args = [ + f'"{arg}"' if " " in arg else arg for arg in sys.argv[1:] + ] + os.execl(py, py, "-m", "astrbot.cli.__main__", *args) + else: + os.execl(py, py, *sys.argv) except Exception as e: logger.error(f"重启失败({py}, {e}),请尝试手动重启。") raise e @@ -67,6 +72,9 @@ class AstrBotUpdator(RepoZipUpdator): update_data = await self.fetch_release_info(self.ASTRBOT_RELEASE_API, latest) file_url = None + if os.environ.get("ASTRBOT_CLI"): + raise Exception("不支持更新CLI启动的AstrBot") # 避免版本管理混乱 + if latest: latest_version = update_data[0]["tag_name"] if self.compare_version(VERSION, latest_version) >= 0: diff --git a/astrbot/core/utils/astrbot_path.py b/astrbot/core/utils/astrbot_path.py new file mode 100644 index 00000000..64ed9229 --- /dev/null +++ b/astrbot/core/utils/astrbot_path.py @@ -0,0 +1,41 @@ +""" +Astrbot统一路径获取 + +项目路径:固定为源码所在路径 +根目录路径:默认为当前工作目录,可通过环境变量 ASTRBOT_ROOT 指定 +数据目录路径:固定为根目录下的 data 目录 +配置文件路径:固定为数据目录下的 config 目录 +插件目录路径:固定为数据目录下的 plugins 目录 +""" + +import os + + +def get_astrbot_path() -> str: + """获取Astrbot项目路径""" + return os.path.realpath( + os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../") + ) + + +def get_astrbot_root() -> str: + """获取Astrbot根目录路径""" + if path := os.environ.get("ASTRBOT_ROOT"): + return os.path.realpath(path) + else: + return os.path.realpath(os.getcwd()) + + +def get_astrbot_data_path() -> str: + """获取Astrbot数据目录路径""" + return os.path.realpath(os.path.join(get_astrbot_root(), "data")) + + +def get_astrbot_config_path() -> str: + """获取Astrbot配置文件路径""" + return os.path.realpath(os.path.join(get_astrbot_data_path(), "config")) + + +def get_astrbot_plugin_path() -> str: + """获取Astrbot插件目录路径""" + return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins")) diff --git a/astrbot/core/utils/io.py b/astrbot/core/utils/io.py index 37a39de9..2cd8fd9c 100644 --- a/astrbot/core/utils/io.py +++ b/astrbot/core/utils/io.py @@ -14,6 +14,7 @@ import certifi from typing import Union from PIL import Image +from .astrbot_path import get_astrbot_data_path def on_error(func, path, exc_info): @@ -49,11 +50,11 @@ def port_checker(port: int, host: str = "localhost"): def save_temp_img(img: Union[Image.Image, str]) -> str: - os.makedirs("data/temp", exist_ok=True) + temp_dir = os.path.join(get_astrbot_data_path(), "temp") # 获得文件创建时间,清除超过 12 小时的 try: - for f in os.listdir("data/temp"): - path = os.path.join("data/temp", f) + for f in os.listdir(temp_dir): + path = os.path.join(temp_dir, f) if os.path.isfile(path): ctime = os.path.getctime(path) if time.time() - ctime > 3600 * 12: @@ -63,7 +64,7 @@ def save_temp_img(img: Union[Image.Image, str]) -> str: # 获得时间戳 timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}" - p = f"data/temp/{timestamp}.jpg" + p = os.path.join(temp_dir, f"{timestamp}.jpg") if isinstance(img, Image.Image): img.save(p) @@ -201,28 +202,29 @@ def get_local_ip_addresses(): async def get_dashboard_version(): - if os.path.exists("data/dist"): - if os.path.exists("data/dist/assets/version"): - with open("data/dist/assets/version", "r") as f: + dist_dir = os.path.join(get_astrbot_data_path(), "dist") + if os.path.exists(dist_dir): + version_file = os.path.join(dist_dir, "assets", "version") + if os.path.exists(version_file): + with open(version_file, "r") as f: v = f.read().strip() return v return None -async def download_dashboard(path: str = "data/dashboard.zip", extract_path: str = "data"): +async def download_dashboard(path: str = None, extract_path: str = "data"): """下载管理面板文件""" + if path is None: + path = os.path.join(get_astrbot_data_path(), "dashboard.zip") + dashboard_release_url = "https://astrbot-registry.soulter.top/download/astrbot-dashboard/latest/dist.zip" try: - await download_file( - dashboard_release_url, path, show_progress=True - ) + await download_file(dashboard_release_url, path, show_progress=True) except BaseException as _: dashboard_release_url = ( "https://github.com/Soulter/AstrBot/releases/latest/download/dist.zip" ) - await download_file( - dashboard_release_url, path, show_progress=True - ) + await download_file(dashboard_release_url, path, show_progress=True) print("解压管理面板文件中...") with zipfile.ZipFile(path, "r") as z: z.extractall(extract_path) diff --git a/astrbot/core/utils/shared_preferences.py b/astrbot/core/utils/shared_preferences.py index 33a68141..7a503583 100644 --- a/astrbot/core/utils/shared_preferences.py +++ b/astrbot/core/utils/shared_preferences.py @@ -1,9 +1,12 @@ import json import os +from .astrbot_path import get_astrbot_data_path class SharedPreferences: - def __init__(self, path="data/shared_preferences.json"): + def __init__(self, path=None): + if path is None: + path = os.path.join(get_astrbot_data_path(), "shared_preferences.json") self.path = path self._data = self._load_preferences() diff --git a/astrbot/core/utils/t2i/local_strategy.py b/astrbot/core/utils/t2i/local_strategy.py index 514e0dd7..19eab2ef 100644 --- a/astrbot/core/utils/t2i/local_strategy.py +++ b/astrbot/core/utils/t2i/local_strategy.py @@ -1,4 +1,5 @@ import re +import os import aiohttp import ssl import certifi @@ -10,38 +11,40 @@ from astrbot.core.config import VERSION from . import RenderStrategy from PIL import ImageFont, Image, ImageDraw from astrbot.core.utils.io import save_temp_img +from astrbot.core.utils.astrbot_path import get_astrbot_data_path class FontManager: """字体管理类,负责加载和缓存字体""" - + _font_cache = {} - + @classmethod def get_font(cls, size: int) -> ImageFont.FreeTypeFont: """获取指定大小的字体,优先从缓存获取""" if size in cls._font_cache: return cls._font_cache[size] - + # 首先尝试加载自定义字体 try: - font = ImageFont.truetype("data/font.ttf", size) + font_path = os.path.join(get_astrbot_data_path(), "font.ttf") + font = ImageFont.truetype(font_path, size) cls._font_cache[size] = font return font except Exception: pass - + # 跨平台常见字体列表 fonts = [ - "msyh.ttc", # Windows + "msyh.ttc", # Windows "NotoSansCJK-Regular.ttc", # Linux - "msyhbd.ttc", # Windows - "PingFang.ttc", # macOS - "Heiti.ttc", # macOS - "Arial.ttf", # 通用 - "DejaVuSans.ttf", # Linux + "msyhbd.ttc", # Windows + "PingFang.ttc", # macOS + "Heiti.ttc", # macOS + "Arial.ttf", # 通用 + "DejaVuSans.ttf", # Linux ] - + for font_name in fonts: try: font = ImageFont.truetype(font_name, size) @@ -49,7 +52,7 @@ class FontManager: return font except Exception: continue - + # 如果所有字体都失败,使用默认字体 try: default_font = ImageFont.load_default() @@ -61,24 +64,30 @@ class FontManager: class TextMeasurer: """测量文本尺寸的工具类""" - + @staticmethod def get_text_size(text: str, font: ImageFont.FreeTypeFont) -> Tuple[int, int]: """获取文本的尺寸""" try: # PIL 9.0.0 以上版本 - return font.getbbox(text)[2:] if hasattr(font, 'getbbox') else font.getsize(text) + return ( + font.getbbox(text)[2:] + if hasattr(font, "getbbox") + else font.getsize(text) + ) except Exception: # 兼容旧版本 return font.getsize(text) @staticmethod - def split_text_to_fit_width(text: str, font: ImageFont.FreeTypeFont, max_width: int) -> List[str]: + def split_text_to_fit_width( + text: str, font: ImageFont.FreeTypeFont, max_width: int + ) -> List[str]: """将文本拆分为多行,确保每行不超过指定宽度""" lines = [] if not text: return lines - + remaining_text = text while remaining_text: # 如果文本宽度小于最大宽度,直接添加 @@ -86,7 +95,7 @@ class TextMeasurer: if text_width <= max_width: lines.append(remaining_text) break - + # 尝试逐字计算能放入当前行的最多字符 for i in range(len(remaining_text), 0, -1): width = TextMeasurer.get_text_size(remaining_text[:i], font)[0] @@ -98,69 +107,99 @@ class TextMeasurer: # 如果单个字符都放不下,强制放一个字符 lines.append(remaining_text[0]) remaining_text = remaining_text[1:] - + return lines class MarkdownElement(ABC): """Markdown元素的基类""" - + def __init__(self, content: str): self.content = content - + @abstractmethod def calculate_height(self, image_width: int, font_size: int) -> int: """计算元素的高度""" pass - + @abstractmethod - def render(self, image: Image.Image, draw: ImageDraw.Draw, x: int, y: int, image_width: int, font_size: int) -> int: + def render( + self, + image: Image.Image, + draw: ImageDraw.Draw, + x: int, + y: int, + image_width: int, + font_size: int, + ) -> int: """渲染元素到图像,返回新的y坐标""" pass class TextElement(MarkdownElement): """普通文本元素""" - + def calculate_height(self, image_width: int, font_size: int) -> int: if not self.content.strip(): return 10 # 空行高度 - + font = FontManager.get_font(font_size) - lines = TextMeasurer.split_text_to_fit_width(self.content, font, image_width - 20) + lines = TextMeasurer.split_text_to_fit_width( + self.content, font, image_width - 20 + ) return len(lines) * (font_size + 8) - - def render(self, image: Image.Image, draw: ImageDraw.Draw, x: int, y: int, image_width: int, font_size: int) -> int: + + def render( + self, + image: Image.Image, + draw: ImageDraw.Draw, + x: int, + y: int, + image_width: int, + font_size: int, + ) -> int: if not self.content.strip(): return y + 10 # 空行 - + font = FontManager.get_font(font_size) - lines = TextMeasurer.split_text_to_fit_width(self.content, font, image_width - 20) - + lines = TextMeasurer.split_text_to_fit_width( + self.content, font, image_width - 20 + ) + for line in lines: draw.text((x, y), line, font=font, fill=(0, 0, 0)) y += font_size + 8 - + return y class BoldTextElement(MarkdownElement): """粗体文本元素""" - + def calculate_height(self, image_width: int, font_size: int) -> int: font = FontManager.get_font(font_size) - lines = TextMeasurer.split_text_to_fit_width(self.content, font, image_width - 20) + lines = TextMeasurer.split_text_to_fit_width( + self.content, font, image_width - 20 + ) return len(lines) * (font_size + 8) - - def render(self, image: Image.Image, draw: ImageDraw.Draw, x: int, y: int, image_width: int, font_size: int) -> int: + + def render( + self, + image: Image.Image, + draw: ImageDraw.Draw, + x: int, + y: int, + image_width: int, + font_size: int, + ) -> int: # 尝试使用粗体字体,如果没有则绘制两次模拟粗体效果 try: bold_fonts = [ - "msyhbd.ttc", # 微软雅黑粗体 (Windows) + "msyhbd.ttc", # 微软雅黑粗体 (Windows) "Arial-Bold.ttf", # Arial粗体 "DejaVuSans-Bold.ttf", # Linux粗体 ] - + bold_font = None for font_name in bold_fonts: try: @@ -168,48 +207,64 @@ class BoldTextElement(MarkdownElement): break except Exception: continue - + if bold_font: - lines = TextMeasurer.split_text_to_fit_width(self.content, bold_font, image_width - 20) + lines = TextMeasurer.split_text_to_fit_width( + self.content, bold_font, image_width - 20 + ) for line in lines: draw.text((x, y), line, font=bold_font, fill=(0, 0, 0)) y += font_size + 8 else: # 如果没有粗体字体,则绘制两次文本轻微偏移以模拟粗体 font = FontManager.get_font(font_size) - lines = TextMeasurer.split_text_to_fit_width(self.content, font, image_width - 20) + lines = TextMeasurer.split_text_to_fit_width( + self.content, font, image_width - 20 + ) for line in lines: draw.text((x, y), line, font=font, fill=(0, 0, 0)) - draw.text((x+1, y), line, font=font, fill=(0, 0, 0)) + draw.text((x + 1, y), line, font=font, fill=(0, 0, 0)) y += font_size + 8 except Exception: # 兜底方案:使用普通字体 font = FontManager.get_font(font_size) - lines = TextMeasurer.split_text_to_fit_width(self.content, font, image_width - 20) + lines = TextMeasurer.split_text_to_fit_width( + self.content, font, image_width - 20 + ) for line in lines: draw.text((x, y), line, font=font, fill=(0, 0, 0)) y += font_size + 8 - + return y class ItalicTextElement(MarkdownElement): """斜体文本元素""" - + def calculate_height(self, image_width: int, font_size: int) -> int: font = FontManager.get_font(font_size) - lines = TextMeasurer.split_text_to_fit_width(self.content, font, image_width - 20) + lines = TextMeasurer.split_text_to_fit_width( + self.content, font, image_width - 20 + ) return len(lines) * (font_size + 8) - - def render(self, image: Image.Image, draw: ImageDraw.Draw, x: int, y: int, image_width: int, font_size: int) -> int: + + def render( + self, + image: Image.Image, + draw: ImageDraw.Draw, + x: int, + y: int, + image_width: int, + font_size: int, + ) -> int: # 尝试使用斜体字体,如果没有则使用倾斜变换模拟斜体效果 try: italic_fonts = [ - "msyhi.ttc", # 微软雅黑斜体 (Windows) + "msyhi.ttc", # 微软雅黑斜体 (Windows) "Arial-Italic.ttf", # Arial斜体 "DejaVuSans-Oblique.ttf", # Linux斜体 ] - + italic_font = None for font_name in italic_fonts: try: @@ -217,312 +272,388 @@ class ItalicTextElement(MarkdownElement): break except Exception: continue - + if italic_font: - lines = TextMeasurer.split_text_to_fit_width(self.content, italic_font, image_width - 20) + lines = TextMeasurer.split_text_to_fit_width( + self.content, italic_font, image_width - 20 + ) for line in lines: draw.text((x, y), line, font=italic_font, fill=(0, 0, 0)) y += font_size + 8 else: # 如果没有斜体字体,使用变换 font = FontManager.get_font(font_size) - lines = TextMeasurer.split_text_to_fit_width(self.content, font, image_width - 20) - + lines = TextMeasurer.split_text_to_fit_width( + self.content, font, image_width - 20 + ) + for line in lines: # 先创建一个临时图像用于倾斜处理 text_width, text_height = TextMeasurer.get_text_size(line, font) - text_img = Image.new('RGBA', (text_width + 20, text_height + 10), (0, 0, 0, 0)) + text_img = Image.new( + "RGBA", (text_width + 20, text_height + 10), (0, 0, 0, 0) + ) text_draw = ImageDraw.Draw(text_img) text_draw.text((0, 0), line, font=font, fill=(0, 0, 0, 255)) - + # 倾斜变换,使用仿射变换实现斜体效果 # 变换矩阵: [1, 0.2, 0, 0, 1, 0] italic_img = text_img.transform( - text_img.size, - Image.AFFINE, - (1, 0.2, 0, 0, 1, 0), - Image.BICUBIC + text_img.size, Image.AFFINE, (1, 0.2, 0, 0, 1, 0), Image.BICUBIC ) - + # 粘贴到原图像 image.paste(italic_img, (x, y), italic_img) y += font_size + 8 except Exception: # 兜底方案:使用普通字体 font = FontManager.get_font(font_size) - lines = TextMeasurer.split_text_to_fit_width(self.content, font, image_width - 20) + lines = TextMeasurer.split_text_to_fit_width( + self.content, font, image_width - 20 + ) for line in lines: draw.text((x, y), line, font=font, fill=(0, 0, 0)) y += font_size + 8 - + return y class UnderlineTextElement(MarkdownElement): """下划线文本元素""" - + def calculate_height(self, image_width: int, font_size: int) -> int: font = FontManager.get_font(font_size) - lines = TextMeasurer.split_text_to_fit_width(self.content, font, image_width - 20) + lines = TextMeasurer.split_text_to_fit_width( + self.content, font, image_width - 20 + ) return len(lines) * (font_size + 8) - - def render(self, image: Image.Image, draw: ImageDraw.Draw, x: int, y: int, image_width: int, font_size: int) -> int: + + def render( + self, + image: Image.Image, + draw: ImageDraw.Draw, + x: int, + y: int, + image_width: int, + font_size: int, + ) -> int: font = FontManager.get_font(font_size) - lines = TextMeasurer.split_text_to_fit_width(self.content, font, image_width - 20) - + lines = TextMeasurer.split_text_to_fit_width( + self.content, font, image_width - 20 + ) + for line in lines: # 绘制文本 draw.text((x, y), line, font=font, fill=(0, 0, 0)) - + # 绘制下划线 text_width, _ = TextMeasurer.get_text_size(line, font) underline_y = y + font_size + 2 - draw.line((x, underline_y, x + text_width, underline_y), fill=(0, 0, 0), width=1) - + draw.line( + (x, underline_y, x + text_width, underline_y), fill=(0, 0, 0), width=1 + ) + y += font_size + 8 - + return y class StrikethroughTextElement(MarkdownElement): """删除线文本元素""" - + def calculate_height(self, image_width: int, font_size: int) -> int: font = FontManager.get_font(font_size) - lines = TextMeasurer.split_text_to_fit_width(self.content, font, image_width - 20) + lines = TextMeasurer.split_text_to_fit_width( + self.content, font, image_width - 20 + ) return len(lines) * (font_size + 8) - - def render(self, image: Image.Image, draw: ImageDraw.Draw, x: int, y: int, image_width: int, font_size: int) -> int: + + def render( + self, + image: Image.Image, + draw: ImageDraw.Draw, + x: int, + y: int, + image_width: int, + font_size: int, + ) -> int: font = FontManager.get_font(font_size) - lines = TextMeasurer.split_text_to_fit_width(self.content, font, image_width - 20) - + lines = TextMeasurer.split_text_to_fit_width( + self.content, font, image_width - 20 + ) + for line in lines: # 绘制文本 draw.text((x, y), line, font=font, fill=(0, 0, 0)) - + # 绘制删除线 text_width, _ = TextMeasurer.get_text_size(line, font) strike_y = y + font_size // 2 draw.line((x, strike_y, x + text_width, strike_y), fill=(0, 0, 0), width=1) - + y += font_size + 8 - + return y class HeaderElement(MarkdownElement): """标题元素""" - + def __init__(self, content: str): # 去除开头的 # 并计算级别 level = 0 for char in content: - if char == '#': + if char == "#": level += 1 else: break - + super().__init__(content[level:].strip()) self.level = min(level, 6) # h1-h6 - + def calculate_height(self, image_width: int, font_size: int) -> int: header_font_size = 42 - (self.level - 1) * 4 font = FontManager.get_font(header_font_size) - lines = TextMeasurer.split_text_to_fit_width(self.content, font, image_width - 20) + lines = TextMeasurer.split_text_to_fit_width( + self.content, font, image_width - 20 + ) return len(lines) * header_font_size + 30 # 包含上下间距和分隔线 - - def render(self, image: Image.Image, draw: ImageDraw.Draw, x: int, y: int, image_width: int, font_size: int) -> int: + + def render( + self, + image: Image.Image, + draw: ImageDraw.Draw, + x: int, + y: int, + image_width: int, + font_size: int, + ) -> int: header_font_size = 42 - (self.level - 1) * 4 font = FontManager.get_font(header_font_size) - + y += 10 # 上间距 draw.text((x, y), self.content, font=font, fill=(0, 0, 0)) - + # 添加分隔线 y += header_font_size + 8 - draw.line( - (x, y, image_width - 10, y), - fill=(230, 230, 230), - width=3 - ) - + draw.line((x, y, image_width - 10, y), fill=(230, 230, 230), width=3) + return y + 10 # 返回包含下间距的新y坐标 class QuoteElement(MarkdownElement): """引用元素""" - + def __init__(self, content: str): # 去除开头的 > super().__init__(content[1:].strip()) - + def calculate_height(self, image_width: int, font_size: int) -> int: font = FontManager.get_font(font_size) - lines = TextMeasurer.split_text_to_fit_width(self.content, font, image_width - 30) # 左边留出引用线的空间 + lines = TextMeasurer.split_text_to_fit_width( + self.content, font, image_width - 30 + ) # 左边留出引用线的空间 return len(lines) * (font_size + 6) + 12 # 包含上下间距 - - def render(self, image: Image.Image, draw: ImageDraw.Draw, x: int, y: int, image_width: int, font_size: int) -> int: + + def render( + self, + image: Image.Image, + draw: ImageDraw.Draw, + x: int, + y: int, + image_width: int, + font_size: int, + ) -> int: font = FontManager.get_font(font_size) - lines = TextMeasurer.split_text_to_fit_width(self.content, font, image_width - 30) - + lines = TextMeasurer.split_text_to_fit_width( + self.content, font, image_width - 30 + ) + total_height = len(lines) * (font_size + 6) - + # 绘制引用线 quote_line_x = x + 3 draw.line( (quote_line_x, y + 6, quote_line_x, y + total_height + 6), fill=(180, 180, 180), - width=5 + width=5, ) - + # 绘制文本 text_x = x + 15 text_y = y + 6 for line in lines: draw.text((text_x, text_y), line, font=font, fill=(180, 180, 180)) text_y += font_size + 6 - + return y + total_height + 12 class ListItemElement(MarkdownElement): """列表项元素""" - + def calculate_height(self, image_width: int, font_size: int) -> int: font = FontManager.get_font(font_size) - lines = TextMeasurer.split_text_to_fit_width(self.content, font, image_width - 30) # 左边留出项目符号的空间 + lines = TextMeasurer.split_text_to_fit_width( + self.content, font, image_width - 30 + ) # 左边留出项目符号的空间 return len(lines) * (font_size + 6) + 16 # 包含上下间距 - - def render(self, image: Image.Image, draw: ImageDraw.Draw, x: int, y: int, image_width: int, font_size: int) -> int: + + def render( + self, + image: Image.Image, + draw: ImageDraw.Draw, + x: int, + y: int, + image_width: int, + font_size: int, + ) -> int: font = FontManager.get_font(font_size) - lines = TextMeasurer.split_text_to_fit_width(self.content, font, image_width - 30) - + lines = TextMeasurer.split_text_to_fit_width( + self.content, font, image_width - 30 + ) + y += 8 # 上间距 - + # 绘制项目符号 bullet_x = x + 5 draw.text((bullet_x, y), "•", font=font, fill=(0, 0, 0)) - + # 绘制文本 text_x = x + 25 text_y = y for line in lines: draw.text((text_x, text_y), line, font=font, fill=(0, 0, 0)) text_y += font_size + 6 - + return text_y + 8 # 包含下间距 class CodeBlockElement(MarkdownElement): """代码块元素""" - + def __init__(self, content: List[str]): super().__init__("\n".join(content)) - + def calculate_height(self, image_width: int, font_size: int) -> int: if not self.content: return 40 # 空代码块的最小高度 - + font = FontManager.get_font(font_size) lines = self.content.split("\n") wrapped_lines = [] - + for line in lines: wrapped = TextMeasurer.split_text_to_fit_width(line, font, image_width - 40) wrapped_lines.extend(wrapped) - + return len(wrapped_lines) * (font_size + 4) + 40 # 包含内边距和上下间距 - - def render(self, image: Image.Image, draw: ImageDraw.Draw, x: int, y: int, image_width: int, font_size: int) -> int: + + def render( + self, + image: Image.Image, + draw: ImageDraw.Draw, + x: int, + y: int, + image_width: int, + font_size: int, + ) -> int: font = FontManager.get_font(font_size) lines = self.content.split("\n") wrapped_lines = [] - + for line in lines: wrapped = TextMeasurer.split_text_to_fit_width(line, font, image_width - 40) wrapped_lines.extend(wrapped) - + content_height = len(wrapped_lines) * (font_size + 4) total_height = content_height + 30 # 包含内边距 - + # 绘制背景 draw.rounded_rectangle( (x, y + 5, image_width - 10, y + total_height), radius=5, fill=(240, 240, 240), - width=1 + width=1, ) - + # 绘制代码 text_y = y + 15 for line in wrapped_lines: draw.text((x + 15, text_y), line, font=font, fill=(0, 0, 0)) text_y += font_size + 4 - + return y + total_height + 10 class InlineCodeElement(MarkdownElement): """行内代码元素""" - + def calculate_height(self, image_width: int, font_size: int) -> int: return font_size + 16 # 包含内边距和上下间距 - - def render(self, image: Image.Image, draw: ImageDraw.Draw, x: int, y: int, image_width: int, font_size: int) -> int: + + def render( + self, + image: Image.Image, + draw: ImageDraw.Draw, + x: int, + y: int, + image_width: int, + font_size: int, + ) -> int: font = FontManager.get_font(font_size) - + # 计算文本大小 text_width, _ = TextMeasurer.get_text_size(self.content, font) text_height = font_size - + # 绘制背景 padding = 4 draw.rounded_rectangle( - ( - x, - y + 4, - x + text_width + padding * 2, - y + text_height + padding * 2 + 4 - ), + (x, y + 4, x + text_width + padding * 2, y + text_height + padding * 2 + 4), radius=5, fill=(230, 230, 230), - width=1 + width=1, ) - + # 绘制文本 - draw.text((x + padding, y + padding + 4), self.content, font=font, fill=(0, 0, 0)) - + draw.text( + (x + padding, y + padding + 4), self.content, font=font, fill=(0, 0, 0) + ) + return y + text_height + 16 # 返回新的y坐标 class ImageElement(MarkdownElement): """图片元素""" - + def __init__(self, content: str, image_url: str): super().__init__(content) self.image_url = image_url self.image = None - + async def load_image(self): """加载图片""" try: ssl_context = ssl.create_default_context(cafile=certifi.where()) connector = aiohttp.TCPConnector(ssl=ssl_context) - - async with aiohttp.ClientSession(trust_env=True, connector=connector) as session: + + async with aiohttp.ClientSession( + trust_env=True, connector=connector + ) as session: async with session.get(self.image_url) as resp: - if (resp.status == 200): + if resp.status == 200: image_data = await resp.read() self.image = Image.open(BytesIO(image_data)) else: print(f"Failed to load image: HTTP {resp.status}") except Exception as e: print(f"Failed to load image: {e}") - + def calculate_height(self, image_width: int, font_size: int) -> int: if self.image is None: return font_size + 20 # 图片加载失败的默认高度 - + # 计算调整大小后的图片高度 max_width = image_width * 0.8 if self.image.width > max_width: @@ -530,52 +661,60 @@ class ImageElement(MarkdownElement): height = int(self.image.height * ratio) else: height = self.image.height - + return height + 30 # 包含上下间距 - - def render(self, image: Image.Image, draw: ImageDraw.Draw, x: int, y: int, image_width: int, font_size: int) -> int: + + def render( + self, + image: Image.Image, + draw: ImageDraw.Draw, + x: int, + y: int, + image_width: int, + font_size: int, + ) -> int: if self.image is None: # 图片加载失败 font = FontManager.get_font(font_size) draw.text((x, y + 10), "[图片加载失败]", font=font, fill=(255, 0, 0)) return y + font_size + 20 - + # 调整图片大小 max_width = image_width * 0.8 pasted_image = self.image - + if pasted_image.width > max_width: ratio = max_width / pasted_image.width new_size = (int(max_width), int(pasted_image.height * ratio)) pasted_image = pasted_image.resize(new_size, Image.LANCZOS) - + # 计算居中位置 paste_x = x + (image_width - pasted_image.width) // 2 - 10 - + # 粘贴图片 - if pasted_image.mode == 'RGBA': + if pasted_image.mode == "RGBA": # 处理透明图片 image.paste(pasted_image, (paste_x, y + 15), pasted_image) else: image.paste(pasted_image, (paste_x, y + 15)) - + return y + pasted_image.height + 30 class MarkdownParser: """Markdown解析器,将文本解析为元素""" - + @staticmethod async def parse(text: str) -> List[MarkdownElement]: elements = [] - lines = text.split('\n') - + lines = text.split("\n") + i = 0 while i < len(lines): line = lines[i].rstrip() - + # 图片检测 - image_match = re.search(r'!\s*\[(.*?)\]\s*\((.*?)\)', line) + image_match = re.search(r"!\s*\[(.*?)\]\s*\((.*?)\)", line) if image_match: image_url = image_match.group(2) element = ImageElement(line, image_url) @@ -583,101 +722,108 @@ class MarkdownParser: elements.append(element) i += 1 continue - + # 标题 - if line.startswith('#'): + if line.startswith("#"): elements.append(HeaderElement(line)) i += 1 continue - + # 引用 - if line.startswith('>'): + if line.startswith(">"): elements.append(QuoteElement(line)) i += 1 continue - + # 列表项 - if line.startswith('-') or line.startswith('*'): + if line.startswith("-") or line.startswith("*"): elements.append(ListItemElement(line[1:].strip())) i += 1 continue - + # 代码块 - if line.startswith('```'): + if line.startswith("```"): code_lines = [] i += 1 # 跳过开始标记行 - - while i < len(lines) and not lines[i].startswith('```'): + + while i < len(lines) and not lines[i].startswith("```"): code_lines.append(lines[i]) i += 1 - + i += 1 # 跳过结束标记行 elements.append(CodeBlockElement(code_lines)) continue - + # 检查行内样式(粗体、斜体、下划线、删除线、行内代码) - if re.search(r'(\*\*.*?\*\*)|(\*.*?\*)|(__.*?__)|(_.*?_)|(~~.*?~~)|(`.*?`)', line): + if re.search( + r"(\*\*.*?\*\*)|(\*.*?\*)|(__.*?__)|(_.*?_)|(~~.*?~~)|(`.*?`)", line + ): # 分析行内样式: # - 粗体: **text** 或 __text__ # - 斜体: *text* 或 _text_ # - 删除线: ~~text~~ # - 行内代码: `text` - + # 定义正则模式和对应的元素类型 patterns = [ - (r'\*\*(.*?)\*\*', BoldTextElement), # **粗体** - (r'__(.*?)__', BoldTextElement), # __粗体__ - (r'\*((?!\*\*).*?)\*', ItalicTextElement), # *斜体* (但不匹配 ** 开头) - (r'_((?!__).*?)_', ItalicTextElement), # _斜体_ (但不匹配 __ 开头) - (r'~~(.*?)~~', StrikethroughTextElement), # ~~删除线~~ - (r'__(.*?)__', UnderlineTextElement), # __下划线__ - (r'`(.*?)`', InlineCodeElement) # `行内代码` + (r"\*\*(.*?)\*\*", BoldTextElement), # **粗体** + (r"__(.*?)__", BoldTextElement), # __粗体__ + ( + r"\*((?!\*\*).*?)\*", + ItalicTextElement, + ), # *斜体* (但不匹配 ** 开头) + (r"_((?!__).*?)_", ItalicTextElement), # _斜体_ (但不匹配 __ 开头) + (r"~~(.*?)~~", StrikethroughTextElement), # ~~删除线~~ + (r"__(.*?)__", UnderlineTextElement), # __下划线__ + (r"`(.*?)`", InlineCodeElement), # `行内代码` ] - + # 创建标记位置列表 markers = [] for pattern, element_class in patterns: for match in re.finditer(pattern, line): - markers.append({ - 'start': match.start(), - 'end': match.end(), - 'text': match.group(1), # 提取内容部分 - 'element_class': element_class - }) - + markers.append( + { + "start": match.start(), + "end": match.end(), + "text": match.group(1), # 提取内容部分 + "element_class": element_class, + } + ) + # 按开始位置排序 - markers.sort(key=lambda x: x['start']) - + markers.sort(key=lambda x: x["start"]) + # 如果没有找到任何匹配,直接添加为普通文本 if not markers: elements.append(TextElement(line)) i += 1 continue - + # 处理每个文本片段 current_pos = 0 for marker in markers: # 添加前面的普通文本 - if marker['start'] > current_pos: - normal_text = line[current_pos:marker['start']] + if marker["start"] > current_pos: + normal_text = line[current_pos : marker["start"]] if normal_text: elements.append(TextElement(normal_text)) - + # 添加特殊样式的文本 - elements.append(marker['element_class'](marker['text'])) - current_pos = marker['end'] - + elements.append(marker["element_class"](marker["text"])) + current_pos = marker["end"] + # 添加最后一段普通文本 if current_pos < len(line): elements.append(TextElement(line[current_pos:])) - + i += 1 continue - + # 行内代码 (如果之前没匹配到混合样式) - inline_code_matches = re.findall(r'`([^`]+)`', line) + inline_code_matches = re.findall(r"`([^`]+)`", line) if inline_code_matches: - parts = re.split(r'`([^`]+)`', line) + parts = re.split(r"`([^`]+)`", line) for j, part in enumerate(parts): if j % 2 == 0: # 普通文本 if part: @@ -686,88 +832,90 @@ class MarkdownParser: elements.append(InlineCodeElement(part)) i += 1 continue - + # 普通文本 elements.append(TextElement(line)) i += 1 - + return elements class MarkdownRenderer: """Markdown渲染器,将元素渲染为图像""" - - def __init__(self, font_size: int = 26, width: int = 800, bg_color: Tuple[int, int, int] = (255, 255, 255)): + + def __init__( + self, + font_size: int = 26, + width: int = 800, + bg_color: Tuple[int, int, int] = (255, 255, 255), + ): self.font_size = font_size self.width = width self.bg_color = bg_color - + async def render(self, markdown_text: str) -> Image.Image: # 解析Markdown文本 elements = await MarkdownParser.parse(markdown_text) - + # 计算总高度 total_height = 20 # 初始边距 for element in elements: total_height += element.calculate_height(self.width, self.font_size) - + # 为页脚添加额外空间 footer_height = 40 total_height += 20 + footer_height # 结束边距 + 页脚高度 - + # 创建图像 - image = Image.new('RGB', (self.width, max(100, total_height)), self.bg_color) + image = Image.new("RGB", (self.width, max(100, total_height)), self.bg_color) draw = ImageDraw.Draw(image) - + # 渲染元素 y = 10 for element in elements: y = element.render(image, draw, 10, y, self.width, self.font_size) - + # 添加页脚 # 克莱因蓝色,近似RGB为(0, 47, 167) klein_blue = (0, 47, 167) # 灰色 grey_color = (130, 130, 130) - + # 绘制"Powered by AstrBot"文本 footer_font_size = 20 footer_font = FontManager.get_font(footer_font_size) - + # 获取"Powered by "和"AstrBot"的宽度以便居中 powered_by_text = "Powered by " astrbot_text = f"AstrBot v{VERSION}" - + powered_by_width, _ = TextMeasurer.get_text_size(powered_by_text, footer_font) astrbot_width, _ = TextMeasurer.get_text_size(astrbot_text, footer_font) - + total_width = powered_by_width + astrbot_width x_start = (self.width - total_width) // 2 - + footer_y = total_height - footer_height - + # 绘制"Powered by "(灰色) draw.text( - (x_start, footer_y), - powered_by_text, - font=footer_font, - fill=grey_color + (x_start, footer_y), powered_by_text, font=footer_font, fill=grey_color ) - + # 绘制"AstrBot"(克莱因蓝) draw.text( - (x_start + powered_by_width, footer_y), - astrbot_text, - font=footer_font, - fill=klein_blue + (x_start + powered_by_width, footer_y), + astrbot_text, + font=footer_font, + fill=klein_blue, ) - + return image class LocalRenderStrategy(RenderStrategy): """本地渲染策略实现""" - + async def render_custom_template( self, tmpl_str: str, tmpl_data: dict, return_url: bool = True ) -> str: @@ -776,9 +924,9 @@ class LocalRenderStrategy(RenderStrategy): async def render(self, text: str, return_url: bool = False) -> str: # 创建渲染器 renderer = MarkdownRenderer(font_size=26, width=800) - + # 渲染Markdown文本 image = await renderer.render(text) - + # 保存图像并返回路径/URL return save_temp_img(image) diff --git a/astrbot/dashboard/routes/chat.py b/astrbot/dashboard/routes/chat.py index d767ddea..17e8b115 100644 --- a/astrbot/dashboard/routes/chat.py +++ b/astrbot/dashboard/routes/chat.py @@ -8,6 +8,7 @@ from astrbot.core.db import BaseDatabase import asyncio from astrbot.core import logger from astrbot.core.core_lifecycle import AstrBotCoreLifecycle +from astrbot.core.utils.astrbot_path import get_astrbot_data_path class ChatRoute(Route): @@ -33,7 +34,8 @@ class ChatRoute(Route): self.db = db self.core_lifecycle = core_lifecycle self.register_routes() - self.imgs_dir = "data/webchat/imgs" + self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs") + os.makedirs(self.imgs_dir, exist_ok=True) self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"] diff --git a/packages/astrbot/main.py b/packages/astrbot/main.py index 9dcd4a68..8f8e1fa1 100644 --- a/packages/astrbot/main.py +++ b/packages/astrbot/main.py @@ -1,3 +1,4 @@ +import os import aiohttp import datetime import builtins @@ -13,6 +14,7 @@ from astrbot.core.platform.astr_message_event import MessageSesion from astrbot.core.platform.message_type import MessageType from astrbot.core.provider.sources.dify_source import ProviderDify from astrbot.core.utils.io import download_dashboard, get_dashboard_version +from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.star.star_handler import star_handlers_registry, StarHandlerMetadata from astrbot.core.star.star import star_map from astrbot.core.star.star_manager import PluginManager @@ -1159,7 +1161,8 @@ UID: {user_id} 此 ID 可用于设置管理员。 @filter.command("gewe_code") async def gewe_code(self, event: AstrMessageEvent, code: str): """保存 gewechat 验证码""" - with open("data/temp/gewe_code", "w", encoding="utf-8") as f: + code_path = os.path.join(get_astrbot_data_path(), "temp","gewe_code") + with open(code_path, "w", encoding="utf-8") as f: f.write(code) yield event.plain_result("验证码已保存。") diff --git a/packages/python_interpreter/main.py b/packages/python_interpreter/main.py index 20eae0c3..84d431b3 100644 --- a/packages/python_interpreter/main.py +++ b/packages/python_interpreter/main.py @@ -15,6 +15,7 @@ from astrbot.api.event import filter from astrbot.api.provider import ProviderRequest from astrbot.api.message_components import Image, File from astrbot.core.utils.io import download_image_by_url, download_file +from astrbot.core.utils.astrbot_path import get_astrbot_data_path PROMPT = """ ## Task @@ -90,7 +91,7 @@ DEFAULT_CONFIG = { }, "docker_host_astrbot_abs_path": "", } -PATH = "data/config/python_interpreter.json" +PATH = os.path.join(get_astrbot_data_path(), "config", "python_interpreter.json") @star.register( @@ -212,7 +213,8 @@ class Main(star.Star): if isinstance(comp, File): if comp.file.startswith("http"): name = comp.name if comp.name else uuid.uuid4().hex[:8] - path = f"data/temp/{name}" + temp_dir = os.path.join(get_astrbot_data_path(), "temp") + path = os.path.join(temp_dir, name) await download_file(comp.file, path) else: path = comp.file diff --git a/packages/reminder/main.py b/packages/reminder/main.py index d72624ef..b15add54 100644 --- a/packages/reminder/main.py +++ b/packages/reminder/main.py @@ -8,6 +8,7 @@ from astrbot.api.event import filter from apscheduler.schedulers.asyncio import AsyncIOScheduler from astrbot.api.event import AstrMessageEvent, MessageEventResult from astrbot.api import llm_tool, logger +from astrbot.core.utils.astrbot_path import get_astrbot_data_path @star.register( @@ -29,10 +30,11 @@ class Main(star.Star): self.scheduler = AsyncIOScheduler(timezone=self.timezone) # set and load config - if not os.path.exists("data/astrbot-reminder.json"): - with open("data/astrbot-reminder.json", "w", encoding="utf-8") as f: + reminder_file = os.path.join(get_astrbot_data_path(), "astrbot-reminder.json") + if not os.path.exists(reminder_file): + with open(reminder_file, "w", encoding="utf-8") as f: f.write("{}") - with open("data/astrbot-reminder.json", "r", encoding="utf-8") as f: + with open(reminder_file, "r", encoding="utf-8") as f: self.reminder_data = json.load(f) self._init_scheduler() @@ -82,7 +84,8 @@ class Main(star.Star): async def _save_data(self): """Save the reminder data.""" - with open("data/astrbot-reminder.json", "w", encoding="utf-8") as f: + reminder_file = os.path.join(get_astrbot_data_path(), "astrbot-reminder.json") + with open(reminder_file, "w", encoding="utf-8") as f: json.dump(self.reminder_data, f, ensure_ascii=False) def _parse_cron_expr(self, cron_expr: str):