refactor: MCPService for improved error handling and header management (#10100)
* refactor MCPService for improved error handling and header management * refactor MCPService: reorder header preparation for improved clarity * refactor: enhance MCP server type determination and clean up error handling
This commit is contained in:
@@ -139,7 +139,7 @@
|
||||
"@langchain/ollama": "^0.2.1",
|
||||
"@langchain/openai": "^0.6.7",
|
||||
"@mistralai/mistralai": "^1.7.5",
|
||||
"@modelcontextprotocol/sdk": "^1.17.0",
|
||||
"@modelcontextprotocol/sdk": "^1.17.5",
|
||||
"@mozilla/readability": "^0.6.0",
|
||||
"@notionhq/client": "^2.2.15",
|
||||
"@openrouter/ai-sdk-provider": "^1.1.2",
|
||||
|
||||
7
packages/shared/utils.ts
Normal file
7
packages/shared/utils.ts
Normal file
@@ -0,0 +1,7 @@
|
||||
export const defaultAppHeaders = () => {
|
||||
return {
|
||||
'HTTP-Referer': 'https://cherry-ai.com',
|
||||
Referer: 'https://cherry-ai.com',
|
||||
'X-Title': 'Cherry Studio'
|
||||
}
|
||||
}
|
||||
@@ -16,6 +16,7 @@ import {
|
||||
type StreamableHTTPClientTransportOptions
|
||||
} from '@modelcontextprotocol/sdk/client/streamableHttp'
|
||||
import { InMemoryTransport } from '@modelcontextprotocol/sdk/inMemory'
|
||||
import { McpError, type Tool as SDKTool } from '@modelcontextprotocol/sdk/types'
|
||||
// Import notification schemas from MCP SDK
|
||||
import {
|
||||
CancelledNotificationSchema,
|
||||
@@ -29,6 +30,7 @@ import {
|
||||
import { nanoid } from '@reduxjs/toolkit'
|
||||
import { MCPProgressEvent } from '@shared/config/types'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { defaultAppHeaders } from '@shared/utils'
|
||||
import {
|
||||
BuiltinMCPServerNames,
|
||||
type GetResourceResponse,
|
||||
@@ -94,7 +96,7 @@ function getServerLogger(server: MCPServer, extra?: Record<string, any>) {
|
||||
baseUrl: server?.baseUrl,
|
||||
type: server?.type || (server?.command ? 'stdio' : server?.baseUrl ? 'http' : 'inmemory')
|
||||
}
|
||||
return loggerService.withContext('MCPService', { ...base, ...(extra || {}) })
|
||||
return loggerService.withContext('MCPService', { ...base, ...extra })
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -193,11 +195,18 @@ class McpService {
|
||||
return existingClient
|
||||
}
|
||||
} catch (error: any) {
|
||||
getServerLogger(server).error(`Error pinging server`, error as Error)
|
||||
getServerLogger(server).error(`Error pinging server ${server.name}`, error as Error)
|
||||
this.clients.delete(serverKey)
|
||||
}
|
||||
}
|
||||
|
||||
const prepareHeaders = () => {
|
||||
return {
|
||||
...defaultAppHeaders(),
|
||||
...server.headers
|
||||
}
|
||||
}
|
||||
|
||||
// Create a promise for the initialization process
|
||||
const initPromise = (async () => {
|
||||
try {
|
||||
@@ -235,8 +244,11 @@ class McpService {
|
||||
} else if (server.baseUrl) {
|
||||
if (server.type === 'streamableHttp') {
|
||||
const options: StreamableHTTPClientTransportOptions = {
|
||||
fetch: async (url, init) => {
|
||||
return net.fetch(typeof url === 'string' ? url : url.toString(), init)
|
||||
},
|
||||
requestInit: {
|
||||
headers: server.headers || {}
|
||||
headers: prepareHeaders()
|
||||
},
|
||||
authProvider
|
||||
}
|
||||
@@ -249,25 +261,11 @@ class McpService {
|
||||
const options: SSEClientTransportOptions = {
|
||||
eventSourceInit: {
|
||||
fetch: async (url, init) => {
|
||||
const headers = { ...(server.headers || {}), ...(init?.headers || {}) }
|
||||
|
||||
// Get tokens from authProvider to make sure using the latest tokens
|
||||
if (authProvider && typeof authProvider.tokens === 'function') {
|
||||
try {
|
||||
const tokens = await authProvider.tokens()
|
||||
if (tokens && tokens.access_token) {
|
||||
headers['Authorization'] = `Bearer ${tokens.access_token}`
|
||||
}
|
||||
} catch (error) {
|
||||
getServerLogger(server).error('Failed to fetch tokens:', error as Error)
|
||||
}
|
||||
}
|
||||
|
||||
return net.fetch(typeof url === 'string' ? url : url.toString(), { ...init, headers })
|
||||
return net.fetch(typeof url === 'string' ? url : url.toString(), init)
|
||||
}
|
||||
},
|
||||
requestInit: {
|
||||
headers: server.headers || {}
|
||||
headers: prepareHeaders()
|
||||
},
|
||||
authProvider
|
||||
}
|
||||
@@ -444,9 +442,9 @@ class McpService {
|
||||
|
||||
logger.debug(`Activated server: ${server.name}`)
|
||||
return client
|
||||
} catch (error: any) {
|
||||
getServerLogger(server).error(`Error activating server`, error as Error)
|
||||
throw new Error(`[MCP] Error activating server ${server.name}: ${error.message}`)
|
||||
} catch (error) {
|
||||
getServerLogger(server).error(`Error activating server ${server.name}`, error as Error)
|
||||
throw error
|
||||
}
|
||||
} finally {
|
||||
// Clean up the pending promise when done
|
||||
@@ -614,12 +612,11 @@ class McpService {
|
||||
}
|
||||
|
||||
private async listToolsImpl(server: MCPServer): Promise<MCPTool[]> {
|
||||
getServerLogger(server).debug(`Listing tools`)
|
||||
const client = await this.initClient(server)
|
||||
try {
|
||||
const { tools } = await client.listTools()
|
||||
const serverTools: MCPTool[] = []
|
||||
tools.map((tool: any) => {
|
||||
tools.map((tool: SDKTool) => {
|
||||
const serverTool: MCPTool = {
|
||||
...tool,
|
||||
id: buildFunctionCallToolName(server.name, tool.name),
|
||||
@@ -628,11 +625,12 @@ class McpService {
|
||||
type: 'mcp'
|
||||
}
|
||||
serverTools.push(serverTool)
|
||||
getServerLogger(server).debug(`Listing tools`, { tool: serverTool })
|
||||
})
|
||||
return serverTools
|
||||
} catch (error: any) {
|
||||
} catch (error: unknown) {
|
||||
getServerLogger(server).error(`Failed to list tools`, error as Error)
|
||||
return []
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
@@ -739,9 +737,9 @@ class McpService {
|
||||
serverId: server.id,
|
||||
serverName: server.name
|
||||
}))
|
||||
} catch (error: any) {
|
||||
} catch (error: unknown) {
|
||||
// -32601 is the code for the method not found
|
||||
if (error?.code !== -32601) {
|
||||
if (error instanceof McpError && error.code !== -32601) {
|
||||
getServerLogger(server).error(`Failed to list prompts`, error as Error)
|
||||
}
|
||||
return []
|
||||
|
||||
@@ -46,6 +46,7 @@ import { isJSON, parseJSON } from '@renderer/utils'
|
||||
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
|
||||
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
import { defaultAppHeaders } from '@shared/utils'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import { CompletionsContext } from '../middleware/types'
|
||||
@@ -179,8 +180,7 @@ export abstract class BaseApiClient<
|
||||
|
||||
public defaultHeaders() {
|
||||
return {
|
||||
'HTTP-Referer': 'https://cherry-ai.com',
|
||||
'X-Title': 'Cherry Studio',
|
||||
...defaultAppHeaders(),
|
||||
'X-Api-Key': this.apiKey
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,7 +5,7 @@ import { useAppDispatch, useAppSelector } from '@renderer/store'
|
||||
import { setMCPServers } from '@renderer/store/mcp'
|
||||
import { MCPServer, safeValidateMcpConfig } from '@renderer/types'
|
||||
import { parseJSON } from '@renderer/utils'
|
||||
import { formatZodError } from '@renderer/utils/error'
|
||||
import { formatErrorMessage, formatZodError } from '@renderer/utils/error'
|
||||
import { Modal, Spin, Typography } from 'antd'
|
||||
import { useEffect, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
@@ -80,11 +80,8 @@ const PopupContainer: React.FC<Props> = ({ resolve }) => {
|
||||
const server: MCPServer = {
|
||||
id,
|
||||
isActive: false,
|
||||
...(serverConfig as any)
|
||||
}
|
||||
|
||||
if (!server.name) {
|
||||
server.name = id
|
||||
name: serverConfig.name || id,
|
||||
...serverConfig
|
||||
}
|
||||
|
||||
serversArray.push(server)
|
||||
@@ -95,9 +92,8 @@ const PopupContainer: React.FC<Props> = ({ resolve }) => {
|
||||
window.toast.success(t('settings.mcp.jsonSaveSuccess'))
|
||||
setJsonError('')
|
||||
setOpen(false)
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to save JSON config:', error)
|
||||
setJsonError(error.message || t('settings.mcp.jsonSaveError'))
|
||||
} catch (error: unknown) {
|
||||
setJsonError(formatErrorMessage(error) || t('settings.mcp.jsonSaveError'))
|
||||
window.toast.error(t('settings.mcp.jsonSaveError'))
|
||||
} finally {
|
||||
setJsonSaving(false)
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type { McpError } from '@modelcontextprotocol/sdk/types.js'
|
||||
import { DeleteIcon } from '@renderer/components/Icons'
|
||||
import { useTheme } from '@renderer/context/ThemeProvider'
|
||||
import { useMCPServer, useMCPServers } from '@renderer/hooks/useMCPServers'
|
||||
@@ -424,7 +425,7 @@ const McpSettings: React.FC = () => {
|
||||
} catch (error: any) {
|
||||
window.modal.error({
|
||||
title: t('settings.mcp.startError'),
|
||||
content: formatMcpError(error),
|
||||
content: formatMcpError(error as McpError),
|
||||
centered: true
|
||||
})
|
||||
updateMCPServer({ ...server, isActive: oldActiveState })
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { nanoid } from '@reduxjs/toolkit'
|
||||
import type { MCPServer } from '@renderer/types'
|
||||
import { getMcpServerType, type MCPServer } from '@renderer/types'
|
||||
import i18next from 'i18next'
|
||||
|
||||
const logger = loggerService.withContext('ModelScopeSyncUtils')
|
||||
@@ -104,13 +104,13 @@ export const syncModelScopeServers = async (
|
||||
|
||||
// Check if server already exists
|
||||
const existingServer = existingServers.find((s) => s.id === `@modelscope/${server.id}`)
|
||||
|
||||
const url = server.operational_urls[0].url
|
||||
const mcpServer: MCPServer = {
|
||||
id: `@modelscope/${server.id}`,
|
||||
name: server.chinese_name || server.name || `ModelScope Server ${nanoid()}`,
|
||||
description: server.description || '',
|
||||
type: 'sse',
|
||||
baseUrl: server.operational_urls[0].url,
|
||||
type: getMcpServerType(url),
|
||||
baseUrl: url,
|
||||
command: '',
|
||||
args: [],
|
||||
env: {},
|
||||
|
||||
@@ -2,7 +2,6 @@ import type { WebSearchResultBlock } from '@anthropic-ai/sdk/resources'
|
||||
import type { GenerateImagesConfig, GroundingMetadata, PersonGeneration } from '@google/genai'
|
||||
import type OpenAI from 'openai'
|
||||
import type { CSSProperties } from 'react'
|
||||
import { z } from 'zod'
|
||||
|
||||
export * from './file'
|
||||
export * from './note'
|
||||
@@ -836,20 +835,6 @@ export const isBuiltinMCPServerName = (name: string): name is BuiltinMCPServerNa
|
||||
return BuiltinMCPServerNamesArray.some((n) => n === name)
|
||||
}
|
||||
|
||||
export interface MCPToolInputSchema {
|
||||
type: string
|
||||
title: string
|
||||
description?: string
|
||||
required?: string[]
|
||||
properties: Record<string, object>
|
||||
}
|
||||
|
||||
export const MCPToolOutputSchema = z.object({
|
||||
type: z.literal('object'),
|
||||
properties: z.record(z.string(), z.unknown()),
|
||||
required: z.array(z.string())
|
||||
})
|
||||
|
||||
export interface MCPPromptArguments {
|
||||
name: string
|
||||
description?: string
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { z } from 'zod'
|
||||
import * as z from 'zod'
|
||||
|
||||
import { isBuiltinMCPServerName } from '.'
|
||||
|
||||
@@ -187,18 +187,13 @@ export const McpServerConfigSchema = z
|
||||
// 显式传入的type会覆盖掉从url推断的逻辑
|
||||
if (!schema.type) {
|
||||
const url = schema.baseUrl ?? schema.url ?? null
|
||||
// NOTE: url 暗示了服务器的类型为 streamableHttp 或 sse,未来可能会扩展其他类型
|
||||
if (url !== null) {
|
||||
if (url.endsWith('/mcp')) {
|
||||
return {
|
||||
...schema,
|
||||
type: 'streamableHttp'
|
||||
} as const
|
||||
} else if (url.endsWith('/sse')) {
|
||||
return {
|
||||
...schema,
|
||||
type: 'sse'
|
||||
} as const
|
||||
}
|
||||
const type = getMcpServerType(url)
|
||||
return {
|
||||
...schema,
|
||||
type
|
||||
} as const
|
||||
}
|
||||
}
|
||||
return schema
|
||||
@@ -254,3 +249,14 @@ export function safeValidateMcpConfig(config: unknown) {
|
||||
export function safeValidateMcpServerConfig(config: unknown) {
|
||||
return McpServerConfigSchema.safeParse(config)
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据给定的URL判断MCP服务器的类型。
|
||||
* 如果URL以 "/mcp" 结尾,则类型为 "streamableHttp",否则为 "sse"。
|
||||
*
|
||||
* @param url - 服务器的URL地址
|
||||
* @returns MCP服务器类型('streamableHttp' 或 'sse')
|
||||
*/
|
||||
export function getMcpServerType(url: string): McpServerType {
|
||||
return url.endsWith('/mcp') ? 'streamableHttp' : 'sse'
|
||||
}
|
||||
|
||||
@@ -19,22 +19,24 @@ export interface BaseTool {
|
||||
// providerExecuted?: boolean // 标识是Provider端执行还是客户端执行
|
||||
// }
|
||||
|
||||
export const MCPToolOutputSchema = z.object({
|
||||
type: z.literal('object'),
|
||||
properties: z.record(z.string(), z.unknown()),
|
||||
required: z.array(z.string())
|
||||
})
|
||||
export const MCPToolOutputSchema = z
|
||||
.object({
|
||||
type: z.literal('object'),
|
||||
properties: z.object({}).loose().optional(),
|
||||
required: z.array(z.string()).optional()
|
||||
})
|
||||
.loose()
|
||||
|
||||
export interface MCPToolInputSchema {
|
||||
type: string
|
||||
title: string
|
||||
description?: string
|
||||
required?: string[]
|
||||
properties: Record<string, object>
|
||||
}
|
||||
export const MCPToolInputSchema = z
|
||||
.object({
|
||||
type: z.literal('object'),
|
||||
properties: z.object({}).loose().optional(),
|
||||
required: z.array(z.string()).optional()
|
||||
})
|
||||
.loose()
|
||||
|
||||
export interface BuiltinTool extends BaseTool {
|
||||
inputSchema: MCPToolInputSchema
|
||||
inputSchema: z.infer<typeof MCPToolInputSchema>
|
||||
type: 'builtin'
|
||||
}
|
||||
|
||||
@@ -44,7 +46,7 @@ export interface MCPTool extends BaseTool {
|
||||
serverName: string
|
||||
name: string
|
||||
description?: string
|
||||
inputSchema: MCPToolInputSchema
|
||||
inputSchema: z.infer<typeof MCPToolInputSchema>
|
||||
outputSchema?: z.infer<typeof MCPToolOutputSchema>
|
||||
isBuiltIn?: boolean // 标识是否为内置工具,内置工具不需要通过MCP协议调用
|
||||
type: 'mcp'
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { McpError } from '@modelcontextprotocol/sdk/types.js'
|
||||
import {
|
||||
AiSdkErrorUnion,
|
||||
isSerializedAiSdkAPICallError,
|
||||
@@ -41,25 +42,17 @@ export function getErrorDetails(err: any, seen = new WeakSet()): any {
|
||||
return result
|
||||
}
|
||||
|
||||
export function formatErrorMessage(error: any): string {
|
||||
try {
|
||||
const detailedError = getErrorDetails(error)
|
||||
delete detailedError?.headers
|
||||
delete detailedError?.stack
|
||||
delete detailedError?.request_id
|
||||
export function formatErrorMessage(error: unknown): string {
|
||||
const detailedError = getErrorDetails(error)
|
||||
delete detailedError?.headers
|
||||
delete detailedError?.stack
|
||||
delete detailedError?.request_id
|
||||
|
||||
const formattedJson = JSON.stringify(detailedError, null, 2)
|
||||
.split('\n')
|
||||
.map((line) => ` ${line}`)
|
||||
.join('\n')
|
||||
return `Error Details:\n${formattedJson}`
|
||||
} catch (e) {
|
||||
try {
|
||||
return `Error: ${String(error)}`
|
||||
} catch {
|
||||
return 'Error: Unable to format error message'
|
||||
}
|
||||
}
|
||||
const formattedJson = JSON.stringify(detailedError, null, 2)
|
||||
.split('\n')
|
||||
.map((line) => ` ${line}`)
|
||||
.join('\n')
|
||||
return `Error Details:\n${formattedJson}`
|
||||
}
|
||||
|
||||
export const isAbortError = (error: any): boolean => {
|
||||
@@ -89,10 +82,8 @@ export const isAbortError = (error: any): boolean => {
|
||||
return false
|
||||
}
|
||||
|
||||
export const formatMcpError = (error: any) => {
|
||||
if (error.message.includes('32000')) {
|
||||
return t('settings.mcp.errors.32000')
|
||||
}
|
||||
// TODO: format
|
||||
export const formatMcpError = (error: McpError) => {
|
||||
return error.message
|
||||
}
|
||||
|
||||
|
||||
10
yarn.lock
10
yarn.lock
@@ -6654,9 +6654,9 @@ __metadata:
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
"@modelcontextprotocol/sdk@npm:^1.17.0":
|
||||
version: 1.17.0
|
||||
resolution: "@modelcontextprotocol/sdk@npm:1.17.0"
|
||||
"@modelcontextprotocol/sdk@npm:^1.17.5":
|
||||
version: 1.17.5
|
||||
resolution: "@modelcontextprotocol/sdk@npm:1.17.5"
|
||||
dependencies:
|
||||
ajv: "npm:^6.12.6"
|
||||
content-type: "npm:^1.0.5"
|
||||
@@ -6670,7 +6670,7 @@ __metadata:
|
||||
raw-body: "npm:^3.0.0"
|
||||
zod: "npm:^3.23.8"
|
||||
zod-to-json-schema: "npm:^3.24.1"
|
||||
checksum: 10c0/ac497edeb05a434bf8092475e4354ec602644b0197735d3bcd809ee1922f2078ab71e7d8d9dbe1c42765978fa3f2f807df01a2a3ad421c986f0b2207c3a40a68
|
||||
checksum: 10c0/182b92b5e7c07da428fd23c6de22021c4f9a91f799c02a8ef15def07e4f9361d0fc22303548658fec2a700623535fd44a9dc4d010fb5d803a8f80e3c6c64a45e
|
||||
languageName: node
|
||||
linkType: hard
|
||||
|
||||
@@ -13067,7 +13067,7 @@ __metadata:
|
||||
"@libsql/client": "npm:0.14.0"
|
||||
"@libsql/win32-x64-msvc": "npm:^0.4.7"
|
||||
"@mistralai/mistralai": "npm:^1.7.5"
|
||||
"@modelcontextprotocol/sdk": "npm:^1.17.0"
|
||||
"@modelcontextprotocol/sdk": "npm:^1.17.5"
|
||||
"@mozilla/readability": "npm:^0.6.0"
|
||||
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch"
|
||||
"@notionhq/client": "npm:^2.2.15"
|
||||
|
||||
Reference in New Issue
Block a user