Compare commits

...

1 Commits

Author SHA1 Message Date
icarus
b3cd1edfdc fix: normalize provider model data 2025-11-30 19:51:27 +08:00
6 changed files with 301 additions and 35 deletions

View File

@@ -7,10 +7,10 @@
* 2. 暂时保持接口兼容性
*/
import type { GatewayLanguageModelEntry } from '@ai-sdk/gateway'
import { createExecutor } from '@cherrystudio/ai-core'
import { loggerService } from '@logger'
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
import { normalizeGatewayModels, normalizeSdkModels } from '@renderer/services/models/ModelAdapter'
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
import { type Assistant, type GenerateImageParams, type Model, type Provider, SystemProviderIds } from '@renderer/types'
@@ -481,18 +481,11 @@ export default class ModernAiProvider {
// 代理其他方法到原有实现
public async models() {
if (this.actualProvider.id === SystemProviderIds['ai-gateway']) {
const formatModel = function (models: GatewayLanguageModelEntry[]): Model[] {
return models.map((m) => ({
id: m.id,
name: m.name,
provider: 'gateway',
group: m.id.split('/')[0],
description: m.description ?? undefined
}))
}
return formatModel((await gateway.getAvailableModels()).models)
const gatewayModels = (await gateway.getAvailableModels()).models
return normalizeGatewayModels(this.actualProvider, gatewayModels)
}
return this.legacyProvider.models()
const sdkModels = await this.legacyProvider.models()
return normalizeSdkModels(this.actualProvider, sdkModels)
}
public async getEmbeddingDimensions(model: Model): Promise<number> {

View File

@@ -18,7 +18,7 @@ import NewApiAddModelPopup from '@renderer/pages/settings/ProviderSettings/Model
import NewApiBatchAddModelPopup from '@renderer/pages/settings/ProviderSettings/ModelList/NewApiBatchAddModelPopup'
import { fetchModels } from '@renderer/services/ApiService'
import type { Model, Provider } from '@renderer/types'
import { filterModelsByKeywords, getDefaultGroupName, getFancyProviderName } from '@renderer/utils'
import { filterModelsByKeywords, getFancyProviderName } from '@renderer/utils'
import { isFreeModel } from '@renderer/utils/model'
import { isNewApiProvider } from '@renderer/utils/provider'
import { Button, Empty, Flex, Modal, Spin, Tabs, Tooltip } from 'antd'
@@ -183,25 +183,7 @@ const PopupContainer: React.FC<Props> = ({ providerId, resolve }) => {
setLoadingModels(true)
try {
const models = await fetchModels(provider)
// TODO: More robust conversion
const filteredModels = models
.map((model) => ({
// @ts-ignore modelId
id: model?.id || model?.name,
// @ts-ignore name
name: model?.display_name || model?.displayName || model?.name || model?.id,
provider: provider.id,
// @ts-ignore group
group: getDefaultGroupName(model?.id || model?.name, provider.id),
// @ts-ignore description
description: model?.description || '',
// @ts-ignore owned_by
owned_by: model?.owned_by || '',
// @ts-ignore supported_endpoint_types
supported_endpoint_types: model?.supported_endpoint_types
}))
.filter((model) => !isEmpty(model.name))
const filteredModels = models.filter((model) => !isEmpty(model.name))
setListModels(filteredModels)
} catch (error) {
logger.error(`Failed to load models for provider ${getFancyProviderName(provider)}`, error as Error)

View File

@@ -13,7 +13,6 @@ import type { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/t
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
import { type Chunk, ChunkType } from '@renderer/types/chunk'
import type { Message, ResponseError } from '@renderer/types/newMessage'
import type { SdkModel } from '@renderer/types/sdk'
import { removeSpecialCharactersForTopicName, uuid } from '@renderer/utils'
import { abortCompletion, readyToAbort } from '@renderer/utils/abortController'
import { isToolUseModeFunction } from '@renderer/utils/assistant'
@@ -424,7 +423,7 @@ export function hasApiKey(provider: Provider) {
// return undefined
// }
export async function fetchModels(provider: Provider): Promise<SdkModel[]> {
export async function fetchModels(provider: Provider): Promise<Model[]> {
const AI = new AiProviderNew(provider)
try {

View File

@@ -0,0 +1,102 @@
import type { GatewayLanguageModelEntry } from '@ai-sdk/gateway'
import { normalizeGatewayModels, normalizeSdkModels } from '@renderer/services/models/ModelAdapter'
import type { Model, Provider } from '@renderer/types'
import type { EndpointType } from '@renderer/types/index'
import type { SdkModel } from '@renderer/types/sdk'
import { describe, expect, it } from 'vitest'
const createProvider = (overrides: Partial<Provider> = {}): Provider => ({
id: 'openai',
type: 'openai',
name: 'OpenAI',
apiKey: 'test-key',
apiHost: 'https://example.com/v1',
models: [],
...overrides
})
describe('ModelAdapter', () => {
it('adapts generic SDK models into internal models', () => {
const provider = createProvider({ id: 'openai' })
const models = normalizeSdkModels(provider, [
{
id: 'gpt-4o-mini',
display_name: 'GPT-4o mini',
description: 'General purpose model',
owned_by: 'openai'
} as unknown as SdkModel
])
expect(models).toHaveLength(1)
expect(models[0]).toMatchObject({
id: 'gpt-4o-mini',
name: 'GPT-4o mini',
provider: 'openai',
group: 'gpt-4o',
description: 'General purpose model',
owned_by: 'openai'
} as Partial<Model>)
})
it('preserves supported endpoint types for New API models', () => {
const provider = createProvider({ id: 'new-api' })
const endpointTypes: EndpointType[] = ['openai', 'image-generation']
const [model] = normalizeSdkModels(provider, [
{
id: 'new-api-model',
name: 'New API Model',
supported_endpoint_types: endpointTypes
} as unknown as SdkModel
])
expect(model.supported_endpoint_types).toEqual(endpointTypes)
})
it('filters unsupported endpoint types while keeping valid ones', () => {
const provider = createProvider({ id: 'new-api' })
const [model] = normalizeSdkModels(provider, [
{
id: 'another-model',
name: 'Another Model',
supported_endpoint_types: ['openai', 'unknown-endpoint', 'gemini']
} as unknown as SdkModel
])
expect(model.supported_endpoint_types).toEqual(['openai', 'gemini'])
})
it('adapts ai-gateway entries through the same adapter', () => {
const provider = createProvider({ id: 'ai-gateway', type: 'ai-gateway' })
const [model] = normalizeGatewayModels(provider, [
{
id: 'openai/gpt-4o',
name: 'OpenAI GPT-4o',
description: 'Gateway entry',
specification: {
specificationVersion: 'v2',
provider: 'openai',
modelId: 'gpt-4o'
}
} as GatewayLanguageModelEntry
])
expect(model).toMatchObject({
id: 'openai/gpt-4o',
group: 'openai',
provider: 'ai-gateway',
description: 'Gateway entry'
})
})
it('drops invalid entries without ids or names', () => {
const provider = createProvider()
const models = normalizeSdkModels(provider, [
{
id: '',
name: ''
} as unknown as SdkModel
])
expect(models).toHaveLength(0)
})
})

View File

@@ -0,0 +1,180 @@
import type { GatewayLanguageModelEntry } from '@ai-sdk/gateway'
import { loggerService } from '@logger'
import { type EndpointType, EndPointTypeSchema, type Model, type Provider } from '@renderer/types'
import type { NewApiModel, SdkModel } from '@renderer/types/sdk'
import { getDefaultGroupName } from '@renderer/utils/naming'
import * as z from 'zod'
const logger = loggerService.withContext('ModelAdapter')
const EndpointTypeArraySchema = z.array(EndPointTypeSchema).nonempty()
const NormalizedModelSchema = z.object({
id: z.string().trim().min(1),
name: z.string().trim().min(1),
provider: z.string().trim().min(1),
group: z.string().trim().min(1),
description: z.string().optional(),
owned_by: z.string().optional(),
supported_endpoint_types: EndpointTypeArraySchema.optional()
})
type NormalizedModelInput = z.input<typeof NormalizedModelSchema>
export function normalizeSdkModels(provider: Provider, models: SdkModel[]): Model[] {
return normalizeModels(models, (entry) => adaptSdkModel(provider, entry))
}
export function normalizeGatewayModels(provider: Provider, models: GatewayLanguageModelEntry[]): Model[] {
return normalizeModels(models, (entry) => adaptGatewayModel(provider, entry))
}
function normalizeModels<T>(models: T[], transformer: (entry: T) => Model | null): Model[] {
const uniqueModels: Model[] = []
const seen = new Set<string>()
for (const entry of models) {
const normalized = transformer(entry)
if (!normalized) continue
if (seen.has(normalized.id)) continue
seen.add(normalized.id)
uniqueModels.push(normalized)
}
return uniqueModels
}
function adaptSdkModel(provider: Provider, model: SdkModel): Model | null {
const id = pickPreferredString([(model as any)?.id, (model as any)?.modelId])
const name = pickPreferredString([
(model as any)?.display_name,
(model as any)?.displayName,
(model as any)?.name,
id
])
if (!id || !name) {
logger.warn('Skip SDK model with missing id or name', {
providerId: provider.id,
modelSnippet: summarizeModel(model)
})
return null
}
const candidate: NormalizedModelInput = {
id,
name,
provider: provider.id,
group: getDefaultGroupName(id, provider.id),
description: pickPreferredString([(model as any)?.description, (model as any)?.summary]),
owned_by: pickPreferredString([(model as any)?.owned_by, (model as any)?.publisher])
}
const supportedEndpointTypes = pickSupportedEndpointTypes(provider.id, model)
if (supportedEndpointTypes) {
candidate.supported_endpoint_types = supportedEndpointTypes
}
return validateModel(candidate, model)
}
function adaptGatewayModel(provider: Provider, model: GatewayLanguageModelEntry): Model | null {
const id = model?.id?.trim()
const name = model?.name?.trim() || id
if (!id || !name) {
logger.warn('Skip gateway model with missing id or name', {
providerId: provider.id,
modelSnippet: summarizeModel(model)
})
return null
}
const candidate: NormalizedModelInput = {
id,
name,
provider: provider.id,
group: getDefaultGroupName(id, provider.id),
description: model.description ?? undefined
}
return validateModel(candidate, model)
}
function pickPreferredString(values: Array<unknown>): string | undefined {
for (const value of values) {
if (typeof value === 'string') {
const trimmed = value.trim()
if (trimmed.length > 0) {
return trimmed
}
}
}
return undefined
}
function pickSupportedEndpointTypes(providerId: string, model: SdkModel): EndpointType[] | undefined {
const candidate =
(model as Partial<NewApiModel>).supported_endpoint_types ??
((model as Record<string, unknown>).supported_endpoint_types as EndpointType[] | undefined)
if (!Array.isArray(candidate) || candidate.length === 0) {
return undefined
}
const supported: EndpointType[] = []
const unsupported: unknown[] = []
for (const value of candidate) {
const parsed = EndPointTypeSchema.safeParse(value)
if (parsed.success) {
supported.push(parsed.data)
} else {
unsupported.push(value)
}
}
if (unsupported.length > 0) {
logger.warn('Pruned unsupported endpoint types', {
providerId,
values: unsupported,
modelSnippet: summarizeModel(model)
})
}
return supported.length > 0 ? supported : undefined
}
function validateModel(candidate: NormalizedModelInput, source: unknown): Model | null {
const parsed = NormalizedModelSchema.safeParse(candidate)
if (!parsed.success) {
logger.warn('Discard invalid model entry', {
providerId: candidate.provider,
issues: parsed.error.issues,
modelSnippet: summarizeModel(source)
})
return null
}
return parsed.data
}
function summarizeModel(model: unknown) {
if (!model || typeof model !== 'object') {
return model
}
const { id, name, display_name, displayName, description, owned_by, supported_endpoint_types } = model as Record<
string,
unknown
>
return {
id,
name,
display_name,
displayName,
description,
owned_by,
supported_endpoint_types
}
}

View File

@@ -7,6 +7,8 @@ import type { CSSProperties } from 'react'
export * from './file'
export * from './note'
import * as z from 'zod'
import type { StreamTextParams } from './aiCoreTypes'
import type { Chunk } from './chunk'
import type { FileMetadata } from './file'
@@ -240,7 +242,15 @@ export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'functio
export type ModelTag = Exclude<ModelType, 'text'> | 'free'
// "image-generation" is also openai endpoint, but specifically for image generation.
export type EndpointType = 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'image-generation' | 'jina-rerank'
export const EndPointTypeSchema = z.enum([
'openai',
'openai-response',
'anthropic',
'gemini',
'image-generation',
'jina-rerank'
])
export type EndpointType = z.infer<typeof EndPointTypeSchema>
export type ModelPricing = {
input_per_million_tokens: number