feat: update package dependencies and introduce new patches for AI SDK tools
- Added patches for `@ai-sdk/google-vertex` and `@ai-sdk/openai-compatible` to enhance functionality and fix issues. - Updated `package.json` to reflect new dependency versions and patch paths. - Refactored `transformParameters` and `ApiService` to support new tool configurations and improve parameter handling. - Introduced utility functions for setting up tools and managing options, enhancing the overall integration of tools within the AI SDK.
This commit is contained in:
@@ -162,16 +162,8 @@ export default class ModernAiProvider {
|
||||
}
|
||||
|
||||
try {
|
||||
// 合并传入的配置和实例配置
|
||||
const finalConfig: AiSdkMiddlewareConfig = {
|
||||
...middlewareConfig,
|
||||
provider: this.provider,
|
||||
// 工具相关信息从 params 中获取
|
||||
enableTool: !!Object.keys(params.tools || {}).length
|
||||
}
|
||||
|
||||
// 动态构建中间件数组
|
||||
const middlewares = buildAiSdkMiddlewares(finalConfig)
|
||||
const middlewares = buildAiSdkMiddlewares(middlewareConfig)
|
||||
console.log('构建的中间件:', middlewares)
|
||||
|
||||
// 创建带有中间件的执行器
|
||||
@@ -179,27 +171,29 @@ export default class ModernAiProvider {
|
||||
// 流式处理 - 使用适配器
|
||||
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
|
||||
// 创建MCP Prompt插件
|
||||
const mcpPromptPlugin = createMCPPromptPlugin({
|
||||
enabled: true,
|
||||
createSystemMessage: (systemPrompt, params, context) => {
|
||||
console.log('createSystemMessage_context', context.isRecursiveCall)
|
||||
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
|
||||
if (context.isRecursiveCall) {
|
||||
if (middlewareConfig.enableTool) {
|
||||
const mcpPromptPlugin = createMCPPromptPlugin({
|
||||
enabled: true,
|
||||
createSystemMessage: (systemPrompt, params, context) => {
|
||||
console.log('createSystemMessage_context', context.isRecursiveCall)
|
||||
if (context.modelId.includes('o1-mini') || context.modelId.includes('o1-preview')) {
|
||||
if (context.isRecursiveCall) {
|
||||
return null
|
||||
}
|
||||
params.messages = [
|
||||
{
|
||||
role: 'assistant',
|
||||
content: systemPrompt
|
||||
},
|
||||
...params.messages
|
||||
]
|
||||
return null
|
||||
}
|
||||
params.messages = [
|
||||
{
|
||||
role: 'assistant',
|
||||
content: systemPrompt
|
||||
},
|
||||
...params.messages
|
||||
]
|
||||
return null
|
||||
return systemPrompt
|
||||
}
|
||||
return systemPrompt
|
||||
}
|
||||
})
|
||||
this.modernExecutor.pluginEngine.use(mcpPromptPlugin)
|
||||
})
|
||||
this.modernExecutor.pluginEngine.use(mcpPromptPlugin)
|
||||
}
|
||||
const streamResult = await this.modernExecutor.streamText(
|
||||
modelId,
|
||||
params,
|
||||
|
||||
@@ -18,6 +18,7 @@ export interface AiSdkMiddlewareConfig {
|
||||
model?: Model
|
||||
provider?: Provider
|
||||
enableReasoning?: boolean
|
||||
// 是否开启提示词工具调用
|
||||
enableTool?: boolean
|
||||
enableWebSearch?: boolean
|
||||
mcpTools?: MCPTool[]
|
||||
|
||||
@@ -39,9 +39,8 @@ export default function thinkingTimeMiddleware(): LanguageModelV1Middleware {
|
||||
hasThinkingContent = false
|
||||
thinkingStartTime = 0
|
||||
accumulatedThinkingContent = ''
|
||||
} else {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
},
|
||||
flush(controller) {
|
||||
|
||||
@@ -11,18 +11,18 @@ const PROVIDER_MAPPING: Record<string, ProviderId> = {
|
||||
}
|
||||
|
||||
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' {
|
||||
const providerType = PROVIDER_MAPPING[provider.type] // 有些第三方需要映射到aicore对应sdk
|
||||
|
||||
if (providerType) {
|
||||
return providerType
|
||||
}
|
||||
|
||||
const providerId = PROVIDER_MAPPING[provider.id]
|
||||
|
||||
if (providerId) {
|
||||
return providerId
|
||||
}
|
||||
|
||||
const providerType = PROVIDER_MAPPING[provider.type] // 有些第三方需要映射到aicore对应sdk
|
||||
|
||||
if (providerType) {
|
||||
return providerType
|
||||
}
|
||||
|
||||
if (AiCore.isSupported(provider.id)) {
|
||||
return provider.id as ProviderId as ProviderId
|
||||
}
|
||||
|
||||
@@ -4,7 +4,6 @@
|
||||
*/
|
||||
|
||||
import { type CoreMessage, type StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { aiSdk } from '@cherrystudio/ai-core'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import {
|
||||
isGenerateImageModel,
|
||||
@@ -18,18 +17,16 @@ import {
|
||||
isWebSearchModel
|
||||
} from '@renderer/config/models'
|
||||
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
|
||||
import type { Assistant, MCPTool, MCPToolInputSchema, MCPToolResponse, Message, Model } from '@renderer/types'
|
||||
import type { Assistant, MCPTool, Message, Model } from '@renderer/types'
|
||||
import { FileTypes } from '@renderer/types'
|
||||
import { callMCPTool } from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
|
||||
// import { jsonSchemaToZod } from 'json-schema-to-zod'
|
||||
import { jsonSchema } from 'ai'
|
||||
import { setupToolsConfig } from './utils/mcp'
|
||||
import { buildProviderOptions } from './utils/options'
|
||||
|
||||
import { buildProviderOptions } from './utils/reasoning'
|
||||
|
||||
const { tool } = aiSdk
|
||||
/**
|
||||
* 获取温度参数
|
||||
*/
|
||||
@@ -190,7 +187,6 @@ export async function buildStreamTextParams(
|
||||
assistant: Assistant,
|
||||
options: {
|
||||
mcpTools?: MCPTool[]
|
||||
// FIXME: 上游没传
|
||||
enableTools?: boolean
|
||||
requestOptions?: {
|
||||
signal?: AbortSignal
|
||||
@@ -199,7 +195,7 @@ export async function buildStreamTextParams(
|
||||
}
|
||||
} = {}
|
||||
): Promise<{ params: StreamTextParams; modelId: string }> {
|
||||
const { mcpTools } = options
|
||||
const { mcpTools, enableTools } = options
|
||||
|
||||
const model = assistant.model || getDefaultModel()
|
||||
|
||||
@@ -221,11 +217,11 @@ export async function buildStreamTextParams(
|
||||
(isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage || false : true)
|
||||
|
||||
// 构建系统提示
|
||||
let systemPrompt = assistant.prompt || ''
|
||||
// TODO:根据调用类型判断是否添加systemPrompt
|
||||
if (mcpTools && mcpTools.length > 0 && assistant.settings?.toolUseMode === 'prompt') {
|
||||
systemPrompt = await buildSystemPromptWithTools(systemPrompt, mcpTools, assistant)
|
||||
}
|
||||
const { tools } = setupToolsConfig({
|
||||
mcpTools,
|
||||
model,
|
||||
enableToolUse: enableTools
|
||||
})
|
||||
|
||||
// 构建真正的 providerOptions
|
||||
const providerOptions = buildProviderOptions(assistant, model, {
|
||||
@@ -240,22 +236,14 @@ export async function buildStreamTextParams(
|
||||
maxTokens: maxTokens || DEFAULT_MAX_TOKENS,
|
||||
temperature: getTemperature(assistant, model),
|
||||
topP: getTopP(assistant, model),
|
||||
system: systemPrompt || undefined,
|
||||
system: assistant.prompt || '',
|
||||
abortSignal: options.requestOptions?.signal,
|
||||
headers: options.requestOptions?.headers,
|
||||
providerOptions,
|
||||
tools,
|
||||
maxSteps: 10
|
||||
}
|
||||
|
||||
const tools = mcpTools ? convertMcpToolsToAiSdkTools(mcpTools) : {}
|
||||
console.log('tools', tools)
|
||||
console.log('enableTools', assistant?.mcpServers?.length)
|
||||
// console.log('tools.length > 0', tools.length > 0)
|
||||
// 添加工具(如果启用且有工具)
|
||||
if (!!assistant?.mcpServers?.length && Object.keys(tools).length > 0) {
|
||||
params.tools = tools
|
||||
}
|
||||
|
||||
return { params, modelId: model.id }
|
||||
}
|
||||
|
||||
@@ -273,76 +261,3 @@ export async function buildGenerateTextParams(
|
||||
// 复用流式参数的构建逻辑
|
||||
return await buildStreamTextParams(messages, assistant, options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 MCPToolInputSchema 转换为 JSONSchema7 格式
|
||||
*/
|
||||
function convertMcpSchemaToJsonSchema7(schema: MCPToolInputSchema): any {
|
||||
// 创建符合 JSONSchema7 的对象
|
||||
const jsonSchema7: Record<string, any> = {
|
||||
type: 'object',
|
||||
properties: schema.properties || {},
|
||||
required: schema.required || []
|
||||
}
|
||||
|
||||
// 如果有 description,添加它
|
||||
if (schema.description) {
|
||||
jsonSchema7.description = schema.description
|
||||
}
|
||||
|
||||
// 如果有 title,添加它
|
||||
if (schema.title) {
|
||||
jsonSchema7.title = schema.title
|
||||
}
|
||||
|
||||
return jsonSchema7
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 MCPTool 转换为 AI SDK 工具格式
|
||||
*/
|
||||
export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string, any> {
|
||||
const tools: Record<string, any> = {}
|
||||
|
||||
for (const mcpTool of mcpTools) {
|
||||
console.log('mcpTool', mcpTool.inputSchema)
|
||||
tools[mcpTool.name] = tool({
|
||||
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
|
||||
parameters: jsonSchema<Record<string, object>>(convertMcpSchemaToJsonSchema7(mcpTool.inputSchema)),
|
||||
execute: async (params) => {
|
||||
console.log('execute_params', params)
|
||||
// 创建适配的 MCPToolResponse 对象
|
||||
const toolResponse: MCPToolResponse = {
|
||||
id: `tool_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
|
||||
tool: mcpTool,
|
||||
arguments: params,
|
||||
status: 'invoking',
|
||||
toolCallId: `call_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
|
||||
}
|
||||
|
||||
try {
|
||||
// 复用现有的 callMCPTool 函数
|
||||
const result = await callMCPTool(toolResponse)
|
||||
|
||||
// 返回结果,AI SDK 会处理序列化
|
||||
if (result.isError) {
|
||||
throw new Error(result.content?.[0]?.text || 'Tool execution failed')
|
||||
}
|
||||
console.log('result', result)
|
||||
// 返回工具执行结果
|
||||
return {
|
||||
success: true,
|
||||
data: result
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`MCP Tool execution failed: ${mcpTool.name}`, error)
|
||||
throw new Error(
|
||||
`Tool ${mcpTool.name} execution failed: ${error instanceof Error ? error.message : String(error)}`
|
||||
)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return tools
|
||||
}
|
||||
|
||||
@@ -0,0 +1,88 @@
|
||||
import { aiSdk, Tool } from '@cherrystudio/ai-core'
|
||||
import { SYSTEM_PROMPT_THRESHOLD } from '@renderer/config/constant'
|
||||
import { isFunctionCallingModel } from '@renderer/config/models'
|
||||
import { MCPCallToolResponse, MCPTool, MCPToolResponse, Model } from '@renderer/types'
|
||||
import { callMCPTool } from '@renderer/utils/mcp-tools'
|
||||
import { JSONSchema7 } from 'json-schema'
|
||||
|
||||
type ToolCallResult = {
|
||||
success: boolean
|
||||
data: MCPCallToolResponse
|
||||
}
|
||||
|
||||
type AiSdkTool = Tool<any, ToolCallResult>
|
||||
|
||||
// Setup tools configuration based on provided parameters
|
||||
export function setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
|
||||
tools: Record<string, AiSdkTool>
|
||||
useSystemPromptForTools?: boolean
|
||||
} {
|
||||
const { mcpTools, model, enableToolUse } = params
|
||||
|
||||
let tools: Record<string, AiSdkTool> = {}
|
||||
|
||||
if (!mcpTools?.length) {
|
||||
return { tools }
|
||||
}
|
||||
|
||||
tools = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
|
||||
if (mcpTools.length > SYSTEM_PROMPT_THRESHOLD) {
|
||||
return { tools, useSystemPromptForTools: true }
|
||||
}
|
||||
|
||||
if (isFunctionCallingModel(model) && enableToolUse) {
|
||||
return { tools, useSystemPromptForTools: false }
|
||||
}
|
||||
|
||||
return { tools, useSystemPromptForTools: true }
|
||||
}
|
||||
|
||||
/**
|
||||
* 将 MCPTool 转换为 AI SDK 工具格式
|
||||
*/
|
||||
export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string, Tool<any, ToolCallResult>> {
|
||||
const tools: Record<string, Tool<any, ToolCallResult>> = {}
|
||||
|
||||
for (const mcpTool of mcpTools) {
|
||||
console.log('mcpTool', mcpTool.inputSchema)
|
||||
tools[mcpTool.name] = aiSdk.tool<any, ToolCallResult>({
|
||||
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
|
||||
parameters: aiSdk.jsonSchema(mcpTool.inputSchema as JSONSchema7),
|
||||
execute: async (params): Promise<ToolCallResult> => {
|
||||
console.log('execute_params', params)
|
||||
// 创建适配的 MCPToolResponse 对象
|
||||
const toolResponse: MCPToolResponse = {
|
||||
id: `tool_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`,
|
||||
tool: mcpTool,
|
||||
arguments: params,
|
||||
status: 'invoking',
|
||||
toolCallId: `call_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`
|
||||
}
|
||||
|
||||
try {
|
||||
// 复用现有的 callMCPTool 函数
|
||||
const result = await callMCPTool(toolResponse)
|
||||
|
||||
// 返回结果,AI SDK 会处理序列化
|
||||
if (result.isError) {
|
||||
throw new Error(result.content?.[0]?.text || 'Tool execution failed')
|
||||
}
|
||||
console.log('result', result)
|
||||
// 返回工具执行结果
|
||||
return {
|
||||
success: true,
|
||||
data: result
|
||||
}
|
||||
} catch (error) {
|
||||
console.error(`MCP Tool execution failed: ${mcpTool.name}`, error)
|
||||
throw new Error(
|
||||
`Tool ${mcpTool.name} execution failed: ${error instanceof Error ? error.message : String(error)}`
|
||||
)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return tools
|
||||
}
|
||||
@@ -0,0 +1,203 @@
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { Assistant, Model } from '@renderer/types'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
import {
|
||||
getAnthropicReasoningParams,
|
||||
getCustomParameters,
|
||||
getGeminiReasoningParams,
|
||||
getOpenAIReasoningParams,
|
||||
getReasoningEffort
|
||||
} from './reasoning'
|
||||
|
||||
/**
|
||||
* 构建 AI SDK 的 providerOptions
|
||||
* 按 provider 类型分离,保持类型安全
|
||||
* 返回格式:{ 'providerId': providerOptions }
|
||||
*/
|
||||
export function buildProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: {
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const provider = getProviderByModel(model)
|
||||
const providerId = getAiSdkProviderId(provider)
|
||||
|
||||
// 构建 provider 特定的选项
|
||||
let providerSpecificOptions: Record<string, any> = {}
|
||||
|
||||
// 根据 provider 类型分离构建逻辑
|
||||
switch (provider.type) {
|
||||
case 'openai-response':
|
||||
case 'azure-openai':
|
||||
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
|
||||
case 'anthropic':
|
||||
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
|
||||
case 'gemini':
|
||||
case 'vertexai':
|
||||
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
|
||||
default:
|
||||
// 对于其他 provider,使用通用的构建逻辑
|
||||
providerSpecificOptions = buildGenericProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
}
|
||||
|
||||
// 合并自定义参数到 provider 特定的选项中
|
||||
providerSpecificOptions = {
|
||||
...providerSpecificOptions,
|
||||
...getCustomParameters(assistant)
|
||||
}
|
||||
|
||||
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions }
|
||||
return {
|
||||
[providerId]: providerSpecificOptions
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 OpenAI 特定的 providerOptions
|
||||
*/
|
||||
function buildOpenAIProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: {
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities
|
||||
let providerOptions: Record<string, any> = {}
|
||||
|
||||
// OpenAI 推理参数
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getOpenAIReasoningParams(assistant, model)
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
...reasoningParams
|
||||
}
|
||||
}
|
||||
|
||||
// Web 搜索和图像生成暂时使用通用格式
|
||||
if (enableWebSearch) {
|
||||
providerOptions.webSearch = { enabled: true }
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
providerOptions.generateImage = { enabled: true }
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 Anthropic 特定的 providerOptions
|
||||
*/
|
||||
function buildAnthropicProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: {
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities
|
||||
let providerOptions: Record<string, any> = {}
|
||||
|
||||
// Anthropic 推理参数
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getAnthropicReasoningParams(assistant, model)
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
...reasoningParams
|
||||
}
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
providerOptions.webSearch = { enabled: true }
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
providerOptions.generateImage = { enabled: true }
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 Gemini 特定的 providerOptions
|
||||
*/
|
||||
function buildGeminiProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: {
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities
|
||||
const providerOptions: Record<string, any> = {}
|
||||
|
||||
// Gemini 推理参数
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getGeminiReasoningParams(assistant, model)
|
||||
Object.assign(providerOptions, reasoningParams)
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
providerOptions.webSearch = { enabled: true }
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
providerOptions.generateImage = { enabled: true }
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建通用的 providerOptions(用于其他 provider)
|
||||
*/
|
||||
function buildGenericProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: {
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities
|
||||
let providerOptions: Record<string, any> = {}
|
||||
|
||||
// 使用原有的通用推理逻辑
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getReasoningEffort(assistant, model)
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
...reasoningParams
|
||||
}
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
providerOptions.webSearch = { enabled: true }
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
providerOptions.generateImage = { enabled: true }
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
}
|
||||
@@ -17,8 +17,6 @@ import { getAssistantSettings, getProviderByModel } from '@renderer/services/Ass
|
||||
import { Assistant, EFFORT_RATIO, Model } from '@renderer/types'
|
||||
import { ReasoningEffortOptionalParams } from '@renderer/types/sdk'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
|
||||
export function getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams {
|
||||
const provider = getProviderByModel(model)
|
||||
if (provider.id === 'groq') {
|
||||
@@ -138,192 +136,11 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
|
||||
return {}
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 AI SDK 的 providerOptions
|
||||
* 按 provider 类型分离,保持类型安全
|
||||
* 返回格式:{ 'providerId': providerOptions }
|
||||
*/
|
||||
export function buildProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: {
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const provider = getProviderByModel(model)
|
||||
const providerId = getAiSdkProviderId(provider)
|
||||
|
||||
// 构建 provider 特定的选项
|
||||
let providerSpecificOptions: Record<string, any> = {}
|
||||
|
||||
// 根据 provider 类型分离构建逻辑
|
||||
switch (provider.type) {
|
||||
case 'openai':
|
||||
case 'azure-openai':
|
||||
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
|
||||
case 'anthropic':
|
||||
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
|
||||
case 'gemini':
|
||||
case 'vertexai':
|
||||
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
|
||||
default:
|
||||
// 对于其他 provider,使用通用的构建逻辑
|
||||
providerSpecificOptions = buildGenericProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
}
|
||||
|
||||
// 合并自定义参数到 provider 特定的选项中
|
||||
const customParameters = getCustomParameters(assistant)
|
||||
Object.assign(providerSpecificOptions, customParameters)
|
||||
|
||||
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions }
|
||||
return {
|
||||
[providerId]: providerSpecificOptions
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 OpenAI 特定的 providerOptions
|
||||
*/
|
||||
function buildOpenAIProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: {
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities
|
||||
const providerOptions: Record<string, any> = {}
|
||||
|
||||
// OpenAI 推理参数
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getOpenAIReasoningParams(assistant, model)
|
||||
Object.assign(providerOptions, reasoningParams)
|
||||
}
|
||||
|
||||
// Web 搜索和图像生成暂时使用通用格式
|
||||
if (enableWebSearch) {
|
||||
providerOptions.webSearch = { enabled: true }
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
providerOptions.generateImage = { enabled: true }
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 Anthropic 特定的 providerOptions
|
||||
*/
|
||||
function buildAnthropicProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: {
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities
|
||||
const providerOptions: Record<string, any> = {}
|
||||
|
||||
// Anthropic 推理参数
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getAnthropicReasoningParams(assistant, model)
|
||||
Object.assign(providerOptions, reasoningParams)
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
providerOptions.webSearch = { enabled: true }
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
providerOptions.generateImage = { enabled: true }
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建 Gemini 特定的 providerOptions
|
||||
*/
|
||||
function buildGeminiProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: {
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities
|
||||
const providerOptions: Record<string, any> = {}
|
||||
|
||||
// Gemini 推理参数
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getGeminiReasoningParams(assistant, model)
|
||||
Object.assign(providerOptions, reasoningParams)
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
providerOptions.webSearch = { enabled: true }
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
providerOptions.generateImage = { enabled: true }
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建通用的 providerOptions(用于其他 provider)
|
||||
*/
|
||||
function buildGenericProviderOptions(
|
||||
assistant: Assistant,
|
||||
model: Model,
|
||||
capabilities: {
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableReasoning, enableWebSearch, enableGenerateImage } = capabilities
|
||||
const providerOptions: Record<string, any> = {}
|
||||
|
||||
// 使用原有的通用推理逻辑
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getReasoningEffort(assistant, model)
|
||||
Object.assign(providerOptions, reasoningParams)
|
||||
}
|
||||
|
||||
if (enableWebSearch) {
|
||||
providerOptions.webSearch = { enabled: true }
|
||||
}
|
||||
|
||||
if (enableGenerateImage) {
|
||||
providerOptions.generateImage = { enabled: true }
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 OpenAI 推理参数
|
||||
* 从 OpenAIResponseAPIClient 和 OpenAIAPIClient 中提取的逻辑
|
||||
*/
|
||||
function getOpenAIReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
|
||||
export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
|
||||
if (!isReasoningModel(model)) {
|
||||
return {}
|
||||
}
|
||||
@@ -348,7 +165,7 @@ function getOpenAIReasoningParams(assistant: Assistant, model: Model): Record<st
|
||||
* 获取 Anthropic 推理参数
|
||||
* 从 AnthropicAPIClient 中提取的逻辑
|
||||
*/
|
||||
function getAnthropicReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
|
||||
export function getAnthropicReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
|
||||
if (!isReasoningModel(model)) {
|
||||
return {}
|
||||
}
|
||||
@@ -394,7 +211,7 @@ function getAnthropicReasoningParams(assistant: Assistant, model: Model): Record
|
||||
* 获取 Gemini 推理参数
|
||||
* 从 GeminiAPIClient 中提取的逻辑
|
||||
*/
|
||||
function getGeminiReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
|
||||
export function getGeminiReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
|
||||
if (!isReasoningModel(model)) {
|
||||
return {}
|
||||
}
|
||||
@@ -440,7 +257,7 @@ function getGeminiReasoningParams(assistant: Assistant, model: Model): Record<st
|
||||
* 获取自定义参数
|
||||
* 从 assistant 设置中提取自定义参数
|
||||
*/
|
||||
function getCustomParameters(assistant: Assistant): Record<string, any> {
|
||||
export function getCustomParameters(assistant: Assistant): Record<string, any> {
|
||||
return (
|
||||
assistant?.settings?.customParameters?.reduce((acc, param) => {
|
||||
if (!param.name?.trim()) {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
export const DEFAULT_TEMPERATURE = 1.0
|
||||
export const DEFAULT_CONTEXTCOUNT = 5
|
||||
export const DEFAULT_MAX_TOKENS = 4096
|
||||
export const SYSTEM_PROMPT_THRESHOLD = 128
|
||||
export const DEFAULT_KNOWLEDGE_DOCUMENT_COUNT = 6
|
||||
export const DEFAULT_KNOWLEDGE_THRESHOLD = 0.0
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ import { type Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import { Message } from '@renderer/types/newMessage'
|
||||
import { SdkModel } from '@renderer/types/sdk'
|
||||
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
|
||||
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { isEmpty, takeRight } from 'lodash'
|
||||
|
||||
@@ -302,6 +303,7 @@ export async function fetchChatCompletion({
|
||||
// 使用 transformParameters 模块构建参数
|
||||
const { params: aiSdkParams, modelId } = await buildStreamTextParams(messages, assistant, {
|
||||
mcpTools: mcpTools,
|
||||
enableTools: isEnabledToolUse(assistant),
|
||||
requestOptions: options
|
||||
})
|
||||
|
||||
@@ -311,6 +313,7 @@ export async function fetchChatCompletion({
|
||||
model: assistant.model,
|
||||
provider: provider,
|
||||
enableReasoning: assistant.settings?.reasoning_effort !== undefined,
|
||||
enableTool: assistant.settings?.toolUseMode === 'prompt',
|
||||
mcpTools
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user