Feat/vertex-claude-support (#7564)

* feat(migrate): add default settings for assistants during migration

- Introduced a new migration step to assign default settings for assistants that lack configuration.
- Default settings include temperature, context count, and other parameters to ensure consistent behavior across the application.

* chore(store): increment version number to 115 for persisted reducer

* feat(vertex-sdk): integrate Anthropic Vertex SDK and add access token retrieval

- Added support for the new `@anthropic-ai/vertex-sdk` in the project.
- Introduced a new IPC channel `VertexAI_GetAccessToken` to retrieve access tokens.
- Implemented `getAccessToken` method in `VertexAIService` to handle service account authentication.
- Updated the `IpcChannel` enum and related IPC handlers to support the new functionality.
- Enhanced the `VertexAPIClient` to utilize the `AnthropicVertexClient` for model handling.
- Refactored existing code to accommodate the integration of the Vertex SDK and improve modularity.

* feat(vertex-ai): enhance VertexAI settings and API host management

- Added a new method to format the API host URL in both AnthropicVertexClient and VertexAPIClient.
- Updated getBaseURL methods to utilize the new formatting logic.
- Enhanced VertexAISettings component to include an input for API host configuration, with help text for user guidance.
- Updated localization files to include new help text for the API host field in multiple languages.

* fix(vertex-sdk): update baseURL handling and patch dependencies

- Refactored baseURL assignment in AnthropicVertexClient to ensure it defaults to undefined when the URL is empty.
- Updated yarn.lock to reflect changes in dependency resolution and checksum for @anthropic-ai/vertex-sdk patch.

* refactor(VertexAISetting): use provider.id rather than provider

* refactor: improve API host formatting in AnthropicVertexClient

- Updated the `formatApiHost` method to streamline host URL handling.
- Introduced a helper function to determine if the original host should be used based on its format.
- Ensured consistent appending of the `/v1/` path for valid API requests.

* fix: handle empty host in AnthropicVertexClient

- Added a check in the `getBaseURL` method to return the host if it is empty, preventing potential errors.
- Included a console log for the base URL to aid in debugging and verification of the URL formatting.

* feat(AnthropicVertexClient): add logging for authentication errors and mock client in tests

- Introduced logging functionality in AnthropicVertexClient to replace console.error with logger service for better error tracking.
- Added mock implementation for AnthropicVertexClient in tests to enhance testing capabilities.
- Updated package.json to include the @aws-sdk/client-s3 dependency.

* feat(tests): add comprehensive tests for client compatibility types

- Introduced a new test file to validate compatibility types for various API clients including OpenAI, Anthropic, Gemini, Aihubmix, NewAPI, and Vertex.
- Implemented mock services to facilitate testing and ensure isolation of client behavior.
- Added tests for both direct API clients and decorator pattern clients, ensuring correct compatibility type returns.
- Enhanced middleware compatibility logic tests to verify correct identification of compatible clients.

---------

Co-authored-by: one <wangan.cs@gmail.com>
This commit is contained in:
SuYao
2025-07-24 23:46:32 +08:00
committed by GitHub
parent 0bb3061f8d
commit 4c0167cc03
21 changed files with 804 additions and 13 deletions
+4
View File
@@ -533,6 +533,10 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
return vertexAIService.getAuthHeaders(params)
})
ipcMain.handle(IpcChannel.VertexAI_GetAccessToken, async (_, params) => {
return vertexAIService.getAccessToken(params)
})
ipcMain.handle(IpcChannel.VertexAI_ClearAuthCache, async (_, projectId: string, clientEmail?: string) => {
vertexAIService.clearAuthCache(projectId, clientEmail)
})
+31
View File
@@ -114,6 +114,37 @@ class VertexAIService {
}
}
async getAccessToken(params: VertexAIAuthParams): Promise<string> {
const { projectId, serviceAccount } = params
if (!serviceAccount?.privateKey || !serviceAccount?.clientEmail) {
throw new Error('Service account credentials are required')
}
const formattedPrivateKey = this.formatPrivateKey(serviceAccount.privateKey)
const cacheKey = `${projectId}-${serviceAccount.clientEmail}`
let auth = this.authClients.get(cacheKey)
if (!auth) {
auth = new GoogleAuth({
credentials: {
private_key: formattedPrivateKey,
client_email: serviceAccount.clientEmail
},
projectId,
scopes: [REQUIRED_VERTEX_AI_SCOPE]
})
this.authClients.set(cacheKey, auth)
}
const accessToken = await auth.getAccessToken()
return accessToken || ''
}
/**
* 清理指定项目的认证缓存
*/
+2
View File
@@ -246,6 +246,8 @@ const api = {
vertexAI: {
getAuthHeaders: (params: { projectId: string; serviceAccount?: { privateKey: string; clientEmail: string } }) =>
ipcRenderer.invoke(IpcChannel.VertexAI_GetAuthHeaders, params),
getAccessToken: (params: { projectId: string; serviceAccount?: { privateKey: string; clientEmail: string } }) =>
ipcRenderer.invoke(IpcChannel.VertexAI_GetAccessToken, params),
clearAuthCache: (projectId: string, clientEmail?: string) =>
ipcRenderer.invoke(IpcChannel.VertexAI_ClearAuthCache, projectId, clientEmail)
},
@@ -0,0 +1,347 @@
import { AihubmixAPIClient } from '@renderer/aiCore/clients/AihubmixAPIClient'
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient'
import { VertexAPIClient } from '@renderer/aiCore/clients/gemini/VertexAPIClient'
import { NewAPIClient } from '@renderer/aiCore/clients/NewAPIClient'
import { OpenAIAPIClient } from '@renderer/aiCore/clients/openai/OpenAIApiClient'
import { OpenAIResponseAPIClient } from '@renderer/aiCore/clients/openai/OpenAIResponseAPIClient'
import { EndpointType, Model, Provider } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
vi.mock('@renderer/config/models', () => ({
SYSTEM_MODELS: {
defaultModel: [
{ id: 'gpt-4', name: 'GPT-4' },
{ id: 'gpt-4', name: 'GPT-4' },
{ id: 'gpt-4', name: 'GPT-4' }
],
silicon: [],
openai: [],
anthropic: [],
gemini: []
},
isOpenAILLMModel: vi.fn().mockReturnValue(true),
isOpenAIChatCompletionOnlyModel: vi.fn().mockReturnValue(false),
isAnthropicLLMModel: vi.fn().mockReturnValue(false),
isGeminiLLMModel: vi.fn().mockReturnValue(false),
isSupportedReasoningEffortOpenAIModel: vi.fn().mockReturnValue(false),
isVisionModel: vi.fn().mockReturnValue(false),
isClaudeReasoningModel: vi.fn().mockReturnValue(false),
isReasoningModel: vi.fn().mockReturnValue(false),
isWebSearchModel: vi.fn().mockReturnValue(false),
findTokenLimit: vi.fn().mockReturnValue(4096),
isFunctionCallingModel: vi.fn().mockReturnValue(false),
DEFAULT_MAX_TOKENS: 4096
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn(),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
vi.mock('@renderer/services/FileManager', () => ({
default: class {
static async read() {
return 'test content'
}
static async write() {
return true
}
}
}))
vi.mock('@renderer/services/TokenService', () => ({
estimateTextTokens: vi.fn().mockReturnValue(100)
}))
vi.mock('@logger', () => ({
loggerService: {
withContext: vi.fn().mockReturnValue({
debug: vi.fn(),
info: vi.fn(),
warn: vi.fn(),
error: vi.fn(),
silly: vi.fn()
})
}
}))
// Mock additional services and hooks that might be imported
vi.mock('@renderer/hooks/useVertexAI', () => ({
getVertexAILocation: vi.fn().mockReturnValue('us-central1'),
getVertexAIProjectId: vi.fn().mockReturnValue('test-project'),
getVertexAIServiceAccount: vi.fn().mockReturnValue({
privateKey: 'test-key',
clientEmail: 'test@example.com'
})
}))
vi.mock('@renderer/hooks/useSettings', () => ({
getStoreSetting: vi.fn().mockReturnValue({}),
useSettings: vi.fn().mockReturnValue([{}, vi.fn()])
}))
vi.mock('@renderer/store/settings', () => ({
default: {},
settingsSlice: {
name: 'settings',
reducer: vi.fn(),
actions: {}
}
}))
vi.mock('@renderer/utils/abortController', () => ({
addAbortController: vi.fn(),
removeAbortController: vi.fn()
}))
vi.mock('@anthropic-ai/sdk', () => ({
default: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('@anthropic-ai/vertex-sdk', () => ({
default: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('openai', () => ({
default: vi.fn().mockImplementation(() => ({})),
AzureOpenAI: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('@google/generative-ai', () => ({
GoogleGenerativeAI: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('@google-cloud/vertexai', () => ({
VertexAI: vi.fn().mockImplementation(() => ({}))
}))
// Mock the circular dependency between VertexAPIClient and AnthropicVertexClient
vi.mock('@renderer/aiCore/clients/anthropic/AnthropicVertexClient', () => {
const MockAnthropicVertexClient = vi.fn()
MockAnthropicVertexClient.prototype.getClientCompatibilityType = vi.fn().mockReturnValue(['AnthropicVertexAPIClient'])
return {
AnthropicVertexClient: MockAnthropicVertexClient
}
})
// Helper to create test provider
const createTestProvider = (id: string, type: string): Provider => ({
id,
type: type as Provider['type'],
name: 'Test Provider',
apiKey: 'test-key',
apiHost: 'https://api.test.com',
models: []
})
// Helper to create test model
const createTestModel = (id: string, provider?: string, endpointType?: string): Model => ({
id,
name: 'Test Model',
provider: provider || 'test',
type: [],
group: 'test',
endpoint_type: endpointType as EndpointType
})
describe('Client Compatibility Types', () => {
let openaiProvider: Provider
let anthropicProvider: Provider
let geminiProvider: Provider
let azureProvider: Provider
let aihubmixProvider: Provider
let newApiProvider: Provider
let vertexProvider: Provider
beforeEach(() => {
vi.clearAllMocks()
openaiProvider = createTestProvider('openai', 'openai')
anthropicProvider = createTestProvider('anthropic', 'anthropic')
geminiProvider = createTestProvider('gemini', 'gemini')
azureProvider = createTestProvider('azure-openai', 'azure-openai')
aihubmixProvider = createTestProvider('aihubmix', 'openai')
newApiProvider = createTestProvider('new-api', 'openai')
vertexProvider = createTestProvider('vertex', 'vertexai')
})
describe('Direct API Clients', () => {
it('should return correct compatibility type for OpenAIAPIClient', () => {
const client = new OpenAIAPIClient(openaiProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['OpenAIAPIClient'])
})
it('should return correct compatibility type for AnthropicAPIClient', () => {
const client = new AnthropicAPIClient(anthropicProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['AnthropicAPIClient'])
})
it('should return correct compatibility type for GeminiAPIClient', () => {
const client = new GeminiAPIClient(geminiProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['GeminiAPIClient'])
})
})
describe('Decorator Pattern API Clients', () => {
it('should return OpenAIResponseAPIClient for OpenAIResponseAPIClient without model', () => {
const client = new OpenAIResponseAPIClient(azureProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient'])
})
it('should delegate to underlying client for OpenAIResponseAPIClient with model', () => {
const client = new OpenAIResponseAPIClient(azureProvider)
const testModel = createTestModel('gpt-4', 'azure-openai')
// Get the actual client selected for this model
const actualClient = client.getClient(testModel)
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
// Should return OpenAIResponseAPIClient for non-chat-completion-only models
expect(compatibilityTypes).toEqual(['OpenAIAPIClient'])
})
it('should return AihubmixAPIClient for AihubmixAPIClient without model', () => {
const client = new AihubmixAPIClient(aihubmixProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['AihubmixAPIClient'])
})
it('should delegate to underlying client for AihubmixAPIClient with model', () => {
const client = new AihubmixAPIClient(aihubmixProvider)
const testModel = createTestModel('gpt-4', 'openai')
// Get the actual client selected for this model
const actualClient = client.getClientForModel(testModel)
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
// Should return the actual underlying client type based on model (OpenAI models use OpenAIResponseAPIClient in Aihubmix)
expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient'])
})
it('should return NewAPIClient for NewAPIClient without model', () => {
const client = new NewAPIClient(newApiProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['NewAPIClient'])
})
it('should delegate to underlying client for NewAPIClient with model', () => {
const client = new NewAPIClient(newApiProvider)
const testModel = createTestModel('gpt-4', 'openai', 'openai-response')
// Get the actual client selected for this model
const actualClient = client.getClientForModel(testModel)
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
// Should return the actual underlying client type based on model
expect(compatibilityTypes).toEqual(['OpenAIResponseAPIClient'])
})
it('should return VertexAPIClient for VertexAPIClient without model', () => {
const client = new VertexAPIClient(vertexProvider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toEqual(['VertexAPIClient'])
})
it('should delegate to underlying client for VertexAPIClient with model', () => {
const client = new VertexAPIClient(vertexProvider)
const testModel = createTestModel('claude-3-5-sonnet', 'vertexai')
// Get the actual client selected for this model
const actualClient = client.getClient(testModel)
const compatibilityTypes = actualClient.getClientCompatibilityType(testModel)
// Should return the actual underlying client type based on model (Claude models use AnthropicVertexClient)
expect(compatibilityTypes).toEqual(['AnthropicVertexAPIClient'])
})
})
describe('Middleware Compatibility Logic', () => {
it('should correctly identify OpenAI compatible clients', () => {
const openaiClient = new OpenAIAPIClient(openaiProvider)
const openaiResponseClient = new OpenAIResponseAPIClient(azureProvider)
const openaiTypes = openaiClient.getClientCompatibilityType()
const responseTypes = openaiResponseClient.getClientCompatibilityType()
// Test the logic from completions method line 94
const isOpenAICompatible = (types: string[]) =>
types.includes('OpenAIAPIClient') || types.includes('OpenAIResponseAPIClient')
expect(isOpenAICompatible(openaiTypes)).toBe(true)
expect(isOpenAICompatible(responseTypes)).toBe(true)
})
it('should correctly identify Anthropic or OpenAIResponse compatible clients', () => {
const anthropicClient = new AnthropicAPIClient(anthropicProvider)
const openaiResponseClient = new OpenAIResponseAPIClient(azureProvider)
const openaiClient = new OpenAIAPIClient(openaiProvider)
const anthropicTypes = anthropicClient.getClientCompatibilityType()
const responseTypes = openaiResponseClient.getClientCompatibilityType()
const openaiTypes = openaiClient.getClientCompatibilityType()
// Test the logic from completions method line 101
const isAnthropicOrOpenAIResponseCompatible = (types: string[]) =>
types.includes('AnthropicAPIClient') || types.includes('OpenAIResponseAPIClient')
expect(isAnthropicOrOpenAIResponseCompatible(anthropicTypes)).toBe(true)
expect(isAnthropicOrOpenAIResponseCompatible(responseTypes)).toBe(true)
expect(isAnthropicOrOpenAIResponseCompatible(openaiTypes)).toBe(false)
})
it('should handle non-compatible clients correctly', () => {
const geminiClient = new GeminiAPIClient(geminiProvider)
const geminiTypes = geminiClient.getClientCompatibilityType()
// Test that Gemini is not OpenAI compatible
const isOpenAICompatible = (types: string[]) =>
types.includes('OpenAIAPIClient') || types.includes('OpenAIResponseAPIClient')
// Test that Gemini is not Anthropic/OpenAIResponse compatible
const isAnthropicOrOpenAIResponseCompatible = (types: string[]) =>
types.includes('AnthropicAPIClient') || types.includes('OpenAIResponseAPIClient')
expect(isOpenAICompatible(geminiTypes)).toBe(false)
expect(isAnthropicOrOpenAIResponseCompatible(geminiTypes)).toBe(false)
})
})
describe('Factory Integration', () => {
it('should return correct compatibility types for factory-created clients', () => {
const testCases = [
{ provider: openaiProvider, expectedType: 'OpenAIAPIClient' },
{ provider: anthropicProvider, expectedType: 'AnthropicAPIClient' },
{ provider: azureProvider, expectedType: 'OpenAIResponseAPIClient' },
{ provider: aihubmixProvider, expectedType: 'AihubmixAPIClient' },
{ provider: newApiProvider, expectedType: 'NewAPIClient' },
{ provider: vertexProvider, expectedType: 'VertexAPIClient' }
]
testCases.forEach(({ provider, expectedType }) => {
const client = ApiClientFactory.create(provider)
const compatibilityTypes = client.getClientCompatibilityType()
expect(compatibilityTypes).toContain(expectedType)
})
})
})
})
@@ -31,6 +31,9 @@ vi.mock('../AihubmixAPIClient', () => ({
vi.mock('../anthropic/AnthropicAPIClient', () => ({
AnthropicAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('../anthropic/AnthropicVertexClient', () => ({
AnthropicVertexClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('../gemini/GeminiAPIClient', () => ({
GeminiAPIClient: vi.fn().mockImplementation(() => ({}))
}))
@@ -24,6 +24,7 @@ import {
WebSearchToolResultError
} from '@anthropic-ai/sdk/resources/messages'
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
import { loggerService } from '@logger'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
@@ -76,7 +77,7 @@ import { AnthropicStreamListener, RawStreamListener, RequestTransformer, Respons
const logger = loggerService.withContext('AnthropicAPIClient')
export class AnthropicAPIClient extends BaseApiClient<
Anthropic,
Anthropic | AnthropicVertex,
AnthropicSdkParams,
AnthropicSdkRawOutput,
AnthropicSdkRawChunk,
@@ -84,11 +85,12 @@ export class AnthropicAPIClient extends BaseApiClient<
ToolUseBlock,
ToolUnion
> {
sdkInstance: Anthropic | AnthropicVertex | undefined = undefined
constructor(provider: Provider) {
super(provider)
}
async getSdkInstance(): Promise<Anthropic> {
async getSdkInstance(): Promise<Anthropic | AnthropicVertex> {
if (this.sdkInstance) {
return this.sdkInstance
}
@@ -108,7 +110,7 @@ export class AnthropicAPIClient extends BaseApiClient<
payload: AnthropicSdkParams,
options?: Anthropic.RequestOptions
): Promise<AnthropicSdkRawOutput> {
const sdk = await this.getSdkInstance()
const sdk = (await this.getSdkInstance()) as Anthropic
if (payload.stream) {
return sdk.messages.stream(payload, options)
}
@@ -122,7 +124,7 @@ export class AnthropicAPIClient extends BaseApiClient<
}
override async listModels(): Promise<Anthropic.ModelInfo[]> {
const sdk = await this.getSdkInstance()
const sdk = (await this.getSdkInstance()) as Anthropic
const response = await sdk.models.list()
return response.data
}
@@ -0,0 +1,103 @@
import Anthropic from '@anthropic-ai/sdk'
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
import { getVertexAILocation, getVertexAIProjectId, getVertexAIServiceAccount } from '@renderer/hooks/useVertexAI'
import { loggerService } from '@renderer/services/LoggerService'
import { Provider } from '@renderer/types'
import { isEmpty } from 'lodash'
const logger = loggerService.withContext('AnthropicVertexClient')
import { AnthropicAPIClient } from './AnthropicAPIClient'
export class AnthropicVertexClient extends AnthropicAPIClient {
sdkInstance: AnthropicVertex | undefined = undefined
private authHeaders?: Record<string, string>
private authHeadersExpiry?: number
constructor(provider: Provider) {
super(provider)
}
private formatApiHost(host: string): string {
const forceUseOriginalHost = () => {
return host.endsWith('/')
}
if (!host) {
return host
}
return forceUseOriginalHost() ? host : `${host}/v1/`
}
override getBaseURL() {
return this.formatApiHost(this.provider.apiHost)
}
override async getSdkInstance(): Promise<AnthropicVertex> {
if (this.sdkInstance) {
return this.sdkInstance
}
const serviceAccount = getVertexAIServiceAccount()
const projectId = getVertexAIProjectId()
const location = getVertexAILocation()
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId || !location) {
throw new Error('Vertex AI settings are not configured')
}
const authHeaders = await this.getServiceAccountAuthHeaders()
this.sdkInstance = new AnthropicVertex({
projectId: projectId,
region: location,
dangerouslyAllowBrowser: true,
defaultHeaders: authHeaders,
baseURL: isEmpty(this.getBaseURL()) ? undefined : this.getBaseURL()
})
return this.sdkInstance
}
override async listModels(): Promise<Anthropic.ModelInfo[]> {
throw new Error('Vertex AI does not support listModels method.')
}
/**
* 获取认证头,如果配置了 service account 则从主进程获取
*/
private async getServiceAccountAuthHeaders(): Promise<Record<string, string> | undefined> {
const serviceAccount = getVertexAIServiceAccount()
const projectId = getVertexAIProjectId()
// 检查是否配置了 service account
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId) {
return undefined
}
// 检查是否已有有效的认证头(提前 5 分钟过期)
const now = Date.now()
if (this.authHeaders && this.authHeadersExpiry && this.authHeadersExpiry - now > 5 * 60 * 1000) {
return this.authHeaders
}
try {
// 从主进程获取认证头
this.authHeaders = await window.api.vertexAI.getAuthHeaders({
projectId,
serviceAccount: {
privateKey: serviceAccount.privateKey,
clientEmail: serviceAccount.clientEmail
}
})
// 设置过期时间(通常认证头有效期为 1 小时)
this.authHeadersExpiry = now + 60 * 60 * 1000
return this.authHeaders
} catch (error: any) {
logger.error('Failed to get auth headers:', error)
throw new Error(`Service Account authentication failed: ${error.message}`)
}
}
}
@@ -1,17 +1,54 @@
import { GoogleGenAI } from '@google/genai'
import { loggerService } from '@logger'
import { getVertexAILocation, getVertexAIProjectId, getVertexAIServiceAccount } from '@renderer/hooks/useVertexAI'
import { Provider } from '@renderer/types'
import { Model, Provider } from '@renderer/types'
import { isEmpty } from 'lodash'
import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient'
import { GeminiAPIClient } from './GeminiAPIClient'
const logger = loggerService.withContext('VertexAPIClient')
export class VertexAPIClient extends GeminiAPIClient {
private authHeaders?: Record<string, string>
private authHeadersExpiry?: number
private anthropicVertexClient: AnthropicVertexClient
constructor(provider: Provider) {
super(provider)
this.anthropicVertexClient = new AnthropicVertexClient(provider)
}
override getClientCompatibilityType(model?: Model): string[] {
if (!model) {
return [this.constructor.name]
}
const actualClient = this.getClient(model)
if (actualClient === this) {
return [this.constructor.name]
}
return actualClient.getClientCompatibilityType(model)
}
public getClient(model: Model) {
if (model.id.includes('claude')) {
return this.anthropicVertexClient
}
return this
}
private formatApiHost(baseUrl: string) {
if (baseUrl.endsWith('/v1/')) {
baseUrl = baseUrl.slice(0, -4)
} else if (baseUrl.endsWith('/v1')) {
baseUrl = baseUrl.slice(0, -3)
}
return baseUrl
}
override getBaseURL() {
return this.formatApiHost(this.provider.apiHost)
}
override async getSdkInstance() {
@@ -35,7 +72,8 @@ export class VertexAPIClient extends GeminiAPIClient {
location: location,
httpOptions: {
apiVersion: this.getApiVersion(),
headers: authHeaders
headers: authHeaders,
baseUrl: isEmpty(this.getBaseURL()) ? undefined : this.getBaseURL()
}
})
+3
View File
@@ -10,6 +10,7 @@ import type { RequestOptions, SdkModel } from '@renderer/types/sdk'
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
import { VertexAPIClient } from './clients/gemini/VertexAPIClient'
import { NewAPIClient } from './clients/NewAPIClient'
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
import { CompletionsMiddlewareBuilder } from './middleware/builder'
@@ -61,6 +62,8 @@ export default class AiProvider {
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
// OpenAIResponseAPIClient: 根据模型特征选择API类型
client = this.apiClient.getClient(model) as BaseApiClient
} else if (this.apiClient instanceof VertexAPIClient) {
client = this.apiClient.getClient(model) as BaseApiClient
} else {
// 其他client直接使用
client = this.apiClient
+1
View File
@@ -2324,6 +2324,7 @@
"search_placeholder": "Search model id or name",
"title": "Model Provider",
"vertex_ai": {
"api_host_help": "The API host for Vertex AI, not recommended to fill in, generally applicable to reverse proxy",
"documentation": "View official documentation for more configuration details:",
"learn_more": "Learn More",
"location": "Location",
+1
View File
@@ -2324,6 +2324,7 @@
"search_placeholder": "モデルIDまたは名前を検索",
"title": "モデルプロバイダー",
"vertex_ai": {
"api_host_help": "Vertex AIのAPIアドレス。逆プロキシに適しています。",
"documentation": "詳細な設定については、公式ドキュメントを参照してください:",
"learn_more": "詳細を確認",
"location": "場所",
+1
View File
@@ -2324,6 +2324,7 @@
"search_placeholder": "Поиск по ID или имени модели",
"title": "Провайдеры моделей",
"vertex_ai": {
"api_host_help": "API-адрес Vertex AI, не рекомендуется заполнять, обычно применим к обратным прокси",
"documentation": "Смотрите официальную документацию для получения более подробной информации о конфигурации:",
"learn_more": "Узнать больше",
"location": "Местоположение",
+1
View File
@@ -2324,6 +2324,7 @@
"search_placeholder": "搜索模型 ID 或名称",
"title": "模型服务",
"vertex_ai": {
"api_host_help": "Vertex AI 的 API 地址,不建议填写,通常适用于反向代理",
"documentation": "查看官方文档了解更多配置详情:",
"learn_more": "了解更多",
"location": "地区",
+1
View File
@@ -2324,6 +2324,7 @@
"search_placeholder": "搜尋模型 ID 或名稱",
"title": "模型提供者",
"vertex_ai": {
"api_host_help": "Vertex AI 的 API 地址,不建議填寫,通常適用於反向代理",
"documentation": "檢視官方文件以取得更多設定詳細資訊:",
"learn_more": "瞭解更多",
"location": "地區",
@@ -372,7 +372,7 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
{provider.id === 'lmstudio' && <LMStudioSettings />}
{provider.id === 'gpustack' && <GPUStackSettings />}
{provider.id === 'copilot' && <GithubCopilotSettings providerId={provider.id} />}
{provider.id === 'vertexai' && <VertexAISettings />}
{provider.id === 'vertexai' && <VertexAISettings providerId={provider.id} />}
<ModelList providerId={provider.id} />
</SettingContainer>
)
@@ -1,13 +1,18 @@
import { HStack } from '@renderer/components/Layout'
import { PROVIDER_CONFIG } from '@renderer/config/providers'
import { useProvider } from '@renderer/hooks/useProvider'
import { useVertexAISettings } from '@renderer/hooks/useVertexAI'
import { Alert, Input } from 'antd'
import { Alert, Input, Space } from 'antd'
import { FC, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { SettingHelpLink, SettingHelpText, SettingHelpTextRow, SettingSubtitle } from '..'
const VertexAISettings: FC = () => {
interface Props {
providerId: string
}
const VertexAISettings: FC<Props> = ({ providerId }) => {
const { t } = useTranslation()
const {
projectId,
@@ -19,11 +24,18 @@ const VertexAISettings: FC = () => {
setServiceAccountClientEmail
} = useVertexAISettings()
const [localProjectId, setLocalProjectId] = useState(projectId)
const [localLocation, setLocalLocation] = useState(location)
const { provider, updateProvider } = useProvider(providerId)
const [apiHost, setApiHost] = useState(provider.apiHost)
const providerConfig = PROVIDER_CONFIG['vertexai']
const apiKeyWebsite = providerConfig?.websites?.apiKey
const [localProjectId, setLocalProjectId] = useState(projectId)
const [localLocation, setLocalLocation] = useState(location)
const onUpdateApiHost = () => {
updateProvider({ apiHost })
}
const handleProjectIdChange = (e: React.ChangeEvent<HTMLInputElement>) => {
setLocalProjectId(e.target.value)
@@ -60,6 +72,18 @@ const VertexAISettings: FC = () => {
return (
<>
<SettingSubtitle>{t('settings.provider.api_host')}</SettingSubtitle>
<Space.Compact style={{ width: '100%', marginTop: 5 }}>
<Input
value={apiHost}
placeholder={t('settings.provider.api_host')}
onChange={(e) => setApiHost(e.target.value)}
onBlur={onUpdateApiHost}
/>
</Space.Compact>
<SettingHelpTextRow>
<SettingHelpText>{t('settings.provider.vertex_ai.api_host_help')}</SettingHelpText>
</SettingHelpTextRow>
<SettingSubtitle style={{ marginTop: 5 }}>
{t('settings.provider.vertex_ai.service_account.title')}
</SettingSubtitle>
+2 -1
View File
@@ -8,6 +8,7 @@ import {
ToolUseBlock
} from '@anthropic-ai/sdk/resources'
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
import {
Content,
CreateChatParameters,
@@ -23,7 +24,7 @@ import { Stream } from 'openai/streaming'
import { EndpointType } from './index'
export type SdkInstance = OpenAI | AzureOpenAI | Anthropic | GoogleGenAI
export type SdkInstance = OpenAI | AzureOpenAI | Anthropic | AnthropicVertex | GoogleGenAI
export type SdkParams = OpenAISdkParams | OpenAIResponseSdkParams | AnthropicSdkParams | GeminiSdkParams
export type SdkRawChunk = OpenAISdkRawChunk | OpenAIResponseSdkRawChunk | AnthropicSdkRawChunk | GeminiSdkRawChunk
export type SdkRawOutput = OpenAISdkRawOutput | OpenAIResponseSdkRawOutput | AnthropicSdkRawOutput | GeminiSdkRawOutput