Compare commits
2 Commits
main
...
copilot/fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b14e48dd78 | ||
|
|
64fde27f9e |
@@ -107,7 +107,7 @@ export async function buildStreamTextParams(
|
|||||||
searchWithTime: store.getState().websearch.searchWithTime
|
searchWithTime: store.getState().websearch.searchWithTime
|
||||||
}
|
}
|
||||||
|
|
||||||
const { providerOptions, standardParams } = buildProviderOptions(assistant, model, provider, {
|
const { providerOptions, standardParams, bodyParams } = buildProviderOptions(assistant, model, provider, {
|
||||||
enableReasoning,
|
enableReasoning,
|
||||||
enableWebSearch,
|
enableWebSearch,
|
||||||
enableGenerateImage
|
enableGenerateImage
|
||||||
@@ -185,6 +185,7 @@ export async function buildStreamTextParams(
|
|||||||
// Note: standardParams (topK, frequencyPenalty, presencePenalty, stopSequences, seed)
|
// Note: standardParams (topK, frequencyPenalty, presencePenalty, stopSequences, seed)
|
||||||
// are extracted from custom parameters and passed directly to streamText()
|
// are extracted from custom parameters and passed directly to streamText()
|
||||||
// instead of being placed in providerOptions
|
// instead of being placed in providerOptions
|
||||||
|
// Note: bodyParams are custom parameters for AI Gateway that should be at body level
|
||||||
const params: StreamTextParams = {
|
const params: StreamTextParams = {
|
||||||
messages: sdkMessages,
|
messages: sdkMessages,
|
||||||
maxOutputTokens: getMaxTokens(assistant, model),
|
maxOutputTokens: getMaxTokens(assistant, model),
|
||||||
@@ -192,6 +193,8 @@ export async function buildStreamTextParams(
|
|||||||
topP: getTopP(assistant, model),
|
topP: getTopP(assistant, model),
|
||||||
// Include AI SDK standard params extracted from custom parameters
|
// Include AI SDK standard params extracted from custom parameters
|
||||||
...standardParams,
|
...standardParams,
|
||||||
|
// Include body-level params for AI Gateway custom parameters
|
||||||
|
...bodyParams,
|
||||||
abortSignal: options.requestOptions?.signal,
|
abortSignal: options.requestOptions?.signal,
|
||||||
headers,
|
headers,
|
||||||
providerOptions,
|
providerOptions,
|
||||||
|
|||||||
@@ -37,7 +37,7 @@ vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => {
|
|||||||
},
|
},
|
||||||
customProviderIdSchema: {
|
customProviderIdSchema: {
|
||||||
safeParse: vi.fn((id) => {
|
safeParse: vi.fn((id) => {
|
||||||
const customProviders = ['google-vertex', 'google-vertex-anthropic', 'bedrock']
|
const customProviders = ['google-vertex', 'google-vertex-anthropic', 'bedrock', 'ai-gateway']
|
||||||
if (customProviders.includes(id)) {
|
if (customProviders.includes(id)) {
|
||||||
return { success: true, data: id }
|
return { success: true, data: id }
|
||||||
}
|
}
|
||||||
@@ -56,7 +56,8 @@ vi.mock('../provider/factory', () => ({
|
|||||||
[SystemProviderIds.anthropic]: 'anthropic',
|
[SystemProviderIds.anthropic]: 'anthropic',
|
||||||
[SystemProviderIds.grok]: 'xai',
|
[SystemProviderIds.grok]: 'xai',
|
||||||
[SystemProviderIds.deepseek]: 'deepseek',
|
[SystemProviderIds.deepseek]: 'deepseek',
|
||||||
[SystemProviderIds.openrouter]: 'openrouter'
|
[SystemProviderIds.openrouter]: 'openrouter',
|
||||||
|
[SystemProviderIds['ai-gateway']]: 'ai-gateway'
|
||||||
}
|
}
|
||||||
return mapping[provider.id] || provider.id
|
return mapping[provider.id] || provider.id
|
||||||
})
|
})
|
||||||
@@ -204,6 +205,8 @@ describe('options utils', () => {
|
|||||||
expect(result.providerOptions).toHaveProperty('openai')
|
expect(result.providerOptions).toHaveProperty('openai')
|
||||||
expect(result.providerOptions.openai).toBeDefined()
|
expect(result.providerOptions.openai).toBeDefined()
|
||||||
expect(result.standardParams).toBeDefined()
|
expect(result.standardParams).toBeDefined()
|
||||||
|
expect(result.bodyParams).toBeDefined()
|
||||||
|
expect(result.bodyParams).toEqual({})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should include reasoning parameters when enabled', () => {
|
it('should include reasoning parameters when enabled', () => {
|
||||||
@@ -696,5 +699,90 @@ describe('options utils', () => {
|
|||||||
})
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe('AI Gateway provider', () => {
|
||||||
|
const aiGatewayProvider: Provider = {
|
||||||
|
id: SystemProviderIds['ai-gateway'],
|
||||||
|
name: 'AI Gateway',
|
||||||
|
type: 'ai-gateway',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://ai-gateway.vercel.sh/v1/ai',
|
||||||
|
isSystem: true,
|
||||||
|
models: [] as Model[]
|
||||||
|
} as Provider
|
||||||
|
|
||||||
|
const aiGatewayModel: Model = {
|
||||||
|
id: 'openai/gpt-4',
|
||||||
|
name: 'GPT-4',
|
||||||
|
provider: SystemProviderIds['ai-gateway']
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
it('should build basic AI Gateway options with empty bodyParams', () => {
|
||||||
|
const result = buildProviderOptions(mockAssistant, aiGatewayModel, aiGatewayProvider, {
|
||||||
|
enableReasoning: false,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.providerOptions).toHaveProperty('gateway')
|
||||||
|
expect(result.providerOptions.gateway).toBeDefined()
|
||||||
|
expect(result.bodyParams).toEqual({})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should place custom parameters in bodyParams for AI Gateway instead of providerOptions', async () => {
|
||||||
|
const { getCustomParameters } = await import('../reasoning')
|
||||||
|
|
||||||
|
vi.mocked(getCustomParameters).mockReturnValue({
|
||||||
|
tools: [{ id: 'openai.image_generation' }],
|
||||||
|
custom_param: 'custom_value'
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = buildProviderOptions(mockAssistant, aiGatewayModel, aiGatewayProvider, {
|
||||||
|
enableReasoning: false,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
// Custom parameters should be in bodyParams, NOT in providerOptions.gateway
|
||||||
|
expect(result.bodyParams).toHaveProperty('tools')
|
||||||
|
expect(result.bodyParams.tools).toEqual([{ id: 'openai.image_generation' }])
|
||||||
|
expect(result.bodyParams).toHaveProperty('custom_param')
|
||||||
|
expect(result.bodyParams.custom_param).toBe('custom_value')
|
||||||
|
|
||||||
|
// providerOptions.gateway should NOT contain custom parameters
|
||||||
|
expect(result.providerOptions.gateway).not.toHaveProperty('tools')
|
||||||
|
expect(result.providerOptions.gateway).not.toHaveProperty('custom_param')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should still extract AI SDK standard params from custom parameters for AI Gateway', async () => {
|
||||||
|
const { getCustomParameters } = await import('../reasoning')
|
||||||
|
|
||||||
|
vi.mocked(getCustomParameters).mockReturnValue({
|
||||||
|
topK: 5,
|
||||||
|
frequencyPenalty: 0.5,
|
||||||
|
tools: [{ id: 'openai.image_generation' }]
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = buildProviderOptions(mockAssistant, aiGatewayModel, aiGatewayProvider, {
|
||||||
|
enableReasoning: false,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
// Standard params should be extracted and returned separately
|
||||||
|
expect(result.standardParams).toEqual({
|
||||||
|
topK: 5,
|
||||||
|
frequencyPenalty: 0.5
|
||||||
|
})
|
||||||
|
|
||||||
|
// Custom params (non-standard) should be in bodyParams
|
||||||
|
expect(result.bodyParams).toHaveProperty('tools')
|
||||||
|
expect(result.bodyParams.tools).toEqual([{ id: 'openai.image_generation' }])
|
||||||
|
|
||||||
|
// Neither should be in providerOptions.gateway
|
||||||
|
expect(result.providerOptions.gateway).not.toHaveProperty('topK')
|
||||||
|
expect(result.providerOptions.gateway).not.toHaveProperty('tools')
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -155,6 +155,7 @@ export function buildProviderOptions(
|
|||||||
): {
|
): {
|
||||||
providerOptions: Record<string, Record<string, JSONValue>>
|
providerOptions: Record<string, Record<string, JSONValue>>
|
||||||
standardParams: Partial<Record<AiSdkParam, any>>
|
standardParams: Partial<Record<AiSdkParam, any>>
|
||||||
|
bodyParams: Record<string, any>
|
||||||
} {
|
} {
|
||||||
logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities })
|
logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities })
|
||||||
const rawProviderId = getAiSdkProviderId(actualProvider)
|
const rawProviderId = getAiSdkProviderId(actualProvider)
|
||||||
@@ -253,12 +254,6 @@ export function buildProviderOptions(
|
|||||||
const customParams = getCustomParameters(assistant)
|
const customParams = getCustomParameters(assistant)
|
||||||
const { standardParams, providerParams } = extractAiSdkStandardParams(customParams)
|
const { standardParams, providerParams } = extractAiSdkStandardParams(customParams)
|
||||||
|
|
||||||
// 合并 provider 特定的自定义参数到 providerSpecificOptions
|
|
||||||
providerSpecificOptions = {
|
|
||||||
...providerSpecificOptions,
|
|
||||||
...providerParams
|
|
||||||
}
|
|
||||||
|
|
||||||
let rawProviderKey =
|
let rawProviderKey =
|
||||||
{
|
{
|
||||||
'google-vertex': 'google',
|
'google-vertex': 'google',
|
||||||
@@ -273,12 +268,27 @@ export function buildProviderOptions(
|
|||||||
rawProviderKey = { gemini: 'google', ['openai-response']: 'openai' }[actualProvider.type] || actualProvider.type
|
rawProviderKey = { gemini: 'google', ['openai-response']: 'openai' }[actualProvider.type] || actualProvider.type
|
||||||
}
|
}
|
||||||
|
|
||||||
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数
|
// For AI Gateway, custom parameters should be placed at body level, not inside providerOptions.gateway
|
||||||
|
// See: https://github.com/CherryHQ/cherry-studio/issues/4197
|
||||||
|
let bodyParams: Record<string, any> = {}
|
||||||
|
if (rawProviderKey === 'gateway') {
|
||||||
|
// Custom parameters go to body level for AI Gateway
|
||||||
|
bodyParams = providerParams
|
||||||
|
} else {
|
||||||
|
// For other providers, merge custom parameters into providerSpecificOptions
|
||||||
|
providerSpecificOptions = {
|
||||||
|
...providerSpecificOptions,
|
||||||
|
...providerParams
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数和 body 参数
|
||||||
return {
|
return {
|
||||||
providerOptions: {
|
providerOptions: {
|
||||||
[rawProviderKey]: providerSpecificOptions
|
[rawProviderKey]: providerSpecificOptions
|
||||||
},
|
},
|
||||||
standardParams
|
standardParams,
|
||||||
|
bodyParams
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user