Compare commits

..

1 Commits

Author SHA1 Message Date
suyao
776fb27a58 fix: update Inputbar components to support dynamic textarea height adjustment 2025-12-01 01:27:46 +08:00
7 changed files with 34 additions and 124 deletions

View File

@@ -107,7 +107,7 @@ export async function buildStreamTextParams(
searchWithTime: store.getState().websearch.searchWithTime
}
const { providerOptions, standardParams, bodyParams } = buildProviderOptions(assistant, model, provider, {
const { providerOptions, standardParams } = buildProviderOptions(assistant, model, provider, {
enableReasoning,
enableWebSearch,
enableGenerateImage
@@ -185,7 +185,6 @@ export async function buildStreamTextParams(
// Note: standardParams (topK, frequencyPenalty, presencePenalty, stopSequences, seed)
// are extracted from custom parameters and passed directly to streamText()
// instead of being placed in providerOptions
// Note: bodyParams are custom parameters for AI Gateway that should be at body level
const params: StreamTextParams = {
messages: sdkMessages,
maxOutputTokens: getMaxTokens(assistant, model),
@@ -193,8 +192,6 @@ export async function buildStreamTextParams(
topP: getTopP(assistant, model),
// Include AI SDK standard params extracted from custom parameters
...standardParams,
// Include body-level params for AI Gateway custom parameters
...bodyParams,
abortSignal: options.requestOptions?.signal,
headers,
providerOptions,

View File

@@ -37,7 +37,7 @@ vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => {
},
customProviderIdSchema: {
safeParse: vi.fn((id) => {
const customProviders = ['google-vertex', 'google-vertex-anthropic', 'bedrock', 'ai-gateway']
const customProviders = ['google-vertex', 'google-vertex-anthropic', 'bedrock']
if (customProviders.includes(id)) {
return { success: true, data: id }
}
@@ -56,8 +56,7 @@ vi.mock('../provider/factory', () => ({
[SystemProviderIds.anthropic]: 'anthropic',
[SystemProviderIds.grok]: 'xai',
[SystemProviderIds.deepseek]: 'deepseek',
[SystemProviderIds.openrouter]: 'openrouter',
[SystemProviderIds['ai-gateway']]: 'ai-gateway'
[SystemProviderIds.openrouter]: 'openrouter'
}
return mapping[provider.id] || provider.id
})
@@ -205,8 +204,6 @@ describe('options utils', () => {
expect(result.providerOptions).toHaveProperty('openai')
expect(result.providerOptions.openai).toBeDefined()
expect(result.standardParams).toBeDefined()
expect(result.bodyParams).toBeDefined()
expect(result.bodyParams).toEqual({})
})
it('should include reasoning parameters when enabled', () => {
@@ -699,90 +696,5 @@ describe('options utils', () => {
})
})
})
describe('AI Gateway provider', () => {
const aiGatewayProvider: Provider = {
id: SystemProviderIds['ai-gateway'],
name: 'AI Gateway',
type: 'ai-gateway',
apiKey: 'test-key',
apiHost: 'https://ai-gateway.vercel.sh/v1/ai',
isSystem: true,
models: [] as Model[]
} as Provider
const aiGatewayModel: Model = {
id: 'openai/gpt-4',
name: 'GPT-4',
provider: SystemProviderIds['ai-gateway']
} as Model
it('should build basic AI Gateway options with empty bodyParams', () => {
const result = buildProviderOptions(mockAssistant, aiGatewayModel, aiGatewayProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.providerOptions).toHaveProperty('gateway')
expect(result.providerOptions.gateway).toBeDefined()
expect(result.bodyParams).toEqual({})
})
it('should place custom parameters in bodyParams for AI Gateway instead of providerOptions', async () => {
const { getCustomParameters } = await import('../reasoning')
vi.mocked(getCustomParameters).mockReturnValue({
tools: [{ id: 'openai.image_generation' }],
custom_param: 'custom_value'
})
const result = buildProviderOptions(mockAssistant, aiGatewayModel, aiGatewayProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
// Custom parameters should be in bodyParams, NOT in providerOptions.gateway
expect(result.bodyParams).toHaveProperty('tools')
expect(result.bodyParams.tools).toEqual([{ id: 'openai.image_generation' }])
expect(result.bodyParams).toHaveProperty('custom_param')
expect(result.bodyParams.custom_param).toBe('custom_value')
// providerOptions.gateway should NOT contain custom parameters
expect(result.providerOptions.gateway).not.toHaveProperty('tools')
expect(result.providerOptions.gateway).not.toHaveProperty('custom_param')
})
it('should still extract AI SDK standard params from custom parameters for AI Gateway', async () => {
const { getCustomParameters } = await import('../reasoning')
vi.mocked(getCustomParameters).mockReturnValue({
topK: 5,
frequencyPenalty: 0.5,
tools: [{ id: 'openai.image_generation' }]
})
const result = buildProviderOptions(mockAssistant, aiGatewayModel, aiGatewayProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
// Standard params should be extracted and returned separately
expect(result.standardParams).toEqual({
topK: 5,
frequencyPenalty: 0.5
})
// Custom params (non-standard) should be in bodyParams
expect(result.bodyParams).toHaveProperty('tools')
expect(result.bodyParams.tools).toEqual([{ id: 'openai.image_generation' }])
// Neither should be in providerOptions.gateway
expect(result.providerOptions.gateway).not.toHaveProperty('topK')
expect(result.providerOptions.gateway).not.toHaveProperty('tools')
})
})
})
})

View File

@@ -155,7 +155,6 @@ export function buildProviderOptions(
): {
providerOptions: Record<string, Record<string, JSONValue>>
standardParams: Partial<Record<AiSdkParam, any>>
bodyParams: Record<string, any>
} {
logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities })
const rawProviderId = getAiSdkProviderId(actualProvider)
@@ -254,6 +253,12 @@ export function buildProviderOptions(
const customParams = getCustomParameters(assistant)
const { standardParams, providerParams } = extractAiSdkStandardParams(customParams)
// 合并 provider 特定的自定义参数到 providerSpecificOptions
providerSpecificOptions = {
...providerSpecificOptions,
...providerParams
}
let rawProviderKey =
{
'google-vertex': 'google',
@@ -268,27 +273,12 @@ export function buildProviderOptions(
rawProviderKey = { gemini: 'google', ['openai-response']: 'openai' }[actualProvider.type] || actualProvider.type
}
// For AI Gateway, custom parameters should be placed at body level, not inside providerOptions.gateway
// See: https://github.com/CherryHQ/cherry-studio/issues/4197
let bodyParams: Record<string, any> = {}
if (rawProviderKey === 'gateway') {
// Custom parameters go to body level for AI Gateway
bodyParams = providerParams
} else {
// For other providers, merge custom parameters into providerSpecificOptions
providerSpecificOptions = {
...providerSpecificOptions,
...providerParams
}
}
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数和 body 参数
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数
return {
providerOptions: {
[rawProviderKey]: providerSpecificOptions
},
standardParams,
bodyParams
standardParams
}
}

View File

@@ -51,7 +51,7 @@ export function useTextareaResize(options: UseTextareaResizeOptions = {}): UseTe
const { maxHeight = 400, minHeight = 30, autoResize = true } = options
const textareaRef = useRef<TextAreaRef>(null)
const [customHeight, setCustomHeight] = useState<number>()
const [customHeight, setCustomHeight] = useState<number | undefined>(undefined)
const [isExpanded, setIsExpanded] = useState(false)
const resize = useCallback(

View File

@@ -177,8 +177,10 @@ const AgentSessionInputbarInner: FC<InnerProps> = ({ assistant, agentId, session
resize: resizeTextArea,
focus: focusTextarea,
setExpanded,
isExpanded: textareaIsExpanded
} = useTextareaResize({ maxHeight: 400, minHeight: 30 })
isExpanded: textareaIsExpanded,
customHeight,
setCustomHeight
} = useTextareaResize({ maxHeight: 500, minHeight: 30 })
const { sendMessageShortcut, apiServer } = useSettings()
const { t } = useTranslation()
@@ -474,6 +476,8 @@ const AgentSessionInputbarInner: FC<InnerProps> = ({ assistant, agentId, session
text={text}
onTextChange={setText}
textareaRef={textareaRef}
height={customHeight}
onHeightChange={setCustomHeight}
resizeTextArea={resizeTextArea}
focusTextarea={focusTextarea}
placeholder={placeholderText}

View File

@@ -143,9 +143,11 @@ const InputbarInner: FC<InputbarInnerProps> = ({ assistant: initialAssistant, se
resize: resizeTextArea,
focus: focusTextarea,
setExpanded,
isExpanded: textareaIsExpanded
isExpanded: textareaIsExpanded,
customHeight,
setCustomHeight
} = useTextareaResize({
maxHeight: 400,
maxHeight: 500,
minHeight: 30
})
@@ -257,7 +259,7 @@ const InputbarInner: FC<InputbarInnerProps> = ({ assistant: initialAssistant, se
setText('')
setFiles([])
setTimeoutTimer('sendMessage_1', () => setText(''), 500)
setTimeoutTimer('sendMessage_2', () => resizeTextArea(true), 0)
setTimeoutTimer('sendMessage_2', () => resizeTextArea(), 0)
} catch (error) {
logger.warn('Failed to send message:', error as Error)
parent?.recordException(error as Error)
@@ -478,6 +480,8 @@ const InputbarInner: FC<InputbarInnerProps> = ({ assistant: initialAssistant, se
text={text}
onTextChange={setText}
textareaRef={textareaRef}
height={customHeight}
onHeightChange={setCustomHeight}
resizeTextArea={resizeTextArea}
focusTextarea={focusTextarea}
isLoading={loading}

View File

@@ -50,6 +50,9 @@ export interface InputbarCoreProps {
resizeTextArea: (force?: boolean) => void
focusTextarea: () => void
height: number | undefined
onHeightChange: (height: number) => void
supportedExts: string[]
isLoading: boolean
@@ -104,6 +107,8 @@ export const InputbarCore: FC<InputbarCoreProps> = ({
textareaRef,
resizeTextArea,
focusTextarea,
height,
onHeightChange,
supportedExts,
isLoading,
onPause,
@@ -131,8 +136,6 @@ export const InputbarCore: FC<InputbarCoreProps> = ({
} = useSettings()
const quickPanelTriggersEnabled = forceEnableQuickPanelTriggers ?? enableQuickPanelTriggers
const [textareaHeight, setTextareaHeight] = useState<number>()
const { t } = useTranslation()
const [isTranslating, setIsTranslating] = useState(false)
const { getLanguageByLangcode } = useTranslate()
@@ -539,7 +542,7 @@ export const InputbarCore: FC<InputbarCoreProps> = ({
const handleMouseMove = (e: MouseEvent) => {
const deltaY = startDragY.current - e.clientY
const newHeight = Math.max(40, Math.min(400, startHeight.current + deltaY))
setTextareaHeight(newHeight)
onHeightChange(newHeight)
}
const handleMouseUp = () => {
@@ -550,7 +553,7 @@ export const InputbarCore: FC<InputbarCoreProps> = ({
document.addEventListener('mousemove', handleMouseMove)
document.addEventListener('mouseup', handleMouseUp)
},
[config.enableDragDrop, setTextareaHeight, textareaRef]
[config.enableDragDrop, onHeightChange, textareaRef]
)
const onQuote = useCallback(
@@ -667,11 +670,11 @@ export const InputbarCore: FC<InputbarCoreProps> = ({
variant="borderless"
spellCheck={enableSpellCheck}
rows={2}
autoSize={textareaHeight ? false : { minRows: 2, maxRows: 20 }}
autoSize={height ? false : { minRows: 2, maxRows: 20 }}
styles={{ textarea: TextareaStyle }}
style={{
fontSize,
height: textareaHeight,
height: height,
minHeight: '30px'
}}
disabled={isTranslating || searching}