Compare commits

..

2 Commits

Author SHA1 Message Date
copilot-swe-agent[bot] b14e48dd78 Fix custom parameters placement for Vercel AI Gateway
For AI Gateway provider, custom parameters are now placed at the body level
instead of being nested inside providerOptions.gateway. This fixes the issue
where parameters like 'tools' were being incorrectly added to
providerOptions.gateway when they should be at the same level as providerOptions.

Fixes #4197

Co-authored-by: DeJeune <67425183+DeJeune@users.noreply.github.com>
2025-12-01 05:13:55 +00:00
copilot-swe-agent[bot] 64fde27f9e Initial plan 2025-12-01 05:05:10 +00:00
68 changed files with 1249 additions and 3975 deletions
-2
View File
@@ -73,5 +73,3 @@ test-results
YOUR_MEMORY_FILE_PATH YOUR_MEMORY_FILE_PATH
.sessions/ .sessions/
.next/
*.tsbuildinfo
+1 -4
View File
@@ -25,10 +25,7 @@ export default defineConfig({
'@shared': resolve('packages/shared'), '@shared': resolve('packages/shared'),
'@logger': resolve('src/main/services/LoggerService'), '@logger': resolve('src/main/services/LoggerService'),
'@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'), '@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'),
'@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node'), '@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node')
'@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'),
'@cherrystudio/ai-core': resolve('packages/aiCore/src'),
'@cherrystudio/ai-sdk-provider': resolve('packages/ai-sdk-provider/src')
} }
}, },
build: { build: {
+8 -57
View File
@@ -9,27 +9,13 @@
*/ */
import Anthropic from '@anthropic-ai/sdk' import Anthropic from '@anthropic-ai/sdk'
import type { MessageCreateParams, TextBlockParam, Tool as AnthropicTool } from '@anthropic-ai/sdk/resources' import type { TextBlockParam } from '@anthropic-ai/sdk/resources'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { type Provider, SystemProviderIds } from '@types' import type { Provider } from '@types'
import type { ModelMessage } from 'ai' import type { ModelMessage } from 'ai'
const logger = loggerService.withContext('anthropic-sdk') const logger = loggerService.withContext('anthropic-sdk')
/**
* Context for Anthropic SDK client creation.
* This allows the shared module to be used in different environments
* by providing environment-specific implementations.
*/
export interface AnthropicSdkContext {
/**
* Custom fetch function to use for HTTP requests.
* In Electron main process, this should be `net.fetch`.
* In other environments, can use the default fetch or a custom implementation.
*/
fetch?: typeof globalThis.fetch
}
const defaultClaudeCodeSystemPrompt = `You are Claude Code, Anthropic's official CLI for Claude.` const defaultClaudeCodeSystemPrompt = `You are Claude Code, Anthropic's official CLI for Claude.`
const defaultClaudeCodeSystem: Array<TextBlockParam> = [ const defaultClaudeCodeSystem: Array<TextBlockParam> = [
@@ -72,11 +58,8 @@ const defaultClaudeCodeSystem: Array<TextBlockParam> = [
export function getSdkClient( export function getSdkClient(
provider: Provider, provider: Provider,
oauthToken?: string | null, oauthToken?: string | null,
extraHeaders?: Record<string, string | string[]>, extraHeaders?: Record<string, string | string[]>
context?: AnthropicSdkContext
): Anthropic { ): Anthropic {
const customFetch = context?.fetch
if (provider.authType === 'oauth') { if (provider.authType === 'oauth') {
if (!oauthToken) { if (!oauthToken) {
throw new Error('OAuth token is not available') throw new Error('OAuth token is not available')
@@ -102,8 +85,7 @@ export function getSdkClient(
'x-stainless-runtime': 'node', 'x-stainless-runtime': 'node',
'x-stainless-runtime-version': 'v22.18.0', 'x-stainless-runtime-version': 'v22.18.0',
...extraHeaders ...extraHeaders
}, }
fetch: customFetch
}) })
} }
let baseURL = let baseURL =
@@ -124,12 +106,11 @@ export function getSdkClient(
baseURL, baseURL,
dangerouslyAllowBrowser: true, dangerouslyAllowBrowser: true,
defaultHeaders: { defaultHeaders: {
'anthropic-beta': 'interleaved-thinking-2025-05-14', 'anthropic-beta': 'output-128k-2025-02-19',
'APP-Code': 'MLTG2087', 'APP-Code': 'MLTG2087',
...provider.extra_headers, ...provider.extra_headers,
...extraHeaders ...extraHeaders
}, }
fetch: customFetch
}) })
} }
@@ -139,11 +120,9 @@ export function getSdkClient(
baseURL, baseURL,
dangerouslyAllowBrowser: true, dangerouslyAllowBrowser: true,
defaultHeaders: { defaultHeaders: {
'anthropic-beta': 'interleaved-thinking-2025-05-14', 'anthropic-beta': 'output-128k-2025-02-19',
Authorization: provider.id === SystemProviderIds.longcat ? `Bearer ${provider.apiKey}` : undefined,
...provider.extra_headers ...provider.extra_headers
}, }
fetch: customFetch
}) })
} }
@@ -194,31 +173,3 @@ export function buildClaudeCodeSystemModelMessage(system?: string | Array<TextBl
content: block.text content: block.text
})) }))
} }
/**
* Sanitize tool definitions for Anthropic API.
*
* Removes non-standard fields like `input_examples` from tool definitions
* that Anthropic's API doesn't support. This prevents validation errors when
* tools with extended fields are passed to the Anthropic SDK.
*
* @param tools - Array of tool definitions from MessageCreateParams
* @returns Sanitized tools array with non-standard fields removed
*
* @example
* ```typescript
* const sanitizedTools = sanitizeToolsForAnthropic(request.tools)
* ```
*/
export function sanitizeToolsForAnthropic(tools?: MessageCreateParams['tools']): MessageCreateParams['tools'] {
if (!tools || tools.length === 0) return tools
return tools.map((tool) => {
if ('type' in tool && tool.type !== 'custom') return tool
// oxlint-disable-next-line no-unused-vars
const { input_examples, ...sanitizedTool } = tool as AnthropicTool & { input_examples?: unknown }
return sanitizedTool as typeof tool
})
}
-245
View File
@@ -1,245 +0,0 @@
/**
* Shared API Utilities
*
* Common utilities for API URL formatting and validation.
* Used by both main process (API Server) and renderer.
*/
import type { MinimalProvider } from '@shared/provider'
import { trim } from 'lodash'
// Supported endpoints for routing
export const SUPPORTED_IMAGE_ENDPOINT_LIST = ['images/generations', 'images/edits', 'predict'] as const
export const SUPPORTED_ENDPOINT_LIST = [
'chat/completions',
'responses',
'messages',
'generateContent',
'streamGenerateContent',
...SUPPORTED_IMAGE_ENDPOINT_LIST
] as const
/**
* Removes the trailing slash from a URL string if it exists.
*/
export function withoutTrailingSlash<T extends string>(url: T): T {
return url.replace(/\/$/, '') as T
}
/**
* Matches a version segment in a path that starts with `/v<number>` and optionally
* continues with `alpha` or `beta`. The segment may be followed by `/` or the end
* of the string (useful for cases like `/v3alpha/resources`).
*/
const VERSION_REGEX_PATTERN = '\\/v\\d+(?:alpha|beta)?(?=\\/|$)'
/**
* Matches an API version at the end of a URL (with optional trailing slash).
* Used to detect and extract versions only from the trailing position.
*/
const TRAILING_VERSION_REGEX = /\/v\d+(?:alpha|beta)?\/?$/i
/**
* 判断 host 的 path 中是否包含形如版本的字符串(例如 /v1、/v2beta 等),
*
* @param host - 要检查的 host 或 path 字符串
* @returns 如果 path 中包含版本字符串则返回 true,否则 false
*/
export function hasAPIVersion(host?: string): boolean {
if (!host) return false
const regex = new RegExp(VERSION_REGEX_PATTERN, 'i')
try {
const url = new URL(host)
return regex.test(url.pathname)
} catch {
// 若无法作为完整 URL 解析,则当作路径直接检测
return regex.test(host)
}
}
/**
* 格式化 Azure OpenAI 的 API 主机地址。
*/
export function formatAzureOpenAIApiHost(host: string): string {
const normalizedHost = withoutTrailingSlash(host)
?.replace(/\/v1$/, '')
.replace(/\/openai$/, '')
// NOTE: AISDK会添加上`v1`
return formatApiHost(normalizedHost + '/openai', false)
}
export function formatVertexApiHost(
provider: MinimalProvider,
project: string = 'test-project',
location: string = 'us-central1'
): string {
const { apiHost } = provider
const trimmedHost = withoutTrailingSlash(trim(apiHost))
if (!trimmedHost || trimmedHost.endsWith('aiplatform.googleapis.com')) {
const host =
location === 'global' ? 'https://aiplatform.googleapis.com' : `https://${location}-aiplatform.googleapis.com`
return `${formatApiHost(host)}/projects/${project}/locations/${location}`
}
return formatApiHost(trimmedHost)
}
/**
* Formats an API host URL by normalizing it and optionally appending an API version.
*
* @param host - The API host URL to format. Leading/trailing whitespace will be trimmed and trailing slashes removed.
* @param supportApiVersion - Whether the API version is supported. Defaults to `true`.
* @param apiVersion - The API version to append if needed. Defaults to `'v1'`.
*
* @returns The formatted API host URL. If the host is empty after normalization, returns an empty string.
* If the host ends with '#', API version is not supported, or the host already contains a version, returns the normalized host as-is.
* Otherwise, returns the host with the API version appended.
*
* @example
* formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1'
* formatApiHost('https://api.example.com#') // Returns 'https://api.example.com#'
* formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2'
*/
export function formatApiHost(host?: string, supportApiVersion: boolean = true, apiVersion: string = 'v1'): string {
const normalizedHost = withoutTrailingSlash(trim(host))
if (!normalizedHost) {
return ''
}
if (normalizedHost.endsWith('#') || !supportApiVersion || hasAPIVersion(normalizedHost)) {
return normalizedHost
}
return `${normalizedHost}/${apiVersion}`
}
/**
* Converts an API host URL into separate base URL and endpoint components.
*
* This function extracts endpoint information from a composite API host string.
* If the host ends with '#', it attempts to match the preceding part against the supported endpoint list.
*
* @param apiHost - The API host string to parse
* @returns An object containing:
* - `baseURL`: The base URL without the endpoint suffix
* - `endpoint`: The matched endpoint identifier, or empty string if no match found
*
* @example
* routeToEndpoint('https://api.example.com/openai/chat/completions#')
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: 'chat/completions' }
*
* @example
* routeToEndpoint('https://api.example.com/v1')
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: '' }
*/
export function routeToEndpoint(apiHost: string): { baseURL: string; endpoint: string } {
const trimmedHost = (apiHost || '').trim()
if (!trimmedHost.endsWith('#')) {
return { baseURL: trimmedHost, endpoint: '' }
}
// Remove trailing #
const host = trimmedHost.slice(0, -1)
const endpointMatch = SUPPORTED_ENDPOINT_LIST.find((endpoint) => host.endsWith(endpoint))
if (!endpointMatch) {
const baseURL = withoutTrailingSlash(host)
return { baseURL, endpoint: '' }
}
const baseSegment = host.slice(0, host.length - endpointMatch.length)
const baseURL = withoutTrailingSlash(baseSegment).replace(/:$/, '') // Remove trailing colon (gemini special case)
return { baseURL, endpoint: endpointMatch }
}
/**
* Gets the AI SDK compatible base URL from a provider's apiHost.
*
* AI SDK expects baseURL WITH version suffix (e.g., /v1).
* This function:
* 1. Handles '#' endpoint routing format
* 2. Ensures the URL has a version suffix (adds /v1 if missing)
*
* @param apiHost - The provider's apiHost value (may or may not have /v1)
* @param apiVersion - The API version to use if missing. Defaults to 'v1'.
* @returns The baseURL suitable for AI SDK (with version suffix)
*
* @example
* getAiSdkBaseUrl('https://api.openai.com') // 'https://api.openai.com/v1'
* getAiSdkBaseUrl('https://api.openai.com/v1') // 'https://api.openai.com/v1'
* getAiSdkBaseUrl('https://api.example.com/chat/completions#') // 'https://api.example.com'
*/
export function getAiSdkBaseUrl(apiHost: string, apiVersion: string = 'v1'): string {
// First handle '#' endpoint routing format
const { baseURL } = routeToEndpoint(apiHost)
// If already has version, return as-is
if (hasAPIVersion(baseURL)) {
return withoutTrailingSlash(baseURL)
}
// Add version suffix
return `${withoutTrailingSlash(baseURL)}/${apiVersion}`
}
/**
* Validates an API host address.
*
* @param apiHost - The API host address to validate
* @returns true if valid URL with http/https protocol, false otherwise
*/
export function validateApiHost(apiHost: string): boolean {
if (!apiHost || !apiHost.trim()) {
return true // Allow empty
}
try {
const url = new URL(apiHost.trim())
return url.protocol === 'http:' || url.protocol === 'https:'
} catch {
return false
}
}
/**
* Extracts the trailing API version segment from a URL path.
*
* This function extracts API version patterns (e.g., `v1`, `v2beta`) from the end of a URL.
* Only versions at the end of the path are extracted, not versions in the middle.
* The returned version string does not include leading or trailing slashes.
*
* @param {string} url - The URL string to parse.
* @returns {string | undefined} The trailing API version found (e.g., 'v1', 'v2beta'), or undefined if none found.
*
* @example
* getTrailingApiVersion('https://api.example.com/v1') // 'v1'
* getTrailingApiVersion('https://api.example.com/v2beta/') // 'v2beta'
* getTrailingApiVersion('https://api.example.com/v1/chat') // undefined (version not at end)
* getTrailingApiVersion('https://gateway.ai.cloudflare.com/v1/xxx/v1beta') // 'v1beta'
* getTrailingApiVersion('https://api.example.com') // undefined
*/
export function getTrailingApiVersion(url: string): string | undefined {
const match = url.match(TRAILING_VERSION_REGEX)
if (match) {
// Extract version without leading slash and trailing slash
return match[0].replace(/^\//, '').replace(/\/$/, '')
}
return undefined
}
/**
* Removes the trailing API version segment from a URL path.
*
* This function removes API version patterns (e.g., `/v1`, `/v2beta`) from the end of a URL.
* Only versions at the end of the path are removed, not versions in the middle.
*
* @param {string} url - The URL string to process.
* @returns {string} The URL with the trailing API version removed, or the original URL if no trailing version found.
*
* @example
* withoutTrailingApiVersion('https://api.example.com/v1') // 'https://api.example.com'
* withoutTrailingApiVersion('https://api.example.com/v2beta/') // 'https://api.example.com'
* withoutTrailingApiVersion('https://api.example.com/v1/chat') // 'https://api.example.com/v1/chat' (no change)
* withoutTrailingApiVersion('https://api.example.com') // 'https://api.example.com'
*/
export function withoutTrailingApiVersion(url: string): string {
return url.replace(TRAILING_VERSION_REGEX, '')
}
+2 -31
View File
@@ -43,35 +43,6 @@ export function isSiliconAnthropicCompatibleModel(modelId: string): boolean {
} }
/** /**
* PPIO provider models that support Anthropic API endpoint. * Silicon provider's Anthropic API host URL.
* These models can be used with Claude Code via the Anthropic-compatible API.
*
* @see https://ppio.com/docs/model/llm-anthropic-compatibility
*/ */
export const PPIO_ANTHROPIC_COMPATIBLE_MODELS: readonly string[] = [ export const SILICON_ANTHROPIC_API_HOST = 'https://api.siliconflow.cn'
'moonshotai/kimi-k2-thinking',
'minimax/minimax-m2',
'deepseek/deepseek-v3.2-exp',
'deepseek/deepseek-v3.1-terminus',
'zai-org/glm-4.6',
'moonshotai/kimi-k2-0905',
'deepseek/deepseek-v3.1',
'moonshotai/kimi-k2-instruct',
'qwen/qwen3-next-80b-a3b-instruct',
'qwen/qwen3-next-80b-a3b-thinking'
]
/**
* Creates a Set for efficient lookup of PPIO Anthropic-compatible model IDs.
*/
const PPIO_ANTHROPIC_COMPATIBLE_MODEL_SET = new Set(PPIO_ANTHROPIC_COMPATIBLE_MODELS)
/**
* Checks if a model ID is compatible with Anthropic API on PPIO provider.
*
* @param modelId - The model ID to check
* @returns true if the model supports Anthropic API endpoint
*/
export function isPpioAnthropicCompatibleModel(modelId: string): boolean {
return PPIO_ANTHROPIC_COMPATIBLE_MODEL_SET.has(modelId)
}
-15
View File
@@ -1,15 +0,0 @@
/**
* Shared AI SDK Middlewares
*
* Environment-agnostic middlewares that can be used in both
* renderer process and main process (API server).
*/
export {
buildSharedMiddlewares,
getReasoningTagName,
isGemini3ModelId,
openrouterReasoningMiddleware,
type SharedMiddlewareConfig,
skipGeminiThoughtSignatureMiddleware
} from './middlewares'
-205
View File
@@ -1,205 +0,0 @@
/**
* Shared AI SDK Middlewares
*
* These middlewares are environment-agnostic and can be used in both
* renderer process and main process (API server).
*/
import type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider'
import { extractReasoningMiddleware } from 'ai'
/**
* Configuration for building shared middlewares
*/
export interface SharedMiddlewareConfig {
/**
* Whether to enable reasoning extraction
*/
enableReasoning?: boolean
/**
* Tag name for reasoning extraction
* Defaults based on model ID
*/
reasoningTagName?: string
/**
* Model ID - used to determine default reasoning tag and model detection
*/
modelId?: string
/**
* Provider ID (Cherry Studio provider ID)
* Used for provider-specific middlewares like OpenRouter
*/
providerId?: string
/**
* AI SDK Provider ID
* Used for Gemini thought signature middleware
* e.g., 'google', 'google-vertex'
*/
aiSdkProviderId?: string
}
/**
* Check if model ID represents a Gemini 3 (2.5) model
* that requires thought signature handling
*
* @param modelId - The model ID string (not Model object)
*/
export function isGemini3ModelId(modelId?: string): boolean {
if (!modelId) return false
const lowerModelId = modelId.toLowerCase()
return lowerModelId.includes('gemini-3')
}
/**
* Get the default reasoning tag name based on model ID
*
* Different models use different tags for reasoning content:
* - Most models: 'think'
* - GPT-OSS models: 'reasoning'
* - Gemini models: 'thought'
* - Seed models: 'seed:think'
*/
export function getReasoningTagName(modelId?: string): string {
if (!modelId) return 'think'
const lowerModelId = modelId.toLowerCase()
if (lowerModelId.includes('gpt-oss')) return 'reasoning'
if (lowerModelId.includes('gemini')) return 'thought'
if (lowerModelId.includes('seed-oss-36b')) return 'seed:think'
return 'think'
}
/**
* Skip Gemini Thought Signature Middleware
*
* Due to the complexity of multi-model client requests (which can switch
* to other models mid-process), this middleware skips all Gemini 3
* thinking signatures validation.
*
* @param aiSdkId - AI SDK Provider ID (e.g., 'google', 'google-vertex')
* @returns LanguageModelV2Middleware
*/
export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelV2Middleware {
const MAGIC_STRING = 'skip_thought_signature_validator'
return {
middlewareVersion: 'v2',
transformParams: async ({ params }) => {
const transformedParams = { ...params }
// Process messages in prompt
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
transformedParams.prompt = transformedParams.prompt.map((message) => {
if (typeof message.content !== 'string') {
for (const part of message.content) {
const googleOptions = part?.providerOptions?.[aiSdkId]
if (googleOptions?.thoughtSignature) {
googleOptions.thoughtSignature = MAGIC_STRING
}
}
}
return message
})
}
return transformedParams
}
}
}
/**
* OpenRouter Reasoning Middleware
*
* Filters out [REDACTED] blocks from OpenRouter reasoning responses.
* OpenRouter may include [REDACTED] markers in reasoning content that
* should be removed for cleaner output.
*
* @see https://openrouter.ai/docs/docs/best-practices/reasoning-tokens
* @returns LanguageModelV2Middleware
*/
export function openrouterReasoningMiddleware(): LanguageModelV2Middleware {
const REDACTED_BLOCK = '[REDACTED]'
return {
middlewareVersion: 'v2',
wrapGenerate: async ({ doGenerate }) => {
const { content, ...rest } = await doGenerate()
const modifiedContent = content.map((part) => {
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
return {
...part,
text: part.text.replace(REDACTED_BLOCK, '')
}
}
return part
})
return { content: modifiedContent, ...rest }
},
wrapStream: async ({ doStream }) => {
const { stream, ...rest } = await doStream()
return {
stream: stream.pipeThrough(
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
transform(
chunk: LanguageModelV2StreamPart,
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
) {
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
controller.enqueue({
...chunk,
delta: chunk.delta.replace(REDACTED_BLOCK, '')
})
} else {
controller.enqueue(chunk)
}
}
})
),
...rest
}
}
}
}
/**
* Build shared middlewares based on configuration
*
* This function builds a set of middlewares that are commonly needed
* across different environments (renderer, API server).
*
* @param config - Configuration for middleware building
* @returns Array of AI SDK middlewares
*
* @example
* ```typescript
* import { buildSharedMiddlewares } from '@shared/middleware'
*
* const middlewares = buildSharedMiddlewares({
* enableReasoning: true,
* modelId: 'gemini-2.5-pro',
* providerId: 'openrouter',
* aiSdkProviderId: 'google'
* })
* ```
*/
export function buildSharedMiddlewares(config: SharedMiddlewareConfig): LanguageModelV2Middleware[] {
const middlewares: LanguageModelV2Middleware[] = []
// 1. Reasoning extraction middleware
if (config.enableReasoning) {
const tagName = config.reasoningTagName || getReasoningTagName(config.modelId)
middlewares.push(extractReasoningMiddleware({ tagName }))
}
// 2. OpenRouter-specific: filter [REDACTED] blocks
if (config.providerId === 'openrouter' && config.enableReasoning) {
middlewares.push(openrouterReasoningMiddleware())
}
// 3. Gemini 3 (2.5) specific: skip thought signature validation
if (isGemini3ModelId(config.modelId) && config.aiSdkProviderId) {
middlewares.push(skipGeminiThoughtSignatureMiddleware(config.aiSdkProviderId))
}
return middlewares
}
@@ -1,22 +0,0 @@
import type { MinimalModel, MinimalProvider, ProviderType } from '../types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
const AZURE_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: MinimalProvider) => ({
...provider,
type: 'anthropic' as ProviderType,
apiHost: provider.apiHost + 'anthropic/v1',
id: 'azure-anthropic'
})
}
],
fallbackRule: (provider: MinimalProvider) => provider
}
export const azureAnthropicProviderCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
provider2Provider<MinimalModel, MinimalProvider, P>(AZURE_ANTHROPIC_RULES, model, provider)
-32
View File
@@ -1,32 +0,0 @@
import type { MinimalModel, MinimalProvider } from '../types'
import type { RuleSet } from './types'
export const startsWith =
(prefix: string) =>
<M extends MinimalModel>(model: M) =>
model.id.toLowerCase().startsWith(prefix.toLowerCase())
export const endpointIs =
(type: string) =>
<M extends MinimalModel>(model: M) =>
model.endpoint_type === type
/**
* 解析模型对应的Provider
* @param ruleSet 规则集对象
* @param model 模型对象
* @param provider 原始provider对象
* @returns 解析出的provider对象
*/
export function provider2Provider<M extends MinimalModel, R extends MinimalProvider, P extends R = R>(
ruleSet: RuleSet<M, R>,
model: M,
provider: P
): P {
for (const rule of ruleSet.rules) {
if (rule.match(model)) {
return rule.provider(provider) as P
}
}
return ruleSet.fallbackRule(provider) as P
}
-6
View File
@@ -1,6 +0,0 @@
export { aihubmixProviderCreator } from './aihubmix'
export { azureAnthropicProviderCreator } from './azure-anthropic'
export { endpointIs, provider2Provider, startsWith } from './helper'
export { newApiResolverCreator } from './newApi'
export type { RuleSet } from './types'
export { vertexAnthropicProviderCreator } from './vertex-anthropic'
-9
View File
@@ -1,9 +0,0 @@
import type { MinimalModel, MinimalProvider } from '../types'
export interface RuleSet<M extends MinimalModel = MinimalModel, P extends MinimalProvider = MinimalProvider> {
rules: Array<{
match: (model: M) => boolean
provider: (provider: P) => P
}>
fallbackRule: (provider: P) => P
}
@@ -1,19 +0,0 @@
import type { MinimalModel, MinimalProvider } from '../types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
const VERTEX_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: MinimalProvider) => ({
...provider,
id: 'google-vertex-anthropic'
})
}
],
fallbackRule: (provider: MinimalProvider) => provider
}
export const vertexAnthropicProviderCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
provider2Provider<MinimalModel, MinimalProvider, P>(VERTEX_ANTHROPIC_RULES, model, provider)
-26
View File
@@ -1,26 +0,0 @@
import { getLowerBaseModelName } from '@shared/utils/naming'
import type { MinimalModel } from './types'
export const COPILOT_EDITOR_VERSION = 'vscode/1.104.1'
export const COPILOT_PLUGIN_VERSION = 'copilot-chat/0.26.7'
export const COPILOT_INTEGRATION_ID = 'vscode-chat'
export const COPILOT_USER_AGENT = 'GitHubCopilotChat/0.26.7'
export const COPILOT_DEFAULT_HEADERS = {
'Copilot-Integration-Id': COPILOT_INTEGRATION_ID,
'User-Agent': COPILOT_USER_AGENT,
'Editor-Version': COPILOT_EDITOR_VERSION,
'Editor-Plugin-Version': COPILOT_PLUGIN_VERSION,
'editor-version': COPILOT_EDITOR_VERSION,
'editor-plugin-version': COPILOT_PLUGIN_VERSION,
'copilot-vision-request': 'true'
} as const
// Models that require the OpenAI Responses endpoint when routed through GitHub Copilot (#10560)
const COPILOT_RESPONSES_MODEL_IDS = ['gpt-5-codex', 'gpt-5.1-codex', 'gpt-5.1-codex-mini']
export function isCopilotResponsesModel<M extends MinimalModel>(model: M): boolean {
const normalizedId = getLowerBaseModelName(model.id)
return COPILOT_RESPONSES_MODEL_IDS.some((target) => normalizedId === target)
}
-100
View File
@@ -1,100 +0,0 @@
/**
* Provider Type Detection Utilities
*
* Functions to detect provider types based on provider configuration.
* These are pure functions that only depend on provider.type and provider.id.
*
* NOTE: These functions should match the logic in @renderer/utils/provider.ts
*/
import type { MinimalProvider } from './types'
/**
* Check if provider is Anthropic type
*/
export function isAnthropicProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'anthropic'
}
/**
* Check if provider is OpenAI Response type (openai-response)
* NOTE: This matches isOpenAIProvider in renderer/utils/provider.ts
*/
export function isOpenAIProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'openai-response'
}
/**
* Check if provider is Gemini type
*/
export function isGeminiProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'gemini'
}
/**
* Check if provider is Azure OpenAI type
*/
export function isAzureOpenAIProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'azure-openai'
}
/**
* Check if provider is Vertex AI type
*/
export function isVertexProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'vertexai'
}
/**
* Check if provider is AWS Bedrock type
*/
export function isAwsBedrockProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'aws-bedrock'
}
/**
* Check if provider is AI Gateway type
*/
export function isAIGatewayProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'ai-gateway'
}
/**
* Check if Azure OpenAI provider uses responses endpoint
* Matches isAzureResponsesEndpoint in renderer/utils/provider.ts
*/
export function isAzureResponsesEndpoint<P extends MinimalProvider>(provider: P): boolean {
return provider.apiVersion === 'preview' || provider.apiVersion === 'v1'
}
/**
* Check if provider is Cherry AI type
* Matches isCherryAIProvider in renderer/utils/provider.ts
*/
export function isCherryAIProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.id === 'cherryai'
}
/**
* Check if provider is Perplexity type
* Matches isPerplexityProvider in renderer/utils/provider.ts
*/
export function isPerplexityProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.id === 'perplexity'
}
/**
* Check if provider is new-api type (supports multiple backends)
* Matches isNewApiProvider in renderer/utils/provider.ts
*/
export function isNewApiProvider<P extends MinimalProvider>(provider: P): boolean {
return ['new-api', 'cherryin'].includes(provider.id) || provider.type === ('new-api' as string)
}
/**
* Check if provider is OpenAI compatible
* Matches isOpenAICompatibleProvider in renderer/utils/provider.ts
*/
export function isOpenAICompatibleProvider<P extends MinimalProvider>(provider: P): boolean {
return ['openai', 'new-api', 'mistral'].includes(provider.type)
}
-136
View File
@@ -1,136 +0,0 @@
/**
* Provider API Host Formatting
*
* Utilities for formatting provider API hosts to work with AI SDK.
* These handle the differences between how Cherry Studio stores API hosts
* and how AI SDK expects them.
*/
import {
formatApiHost,
formatAzureOpenAIApiHost,
formatVertexApiHost,
routeToEndpoint,
withoutTrailingSlash
} from '../api'
import {
isAnthropicProvider,
isAzureOpenAIProvider,
isCherryAIProvider,
isGeminiProvider,
isPerplexityProvider,
isVertexProvider
} from './detection'
import type { MinimalProvider } from './types'
import { SystemProviderIds } from './types'
/**
* Interface for environment-specific implementations
* Renderer and Main process can provide their own implementations
*/
export interface ProviderFormatContext {
vertex: {
project: string
location: string
}
}
/**
* Default Azure OpenAI API host formatter
*/
export function defaultFormatAzureOpenAIApiHost(host: string): string {
const normalizedHost = withoutTrailingSlash(host)
?.replace(/\/v1$/, '')
.replace(/\/openai$/, '')
// AI SDK will add /v1
return formatApiHost(normalizedHost + '/openai', false)
}
/**
* Format provider API host for AI SDK
*
* This function normalizes the apiHost to work with AI SDK.
* Different providers have different requirements:
* - Most providers: add /v1 suffix
* - Gemini: add /v1beta suffix
* - Some providers: no suffix needed
*
* @param provider - The provider to format
* @param context - Optional context with environment-specific implementations
* @returns Provider with formatted apiHost (and anthropicApiHost if applicable)
*/
export function formatProviderApiHost<T extends MinimalProvider>(provider: T, context: ProviderFormatContext): T {
const formatted = { ...provider }
// Format anthropicApiHost if present
if (formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost)
}
// Format based on provider type
if (isAnthropicProvider(provider)) {
const baseHost = formatted.anthropicApiHost || formatted.apiHost
// AI SDK needs /v1 in baseURL
formatted.apiHost = formatApiHost(baseHost)
if (!formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatted.apiHost
}
} else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isGeminiProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, true, 'v1beta')
} else if (isAzureOpenAIProvider(formatted)) {
formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost)
} else if (isVertexProvider(formatted)) {
formatted.apiHost = formatVertexApiHost(formatted, context.vertex.project, context.vertex.location)
} else if (isCherryAIProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isPerplexityProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else {
formatted.apiHost = formatApiHost(formatted.apiHost)
}
return formatted
}
/**
* Get the base URL for AI SDK from a formatted provider
*
* This extracts the baseURL that AI SDK expects, handling
* the '#' endpoint routing format if present.
*
* @param formattedApiHost - The formatted apiHost (after formatProviderApiHost)
* @returns The baseURL for AI SDK
*/
export function getBaseUrlForAiSdk(formattedApiHost: string): string {
const { baseURL } = routeToEndpoint(formattedApiHost)
return baseURL
}
/**
* Get rotated API key from comma-separated keys
*
* This is the interface for API key rotation. The actual implementation
* depends on the environment (renderer uses window.keyv, main uses its own storage).
*/
export interface ApiKeyRotator {
/**
* Get the next API key in rotation
* @param providerId - The provider ID for tracking rotation
* @param keys - Comma-separated API keys
* @returns The next API key to use
*/
getRotatedKey(providerId: string, keys: string): string
}
/**
* Simple API key rotator that always returns the first key
* Use this when rotation is not needed
*/
export const simpleKeyRotator: ApiKeyRotator = {
getRotatedKey(_providerId: string, keys: string): string {
const keyList = keys.split(',').map((k) => k.trim())
return keyList[0] || keys
}
}
-48
View File
@@ -1,48 +0,0 @@
/**
* Shared Provider Utilities
*
* This module exports utilities for working with AI providers
* that can be shared between main process and renderer process.
*/
// Type definitions
export type { MinimalProvider, ProviderType, SystemProviderId } from './types'
export { SystemProviderIds } from './types'
// Provider type detection
export {
isAIGatewayProvider,
isAnthropicProvider,
isAwsBedrockProvider,
isAzureOpenAIProvider,
isAzureResponsesEndpoint,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isOpenAICompatibleProvider,
isOpenAIProvider,
isPerplexityProvider,
isVertexProvider
} from './detection'
// API host formatting
export type { ApiKeyRotator, ProviderFormatContext } from './format'
export {
defaultFormatAzureOpenAIApiHost,
formatProviderApiHost,
getBaseUrlForAiSdk,
simpleKeyRotator
} from './format'
// Provider ID mapping
export { getAiSdkProviderId, STATIC_PROVIDER_MAPPING, tryResolveProviderId } from './mapping'
// AI SDK configuration
export type { AiSdkConfig, AiSdkConfigContext } from './sdk-config'
export { providerToAiSdkConfig } from './sdk-config'
// Provider resolution
export { resolveActualProvider } from './resolve'
// Provider initialization
export { initializeSharedProviders, SHARED_PROVIDER_CONFIGS } from './initialization'
-107
View File
@@ -1,107 +0,0 @@
import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider'
type ProviderInitializationLogger = {
warn?: (message: string) => void
error?: (message: string, error: Error) => void
}
export const SHARED_PROVIDER_CONFIGS: ProviderConfig[] = [
{
id: 'openrouter',
name: 'OpenRouter',
import: () => import('@openrouter/ai-sdk-provider'),
creatorFunctionName: 'createOpenRouter',
supportsImageGeneration: true,
aliases: ['openrouter']
},
{
id: 'google-vertex',
name: 'Google Vertex AI',
import: () => import('@ai-sdk/google-vertex/edge'),
creatorFunctionName: 'createVertex',
supportsImageGeneration: true,
aliases: ['vertexai']
},
{
id: 'google-vertex-anthropic',
name: 'Google Vertex AI Anthropic',
import: () => import('@ai-sdk/google-vertex/anthropic/edge'),
creatorFunctionName: 'createVertexAnthropic',
supportsImageGeneration: true,
aliases: ['vertexai-anthropic']
},
{
id: 'azure-anthropic',
name: 'Azure AI Anthropic',
import: () => import('@ai-sdk/anthropic'),
creatorFunctionName: 'createAnthropic',
supportsImageGeneration: false,
aliases: ['azure-anthropic']
},
{
id: 'github-copilot-openai-compatible',
name: 'GitHub Copilot OpenAI Compatible',
import: () => import('@opeoginni/github-copilot-openai-compatible'),
creatorFunctionName: 'createGitHubCopilotOpenAICompatible',
supportsImageGeneration: false,
aliases: ['copilot', 'github-copilot']
},
{
id: 'bedrock',
name: 'Amazon Bedrock',
import: () => import('@ai-sdk/amazon-bedrock'),
creatorFunctionName: 'createAmazonBedrock',
supportsImageGeneration: true,
aliases: ['aws-bedrock']
},
{
id: 'perplexity',
name: 'Perplexity',
import: () => import('@ai-sdk/perplexity'),
creatorFunctionName: 'createPerplexity',
supportsImageGeneration: false,
aliases: ['perplexity']
},
{
id: 'mistral',
name: 'Mistral',
import: () => import('@ai-sdk/mistral'),
creatorFunctionName: 'createMistral',
supportsImageGeneration: false,
aliases: ['mistral']
},
{
id: 'huggingface',
name: 'HuggingFace',
import: () => import('@ai-sdk/huggingface'),
creatorFunctionName: 'createHuggingFace',
supportsImageGeneration: true,
aliases: ['hf', 'hugging-face']
},
{
id: 'ai-gateway',
name: 'AI Gateway',
import: () => import('@ai-sdk/gateway'),
creatorFunctionName: 'createGateway',
supportsImageGeneration: true,
aliases: ['gateway']
},
{
id: 'cerebras',
name: 'Cerebras',
import: () => import('@ai-sdk/cerebras'),
creatorFunctionName: 'createCerebras',
supportsImageGeneration: false
}
] as const
export function initializeSharedProviders(logger?: ProviderInitializationLogger): void {
try {
const successCount = registerMultipleProviderConfigs(SHARED_PROVIDER_CONFIGS)
if (successCount < SHARED_PROVIDER_CONFIGS.length) {
logger?.warn?.('Some providers failed to register. Check previous error logs.')
}
} catch (error) {
logger?.error?.('Failed to initialize shared providers', error as Error)
}
}
-95
View File
@@ -1,95 +0,0 @@
/**
* Provider ID Mapping
*
* Maps Cherry Studio provider IDs/types to AI SDK provider IDs.
* This logic should match @renderer/aiCore/provider/factory.ts
*/
import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider'
import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from './detection'
import type { MinimalProvider } from './types'
/**
* Static mapping from Cherry Studio provider ID/type to AI SDK provider ID
* Matches STATIC_PROVIDER_MAPPING in @renderer/aiCore/provider/factory.ts
*/
export const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
gemini: 'google', // Google Gemini -> google
'azure-openai': 'azure', // Azure OpenAI -> azure
'openai-response': 'openai', // OpenAI Responses -> openai
grok: 'xai', // Grok -> xai
copilot: 'github-copilot-openai-compatible'
}
/**
* Try to resolve a provider identifier to an AI SDK provider ID
* Matches tryResolveProviderId in @renderer/aiCore/provider/factory.ts
*
* @param identifier - The provider ID or type to resolve
* @param checker - Provider config checker (defaults to static mapping only)
* @returns The resolved AI SDK provider ID, or null if not found
*/
export function tryResolveProviderId(identifier: string): ProviderId | null {
// 1. 检查静态映射
const staticMapping = STATIC_PROVIDER_MAPPING[identifier]
if (staticMapping) {
return staticMapping
}
// 2. 检查AiCore是否支持(包括别名支持)
if (hasProviderConfigByAlias(identifier)) {
// 解析为真实的Provider ID
return resolveProviderConfigId(identifier) as ProviderId
}
return null
}
/**
* Get the AI SDK Provider ID for a Cherry Studio provider
* Matches getAiSdkProviderId in @renderer/aiCore/provider/factory.ts
*
* Logic:
* 1. Handle Azure OpenAI specially (check responses endpoint)
* 2. Try to resolve from provider.id
* 3. Try to resolve from provider.type (but not for generic 'openai' type)
* 4. Check for OpenAI API host pattern
* 5. Fallback to provider's own ID
*
* @param provider - The Cherry Studio provider
* @param checker - Provider config checker (defaults to static mapping only)
* @returns The AI SDK provider ID to use
*/
export function getAiSdkProviderId(provider: MinimalProvider): ProviderId {
// 1. Handle Azure OpenAI specially - check this FIRST before other resolution
if (isAzureOpenAIProvider(provider)) {
if (isAzureResponsesEndpoint(provider)) {
return 'azure-responses'
}
return 'azure'
}
// 2. 尝试解析provider.id
const resolvedFromId = tryResolveProviderId(provider.id)
if (resolvedFromId) {
return resolvedFromId
}
// 3. 尝试解析provider.type
// 会把所有类型为openai的自定义provider解析到aisdk的openaiProvider上
if (provider.type !== 'openai') {
const resolvedFromType = tryResolveProviderId(provider.type)
if (resolvedFromType) {
return resolvedFromType
}
}
// 4. Check for OpenAI API host pattern
if (provider.apiHost.includes('api.openai.com')) {
return 'openai-chat'
}
// 5. 最后的fallback(使用provider本身的id
return provider.id
}
-43
View File
@@ -1,43 +0,0 @@
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
import { isAzureOpenAIProvider, isNewApiProvider } from './detection'
import type { MinimalModel, MinimalProvider } from './types'
export interface ResolveActualProviderOptions<P extends MinimalProvider> {
isSystemProvider?: (provider: P) => boolean
}
const defaultIsSystemProvider = <P extends MinimalProvider>(provider: P): boolean => {
if ('isSystem' in provider) {
return Boolean((provider as unknown as { isSystem?: boolean }).isSystem)
}
return false
}
export function resolveActualProvider<M extends MinimalModel, P extends MinimalProvider>(
provider: P,
model: M,
options: ResolveActualProviderOptions<P> = {}
): P {
let resolvedProvider = provider
if (isNewApiProvider(resolvedProvider)) {
resolvedProvider = newApiResolverCreator(model, resolvedProvider)
}
const isSystemProvider = options.isSystemProvider?.(resolvedProvider) ?? defaultIsSystemProvider(resolvedProvider)
if (isSystemProvider && resolvedProvider.id === 'aihubmix') {
resolvedProvider = aihubmixProviderCreator(model, resolvedProvider)
}
if (isSystemProvider && resolvedProvider.id === 'vertexai') {
resolvedProvider = vertexAnthropicProviderCreator(model, resolvedProvider)
}
if (isAzureOpenAIProvider(resolvedProvider)) {
resolvedProvider = azureAnthropicProviderCreator(model, resolvedProvider)
}
return resolvedProvider
}
-259
View File
@@ -1,259 +0,0 @@
/**
* AI SDK Configuration
*
* Shared utilities for converting Cherry Studio Provider to AI SDK configuration.
* Environment-specific logic (renderer/main) is injected via context interfaces.
*/
import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
import { routeToEndpoint } from '../api'
import { getAiSdkProviderId } from './mapping'
import type { MinimalProvider } from './types'
import { SystemProviderIds } from './types'
/**
* AI SDK configuration result
*/
export interface AiSdkConfig {
providerId: string
options: Record<string, unknown>
}
/**
* Context for environment-specific implementations
*/
export interface AiSdkConfigContext {
/**
* Get the rotated API key (for multi-key support)
* Default: returns first key
*/
getRotatedApiKey?: (provider: MinimalProvider) => string
/**
* Check if a model uses chat completion only (for OpenAI response mode)
* Default: returns false
*/
isOpenAIChatCompletionOnlyModel?: (modelId: string) => boolean
/**
* Get Copilot default headers (constants)
* Default: returns empty object
*/
getCopilotDefaultHeaders?: () => Record<string, string>
/**
* Get Copilot stored headers from state
* Default: returns empty object
*/
getCopilotStoredHeaders?: () => Record<string, string>
/**
* Get AWS Bedrock configuration
* Default: returns undefined (not configured)
*/
getAwsBedrockConfig?: () =>
| {
authType: 'apiKey' | 'iam'
region: string
apiKey?: string
accessKeyId?: string
secretAccessKey?: string
}
| undefined
/**
* Get Vertex AI configuration
* Default: returns undefined (not configured)
*/
getVertexConfig?: (provider: MinimalProvider) =>
| {
project: string
location: string
googleCredentials: {
privateKey: string
clientEmail: string
}
}
| undefined
/**
* Get endpoint type for cherryin provider
*/
getEndpointType?: (modelId: string) => string | undefined
/**
* Custom fetch implementation
* Main process: use Electron net.fetch
* Renderer process: use browser fetch (default)
*/
fetch?: typeof globalThis.fetch
/**
* Get CherryAI signed fetch wrapper
* Returns a fetch function that adds signature headers to requests
*/
getCherryAISignedFetch?: () => typeof globalThis.fetch
}
/**
* Default simple key rotator - returns first key
*/
function defaultGetRotatedApiKey(provider: MinimalProvider): string {
const keys = provider.apiKey.split(',').map((k) => k.trim())
return keys[0] || provider.apiKey
}
/**
* Convert Cherry Studio Provider to AI SDK configuration
*
* @param provider - The formatted provider (after formatProviderApiHost)
* @param modelId - The model ID to use
* @param context - Environment-specific implementations
* @returns AI SDK configuration
*/
export function providerToAiSdkConfig(
provider: MinimalProvider,
modelId: string,
context: AiSdkConfigContext = {}
): AiSdkConfig {
const getRotatedApiKey = context.getRotatedApiKey || defaultGetRotatedApiKey
const isOpenAIChatCompletionOnlyModel = context.isOpenAIChatCompletionOnlyModel || (() => false)
const aiSdkProviderId = getAiSdkProviderId(provider)
// Build base config
const { baseURL, endpoint } = routeToEndpoint(provider.apiHost)
const baseConfig = {
baseURL,
apiKey: getRotatedApiKey(provider)
}
// Handle Copilot specially
if (provider.id === SystemProviderIds.copilot) {
const defaultHeaders = context.getCopilotDefaultHeaders?.() ?? {}
const storedHeaders = context.getCopilotStoredHeaders?.() ?? {}
const copilotExtraOptions: Record<string, unknown> = {
headers: {
...defaultHeaders,
...storedHeaders,
...provider.extra_headers
},
name: provider.id,
includeUsage: true
}
if (context.fetch) {
copilotExtraOptions.fetch = context.fetch
}
const options = ProviderConfigFactory.fromProvider(
'github-copilot-openai-compatible',
baseConfig,
copilotExtraOptions
)
return {
providerId: 'github-copilot-openai-compatible',
options
}
}
// Build extra options
const extraOptions: Record<string, unknown> = {}
if (endpoint) {
extraOptions.endpoint = endpoint
}
// Handle OpenAI mode
if (provider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(modelId)) {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && provider.type === 'openai')) {
extraOptions.mode = 'chat'
}
// Add extra headers
if (provider.extra_headers) {
extraOptions.headers = provider.extra_headers
if (aiSdkProviderId === 'openai') {
extraOptions.headers = {
...(extraOptions.headers as Record<string, string>),
'HTTP-Referer': 'https://cherry-ai.com',
'X-Title': 'Cherry Studio',
'X-Api-Key': baseConfig.apiKey
}
}
}
// Handle Azure modes
if (aiSdkProviderId === 'azure-responses') {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'azure') {
extraOptions.mode = 'chat'
}
// Handle AWS Bedrock
if (aiSdkProviderId === 'bedrock') {
const bedrockConfig = context.getAwsBedrockConfig?.()
if (bedrockConfig) {
extraOptions.region = bedrockConfig.region
if (bedrockConfig.authType === 'apiKey') {
extraOptions.apiKey = bedrockConfig.apiKey
} else {
extraOptions.accessKeyId = bedrockConfig.accessKeyId
extraOptions.secretAccessKey = bedrockConfig.secretAccessKey
}
}
}
// Handle Vertex AI
if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') {
const vertexConfig = context.getVertexConfig?.(provider)
if (vertexConfig) {
extraOptions.project = vertexConfig.project
extraOptions.location = vertexConfig.location
extraOptions.googleCredentials = {
...vertexConfig.googleCredentials,
privateKey: formatPrivateKey(vertexConfig.googleCredentials.privateKey)
}
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
}
}
// Handle cherryin endpoint type
if (aiSdkProviderId === 'cherryin') {
const endpointType = context.getEndpointType?.(modelId)
if (endpointType) {
extraOptions.endpointType = endpointType
}
}
// Handle cherryai signed fetch
if (provider.id === 'cherryai') {
const signedFetch = context.getCherryAISignedFetch?.()
if (signedFetch) {
extraOptions.fetch = signedFetch
}
} else if (context.fetch) {
extraOptions.fetch = context.fetch
}
// Check if AI SDK supports this provider natively
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return {
providerId: aiSdkProviderId,
options
}
}
// Fallback to openai-compatible
const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey)
return {
providerId: 'openai-compatible',
options: {
...options,
name: provider.id,
...extraOptions,
includeUsage: true
}
}
}
-174
View File
@@ -1,174 +0,0 @@
import * as z from 'zod'
export const ProviderTypeSchema = z.enum([
'openai',
'openai-response',
'anthropic',
'gemini',
'azure-openai',
'vertexai',
'mistral',
'aws-bedrock',
'vertex-anthropic',
'new-api',
'ai-gateway'
])
export type ProviderType = z.infer<typeof ProviderTypeSchema>
/**
* Minimal provider interface for shared utilities
* This is the subset of Provider that shared code needs
*/
export type MinimalProvider = {
id: string
type: ProviderType
apiKey: string
apiHost: string
anthropicApiHost?: string
apiVersion?: string
extra_headers?: Record<string, string>
}
/**
* Minimal model interface for shared utilities
* This is the subset of Model that shared code needs
*/
export type MinimalModel = {
id: string
endpoint_type?: string
}
export const SystemProviderIdSchema = z.enum([
'cherryin',
'silicon',
'aihubmix',
'ocoolai',
'deepseek',
'ppio',
'alayanew',
'qiniu',
'dmxapi',
'burncloud',
'tokenflux',
'302ai',
'cephalon',
'lanyun',
'ph8',
'openrouter',
'ollama',
'ovms',
'new-api',
'lmstudio',
'anthropic',
'openai',
'azure-openai',
'gemini',
'vertexai',
'github',
'copilot',
'zhipu',
'yi',
'moonshot',
'baichuan',
'dashscope',
'stepfun',
'doubao',
'infini',
'minimax',
'groq',
'together',
'fireworks',
'nvidia',
'grok',
'hyperbolic',
'mistral',
'jina',
'perplexity',
'modelscope',
'xirang',
'hunyuan',
'tencent-cloud-ti',
'baidu-cloud',
'gpustack',
'voyageai',
'aws-bedrock',
'poe',
'aionly',
'longcat',
'huggingface',
'sophnet',
'ai-gateway',
'cerebras'
])
export type SystemProviderId = z.infer<typeof SystemProviderIdSchema>
export const isSystemProviderId = (id: string): id is SystemProviderId => {
return SystemProviderIdSchema.safeParse(id).success
}
export const SystemProviderIds = {
cherryin: 'cherryin',
silicon: 'silicon',
aihubmix: 'aihubmix',
ocoolai: 'ocoolai',
deepseek: 'deepseek',
ppio: 'ppio',
alayanew: 'alayanew',
qiniu: 'qiniu',
dmxapi: 'dmxapi',
burncloud: 'burncloud',
tokenflux: 'tokenflux',
'302ai': '302ai',
cephalon: 'cephalon',
lanyun: 'lanyun',
ph8: 'ph8',
sophnet: 'sophnet',
openrouter: 'openrouter',
ollama: 'ollama',
ovms: 'ovms',
'new-api': 'new-api',
lmstudio: 'lmstudio',
anthropic: 'anthropic',
openai: 'openai',
'azure-openai': 'azure-openai',
gemini: 'gemini',
vertexai: 'vertexai',
github: 'github',
copilot: 'copilot',
zhipu: 'zhipu',
yi: 'yi',
moonshot: 'moonshot',
baichuan: 'baichuan',
dashscope: 'dashscope',
stepfun: 'stepfun',
doubao: 'doubao',
infini: 'infini',
minimax: 'minimax',
groq: 'groq',
together: 'together',
fireworks: 'fireworks',
nvidia: 'nvidia',
grok: 'grok',
hyperbolic: 'hyperbolic',
mistral: 'mistral',
jina: 'jina',
perplexity: 'perplexity',
modelscope: 'modelscope',
xirang: 'xirang',
hunyuan: 'hunyuan',
'tencent-cloud-ti': 'tencent-cloud-ti',
'baidu-cloud': 'baidu-cloud',
gpustack: 'gpustack',
voyageai: 'voyageai',
'aws-bedrock': 'aws-bedrock',
poe: 'poe',
aionly: 'aionly',
longcat: 'longcat',
huggingface: 'huggingface',
'ai-gateway': 'ai-gateway',
cerebras: 'cerebras'
} as const satisfies Record<SystemProviderId, SystemProviderId>
export type SystemProviderIdTypeMap = typeof SystemProviderIds
-1
View File
@@ -1 +0,0 @@
export { getBaseModelName, getLowerBaseModelName } from './naming'
-31
View File
@@ -1,31 +0,0 @@
/**
* 从模型 ID 中提取基础名称。
* 例如:
* - 'deepseek/deepseek-r1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/deepseek-r1' => 'deepseek-r1'
* @param {string} id 模型 ID
* @param {string} [delimiter='/'] 分隔符,默认为 '/'
* @returns {string} 基础名称
*/
export const getBaseModelName = (id: string, delimiter: string = '/'): string => {
const parts = id.split(delimiter)
return parts[parts.length - 1]
}
/**
* 从模型 ID 中提取基础名称并转换为小写。
* 例如:
* - 'deepseek/DeepSeek-R1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/DeepSeek-R1' => 'deepseek-r1'
* @param {string} id 模型 ID
* @param {string} [delimiter='/'] 分隔符,默认为 '/'
* @returns {string} 小写的基础名称
*/
export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => {
const baseModelName = getBaseModelName(id, delimiter).toLowerCase()
// for openrouter
if (baseModelName.endsWith(':free')) {
return baseModelName.replace(':free', '')
}
return baseModelName
}
@@ -1,638 +0,0 @@
/**
* AI SDK to Anthropic SSE Adapter
*
* Converts AI SDK's fullStream (TextStreamPart) events to Anthropic Messages API SSE format.
* This enables any AI provider supported by AI SDK to be exposed via Anthropic-compatible API.
*
* Anthropic SSE Event Flow:
* 1. message_start - Initial message with metadata
* 2. content_block_start - Begin a content block (text, tool_use, thinking)
* 3. content_block_delta - Incremental content updates
* 4. content_block_stop - End a content block
* 5. message_delta - Updates to overall message (stop_reason, usage)
* 6. message_stop - Stream complete
*
* @see https://docs.anthropic.com/en/api/messages-streaming
*/
import type {
ContentBlock,
InputJSONDelta,
Message,
MessageDeltaUsage,
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawContentBlockStopEvent,
RawMessageDeltaEvent,
RawMessageStartEvent,
RawMessageStopEvent,
RawMessageStreamEvent,
StopReason,
TextBlock,
TextDelta,
ThinkingBlock,
ThinkingDelta,
ToolUseBlock,
Usage
} from '@anthropic-ai/sdk/resources/messages'
import { loggerService } from '@logger'
import { type FinishReason, type LanguageModelUsage, type TextStreamPart, type ToolSet } from 'ai'
import { googleReasoningCache, openRouterReasoningCache } from '../../services/CacheService'
const logger = loggerService.withContext('AiSdkToAnthropicSSE')
interface ContentBlockState {
type: 'text' | 'tool_use' | 'thinking'
index: number
started: boolean
content: string
// For tool_use blocks
toolId?: string
toolName?: string
toolInput?: string
}
interface AdapterState {
messageId: string
model: string
inputTokens: number
outputTokens: number
cacheInputTokens: number
currentBlockIndex: number
blocks: Map<number, ContentBlockState>
textBlockIndex: number | null
// Track multiple thinking blocks by their reasoning ID
thinkingBlocks: Map<string, number> // reasoningId -> blockIndex
currentThinkingId: string | null // Currently active thinking block ID
toolBlocks: Map<string, number> // toolCallId -> blockIndex
stopReason: StopReason | null
hasEmittedMessageStart: boolean
}
export type SSEEventCallback = (event: RawMessageStreamEvent) => void
export interface AiSdkToAnthropicSSEOptions {
model: string
messageId?: string
inputTokens?: number
onEvent: SSEEventCallback
}
/**
* Adapter that converts AI SDK fullStream events to Anthropic SSE events
*/
export class AiSdkToAnthropicSSE {
private state: AdapterState
private onEvent: SSEEventCallback
constructor(options: AiSdkToAnthropicSSEOptions) {
this.onEvent = options.onEvent
this.state = {
messageId: options.messageId || `msg_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`,
model: options.model,
inputTokens: options.inputTokens || 0,
outputTokens: 0,
cacheInputTokens: 0,
currentBlockIndex: 0,
blocks: new Map(),
textBlockIndex: null,
thinkingBlocks: new Map(),
currentThinkingId: null,
toolBlocks: new Map(),
stopReason: null,
hasEmittedMessageStart: false
}
}
/**
* Process the AI SDK stream and emit Anthropic SSE events
*/
async processStream(fullStream: ReadableStream<TextStreamPart<ToolSet>>): Promise<void> {
const reader = fullStream.getReader()
try {
// Emit message_start at the beginning
this.emitMessageStart()
while (true) {
const { done, value } = await reader.read()
if (done) {
break
}
this.processChunk(value)
}
// Ensure all blocks are closed and emit final events
this.finalize()
} catch (error) {
await reader.cancel()
throw error
} finally {
reader.releaseLock()
}
}
/**
* Process a single AI SDK chunk and emit corresponding Anthropic events
*/
private processChunk(chunk: TextStreamPart<ToolSet>): void {
logger.silly('AiSdkToAnthropicSSE - Processing chunk:', { chunk: JSON.stringify(chunk) })
switch (chunk.type) {
// === Text Events ===
case 'text-start':
this.startTextBlock()
break
case 'text-delta':
this.emitTextDelta(chunk.text || '')
break
case 'text-end':
this.stopTextBlock()
break
// === Reasoning/Thinking Events ===
case 'reasoning-start': {
const reasoningId = chunk.id
this.startThinkingBlock(reasoningId)
break
}
case 'reasoning-delta': {
const reasoningId = chunk.id
this.emitThinkingDelta(chunk.text || '', reasoningId)
break
}
case 'reasoning-end': {
const reasoningId = chunk.id
this.stopThinkingBlock(reasoningId)
break
}
// === Tool Events ===
case 'tool-call':
if (googleReasoningCache && chunk.providerMetadata?.google?.thoughtSignature) {
googleReasoningCache.set(
`google-${chunk.toolName}`,
chunk.providerMetadata?.google?.thoughtSignature as string
)
}
// FIXME: 按toolcall id绑定
if (
openRouterReasoningCache &&
chunk.providerMetadata?.openrouter?.reasoning_details &&
Array.isArray(chunk.providerMetadata.openrouter.reasoning_details)
) {
openRouterReasoningCache.set(
'openrouter',
JSON.parse(JSON.stringify(chunk.providerMetadata.openrouter.reasoning_details))
)
}
this.handleToolCall({
type: 'tool-call',
toolCallId: chunk.toolCallId,
toolName: chunk.toolName,
args: chunk.input
})
break
case 'tool-result':
// this.handleToolResult({
// type: 'tool-result',
// toolCallId: chunk.toolCallId,
// toolName: chunk.toolName,
// args: chunk.input,
// result: chunk.output
// })
break
case 'finish-step':
if (chunk.finishReason === 'tool-calls') {
this.state.stopReason = 'tool_use'
}
break
case 'finish':
this.handleFinish(chunk)
break
case 'error':
throw chunk.error
// Ignore other event types
default:
break
}
}
private emitMessageStart(): void {
if (this.state.hasEmittedMessageStart) return
this.state.hasEmittedMessageStart = true
const usage: Usage = {
input_tokens: this.state.inputTokens,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
server_tool_use: null
}
const message: Message = {
id: this.state.messageId,
type: 'message',
role: 'assistant',
content: [],
model: this.state.model,
stop_reason: null,
stop_sequence: null,
usage
}
const event: RawMessageStartEvent = {
type: 'message_start',
message
}
this.onEvent(event)
}
private startTextBlock(): void {
// If we already have a text block, don't create another
if (this.state.textBlockIndex !== null) return
const index = this.state.currentBlockIndex++
this.state.textBlockIndex = index
this.state.blocks.set(index, {
type: 'text',
index,
started: true,
content: ''
})
const contentBlock: TextBlock = {
type: 'text',
text: '',
citations: null
}
const event: RawContentBlockStartEvent = {
type: 'content_block_start',
index,
content_block: contentBlock
}
this.onEvent(event)
}
private emitTextDelta(text: string): void {
if (!text) return
// Auto-start text block if not started
if (this.state.textBlockIndex === null) {
this.startTextBlock()
}
const index = this.state.textBlockIndex!
const block = this.state.blocks.get(index)
if (block) {
block.content += text
}
const delta: TextDelta = {
type: 'text_delta',
text
}
const event: RawContentBlockDeltaEvent = {
type: 'content_block_delta',
index,
delta
}
this.onEvent(event)
}
private stopTextBlock(): void {
if (this.state.textBlockIndex === null) return
const index = this.state.textBlockIndex
const event: RawContentBlockStopEvent = {
type: 'content_block_stop',
index
}
this.onEvent(event)
this.state.textBlockIndex = null
}
private startThinkingBlock(reasoningId: string): void {
// Check if this thinking block already exists
if (this.state.thinkingBlocks.has(reasoningId)) return
const index = this.state.currentBlockIndex++
this.state.thinkingBlocks.set(reasoningId, index)
this.state.currentThinkingId = reasoningId
this.state.blocks.set(index, {
type: 'thinking',
index,
started: true,
content: ''
})
const contentBlock: ThinkingBlock = {
type: 'thinking',
thinking: '',
signature: ''
}
const event: RawContentBlockStartEvent = {
type: 'content_block_start',
index,
content_block: contentBlock
}
this.onEvent(event)
}
private emitThinkingDelta(text: string, reasoningId?: string): void {
if (!text) return
// Determine which thinking block to use
const targetId = reasoningId || this.state.currentThinkingId
if (!targetId) {
// Auto-start thinking block if not started
const newId = `reasoning_${Date.now()}`
this.startThinkingBlock(newId)
return this.emitThinkingDelta(text, newId)
}
const index = this.state.thinkingBlocks.get(targetId)
if (index === undefined) {
// If the block doesn't exist, create it
this.startThinkingBlock(targetId)
return this.emitThinkingDelta(text, targetId)
}
const block = this.state.blocks.get(index)
if (block) {
block.content += text
}
const delta: ThinkingDelta = {
type: 'thinking_delta',
thinking: text
}
const event: RawContentBlockDeltaEvent = {
type: 'content_block_delta',
index,
delta
}
this.onEvent(event)
}
private stopThinkingBlock(reasoningId?: string): void {
const targetId = reasoningId || this.state.currentThinkingId
if (!targetId) return
const index = this.state.thinkingBlocks.get(targetId)
if (index === undefined) return
const event: RawContentBlockStopEvent = {
type: 'content_block_stop',
index
}
this.onEvent(event)
this.state.thinkingBlocks.delete(targetId)
// Update currentThinkingId if we just closed the current one
if (this.state.currentThinkingId === targetId) {
// Set to the most recent remaining thinking block, or null if none
const remaining = Array.from(this.state.thinkingBlocks.keys())
this.state.currentThinkingId = remaining.length > 0 ? remaining[remaining.length - 1] : null
}
}
private handleToolCall(chunk: { type: 'tool-call'; toolCallId: string; toolName: string; args: unknown }): void {
const { toolCallId, toolName, args } = chunk
// Check if we already have this tool call
if (this.state.toolBlocks.has(toolCallId)) {
return
}
const index = this.state.currentBlockIndex++
this.state.toolBlocks.set(toolCallId, index)
const inputJson = JSON.stringify(args)
this.state.blocks.set(index, {
type: 'tool_use',
index,
started: true,
content: inputJson,
toolId: toolCallId,
toolName,
toolInput: inputJson
})
// Emit content_block_start for tool_use
const contentBlock: ToolUseBlock = {
type: 'tool_use',
id: toolCallId,
name: toolName,
input: {}
}
const startEvent: RawContentBlockStartEvent = {
type: 'content_block_start',
index,
content_block: contentBlock
}
this.onEvent(startEvent)
// Emit the full input as a delta (Anthropic streams JSON incrementally)
const delta: InputJSONDelta = {
type: 'input_json_delta',
partial_json: inputJson
}
const deltaEvent: RawContentBlockDeltaEvent = {
type: 'content_block_delta',
index,
delta
}
this.onEvent(deltaEvent)
// Emit content_block_stop
const stopEvent: RawContentBlockStopEvent = {
type: 'content_block_stop',
index
}
this.onEvent(stopEvent)
// Mark that we have tool use
this.state.stopReason = 'tool_use'
}
private handleFinish(chunk: { type: 'finish'; finishReason?: FinishReason; totalUsage?: LanguageModelUsage }): void {
// Update usage
if (chunk.totalUsage) {
this.state.inputTokens = chunk.totalUsage.inputTokens || 0
this.state.outputTokens = chunk.totalUsage.outputTokens || 0
this.state.cacheInputTokens = chunk.totalUsage.cachedInputTokens || 0
}
// Determine finish reason
if (!this.state.stopReason) {
switch (chunk.finishReason) {
case 'stop':
this.state.stopReason = 'end_turn'
break
case 'length':
this.state.stopReason = 'max_tokens'
break
case 'tool-calls':
this.state.stopReason = 'tool_use'
break
case 'content-filter':
this.state.stopReason = 'refusal'
break
default:
this.state.stopReason = 'end_turn'
}
}
}
private finalize(): void {
// Close any open blocks
if (this.state.textBlockIndex !== null) {
this.stopTextBlock()
}
// Close all open thinking blocks
for (const reasoningId of this.state.thinkingBlocks.keys()) {
this.stopThinkingBlock(reasoningId)
}
// Emit message_delta with final stop reason and usage
const usage: MessageDeltaUsage = {
output_tokens: this.state.outputTokens,
input_tokens: this.state.inputTokens,
cache_creation_input_tokens: this.state.cacheInputTokens,
cache_read_input_tokens: null,
server_tool_use: null
}
const messageDeltaEvent: RawMessageDeltaEvent = {
type: 'message_delta',
delta: {
stop_reason: this.state.stopReason || 'end_turn',
stop_sequence: null
},
usage
}
this.onEvent(messageDeltaEvent)
// Emit message_stop
const messageStopEvent: RawMessageStopEvent = {
type: 'message_stop'
}
this.onEvent(messageStopEvent)
}
/**
* Set input token count (typically from prompt)
*/
setInputTokens(count: number): void {
this.state.inputTokens = count
}
/**
* Get the current message ID
*/
getMessageId(): string {
return this.state.messageId
}
/**
* Build a complete Message object for non-streaming responses
*/
buildNonStreamingResponse(): Message {
const content: ContentBlock[] = []
// Collect all content blocks in order
const sortedBlocks = Array.from(this.state.blocks.values()).sort((a, b) => a.index - b.index)
for (const block of sortedBlocks) {
switch (block.type) {
case 'text':
content.push({
type: 'text',
text: block.content,
citations: null
} as TextBlock)
break
case 'thinking':
content.push({
type: 'thinking',
thinking: block.content
} as ThinkingBlock)
break
case 'tool_use':
content.push({
type: 'tool_use',
id: block.toolId!,
name: block.toolName!,
input: JSON.parse(block.toolInput || '{}')
} as ToolUseBlock)
break
}
}
return {
id: this.state.messageId,
type: 'message',
role: 'assistant',
content,
model: this.state.model,
stop_reason: this.state.stopReason || 'end_turn',
stop_sequence: null,
usage: {
input_tokens: this.state.inputTokens,
output_tokens: this.state.outputTokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
server_tool_use: null
}
}
}
}
/**
* Format an Anthropic SSE event for HTTP streaming
*/
export function formatSSEEvent(event: RawMessageStreamEvent): string {
return `event: ${event.type}\ndata: ${JSON.stringify(event)}\n\n`
}
/**
* Create a done marker for SSE stream
*/
export function formatSSEDone(): string {
return 'data: [DONE]\n\n'
}
export default AiSdkToAnthropicSSE
-13
View File
@@ -1,13 +0,0 @@
/**
* Shared Adapters
*
* This module exports adapters for converting between different AI API formats.
*/
export {
AiSdkToAnthropicSSE,
type AiSdkToAnthropicSSEOptions,
formatSSEDone,
formatSSEEvent,
type SSEEventCallback
} from './AiSdkToAnthropicSSE'
-95
View File
@@ -1,95 +0,0 @@
import * as z from 'zod/v4'
enum ReasoningFormat {
Unknown = 'unknown',
OpenAIResponsesV1 = 'openai-responses-v1',
XAIResponsesV1 = 'xai-responses-v1',
AnthropicClaudeV1 = 'anthropic-claude-v1',
GoogleGeminiV1 = 'google-gemini-v1'
}
// Anthropic Claude was the first reasoning that we're
// passing back and forth
export const DEFAULT_REASONING_FORMAT = ReasoningFormat.AnthropicClaudeV1
function isDefinedOrNotNull<T>(value: T | null | undefined): value is T {
return value !== null && value !== undefined
}
export enum ReasoningDetailType {
Summary = 'reasoning.summary',
Encrypted = 'reasoning.encrypted',
Text = 'reasoning.text'
}
export const CommonReasoningDetailSchema = z
.object({
id: z.string().nullish(),
format: z.enum(ReasoningFormat).nullish(),
index: z.number().optional()
})
.loose()
export const ReasoningDetailSummarySchema = z
.object({
type: z.literal(ReasoningDetailType.Summary),
summary: z.string()
})
.extend(CommonReasoningDetailSchema.shape)
export type ReasoningDetailSummary = z.infer<typeof ReasoningDetailSummarySchema>
export const ReasoningDetailEncryptedSchema = z
.object({
type: z.literal(ReasoningDetailType.Encrypted),
data: z.string()
})
.extend(CommonReasoningDetailSchema.shape)
export type ReasoningDetailEncrypted = z.infer<typeof ReasoningDetailEncryptedSchema>
export const ReasoningDetailTextSchema = z
.object({
type: z.literal(ReasoningDetailType.Text),
text: z.string().nullish(),
signature: z.string().nullish()
})
.extend(CommonReasoningDetailSchema.shape)
export type ReasoningDetailText = z.infer<typeof ReasoningDetailTextSchema>
export const ReasoningDetailUnionSchema = z.union([
ReasoningDetailSummarySchema,
ReasoningDetailEncryptedSchema,
ReasoningDetailTextSchema
])
export type ReasoningDetailUnion = z.infer<typeof ReasoningDetailUnionSchema>
const ReasoningDetailsWithUnknownSchema = z.union([ReasoningDetailUnionSchema, z.unknown().transform(() => null)])
export const ReasoningDetailArraySchema = z
.array(ReasoningDetailsWithUnknownSchema)
.transform((d) => d.filter((d): d is ReasoningDetailUnion => !!d))
export const OutputUnionToReasoningDetailsSchema = z.union([
z
.object({
delta: z.object({
reasoning_details: z.array(ReasoningDetailsWithUnknownSchema)
})
})
.transform((data) => data.delta.reasoning_details.filter(isDefinedOrNotNull)),
z
.object({
message: z.object({
reasoning_details: z.array(ReasoningDetailsWithUnknownSchema)
})
})
.transform((data) => data.message.reasoning_details.filter(isDefinedOrNotNull)),
z
.object({
text: z.string(),
reasoning_details: z.array(ReasoningDetailsWithUnknownSchema)
})
.transform((data) => data.reasoning_details.filter(isDefinedOrNotNull))
])
+13 -330
View File
@@ -1,93 +1,17 @@
import type { MessageCreateParams } from '@anthropic-ai/sdk/resources' import type { MessageCreateParams } from '@anthropic-ai/sdk/resources'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { buildSharedMiddlewares, type SharedMiddlewareConfig } from '@shared/middleware'
import { getAiSdkProviderId } from '@shared/provider'
import type { Provider } from '@types' import type { Provider } from '@types'
import type { Request, Response } from 'express' import type { Request, Response } from 'express'
import express from 'express' import express from 'express'
import { messagesService } from '../services/messages' import { messagesService } from '../services/messages'
import { generateUnifiedMessage, streamUnifiedMessages } from '../services/unified-messages' import { getProviderById, validateModelId } from '../utils'
import { getProviderById, isModelAnthropicCompatible, validateModelId } from '../utils'
/**
* Check if a specific model on a provider should use direct Anthropic SDK
*
* A provider+model combination is considered "Anthropic-compatible" if:
* 1. It's a native Anthropic provider (type === 'anthropic'), OR
* 2. It has anthropicApiHost configured AND the specific model supports Anthropic API
* (for aggregated providers like Silicon, only certain models support Anthropic endpoint)
*
* @param provider - The provider to check
* @param modelId - The model ID to check (without provider prefix)
* @returns true if should use direct Anthropic SDK, false for unified SDK
*/
function shouldUseDirectAnthropic(provider: Provider, modelId: string): boolean {
// Native Anthropic provider - always use direct SDK
if (provider.type === 'anthropic') {
return true
}
// No anthropicApiHost configured - use unified SDK
if (!provider.anthropicApiHost?.trim()) {
return false
}
// Has anthropicApiHost - check model-level compatibility
// For aggregated providers, only specific models support Anthropic API
return isModelAnthropicCompatible(provider, modelId)
}
const logger = loggerService.withContext('ApiServerMessagesRoutes') const logger = loggerService.withContext('ApiServerMessagesRoutes')
const router = express.Router() const router = express.Router()
const providerRouter = express.Router({ mergeParams: true }) const providerRouter = express.Router({ mergeParams: true })
/**
* Estimate token count from messages
* Simple approximation: ~4 characters per token for English text
*/
interface CountTokensInput {
messages: Array<{ role: string; content: string | Array<{ type: string; text?: string }> }>
system?: string | Array<{ type: string; text?: string }>
}
function estimateTokenCount(input: CountTokensInput): number {
const { messages, system } = input
let totalChars = 0
// Count system message tokens
if (system) {
if (typeof system === 'string') {
totalChars += system.length
} else if (Array.isArray(system)) {
for (const block of system) {
if (block.type === 'text' && block.text) {
totalChars += block.text.length
}
}
}
}
// Count message tokens
for (const msg of messages) {
if (typeof msg.content === 'string') {
totalChars += msg.content.length
} else if (Array.isArray(msg.content)) {
for (const block of msg.content) {
if (block.type === 'text' && block.text) {
totalChars += block.text.length
}
}
}
// Add overhead for role
totalChars += 10
}
// Estimate tokens (~4 chars per token, with some overhead)
return Math.ceil(totalChars / 4) + messages.length * 3
}
// Helper function for basic request validation // Helper function for basic request validation
async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> { async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> {
const request: MessageCreateParams = req.body const request: MessageCreateParams = req.body
@@ -109,36 +33,21 @@ async function validateRequestBody(req: Request): Promise<{ valid: boolean; erro
} }
interface HandleMessageProcessingOptions { interface HandleMessageProcessingOptions {
req: Request
res: Response res: Response
provider: Provider provider: Provider
request: MessageCreateParams request: MessageCreateParams
modelId?: string modelId?: string
} }
/** async function handleMessageProcessing({
* Handle message processing using direct Anthropic SDK req,
* Used for providers with anthropicApiHost or native Anthropic providers
* This bypasses AI SDK conversion and uses native Anthropic protocol
*/
async function handleDirectAnthropicProcessing({
res, res,
provider, provider,
request, request,
modelId, modelId
extraHeaders }: HandleMessageProcessingOptions): Promise<void> {
}: HandleMessageProcessingOptions & { extraHeaders?: Record<string, string | string[]> }): Promise<void> {
const actualModelId = modelId || request.model
logger.info('Processing message via direct Anthropic SDK', {
providerId: provider.id,
providerType: provider.type,
modelId: actualModelId,
stream: !!request.stream,
anthropicApiHost: provider.anthropicApiHost
})
try { try {
// Validate request
const validation = messagesService.validateRequest(request) const validation = messagesService.validateRequest(request)
if (!validation.isValid) { if (!validation.isValid) {
res.status(400).json({ res.status(400).json({
@@ -151,126 +60,28 @@ async function handleDirectAnthropicProcessing({
return return
} }
// Process message using messagesService (native Anthropic SDK) const extraHeaders = messagesService.prepareHeaders(req.headers)
const { client, anthropicRequest } = await messagesService.processMessage({ const { client, anthropicRequest } = await messagesService.processMessage({
provider, provider,
request, request,
extraHeaders, extraHeaders,
modelId: actualModelId modelId
}) })
if (request.stream) { if (request.stream) {
// Use native Anthropic streaming
await messagesService.handleStreaming(client, anthropicRequest, { response: res }, provider) await messagesService.handleStreaming(client, anthropicRequest, { response: res }, provider)
} else {
// Use native Anthropic non-streaming
const response = await client.messages.create(anthropicRequest)
res.json(response)
}
} catch (error: any) {
logger.error('Direct Anthropic processing error', { error })
const { statusCode, errorResponse } = messagesService.transformError(error)
res.status(statusCode).json(errorResponse)
}
}
/**
* Handle message processing using unified AI SDK
* Used for non-Anthropic providers that need format conversion
* - Uses AI SDK adapters with output converted to Anthropic SSE format
*/
async function handleUnifiedProcessing({
res,
provider,
request,
modelId
}: HandleMessageProcessingOptions): Promise<void> {
const actualModelId = modelId || request.model
logger.info('Processing message via unified AI SDK', {
providerId: provider.id,
providerType: provider.type,
modelId: actualModelId,
stream: !!request.stream
})
try {
// Validate request
const validation = messagesService.validateRequest(request)
if (!validation.isValid) {
res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: validation.errors.join('; ')
}
})
return return
} }
const middlewareConfig: SharedMiddlewareConfig = { const response = await client.messages.create(anthropicRequest)
modelId: actualModelId, res.json(response)
providerId: provider.id,
aiSdkProviderId: getAiSdkProviderId(provider)
}
const middlewares = buildSharedMiddlewares(middlewareConfig)
logger.debug('Built middlewares for unified processing', {
middlewareCount: middlewares.length,
modelId: actualModelId,
providerId: provider.id
})
if (request.stream) {
await streamUnifiedMessages({
response: res,
provider,
modelId: actualModelId,
params: request,
middlewares,
onError: (error) => {
logger.error('Stream error', error as Error)
},
onComplete: () => {
logger.debug('Stream completed')
}
})
} else {
const response = await generateUnifiedMessage({
provider,
modelId: actualModelId,
params: request,
middlewares
})
res.json(response)
}
} catch (error: any) { } catch (error: any) {
logger.error('Message processing error', { error })
const { statusCode, errorResponse } = messagesService.transformError(error) const { statusCode, errorResponse } = messagesService.transformError(error)
res.status(statusCode).json(errorResponse) res.status(statusCode).json(errorResponse)
} }
} }
/**
* Handle message processing - routes to appropriate handler based on provider and model
*
* Routing logic:
* - Native Anthropic providers (type === 'anthropic'): Direct Anthropic SDK
* - Providers with anthropicApiHost AND model supports Anthropic API: Direct Anthropic SDK
* - Other providers/models: Unified AI SDK with Anthropic SSE conversion
*/
async function handleMessageProcessing({
res,
provider,
request,
modelId
}: HandleMessageProcessingOptions): Promise<void> {
const actualModelId = modelId || request.model
if (shouldUseDirectAnthropic(provider, actualModelId)) {
return handleDirectAnthropicProcessing({ res, provider, request, modelId })
}
return handleUnifiedProcessing({ res, provider, request, modelId })
}
/** /**
* @swagger * @swagger
* /v1/messages: * /v1/messages:
@@ -424,7 +235,7 @@ router.post('/', async (req: Request, res: Response) => {
const provider = modelValidation.provider! const provider = modelValidation.provider!
const modelId = modelValidation.modelId! const modelId = modelValidation.modelId!
return handleMessageProcessing({ res, provider, request, modelId }) return handleMessageProcessing({ req, res, provider, request, modelId })
} catch (error: any) { } catch (error: any) {
logger.error('Message processing error', { error }) logger.error('Message processing error', { error })
const { statusCode, errorResponse } = messagesService.transformError(error) const { statusCode, errorResponse } = messagesService.transformError(error)
@@ -582,7 +393,7 @@ providerRouter.post('/', async (req: Request, res: Response) => {
const request: MessageCreateParams = req.body const request: MessageCreateParams = req.body
return handleMessageProcessing({ res, provider, request }) return handleMessageProcessing({ req, res, provider, request })
} catch (error: any) { } catch (error: any) {
logger.error('Message processing error', { error }) logger.error('Message processing error', { error })
const { statusCode, errorResponse } = messagesService.transformError(error) const { statusCode, errorResponse } = messagesService.transformError(error)
@@ -590,132 +401,4 @@ providerRouter.post('/', async (req: Request, res: Response) => {
} }
}) })
/**
* @swagger
* /v1/messages/count_tokens:
* post:
* summary: Count tokens for messages
* description: Count tokens for Anthropic Messages API format (required by Claude Code SDK)
* tags: [Messages]
* requestBody:
* required: true
* content:
* application/json:
* schema:
* type: object
* required:
* - model
* - messages
* properties:
* model:
* type: string
* description: Model ID
* messages:
* type: array
* items:
* type: object
* system:
* type: string
* description: System message
* responses:
* 200:
* description: Token count response
* content:
* application/json:
* schema:
* type: object
* properties:
* input_tokens:
* type: integer
* 400:
* description: Bad request
*/
router.post('/count_tokens', async (req: Request, res: Response) => {
try {
const { model, messages, system } = req.body
if (!model) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'model parameter is required'
}
})
}
if (!messages || !Array.isArray(messages)) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'messages parameter is required'
}
})
}
const estimatedTokens = estimateTokenCount({ messages, system })
logger.debug('Token count estimated', {
model,
messageCount: messages.length,
estimatedTokens
})
return res.json({
input_tokens: estimatedTokens
})
} catch (error: any) {
logger.error('Token counting error', { error })
return res.status(500).json({
type: 'error',
error: {
type: 'api_error',
message: error.message || 'Internal server error'
}
})
}
})
/**
* Provider-specific count_tokens endpoint
*/
providerRouter.post('/count_tokens', async (req: Request, res: Response) => {
try {
const { model, messages, system } = req.body
if (!messages || !Array.isArray(messages)) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'messages parameter is required'
}
})
}
const estimatedTokens = estimateTokenCount({ messages, system })
logger.debug('Token count estimated (provider route)', {
providerId: req.params.provider,
model,
messageCount: messages.length,
estimatedTokens
})
return res.json({
input_tokens: estimatedTokens
})
} catch (error: any) {
logger.error('Token counting error', { error })
return res.status(500).json({
type: 'error',
error: {
type: 'api_error',
message: error.message || 'Internal server error'
}
})
}
})
export { providerRouter as messagesProviderRoutes, router as messagesRoutes } export { providerRouter as messagesProviderRoutes, router as messagesRoutes }
+9 -93
View File
@@ -2,10 +2,8 @@ import type Anthropic from '@anthropic-ai/sdk'
import type { MessageCreateParams, MessageStreamEvent } from '@anthropic-ai/sdk/resources' import type { MessageCreateParams, MessageStreamEvent } from '@anthropic-ai/sdk/resources'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import anthropicService from '@main/services/AnthropicService' import anthropicService from '@main/services/AnthropicService'
import { buildClaudeCodeSystemMessage, getSdkClient, sanitizeToolsForAnthropic } from '@shared/anthropic' import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
import type { Provider } from '@types' import type { Provider } from '@types'
import { APICallError, RetryError } from 'ai'
import { net } from 'electron'
import type { Response } from 'express' import type { Response } from 'express'
const logger = loggerService.withContext('MessagesService') const logger = loggerService.withContext('MessagesService')
@@ -100,30 +98,11 @@ export class MessagesService {
async getClient(provider: Provider, extraHeaders?: Record<string, string | string[]>): Promise<Anthropic> { async getClient(provider: Provider, extraHeaders?: Record<string, string | string[]>): Promise<Anthropic> {
// Create Anthropic client for the provider // Create Anthropic client for the provider
// Wrap net.fetch to handle compatibility issues:
// 1. net.fetch expects string URLs, not Request objects
// 2. net.fetch doesn't support 'agent' option from Node.js http module
const electronFetch: typeof globalThis.fetch = async (input: URL | RequestInfo, init?: RequestInit) => {
const url = typeof input === 'string' ? input : input instanceof URL ? input.toString() : input.url
// Remove unsupported options for Electron's net.fetch
if (init) {
const initWithAgent = init as RequestInit & { agent?: unknown }
delete initWithAgent.agent
const headers = new Headers(initWithAgent.headers)
if (headers.has('content-length')) {
headers.delete('content-length')
}
initWithAgent.headers = headers
return net.fetch(url, initWithAgent)
}
return net.fetch(url)
}
const context = { fetch: electronFetch }
if (provider.authType === 'oauth') { if (provider.authType === 'oauth') {
const oauthToken = await anthropicService.getValidAccessToken() const oauthToken = await anthropicService.getValidAccessToken()
return getSdkClient(provider, oauthToken, extraHeaders, context) return getSdkClient(provider, oauthToken, extraHeaders)
} }
return getSdkClient(provider, null, extraHeaders, context) return getSdkClient(provider, null, extraHeaders)
} }
prepareHeaders(headers: Record<string, string | string[] | undefined>): Record<string, string | string[]> { prepareHeaders(headers: Record<string, string | string[] | undefined>): Record<string, string | string[]> {
@@ -148,8 +127,7 @@ export class MessagesService {
createAnthropicRequest(request: MessageCreateParams, provider: Provider, modelId?: string): MessageCreateParams { createAnthropicRequest(request: MessageCreateParams, provider: Provider, modelId?: string): MessageCreateParams {
const anthropicRequest: MessageCreateParams = { const anthropicRequest: MessageCreateParams = {
...request, ...request,
stream: !!request.stream, stream: !!request.stream
tools: sanitizeToolsForAnthropic(request.tools)
} }
// Override model if provided // Override model if provided
@@ -255,71 +233,9 @@ export class MessagesService {
} }
transformError(error: any): { statusCode: number; errorResponse: ErrorResponse } { transformError(error: any): { statusCode: number; errorResponse: ErrorResponse } {
let statusCode: number | undefined = undefined let statusCode = 500
let errorType: string | undefined = undefined let errorType = 'api_error'
let errorMessage: string | undefined = undefined let errorMessage = 'Internal server error'
const errorMap: Record<number, string> = {
400: 'invalid_request_error',
401: 'authentication_error',
403: 'forbidden_error',
404: 'not_found_error',
429: 'rate_limit_error',
500: 'internal_server_error'
}
// Handle AI SDK RetryError - extract the last error for better error messages
if (RetryError.isInstance(error)) {
const lastError = error.lastError
// If the last error is an APICallError, extract its details
if (APICallError.isInstance(lastError)) {
statusCode = lastError.statusCode || 502
errorMessage = lastError.message
return {
statusCode,
errorResponse: {
type: 'error',
error: {
type: errorMap[statusCode] || 'api_error',
message: `${error.reason}: ${errorMessage}`,
requestId: lastError.name
}
}
}
}
// Fallback for other retry errors
errorMessage = error.message
statusCode = 502
return {
statusCode,
errorResponse: {
type: 'error',
error: {
type: 'api_error',
message: errorMessage,
requestId: error.name
}
}
}
}
if (APICallError.isInstance(error)) {
statusCode = error.statusCode
errorMessage = error.message
if (statusCode) {
return {
statusCode,
errorResponse: {
type: 'error',
error: {
type: errorMap[statusCode] || 'api_error',
message: errorMessage,
requestId: error.name
}
}
}
}
}
const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined
const anthropicError = error?.error const anthropicError = error?.error
@@ -361,11 +277,11 @@ export class MessagesService {
typeof errorMessage === 'string' && errorMessage.length > 0 ? errorMessage : 'Internal server error' typeof errorMessage === 'string' && errorMessage.length > 0 ? errorMessage : 'Internal server error'
return { return {
statusCode: statusCode ?? 500, statusCode,
errorResponse: { errorResponse: {
type: 'error', type: 'error',
error: { error: {
type: errorType || 'api_error', type: errorType,
message: safeErrorMessage, message: safeErrorMessage,
requestId: error?.request_id requestId: error?.request_id
} }
+21 -6
View File
@@ -1,6 +1,13 @@
import { isEmpty } from 'lodash'
import type { ApiModel, ApiModelsFilter, ApiModelsResponse } from '../../../renderer/src/types/apiModels' import type { ApiModel, ApiModelsFilter, ApiModelsResponse } from '../../../renderer/src/types/apiModels'
import { loggerService } from '../../services/LoggerService' import { loggerService } from '../../services/LoggerService'
import { getAvailableProviders, listAllAvailableModels, transformModelToOpenAI } from '../utils' import {
getAvailableProviders,
getProviderAnthropicModelChecker,
listAllAvailableModels,
transformModelToOpenAI
} from '../utils'
const logger = loggerService.withContext('ModelsService') const logger = loggerService.withContext('ModelsService')
@@ -13,12 +20,11 @@ export class ModelsService {
try { try {
logger.debug('Getting available models from providers', { filter }) logger.debug('Getting available models from providers', { filter })
const providers = await getAvailableProviders() let providers = await getAvailableProviders()
// Note: When providerType === 'anthropic', we now return ALL available models if (filter.providerType === 'anthropic') {
// because the API Server's unified adapter (AiSdkToAnthropicSSE) can convert providers = providers.filter((p) => p.type === 'anthropic' || !isEmpty(p.anthropicApiHost?.trim()))
// any provider's response to Anthropic SSE format. This enables Claude Code Agent }
// to work with OpenAI, Gemini, and other providers transparently.
const models = await listAllAvailableModels(providers) const models = await listAllAvailableModels(providers)
// Use Map to deduplicate models by their full ID (provider:model_id) // Use Map to deduplicate models by their full ID (provider:model_id)
@@ -26,11 +32,20 @@ export class ModelsService {
for (const model of models) { for (const model of models) {
const provider = providers.find((p) => p.id === model.provider) const provider = providers.find((p) => p.id === model.provider)
// logger.debug(`Processing model ${model.id}`)
if (!provider) { if (!provider) {
logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`) logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`)
continue continue
} }
if (filter.providerType === 'anthropic') {
const checker = getProviderAnthropicModelChecker(provider.id)
if (!checker(model)) {
logger.debug(`Skipping model ${model.id} from ${model.provider}. Reason: Not an Anthropic model.`)
continue
}
}
const openAIModel = transformModelToOpenAI(model, provider) const openAIModel = transformModelToOpenAI(model, provider)
const fullModelId = openAIModel.id // This is already in format "provider:model_id" const fullModelId = openAIModel.id // This is already in format "provider:model_id"
@@ -1,718 +0,0 @@
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import type { LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider'
import type { ProviderOptions, ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils'
import type {
ImageBlockParam,
MessageCreateParams,
TextBlockParam,
Tool as AnthropicTool
} from '@anthropic-ai/sdk/resources/messages'
import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core'
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@main/apiServer/adapters'
import { generateSignature as cherryaiGenerateSignature } from '@main/integration/cherryai'
import anthropicService from '@main/services/AnthropicService'
import copilotService from '@main/services/CopilotService'
import { reduxService } from '@main/services/ReduxService'
import { isGemini3ModelId } from '@shared/middleware'
import {
type AiSdkConfig,
type AiSdkConfigContext,
formatProviderApiHost,
initializeSharedProviders,
isAnthropicProvider,
isGeminiProvider,
isOpenAIProvider,
type ProviderFormatContext,
providerToAiSdkConfig as sharedProviderToAiSdkConfig,
resolveActualProvider
} from '@shared/provider'
import { COPILOT_DEFAULT_HEADERS } from '@shared/provider/constant'
import { defaultAppHeaders } from '@shared/utils'
import type { Provider } from '@types'
import type { ImagePart, JSONValue, ModelMessage, Provider as AiSdkProvider, TextPart, Tool as AiSdkTool } from 'ai'
import { simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel, zodSchema } from 'ai'
import { net } from 'electron'
import type { Response } from 'express'
import * as z from 'zod'
import { googleReasoningCache, openRouterReasoningCache } from '../../services/CacheService'
const logger = loggerService.withContext('UnifiedMessagesService')
const MAGIC_STRING = 'skip_thought_signature_validator'
function sanitizeJson(value: unknown): JSONValue {
return JSON.parse(JSON.stringify(value))
}
initializeSharedProviders({
warn: (message) => logger.warn(message),
error: (message, error) => logger.error(message, error)
})
/**
* Configuration for unified message streaming
*/
export interface UnifiedStreamConfig {
response: Response
provider: Provider
modelId: string
params: MessageCreateParams
onError?: (error: unknown) => void
onComplete?: () => void
/**
* Optional AI SDK middlewares to apply
*/
middlewares?: LanguageModelV2Middleware[]
/**
* Optional AI Core plugins to use with the executor
*/
plugins?: AiPlugin[]
}
/**
* Configuration for non-streaming message generation
*/
export interface GenerateUnifiedMessageConfig {
provider: Provider
modelId: string
params: MessageCreateParams
middlewares?: LanguageModelV2Middleware[]
plugins?: AiPlugin[]
}
function getMainProcessFormatContext(): ProviderFormatContext {
const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai')
return {
vertex: {
project: vertexSettings?.projectId || 'default-project',
location: vertexSettings?.location || 'us-central1'
}
}
}
const mainProcessSdkContext: AiSdkConfigContext = {
getRotatedApiKey: (provider) => {
const keys = provider.apiKey.split(',').map((k) => k.trim())
return keys[0] || provider.apiKey
},
fetch: net.fetch as typeof globalThis.fetch
}
function getActualProvider(provider: Provider, modelId: string): Provider {
const model = provider.models?.find((m) => m.id === modelId)
if (!model) return provider
return resolveActualProvider(provider, model)
}
function providerToAiSdkConfig(provider: Provider, modelId: string): AiSdkConfig {
const actualProvider = getActualProvider(provider, modelId)
const formattedProvider = formatProviderApiHost(actualProvider, getMainProcessFormatContext())
return sharedProviderToAiSdkConfig(formattedProvider, modelId, mainProcessSdkContext)
}
function convertAnthropicToolResultToAiSdk(
content: string | Array<TextBlockParam | ImageBlockParam>
): LanguageModelV2ToolResultOutput {
if (typeof content === 'string') {
return { type: 'text', value: content }
}
const values: Array<{ type: 'text'; text: string } | { type: 'media'; data: string; mediaType: string }> = []
for (const block of content) {
if (block.type === 'text') {
values.push({ type: 'text', text: block.text })
} else if (block.type === 'image') {
values.push({
type: 'media',
data: block.source.type === 'base64' ? block.source.data : block.source.url,
mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png'
})
}
}
return { type: 'content', value: values }
}
// Type alias for JSON Schema (compatible with recursive calls)
type JsonSchemaLike = AnthropicTool.InputSchema | Record<string, unknown>
/**
* Convert JSON Schema to Zod schema
* This avoids non-standard fields like input_examples that Anthropic doesn't support
*/
function jsonSchemaToZod(schema: JsonSchemaLike): z.ZodTypeAny {
const s = schema as Record<string, unknown>
const schemaType = s.type as string | string[] | undefined
const enumValues = s.enum as unknown[] | undefined
const description = s.description as string | undefined
// Handle enum first
if (enumValues && Array.isArray(enumValues) && enumValues.length > 0) {
if (enumValues.every((v) => typeof v === 'string')) {
const zodEnum = z.enum(enumValues as [string, ...string[]])
return description ? zodEnum.describe(description) : zodEnum
}
// For non-string enums, use union of literals
const literals = enumValues.map((v) => z.literal(v as string | number | boolean))
if (literals.length === 1) {
return description ? literals[0].describe(description) : literals[0]
}
const zodUnion = z.union(literals as unknown as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
return description ? zodUnion.describe(description) : zodUnion
}
// Handle union types (type: ["string", "null"])
if (Array.isArray(schemaType)) {
const schemas = schemaType.map((t) => jsonSchemaToZod({ ...s, type: t, enum: undefined }))
if (schemas.length === 1) {
return schemas[0]
}
return z.union(schemas as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
}
// Handle by type
switch (schemaType) {
case 'string': {
let zodString = z.string()
if (typeof s.minLength === 'number') zodString = zodString.min(s.minLength)
if (typeof s.maxLength === 'number') zodString = zodString.max(s.maxLength)
if (typeof s.pattern === 'string') zodString = zodString.regex(new RegExp(s.pattern))
return description ? zodString.describe(description) : zodString
}
case 'number':
case 'integer': {
let zodNumber = schemaType === 'integer' ? z.number().int() : z.number()
if (typeof s.minimum === 'number') zodNumber = zodNumber.min(s.minimum)
if (typeof s.maximum === 'number') zodNumber = zodNumber.max(s.maximum)
return description ? zodNumber.describe(description) : zodNumber
}
case 'boolean': {
const zodBoolean = z.boolean()
return description ? zodBoolean.describe(description) : zodBoolean
}
case 'null':
return z.null()
case 'array': {
const items = s.items as Record<string, unknown> | undefined
let zodArray = items ? z.array(jsonSchemaToZod(items)) : z.array(z.unknown())
if (typeof s.minItems === 'number') zodArray = zodArray.min(s.minItems)
if (typeof s.maxItems === 'number') zodArray = zodArray.max(s.maxItems)
return description ? zodArray.describe(description) : zodArray
}
case 'object': {
const properties = s.properties as Record<string, Record<string, unknown>> | undefined
const required = (s.required as string[]) || []
// Always use z.object() to ensure "properties" field is present in output schema
// OpenAI requires explicit properties field even for empty objects
const shape: Record<string, z.ZodTypeAny> = {}
if (properties) {
for (const [key, propSchema] of Object.entries(properties)) {
const zodProp = jsonSchemaToZod(propSchema)
shape[key] = required.includes(key) ? zodProp : zodProp.optional()
}
}
const zodObject = z.object(shape)
return description ? zodObject.describe(description) : zodObject
}
default:
// Unknown type, use z.unknown()
return z.unknown()
}
}
function convertAnthropicToolsToAiSdk(tools: MessageCreateParams['tools']): Record<string, AiSdkTool> | undefined {
if (!tools || tools.length === 0) return undefined
const aiSdkTools: Record<string, AiSdkTool> = {}
for (const anthropicTool of tools) {
if (anthropicTool.type === 'bash_20250124') continue
const toolDef = anthropicTool as AnthropicTool
const rawSchema = toolDef.input_schema
const schema = jsonSchemaToZod(rawSchema)
// Use tool() with inputSchema (AI SDK v5 API)
const aiTool = tool({
description: toolDef.description || '',
inputSchema: zodSchema(schema)
})
aiSdkTools[toolDef.name] = aiTool
}
return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined
}
function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage[] {
const messages: ModelMessage[] = []
// System message
if (params.system) {
if (typeof params.system === 'string') {
messages.push({ role: 'system', content: params.system })
} else if (Array.isArray(params.system)) {
const systemText = params.system
.filter((block) => block.type === 'text')
.map((block) => block.text)
.join('\n')
if (systemText) {
messages.push({ role: 'system', content: systemText })
}
}
}
const toolCallIdToName = new Map<string, string>()
for (const msg of params.messages) {
if (Array.isArray(msg.content)) {
for (const block of msg.content) {
if (block.type === 'tool_use') {
toolCallIdToName.set(block.id, block.name)
}
}
}
}
// User/assistant messages
for (const msg of params.messages) {
if (typeof msg.content === 'string') {
messages.push({
role: msg.role === 'user' ? 'user' : 'assistant',
content: msg.content
})
} else if (Array.isArray(msg.content)) {
const textParts: TextPart[] = []
const imageParts: ImagePart[] = []
const reasoningParts: ReasoningPart[] = []
const toolCallParts: ToolCallPart[] = []
const toolResultParts: ToolResultPart[] = []
for (const block of msg.content) {
if (block.type === 'text') {
textParts.push({ type: 'text', text: block.text })
} else if (block.type === 'thinking') {
reasoningParts.push({ type: 'reasoning', text: block.thinking })
} else if (block.type === 'redacted_thinking') {
reasoningParts.push({ type: 'reasoning', text: block.data })
} else if (block.type === 'image') {
const source = block.source
if (source.type === 'base64') {
imageParts.push({ type: 'image', image: `data:${source.media_type};base64,${source.data}` })
} else if (source.type === 'url') {
imageParts.push({ type: 'image', image: source.url })
}
} else if (block.type === 'tool_use') {
const options: ProviderOptions = {}
if (isGemini3ModelId(params.model)) {
if (googleReasoningCache.get(`google-${block.name}`)) {
options.google = {
thoughtSignature: MAGIC_STRING
}
} else if (openRouterReasoningCache.get('openrouter')) {
options.openrouter = {
reasoning_details: (sanitizeJson(openRouterReasoningCache.get('openrouter')) as JSONValue[]) || []
}
}
}
toolCallParts.push({
type: 'tool-call',
toolName: block.name,
toolCallId: block.id,
input: block.input,
providerOptions: options
})
} else if (block.type === 'tool_result') {
// Look up toolName from the pre-built map (covers cross-message references)
const toolName = toolCallIdToName.get(block.tool_use_id) || 'unknown'
toolResultParts.push({
type: 'tool-result',
toolCallId: block.tool_use_id,
toolName,
output: block.content ? convertAnthropicToolResultToAiSdk(block.content) : { type: 'text', value: '' }
})
}
}
if (toolResultParts.length > 0) {
messages.push({ role: 'tool', content: [...toolResultParts] })
}
if (msg.role === 'user') {
const userContent = [...textParts, ...imageParts]
if (userContent.length > 0) {
messages.push({ role: 'user', content: userContent })
}
} else {
const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts]
if (assistantContent.length > 0) {
let providerOptions: ProviderOptions | undefined = undefined
if (openRouterReasoningCache.get('openrouter')) {
providerOptions = {
openrouter: {
reasoning_details: (sanitizeJson(openRouterReasoningCache.get('openrouter')) as JSONValue[]) || []
}
}
} else if (isGemini3ModelId(params.model)) {
providerOptions = {
google: {
thoughtSignature: MAGIC_STRING
}
}
}
messages.push({ role: 'assistant', content: assistantContent, providerOptions })
}
}
}
}
return messages
}
interface ExecuteStreamConfig {
provider: Provider
modelId: string
params: MessageCreateParams
middlewares?: LanguageModelV2Middleware[]
plugins?: AiPlugin[]
onEvent?: (event: Parameters<typeof formatSSEEvent>[0]) => void
}
/**
* Create AI SDK provider instance from config
* Similar to renderer's createAiSdkProvider
*/
async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider> {
let providerId = config.providerId
// Handle special provider modes (same as renderer)
if (providerId === 'openai' && config.options?.mode === 'chat') {
providerId = 'openai-chat'
} else if (providerId === 'azure' && config.options?.mode === 'responses') {
providerId = 'azure-responses'
} else if (providerId === 'cherryin' && config.options?.mode === 'chat') {
providerId = 'cherryin-chat'
}
const provider = await createProviderCore(providerId, config.options)
return provider
}
/**
* Prepare special provider configuration for providers that need dynamic tokens
* Similar to renderer's prepareSpecialProviderConfig
*/
async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkConfig): Promise<AiSdkConfig> {
switch (provider.id) {
case 'copilot': {
const storedHeaders =
((await reduxService.select('state.copilot.defaultHeaders')) as Record<string, string> | null) ?? {}
const headers: Record<string, string> = {
...COPILOT_DEFAULT_HEADERS,
...storedHeaders
}
try {
const { token } = await copilotService.getToken(null as any, headers)
config.options.apiKey = token
const existingHeaders = (config.options.headers as Record<string, string> | undefined) ?? {}
config.options.headers = {
...headers,
...existingHeaders
}
} catch (error) {
logger.error('Failed to get Copilot token', error as Error)
throw new Error('Failed to get Copilot token. Please re-authorize Copilot.')
}
break
}
case 'anthropic': {
if (provider.authType === 'oauth') {
try {
const oauthToken = await anthropicService.getValidAccessToken()
if (!oauthToken) {
throw new Error('Anthropic OAuth token not available. Please re-authorize.')
}
config.options = {
...config.options,
headers: {
...(config.options.headers ? config.options.headers : {}),
'Content-Type': 'application/json',
'anthropic-version': '2023-06-01',
'anthropic-beta': 'oauth-2025-04-20',
Authorization: `Bearer ${oauthToken}`
},
baseURL: 'https://api.anthropic.com/v1',
apiKey: ''
}
} catch (error) {
logger.error('Failed to get Anthropic OAuth token', error as Error)
throw new Error('Failed to get Anthropic OAuth token. Please re-authorize.')
}
}
break
}
case 'cherryai': {
// Create a signed fetch wrapper for cherryai
const baseFetch = net.fetch as typeof globalThis.fetch
config.options.fetch = async (url: RequestInfo | URL, options?: RequestInit) => {
if (!options?.body) {
return baseFetch(url, options)
}
const signature = cherryaiGenerateSignature({
method: 'POST',
path: '/chat/completions',
query: '',
body: JSON.parse(options.body as string)
})
return baseFetch(url, {
...options,
headers: {
...(options.headers as Record<string, string>),
...signature
}
})
}
break
}
}
return config
}
function mapAnthropicThinkToAISdkProviderOptions(
provider: Provider,
config: MessageCreateParams['thinking']
): ProviderOptions | undefined {
if (!config) return undefined
if (isAnthropicProvider(provider)) {
return {
anthropic: {
...mapToAnthropicProviderOptions(config)
}
}
}
if (isGeminiProvider(provider)) {
return {
google: {
...mapToGeminiProviderOptions(config)
}
}
}
if (isOpenAIProvider(provider)) {
return {
openai: {
...mapToOpenAIProviderOptions(config)
}
}
}
return undefined
}
function mapToAnthropicProviderOptions(config: NonNullable<MessageCreateParams['thinking']>): AnthropicProviderOptions {
return {
thinking: {
type: config.type,
budgetTokens: config.type === 'enabled' ? config.budget_tokens : undefined
}
}
}
function mapToGeminiProviderOptions(
config: NonNullable<MessageCreateParams['thinking']>
): GoogleGenerativeAIProviderOptions {
return {
thinkingConfig: {
thinkingBudget: config.type === 'enabled' ? config.budget_tokens : -1,
includeThoughts: config.type === 'enabled'
}
}
}
function mapToOpenAIProviderOptions(
config: NonNullable<MessageCreateParams['thinking']>
): OpenAIResponsesProviderOptions {
return {
reasoningEffort: config.type === 'enabled' ? 'high' : 'none'
}
}
/**
* Core stream execution function - single source of truth for AI SDK calls
*/
async function executeStream(config: ExecuteStreamConfig): Promise<AiSdkToAnthropicSSE> {
const { provider, modelId, params, middlewares = [], plugins = [], onEvent } = config
// Convert provider config to AI SDK config
let sdkConfig = providerToAiSdkConfig(provider, modelId)
// Prepare special provider config (Copilot, Anthropic OAuth, etc.)
sdkConfig = await prepareSpecialProviderConfig(provider, sdkConfig)
// Create provider instance and get language model
const aiSdkProvider = await createAiSdkProvider(sdkConfig)
const baseModel = aiSdkProvider.languageModel(modelId)
// Apply middlewares if present
const model =
middlewares.length > 0 && typeof baseModel === 'object'
? (wrapLanguageModel({ model: baseModel, middleware: middlewares }) as typeof baseModel)
: baseModel
// Create executor with plugins
const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins)
// Convert messages and tools
const coreMessages = convertAnthropicToAiMessages(params)
const tools = convertAnthropicToolsToAiSdk(params.tools)
// Create the adapter
const adapter = new AiSdkToAnthropicSSE({
model: `${provider.id}:${modelId}`,
onEvent: onEvent || (() => {})
})
// Execute stream - pass model object instead of string
const result = await executor.streamText({
model, // Now passing LanguageModel object, not string
messages: coreMessages,
// FIXME: Claude Code传入的maxToken会超出有些模型限制,需做特殊处理,可能在v2好修复一点,现在维护的成本有点高
// 已知: 豆包
maxOutputTokens: params.max_tokens,
temperature: params.temperature,
topP: params.top_p,
topK: params.top_k,
stopSequences: params.stop_sequences,
stopWhen: stepCountIs(100),
headers: defaultAppHeaders(),
tools,
providerOptions: mapAnthropicThinkToAISdkProviderOptions(provider, params.thinking)
})
// Process the stream through the adapter
await adapter.processStream(result.fullStream)
return adapter
}
/**
* Stream a message request using AI SDK executor and convert to Anthropic SSE format
*/
export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise<void> {
const { response, provider, modelId, params, onError, onComplete, middlewares = [], plugins = [] } = config
logger.info('Starting unified message stream', {
providerId: provider.id,
providerType: provider.type,
modelId,
stream: params.stream,
middlewareCount: middlewares.length,
pluginCount: plugins.length
})
try {
response.setHeader('Content-Type', 'text/event-stream')
response.setHeader('Cache-Control', 'no-cache')
response.setHeader('Connection', 'keep-alive')
response.setHeader('X-Accel-Buffering', 'no')
await executeStream({
provider,
modelId,
params,
middlewares,
plugins,
onEvent: (event) => {
logger.silly('Streaming event', { eventType: event.type })
const sseData = formatSSEEvent(event)
response.write(sseData)
}
})
// Send done marker
response.write(formatSSEDone())
response.end()
logger.info('Unified message stream completed', { providerId: provider.id, modelId })
onComplete?.()
} catch (error) {
logger.error('Error in unified message stream', error as Error, { providerId: provider.id, modelId })
onError?.(error)
throw error
}
}
/**
* Generate a non-streaming message response
*
* Uses simulateStreamingMiddleware to reuse the same streaming logic,
* similar to renderer's ModernAiProvider pattern.
*/
export async function generateUnifiedMessage(
providerOrConfig: Provider | GenerateUnifiedMessageConfig,
modelId?: string,
params?: MessageCreateParams
): Promise<ReturnType<typeof AiSdkToAnthropicSSE.prototype.buildNonStreamingResponse>> {
// Support both old signature and new config-based signature
let config: GenerateUnifiedMessageConfig
if ('provider' in providerOrConfig && 'modelId' in providerOrConfig && 'params' in providerOrConfig) {
config = providerOrConfig
} else {
config = {
provider: providerOrConfig as Provider,
modelId: modelId!,
params: params!
}
}
const { provider, middlewares = [], plugins = [] } = config
logger.info('Starting unified message generation', {
providerId: provider.id,
providerType: provider.type,
modelId: config.modelId,
middlewareCount: middlewares.length,
pluginCount: plugins.length
})
try {
// Add simulateStreamingMiddleware to reuse streaming logic for non-streaming
const allMiddlewares = [simulateStreamingMiddleware(), ...middlewares]
const adapter = await executeStream({
provider,
modelId: config.modelId,
params: config.params,
middlewares: allMiddlewares,
plugins
})
const finalResponse = adapter.buildNonStreamingResponse()
logger.info('Unified message generation completed', {
providerId: provider.id,
modelId: config.modelId
})
return finalResponse
} catch (error) {
logger.error('Error in unified message generation', error as Error, {
providerId: provider.id,
modelId: config.modelId
})
throw error
}
}
export default {
streamUnifiedMessages,
generateUnifiedMessage
}
+14 -38
View File
@@ -1,7 +1,7 @@
import { CacheService } from '@main/services/CacheService' import { CacheService } from '@main/services/CacheService'
import { loggerService } from '@main/services/LoggerService' import { loggerService } from '@main/services/LoggerService'
import { reduxService } from '@main/services/ReduxService' import { reduxService } from '@main/services/ReduxService'
import { isPpioAnthropicCompatibleModel, isSiliconAnthropicCompatibleModel } from '@shared/config/providers' import { isSiliconAnthropicCompatibleModel } from '@shared/config/providers'
import type { ApiModel, Model, Provider } from '@types' import type { ApiModel, Model, Provider } from '@types'
const logger = loggerService.withContext('ApiServerUtils') const logger = loggerService.withContext('ApiServerUtils')
@@ -28,9 +28,10 @@ export async function getAvailableProviders(): Promise<Provider[]> {
return [] return []
} }
// Support all provider types that AI SDK can handle // Support OpenAI and Anthropic type providers for API server
// The unified-messages service uses AI SDK which supports many providers const supportedProviders = providers.filter(
const supportedProviders = providers.filter((p: Provider) => p.enabled) (p: Provider) => p.enabled && (p.type === 'openai' || p.type === 'anthropic')
)
// Cache the filtered results // Cache the filtered results
CacheService.set(PROVIDERS_CACHE_KEY, supportedProviders, PROVIDERS_CACHE_TTL) CacheService.set(PROVIDERS_CACHE_KEY, supportedProviders, PROVIDERS_CACHE_TTL)
@@ -159,7 +160,7 @@ export async function validateModelId(model: string): Promise<{
valid: false, valid: false,
error: { error: {
type: 'provider_not_found', type: 'provider_not_found',
message: `Provider '${providerId}' not found or not enabled.`, message: `Provider '${providerId}' not found, not enabled, or not supported. Only OpenAI providers are currently supported.`,
code: 'provider_not_found' code: 'provider_not_found'
} }
} }
@@ -261,8 +262,14 @@ export function validateProvider(provider: Provider): boolean {
return false return false
} }
// AI SDK supports many provider types, no longer need to filter by type // Support OpenAI and Anthropic type providers
// The unified-messages service handles all supported types if (provider.type !== 'openai' && provider.type !== 'anthropic') {
logger.debug('Provider type not supported', {
providerId: provider.id,
providerType: provider.type
})
return false
}
return true return true
} catch (error: any) { } catch (error: any) {
@@ -283,39 +290,8 @@ export const getProviderAnthropicModelChecker = (providerId: string): ((m: Model
return (m: Model) => m.id.includes('claude') return (m: Model) => m.id.includes('claude')
case 'silicon': case 'silicon':
return (m: Model) => isSiliconAnthropicCompatibleModel(m.id) return (m: Model) => isSiliconAnthropicCompatibleModel(m.id)
case 'ppio':
return (m: Model) => isPpioAnthropicCompatibleModel(m.id)
default: default:
// allow all models when checker not configured // allow all models when checker not configured
return () => true return () => true
} }
} }
/**
* Check if a specific model is compatible with Anthropic API for a given provider.
*
* This is used for fine-grained routing decisions at the model level.
* For aggregated providers (like Silicon), only certain models support the Anthropic API endpoint.
*
* @param provider - The provider to check
* @param modelId - The model ID to check (without provider prefix)
* @returns true if the model supports Anthropic API endpoint
*/
export function isModelAnthropicCompatible(provider: Provider, modelId: string): boolean {
const checker = getProviderAnthropicModelChecker(provider.id)
const model = provider.models?.find((m) => m.id === modelId)
if (model) {
return checker(model)
}
const minimalModel: Model = {
id: modelId,
name: modelId,
provider: provider.id,
group: ''
}
return checker(minimalModel)
}
-21
View File
@@ -1,19 +1,9 @@
import type { ReasoningDetailUnion } from '@main/apiServer/adapters/openrouter'
interface CacheItem<T> { interface CacheItem<T> {
data: T data: T
timestamp: number timestamp: number
duration: number duration: number
} }
/**
* Interface for reasoning cache
*/
export interface IReasoningCache<T> {
set(key: string, value: T): void
get(key: string): T | undefined
}
export class CacheService { export class CacheService {
private static cache: Map<string, CacheItem<any>> = new Map() private static cache: Map<string, CacheItem<any>> = new Map()
@@ -82,14 +72,3 @@ export class CacheService {
return true return true
} }
} }
// Singleton cache instances using CacheService
export const googleReasoningCache: IReasoningCache<string> = {
set: (key, value) => CacheService.set(`google-reasoning:${key}`, value, 30 * 60 * 1000),
get: (key) => CacheService.get(`google-reasoning:${key}`) || undefined
}
export const openRouterReasoningCache: IReasoningCache<ReasoningDetailUnion[]> = {
set: (key, value) => CacheService.set(`openrouter-reasoning:${key}`, value, 30 * 60 * 1000),
get: (key) => CacheService.get(`openrouter-reasoning:${key}`) || undefined
}
@@ -87,7 +87,6 @@ export class ClaudeStreamState {
private pendingUsage: PendingUsageState = {} private pendingUsage: PendingUsageState = {}
private pendingToolCalls = new Map<string, PendingToolCall>() private pendingToolCalls = new Map<string, PendingToolCall>()
private stepActive = false private stepActive = false
private _streamFinished = false
constructor(options: ClaudeStreamStateOptions) { constructor(options: ClaudeStreamStateOptions) {
this.logger = loggerService.withContext('ClaudeStreamState') this.logger = loggerService.withContext('ClaudeStreamState')
@@ -290,16 +289,6 @@ export class ClaudeStreamState {
getNamespacedToolCallId(rawToolCallId: string): string { getNamespacedToolCallId(rawToolCallId: string): string {
return buildNamespacedToolCallId(this.agentSessionId, rawToolCallId) return buildNamespacedToolCallId(this.agentSessionId, rawToolCallId)
} }
/** Marks the stream as finished (either completed or errored). */
markFinished(): void {
this._streamFinished = true
}
/** Returns true if the stream has already emitted a terminal event. */
isFinished(): boolean {
return this._streamFinished
}
} }
export type { PendingToolCall } export type { PendingToolCall }
@@ -85,14 +85,18 @@ class ClaudeCodeService implements AgentServiceInterface {
}) })
return aiStream return aiStream
} }
// Validate provider has required configuration if (
// Note: We no longer restrict to anthropic type only - the API Server's unified adapter (modelInfo.provider?.type !== 'anthropic' &&
// handles format conversion for any provider type (OpenAI, Gemini, etc.) (modelInfo.provider?.anthropicApiHost === undefined || modelInfo.provider.anthropicApiHost.trim() === '')) ||
if (!modelInfo.provider?.apiKey) { modelInfo.provider.apiKey === ''
logger.error('Provider API key is missing', { modelInfo }) ) {
logger.error('Anthropic provider configuration is missing', {
modelInfo
})
aiStream.emit('data', { aiStream.emit('data', {
type: 'error', type: 'error',
error: new Error(`Provider '${modelInfo.provider?.id}' is missing API key configuration.`) error: new Error(`Invalid provider type '${modelInfo.provider?.type}'. Expected 'anthropic' provider type.`)
}) })
return aiStream return aiStream
} }
@@ -103,14 +107,15 @@ class ClaudeCodeService implements AgentServiceInterface {
Object.entries(loginShellEnv).filter(([key]) => !key.toLowerCase().endsWith('_proxy')) Object.entries(loginShellEnv).filter(([key]) => !key.toLowerCase().endsWith('_proxy'))
) as Record<string, string> ) as Record<string, string>
// Route through local API Server which handles format conversion via unified adapter
// This enables Claude Code Agent to work with any provider (OpenAI, Gemini, etc.)
// The API Server converts AI SDK responses to Anthropic SSE format transparently
const env = { const env = {
...loginShellEnvWithoutProxies, ...loginShellEnvWithoutProxies,
ANTHROPIC_API_KEY: apiConfig.apiKey, // TODO: fix the proxy api server
ANTHROPIC_AUTH_TOKEN: apiConfig.apiKey, // ANTHROPIC_API_KEY: apiConfig.apiKey,
ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`, // ANTHROPIC_AUTH_TOKEN: apiConfig.apiKey,
// ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`,
ANTHROPIC_API_KEY: modelInfo.provider.apiKey,
ANTHROPIC_AUTH_TOKEN: modelInfo.provider.apiKey,
ANTHROPIC_BASE_URL: modelInfo.provider.anthropicApiHost?.trim() || modelInfo.provider.apiHost,
ANTHROPIC_MODEL: modelInfo.modelId, ANTHROPIC_MODEL: modelInfo.modelId,
ANTHROPIC_DEFAULT_OPUS_MODEL: modelInfo.modelId, ANTHROPIC_DEFAULT_OPUS_MODEL: modelInfo.modelId,
ANTHROPIC_DEFAULT_SONNET_MODEL: modelInfo.modelId, ANTHROPIC_DEFAULT_SONNET_MODEL: modelInfo.modelId,
@@ -534,19 +539,6 @@ class ClaudeCodeService implements AgentServiceInterface {
return return
} }
// Skip emitting error if stream already finished (error was handled via result message)
if (streamState.isFinished()) {
logger.debug('SDK process exited after stream finished, skipping duplicate error event', {
duration,
error: errorObj instanceof Error ? { name: errorObj.name, message: errorObj.message } : String(errorObj)
})
// Still emit complete to signal stream end
stream.emit('data', {
type: 'complete'
})
return
}
errorChunks.push(errorObj instanceof Error ? errorObj.message : String(errorObj)) errorChunks.push(errorObj instanceof Error ? errorObj.message : String(errorObj))
const errorMessage = errorChunks.join('\n\n') const errorMessage = errorChunks.join('\n\n')
logger.error('SDK query failed', { logger.error('SDK query failed', {
@@ -121,7 +121,7 @@ export function transformSDKMessageToStreamParts(sdkMessage: SDKMessage, state:
case 'system': case 'system':
return handleSystemMessage(sdkMessage) return handleSystemMessage(sdkMessage)
case 'result': case 'result':
return handleResultMessage(sdkMessage, state) return handleResultMessage(sdkMessage)
default: default:
logger.warn('Unknown SDKMessage type', { type: (sdkMessage as any).type }) logger.warn('Unknown SDKMessage type', { type: (sdkMessage as any).type })
return [] return []
@@ -193,30 +193,6 @@ function handleAssistantMessage(
} }
break break
} }
case 'thinking':
case 'redacted_thinking': {
const thinkingText = block.type === 'thinking' ? block.thinking : block.data
if (thinkingText) {
const id = generateMessageId()
chunks.push({
type: 'reasoning-start',
id,
providerMetadata
})
chunks.push({
type: 'reasoning-delta',
id,
text: thinkingText,
providerMetadata
})
chunks.push({
type: 'reasoning-end',
id,
providerMetadata
})
}
break
}
case 'tool_use': case 'tool_use':
handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks) handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks)
break break
@@ -469,11 +445,7 @@ function handleStreamEvent(
case 'content_block_stop': { case 'content_block_stop': {
const block = state.closeBlock(event.index) const block = state.closeBlock(event.index)
if (!block) { if (!block) {
// Some providers (e.g., Gemini) send content via assistant message before stream events, logger.warn('Received content_block_stop for unknown index', { index: event.index })
// so the block may not exist in state. This is expected behavior, not an error.
logger.debug('Received content_block_stop for unknown index (may be from non-streaming content)', {
index: event.index
})
break break
} }
@@ -707,13 +679,7 @@ function handleSystemMessage(message: Extract<SDKMessage, { type: 'system' }>):
* Successful runs yield a `finish` frame with aggregated usage metrics, while * Successful runs yield a `finish` frame with aggregated usage metrics, while
* failures are surfaced as `error` frames. * failures are surfaced as `error` frames.
*/ */
function handleResultMessage( function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>): AgentStreamPart[] {
message: Extract<SDKMessage, { type: 'result' }>,
state: ClaudeStreamState
): AgentStreamPart[] {
// Mark stream as finished to prevent duplicate error events when SDK process exits
state.markFinished()
const chunks: AgentStreamPart[] = [] const chunks: AgentStreamPart[] = []
let usage: LanguageModelUsage | undefined let usage: LanguageModelUsage | undefined
@@ -725,33 +691,26 @@ function handleResultMessage(
} }
} }
chunks.push({ if (message.subtype === 'success') {
type: 'finish', chunks.push({
totalUsage: usage ?? emptyUsage, type: 'finish',
finishReason: mapClaudeCodeFinishReason(message.subtype), totalUsage: usage ?? emptyUsage,
providerMetadata: { finishReason: mapClaudeCodeFinishReason(message.subtype),
...sdkMessageToProviderMetadata(message), providerMetadata: {
usage: message.usage, ...sdkMessageToProviderMetadata(message),
durationMs: message.duration_ms, usage: message.usage,
costUsd: message.total_cost_usd, durationMs: message.duration_ms,
raw: message costUsd: message.total_cost_usd,
} raw: message
} as AgentStreamPart) }
if (message.subtype !== 'success') { } as AgentStreamPart)
} else {
chunks.push({ chunks.push({
type: 'error', type: 'error',
error: { error: {
message: `${message.subtype}: Process failed after ${message.num_turns} turns` message: `${message.subtype}: Process failed after ${message.num_turns} turns`
} }
} as AgentStreamPart) } as AgentStreamPart)
} else {
if (message.is_error) {
const errorMatch = message.result.match(/\{.*\}/)
if (errorMatch) {
const errorDetail = JSON.parse(errorMatch[0])
chunks.push(errorDetail)
}
}
} }
return chunks return chunks
} }
@@ -24,7 +24,7 @@ export class VertexAPIClient extends GeminiAPIClient {
this.anthropicVertexClient = new AnthropicVertexClient(provider) this.anthropicVertexClient = new AnthropicVertexClient(provider)
// 如果传入的是普通 Provider,转换为 VertexProvider // 如果传入的是普通 Provider,转换为 VertexProvider
if (isVertexProvider(provider)) { if (isVertexProvider(provider)) {
this.vertexProvider = provider as VertexProvider this.vertexProvider = provider
} else { } else {
this.vertexProvider = createVertexProvider(provider) this.vertexProvider = createVertexProvider(provider)
} }
@@ -5,7 +5,6 @@ import type { MCPTool } from '@renderer/types'
import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types' import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk' import type { Chunk } from '@renderer/types/chunk'
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider' import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
import { openrouterReasoningMiddleware, skipGeminiThoughtSignatureMiddleware } from '@shared/middleware'
import type { LanguageModelMiddleware } from 'ai' import type { LanguageModelMiddleware } from 'ai'
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai' import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
import { isEmpty } from 'lodash' import { isEmpty } from 'lodash'
@@ -14,7 +13,9 @@ import { getAiSdkProviderId } from '../provider/factory'
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image' import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
import { noThinkMiddleware } from './noThinkMiddleware' import { noThinkMiddleware } from './noThinkMiddleware'
import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware' import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware'
import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware' import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware'
import { toolChoiceMiddleware } from './toolChoiceMiddleware' import { toolChoiceMiddleware } from './toolChoiceMiddleware'
const logger = loggerService.withContext('AiSdkMiddlewareBuilder') const logger = loggerService.withContext('AiSdkMiddlewareBuilder')
@@ -0,0 +1,50 @@
import type { LanguageModelV2StreamPart } from '@ai-sdk/provider'
import type { LanguageModelMiddleware } from 'ai'
/**
* https://openrouter.ai/docs/docs/best-practices/reasoning-tokens#example-preserving-reasoning-blocks-with-openrouter-and-claude
*
* @returns LanguageModelMiddleware - a middleware filter redacted block
*/
export function openrouterReasoningMiddleware(): LanguageModelMiddleware {
const REDACTED_BLOCK = '[REDACTED]'
return {
middlewareVersion: 'v2',
wrapGenerate: async ({ doGenerate }) => {
const { content, ...rest } = await doGenerate()
const modifiedContent = content.map((part) => {
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
return {
...part,
text: part.text.replace(REDACTED_BLOCK, '')
}
}
return part
})
return { content: modifiedContent, ...rest }
},
wrapStream: async ({ doStream }) => {
const { stream, ...rest } = await doStream()
return {
stream: stream.pipeThrough(
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
transform(
chunk: LanguageModelV2StreamPart,
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
) {
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
controller.enqueue({
...chunk,
delta: chunk.delta.replace(REDACTED_BLOCK, '')
})
} else {
controller.enqueue(chunk)
}
}
})
),
...rest
}
}
}
}
@@ -0,0 +1,36 @@
import type { LanguageModelMiddleware } from 'ai'
/**
* skip Gemini Thought Signature Middleware
* 由于多模型客户端请求的复杂性(可以中途切换其他模型),这里选择通过中间件方式添加跳过所有 Gemini3 思考签名
* Due to the complexity of multi-model client requests (which can switch to other models mid-process),
* it was decided to add a skip for all Gemini3 thinking signatures via middleware.
* @param aiSdkId AI SDK Provider ID
* @returns LanguageModelMiddleware
*/
export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelMiddleware {
const MAGIC_STRING = 'skip_thought_signature_validator'
return {
middlewareVersion: 'v2',
transformParams: async ({ params }) => {
const transformedParams = { ...params }
// Process messages in prompt
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
transformedParams.prompt = transformedParams.prompt.map((message) => {
if (typeof message.content !== 'string') {
for (const part of message.content) {
const googleOptions = part?.providerOptions?.[aiSdkId]
if (googleOptions?.thoughtSignature) {
googleOptions.thoughtSignature = MAGIC_STRING
}
}
}
return message
})
}
return transformedParams
}
}
}
@@ -107,7 +107,7 @@ export async function buildStreamTextParams(
searchWithTime: store.getState().websearch.searchWithTime searchWithTime: store.getState().websearch.searchWithTime
} }
const { providerOptions, standardParams } = buildProviderOptions(assistant, model, provider, { const { providerOptions, standardParams, bodyParams } = buildProviderOptions(assistant, model, provider, {
enableReasoning, enableReasoning,
enableWebSearch, enableWebSearch,
enableGenerateImage enableGenerateImage
@@ -185,6 +185,7 @@ export async function buildStreamTextParams(
// Note: standardParams (topK, frequencyPenalty, presencePenalty, stopSequences, seed) // Note: standardParams (topK, frequencyPenalty, presencePenalty, stopSequences, seed)
// are extracted from custom parameters and passed directly to streamText() // are extracted from custom parameters and passed directly to streamText()
// instead of being placed in providerOptions // instead of being placed in providerOptions
// Note: bodyParams are custom parameters for AI Gateway that should be at body level
const params: StreamTextParams = { const params: StreamTextParams = {
messages: sdkMessages, messages: sdkMessages,
maxOutputTokens: getMaxTokens(assistant, model), maxOutputTokens: getMaxTokens(assistant, model),
@@ -192,6 +193,8 @@ export async function buildStreamTextParams(
topP: getTopP(assistant, model), topP: getTopP(assistant, model),
// Include AI SDK standard params extracted from custom parameters // Include AI SDK standard params extracted from custom parameters
...standardParams, ...standardParams,
// Include body-level params for AI Gateway custom parameters
...bodyParams,
abortSignal: options.requestOptions?.signal, abortSignal: options.requestOptions?.signal,
headers, headers,
providerOptions, providerOptions,
@@ -24,17 +24,7 @@ vi.mock('@renderer/services/AssistantService', () => ({
vi.mock('@renderer/store', () => ({ vi.mock('@renderer/store', () => ({
default: { default: {
getState: () => ({ getState: () => ({ copilot: { defaultHeaders: {} } })
copilot: { defaultHeaders: {} },
llm: {
settings: {
vertexai: {
projectId: 'test-project',
location: 'us-central1'
}
}
}
})
} }
})) }))
@@ -43,7 +33,7 @@ vi.mock('@renderer/utils/api', () => ({
if (isSupportedAPIVersion === false) { if (isSupportedAPIVersion === false) {
return host // Return host as-is when isSupportedAPIVersion is false return host // Return host as-is when isSupportedAPIVersion is false
} }
return host ? `${host}/v1` : '' // Default behavior when isSupportedAPIVersion is true return `${host}/v1` // Default behavior when isSupportedAPIVersion is true
}), }),
routeToEndpoint: vi.fn((host) => ({ routeToEndpoint: vi.fn((host) => ({
baseURL: host, baseURL: host,
@@ -51,20 +41,6 @@ vi.mock('@renderer/utils/api', () => ({
})) }))
})) }))
// Also mock @shared/api since formatProviderApiHost uses it directly
vi.mock('@shared/api', async (importOriginal) => {
const actual = (await importOriginal()) as any
return {
...actual,
formatApiHost: vi.fn((host, isSupportedAPIVersion = true) => {
if (isSupportedAPIVersion === false) {
return host || '' // Return host as-is when isSupportedAPIVersion is false
}
return host ? `${host}/v1` : '' // Default behavior when isSupportedAPIVersion is true
})
}
})
vi.mock('@renderer/utils/provider', async (importOriginal) => { vi.mock('@renderer/utils/provider', async (importOriginal) => {
const actual = (await importOriginal()) as any const actual = (await importOriginal()) as any
return { return {
@@ -97,8 +73,8 @@ vi.mock('@renderer/services/AssistantService', () => ({
import { getProviderByModel } from '@renderer/services/AssistantService' import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model, Provider } from '@renderer/types' import type { Model, Provider } from '@renderer/types'
import { formatApiHost } from '@renderer/utils/api'
import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider' import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider'
import { formatApiHost } from '@shared/api'
import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants' import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants'
import { getActualProvider, providerToAiSdkConfig } from '../providerConfig' import { getActualProvider, providerToAiSdkConfig } from '../providerConfig'
@@ -1,13 +1,13 @@
/** /**
* AiHubMix规则集 * AiHubMix规则集
*/ */
import { getLowerBaseModelName } from '@shared/utils/naming' import { isOpenAILLMModel } from '@renderer/config/models'
import type { Provider } from '@renderer/types'
import type { MinimalModel, MinimalProvider } from '../types'
import { provider2Provider, startsWith } from './helper' import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types' import type { RuleSet } from './types'
const extraProviderConfig = <P extends MinimalProvider>(provider: P) => { const extraProviderConfig = (provider: Provider) => {
return { return {
...provider, ...provider,
extra_headers: { extra_headers: {
@@ -17,23 +17,11 @@ const extraProviderConfig = <P extends MinimalProvider>(provider: P) => {
} }
} }
function isOpenAILLMModel<M extends MinimalModel>(model: M): boolean {
const modelId = getLowerBaseModelName(model.id)
const reasonings = ['o1', 'o3', 'o4', 'gpt-oss']
if (reasonings.some((r) => modelId.includes(r))) {
return true
}
if (modelId.includes('gpt')) {
return true
}
return false
}
const AIHUBMIX_RULES: RuleSet = { const AIHUBMIX_RULES: RuleSet = {
rules: [ rules: [
{ {
match: startsWith('claude'), match: startsWith('claude'),
provider: (provider) => { provider: (provider: Provider) => {
return extraProviderConfig({ return extraProviderConfig({
...provider, ...provider,
type: 'anthropic' type: 'anthropic'
@@ -46,7 +34,7 @@ const AIHUBMIX_RULES: RuleSet = {
!model.id.endsWith('-nothink') && !model.id.endsWith('-nothink') &&
!model.id.endsWith('-search') && !model.id.endsWith('-search') &&
!model.id.includes('embedding'), !model.id.includes('embedding'),
provider: (provider) => { provider: (provider: Provider) => {
return extraProviderConfig({ return extraProviderConfig({
...provider, ...provider,
type: 'gemini', type: 'gemini',
@@ -56,7 +44,7 @@ const AIHUBMIX_RULES: RuleSet = {
}, },
{ {
match: isOpenAILLMModel, match: isOpenAILLMModel,
provider: (provider) => { provider: (provider: Provider) => {
return extraProviderConfig({ return extraProviderConfig({
...provider, ...provider,
type: 'openai-response' type: 'openai-response'
@@ -64,8 +52,7 @@ const AIHUBMIX_RULES: RuleSet = {
} }
} }
], ],
fallbackRule: (provider) => extraProviderConfig(provider) fallbackRule: (provider: Provider) => extraProviderConfig(provider)
} }
export const aihubmixProviderCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P => export const aihubmixProviderCreator = provider2Provider.bind(null, AIHUBMIX_RULES)
provider2Provider<MinimalModel, MinimalProvider, P>(AIHUBMIX_RULES, model, provider)
@@ -0,0 +1,22 @@
import type { Provider } from '@renderer/types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
const AZURE_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: Provider) => ({
...provider,
type: 'anthropic',
apiHost: provider.apiHost + 'anthropic/v1',
id: 'azure-anthropic'
})
}
],
fallbackRule: (provider: Provider) => provider
}
export const azureAnthropicProviderCreator = provider2Provider.bind(null, AZURE_ANTHROPIC_RULES)
@@ -0,0 +1,22 @@
import type { Model, Provider } from '@renderer/types'
import type { RuleSet } from './types'
export const startsWith = (prefix: string) => (model: Model) => model.id.toLowerCase().startsWith(prefix.toLowerCase())
export const endpointIs = (type: string) => (model: Model) => model.endpoint_type === type
/**
* 解析模型对应的Provider
* @param ruleSet 规则集对象
* @param model 模型对象
* @param provider 原始provider对象
* @returns 解析出的provider对象
*/
export function provider2Provider(ruleSet: RuleSet, model: Model, provider: Provider): Provider {
for (const rule of ruleSet.rules) {
if (rule.match(model)) {
return rule.provider(provider)
}
}
return ruleSet.fallbackRule(provider)
}
@@ -1,7 +1,3 @@
// Re-export from shared config export { aihubmixProviderCreator } from './aihubmix'
export { export { newApiResolverCreator } from './newApi'
aihubmixProviderCreator, export { vertexAnthropicProviderCreator } from './vertext-anthropic'
azureAnthropicProviderCreator,
newApiResolverCreator,
vertexAnthropicProviderCreator
} from '@shared/provider/config'
@@ -1,7 +1,8 @@
/** /**
* NewAPI规则集 * NewAPI规则集
*/ */
import type { MinimalModel, MinimalProvider, ProviderType } from '../types' import type { Provider } from '@renderer/types'
import { endpointIs, provider2Provider } from './helper' import { endpointIs, provider2Provider } from './helper'
import type { RuleSet } from './types' import type { RuleSet } from './types'
@@ -9,43 +10,42 @@ const NEWAPI_RULES: RuleSet = {
rules: [ rules: [
{ {
match: endpointIs('anthropic'), match: endpointIs('anthropic'),
provider: (provider) => { provider: (provider: Provider) => {
return { return {
...provider, ...provider,
type: 'anthropic' as ProviderType type: 'anthropic'
} }
} }
}, },
{ {
match: endpointIs('gemini'), match: endpointIs('gemini'),
provider: (provider) => { provider: (provider: Provider) => {
return { return {
...provider, ...provider,
type: 'gemini' as ProviderType type: 'gemini'
} }
} }
}, },
{ {
match: endpointIs('openai-response'), match: endpointIs('openai-response'),
provider: (provider) => { provider: (provider: Provider) => {
return { return {
...provider, ...provider,
type: 'openai-response' as ProviderType type: 'openai-response'
} }
} }
}, },
{ {
match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model), match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model),
provider: (provider) => { provider: (provider: Provider) => {
return { return {
...provider, ...provider,
type: 'openai' as ProviderType type: 'openai'
} }
} }
} }
], ],
fallbackRule: (provider) => provider fallbackRule: (provider: Provider) => provider
} }
export const newApiResolverCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P => export const newApiResolverCreator = provider2Provider.bind(null, NEWAPI_RULES)
provider2Provider<MinimalModel, MinimalProvider, P>(NEWAPI_RULES, model, provider)
@@ -0,0 +1,9 @@
import type { Model, Provider } from '@renderer/types'
export interface RuleSet {
rules: Array<{
match: (model: Model) => boolean
provider: (provider: Provider) => Provider
}>
fallbackRule: (provider: Provider) => Provider
}
@@ -0,0 +1,19 @@
import type { Provider } from '@renderer/types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
const VERTEX_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: Provider) => ({
...provider,
id: 'google-vertex-anthropic'
})
}
],
fallbackRule: (provider: Provider) => provider
}
export const vertexAnthropicProviderCreator = provider2Provider.bind(null, VERTEX_ANTHROPIC_RULES)
+25 -1
View File
@@ -1 +1,25 @@
export { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '@shared/provider/constant' import type { Model } from '@renderer/types'
export const COPILOT_EDITOR_VERSION = 'vscode/1.104.1'
export const COPILOT_PLUGIN_VERSION = 'copilot-chat/0.26.7'
export const COPILOT_INTEGRATION_ID = 'vscode-chat'
export const COPILOT_USER_AGENT = 'GitHubCopilotChat/0.26.7'
export const COPILOT_DEFAULT_HEADERS = {
'Copilot-Integration-Id': COPILOT_INTEGRATION_ID,
'User-Agent': COPILOT_USER_AGENT,
'Editor-Version': COPILOT_EDITOR_VERSION,
'Editor-Plugin-Version': COPILOT_PLUGIN_VERSION,
'editor-version': COPILOT_EDITOR_VERSION,
'editor-plugin-version': COPILOT_PLUGIN_VERSION,
'copilot-vision-request': 'true'
} as const
// Models that require the OpenAI Responses endpoint when routed through GitHub Copilot (#10560)
const COPILOT_RESPONSES_MODEL_IDS = ['gpt-5-codex']
export function isCopilotResponsesModel(model: Model): boolean {
const normalizedId = model.id?.trim().toLowerCase()
const normalizedName = model.name?.trim().toLowerCase()
return COPILOT_RESPONSES_MODEL_IDS.some((target) => normalizedId === target || normalizedName === target)
}
+60 -3
View File
@@ -1,7 +1,8 @@
import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider'
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider' import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import type { Provider } from '@renderer/types' import type { Provider } from '@renderer/types'
import { getAiSdkProviderId as sharedGetAiSdkProviderId } from '@shared/provider' import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from '@renderer/utils/provider'
import type { Provider as AiSdkProvider } from 'ai' import type { Provider as AiSdkProvider } from 'ai'
import type { AiSdkConfig } from '../types' import type { AiSdkConfig } from '../types'
@@ -21,12 +22,68 @@ const logger = loggerService.withContext('ProviderFactory')
} }
})() })()
/**
* 静态Provider映射表
* 处理Cherry Studio特有的provider ID到AI SDK标准ID的映射
*/
const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
gemini: 'google', // Google Gemini -> google
'azure-openai': 'azure', // Azure OpenAI -> azure
'openai-response': 'openai', // OpenAI Responses -> openai
grok: 'xai', // Grok -> xai
copilot: 'github-copilot-openai-compatible'
}
/**
* 尝试解析provider标识符(支持静态映射和别名)
*/
function tryResolveProviderId(identifier: string): ProviderId | null {
// 1. 检查静态映射
const staticMapping = STATIC_PROVIDER_MAPPING[identifier]
if (staticMapping) {
return staticMapping
}
// 2. 检查AiCore是否支持(包括别名支持)
if (hasProviderConfigByAlias(identifier)) {
// 解析为真实的Provider ID
return resolveProviderConfigId(identifier) as ProviderId
}
return null
}
/** /**
* 获取AI SDK Provider ID * 获取AI SDK Provider ID
* Uses shared implementation with renderer-specific config checker * 简化版:减少重复逻辑,利用通用解析函数
*/ */
export function getAiSdkProviderId(provider: Provider): string { export function getAiSdkProviderId(provider: Provider): string {
return sharedGetAiSdkProviderId(provider) // 1. 尝试解析provider.id
const resolvedFromId = tryResolveProviderId(provider.id)
if (isAzureOpenAIProvider(provider)) {
if (isAzureResponsesEndpoint(provider)) {
return 'azure-responses'
} else {
return 'azure'
}
}
if (resolvedFromId) {
return resolvedFromId
}
// 2. 尝试解析provider.type
// 会把所有类型为openai的自定义provider解析到aisdk的openaiProvider上
if (provider.type !== 'openai') {
const resolvedFromType = tryResolveProviderId(provider.type)
if (resolvedFromType) {
return resolvedFromType
}
}
if (provider.apiHost.includes('api.openai.com')) {
return 'openai-chat'
}
// 3. 最后的fallback(使用provider本身的id
return provider.id
} }
export async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> { export async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {
@@ -1,4 +1,4 @@
import { hasProviderConfig } from '@cherrystudio/ai-core/provider' import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models' import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
import { import {
getAwsBedrockAccessKeyId, getAwsBedrockAccessKeyId,
@@ -10,17 +10,22 @@ import {
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI' import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService' import { getProviderByModel } from '@renderer/services/AssistantService'
import store from '@renderer/store' import store from '@renderer/store'
import { isSystemProvider, type Model, type Provider } from '@renderer/types' import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import { formatApiHost, formatAzureOpenAIApiHost, formatVertexApiHost, routeToEndpoint } from '@renderer/utils/api'
import { import {
type AiSdkConfigContext, isAnthropicProvider,
formatProviderApiHost as sharedFormatProviderApiHost, isAzureOpenAIProvider,
type ProviderFormatContext, isCherryAIProvider,
providerToAiSdkConfig as sharedProviderToAiSdkConfig, isGeminiProvider,
resolveActualProvider isNewApiProvider,
} from '@shared/provider' isPerplexityProvider,
isVertexProvider
} from '@renderer/utils/provider'
import { cloneDeep } from 'lodash' import { cloneDeep } from 'lodash'
import type { AiSdkConfig } from '../types' import type { AiSdkConfig } from '../types'
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
import { COPILOT_DEFAULT_HEADERS } from './constants' import { COPILOT_DEFAULT_HEADERS } from './constants'
import { getAiSdkProviderId } from './factory' import { getAiSdkProviderId } from './factory'
@@ -51,51 +56,61 @@ function getRotatedApiKey(provider: Provider): string {
} }
/** /**
* Renderer-specific context for providerToAiSdkConfig * 处理特殊provider的转换逻辑
* Provides implementations using browser APIs, store, and hooks
*/ */
function createRendererSdkContext(model: Model): AiSdkConfigContext { function handleSpecialProviders(model: Model, provider: Provider): Provider {
return { if (isNewApiProvider(provider)) {
getRotatedApiKey: (provider) => getRotatedApiKey(provider as Provider), return newApiResolverCreator(model, provider)
isOpenAIChatCompletionOnlyModel: () => isOpenAIChatCompletionOnlyModel(model),
getCopilotDefaultHeaders: () => COPILOT_DEFAULT_HEADERS,
getCopilotStoredHeaders: () => store.getState().copilot.defaultHeaders ?? {},
getAwsBedrockConfig: () => {
const authType = getAwsBedrockAuthType()
return {
authType,
region: getAwsBedrockRegion(),
apiKey: authType === 'apiKey' ? getAwsBedrockApiKey() : undefined,
accessKeyId: authType === 'iam' ? getAwsBedrockAccessKeyId() : undefined,
secretAccessKey: authType === 'iam' ? getAwsBedrockSecretAccessKey() : undefined
}
},
getVertexConfig: (provider) => {
if (!isVertexAIConfigured()) {
return undefined
}
return createVertexProvider(provider as Provider)
},
getEndpointType: () => model.endpoint_type
} }
if (isSystemProvider(provider)) {
if (provider.id === 'aihubmix') {
return aihubmixProviderCreator(model, provider)
}
if (provider.id === 'vertexai') {
return vertexAnthropicProviderCreator(model, provider)
}
}
if (isAzureOpenAIProvider(provider)) {
return azureAnthropicProviderCreator(model, provider)
}
return provider
} }
/** /**
* 主要用来对齐AISdk的BaseURL格式 * 主要用来对齐AISdk的BaseURL格式
* Uses shared implementation with renderer-specific context * @param provider
* @returns
*/ */
function getRendererFormatContext(): ProviderFormatContext {
const vertexSettings = store.getState().llm.settings.vertexai
return {
vertex: {
project: vertexSettings.projectId || 'default-project',
location: vertexSettings.location || 'us-central1'
}
}
}
function formatProviderApiHost(provider: Provider): Provider { function formatProviderApiHost(provider: Provider): Provider {
return sharedFormatProviderApiHost(provider, getRendererFormatContext()) const formatted = { ...provider }
if (formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost)
}
if (isAnthropicProvider(provider)) {
const baseHost = formatted.anthropicApiHost || formatted.apiHost
// AI SDK needs /v1 in baseURL, Anthropic SDK will strip it in getSdkClient
formatted.apiHost = formatApiHost(baseHost)
if (!formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatted.apiHost
}
} else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isGeminiProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, true, 'v1beta')
} else if (isAzureOpenAIProvider(formatted)) {
formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost)
} else if (isVertexProvider(formatted)) {
formatted.apiHost = formatVertexApiHost(formatted)
} else if (isCherryAIProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isPerplexityProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else {
formatted.apiHost = formatApiHost(formatted.apiHost)
}
return formatted
} }
/** /**
@@ -107,9 +122,7 @@ export function getActualProvider(model: Model): Provider {
// 按顺序处理各种转换 // 按顺序处理各种转换
let actualProvider = cloneDeep(baseProvider) let actualProvider = cloneDeep(baseProvider)
actualProvider = resolveActualProvider(actualProvider, model, { actualProvider = handleSpecialProviders(model, actualProvider)
isSystemProvider
}) as Provider
actualProvider = formatProviderApiHost(actualProvider) actualProvider = formatProviderApiHost(actualProvider)
return actualProvider return actualProvider
@@ -117,11 +130,121 @@ export function getActualProvider(model: Model): Provider {
/** /**
* 将 Provider 配置转换为新 AI SDK 格式 * 将 Provider 配置转换为新 AI SDK 格式
* Uses shared implementation with renderer-specific context * 简化版:利用新的别名映射系统
*/ */
export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig { export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig {
const context = createRendererSdkContext(model) const aiSdkProviderId = getAiSdkProviderId(actualProvider)
return sharedProviderToAiSdkConfig(actualProvider, model.id, context) as AiSdkConfig
// 构建基础配置
const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost)
const baseConfig = {
baseURL: baseURL,
apiKey: getRotatedApiKey(actualProvider)
}
const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot
if (isCopilotProvider) {
const storedHeaders = store.getState().copilot.defaultHeaders ?? {}
const options = ProviderConfigFactory.fromProvider('github-copilot-openai-compatible', baseConfig, {
headers: {
...COPILOT_DEFAULT_HEADERS,
...storedHeaders,
...actualProvider.extra_headers
},
name: actualProvider.id,
includeUsage: true
})
return {
providerId: 'github-copilot-openai-compatible',
options
}
}
// 处理OpenAI模式
const extraOptions: any = {}
extraOptions.endpoint = endpoint
if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && actualProvider.type === 'openai')) {
extraOptions.mode = 'chat'
}
// 添加额外headers
if (actualProvider.extra_headers) {
extraOptions.headers = actualProvider.extra_headers
// copy from openaiBaseClient/openaiResponseApiClient
if (aiSdkProviderId === 'openai') {
extraOptions.headers = {
...extraOptions.headers,
'HTTP-Referer': 'https://cherry-ai.com',
'X-Title': 'Cherry Studio',
'X-Api-Key': baseConfig.apiKey
}
}
}
// azure
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses?tabs=python-key#responses-api
if (aiSdkProviderId === 'azure-responses') {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'azure') {
extraOptions.mode = 'chat'
}
// bedrock
if (aiSdkProviderId === 'bedrock') {
const authType = getAwsBedrockAuthType()
extraOptions.region = getAwsBedrockRegion()
if (authType === 'apiKey') {
extraOptions.apiKey = getAwsBedrockApiKey()
} else {
extraOptions.accessKeyId = getAwsBedrockAccessKeyId()
extraOptions.secretAccessKey = getAwsBedrockSecretAccessKey()
}
}
// google-vertex
if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') {
if (!isVertexAIConfigured()) {
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
}
const { project, location, googleCredentials } = createVertexProvider(actualProvider)
extraOptions.project = project
extraOptions.location = location
extraOptions.googleCredentials = {
...googleCredentials,
privateKey: formatPrivateKey(googleCredentials.privateKey)
}
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
}
// cherryin
if (aiSdkProviderId === 'cherryin') {
if (model.endpoint_type) {
extraOptions.endpointType = model.endpoint_type
}
}
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return {
providerId: aiSdkProviderId,
options
}
}
// 否则fallback到openai-compatible
const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey)
return {
providerId: 'openai-compatible',
options: {
...options,
name: actualProvider.id,
...extraOptions,
includeUsage: true
}
}
} }
/** /**
@@ -164,13 +287,13 @@ export async function prepareSpecialProviderConfig(
break break
} }
case 'cherryai': { case 'cherryai': {
config.options.fetch = async (url: RequestInfo | URL, options: RequestInit) => { config.options.fetch = async (url, options) => {
// 在这里对最终参数进行签名 // 在这里对最终参数进行签名
const signature = await window.api.cherryai.generateSignature({ const signature = await window.api.cherryai.generateSignature({
method: 'POST', method: 'POST',
path: '/chat/completions', path: '/chat/completions',
query: '', query: '',
body: JSON.parse(options.body as string) body: JSON.parse(options.body)
}) })
return fetch(url, { return fetch(url, {
...options, ...options,
@@ -1,13 +1,113 @@
import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger' import { loggerService } from '@logger'
import { initializeSharedProviders, SHARED_PROVIDER_CONFIGS } from '@shared/provider'
const logger = loggerService.withContext('ProviderConfigs') const logger = loggerService.withContext('ProviderConfigs')
export const NEW_PROVIDER_CONFIGS = SHARED_PROVIDER_CONFIGS /**
* 新Provider配置定义
* 定义了需要动态注册的AI Providers
*/
export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
{
id: 'openrouter',
name: 'OpenRouter',
import: () => import('@openrouter/ai-sdk-provider'),
creatorFunctionName: 'createOpenRouter',
supportsImageGeneration: true,
aliases: ['openrouter']
},
{
id: 'google-vertex',
name: 'Google Vertex AI',
import: () => import('@ai-sdk/google-vertex/edge'),
creatorFunctionName: 'createVertex',
supportsImageGeneration: true,
aliases: ['vertexai']
},
{
id: 'google-vertex-anthropic',
name: 'Google Vertex AI Anthropic',
import: () => import('@ai-sdk/google-vertex/anthropic/edge'),
creatorFunctionName: 'createVertexAnthropic',
supportsImageGeneration: true,
aliases: ['vertexai-anthropic']
},
{
id: 'azure-anthropic',
name: 'Azure AI Anthropic',
import: () => import('@ai-sdk/anthropic'),
creatorFunctionName: 'createAnthropic',
supportsImageGeneration: false,
aliases: ['azure-anthropic']
},
{
id: 'github-copilot-openai-compatible',
name: 'GitHub Copilot OpenAI Compatible',
import: () => import('@opeoginni/github-copilot-openai-compatible'),
creatorFunctionName: 'createGitHubCopilotOpenAICompatible',
supportsImageGeneration: false,
aliases: ['copilot', 'github-copilot']
},
{
id: 'bedrock',
name: 'Amazon Bedrock',
import: () => import('@ai-sdk/amazon-bedrock'),
creatorFunctionName: 'createAmazonBedrock',
supportsImageGeneration: true,
aliases: ['aws-bedrock']
},
{
id: 'perplexity',
name: 'Perplexity',
import: () => import('@ai-sdk/perplexity'),
creatorFunctionName: 'createPerplexity',
supportsImageGeneration: false,
aliases: ['perplexity']
},
{
id: 'mistral',
name: 'Mistral',
import: () => import('@ai-sdk/mistral'),
creatorFunctionName: 'createMistral',
supportsImageGeneration: false,
aliases: ['mistral']
},
{
id: 'huggingface',
name: 'HuggingFace',
import: () => import('@ai-sdk/huggingface'),
creatorFunctionName: 'createHuggingFace',
supportsImageGeneration: true,
aliases: ['hf', 'hugging-face']
},
{
id: 'ai-gateway',
name: 'AI Gateway',
import: () => import('@ai-sdk/gateway'),
creatorFunctionName: 'createGateway',
supportsImageGeneration: true,
aliases: ['gateway']
},
{
id: 'cerebras',
name: 'Cerebras',
import: () => import('@ai-sdk/cerebras'),
creatorFunctionName: 'createCerebras',
supportsImageGeneration: false
}
] as const
/**
* 初始化新的Providers
* 使用aiCore的动态注册功能
*/
export async function initializeNewProviders(): Promise<void> { export async function initializeNewProviders(): Promise<void> {
initializeSharedProviders({ try {
warn: (message) => logger.warn(message), const successCount = registerMultipleProviderConfigs(NEW_PROVIDER_CONFIGS)
error: (message, error) => logger.error(message, error) if (successCount < NEW_PROVIDER_CONFIGS.length) {
}) logger.warn('Some providers failed to register. Check previous error logs.')
}
} catch (error) {
logger.error('Failed to initialize new providers:', error as Error)
}
} }
@@ -37,7 +37,7 @@ vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => {
}, },
customProviderIdSchema: { customProviderIdSchema: {
safeParse: vi.fn((id) => { safeParse: vi.fn((id) => {
const customProviders = ['google-vertex', 'google-vertex-anthropic', 'bedrock'] const customProviders = ['google-vertex', 'google-vertex-anthropic', 'bedrock', 'ai-gateway']
if (customProviders.includes(id)) { if (customProviders.includes(id)) {
return { success: true, data: id } return { success: true, data: id }
} }
@@ -56,7 +56,8 @@ vi.mock('../provider/factory', () => ({
[SystemProviderIds.anthropic]: 'anthropic', [SystemProviderIds.anthropic]: 'anthropic',
[SystemProviderIds.grok]: 'xai', [SystemProviderIds.grok]: 'xai',
[SystemProviderIds.deepseek]: 'deepseek', [SystemProviderIds.deepseek]: 'deepseek',
[SystemProviderIds.openrouter]: 'openrouter' [SystemProviderIds.openrouter]: 'openrouter',
[SystemProviderIds['ai-gateway']]: 'ai-gateway'
} }
return mapping[provider.id] || provider.id return mapping[provider.id] || provider.id
}) })
@@ -204,6 +205,8 @@ describe('options utils', () => {
expect(result.providerOptions).toHaveProperty('openai') expect(result.providerOptions).toHaveProperty('openai')
expect(result.providerOptions.openai).toBeDefined() expect(result.providerOptions.openai).toBeDefined()
expect(result.standardParams).toBeDefined() expect(result.standardParams).toBeDefined()
expect(result.bodyParams).toBeDefined()
expect(result.bodyParams).toEqual({})
}) })
it('should include reasoning parameters when enabled', () => { it('should include reasoning parameters when enabled', () => {
@@ -696,5 +699,90 @@ describe('options utils', () => {
}) })
}) })
}) })
describe('AI Gateway provider', () => {
const aiGatewayProvider: Provider = {
id: SystemProviderIds['ai-gateway'],
name: 'AI Gateway',
type: 'ai-gateway',
apiKey: 'test-key',
apiHost: 'https://ai-gateway.vercel.sh/v1/ai',
isSystem: true,
models: [] as Model[]
} as Provider
const aiGatewayModel: Model = {
id: 'openai/gpt-4',
name: 'GPT-4',
provider: SystemProviderIds['ai-gateway']
} as Model
it('should build basic AI Gateway options with empty bodyParams', () => {
const result = buildProviderOptions(mockAssistant, aiGatewayModel, aiGatewayProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.providerOptions).toHaveProperty('gateway')
expect(result.providerOptions.gateway).toBeDefined()
expect(result.bodyParams).toEqual({})
})
it('should place custom parameters in bodyParams for AI Gateway instead of providerOptions', async () => {
const { getCustomParameters } = await import('../reasoning')
vi.mocked(getCustomParameters).mockReturnValue({
tools: [{ id: 'openai.image_generation' }],
custom_param: 'custom_value'
})
const result = buildProviderOptions(mockAssistant, aiGatewayModel, aiGatewayProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
// Custom parameters should be in bodyParams, NOT in providerOptions.gateway
expect(result.bodyParams).toHaveProperty('tools')
expect(result.bodyParams.tools).toEqual([{ id: 'openai.image_generation' }])
expect(result.bodyParams).toHaveProperty('custom_param')
expect(result.bodyParams.custom_param).toBe('custom_value')
// providerOptions.gateway should NOT contain custom parameters
expect(result.providerOptions.gateway).not.toHaveProperty('tools')
expect(result.providerOptions.gateway).not.toHaveProperty('custom_param')
})
it('should still extract AI SDK standard params from custom parameters for AI Gateway', async () => {
const { getCustomParameters } = await import('../reasoning')
vi.mocked(getCustomParameters).mockReturnValue({
topK: 5,
frequencyPenalty: 0.5,
tools: [{ id: 'openai.image_generation' }]
})
const result = buildProviderOptions(mockAssistant, aiGatewayModel, aiGatewayProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
// Standard params should be extracted and returned separately
expect(result.standardParams).toEqual({
topK: 5,
frequencyPenalty: 0.5
})
// Custom params (non-standard) should be in bodyParams
expect(result.bodyParams).toHaveProperty('tools')
expect(result.bodyParams.tools).toEqual([{ id: 'openai.image_generation' }])
// Neither should be in providerOptions.gateway
expect(result.providerOptions.gateway).not.toHaveProperty('topK')
expect(result.providerOptions.gateway).not.toHaveProperty('tools')
})
})
}) })
}) })
+18 -8
View File
@@ -155,6 +155,7 @@ export function buildProviderOptions(
): { ): {
providerOptions: Record<string, Record<string, JSONValue>> providerOptions: Record<string, Record<string, JSONValue>>
standardParams: Partial<Record<AiSdkParam, any>> standardParams: Partial<Record<AiSdkParam, any>>
bodyParams: Record<string, any>
} { } {
logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities }) logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities })
const rawProviderId = getAiSdkProviderId(actualProvider) const rawProviderId = getAiSdkProviderId(actualProvider)
@@ -253,12 +254,6 @@ export function buildProviderOptions(
const customParams = getCustomParameters(assistant) const customParams = getCustomParameters(assistant)
const { standardParams, providerParams } = extractAiSdkStandardParams(customParams) const { standardParams, providerParams } = extractAiSdkStandardParams(customParams)
// 合并 provider 特定的自定义参数到 providerSpecificOptions
providerSpecificOptions = {
...providerSpecificOptions,
...providerParams
}
let rawProviderKey = let rawProviderKey =
{ {
'google-vertex': 'google', 'google-vertex': 'google',
@@ -273,12 +268,27 @@ export function buildProviderOptions(
rawProviderKey = { gemini: 'google', ['openai-response']: 'openai' }[actualProvider.type] || actualProvider.type rawProviderKey = { gemini: 'google', ['openai-response']: 'openai' }[actualProvider.type] || actualProvider.type
} }
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数 // For AI Gateway, custom parameters should be placed at body level, not inside providerOptions.gateway
// See: https://github.com/CherryHQ/cherry-studio/issues/4197
let bodyParams: Record<string, any> = {}
if (rawProviderKey === 'gateway') {
// Custom parameters go to body level for AI Gateway
bodyParams = providerParams
} else {
// For other providers, merge custom parameters into providerSpecificOptions
providerSpecificOptions = {
...providerSpecificOptions,
...providerParams
}
}
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions } 以及提取的标准参数和 body 参数
return { return {
providerOptions: { providerOptions: {
[rawProviderKey]: providerSpecificOptions [rawProviderKey]: providerSpecificOptions
}, },
standardParams standardParams,
bodyParams
} }
} }
@@ -15,7 +15,6 @@ import {
isSupportVerbosityModel isSupportVerbosityModel
} from '../openai' } from '../openai'
import { isQwenMTModel } from '../qwen' import { isQwenMTModel } from '../qwen'
import { isFunctionCallingModel } from '../tooluse'
import { import {
agentModelFilter, agentModelFilter,
getModelSupportedVerbosity, getModelSupportedVerbosity,
@@ -113,7 +112,6 @@ const textToImageMock = vi.mocked(isTextToImageModel)
const generateImageMock = vi.mocked(isGenerateImageModel) const generateImageMock = vi.mocked(isGenerateImageModel)
const reasoningMock = vi.mocked(isOpenAIReasoningModel) const reasoningMock = vi.mocked(isOpenAIReasoningModel)
const openAIWebSearchOnlyMock = vi.mocked(isOpenAIWebSearchChatCompletionOnlyModel) const openAIWebSearchOnlyMock = vi.mocked(isOpenAIWebSearchChatCompletionOnlyModel)
const isFunctionCallingModelMock = vi.mocked(isFunctionCallingModel)
describe('model utils', () => { describe('model utils', () => {
beforeEach(() => { beforeEach(() => {
@@ -122,7 +120,7 @@ describe('model utils', () => {
rerankMock.mockReturnValue(false) rerankMock.mockReturnValue(false)
visionMock.mockReturnValue(true) visionMock.mockReturnValue(true)
textToImageMock.mockReturnValue(false) textToImageMock.mockReturnValue(false)
generateImageMock.mockReturnValue(false) generateImageMock.mockReturnValue(true)
reasoningMock.mockReturnValue(false) reasoningMock.mockReturnValue(false)
openAIWebSearchOnlyMock.mockReturnValue(false) openAIWebSearchOnlyMock.mockReturnValue(false)
}) })
@@ -420,7 +418,6 @@ describe('model utils', () => {
describe('isGenerateImageModels', () => { describe('isGenerateImageModels', () => {
it('returns true when all models support image generation', () => { it('returns true when all models support image generation', () => {
const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })] const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })]
generateImageMock.mockReturnValue(true)
expect(isGenerateImageModels(models)).toBe(true) expect(isGenerateImageModels(models)).toBe(true)
}) })
@@ -459,22 +456,12 @@ describe('model utils', () => {
expect(agentModelFilter(createModel({ id: 'rerank' }))).toBe(false) expect(agentModelFilter(createModel({ id: 'rerank' }))).toBe(false)
}) })
it('filters out non-function-call models', () => {
rerankMock.mockReturnValue(false)
isFunctionCallingModelMock.mockReturnValueOnce(false)
expect(agentModelFilter(createModel({ id: 'DeepSeek R1' }))).toBe(false)
})
it('filters out text-to-image models', () => { it('filters out text-to-image models', () => {
rerankMock.mockReturnValue(false) rerankMock.mockReturnValue(false)
textToImageMock.mockReturnValueOnce(true) textToImageMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'gpt-image-1' }))).toBe(false) expect(agentModelFilter(createModel({ id: 'gpt-image-1' }))).toBe(false)
}) })
}) })
textToImageMock.mockReturnValue(false)
generateImageMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'dall-e-3' }))).toBe(false)
}) })
describe('Temperature limits', () => { describe('Temperature limits', () => {
-18
View File
@@ -1,8 +1,6 @@
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model } from '@renderer/types' import type { Model } from '@renderer/types'
import { isSystemProviderId } from '@renderer/types' import { isSystemProviderId } from '@renderer/types'
import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils' import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
import { isAzureOpenAIProvider } from '@shared/provider'
import { isEmbeddingModel, isRerankModel } from './embedding' import { isEmbeddingModel, isRerankModel } from './embedding'
import { isDeepSeekHybridInferenceModel } from './reasoning' import { isDeepSeekHybridInferenceModel } from './reasoning'
@@ -54,13 +52,6 @@ export const FUNCTION_CALLING_REGEX = new RegExp(
'i' 'i'
) )
const AZURE_FUNCTION_CALLING_EXCLUDED_MODELS = [
'(?:Meta-)?Llama-3(?:\\.\\d+)?-[\\w-]+',
'Phi-[34](?:\\.[\\w-]+)?(?:-[\\w-]+)?',
'DeepSeek-(?:R1|V3)',
'Codestral-2501'
]
export function isFunctionCallingModel(model?: Model): boolean { export function isFunctionCallingModel(model?: Model): boolean {
if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) { if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) {
return false return false
@@ -76,15 +67,6 @@ export function isFunctionCallingModel(model?: Model): boolean {
return FUNCTION_CALLING_REGEX.test(modelId) || FUNCTION_CALLING_REGEX.test(model.name) return FUNCTION_CALLING_REGEX.test(modelId) || FUNCTION_CALLING_REGEX.test(model.name)
} }
const provider = getProviderByModel(model)
if (isAzureOpenAIProvider(provider)) {
const azureExcludedRegex = new RegExp(`\\b(?:${AZURE_FUNCTION_CALLING_EXCLUDED_MODELS.join('|')})\\b`, 'i')
if (azureExcludedRegex.test(modelId)) {
return false
}
}
if (['deepseek', 'anthropic', 'kimi', 'moonshot'].includes(model.provider)) { if (['deepseek', 'anthropic', 'kimi', 'moonshot'].includes(model.provider)) {
return true return true
} }
+1 -16
View File
@@ -1,6 +1,5 @@
import type OpenAI from '@cherrystudio/openai' import type OpenAI from '@cherrystudio/openai'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding' import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding'
import { getProviderByModel } from '@renderer/services/AssistantService'
import { type Model, SystemProviderIds } from '@renderer/types' import { type Model, SystemProviderIds } from '@renderer/types'
import type { OpenAIVerbosity, ValidOpenAIVerbosity } from '@renderer/types/aiCoreTypes' import type { OpenAIVerbosity, ValidOpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import { getLowerBaseModelName } from '@renderer/utils' import { getLowerBaseModelName } from '@renderer/utils'
@@ -14,7 +13,6 @@ import {
isOpenAIReasoningModel isOpenAIReasoningModel
} from './openai' } from './openai'
import { isQwenMTModel } from './qwen' import { isQwenMTModel } from './qwen'
import { isFunctionCallingModel } from './tooluse'
import { isGenerateImageModel, isTextToImageModel, isVisionModel } from './vision' import { isGenerateImageModel, isTextToImageModel, isVisionModel } from './vision'
export const NOT_SUPPORTED_REGEX = /(?:^tts|whisper|speech)/i export const NOT_SUPPORTED_REGEX = /(?:^tts|whisper|speech)/i
export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini.*-flash.*$', 'i') export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini.*-flash.*$', 'i')
@@ -183,21 +181,8 @@ export const isGeminiModel = (model: Model) => {
// zhipu 视觉推理模型用这组 special token 标记推理结果 // zhipu 视觉推理模型用这组 special token 标记推理结果
export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const
// TODO: 支持提示词模式的工具调用
export const agentModelFilter = (model: Model): boolean => { export const agentModelFilter = (model: Model): boolean => {
const provider = getProviderByModel(model) return !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model)
// 需要适配,且容易超出限额
if (provider.id === SystemProviderIds.copilot) {
return false
}
return (
!isEmbeddingModel(model) &&
!isRerankModel(model) &&
!isTextToImageModel(model) &&
!isGenerateImageModel(model) &&
isFunctionCallingModel(model)
)
} }
export const isMaxTemperatureOneModel = (model: Model): boolean => { export const isMaxTemperatureOneModel = (model: Model): boolean => {
@@ -17,7 +17,7 @@ import type { EndpointType, Model } from '@renderer/types'
import { getClaudeSupportedProviders } from '@renderer/utils/provider' import { getClaudeSupportedProviders } from '@renderer/utils/provider'
import type { TerminalConfig } from '@shared/config/constant' import type { TerminalConfig } from '@shared/config/constant'
import { codeTools, terminalApps } from '@shared/config/constant' import { codeTools, terminalApps } from '@shared/config/constant'
import { isPpioAnthropicCompatibleModel, isSiliconAnthropicCompatibleModel } from '@shared/config/providers' import { isSiliconAnthropicCompatibleModel } from '@shared/config/providers'
import { Alert, Avatar, Button, Checkbox, Input, Popover, Select, Space, Tooltip } from 'antd' import { Alert, Avatar, Button, Checkbox, Input, Popover, Select, Space, Tooltip } from 'antd'
import { ArrowUpRight, Download, FolderOpen, HelpCircle, Terminal, X } from 'lucide-react' import { ArrowUpRight, Download, FolderOpen, HelpCircle, Terminal, X } from 'lucide-react'
import type { FC } from 'react' import type { FC } from 'react'
@@ -82,12 +82,10 @@ const CodeToolsPage: FC = () => {
if (m.supported_endpoint_types) { if (m.supported_endpoint_types) {
return m.supported_endpoint_types.includes('anthropic') return m.supported_endpoint_types.includes('anthropic')
} }
// Special handling for silicon provider: only specific models support Anthropic API
if (m.provider === 'silicon') { if (m.provider === 'silicon') {
return isSiliconAnthropicCompatibleModel(m.id) return isSiliconAnthropicCompatibleModel(m.id)
} }
if (m.provider === 'ppio') {
return isPpioAnthropicCompatibleModel(m.id)
}
return m.id.includes('claude') || CLAUDE_OFFICIAL_SUPPORTED_PROVIDERS.includes(m.provider) return m.id.includes('claude') || CLAUDE_OFFICIAL_SUPPORTED_PROVIDERS.includes(m.provider)
} }
@@ -85,8 +85,7 @@ const ANTHROPIC_COMPATIBLE_PROVIDER_IDS = [
SystemProviderIds.minimax, SystemProviderIds.minimax,
SystemProviderIds.silicon, SystemProviderIds.silicon,
SystemProviderIds.qiniu, SystemProviderIds.qiniu,
SystemProviderIds.dmxapi, SystemProviderIds.dmxapi
SystemProviderIds.ppio
] as const ] as const
type AnthropicCompatibleProviderId = (typeof ANTHROPIC_COMPATIBLE_PROVIDER_IDS)[number] type AnthropicCompatibleProviderId = (typeof ANTHROPIC_COMPATIBLE_PROVIDER_IDS)[number]
+1 -1
View File
@@ -67,7 +67,7 @@ const persistedReducer = persistReducer(
{ {
key: 'cherry-studio', key: 'cherry-studio',
storage, storage,
version: 180, version: 179,
blacklist: ['runtime', 'messages', 'messageBlocks', 'tabs', 'toolPermissions'], blacklist: ['runtime', 'messages', 'messageBlocks', 'tabs', 'toolPermissions'],
migrate migrate
}, },
-14
View File
@@ -2906,20 +2906,6 @@ const migrateConfig = {
logger.error('migrate 179 error', error as Error) logger.error('migrate 179 error', error as Error)
return state return state
} }
},
'180': (state: RootState) => {
try {
state.llm.providers.forEach((provider) => {
if (provider.id === SystemProviderIds.ppio) {
provider.anthropicApiHost = 'https://api.ppinfra.com/anthropic'
}
})
logger.info('migrate 180 success')
return state
} catch (error) {
logger.error('migrate 180 error', error as Error)
return state
}
} }
} }
+1 -3
View File
@@ -7,8 +7,6 @@ import type { CSSProperties } from 'react'
export * from './file' export * from './file'
export * from './note' export * from './note'
import type { MinimalModel } from '@shared/provider/types'
import type { StreamTextParams } from './aiCoreTypes' import type { StreamTextParams } from './aiCoreTypes'
import type { Chunk } from './chunk' import type { Chunk } from './chunk'
import type { FileMetadata } from './file' import type { FileMetadata } from './file'
@@ -259,7 +257,7 @@ export type ModelCapability = {
isUserSelected?: boolean isUserSelected?: boolean
} }
export type Model = MinimalModel & { export type Model = {
id: string id: string
provider: string provider: string
name: string name: string
+151 -7
View File
@@ -1,14 +1,24 @@
import type OpenAI from '@cherrystudio/openai' import type OpenAI from '@cherrystudio/openai'
import type { MinimalProvider } from '@shared/provider'
import type { ProviderType, SystemProviderId, SystemProviderIdTypeMap } from '@shared/provider/types'
import { isSystemProviderId, SystemProviderIds } from '@shared/provider/types'
import type { Model } from '@types' import type { Model } from '@types'
import * as z from 'zod'
import type { OpenAIVerbosity } from './aiCoreTypes' import type { OpenAIVerbosity } from './aiCoreTypes'
export type { ProviderType } from '@shared/provider' export const ProviderTypeSchema = z.enum([
export type { SystemProviderId, SystemProviderIdTypeMap } from '@shared/provider/types' 'openai',
export { isSystemProviderId, ProviderTypeSchema, SystemProviderIds } from '@shared/provider/types' 'openai-response',
'anthropic',
'gemini',
'azure-openai',
'vertexai',
'mistral',
'aws-bedrock',
'vertex-anthropic',
'new-api',
'ai-gateway'
])
export type ProviderType = z.infer<typeof ProviderTypeSchema>
// undefined is treated as supported, enabled by default // undefined is treated as supported, enabled by default
export type ProviderApiOptions = { export type ProviderApiOptions = {
@@ -83,7 +93,7 @@ export function isAwsBedrockAuthType(type: string): type is AwsBedrockAuthType {
return Object.hasOwn(AwsBedrockAuthTypes, type) return Object.hasOwn(AwsBedrockAuthTypes, type)
} }
export type Provider = MinimalProvider & { export type Provider = {
id: string id: string
type: ProviderType type: ProviderType
name: string name: string
@@ -118,6 +128,140 @@ export type Provider = MinimalProvider & {
extra_headers?: Record<string, string> extra_headers?: Record<string, string>
} }
export const SystemProviderIdSchema = z.enum([
'cherryin',
'silicon',
'aihubmix',
'ocoolai',
'deepseek',
'ppio',
'alayanew',
'qiniu',
'dmxapi',
'burncloud',
'tokenflux',
'302ai',
'cephalon',
'lanyun',
'ph8',
'openrouter',
'ollama',
'ovms',
'new-api',
'lmstudio',
'anthropic',
'openai',
'azure-openai',
'gemini',
'vertexai',
'github',
'copilot',
'zhipu',
'yi',
'moonshot',
'baichuan',
'dashscope',
'stepfun',
'doubao',
'infini',
'minimax',
'groq',
'together',
'fireworks',
'nvidia',
'grok',
'hyperbolic',
'mistral',
'jina',
'perplexity',
'modelscope',
'xirang',
'hunyuan',
'tencent-cloud-ti',
'baidu-cloud',
'gpustack',
'voyageai',
'aws-bedrock',
'poe',
'aionly',
'longcat',
'huggingface',
'sophnet',
'ai-gateway',
'cerebras'
])
export type SystemProviderId = z.infer<typeof SystemProviderIdSchema>
export const isSystemProviderId = (id: string): id is SystemProviderId => {
return SystemProviderIdSchema.safeParse(id).success
}
export const SystemProviderIds = {
cherryin: 'cherryin',
silicon: 'silicon',
aihubmix: 'aihubmix',
ocoolai: 'ocoolai',
deepseek: 'deepseek',
ppio: 'ppio',
alayanew: 'alayanew',
qiniu: 'qiniu',
dmxapi: 'dmxapi',
burncloud: 'burncloud',
tokenflux: 'tokenflux',
'302ai': '302ai',
cephalon: 'cephalon',
lanyun: 'lanyun',
ph8: 'ph8',
sophnet: 'sophnet',
openrouter: 'openrouter',
ollama: 'ollama',
ovms: 'ovms',
'new-api': 'new-api',
lmstudio: 'lmstudio',
anthropic: 'anthropic',
openai: 'openai',
'azure-openai': 'azure-openai',
gemini: 'gemini',
vertexai: 'vertexai',
github: 'github',
copilot: 'copilot',
zhipu: 'zhipu',
yi: 'yi',
moonshot: 'moonshot',
baichuan: 'baichuan',
dashscope: 'dashscope',
stepfun: 'stepfun',
doubao: 'doubao',
infini: 'infini',
minimax: 'minimax',
groq: 'groq',
together: 'together',
fireworks: 'fireworks',
nvidia: 'nvidia',
grok: 'grok',
hyperbolic: 'hyperbolic',
mistral: 'mistral',
jina: 'jina',
perplexity: 'perplexity',
modelscope: 'modelscope',
xirang: 'xirang',
hunyuan: 'hunyuan',
'tencent-cloud-ti': 'tencent-cloud-ti',
'baidu-cloud': 'baidu-cloud',
gpustack: 'gpustack',
voyageai: 'voyageai',
'aws-bedrock': 'aws-bedrock',
poe: 'poe',
aionly: 'aionly',
longcat: 'longcat',
huggingface: 'huggingface',
'ai-gateway': 'ai-gateway',
cerebras: 'cerebras'
} as const satisfies Record<SystemProviderId, SystemProviderId>
type SystemProviderIdTypeMap = typeof SystemProviderIds
export type SystemProvider = Provider & { export type SystemProvider = Provider & {
id: SystemProviderId id: SystemProviderId
isSystem: true isSystem: true
+12 -1
View File
@@ -302,7 +302,18 @@ describe('api', () => {
}) })
it('uses global endpoint when location equals global', () => { it('uses global endpoint when location equals global', () => {
expect(formatVertexApiHost(createVertexProvider(''), 'global-project', 'global')).toBe( getStateMock.mockReturnValueOnce({
llm: {
settings: {
vertexai: {
projectId: 'global-project',
location: 'global'
}
}
}
})
expect(formatVertexApiHost(createVertexProvider(''))).toBe(
'https://aiplatform.googleapis.com/v1/projects/global-project/locations/global' 'https://aiplatform.googleapis.com/v1/projects/global-project/locations/global'
) )
}) })
+224 -14
View File
@@ -1,17 +1,6 @@
export { import store from '@renderer/store'
formatApiHost, import type { VertexProvider } from '@renderer/types'
formatAzureOpenAIApiHost, import { trim } from 'lodash'
formatVertexApiHost,
getAiSdkBaseUrl,
getTrailingApiVersion,
hasAPIVersion,
routeToEndpoint,
SUPPORTED_ENDPOINT_LIST,
SUPPORTED_IMAGE_ENDPOINT_LIST,
validateApiHost,
withoutTrailingApiVersion,
withoutTrailingSlash
} from '@shared/api'
/** /**
* API key * API key
@@ -23,6 +12,180 @@ export function formatApiKeys(value: string): string {
return value.replaceAll('', ',').replaceAll('\n', ',') return value.replaceAll('', ',').replaceAll('\n', ',')
} }
/**
* Matches a version segment in a path that starts with `/v<number>` and optionally
* continues with `alpha` or `beta`. The segment may be followed by `/` or the end
* of the string (useful for cases like `/v3alpha/resources`).
*/
const VERSION_REGEX_PATTERN = '\\/v\\d+(?:alpha|beta)?(?=\\/|$)'
/**
* Matches an API version at the end of a URL (with optional trailing slash).
* Used to detect and extract versions only from the trailing position.
*/
const TRAILING_VERSION_REGEX = /\/v\d+(?:alpha|beta)?\/?$/i
/**
* host path /v1/v2beta
*
* @param host - host path
* @returns path true false
*/
export function hasAPIVersion(host?: string): boolean {
if (!host) return false
const regex = new RegExp(VERSION_REGEX_PATTERN, 'i')
try {
const url = new URL(host)
return regex.test(url.pathname)
} catch {
// 若无法作为完整 URL 解析,则当作路径直接检测
return regex.test(host)
}
}
/**
* Removes the trailing slash from a URL string if it exists.
*
* @template T - The string type to preserve type safety
* @param {T} url - The URL string to process
* @returns {T} The URL string without a trailing slash
*
* @example
* ```ts
* withoutTrailingSlash('https://example.com/') // 'https://example.com'
* withoutTrailingSlash('https://example.com') // 'https://example.com'
* ```
*/
export function withoutTrailingSlash<T extends string>(url: T): T {
return url.replace(/\/$/, '') as T
}
/**
* Formats an API host URL by normalizing it and optionally appending an API version.
*
* @param host - The API host URL to format. Leading/trailing whitespace will be trimmed and trailing slashes removed.
* @param supportApiVersion - Whether the API version is supported. Defaults to `true`.
* @param apiVersion - The API version to append if needed. Defaults to `'v1'`.
*
* @returns The formatted API host URL. If the host is empty after normalization, returns an empty string.
* If the host ends with '#', API version is not supported, or the host already contains a version, returns the normalized host as-is.
* Otherwise, returns the host with the API version appended.
*
* @example
* formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1'
* formatApiHost('https://api.example.com#') // Returns 'https://api.example.com#'
* formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2'
*/
export function formatApiHost(host?: string, supportApiVersion: boolean = true, apiVersion: string = 'v1'): string {
const normalizedHost = withoutTrailingSlash(trim(host))
if (!normalizedHost) {
return ''
}
if (normalizedHost.endsWith('#') || !supportApiVersion || hasAPIVersion(normalizedHost)) {
return normalizedHost
}
return `${normalizedHost}/${apiVersion}`
}
/**
* Azure OpenAI API
*/
export function formatAzureOpenAIApiHost(host: string): string {
const normalizedHost = withoutTrailingSlash(host)
?.replace(/\/v1$/, '')
.replace(/\/openai$/, '')
// NOTE: AISDK会添加上`v1`
return formatApiHost(normalizedHost + '/openai', false)
}
export function formatVertexApiHost(provider: VertexProvider): string {
const { apiHost } = provider
const { projectId: project, location } = store.getState().llm.settings.vertexai
const trimmedHost = withoutTrailingSlash(trim(apiHost))
if (!trimmedHost || trimmedHost.endsWith('aiplatform.googleapis.com')) {
const host =
location == 'global' ? 'https://aiplatform.googleapis.com' : `https://${location}-aiplatform.googleapis.com`
return `${formatApiHost(host)}/projects/${project}/locations/${location}`
}
return formatApiHost(trimmedHost)
}
// 目前对话界面只支持这些端点
export const SUPPORTED_IMAGE_ENDPOINT_LIST = ['images/generations', 'images/edits', 'predict'] as const
export const SUPPORTED_ENDPOINT_LIST = [
'chat/completions',
'responses',
'messages',
'generateContent',
'streamGenerateContent',
...SUPPORTED_IMAGE_ENDPOINT_LIST
] as const
/**
* Converts an API host URL into separate base URL and endpoint components.
*
* @param apiHost - The API host string to parse. Expected to be a trimmed URL that may end with '#' followed by an endpoint identifier.
* @returns An object containing:
* - `baseURL`: The base URL without the endpoint suffix
* - `endpoint`: The matched endpoint identifier, or empty string if no match found
*
* @description
* This function extracts endpoint information from a composite API host string.
* If the host ends with '#', it attempts to match the preceding part against the supported endpoint list.
* The '#' delimiter is removed before processing.
*
* @example
* routeToEndpoint('https://api.example.com/openai/chat/completions#')
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: 'chat/completions' }
*
* @example
* routeToEndpoint('https://api.example.com/v1')
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: '' }
*/
export function routeToEndpoint(apiHost: string): { baseURL: string; endpoint: string } {
const trimmedHost = trim(apiHost)
// 前面已经确保apiHost合法
if (!trimmedHost.endsWith('#')) {
return { baseURL: trimmedHost, endpoint: '' }
}
// 去掉结尾的 #
const host = trimmedHost.slice(0, -1)
const endpointMatch = SUPPORTED_ENDPOINT_LIST.find((endpoint) => host.endsWith(endpoint))
if (!endpointMatch) {
const baseURL = withoutTrailingSlash(host)
return { baseURL, endpoint: '' }
}
const baseSegment = host.slice(0, host.length - endpointMatch.length)
const baseURL = withoutTrailingSlash(baseSegment).replace(/:$/, '') // 去掉结尾可能存在的冒号(gemini的特殊情况)
return { baseURL, endpoint: endpointMatch }
}
/**
* API
*
* @param {string} apiHost - API
* @returns {boolean} URL true false
*/
export function validateApiHost(apiHost: string): boolean {
// 允许apiHost为空
if (!apiHost || !trim(apiHost)) {
return true
}
try {
const url = new URL(trim(apiHost))
// 验证协议是否为 http 或 https
if (url.protocol !== 'http:' && url.protocol !== 'https:') {
return false
}
return true
} catch {
return false
}
}
/** /**
* API key * API key
* *
@@ -61,3 +224,50 @@ export function splitApiKeyString(keyStr: string): string[] {
.map((k) => k.replace(/\\,/g, ',')) .map((k) => k.replace(/\\,/g, ','))
.filter((k) => k) .filter((k) => k)
} }
/**
* Extracts the trailing API version segment from a URL path.
*
* This function extracts API version patterns (e.g., `v1`, `v2beta`) from the end of a URL.
* Only versions at the end of the path are extracted, not versions in the middle.
* The returned version string does not include leading or trailing slashes.
*
* @param {string} url - The URL string to parse.
* @returns {string | undefined} The trailing API version found (e.g., 'v1', 'v2beta'), or undefined if none found.
*
* @example
* getTrailingApiVersion('https://api.example.com/v1') // 'v1'
* getTrailingApiVersion('https://api.example.com/v2beta/') // 'v2beta'
* getTrailingApiVersion('https://api.example.com/v1/chat') // undefined (version not at end)
* getTrailingApiVersion('https://gateway.ai.cloudflare.com/v1/xxx/v1beta') // 'v1beta'
* getTrailingApiVersion('https://api.example.com') // undefined
*/
export function getTrailingApiVersion(url: string): string | undefined {
const match = url.match(TRAILING_VERSION_REGEX)
if (match) {
// Extract version without leading slash and trailing slash
return match[0].replace(/^\//, '').replace(/\/$/, '')
}
return undefined
}
/**
* Removes the trailing API version segment from a URL path.
*
* This function removes API version patterns (e.g., `/v1`, `/v2beta`) from the end of a URL.
* Only versions at the end of the path are removed, not versions in the middle.
*
* @param {string} url - The URL string to process.
* @returns {string} The URL with the trailing API version removed, or the original URL if no trailing version found.
*
* @example
* withoutTrailingApiVersion('https://api.example.com/v1') // 'https://api.example.com'
* withoutTrailingApiVersion('https://api.example.com/v2beta/') // 'https://api.example.com'
* withoutTrailingApiVersion('https://api.example.com/v1/chat') // 'https://api.example.com/v1/chat' (no change)
* withoutTrailingApiVersion('https://api.example.com') // 'https://api.example.com'
*/
export function withoutTrailingApiVersion(url: string): string {
return url.replace(TRAILING_VERSION_REGEX, '')
}
+32 -2
View File
@@ -2,8 +2,6 @@ import { getProviderLabel } from '@renderer/i18n/label'
import type { Provider } from '@renderer/types' import type { Provider } from '@renderer/types'
import { isSystemProvider } from '@renderer/types' import { isSystemProvider } from '@renderer/types'
export { getBaseModelName, getLowerBaseModelName } from '@shared/utils/naming'
/** /**
* ID * ID
* *
@@ -52,6 +50,38 @@ export const getDefaultGroupName = (id: string, provider?: string): string => {
return str return str
} }
/**
* ID
*
* - 'deepseek/deepseek-r1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/deepseek-r1' => 'deepseek-r1'
* @param {string} id ID
* @param {string} [delimiter='/'] '/'
* @returns {string}
*/
export const getBaseModelName = (id: string, delimiter: string = '/'): string => {
const parts = id.split(delimiter)
return parts[parts.length - 1]
}
/**
* ID
*
* - 'deepseek/DeepSeek-R1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/DeepSeek-R1' => 'deepseek-r1'
* @param {string} id ID
* @param {string} [delimiter='/'] '/'
* @returns {string}
*/
export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => {
const baseModelName = getBaseModelName(id, delimiter).toLowerCase()
// for openrouter
if (baseModelName.endsWith(':free')) {
return baseModelName.replace(':free', '')
}
return baseModelName
}
/** /**
* *
* @param provider * @param provider
+54 -15
View File
@@ -1,20 +1,10 @@
import { CLAUDE_SUPPORTED_PROVIDERS } from '@renderer/pages/code' import { CLAUDE_SUPPORTED_PROVIDERS } from '@renderer/pages/code'
import type { ProviderType } from '@renderer/types' import type { AzureOpenAIProvider, ProviderType, VertexProvider } from '@renderer/types'
import { isSystemProvider, type Provider, type SystemProviderId, SystemProviderIds } from '@renderer/types' import { isSystemProvider, type Provider, type SystemProviderId, SystemProviderIds } from '@renderer/types'
export {
isAIGatewayProvider, export const isAzureResponsesEndpoint = (provider: AzureOpenAIProvider) => {
isAnthropicProvider, return provider.apiVersion === 'preview' || provider.apiVersion === 'v1'
isAwsBedrockProvider, }
isAzureOpenAIProvider,
isAzureResponsesEndpoint,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isOpenAICompatibleProvider,
isOpenAIProvider,
isPerplexityProvider,
isVertexProvider
} from '@shared/provider'
export const getClaudeSupportedProviders = (providers: Provider[]) => { export const getClaudeSupportedProviders = (providers: Provider[]) => {
return providers.filter( return providers.filter(
@@ -136,6 +126,55 @@ export const isGeminiWebSearchProvider = (provider: Provider) => {
return SUPPORT_GEMINI_NATIVE_WEB_SEARCH_PROVIDERS.some((id) => id === provider.id) return SUPPORT_GEMINI_NATIVE_WEB_SEARCH_PROVIDERS.some((id) => id === provider.id)
} }
export const isNewApiProvider = (provider: Provider) => {
return ['new-api', 'cherryin'].includes(provider.id) || provider.type === 'new-api'
}
export function isCherryAIProvider(provider: Provider): boolean {
return provider.id === 'cherryai'
}
export function isPerplexityProvider(provider: Provider): boolean {
return provider.id === 'perplexity'
}
/**
* OpenAI
* @param {Provider} provider
* @returns {boolean} OpenAI
*/
export function isOpenAICompatibleProvider(provider: Provider): boolean {
return ['openai', 'new-api', 'mistral'].includes(provider.type)
}
export function isAzureOpenAIProvider(provider: Provider): provider is AzureOpenAIProvider {
return provider.type === 'azure-openai'
}
export function isOpenAIProvider(provider: Provider): boolean {
return provider.type === 'openai-response'
}
export function isVertexProvider(provider: Provider): provider is VertexProvider {
return provider.type === 'vertexai'
}
export function isAwsBedrockProvider(provider: Provider): boolean {
return provider.type === 'aws-bedrock'
}
export function isAnthropicProvider(provider: Provider): boolean {
return provider.type === 'anthropic'
}
export function isGeminiProvider(provider: Provider): boolean {
return provider.type === 'gemini'
}
export function isAIGatewayProvider(provider: Provider): boolean {
return provider.type === 'ai-gateway'
}
const NOT_SUPPORT_API_VERSION_PROVIDERS = ['github', 'copilot', 'perplexity'] as const satisfies SystemProviderId[] const NOT_SUPPORT_API_VERSION_PROVIDERS = ['github', 'copilot', 'perplexity'] as const satisfies SystemProviderId[]
export const isSupportAPIVersionProvider = (provider: Provider) => { export const isSupportAPIVersionProvider = (provider: Provider) => {
+1 -8
View File
@@ -8,10 +8,8 @@
"src/preload/**/*", "src/preload/**/*",
"src/renderer/src/services/traceApi.ts", "src/renderer/src/services/traceApi.ts",
"src/renderer/src/types/*", "src/renderer/src/types/*",
"packages/aiCore/src/**/*",
"packages/mcp-trace/**/*", "packages/mcp-trace/**/*",
"packages/shared/**/*", "packages/shared/**/*",
"packages/ai-sdk-provider/**/*"
], ],
"compilerOptions": { "compilerOptions": {
"composite": true, "composite": true,
@@ -28,12 +26,7 @@
"@types": ["./src/renderer/src/types/index.ts"], "@types": ["./src/renderer/src/types/index.ts"],
"@shared/*": ["./packages/shared/*"], "@shared/*": ["./packages/shared/*"],
"@mcp-trace/*": ["./packages/mcp-trace/*"], "@mcp-trace/*": ["./packages/mcp-trace/*"],
"@modelcontextprotocol/sdk/*": ["./node_modules/@modelcontextprotocol/sdk/dist/esm/*"], "@modelcontextprotocol/sdk/*": ["./node_modules/@modelcontextprotocol/sdk/dist/esm/*"]
"@cherrystudio/ai-core/provider": ["./packages/aiCore/src/core/providers/index.ts"],
"@cherrystudio/ai-core/built-in/plugins": ["./packages/aiCore/src/core/plugins/built-in/index.ts"],
"@cherrystudio/ai-core/*": ["./packages/aiCore/src/*"],
"@cherrystudio/ai-core": ["./packages/aiCore/src/index.ts"],
"@cherrystudio/ai-sdk-provider": ["./packages/ai-sdk-provider/src/index.ts"]
}, },
"experimentalDecorators": true, "experimentalDecorators": true,
"emitDecoratorMetadata": true, "emitDecoratorMetadata": true,