import logging from asyncio import Queue from collections.abc import Awaitable, Callable from typing import Any from deprecated import deprecated from astrbot.core.agent.hooks import BaseAgentRunHooks from astrbot.core.agent.message import Message from astrbot.core.agent.runners.tool_loop_agent_runner import ToolLoopAgentRunner from astrbot.core.agent.tool import ToolSet from astrbot.core.astrbot_config_mgr import AstrBotConfigManager from astrbot.core.config.astrbot_config import AstrBotConfig from astrbot.core.conversation_mgr import ConversationManager from astrbot.core.db import BaseDatabase from astrbot.core.knowledge_base.kb_mgr import KnowledgeBaseManager from astrbot.core.message.message_event_result import MessageChain from astrbot.core.persona_mgr import PersonaManager from astrbot.core.platform import Platform from astrbot.core.platform.astr_message_event import AstrMessageEvent, MessageSesion from astrbot.core.platform.manager import PlatformManager from astrbot.core.platform_message_history_mgr import PlatformMessageHistoryManager from astrbot.core.provider.entities import LLMResponse, ProviderRequest, ProviderType from astrbot.core.provider.func_tool_manager import FunctionTool, FunctionToolManager from astrbot.core.provider.manager import ProviderManager from astrbot.core.provider.provider import ( EmbeddingProvider, Provider, RerankProvider, STTProvider, TTSProvider, ) from astrbot.core.star.filter.platform_adapter_type import ( ADAPTER_NAME_2_TYPE, PlatformAdapterType, ) from ..exceptions import ProviderNotFoundError from .filter.command import CommandFilter from .filter.regex import RegexFilter from .star import StarMetadata, star_map, star_registry from .star_handler import EventType, StarHandlerMetadata, star_handlers_registry logger = logging.getLogger("astrbot") class Context: """暴露给插件的接口上下文。""" registered_web_apis: list = [] # back compatibility _register_tasks: list[Awaitable] = [] _star_manager = None def __init__( self, event_queue: Queue, config: AstrBotConfig, db: BaseDatabase, provider_manager: ProviderManager, platform_manager: PlatformManager, conversation_manager: ConversationManager, message_history_manager: PlatformMessageHistoryManager, persona_manager: PersonaManager, astrbot_config_mgr: AstrBotConfigManager, knowledge_base_manager: KnowledgeBaseManager, ): self._event_queue = event_queue """事件队列。消息平台通过事件队列传递消息事件。""" self._config = config """AstrBot 默认配置""" self._db = db """AstrBot 数据库""" self.provider_manager = provider_manager self.platform_manager = platform_manager self.conversation_manager = conversation_manager self.message_history_manager = message_history_manager self.persona_manager = persona_manager self.astrbot_config_mgr = astrbot_config_mgr self.kb_manager = knowledge_base_manager async def llm_generate( self, *, chat_provider_id: str, prompt: str | None = None, image_urls: list[str] | None = None, tools: ToolSet | None = None, system_prompt: str | None = None, contexts: list[Message] | None = None, **kwargs: Any, ) -> LLMResponse: """Call the LLM to generate a response. The method will not automatically execute tool calls. If you want to use tool calls, please use `tool_loop_agent()`. .. versionadded:: 4.5.7 (sdk) Args: chat_provider_id: The chat provider ID to use. prompt: The prompt to send to the LLM, if `contexts` and `prompt` are both provided, `prompt` will be appended as the last user message image_urls: List of image URLs to include in the prompt, if `contexts` and `prompt` are both provided, `image_urls` will be appended to the last user message tools: ToolSet of tools available to the LLM system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context contexts: context messages for the LLM **kwargs: Additional keyword arguments for LLM generation, OpenAI compatible Raises: ChatProviderNotFoundError: If the specified chat provider ID is not found Exception: For other errors during LLM generation """ prov = await self.provider_manager.get_provider_by_id(chat_provider_id) if not prov or not isinstance(prov, Provider): raise ProviderNotFoundError(f"Provider {chat_provider_id} not found") llm_resp = await prov.text_chat( prompt=prompt, image_urls=image_urls, func_tool=tools, contexts=contexts, system_prompt=system_prompt, **kwargs, ) return llm_resp async def tool_loop_agent( self, *, event: AstrMessageEvent, chat_provider_id: str, prompt: str | None = None, image_urls: list[str] | None = None, tools: ToolSet | None = None, system_prompt: str | None = None, contexts: list[Message] | None = None, max_steps: int = 30, tool_call_timeout: int = 60, **kwargs: Any, ) -> LLMResponse: """Run an agent loop that allows the LLM to call tools iteratively until a final answer is produced. If you do not pass the agent_context parameter, the method will recreate a new agent context. .. versionadded:: 4.5.7 (sdk) Args: chat_provider_id: The chat provider ID to use. prompt: The prompt to send to the LLM, if `contexts` and `prompt` are both provided, `prompt` will be appended as the last user message image_urls: List of image URLs to include in the prompt, if `contexts` and `prompt` are both provided, `image_urls` will be appended to the last user message tools: ToolSet of tools available to the LLM system_prompt: System prompt to guide the LLM's behavior, if provided, it will always insert as the first system message in the context contexts: context messages for the LLM max_steps: Maximum number of tool calls before stopping the loop **kwargs: Additional keyword arguments. The kwargs will not be passed to the LLM directly for now, but can include: agent_hooks: BaseAgentRunHooks[AstrAgentContext] - hooks to run during agent execution agent_context: AstrAgentContext - context to use for the agent Returns: The final LLMResponse after tool calls are completed. Raises: ChatProviderNotFoundError: If the specified chat provider ID is not found Exception: For other errors during LLM generation """ # Import here to avoid circular imports from astrbot.core.astr_agent_context import ( AgentContextWrapper, AstrAgentContext, ) from astrbot.core.astr_agent_tool_exec import FunctionToolExecutor prov = await self.provider_manager.get_provider_by_id(chat_provider_id) if not prov or not isinstance(prov, Provider): raise ProviderNotFoundError(f"Provider {chat_provider_id} not found") agent_hooks = kwargs.get("agent_hooks") or BaseAgentRunHooks[AstrAgentContext]() agent_context = kwargs.get("agent_context") context_ = [] for msg in contexts or []: if isinstance(msg, Message): context_.append(msg.model_dump()) else: context_.append(msg) request = ProviderRequest( prompt=prompt, image_urls=image_urls or [], func_tool=tools, contexts=context_, system_prompt=system_prompt or "", ) if agent_context is None: agent_context = AstrAgentContext( context=self, event=event, ) agent_runner = ToolLoopAgentRunner() tool_executor = FunctionToolExecutor() await agent_runner.reset( provider=prov, request=request, run_context=AgentContextWrapper( context=agent_context, tool_call_timeout=tool_call_timeout, ), tool_executor=tool_executor, agent_hooks=agent_hooks, streaming=kwargs.get("stream", False), ) async for _ in agent_runner.step_until_done(max_steps): pass llm_resp = agent_runner.get_final_llm_resp() if not llm_resp: raise Exception("Agent did not produce a final LLM response") return llm_resp async def get_current_chat_provider_id(self, umo: str) -> str: """Get the ID of the currently used chat provider. Args: umo(str): unified_message_origin value, if provided and user has enabled provider session isolation, the provider preferred by that session will be used. Raises: ProviderNotFoundError: If the specified chat provider is not found """ prov = self.get_using_provider(umo) if not prov: raise ProviderNotFoundError("Provider not found") return prov.meta().id def get_registered_star(self, star_name: str) -> StarMetadata | None: """根据插件名获取插件的 Metadata""" for star in star_registry: if star.name == star_name: return star def get_all_stars(self) -> list[StarMetadata]: """获取当前载入的所有插件 Metadata 的列表""" return star_registry def get_llm_tool_manager(self) -> FunctionToolManager: """获取 LLM Tool Manager,其用于管理注册的所有的 Function-calling tools""" return self.provider_manager.llm_tools def activate_llm_tool(self, name: str) -> bool: """激活一个已经注册的函数调用工具。注册的工具默认是激活状态。 Returns: 如果没找到,会返回 False """ return self.provider_manager.llm_tools.activate_llm_tool(name, star_map) def deactivate_llm_tool(self, name: str) -> bool: """停用一个已经注册的函数调用工具。 Returns: 如果没找到,会返回 False """ return self.provider_manager.llm_tools.deactivate_llm_tool(name) def get_provider_by_id( self, provider_id: str, ) -> ( Provider | TTSProvider | STTProvider | EmbeddingProvider | RerankProvider | None ): """通过 ID 获取对应的 LLM Provider。""" prov = self.provider_manager.inst_map.get(provider_id) return prov def get_all_providers(self) -> list[Provider]: """获取所有用于文本生成任务的 LLM Provider(Chat_Completion 类型)。""" return self.provider_manager.provider_insts def get_all_tts_providers(self) -> list[TTSProvider]: """获取所有用于 TTS 任务的 Provider。""" return self.provider_manager.tts_provider_insts def get_all_stt_providers(self) -> list[STTProvider]: """获取所有用于 STT 任务的 Provider。""" return self.provider_manager.stt_provider_insts def get_all_embedding_providers(self) -> list[EmbeddingProvider]: """获取所有用于 Embedding 任务的 Provider。""" return self.provider_manager.embedding_provider_insts def get_using_provider(self, umo: str | None = None) -> Provider | None: """获取当前使用的用于文本生成任务的 LLM Provider(Chat_Completion 类型)。通过 /provider 指令切换。 Args: umo(str): unified_message_origin 值,如果传入并且用户启用了提供商会话隔离,则使用该会话偏好的提供商。 """ prov = self.provider_manager.get_using_provider( provider_type=ProviderType.CHAT_COMPLETION, umo=umo, ) if prov and not isinstance(prov, Provider): raise ValueError("返回的 Provider 不是 Provider 类型") return prov def get_using_tts_provider(self, umo: str | None = None) -> TTSProvider | None: """获取当前使用的用于 TTS 任务的 Provider。 Args: umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 """ prov = self.provider_manager.get_using_provider( provider_type=ProviderType.TEXT_TO_SPEECH, umo=umo, ) if prov and not isinstance(prov, TTSProvider): raise ValueError("返回的 Provider 不是 TTSProvider 类型") return prov def get_using_stt_provider(self, umo: str | None = None) -> STTProvider | None: """获取当前使用的用于 STT 任务的 Provider。 Args: umo(str): unified_message_origin 值,如果传入,则使用该会话偏好的提供商。 """ prov = self.provider_manager.get_using_provider( provider_type=ProviderType.SPEECH_TO_TEXT, umo=umo, ) if prov and not isinstance(prov, STTProvider): raise ValueError("返回的 Provider 不是 STTProvider 类型") return prov def get_config(self, umo: str | None = None) -> AstrBotConfig: """获取 AstrBot 的配置。""" if not umo: # using default config return self._config return self.astrbot_config_mgr.get_conf(umo) async def send_message( self, session: str | MessageSesion, message_chain: MessageChain, ) -> bool: """根据 session(unified_msg_origin) 主动发送消息。 @param session: 消息会话。通过 event.session 或者 event.unified_msg_origin 获取。 @param message_chain: 消息链。 @return: 是否找到匹配的平台。 当 session 为字符串时,会尝试解析为 MessageSesion 对象,如果解析失败,会抛出 ValueError 异常。 NOTE: qq_official(QQ 官方 API 平台) 不支持此方法 """ if isinstance(session, str): try: session = MessageSesion.from_str(session) except BaseException as e: raise ValueError("不合法的 session 字符串: " + str(e)) for platform in self.platform_manager.platform_insts: if platform.meta().id == session.platform_name: await platform.send_by_session(session, message_chain) return True return False def add_llm_tools(self, *tools: FunctionTool) -> None: """添加 LLM 工具。""" tool_name = {tool.name for tool in self.provider_manager.llm_tools.func_list} module_path = "" for tool in tools: if not module_path: _parts = [] module_part = tool.__module__.split(".") flags = ["packages", "plugins"] for i, part in enumerate(module_part): _parts.append(part) if part in flags and i + 1 < len(module_part): _parts.append(module_part[i + 1]) break tool.handler_module_path = ".".join(_parts) module_path = tool.handler_module_path else: tool.handler_module_path = module_path logger.info( f"plugin(module_path {module_path}) added LLM tool: {tool.name}" ) if tool.name in tool_name: logger.warning("替换已存在的 LLM 工具: " + tool.name) self.provider_manager.llm_tools.remove_func(tool.name) self.provider_manager.llm_tools.func_list.append(tool) def register_web_api( self, route: str, view_handler: Awaitable, methods: list, desc: str, ): for idx, api in enumerate(self.registered_web_apis): if api[0] == route and methods == api[2]: self.registered_web_apis[idx] = (route, view_handler, methods, desc) return self.registered_web_apis.append((route, view_handler, methods, desc)) """ 以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。 """ def get_event_queue(self) -> Queue: """获取事件队列。""" return self._event_queue @deprecated(version="4.0.0", reason="Use get_platform_inst instead") def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None: """获取指定类型的平台适配器。 该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0) """ for platform in self.platform_manager.platform_insts: name = platform.meta().name if isinstance(platform_type, str): if name == platform_type: return platform elif ( name in ADAPTER_NAME_2_TYPE and ADAPTER_NAME_2_TYPE[name] & platform_type ): return platform def get_platform_inst(self, platform_id: str) -> Platform | None: """获取指定 ID 的平台适配器实例。 Args: platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。 Returns: Platform: 平台适配器实例,如果未找到则返回 None。 """ for platform in self.platform_manager.platform_insts: if platform.meta().id == platform_id: return platform def get_db(self) -> BaseDatabase: """获取 AstrBot 数据库。""" return self._db def register_provider(self, provider: Provider): """注册一个 LLM Provider(Chat_Completion 类型)。""" self.provider_manager.provider_insts.append(provider) def register_llm_tool( self, name: str, func_args: list, desc: str, func_obj: Callable[..., Awaitable[Any]], ) -> None: """[DEPRECATED]为函数调用(function-calling / tools-use)添加工具。 @param name: 函数名 @param func_args: 函数参数列表,格式为 [{"type": "string", "name": "arg_name", "description": "arg_description"}, ...] @param desc: 函数描述 @param func_obj: 异步处理函数。 异步处理函数会接收到额外的的关键词参数:event: AstrMessageEvent, context: Context。 """ md = StarHandlerMetadata( event_type=EventType.OnLLMRequestEvent, handler_full_name=func_obj.__module__ + "_" + func_obj.__name__, handler_name=func_obj.__name__, handler_module_path=func_obj.__module__, handler=func_obj, event_filters=[], desc=desc, ) star_handlers_registry.append(md) self.provider_manager.llm_tools.add_func(name, func_args, desc, func_obj) def unregister_llm_tool(self, name: str) -> None: """[DEPRECATED]删除一个函数调用工具。如果再要启用,需要重新注册。""" self.provider_manager.llm_tools.remove_func(name) def register_commands( self, star_name: str, command_name: str, desc: str, priority: int, awaitable: Callable[..., Awaitable[Any]], use_regex=False, ignore_prefix=False, ): """注册一个命令。 [Deprecated] 推荐使用装饰器注册指令。该方法将在未来的版本中被移除。 @param star_name: 插件(Star)名称。 @param command_name: 命令名称。 @param desc: 命令描述。 @param priority: 优先级。1-10。 @param awaitable: 异步处理函数。 """ md = StarHandlerMetadata( event_type=EventType.AdapterMessageEvent, handler_full_name=awaitable.__module__ + "_" + awaitable.__name__, handler_name=awaitable.__name__, handler_module_path=awaitable.__module__, handler=awaitable, event_filters=[], desc=desc, ) if use_regex: md.event_filters.append(RegexFilter(regex=command_name)) else: md.event_filters.append( CommandFilter(command_name=command_name, handler_md=md), ) star_handlers_registry.append(md) def register_task(self, task: Awaitable, desc: str): """[DEPRECATED]注册一个异步任务。""" self._register_tasks.append(task)