feat(aiCore): enhance provider management and registration system
- Added support for a new provider configuration structure in package.json, enabling better integration of provider types. - Updated tsdown.config.ts to include new entry points for provider modules, improving build organization. - Refactored index.ts to streamline exports and enhance type handling for provider-related functionalities. - Simplified provider initialization and registration processes, allowing for more flexible provider management. - Improved type definitions and removed deprecated methods to enhance code clarity and maintainability.
This commit is contained in:
@@ -8,7 +8,8 @@
|
||||
* 3. 暂时保持接口兼容性
|
||||
*/
|
||||
|
||||
import { createExecutor, generateImage, initializeProvider, StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { createExecutor, generateImage, StreamTextParams } from '@cherrystudio/ai-core'
|
||||
import { createAndRegisterProvider } from '@cherrystudio/ai-core/provider'
|
||||
import { loggerService } from '@logger'
|
||||
import { isNotSupportedImageSizeModel } from '@renderer/config/models'
|
||||
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
|
||||
@@ -36,21 +37,6 @@ export default class ModernAiProvider {
|
||||
|
||||
// 只保存配置,不预先创建executor
|
||||
this.config = providerToAiSdkConfig(this.actualProvider)
|
||||
|
||||
// 初始化 provider 到全局管理器
|
||||
try {
|
||||
initializeProvider(this.config.providerId, this.config.options)
|
||||
logger.debug('Provider initialized successfully', {
|
||||
providerId: this.config.providerId,
|
||||
hasOptions: !!this.config.options
|
||||
})
|
||||
} catch (error) {
|
||||
// 如果 provider 已经初始化过,可能会抛出错误,这里可以忽略
|
||||
logger.debug('Provider initialization skipped (may already be initialized)', {
|
||||
providerId: this.config.providerId,
|
||||
error: error instanceof Error ? error.message : String(error)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
public getActualProvider() {
|
||||
@@ -67,6 +53,21 @@ export default class ModernAiProvider {
|
||||
callType: string
|
||||
}
|
||||
): Promise<CompletionsResult> {
|
||||
// 初始化 provider 到全局管理器
|
||||
try {
|
||||
await createAndRegisterProvider(this.config.providerId, this.config.options)
|
||||
logger.debug('Provider initialized successfully', {
|
||||
providerId: this.config.providerId,
|
||||
hasOptions: !!this.config.options
|
||||
})
|
||||
} catch (error) {
|
||||
// 如果 provider 已经初始化过,可能会抛出错误,这里可以忽略
|
||||
logger.debug('Provider initialization skipped (may already be initialized)', {
|
||||
providerId: this.config.providerId,
|
||||
error: error instanceof Error ? error.message : String(error)
|
||||
})
|
||||
}
|
||||
|
||||
if (config.isImageGenerationEndpoint) {
|
||||
return await this.modernImageGeneration(modelId, params, config)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,9 @@
|
||||
import { AiCore, ProviderConfigFactory, type ProviderId, type ProviderSettingsMap } from '@cherrystudio/ai-core'
|
||||
import {
|
||||
hasProviderConfig,
|
||||
ProviderConfigFactory,
|
||||
type ProviderId,
|
||||
type ProviderSettingsMap
|
||||
} from '@cherrystudio/ai-core/provider'
|
||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
@@ -86,7 +91,7 @@ export function providerToAiSdkConfig(actualProvider: Provider): {
|
||||
}
|
||||
|
||||
// 如果AI SDK支持该provider,使用原生配置
|
||||
if (AiCore.isSupported(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
|
||||
return {
|
||||
providerId: aiSdkProviderId as ProviderId,
|
||||
@@ -120,5 +125,5 @@ export function isModernSdkSupported(provider: Provider): boolean {
|
||||
const aiSdkProviderId = getAiSdkProviderId(provider)
|
||||
|
||||
// 如果映射到了支持的provider,则支持现代SDK
|
||||
return AiCore.isSupported(aiSdkProviderId)
|
||||
return hasProviderConfig(aiSdkProviderId)
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { AiCore, getProviderMapping, type ProviderId } from '@cherrystudio/ai-core'
|
||||
import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider'
|
||||
import { loggerService } from '@logger'
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
@@ -30,7 +30,7 @@ const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
|
||||
}
|
||||
|
||||
/**
|
||||
* 尝试解析provider标识符(支持静态映射和动态映射)
|
||||
* 尝试解析provider标识符(支持静态映射和别名)
|
||||
*/
|
||||
function tryResolveProviderId(identifier: string): ProviderId | null {
|
||||
// 1. 检查静态映射
|
||||
@@ -39,15 +39,10 @@ function tryResolveProviderId(identifier: string): ProviderId | null {
|
||||
return staticMapping
|
||||
}
|
||||
|
||||
// 2. 检查动态映射
|
||||
const dynamicMapping = getProviderMapping(identifier)
|
||||
if (dynamicMapping && dynamicMapping !== identifier) {
|
||||
return dynamicMapping as ProviderId
|
||||
}
|
||||
|
||||
// 3. 检查AiCore是否直接支持
|
||||
if (AiCore.isSupported(identifier)) {
|
||||
return identifier as ProviderId
|
||||
// 2. 检查AiCore是否支持(包括别名支持)
|
||||
if (hasProviderConfigByAlias(identifier)) {
|
||||
// 解析为真实的Provider ID
|
||||
return resolveProviderConfigId(identifier) as ProviderId
|
||||
}
|
||||
|
||||
return null
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { type ProviderConfig, registerMultipleProviders } from '@cherrystudio/ai-core'
|
||||
import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider'
|
||||
import { loggerService } from '@logger'
|
||||
|
||||
const logger = loggerService.withContext('ProviderConfigs')
|
||||
@@ -16,9 +16,7 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
|
||||
import: () => import('@openrouter/ai-sdk-provider'),
|
||||
creatorFunctionName: 'createOpenRouter',
|
||||
supportsImageGeneration: true,
|
||||
mappings: {
|
||||
openrouter: 'openrouter'
|
||||
}
|
||||
aliases: ['openrouter']
|
||||
},
|
||||
{
|
||||
id: 'google-vertex',
|
||||
@@ -26,10 +24,7 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
|
||||
import: () => import('@ai-sdk/google-vertex'),
|
||||
creatorFunctionName: 'createGoogleVertex',
|
||||
supportsImageGeneration: true,
|
||||
mappings: {
|
||||
'google-vertex': 'google-vertex',
|
||||
vertexai: 'google-vertex'
|
||||
}
|
||||
aliases: ['google-vertex', 'vertexai']
|
||||
},
|
||||
{
|
||||
id: 'bedrock',
|
||||
@@ -37,9 +32,7 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
|
||||
import: () => import('@ai-sdk/amazon-bedrock'),
|
||||
creatorFunctionName: 'createAmazonBedrock',
|
||||
supportsImageGeneration: true,
|
||||
mappings: {
|
||||
'aws-bedrock': 'bedrock'
|
||||
}
|
||||
aliases: ['aws-bedrock']
|
||||
}
|
||||
] as const
|
||||
|
||||
@@ -49,19 +42,7 @@ export const NEW_PROVIDER_CONFIGS: (ProviderConfig & {
|
||||
*/
|
||||
export async function initializeNewProviders(): Promise<void> {
|
||||
try {
|
||||
logger.info('Starting to register new providers', {
|
||||
providerCount: NEW_PROVIDER_CONFIGS.length,
|
||||
providerIds: NEW_PROVIDER_CONFIGS.map((p) => p.id)
|
||||
})
|
||||
|
||||
const successCount = registerMultipleProviders(NEW_PROVIDER_CONFIGS)
|
||||
|
||||
logger.info('Provider registration completed', {
|
||||
successCount,
|
||||
totalCount: NEW_PROVIDER_CONFIGS.length,
|
||||
failedCount: NEW_PROVIDER_CONFIGS.length - successCount
|
||||
})
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(NEW_PROVIDER_CONFIGS)
|
||||
if (successCount < NEW_PROVIDER_CONFIGS.length) {
|
||||
logger.warn('Some providers failed to register. Check previous error logs.')
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { aiSdk, Tool } from '@cherrystudio/ai-core'
|
||||
import { Tool } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
// import { AiSdkTool, ToolCallResult } from '@renderer/aiCore/tools/types'
|
||||
import { MCPTool, MCPToolResponse } from '@renderer/types'
|
||||
import { callMCPTool } from '@renderer/utils/mcp-tools'
|
||||
import { jsonSchema, tool } from 'ai'
|
||||
import { JSONSchema7 } from 'json-schema'
|
||||
|
||||
const { tool } = aiSdk
|
||||
const logger = loggerService.withContext('MCP-utils')
|
||||
|
||||
// Setup tools configuration based on provided parameters
|
||||
@@ -30,7 +30,7 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): Record<string,
|
||||
for (const mcpTool of mcpTools) {
|
||||
tools[mcpTool.name] = tool({
|
||||
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
|
||||
inputSchema: aiSdk.jsonSchema(mcpTool.inputSchema as JSONSchema7),
|
||||
inputSchema: jsonSchema(mcpTool.inputSchema as JSONSchema7),
|
||||
execute: async (params) => {
|
||||
// 创建适配的 MCPToolResponse 对象
|
||||
const toolResponse: MCPToolResponse = {
|
||||
|
||||
Reference in New Issue
Block a user