142 lines
5.7 KiB
Python
142 lines
5.7 KiB
Python
import heapq
|
|
import inspect
|
|
import traceback
|
|
from typing import Dict
|
|
from type.types import Context
|
|
from type.plugin import PluginMetadata
|
|
from type.message_event import AstrMessageEvent
|
|
from type.command import CommandResult
|
|
from type.register import RegisteredPlugins
|
|
from model.command.parser import CommandParser
|
|
from model.plugin.command import PluginCommandBridge
|
|
from SparkleLogging.utils.core import LogManager
|
|
from logging import Logger
|
|
from dataclasses import dataclass
|
|
|
|
logger: Logger = LogManager.GetLogger(log_name='astrbot')
|
|
|
|
@dataclass
|
|
class CommandMetadata():
|
|
inner_command: bool
|
|
plugin_metadata: PluginMetadata
|
|
handler: callable
|
|
use_regex: bool = False
|
|
ignore_prefix: bool = False
|
|
description: str = ""
|
|
|
|
class CommandManager():
|
|
def __init__(self):
|
|
self.commands = []
|
|
self.commands_handler: Dict[str, CommandMetadata] = {}
|
|
self.command_parser = CommandParser()
|
|
|
|
def register(self,
|
|
command: str,
|
|
description: str,
|
|
priority: int,
|
|
handler: callable,
|
|
use_regex: bool = False,
|
|
ignore_prefix: bool = False,
|
|
plugin_metadata: PluginMetadata = None,
|
|
):
|
|
'''
|
|
优先级越高,越先被处理。
|
|
|
|
use_regex: 是否使用正则表达式匹配指令。
|
|
'''
|
|
if command in self.commands_handler:
|
|
raise ValueError(f"Command {command} already exists.")
|
|
if not handler:
|
|
raise ValueError(f"Handler of {command} is None.")
|
|
|
|
heapq.heappush(self.commands, (-priority, command))
|
|
self.commands_handler[command] = CommandMetadata(
|
|
inner_command=plugin_metadata == None,
|
|
plugin_metadata=plugin_metadata,
|
|
handler=handler,
|
|
use_regex=use_regex,
|
|
ignore_prefix=ignore_prefix,
|
|
description=description
|
|
)
|
|
if plugin_metadata:
|
|
logger.info(f"已注册 {plugin_metadata.author}/{plugin_metadata.plugin_name} 的指令 {command}。")
|
|
else:
|
|
logger.info(f"已注册指令 {command}。")
|
|
|
|
def register_from_pcb(self, pcb: PluginCommandBridge):
|
|
for request in pcb.plugin_commands_waitlist:
|
|
plugin = None
|
|
for registered_plugin in pcb.cached_plugins:
|
|
if registered_plugin.metadata.plugin_name == request.plugin_name:
|
|
plugin = registered_plugin
|
|
break
|
|
if not plugin:
|
|
logger.warning(f"插件 {request.plugin_name} 未找到,无法注册指令 {request.command_name}。")
|
|
else:
|
|
self.register(command=request.command_name,
|
|
description=request.description,
|
|
priority=request.priority,
|
|
handler=request.handler,
|
|
use_regex=request.use_regex,
|
|
ignore_prefix=request.ignore_prefix,
|
|
plugin_metadata=plugin.metadata)
|
|
self.plugin_commands_waitlist = []
|
|
|
|
async def check_command_ignore_prefix(self, message_str: str) -> bool:
|
|
for _, command in self.commands:
|
|
command_metadata = self.commands_handler[command]
|
|
if command_metadata.ignore_prefix:
|
|
trig = False
|
|
if self.commands_handler[command].use_regex:
|
|
trig = self.command_parser.regex_match(message_str, command)
|
|
else:
|
|
trig = message_str.startswith(command)
|
|
if trig:
|
|
return True
|
|
return False
|
|
|
|
async def scan_command(self, message_event: AstrMessageEvent, context: Context) -> CommandResult:
|
|
message_str = message_event.message_str
|
|
for _, command in self.commands:
|
|
trig = False
|
|
if self.commands_handler[command].use_regex:
|
|
trig = self.command_parser.regex_match(message_str, command)
|
|
else:
|
|
trig = message_str.startswith(command)
|
|
if trig:
|
|
logger.info(f"触发 {command} 指令。")
|
|
command_result = await self.execute_handler(command, message_event, context)
|
|
if not command_result:
|
|
continue
|
|
if command_result.hit:
|
|
return command_result
|
|
|
|
async def execute_handler(self,
|
|
command: str,
|
|
message_event: AstrMessageEvent,
|
|
context: Context) -> CommandResult:
|
|
command_metadata = self.commands_handler[command]
|
|
handler = command_metadata.handler
|
|
# call handler
|
|
try:
|
|
if inspect.iscoroutinefunction(handler):
|
|
command_result = await handler(message_event, context)
|
|
else:
|
|
command_result = handler(message_event, context)
|
|
|
|
if not isinstance(command_result, CommandResult):
|
|
raise ValueError(f"Command {command} handler should return CommandResult.")
|
|
|
|
context.metrics_uploader.command_stats[command] += 1
|
|
|
|
return command_result
|
|
except BaseException as e:
|
|
logger.error(traceback.format_exc())
|
|
|
|
if not command_metadata.inner_command:
|
|
text = f"执行 {command}/({command_metadata.plugin_metadata.plugin_name} By {command_metadata.plugin_metadata.author}) 指令时发生了异常。{e}"
|
|
logger.error(text)
|
|
else:
|
|
text = f"执行 {command} 指令时发生了异常。{e}"
|
|
logger.error(text)
|
|
return CommandResult().message(text) |