refactor: MCPService for improved error handling and header management (#10100)

* refactor MCPService for improved error handling and header management

* refactor MCPService: reorder header preparation for improved clarity

* refactor: enhance MCP server type determination and clean up error handling
This commit is contained in:
SuYao
2025-09-11 20:52:13 +08:00
committed by GitHub
parent 95a332f38a
commit 6104b7803b
12 changed files with 99 additions and 113 deletions

View File

@@ -139,7 +139,7 @@
"@langchain/ollama": "^0.2.1",
"@langchain/openai": "^0.6.7",
"@mistralai/mistralai": "^1.7.5",
"@modelcontextprotocol/sdk": "^1.17.0",
"@modelcontextprotocol/sdk": "^1.17.5",
"@mozilla/readability": "^0.6.0",
"@notionhq/client": "^2.2.15",
"@openrouter/ai-sdk-provider": "^1.1.2",

7
packages/shared/utils.ts Normal file
View File

@@ -0,0 +1,7 @@
export const defaultAppHeaders = () => {
return {
'HTTP-Referer': 'https://cherry-ai.com',
Referer: 'https://cherry-ai.com',
'X-Title': 'Cherry Studio'
}
}

View File

@@ -16,6 +16,7 @@ import {
type StreamableHTTPClientTransportOptions
} from '@modelcontextprotocol/sdk/client/streamableHttp'
import { InMemoryTransport } from '@modelcontextprotocol/sdk/inMemory'
import { McpError, type Tool as SDKTool } from '@modelcontextprotocol/sdk/types'
// Import notification schemas from MCP SDK
import {
CancelledNotificationSchema,
@@ -29,6 +30,7 @@ import {
import { nanoid } from '@reduxjs/toolkit'
import { MCPProgressEvent } from '@shared/config/types'
import { IpcChannel } from '@shared/IpcChannel'
import { defaultAppHeaders } from '@shared/utils'
import {
BuiltinMCPServerNames,
type GetResourceResponse,
@@ -94,7 +96,7 @@ function getServerLogger(server: MCPServer, extra?: Record<string, any>) {
baseUrl: server?.baseUrl,
type: server?.type || (server?.command ? 'stdio' : server?.baseUrl ? 'http' : 'inmemory')
}
return loggerService.withContext('MCPService', { ...base, ...(extra || {}) })
return loggerService.withContext('MCPService', { ...base, ...extra })
}
/**
@@ -193,11 +195,18 @@ class McpService {
return existingClient
}
} catch (error: any) {
getServerLogger(server).error(`Error pinging server`, error as Error)
getServerLogger(server).error(`Error pinging server ${server.name}`, error as Error)
this.clients.delete(serverKey)
}
}
const prepareHeaders = () => {
return {
...defaultAppHeaders(),
...server.headers
}
}
// Create a promise for the initialization process
const initPromise = (async () => {
try {
@@ -235,8 +244,11 @@ class McpService {
} else if (server.baseUrl) {
if (server.type === 'streamableHttp') {
const options: StreamableHTTPClientTransportOptions = {
fetch: async (url, init) => {
return net.fetch(typeof url === 'string' ? url : url.toString(), init)
},
requestInit: {
headers: server.headers || {}
headers: prepareHeaders()
},
authProvider
}
@@ -249,25 +261,11 @@ class McpService {
const options: SSEClientTransportOptions = {
eventSourceInit: {
fetch: async (url, init) => {
const headers = { ...(server.headers || {}), ...(init?.headers || {}) }
// Get tokens from authProvider to make sure using the latest tokens
if (authProvider && typeof authProvider.tokens === 'function') {
try {
const tokens = await authProvider.tokens()
if (tokens && tokens.access_token) {
headers['Authorization'] = `Bearer ${tokens.access_token}`
}
} catch (error) {
getServerLogger(server).error('Failed to fetch tokens:', error as Error)
}
}
return net.fetch(typeof url === 'string' ? url : url.toString(), { ...init, headers })
return net.fetch(typeof url === 'string' ? url : url.toString(), init)
}
},
requestInit: {
headers: server.headers || {}
headers: prepareHeaders()
},
authProvider
}
@@ -444,9 +442,9 @@ class McpService {
logger.debug(`Activated server: ${server.name}`)
return client
} catch (error: any) {
getServerLogger(server).error(`Error activating server`, error as Error)
throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`)
} catch (error) {
getServerLogger(server).error(`Error activating server ${server.name}`, error as Error)
throw error
}
} finally {
// Clean up the pending promise when done
@@ -614,12 +612,11 @@ class McpService {
}
private async listToolsImpl(server: MCPServer): Promise<MCPTool[]> {
getServerLogger(server).debug(`Listing tools`)
const client = await this.initClient(server)
try {
const { tools } = await client.listTools()
const serverTools: MCPTool[] = []
tools.map((tool: any) => {
tools.map((tool: SDKTool) => {
const serverTool: MCPTool = {
...tool,
id: buildFunctionCallToolName(server.name, tool.name),
@@ -628,11 +625,12 @@ class McpService {
type: 'mcp'
}
serverTools.push(serverTool)
getServerLogger(server).debug(`Listing tools`, { tool: serverTool })
})
return serverTools
} catch (error: any) {
} catch (error: unknown) {
getServerLogger(server).error(`Failed to list tools`, error as Error)
return []
throw error
}
}
@@ -739,9 +737,9 @@ class McpService {
serverId: server.id,
serverName: server.name
}))
} catch (error: any) {
} catch (error: unknown) {
// -32601 is the code for the method not found
if (error?.code !== -32601) {
if (error instanceof McpError && error.code !== -32601) {
getServerLogger(server).error(`Failed to list prompts`, error as Error)
}
return []

View File

@@ -46,6 +46,7 @@ import { isJSON, parseJSON } from '@renderer/utils'
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { defaultTimeout } from '@shared/config/constant'
import { defaultAppHeaders } from '@shared/utils'
import { isEmpty } from 'lodash'
import { CompletionsContext } from '../middleware/types'
@@ -179,8 +180,7 @@ export abstract class BaseApiClient<
public defaultHeaders() {
return {
'HTTP-Referer': 'https://cherry-ai.com',
'X-Title': 'Cherry Studio',
...defaultAppHeaders(),
'X-Api-Key': this.apiKey
}
}

View File

@@ -5,7 +5,7 @@ import { useAppDispatch, useAppSelector } from '@renderer/store'
import { setMCPServers } from '@renderer/store/mcp'
import { MCPServer, safeValidateMcpConfig } from '@renderer/types'
import { parseJSON } from '@renderer/utils'
import { formatZodError } from '@renderer/utils/error'
import { formatErrorMessage, formatZodError } from '@renderer/utils/error'
import { Modal, Spin, Typography } from 'antd'
import { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
@@ -80,11 +80,8 @@ const PopupContainer: React.FC<Props> = ({ resolve }) => {
const server: MCPServer = {
id,
isActive: false,
...(serverConfig as any)
}
if (!server.name) {
server.name = id
name: serverConfig.name || id,
...serverConfig
}
serversArray.push(server)
@@ -95,9 +92,8 @@ const PopupContainer: React.FC<Props> = ({ resolve }) => {
window.toast.success(t('settings.mcp.jsonSaveSuccess'))
setJsonError('')
setOpen(false)
} catch (error: any) {
logger.error('Failed to save JSON config:', error)
setJsonError(error.message || t('settings.mcp.jsonSaveError'))
} catch (error: unknown) {
setJsonError(formatErrorMessage(error) || t('settings.mcp.jsonSaveError'))
window.toast.error(t('settings.mcp.jsonSaveError'))
} finally {
setJsonSaving(false)

View File

@@ -1,4 +1,5 @@
import { loggerService } from '@logger'
import type { McpError } from '@modelcontextprotocol/sdk/types.js'
import { DeleteIcon } from '@renderer/components/Icons'
import { useTheme } from '@renderer/context/ThemeProvider'
import { useMCPServer, useMCPServers } from '@renderer/hooks/useMCPServers'
@@ -424,7 +425,7 @@ const McpSettings: React.FC = () => {
} catch (error: any) {
window.modal.error({
title: t('settings.mcp.startError'),
content: formatMcpError(error),
content: formatMcpError(error as McpError),
centered: true
})
updateMCPServer({ ...server, isActive: oldActiveState })

View File

@@ -1,6 +1,6 @@
import { loggerService } from '@logger'
import { nanoid } from '@reduxjs/toolkit'
import type { MCPServer } from '@renderer/types'
import { getMcpServerType, type MCPServer } from '@renderer/types'
import i18next from 'i18next'
const logger = loggerService.withContext('ModelScopeSyncUtils')
@@ -104,13 +104,13 @@ export const syncModelScopeServers = async (
// Check if server already exists
const existingServer = existingServers.find((s) => s.id === `@modelscope/${server.id}`)
const url = server.operational_urls[0].url
const mcpServer: MCPServer = {
id: `@modelscope/${server.id}`,
name: server.chinese_name || server.name || `ModelScope Server ${nanoid()}`,
description: server.description || '',
type: 'sse',
baseUrl: server.operational_urls[0].url,
type: getMcpServerType(url),
baseUrl: url,
command: '',
args: [],
env: {},

View File

@@ -2,7 +2,6 @@ import type { WebSearchResultBlock } from '@anthropic-ai/sdk/resources'
import type { GenerateImagesConfig, GroundingMetadata, PersonGeneration } from '@google/genai'
import type OpenAI from 'openai'
import type { CSSProperties } from 'react'
import { z } from 'zod'
export * from './file'
export * from './note'
@@ -836,20 +835,6 @@ export const isBuiltinMCPServerName = (name: string): name is BuiltinMCPServerNa
return BuiltinMCPServerNamesArray.some((n) => n === name)
}
export interface MCPToolInputSchema {
type: string
title: string
description?: string
required?: string[]
properties: Record<string, object>
}
export const MCPToolOutputSchema = z.object({
type: z.literal('object'),
properties: z.record(z.string(), z.unknown()),
required: z.array(z.string())
})
export interface MCPPromptArguments {
name: string
description?: string

View File

@@ -1,4 +1,4 @@
import { z } from 'zod'
import * as z from 'zod'
import { isBuiltinMCPServerName } from '.'
@@ -187,18 +187,13 @@ export const McpServerConfigSchema = z
// 显式传入的type会覆盖掉从url推断的逻辑
if (!schema.type) {
const url = schema.baseUrl ?? schema.url ?? null
// NOTE: url 暗示了服务器的类型为 streamableHttp 或 sse未来可能会扩展其他类型
if (url !== null) {
if (url.endsWith('/mcp')) {
return {
...schema,
type: 'streamableHttp'
} as const
} else if (url.endsWith('/sse')) {
return {
...schema,
type: 'sse'
} as const
}
const type = getMcpServerType(url)
return {
...schema,
type
} as const
}
}
return schema
@@ -254,3 +249,14 @@ export function safeValidateMcpConfig(config: unknown) {
export function safeValidateMcpServerConfig(config: unknown) {
return McpServerConfigSchema.safeParse(config)
}
/**
* 根据给定的URL判断MCP服务器的类型。
* 如果URL以 "/mcp" 结尾,则类型为 "streamableHttp",否则为 "sse"。
*
* @param url - 服务器的URL地址
* @returns MCP服务器类型'streamableHttp' 或 'sse'
*/
export function getMcpServerType(url: string): McpServerType {
return url.endsWith('/mcp') ? 'streamableHttp' : 'sse'
}

View File

@@ -19,22 +19,24 @@ export interface BaseTool {
// providerExecuted?: boolean // 标识是Provider端执行还是客户端执行
// }
export const MCPToolOutputSchema = z.object({
type: z.literal('object'),
properties: z.record(z.string(), z.unknown()),
required: z.array(z.string())
})
export const MCPToolOutputSchema = z
.object({
type: z.literal('object'),
properties: z.object({}).loose().optional(),
required: z.array(z.string()).optional()
})
.loose()
export interface MCPToolInputSchema {
type: string
title: string
description?: string
required?: string[]
properties: Record<string, object>
}
export const MCPToolInputSchema = z
.object({
type: z.literal('object'),
properties: z.object({}).loose().optional(),
required: z.array(z.string()).optional()
})
.loose()
export interface BuiltinTool extends BaseTool {
inputSchema: MCPToolInputSchema
inputSchema: z.infer<typeof MCPToolInputSchema>
type: 'builtin'
}
@@ -44,7 +46,7 @@ export interface MCPTool extends BaseTool {
serverName: string
name: string
description?: string
inputSchema: MCPToolInputSchema
inputSchema: z.infer<typeof MCPToolInputSchema>
outputSchema?: z.infer<typeof MCPToolOutputSchema>
isBuiltIn?: boolean // 标识是否为内置工具内置工具不需要通过MCP协议调用
type: 'mcp'

View File

@@ -1,3 +1,4 @@
import { McpError } from '@modelcontextprotocol/sdk/types.js'
import {
AiSdkErrorUnion,
isSerializedAiSdkAPICallError,
@@ -41,25 +42,17 @@ export function getErrorDetails(err: any, seen = new WeakSet()): any {
return result
}
export function formatErrorMessage(error: any): string {
try {
const detailedError = getErrorDetails(error)
delete detailedError?.headers
delete detailedError?.stack
delete detailedError?.request_id
export function formatErrorMessage(error: unknown): string {
const detailedError = getErrorDetails(error)
delete detailedError?.headers
delete detailedError?.stack
delete detailedError?.request_id
const formattedJson = JSON.stringify(detailedError, null, 2)
.split('\n')
.map((line) => ` ${line}`)
.join('\n')
return `Error Details:\n${formattedJson}`
} catch (e) {
try {
return `Error: ${String(error)}`
} catch {
return 'Error: Unable to format error message'
}
}
const formattedJson = JSON.stringify(detailedError, null, 2)
.split('\n')
.map((line) => ` ${line}`)
.join('\n')
return `Error Details:\n${formattedJson}`
}
export const isAbortError = (error: any): boolean => {
@@ -89,10 +82,8 @@ export const isAbortError = (error: any): boolean => {
return false
}
export const formatMcpError = (error: any) => {
if (error.message.includes('32000')) {
return t('settings.mcp.errors.32000')
}
// TODO: format
export const formatMcpError = (error: McpError) => {
return error.message
}

View File

@@ -6654,9 +6654,9 @@ __metadata:
languageName: node
linkType: hard
"@modelcontextprotocol/sdk@npm:^1.17.0":
version: 1.17.0
resolution: "@modelcontextprotocol/sdk@npm:1.17.0"
"@modelcontextprotocol/sdk@npm:^1.17.5":
version: 1.17.5
resolution: "@modelcontextprotocol/sdk@npm:1.17.5"
dependencies:
ajv: "npm:^6.12.6"
content-type: "npm:^1.0.5"
@@ -6670,7 +6670,7 @@ __metadata:
raw-body: "npm:^3.0.0"
zod: "npm:^3.23.8"
zod-to-json-schema: "npm:^3.24.1"
checksum: 10c0/ac497edeb05a434bf8092475e4354ec602644b0197735d3bcd809ee1922f2078ab71e7d8d9dbe1c42765978fa3f2f807df01a2a3ad421c986f0b2207c3a40a68
checksum: 10c0/182b92b5e7c07da428fd23c6de22021c4f9a91f799c02a8ef15def07e4f9361d0fc22303548658fec2a700623535fd44a9dc4d010fb5d803a8f80e3c6c64a45e
languageName: node
linkType: hard
@@ -13067,7 +13067,7 @@ __metadata:
"@libsql/client": "npm:0.14.0"
"@libsql/win32-x64-msvc": "npm:^0.4.7"
"@mistralai/mistralai": "npm:^1.7.5"
"@modelcontextprotocol/sdk": "npm:^1.17.0"
"@modelcontextprotocol/sdk": "npm:^1.17.5"
"@mozilla/readability": "npm:^0.6.0"
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch"
"@notionhq/client": "npm:^2.2.15"