fix: handle tool_calls_result as list or single object in context query in streaming mode

This commit is contained in:
Soulter
2025-07-02 10:16:44 +08:00
parent c36142deaf
commit adb0cbc5dd
2 changed files with 30 additions and 18 deletions

View File

@@ -300,7 +300,11 @@ class ProviderAnthropic(Provider):
# tool calls result # tool calls result
if tool_calls_result: if tool_calls_result:
context_query.extend(tool_calls_result.to_openai_messages()) if not isinstance(tool_calls_result, list):
context_query.extend(tool_calls_result.to_openai_messages())
else:
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
system_prompt, new_messages = self._prepare_payload(context_query) system_prompt, new_messages = self._prepare_payload(context_query)

View File

@@ -14,7 +14,7 @@ import astrbot.core.message.components as Comp
from astrbot import logger from astrbot import logger
from astrbot.api.provider import Provider from astrbot.api.provider import Provider
from astrbot.core.message.message_event_result import MessageChain from astrbot.core.message.message_event_result import MessageChain
from astrbot.core.provider.entities import LLMResponse, ToolCallsResult from astrbot.core.provider.entities import LLMResponse
from astrbot.core.provider.func_tool_manager import FuncCall from astrbot.core.provider.func_tool_manager import FuncCall
from astrbot.core.utils.io import download_image_by_url from astrbot.core.utils.io import download_image_by_url
@@ -259,10 +259,12 @@ class ProviderGoogleGenAI(Provider):
contents.append(content_cls(parts=part)) contents.append(content_cls(parts=part))
gemini_contents: list[types.Content] = [] gemini_contents: list[types.Content] = []
native_tool_enabled = any([ native_tool_enabled = any(
self.provider_config.get("gm_native_coderunner", False), [
self.provider_config.get("gm_native_search", False), self.provider_config.get("gm_native_coderunner", False),
]) self.provider_config.get("gm_native_search", False),
]
)
for message in payloads["messages"]: for message in payloads["messages"]:
role, content = message["role"], message.get("content") role, content = message["role"], message.get("content")
@@ -544,13 +546,13 @@ class ProviderGoogleGenAI(Provider):
async def text_chat_stream( async def text_chat_stream(
self, self,
prompt: str, prompt,
session_id: str = None, session_id=None,
image_urls: list[str] = None, image_urls=None,
func_tool: FuncCall = None, func_tool=None,
contexts: str = None, contexts=None,
system_prompt: str = None, system_prompt=None,
tool_calls_result: ToolCallsResult = None, tool_calls_result=None,
**kwargs, **kwargs,
) -> AsyncGenerator[LLMResponse, None]: ) -> AsyncGenerator[LLMResponse, None]:
if contexts is None: if contexts is None:
@@ -566,7 +568,11 @@ class ProviderGoogleGenAI(Provider):
# tool calls result # tool calls result
if tool_calls_result: if tool_calls_result:
context_query.extend(tool_calls_result.to_openai_messages()) if not isinstance(tool_calls_result, list):
context_query.extend(tool_calls_result.to_openai_messages())
else:
for tcr in tool_calls_result:
context_query.extend(tcr.to_openai_messages())
model_config = self.provider_config.get("model_config", {}) model_config = self.provider_config.get("model_config", {})
model_config["model"] = self.get_model() model_config["model"] = self.get_model()
@@ -628,10 +634,12 @@ class ProviderGoogleGenAI(Provider):
if not image_data: if not image_data:
logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。") logger.warning(f"图片 {image_url} 得到的结果为空,将忽略。")
continue continue
user_content["content"].append({ user_content["content"].append(
"type": "image_url", {
"image_url": {"url": image_data}, "type": "image_url",
}) "image_url": {"url": image_data},
}
)
return user_content return user_content
else: else:
return {"role": "user", "content": text} return {"role": "user", "content": text}