Compare commits
1 Commits
fix/inputb
...
fix/valida
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b3cd1edfdc |
@@ -7,10 +7,10 @@
|
||||
* 2. 暂时保持接口兼容性
|
||||
*/
|
||||
|
||||
import type { GatewayLanguageModelEntry } from '@ai-sdk/gateway'
|
||||
import { createExecutor } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
|
||||
import { normalizeGatewayModels, normalizeSdkModels } from '@renderer/services/models/ModelAdapter'
|
||||
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
|
||||
import type { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||
import { type Assistant, type GenerateImageParams, type Model, type Provider, SystemProviderIds } from '@renderer/types'
|
||||
@@ -481,18 +481,11 @@ export default class ModernAiProvider {
|
||||
// 代理其他方法到原有实现
|
||||
public async models() {
|
||||
if (this.actualProvider.id === SystemProviderIds['ai-gateway']) {
|
||||
const formatModel = function (models: GatewayLanguageModelEntry[]): Model[] {
|
||||
return models.map((m) => ({
|
||||
id: m.id,
|
||||
name: m.name,
|
||||
provider: 'gateway',
|
||||
group: m.id.split('/')[0],
|
||||
description: m.description ?? undefined
|
||||
}))
|
||||
}
|
||||
return formatModel((await gateway.getAvailableModels()).models)
|
||||
const gatewayModels = (await gateway.getAvailableModels()).models
|
||||
return normalizeGatewayModels(this.actualProvider, gatewayModels)
|
||||
}
|
||||
return this.legacyProvider.models()
|
||||
const sdkModels = await this.legacyProvider.models()
|
||||
return normalizeSdkModels(this.actualProvider, sdkModels)
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
|
||||
@@ -18,7 +18,7 @@ import NewApiAddModelPopup from '@renderer/pages/settings/ProviderSettings/Model
|
||||
import NewApiBatchAddModelPopup from '@renderer/pages/settings/ProviderSettings/ModelList/NewApiBatchAddModelPopup'
|
||||
import { fetchModels } from '@renderer/services/ApiService'
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
import { filterModelsByKeywords, getDefaultGroupName, getFancyProviderName } from '@renderer/utils'
|
||||
import { filterModelsByKeywords, getFancyProviderName } from '@renderer/utils'
|
||||
import { isFreeModel } from '@renderer/utils/model'
|
||||
import { isNewApiProvider } from '@renderer/utils/provider'
|
||||
import { Button, Empty, Flex, Modal, Spin, Tabs, Tooltip } from 'antd'
|
||||
@@ -183,25 +183,7 @@ const PopupContainer: React.FC<Props> = ({ providerId, resolve }) => {
|
||||
setLoadingModels(true)
|
||||
try {
|
||||
const models = await fetchModels(provider)
|
||||
// TODO: More robust conversion
|
||||
const filteredModels = models
|
||||
.map((model) => ({
|
||||
// @ts-ignore modelId
|
||||
id: model?.id || model?.name,
|
||||
// @ts-ignore name
|
||||
name: model?.display_name || model?.displayName || model?.name || model?.id,
|
||||
provider: provider.id,
|
||||
// @ts-ignore group
|
||||
group: getDefaultGroupName(model?.id || model?.name, provider.id),
|
||||
// @ts-ignore description
|
||||
description: model?.description || '',
|
||||
// @ts-ignore owned_by
|
||||
owned_by: model?.owned_by || '',
|
||||
// @ts-ignore supported_endpoint_types
|
||||
supported_endpoint_types: model?.supported_endpoint_types
|
||||
}))
|
||||
.filter((model) => !isEmpty(model.name))
|
||||
|
||||
const filteredModels = models.filter((model) => !isEmpty(model.name))
|
||||
setListModels(filteredModels)
|
||||
} catch (error) {
|
||||
logger.error(`Failed to load models for provider ${getFancyProviderName(provider)}`, error as Error)
|
||||
|
||||
@@ -13,7 +13,6 @@ import type { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/t
|
||||
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import { type Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import type { Message, ResponseError } from '@renderer/types/newMessage'
|
||||
import type { SdkModel } from '@renderer/types/sdk'
|
||||
import { removeSpecialCharactersForTopicName, uuid } from '@renderer/utils'
|
||||
import { abortCompletion, readyToAbort } from '@renderer/utils/abortController'
|
||||
import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
||||
@@ -424,7 +423,7 @@ export function hasApiKey(provider: Provider) {
|
||||
// return undefined
|
||||
// }
|
||||
|
||||
export async function fetchModels(provider: Provider): Promise<SdkModel[]> {
|
||||
export async function fetchModels(provider: Provider): Promise<Model[]> {
|
||||
const AI = new AiProviderNew(provider)
|
||||
|
||||
try {
|
||||
|
||||
102
src/renderer/src/services/__tests__/ModelAdapter.test.ts
Normal file
102
src/renderer/src/services/__tests__/ModelAdapter.test.ts
Normal file
@@ -0,0 +1,102 @@
|
||||
import type { GatewayLanguageModelEntry } from '@ai-sdk/gateway'
|
||||
import { normalizeGatewayModels, normalizeSdkModels } from '@renderer/services/models/ModelAdapter'
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
import type { EndpointType } from '@renderer/types/index'
|
||||
import type { SdkModel } from '@renderer/types/sdk'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
const createProvider = (overrides: Partial<Provider> = {}): Provider => ({
|
||||
id: 'openai',
|
||||
type: 'openai',
|
||||
name: 'OpenAI',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://example.com/v1',
|
||||
models: [],
|
||||
...overrides
|
||||
})
|
||||
|
||||
describe('ModelAdapter', () => {
|
||||
it('adapts generic SDK models into internal models', () => {
|
||||
const provider = createProvider({ id: 'openai' })
|
||||
const models = normalizeSdkModels(provider, [
|
||||
{
|
||||
id: 'gpt-4o-mini',
|
||||
display_name: 'GPT-4o mini',
|
||||
description: 'General purpose model',
|
||||
owned_by: 'openai'
|
||||
} as unknown as SdkModel
|
||||
])
|
||||
|
||||
expect(models).toHaveLength(1)
|
||||
expect(models[0]).toMatchObject({
|
||||
id: 'gpt-4o-mini',
|
||||
name: 'GPT-4o mini',
|
||||
provider: 'openai',
|
||||
group: 'gpt-4o',
|
||||
description: 'General purpose model',
|
||||
owned_by: 'openai'
|
||||
} as Partial<Model>)
|
||||
})
|
||||
|
||||
it('preserves supported endpoint types for New API models', () => {
|
||||
const provider = createProvider({ id: 'new-api' })
|
||||
const endpointTypes: EndpointType[] = ['openai', 'image-generation']
|
||||
const [model] = normalizeSdkModels(provider, [
|
||||
{
|
||||
id: 'new-api-model',
|
||||
name: 'New API Model',
|
||||
supported_endpoint_types: endpointTypes
|
||||
} as unknown as SdkModel
|
||||
])
|
||||
|
||||
expect(model.supported_endpoint_types).toEqual(endpointTypes)
|
||||
})
|
||||
|
||||
it('filters unsupported endpoint types while keeping valid ones', () => {
|
||||
const provider = createProvider({ id: 'new-api' })
|
||||
const [model] = normalizeSdkModels(provider, [
|
||||
{
|
||||
id: 'another-model',
|
||||
name: 'Another Model',
|
||||
supported_endpoint_types: ['openai', 'unknown-endpoint', 'gemini']
|
||||
} as unknown as SdkModel
|
||||
])
|
||||
|
||||
expect(model.supported_endpoint_types).toEqual(['openai', 'gemini'])
|
||||
})
|
||||
|
||||
it('adapts ai-gateway entries through the same adapter', () => {
|
||||
const provider = createProvider({ id: 'ai-gateway', type: 'ai-gateway' })
|
||||
const [model] = normalizeGatewayModels(provider, [
|
||||
{
|
||||
id: 'openai/gpt-4o',
|
||||
name: 'OpenAI GPT-4o',
|
||||
description: 'Gateway entry',
|
||||
specification: {
|
||||
specificationVersion: 'v2',
|
||||
provider: 'openai',
|
||||
modelId: 'gpt-4o'
|
||||
}
|
||||
} as GatewayLanguageModelEntry
|
||||
])
|
||||
|
||||
expect(model).toMatchObject({
|
||||
id: 'openai/gpt-4o',
|
||||
group: 'openai',
|
||||
provider: 'ai-gateway',
|
||||
description: 'Gateway entry'
|
||||
})
|
||||
})
|
||||
|
||||
it('drops invalid entries without ids or names', () => {
|
||||
const provider = createProvider()
|
||||
const models = normalizeSdkModels(provider, [
|
||||
{
|
||||
id: '',
|
||||
name: ''
|
||||
} as unknown as SdkModel
|
||||
])
|
||||
|
||||
expect(models).toHaveLength(0)
|
||||
})
|
||||
})
|
||||
180
src/renderer/src/services/models/ModelAdapter.ts
Normal file
180
src/renderer/src/services/models/ModelAdapter.ts
Normal file
@@ -0,0 +1,180 @@
|
||||
import type { GatewayLanguageModelEntry } from '@ai-sdk/gateway'
|
||||
import { loggerService } from '@logger'
|
||||
import { type EndpointType, EndPointTypeSchema, type Model, type Provider } from '@renderer/types'
|
||||
import type { NewApiModel, SdkModel } from '@renderer/types/sdk'
|
||||
import { getDefaultGroupName } from '@renderer/utils/naming'
|
||||
import * as z from 'zod'
|
||||
|
||||
const logger = loggerService.withContext('ModelAdapter')
|
||||
|
||||
const EndpointTypeArraySchema = z.array(EndPointTypeSchema).nonempty()
|
||||
|
||||
const NormalizedModelSchema = z.object({
|
||||
id: z.string().trim().min(1),
|
||||
name: z.string().trim().min(1),
|
||||
provider: z.string().trim().min(1),
|
||||
group: z.string().trim().min(1),
|
||||
description: z.string().optional(),
|
||||
owned_by: z.string().optional(),
|
||||
supported_endpoint_types: EndpointTypeArraySchema.optional()
|
||||
})
|
||||
|
||||
type NormalizedModelInput = z.input<typeof NormalizedModelSchema>
|
||||
|
||||
export function normalizeSdkModels(provider: Provider, models: SdkModel[]): Model[] {
|
||||
return normalizeModels(models, (entry) => adaptSdkModel(provider, entry))
|
||||
}
|
||||
|
||||
export function normalizeGatewayModels(provider: Provider, models: GatewayLanguageModelEntry[]): Model[] {
|
||||
return normalizeModels(models, (entry) => adaptGatewayModel(provider, entry))
|
||||
}
|
||||
|
||||
function normalizeModels<T>(models: T[], transformer: (entry: T) => Model | null): Model[] {
|
||||
const uniqueModels: Model[] = []
|
||||
const seen = new Set<string>()
|
||||
|
||||
for (const entry of models) {
|
||||
const normalized = transformer(entry)
|
||||
if (!normalized) continue
|
||||
if (seen.has(normalized.id)) continue
|
||||
seen.add(normalized.id)
|
||||
uniqueModels.push(normalized)
|
||||
}
|
||||
|
||||
return uniqueModels
|
||||
}
|
||||
|
||||
function adaptSdkModel(provider: Provider, model: SdkModel): Model | null {
|
||||
const id = pickPreferredString([(model as any)?.id, (model as any)?.modelId])
|
||||
const name = pickPreferredString([
|
||||
(model as any)?.display_name,
|
||||
(model as any)?.displayName,
|
||||
(model as any)?.name,
|
||||
id
|
||||
])
|
||||
|
||||
if (!id || !name) {
|
||||
logger.warn('Skip SDK model with missing id or name', {
|
||||
providerId: provider.id,
|
||||
modelSnippet: summarizeModel(model)
|
||||
})
|
||||
return null
|
||||
}
|
||||
|
||||
const candidate: NormalizedModelInput = {
|
||||
id,
|
||||
name,
|
||||
provider: provider.id,
|
||||
group: getDefaultGroupName(id, provider.id),
|
||||
description: pickPreferredString([(model as any)?.description, (model as any)?.summary]),
|
||||
owned_by: pickPreferredString([(model as any)?.owned_by, (model as any)?.publisher])
|
||||
}
|
||||
|
||||
const supportedEndpointTypes = pickSupportedEndpointTypes(provider.id, model)
|
||||
if (supportedEndpointTypes) {
|
||||
candidate.supported_endpoint_types = supportedEndpointTypes
|
||||
}
|
||||
|
||||
return validateModel(candidate, model)
|
||||
}
|
||||
|
||||
function adaptGatewayModel(provider: Provider, model: GatewayLanguageModelEntry): Model | null {
|
||||
const id = model?.id?.trim()
|
||||
const name = model?.name?.trim() || id
|
||||
|
||||
if (!id || !name) {
|
||||
logger.warn('Skip gateway model with missing id or name', {
|
||||
providerId: provider.id,
|
||||
modelSnippet: summarizeModel(model)
|
||||
})
|
||||
return null
|
||||
}
|
||||
|
||||
const candidate: NormalizedModelInput = {
|
||||
id,
|
||||
name,
|
||||
provider: provider.id,
|
||||
group: getDefaultGroupName(id, provider.id),
|
||||
description: model.description ?? undefined
|
||||
}
|
||||
|
||||
return validateModel(candidate, model)
|
||||
}
|
||||
|
||||
function pickPreferredString(values: Array<unknown>): string | undefined {
|
||||
for (const value of values) {
|
||||
if (typeof value === 'string') {
|
||||
const trimmed = value.trim()
|
||||
if (trimmed.length > 0) {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
function pickSupportedEndpointTypes(providerId: string, model: SdkModel): EndpointType[] | undefined {
|
||||
const candidate =
|
||||
(model as Partial<NewApiModel>).supported_endpoint_types ??
|
||||
((model as Record<string, unknown>).supported_endpoint_types as EndpointType[] | undefined)
|
||||
|
||||
if (!Array.isArray(candidate) || candidate.length === 0) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const supported: EndpointType[] = []
|
||||
const unsupported: unknown[] = []
|
||||
|
||||
for (const value of candidate) {
|
||||
const parsed = EndPointTypeSchema.safeParse(value)
|
||||
if (parsed.success) {
|
||||
supported.push(parsed.data)
|
||||
} else {
|
||||
unsupported.push(value)
|
||||
}
|
||||
}
|
||||
|
||||
if (unsupported.length > 0) {
|
||||
logger.warn('Pruned unsupported endpoint types', {
|
||||
providerId,
|
||||
values: unsupported,
|
||||
modelSnippet: summarizeModel(model)
|
||||
})
|
||||
}
|
||||
|
||||
return supported.length > 0 ? supported : undefined
|
||||
}
|
||||
|
||||
function validateModel(candidate: NormalizedModelInput, source: unknown): Model | null {
|
||||
const parsed = NormalizedModelSchema.safeParse(candidate)
|
||||
if (!parsed.success) {
|
||||
logger.warn('Discard invalid model entry', {
|
||||
providerId: candidate.provider,
|
||||
issues: parsed.error.issues,
|
||||
modelSnippet: summarizeModel(source)
|
||||
})
|
||||
return null
|
||||
}
|
||||
|
||||
return parsed.data
|
||||
}
|
||||
|
||||
function summarizeModel(model: unknown) {
|
||||
if (!model || typeof model !== 'object') {
|
||||
return model
|
||||
}
|
||||
const { id, name, display_name, displayName, description, owned_by, supported_endpoint_types } = model as Record<
|
||||
string,
|
||||
unknown
|
||||
>
|
||||
|
||||
return {
|
||||
id,
|
||||
name,
|
||||
display_name,
|
||||
displayName,
|
||||
description,
|
||||
owned_by,
|
||||
supported_endpoint_types
|
||||
}
|
||||
}
|
||||
@@ -7,6 +7,8 @@ import type { CSSProperties } from 'react'
|
||||
export * from './file'
|
||||
export * from './note'
|
||||
|
||||
import * as z from 'zod'
|
||||
|
||||
import type { StreamTextParams } from './aiCoreTypes'
|
||||
import type { Chunk } from './chunk'
|
||||
import type { FileMetadata } from './file'
|
||||
@@ -240,7 +242,15 @@ export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'functio
|
||||
export type ModelTag = Exclude<ModelType, 'text'> | 'free'
|
||||
|
||||
// "image-generation" is also openai endpoint, but specifically for image generation.
|
||||
export type EndpointType = 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'image-generation' | 'jina-rerank'
|
||||
export const EndPointTypeSchema = z.enum([
|
||||
'openai',
|
||||
'openai-response',
|
||||
'anthropic',
|
||||
'gemini',
|
||||
'image-generation',
|
||||
'jina-rerank'
|
||||
])
|
||||
export type EndpointType = z.infer<typeof EndPointTypeSchema>
|
||||
|
||||
export type ModelPricing = {
|
||||
input_per_million_tokens: number
|
||||
|
||||
Reference in New Issue
Block a user