Files
AstrBot/astrbot/core/provider/sources/openai_source.py

643 lines
24 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import asyncio
import base64
import inspect
import json
import os
import random
import re
from collections.abc import AsyncGenerator
from openai import AsyncAzureOpenAI, AsyncOpenAI
from openai._exceptions import NotFoundError
from openai.lib.streaming.chat._completions import ChatCompletionStreamState
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
import astrbot.core.message.components as Comp
from astrbot import logger
from astrbot.api.provider import Provider
from astrbot.core.agent.message import Message
from astrbot.core.agent.tool import ToolSet
from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult
from astrbot.core.utils.io import download_image_by_url
from ..register import register_provider_adapter
@register_provider_adapter(
"openai_chat_completion",
"OpenAI API Chat Completion 提供商适配器",
)
class ProviderOpenAIOfficial(Provider):
def __init__(self, provider_config, provider_settings) -> None:
super().__init__(provider_config, provider_settings)
self.chosen_api_key = None
self.api_keys: list = super().get_keys()
self.chosen_api_key = self.api_keys[0] if len(self.api_keys) > 0 else None
self.timeout = provider_config.get("timeout", 120)
self.custom_headers = provider_config.get("custom_headers", {})
if isinstance(self.timeout, str):
self.timeout = int(self.timeout)
if not isinstance(self.custom_headers, dict) or not self.custom_headers:
self.custom_headers = None
else:
for key in self.custom_headers:
self.custom_headers[key] = str(self.custom_headers[key])
if "api_version" in provider_config:
# Using Azure OpenAI API
self.client = AsyncAzureOpenAI(
api_key=self.chosen_api_key,
api_version=provider_config.get("api_version", None),
default_headers=self.custom_headers,
base_url=provider_config.get("api_base", ""),
timeout=self.timeout,
)
else:
# Using OpenAI Official API
self.client = AsyncOpenAI(
api_key=self.chosen_api_key,
base_url=provider_config.get("api_base", None),
default_headers=self.custom_headers,
timeout=self.timeout,
)
self.default_params = inspect.signature(
self.client.chat.completions.create,
).parameters.keys()
model_config = provider_config.get("model_config", {})
model = model_config.get("model", "unknown")
self.set_model(model)
self.reasoning_key = "reasoning_content"
def _maybe_inject_xai_search(self, payloads: dict, **kwargs):
"""当开启 xAI 原生搜索时,向请求体注入 Live Search 参数。
- 仅在 provider_config.xai_native_search 为 True 时生效
- 默认注入 {"mode": "auto"}
- 允许通过 kwargs 使用 xai_search_mode 覆盖on/auto/off
"""
if not bool(self.provider_config.get("xai_native_search", False)):
return
mode = kwargs.get("xai_search_mode", "auto")
mode = str(mode).lower()
if mode not in ("auto", "on", "off"):
mode = "auto"
# off 时不注入,保持与未开启一致
if mode == "off":
return
# OpenAI SDK 不识别的字段会在 _query/_query_stream 中放入 extra_body
payloads["search_parameters"] = {"mode": mode}
async def get_models(self):
try:
models_str = []
models = await self.client.models.list()
models = sorted(models.data, key=lambda x: x.id)
for model in models:
models_str.append(model.id)
return models_str
except NotFoundError as e:
raise Exception(f"获取模型列表失败:{e}")
async def _query(self, payloads: dict, tools: ToolSet | None) -> LLMResponse:
if tools:
model = payloads.get("model", "").lower()
omit_empty_param_field = "gemini" in model
tool_list = tools.get_func_desc_openai_style(
omit_empty_parameter_field=omit_empty_param_field,
)
if tool_list:
payloads["tools"] = tool_list
# 不在默认参数中的参数放在 extra_body 中
extra_body = {}
to_del = []
for key in payloads:
if key not in self.default_params:
extra_body[key] = payloads[key]
to_del.append(key)
for key in to_del:
del payloads[key]
# 读取并合并 custom_extra_body 配置
custom_extra_body = self.provider_config.get("custom_extra_body", {})
if isinstance(custom_extra_body, dict):
extra_body.update(custom_extra_body)
model = payloads.get("model", "").lower()
# 针对 deepseek 模型的特殊处理deepseek-reasoner调用必须移除 tools ,否则将被切换至 deepseek-chat
if model == "deepseek-reasoner" and "tools" in payloads:
del payloads["tools"]
completion = await self.client.chat.completions.create(
**payloads,
stream=False,
extra_body=extra_body,
)
if not isinstance(completion, ChatCompletion):
raise Exception(
f"API 返回的 completion 类型错误:{type(completion)}: {completion}",
)
logger.debug(f"completion: {completion}")
llm_response = await self._parse_openai_completion(completion, tools)
return llm_response
async def _query_stream(
self,
payloads: dict,
tools: ToolSet | None,
) -> AsyncGenerator[LLMResponse, None]:
"""流式查询API逐步返回结果"""
if tools:
model = payloads.get("model", "").lower()
omit_empty_param_field = "gemini" in model
tool_list = tools.get_func_desc_openai_style(
omit_empty_parameter_field=omit_empty_param_field,
)
if tool_list:
payloads["tools"] = tool_list
# 不在默认参数中的参数放在 extra_body 中
extra_body = {}
# 读取并合并 custom_extra_body 配置
custom_extra_body = self.provider_config.get("custom_extra_body", {})
if isinstance(custom_extra_body, dict):
extra_body.update(custom_extra_body)
to_del = []
for key in payloads:
if key not in self.default_params:
extra_body[key] = payloads[key]
to_del.append(key)
for key in to_del:
del payloads[key]
stream = await self.client.chat.completions.create(
**payloads,
stream=True,
extra_body=extra_body,
)
llm_response = LLMResponse("assistant", is_chunk=True)
state = ChatCompletionStreamState()
async for chunk in stream:
try:
state.handle_chunk(chunk)
except Exception as e:
logger.warning("Saving chunk state error: " + str(e))
if len(chunk.choices) == 0:
continue
delta = chunk.choices[0].delta
# logger.debug(f"chunk delta: {delta}")
# handle the content delta
reasoning = self._extract_reasoning_content(chunk)
_y = False
if reasoning:
llm_response.reasoning_content = reasoning
_y = True
if delta.content:
completion_text = delta.content
llm_response.result_chain = MessageChain(
chain=[Comp.Plain(completion_text)],
)
_y = True
if _y:
yield llm_response
final_completion = state.get_final_completion()
llm_response = await self._parse_openai_completion(final_completion, tools)
yield llm_response
def _extract_reasoning_content(
self,
completion: ChatCompletion | ChatCompletionChunk,
) -> str:
"""Extract reasoning content from OpenAI ChatCompletion if available."""
reasoning_text = ""
if len(completion.choices) == 0:
return reasoning_text
if isinstance(completion, ChatCompletion):
choice = completion.choices[0]
reasoning_attr = getattr(choice.message, self.reasoning_key, None)
if reasoning_attr:
reasoning_text = str(reasoning_attr)
elif isinstance(completion, ChatCompletionChunk):
delta = completion.choices[0].delta
reasoning_attr = getattr(delta, self.reasoning_key, None)
if reasoning_attr:
reasoning_text = str(reasoning_attr)
return reasoning_text
async def _parse_openai_completion(
self, completion: ChatCompletion, tools: ToolSet | None
) -> LLMResponse:
"""Parse OpenAI ChatCompletion into LLMResponse"""
llm_response = LLMResponse("assistant")
if len(completion.choices) == 0:
raise Exception("API 返回的 completion 为空。")
choice = completion.choices[0]
# parse the text completion
if choice.message.content is not None:
# text completion
completion_text = str(choice.message.content).strip()
# specially, some providers may set <think> tags around reasoning content in the completion text,
# we use regex to remove them, and store then in reasoning_content field
reasoning_pattern = re.compile(r"<think>(.*?)</think>", re.DOTALL)
matches = reasoning_pattern.findall(completion_text)
if matches:
llm_response.reasoning_content = "\n".join(
[match.strip() for match in matches],
)
completion_text = reasoning_pattern.sub("", completion_text).strip()
llm_response.result_chain = MessageChain().message(completion_text)
# parse the reasoning content if any
# the priority is higher than the <think> tag extraction
llm_response.reasoning_content = self._extract_reasoning_content(completion)
# parse tool calls if any
if choice.message.tool_calls and tools is not None:
args_ls = []
func_name_ls = []
tool_call_ids = []
tool_call_extra_content_dict = {}
for tool_call in choice.message.tool_calls:
if isinstance(tool_call, str):
# workaround for #1359
tool_call = json.loads(tool_call)
for tool in tools.func_list:
if (
tool_call.type == "function"
and tool.name == tool_call.function.name
):
# workaround for #1454
if isinstance(tool_call.function.arguments, str):
args = json.loads(tool_call.function.arguments)
else:
args = tool_call.function.arguments
args_ls.append(args)
func_name_ls.append(tool_call.function.name)
tool_call_ids.append(tool_call.id)
# gemini-2.5 / gemini-3 series extra_content handling
extra_content = getattr(tool_call, "extra_content", None)
if extra_content is not None:
tool_call_extra_content_dict[tool_call.id] = extra_content
llm_response.role = "tool"
llm_response.tools_call_args = args_ls
llm_response.tools_call_name = func_name_ls
llm_response.tools_call_ids = tool_call_ids
llm_response.tools_call_extra_content = tool_call_extra_content_dict
# specially handle finish reason
if choice.finish_reason == "content_filter":
raise Exception(
"API 返回的 completion 由于内容安全过滤被拒绝(非 AstrBot)。",
)
if llm_response.completion_text is None and not llm_response.tools_call_args:
logger.error(f"API 返回的 completion 无法解析:{completion}")
raise Exception(f"API 返回的 completion 无法解析:{completion}")
llm_response.raw_completion = completion
return llm_response
async def _prepare_chat_payload(
self,
prompt: str | None,
image_urls: list[str] | None = None,
contexts: list[dict] | list[Message] | None = None,
system_prompt: str | None = None,
tool_calls_result: ToolCallsResult | list[ToolCallsResult] | None = None,
model: str | None = None,
**kwargs,
) -> tuple:
"""准备聊天所需的有效载荷和上下文"""
if contexts is None:
contexts = []
new_record = None
if prompt is not None:
new_record = await self.assemble_context(prompt, image_urls)
context_query = self._ensure_message_to_dicts(contexts)
if new_record:
context_query.append(new_record)
if system_prompt:
context_query.insert(0, {"role": "system", "content": system_prompt})
for part in context_query:
if "_no_save" in part:
del part["_no_save"]
# tool calls result
if tool_calls_result:
if isinstance(tool_calls_result, ToolCallsResult):
context_query.extend(tool_calls_result.to_openai_messages())
else:
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {})
model_config["model"] = model or self.get_model()
payloads = {"messages": context_query, **model_config}
# xAI origin search tool inject
self._maybe_inject_xai_search(payloads, **kwargs)
return payloads, context_query
async def _handle_api_error(
self,
e: Exception,
payloads: dict,
context_query: list,
func_tool: ToolSet | None,
chosen_key: str,
available_api_keys: list[str],
retry_cnt: int,
max_retries: int,
) -> tuple:
"""处理API错误并尝试恢复"""
if "429" in str(e):
logger.warning(
f"API 调用过于频繁,尝试使用其他 Key 重试。当前 Key: {chosen_key[:12]}",
)
# 最后一次不等待
if retry_cnt < max_retries - 1:
await asyncio.sleep(1)
available_api_keys.remove(chosen_key)
if len(available_api_keys) > 0:
chosen_key = random.choice(available_api_keys)
return (
False,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
)
raise e
if "maximum context length" in str(e):
logger.warning(
f"上下文长度超过限制。尝试弹出最早的记录然后重试。当前记录条数: {len(context_query)}",
)
await self.pop_record(context_query)
payloads["messages"] = context_query
return (
False,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
)
if "The model is not a VLM" in str(e): # siliconcloud
# 尝试删除所有 image
new_contexts = await self._remove_image_from_context(context_query)
payloads["messages"] = new_contexts
context_query = new_contexts
return (
False,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
)
if (
"Function calling is not enabled" in str(e)
or ("tool" in str(e).lower() and "support" in str(e).lower())
or ("function" in str(e).lower() and "support" in str(e).lower())
):
# openai, ollama, gemini openai, siliconcloud 的错误提示与 code 不统一,只能通过字符串匹配
logger.info(
f"{self.get_model()} 不支持函数工具调用,已自动去除,不影响使用。",
)
payloads.pop("tools", None)
return False, chosen_key, available_api_keys, payloads, context_query, None
# logger.error(f"发生了错误。Provider 配置如下: {self.provider_config}")
if "tool" in str(e).lower() and "support" in str(e).lower():
logger.error("疑似该模型不支持函数调用工具调用。请输入 /tool off_all")
if "Connection error." in str(e):
proxy = os.environ.get("http_proxy", None)
if proxy:
logger.error(
f"可能为代理原因,请检查代理是否正常。当前代理: {proxy}",
)
raise e
async def text_chat(
self,
prompt=None,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> LLMResponse:
payloads, context_query = await self._prepare_chat_payload(
prompt,
image_urls,
contexts,
system_prompt,
tool_calls_result,
model=model,
**kwargs,
)
llm_response = None
max_retries = 10
available_api_keys = self.api_keys.copy()
chosen_key = random.choice(available_api_keys)
last_exception = None
retry_cnt = 0
for retry_cnt in range(max_retries):
try:
self.client.api_key = chosen_key
llm_response = await self._query(payloads, func_tool)
break
except Exception as e:
last_exception = e
(
success,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
) = await self._handle_api_error(
e,
payloads,
context_query,
func_tool,
chosen_key,
available_api_keys,
retry_cnt,
max_retries,
)
if success:
break
if retry_cnt == max_retries - 1 or llm_response is None:
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
if last_exception is None:
raise Exception("未知错误")
raise last_exception
return llm_response
async def text_chat_stream(
self,
prompt=None,
session_id=None,
image_urls=None,
func_tool=None,
contexts=None,
system_prompt=None,
tool_calls_result=None,
model=None,
**kwargs,
) -> AsyncGenerator[LLMResponse, None]:
"""流式对话,与服务商交互并逐步返回结果"""
payloads, context_query = await self._prepare_chat_payload(
prompt,
image_urls,
contexts,
system_prompt,
tool_calls_result,
model=model,
**kwargs,
)
max_retries = 10
available_api_keys = self.api_keys.copy()
chosen_key = random.choice(available_api_keys)
last_exception = None
retry_cnt = 0
for retry_cnt in range(max_retries):
try:
self.client.api_key = chosen_key
async for response in self._query_stream(payloads, func_tool):
yield response
break
except Exception as e:
last_exception = e
(
success,
chosen_key,
available_api_keys,
payloads,
context_query,
func_tool,
) = await self._handle_api_error(
e,
payloads,
context_query,
func_tool,
chosen_key,
available_api_keys,
retry_cnt,
max_retries,
)
if success:
break
if retry_cnt == max_retries - 1:
logger.error(f"API 调用失败,重试 {max_retries} 次仍然失败。")
if last_exception is None:
raise Exception("未知错误")
raise last_exception
async def _remove_image_from_context(self, contexts: list):
"""从上下文中删除所有带有 image 的记录"""
new_contexts = []
for context in contexts:
if "content" in context and isinstance(context["content"], list):
# continue
new_content = []
for item in context["content"]:
if isinstance(item, dict) and "image_url" in item:
continue
new_content.append(item)
if not new_content:
# 用户只发了图片
new_content = [{"type": "text", "text": "[图片]"}]
context["content"] = new_content
new_contexts.append(context)
return new_contexts
def get_current_key(self) -> str:
return self.client.api_key
def get_keys(self) -> list[str]:
return self.api_keys
def set_key(self, key):
self.client.api_key = key
async def assemble_context(
self,
text: str,
image_urls: list[str] | None = None,
) -> dict:
"""组装成符合 OpenAI 格式的 role 为 user 的消息段"""
if image_urls:
user_content = {
"role": "user",
"content": [{"type": "text", "text": text if text else "[图片]"}],
}
for image_url in image_urls:
if image_url.startswith("http"):
image_path = await download_image_by_url(image_url)
image_data = await self.encode_image_bs64(image_path)
elif image_url.startswith("file:///"):
image_path = image_url.replace("file:///", "")
image_data = await self.encode_image_bs64(image_path)
else:
image_data = await self.encode_image_bs64(image_url)
if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue
user_content["content"].append(
{
"type": "image_url",
"image_url": {"url": image_data},
},
)
return user_content
return {"role": "user", "content": text}
async def encode_image_bs64(self, image_url: str) -> str:
"""将图片转换为 base64"""
if image_url.startswith("base64://"):
return image_url.replace("base64://", "data:image/jpeg;base64,")
with open(image_url, "rb") as f:
image_bs64 = base64.b64encode(f.read()).decode("utf-8")
return "data:image/jpeg;base64," + image_bs64