Files
AstrBot/astrbot/core/provider/provider.py
2025-07-13 17:27:40 +08:00

209 lines
6.8 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import abc
from typing import List
from typing import TypedDict, AsyncGenerator
from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult, ProviderType
from astrbot.core.provider.register import provider_cls_map
from dataclasses import dataclass
class Personality(TypedDict):
prompt: str = ""
name: str = ""
begin_dialogs: List[str] = []
mood_imitation_dialogs: List[str] = []
# cache
_begin_dialogs_processed: List[dict] = []
_mood_imitation_dialogs_processed: str = ""
@dataclass
class ProviderMeta:
id: str
model: str
type: str
provider_type: ProviderType
class AbstractProvider(abc.ABC):
def __init__(self, provider_config: dict) -> None:
super().__init__()
self.model_name = ""
self.provider_config = provider_config
def set_model(self, model_name: str):
"""设置当前使用的模型名称"""
self.model_name = model_name
def get_model(self) -> str:
"""获得当前使用的模型名称"""
return self.model_name
def meta(self) -> ProviderMeta:
"""获取 Provider 的元数据"""
provider_type_name = self.provider_config["type"]
meta_data = provider_cls_map.get(provider_type_name)
provider_type = meta_data.provider_type if meta_data else None
return ProviderMeta(
id=self.provider_config["id"],
model=self.get_model(),
type=provider_type_name,
provider_type=provider_type,
)
class Provider(AbstractProvider):
def __init__(
self,
provider_config: dict,
provider_settings: dict,
default_persona: Personality | None = None,
) -> None:
super().__init__(provider_config)
self.provider_settings = provider_settings
self.curr_personality = default_persona
"""维护了当前的使用的 persona即人格。可能为 None"""
@abc.abstractmethod
def get_current_key(self) -> str:
raise NotImplementedError()
def get_keys(self) -> List[str]:
"""获得提供商 Key"""
return self.provider_config.get("key", [])
@abc.abstractmethod
def set_key(self, key: str):
raise NotImplementedError()
@abc.abstractmethod
def get_models(self) -> List[str]:
"""获得支持的模型列表"""
raise NotImplementedError()
@abc.abstractmethod
async def text_chat(
self,
prompt: str,
session_id: str = None,
image_urls: list[str] = None,
func_tool: FuncCall = None,
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
model: str | None = None,
**kwargs,
) -> LLMResponse:
"""获得 LLM 的文本对话结果。会使用当前的模型进行对话。
Args:
prompt: 提示词
session_id: 会话 ID(此属性已经被废弃)
image_urls: 图片 URL 列表
tools: Function-calling 工具
contexts: 上下文
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
kwargs: 其他参数
Notes:
- 如果传入了 image_urls将会在对话时附上图片。如果模型不支持图片输入将会抛出错误。
- 如果传入了 tools将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling将会抛出错误。
"""
...
async def text_chat_stream(
self,
prompt: str,
session_id: str = None,
image_urls: list[str] = None,
func_tool: FuncCall = None,
contexts: list = None,
system_prompt: str = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] = None,
model: str | None = None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""获得 LLM 的流式文本对话结果。会使用当前的模型进行对话。在生成的最后会返回一次完整的结果。
Args:
prompt: 提示词
session_id: 会话 ID(此属性已经被废弃)
image_urls: 图片 URL 列表
tools: Function-calling 工具
contexts: 上下文
tool_calls_result: 回传给 LLM 的工具调用结果。参考: https://platform.openai.com/docs/guides/function-calling
kwargs: 其他参数
Notes:
- 如果传入了 image_urls将会在对话时附上图片。如果模型不支持图片输入将会抛出错误。
- 如果传入了 tools将会使用 tools 进行 Function-calling。如果模型不支持 Function-calling将会抛出错误。
"""
...
async def pop_record(self, context: List):
"""
弹出 context 第一条非系统提示词对话记录
"""
poped = 0
indexs_to_pop = []
for idx, record in enumerate(context):
if record["role"] == "system":
continue
else:
indexs_to_pop.append(idx)
poped += 1
if poped == 2:
break
for idx in reversed(indexs_to_pop):
context.pop(idx)
class STTProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def get_text(self, audio_url: str) -> str:
"""获取音频的文本"""
raise NotImplementedError()
class TTSProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def get_audio(self, text: str) -> str:
"""获取文本的音频,返回音频文件路径"""
raise NotImplementedError()
class EmbeddingProvider(AbstractProvider):
def __init__(self, provider_config: dict, provider_settings: dict) -> None:
super().__init__(provider_config)
self.provider_config = provider_config
self.provider_settings = provider_settings
@abc.abstractmethod
async def get_embedding(self, text: str) -> list[float]:
"""获取文本的向量"""
...
@abc.abstractmethod
async def get_embeddings(self, text: list[str]) -> list[list[float]]:
"""批量获取文本的向量"""
...
@abc.abstractmethod
def get_dim(self) -> int:
"""获取向量的维度"""
...