Feat/vertex-claude-support (#7564)
* feat(migrate): add default settings for assistants during migration - Introduced a new migration step to assign default settings for assistants that lack configuration. - Default settings include temperature, context count, and other parameters to ensure consistent behavior across the application. * chore(store): increment version number to 115 for persisted reducer * feat(vertex-sdk): integrate Anthropic Vertex SDK and add access token retrieval - Added support for the new `@anthropic-ai/vertex-sdk` in the project. - Introduced a new IPC channel `VertexAI_GetAccessToken` to retrieve access tokens. - Implemented `getAccessToken` method in `VertexAIService` to handle service account authentication. - Updated the `IpcChannel` enum and related IPC handlers to support the new functionality. - Enhanced the `VertexAPIClient` to utilize the `AnthropicVertexClient` for model handling. - Refactored existing code to accommodate the integration of the Vertex SDK and improve modularity. * feat(vertex-ai): enhance VertexAI settings and API host management - Added a new method to format the API host URL in both AnthropicVertexClient and VertexAPIClient. - Updated getBaseURL methods to utilize the new formatting logic. - Enhanced VertexAISettings component to include an input for API host configuration, with help text for user guidance. - Updated localization files to include new help text for the API host field in multiple languages. * fix(vertex-sdk): update baseURL handling and patch dependencies - Refactored baseURL assignment in AnthropicVertexClient to ensure it defaults to undefined when the URL is empty. - Updated yarn.lock to reflect changes in dependency resolution and checksum for @anthropic-ai/vertex-sdk patch. * refactor(VertexAISetting): use provider.id rather than provider * refactor: improve API host formatting in AnthropicVertexClient - Updated the `formatApiHost` method to streamline host URL handling. - Introduced a helper function to determine if the original host should be used based on its format. - Ensured consistent appending of the `/v1/` path for valid API requests. * fix: handle empty host in AnthropicVertexClient - Added a check in the `getBaseURL` method to return the host if it is empty, preventing potential errors. - Included a console log for the base URL to aid in debugging and verification of the URL formatting. * feat(AnthropicVertexClient): add logging for authentication errors and mock client in tests - Introduced logging functionality in AnthropicVertexClient to replace console.error with logger service for better error tracking. - Added mock implementation for AnthropicVertexClient in tests to enhance testing capabilities. - Updated package.json to include the @aws-sdk/client-s3 dependency. * feat(tests): add comprehensive tests for client compatibility types - Introduced a new test file to validate compatibility types for various API clients including OpenAI, Anthropic, Gemini, Aihubmix, NewAPI, and Vertex. - Implemented mock services to facilitate testing and ensure isolation of client behavior. - Added tests for both direct API clients and decorator pattern clients, ensuring correct compatibility type returns. - Enhanced middleware compatibility logic tests to verify correct identification of compatible clients. --------- Co-authored-by: one <wangan.cs@gmail.com>
This commit is contained in:
@@ -533,6 +533,10 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
return vertexAIService.getAuthHeaders(params)
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.VertexAI_GetAccessToken, async (_, params) => {
|
||||
return vertexAIService.getAccessToken(params)
|
||||
})
|
||||
|
||||
ipcMain.handle(IpcChannel.VertexAI_ClearAuthCache, async (_, projectId: string, clientEmail?: string) => {
|
||||
vertexAIService.clearAuthCache(projectId, clientEmail)
|
||||
})
|
||||
|
||||
@@ -114,6 +114,37 @@ class VertexAIService {
|
||||
}
|
||||
}
|
||||
|
||||
async getAccessToken(params: VertexAIAuthParams): Promise<string> {
|
||||
const { projectId, serviceAccount } = params
|
||||
|
||||
if (!serviceAccount?.privateKey || !serviceAccount?.clientEmail) {
|
||||
throw new Error('Service account credentials are required')
|
||||
}
|
||||
|
||||
const formattedPrivateKey = this.formatPrivateKey(serviceAccount.privateKey)
|
||||
|
||||
const cacheKey = `${projectId}-${serviceAccount.clientEmail}`
|
||||
|
||||
let auth = this.authClients.get(cacheKey)
|
||||
|
||||
if (!auth) {
|
||||
auth = new GoogleAuth({
|
||||
credentials: {
|
||||
private_key: formattedPrivateKey,
|
||||
client_email: serviceAccount.clientEmail
|
||||
},
|
||||
projectId,
|
||||
scopes: [REQUIRED_VERTEX_AI_SCOPE]
|
||||
})
|
||||
|
||||
this.authClients.set(cacheKey, auth)
|
||||
}
|
||||
|
||||
const accessToken = await auth.getAccessToken()
|
||||
|
||||
return accessToken || ''
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理指定项目的认证缓存
|
||||
*/
|
||||
|
||||
@@ -246,6 +246,8 @@ const api = {
|
||||
vertexAI: {
|
||||
getAuthHeaders: (params: { projectId: string; serviceAccount?: { privateKey: string; clientEmail: string } }) =>
|
||||
ipcRenderer.invoke(IpcChannel.VertexAI_GetAuthHeaders, params),
|
||||
getAccessToken: (params: { projectId: string; serviceAccount?: { privateKey: string; clientEmail: string } }) =>
|
||||
ipcRenderer.invoke(IpcChannel.VertexAI_GetAccessToken, params),
|
||||
clearAuthCache: (projectId: string, clientEmail?: string) =>
|
||||
ipcRenderer.invoke(IpcChannel.VertexAI_ClearAuthCache, projectId, clientEmail)
|
||||
},
|
||||
|
||||
@@ -0,0 +1,347 @@
|
||||
import { AihubmixAPIClient } from '@renderer/aiCore/clients/AihubmixAPIClient'
|
||||
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
|
||||
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
|
||||
import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient'
|
||||
import { VertexAPIClient } from '@renderer/aiCore/clients/gemini/VertexAPIClient'
|
||||
import { NewAPIClient } from '@renderer/aiCore/clients/NewAPIClient'
|
||||
import { OpenAIAPIClient } from '@renderer/aiCore/clients/openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from '@renderer/aiCore/clients/openai/OpenAIResponseAPIClient'
|
||||
import { EndpointType, Model, Provider } from '@renderer/types'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
vi.mock('@renderer/config/models', () => ({
|
||||
SYSTEM_MODELS: {
|
||||
defaultModel: [
|
||||
{ id: 'gpt-4', name: 'GPT-4' },
|
||||
{ id: 'gpt-4', name: 'GPT-4' },
|
||||
{ id: 'gpt-4', name: 'GPT-4' }
|
||||
],
|
||||
silicon: [],
|
||||
openai: [],
|
||||
anthropic: [],
|
||||
gemini: []
|
||||
},
|
||||
isOpenAILLMModel: vi.fn().mockReturnValue(true),
|
||||
isOpenAIChatCompletionOnlyModel: vi.fn().mockReturnValue(false),
|
||||
isAnthropicLLMModel: vi.fn().mockReturnValue(false),
|
||||
isGeminiLLMModel: vi.fn().mockReturnValue(false),
|
||||
isSupportedReasoningEffortOpenAIModel: vi.fn().mockReturnValue(false),
|
||||
isVisionModel: vi.fn().mockReturnValue(false),
|
||||
isClaudeReasoningModel: vi.fn().mockReturnValue(false),
|
||||
isReasoningModel: vi.fn().mockReturnValue(false),
|
||||
isWebSearchModel: vi.fn().mockReturnValue(false),
|
||||
findTokenLimit: vi.fn().mockReturnValue(4096),
|
||||
isFunctionCallingModel: vi.fn().mockReturnValue(false),
|
||||
DEFAULT_MAX_TOKENS: 4096
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
getProviderByModel: vi.fn(),
|
||||
getAssistantSettings: vi.fn(),
|
||||
getDefaultAssistant: vi.fn().mockReturnValue({
|
||||
id: 'default',
|
||||
name: 'Default Assistant',
|
||||
prompt: '',
|
||||
settings: {}
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/FileManager', () => ({
|
||||
default: class {
|
||||
static async read() {
|
||||
return 'test content'
|
||||
}
|
||||
static async write() {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/TokenService', () => ({
|
||||
estimateTextTokens: vi.fn().mockReturnValue(100)
|
||||
}))
|
||||
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: vi.fn().mockReturnValue({
|
||||
debug: vi.fn(),
|
||||
info: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
error: vi.fn(),
|
||||
silly: vi.fn()
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
// Mock additional services and hooks that might be imported
|
||||
vi.mock('@renderer/hooks/useVertexAI', () => ({
|
||||
getVertexAILocation: vi.fn().mockReturnValue('us-central1'),
|
||||
getVertexAIProjectId: vi.fn().mockReturnValue('test-project'),
|
||||
getVertexAIServiceAccount: vi.fn().mockReturnValue({
|
||||
privateKey: 'test-key',
|
||||
clientEmail: 'test@example.com'
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/hooks/useSettings', () => ({
|
||||
getStoreSetting: vi.fn().mockReturnValue({}),
|
||||
useSettings: vi.fn().mockReturnValue([{}, vi.fn()])
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store/settings', () => ({
|
||||
default: {},
|
||||
settingsSlice: {
|
||||
name: 'settings',
|
||||
reducer: vi.fn(),
|
||||
actions: {}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/abortController', () => ({
|
||||
addAbortController: vi.fn(),
|
||||
removeAbortController: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@anthropic-ai/sdk', () => ({
|
||||
default: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
|
||||
vi.mock('@anthropic-ai/vertex-sdk', () => ({
|
||||
default: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
|
||||
vi.mock('openai', () => ({
|
||||
default: vi.fn().mockImplementation(() => ({})),
|
||||
AzureOpenAI: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
|
||||
vi.mock('@google/generative-ai', () => ({
|
||||
GoogleGenerativeAI: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
|
||||
vi.mock('@google-cloud/vertexai', () => ({
|
||||
VertexAI: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
|
||||
// Mock the circular dependency between VertexAPIClient and AnthropicVertexClient
|
||||
vi.mock('@renderer/aiCore/clients/anthropic/AnthropicVertexClient', () => {
|
||||
const MockAnthropicVertexClient = vi.fn()
|
||||
MockAnthropicVertexClient.prototype.getClientCompatibilityType = vi.fn().mockReturnValue(['AnthropicVertexAPIClient'])
|
||||
return {
|
||||
AnthropicVertexClient: MockAnthropicVertexClient
|
||||
}
|
||||
})
|
||||
|
||||
// Helper to create test provider
|
||||
const createTestProvider = (id: string, type: string): Provider => ({
|
||||
id,
|
||||
type: type as Provider['type'],
|
||||
name: 'Test Provider',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://api.test.com',
|
||||
models: []
|
||||
})
|
||||
|
||||
// Helper to create test model
|
||||
const createTestModel = (id: string, provider?: string, endpointType?: string): Model => ({
|
||||
id,
|
||||
name: 'Test Model',
|
||||
provider: provider || 'test',
|
||||
type: [],
|
||||
group: 'test',
|
||||
endpoint_type: endpointType as EndpointType
|
||||
})
|
||||
|
||||
describe('Client Compatibility Types', () => {
|
||||
let openaiProvider: Provider
|
||||
let anthropicProvider: Provider
|
||||
let geminiProvider: Provider
|
||||
let azureProvider: Provider
|
||||
let aihubmixProvider: Provider
|
||||
let newApiProvider: Provider
|
||||
let vertexProvider: Provider
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
openaiProvider = createTestProvider('openai', 'openai')
|
||||
anthropicProvider = createTestProvider('anthropic', 'anthropic')
|
||||
geminiProvider = createTestProvider('gemini', 'gemini')
|
||||
azureProvider = createTestProvider('azure-openai', 'azure-openai')
|
||||
aihubmixProvider = createTestProvider('aihubmix', 'openai')
|
||||
newApiProvider = createTestProvider('new-api', 'openai')
|
||||
vertexProvider = createTestProvider('vertex', 'vertexai')
|
||||
})
|
||||
|
||||
describe('Direct API Clients', () => {
|
||||
it('should return correct compatibility type for OpenAIAPIClient', () => {
|
||||
const client = new OpenAIAPIClient(openaiProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['OpenAIAPIClient'])
|
||||
})
|
||||
|
||||
it('should return correct compatibility type for AnthropicAPIClient', () => {
|
||||
const client = new AnthropicAPIClient(anthropicProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['AnthropicAPIClient'])
|
||||
})
|
||||
|
||||
it('should return correct compatibility type for GeminiAPIClient', () => {
|
||||
const client = new GeminiAPIClient(geminiProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['GeminiAPIClient'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('Decorator Pattern API Clients', () => {
|
||||
it('should return OpenAIResponseAPIClient for OpenAIResponseAPIClient without model', () => {
|
||||
const client = new OpenAIResponseAPIClient(azureProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient'])
|
||||
})
|
||||
|
||||
it('should delegate to underlying client for OpenAIResponseAPIClient with model', () => {
|
||||
const client = new OpenAIResponseAPIClient(azureProvider)
|
||||
const testModel = createTestModel('gpt-4', 'azure-openai')
|
||||
|
||||
// Get the actual client selected for this model
|
||||
const actualClient = client.getClient(testModel)
|
||||
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
|
||||
|
||||
// Should return OpenAIResponseAPIClient for non-chat-completion-only models
|
||||
expect(compatibilityTypes).toEqual(['OpenAIAPIClient'])
|
||||
})
|
||||
|
||||
it('should return AihubmixAPIClient for AihubmixAPIClient without model', () => {
|
||||
const client = new AihubmixAPIClient(aihubmixProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['AihubmixAPIClient'])
|
||||
})
|
||||
|
||||
it('should delegate to underlying client for AihubmixAPIClient with model', () => {
|
||||
const client = new AihubmixAPIClient(aihubmixProvider)
|
||||
const testModel = createTestModel('gpt-4', 'openai')
|
||||
|
||||
// Get the actual client selected for this model
|
||||
const actualClient = client.getClientForModel(testModel)
|
||||
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
|
||||
|
||||
// Should return the actual underlying client type based on model (OpenAI models use OpenAIResponseAPIClient in Aihubmix)
|
||||
expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient'])
|
||||
})
|
||||
|
||||
it('should return NewAPIClient for NewAPIClient without model', () => {
|
||||
const client = new NewAPIClient(newApiProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['NewAPIClient'])
|
||||
})
|
||||
|
||||
it('should delegate to underlying client for NewAPIClient with model', () => {
|
||||
const client = new NewAPIClient(newApiProvider)
|
||||
const testModel = createTestModel('gpt-4', 'openai', 'openai-response')
|
||||
|
||||
// Get the actual client selected for this model
|
||||
const actualClient = client.getClientForModel(testModel)
|
||||
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
|
||||
|
||||
// Should return the actual underlying client type based on model
|
||||
expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient'])
|
||||
})
|
||||
|
||||
it('should return VertexAPIClient for VertexAPIClient without model', () => {
|
||||
const client = new VertexAPIClient(vertexProvider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toEqual(['VertexAPIClient'])
|
||||
})
|
||||
|
||||
it('should delegate to underlying client for VertexAPIClient with model', () => {
|
||||
const client = new VertexAPIClient(vertexProvider)
|
||||
const testModel = createTestModel('claude-3-5-sonnet', 'vertexai')
|
||||
|
||||
// Get the actual client selected for this model
|
||||
const actualClient = client.getClient(testModel)
|
||||
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
|
||||
|
||||
// Should return the actual underlying client type based on model (Claude models use AnthropicVertexClient)
|
||||
expect(compatibilityTypes).toEqual(['AnthropicVertexAPIClient'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('Middleware Compatibility Logic', () => {
|
||||
it('should correctly identify OpenAI compatible clients', () => {
|
||||
const openaiClient = new OpenAIAPIClient(openaiProvider)
|
||||
const openaiResponseClient = new OpenAIResponseAPIClient(azureProvider)
|
||||
|
||||
const openaiTypes = openaiClient.getClientCompatibilityType()
|
||||
const responseTypes = openaiResponseClient.getClientCompatibilityType()
|
||||
|
||||
// Test the logic from completions method line 94
|
||||
const isOpenAICompatible = (types: string[]) =>
|
||||
types.includes('OpenAIAPIClient') || types.includes('OpenAIResponseAPIClient')
|
||||
|
||||
expect(isOpenAICompatible(openaiTypes)).toBe(true)
|
||||
expect(isOpenAICompatible(responseTypes)).toBe(true)
|
||||
})
|
||||
|
||||
it('should correctly identify Anthropic or OpenAIResponse compatible clients', () => {
|
||||
const anthropicClient = new AnthropicAPIClient(anthropicProvider)
|
||||
const openaiResponseClient = new OpenAIResponseAPIClient(azureProvider)
|
||||
const openaiClient = new OpenAIAPIClient(openaiProvider)
|
||||
|
||||
const anthropicTypes = anthropicClient.getClientCompatibilityType()
|
||||
const responseTypes = openaiResponseClient.getClientCompatibilityType()
|
||||
const openaiTypes = openaiClient.getClientCompatibilityType()
|
||||
|
||||
// Test the logic from completions method line 101
|
||||
const isAnthropicOrOpenAIResponseCompatible = (types: string[]) =>
|
||||
types.includes('AnthropicAPIClient') || types.includes('OpenAIResponseAPIClient')
|
||||
|
||||
expect(isAnthropicOrOpenAIResponseCompatible(anthropicTypes)).toBe(true)
|
||||
expect(isAnthropicOrOpenAIResponseCompatible(responseTypes)).toBe(true)
|
||||
expect(isAnthropicOrOpenAIResponseCompatible(openaiTypes)).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle non-compatible clients correctly', () => {
|
||||
const geminiClient = new GeminiAPIClient(geminiProvider)
|
||||
const geminiTypes = geminiClient.getClientCompatibilityType()
|
||||
|
||||
// Test that Gemini is not OpenAI compatible
|
||||
const isOpenAICompatible = (types: string[]) =>
|
||||
types.includes('OpenAIAPIClient') || types.includes('OpenAIResponseAPIClient')
|
||||
|
||||
// Test that Gemini is not Anthropic/OpenAIResponse compatible
|
||||
const isAnthropicOrOpenAIResponseCompatible = (types: string[]) =>
|
||||
types.includes('AnthropicAPIClient') || types.includes('OpenAIResponseAPIClient')
|
||||
|
||||
expect(isOpenAICompatible(geminiTypes)).toBe(false)
|
||||
expect(isAnthropicOrOpenAIResponseCompatible(geminiTypes)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Factory Integration', () => {
|
||||
it('should return correct compatibility types for factory-created clients', () => {
|
||||
const testCases = [
|
||||
{ provider: openaiProvider, expectedType: 'OpenAIAPIClient' },
|
||||
{ provider: anthropicProvider, expectedType: 'AnthropicAPIClient' },
|
||||
{ provider: azureProvider, expectedType: 'OpenAIResponseAPIClient' },
|
||||
{ provider: aihubmixProvider, expectedType: 'AihubmixAPIClient' },
|
||||
{ provider: newApiProvider, expectedType: 'NewAPIClient' },
|
||||
{ provider: vertexProvider, expectedType: 'VertexAPIClient' }
|
||||
]
|
||||
|
||||
testCases.forEach(({ provider, expectedType }) => {
|
||||
const client = ApiClientFactory.create(provider)
|
||||
const compatibilityTypes = client.getClientCompatibilityType()
|
||||
|
||||
expect(compatibilityTypes).toContain(expectedType)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -31,6 +31,9 @@ vi.mock('../AihubmixAPIClient', () => ({
|
||||
vi.mock('../anthropic/AnthropicAPIClient', () => ({
|
||||
AnthropicAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
vi.mock('../anthropic/AnthropicVertexClient', () => ({
|
||||
AnthropicVertexClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
vi.mock('../gemini/GeminiAPIClient', () => ({
|
||||
GeminiAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
|
||||
@@ -24,6 +24,7 @@ import {
|
||||
WebSearchToolResultError
|
||||
} from '@anthropic-ai/sdk/resources/messages'
|
||||
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
||||
import { loggerService } from '@logger'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
@@ -76,7 +77,7 @@ import { AnthropicStreamListener, RawStreamListener, RequestTransformer, Respons
|
||||
const logger = loggerService.withContext('AnthropicAPIClient')
|
||||
|
||||
export class AnthropicAPIClient extends BaseApiClient<
|
||||
Anthropic,
|
||||
Anthropic | AnthropicVertex,
|
||||
AnthropicSdkParams,
|
||||
AnthropicSdkRawOutput,
|
||||
AnthropicSdkRawChunk,
|
||||
@@ -84,11 +85,12 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
ToolUseBlock,
|
||||
ToolUnion
|
||||
> {
|
||||
sdkInstance: Anthropic | AnthropicVertex | undefined = undefined
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
async getSdkInstance(): Promise<Anthropic> {
|
||||
async getSdkInstance(): Promise<Anthropic | AnthropicVertex> {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
@@ -108,7 +110,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
payload: AnthropicSdkParams,
|
||||
options?: Anthropic.RequestOptions
|
||||
): Promise<AnthropicSdkRawOutput> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const sdk = (await this.getSdkInstance()) as Anthropic
|
||||
if (payload.stream) {
|
||||
return sdk.messages.stream(payload, options)
|
||||
}
|
||||
@@ -122,7 +124,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
}
|
||||
|
||||
override async listModels(): Promise<Anthropic.ModelInfo[]> {
|
||||
const sdk = await this.getSdkInstance()
|
||||
const sdk = (await this.getSdkInstance()) as Anthropic
|
||||
const response = await sdk.models.list()
|
||||
return response.data
|
||||
}
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
import Anthropic from '@anthropic-ai/sdk'
|
||||
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
||||
import { getVertexAILocation, getVertexAIProjectId, getVertexAIServiceAccount } from '@renderer/hooks/useVertexAI'
|
||||
import { loggerService } from '@renderer/services/LoggerService'
|
||||
import { Provider } from '@renderer/types'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
const logger = loggerService.withContext('AnthropicVertexClient')
|
||||
import { AnthropicAPIClient } from './AnthropicAPIClient'
|
||||
|
||||
export class AnthropicVertexClient extends AnthropicAPIClient {
|
||||
sdkInstance: AnthropicVertex | undefined = undefined
|
||||
private authHeaders?: Record<string, string>
|
||||
private authHeadersExpiry?: number
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
}
|
||||
|
||||
private formatApiHost(host: string): string {
|
||||
const forceUseOriginalHost = () => {
|
||||
return host.endsWith('/')
|
||||
}
|
||||
|
||||
if (!host) {
|
||||
return host
|
||||
}
|
||||
|
||||
return forceUseOriginalHost() ? host : `${host}/v1/`
|
||||
}
|
||||
|
||||
override getBaseURL() {
|
||||
return this.formatApiHost(this.provider.apiHost)
|
||||
}
|
||||
|
||||
override async getSdkInstance(): Promise<AnthropicVertex> {
|
||||
if (this.sdkInstance) {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
const serviceAccount = getVertexAIServiceAccount()
|
||||
const projectId = getVertexAIProjectId()
|
||||
const location = getVertexAILocation()
|
||||
|
||||
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId || !location) {
|
||||
throw new Error('Vertex AI settings are not configured')
|
||||
}
|
||||
|
||||
const authHeaders = await this.getServiceAccountAuthHeaders()
|
||||
|
||||
this.sdkInstance = new AnthropicVertex({
|
||||
projectId: projectId,
|
||||
region: location,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: authHeaders,
|
||||
baseURL: isEmpty(this.getBaseURL()) ? undefined : this.getBaseURL()
|
||||
})
|
||||
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
override async listModels(): Promise<Anthropic.ModelInfo[]> {
|
||||
throw new Error('Vertex AI does not support listModels method.')
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取认证头,如果配置了 service account 则从主进程获取
|
||||
*/
|
||||
private async getServiceAccountAuthHeaders(): Promise<Record<string, string> | undefined> {
|
||||
const serviceAccount = getVertexAIServiceAccount()
|
||||
const projectId = getVertexAIProjectId()
|
||||
|
||||
// 检查是否配置了 service account
|
||||
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
// 检查是否已有有效的认证头(提前 5 分钟过期)
|
||||
const now = Date.now()
|
||||
if (this.authHeaders && this.authHeadersExpiry && this.authHeadersExpiry - now > 5 * 60 * 1000) {
|
||||
return this.authHeaders
|
||||
}
|
||||
|
||||
try {
|
||||
// 从主进程获取认证头
|
||||
this.authHeaders = await window.api.vertexAI.getAuthHeaders({
|
||||
projectId,
|
||||
serviceAccount: {
|
||||
privateKey: serviceAccount.privateKey,
|
||||
clientEmail: serviceAccount.clientEmail
|
||||
}
|
||||
})
|
||||
|
||||
// 设置过期时间(通常认证头有效期为 1 小时)
|
||||
this.authHeadersExpiry = now + 60 * 60 * 1000
|
||||
|
||||
return this.authHeaders
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to get auth headers:', error)
|
||||
throw new Error(`Service Account authentication failed: ${error.message}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,17 +1,54 @@
|
||||
import { GoogleGenAI } from '@google/genai'
|
||||
import { loggerService } from '@logger'
|
||||
import { getVertexAILocation, getVertexAIProjectId, getVertexAIServiceAccount } from '@renderer/hooks/useVertexAI'
|
||||
import { Provider } from '@renderer/types'
|
||||
import { Model, Provider } from '@renderer/types'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient'
|
||||
import { GeminiAPIClient } from './GeminiAPIClient'
|
||||
|
||||
const logger = loggerService.withContext('VertexAPIClient')
|
||||
export class VertexAPIClient extends GeminiAPIClient {
|
||||
private authHeaders?: Record<string, string>
|
||||
private authHeadersExpiry?: number
|
||||
private anthropicVertexClient: AnthropicVertexClient
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
this.anthropicVertexClient = new AnthropicVertexClient(provider)
|
||||
}
|
||||
|
||||
override getClientCompatibilityType(model?: Model): string[] {
|
||||
if (!model) {
|
||||
return [this.constructor.name]
|
||||
}
|
||||
|
||||
const actualClient = this.getClient(model)
|
||||
if (actualClient === this) {
|
||||
return [this.constructor.name]
|
||||
}
|
||||
|
||||
return actualClient.getClientCompatibilityType(model)
|
||||
}
|
||||
|
||||
public getClient(model: Model) {
|
||||
if (model.id.includes('claude')) {
|
||||
return this.anthropicVertexClient
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
private formatApiHost(baseUrl: string) {
|
||||
if (baseUrl.endsWith('/v1/')) {
|
||||
baseUrl = baseUrl.slice(0, -4)
|
||||
} else if (baseUrl.endsWith('/v1')) {
|
||||
baseUrl = baseUrl.slice(0, -3)
|
||||
}
|
||||
return baseUrl
|
||||
}
|
||||
|
||||
override getBaseURL() {
|
||||
return this.formatApiHost(this.provider.apiHost)
|
||||
}
|
||||
|
||||
override async getSdkInstance() {
|
||||
@@ -35,7 +72,8 @@ export class VertexAPIClient extends GeminiAPIClient {
|
||||
location: location,
|
||||
httpOptions: {
|
||||
apiVersion: this.getApiVersion(),
|
||||
headers: authHeaders
|
||||
headers: authHeaders,
|
||||
baseUrl: isEmpty(this.getBaseURL()) ? undefined : this.getBaseURL()
|
||||
}
|
||||
})
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ import type { RequestOptions, SdkModel } from '@renderer/types/sdk'
|
||||
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
|
||||
|
||||
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
|
||||
import { VertexAPIClient } from './clients/gemini/VertexAPIClient'
|
||||
import { NewAPIClient } from './clients/NewAPIClient'
|
||||
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
|
||||
import { CompletionsMiddlewareBuilder } from './middleware/builder'
|
||||
@@ -61,6 +62,8 @@ export default class AiProvider {
|
||||
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
|
||||
// OpenAIResponseAPIClient: 根据模型特征选择API类型
|
||||
client = this.apiClient.getClient(model) as BaseApiClient
|
||||
} else if (this.apiClient instanceof VertexAPIClient) {
|
||||
client = this.apiClient.getClient(model) as BaseApiClient
|
||||
} else {
|
||||
// 其他client直接使用
|
||||
client = this.apiClient
|
||||
|
||||
@@ -2324,6 +2324,7 @@
|
||||
"search_placeholder": "Search model id or name",
|
||||
"title": "Model Provider",
|
||||
"vertex_ai": {
|
||||
"api_host_help": "The API host for Vertex AI, not recommended to fill in, generally applicable to reverse proxy",
|
||||
"documentation": "View official documentation for more configuration details:",
|
||||
"learn_more": "Learn More",
|
||||
"location": "Location",
|
||||
|
||||
@@ -2324,6 +2324,7 @@
|
||||
"search_placeholder": "モデルIDまたは名前を検索",
|
||||
"title": "モデルプロバイダー",
|
||||
"vertex_ai": {
|
||||
"api_host_help": "Vertex AIのAPIアドレス。逆プロキシに適しています。",
|
||||
"documentation": "詳細な設定については、公式ドキュメントを参照してください:",
|
||||
"learn_more": "詳細を確認",
|
||||
"location": "場所",
|
||||
|
||||
@@ -2324,6 +2324,7 @@
|
||||
"search_placeholder": "Поиск по ID или имени модели",
|
||||
"title": "Провайдеры моделей",
|
||||
"vertex_ai": {
|
||||
"api_host_help": "API-адрес Vertex AI, не рекомендуется заполнять, обычно применим к обратным прокси",
|
||||
"documentation": "Смотрите официальную документацию для получения более подробной информации о конфигурации:",
|
||||
"learn_more": "Узнать больше",
|
||||
"location": "Местоположение",
|
||||
|
||||
@@ -2324,6 +2324,7 @@
|
||||
"search_placeholder": "搜索模型 ID 或名称",
|
||||
"title": "模型服务",
|
||||
"vertex_ai": {
|
||||
"api_host_help": "Vertex AI 的 API 地址,不建议填写,通常适用于反向代理",
|
||||
"documentation": "查看官方文档了解更多配置详情:",
|
||||
"learn_more": "了解更多",
|
||||
"location": "地区",
|
||||
|
||||
@@ -2324,6 +2324,7 @@
|
||||
"search_placeholder": "搜尋模型 ID 或名稱",
|
||||
"title": "模型提供者",
|
||||
"vertex_ai": {
|
||||
"api_host_help": "Vertex AI 的 API 地址,不建議填寫,通常適用於反向代理",
|
||||
"documentation": "檢視官方文件以取得更多設定詳細資訊:",
|
||||
"learn_more": "瞭解更多",
|
||||
"location": "地區",
|
||||
|
||||
@@ -372,7 +372,7 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
|
||||
{provider.id === 'lmstudio' && <LMStudioSettings />}
|
||||
{provider.id === 'gpustack' && <GPUStackSettings />}
|
||||
{provider.id === 'copilot' && <GithubCopilotSettings providerId={provider.id} />}
|
||||
{provider.id === 'vertexai' && <VertexAISettings />}
|
||||
{provider.id === 'vertexai' && <VertexAISettings providerId={provider.id} />}
|
||||
<ModelList providerId={provider.id} />
|
||||
</SettingContainer>
|
||||
)
|
||||
|
||||
@@ -1,13 +1,18 @@
|
||||
import { HStack } from '@renderer/components/Layout'
|
||||
import { PROVIDER_CONFIG } from '@renderer/config/providers'
|
||||
import { useProvider } from '@renderer/hooks/useProvider'
|
||||
import { useVertexAISettings } from '@renderer/hooks/useVertexAI'
|
||||
import { Alert, Input } from 'antd'
|
||||
import { Alert, Input, Space } from 'antd'
|
||||
import { FC, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
import { SettingHelpLink, SettingHelpText, SettingHelpTextRow, SettingSubtitle } from '..'
|
||||
|
||||
const VertexAISettings: FC = () => {
|
||||
interface Props {
|
||||
providerId: string
|
||||
}
|
||||
|
||||
const VertexAISettings: FC<Props> = ({ providerId }) => {
|
||||
const { t } = useTranslation()
|
||||
const {
|
||||
projectId,
|
||||
@@ -19,11 +24,18 @@ const VertexAISettings: FC = () => {
|
||||
setServiceAccountClientEmail
|
||||
} = useVertexAISettings()
|
||||
|
||||
const [localProjectId, setLocalProjectId] = useState(projectId)
|
||||
const [localLocation, setLocalLocation] = useState(location)
|
||||
|
||||
const { provider, updateProvider } = useProvider(providerId)
|
||||
const [apiHost, setApiHost] = useState(provider.apiHost)
|
||||
|
||||
const providerConfig = PROVIDER_CONFIG['vertexai']
|
||||
const apiKeyWebsite = providerConfig?.websites?.apiKey
|
||||
|
||||
const [localProjectId, setLocalProjectId] = useState(projectId)
|
||||
const [localLocation, setLocalLocation] = useState(location)
|
||||
const onUpdateApiHost = () => {
|
||||
updateProvider({ apiHost })
|
||||
}
|
||||
|
||||
const handleProjectIdChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
setLocalProjectId(e.target.value)
|
||||
@@ -60,6 +72,18 @@ const VertexAISettings: FC = () => {
|
||||
|
||||
return (
|
||||
<>
|
||||
<SettingSubtitle>{t('settings.provider.api_host')}</SettingSubtitle>
|
||||
<Space.Compact style={{ width: '100%', marginTop: 5 }}>
|
||||
<Input
|
||||
value={apiHost}
|
||||
placeholder={t('settings.provider.api_host')}
|
||||
onChange={(e) => setApiHost(e.target.value)}
|
||||
onBlur={onUpdateApiHost}
|
||||
/>
|
||||
</Space.Compact>
|
||||
<SettingHelpTextRow>
|
||||
<SettingHelpText>{t('settings.provider.vertex_ai.api_host_help')}</SettingHelpText>
|
||||
</SettingHelpTextRow>
|
||||
<SettingSubtitle style={{ marginTop: 5 }}>
|
||||
{t('settings.provider.vertex_ai.service_account.title')}
|
||||
</SettingSubtitle>
|
||||
|
||||
@@ -8,6 +8,7 @@ import {
|
||||
ToolUseBlock
|
||||
} from '@anthropic-ai/sdk/resources'
|
||||
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
||||
import {
|
||||
Content,
|
||||
CreateChatParameters,
|
||||
@@ -23,7 +24,7 @@ import { Stream } from 'openai/streaming'
|
||||
|
||||
import { EndpointType } from './index'
|
||||
|
||||
export type SdkInstance = OpenAI | AzureOpenAI | Anthropic | GoogleGenAI
|
||||
export type SdkInstance = OpenAI | AzureOpenAI | Anthropic | AnthropicVertex | GoogleGenAI
|
||||
export type SdkParams = OpenAISdkParams | OpenAIResponseSdkParams | AnthropicSdkParams | GeminiSdkParams
|
||||
export type SdkRawChunk = OpenAISdkRawChunk | OpenAIResponseSdkRawChunk | AnthropicSdkRawChunk | GeminiSdkRawChunk
|
||||
export type SdkRawOutput = OpenAISdkRawOutput | OpenAIResponseSdkRawOutput | AnthropicSdkRawOutput | GeminiSdkRawOutput
|
||||
|
||||
Reference in New Issue
Block a user