modified: astrbot/core/pipeline/process_stage/method/llm_request.py
new file: astrbot/core/star/session_llm_manager.py modified: astrbot/dashboard/routes/session_management.py modified: dashboard/src/views/SessionManagementPage.vue 增加了精确到会话的LLM启停管理以及插件启停管理
This commit is contained in:
@@ -26,6 +26,7 @@ from astrbot.core.provider.entities import (
|
||||
)
|
||||
from astrbot.core.star.star_handler import star_handlers_registry, EventType
|
||||
from astrbot.core.star.star import star_map
|
||||
from astrbot.core.star.session_llm_manager import SessionLLMManager
|
||||
from mcp.types import (
|
||||
TextContent,
|
||||
ImageContent,
|
||||
@@ -71,6 +72,12 @@ class LLMRequestSubStage(Stage):
|
||||
if not self.ctx.astrbot_config["provider_settings"]["enable"]:
|
||||
logger.debug("未启用 LLM 能力,跳过处理。")
|
||||
return
|
||||
|
||||
# 检查会话级别的LLM启停状态
|
||||
if not SessionLLMManager.should_process_llm_request(event):
|
||||
logger.debug(f"会话 {event.unified_msg_origin} 禁用了 LLM,跳过处理。")
|
||||
return
|
||||
|
||||
umo = event.unified_msg_origin
|
||||
provider = self.ctx.plugin_manager.context.get_using_provider(umo=umo)
|
||||
if provider is None:
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
会话LLM管理器 - 负责管理每个会话的LLM启停状态
|
||||
"""
|
||||
|
||||
from astrbot.core import sp, logger
|
||||
from typing import Dict, Optional
|
||||
from astrbot.core.platform.astr_message_event import AstrMessageEvent
|
||||
|
||||
|
||||
class SessionLLMManager:
|
||||
"""管理会话级别的LLM启停状态"""
|
||||
|
||||
@staticmethod
|
||||
def is_llm_enabled_for_session(session_id: str) -> bool:
|
||||
"""检查LLM是否在指定会话中启用
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
bool: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取会话LLM配置
|
||||
session_llm_config = sp.get("session_llm_config", {})
|
||||
session_config = session_llm_config.get(session_id, {})
|
||||
|
||||
# 如果配置了该会话的LLM状态,返回该状态
|
||||
llm_enabled = session_config.get("llm_enabled")
|
||||
if llm_enabled is not None:
|
||||
return llm_enabled
|
||||
|
||||
# 如果没有配置,默认为启用(兼容性考虑)
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def set_llm_status_for_session(session_id: str, enabled: bool) -> None:
|
||||
"""设置LLM在指定会话中的启停状态
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
enabled: True表示启用,False表示禁用
|
||||
"""
|
||||
# 获取当前配置
|
||||
session_llm_config = sp.get("session_llm_config", {})
|
||||
if session_id not in session_llm_config:
|
||||
session_llm_config[session_id] = {}
|
||||
|
||||
# 设置LLM状态
|
||||
session_llm_config[session_id]["llm_enabled"] = enabled
|
||||
|
||||
# 保存配置
|
||||
sp.put("session_llm_config", session_llm_config)
|
||||
|
||||
logger.info(f"会话 {session_id} 的LLM状态已更新为: {'启用' if enabled else '禁用'}")
|
||||
|
||||
@staticmethod
|
||||
def get_session_llm_config(session_id: str) -> Dict[str, bool]:
|
||||
"""获取指定会话的LLM配置
|
||||
|
||||
Args:
|
||||
session_id: 会话ID (unified_msg_origin)
|
||||
|
||||
Returns:
|
||||
Dict[str, bool]: 包含llm_enabled的字典
|
||||
"""
|
||||
session_llm_config = sp.get("session_llm_config", {})
|
||||
return session_llm_config.get(session_id, {
|
||||
"llm_enabled": True # 默认启用
|
||||
})
|
||||
|
||||
@staticmethod
|
||||
def should_process_llm_request(event: AstrMessageEvent) -> bool:
|
||||
"""检查是否应该处理LLM请求
|
||||
|
||||
Args:
|
||||
event: 消息事件
|
||||
|
||||
Returns:
|
||||
bool: True表示应该处理,False表示跳过
|
||||
"""
|
||||
session_id = event.unified_msg_origin
|
||||
return SessionLLMManager.is_llm_enabled_for_session(session_id)
|
||||
@@ -6,6 +6,7 @@ from astrbot.core.db import BaseDatabase
|
||||
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
|
||||
from astrbot.core.provider.entities import ProviderType
|
||||
from astrbot.core.star.session_plugin_manager import SessionPluginManager
|
||||
from astrbot.core.star.session_llm_manager import SessionLLMManager
|
||||
|
||||
|
||||
class SessionManagementRoute(Route):
|
||||
@@ -23,6 +24,7 @@ class SessionManagementRoute(Route):
|
||||
"/session/get_session_info": ("POST", self.get_session_info),
|
||||
"/session/plugins": ("GET", self.get_session_plugins),
|
||||
"/session/update_plugin": ("POST", self.update_session_plugin),
|
||||
"/session/update_llm": ("POST", self.update_session_llm),
|
||||
}
|
||||
self.db_helper = db_helper
|
||||
self.core_lifecycle = core_lifecycle
|
||||
@@ -47,8 +49,7 @@ class SessionManagementRoute(Route):
|
||||
provider_manager = self.core_lifecycle.star_context.provider_manager
|
||||
|
||||
sessions = []
|
||||
|
||||
# 构建会话信息
|
||||
# 构建会话信息
|
||||
for session_id, conversation_id in session_conversations.items():
|
||||
session_info = {
|
||||
"session_id": session_id,
|
||||
@@ -61,6 +62,7 @@ class SessionManagementRoute(Route):
|
||||
"stt_provider_name": None,
|
||||
"tts_provider_id": None,
|
||||
"tts_provider_name": None,
|
||||
"llm_enabled": SessionLLMManager.is_llm_enabled_for_session(session_id),
|
||||
"platform": session_id.split(":")[0] if ":" in session_id else "unknown",
|
||||
"message_type": session_id.split(":")[1] if session_id.count(":") >= 1 else "unknown",
|
||||
"session_name": session_id.split(":")[2] if session_id.count(":") >= 2 else session_id,
|
||||
@@ -260,8 +262,7 @@ class SessionManagementRoute(Route):
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
# 获取会话对话信息
|
||||
# 获取会话对话信息
|
||||
session_conversations = sp.get("session_conversation", {})
|
||||
conversation_id = session_conversations.get(session_id)
|
||||
|
||||
@@ -279,6 +280,7 @@ class SessionManagementRoute(Route):
|
||||
"stt_provider_name": None,
|
||||
"tts_provider_id": None,
|
||||
"tts_provider_name": None,
|
||||
"llm_enabled": SessionLLMManager.is_llm_enabled_for_session(session_id),
|
||||
"platform": session_id.split(":")[0] if ":" in session_id else "unknown",
|
||||
"message_type": session_id.split(":")[1] if session_id.count(":") >= 1 else "unknown",
|
||||
"session_name": session_id.split(":")[2] if session_id.count(":") >= 2 else session_id,
|
||||
@@ -441,3 +443,30 @@ class SessionManagementRoute(Route):
|
||||
error_msg = f"更新会话插件状态失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话插件状态失败: {str(e)}").__dict__
|
||||
|
||||
async def update_session_llm(self):
|
||||
"""更新指定会话的LLM启停状态"""
|
||||
try:
|
||||
data = await request.get_json()
|
||||
session_id = data.get("session_id")
|
||||
enabled = data.get("enabled")
|
||||
|
||||
if not session_id:
|
||||
return Response().error("缺少必要参数: session_id").__dict__
|
||||
|
||||
if enabled is None:
|
||||
return Response().error("缺少必要参数: enabled").__dict__
|
||||
|
||||
# 使用 SessionLLMManager 更新LLM状态
|
||||
SessionLLMManager.set_llm_status_for_session(session_id, enabled)
|
||||
|
||||
return Response().ok({
|
||||
"message": f"LLM已{'启用' if enabled else '禁用'}",
|
||||
"session_id": session_id,
|
||||
"llm_enabled": enabled,
|
||||
}).__dict__
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"更新会话LLM状态失败: {str(e)}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg)
|
||||
return Response().error(f"更新会话LLM状态失败: {str(e)}").__dict__
|
||||
|
||||
Reference in New Issue
Block a user