Compare commits
1 Commits
copilot/fi
...
fix/inputb
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
776fb27a58 |
@@ -107,7 +107,7 @@ export async function buildStreamTextParams(
|
||||
searchWithTime: store.getState().websearch.searchWithTime
|
||||
}
|
||||
|
||||
const { providerOptions, standardParams, bodyParams } = buildProviderOptions(assistant, model, provider, {
|
||||
const { providerOptions, standardParams } = buildProviderOptions(assistant, model, provider, {
|
||||
enableReasoning,
|
||||
enableWebSearch,
|
||||
enableGenerateImage
|
||||
@@ -185,7 +185,6 @@ export async function buildStreamTextParams(
|
||||
// Note: standardParams (topK, frequencyPenalty, presencePenalty, stopSequences, seed)
|
||||
// are extracted from custom parameters and passed directly to streamText()
|
||||
// instead of being placed in providerOptions
|
||||
// Note: bodyParams are custom parameters for AI Gateway that should be at body level
|
||||
const params: StreamTextParams = {
|
||||
messages: sdkMessages,
|
||||
maxOutputTokens: getMaxTokens(assistant, model),
|
||||
@@ -193,8 +192,6 @@ export async function buildStreamTextParams(
|
||||
topP: getTopP(assistant, model),
|
||||
// Include AI SDK standard params extracted from custom parameters
|
||||
...standardParams,
|
||||
// Include body-level params for AI Gateway custom parameters
|
||||
...bodyParams,
|
||||
abortSignal: options.requestOptions?.signal,
|
||||
headers,
|
||||
providerOptions,
|
||||
|
||||
@@ -37,7 +37,7 @@ vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => {
|
||||
},
|
||||
customProviderIdSchema: {
|
||||
safeParse: vi.fn((id) => {
|
||||
const customProviders = ['google-vertex', 'google-vertex-anthropic', 'bedrock', 'ai-gateway']
|
||||
const customProviders = ['google-vertex', 'google-vertex-anthropic', 'bedrock']
|
||||
if (customProviders.includes(id)) {
|
||||
return { success: true, data: id }
|
||||
}
|
||||
@@ -56,8 +56,7 @@ vi.mock('../provider/factory', () => ({
|
||||
[SystemProviderIds.anthropic]: 'anthropic',
|
||||
[SystemProviderIds.grok]: 'xai',
|
||||
[SystemProviderIds.deepseek]: 'deepseek',
|
||||
[SystemProviderIds.openrouter]: 'openrouter',
|
||||
[SystemProviderIds['ai-gateway']]: 'ai-gateway'
|
||||
[SystemProviderIds.openrouter]: 'openrouter'
|
||||
}
|
||||
return mapping[provider.id] || provider.id
|
||||
})
|
||||
@@ -205,8 +204,6 @@ describe('options utils', () => {
|
||||
expect(result.providerOptions).toHaveProperty('openai')
|
||||
expect(result.providerOptions.openai).toBeDefined()
|
||||
expect(result.standardParams).toBeDefined()
|
||||
expect(result.bodyParams).toBeDefined()
|
||||
expect(result.bodyParams).toEqual({})
|
||||
})
|
||||
|
||||
it('should include reasoning parameters when enabled', () => {
|
||||
@@ -699,90 +696,5 @@ describe('options utils', () => {
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('AI Gateway provider', () => {
|
||||
const aiGatewayProvider: Provider = {
|
||||
id: SystemProviderIds['ai-gateway'],
|
||||
name: 'AI Gateway',
|
||||
type: 'ai-gateway',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://ai-gateway.vercel.sh/v1/ai',
|
||||
isSystem: true,
|
||||
models: [] as Model[]
|
||||
} as Provider
|
||||
|
||||
const aiGatewayModel: Model = {
|
||||
id: 'openai/gpt-4',
|
||||
name: 'GPT-4',
|
||||
provider: SystemProviderIds['ai-gateway']
|
||||
} as Model
|
||||
|
||||
it('should build basic AI Gateway options with empty bodyParams', () => {
|
||||
const result = buildProviderOptions(mockAssistant, aiGatewayModel, aiGatewayProvider, {
|
||||
enableReasoning: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result.providerOptions).toHaveProperty('gateway')
|
||||
expect(result.providerOptions.gateway).toBeDefined()
|
||||
expect(result.bodyParams).toEqual({})
|
||||
})
|
||||
|
||||
it('should place custom parameters in bodyParams for AI Gateway instead of providerOptions', async () => {
|
||||
const { getCustomParameters } = await import('../reasoning')
|
||||
|
||||
vi.mocked(getCustomParameters).mockReturnValue({
|
||||
tools: [{ id: 'openai.image_generation' }],
|
||||
custom_param: 'custom_value'
|
||||
})
|
||||
|
||||
const result = buildProviderOptions(mockAssistant, aiGatewayModel, aiGatewayProvider, {
|
||||
enableReasoning: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
// Custom parameters should be in bodyParams, NOT in providerOptions.gateway
|
||||
expect(result.bodyParams).toHaveProperty('tools')
|
||||
expect(result.bodyParams.tools).toEqual([{ id: 'openai.image_generation' }])
|
||||
expect(result.bodyParams).toHaveProperty('custom_param')
|
||||
expect(result.bodyParams.custom_param).toBe('custom_value')
|
||||
|
||||
// providerOptions.gateway should NOT contain custom parameters
|
||||
expect(result.providerOptions.gateway).not.toHaveProperty('tools')
|
||||
expect(result.providerOptions.gateway).not.toHaveProperty('custom_param')
|
||||
})
|
||||
|
||||
it('should still extract AI SDK standard params from custom parameters for AI Gateway', async () => {
|
||||
const { getCustomParameters } = await import('../reasoning')
|
||||
|
||||
vi.mocked(getCustomParameters).mockReturnValue({
|
||||
topK: 5,
|
||||
frequencyPenalty: 0.5,
|
||||
tools: [{ id: 'openai.image_generation' }]
|
||||
})
|
||||
|
||||
const result = buildProviderOptions(mockAssistant, aiGatewayModel, aiGatewayProvider, {
|
||||
enableReasoning: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
// Standard params should be extracted and returned separately
|
||||
expect(result.standardParams).toEqual({
|
||||
topK: 5,
|
||||
frequencyPenalty: 0.5
|
||||
})
|
||||
|
||||
// Custom params (non-standard) should be in bodyParams
|
||||
expect(result.bodyParams).toHaveProperty('tools')
|
||||
expect(result.bodyParams.tools).toEqual([{ id: 'openai.image_generation' }])
|
||||
|
||||
// Neither should be in providerOptions.gateway
|
||||
expect(result.providerOptions.gateway).not.toHaveProperty('topK')
|
||||
expect(result.providerOptions.gateway).not.toHaveProperty('tools')
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -155,7 +155,6 @@ export function buildProviderOptions(
|
||||
): {
|
||||
providerOptions: Record<string, Record<string, JSONValue>>
|
||||
standardParams: Partial<Record<AiSdkParam, any>>
|
||||
bodyParams: Record<string, any>
|
||||
} {
|
||||
logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities })
|
||||
const rawProviderId = getAiSdkProviderId(actualProvider)
|
||||
@@ -254,6 +253,12 @@ export function buildProviderOptions(
|
||||
const customParams = getCustomParameters(assistant)
|
||||
const { standardParams, providerParams } = extractAiSdkStandardParams(customParams)
|
||||
|
||||
// 合并 provider 特定的自定义参数到 providerSpecificOptions
|
||||
providerSpecificOptions = {
|
||||
...providerSpecificOptions,
|
||||
...providerParams
|
||||
}
|
||||
|
||||
let rawProviderKey =
|
||||
{
|
||||
'google-vertex': 'google',
|
||||
@@ -268,27 +273,12 @@ export function buildProviderOptions(
|
||||
rawProviderKey = { gemini: 'google', ['openai-response']: 'openai' }[actualProvider.type] || actualProvider.type
|
||||
}
|
||||
|
||||
// For AI Gateway, custom parameters should be placed at body level, not inside providerOptions.gateway
|
||||
// See: https://github.com/CherryHQ/cherry-studio/issues/4197
|
||||
let bodyParams: Record<string, any> = {}
|
||||
if (rawProviderKey === 'gateway') {
|
||||
// Custom parameters go to body level for AI Gateway
|
||||
bodyParams = providerParams
|
||||
} else {
|
||||
// For other providers, merge custom parameters into providerSpecificOptions
|
||||
providerSpecificOptions = {
|
||||
...providerSpecificOptions,
|
||||
...providerParams
|
||||
}
|
||||
}
|
||||
|
||||
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数和 body 参数
|
||||
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数
|
||||
return {
|
||||
providerOptions: {
|
||||
[rawProviderKey]: providerSpecificOptions
|
||||
},
|
||||
standardParams,
|
||||
bodyParams
|
||||
standardParams
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -51,7 +51,7 @@ export function useTextareaResize(options: UseTextareaResizeOptions = {}): UseTe
|
||||
const { maxHeight = 400, minHeight = 30, autoResize = true } = options
|
||||
|
||||
const textareaRef = useRef<TextAreaRef>(null)
|
||||
const [customHeight, setCustomHeight] = useState<number>()
|
||||
const [customHeight, setCustomHeight] = useState<number | undefined>(undefined)
|
||||
const [isExpanded, setIsExpanded] = useState(false)
|
||||
|
||||
const resize = useCallback(
|
||||
|
||||
@@ -177,8 +177,10 @@ const AgentSessionInputbarInner: FC<InnerProps> = ({ assistant, agentId, session
|
||||
resize: resizeTextArea,
|
||||
focus: focusTextarea,
|
||||
setExpanded,
|
||||
isExpanded: textareaIsExpanded
|
||||
} = useTextareaResize({ maxHeight: 400, minHeight: 30 })
|
||||
isExpanded: textareaIsExpanded,
|
||||
customHeight,
|
||||
setCustomHeight
|
||||
} = useTextareaResize({ maxHeight: 500, minHeight: 30 })
|
||||
const { sendMessageShortcut, apiServer } = useSettings()
|
||||
|
||||
const { t } = useTranslation()
|
||||
@@ -474,6 +476,8 @@ const AgentSessionInputbarInner: FC<InnerProps> = ({ assistant, agentId, session
|
||||
text={text}
|
||||
onTextChange={setText}
|
||||
textareaRef={textareaRef}
|
||||
height={customHeight}
|
||||
onHeightChange={setCustomHeight}
|
||||
resizeTextArea={resizeTextArea}
|
||||
focusTextarea={focusTextarea}
|
||||
placeholder={placeholderText}
|
||||
|
||||
@@ -143,9 +143,11 @@ const InputbarInner: FC<InputbarInnerProps> = ({ assistant: initialAssistant, se
|
||||
resize: resizeTextArea,
|
||||
focus: focusTextarea,
|
||||
setExpanded,
|
||||
isExpanded: textareaIsExpanded
|
||||
isExpanded: textareaIsExpanded,
|
||||
customHeight,
|
||||
setCustomHeight
|
||||
} = useTextareaResize({
|
||||
maxHeight: 400,
|
||||
maxHeight: 500,
|
||||
minHeight: 30
|
||||
})
|
||||
|
||||
@@ -257,7 +259,7 @@ const InputbarInner: FC<InputbarInnerProps> = ({ assistant: initialAssistant, se
|
||||
setText('')
|
||||
setFiles([])
|
||||
setTimeoutTimer('sendMessage_1', () => setText(''), 500)
|
||||
setTimeoutTimer('sendMessage_2', () => resizeTextArea(true), 0)
|
||||
setTimeoutTimer('sendMessage_2', () => resizeTextArea(), 0)
|
||||
} catch (error) {
|
||||
logger.warn('Failed to send message:', error as Error)
|
||||
parent?.recordException(error as Error)
|
||||
@@ -478,6 +480,8 @@ const InputbarInner: FC<InputbarInnerProps> = ({ assistant: initialAssistant, se
|
||||
text={text}
|
||||
onTextChange={setText}
|
||||
textareaRef={textareaRef}
|
||||
height={customHeight}
|
||||
onHeightChange={setCustomHeight}
|
||||
resizeTextArea={resizeTextArea}
|
||||
focusTextarea={focusTextarea}
|
||||
isLoading={loading}
|
||||
|
||||
@@ -50,6 +50,9 @@ export interface InputbarCoreProps {
|
||||
resizeTextArea: (force?: boolean) => void
|
||||
focusTextarea: () => void
|
||||
|
||||
height: number | undefined
|
||||
onHeightChange: (height: number) => void
|
||||
|
||||
supportedExts: string[]
|
||||
isLoading: boolean
|
||||
|
||||
@@ -104,6 +107,8 @@ export const InputbarCore: FC<InputbarCoreProps> = ({
|
||||
textareaRef,
|
||||
resizeTextArea,
|
||||
focusTextarea,
|
||||
height,
|
||||
onHeightChange,
|
||||
supportedExts,
|
||||
isLoading,
|
||||
onPause,
|
||||
@@ -131,8 +136,6 @@ export const InputbarCore: FC<InputbarCoreProps> = ({
|
||||
} = useSettings()
|
||||
const quickPanelTriggersEnabled = forceEnableQuickPanelTriggers ?? enableQuickPanelTriggers
|
||||
|
||||
const [textareaHeight, setTextareaHeight] = useState<number>()
|
||||
|
||||
const { t } = useTranslation()
|
||||
const [isTranslating, setIsTranslating] = useState(false)
|
||||
const { getLanguageByLangcode } = useTranslate()
|
||||
@@ -539,7 +542,7 @@ export const InputbarCore: FC<InputbarCoreProps> = ({
|
||||
const handleMouseMove = (e: MouseEvent) => {
|
||||
const deltaY = startDragY.current - e.clientY
|
||||
const newHeight = Math.max(40, Math.min(400, startHeight.current + deltaY))
|
||||
setTextareaHeight(newHeight)
|
||||
onHeightChange(newHeight)
|
||||
}
|
||||
|
||||
const handleMouseUp = () => {
|
||||
@@ -550,7 +553,7 @@ export const InputbarCore: FC<InputbarCoreProps> = ({
|
||||
document.addEventListener('mousemove', handleMouseMove)
|
||||
document.addEventListener('mouseup', handleMouseUp)
|
||||
},
|
||||
[config.enableDragDrop, setTextareaHeight, textareaRef]
|
||||
[config.enableDragDrop, onHeightChange, textareaRef]
|
||||
)
|
||||
|
||||
const onQuote = useCallback(
|
||||
@@ -667,11 +670,11 @@ export const InputbarCore: FC<InputbarCoreProps> = ({
|
||||
variant="borderless"
|
||||
spellCheck={enableSpellCheck}
|
||||
rows={2}
|
||||
autoSize={textareaHeight ? false : { minRows: 2, maxRows: 20 }}
|
||||
autoSize={height ? false : { minRows: 2, maxRows: 20 }}
|
||||
styles={{ textarea: TextareaStyle }}
|
||||
style={{
|
||||
fontSize,
|
||||
height: textareaHeight,
|
||||
height: height,
|
||||
minHeight: '30px'
|
||||
}}
|
||||
disabled={isTranslating || searching}
|
||||
|
||||
Reference in New Issue
Block a user