Merge branch 'main' of github.com:CherryHQ/cherry-studio into v2
This commit is contained in:
@@ -3,23 +3,42 @@ import cors from 'cors'
|
||||
import express from 'express'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
|
||||
import { LONG_POLL_TIMEOUT_MS } from './config/timeouts'
|
||||
import { authMiddleware } from './middleware/auth'
|
||||
import { errorHandler } from './middleware/error'
|
||||
import { setupOpenAPIDocumentation } from './middleware/openapi'
|
||||
import { agentsRoutes } from './routes/agents'
|
||||
import { chatRoutes } from './routes/chat'
|
||||
import { mcpRoutes } from './routes/mcp'
|
||||
import { messagesProviderRoutes, messagesRoutes } from './routes/messages'
|
||||
import { modelsRoutes } from './routes/models'
|
||||
|
||||
const logger = loggerService.withContext('ApiServer')
|
||||
|
||||
const extendMessagesTimeout: express.RequestHandler = (req, res, next) => {
|
||||
req.setTimeout(LONG_POLL_TIMEOUT_MS)
|
||||
res.setTimeout(LONG_POLL_TIMEOUT_MS)
|
||||
next()
|
||||
}
|
||||
|
||||
const app = express()
|
||||
app.use(
|
||||
express.json({
|
||||
limit: '50mb'
|
||||
})
|
||||
)
|
||||
|
||||
// Global middleware
|
||||
app.use((req, res, next) => {
|
||||
const start = Date.now()
|
||||
res.on('finish', () => {
|
||||
const duration = Date.now() - start
|
||||
logger.info(`${req.method} ${req.path} - ${res.statusCode} - ${duration}ms`)
|
||||
logger.info('API request completed', {
|
||||
method: req.method,
|
||||
path: req.path,
|
||||
statusCode: res.statusCode,
|
||||
durationMs: duration
|
||||
})
|
||||
})
|
||||
next()
|
||||
})
|
||||
@@ -101,27 +120,28 @@ app.get('/', (_req, res) => {
|
||||
name: 'Cherry Studio API',
|
||||
version: '1.0.0',
|
||||
endpoints: {
|
||||
health: 'GET /health',
|
||||
models: 'GET /v1/models',
|
||||
chat: 'POST /v1/chat/completions',
|
||||
mcp: 'GET /v1/mcps'
|
||||
health: 'GET /health'
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
// Setup OpenAPI documentation before protected routes so docs remain public
|
||||
setupOpenAPIDocumentation(app)
|
||||
|
||||
// Provider-specific messages route requires authentication
|
||||
app.use('/:provider/v1/messages', authMiddleware, extendMessagesTimeout, messagesProviderRoutes)
|
||||
|
||||
// API v1 routes with auth
|
||||
const apiRouter = express.Router()
|
||||
apiRouter.use(authMiddleware)
|
||||
apiRouter.use(express.json())
|
||||
// Mount routes
|
||||
apiRouter.use('/chat', chatRoutes)
|
||||
apiRouter.use('/mcps', mcpRoutes)
|
||||
apiRouter.use('/messages', extendMessagesTimeout, messagesRoutes)
|
||||
apiRouter.use('/models', modelsRoutes)
|
||||
apiRouter.use('/agents', agentsRoutes)
|
||||
app.use('/v1', apiRouter)
|
||||
|
||||
// Setup OpenAPI documentation
|
||||
setupOpenAPIDocumentation(app)
|
||||
|
||||
// Error handling (must be last)
|
||||
app.use(errorHandler)
|
||||
|
||||
|
||||
@@ -36,7 +36,7 @@ class ConfigManager {
|
||||
}
|
||||
return this._config
|
||||
} catch (error: any) {
|
||||
logger.warn('Failed to load config from Redux, using defaults:', error)
|
||||
logger.warn('Failed to load config from Redux, using defaults', { error })
|
||||
this._config = {
|
||||
enabled: false,
|
||||
port: defaultPort,
|
||||
|
||||
@@ -0,0 +1,3 @@
|
||||
export const LONG_POLL_TIMEOUT_MS = 120 * 60_000 // 120 minutes
|
||||
|
||||
export const MESSAGE_STREAM_TIMEOUT_MS = LONG_POLL_TIMEOUT_MS
|
||||
@@ -0,0 +1,368 @@
|
||||
import type { NextFunction, Request, Response } from 'express'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { config } from '../../config'
|
||||
import { authMiddleware } from '../auth'
|
||||
|
||||
// Mock the config module
|
||||
vi.mock('../../config', () => ({
|
||||
config: {
|
||||
get: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
// Mock the logger
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: vi.fn(() => ({
|
||||
debug: vi.fn()
|
||||
}))
|
||||
}
|
||||
}))
|
||||
|
||||
const mockConfig = config as any
|
||||
|
||||
describe('authMiddleware', () => {
|
||||
let req: Partial<Request>
|
||||
let res: Partial<Response>
|
||||
let next: NextFunction
|
||||
let jsonMock: ReturnType<typeof vi.fn>
|
||||
let statusMock: ReturnType<typeof vi.fn>
|
||||
|
||||
beforeEach(() => {
|
||||
jsonMock = vi.fn()
|
||||
statusMock = vi.fn(() => ({ json: jsonMock }))
|
||||
|
||||
req = {
|
||||
header: vi.fn()
|
||||
}
|
||||
res = {
|
||||
status: statusMock
|
||||
}
|
||||
next = vi.fn()
|
||||
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('Missing credentials', () => {
|
||||
it('should return 401 when both auth headers are missing', async () => {
|
||||
;(req.header as any).mockReturnValue('')
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: missing credentials' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return 401 when both auth headers are empty strings', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return ''
|
||||
if (header === 'x-api-key') return ''
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: missing credentials' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Server configuration', () => {
|
||||
it('should return 403 when API key is not configured', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return 'some-key'
|
||||
return ''
|
||||
})
|
||||
|
||||
mockConfig.get.mockResolvedValue({ apiKey: '' })
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return 403 when API key is null', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return 'some-key'
|
||||
return ''
|
||||
})
|
||||
|
||||
mockConfig.get.mockResolvedValue({ apiKey: null })
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('API Key authentication (priority)', () => {
|
||||
const validApiKey = 'valid-api-key-123'
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
|
||||
})
|
||||
|
||||
it('should authenticate successfully with valid API key', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return validApiKey
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(next).toHaveBeenCalled()
|
||||
expect(statusMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return 403 with invalid API key', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return 'invalid-key'
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return 401 with empty API key', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return ' '
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: empty x-api-key' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle API key with whitespace', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return ` ${validApiKey} `
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(next).toHaveBeenCalled()
|
||||
expect(statusMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should prioritize API key over Bearer token when both are present', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return validApiKey
|
||||
if (header === 'authorization') return 'Bearer invalid-token'
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(next).toHaveBeenCalled()
|
||||
expect(statusMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return 403 when API key is invalid even if Bearer token is valid', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return 'invalid-key'
|
||||
if (header === 'authorization') return `Bearer ${validApiKey}`
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Bearer token authentication (fallback)', () => {
|
||||
const validApiKey = 'valid-api-key-123'
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
|
||||
})
|
||||
|
||||
it('should authenticate successfully with valid Bearer token when no API key', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return `Bearer ${validApiKey}`
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(next).toHaveBeenCalled()
|
||||
expect(statusMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return 403 with invalid Bearer token', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return 'Bearer invalid-token'
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return 401 with malformed authorization header', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return 'Basic sometoken'
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return 401 with Bearer without space', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return 'Bearer'
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle Bearer token with only trailing spaces (edge case)', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return 'Bearer ' // This will be trimmed to "Bearer" and fail format check
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle Bearer token with case insensitive prefix', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return `bearer ${validApiKey}`
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(next).toHaveBeenCalled()
|
||||
expect(statusMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle Bearer token with whitespace', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return ` Bearer ${validApiKey} `
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(next).toHaveBeenCalled()
|
||||
expect(statusMock).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge cases', () => {
|
||||
const validApiKey = 'valid-api-key-123'
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
|
||||
})
|
||||
|
||||
it('should handle config.get() rejection', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return validApiKey
|
||||
return ''
|
||||
})
|
||||
|
||||
mockConfig.get.mockRejectedValue(new Error('Config error'))
|
||||
|
||||
await expect(authMiddleware(req as Request, res as Response, next)).rejects.toThrow('Config error')
|
||||
})
|
||||
|
||||
it('should use timing-safe comparison for different length tokens', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return 'short'
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return 401 when neither credential format is valid', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return 'Invalid format'
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Unauthorized: invalid authorization format' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Timing attack protection', () => {
|
||||
const validApiKey = 'valid-api-key-123'
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfig.get.mockResolvedValue({ apiKey: validApiKey })
|
||||
})
|
||||
|
||||
it('should handle similar length but different API keys securely', async () => {
|
||||
const similarKey = 'valid-api-key-124' // Same length, different last char
|
||||
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return similarKey
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle similar length but different Bearer tokens securely', async () => {
|
||||
const similarKey = 'valid-api-key-124' // Same length, different last char
|
||||
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return `Bearer ${similarKey}`
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({ error: 'Forbidden' })
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -3,8 +3,17 @@ import type { NextFunction, Request, Response } from 'express'
|
||||
|
||||
import { config } from '../config'
|
||||
|
||||
const isValidToken = (token: string, apiKey: string): boolean => {
|
||||
if (token.length !== apiKey.length) {
|
||||
return false
|
||||
}
|
||||
const tokenBuf = Buffer.from(token)
|
||||
const keyBuf = Buffer.from(apiKey)
|
||||
return crypto.timingSafeEqual(tokenBuf, keyBuf)
|
||||
}
|
||||
|
||||
export const authMiddleware = async (req: Request, res: Response, next: NextFunction) => {
|
||||
const auth = req.header('Authorization') || ''
|
||||
const auth = req.header('authorization') || ''
|
||||
const xApiKey = req.header('x-api-key') || ''
|
||||
|
||||
// Fast rejection if neither credential header provided
|
||||
@@ -12,51 +21,46 @@ export const authMiddleware = async (req: Request, res: Response, next: NextFunc
|
||||
return res.status(401).json({ error: 'Unauthorized: missing credentials' })
|
||||
}
|
||||
|
||||
let token: string | undefined
|
||||
|
||||
// Prefer Bearer if well‑formed
|
||||
if (auth) {
|
||||
const trimmed = auth.trim()
|
||||
const bearerPrefix = /^Bearer\s+/i
|
||||
if (bearerPrefix.test(trimmed)) {
|
||||
const candidate = trimmed.replace(bearerPrefix, '').trim()
|
||||
if (!candidate) {
|
||||
return res.status(401).json({ error: 'Unauthorized: empty bearer token' })
|
||||
}
|
||||
token = candidate
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to x-api-key if token still not resolved
|
||||
if (!token && xApiKey) {
|
||||
if (!xApiKey.trim()) {
|
||||
return res.status(401).json({ error: 'Unauthorized: empty x-api-key' })
|
||||
}
|
||||
token = xApiKey.trim()
|
||||
}
|
||||
|
||||
if (!token) {
|
||||
// At this point we had at least one header, but none yielded a usable token
|
||||
return res.status(401).json({ error: 'Unauthorized: invalid credentials format' })
|
||||
}
|
||||
|
||||
const { apiKey } = await config.get()
|
||||
|
||||
if (!apiKey) {
|
||||
// If server not configured, treat as forbidden (or could be 500). Choose 403 to avoid leaking config state.
|
||||
return res.status(403).json({ error: 'Forbidden' })
|
||||
}
|
||||
|
||||
// Timing-safe compare when lengths match, else immediate forbidden
|
||||
if (token.length !== apiKey.length) {
|
||||
return res.status(403).json({ error: 'Forbidden' })
|
||||
// Check API key first (priority)
|
||||
if (xApiKey) {
|
||||
const trimmedApiKey = xApiKey.trim()
|
||||
if (!trimmedApiKey) {
|
||||
return res.status(401).json({ error: 'Unauthorized: empty x-api-key' })
|
||||
}
|
||||
|
||||
if (isValidToken(trimmedApiKey, apiKey)) {
|
||||
return next()
|
||||
} else {
|
||||
return res.status(403).json({ error: 'Forbidden' })
|
||||
}
|
||||
}
|
||||
|
||||
const tokenBuf = Buffer.from(token)
|
||||
const keyBuf = Buffer.from(apiKey)
|
||||
if (!crypto.timingSafeEqual(tokenBuf, keyBuf)) {
|
||||
return res.status(403).json({ error: 'Forbidden' })
|
||||
// Fallback to Bearer token
|
||||
if (auth) {
|
||||
const trimmed = auth.trim()
|
||||
const bearerPrefix = /^Bearer\s+/i
|
||||
|
||||
if (!bearerPrefix.test(trimmed)) {
|
||||
return res.status(401).json({ error: 'Unauthorized: invalid authorization format' })
|
||||
}
|
||||
|
||||
const token = trimmed.replace(bearerPrefix, '').trim()
|
||||
if (!token) {
|
||||
return res.status(401).json({ error: 'Unauthorized: empty bearer token' })
|
||||
}
|
||||
|
||||
if (isValidToken(token, apiKey)) {
|
||||
return next()
|
||||
} else {
|
||||
return res.status(403).json({ error: 'Forbidden' })
|
||||
}
|
||||
}
|
||||
|
||||
return next()
|
||||
return res.status(401).json({ error: 'Unauthorized: invalid credentials format' })
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ const logger = loggerService.withContext('ApiServerErrorHandler')
|
||||
|
||||
// oxlint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
export const errorHandler = (err: Error, _req: Request, res: Response, _next: NextFunction) => {
|
||||
logger.error('API Server Error:', err)
|
||||
logger.error('API server error', { error: err })
|
||||
|
||||
// Don't expose internal errors in production
|
||||
const isDev = process.env.NODE_ENV === 'development'
|
||||
|
||||
@@ -197,10 +197,11 @@ export function setupOpenAPIDocumentation(app: Express) {
|
||||
})
|
||||
)
|
||||
|
||||
logger.info('OpenAPI documentation setup complete')
|
||||
logger.info('Documentation available at /api-docs')
|
||||
logger.info('OpenAPI spec available at /api-docs.json')
|
||||
logger.info('OpenAPI documentation ready', {
|
||||
docsPath: '/api-docs',
|
||||
specPath: '/api-docs.json'
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to setup OpenAPI documentation:', error as Error)
|
||||
logger.error('Failed to setup OpenAPI documentation', { error })
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,567 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { AgentModelValidationError, agentService, sessionService } from '@main/services/agents'
|
||||
import { ListAgentsResponse, type ReplaceAgentRequest, type UpdateAgentRequest } from '@types'
|
||||
import { Request, Response } from 'express'
|
||||
|
||||
import type { ValidationRequest } from '../validators/zodValidator'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerAgentsHandlers')
|
||||
|
||||
const modelValidationErrorBody = (error: AgentModelValidationError) => ({
|
||||
error: {
|
||||
message: `Invalid ${error.context.field}: ${error.detail.message}`,
|
||||
type: 'invalid_request_error',
|
||||
code: error.detail.code
|
||||
}
|
||||
})
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /v1/agents:
|
||||
* post:
|
||||
* summary: Create a new agent
|
||||
* description: Creates a new autonomous agent with the specified configuration and automatically
|
||||
* provisions an initial session that mirrors the agent's settings.
|
||||
* tags: [Agents]
|
||||
* requestBody:
|
||||
* required: true
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/CreateAgentRequest'
|
||||
* responses:
|
||||
* 201:
|
||||
* description: Agent created successfully
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/AgentEntity'
|
||||
* 400:
|
||||
* description: Validation error
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
* 500:
|
||||
* description: Internal server error
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
*/
|
||||
export const createAgent = async (req: Request, res: Response): Promise<Response> => {
|
||||
try {
|
||||
logger.debug('Creating agent')
|
||||
logger.debug('Agent payload', { body: req.body })
|
||||
|
||||
const agent = await agentService.createAgent(req.body)
|
||||
|
||||
try {
|
||||
logger.info('Agent created', { agentId: agent.id })
|
||||
logger.debug('Creating default session for agent', { agentId: agent.id })
|
||||
|
||||
await sessionService.createSession(agent.id, {})
|
||||
|
||||
logger.info('Default session created for agent', { agentId: agent.id })
|
||||
return res.status(201).json(agent)
|
||||
} catch (sessionError: any) {
|
||||
logger.error('Failed to create default session for new agent, rolling back agent creation', {
|
||||
agentId: agent.id,
|
||||
error: sessionError
|
||||
})
|
||||
|
||||
try {
|
||||
await agentService.deleteAgent(agent.id)
|
||||
} catch (rollbackError: any) {
|
||||
logger.error('Failed to roll back agent after session creation failure', {
|
||||
agentId: agent.id,
|
||||
error: rollbackError
|
||||
})
|
||||
}
|
||||
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
message: `Failed to create default session for agent: ${sessionError.message}`,
|
||||
type: 'internal_error',
|
||||
code: 'agent_session_creation_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
} catch (error: any) {
|
||||
if (error instanceof AgentModelValidationError) {
|
||||
logger.warn('Agent model validation error during create', {
|
||||
agentType: error.context.agentType,
|
||||
field: error.context.field,
|
||||
model: error.context.model,
|
||||
detail: error.detail
|
||||
})
|
||||
return res.status(400).json(modelValidationErrorBody(error))
|
||||
}
|
||||
|
||||
logger.error('Error creating agent', { error })
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
message: `Failed to create agent: ${error.message}`,
|
||||
type: 'internal_error',
|
||||
code: 'agent_creation_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /v1/agents:
|
||||
* get:
|
||||
* summary: List all agents
|
||||
* description: Retrieves a paginated list of all agents
|
||||
* tags: [Agents]
|
||||
* parameters:
|
||||
* - in: query
|
||||
* name: limit
|
||||
* schema:
|
||||
* type: integer
|
||||
* minimum: 1
|
||||
* maximum: 100
|
||||
* default: 20
|
||||
* description: Number of agents to return
|
||||
* - in: query
|
||||
* name: offset
|
||||
* schema:
|
||||
* type: integer
|
||||
* minimum: 0
|
||||
* default: 0
|
||||
* description: Number of agents to skip
|
||||
* responses:
|
||||
* 200:
|
||||
* description: List of agents
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* properties:
|
||||
* data:
|
||||
* type: array
|
||||
* items:
|
||||
* $ref: '#/components/schemas/AgentEntity'
|
||||
* total:
|
||||
* type: integer
|
||||
* description: Total number of agents
|
||||
* limit:
|
||||
* type: integer
|
||||
* description: Number of agents returned
|
||||
* offset:
|
||||
* type: integer
|
||||
* description: Number of agents skipped
|
||||
* 400:
|
||||
* description: Validation error
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
* 500:
|
||||
* description: Internal server error
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
*/
|
||||
export const listAgents = async (req: Request, res: Response): Promise<Response> => {
|
||||
try {
|
||||
const limit = req.query.limit ? parseInt(req.query.limit as string) : 20
|
||||
const offset = req.query.offset ? parseInt(req.query.offset as string) : 0
|
||||
|
||||
logger.debug('Listing agents', { limit, offset })
|
||||
|
||||
const result = await agentService.listAgents({ limit, offset })
|
||||
|
||||
logger.info('Agents listed', {
|
||||
returned: result.agents.length,
|
||||
total: result.total,
|
||||
limit,
|
||||
offset
|
||||
})
|
||||
return res.json({
|
||||
data: result.agents,
|
||||
total: result.total,
|
||||
limit,
|
||||
offset
|
||||
} satisfies ListAgentsResponse)
|
||||
} catch (error: any) {
|
||||
logger.error('Error listing agents', { error })
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to list agents',
|
||||
type: 'internal_error',
|
||||
code: 'agent_list_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /v1/agents/{agentId}:
|
||||
* get:
|
||||
* summary: Get agent by ID
|
||||
* description: Retrieves a specific agent by its ID
|
||||
* tags: [Agents]
|
||||
* parameters:
|
||||
* - in: path
|
||||
* name: agentId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Agent ID
|
||||
* responses:
|
||||
* 200:
|
||||
* description: Agent details
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/AgentEntity'
|
||||
* 404:
|
||||
* description: Agent not found
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
* 500:
|
||||
* description: Internal server error
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
*/
|
||||
export const getAgent = async (req: Request, res: Response): Promise<Response> => {
|
||||
try {
|
||||
const { agentId } = req.params
|
||||
logger.debug('Getting agent', { agentId })
|
||||
|
||||
const agent = await agentService.getAgent(agentId)
|
||||
|
||||
if (!agent) {
|
||||
logger.warn('Agent not found', { agentId })
|
||||
return res.status(404).json({
|
||||
error: {
|
||||
message: 'Agent not found',
|
||||
type: 'not_found',
|
||||
code: 'agent_not_found'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
logger.info('Agent retrieved', { agentId })
|
||||
return res.json(agent)
|
||||
} catch (error: any) {
|
||||
logger.error('Error getting agent', { error, agentId: req.params.agentId })
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to get agent',
|
||||
type: 'internal_error',
|
||||
code: 'agent_get_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /v1/agents/{agentId}:
|
||||
* put:
|
||||
* summary: Update agent
|
||||
* description: Updates an existing agent with the provided data
|
||||
* tags: [Agents]
|
||||
* parameters:
|
||||
* - in: path
|
||||
* name: agentId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Agent ID
|
||||
* requestBody:
|
||||
* required: true
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/CreateAgentRequest'
|
||||
* responses:
|
||||
* 200:
|
||||
* description: Agent updated successfully
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/AgentEntity'
|
||||
* 400:
|
||||
* description: Validation error
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
* 404:
|
||||
* description: Agent not found
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
* 500:
|
||||
* description: Internal server error
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
*/
|
||||
export const updateAgent = async (req: Request, res: Response): Promise<Response> => {
|
||||
const { agentId } = req.params
|
||||
try {
|
||||
logger.debug('Updating agent', { agentId })
|
||||
logger.debug('Replace payload', { body: req.body })
|
||||
|
||||
const { validatedBody } = req as ValidationRequest
|
||||
const replacePayload = (validatedBody ?? {}) as ReplaceAgentRequest
|
||||
|
||||
const agent = await agentService.updateAgent(agentId, replacePayload, { replace: true })
|
||||
|
||||
if (!agent) {
|
||||
logger.warn('Agent not found for update', { agentId })
|
||||
return res.status(404).json({
|
||||
error: {
|
||||
message: 'Agent not found',
|
||||
type: 'not_found',
|
||||
code: 'agent_not_found'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
logger.info('Agent updated', { agentId })
|
||||
return res.json(agent)
|
||||
} catch (error: any) {
|
||||
if (error instanceof AgentModelValidationError) {
|
||||
logger.warn('Agent model validation error during update', {
|
||||
agentId,
|
||||
agentType: error.context.agentType,
|
||||
field: error.context.field,
|
||||
model: error.context.model,
|
||||
detail: error.detail
|
||||
})
|
||||
return res.status(400).json(modelValidationErrorBody(error))
|
||||
}
|
||||
|
||||
logger.error('Error updating agent', { error, agentId })
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to update agent: ' + error.message,
|
||||
type: 'internal_error',
|
||||
code: 'agent_update_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /v1/agents/{agentId}:
|
||||
* patch:
|
||||
* summary: Partially update agent
|
||||
* description: Partially updates an existing agent with only the provided fields
|
||||
* tags: [Agents]
|
||||
* parameters:
|
||||
* - in: path
|
||||
* name: agentId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Agent ID
|
||||
* requestBody:
|
||||
* required: true
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* properties:
|
||||
* name:
|
||||
* type: string
|
||||
* description: Agent name
|
||||
* description:
|
||||
* type: string
|
||||
* description: Agent description
|
||||
* avatar:
|
||||
* type: string
|
||||
* description: Agent avatar URL
|
||||
* instructions:
|
||||
* type: string
|
||||
* description: System prompt/instructions
|
||||
* model:
|
||||
* type: string
|
||||
* description: Main model ID
|
||||
* plan_model:
|
||||
* type: string
|
||||
* description: Optional planning model ID
|
||||
* small_model:
|
||||
* type: string
|
||||
* description: Optional small/fast model ID
|
||||
* tools:
|
||||
* type: array
|
||||
* items:
|
||||
* type: string
|
||||
* description: Tools
|
||||
* mcps:
|
||||
* type: array
|
||||
* items:
|
||||
* type: string
|
||||
* description: MCP tool IDs
|
||||
* knowledges:
|
||||
* type: array
|
||||
* items:
|
||||
* type: string
|
||||
* description: Knowledge base IDs
|
||||
* configuration:
|
||||
* type: object
|
||||
* description: Extensible settings
|
||||
* accessible_paths:
|
||||
* type: array
|
||||
* items:
|
||||
* type: string
|
||||
* description: Accessible directory paths
|
||||
* permission_mode:
|
||||
* type: string
|
||||
* enum: [readOnly, acceptEdits, bypassPermissions]
|
||||
* description: Permission mode
|
||||
* max_steps:
|
||||
* type: integer
|
||||
* description: Maximum steps the agent can take
|
||||
* description: Only include the fields you want to update
|
||||
* responses:
|
||||
* 200:
|
||||
* description: Agent updated successfully
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/AgentEntity'
|
||||
* 400:
|
||||
* description: Validation error
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
* 404:
|
||||
* description: Agent not found
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
* 500:
|
||||
* description: Internal server error
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
*/
|
||||
export const patchAgent = async (req: Request, res: Response): Promise<Response> => {
|
||||
const { agentId } = req.params
|
||||
try {
|
||||
logger.debug('Partially updating agent', { agentId })
|
||||
logger.debug('Patch payload', { body: req.body })
|
||||
|
||||
const { validatedBody } = req as ValidationRequest
|
||||
const updatePayload = (validatedBody ?? {}) as UpdateAgentRequest
|
||||
|
||||
const agent = await agentService.updateAgent(agentId, updatePayload)
|
||||
|
||||
if (!agent) {
|
||||
logger.warn('Agent not found for partial update', { agentId })
|
||||
return res.status(404).json({
|
||||
error: {
|
||||
message: 'Agent not found',
|
||||
type: 'not_found',
|
||||
code: 'agent_not_found'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
logger.info('Agent patched', { agentId })
|
||||
return res.json(agent)
|
||||
} catch (error: any) {
|
||||
if (error instanceof AgentModelValidationError) {
|
||||
logger.warn('Agent model validation error during partial update', {
|
||||
agentId,
|
||||
agentType: error.context.agentType,
|
||||
field: error.context.field,
|
||||
model: error.context.model,
|
||||
detail: error.detail
|
||||
})
|
||||
return res.status(400).json(modelValidationErrorBody(error))
|
||||
}
|
||||
|
||||
logger.error('Error partially updating agent', { error, agentId })
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
message: `Failed to partially update agent: ${error.message}`,
|
||||
type: 'internal_error',
|
||||
code: 'agent_patch_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /v1/agents/{agentId}:
|
||||
* delete:
|
||||
* summary: Delete agent
|
||||
* description: Deletes an agent and all associated sessions and logs
|
||||
* tags: [Agents]
|
||||
* parameters:
|
||||
* - in: path
|
||||
* name: agentId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Agent ID
|
||||
* responses:
|
||||
* 204:
|
||||
* description: Agent deleted successfully
|
||||
* 404:
|
||||
* description: Agent not found
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
* 500:
|
||||
* description: Internal server error
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
*/
|
||||
export const deleteAgent = async (req: Request, res: Response): Promise<Response> => {
|
||||
try {
|
||||
const { agentId } = req.params
|
||||
logger.debug('Deleting agent', { agentId })
|
||||
|
||||
const deleted = await agentService.deleteAgent(agentId)
|
||||
|
||||
if (!deleted) {
|
||||
logger.warn('Agent not found for deletion', { agentId })
|
||||
return res.status(404).json({
|
||||
error: {
|
||||
message: 'Agent not found',
|
||||
type: 'not_found',
|
||||
code: 'agent_not_found'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
logger.info('Agent deleted', { agentId })
|
||||
return res.status(204).send()
|
||||
} catch (error: any) {
|
||||
logger.error('Error deleting agent', { error, agentId: req.params.agentId })
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to delete agent',
|
||||
type: 'internal_error',
|
||||
code: 'agent_delete_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,3 @@
|
||||
export * as agentHandlers from './agents'
|
||||
export * as messageHandlers from './messages'
|
||||
export * as sessionHandlers from './sessions'
|
||||
@@ -0,0 +1,317 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { MESSAGE_STREAM_TIMEOUT_MS } from '@main/apiServer/config/timeouts'
|
||||
import { createStreamAbortController, STREAM_TIMEOUT_REASON } from '@main/apiServer/utils/createStreamAbortController'
|
||||
import { agentService, sessionMessageService, sessionService } from '@main/services/agents'
|
||||
import { Request, Response } from 'express'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerMessagesHandlers')
|
||||
|
||||
// Helper function to verify agent and session exist and belong together
|
||||
const verifyAgentAndSession = async (agentId: string, sessionId: string) => {
|
||||
const agentExists = await agentService.agentExists(agentId)
|
||||
if (!agentExists) {
|
||||
throw { status: 404, code: 'agent_not_found', message: 'Agent not found' }
|
||||
}
|
||||
|
||||
const session = await sessionService.getSession(agentId, sessionId)
|
||||
if (!session) {
|
||||
throw { status: 404, code: 'session_not_found', message: 'Session not found' }
|
||||
}
|
||||
|
||||
if (session.agent_id !== agentId) {
|
||||
throw { status: 404, code: 'session_not_found', message: 'Session not found for this agent' }
|
||||
}
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
export const createMessage = async (req: Request, res: Response): Promise<void> => {
|
||||
let clearAbortTimeout: (() => void) | undefined
|
||||
|
||||
try {
|
||||
const { agentId, sessionId } = req.params
|
||||
|
||||
const session = await verifyAgentAndSession(agentId, sessionId)
|
||||
|
||||
const messageData = req.body
|
||||
|
||||
logger.info('Creating streaming message', { agentId, sessionId })
|
||||
logger.debug('Streaming message payload', { messageData })
|
||||
|
||||
// Set SSE headers
|
||||
res.setHeader('Content-Type', 'text/event-stream')
|
||||
res.setHeader('Cache-Control', 'no-cache')
|
||||
res.setHeader('Connection', 'keep-alive')
|
||||
res.setHeader('Access-Control-Allow-Origin', '*')
|
||||
res.setHeader('Access-Control-Allow-Headers', 'Cache-Control')
|
||||
|
||||
const {
|
||||
abortController,
|
||||
registerAbortHandler,
|
||||
clearAbortTimeout: helperClearAbortTimeout
|
||||
} = createStreamAbortController({
|
||||
timeoutMs: MESSAGE_STREAM_TIMEOUT_MS
|
||||
})
|
||||
clearAbortTimeout = helperClearAbortTimeout
|
||||
const { stream, completion } = await sessionMessageService.createSessionMessage(
|
||||
session,
|
||||
messageData,
|
||||
abortController
|
||||
)
|
||||
const reader = stream.getReader()
|
||||
|
||||
// Track stream lifecycle so we keep the SSE connection open until persistence finishes
|
||||
let responseEnded = false
|
||||
let streamFinished = false
|
||||
|
||||
const cleanupAbortTimeout = () => {
|
||||
clearAbortTimeout?.()
|
||||
}
|
||||
|
||||
const finalizeResponse = () => {
|
||||
if (responseEnded) {
|
||||
return
|
||||
}
|
||||
|
||||
if (!streamFinished) {
|
||||
return
|
||||
}
|
||||
|
||||
responseEnded = true
|
||||
cleanupAbortTimeout()
|
||||
try {
|
||||
// res.write('data: {"type":"finish"}\n\n')
|
||||
res.write('data: [DONE]\n\n')
|
||||
} catch (writeError) {
|
||||
logger.error('Error writing final sentinel to SSE stream', { error: writeError as Error })
|
||||
}
|
||||
res.end()
|
||||
}
|
||||
|
||||
/**
|
||||
* Client Disconnect Detection for Server-Sent Events (SSE)
|
||||
*
|
||||
* We monitor multiple HTTP events to reliably detect when a client disconnects
|
||||
* from the streaming response. This is crucial for:
|
||||
* - Aborting long-running Claude Code processes
|
||||
* - Cleaning up resources and preventing memory leaks
|
||||
* - Avoiding orphaned processes
|
||||
*
|
||||
* Event Priority & Behavior:
|
||||
* 1. res.on('close') - Most common for SSE client disconnects (browser tab close, curl Ctrl+C)
|
||||
* 2. req.on('aborted') - Explicit request abortion
|
||||
* 3. req.on('close') - Request object closure (less common with SSE)
|
||||
*
|
||||
* When any disconnect event fires, we:
|
||||
* - Abort the Claude Code SDK process via abortController
|
||||
* - Clean up event listeners to prevent memory leaks
|
||||
* - Mark the response as ended to prevent further writes
|
||||
*/
|
||||
registerAbortHandler((abortReason) => {
|
||||
cleanupAbortTimeout()
|
||||
|
||||
if (responseEnded) return
|
||||
|
||||
responseEnded = true
|
||||
|
||||
if (abortReason === STREAM_TIMEOUT_REASON) {
|
||||
logger.error('Streaming message timeout', { agentId, sessionId })
|
||||
try {
|
||||
res.write(
|
||||
`data: ${JSON.stringify({
|
||||
type: 'error',
|
||||
error: {
|
||||
message: 'Stream timeout',
|
||||
type: 'timeout_error',
|
||||
code: 'stream_timeout'
|
||||
}
|
||||
})}\n\n`
|
||||
)
|
||||
} catch (writeError) {
|
||||
logger.error('Error writing timeout to SSE stream', { error: writeError })
|
||||
}
|
||||
} else if (abortReason === 'Client disconnected') {
|
||||
logger.info('Streaming client disconnected', { agentId, sessionId })
|
||||
} else {
|
||||
logger.warn('Streaming aborted', { agentId, sessionId, reason: abortReason })
|
||||
}
|
||||
|
||||
reader.cancel(abortReason ?? 'stream aborted').catch(() => {})
|
||||
|
||||
if (!res.headersSent) {
|
||||
res.setHeader('Content-Type', 'text/event-stream')
|
||||
res.setHeader('Cache-Control', 'no-cache')
|
||||
res.setHeader('Connection', 'keep-alive')
|
||||
}
|
||||
|
||||
if (!res.writableEnded) {
|
||||
res.end()
|
||||
}
|
||||
})
|
||||
|
||||
const handleDisconnect = () => {
|
||||
if (abortController.signal.aborted) return
|
||||
abortController.abort('Client disconnected')
|
||||
}
|
||||
|
||||
req.on('close', handleDisconnect)
|
||||
req.on('aborted', handleDisconnect)
|
||||
res.on('close', handleDisconnect)
|
||||
|
||||
const pumpStream = async () => {
|
||||
try {
|
||||
while (!responseEnded) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) {
|
||||
break
|
||||
}
|
||||
|
||||
res.write(`data: ${JSON.stringify(value)}\n\n`)
|
||||
}
|
||||
|
||||
streamFinished = true
|
||||
finalizeResponse()
|
||||
} catch (error) {
|
||||
if (responseEnded) return
|
||||
logger.error('Error reading agent stream', { error })
|
||||
try {
|
||||
res.write(
|
||||
`data: ${JSON.stringify({
|
||||
type: 'error',
|
||||
error: {
|
||||
message: (error as Error).message || 'Stream processing error',
|
||||
type: 'stream_error',
|
||||
code: 'stream_processing_failed'
|
||||
}
|
||||
})}\n\n`
|
||||
)
|
||||
} catch (writeError) {
|
||||
logger.error('Error writing stream error to SSE', { error: writeError })
|
||||
}
|
||||
responseEnded = true
|
||||
cleanupAbortTimeout()
|
||||
res.end()
|
||||
}
|
||||
}
|
||||
|
||||
pumpStream().catch((error) => {
|
||||
logger.error('Pump stream failure', { error })
|
||||
})
|
||||
|
||||
completion
|
||||
.then(() => {
|
||||
streamFinished = true
|
||||
finalizeResponse()
|
||||
})
|
||||
.catch((error) => {
|
||||
if (responseEnded) return
|
||||
logger.error('Streaming message error', { agentId, sessionId, error })
|
||||
try {
|
||||
res.write(
|
||||
`data: ${JSON.stringify({
|
||||
type: 'error',
|
||||
error: {
|
||||
message: (error as { message?: string })?.message || 'Stream processing error',
|
||||
type: 'stream_error',
|
||||
code: 'stream_processing_failed'
|
||||
}
|
||||
})}\n\n`
|
||||
)
|
||||
} catch (writeError) {
|
||||
logger.error('Error writing completion error to SSE stream', { error: writeError })
|
||||
}
|
||||
responseEnded = true
|
||||
cleanupAbortTimeout()
|
||||
res.end()
|
||||
})
|
||||
// Clear timeout when response ends
|
||||
res.on('close', cleanupAbortTimeout)
|
||||
res.on('finish', cleanupAbortTimeout)
|
||||
} catch (error: any) {
|
||||
clearAbortTimeout?.()
|
||||
logger.error('Error in streaming message handler', {
|
||||
error,
|
||||
agentId: req.params.agentId,
|
||||
sessionId: req.params.sessionId
|
||||
})
|
||||
|
||||
// Send error as SSE if possible
|
||||
if (!res.headersSent) {
|
||||
res.setHeader('Content-Type', 'text/event-stream')
|
||||
res.setHeader('Cache-Control', 'no-cache')
|
||||
res.setHeader('Connection', 'keep-alive')
|
||||
}
|
||||
|
||||
try {
|
||||
const errorResponse = {
|
||||
type: 'error',
|
||||
error: {
|
||||
message: error.status ? error.message : 'Failed to create streaming message',
|
||||
type: error.status ? 'not_found' : 'internal_error',
|
||||
code: error.status ? error.code : 'stream_creation_failed'
|
||||
}
|
||||
}
|
||||
|
||||
res.write(`data: ${JSON.stringify(errorResponse)}\n\n`)
|
||||
} catch (writeError) {
|
||||
logger.error('Error writing initial error to SSE stream', { error: writeError })
|
||||
}
|
||||
|
||||
res.end()
|
||||
}
|
||||
}
|
||||
|
||||
export const deleteMessage = async (req: Request, res: Response): Promise<Response> => {
|
||||
try {
|
||||
const { agentId, sessionId, messageId: messageIdParam } = req.params
|
||||
const messageId = Number(messageIdParam)
|
||||
|
||||
await verifyAgentAndSession(agentId, sessionId)
|
||||
|
||||
const deleted = await sessionMessageService.deleteSessionMessage(sessionId, messageId)
|
||||
|
||||
if (!deleted) {
|
||||
logger.warn('Session message not found', { agentId, sessionId, messageId })
|
||||
return res.status(404).json({
|
||||
error: {
|
||||
message: 'Message not found for this session',
|
||||
type: 'not_found',
|
||||
code: 'session_message_not_found'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
logger.info('Session message deleted', { agentId, sessionId, messageId })
|
||||
return res.status(204).send()
|
||||
} catch (error: any) {
|
||||
if (error?.status === 404) {
|
||||
logger.warn('Delete message failed - missing resource', {
|
||||
agentId: req.params.agentId,
|
||||
sessionId: req.params.sessionId,
|
||||
messageId: req.params.messageId,
|
||||
error
|
||||
})
|
||||
return res.status(404).json({
|
||||
error: {
|
||||
message: error.message,
|
||||
type: 'not_found',
|
||||
code: error.code ?? 'session_message_not_found'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
logger.error('Error deleting session message', {
|
||||
error,
|
||||
agentId: req.params.agentId,
|
||||
sessionId: req.params.sessionId,
|
||||
messageId: Number(req.params.messageId)
|
||||
})
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to delete session message',
|
||||
type: 'internal_error',
|
||||
code: 'session_message_delete_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,366 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { AgentModelValidationError, sessionMessageService, sessionService } from '@main/services/agents'
|
||||
import { ListAgentSessionsResponse, type ReplaceSessionRequest, UpdateSessionResponse } from '@types'
|
||||
import { Request, Response } from 'express'
|
||||
|
||||
import type { ValidationRequest } from '../validators/zodValidator'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerSessionsHandlers')
|
||||
|
||||
const modelValidationErrorBody = (error: AgentModelValidationError) => ({
|
||||
error: {
|
||||
message: `Invalid ${error.context.field}: ${error.detail.message}`,
|
||||
type: 'invalid_request_error',
|
||||
code: error.detail.code
|
||||
}
|
||||
})
|
||||
|
||||
export const createSession = async (req: Request, res: Response): Promise<Response> => {
|
||||
const { agentId } = req.params
|
||||
try {
|
||||
const sessionData = req.body
|
||||
|
||||
logger.debug('Creating new session', { agentId })
|
||||
logger.debug('Session payload', { sessionData })
|
||||
|
||||
const session = await sessionService.createSession(agentId, sessionData)
|
||||
|
||||
logger.info('Session created', { agentId, sessionId: session?.id })
|
||||
return res.status(201).json(session)
|
||||
} catch (error: any) {
|
||||
if (error instanceof AgentModelValidationError) {
|
||||
logger.warn('Session model validation error during create', {
|
||||
agentId,
|
||||
agentType: error.context.agentType,
|
||||
field: error.context.field,
|
||||
model: error.context.model,
|
||||
detail: error.detail
|
||||
})
|
||||
return res.status(400).json(modelValidationErrorBody(error))
|
||||
}
|
||||
|
||||
logger.error('Error creating session', { error, agentId })
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
message: `Failed to create session: ${error.message}`,
|
||||
type: 'internal_error',
|
||||
code: 'session_creation_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
export const listSessions = async (req: Request, res: Response): Promise<Response> => {
|
||||
const { agentId } = req.params
|
||||
try {
|
||||
const limit = req.query.limit ? parseInt(req.query.limit as string) : 20
|
||||
const offset = req.query.offset ? parseInt(req.query.offset as string) : 0
|
||||
const status = req.query.status as any
|
||||
|
||||
logger.debug('Listing agent sessions', { agentId, limit, offset, status })
|
||||
|
||||
const result = await sessionService.listSessions(agentId, { limit, offset })
|
||||
|
||||
logger.info('Agent sessions listed', {
|
||||
agentId,
|
||||
returned: result.sessions.length,
|
||||
total: result.total,
|
||||
limit,
|
||||
offset
|
||||
})
|
||||
return res.json({
|
||||
data: result.sessions,
|
||||
total: result.total,
|
||||
limit,
|
||||
offset
|
||||
})
|
||||
} catch (error: any) {
|
||||
logger.error('Error listing sessions', { error, agentId })
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to list sessions',
|
||||
type: 'internal_error',
|
||||
code: 'session_list_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
export const getSession = async (req: Request, res: Response): Promise<Response> => {
|
||||
try {
|
||||
const { agentId, sessionId } = req.params
|
||||
logger.debug('Getting session', { agentId, sessionId })
|
||||
|
||||
const session = await sessionService.getSession(agentId, sessionId)
|
||||
|
||||
if (!session) {
|
||||
logger.warn('Session not found', { agentId, sessionId })
|
||||
return res.status(404).json({
|
||||
error: {
|
||||
message: 'Session not found',
|
||||
type: 'not_found',
|
||||
code: 'session_not_found'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// // Verify session belongs to the agent
|
||||
// logger.warn(`Session ${sessionId} does not belong to agent ${agentId}`)
|
||||
// return res.status(404).json({
|
||||
// error: {
|
||||
// message: 'Session not found for this agent',
|
||||
// type: 'not_found',
|
||||
// code: 'session_not_found'
|
||||
// }
|
||||
// })
|
||||
// }
|
||||
|
||||
// Fetch session messages
|
||||
logger.debug('Fetching session messages', { sessionId })
|
||||
const { messages } = await sessionMessageService.listSessionMessages(sessionId)
|
||||
|
||||
// Add messages to session
|
||||
const sessionWithMessages = {
|
||||
...session,
|
||||
messages: messages
|
||||
}
|
||||
|
||||
logger.info('Session retrieved', { agentId, sessionId, messageCount: messages.length })
|
||||
return res.json(sessionWithMessages)
|
||||
} catch (error: any) {
|
||||
logger.error('Error getting session', { error, agentId: req.params.agentId, sessionId: req.params.sessionId })
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to get session',
|
||||
type: 'internal_error',
|
||||
code: 'session_get_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
export const updateSession = async (req: Request, res: Response): Promise<Response> => {
|
||||
const { agentId, sessionId } = req.params
|
||||
try {
|
||||
logger.debug('Updating session', { agentId, sessionId })
|
||||
logger.debug('Replace payload', { body: req.body })
|
||||
|
||||
// First check if session exists and belongs to agent
|
||||
const existingSession = await sessionService.getSession(agentId, sessionId)
|
||||
if (!existingSession || existingSession.agent_id !== agentId) {
|
||||
logger.warn('Session not found for update', { agentId, sessionId })
|
||||
return res.status(404).json({
|
||||
error: {
|
||||
message: 'Session not found for this agent',
|
||||
type: 'not_found',
|
||||
code: 'session_not_found'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const { validatedBody } = req as ValidationRequest
|
||||
const replacePayload = (validatedBody ?? {}) as ReplaceSessionRequest
|
||||
|
||||
const session = await sessionService.updateSession(agentId, sessionId, replacePayload)
|
||||
|
||||
if (!session) {
|
||||
logger.warn('Session missing during update', { agentId, sessionId })
|
||||
return res.status(404).json({
|
||||
error: {
|
||||
message: 'Session not found',
|
||||
type: 'not_found',
|
||||
code: 'session_not_found'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
logger.info('Session updated', { agentId, sessionId })
|
||||
return res.json(session satisfies UpdateSessionResponse)
|
||||
} catch (error: any) {
|
||||
if (error instanceof AgentModelValidationError) {
|
||||
logger.warn('Session model validation error during update', {
|
||||
agentId,
|
||||
sessionId,
|
||||
agentType: error.context.agentType,
|
||||
field: error.context.field,
|
||||
model: error.context.model,
|
||||
detail: error.detail
|
||||
})
|
||||
return res.status(400).json(modelValidationErrorBody(error))
|
||||
}
|
||||
|
||||
logger.error('Error updating session', { error, agentId, sessionId })
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
message: `Failed to update session: ${error.message}`,
|
||||
type: 'internal_error',
|
||||
code: 'session_update_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
export const patchSession = async (req: Request, res: Response): Promise<Response> => {
|
||||
const { agentId, sessionId } = req.params
|
||||
try {
|
||||
logger.debug('Patching session', { agentId, sessionId })
|
||||
logger.debug('Patch payload', { body: req.body })
|
||||
|
||||
// First check if session exists and belongs to agent
|
||||
const existingSession = await sessionService.getSession(agentId, sessionId)
|
||||
if (!existingSession || existingSession.agent_id !== agentId) {
|
||||
logger.warn('Session not found for patch', { agentId, sessionId })
|
||||
return res.status(404).json({
|
||||
error: {
|
||||
message: 'Session not found for this agent',
|
||||
type: 'not_found',
|
||||
code: 'session_not_found'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const updateSession = { ...existingSession, ...req.body }
|
||||
const session = await sessionService.updateSession(agentId, sessionId, updateSession)
|
||||
|
||||
if (!session) {
|
||||
logger.warn('Session missing while patching', { agentId, sessionId })
|
||||
return res.status(404).json({
|
||||
error: {
|
||||
message: 'Session not found',
|
||||
type: 'not_found',
|
||||
code: 'session_not_found'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
logger.info('Session patched', { agentId, sessionId })
|
||||
return res.json(session)
|
||||
} catch (error: any) {
|
||||
if (error instanceof AgentModelValidationError) {
|
||||
logger.warn('Session model validation error during patch', {
|
||||
agentId,
|
||||
sessionId,
|
||||
agentType: error.context.agentType,
|
||||
field: error.context.field,
|
||||
model: error.context.model,
|
||||
detail: error.detail
|
||||
})
|
||||
return res.status(400).json(modelValidationErrorBody(error))
|
||||
}
|
||||
|
||||
logger.error('Error patching session', { error, agentId, sessionId })
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
message: `Failed to patch session, ${error.message}`,
|
||||
type: 'internal_error',
|
||||
code: 'session_patch_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
export const deleteSession = async (req: Request, res: Response): Promise<Response> => {
|
||||
try {
|
||||
const { agentId, sessionId } = req.params
|
||||
logger.debug('Deleting session', { agentId, sessionId })
|
||||
|
||||
// First check if session exists and belongs to agent
|
||||
const existingSession = await sessionService.getSession(agentId, sessionId)
|
||||
if (!existingSession || existingSession.agent_id !== agentId) {
|
||||
logger.warn('Session not found for deletion', { agentId, sessionId })
|
||||
return res.status(404).json({
|
||||
error: {
|
||||
message: 'Session not found for this agent',
|
||||
type: 'not_found',
|
||||
code: 'session_not_found'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const deleted = await sessionService.deleteSession(agentId, sessionId)
|
||||
|
||||
if (!deleted) {
|
||||
logger.warn('Session missing during delete', { agentId, sessionId })
|
||||
return res.status(404).json({
|
||||
error: {
|
||||
message: 'Session not found',
|
||||
type: 'not_found',
|
||||
code: 'session_not_found'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
logger.info('Session deleted', { agentId, sessionId })
|
||||
|
||||
const { total } = await sessionService.listSessions(agentId, { limit: 1 })
|
||||
|
||||
if (total === 0) {
|
||||
logger.info('No remaining sessions, creating default', { agentId })
|
||||
try {
|
||||
const fallbackSession = await sessionService.createSession(agentId, {})
|
||||
logger.info('Default session created after delete', {
|
||||
agentId,
|
||||
sessionId: fallbackSession?.id
|
||||
})
|
||||
} catch (recoveryError: any) {
|
||||
logger.error('Failed to recreate session after deleting last session', {
|
||||
agentId,
|
||||
error: recoveryError
|
||||
})
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
message: `Failed to recreate session after deletion: ${recoveryError.message}`,
|
||||
type: 'internal_error',
|
||||
code: 'session_recovery_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return res.status(204).send()
|
||||
} catch (error: any) {
|
||||
logger.error('Error deleting session', { error, agentId: req.params.agentId, sessionId: req.params.sessionId })
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to delete session',
|
||||
type: 'internal_error',
|
||||
code: 'session_delete_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Convenience endpoints for sessions without agent context
|
||||
export const listAllSessions = async (req: Request, res: Response): Promise<Response> => {
|
||||
try {
|
||||
const limit = req.query.limit ? parseInt(req.query.limit as string) : 20
|
||||
const offset = req.query.offset ? parseInt(req.query.offset as string) : 0
|
||||
const status = req.query.status as any
|
||||
|
||||
logger.debug('Listing all sessions', { limit, offset, status })
|
||||
|
||||
const result = await sessionService.listSessions(undefined, { limit, offset })
|
||||
|
||||
logger.info('Sessions listed', {
|
||||
returned: result.sessions.length,
|
||||
total: result.total,
|
||||
limit,
|
||||
offset
|
||||
})
|
||||
return res.json({
|
||||
data: result.sessions,
|
||||
total: result.total,
|
||||
limit,
|
||||
offset
|
||||
} satisfies ListAgentSessionsResponse)
|
||||
} catch (error: any) {
|
||||
logger.error('Error listing all sessions', { error })
|
||||
return res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to list sessions',
|
||||
type: 'internal_error',
|
||||
code: 'session_list_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,965 @@
|
||||
import express from 'express'
|
||||
|
||||
import { agentHandlers, messageHandlers, sessionHandlers } from './handlers'
|
||||
import { checkAgentExists, handleValidationErrors } from './middleware'
|
||||
import {
|
||||
validateAgent,
|
||||
validateAgentId,
|
||||
validateAgentReplace,
|
||||
validateAgentUpdate,
|
||||
validatePagination,
|
||||
validateSession,
|
||||
validateSessionId,
|
||||
validateSessionMessage,
|
||||
validateSessionMessageId,
|
||||
validateSessionReplace,
|
||||
validateSessionUpdate
|
||||
} from './validators'
|
||||
|
||||
// Create main agents router
|
||||
const agentsRouter = express.Router()
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* components:
|
||||
* schemas:
|
||||
* PermissionMode:
|
||||
* type: string
|
||||
* enum: [default, acceptEdits, bypassPermissions, plan]
|
||||
* description: Permission mode for agent operations
|
||||
*
|
||||
* AgentType:
|
||||
* type: string
|
||||
* enum: [claude-code]
|
||||
* description: Type of agent
|
||||
*
|
||||
* AgentConfiguration:
|
||||
* type: object
|
||||
* properties:
|
||||
* permission_mode:
|
||||
* $ref: '#/components/schemas/PermissionMode'
|
||||
* default: default
|
||||
* max_turns:
|
||||
* type: integer
|
||||
* default: 10
|
||||
* description: Maximum number of interaction turns
|
||||
* additionalProperties: true
|
||||
*
|
||||
* AgentBase:
|
||||
* type: object
|
||||
* properties:
|
||||
* name:
|
||||
* type: string
|
||||
* description: Agent name
|
||||
* description:
|
||||
* type: string
|
||||
* description: Agent description
|
||||
* accessible_paths:
|
||||
* type: array
|
||||
* items:
|
||||
* type: string
|
||||
* description: Array of directory paths the agent can access
|
||||
* instructions:
|
||||
* type: string
|
||||
* description: System prompt/instructions
|
||||
* model:
|
||||
* type: string
|
||||
* description: Main model ID
|
||||
* plan_model:
|
||||
* type: string
|
||||
* description: Optional planning model ID
|
||||
* small_model:
|
||||
* type: string
|
||||
* description: Optional small/fast model ID
|
||||
* mcps:
|
||||
* type: array
|
||||
* items:
|
||||
* type: string
|
||||
* description: Array of MCP tool IDs
|
||||
* allowed_tools:
|
||||
* type: array
|
||||
* items:
|
||||
* type: string
|
||||
* description: Array of allowed tool IDs (whitelist)
|
||||
* configuration:
|
||||
* $ref: '#/components/schemas/AgentConfiguration'
|
||||
* required:
|
||||
* - model
|
||||
* - accessible_paths
|
||||
*
|
||||
* AgentEntity:
|
||||
* allOf:
|
||||
* - $ref: '#/components/schemas/AgentBase'
|
||||
* - type: object
|
||||
* properties:
|
||||
* id:
|
||||
* type: string
|
||||
* description: Unique agent identifier
|
||||
* type:
|
||||
* $ref: '#/components/schemas/AgentType'
|
||||
* created_at:
|
||||
* type: string
|
||||
* format: date-time
|
||||
* description: ISO timestamp of creation
|
||||
* updated_at:
|
||||
* type: string
|
||||
* format: date-time
|
||||
* description: ISO timestamp of last update
|
||||
* required:
|
||||
* - id
|
||||
* - type
|
||||
* - created_at
|
||||
* - updated_at
|
||||
* CreateAgentRequest:
|
||||
* allOf:
|
||||
* - $ref: '#/components/schemas/AgentBase'
|
||||
* - type: object
|
||||
* properties:
|
||||
* type:
|
||||
* $ref: '#/components/schemas/AgentType'
|
||||
* name:
|
||||
* type: string
|
||||
* minLength: 1
|
||||
* description: Agent name (required)
|
||||
* model:
|
||||
* type: string
|
||||
* minLength: 1
|
||||
* description: Main model ID (required)
|
||||
* required:
|
||||
* - type
|
||||
* - name
|
||||
* - model
|
||||
*
|
||||
* UpdateAgentRequest:
|
||||
* type: object
|
||||
* properties:
|
||||
* name:
|
||||
* type: string
|
||||
* description: Agent name
|
||||
* description:
|
||||
* type: string
|
||||
* description: Agent description
|
||||
* accessible_paths:
|
||||
* type: array
|
||||
* items:
|
||||
* type: string
|
||||
* description: Array of directory paths the agent can access
|
||||
* instructions:
|
||||
* type: string
|
||||
* description: System prompt/instructions
|
||||
* model:
|
||||
* type: string
|
||||
* description: Main model ID
|
||||
* plan_model:
|
||||
* type: string
|
||||
* description: Optional planning model ID
|
||||
* small_model:
|
||||
* type: string
|
||||
* description: Optional small/fast model ID
|
||||
* mcps:
|
||||
* type: array
|
||||
* items:
|
||||
* type: string
|
||||
* description: Array of MCP tool IDs
|
||||
* allowed_tools:
|
||||
* type: array
|
||||
* items:
|
||||
* type: string
|
||||
* description: Array of allowed tool IDs (whitelist)
|
||||
* configuration:
|
||||
* $ref: '#/components/schemas/AgentConfiguration'
|
||||
* description: Partial update - all fields are optional
|
||||
*
|
||||
* ReplaceAgentRequest:
|
||||
* $ref: '#/components/schemas/AgentBase'
|
||||
*
|
||||
* SessionEntity:
|
||||
* allOf:
|
||||
* - $ref: '#/components/schemas/AgentBase'
|
||||
* - type: object
|
||||
* properties:
|
||||
* id:
|
||||
* type: string
|
||||
* description: Unique session identifier
|
||||
* agent_id:
|
||||
* type: string
|
||||
* description: Primary agent ID for the session
|
||||
* agent_type:
|
||||
* $ref: '#/components/schemas/AgentType'
|
||||
* created_at:
|
||||
* type: string
|
||||
* format: date-time
|
||||
* description: ISO timestamp of creation
|
||||
* updated_at:
|
||||
* type: string
|
||||
* format: date-time
|
||||
* description: ISO timestamp of last update
|
||||
* required:
|
||||
* - id
|
||||
* - agent_id
|
||||
* - agent_type
|
||||
* - created_at
|
||||
* - updated_at
|
||||
*
|
||||
* CreateSessionRequest:
|
||||
* allOf:
|
||||
* - $ref: '#/components/schemas/AgentBase'
|
||||
* - type: object
|
||||
* properties:
|
||||
* model:
|
||||
* type: string
|
||||
* minLength: 1
|
||||
* description: Main model ID (required)
|
||||
* required:
|
||||
* - model
|
||||
*
|
||||
* UpdateSessionRequest:
|
||||
* type: object
|
||||
* properties:
|
||||
* name:
|
||||
* type: string
|
||||
* description: Session name
|
||||
* description:
|
||||
* type: string
|
||||
* description: Session description
|
||||
* accessible_paths:
|
||||
* type: array
|
||||
* items:
|
||||
* type: string
|
||||
* description: Array of directory paths the agent can access
|
||||
* instructions:
|
||||
* type: string
|
||||
* description: System prompt/instructions
|
||||
* model:
|
||||
* type: string
|
||||
* description: Main model ID
|
||||
* plan_model:
|
||||
* type: string
|
||||
* description: Optional planning model ID
|
||||
* small_model:
|
||||
* type: string
|
||||
* description: Optional small/fast model ID
|
||||
* mcps:
|
||||
* type: array
|
||||
* items:
|
||||
* type: string
|
||||
* description: Array of MCP tool IDs
|
||||
* allowed_tools:
|
||||
* type: array
|
||||
* items:
|
||||
* type: string
|
||||
* description: Array of allowed tool IDs (whitelist)
|
||||
* configuration:
|
||||
* $ref: '#/components/schemas/AgentConfiguration'
|
||||
* description: Partial update - all fields are optional
|
||||
*
|
||||
* ReplaceSessionRequest:
|
||||
* allOf:
|
||||
* - $ref: '#/components/schemas/AgentBase'
|
||||
* - type: object
|
||||
* properties:
|
||||
* model:
|
||||
* type: string
|
||||
* minLength: 1
|
||||
* description: Main model ID (required)
|
||||
* required:
|
||||
* - model
|
||||
*
|
||||
* CreateSessionMessageRequest:
|
||||
* type: object
|
||||
* properties:
|
||||
* content:
|
||||
* type: string
|
||||
* minLength: 1
|
||||
* description: Message content
|
||||
* required:
|
||||
* - content
|
||||
*
|
||||
* PaginationQuery:
|
||||
* type: object
|
||||
* properties:
|
||||
* limit:
|
||||
* type: integer
|
||||
* minimum: 1
|
||||
* maximum: 100
|
||||
* default: 20
|
||||
* description: Number of items to return
|
||||
* offset:
|
||||
* type: integer
|
||||
* minimum: 0
|
||||
* default: 0
|
||||
* description: Number of items to skip
|
||||
* status:
|
||||
* type: string
|
||||
* enum: [idle, running, completed, failed, stopped]
|
||||
* description: Filter by session status
|
||||
*
|
||||
* ListAgentsResponse:
|
||||
* type: object
|
||||
* properties:
|
||||
* agents:
|
||||
* type: array
|
||||
* items:
|
||||
* $ref: '#/components/schemas/AgentEntity'
|
||||
* total:
|
||||
* type: integer
|
||||
* description: Total number of agents
|
||||
* limit:
|
||||
* type: integer
|
||||
* description: Number of items returned
|
||||
* offset:
|
||||
* type: integer
|
||||
* description: Number of items skipped
|
||||
* required:
|
||||
* - agents
|
||||
* - total
|
||||
* - limit
|
||||
* - offset
|
||||
*
|
||||
* ListSessionsResponse:
|
||||
* type: object
|
||||
* properties:
|
||||
* sessions:
|
||||
* type: array
|
||||
* items:
|
||||
* $ref: '#/components/schemas/SessionEntity'
|
||||
* total:
|
||||
* type: integer
|
||||
* description: Total number of sessions
|
||||
* limit:
|
||||
* type: integer
|
||||
* description: Number of items returned
|
||||
* offset:
|
||||
* type: integer
|
||||
* description: Number of items skipped
|
||||
* required:
|
||||
* - sessions
|
||||
* - total
|
||||
* - limit
|
||||
* - offset
|
||||
*
|
||||
* ErrorResponse:
|
||||
* type: object
|
||||
* properties:
|
||||
* error:
|
||||
* type: object
|
||||
* properties:
|
||||
* message:
|
||||
* type: string
|
||||
* description: Error message
|
||||
* type:
|
||||
* type: string
|
||||
* description: Error type
|
||||
* code:
|
||||
* type: string
|
||||
* description: Error code
|
||||
* required:
|
||||
* - message
|
||||
* - type
|
||||
* - code
|
||||
* required:
|
||||
* - error
|
||||
*/
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /agents:
|
||||
* post:
|
||||
* summary: Create a new agent
|
||||
* tags: [Agents]
|
||||
* requestBody:
|
||||
* required: true
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/CreateAgentRequest'
|
||||
* responses:
|
||||
* 201:
|
||||
* description: Agent created successfully
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/AgentEntity'
|
||||
* 400:
|
||||
* description: Invalid request body
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ErrorResponse'
|
||||
*/
|
||||
// Agent CRUD routes
|
||||
agentsRouter.post('/', validateAgent, handleValidationErrors, agentHandlers.createAgent)
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /agents:
|
||||
* get:
|
||||
* summary: List all agents with pagination
|
||||
* tags: [Agents]
|
||||
* parameters:
|
||||
* - in: query
|
||||
* name: limit
|
||||
* schema:
|
||||
* type: integer
|
||||
* minimum: 1
|
||||
* maximum: 100
|
||||
* default: 20
|
||||
* description: Number of agents to return
|
||||
* - in: query
|
||||
* name: offset
|
||||
* schema:
|
||||
* type: integer
|
||||
* minimum: 0
|
||||
* default: 0
|
||||
* description: Number of agents to skip
|
||||
* - in: query
|
||||
* name: status
|
||||
* schema:
|
||||
* type: string
|
||||
* enum: [idle, running, completed, failed, stopped]
|
||||
* description: Filter by agent status
|
||||
* responses:
|
||||
* 200:
|
||||
* description: List of agents
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ListAgentsResponse'
|
||||
*/
|
||||
agentsRouter.get('/', validatePagination, handleValidationErrors, agentHandlers.listAgents)
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /agents/{agentId}:
|
||||
* get:
|
||||
* summary: Get agent by ID
|
||||
* tags: [Agents]
|
||||
* parameters:
|
||||
* - in: path
|
||||
* name: agentId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Agent ID
|
||||
* responses:
|
||||
* 200:
|
||||
* description: Agent details
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/AgentEntity'
|
||||
* 404:
|
||||
* description: Agent not found
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ErrorResponse'
|
||||
*/
|
||||
agentsRouter.get('/:agentId', validateAgentId, handleValidationErrors, agentHandlers.getAgent)
|
||||
/**
|
||||
* @swagger
|
||||
* /agents/{agentId}:
|
||||
* put:
|
||||
* summary: Replace agent (full update)
|
||||
* tags: [Agents]
|
||||
* parameters:
|
||||
* - in: path
|
||||
* name: agentId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Agent ID
|
||||
* requestBody:
|
||||
* required: true
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ReplaceAgentRequest'
|
||||
* responses:
|
||||
* 200:
|
||||
* description: Agent updated successfully
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/AgentEntity'
|
||||
* 400:
|
||||
* description: Invalid request body
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ErrorResponse'
|
||||
* 404:
|
||||
* description: Agent not found
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ErrorResponse'
|
||||
*/
|
||||
agentsRouter.put('/:agentId', validateAgentId, validateAgentReplace, handleValidationErrors, agentHandlers.updateAgent)
|
||||
/**
|
||||
* @swagger
|
||||
* /agents/{agentId}:
|
||||
* patch:
|
||||
* summary: Update agent (partial update)
|
||||
* tags: [Agents]
|
||||
* parameters:
|
||||
* - in: path
|
||||
* name: agentId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Agent ID
|
||||
* requestBody:
|
||||
* required: true
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/UpdateAgentRequest'
|
||||
* responses:
|
||||
* 200:
|
||||
* description: Agent updated successfully
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/AgentEntity'
|
||||
* 400:
|
||||
* description: Invalid request body
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ErrorResponse'
|
||||
* 404:
|
||||
* description: Agent not found
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ErrorResponse'
|
||||
*/
|
||||
agentsRouter.patch('/:agentId', validateAgentId, validateAgentUpdate, handleValidationErrors, agentHandlers.patchAgent)
|
||||
/**
|
||||
* @swagger
|
||||
* /agents/{agentId}:
|
||||
* delete:
|
||||
* summary: Delete agent
|
||||
* tags: [Agents]
|
||||
* parameters:
|
||||
* - in: path
|
||||
* name: agentId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Agent ID
|
||||
* responses:
|
||||
* 204:
|
||||
* description: Agent deleted successfully
|
||||
* 404:
|
||||
* description: Agent not found
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ErrorResponse'
|
||||
*/
|
||||
agentsRouter.delete('/:agentId', validateAgentId, handleValidationErrors, agentHandlers.deleteAgent)
|
||||
|
||||
// Create sessions router with agent context
|
||||
const createSessionsRouter = (): express.Router => {
|
||||
const sessionsRouter = express.Router({ mergeParams: true })
|
||||
|
||||
// Session CRUD routes (nested under agent)
|
||||
/**
|
||||
* @swagger
|
||||
* /agents/{agentId}/sessions:
|
||||
* post:
|
||||
* summary: Create a new session for an agent
|
||||
* tags: [Sessions]
|
||||
* parameters:
|
||||
* - in: path
|
||||
* name: agentId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Agent ID
|
||||
* requestBody:
|
||||
* required: true
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/CreateSessionRequest'
|
||||
* responses:
|
||||
* 201:
|
||||
* description: Session created successfully
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/SessionEntity'
|
||||
* 400:
|
||||
* description: Invalid request body
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ErrorResponse'
|
||||
* 404:
|
||||
* description: Agent not found
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ErrorResponse'
|
||||
*/
|
||||
sessionsRouter.post('/', validateSession, handleValidationErrors, sessionHandlers.createSession)
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /agents/{agentId}/sessions:
|
||||
* get:
|
||||
* summary: List sessions for an agent
|
||||
* tags: [Sessions]
|
||||
* parameters:
|
||||
* - in: path
|
||||
* name: agentId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Agent ID
|
||||
* - in: query
|
||||
* name: limit
|
||||
* schema:
|
||||
* type: integer
|
||||
* minimum: 1
|
||||
* maximum: 100
|
||||
* default: 20
|
||||
* description: Number of sessions to return
|
||||
* - in: query
|
||||
* name: offset
|
||||
* schema:
|
||||
* type: integer
|
||||
* minimum: 0
|
||||
* default: 0
|
||||
* description: Number of sessions to skip
|
||||
* - in: query
|
||||
* name: status
|
||||
* schema:
|
||||
* type: string
|
||||
* enum: [idle, running, completed, failed, stopped]
|
||||
* description: Filter by session status
|
||||
* responses:
|
||||
* 200:
|
||||
* description: List of sessions
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ListSessionsResponse'
|
||||
* 404:
|
||||
* description: Agent not found
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ErrorResponse'
|
||||
*/
|
||||
sessionsRouter.get('/', validatePagination, handleValidationErrors, sessionHandlers.listSessions)
|
||||
/**
|
||||
* @swagger
|
||||
* /agents/{agentId}/sessions/{sessionId}:
|
||||
* get:
|
||||
* summary: Get session by ID
|
||||
* tags: [Sessions]
|
||||
* parameters:
|
||||
* - in: path
|
||||
* name: agentId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Agent ID
|
||||
* - in: path
|
||||
* name: sessionId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Session ID
|
||||
* responses:
|
||||
* 200:
|
||||
* description: Session details
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/SessionEntity'
|
||||
* 404:
|
||||
* description: Agent or session not found
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ErrorResponse'
|
||||
*/
|
||||
sessionsRouter.get('/:sessionId', validateSessionId, handleValidationErrors, sessionHandlers.getSession)
|
||||
/**
|
||||
* @swagger
|
||||
* /agents/{agentId}/sessions/{sessionId}:
|
||||
* put:
|
||||
* summary: Replace session (full update)
|
||||
* tags: [Sessions]
|
||||
* parameters:
|
||||
* - in: path
|
||||
* name: agentId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Agent ID
|
||||
* - in: path
|
||||
* name: sessionId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Session ID
|
||||
* requestBody:
|
||||
* required: true
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ReplaceSessionRequest'
|
||||
* responses:
|
||||
* 200:
|
||||
* description: Session updated successfully
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/SessionEntity'
|
||||
* 400:
|
||||
* description: Invalid request body
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ErrorResponse'
|
||||
* 404:
|
||||
* description: Agent or session not found
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ErrorResponse'
|
||||
*/
|
||||
sessionsRouter.put(
|
||||
'/:sessionId',
|
||||
validateSessionId,
|
||||
validateSessionReplace,
|
||||
handleValidationErrors,
|
||||
sessionHandlers.updateSession
|
||||
)
|
||||
/**
|
||||
* @swagger
|
||||
* /agents/{agentId}/sessions/{sessionId}:
|
||||
* patch:
|
||||
* summary: Update session (partial update)
|
||||
* tags: [Sessions]
|
||||
* parameters:
|
||||
* - in: path
|
||||
* name: agentId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Agent ID
|
||||
* - in: path
|
||||
* name: sessionId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Session ID
|
||||
* requestBody:
|
||||
* required: true
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/UpdateSessionRequest'
|
||||
* responses:
|
||||
* 200:
|
||||
* description: Session updated successfully
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/SessionEntity'
|
||||
* 400:
|
||||
* description: Invalid request body
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ErrorResponse'
|
||||
* 404:
|
||||
* description: Agent or session not found
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ErrorResponse'
|
||||
*/
|
||||
sessionsRouter.patch(
|
||||
'/:sessionId',
|
||||
validateSessionId,
|
||||
validateSessionUpdate,
|
||||
handleValidationErrors,
|
||||
sessionHandlers.patchSession
|
||||
)
|
||||
/**
|
||||
* @swagger
|
||||
* /agents/{agentId}/sessions/{sessionId}:
|
||||
* delete:
|
||||
* summary: Delete session
|
||||
* tags: [Sessions]
|
||||
* parameters:
|
||||
* - in: path
|
||||
* name: agentId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Agent ID
|
||||
* - in: path
|
||||
* name: sessionId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Session ID
|
||||
* responses:
|
||||
* 204:
|
||||
* description: Session deleted successfully
|
||||
* 404:
|
||||
* description: Agent or session not found
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ErrorResponse'
|
||||
*/
|
||||
sessionsRouter.delete('/:sessionId', validateSessionId, handleValidationErrors, sessionHandlers.deleteSession)
|
||||
|
||||
return sessionsRouter
|
||||
}
|
||||
|
||||
// Create messages router with agent and session context
|
||||
const createMessagesRouter = (): express.Router => {
|
||||
const messagesRouter = express.Router({ mergeParams: true })
|
||||
|
||||
// Message CRUD routes (nested under agent/session)
|
||||
/**
|
||||
* @swagger
|
||||
* /agents/{agentId}/sessions/{sessionId}/messages:
|
||||
* post:
|
||||
* summary: Create a new message in a session
|
||||
* tags: [Messages]
|
||||
* parameters:
|
||||
* - in: path
|
||||
* name: agentId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Agent ID
|
||||
* - in: path
|
||||
* name: sessionId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Session ID
|
||||
* requestBody:
|
||||
* required: true
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/CreateSessionMessageRequest'
|
||||
* responses:
|
||||
* 201:
|
||||
* description: Message created successfully
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* properties:
|
||||
* id:
|
||||
* type: number
|
||||
* description: Message ID
|
||||
* session_id:
|
||||
* type: string
|
||||
* description: Session ID
|
||||
* role:
|
||||
* type: string
|
||||
* enum: [assistant, user, system, tool]
|
||||
* description: Message role
|
||||
* content:
|
||||
* type: object
|
||||
* description: Message content (AI SDK format)
|
||||
* agent_session_id:
|
||||
* type: string
|
||||
* description: Agent session ID for resuming
|
||||
* metadata:
|
||||
* type: object
|
||||
* description: Additional metadata
|
||||
* created_at:
|
||||
* type: string
|
||||
* format: date-time
|
||||
* updated_at:
|
||||
* type: string
|
||||
* format: date-time
|
||||
* 400:
|
||||
* description: Invalid request body
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ErrorResponse'
|
||||
* 404:
|
||||
* description: Agent or session not found
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ErrorResponse'
|
||||
*/
|
||||
messagesRouter.post('/', validateSessionMessage, handleValidationErrors, messageHandlers.createMessage)
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /agents/{agentId}/sessions/{sessionId}/messages/{messageId}:
|
||||
* delete:
|
||||
* summary: Delete a message from a session
|
||||
* tags: [Messages]
|
||||
* parameters:
|
||||
* - in: path
|
||||
* name: agentId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Agent ID
|
||||
* - in: path
|
||||
* name: sessionId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Session ID
|
||||
* - in: path
|
||||
* name: messageId
|
||||
* required: true
|
||||
* schema:
|
||||
* type: integer
|
||||
* description: Message ID
|
||||
* responses:
|
||||
* 204:
|
||||
* description: Message deleted successfully
|
||||
* 404:
|
||||
* description: Agent, session, or message not found
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/ErrorResponse'
|
||||
*/
|
||||
messagesRouter.delete('/:messageId', validateSessionMessageId, handleValidationErrors, messageHandlers.deleteMessage)
|
||||
return messagesRouter
|
||||
}
|
||||
|
||||
// Mount nested resources with clear hierarchy
|
||||
const sessionsRouter = createSessionsRouter()
|
||||
const messagesRouter = createMessagesRouter()
|
||||
|
||||
// Mount sessions under specific agent
|
||||
agentsRouter.use('/:agentId/sessions', validateAgentId, checkAgentExists, handleValidationErrors, sessionsRouter)
|
||||
|
||||
// Mount messages under specific agent/session
|
||||
agentsRouter.use(
|
||||
'/:agentId/sessions/:sessionId/messages',
|
||||
validateAgentId,
|
||||
validateSessionId,
|
||||
handleValidationErrors,
|
||||
messagesRouter
|
||||
)
|
||||
|
||||
// Export main router and convenience router
|
||||
export const agentsRoutes = agentsRouter
|
||||
@@ -0,0 +1,44 @@
|
||||
import { Request, Response } from 'express'
|
||||
|
||||
import { agentService } from '../../../../services/agents'
|
||||
import { loggerService } from '../../../../services/LoggerService'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerMiddleware')
|
||||
|
||||
// Since Zod validators handle their own errors, this is now a pass-through
|
||||
export const handleValidationErrors = (_req: Request, _res: Response, next: any): void => {
|
||||
next()
|
||||
}
|
||||
|
||||
// Middleware to check if agent exists
|
||||
export const checkAgentExists = async (req: Request, res: Response, next: any): Promise<void> => {
|
||||
try {
|
||||
const { agentId } = req.params
|
||||
const exists = await agentService.agentExists(agentId)
|
||||
|
||||
if (!exists) {
|
||||
res.status(404).json({
|
||||
error: {
|
||||
message: 'Agent not found',
|
||||
type: 'not_found',
|
||||
code: 'agent_not_found'
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
next()
|
||||
} catch (error) {
|
||||
logger.error('Error checking agent existence', {
|
||||
error: error as Error,
|
||||
agentId: req.params.agentId
|
||||
})
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Failed to validate agent',
|
||||
type: 'internal_error',
|
||||
code: 'agent_validation_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
export * from './common'
|
||||
@@ -0,0 +1,24 @@
|
||||
import {
|
||||
AgentIdParamSchema,
|
||||
CreateAgentRequestSchema,
|
||||
ReplaceAgentRequestSchema,
|
||||
UpdateAgentRequestSchema
|
||||
} from '@types'
|
||||
|
||||
import { createZodValidator } from './zodValidator'
|
||||
|
||||
export const validateAgent = createZodValidator({
|
||||
body: CreateAgentRequestSchema
|
||||
})
|
||||
|
||||
export const validateAgentReplace = createZodValidator({
|
||||
body: ReplaceAgentRequestSchema
|
||||
})
|
||||
|
||||
export const validateAgentUpdate = createZodValidator({
|
||||
body: UpdateAgentRequestSchema
|
||||
})
|
||||
|
||||
export const validateAgentId = createZodValidator({
|
||||
params: AgentIdParamSchema
|
||||
})
|
||||
@@ -0,0 +1,7 @@
|
||||
import { PaginationQuerySchema } from '@types'
|
||||
|
||||
import { createZodValidator } from './zodValidator'
|
||||
|
||||
export const validatePagination = createZodValidator({
|
||||
query: PaginationQuerySchema
|
||||
})
|
||||
@@ -0,0 +1,4 @@
|
||||
export * from './agents'
|
||||
export * from './common'
|
||||
export * from './messages'
|
||||
export * from './sessions'
|
||||
@@ -0,0 +1,11 @@
|
||||
import { CreateSessionMessageRequestSchema, SessionMessageIdParamSchema } from '@types'
|
||||
|
||||
import { createZodValidator } from './zodValidator'
|
||||
|
||||
export const validateSessionMessage = createZodValidator({
|
||||
body: CreateSessionMessageRequestSchema
|
||||
})
|
||||
|
||||
export const validateSessionMessageId = createZodValidator({
|
||||
params: SessionMessageIdParamSchema
|
||||
})
|
||||
@@ -0,0 +1,24 @@
|
||||
import {
|
||||
CreateSessionRequestSchema,
|
||||
ReplaceSessionRequestSchema,
|
||||
SessionIdParamSchema,
|
||||
UpdateSessionRequestSchema
|
||||
} from '@types'
|
||||
|
||||
import { createZodValidator } from './zodValidator'
|
||||
|
||||
export const validateSession = createZodValidator({
|
||||
body: CreateSessionRequestSchema
|
||||
})
|
||||
|
||||
export const validateSessionReplace = createZodValidator({
|
||||
body: ReplaceSessionRequestSchema
|
||||
})
|
||||
|
||||
export const validateSessionUpdate = createZodValidator({
|
||||
body: UpdateSessionRequestSchema
|
||||
})
|
||||
|
||||
export const validateSessionId = createZodValidator({
|
||||
params: SessionIdParamSchema
|
||||
})
|
||||
@@ -0,0 +1,68 @@
|
||||
import { NextFunction, Request, Response } from 'express'
|
||||
import { ZodError, ZodType } from 'zod'
|
||||
|
||||
export interface ValidationRequest extends Request {
|
||||
validatedBody?: any
|
||||
validatedParams?: any
|
||||
validatedQuery?: any
|
||||
}
|
||||
|
||||
export interface ZodValidationConfig {
|
||||
body?: ZodType
|
||||
params?: ZodType
|
||||
query?: ZodType
|
||||
}
|
||||
|
||||
export const createZodValidator = (config: ZodValidationConfig) => {
|
||||
return (req: ValidationRequest, res: Response, next: NextFunction): void => {
|
||||
try {
|
||||
if (config.body && req.body) {
|
||||
req.validatedBody = config.body.parse(req.body)
|
||||
}
|
||||
|
||||
if (config.params && req.params) {
|
||||
req.validatedParams = config.params.parse(req.params)
|
||||
}
|
||||
|
||||
if (config.query && req.query) {
|
||||
req.validatedQuery = config.query.parse(req.query)
|
||||
}
|
||||
|
||||
next()
|
||||
} catch (error) {
|
||||
if (error instanceof ZodError) {
|
||||
const validationErrors = error.issues.map((err) => ({
|
||||
type: 'field',
|
||||
value: err.input,
|
||||
msg: err.message,
|
||||
path: err.path.map((p) => String(p)).join('.'),
|
||||
location: getLocationFromPath(err.path, config)
|
||||
}))
|
||||
|
||||
res.status(400).json({
|
||||
error: {
|
||||
message: 'Validation failed',
|
||||
type: 'validation_error',
|
||||
details: validationErrors
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
res.status(500).json({
|
||||
error: {
|
||||
message: 'Internal validation error',
|
||||
type: 'internal_error',
|
||||
code: 'validation_processing_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function getLocationFromPath(path: (string | number | symbol)[], config: ZodValidationConfig): string {
|
||||
if (config.body && path.length > 0) return 'body'
|
||||
if (config.params && path.length > 0) return 'params'
|
||||
if (config.query && path.length > 0) return 'query'
|
||||
return 'unknown'
|
||||
}
|
||||
@@ -1,16 +1,106 @@
|
||||
import type { Request, Response } from 'express'
|
||||
import express from 'express'
|
||||
import OpenAI from 'openai'
|
||||
import type { ChatCompletionCreateParams } from 'openai/resources'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import { chatCompletionService } from '../services/chat-completion'
|
||||
import { validateModelId } from '../utils'
|
||||
import {
|
||||
ChatCompletionModelError,
|
||||
chatCompletionService,
|
||||
ChatCompletionValidationError
|
||||
} from '../services/chat-completion'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerChatRoutes')
|
||||
|
||||
const router = express.Router()
|
||||
|
||||
interface ErrorResponseBody {
|
||||
error: {
|
||||
message: string
|
||||
type: string
|
||||
code: string
|
||||
}
|
||||
}
|
||||
|
||||
const mapChatCompletionError = (error: unknown): { status: number; body: ErrorResponseBody } => {
|
||||
if (error instanceof ChatCompletionValidationError) {
|
||||
logger.warn('Chat completion validation error', {
|
||||
errors: error.errors
|
||||
})
|
||||
|
||||
return {
|
||||
status: 400,
|
||||
body: {
|
||||
error: {
|
||||
message: error.errors.join('; '),
|
||||
type: 'invalid_request_error',
|
||||
code: 'validation_failed'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (error instanceof ChatCompletionModelError) {
|
||||
logger.warn('Chat completion model error', error.error)
|
||||
|
||||
return {
|
||||
status: 400,
|
||||
body: {
|
||||
error: {
|
||||
message: error.error.message,
|
||||
type: 'invalid_request_error',
|
||||
code: error.error.code
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (error instanceof Error) {
|
||||
let statusCode = 500
|
||||
let errorType = 'server_error'
|
||||
let errorCode = 'internal_error'
|
||||
|
||||
if (error.message.includes('API key') || error.message.includes('authentication')) {
|
||||
statusCode = 401
|
||||
errorType = 'authentication_error'
|
||||
errorCode = 'invalid_api_key'
|
||||
} else if (error.message.includes('rate limit') || error.message.includes('quota')) {
|
||||
statusCode = 429
|
||||
errorType = 'rate_limit_error'
|
||||
errorCode = 'rate_limit_exceeded'
|
||||
} else if (error.message.includes('timeout') || error.message.includes('connection')) {
|
||||
statusCode = 502
|
||||
errorType = 'server_error'
|
||||
errorCode = 'upstream_error'
|
||||
}
|
||||
|
||||
logger.error('Chat completion error', { error })
|
||||
|
||||
return {
|
||||
status: statusCode,
|
||||
body: {
|
||||
error: {
|
||||
message: error.message || 'Internal server error',
|
||||
type: errorType,
|
||||
code: errorCode
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
logger.error('Chat completion unknown error', { error })
|
||||
|
||||
return {
|
||||
status: 500,
|
||||
body: {
|
||||
error: {
|
||||
message: 'Internal server error',
|
||||
type: 'server_error',
|
||||
code: 'internal_error'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /v1/chat/completions:
|
||||
@@ -61,7 +151,7 @@ const router = express.Router()
|
||||
* type: integer
|
||||
* total_tokens:
|
||||
* type: integer
|
||||
* text/plain:
|
||||
* text/event-stream:
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Server-sent events stream (when stream=true)
|
||||
@@ -104,72 +194,31 @@ router.post('/completions', async (req: Request, res: Response) => {
|
||||
})
|
||||
}
|
||||
|
||||
logger.info('Chat completion request:', {
|
||||
logger.debug('Chat completion request', {
|
||||
model: request.model,
|
||||
messageCount: request.messages?.length || 0,
|
||||
stream: request.stream,
|
||||
temperature: request.temperature
|
||||
})
|
||||
|
||||
// Validate request
|
||||
const validation = chatCompletionService.validateRequest(request)
|
||||
if (!validation.isValid) {
|
||||
return res.status(400).json({
|
||||
error: {
|
||||
message: validation.errors.join('; '),
|
||||
type: 'invalid_request_error',
|
||||
code: 'validation_failed'
|
||||
}
|
||||
})
|
||||
}
|
||||
const isStreaming = !!request.stream
|
||||
|
||||
// Validate model ID and get provider
|
||||
const modelValidation = await validateModelId(request.model)
|
||||
if (!modelValidation.valid) {
|
||||
const error = modelValidation.error!
|
||||
logger.warn(`Model validation failed for '${request.model}':`, error)
|
||||
return res.status(400).json({
|
||||
error: {
|
||||
message: error.message,
|
||||
type: 'invalid_request_error',
|
||||
code: error.code
|
||||
}
|
||||
})
|
||||
}
|
||||
if (isStreaming) {
|
||||
const { stream } = await chatCompletionService.processStreamingCompletion(request)
|
||||
|
||||
const provider = modelValidation.provider!
|
||||
const modelId = modelValidation.modelId!
|
||||
|
||||
logger.info('Model validation successful:', {
|
||||
provider: provider.id,
|
||||
providerType: provider.type,
|
||||
modelId: modelId,
|
||||
fullModelId: request.model
|
||||
})
|
||||
|
||||
// Create OpenAI client
|
||||
const client = new OpenAI({
|
||||
baseURL: provider.apiHost,
|
||||
apiKey: provider.apiKey
|
||||
})
|
||||
request.model = modelId
|
||||
|
||||
// Handle streaming
|
||||
if (request.stream) {
|
||||
const streamResponse = await client.chat.completions.create(request)
|
||||
|
||||
res.setHeader('Content-Type', 'text/plain; charset=utf-8')
|
||||
res.setHeader('Cache-Control', 'no-cache')
|
||||
res.setHeader('Content-Type', 'text/event-stream; charset=utf-8')
|
||||
res.setHeader('Cache-Control', 'no-cache, no-transform')
|
||||
res.setHeader('Connection', 'keep-alive')
|
||||
res.setHeader('X-Accel-Buffering', 'no')
|
||||
res.flushHeaders()
|
||||
|
||||
try {
|
||||
for await (const chunk of streamResponse as any) {
|
||||
for await (const chunk of stream) {
|
||||
res.write(`data: ${JSON.stringify(chunk)}\n\n`)
|
||||
}
|
||||
res.write('data: [DONE]\n\n')
|
||||
res.end()
|
||||
} catch (streamError: any) {
|
||||
logger.error('Stream error:', streamError)
|
||||
logger.error('Stream error', { error: streamError })
|
||||
res.write(
|
||||
`data: ${JSON.stringify({
|
||||
error: {
|
||||
@@ -179,47 +228,17 @@ router.post('/completions', async (req: Request, res: Response) => {
|
||||
}
|
||||
})}\n\n`
|
||||
)
|
||||
} finally {
|
||||
res.end()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Handle non-streaming
|
||||
const response = await client.chat.completions.create(request)
|
||||
const { response } = await chatCompletionService.processCompletion(request)
|
||||
return res.json(response)
|
||||
} catch (error: any) {
|
||||
logger.error('Chat completion error:', error)
|
||||
|
||||
let statusCode = 500
|
||||
let errorType = 'server_error'
|
||||
let errorCode = 'internal_error'
|
||||
let errorMessage = 'Internal server error'
|
||||
|
||||
if (error instanceof Error) {
|
||||
errorMessage = error.message
|
||||
|
||||
if (error.message.includes('API key') || error.message.includes('authentication')) {
|
||||
statusCode = 401
|
||||
errorType = 'authentication_error'
|
||||
errorCode = 'invalid_api_key'
|
||||
} else if (error.message.includes('rate limit') || error.message.includes('quota')) {
|
||||
statusCode = 429
|
||||
errorType = 'rate_limit_error'
|
||||
errorCode = 'rate_limit_exceeded'
|
||||
} else if (error.message.includes('timeout') || error.message.includes('connection')) {
|
||||
statusCode = 502
|
||||
errorType = 'server_error'
|
||||
errorCode = 'upstream_error'
|
||||
}
|
||||
}
|
||||
|
||||
return res.status(statusCode).json({
|
||||
error: {
|
||||
message: errorMessage,
|
||||
type: errorType,
|
||||
code: errorCode
|
||||
}
|
||||
})
|
||||
} catch (error: unknown) {
|
||||
const { status, body } = mapChatCompletionError(error)
|
||||
return res.status(status).json(body)
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -44,14 +44,14 @@ const router = express.Router()
|
||||
*/
|
||||
router.get('/', async (req: Request, res: Response) => {
|
||||
try {
|
||||
logger.info('Get all MCP servers request received')
|
||||
logger.debug('Listing MCP servers')
|
||||
const servers = await mcpApiService.getAllServers(req)
|
||||
return res.json({
|
||||
success: true,
|
||||
data: servers
|
||||
})
|
||||
} catch (error: any) {
|
||||
logger.error('Error fetching MCP servers:', error)
|
||||
logger.error('Error fetching MCP servers', { error })
|
||||
return res.status(503).json({
|
||||
success: false,
|
||||
error: {
|
||||
@@ -104,10 +104,12 @@ router.get('/', async (req: Request, res: Response) => {
|
||||
*/
|
||||
router.get('/:server_id', async (req: Request, res: Response) => {
|
||||
try {
|
||||
logger.info('Get MCP server info request received')
|
||||
logger.debug('Get MCP server info request received', {
|
||||
serverId: req.params.server_id
|
||||
})
|
||||
const server = await mcpApiService.getServerInfo(req.params.server_id)
|
||||
if (!server) {
|
||||
logger.warn('MCP server not found')
|
||||
logger.warn('MCP server not found', { serverId: req.params.server_id })
|
||||
return res.status(404).json({
|
||||
success: false,
|
||||
error: {
|
||||
@@ -122,7 +124,7 @@ router.get('/:server_id', async (req: Request, res: Response) => {
|
||||
data: server
|
||||
})
|
||||
} catch (error: any) {
|
||||
logger.error('Error fetching MCP server info:', error)
|
||||
logger.error('Error fetching MCP server info', { error, serverId: req.params.server_id })
|
||||
return res.status(503).json({
|
||||
success: false,
|
||||
error: {
|
||||
@@ -138,7 +140,7 @@ router.get('/:server_id', async (req: Request, res: Response) => {
|
||||
router.all('/:server_id/mcp', async (req: Request, res: Response) => {
|
||||
const server = await mcpApiService.getServerById(req.params.server_id)
|
||||
if (!server) {
|
||||
logger.warn('MCP server not found')
|
||||
logger.warn('MCP server not found', { serverId: req.params.server_id })
|
||||
return res.status(404).json({
|
||||
success: false,
|
||||
error: {
|
||||
|
||||
@@ -0,0 +1,403 @@
|
||||
import { MessageCreateParams } from '@anthropic-ai/sdk/resources'
|
||||
import { loggerService } from '@logger'
|
||||
import { Provider } from '@types'
|
||||
import express, { Request, Response } from 'express'
|
||||
|
||||
import { messagesService } from '../services/messages'
|
||||
import { getProviderById, validateModelId } from '../utils'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerMessagesRoutes')
|
||||
|
||||
const router = express.Router()
|
||||
const providerRouter = express.Router({ mergeParams: true })
|
||||
|
||||
// Helper function for basic request validation
|
||||
async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> {
|
||||
const request: MessageCreateParams = req.body
|
||||
|
||||
if (!request) {
|
||||
return {
|
||||
valid: false,
|
||||
error: {
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: 'Request body is required'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return { valid: true }
|
||||
}
|
||||
|
||||
interface HandleMessageProcessingOptions {
|
||||
req: Request
|
||||
res: Response
|
||||
provider: Provider
|
||||
request: MessageCreateParams
|
||||
modelId?: string
|
||||
}
|
||||
|
||||
async function handleMessageProcessing({
|
||||
req,
|
||||
res,
|
||||
provider,
|
||||
request,
|
||||
modelId
|
||||
}: HandleMessageProcessingOptions): Promise<void> {
|
||||
try {
|
||||
const validation = messagesService.validateRequest(request)
|
||||
if (!validation.isValid) {
|
||||
res.status(400).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: validation.errors.join('; ')
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
const extraHeaders = messagesService.prepareHeaders(req.headers)
|
||||
const { client, anthropicRequest } = await messagesService.processMessage({
|
||||
provider,
|
||||
request,
|
||||
extraHeaders,
|
||||
modelId
|
||||
})
|
||||
|
||||
if (request.stream) {
|
||||
await messagesService.handleStreaming(client, anthropicRequest, { response: res }, provider)
|
||||
return
|
||||
}
|
||||
|
||||
const response = await client.messages.create(anthropicRequest)
|
||||
res.json(response)
|
||||
} catch (error: any) {
|
||||
logger.error('Message processing error', { error })
|
||||
const { statusCode, errorResponse } = messagesService.transformError(error)
|
||||
res.status(statusCode).json(errorResponse)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /v1/messages:
|
||||
* post:
|
||||
* summary: Create message
|
||||
* description: Create a message response using Anthropic's API format
|
||||
* tags: [Messages]
|
||||
* requestBody:
|
||||
* required: true
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* required:
|
||||
* - model
|
||||
* - max_tokens
|
||||
* - messages
|
||||
* properties:
|
||||
* model:
|
||||
* type: string
|
||||
* description: Model ID in format "provider:model_id"
|
||||
* example: "my-anthropic:claude-3-5-sonnet-20241022"
|
||||
* max_tokens:
|
||||
* type: integer
|
||||
* minimum: 1
|
||||
* description: Maximum number of tokens to generate
|
||||
* example: 1024
|
||||
* messages:
|
||||
* type: array
|
||||
* items:
|
||||
* type: object
|
||||
* properties:
|
||||
* role:
|
||||
* type: string
|
||||
* enum: [user, assistant]
|
||||
* content:
|
||||
* oneOf:
|
||||
* - type: string
|
||||
* - type: array
|
||||
* system:
|
||||
* type: string
|
||||
* description: System message
|
||||
* temperature:
|
||||
* type: number
|
||||
* minimum: 0
|
||||
* maximum: 1
|
||||
* description: Sampling temperature
|
||||
* top_p:
|
||||
* type: number
|
||||
* minimum: 0
|
||||
* maximum: 1
|
||||
* description: Nucleus sampling
|
||||
* top_k:
|
||||
* type: integer
|
||||
* minimum: 0
|
||||
* description: Top-k sampling
|
||||
* stream:
|
||||
* type: boolean
|
||||
* description: Whether to stream the response
|
||||
* tools:
|
||||
* type: array
|
||||
* description: Available tools for the model
|
||||
* responses:
|
||||
* 200:
|
||||
* description: Message response
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* properties:
|
||||
* id:
|
||||
* type: string
|
||||
* type:
|
||||
* type: string
|
||||
* example: message
|
||||
* role:
|
||||
* type: string
|
||||
* example: assistant
|
||||
* content:
|
||||
* type: array
|
||||
* items:
|
||||
* type: object
|
||||
* model:
|
||||
* type: string
|
||||
* stop_reason:
|
||||
* type: string
|
||||
* stop_sequence:
|
||||
* type: string
|
||||
* usage:
|
||||
* type: object
|
||||
* properties:
|
||||
* input_tokens:
|
||||
* type: integer
|
||||
* output_tokens:
|
||||
* type: integer
|
||||
* text/event-stream:
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Server-sent events stream (when stream=true)
|
||||
* 400:
|
||||
* description: Bad request
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* properties:
|
||||
* type:
|
||||
* type: string
|
||||
* example: error
|
||||
* error:
|
||||
* type: object
|
||||
* properties:
|
||||
* type:
|
||||
* type: string
|
||||
* message:
|
||||
* type: string
|
||||
* 401:
|
||||
* description: Unauthorized
|
||||
* 429:
|
||||
* description: Rate limit exceeded
|
||||
* 500:
|
||||
* description: Internal server error
|
||||
*/
|
||||
router.post('/', async (req: Request, res: Response) => {
|
||||
// Validate request body
|
||||
const bodyValidation = await validateRequestBody(req)
|
||||
if (!bodyValidation.valid) {
|
||||
return res.status(400).json(bodyValidation.error)
|
||||
}
|
||||
|
||||
try {
|
||||
const request: MessageCreateParams = req.body
|
||||
|
||||
// Validate model ID and get provider
|
||||
const modelValidation = await validateModelId(request.model)
|
||||
if (!modelValidation.valid) {
|
||||
const error = modelValidation.error!
|
||||
logger.warn('Model validation failed', {
|
||||
model: request.model,
|
||||
error
|
||||
})
|
||||
return res.status(400).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: error.message
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const provider = modelValidation.provider!
|
||||
const modelId = modelValidation.modelId!
|
||||
|
||||
return handleMessageProcessing({ req, res, provider, request, modelId })
|
||||
} catch (error: any) {
|
||||
logger.error('Message processing error', { error })
|
||||
const { statusCode, errorResponse } = messagesService.transformError(error)
|
||||
return res.status(statusCode).json(errorResponse)
|
||||
}
|
||||
})
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /{provider_id}/v1/messages:
|
||||
* post:
|
||||
* summary: Create message with provider in path
|
||||
* description: Create a message response using provider ID from URL path
|
||||
* tags: [Messages]
|
||||
* parameters:
|
||||
* - in: path
|
||||
* name: provider_id
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Provider ID (e.g., "my-anthropic")
|
||||
* example: "my-anthropic"
|
||||
* requestBody:
|
||||
* required: true
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* required:
|
||||
* - model
|
||||
* - max_tokens
|
||||
* - messages
|
||||
* properties:
|
||||
* model:
|
||||
* type: string
|
||||
* description: Model ID without provider prefix
|
||||
* example: "claude-3-5-sonnet-20241022"
|
||||
* max_tokens:
|
||||
* type: integer
|
||||
* minimum: 1
|
||||
* description: Maximum number of tokens to generate
|
||||
* example: 1024
|
||||
* messages:
|
||||
* type: array
|
||||
* items:
|
||||
* type: object
|
||||
* properties:
|
||||
* role:
|
||||
* type: string
|
||||
* enum: [user, assistant]
|
||||
* content:
|
||||
* oneOf:
|
||||
* - type: string
|
||||
* - type: array
|
||||
* system:
|
||||
* type: string
|
||||
* description: System message
|
||||
* temperature:
|
||||
* type: number
|
||||
* minimum: 0
|
||||
* maximum: 1
|
||||
* description: Sampling temperature
|
||||
* top_p:
|
||||
* type: number
|
||||
* minimum: 0
|
||||
* maximum: 1
|
||||
* description: Nucleus sampling
|
||||
* top_k:
|
||||
* type: integer
|
||||
* minimum: 0
|
||||
* description: Top-k sampling
|
||||
* stream:
|
||||
* type: boolean
|
||||
* description: Whether to stream the response
|
||||
* tools:
|
||||
* type: array
|
||||
* description: Available tools for the model
|
||||
* responses:
|
||||
* 200:
|
||||
* description: Message response
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* properties:
|
||||
* id:
|
||||
* type: string
|
||||
* type:
|
||||
* type: string
|
||||
* example: message
|
||||
* role:
|
||||
* type: string
|
||||
* example: assistant
|
||||
* content:
|
||||
* type: array
|
||||
* items:
|
||||
* type: object
|
||||
* model:
|
||||
* type: string
|
||||
* stop_reason:
|
||||
* type: string
|
||||
* stop_sequence:
|
||||
* type: string
|
||||
* usage:
|
||||
* type: object
|
||||
* properties:
|
||||
* input_tokens:
|
||||
* type: integer
|
||||
* output_tokens:
|
||||
* type: integer
|
||||
* text/event-stream:
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Server-sent events stream (when stream=true)
|
||||
* 400:
|
||||
* description: Bad request
|
||||
* 401:
|
||||
* description: Unauthorized
|
||||
* 429:
|
||||
* description: Rate limit exceeded
|
||||
* 500:
|
||||
* description: Internal server error
|
||||
*/
|
||||
providerRouter.post('/', async (req: Request, res: Response) => {
|
||||
// Validate request body
|
||||
const bodyValidation = await validateRequestBody(req)
|
||||
if (!bodyValidation.valid) {
|
||||
return res.status(400).json(bodyValidation.error)
|
||||
}
|
||||
|
||||
try {
|
||||
const providerId = req.params.provider
|
||||
|
||||
if (!providerId) {
|
||||
return res.status(400).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: 'Provider ID is required in URL path'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Get provider directly by ID from URL path
|
||||
const provider = await getProviderById(providerId)
|
||||
if (!provider) {
|
||||
return res.status(400).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: `Provider '${providerId}' not found or not enabled`
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const request: MessageCreateParams = req.body
|
||||
|
||||
return handleMessageProcessing({ req, res, provider, request })
|
||||
} catch (error: any) {
|
||||
logger.error('Message processing error', { error })
|
||||
const { statusCode, errorResponse } = messagesService.transformError(error)
|
||||
return res.status(statusCode).json(errorResponse)
|
||||
}
|
||||
})
|
||||
|
||||
export { providerRouter as messagesProviderRoutes, router as messagesRoutes }
|
||||
@@ -1,74 +1,126 @@
|
||||
import type { ApiModelsResponse } from '@types'
|
||||
import { ApiModelsFilterSchema } from '@types'
|
||||
import type { Request, Response } from 'express'
|
||||
import express from 'express'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import { chatCompletionService } from '../services/chat-completion'
|
||||
import { loggerService } from '@logger'
|
||||
import { modelsService } from '../services/models'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerModelsRoutes')
|
||||
|
||||
const router = express.Router()
|
||||
const router = express
|
||||
.Router()
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /v1/models:
|
||||
* get:
|
||||
* summary: List available models
|
||||
* description: Returns a list of available AI models from all configured providers
|
||||
* tags: [Models]
|
||||
* responses:
|
||||
* 200:
|
||||
* description: List of available models
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* properties:
|
||||
* object:
|
||||
* type: string
|
||||
* example: list
|
||||
* data:
|
||||
* type: array
|
||||
* items:
|
||||
* $ref: '#/components/schemas/Model'
|
||||
* 503:
|
||||
* description: Service unavailable
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
*/
|
||||
router.get('/', async (_req: Request, res: Response) => {
|
||||
try {
|
||||
logger.info('Models list request received')
|
||||
/**
|
||||
* @swagger
|
||||
* /v1/models:
|
||||
* get:
|
||||
* summary: List available models
|
||||
* description: Returns a list of available AI models from all configured providers with optional filtering
|
||||
* tags: [Models]
|
||||
* parameters:
|
||||
* - in: query
|
||||
* name: providerType
|
||||
* schema:
|
||||
* type: string
|
||||
* enum: [openai, openai-response, anthropic, gemini]
|
||||
* description: Filter models by provider type
|
||||
* - in: query
|
||||
* name: offset
|
||||
* schema:
|
||||
* type: integer
|
||||
* minimum: 0
|
||||
* default: 0
|
||||
* description: Pagination offset
|
||||
* - in: query
|
||||
* name: limit
|
||||
* schema:
|
||||
* type: integer
|
||||
* minimum: 1
|
||||
* description: Maximum number of models to return
|
||||
* responses:
|
||||
* 200:
|
||||
* description: List of available models
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* properties:
|
||||
* object:
|
||||
* type: string
|
||||
* example: list
|
||||
* data:
|
||||
* type: array
|
||||
* items:
|
||||
* $ref: '#/components/schemas/Model'
|
||||
* total:
|
||||
* type: integer
|
||||
* description: Total number of models (when using pagination)
|
||||
* offset:
|
||||
* type: integer
|
||||
* description: Current offset (when using pagination)
|
||||
* limit:
|
||||
* type: integer
|
||||
* description: Current limit (when using pagination)
|
||||
* 400:
|
||||
* description: Invalid query parameters
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
* 503:
|
||||
* description: Service unavailable
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* $ref: '#/components/schemas/Error'
|
||||
*/
|
||||
.get('/', async (req: Request, res: Response) => {
|
||||
try {
|
||||
logger.debug('Models list request received', { query: req.query })
|
||||
|
||||
const models = await chatCompletionService.getModels()
|
||||
// Validate query parameters using Zod schema
|
||||
const filterResult = ApiModelsFilterSchema.safeParse(req.query)
|
||||
|
||||
if (models.length === 0) {
|
||||
logger.warn(
|
||||
'No models available from providers. This may be because no OpenAI providers are configured or enabled.'
|
||||
)
|
||||
}
|
||||
|
||||
logger.info(`Returning ${models.length} models (OpenAI providers only)`)
|
||||
logger.debug(
|
||||
'Model IDs:',
|
||||
models.map((m) => m.id)
|
||||
)
|
||||
|
||||
return res.json({
|
||||
object: 'list',
|
||||
data: models
|
||||
})
|
||||
} catch (error: any) {
|
||||
logger.error('Error fetching models:', error)
|
||||
return res.status(503).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve models from available providers',
|
||||
type: 'service_unavailable',
|
||||
code: 'models_unavailable'
|
||||
if (!filterResult.success) {
|
||||
logger.warn('Invalid model query parameters', { issues: filterResult.error.issues })
|
||||
return res.status(400).json({
|
||||
error: {
|
||||
message: 'Invalid query parameters',
|
||||
type: 'invalid_request_error',
|
||||
code: 'invalid_parameters',
|
||||
details: filterResult.error.issues.map((issue) => ({
|
||||
field: issue.path.join('.'),
|
||||
message: issue.message
|
||||
}))
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
const filter = filterResult.data
|
||||
const response = await modelsService.getModels(filter)
|
||||
|
||||
if (response.data.length === 0) {
|
||||
logger.warn('No models available from providers', { filter })
|
||||
}
|
||||
|
||||
logger.info('Models response ready', {
|
||||
filter,
|
||||
total: response.total,
|
||||
modelIds: response.data.map((m) => m.id)
|
||||
})
|
||||
|
||||
return res.json(response satisfies ApiModelsResponse)
|
||||
} catch (error: any) {
|
||||
logger.error('Error fetching models', { error })
|
||||
return res.status(503).json({
|
||||
error: {
|
||||
message: 'Failed to retrieve models from available providers',
|
||||
type: 'service_unavailable',
|
||||
code: 'models_unavailable'
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
export { router as modelsRoutes }
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
import { createServer } from 'node:http'
|
||||
|
||||
import { agentService } from '../services/agents'
|
||||
import { loggerService } from '../services/LoggerService'
|
||||
import { app } from './app'
|
||||
import { config } from './config'
|
||||
|
||||
const logger = loggerService.withContext('ApiServer')
|
||||
|
||||
const GLOBAL_REQUEST_TIMEOUT_MS = 5 * 60_000
|
||||
const GLOBAL_HEADERS_TIMEOUT_MS = GLOBAL_REQUEST_TIMEOUT_MS + 5_000
|
||||
const GLOBAL_KEEPALIVE_TIMEOUT_MS = 60_000
|
||||
|
||||
export class ApiServer {
|
||||
private server: ReturnType<typeof createServer> | null = null
|
||||
|
||||
@@ -16,16 +21,21 @@ export class ApiServer {
|
||||
}
|
||||
|
||||
// Load config
|
||||
const { port, host, apiKey } = await config.load()
|
||||
const { port, host } = await config.load()
|
||||
|
||||
// Initialize AgentService
|
||||
logger.info('Initializing AgentService')
|
||||
await agentService.initialize()
|
||||
logger.info('AgentService initialized')
|
||||
|
||||
// Create server with Express app
|
||||
this.server = createServer(app)
|
||||
this.applyServerTimeouts(this.server)
|
||||
|
||||
// Start server
|
||||
return new Promise((resolve, reject) => {
|
||||
this.server!.listen(port, host, () => {
|
||||
logger.info(`API Server started at http://${host}:${port}`)
|
||||
logger.info(`API Key: ${apiKey}`)
|
||||
logger.info('API server started', { host, port })
|
||||
resolve()
|
||||
})
|
||||
|
||||
@@ -33,12 +43,19 @@ export class ApiServer {
|
||||
})
|
||||
}
|
||||
|
||||
private applyServerTimeouts(server: ReturnType<typeof createServer>): void {
|
||||
server.requestTimeout = GLOBAL_REQUEST_TIMEOUT_MS
|
||||
server.headersTimeout = Math.max(GLOBAL_HEADERS_TIMEOUT_MS, server.requestTimeout + 1_000)
|
||||
server.keepAliveTimeout = GLOBAL_KEEPALIVE_TIMEOUT_MS
|
||||
server.setTimeout(0)
|
||||
}
|
||||
|
||||
async stop(): Promise<void> {
|
||||
if (!this.server) return
|
||||
|
||||
return new Promise((resolve) => {
|
||||
this.server!.close(() => {
|
||||
logger.info('API Server stopped')
|
||||
logger.info('API server stopped')
|
||||
this.server = null
|
||||
resolve()
|
||||
})
|
||||
@@ -56,7 +73,7 @@ export class ApiServer {
|
||||
const isListening = this.server?.listening || false
|
||||
const result = hasServer && isListening
|
||||
|
||||
logger.debug('isRunning check:', { hasServer, isListening, result })
|
||||
logger.debug('isRunning check', { hasServer, isListening, result })
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
@@ -1,83 +1,132 @@
|
||||
import type { Provider } from '@types'
|
||||
import OpenAI from 'openai'
|
||||
import type { ChatCompletionCreateParams } from 'openai/resources'
|
||||
import type { ChatCompletionCreateParams, ChatCompletionCreateParamsStreaming } from 'openai/resources'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import type { OpenAICompatibleModel } from '../utils'
|
||||
import {
|
||||
getProviderByModel,
|
||||
getRealProviderModel,
|
||||
listAllAvailableModels,
|
||||
transformModelToOpenAI,
|
||||
validateProvider
|
||||
} from '../utils'
|
||||
import { loggerService } from '@logger'
|
||||
import { type ModelValidationError, validateModelId } from '../utils'
|
||||
|
||||
const logger = loggerService.withContext('ChatCompletionService')
|
||||
|
||||
export interface ModelData extends OpenAICompatibleModel {
|
||||
provider_id: string
|
||||
model_id: string
|
||||
name: string
|
||||
}
|
||||
|
||||
export interface ValidationResult {
|
||||
isValid: boolean
|
||||
errors: string[]
|
||||
}
|
||||
|
||||
export class ChatCompletionValidationError extends Error {
|
||||
constructor(public readonly errors: string[]) {
|
||||
super(`Request validation failed: ${errors.join('; ')}`)
|
||||
this.name = 'ChatCompletionValidationError'
|
||||
}
|
||||
}
|
||||
|
||||
export class ChatCompletionModelError extends Error {
|
||||
constructor(public readonly error: ModelValidationError) {
|
||||
super(`Model validation failed: ${error.message}`)
|
||||
this.name = 'ChatCompletionModelError'
|
||||
}
|
||||
}
|
||||
|
||||
export type PrepareRequestResult =
|
||||
| { status: 'validation_error'; errors: string[] }
|
||||
| { status: 'model_error'; error: ModelValidationError }
|
||||
| {
|
||||
status: 'ok'
|
||||
provider: Provider
|
||||
modelId: string
|
||||
client: OpenAI
|
||||
providerRequest: ChatCompletionCreateParams
|
||||
}
|
||||
|
||||
export class ChatCompletionService {
|
||||
async getModels(): Promise<ModelData[]> {
|
||||
try {
|
||||
logger.info('Getting available models from providers')
|
||||
async resolveProviderContext(
|
||||
model: string
|
||||
): Promise<
|
||||
{ ok: false; error: ModelValidationError } | { ok: true; provider: Provider; modelId: string; client: OpenAI }
|
||||
> {
|
||||
const modelValidation = await validateModelId(model)
|
||||
if (!modelValidation.valid) {
|
||||
return {
|
||||
ok: false,
|
||||
error: modelValidation.error!
|
||||
}
|
||||
}
|
||||
|
||||
const models = await listAllAvailableModels()
|
||||
const provider = modelValidation.provider!
|
||||
|
||||
// Use Map to deduplicate models by their full ID (provider:model_id)
|
||||
const uniqueModels = new Map<string, ModelData>()
|
||||
|
||||
for (const model of models) {
|
||||
const openAIModel = transformModelToOpenAI(model)
|
||||
const fullModelId = openAIModel.id // This is already in format "provider:model_id"
|
||||
|
||||
// Only add if not already present (first occurrence wins)
|
||||
if (!uniqueModels.has(fullModelId)) {
|
||||
uniqueModels.set(fullModelId, {
|
||||
...openAIModel,
|
||||
provider_id: model.provider,
|
||||
model_id: model.id,
|
||||
name: model.name
|
||||
})
|
||||
} else {
|
||||
logger.debug(`Skipping duplicate model: ${fullModelId}`)
|
||||
if (provider.type !== 'openai') {
|
||||
return {
|
||||
ok: false,
|
||||
error: {
|
||||
type: 'unsupported_provider_type',
|
||||
message: `Provider '${provider.id}' of type '${provider.type}' is not supported for OpenAI chat completions`,
|
||||
code: 'unsupported_provider_type'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const modelData = Array.from(uniqueModels.values())
|
||||
const modelId = modelValidation.modelId!
|
||||
|
||||
logger.info(`Successfully retrieved ${modelData.length} unique models from ${models.length} total models`)
|
||||
const client = new OpenAI({
|
||||
baseURL: provider.apiHost,
|
||||
apiKey: provider.apiKey
|
||||
})
|
||||
|
||||
if (models.length > modelData.length) {
|
||||
logger.debug(`Filtered out ${models.length - modelData.length} duplicate models`)
|
||||
return {
|
||||
ok: true,
|
||||
provider,
|
||||
modelId,
|
||||
client
|
||||
}
|
||||
}
|
||||
|
||||
async prepareRequest(request: ChatCompletionCreateParams, stream: boolean): Promise<PrepareRequestResult> {
|
||||
const requestValidation = this.validateRequest(request)
|
||||
if (!requestValidation.isValid) {
|
||||
return {
|
||||
status: 'validation_error',
|
||||
errors: requestValidation.errors
|
||||
}
|
||||
}
|
||||
|
||||
return modelData
|
||||
} catch (error: any) {
|
||||
logger.error('Error getting models:', error)
|
||||
return []
|
||||
const providerContext = await this.resolveProviderContext(request.model!)
|
||||
if (!providerContext.ok) {
|
||||
return {
|
||||
status: 'model_error',
|
||||
error: providerContext.error
|
||||
}
|
||||
}
|
||||
|
||||
const { provider, modelId, client } = providerContext
|
||||
|
||||
logger.debug('Model validation successful', {
|
||||
provider: provider.id,
|
||||
providerType: provider.type,
|
||||
modelId,
|
||||
fullModelId: request.model
|
||||
})
|
||||
|
||||
return {
|
||||
status: 'ok',
|
||||
provider,
|
||||
modelId,
|
||||
client,
|
||||
providerRequest: stream
|
||||
? {
|
||||
...request,
|
||||
model: modelId,
|
||||
stream: true as const
|
||||
}
|
||||
: {
|
||||
...request,
|
||||
model: modelId,
|
||||
stream: false as const
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
validateRequest(request: ChatCompletionCreateParams): ValidationResult {
|
||||
const errors: string[] = []
|
||||
|
||||
// Validate model
|
||||
if (!request.model) {
|
||||
errors.push('Model is required')
|
||||
} else if (typeof request.model !== 'string') {
|
||||
errors.push('Model must be a string')
|
||||
} else if (!request.model.includes(':')) {
|
||||
errors.push('Model must be in format "provider:model_id"')
|
||||
}
|
||||
|
||||
// Validate messages
|
||||
if (!request.messages) {
|
||||
errors.push('Messages array is required')
|
||||
@@ -98,17 +147,6 @@ export class ChatCompletionService {
|
||||
}
|
||||
|
||||
// Validate optional parameters
|
||||
if (request.temperature !== undefined) {
|
||||
if (typeof request.temperature !== 'number' || request.temperature < 0 || request.temperature > 2) {
|
||||
errors.push('Temperature must be a number between 0 and 2')
|
||||
}
|
||||
}
|
||||
|
||||
if (request.max_tokens !== undefined) {
|
||||
if (typeof request.max_tokens !== 'number' || request.max_tokens < 1) {
|
||||
errors.push('max_tokens must be a positive number')
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
isValid: errors.length === 0,
|
||||
@@ -116,48 +154,30 @@ export class ChatCompletionService {
|
||||
}
|
||||
}
|
||||
|
||||
async processCompletion(request: ChatCompletionCreateParams): Promise<OpenAI.Chat.Completions.ChatCompletion> {
|
||||
async processCompletion(request: ChatCompletionCreateParams): Promise<{
|
||||
provider: Provider
|
||||
modelId: string
|
||||
response: OpenAI.Chat.Completions.ChatCompletion
|
||||
}> {
|
||||
try {
|
||||
logger.info('Processing chat completion request:', {
|
||||
logger.debug('Processing chat completion request', {
|
||||
model: request.model,
|
||||
messageCount: request.messages.length,
|
||||
stream: request.stream
|
||||
})
|
||||
|
||||
// Validate request
|
||||
const validation = this.validateRequest(request)
|
||||
if (!validation.isValid) {
|
||||
throw new Error(`Request validation failed: ${validation.errors.join(', ')}`)
|
||||
const preparation = await this.prepareRequest(request, false)
|
||||
if (preparation.status === 'validation_error') {
|
||||
throw new ChatCompletionValidationError(preparation.errors)
|
||||
}
|
||||
|
||||
// Get provider for the model
|
||||
const provider = await getProviderByModel(request.model!)
|
||||
if (!provider) {
|
||||
throw new Error(`Provider not found for model: ${request.model}`)
|
||||
if (preparation.status === 'model_error') {
|
||||
throw new ChatCompletionModelError(preparation.error)
|
||||
}
|
||||
|
||||
// Validate provider
|
||||
if (!validateProvider(provider)) {
|
||||
throw new Error(`Provider validation failed for: ${provider.id}`)
|
||||
}
|
||||
const { provider, modelId, client, providerRequest } = preparation
|
||||
|
||||
// Extract model ID from the full model string
|
||||
const modelId = getRealProviderModel(request.model)
|
||||
|
||||
// Create OpenAI client for the provider
|
||||
const client = new OpenAI({
|
||||
baseURL: provider.apiHost,
|
||||
apiKey: provider.apiKey
|
||||
})
|
||||
|
||||
// Prepare request with the actual model ID
|
||||
const providerRequest = {
|
||||
...request,
|
||||
model: modelId,
|
||||
stream: false
|
||||
}
|
||||
|
||||
logger.debug('Sending request to provider:', {
|
||||
logger.debug('Sending request to provider', {
|
||||
provider: provider.id,
|
||||
model: modelId,
|
||||
apiHost: provider.apiHost
|
||||
@@ -165,71 +185,71 @@ export class ChatCompletionService {
|
||||
|
||||
const response = (await client.chat.completions.create(providerRequest)) as OpenAI.Chat.Completions.ChatCompletion
|
||||
|
||||
logger.info('Successfully processed chat completion')
|
||||
return response
|
||||
logger.info('Chat completion processed', {
|
||||
modelId,
|
||||
provider: provider.id
|
||||
})
|
||||
return {
|
||||
provider,
|
||||
modelId,
|
||||
response
|
||||
}
|
||||
} catch (error: any) {
|
||||
logger.error('Error processing chat completion:', error)
|
||||
logger.error('Error processing chat completion', {
|
||||
error,
|
||||
model: request.model
|
||||
})
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
async *processStreamingCompletion(
|
||||
request: ChatCompletionCreateParams
|
||||
): AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk> {
|
||||
async processStreamingCompletion(request: ChatCompletionCreateParams): Promise<{
|
||||
provider: Provider
|
||||
modelId: string
|
||||
stream: AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>
|
||||
}> {
|
||||
try {
|
||||
logger.info('Processing streaming chat completion request:', {
|
||||
logger.debug('Processing streaming chat completion request', {
|
||||
model: request.model,
|
||||
messageCount: request.messages.length
|
||||
})
|
||||
|
||||
// Validate request
|
||||
const validation = this.validateRequest(request)
|
||||
if (!validation.isValid) {
|
||||
throw new Error(`Request validation failed: ${validation.errors.join(', ')}`)
|
||||
const preparation = await this.prepareRequest(request, true)
|
||||
if (preparation.status === 'validation_error') {
|
||||
throw new ChatCompletionValidationError(preparation.errors)
|
||||
}
|
||||
|
||||
// Get provider for the model
|
||||
const provider = await getProviderByModel(request.model!)
|
||||
if (!provider) {
|
||||
throw new Error(`Provider not found for model: ${request.model}`)
|
||||
if (preparation.status === 'model_error') {
|
||||
throw new ChatCompletionModelError(preparation.error)
|
||||
}
|
||||
|
||||
// Validate provider
|
||||
if (!validateProvider(provider)) {
|
||||
throw new Error(`Provider validation failed for: ${provider.id}`)
|
||||
}
|
||||
const { provider, modelId, client, providerRequest } = preparation
|
||||
|
||||
// Extract model ID from the full model string
|
||||
const modelId = getRealProviderModel(request.model)
|
||||
|
||||
// Create OpenAI client for the provider
|
||||
const client = new OpenAI({
|
||||
baseURL: provider.apiHost,
|
||||
apiKey: provider.apiKey
|
||||
})
|
||||
|
||||
// Prepare streaming request
|
||||
const streamingRequest = {
|
||||
...request,
|
||||
model: modelId,
|
||||
stream: true as const
|
||||
}
|
||||
|
||||
logger.debug('Sending streaming request to provider:', {
|
||||
logger.debug('Sending streaming request to provider', {
|
||||
provider: provider.id,
|
||||
model: modelId,
|
||||
apiHost: provider.apiHost
|
||||
})
|
||||
|
||||
const stream = await client.chat.completions.create(streamingRequest)
|
||||
const streamRequest = providerRequest as ChatCompletionCreateParamsStreaming
|
||||
const stream = (await client.chat.completions.create(
|
||||
streamRequest
|
||||
)) as AsyncIterable<OpenAI.Chat.Completions.ChatCompletionChunk>
|
||||
|
||||
for await (const chunk of stream) {
|
||||
yield chunk
|
||||
logger.info('Streaming chat completion started', {
|
||||
modelId,
|
||||
provider: provider.id
|
||||
})
|
||||
return {
|
||||
provider,
|
||||
modelId,
|
||||
stream
|
||||
}
|
||||
|
||||
logger.info('Successfully completed streaming chat completion')
|
||||
} catch (error: any) {
|
||||
logger.error('Error processing streaming chat completion:', error)
|
||||
logger.error('Error processing streaming chat completion', {
|
||||
error,
|
||||
model: request.model
|
||||
})
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
@@ -9,8 +9,7 @@ import type { Request, Response } from 'express'
|
||||
import type { IncomingMessage, ServerResponse } from 'http'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import { reduxService } from '../../services/ReduxService'
|
||||
import { getMcpServerById } from '../utils/mcp'
|
||||
import { getMcpServerById, getMCPServersFromRedux } from '../utils/mcp'
|
||||
|
||||
const logger = loggerService.withContext('MCPApiService')
|
||||
const transports: Record<string, StreamableHTTPServerTransport> = {}
|
||||
@@ -46,42 +45,18 @@ class MCPApiService extends EventEmitter {
|
||||
constructor() {
|
||||
super()
|
||||
this.initMcpServer()
|
||||
logger.silly('MCPApiService initialized')
|
||||
logger.debug('MCPApiService initialized')
|
||||
}
|
||||
|
||||
private initMcpServer() {
|
||||
this.transport.onmessage = this.onMessage
|
||||
}
|
||||
|
||||
/**
|
||||
* Get servers directly from Redux store
|
||||
*/
|
||||
private async getServersFromRedux(): Promise<MCPServer[]> {
|
||||
try {
|
||||
logger.silly('Getting servers from Redux store')
|
||||
|
||||
// Try to get from cache first (faster)
|
||||
const cachedServers = reduxService.selectSync<MCPServer[]>('state.mcp.servers')
|
||||
if (cachedServers && Array.isArray(cachedServers)) {
|
||||
logger.silly(`Found ${cachedServers.length} servers in Redux cache`)
|
||||
return cachedServers
|
||||
}
|
||||
|
||||
// If cache is not available, get fresh data
|
||||
const servers = await reduxService.select<MCPServer[]>('state.mcp.servers')
|
||||
logger.silly(`Fetched ${servers?.length || 0} servers from Redux store`)
|
||||
return servers || []
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to get servers from Redux:', error)
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
// get all activated servers
|
||||
async getAllServers(req: Request): Promise<McpServersResp> {
|
||||
try {
|
||||
const servers = await this.getServersFromRedux()
|
||||
logger.silly(`Returning ${servers.length} servers`)
|
||||
const servers = await getMCPServersFromRedux()
|
||||
logger.debug('Returning servers from Redux', { count: servers.length })
|
||||
const resp: McpServersResp = {
|
||||
servers: {}
|
||||
}
|
||||
@@ -98,7 +73,7 @@ class MCPApiService extends EventEmitter {
|
||||
}
|
||||
return resp
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to get all servers:', error)
|
||||
logger.error('Failed to get all servers', { error })
|
||||
throw new Error('Failed to retrieve servers')
|
||||
}
|
||||
}
|
||||
@@ -106,87 +81,47 @@ class MCPApiService extends EventEmitter {
|
||||
// get server by id
|
||||
async getServerById(id: string): Promise<MCPServer | null> {
|
||||
try {
|
||||
logger.silly(`getServerById called with id: ${id}`)
|
||||
const servers = await this.getServersFromRedux()
|
||||
logger.debug('getServerById called', { id })
|
||||
const servers = await getMCPServersFromRedux()
|
||||
const server = servers.find((s) => s.id === id)
|
||||
if (!server) {
|
||||
logger.warn(`Server with id ${id} not found`)
|
||||
logger.warn('Server not found', { id })
|
||||
return null
|
||||
}
|
||||
logger.silly(`Returning server with id ${id}`)
|
||||
logger.debug('Returning server', { id })
|
||||
return server
|
||||
} catch (error: any) {
|
||||
logger.error(`Failed to get server with id ${id}:`, error)
|
||||
logger.error('Failed to get server', { id, error })
|
||||
throw new Error('Failed to retrieve server')
|
||||
}
|
||||
}
|
||||
|
||||
async getServerInfo(id: string): Promise<any> {
|
||||
try {
|
||||
logger.silly(`getServerInfo called with id: ${id}`)
|
||||
const server = await this.getServerById(id)
|
||||
if (!server) {
|
||||
logger.warn(`Server with id ${id} not found`)
|
||||
logger.warn('Server not found while fetching info', { id })
|
||||
return null
|
||||
}
|
||||
logger.silly(`Returning server info for id ${id}`)
|
||||
|
||||
const client = await mcpService.initClient(server)
|
||||
const tools = await client.listTools()
|
||||
|
||||
logger.info(`Server with id ${id} info:`, { tools: JSON.stringify(tools) })
|
||||
|
||||
// const [version, tools, prompts, resources] = await Promise.all([
|
||||
// () => {
|
||||
// try {
|
||||
// return client.getServerVersion()
|
||||
// } catch (error) {
|
||||
// logger.error(`Failed to get server version for id ${id}:`, { error: error })
|
||||
// return '1.0.0'
|
||||
// }
|
||||
// },
|
||||
// (() => {
|
||||
// try {
|
||||
// return client.listTools()
|
||||
// } catch (error) {
|
||||
// logger.error(`Failed to list tools for id ${id}:`, { error: error })
|
||||
// return []
|
||||
// }
|
||||
// })(),
|
||||
// (() => {
|
||||
// try {
|
||||
// return client.listPrompts()
|
||||
// } catch (error) {
|
||||
// logger.error(`Failed to list prompts for id ${id}:`, { error: error })
|
||||
// return []
|
||||
// }
|
||||
// })(),
|
||||
// (() => {
|
||||
// try {
|
||||
// return client.listResources()
|
||||
// } catch (error) {
|
||||
// logger.error(`Failed to list resources for id ${id}:`, { error: error })
|
||||
// return []
|
||||
// }
|
||||
// })()
|
||||
// ])
|
||||
|
||||
return {
|
||||
id: server.id,
|
||||
name: server.name,
|
||||
type: server.type,
|
||||
description: server.description,
|
||||
tools
|
||||
tools: tools.tools
|
||||
}
|
||||
} catch (error: any) {
|
||||
logger.error(`Failed to get server info with id ${id}:`, error)
|
||||
logger.error('Failed to get server info', { id, error })
|
||||
throw new Error('Failed to retrieve server info')
|
||||
}
|
||||
}
|
||||
|
||||
async handleRequest(req: Request, res: Response, server: MCPServer) {
|
||||
const sessionId = req.headers['mcp-session-id'] as string | undefined
|
||||
logger.silly(`Handling request for server with sessionId ${sessionId}`)
|
||||
logger.debug('Handling MCP request', { sessionId, serverId: server.id })
|
||||
let transport: StreamableHTTPServerTransport
|
||||
if (sessionId && transports[sessionId]) {
|
||||
transport = transports[sessionId]
|
||||
@@ -199,7 +134,7 @@ class MCPApiService extends EventEmitter {
|
||||
})
|
||||
|
||||
transport.onclose = () => {
|
||||
logger.info(`Transport for sessionId ${sessionId} closed`)
|
||||
logger.info('Transport closed', { sessionId })
|
||||
if (transport.sessionId) {
|
||||
delete transports[transport.sessionId]
|
||||
}
|
||||
@@ -234,12 +169,15 @@ class MCPApiService extends EventEmitter {
|
||||
}
|
||||
}
|
||||
|
||||
logger.info(`Request body`, { rawBody: req.body, messages: JSON.stringify(messages) })
|
||||
logger.debug('Dispatching MCP request', {
|
||||
sessionId: transport.sessionId ?? sessionId,
|
||||
messageCount: messages.length
|
||||
})
|
||||
await transport.handleRequest(req as IncomingMessage, res as ServerResponse, messages)
|
||||
}
|
||||
|
||||
private onMessage(message: JSONRPCMessage, extra?: MessageExtraInfo) {
|
||||
logger.info(`Received message: ${JSON.stringify(message)}`, extra)
|
||||
logger.debug('Received MCP message', { message, extra })
|
||||
// Handle message here
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,321 @@
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import { MessageCreateParams, MessageStreamEvent } from '@anthropic-ai/sdk/resources'
|
||||
import { loggerService } from '@logger'
|
||||
import anthropicService from '@main/services/AnthropicService'
|
||||
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
|
||||
import { Provider } from '@types'
|
||||
import { Response } from 'express'
|
||||
|
||||
const logger = loggerService.withContext('MessagesService')
|
||||
const EXCLUDED_FORWARD_HEADERS: ReadonlySet<string> = new Set([
|
||||
'host',
|
||||
'x-api-key',
|
||||
'authorization',
|
||||
'sentry-trace',
|
||||
'baggage',
|
||||
'content-length',
|
||||
'connection'
|
||||
])
|
||||
|
||||
export interface ValidationResult {
|
||||
isValid: boolean
|
||||
errors: string[]
|
||||
}
|
||||
|
||||
export interface ErrorResponse {
|
||||
type: 'error'
|
||||
error: {
|
||||
type: string
|
||||
message: string
|
||||
requestId?: string
|
||||
}
|
||||
}
|
||||
|
||||
export interface StreamConfig {
|
||||
response: Response
|
||||
onChunk?: (chunk: MessageStreamEvent) => void
|
||||
onError?: (error: any) => void
|
||||
onComplete?: () => void
|
||||
}
|
||||
|
||||
export interface ProcessMessageOptions {
|
||||
provider: Provider
|
||||
request: MessageCreateParams
|
||||
extraHeaders?: Record<string, string | string[]>
|
||||
modelId?: string
|
||||
}
|
||||
|
||||
export interface ProcessMessageResult {
|
||||
client: Anthropic
|
||||
anthropicRequest: MessageCreateParams
|
||||
}
|
||||
|
||||
export class MessagesService {
|
||||
validateRequest(request: MessageCreateParams): ValidationResult {
|
||||
// TODO: Implement comprehensive request validation
|
||||
const errors: string[] = []
|
||||
|
||||
if (!request.model || typeof request.model !== 'string') {
|
||||
errors.push('Model is required')
|
||||
}
|
||||
|
||||
if (typeof request.max_tokens !== 'number' || !Number.isFinite(request.max_tokens) || request.max_tokens < 1) {
|
||||
errors.push('max_tokens is required and must be a positive number')
|
||||
}
|
||||
|
||||
if (!request.messages || !Array.isArray(request.messages) || request.messages.length === 0) {
|
||||
errors.push('messages is required and must be a non-empty array')
|
||||
} else {
|
||||
request.messages.forEach((message, index) => {
|
||||
if (!message || typeof message !== 'object') {
|
||||
errors.push(`messages[${index}] must be an object`)
|
||||
return
|
||||
}
|
||||
|
||||
if (!('role' in message) || typeof message.role !== 'string' || message.role.trim().length === 0) {
|
||||
errors.push(`messages[${index}].role is required`)
|
||||
}
|
||||
|
||||
const content: unknown = message.content
|
||||
if (content === undefined || content === null) {
|
||||
errors.push(`messages[${index}].content is required`)
|
||||
return
|
||||
}
|
||||
|
||||
if (typeof content === 'string' && content.trim().length === 0) {
|
||||
errors.push(`messages[${index}].content cannot be empty`)
|
||||
} else if (Array.isArray(content) && content.length === 0) {
|
||||
errors.push(`messages[${index}].content must include at least one item when using an array`)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
isValid: errors.length === 0,
|
||||
errors
|
||||
}
|
||||
}
|
||||
|
||||
async getClient(provider: Provider, extraHeaders?: Record<string, string | string[]>): Promise<Anthropic> {
|
||||
// Create Anthropic client for the provider
|
||||
if (provider.authType === 'oauth') {
|
||||
const oauthToken = await anthropicService.getValidAccessToken()
|
||||
return getSdkClient(provider, oauthToken, extraHeaders)
|
||||
}
|
||||
return getSdkClient(provider, null, extraHeaders)
|
||||
}
|
||||
|
||||
prepareHeaders(headers: Record<string, string | string[] | undefined>): Record<string, string | string[]> {
|
||||
const extraHeaders: Record<string, string | string[]> = {}
|
||||
|
||||
for (const [key, value] of Object.entries(headers)) {
|
||||
if (value === undefined) {
|
||||
continue
|
||||
}
|
||||
|
||||
const normalizedKey = key.toLowerCase()
|
||||
if (EXCLUDED_FORWARD_HEADERS.has(normalizedKey)) {
|
||||
continue
|
||||
}
|
||||
|
||||
extraHeaders[normalizedKey] = value
|
||||
}
|
||||
|
||||
return extraHeaders
|
||||
}
|
||||
|
||||
createAnthropicRequest(request: MessageCreateParams, provider: Provider, modelId?: string): MessageCreateParams {
|
||||
const anthropicRequest: MessageCreateParams = {
|
||||
...request,
|
||||
stream: !!request.stream
|
||||
}
|
||||
|
||||
// Override model if provided
|
||||
if (modelId) {
|
||||
anthropicRequest.model = modelId
|
||||
}
|
||||
|
||||
// Add Claude Code system message for OAuth providers
|
||||
if (provider.type === 'anthropic' && provider.authType === 'oauth') {
|
||||
anthropicRequest.system = buildClaudeCodeSystemMessage(request.system)
|
||||
}
|
||||
|
||||
return anthropicRequest
|
||||
}
|
||||
|
||||
async handleStreaming(
|
||||
client: Anthropic,
|
||||
request: MessageCreateParams,
|
||||
config: StreamConfig,
|
||||
provider: Provider
|
||||
): Promise<void> {
|
||||
const { response, onChunk, onError, onComplete } = config
|
||||
|
||||
// Set streaming headers
|
||||
response.setHeader('Content-Type', 'text/event-stream; charset=utf-8')
|
||||
response.setHeader('Cache-Control', 'no-cache, no-transform')
|
||||
response.setHeader('Connection', 'keep-alive')
|
||||
response.setHeader('X-Accel-Buffering', 'no')
|
||||
response.flushHeaders()
|
||||
|
||||
const flushableResponse = response as Response & { flush?: () => void }
|
||||
const flushStream = () => {
|
||||
if (typeof flushableResponse.flush !== 'function') {
|
||||
return
|
||||
}
|
||||
try {
|
||||
flushableResponse.flush()
|
||||
} catch (flushError: unknown) {
|
||||
logger.warn('Failed to flush streaming response', { error: flushError })
|
||||
}
|
||||
}
|
||||
|
||||
const writeSse = (eventType: string | undefined, payload: unknown) => {
|
||||
if (response.writableEnded || response.destroyed) {
|
||||
return
|
||||
}
|
||||
|
||||
if (eventType) {
|
||||
response.write(`event: ${eventType}\n`)
|
||||
}
|
||||
|
||||
const data = typeof payload === 'string' ? payload : JSON.stringify(payload)
|
||||
response.write(`data: ${data}\n\n`)
|
||||
flushStream()
|
||||
}
|
||||
|
||||
try {
|
||||
const stream = client.messages.stream(request)
|
||||
for await (const chunk of stream) {
|
||||
if (response.writableEnded || response.destroyed) {
|
||||
logger.warn('Streaming response ended before stream completion', {
|
||||
provider: provider.id,
|
||||
model: request.model
|
||||
})
|
||||
break
|
||||
}
|
||||
|
||||
writeSse(chunk.type, chunk)
|
||||
|
||||
if (onChunk) {
|
||||
onChunk(chunk)
|
||||
}
|
||||
}
|
||||
writeSse(undefined, '[DONE]')
|
||||
|
||||
if (onComplete) {
|
||||
onComplete()
|
||||
}
|
||||
} catch (streamError: any) {
|
||||
logger.error('Stream error', {
|
||||
error: streamError,
|
||||
provider: provider.id,
|
||||
model: request.model,
|
||||
apiHost: provider.apiHost,
|
||||
anthropicApiHost: provider.anthropicApiHost
|
||||
})
|
||||
writeSse(undefined, {
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'api_error',
|
||||
message: 'Stream processing error'
|
||||
}
|
||||
})
|
||||
|
||||
if (onError) {
|
||||
onError(streamError)
|
||||
}
|
||||
} finally {
|
||||
if (!response.writableEnded) {
|
||||
response.end()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
transformError(error: any): { statusCode: number; errorResponse: ErrorResponse } {
|
||||
let statusCode = 500
|
||||
let errorType = 'api_error'
|
||||
let errorMessage = 'Internal server error'
|
||||
|
||||
const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined
|
||||
const anthropicError = error?.error
|
||||
|
||||
if (anthropicStatus) {
|
||||
statusCode = anthropicStatus
|
||||
}
|
||||
|
||||
if (anthropicError?.type) {
|
||||
errorType = anthropicError.type
|
||||
}
|
||||
|
||||
if (anthropicError?.message) {
|
||||
errorMessage = anthropicError.message
|
||||
} else if (error instanceof Error && error.message) {
|
||||
errorMessage = error.message
|
||||
}
|
||||
|
||||
// Infer error type from message if not from Anthropic API
|
||||
if (!anthropicStatus && error instanceof Error) {
|
||||
const errorMessageText = error.message ?? ''
|
||||
|
||||
if (errorMessageText.includes('API key') || errorMessageText.includes('authentication')) {
|
||||
statusCode = 401
|
||||
errorType = 'authentication_error'
|
||||
} else if (errorMessageText.includes('rate limit') || errorMessageText.includes('quota')) {
|
||||
statusCode = 429
|
||||
errorType = 'rate_limit_error'
|
||||
} else if (errorMessageText.includes('timeout') || errorMessageText.includes('connection')) {
|
||||
statusCode = 502
|
||||
errorType = 'api_error'
|
||||
} else if (errorMessageText.includes('validation') || errorMessageText.includes('invalid')) {
|
||||
statusCode = 400
|
||||
errorType = 'invalid_request_error'
|
||||
}
|
||||
}
|
||||
|
||||
const safeErrorMessage =
|
||||
typeof errorMessage === 'string' && errorMessage.length > 0 ? errorMessage : 'Internal server error'
|
||||
|
||||
return {
|
||||
statusCode,
|
||||
errorResponse: {
|
||||
type: 'error',
|
||||
error: {
|
||||
type: errorType,
|
||||
message: safeErrorMessage,
|
||||
requestId: error?.request_id
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async processMessage(options: ProcessMessageOptions): Promise<ProcessMessageResult> {
|
||||
const { provider, request, extraHeaders, modelId } = options
|
||||
|
||||
const client = await this.getClient(provider, extraHeaders)
|
||||
const anthropicRequest = this.createAnthropicRequest(request, provider, modelId)
|
||||
|
||||
const messageCount = Array.isArray(request.messages) ? request.messages.length : 0
|
||||
|
||||
logger.info('Processing anthropic messages request', {
|
||||
provider: provider.id,
|
||||
apiHost: provider.apiHost,
|
||||
anthropicApiHost: provider.anthropicApiHost,
|
||||
model: anthropicRequest.model,
|
||||
stream: !!anthropicRequest.stream,
|
||||
// systemPrompt: JSON.stringify(!!request.system),
|
||||
// messages: JSON.stringify(request.messages),
|
||||
messageCount,
|
||||
toolCount: Array.isArray(request.tools) ? request.tools.length : 0
|
||||
})
|
||||
|
||||
// Return client and request for route layer to handle streaming/non-streaming
|
||||
return {
|
||||
client,
|
||||
anthropicRequest
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Export singleton instance
|
||||
export const messagesService = new MessagesService()
|
||||
@@ -0,0 +1,108 @@
|
||||
import { ApiModel, ApiModelsFilter, ApiModelsResponse } from '../../../renderer/src/types/apiModels'
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import { getAvailableProviders, listAllAvailableModels, transformModelToOpenAI } from '../utils'
|
||||
|
||||
const logger = loggerService.withContext('ModelsService')
|
||||
|
||||
// Re-export for backward compatibility
|
||||
|
||||
export type ModelsFilter = ApiModelsFilter
|
||||
|
||||
export class ModelsService {
|
||||
async getModels(filter: ModelsFilter): Promise<ApiModelsResponse> {
|
||||
try {
|
||||
logger.debug('Getting available models from providers', { filter })
|
||||
|
||||
let providers = await getAvailableProviders()
|
||||
|
||||
if (filter.providerType === 'anthropic') {
|
||||
providers = providers.filter(
|
||||
(p) => p.type === 'anthropic' || (p.anthropicApiHost !== undefined && p.anthropicApiHost.trim() !== '')
|
||||
)
|
||||
}
|
||||
|
||||
const models = await listAllAvailableModels(providers)
|
||||
// Use Map to deduplicate models by their full ID (provider:model_id)
|
||||
const uniqueModels = new Map<string, ApiModel>()
|
||||
|
||||
for (const model of models) {
|
||||
const provider = providers.find((p) => p.id === model.provider)
|
||||
logger.debug(`Processing model ${model.id} from provider ${model.provider}`, {
|
||||
isAnthropicModel: provider?.isAnthropicModel
|
||||
})
|
||||
if (
|
||||
!provider ||
|
||||
(filter.providerType === 'anthropic' && provider.isAnthropicModel && !provider.isAnthropicModel(model))
|
||||
) {
|
||||
continue
|
||||
}
|
||||
// Special case: For "aihubmix", it should be covered by above condition, but just in case
|
||||
if (provider.id === 'aihubmix' && filter.providerType === 'anthropic' && !model.id.includes('claude')) {
|
||||
continue
|
||||
}
|
||||
|
||||
const openAIModel = transformModelToOpenAI(model, provider)
|
||||
const fullModelId = openAIModel.id // This is already in format "provider:model_id"
|
||||
|
||||
// Only add if not already present (first occurrence wins)
|
||||
if (!uniqueModels.has(fullModelId)) {
|
||||
uniqueModels.set(fullModelId, openAIModel)
|
||||
} else {
|
||||
logger.debug(`Skipping duplicate model: ${fullModelId}`)
|
||||
}
|
||||
}
|
||||
|
||||
let modelData = Array.from(uniqueModels.values())
|
||||
const total = modelData.length
|
||||
|
||||
// Apply pagination
|
||||
const offset = filter?.offset || 0
|
||||
const limit = filter?.limit
|
||||
|
||||
if (limit !== undefined) {
|
||||
modelData = modelData.slice(offset, offset + limit)
|
||||
logger.debug(
|
||||
`Applied pagination: offset=${offset}, limit=${limit}, showing ${modelData.length} of ${total} models`
|
||||
)
|
||||
} else if (offset > 0) {
|
||||
modelData = modelData.slice(offset)
|
||||
logger.debug(`Applied offset: offset=${offset}, showing ${modelData.length} of ${total} models`)
|
||||
}
|
||||
|
||||
logger.info('Models retrieved', {
|
||||
returned: modelData.length,
|
||||
discovered: models.length,
|
||||
filter
|
||||
})
|
||||
|
||||
if (models.length > total) {
|
||||
logger.debug(`Filtered out ${models.length - total} models after deduplication and filtering`)
|
||||
}
|
||||
|
||||
const response: ApiModelsResponse = {
|
||||
object: 'list',
|
||||
data: modelData
|
||||
}
|
||||
|
||||
// Add pagination metadata if applicable
|
||||
if (filter?.limit !== undefined || filter?.offset !== undefined) {
|
||||
response.total = total
|
||||
response.offset = offset
|
||||
if (filter?.limit !== undefined) {
|
||||
response.limit = filter.limit
|
||||
}
|
||||
}
|
||||
|
||||
return response
|
||||
} catch (error: any) {
|
||||
logger.error('Error getting models', { error, filter })
|
||||
return {
|
||||
object: 'list',
|
||||
data: []
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Export singleton instance
|
||||
export const modelsService = new ModelsService()
|
||||
@@ -0,0 +1,64 @@
|
||||
export type StreamAbortHandler = (reason: unknown) => void
|
||||
|
||||
export interface StreamAbortController {
|
||||
abortController: AbortController
|
||||
registerAbortHandler: (handler: StreamAbortHandler) => void
|
||||
clearAbortTimeout: () => void
|
||||
}
|
||||
|
||||
export const STREAM_TIMEOUT_REASON = 'stream timeout'
|
||||
|
||||
interface CreateStreamAbortControllerOptions {
|
||||
timeoutMs: number
|
||||
}
|
||||
|
||||
export const createStreamAbortController = (options: CreateStreamAbortControllerOptions): StreamAbortController => {
|
||||
const { timeoutMs } = options
|
||||
const abortController = new AbortController()
|
||||
const signal = abortController.signal
|
||||
|
||||
let timeoutId: NodeJS.Timeout | undefined
|
||||
let abortHandler: StreamAbortHandler | undefined
|
||||
|
||||
const clearAbortTimeout = () => {
|
||||
if (!timeoutId) {
|
||||
return
|
||||
}
|
||||
clearTimeout(timeoutId)
|
||||
timeoutId = undefined
|
||||
}
|
||||
|
||||
const handleAbort = () => {
|
||||
clearAbortTimeout()
|
||||
|
||||
if (!abortHandler) {
|
||||
return
|
||||
}
|
||||
|
||||
abortHandler(signal.reason)
|
||||
}
|
||||
|
||||
signal.addEventListener('abort', handleAbort, { once: true })
|
||||
|
||||
const registerAbortHandler = (handler: StreamAbortHandler) => {
|
||||
abortHandler = handler
|
||||
|
||||
if (signal.aborted) {
|
||||
abortHandler(signal.reason)
|
||||
}
|
||||
}
|
||||
|
||||
if (timeoutMs > 0) {
|
||||
timeoutId = setTimeout(() => {
|
||||
if (!signal.aborted) {
|
||||
abortController.abort(STREAM_TIMEOUT_REASON)
|
||||
}
|
||||
}, timeoutMs)
|
||||
}
|
||||
|
||||
return {
|
||||
abortController,
|
||||
registerAbortHandler,
|
||||
clearAbortTimeout
|
||||
}
|
||||
}
|
||||
@@ -1,46 +1,60 @@
|
||||
import { CacheService } from '@main/services/CacheService'
|
||||
import { loggerService } from '@main/services/LoggerService'
|
||||
import { reduxService } from '@main/services/ReduxService'
|
||||
import type { Model, Provider } from '@types'
|
||||
import type { ApiModel, Model, Provider } from '@types'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerUtils')
|
||||
|
||||
// OpenAI compatible model format
|
||||
export interface OpenAICompatibleModel {
|
||||
id: string
|
||||
object: 'model'
|
||||
created: number
|
||||
owned_by: string
|
||||
provider?: string
|
||||
provider_model_id?: string
|
||||
}
|
||||
// Cache configuration
|
||||
const PROVIDERS_CACHE_KEY = 'api-server:providers'
|
||||
const PROVIDERS_CACHE_TTL = 10 * 1000 // 10 seconds
|
||||
|
||||
export async function getAvailableProviders(): Promise<Provider[]> {
|
||||
try {
|
||||
// Wait for store to be ready before accessing providers
|
||||
// Try to get from cache first (faster)
|
||||
const cachedSupportedProviders = CacheService.get<Provider[]>(PROVIDERS_CACHE_KEY)
|
||||
if (cachedSupportedProviders && cachedSupportedProviders.length > 0) {
|
||||
logger.debug('Providers resolved from cache', {
|
||||
count: cachedSupportedProviders.length
|
||||
})
|
||||
return cachedSupportedProviders
|
||||
}
|
||||
|
||||
// If cache is not available, get fresh data from Redux
|
||||
const providers = await reduxService.select('state.llm.providers')
|
||||
if (!providers || !Array.isArray(providers)) {
|
||||
logger.warn('No providers found in Redux store, returning empty array')
|
||||
logger.warn('No providers found in Redux store')
|
||||
return []
|
||||
}
|
||||
|
||||
// Only support OpenAI type providers for API server
|
||||
const openAIProviders = providers.filter((p: Provider) => p.enabled && p.type === 'openai')
|
||||
// Support OpenAI and Anthropic type providers for API server
|
||||
const supportedProviders = providers.filter(
|
||||
(p: Provider) => p.enabled && (p.type === 'openai' || p.type === 'anthropic')
|
||||
)
|
||||
|
||||
logger.info(`Filtered to ${openAIProviders.length} OpenAI providers from ${providers.length} total providers`)
|
||||
// Cache the filtered results
|
||||
CacheService.set(PROVIDERS_CACHE_KEY, supportedProviders, PROVIDERS_CACHE_TTL)
|
||||
|
||||
return openAIProviders
|
||||
logger.info('Providers filtered', {
|
||||
supported: supportedProviders.length,
|
||||
total: providers.length
|
||||
})
|
||||
|
||||
return supportedProviders
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to get providers from Redux store:', error)
|
||||
logger.error('Failed to get providers from Redux store', { error })
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
export async function listAllAvailableModels(): Promise<Model[]> {
|
||||
export async function listAllAvailableModels(providers?: Provider[]): Promise<Model[]> {
|
||||
try {
|
||||
const providers = await getAvailableProviders()
|
||||
if (!providers) {
|
||||
providers = await getAvailableProviders()
|
||||
}
|
||||
return providers.map((p: Provider) => p.models || []).flat()
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to list available models:', error)
|
||||
logger.error('Failed to list available models', { error })
|
||||
return []
|
||||
}
|
||||
}
|
||||
@@ -48,15 +62,13 @@ export async function listAllAvailableModels(): Promise<Model[]> {
|
||||
export async function getProviderByModel(model: string): Promise<Provider | undefined> {
|
||||
try {
|
||||
if (!model || typeof model !== 'string') {
|
||||
logger.warn(`Invalid model parameter: ${model}`)
|
||||
logger.warn('Invalid model parameter', { model })
|
||||
return undefined
|
||||
}
|
||||
|
||||
// Validate model format first
|
||||
if (!model.includes(':')) {
|
||||
logger.warn(
|
||||
`Invalid model format, must contain ':' separator. Expected format "provider:model_id", got: ${model}`
|
||||
)
|
||||
logger.warn('Invalid model format missing separator', { model })
|
||||
return undefined
|
||||
}
|
||||
|
||||
@@ -64,7 +76,7 @@ export async function getProviderByModel(model: string): Promise<Provider | unde
|
||||
const modelInfo = model.split(':')
|
||||
|
||||
if (modelInfo.length < 2 || modelInfo[0].length === 0 || modelInfo[1].length === 0) {
|
||||
logger.warn(`Invalid model format, expected "provider:model_id" with non-empty parts, got: ${model}`)
|
||||
logger.warn('Invalid model format with empty parts', { model })
|
||||
return undefined
|
||||
}
|
||||
|
||||
@@ -72,16 +84,17 @@ export async function getProviderByModel(model: string): Promise<Provider | unde
|
||||
const provider = providers.find((p: Provider) => p.id === providerId)
|
||||
|
||||
if (!provider) {
|
||||
logger.warn(
|
||||
`Provider '${providerId}' not found or not enabled. Available providers: ${providers.map((p) => p.id).join(', ')}`
|
||||
)
|
||||
logger.warn('Provider not found for model', {
|
||||
providerId,
|
||||
available: providers.map((p) => p.id)
|
||||
})
|
||||
return undefined
|
||||
}
|
||||
|
||||
logger.debug(`Found provider '${providerId}' for model: ${model}`)
|
||||
logger.debug('Provider resolved for model', { providerId, model })
|
||||
return provider
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to get provider by model:', error)
|
||||
logger.error('Failed to get provider by model', { error, model })
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
@@ -96,9 +109,12 @@ export interface ModelValidationError {
|
||||
code: string
|
||||
}
|
||||
|
||||
export async function validateModelId(
|
||||
model: string
|
||||
): Promise<{ valid: boolean; error?: ModelValidationError; provider?: Provider; modelId?: string }> {
|
||||
export async function validateModelId(model: string): Promise<{
|
||||
valid: boolean
|
||||
error?: ModelValidationError
|
||||
provider?: Provider
|
||||
modelId?: string
|
||||
}> {
|
||||
try {
|
||||
if (!model || typeof model !== 'string') {
|
||||
return {
|
||||
@@ -169,7 +185,7 @@ export async function validateModelId(
|
||||
modelId
|
||||
}
|
||||
} catch (error: any) {
|
||||
logger.error('Error validating model ID:', error)
|
||||
logger.error('Error validating model ID', { error, model })
|
||||
return {
|
||||
valid: false,
|
||||
error: {
|
||||
@@ -181,17 +197,47 @@ export async function validateModelId(
|
||||
}
|
||||
}
|
||||
|
||||
export function transformModelToOpenAI(model: Model): OpenAICompatibleModel {
|
||||
export function transformModelToOpenAI(model: Model, provider?: Provider): ApiModel {
|
||||
const providerDisplayName = provider?.name
|
||||
return {
|
||||
id: `${model.provider}:${model.id}`,
|
||||
object: 'model',
|
||||
name: model.name,
|
||||
created: Math.floor(Date.now() / 1000),
|
||||
owned_by: model.owned_by || model.provider,
|
||||
owned_by: model.owned_by || providerDisplayName || model.provider,
|
||||
provider: model.provider,
|
||||
provider_name: providerDisplayName,
|
||||
provider_type: provider?.type,
|
||||
provider_model_id: model.id
|
||||
}
|
||||
}
|
||||
|
||||
export async function getProviderById(providerId: string): Promise<Provider | undefined> {
|
||||
try {
|
||||
if (!providerId || typeof providerId !== 'string') {
|
||||
logger.warn('Invalid provider ID parameter', { providerId })
|
||||
return undefined
|
||||
}
|
||||
|
||||
const providers = await getAvailableProviders()
|
||||
const provider = providers.find((p: Provider) => p.id === providerId)
|
||||
|
||||
if (!provider) {
|
||||
logger.warn('Provider not found by ID', {
|
||||
providerId,
|
||||
available: providers.map((p) => p.id)
|
||||
})
|
||||
return undefined
|
||||
}
|
||||
|
||||
logger.debug('Provider found by ID', { providerId })
|
||||
return provider
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to get provider by ID', { error, providerId })
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
export function validateProvider(provider: Provider): boolean {
|
||||
try {
|
||||
if (!provider) {
|
||||
@@ -200,7 +246,7 @@ export function validateProvider(provider: Provider): boolean {
|
||||
|
||||
// Check required fields
|
||||
if (!provider.id || !provider.type || !provider.apiKey || !provider.apiHost) {
|
||||
logger.warn('Provider missing required fields:', {
|
||||
logger.warn('Provider missing required fields', {
|
||||
id: !!provider.id,
|
||||
type: !!provider.type,
|
||||
apiKey: !!provider.apiKey,
|
||||
@@ -211,21 +257,25 @@ export function validateProvider(provider: Provider): boolean {
|
||||
|
||||
// Check if provider is enabled
|
||||
if (!provider.enabled) {
|
||||
logger.debug(`Provider is disabled: ${provider.id}`)
|
||||
logger.debug('Provider is disabled', { providerId: provider.id })
|
||||
return false
|
||||
}
|
||||
|
||||
// Only support OpenAI type providers
|
||||
if (provider.type !== 'openai') {
|
||||
logger.debug(
|
||||
`Provider type '${provider.type}' not supported, only 'openai' type is currently supported: ${provider.id}`
|
||||
)
|
||||
// Support OpenAI and Anthropic type providers
|
||||
if (provider.type !== 'openai' && provider.type !== 'anthropic') {
|
||||
logger.debug('Provider type not supported', {
|
||||
providerId: provider.id,
|
||||
providerType: provider.type
|
||||
})
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
} catch (error: any) {
|
||||
logger.error('Error validating provider:', error)
|
||||
logger.error('Error validating provider', {
|
||||
error,
|
||||
providerId: provider?.id
|
||||
})
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import { CacheService } from '@main/services/CacheService'
|
||||
import mcpService from '@main/services/MCPService'
|
||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||
import type { ListToolsResult } from '@modelcontextprotocol/sdk/types.js'
|
||||
@@ -9,6 +10,10 @@ import { reduxService } from '../../services/ReduxService'
|
||||
|
||||
const logger = loggerService.withContext('MCPApiService')
|
||||
|
||||
// Cache configuration
|
||||
const MCP_SERVERS_CACHE_KEY = 'api-server:mcp-servers'
|
||||
const MCP_SERVERS_CACHE_TTL = 5 * 60 * 1000 // 5 minutes
|
||||
|
||||
const cachedServers: Record<string, Server> = {}
|
||||
|
||||
async function handleListToolsRequest(request: any, extra: any): Promise<ListToolsResult> {
|
||||
@@ -34,20 +39,35 @@ async function handleCallToolRequest(request: any, extra: any): Promise<any> {
|
||||
}
|
||||
|
||||
async function getMcpServerConfigById(id: string): Promise<MCPServer | undefined> {
|
||||
const servers = await getServersFromRedux()
|
||||
const servers = await getMCPServersFromRedux()
|
||||
return servers.find((s) => s.id === id || s.name === id)
|
||||
}
|
||||
|
||||
/**
|
||||
* Get servers directly from Redux store
|
||||
*/
|
||||
async function getServersFromRedux(): Promise<MCPServer[]> {
|
||||
export async function getMCPServersFromRedux(): Promise<MCPServer[]> {
|
||||
try {
|
||||
logger.debug('Getting servers from Redux store')
|
||||
|
||||
// Try to get from cache first (faster)
|
||||
const cachedServers = CacheService.get<MCPServer[]>(MCP_SERVERS_CACHE_KEY)
|
||||
if (cachedServers) {
|
||||
logger.debug('MCP servers resolved from cache', { count: cachedServers.length })
|
||||
return cachedServers
|
||||
}
|
||||
|
||||
// If cache is not available, get fresh data from Redux
|
||||
const servers = await reduxService.select<MCPServer[]>('state.mcp.servers')
|
||||
logger.silly(`Fetched ${servers?.length || 0} servers from Redux store`)
|
||||
return servers || []
|
||||
const serverList = servers || []
|
||||
|
||||
// Cache the results
|
||||
CacheService.set(MCP_SERVERS_CACHE_KEY, serverList, MCP_SERVERS_CACHE_TTL)
|
||||
|
||||
logger.debug('Fetched servers from Redux store', { count: serverList.length })
|
||||
return serverList
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to get servers from Redux:', error)
|
||||
logger.error('Failed to get servers from Redux', { error })
|
||||
return []
|
||||
}
|
||||
}
|
||||
@@ -55,7 +75,7 @@ async function getServersFromRedux(): Promise<MCPServer[]> {
|
||||
export async function getMcpServerById(id: string): Promise<Server> {
|
||||
const server = cachedServers[id]
|
||||
if (!server) {
|
||||
const servers = await getServersFromRedux()
|
||||
const servers = await getMCPServersFromRedux()
|
||||
const mcpServer = servers.find((s) => s.id === id || s.name === id)
|
||||
if (!mcpServer) {
|
||||
throw new Error(`Server not found: ${id}`)
|
||||
@@ -72,6 +92,6 @@ export async function getMcpServerById(id: string): Promise<Server> {
|
||||
cachedServers[id] = newServer
|
||||
return newServer
|
||||
}
|
||||
logger.silly('getMcpServer ', { server: server })
|
||||
logger.debug('Returning cached MCP server', { id, hasHandlers: Boolean(server) })
|
||||
return server
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user