feat(models): update models filtering to use providerType and enhance API schemas
This commit is contained in:
@@ -1,7 +1,8 @@
|
||||
import { ApiModelsFilterSchema } from '@types'
|
||||
import express, { Request, Response } from 'express'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import { ModelsFilterSchema, modelsService } from '../services/models'
|
||||
import { modelsService } from '../services/models'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerModelsRoutes')
|
||||
|
||||
@@ -17,10 +18,10 @@ const router = express
|
||||
* tags: [Models]
|
||||
* parameters:
|
||||
* - in: query
|
||||
* name: provider
|
||||
* name: providerType
|
||||
* schema:
|
||||
* type: string
|
||||
* enum: [openai, anthropic]
|
||||
* enum: [openai, openai-response, anthropic, gemini]
|
||||
* description: Filter models by provider type
|
||||
* - in: query
|
||||
* name: offset
|
||||
@@ -77,7 +78,7 @@ const router = express
|
||||
logger.info('Models list request received', { query: req.query })
|
||||
|
||||
// Validate query parameters using Zod schema
|
||||
const filterResult = ModelsFilterSchema.safeParse(req.query)
|
||||
const filterResult = ApiModelsFilterSchema.safeParse(req.query)
|
||||
|
||||
if (!filterResult.success) {
|
||||
logger.warn('Invalid query parameters:', filterResult.error.issues)
|
||||
|
||||
@@ -1,53 +1,43 @@
|
||||
import {
|
||||
ApiModelsRequest,
|
||||
ApiModelsRequestSchema,
|
||||
ApiModelsResponse,
|
||||
OpenAICompatibleModel
|
||||
} from '../../../renderer/src/types/apiModels'
|
||||
import { ApiModel, ApiModelsRequest, ApiModelsResponse } from '../../../renderer/src/types/apiModels'
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import { getAvailableProviders, listAllAvailableModels, transformModelToOpenAI } from '../utils'
|
||||
|
||||
const logger = loggerService.withContext('ModelsService')
|
||||
|
||||
// Re-export for backward compatibility
|
||||
export const ModelsFilterSchema = ApiModelsRequestSchema
|
||||
|
||||
export type ModelsFilter = ApiModelsRequest
|
||||
|
||||
export class ModelsService {
|
||||
async getModels(filter?: ModelsFilter): Promise<ApiModelsResponse> {
|
||||
async getModels(filter: ModelsFilter): Promise<ApiModelsResponse> {
|
||||
try {
|
||||
logger.info('Getting available models from providers', { filter })
|
||||
logger.debug('Getting available models from providers', { filter })
|
||||
|
||||
const models = await listAllAvailableModels()
|
||||
const providers = await getAvailableProviders()
|
||||
|
||||
// Use Map to deduplicate models by their full ID (provider:model_id)
|
||||
const uniqueModels = new Map<string, OpenAICompatibleModel>()
|
||||
const uniqueModels = new Map<string, ApiModel>()
|
||||
|
||||
for (const model of models) {
|
||||
const openAIModel = transformModelToOpenAI(model)
|
||||
const openAIModel = transformModelToOpenAI(model, providers)
|
||||
const fullModelId = openAIModel.id // This is already in format "provider:model_id"
|
||||
|
||||
// Only add if not already present (first occurrence wins)
|
||||
if (!uniqueModels.has(fullModelId)) {
|
||||
uniqueModels.set(fullModelId, {
|
||||
...openAIModel,
|
||||
name: model.name
|
||||
})
|
||||
uniqueModels.set(fullModelId, openAIModel)
|
||||
} else {
|
||||
logger.debug(`Skipping duplicate model: ${fullModelId}`)
|
||||
}
|
||||
}
|
||||
|
||||
let modelData = Array.from(uniqueModels.values())
|
||||
|
||||
// Apply filters
|
||||
if (filter?.provider) {
|
||||
const providerType = filter.provider
|
||||
if (filter.providerType) {
|
||||
// Apply filters
|
||||
const providerType = filter.providerType
|
||||
modelData = modelData.filter((model) => {
|
||||
// Find the provider for this model and check its type
|
||||
const provider = providers.find((p) => p.id === model.provider)
|
||||
return provider && provider.type === providerType
|
||||
return model.provider_type === providerType
|
||||
})
|
||||
logger.debug(`Filtered by provider type '${providerType}': ${modelData.length} models`)
|
||||
}
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { loggerService } from '@main/services/LoggerService'
|
||||
import { reduxService } from '@main/services/ReduxService'
|
||||
import { Model, OpenAICompatibleModel, Provider } from '@types'
|
||||
import { ApiModel, Model, Provider } from '@types'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerUtils')
|
||||
|
||||
@@ -173,7 +173,8 @@ export async function validateModelId(
|
||||
}
|
||||
}
|
||||
|
||||
export function transformModelToOpenAI(model: Model): OpenAICompatibleModel {
|
||||
export function transformModelToOpenAI(model: Model, providers: Provider[]): ApiModel {
|
||||
const provider = providers.find((p) => p.id === model.provider)
|
||||
return {
|
||||
id: `${model.provider}:${model.id}`,
|
||||
object: 'model',
|
||||
@@ -181,6 +182,7 @@ export function transformModelToOpenAI(model: Model): OpenAICompatibleModel {
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
owned_by: model.owned_by || model.provider,
|
||||
provider: model.provider,
|
||||
provider_type: provider?.type,
|
||||
provider_model_id: model.id
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user