Compare commits
5 Commits
feat/vpo-t
...
refactor/a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
381f7f4405 | ||
|
|
8f38e748cd | ||
|
|
f83484a8c0 | ||
|
|
56e3ddd62a | ||
|
|
80948be41d |
@@ -1,2 +0,0 @@
|
||||
# ASTRBOT 数据目录
|
||||
# ASTRBOT_ROOT = ./data
|
||||
98
.github/copilot-instructions.md
vendored
98
.github/copilot-instructions.md
vendored
@@ -1,63 +1,63 @@
|
||||
# AstrBot 开发指南
|
||||
# AstrBot Development Instructions
|
||||
|
||||
AstrBot 是一个使用 Python 编写、配备 Vue.js 仪表盘的多平台 LLM 聊天机器人开发框架。它支持多个消息平台(QQ、Telegram、Discord 等)和多种 LLM 提供商(OpenAI、Anthropic、Google Gemini 等)。
|
||||
AstrBot is a multi-platform LLM chatbot and development framework written in Python with a Vue.js dashboard. It supports multiple messaging platforms (QQ, Telegram, Discord, etc.) and various LLM providers (OpenAI, Anthropic, Google Gemini, etc.).
|
||||
|
||||
始终优先参考这些指南,仅在遇到与此处信息不符的意外情况时才回退到搜索或 bash 命令。
|
||||
Always reference these instructions first and fallback to search or bash commands only when you encounter unexpected information that does not match the info here.
|
||||
|
||||
## 高效工作
|
||||
## Working Effectively
|
||||
|
||||
### 引导和安装依赖
|
||||
- **需要 Python 3.10+** - 检查 `.python-version` 文件
|
||||
- 安装 UV 包管理器:`pip install uv`
|
||||
- 安装项目依赖:`uv sync` -- 很快几分钟。绝不要取消。设置超时时间为 10+ 分钟。
|
||||
- 创建必需的目录:`mkdir -p data/plugins data/config data/temp`
|
||||
### Bootstrap and Install Dependencies
|
||||
- **Python 3.10+ required** - Check `.python-version` file
|
||||
- Install UV package manager: `pip install uv`
|
||||
- Install project dependencies: `uv sync` -- takes 6-7 minutes. NEVER CANCEL. Set timeout to 10+ minutes.
|
||||
- Create required directories: `mkdir -p data/plugins data/config data/temp`
|
||||
|
||||
### 运行应用程序
|
||||
- 运行主应用程序:`uv run main.py` -- 约 3 秒启动
|
||||
- 应用程序在 http://localhost:6185 创建 WebUI(默认凭据:`astrbot`/`astrbot`)
|
||||
- 应用程序自动从 `packages/` 和 `data/plugins/` 目录加载插件
|
||||
### Running the Application
|
||||
- Run main application: `uv run main.py` -- starts in ~3 seconds
|
||||
- Application creates WebUI on http://localhost:6185 (default credentials: `astrbot`/`astrbot`)
|
||||
- Application loads plugins automatically from `packages/` and `data/plugins/` directories
|
||||
|
||||
### 仪表盘构建(Vue.js/Node.js)
|
||||
- **前置要求**:需要 Node.js 20+ 和 npm 10+
|
||||
- 导航到仪表盘:`cd dashboard`
|
||||
- 安装仪表盘依赖:`npm install` -- 需要 2-3 分钟。绝不要取消。设置超时时间为 5+ 分钟。
|
||||
- 构建仪表盘:`npm run build` -- 需要 25-30 秒。绝不要取消。
|
||||
- 仪表盘在 `dashboard/dist/` 创建优化的生产构建
|
||||
### Dashboard Build (Vue.js/Node.js)
|
||||
- **Prerequisites**: Node.js 20+ and npm 10+ required
|
||||
- Navigate to dashboard: `cd dashboard`
|
||||
- Install dashboard dependencies: `npm install` -- takes 2-3 minutes. NEVER CANCEL. Set timeout to 5+ minutes.
|
||||
- Build dashboard: `npm run build` -- takes 25-30 seconds. NEVER CANCEL.
|
||||
- Dashboard creates optimized production build in `dashboard/dist/`
|
||||
|
||||
### 测试
|
||||
- 暂时不要生成测试文件。
|
||||
### Testing
|
||||
- Do not generate test files for now.
|
||||
|
||||
### 代码质量和检查
|
||||
- 安装 ruff 检查器:`uv add --dev ruff`
|
||||
- 检查代码风格:`uv run ruff check .` -- 耗时 <1 秒
|
||||
- 检查格式:`uv run ruff format --check .` -- 耗时 <1 秒
|
||||
- 修复格式:`uv run ruff format .`
|
||||
- **始终**在提交更改前运行 `uv run ruff check .` 和 `uv run ruff format .`
|
||||
### Code Quality and Linting
|
||||
- Install ruff linter: `uv add --dev ruff`
|
||||
- Check code style: `uv run ruff check .` -- takes <1 second
|
||||
- Check formatting: `uv run ruff format --check .` -- takes <1 second
|
||||
- Fix formatting: `uv run ruff format .`
|
||||
- **ALWAYS** run `uv run ruff check .` and `uv run ruff format .` before committing changes
|
||||
|
||||
### 插件开发
|
||||
- 插件从 `packages/`(内置)和 `data/plugins/`(用户安装)加载
|
||||
- 插件系统支持函数工具和消息处理器
|
||||
- 关键插件:python_interpreter、web_searcher、astrbot、reminder、session_controller
|
||||
### Plugin Development
|
||||
- Plugins load from `packages/` (built-in) and `data/plugins/` (user-installed)
|
||||
- Plugin system supports function tools and message handlers
|
||||
- Key plugins: python_interpreter, web_searcher, astrbot, reminder, session_controller
|
||||
|
||||
### 常见问题和解决方法
|
||||
- **仪表盘下载失败**:已知的"除以零"错误问题 - 应用程序仍可正常工作
|
||||
- **测试中的导入错误**:确保使用 `uv run` 在适当的环境中运行测试
|
||||
- **构建超时**:始终设置适当的超时时间(uv sync 为 10+ 分钟,npm install 为 5+ 分钟)
|
||||
### Common Issues and Workarounds
|
||||
- **Dashboard download fails**: Known issue with "division by zero" error - application still works
|
||||
- **Import errors in tests**: Ensure `uv run` is used to run tests in proper environment
|
||||
=- **Build timeouts**: Always set appropriate timeouts (10+ minutes for uv sync, 5+ minutes for npm install)
|
||||
|
||||
## CI/CD 集成
|
||||
- GitHub Actions 工作流在 `.github/workflows/` 中
|
||||
- 通过 `Dockerfile` 支持 Docker 构建
|
||||
- Pre-commit 钩子强制执行 ruff 格式化和检查
|
||||
## CI/CD Integration
|
||||
- GitHub Actions workflows in `.github/workflows/`
|
||||
- Docker builds supported via `Dockerfile`
|
||||
- Pre-commit hooks enforce ruff formatting and linting
|
||||
|
||||
## Docker 支持
|
||||
- 主要部署方法:`docker run soulter/astrbot:latest`
|
||||
- 可用的 Compose 文件:`compose.yml`
|
||||
- 暴露端口:6185(WebUI)、6195(WeChat)、6199(QQ)等
|
||||
- 需要挂载卷:`./data:/AstrBot/data`
|
||||
## Docker Support
|
||||
- Primary deployment method: `docker run soulter/astrbot:latest`
|
||||
- Compose file available: `compose.yml`
|
||||
- Exposes ports: 6185 (WebUI), 6195 (WeChat), 6199 (QQ), etc.
|
||||
- Volume mount required: `./data:/AstrBot/data`
|
||||
|
||||
## 多语言支持
|
||||
- 文档包括中文(README.md)、英文(README_en.md)、日文(README_ja.md)
|
||||
- UI 支持国际化
|
||||
- 默认语言为中文
|
||||
## Multi-language Support
|
||||
- Documentation in Chinese (README.md), English (README_en.md), Japanese (README_ja.md)
|
||||
- UI supports internationalization
|
||||
- Default language is Chinese
|
||||
|
||||
请记住:这是一个有真实用户的生产聊天机器人框架。始终进行彻底测试,确保更改不会破坏现有功能。
|
||||
Remember: This is a production chatbot framework with real users. Always test thoroughly and ensure changes don't break existing functionality.
|
||||
|
||||
@@ -1,97 +0,0 @@
|
||||
import argparse
|
||||
import asyncio
|
||||
import mimetypes
|
||||
import os
|
||||
import sys
|
||||
|
||||
from astrbot.core import LogBroker, LogManager, db_helper, logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.initial_loader import InitialLoader
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
from astrbot_api import LOGO, IAstrbotPaths
|
||||
from astrbot_sdk import sync_base_container
|
||||
|
||||
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
|
||||
|
||||
def check_env() -> None:
|
||||
if not (sys.version_info.major == 3 and sys.version_info.minor >= 10):
|
||||
logger.error("请使用 Python3.10+ 运行本项目。")
|
||||
exit()
|
||||
|
||||
# os.makedirs("data/config", exist_ok=True)
|
||||
# os.makedirs("data/plugins", exist_ok=True)
|
||||
# os.makedirs("data/temp", exist_ok=True)
|
||||
|
||||
# 针对问题 #181 的临时解决方案
|
||||
mimetypes.add_type("text/javascript", ".js")
|
||||
mimetypes.add_type("text/javascript", ".mjs")
|
||||
mimetypes.add_type("application/json", ".json")
|
||||
|
||||
|
||||
async def check_dashboard_files(webui_dir: str | None = None):
|
||||
"""下载管理面板文件"""
|
||||
# 指定webui目录
|
||||
if webui_dir:
|
||||
if os.path.exists(webui_dir):
|
||||
logger.info(f"使用指定的 WebUI 目录: {webui_dir}")
|
||||
return webui_dir
|
||||
logger.warning(f"指定的 WebUI 目录 {webui_dir} 不存在,将使用默认逻辑。")
|
||||
|
||||
data_dist_path = str(AstrbotPaths.astrbot_root / "dist")
|
||||
if os.path.exists(data_dist_path):
|
||||
v = await get_dashboard_version()
|
||||
if v is not None:
|
||||
# 存在文件
|
||||
if v == f"v{VERSION}":
|
||||
logger.info("WebUI 版本已是最新。")
|
||||
else:
|
||||
logger.warning(
|
||||
f"检测到 WebUI 版本 ({v}) 与当前 AstrBot 版本 (v{VERSION}) 不符。",
|
||||
)
|
||||
return data_dist_path
|
||||
|
||||
logger.info(
|
||||
"开始下载管理面板文件...高峰期(晚上)可能导致较慢的速度。如多次下载失败,请前往 https://github.com/AstrBotDevs/AstrBot/releases/latest 下载 dist.zip,并将其中的 dist 文件夹解压至 data 目录下。",
|
||||
)
|
||||
|
||||
try:
|
||||
await download_dashboard(version=f"v{VERSION}", latest=False)
|
||||
except Exception as e:
|
||||
logger.critical(f"下载管理面板文件失败: {e}。")
|
||||
return None
|
||||
|
||||
logger.info("管理面板下载完成。")
|
||||
return data_dist_path
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="AstrBot")
|
||||
parser.add_argument(
|
||||
"--webui-dir",
|
||||
type=str,
|
||||
help="指定 WebUI 静态文件目录路径",
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
check_env()
|
||||
|
||||
# 启动日志代理
|
||||
log_broker = LogBroker()
|
||||
LogManager.set_queue_handler(logger, log_broker)
|
||||
|
||||
# 检查仪表板文件
|
||||
webui_dir = asyncio.run(check_dashboard_files(args.webui_dir))
|
||||
|
||||
db = db_helper
|
||||
|
||||
# 打印 logo
|
||||
logger.info(LOGO)
|
||||
|
||||
core_lifecycle = InitialLoader(db, log_broker)
|
||||
core_lifecycle.webui_dir = webui_dir
|
||||
asyncio.run(core_lifecycle.start())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1 +1 @@
|
||||
"""AstrBot CLI入口"""
|
||||
__version__ = "3.5.23"
|
||||
|
||||
@@ -1,22 +1,27 @@
|
||||
"""AstrBot CLI入口"""
|
||||
|
||||
import sys
|
||||
from importlib.metadata import version
|
||||
|
||||
import click
|
||||
|
||||
from astrbot_api import LOGO
|
||||
|
||||
from . import __version__
|
||||
from .commands import conf, init, plug, run
|
||||
|
||||
__version__ = version("astrbot")
|
||||
logo_tmpl = r"""
|
||||
___ _______.___________..______ .______ ______ .___________.
|
||||
/ \ / | || _ \ | _ \ / __ \ | |
|
||||
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
|
||||
/ /_\ \ \ \ | | | / | _ < | | | | | |
|
||||
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
|
||||
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
|
||||
"""
|
||||
|
||||
|
||||
@click.group()
|
||||
@click.version_option(__version__, prog_name="AstrBot")
|
||||
def cli() -> None:
|
||||
"""The AstrBot CLI"""
|
||||
click.echo(LOGO)
|
||||
click.echo(logo_tmpl)
|
||||
click.echo("Welcome to AstrBot CLI!")
|
||||
click.echo(f"AstrBot CLI version: {__version__}")
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
import anyio
|
||||
import click
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
@@ -48,7 +48,7 @@ def init() -> None:
|
||||
|
||||
try:
|
||||
with lock.acquire():
|
||||
asyncio.run(initialize_astrbot(astrbot_root))
|
||||
anyio.run(initialize_astrbot, astrbot_root)
|
||||
except Timeout:
|
||||
raise click.ClickException("无法获取锁文件,请检查是否有其他实例正在运行")
|
||||
|
||||
|
||||
@@ -3,11 +3,6 @@ import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
from astrbot_api.abc import IAstrbotPaths
|
||||
|
||||
from astrbot_sdk import sync_base_container
|
||||
|
||||
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
|
||||
|
||||
from ..utils import (
|
||||
PluginStatus,
|
||||
@@ -52,7 +47,8 @@ def display_plugins(plugins, title=None, color=None):
|
||||
@click.argument("name")
|
||||
def new(name: str):
|
||||
"""创建新插件"""
|
||||
plug_path = AstrbotPaths.getPaths(name).plugins
|
||||
base_path = _get_data_path()
|
||||
plug_path = base_path / "plugins" / name
|
||||
|
||||
if plug_path.exists():
|
||||
raise click.ClickException(f"插件 {name} 已存在")
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
import asyncio
|
||||
import os
|
||||
import sys
|
||||
import traceback
|
||||
from pathlib import Path
|
||||
|
||||
import anyio
|
||||
import click
|
||||
from filelock import FileLock, Timeout
|
||||
|
||||
from ..utils import check_astrbot_root, check_dashboard, get_astrbot_root
|
||||
|
||||
|
||||
async def run_astrbot(astrbot_root: Path):
|
||||
async def run_astrbot(astrbot_root: Path) -> None:
|
||||
"""运行 AstrBot"""
|
||||
from astrbot.core import LogBroker, LogManager, db_helper, logger
|
||||
from astrbot.core.initial_loader import InitialLoader
|
||||
@@ -53,7 +53,7 @@ def run(reload: bool, port: str) -> None:
|
||||
lock_file = astrbot_root / "astrbot.lock"
|
||||
lock = FileLock(lock_file, timeout=5)
|
||||
with lock.acquire():
|
||||
asyncio.run(run_astrbot(astrbot_root))
|
||||
anyio.run(run_astrbot, astrbot_root)
|
||||
except KeyboardInterrupt:
|
||||
click.echo("AstrBot 已关闭...")
|
||||
except Timeout:
|
||||
|
||||
@@ -1,31 +1,22 @@
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import click
|
||||
from astrbot_api.abc import IAstrbotPaths
|
||||
|
||||
from astrbot_sdk import sync_base_container
|
||||
|
||||
AstrbotPaths = sync_base_container.get(type[IAstrbotPaths])
|
||||
|
||||
def check_astrbot_root(path: str | Path) -> bool:
|
||||
"""检查路径是否为 AstrBot 根目录"""
|
||||
warnings.warn(
|
||||
"请使用 AstrbotPaths 类代替本模块中的函数",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return AstrbotPaths.is_root(Path(path))
|
||||
if not isinstance(path, Path):
|
||||
path = Path(path)
|
||||
if not path.exists() or not path.is_dir():
|
||||
return False
|
||||
if not (path / ".astrbot").exists():
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def get_astrbot_root() -> Path:
|
||||
"""获取Astrbot根目录路径"""
|
||||
warnings.warn(
|
||||
"请使用 AstrbotPaths 类代替本模块中的函数",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return AstrbotPaths.astrbot_root
|
||||
return Path.cwd()
|
||||
|
||||
|
||||
async def check_dashboard(astrbot_root: Path) -> None:
|
||||
|
||||
@@ -9,6 +9,10 @@ from astrbot.core.utils.shared_preferences import SharedPreferences
|
||||
from astrbot.core.utils.t2i.renderer import HtmlRenderer
|
||||
|
||||
from .log import LogBroker, LogManager # noqa
|
||||
from .utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
# 初始化数据存储文件夹
|
||||
os.makedirs(get_astrbot_data_path(), exist_ok=True)
|
||||
|
||||
DEMO_MODE = os.getenv("DEMO_MODE", False)
|
||||
|
||||
|
||||
@@ -3,15 +3,11 @@ import json
|
||||
import logging
|
||||
import os
|
||||
|
||||
from astrbot_api.abc import IAstrbotPaths
|
||||
|
||||
from astrbot_sdk import sync_base_container
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from .default import DEFAULT_CONFIG, DEFAULT_VALUE_MAP
|
||||
|
||||
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
|
||||
|
||||
ASTRBOT_CONFIG_PATH = str(AstrbotPaths.astrbot_root / "cmd_config.json")
|
||||
ASTRBOT_CONFIG_PATH = os.path.join(get_astrbot_data_path(), "cmd_config.json")
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
"""如需修改配置,请在 `data/cmd_config.json` 中修改或者在管理面板中可视化修改。"""
|
||||
|
||||
from importlib.metadata import version
|
||||
import os
|
||||
|
||||
from astrbot_api.abc import IAstrbotPaths
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from astrbot_sdk import sync_base_container
|
||||
|
||||
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
|
||||
|
||||
# 警告,请使用version函数获取版本,此变量兼容保留
|
||||
VERSION = version("astrbot")
|
||||
|
||||
DB_PATH = str(AstrbotPaths.astrbot_root / "data_v4.db")
|
||||
VERSION = "4.5.1"
|
||||
DB_PATH = os.path.join(get_astrbot_data_path(), "data_v4.db")
|
||||
|
||||
# 默认配置
|
||||
DEFAULT_CONFIG = {
|
||||
|
||||
@@ -14,7 +14,8 @@ import os
|
||||
import threading
|
||||
import time
|
||||
import traceback
|
||||
from asyncio import Queue
|
||||
|
||||
import anyio
|
||||
|
||||
from astrbot.core import LogBroker, logger, sp
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
@@ -104,7 +105,9 @@ class AstrBotCoreLifecycle:
|
||||
logger.error(traceback.format_exc())
|
||||
|
||||
# 初始化事件队列
|
||||
self.event_queue = Queue()
|
||||
self._event_queue_send, self.event_queue = anyio.create_memory_object_stream[
|
||||
object
|
||||
](0)
|
||||
|
||||
# 初始化人格管理器
|
||||
self.persona_mgr = PersonaManager(self.db, self.astrbot_config_mgr)
|
||||
@@ -118,7 +121,9 @@ class AstrBotCoreLifecycle:
|
||||
)
|
||||
|
||||
# 初始化平台管理器
|
||||
self.platform_manager = PlatformManager(self.astrbot_config, self.event_queue)
|
||||
self.platform_manager = PlatformManager(
|
||||
self.astrbot_config, self._event_queue_send
|
||||
)
|
||||
|
||||
# 初始化对话管理器
|
||||
self.conversation_manager = ConversationManager(self.db)
|
||||
@@ -131,7 +136,7 @@ class AstrBotCoreLifecycle:
|
||||
|
||||
# 初始化提供给插件的上下文
|
||||
self.star_context = Context(
|
||||
self.event_queue,
|
||||
self._event_queue_send,
|
||||
self.astrbot_config,
|
||||
self.db,
|
||||
self.provider_manager,
|
||||
|
||||
@@ -271,7 +271,7 @@ class SQLiteDatabase(BaseDatabase):
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
delete(ConversationV2).where(
|
||||
col(ConversationV2.user_id) == user_id
|
||||
col(ConversationV2.user_id) == user_id,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""事件总线, 用于处理事件的分发和处理
|
||||
"""事件总线, 用于处理事件的分发和处理.
|
||||
|
||||
事件总线是一个异步队列, 用于接收各种消息事件, 并将其发送到Scheduler调度器进行处理
|
||||
其中包含了一个无限循环的调度函数, 用于从事件队列中获取新的事件, 并创建一个新的异步任务来执行管道调度器的处理逻辑
|
||||
|
||||
@@ -10,8 +11,8 @@ class:
|
||||
2. 无限循环的调度函数, 从事件队列中获取新的事件, 打印日志并创建一个新的异步任务来执行管道调度器的处理逻辑
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from asyncio import Queue
|
||||
import anyio
|
||||
from anyio.streams.memory import MemoryObjectReceiveStream
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
@@ -25,28 +26,29 @@ class EventBus:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_queue: Queue,
|
||||
event_queue: MemoryObjectReceiveStream[AstrMessageEvent],
|
||||
pipeline_scheduler_mapping: dict[str, PipelineScheduler],
|
||||
astrbot_config_mgr: AstrBotConfigManager = None,
|
||||
):
|
||||
astrbot_config_mgr: AstrBotConfigManager | None = None,
|
||||
) -> None:
|
||||
self.event_queue = event_queue # 事件队列
|
||||
# abconf uuid -> scheduler
|
||||
self.pipeline_scheduler_mapping = pipeline_scheduler_mapping
|
||||
self.astrbot_config_mgr = astrbot_config_mgr
|
||||
|
||||
async def dispatch(self):
|
||||
async def dispatch(self) -> None:
|
||||
while True:
|
||||
event: AstrMessageEvent = await self.event_queue.get()
|
||||
event: AstrMessageEvent = await self.event_queue.receive()
|
||||
conf_info = self.astrbot_config_mgr.get_conf_info(event.unified_msg_origin)
|
||||
self._print_event(event, conf_info["name"])
|
||||
scheduler = self.pipeline_scheduler_mapping.get(conf_info["id"])
|
||||
asyncio.create_task(scheduler.execute(event))
|
||||
anyio.create_task(scheduler.execute(event))
|
||||
|
||||
def _print_event(self, event: AstrMessageEvent, conf_name: str):
|
||||
def _print_event(self, event: AstrMessageEvent, conf_name: str) -> None:
|
||||
"""用于记录事件信息
|
||||
|
||||
Args:
|
||||
event (AstrMessageEvent): 事件对象
|
||||
event: 事件对象
|
||||
conf_name: 配置名称
|
||||
|
||||
"""
|
||||
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import asyncio
|
||||
import os
|
||||
import platform
|
||||
import time
|
||||
import uuid
|
||||
from urllib.parse import unquote, urlparse
|
||||
|
||||
import anyio
|
||||
|
||||
|
||||
class FileTokenService:
|
||||
"""维护一个简单的基于令牌的文件下载服务,支持超时和懒清除。"""
|
||||
|
||||
def __init__(self, default_timeout: float = 300):
|
||||
self.lock = asyncio.Lock()
|
||||
self.staged_files = {} # token: (file_path, expire_time)
|
||||
def __init__(self, default_timeout: float = 300) -> None:
|
||||
self.lock = anyio.Lock()
|
||||
self.staged_files: dict = {} # token: (file_path, expire_time)
|
||||
self.default_timeout = default_timeout
|
||||
|
||||
async def _cleanup_expired_tokens(self):
|
||||
|
||||
@@ -21,20 +21,20 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
from enum import Enum
|
||||
|
||||
from astrbot_api.abc import IAstrbotPaths
|
||||
from pydantic.v1 import BaseModel
|
||||
|
||||
from astrbot.core import astrbot_config, file_token_service, logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_file, download_image_by_url, file_to_base64
|
||||
from astrbot_sdk import sync_base_container
|
||||
|
||||
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
|
||||
|
||||
class ComponentType(str, Enum):
|
||||
# Basic Segment Types
|
||||
Plain = "Plain" # plain text message
|
||||
@@ -153,7 +153,8 @@ class Record(BaseMessageComponent):
|
||||
if self.file.startswith("base64://"):
|
||||
bs64_data = self.file.removeprefix("base64://")
|
||||
image_bytes = base64.b64decode(bs64_data)
|
||||
file_path = str(AstrbotPaths.astrbot_root / "temp" / f"{uuid.uuid4()}.jpg")
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return os.path.abspath(file_path)
|
||||
@@ -241,9 +242,8 @@ class Video(BaseMessageComponent):
|
||||
if url and url.startswith("file:///"):
|
||||
return url[8:]
|
||||
if url and url.startswith("http"):
|
||||
video_file_path = str(
|
||||
AstrbotPaths.astrbot_root / "temp" / f"{uuid.uuid4().hex}"
|
||||
)
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
video_file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||
await download_file(url, video_file_path)
|
||||
if os.path.exists(video_file_path):
|
||||
return os.path.abspath(video_file_path)
|
||||
@@ -442,9 +442,8 @@ class Image(BaseMessageComponent):
|
||||
if url.startswith("base64://"):
|
||||
bs64_data = url.removeprefix("base64://")
|
||||
image_bytes = base64.b64decode(bs64_data)
|
||||
image_file_path = str(
|
||||
AstrbotPaths.astrbot_root / "temp" / f"{uuid.uuid4()}.jpg"
|
||||
)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
image_file_path = os.path.join(temp_dir, f"{uuid.uuid4()}.jpg")
|
||||
with open(image_file_path, "wb") as f:
|
||||
f.write(image_bytes)
|
||||
return os.path.abspath(image_file_path)
|
||||
@@ -528,7 +527,7 @@ class Reply(BaseMessageComponent):
|
||||
|
||||
|
||||
class Poke(BaseMessageComponent):
|
||||
type = ComponentType.Poke
|
||||
type: str = ComponentType.Poke
|
||||
id: int | None = 0
|
||||
qq: int | None = 0
|
||||
|
||||
@@ -655,19 +654,33 @@ class File(BaseMessageComponent):
|
||||
|
||||
@property
|
||||
def file(self) -> str:
|
||||
"""获取本地文件路径(仅返回已存在的文件)
|
||||
|
||||
⚠️ 警告:此属性不会自动下载文件!
|
||||
- 如果文件已存在,返回绝对路径
|
||||
- 如果只有 URL 没有本地文件,返回空字符串
|
||||
- 需要下载文件请使用 `await get_file()` 方法
|
||||
"""获取文件路径,如果文件不存在但有URL,则同步下载文件
|
||||
|
||||
Returns:
|
||||
str: 文件的绝对路径,如果文件不存在则返回空字符串
|
||||
str: 文件路径
|
||||
|
||||
"""
|
||||
if self.file_ and os.path.exists(self.file_):
|
||||
return os.path.abspath(self.file_)
|
||||
|
||||
if self.url:
|
||||
try:
|
||||
loop = asyncio.get_event_loop()
|
||||
if loop.is_running():
|
||||
logger.warning(
|
||||
"不可以在异步上下文中同步等待下载! "
|
||||
"这个警告通常发生于某些逻辑试图通过 <File>.file 获取文件消息段的文件内容。"
|
||||
"请使用 await get_file() 代替直接获取 <File>.file 字段",
|
||||
)
|
||||
return ""
|
||||
# 等待下载完成
|
||||
loop.run_until_complete(self._download_file())
|
||||
|
||||
if self.file_ and os.path.exists(self.file_):
|
||||
return os.path.abspath(self.file_)
|
||||
except Exception as e:
|
||||
logger.error(f"文件下载失败: {e}")
|
||||
|
||||
return ""
|
||||
|
||||
@file.setter
|
||||
@@ -701,18 +714,15 @@ class File(BaseMessageComponent):
|
||||
|
||||
if self.url:
|
||||
await self._download_file()
|
||||
if self.file_:
|
||||
return os.path.abspath(self.file_)
|
||||
return os.path.abspath(self.file_)
|
||||
|
||||
return ""
|
||||
|
||||
async def _download_file(self):
|
||||
"""下载文件"""
|
||||
if not self.url:
|
||||
raise ValueError("No URL provided for download")
|
||||
download_dir = str(AstrbotPaths.astrbot_root / "temp")
|
||||
download_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(download_dir, exist_ok=True)
|
||||
file_path = str(AstrbotPaths.astrbot_root / "temp" / f"{uuid.uuid4().hex}")
|
||||
file_path = os.path.join(download_dir, f"{uuid.uuid4().hex}")
|
||||
await download_file(self.url, file_path)
|
||||
self.file_ = os.path.abspath(file_path)
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import asyncio
|
||||
from collections import defaultdict, deque
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import anyio
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.astrbot_config import RateLimitStrategy
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
@@ -19,11 +20,11 @@ class RateLimitStage(Stage):
|
||||
如果触发限流,将 stall 流水线,直到下一个时间窗口来临时自动唤醒。
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self) -> None:
|
||||
# 存储每个会话的请求时间队列
|
||||
self.event_timestamps: defaultdict[str, deque[datetime]] = defaultdict(deque)
|
||||
# 为每个会话设置一个锁,避免并发冲突
|
||||
self.locks: defaultdict[str, asyncio.Lock] = defaultdict(asyncio.Lock)
|
||||
self.locks: defaultdict[str, anyio.Lock] = defaultdict(anyio.Lock)
|
||||
# 限流参数
|
||||
self.rate_limit_count: int = 0
|
||||
self.rate_limit_time: timedelta = timedelta(0)
|
||||
@@ -74,7 +75,7 @@ class RateLimitStage(Stage):
|
||||
logger.info(
|
||||
f"会话 {session_id} 被限流。根据限流策略,此会话处理将被暂停 {stall_duration:.2f} 秒。",
|
||||
)
|
||||
await asyncio.sleep(stall_duration)
|
||||
await anyio.sleep(stall_duration)
|
||||
now = datetime.now()
|
||||
case RateLimitStrategy.DISCARD.value:
|
||||
logger.info(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import math
|
||||
import random
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.core import logger
|
||||
@@ -152,7 +153,7 @@ class RespondStage(Stage):
|
||||
async def process(
|
||||
self,
|
||||
event: AstrMessageEvent,
|
||||
) -> None:
|
||||
) -> None | AsyncGenerator[None, None]:
|
||||
result = event.get_result()
|
||||
if result is None:
|
||||
return
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import asyncio
|
||||
import traceback
|
||||
from asyncio import Queue
|
||||
|
||||
from anyio.streams.memory import MemoryObjectSendStream
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.astrbot_config import AstrBotConfig
|
||||
@@ -12,7 +13,7 @@ from .sources.webchat.webchat_adapter import WebChatAdapter
|
||||
|
||||
|
||||
class PlatformManager:
|
||||
def __init__(self, config: AstrBotConfig, event_queue: Queue):
|
||||
def __init__(self, config: AstrBotConfig, event_queue: MemoryObjectSendStream):
|
||||
self.platform_insts: list[Platform] = []
|
||||
"""加载的 Platform 的实例"""
|
||||
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
import abc
|
||||
import uuid
|
||||
from asyncio import Queue
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any
|
||||
|
||||
from anyio.streams.memory import MemoryObjectSendStream
|
||||
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.utils.metrics import Metric
|
||||
|
||||
@@ -13,7 +14,7 @@ from .platform_metadata import PlatformMetadata
|
||||
|
||||
|
||||
class Platform(abc.ABC):
|
||||
def __init__(self, event_queue: Queue):
|
||||
def __init__(self, event_queue: MemoryObjectSendStream):
|
||||
super().__init__()
|
||||
# 维护了消息平台的事件队列,EventBus 会从这里取出事件并处理。
|
||||
self._event_queue = event_queue
|
||||
@@ -45,7 +46,7 @@ class Platform(abc.ABC):
|
||||
|
||||
def commit_event(self, event: AstrMessageEvent):
|
||||
"""提交一个事件到事件队列。"""
|
||||
self._event_queue.put_nowait(event)
|
||||
self._event_queue.send_nowait(event)
|
||||
|
||||
def get_client(self):
|
||||
"""获取平台的客户端对象。"""
|
||||
|
||||
@@ -216,7 +216,7 @@ class DingtalkPlatformAdapter(Platform):
|
||||
client=self.client,
|
||||
)
|
||||
|
||||
self._event_queue.put_nowait(event)
|
||||
self._event_queue.send_nowait(event)
|
||||
|
||||
async def run(self):
|
||||
# await self.client_.start()
|
||||
|
||||
@@ -224,7 +224,7 @@ class LarkPlatformAdapter(Platform):
|
||||
bot=self.lark_api,
|
||||
)
|
||||
|
||||
self._event_queue.put_nowait(event)
|
||||
self._event_queue.send_nowait(event)
|
||||
|
||||
async def run(self):
|
||||
# self.client.start()
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import asyncio
|
||||
import hashlib
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
|
||||
import anyio
|
||||
from quart import Quart, Response, request
|
||||
from slack_sdk.socket_mode.aiohttp import SocketModeClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
@@ -40,7 +40,7 @@ class SlackWebhookClient:
|
||||
logging.getLogger("quart.app").setLevel(logging.WARNING)
|
||||
logging.getLogger("quart.serving").setLevel(logging.WARNING)
|
||||
|
||||
self.shutdown_event = asyncio.Event()
|
||||
self.shutdown_event = anyio.Event()
|
||||
|
||||
def _setup_routes(self):
|
||||
"""设置路由"""
|
||||
|
||||
@@ -5,8 +5,6 @@ import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from astrbot_api.abc import IAstrbotPaths
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core.message.components import Image, Plain, Record
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
@@ -18,13 +16,12 @@ from astrbot.core.platform import (
|
||||
PlatformMetadata,
|
||||
)
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot_sdk import sync_base_container
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from ...register import register_platform_adapter
|
||||
from .webchat_event import WebChatMessageEvent
|
||||
from .webchat_queue_mgr import WebChatQueueMgr, webchat_queue_mgr
|
||||
|
||||
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
|
||||
|
||||
class QueueListener:
|
||||
def __init__(self, webchat_queue_mgr: WebChatQueueMgr, callback: Callable) -> None:
|
||||
@@ -82,7 +79,7 @@ class WebChatAdapter(Platform):
|
||||
self.config = platform_config
|
||||
self.settings = platform_settings
|
||||
self.unique_session = platform_settings["unique_session"]
|
||||
self.imgs_dir = str(AstrbotPaths.astrbot_root / "webchat" / "imgs")
|
||||
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||
|
||||
self.metadata = PlatformMetadata(
|
||||
|
||||
@@ -8,14 +8,10 @@ import traceback
|
||||
import aiohttp
|
||||
import anyio
|
||||
import websockets
|
||||
from astrbot_api.abc import IAstrbotPaths
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.api.message_components import At, Image, Plain, Record
|
||||
from astrbot.api.platform import Platform, PlatformMetadata
|
||||
from astrbot_sdk import sync_base_container
|
||||
|
||||
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
|
||||
from astrbot.core.message.message_event_result import MessageChain
|
||||
from astrbot.core.platform.astr_message_event import MessageSesion
|
||||
from astrbot.core.platform.astrbot_message import (
|
||||
@@ -23,6 +19,7 @@ from astrbot.core.platform.astrbot_message import (
|
||||
MessageMember,
|
||||
MessageType,
|
||||
)
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from ...register import register_platform_adapter
|
||||
from .wechatpadpro_message_event import WeChatPadProMessageEvent
|
||||
@@ -71,8 +68,9 @@ class WeChatPadProAdapter(Platform):
|
||||
self.base_url = f"http://{self.host}:{self.port}"
|
||||
self.auth_key = None # 用于保存生成的授权码
|
||||
self.wxid = None # 用于保存登录成功后的 wxid
|
||||
self.credentials_file = str(
|
||||
AstrbotPaths.astrbot_root / "wechatpadpro_credentials.json"
|
||||
self.credentials_file = os.path.join(
|
||||
get_astrbot_data_path(),
|
||||
"wechatpadpro_credentials.json",
|
||||
) # 持久化文件路径
|
||||
self.ws_handle_task = None
|
||||
|
||||
@@ -157,8 +155,8 @@ class WeChatPadProAdapter(Platform):
|
||||
}
|
||||
try:
|
||||
# 确保数据目录存在
|
||||
config_dir = AstrbotPaths.astrbot_root / "config"
|
||||
config_dir.mkdir(parents=True, exist_ok=True)
|
||||
data_dir = os.path.dirname(self.credentials_file)
|
||||
os.makedirs(data_dir, exist_ok=True)
|
||||
with open(self.credentials_file, "w") as f:
|
||||
json.dump(credentials, f)
|
||||
except Exception as e:
|
||||
@@ -789,10 +787,10 @@ class WeChatPadProAdapter(Platform):
|
||||
voice_bs64_data = voice_resp.get("Data", {}).get("Base64", None)
|
||||
if voice_bs64_data:
|
||||
voice_bs64_data = base64.b64decode(voice_bs64_data)
|
||||
file_path = str(
|
||||
AstrbotPaths.astrbot_root
|
||||
/ "temp"
|
||||
/ f"wechatpadpro_voice_{abm.message_id}.silk"
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
file_path = os.path.join(
|
||||
temp_dir,
|
||||
f"wechatpadpro_voice_{abm.message_id}.silk",
|
||||
)
|
||||
|
||||
async with await anyio.open_file(file_path, "wb") as f:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""企业微信智能机器人 API 客户端
|
||||
"""企业微信智能机器人 API 客户端.
|
||||
|
||||
处理消息加密解密、API 调用等
|
||||
"""
|
||||
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
处理企业微信智能机器人的 HTTP 回调请求
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
import quart
|
||||
|
||||
from astrbot.api import logger
|
||||
@@ -41,7 +41,7 @@ class WecomAIBotServer:
|
||||
self.app = quart.Quart(__name__)
|
||||
self._setup_routes()
|
||||
|
||||
self.shutdown_event = asyncio.Event()
|
||||
self.shutdown_event = anyio.Event()
|
||||
|
||||
def _setup_routes(self):
|
||||
"""设置 Quart 路由"""
|
||||
|
||||
@@ -7,6 +7,7 @@ from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import aiohttp
|
||||
import anyio
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot.core import sp
|
||||
@@ -98,7 +99,7 @@ class FunctionToolManager:
|
||||
self.func_list: list[FuncTool] = []
|
||||
self.mcp_client_dict: dict[str, MCPClient] = {}
|
||||
"""MCP 服务列表"""
|
||||
self.mcp_client_event: dict[str, asyncio.Event] = {}
|
||||
self.mcp_client_event: dict[str, anyio.Event] = {}
|
||||
|
||||
def empty(self) -> bool:
|
||||
return len(self.func_list) == 0
|
||||
@@ -206,7 +207,7 @@ class FunctionToolManager:
|
||||
for name in mcp_server_json_obj:
|
||||
cfg = mcp_server_json_obj[name]
|
||||
if cfg.get("active", True):
|
||||
event = asyncio.Event()
|
||||
event = anyio.Event()
|
||||
asyncio.create_task(
|
||||
self._init_mcp_client_task_wrapper(name, cfg, event),
|
||||
)
|
||||
@@ -216,7 +217,7 @@ class FunctionToolManager:
|
||||
self,
|
||||
name: str,
|
||||
cfg: dict,
|
||||
event: asyncio.Event,
|
||||
event: anyio.Event,
|
||||
ready_future: asyncio.Future | None = None,
|
||||
) -> None:
|
||||
"""初始化 MCP 客户端的包装函数,用于捕获异常"""
|
||||
@@ -307,7 +308,7 @@ class FunctionToolManager:
|
||||
self,
|
||||
name: str,
|
||||
config: dict,
|
||||
event: asyncio.Event | None = None,
|
||||
event: anyio.Event | None = None,
|
||||
ready_future: asyncio.Future | None = None,
|
||||
timeout: int = 30,
|
||||
) -> None:
|
||||
@@ -316,7 +317,7 @@ class FunctionToolManager:
|
||||
Args:
|
||||
name (str): The name of the MCP server.
|
||||
config (dict): Configuration for the MCP server.
|
||||
event (asyncio.Event): Event to signal when the MCP client is ready.
|
||||
event (anyio.Event): Event to signal when the MCP client is ready.
|
||||
ready_future (asyncio.Future): Future to signal when the MCP client is ready.
|
||||
timeout (int): Timeout for the initialization.
|
||||
|
||||
@@ -326,7 +327,7 @@ class FunctionToolManager:
|
||||
|
||||
"""
|
||||
if not event:
|
||||
event = asyncio.Event()
|
||||
event = anyio.Event()
|
||||
if not ready_future:
|
||||
ready_future = asyncio.Future()
|
||||
if name in self.mcp_client_dict:
|
||||
|
||||
@@ -15,11 +15,7 @@ except (
|
||||
): # pragma: no cover - older dashscope versions without Qwen TTS support
|
||||
MultiModalConversation = None
|
||||
|
||||
from astrbot_api.abc import IAstrbotPaths
|
||||
|
||||
from astrbot_sdk import sync_base_container
|
||||
|
||||
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import TTSProvider
|
||||
@@ -49,7 +45,7 @@ class ProviderDashscopeTTSAPI(TTSProvider):
|
||||
if not model:
|
||||
raise RuntimeError("Dashscope TTS model is not configured.")
|
||||
|
||||
temp_dir = str(AstrbotPaths.astrbot_root / "temp")
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
if self._is_qwen_tts_model(model):
|
||||
|
||||
@@ -4,12 +4,9 @@ import subprocess
|
||||
import uuid
|
||||
|
||||
import edge_tts
|
||||
from astrbot_api.abc import IAstrbotPaths
|
||||
|
||||
from astrbot_sdk import sync_base_container
|
||||
|
||||
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import TTSProvider
|
||||
@@ -49,10 +46,8 @@ class ProviderEdgeTTS(TTSProvider):
|
||||
self.set_model("edge_tts")
|
||||
|
||||
async def get_audio(self, text: str) -> str:
|
||||
temp_dir = str(AstrbotPaths.astrbot_root / "temp")
|
||||
mp3_path = str(
|
||||
AstrbotPaths.astrbot_root / "temp" / f"edge_tts_temp_{uuid.uuid4()}.mp3"
|
||||
)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
mp3_path = os.path.join(temp_dir, f"edge_tts_temp_{uuid.uuid4()}.mp3")
|
||||
wav_path = os.path.join(temp_dir, f"edge_tts_{uuid.uuid4()}.wav")
|
||||
|
||||
# 构建 Edge TTS 参数
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
import traceback
|
||||
import uuid
|
||||
|
||||
import aiohttp
|
||||
from astrbot_api.abc import IAstrbotPaths
|
||||
|
||||
from astrbot import logger
|
||||
from astrbot_sdk import sync_base_container
|
||||
|
||||
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
|
||||
|
||||
from ..entities import ProviderType
|
||||
from ..provider import TTSProvider
|
||||
@@ -95,10 +92,9 @@ class ProviderVolcengineTTS(TTSProvider):
|
||||
if "data" in resp_data:
|
||||
audio_data = base64.b64decode(resp_data["data"])
|
||||
|
||||
temp_dir = AstrbotPaths.astrbot_root / "temp"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
|
||||
file_path = str(temp_dir / f"volcengine_tts_{uuid.uuid4()}.mp3")
|
||||
file_path = f"data/temp/volcengine_tts_{uuid.uuid4()}.mp3"
|
||||
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(
|
||||
|
||||
@@ -1,13 +1,10 @@
|
||||
import os
|
||||
import uuid
|
||||
|
||||
from astrbot_api.abc import IAstrbotPaths
|
||||
from openai import NOT_GIVEN, AsyncOpenAI
|
||||
|
||||
from astrbot_sdk import sync_base_container
|
||||
|
||||
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_file
|
||||
from astrbot.core.utils.tencent_record_helper import tencent_silk_to_wav
|
||||
|
||||
@@ -56,7 +53,8 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
is_tencent = True
|
||||
|
||||
name = str(uuid.uuid4())
|
||||
path = str(AstrbotPaths.astrbot_root / "temp" / name)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, name)
|
||||
await download_file(audio_url, path)
|
||||
audio_url = path
|
||||
|
||||
@@ -67,9 +65,8 @@ class ProviderOpenAIWhisperAPI(STTProvider):
|
||||
is_silk = await self._is_silk_file(audio_url)
|
||||
if is_silk:
|
||||
logger.info("Converting silk file to wav ...")
|
||||
output_path = str(
|
||||
AstrbotPaths.astrbot_root / "temp" / f"{uuid.uuid4()}.wav"
|
||||
)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
output_path = os.path.join(temp_dir, str(uuid.uuid4()) + ".wav")
|
||||
await tencent_silk_to_wav(audio_url, output_path)
|
||||
audio_url = output_path
|
||||
|
||||
|
||||
@@ -3,11 +3,7 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from astrbot_api.abc import IAstrbotPaths
|
||||
|
||||
from astrbot_sdk import sync_base_container
|
||||
|
||||
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
def load_config(namespace: str) -> dict | bool:
|
||||
@@ -15,7 +11,7 @@ def load_config(namespace: str) -> dict | bool:
|
||||
namespace: str, 配置的唯一识别符,也就是配置文件的名字。
|
||||
返回值: 当配置文件存在时,返回 namespace 对应配置文件的内容dict,否则返回 False。
|
||||
"""
|
||||
path = str(AstrbotPaths.astrbot_root / "config" / f"{namespace}.json")
|
||||
path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json")
|
||||
if not os.path.exists(path):
|
||||
return False
|
||||
with open(path, encoding="utf-8-sig") as f:
|
||||
@@ -45,7 +41,8 @@ def put_config(namespace: str, name: str, key: str, value, description: str):
|
||||
if not isinstance(value, (str, int, float, bool, list)):
|
||||
raise ValueError("value 只支持 str, int, float, bool, list 类型。")
|
||||
|
||||
path = str(AstrbotPaths.astrbot_root / "config" / f"{namespace}.json")
|
||||
config_dir = os.path.join(get_astrbot_data_path(), "config")
|
||||
path = os.path.join(config_dir, f"{namespace}.json")
|
||||
|
||||
if not os.path.exists(path):
|
||||
with open(path, "w", encoding="utf-8-sig") as f:
|
||||
@@ -73,7 +70,7 @@ def update_config(namespace: str, key: str, value):
|
||||
key: str, 配置项的键。
|
||||
value: str, int, float, bool, list, 配置项的值。
|
||||
"""
|
||||
path = str(AstrbotPaths.astrbot_root / "config" / f"{namespace}.json")
|
||||
path = os.path.join(get_astrbot_data_path(), "config", f"{namespace}.json")
|
||||
if not os.path.exists(path):
|
||||
raise FileNotFoundError(f"配置文件 {namespace}.json 不存在。")
|
||||
with open(path, encoding="utf-8-sig") as f:
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import logging
|
||||
from asyncio import Queue
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from anyio.streams.memory import MemoryObjectSendStream
|
||||
from deprecated import deprecated
|
||||
|
||||
from astrbot.core.astrbot_config_mgr import AstrBotConfigManager
|
||||
@@ -50,7 +50,7 @@ class Context:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
event_queue: Queue,
|
||||
event_queue: MemoryObjectSendStream,
|
||||
config: AstrBotConfig,
|
||||
db: BaseDatabase,
|
||||
provider_manager: ProviderManager,
|
||||
@@ -193,7 +193,7 @@ class Context:
|
||||
"""获取 AstrBot 数据库。"""
|
||||
return self._db
|
||||
|
||||
def get_event_queue(self) -> Queue:
|
||||
def get_event_queue(self) -> MemoryObjectSendStream:
|
||||
"""获取事件队列。"""
|
||||
return self._event_queue
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ class CommandGroupFilter(HandlerFilter):
|
||||
prefix + "│ ",
|
||||
event=event,
|
||||
cfg=cfg,
|
||||
)
|
||||
),
|
||||
)
|
||||
|
||||
return "".join(parts)
|
||||
|
||||
@@ -30,7 +30,9 @@ class UmopConfigRouter:
|
||||
if len(p1_ls) != 3 or len(p2_ls) != 3:
|
||||
return False # 非法格式
|
||||
|
||||
return all(p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls))
|
||||
return all(
|
||||
p == "" or p == "*" or p == t for p, t in zip(p1_ls, p2_ls, strict=False)
|
||||
)
|
||||
|
||||
def get_conf_id_for_umop(self, umo: str) -> str | None:
|
||||
"""根据 UMO 获取对应的配置文件 ID
|
||||
|
||||
@@ -6,7 +6,7 @@ import psutil
|
||||
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
from astrbot.core.utils.io import download_file
|
||||
|
||||
from .zip_updator import ReleaseInfo, RepoZipUpdator
|
||||
@@ -20,7 +20,7 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
|
||||
def __init__(self, repo_mirror: str = "") -> None:
|
||||
super().__init__(repo_mirror)
|
||||
self.astrbot_root = get_astrbot_data_path()
|
||||
self.MAIN_PATH = get_astrbot_path()
|
||||
self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases"
|
||||
|
||||
def terminate_child_processes(self):
|
||||
@@ -117,7 +117,7 @@ class AstrBotUpdator(RepoZipUpdator):
|
||||
try:
|
||||
await download_file(file_url, "temp.zip")
|
||||
logger.info("下载 AstrBot Core 更新文件完成,正在执行解压...")
|
||||
self.unzip_file("temp.zip", self.astrbot_root)
|
||||
self.unzip_file("temp.zip", self.MAIN_PATH)
|
||||
except BaseException as e:
|
||||
raise e
|
||||
|
||||
|
||||
@@ -1,70 +1,39 @@
|
||||
"""Astrbot统一路径获取
|
||||
|
||||
项目路径:固定为源码所在路径
|
||||
根目录路径:默认为当前工作目录,可通过环境变量 ASTRBOT_ROOT 指定
|
||||
|
||||
数据目录路径:固定为根目录下的 data 目录
|
||||
配置文件路径:固定为数据目录下的 config 目录
|
||||
插件目录路径:固定为数据目录下的 plugins 目录
|
||||
"""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from astrbot_api.abc import IAstrbotPaths
|
||||
|
||||
from astrbot_sdk import sync_base_container
|
||||
|
||||
AstrbotPaths = sync_base_container.get(type[IAstrbotPaths])
|
||||
|
||||
def get_astrbot_path() -> str:
|
||||
"""获取Astrbot项目路径 -- 请勿使用本函数!!! -- 仅供兼容旧代码使用"""
|
||||
warnings.warn(
|
||||
"get_astrbot_path is deprecated. Use AstrbotPaths class instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
"""获取Astrbot项目路径"""
|
||||
return os.path.realpath(
|
||||
os.path.join(os.path.dirname(os.path.abspath(__file__)), "../../../"),
|
||||
)
|
||||
|
||||
|
||||
def get_astrbot_root() -> str:
|
||||
"""获取Astrbot根目录路径 --> get_astrbot_data_path"""
|
||||
warnings.warn(
|
||||
"不要再使用本函数!等效于: AstrbotPaths.astrbot_root",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return str(AstrbotPaths.astrbot_root)
|
||||
"""获取Astrbot根目录路径"""
|
||||
if path := os.environ.get("ASTRBOT_ROOT"):
|
||||
return os.path.realpath(path)
|
||||
return os.path.realpath(os.getcwd())
|
||||
|
||||
|
||||
def get_astrbot_data_path() -> str:
|
||||
"""获取Astrbot数据目录路径
|
||||
特别注意!
|
||||
这里的data目录指的就是.astrbot根目录!
|
||||
两者是等价的!
|
||||
不要和AstrbotPaths.data混淆!
|
||||
"""
|
||||
warnings.warn(
|
||||
"等效于: AstrbotPaths.astrbot_root",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return str(AstrbotPaths.astrbot_root)
|
||||
"""获取Astrbot数据目录路径"""
|
||||
return os.path.realpath(os.path.join(get_astrbot_root(), "data"))
|
||||
|
||||
|
||||
def get_astrbot_config_path() -> str:
|
||||
"""获取Astrbot配置文件路径"""
|
||||
warnings.warn(
|
||||
"get_astrbot_config_path is deprecated. Use AstrbotPaths class instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return str(AstrbotPaths.astrbot_root / "config")
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "config"))
|
||||
|
||||
|
||||
def get_astrbot_plugin_path() -> str:
|
||||
"""获取Astrbot插件目录路径"""
|
||||
warnings.warn(
|
||||
"get_astrbot_plugin_path is deprecated. Use AstrbotPaths class instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return str(AstrbotPaths.astrbot_root / "plugins")
|
||||
return os.path.realpath(os.path.join(get_astrbot_data_path(), "plugins"))
|
||||
|
||||
@@ -12,16 +12,12 @@ from pathlib import Path
|
||||
import aiohttp
|
||||
import certifi
|
||||
import psutil
|
||||
from astrbot_api.abc import IAstrbotPaths
|
||||
from PIL import Image
|
||||
|
||||
from astrbot_sdk import sync_base_container
|
||||
|
||||
from .astrbot_path import get_astrbot_data_path
|
||||
|
||||
logger = logging.getLogger("astrbot")
|
||||
|
||||
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
|
||||
|
||||
def on_error(func, path, exc_info):
|
||||
"""A callback of the rmtree function."""
|
||||
@@ -54,11 +50,11 @@ def port_checker(port: int, host: str = "localhost"):
|
||||
|
||||
|
||||
def save_temp_img(img: Image.Image | str) -> str:
|
||||
temp_dir = str(AstrbotPaths.astrbot_root / "temp")
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
# 获得文件创建时间,清除超过 12 小时的
|
||||
try:
|
||||
for f in os.listdir(temp_dir):
|
||||
path = str(AstrbotPaths.astrbot_root / "temp" / f)
|
||||
path = os.path.join(temp_dir, f)
|
||||
if os.path.isfile(path):
|
||||
ctime = os.path.getctime(path)
|
||||
if time.time() - ctime > 3600 * 12:
|
||||
@@ -68,7 +64,7 @@ def save_temp_img(img: Image.Image | str) -> str:
|
||||
|
||||
# 获得时间戳
|
||||
timestamp = f"{int(time.time())}_{uuid.uuid4().hex[:8]}"
|
||||
p = str(AstrbotPaths.astrbot_root / "temp" / f"{timestamp}.jpg")
|
||||
p = os.path.join(temp_dir, f"{timestamp}.jpg")
|
||||
|
||||
if isinstance(img, Image.Image):
|
||||
img.save(p)
|
||||
@@ -209,9 +205,9 @@ def get_local_ip_addresses():
|
||||
|
||||
|
||||
async def get_dashboard_version():
|
||||
dist_dir = str(AstrbotPaths.astrbot_root / "dist")
|
||||
dist_dir = os.path.join(get_astrbot_data_path(), "dist")
|
||||
if os.path.exists(dist_dir):
|
||||
version_file = str(AstrbotPaths.astrbot_root / "dist" / "assets" / "version")
|
||||
version_file = os.path.join(dist_dir, "assets", "version")
|
||||
if os.path.exists(version_file):
|
||||
with open(version_file, encoding="utf-8") as f:
|
||||
v = f.read().strip()
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
"""会话控制"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
import asyncio
|
||||
import copy
|
||||
@@ -8,11 +10,13 @@ import time
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
import anyio
|
||||
|
||||
import astrbot.core.message.components as Comp
|
||||
from astrbot.core.platform import AstrMessageEvent
|
||||
|
||||
USER_SESSIONS: dict[str, "SessionWaiter"] = {} # 存储 SessionWaiter 实例
|
||||
FILTERS: list["SessionFilter"] = [] # 存储 SessionFilter 实例
|
||||
USER_SESSIONS: dict[str, SessionWaiter] = {} # 存储 SessionWaiter 实例
|
||||
FILTERS: list[SessionFilter] = [] # 存储 SessionFilter 实例
|
||||
|
||||
|
||||
class SessionController:
|
||||
@@ -20,16 +24,16 @@ class SessionController:
|
||||
|
||||
def __init__(self):
|
||||
self.future = asyncio.Future()
|
||||
self.current_event: asyncio.Event = None
|
||||
self.current_event: anyio.Event | None = None
|
||||
"""当前正在等待的所用的异步事件"""
|
||||
self.ts: float = None
|
||||
self.ts: float | None = None
|
||||
"""上次保持(keep)开始时的时间"""
|
||||
self.timeout: float | int = None
|
||||
self.timeout: float | int | None = None
|
||||
"""上次保持(keep)开始时的超时时间"""
|
||||
|
||||
self.history_chains: list[list[Comp.BaseMessageComponent]] = []
|
||||
|
||||
def stop(self, error: Exception = None):
|
||||
def stop(self, error: Exception | None = None):
|
||||
"""立即结束这个会话"""
|
||||
if not self.future.done():
|
||||
if error:
|
||||
@@ -53,7 +57,9 @@ class SessionController:
|
||||
self.stop()
|
||||
return
|
||||
else:
|
||||
left_timeout = self.timeout - (new_ts - self.ts)
|
||||
current_timeout = self.timeout if self.timeout is not None else 0
|
||||
current_ts = self.ts if self.ts is not None else new_ts
|
||||
left_timeout = current_timeout - (new_ts - current_ts)
|
||||
timeout = left_timeout + timeout
|
||||
if timeout <= 0:
|
||||
self.stop()
|
||||
@@ -62,18 +68,19 @@ class SessionController:
|
||||
if self.current_event and not self.current_event.is_set():
|
||||
self.current_event.set() # 通知上一个 keep 结束
|
||||
|
||||
new_event = asyncio.Event()
|
||||
new_event = anyio.Event()
|
||||
self.ts = new_ts
|
||||
self.current_event = new_event
|
||||
self.timeout = timeout
|
||||
|
||||
asyncio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
|
||||
anyio.create_task(self._holding(new_event, timeout)) # 开始新的 keep
|
||||
|
||||
async def _holding(self, event: asyncio.Event, timeout: int):
|
||||
async def _holding(self, event: anyio.Event, timeout_seconds: float):
|
||||
"""等待事件结束或超时"""
|
||||
try:
|
||||
await asyncio.wait_for(event.wait(), timeout)
|
||||
except asyncio.TimeoutError:
|
||||
with anyio.move_on_after(timeout_seconds):
|
||||
await event.wait()
|
||||
except TimeoutError:
|
||||
if not self.future.done():
|
||||
self.future.set_exception(TimeoutError("等待超时"))
|
||||
except asyncio.CancelledError:
|
||||
@@ -105,10 +112,12 @@ class SessionWaiter:
|
||||
session_filter: SessionFilter,
|
||||
session_id: str,
|
||||
record_history_chains: bool,
|
||||
):
|
||||
) -> None:
|
||||
self.session_id = session_id
|
||||
self.session_filter = session_filter
|
||||
self.handler: Callable[[str], Awaitable[Any]] | None = None # 处理函数
|
||||
self.handler: (
|
||||
Callable[[SessionController, AstrMessageEvent], Awaitable[Any]] | None
|
||||
) = None # 处理函数
|
||||
|
||||
self.session_controller = SessionController()
|
||||
self.record_history_chains = record_history_chains
|
||||
@@ -119,15 +128,15 @@ class SessionWaiter:
|
||||
|
||||
async def register_wait(
|
||||
self,
|
||||
handler: Callable[[str], Awaitable[Any]],
|
||||
timeout: int = 30,
|
||||
handler: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
|
||||
timeout_seconds: int = 30,
|
||||
) -> Any:
|
||||
"""等待外部输入并处理"""
|
||||
self.handler = handler
|
||||
USER_SESSIONS[self.session_id] = self
|
||||
|
||||
# 开始一个会话保持事件
|
||||
self.session_controller.keep(timeout, reset_timeout=True)
|
||||
self.session_controller.keep(timeout_seconds, reset_timeout=True)
|
||||
|
||||
try:
|
||||
return await self.session_controller.future
|
||||
@@ -137,7 +146,7 @@ class SessionWaiter:
|
||||
finally:
|
||||
self._cleanup()
|
||||
|
||||
def _cleanup(self, error: Exception = None):
|
||||
def _cleanup(self, error: Exception | None = None):
|
||||
"""清理会话"""
|
||||
USER_SESSIONS.pop(self.session_id, None)
|
||||
try:
|
||||
@@ -153,6 +162,10 @@ class SessionWaiter:
|
||||
if not session or session.session_controller.future.done():
|
||||
return
|
||||
|
||||
# 此时 session 不会是 None,因为上面的检查
|
||||
if session is None:
|
||||
return
|
||||
|
||||
async with session._lock:
|
||||
if not session.session_controller.future.done():
|
||||
if session.record_history_chains:
|
||||
@@ -161,7 +174,8 @@ class SessionWaiter:
|
||||
)
|
||||
try:
|
||||
# TODO: 这里使用 create_task,跟踪 task,防止超时后这里 handler 仍然在执行
|
||||
await session.handler(session.session_controller, event)
|
||||
if session.handler is not None:
|
||||
await session.handler(session.session_controller, event)
|
||||
except Exception as e:
|
||||
session.session_controller.stop(e)
|
||||
|
||||
@@ -173,11 +187,13 @@ def session_waiter(timeout: int = 30, record_history_chains: bool = False):
|
||||
:param record_history_chain: 是否自动记录历史消息链。可以通过 controller.get_history_chains() 获取。深拷贝。
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[[str], Awaitable[Any]]):
|
||||
def decorator(
|
||||
func: Callable[[SessionController, AstrMessageEvent], Awaitable[Any]],
|
||||
):
|
||||
@functools.wraps(func)
|
||||
async def wrapper(
|
||||
event: AstrMessageEvent,
|
||||
session_filter: SessionFilter = None,
|
||||
session_filter: SessionFilter | None = None,
|
||||
*args,
|
||||
**kwargs,
|
||||
):
|
||||
|
||||
@@ -2,9 +2,8 @@
|
||||
|
||||
import os
|
||||
import shutil
|
||||
from importlib import resources
|
||||
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path, get_astrbot_path
|
||||
|
||||
|
||||
class TemplateManager:
|
||||
@@ -16,10 +15,14 @@ class TemplateManager:
|
||||
CORE_TEMPLATES = ["base.html", "astrbot_powershell.html"]
|
||||
|
||||
def __init__(self):
|
||||
self.builtin_template_dir = str(
|
||||
resources.files("astrbot.core.utils.t2i.template")
|
||||
self.builtin_template_dir = os.path.join(
|
||||
get_astrbot_path(),
|
||||
"astrbot",
|
||||
"core",
|
||||
"utils",
|
||||
"t2i",
|
||||
"template",
|
||||
)
|
||||
|
||||
self.user_template_dir = os.path.join(get_astrbot_data_path(), "t2i_templates")
|
||||
|
||||
os.makedirs(self.user_template_dir, exist_ok=True)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
|
||||
import anyio
|
||||
import jwt
|
||||
from quart import request
|
||||
|
||||
@@ -44,7 +44,7 @@ class AuthRoute(Route):
|
||||
)
|
||||
.__dict__
|
||||
)
|
||||
await asyncio.sleep(3)
|
||||
await anyio.sleep(3)
|
||||
return Response().error("用户名或密码错误").__dict__
|
||||
|
||||
async def edit_account(self):
|
||||
|
||||
@@ -4,18 +4,16 @@ import os
|
||||
import uuid
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from astrbot_api.abc import IAstrbotPaths
|
||||
import anyio
|
||||
from quart import Response as QuartResponse
|
||||
from quart import g, make_response, request
|
||||
|
||||
from astrbot_sdk import sync_base_container
|
||||
|
||||
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
|
||||
from astrbot.core import logger
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.platform.astr_message_event import MessageSession
|
||||
from astrbot.core.platform.sources.webchat.webchat_queue_mgr import webchat_queue_mgr
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
@@ -50,7 +48,7 @@ class ChatRoute(Route):
|
||||
}
|
||||
self.core_lifecycle = core_lifecycle
|
||||
self.register_routes()
|
||||
self.imgs_dir = str(AstrbotPaths.astrbot_root / "webchat" / "imgs")
|
||||
self.imgs_dir = os.path.join(get_astrbot_data_path(), "webchat", "imgs")
|
||||
os.makedirs(self.imgs_dir, exist_ok=True)
|
||||
|
||||
self.supported_imgs = ["jpg", "jpeg", "png", "gif", "webp"]
|
||||
@@ -191,8 +189,8 @@ class ChatRoute(Route):
|
||||
|
||||
try:
|
||||
if not client_disconnected:
|
||||
await asyncio.sleep(0.05)
|
||||
except asyncio.CancelledError:
|
||||
await anyio.sleep(0.05)
|
||||
except anyio.get_cancelled_exc_class():
|
||||
logger.debug(f"[WebChat] 用户 {username} 断开聊天长连接。")
|
||||
client_disconnected = True
|
||||
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import asyncio
|
||||
import importlib.resources
|
||||
import inspect
|
||||
import os
|
||||
import traceback
|
||||
@@ -22,6 +21,7 @@ from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.provider.provider import RerankProvider
|
||||
from astrbot.core.provider.register import provider_registry
|
||||
from astrbot.core.star.star import star_registry
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_path
|
||||
|
||||
from .route import Response, Route, RouteContext
|
||||
|
||||
@@ -461,12 +461,11 @@ class ConfigRoute(Route):
|
||||
logger.debug(
|
||||
f"Sending health check audio to provider: {status_info['name']}",
|
||||
)
|
||||
sample_audio_path = str(
|
||||
importlib.resources.files("astrbot")
|
||||
/ "samples"
|
||||
/ "stt_health_check.wav"
|
||||
sample_audio_path = os.path.join(
|
||||
get_astrbot_path(),
|
||||
"samples",
|
||||
"stt_health_check.wav",
|
||||
)
|
||||
|
||||
if not os.path.exists(sample_audio_path):
|
||||
status_info["status"] = "unavailable"
|
||||
status_info["error"] = (
|
||||
@@ -818,7 +817,8 @@ class ConfigRoute(Route):
|
||||
cached_token = self._logo_token_cache[cache_key]
|
||||
# 确保platform_default_tmpl[platform.name]存在且为字典
|
||||
if platform.name not in platform_default_tmpl or not isinstance(
|
||||
platform_default_tmpl[platform.name], dict
|
||||
platform_default_tmpl[platform.name],
|
||||
dict,
|
||||
):
|
||||
platform_default_tmpl[platform.name] = {}
|
||||
platform_default_tmpl[platform.name]["logo_token"] = cached_token
|
||||
@@ -847,7 +847,8 @@ class ConfigRoute(Route):
|
||||
|
||||
# 确保platform_default_tmpl[platform.name]存在且为字典
|
||||
if platform.name not in platform_default_tmpl or not isinstance(
|
||||
platform_default_tmpl[platform.name], dict
|
||||
platform_default_tmpl[platform.name],
|
||||
dict,
|
||||
):
|
||||
platform_default_tmpl[platform.name] = {}
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import socket
|
||||
|
||||
import jwt
|
||||
import psutil
|
||||
from astrbot_api.abc import IAstrbotPaths
|
||||
from quart import Quart, g, jsonify, request
|
||||
from quart.logging import default_handler
|
||||
|
||||
@@ -13,8 +12,8 @@ from astrbot.core import logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import get_local_ip_addresses
|
||||
from astrbot_sdk import sync_base_container
|
||||
|
||||
from .routes import *
|
||||
from .routes.route import Response, RouteContext
|
||||
@@ -23,7 +22,7 @@ from .routes.t2i import T2iRoute
|
||||
|
||||
APP: Quart = None
|
||||
|
||||
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
|
||||
|
||||
class AstrBotDashboard:
|
||||
def __init__(
|
||||
self,
|
||||
@@ -39,7 +38,9 @@ class AstrBotDashboard:
|
||||
if webui_dir and os.path.exists(webui_dir):
|
||||
self.data_path = os.path.abspath(webui_dir)
|
||||
else:
|
||||
self.data_path = os.path.abspath(str(AstrbotPaths.astrbot_root / "dist"))
|
||||
self.data_path = os.path.abspath(
|
||||
os.path.join(get_astrbot_data_path(), "dist"),
|
||||
)
|
||||
|
||||
self.app = Quart("dashboard", static_folder=self.data_path, static_url_path="/")
|
||||
APP = self.app # noqa
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
[project]
|
||||
name = "astrbot-api"
|
||||
dynamic = ["version"]
|
||||
description = "Astrbot Python API"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "AstrBotDevs", email = "community@astrbot.app" }
|
||||
]
|
||||
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"anyio>=4.11.0",
|
||||
"mcp>=1.12.4",
|
||||
"pydantic>=2.10.6",
|
||||
"pyyaml>=6.0.3",
|
||||
"types-pyyaml>=6.0.12.20250915",
|
||||
"yarl>=1.22.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling", "uv-dynamic-versioning"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.version]
|
||||
source = "uv-dynamic-versioning"
|
||||
@@ -1,7 +0,0 @@
|
||||
from .abc import IAstrbotPaths
|
||||
from .const import LOGO
|
||||
|
||||
__all__ = [
|
||||
"IAstrbotPaths",
|
||||
"LOGO",
|
||||
]
|
||||
@@ -1,379 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncGenerator, Callable, Generator
|
||||
from contextlib import (
|
||||
AbstractAsyncContextManager,
|
||||
AbstractContextManager,
|
||||
asynccontextmanager,
|
||||
contextmanager,
|
||||
)
|
||||
from importlib.abc import Traversable
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
import anyio
|
||||
import tomli
|
||||
import yaml
|
||||
from yarl import URL
|
||||
|
||||
from .models import AstrbotPluginMetadata
|
||||
|
||||
# region 核心运行时协议
|
||||
|
||||
|
||||
class IAstrbotContainerMgr(ABC):
|
||||
"""AstrBot 容器管理器的抽象基类."""
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class IAstrbotPaths(ABC):
|
||||
"""路径管理的抽象基类."""
|
||||
astrbot_root: ClassVar[Path]
|
||||
"""Astrbot 根目录路径."""
|
||||
|
||||
@abstractmethod
|
||||
def __init__(self, name: str) -> None:
|
||||
"""初始化路径管理器."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def getPaths(cls, name: str) -> IAstrbotPaths:
|
||||
"""返回Paths实例,用于访问模块的各类目录."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def root(self) -> Path:
|
||||
"""获取根目录."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def home(self) -> Path:
|
||||
"""获取模块/插件主目录."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def config(self) -> Path:
|
||||
"""获取模块配置目录."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def data(self) -> Path:
|
||||
"""获取模块数据目录."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def log(self) -> Path:
|
||||
"""获取模块日志目录."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def temp(self) -> Path:
|
||||
"""获取模块临时目录."""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def plugins(self) -> Path:
|
||||
"""获取插件目录."""
|
||||
|
||||
@abstractmethod
|
||||
def reload(self) -> None:
|
||||
"""重新加载环境变量."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def is_root(cls, path: Path) -> bool:
|
||||
"""判断路径是否为根目录."""
|
||||
|
||||
@abstractmethod
|
||||
def chdir(self, cwd: str = "home") -> AbstractContextManager[Path]:
|
||||
"""临时切换到指定目录, 子进程将继承此 CWD。"""
|
||||
|
||||
@abstractmethod
|
||||
async def achdir(self, cwd: str = "home") -> AbstractAsyncContextManager[Path]:
|
||||
"""异步临时切换到指定目录, 子进程将继承此 CWD。"""
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class AstrbotPluginBaseSession(ABC):
|
||||
"""插件会话的基类."""
|
||||
|
||||
url: URL
|
||||
"""插件的astrbot专有协议URL地址.
|
||||
|
||||
协议: astrbot://{stdio/web/legacy}/plugin_id
|
||||
"""
|
||||
|
||||
|
||||
@abstractmethod
|
||||
def connect(self) -> AstrbotPluginBaseSession:
|
||||
"""连接到插件."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def disconnect(self) -> None:
|
||||
"""断开与插件的连接."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def async_connect(self) -> AstrbotPluginBaseSession:
|
||||
"""异步连接到插件."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def async_disconnect(self) -> None:
|
||||
"""异步断开与插件的连接."""
|
||||
...
|
||||
|
||||
@contextmanager
|
||||
def session_scope(self) -> Generator[AstrbotPluginBaseSession]:
|
||||
"""插件会话的上下文管理器."""
|
||||
try:
|
||||
yield self.connect()
|
||||
finally:
|
||||
self.disconnect()
|
||||
|
||||
@asynccontextmanager
|
||||
async def async_session_scope(self) -> AsyncGenerator[AstrbotPluginBaseSession]:
|
||||
"""异步插件会话的上下文管理器."""
|
||||
try:
|
||||
yield await self.async_connect()
|
||||
finally:
|
||||
await self.async_disconnect()
|
||||
|
||||
# region 数据传输方法
|
||||
|
||||
# 主动发送数据
|
||||
@abstractmethod
|
||||
def send(self, data: bytes) -> bytes:
|
||||
"""发送数据到插件并接收响应."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def async_send(self, data: bytes) -> bytes:
|
||||
"""异步发送数据到插件并接收响应."""
|
||||
...
|
||||
|
||||
# 被动接收数据
|
||||
@abstractmethod
|
||||
def listen(self, callback: Callable[[bytes], bytes | None]) -> None:
|
||||
"""监听插件发送的数据.
|
||||
|
||||
callback: 一个接受 bytes 类型参数并返回 bytes 或 None 的函数.
|
||||
如果返回 None, 则不给插件发送响应.
|
||||
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def async_listen(self, callback: Callable[[bytes], bytes | None]) -> None:
|
||||
"""异步监听插件发送的数据.
|
||||
|
||||
callback: 一个接受 bytes 类型参数并返回 bytes 或 None 的函数.
|
||||
如果返回 None, 则不给插件发送响应.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class IVirtualAstrbotPlugin(ABC):
|
||||
"""AstrBot 虚拟插件的基类协议."""
|
||||
|
||||
vpo_map: ClassVar[dict[URL, IVirtualAstrbotPlugin]] = {}
|
||||
"""虚拟插件对象映射表, key 是插件的 astrbot 协议 URL 地址."""
|
||||
|
||||
url: URL
|
||||
"""AstrBot 插件的 astrbot 专有协议 URL 地址.
|
||||
协议:
|
||||
astrbot://{stdio/web/legacy}/plugin_id
|
||||
"""
|
||||
|
||||
metadata: AstrbotPluginMetadata
|
||||
"""插件元数据."""
|
||||
|
||||
session: AstrbotPluginBaseSession
|
||||
"""插件会话对象."""
|
||||
|
||||
# region 公共方法
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def handshake(cls) -> None:
|
||||
"""通用插件握手方法.
|
||||
|
||||
1. 发送握手请求
|
||||
2. 接受插件元数据响应,并设置 metadata 属性
|
||||
3. 返回确认消息,表示握手成功
|
||||
"""
|
||||
...
|
||||
|
||||
# part1 :工厂方法
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def fromFile(cls, path: Path, * , stdio: bool = False) -> IVirtualAstrbotPlugin:
|
||||
"""从文件加载插件/插件包的公共方法.
|
||||
|
||||
通过此方法加载经典插件: stdio=False
|
||||
或者子进程插件: stdio=True
|
||||
|
||||
任务:
|
||||
1. 从 path 加载插件,调用私有方法 _load_metadata 加载插件元数据.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def fromURL(cls, url: URL) -> IVirtualAstrbotPlugin:
|
||||
"""从URL加载插件/插件包的公共方法."""
|
||||
...
|
||||
|
||||
|
||||
# part2 :实例方法
|
||||
|
||||
@abstractmethod
|
||||
def get_logo(self) -> Traversable | None:
|
||||
"""获取插件的logo文件路径(Optional).
|
||||
|
||||
请使用importlib.resources.files来访问并返回Traversable对象.
|
||||
示例:
|
||||
from importlib.resources import files
|
||||
logo_path = files("plugin_a/assets/logo.png")
|
||||
# 如果插件安装在虚拟环境,可以直接这样获取
|
||||
|
||||
# 如果插件不可以安装到虚拟环境,适用于子进程插件/网络插件,可以这样获取
|
||||
|
||||
|
||||
返回None表示没有logo文件.
|
||||
Returns:
|
||||
Traversable | None: 插件logo文件的路径或None.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_metadata(self) -> AstrbotPluginMetadata:
|
||||
"""获取插件元数据的公共方法."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_url(self) -> URL:
|
||||
"""获取插件URL(astrbot协议)的公共方法."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_session(self) -> AstrbotPluginBaseSession:
|
||||
"""获取插件会话对象的公共方法."""
|
||||
...
|
||||
|
||||
|
||||
|
||||
# region 魔术方法
|
||||
|
||||
@abstractmethod
|
||||
def __str__(self) -> str:
|
||||
"""返回插件元数据的字符串表示."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def __repr__(self) -> str:
|
||||
"""返回插件元数据的正式字符串表示."""
|
||||
...
|
||||
|
||||
# region 私有方法
|
||||
@classmethod
|
||||
def _load_metadata(cls, path: Path) -> AstrbotPluginMetadata:
|
||||
"""加载插件元数据的私有方法,自动按下列优先级加载: pyproject -> yaml -> toml -> json.
|
||||
|
||||
其中yaml: plugin.yaml > metadata.yaml
|
||||
|
||||
此函数是上述工厂方法的辅助函数,用于加载插件元数据.
|
||||
"""
|
||||
match path.suffix.lower():
|
||||
case ".json":
|
||||
return cls._load_metadata_json(path)
|
||||
case ".toml":
|
||||
return cls._load_metadata_toml(path)
|
||||
case ".yaml" | ".yml":
|
||||
return cls._load_metadata_yaml(path)
|
||||
case _:
|
||||
raise ValueError(f"不支持的插件元数据文件格式: {path.suffix}")
|
||||
|
||||
@classmethod
|
||||
async def _load_metadata_async(cls, path: Path) -> AstrbotPluginMetadata:
|
||||
"""异步加载插件元数据的私有方法,自动按下列优先级加载: pyproject -> yaml -> toml -> json.
|
||||
|
||||
其中yaml: plugin.yaml > metadata.yaml
|
||||
|
||||
此函数是上述工厂方法的辅助函数,用于加载插件元数据.
|
||||
"""
|
||||
return await anyio.to_thread.run_sync(cls._load_metadata, path) # type: ignore[attr-defined]
|
||||
|
||||
@classmethod
|
||||
def _load_metadata_json(cls, path: Path) -> AstrbotPluginMetadata:
|
||||
"""从json文件加载插件元数据的私有方法."""
|
||||
with path.open(encoding="utf-8") as f:
|
||||
data = json.load(f)
|
||||
return AstrbotPluginMetadata.model_validate(**data)
|
||||
|
||||
@classmethod
|
||||
async def _load_metadata_json_async(cls, path: Path) -> AstrbotPluginMetadata:
|
||||
"""异步从json文件加载插件元数据的私有方法."""
|
||||
async with await anyio.open_file(path, "r", encoding="utf-8") as f:
|
||||
content = await f.read()
|
||||
data = json.loads(content)
|
||||
return AstrbotPluginMetadata.model_validate(**data)
|
||||
|
||||
@classmethod
|
||||
def _load_metadata_toml(cls, path: Path) -> AstrbotPluginMetadata:
|
||||
"""从toml文件加载插件元数据的私有方法."""
|
||||
with path.open("rb") as f:
|
||||
data = tomli.load(f)
|
||||
return AstrbotPluginMetadata.model_validate(**data)
|
||||
|
||||
@classmethod
|
||||
async def _load_metadata_toml_async(cls, path: Path) -> AstrbotPluginMetadata:
|
||||
"""异步从toml文件加载插件元数据的私有方法."""
|
||||
async with await anyio.open_file(path, "rb") as f:
|
||||
content = await f.read()
|
||||
content_str = content.decode("utf-8")
|
||||
data = tomli.loads(content_str)
|
||||
return AstrbotPluginMetadata.model_validate(**data)
|
||||
|
||||
@classmethod
|
||||
def _load_metadata_yaml(cls, path: Path) -> AstrbotPluginMetadata:
|
||||
"""从yaml文件加载插件元数据的私有方法."""
|
||||
with path.open(encoding="utf-8") as f:
|
||||
data = yaml.safe_load(f)
|
||||
return AstrbotPluginMetadata.model_validate(**data)
|
||||
|
||||
@classmethod
|
||||
async def _load_metadata_yaml_async(cls, path: Path) -> AstrbotPluginMetadata:
|
||||
"""异步从yaml文件加载插件元数据的私有方法."""
|
||||
async with await anyio.open_file(path, "r", encoding="utf-8") as f:
|
||||
content = await f.read()
|
||||
data = yaml.safe_load(content)
|
||||
return AstrbotPluginMetadata.model_validate(**data)
|
||||
|
||||
|
||||
# region 插件运行时协议
|
||||
|
||||
class IAstrbotPluginRuntime(ABC):
|
||||
"""AstrBot 插件运行时的基类协议."""
|
||||
|
||||
@abstractmethod
|
||||
def start(self) -> None:
|
||||
"""启动插件运行时."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def stop(self) -> None:
|
||||
"""停止插件运行时."""
|
||||
...
|
||||
|
||||
@@ -1,9 +0,0 @@
|
||||
LOGO = r"""
|
||||
___ _______.___________..______ .______ ______ .___________.
|
||||
/ \ / | || _ \ | _ \ / __ \ | |
|
||||
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
|
||||
/ /_\ \ \ \ | | | / | _ < | | | | | |
|
||||
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
|
||||
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
|
||||
|
||||
"""
|
||||
@@ -1,11 +0,0 @@
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class PluginType(Enum):
|
||||
"""插件大类."""
|
||||
# 兼容保留
|
||||
LEGACY = "legacy" # 经典插件,直接在主进程中运行 -- 逐渐淘汰
|
||||
|
||||
# 后两者为新插件机制
|
||||
STDIO = "stdio" # 子进程 -- 进程间通信
|
||||
WEB = "web" # 通过 HTTP/HTTPS 协议调用
|
||||
@@ -1,3 +0,0 @@
|
||||
|
||||
class AstrbotBaseError(Exception):
|
||||
"""Base exception for Astrbot API errors."""
|
||||
@@ -1,28 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class AstrbotPluginMetadata(BaseModel):
|
||||
"""AstrBot 插件元数据模型."""
|
||||
name: str | None = None
|
||||
"""插件名"""
|
||||
author: str | None = None
|
||||
"""插件作者"""
|
||||
desc: str | None = None
|
||||
"""插件简介"""
|
||||
version: str | None = None
|
||||
"""插件版本"""
|
||||
repo: str | None = None
|
||||
"""插件仓库地址"""
|
||||
|
||||
reserved: bool = False
|
||||
"""是否是 AstrBot 的保留插件"""
|
||||
|
||||
activated: bool = True
|
||||
"""是否被激活"""
|
||||
|
||||
display_name: str | None = None
|
||||
"""用于展示的插件名称"""
|
||||
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
[project]
|
||||
name = "astrbot-sdk"
|
||||
dynamic = ["version"]
|
||||
description = "Astrbot Python SDK"
|
||||
readme = "README.md"
|
||||
authors = [
|
||||
{ name = "AstrBotDevs", email = "community@astrbot.app" }
|
||||
]
|
||||
|
||||
requires-python = ">=3.10"
|
||||
dependencies = [
|
||||
"astrbot-api",
|
||||
"dishka>=1.7.2",
|
||||
"dotenv>=0.9.9",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling", "uv-dynamic-versioning"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.version]
|
||||
source = "uv-dynamic-versioning"
|
||||
|
||||
[tool.uv.sources]
|
||||
astrbot-api = { workspace = true }
|
||||
@@ -1,23 +0,0 @@
|
||||
from dishka import (
|
||||
AsyncContainer,
|
||||
Container,
|
||||
Provider,
|
||||
make_async_container,
|
||||
make_container,
|
||||
)
|
||||
|
||||
from .base import AstrbotBaseProvider
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
base_provider: Provider = AstrbotBaseProvider()
|
||||
async_base_container: AsyncContainer = make_async_container(base_provider)
|
||||
sync_base_container: Container = make_container(base_provider)
|
||||
|
||||
__all__ = [
|
||||
"async_base_container",
|
||||
"sync_base_container",
|
||||
]
|
||||
@@ -1,3 +0,0 @@
|
||||
# 基础类的实现
|
||||
|
||||
- 单向导出
|
||||
@@ -1,3 +0,0 @@
|
||||
from .provider import AstrbotBaseProvider
|
||||
|
||||
__all__ = ["AstrbotBaseProvider"]
|
||||
@@ -1,131 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator, Generator
|
||||
from contextlib import asynccontextmanager, contextmanager
|
||||
from os import chdir, getenv
|
||||
from pathlib import Path
|
||||
from typing import ClassVar
|
||||
|
||||
from dotenv import load_dotenv
|
||||
from packaging.utils import NormalizedName, canonicalize_name
|
||||
|
||||
from astrbot_api import IAstrbotPaths
|
||||
|
||||
|
||||
class AstrbotPaths(IAstrbotPaths):
|
||||
"""统一化路径获取."""
|
||||
|
||||
load_dotenv()
|
||||
astrbot_root: ClassVar[Path] = Path(
|
||||
getenv("ASTRBOT_ROOT", Path.home() / ".astrbot")
|
||||
).absolute()
|
||||
|
||||
_instances: ClassVar[dict[str, AstrbotPaths]] = {}
|
||||
|
||||
def __init__(self, name: str) -> None:
|
||||
self.name: str = name
|
||||
# 确保根目录存在
|
||||
self.astrbot_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@classmethod
|
||||
def getPaths(cls, name: str) -> AstrbotPaths:
|
||||
"""返回Paths实例,用于访问模块的各类目录."""
|
||||
normalized_name: NormalizedName = canonicalize_name(name)
|
||||
if normalized_name in cls._instances:
|
||||
return cls._instances[normalized_name]
|
||||
instance: AstrbotPaths = cls(normalized_name)
|
||||
instance.name = normalized_name
|
||||
cls._instances[normalized_name] = instance
|
||||
return instance
|
||||
|
||||
@property
|
||||
def root(self) -> Path:
|
||||
"""返回根目录."""
|
||||
return (
|
||||
self.astrbot_root if self.astrbot_root.exists() else Path.cwd() / ".astrbot"
|
||||
)
|
||||
|
||||
@property
|
||||
def home(self) -> Path:
|
||||
"""模块/插件主目录.
|
||||
|
||||
通过此属性获取模块/插件主目录.
|
||||
"""
|
||||
my_home = self.astrbot_root / "home" / self.name
|
||||
my_home.mkdir(parents=True, exist_ok=True)
|
||||
return my_home
|
||||
|
||||
@property
|
||||
def config(self) -> Path:
|
||||
"""返回模块/插件配置目录.
|
||||
|
||||
搭配 astrbot_config 使用.
|
||||
"""
|
||||
config_path = self.astrbot_root / "config" / self.name
|
||||
config_path.mkdir(parents=True, exist_ok=True)
|
||||
return config_path
|
||||
|
||||
@property
|
||||
def data(self) -> Path:
|
||||
"""返回模块/插件数据目录."""
|
||||
data_path = self.astrbot_root / "data" / self.name
|
||||
data_path.mkdir(parents=True, exist_ok=True)
|
||||
return data_path
|
||||
|
||||
@property
|
||||
def log(self) -> Path:
|
||||
"""返回模块日志目录."""
|
||||
log_path = self.astrbot_root / "logs" / self.name
|
||||
log_path.mkdir(parents=True, exist_ok=True)
|
||||
return log_path
|
||||
|
||||
@property
|
||||
def temp(self) -> Path:
|
||||
"""返回模块临时文件目录."""
|
||||
temp_path = self.astrbot_root / "temp" / self.name
|
||||
temp_path.mkdir(parents=True, exist_ok=True)
|
||||
return temp_path
|
||||
|
||||
@property
|
||||
def plugins(self) -> Path:
|
||||
"""返回插件目录."""
|
||||
plugin_path = self.astrbot_root / "plugins" / self.name
|
||||
plugin_path.mkdir(parents=True, exist_ok=True)
|
||||
return plugin_path
|
||||
|
||||
@classmethod
|
||||
def is_root(cls, path: Path) -> bool:
|
||||
"""检查路径是否为 Astrbot 根目录."""
|
||||
if not path.exists() or not path.is_dir():
|
||||
return False
|
||||
# 检查此目录内是是否包含.astrbot标记文件
|
||||
return bool((path / ".astrbot").exists())
|
||||
|
||||
def reload(self) -> None:
|
||||
"""重新加载环境变量."""
|
||||
load_dotenv()
|
||||
self.__class__.astrbot_root = Path(
|
||||
getenv("ASTRBOT_ROOT", Path.home() / ".astrbot")
|
||||
).absolute()
|
||||
|
||||
@contextmanager
|
||||
def chdir(self, cwd: str = "home") -> Generator[Path]:
|
||||
"""临时切换到指定目录, 子进程将继承此 CWD。"""
|
||||
original_cwd = Path.cwd()
|
||||
target_dir = self.root / cwd
|
||||
try:
|
||||
chdir(target_dir)
|
||||
yield target_dir
|
||||
finally:
|
||||
chdir(original_cwd)
|
||||
|
||||
@asynccontextmanager
|
||||
async def achdir(self, cwd: str = "home") -> AsyncGenerator[Path]:
|
||||
"""异步上下文管理器: 临时切换到指定目录, 子进程将继承此 CWD。"""
|
||||
original_cwd = Path.cwd()
|
||||
target_dir = self.root / cwd
|
||||
try:
|
||||
chdir(target_dir)
|
||||
yield target_dir
|
||||
finally:
|
||||
chdir(original_cwd)
|
||||
@@ -1,19 +0,0 @@
|
||||
from dishka import Provider, Scope, provide
|
||||
|
||||
from astrbot_api import IAstrbotPaths
|
||||
|
||||
from .paths import AstrbotPaths
|
||||
|
||||
|
||||
class AstrbotBaseProvider(Provider):
|
||||
scope = Scope.APP # 基础Provider的作用域设为APP
|
||||
|
||||
@provide
|
||||
def get_astrbot_paths_cls(self) -> type[IAstrbotPaths]:
|
||||
return AstrbotPaths
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
117
main.py
117
main.py
@@ -1,4 +1,117 @@
|
||||
from astrbot.__main__ import main
|
||||
import argparse
|
||||
import mimetypes
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import anyio
|
||||
|
||||
from astrbot.core import LogBroker, LogManager, db_helper, logger
|
||||
from astrbot.core.config.default import VERSION
|
||||
from astrbot.core.initial_loader import InitialLoader
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_dashboard, get_dashboard_version
|
||||
|
||||
# uvloop 仅在非 Windows 平台可用
|
||||
backend_options = {"use_uvloop": True} if sys.platform != "win32" else {}
|
||||
|
||||
# 将父目录添加到 sys.path
|
||||
sys.path.append(Path(__file__).parent.as_posix())
|
||||
|
||||
logo_tmpl = r"""
|
||||
___ _______.___________..______ .______ ______ .___________.
|
||||
/ \ / | || _ \ | _ \ / __ \ | |
|
||||
/ ^ \ | (----`---| |----`| |_) | | |_) | | | | | `---| |----`
|
||||
/ /_\ \ \ \ | | | / | _ < | | | | | |
|
||||
/ _____ \ .----) | | | | |\ \----.| |_) | | `--' | | |
|
||||
/__/ \__\ |_______/ |__| | _| `._____||______/ \______/ |__|
|
||||
|
||||
"""
|
||||
|
||||
|
||||
def check_env():
|
||||
if not (sys.version_info.major == 3 and sys.version_info.minor >= 10):
|
||||
logger.error("请使用 Python3.10+ 运行本项目。")
|
||||
exit()
|
||||
|
||||
os.makedirs("data/config", exist_ok=True)
|
||||
os.makedirs("data/plugins", exist_ok=True)
|
||||
os.makedirs("data/temp", exist_ok=True)
|
||||
|
||||
# 针对问题 #181 的临时解决方案
|
||||
mimetypes.add_type("text/javascript", ".js")
|
||||
mimetypes.add_type("text/javascript", ".mjs")
|
||||
mimetypes.add_type("application/json", ".json")
|
||||
|
||||
|
||||
async def check_dashboard_files(webui_dir: str | None = None):
|
||||
"""下载管理面板文件"""
|
||||
# 指定webui目录
|
||||
if webui_dir:
|
||||
if os.path.exists(webui_dir):
|
||||
logger.info(f"使用指定的 WebUI 目录: {webui_dir}")
|
||||
return webui_dir
|
||||
logger.warning(f"指定的 WebUI 目录 {webui_dir} 不存在,将使用默认逻辑。")
|
||||
|
||||
data_dist_path = os.path.join(get_astrbot_data_path(), "dist")
|
||||
if os.path.exists(data_dist_path):
|
||||
v = await get_dashboard_version()
|
||||
if v is not None:
|
||||
# 存在文件
|
||||
if v == f"v{VERSION}":
|
||||
logger.info("WebUI 版本已是最新。")
|
||||
else:
|
||||
logger.warning(
|
||||
f"检测到 WebUI 版本 ({v}) 与当前 AstrBot 版本 (v{VERSION}) 不符。",
|
||||
)
|
||||
return data_dist_path
|
||||
|
||||
logger.info(
|
||||
"开始下载管理面板文件...高峰期(晚上)可能导致较慢的速度。如多次下载失败,请前往 https://github.com/AstrBotDevs/AstrBot/releases/latest 下载 dist.zip,并将其中的 dist 文件夹解压至 data 目录下。",
|
||||
)
|
||||
|
||||
try:
|
||||
await download_dashboard(version=f"v{VERSION}", latest=False)
|
||||
except Exception as e:
|
||||
logger.critical(f"下载管理面板文件失败: {e}。")
|
||||
return None
|
||||
|
||||
logger.info("管理面板下载完成。")
|
||||
return data_dist_path
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
parser = argparse.ArgumentParser(description="AstrBot")
|
||||
parser.add_argument(
|
||||
"--webui-dir",
|
||||
type=str,
|
||||
help="指定 WebUI 静态文件目录路径",
|
||||
default=None,
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
check_env()
|
||||
|
||||
# 启动日志代理
|
||||
log_broker = LogBroker()
|
||||
LogManager.set_queue_handler(logger, log_broker)
|
||||
|
||||
# 检查仪表板文件
|
||||
webui_dir = anyio.run(
|
||||
check_dashboard_files,
|
||||
args.webui_dir,
|
||||
backend_options=backend_options,
|
||||
)
|
||||
|
||||
db = db_helper
|
||||
|
||||
# 打印 logo
|
||||
logger.info(logo_tmpl)
|
||||
|
||||
core_lifecycle = InitialLoader(db, log_broker)
|
||||
core_lifecycle.webui_dir = webui_dir
|
||||
logger.info(
|
||||
"将按以下异步后端启动 AstrBot: %s",
|
||||
backend_options if backend_options else "asyncio",
|
||||
)
|
||||
anyio.run(core_lifecycle.start, backend_options=backend_options)
|
||||
|
||||
@@ -164,14 +164,14 @@ class ConversationCommands:
|
||||
"%m-%d %H:%M",
|
||||
)
|
||||
parts.append(
|
||||
f"{idx}. {conv['name']}({conv['id'][:4]})\n 上次更新:{ts_h}\n"
|
||||
f"{idx}. {conv['name']}({conv['id'][:4]})\n 上次更新:{ts_h}\n",
|
||||
)
|
||||
idx += 1
|
||||
if idx == 1:
|
||||
parts.append("没有找到任何对话。")
|
||||
dify_cid = provider.conversation_ids.get(message.unified_msg_origin, None)
|
||||
parts.append(
|
||||
f"\n\n用户: {message.unified_msg_origin}\n当前对话: {dify_cid}\n使用 /switch <序号> 切换对话。"
|
||||
f"\n\n用户: {message.unified_msg_origin}\n当前对话: {dify_cid}\n使用 /switch <序号> 切换对话。",
|
||||
)
|
||||
ret = "".join(parts)
|
||||
message.set_result(MessageEventResult().message(ret))
|
||||
@@ -211,7 +211,7 @@ class ConversationCommands:
|
||||
persona_id = persona["name"]
|
||||
title = _titles.get(conv.cid, "新对话")
|
||||
parts.append(
|
||||
f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n"
|
||||
f"{global_index}. {title}({conv.cid[:4]})\n 人格情景: {persona_id}\n 上次更新: {datetime.datetime.fromtimestamp(conv.updated_at).strftime('%m-%d %H:%M')}\n",
|
||||
)
|
||||
global_index += 1
|
||||
|
||||
|
||||
@@ -136,7 +136,7 @@ class ProviderCommands:
|
||||
curr_model = prov.get_model() or "无"
|
||||
parts.append(f"\n当前模型: [{curr_model}]")
|
||||
parts.append(
|
||||
"\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。"
|
||||
"\nTips: 使用 /model <模型名/编号>,即可实时更换模型。如目标模型不存在于上表,请输入模型名。",
|
||||
)
|
||||
|
||||
ret = "".join(parts)
|
||||
|
||||
@@ -9,15 +9,13 @@ from collections import defaultdict
|
||||
|
||||
import aiodocker
|
||||
import aiohttp
|
||||
from astrbot_api.abc import IAstrbotPaths
|
||||
import anyio
|
||||
|
||||
from astrbot.api import llm_tool, logger, star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
|
||||
from astrbot.api.message_components import File, Image
|
||||
from astrbot.api.provider import ProviderRequest
|
||||
from astrbot_sdk import sync_base_container
|
||||
|
||||
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
from astrbot.core.utils.io import download_file, download_image_by_url
|
||||
|
||||
PROMPT = """
|
||||
@@ -94,7 +92,7 @@ DEFAULT_CONFIG = {
|
||||
},
|
||||
"docker_host_astrbot_abs_path": "",
|
||||
}
|
||||
PATH = str(AstrbotPaths.astrbot_root / "config" / "python_interpreter.json")
|
||||
PATH = os.path.join(get_astrbot_data_path(), "config", "python_interpreter.json")
|
||||
|
||||
|
||||
class Main(star.Star):
|
||||
@@ -104,15 +102,13 @@ class Main(star.Star):
|
||||
self.context = context
|
||||
self.curr_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
|
||||
self.shared_path = str(AstrbotPaths.astrbot_root / "py_interpreter_shared")
|
||||
self.shared_path = os.path.join("data", "py_interpreter_shared")
|
||||
if not os.path.exists(self.shared_path):
|
||||
# 复制 api.py 到 shared 目录
|
||||
os.makedirs(self.shared_path, exist_ok=True)
|
||||
shared_api_file = os.path.join(self.curr_dir, "shared", "api.py")
|
||||
shutil.copy(shared_api_file, self.shared_path)
|
||||
self.workplace_path = str(
|
||||
AstrbotPaths.astrbot_root / "py_interpreter_workplace"
|
||||
)
|
||||
self.workplace_path = os.path.join("data", "py_interpreter_workplace")
|
||||
os.makedirs(self.workplace_path, exist_ok=True)
|
||||
|
||||
self.user_file_msg_buffer = defaultdict(list)
|
||||
@@ -217,7 +213,8 @@ class Main(star.Star):
|
||||
file_path = await comp.get_file()
|
||||
if file_path.startswith("http"):
|
||||
name = comp.name if comp.name else uuid.uuid4().hex[:8]
|
||||
path = str(AstrbotPaths.astrbot_root / "temp" / name)
|
||||
temp_dir = os.path.join(get_astrbot_data_path(), "temp")
|
||||
path = os.path.join(temp_dir, name)
|
||||
await download_file(file_path, path)
|
||||
else:
|
||||
path = file_path
|
||||
@@ -295,7 +292,7 @@ class Main(star.Star):
|
||||
self.user_waiting[uid] = time.time()
|
||||
tip = "文件"
|
||||
yield event.plain_result(f"代码执行器: 请在 60s 内上传一个{tip}。")
|
||||
await asyncio.sleep(60)
|
||||
await anyio.sleep(60)
|
||||
if uid in self.user_waiting:
|
||||
yield event.plain_result(
|
||||
f"代码执行器: {event.get_sender_name()}/{event.get_sender_id()} 未在规定时间内上传{tip}。",
|
||||
|
||||
@@ -5,13 +5,9 @@ import uuid
|
||||
import zoneinfo
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from astrbot_api.abc import IAstrbotPaths
|
||||
|
||||
from astrbot.api import llm_tool, logger, star
|
||||
from astrbot.api.event import AstrMessageEvent, MessageEventResult, filter
|
||||
from astrbot_sdk import sync_base_container
|
||||
|
||||
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
|
||||
from astrbot.core.utils.astrbot_path import get_astrbot_data_path
|
||||
|
||||
|
||||
@@ -31,7 +27,7 @@ class Main(star.Star):
|
||||
self.scheduler = AsyncIOScheduler(timezone=self.timezone)
|
||||
|
||||
# set and load config
|
||||
reminder_file = str(AstrbotPaths.astrbot_root / "astrbot-reminder.json")
|
||||
reminder_file = os.path.join(get_astrbot_data_path(), "astrbot-reminder.json")
|
||||
if not os.path.exists(reminder_file):
|
||||
with open(reminder_file, "w", encoding="utf-8") as f:
|
||||
f.write("{}")
|
||||
|
||||
@@ -2,6 +2,7 @@ import asyncio
|
||||
import random
|
||||
|
||||
import aiohttp
|
||||
import anyio
|
||||
from bs4 import BeautifulSoup
|
||||
from readability import Document
|
||||
|
||||
@@ -26,7 +27,7 @@ class Main(star.Star):
|
||||
def __init__(self, context: star.Context) -> None:
|
||||
self.context = context
|
||||
self.tavily_key_index = 0
|
||||
self.tavily_key_lock = asyncio.Lock()
|
||||
self.tavily_key_lock = anyio.Lock()
|
||||
|
||||
# 将 str 类型的 key 迁移至 list[str],并保存
|
||||
cfg = self.context.get_config()
|
||||
|
||||
@@ -63,10 +63,8 @@ dependencies = [
|
||||
"jieba>=0.42.1",
|
||||
"markitdown-no-magika[docx,xls,xlsx]>=0.1.2",
|
||||
"xinference-client",
|
||||
"dotenv>=0.9.9",
|
||||
"astrbot-api",
|
||||
"astrbot-sdk",
|
||||
"dishka>=1.7.2",
|
||||
"anyio>=4.11.0",
|
||||
"uvloop>=0.22.1 ; sys_platform == 'linux'",
|
||||
]
|
||||
|
||||
[dependency-groups]
|
||||
@@ -102,29 +100,48 @@ select = [
|
||||
# "SIM", # flake8-simplify
|
||||
]
|
||||
ignore = [
|
||||
"F403",
|
||||
"F405",
|
||||
"F403",
|
||||
"F405",
|
||||
"E501",
|
||||
"ASYNC230" # TODO: handle ASYNC230 in AstrBot
|
||||
]
|
||||
[tool.uv.workspace]
|
||||
members = [
|
||||
"astrbot_api",
|
||||
"astrbot_sdk",
|
||||
]
|
||||
|
||||
[tool.uv.sources]
|
||||
astrbot-api = { workspace = true }
|
||||
astrbot-sdk = { workspace = true }
|
||||
|
||||
[tool.pyright]
|
||||
typeCheckingMode = "basic"
|
||||
pythonVersion = "3.10"
|
||||
reportMissingTypeStubs = false
|
||||
reportMissingImports = false
|
||||
include = ["astrbot","astrbot_api","astrbot_sdk","packages"]
|
||||
include = ["astrbot","packages"]
|
||||
exclude = ["dashboard", "node_modules", "dist", "data", "tests"]
|
||||
|
||||
[tool.mypy]
|
||||
python_version = "3.10"
|
||||
warn_return_any = true
|
||||
warn_unused_configs = true
|
||||
disallow_untyped_defs = false
|
||||
disallow_incomplete_defs = false
|
||||
check_untyped_defs = true
|
||||
disallow_untyped_decorators = false
|
||||
no_implicit_optional = true
|
||||
warn_redundant_casts = true
|
||||
warn_unused_ignores = true
|
||||
warn_no_return = true
|
||||
warn_unreachable = true
|
||||
strict_equality = true
|
||||
show_error_codes = true
|
||||
ignore_missing_imports = true
|
||||
explicit_package_bases = true
|
||||
namespace_packages = true
|
||||
files = ["astrbot", "packages"]
|
||||
exclude = [
|
||||
"dashboard",
|
||||
"node_modules",
|
||||
"dist",
|
||||
"data",
|
||||
"tests",
|
||||
"packages/.*/.*",
|
||||
]
|
||||
|
||||
[tool.hatch.version]
|
||||
source = "uv-dynamic-versioning"
|
||||
|
||||
@@ -136,4 +153,3 @@ bump = true
|
||||
[build-system]
|
||||
requires = ["hatchling", "uv-dynamic-versioning"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
|
||||
550
requirements.txt
550
requirements.txt
@@ -1,54 +1,498 @@
|
||||
aiocqhttp>=1.4.4
|
||||
aiodocker>=0.24.0
|
||||
aiohttp>=3.11.18
|
||||
aiocqhttp>=1.4.4
|
||||
aiodocker>=0.24.0
|
||||
aiohttp>=3.11.18
|
||||
aiosqlite>=0.21.0
|
||||
anthropic>=0.51.0
|
||||
apscheduler>=3.11.0
|
||||
beautifulsoup4>=4.13.4
|
||||
certifi>=2025.4.26
|
||||
chardet~=5.1.0
|
||||
colorlog>=6.9.0
|
||||
cryptography>=44.0.3
|
||||
dashscope>=1.23.2
|
||||
defusedxml>=0.7.1
|
||||
deprecated>=1.2.18
|
||||
dingtalk-stream>=0.22.1
|
||||
docstring-parser>=0.16
|
||||
# This file was autogenerated by uv via the following command:
|
||||
# uv export --format requirements.txt --no-hashes --no-editable
|
||||
.
|
||||
aiocqhttp==1.4.4
|
||||
# via astrbot
|
||||
aiodocker==0.24.0
|
||||
# via astrbot
|
||||
aiofiles==25.1.0
|
||||
# via
|
||||
# astrbot
|
||||
# quart
|
||||
aiohappyeyeballs==2.6.1
|
||||
# via aiohttp
|
||||
aiohttp==3.13.2
|
||||
# via
|
||||
# aiodocker
|
||||
# astrbot
|
||||
# dashscope
|
||||
# dingtalk-stream
|
||||
# py-cord
|
||||
# qq-botpy
|
||||
# xinference-client
|
||||
aiosignal==1.4.0
|
||||
# via aiohttp
|
||||
aiosqlite==0.21.0
|
||||
# via astrbot
|
||||
annotated-types==0.7.0
|
||||
# via pydantic
|
||||
anthropic==0.72.0
|
||||
# via astrbot
|
||||
anyio==4.11.0
|
||||
# via
|
||||
# anthropic
|
||||
# astrbot
|
||||
# google-genai
|
||||
# httpx
|
||||
# mcp
|
||||
# openai
|
||||
# sse-starlette
|
||||
# starlette
|
||||
# watchfiles
|
||||
apscheduler==3.11.1
|
||||
# via
|
||||
# astrbot
|
||||
# qq-botpy
|
||||
argcomplete==3.6.3
|
||||
# via commitizen
|
||||
async-timeout==5.0.1 ; python_full_version < '3.11'
|
||||
# via aiohttp
|
||||
attrs==25.4.0
|
||||
# via
|
||||
# aiohttp
|
||||
# jsonschema
|
||||
# referencing
|
||||
audioop-lts==0.2.2 ; python_full_version >= '3.13'
|
||||
# via astrbot
|
||||
backports-asyncio-runner==1.2.0 ; python_full_version < '3.11'
|
||||
# via pytest-asyncio
|
||||
beautifulsoup4==4.14.2
|
||||
# via
|
||||
# astrbot
|
||||
# markdownify
|
||||
# markitdown-no-magika
|
||||
blinker==1.9.0
|
||||
# via
|
||||
# flask
|
||||
# quart
|
||||
cachetools==6.2.1
|
||||
# via google-auth
|
||||
certifi==2025.10.5
|
||||
# via
|
||||
# astrbot
|
||||
# dashscope
|
||||
# httpcore
|
||||
# httpx
|
||||
# requests
|
||||
cffi==2.0.0
|
||||
# via
|
||||
# cryptography
|
||||
# silk-python
|
||||
chardet==5.1.0
|
||||
# via
|
||||
# astrbot
|
||||
# readability-lxml
|
||||
charset-normalizer==3.4.4
|
||||
# via
|
||||
# commitizen
|
||||
# markitdown-no-magika
|
||||
# requests
|
||||
click==8.3.0
|
||||
# via
|
||||
# astrbot
|
||||
# flask
|
||||
# quart
|
||||
# uvicorn
|
||||
cobble==0.1.4
|
||||
# via mammoth
|
||||
colorama==0.4.6
|
||||
# via
|
||||
# click
|
||||
# colorlog
|
||||
# commitizen
|
||||
# pytest
|
||||
# tqdm
|
||||
colorlog==6.10.1
|
||||
# via astrbot
|
||||
commitizen==4.9.1
|
||||
coverage==7.11.0
|
||||
# via pytest-cov
|
||||
cryptography==46.0.3
|
||||
# via
|
||||
# astrbot
|
||||
# dashscope
|
||||
cssselect==1.3.0
|
||||
# via readability-lxml
|
||||
dashscope==1.24.9
|
||||
# via astrbot
|
||||
decli==0.6.3
|
||||
# via commitizen
|
||||
defusedxml==0.7.1
|
||||
# via
|
||||
# astrbot
|
||||
# markitdown-no-magika
|
||||
deprecated==1.3.1
|
||||
# via
|
||||
# astrbot
|
||||
# commitizen
|
||||
dingtalk-stream==0.24.3
|
||||
# via astrbot
|
||||
distro==1.9.0
|
||||
# via
|
||||
# anthropic
|
||||
# openai
|
||||
docstring-parser==0.17.0
|
||||
# via
|
||||
# anthropic
|
||||
# astrbot
|
||||
et-xmlfile==2.0.0
|
||||
# via openpyxl
|
||||
exceptiongroup==1.3.0 ; python_full_version < '3.11'
|
||||
# via
|
||||
# anyio
|
||||
# hypercorn
|
||||
# pytest
|
||||
# taskgroup
|
||||
faiss-cpu==1.10.0
|
||||
filelock>=3.18.0
|
||||
google-genai>=1.14.0
|
||||
lark-oapi>=1.4.15
|
||||
lxml-html-clean>=0.4.2
|
||||
mcp>=1.8.0
|
||||
openai>=1.78.0
|
||||
ormsgpack>=1.9.1
|
||||
pillow>=11.2.1
|
||||
pip>=25.1.1
|
||||
psutil>=5.8.0
|
||||
py-cord>=2.6.1
|
||||
pydantic~=2.10.3
|
||||
pydub>=0.25.1
|
||||
pyjwt>=2.10.1
|
||||
python-telegram-bot>=22.0
|
||||
qq-botpy>=1.2.1
|
||||
quart>=0.20.0
|
||||
readability-lxml>=0.8.4.1
|
||||
silk-python>=0.2.6
|
||||
slack-sdk>=3.35.0
|
||||
sqlalchemy[asyncio]>=2.0.41
|
||||
sqlmodel>=0.0.24
|
||||
telegramify-markdown>=0.5.1
|
||||
watchfiles>=1.0.5
|
||||
websockets>=15.0.1
|
||||
wechatpy>=1.8.18
|
||||
audioop-lts ; python_full_version >= '3.13'
|
||||
click>=8.2.1
|
||||
pypdf>=6.1.1
|
||||
aiofiles>=25.1.0
|
||||
rank-bm25>=0.2.2
|
||||
jieba>=0.42.1
|
||||
markitdown-no-magika[docx,xls,xlsx]>=0.1.2
|
||||
xinference-client
|
||||
# via astrbot
|
||||
filelock==3.20.0
|
||||
# via astrbot
|
||||
flask==3.1.2
|
||||
# via quart
|
||||
frozenlist==1.8.0
|
||||
# via
|
||||
# aiohttp
|
||||
# aiosignal
|
||||
google-auth==2.42.1
|
||||
# via google-genai
|
||||
google-genai==1.47.0
|
||||
# via astrbot
|
||||
greenlet==3.2.4
|
||||
# via sqlalchemy
|
||||
h11==0.16.0
|
||||
# via
|
||||
# httpcore
|
||||
# hypercorn
|
||||
# uvicorn
|
||||
# wsproto
|
||||
h2==4.3.0
|
||||
# via hypercorn
|
||||
hpack==4.1.0
|
||||
# via h2
|
||||
httpcore==1.0.9
|
||||
# via httpx
|
||||
httpx==0.28.1
|
||||
# via
|
||||
# aiocqhttp
|
||||
# anthropic
|
||||
# google-genai
|
||||
# lark-oapi
|
||||
# mcp
|
||||
# openai
|
||||
# python-telegram-bot
|
||||
httpx-sse==0.4.3
|
||||
# via mcp
|
||||
hypercorn==0.17.3
|
||||
# via quart
|
||||
hyperframe==6.1.0
|
||||
# via h2
|
||||
idna==3.11
|
||||
# via
|
||||
# anyio
|
||||
# httpx
|
||||
# requests
|
||||
# yarl
|
||||
iniconfig==2.3.0
|
||||
# via pytest
|
||||
itsdangerous==2.2.0
|
||||
# via
|
||||
# flask
|
||||
# quart
|
||||
jieba==0.42.1
|
||||
# via astrbot
|
||||
jinja2==3.1.6
|
||||
# via
|
||||
# commitizen
|
||||
# flask
|
||||
# quart
|
||||
jiter==0.11.1
|
||||
# via
|
||||
# anthropic
|
||||
# openai
|
||||
jsonschema==4.25.1
|
||||
# via mcp
|
||||
jsonschema-specifications==2025.9.1
|
||||
# via jsonschema
|
||||
lark-oapi==1.4.23
|
||||
# via astrbot
|
||||
lxml==6.0.2
|
||||
# via
|
||||
# lxml-html-clean
|
||||
# markitdown-no-magika
|
||||
# readability-lxml
|
||||
lxml-html-clean==0.4.3
|
||||
# via
|
||||
# astrbot
|
||||
# lxml
|
||||
# readability-lxml
|
||||
mammoth==1.11.0
|
||||
# via markitdown-no-magika
|
||||
markdownify==1.2.0
|
||||
# via markitdown-no-magika
|
||||
markitdown-no-magika==0.1.2
|
||||
# via astrbot
|
||||
markupsafe==3.0.3
|
||||
# via
|
||||
# flask
|
||||
# jinja2
|
||||
# quart
|
||||
# werkzeug
|
||||
mcp==1.12.4
|
||||
# via astrbot
|
||||
mistletoe==1.4.0
|
||||
# via telegramify-markdown
|
||||
multidict==6.7.0
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
numpy==2.2.6 ; python_full_version < '3.11'
|
||||
# via
|
||||
# faiss-cpu
|
||||
# pandas
|
||||
# rank-bm25
|
||||
numpy==2.3.4 ; python_full_version >= '3.11'
|
||||
# via
|
||||
# faiss-cpu
|
||||
# pandas
|
||||
# rank-bm25
|
||||
openai==2.6.1
|
||||
# via astrbot
|
||||
openpyxl==3.1.5
|
||||
# via markitdown-no-magika
|
||||
optionaldict==0.1.2
|
||||
# via wechatpy
|
||||
ormsgpack==1.11.0
|
||||
# via astrbot
|
||||
packaging==25.0
|
||||
# via
|
||||
# commitizen
|
||||
# faiss-cpu
|
||||
# pytest
|
||||
pandas==2.3.3
|
||||
# via markitdown-no-magika
|
||||
pillow==12.0.0
|
||||
# via astrbot
|
||||
pip==25.3
|
||||
# via astrbot
|
||||
pluggy==1.6.0
|
||||
# via
|
||||
# pytest
|
||||
# pytest-cov
|
||||
priority==2.0.0
|
||||
# via hypercorn
|
||||
prompt-toolkit==3.0.51
|
||||
# via
|
||||
# commitizen
|
||||
# questionary
|
||||
propcache==0.4.1
|
||||
# via
|
||||
# aiohttp
|
||||
# yarl
|
||||
psutil==7.1.2
|
||||
# via astrbot
|
||||
py-cord==2.6.1
|
||||
# via astrbot
|
||||
pyasn1==0.6.1
|
||||
# via
|
||||
# pyasn1-modules
|
||||
# rsa
|
||||
pyasn1-modules==0.4.2
|
||||
# via google-auth
|
||||
pycparser==2.23 ; implementation_name != 'PyPy'
|
||||
# via cffi
|
||||
pycryptodome==3.23.0
|
||||
# via lark-oapi
|
||||
pydantic==2.10.6
|
||||
# via
|
||||
# anthropic
|
||||
# astrbot
|
||||
# google-genai
|
||||
# mcp
|
||||
# openai
|
||||
# pydantic-settings
|
||||
# sqlmodel
|
||||
# xinference-client
|
||||
pydantic-core==2.27.2
|
||||
# via pydantic
|
||||
pydantic-settings==2.11.0
|
||||
# via mcp
|
||||
pydub==0.25.1
|
||||
# via astrbot
|
||||
pygments==2.19.2
|
||||
# via pytest
|
||||
pyjwt==2.10.1
|
||||
# via astrbot
|
||||
pypdf==6.1.3
|
||||
# via astrbot
|
||||
pytest==8.4.2
|
||||
# via
|
||||
# pytest-asyncio
|
||||
# pytest-cov
|
||||
pytest-asyncio==1.2.0
|
||||
pytest-cov==7.0.0
|
||||
python-dateutil==2.9.0.post0
|
||||
# via
|
||||
# pandas
|
||||
# wechatpy
|
||||
python-dotenv==1.2.1
|
||||
# via pydantic-settings
|
||||
python-multipart==0.0.20
|
||||
# via mcp
|
||||
python-telegram-bot==22.5
|
||||
# via astrbot
|
||||
pytz==2025.2
|
||||
# via pandas
|
||||
pywin32==311 ; sys_platform == 'win32'
|
||||
# via mcp
|
||||
pyyaml==6.0.3
|
||||
# via
|
||||
# commitizen
|
||||
# qq-botpy
|
||||
qq-botpy==1.2.1
|
||||
# via astrbot
|
||||
quart==0.20.0
|
||||
# via
|
||||
# aiocqhttp
|
||||
# astrbot
|
||||
questionary==2.1.1
|
||||
# via commitizen
|
||||
rank-bm25==0.2.2
|
||||
# via astrbot
|
||||
readability-lxml==0.8.4.1
|
||||
# via astrbot
|
||||
referencing==0.37.0
|
||||
# via
|
||||
# jsonschema
|
||||
# jsonschema-specifications
|
||||
requests==2.32.5
|
||||
# via
|
||||
# dashscope
|
||||
# dingtalk-stream
|
||||
# google-genai
|
||||
# lark-oapi
|
||||
# markitdown-no-magika
|
||||
# requests-toolbelt
|
||||
# wechatpy
|
||||
# xinference-client
|
||||
requests-toolbelt==1.0.0
|
||||
# via lark-oapi
|
||||
rpds-py==0.28.0
|
||||
# via
|
||||
# jsonschema
|
||||
# referencing
|
||||
rsa==4.9.1
|
||||
# via google-auth
|
||||
ruff==0.14.3
|
||||
silk-python==0.2.7
|
||||
# via astrbot
|
||||
six==1.17.0
|
||||
# via
|
||||
# markdownify
|
||||
# python-dateutil
|
||||
# wechatpy
|
||||
slack-sdk==3.37.0
|
||||
# via astrbot
|
||||
sniffio==1.3.1
|
||||
# via
|
||||
# anthropic
|
||||
# anyio
|
||||
# openai
|
||||
soupsieve==2.8
|
||||
# via beautifulsoup4
|
||||
sqlalchemy==2.0.44
|
||||
# via
|
||||
# astrbot
|
||||
# sqlmodel
|
||||
sqlmodel==0.0.27
|
||||
# via astrbot
|
||||
sse-starlette==3.0.3
|
||||
# via mcp
|
||||
starlette==0.50.0
|
||||
# via mcp
|
||||
taskgroup==0.2.2 ; python_full_version < '3.11'
|
||||
# via hypercorn
|
||||
telegramify-markdown==0.5.2
|
||||
# via astrbot
|
||||
tenacity==9.1.2
|
||||
# via google-genai
|
||||
termcolor==3.2.0
|
||||
# via commitizen
|
||||
tomli==2.3.0 ; python_full_version <= '3.11'
|
||||
# via
|
||||
# coverage
|
||||
# hypercorn
|
||||
# pytest
|
||||
tomlkit==0.13.3
|
||||
# via commitizen
|
||||
tqdm==4.67.1
|
||||
# via openai
|
||||
typing-extensions==4.15.0
|
||||
# via
|
||||
# aiosignal
|
||||
# aiosqlite
|
||||
# anthropic
|
||||
# anyio
|
||||
# beautifulsoup4
|
||||
# commitizen
|
||||
# cryptography
|
||||
# exceptiongroup
|
||||
# google-genai
|
||||
# hypercorn
|
||||
# multidict
|
||||
# openai
|
||||
# py-cord
|
||||
# pydantic
|
||||
# pydantic-core
|
||||
# pypdf
|
||||
# pytest-asyncio
|
||||
# referencing
|
||||
# sqlalchemy
|
||||
# starlette
|
||||
# taskgroup
|
||||
# typing-inspection
|
||||
# uvicorn
|
||||
# xinference-client
|
||||
typing-inspection==0.4.2
|
||||
# via pydantic-settings
|
||||
tzdata==2025.2
|
||||
# via
|
||||
# pandas
|
||||
# tzlocal
|
||||
tzlocal==5.3.1
|
||||
# via apscheduler
|
||||
urllib3==2.5.0
|
||||
# via requests
|
||||
uvicorn==0.38.0 ; sys_platform != 'emscripten'
|
||||
# via mcp
|
||||
uvloop==0.22.1 ; sys_platform == 'linux'
|
||||
# via astrbot
|
||||
watchfiles==1.1.1
|
||||
# via astrbot
|
||||
wcwidth==0.2.14
|
||||
# via prompt-toolkit
|
||||
websocket-client==1.9.0
|
||||
# via dashscope
|
||||
websockets==15.0.1
|
||||
# via
|
||||
# astrbot
|
||||
# dingtalk-stream
|
||||
# google-genai
|
||||
# lark-oapi
|
||||
wechatpy==1.8.18
|
||||
# via astrbot
|
||||
werkzeug==3.1.3
|
||||
# via
|
||||
# flask
|
||||
# quart
|
||||
wrapt==2.0.0
|
||||
# via deprecated
|
||||
wsproto==1.2.0
|
||||
# via hypercorn
|
||||
xinference-client==1.11.0.post1
|
||||
# via astrbot
|
||||
xlrd==2.0.2
|
||||
# via markitdown-no-magika
|
||||
xmltodict==1.0.2
|
||||
# via wechatpy
|
||||
yarl==1.22.0
|
||||
# via aiohttp
|
||||
|
||||
@@ -1,517 +0,0 @@
|
||||
"""测试 AstrbotPaths 路径类的综合测试."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pytest
|
||||
from astrbot_api.abc import IAstrbotPaths
|
||||
|
||||
from astrbot_sdk import sync_base_container
|
||||
|
||||
AstrbotPaths: type[IAstrbotPaths] = sync_base_container.get(type[IAstrbotPaths])
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Generator
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_root(monkeypatch: pytest.MonkeyPatch) -> Generator[Path]:
|
||||
"""创建一个临时根目录用于测试."""
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
temp_path = Path(tmpdir)
|
||||
monkeypatch.setenv("ASTRBOT_ROOT", str(temp_path))
|
||||
# 清除类变量和实例缓存
|
||||
AstrbotPaths._instances.clear()
|
||||
# 重新加载环境变量
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv(override=True)
|
||||
AstrbotPaths.astrbot_root = temp_path
|
||||
yield temp_path
|
||||
# 清理
|
||||
AstrbotPaths._instances.clear()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def paths_instance(temp_root: Path) -> AstrbotPaths:
|
||||
"""创建一个 AstrbotPaths 实例用于测试."""
|
||||
return AstrbotPaths.getPaths("test-module")
|
||||
|
||||
|
||||
class TestAstrbotPathsInit:
|
||||
"""测试 AstrbotPaths 初始化."""
|
||||
|
||||
def test_init_creates_root_directory(self, temp_root: Path) -> None:
|
||||
"""测试初始化时创建根目录."""
|
||||
# 删除根目录以测试自动创建
|
||||
if temp_root.exists():
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(temp_root)
|
||||
|
||||
AstrbotPaths("test-init")
|
||||
assert temp_root.exists()
|
||||
assert temp_root.is_dir()
|
||||
|
||||
def test_init_with_name(self, temp_root: Path) -> None:
|
||||
"""测试使用名称初始化."""
|
||||
paths = AstrbotPaths("my-module")
|
||||
assert paths.name == "my-module"
|
||||
|
||||
def test_astrbot_root_from_env(
|
||||
self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
||||
) -> None:
|
||||
"""测试从环境变量读取根目录."""
|
||||
custom_root = tmp_path / "custom_root"
|
||||
custom_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 清除实例缓存
|
||||
AstrbotPaths._instances.clear()
|
||||
|
||||
# 直接设置环境变量(在 load_dotenv 之前)
|
||||
monkeypatch.setenv("ASTRBOT_ROOT", str(custom_root))
|
||||
|
||||
# 直接更新 astrbot_root(模拟 load_dotenv 的效果但使用我们设置的环境变量)
|
||||
AstrbotPaths.astrbot_root = Path(
|
||||
os.getenv("ASTRBOT_ROOT", Path.home() / ".astrbot")
|
||||
).absolute()
|
||||
|
||||
assert AstrbotPaths.astrbot_root == custom_root.absolute()
|
||||
|
||||
def test_astrbot_root_default(
|
||||
self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
||||
) -> None:
|
||||
"""测试默认根目录."""
|
||||
# 清除环境变量
|
||||
monkeypatch.delenv("ASTRBOT_ROOT", raising=False)
|
||||
# 清除任何可能存在的 .env 文件影响
|
||||
monkeypatch.setattr("os.environ", {**os.environ})
|
||||
|
||||
# 清除实例缓存
|
||||
AstrbotPaths._instances.clear()
|
||||
|
||||
# 重新计算根目录
|
||||
AstrbotPaths.astrbot_root = Path(
|
||||
os.getenv("ASTRBOT_ROOT", Path.home() / ".astrbot")
|
||||
).absolute()
|
||||
|
||||
expected = (Path.home() / ".astrbot").absolute()
|
||||
assert AstrbotPaths.astrbot_root == expected
|
||||
|
||||
|
||||
class TestGetPaths:
|
||||
"""测试 getPaths 单例模式."""
|
||||
|
||||
def test_get_paths_returns_same_instance(self, temp_root: Path) -> None:
|
||||
"""测试多次调用返回同一个实例."""
|
||||
paths1 = AstrbotPaths.getPaths("test-module")
|
||||
paths2 = AstrbotPaths.getPaths("test-module")
|
||||
assert paths1 is paths2
|
||||
|
||||
def test_get_paths_different_names(self, temp_root: Path) -> None:
|
||||
"""测试不同名称返回不同实例."""
|
||||
paths1 = AstrbotPaths.getPaths("module-a")
|
||||
paths2 = AstrbotPaths.getPaths("module-b")
|
||||
assert paths1 is not paths2
|
||||
assert paths1.name == "module-a"
|
||||
assert paths2.name == "module-b"
|
||||
|
||||
def test_get_paths_normalizes_name(self, temp_root: Path) -> None:
|
||||
"""测试名称规范化."""
|
||||
# PEP 503 规范化: 转小写, 替换 -, _, .
|
||||
paths1 = AstrbotPaths.getPaths("Test_Module")
|
||||
paths2 = AstrbotPaths.getPaths("test-module")
|
||||
paths3 = AstrbotPaths.getPaths("TEST.MODULE")
|
||||
|
||||
# 所有这些名称应该被规范化为相同的名称
|
||||
assert paths1 is paths2
|
||||
assert paths2 is paths3
|
||||
|
||||
|
||||
class TestProperties:
|
||||
"""测试所有属性访问器."""
|
||||
|
||||
def test_root_property(self, paths_instance: AstrbotPaths, temp_root: Path) -> None:
|
||||
"""测试 root 属性."""
|
||||
assert paths_instance.root == temp_root
|
||||
assert paths_instance.root.exists()
|
||||
|
||||
def test_root_property_when_not_exists(
|
||||
self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
||||
) -> None:
|
||||
"""测试 root 属性当根目录不存在时."""
|
||||
non_existent = tmp_path / "non_existent_path"
|
||||
# 确保目录不存在
|
||||
if non_existent.exists():
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(non_existent)
|
||||
|
||||
# 清除实例缓存
|
||||
AstrbotPaths._instances.clear()
|
||||
# 设置不存在的路径
|
||||
AstrbotPaths.astrbot_root = non_existent
|
||||
|
||||
# __init__ 会创建根目录,所以 getPaths 会使根目录存在
|
||||
# 我们测试的是在 __init__ 创建目录之前访问 root 属性的行为
|
||||
# 但由于 getPaths 总是调用 __init__,目录总是会被创建
|
||||
# 所以这个测试应该验证即使最初不存在,getPaths 之后也会存在
|
||||
paths = AstrbotPaths.getPaths("test")
|
||||
# getPaths 调用 __init__,__init__ 会创建根目录
|
||||
# 所以 root 应该返回 astrbot_root(现在已存在)
|
||||
assert paths.root == non_existent
|
||||
assert non_existent.exists()
|
||||
|
||||
def test_root_property_fallback_to_cwd(
|
||||
self, monkeypatch: pytest.MonkeyPatch, tmp_path: Path
|
||||
) -> None:
|
||||
"""测试 root 属性在根目录被删除后回退到 cwd/.astrbot."""
|
||||
import shutil
|
||||
|
||||
# 创建并设置一个根目录
|
||||
temp_root = tmp_path / "test_root"
|
||||
temp_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 清除实例缓存
|
||||
AstrbotPaths._instances.clear()
|
||||
AstrbotPaths.astrbot_root = temp_root
|
||||
|
||||
# 创建实例
|
||||
paths = AstrbotPaths.getPaths("test-fallback")
|
||||
|
||||
# 删除根目录(模拟被外部删除的情况)
|
||||
shutil.rmtree(temp_root)
|
||||
|
||||
# 现在访问 root 应该回退到 cwd/.astrbot
|
||||
expected = Path.cwd() / ".astrbot"
|
||||
assert paths.root == expected
|
||||
|
||||
def test_home_property(self, paths_instance: AstrbotPaths, temp_root: Path) -> None:
|
||||
"""测试 home 属性."""
|
||||
home_path = paths_instance.home
|
||||
expected = temp_root / "home" / paths_instance.name
|
||||
assert home_path == expected
|
||||
assert home_path.exists()
|
||||
assert home_path.is_dir()
|
||||
|
||||
def test_config_property(
|
||||
self, paths_instance: AstrbotPaths, temp_root: Path
|
||||
) -> None:
|
||||
"""测试 config 属性."""
|
||||
config_path = paths_instance.config
|
||||
expected = temp_root / "config" / paths_instance.name
|
||||
assert config_path == expected
|
||||
assert config_path.exists()
|
||||
assert config_path.is_dir()
|
||||
|
||||
def test_data_property(self, paths_instance: AstrbotPaths, temp_root: Path) -> None:
|
||||
"""测试 data 属性."""
|
||||
data_path = paths_instance.data
|
||||
expected = temp_root / "data" / paths_instance.name
|
||||
assert data_path == expected
|
||||
assert data_path.exists()
|
||||
assert data_path.is_dir()
|
||||
|
||||
def test_log_property(self, paths_instance: AstrbotPaths, temp_root: Path) -> None:
|
||||
"""测试 log 属性."""
|
||||
log_path = paths_instance.log
|
||||
expected = temp_root / "logs" / paths_instance.name
|
||||
assert log_path == expected
|
||||
assert log_path.exists()
|
||||
assert log_path.is_dir()
|
||||
|
||||
def test_temp_property(self, paths_instance: AstrbotPaths, temp_root: Path) -> None:
|
||||
"""测试 temp 属性."""
|
||||
temp_path = paths_instance.temp
|
||||
expected = temp_root / "temp" / paths_instance.name
|
||||
assert temp_path == expected
|
||||
assert temp_path.exists()
|
||||
assert temp_path.is_dir()
|
||||
|
||||
def test_plugins_property(
|
||||
self, paths_instance: AstrbotPaths, temp_root: Path
|
||||
) -> None:
|
||||
"""测试 plugins 属性."""
|
||||
plugins_path = paths_instance.plugins
|
||||
expected = temp_root / "plugins" / paths_instance.name
|
||||
assert plugins_path == expected
|
||||
assert plugins_path.exists()
|
||||
assert plugins_path.is_dir()
|
||||
|
||||
def test_properties_create_nested_directories(
|
||||
self, paths_instance: AstrbotPaths, temp_root: Path
|
||||
) -> None:
|
||||
"""测试属性访问时创建嵌套目录."""
|
||||
# 清空目录
|
||||
import shutil
|
||||
|
||||
if temp_root.exists():
|
||||
for item in temp_root.iterdir():
|
||||
if item.is_dir():
|
||||
shutil.rmtree(item)
|
||||
else:
|
||||
item.unlink()
|
||||
|
||||
# 访问所有属性
|
||||
_ = paths_instance.home
|
||||
_ = paths_instance.config
|
||||
_ = paths_instance.data
|
||||
_ = paths_instance.log
|
||||
_ = paths_instance.temp
|
||||
_ = paths_instance.plugins
|
||||
|
||||
# 验证所有目录都已创建
|
||||
assert (temp_root / "home" / paths_instance.name).exists()
|
||||
assert (temp_root / "config" / paths_instance.name).exists()
|
||||
assert (temp_root / "data" / paths_instance.name).exists()
|
||||
assert (temp_root / "logs" / paths_instance.name).exists()
|
||||
assert (temp_root / "temp" / paths_instance.name).exists()
|
||||
assert (temp_root / "plugins" / paths_instance.name).exists()
|
||||
|
||||
|
||||
class TestIsRoot:
|
||||
"""测试 is_root 类方法."""
|
||||
|
||||
def test_is_root_with_marker_file(self, temp_root: Path) -> None:
|
||||
"""测试带有标记文件的根目录识别."""
|
||||
marker_file = temp_root / ".astrbot"
|
||||
marker_file.touch()
|
||||
|
||||
assert AstrbotPaths.is_root(temp_root) is True
|
||||
|
||||
def test_is_root_without_marker_file(self, temp_root: Path) -> None:
|
||||
"""测试没有标记文件的目录."""
|
||||
marker_file = temp_root / ".astrbot"
|
||||
if marker_file.exists():
|
||||
marker_file.unlink()
|
||||
|
||||
assert AstrbotPaths.is_root(temp_root) is False
|
||||
|
||||
def test_is_root_with_non_existent_path(self) -> None:
|
||||
"""测试不存在的路径."""
|
||||
non_existent = Path("/definitely/not/exist/path")
|
||||
assert AstrbotPaths.is_root(non_existent) is False
|
||||
|
||||
def test_is_root_with_file_not_directory(self, temp_root: Path) -> None:
|
||||
"""测试路径是文件而非目录."""
|
||||
test_file = temp_root / "test.txt"
|
||||
test_file.touch()
|
||||
|
||||
assert AstrbotPaths.is_root(test_file) is False
|
||||
|
||||
|
||||
class TestReload:
|
||||
"""测试 reload 方法."""
|
||||
|
||||
def test_reload_updates_root(
|
||||
self, temp_root: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""测试 reload 更新根目录."""
|
||||
paths = AstrbotPaths.getPaths("test-reload")
|
||||
|
||||
# 修改环境变量
|
||||
new_root = temp_root / "new_root"
|
||||
new_root.mkdir(parents=True, exist_ok=True)
|
||||
monkeypatch.setenv("ASTRBOT_ROOT", str(new_root))
|
||||
|
||||
# 重新加载
|
||||
paths.reload()
|
||||
|
||||
# 验证根目录已更新
|
||||
assert AstrbotPaths.astrbot_root == new_root.absolute()
|
||||
|
||||
def test_reload_clears_old_env(
|
||||
self, temp_root: Path, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""测试 reload 在环境变量被删除后使用默认值."""
|
||||
paths = AstrbotPaths.getPaths("test-reload-default")
|
||||
|
||||
# 删除环境变量
|
||||
monkeypatch.delenv("ASTRBOT_ROOT", raising=False)
|
||||
|
||||
# 重新加载
|
||||
paths.reload()
|
||||
|
||||
# 应该使用默认值
|
||||
(Path.home() / ".astrbot").absolute()
|
||||
# 由于 .env 文件可能存在,实际结果可能不变
|
||||
# 所以我们只验证 reload 没有抛出异常
|
||||
assert AstrbotPaths.astrbot_root is not None
|
||||
assert isinstance(AstrbotPaths.astrbot_root, Path)
|
||||
|
||||
|
||||
class TestChdir:
|
||||
"""测试 chdir 上下文管理器."""
|
||||
|
||||
def test_chdir_changes_directory(
|
||||
self, paths_instance: AstrbotPaths, temp_root: Path
|
||||
) -> None:
|
||||
"""测试 chdir 切换目录."""
|
||||
original_cwd = Path.cwd()
|
||||
|
||||
# 创建目标目录
|
||||
target_path = temp_root / "home"
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with paths_instance.chdir("home") as target_dir:
|
||||
current_cwd = Path.cwd()
|
||||
expected_dir = temp_root / "home"
|
||||
assert current_cwd == expected_dir
|
||||
assert target_dir == expected_dir
|
||||
|
||||
# 验证已恢复原目录
|
||||
assert Path.cwd() == original_cwd
|
||||
|
||||
def test_chdir_restores_on_exception(
|
||||
self, paths_instance: AstrbotPaths, temp_root: Path
|
||||
) -> None:
|
||||
"""测试 chdir 在异常时恢复原目录."""
|
||||
original_cwd = Path.cwd()
|
||||
|
||||
# 创建目标目录
|
||||
target_path = temp_root / "home"
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
with paths_instance.chdir("home"):
|
||||
raise ValueError("Test exception")
|
||||
|
||||
# 验证已恢复原目录
|
||||
assert Path.cwd() == original_cwd
|
||||
|
||||
def test_chdir_with_different_subdirectories(
|
||||
self, paths_instance: AstrbotPaths, temp_root: Path
|
||||
) -> None:
|
||||
"""测试 chdir 使用不同的子目录."""
|
||||
original_cwd = Path.cwd()
|
||||
|
||||
# 创建测试目录
|
||||
test_dir = temp_root / "test_subdir"
|
||||
test_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with paths_instance.chdir("test_subdir") as target_dir:
|
||||
assert Path.cwd() == test_dir
|
||||
assert target_dir == test_dir
|
||||
|
||||
assert Path.cwd() == original_cwd
|
||||
|
||||
|
||||
class TestAchdir:
|
||||
"""测试 achdir 异步上下文管理器."""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_achdir_changes_directory(
|
||||
self, paths_instance: AstrbotPaths, temp_root: Path
|
||||
) -> None:
|
||||
"""测试 achdir 异步切换目录."""
|
||||
original_cwd = Path.cwd()
|
||||
|
||||
# 创建目标目录
|
||||
target_path = temp_root / "home"
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async with paths_instance.achdir("home") as target_dir:
|
||||
current_cwd = Path.cwd()
|
||||
expected_dir = temp_root / "home"
|
||||
assert current_cwd == expected_dir
|
||||
assert target_dir == expected_dir
|
||||
|
||||
# 验证已恢复原目录
|
||||
assert Path.cwd() == original_cwd
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_achdir_restores_on_exception(
|
||||
self, paths_instance: AstrbotPaths, temp_root: Path
|
||||
) -> None:
|
||||
"""测试 achdir 在异常时恢复原目录."""
|
||||
original_cwd = Path.cwd()
|
||||
|
||||
# 创建目标目录
|
||||
target_path = temp_root / "home"
|
||||
target_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
async with paths_instance.achdir("home"):
|
||||
raise ValueError("Test exception")
|
||||
|
||||
# 验证已恢复原目录
|
||||
assert Path.cwd() == original_cwd
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_achdir_with_different_subdirectories(
|
||||
self, paths_instance: AstrbotPaths, temp_root: Path
|
||||
) -> None:
|
||||
"""测试 achdir 使用不同的子目录."""
|
||||
original_cwd = Path.cwd()
|
||||
|
||||
# 创建测试目录
|
||||
test_dir = temp_root / "async_test_subdir"
|
||||
test_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
async with paths_instance.achdir("async_test_subdir") as target_dir:
|
||||
assert Path.cwd() == test_dir
|
||||
assert target_dir == test_dir
|
||||
|
||||
assert Path.cwd() == original_cwd
|
||||
|
||||
|
||||
class TestIntegration:
|
||||
"""集成测试."""
|
||||
|
||||
def test_multiple_modules_isolated(self, temp_root: Path) -> None:
|
||||
"""测试多个模块之间的隔离."""
|
||||
module_a = AstrbotPaths.getPaths("module-a")
|
||||
module_b = AstrbotPaths.getPaths("module-b")
|
||||
|
||||
# 访问各自的 home 目录
|
||||
home_a = module_a.home
|
||||
home_b = module_b.home
|
||||
|
||||
# 验证目录不同
|
||||
assert home_a != home_b
|
||||
assert home_a == temp_root / "home" / "module-a"
|
||||
assert home_b == temp_root / "home" / "module-b"
|
||||
|
||||
# 验证都存在
|
||||
assert home_a.exists()
|
||||
assert home_b.exists()
|
||||
|
||||
def test_full_workflow(self, temp_root: Path) -> None:
|
||||
"""测试完整工作流."""
|
||||
# 创建一个模块
|
||||
module = AstrbotPaths.getPaths("my-plugin")
|
||||
|
||||
# 创建各种文件
|
||||
config_file = module.config / "settings.json"
|
||||
config_file.write_text('{"key": "value"}')
|
||||
|
||||
data_file = module.data / "data.txt"
|
||||
data_file.write_text("some data")
|
||||
|
||||
log_file = module.log / "app.log"
|
||||
log_file.write_text("log entry")
|
||||
|
||||
# 验证文件存在
|
||||
assert config_file.exists()
|
||||
assert data_file.exists()
|
||||
assert log_file.exists()
|
||||
|
||||
# 验证内容
|
||||
assert config_file.read_text() == '{"key": "value"}'
|
||||
assert data_file.read_text() == "some data"
|
||||
assert log_file.read_text() == "log entry"
|
||||
|
||||
def test_singleton_pattern_thread_safe(self, temp_root: Path) -> None:
|
||||
"""测试单例模式的基本行为(注意:不是真正的线程安全测试)."""
|
||||
instances = []
|
||||
for _ in range(10):
|
||||
instances.append(AstrbotPaths.getPaths("singleton-test"))
|
||||
|
||||
# 所有实例应该是同一个对象
|
||||
first = instances[0]
|
||||
for instance in instances[1:]:
|
||||
assert instance is first
|
||||
Reference in New Issue
Block a user