diff --git a/astrbot/core/config/default.py b/astrbot/core/config/default.py index 7c4517e5..78f599c2 100644 --- a/astrbot/core/config/default.py +++ b/astrbot/core/config/default.py @@ -225,12 +225,15 @@ CONFIG_METADATA_2 = { "telegram_command_auto_refresh": True, "telegram_command_register_interval": 300, }, - "discord":{ + "discord": { "id": "discord", "type": "discord", "enable": False, "discord_token": "", "discord_proxy": "", + "discord_command_register": True, + "discord_guild_id_for_debug": "", + "discord_activity_name": "", }, "slack": { "id": "slack", @@ -374,6 +377,19 @@ CONFIG_METADATA_2 = { "type": "string", "hint": "可选的代理地址:http://ip:port" }, + "discord_command_register": { + "description": "是否自动将插件指令注册为 Discord 斜杠指令", + "type": "bool", + }, + "discord_activity_name": { + "description": "Discord 活动名称", + "type": "string", + "hint": "可选的 Discord 活动名称。留空则不设置活动。", + }, + "discord_guild_id_for_debug": { + "description": "【开发用】指定一个服务器(Guild)ID。在此服务器注册的指令会立刻生效,便于调试。留空则注册为全局指令。", + "type": "string", + }, }, }, "platform_settings": { diff --git a/astrbot/core/platform/sources/discord/client.py b/astrbot/core/platform/sources/discord/client.py index f151c1d1..2e3fd89b 100644 --- a/astrbot/core/platform/sources/discord/client.py +++ b/astrbot/core/platform/sources/discord/client.py @@ -1,5 +1,11 @@ import discord from astrbot import logger +import sys + +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override # Discord Bot客户端 @@ -20,12 +26,23 @@ class DiscordBotClient(discord.Bot): # 回调函数 self.on_message_received = None + self.on_ready_once_callback = None + self._ready_once_fired = False + @override async def on_ready(self): """当机器人成功连接并准备就绪时触发""" logger.info(f"[Discord] 已作为 {self.user} (ID: {self.user.id}) 登录") logger.info("[Discord] 客户端已准备就绪。") + if self.on_ready_once_callback and not self._ready_once_fired: + self._ready_once_fired = True + try: + await self.on_ready_once_callback() + except Exception as e: + logger.error( + f"[Discord] on_ready_once_callback 执行失败: {e}", exc_info=True) + def _create_message_data(self, message: discord.Message) -> dict: """从 discord.Message 创建数据字典""" is_mentioned = self.user in message.mentions @@ -59,6 +76,7 @@ class DiscordBotClient(discord.Bot): "type": "interaction", } + @override async def on_message(self, message: discord.Message): """当接收到消息时触发""" if message.author.bot: @@ -72,15 +90,6 @@ class DiscordBotClient(discord.Bot): message_data = self._create_message_data(message) await self.on_message_received(message_data) - async def on_interaction(self, interaction: discord.Interaction): - """当接收到交互(按钮点击等)时触发""" - logger.debug( - f"[Discord] 收到交互 from {interaction.user.name}: {interaction.data}" - ) - - if self.on_message_received: - interaction_data = self._create_interaction_data(interaction) - await self.on_message_received(interaction_data) def _extract_interaction_content(self, interaction: discord.Interaction) -> str: """从交互中提取内容""" @@ -110,6 +119,7 @@ class DiscordBotClient(discord.Bot): """开始轮询消息,这是个阻塞方法""" await self.start(self.token) + @override async def close(self): """关闭客户端""" if not self.is_closed(): diff --git a/astrbot/core/platform/sources/discord/components.py b/astrbot/core/platform/sources/discord/components.py index 996f7957..dbeda38a 100644 --- a/astrbot/core/platform/sources/discord/components.py +++ b/astrbot/core/platform/sources/discord/components.py @@ -79,6 +79,13 @@ class DiscordButton(BaseMessageComponent): self.url = url self.disabled = disabled +class DiscordReference(BaseMessageComponent): + """Discord引用组件""" + type: str = "discord_reference" + def __init__(self, message_id: str, channel_id: str): + self.message_id = message_id + self.channel_id = channel_id + class DiscordView(BaseMessageComponent): """Discord视图组件,包含按钮和选择菜单""" @@ -91,6 +98,7 @@ class DiscordView(BaseMessageComponent): self.components = components or [] self.timeout = timeout + def to_discord_view(self) -> discord.ui.View: """转换为Discord View对象""" view = discord.ui.View(timeout=self.timeout) diff --git a/astrbot/core/platform/sources/discord/discord_platform_adapter.py b/astrbot/core/platform/sources/discord/discord_platform_adapter.py index 7f52f52e..9b30667e 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_adapter.py +++ b/astrbot/core/platform/sources/discord/discord_platform_adapter.py @@ -14,6 +14,19 @@ from astrbot.api.platform import register_platform_adapter from astrbot import logger from .client import DiscordBotClient from .discord_platform_event import DiscordPlatformEvent +import sys +from functools import partial +from typing import Any, Dict, List, Tuple, Type +from astrbot.core.star.filter.command import CommandFilter, GreedyStr +from astrbot.core.star.filter.command_group import CommandGroupFilter +from astrbot.core.star.star import star_map +from astrbot.core.star.star_handler import StarHandlerMetadata, star_handlers_registry +import re + +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override # 注册平台适配器 @@ -27,7 +40,13 @@ class DiscordPlatformAdapter(Platform): self.settings = platform_settings self.client_self_id = None self.registered_handlers = [] + # 指令注册相关 + self.enable_command_register = self.config.get( + "discord_command_register", True) + self.guild_id = self.config.get("discord_guild_id_for_debug", None) + self.activity_name = self.config.get("discord_activity_name", None) + @override async def send_by_session( self, session: MessageSesion, message_chain: MessageChain ): @@ -43,6 +62,7 @@ class DiscordPlatformAdapter(Platform): await temp_event.send(message_chain) await super().send_by_session(session, message_chain) + @override def meta(self) -> PlatformMetadata: """返回平台元数据""" return PlatformMetadata( @@ -52,6 +72,7 @@ class DiscordPlatformAdapter(Platform): default_config_tmpl=self.config, ) + @override async def run(self): """主要运行逻辑""" @@ -73,6 +94,14 @@ class DiscordPlatformAdapter(Platform): self.client = DiscordBotClient(token, proxy) self.client.on_message_received = on_received + async def callback(): + if self.enable_command_register: + await self._collect_and_register_commands() + if self.activity_name: + await self.client.change_presence(status=discord.Status.online, activity=discord.CustomActivity(name=self.activity_name)) + + self.client.on_ready_once_callback = callback + try: await self.client.start_polling() except discord.errors.LoginFailure: @@ -95,32 +124,6 @@ class DiscordPlatformAdapter(Platform): gid = guild_id or getattr(channel, "guild", None).id return MessageType.GROUP_MESSAGE, str(gid) - def _convert_interaction_to_abm(self, data: dict) -> AstrBotMessage: - """将交互事件转换为 AstrBotMessage""" - interaction: discord.Interaction = data["interaction"] - abm = AstrBotMessage() - - abm.type, abm.group_id = self._determine_message_type( - interaction.channel, interaction.guild_id - ) - - # 对于交互事件,message_str 通常没有意义,且可能导致被闲聊等通用插件错误响应。 - # 将其清空,以确保只有专门的指令处理器会响应。 - abm.message_str = "" - abm.sender = MessageMember( - user_id=str(interaction.user.id), nickname=interaction.user.display_name - ) - abm.message = [Plain(text=data["content"])] - abm.raw_message = interaction - abm.self_id = self.client_self_id - abm.session_id = ( - str(interaction.channel_id) - if interaction.channel_id - else str(interaction.user.id) - ) - abm.message_id = str(interaction.id) - return abm - def _convert_message_to_abm(self, data: dict) -> AstrBotMessage: """将普通消息转换为 AstrBotMessage""" message: discord.Message = data["message"] @@ -142,9 +145,9 @@ class DiscordPlatformAdapter(Platform): ) if content.startswith(mention_str): - content = content[len(mention_str) :].lstrip() + content = content[len(mention_str):].lstrip() elif content.startswith(mention_str_nickname): - content = content[len(mention_str_nickname) :].lstrip() + content = content[len(mention_str_nickname):].lstrip() abm = AstrBotMessage() @@ -181,12 +184,10 @@ class DiscordPlatformAdapter(Platform): async def convert_message(self, data: dict) -> AstrBotMessage: """将平台消息转换成 AstrBotMessage""" - if data.get("type") in ["interaction", "slash_command"]: - return self._convert_interaction_to_abm(data) - else: - return self._convert_message_to_abm(data) + # 由于 on_interaction 已被禁用,我们只处理普通消息 + return self._convert_message_to_abm(data) - async def handle_msg(self, message: AstrBotMessage): + async def handle_msg(self, message: AstrBotMessage, followup_webhook=None): """处理消息""" message_event = DiscordPlatformEvent( message_str=message.message_str, @@ -194,23 +195,43 @@ class DiscordPlatformAdapter(Platform): platform_meta=self.meta(), session_id=message.session_id, client=self.client, + interaction_followup_webhook=followup_webhook, ) - # 如果是被@的消息,设置为唤醒状态 - if ( + # 检查是否为斜杠指令 + is_slash_command = message_event.interaction_followup_webhook is not None + + # 检查是否被@ + is_mention = ( self.client and self.client.user and hasattr(message.raw_message, "mentions") and self.client.user in message.raw_message.mentions - ): + ) + + # 如果是斜杠指令或被@的消息,设置为唤醒状态 + if is_slash_command or is_mention: message_event.is_wake = True message_event.is_at_or_wake_command = True self.commit_event(message_event) + @override async def terminate(self): """终止适配器""" logger.info("[Discord] 正在终止适配器...") + + # 清理指令 + if self.enable_command_register and self.client: + logger.info("[Discord] 正在清理已注册的斜杠指令...") + try: + # 传入空的列表来清除所有全局指令 + # 如果指定了 guild_id,则只清除该服务器的指令 + await self.client.sync_commands(commands=[], guild_ids=[self.guild_id] if self.guild_id else None) + logger.info("[Discord] 指令清理完成。") + except Exception as e: + logger.error(f"[Discord] 清理指令时发生错误: {e}", exc_info=True) + if self.client and hasattr(self.client, "close"): await self.client.close() logger.info("[Discord] 适配器已终止。") @@ -218,3 +239,132 @@ class DiscordPlatformAdapter(Platform): def register_handler(self, handler_info): """注册处理器信息""" self.registered_handlers.append(handler_info) + + async def _collect_and_register_commands(self): + """收集所有指令并注册到Discord""" + logger.info("[Discord] 开始收集并注册斜杠指令...") + registered_commands = [] + + for handler_md in star_handlers_registry: + if not star_map[handler_md.handler_module_path].activated: + continue + for event_filter in handler_md.event_filters: + cmd_info = self._extract_command_info(event_filter, handler_md) + if not cmd_info: + continue + + cmd_name, description, cmd_filter_instance = cmd_info + + # 创建动态回调 + callback = self._create_dynamic_callback(cmd_name) + + # 创建一个通用的参数选项来接收所有文本输入 + options = [ + discord.Option( + name="params", + description="指令的所有参数", + type=discord.SlashCommandOptionType.string, + required=False, + ) + ] + + # 创建SlashCommand + slash_command = discord.SlashCommand( + name=cmd_name, + description=description, + func=callback, + options=options, + guild_ids=[self.guild_id] if self.guild_id else None, + ) + self.client.add_application_command(slash_command) + registered_commands.append(cmd_name) + + if registered_commands: + logger.info( + f"[Discord] 准备同步 {len(registered_commands)} 个指令: {', '.join(registered_commands)}") + else: + logger.info("[Discord] 没有发现可注册的指令。") + + # 使用 Pycord 的方法同步指令 + # 注意:这可能需要一些时间,并且有频率限制 + await self.client.sync_commands() + logger.info("[Discord] 指令同步完成。") + + def _create_dynamic_callback(self, cmd_name: str): + """为每个指令动态创建一个异步回调函数""" + async def dynamic_callback(ctx: discord.ApplicationContext, params: str = None): + # 将平台特定的前缀'/'剥离,以适配通用的CommandFilter + logger.debug(f"[Discord] 回调函数触发: {cmd_name}") + logger.debug(f"[Discord] 回调函数参数: {ctx}") + logger.debug(f"[Discord] 回调函数参数: {params}") + message_str_for_filter = cmd_name + if params: + message_str_for_filter += " " + params + + logger.debug( + f"[Discord] 斜杠指令 '{cmd_name}' 被触发。 " + f"原始参数: '{params}'. " + f"构建的指令字符串: '{message_str_for_filter}'" + ) + + # 尝试立即响应,防止超时 + followup_webhook = None + try: + await ctx.defer() + followup_webhook = ctx.followup + except Exception as e: + logger.warning(f"[Discord] 指令 '{cmd_name}' defer 失败: {e}") + + # 2. 构建 AstrBotMessage + abm = AstrBotMessage() + abm.type, abm.group_id = self._determine_message_type( + ctx.channel, ctx.guild_id + ) + abm.message_str = message_str_for_filter + abm.sender = MessageMember( + user_id=str(ctx.author.id), nickname=ctx.author.display_name + ) + abm.message = [Plain(text=message_str_for_filter)] + abm.raw_message = ctx.interaction + abm.self_id = self.client_self_id + abm.session_id = str(ctx.channel_id) + abm.message_id = str(ctx.interaction.id) + + # 3. 将消息和 webhook 分别交给 handle_msg 处理 + await self.handle_msg(abm, followup_webhook) + + return dynamic_callback + + @staticmethod + def _extract_command_info( + event_filter: Any, handler_metadata: StarHandlerMetadata + ) -> Tuple[str, str, CommandFilter] | None: + """从事件过滤器中提取指令信息""" + cmd_name = None + is_group = False + cmd_filter_instance = None + + if isinstance(event_filter, CommandFilter): + # 暂不支持子指令注册为斜杠指令 + if event_filter.parent_command_names and event_filter.parent_command_names != [""]: + return None + cmd_name = event_filter.command_name + cmd_filter_instance = event_filter + + elif isinstance(event_filter, CommandGroupFilter): + # 暂不支持指令组直接注册为斜杠指令,因为它们没有 handle 方法 + return None + + if not cmd_name: + return None + + # Discord 斜杠指令名称规范 + if not re.match(r"^[a-z0-9_-]{1,32}$", cmd_name): + logger.debug(f"[Discord] 跳过不符合规范的指令: {cmd_name}") + return None + + description = handler_metadata.desc or f"指令: {cmd_name}" + if len(description) > 100: + description = description[:97] + "..." + + return cmd_name, description, cmd_filter_instance diff --git a/astrbot/core/platform/sources/discord/discord_platform_event.py b/astrbot/core/platform/sources/discord/discord_platform_event.py index bb9378de..c6aaf5de 100644 --- a/astrbot/core/platform/sources/discord/discord_platform_event.py +++ b/astrbot/core/platform/sources/discord/discord_platform_event.py @@ -3,15 +3,27 @@ import discord import base64 from io import BytesIO from pathlib import Path -from typing import Optional +from typing import Optional, List +import sys from astrbot.api.event import AstrMessageEvent, MessageChain -from astrbot.api.platform import AstrBotMessage, PlatformMetadata -from astrbot.api.message_components import Plain, Image, File, BaseMessageComponent +from astrbot.api.platform import AstrBotMessage, PlatformMetadata, At +from astrbot.api.message_components import ( + Plain, + Image, + File, + BaseMessageComponent, + Reply, +) from astrbot import logger from .client import DiscordBotClient from .components import DiscordEmbed, DiscordView +if sys.version_info >= (3, 12): + from typing import override +else: + from typing_extensions import override + # 自定义Discord视图组件(兼容旧版本) class DiscordViewComponent(BaseMessageComponent): @@ -29,36 +41,52 @@ class DiscordPlatformEvent(AstrMessageEvent): platform_meta: PlatformMetadata, session_id: str, client: DiscordBotClient, + interaction_followup_webhook: Optional[discord.Webhook] = None, ): super().__init__(message_str, message_obj, platform_meta, session_id) self.client = client + self.interaction_followup_webhook = interaction_followup_webhook + @override async def send(self, message: MessageChain): """发送消息到Discord平台""" + + # 解析消息链为 Discord 所需的对象 try: - channel = await self._get_channel() - if not channel: - logger.error(f"[Discord] 无法获取频道 {self.session_id}") - return + content, files, view, embeds, reference_message_id = await self._parse_to_discord(message) + except Exception as e: + logger.error(f"[Discord] 解析消息链时失败: {e}", exc_info=True) + return - # 解析消息链 - content, files, view, embeds = await self._parse_to_discord(message) + kwargs = {} + if content: + kwargs["content"] = content + if files: + kwargs["files"] = files + if view: + kwargs["view"] = view + if embeds: + kwargs["embeds"] = embeds + if reference_message_id and not self.interaction_followup_webhook: + kwargs["reference"] = self.client.get_message(int(reference_message_id)) + if not kwargs: + logger.debug("[Discord] 尝试发送空消息,已忽略。") + return - # Discord 不允许发送完全空的消息 - if not content and not files and not view and not embeds: - logger.debug("[Discord] 尝试发送空消息,已忽略。") - return + # 根据上下文执行发送/回复操作 + try: + # -- 斜杠指令/交互上下文 -- + if self.interaction_followup_webhook: + await self.interaction_followup_webhook.send(**kwargs) - # 发送消息 - await channel.send( - content=content or None, - files=files or None, - view=view or None, - embeds=embeds or None, - ) + # -- 常规消息上下文 -- + else: + channel = await self._get_channel() + if not channel: + return + else: + await channel.send(**kwargs) - except discord.errors.HTTPException as e: - logger.error(f"[Discord] 发送消息失败: {e.status} {e.code} - {e.text}") except Exception as e: logger.error(f"[Discord] 发送消息时发生未知错误: {e}", exc_info=True) @@ -80,14 +108,18 @@ class DiscordPlatformEvent(AstrMessageEvent): message: MessageChain, ) -> tuple[str, list[discord.File], Optional[discord.ui.View], list[discord.Embed]]: """将 MessageChain 解析为 Discord 发送所需的内容""" - plain_text_parts = [] + content = "" files = [] view = None embeds = [] - + reference_message_id = None for i in message.chain: # 遍历消息链 if isinstance(i, Plain): # 如果是文字类型的 - plain_text_parts.append(i.text) + content += i.text + elif isinstance(i, Reply): + reference_message_id = i.id + elif isinstance(i, At): + content += f"<@{i.qq}>" elif isinstance(i, Image): logger.debug(f"[Discord] 开始处理 Image 组件: {i}") try: @@ -174,7 +206,8 @@ class DiscordPlatformEvent(AstrMessageEvent): if await asyncio.to_thread(path.exists): file_bytes = await asyncio.to_thread(path.read_bytes) files.append( - discord.File(BytesIO(file_bytes), filename=i.name) + discord.File(BytesIO(file_bytes), + filename=i.name) ) else: logger.warning( @@ -197,37 +230,10 @@ class DiscordPlatformEvent(AstrMessageEvent): else: logger.debug(f"[Discord] 忽略了不支持的消息组件: {i.type}") - # 合并文本内容 - content = "\n".join(plain_text_parts) if len(content) > 2000: logger.warning("[Discord] 消息内容超过2000字符,将被截断。") content = content[:2000] - - return content, files, view, embeds - - async def reply(self, message: MessageChain): - """回复消息(如果原消息存在)""" - try: - if hasattr(self.message_obj, "raw_message") and hasattr( - self.message_obj.raw_message, "reply" - ): - # 解析消息链 - content, files, view, embeds = await self._parse_to_discord(message) - - # 使用Discord的回复功能 - await self.message_obj.raw_message.reply( - content=content or None, - files=files or None, - view=view or None, - embeds=embeds or None, - ) - else: - # 如果无法回复,使用普通发送 - await self.send(message) - except Exception as e: - logger.error(f"[Discord] 回复消息失败: {e}") - # 回退到普通发送 - await self.send(message) + return content, files, view, embeds, reference_message_id async def react(self, emoji: str): """对原消息添加反应"""