refactor: update ToolSet initialization to use Pydantic Field and clean up deprecated methods in Context

This commit is contained in:
Soulter
2025-11-16 12:13:11 +08:00
parent 388c1ab16d
commit fd9cb703db
2 changed files with 46 additions and 46 deletions

View File

@@ -4,7 +4,7 @@ from typing import Any, Generic
import jsonschema
import mcp
from deprecated import deprecated
from pydantic import model_validator
from pydantic import Field, model_validator
from pydantic.dataclasses import dataclass
from .run_context import ContextWrapper, TContext
@@ -63,6 +63,7 @@ class FunctionTool(ToolSchema, Generic[TContext]):
)
@dataclass
class ToolSet:
"""A set of function tools that can be used in function calling.
@@ -70,8 +71,7 @@ class ToolSet:
convert the tools to different API formats (OpenAI, Anthropic, Google GenAI).
"""
def __init__(self, tools: list[FunctionTool] | None = None):
self.tools: list[FunctionTool] = tools or []
tools: list[FunctionTool] = Field(default_factory=list)
def empty(self) -> bool:
"""Check if the tool set is empty."""

View File

@@ -259,10 +259,6 @@ class Context:
"""
return self.provider_manager.llm_tools.deactivate_llm_tool(name)
def register_provider(self, provider: Provider):
"""注册一个 LLM Provider(Chat_Completion 类型)。"""
self.provider_manager.provider_insts.append(provider)
def get_provider_by_id(
self,
provider_id: str,
@@ -341,45 +337,6 @@ class Context:
return self._config
return self.astrbot_config_mgr.get_conf(umo)
def get_db(self) -> BaseDatabase:
"""获取 AstrBot 数据库。"""
return self._db
def get_event_queue(self) -> Queue:
"""获取事件队列。"""
return self._event_queue
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
"""获取指定类型的平台适配器。
该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
"""
for platform in self.platform_manager.platform_insts:
name = platform.meta().name
if isinstance(platform_type, str):
if name == platform_type:
return platform
elif (
name in ADAPTER_NAME_2_TYPE
and ADAPTER_NAME_2_TYPE[name] & platform_type
):
return platform
def get_platform_inst(self, platform_id: str) -> Platform | None:
"""获取指定 ID 的平台适配器实例。
Args:
platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。
Returns:
Platform: 平台适配器实例,如果未找到则返回 None。
"""
for platform in self.platform_manager.platform_insts:
if platform.meta().id == platform_id:
return platform
async def send_message(
self,
session: str | MessageSesion,
@@ -452,6 +409,49 @@ class Context:
以下的方法已经不推荐使用。请从 AstrBot 文档查看更好的注册方式。
"""
def get_event_queue(self) -> Queue:
"""获取事件队列。"""
return self._event_queue
@deprecated(version="4.0.0", reason="Use get_platform_inst instead")
def get_platform(self, platform_type: PlatformAdapterType | str) -> Platform | None:
"""获取指定类型的平台适配器。
该方法已经过时,请使用 get_platform_inst 方法。(>= AstrBot v4.0.0)
"""
for platform in self.platform_manager.platform_insts:
name = platform.meta().name
if isinstance(platform_type, str):
if name == platform_type:
return platform
elif (
name in ADAPTER_NAME_2_TYPE
and ADAPTER_NAME_2_TYPE[name] & platform_type
):
return platform
def get_platform_inst(self, platform_id: str) -> Platform | None:
"""获取指定 ID 的平台适配器实例。
Args:
platform_id (str): 平台适配器的唯一标识符。你可以通过 event.get_platform_id() 获取。
Returns:
Platform: 平台适配器实例,如果未找到则返回 None。
"""
for platform in self.platform_manager.platform_insts:
if platform.meta().id == platform_id:
return platform
def get_db(self) -> BaseDatabase:
"""获取 AstrBot 数据库。"""
return self._db
def register_provider(self, provider: Provider):
"""注册一个 LLM Provider(Chat_Completion 类型)。"""
self.provider_manager.provider_insts.append(provider)
def register_llm_tool(
self,
name: str,