Merge branch 'feat/agents-new' of github.com:CherryHQ/cherry-studio into feat/agents-new

This commit is contained in:
icarus
2025-09-20 00:53:55 +08:00
7 changed files with 134 additions and 136 deletions
@@ -1,4 +1,5 @@
import { loggerService } from '@logger'
import { AgentStreamEvent } from '@main/services/agents/interfaces/AgentStreamInterface'
import { Request, Response } from 'express'
import { agentService, sessionMessageService, sessionService } from '../../../../services/agents'
@@ -42,13 +43,12 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
res.setHeader('Access-Control-Allow-Origin', '*')
res.setHeader('Access-Control-Allow-Headers', 'Cache-Control')
const messageStream = sessionMessageService.createSessionMessage(session, messageData)
const abortController = new AbortController()
const messageStream = sessionMessageService.createSessionMessage(session, messageData, abortController)
// Track stream lifecycle so we keep the SSE connection open until persistence finishes
let responseEnded = false
let streamFinished = false
let awaitingPersistence = false
let persistenceResolved = false
const finalizeResponse = () => {
if (responseEnded) {
@@ -59,10 +59,6 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
return
}
if (awaitingPersistence && !persistenceResolved) {
return
}
responseEnded = true
try {
res.write('data: {"type":"finish"}\n\n')
@@ -73,15 +69,39 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
res.end()
}
// Handle client disconnect
req.on('close', () => {
/**
* Client Disconnect Detection for Server-Sent Events (SSE)
*
* We monitor multiple HTTP events to reliably detect when a client disconnects
* from the streaming response. This is crucial for:
* - Aborting long-running Claude Code processes
* - Cleaning up resources and preventing memory leaks
* - Avoiding orphaned processes
*
* Event Priority & Behavior:
* 1. res.on('close') - Most common for SSE client disconnects (browser tab close, curl Ctrl+C)
* 2. req.on('aborted') - Explicit request abortion
* 3. req.on('close') - Request object closure (less common with SSE)
*
* When any disconnect event fires, we:
* - Abort the Claude Code SDK process via abortController
* - Clean up event listeners to prevent memory leaks
* - Mark the response as ended to prevent further writes
*/
const handleDisconnect = () => {
if (responseEnded) return
logger.info(`Client disconnected from streaming message for session: ${sessionId}`)
responseEnded = true
messageStream.removeAllListeners()
})
abortController.abort('Client disconnected')
}
req.on('close', handleDisconnect)
req.on('aborted', handleDisconnect)
res.on('close', handleDisconnect)
// Handle stream events
messageStream.on('data', (event: any) => {
messageStream.on('data', (event: AgentStreamEvent) => {
if (responseEnded) return
try {
@@ -101,12 +121,6 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
logger.error(`Streaming message error for session: ${sessionId}:`, event.error)
streamFinished = true
awaitingPersistence = Boolean(event.persistScheduled)
if (!awaitingPersistence) {
persistenceResolved = true
}
finalizeResponse()
break
}
@@ -116,32 +130,22 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
// res.write(`data: ${JSON.stringify({ type: 'complete', result: event.result })}\n\n`)
streamFinished = true
awaitingPersistence = true
finalizeResponse()
break
}
case 'persisted':
// Send persistence success event
// res.write(`data: ${JSON.stringify(event)}\n\n`)
logger.debug(`Session message persisted for session: ${sessionId}`, { messageId: event.message?.id })
persistenceResolved = true
finalizeResponse()
break
case 'persist-error':
// Send persistence error event
// res.write(`data: ${JSON.stringify(event)}\n\n`)
logger.error(`Failed to persist session message for session: ${sessionId}:`, event.error)
persistenceResolved = true
case 'cancelled': {
logger.info(`Streaming message cancelled for session: ${sessionId}`)
// res.write(`data: ${JSON.stringify({ type: 'cancelled' })}\n\n`)
streamFinished = true
finalizeResponse()
break
}
default:
// Handle other event types as generic data
res.write(`data: ${JSON.stringify(event)}\n\n`)
logger.info(`Streaming message event for session: ${sessionId}:`, { event })
// res.write(`data: ${JSON.stringify(event)}\n\n`)
break
}
} catch (writeError) {
@@ -199,8 +203,8 @@ export const createMessage = async (req: Request, res: Response): Promise<void>
res.end()
}
},
5 * 60 * 1000
) // 5 minutes timeout
10 * 60 * 1000
) // 10 minutes timeout
// Clear timeout when response ends
res.on('close', () => clearTimeout(timeout))
+3 -28
View File
@@ -13,8 +13,7 @@ import { Request, Response } from 'express'
import { IncomingMessage, ServerResponse } from 'http'
import { loggerService } from '../../services/LoggerService'
import { reduxService } from '../../services/ReduxService'
import { getMcpServerById } from '../utils/mcp'
import { getMcpServerById, getMCPServersFromRedux } from '../utils/mcp'
const logger = loggerService.withContext('MCPApiService')
const transports: Record<string, StreamableHTTPServerTransport> = {}
@@ -57,34 +56,10 @@ class MCPApiService extends EventEmitter {
this.transport.onmessage = this.onMessage
}
/**
* Get servers directly from Redux store
*/
private async getServersFromRedux(): Promise<MCPServer[]> {
try {
logger.silly('Getting servers from Redux store')
// Try to get from cache first (faster)
const cachedServers = reduxService.selectSync<MCPServer[]>('state.mcp.servers')
if (cachedServers && Array.isArray(cachedServers)) {
logger.silly(`Found ${cachedServers.length} servers in Redux cache`)
return cachedServers
}
// If cache is not available, get fresh data
const servers = await reduxService.select<MCPServer[]>('state.mcp.servers')
logger.silly(`Fetched ${servers?.length || 0} servers from Redux store`)
return servers || []
} catch (error: any) {
logger.error('Failed to get servers from Redux:', error)
return []
}
}
// get all activated servers
async getAllServers(req: Request): Promise<McpServersResp> {
try {
const servers = await this.getServersFromRedux()
const servers = await getMCPServersFromRedux()
logger.silly(`Returning ${servers.length} servers`)
const resp: McpServersResp = {
servers: {}
@@ -111,7 +86,7 @@ class MCPApiService extends EventEmitter {
async getServerById(id: string): Promise<MCPServer | null> {
try {
logger.silly(`getServerById called with id: ${id}`)
const servers = await this.getServersFromRedux()
const servers = await getMCPServersFromRedux()
const server = servers.find((s) => s.id === id)
if (!server) {
logger.warn(`Server with id ${id} not found`)
+16 -1
View File
@@ -1,12 +1,24 @@
import { CacheService } from '@main/services/CacheService'
import { loggerService } from '@main/services/LoggerService'
import { reduxService } from '@main/services/ReduxService'
import { ApiModel, Model, Provider } from '@types'
const logger = loggerService.withContext('ApiServerUtils')
// Cache configuration
const PROVIDERS_CACHE_KEY = 'api-server:providers'
const PROVIDERS_CACHE_TTL = 5 * 60 * 1000 // 5 minutes
export async function getAvailableProviders(): Promise<Provider[]> {
try {
// Wait for store to be ready before accessing providers
// Try to get from cache first (faster)
const cachedSupportedProviders = CacheService.get<Provider[]>(PROVIDERS_CACHE_KEY)
if (cachedSupportedProviders) {
logger.debug(`Found ${cachedSupportedProviders.length} supported providers (from cache)`)
return cachedSupportedProviders
}
// If cache is not available, get fresh data from Redux
const providers = await reduxService.select('state.llm.providers')
if (!providers || !Array.isArray(providers)) {
logger.warn('No providers found in Redux store, returning empty array')
@@ -18,6 +30,9 @@ export async function getAvailableProviders(): Promise<Provider[]> {
(p: Provider) => p.enabled && (p.type === 'openai' || p.type === 'anthropic')
)
// Cache the filtered results
CacheService.set(PROVIDERS_CACHE_KEY, supportedProviders, PROVIDERS_CACHE_TTL)
logger.info(`Filtered to ${supportedProviders.length} supported providers from ${providers.length} total providers`)
return supportedProviders
+25 -5
View File
@@ -1,3 +1,4 @@
import { CacheService } from '@main/services/CacheService'
import mcpService from '@main/services/MCPService'
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
import { CallToolRequestSchema, ListToolsRequestSchema, ListToolsResult } from '@modelcontextprotocol/sdk/types.js'
@@ -8,6 +9,10 @@ import { reduxService } from '../../services/ReduxService'
const logger = loggerService.withContext('MCPApiService')
// Cache configuration
const MCP_SERVERS_CACHE_KEY = 'api-server:mcp-servers'
const MCP_SERVERS_CACHE_TTL = 5 * 60 * 1000 // 5 minutes
const cachedServers: Record<string, Server> = {}
async function handleListToolsRequest(request: any, extra: any): Promise<ListToolsResult> {
@@ -33,18 +38,33 @@ async function handleCallToolRequest(request: any, extra: any): Promise<any> {
}
async function getMcpServerConfigById(id: string): Promise<MCPServer | undefined> {
const servers = await getServersFromRedux()
const servers = await getMCPServersFromRedux()
return servers.find((s) => s.id === id || s.name === id)
}
/**
* Get servers directly from Redux store
*/
async function getServersFromRedux(): Promise<MCPServer[]> {
export async function getMCPServersFromRedux(): Promise<MCPServer[]> {
try {
logger.silly('Getting servers from Redux store')
// Try to get from cache first (faster)
const cachedServers = CacheService.get<MCPServer[]>(MCP_SERVERS_CACHE_KEY)
if (cachedServers) {
logger.silly(`Found ${cachedServers.length} servers (from cache)`)
return cachedServers
}
// If cache is not available, get fresh data from Redux
const servers = await reduxService.select<MCPServer[]>('state.mcp.servers')
logger.silly(`Fetched ${servers?.length || 0} servers from Redux store`)
return servers || []
const serverList = servers || []
// Cache the results
CacheService.set(MCP_SERVERS_CACHE_KEY, serverList, MCP_SERVERS_CACHE_TTL)
logger.silly(`Fetched ${serverList.length} servers from Redux store`)
return serverList
} catch (error: any) {
logger.error('Failed to get servers from Redux:', error)
return []
@@ -54,7 +74,7 @@ async function getServersFromRedux(): Promise<MCPServer[]> {
export async function getMcpServerById(id: string): Promise<Server> {
const server = cachedServers[id]
if (!server) {
const servers = await getServersFromRedux()
const servers = await getMCPServersFromRedux()
const mcpServer = servers.find((s) => s.id === id || s.name === id)
if (!mcpServer) {
throw new Error(`Server not found: ${id}`)