Compare commits

...

1 Commits

2 changed files with 99 additions and 14 deletions

View File

@@ -6,20 +6,28 @@
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { isImageEnhancementModel, isVisionModel } from '@renderer/config/models' import { isImageEnhancementModel, isVisionModel } from '@renderer/config/models'
import type { Message, Model } from '@renderer/types' import type { Message, Model } from '@renderer/types'
import type { FileMessageBlock, ImageMessageBlock, ThinkingMessageBlock } from '@renderer/types/newMessage' import type {
FileMessageBlock,
ImageMessageBlock,
ThinkingMessageBlock,
ToolMessageBlock
} from '@renderer/types/newMessage'
import { import {
findFileBlocks, findFileBlocks,
findImageBlocks, findImageBlocks,
findThinkingBlocks, findThinkingBlocks,
findToolBlocks,
getMainTextContent getMainTextContent
} from '@renderer/utils/messageUtils/find' } from '@renderer/utils/messageUtils/find'
import type { import type {
AssistantModelMessage, AssistantContent,
FilePart, FilePart,
ImagePart, ImagePart,
ModelMessage, ModelMessage,
SystemModelMessage, SystemModelMessage,
TextPart, TextPart,
ToolCallPart,
ToolResultPart,
UserModelMessage UserModelMessage
} from 'ai' } from 'ai'
@@ -40,10 +48,11 @@ export async function convertMessageToSdkParam(
const fileBlocks = findFileBlocks(message) const fileBlocks = findFileBlocks(message)
const imageBlocks = findImageBlocks(message) const imageBlocks = findImageBlocks(message)
const reasoningBlocks = findThinkingBlocks(message) const reasoningBlocks = findThinkingBlocks(message)
const toolBlocks = findToolBlocks(message)
if (message.role === 'user' || message.role === 'system') { if (message.role === 'user' || message.role === 'system') {
return convertMessageToUserModelMessage(content, fileBlocks, imageBlocks, isVisionModel, model) return convertMessageToUserModelMessage(content, fileBlocks, imageBlocks, isVisionModel, model)
} else { } else {
return convertMessageToAssistantModelMessage(content, fileBlocks, reasoningBlocks, model) return convertMessageToAssistantAndToolMessages(content, fileBlocks, toolBlocks, reasoningBlocks, model)
} }
} }
@@ -147,30 +156,65 @@ async function convertMessageToUserModelMessage(
} }
} }
/** function convertToolBlockToToolCallPart(toolBlock: ToolMessageBlock): ToolCallPart {
* 转换为助手模型消息 return {
*/ type: 'tool-call',
async function convertMessageToAssistantModelMessage( toolCallId: toolBlock.toolId,
toolName: toolBlock.toolName || 'unknown',
input: toolBlock.arguments || {}
}
}
function convertToolBlockToToolResultPart(toolBlock: ToolMessageBlock): ToolResultPart {
const content = toolBlock.content
let output: ToolResultPart['output']
if (content === undefined || content === null) {
output = { type: 'text', value: '' }
} else if (typeof content === 'string') {
output = { type: 'text', value: content }
} else {
output = { type: 'json', value: JSON.parse(JSON.stringify(content)) }
}
return {
type: 'tool-result',
toolCallId: toolBlock.toolId,
toolName: toolBlock.toolName || 'unknown',
output
}
}
function hasToolResult(toolBlock: ToolMessageBlock): boolean {
return toolBlock.content !== undefined && toolBlock.content !== null
}
async function convertMessageToAssistantAndToolMessages(
content: string, content: string,
fileBlocks: FileMessageBlock[], fileBlocks: FileMessageBlock[],
toolBlocks: ToolMessageBlock[],
thinkingBlocks: ThinkingMessageBlock[], thinkingBlocks: ThinkingMessageBlock[],
model?: Model model?: Model
): Promise<AssistantModelMessage> { ): Promise<ModelMessage | ModelMessage[]> {
const parts: Array<TextPart | FilePart> = [] const assistantParts: AssistantContent = []
// 添加文本内容
if (content) { if (content) {
parts.push({ type: 'text', text: content }) assistantParts.push({ type: 'text', text: content })
} }
// 添加推理内容
for (const thinkingBlock of thinkingBlocks) { for (const thinkingBlock of thinkingBlocks) {
parts.push({ type: 'text', text: thinkingBlock.content }) assistantParts.push({ type: 'reasoning', text: thinkingBlock.content })
} }
// 处理文件
for (const fileBlock of fileBlocks) { for (const fileBlock of fileBlocks) {
// 优先尝试原生文件支持PDF等 // 优先尝试原生文件支持PDF等
if (model) { if (model) {
const filePart = await convertFileBlockToFilePart(fileBlock, model) const filePart = await convertFileBlockToFilePart(fileBlock, model)
if (filePart) { if (filePart) {
parts.push(filePart) assistantParts.push(filePart)
continue continue
} }
} }
@@ -178,13 +222,33 @@ async function convertMessageToAssistantModelMessage(
// 回退到文本处理 // 回退到文本处理
const textPart = await convertFileBlockToTextPart(fileBlock) const textPart = await convertFileBlockToTextPart(fileBlock)
if (textPart) { if (textPart) {
parts.push(textPart) assistantParts.push(textPart)
}
}
// 如果没有 tool blocks直接返回 assistant 消息
if (toolBlocks.length === 0) {
return {
role: 'assistant',
content: assistantParts
}
}
// 处理 tool blocks
// 将 tool calls 和 tool results 都添加到 assistant 消息的 content 中
for (const toolBlock of toolBlocks) {
// 添加 tool call
assistantParts.push(convertToolBlockToToolCallPart(toolBlock))
// 如果有结果,添加 tool result
if (hasToolResult(toolBlock)) {
assistantParts.push(convertToolBlockToToolResultPart(toolBlock))
} }
} }
return { return {
role: 'assistant', role: 'assistant',
content: parts content: assistantParts
} }
} }

View File

@@ -9,6 +9,7 @@ import type {
Message, Message,
MessageBlock, MessageBlock,
ThinkingMessageBlock, ThinkingMessageBlock,
ToolMessageBlock,
TranslationMessageBlock TranslationMessageBlock
} from '@renderer/types/newMessage' } from '@renderer/types/newMessage'
import { MessageBlockType } from '@renderer/types/newMessage' import { MessageBlockType } from '@renderer/types/newMessage'
@@ -108,6 +109,26 @@ export const findFileBlocks = (message: Message): FileMessageBlock[] => {
return fileBlocks return fileBlocks
} }
/**
* Finds all ToolMessageBlocks associated with a given message.
* @param message - The message object.
* @returns An array of ToolMessageBlocks (empty if none found).
*/
export const findToolBlocks = (message: Message): ToolMessageBlock[] => {
if (!message || !message.blocks || message.blocks.length === 0) {
return []
}
const state = store.getState()
const toolBlocks: ToolMessageBlock[] = []
for (const blockId of message.blocks) {
const block = messageBlocksSelectors.selectById(state, blockId)
if (block && block.type === MessageBlockType.TOOL) {
toolBlocks.push(block as ToolMessageBlock)
}
}
return toolBlocks
}
/** /**
* Gets the concatenated content string from all MainTextMessageBlocks of a message, in order. * Gets the concatenated content string from all MainTextMessageBlocks of a message, in order.
* @param message - The message object. * @param message - The message object.