import abc from typing import List from typing import AsyncGenerator from astrbot.core.agent.tool import ToolSet from astrbot.core.provider.entities import ( LLMResponse, ToolCallsResult, ProviderType, RerankResult, ) from astrbot.core.provider.register import provider_cls_map from astrbot.core.db.po import Personality from dataclasses import dataclass @dataclass class ProviderMeta: id: str model: str type: str provider_type: ProviderType class AbstractProvider(abc.ABC): def __init__(self, provider_config: dict) -> None: super().__init__() self.model_name = "" self.provider_config = provider_config def set_model(self, model_name: str): """设置当前使用的模型名称""" self.model_name = model_name def get_model(self) -> str: """获得当前使用的模型名称""" return self.model_name def meta(self) -> ProviderMeta: """获取 Provider 的元数据""" provider_type_name = self.provider_config["type"] meta_data = provider_cls_map.get(provider_type_name) provider_type = meta_data.provider_type if meta_data else None return ProviderMeta( id=self.provider_config["id"], model=self.get_model(), type=provider_type_name, provider_type=provider_type, ) class Provider(AbstractProvider): def __init__( self, provider_config: dict, provider_settings: dict, default_persona: Personality | None = None, ) -> None: super().__init__(provider_config) self.provider_settings = provider_settings self.curr_personality = default_persona """维护了当前的使用的 persona,即人格。可能为 None""" @abc.abstractmethod def get_current_key(self) -> str: raise NotImplementedError() def get_keys(self) -> List[str]: """获得提供商 Key""" return self.provider_config.get("key", []) @abc.abstractmethod def set_key(self, key: str): raise NotImplementedError() @abc.abstractmethod def get_models(self) -> List[str]: """获得支持的模型列表""" raise NotImplementedError() @abc.abstractmethod async def text_chat( self, prompt: str, session_id: str = None, image_urls: list[str] = None, func_tool: ToolSet = None, contexts: list = None, system_prompt: str = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None, model: str | None = None, **kwargs, ) -> LLMResponse: """获得 LLM 的文本对话结果。会使用当前的模型进行对话。 Args: prompt: 提示词 session_id: 会话 ID(此属性已经被废弃) image_urls: 图片 URL 列表 tools: Function-calling 工具 contexts: 上下文 tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling kwargs: 其他参数 Notes: - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。 - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 """ ... async def text_chat_stream( self, prompt: str, session_id: str = None, image_urls: list[str] = None, func_tool: ToolSet = None, contexts: list = None, system_prompt: str = None, tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None, model: str | None = None, **kwargs, ) -> AsyncGenerator[LLMResponse, None]: """获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。 Args: prompt: 提示词 session_id: 会话 ID(此属性已经被废弃) image_urls: 图片 URL 列表 tools: Function-calling 工具 contexts: 上下文 tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling kwargs: 其他参数 Notes: - 如果传入了 image_urls,将会在对话时附上图片。如果模型不支持图片输入,将会抛出错误。 - 如果传入了 tools,将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling,将会抛出错误。 """ ... async def pop_record(self, context: List): """ 弹出 context 第一条非系统提示词对话记录 """ poped = 0 indexs_to_pop = [] for idx, record in enumerate(context): if record["role"] == "system": continue else: indexs_to_pop.append(idx) poped += 1 if poped == 2: break for idx in reversed(indexs_to_pop): context.pop(idx) class STTProvider(AbstractProvider): def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config) self.provider_config = provider_config self.provider_settings = provider_settings @abc.abstractmethod async def get_text(self, audio_url: str) -> str: """获取音频的文本""" raise NotImplementedError() class TTSProvider(AbstractProvider): def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config) self.provider_config = provider_config self.provider_settings = provider_settings @abc.abstractmethod async def get_audio(self, text: str) -> str: """获取文本的音频,返回音频文件路径""" raise NotImplementedError() class EmbeddingProvider(AbstractProvider): def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config) self.provider_config = provider_config self.provider_settings = provider_settings @abc.abstractmethod async def get_embedding(self, text: str) -> list[float]: """获取文本的向量""" ... @abc.abstractmethod async def get_embeddings(self, text: list[str]) -> list[list[float]]: """批量获取文本的向量""" ... @abc.abstractmethod def get_dim(self) -> int: """获取向量的维度""" ... class RerankProvider(AbstractProvider): def __init__(self, provider_config: dict, provider_settings: dict) -> None: super().__init__(provider_config) self.provider_config = provider_config self.provider_settings = provider_settings @abc.abstractmethod async def rerank( self, query: str, documents: list[str], top_n: int | None = None ) -> list[RerankResult]: """获取查询和文档的重排序分数""" ...