diff --git a/package.json b/package.json index 9b82e5e69..a26cb1647 100644 --- a/package.json +++ b/package.json @@ -139,7 +139,7 @@ "@langchain/ollama": "^0.2.1", "@langchain/openai": "^0.6.7", "@mistralai/mistralai": "^1.7.5", - "@modelcontextprotocol/sdk": "^1.17.0", + "@modelcontextprotocol/sdk": "^1.17.5", "@mozilla/readability": "^0.6.0", "@notionhq/client": "^2.2.15", "@openrouter/ai-sdk-provider": "^1.1.2", diff --git a/packages/shared/utils.ts b/packages/shared/utils.ts new file mode 100644 index 000000000..4101a65c0 --- /dev/null +++ b/packages/shared/utils.ts @@ -0,0 +1,7 @@ +export const defaultAppHeaders = () => { + return { + 'HTTP-Referer': 'https://cherry-ai.com', + Referer: 'https://cherry-ai.com', + 'X-Title': 'Cherry Studio' + } +} diff --git a/src/main/services/MCPService.ts b/src/main/services/MCPService.ts index ecb34500c..710867da8 100644 --- a/src/main/services/MCPService.ts +++ b/src/main/services/MCPService.ts @@ -16,6 +16,7 @@ import { type StreamableHTTPClientTransportOptions } from '@modelcontextprotocol/sdk/client/streamableHttp' import { InMemoryTransport } from '@modelcontextprotocol/sdk/inMemory' +import { McpError, type Tool as SDKTool } from '@modelcontextprotocol/sdk/types' // Import notification schemas from MCP SDK import { CancelledNotificationSchema, @@ -29,6 +30,7 @@ import { import { nanoid } from '@reduxjs/toolkit' import { MCPProgressEvent } from '@shared/config/types' import { IpcChannel } from '@shared/IpcChannel' +import { defaultAppHeaders } from '@shared/utils' import { BuiltinMCPServerNames, type GetResourceResponse, @@ -94,7 +96,7 @@ function getServerLogger(server: MCPServer, extra?: Record) { baseUrl: server?.baseUrl, type: server?.type || (server?.command ? 'stdio' : server?.baseUrl ? 'http' : 'inmemory') } - return loggerService.withContext('MCPService', { ...base, ...(extra || {}) }) + return loggerService.withContext('MCPService', { ...base, ...extra }) } /** @@ -193,11 +195,18 @@ class McpService { return existingClient } } catch (error: any) { - getServerLogger(server).error(`Error pinging server`, error as Error) + getServerLogger(server).error(`Error pinging server ${server.name}`, error as Error) this.clients.delete(serverKey) } } + const prepareHeaders = () => { + return { + ...defaultAppHeaders(), + ...server.headers + } + } + // Create a promise for the initialization process const initPromise = (async () => { try { @@ -235,8 +244,11 @@ class McpService { } else if (server.baseUrl) { if (server.type === 'streamableHttp') { const options: StreamableHTTPClientTransportOptions = { + fetch: async (url, init) => { + return net.fetch(typeof url === 'string' ? url : url.toString(), init) + }, requestInit: { - headers: server.headers || {} + headers: prepareHeaders() }, authProvider } @@ -249,25 +261,11 @@ class McpService { const options: SSEClientTransportOptions = { eventSourceInit: { fetch: async (url, init) => { - const headers = { ...(server.headers || {}), ...(init?.headers || {}) } - - // Get tokens from authProvider to make sure using the latest tokens - if (authProvider && typeof authProvider.tokens === 'function') { - try { - const tokens = await authProvider.tokens() - if (tokens && tokens.access_token) { - headers['Authorization'] = `Bearer ${tokens.access_token}` - } - } catch (error) { - getServerLogger(server).error('Failed to fetch tokens:', error as Error) - } - } - - return net.fetch(typeof url === 'string' ? url : url.toString(), { ...init, headers }) + return net.fetch(typeof url === 'string' ? url : url.toString(), init) } }, requestInit: { - headers: server.headers || {} + headers: prepareHeaders() }, authProvider } @@ -444,9 +442,9 @@ class McpService { logger.debug(`Activated server: ${server.name}`) return client - } catch (error: any) { - getServerLogger(server).error(`Error activating server`, error as Error) - throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`) + } catch (error) { + getServerLogger(server).error(`Error activating server ${server.name}`, error as Error) + throw error } } finally { // Clean up the pending promise when done @@ -614,12 +612,11 @@ class McpService { } private async listToolsImpl(server: MCPServer): Promise { - getServerLogger(server).debug(`Listing tools`) const client = await this.initClient(server) try { const { tools } = await client.listTools() const serverTools: MCPTool[] = [] - tools.map((tool: any) => { + tools.map((tool: SDKTool) => { const serverTool: MCPTool = { ...tool, id: buildFunctionCallToolName(server.name, tool.name), @@ -628,11 +625,12 @@ class McpService { type: 'mcp' } serverTools.push(serverTool) + getServerLogger(server).debug(`Listing tools`, { tool: serverTool }) }) return serverTools - } catch (error: any) { + } catch (error: unknown) { getServerLogger(server).error(`Failed to list tools`, error as Error) - return [] + throw error } } @@ -739,9 +737,9 @@ class McpService { serverId: server.id, serverName: server.name })) - } catch (error: any) { + } catch (error: unknown) { // -32601 is the code for the method not found - if (error?.code !== -32601) { + if (error instanceof McpError && error.code !== -32601) { getServerLogger(server).error(`Failed to list prompts`, error as Error) } return [] diff --git a/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts b/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts index bc146ae68..430b9e4df 100644 --- a/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts +++ b/src/renderer/src/aiCore/legacy/clients/BaseApiClient.ts @@ -46,6 +46,7 @@ import { isJSON, parseJSON } from '@renderer/utils' import { addAbortController, removeAbortController } from '@renderer/utils/abortController' import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find' import { defaultTimeout } from '@shared/config/constant' +import { defaultAppHeaders } from '@shared/utils' import { isEmpty } from 'lodash' import { CompletionsContext } from '../middleware/types' @@ -179,8 +180,7 @@ export abstract class BaseApiClient< public defaultHeaders() { return { - 'HTTP-Referer': 'https://cherry-ai.com', - 'X-Title': 'Cherry Studio', + ...defaultAppHeaders(), 'X-Api-Key': this.apiKey } } diff --git a/src/renderer/src/pages/settings/MCPSettings/EditMcpJsonPopup.tsx b/src/renderer/src/pages/settings/MCPSettings/EditMcpJsonPopup.tsx index 37e5aaa4f..48df875fc 100644 --- a/src/renderer/src/pages/settings/MCPSettings/EditMcpJsonPopup.tsx +++ b/src/renderer/src/pages/settings/MCPSettings/EditMcpJsonPopup.tsx @@ -5,7 +5,7 @@ import { useAppDispatch, useAppSelector } from '@renderer/store' import { setMCPServers } from '@renderer/store/mcp' import { MCPServer, safeValidateMcpConfig } from '@renderer/types' import { parseJSON } from '@renderer/utils' -import { formatZodError } from '@renderer/utils/error' +import { formatErrorMessage, formatZodError } from '@renderer/utils/error' import { Modal, Spin, Typography } from 'antd' import { useEffect, useState } from 'react' import { useTranslation } from 'react-i18next' @@ -80,11 +80,8 @@ const PopupContainer: React.FC = ({ resolve }) => { const server: MCPServer = { id, isActive: false, - ...(serverConfig as any) - } - - if (!server.name) { - server.name = id + name: serverConfig.name || id, + ...serverConfig } serversArray.push(server) @@ -95,9 +92,8 @@ const PopupContainer: React.FC = ({ resolve }) => { window.toast.success(t('settings.mcp.jsonSaveSuccess')) setJsonError('') setOpen(false) - } catch (error: any) { - logger.error('Failed to save JSON config:', error) - setJsonError(error.message || t('settings.mcp.jsonSaveError')) + } catch (error: unknown) { + setJsonError(formatErrorMessage(error) || t('settings.mcp.jsonSaveError')) window.toast.error(t('settings.mcp.jsonSaveError')) } finally { setJsonSaving(false) diff --git a/src/renderer/src/pages/settings/MCPSettings/McpSettings.tsx b/src/renderer/src/pages/settings/MCPSettings/McpSettings.tsx index db7484ac2..70488d5b4 100644 --- a/src/renderer/src/pages/settings/MCPSettings/McpSettings.tsx +++ b/src/renderer/src/pages/settings/MCPSettings/McpSettings.tsx @@ -1,4 +1,5 @@ import { loggerService } from '@logger' +import type { McpError } from '@modelcontextprotocol/sdk/types.js' import { DeleteIcon } from '@renderer/components/Icons' import { useTheme } from '@renderer/context/ThemeProvider' import { useMCPServer, useMCPServers } from '@renderer/hooks/useMCPServers' @@ -424,7 +425,7 @@ const McpSettings: React.FC = () => { } catch (error: any) { window.modal.error({ title: t('settings.mcp.startError'), - content: formatMcpError(error), + content: formatMcpError(error as McpError), centered: true }) updateMCPServer({ ...server, isActive: oldActiveState }) diff --git a/src/renderer/src/pages/settings/MCPSettings/providers/modelscope.ts b/src/renderer/src/pages/settings/MCPSettings/providers/modelscope.ts index 0b1a9b585..36fd4b0c0 100644 --- a/src/renderer/src/pages/settings/MCPSettings/providers/modelscope.ts +++ b/src/renderer/src/pages/settings/MCPSettings/providers/modelscope.ts @@ -1,6 +1,6 @@ import { loggerService } from '@logger' import { nanoid } from '@reduxjs/toolkit' -import type { MCPServer } from '@renderer/types' +import { getMcpServerType, type MCPServer } from '@renderer/types' import i18next from 'i18next' const logger = loggerService.withContext('ModelScopeSyncUtils') @@ -104,13 +104,13 @@ export const syncModelScopeServers = async ( // Check if server already exists const existingServer = existingServers.find((s) => s.id === `@modelscope/${server.id}`) - + const url = server.operational_urls[0].url const mcpServer: MCPServer = { id: `@modelscope/${server.id}`, name: server.chinese_name || server.name || `ModelScope Server ${nanoid()}`, description: server.description || '', - type: 'sse', - baseUrl: server.operational_urls[0].url, + type: getMcpServerType(url), + baseUrl: url, command: '', args: [], env: {}, diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index 40b7d4e13..a5fba1040 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -2,7 +2,6 @@ import type { WebSearchResultBlock } from '@anthropic-ai/sdk/resources' import type { GenerateImagesConfig, GroundingMetadata, PersonGeneration } from '@google/genai' import type OpenAI from 'openai' import type { CSSProperties } from 'react' -import { z } from 'zod' export * from './file' export * from './note' @@ -836,20 +835,6 @@ export const isBuiltinMCPServerName = (name: string): name is BuiltinMCPServerNa return BuiltinMCPServerNamesArray.some((n) => n === name) } -export interface MCPToolInputSchema { - type: string - title: string - description?: string - required?: string[] - properties: Record -} - -export const MCPToolOutputSchema = z.object({ - type: z.literal('object'), - properties: z.record(z.string(), z.unknown()), - required: z.array(z.string()) -}) - export interface MCPPromptArguments { name: string description?: string diff --git a/src/renderer/src/types/mcp.ts b/src/renderer/src/types/mcp.ts index 23f04e882..c447d697e 100644 --- a/src/renderer/src/types/mcp.ts +++ b/src/renderer/src/types/mcp.ts @@ -1,4 +1,4 @@ -import { z } from 'zod' +import * as z from 'zod' import { isBuiltinMCPServerName } from '.' @@ -187,18 +187,13 @@ export const McpServerConfigSchema = z // 显式传入的type会覆盖掉从url推断的逻辑 if (!schema.type) { const url = schema.baseUrl ?? schema.url ?? null + // NOTE: url 暗示了服务器的类型为 streamableHttp 或 sse,未来可能会扩展其他类型 if (url !== null) { - if (url.endsWith('/mcp')) { - return { - ...schema, - type: 'streamableHttp' - } as const - } else if (url.endsWith('/sse')) { - return { - ...schema, - type: 'sse' - } as const - } + const type = getMcpServerType(url) + return { + ...schema, + type + } as const } } return schema @@ -254,3 +249,14 @@ export function safeValidateMcpConfig(config: unknown) { export function safeValidateMcpServerConfig(config: unknown) { return McpServerConfigSchema.safeParse(config) } + +/** + * 根据给定的URL判断MCP服务器的类型。 + * 如果URL以 "/mcp" 结尾,则类型为 "streamableHttp",否则为 "sse"。 + * + * @param url - 服务器的URL地址 + * @returns MCP服务器类型('streamableHttp' 或 'sse') + */ +export function getMcpServerType(url: string): McpServerType { + return url.endsWith('/mcp') ? 'streamableHttp' : 'sse' +} diff --git a/src/renderer/src/types/tool.ts b/src/renderer/src/types/tool.ts index d0567a8a6..ad6f2727e 100644 --- a/src/renderer/src/types/tool.ts +++ b/src/renderer/src/types/tool.ts @@ -19,22 +19,24 @@ export interface BaseTool { // providerExecuted?: boolean // 标识是Provider端执行还是客户端执行 // } -export const MCPToolOutputSchema = z.object({ - type: z.literal('object'), - properties: z.record(z.string(), z.unknown()), - required: z.array(z.string()) -}) +export const MCPToolOutputSchema = z + .object({ + type: z.literal('object'), + properties: z.object({}).loose().optional(), + required: z.array(z.string()).optional() + }) + .loose() -export interface MCPToolInputSchema { - type: string - title: string - description?: string - required?: string[] - properties: Record -} +export const MCPToolInputSchema = z + .object({ + type: z.literal('object'), + properties: z.object({}).loose().optional(), + required: z.array(z.string()).optional() + }) + .loose() export interface BuiltinTool extends BaseTool { - inputSchema: MCPToolInputSchema + inputSchema: z.infer type: 'builtin' } @@ -44,7 +46,7 @@ export interface MCPTool extends BaseTool { serverName: string name: string description?: string - inputSchema: MCPToolInputSchema + inputSchema: z.infer outputSchema?: z.infer isBuiltIn?: boolean // 标识是否为内置工具,内置工具不需要通过MCP协议调用 type: 'mcp' diff --git a/src/renderer/src/utils/error.ts b/src/renderer/src/utils/error.ts index 1a2f13a85..024a6713b 100644 --- a/src/renderer/src/utils/error.ts +++ b/src/renderer/src/utils/error.ts @@ -1,3 +1,4 @@ +import { McpError } from '@modelcontextprotocol/sdk/types.js' import { AiSdkErrorUnion, isSerializedAiSdkAPICallError, @@ -41,25 +42,17 @@ export function getErrorDetails(err: any, seen = new WeakSet()): any { return result } -export function formatErrorMessage(error: any): string { - try { - const detailedError = getErrorDetails(error) - delete detailedError?.headers - delete detailedError?.stack - delete detailedError?.request_id +export function formatErrorMessage(error: unknown): string { + const detailedError = getErrorDetails(error) + delete detailedError?.headers + delete detailedError?.stack + delete detailedError?.request_id - const formattedJson = JSON.stringify(detailedError, null, 2) - .split('\n') - .map((line) => ` ${line}`) - .join('\n') - return `Error Details:\n${formattedJson}` - } catch (e) { - try { - return `Error: ${String(error)}` - } catch { - return 'Error: Unable to format error message' - } - } + const formattedJson = JSON.stringify(detailedError, null, 2) + .split('\n') + .map((line) => ` ${line}`) + .join('\n') + return `Error Details:\n${formattedJson}` } export const isAbortError = (error: any): boolean => { @@ -89,10 +82,8 @@ export const isAbortError = (error: any): boolean => { return false } -export const formatMcpError = (error: any) => { - if (error.message.includes('32000')) { - return t('settings.mcp.errors.32000') - } +// TODO: format +export const formatMcpError = (error: McpError) => { return error.message } diff --git a/yarn.lock b/yarn.lock index a1e4ad40e..79a814375 100644 --- a/yarn.lock +++ b/yarn.lock @@ -6654,9 +6654,9 @@ __metadata: languageName: node linkType: hard -"@modelcontextprotocol/sdk@npm:^1.17.0": - version: 1.17.0 - resolution: "@modelcontextprotocol/sdk@npm:1.17.0" +"@modelcontextprotocol/sdk@npm:^1.17.5": + version: 1.17.5 + resolution: "@modelcontextprotocol/sdk@npm:1.17.5" dependencies: ajv: "npm:^6.12.6" content-type: "npm:^1.0.5" @@ -6670,7 +6670,7 @@ __metadata: raw-body: "npm:^3.0.0" zod: "npm:^3.23.8" zod-to-json-schema: "npm:^3.24.1" - checksum: 10c0/ac497edeb05a434bf8092475e4354ec602644b0197735d3bcd809ee1922f2078ab71e7d8d9dbe1c42765978fa3f2f807df01a2a3ad421c986f0b2207c3a40a68 + checksum: 10c0/182b92b5e7c07da428fd23c6de22021c4f9a91f799c02a8ef15def07e4f9361d0fc22303548658fec2a700623535fd44a9dc4d010fb5d803a8f80e3c6c64a45e languageName: node linkType: hard @@ -13067,7 +13067,7 @@ __metadata: "@libsql/client": "npm:0.14.0" "@libsql/win32-x64-msvc": "npm:^0.4.7" "@mistralai/mistralai": "npm:^1.7.5" - "@modelcontextprotocol/sdk": "npm:^1.17.0" + "@modelcontextprotocol/sdk": "npm:^1.17.5" "@mozilla/readability": "npm:^0.6.0" "@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch" "@notionhq/client": "npm:^2.2.15"