fix: improve Gemini reasoning and message handling (#11439)
* fix: some bug * fix/test * fix: lint * fix: 添加跳过 Gemini3 思考签名的中间件并更新消息转换逻辑 * fix: comment * fix: js docs * fix:id bug * fix: condition * fix: Update the user's verbosity setting logic to ensure that supported options are prioritized for use. * fix: Add support for the 'openai-response' provider type. * fix: lint
This commit is contained in:
@@ -156,7 +156,7 @@ export default class ModernAiProvider {
|
||||
config: ModernAiProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
// ai-gateway不是image/generation 端点,所以就先不走legacy了
|
||||
if (config.isImageGenerationEndpoint && config.provider!.id !== SystemProviderIds['ai-gateway']) {
|
||||
if (config.isImageGenerationEndpoint && this.getActualProvider().id !== SystemProviderIds['ai-gateway']) {
|
||||
// 使用 legacy 实现处理图像生成(支持图片编辑等高级功能)
|
||||
if (!config.uiMessages) {
|
||||
throw new Error('uiMessages is required for image generation endpoint')
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
import { loggerService } from '@logger'
|
||||
import { isSupportedThinkingTokenQwenModel } from '@renderer/config/models'
|
||||
import { isGemini3Model, isSupportedThinkingTokenQwenModel } from '@renderer/config/models'
|
||||
import type { MCPTool } from '@renderer/types'
|
||||
import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
@@ -9,11 +9,13 @@ import type { LanguageModelMiddleware } from 'ai'
|
||||
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
|
||||
import { noThinkMiddleware } from './noThinkMiddleware'
|
||||
import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware'
|
||||
import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
|
||||
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
|
||||
import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware'
|
||||
import { toolChoiceMiddleware } from './toolChoiceMiddleware'
|
||||
|
||||
const logger = loggerService.withContext('AiSdkMiddlewareBuilder')
|
||||
@@ -257,6 +259,15 @@ function addModelSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: Ai
|
||||
middleware: openrouterGenerateImageMiddleware()
|
||||
})
|
||||
}
|
||||
|
||||
if (isGemini3Model(config.model)) {
|
||||
const aiSdkId = getAiSdkProviderId(config.provider)
|
||||
builder.add({
|
||||
name: 'skip-gemini3-thought-signature',
|
||||
middleware: skipGeminiThoughtSignatureMiddleware(aiSdkId)
|
||||
})
|
||||
logger.debug('Added skip Gemini3 thought signature middleware')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
|
||||
/**
|
||||
* skip Gemini Thought Signature Middleware
|
||||
* 由于多模型客户端请求的复杂性(可以中途切换其他模型),这里选择通过中间件方式添加跳过所有 Gemini3 思考签名
|
||||
* Due to the complexity of multi-model client requests (which can switch to other models mid-process),
|
||||
* it was decided to add a skip for all Gemini3 thinking signatures via middleware.
|
||||
* @param aiSdkId AI SDK Provider ID
|
||||
* @returns LanguageModelMiddleware
|
||||
*/
|
||||
export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelMiddleware {
|
||||
const MAGIC_STRING = 'skip_thought_signature_validator'
|
||||
return {
|
||||
middlewareVersion: 'v2',
|
||||
|
||||
transformParams: async ({ params }) => {
|
||||
const transformedParams = { ...params }
|
||||
// Process messages in prompt
|
||||
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
|
||||
transformedParams.prompt = transformedParams.prompt.map((message) => {
|
||||
if (typeof message.content !== 'string') {
|
||||
for (const part of message.content) {
|
||||
const googleOptions = part?.providerOptions?.[aiSdkId]
|
||||
if (googleOptions?.thoughtSignature) {
|
||||
googleOptions.thoughtSignature = MAGIC_STRING
|
||||
}
|
||||
}
|
||||
}
|
||||
return message
|
||||
})
|
||||
}
|
||||
|
||||
return transformedParams
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -180,6 +180,10 @@ describe('messageConverter', () => {
|
||||
const result = await convertMessagesToSdkMessages([initialUser, assistant, finalUser], model)
|
||||
|
||||
expect(result).toEqual([
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'Start editing' }]
|
||||
},
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'Here is the current preview' }]
|
||||
@@ -217,6 +221,7 @@ describe('messageConverter', () => {
|
||||
|
||||
expect(result).toEqual([
|
||||
{ role: 'system', content: 'fileid://reference' },
|
||||
{ role: 'user', content: [{ type: 'text', text: 'Use this document as inspiration' }] },
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'Generated previews ready' }]
|
||||
|
||||
@@ -194,20 +194,20 @@ async function convertMessageToAssistantModelMessage(
|
||||
* This function processes messages and transforms them into the format required by the SDK.
|
||||
* It handles special cases for vision models and image enhancement models.
|
||||
*
|
||||
* @param messages - Array of messages to convert. Must contain at least 2 messages when using image enhancement models.
|
||||
* @param messages - Array of messages to convert. Must contain at least 3 messages when using image enhancement models for special handling.
|
||||
* @param model - The model configuration that determines conversion behavior
|
||||
*
|
||||
* @returns A promise that resolves to an array of SDK-compatible model messages
|
||||
*
|
||||
* @remarks
|
||||
* For image enhancement models with 2+ messages:
|
||||
* - Expects the second-to-last message (index length-2) to be an assistant message containing image blocks
|
||||
* - Expects the last message (index length-1) to be a user message
|
||||
* - Extracts images from the assistant message and appends them to the user message content
|
||||
* - Returns only the last two processed messages [assistantSdkMessage, userSdkMessage]
|
||||
* For image enhancement models with 3+ messages:
|
||||
* - Examines the last 2 messages to find an assistant message containing image blocks
|
||||
* - If found, extracts images from the assistant message and appends them to the last user message content
|
||||
* - Returns all converted messages (not just the last two) with the images merged into the user message
|
||||
* - Typical pattern: [system?, assistant(image), user] -> [system?, assistant, user(image)]
|
||||
*
|
||||
* For other models:
|
||||
* - Returns all converted messages in order
|
||||
* - Returns all converted messages in order without special image handling
|
||||
*
|
||||
* The function automatically detects vision model capabilities and adjusts conversion accordingly.
|
||||
*/
|
||||
@@ -220,29 +220,25 @@ export async function convertMessagesToSdkMessages(messages: Message[], model: M
|
||||
sdkMessages.push(...(Array.isArray(sdkMessage) ? sdkMessage : [sdkMessage]))
|
||||
}
|
||||
// Special handling for image enhancement models
|
||||
// Only keep the last two messages and merge images into the user message
|
||||
// [system?, user, assistant, user]
|
||||
// Only merge images into the user message
|
||||
// [system?, assistant(image), user] -> [system?, assistant, user(image)]
|
||||
if (isImageEnhancementModel(model) && messages.length >= 3) {
|
||||
const needUpdatedMessages = messages.slice(-2)
|
||||
const needUpdatedSdkMessages = sdkMessages.slice(-2)
|
||||
const assistantMessage = needUpdatedMessages.filter((m) => m.role === 'assistant')[0]
|
||||
const assistantSdkMessage = needUpdatedSdkMessages.filter((m) => m.role === 'assistant')[0]
|
||||
const userSdkMessage = needUpdatedSdkMessages.filter((m) => m.role === 'user')[0]
|
||||
const systemSdkMessages = sdkMessages.filter((m) => m.role === 'system')
|
||||
const imageBlocks = findImageBlocks(assistantMessage)
|
||||
const imageParts = await convertImageBlockToImagePart(imageBlocks)
|
||||
const parts: Array<TextPart | ImagePart | FilePart> = []
|
||||
if (typeof userSdkMessage.content === 'string') {
|
||||
parts.push({ type: 'text', text: userSdkMessage.content })
|
||||
parts.push(...imageParts)
|
||||
userSdkMessage.content = parts
|
||||
} else {
|
||||
userSdkMessage.content.push(...imageParts)
|
||||
const assistantMessage = needUpdatedMessages.find((m) => m.role === 'assistant')
|
||||
const userSdkMessage = sdkMessages[sdkMessages.length - 1]
|
||||
|
||||
if (assistantMessage && userSdkMessage?.role === 'user') {
|
||||
const imageBlocks = findImageBlocks(assistantMessage)
|
||||
const imageParts = await convertImageBlockToImagePart(imageBlocks)
|
||||
|
||||
if (imageParts.length > 0) {
|
||||
if (typeof userSdkMessage.content === 'string') {
|
||||
userSdkMessage.content = [{ type: 'text', text: userSdkMessage.content }, ...imageParts]
|
||||
} else if (Array.isArray(userSdkMessage.content)) {
|
||||
userSdkMessage.content.push(...imageParts)
|
||||
}
|
||||
}
|
||||
}
|
||||
if (systemSdkMessages.length > 0) {
|
||||
return [systemSdkMessages[0], assistantSdkMessage, userSdkMessage]
|
||||
}
|
||||
return [assistantSdkMessage, userSdkMessage]
|
||||
}
|
||||
|
||||
return sdkMessages
|
||||
|
||||
@@ -91,9 +91,21 @@ function getServiceTier<T extends Provider>(model: Model, provider: T): OpenAISe
|
||||
}
|
||||
}
|
||||
|
||||
function getVerbosity(): OpenAIVerbosity {
|
||||
function getVerbosity(model: Model): OpenAIVerbosity {
|
||||
if (!isSupportVerbosityModel(model) || !isSupportVerbosityProvider(getProviderById(model.provider)!)) {
|
||||
return undefined
|
||||
}
|
||||
const openAI = getStoreSetting('openAI')
|
||||
return openAI.verbosity
|
||||
|
||||
const userVerbosity = openAI.verbosity
|
||||
|
||||
if (userVerbosity) {
|
||||
const supportedVerbosity = getModelSupportedVerbosity(model)
|
||||
// Use user's verbosity if supported, otherwise use the first supported option
|
||||
const verbosity = supportedVerbosity.includes(userVerbosity) ? userVerbosity : supportedVerbosity[0]
|
||||
return verbosity
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -148,7 +160,7 @@ export function buildProviderOptions(
|
||||
// 构建 provider 特定的选项
|
||||
let providerSpecificOptions: Record<string, any> = {}
|
||||
const serviceTier = getServiceTier(model, actualProvider)
|
||||
const textVerbosity = getVerbosity()
|
||||
const textVerbosity = getVerbosity(model)
|
||||
// 根据 provider 类型分离构建逻辑
|
||||
const { data: baseProviderId, success } = baseProviderIdSchema.safeParse(rawProviderId)
|
||||
if (success) {
|
||||
@@ -163,7 +175,8 @@ export function buildProviderOptions(
|
||||
assistant,
|
||||
model,
|
||||
capabilities,
|
||||
serviceTier
|
||||
serviceTier,
|
||||
textVerbosity
|
||||
)
|
||||
providerSpecificOptions = options
|
||||
}
|
||||
@@ -196,7 +209,8 @@ export function buildProviderOptions(
|
||||
model,
|
||||
capabilities,
|
||||
actualProvider,
|
||||
serviceTier
|
||||
serviceTier,
|
||||
textVerbosity
|
||||
)
|
||||
break
|
||||
default:
|
||||
@@ -255,7 +269,7 @@ export function buildProviderOptions(
|
||||
}[rawProviderId] || rawProviderId
|
||||
|
||||
if (rawProviderKey === 'cherryin') {
|
||||
rawProviderKey = { gemini: 'google' }[actualProvider.type] || actualProvider.type
|
||||
rawProviderKey = { gemini: 'google', ['openai-response']: 'openai' }[actualProvider.type] || actualProvider.type
|
||||
}
|
||||
|
||||
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数
|
||||
@@ -278,7 +292,8 @@ function buildOpenAIProviderOptions(
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
},
|
||||
serviceTier: OpenAIServiceTier
|
||||
serviceTier: OpenAIServiceTier,
|
||||
textVerbosity?: OpenAIVerbosity
|
||||
): OpenAIResponsesProviderOptions {
|
||||
const { enableReasoning } = capabilities
|
||||
let providerOptions: OpenAIResponsesProviderOptions = {}
|
||||
@@ -314,7 +329,8 @@ function buildOpenAIProviderOptions(
|
||||
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
serviceTier
|
||||
serviceTier,
|
||||
textVerbosity
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
@@ -413,11 +429,13 @@ function buildCherryInProviderOptions(
|
||||
enableGenerateImage: boolean
|
||||
},
|
||||
actualProvider: Provider,
|
||||
serviceTier: OpenAIServiceTier
|
||||
serviceTier: OpenAIServiceTier,
|
||||
textVerbosity: OpenAIVerbosity
|
||||
): OpenAIResponsesProviderOptions | AnthropicProviderOptions | GoogleGenerativeAIProviderOptions {
|
||||
switch (actualProvider.type) {
|
||||
case 'openai':
|
||||
return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier)
|
||||
case 'openai-response':
|
||||
return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier, textVerbosity)
|
||||
|
||||
case 'anthropic':
|
||||
return buildAnthropicProviderOptions(assistant, model, capabilities)
|
||||
|
||||
@@ -12,7 +12,7 @@ import {
|
||||
isDeepSeekHybridInferenceModel,
|
||||
isDoubaoSeedAfter251015,
|
||||
isDoubaoThinkingAutoModel,
|
||||
isGemini3Model,
|
||||
isGemini3ThinkingTokenModel,
|
||||
isGPT51SeriesModel,
|
||||
isGrok4FastReasoningModel,
|
||||
isGrokReasoningModel,
|
||||
@@ -36,7 +36,7 @@ import {
|
||||
} from '@renderer/config/models'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import type { Assistant, Model, ReasoningEffortOption } from '@renderer/types'
|
||||
import type { Assistant, Model } from '@renderer/types'
|
||||
import { EFFORT_RATIO, isSystemProvider, SystemProviderIds } from '@renderer/types'
|
||||
import type { OpenAISummaryText } from '@renderer/types/aiCoreTypes'
|
||||
import type { ReasoningEffortOptionalParams } from '@renderer/types/sdk'
|
||||
@@ -281,7 +281,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
|
||||
// gemini series, openai compatible api
|
||||
if (isSupportedThinkingTokenGeminiModel(model)) {
|
||||
// https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#openai_compatibility
|
||||
if (isGemini3Model(model)) {
|
||||
if (isGemini3ThinkingTokenModel(model)) {
|
||||
return {
|
||||
reasoning_effort: reasoningEffort
|
||||
}
|
||||
@@ -465,20 +465,20 @@ export function getAnthropicReasoningParams(
|
||||
return {}
|
||||
}
|
||||
|
||||
type GoogelThinkingLevel = NonNullable<GoogleGenerativeAIProviderOptions['thinkingConfig']>['thinkingLevel']
|
||||
// type GoogleThinkingLevel = NonNullable<GoogleGenerativeAIProviderOptions['thinkingConfig']>['thinkingLevel']
|
||||
|
||||
function mapToGeminiThinkingLevel(reasoningEffort: ReasoningEffortOption): GoogelThinkingLevel {
|
||||
switch (reasoningEffort) {
|
||||
case 'low':
|
||||
return 'low'
|
||||
case 'medium':
|
||||
return 'medium'
|
||||
case 'high':
|
||||
return 'high'
|
||||
default:
|
||||
return 'medium'
|
||||
}
|
||||
}
|
||||
// function mapToGeminiThinkingLevel(reasoningEffort: ReasoningEffortOption): GoogelThinkingLevel {
|
||||
// switch (reasoningEffort) {
|
||||
// case 'low':
|
||||
// return 'low'
|
||||
// case 'medium':
|
||||
// return 'medium'
|
||||
// case 'high':
|
||||
// return 'high'
|
||||
// default:
|
||||
// return 'medium'
|
||||
// }
|
||||
// }
|
||||
|
||||
/**
|
||||
* 获取 Gemini 推理参数
|
||||
@@ -507,14 +507,15 @@ export function getGeminiReasoningParams(
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: 很多中转还不支持
|
||||
// https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#new_api_features_in_gemini_3
|
||||
if (isGemini3Model(model)) {
|
||||
return {
|
||||
thinkingConfig: {
|
||||
thinkingLevel: mapToGeminiThinkingLevel(reasoningEffort)
|
||||
}
|
||||
}
|
||||
}
|
||||
// if (isGemini3ThinkingTokenModel(model)) {
|
||||
// return {
|
||||
// thinkingConfig: {
|
||||
// thinkingLevel: mapToGeminiThinkingLevel(reasoningEffort)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
|
||||
|
||||
@@ -33,6 +33,7 @@ import {
|
||||
MODEL_SUPPORTED_OPTIONS,
|
||||
MODEL_SUPPORTED_REASONING_EFFORT
|
||||
} from '../reasoning'
|
||||
import { isGemini3ThinkingTokenModel } from '../utils'
|
||||
import { isTextToImageModel } from '../vision'
|
||||
|
||||
vi.mock('@renderer/store', () => ({
|
||||
@@ -955,7 +956,7 @@ describe('Gemini Models', () => {
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
).toBe(false)
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'gemini-3.0-flash-image-preview',
|
||||
@@ -963,7 +964,7 @@ describe('Gemini Models', () => {
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
).toBe(false)
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'gemini-3.5-pro-image-preview',
|
||||
@@ -971,7 +972,7 @@ describe('Gemini Models', () => {
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for gemini-2.x image models', () => {
|
||||
@@ -1163,7 +1164,7 @@ describe('Gemini Models', () => {
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
).toBe(false)
|
||||
expect(
|
||||
isGeminiReasoningModel({
|
||||
id: 'gemini-3.5-flash-image-preview',
|
||||
@@ -1171,7 +1172,7 @@ describe('Gemini Models', () => {
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for older gemini models without thinking', () => {
|
||||
@@ -1230,3 +1231,153 @@ describe('findTokenLimit', () => {
|
||||
expect(findTokenLimit('unknown-model')).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('isGemini3ThinkingTokenModel', () => {
|
||||
it('should return true for Gemini 3 non-image models', () => {
|
||||
expect(
|
||||
isGemini3ThinkingTokenModel({
|
||||
id: 'gemini-3-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isGemini3ThinkingTokenModel({
|
||||
id: 'gemini-3-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isGemini3ThinkingTokenModel({
|
||||
id: 'gemini-3-pro-preview',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isGemini3ThinkingTokenModel({
|
||||
id: 'google/gemini-3-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isGemini3ThinkingTokenModel({
|
||||
id: 'gemini-3.0-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isGemini3ThinkingTokenModel({
|
||||
id: 'gemini-3.5-pro-preview',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for Gemini 3 image models', () => {
|
||||
expect(
|
||||
isGemini3ThinkingTokenModel({
|
||||
id: 'gemini-3-flash-image',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isGemini3ThinkingTokenModel({
|
||||
id: 'gemini-3-pro-image-preview',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isGemini3ThinkingTokenModel({
|
||||
id: 'gemini-3.0-flash-image-preview',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isGemini3ThinkingTokenModel({
|
||||
id: 'gemini-3.5-pro-image-preview',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for non-Gemini 3 models', () => {
|
||||
expect(
|
||||
isGemini3ThinkingTokenModel({
|
||||
id: 'gemini-2.5-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isGemini3ThinkingTokenModel({
|
||||
id: 'gemini-1.5-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isGemini3ThinkingTokenModel({
|
||||
id: 'gpt-4',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isGemini3ThinkingTokenModel({
|
||||
id: 'claude-3-opus',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle case insensitivity', () => {
|
||||
expect(
|
||||
isGemini3ThinkingTokenModel({
|
||||
id: 'Gemini-3-Flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isGemini3ThinkingTokenModel({
|
||||
id: 'GEMINI-3-PRO',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isGemini3ThinkingTokenModel({
|
||||
id: 'Gemini-3-Pro-Image',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -16,7 +16,7 @@ import {
|
||||
isOpenAIReasoningModel,
|
||||
isSupportedReasoningEffortOpenAIModel
|
||||
} from './openai'
|
||||
import { GEMINI_FLASH_MODEL_REGEX, isGemini3Model } from './utils'
|
||||
import { GEMINI_FLASH_MODEL_REGEX, isGemini3ThinkingTokenModel } from './utils'
|
||||
import { isTextToImageModel } from './vision'
|
||||
|
||||
// Reasoning models
|
||||
@@ -115,7 +115,7 @@ const _getThinkModelType = (model: Model): ThinkingModelType => {
|
||||
} else {
|
||||
thinkingModelType = 'gemini_pro'
|
||||
}
|
||||
if (isGemini3Model(model)) {
|
||||
if (isGemini3ThinkingTokenModel(model)) {
|
||||
thinkingModelType = 'gemini3'
|
||||
}
|
||||
} else if (isSupportedReasoningEffortGrokModel(model)) thinkingModelType = 'grok'
|
||||
@@ -271,14 +271,6 @@ export const GEMINI_THINKING_MODEL_REGEX =
|
||||
export const isSupportedThinkingTokenGeminiModel = (model: Model): boolean => {
|
||||
const modelId = getLowerBaseModelName(model.id, '/')
|
||||
if (GEMINI_THINKING_MODEL_REGEX.test(modelId)) {
|
||||
// gemini-3.x 的 image 模型支持思考模式
|
||||
if (isGemini3Model(model)) {
|
||||
if (modelId.includes('tts')) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
// gemini-2.x 的 image/tts 模型不支持
|
||||
if (modelId.includes('image') || modelId.includes('tts')) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -43,7 +43,8 @@ const FUNCTION_CALLING_EXCLUDED_MODELS = [
|
||||
'gpt-5-chat(?:-[\\w-]+)?',
|
||||
'glm-4\\.5v',
|
||||
'gemini-2.5-flash-image(?:-[\\w-]+)?',
|
||||
'gemini-2.0-flash-preview-image-generation'
|
||||
'gemini-2.0-flash-preview-image-generation',
|
||||
'gemini-3(?:\\.\\d+)?-pro-image(?:-[\\w-]+)?'
|
||||
]
|
||||
|
||||
export const FUNCTION_CALLING_REGEX = new RegExp(
|
||||
|
||||
@@ -164,3 +164,8 @@ export const isGemini3Model = (model: Model) => {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return modelId.includes('gemini-3')
|
||||
}
|
||||
|
||||
export const isGemini3ThinkingTokenModel = (model: Model) => {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return isGemini3Model(model) && !modelId.includes('image')
|
||||
}
|
||||
|
||||
@@ -4,6 +4,7 @@ import { BingLogo, BochaLogo, ExaLogo, SearXNGLogo, TavilyLogo, ZhipuLogo } from
|
||||
import type { QuickPanelListItem } from '@renderer/components/QuickPanel'
|
||||
import { QuickPanelReservedSymbol } from '@renderer/components/QuickPanel'
|
||||
import {
|
||||
isFunctionCallingModel,
|
||||
isGeminiModel,
|
||||
isGPT5SeriesReasoningModel,
|
||||
isOpenAIWebSearchModel,
|
||||
@@ -18,6 +19,7 @@ import WebSearchService from '@renderer/services/WebSearchService'
|
||||
import type { WebSearchProvider, WebSearchProviderId } from '@renderer/types'
|
||||
import { hasObjectKey } from '@renderer/utils'
|
||||
import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
||||
import { isPromptToolUse } from '@renderer/utils/mcp-tools'
|
||||
import { isGeminiWebSearchProvider } from '@renderer/utils/provider'
|
||||
import { Globe } from 'lucide-react'
|
||||
import { useCallback, useEffect, useMemo } from 'react'
|
||||
@@ -126,20 +128,25 @@ export const useWebSearchPanelController = (assistantId: string, quickPanelContr
|
||||
|
||||
const providerItems = useMemo<QuickPanelListItem[]>(() => {
|
||||
const isWebSearchModelEnabled = assistant.model && isWebSearchModel(assistant.model)
|
||||
const items: QuickPanelListItem[] = providers
|
||||
.map((p) => ({
|
||||
label: p.name,
|
||||
description: WebSearchService.isWebSearchEnabled(p.id)
|
||||
? hasObjectKey(p, 'apiKey')
|
||||
? t('settings.tool.websearch.apikey')
|
||||
: t('settings.tool.websearch.free')
|
||||
: t('chat.input.web_search.enable_content'),
|
||||
icon: <WebSearchProviderIcon size={13} pid={p.id} />,
|
||||
isSelected: p.id === assistant?.webSearchProviderId,
|
||||
disabled: !WebSearchService.isWebSearchEnabled(p.id),
|
||||
action: () => updateQuickPanelItem(p.id)
|
||||
}))
|
||||
.filter((item) => !item.disabled)
|
||||
const items: QuickPanelListItem[] = []
|
||||
if (isFunctionCallingModel(assistant.model) || isPromptToolUse(assistant)) {
|
||||
items.push(
|
||||
...providers
|
||||
.map((p) => ({
|
||||
label: p.name,
|
||||
description: WebSearchService.isWebSearchEnabled(p.id)
|
||||
? hasObjectKey(p, 'apiKey')
|
||||
? t('settings.tool.websearch.apikey')
|
||||
: t('settings.tool.websearch.free')
|
||||
: t('chat.input.web_search.enable_content'),
|
||||
icon: <WebSearchProviderIcon size={13} pid={p.id} />,
|
||||
isSelected: p.id === assistant?.webSearchProviderId,
|
||||
disabled: !WebSearchService.isWebSearchEnabled(p.id),
|
||||
action: () => updateQuickPanelItem(p.id)
|
||||
}))
|
||||
.filter((item) => !item.disabled)
|
||||
)
|
||||
}
|
||||
|
||||
if (isWebSearchModelEnabled) {
|
||||
items.unshift({
|
||||
@@ -155,15 +162,7 @@ export const useWebSearchPanelController = (assistantId: string, quickPanelContr
|
||||
}
|
||||
|
||||
return items
|
||||
}, [
|
||||
assistant.enableWebSearch,
|
||||
assistant.model,
|
||||
assistant?.webSearchProviderId,
|
||||
providers,
|
||||
t,
|
||||
updateQuickPanelItem,
|
||||
updateToModelBuiltinWebSearch
|
||||
])
|
||||
}, [assistant, providers, t, updateQuickPanelItem, updateToModelBuiltinWebSearch])
|
||||
|
||||
const openQuickPanel = useCallback(() => {
|
||||
quickPanelController.open({
|
||||
|
||||
Reference in New Issue
Block a user