Compare commits

..

3 Commits

Author SHA1 Message Date
suyao
7f8d0b06ee Merge branch 'main' into fix/check-api-key 2025-12-01 16:37:43 +08:00
suyao
4be5fedeec fix 2025-12-01 00:07:43 +08:00
suyao
163e016759 fix: enhance provider handling and API key rotation logic in AiProvider 2025-12-01 00:01:01 +08:00
9 changed files with 95 additions and 127 deletions

View File

@@ -162,7 +162,7 @@
"@langchain/core": "patch:@langchain/core@npm%3A1.0.2#~/.yarn/patches/@langchain-core-npm-1.0.2-183ef83fe4.patch", "@langchain/core": "patch:@langchain/core@npm%3A1.0.2#~/.yarn/patches/@langchain-core-npm-1.0.2-183ef83fe4.patch",
"@langchain/openai": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch", "@langchain/openai": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
"@mistralai/mistralai": "^1.7.5", "@mistralai/mistralai": "^1.7.5",
"@modelcontextprotocol/sdk": "^1.23.0", "@modelcontextprotocol/sdk": "^1.17.5",
"@mozilla/readability": "^0.6.0", "@mozilla/readability": "^0.6.0",
"@notionhq/client": "^2.2.15", "@notionhq/client": "^2.2.15",
"@openrouter/ai-sdk-provider": "^1.2.8", "@openrouter/ai-sdk-provider": "^1.2.8",

View File

@@ -42,14 +42,11 @@ import {
type MCPPrompt, type MCPPrompt,
type MCPResource, type MCPResource,
type MCPServer, type MCPServer,
type MCPTool, type MCPTool
MCPToolInputSchema,
MCPToolOutputSchema
} from '@types' } from '@types'
import { app, net } from 'electron' import { app, net } from 'electron'
import { EventEmitter } from 'events' import { EventEmitter } from 'events'
import { v4 as uuidv4 } from 'uuid' import { v4 as uuidv4 } from 'uuid'
import * as z from 'zod'
import { CacheService } from './CacheService' import { CacheService } from './CacheService'
import DxtService from './DxtService' import DxtService from './DxtService'
@@ -623,8 +620,6 @@ class McpService {
tools.map((tool: SDKTool) => { tools.map((tool: SDKTool) => {
const serverTool: MCPTool = { const serverTool: MCPTool = {
...tool, ...tool,
inputSchema: z.parse(MCPToolInputSchema, tool.inputSchema),
outputSchema: tool.outputSchema ? z.parse(MCPToolOutputSchema, tool.outputSchema) : undefined,
id: buildFunctionCallToolName(server.name, tool.name, server.id), id: buildFunctionCallToolName(server.name, tool.name, server.id),
serverId: server.id, serverId: server.id,
serverName: server.name, serverName: server.name,

View File

@@ -120,9 +120,12 @@ export default class ModernAiProvider {
throw new Error('Model is required for completions. Please use constructor with model parameter.') throw new Error('Model is required for completions. Please use constructor with model parameter.')
} }
// 每次请求时重新生成配置以确保API key轮换生效 // Config is now set in constructor, ApiService handles key rotation before passing provider
this.config = providerToAiSdkConfig(this.actualProvider, this.model) if (!this.config) {
logger.debug('Generated provider config for completions', this.config) // If config wasn't set in constructor (when provider only), generate it now
this.config = providerToAiSdkConfig(this.actualProvider, this.model!)
}
logger.debug('Using provider config for completions', this.config)
// 检查 config 是否存在 // 检查 config 是否存在
if (!this.config) { if (!this.config) {

View File

@@ -29,32 +29,6 @@ import { azureAnthropicProviderCreator } from './config/azure-anthropic'
import { COPILOT_DEFAULT_HEADERS } from './constants' import { COPILOT_DEFAULT_HEADERS } from './constants'
import { getAiSdkProviderId } from './factory' import { getAiSdkProviderId } from './factory'
/**
* 获取轮询的API key
* 复用legacy架构的多key轮询逻辑
*/
function getRotatedApiKey(provider: Provider): string {
const keys = provider.apiKey.split(',').map((key) => key.trim())
const keyName = `provider:${provider.id}:last_used_key`
if (keys.length === 1) {
return keys[0]
}
const lastUsedKey = window.keyv.get(keyName)
if (!lastUsedKey) {
window.keyv.set(keyName, keys[0])
return keys[0]
}
const currentIndex = keys.indexOf(lastUsedKey)
const nextIndex = (currentIndex + 1) % keys.length
const nextKey = keys[nextIndex]
window.keyv.set(keyName, nextKey)
return nextKey
}
/** /**
* 处理特殊provider的转换逻辑 * 处理特殊provider的转换逻辑
*/ */
@@ -161,7 +135,7 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A
const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost) const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost)
const baseConfig = { const baseConfig = {
baseURL: baseURL, baseURL: baseURL,
apiKey: getRotatedApiKey(actualProvider) apiKey: actualProvider.apiKey
} }
const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot

View File

@@ -8,8 +8,8 @@ import { isDedicatedImageGenerationModel, isEmbeddingModel, isFunctionCallingMod
import { getStoreSetting } from '@renderer/hooks/useSettings' import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n' import i18n from '@renderer/i18n'
import store from '@renderer/store' import store from '@renderer/store'
import type { FetchChatCompletionParams } from '@renderer/types'
import type { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/types' import type { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/types'
import { type FetchChatCompletionParams, isSystemProvider } from '@renderer/types'
import type { StreamTextParams } from '@renderer/types/aiCoreTypes' import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
import { type Chunk, ChunkType } from '@renderer/types/chunk' import { type Chunk, ChunkType } from '@renderer/types/chunk'
import type { Message, ResponseError } from '@renderer/types/newMessage' import type { Message, ResponseError } from '@renderer/types/newMessage'
@@ -22,7 +22,8 @@ import { purifyMarkdownImages } from '@renderer/utils/markdown'
import { isPromptToolUse, isSupportedToolUse } from '@renderer/utils/mcp-tools' import { isPromptToolUse, isSupportedToolUse } from '@renderer/utils/mcp-tools'
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { containsSupportedVariables, replacePromptVariables } from '@renderer/utils/prompt' import { containsSupportedVariables, replacePromptVariables } from '@renderer/utils/prompt'
import { isEmpty, takeRight } from 'lodash' import { NOT_SUPPORT_API_KEY_PROVIDERS } from '@renderer/utils/provider'
import { cloneDeep, isEmpty, takeRight } from 'lodash'
import type { ModernAiProviderConfig } from '../aiCore/index_new' import type { ModernAiProviderConfig } from '../aiCore/index_new'
import AiProviderNew from '../aiCore/index_new' import AiProviderNew from '../aiCore/index_new'
@@ -43,6 +44,8 @@ import {
// } from './MessagesService' // } from './MessagesService'
// import WebSearchService from './WebSearchService' // import WebSearchService from './WebSearchService'
// FIXME: 这里太多重复逻辑,需要重构
const logger = loggerService.withContext('ApiService') const logger = loggerService.withContext('ApiService')
export async function fetchMcpTools(assistant: Assistant) { export async function fetchMcpTools(assistant: Assistant) {
@@ -95,7 +98,15 @@ export async function fetchChatCompletion({
modelId: assistant.model?.id, modelId: assistant.model?.id,
modelName: assistant.model?.name modelName: assistant.model?.name
}) })
const AI = new AiProviderNew(assistant.model || getDefaultModel())
// Get base provider and apply API key rotation
const baseProvider = getProviderByModel(assistant.model || getDefaultModel())
const providerWithRotatedKey = {
...cloneDeep(baseProvider),
apiKey: getRotatedApiKey(baseProvider)
}
const AI = new AiProviderNew(assistant.model || getDefaultModel(), providerWithRotatedKey)
const provider = AI.getActualProvider() const provider = AI.getActualProvider()
const mcpTools: MCPTool[] = [] const mcpTools: MCPTool[] = []
@@ -172,7 +183,13 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages:
return null return null
} }
const AI = new AiProviderNew(model) // Apply API key rotation
const providerWithRotatedKey = {
...cloneDeep(provider),
apiKey: getRotatedApiKey(provider)
}
const AI = new AiProviderNew(model, providerWithRotatedKey)
const topicId = messages?.find((message) => message.topicId)?.topicId || '' const topicId = messages?.find((message) => message.topicId)?.topicId || ''
@@ -271,7 +288,13 @@ export async function fetchNoteSummary({ content, assistant }: { content: string
return null return null
} }
const AI = new AiProviderNew(model) // Apply API key rotation
const providerWithRotatedKey = {
...cloneDeep(provider),
apiKey: getRotatedApiKey(provider)
}
const AI = new AiProviderNew(model, providerWithRotatedKey)
// only 2000 char and no images // only 2000 char and no images
const truncatedContent = content.substring(0, 2000) const truncatedContent = content.substring(0, 2000)
@@ -359,7 +382,13 @@ export async function fetchGenerate({
return '' return ''
} }
const AI = new AiProviderNew(model) // Apply API key rotation
const providerWithRotatedKey = {
...cloneDeep(provider),
apiKey: getRotatedApiKey(provider)
}
const AI = new AiProviderNew(model, providerWithRotatedKey)
const assistant = getDefaultAssistant() const assistant = getDefaultAssistant()
assistant.model = model assistant.model = model
@@ -404,28 +433,44 @@ export async function fetchGenerate({
export function hasApiKey(provider: Provider) { export function hasApiKey(provider: Provider) {
if (!provider) return false if (!provider) return false
if (['ollama', 'lmstudio', 'vertexai', 'cherryai'].includes(provider.id)) return true if (isSystemProvider(provider) && NOT_SUPPORT_API_KEY_PROVIDERS.includes(provider.id)) return true
return !isEmpty(provider.apiKey) return !isEmpty(provider.apiKey)
} }
/** /**
* Get the first available embedding model from enabled providers * 获取轮询的API key
* 复用legacy架构的多key轮询逻辑
*/ */
// function getFirstEmbeddingModel() { function getRotatedApiKey(provider: Provider): string {
// const providers = store.getState().llm.providers.filter((p) => p.enabled) const keys = provider.apiKey.split(',').map((key) => key.trim())
const keyName = `provider:${provider.id}:last_used_key`
// for (const provider of providers) { if (keys.length === 1) {
// const embeddingModel = provider.models.find((model) => isEmbeddingModel(model)) return keys[0]
// if (embeddingModel) { }
// return embeddingModel
// }
// }
// return undefined const lastUsedKey = window.keyv.get(keyName)
// } if (!lastUsedKey) {
window.keyv.set(keyName, keys[0])
return keys[0]
}
const currentIndex = keys.indexOf(lastUsedKey)
const nextIndex = (currentIndex + 1) % keys.length
const nextKey = keys[nextIndex]
window.keyv.set(keyName, nextKey)
return nextKey
}
export async function fetchModels(provider: Provider): Promise<SdkModel[]> { export async function fetchModels(provider: Provider): Promise<SdkModel[]> {
const AI = new AiProviderNew(provider) // Apply API key rotation
const providerWithRotatedKey = {
...cloneDeep(provider),
apiKey: getRotatedApiKey(provider)
}
const AI = new AiProviderNew(providerWithRotatedKey)
try { try {
return await AI.models() return await AI.models()
@@ -435,12 +480,7 @@ export async function fetchModels(provider: Provider): Promise<SdkModel[]> {
} }
export function checkApiProvider(provider: Provider): void { export function checkApiProvider(provider: Provider): void {
if ( if (isSystemProvider(provider) && !NOT_SUPPORT_API_KEY_PROVIDERS.includes(provider.id)) {
provider.id !== 'ollama' &&
provider.id !== 'lmstudio' &&
provider.type !== 'vertexai' &&
provider.id !== 'copilot'
) {
if (!provider.apiKey) { if (!provider.apiKey) {
window.toast.error(i18n.t('message.error.enter.api.label')) window.toast.error(i18n.t('message.error.enter.api.label'))
throw new Error(i18n.t('message.error.enter.api.label')) throw new Error(i18n.t('message.error.enter.api.label'))
@@ -461,8 +501,7 @@ export function checkApiProvider(provider: Provider): void {
export async function checkApi(provider: Provider, model: Model, timeout = 15000): Promise<void> { export async function checkApi(provider: Provider, model: Model, timeout = 15000): Promise<void> {
checkApiProvider(provider) checkApiProvider(provider)
// Don't pass in provider parameter. We need auto-format URL const ai = new AiProviderNew(model, provider)
const ai = new AiProviderNew(model)
const assistant = getDefaultAssistant() const assistant = getDefaultAssistant()
assistant.model = model assistant.model = model

View File

@@ -34,15 +34,6 @@ export const MCPToolInputSchema = z
required: z.array(z.string()).optional() required: z.array(z.string()).optional()
}) })
.loose() .loose()
.transform((schema) => {
if (!schema.properties) {
schema.properties = {}
}
if (!schema.required) {
schema.required = []
}
return schema
})
export interface BuiltinTool extends BaseTool { export interface BuiltinTool extends BaseTool {
inputSchema: z.infer<typeof MCPToolInputSchema> inputSchema: z.infer<typeof MCPToolInputSchema>

View File

@@ -136,10 +136,7 @@ export async function callMCPTool(
topicId?: string, topicId?: string,
modelName?: string modelName?: string
): Promise<MCPCallToolResponse> { ): Promise<MCPCallToolResponse> {
logger.info( logger.info(`Calling Tool: ${toolResponse.tool.serverName} ${toolResponse.tool.name}`, toolResponse.tool)
`Calling Tool: ${toolResponse.id} ${toolResponse.tool.serverName} ${toolResponse.tool.name}`,
toolResponse.tool
)
try { try {
const server = getMcpServerByTool(toolResponse.tool) const server = getMcpServerByTool(toolResponse.tool)

View File

@@ -183,3 +183,11 @@ export const isSupportAPIVersionProvider = (provider: Provider) => {
} }
return provider.apiOptions?.isNotSupportAPIVersion !== false return provider.apiOptions?.isNotSupportAPIVersion !== false
} }
export const NOT_SUPPORT_API_KEY_PROVIDERS: readonly SystemProviderId[] = [
'ollama',
'lmstudio',
'vertexai',
'aws-bedrock',
'copilot'
]

View File

@@ -4747,12 +4747,11 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"@modelcontextprotocol/sdk@npm:^1.23.0": "@modelcontextprotocol/sdk@npm:^1.17.5":
version: 1.23.0 version: 1.17.5
resolution: "@modelcontextprotocol/sdk@npm:1.23.0" resolution: "@modelcontextprotocol/sdk@npm:1.17.5"
dependencies: dependencies:
ajv: "npm:^8.17.1" ajv: "npm:^6.12.6"
ajv-formats: "npm:^3.0.1"
content-type: "npm:^1.0.5" content-type: "npm:^1.0.5"
cors: "npm:^2.8.5" cors: "npm:^2.8.5"
cross-spawn: "npm:^7.0.5" cross-spawn: "npm:^7.0.5"
@@ -4762,17 +4761,9 @@ __metadata:
express-rate-limit: "npm:^7.5.0" express-rate-limit: "npm:^7.5.0"
pkce-challenge: "npm:^5.0.0" pkce-challenge: "npm:^5.0.0"
raw-body: "npm:^3.0.0" raw-body: "npm:^3.0.0"
zod: "npm:^3.25 || ^4.0" zod: "npm:^3.23.8"
zod-to-json-schema: "npm:^3.25.0" zod-to-json-schema: "npm:^3.24.1"
peerDependencies: checksum: 10c0/182b92b5e7c07da428fd23c6de22021c4f9a91f799c02a8ef15def07e4f9361d0fc22303548658fec2a700623535fd44a9dc4d010fb5d803a8f80e3c6c64a45e
"@cfworker/json-schema": ^4.1.1
zod: ^3.25 || ^4.0
peerDependenciesMeta:
"@cfworker/json-schema":
optional: true
zod:
optional: false
checksum: 10c0/b0291f921ad9bda06bbf1a61b1bb61ceca1173da5d74d39a411c40428d6ca50a95f0de3a1631f25a44b439220b15c30c1306600bf48bef665ab7ad118d528260
languageName: node languageName: node
linkType: hard linkType: hard
@@ -10055,7 +10046,7 @@ __metadata:
"@libsql/client": "npm:0.14.0" "@libsql/client": "npm:0.14.0"
"@libsql/win32-x64-msvc": "npm:^0.4.7" "@libsql/win32-x64-msvc": "npm:^0.4.7"
"@mistralai/mistralai": "npm:^1.7.5" "@mistralai/mistralai": "npm:^1.7.5"
"@modelcontextprotocol/sdk": "npm:^1.23.0" "@modelcontextprotocol/sdk": "npm:^1.17.5"
"@mozilla/readability": "npm:^0.6.0" "@mozilla/readability": "npm:^0.6.0"
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch" "@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch"
"@notionhq/client": "npm:^2.2.15" "@notionhq/client": "npm:^2.2.15"
@@ -10412,20 +10403,6 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"ajv-formats@npm:^3.0.1":
version: 3.0.1
resolution: "ajv-formats@npm:3.0.1"
dependencies:
ajv: "npm:^8.0.0"
peerDependencies:
ajv: ^8.0.0
peerDependenciesMeta:
ajv:
optional: true
checksum: 10c0/168d6bca1ea9f163b41c8147bae537e67bd963357a5488a1eaf3abe8baa8eec806d4e45f15b10767e6020679315c7e1e5e6803088dfb84efa2b4e9353b83dd0a
languageName: node
linkType: hard
"ajv-keywords@npm:^3.4.1": "ajv-keywords@npm:^3.4.1":
version: 3.5.2 version: 3.5.2
resolution: "ajv-keywords@npm:3.5.2" resolution: "ajv-keywords@npm:3.5.2"
@@ -10435,7 +10412,7 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"ajv@npm:^6.10.0, ajv@npm:^6.12.0, ajv@npm:^6.12.4": "ajv@npm:^6.10.0, ajv@npm:^6.12.0, ajv@npm:^6.12.4, ajv@npm:^6.12.6":
version: 6.12.6 version: 6.12.6
resolution: "ajv@npm:6.12.6" resolution: "ajv@npm:6.12.6"
dependencies: dependencies:
@@ -10447,7 +10424,7 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"ajv@npm:^8.0.0, ajv@npm:^8.17.1, ajv@npm:^8.6.3": "ajv@npm:^8.0.0, ajv@npm:^8.6.3":
version: 8.17.1 version: 8.17.1
resolution: "ajv@npm:8.17.1" resolution: "ajv@npm:8.17.1"
dependencies: dependencies:
@@ -26376,15 +26353,6 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"zod-to-json-schema@npm:^3.25.0":
version: 3.25.0
resolution: "zod-to-json-schema@npm:3.25.0"
peerDependencies:
zod: ^3.25 || ^4
checksum: 10c0/2d2cf6ca49752bf3dc5fb37bc8f275eddbbc4020e7958d9c198ea88cd197a5f527459118188a0081b889da6a6474d64c4134cd60951fa70178c125138761c680
languageName: node
linkType: hard
"zod-validation-error@npm:^3.4.0": "zod-validation-error@npm:^3.4.0":
version: 3.4.0 version: 3.4.0
resolution: "zod-validation-error@npm:3.4.0" resolution: "zod-validation-error@npm:3.4.0"
@@ -26394,20 +26362,13 @@ __metadata:
languageName: node languageName: node
linkType: hard linkType: hard
"zod@npm:^3.22.4, zod@npm:^3.24.1": "zod@npm:^3.22.4, zod@npm:^3.23.8, zod@npm:^3.24.1":
version: 3.25.56 version: 3.25.56
resolution: "zod@npm:3.25.56" resolution: "zod@npm:3.25.56"
checksum: 10c0/3800f01d4b1df932b91354eb1e648f69cc7e5561549e6d2bf83827d930a5f33bbf92926099445f6fc1ebb64ca9c6513ef9ae5e5409cfef6325f354bcf6fc9a24 checksum: 10c0/3800f01d4b1df932b91354eb1e648f69cc7e5561549e6d2bf83827d930a5f33bbf92926099445f6fc1ebb64ca9c6513ef9ae5e5409cfef6325f354bcf6fc9a24
languageName: node languageName: node
linkType: hard linkType: hard
"zod@npm:^3.25 || ^4.0":
version: 4.1.13
resolution: "zod@npm:4.1.13"
checksum: 10c0/d7e74e82dba81a91ffc3239cd85bc034abe193a28f7087a94ab258a3e48e9a7ca4141920cac979a0d781495b48fc547777394149f26be04c3dc642f58bbc3941
languageName: node
linkType: hard
"zod@npm:^3.25.0 || ^4.0.0, zod@npm:^3.25.76 || ^4": "zod@npm:^3.25.0 || ^4.0.0, zod@npm:^3.25.76 || ^4":
version: 4.1.12 version: 4.1.12
resolution: "zod@npm:4.1.12" resolution: "zod@npm:4.1.12"