Compare commits

..

1 Commits

Author SHA1 Message Date
dependabot[bot]
94a28f1055 ci(deps): bump actions/github-script from 7 to 8
Bumps [actions/github-script](https://github.com/actions/github-script) from 7 to 8.
- [Release notes](https://github.com/actions/github-script/releases)
- [Commits](https://github.com/actions/github-script/compare/v7...v8)

---
updated-dependencies:
- dependency-name: actions/github-script
  dependency-version: '8'
  dependency-type: direct:production
  update-type: version-update:semver-major
...

Signed-off-by: dependabot[bot] <support@github.com>
2025-12-01 00:33:53 +00:00
18 changed files with 125 additions and 259 deletions

View File

@@ -42,7 +42,7 @@ jobs:
- name: Add pending label if in quiet hours - name: Add pending label if in quiet hours
if: steps.check_time.outputs.should_delay == 'true' if: steps.check_time.outputs.should_delay == 'true'
uses: actions/github-script@v7 uses: actions/github-script@v8
with: with:
script: | script: |
github.rest.issues.addLabels({ github.rest.issues.addLabels({

View File

@@ -27,7 +27,6 @@ import { buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
import { buildPlugins } from './plugins/PluginBuilder' import { buildPlugins } from './plugins/PluginBuilder'
import { createAiSdkProvider } from './provider/factory' import { createAiSdkProvider } from './provider/factory'
import { import {
adaptProvider,
getActualProvider, getActualProvider,
isModernSdkSupported, isModernSdkSupported,
prepareSpecialProviderConfig, prepareSpecialProviderConfig,
@@ -65,11 +64,12 @@ export default class ModernAiProvider {
* - URL will be automatically formatted via `formatProviderApiHost`, adding version suffixes like `/v1` * - URL will be automatically formatted via `formatProviderApiHost`, adding version suffixes like `/v1`
* *
* 2. When called with `(model, provider)`: * 2. When called with `(model, provider)`:
* - The provided provider will be adapted via `adaptProvider` * - **Directly uses the provided provider WITHOUT going through `getActualProvider`**
* - URL formatting behavior depends on the adapted result * - **URL will NOT be automatically formatted, `/v1` suffix will NOT be added**
* - This is legacy behavior kept for backward compatibility
* *
* 3. When called with `(provider)`: * 3. When called with `(provider)`:
* - The provider will be adapted via `adaptProvider` * - Directly uses the provider without requiring a model
* - Used for operations that don't need a model (e.g., fetchModels) * - Used for operations that don't need a model (e.g., fetchModels)
* *
* @example * @example
@@ -77,7 +77,7 @@ export default class ModernAiProvider {
* // Recommended: Auto-format URL * // Recommended: Auto-format URL
* const ai = new ModernAiProvider(model) * const ai = new ModernAiProvider(model)
* *
* // Provider will be adapted * // Not recommended: Skip URL formatting (only for special cases)
* const ai = new ModernAiProvider(model, customProvider) * const ai = new ModernAiProvider(model, customProvider)
* *
* // For operations that don't need a model * // For operations that don't need a model
@@ -91,12 +91,12 @@ export default class ModernAiProvider {
if (this.isModel(modelOrProvider)) { if (this.isModel(modelOrProvider)) {
// 传入的是 Model // 传入的是 Model
this.model = modelOrProvider this.model = modelOrProvider
this.actualProvider = provider ? adaptProvider({ provider }) : getActualProvider(modelOrProvider) this.actualProvider = provider || getActualProvider(modelOrProvider)
// 只保存配置不预先创建executor // 只保存配置不预先创建executor
this.config = providerToAiSdkConfig(this.actualProvider, modelOrProvider) this.config = providerToAiSdkConfig(this.actualProvider, modelOrProvider)
} else { } else {
// 传入的是 Provider // 传入的是 Provider
this.actualProvider = adaptProvider({ provider: modelOrProvider }) this.actualProvider = modelOrProvider
// model为可选某些操作如fetchModels不需要model // model为可选某些操作如fetchModels不需要model
} }
@@ -120,12 +120,9 @@ export default class ModernAiProvider {
throw new Error('Model is required for completions. Please use constructor with model parameter.') throw new Error('Model is required for completions. Please use constructor with model parameter.')
} }
// Config is now set in constructor, ApiService handles key rotation before passing provider // 每次请求时重新生成配置以确保API key轮换生效
if (!this.config) { this.config = providerToAiSdkConfig(this.actualProvider, this.model)
// If config wasn't set in constructor (when provider only), generate it now logger.debug('Generated provider config for completions', this.config)
this.config = providerToAiSdkConfig(this.actualProvider, this.model!)
}
logger.debug('Using provider config for completions', this.config)
// 检查 config 是否存在 // 检查 config 是否存在
if (!this.config) { if (!this.config) {

View File

@@ -29,6 +29,32 @@ import { azureAnthropicProviderCreator } from './config/azure-anthropic'
import { COPILOT_DEFAULT_HEADERS } from './constants' import { COPILOT_DEFAULT_HEADERS } from './constants'
import { getAiSdkProviderId } from './factory' import { getAiSdkProviderId } from './factory'
/**
* 获取轮询的API key
* 复用legacy架构的多key轮询逻辑
*/
function getRotatedApiKey(provider: Provider): string {
const keys = provider.apiKey.split(',').map((key) => key.trim())
const keyName = `provider:${provider.id}:last_used_key`
if (keys.length === 1) {
return keys[0]
}
const lastUsedKey = window.keyv.get(keyName)
if (!lastUsedKey) {
window.keyv.set(keyName, keys[0])
return keys[0]
}
const currentIndex = keys.indexOf(lastUsedKey)
const nextIndex = (currentIndex + 1) % keys.length
const nextKey = keys[nextIndex]
window.keyv.set(keyName, nextKey)
return nextKey
}
/** /**
* 处理特殊provider的转换逻辑 * 处理特殊provider的转换逻辑
*/ */
@@ -52,13 +78,11 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider {
} }
/** /**
* Format and normalize the API host URL for a provider. * 主要用来对齐AISdk的BaseURL格式
* Handles provider-specific URL formatting rules (e.g., appending version paths, Azure formatting). * @param provider
* * @returns
* @param provider - The provider whose API host is to be formatted.
* @returns A new provider instance with the formatted API host.
*/ */
export function formatProviderApiHost(provider: Provider): Provider { function formatProviderApiHost(provider: Provider): Provider {
const formatted = { ...provider } const formatted = { ...provider }
if (formatted.anthropicApiHost) { if (formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost) formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost)
@@ -90,38 +114,18 @@ export function formatProviderApiHost(provider: Provider): Provider {
} }
/** /**
* Retrieve the effective Provider configuration for the given model. * 获取实际的Provider配置
* Applies all necessary transformations (special-provider handling, URL formatting, etc.). * 简化版:将逻辑分解为小函数
*
* @param model - The model whose provider is to be resolved.
* @returns A new Provider instance with all adaptations applied.
*/ */
export function getActualProvider(model: Model): Provider { export function getActualProvider(model: Model): Provider {
const baseProvider = getProviderByModel(model) const baseProvider = getProviderByModel(model)
return adaptProvider({ provider: baseProvider, model }) // 按顺序处理各种转换
} let actualProvider = cloneDeep(baseProvider)
actualProvider = handleSpecialProviders(model, actualProvider)
actualProvider = formatProviderApiHost(actualProvider)
/** return actualProvider
* Transforms a provider configuration by applying model-specific adaptations and normalizing its API host.
* The transformations are applied in the following order:
* 1. Model-specific provider handling (e.g., New-API, system providers, Azure OpenAI)
* 2. API host formatting (provider-specific URL normalization)
*
* @param provider - The base provider configuration to transform.
* @param model - The model associated with the provider; optional but required for special-provider handling.
* @returns A new Provider instance with all transformations applied.
*/
export function adaptProvider({ provider, model }: { provider: Provider; model?: Model }): Provider {
let adaptedProvider = cloneDeep(provider)
// Apply transformations in order
if (model) {
adaptedProvider = handleSpecialProviders(model, adaptedProvider)
}
adaptedProvider = formatProviderApiHost(adaptedProvider)
return adaptedProvider
} }
/** /**
@@ -135,7 +139,7 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A
const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost) const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost)
const baseConfig = { const baseConfig = {
baseURL: baseURL, baseURL: baseURL,
apiKey: actualProvider.apiKey apiKey: getRotatedApiKey(actualProvider)
} }
const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot

View File

@@ -4372,7 +4372,7 @@
"url": { "url": {
"preview": "Preview: {{url}}", "preview": "Preview: {{url}}",
"reset": "Reset", "reset": "Reset",
"tip": "Add # at the end to disable the automatically appended API version." "tip": "ending with # forces use of input address"
} }
}, },
"api_host": "API Host", "api_host": "API Host",

View File

@@ -4372,7 +4372,7 @@
"url": { "url": {
"preview": "预览: {{url}}", "preview": "预览: {{url}}",
"reset": "重置", "reset": "重置",
"tip": "在末尾添加 # 以禁用自动附加的API版本。" "tip": "# 结尾强制使用输入地址"
} }
}, },
"api_host": "API 地址", "api_host": "API 地址",

View File

@@ -4372,7 +4372,7 @@
"url": { "url": {
"preview": "預覽:{{url}}", "preview": "預覽:{{url}}",
"reset": "重設", "reset": "重設",
"tip": "在末尾添加 # 以停用自動附加的 API 版本。" "tip": "# 結尾強制使用輸入位址"
} }
}, },
"api_host": "API 主機地址", "api_host": "API 主機地址",

View File

@@ -4372,7 +4372,7 @@
"url": { "url": {
"preview": "Vorschau: {{url}}", "preview": "Vorschau: {{url}}",
"reset": "Zurücksetzen", "reset": "Zurücksetzen",
"tip": "Fügen Sie am Ende ein # hinzu, um die automatisch angehängte API-Version zu deaktivieren." "tip": "# am Ende erzwingt die Verwendung der Eingabe-Adresse"
} }
}, },
"api_host": "API-Adresse", "api_host": "API-Adresse",

View File

@@ -4372,7 +4372,7 @@
"url": { "url": {
"preview": "Προεπισκόπηση: {{url}}", "preview": "Προεπισκόπηση: {{url}}",
"reset": "Επαναφορά", "reset": "Επαναφορά",
"tip": "Προσθέστε το σύμβολο # στο τέλος για να απενεργοποιήσετε την αυτόματα προστιθέμενη έκδοση API." "tip": "#τέλος ενδεχόμενη χρήση της εισαγωγής διευθύνσεως"
} }
}, },
"api_host": "Διεύθυνση API", "api_host": "Διεύθυνση API",

View File

@@ -4372,7 +4372,7 @@
"url": { "url": {
"preview": "Vista previa: {{url}}", "preview": "Vista previa: {{url}}",
"reset": "Restablecer", "reset": "Restablecer",
"tip": "Añada # al final para deshabilitar la versión de la API que se añade automáticamente." "tip": "forzar uso de dirección de entrada con # al final"
} }
}, },
"api_host": "Dirección API", "api_host": "Dirección API",

View File

@@ -4372,7 +4372,7 @@
"url": { "url": {
"preview": "Aperçu : {{url}}", "preview": "Aperçu : {{url}}",
"reset": "Réinitialiser", "reset": "Réinitialiser",
"tip": "Ajoutez # à la fin pour désactiver la version d'API ajoutée automatiquement." "tip": "forcer l'utilisation de l'adresse d'entrée si terminé par #"
} }
}, },
"api_host": "Adresse API", "api_host": "Adresse API",

View File

@@ -4372,7 +4372,7 @@
"url": { "url": {
"preview": "プレビュー: {{url}}", "preview": "プレビュー: {{url}}",
"reset": "リセット", "reset": "リセット",
"tip": "自動的に付加されるAPIバージョンを無効にするには、末尾に#を追加します" "tip": "#で終わる場合、入力されたアドレスを強制的に使用します"
} }
}, },
"api_host": "APIホスト", "api_host": "APIホスト",

View File

@@ -4372,7 +4372,7 @@
"url": { "url": {
"preview": "Pré-visualização: {{url}}", "preview": "Pré-visualização: {{url}}",
"reset": "Redefinir", "reset": "Redefinir",
"tip": "Adicione # no final para desativar a versão da API adicionada automaticamente." "tip": "e forçar o uso do endereço original quando terminar com '#'"
} }
}, },
"api_host": "Endereço API", "api_host": "Endereço API",

View File

@@ -4372,7 +4372,7 @@
"url": { "url": {
"preview": "Предпросмотр: {{url}}", "preview": "Предпросмотр: {{url}}",
"reset": "Сброс", "reset": "Сброс",
"tip": "Добавьте # в конце, чтобы отключить автоматически добавляемую версию API." "tip": "заканчивая на # принудительно использует введенный адрес"
} }
}, },
"api_host": "Хост API", "api_host": "Хост API",

View File

@@ -1,10 +1,8 @@
import { adaptProvider } from '@renderer/aiCore/provider/providerConfig'
import OpenAIAlert from '@renderer/components/Alert/OpenAIAlert' import OpenAIAlert from '@renderer/components/Alert/OpenAIAlert'
import { LoadingIcon } from '@renderer/components/Icons' import { LoadingIcon } from '@renderer/components/Icons'
import { HStack } from '@renderer/components/Layout' import { HStack } from '@renderer/components/Layout'
import { ApiKeyListPopup } from '@renderer/components/Popups/ApiKeyListPopup' import { ApiKeyListPopup } from '@renderer/components/Popups/ApiKeyListPopup'
import Selector from '@renderer/components/Selector' import Selector from '@renderer/components/Selector'
import { HelpTooltip } from '@renderer/components/TooltipIcons'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models' import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { PROVIDER_URLS } from '@renderer/config/providers' import { PROVIDER_URLS } from '@renderer/config/providers'
import { useTheme } from '@renderer/context/ThemeProvider' import { useTheme } from '@renderer/context/ThemeProvider'
@@ -21,7 +19,14 @@ import type { SystemProviderId } from '@renderer/types'
import { isSystemProvider, isSystemProviderId, SystemProviderIds } from '@renderer/types' import { isSystemProvider, isSystemProviderId, SystemProviderIds } from '@renderer/types'
import type { ApiKeyConnectivity } from '@renderer/types/healthCheck' import type { ApiKeyConnectivity } from '@renderer/types/healthCheck'
import { HealthStatus } from '@renderer/types/healthCheck' import { HealthStatus } from '@renderer/types/healthCheck'
import { formatApiHost, formatApiKeys, getFancyProviderName, validateApiHost } from '@renderer/utils' import {
formatApiHost,
formatApiKeys,
formatAzureOpenAIApiHost,
formatVertexApiHost,
getFancyProviderName,
validateApiHost
} from '@renderer/utils'
import { formatErrorMessage } from '@renderer/utils/error' import { formatErrorMessage } from '@renderer/utils/error'
import { import {
isAIGatewayProvider, isAIGatewayProvider,
@@ -31,6 +36,7 @@ import {
isNewApiProvider, isNewApiProvider,
isOpenAICompatibleProvider, isOpenAICompatibleProvider,
isOpenAIProvider, isOpenAIProvider,
isSupportAPIVersionProvider,
isVertexProvider isVertexProvider
} from '@renderer/utils/provider' } from '@renderer/utils/provider'
import { Button, Divider, Flex, Input, Select, Space, Switch, Tooltip } from 'antd' import { Button, Divider, Flex, Input, Select, Space, Switch, Tooltip } from 'antd'
@@ -275,10 +281,12 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
}, [configuredApiHost, apiHost]) }, [configuredApiHost, apiHost])
const hostPreview = () => { const hostPreview = () => {
const formattedApiHost = adaptProvider({ provider: { ...provider, apiHost } }).apiHost if (apiHost.endsWith('#')) {
return apiHost.replace('#', '')
}
if (isOpenAICompatibleProvider(provider)) { if (isOpenAICompatibleProvider(provider)) {
return formattedApiHost + '/chat/completions' return formatApiHost(apiHost, isSupportAPIVersionProvider(provider)) + '/chat/completions'
} }
if (isAzureOpenAIProvider(provider)) { if (isAzureOpenAIProvider(provider)) {
@@ -286,26 +294,29 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
const path = !['preview', 'v1'].includes(apiVersion) const path = !['preview', 'v1'].includes(apiVersion)
? `/v1/chat/completion?apiVersion=v1` ? `/v1/chat/completion?apiVersion=v1`
: `/v1/responses?apiVersion=v1` : `/v1/responses?apiVersion=v1`
return formattedApiHost + path return formatAzureOpenAIApiHost(apiHost) + path
} }
if (isAnthropicProvider(provider)) { if (isAnthropicProvider(provider)) {
return formattedApiHost + '/messages' // AI SDK uses the baseURL with /v1, then appends /messages
// formatApiHost adds /v1 automatically if not present
const normalizedHost = formatApiHost(apiHost)
return normalizedHost + '/messages'
} }
if (isGeminiProvider(provider)) { if (isGeminiProvider(provider)) {
return formattedApiHost + '/models' return formatApiHost(apiHost, true, 'v1beta') + '/models'
} }
if (isOpenAIProvider(provider)) { if (isOpenAIProvider(provider)) {
return formattedApiHost + '/responses' return formatApiHost(apiHost) + '/responses'
} }
if (isVertexProvider(provider)) { if (isVertexProvider(provider)) {
return formattedApiHost + '/publishers/google' return formatVertexApiHost(provider) + '/publishers/google'
} }
if (isAIGatewayProvider(provider)) { if (isAIGatewayProvider(provider)) {
return formattedApiHost + '/language-model' return formatApiHost(apiHost) + '/language-model'
} }
return formattedApiHost return formatApiHost(apiHost)
} }
// API key 连通性检查状态指示器,目前仅在失败时显示 // API key 连通性检查状态指示器,目前仅在失败时显示
@@ -483,21 +494,16 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
{!isDmxapi && ( {!isDmxapi && (
<> <>
<SettingSubtitle style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between' }}> <SettingSubtitle style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between' }}>
<div className="flex items-center gap-1"> <Tooltip title={hostSelectorTooltip} mouseEnterDelay={0.3}>
<Tooltip title={hostSelectorTooltip} mouseEnterDelay={0.3}> <Selector
<div> size={14}
<Selector value={activeHostField}
size={14} onChange={(value) => setActiveHostField(value as HostField)}
value={activeHostField} options={hostSelectorOptions}
onChange={(value) => setActiveHostField(value as HostField)} style={{ paddingLeft: 1, fontWeight: 'bold' }}
options={hostSelectorOptions} placement="bottomLeft"
style={{ paddingLeft: 1, fontWeight: 'bold' }} />
placement="bottomLeft" </Tooltip>
/>
</div>
</Tooltip>
<HelpTooltip title={t('settings.provider.api.url.tip')}></HelpTooltip>
</div>
<div style={{ display: 'flex', alignItems: 'center', gap: 4 }}> <div style={{ display: 'flex', alignItems: 'center', gap: 4 }}>
<Button <Button
type="text" type="text"

View File

@@ -8,8 +8,8 @@ import { isDedicatedImageGenerationModel, isEmbeddingModel, isFunctionCallingMod
import { getStoreSetting } from '@renderer/hooks/useSettings' import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n' import i18n from '@renderer/i18n'
import store from '@renderer/store' import store from '@renderer/store'
import type { FetchChatCompletionParams } from '@renderer/types'
import type { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/types' import type { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/types'
import { type FetchChatCompletionParams, isSystemProvider } from '@renderer/types'
import type { StreamTextParams } from '@renderer/types/aiCoreTypes' import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
import { type Chunk, ChunkType } from '@renderer/types/chunk' import { type Chunk, ChunkType } from '@renderer/types/chunk'
import type { Message, ResponseError } from '@renderer/types/newMessage' import type { Message, ResponseError } from '@renderer/types/newMessage'
@@ -22,8 +22,7 @@ import { purifyMarkdownImages } from '@renderer/utils/markdown'
import { isPromptToolUse, isSupportedToolUse } from '@renderer/utils/mcp-tools' import { isPromptToolUse, isSupportedToolUse } from '@renderer/utils/mcp-tools'
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { containsSupportedVariables, replacePromptVariables } from '@renderer/utils/prompt' import { containsSupportedVariables, replacePromptVariables } from '@renderer/utils/prompt'
import { NOT_SUPPORT_API_KEY_PROVIDERS } from '@renderer/utils/provider' import { isEmpty, takeRight } from 'lodash'
import { cloneDeep, isEmpty, takeRight } from 'lodash'
import type { ModernAiProviderConfig } from '../aiCore/index_new' import type { ModernAiProviderConfig } from '../aiCore/index_new'
import AiProviderNew from '../aiCore/index_new' import AiProviderNew from '../aiCore/index_new'
@@ -44,8 +43,6 @@ import {
// } from './MessagesService' // } from './MessagesService'
// import WebSearchService from './WebSearchService' // import WebSearchService from './WebSearchService'
// FIXME: 这里太多重复逻辑,需要重构
const logger = loggerService.withContext('ApiService') const logger = loggerService.withContext('ApiService')
export async function fetchMcpTools(assistant: Assistant) { export async function fetchMcpTools(assistant: Assistant) {
@@ -98,15 +95,7 @@ export async function fetchChatCompletion({
modelId: assistant.model?.id, modelId: assistant.model?.id,
modelName: assistant.model?.name modelName: assistant.model?.name
}) })
const AI = new AiProviderNew(assistant.model || getDefaultModel())
// Get base provider and apply API key rotation
const baseProvider = getProviderByModel(assistant.model || getDefaultModel())
const providerWithRotatedKey = {
...cloneDeep(baseProvider),
apiKey: getRotatedApiKey(baseProvider)
}
const AI = new AiProviderNew(assistant.model || getDefaultModel(), providerWithRotatedKey)
const provider = AI.getActualProvider() const provider = AI.getActualProvider()
const mcpTools: MCPTool[] = [] const mcpTools: MCPTool[] = []
@@ -183,13 +172,7 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages:
return null return null
} }
// Apply API key rotation const AI = new AiProviderNew(model)
const providerWithRotatedKey = {
...cloneDeep(provider),
apiKey: getRotatedApiKey(provider)
}
const AI = new AiProviderNew(model, providerWithRotatedKey)
const topicId = messages?.find((message) => message.topicId)?.topicId || '' const topicId = messages?.find((message) => message.topicId)?.topicId || ''
@@ -288,13 +271,7 @@ export async function fetchNoteSummary({ content, assistant }: { content: string
return null return null
} }
// Apply API key rotation const AI = new AiProviderNew(model)
const providerWithRotatedKey = {
...cloneDeep(provider),
apiKey: getRotatedApiKey(provider)
}
const AI = new AiProviderNew(model, providerWithRotatedKey)
// only 2000 char and no images // only 2000 char and no images
const truncatedContent = content.substring(0, 2000) const truncatedContent = content.substring(0, 2000)
@@ -382,13 +359,7 @@ export async function fetchGenerate({
return '' return ''
} }
// Apply API key rotation const AI = new AiProviderNew(model)
const providerWithRotatedKey = {
...cloneDeep(provider),
apiKey: getRotatedApiKey(provider)
}
const AI = new AiProviderNew(model, providerWithRotatedKey)
const assistant = getDefaultAssistant() const assistant = getDefaultAssistant()
assistant.model = model assistant.model = model
@@ -433,44 +404,28 @@ export async function fetchGenerate({
export function hasApiKey(provider: Provider) { export function hasApiKey(provider: Provider) {
if (!provider) return false if (!provider) return false
if (isSystemProvider(provider) && NOT_SUPPORT_API_KEY_PROVIDERS.includes(provider.id)) return true if (['ollama', 'lmstudio', 'vertexai', 'cherryai'].includes(provider.id)) return true
return !isEmpty(provider.apiKey) return !isEmpty(provider.apiKey)
} }
/** /**
* 获取轮询的API key * Get the first available embedding model from enabled providers
* 复用legacy架构的多key轮询逻辑
*/ */
function getRotatedApiKey(provider: Provider): string { // function getFirstEmbeddingModel() {
const keys = provider.apiKey.split(',').map((key) => key.trim()) // const providers = store.getState().llm.providers.filter((p) => p.enabled)
const keyName = `provider:${provider.id}:last_used_key`
if (keys.length === 1) { // for (const provider of providers) {
return keys[0] // const embeddingModel = provider.models.find((model) => isEmbeddingModel(model))
} // if (embeddingModel) {
// return embeddingModel
// }
// }
const lastUsedKey = window.keyv.get(keyName) // return undefined
if (!lastUsedKey) { // }
window.keyv.set(keyName, keys[0])
return keys[0]
}
const currentIndex = keys.indexOf(lastUsedKey)
const nextIndex = (currentIndex + 1) % keys.length
const nextKey = keys[nextIndex]
window.keyv.set(keyName, nextKey)
return nextKey
}
export async function fetchModels(provider: Provider): Promise<SdkModel[]> { export async function fetchModels(provider: Provider): Promise<SdkModel[]> {
// Apply API key rotation const AI = new AiProviderNew(provider)
const providerWithRotatedKey = {
...cloneDeep(provider),
apiKey: getRotatedApiKey(provider)
}
const AI = new AiProviderNew(providerWithRotatedKey)
try { try {
return await AI.models() return await AI.models()
@@ -480,7 +435,12 @@ export async function fetchModels(provider: Provider): Promise<SdkModel[]> {
} }
export function checkApiProvider(provider: Provider): void { export function checkApiProvider(provider: Provider): void {
if (isSystemProvider(provider) && !NOT_SUPPORT_API_KEY_PROVIDERS.includes(provider.id)) { if (
provider.id !== 'ollama' &&
provider.id !== 'lmstudio' &&
provider.type !== 'vertexai' &&
provider.id !== 'copilot'
) {
if (!provider.apiKey) { if (!provider.apiKey) {
window.toast.error(i18n.t('message.error.enter.api.label')) window.toast.error(i18n.t('message.error.enter.api.label'))
throw new Error(i18n.t('message.error.enter.api.label')) throw new Error(i18n.t('message.error.enter.api.label'))
@@ -501,7 +461,8 @@ export function checkApiProvider(provider: Provider): void {
export async function checkApi(provider: Provider, model: Model, timeout = 15000): Promise<void> { export async function checkApi(provider: Provider, model: Model, timeout = 15000): Promise<void> {
checkApiProvider(provider) checkApiProvider(provider)
const ai = new AiProviderNew(model, provider) // Don't pass in provider parameter. We need auto-format URL
const ai = new AiProviderNew(model)
const assistant = getDefaultAssistant() const assistant = getDefaultAssistant()
assistant.model = model assistant.model = model

View File

@@ -13,8 +13,7 @@ import {
routeToEndpoint, routeToEndpoint,
splitApiKeyString, splitApiKeyString,
validateApiHost, validateApiHost,
withoutTrailingApiVersion, withoutTrailingApiVersion
withoutTrailingSharp
} from '../api' } from '../api'
vi.mock('@renderer/store', () => { vi.mock('@renderer/store', () => {
@@ -82,27 +81,6 @@ describe('api', () => {
it('keeps host untouched when api version unsupported', () => { it('keeps host untouched when api version unsupported', () => {
expect(formatApiHost('https://api.example.com', false)).toBe('https://api.example.com') expect(formatApiHost('https://api.example.com', false)).toBe('https://api.example.com')
}) })
it('removes trailing # and does not append api version when host ends with #', () => {
expect(formatApiHost('https://api.example.com#')).toBe('https://api.example.com')
expect(formatApiHost('http://localhost:5173/#')).toBe('http://localhost:5173/')
expect(formatApiHost(' https://api.openai.com/# ')).toBe('https://api.openai.com/')
})
it('handles trailing # with custom api version settings', () => {
expect(formatApiHost('https://api.example.com#', true, 'v2')).toBe('https://api.example.com')
expect(formatApiHost('https://api.example.com#', false, 'v2')).toBe('https://api.example.com')
})
it('handles host with both trailing # and existing api version', () => {
expect(formatApiHost('https://api.example.com/v2#')).toBe('https://api.example.com/v2')
expect(formatApiHost('https://api.example.com/v3beta#')).toBe('https://api.example.com/v3beta')
})
it('trims whitespace before processing trailing #', () => {
expect(formatApiHost(' https://api.example.com# ')).toBe('https://api.example.com')
expect(formatApiHost('\thttps://api.example.com#\n')).toBe('https://api.example.com')
})
}) })
describe('hasAPIVersion', () => { describe('hasAPIVersion', () => {
@@ -426,56 +404,4 @@ describe('api', () => {
expect(withoutTrailingApiVersion('')).toBe('') expect(withoutTrailingApiVersion('')).toBe('')
}) })
}) })
describe('withoutTrailingSharp', () => {
it('removes trailing # from URL', () => {
expect(withoutTrailingSharp('https://api.example.com#')).toBe('https://api.example.com')
expect(withoutTrailingSharp('http://localhost:3000#')).toBe('http://localhost:3000')
})
it('returns URL unchanged when no trailing #', () => {
expect(withoutTrailingSharp('https://api.example.com')).toBe('https://api.example.com')
expect(withoutTrailingSharp('http://localhost:3000')).toBe('http://localhost:3000')
})
it('handles URLs with multiple # characters but only removes trailing one', () => {
expect(withoutTrailingSharp('https://api.example.com#path#')).toBe('https://api.example.com#path')
})
it('handles URLs with # in the middle (not trailing)', () => {
expect(withoutTrailingSharp('https://api.example.com#section/path')).toBe('https://api.example.com#section/path')
expect(withoutTrailingSharp('https://api.example.com/v1/chat/completions#')).toBe(
'https://api.example.com/v1/chat/completions'
)
})
it('handles empty string', () => {
expect(withoutTrailingSharp('')).toBe('')
})
it('handles single character #', () => {
expect(withoutTrailingSharp('#')).toBe('')
})
it('preserves whitespace around the URL (pure function)', () => {
expect(withoutTrailingSharp(' https://api.example.com# ')).toBe(' https://api.example.com# ')
expect(withoutTrailingSharp('\thttps://api.example.com#\n')).toBe('\thttps://api.example.com#\n')
})
it('only removes exact trailing # character', () => {
expect(withoutTrailingSharp('https://api.example.com# ')).toBe('https://api.example.com# ')
expect(withoutTrailingSharp(' https://api.example.com#')).toBe(' https://api.example.com')
expect(withoutTrailingSharp('https://api.example.com#\t')).toBe('https://api.example.com#\t')
})
it('handles URLs ending with multiple # characters', () => {
expect(withoutTrailingSharp('https://api.example.com##')).toBe('https://api.example.com#')
expect(withoutTrailingSharp('https://api.example.com###')).toBe('https://api.example.com##')
})
it('preserves URL with trailing # and other content', () => {
expect(withoutTrailingSharp('https://api.example.com/v1#')).toBe('https://api.example.com/v1')
expect(withoutTrailingSharp('https://api.example.com/v2beta#')).toBe('https://api.example.com/v2beta')
})
})
}) })

View File

@@ -62,23 +62,6 @@ export function withoutTrailingSlash<T extends string>(url: T): T {
return url.replace(/\/$/, '') as T return url.replace(/\/$/, '') as T
} }
/**
* Removes the trailing '#' from a URL string if it exists.
*
* @template T - The string type to preserve type safety
* @param {T} url - The URL string to process
* @returns {T} The URL string without a trailing '#'
*
* @example
* ```ts
* withoutTrailingSharp('https://example.com#') // 'https://example.com'
* withoutTrailingSharp('https://example.com') // 'https://example.com'
* ```
*/
export function withoutTrailingSharp<T extends string>(url: T): T {
return url.replace(/#$/, '') as T
}
/** /**
* Formats an API host URL by normalizing it and optionally appending an API version. * Formats an API host URL by normalizing it and optionally appending an API version.
* *
@@ -87,12 +70,12 @@ export function withoutTrailingSharp<T extends string>(url: T): T {
* @param apiVersion - The API version to append if needed. Defaults to `'v1'`. * @param apiVersion - The API version to append if needed. Defaults to `'v1'`.
* *
* @returns The formatted API host URL. If the host is empty after normalization, returns an empty string. * @returns The formatted API host URL. If the host is empty after normalization, returns an empty string.
* If the host ends with '#', API version is not supported, or the host already contains a version, returns the normalized host with trailing '#' removed. * If the host ends with '#', API version is not supported, or the host already contains a version, returns the normalized host as-is.
* Otherwise, returns the host with the API version appended. * Otherwise, returns the host with the API version appended.
* *
* @example * @example
* formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1' * formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1'
* formatApiHost('https://api.example.com#') // Returns 'https://api.example.com' * formatApiHost('https://api.example.com#') // Returns 'https://api.example.com#'
* formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2' * formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2'
*/ */
export function formatApiHost(host?: string, supportApiVersion: boolean = true, apiVersion: string = 'v1'): string { export function formatApiHost(host?: string, supportApiVersion: boolean = true, apiVersion: string = 'v1'): string {
@@ -101,13 +84,10 @@ export function formatApiHost(host?: string, supportApiVersion: boolean = true,
return '' return ''
} }
const shouldAppendApiVersion = !(normalizedHost.endsWith('#') || !supportApiVersion || hasAPIVersion(normalizedHost)) if (normalizedHost.endsWith('#') || !supportApiVersion || hasAPIVersion(normalizedHost)) {
return normalizedHost
if (shouldAppendApiVersion) {
return `${normalizedHost}/${apiVersion}`
} else {
return withoutTrailingSharp(normalizedHost)
} }
return `${normalizedHost}/${apiVersion}`
} }
/** /**

View File

@@ -183,11 +183,3 @@ export const isSupportAPIVersionProvider = (provider: Provider) => {
} }
return provider.apiOptions?.isNotSupportAPIVersion !== false return provider.apiOptions?.isNotSupportAPIVersion !== false
} }
export const NOT_SUPPORT_API_KEY_PROVIDERS: readonly SystemProviderId[] = [
'ollama',
'lmstudio',
'vertexai',
'aws-bedrock',
'copilot'
]