Merge branch 'feat/agents-new' of github.com:CherryHQ/cherry-studio into feat/agents-new
This commit is contained in:
@@ -1,8 +1,9 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { ListAgentsResponse } from '@types'
|
||||
import { ListAgentsResponse,type ReplaceAgentRequest, type UpdateAgentRequest } from '@types'
|
||||
import { Request, Response } from 'express'
|
||||
|
||||
import { agentService } from '../../../../services/agents'
|
||||
import type { ValidationRequest } from '../validators/zodValidator'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerAgentsHandlers')
|
||||
|
||||
@@ -263,7 +264,10 @@ export const updateAgent = async (req: Request, res: Response): Promise<Response
|
||||
logger.info(`Updating agent: ${agentId}`)
|
||||
logger.debug('Update data:', req.body)
|
||||
|
||||
const agent = await agentService.updateAgent(agentId, req.body)
|
||||
const { validatedBody } = req as ValidationRequest
|
||||
const replacePayload = (validatedBody ?? {}) as ReplaceAgentRequest
|
||||
|
||||
const agent = await agentService.updateAgent(agentId, replacePayload, { replace: true })
|
||||
|
||||
if (!agent) {
|
||||
logger.warn(`Agent not found for update: ${agentId}`)
|
||||
@@ -395,7 +399,10 @@ export const patchAgent = async (req: Request, res: Response): Promise<Response>
|
||||
logger.info(`Partially updating agent: ${agentId}`)
|
||||
logger.debug('Partial update data:', req.body)
|
||||
|
||||
const agent = await agentService.updateAgent(agentId, req.body)
|
||||
const { validatedBody } = req as ValidationRequest
|
||||
const updatePayload = (validatedBody ?? {}) as UpdateAgentRequest
|
||||
|
||||
const agent = await agentService.updateAgent(agentId, updatePayload)
|
||||
|
||||
if (!agent) {
|
||||
logger.warn(`Agent not found for partial update: ${agentId}`)
|
||||
|
||||
@@ -12,7 +12,7 @@ const verifyAgentAndSession = async (agentId: string, sessionId: string) => {
|
||||
throw { status: 404, code: 'agent_not_found', message: 'Agent not found' }
|
||||
}
|
||||
|
||||
const session = await sessionService.getSession(sessionId)
|
||||
const session = await sessionService.getSession(agentId, sessionId)
|
||||
if (!session) {
|
||||
throw { status: 404, code: 'session_not_found', message: 'Session not found' }
|
||||
}
|
||||
|
||||
@@ -1,8 +1,15 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { sessionMessageService, sessionService } from '@main/services/agents'
|
||||
import { CreateSessionResponse, ListAgentSessionsResponse, UpdateSessionResponse } from '@types'
|
||||
import {
|
||||
CreateSessionResponse,
|
||||
ListAgentSessionsResponse,
|
||||
type ReplaceSessionRequest,
|
||||
UpdateSessionResponse
|
||||
} from '@types'
|
||||
import { Request, Response } from 'express'
|
||||
|
||||
import type { ValidationRequest } from '../validators/zodValidator'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerSessionsHandlers')
|
||||
|
||||
export const createSession = async (req: Request, res: Response): Promise<Response> => {
|
||||
@@ -64,7 +71,7 @@ export const getSession = async (req: Request, res: Response): Promise<Response>
|
||||
const { agentId, sessionId } = req.params
|
||||
logger.info(`Getting session: ${sessionId} for agent: ${agentId}`)
|
||||
|
||||
const session = await sessionService.getSession(sessionId)
|
||||
const session = await sessionService.getSession(agentId, sessionId)
|
||||
|
||||
if (!session) {
|
||||
logger.warn(`Session not found: ${sessionId}`)
|
||||
@@ -119,7 +126,7 @@ export const updateSession = async (req: Request, res: Response): Promise<Respon
|
||||
logger.debug('Update data:', req.body)
|
||||
|
||||
// First check if session exists and belongs to agent
|
||||
const existingSession = await sessionService.getSession(sessionId)
|
||||
const existingSession = await sessionService.getSession(agentId, sessionId)
|
||||
if (!existingSession || existingSession.agent_id !== agentId) {
|
||||
logger.warn(`Session ${sessionId} not found for agent ${agentId}`)
|
||||
return res.status(404).json({
|
||||
@@ -131,9 +138,10 @@ export const updateSession = async (req: Request, res: Response): Promise<Respon
|
||||
})
|
||||
}
|
||||
|
||||
// For PUT, we replace the entire resource
|
||||
const sessionData = { ...req.body, main_agent_id: agentId }
|
||||
const session = await sessionService.updateSession(sessionId, sessionData)
|
||||
const { validatedBody } = req as ValidationRequest
|
||||
const replacePayload = (validatedBody ?? {}) as ReplaceSessionRequest
|
||||
|
||||
const session = await sessionService.updateSession(agentId, sessionId, replacePayload)
|
||||
|
||||
if (!session) {
|
||||
logger.warn(`Session not found for update: ${sessionId}`)
|
||||
@@ -167,7 +175,7 @@ export const patchSession = async (req: Request, res: Response): Promise<Respons
|
||||
logger.debug('Patch data:', req.body)
|
||||
|
||||
// First check if session exists and belongs to agent
|
||||
const existingSession = await sessionService.getSession(sessionId)
|
||||
const existingSession = await sessionService.getSession(agentId, sessionId)
|
||||
if (!existingSession || existingSession.agent_id !== agentId) {
|
||||
logger.warn(`Session ${sessionId} not found for agent ${agentId}`)
|
||||
return res.status(404).json({
|
||||
@@ -180,7 +188,7 @@ export const patchSession = async (req: Request, res: Response): Promise<Respons
|
||||
}
|
||||
|
||||
const updateSession = { ...existingSession, ...req.body }
|
||||
const session = await sessionService.updateSession(sessionId, updateSession)
|
||||
const session = await sessionService.updateSession(agentId, sessionId, updateSession)
|
||||
|
||||
if (!session) {
|
||||
logger.warn(`Session not found for patch: ${sessionId}`)
|
||||
@@ -213,7 +221,7 @@ export const deleteSession = async (req: Request, res: Response): Promise<Respon
|
||||
logger.info(`Deleting session: ${sessionId} for agent: ${agentId}`)
|
||||
|
||||
// First check if session exists and belongs to agent
|
||||
const existingSession = await sessionService.getSession(sessionId)
|
||||
const existingSession = await sessionService.getSession(agentId, sessionId)
|
||||
if (!existingSession || existingSession.agent_id !== agentId) {
|
||||
logger.warn(`Session ${sessionId} not found for agent ${agentId}`)
|
||||
return res.status(404).json({
|
||||
@@ -225,7 +233,7 @@ export const deleteSession = async (req: Request, res: Response): Promise<Respon
|
||||
})
|
||||
}
|
||||
|
||||
const deleted = await sessionService.deleteSession(sessionId)
|
||||
const deleted = await sessionService.deleteSession(agentId, sessionId)
|
||||
|
||||
if (!deleted) {
|
||||
logger.warn(`Session not found for deletion: ${sessionId}`)
|
||||
@@ -287,7 +295,7 @@ export const getSessionById = async (req: Request, res: Response): Promise<Respo
|
||||
const { sessionId } = req.params
|
||||
logger.info(`Getting session: ${sessionId}`)
|
||||
|
||||
const session = await sessionService.getSession(sessionId)
|
||||
const session = await sessionService.getSessionById(sessionId)
|
||||
|
||||
if (!session) {
|
||||
logger.warn(`Session not found: ${sessionId}`)
|
||||
|
||||
@@ -5,11 +5,13 @@ import { checkAgentExists, handleValidationErrors } from './middleware'
|
||||
import {
|
||||
validateAgent,
|
||||
validateAgentId,
|
||||
validateAgentReplace,
|
||||
validateAgentUpdate,
|
||||
validatePagination,
|
||||
validateSession,
|
||||
validateSessionId,
|
||||
validateSessionMessage,
|
||||
validateSessionReplace,
|
||||
validateSessionUpdate
|
||||
} from './validators'
|
||||
|
||||
@@ -152,7 +154,13 @@ const agentsRouter = express.Router()
|
||||
agentsRouter.post('/', validateAgent, handleValidationErrors, agentHandlers.createAgent)
|
||||
agentsRouter.get('/', validatePagination, handleValidationErrors, agentHandlers.listAgents)
|
||||
agentsRouter.get('/:agentId', validateAgentId, handleValidationErrors, agentHandlers.getAgent)
|
||||
agentsRouter.put('/:agentId', validateAgentId, validateAgentUpdate, handleValidationErrors, agentHandlers.updateAgent)
|
||||
agentsRouter.put(
|
||||
'/:agentId',
|
||||
validateAgentId,
|
||||
validateAgentReplace,
|
||||
handleValidationErrors,
|
||||
agentHandlers.updateAgent
|
||||
)
|
||||
agentsRouter.patch('/:agentId', validateAgentId, validateAgentUpdate, handleValidationErrors, agentHandlers.patchAgent)
|
||||
agentsRouter.delete('/:agentId', validateAgentId, handleValidationErrors, agentHandlers.deleteAgent)
|
||||
|
||||
@@ -167,7 +175,7 @@ const createSessionsRouter = (): express.Router => {
|
||||
sessionsRouter.put(
|
||||
'/:sessionId',
|
||||
validateSessionId,
|
||||
validateSessionUpdate,
|
||||
validateSessionReplace,
|
||||
handleValidationErrors,
|
||||
sessionHandlers.updateSession
|
||||
)
|
||||
|
||||
@@ -1,24 +1,12 @@
|
||||
import { Request, Response } from 'express'
|
||||
import { validationResult } from 'express-validator'
|
||||
|
||||
import { agentService } from '../../../../services/agents'
|
||||
import { loggerService } from '../../../../services/LoggerService'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerMiddleware')
|
||||
|
||||
// Error handler for validation
|
||||
export const handleValidationErrors = (req: Request, res: Response, next: any): void => {
|
||||
const errors = validationResult(req)
|
||||
if (!errors.isEmpty()) {
|
||||
res.status(400).json({
|
||||
error: {
|
||||
message: 'Validation failed',
|
||||
type: 'validation_error',
|
||||
details: errors.array()
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
// Since Zod validators handle their own errors, this is now a pass-through
|
||||
export const handleValidationErrors = (_req: Request, _res: Response, next: any): void => {
|
||||
next()
|
||||
}
|
||||
|
||||
|
||||
@@ -1,37 +1,24 @@
|
||||
import { body, param } from 'express-validator'
|
||||
import {
|
||||
AgentIdParamSchema,
|
||||
CreateAgentRequestSchema,
|
||||
ReplaceAgentRequestSchema,
|
||||
UpdateAgentRequestSchema
|
||||
} from '@types'
|
||||
|
||||
export const validateAgent = [
|
||||
body('name').notEmpty().withMessage('Name is required'),
|
||||
body('model').notEmpty().withMessage('Model is required'),
|
||||
body('description').optional().isString(),
|
||||
body('avatar').optional().isString(),
|
||||
body('instructions').optional().isString(),
|
||||
body('plan_model').optional().isString(),
|
||||
body('small_model').optional().isString(),
|
||||
body('built_in_tools').optional().isArray(),
|
||||
body('mcps').optional().isArray(),
|
||||
body('knowledges').optional().isArray(),
|
||||
body('configuration').optional().isObject(),
|
||||
body('accessible_paths').optional().isArray(),
|
||||
body('permission_mode').optional().isIn(['readOnly', 'acceptEdits', 'bypassPermissions']),
|
||||
body('max_steps').optional().isInt({ min: 1 })
|
||||
]
|
||||
import { createZodValidator } from './zodValidator'
|
||||
|
||||
export const validateAgentUpdate = [
|
||||
body('name').optional().notEmpty().withMessage('Name cannot be empty'),
|
||||
body('model').optional().notEmpty().withMessage('Model cannot be empty'),
|
||||
body('description').optional().isString(),
|
||||
body('avatar').optional().isString(),
|
||||
body('instructions').optional().isString(),
|
||||
body('plan_model').optional().isString(),
|
||||
body('small_model').optional().isString(),
|
||||
body('built_in_tools').optional().isArray(),
|
||||
body('mcps').optional().isArray(),
|
||||
body('knowledges').optional().isArray(),
|
||||
body('configuration').optional().isObject(),
|
||||
body('accessible_paths').optional().isArray(),
|
||||
body('permission_mode').optional().isIn(['readOnly', 'acceptEdits', 'bypassPermissions']),
|
||||
body('max_steps').optional().isInt({ min: 1 })
|
||||
]
|
||||
export const validateAgent = createZodValidator({
|
||||
body: CreateAgentRequestSchema
|
||||
})
|
||||
|
||||
export const validateAgentId = [param('agentId').notEmpty().withMessage('Agent ID is required')]
|
||||
export const validateAgentReplace = createZodValidator({
|
||||
body: ReplaceAgentRequestSchema
|
||||
})
|
||||
|
||||
export const validateAgentUpdate = createZodValidator({
|
||||
body: UpdateAgentRequestSchema
|
||||
})
|
||||
|
||||
export const validateAgentId = createZodValidator({
|
||||
params: AgentIdParamSchema
|
||||
})
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import { query } from 'express-validator'
|
||||
import { PaginationQuerySchema } from '@types'
|
||||
|
||||
export const validatePagination = [
|
||||
query('limit').optional().isInt({ min: 1, max: 100 }).withMessage('Limit must be between 1 and 100'),
|
||||
query('offset').optional().isInt({ min: 0 }).withMessage('Offset must be non-negative'),
|
||||
query('status')
|
||||
.optional()
|
||||
.isIn(['idle', 'running', 'completed', 'failed', 'stopped'])
|
||||
.withMessage('Invalid status filter')
|
||||
]
|
||||
import { createZodValidator } from './zodValidator'
|
||||
|
||||
export const validatePagination = createZodValidator({
|
||||
query: PaginationQuerySchema
|
||||
})
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import { body } from 'express-validator'
|
||||
import { CreateSessionMessageRequestSchema } from '@types'
|
||||
|
||||
export const validateSessionMessage = [
|
||||
body('content').notEmpty().isString().withMessage('Content must be a valid string')
|
||||
]
|
||||
import { createZodValidator } from './zodValidator'
|
||||
|
||||
export const validateSessionMessage = createZodValidator({
|
||||
body: CreateSessionMessageRequestSchema
|
||||
})
|
||||
|
||||
@@ -1,47 +1,24 @@
|
||||
import { body, param } from 'express-validator'
|
||||
import {
|
||||
CreateSessionRequestSchema,
|
||||
ReplaceSessionRequestSchema,
|
||||
SessionIdParamSchema,
|
||||
UpdateSessionRequestSchema
|
||||
} from '@types'
|
||||
|
||||
export const validateSession = [
|
||||
body('name').optional().isString(),
|
||||
body('sub_agent_ids').optional().isArray(),
|
||||
body('user_goal').optional().isString(),
|
||||
body('status').optional().isIn(['idle', 'running', 'completed', 'failed', 'stopped']),
|
||||
body('external_session_id').optional().isString(),
|
||||
body('model').optional().isString(),
|
||||
body('plan_model').optional().isString(),
|
||||
body('small_model').optional().isString(),
|
||||
body('built_in_tools').optional().isArray(),
|
||||
body('mcps').optional().isArray(),
|
||||
body('knowledges').optional().isArray(),
|
||||
body('configuration').optional().isObject(),
|
||||
body('accessible_paths').optional().isArray(),
|
||||
body('permission_mode').optional().isIn(['readOnly', 'acceptEdits', 'bypassPermissions']),
|
||||
body('max_steps').optional().isInt({ min: 1 })
|
||||
]
|
||||
import { createZodValidator } from './zodValidator'
|
||||
|
||||
export const validateSessionUpdate = [
|
||||
body('name').optional().isString(),
|
||||
body('main_agent_id').optional().notEmpty().withMessage('Main agent ID cannot be empty'),
|
||||
body('sub_agent_ids').optional().isArray(),
|
||||
body('user_goal').optional().isString(),
|
||||
body('status').optional().isIn(['idle', 'running', 'completed', 'failed', 'stopped']),
|
||||
body('external_session_id').optional().isString(),
|
||||
body('model').optional().isString(),
|
||||
body('plan_model').optional().isString(),
|
||||
body('small_model').optional().isString(),
|
||||
body('built_in_tools').optional().isArray(),
|
||||
body('mcps').optional().isArray(),
|
||||
body('knowledges').optional().isArray(),
|
||||
body('configuration').optional().isObject(),
|
||||
body('accessible_paths').optional().isArray(),
|
||||
body('permission_mode').optional().isIn(['readOnly', 'acceptEdits', 'bypassPermissions']),
|
||||
body('max_steps').optional().isInt({ min: 1 })
|
||||
]
|
||||
export const validateSession = createZodValidator({
|
||||
body: CreateSessionRequestSchema
|
||||
})
|
||||
|
||||
export const validateStatusUpdate = [
|
||||
body('status')
|
||||
.notEmpty()
|
||||
.isIn(['idle', 'running', 'completed', 'failed', 'stopped'])
|
||||
.withMessage('Valid status is required')
|
||||
]
|
||||
export const validateSessionReplace = createZodValidator({
|
||||
body: ReplaceSessionRequestSchema
|
||||
})
|
||||
|
||||
export const validateSessionId = [param('sessionId').notEmpty().withMessage('Session ID is required')]
|
||||
export const validateSessionUpdate = createZodValidator({
|
||||
body: UpdateSessionRequestSchema
|
||||
})
|
||||
|
||||
export const validateSessionId = createZodValidator({
|
||||
params: SessionIdParamSchema
|
||||
})
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
import { NextFunction,Request, Response } from 'express'
|
||||
import { ZodError, ZodType } from 'zod'
|
||||
|
||||
export interface ValidationRequest extends Request {
|
||||
validatedBody?: any
|
||||
validatedParams?: any
|
||||
validatedQuery?: any
|
||||
}
|
||||
|
||||
export interface ZodValidationConfig {
|
||||
body?: ZodType
|
||||
params?: ZodType
|
||||
query?: ZodType
|
||||
}
|
||||
|
||||
export const createZodValidator = (config: ZodValidationConfig) => {
|
||||
return (req: ValidationRequest, res: Response, next: NextFunction): void => {
|
||||
try {
|
||||
if (config.body && req.body) {
|
||||
req.validatedBody = config.body.parse(req.body)
|
||||
}
|
||||
|
||||
if (config.params && req.params) {
|
||||
req.validatedParams = config.params.parse(req.params)
|
||||
}
|
||||
|
||||
if (config.query && req.query) {
|
||||
req.validatedQuery = config.query.parse(req.query)
|
||||
}
|
||||
|
||||
next()
|
||||
} catch (error) {
|
||||
if (error instanceof ZodError) {
|
||||
const validationErrors = error.issues.map((err) => ({
|
||||
type: 'field',
|
||||
value: err.input,
|
||||
msg: err.message,
|
||||
path: err.path.map(p => String(p)).join('.'),
|
||||
location: getLocationFromPath(err.path, config)
|
||||
}))
|
||||
|
||||
res.status(400).json({
|
||||
error: {
|
||||
message: 'Validation failed',
|
||||
type: 'validation_error',
|
||||
details: validationErrors
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Internal validation error',
|
||||
type: 'internal_error',
|
||||
code: 'validation_processing_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function getLocationFromPath(path: (string | number | symbol)[], config: ZodValidationConfig): string {
|
||||
if (config.body && path.length > 0) return 'body'
|
||||
if (config.params && path.length > 0) return 'params'
|
||||
if (config.query && path.length > 0) return 'query'
|
||||
return 'unknown'
|
||||
}
|
||||
Reference in New Issue
Block a user