diff --git a/.env b/.env new file mode 100644 index 00000000..ab192957 --- /dev/null +++ b/.env @@ -0,0 +1 @@ +ASTRBOT_ROOT = ./data \ No newline at end of file diff --git a/.env.example b/.env.example new file mode 100644 index 00000000..a0507a12 --- /dev/null +++ b/.env.example @@ -0,0 +1,2 @@ +# ASTRBOT 数据目录 +# ASTRBOT_ROOT = ./data diff --git a/astrbot/base/README.md b/astrbot/base/README.md new file mode 100644 index 00000000..b62e9f48 --- /dev/null +++ b/astrbot/base/README.md @@ -0,0 +1,3 @@ +# Base 包 + +- 此包内容仅可单向导出 diff --git a/astrbot/base/__init__.py b/astrbot/base/__init__.py new file mode 100644 index 00000000..466c3e47 --- /dev/null +++ b/astrbot/base/__init__.py @@ -0,0 +1,7 @@ +from .abc import IAstrbotPaths +from .paths import AstrbotPaths + +__all__ = [ + "IAstrbotPaths", + "AstrbotPaths", +] \ No newline at end of file diff --git a/astrbot/base/abc.py b/astrbot/base/abc.py new file mode 100644 index 00000000..d823beaa --- /dev/null +++ b/astrbot/base/abc.py @@ -0,0 +1,61 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from contextlib import AbstractAsyncContextManager, AbstractContextManager +from pathlib import Path + + +# TODO: 抽象基类 +class IAstrbotPaths(ABC): + """路径管理的抽象基类.""" + + @abstractmethod + def __init__(self, name: str) -> None: + """初始化路径管理器.""" + + @classmethod + @abstractmethod + def getPaths(cls, name: str) -> IAstrbotPaths: + """返回Paths实例,用于访问模块的各类目录.""" + + @property + @abstractmethod + def root(self) -> Path: + """获取根目录.""" + + @property + @abstractmethod + def home(self) -> Path: + """获取模块/插件主目录.""" + + @property + @abstractmethod + def config(self) -> Path: + """获取模块配置目录.""" + + @property + @abstractmethod + def data(self) -> Path: + """获取模块数据目录.""" + + @property + @abstractmethod + def log(self) -> Path: + """获取模块日志目录.""" + + @abstractmethod + def reload(self) -> None: + """重新加载环境变量.""" + + @abstractmethod + @classmethod + def is_root(cls, path: Path) -> bool: + """判断路径是否为根目录.""" + + @abstractmethod + def chdir(self, cwd: str = "home") -> AbstractContextManager[Path]: + """临时切换到指定目录, 子进程将继承此 CWD。""" + + @abstractmethod + async def achdir(self, cwd: str = "home") -> AbstractAsyncContextManager[Path]: + """异步临时切换到指定目录, 子进程将继承此 CWD。""" diff --git a/astrbot/base/paths.py b/astrbot/base/paths.py new file mode 100644 index 00000000..c4e98c05 --- /dev/null +++ b/astrbot/base/paths.py @@ -0,0 +1,121 @@ +from __future__ import annotations + +import os +from contextlib import ( + asynccontextmanager, + contextmanager, +) +from os import getenv +from pathlib import Path +from typing import TYPE_CHECKING, ClassVar + +from dotenv import load_dotenv +from packaging.utils import NormalizedName, canonicalize_name + +from astrbot.base.abc import IAstrbotPaths + +if TYPE_CHECKING: + from collections.abc import AsyncGenerator, Generator + + +class AstrbotPaths(IAstrbotPaths): + """Class to manage and provide paths used by Astrbot Canary.""" + + load_dotenv() + astrbot_root: ClassVar[Path] = Path( + getenv("ASTRBOT_ROOT", Path.home() / ".astrbot") + ).absolute() + + def __init__(self, name: str) -> None: + self.name: str = name + # 确保根目录存在 + self.astrbot_root.mkdir(parents=True, exist_ok=True) + + @classmethod + def getPaths(cls, name: str) -> AstrbotPaths: + """返回Paths实例,用于访问模块的各类目录.""" + normalized_name: NormalizedName = canonicalize_name(name) + instance: AstrbotPaths = cls(normalized_name) + instance.name = normalized_name + return instance + + @property + def root(self) -> Path: + """返回根目录.""" + return ( + self.astrbot_root if self.astrbot_root.exists() else Path.cwd() / ".astrbot" + ) + + @property + def home(self) -> Path: + """模块/插件主目录. + + 通过此属性获取模块/插件主目录. + """ + my_home = self.astrbot_root / "home" / self.name + my_home.mkdir(parents=True, exist_ok=True) + return my_home + + @property + def config(self) -> Path: + """返回模块/插件配置目录. + + 搭配 astrbot_config 使用. + """ + config_path = self.astrbot_root / "config" / self.name + config_path.mkdir(parents=True, exist_ok=True) + return config_path + + @property + def data(self) -> Path: + """返回模块数据目录.""" + data_path = self.astrbot_root / "data" / self.name + data_path.mkdir(parents=True, exist_ok=True) + return data_path + + @property + def log(self) -> Path: + """返回模块日志目录.""" + log_path = self.astrbot_root / "logs" / self.name + log_path.mkdir(parents=True, exist_ok=True) + return log_path + + @classmethod + def is_root(cls, path: Path) -> bool: + """检查路径是否为 Astrbot 根目录.""" + if not path.exists() or not path.is_dir(): + return False + # 检查此目录内是是否包含.astrbot标记文件 + if not (path / ".astrbot").exists(): + return False + return True + + def reload(self) -> None: + """重新加载环境变量.""" + load_dotenv() + self.__class__.astrbot_root = Path( + getenv("ASTRBOT_ROOT", Path.home() / ".astrbot") + ).absolute() + + @contextmanager + def chdir(self, cwd: str = "home") -> Generator[Path]: + """临时切换到指定目录, 子进程将继承此 CWD。""" + original_cwd = Path.cwd() + target_dir = self.root / cwd + try: + os.chdir(target_dir) + yield target_dir + finally: + os.chdir(original_cwd) + + # 上面类型标注没错,这里mypy报错,但是这不应该错误,直接忽略掉 + @asynccontextmanager + async def achdir(self, cwd: str = "home") -> AsyncGenerator[Path]: # type: ignore + """异步上下文管理器: 临时切换到指定目录, 子进程将继承此 CWD。""" + original_cwd = Path.cwd() + target_dir = self.root / cwd + try: + os.chdir(target_dir) + yield target_dir + finally: + os.chdir(original_cwd) diff --git a/astrbot/cli/utils/basic.py b/astrbot/cli/utils/basic.py index 5dbe2900..84535564 100644 --- a/astrbot/cli/utils/basic.py +++ b/astrbot/cli/utils/basic.py @@ -1,22 +1,28 @@ +import warnings from pathlib import Path import click +from astrbot.base import AstrbotPaths + def check_astrbot_root(path: str | Path) -> bool: """检查路径是否为 AstrBot 根目录""" - if not isinstance(path, Path): - path = Path(path) - if not path.exists() or not path.is_dir(): - return False - if not (path / ".astrbot").exists(): - return False - return True - + warnings.warn( + "请使用 AstrbotPaths 类代替本模块中的函数", + DeprecationWarning, + stacklevel=2, + ) + return AstrbotPaths.is_root(Path(path)) def get_astrbot_root() -> Path: """获取Astrbot根目录路径""" - return Path.cwd() + warnings.warn( + "请使用 AstrbotPaths 类代替本模块中的函数", + DeprecationWarning, + stacklevel=2, + ) + return AstrbotPaths.astrbot_root async def check_dashboard(astrbot_root: Path) -> None: diff --git a/astrbot/core/pipeline/respond/stage.py b/astrbot/core/pipeline/respond/stage.py index f2044559..107c82a5 100644 --- a/astrbot/core/pipeline/respond/stage.py +++ b/astrbot/core/pipeline/respond/stage.py @@ -1,7 +1,6 @@ import asyncio import math import random -from collections.abc import AsyncGenerator import astrbot.core.message.components as Comp from astrbot.core import logger @@ -153,7 +152,7 @@ class RespondStage(Stage): async def process( self, event: AstrMessageEvent, - ) -> None | AsyncGenerator[None, None]: + ) -> None: result = event.get_result() if result is None: return diff --git a/astrbot/core/updator.py b/astrbot/core/updator.py index d13bab68..3deb1126 100644 --- a/astrbot/core/updator.py +++ b/astrbot/core/updator.py @@ -6,7 +6,7 @@ import psutil from astrbot.core import logger from astrbot.core.config.default import VERSION -from astrbot.core.utils.astrbot_path import get_astrbot_path +from astrbot.core.utils.astrbot_path import get_astrbot_data_path from astrbot.core.utils.io import download_file from .zip_updator import ReleaseInfo, RepoZipUpdator @@ -20,7 +20,7 @@ class AstrBotUpdator(RepoZipUpdator): def __init__(self, repo_mirror: str = "") -> None: super().__init__(repo_mirror) - self.MAIN_PATH = get_astrbot_path() + self.astrbot_root = get_astrbot_data_path() self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases" def terminate_child_processes(self): @@ -117,7 +117,7 @@ class AstrBotUpdator(RepoZipUpdator): try: await download_file(file_url, "temp.zip") logger.info("下载 AstrBot Core 更新文件完成,正在执行解压...") - self.unzip_file("temp.zip", self.MAIN_PATH) + self.unzip_file("temp.zip", self.astrbot_root) except BaseException as e: raise e diff --git a/astrbot/core/utils/astrbot_path.py b/astrbot/core/utils/astrbot_path.py index e13379b9..8301438c 100644 --- a/astrbot/core/utils/astrbot_path.py +++ b/astrbot/core/utils/astrbot_path.py @@ -8,32 +8,63 @@ """ import os +import warnings + +from astrbot.base import AstrbotPaths def get_astrbot_path() -> str: - """获取Astrbot项目路径""" + """获取Astrbot项目路径 -- 请勿使用本函数!!! -- 仅供兼容旧代码使用""" + warnings.warn( + "get_astrbot_path is deprecated. Use AstrbotPaths class instead.", + DeprecationWarning, + stacklevel=2, + ) return os.path.realpath( os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../"), ) def get_astrbot_root() -> str: - """获取Astrbot根目录路径""" - if path := os.environ.get("ASTRBOT_ROOT"): - return os.path.realpath(path) - return os.path.realpath(os.getcwd()) + """获取Astrbot根目录路径 --> get_astrbot_data_path""" + warnings.warn( + "不要再使用本函数!实际上就是获取data目录!等效于: AstrbotPaths.getPaths('any!').root", + DeprecationWarning, + stacklevel=2, + ) + return str(AstrbotPaths.astrbot_root) def get_astrbot_data_path() -> str: - """获取Astrbot数据目录路径""" - return os.path.realpath(os.path.join(get_astrbot_root(), "data")) + """获取Astrbot数据目录路径 + 特别注意! + 这里的data目录指的就是.astrbot根目录! + 两者是等价的! + 不要和AstrbotPaths.data混淆! + """ + warnings.warn( + "等效于: AstrbotPaths.getPaths('any!').root.env 文件内容: ASTRBOT_ROOT=./data", + DeprecationWarning, + stacklevel=2, + ) + return str(AstrbotPaths.astrbot_root) def get_astrbot_config_path() -> str: """获取Astrbot配置文件路径""" - return os.path.realpath(os.path.join(get_astrbot_data_path(), "config")) + warnings.warn( + "get_astrbot_config_path is deprecated. Use AstrbotPaths class instead.", + DeprecationWarning, + stacklevel=2, + ) + return str(AstrbotPaths.astrbot_root / "config") def get_astrbot_plugin_path() -> str: """获取Astrbot插件目录路径""" - return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins")) + warnings.warn( + "get_astrbot_plugin_path is deprecated. Use AstrbotPaths class instead.", + DeprecationWarning, + stacklevel=2, + ) + return str(AstrbotPaths.astrbot_root / "plugins") diff --git a/astrbot/core/utils/t2i/template_manager.py b/astrbot/core/utils/t2i/template_manager.py index 6d44f735..883bc276 100644 --- a/astrbot/core/utils/t2i/template_manager.py +++ b/astrbot/core/utils/t2i/template_manager.py @@ -2,8 +2,9 @@ import os import shutil +from importlib import resources -from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_path +from astrbot.core.utils.astrbot_path import get_astrbot_data_path class TemplateManager: @@ -15,14 +16,10 @@ class TemplateManager: CORE_TEMPLATES = ["base.html", "astrbot_powershell.html"] def __init__(self): - self.builtin_template_dir = os.path.join( - get_astrbot_path(), - "astrbot", - "core", - "utils", - "t2i", - "template", + self.builtin_template_dir = str( + resources.files("astrbot.core.utils.t2i.template") ) + self.user_template_dir = os.path.join(get_astrbot_data_path(), "t2i_templates") os.makedirs(self.user_template_dir, exist_ok=True) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index b947d26f..cdea62b4 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -1,4 +1,5 @@ import asyncio +import importlib.resources import inspect import os import traceback @@ -21,7 +22,6 @@ from astrbot.core.provider.entities import ProviderType from astrbot.core.provider.provider import RerankProvider from astrbot.core.provider.register import provider_registry from astrbot.core.star.star import star_registry -from astrbot.core.utils.astrbot_path import get_astrbot_path from .route import Response, Route, RouteContext @@ -461,11 +461,12 @@ class ConfigRoute(Route): logger.debug( f"Sending health check audio to provider: {status_info['name']}", ) - sample_audio_path = os.path.join( - get_astrbot_path(), - "samples", - "stt_health_check.wav", + sample_audio_path = str( + importlib.resources.files("astrbot") + / "samples" + / "stt_health_check.wav" ) + if not os.path.exists(sample_audio_path): status_info["status"] = "unavailable" status_info["error"] = ( diff --git a/samples/stt_health_check.wav b/astrbot/samples/stt_health_check.wav similarity index 100% rename from samples/stt_health_check.wav rename to astrbot/samples/stt_health_check.wav diff --git a/pyproject.toml b/pyproject.toml index c83fdf2d..413a073a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -63,6 +63,7 @@ dependencies = [ "jieba>=0.42.1", "markitdown-no-magika[docx,xls,xlsx]>=0.1.2", "xinference-client", + "dotenv>=0.9.9", ] [dependency-groups]