integrate AWS Bedrock API (#8383)
* feat(AWS Bedrock): integrate AWS Bedrock API client and configuration * feat(AWS Bedrock): add AWS Bedrock settings management and UI integration * refactor(AWS Bedrock): refactor AWS Bedrock API client and settings management with vertexai * fix: lint error * refactor: update aws bedrock placeholder * refactor(i18n):update i18n content with aws bedrock * feat(AwsBedrockAPIClient): enhance message handling, add image support * fix: code review suggestion * feat(test): add aws bedrock utils unit test * feat(AwsBedrockAPIClient): enhance getEmbeddingDimensions method to support dynamic model dimension retrieval * fix(AwsBedrockAPIClient): Modify the processing logic when the embedded dimension cannot be parsed, throw an error instead of returning the default value * chore(package): Reorganize AWS SDK dependencies in package.json
This commit is contained in:
@@ -3,6 +3,7 @@ import { Provider } from '@renderer/types'
|
||||
|
||||
import { AihubmixAPIClient } from './AihubmixAPIClient'
|
||||
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
||||
import { AwsBedrockAPIClient } from './aws/AwsBedrockAPIClient'
|
||||
import { BaseApiClient } from './BaseApiClient'
|
||||
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
|
||||
import { VertexAPIClient } from './gemini/VertexAPIClient'
|
||||
@@ -65,6 +66,9 @@ export class ApiClientFactory {
|
||||
case 'anthropic':
|
||||
instance = new AnthropicAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'aws-bedrock':
|
||||
instance = new AwsBedrockAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
default:
|
||||
logger.debug(`Using default OpenAIApiClient for provider: ${provider.id}`)
|
||||
instance = new OpenAIAPIClient(provider) as BaseApiClient
|
||||
|
||||
@@ -0,0 +1,620 @@
|
||||
import {
|
||||
BedrockRuntimeClient,
|
||||
ConverseCommand,
|
||||
ConverseStreamCommand,
|
||||
InvokeModelCommand
|
||||
} from '@aws-sdk/client-bedrock-runtime'
|
||||
import { loggerService } from '@logger'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import {
|
||||
getAwsBedrockAccessKeyId,
|
||||
getAwsBedrockRegion,
|
||||
getAwsBedrockSecretAccessKey
|
||||
} from '@renderer/hooks/useAwsBedrock'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
import {
|
||||
GenerateImageParams,
|
||||
MCPCallToolResponse,
|
||||
MCPTool,
|
||||
MCPToolResponse,
|
||||
Model,
|
||||
Provider,
|
||||
ToolCallResponse
|
||||
} from '@renderer/types'
|
||||
import { ChunkType, MCPToolCreatedChunk, TextDeltaChunk } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import {
|
||||
AwsBedrockSdkInstance,
|
||||
AwsBedrockSdkMessageParam,
|
||||
AwsBedrockSdkParams,
|
||||
AwsBedrockSdkRawChunk,
|
||||
AwsBedrockSdkRawOutput,
|
||||
AwsBedrockSdkTool,
|
||||
AwsBedrockSdkToolCall,
|
||||
SdkModel
|
||||
} from '@renderer/types/sdk'
|
||||
import { convertBase64ImageToAwsBedrockFormat } from '@renderer/utils/aws-bedrock-utils'
|
||||
import {
|
||||
awsBedrockToolUseToMcpTool,
|
||||
isEnabledToolUse,
|
||||
mcpToolCallResponseToAwsBedrockMessage,
|
||||
mcpToolsToAwsBedrockTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import { RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
|
||||
const logger = loggerService.withContext('AwsBedrockAPIClient')
|
||||
|
||||
export class AwsBedrockAPIClient extends BaseApiClient<
|
||||
AwsBedrockSdkInstance,
|
||||
AwsBedrockSdkParams,
|
||||
AwsBedrockSdkRawOutput,
|
||||
AwsBedrockSdkRawChunk,
|
||||
AwsBedrockSdkMessageParam,
|
||||
AwsBedrockSdkToolCall,
|
||||
AwsBedrockSdkTool
|
||||
> {
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
async getSdkInstance(): Promise<AwsBedrockSdkInstance> {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
const region = getAwsBedrockRegion()
|
||||
const accessKeyId = getAwsBedrockAccessKeyId()
|
||||
const secretAccessKey = getAwsBedrockSecretAccessKey()
|
||||
|
||||
if (!region) {
|
||||
throw new Error('AWS region is required. Please configure AWS-Region in extra headers.')
|
||||
}
|
||||
|
||||
if (!accessKeyId || !secretAccessKey) {
|
||||
throw new Error('AWS credentials are required. Please configure AWS-Access-Key-ID and AWS-Secret-Access-Key.')
|
||||
}
|
||||
|
||||
const client = new BedrockRuntimeClient({
|
||||
region,
|
||||
credentials: {
|
||||
accessKeyId,
|
||||
secretAccessKey
|
||||
}
|
||||
})
|
||||
|
||||
this.sdkInstance = { client, region }
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
override async createCompletions(payload: AwsBedrockSdkParams): Promise<AwsBedrockSdkRawOutput> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
|
||||
// 转换消息格式到AWS SDK原生格式
|
||||
const awsMessages = payload.messages.map((msg) => ({
|
||||
role: msg.role,
|
||||
content: msg.content.map((content) => {
|
||||
if (content.text) {
|
||||
return { text: content.text }
|
||||
}
|
||||
if (content.image) {
|
||||
return {
|
||||
image: {
|
||||
format: content.image.format,
|
||||
source: content.image.source
|
||||
}
|
||||
}
|
||||
}
|
||||
if (content.toolResult) {
|
||||
return {
|
||||
toolResult: {
|
||||
toolUseId: content.toolResult.toolUseId,
|
||||
content: content.toolResult.content,
|
||||
status: content.toolResult.status
|
||||
}
|
||||
}
|
||||
}
|
||||
if (content.toolUse) {
|
||||
return {
|
||||
toolUse: {
|
||||
toolUseId: content.toolUse.toolUseId,
|
||||
name: content.toolUse.name,
|
||||
input: content.toolUse.input
|
||||
}
|
||||
}
|
||||
}
|
||||
// 返回符合AWS SDK ContentBlock类型的对象
|
||||
return { text: 'Unknown content type' }
|
||||
})
|
||||
}))
|
||||
|
||||
const commonParams = {
|
||||
modelId: payload.modelId,
|
||||
messages: awsMessages as any,
|
||||
system: payload.system ? [{ text: payload.system }] : undefined,
|
||||
inferenceConfig: {
|
||||
maxTokens: payload.maxTokens || DEFAULT_MAX_TOKENS,
|
||||
temperature: payload.temperature || 0.7,
|
||||
topP: payload.topP || 1
|
||||
},
|
||||
toolConfig:
|
||||
payload.tools && payload.tools.length > 0
|
||||
? {
|
||||
tools: payload.tools
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
|
||||
try {
|
||||
if (payload.stream) {
|
||||
const command = new ConverseStreamCommand(commonParams)
|
||||
const response = await sdk.client.send(command)
|
||||
// 直接返回AWS Bedrock流式响应的异步迭代器
|
||||
return this.createStreamIterator(response)
|
||||
} else {
|
||||
const command = new ConverseCommand(commonParams)
|
||||
const response = await sdk.client.send(command)
|
||||
return { output: response }
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to create completions with AWS Bedrock:', error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
private async *createStreamIterator(response: any): AsyncIterable<AwsBedrockSdkRawChunk> {
|
||||
try {
|
||||
if (response.stream) {
|
||||
for await (const chunk of response.stream) {
|
||||
logger.debug('AWS Bedrock chunk received:', chunk)
|
||||
|
||||
// AWS Bedrock的流式响应格式转换为标准格式
|
||||
if (chunk.contentBlockDelta?.delta?.text) {
|
||||
yield {
|
||||
contentBlockDelta: {
|
||||
delta: { text: chunk.contentBlockDelta.delta.text }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (chunk.messageStart) {
|
||||
yield { messageStart: chunk.messageStart }
|
||||
}
|
||||
|
||||
if (chunk.messageStop) {
|
||||
yield { messageStop: chunk.messageStop }
|
||||
}
|
||||
|
||||
if (chunk.metadata) {
|
||||
yield { metadata: chunk.metadata }
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error in AWS Bedrock stream iterator:', error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// @ts-ignore sdk未提供
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
override async generateImage(_generateImageParams: GenerateImageParams): Promise<string[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
override async getEmbeddingDimensions(model?: Model): Promise<number> {
|
||||
if (!model) {
|
||||
throw new Error('Model is required for AWS Bedrock embedding dimensions.')
|
||||
}
|
||||
|
||||
const sdk = await this.getSdkInstance()
|
||||
|
||||
// AWS Bedrock 支持的嵌入模型及其维度
|
||||
const embeddingModels: Record<string, number> = {
|
||||
'cohere.embed-english-v3': 1024,
|
||||
'cohere.embed-multilingual-v3': 1024,
|
||||
// Amazon Titan embeddings
|
||||
'amazon.titan-embed-text-v1': 1536,
|
||||
'amazon.titan-embed-text-v2:0': 1024
|
||||
// 可以根据需要添加更多模型
|
||||
}
|
||||
|
||||
// 如果是已知的嵌入模型,直接返回维度
|
||||
if (embeddingModels[model.id]) {
|
||||
return embeddingModels[model.id]
|
||||
}
|
||||
|
||||
// 对于未知模型,尝试实际调用API获取维度
|
||||
try {
|
||||
let requestBody: any
|
||||
|
||||
if (model.id.startsWith('cohere.embed')) {
|
||||
// Cohere Embed API 格式
|
||||
requestBody = {
|
||||
texts: ['test'],
|
||||
input_type: 'search_document',
|
||||
embedding_types: ['float']
|
||||
}
|
||||
} else if (model.id.startsWith('amazon.titan-embed')) {
|
||||
// Amazon Titan Embed API 格式
|
||||
requestBody = {
|
||||
inputText: 'test'
|
||||
}
|
||||
} else {
|
||||
// 通用格式,大多数嵌入模型都支持
|
||||
requestBody = {
|
||||
inputText: 'test'
|
||||
}
|
||||
}
|
||||
|
||||
const command = new InvokeModelCommand({
|
||||
modelId: model.id,
|
||||
body: JSON.stringify(requestBody),
|
||||
contentType: 'application/json',
|
||||
accept: 'application/json'
|
||||
})
|
||||
|
||||
const response = await sdk.client.send(command)
|
||||
const responseBody = JSON.parse(new TextDecoder().decode(response.body))
|
||||
|
||||
// 解析响应获取嵌入维度
|
||||
if (responseBody.embeddings && responseBody.embeddings.length > 0) {
|
||||
// Cohere 格式
|
||||
if (responseBody.embeddings[0].values) {
|
||||
return responseBody.embeddings[0].values.length
|
||||
}
|
||||
// 其他可能的格式
|
||||
if (Array.isArray(responseBody.embeddings[0])) {
|
||||
return responseBody.embeddings[0].length
|
||||
}
|
||||
}
|
||||
|
||||
if (responseBody.embedding && Array.isArray(responseBody.embedding)) {
|
||||
// Amazon Titan 格式
|
||||
return responseBody.embedding.length
|
||||
}
|
||||
|
||||
// 如果无法解析,则抛出错误
|
||||
throw new Error(`Unable to determine embedding dimensions for model ${model.id}`)
|
||||
} catch (error) {
|
||||
logger.error('Failed to get embedding dimensions from AWS Bedrock:', error as Error)
|
||||
|
||||
// 根据模型名称推测维度
|
||||
if (model.id.includes('titan')) {
|
||||
return 1536 // Amazon Titan 默认维度
|
||||
}
|
||||
if (model.id.includes('cohere')) {
|
||||
return 1024 // Cohere 默认维度
|
||||
}
|
||||
|
||||
throw new Error(`Unable to determine embedding dimensions for model ${model.id}: ${(error as Error).message}`)
|
||||
}
|
||||
}
|
||||
|
||||
// @ts-ignore sdk未提供
|
||||
override async listModels(): Promise<SdkModel[]> {
|
||||
return []
|
||||
}
|
||||
|
||||
public async convertMessageToSdkParam(message: Message): Promise<AwsBedrockSdkMessageParam> {
|
||||
const content = await this.getMessageContent(message)
|
||||
const parts: Array<{
|
||||
text?: string
|
||||
image?: {
|
||||
format: 'png' | 'jpeg' | 'gif' | 'webp'
|
||||
source: {
|
||||
bytes?: Uint8Array
|
||||
s3Location?: {
|
||||
uri: string
|
||||
bucketOwner?: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}> = []
|
||||
|
||||
// 添加文本内容 - 只在有非空内容时添加
|
||||
if (content && content.trim()) {
|
||||
parts.push({ text: content })
|
||||
}
|
||||
|
||||
// 处理图片内容
|
||||
const imageBlocks = findImageBlocks(message)
|
||||
for (const imageBlock of imageBlocks) {
|
||||
if (imageBlock.file) {
|
||||
try {
|
||||
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
|
||||
const mimeType = image.mime || 'image/png'
|
||||
const base64Data = image.base64
|
||||
|
||||
const awsImage = convertBase64ImageToAwsBedrockFormat(base64Data, mimeType)
|
||||
if (awsImage) {
|
||||
parts.push({ image: awsImage })
|
||||
} else {
|
||||
// 不支持的格式,转换为文本描述
|
||||
parts.push({ text: `[Image: ${mimeType}]` })
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error processing image:', error as Error)
|
||||
parts.push({ text: '[Image processing failed]' })
|
||||
}
|
||||
} else if (imageBlock.url && imageBlock.url.startsWith('data:')) {
|
||||
try {
|
||||
// 处理base64图片URL
|
||||
const matches = imageBlock.url.match(/^data:(.+);base64,(.*)$/)
|
||||
if (matches && matches.length === 3) {
|
||||
const mimeType = matches[1]
|
||||
const base64Data = matches[2]
|
||||
|
||||
const awsImage = convertBase64ImageToAwsBedrockFormat(base64Data, mimeType)
|
||||
if (awsImage) {
|
||||
parts.push({ image: awsImage })
|
||||
} else {
|
||||
parts.push({ text: `[Image: ${mimeType}]` })
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error processing base64 image:', error as Error)
|
||||
parts.push({ text: '[Image processing failed]' })
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有任何内容,添加默认文本而不是空文本
|
||||
if (parts.length === 0) {
|
||||
parts.push({ text: 'No content provided' })
|
||||
}
|
||||
|
||||
return {
|
||||
role: message.role === 'system' ? 'user' : message.role,
|
||||
content: parts
|
||||
}
|
||||
}
|
||||
|
||||
getRequestTransformer(): RequestTransformer<AwsBedrockSdkParams, AwsBedrockSdkMessageParam> {
|
||||
return {
|
||||
transform: async (
|
||||
coreRequest,
|
||||
assistant,
|
||||
model,
|
||||
isRecursiveCall,
|
||||
recursiveSdkMessages
|
||||
): Promise<{
|
||||
payload: AwsBedrockSdkParams
|
||||
messages: AwsBedrockSdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const { messages, mcpTools, maxTokens, streamOutput } = coreRequest
|
||||
// 1. 处理系统消息
|
||||
const systemPrompt = assistant.prompt
|
||||
// 2. 设置工具
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
})
|
||||
|
||||
// 3. 处理消息
|
||||
const sdkMessages: AwsBedrockSdkMessageParam[] = []
|
||||
if (typeof messages === 'string') {
|
||||
sdkMessages.push({ role: 'user', content: [{ text: messages }] })
|
||||
} else {
|
||||
for (const message of messages) {
|
||||
sdkMessages.push(await this.convertMessageToSdkParam(message))
|
||||
}
|
||||
}
|
||||
|
||||
const payload: AwsBedrockSdkParams = {
|
||||
modelId: model.id,
|
||||
messages:
|
||||
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||
? recursiveSdkMessages
|
||||
: sdkMessages,
|
||||
system: systemPrompt,
|
||||
maxTokens: maxTokens || DEFAULT_MAX_TOKENS,
|
||||
temperature: this.getTemperature(assistant, model),
|
||||
topP: this.getTopP(assistant, model),
|
||||
stream: streamOutput !== false,
|
||||
tools: tools.length > 0 ? tools : undefined
|
||||
}
|
||||
|
||||
const timeout = this.getTimeout(model)
|
||||
return { payload, messages: sdkMessages, metadata: { timeout } }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
getResponseChunkTransformer(): ResponseChunkTransformer<AwsBedrockSdkRawChunk> {
|
||||
return () => {
|
||||
let hasStartedText = false
|
||||
let accumulatedJson = ''
|
||||
const toolCalls: Record<number, AwsBedrockSdkToolCall> = {}
|
||||
|
||||
return {
|
||||
async transform(rawChunk: AwsBedrockSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||
logger.silly('Processing AWS Bedrock chunk:', rawChunk)
|
||||
|
||||
// 处理消息开始事件
|
||||
if (rawChunk.messageStart) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_START
|
||||
})
|
||||
hasStartedText = true
|
||||
logger.debug('Message started')
|
||||
}
|
||||
|
||||
// 处理内容块开始事件 - 参考 Anthropic 的 content_block_start 处理
|
||||
if (rawChunk.contentBlockStart?.start?.toolUse) {
|
||||
const toolUse = rawChunk.contentBlockStart.start.toolUse
|
||||
const blockIndex = rawChunk.contentBlockStart.contentBlockIndex || 0
|
||||
toolCalls[blockIndex] = {
|
||||
id: toolUse.toolUseId, // 设置 id 字段与 toolUseId 相同
|
||||
name: toolUse.name,
|
||||
toolUseId: toolUse.toolUseId,
|
||||
input: {}
|
||||
}
|
||||
logger.debug('Tool use started:', toolUse)
|
||||
}
|
||||
|
||||
// 处理内容块增量事件 - 参考 Anthropic 的 content_block_delta 处理
|
||||
if (rawChunk.contentBlockDelta?.delta?.toolUse?.input) {
|
||||
const inputDelta = rawChunk.contentBlockDelta.delta.toolUse.input
|
||||
accumulatedJson += inputDelta
|
||||
}
|
||||
|
||||
// 处理文本增量
|
||||
if (rawChunk.contentBlockDelta?.delta?.text) {
|
||||
if (!hasStartedText) {
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_START
|
||||
})
|
||||
hasStartedText = true
|
||||
}
|
||||
|
||||
controller.enqueue({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: rawChunk.contentBlockDelta.delta.text
|
||||
} as TextDeltaChunk)
|
||||
}
|
||||
|
||||
// 处理内容块停止事件 - 参考 Anthropic 的 content_block_stop 处理
|
||||
if (rawChunk.contentBlockStop) {
|
||||
const blockIndex = rawChunk.contentBlockStop.contentBlockIndex || 0
|
||||
const toolCall = toolCalls[blockIndex]
|
||||
if (toolCall && accumulatedJson) {
|
||||
try {
|
||||
toolCall.input = JSON.parse(accumulatedJson)
|
||||
controller.enqueue({
|
||||
type: ChunkType.MCP_TOOL_CREATED,
|
||||
tool_calls: [toolCall]
|
||||
} as MCPToolCreatedChunk)
|
||||
accumulatedJson = ''
|
||||
} catch (error) {
|
||||
logger.error('Error parsing tool call input:', error as Error)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 处理消息结束事件
|
||||
if (rawChunk.messageStop) {
|
||||
// 从metadata中提取usage信息
|
||||
const usage = rawChunk.metadata?.usage || {}
|
||||
|
||||
controller.enqueue({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage: {
|
||||
prompt_tokens: usage.inputTokens || 0,
|
||||
completion_tokens: usage.outputTokens || 0,
|
||||
total_tokens: (usage.inputTokens || 0) + (usage.outputTokens || 0)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): AwsBedrockSdkTool[] {
|
||||
return mcpToolsToAwsBedrockTools(mcpTools)
|
||||
}
|
||||
|
||||
convertSdkToolCallToMcp(toolCall: AwsBedrockSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||
return awsBedrockToolUseToMcpTool(mcpTools, toolCall)
|
||||
}
|
||||
|
||||
convertSdkToolCallToMcpToolResponse(toolCall: AwsBedrockSdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
||||
return {
|
||||
id: toolCall.id,
|
||||
tool: mcpTool,
|
||||
arguments: toolCall.input || {},
|
||||
status: 'pending',
|
||||
toolCallId: toolCall.id
|
||||
}
|
||||
}
|
||||
|
||||
override buildSdkMessages(
|
||||
currentReqMessages: AwsBedrockSdkMessageParam[],
|
||||
output: AwsBedrockSdkRawOutput | string | undefined,
|
||||
toolResults: AwsBedrockSdkMessageParam[]
|
||||
): AwsBedrockSdkMessageParam[] {
|
||||
const messages: AwsBedrockSdkMessageParam[] = [...currentReqMessages]
|
||||
|
||||
if (typeof output === 'string') {
|
||||
messages.push({
|
||||
role: 'assistant',
|
||||
content: [{ text: output }]
|
||||
})
|
||||
}
|
||||
|
||||
if (toolResults.length > 0) {
|
||||
messages.push(...toolResults)
|
||||
}
|
||||
|
||||
return messages
|
||||
}
|
||||
|
||||
override estimateMessageTokens(message: AwsBedrockSdkMessageParam): number {
|
||||
if (typeof message.content === 'string') {
|
||||
return estimateTextTokens(message.content)
|
||||
}
|
||||
const content = message.content
|
||||
if (Array.isArray(content)) {
|
||||
return content.reduce((total, item) => {
|
||||
if (item.text) {
|
||||
return total + estimateTextTokens(item.text)
|
||||
}
|
||||
return total
|
||||
}, 0)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
public convertMcpToolResponseToSdkMessageParam(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): AwsBedrockSdkMessageParam | undefined {
|
||||
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||
// 使用专用的转换函数处理 toolUseId 情况
|
||||
return mcpToolCallResponseToAwsBedrockMessage(mcpToolResponse, resp, model)
|
||||
} else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) {
|
||||
return {
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
toolResult: {
|
||||
toolUseId: mcpToolResponse.toolCallId,
|
||||
content: resp.content
|
||||
.map((item) => {
|
||||
if (item.type === 'text') {
|
||||
// 确保文本不为空,如果为空则提供默认文本
|
||||
return { text: item.text && item.text.trim() ? item.text : 'No text content' }
|
||||
}
|
||||
if (item.type === 'image' && item.data) {
|
||||
const awsImage = convertBase64ImageToAwsBedrockFormat(item.data, item.mimeType)
|
||||
if (awsImage) {
|
||||
return { image: awsImage }
|
||||
} else {
|
||||
// 如果转换失败,返回描述性文本
|
||||
return { text: `[Image: ${item.mimeType || 'unknown format'}]` }
|
||||
}
|
||||
}
|
||||
return { text: JSON.stringify(item) }
|
||||
})
|
||||
.filter((content) => content !== null)
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
return undefined
|
||||
}
|
||||
|
||||
extractMessagesFromSdkPayload(sdkPayload: AwsBedrockSdkParams): AwsBedrockSdkMessageParam[] {
|
||||
return sdkPayload.messages || []
|
||||
}
|
||||
}
|
||||
@@ -45,7 +45,7 @@ export const StreamAdapterMiddleware: CompletionsMiddleware =
|
||||
} else if (result.rawOutput) {
|
||||
// 非流式输出,强行变为可读流
|
||||
const whatwgReadableStream: ReadableStream<SdkRawChunk> = createSingleChunkReadableStream<SdkRawChunk>(
|
||||
result.rawOutput
|
||||
result.rawOutput as SdkRawChunk
|
||||
)
|
||||
return {
|
||||
...result,
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 182 KiB |
@@ -2345,7 +2345,8 @@ export const SYSTEM_MODELS: Record<string, Model[]> = {
|
||||
group: 'google'
|
||||
}
|
||||
],
|
||||
'new-api': []
|
||||
'new-api': [],
|
||||
'aws-bedrock': []
|
||||
}
|
||||
|
||||
export const TEXT_TO_IMAGES_MODELS = [
|
||||
|
||||
@@ -5,6 +5,7 @@ import Ai302ProviderLogo from '@renderer/assets/images/providers/302ai.webp'
|
||||
import AiHubMixProviderLogo from '@renderer/assets/images/providers/aihubmix.webp'
|
||||
import AlayaNewProviderLogo from '@renderer/assets/images/providers/alayanew.webp'
|
||||
import AnthropicProviderLogo from '@renderer/assets/images/providers/anthropic.png'
|
||||
import AwsProviderLogo from '@renderer/assets/images/providers/aws-bedrock.png'
|
||||
import BaichuanProviderLogo from '@renderer/assets/images/providers/baichuan.png'
|
||||
import BaiduCloudProviderLogo from '@renderer/assets/images/providers/baidu-cloud.svg'
|
||||
import BailianProviderLogo from '@renderer/assets/images/providers/bailian.png'
|
||||
@@ -106,7 +107,8 @@ const PROVIDER_LOGO_MAP = {
|
||||
cephalon: CephalonProviderLogo,
|
||||
lanyun: LanyunProviderLogo,
|
||||
vertexai: VertexAIProviderLogo,
|
||||
'new-api': NewAPIProviderLogo
|
||||
'new-api': NewAPIProviderLogo,
|
||||
'aws-bedrock': AwsProviderLogo
|
||||
} as const
|
||||
|
||||
export function getProviderLogo(providerId: string) {
|
||||
@@ -689,5 +691,16 @@ export const PROVIDER_CONFIG = {
|
||||
official: 'https://docs.newapi.pro/',
|
||||
docs: 'https://docs.newapi.pro'
|
||||
}
|
||||
},
|
||||
'aws-bedrock': {
|
||||
api: {
|
||||
url: ''
|
||||
},
|
||||
websites: {
|
||||
official: 'https://aws.amazon.com/bedrock/',
|
||||
apiKey: 'https://docs.aws.amazon.com/bedrock/latest/userguide/security-iam.html',
|
||||
docs: 'https://docs.aws.amazon.com/bedrock/',
|
||||
models: 'https://docs.aws.amazon.com/bedrock/latest/userguide/models-supported.html'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
import store, { useAppSelector } from '@renderer/store'
|
||||
import { setAwsBedrockAccessKeyId, setAwsBedrockRegion, setAwsBedrockSecretAccessKey } from '@renderer/store/llm'
|
||||
import { useDispatch } from 'react-redux'
|
||||
|
||||
export function useAwsBedrockSettings() {
|
||||
const settings = useAppSelector((state) => state.llm.settings.awsBedrock)
|
||||
const dispatch = useDispatch()
|
||||
|
||||
return {
|
||||
...settings,
|
||||
setAccessKeyId: (accessKeyId: string) => dispatch(setAwsBedrockAccessKeyId(accessKeyId)),
|
||||
setSecretAccessKey: (secretAccessKey: string) => dispatch(setAwsBedrockSecretAccessKey(secretAccessKey)),
|
||||
setRegion: (region: string) => dispatch(setAwsBedrockRegion(region))
|
||||
}
|
||||
}
|
||||
|
||||
export function getAwsBedrockSettings() {
|
||||
return store.getState().llm.settings.awsBedrock
|
||||
}
|
||||
|
||||
export function getAwsBedrockAccessKeyId() {
|
||||
return store.getState().llm.settings.awsBedrock.accessKeyId
|
||||
}
|
||||
|
||||
export function getAwsBedrockSecretAccessKey() {
|
||||
return store.getState().llm.settings.awsBedrock.secretAccessKey
|
||||
}
|
||||
|
||||
export function getAwsBedrockRegion() {
|
||||
return store.getState().llm.settings.awsBedrock.region
|
||||
}
|
||||
@@ -13,6 +13,7 @@ const providerKeyMap = {
|
||||
aihubmix: 'provider.aihubmix',
|
||||
alayanew: 'provider.alayanew',
|
||||
anthropic: 'provider.anthropic',
|
||||
'aws-bedrock': 'provider.aws-bedrock',
|
||||
'azure-openai': 'provider.azure-openai',
|
||||
baichuan: 'provider.baichuan',
|
||||
'baidu-cloud': 'provider.baidu-cloud',
|
||||
|
||||
@@ -1594,6 +1594,7 @@
|
||||
"aihubmix": "AiHubMix",
|
||||
"alayanew": "Alaya NeW",
|
||||
"anthropic": "Anthropic",
|
||||
"aws-bedrock": "AWS Bedrock",
|
||||
"azure-openai": "Azure OpenAI",
|
||||
"baichuan": "Baichuan",
|
||||
"baidu-cloud": "Baidu Cloud",
|
||||
@@ -3035,6 +3036,16 @@
|
||||
"tip": "Multiple keys separated by commas or spaces"
|
||||
},
|
||||
"api_version": "API Version",
|
||||
"aws-bedrock": {
|
||||
"access_key_id": "AWS Access Key ID",
|
||||
"access_key_id_help": "Your AWS Access Key ID for accessing AWS Bedrock services",
|
||||
"description": "AWS Bedrock is Amazon's fully managed foundation model service that supports various advanced large language models",
|
||||
"region": "AWS Region",
|
||||
"region_help": "Your AWS service region, e.g., us-east-1",
|
||||
"secret_access_key": "AWS Secret Access Key",
|
||||
"secret_access_key_help": "Your AWS Secret Access Key, please keep it secure",
|
||||
"title": "AWS Bedrock Configuration"
|
||||
},
|
||||
"azure": {
|
||||
"apiversion": {
|
||||
"tip": "The API version of Azure OpenAI, if you want to use Response API, please enter the preview version"
|
||||
|
||||
@@ -1594,6 +1594,7 @@
|
||||
"aihubmix": "AiHubMix",
|
||||
"alayanew": "Alaya NeW",
|
||||
"anthropic": "Anthropic",
|
||||
"aws-bedrock": "AWS Bedrock",
|
||||
"azure-openai": "Azure OpenAI",
|
||||
"baichuan": "百川",
|
||||
"baidu-cloud": "Baidu Cloud",
|
||||
@@ -3035,6 +3036,16 @@
|
||||
"tip": "複数のキーはカンマまたはスペースで区切ります"
|
||||
},
|
||||
"api_version": "APIバージョン",
|
||||
"aws-bedrock": {
|
||||
"access_key_id": "AWS アクセスキー ID",
|
||||
"access_key_id_help": "あなたの AWS アクセスキー ID は、AWS Bedrock サービスへのアクセスに使用されます",
|
||||
"description": "AWS Bedrock は、Amazon が提供する完全に管理されたベースモデルサービスで、さまざまな最先端の大言語モデルをサポートしています",
|
||||
"region": "AWS リージョン",
|
||||
"region_help": "あなたの AWS サービスリージョン、例:us-east-1",
|
||||
"secret_access_key": "AWS アクセスキー",
|
||||
"secret_access_key_help": "あなたの AWS アクセスキー、安全に保管してください",
|
||||
"title": "AWS Bedrock 設定"
|
||||
},
|
||||
"azure": {
|
||||
"apiversion": {
|
||||
"tip": "Azure OpenAIのAPIバージョン。Response APIを使用する場合は、previewバージョンを入力してください"
|
||||
|
||||
@@ -1594,6 +1594,7 @@
|
||||
"aihubmix": "AiHubMix",
|
||||
"alayanew": "Alaya NeW",
|
||||
"anthropic": "Anthropic",
|
||||
"aws-bedrock": "AWS Bedrock",
|
||||
"azure-openai": "Azure OpenAI",
|
||||
"baichuan": "Baichuan",
|
||||
"baidu-cloud": "Baidu Cloud",
|
||||
@@ -3035,6 +3036,16 @@
|
||||
"tip": "Несколько ключей, разделенных запятыми или пробелами"
|
||||
},
|
||||
"api_version": "Версия API",
|
||||
"aws-bedrock": {
|
||||
"access_key_id": "AWS Ключ доступа ID",
|
||||
"access_key_id_help": "Ваш AWS Ключ доступа ID для доступа к AWS Bedrock",
|
||||
"description": "AWS Bedrock — это полное управляемое сервисное предложение для моделей, поддерживающее различные современные модели языка",
|
||||
"region": "AWS регион",
|
||||
"region_help": "Ваш регион AWS, например us-east-1",
|
||||
"secret_access_key": "AWS Ключ доступа",
|
||||
"secret_access_key_help": "Ваш AWS Ключ доступа, пожалуйста, храните его в безопасности",
|
||||
"title": "AWS Bedrock Конфигурация"
|
||||
},
|
||||
"azure": {
|
||||
"apiversion": {
|
||||
"tip": "Версия API Azure OpenAI. Если вы хотите использовать Response API, введите версию preview"
|
||||
|
||||
@@ -1594,6 +1594,7 @@
|
||||
"aihubmix": "AiHubMix",
|
||||
"alayanew": "Alaya NeW",
|
||||
"anthropic": "Anthropic",
|
||||
"aws-bedrock": "AWS Bedrock",
|
||||
"azure-openai": "Azure OpenAI",
|
||||
"baichuan": "百川",
|
||||
"baidu-cloud": "百度云千帆",
|
||||
@@ -3035,6 +3036,16 @@
|
||||
"tip": "多个密钥使用逗号或空格分隔"
|
||||
},
|
||||
"api_version": "API 版本",
|
||||
"aws-bedrock": {
|
||||
"access_key_id": "AWS 访问密钥 ID",
|
||||
"access_key_id_help": "您的 AWS 访问密钥 ID,用于访问 AWS Bedrock 服务",
|
||||
"description": "AWS Bedrock 是亚马逊提供的全托管基础模型服务,支持多种先进的大语言模型",
|
||||
"region": "AWS 区域",
|
||||
"region_help": "您的 AWS 服务区域,例如 us-east-1",
|
||||
"secret_access_key": "AWS 访问密钥",
|
||||
"secret_access_key_help": "您的 AWS 访问密钥,请妥善保管",
|
||||
"title": "AWS Bedrock 配置"
|
||||
},
|
||||
"azure": {
|
||||
"apiversion": {
|
||||
"tip": "Azure OpenAI 的 API 版本,如果想要使用 Response API,请输入 preview 版本"
|
||||
|
||||
@@ -1594,6 +1594,7 @@
|
||||
"aihubmix": "AiHubMix",
|
||||
"alayanew": "Alaya NeW",
|
||||
"anthropic": "Anthropic",
|
||||
"aws-bedrock": "AWS Bedrock",
|
||||
"azure-openai": "Azure OpenAI",
|
||||
"baichuan": "百川",
|
||||
"baidu-cloud": "百度雲千帆",
|
||||
@@ -3035,6 +3036,16 @@
|
||||
"tip": "多個金鑰使用逗號或空格分隔"
|
||||
},
|
||||
"api_version": "API 版本",
|
||||
"aws-bedrock": {
|
||||
"access_key_id": "AWS 存取密鑰 ID",
|
||||
"access_key_id_help": "您的 AWS 存取密鑰 ID,用於存取 AWS Bedrock 服務",
|
||||
"description": "AWS Bedrock 是亞馬遜提供的全托管基础模型服務,支持多種先進的大語言模型",
|
||||
"region": "AWS 區域",
|
||||
"region_help": "您的 AWS 服務區域,例如 us-east-1",
|
||||
"secret_access_key": "AWS 存取密鑰",
|
||||
"secret_access_key_help": "您的 AWS 存取密鑰,請妥善保管",
|
||||
"title": "AWS Bedrock 設定"
|
||||
},
|
||||
"azure": {
|
||||
"apiversion": {
|
||||
"tip": "Azure OpenAI 的 API 版本,如果想要使用 Response API,請輸入 preview 版本"
|
||||
|
||||
@@ -0,0 +1,74 @@
|
||||
import { HStack } from '@renderer/components/Layout'
|
||||
import { PROVIDER_CONFIG } from '@renderer/config/providers'
|
||||
import { useAwsBedrockSettings } from '@renderer/hooks/useAwsBedrock'
|
||||
import { Alert, Input } from 'antd'
|
||||
import { FC, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
import { SettingHelpLink, SettingHelpText, SettingHelpTextRow, SettingSubtitle } from '..'
|
||||
|
||||
const AwsBedrockSettings: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const { accessKeyId, secretAccessKey, region, setAccessKeyId, setSecretAccessKey, setRegion } =
|
||||
useAwsBedrockSettings()
|
||||
|
||||
const providerConfig = PROVIDER_CONFIG['aws-bedrock']
|
||||
const apiKeyWebsite = providerConfig?.websites?.apiKey
|
||||
|
||||
const [localAccessKeyId, setLocalAccessKeyId] = useState(accessKeyId)
|
||||
const [localSecretAccessKey, setLocalSecretAccessKey] = useState(secretAccessKey)
|
||||
const [localRegion, setLocalRegion] = useState(region)
|
||||
|
||||
return (
|
||||
<>
|
||||
<SettingSubtitle style={{ marginTop: 5 }}>{t('settings.provider.aws-bedrock.title')}</SettingSubtitle>
|
||||
<Alert type="info" style={{ marginTop: 5 }} message={t('settings.provider.aws-bedrock.description')} showIcon />
|
||||
|
||||
<SettingSubtitle style={{ marginTop: 5 }}>{t('settings.provider.aws-bedrock.access_key_id')}</SettingSubtitle>
|
||||
<Input
|
||||
value={localAccessKeyId}
|
||||
placeholder="Access Key ID"
|
||||
onChange={(e) => setLocalAccessKeyId(e.target.value)}
|
||||
onBlur={() => setAccessKeyId(localAccessKeyId)}
|
||||
style={{ marginTop: 5 }}
|
||||
/>
|
||||
<SettingHelpTextRow>
|
||||
<SettingHelpText>{t('settings.provider.aws-bedrock.access_key_id_help')}</SettingHelpText>
|
||||
</SettingHelpTextRow>
|
||||
|
||||
<SettingSubtitle style={{ marginTop: 5 }}>{t('settings.provider.aws-bedrock.secret_access_key')}</SettingSubtitle>
|
||||
<Input.Password
|
||||
value={localSecretAccessKey}
|
||||
placeholder="Secret Access Key"
|
||||
onChange={(e) => setLocalSecretAccessKey(e.target.value)}
|
||||
onBlur={() => setSecretAccessKey(localSecretAccessKey)}
|
||||
style={{ marginTop: 5 }}
|
||||
spellCheck={false}
|
||||
/>
|
||||
{apiKeyWebsite && (
|
||||
<SettingHelpTextRow style={{ justifyContent: 'space-between' }}>
|
||||
<HStack>
|
||||
<SettingHelpLink target="_blank" href={apiKeyWebsite}>
|
||||
{t('settings.provider.get_api_key')}
|
||||
</SettingHelpLink>
|
||||
</HStack>
|
||||
<SettingHelpText>{t('settings.provider.aws-bedrock.secret_access_key_help')}</SettingHelpText>
|
||||
</SettingHelpTextRow>
|
||||
)}
|
||||
|
||||
<SettingSubtitle style={{ marginTop: 5 }}>{t('settings.provider.aws-bedrock.region')}</SettingSubtitle>
|
||||
<Input
|
||||
value={localRegion}
|
||||
placeholder="us-east-1"
|
||||
onChange={(e) => setLocalRegion(e.target.value)}
|
||||
onBlur={() => setRegion(localRegion)}
|
||||
style={{ marginTop: 5 }}
|
||||
/>
|
||||
<SettingHelpTextRow>
|
||||
<SettingHelpText>{t('settings.provider.aws-bedrock.region_help')}</SettingHelpText>
|
||||
</SettingHelpTextRow>
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default AwsBedrockSettings
|
||||
@@ -29,6 +29,7 @@ import {
|
||||
SettingSubtitle,
|
||||
SettingTitle
|
||||
} from '..'
|
||||
import AwsBedrockSettings from './AwsBedrockSettings'
|
||||
import CustomHeaderPopup from './CustomHeaderPopup'
|
||||
import DMXAPISettings from './DMXAPISettings'
|
||||
import GithubCopilotSettings from './GithubCopilotSettings'
|
||||
@@ -259,7 +260,7 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
|
||||
{isProviderSupportAuth(provider) && <ProviderOAuth providerId={provider.id} />}
|
||||
{provider.id === 'openai' && <OpenAIAlert />}
|
||||
{isDmxapi && <DMXAPISettings providerId={provider.id} />}
|
||||
{provider.id !== 'vertexai' && (
|
||||
{provider.id !== 'vertexai' && provider.id !== 'aws-bedrock' && (
|
||||
<>
|
||||
<SettingSubtitle
|
||||
style={{
|
||||
@@ -372,6 +373,7 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
|
||||
{provider.id === 'lmstudio' && <LMStudioSettings />}
|
||||
{provider.id === 'gpustack' && <GPUStackSettings />}
|
||||
{provider.id === 'copilot' && <GithubCopilotSettings providerId={provider.id} />}
|
||||
{provider.id === 'aws-bedrock' && <AwsBedrockSettings />}
|
||||
{provider.id === 'vertexai' && <VertexAISettings providerId={provider.id} />}
|
||||
<ModelList providerId={provider.id} />
|
||||
</SettingContainer>
|
||||
|
||||
@@ -22,6 +22,11 @@ type LlmSettings = {
|
||||
projectId: string
|
||||
location: string
|
||||
}
|
||||
awsBedrock: {
|
||||
accessKeyId: string
|
||||
secretAccessKey: string
|
||||
region: string
|
||||
}
|
||||
}
|
||||
|
||||
export interface LlmState {
|
||||
@@ -537,6 +542,16 @@ export const INITIAL_PROVIDERS: Provider[] = [
|
||||
models: SYSTEM_MODELS.voyageai,
|
||||
isSystem: true,
|
||||
enabled: false
|
||||
},
|
||||
{
|
||||
id: 'aws-bedrock',
|
||||
name: 'AWS Bedrock',
|
||||
type: 'aws-bedrock',
|
||||
apiKey: '',
|
||||
apiHost: '',
|
||||
models: SYSTEM_MODELS['aws-bedrock'],
|
||||
isSystem: true,
|
||||
enabled: false
|
||||
}
|
||||
]
|
||||
|
||||
@@ -563,6 +578,11 @@ export const initialState: LlmState = {
|
||||
},
|
||||
projectId: '',
|
||||
location: ''
|
||||
},
|
||||
awsBedrock: {
|
||||
accessKeyId: '',
|
||||
secretAccessKey: '',
|
||||
region: ''
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -687,6 +707,15 @@ const llmSlice = createSlice({
|
||||
setVertexAIServiceAccountClientEmail: (state, action: PayloadAction<string>) => {
|
||||
state.settings.vertexai.serviceAccount.clientEmail = action.payload
|
||||
},
|
||||
setAwsBedrockAccessKeyId: (state, action: PayloadAction<string>) => {
|
||||
state.settings.awsBedrock.accessKeyId = action.payload
|
||||
},
|
||||
setAwsBedrockSecretAccessKey: (state, action: PayloadAction<string>) => {
|
||||
state.settings.awsBedrock.secretAccessKey = action.payload
|
||||
},
|
||||
setAwsBedrockRegion: (state, action: PayloadAction<string>) => {
|
||||
state.settings.awsBedrock.region = action.payload
|
||||
},
|
||||
updateModel: (
|
||||
state,
|
||||
action: PayloadAction<{
|
||||
@@ -723,6 +752,9 @@ export const {
|
||||
setVertexAILocation,
|
||||
setVertexAIServiceAccountPrivateKey,
|
||||
setVertexAIServiceAccountClientEmail,
|
||||
setAwsBedrockAccessKeyId,
|
||||
setAwsBedrockSecretAccessKey,
|
||||
setAwsBedrockRegion,
|
||||
updateModel
|
||||
} = llmSlice.actions
|
||||
|
||||
|
||||
@@ -1907,6 +1907,13 @@ const migrateConfig = {
|
||||
updateModelTextDelta(state.assistants.defaultAssistant.defaultModel)
|
||||
}
|
||||
|
||||
addProvider(state, 'aws-bedrock')
|
||||
|
||||
// 初始化 awsBedrock 设置
|
||||
if (!state.llm.settings.awsBedrock) {
|
||||
state.llm.settings.awsBedrock = llmInitialState.settings.awsBedrock
|
||||
}
|
||||
|
||||
return state
|
||||
} catch (error) {
|
||||
logger.error('migrate 124 error', error as Error)
|
||||
|
||||
@@ -187,6 +187,7 @@ export type ProviderType =
|
||||
| 'azure-openai'
|
||||
| 'vertexai'
|
||||
| 'mistral'
|
||||
| 'aws-bedrock'
|
||||
|
||||
export type ModelType = 'text' | 'vision' | 'embedding' | 'reasoning' | 'function_calling' | 'web_search' | 'rerank'
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ import {
|
||||
} from '@anthropic-ai/sdk/resources'
|
||||
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
||||
import type { BedrockRuntimeClient } from '@aws-sdk/client-bedrock-runtime'
|
||||
import {
|
||||
Content,
|
||||
CreateChatParameters,
|
||||
@@ -24,21 +25,43 @@ import { Stream } from 'openai/streaming'
|
||||
|
||||
import { EndpointType } from './index'
|
||||
|
||||
export type SdkInstance = OpenAI | AzureOpenAI | Anthropic | AnthropicVertex | GoogleGenAI
|
||||
export type SdkParams = OpenAISdkParams | OpenAIResponseSdkParams | AnthropicSdkParams | GeminiSdkParams
|
||||
export type SdkRawChunk = OpenAISdkRawChunk | OpenAIResponseSdkRawChunk | AnthropicSdkRawChunk | GeminiSdkRawChunk
|
||||
export type SdkRawOutput = OpenAISdkRawOutput | OpenAIResponseSdkRawOutput | AnthropicSdkRawOutput | GeminiSdkRawOutput
|
||||
export type SdkInstance = OpenAI | AzureOpenAI | Anthropic | AnthropicVertex | GoogleGenAI | AwsBedrockSdkInstance
|
||||
export type SdkParams =
|
||||
| OpenAISdkParams
|
||||
| OpenAIResponseSdkParams
|
||||
| AnthropicSdkParams
|
||||
| GeminiSdkParams
|
||||
| AwsBedrockSdkParams
|
||||
export type SdkRawChunk =
|
||||
| OpenAISdkRawChunk
|
||||
| OpenAIResponseSdkRawChunk
|
||||
| AnthropicSdkRawChunk
|
||||
| GeminiSdkRawChunk
|
||||
| AwsBedrockSdkRawChunk
|
||||
export type SdkRawOutput =
|
||||
| OpenAISdkRawOutput
|
||||
| OpenAIResponseSdkRawOutput
|
||||
| AnthropicSdkRawOutput
|
||||
| GeminiSdkRawOutput
|
||||
| AwsBedrockSdkRawOutput
|
||||
export type SdkMessageParam =
|
||||
| OpenAISdkMessageParam
|
||||
| OpenAIResponseSdkMessageParam
|
||||
| AnthropicSdkMessageParam
|
||||
| GeminiSdkMessageParam
|
||||
| AwsBedrockSdkMessageParam
|
||||
export type SdkToolCall =
|
||||
| OpenAI.Chat.Completions.ChatCompletionMessageToolCall
|
||||
| ToolUseBlock
|
||||
| FunctionCall
|
||||
| OpenAIResponseSdkToolCall
|
||||
export type SdkTool = OpenAI.Chat.Completions.ChatCompletionTool | ToolUnion | Tool | OpenAIResponseSdkTool
|
||||
| AwsBedrockSdkToolCall
|
||||
export type SdkTool =
|
||||
| OpenAI.Chat.Completions.ChatCompletionTool
|
||||
| ToolUnion
|
||||
| Tool
|
||||
| OpenAIResponseSdkTool
|
||||
| AwsBedrockSdkTool
|
||||
export type SdkModel = OpenAI.Models.Model | Anthropic.ModelInfo | GeminiModel | NewApiModel
|
||||
|
||||
export type RequestOptions = Anthropic.RequestOptions | OpenAI.RequestOptions | GeminiOptions
|
||||
@@ -117,3 +140,119 @@ export type GeminiOptions = {
|
||||
export interface NewApiModel extends OpenAI.Models.Model {
|
||||
supported_endpoint_types?: EndpointType[]
|
||||
}
|
||||
|
||||
/**
|
||||
* AWS Bedrock
|
||||
*/
|
||||
export interface AwsBedrockSdkInstance {
|
||||
client: BedrockRuntimeClient
|
||||
region: string
|
||||
}
|
||||
|
||||
export interface AwsBedrockSdkParams {
|
||||
modelId: string
|
||||
messages: AwsBedrockSdkMessageParam[]
|
||||
system?: string
|
||||
maxTokens?: number
|
||||
temperature?: number
|
||||
topP?: number
|
||||
stream?: boolean
|
||||
tools?: AwsBedrockSdkTool[]
|
||||
}
|
||||
|
||||
export interface AwsBedrockSdkMessageParam {
|
||||
role: 'user' | 'assistant'
|
||||
content: Array<{
|
||||
text?: string
|
||||
image?: {
|
||||
format: 'png' | 'jpeg' | 'gif' | 'webp'
|
||||
source: {
|
||||
bytes?: Uint8Array
|
||||
s3Location?: {
|
||||
uri: string
|
||||
bucketOwner?: string
|
||||
}
|
||||
}
|
||||
}
|
||||
toolResult?: {
|
||||
toolUseId: string
|
||||
content: Array<{
|
||||
json?: any
|
||||
text?: string
|
||||
image?: {
|
||||
format: 'png' | 'jpeg' | 'gif' | 'webp'
|
||||
source: {
|
||||
bytes?: Uint8Array
|
||||
s3Location?: {
|
||||
uri: string
|
||||
bucketOwner?: string
|
||||
}
|
||||
}
|
||||
}
|
||||
document?: any
|
||||
video?: any
|
||||
}>
|
||||
status?: 'success' | 'error'
|
||||
}
|
||||
toolUse?: {
|
||||
toolUseId: string
|
||||
name: string
|
||||
input: any
|
||||
}
|
||||
}>
|
||||
}
|
||||
|
||||
export interface AwsBedrockSdkRawChunk {
|
||||
contentBlockStart?: {
|
||||
start?: {
|
||||
toolUse?: {
|
||||
toolUseId: string
|
||||
name: string
|
||||
}
|
||||
}
|
||||
contentBlockIndex?: number
|
||||
}
|
||||
contentBlockDelta?: {
|
||||
delta?: {
|
||||
text?: string
|
||||
toolUse?: {
|
||||
input?: string
|
||||
}
|
||||
}
|
||||
contentBlockIndex?: number
|
||||
}
|
||||
contentBlockStop?: {
|
||||
contentBlockIndex?: number
|
||||
}
|
||||
messageStart?: any
|
||||
messageStop?: any
|
||||
metadata?: any
|
||||
}
|
||||
|
||||
export type AwsBedrockSdkRawOutput = { output: any } | AsyncIterable<AwsBedrockSdkRawChunk>
|
||||
|
||||
export interface AwsBedrockSdkTool {
|
||||
toolSpec: {
|
||||
name: string
|
||||
description?: string
|
||||
inputSchema: {
|
||||
json: {
|
||||
type: string
|
||||
properties?: {
|
||||
[key: string]: {
|
||||
type: string
|
||||
description?: string
|
||||
}
|
||||
}
|
||||
required?: string[]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export interface AwsBedrockSdkToolCall {
|
||||
id: string
|
||||
name: string
|
||||
input: any
|
||||
toolUseId: string
|
||||
}
|
||||
|
||||
@@ -0,0 +1,226 @@
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import {
|
||||
type AwsBedrockImage,
|
||||
type AwsBedrockImageFormat,
|
||||
base64ToUint8Array,
|
||||
convertBase64ImageToAwsBedrockFormat,
|
||||
extractImageFormatFromMimeType,
|
||||
isAwsBedrockSupportedImageFormat
|
||||
} from '../aws-bedrock-utils'
|
||||
|
||||
describe('utils/aws-bedrock-utils', () => {
|
||||
describe('extractImageFormatFromMimeType', () => {
|
||||
it('should extract png format from mime type', () => {
|
||||
expect(extractImageFormatFromMimeType('image/png')).toBe('png')
|
||||
})
|
||||
|
||||
it('should extract jpeg format from mime type', () => {
|
||||
expect(extractImageFormatFromMimeType('image/jpeg')).toBe('jpeg')
|
||||
})
|
||||
|
||||
it('should extract gif format from mime type', () => {
|
||||
expect(extractImageFormatFromMimeType('image/gif')).toBe('gif')
|
||||
})
|
||||
|
||||
it('should extract webp format from mime type', () => {
|
||||
expect(extractImageFormatFromMimeType('image/webp')).toBe('webp')
|
||||
})
|
||||
|
||||
it('should return null for unsupported mime type', () => {
|
||||
expect(extractImageFormatFromMimeType('image/bmp')).toBe(null)
|
||||
expect(extractImageFormatFromMimeType('image/svg+xml')).toBe(null)
|
||||
expect(extractImageFormatFromMimeType('image/tiff')).toBe(null)
|
||||
})
|
||||
|
||||
it('should return null for invalid mime type format', () => {
|
||||
expect(extractImageFormatFromMimeType('invalid')).toBe(null)
|
||||
expect(extractImageFormatFromMimeType('text/plain')).toBe(null)
|
||||
expect(extractImageFormatFromMimeType('application/json')).toBe(null)
|
||||
})
|
||||
|
||||
it('should return null for undefined or empty input', () => {
|
||||
expect(extractImageFormatFromMimeType(undefined)).toBe(null)
|
||||
expect(extractImageFormatFromMimeType('')).toBe(null)
|
||||
})
|
||||
|
||||
it('should handle mime type with additional parameters', () => {
|
||||
expect(extractImageFormatFromMimeType('image/png; charset=utf-8')).toBe(null)
|
||||
expect(extractImageFormatFromMimeType('image/jpeg; quality=95')).toBe(null)
|
||||
})
|
||||
})
|
||||
|
||||
describe('base64ToUint8Array', () => {
|
||||
it('should convert valid base64 string to Uint8Array', () => {
|
||||
// "hello" in base64 is "aGVsbG8="
|
||||
const base64 = 'aGVsbG8='
|
||||
const result = base64ToUint8Array(base64)
|
||||
|
||||
expect(result).toBeInstanceOf(Uint8Array)
|
||||
expect(result.length).toBe(5)
|
||||
expect(Array.from(result)).toEqual([104, 101, 108, 108, 111]) // ASCII values for "hello"
|
||||
})
|
||||
|
||||
it('should convert empty base64 string to empty Uint8Array', () => {
|
||||
const result = base64ToUint8Array('')
|
||||
expect(result).toBeInstanceOf(Uint8Array)
|
||||
expect(result.length).toBe(0)
|
||||
})
|
||||
|
||||
it('should handle base64 with padding', () => {
|
||||
const base64 = 'YQ==' // "a" in base64
|
||||
const result = base64ToUint8Array(base64)
|
||||
|
||||
expect(result).toBeInstanceOf(Uint8Array)
|
||||
expect(result.length).toBe(1)
|
||||
expect(result[0]).toBe(97) // ASCII value for "a"
|
||||
})
|
||||
|
||||
it('should handle base64 without padding', () => {
|
||||
const base64 = 'YWI' // "ab" in base64 without padding
|
||||
const result = base64ToUint8Array(base64)
|
||||
|
||||
expect(result).toBeInstanceOf(Uint8Array)
|
||||
expect(result.length).toBe(2)
|
||||
expect(Array.from(result)).toEqual([97, 98]) // ASCII values for "ab"
|
||||
})
|
||||
|
||||
it('should throw error for invalid base64 string', () => {
|
||||
expect(() => base64ToUint8Array('invalid!@#$%^&*()')).toThrow('Failed to decode base64 data')
|
||||
expect(() => base64ToUint8Array('hello world!')).toThrow('Failed to decode base64 data')
|
||||
})
|
||||
|
||||
it('should handle binary data correctly', () => {
|
||||
// Binary data that represents a simple image header
|
||||
const binaryData = new Uint8Array([137, 80, 78, 71]) // PNG header
|
||||
const base64 = btoa(String.fromCharCode(...binaryData))
|
||||
const result = base64ToUint8Array(base64)
|
||||
|
||||
expect(result).toBeInstanceOf(Uint8Array)
|
||||
expect(Array.from(result)).toEqual([137, 80, 78, 71])
|
||||
})
|
||||
})
|
||||
|
||||
describe('convertBase64ImageToAwsBedrockFormat', () => {
|
||||
const validBase64 = 'aGVsbG8=' // "hello" in base64
|
||||
|
||||
it('should convert base64 image with valid mime type', () => {
|
||||
const result = convertBase64ImageToAwsBedrockFormat(validBase64, 'image/png')
|
||||
|
||||
expect(result).not.toBe(null)
|
||||
expect(result?.format).toBe('png')
|
||||
expect(result?.source.bytes).toBeInstanceOf(Uint8Array)
|
||||
expect(result?.source.bytes.length).toBe(5)
|
||||
})
|
||||
|
||||
it('should use fallback format when mime type is not provided', () => {
|
||||
const result = convertBase64ImageToAwsBedrockFormat(validBase64)
|
||||
|
||||
expect(result).not.toBe(null)
|
||||
expect(result?.format).toBe('png') // default fallback
|
||||
expect(result?.source.bytes).toBeInstanceOf(Uint8Array)
|
||||
})
|
||||
|
||||
it('should use custom fallback format', () => {
|
||||
const result = convertBase64ImageToAwsBedrockFormat(validBase64, undefined, 'jpeg')
|
||||
|
||||
expect(result).not.toBe(null)
|
||||
expect(result?.format).toBe('jpeg')
|
||||
expect(result?.source.bytes).toBeInstanceOf(Uint8Array)
|
||||
})
|
||||
|
||||
it('should extract format from mime type when provided', () => {
|
||||
const result = convertBase64ImageToAwsBedrockFormat(validBase64, 'image/webp', 'png')
|
||||
|
||||
expect(result).not.toBe(null)
|
||||
expect(result?.format).toBe('webp') // extracted from mime type, not fallback
|
||||
})
|
||||
|
||||
it('should use fallback format for unsupported mime type', () => {
|
||||
const result = convertBase64ImageToAwsBedrockFormat(validBase64, 'image/bmp')
|
||||
|
||||
expect(result).not.toBe(null)
|
||||
expect(result?.format).toBe('png') // uses fallback format
|
||||
})
|
||||
|
||||
it('should return null for invalid base64 data', () => {
|
||||
const result = convertBase64ImageToAwsBedrockFormat('invalid!@#$%^&*()', 'image/png')
|
||||
|
||||
expect(result).toBe(null)
|
||||
})
|
||||
|
||||
it('should return null for invalid fallback format', () => {
|
||||
// @ts-ignore - testing invalid fallback format
|
||||
const result = convertBase64ImageToAwsBedrockFormat(validBase64, undefined, 'bmp')
|
||||
|
||||
expect(result).toBe(null)
|
||||
})
|
||||
|
||||
it('should handle all supported formats', () => {
|
||||
const formats: AwsBedrockImageFormat[] = ['png', 'jpeg', 'gif', 'webp']
|
||||
|
||||
formats.forEach((format) => {
|
||||
const result = convertBase64ImageToAwsBedrockFormat(validBase64, `image/${format}`)
|
||||
expect(result).not.toBe(null)
|
||||
expect(result?.format).toBe(format)
|
||||
})
|
||||
})
|
||||
|
||||
it('should return proper AwsBedrockImage structure', () => {
|
||||
const result = convertBase64ImageToAwsBedrockFormat(validBase64, 'image/png')
|
||||
|
||||
expect(result).toEqual({
|
||||
format: 'png',
|
||||
source: {
|
||||
bytes: expect.any(Uint8Array)
|
||||
}
|
||||
} as AwsBedrockImage)
|
||||
})
|
||||
|
||||
it('should handle empty base64 string', () => {
|
||||
const result = convertBase64ImageToAwsBedrockFormat('', 'image/png')
|
||||
|
||||
expect(result).not.toBe(null)
|
||||
expect(result?.format).toBe('png')
|
||||
expect(result?.source.bytes).toBeInstanceOf(Uint8Array)
|
||||
expect(result?.source.bytes.length).toBe(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isAwsBedrockSupportedImageFormat', () => {
|
||||
it('should return true for supported formats', () => {
|
||||
expect(isAwsBedrockSupportedImageFormat('image/png')).toBe(true)
|
||||
expect(isAwsBedrockSupportedImageFormat('image/jpeg')).toBe(true)
|
||||
expect(isAwsBedrockSupportedImageFormat('image/gif')).toBe(true)
|
||||
expect(isAwsBedrockSupportedImageFormat('image/webp')).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for unsupported formats', () => {
|
||||
expect(isAwsBedrockSupportedImageFormat('image/bmp')).toBe(false)
|
||||
expect(isAwsBedrockSupportedImageFormat('image/svg+xml')).toBe(false)
|
||||
expect(isAwsBedrockSupportedImageFormat('image/tiff')).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for non-image mime types', () => {
|
||||
expect(isAwsBedrockSupportedImageFormat('text/plain')).toBe(false)
|
||||
expect(isAwsBedrockSupportedImageFormat('application/json')).toBe(false)
|
||||
expect(isAwsBedrockSupportedImageFormat('video/mp4')).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for invalid mime types', () => {
|
||||
expect(isAwsBedrockSupportedImageFormat('invalid')).toBe(false)
|
||||
expect(isAwsBedrockSupportedImageFormat('image/')).toBe(false)
|
||||
expect(isAwsBedrockSupportedImageFormat('/bmp')).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for undefined or empty input', () => {
|
||||
expect(isAwsBedrockSupportedImageFormat(undefined)).toBe(false)
|
||||
expect(isAwsBedrockSupportedImageFormat('')).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for mime types with additional parameters', () => {
|
||||
expect(isAwsBedrockSupportedImageFormat('image/png; charset=utf-8')).toBe(false)
|
||||
expect(isAwsBedrockSupportedImageFormat('image/jpeg; quality=95')).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,98 @@
|
||||
/**
|
||||
* AWS Bedrock 相关工具函数
|
||||
*/
|
||||
|
||||
/**
|
||||
* 支持的图片格式类型
|
||||
*/
|
||||
export type AwsBedrockImageFormat = 'png' | 'jpeg' | 'gif' | 'webp'
|
||||
|
||||
/**
|
||||
* AWS Bedrock 图片对象格式
|
||||
*/
|
||||
export interface AwsBedrockImage {
|
||||
format: AwsBedrockImageFormat
|
||||
source: {
|
||||
bytes: Uint8Array
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 从 MIME 类型中提取图片格式
|
||||
* @param mimeType MIME 类型,如 'image/png'
|
||||
* @returns 图片格式或 null(如果不支持)
|
||||
*/
|
||||
export function extractImageFormatFromMimeType(mimeType?: string): AwsBedrockImageFormat | null {
|
||||
if (!mimeType) return null
|
||||
|
||||
const format = mimeType.split('/')[1] as AwsBedrockImageFormat
|
||||
|
||||
if (['png', 'jpeg', 'gif', 'webp'].includes(format)) {
|
||||
return format
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 base64 字符串转换为 Uint8Array
|
||||
* @param base64Data base64 编码的字符串
|
||||
* @returns Uint8Array
|
||||
* @throws Error 如果 base64 解码失败
|
||||
*/
|
||||
export function base64ToUint8Array(base64Data: string): Uint8Array {
|
||||
try {
|
||||
// 在浏览器环境中正确处理base64转换为Uint8Array
|
||||
const binaryString = atob(base64Data)
|
||||
const bytes = new Uint8Array(binaryString.length)
|
||||
for (let i = 0; i < binaryString.length; i++) {
|
||||
bytes[i] = binaryString.charCodeAt(i)
|
||||
}
|
||||
return bytes
|
||||
} catch (error) {
|
||||
throw new Error(`Failed to decode base64 data: ${error instanceof Error ? error.message : 'Unknown error'}`)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 base64 图片数据转换为 AWS Bedrock 格式
|
||||
* @param data base64 编码的图片数据
|
||||
* @param mimeType 图片的 MIME 类型
|
||||
* @param fallbackFormat 当无法从 mimeType 中提取格式时的默认格式
|
||||
* @returns AWS Bedrock 格式的图片对象,如果格式不支持则返回 null
|
||||
*/
|
||||
export function convertBase64ImageToAwsBedrockFormat(
|
||||
data: string,
|
||||
mimeType?: string,
|
||||
fallbackFormat: AwsBedrockImageFormat = 'png'
|
||||
): AwsBedrockImage | null {
|
||||
const format = extractImageFormatFromMimeType(mimeType) || fallbackFormat
|
||||
|
||||
// 验证格式是否支持
|
||||
if (!['png', 'jpeg', 'gif', 'webp'].includes(format)) {
|
||||
return null
|
||||
}
|
||||
|
||||
try {
|
||||
const bytes = base64ToUint8Array(data)
|
||||
|
||||
return {
|
||||
format,
|
||||
source: {
|
||||
bytes
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
// 如果转换失败,返回 null
|
||||
return null
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查给定的 MIME 类型是否为 AWS Bedrock 支持的图片格式
|
||||
* @param mimeType MIME 类型
|
||||
* @returns 是否支持
|
||||
*/
|
||||
export function isAwsBedrockSupportedImageFormat(mimeType?: string): boolean {
|
||||
return extractImageFormatFromMimeType(mimeType) !== null
|
||||
}
|
||||
@@ -17,6 +17,7 @@ import {
|
||||
} from '@renderer/types'
|
||||
import type { MCPToolCompleteChunk, MCPToolInProgressChunk, MCPToolPendingChunk } from '@renderer/types/chunk'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
import { AwsBedrockSdkMessageParam, AwsBedrockSdkTool, AwsBedrockSdkToolCall } from '@renderer/types/sdk'
|
||||
import { isArray, isObject, pull, transform } from 'lodash'
|
||||
import { nanoid } from 'nanoid'
|
||||
import OpenAI from 'openai'
|
||||
@@ -27,6 +28,8 @@ import {
|
||||
ChatCompletionTool
|
||||
} from 'openai/resources'
|
||||
|
||||
import { convertBase64ImageToAwsBedrockFormat } from './aws-bedrock-utils'
|
||||
|
||||
const logger = loggerService.withContext('Utils:MCPTools')
|
||||
|
||||
const MCP_AUTO_INSTALL_SERVER_NAME = '@cherry/mcp-auto-install'
|
||||
@@ -533,7 +536,7 @@ export function parseToolUse(content: string, mcpTools: MCPTool[], startIdx: num
|
||||
parsedArgs = toolArgs
|
||||
}
|
||||
// Logger.log(`Parsed arguments for tool "${toolName}":`, parsedArgs)
|
||||
const mcpTool = mcpTools.find((tool) => tool.id === toolName)
|
||||
const mcpTool = mcpTools.find((tool) => tool.id === toolName || tool.name === toolName)
|
||||
if (!mcpTool) {
|
||||
logger.error(`Tool "${toolName}" not found in MCP tools`)
|
||||
window.message.error(i18n.t('settings.mcp.errors.toolNotFound', { name: toolName }))
|
||||
@@ -835,6 +838,163 @@ export function mcpToolCallResponseToGeminiMessage(
|
||||
return message
|
||||
}
|
||||
|
||||
export function mcpToolsToAwsBedrockTools(mcpTools: MCPTool[]): Array<AwsBedrockSdkTool> {
|
||||
return mcpTools.map((tool) => ({
|
||||
toolSpec: {
|
||||
name: tool.id,
|
||||
description: tool.description,
|
||||
inputSchema: {
|
||||
json: {
|
||||
type: 'object',
|
||||
properties: tool.inputSchema?.properties
|
||||
? Object.fromEntries(
|
||||
Object.entries(tool.inputSchema.properties).map(([key, value]) => [
|
||||
key,
|
||||
{
|
||||
type:
|
||||
typeof value === 'object' && value !== null && 'type' in value ? (value as any).type : 'string',
|
||||
description:
|
||||
typeof value === 'object' && value !== null && 'description' in value
|
||||
? (value as any).description
|
||||
: undefined
|
||||
}
|
||||
])
|
||||
)
|
||||
: {},
|
||||
required: tool.inputSchema?.required || []
|
||||
}
|
||||
}
|
||||
}
|
||||
}))
|
||||
}
|
||||
|
||||
export function awsBedrockToolUseToMcpTool(
|
||||
mcpTools: MCPTool[] | undefined,
|
||||
toolCall: AwsBedrockSdkToolCall
|
||||
): MCPTool | undefined {
|
||||
if (!toolCall) return undefined
|
||||
if (!mcpTools) return undefined
|
||||
const tool = mcpTools.find((tool) => tool.id === toolCall.name || tool.name === toolCall.name)
|
||||
if (!tool) {
|
||||
return undefined
|
||||
}
|
||||
return tool
|
||||
}
|
||||
|
||||
export function mcpToolCallResponseToAwsBedrockMessage(
|
||||
mcpToolResponse: MCPToolResponse,
|
||||
resp: MCPCallToolResponse,
|
||||
model: Model
|
||||
): AwsBedrockSdkMessageParam {
|
||||
const message: AwsBedrockSdkMessageParam = {
|
||||
role: 'user',
|
||||
content: []
|
||||
}
|
||||
|
||||
const toolUseId =
|
||||
'toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId
|
||||
? mcpToolResponse.toolUseId
|
||||
: 'toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId
|
||||
? mcpToolResponse.toolCallId
|
||||
: 'unknown-tool-id'
|
||||
|
||||
if (resp.isError) {
|
||||
message.content = [
|
||||
{
|
||||
toolResult: {
|
||||
toolUseId: toolUseId,
|
||||
content: [
|
||||
{
|
||||
text: `Error: ${JSON.stringify(resp.content)}`
|
||||
}
|
||||
],
|
||||
status: 'error'
|
||||
}
|
||||
}
|
||||
]
|
||||
} else {
|
||||
const toolResultContent: Array<{
|
||||
json?: any
|
||||
text?: string
|
||||
image?: {
|
||||
format: 'png' | 'jpeg' | 'gif' | 'webp'
|
||||
source: {
|
||||
bytes?: Uint8Array
|
||||
s3Location?: {
|
||||
uri: string
|
||||
bucketOwner?: string
|
||||
}
|
||||
}
|
||||
}
|
||||
}> = []
|
||||
|
||||
if (isVisionModel(model)) {
|
||||
for (const item of resp.content) {
|
||||
switch (item.type) {
|
||||
case 'text':
|
||||
toolResultContent.push({
|
||||
text: item.text || 'no content'
|
||||
})
|
||||
break
|
||||
case 'image':
|
||||
if (item.data && item.mimeType) {
|
||||
const awsImage = convertBase64ImageToAwsBedrockFormat(item.data, item.mimeType)
|
||||
if (awsImage) {
|
||||
toolResultContent.push({ image: awsImage })
|
||||
} else {
|
||||
toolResultContent.push({
|
||||
text: `[Image received: ${item.mimeType}, size: ${item.data?.length || 0} bytes]`
|
||||
})
|
||||
}
|
||||
} else {
|
||||
toolResultContent.push({
|
||||
text: '[Image received but no data available]'
|
||||
})
|
||||
}
|
||||
break
|
||||
default:
|
||||
toolResultContent.push({
|
||||
text: `Unsupported content type: ${item.type}`
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// 对于非视觉模型,将所有内容合并为文本
|
||||
const textContent = resp.content
|
||||
.map((item) => {
|
||||
if (item.type === 'text') {
|
||||
return item.text
|
||||
} else {
|
||||
// 对于非文本内容,尝试转换为JSON格式
|
||||
try {
|
||||
return JSON.stringify(item)
|
||||
} catch {
|
||||
return `[${item.type} content]`
|
||||
}
|
||||
}
|
||||
})
|
||||
.join('\n')
|
||||
|
||||
toolResultContent.push({
|
||||
text: textContent || 'Tool execution completed with no output'
|
||||
})
|
||||
}
|
||||
|
||||
message.content = [
|
||||
{
|
||||
toolResult: {
|
||||
toolUseId: toolUseId,
|
||||
content: toolResultContent,
|
||||
status: 'success'
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
return message
|
||||
}
|
||||
|
||||
export function isEnabledToolUse(assistant: Assistant) {
|
||||
if (assistant.model) {
|
||||
if (isFunctionCallingModel(assistant.model)) {
|
||||
|
||||
Reference in New Issue
Block a user