diff --git a/astrbot/core/__init__.py b/astrbot/core/__init__.py index 20ec6167..13a63063 100644 --- a/astrbot/core/__init__.py +++ b/astrbot/core/__init__.py @@ -20,7 +20,9 @@ if os.environ.get("TESTING", ""): logger.setLevel("DEBUG") db_helper = SQLiteDatabase(DB_PATH) -sp = SharedPreferences() # 简单的偏好设置存储 +sp = ( + SharedPreferences() +) # 简单的偏好设置存储, 这里后续应该存储到数据库中, 一些部分可以存储到配置中 pip_installer = PipInstaller(astrbot_config.get("pip_install_arg", "")) web_chat_queue = asyncio.Queue(maxsize=32) web_chat_back_queue = asyncio.Queue(maxsize=32) diff --git a/astrbot/core/conversation_mgr.py b/astrbot/core/conversation_mgr.py index ed29dddd..0e11fd46 100644 --- a/astrbot/core/conversation_mgr.py +++ b/astrbot/core/conversation_mgr.py @@ -1,3 +1,10 @@ +""" +AstrBot 会话-对话管理器, 维护两个本地存储, 其中一个是 json 格式的shared_preferences, 另外一个是数据库 + +在 AstrBot 中, 会话和对话是独立的, 会话用于标记对话窗口, 例如群聊"123456789"可以建立一个会话, +在一个会话中可以建立多个对话, 并且支持对话的切换和删除 +""" + import uuid import json import asyncio @@ -11,21 +18,24 @@ class ConversationManager: """负责管理会话与 LLM 的对话,某个会话当前正在用哪个对话。""" def __init__(self, db_helper: BaseDatabase): - # session_conversations 字典记录会话ID-用户ID 映射关系 + # session_conversations 字典记录会话ID-对话ID 映射关系 self.session_conversations: Dict[str, str] = sp.get("session_conversation", {}) self.db = db_helper self.save_interval = 60 # 每 60 秒保存一次 self._start_periodic_save() def _start_periodic_save(self): + """启动定时保存任务""" asyncio.create_task(self._periodic_save()) async def _periodic_save(self): + """定时保存会话对话映射关系到存储中""" while True: await asyncio.sleep(self.save_interval) self._save_to_storage() def _save_to_storage(self): + """保存会话对话映射关系到存储中""" sp.put("session_conversation", self.session_conversations) async def new_conversation(self, unified_msg_origin: str) -> str: @@ -97,6 +107,7 @@ class ConversationManager: async def get_human_readable_context( self, unified_msg_origin, conversation_id, page=1, page_size=10 ): + """获取人类可读的上下文""" conversation = await self.get_conversation(unified_msg_origin, conversation_id) history = json.loads(conversation.history) diff --git a/astrbot/core/pipeline/__init__.py b/astrbot/core/pipeline/__init__.py index 76844f6f..b97fc0f1 100644 --- a/astrbot/core/pipeline/__init__.py +++ b/astrbot/core/pipeline/__init__.py @@ -12,6 +12,7 @@ from .process_stage.stage import ProcessStage from .result_decorate.stage import ResultDecorateStage from .respond.stage import RespondStage +# 管道阶段顺序 STAGES_ORDER = [ "WakingCheckStage", # 检查是否需要唤醒 "WhitelistCheckStage", # 检查是否在群聊/私聊白名单 diff --git a/astrbot/core/pipeline/context.py b/astrbot/core/pipeline/context.py index 1abbca4e..eb5ffb1c 100644 --- a/astrbot/core/pipeline/context.py +++ b/astrbot/core/pipeline/context.py @@ -5,5 +5,7 @@ from astrbot.core.star import PluginManager @dataclass class PipelineContext: - astrbot_config: AstrBotConfig - plugin_manager: PluginManager + """上下文对象,包含管道执行所需的上下文信息""" + + astrbot_config: AstrBotConfig # AstrBot 配置对象 + plugin_manager: PluginManager # 插件管理器对象 diff --git a/astrbot/core/pipeline/scheduler.py b/astrbot/core/pipeline/scheduler.py index 66874b80..2ed3c0d2 100644 --- a/astrbot/core/pipeline/scheduler.py +++ b/astrbot/core/pipeline/scheduler.py @@ -7,23 +7,29 @@ from astrbot.core import logger class PipelineScheduler: + """管道调度器,负责调度各个阶段的执行""" + def __init__(self, context: PipelineContext): - registered_stages.sort(key=lambda x: STAGES_ORDER.index(x.__class__.__name__)) - self.ctx = context + registered_stages.sort( + key=lambda x: STAGES_ORDER.index(x.__class__.__name__) + ) # 按照顺序排序 + self.ctx = context # 上下文对象 async def initialize(self): + """初始化管道调度器时, 初始化所有阶段""" for stage in registered_stages: # logger.debug(f"初始化阶段 {stage.__class__ .__name__}") await stage.initialize(self.ctx) async def _process_stages(self, event: AstrMessageEvent, from_stage=0): + """依次执行各个阶段""" for i in range(from_stage, len(registered_stages)): stage = registered_stages[i] # logger.debug(f"执行阶段 {stage.__class__ .__name__}") - coro = stage.process(event) - if isinstance(coro, AsyncGenerator): - async for _ in coro: + coroutine = stage.process(event) + if isinstance(coroutine, AsyncGenerator): + async for _ in coroutine: if event.is_stopped(): logger.debug( f"阶段 {stage.__class__.__name__} 已终止事件传播。" @@ -36,7 +42,7 @@ class PipelineScheduler: ) break else: - await coro + await coroutine if event.is_stopped(): logger.debug(f"阶段 {stage.__class__.__name__} 已终止事件传播。") diff --git a/astrbot/core/updator.py b/astrbot/core/updator.py index 0d9860a6..1e7279a8 100644 --- a/astrbot/core/updator.py +++ b/astrbot/core/updator.py @@ -9,6 +9,11 @@ from astrbot.core.utils.io import download_file class AstrBotUpdator(RepoZipUpdator): + """AstrBot 更新器,继承自 RepoZipUpdator 类 + 该类用于处理 AstrBot 的更新操作 + 功能包括检查更新、下载更新文件、解压缩更新文件等 + """ + def __init__(self, repo_mirror: str = "") -> None: super().__init__(repo_mirror) self.MAIN_PATH = os.path.abspath( @@ -17,6 +22,9 @@ class AstrBotUpdator(RepoZipUpdator): self.ASTRBOT_RELEASE_API = "https://api.soulter.top/releases" def terminate_child_processes(self): + """终止当前进程的所有子进程 + 使用 psutil 库获取当前进程的所有子进程,并尝试终止它们 + """ try: parent = psutil.Process(os.getpid()) children = parent.children(recursive=True) @@ -35,6 +43,9 @@ class AstrBotUpdator(RepoZipUpdator): pass def _reboot(self, delay: int = 3): + """重启当前程序 + 在指定的延迟后,终止所有子进程并重新启动程序 + """ py = sys.executable time.sleep(delay) self.terminate_child_processes() @@ -46,6 +57,7 @@ class AstrBotUpdator(RepoZipUpdator): raise e async def check_update(self, url: str, current_version: str) -> ReleaseInfo: + """检查更新""" return await super().check_update(self.ASTRBOT_RELEASE_API, VERSION) async def get_releases(self) -> list: