Compare commits

..

4 Commits

Author SHA1 Message Date
suyao
7f8d0b06ee Merge branch 'main' into fix/check-api-key 2025-12-01 16:37:43 +08:00
Phantom
3e6dc56196 fix(api): add withoutTrailingSharp utility and fix # handling in formatApiHost (#11604)
* docs(providerConfig): improve jsdoc for formatProviderApiHost function

* refactor(aiCore): improve provider handling with adaptProvider function

Introduce adaptProvider to centralize provider transformations and replace direct usage of handleSpecialProviders and formatProviderApiHost. This improves maintainability and provides consistent behavior across all provider usage scenarios.

* refactor(ProviderSettings): simplify api host formatting logic by using adaptProvider

Replace multiple format functions with a single adaptProvider utility to centralize host formatting logic and improve maintainability

* feat(api): add withoutTrailingSharp utility and update formatApiHost

add utility function to remove trailing # from URLs and update formatApiHost to use it
add comprehensive tests for new functionality

* feat(ProviderSetting): add help tooltip for api url selector

Add HelpTooltip component next to host selector to provide additional guidance about API URL configuration
2025-12-01 16:27:33 +08:00
suyao
4be5fedeec fix 2025-12-01 00:07:43 +08:00
suyao
163e016759 fix: enhance provider handling and API key rotation logic in AiProvider 2025-12-01 00:01:01 +08:00
77 changed files with 1387 additions and 4080 deletions

2
.gitignore vendored
View File

@@ -73,5 +73,3 @@ test-results
YOUR_MEMORY_FILE_PATH
.sessions/
.next/
*.tsbuildinfo

View File

@@ -25,10 +25,7 @@ export default defineConfig({
'@shared': resolve('packages/shared'),
'@logger': resolve('src/main/services/LoggerService'),
'@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'),
'@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node'),
'@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'),
'@cherrystudio/ai-core': resolve('packages/aiCore/src'),
'@cherrystudio/ai-sdk-provider': resolve('packages/ai-sdk-provider/src')
'@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node')
}
},
build: {

View File

@@ -9,27 +9,13 @@
*/
import Anthropic from '@anthropic-ai/sdk'
import type { MessageCreateParams, TextBlockParam, Tool as AnthropicTool } from '@anthropic-ai/sdk/resources'
import type { TextBlockParam } from '@anthropic-ai/sdk/resources'
import { loggerService } from '@logger'
import { type Provider, SystemProviderIds } from '@types'
import type { Provider } from '@types'
import type { ModelMessage } from 'ai'
const logger = loggerService.withContext('anthropic-sdk')
/**
* Context for Anthropic SDK client creation.
* This allows the shared module to be used in different environments
* by providing environment-specific implementations.
*/
export interface AnthropicSdkContext {
/**
* Custom fetch function to use for HTTP requests.
* In Electron main process, this should be `net.fetch`.
* In other environments, can use the default fetch or a custom implementation.
*/
fetch?: typeof globalThis.fetch
}
const defaultClaudeCodeSystemPrompt = `You are Claude Code, Anthropic's official CLI for Claude.`
const defaultClaudeCodeSystem: Array<TextBlockParam> = [
@@ -72,11 +58,8 @@ const defaultClaudeCodeSystem: Array<TextBlockParam> = [
export function getSdkClient(
provider: Provider,
oauthToken?: string | null,
extraHeaders?: Record<string, string | string[]>,
context?: AnthropicSdkContext
extraHeaders?: Record<string, string | string[]>
): Anthropic {
const customFetch = context?.fetch
if (provider.authType === 'oauth') {
if (!oauthToken) {
throw new Error('OAuth token is not available')
@@ -102,8 +85,7 @@ export function getSdkClient(
'x-stainless-runtime': 'node',
'x-stainless-runtime-version': 'v22.18.0',
...extraHeaders
},
fetch: customFetch
}
})
}
let baseURL =
@@ -124,12 +106,11 @@ export function getSdkClient(
baseURL,
dangerouslyAllowBrowser: true,
defaultHeaders: {
'anthropic-beta': 'interleaved-thinking-2025-05-14',
'anthropic-beta': 'output-128k-2025-02-19',
'APP-Code': 'MLTG2087',
...provider.extra_headers,
...extraHeaders
},
fetch: customFetch
}
})
}
@@ -139,11 +120,9 @@ export function getSdkClient(
baseURL,
dangerouslyAllowBrowser: true,
defaultHeaders: {
'anthropic-beta': 'interleaved-thinking-2025-05-14',
Authorization: provider.id === SystemProviderIds.longcat ? `Bearer ${provider.apiKey}` : undefined,
'anthropic-beta': 'output-128k-2025-02-19',
...provider.extra_headers
},
fetch: customFetch
}
})
}
@@ -194,31 +173,3 @@ export function buildClaudeCodeSystemModelMessage(system?: string | Array<TextBl
content: block.text
}))
}
/**
* Sanitize tool definitions for Anthropic API.
*
* Removes non-standard fields like `input_examples` from tool definitions
* that Anthropic's API doesn't support. This prevents validation errors when
* tools with extended fields are passed to the Anthropic SDK.
*
* @param tools - Array of tool definitions from MessageCreateParams
* @returns Sanitized tools array with non-standard fields removed
*
* @example
* ```typescript
* const sanitizedTools = sanitizeToolsForAnthropic(request.tools)
* ```
*/
export function sanitizeToolsForAnthropic(tools?: MessageCreateParams['tools']): MessageCreateParams['tools'] {
if (!tools || tools.length === 0) return tools
return tools.map((tool) => {
if ('type' in tool && tool.type !== 'custom') return tool
// oxlint-disable-next-line no-unused-vars
const { input_examples, ...sanitizedTool } = tool as AnthropicTool & { input_examples?: unknown }
return sanitizedTool as typeof tool
})
}

View File

@@ -1,245 +0,0 @@
/**
* Shared API Utilities
*
* Common utilities for API URL formatting and validation.
* Used by both main process (API Server) and renderer.
*/
import type { MinimalProvider } from '@shared/provider'
import { trim } from 'lodash'
// Supported endpoints for routing
export const SUPPORTED_IMAGE_ENDPOINT_LIST = ['images/generations', 'images/edits', 'predict'] as const
export const SUPPORTED_ENDPOINT_LIST = [
'chat/completions',
'responses',
'messages',
'generateContent',
'streamGenerateContent',
...SUPPORTED_IMAGE_ENDPOINT_LIST
] as const
/**
* Removes the trailing slash from a URL string if it exists.
*/
export function withoutTrailingSlash<T extends string>(url: T): T {
return url.replace(/\/$/, '') as T
}
/**
* Matches a version segment in a path that starts with `/v<number>` and optionally
* continues with `alpha` or `beta`. The segment may be followed by `/` or the end
* of the string (useful for cases like `/v3alpha/resources`).
*/
const VERSION_REGEX_PATTERN = '\\/v\\d+(?:alpha|beta)?(?=\\/|$)'
/**
* Matches an API version at the end of a URL (with optional trailing slash).
* Used to detect and extract versions only from the trailing position.
*/
const TRAILING_VERSION_REGEX = /\/v\d+(?:alpha|beta)?\/?$/i
/**
* 判断 host 的 path 中是否包含形如版本的字符串(例如 /v1、/v2beta 等),
*
* @param host - 要检查的 host 或 path 字符串
* @returns 如果 path 中包含版本字符串则返回 true否则 false
*/
export function hasAPIVersion(host?: string): boolean {
if (!host) return false
const regex = new RegExp(VERSION_REGEX_PATTERN, 'i')
try {
const url = new URL(host)
return regex.test(url.pathname)
} catch {
// 若无法作为完整 URL 解析,则当作路径直接检测
return regex.test(host)
}
}
/**
* 格式化 Azure OpenAI 的 API 主机地址。
*/
export function formatAzureOpenAIApiHost(host: string): string {
const normalizedHost = withoutTrailingSlash(host)
?.replace(/\/v1$/, '')
.replace(/\/openai$/, '')
// NOTE: AISDK会添加上`v1`
return formatApiHost(normalizedHost + '/openai', false)
}
export function formatVertexApiHost(
provider: MinimalProvider,
project: string = 'test-project',
location: string = 'us-central1'
): string {
const { apiHost } = provider
const trimmedHost = withoutTrailingSlash(trim(apiHost))
if (!trimmedHost || trimmedHost.endsWith('aiplatform.googleapis.com')) {
const host =
location === 'global' ? 'https://aiplatform.googleapis.com' : `https://${location}-aiplatform.googleapis.com`
return `${formatApiHost(host)}/projects/${project}/locations/${location}`
}
return formatApiHost(trimmedHost)
}
/**
* Formats an API host URL by normalizing it and optionally appending an API version.
*
* @param host - The API host URL to format. Leading/trailing whitespace will be trimmed and trailing slashes removed.
* @param supportApiVersion - Whether the API version is supported. Defaults to `true`.
* @param apiVersion - The API version to append if needed. Defaults to `'v1'`.
*
* @returns The formatted API host URL. If the host is empty after normalization, returns an empty string.
* If the host ends with '#', API version is not supported, or the host already contains a version, returns the normalized host as-is.
* Otherwise, returns the host with the API version appended.
*
* @example
* formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1'
* formatApiHost('https://api.example.com#') // Returns 'https://api.example.com#'
* formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2'
*/
export function formatApiHost(host?: string, supportApiVersion: boolean = true, apiVersion: string = 'v1'): string {
const normalizedHost = withoutTrailingSlash(trim(host))
if (!normalizedHost) {
return ''
}
if (normalizedHost.endsWith('#') || !supportApiVersion || hasAPIVersion(normalizedHost)) {
return normalizedHost
}
return `${normalizedHost}/${apiVersion}`
}
/**
* Converts an API host URL into separate base URL and endpoint components.
*
* This function extracts endpoint information from a composite API host string.
* If the host ends with '#', it attempts to match the preceding part against the supported endpoint list.
*
* @param apiHost - The API host string to parse
* @returns An object containing:
* - `baseURL`: The base URL without the endpoint suffix
* - `endpoint`: The matched endpoint identifier, or empty string if no match found
*
* @example
* routeToEndpoint('https://api.example.com/openai/chat/completions#')
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: 'chat/completions' }
*
* @example
* routeToEndpoint('https://api.example.com/v1')
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: '' }
*/
export function routeToEndpoint(apiHost: string): { baseURL: string; endpoint: string } {
const trimmedHost = (apiHost || '').trim()
if (!trimmedHost.endsWith('#')) {
return { baseURL: trimmedHost, endpoint: '' }
}
// Remove trailing #
const host = trimmedHost.slice(0, -1)
const endpointMatch = SUPPORTED_ENDPOINT_LIST.find((endpoint) => host.endsWith(endpoint))
if (!endpointMatch) {
const baseURL = withoutTrailingSlash(host)
return { baseURL, endpoint: '' }
}
const baseSegment = host.slice(0, host.length - endpointMatch.length)
const baseURL = withoutTrailingSlash(baseSegment).replace(/:$/, '') // Remove trailing colon (gemini special case)
return { baseURL, endpoint: endpointMatch }
}
/**
* Gets the AI SDK compatible base URL from a provider's apiHost.
*
* AI SDK expects baseURL WITH version suffix (e.g., /v1).
* This function:
* 1. Handles '#' endpoint routing format
* 2. Ensures the URL has a version suffix (adds /v1 if missing)
*
* @param apiHost - The provider's apiHost value (may or may not have /v1)
* @param apiVersion - The API version to use if missing. Defaults to 'v1'.
* @returns The baseURL suitable for AI SDK (with version suffix)
*
* @example
* getAiSdkBaseUrl('https://api.openai.com') // 'https://api.openai.com/v1'
* getAiSdkBaseUrl('https://api.openai.com/v1') // 'https://api.openai.com/v1'
* getAiSdkBaseUrl('https://api.example.com/chat/completions#') // 'https://api.example.com'
*/
export function getAiSdkBaseUrl(apiHost: string, apiVersion: string = 'v1'): string {
// First handle '#' endpoint routing format
const { baseURL } = routeToEndpoint(apiHost)
// If already has version, return as-is
if (hasAPIVersion(baseURL)) {
return withoutTrailingSlash(baseURL)
}
// Add version suffix
return `${withoutTrailingSlash(baseURL)}/${apiVersion}`
}
/**
* Validates an API host address.
*
* @param apiHost - The API host address to validate
* @returns true if valid URL with http/https protocol, false otherwise
*/
export function validateApiHost(apiHost: string): boolean {
if (!apiHost || !apiHost.trim()) {
return true // Allow empty
}
try {
const url = new URL(apiHost.trim())
return url.protocol === 'http:' || url.protocol === 'https:'
} catch {
return false
}
}
/**
* Extracts the trailing API version segment from a URL path.
*
* This function extracts API version patterns (e.g., `v1`, `v2beta`) from the end of a URL.
* Only versions at the end of the path are extracted, not versions in the middle.
* The returned version string does not include leading or trailing slashes.
*
* @param {string} url - The URL string to parse.
* @returns {string | undefined} The trailing API version found (e.g., 'v1', 'v2beta'), or undefined if none found.
*
* @example
* getTrailingApiVersion('https://api.example.com/v1') // 'v1'
* getTrailingApiVersion('https://api.example.com/v2beta/') // 'v2beta'
* getTrailingApiVersion('https://api.example.com/v1/chat') // undefined (version not at end)
* getTrailingApiVersion('https://gateway.ai.cloudflare.com/v1/xxx/v1beta') // 'v1beta'
* getTrailingApiVersion('https://api.example.com') // undefined
*/
export function getTrailingApiVersion(url: string): string | undefined {
const match = url.match(TRAILING_VERSION_REGEX)
if (match) {
// Extract version without leading slash and trailing slash
return match[0].replace(/^\//, '').replace(/\/$/, '')
}
return undefined
}
/**
* Removes the trailing API version segment from a URL path.
*
* This function removes API version patterns (e.g., `/v1`, `/v2beta`) from the end of a URL.
* Only versions at the end of the path are removed, not versions in the middle.
*
* @param {string} url - The URL string to process.
* @returns {string} The URL with the trailing API version removed, or the original URL if no trailing version found.
*
* @example
* withoutTrailingApiVersion('https://api.example.com/v1') // 'https://api.example.com'
* withoutTrailingApiVersion('https://api.example.com/v2beta/') // 'https://api.example.com'
* withoutTrailingApiVersion('https://api.example.com/v1/chat') // 'https://api.example.com/v1/chat' (no change)
* withoutTrailingApiVersion('https://api.example.com') // 'https://api.example.com'
*/
export function withoutTrailingApiVersion(url: string): string {
return url.replace(TRAILING_VERSION_REGEX, '')
}

View File

@@ -43,35 +43,6 @@ export function isSiliconAnthropicCompatibleModel(modelId: string): boolean {
}
/**
* PPIO provider models that support Anthropic API endpoint.
* These models can be used with Claude Code via the Anthropic-compatible API.
*
* @see https://ppio.com/docs/model/llm-anthropic-compatibility
* Silicon provider's Anthropic API host URL.
*/
export const PPIO_ANTHROPIC_COMPATIBLE_MODELS: readonly string[] = [
'moonshotai/kimi-k2-thinking',
'minimax/minimax-m2',
'deepseek/deepseek-v3.2-exp',
'deepseek/deepseek-v3.1-terminus',
'zai-org/glm-4.6',
'moonshotai/kimi-k2-0905',
'deepseek/deepseek-v3.1',
'moonshotai/kimi-k2-instruct',
'qwen/qwen3-next-80b-a3b-instruct',
'qwen/qwen3-next-80b-a3b-thinking'
]
/**
* Creates a Set for efficient lookup of PPIO Anthropic-compatible model IDs.
*/
const PPIO_ANTHROPIC_COMPATIBLE_MODEL_SET = new Set(PPIO_ANTHROPIC_COMPATIBLE_MODELS)
/**
* Checks if a model ID is compatible with Anthropic API on PPIO provider.
*
* @param modelId - The model ID to check
* @returns true if the model supports Anthropic API endpoint
*/
export function isPpioAnthropicCompatibleModel(modelId: string): boolean {
return PPIO_ANTHROPIC_COMPATIBLE_MODEL_SET.has(modelId)
}
export const SILICON_ANTHROPIC_API_HOST = 'https://api.siliconflow.cn'

View File

@@ -1,15 +0,0 @@
/**
* Shared AI SDK Middlewares
*
* Environment-agnostic middlewares that can be used in both
* renderer process and main process (API server).
*/
export {
buildSharedMiddlewares,
getReasoningTagName,
isGemini3ModelId,
openrouterReasoningMiddleware,
type SharedMiddlewareConfig,
skipGeminiThoughtSignatureMiddleware
} from './middlewares'

View File

@@ -1,205 +0,0 @@
/**
* Shared AI SDK Middlewares
*
* These middlewares are environment-agnostic and can be used in both
* renderer process and main process (API server).
*/
import type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider'
import { extractReasoningMiddleware } from 'ai'
/**
* Configuration for building shared middlewares
*/
export interface SharedMiddlewareConfig {
/**
* Whether to enable reasoning extraction
*/
enableReasoning?: boolean
/**
* Tag name for reasoning extraction
* Defaults based on model ID
*/
reasoningTagName?: string
/**
* Model ID - used to determine default reasoning tag and model detection
*/
modelId?: string
/**
* Provider ID (Cherry Studio provider ID)
* Used for provider-specific middlewares like OpenRouter
*/
providerId?: string
/**
* AI SDK Provider ID
* Used for Gemini thought signature middleware
* e.g., 'google', 'google-vertex'
*/
aiSdkProviderId?: string
}
/**
* Check if model ID represents a Gemini 3 (2.5) model
* that requires thought signature handling
*
* @param modelId - The model ID string (not Model object)
*/
export function isGemini3ModelId(modelId?: string): boolean {
if (!modelId) return false
const lowerModelId = modelId.toLowerCase()
return lowerModelId.includes('gemini-3')
}
/**
* Get the default reasoning tag name based on model ID
*
* Different models use different tags for reasoning content:
* - Most models: 'think'
* - GPT-OSS models: 'reasoning'
* - Gemini models: 'thought'
* - Seed models: 'seed:think'
*/
export function getReasoningTagName(modelId?: string): string {
if (!modelId) return 'think'
const lowerModelId = modelId.toLowerCase()
if (lowerModelId.includes('gpt-oss')) return 'reasoning'
if (lowerModelId.includes('gemini')) return 'thought'
if (lowerModelId.includes('seed-oss-36b')) return 'seed:think'
return 'think'
}
/**
* Skip Gemini Thought Signature Middleware
*
* Due to the complexity of multi-model client requests (which can switch
* to other models mid-process), this middleware skips all Gemini 3
* thinking signatures validation.
*
* @param aiSdkId - AI SDK Provider ID (e.g., 'google', 'google-vertex')
* @returns LanguageModelV2Middleware
*/
export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelV2Middleware {
const MAGIC_STRING = 'skip_thought_signature_validator'
return {
middlewareVersion: 'v2',
transformParams: async ({ params }) => {
const transformedParams = { ...params }
// Process messages in prompt
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
transformedParams.prompt = transformedParams.prompt.map((message) => {
if (typeof message.content !== 'string') {
for (const part of message.content) {
const googleOptions = part?.providerOptions?.[aiSdkId]
if (googleOptions?.thoughtSignature) {
googleOptions.thoughtSignature = MAGIC_STRING
}
}
}
return message
})
}
return transformedParams
}
}
}
/**
* OpenRouter Reasoning Middleware
*
* Filters out [REDACTED] blocks from OpenRouter reasoning responses.
* OpenRouter may include [REDACTED] markers in reasoning content that
* should be removed for cleaner output.
*
* @see https://openrouter.ai/docs/docs/best-practices/reasoning-tokens
* @returns LanguageModelV2Middleware
*/
export function openrouterReasoningMiddleware(): LanguageModelV2Middleware {
const REDACTED_BLOCK = '[REDACTED]'
return {
middlewareVersion: 'v2',
wrapGenerate: async ({ doGenerate }) => {
const { content, ...rest } = await doGenerate()
const modifiedContent = content.map((part) => {
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
return {
...part,
text: part.text.replace(REDACTED_BLOCK, '')
}
}
return part
})
return { content: modifiedContent, ...rest }
},
wrapStream: async ({ doStream }) => {
const { stream, ...rest } = await doStream()
return {
stream: stream.pipeThrough(
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
transform(
chunk: LanguageModelV2StreamPart,
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
) {
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
controller.enqueue({
...chunk,
delta: chunk.delta.replace(REDACTED_BLOCK, '')
})
} else {
controller.enqueue(chunk)
}
}
})
),
...rest
}
}
}
}
/**
* Build shared middlewares based on configuration
*
* This function builds a set of middlewares that are commonly needed
* across different environments (renderer, API server).
*
* @param config - Configuration for middleware building
* @returns Array of AI SDK middlewares
*
* @example
* ```typescript
* import { buildSharedMiddlewares } from '@shared/middleware'
*
* const middlewares = buildSharedMiddlewares({
* enableReasoning: true,
* modelId: 'gemini-2.5-pro',
* providerId: 'openrouter',
* aiSdkProviderId: 'google'
* })
* ```
*/
export function buildSharedMiddlewares(config: SharedMiddlewareConfig): LanguageModelV2Middleware[] {
const middlewares: LanguageModelV2Middleware[] = []
// 1. Reasoning extraction middleware
if (config.enableReasoning) {
const tagName = config.reasoningTagName || getReasoningTagName(config.modelId)
middlewares.push(extractReasoningMiddleware({ tagName }))
}
// 2. OpenRouter-specific: filter [REDACTED] blocks
if (config.providerId === 'openrouter' && config.enableReasoning) {
middlewares.push(openrouterReasoningMiddleware())
}
// 3. Gemini 3 (2.5) specific: skip thought signature validation
if (isGemini3ModelId(config.modelId) && config.aiSdkProviderId) {
middlewares.push(skipGeminiThoughtSignatureMiddleware(config.aiSdkProviderId))
}
return middlewares
}

View File

@@ -1,22 +0,0 @@
import type { MinimalModel, MinimalProvider, ProviderType } from '../types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
const AZURE_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: MinimalProvider) => ({
...provider,
type: 'anthropic' as ProviderType,
apiHost: provider.apiHost + 'anthropic/v1',
id: 'azure-anthropic'
})
}
],
fallbackRule: (provider: MinimalProvider) => provider
}
export const azureAnthropicProviderCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
provider2Provider<MinimalModel, MinimalProvider, P>(AZURE_ANTHROPIC_RULES, model, provider)

View File

@@ -1,32 +0,0 @@
import type { MinimalModel, MinimalProvider } from '../types'
import type { RuleSet } from './types'
export const startsWith =
(prefix: string) =>
<M extends MinimalModel>(model: M) =>
model.id.toLowerCase().startsWith(prefix.toLowerCase())
export const endpointIs =
(type: string) =>
<M extends MinimalModel>(model: M) =>
model.endpoint_type === type
/**
* 解析模型对应的Provider
* @param ruleSet 规则集对象
* @param model 模型对象
* @param provider 原始provider对象
* @returns 解析出的provider对象
*/
export function provider2Provider<M extends MinimalModel, R extends MinimalProvider, P extends R = R>(
ruleSet: RuleSet<M, R>,
model: M,
provider: P
): P {
for (const rule of ruleSet.rules) {
if (rule.match(model)) {
return rule.provider(provider) as P
}
}
return ruleSet.fallbackRule(provider) as P
}

View File

@@ -1,6 +0,0 @@
export { aihubmixProviderCreator } from './aihubmix'
export { azureAnthropicProviderCreator } from './azure-anthropic'
export { endpointIs, provider2Provider, startsWith } from './helper'
export { newApiResolverCreator } from './newApi'
export type { RuleSet } from './types'
export { vertexAnthropicProviderCreator } from './vertex-anthropic'

View File

@@ -1,9 +0,0 @@
import type { MinimalModel, MinimalProvider } from '../types'
export interface RuleSet<M extends MinimalModel = MinimalModel, P extends MinimalProvider = MinimalProvider> {
rules: Array<{
match: (model: M) => boolean
provider: (provider: P) => P
}>
fallbackRule: (provider: P) => P
}

View File

@@ -1,19 +0,0 @@
import type { MinimalModel, MinimalProvider } from '../types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
const VERTEX_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: MinimalProvider) => ({
...provider,
id: 'google-vertex-anthropic'
})
}
],
fallbackRule: (provider: MinimalProvider) => provider
}
export const vertexAnthropicProviderCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
provider2Provider<MinimalModel, MinimalProvider, P>(VERTEX_ANTHROPIC_RULES, model, provider)

View File

@@ -1,26 +0,0 @@
import { getLowerBaseModelName } from '@shared/utils/naming'
import type { MinimalModel } from './types'
export const COPILOT_EDITOR_VERSION = 'vscode/1.104.1'
export const COPILOT_PLUGIN_VERSION = 'copilot-chat/0.26.7'
export const COPILOT_INTEGRATION_ID = 'vscode-chat'
export const COPILOT_USER_AGENT = 'GitHubCopilotChat/0.26.7'
export const COPILOT_DEFAULT_HEADERS = {
'Copilot-Integration-Id': COPILOT_INTEGRATION_ID,
'User-Agent': COPILOT_USER_AGENT,
'Editor-Version': COPILOT_EDITOR_VERSION,
'Editor-Plugin-Version': COPILOT_PLUGIN_VERSION,
'editor-version': COPILOT_EDITOR_VERSION,
'editor-plugin-version': COPILOT_PLUGIN_VERSION,
'copilot-vision-request': 'true'
} as const
// Models that require the OpenAI Responses endpoint when routed through GitHub Copilot (#10560)
const COPILOT_RESPONSES_MODEL_IDS = ['gpt-5-codex', 'gpt-5.1-codex', 'gpt-5.1-codex-mini']
export function isCopilotResponsesModel<M extends MinimalModel>(model: M): boolean {
const normalizedId = getLowerBaseModelName(model.id)
return COPILOT_RESPONSES_MODEL_IDS.some((target) => normalizedId === target)
}

View File

@@ -1,100 +0,0 @@
/**
* Provider Type Detection Utilities
*
* Functions to detect provider types based on provider configuration.
* These are pure functions that only depend on provider.type and provider.id.
*
* NOTE: These functions should match the logic in @renderer/utils/provider.ts
*/
import type { MinimalProvider } from './types'
/**
* Check if provider is Anthropic type
*/
export function isAnthropicProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'anthropic'
}
/**
* Check if provider is OpenAI Response type (openai-response)
* NOTE: This matches isOpenAIProvider in renderer/utils/provider.ts
*/
export function isOpenAIProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'openai-response'
}
/**
* Check if provider is Gemini type
*/
export function isGeminiProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'gemini'
}
/**
* Check if provider is Azure OpenAI type
*/
export function isAzureOpenAIProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'azure-openai'
}
/**
* Check if provider is Vertex AI type
*/
export function isVertexProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'vertexai'
}
/**
* Check if provider is AWS Bedrock type
*/
export function isAwsBedrockProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'aws-bedrock'
}
/**
* Check if provider is AI Gateway type
*/
export function isAIGatewayProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'ai-gateway'
}
/**
* Check if Azure OpenAI provider uses responses endpoint
* Matches isAzureResponsesEndpoint in renderer/utils/provider.ts
*/
export function isAzureResponsesEndpoint<P extends MinimalProvider>(provider: P): boolean {
return provider.apiVersion === 'preview' || provider.apiVersion === 'v1'
}
/**
* Check if provider is Cherry AI type
* Matches isCherryAIProvider in renderer/utils/provider.ts
*/
export function isCherryAIProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.id === 'cherryai'
}
/**
* Check if provider is Perplexity type
* Matches isPerplexityProvider in renderer/utils/provider.ts
*/
export function isPerplexityProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.id === 'perplexity'
}
/**
* Check if provider is new-api type (supports multiple backends)
* Matches isNewApiProvider in renderer/utils/provider.ts
*/
export function isNewApiProvider<P extends MinimalProvider>(provider: P): boolean {
return ['new-api', 'cherryin'].includes(provider.id) || provider.type === ('new-api' as string)
}
/**
* Check if provider is OpenAI compatible
* Matches isOpenAICompatibleProvider in renderer/utils/provider.ts
*/
export function isOpenAICompatibleProvider<P extends MinimalProvider>(provider: P): boolean {
return ['openai', 'new-api', 'mistral'].includes(provider.type)
}

View File

@@ -1,136 +0,0 @@
/**
* Provider API Host Formatting
*
* Utilities for formatting provider API hosts to work with AI SDK.
* These handle the differences between how Cherry Studio stores API hosts
* and how AI SDK expects them.
*/
import {
formatApiHost,
formatAzureOpenAIApiHost,
formatVertexApiHost,
routeToEndpoint,
withoutTrailingSlash
} from '../api'
import {
isAnthropicProvider,
isAzureOpenAIProvider,
isCherryAIProvider,
isGeminiProvider,
isPerplexityProvider,
isVertexProvider
} from './detection'
import type { MinimalProvider } from './types'
import { SystemProviderIds } from './types'
/**
* Interface for environment-specific implementations
* Renderer and Main process can provide their own implementations
*/
export interface ProviderFormatContext {
vertex: {
project: string
location: string
}
}
/**
* Default Azure OpenAI API host formatter
*/
export function defaultFormatAzureOpenAIApiHost(host: string): string {
const normalizedHost = withoutTrailingSlash(host)
?.replace(/\/v1$/, '')
.replace(/\/openai$/, '')
// AI SDK will add /v1
return formatApiHost(normalizedHost + '/openai', false)
}
/**
* Format provider API host for AI SDK
*
* This function normalizes the apiHost to work with AI SDK.
* Different providers have different requirements:
* - Most providers: add /v1 suffix
* - Gemini: add /v1beta suffix
* - Some providers: no suffix needed
*
* @param provider - The provider to format
* @param context - Optional context with environment-specific implementations
* @returns Provider with formatted apiHost (and anthropicApiHost if applicable)
*/
export function formatProviderApiHost<T extends MinimalProvider>(provider: T, context: ProviderFormatContext): T {
const formatted = { ...provider }
// Format anthropicApiHost if present
if (formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost)
}
// Format based on provider type
if (isAnthropicProvider(provider)) {
const baseHost = formatted.anthropicApiHost || formatted.apiHost
// AI SDK needs /v1 in baseURL
formatted.apiHost = formatApiHost(baseHost)
if (!formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatted.apiHost
}
} else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isGeminiProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, true, 'v1beta')
} else if (isAzureOpenAIProvider(formatted)) {
formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost)
} else if (isVertexProvider(formatted)) {
formatted.apiHost = formatVertexApiHost(formatted, context.vertex.project, context.vertex.location)
} else if (isCherryAIProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isPerplexityProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else {
formatted.apiHost = formatApiHost(formatted.apiHost)
}
return formatted
}
/**
* Get the base URL for AI SDK from a formatted provider
*
* This extracts the baseURL that AI SDK expects, handling
* the '#' endpoint routing format if present.
*
* @param formattedApiHost - The formatted apiHost (after formatProviderApiHost)
* @returns The baseURL for AI SDK
*/
export function getBaseUrlForAiSdk(formattedApiHost: string): string {
const { baseURL } = routeToEndpoint(formattedApiHost)
return baseURL
}
/**
* Get rotated API key from comma-separated keys
*
* This is the interface for API key rotation. The actual implementation
* depends on the environment (renderer uses window.keyv, main uses its own storage).
*/
export interface ApiKeyRotator {
/**
* Get the next API key in rotation
* @param providerId - The provider ID for tracking rotation
* @param keys - Comma-separated API keys
* @returns The next API key to use
*/
getRotatedKey(providerId: string, keys: string): string
}
/**
* Simple API key rotator that always returns the first key
* Use this when rotation is not needed
*/
export const simpleKeyRotator: ApiKeyRotator = {
getRotatedKey(_providerId: string, keys: string): string {
const keyList = keys.split(',').map((k) => k.trim())
return keyList[0] || keys
}
}

View File

@@ -1,48 +0,0 @@
/**
* Shared Provider Utilities
*
* This module exports utilities for working with AI providers
* that can be shared between main process and renderer process.
*/
// Type definitions
export type { MinimalProvider, ProviderType, SystemProviderId } from './types'
export { SystemProviderIds } from './types'
// Provider type detection
export {
isAIGatewayProvider,
isAnthropicProvider,
isAwsBedrockProvider,
isAzureOpenAIProvider,
isAzureResponsesEndpoint,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isOpenAICompatibleProvider,
isOpenAIProvider,
isPerplexityProvider,
isVertexProvider
} from './detection'
// API host formatting
export type { ApiKeyRotator, ProviderFormatContext } from './format'
export {
defaultFormatAzureOpenAIApiHost,
formatProviderApiHost,
getBaseUrlForAiSdk,
simpleKeyRotator
} from './format'
// Provider ID mapping
export { getAiSdkProviderId, STATIC_PROVIDER_MAPPING, tryResolveProviderId } from './mapping'
// AI SDK configuration
export type { AiSdkConfig, AiSdkConfigContext } from './sdk-config'
export { providerToAiSdkConfig } from './sdk-config'
// Provider resolution
export { resolveActualProvider } from './resolve'
// Provider initialization
export { initializeSharedProviders, SHARED_PROVIDER_CONFIGS } from './initialization'

View File

@@ -1,107 +0,0 @@
import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider'
type ProviderInitializationLogger = {
warn?: (message: string) => void
error?: (message: string, error: Error) => void
}
export const SHARED_PROVIDER_CONFIGS: ProviderConfig[] = [
{
id: 'openrouter',
name: 'OpenRouter',
import: () => import('@openrouter/ai-sdk-provider'),
creatorFunctionName: 'createOpenRouter',
supportsImageGeneration: true,
aliases: ['openrouter']
},
{
id: 'google-vertex',
name: 'Google Vertex AI',
import: () => import('@ai-sdk/google-vertex/edge'),
creatorFunctionName: 'createVertex',
supportsImageGeneration: true,
aliases: ['vertexai']
},
{
id: 'google-vertex-anthropic',
name: 'Google Vertex AI Anthropic',
import: () => import('@ai-sdk/google-vertex/anthropic/edge'),
creatorFunctionName: 'createVertexAnthropic',
supportsImageGeneration: true,
aliases: ['vertexai-anthropic']
},
{
id: 'azure-anthropic',
name: 'Azure AI Anthropic',
import: () => import('@ai-sdk/anthropic'),
creatorFunctionName: 'createAnthropic',
supportsImageGeneration: false,
aliases: ['azure-anthropic']
},
{
id: 'github-copilot-openai-compatible',
name: 'GitHub Copilot OpenAI Compatible',
import: () => import('@opeoginni/github-copilot-openai-compatible'),
creatorFunctionName: 'createGitHubCopilotOpenAICompatible',
supportsImageGeneration: false,
aliases: ['copilot', 'github-copilot']
},
{
id: 'bedrock',
name: 'Amazon Bedrock',
import: () => import('@ai-sdk/amazon-bedrock'),
creatorFunctionName: 'createAmazonBedrock',
supportsImageGeneration: true,
aliases: ['aws-bedrock']
},
{
id: 'perplexity',
name: 'Perplexity',
import: () => import('@ai-sdk/perplexity'),
creatorFunctionName: 'createPerplexity',
supportsImageGeneration: false,
aliases: ['perplexity']
},
{
id: 'mistral',
name: 'Mistral',
import: () => import('@ai-sdk/mistral'),
creatorFunctionName: 'createMistral',
supportsImageGeneration: false,
aliases: ['mistral']
},
{
id: 'huggingface',
name: 'HuggingFace',
import: () => import('@ai-sdk/huggingface'),
creatorFunctionName: 'createHuggingFace',
supportsImageGeneration: true,
aliases: ['hf', 'hugging-face']
},
{
id: 'ai-gateway',
name: 'AI Gateway',
import: () => import('@ai-sdk/gateway'),
creatorFunctionName: 'createGateway',
supportsImageGeneration: true,
aliases: ['gateway']
},
{
id: 'cerebras',
name: 'Cerebras',
import: () => import('@ai-sdk/cerebras'),
creatorFunctionName: 'createCerebras',
supportsImageGeneration: false
}
] as const
export function initializeSharedProviders(logger?: ProviderInitializationLogger): void {
try {
const successCount = registerMultipleProviderConfigs(SHARED_PROVIDER_CONFIGS)
if (successCount < SHARED_PROVIDER_CONFIGS.length) {
logger?.warn?.('Some providers failed to register. Check previous error logs.')
}
} catch (error) {
logger?.error?.('Failed to initialize shared providers', error as Error)
}
}

View File

@@ -1,95 +0,0 @@
/**
* Provider ID Mapping
*
* Maps Cherry Studio provider IDs/types to AI SDK provider IDs.
* This logic should match @renderer/aiCore/provider/factory.ts
*/
import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider'
import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from './detection'
import type { MinimalProvider } from './types'
/**
* Static mapping from Cherry Studio provider ID/type to AI SDK provider ID
* Matches STATIC_PROVIDER_MAPPING in @renderer/aiCore/provider/factory.ts
*/
export const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
gemini: 'google', // Google Gemini -> google
'azure-openai': 'azure', // Azure OpenAI -> azure
'openai-response': 'openai', // OpenAI Responses -> openai
grok: 'xai', // Grok -> xai
copilot: 'github-copilot-openai-compatible'
}
/**
* Try to resolve a provider identifier to an AI SDK provider ID
* Matches tryResolveProviderId in @renderer/aiCore/provider/factory.ts
*
* @param identifier - The provider ID or type to resolve
* @param checker - Provider config checker (defaults to static mapping only)
* @returns The resolved AI SDK provider ID, or null if not found
*/
export function tryResolveProviderId(identifier: string): ProviderId | null {
// 1. 检查静态映射
const staticMapping = STATIC_PROVIDER_MAPPING[identifier]
if (staticMapping) {
return staticMapping
}
// 2. 检查AiCore是否支持包括别名支持
if (hasProviderConfigByAlias(identifier)) {
// 解析为真实的Provider ID
return resolveProviderConfigId(identifier) as ProviderId
}
return null
}
/**
* Get the AI SDK Provider ID for a Cherry Studio provider
* Matches getAiSdkProviderId in @renderer/aiCore/provider/factory.ts
*
* Logic:
* 1. Handle Azure OpenAI specially (check responses endpoint)
* 2. Try to resolve from provider.id
* 3. Try to resolve from provider.type (but not for generic 'openai' type)
* 4. Check for OpenAI API host pattern
* 5. Fallback to provider's own ID
*
* @param provider - The Cherry Studio provider
* @param checker - Provider config checker (defaults to static mapping only)
* @returns The AI SDK provider ID to use
*/
export function getAiSdkProviderId(provider: MinimalProvider): ProviderId {
// 1. Handle Azure OpenAI specially - check this FIRST before other resolution
if (isAzureOpenAIProvider(provider)) {
if (isAzureResponsesEndpoint(provider)) {
return 'azure-responses'
}
return 'azure'
}
// 2. 尝试解析provider.id
const resolvedFromId = tryResolveProviderId(provider.id)
if (resolvedFromId) {
return resolvedFromId
}
// 3. 尝试解析provider.type
// 会把所有类型为openai的自定义provider解析到aisdk的openaiProvider上
if (provider.type !== 'openai') {
const resolvedFromType = tryResolveProviderId(provider.type)
if (resolvedFromType) {
return resolvedFromType
}
}
// 4. Check for OpenAI API host pattern
if (provider.apiHost.includes('api.openai.com')) {
return 'openai-chat'
}
// 5. 最后的fallback使用provider本身的id
return provider.id
}

View File

@@ -1,43 +0,0 @@
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
import { isAzureOpenAIProvider, isNewApiProvider } from './detection'
import type { MinimalModel, MinimalProvider } from './types'
export interface ResolveActualProviderOptions<P extends MinimalProvider> {
isSystemProvider?: (provider: P) => boolean
}
const defaultIsSystemProvider = <P extends MinimalProvider>(provider: P): boolean => {
if ('isSystem' in provider) {
return Boolean((provider as unknown as { isSystem?: boolean }).isSystem)
}
return false
}
export function resolveActualProvider<M extends MinimalModel, P extends MinimalProvider>(
provider: P,
model: M,
options: ResolveActualProviderOptions<P> = {}
): P {
let resolvedProvider = provider
if (isNewApiProvider(resolvedProvider)) {
resolvedProvider = newApiResolverCreator(model, resolvedProvider)
}
const isSystemProvider = options.isSystemProvider?.(resolvedProvider) ?? defaultIsSystemProvider(resolvedProvider)
if (isSystemProvider && resolvedProvider.id === 'aihubmix') {
resolvedProvider = aihubmixProviderCreator(model, resolvedProvider)
}
if (isSystemProvider && resolvedProvider.id === 'vertexai') {
resolvedProvider = vertexAnthropicProviderCreator(model, resolvedProvider)
}
if (isAzureOpenAIProvider(resolvedProvider)) {
resolvedProvider = azureAnthropicProviderCreator(model, resolvedProvider)
}
return resolvedProvider
}

View File

@@ -1,259 +0,0 @@
/**
* AI SDK Configuration
*
* Shared utilities for converting Cherry Studio Provider to AI SDK configuration.
* Environment-specific logic (renderer/main) is injected via context interfaces.
*/
import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
import { routeToEndpoint } from '../api'
import { getAiSdkProviderId } from './mapping'
import type { MinimalProvider } from './types'
import { SystemProviderIds } from './types'
/**
* AI SDK configuration result
*/
export interface AiSdkConfig {
providerId: string
options: Record<string, unknown>
}
/**
* Context for environment-specific implementations
*/
export interface AiSdkConfigContext {
/**
* Get the rotated API key (for multi-key support)
* Default: returns first key
*/
getRotatedApiKey?: (provider: MinimalProvider) => string
/**
* Check if a model uses chat completion only (for OpenAI response mode)
* Default: returns false
*/
isOpenAIChatCompletionOnlyModel?: (modelId: string) => boolean
/**
* Get Copilot default headers (constants)
* Default: returns empty object
*/
getCopilotDefaultHeaders?: () => Record<string, string>
/**
* Get Copilot stored headers from state
* Default: returns empty object
*/
getCopilotStoredHeaders?: () => Record<string, string>
/**
* Get AWS Bedrock configuration
* Default: returns undefined (not configured)
*/
getAwsBedrockConfig?: () =>
| {
authType: 'apiKey' | 'iam'
region: string
apiKey?: string
accessKeyId?: string
secretAccessKey?: string
}
| undefined
/**
* Get Vertex AI configuration
* Default: returns undefined (not configured)
*/
getVertexConfig?: (provider: MinimalProvider) =>
| {
project: string
location: string
googleCredentials: {
privateKey: string
clientEmail: string
}
}
| undefined
/**
* Get endpoint type for cherryin provider
*/
getEndpointType?: (modelId: string) => string | undefined
/**
* Custom fetch implementation
* Main process: use Electron net.fetch
* Renderer process: use browser fetch (default)
*/
fetch?: typeof globalThis.fetch
/**
* Get CherryAI signed fetch wrapper
* Returns a fetch function that adds signature headers to requests
*/
getCherryAISignedFetch?: () => typeof globalThis.fetch
}
/**
* Default simple key rotator - returns first key
*/
function defaultGetRotatedApiKey(provider: MinimalProvider): string {
const keys = provider.apiKey.split(',').map((k) => k.trim())
return keys[0] || provider.apiKey
}
/**
* Convert Cherry Studio Provider to AI SDK configuration
*
* @param provider - The formatted provider (after formatProviderApiHost)
* @param modelId - The model ID to use
* @param context - Environment-specific implementations
* @returns AI SDK configuration
*/
export function providerToAiSdkConfig(
provider: MinimalProvider,
modelId: string,
context: AiSdkConfigContext = {}
): AiSdkConfig {
const getRotatedApiKey = context.getRotatedApiKey || defaultGetRotatedApiKey
const isOpenAIChatCompletionOnlyModel = context.isOpenAIChatCompletionOnlyModel || (() => false)
const aiSdkProviderId = getAiSdkProviderId(provider)
// Build base config
const { baseURL, endpoint } = routeToEndpoint(provider.apiHost)
const baseConfig = {
baseURL,
apiKey: getRotatedApiKey(provider)
}
// Handle Copilot specially
if (provider.id === SystemProviderIds.copilot) {
const defaultHeaders = context.getCopilotDefaultHeaders?.() ?? {}
const storedHeaders = context.getCopilotStoredHeaders?.() ?? {}
const copilotExtraOptions: Record<string, unknown> = {
headers: {
...defaultHeaders,
...storedHeaders,
...provider.extra_headers
},
name: provider.id,
includeUsage: true
}
if (context.fetch) {
copilotExtraOptions.fetch = context.fetch
}
const options = ProviderConfigFactory.fromProvider(
'github-copilot-openai-compatible',
baseConfig,
copilotExtraOptions
)
return {
providerId: 'github-copilot-openai-compatible',
options
}
}
// Build extra options
const extraOptions: Record<string, unknown> = {}
if (endpoint) {
extraOptions.endpoint = endpoint
}
// Handle OpenAI mode
if (provider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(modelId)) {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && provider.type === 'openai')) {
extraOptions.mode = 'chat'
}
// Add extra headers
if (provider.extra_headers) {
extraOptions.headers = provider.extra_headers
if (aiSdkProviderId === 'openai') {
extraOptions.headers = {
...(extraOptions.headers as Record<string, string>),
'HTTP-Referer': 'https://cherry-ai.com',
'X-Title': 'Cherry Studio',
'X-Api-Key': baseConfig.apiKey
}
}
}
// Handle Azure modes
if (aiSdkProviderId === 'azure-responses') {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'azure') {
extraOptions.mode = 'chat'
}
// Handle AWS Bedrock
if (aiSdkProviderId === 'bedrock') {
const bedrockConfig = context.getAwsBedrockConfig?.()
if (bedrockConfig) {
extraOptions.region = bedrockConfig.region
if (bedrockConfig.authType === 'apiKey') {
extraOptions.apiKey = bedrockConfig.apiKey
} else {
extraOptions.accessKeyId = bedrockConfig.accessKeyId
extraOptions.secretAccessKey = bedrockConfig.secretAccessKey
}
}
}
// Handle Vertex AI
if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') {
const vertexConfig = context.getVertexConfig?.(provider)
if (vertexConfig) {
extraOptions.project = vertexConfig.project
extraOptions.location = vertexConfig.location
extraOptions.googleCredentials = {
...vertexConfig.googleCredentials,
privateKey: formatPrivateKey(vertexConfig.googleCredentials.privateKey)
}
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
}
}
// Handle cherryin endpoint type
if (aiSdkProviderId === 'cherryin') {
const endpointType = context.getEndpointType?.(modelId)
if (endpointType) {
extraOptions.endpointType = endpointType
}
}
// Handle cherryai signed fetch
if (provider.id === 'cherryai') {
const signedFetch = context.getCherryAISignedFetch?.()
if (signedFetch) {
extraOptions.fetch = signedFetch
}
} else if (context.fetch) {
extraOptions.fetch = context.fetch
}
// Check if AI SDK supports this provider natively
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return {
providerId: aiSdkProviderId,
options
}
}
// Fallback to openai-compatible
const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey)
return {
providerId: 'openai-compatible',
options: {
...options,
name: provider.id,
...extraOptions,
includeUsage: true
}
}
}

View File

@@ -1,174 +0,0 @@
import * as z from 'zod'
export const ProviderTypeSchema = z.enum([
'openai',
'openai-response',
'anthropic',
'gemini',
'azure-openai',
'vertexai',
'mistral',
'aws-bedrock',
'vertex-anthropic',
'new-api',
'ai-gateway'
])
export type ProviderType = z.infer<typeof ProviderTypeSchema>
/**
* Minimal provider interface for shared utilities
* This is the subset of Provider that shared code needs
*/
export type MinimalProvider = {
id: string
type: ProviderType
apiKey: string
apiHost: string
anthropicApiHost?: string
apiVersion?: string
extra_headers?: Record<string, string>
}
/**
* Minimal model interface for shared utilities
* This is the subset of Model that shared code needs
*/
export type MinimalModel = {
id: string
endpoint_type?: string
}
export const SystemProviderIdSchema = z.enum([
'cherryin',
'silicon',
'aihubmix',
'ocoolai',
'deepseek',
'ppio',
'alayanew',
'qiniu',
'dmxapi',
'burncloud',
'tokenflux',
'302ai',
'cephalon',
'lanyun',
'ph8',
'openrouter',
'ollama',
'ovms',
'new-api',
'lmstudio',
'anthropic',
'openai',
'azure-openai',
'gemini',
'vertexai',
'github',
'copilot',
'zhipu',
'yi',
'moonshot',
'baichuan',
'dashscope',
'stepfun',
'doubao',
'infini',
'minimax',
'groq',
'together',
'fireworks',
'nvidia',
'grok',
'hyperbolic',
'mistral',
'jina',
'perplexity',
'modelscope',
'xirang',
'hunyuan',
'tencent-cloud-ti',
'baidu-cloud',
'gpustack',
'voyageai',
'aws-bedrock',
'poe',
'aionly',
'longcat',
'huggingface',
'sophnet',
'ai-gateway',
'cerebras'
])
export type SystemProviderId = z.infer<typeof SystemProviderIdSchema>
export const isSystemProviderId = (id: string): id is SystemProviderId => {
return SystemProviderIdSchema.safeParse(id).success
}
export const SystemProviderIds = {
cherryin: 'cherryin',
silicon: 'silicon',
aihubmix: 'aihubmix',
ocoolai: 'ocoolai',
deepseek: 'deepseek',
ppio: 'ppio',
alayanew: 'alayanew',
qiniu: 'qiniu',
dmxapi: 'dmxapi',
burncloud: 'burncloud',
tokenflux: 'tokenflux',
'302ai': '302ai',
cephalon: 'cephalon',
lanyun: 'lanyun',
ph8: 'ph8',
sophnet: 'sophnet',
openrouter: 'openrouter',
ollama: 'ollama',
ovms: 'ovms',
'new-api': 'new-api',
lmstudio: 'lmstudio',
anthropic: 'anthropic',
openai: 'openai',
'azure-openai': 'azure-openai',
gemini: 'gemini',
vertexai: 'vertexai',
github: 'github',
copilot: 'copilot',
zhipu: 'zhipu',
yi: 'yi',
moonshot: 'moonshot',
baichuan: 'baichuan',
dashscope: 'dashscope',
stepfun: 'stepfun',
doubao: 'doubao',
infini: 'infini',
minimax: 'minimax',
groq: 'groq',
together: 'together',
fireworks: 'fireworks',
nvidia: 'nvidia',
grok: 'grok',
hyperbolic: 'hyperbolic',
mistral: 'mistral',
jina: 'jina',
perplexity: 'perplexity',
modelscope: 'modelscope',
xirang: 'xirang',
hunyuan: 'hunyuan',
'tencent-cloud-ti': 'tencent-cloud-ti',
'baidu-cloud': 'baidu-cloud',
gpustack: 'gpustack',
voyageai: 'voyageai',
'aws-bedrock': 'aws-bedrock',
poe: 'poe',
aionly: 'aionly',
longcat: 'longcat',
huggingface: 'huggingface',
'ai-gateway': 'ai-gateway',
cerebras: 'cerebras'
} as const satisfies Record<SystemProviderId, SystemProviderId>
export type SystemProviderIdTypeMap = typeof SystemProviderIds

View File

@@ -1 +0,0 @@
export { getBaseModelName, getLowerBaseModelName } from './naming'

View File

@@ -1,31 +0,0 @@
/**
* 从模型 ID 中提取基础名称。
* 例如:
* - 'deepseek/deepseek-r1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/deepseek-r1' => 'deepseek-r1'
* @param {string} id 模型 ID
* @param {string} [delimiter='/'] 分隔符,默认为 '/'
* @returns {string} 基础名称
*/
export const getBaseModelName = (id: string, delimiter: string = '/'): string => {
const parts = id.split(delimiter)
return parts[parts.length - 1]
}
/**
* 从模型 ID 中提取基础名称并转换为小写。
* 例如:
* - 'deepseek/DeepSeek-R1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/DeepSeek-R1' => 'deepseek-r1'
* @param {string} id 模型 ID
* @param {string} [delimiter='/'] 分隔符,默认为 '/'
* @returns {string} 小写的基础名称
*/
export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => {
const baseModelName = getBaseModelName(id, delimiter).toLowerCase()
// for openrouter
if (baseModelName.endsWith(':free')) {
return baseModelName.replace(':free', '')
}
return baseModelName
}

View File

@@ -1,638 +0,0 @@
/**
* AI SDK to Anthropic SSE Adapter
*
* Converts AI SDK's fullStream (TextStreamPart) events to Anthropic Messages API SSE format.
* This enables any AI provider supported by AI SDK to be exposed via Anthropic-compatible API.
*
* Anthropic SSE Event Flow:
* 1. message_start - Initial message with metadata
* 2. content_block_start - Begin a content block (text, tool_use, thinking)
* 3. content_block_delta - Incremental content updates
* 4. content_block_stop - End a content block
* 5. message_delta - Updates to overall message (stop_reason, usage)
* 6. message_stop - Stream complete
*
* @see https://docs.anthropic.com/en/api/messages-streaming
*/
import type {
ContentBlock,
InputJSONDelta,
Message,
MessageDeltaUsage,
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawContentBlockStopEvent,
RawMessageDeltaEvent,
RawMessageStartEvent,
RawMessageStopEvent,
RawMessageStreamEvent,
StopReason,
TextBlock,
TextDelta,
ThinkingBlock,
ThinkingDelta,
ToolUseBlock,
Usage
} from '@anthropic-ai/sdk/resources/messages'
import { loggerService } from '@logger'
import { type FinishReason, type LanguageModelUsage, type TextStreamPart, type ToolSet } from 'ai'
import { googleReasoningCache, openRouterReasoningCache } from '../../services/CacheService'
const logger = loggerService.withContext('AiSdkToAnthropicSSE')
interface ContentBlockState {
type: 'text' | 'tool_use' | 'thinking'
index: number
started: boolean
content: string
// For tool_use blocks
toolId?: string
toolName?: string
toolInput?: string
}
interface AdapterState {
messageId: string
model: string
inputTokens: number
outputTokens: number
cacheInputTokens: number
currentBlockIndex: number
blocks: Map<number, ContentBlockState>
textBlockIndex: number | null
// Track multiple thinking blocks by their reasoning ID
thinkingBlocks: Map<string, number> // reasoningId -> blockIndex
currentThinkingId: string | null // Currently active thinking block ID
toolBlocks: Map<string, number> // toolCallId -> blockIndex
stopReason: StopReason | null
hasEmittedMessageStart: boolean
}
export type SSEEventCallback = (event: RawMessageStreamEvent) => void
export interface AiSdkToAnthropicSSEOptions {
model: string
messageId?: string
inputTokens?: number
onEvent: SSEEventCallback
}
/**
* Adapter that converts AI SDK fullStream events to Anthropic SSE events
*/
export class AiSdkToAnthropicSSE {
private state: AdapterState
private onEvent: SSEEventCallback
constructor(options: AiSdkToAnthropicSSEOptions) {
this.onEvent = options.onEvent
this.state = {
messageId: options.messageId || `msg_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`,
model: options.model,
inputTokens: options.inputTokens || 0,
outputTokens: 0,
cacheInputTokens: 0,
currentBlockIndex: 0,
blocks: new Map(),
textBlockIndex: null,
thinkingBlocks: new Map(),
currentThinkingId: null,
toolBlocks: new Map(),
stopReason: null,
hasEmittedMessageStart: false
}
}
/**
* Process the AI SDK stream and emit Anthropic SSE events
*/
async processStream(fullStream: ReadableStream<TextStreamPart<ToolSet>>): Promise<void> {
const reader = fullStream.getReader()
try {
// Emit message_start at the beginning
this.emitMessageStart()
while (true) {
const { done, value } = await reader.read()
if (done) {
break
}
this.processChunk(value)
}
// Ensure all blocks are closed and emit final events
this.finalize()
} catch (error) {
await reader.cancel()
throw error
} finally {
reader.releaseLock()
}
}
/**
* Process a single AI SDK chunk and emit corresponding Anthropic events
*/
private processChunk(chunk: TextStreamPart<ToolSet>): void {
logger.silly('AiSdkToAnthropicSSE - Processing chunk:', { chunk: JSON.stringify(chunk) })
switch (chunk.type) {
// === Text Events ===
case 'text-start':
this.startTextBlock()
break
case 'text-delta':
this.emitTextDelta(chunk.text || '')
break
case 'text-end':
this.stopTextBlock()
break
// === Reasoning/Thinking Events ===
case 'reasoning-start': {
const reasoningId = chunk.id
this.startThinkingBlock(reasoningId)
break
}
case 'reasoning-delta': {
const reasoningId = chunk.id
this.emitThinkingDelta(chunk.text || '', reasoningId)
break
}
case 'reasoning-end': {
const reasoningId = chunk.id
this.stopThinkingBlock(reasoningId)
break
}
// === Tool Events ===
case 'tool-call':
if (googleReasoningCache && chunk.providerMetadata?.google?.thoughtSignature) {
googleReasoningCache.set(
`google-${chunk.toolName}`,
chunk.providerMetadata?.google?.thoughtSignature as string
)
}
// FIXME: 按toolcall id绑定
if (
openRouterReasoningCache &&
chunk.providerMetadata?.openrouter?.reasoning_details &&
Array.isArray(chunk.providerMetadata.openrouter.reasoning_details)
) {
openRouterReasoningCache.set(
'openrouter',
JSON.parse(JSON.stringify(chunk.providerMetadata.openrouter.reasoning_details))
)
}
this.handleToolCall({
type: 'tool-call',
toolCallId: chunk.toolCallId,
toolName: chunk.toolName,
args: chunk.input
})
break
case 'tool-result':
// this.handleToolResult({
// type: 'tool-result',
// toolCallId: chunk.toolCallId,
// toolName: chunk.toolName,
// args: chunk.input,
// result: chunk.output
// })
break
case 'finish-step':
if (chunk.finishReason === 'tool-calls') {
this.state.stopReason = 'tool_use'
}
break
case 'finish':
this.handleFinish(chunk)
break
case 'error':
throw chunk.error
// Ignore other event types
default:
break
}
}
private emitMessageStart(): void {
if (this.state.hasEmittedMessageStart) return
this.state.hasEmittedMessageStart = true
const usage: Usage = {
input_tokens: this.state.inputTokens,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
server_tool_use: null
}
const message: Message = {
id: this.state.messageId,
type: 'message',
role: 'assistant',
content: [],
model: this.state.model,
stop_reason: null,
stop_sequence: null,
usage
}
const event: RawMessageStartEvent = {
type: 'message_start',
message
}
this.onEvent(event)
}
private startTextBlock(): void {
// If we already have a text block, don't create another
if (this.state.textBlockIndex !== null) return
const index = this.state.currentBlockIndex++
this.state.textBlockIndex = index
this.state.blocks.set(index, {
type: 'text',
index,
started: true,
content: ''
})
const contentBlock: TextBlock = {
type: 'text',
text: '',
citations: null
}
const event: RawContentBlockStartEvent = {
type: 'content_block_start',
index,
content_block: contentBlock
}
this.onEvent(event)
}
private emitTextDelta(text: string): void {
if (!text) return
// Auto-start text block if not started
if (this.state.textBlockIndex === null) {
this.startTextBlock()
}
const index = this.state.textBlockIndex!
const block = this.state.blocks.get(index)
if (block) {
block.content += text
}
const delta: TextDelta = {
type: 'text_delta',
text
}
const event: RawContentBlockDeltaEvent = {
type: 'content_block_delta',
index,
delta
}
this.onEvent(event)
}
private stopTextBlock(): void {
if (this.state.textBlockIndex === null) return
const index = this.state.textBlockIndex
const event: RawContentBlockStopEvent = {
type: 'content_block_stop',
index
}
this.onEvent(event)
this.state.textBlockIndex = null
}
private startThinkingBlock(reasoningId: string): void {
// Check if this thinking block already exists
if (this.state.thinkingBlocks.has(reasoningId)) return
const index = this.state.currentBlockIndex++
this.state.thinkingBlocks.set(reasoningId, index)
this.state.currentThinkingId = reasoningId
this.state.blocks.set(index, {
type: 'thinking',
index,
started: true,
content: ''
})
const contentBlock: ThinkingBlock = {
type: 'thinking',
thinking: '',
signature: ''
}
const event: RawContentBlockStartEvent = {
type: 'content_block_start',
index,
content_block: contentBlock
}
this.onEvent(event)
}
private emitThinkingDelta(text: string, reasoningId?: string): void {
if (!text) return
// Determine which thinking block to use
const targetId = reasoningId || this.state.currentThinkingId
if (!targetId) {
// Auto-start thinking block if not started
const newId = `reasoning_${Date.now()}`
this.startThinkingBlock(newId)
return this.emitThinkingDelta(text, newId)
}
const index = this.state.thinkingBlocks.get(targetId)
if (index === undefined) {
// If the block doesn't exist, create it
this.startThinkingBlock(targetId)
return this.emitThinkingDelta(text, targetId)
}
const block = this.state.blocks.get(index)
if (block) {
block.content += text
}
const delta: ThinkingDelta = {
type: 'thinking_delta',
thinking: text
}
const event: RawContentBlockDeltaEvent = {
type: 'content_block_delta',
index,
delta
}
this.onEvent(event)
}
private stopThinkingBlock(reasoningId?: string): void {
const targetId = reasoningId || this.state.currentThinkingId
if (!targetId) return
const index = this.state.thinkingBlocks.get(targetId)
if (index === undefined) return
const event: RawContentBlockStopEvent = {
type: 'content_block_stop',
index
}
this.onEvent(event)
this.state.thinkingBlocks.delete(targetId)
// Update currentThinkingId if we just closed the current one
if (this.state.currentThinkingId === targetId) {
// Set to the most recent remaining thinking block, or null if none
const remaining = Array.from(this.state.thinkingBlocks.keys())
this.state.currentThinkingId = remaining.length > 0 ? remaining[remaining.length - 1] : null
}
}
private handleToolCall(chunk: { type: 'tool-call'; toolCallId: string; toolName: string; args: unknown }): void {
const { toolCallId, toolName, args } = chunk
// Check if we already have this tool call
if (this.state.toolBlocks.has(toolCallId)) {
return
}
const index = this.state.currentBlockIndex++
this.state.toolBlocks.set(toolCallId, index)
const inputJson = JSON.stringify(args)
this.state.blocks.set(index, {
type: 'tool_use',
index,
started: true,
content: inputJson,
toolId: toolCallId,
toolName,
toolInput: inputJson
})
// Emit content_block_start for tool_use
const contentBlock: ToolUseBlock = {
type: 'tool_use',
id: toolCallId,
name: toolName,
input: {}
}
const startEvent: RawContentBlockStartEvent = {
type: 'content_block_start',
index,
content_block: contentBlock
}
this.onEvent(startEvent)
// Emit the full input as a delta (Anthropic streams JSON incrementally)
const delta: InputJSONDelta = {
type: 'input_json_delta',
partial_json: inputJson
}
const deltaEvent: RawContentBlockDeltaEvent = {
type: 'content_block_delta',
index,
delta
}
this.onEvent(deltaEvent)
// Emit content_block_stop
const stopEvent: RawContentBlockStopEvent = {
type: 'content_block_stop',
index
}
this.onEvent(stopEvent)
// Mark that we have tool use
this.state.stopReason = 'tool_use'
}
private handleFinish(chunk: { type: 'finish'; finishReason?: FinishReason; totalUsage?: LanguageModelUsage }): void {
// Update usage
if (chunk.totalUsage) {
this.state.inputTokens = chunk.totalUsage.inputTokens || 0
this.state.outputTokens = chunk.totalUsage.outputTokens || 0
this.state.cacheInputTokens = chunk.totalUsage.cachedInputTokens || 0
}
// Determine finish reason
if (!this.state.stopReason) {
switch (chunk.finishReason) {
case 'stop':
this.state.stopReason = 'end_turn'
break
case 'length':
this.state.stopReason = 'max_tokens'
break
case 'tool-calls':
this.state.stopReason = 'tool_use'
break
case 'content-filter':
this.state.stopReason = 'refusal'
break
default:
this.state.stopReason = 'end_turn'
}
}
}
private finalize(): void {
// Close any open blocks
if (this.state.textBlockIndex !== null) {
this.stopTextBlock()
}
// Close all open thinking blocks
for (const reasoningId of this.state.thinkingBlocks.keys()) {
this.stopThinkingBlock(reasoningId)
}
// Emit message_delta with final stop reason and usage
const usage: MessageDeltaUsage = {
output_tokens: this.state.outputTokens,
input_tokens: this.state.inputTokens,
cache_creation_input_tokens: this.state.cacheInputTokens,
cache_read_input_tokens: null,
server_tool_use: null
}
const messageDeltaEvent: RawMessageDeltaEvent = {
type: 'message_delta',
delta: {
stop_reason: this.state.stopReason || 'end_turn',
stop_sequence: null
},
usage
}
this.onEvent(messageDeltaEvent)
// Emit message_stop
const messageStopEvent: RawMessageStopEvent = {
type: 'message_stop'
}
this.onEvent(messageStopEvent)
}
/**
* Set input token count (typically from prompt)
*/
setInputTokens(count: number): void {
this.state.inputTokens = count
}
/**
* Get the current message ID
*/
getMessageId(): string {
return this.state.messageId
}
/**
* Build a complete Message object for non-streaming responses
*/
buildNonStreamingResponse(): Message {
const content: ContentBlock[] = []
// Collect all content blocks in order
const sortedBlocks = Array.from(this.state.blocks.values()).sort((a, b) => a.index - b.index)
for (const block of sortedBlocks) {
switch (block.type) {
case 'text':
content.push({
type: 'text',
text: block.content,
citations: null
} as TextBlock)
break
case 'thinking':
content.push({
type: 'thinking',
thinking: block.content
} as ThinkingBlock)
break
case 'tool_use':
content.push({
type: 'tool_use',
id: block.toolId!,
name: block.toolName!,
input: JSON.parse(block.toolInput || '{}')
} as ToolUseBlock)
break
}
}
return {
id: this.state.messageId,
type: 'message',
role: 'assistant',
content,
model: this.state.model,
stop_reason: this.state.stopReason || 'end_turn',
stop_sequence: null,
usage: {
input_tokens: this.state.inputTokens,
output_tokens: this.state.outputTokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
server_tool_use: null
}
}
}
}
/**
* Format an Anthropic SSE event for HTTP streaming
*/
export function formatSSEEvent(event: RawMessageStreamEvent): string {
return `event: ${event.type}\ndata: ${JSON.stringify(event)}\n\n`
}
/**
* Create a done marker for SSE stream
*/
export function formatSSEDone(): string {
return 'data: [DONE]\n\n'
}
export default AiSdkToAnthropicSSE

View File

@@ -1,13 +0,0 @@
/**
* Shared Adapters
*
* This module exports adapters for converting between different AI API formats.
*/
export {
AiSdkToAnthropicSSE,
type AiSdkToAnthropicSSEOptions,
formatSSEDone,
formatSSEEvent,
type SSEEventCallback
} from './AiSdkToAnthropicSSE'

View File

@@ -1,95 +0,0 @@
import * as z from 'zod/v4'
enum ReasoningFormat {
Unknown = 'unknown',
OpenAIResponsesV1 = 'openai-responses-v1',
XAIResponsesV1 = 'xai-responses-v1',
AnthropicClaudeV1 = 'anthropic-claude-v1',
GoogleGeminiV1 = 'google-gemini-v1'
}
// Anthropic Claude was the first reasoning that we're
// passing back and forth
export const DEFAULT_REASONING_FORMAT = ReasoningFormat.AnthropicClaudeV1
function isDefinedOrNotNull<T>(value: T | null | undefined): value is T {
return value !== null && value !== undefined
}
export enum ReasoningDetailType {
Summary = 'reasoning.summary',
Encrypted = 'reasoning.encrypted',
Text = 'reasoning.text'
}
export const CommonReasoningDetailSchema = z
.object({
id: z.string().nullish(),
format: z.enum(ReasoningFormat).nullish(),
index: z.number().optional()
})
.loose()
export const ReasoningDetailSummarySchema = z
.object({
type: z.literal(ReasoningDetailType.Summary),
summary: z.string()
})
.extend(CommonReasoningDetailSchema.shape)
export type ReasoningDetailSummary = z.infer<typeof ReasoningDetailSummarySchema>
export const ReasoningDetailEncryptedSchema = z
.object({
type: z.literal(ReasoningDetailType.Encrypted),
data: z.string()
})
.extend(CommonReasoningDetailSchema.shape)
export type ReasoningDetailEncrypted = z.infer<typeof ReasoningDetailEncryptedSchema>
export const ReasoningDetailTextSchema = z
.object({
type: z.literal(ReasoningDetailType.Text),
text: z.string().nullish(),
signature: z.string().nullish()
})
.extend(CommonReasoningDetailSchema.shape)
export type ReasoningDetailText = z.infer<typeof ReasoningDetailTextSchema>
export const ReasoningDetailUnionSchema = z.union([
ReasoningDetailSummarySchema,
ReasoningDetailEncryptedSchema,
ReasoningDetailTextSchema
])
export type ReasoningDetailUnion = z.infer<typeof ReasoningDetailUnionSchema>
const ReasoningDetailsWithUnknownSchema = z.union([ReasoningDetailUnionSchema, z.unknown().transform(() => null)])
export const ReasoningDetailArraySchema = z
.array(ReasoningDetailsWithUnknownSchema)
.transform((d) => d.filter((d): d is ReasoningDetailUnion => !!d))
export const OutputUnionToReasoningDetailsSchema = z.union([
z
.object({
delta: z.object({
reasoning_details: z.array(ReasoningDetailsWithUnknownSchema)
})
})
.transform((data) => data.delta.reasoning_details.filter(isDefinedOrNotNull)),
z
.object({
message: z.object({
reasoning_details: z.array(ReasoningDetailsWithUnknownSchema)
})
})
.transform((data) => data.message.reasoning_details.filter(isDefinedOrNotNull)),
z
.object({
text: z.string(),
reasoning_details: z.array(ReasoningDetailsWithUnknownSchema)
})
.transform((data) => data.reasoning_details.filter(isDefinedOrNotNull))
])

View File

@@ -1,93 +1,17 @@
import type { MessageCreateParams } from '@anthropic-ai/sdk/resources'
import { loggerService } from '@logger'
import { buildSharedMiddlewares, type SharedMiddlewareConfig } from '@shared/middleware'
import { getAiSdkProviderId } from '@shared/provider'
import type { Provider } from '@types'
import type { Request, Response } from 'express'
import express from 'express'
import { messagesService } from '../services/messages'
import { generateUnifiedMessage, streamUnifiedMessages } from '../services/unified-messages'
import { getProviderById, isModelAnthropicCompatible, validateModelId } from '../utils'
/**
* Check if a specific model on a provider should use direct Anthropic SDK
*
* A provider+model combination is considered "Anthropic-compatible" if:
* 1. It's a native Anthropic provider (type === 'anthropic'), OR
* 2. It has anthropicApiHost configured AND the specific model supports Anthropic API
* (for aggregated providers like Silicon, only certain models support Anthropic endpoint)
*
* @param provider - The provider to check
* @param modelId - The model ID to check (without provider prefix)
* @returns true if should use direct Anthropic SDK, false for unified SDK
*/
function shouldUseDirectAnthropic(provider: Provider, modelId: string): boolean {
// Native Anthropic provider - always use direct SDK
if (provider.type === 'anthropic') {
return true
}
// No anthropicApiHost configured - use unified SDK
if (!provider.anthropicApiHost?.trim()) {
return false
}
// Has anthropicApiHost - check model-level compatibility
// For aggregated providers, only specific models support Anthropic API
return isModelAnthropicCompatible(provider, modelId)
}
import { getProviderById, validateModelId } from '../utils'
const logger = loggerService.withContext('ApiServerMessagesRoutes')
const router = express.Router()
const providerRouter = express.Router({ mergeParams: true })
/**
* Estimate token count from messages
* Simple approximation: ~4 characters per token for English text
*/
interface CountTokensInput {
messages: Array<{ role: string; content: string | Array<{ type: string; text?: string }> }>
system?: string | Array<{ type: string; text?: string }>
}
function estimateTokenCount(input: CountTokensInput): number {
const { messages, system } = input
let totalChars = 0
// Count system message tokens
if (system) {
if (typeof system === 'string') {
totalChars += system.length
} else if (Array.isArray(system)) {
for (const block of system) {
if (block.type === 'text' && block.text) {
totalChars += block.text.length
}
}
}
}
// Count message tokens
for (const msg of messages) {
if (typeof msg.content === 'string') {
totalChars += msg.content.length
} else if (Array.isArray(msg.content)) {
for (const block of msg.content) {
if (block.type === 'text' && block.text) {
totalChars += block.text.length
}
}
}
// Add overhead for role
totalChars += 10
}
// Estimate tokens (~4 chars per token, with some overhead)
return Math.ceil(totalChars / 4) + messages.length * 3
}
// Helper function for basic request validation
async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> {
const request: MessageCreateParams = req.body
@@ -109,36 +33,21 @@ async function validateRequestBody(req: Request): Promise<{ valid: boolean; erro
}
interface HandleMessageProcessingOptions {
req: Request
res: Response
provider: Provider
request: MessageCreateParams
modelId?: string
}
/**
* Handle message processing using direct Anthropic SDK
* Used for providers with anthropicApiHost or native Anthropic providers
* This bypasses AI SDK conversion and uses native Anthropic protocol
*/
async function handleDirectAnthropicProcessing({
async function handleMessageProcessing({
req,
res,
provider,
request,
modelId,
extraHeaders
}: HandleMessageProcessingOptions & { extraHeaders?: Record<string, string | string[]> }): Promise<void> {
const actualModelId = modelId || request.model
logger.info('Processing message via direct Anthropic SDK', {
providerId: provider.id,
providerType: provider.type,
modelId: actualModelId,
stream: !!request.stream,
anthropicApiHost: provider.anthropicApiHost
})
modelId
}: HandleMessageProcessingOptions): Promise<void> {
try {
// Validate request
const validation = messagesService.validateRequest(request)
if (!validation.isValid) {
res.status(400).json({
@@ -151,126 +60,28 @@ async function handleDirectAnthropicProcessing({
return
}
// Process message using messagesService (native Anthropic SDK)
const extraHeaders = messagesService.prepareHeaders(req.headers)
const { client, anthropicRequest } = await messagesService.processMessage({
provider,
request,
extraHeaders,
modelId: actualModelId
modelId
})
if (request.stream) {
// Use native Anthropic streaming
await messagesService.handleStreaming(client, anthropicRequest, { response: res }, provider)
} else {
// Use native Anthropic non-streaming
const response = await client.messages.create(anthropicRequest)
res.json(response)
}
} catch (error: any) {
logger.error('Direct Anthropic processing error', { error })
const { statusCode, errorResponse } = messagesService.transformError(error)
res.status(statusCode).json(errorResponse)
}
}
/**
* Handle message processing using unified AI SDK
* Used for non-Anthropic providers that need format conversion
* - Uses AI SDK adapters with output converted to Anthropic SSE format
*/
async function handleUnifiedProcessing({
res,
provider,
request,
modelId
}: HandleMessageProcessingOptions): Promise<void> {
const actualModelId = modelId || request.model
logger.info('Processing message via unified AI SDK', {
providerId: provider.id,
providerType: provider.type,
modelId: actualModelId,
stream: !!request.stream
})
try {
// Validate request
const validation = messagesService.validateRequest(request)
if (!validation.isValid) {
res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: validation.errors.join('; ')
}
})
return
}
const middlewareConfig: SharedMiddlewareConfig = {
modelId: actualModelId,
providerId: provider.id,
aiSdkProviderId: getAiSdkProviderId(provider)
}
const middlewares = buildSharedMiddlewares(middlewareConfig)
logger.debug('Built middlewares for unified processing', {
middlewareCount: middlewares.length,
modelId: actualModelId,
providerId: provider.id
})
if (request.stream) {
await streamUnifiedMessages({
response: res,
provider,
modelId: actualModelId,
params: request,
middlewares,
onError: (error) => {
logger.error('Stream error', error as Error)
},
onComplete: () => {
logger.debug('Stream completed')
}
})
} else {
const response = await generateUnifiedMessage({
provider,
modelId: actualModelId,
params: request,
middlewares
})
res.json(response)
}
const response = await client.messages.create(anthropicRequest)
res.json(response)
} catch (error: any) {
logger.error('Message processing error', { error })
const { statusCode, errorResponse } = messagesService.transformError(error)
res.status(statusCode).json(errorResponse)
}
}
/**
* Handle message processing - routes to appropriate handler based on provider and model
*
* Routing logic:
* - Native Anthropic providers (type === 'anthropic'): Direct Anthropic SDK
* - Providers with anthropicApiHost AND model supports Anthropic API: Direct Anthropic SDK
* - Other providers/models: Unified AI SDK with Anthropic SSE conversion
*/
async function handleMessageProcessing({
res,
provider,
request,
modelId
}: HandleMessageProcessingOptions): Promise<void> {
const actualModelId = modelId || request.model
if (shouldUseDirectAnthropic(provider, actualModelId)) {
return handleDirectAnthropicProcessing({ res, provider, request, modelId })
}
return handleUnifiedProcessing({ res, provider, request, modelId })
}
/**
* @swagger
* /v1/messages:
@@ -424,7 +235,7 @@ router.post('/', async (req: Request, res: Response) => {
const provider = modelValidation.provider!
const modelId = modelValidation.modelId!
return handleMessageProcessing({ res, provider, request, modelId })
return handleMessageProcessing({ req, res, provider, request, modelId })
} catch (error: any) {
logger.error('Message processing error', { error })
const { statusCode, errorResponse } = messagesService.transformError(error)
@@ -582,7 +393,7 @@ providerRouter.post('/', async (req: Request, res: Response) => {
const request: MessageCreateParams = req.body
return handleMessageProcessing({ res, provider, request })
return handleMessageProcessing({ req, res, provider, request })
} catch (error: any) {
logger.error('Message processing error', { error })
const { statusCode, errorResponse } = messagesService.transformError(error)
@@ -590,132 +401,4 @@ providerRouter.post('/', async (req: Request, res: Response) => {
}
})
/**
* @swagger
* /v1/messages/count_tokens:
* post:
* summary: Count tokens for messages
* description: Count tokens for Anthropic Messages API format (required by Claude Code SDK)
* tags: [Messages]
* requestBody:
* required: true
* content:
* application/json:
* schema:
* type: object
* required:
* - model
* - messages
* properties:
* model:
* type: string
* description: Model ID
* messages:
* type: array
* items:
* type: object
* system:
* type: string
* description: System message
* responses:
* 200:
* description: Token count response
* content:
* application/json:
* schema:
* type: object
* properties:
* input_tokens:
* type: integer
* 400:
* description: Bad request
*/
router.post('/count_tokens', async (req: Request, res: Response) => {
try {
const { model, messages, system } = req.body
if (!model) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'model parameter is required'
}
})
}
if (!messages || !Array.isArray(messages)) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'messages parameter is required'
}
})
}
const estimatedTokens = estimateTokenCount({ messages, system })
logger.debug('Token count estimated', {
model,
messageCount: messages.length,
estimatedTokens
})
return res.json({
input_tokens: estimatedTokens
})
} catch (error: any) {
logger.error('Token counting error', { error })
return res.status(500).json({
type: 'error',
error: {
type: 'api_error',
message: error.message || 'Internal server error'
}
})
}
})
/**
* Provider-specific count_tokens endpoint
*/
providerRouter.post('/count_tokens', async (req: Request, res: Response) => {
try {
const { model, messages, system } = req.body
if (!messages || !Array.isArray(messages)) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'messages parameter is required'
}
})
}
const estimatedTokens = estimateTokenCount({ messages, system })
logger.debug('Token count estimated (provider route)', {
providerId: req.params.provider,
model,
messageCount: messages.length,
estimatedTokens
})
return res.json({
input_tokens: estimatedTokens
})
} catch (error: any) {
logger.error('Token counting error', { error })
return res.status(500).json({
type: 'error',
error: {
type: 'api_error',
message: error.message || 'Internal server error'
}
})
}
})
export { providerRouter as messagesProviderRoutes, router as messagesRoutes }

View File

@@ -2,10 +2,8 @@ import type Anthropic from '@anthropic-ai/sdk'
import type { MessageCreateParams, MessageStreamEvent } from '@anthropic-ai/sdk/resources'
import { loggerService } from '@logger'
import anthropicService from '@main/services/AnthropicService'
import { buildClaudeCodeSystemMessage, getSdkClient, sanitizeToolsForAnthropic } from '@shared/anthropic'
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
import type { Provider } from '@types'
import { APICallError, RetryError } from 'ai'
import { net } from 'electron'
import type { Response } from 'express'
const logger = loggerService.withContext('MessagesService')
@@ -100,30 +98,11 @@ export class MessagesService {
async getClient(provider: Provider, extraHeaders?: Record<string, string | string[]>): Promise<Anthropic> {
// Create Anthropic client for the provider
// Wrap net.fetch to handle compatibility issues:
// 1. net.fetch expects string URLs, not Request objects
// 2. net.fetch doesn't support 'agent' option from Node.js http module
const electronFetch: typeof globalThis.fetch = async (input: URL | RequestInfo, init?: RequestInit) => {
const url = typeof input === 'string' ? input : input instanceof URL ? input.toString() : input.url
// Remove unsupported options for Electron's net.fetch
if (init) {
const initWithAgent = init as RequestInit & { agent?: unknown }
delete initWithAgent.agent
const headers = new Headers(initWithAgent.headers)
if (headers.has('content-length')) {
headers.delete('content-length')
}
initWithAgent.headers = headers
return net.fetch(url, initWithAgent)
}
return net.fetch(url)
}
const context = { fetch: electronFetch }
if (provider.authType === 'oauth') {
const oauthToken = await anthropicService.getValidAccessToken()
return getSdkClient(provider, oauthToken, extraHeaders, context)
return getSdkClient(provider, oauthToken, extraHeaders)
}
return getSdkClient(provider, null, extraHeaders, context)
return getSdkClient(provider, null, extraHeaders)
}
prepareHeaders(headers: Record<string, string | string[] | undefined>): Record<string, string | string[]> {
@@ -148,8 +127,7 @@ export class MessagesService {
createAnthropicRequest(request: MessageCreateParams, provider: Provider, modelId?: string): MessageCreateParams {
const anthropicRequest: MessageCreateParams = {
...request,
stream: !!request.stream,
tools: sanitizeToolsForAnthropic(request.tools)
stream: !!request.stream
}
// Override model if provided
@@ -255,71 +233,9 @@ export class MessagesService {
}
transformError(error: any): { statusCode: number; errorResponse: ErrorResponse } {
let statusCode: number | undefined = undefined
let errorType: string | undefined = undefined
let errorMessage: string | undefined = undefined
const errorMap: Record<number, string> = {
400: 'invalid_request_error',
401: 'authentication_error',
403: 'forbidden_error',
404: 'not_found_error',
429: 'rate_limit_error',
500: 'internal_server_error'
}
// Handle AI SDK RetryError - extract the last error for better error messages
if (RetryError.isInstance(error)) {
const lastError = error.lastError
// If the last error is an APICallError, extract its details
if (APICallError.isInstance(lastError)) {
statusCode = lastError.statusCode || 502
errorMessage = lastError.message
return {
statusCode,
errorResponse: {
type: 'error',
error: {
type: errorMap[statusCode] || 'api_error',
message: `${error.reason}: ${errorMessage}`,
requestId: lastError.name
}
}
}
}
// Fallback for other retry errors
errorMessage = error.message
statusCode = 502
return {
statusCode,
errorResponse: {
type: 'error',
error: {
type: 'api_error',
message: errorMessage,
requestId: error.name
}
}
}
}
if (APICallError.isInstance(error)) {
statusCode = error.statusCode
errorMessage = error.message
if (statusCode) {
return {
statusCode,
errorResponse: {
type: 'error',
error: {
type: errorMap[statusCode] || 'api_error',
message: errorMessage,
requestId: error.name
}
}
}
}
}
let statusCode = 500
let errorType = 'api_error'
let errorMessage = 'Internal server error'
const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined
const anthropicError = error?.error
@@ -361,11 +277,11 @@ export class MessagesService {
typeof errorMessage === 'string' && errorMessage.length > 0 ? errorMessage : 'Internal server error'
return {
statusCode: statusCode ?? 500,
statusCode,
errorResponse: {
type: 'error',
error: {
type: errorType || 'api_error',
type: errorType,
message: safeErrorMessage,
requestId: error?.request_id
}

View File

@@ -1,6 +1,13 @@
import { isEmpty } from 'lodash'
import type { ApiModel, ApiModelsFilter, ApiModelsResponse } from '../../../renderer/src/types/apiModels'
import { loggerService } from '../../services/LoggerService'
import { getAvailableProviders, listAllAvailableModels, transformModelToOpenAI } from '../utils'
import {
getAvailableProviders,
getProviderAnthropicModelChecker,
listAllAvailableModels,
transformModelToOpenAI
} from '../utils'
const logger = loggerService.withContext('ModelsService')
@@ -13,12 +20,11 @@ export class ModelsService {
try {
logger.debug('Getting available models from providers', { filter })
const providers = await getAvailableProviders()
let providers = await getAvailableProviders()
// Note: When providerType === 'anthropic', we now return ALL available models
// because the API Server's unified adapter (AiSdkToAnthropicSSE) can convert
// any provider's response to Anthropic SSE format. This enables Claude Code Agent
// to work with OpenAI, Gemini, and other providers transparently.
if (filter.providerType === 'anthropic') {
providers = providers.filter((p) => p.type === 'anthropic' || !isEmpty(p.anthropicApiHost?.trim()))
}
const models = await listAllAvailableModels(providers)
// Use Map to deduplicate models by their full ID (provider:model_id)
@@ -26,11 +32,20 @@ export class ModelsService {
for (const model of models) {
const provider = providers.find((p) => p.id === model.provider)
// logger.debug(`Processing model ${model.id}`)
if (!provider) {
logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`)
continue
}
if (filter.providerType === 'anthropic') {
const checker = getProviderAnthropicModelChecker(provider.id)
if (!checker(model)) {
logger.debug(`Skipping model ${model.id} from ${model.provider}. Reason: Not an Anthropic model.`)
continue
}
}
const openAIModel = transformModelToOpenAI(model, provider)
const fullModelId = openAIModel.id // This is already in format "provider:model_id"

View File

@@ -1,718 +0,0 @@
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import type { LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider'
import type { ProviderOptions, ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils'
import type {
ImageBlockParam,
MessageCreateParams,
TextBlockParam,
Tool as AnthropicTool
} from '@anthropic-ai/sdk/resources/messages'
import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core'
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@main/apiServer/adapters'
import { generateSignature as cherryaiGenerateSignature } from '@main/integration/cherryai'
import anthropicService from '@main/services/AnthropicService'
import copilotService from '@main/services/CopilotService'
import { reduxService } from '@main/services/ReduxService'
import { isGemini3ModelId } from '@shared/middleware'
import {
type AiSdkConfig,
type AiSdkConfigContext,
formatProviderApiHost,
initializeSharedProviders,
isAnthropicProvider,
isGeminiProvider,
isOpenAIProvider,
type ProviderFormatContext,
providerToAiSdkConfig as sharedProviderToAiSdkConfig,
resolveActualProvider
} from '@shared/provider'
import { COPILOT_DEFAULT_HEADERS } from '@shared/provider/constant'
import { defaultAppHeaders } from '@shared/utils'
import type { Provider } from '@types'
import type { ImagePart, JSONValue, ModelMessage, Provider as AiSdkProvider, TextPart, Tool as AiSdkTool } from 'ai'
import { simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel, zodSchema } from 'ai'
import { net } from 'electron'
import type { Response } from 'express'
import * as z from 'zod'
import { googleReasoningCache, openRouterReasoningCache } from '../../services/CacheService'
const logger = loggerService.withContext('UnifiedMessagesService')
const MAGIC_STRING = 'skip_thought_signature_validator'
function sanitizeJson(value: unknown): JSONValue {
return JSON.parse(JSON.stringify(value))
}
initializeSharedProviders({
warn: (message) => logger.warn(message),
error: (message, error) => logger.error(message, error)
})
/**
* Configuration for unified message streaming
*/
export interface UnifiedStreamConfig {
response: Response
provider: Provider
modelId: string
params: MessageCreateParams
onError?: (error: unknown) => void
onComplete?: () => void
/**
* Optional AI SDK middlewares to apply
*/
middlewares?: LanguageModelV2Middleware[]
/**
* Optional AI Core plugins to use with the executor
*/
plugins?: AiPlugin[]
}
/**
* Configuration for non-streaming message generation
*/
export interface GenerateUnifiedMessageConfig {
provider: Provider
modelId: string
params: MessageCreateParams
middlewares?: LanguageModelV2Middleware[]
plugins?: AiPlugin[]
}
function getMainProcessFormatContext(): ProviderFormatContext {
const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai')
return {
vertex: {
project: vertexSettings?.projectId || 'default-project',
location: vertexSettings?.location || 'us-central1'
}
}
}
const mainProcessSdkContext: AiSdkConfigContext = {
getRotatedApiKey: (provider) => {
const keys = provider.apiKey.split(',').map((k) => k.trim())
return keys[0] || provider.apiKey
},
fetch: net.fetch as typeof globalThis.fetch
}
function getActualProvider(provider: Provider, modelId: string): Provider {
const model = provider.models?.find((m) => m.id === modelId)
if (!model) return provider
return resolveActualProvider(provider, model)
}
function providerToAiSdkConfig(provider: Provider, modelId: string): AiSdkConfig {
const actualProvider = getActualProvider(provider, modelId)
const formattedProvider = formatProviderApiHost(actualProvider, getMainProcessFormatContext())
return sharedProviderToAiSdkConfig(formattedProvider, modelId, mainProcessSdkContext)
}
function convertAnthropicToolResultToAiSdk(
content: string | Array<TextBlockParam | ImageBlockParam>
): LanguageModelV2ToolResultOutput {
if (typeof content === 'string') {
return { type: 'text', value: content }
}
const values: Array<{ type: 'text'; text: string } | { type: 'media'; data: string; mediaType: string }> = []
for (const block of content) {
if (block.type === 'text') {
values.push({ type: 'text', text: block.text })
} else if (block.type === 'image') {
values.push({
type: 'media',
data: block.source.type === 'base64' ? block.source.data : block.source.url,
mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png'
})
}
}
return { type: 'content', value: values }
}
// Type alias for JSON Schema (compatible with recursive calls)
type JsonSchemaLike = AnthropicTool.InputSchema | Record<string, unknown>
/**
* Convert JSON Schema to Zod schema
* This avoids non-standard fields like input_examples that Anthropic doesn't support
*/
function jsonSchemaToZod(schema: JsonSchemaLike): z.ZodTypeAny {
const s = schema as Record<string, unknown>
const schemaType = s.type as string | string[] | undefined
const enumValues = s.enum as unknown[] | undefined
const description = s.description as string | undefined
// Handle enum first
if (enumValues && Array.isArray(enumValues) && enumValues.length > 0) {
if (enumValues.every((v) => typeof v === 'string')) {
const zodEnum = z.enum(enumValues as [string, ...string[]])
return description ? zodEnum.describe(description) : zodEnum
}
// For non-string enums, use union of literals
const literals = enumValues.map((v) => z.literal(v as string | number | boolean))
if (literals.length === 1) {
return description ? literals[0].describe(description) : literals[0]
}
const zodUnion = z.union(literals as unknown as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
return description ? zodUnion.describe(description) : zodUnion
}
// Handle union types (type: ["string", "null"])
if (Array.isArray(schemaType)) {
const schemas = schemaType.map((t) => jsonSchemaToZod({ ...s, type: t, enum: undefined }))
if (schemas.length === 1) {
return schemas[0]
}
return z.union(schemas as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
}
// Handle by type
switch (schemaType) {
case 'string': {
let zodString = z.string()
if (typeof s.minLength === 'number') zodString = zodString.min(s.minLength)
if (typeof s.maxLength === 'number') zodString = zodString.max(s.maxLength)
if (typeof s.pattern === 'string') zodString = zodString.regex(new RegExp(s.pattern))
return description ? zodString.describe(description) : zodString
}
case 'number':
case 'integer': {
let zodNumber = schemaType === 'integer' ? z.number().int() : z.number()
if (typeof s.minimum === 'number') zodNumber = zodNumber.min(s.minimum)
if (typeof s.maximum === 'number') zodNumber = zodNumber.max(s.maximum)
return description ? zodNumber.describe(description) : zodNumber
}
case 'boolean': {
const zodBoolean = z.boolean()
return description ? zodBoolean.describe(description) : zodBoolean
}
case 'null':
return z.null()
case 'array': {
const items = s.items as Record<string, unknown> | undefined
let zodArray = items ? z.array(jsonSchemaToZod(items)) : z.array(z.unknown())
if (typeof s.minItems === 'number') zodArray = zodArray.min(s.minItems)
if (typeof s.maxItems === 'number') zodArray = zodArray.max(s.maxItems)
return description ? zodArray.describe(description) : zodArray
}
case 'object': {
const properties = s.properties as Record<string, Record<string, unknown>> | undefined
const required = (s.required as string[]) || []
// Always use z.object() to ensure "properties" field is present in output schema
// OpenAI requires explicit properties field even for empty objects
const shape: Record<string, z.ZodTypeAny> = {}
if (properties) {
for (const [key, propSchema] of Object.entries(properties)) {
const zodProp = jsonSchemaToZod(propSchema)
shape[key] = required.includes(key) ? zodProp : zodProp.optional()
}
}
const zodObject = z.object(shape)
return description ? zodObject.describe(description) : zodObject
}
default:
// Unknown type, use z.unknown()
return z.unknown()
}
}
function convertAnthropicToolsToAiSdk(tools: MessageCreateParams['tools']): Record<string, AiSdkTool> | undefined {
if (!tools || tools.length === 0) return undefined
const aiSdkTools: Record<string, AiSdkTool> = {}
for (const anthropicTool of tools) {
if (anthropicTool.type === 'bash_20250124') continue
const toolDef = anthropicTool as AnthropicTool
const rawSchema = toolDef.input_schema
const schema = jsonSchemaToZod(rawSchema)
// Use tool() with inputSchema (AI SDK v5 API)
const aiTool = tool({
description: toolDef.description || '',
inputSchema: zodSchema(schema)
})
aiSdkTools[toolDef.name] = aiTool
}
return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined
}
function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage[] {
const messages: ModelMessage[] = []
// System message
if (params.system) {
if (typeof params.system === 'string') {
messages.push({ role: 'system', content: params.system })
} else if (Array.isArray(params.system)) {
const systemText = params.system
.filter((block) => block.type === 'text')
.map((block) => block.text)
.join('\n')
if (systemText) {
messages.push({ role: 'system', content: systemText })
}
}
}
const toolCallIdToName = new Map<string, string>()
for (const msg of params.messages) {
if (Array.isArray(msg.content)) {
for (const block of msg.content) {
if (block.type === 'tool_use') {
toolCallIdToName.set(block.id, block.name)
}
}
}
}
// User/assistant messages
for (const msg of params.messages) {
if (typeof msg.content === 'string') {
messages.push({
role: msg.role === 'user' ? 'user' : 'assistant',
content: msg.content
})
} else if (Array.isArray(msg.content)) {
const textParts: TextPart[] = []
const imageParts: ImagePart[] = []
const reasoningParts: ReasoningPart[] = []
const toolCallParts: ToolCallPart[] = []
const toolResultParts: ToolResultPart[] = []
for (const block of msg.content) {
if (block.type === 'text') {
textParts.push({ type: 'text', text: block.text })
} else if (block.type === 'thinking') {
reasoningParts.push({ type: 'reasoning', text: block.thinking })
} else if (block.type === 'redacted_thinking') {
reasoningParts.push({ type: 'reasoning', text: block.data })
} else if (block.type === 'image') {
const source = block.source
if (source.type === 'base64') {
imageParts.push({ type: 'image', image: `data:${source.media_type};base64,${source.data}` })
} else if (source.type === 'url') {
imageParts.push({ type: 'image', image: source.url })
}
} else if (block.type === 'tool_use') {
const options: ProviderOptions = {}
if (isGemini3ModelId(params.model)) {
if (googleReasoningCache.get(`google-${block.name}`)) {
options.google = {
thoughtSignature: MAGIC_STRING
}
} else if (openRouterReasoningCache.get('openrouter')) {
options.openrouter = {
reasoning_details: (sanitizeJson(openRouterReasoningCache.get('openrouter')) as JSONValue[]) || []
}
}
}
toolCallParts.push({
type: 'tool-call',
toolName: block.name,
toolCallId: block.id,
input: block.input,
providerOptions: options
})
} else if (block.type === 'tool_result') {
// Look up toolName from the pre-built map (covers cross-message references)
const toolName = toolCallIdToName.get(block.tool_use_id) || 'unknown'
toolResultParts.push({
type: 'tool-result',
toolCallId: block.tool_use_id,
toolName,
output: block.content ? convertAnthropicToolResultToAiSdk(block.content) : { type: 'text', value: '' }
})
}
}
if (toolResultParts.length > 0) {
messages.push({ role: 'tool', content: [...toolResultParts] })
}
if (msg.role === 'user') {
const userContent = [...textParts, ...imageParts]
if (userContent.length > 0) {
messages.push({ role: 'user', content: userContent })
}
} else {
const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts]
if (assistantContent.length > 0) {
let providerOptions: ProviderOptions | undefined = undefined
if (openRouterReasoningCache.get('openrouter')) {
providerOptions = {
openrouter: {
reasoning_details: (sanitizeJson(openRouterReasoningCache.get('openrouter')) as JSONValue[]) || []
}
}
} else if (isGemini3ModelId(params.model)) {
providerOptions = {
google: {
thoughtSignature: MAGIC_STRING
}
}
}
messages.push({ role: 'assistant', content: assistantContent, providerOptions })
}
}
}
}
return messages
}
interface ExecuteStreamConfig {
provider: Provider
modelId: string
params: MessageCreateParams
middlewares?: LanguageModelV2Middleware[]
plugins?: AiPlugin[]
onEvent?: (event: Parameters<typeof formatSSEEvent>[0]) => void
}
/**
* Create AI SDK provider instance from config
* Similar to renderer's createAiSdkProvider
*/
async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider> {
let providerId = config.providerId
// Handle special provider modes (same as renderer)
if (providerId === 'openai' && config.options?.mode === 'chat') {
providerId = 'openai-chat'
} else if (providerId === 'azure' && config.options?.mode === 'responses') {
providerId = 'azure-responses'
} else if (providerId === 'cherryin' && config.options?.mode === 'chat') {
providerId = 'cherryin-chat'
}
const provider = await createProviderCore(providerId, config.options)
return provider
}
/**
* Prepare special provider configuration for providers that need dynamic tokens
* Similar to renderer's prepareSpecialProviderConfig
*/
async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkConfig): Promise<AiSdkConfig> {
switch (provider.id) {
case 'copilot': {
const storedHeaders =
((await reduxService.select('state.copilot.defaultHeaders')) as Record<string, string> | null) ?? {}
const headers: Record<string, string> = {
...COPILOT_DEFAULT_HEADERS,
...storedHeaders
}
try {
const { token } = await copilotService.getToken(null as any, headers)
config.options.apiKey = token
const existingHeaders = (config.options.headers as Record<string, string> | undefined) ?? {}
config.options.headers = {
...headers,
...existingHeaders
}
} catch (error) {
logger.error('Failed to get Copilot token', error as Error)
throw new Error('Failed to get Copilot token. Please re-authorize Copilot.')
}
break
}
case 'anthropic': {
if (provider.authType === 'oauth') {
try {
const oauthToken = await anthropicService.getValidAccessToken()
if (!oauthToken) {
throw new Error('Anthropic OAuth token not available. Please re-authorize.')
}
config.options = {
...config.options,
headers: {
...(config.options.headers ? config.options.headers : {}),
'Content-Type': 'application/json',
'anthropic-version': '2023-06-01',
'anthropic-beta': 'oauth-2025-04-20',
Authorization: `Bearer ${oauthToken}`
},
baseURL: 'https://api.anthropic.com/v1',
apiKey: ''
}
} catch (error) {
logger.error('Failed to get Anthropic OAuth token', error as Error)
throw new Error('Failed to get Anthropic OAuth token. Please re-authorize.')
}
}
break
}
case 'cherryai': {
// Create a signed fetch wrapper for cherryai
const baseFetch = net.fetch as typeof globalThis.fetch
config.options.fetch = async (url: RequestInfo | URL, options?: RequestInit) => {
if (!options?.body) {
return baseFetch(url, options)
}
const signature = cherryaiGenerateSignature({
method: 'POST',
path: '/chat/completions',
query: '',
body: JSON.parse(options.body as string)
})
return baseFetch(url, {
...options,
headers: {
...(options.headers as Record<string, string>),
...signature
}
})
}
break
}
}
return config
}
function mapAnthropicThinkToAISdkProviderOptions(
provider: Provider,
config: MessageCreateParams['thinking']
): ProviderOptions | undefined {
if (!config) return undefined
if (isAnthropicProvider(provider)) {
return {
anthropic: {
...mapToAnthropicProviderOptions(config)
}
}
}
if (isGeminiProvider(provider)) {
return {
google: {
...mapToGeminiProviderOptions(config)
}
}
}
if (isOpenAIProvider(provider)) {
return {
openai: {
...mapToOpenAIProviderOptions(config)
}
}
}
return undefined
}
function mapToAnthropicProviderOptions(config: NonNullable<MessageCreateParams['thinking']>): AnthropicProviderOptions {
return {
thinking: {
type: config.type,
budgetTokens: config.type === 'enabled' ? config.budget_tokens : undefined
}
}
}
function mapToGeminiProviderOptions(
config: NonNullable<MessageCreateParams['thinking']>
): GoogleGenerativeAIProviderOptions {
return {
thinkingConfig: {
thinkingBudget: config.type === 'enabled' ? config.budget_tokens : -1,
includeThoughts: config.type === 'enabled'
}
}
}
function mapToOpenAIProviderOptions(
config: NonNullable<MessageCreateParams['thinking']>
): OpenAIResponsesProviderOptions {
return {
reasoningEffort: config.type === 'enabled' ? 'high' : 'none'
}
}
/**
* Core stream execution function - single source of truth for AI SDK calls
*/
async function executeStream(config: ExecuteStreamConfig): Promise<AiSdkToAnthropicSSE> {
const { provider, modelId, params, middlewares = [], plugins = [], onEvent } = config
// Convert provider config to AI SDK config
let sdkConfig = providerToAiSdkConfig(provider, modelId)
// Prepare special provider config (Copilot, Anthropic OAuth, etc.)
sdkConfig = await prepareSpecialProviderConfig(provider, sdkConfig)
// Create provider instance and get language model
const aiSdkProvider = await createAiSdkProvider(sdkConfig)
const baseModel = aiSdkProvider.languageModel(modelId)
// Apply middlewares if present
const model =
middlewares.length > 0 && typeof baseModel === 'object'
? (wrapLanguageModel({ model: baseModel, middleware: middlewares }) as typeof baseModel)
: baseModel
// Create executor with plugins
const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins)
// Convert messages and tools
const coreMessages = convertAnthropicToAiMessages(params)
const tools = convertAnthropicToolsToAiSdk(params.tools)
// Create the adapter
const adapter = new AiSdkToAnthropicSSE({
model: `${provider.id}:${modelId}`,
onEvent: onEvent || (() => {})
})
// Execute stream - pass model object instead of string
const result = await executor.streamText({
model, // Now passing LanguageModel object, not string
messages: coreMessages,
// FIXME: Claude Code传入的maxToken会超出有些模型限制需做特殊处理可能在v2好修复一点现在维护的成本有点高
// 已知: 豆包
maxOutputTokens: params.max_tokens,
temperature: params.temperature,
topP: params.top_p,
topK: params.top_k,
stopSequences: params.stop_sequences,
stopWhen: stepCountIs(100),
headers: defaultAppHeaders(),
tools,
providerOptions: mapAnthropicThinkToAISdkProviderOptions(provider, params.thinking)
})
// Process the stream through the adapter
await adapter.processStream(result.fullStream)
return adapter
}
/**
* Stream a message request using AI SDK executor and convert to Anthropic SSE format
*/
export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise<void> {
const { response, provider, modelId, params, onError, onComplete, middlewares = [], plugins = [] } = config
logger.info('Starting unified message stream', {
providerId: provider.id,
providerType: provider.type,
modelId,
stream: params.stream,
middlewareCount: middlewares.length,
pluginCount: plugins.length
})
try {
response.setHeader('Content-Type', 'text/event-stream')
response.setHeader('Cache-Control', 'no-cache')
response.setHeader('Connection', 'keep-alive')
response.setHeader('X-Accel-Buffering', 'no')
await executeStream({
provider,
modelId,
params,
middlewares,
plugins,
onEvent: (event) => {
logger.silly('Streaming event', { eventType: event.type })
const sseData = formatSSEEvent(event)
response.write(sseData)
}
})
// Send done marker
response.write(formatSSEDone())
response.end()
logger.info('Unified message stream completed', { providerId: provider.id, modelId })
onComplete?.()
} catch (error) {
logger.error('Error in unified message stream', error as Error, { providerId: provider.id, modelId })
onError?.(error)
throw error
}
}
/**
* Generate a non-streaming message response
*
* Uses simulateStreamingMiddleware to reuse the same streaming logic,
* similar to renderer's ModernAiProvider pattern.
*/
export async function generateUnifiedMessage(
providerOrConfig: Provider | GenerateUnifiedMessageConfig,
modelId?: string,
params?: MessageCreateParams
): Promise<ReturnType<typeof AiSdkToAnthropicSSE.prototype.buildNonStreamingResponse>> {
// Support both old signature and new config-based signature
let config: GenerateUnifiedMessageConfig
if ('provider' in providerOrConfig && 'modelId' in providerOrConfig && 'params' in providerOrConfig) {
config = providerOrConfig
} else {
config = {
provider: providerOrConfig as Provider,
modelId: modelId!,
params: params!
}
}
const { provider, middlewares = [], plugins = [] } = config
logger.info('Starting unified message generation', {
providerId: provider.id,
providerType: provider.type,
modelId: config.modelId,
middlewareCount: middlewares.length,
pluginCount: plugins.length
})
try {
// Add simulateStreamingMiddleware to reuse streaming logic for non-streaming
const allMiddlewares = [simulateStreamingMiddleware(), ...middlewares]
const adapter = await executeStream({
provider,
modelId: config.modelId,
params: config.params,
middlewares: allMiddlewares,
plugins
})
const finalResponse = adapter.buildNonStreamingResponse()
logger.info('Unified message generation completed', {
providerId: provider.id,
modelId: config.modelId
})
return finalResponse
} catch (error) {
logger.error('Error in unified message generation', error as Error, {
providerId: provider.id,
modelId: config.modelId
})
throw error
}
}
export default {
streamUnifiedMessages,
generateUnifiedMessage
}

View File

@@ -1,7 +1,7 @@
import { CacheService } from '@main/services/CacheService'
import { loggerService } from '@main/services/LoggerService'
import { reduxService } from '@main/services/ReduxService'
import { isPpioAnthropicCompatibleModel, isSiliconAnthropicCompatibleModel } from '@shared/config/providers'
import { isSiliconAnthropicCompatibleModel } from '@shared/config/providers'
import type { ApiModel, Model, Provider } from '@types'
const logger = loggerService.withContext('ApiServerUtils')
@@ -28,9 +28,10 @@ export async function getAvailableProviders(): Promise<Provider[]> {
return []
}
// Support all provider types that AI SDK can handle
// The unified-messages service uses AI SDK which supports many providers
const supportedProviders = providers.filter((p: Provider) => p.enabled)
// Support OpenAI and Anthropic type providers for API server
const supportedProviders = providers.filter(
(p: Provider) => p.enabled && (p.type === 'openai' || p.type === 'anthropic')
)
// Cache the filtered results
CacheService.set(PROVIDERS_CACHE_KEY, supportedProviders, PROVIDERS_CACHE_TTL)
@@ -159,7 +160,7 @@ export async function validateModelId(model: string): Promise<{
valid: false,
error: {
type: 'provider_not_found',
message: `Provider '${providerId}' not found or not enabled.`,
message: `Provider '${providerId}' not found, not enabled, or not supported. Only OpenAI providers are currently supported.`,
code: 'provider_not_found'
}
}
@@ -261,8 +262,14 @@ export function validateProvider(provider: Provider): boolean {
return false
}
// AI SDK supports many provider types, no longer need to filter by type
// The unified-messages service handles all supported types
// Support OpenAI and Anthropic type providers
if (provider.type !== 'openai' && provider.type !== 'anthropic') {
logger.debug('Provider type not supported', {
providerId: provider.id,
providerType: provider.type
})
return false
}
return true
} catch (error: any) {
@@ -283,39 +290,8 @@ export const getProviderAnthropicModelChecker = (providerId: string): ((m: Model
return (m: Model) => m.id.includes('claude')
case 'silicon':
return (m: Model) => isSiliconAnthropicCompatibleModel(m.id)
case 'ppio':
return (m: Model) => isPpioAnthropicCompatibleModel(m.id)
default:
// allow all models when checker not configured
return () => true
}
}
/**
* Check if a specific model is compatible with Anthropic API for a given provider.
*
* This is used for fine-grained routing decisions at the model level.
* For aggregated providers (like Silicon), only certain models support the Anthropic API endpoint.
*
* @param provider - The provider to check
* @param modelId - The model ID to check (without provider prefix)
* @returns true if the model supports Anthropic API endpoint
*/
export function isModelAnthropicCompatible(provider: Provider, modelId: string): boolean {
const checker = getProviderAnthropicModelChecker(provider.id)
const model = provider.models?.find((m) => m.id === modelId)
if (model) {
return checker(model)
}
const minimalModel: Model = {
id: modelId,
name: modelId,
provider: provider.id,
group: ''
}
return checker(minimalModel)
}

View File

@@ -1,19 +1,9 @@
import type { ReasoningDetailUnion } from '@main/apiServer/adapters/openrouter'
interface CacheItem<T> {
data: T
timestamp: number
duration: number
}
/**
* Interface for reasoning cache
*/
export interface IReasoningCache<T> {
set(key: string, value: T): void
get(key: string): T | undefined
}
export class CacheService {
private static cache: Map<string, CacheItem<any>> = new Map()
@@ -82,14 +72,3 @@ export class CacheService {
return true
}
}
// Singleton cache instances using CacheService
export const googleReasoningCache: IReasoningCache<string> = {
set: (key, value) => CacheService.set(`google-reasoning:${key}`, value, 30 * 60 * 1000),
get: (key) => CacheService.get(`google-reasoning:${key}`) || undefined
}
export const openRouterReasoningCache: IReasoningCache<ReasoningDetailUnion[]> = {
set: (key, value) => CacheService.set(`openrouter-reasoning:${key}`, value, 30 * 60 * 1000),
get: (key) => CacheService.get(`openrouter-reasoning:${key}`) || undefined
}

View File

@@ -87,7 +87,6 @@ export class ClaudeStreamState {
private pendingUsage: PendingUsageState = {}
private pendingToolCalls = new Map<string, PendingToolCall>()
private stepActive = false
private _streamFinished = false
constructor(options: ClaudeStreamStateOptions) {
this.logger = loggerService.withContext('ClaudeStreamState')
@@ -290,16 +289,6 @@ export class ClaudeStreamState {
getNamespacedToolCallId(rawToolCallId: string): string {
return buildNamespacedToolCallId(this.agentSessionId, rawToolCallId)
}
/** Marks the stream as finished (either completed or errored). */
markFinished(): void {
this._streamFinished = true
}
/** Returns true if the stream has already emitted a terminal event. */
isFinished(): boolean {
return this._streamFinished
}
}
export type { PendingToolCall }

View File

@@ -85,14 +85,18 @@ class ClaudeCodeService implements AgentServiceInterface {
})
return aiStream
}
// Validate provider has required configuration
// Note: We no longer restrict to anthropic type only - the API Server's unified adapter
// handles format conversion for any provider type (OpenAI, Gemini, etc.)
if (!modelInfo.provider?.apiKey) {
logger.error('Provider API key is missing', { modelInfo })
if (
(modelInfo.provider?.type !== 'anthropic' &&
(modelInfo.provider?.anthropicApiHost === undefined || modelInfo.provider.anthropicApiHost.trim() === '')) ||
modelInfo.provider.apiKey === ''
) {
logger.error('Anthropic provider configuration is missing', {
modelInfo
})
aiStream.emit('data', {
type: 'error',
error: new Error(`Provider '${modelInfo.provider?.id}' is missing API key configuration.`)
error: new Error(`Invalid provider type '${modelInfo.provider?.type}'. Expected 'anthropic' provider type.`)
})
return aiStream
}
@@ -103,14 +107,15 @@ class ClaudeCodeService implements AgentServiceInterface {
Object.entries(loginShellEnv).filter(([key]) => !key.toLowerCase().endsWith('_proxy'))
) as Record<string, string>
// Route through local API Server which handles format conversion via unified adapter
// This enables Claude Code Agent to work with any provider (OpenAI, Gemini, etc.)
// The API Server converts AI SDK responses to Anthropic SSE format transparently
const env = {
...loginShellEnvWithoutProxies,
ANTHROPIC_API_KEY: apiConfig.apiKey,
ANTHROPIC_AUTH_TOKEN: apiConfig.apiKey,
ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`,
// TODO: fix the proxy api server
// ANTHROPIC_API_KEY: apiConfig.apiKey,
// ANTHROPIC_AUTH_TOKEN: apiConfig.apiKey,
// ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`,
ANTHROPIC_API_KEY: modelInfo.provider.apiKey,
ANTHROPIC_AUTH_TOKEN: modelInfo.provider.apiKey,
ANTHROPIC_BASE_URL: modelInfo.provider.anthropicApiHost?.trim() || modelInfo.provider.apiHost,
ANTHROPIC_MODEL: modelInfo.modelId,
ANTHROPIC_DEFAULT_OPUS_MODEL: modelInfo.modelId,
ANTHROPIC_DEFAULT_SONNET_MODEL: modelInfo.modelId,
@@ -534,19 +539,6 @@ class ClaudeCodeService implements AgentServiceInterface {
return
}
// Skip emitting error if stream already finished (error was handled via result message)
if (streamState.isFinished()) {
logger.debug('SDK process exited after stream finished, skipping duplicate error event', {
duration,
error: errorObj instanceof Error ? { name: errorObj.name, message: errorObj.message } : String(errorObj)
})
// Still emit complete to signal stream end
stream.emit('data', {
type: 'complete'
})
return
}
errorChunks.push(errorObj instanceof Error ? errorObj.message : String(errorObj))
const errorMessage = errorChunks.join('\n\n')
logger.error('SDK query failed', {

View File

@@ -121,7 +121,7 @@ export function transformSDKMessageToStreamParts(sdkMessage: SDKMessage, state:
case 'system':
return handleSystemMessage(sdkMessage)
case 'result':
return handleResultMessage(sdkMessage, state)
return handleResultMessage(sdkMessage)
default:
logger.warn('Unknown SDKMessage type', { type: (sdkMessage as any).type })
return []
@@ -193,30 +193,6 @@ function handleAssistantMessage(
}
break
}
case 'thinking':
case 'redacted_thinking': {
const thinkingText = block.type === 'thinking' ? block.thinking : block.data
if (thinkingText) {
const id = generateMessageId()
chunks.push({
type: 'reasoning-start',
id,
providerMetadata
})
chunks.push({
type: 'reasoning-delta',
id,
text: thinkingText,
providerMetadata
})
chunks.push({
type: 'reasoning-end',
id,
providerMetadata
})
}
break
}
case 'tool_use':
handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks)
break
@@ -469,11 +445,7 @@ function handleStreamEvent(
case 'content_block_stop': {
const block = state.closeBlock(event.index)
if (!block) {
// Some providers (e.g., Gemini) send content via assistant message before stream events,
// so the block may not exist in state. This is expected behavior, not an error.
logger.debug('Received content_block_stop for unknown index (may be from non-streaming content)', {
index: event.index
})
logger.warn('Received content_block_stop for unknown index', { index: event.index })
break
}
@@ -707,13 +679,7 @@ function handleSystemMessage(message: Extract<SDKMessage, { type: 'system' }>):
* Successful runs yield a `finish` frame with aggregated usage metrics, while
* failures are surfaced as `error` frames.
*/
function handleResultMessage(
message: Extract<SDKMessage, { type: 'result' }>,
state: ClaudeStreamState
): AgentStreamPart[] {
// Mark stream as finished to prevent duplicate error events when SDK process exits
state.markFinished()
function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>): AgentStreamPart[] {
const chunks: AgentStreamPart[] = []
let usage: LanguageModelUsage | undefined
@@ -725,33 +691,26 @@ function handleResultMessage(
}
}
chunks.push({
type: 'finish',
totalUsage: usage ?? emptyUsage,
finishReason: mapClaudeCodeFinishReason(message.subtype),
providerMetadata: {
...sdkMessageToProviderMetadata(message),
usage: message.usage,
durationMs: message.duration_ms,
costUsd: message.total_cost_usd,
raw: message
}
} as AgentStreamPart)
if (message.subtype !== 'success') {
if (message.subtype === 'success') {
chunks.push({
type: 'finish',
totalUsage: usage ?? emptyUsage,
finishReason: mapClaudeCodeFinishReason(message.subtype),
providerMetadata: {
...sdkMessageToProviderMetadata(message),
usage: message.usage,
durationMs: message.duration_ms,
costUsd: message.total_cost_usd,
raw: message
}
} as AgentStreamPart)
} else {
chunks.push({
type: 'error',
error: {
message: `${message.subtype}: Process failed after ${message.num_turns} turns`
}
} as AgentStreamPart)
} else {
if (message.is_error) {
const errorMatch = message.result.match(/\{.*\}/)
if (errorMatch) {
const errorDetail = JSON.parse(errorMatch[0])
chunks.push(errorDetail)
}
}
}
return chunks
}

View File

@@ -27,6 +27,7 @@ import { buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
import { buildPlugins } from './plugins/PluginBuilder'
import { createAiSdkProvider } from './provider/factory'
import {
adaptProvider,
getActualProvider,
isModernSdkSupported,
prepareSpecialProviderConfig,
@@ -64,12 +65,11 @@ export default class ModernAiProvider {
* - URL will be automatically formatted via `formatProviderApiHost`, adding version suffixes like `/v1`
*
* 2. When called with `(model, provider)`:
* - **Directly uses the provided provider WITHOUT going through `getActualProvider`**
* - **URL will NOT be automatically formatted, `/v1` suffix will NOT be added**
* - This is legacy behavior kept for backward compatibility
* - The provided provider will be adapted via `adaptProvider`
* - URL formatting behavior depends on the adapted result
*
* 3. When called with `(provider)`:
* - Directly uses the provider without requiring a model
* - The provider will be adapted via `adaptProvider`
* - Used for operations that don't need a model (e.g., fetchModels)
*
* @example
@@ -77,7 +77,7 @@ export default class ModernAiProvider {
* // Recommended: Auto-format URL
* const ai = new ModernAiProvider(model)
*
* // Not recommended: Skip URL formatting (only for special cases)
* // Provider will be adapted
* const ai = new ModernAiProvider(model, customProvider)
*
* // For operations that don't need a model
@@ -91,12 +91,12 @@ export default class ModernAiProvider {
if (this.isModel(modelOrProvider)) {
// 传入的是 Model
this.model = modelOrProvider
this.actualProvider = provider || getActualProvider(modelOrProvider)
this.actualProvider = provider ? adaptProvider({ provider }) : getActualProvider(modelOrProvider)
// 只保存配置不预先创建executor
this.config = providerToAiSdkConfig(this.actualProvider, modelOrProvider)
} else {
// 传入的是 Provider
this.actualProvider = modelOrProvider
this.actualProvider = adaptProvider({ provider: modelOrProvider })
// model为可选某些操作如fetchModels不需要model
}
@@ -120,9 +120,12 @@ export default class ModernAiProvider {
throw new Error('Model is required for completions. Please use constructor with model parameter.')
}
// 每次请求时重新生成配置以确保API key轮换生效
this.config = providerToAiSdkConfig(this.actualProvider, this.model)
logger.debug('Generated provider config for completions', this.config)
// Config is now set in constructor, ApiService handles key rotation before passing provider
if (!this.config) {
// If config wasn't set in constructor (when provider only), generate it now
this.config = providerToAiSdkConfig(this.actualProvider, this.model!)
}
logger.debug('Using provider config for completions', this.config)
// 检查 config 是否存在
if (!this.config) {

View File

@@ -24,7 +24,7 @@ export class VertexAPIClient extends GeminiAPIClient {
this.anthropicVertexClient = new AnthropicVertexClient(provider)
// 如果传入的是普通 Provider转换为 VertexProvider
if (isVertexProvider(provider)) {
this.vertexProvider = provider as VertexProvider
this.vertexProvider = provider
} else {
this.vertexProvider = createVertexProvider(provider)
}

View File

@@ -5,7 +5,6 @@ import type { MCPTool } from '@renderer/types'
import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
import { openrouterReasoningMiddleware, skipGeminiThoughtSignatureMiddleware } from '@shared/middleware'
import type { LanguageModelMiddleware } from 'ai'
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
import { isEmpty } from 'lodash'
@@ -14,7 +13,9 @@ import { getAiSdkProviderId } from '../provider/factory'
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
import { noThinkMiddleware } from './noThinkMiddleware'
import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware'
import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware'
import { toolChoiceMiddleware } from './toolChoiceMiddleware'
const logger = loggerService.withContext('AiSdkMiddlewareBuilder')

View File

@@ -0,0 +1,50 @@
import type { LanguageModelV2StreamPart } from '@ai-sdk/provider'
import type { LanguageModelMiddleware } from 'ai'
/**
* https://openrouter.ai/docs/docs/best-practices/reasoning-tokens#example-preserving-reasoning-blocks-with-openrouter-and-claude
*
* @returns LanguageModelMiddleware - a middleware filter redacted block
*/
export function openrouterReasoningMiddleware(): LanguageModelMiddleware {
const REDACTED_BLOCK = '[REDACTED]'
return {
middlewareVersion: 'v2',
wrapGenerate: async ({ doGenerate }) => {
const { content, ...rest } = await doGenerate()
const modifiedContent = content.map((part) => {
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
return {
...part,
text: part.text.replace(REDACTED_BLOCK, '')
}
}
return part
})
return { content: modifiedContent, ...rest }
},
wrapStream: async ({ doStream }) => {
const { stream, ...rest } = await doStream()
return {
stream: stream.pipeThrough(
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
transform(
chunk: LanguageModelV2StreamPart,
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
) {
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
controller.enqueue({
...chunk,
delta: chunk.delta.replace(REDACTED_BLOCK, '')
})
} else {
controller.enqueue(chunk)
}
}
})
),
...rest
}
}
}
}

View File

@@ -0,0 +1,36 @@
import type { LanguageModelMiddleware } from 'ai'
/**
* skip Gemini Thought Signature Middleware
* 由于多模型客户端请求的复杂性(可以中途切换其他模型),这里选择通过中间件方式添加跳过所有 Gemini3 思考签名
* Due to the complexity of multi-model client requests (which can switch to other models mid-process),
* it was decided to add a skip for all Gemini3 thinking signatures via middleware.
* @param aiSdkId AI SDK Provider ID
* @returns LanguageModelMiddleware
*/
export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelMiddleware {
const MAGIC_STRING = 'skip_thought_signature_validator'
return {
middlewareVersion: 'v2',
transformParams: async ({ params }) => {
const transformedParams = { ...params }
// Process messages in prompt
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
transformedParams.prompt = transformedParams.prompt.map((message) => {
if (typeof message.content !== 'string') {
for (const part of message.content) {
const googleOptions = part?.providerOptions?.[aiSdkId]
if (googleOptions?.thoughtSignature) {
googleOptions.thoughtSignature = MAGIC_STRING
}
}
}
return message
})
}
return transformedParams
}
}
}

View File

@@ -24,17 +24,7 @@ vi.mock('@renderer/services/AssistantService', () => ({
vi.mock('@renderer/store', () => ({
default: {
getState: () => ({
copilot: { defaultHeaders: {} },
llm: {
settings: {
vertexai: {
projectId: 'test-project',
location: 'us-central1'
}
}
}
})
getState: () => ({ copilot: { defaultHeaders: {} } })
}
}))
@@ -43,7 +33,7 @@ vi.mock('@renderer/utils/api', () => ({
if (isSupportedAPIVersion === false) {
return host // Return host as-is when isSupportedAPIVersion is false
}
return host ? `${host}/v1` : '' // Default behavior when isSupportedAPIVersion is true
return `${host}/v1` // Default behavior when isSupportedAPIVersion is true
}),
routeToEndpoint: vi.fn((host) => ({
baseURL: host,
@@ -51,20 +41,6 @@ vi.mock('@renderer/utils/api', () => ({
}))
}))
// Also mock @shared/api since formatProviderApiHost uses it directly
vi.mock('@shared/api', async (importOriginal) => {
const actual = (await importOriginal()) as any
return {
...actual,
formatApiHost: vi.fn((host, isSupportedAPIVersion = true) => {
if (isSupportedAPIVersion === false) {
return host || '' // Return host as-is when isSupportedAPIVersion is false
}
return host ? `${host}/v1` : '' // Default behavior when isSupportedAPIVersion is true
})
}
})
vi.mock('@renderer/utils/provider', async (importOriginal) => {
const actual = (await importOriginal()) as any
return {
@@ -97,8 +73,8 @@ vi.mock('@renderer/services/AssistantService', () => ({
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model, Provider } from '@renderer/types'
import { formatApiHost } from '@renderer/utils/api'
import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider'
import { formatApiHost } from '@shared/api'
import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants'
import { getActualProvider, providerToAiSdkConfig } from '../providerConfig'

View File

@@ -1,13 +1,13 @@
/**
* AiHubMix规则集
*/
import { getLowerBaseModelName } from '@shared/utils/naming'
import { isOpenAILLMModel } from '@renderer/config/models'
import type { Provider } from '@renderer/types'
import type { MinimalModel, MinimalProvider } from '../types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
const extraProviderConfig = <P extends MinimalProvider>(provider: P) => {
const extraProviderConfig = (provider: Provider) => {
return {
...provider,
extra_headers: {
@@ -17,23 +17,11 @@ const extraProviderConfig = <P extends MinimalProvider>(provider: P) => {
}
}
function isOpenAILLMModel<M extends MinimalModel>(model: M): boolean {
const modelId = getLowerBaseModelName(model.id)
const reasonings = ['o1', 'o3', 'o4', 'gpt-oss']
if (reasonings.some((r) => modelId.includes(r))) {
return true
}
if (modelId.includes('gpt')) {
return true
}
return false
}
const AIHUBMIX_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider) => {
provider: (provider: Provider) => {
return extraProviderConfig({
...provider,
type: 'anthropic'
@@ -46,7 +34,7 @@ const AIHUBMIX_RULES: RuleSet = {
!model.id.endsWith('-nothink') &&
!model.id.endsWith('-search') &&
!model.id.includes('embedding'),
provider: (provider) => {
provider: (provider: Provider) => {
return extraProviderConfig({
...provider,
type: 'gemini',
@@ -56,7 +44,7 @@ const AIHUBMIX_RULES: RuleSet = {
},
{
match: isOpenAILLMModel,
provider: (provider) => {
provider: (provider: Provider) => {
return extraProviderConfig({
...provider,
type: 'openai-response'
@@ -64,8 +52,7 @@ const AIHUBMIX_RULES: RuleSet = {
}
}
],
fallbackRule: (provider) => extraProviderConfig(provider)
fallbackRule: (provider: Provider) => extraProviderConfig(provider)
}
export const aihubmixProviderCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
provider2Provider<MinimalModel, MinimalProvider, P>(AIHUBMIX_RULES, model, provider)
export const aihubmixProviderCreator = provider2Provider.bind(null, AIHUBMIX_RULES)

View File

@@ -0,0 +1,22 @@
import type { Provider } from '@renderer/types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
const AZURE_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: Provider) => ({
...provider,
type: 'anthropic',
apiHost: provider.apiHost + 'anthropic/v1',
id: 'azure-anthropic'
})
}
],
fallbackRule: (provider: Provider) => provider
}
export const azureAnthropicProviderCreator = provider2Provider.bind(null, AZURE_ANTHROPIC_RULES)

View File

@@ -0,0 +1,22 @@
import type { Model, Provider } from '@renderer/types'
import type { RuleSet } from './types'
export const startsWith = (prefix: string) => (model: Model) => model.id.toLowerCase().startsWith(prefix.toLowerCase())
export const endpointIs = (type: string) => (model: Model) => model.endpoint_type === type
/**
* 解析模型对应的Provider
* @param ruleSet 规则集对象
* @param model 模型对象
* @param provider 原始provider对象
* @returns 解析出的provider对象
*/
export function provider2Provider(ruleSet: RuleSet, model: Model, provider: Provider): Provider {
for (const rule of ruleSet.rules) {
if (rule.match(model)) {
return rule.provider(provider)
}
}
return ruleSet.fallbackRule(provider)
}

View File

@@ -1,7 +1,3 @@
// Re-export from shared config
export {
aihubmixProviderCreator,
azureAnthropicProviderCreator,
newApiResolverCreator,
vertexAnthropicProviderCreator
} from '@shared/provider/config'
export { aihubmixProviderCreator } from './aihubmix'
export { newApiResolverCreator } from './newApi'
export { vertexAnthropicProviderCreator } from './vertext-anthropic'

View File

@@ -1,7 +1,8 @@
/**
* NewAPI规则集
*/
import type { MinimalModel, MinimalProvider, ProviderType } from '../types'
import type { Provider } from '@renderer/types'
import { endpointIs, provider2Provider } from './helper'
import type { RuleSet } from './types'
@@ -9,43 +10,42 @@ const NEWAPI_RULES: RuleSet = {
rules: [
{
match: endpointIs('anthropic'),
provider: (provider) => {
provider: (provider: Provider) => {
return {
...provider,
type: 'anthropic' as ProviderType
type: 'anthropic'
}
}
},
{
match: endpointIs('gemini'),
provider: (provider) => {
provider: (provider: Provider) => {
return {
...provider,
type: 'gemini' as ProviderType
type: 'gemini'
}
}
},
{
match: endpointIs('openai-response'),
provider: (provider) => {
provider: (provider: Provider) => {
return {
...provider,
type: 'openai-response' as ProviderType
type: 'openai-response'
}
}
},
{
match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model),
provider: (provider) => {
provider: (provider: Provider) => {
return {
...provider,
type: 'openai' as ProviderType
type: 'openai'
}
}
}
],
fallbackRule: (provider) => provider
fallbackRule: (provider: Provider) => provider
}
export const newApiResolverCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
provider2Provider<MinimalModel, MinimalProvider, P>(NEWAPI_RULES, model, provider)
export const newApiResolverCreator = provider2Provider.bind(null, NEWAPI_RULES)

View File

@@ -0,0 +1,9 @@
import type { Model, Provider } from '@renderer/types'
export interface RuleSet {
rules: Array<{
match: (model: Model) => boolean
provider: (provider: Provider) => Provider
}>
fallbackRule: (provider: Provider) => Provider
}

View File

@@ -0,0 +1,19 @@
import type { Provider } from '@renderer/types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
const VERTEX_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: Provider) => ({
...provider,
id: 'google-vertex-anthropic'
})
}
],
fallbackRule: (provider: Provider) => provider
}
export const vertexAnthropicProviderCreator = provider2Provider.bind(null, VERTEX_ANTHROPIC_RULES)

View File

@@ -1 +1,25 @@
export { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '@shared/provider/constant'
import type { Model } from '@renderer/types'
export const COPILOT_EDITOR_VERSION = 'vscode/1.104.1'
export const COPILOT_PLUGIN_VERSION = 'copilot-chat/0.26.7'
export const COPILOT_INTEGRATION_ID = 'vscode-chat'
export const COPILOT_USER_AGENT = 'GitHubCopilotChat/0.26.7'
export const COPILOT_DEFAULT_HEADERS = {
'Copilot-Integration-Id': COPILOT_INTEGRATION_ID,
'User-Agent': COPILOT_USER_AGENT,
'Editor-Version': COPILOT_EDITOR_VERSION,
'Editor-Plugin-Version': COPILOT_PLUGIN_VERSION,
'editor-version': COPILOT_EDITOR_VERSION,
'editor-plugin-version': COPILOT_PLUGIN_VERSION,
'copilot-vision-request': 'true'
} as const
// Models that require the OpenAI Responses endpoint when routed through GitHub Copilot (#10560)
const COPILOT_RESPONSES_MODEL_IDS = ['gpt-5-codex']
export function isCopilotResponsesModel(model: Model): boolean {
const normalizedId = model.id?.trim().toLowerCase()
const normalizedName = model.name?.trim().toLowerCase()
return COPILOT_RESPONSES_MODEL_IDS.some((target) => normalizedId === target || normalizedName === target)
}

View File

@@ -1,7 +1,8 @@
import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider'
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import type { Provider } from '@renderer/types'
import { getAiSdkProviderId as sharedGetAiSdkProviderId } from '@shared/provider'
import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from '@renderer/utils/provider'
import type { Provider as AiSdkProvider } from 'ai'
import type { AiSdkConfig } from '../types'
@@ -21,12 +22,68 @@ const logger = loggerService.withContext('ProviderFactory')
}
})()
/**
* 静态Provider映射表
* 处理Cherry Studio特有的provider ID到AI SDK标准ID的映射
*/
const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
gemini: 'google', // Google Gemini -> google
'azure-openai': 'azure', // Azure OpenAI -> azure
'openai-response': 'openai', // OpenAI Responses -> openai
grok: 'xai', // Grok -> xai
copilot: 'github-copilot-openai-compatible'
}
/**
* 尝试解析provider标识符支持静态映射和别名
*/
function tryResolveProviderId(identifier: string): ProviderId | null {
// 1. 检查静态映射
const staticMapping = STATIC_PROVIDER_MAPPING[identifier]
if (staticMapping) {
return staticMapping
}
// 2. 检查AiCore是否支持包括别名支持
if (hasProviderConfigByAlias(identifier)) {
// 解析为真实的Provider ID
return resolveProviderConfigId(identifier) as ProviderId
}
return null
}
/**
* 获取AI SDK Provider ID
* Uses shared implementation with renderer-specific config checker
* 简化版:减少重复逻辑,利用通用解析函数
*/
export function getAiSdkProviderId(provider: Provider): string {
return sharedGetAiSdkProviderId(provider)
// 1. 尝试解析provider.id
const resolvedFromId = tryResolveProviderId(provider.id)
if (isAzureOpenAIProvider(provider)) {
if (isAzureResponsesEndpoint(provider)) {
return 'azure-responses'
} else {
return 'azure'
}
}
if (resolvedFromId) {
return resolvedFromId
}
// 2. 尝试解析provider.type
// 会把所有类型为openai的自定义provider解析到aisdk的openaiProvider上
if (provider.type !== 'openai') {
const resolvedFromType = tryResolveProviderId(provider.type)
if (resolvedFromType) {
return resolvedFromType
}
}
if (provider.apiHost.includes('api.openai.com')) {
return 'openai-chat'
}
// 3. 最后的fallback使用provider本身的id
return provider.id
}
export async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {

View File

@@ -1,4 +1,4 @@
import { hasProviderConfig } from '@cherrystudio/ai-core/provider'
import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
import {
getAwsBedrockAccessKeyId,
@@ -10,118 +10,237 @@ import {
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService'
import store from '@renderer/store'
import { isSystemProvider, type Model, type Provider } from '@renderer/types'
import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import { formatApiHost, formatAzureOpenAIApiHost, formatVertexApiHost, routeToEndpoint } from '@renderer/utils/api'
import {
type AiSdkConfigContext,
formatProviderApiHost as sharedFormatProviderApiHost,
type ProviderFormatContext,
providerToAiSdkConfig as sharedProviderToAiSdkConfig,
resolveActualProvider
} from '@shared/provider'
isAnthropicProvider,
isAzureOpenAIProvider,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isPerplexityProvider,
isVertexProvider
} from '@renderer/utils/provider'
import { cloneDeep } from 'lodash'
import type { AiSdkConfig } from '../types'
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
import { COPILOT_DEFAULT_HEADERS } from './constants'
import { getAiSdkProviderId } from './factory'
/**
* 获取轮询的API key
* 复用legacy架构的多key轮询逻辑
* 处理特殊provider的转换逻辑
*/
function getRotatedApiKey(provider: Provider): string {
const keys = provider.apiKey.split(',').map((key) => key.trim())
const keyName = `provider:${provider.id}:last_used_key`
if (keys.length === 1) {
return keys[0]
function handleSpecialProviders(model: Model, provider: Provider): Provider {
if (isNewApiProvider(provider)) {
return newApiResolverCreator(model, provider)
}
const lastUsedKey = window.keyv.get(keyName)
if (!lastUsedKey) {
window.keyv.set(keyName, keys[0])
return keys[0]
}
const currentIndex = keys.indexOf(lastUsedKey)
const nextIndex = (currentIndex + 1) % keys.length
const nextKey = keys[nextIndex]
window.keyv.set(keyName, nextKey)
return nextKey
}
/**
* Renderer-specific context for providerToAiSdkConfig
* Provides implementations using browser APIs, store, and hooks
*/
function createRendererSdkContext(model: Model): AiSdkConfigContext {
return {
getRotatedApiKey: (provider) => getRotatedApiKey(provider as Provider),
isOpenAIChatCompletionOnlyModel: () => isOpenAIChatCompletionOnlyModel(model),
getCopilotDefaultHeaders: () => COPILOT_DEFAULT_HEADERS,
getCopilotStoredHeaders: () => store.getState().copilot.defaultHeaders ?? {},
getAwsBedrockConfig: () => {
const authType = getAwsBedrockAuthType()
return {
authType,
region: getAwsBedrockRegion(),
apiKey: authType === 'apiKey' ? getAwsBedrockApiKey() : undefined,
accessKeyId: authType === 'iam' ? getAwsBedrockAccessKeyId() : undefined,
secretAccessKey: authType === 'iam' ? getAwsBedrockSecretAccessKey() : undefined
}
},
getVertexConfig: (provider) => {
if (!isVertexAIConfigured()) {
return undefined
}
return createVertexProvider(provider as Provider)
},
getEndpointType: () => model.endpoint_type
}
}
/**
* 主要用来对齐AISdk的BaseURL格式
* Uses shared implementation with renderer-specific context
*/
function getRendererFormatContext(): ProviderFormatContext {
const vertexSettings = store.getState().llm.settings.vertexai
return {
vertex: {
project: vertexSettings.projectId || 'default-project',
location: vertexSettings.location || 'us-central1'
if (isSystemProvider(provider)) {
if (provider.id === 'aihubmix') {
return aihubmixProviderCreator(model, provider)
}
if (provider.id === 'vertexai') {
return vertexAnthropicProviderCreator(model, provider)
}
}
}
function formatProviderApiHost(provider: Provider): Provider {
return sharedFormatProviderApiHost(provider, getRendererFormatContext())
if (isAzureOpenAIProvider(provider)) {
return azureAnthropicProviderCreator(model, provider)
}
return provider
}
/**
* 获取实际的Provider配置
* 简化版:将逻辑分解为小函数
* Format and normalize the API host URL for a provider.
* Handles provider-specific URL formatting rules (e.g., appending version paths, Azure formatting).
*
* @param provider - The provider whose API host is to be formatted.
* @returns A new provider instance with the formatted API host.
*/
export function formatProviderApiHost(provider: Provider): Provider {
const formatted = { ...provider }
if (formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost)
}
if (isAnthropicProvider(provider)) {
const baseHost = formatted.anthropicApiHost || formatted.apiHost
// AI SDK needs /v1 in baseURL, Anthropic SDK will strip it in getSdkClient
formatted.apiHost = formatApiHost(baseHost)
if (!formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatted.apiHost
}
} else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isGeminiProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, true, 'v1beta')
} else if (isAzureOpenAIProvider(formatted)) {
formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost)
} else if (isVertexProvider(formatted)) {
formatted.apiHost = formatVertexApiHost(formatted)
} else if (isCherryAIProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isPerplexityProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else {
formatted.apiHost = formatApiHost(formatted.apiHost)
}
return formatted
}
/**
* Retrieve the effective Provider configuration for the given model.
* Applies all necessary transformations (special-provider handling, URL formatting, etc.).
*
* @param model - The model whose provider is to be resolved.
* @returns A new Provider instance with all adaptations applied.
*/
export function getActualProvider(model: Model): Provider {
const baseProvider = getProviderByModel(model)
// 按顺序处理各种转换
let actualProvider = cloneDeep(baseProvider)
actualProvider = resolveActualProvider(actualProvider, model, {
isSystemProvider
}) as Provider
actualProvider = formatProviderApiHost(actualProvider)
return adaptProvider({ provider: baseProvider, model })
}
return actualProvider
/**
* Transforms a provider configuration by applying model-specific adaptations and normalizing its API host.
* The transformations are applied in the following order:
* 1. Model-specific provider handling (e.g., New-API, system providers, Azure OpenAI)
* 2. API host formatting (provider-specific URL normalization)
*
* @param provider - The base provider configuration to transform.
* @param model - The model associated with the provider; optional but required for special-provider handling.
* @returns A new Provider instance with all transformations applied.
*/
export function adaptProvider({ provider, model }: { provider: Provider; model?: Model }): Provider {
let adaptedProvider = cloneDeep(provider)
// Apply transformations in order
if (model) {
adaptedProvider = handleSpecialProviders(model, adaptedProvider)
}
adaptedProvider = formatProviderApiHost(adaptedProvider)
return adaptedProvider
}
/**
* 将 Provider 配置转换为新 AI SDK 格式
* Uses shared implementation with renderer-specific context
* 简化版:利用新的别名映射系统
*/
export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig {
const context = createRendererSdkContext(model)
return sharedProviderToAiSdkConfig(actualProvider, model.id, context) as AiSdkConfig
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
// 构建基础配置
const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost)
const baseConfig = {
baseURL: baseURL,
apiKey: actualProvider.apiKey
}
const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot
if (isCopilotProvider) {
const storedHeaders = store.getState().copilot.defaultHeaders ?? {}
const options = ProviderConfigFactory.fromProvider('github-copilot-openai-compatible', baseConfig, {
headers: {
...COPILOT_DEFAULT_HEADERS,
...storedHeaders,
...actualProvider.extra_headers
},
name: actualProvider.id,
includeUsage: true
})
return {
providerId: 'github-copilot-openai-compatible',
options
}
}
// 处理OpenAI模式
const extraOptions: any = {}
extraOptions.endpoint = endpoint
if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && actualProvider.type === 'openai')) {
extraOptions.mode = 'chat'
}
// 添加额外headers
if (actualProvider.extra_headers) {
extraOptions.headers = actualProvider.extra_headers
// copy from openaiBaseClient/openaiResponseApiClient
if (aiSdkProviderId === 'openai') {
extraOptions.headers = {
...extraOptions.headers,
'HTTP-Referer': 'https://cherry-ai.com',
'X-Title': 'Cherry Studio',
'X-Api-Key': baseConfig.apiKey
}
}
}
// azure
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses?tabs=python-key#responses-api
if (aiSdkProviderId === 'azure-responses') {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'azure') {
extraOptions.mode = 'chat'
}
// bedrock
if (aiSdkProviderId === 'bedrock') {
const authType = getAwsBedrockAuthType()
extraOptions.region = getAwsBedrockRegion()
if (authType === 'apiKey') {
extraOptions.apiKey = getAwsBedrockApiKey()
} else {
extraOptions.accessKeyId = getAwsBedrockAccessKeyId()
extraOptions.secretAccessKey = getAwsBedrockSecretAccessKey()
}
}
// google-vertex
if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') {
if (!isVertexAIConfigured()) {
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
}
const { project, location, googleCredentials } = createVertexProvider(actualProvider)
extraOptions.project = project
extraOptions.location = location
extraOptions.googleCredentials = {
...googleCredentials,
privateKey: formatPrivateKey(googleCredentials.privateKey)
}
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
}
// cherryin
if (aiSdkProviderId === 'cherryin') {
if (model.endpoint_type) {
extraOptions.endpointType = model.endpoint_type
}
}
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return {
providerId: aiSdkProviderId,
options
}
}
// 否则fallback到openai-compatible
const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey)
return {
providerId: 'openai-compatible',
options: {
...options,
name: actualProvider.id,
...extraOptions,
includeUsage: true
}
}
}
/**
@@ -164,13 +283,13 @@ export async function prepareSpecialProviderConfig(
break
}
case 'cherryai': {
config.options.fetch = async (url: RequestInfo | URL, options: RequestInit) => {
config.options.fetch = async (url, options) => {
// 在这里对最终参数进行签名
const signature = await window.api.cherryai.generateSignature({
method: 'POST',
path: '/chat/completions',
query: '',
body: JSON.parse(options.body as string)
body: JSON.parse(options.body)
})
return fetch(url, {
...options,

View File

@@ -1,13 +1,113 @@
import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import { initializeSharedProviders, SHARED_PROVIDER_CONFIGS } from '@shared/provider'
const logger = loggerService.withContext('ProviderConfigs')
export const NEW_PROVIDER_CONFIGS = SHARED_PROVIDER_CONFIGS
/**
* 新Provider配置定义
* 定义了需要动态注册的AI Providers
*/
export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
{
id: 'openrouter',
name: 'OpenRouter',
import: () => import('@openrouter/ai-sdk-provider'),
creatorFunctionName: 'createOpenRouter',
supportsImageGeneration: true,
aliases: ['openrouter']
},
{
id: 'google-vertex',
name: 'Google Vertex AI',
import: () => import('@ai-sdk/google-vertex/edge'),
creatorFunctionName: 'createVertex',
supportsImageGeneration: true,
aliases: ['vertexai']
},
{
id: 'google-vertex-anthropic',
name: 'Google Vertex AI Anthropic',
import: () => import('@ai-sdk/google-vertex/anthropic/edge'),
creatorFunctionName: 'createVertexAnthropic',
supportsImageGeneration: true,
aliases: ['vertexai-anthropic']
},
{
id: 'azure-anthropic',
name: 'Azure AI Anthropic',
import: () => import('@ai-sdk/anthropic'),
creatorFunctionName: 'createAnthropic',
supportsImageGeneration: false,
aliases: ['azure-anthropic']
},
{
id: 'github-copilot-openai-compatible',
name: 'GitHub Copilot OpenAI Compatible',
import: () => import('@opeoginni/github-copilot-openai-compatible'),
creatorFunctionName: 'createGitHubCopilotOpenAICompatible',
supportsImageGeneration: false,
aliases: ['copilot', 'github-copilot']
},
{
id: 'bedrock',
name: 'Amazon Bedrock',
import: () => import('@ai-sdk/amazon-bedrock'),
creatorFunctionName: 'createAmazonBedrock',
supportsImageGeneration: true,
aliases: ['aws-bedrock']
},
{
id: 'perplexity',
name: 'Perplexity',
import: () => import('@ai-sdk/perplexity'),
creatorFunctionName: 'createPerplexity',
supportsImageGeneration: false,
aliases: ['perplexity']
},
{
id: 'mistral',
name: 'Mistral',
import: () => import('@ai-sdk/mistral'),
creatorFunctionName: 'createMistral',
supportsImageGeneration: false,
aliases: ['mistral']
},
{
id: 'huggingface',
name: 'HuggingFace',
import: () => import('@ai-sdk/huggingface'),
creatorFunctionName: 'createHuggingFace',
supportsImageGeneration: true,
aliases: ['hf', 'hugging-face']
},
{
id: 'ai-gateway',
name: 'AI Gateway',
import: () => import('@ai-sdk/gateway'),
creatorFunctionName: 'createGateway',
supportsImageGeneration: true,
aliases: ['gateway']
},
{
id: 'cerebras',
name: 'Cerebras',
import: () => import('@ai-sdk/cerebras'),
creatorFunctionName: 'createCerebras',
supportsImageGeneration: false
}
] as const
/**
* 初始化新的Providers
* 使用aiCore的动态注册功能
*/
export async function initializeNewProviders(): Promise<void> {
initializeSharedProviders({
warn: (message) => logger.warn(message),
error: (message, error) => logger.error(message, error)
})
try {
const successCount = registerMultipleProviderConfigs(NEW_PROVIDER_CONFIGS)
if (successCount < NEW_PROVIDER_CONFIGS.length) {
logger.warn('Some providers failed to register. Check previous error logs.')
}
} catch (error) {
logger.error('Failed to initialize new providers:', error as Error)
}
}

View File

@@ -15,7 +15,6 @@ import {
isSupportVerbosityModel
} from '../openai'
import { isQwenMTModel } from '../qwen'
import { isFunctionCallingModel } from '../tooluse'
import {
agentModelFilter,
getModelSupportedVerbosity,
@@ -113,7 +112,6 @@ const textToImageMock = vi.mocked(isTextToImageModel)
const generateImageMock = vi.mocked(isGenerateImageModel)
const reasoningMock = vi.mocked(isOpenAIReasoningModel)
const openAIWebSearchOnlyMock = vi.mocked(isOpenAIWebSearchChatCompletionOnlyModel)
const isFunctionCallingModelMock = vi.mocked(isFunctionCallingModel)
describe('model utils', () => {
beforeEach(() => {
@@ -122,7 +120,7 @@ describe('model utils', () => {
rerankMock.mockReturnValue(false)
visionMock.mockReturnValue(true)
textToImageMock.mockReturnValue(false)
generateImageMock.mockReturnValue(false)
generateImageMock.mockReturnValue(true)
reasoningMock.mockReturnValue(false)
openAIWebSearchOnlyMock.mockReturnValue(false)
})
@@ -420,7 +418,6 @@ describe('model utils', () => {
describe('isGenerateImageModels', () => {
it('returns true when all models support image generation', () => {
const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })]
generateImageMock.mockReturnValue(true)
expect(isGenerateImageModels(models)).toBe(true)
})
@@ -459,22 +456,12 @@ describe('model utils', () => {
expect(agentModelFilter(createModel({ id: 'rerank' }))).toBe(false)
})
it('filters out non-function-call models', () => {
rerankMock.mockReturnValue(false)
isFunctionCallingModelMock.mockReturnValueOnce(false)
expect(agentModelFilter(createModel({ id: 'DeepSeek R1' }))).toBe(false)
})
it('filters out text-to-image models', () => {
rerankMock.mockReturnValue(false)
textToImageMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'gpt-image-1' }))).toBe(false)
})
})
textToImageMock.mockReturnValue(false)
generateImageMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'dall-e-3' }))).toBe(false)
})
describe('Temperature limits', () => {

View File

@@ -1,8 +1,6 @@
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model } from '@renderer/types'
import { isSystemProviderId } from '@renderer/types'
import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
import { isAzureOpenAIProvider } from '@shared/provider'
import { isEmbeddingModel, isRerankModel } from './embedding'
import { isDeepSeekHybridInferenceModel } from './reasoning'
@@ -54,13 +52,6 @@ export const FUNCTION_CALLING_REGEX = new RegExp(
'i'
)
const AZURE_FUNCTION_CALLING_EXCLUDED_MODELS = [
'(?:Meta-)?Llama-3(?:\\.\\d+)?-[\\w-]+',
'Phi-[34](?:\\.[\\w-]+)?(?:-[\\w-]+)?',
'DeepSeek-(?:R1|V3)',
'Codestral-2501'
]
export function isFunctionCallingModel(model?: Model): boolean {
if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) {
return false
@@ -76,15 +67,6 @@ export function isFunctionCallingModel(model?: Model): boolean {
return FUNCTION_CALLING_REGEX.test(modelId) || FUNCTION_CALLING_REGEX.test(model.name)
}
const provider = getProviderByModel(model)
if (isAzureOpenAIProvider(provider)) {
const azureExcludedRegex = new RegExp(`\\b(?:${AZURE_FUNCTION_CALLING_EXCLUDED_MODELS.join('|')})\\b`, 'i')
if (azureExcludedRegex.test(modelId)) {
return false
}
}
if (['deepseek', 'anthropic', 'kimi', 'moonshot'].includes(model.provider)) {
return true
}

View File

@@ -1,6 +1,5 @@
import type OpenAI from '@cherrystudio/openai'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding'
import { getProviderByModel } from '@renderer/services/AssistantService'
import { type Model, SystemProviderIds } from '@renderer/types'
import type { OpenAIVerbosity, ValidOpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import { getLowerBaseModelName } from '@renderer/utils'
@@ -14,7 +13,6 @@ import {
isOpenAIReasoningModel
} from './openai'
import { isQwenMTModel } from './qwen'
import { isFunctionCallingModel } from './tooluse'
import { isGenerateImageModel, isTextToImageModel, isVisionModel } from './vision'
export const NOT_SUPPORTED_REGEX = /(?:^tts|whisper|speech)/i
export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini.*-flash.*$', 'i')
@@ -183,21 +181,8 @@ export const isGeminiModel = (model: Model) => {
// zhipu 视觉推理模型用这组 special token 标记推理结果
export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const
// TODO: 支持提示词模式的工具调用
export const agentModelFilter = (model: Model): boolean => {
const provider = getProviderByModel(model)
// 需要适配,且容易超出限额
if (provider.id === SystemProviderIds.copilot) {
return false
}
return (
!isEmbeddingModel(model) &&
!isRerankModel(model) &&
!isTextToImageModel(model) &&
!isGenerateImageModel(model) &&
isFunctionCallingModel(model)
)
return !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model)
}
export const isMaxTemperatureOneModel = (model: Model): boolean => {

View File

@@ -4372,7 +4372,7 @@
"url": {
"preview": "Preview: {{url}}",
"reset": "Reset",
"tip": "ending with # forces use of input address"
"tip": "Add # at the end to disable the automatically appended API version."
}
},
"api_host": "API Host",

View File

@@ -4372,7 +4372,7 @@
"url": {
"preview": "预览: {{url}}",
"reset": "重置",
"tip": "# 结尾强制使用输入地址"
"tip": "在末尾添加 # 以禁用自动附加的API版本。"
}
},
"api_host": "API 地址",

View File

@@ -4372,7 +4372,7 @@
"url": {
"preview": "預覽:{{url}}",
"reset": "重設",
"tip": "# 結尾強制使用輸入位址"
"tip": "在末尾添加 # 以停用自動附加的 API 版本。"
}
},
"api_host": "API 主機地址",

View File

@@ -4372,7 +4372,7 @@
"url": {
"preview": "Vorschau: {{url}}",
"reset": "Zurücksetzen",
"tip": "# am Ende erzwingt die Verwendung der Eingabe-Adresse"
"tip": "Fügen Sie am Ende ein # hinzu, um die automatisch angehängte API-Version zu deaktivieren."
}
},
"api_host": "API-Adresse",

View File

@@ -4372,7 +4372,7 @@
"url": {
"preview": "Προεπισκόπηση: {{url}}",
"reset": "Επαναφορά",
"tip": "#τέλος ενδεχόμενη χρήση της εισαγωγής διευθύνσεως"
"tip": "Προσθέστε το σύμβολο # στο τέλος για να απενεργοποιήσετε την αυτόματα προστιθέμενη έκδοση API."
}
},
"api_host": "Διεύθυνση API",

View File

@@ -4372,7 +4372,7 @@
"url": {
"preview": "Vista previa: {{url}}",
"reset": "Restablecer",
"tip": "forzar uso de dirección de entrada con # al final"
"tip": "Añada # al final para deshabilitar la versión de la API que se añade automáticamente."
}
},
"api_host": "Dirección API",

View File

@@ -4372,7 +4372,7 @@
"url": {
"preview": "Aperçu : {{url}}",
"reset": "Réinitialiser",
"tip": "forcer l'utilisation de l'adresse d'entrée si terminé par #"
"tip": "Ajoutez # à la fin pour désactiver la version d'API ajoutée automatiquement."
}
},
"api_host": "Adresse API",

View File

@@ -4372,7 +4372,7 @@
"url": {
"preview": "プレビュー: {{url}}",
"reset": "リセット",
"tip": "#で終わる場合、入力されたアドレスを強制的に使用します"
"tip": "自動的に付加されるAPIバージョンを無効にするには、末尾に#を追加します"
}
},
"api_host": "APIホスト",

View File

@@ -4372,7 +4372,7 @@
"url": {
"preview": "Pré-visualização: {{url}}",
"reset": "Redefinir",
"tip": "e forçar o uso do endereço original quando terminar com '#'"
"tip": "Adicione # no final para desativar a versão da API adicionada automaticamente."
}
},
"api_host": "Endereço API",

View File

@@ -4372,7 +4372,7 @@
"url": {
"preview": "Предпросмотр: {{url}}",
"reset": "Сброс",
"tip": "заканчивая на # принудительно использует введенный адрес"
"tip": "Добавьте # в конце, чтобы отключить автоматически добавляемую версию API."
}
},
"api_host": "Хост API",

View File

@@ -17,7 +17,7 @@ import type { EndpointType, Model } from '@renderer/types'
import { getClaudeSupportedProviders } from '@renderer/utils/provider'
import type { TerminalConfig } from '@shared/config/constant'
import { codeTools, terminalApps } from '@shared/config/constant'
import { isPpioAnthropicCompatibleModel, isSiliconAnthropicCompatibleModel } from '@shared/config/providers'
import { isSiliconAnthropicCompatibleModel } from '@shared/config/providers'
import { Alert, Avatar, Button, Checkbox, Input, Popover, Select, Space, Tooltip } from 'antd'
import { ArrowUpRight, Download, FolderOpen, HelpCircle, Terminal, X } from 'lucide-react'
import type { FC } from 'react'
@@ -82,12 +82,10 @@ const CodeToolsPage: FC = () => {
if (m.supported_endpoint_types) {
return m.supported_endpoint_types.includes('anthropic')
}
// Special handling for silicon provider: only specific models support Anthropic API
if (m.provider === 'silicon') {
return isSiliconAnthropicCompatibleModel(m.id)
}
if (m.provider === 'ppio') {
return isPpioAnthropicCompatibleModel(m.id)
}
return m.id.includes('claude') || CLAUDE_OFFICIAL_SUPPORTED_PROVIDERS.includes(m.provider)
}

View File

@@ -1,8 +1,10 @@
import { adaptProvider } from '@renderer/aiCore/provider/providerConfig'
import OpenAIAlert from '@renderer/components/Alert/OpenAIAlert'
import { LoadingIcon } from '@renderer/components/Icons'
import { HStack } from '@renderer/components/Layout'
import { ApiKeyListPopup } from '@renderer/components/Popups/ApiKeyListPopup'
import Selector from '@renderer/components/Selector'
import { HelpTooltip } from '@renderer/components/TooltipIcons'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { PROVIDER_URLS } from '@renderer/config/providers'
import { useTheme } from '@renderer/context/ThemeProvider'
@@ -19,14 +21,7 @@ import type { SystemProviderId } from '@renderer/types'
import { isSystemProvider, isSystemProviderId, SystemProviderIds } from '@renderer/types'
import type { ApiKeyConnectivity } from '@renderer/types/healthCheck'
import { HealthStatus } from '@renderer/types/healthCheck'
import {
formatApiHost,
formatApiKeys,
formatAzureOpenAIApiHost,
formatVertexApiHost,
getFancyProviderName,
validateApiHost
} from '@renderer/utils'
import { formatApiHost, formatApiKeys, getFancyProviderName, validateApiHost } from '@renderer/utils'
import { formatErrorMessage } from '@renderer/utils/error'
import {
isAIGatewayProvider,
@@ -36,7 +31,6 @@ import {
isNewApiProvider,
isOpenAICompatibleProvider,
isOpenAIProvider,
isSupportAPIVersionProvider,
isVertexProvider
} from '@renderer/utils/provider'
import { Button, Divider, Flex, Input, Select, Space, Switch, Tooltip } from 'antd'
@@ -85,8 +79,7 @@ const ANTHROPIC_COMPATIBLE_PROVIDER_IDS = [
SystemProviderIds.minimax,
SystemProviderIds.silicon,
SystemProviderIds.qiniu,
SystemProviderIds.dmxapi,
SystemProviderIds.ppio
SystemProviderIds.dmxapi
] as const
type AnthropicCompatibleProviderId = (typeof ANTHROPIC_COMPATIBLE_PROVIDER_IDS)[number]
@@ -282,12 +275,10 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
}, [configuredApiHost, apiHost])
const hostPreview = () => {
if (apiHost.endsWith('#')) {
return apiHost.replace('#', '')
}
const formattedApiHost = adaptProvider({ provider: { ...provider, apiHost } }).apiHost
if (isOpenAICompatibleProvider(provider)) {
return formatApiHost(apiHost, isSupportAPIVersionProvider(provider)) + '/chat/completions'
return formattedApiHost + '/chat/completions'
}
if (isAzureOpenAIProvider(provider)) {
@@ -295,29 +286,26 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
const path = !['preview', 'v1'].includes(apiVersion)
? `/v1/chat/completion?apiVersion=v1`
: `/v1/responses?apiVersion=v1`
return formatAzureOpenAIApiHost(apiHost) + path
return formattedApiHost + path
}
if (isAnthropicProvider(provider)) {
// AI SDK uses the baseURL with /v1, then appends /messages
// formatApiHost adds /v1 automatically if not present
const normalizedHost = formatApiHost(apiHost)
return normalizedHost + '/messages'
return formattedApiHost + '/messages'
}
if (isGeminiProvider(provider)) {
return formatApiHost(apiHost, true, 'v1beta') + '/models'
return formattedApiHost + '/models'
}
if (isOpenAIProvider(provider)) {
return formatApiHost(apiHost) + '/responses'
return formattedApiHost + '/responses'
}
if (isVertexProvider(provider)) {
return formatVertexApiHost(provider) + '/publishers/google'
return formattedApiHost + '/publishers/google'
}
if (isAIGatewayProvider(provider)) {
return formatApiHost(apiHost) + '/language-model'
return formattedApiHost + '/language-model'
}
return formatApiHost(apiHost)
return formattedApiHost
}
// API key 连通性检查状态指示器,目前仅在失败时显示
@@ -495,16 +483,21 @@ const ProviderSetting: FC<Props> = ({ providerId }) => {
{!isDmxapi && (
<>
<SettingSubtitle style={{ display: 'flex', alignItems: 'center', justifyContent: 'space-between' }}>
<Tooltip title={hostSelectorTooltip} mouseEnterDelay={0.3}>
<Selector
size={14}
value={activeHostField}
onChange={(value) => setActiveHostField(value as HostField)}
options={hostSelectorOptions}
style={{ paddingLeft: 1, fontWeight: 'bold' }}
placement="bottomLeft"
/>
</Tooltip>
<div className="flex items-center gap-1">
<Tooltip title={hostSelectorTooltip} mouseEnterDelay={0.3}>
<div>
<Selector
size={14}
value={activeHostField}
onChange={(value) => setActiveHostField(value as HostField)}
options={hostSelectorOptions}
style={{ paddingLeft: 1, fontWeight: 'bold' }}
placement="bottomLeft"
/>
</div>
</Tooltip>
<HelpTooltip title={t('settings.provider.api.url.tip')}></HelpTooltip>
</div>
<div style={{ display: 'flex', alignItems: 'center', gap: 4 }}>
<Button
type="text"

View File

@@ -8,8 +8,8 @@ import { isDedicatedImageGenerationModel, isEmbeddingModel, isFunctionCallingMod
import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import store from '@renderer/store'
import type { FetchChatCompletionParams } from '@renderer/types'
import type { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/types'
import { type FetchChatCompletionParams, isSystemProvider } from '@renderer/types'
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
import { type Chunk, ChunkType } from '@renderer/types/chunk'
import type { Message, ResponseError } from '@renderer/types/newMessage'
@@ -22,7 +22,8 @@ import { purifyMarkdownImages } from '@renderer/utils/markdown'
import { isPromptToolUse, isSupportedToolUse } from '@renderer/utils/mcp-tools'
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { containsSupportedVariables, replacePromptVariables } from '@renderer/utils/prompt'
import { isEmpty, takeRight } from 'lodash'
import { NOT_SUPPORT_API_KEY_PROVIDERS } from '@renderer/utils/provider'
import { cloneDeep, isEmpty, takeRight } from 'lodash'
import type { ModernAiProviderConfig } from '../aiCore/index_new'
import AiProviderNew from '../aiCore/index_new'
@@ -43,6 +44,8 @@ import {
// } from './MessagesService'
// import WebSearchService from './WebSearchService'
// FIXME: 这里太多重复逻辑,需要重构
const logger = loggerService.withContext('ApiService')
export async function fetchMcpTools(assistant: Assistant) {
@@ -95,7 +98,15 @@ export async function fetchChatCompletion({
modelId: assistant.model?.id,
modelName: assistant.model?.name
})
const AI = new AiProviderNew(assistant.model || getDefaultModel())
// Get base provider and apply API key rotation
const baseProvider = getProviderByModel(assistant.model || getDefaultModel())
const providerWithRotatedKey = {
...cloneDeep(baseProvider),
apiKey: getRotatedApiKey(baseProvider)
}
const AI = new AiProviderNew(assistant.model || getDefaultModel(), providerWithRotatedKey)
const provider = AI.getActualProvider()
const mcpTools: MCPTool[] = []
@@ -172,7 +183,13 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages:
return null
}
const AI = new AiProviderNew(model)
// Apply API key rotation
const providerWithRotatedKey = {
...cloneDeep(provider),
apiKey: getRotatedApiKey(provider)
}
const AI = new AiProviderNew(model, providerWithRotatedKey)
const topicId = messages?.find((message) => message.topicId)?.topicId || ''
@@ -271,7 +288,13 @@ export async function fetchNoteSummary({ content, assistant }: { content: string
return null
}
const AI = new AiProviderNew(model)
// Apply API key rotation
const providerWithRotatedKey = {
...cloneDeep(provider),
apiKey: getRotatedApiKey(provider)
}
const AI = new AiProviderNew(model, providerWithRotatedKey)
// only 2000 char and no images
const truncatedContent = content.substring(0, 2000)
@@ -359,7 +382,13 @@ export async function fetchGenerate({
return ''
}
const AI = new AiProviderNew(model)
// Apply API key rotation
const providerWithRotatedKey = {
...cloneDeep(provider),
apiKey: getRotatedApiKey(provider)
}
const AI = new AiProviderNew(model, providerWithRotatedKey)
const assistant = getDefaultAssistant()
assistant.model = model
@@ -404,28 +433,44 @@ export async function fetchGenerate({
export function hasApiKey(provider: Provider) {
if (!provider) return false
if (['ollama', 'lmstudio', 'vertexai', 'cherryai'].includes(provider.id)) return true
if (isSystemProvider(provider) && NOT_SUPPORT_API_KEY_PROVIDERS.includes(provider.id)) return true
return !isEmpty(provider.apiKey)
}
/**
* Get the first available embedding model from enabled providers
* 获取轮询的API key
* 复用legacy架构的多key轮询逻辑
*/
// function getFirstEmbeddingModel() {
// const providers = store.getState().llm.providers.filter((p) => p.enabled)
function getRotatedApiKey(provider: Provider): string {
const keys = provider.apiKey.split(',').map((key) => key.trim())
const keyName = `provider:${provider.id}:last_used_key`
// for (const provider of providers) {
// const embeddingModel = provider.models.find((model) => isEmbeddingModel(model))
// if (embeddingModel) {
// return embeddingModel
// }
// }
if (keys.length === 1) {
return keys[0]
}
// return undefined
// }
const lastUsedKey = window.keyv.get(keyName)
if (!lastUsedKey) {
window.keyv.set(keyName, keys[0])
return keys[0]
}
const currentIndex = keys.indexOf(lastUsedKey)
const nextIndex = (currentIndex + 1) % keys.length
const nextKey = keys[nextIndex]
window.keyv.set(keyName, nextKey)
return nextKey
}
export async function fetchModels(provider: Provider): Promise<SdkModel[]> {
const AI = new AiProviderNew(provider)
// Apply API key rotation
const providerWithRotatedKey = {
...cloneDeep(provider),
apiKey: getRotatedApiKey(provider)
}
const AI = new AiProviderNew(providerWithRotatedKey)
try {
return await AI.models()
@@ -435,12 +480,7 @@ export async function fetchModels(provider: Provider): Promise<SdkModel[]> {
}
export function checkApiProvider(provider: Provider): void {
if (
provider.id !== 'ollama' &&
provider.id !== 'lmstudio' &&
provider.type !== 'vertexai' &&
provider.id !== 'copilot'
) {
if (isSystemProvider(provider) && !NOT_SUPPORT_API_KEY_PROVIDERS.includes(provider.id)) {
if (!provider.apiKey) {
window.toast.error(i18n.t('message.error.enter.api.label'))
throw new Error(i18n.t('message.error.enter.api.label'))
@@ -461,8 +501,7 @@ export function checkApiProvider(provider: Provider): void {
export async function checkApi(provider: Provider, model: Model, timeout = 15000): Promise<void> {
checkApiProvider(provider)
// Don't pass in provider parameter. We need auto-format URL
const ai = new AiProviderNew(model)
const ai = new AiProviderNew(model, provider)
const assistant = getDefaultAssistant()
assistant.model = model

View File

@@ -67,7 +67,7 @@ const persistedReducer = persistReducer(
{
key: 'cherry-studio',
storage,
version: 180,
version: 179,
blacklist: ['runtime', 'messages', 'messageBlocks', 'tabs', 'toolPermissions'],
migrate
},

View File

@@ -2906,20 +2906,6 @@ const migrateConfig = {
logger.error('migrate 179 error', error as Error)
return state
}
},
'180': (state: RootState) => {
try {
state.llm.providers.forEach((provider) => {
if (provider.id === SystemProviderIds.ppio) {
provider.anthropicApiHost = 'https://api.ppinfra.com/anthropic'
}
})
logger.info('migrate 180 success')
return state
} catch (error) {
logger.error('migrate 180 error', error as Error)
return state
}
}
}

View File

@@ -7,8 +7,6 @@ import type { CSSProperties } from 'react'
export * from './file'
export * from './note'
import type { MinimalModel } from '@shared/provider/types'
import type { StreamTextParams } from './aiCoreTypes'
import type { Chunk } from './chunk'
import type { FileMetadata } from './file'
@@ -259,7 +257,7 @@ export type ModelCapability = {
isUserSelected?: boolean
}
export type Model = MinimalModel & {
export type Model = {
id: string
provider: string
name: string

View File

@@ -1,14 +1,24 @@
import type OpenAI from '@cherrystudio/openai'
import type { MinimalProvider } from '@shared/provider'
import type { ProviderType, SystemProviderId, SystemProviderIdTypeMap } from '@shared/provider/types'
import { isSystemProviderId, SystemProviderIds } from '@shared/provider/types'
import type { Model } from '@types'
import * as z from 'zod'
import type { OpenAIVerbosity } from './aiCoreTypes'
export type { ProviderType } from '@shared/provider'
export type { SystemProviderId, SystemProviderIdTypeMap } from '@shared/provider/types'
export { isSystemProviderId, ProviderTypeSchema, SystemProviderIds } from '@shared/provider/types'
export const ProviderTypeSchema = z.enum([
'openai',
'openai-response',
'anthropic',
'gemini',
'azure-openai',
'vertexai',
'mistral',
'aws-bedrock',
'vertex-anthropic',
'new-api',
'ai-gateway'
])
export type ProviderType = z.infer<typeof ProviderTypeSchema>
// undefined is treated as supported, enabled by default
export type ProviderApiOptions = {
@@ -83,7 +93,7 @@ export function isAwsBedrockAuthType(type: string): type is AwsBedrockAuthType {
return Object.hasOwn(AwsBedrockAuthTypes, type)
}
export type Provider = MinimalProvider & {
export type Provider = {
id: string
type: ProviderType
name: string
@@ -118,6 +128,140 @@ export type Provider = MinimalProvider & {
extra_headers?: Record<string, string>
}
export const SystemProviderIdSchema = z.enum([
'cherryin',
'silicon',
'aihubmix',
'ocoolai',
'deepseek',
'ppio',
'alayanew',
'qiniu',
'dmxapi',
'burncloud',
'tokenflux',
'302ai',
'cephalon',
'lanyun',
'ph8',
'openrouter',
'ollama',
'ovms',
'new-api',
'lmstudio',
'anthropic',
'openai',
'azure-openai',
'gemini',
'vertexai',
'github',
'copilot',
'zhipu',
'yi',
'moonshot',
'baichuan',
'dashscope',
'stepfun',
'doubao',
'infini',
'minimax',
'groq',
'together',
'fireworks',
'nvidia',
'grok',
'hyperbolic',
'mistral',
'jina',
'perplexity',
'modelscope',
'xirang',
'hunyuan',
'tencent-cloud-ti',
'baidu-cloud',
'gpustack',
'voyageai',
'aws-bedrock',
'poe',
'aionly',
'longcat',
'huggingface',
'sophnet',
'ai-gateway',
'cerebras'
])
export type SystemProviderId = z.infer<typeof SystemProviderIdSchema>
export const isSystemProviderId = (id: string): id is SystemProviderId => {
return SystemProviderIdSchema.safeParse(id).success
}
export const SystemProviderIds = {
cherryin: 'cherryin',
silicon: 'silicon',
aihubmix: 'aihubmix',
ocoolai: 'ocoolai',
deepseek: 'deepseek',
ppio: 'ppio',
alayanew: 'alayanew',
qiniu: 'qiniu',
dmxapi: 'dmxapi',
burncloud: 'burncloud',
tokenflux: 'tokenflux',
'302ai': '302ai',
cephalon: 'cephalon',
lanyun: 'lanyun',
ph8: 'ph8',
sophnet: 'sophnet',
openrouter: 'openrouter',
ollama: 'ollama',
ovms: 'ovms',
'new-api': 'new-api',
lmstudio: 'lmstudio',
anthropic: 'anthropic',
openai: 'openai',
'azure-openai': 'azure-openai',
gemini: 'gemini',
vertexai: 'vertexai',
github: 'github',
copilot: 'copilot',
zhipu: 'zhipu',
yi: 'yi',
moonshot: 'moonshot',
baichuan: 'baichuan',
dashscope: 'dashscope',
stepfun: 'stepfun',
doubao: 'doubao',
infini: 'infini',
minimax: 'minimax',
groq: 'groq',
together: 'together',
fireworks: 'fireworks',
nvidia: 'nvidia',
grok: 'grok',
hyperbolic: 'hyperbolic',
mistral: 'mistral',
jina: 'jina',
perplexity: 'perplexity',
modelscope: 'modelscope',
xirang: 'xirang',
hunyuan: 'hunyuan',
'tencent-cloud-ti': 'tencent-cloud-ti',
'baidu-cloud': 'baidu-cloud',
gpustack: 'gpustack',
voyageai: 'voyageai',
'aws-bedrock': 'aws-bedrock',
poe: 'poe',
aionly: 'aionly',
longcat: 'longcat',
huggingface: 'huggingface',
'ai-gateway': 'ai-gateway',
cerebras: 'cerebras'
} as const satisfies Record<SystemProviderId, SystemProviderId>
type SystemProviderIdTypeMap = typeof SystemProviderIds
export type SystemProvider = Provider & {
id: SystemProviderId
isSystem: true

View File

@@ -13,7 +13,8 @@ import {
routeToEndpoint,
splitApiKeyString,
validateApiHost,
withoutTrailingApiVersion
withoutTrailingApiVersion,
withoutTrailingSharp
} from '../api'
vi.mock('@renderer/store', () => {
@@ -81,6 +82,27 @@ describe('api', () => {
it('keeps host untouched when api version unsupported', () => {
expect(formatApiHost('https://api.example.com', false)).toBe('https://api.example.com')
})
it('removes trailing # and does not append api version when host ends with #', () => {
expect(formatApiHost('https://api.example.com#')).toBe('https://api.example.com')
expect(formatApiHost('http://localhost:5173/#')).toBe('http://localhost:5173/')
expect(formatApiHost(' https://api.openai.com/# ')).toBe('https://api.openai.com/')
})
it('handles trailing # with custom api version settings', () => {
expect(formatApiHost('https://api.example.com#', true, 'v2')).toBe('https://api.example.com')
expect(formatApiHost('https://api.example.com#', false, 'v2')).toBe('https://api.example.com')
})
it('handles host with both trailing # and existing api version', () => {
expect(formatApiHost('https://api.example.com/v2#')).toBe('https://api.example.com/v2')
expect(formatApiHost('https://api.example.com/v3beta#')).toBe('https://api.example.com/v3beta')
})
it('trims whitespace before processing trailing #', () => {
expect(formatApiHost(' https://api.example.com# ')).toBe('https://api.example.com')
expect(formatApiHost('\thttps://api.example.com#\n')).toBe('https://api.example.com')
})
})
describe('hasAPIVersion', () => {
@@ -302,7 +324,18 @@ describe('api', () => {
})
it('uses global endpoint when location equals global', () => {
expect(formatVertexApiHost(createVertexProvider(''), 'global-project', 'global')).toBe(
getStateMock.mockReturnValueOnce({
llm: {
settings: {
vertexai: {
projectId: 'global-project',
location: 'global'
}
}
}
})
expect(formatVertexApiHost(createVertexProvider(''))).toBe(
'https://aiplatform.googleapis.com/v1/projects/global-project/locations/global'
)
})
@@ -393,4 +426,56 @@ describe('api', () => {
expect(withoutTrailingApiVersion('')).toBe('')
})
})
describe('withoutTrailingSharp', () => {
it('removes trailing # from URL', () => {
expect(withoutTrailingSharp('https://api.example.com#')).toBe('https://api.example.com')
expect(withoutTrailingSharp('http://localhost:3000#')).toBe('http://localhost:3000')
})
it('returns URL unchanged when no trailing #', () => {
expect(withoutTrailingSharp('https://api.example.com')).toBe('https://api.example.com')
expect(withoutTrailingSharp('http://localhost:3000')).toBe('http://localhost:3000')
})
it('handles URLs with multiple # characters but only removes trailing one', () => {
expect(withoutTrailingSharp('https://api.example.com#path#')).toBe('https://api.example.com#path')
})
it('handles URLs with # in the middle (not trailing)', () => {
expect(withoutTrailingSharp('https://api.example.com#section/path')).toBe('https://api.example.com#section/path')
expect(withoutTrailingSharp('https://api.example.com/v1/chat/completions#')).toBe(
'https://api.example.com/v1/chat/completions'
)
})
it('handles empty string', () => {
expect(withoutTrailingSharp('')).toBe('')
})
it('handles single character #', () => {
expect(withoutTrailingSharp('#')).toBe('')
})
it('preserves whitespace around the URL (pure function)', () => {
expect(withoutTrailingSharp(' https://api.example.com# ')).toBe(' https://api.example.com# ')
expect(withoutTrailingSharp('\thttps://api.example.com#\n')).toBe('\thttps://api.example.com#\n')
})
it('only removes exact trailing # character', () => {
expect(withoutTrailingSharp('https://api.example.com# ')).toBe('https://api.example.com# ')
expect(withoutTrailingSharp(' https://api.example.com#')).toBe(' https://api.example.com')
expect(withoutTrailingSharp('https://api.example.com#\t')).toBe('https://api.example.com#\t')
})
it('handles URLs ending with multiple # characters', () => {
expect(withoutTrailingSharp('https://api.example.com##')).toBe('https://api.example.com#')
expect(withoutTrailingSharp('https://api.example.com###')).toBe('https://api.example.com##')
})
it('preserves URL with trailing # and other content', () => {
expect(withoutTrailingSharp('https://api.example.com/v1#')).toBe('https://api.example.com/v1')
expect(withoutTrailingSharp('https://api.example.com/v2beta#')).toBe('https://api.example.com/v2beta')
})
})
})

View File

@@ -1,17 +1,6 @@
export {
formatApiHost,
formatAzureOpenAIApiHost,
formatVertexApiHost,
getAiSdkBaseUrl,
getTrailingApiVersion,
hasAPIVersion,
routeToEndpoint,
SUPPORTED_ENDPOINT_LIST,
SUPPORTED_IMAGE_ENDPOINT_LIST,
validateApiHost,
withoutTrailingApiVersion,
withoutTrailingSlash
} from '@shared/api'
import store from '@renderer/store'
import type { VertexProvider } from '@renderer/types'
import { trim } from 'lodash'
/**
* 格式化 API key 字符串。
@@ -23,6 +12,200 @@ export function formatApiKeys(value: string): string {
return value.replaceAll('', ',').replaceAll('\n', ',')
}
/**
* Matches a version segment in a path that starts with `/v<number>` and optionally
* continues with `alpha` or `beta`. The segment may be followed by `/` or the end
* of the string (useful for cases like `/v3alpha/resources`).
*/
const VERSION_REGEX_PATTERN = '\\/v\\d+(?:alpha|beta)?(?=\\/|$)'
/**
* Matches an API version at the end of a URL (with optional trailing slash).
* Used to detect and extract versions only from the trailing position.
*/
const TRAILING_VERSION_REGEX = /\/v\d+(?:alpha|beta)?\/?$/i
/**
* 判断 host 的 path 中是否包含形如版本的字符串(例如 /v1、/v2beta 等),
*
* @param host - 要检查的 host 或 path 字符串
* @returns 如果 path 中包含版本字符串则返回 true否则 false
*/
export function hasAPIVersion(host?: string): boolean {
if (!host) return false
const regex = new RegExp(VERSION_REGEX_PATTERN, 'i')
try {
const url = new URL(host)
return regex.test(url.pathname)
} catch {
// 若无法作为完整 URL 解析,则当作路径直接检测
return regex.test(host)
}
}
/**
* Removes the trailing slash from a URL string if it exists.
*
* @template T - The string type to preserve type safety
* @param {T} url - The URL string to process
* @returns {T} The URL string without a trailing slash
*
* @example
* ```ts
* withoutTrailingSlash('https://example.com/') // 'https://example.com'
* withoutTrailingSlash('https://example.com') // 'https://example.com'
* ```
*/
export function withoutTrailingSlash<T extends string>(url: T): T {
return url.replace(/\/$/, '') as T
}
/**
* Removes the trailing '#' from a URL string if it exists.
*
* @template T - The string type to preserve type safety
* @param {T} url - The URL string to process
* @returns {T} The URL string without a trailing '#'
*
* @example
* ```ts
* withoutTrailingSharp('https://example.com#') // 'https://example.com'
* withoutTrailingSharp('https://example.com') // 'https://example.com'
* ```
*/
export function withoutTrailingSharp<T extends string>(url: T): T {
return url.replace(/#$/, '') as T
}
/**
* Formats an API host URL by normalizing it and optionally appending an API version.
*
* @param host - The API host URL to format. Leading/trailing whitespace will be trimmed and trailing slashes removed.
* @param supportApiVersion - Whether the API version is supported. Defaults to `true`.
* @param apiVersion - The API version to append if needed. Defaults to `'v1'`.
*
* @returns The formatted API host URL. If the host is empty after normalization, returns an empty string.
* If the host ends with '#', API version is not supported, or the host already contains a version, returns the normalized host with trailing '#' removed.
* Otherwise, returns the host with the API version appended.
*
* @example
* formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1'
* formatApiHost('https://api.example.com#') // Returns 'https://api.example.com'
* formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2'
*/
export function formatApiHost(host?: string, supportApiVersion: boolean = true, apiVersion: string = 'v1'): string {
const normalizedHost = withoutTrailingSlash(trim(host))
if (!normalizedHost) {
return ''
}
const shouldAppendApiVersion = !(normalizedHost.endsWith('#') || !supportApiVersion || hasAPIVersion(normalizedHost))
if (shouldAppendApiVersion) {
return `${normalizedHost}/${apiVersion}`
} else {
return withoutTrailingSharp(normalizedHost)
}
}
/**
* 格式化 Azure OpenAI 的 API 主机地址。
*/
export function formatAzureOpenAIApiHost(host: string): string {
const normalizedHost = withoutTrailingSlash(host)
?.replace(/\/v1$/, '')
.replace(/\/openai$/, '')
// NOTE: AISDK会添加上`v1`
return formatApiHost(normalizedHost + '/openai', false)
}
export function formatVertexApiHost(provider: VertexProvider): string {
const { apiHost } = provider
const { projectId: project, location } = store.getState().llm.settings.vertexai
const trimmedHost = withoutTrailingSlash(trim(apiHost))
if (!trimmedHost || trimmedHost.endsWith('aiplatform.googleapis.com')) {
const host =
location == 'global' ? 'https://aiplatform.googleapis.com' : `https://${location}-aiplatform.googleapis.com`
return `${formatApiHost(host)}/projects/${project}/locations/${location}`
}
return formatApiHost(trimmedHost)
}
// 目前对话界面只支持这些端点
export const SUPPORTED_IMAGE_ENDPOINT_LIST = ['images/generations', 'images/edits', 'predict'] as const
export const SUPPORTED_ENDPOINT_LIST = [
'chat/completions',
'responses',
'messages',
'generateContent',
'streamGenerateContent',
...SUPPORTED_IMAGE_ENDPOINT_LIST
] as const
/**
* Converts an API host URL into separate base URL and endpoint components.
*
* @param apiHost - The API host string to parse. Expected to be a trimmed URL that may end with '#' followed by an endpoint identifier.
* @returns An object containing:
* - `baseURL`: The base URL without the endpoint suffix
* - `endpoint`: The matched endpoint identifier, or empty string if no match found
*
* @description
* This function extracts endpoint information from a composite API host string.
* If the host ends with '#', it attempts to match the preceding part against the supported endpoint list.
* The '#' delimiter is removed before processing.
*
* @example
* routeToEndpoint('https://api.example.com/openai/chat/completions#')
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: 'chat/completions' }
*
* @example
* routeToEndpoint('https://api.example.com/v1')
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: '' }
*/
export function routeToEndpoint(apiHost: string): { baseURL: string; endpoint: string } {
const trimmedHost = trim(apiHost)
// 前面已经确保apiHost合法
if (!trimmedHost.endsWith('#')) {
return { baseURL: trimmedHost, endpoint: '' }
}
// 去掉结尾的 #
const host = trimmedHost.slice(0, -1)
const endpointMatch = SUPPORTED_ENDPOINT_LIST.find((endpoint) => host.endsWith(endpoint))
if (!endpointMatch) {
const baseURL = withoutTrailingSlash(host)
return { baseURL, endpoint: '' }
}
const baseSegment = host.slice(0, host.length - endpointMatch.length)
const baseURL = withoutTrailingSlash(baseSegment).replace(/:$/, '') // 去掉结尾可能存在的冒号(gemini的特殊情况)
return { baseURL, endpoint: endpointMatch }
}
/**
* 验证 API 主机地址是否合法。
*
* @param {string} apiHost - 需要验证的 API 主机地址。
* @returns {boolean} 如果是合法的 URL 则返回 true否则返回 false。
*/
export function validateApiHost(apiHost: string): boolean {
// 允许apiHost为空
if (!apiHost || !trim(apiHost)) {
return true
}
try {
const url = new URL(trim(apiHost))
// 验证协议是否为 http 或 https
if (url.protocol !== 'http:' && url.protocol !== 'https:') {
return false
}
return true
} catch {
return false
}
}
/**
* API key 脱敏函数。仅保留部分前后字符,中间用星号代替。
*
@@ -61,3 +244,50 @@ export function splitApiKeyString(keyStr: string): string[] {
.map((k) => k.replace(/\\,/g, ','))
.filter((k) => k)
}
/**
* Extracts the trailing API version segment from a URL path.
*
* This function extracts API version patterns (e.g., `v1`, `v2beta`) from the end of a URL.
* Only versions at the end of the path are extracted, not versions in the middle.
* The returned version string does not include leading or trailing slashes.
*
* @param {string} url - The URL string to parse.
* @returns {string | undefined} The trailing API version found (e.g., 'v1', 'v2beta'), or undefined if none found.
*
* @example
* getTrailingApiVersion('https://api.example.com/v1') // 'v1'
* getTrailingApiVersion('https://api.example.com/v2beta/') // 'v2beta'
* getTrailingApiVersion('https://api.example.com/v1/chat') // undefined (version not at end)
* getTrailingApiVersion('https://gateway.ai.cloudflare.com/v1/xxx/v1beta') // 'v1beta'
* getTrailingApiVersion('https://api.example.com') // undefined
*/
export function getTrailingApiVersion(url: string): string | undefined {
const match = url.match(TRAILING_VERSION_REGEX)
if (match) {
// Extract version without leading slash and trailing slash
return match[0].replace(/^\//, '').replace(/\/$/, '')
}
return undefined
}
/**
* Removes the trailing API version segment from a URL path.
*
* This function removes API version patterns (e.g., `/v1`, `/v2beta`) from the end of a URL.
* Only versions at the end of the path are removed, not versions in the middle.
*
* @param {string} url - The URL string to process.
* @returns {string} The URL with the trailing API version removed, or the original URL if no trailing version found.
*
* @example
* withoutTrailingApiVersion('https://api.example.com/v1') // 'https://api.example.com'
* withoutTrailingApiVersion('https://api.example.com/v2beta/') // 'https://api.example.com'
* withoutTrailingApiVersion('https://api.example.com/v1/chat') // 'https://api.example.com/v1/chat' (no change)
* withoutTrailingApiVersion('https://api.example.com') // 'https://api.example.com'
*/
export function withoutTrailingApiVersion(url: string): string {
return url.replace(TRAILING_VERSION_REGEX, '')
}

View File

@@ -2,8 +2,6 @@ import { getProviderLabel } from '@renderer/i18n/label'
import type { Provider } from '@renderer/types'
import { isSystemProvider } from '@renderer/types'
export { getBaseModelName, getLowerBaseModelName } from '@shared/utils/naming'
/**
* 从模型 ID 中提取默认组名。
* 规则如下:
@@ -52,6 +50,38 @@ export const getDefaultGroupName = (id: string, provider?: string): string => {
return str
}
/**
* 从模型 ID 中提取基础名称。
* 例如:
* - 'deepseek/deepseek-r1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/deepseek-r1' => 'deepseek-r1'
* @param {string} id 模型 ID
* @param {string} [delimiter='/'] 分隔符,默认为 '/'
* @returns {string} 基础名称
*/
export const getBaseModelName = (id: string, delimiter: string = '/'): string => {
const parts = id.split(delimiter)
return parts[parts.length - 1]
}
/**
* 从模型 ID 中提取基础名称并转换为小写。
* 例如:
* - 'deepseek/DeepSeek-R1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/DeepSeek-R1' => 'deepseek-r1'
* @param {string} id 模型 ID
* @param {string} [delimiter='/'] 分隔符,默认为 '/'
* @returns {string} 小写的基础名称
*/
export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => {
const baseModelName = getBaseModelName(id, delimiter).toLowerCase()
// for openrouter
if (baseModelName.endsWith(':free')) {
return baseModelName.replace(':free', '')
}
return baseModelName
}
/**
* 获取模型服务商名称,根据是否内置服务商来决定要不要翻译
* @param provider 服务商

View File

@@ -1,20 +1,10 @@
import { CLAUDE_SUPPORTED_PROVIDERS } from '@renderer/pages/code'
import type { ProviderType } from '@renderer/types'
import type { AzureOpenAIProvider, ProviderType, VertexProvider } from '@renderer/types'
import { isSystemProvider, type Provider, type SystemProviderId, SystemProviderIds } from '@renderer/types'
export {
isAIGatewayProvider,
isAnthropicProvider,
isAwsBedrockProvider,
isAzureOpenAIProvider,
isAzureResponsesEndpoint,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isOpenAICompatibleProvider,
isOpenAIProvider,
isPerplexityProvider,
isVertexProvider
} from '@shared/provider'
export const isAzureResponsesEndpoint = (provider: AzureOpenAIProvider) => {
return provider.apiVersion === 'preview' || provider.apiVersion === 'v1'
}
export const getClaudeSupportedProviders = (providers: Provider[]) => {
return providers.filter(
@@ -136,6 +126,55 @@ export const isGeminiWebSearchProvider = (provider: Provider) => {
return SUPPORT_GEMINI_NATIVE_WEB_SEARCH_PROVIDERS.some((id) => id === provider.id)
}
export const isNewApiProvider = (provider: Provider) => {
return ['new-api', 'cherryin'].includes(provider.id) || provider.type === 'new-api'
}
export function isCherryAIProvider(provider: Provider): boolean {
return provider.id === 'cherryai'
}
export function isPerplexityProvider(provider: Provider): boolean {
return provider.id === 'perplexity'
}
/**
* 判断是否为 OpenAI 兼容的提供商
* @param {Provider} provider 提供商对象
* @returns {boolean} 是否为 OpenAI 兼容提供商
*/
export function isOpenAICompatibleProvider(provider: Provider): boolean {
return ['openai', 'new-api', 'mistral'].includes(provider.type)
}
export function isAzureOpenAIProvider(provider: Provider): provider is AzureOpenAIProvider {
return provider.type === 'azure-openai'
}
export function isOpenAIProvider(provider: Provider): boolean {
return provider.type === 'openai-response'
}
export function isVertexProvider(provider: Provider): provider is VertexProvider {
return provider.type === 'vertexai'
}
export function isAwsBedrockProvider(provider: Provider): boolean {
return provider.type === 'aws-bedrock'
}
export function isAnthropicProvider(provider: Provider): boolean {
return provider.type === 'anthropic'
}
export function isGeminiProvider(provider: Provider): boolean {
return provider.type === 'gemini'
}
export function isAIGatewayProvider(provider: Provider): boolean {
return provider.type === 'ai-gateway'
}
const NOT_SUPPORT_API_VERSION_PROVIDERS = ['github', 'copilot', 'perplexity'] as const satisfies SystemProviderId[]
export const isSupportAPIVersionProvider = (provider: Provider) => {
@@ -144,3 +183,11 @@ export const isSupportAPIVersionProvider = (provider: Provider) => {
}
return provider.apiOptions?.isNotSupportAPIVersion !== false
}
export const NOT_SUPPORT_API_KEY_PROVIDERS: readonly SystemProviderId[] = [
'ollama',
'lmstudio',
'vertexai',
'aws-bedrock',
'copilot'
]

View File

@@ -8,10 +8,8 @@
"src/preload/**/*",
"src/renderer/src/services/traceApi.ts",
"src/renderer/src/types/*",
"packages/aiCore/src/**/*",
"packages/mcp-trace/**/*",
"packages/shared/**/*",
"packages/ai-sdk-provider/**/*"
],
"compilerOptions": {
"composite": true,
@@ -28,12 +26,7 @@
"@types": ["./src/renderer/src/types/index.ts"],
"@shared/*": ["./packages/shared/*"],
"@mcp-trace/*": ["./packages/mcp-trace/*"],
"@modelcontextprotocol/sdk/*": ["./node_modules/@modelcontextprotocol/sdk/dist/esm/*"],
"@cherrystudio/ai-core/provider": ["./packages/aiCore/src/core/providers/index.ts"],
"@cherrystudio/ai-core/built-in/plugins": ["./packages/aiCore/src/core/plugins/built-in/index.ts"],
"@cherrystudio/ai-core/*": ["./packages/aiCore/src/*"],
"@cherrystudio/ai-core": ["./packages/aiCore/src/index.ts"],
"@cherrystudio/ai-sdk-provider": ["./packages/ai-sdk-provider/src/index.ts"]
"@modelcontextprotocol/sdk/*": ["./node_modules/@modelcontextprotocol/sdk/dist/esm/*"]
},
"experimentalDecorators": true,
"emitDecoratorMetadata": true,