✨ feat: implement comprehensive Claude Code OAuth integration and API enhancements
- Add shared Anthropic utilities package with OAuth and API key client creation - Implement provider-specific message routing alongside existing v1 API - Enhance authentication middleware with priority handling (API key > Bearer token) - Add comprehensive auth middleware test suite with timing attack protection - Update session handling and message transformation for Claude Code integration - Improve error handling and validation across message processing pipeline - Standardize import formatting and code structure across affected modules This establishes the foundation for Claude Code OAuth authentication while maintaining backward compatibility with existing API key authentication methods.
This commit is contained in:
146
packages/shared/anthropic/index.ts
Normal file
146
packages/shared/anthropic/index.ts
Normal file
@@ -0,0 +1,146 @@
|
||||
/**
|
||||
* @fileoverview Shared Anthropic AI client utilities for Cherry Studio
|
||||
*
|
||||
* This module provides functions for creating Anthropic SDK clients with different
|
||||
* authentication methods (OAuth, API key) and building Claude Code system messages.
|
||||
* It supports both standard Anthropic API and Anthropic Vertex AI endpoints.
|
||||
*
|
||||
* This shared module can be used by both main and renderer processes.
|
||||
*/
|
||||
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
import {TextBlockParam} from "@anthropic-ai/sdk/resources";
|
||||
import {Provider} from "@types";
|
||||
|
||||
/**
|
||||
* Creates and configures an Anthropic SDK client based on the provider configuration.
|
||||
*
|
||||
* This function supports two authentication methods:
|
||||
* 1. OAuth: Uses OAuth tokens passed as parameter
|
||||
* 2. API Key: Uses traditional API key authentication
|
||||
*
|
||||
* For OAuth authentication, it includes Claude Code specific headers and beta features.
|
||||
* For API key authentication, it uses the provider's configuration with custom headers.
|
||||
*
|
||||
* @param provider - The provider configuration containing authentication details
|
||||
* @param oauthToken - Optional OAuth token for OAuth authentication
|
||||
* @returns An initialized Anthropic or AnthropicVertex client
|
||||
* @throws Error when OAuth token is not available for OAuth authentication
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* // OAuth authentication
|
||||
* const oauthProvider = { authType: 'oauth' };
|
||||
* const oauthClient = getSdkClient(oauthProvider, 'oauth-token-here');
|
||||
*
|
||||
* // API key authentication
|
||||
* const apiKeyProvider = {
|
||||
* authType: 'apikey',
|
||||
* apiKey: 'your-api-key',
|
||||
* apiHost: 'https://api.anthropic.com'
|
||||
* };
|
||||
* const apiKeyClient = getSdkClient(apiKeyProvider);
|
||||
* ```
|
||||
*/
|
||||
export function getSdkClient(provider: Provider, oauthToken?: string | null): Anthropic {
|
||||
if (provider.authType === 'oauth') {
|
||||
if (!oauthToken) {
|
||||
throw new Error('OAuth token is not available')
|
||||
}
|
||||
return new Anthropic({
|
||||
authToken: oauthToken,
|
||||
baseURL: 'https://api.anthropic.com',
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: {
|
||||
'Content-Type': 'application/json',
|
||||
'anthropic-version': '2023-06-01',
|
||||
'anthropic-beta': 'oauth-2025-04-20,claude-code-20250219,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14',
|
||||
'anthropic-dangerous-direct-browser-access': 'true',
|
||||
'user-agent': 'claude-cli/1.0.118 (external, sdk-ts)',
|
||||
'x-app': 'cli',
|
||||
'x-stainless-retry-count': '0',
|
||||
'x-stainless-timeout': '600',
|
||||
'x-stainless-lang': 'js',
|
||||
'x-stainless-package-version': '0.60.0',
|
||||
'x-stainless-os': 'MacOS',
|
||||
'x-stainless-arch': 'arm64',
|
||||
'x-stainless-runtime': 'node',
|
||||
'x-stainless-runtime-version': 'v22.18.0'
|
||||
}
|
||||
})
|
||||
}
|
||||
return new Anthropic({
|
||||
apiKey: provider.apiKey,
|
||||
baseURL: provider.apiHost,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: {
|
||||
'anthropic-beta': 'output-128k-2025-02-19',
|
||||
...provider.extra_headers
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Builds and prepends the Claude Code system message to user-provided system messages.
|
||||
*
|
||||
* This function ensures that all interactions with Claude include the official Claude Code
|
||||
* system prompt, which identifies the assistant as "Claude Code, Anthropic's official CLI for Claude."
|
||||
*
|
||||
* The function handles three cases:
|
||||
* 1. No system message provided: Returns only the default Claude Code system message
|
||||
* 2. String system message: Converts to array format and prepends Claude Code message
|
||||
* 3. Array system message: Checks if Claude Code message exists and prepends if missing
|
||||
*
|
||||
* @param system - Optional user-provided system message (string or TextBlockParam array)
|
||||
* @returns Combined system message with Claude Code prompt prepended
|
||||
*
|
||||
* @example
|
||||
* ```typescript
|
||||
* // No system message
|
||||
* const result1 = buildClaudeCodeSystemMessage();
|
||||
* // Returns: "You are Claude Code, Anthropic's official CLI for Claude."
|
||||
*
|
||||
* // String system message
|
||||
* const result2 = buildClaudeCodeSystemMessage("You are a helpful assistant.");
|
||||
* // Returns: [
|
||||
* // { type: 'text', text: "You are Claude Code, Anthropic's official CLI for Claude." },
|
||||
* // { type: 'text', text: "You are a helpful assistant." }
|
||||
* // ]
|
||||
*
|
||||
* // Array system message
|
||||
* const systemArray = [{ type: 'text', text: 'Custom instructions' }];
|
||||
* const result3 = buildClaudeCodeSystemMessage(systemArray);
|
||||
* // Returns: Array with Claude Code message prepended
|
||||
* ```
|
||||
*/
|
||||
export function buildClaudeCodeSystemMessage(system?: string | Array<TextBlockParam>): string | Array<TextBlockParam> {
|
||||
const defaultClaudeCodeSystem = `You are Claude Code, Anthropic's official CLI for Claude.`
|
||||
if (!system) {
|
||||
return defaultClaudeCodeSystem
|
||||
}
|
||||
|
||||
if (typeof system === 'string') {
|
||||
if (system.trim() === defaultClaudeCodeSystem) {
|
||||
return system
|
||||
}
|
||||
return [
|
||||
{
|
||||
type: 'text',
|
||||
text: defaultClaudeCodeSystem
|
||||
},
|
||||
{
|
||||
type: 'text',
|
||||
text: system
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
if (system[0].text.trim() != defaultClaudeCodeSystem) {
|
||||
system.unshift({
|
||||
type: 'text',
|
||||
text: defaultClaudeCodeSystem
|
||||
})
|
||||
}
|
||||
|
||||
return system
|
||||
}
|
||||
@@ -1,16 +1,16 @@
|
||||
import { loggerService } from '@main/services/LoggerService'
|
||||
import {loggerService} from '@main/services/LoggerService'
|
||||
import cors from 'cors'
|
||||
import express from 'express'
|
||||
import { v4 as uuidv4 } from 'uuid'
|
||||
import {v4 as uuidv4} from 'uuid'
|
||||
|
||||
import { authMiddleware } from './middleware/auth'
|
||||
import { errorHandler } from './middleware/error'
|
||||
import { setupOpenAPIDocumentation } from './middleware/openapi'
|
||||
import { agentsRoutes } from './routes/agents'
|
||||
import { chatRoutes } from './routes/chat'
|
||||
import { mcpRoutes } from './routes/mcp'
|
||||
import { messagesRoutes } from './routes/messages'
|
||||
import { modelsRoutes } from './routes/models'
|
||||
import {authMiddleware} from './middleware/auth'
|
||||
import {errorHandler} from './middleware/error'
|
||||
import {setupOpenAPIDocumentation} from './middleware/openapi'
|
||||
import {agentsRoutes} from './routes/agents'
|
||||
import {chatRoutes} from './routes/chat'
|
||||
import {mcpRoutes} from './routes/mcp'
|
||||
import {messagesProviderRoutes, messagesRoutes} from './routes/messages'
|
||||
import {modelsRoutes} from './routes/models'
|
||||
|
||||
const logger = loggerService.withContext('ApiServer')
|
||||
|
||||
@@ -108,6 +108,14 @@ app.get('/', (_req, res) => {
|
||||
})
|
||||
})
|
||||
|
||||
// Provider-specific API routes with auth (must be before /v1 to avoid conflicts)
|
||||
const providerRouter = express.Router({mergeParams: true})
|
||||
providerRouter.use(authMiddleware)
|
||||
providerRouter.use(express.json())
|
||||
// Mount provider-specific messages route
|
||||
providerRouter.use('/v1/messages', messagesProviderRoutes)
|
||||
app.use('/:provider', providerRouter)
|
||||
|
||||
// API v1 routes with auth
|
||||
const apiRouter = express.Router()
|
||||
apiRouter.use(authMiddleware)
|
||||
@@ -120,10 +128,11 @@ apiRouter.use('/models', modelsRoutes)
|
||||
apiRouter.use('/agents', agentsRoutes)
|
||||
app.use('/v1', apiRouter)
|
||||
|
||||
|
||||
// Setup OpenAPI documentation
|
||||
setupOpenAPIDocumentation(app)
|
||||
|
||||
// Error handling (must be last)
|
||||
app.use(errorHandler)
|
||||
|
||||
export { app }
|
||||
export {app}
|
||||
|
||||
368
src/main/apiServer/middleware/__tests__/auth.test.ts
Normal file
368
src/main/apiServer/middleware/__tests__/auth.test.ts
Normal file
@@ -0,0 +1,368 @@
|
||||
import type {NextFunction, Request, Response} from 'express'
|
||||
import {beforeEach, describe, expect, it, vi} from 'vitest'
|
||||
|
||||
import {config} from '../../config'
|
||||
import {authMiddleware} from '../auth'
|
||||
|
||||
// Mock the config module
|
||||
vi.mock('../../config', () => ({
|
||||
config: {
|
||||
get: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
// Mock the logger
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: vi.fn(() => ({
|
||||
debug: vi.fn()
|
||||
}))
|
||||
}
|
||||
}))
|
||||
|
||||
const mockConfig = config as any
|
||||
|
||||
describe('authMiddleware', () => {
|
||||
let req: Partial<Request>
|
||||
let res: Partial<Response>
|
||||
let next: NextFunction
|
||||
let jsonMock: ReturnType<typeof vi.fn>
|
||||
let statusMock: ReturnType<typeof vi.fn>
|
||||
|
||||
beforeEach(() => {
|
||||
jsonMock = vi.fn()
|
||||
statusMock = vi.fn(() => ({json: jsonMock}))
|
||||
|
||||
req = {
|
||||
header: vi.fn()
|
||||
}
|
||||
res = {
|
||||
status: statusMock
|
||||
}
|
||||
next = vi.fn()
|
||||
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('Missing credentials', () => {
|
||||
it('should return 401 when both auth headers are missing', async () => {
|
||||
;(req.header as any).mockReturnValue('')
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: missing credentials'})
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return 401 when both auth headers are empty strings', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return ''
|
||||
if (header === 'x-api-key') return ''
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: missing credentials'})
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Server configuration', () => {
|
||||
it('should return 403 when API key is not configured', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return 'some-key'
|
||||
return ''
|
||||
})
|
||||
|
||||
mockConfig.get.mockResolvedValue({apiKey: ''})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return 403 when API key is null', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return 'some-key'
|
||||
return ''
|
||||
})
|
||||
|
||||
mockConfig.get.mockResolvedValue({apiKey: null})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('API Key authentication (priority)', () => {
|
||||
const validApiKey = 'valid-api-key-123'
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfig.get.mockResolvedValue({apiKey: validApiKey})
|
||||
})
|
||||
|
||||
it('should authenticate successfully with valid API key', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return validApiKey
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(next).toHaveBeenCalled()
|
||||
expect(statusMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return 403 with invalid API key', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return 'invalid-key'
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return 401 with empty API key', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return ' '
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: empty x-api-key'})
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle API key with whitespace', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return ` ${validApiKey} `
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(next).toHaveBeenCalled()
|
||||
expect(statusMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should prioritize API key over Bearer token when both are present', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return validApiKey
|
||||
if (header === 'authorization') return 'Bearer invalid-token'
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(next).toHaveBeenCalled()
|
||||
expect(statusMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return 403 when API key is invalid even if Bearer token is valid', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return 'invalid-key'
|
||||
if (header === 'authorization') return `Bearer ${validApiKey}`
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Bearer token authentication (fallback)', () => {
|
||||
const validApiKey = 'valid-api-key-123'
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfig.get.mockResolvedValue({apiKey: validApiKey})
|
||||
})
|
||||
|
||||
it('should authenticate successfully with valid Bearer token when no API key', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return `Bearer ${validApiKey}`
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(next).toHaveBeenCalled()
|
||||
expect(statusMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return 403 with invalid Bearer token', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return 'Bearer invalid-token'
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return 401 with malformed authorization header', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return 'Basic sometoken'
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'})
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return 401 with Bearer without space', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return 'Bearer'
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'})
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle Bearer token with only trailing spaces (edge case)', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return 'Bearer ' // This will be trimmed to "Bearer" and fail format check
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'})
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle Bearer token with case insensitive prefix', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return `bearer ${validApiKey}`
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(next).toHaveBeenCalled()
|
||||
expect(statusMock).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle Bearer token with whitespace', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return ` Bearer ${validApiKey} `
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(next).toHaveBeenCalled()
|
||||
expect(statusMock).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Edge cases', () => {
|
||||
const validApiKey = 'valid-api-key-123'
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfig.get.mockResolvedValue({apiKey: validApiKey})
|
||||
})
|
||||
|
||||
it('should handle config.get() rejection', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return validApiKey
|
||||
return ''
|
||||
})
|
||||
|
||||
mockConfig.get.mockRejectedValue(new Error('Config error'))
|
||||
|
||||
await expect(authMiddleware(req as Request, res as Response, next)).rejects.toThrow('Config error')
|
||||
})
|
||||
|
||||
it('should use timing-safe comparison for different length tokens', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return 'short'
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should return 401 when neither credential format is valid', async () => {
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return 'Invalid format'
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(401)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Unauthorized: invalid authorization format'})
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Timing attack protection', () => {
|
||||
const validApiKey = 'valid-api-key-123'
|
||||
|
||||
beforeEach(() => {
|
||||
mockConfig.get.mockResolvedValue({apiKey: validApiKey})
|
||||
})
|
||||
|
||||
it('should handle similar length but different API keys securely', async () => {
|
||||
const similarKey = 'valid-api-key-124' // Same length, different last char
|
||||
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'x-api-key') return similarKey
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle similar length but different Bearer tokens securely', async () => {
|
||||
const similarKey = 'valid-api-key-124' // Same length, different last char
|
||||
|
||||
;(req.header as any).mockImplementation((header: string) => {
|
||||
if (header === 'authorization') return `Bearer ${similarKey}`
|
||||
return ''
|
||||
})
|
||||
|
||||
await authMiddleware(req as Request, res as Response, next)
|
||||
|
||||
expect(statusMock).toHaveBeenCalledWith(403)
|
||||
expect(jsonMock).toHaveBeenCalledWith({error: 'Forbidden'})
|
||||
expect(next).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,62 +1,67 @@
|
||||
import crypto from 'crypto'
|
||||
import { NextFunction, Request, Response } from 'express'
|
||||
import {NextFunction, Request, Response} from 'express'
|
||||
|
||||
import { config } from '../config'
|
||||
import {config} from '../config'
|
||||
|
||||
|
||||
const isValidToken = (token: string, apiKey: string): boolean => {
|
||||
if (token.length !== apiKey.length) {
|
||||
return false
|
||||
}
|
||||
const tokenBuf = Buffer.from(token)
|
||||
const keyBuf = Buffer.from(apiKey)
|
||||
return crypto.timingSafeEqual(tokenBuf, keyBuf)
|
||||
}
|
||||
|
||||
export const authMiddleware = async (req: Request, res: Response, next: NextFunction) => {
|
||||
const auth = req.header('Authorization') || ''
|
||||
const auth = req.header('authorization') || ''
|
||||
const xApiKey = req.header('x-api-key') || ''
|
||||
|
||||
// Fast rejection if neither credential header provided
|
||||
if (!auth && !xApiKey) {
|
||||
return res.status(401).json({ error: 'Unauthorized: missing credentials' })
|
||||
return res.status(401).json({error: 'Unauthorized: missing credentials'})
|
||||
}
|
||||
|
||||
let token: string | undefined
|
||||
const {apiKey} = await config.get()
|
||||
|
||||
// Prefer Bearer if well‑formed
|
||||
if (!apiKey) {
|
||||
return res.status(403).json({error: 'Forbidden'})
|
||||
}
|
||||
|
||||
// Check API key first (priority)
|
||||
if (xApiKey) {
|
||||
const trimmedApiKey = xApiKey.trim()
|
||||
if (!trimmedApiKey) {
|
||||
return res.status(401).json({error: 'Unauthorized: empty x-api-key'})
|
||||
}
|
||||
|
||||
if (isValidToken(trimmedApiKey, apiKey)) {
|
||||
return next()
|
||||
} else {
|
||||
return res.status(403).json({error: 'Forbidden'})
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to Bearer token
|
||||
if (auth) {
|
||||
const trimmed = auth.trim()
|
||||
const bearerPrefix = /^Bearer\s+/i
|
||||
if (bearerPrefix.test(trimmed)) {
|
||||
const candidate = trimmed.replace(bearerPrefix, '').trim()
|
||||
if (!candidate) {
|
||||
return res.status(401).json({ error: 'Unauthorized: empty bearer token' })
|
||||
}
|
||||
token = candidate
|
||||
|
||||
if (!bearerPrefix.test(trimmed)) {
|
||||
return res.status(401).json({error: 'Unauthorized: invalid authorization format'})
|
||||
}
|
||||
|
||||
const token = trimmed.replace(bearerPrefix, '').trim()
|
||||
if (!token) {
|
||||
return res.status(401).json({error: 'Unauthorized: empty bearer token'})
|
||||
}
|
||||
|
||||
if (isValidToken(token, apiKey)) {
|
||||
return next()
|
||||
} else {
|
||||
return res.status(403).json({error: 'Forbidden'})
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to x-api-key if token still not resolved
|
||||
if (!token && xApiKey) {
|
||||
if (!xApiKey.trim()) {
|
||||
return res.status(401).json({ error: 'Unauthorized: empty x-api-key' })
|
||||
}
|
||||
token = xApiKey.trim()
|
||||
}
|
||||
|
||||
if (!token) {
|
||||
// At this point we had at least one header, but none yielded a usable token
|
||||
return res.status(401).json({ error: 'Unauthorized: invalid credentials format' })
|
||||
}
|
||||
|
||||
const { apiKey } = await config.get()
|
||||
|
||||
if (!apiKey) {
|
||||
// If server not configured, treat as forbidden (or could be 500). Choose 403 to avoid leaking config state.
|
||||
return res.status(403).json({ error: 'Forbidden' })
|
||||
}
|
||||
|
||||
// Timing-safe compare when lengths match, else immediate forbidden
|
||||
if (token.length !== apiKey.length) {
|
||||
return res.status(403).json({ error: 'Forbidden' })
|
||||
}
|
||||
|
||||
const tokenBuf = Buffer.from(token)
|
||||
const keyBuf = Buffer.from(apiKey)
|
||||
if (!crypto.timingSafeEqual(tokenBuf, keyBuf)) {
|
||||
return res.status(403).json({ error: 'Forbidden' })
|
||||
}
|
||||
|
||||
return next()
|
||||
return res.status(401).json({error: 'Unauthorized: invalid credentials format'})
|
||||
}
|
||||
|
||||
@@ -1,13 +1,185 @@
|
||||
import { MessageCreateParams } from '@anthropic-ai/sdk/resources'
|
||||
import express, { Request, Response } from 'express'
|
||||
import {MessageCreateParams} from '@anthropic-ai/sdk/resources'
|
||||
import {loggerService} from '@logger'
|
||||
import express, {Request, Response} from 'express'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
import { messagesService } from '../services/messages'
|
||||
import { validateModelId } from '../utils'
|
||||
import {messagesService} from '../services/messages'
|
||||
import {getProviderById, validateModelId} from '../utils'
|
||||
|
||||
const logger = loggerService.withContext('ApiServerMessagesRoutes')
|
||||
|
||||
const router = express.Router()
|
||||
const providerRouter = express.Router({mergeParams: true})
|
||||
|
||||
// Helper functions for shared logic
|
||||
async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> {
|
||||
logger.info('Validating request body', { body: req.body })
|
||||
const request: MessageCreateParams = req.body
|
||||
|
||||
if (!request) {
|
||||
return {
|
||||
valid: false,
|
||||
error: {
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: 'Request body is required'
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return {valid: true}
|
||||
}
|
||||
|
||||
async function handleStreamingResponse(
|
||||
res: Response,
|
||||
request: MessageCreateParams,
|
||||
provider: any,
|
||||
messagesService: any,
|
||||
logger: any
|
||||
): Promise<void> {
|
||||
res.setHeader('Content-Type', 'text/event-stream; charset=utf-8')
|
||||
res.setHeader('Cache-Control', 'no-cache, no-transform')
|
||||
res.setHeader('Connection', 'keep-alive')
|
||||
res.setHeader('X-Accel-Buffering', 'no')
|
||||
res.flushHeaders()
|
||||
|
||||
try {
|
||||
for await (const chunk of messagesService.processStreamingMessage(request, provider)) {
|
||||
res.write(`data: ${JSON.stringify(chunk)}\n\n`)
|
||||
}
|
||||
res.write('data: [DONE]\n\n')
|
||||
} catch (streamError: any) {
|
||||
logger.error('Stream error:', streamError)
|
||||
res.write(
|
||||
`data: ${JSON.stringify({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'api_error',
|
||||
message: 'Stream processing error'
|
||||
}
|
||||
})}\n\n`
|
||||
)
|
||||
} finally {
|
||||
res.end()
|
||||
}
|
||||
}
|
||||
|
||||
function handleErrorResponse(res: Response, error: any, logger: any): Response {
|
||||
logger.error('Message processing error:', error)
|
||||
|
||||
let statusCode = 500
|
||||
let errorType = 'api_error'
|
||||
let errorMessage = 'Internal server error'
|
||||
|
||||
const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined
|
||||
const anthropicError = error?.error
|
||||
|
||||
if (anthropicStatus) {
|
||||
statusCode = anthropicStatus
|
||||
}
|
||||
|
||||
if (anthropicError?.type) {
|
||||
errorType = anthropicError.type
|
||||
}
|
||||
|
||||
if (anthropicError?.message) {
|
||||
errorMessage = anthropicError.message
|
||||
} else if (error instanceof Error && error.message) {
|
||||
errorMessage = error.message
|
||||
}
|
||||
|
||||
if (!anthropicStatus && error instanceof Error) {
|
||||
if (error.message.includes('API key') || error.message.includes('authentication')) {
|
||||
statusCode = 401
|
||||
errorType = 'authentication_error'
|
||||
} else if (error.message.includes('rate limit') || error.message.includes('quota')) {
|
||||
statusCode = 429
|
||||
errorType = 'rate_limit_error'
|
||||
} else if (error.message.includes('timeout') || error.message.includes('connection')) {
|
||||
statusCode = 502
|
||||
errorType = 'api_error'
|
||||
} else if (error.message.includes('validation') || error.message.includes('invalid')) {
|
||||
statusCode = 400
|
||||
errorType = 'invalid_request_error'
|
||||
}
|
||||
}
|
||||
|
||||
return res.status(statusCode).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: errorType,
|
||||
message: errorMessage,
|
||||
requestId: error?.request_id
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
async function processMessageRequest(
|
||||
req: Request,
|
||||
res: Response,
|
||||
provider: any,
|
||||
modelId?: string
|
||||
): Promise<Response | void> {
|
||||
try {
|
||||
const request: MessageCreateParams = req.body
|
||||
|
||||
// Use provided modelId or keep original model
|
||||
if (modelId) {
|
||||
request.model = modelId
|
||||
}
|
||||
|
||||
logger.info('Processing message request:', {
|
||||
provider: provider.id,
|
||||
model: request.model,
|
||||
messageCount: request.messages?.length || 0,
|
||||
stream: request.stream,
|
||||
max_tokens: request.max_tokens,
|
||||
temperature: request.temperature
|
||||
})
|
||||
|
||||
// Ensure provider is Anthropic type
|
||||
if (provider.type !== 'anthropic') {
|
||||
return res.status(400).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: `Invalid provider type '${provider.type}' for messages endpoint. Expected 'anthropic' provider.`
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
logger.info('Provider validation successful:', {
|
||||
provider: provider.id,
|
||||
providerType: provider.type,
|
||||
modelId: request.model
|
||||
})
|
||||
|
||||
// Validate request
|
||||
const validation = messagesService.validateRequest(request)
|
||||
if (!validation.isValid) {
|
||||
return res.status(400).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: validation.errors.join('; ')
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Handle streaming
|
||||
if (request.stream) {
|
||||
await handleStreamingResponse(res, request, provider, messagesService, logger)
|
||||
return
|
||||
}
|
||||
|
||||
// Handle non-streaming
|
||||
const response = await messagesService.processMessage(request, provider)
|
||||
return res.json(response)
|
||||
} catch (error: any) {
|
||||
return handleErrorResponse(res, error, logger)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
@@ -133,25 +305,20 @@ const router = express.Router()
|
||||
* description: Internal server error
|
||||
*/
|
||||
router.post('/', async (req: Request, res: Response) => {
|
||||
// Validate request body
|
||||
const bodyValidation = await validateRequestBody(req)
|
||||
if (!bodyValidation.valid) {
|
||||
return res.status(400).json(bodyValidation.error)
|
||||
}
|
||||
|
||||
try {
|
||||
const request: MessageCreateParams = req.body
|
||||
|
||||
if (!request) {
|
||||
return res.status(400).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: 'Request body is required'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
logger.info('Anthropic message request:', {
|
||||
model: request.model,
|
||||
messageCount: request.messages?.length || 0,
|
||||
stream: request.stream,
|
||||
max_tokens: request.max_tokens,
|
||||
temperature: request.temperature
|
||||
})
|
||||
|
||||
// Validate model ID and get provider
|
||||
@@ -169,20 +336,7 @@ router.post('/', async (req: Request, res: Response) => {
|
||||
}
|
||||
|
||||
const provider = modelValidation.provider!
|
||||
|
||||
// Ensure provider is Anthropic type
|
||||
if (provider.type !== 'anthropic') {
|
||||
return res.status(400).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: `Invalid provider type '${provider.type}' for messages endpoint. Expected 'anthropic' provider.`
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
const modelId = modelValidation.modelId!
|
||||
request.model = modelId
|
||||
|
||||
logger.info('Model validation successful:', {
|
||||
provider: provider.id,
|
||||
@@ -191,100 +345,181 @@ router.post('/', async (req: Request, res: Response) => {
|
||||
fullModelId: request.model
|
||||
})
|
||||
|
||||
// Validate request
|
||||
const validation = messagesService.validateRequest(request)
|
||||
if (!validation.isValid) {
|
||||
// Use shared processing function
|
||||
return await processMessageRequest(req, res, provider, modelId)
|
||||
} catch (error: any) {
|
||||
return handleErrorResponse(res, error, logger)
|
||||
}
|
||||
})
|
||||
|
||||
/**
|
||||
* @swagger
|
||||
* /{provider_id}/v1/messages:
|
||||
* post:
|
||||
* summary: Create message with provider in path
|
||||
* description: Create a message response using provider ID from URL path
|
||||
* tags: [Messages]
|
||||
* parameters:
|
||||
* - in: path
|
||||
* name: provider_id
|
||||
* required: true
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Provider ID (e.g., "my-anthropic")
|
||||
* example: "my-anthropic"
|
||||
* requestBody:
|
||||
* required: true
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* required:
|
||||
* - model
|
||||
* - max_tokens
|
||||
* - messages
|
||||
* properties:
|
||||
* model:
|
||||
* type: string
|
||||
* description: Model ID without provider prefix
|
||||
* example: "claude-3-5-sonnet-20241022"
|
||||
* max_tokens:
|
||||
* type: integer
|
||||
* minimum: 1
|
||||
* description: Maximum number of tokens to generate
|
||||
* example: 1024
|
||||
* messages:
|
||||
* type: array
|
||||
* items:
|
||||
* type: object
|
||||
* properties:
|
||||
* role:
|
||||
* type: string
|
||||
* enum: [user, assistant]
|
||||
* content:
|
||||
* oneOf:
|
||||
* - type: string
|
||||
* - type: array
|
||||
* system:
|
||||
* type: string
|
||||
* description: System message
|
||||
* temperature:
|
||||
* type: number
|
||||
* minimum: 0
|
||||
* maximum: 1
|
||||
* description: Sampling temperature
|
||||
* top_p:
|
||||
* type: number
|
||||
* minimum: 0
|
||||
* maximum: 1
|
||||
* description: Nucleus sampling
|
||||
* top_k:
|
||||
* type: integer
|
||||
* minimum: 0
|
||||
* description: Top-k sampling
|
||||
* stream:
|
||||
* type: boolean
|
||||
* description: Whether to stream the response
|
||||
* tools:
|
||||
* type: array
|
||||
* description: Available tools for the model
|
||||
* responses:
|
||||
* 200:
|
||||
* description: Message response
|
||||
* content:
|
||||
* application/json:
|
||||
* schema:
|
||||
* type: object
|
||||
* properties:
|
||||
* id:
|
||||
* type: string
|
||||
* type:
|
||||
* type: string
|
||||
* example: message
|
||||
* role:
|
||||
* type: string
|
||||
* example: assistant
|
||||
* content:
|
||||
* type: array
|
||||
* items:
|
||||
* type: object
|
||||
* model:
|
||||
* type: string
|
||||
* stop_reason:
|
||||
* type: string
|
||||
* stop_sequence:
|
||||
* type: string
|
||||
* usage:
|
||||
* type: object
|
||||
* properties:
|
||||
* input_tokens:
|
||||
* type: integer
|
||||
* output_tokens:
|
||||
* type: integer
|
||||
* text/event-stream:
|
||||
* schema:
|
||||
* type: string
|
||||
* description: Server-sent events stream (when stream=true)
|
||||
* 400:
|
||||
* description: Bad request
|
||||
* 401:
|
||||
* description: Unauthorized
|
||||
* 429:
|
||||
* description: Rate limit exceeded
|
||||
* 500:
|
||||
* description: Internal server error
|
||||
*/
|
||||
providerRouter.post('/', async (req: Request, res: Response) => {
|
||||
// Validate request body
|
||||
const bodyValidation = await validateRequestBody(req)
|
||||
if (!bodyValidation.valid) {
|
||||
return res.status(400).json(bodyValidation.error)
|
||||
}
|
||||
|
||||
try {
|
||||
const providerId = req.params.provider
|
||||
const request: MessageCreateParams = req.body
|
||||
|
||||
if (!providerId) {
|
||||
return res.status(400).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: validation.errors.join('; ')
|
||||
message: 'Provider ID is required in URL path'
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// Handle streaming
|
||||
if (request.stream) {
|
||||
res.setHeader('Content-Type', 'text/event-stream; charset=utf-8')
|
||||
res.setHeader('Cache-Control', 'no-cache, no-transform')
|
||||
res.setHeader('Connection', 'keep-alive')
|
||||
res.setHeader('X-Accel-Buffering', 'no')
|
||||
res.flushHeaders()
|
||||
|
||||
try {
|
||||
for await (const chunk of messagesService.processStreamingMessage(request, provider)) {
|
||||
res.write(`data: ${JSON.stringify(chunk)}\n\n`)
|
||||
}
|
||||
res.write('data: [DONE]\n\n')
|
||||
} catch (streamError: any) {
|
||||
logger.error('Stream error:', streamError)
|
||||
res.write(
|
||||
`data: ${JSON.stringify({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'api_error',
|
||||
message: 'Stream processing error'
|
||||
}
|
||||
})}\n\n`
|
||||
)
|
||||
} finally {
|
||||
res.end()
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// Handle non-streaming
|
||||
const response = await messagesService.processMessage(request, provider)
|
||||
return res.json(response)
|
||||
} catch (error: any) {
|
||||
logger.error('Anthropic message error:', error)
|
||||
|
||||
let statusCode = 500
|
||||
let errorType = 'api_error'
|
||||
let errorMessage = 'Internal server error'
|
||||
|
||||
const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined
|
||||
const anthropicError = error?.error
|
||||
|
||||
if (anthropicStatus) {
|
||||
statusCode = anthropicStatus
|
||||
}
|
||||
|
||||
if (anthropicError?.type) {
|
||||
errorType = anthropicError.type
|
||||
}
|
||||
|
||||
if (anthropicError?.message) {
|
||||
errorMessage = anthropicError.message
|
||||
} else if (error instanceof Error && error.message) {
|
||||
errorMessage = error.message
|
||||
}
|
||||
|
||||
if (!anthropicStatus && error instanceof Error) {
|
||||
if (error.message.includes('API key') || error.message.includes('authentication')) {
|
||||
statusCode = 401
|
||||
errorType = 'authentication_error'
|
||||
} else if (error.message.includes('rate limit') || error.message.includes('quota')) {
|
||||
statusCode = 429
|
||||
errorType = 'rate_limit_error'
|
||||
} else if (error.message.includes('timeout') || error.message.includes('connection')) {
|
||||
statusCode = 502
|
||||
errorType = 'api_error'
|
||||
} else if (error.message.includes('validation') || error.message.includes('invalid')) {
|
||||
statusCode = 400
|
||||
errorType = 'invalid_request_error'
|
||||
}
|
||||
}
|
||||
|
||||
return res.status(statusCode).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: errorType,
|
||||
message: errorMessage,
|
||||
requestId: error?.request_id
|
||||
}
|
||||
logger.info('Provider-specific message request:', {
|
||||
providerId,
|
||||
model: request.model,
|
||||
messageCount: request.messages?.length || 0,
|
||||
stream: request.stream,
|
||||
max_tokens: request.max_tokens
|
||||
})
|
||||
|
||||
// Get provider directly by ID from URL path
|
||||
const provider = await getProviderById(providerId)
|
||||
if (!provider) {
|
||||
return res.status(400).json({
|
||||
type: 'error',
|
||||
error: {
|
||||
type: 'invalid_request_error',
|
||||
message: `Provider '${providerId}' not found or not enabled`
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
logger.info('Provider validation successful:', {
|
||||
provider: provider.id,
|
||||
providerType: provider.type,
|
||||
modelId: request.model
|
||||
})
|
||||
|
||||
// Use shared processing function (no modelId override needed)
|
||||
return await processMessageRequest(req, res, provider)
|
||||
} catch (error: any) {
|
||||
return handleErrorResponse(res, error, logger)
|
||||
}
|
||||
})
|
||||
|
||||
export { router as messagesRoutes }
|
||||
export {providerRouter as messagesProviderRoutes, router as messagesRoutes}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import { Message, MessageCreateParams, RawMessageStreamEvent } from '@anthropic-ai/sdk/resources'
|
||||
import { Provider } from '@types'
|
||||
import Anthropic from "@anthropic-ai/sdk";
|
||||
import {Message, MessageCreateParams, RawMessageStreamEvent} from '@anthropic-ai/sdk/resources'
|
||||
import {loggerService} from '@logger'
|
||||
import anthropicService from "@main/services/AnthropicService";
|
||||
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
|
||||
import {Provider} from '@types'
|
||||
|
||||
import { loggerService } from '../../services/LoggerService'
|
||||
|
||||
const logger = loggerService.withContext('MessagesService')
|
||||
|
||||
@@ -35,6 +37,16 @@ export class MessagesService {
|
||||
}
|
||||
}
|
||||
|
||||
async getClient(provider: Provider): Promise<Anthropic> {
|
||||
// Create Anthropic client for the provider
|
||||
if (provider.authType === 'oauth') {
|
||||
const oauthToken = await anthropicService.getValidAccessToken()
|
||||
return getSdkClient(provider, oauthToken)
|
||||
}
|
||||
return getSdkClient(provider)
|
||||
}
|
||||
|
||||
|
||||
async processMessage(request: MessageCreateParams, provider: Provider): Promise<Message> {
|
||||
logger.info('Processing Anthropic message request:', {
|
||||
model: request.model,
|
||||
@@ -44,10 +56,7 @@ export class MessagesService {
|
||||
})
|
||||
|
||||
// Create Anthropic client for the provider
|
||||
const client = new Anthropic({
|
||||
baseURL: provider.apiHost,
|
||||
apiKey: provider.apiKey
|
||||
})
|
||||
const client = await this.getClient(provider)
|
||||
|
||||
// Prepare request with the actual model ID
|
||||
const anthropicRequest: MessageCreateParams = {
|
||||
@@ -55,6 +64,10 @@ export class MessagesService {
|
||||
stream: false
|
||||
}
|
||||
|
||||
if (provider.authType === 'oauth') {
|
||||
anthropicRequest.system = buildClaudeCodeSystemMessage(request.system || '')
|
||||
}
|
||||
|
||||
logger.debug('Sending request to Anthropic provider:', {
|
||||
provider: provider.id,
|
||||
apiHost: provider.apiHost
|
||||
@@ -66,7 +79,7 @@ export class MessagesService {
|
||||
return response
|
||||
}
|
||||
|
||||
async *processStreamingMessage(
|
||||
async* processStreamingMessage(
|
||||
request: MessageCreateParams,
|
||||
provider: Provider
|
||||
): AsyncIterable<RawMessageStreamEvent> {
|
||||
@@ -76,10 +89,7 @@ export class MessagesService {
|
||||
})
|
||||
|
||||
// Create Anthropic client for the provider
|
||||
const client = new Anthropic({
|
||||
baseURL: provider.apiHost,
|
||||
apiKey: provider.apiKey
|
||||
})
|
||||
const client = await this.getClient(provider)
|
||||
|
||||
// Prepare streaming request
|
||||
const streamingRequest: MessageCreateParams = {
|
||||
@@ -87,6 +97,10 @@ export class MessagesService {
|
||||
stream: true
|
||||
}
|
||||
|
||||
if (provider.authType === 'oauth') {
|
||||
streamingRequest.system = buildClaudeCodeSystemMessage(request.system || '')
|
||||
}
|
||||
|
||||
logger.debug('Sending streaming request to Anthropic provider:', {
|
||||
provider: provider.id,
|
||||
apiHost: provider.apiHost
|
||||
|
||||
@@ -204,6 +204,31 @@ export function transformModelToOpenAI(model: Model, providers: Provider[]): Api
|
||||
}
|
||||
}
|
||||
|
||||
export async function getProviderById(providerId: string): Promise<Provider | undefined> {
|
||||
try {
|
||||
if (!providerId || typeof providerId !== 'string') {
|
||||
logger.warn(`Invalid provider ID parameter: ${providerId}`)
|
||||
return undefined
|
||||
}
|
||||
|
||||
const providers = await getAvailableProviders()
|
||||
const provider = providers.find((p: Provider) => p.id === providerId)
|
||||
|
||||
if (!provider) {
|
||||
logger.warn(
|
||||
`Provider '${providerId}' not found or not enabled. Available providers: ${providers.map((p) => p.id).join(', ')}`
|
||||
)
|
||||
return undefined
|
||||
}
|
||||
|
||||
logger.debug(`Found provider '${providerId}'`)
|
||||
return provider
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to get provider by ID:', error)
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
|
||||
export function validateProvider(provider: Provider): boolean {
|
||||
try {
|
||||
if (!provider) {
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type {
|
||||
AgentSessionMessageEntity,
|
||||
CreateSessionMessageRequest,
|
||||
GetAgentSessionResponse,
|
||||
ListOptions
|
||||
} from '@types'
|
||||
import { ModelMessage, TextStreamPart } from 'ai'
|
||||
import { desc, eq } from 'drizzle-orm'
|
||||
import {loggerService} from '@logger'
|
||||
import type {AgentSessionMessageEntity, CreateSessionMessageRequest, GetAgentSessionResponse, ListOptions} from '@types'
|
||||
import {TextStreamPart} from 'ai'
|
||||
import {desc, eq} from 'drizzle-orm'
|
||||
|
||||
import { BaseService } from '../BaseService'
|
||||
import { sessionMessagesTable } from '../database/schema'
|
||||
import { AgentStreamEvent } from '../interfaces/AgentStreamInterface'
|
||||
import {BaseService} from '../BaseService'
|
||||
import {sessionMessagesTable} from '../database/schema'
|
||||
import {AgentStreamEvent} from '../interfaces/AgentStreamInterface'
|
||||
import ClaudeCodeService from './claudecode'
|
||||
|
||||
const logger = loggerService.withContext('SessionMessageService')
|
||||
@@ -34,7 +29,7 @@ function serializeError(error: unknown): { message: string; name?: string; stack
|
||||
}
|
||||
|
||||
if (typeof error === 'string') {
|
||||
return { message: error }
|
||||
return {message: error}
|
||||
}
|
||||
|
||||
return {
|
||||
@@ -104,7 +99,7 @@ export class SessionMessageService extends BaseService {
|
||||
this.ensureInitialized()
|
||||
|
||||
const result = await this.database
|
||||
.select({ id: sessionMessagesTable.id })
|
||||
.select({id: sessionMessagesTable.id})
|
||||
.from(sessionMessagesTable)
|
||||
.where(eq(sessionMessagesTable.id, id))
|
||||
.limit(1)
|
||||
@@ -134,7 +129,7 @@ export class SessionMessageService extends BaseService {
|
||||
|
||||
const messages = result.map((row) => this.deserializeSessionMessage(row)) as AgentSessionMessageEntity[]
|
||||
|
||||
return { messages }
|
||||
return {messages}
|
||||
}
|
||||
|
||||
async createSessionMessage(
|
||||
@@ -153,11 +148,11 @@ export class SessionMessageService extends BaseService {
|
||||
abortController: AbortController
|
||||
): Promise<SessionStreamResult> {
|
||||
const agentSessionId = await this.getLastAgentSessionId(session.id)
|
||||
logger.debug('Session Message stream message data:', { message: req, session_id: agentSessionId })
|
||||
logger.debug('Session Message stream message data:', {message: req, session_id: agentSessionId})
|
||||
|
||||
if (session.agent_type !== 'claude-code') {
|
||||
// TODO: Implement support for other agent types
|
||||
logger.error('Unsupported agent type for streaming:', { agent_type: session.agent_type })
|
||||
logger.error('Unsupported agent type for streaming:', {agent_type: session.agent_type})
|
||||
throw new Error('Unsupported agent type for streaming')
|
||||
}
|
||||
|
||||
@@ -248,7 +243,7 @@ export class SessionMessageService extends BaseService {
|
||||
}
|
||||
})
|
||||
|
||||
return { stream, completion }
|
||||
return {stream, completion}
|
||||
}
|
||||
|
||||
private async getLastAgentSessionId(sessionId: string): Promise<string> {
|
||||
@@ -256,7 +251,7 @@ export class SessionMessageService extends BaseService {
|
||||
|
||||
try {
|
||||
const result = await this.database
|
||||
.select({ agent_session_id: sessionMessagesTable.agent_session_id })
|
||||
.select({agent_session_id: sessionMessagesTable.agent_session_id})
|
||||
.from(sessionMessagesTable)
|
||||
.where(eq(sessionMessagesTable.session_id, sessionId))
|
||||
.orderBy(desc(sessionMessagesTable.created_at))
|
||||
@@ -275,7 +270,7 @@ export class SessionMessageService extends BaseService {
|
||||
private deserializeSessionMessage(data: any): AgentSessionMessageEntity {
|
||||
if (!data) return data
|
||||
|
||||
const deserialized = { ...data }
|
||||
const deserialized = {...data}
|
||||
|
||||
// Parse content JSON
|
||||
if (deserialized.content && typeof deserialized.content === 'string') {
|
||||
|
||||
@@ -69,17 +69,9 @@ class ClaudeCodeService implements AgentServiceInterface {
|
||||
// process.env.ANTHROPIC_BASE_URL = `http://${apiConfig.host}:${apiConfig.port}`
|
||||
const env = {
|
||||
...process.env,
|
||||
ELECTRON_RUN_AS_NODE: '1',
|
||||
}
|
||||
|
||||
if (modelInfo.provider.authType === 'oauth') {
|
||||
// TODO: support claude code max oauth
|
||||
// env['ANTHROPIC_AUTH_TOKEN'] = await anthropicService.getValidAccessToken()
|
||||
// env['ANTHROPIC_BASE_URL'] = 'https://api.anthropic.com'
|
||||
} else {
|
||||
env['ANTHROPIC_AUTH_TOKEN'] = modelInfo.provider.apiKey
|
||||
env['ANTHROPIC_API_KEY'] = modelInfo.provider.apiKey
|
||||
env['ANTHROPIC_BASE_URL'] = modelInfo.provider.apiHost
|
||||
ANTHROPIC_API_KEY: apiConfig.apiKey,
|
||||
ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`,
|
||||
ELECTRON_RUN_AS_NODE: '1'
|
||||
}
|
||||
|
||||
// Build SDK options from parameters
|
||||
@@ -121,7 +113,7 @@ class ClaudeCodeService implements AgentServiceInterface {
|
||||
options.resume = lastAgentSessionId
|
||||
}
|
||||
|
||||
logger.info('Starting Claude Code SDK query', {
|
||||
logger.silly('Starting Claude Code SDK query', {
|
||||
prompt,
|
||||
options
|
||||
})
|
||||
|
||||
@@ -150,8 +150,7 @@ function handleStreamEvent(message: Extract<SDKMessage, { type: 'stream_event' }
|
||||
break
|
||||
|
||||
case 'content_block_start':
|
||||
const contentBlockType = event.content_block.type
|
||||
switch (contentBlockType) {
|
||||
switch (event.content_block.type) {
|
||||
case 'text': {
|
||||
contentBlockState.set(blockKey, { type: 'text' })
|
||||
chunks.push({
|
||||
|
||||
@@ -22,14 +22,14 @@ import {
|
||||
WebSearchToolResultBlockParam,
|
||||
WebSearchToolResultError
|
||||
} from '@anthropic-ai/sdk/resources/messages'
|
||||
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||
import {MessageStream} from '@anthropic-ai/sdk/resources/messages/messages'
|
||||
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
||||
import { loggerService } from '@logger'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
|
||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||
import {loggerService} from '@logger'
|
||||
import {DEFAULT_MAX_TOKENS} from '@renderer/config/constant'
|
||||
import {findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel} from '@renderer/config/models'
|
||||
import {getAssistantSettings} from '@renderer/services/AssistantService'
|
||||
import FileManager from '@renderer/services/FileManager'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
import {estimateTextTokens} from '@renderer/services/TokenService'
|
||||
import {
|
||||
Assistant,
|
||||
EFFORT_RATIO,
|
||||
@@ -53,26 +53,17 @@ import {
|
||||
ThinkingDeltaChunk,
|
||||
ThinkingStartChunk
|
||||
} from '@renderer/types/chunk'
|
||||
import { type Message } from '@renderer/types/newMessage'
|
||||
import {
|
||||
AnthropicSdkMessageParam,
|
||||
AnthropicSdkParams,
|
||||
AnthropicSdkRawChunk,
|
||||
AnthropicSdkRawOutput
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
anthropicToolUseToMcpTool,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToAnthropicMessage,
|
||||
mcpToolsToAnthropicTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||
import { t } from 'i18next'
|
||||
import {type Message} from '@renderer/types/newMessage'
|
||||
import {AnthropicSdkMessageParam, AnthropicSdkParams, AnthropicSdkRawChunk, AnthropicSdkRawOutput} from '@renderer/types/sdk'
|
||||
import {addImageFileToContents} from '@renderer/utils/formats'
|
||||
import {anthropicToolUseToMcpTool, isSupportedToolUse, mcpToolCallResponseToAnthropicMessage, mcpToolsToAnthropicTools} from '@renderer/utils/mcp-tools'
|
||||
import {findFileBlocks, findImageBlocks} from '@renderer/utils/messageUtils/find'
|
||||
import {buildClaudeCodeSystemMessage, getSdkClient} from "@shared/anthropic";
|
||||
import {t} from 'i18next'
|
||||
|
||||
import { GenericChunk } from '../../middleware/schemas'
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
import {GenericChunk} from '../../middleware/schemas'
|
||||
import {BaseApiClient} from '../BaseApiClient'
|
||||
import {AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer} from '../types'
|
||||
|
||||
const logger = loggerService.withContext('AnthropicAPIClient')
|
||||
|
||||
@@ -86,8 +77,8 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
ToolUnion
|
||||
> {
|
||||
oauthToken: string | undefined = undefined
|
||||
isOAuthMode: boolean = false
|
||||
sdkInstance: Anthropic | AnthropicVertex | undefined = undefined
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
@@ -96,84 +87,25 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
if (this.provider.authType === 'oauth') {
|
||||
if (!this.oauthToken) {
|
||||
throw new Error('OAuth token is not available')
|
||||
}
|
||||
this.sdkInstance = new Anthropic({
|
||||
authToken: this.oauthToken,
|
||||
baseURL: 'https://api.anthropic.com',
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: {
|
||||
'Content-Type': 'application/json',
|
||||
'anthropic-version': '2023-06-01',
|
||||
'anthropic-beta': 'oauth-2025-04-20'
|
||||
// ...this.provider.extra_headers
|
||||
}
|
||||
})
|
||||
} else {
|
||||
this.sdkInstance = new Anthropic({
|
||||
apiKey: this.apiKey,
|
||||
baseURL: this.getBaseURL(),
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: {
|
||||
'anthropic-beta': 'output-128k-2025-02-19',
|
||||
...this.provider.extra_headers
|
||||
}
|
||||
})
|
||||
this.oauthToken = await window.api.anthropic_oauth.getAccessToken()
|
||||
}
|
||||
|
||||
this.sdkInstance = getSdkClient(this.provider, this.oauthToken)
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
private buildClaudeCodeSystemMessage(system?: string | Array<TextBlockParam>): string | Array<TextBlockParam> {
|
||||
const defaultClaudeCodeSystem = `You are Claude Code, Anthropic's official CLI for Claude.`
|
||||
if (!system) {
|
||||
return defaultClaudeCodeSystem
|
||||
}
|
||||
|
||||
if (typeof system === 'string') {
|
||||
if (system.trim() === defaultClaudeCodeSystem) {
|
||||
return system
|
||||
}
|
||||
return [
|
||||
{
|
||||
type: 'text',
|
||||
text: defaultClaudeCodeSystem
|
||||
},
|
||||
{
|
||||
type: 'text',
|
||||
text: system
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
if (system[0].text.trim() != defaultClaudeCodeSystem) {
|
||||
system.unshift({
|
||||
type: 'text',
|
||||
text: defaultClaudeCodeSystem
|
||||
})
|
||||
}
|
||||
|
||||
return system
|
||||
}
|
||||
|
||||
override async createCompletions(
|
||||
payload: AnthropicSdkParams,
|
||||
options?: Anthropic.RequestOptions
|
||||
): Promise<AnthropicSdkRawOutput> {
|
||||
if (this.provider.authType === 'oauth') {
|
||||
this.oauthToken = await window.api.anthropic_oauth.getAccessToken()
|
||||
this.isOAuthMode = true
|
||||
logger.info('[Anthropic Provider] Using OAuth token for authentication')
|
||||
payload.system = this.buildClaudeCodeSystemMessage(payload.system)
|
||||
payload.system = buildClaudeCodeSystemMessage(payload.system)
|
||||
}
|
||||
const sdk = (await this.getSdkInstance()) as Anthropic
|
||||
if (payload.stream) {
|
||||
return sdk.messages.stream(payload, options)
|
||||
}
|
||||
return await sdk.messages.create(payload, options)
|
||||
return sdk.messages.create(payload, options);
|
||||
}
|
||||
|
||||
// @ts-ignore sdk未提供
|
||||
@@ -183,14 +115,8 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
}
|
||||
|
||||
override async listModels(): Promise<Anthropic.ModelInfo[]> {
|
||||
if (this.provider.authType === 'oauth') {
|
||||
this.oauthToken = await window.api.anthropic_oauth.getAccessToken()
|
||||
this.isOAuthMode = true
|
||||
logger.info('[Anthropic Provider] Using OAuth token for authentication')
|
||||
}
|
||||
const sdk = (await this.getSdkInstance()) as Anthropic
|
||||
const response = await sdk.models.list()
|
||||
|
||||
return response.data
|
||||
}
|
||||
|
||||
@@ -223,7 +149,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
if (!isReasoningModel(model)) {
|
||||
return undefined
|
||||
}
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
const {maxTokens} = getAssistantSettings(assistant)
|
||||
|
||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||
|
||||
@@ -240,7 +166,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
Math.floor(
|
||||
Math.min(
|
||||
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
|
||||
findTokenLimit(model.id)?.min!,
|
||||
findTokenLimit(model.id)?.min!,
|
||||
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
|
||||
)
|
||||
)
|
||||
@@ -262,7 +188,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
* @returns The message parameter
|
||||
*/
|
||||
public async convertMessageToSdkParam(message: Message): Promise<AnthropicSdkMessageParam> {
|
||||
const { textContent, imageContents } = await this.getMessageContent(message)
|
||||
const {textContent, imageContents} = await this.getMessageContent(message)
|
||||
|
||||
const parts: MessageParam['content'] = [
|
||||
{
|
||||
@@ -285,7 +211,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
}
|
||||
})
|
||||
} else {
|
||||
logger.warn('Unsupported image type, ignored.', { mime: base64Data.mime })
|
||||
logger.warn('Unsupported image type, ignored.', {mime: base64Data.mime})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -310,7 +236,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
// Get and process file blocks
|
||||
const fileBlocks = findFileBlocks(message)
|
||||
for (const fileBlock of fileBlocks) {
|
||||
const { file } = fileBlock
|
||||
const {file} = fileBlock
|
||||
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||
if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) {
|
||||
const base64Data = await FileManager.readBase64File(file)
|
||||
@@ -538,25 +464,25 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
messages: AnthropicSdkMessageParam[]
|
||||
metadata: Record<string, any>
|
||||
}> => {
|
||||
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
|
||||
const {messages, mcpTools, maxTokens, streamOutput, enableWebSearch} = coreRequest
|
||||
// 1. 处理系统消息
|
||||
const systemPrompt = assistant.prompt
|
||||
|
||||
// 2. 设置工具
|
||||
const { tools } = this.setupToolsConfig({
|
||||
const {tools} = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
const systemMessage: TextBlockParam | undefined = systemPrompt
|
||||
? { type: 'text', text: systemPrompt }
|
||||
? {type: 'text', text: systemPrompt}
|
||||
: undefined
|
||||
|
||||
// 3. 处理用户消息
|
||||
const sdkMessages: AnthropicSdkMessageParam[] = []
|
||||
if (typeof messages === 'string') {
|
||||
sdkMessages.push({ role: 'user', content: messages })
|
||||
sdkMessages.push({role: 'user', content: messages})
|
||||
} else {
|
||||
const processedMessages = addImageFileToContents(messages)
|
||||
for (const message of processedMessages) {
|
||||
@@ -590,7 +516,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
}
|
||||
|
||||
const timeout = this.getTimeout(model)
|
||||
return { payload: commonParams, messages: sdkMessages, metadata: { timeout } }
|
||||
return {payload: commonParams, messages: sdkMessages, metadata: {timeout}}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -605,7 +531,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
try {
|
||||
rawChunk = JSON.parse(rawChunk)
|
||||
} catch (error) {
|
||||
logger.error('invalid chunk', { rawChunk, error })
|
||||
logger.error('invalid chunk', {rawChunk, error})
|
||||
throw new Error(t('error.chat.chunk.non_json'))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -77,3 +77,20 @@ Content-Type: application/json
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
### Anthropic Chat Message with streaming
|
||||
POST {{host}}/anthropic/v1/messages
|
||||
Authorization: Bearer {{token}}
|
||||
Content-Type: application/json
|
||||
|
||||
{
|
||||
"model": "claude-sonnet-4-20250514",
|
||||
"stream": true,
|
||||
"max_tokens": 1024,
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "Explain the theory of relativity in simple terms."
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user