Compare commits
1 Commits
fix/check-
...
fix/valida
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b3cd1edfdc |
@@ -134,9 +134,9 @@ artifactBuildCompleted: scripts/artifact-build-completed.js
|
||||
releaseInfo:
|
||||
releaseNotes: |
|
||||
<!--LANG:en-->
|
||||
A New Era of Intelligence with Cherry Studio 1.7.1
|
||||
A New Era of Intelligence with Cherry Studio 1.7.0
|
||||
|
||||
Today we're releasing Cherry Studio 1.7.1 — our most ambitious update yet, introducing Agent: autonomous AI that thinks, plans, and acts.
|
||||
Today we're releasing Cherry Studio 1.7.0 — our most ambitious update yet, introducing Agent: autonomous AI that thinks, plans, and acts.
|
||||
|
||||
For years, AI assistants have been reactive — waiting for your commands, responding to your questions. With Agent, we're changing that. Now, AI can truly work alongside you: understanding complex goals, breaking them into steps, and executing them independently.
|
||||
|
||||
@@ -187,9 +187,9 @@ releaseInfo:
|
||||
The Agent Era is here. We can't wait to see what you'll create.
|
||||
|
||||
<!--LANG:zh-CN-->
|
||||
Cherry Studio 1.7.1:开启智能新纪元
|
||||
Cherry Studio 1.7.0:开启智能新纪元
|
||||
|
||||
今天,我们正式发布 Cherry Studio 1.7.1 —— 迄今最具雄心的版本,带来全新的 Agent:能够自主思考、规划和行动的 AI。
|
||||
今天,我们正式发布 Cherry Studio 1.7.0 —— 迄今最具雄心的版本,带来全新的 Agent:能够自主思考、规划和行动的 AI。
|
||||
|
||||
多年来,AI 助手一直是被动的——等待你的指令,回应你的问题。Agent 改变了这一切。现在,AI 能够真正与你并肩工作:理解复杂目标,将其拆解为步骤,并独立执行。
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "CherryStudio",
|
||||
"version": "1.7.1",
|
||||
"version": "1.7.0",
|
||||
"private": true,
|
||||
"description": "A powerful AI assistant for producer.",
|
||||
"main": "./out/main/index.js",
|
||||
|
||||
@@ -7,10 +7,10 @@
|
||||
* 2. 暂时保持接口兼容性
|
||||
*/
|
||||
|
||||
import type { GatewayLanguageModelEntry } from '@ai-sdk/gateway'
|
||||
import { createExecutor } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
|
||||
import { normalizeGatewayModels, normalizeSdkModels } from '@renderer/services/models/ModelAdapter'
|
||||
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
|
||||
import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||
import { type Assistant, type GenerateImageParams, type Model, type Provider, SystemProviderIds } from '@renderer/types'
|
||||
@@ -27,7 +27,6 @@ import { buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
|
||||
import { buildPlugins } from './plugins/PluginBuilder'
|
||||
import { createAiSdkProvider } from './provider/factory'
|
||||
import {
|
||||
adaptProvider,
|
||||
getActualProvider,
|
||||
isModernSdkSupported,
|
||||
prepareSpecialProviderConfig,
|
||||
@@ -65,11 +64,12 @@ export default class ModernAiProvider {
|
||||
* - URL will be automatically formatted via `formatProviderApiHost`, adding version suffixes like `/v1`
|
||||
*
|
||||
* 2. When called with `(model, provider)`:
|
||||
* - The provided provider will be adapted via `adaptProvider`
|
||||
* - URL formatting behavior depends on the adapted result
|
||||
* - **Directly uses the provided provider WITHOUT going through `getActualProvider`**
|
||||
* - **URL will NOT be automatically formatted, `/v1` suffix will NOT be added**
|
||||
* - This is legacy behavior kept for backward compatibility
|
||||
*
|
||||
* 3. When called with `(provider)`:
|
||||
* - The provider will be adapted via `adaptProvider`
|
||||
* - Directly uses the provider without requiring a model
|
||||
* - Used for operations that don't need a model (e.g., fetchModels)
|
||||
*
|
||||
* @example
|
||||
@@ -77,7 +77,7 @@ export default class ModernAiProvider {
|
||||
* // Recommended: Auto-format URL
|
||||
* const ai = new ModernAiProvider(model)
|
||||
*
|
||||
* // Provider will be adapted
|
||||
* // Not recommended: Skip URL formatting (only for special cases)
|
||||
* const ai = new ModernAiProvider(model, customProvider)
|
||||
*
|
||||
* // For operations that don't need a model
|
||||
@@ -91,12 +91,12 @@ export default class ModernAiProvider {
|
||||
if (this.isModel(modelOrProvider)) {
|
||||
// 传入的是 Model
|
||||
this.model = modelOrProvider
|
||||
this.actualProvider = provider ? adaptProvider({ provider }) : getActualProvider(modelOrProvider)
|
||||
this.actualProvider = provider || getActualProvider(modelOrProvider)
|
||||
// 只保存配置,不预先创建executor
|
||||
this.config = providerToAiSdkConfig(this.actualProvider, modelOrProvider)
|
||||
} else {
|
||||
// 传入的是 Provider
|
||||
this.actualProvider = adaptProvider({ provider: modelOrProvider })
|
||||
this.actualProvider = modelOrProvider
|
||||
// model为可选,某些操作(如fetchModels)不需要model
|
||||
}
|
||||
|
||||
@@ -120,12 +120,9 @@ export default class ModernAiProvider {
|
||||
throw new Error('Model is required for completions. Please use constructor with model parameter.')
|
||||
}
|
||||
|
||||
// Config is now set in constructor, ApiService handles key rotation before passing provider
|
||||
if (!this.config) {
|
||||
// If config wasn't set in constructor (when provider only), generate it now
|
||||
this.config = providerToAiSdkConfig(this.actualProvider, this.model!)
|
||||
}
|
||||
logger.debug('Using provider config for completions', this.config)
|
||||
// 每次请求时重新生成配置以确保API key轮换生效
|
||||
this.config = providerToAiSdkConfig(this.actualProvider, this.model)
|
||||
logger.debug('Generated provider config for completions', this.config)
|
||||
|
||||
// 检查 config 是否存在
|
||||
if (!this.config) {
|
||||
@@ -484,18 +481,11 @@ export default class ModernAiProvider {
|
||||
// 代理其他方法到原有实现
|
||||
public async models() {
|
||||
if (this.actualProvider.id === SystemProviderIds['ai-gateway']) {
|
||||
const formatModel = function (models: GatewayLanguageModelEntry[]): Model[] {
|
||||
return models.map((m) => ({
|
||||
id: m.id,
|
||||
name: m.name,
|
||||
provider: 'gateway',
|
||||
group: m.id.split('/')[0],
|
||||
description: m.description ?? undefined
|
||||
}))
|
||||
}
|
||||
return formatModel((await gateway.getAvailableModels()).models)
|
||||
const gatewayModels = (await gateway.getAvailableModels()).models
|
||||
return normalizeGatewayModels(this.actualProvider, gatewayModels)
|
||||
}
|
||||
return this.legacyProvider.models()
|
||||
const sdkModels = await this.legacyProvider.models()
|
||||
return normalizeSdkModels(this.actualProvider, sdkModels)
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
|
||||
@@ -29,6 +29,32 @@ import { azureAnthropicProviderCreator } from './config/azure-anthropic'
|
||||
import { COPILOT_DEFAULT_HEADERS } from './constants'
|
||||
import { getAiSdkProviderId } from './factory'
|
||||
|
||||
/**
|
||||
* 获取轮询的API key
|
||||
* 复用legacy架构的多key轮询逻辑
|
||||
*/
|
||||
function getRotatedApiKey(provider: Provider): string {
|
||||
const keys = provider.apiKey.split(',').map((key) => key.trim())
|
||||
const keyName = `provider:${provider.id}:last_used_key`
|
||||
|
||||
if (keys.length === 1) {
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
const lastUsedKey = window.keyv.get(keyName)
|
||||
if (!lastUsedKey) {
|
||||
window.keyv.set(keyName, keys[0])
|
||||
return keys[0]
|
||||
}
|
||||
|
||||
const currentIndex = keys.indexOf(lastUsedKey)
|
||||
const nextIndex = (currentIndex + 1) % keys.length
|
||||
const nextKey = keys[nextIndex]
|
||||
window.keyv.set(keyName, nextKey)
|
||||
|
||||
return nextKey
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理特殊provider的转换逻辑
|
||||
*/
|
||||
@@ -52,13 +78,11 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider {
|
||||
}
|
||||
|
||||
/**
|
||||
* Format and normalize the API host URL for a provider.
|
||||
* Handles provider-specific URL formatting rules (e.g., appending version paths, Azure formatting).
|
||||
*
|
||||
* @param provider - The provider whose API host is to be formatted.
|
||||
* @returns A new provider instance with the formatted API host.
|
||||
* 主要用来对齐AISdk的BaseURL格式
|
||||
* @param provider
|
||||
* @returns
|
||||
*/
|
||||
export function formatProviderApiHost(provider: Provider): Provider {
|
||||
function formatProviderApiHost(provider: Provider): Provider {
|
||||
const formatted = { ...provider }
|
||||
if (formatted.anthropicApiHost) {
|
||||
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost)
|
||||
@@ -90,38 +114,18 @@ export function formatProviderApiHost(provider: Provider): Provider {
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieve the effective Provider configuration for the given model.
|
||||
* Applies all necessary transformations (special-provider handling, URL formatting, etc.).
|
||||
*
|
||||
* @param model - The model whose provider is to be resolved.
|
||||
* @returns A new Provider instance with all adaptations applied.
|
||||
* 获取实际的Provider配置
|
||||
* 简化版:将逻辑分解为小函数
|
||||
*/
|
||||
export function getActualProvider(model: Model): Provider {
|
||||
const baseProvider = getProviderByModel(model)
|
||||
|
||||
return adaptProvider({ provider: baseProvider, model })
|
||||
}
|
||||
// 按顺序处理各种转换
|
||||
let actualProvider = cloneDeep(baseProvider)
|
||||
actualProvider = handleSpecialProviders(model, actualProvider)
|
||||
actualProvider = formatProviderApiHost(actualProvider)
|
||||
|
||||
/**
|
||||
* Transforms a provider configuration by applying model-specific adaptations and normalizing its API host.
|
||||
* The transformations are applied in the following order:
|
||||
* 1. Model-specific provider handling (e.g., New-API, system providers, Azure OpenAI)
|
||||
* 2. API host formatting (provider-specific URL normalization)
|
||||
*
|
||||
* @param provider - The base provider configuration to transform.
|
||||
* @param model - The model associated with the provider; optional but required for special-provider handling.
|
||||
* @returns A new Provider instance with all transformations applied.
|
||||
*/
|
||||
export function adaptProvider({ provider, model }: { provider: Provider; model?: Model }): Provider {
|
||||
let adaptedProvider = cloneDeep(provider)
|
||||
|
||||
// Apply transformations in order
|
||||
if (model) {
|
||||
adaptedProvider = handleSpecialProviders(model, adaptedProvider)
|
||||
}
|
||||
adaptedProvider = formatProviderApiHost(adaptedProvider)
|
||||
|
||||
return adaptedProvider
|
||||
return actualProvider
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -135,7 +139,7 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A
|
||||
const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost)
|
||||
const baseConfig = {
|
||||
baseURL: baseURL,
|
||||
apiKey: actualProvider.apiKey
|
||||
apiKey: getRotatedApiKey(actualProvider)
|
||||
}
|
||||
|
||||
const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot
|
||||
|
||||
@@ -4372,7 +4372,7 @@
|
||||
"url": {
|
||||
"preview": "Preview: {{url}}",
|
||||
"reset": "Reset",
|
||||
"tip": "Add # at the end to disable the automatically appended API version."
|
||||
"tip": "ending with # forces use of input address"
|
||||
}
|
||||
},
|
||||
"api_host": "API Host",
|
||||
|
||||
@@ -4372,7 +4372,7 @@
|
||||
"url": {
|
||||
"preview": "预览: {{url}}",
|
||||
"reset": "重置",
|
||||
"tip": "在末尾添加 # 以禁用自动附加的API版本。"
|
||||
"tip": "# 结尾强制使用输入地址"
|
||||
}
|
||||
},
|
||||
"api_host": "API 地址",
|
||||
|
||||
@@ -4372,7 +4372,7 @@
|
||||
"url": {
|
||||
"preview": "預覽:{{url}}",
|
||||
"reset": "重設",
|
||||
"tip": "在末尾添加 # 以停用自動附加的 API 版本。"
|
||||
"tip": "# 結尾強制使用輸入位址"
|
||||
}
|
||||
},
|
||||
"api_host": "API 主機地址",
|
||||
|
||||
@@ -4372,7 +4372,7 @@
|
||||
"url": {
|
||||
"preview": "Vorschau: {{url}}",
|
||||
"reset": "Zurücksetzen",
|
||||
"tip": "Fügen Sie am Ende ein # hinzu, um die automatisch angehängte API-Version zu deaktivieren."
|
||||
"tip": "# am Ende erzwingt die Verwendung der Eingabe-Adresse"
|
||||
}
|
||||
},
|
||||
"api_host": "API-Adresse",
|
||||
|
||||
@@ -4372,7 +4372,7 @@
|
||||
"url": {
|
||||
"preview": "Προεπισκόπηση: {{url}}",
|
||||
"reset": "Επαναφορά",
|
||||
"tip": "Προσθέστε το σύμβολο # στο τέλος για να απενεργοποιήσετε την αυτόματα προστιθέμενη έκδοση API."
|
||||
"tip": "#τέλος ενδεχόμενη χρήση της εισαγωγής διευθύνσεως"
|
||||
}
|
||||
},
|
||||
"api_host": "Διεύθυνση API",
|
||||
|
||||
@@ -4372,7 +4372,7 @@
|
||||
"url": {
|
||||
"preview": "Vista previa: {{url}}",
|
||||
"reset": "Restablecer",
|
||||
"tip": "Añada # al final para deshabilitar la versión de la API que se añade automáticamente."
|
||||
"tip": "forzar uso de dirección de entrada con # al final"
|
||||
}
|
||||
},
|
||||
"api_host": "Dirección API",
|
||||
|
||||
@@ -4372,7 +4372,7 @@
|
||||
"url": {
|
||||
"preview": "Aperçu : {{url}}",
|
||||
"reset": "Réinitialiser",
|
||||
"tip": "Ajoutez # à la fin pour désactiver la version d'API ajoutée automatiquement."
|
||||
"tip": "forcer l'utilisation de l'adresse d'entrée si terminé par #"
|
||||
}
|
||||
},
|
||||
"api_host": "Adresse API",
|
||||
|
||||
@@ -4372,7 +4372,7 @@
|
||||
"url": {
|
||||
"preview": "プレビュー: {{url}}",
|
||||
"reset": "リセット",
|
||||
"tip": "自動的に付加されるAPIバージョンを無効にするには、末尾に#を追加します。"
|
||||
"tip": "#で終わる場合、入力されたアドレスを強制的に使用します"
|
||||
}
|
||||
},
|
||||
"api_host": "APIホスト",
|
||||
|
||||
@@ -4372,7 +4372,7 @@
|
||||
"url": {
|
||||
"preview": "Pré-visualização: {{url}}",
|
||||
"reset": "Redefinir",
|
||||
"tip": "Adicione # no final para desativar a versão da API adicionada automaticamente."
|
||||
"tip": "e forçar o uso do endereço original quando terminar com '#'"
|
||||
}
|
||||
},
|
||||
"api_host": "Endereço API",
|
||||
|
||||
@@ -4372,7 +4372,7 @@
|
||||
"url": {
|
||||
"preview": "Предпросмотр: {{url}}",
|
||||
"reset": "Сброс",
|
||||
"tip": "Добавьте # в конце, чтобы отключить автоматически добавляемую версию API."
|
||||
"tip": "заканчивая на # принудительно использует введенный адрес"
|
||||
}
|
||||
},
|
||||
"api_host": "Хост API",
|
||||
|
||||
@@ -18,7 +18,7 @@ import NewApiAddModelPopup from '@renderer/pages/settings/ProviderSettings/Model
|
||||
import NewApiBatchAddModelPopup from '@renderer/pages/settings/ProviderSettings/ModelList/NewApiBatchAddModelPopup'
|
||||
import { fetchModels } from '@renderer/services/ApiService'
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
import { filterModelsByKeywords, getDefaultGroupName, getFancyProviderName } from '@renderer/utils'
|
||||
import { filterModelsByKeywords, getFancyProviderName } from '@renderer/utils'
|
||||
import { isFreeModel } from '@renderer/utils/model'
|
||||
import { isNewApiProvider } from '@renderer/utils/provider'
|
||||
import { Button, Empty, Flex, Modal, Spin, Tabs, Tooltip } from 'antd'
|
||||
@@ -183,25 +183,7 @@ const PopupContainer: React.FC<Props> = ({ providerId, resolve }) => {
|
||||
setLoadingModels(true)
|
||||
try {
|
||||
const models = await fetchModels(provider)
|
||||
// TODO: More robust conversion
|
||||
const filteredModels = models
|
||||
.map((model) => ({
|
||||
// @ts-ignore modelId
|
||||
id: model?.id || model?.name,
|
||||
// @ts-ignore name
|
||||
name: model?.display_name || model?.displayName || model?.name || model?.id,
|
||||
provider: provider.id,
|
||||
// @ts-ignore group
|
||||
group: getDefaultGroupName(model?.id || model?.name, provider.id),
|
||||
// @ts-ignore description
|
||||
description: model?.description || '',
|
||||
// @ts-ignore owned_by
|
||||
owned_by: model?.owned_by || '',
|
||||
// @ts-ignore supported_endpoint_types
|
||||
supported_endpoint_types: model?.supported_endpoint_types
|
||||
}))
|
||||
.filter((model) => !isEmpty(model.name))
|
||||
|
||||
const filteredModels = models.filter((model) => !isEmpty(model.name))
|
||||
setListModels(filteredModels)
|
||||
} catch (error) {
|
||||
logger.error(`Failed to load models for provider ${getFancyProviderName(provider)}`, error as Error)
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import { adaptProvider } from '@renderer/aiCore/provider/providerConfig'
|
||||
import OpenAIAlert from '@renderer/components/Alert/OpenAIAlert'
|
||||
import { LoadingIcon } from '@renderer/components/Icons'
|
||||
import { HStack } from '@renderer/components/Layout'
|
||||
import { ApiKeyListPopup } from '@renderer/components/Popups/ApiKeyListPopup'
|
||||
import Selector from '@renderer/components/Selector'
|
||||
import { HelpTooltip } from '@renderer/components/TooltipIcons'
|
||||
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
|
||||
import { PROVIDER_URLS } from '@renderer/config/providers'
|
||||
import { useTheme } from '@renderer/context/ThemeProvider'
|
||||
@@ -21,7 +19,14 @@ import type { SystemProviderId } from '@renderer/types'
|
||||
import { isSystemProvider, isSystemProviderId, SystemProviderIds } from '@renderer/types'
|
||||
import type { ApiKeyConnectivity } from '@renderer/types/healthCheck'
|
||||
import { HealthStatus } from '@renderer/types/healthCheck'
|
||||
import { formatApiHost, formatApiKeys, getFancyProviderName, validateApiHost } from '@renderer/utils'
|
||||
import {
|
||||
formatApiHost,
|
||||
formatApiKeys,
|
||||
formatAzureOpenAIApiHost,
|
||||
formatVertexApiHost,
|
||||
getFancyProviderName,
|
||||
validateApiHost
|
||||
} from '@renderer/utils'
|
||||
import { formatErrorMessage } from '@renderer/utils/error'
|
||||
import {
|
||||
isAIGatewayProvider,
|
||||
@@ -31,6 +36,7 @@ import {
|
||||
isNewApiProvider,
|
||||
isOpenAICompatibleProvider,
|
||||
isOpenAIProvider,
|
||||
isSupportAPIVersionProvider,
|
||||
isVertexProvider
|
||||
} from '@renderer/utils/provider'
|
||||
import { Button, Divider, Flex, Input, Select, Space, Switch, Tooltip } from 'antd'
|
||||
@@ -275,10 +281,12 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
|
||||
}, [configuredApiHost, apiHost])
|
||||
|
||||
const hostPreview = () => {
|
||||
const formattedApiHost = adaptProvider({ provider: { ...provider, apiHost } }).apiHost
|
||||
if (apiHost.endsWith('#')) {
|
||||
return apiHost.replace('#', '')
|
||||
}
|
||||
|
||||
if (isOpenAICompatibleProvider(provider)) {
|
||||
return formattedApiHost + '/chat/completions'
|
||||
return formatApiHost(apiHost, isSupportAPIVersionProvider(provider)) + '/chat/completions'
|
||||
}
|
||||
|
||||
if (isAzureOpenAIProvider(provider)) {
|
||||
@@ -286,26 +294,29 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
|
||||
const path = !['preview', 'v1'].includes(apiVersion)
|
||||
? `/v1/chat/completion?apiVersion=v1`
|
||||
: `/v1/responses?apiVersion=v1`
|
||||
return formattedApiHost + path
|
||||
return formatAzureOpenAIApiHost(apiHost) + path
|
||||
}
|
||||
|
||||
if (isAnthropicProvider(provider)) {
|
||||
return formattedApiHost + '/messages'
|
||||
// AI SDK uses the baseURL with /v1, then appends /messages
|
||||
// formatApiHost adds /v1 automatically if not present
|
||||
const normalizedHost = formatApiHost(apiHost)
|
||||
return normalizedHost + '/messages'
|
||||
}
|
||||
|
||||
if (isGeminiProvider(provider)) {
|
||||
return formattedApiHost + '/models'
|
||||
return formatApiHost(apiHost, true, 'v1beta') + '/models'
|
||||
}
|
||||
if (isOpenAIProvider(provider)) {
|
||||
return formattedApiHost + '/responses'
|
||||
return formatApiHost(apiHost) + '/responses'
|
||||
}
|
||||
if (isVertexProvider(provider)) {
|
||||
return formattedApiHost + '/publishers/google'
|
||||
return formatVertexApiHost(provider) + '/publishers/google'
|
||||
}
|
||||
if (isAIGatewayProvider(provider)) {
|
||||
return formattedApiHost + '/language-model'
|
||||
return formatApiHost(apiHost) + '/language-model'
|
||||
}
|
||||
return formattedApiHost
|
||||
return formatApiHost(apiHost)
|
||||
}
|
||||
|
||||
// API key 连通性检查状态指示器,目前仅在失败时显示
|
||||
@@ -483,21 +494,16 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
|
||||
{!isDmxapi && (
|
||||
<>
|
||||
<SettingSubtitle style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between' }}>
|
||||
<div className="flex items-center gap-1">
|
||||
<Tooltip title={hostSelectorTooltip} mouseEnterDelay={0.3}>
|
||||
<div>
|
||||
<Selector
|
||||
size={14}
|
||||
value={activeHostField}
|
||||
onChange={(value) => setActiveHostField(value as HostField)}
|
||||
options={hostSelectorOptions}
|
||||
style={{ paddingLeft: 1, fontWeight: 'bold' }}
|
||||
placement="bottomLeft"
|
||||
/>
|
||||
</div>
|
||||
</Tooltip>
|
||||
<HelpTooltip title={t('settings.provider.api.url.tip')}></HelpTooltip>
|
||||
</div>
|
||||
<Tooltip title={hostSelectorTooltip} mouseEnterDelay={0.3}>
|
||||
<Selector
|
||||
size={14}
|
||||
value={activeHostField}
|
||||
onChange={(value) => setActiveHostField(value as HostField)}
|
||||
options={hostSelectorOptions}
|
||||
style={{ paddingLeft: 1, fontWeight: 'bold' }}
|
||||
placement="bottomLeft"
|
||||
/>
|
||||
</Tooltip>
|
||||
<div style={{ display: 'flex', alignItems: 'center', gap: 4 }}>
|
||||
<Button
|
||||
type="text"
|
||||
|
||||
@@ -8,12 +8,11 @@ import { isDedicatedImageGenerationModel, isEmbeddingModel, isFunctionCallingMod
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import i18n from '@renderer/i18n'
|
||||
import store from '@renderer/store'
|
||||
import type { FetchChatCompletionParams } from '@renderer/types'
|
||||
import type { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/types'
|
||||
import { type FetchChatCompletionParams, isSystemProvider } from '@renderer/types'
|
||||
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import { type Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import type { Message, ResponseError } from '@renderer/types/newMessage'
|
||||
import type { SdkModel } from '@renderer/types/sdk'
|
||||
import { removeSpecialCharactersForTopicName, uuid } from '@renderer/utils'
|
||||
import { abortCompletion, readyToAbort } from '@renderer/utils/abortController'
|
||||
import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
||||
@@ -22,8 +21,7 @@ import { purifyMarkdownImages } from '@renderer/utils/markdown'
|
||||
import { isPromptToolUse, isSupportedToolUse } from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { containsSupportedVariables, replacePromptVariables } from '@renderer/utils/prompt'
|
||||
import { NOT_SUPPORT_API_KEY_PROVIDERS } from '@renderer/utils/provider'
|
||||
import { cloneDeep, isEmpty, takeRight } from 'lodash'
|
||||
import { isEmpty, takeRight } from 'lodash'
|
||||
|
||||
import type { ModernAiProviderConfig } from '../aiCore/index_new'
|
||||
import AiProviderNew from '../aiCore/index_new'
|
||||
@@ -44,8 +42,6 @@ import {
|
||||
// } from './MessagesService'
|
||||
// import WebSearchService from './WebSearchService'
|
||||
|
||||
// FIXME: 这里太多重复逻辑,需要重构
|
||||
|
||||
const logger = loggerService.withContext('ApiService')
|
||||
|
||||
export async function fetchMcpTools(assistant: Assistant) {
|
||||
@@ -98,15 +94,7 @@ export async function fetchChatCompletion({
|
||||
modelId: assistant.model?.id,
|
||||
modelName: assistant.model?.name
|
||||
})
|
||||
|
||||
// Get base provider and apply API key rotation
|
||||
const baseProvider = getProviderByModel(assistant.model || getDefaultModel())
|
||||
const providerWithRotatedKey = {
|
||||
...cloneDeep(baseProvider),
|
||||
apiKey: getRotatedApiKey(baseProvider)
|
||||
}
|
||||
|
||||
const AI = new AiProviderNew(assistant.model || getDefaultModel(), providerWithRotatedKey)
|
||||
const AI = new AiProviderNew(assistant.model || getDefaultModel())
|
||||
const provider = AI.getActualProvider()
|
||||
|
||||
const mcpTools: MCPTool[] = []
|
||||
@@ -183,13 +171,7 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages:
|
||||
return null
|
||||
}
|
||||
|
||||
// Apply API key rotation
|
||||
const providerWithRotatedKey = {
|
||||
...cloneDeep(provider),
|
||||
apiKey: getRotatedApiKey(provider)
|
||||
}
|
||||
|
||||
const AI = new AiProviderNew(model, providerWithRotatedKey)
|
||||
const AI = new AiProviderNew(model)
|
||||
|
||||
const topicId = messages?.find((message) => message.topicId)?.topicId || ''
|
||||
|
||||
@@ -288,13 +270,7 @@ export async function fetchNoteSummary({ content, assistant }: { content: string
|
||||
return null
|
||||
}
|
||||
|
||||
// Apply API key rotation
|
||||
const providerWithRotatedKey = {
|
||||
...cloneDeep(provider),
|
||||
apiKey: getRotatedApiKey(provider)
|
||||
}
|
||||
|
||||
const AI = new AiProviderNew(model, providerWithRotatedKey)
|
||||
const AI = new AiProviderNew(model)
|
||||
|
||||
// only 2000 char and no images
|
||||
const truncatedContent = content.substring(0, 2000)
|
||||
@@ -382,13 +358,7 @@ export async function fetchGenerate({
|
||||
return ''
|
||||
}
|
||||
|
||||
// Apply API key rotation
|
||||
const providerWithRotatedKey = {
|
||||
...cloneDeep(provider),
|
||||
apiKey: getRotatedApiKey(provider)
|
||||
}
|
||||
|
||||
const AI = new AiProviderNew(model, providerWithRotatedKey)
|
||||
const AI = new AiProviderNew(model)
|
||||
|
||||
const assistant = getDefaultAssistant()
|
||||
assistant.model = model
|
||||
@@ -433,44 +403,28 @@ export async function fetchGenerate({
|
||||
|
||||
export function hasApiKey(provider: Provider) {
|
||||
if (!provider) return false
|
||||
if (isSystemProvider(provider) && NOT_SUPPORT_API_KEY_PROVIDERS.includes(provider.id)) return true
|
||||
if (['ollama', 'lmstudio', 'vertexai', 'cherryai'].includes(provider.id)) return true
|
||||
return !isEmpty(provider.apiKey)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取轮询的API key
|
||||
* 复用legacy架构的多key轮询逻辑
|
||||
* Get the first available embedding model from enabled providers
|
||||
*/
|
||||
function getRotatedApiKey(provider: Provider): string {
|
||||
const keys = provider.apiKey.split(',').map((key) => key.trim())
|
||||
const keyName = `provider:${provider.id}:last_used_key`
|
||||
// function getFirstEmbeddingModel() {
|
||||
// const providers = store.getState().llm.providers.filter((p) => p.enabled)
|
||||
|
||||
if (keys.length === 1) {
|
||||
return keys[0]
|
||||
}
|
||||
// for (const provider of providers) {
|
||||
// const embeddingModel = provider.models.find((model) => isEmbeddingModel(model))
|
||||
// if (embeddingModel) {
|
||||
// return embeddingModel
|
||||
// }
|
||||
// }
|
||||
|
||||
const lastUsedKey = window.keyv.get(keyName)
|
||||
if (!lastUsedKey) {
|
||||
window.keyv.set(keyName, keys[0])
|
||||
return keys[0]
|
||||
}
|
||||
// return undefined
|
||||
// }
|
||||
|
||||
const currentIndex = keys.indexOf(lastUsedKey)
|
||||
const nextIndex = (currentIndex + 1) % keys.length
|
||||
const nextKey = keys[nextIndex]
|
||||
window.keyv.set(keyName, nextKey)
|
||||
|
||||
return nextKey
|
||||
}
|
||||
|
||||
export async function fetchModels(provider: Provider): Promise<SdkModel[]> {
|
||||
// Apply API key rotation
|
||||
const providerWithRotatedKey = {
|
||||
...cloneDeep(provider),
|
||||
apiKey: getRotatedApiKey(provider)
|
||||
}
|
||||
|
||||
const AI = new AiProviderNew(providerWithRotatedKey)
|
||||
export async function fetchModels(provider: Provider): Promise<Model[]> {
|
||||
const AI = new AiProviderNew(provider)
|
||||
|
||||
try {
|
||||
return await AI.models()
|
||||
@@ -480,7 +434,12 @@ export async function fetchModels(provider: Provider): Promise<SdkModel[]> {
|
||||
}
|
||||
|
||||
export function checkApiProvider(provider: Provider): void {
|
||||
if (isSystemProvider(provider) && !NOT_SUPPORT_API_KEY_PROVIDERS.includes(provider.id)) {
|
||||
if (
|
||||
provider.id !== 'ollama' &&
|
||||
provider.id !== 'lmstudio' &&
|
||||
provider.type !== 'vertexai' &&
|
||||
provider.id !== 'copilot'
|
||||
) {
|
||||
if (!provider.apiKey) {
|
||||
window.toast.error(i18n.t('message.error.enter.api.label'))
|
||||
throw new Error(i18n.t('message.error.enter.api.label'))
|
||||
@@ -501,7 +460,8 @@ export function checkApiProvider(provider: Provider): void {
|
||||
export async function checkApi(provider: Provider, model: Model, timeout = 15000): Promise<void> {
|
||||
checkApiProvider(provider)
|
||||
|
||||
const ai = new AiProviderNew(model, provider)
|
||||
// Don't pass in provider parameter. We need auto-format URL
|
||||
const ai = new AiProviderNew(model)
|
||||
|
||||
const assistant = getDefaultAssistant()
|
||||
assistant.model = model
|
||||
|
||||
102
src/renderer/src/services/__tests__/ModelAdapter.test.ts
Normal file
102
src/renderer/src/services/__tests__/ModelAdapter.test.ts
Normal file
@@ -0,0 +1,102 @@
|
||||
import type { GatewayLanguageModelEntry } from '@ai-sdk/gateway'
|
||||
import { normalizeGatewayModels, normalizeSdkModels } from '@renderer/services/models/ModelAdapter'
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
import type { EndpointType } from '@renderer/types/index'
|
||||
import type { SdkModel } from '@renderer/types/sdk'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
const createProvider = (overrides: Partial<Provider> = {}): Provider => ({
|
||||
id: 'openai',
|
||||
type: 'openai',
|
||||
name: 'OpenAI',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://example.com/v1',
|
||||
models: [],
|
||||
...overrides
|
||||
})
|
||||
|
||||
describe('ModelAdapter', () => {
|
||||
it('adapts generic SDK models into internal models', () => {
|
||||
const provider = createProvider({ id: 'openai' })
|
||||
const models = normalizeSdkModels(provider, [
|
||||
{
|
||||
id: 'gpt-4o-mini',
|
||||
display_name: 'GPT-4o mini',
|
||||
description: 'General purpose model',
|
||||
owned_by: 'openai'
|
||||
} as unknown as SdkModel
|
||||
])
|
||||
|
||||
expect(models).toHaveLength(1)
|
||||
expect(models[0]).toMatchObject({
|
||||
id: 'gpt-4o-mini',
|
||||
name: 'GPT-4o mini',
|
||||
provider: 'openai',
|
||||
group: 'gpt-4o',
|
||||
description: 'General purpose model',
|
||||
owned_by: 'openai'
|
||||
} as Partial<Model>)
|
||||
})
|
||||
|
||||
it('preserves supported endpoint types for New API models', () => {
|
||||
const provider = createProvider({ id: 'new-api' })
|
||||
const endpointTypes: EndpointType[] = ['openai', 'image-generation']
|
||||
const [model] = normalizeSdkModels(provider, [
|
||||
{
|
||||
id: 'new-api-model',
|
||||
name: 'New API Model',
|
||||
supported_endpoint_types: endpointTypes
|
||||
} as unknown as SdkModel
|
||||
])
|
||||
|
||||
expect(model.supported_endpoint_types).toEqual(endpointTypes)
|
||||
})
|
||||
|
||||
it('filters unsupported endpoint types while keeping valid ones', () => {
|
||||
const provider = createProvider({ id: 'new-api' })
|
||||
const [model] = normalizeSdkModels(provider, [
|
||||
{
|
||||
id: 'another-model',
|
||||
name: 'Another Model',
|
||||
supported_endpoint_types: ['openai', 'unknown-endpoint', 'gemini']
|
||||
} as unknown as SdkModel
|
||||
])
|
||||
|
||||
expect(model.supported_endpoint_types).toEqual(['openai', 'gemini'])
|
||||
})
|
||||
|
||||
it('adapts ai-gateway entries through the same adapter', () => {
|
||||
const provider = createProvider({ id: 'ai-gateway', type: 'ai-gateway' })
|
||||
const [model] = normalizeGatewayModels(provider, [
|
||||
{
|
||||
id: 'openai/gpt-4o',
|
||||
name: 'OpenAI GPT-4o',
|
||||
description: 'Gateway entry',
|
||||
specification: {
|
||||
specificationVersion: 'v2',
|
||||
provider: 'openai',
|
||||
modelId: 'gpt-4o'
|
||||
}
|
||||
} as GatewayLanguageModelEntry
|
||||
])
|
||||
|
||||
expect(model).toMatchObject({
|
||||
id: 'openai/gpt-4o',
|
||||
group: 'openai',
|
||||
provider: 'ai-gateway',
|
||||
description: 'Gateway entry'
|
||||
})
|
||||
})
|
||||
|
||||
it('drops invalid entries without ids or names', () => {
|
||||
const provider = createProvider()
|
||||
const models = normalizeSdkModels(provider, [
|
||||
{
|
||||
id: '',
|
||||
name: ''
|
||||
} as unknown as SdkModel
|
||||
])
|
||||
|
||||
expect(models).toHaveLength(0)
|
||||
})
|
||||
})
|
||||
180
src/renderer/src/services/models/ModelAdapter.ts
Normal file
180
src/renderer/src/services/models/ModelAdapter.ts
Normal file
@@ -0,0 +1,180 @@
|
||||
import type { GatewayLanguageModelEntry } from '@ai-sdk/gateway'
|
||||
import { loggerService } from '@logger'
|
||||
import { type EndpointType, EndPointTypeSchema, type Model, type Provider } from '@renderer/types'
|
||||
import type { NewApiModel, SdkModel } from '@renderer/types/sdk'
|
||||
import { getDefaultGroupName } from '@renderer/utils/naming'
|
||||
import * as z from 'zod'
|
||||
|
||||
const logger = loggerService.withContext('ModelAdapter')
|
||||
|
||||
const EndpointTypeArraySchema = z.array(EndPointTypeSchema).nonempty()
|
||||
|
||||
const NormalizedModelSchema = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
name: z.string().trim().min(1),
|
||||
provider: z.string().trim().min(1),
|
||||
group: z.string().trim().min(1),
|
||||
description: z.string().optional(),
|
||||
owned_by: z.string().optional(),
|
||||
supported_endpoint_types: EndpointTypeArraySchema.optional()
|
||||
})
|
||||
|
||||
type NormalizedModelInput = z.input<typeof NormalizedModelSchema>
|
||||
|
||||
export function normalizeSdkModels(provider: Provider, models: SdkModel[]): Model[] {
|
||||
return normalizeModels(models, (entry) => adaptSdkModel(provider, entry))
|
||||
}
|
||||
|
||||
export function normalizeGatewayModels(provider: Provider, models: GatewayLanguageModelEntry[]): Model[] {
|
||||
return normalizeModels(models, (entry) => adaptGatewayModel(provider, entry))
|
||||
}
|
||||
|
||||
function normalizeModels<T>(models: T[], transformer: (entry: T) => Model | null): Model[] {
|
||||
const uniqueModels: Model[] = []
|
||||
const seen = new Set<string>()
|
||||
|
||||
for (const entry of models) {
|
||||
const normalized = transformer(entry)
|
||||
if (!normalized) continue
|
||||
if (seen.has(normalized.id)) continue
|
||||
seen.add(normalized.id)
|
||||
uniqueModels.push(normalized)
|
||||
}
|
||||
|
||||
return uniqueModels
|
||||
}
|
||||
|
||||
function adaptSdkModel(provider: Provider, model: SdkModel): Model | null {
|
||||
const id = pickPreferredString([(model as any)?.id, (model as any)?.modelId])
|
||||
const name = pickPreferredString([
|
||||
(model as any)?.display_name,
|
||||
(model as any)?.displayName,
|
||||
(model as any)?.name,
|
||||
id
|
||||
])
|
||||
|
||||
if (!id || !name) {
|
||||
logger.warn('Skip SDK model with missing id or name', {
|
||||
providerId: provider.id,
|
||||
modelSnippet: summarizeModel(model)
|
||||
})
|
||||
return null
|
||||
}
|
||||
|
||||
const candidate: NormalizedModelInput = {
|
||||
id,
|
||||
name,
|
||||
provider: provider.id,
|
||||
group: getDefaultGroupName(id, provider.id),
|
||||
description: pickPreferredString([(model as any)?.description, (model as any)?.summary]),
|
||||
owned_by: pickPreferredString([(model as any)?.owned_by, (model as any)?.publisher])
|
||||
}
|
||||
|
||||
const supportedEndpointTypes = pickSupportedEndpointTypes(provider.id, model)
|
||||
if (supportedEndpointTypes) {
|
||||
candidate.supported_endpoint_types = supportedEndpointTypes
|
||||
}
|
||||
|
||||
return validateModel(candidate, model)
|
||||
}
|
||||
|
||||
function adaptGatewayModel(provider: Provider, model: GatewayLanguageModelEntry): Model | null {
|
||||
const id = model?.id?.trim()
|
||||
const name = model?.name?.trim() || id
|
||||
|
||||
if (!id || !name) {
|
||||
logger.warn('Skip gateway model with missing id or name', {
|
||||
providerId: provider.id,
|
||||
modelSnippet: summarizeModel(model)
|
||||
})
|
||||
return null
|
||||
}
|
||||
|
||||
const candidate: NormalizedModelInput = {
|
||||
id,
|
||||
name,
|
||||
provider: provider.id,
|
||||
group: getDefaultGroupName(id, provider.id),
|
||||
description: model.description ?? undefined
|
||||
}
|
||||
|
||||
return validateModel(candidate, model)
|
||||
}
|
||||
|
||||
function pickPreferredString(values: Array<unknown>): string | undefined {
|
||||
for (const value of values) {
|
||||
if (typeof value === 'string') {
|
||||
const trimmed = value.trim()
|
||||
if (trimmed.length > 0) {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
function pickSupportedEndpointTypes(providerId: string, model: SdkModel): EndpointType[] | undefined {
|
||||
const candidate =
|
||||
(model as Partial<NewApiModel>).supported_endpoint_types ??
|
||||
((model as Record<string, unknown>).supported_endpoint_types as EndpointType[] | undefined)
|
||||
|
||||
if (!Array.isArray(candidate) || candidate.length === 0) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const supported: EndpointType[] = []
|
||||
const unsupported: unknown[] = []
|
||||
|
||||
for (const value of candidate) {
|
||||
const parsed = EndPointTypeSchema.safeParse(value)
|
||||
if (parsed.success) {
|
||||
supported.push(parsed.data)
|
||||
} else {
|
||||
unsupported.push(value)
|
||||
}
|
||||
}
|
||||
|
||||
if (unsupported.length > 0) {
|
||||
logger.warn('Pruned unsupported endpoint types', {
|
||||
providerId,
|
||||
values: unsupported,
|
||||
modelSnippet: summarizeModel(model)
|
||||
})
|
||||
}
|
||||
|
||||
return supported.length > 0 ? supported : undefined
|
||||
}
|
||||
|
||||
function validateModel(candidate: NormalizedModelInput, source: unknown): Model | null {
|
||||
const parsed = NormalizedModelSchema.safeParse(candidate)
|
||||
if (!parsed.success) {
|
||||
logger.warn('Discard invalid model entry', {
|
||||
providerId: candidate.provider,
|
||||
issues: parsed.error.issues,
|
||||
modelSnippet: summarizeModel(source)
|
||||
})
|
||||
return null
|
||||
}
|
||||
|
||||
return parsed.data
|
||||
}
|
||||
|
||||
function summarizeModel(model: unknown) {
|
||||
if (!model || typeof model !== 'object') {
|
||||
return model
|
||||
}
|
||||
const { id, name, display_name, displayName, description, owned_by, supported_endpoint_types } = model as Record<
|
||||
string,
|
||||
unknown
|
||||
>
|
||||
|
||||
return {
|
||||
id,
|
||||
name,
|
||||
display_name,
|
||||
displayName,
|
||||
description,
|
||||
owned_by,
|
||||
supported_endpoint_types
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,8 @@ import type { CSSProperties } from 'react'
|
||||
export * from './file'
|
||||
export * from './note'
|
||||
|
||||
import * as z from 'zod'
|
||||
|
||||
import type { StreamTextParams } from './aiCoreTypes'
|
||||
import type { Chunk } from './chunk'
|
||||
import type { FileMetadata } from './file'
|
||||
@@ -240,7 +242,15 @@ export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'functio
|
||||
export type ModelTag = Exclude<ModelType, 'text'> | 'free'
|
||||
|
||||
// "image-generation" is also openai endpoint, but specifically for image generation.
|
||||
export type EndpointType = 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'image-generation' | 'jina-rerank'
|
||||
export const EndPointTypeSchema = z.enum([
|
||||
'openai',
|
||||
'openai-response',
|
||||
'anthropic',
|
||||
'gemini',
|
||||
'image-generation',
|
||||
'jina-rerank'
|
||||
])
|
||||
export type EndpointType = z.infer<typeof EndPointTypeSchema>
|
||||
|
||||
export type ModelPricing = {
|
||||
input_per_million_tokens: number
|
||||
|
||||
@@ -13,8 +13,7 @@ import {
|
||||
routeToEndpoint,
|
||||
splitApiKeyString,
|
||||
validateApiHost,
|
||||
withoutTrailingApiVersion,
|
||||
withoutTrailingSharp
|
||||
withoutTrailingApiVersion
|
||||
} from '../api'
|
||||
|
||||
vi.mock('@renderer/store', () => {
|
||||
@@ -82,27 +81,6 @@ describe('api', () => {
|
||||
it('keeps host untouched when api version unsupported', () => {
|
||||
expect(formatApiHost('https://api.example.com', false)).toBe('https://api.example.com')
|
||||
})
|
||||
|
||||
it('removes trailing # and does not append api version when host ends with #', () => {
|
||||
expect(formatApiHost('https://api.example.com#')).toBe('https://api.example.com')
|
||||
expect(formatApiHost('http://localhost:5173/#')).toBe('http://localhost:5173/')
|
||||
expect(formatApiHost(' https://api.openai.com/# ')).toBe('https://api.openai.com/')
|
||||
})
|
||||
|
||||
it('handles trailing # with custom api version settings', () => {
|
||||
expect(formatApiHost('https://api.example.com#', true, 'v2')).toBe('https://api.example.com')
|
||||
expect(formatApiHost('https://api.example.com#', false, 'v2')).toBe('https://api.example.com')
|
||||
})
|
||||
|
||||
it('handles host with both trailing # and existing api version', () => {
|
||||
expect(formatApiHost('https://api.example.com/v2#')).toBe('https://api.example.com/v2')
|
||||
expect(formatApiHost('https://api.example.com/v3beta#')).toBe('https://api.example.com/v3beta')
|
||||
})
|
||||
|
||||
it('trims whitespace before processing trailing #', () => {
|
||||
expect(formatApiHost(' https://api.example.com# ')).toBe('https://api.example.com')
|
||||
expect(formatApiHost('\thttps://api.example.com#\n')).toBe('https://api.example.com')
|
||||
})
|
||||
})
|
||||
|
||||
describe('hasAPIVersion', () => {
|
||||
@@ -426,56 +404,4 @@ describe('api', () => {
|
||||
expect(withoutTrailingApiVersion('')).toBe('')
|
||||
})
|
||||
})
|
||||
|
||||
describe('withoutTrailingSharp', () => {
|
||||
it('removes trailing # from URL', () => {
|
||||
expect(withoutTrailingSharp('https://api.example.com#')).toBe('https://api.example.com')
|
||||
expect(withoutTrailingSharp('http://localhost:3000#')).toBe('http://localhost:3000')
|
||||
})
|
||||
|
||||
it('returns URL unchanged when no trailing #', () => {
|
||||
expect(withoutTrailingSharp('https://api.example.com')).toBe('https://api.example.com')
|
||||
expect(withoutTrailingSharp('http://localhost:3000')).toBe('http://localhost:3000')
|
||||
})
|
||||
|
||||
it('handles URLs with multiple # characters but only removes trailing one', () => {
|
||||
expect(withoutTrailingSharp('https://api.example.com#path#')).toBe('https://api.example.com#path')
|
||||
})
|
||||
|
||||
it('handles URLs with # in the middle (not trailing)', () => {
|
||||
expect(withoutTrailingSharp('https://api.example.com#section/path')).toBe('https://api.example.com#section/path')
|
||||
expect(withoutTrailingSharp('https://api.example.com/v1/chat/completions#')).toBe(
|
||||
'https://api.example.com/v1/chat/completions'
|
||||
)
|
||||
})
|
||||
|
||||
it('handles empty string', () => {
|
||||
expect(withoutTrailingSharp('')).toBe('')
|
||||
})
|
||||
|
||||
it('handles single character #', () => {
|
||||
expect(withoutTrailingSharp('#')).toBe('')
|
||||
})
|
||||
|
||||
it('preserves whitespace around the URL (pure function)', () => {
|
||||
expect(withoutTrailingSharp(' https://api.example.com# ')).toBe(' https://api.example.com# ')
|
||||
expect(withoutTrailingSharp('\thttps://api.example.com#\n')).toBe('\thttps://api.example.com#\n')
|
||||
})
|
||||
|
||||
it('only removes exact trailing # character', () => {
|
||||
expect(withoutTrailingSharp('https://api.example.com# ')).toBe('https://api.example.com# ')
|
||||
expect(withoutTrailingSharp(' https://api.example.com#')).toBe(' https://api.example.com')
|
||||
expect(withoutTrailingSharp('https://api.example.com#\t')).toBe('https://api.example.com#\t')
|
||||
})
|
||||
|
||||
it('handles URLs ending with multiple # characters', () => {
|
||||
expect(withoutTrailingSharp('https://api.example.com##')).toBe('https://api.example.com#')
|
||||
expect(withoutTrailingSharp('https://api.example.com###')).toBe('https://api.example.com##')
|
||||
})
|
||||
|
||||
it('preserves URL with trailing # and other content', () => {
|
||||
expect(withoutTrailingSharp('https://api.example.com/v1#')).toBe('https://api.example.com/v1')
|
||||
expect(withoutTrailingSharp('https://api.example.com/v2beta#')).toBe('https://api.example.com/v2beta')
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -62,23 +62,6 @@ export function withoutTrailingSlash<T extends string>(url: T): T {
|
||||
return url.replace(/\/$/, '') as T
|
||||
}
|
||||
|
||||
/**
|
||||
* Removes the trailing '#' from a URL string if it exists.
|
||||
*
|
||||
* @template T - The string type to preserve type safety
|
||||
* @param {T} url - The URL string to process
|
||||
* @returns {T} The URL string without a trailing '#'
|
||||
*
|
||||
* @example
|
||||
* ```ts
|
||||
* withoutTrailingSharp('https://example.com#') // 'https://example.com'
|
||||
* withoutTrailingSharp('https://example.com') // 'https://example.com'
|
||||
* ```
|
||||
*/
|
||||
export function withoutTrailingSharp<T extends string>(url: T): T {
|
||||
return url.replace(/#$/, '') as T
|
||||
}
|
||||
|
||||
/**
|
||||
* Formats an API host URL by normalizing it and optionally appending an API version.
|
||||
*
|
||||
@@ -87,12 +70,12 @@ export function withoutTrailingSharp<T extends string>(url: T): T {
|
||||
* @param apiVersion - The API version to append if needed. Defaults to `'v1'`.
|
||||
*
|
||||
* @returns The formatted API host URL. If the host is empty after normalization, returns an empty string.
|
||||
* If the host ends with '#', API version is not supported, or the host already contains a version, returns the normalized host with trailing '#' removed.
|
||||
* If the host ends with '#', API version is not supported, or the host already contains a version, returns the normalized host as-is.
|
||||
* Otherwise, returns the host with the API version appended.
|
||||
*
|
||||
* @example
|
||||
* formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1'
|
||||
* formatApiHost('https://api.example.com#') // Returns 'https://api.example.com'
|
||||
* formatApiHost('https://api.example.com#') // Returns 'https://api.example.com#'
|
||||
* formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2'
|
||||
*/
|
||||
export function formatApiHost(host?: string, supportApiVersion: boolean = true, apiVersion: string = 'v1'): string {
|
||||
@@ -101,13 +84,10 @@ export function formatApiHost(host?: string, supportApiVersion: boolean = true,
|
||||
return ''
|
||||
}
|
||||
|
||||
const shouldAppendApiVersion = !(normalizedHost.endsWith('#') || !supportApiVersion || hasAPIVersion(normalizedHost))
|
||||
|
||||
if (shouldAppendApiVersion) {
|
||||
return `${normalizedHost}/${apiVersion}`
|
||||
} else {
|
||||
return withoutTrailingSharp(normalizedHost)
|
||||
if (normalizedHost.endsWith('#') || !supportApiVersion || hasAPIVersion(normalizedHost)) {
|
||||
return normalizedHost
|
||||
}
|
||||
return `${normalizedHost}/${apiVersion}`
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -183,11 +183,3 @@ export const isSupportAPIVersionProvider = (provider: Provider) => {
|
||||
}
|
||||
return provider.apiOptions?.isNotSupportAPIVersion !== false
|
||||
}
|
||||
|
||||
export const NOT_SUPPORT_API_KEY_PROVIDERS: readonly SystemProviderId[] = [
|
||||
'ollama',
|
||||
'lmstudio',
|
||||
'vertexai',
|
||||
'aws-bedrock',
|
||||
'copilot'
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user