Files
AstrBot/model/command/manager.py

142 lines
5.7 KiB
Python

import heapq
import inspect
import traceback
from typing import Dict
from type.types import Context
from type.plugin import PluginMetadata
from type.message_event import AstrMessageEvent
from type.command import CommandResult
from type.register import RegisteredPlugins
from model.command.parser import CommandParser
from model.plugin.command import PluginCommandBridge
from SparkleLogging.utils.core import LogManager
from logging import Logger
from dataclasses import dataclass
logger: Logger = LogManager.GetLogger(log_name='astrbot')
@dataclass
class CommandMetadata():
inner_command: bool
plugin_metadata: PluginMetadata
handler: callable
use_regex: bool = False
ignore_prefix: bool = False
description: str = ""
class CommandManager():
def __init__(self):
self.commands = []
self.commands_handler: Dict[str, CommandMetadata] = {}
self.command_parser = CommandParser()
def register(self,
command: str,
description: str,
priority: int,
handler: callable,
use_regex: bool = False,
ignore_prefix: bool = False,
plugin_metadata: PluginMetadata = None,
):
'''
优先级越高,越先被处理。
use_regex: 是否使用正则表达式匹配指令。
'''
if command in self.commands_handler:
raise ValueError(f"Command {command} already exists.")
if not handler:
raise ValueError(f"Handler of {command} is None.")
heapq.heappush(self.commands, (-priority, command))
self.commands_handler[command] = CommandMetadata(
inner_command=plugin_metadata == None,
plugin_metadata=plugin_metadata,
handler=handler,
use_regex=use_regex,
ignore_prefix=ignore_prefix,
description=description
)
if plugin_metadata:
logger.info(f"已注册 {plugin_metadata.author}/{plugin_metadata.plugin_name} 的指令 {command}")
else:
logger.info(f"已注册指令 {command}")
def register_from_pcb(self, pcb: PluginCommandBridge):
for request in pcb.plugin_commands_waitlist:
plugin = None
for registered_plugin in pcb.cached_plugins:
if registered_plugin.metadata.plugin_name == request.plugin_name:
plugin = registered_plugin
break
if not plugin:
logger.warning(f"插件 {request.plugin_name} 未找到,无法注册指令 {request.command_name}")
else:
self.register(command=request.command_name,
description=request.description,
priority=request.priority,
handler=request.handler,
use_regex=request.use_regex,
ignore_prefix=request.ignore_prefix,
plugin_metadata=plugin.metadata)
self.plugin_commands_waitlist = []
async def check_command_ignore_prefix(self, message_str: str) -> bool:
for _, command in self.commands:
command_metadata = self.commands_handler[command]
if command_metadata.ignore_prefix:
trig = False
if self.commands_handler[command].use_regex:
trig = self.command_parser.regex_match(message_str, command)
else:
trig = message_str.startswith(command)
if trig:
return True
return False
async def scan_command(self, message_event: AstrMessageEvent, context: Context) -> CommandResult:
message_str = message_event.message_str
for _, command in self.commands:
trig = False
if self.commands_handler[command].use_regex:
trig = self.command_parser.regex_match(message_str, command)
else:
trig = message_str.startswith(command)
if trig:
logger.info(f"触发 {command} 指令。")
command_result = await self.execute_handler(command, message_event, context)
if not command_result:
continue
if command_result.hit:
return command_result
async def execute_handler(self,
command: str,
message_event: AstrMessageEvent,
context: Context) -> CommandResult:
command_metadata = self.commands_handler[command]
handler = command_metadata.handler
# call handler
try:
if inspect.iscoroutinefunction(handler):
command_result = await handler(message_event, context)
else:
command_result = handler(message_event, context)
if not isinstance(command_result, CommandResult):
raise ValueError(f"Command {command} handler should return CommandResult.")
context.metrics_uploader.command_stats[command] += 1
return command_result
except BaseException as e:
logger.error(traceback.format_exc())
if not command_metadata.inner_command:
text = f"执行 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令时发生了异常。{e}"
logger.error(text)
else:
text = f"执行 {command} 指令时发生了异常。{e}"
logger.error(text)
return CommandResult().message(text)