* feat: add supports for gemini-3 series thought signature * feat: refactor tools_call_extra_content to use a dictionary for better structure
338 lines
12 KiB
Python
338 lines
12 KiB
Python
import base64
|
|
import enum
|
|
import json
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from anthropic.types import Message as AnthropicMessage
|
|
from google.genai.types import GenerateContentResponse
|
|
from openai.types.chat.chat_completion import ChatCompletion
|
|
|
|
import astrbot.core.message.components as Comp
|
|
from astrbot import logger
|
|
from astrbot.core.agent.message import (
|
|
AssistantMessageSegment,
|
|
ToolCall,
|
|
ToolCallMessageSegment,
|
|
)
|
|
from astrbot.core.agent.tool import ToolSet
|
|
from astrbot.core.db.po import Conversation
|
|
from astrbot.core.message.message_event_result import MessageChain
|
|
from astrbot.core.utils.io import download_image_by_url
|
|
|
|
|
|
class ProviderType(enum.Enum):
|
|
CHAT_COMPLETION = "chat_completion"
|
|
SPEECH_TO_TEXT = "speech_to_text"
|
|
TEXT_TO_SPEECH = "text_to_speech"
|
|
EMBEDDING = "embedding"
|
|
RERANK = "rerank"
|
|
|
|
|
|
@dataclass
|
|
class ProviderMeta:
|
|
"""The basic metadata of a provider instance."""
|
|
|
|
id: str
|
|
"""the unique id of the provider instance that user configured"""
|
|
model: str | None
|
|
"""the model name of the provider instance currently used"""
|
|
type: str
|
|
"""the name of the provider adapter, such as openai, ollama"""
|
|
provider_type: ProviderType = ProviderType.CHAT_COMPLETION
|
|
"""the capability type of the provider adapter"""
|
|
|
|
|
|
@dataclass
|
|
class ProviderMetaData(ProviderMeta):
|
|
"""The metadata of a provider adapter for registration."""
|
|
|
|
desc: str = ""
|
|
"""the short description of the provider adapter"""
|
|
cls_type: Any = None
|
|
"""the class type of the provider adapter"""
|
|
default_config_tmpl: dict | None = None
|
|
"""the default configuration template of the provider adapter"""
|
|
provider_display_name: str | None = None
|
|
"""the display name of the provider shown in the WebUI configuration page; if empty, the type is used"""
|
|
|
|
|
|
@dataclass
|
|
class ToolCallsResult:
|
|
"""工具调用结果"""
|
|
|
|
tool_calls_info: AssistantMessageSegment
|
|
"""函数调用的信息"""
|
|
tool_calls_result: list[ToolCallMessageSegment]
|
|
"""函数调用的结果"""
|
|
|
|
def to_openai_messages(self) -> list[dict]:
|
|
ret = [
|
|
self.tool_calls_info.model_dump(),
|
|
*[item.model_dump() for item in self.tool_calls_result],
|
|
]
|
|
return ret
|
|
|
|
def to_openai_messages_model(
|
|
self,
|
|
) -> list[AssistantMessageSegment | ToolCallMessageSegment]:
|
|
return [
|
|
self.tool_calls_info,
|
|
*self.tool_calls_result,
|
|
]
|
|
|
|
|
|
@dataclass
|
|
class ProviderRequest:
|
|
prompt: str | None = None
|
|
"""提示词"""
|
|
session_id: str | None = ""
|
|
"""会话 ID"""
|
|
image_urls: list[str] = field(default_factory=list)
|
|
"""图片 URL 列表"""
|
|
func_tool: ToolSet | None = None
|
|
"""可用的函数工具"""
|
|
contexts: list[dict] = field(default_factory=list)
|
|
"""
|
|
OpenAI 格式上下文列表。
|
|
参考 https://platform.openai.com/docs/api-reference/chat/create#chat-create-messages
|
|
"""
|
|
system_prompt: str = ""
|
|
"""系统提示词"""
|
|
conversation: Conversation | None = None
|
|
"""关联的对话对象"""
|
|
tool_calls_result: list[ToolCallsResult] | ToolCallsResult | None = None
|
|
"""附加的上次请求后工具调用的结果。参考: https://platform.openai.com/docs/guides/function-calling#handling-function-calls"""
|
|
model: str | None = None
|
|
"""模型名称,为 None 时使用提供商的默认模型"""
|
|
|
|
def __repr__(self):
|
|
return (
|
|
f"ProviderRequest(prompt={self.prompt}, session_id={self.session_id}, "
|
|
f"image_count={len(self.image_urls or [])}, "
|
|
f"func_tool={self.func_tool}, "
|
|
f"contexts={self._print_friendly_context()}, "
|
|
f"system_prompt={self.system_prompt}, "
|
|
f"conversation_id={self.conversation.cid if self.conversation else 'N/A'}, "
|
|
)
|
|
|
|
def __str__(self):
|
|
return self.__repr__()
|
|
|
|
def append_tool_calls_result(self, tool_calls_result: ToolCallsResult):
|
|
"""添加工具调用结果到请求中"""
|
|
if not self.tool_calls_result:
|
|
self.tool_calls_result = []
|
|
if isinstance(self.tool_calls_result, ToolCallsResult):
|
|
self.tool_calls_result = [self.tool_calls_result]
|
|
self.tool_calls_result.append(tool_calls_result)
|
|
|
|
def _print_friendly_context(self):
|
|
"""打印友好的消息上下文。将 image_url 的值替换为 <Image>"""
|
|
if not self.contexts:
|
|
return f"prompt: {self.prompt}, image_count: {len(self.image_urls or [])}"
|
|
|
|
result_parts = []
|
|
|
|
for ctx in self.contexts:
|
|
role = ctx.get("role", "unknown")
|
|
content = ctx.get("content", "")
|
|
|
|
if isinstance(content, str):
|
|
result_parts.append(f"{role}: {content}")
|
|
elif isinstance(content, list):
|
|
msg_parts = []
|
|
image_count = 0
|
|
|
|
for item in content:
|
|
item_type = item.get("type", "")
|
|
|
|
if item_type == "text":
|
|
msg_parts.append(item.get("text", ""))
|
|
elif item_type == "image_url":
|
|
image_count += 1
|
|
|
|
if image_count > 0:
|
|
if msg_parts:
|
|
msg_parts.append(f"[+{image_count} images]")
|
|
else:
|
|
msg_parts.append(f"[{image_count} images]")
|
|
|
|
result_parts.append(f"{role}: {''.join(msg_parts)}")
|
|
|
|
return result_parts
|
|
|
|
async def assemble_context(self) -> dict:
|
|
"""将请求(prompt 和 image_urls)包装成 OpenAI 的消息格式。"""
|
|
if self.image_urls:
|
|
user_content = {
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "text", "text": self.prompt if self.prompt else "[图片]"},
|
|
],
|
|
}
|
|
for image_url in self.image_urls:
|
|
if image_url.startswith("http"):
|
|
image_path = await download_image_by_url(image_url)
|
|
image_data = await self._encode_image_bs64(image_path)
|
|
elif image_url.startswith("file:///"):
|
|
image_path = image_url.replace("file:///", "")
|
|
image_data = await self._encode_image_bs64(image_path)
|
|
else:
|
|
image_data = await self._encode_image_bs64(image_url)
|
|
if not image_data:
|
|
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
|
|
continue
|
|
user_content["content"].append(
|
|
{"type": "image_url", "image_url": {"url": image_data}},
|
|
)
|
|
return user_content
|
|
return {"role": "user", "content": self.prompt}
|
|
|
|
async def _encode_image_bs64(self, image_url: str) -> str:
|
|
"""将图片转换为 base64"""
|
|
if image_url.startswith("base64://"):
|
|
return image_url.replace("base64://", "data:image/jpeg;base64,")
|
|
with open(image_url, "rb") as f:
|
|
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
|
|
return "data:image/jpeg;base64," + image_bs64
|
|
return ""
|
|
|
|
|
|
@dataclass
|
|
class LLMResponse:
|
|
role: str
|
|
"""The role of the message, e.g., assistant, tool, err"""
|
|
result_chain: MessageChain | None = None
|
|
"""A chain of message components representing the text completion from LLM."""
|
|
tools_call_args: list[dict[str, Any]] = field(default_factory=list)
|
|
"""Tool call arguments."""
|
|
tools_call_name: list[str] = field(default_factory=list)
|
|
"""Tool call names."""
|
|
tools_call_ids: list[str] = field(default_factory=list)
|
|
"""Tool call IDs."""
|
|
tools_call_extra_content: dict[str, dict[str, Any]] = field(default_factory=dict)
|
|
"""Tool call extra content. tool_call_id -> extra_content dict"""
|
|
reasoning_content: str = ""
|
|
"""The reasoning content extracted from the LLM, if any."""
|
|
|
|
raw_completion: (
|
|
ChatCompletion | GenerateContentResponse | AnthropicMessage | None
|
|
) = None
|
|
"""The raw completion response from the LLM provider."""
|
|
|
|
_completion_text: str = ""
|
|
"""The plain text of the completion."""
|
|
|
|
is_chunk: bool = False
|
|
"""Indicates if the response is a chunked response."""
|
|
|
|
def __init__(
|
|
self,
|
|
role: str,
|
|
completion_text: str = "",
|
|
result_chain: MessageChain | None = None,
|
|
tools_call_args: list[dict[str, Any]] | None = None,
|
|
tools_call_name: list[str] | None = None,
|
|
tools_call_ids: list[str] | None = None,
|
|
tools_call_extra_content: dict[str, dict[str, Any]] | None = None,
|
|
raw_completion: ChatCompletion
|
|
| GenerateContentResponse
|
|
| AnthropicMessage
|
|
| None = None,
|
|
is_chunk: bool = False,
|
|
):
|
|
"""初始化 LLMResponse
|
|
|
|
Args:
|
|
role (str): 角色, assistant, tool, err
|
|
completion_text (str, optional): 返回的结果文本,已经过时,推荐使用 result_chain. Defaults to "".
|
|
result_chain (MessageChain, optional): 返回的消息链. Defaults to None.
|
|
tools_call_args (List[Dict[str, any]], optional): 工具调用参数. Defaults to None.
|
|
tools_call_name (List[str], optional): 工具调用名称. Defaults to None.
|
|
raw_completion (ChatCompletion, optional): 原始响应, OpenAI 格式. Defaults to None.
|
|
|
|
"""
|
|
if tools_call_args is None:
|
|
tools_call_args = []
|
|
if tools_call_name is None:
|
|
tools_call_name = []
|
|
if tools_call_ids is None:
|
|
tools_call_ids = []
|
|
if tools_call_extra_content is None:
|
|
tools_call_extra_content = {}
|
|
|
|
self.role = role
|
|
self.completion_text = completion_text
|
|
self.result_chain = result_chain
|
|
self.tools_call_args = tools_call_args
|
|
self.tools_call_name = tools_call_name
|
|
self.tools_call_ids = tools_call_ids
|
|
self.tools_call_extra_content = tools_call_extra_content
|
|
self.raw_completion = raw_completion
|
|
self.is_chunk = is_chunk
|
|
|
|
@property
|
|
def completion_text(self):
|
|
if self.result_chain:
|
|
return self.result_chain.get_plain_text()
|
|
return self._completion_text
|
|
|
|
@completion_text.setter
|
|
def completion_text(self, value):
|
|
if self.result_chain:
|
|
self.result_chain.chain = [
|
|
comp
|
|
for comp in self.result_chain.chain
|
|
if not isinstance(comp, Comp.Plain)
|
|
] # 清空 Plain 组件
|
|
self.result_chain.chain.insert(0, Comp.Plain(value))
|
|
else:
|
|
self._completion_text = value
|
|
|
|
def to_openai_tool_calls(self) -> list[dict]:
|
|
"""Convert to OpenAI tool calls format. Deprecated, use to_openai_to_calls_model instead."""
|
|
ret = []
|
|
for idx, tool_call_arg in enumerate(self.tools_call_args):
|
|
payload = {
|
|
"id": self.tools_call_ids[idx],
|
|
"function": {
|
|
"name": self.tools_call_name[idx],
|
|
"arguments": json.dumps(tool_call_arg),
|
|
},
|
|
"type": "function",
|
|
}
|
|
if self.tools_call_extra_content.get(self.tools_call_ids[idx]):
|
|
payload["extra_content"] = self.tools_call_extra_content[
|
|
self.tools_call_ids[idx]
|
|
]
|
|
ret.append(payload)
|
|
return ret
|
|
|
|
def to_openai_to_calls_model(self) -> list[ToolCall]:
|
|
"""The same as to_openai_tool_calls but return pydantic model."""
|
|
ret = []
|
|
for idx, tool_call_arg in enumerate(self.tools_call_args):
|
|
ret.append(
|
|
ToolCall(
|
|
id=self.tools_call_ids[idx],
|
|
function=ToolCall.FunctionBody(
|
|
name=self.tools_call_name[idx],
|
|
arguments=json.dumps(tool_call_arg),
|
|
),
|
|
# the extra_content will not serialize if it's None when calling ToolCall.model_dump()
|
|
extra_content=self.tools_call_extra_content.get(
|
|
self.tools_call_ids[idx]
|
|
),
|
|
),
|
|
)
|
|
return ret
|
|
|
|
|
|
@dataclass
|
|
class RerankResult:
|
|
index: int
|
|
"""在候选列表中的索引位置"""
|
|
relevance_score: float
|
|
"""相关性分数"""
|