From 65d89cf75998bce2a674ee7665eb425ebed524c7 Mon Sep 17 00:00:00 2001 From: ousugo Date: Thu, 13 Mar 2025 02:11:28 +0800 Subject: [PATCH] feat(GeminiProvider): Add isGemmaModel function and update model handling Introduce isGemmaModel function to identify Gemma models and adjust system instruction handling in GeminiProvider based on model type. Ensure proper message formatting for Gemma models during chat initialization. --- src/renderer/src/config/models.ts | 8 +++++++ src/renderer/src/providers/GeminiProvider.ts | 25 ++++++++++++++++++-- 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/src/renderer/src/config/models.ts b/src/renderer/src/config/models.ts index e6b5aa6a7..e7605e673 100644 --- a/src/renderer/src/config/models.ts +++ b/src/renderer/src/config/models.ts @@ -1992,3 +1992,11 @@ export function getOpenAIWebSearchParams(assistant: Assistant, model: Model): Re return {} } + +export function isGemmaModel(model?: Model): boolean { + if (!model) { + return false + } + + return model.id.includes('gemma-') || model.group === 'Gemma' +} diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index 96dcf8b3d..c6a368bfe 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -13,7 +13,7 @@ import { SafetySetting, TextPart } from '@google/generative-ai' -import { isWebSearchModel } from '@renderer/config/models' +import { isGemmaModel, isWebSearchModel } from '@renderer/config/models' import { getStoreSetting } from '@renderer/hooks/useSettings' import i18n from '@renderer/i18n' import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService' @@ -205,7 +205,7 @@ export default class GeminiProvider extends BaseProvider { const geminiModel = this.sdk.getGenerativeModel( { model: model.id, - systemInstruction: assistant.prompt, + ...(isGemmaModel(model) ? {} : { systemInstruction: assistant.prompt }), safetySettings: this.getSafetySettings(model.id), tools: tools, generationConfig: { @@ -221,6 +221,27 @@ export default class GeminiProvider extends BaseProvider { const chat = geminiModel.startChat({ history }) const messageContents = await this.getMessageContents(userLastMessage!) + if (isGemmaModel(model) && assistant.prompt) { + const isFirstMessage = history.length === 0 + if (isFirstMessage) { + const systemMessage = { + role: 'user', + parts: [ + { + text: + 'user\n' + + assistant.prompt + + '\n' + + 'user\n' + + messageContents.parts[0].text + + '' + } + ] + } + messageContents.parts = systemMessage.parts + } + } + const start_time_millsec = new Date().getTime() const { abortController, cleanup } = this.createAbortController(userLastMessage?.id) const { signal } = abortController