Fix: 当多个相同消息平台实例部署时上下文可能混乱(共享) (#2298)

* perf: update astrbot event session format, using platfrom id to ensure uniqueness

fixes: #1000

* fix: 更新 MessageSession 类以使用 platform_id 作为唯一标识符,并调整相关方法以确保一致性

* fix: 更新 MessageSession 文档以明确 platform_id 的赋值规则,并调整 get_platform 和 get_platform_inst 方法的返回类型
This commit is contained in:
Soulter
2025-08-02 21:38:55 +08:00
committed by GitHub
parent 1b37530c96
commit 87f05fce66
7 changed files with 54 additions and 14 deletions

View File

@@ -232,6 +232,6 @@ class AstrBotCoreLifecycle:
platform_insts = self.platform_manager.get_insts()
for platform_inst in platform_insts:
tasks.append(
asyncio.create_task(platform_inst.run(), name=platform_inst.meta().name)
asyncio.create_task(platform_inst.run(), name=f"{platform_inst.meta().id}({platform_inst.meta().name})")
)
return tasks

View File

@@ -48,10 +48,10 @@ class EventBus:
# 如果有发送者名称: [平台名] 发送者名称/发送者ID: 消息概要
if event.get_sender_name():
logger.info(
f"[{event.get_platform_name()}] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
f"[{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_name()}/{event.get_sender_id()}: {event.get_message_outline()}"
)
# 没有发送者名称: [平台名] 发送者ID: 消息概要
else:
logger.info(
f"[{event.get_platform_name()}] {event.get_sender_id()}: {event.get_message_outline()}"
f"[{event.get_platform_id()}({event.get_platform_name()})] {event.get_sender_id()}: {event.get_message_outline()}"
)

View File

@@ -128,7 +128,7 @@ class RespondStage(Stage):
use_fallback = self.config.get("provider_settings", {}).get(
"streaming_segmented", False
)
logger.info(f"应用流式输出({event.get_platform_name()})")
logger.info(f"应用流式输出({event.get_platform_id()})")
await event.send_streaming(result.async_stream, use_fallback)
return
elif len(result.chain) > 0:

View File

@@ -27,19 +27,29 @@ from .platform_metadata import PlatformMetadata
@dataclass
class MessageSession:
"""描述一条消息在 AstrBot 中对应的会话的唯一标识。
如果您需要实例化 MessageSession请不要给 platform_id 赋值(或者同时给 platform_name 和 platform_id 赋值相同值)。它会在 __post_init__ 中自动设置为 platform_name 的值。"""
platform_name: str
"""平台适配器实例的唯一标识符。自 AstrBot v4.0.0 起,该字段实际为 platform_id。"""
message_type: MessageType
session_id: str
platform_id: str = None
def __str__(self):
return f"{self.platform_name}:{self.message_type.value}:{self.session_id}"
return f"{self.platform_id}:{self.message_type.value}:{self.session_id}"
def __post_init__(self):
self.platform_id = self.platform_name
@staticmethod
def from_str(session_str: str):
platform_name, message_type, session_id = session_str.split(":")
return MessageSession(platform_name, MessageType(message_type), session_id)
platform_id, message_type, session_id = session_str.split(":")
return MessageSession(platform_id, MessageType(message_type), session_id)
MessageSesion = MessageSession # back compatibility
MessageSesion = MessageSession # back compatibility
class AstrMessageEvent(abc.ABC):
def __init__(
@@ -65,7 +75,7 @@ class AstrMessageEvent(abc.ABC):
"""是否是 At 机器人或者带有唤醒词或者是私聊(插件注册的事件监听器会让 is_wake 设为 True, 但是不会让这个属性置为 True)"""
self._extras = {}
self.session = MessageSesion(
platform_name=platform_meta.name,
platform_name=platform_meta.id,
message_type=message_obj.type,
session_id=session_id,
)
@@ -83,9 +93,16 @@ class AstrMessageEvent(abc.ABC):
self.platform = platform_meta
def get_platform_name(self):
"""获取这个事件所属的平台的类型(如 aiocqhttp, slack, discord 等)。
NOTE: 用户可能会同时运行多个相同类型的平台适配器。"""
return self.platform_meta.name
def get_platform_id(self):
"""获取这个事件所属的平台的 ID。
NOTE: 用户可能会同时运行多个相同类型的平台适配器,但能确定的是 ID 是唯一的。
"""
return self.platform_meta.id
def get_message_str(self) -> str:

View File

@@ -4,7 +4,7 @@ from dataclasses import dataclass
@dataclass
class PlatformMetadata:
name: str
"""平台的名称"""
"""平台的名称,即平台的类型,如 aiocqhttp, discord, slack"""
description: str
"""平台的描述"""
id: str = None

View File

@@ -2,7 +2,12 @@ from asyncio import Queue
from typing import List, Union
from astrbot.core import sp
from astrbot.core.provider.provider import Provider, TTSProvider, STTProvider, EmbeddingProvider
from astrbot.core.provider.provider import (
Provider,
TTSProvider,
STTProvider,
EmbeddingProvider,
)
from astrbot.core.provider.entities import ProviderType
from astrbot.core.db import BaseDatabase
from astrbot.core.config.astrbot_config import AstrBotConfig
@@ -22,6 +27,7 @@ from astrbot.core.star.filter.platform_adapter_type import (
PlatformAdapterType,
ADAPTER_NAME_2_TYPE,
)
from deprecated import deprecated
class Context:
@@ -201,9 +207,12 @@ class Context:
"""
return self._event_queue
def get_platform(self, platform_type: Union[PlatformAdapterType, str]) -> Platform:
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
def get_platform(self, platform_type: Union[PlatformAdapterType, str]) -> Platform | None:
"""
获取指定类型的平台适配器。
该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
"""
for platform in self.platform_manager.platform_insts:
name = platform.meta().name
@@ -217,6 +226,20 @@ class Context:
):
return platform
def get_platform_inst(self, platform_id: str) -> Platform | None:
"""
获取指定 ID 的平台适配器实例。
Args:
platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。
Returns:
Platform: 平台适配器实例,如果未找到则返回 None。
"""
for platform in self.platform_manager.platform_insts:
if platform.meta().id == platform_id:
return platform
async def send_message(
self, session: Union[str, MessageSesion], message_chain: MessageChain
) -> bool:
@@ -240,7 +263,7 @@ class Context:
raise ValueError("不合法的 session 字符串: " + str(e))
for platform in self.platform_manager.platform_insts:
if platform.meta().name == session.platform_name:
if platform.meta().id == session.platform_name:
await platform.send_by_session(session, message_chain)
return True
return False

View File

@@ -820,7 +820,7 @@ UID: {user_id} 此 ID 可用于设置管理员。
if sid:
session = str(
MessageSesion(
platform_name=message.platform_meta.name,
platform_name=message.platform_meta.id,
message_type=MessageType("GroupMessage"),
session_id=sid,
)