Compare commits

..

4 Commits

Author SHA1 Message Date
suyao
76d48f9ccb fix: address PR review issues for Volcengine integration
- Fix region field being ignored: pass user-configured region to listFoundationModels and listEndpoints
- Add user notification before silent fallback when API fails
- Throw error on credential corruption instead of returning null
- Remove redundant credentials (accessKeyId, secretAccessKey) from Redux store (they're securely stored via safeStorage)
- Add warnings field to ListModelsResult for partial API failures
- Fix Redux/IPC order: save to secure storage first, then update Redux on success
- Update related tests

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-11-29 19:39:15 +08:00
suyao
115cd80432 fix: format 2025-11-27 01:33:17 +08:00
suyao
531101742e feat: add project name support for Volcengine integration 2025-11-27 01:29:02 +08:00
suyao
c3c577dff4 feat: add Volcengine integration with settings and API client
- Implement Volcengine configuration in multiple languages (el-gr, es-es, fr-fr, ja-jp, pt-pt, ru-ru).
- Add Volcengine settings component to manage access key ID, secret access key, and region.
- Create Volcengine service for API interactions, including credential management and model listing.
- Extend OpenAI API client to support Volcengine's signed API for model retrieval.
- Update Redux store to handle Volcengine settings and credentials.
- Implement migration for Volcengine settings in the store.
- Add hooks for accessing and managing Volcengine settings in the application.
2025-11-27 01:15:22 +08:00
187 changed files with 3084 additions and 7265 deletions

2
.gitignore vendored
View File

@@ -73,5 +73,3 @@ test-results
YOUR_MEMORY_FILE_PATH
.sessions/
.next/
*.tsbuildinfo

View File

@@ -11,7 +11,6 @@
"dist/**",
"out/**",
"local/**",
"tests/**",
".yarn/**",
".gitignore",
"scripts/cloudflare-worker.js",

View File

@@ -12,15 +12,7 @@ This file provides guidance to AI coding assistants when working with code in th
- **Always propose before executing**: Before making any changes, clearly explain your planned approach and wait for explicit user approval to ensure alignment and prevent unwanted modifications.
- **Lint, test, and format before completion**: Coding tasks are only complete after running `yarn lint`, `yarn test`, and `yarn format` successfully.
- **Write conventional commits**: Commit small, focused changes using Conventional Commit messages (e.g., `feat:`, `fix:`, `refactor:`, `docs:`).
## Pull Request Workflow (CRITICAL)
When creating a Pull Request, you MUST:
1. **Read the PR template first**: Always read `.github/pull_request_template.md` before creating the PR
2. **Follow ALL template sections**: Structure the `--body` parameter to include every section from the template
3. **Never skip sections**: Include all sections even if marking them as N/A or "None"
4. **Use proper formatting**: Match the template's markdown structure exactly (headings, checkboxes, code blocks)
- **Follow PR template**: When submitting pull requests, follow the template in `.github/pull_request_template.md` to ensure complete context and documentation.
## Development Commands

View File

@@ -134,108 +134,56 @@ artifactBuildCompleted: scripts/artifact-build-completed.js
releaseInfo:
releaseNotes: |
<!--LANG:en-->
A New Era of Intelligence with Cherry Studio 1.7.1
What's New in v1.7.0-rc.3
Today we're releasing Cherry Studio 1.7.1 — our most ambitious update yet, introducing Agent: autonomous AI that thinks, plans, and acts.
✨ New Features:
- Provider: Added Silicon provider support for Anthropic API compatibility
- Provider: AIHubMix support for nano banana
For years, AI assistants have been reactive — waiting for your commands, responding to your questions. With Agent, we're changing that. Now, AI can truly work alongside you: understanding complex goals, breaking them into steps, and executing them independently.
🐛 Bug Fixes:
- i18n: Clean up translation tags and untranslated strings
- Provider: Fixed Silicon provider code list
- Provider: Fixed Poe API reasoning parameters for GPT-5 and reasoning models
- Provider: Fixed duplicate /v1 in Anthropic API endpoints
- Provider: Fixed Azure provider handling in AI SDK integration
- Models: Added Claude Opus 4.5 pattern to THINKING_TOKEN_MAP
- Models: Improved Gemini reasoning and message handling
- Models: Fixed custom parameters for Gemini models
- Models: Fixed qwen-mt-flash text delta support
- Models: Fixed Groq verbosity setting
- UI: Fixed quota display and quota tips
- UI: Fixed web search button condition
- Settings: Fixed updateAssistantPreset reducer to properly update preset
- Settings: Respect enableMaxTokens setting when maxTokens is not configured
- SDK: Fixed header merging logic in AI SDK
This is what we've been building toward. And it's just the beginning.
🤖 Meet Agent
Imagine having a brilliant colleague who never sleeps. Give Agent a goal — write a report, analyze data, refactor code — and watch it work. It reasons through problems, breaks them into steps, calls the right tools, and adapts when things change.
- **Think → Plan → Act**: From goal to execution, fully autonomous
- **Deep Reasoning**: Multi-turn thinking that solves real problems
- **Tool Mastery**: File operations, web search, code execution, and more
- **Skill Plugins**: Extend with custom commands and capabilities
- **You Stay in Control**: Real-time approval for sensitive actions
- **Full Visibility**: Every thought, every decision, fully transparent
🌐 Expanding Ecosystem
- **New Providers**: HuggingFace, Mistral, CherryIN, AI Gateway, Intel OVMS, Didi MCP
- **New Models**: Claude 4.5 Haiku, DeepSeek v3.2, GLM-4.6, Doubao, Ling series
- **MCP Integration**: Alibaba Cloud, ModelScope, Higress, MCP.so, TokenFlux and more
📚 Smarter Knowledge Base
- **OpenMinerU**: Self-hosted document processing
- **Full-Text Search**: Find anything instantly across your notes
- **Enhanced Tool Selection**: Smarter configuration for better AI assistance
📝 Notes, Reimagined
- Full-text search with highlighted results
- AI-powered smart rename
- Export as image
- Auto-wrap for tables
🖼️ Image & OCR
- Intel OVMS painting capabilities
- Intel OpenVINO NPU-accelerated OCR
🌍 Now in 10+ Languages
- Added German support
- Enhanced internationalization
⚡ Faster & More Polished
- Electron 38 upgrade
- New MCP management interface
- Dozens of UI refinements
❤️ Fully Open Source
Commercial restrictions removed. Cherry Studio now follows standard AGPL v3 — free for teams of any size.
The Agent Era is here. We can't wait to see what you'll create.
⚡ Improvements:
- SDK: Upgraded @anthropic-ai/claude-agent-sdk to 0.1.53
<!--LANG:zh-CN-->
Cherry Studio 1.7.1:开启智能新纪元
v1.7.0-rc.3 更新内容
今天,我们正式发布 Cherry Studio 1.7.1 —— 迄今最具雄心的版本,带来全新的 Agent能够自主思考、规划和行动的 AI。
✨ 新功能:
- 提供商:新增 Silicon 提供商对 Anthropic API 的兼容性支持
- 提供商AIHubMix 支持 nano banana
多年来AI 助手一直是被动的——等待你的指令回应你的问题。Agent 改变了这一切。现在AI 能够真正与你并肩工作:理解复杂目标,将其拆解为步骤,并独立执行。
🐛 问题修复:
- 国际化:清理翻译标签和未翻译字符串
- 提供商:修复 Silicon 提供商代码列表
- 提供商:修复 Poe API 对 GPT-5 和推理模型的推理参数
- 提供商:修复 Anthropic API 端点重复 /v1 问题
- 提供商:修复 Azure 提供商在 AI SDK 集成中的处理
- 模型Claude Opus 4.5 添加到 THINKING_TOKEN_MAP
- 模型:改进 Gemini 推理和消息处理
- 模型:修复 Gemini 模型自定义参数
- 模型:修复 qwen-mt-flash text delta 支持
- 模型:修复 Groq verbosity 设置
- 界面:修复配额显示和配额提示
- 界面:修复 Web 搜索按钮条件
- 设置:修复 updateAssistantPreset reducer 正确更新 preset
- 设置:尊重 enableMaxTokens 设置
- SDK修复 AI SDK 中 header 合并逻辑
这是我们一直在构建的未来。而这,仅仅是开始。
🤖 认识 Agent
想象一位永不疲倦的得力伙伴。给 Agent 一个目标——撰写报告、分析数据、重构代码——然后看它工作。它会推理问题、拆解步骤、调用工具,并在情况变化时灵活应对。
- **思考 → 规划 → 行动**:从目标到执行,全程自主
- **深度推理**:多轮思考,解决真实问题
- **工具大师**:文件操作、网络搜索、代码执行,样样精通
- **技能插件**:自定义命令,无限扩展
- **你掌控全局**:敏感操作,实时审批
- **完全透明**:每一步思考,每一个决策,清晰可见
🌐 生态持续壮大
- **新增服务商**Hugging Face、Mistral、Perplexity、SophNet、AI Gateway、Cerebras AI
- **新增模型**Gemini 3、Gemini 3 Pro支持图像预览、GPT-5.1、Claude Opus 4.5
- **MCP 集成**百炼、魔搭、Higress、MCP.so、TokenFlux 等平台
📚 更智能的知识库
- **OpenMinerU**:本地自部署文档处理
- **全文搜索**:笔记内容一搜即达
- **增强工具选择**:更智能的配置,更好的 AI 协助
📝 笔记,焕然一新
- 全文搜索,结果高亮
- AI 智能重命名
- 导出为图片
- 表格自动换行
🖼️ 图像与 OCR
- Intel OVMS 绘图能力
- Intel OpenVINO NPU 加速 OCR
🌍 支持 10+ 种语言
- 新增德语支持
- 全面增强国际化
⚡ 更快、更精致
- 升级 Electron 38
- 新的 MCP 管理界面
- 数十处 UI 细节打磨
❤️ 完全开源
商用限制已移除。Cherry Studio 现遵循标准 AGPL v3 协议——任意规模团队均可自由使用。
Agent 纪元已至。期待你的创造。
⚡ 改进:
- SDK升级 @anthropic-ai/claude-agent-sdk 到 0.1.53
<!--LANG:END-->

View File

@@ -25,10 +25,7 @@ export default defineConfig({
'@shared': resolve('packages/shared'),
'@logger': resolve('src/main/services/LoggerService'),
'@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'),
'@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node'),
'@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'),
'@cherrystudio/ai-core': resolve('packages/aiCore/src'),
'@cherrystudio/ai-sdk-provider': resolve('packages/ai-sdk-provider/src')
'@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node')
}
},
build: {

View File

@@ -58,7 +58,6 @@ export default defineConfig([
'dist/**',
'out/**',
'local/**',
'tests/**',
'.yarn/**',
'.gitignore',
'scripts/cloudflare-worker.js',

View File

@@ -1,6 +1,6 @@
{
"name": "CherryStudio",
"version": "1.7.1",
"version": "1.7.0-rc.3",
"private": true,
"description": "A powerful AI assistant for producer.",
"main": "./out/main/index.js",
@@ -62,7 +62,6 @@
"test": "vitest run --silent",
"test:main": "vitest run --project main",
"test:renderer": "vitest run --project renderer",
"test:aicore": "vitest run --project aiCore",
"test:update": "yarn test:renderer --update",
"test:coverage": "vitest run --coverage --silent",
"test:ui": "vitest --ui",
@@ -165,7 +164,7 @@
"@modelcontextprotocol/sdk": "^1.17.5",
"@mozilla/readability": "^0.6.0",
"@notionhq/client": "^2.2.15",
"@openrouter/ai-sdk-provider": "^1.2.8",
"@openrouter/ai-sdk-provider": "^1.2.5",
"@opentelemetry/api": "^1.9.0",
"@opentelemetry/core": "2.0.0",
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
@@ -173,7 +172,7 @@
"@opentelemetry/sdk-trace-node": "^2.0.0",
"@opentelemetry/sdk-trace-web": "^2.0.0",
"@opeoginni/github-copilot-openai-compatible": "^0.1.21",
"@playwright/test": "^1.55.1",
"@playwright/test": "^1.52.0",
"@radix-ui/react-context-menu": "^2.2.16",
"@reduxjs/toolkit": "^2.2.5",
"@shikijs/markdown-it": "^3.12.0",
@@ -322,6 +321,7 @@
"p-queue": "^8.1.0",
"pdf-lib": "^1.17.1",
"pdf-parse": "^1.1.1",
"playwright": "^1.55.1",
"proxy-agent": "^6.5.0",
"react": "^19.2.0",
"react-dom": "^19.2.0",

View File

@@ -69,7 +69,6 @@ export interface CherryInProviderSettings {
headers?: HeadersInput
/**
* Optional endpoint type to distinguish different endpoint behaviors.
* "image-generation" is also openai endpoint, but specifically for image generation.
*/
endpointType?: 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'image-generation' | 'jina-rerank'
}

View File

@@ -3,13 +3,12 @@
* Provides realistic mock responses for all provider types
*/
import type { ModelMessage, Tool } from 'ai'
import { jsonSchema } from 'ai'
import { jsonSchema, type ModelMessage, type Tool } from 'ai'
/**
* Standard test messages for all scenarios
*/
export const testMessages: Record<string, ModelMessage[]> = {
export const testMessages = {
simple: [{ role: 'user' as const, content: 'Hello, how are you?' }],
conversation: [
@@ -46,7 +45,7 @@ export const testMessages: Record<string, ModelMessage[]> = {
{ role: 'assistant' as const, content: '15 * 23 = 345' },
{ role: 'user' as const, content: 'Now divide that by 5' }
]
}
} satisfies Record<string, ModelMessage[]>
/**
* Standard test tools for tool calling scenarios
@@ -139,17 +138,68 @@ export const testTools: Record<string, Tool> = {
}
}
/**
* Mock streaming chunks for different providers
*/
export const mockStreamingChunks = {
text: [
{ type: 'text-delta' as const, textDelta: 'Hello' },
{ type: 'text-delta' as const, textDelta: ', ' },
{ type: 'text-delta' as const, textDelta: 'this ' },
{ type: 'text-delta' as const, textDelta: 'is ' },
{ type: 'text-delta' as const, textDelta: 'a ' },
{ type: 'text-delta' as const, textDelta: 'test.' }
],
withToolCall: [
{ type: 'text-delta' as const, textDelta: 'Let me check the weather for you.' },
{
type: 'tool-call-delta' as const,
toolCallType: 'function' as const,
toolCallId: 'call_123',
toolName: 'getWeather',
argsTextDelta: '{"location":'
},
{
type: 'tool-call-delta' as const,
toolCallType: 'function' as const,
toolCallId: 'call_123',
toolName: 'getWeather',
argsTextDelta: ' "San Francisco, CA"}'
},
{
type: 'tool-call' as const,
toolCallType: 'function' as const,
toolCallId: 'call_123',
toolName: 'getWeather',
args: { location: 'San Francisco, CA' }
}
],
withFinish: [
{ type: 'text-delta' as const, textDelta: 'Complete response.' },
{
type: 'finish' as const,
finishReason: 'stop' as const,
usage: {
promptTokens: 10,
completionTokens: 5,
totalTokens: 15
}
}
]
}
/**
* Mock complete responses for non-streaming scenarios
* Note: AI SDK v5 uses inputTokens/outputTokens instead of promptTokens/completionTokens
*/
export const mockCompleteResponses = {
simple: {
text: 'This is a simple response.',
finishReason: 'stop' as const,
usage: {
inputTokens: 15,
outputTokens: 8,
promptTokens: 15,
completionTokens: 8,
totalTokens: 23
}
},
@@ -165,8 +215,8 @@ export const mockCompleteResponses = {
],
finishReason: 'tool-calls' as const,
usage: {
inputTokens: 25,
outputTokens: 12,
promptTokens: 25,
completionTokens: 12,
totalTokens: 37
}
},
@@ -175,15 +225,14 @@ export const mockCompleteResponses = {
text: 'Response with warnings.',
finishReason: 'stop' as const,
usage: {
inputTokens: 10,
outputTokens: 5,
promptTokens: 10,
completionTokens: 5,
totalTokens: 15
},
warnings: [
{
type: 'unsupported-setting' as const,
setting: 'temperature',
details: 'Temperature parameter not supported for this model'
message: 'Temperature parameter not supported for this model'
}
]
}
@@ -236,3 +285,47 @@ export const mockImageResponses = {
warnings: []
}
}
/**
* Mock error responses
*/
export const mockErrors = {
invalidApiKey: {
name: 'APIError',
message: 'Invalid API key provided',
statusCode: 401
},
rateLimitExceeded: {
name: 'RateLimitError',
message: 'Rate limit exceeded. Please try again later.',
statusCode: 429,
headers: {
'retry-after': '60'
}
},
modelNotFound: {
name: 'ModelNotFoundError',
message: 'The requested model was not found',
statusCode: 404
},
contextLengthExceeded: {
name: 'ContextLengthError',
message: "This model's maximum context length is 4096 tokens",
statusCode: 400
},
timeout: {
name: 'TimeoutError',
message: 'Request timed out after 30000ms',
code: 'ETIMEDOUT'
},
networkError: {
name: 'NetworkError',
message: 'Network connection failed',
code: 'ECONNREFUSED'
}
}

View File

@@ -1,35 +0,0 @@
/**
* Mock for @cherrystudio/ai-sdk-provider
* This mock is used in tests to avoid importing the actual package
*/
export type CherryInProviderSettings = {
apiKey?: string
baseURL?: string
}
// oxlint-disable-next-line no-unused-vars
export const createCherryIn = (_options?: CherryInProviderSettings) => ({
// oxlint-disable-next-line no-unused-vars
languageModel: (_modelId: string) => ({
specificationVersion: 'v1',
provider: 'cherryin',
modelId: 'mock-model',
doGenerate: async () => ({ text: 'mock response' }),
doStream: async () => ({ stream: (async function* () {})() })
}),
// oxlint-disable-next-line no-unused-vars
chat: (_modelId: string) => ({
specificationVersion: 'v1',
provider: 'cherryin-chat',
modelId: 'mock-model',
doGenerate: async () => ({ text: 'mock response' }),
doStream: async () => ({ stream: (async function* () {})() })
}),
// oxlint-disable-next-line no-unused-vars
textEmbeddingModel: (_modelId: string) => ({
specificationVersion: 'v1',
provider: 'cherryin',
modelId: 'mock-embedding-model'
})
})

View File

@@ -1,9 +0,0 @@
/**
* Vitest Setup File
* Global test configuration and mocks for @cherrystudio/ai-core package
*/
// Mock Vite SSR helper to avoid Node environment errors
;(globalThis as any).__vite_ssr_exportName__ = (_name: string, value: any) => value
// Note: @cherrystudio/ai-sdk-provider is mocked via alias in vitest.config.ts

View File

@@ -1,109 +0,0 @@
import { describe, expect, it } from 'vitest'
import { createOpenAIOptions, createOpenRouterOptions, mergeProviderOptions } from '../factory'
describe('mergeProviderOptions', () => {
it('deep merges provider options for the same provider', () => {
const reasoningOptions = createOpenRouterOptions({
reasoning: {
enabled: true,
effort: 'medium'
}
})
const webSearchOptions = createOpenRouterOptions({
plugins: [{ id: 'web', max_results: 5 }]
})
const merged = mergeProviderOptions(reasoningOptions, webSearchOptions)
expect(merged.openrouter).toEqual({
reasoning: {
enabled: true,
effort: 'medium'
},
plugins: [{ id: 'web', max_results: 5 }]
})
})
it('preserves options from other providers while merging', () => {
const openRouter = createOpenRouterOptions({
reasoning: { enabled: true }
})
const openAI = createOpenAIOptions({
reasoningEffort: 'low'
})
const merged = mergeProviderOptions(openRouter, openAI)
expect(merged.openrouter).toEqual({ reasoning: { enabled: true } })
expect(merged.openai).toEqual({ reasoningEffort: 'low' })
})
it('overwrites primitive values with later values', () => {
const first = createOpenAIOptions({
reasoningEffort: 'low',
user: 'user-123'
})
const second = createOpenAIOptions({
reasoningEffort: 'high',
maxToolCalls: 5
})
const merged = mergeProviderOptions(first, second)
expect(merged.openai).toEqual({
reasoningEffort: 'high', // overwritten by second
user: 'user-123', // preserved from first
maxToolCalls: 5 // added from second
})
})
it('overwrites arrays with later values instead of merging', () => {
const first = createOpenRouterOptions({
models: ['gpt-4', 'gpt-3.5-turbo']
})
const second = createOpenRouterOptions({
models: ['claude-3-opus', 'claude-3-sonnet']
})
const merged = mergeProviderOptions(first, second)
// Array is completely replaced, not merged
expect(merged.openrouter?.models).toEqual(['claude-3-opus', 'claude-3-sonnet'])
})
it('deeply merges nested objects while overwriting primitives', () => {
const first = createOpenRouterOptions({
reasoning: {
enabled: true,
effort: 'low'
},
user: 'user-123'
})
const second = createOpenRouterOptions({
reasoning: {
effort: 'high',
max_tokens: 500
},
user: 'user-456'
})
const merged = mergeProviderOptions(first, second)
expect(merged.openrouter).toEqual({
reasoning: {
enabled: true, // preserved from first
effort: 'high', // overwritten by second
max_tokens: 500 // added from second
},
user: 'user-456' // overwritten by second
})
})
it('replaces arrays instead of merging them', () => {
const first = createOpenRouterOptions({ plugins: [{ id: 'old' }] })
const second = createOpenRouterOptions({ plugins: [{ id: 'new' }] })
const merged = mergeProviderOptions(first, second)
// @ts-expect-error type-check for openrouter options is skipped. see function signature of createOpenRouterOptions
expect(merged.openrouter?.plugins).toEqual([{ id: 'new' }])
})
})

View File

@@ -26,65 +26,13 @@ export function createGenericProviderOptions<T extends string>(
return { [provider]: options } as Record<T, Record<string, any>>
}
type PlainObject = Record<string, any>
const isPlainObject = (value: unknown): value is PlainObject => {
return typeof value === 'object' && value !== null && !Array.isArray(value)
}
function deepMergeObjects<T extends PlainObject>(target: T, source: PlainObject): T {
const result: PlainObject = { ...target }
Object.entries(source).forEach(([key, value]) => {
if (isPlainObject(value) && isPlainObject(result[key])) {
result[key] = deepMergeObjects(result[key], value)
} else {
result[key] = value
}
})
return result as T
}
/**
* Deep-merge multiple provider-specific options.
* Nested objects are recursively merged; primitive values are overwritten.
*
* When the same key appears in multiple options:
* - If both values are plain objects: they are deeply merged (recursive merge)
* - If values are primitives/arrays: the later value overwrites the earlier one
*
* @example
* mergeProviderOptions(
* { openrouter: { reasoning: { enabled: true, effort: 'low' }, user: 'user-123' } },
* { openrouter: { reasoning: { effort: 'high', max_tokens: 500 }, models: ['gpt-4'] } }
* )
* // Result: {
* // openrouter: {
* // reasoning: { enabled: true, effort: 'high', max_tokens: 500 },
* // user: 'user-123',
* // models: ['gpt-4']
* // }
* // }
*
* @param optionsMap Objects containing options for multiple providers
* @returns Fully merged TypedProviderOptions
* 合并多个供应商的options
* @param optionsMap 包含多个供应商选项的对象
* @returns 合并后的TypedProviderOptions
*/
export function mergeProviderOptions(...optionsMap: Partial<TypedProviderOptions>[]): TypedProviderOptions {
return optionsMap.reduce<TypedProviderOptions>((acc, options) => {
if (!options) {
return acc
}
Object.entries(options).forEach(([providerId, providerOptions]) => {
if (!providerOptions) {
return
}
if (acc[providerId]) {
acc[providerId] = deepMergeObjects(acc[providerId] as PlainObject, providerOptions as PlainObject)
} else {
acc[providerId] = providerOptions as any
}
})
return acc
}, {} as TypedProviderOptions)
return Object.assign({}, ...optionsMap)
}
/**

View File

@@ -19,20 +19,15 @@ describe('Provider Schemas', () => {
expect(Array.isArray(baseProviders)).toBe(true)
expect(baseProviders.length).toBeGreaterThan(0)
// These are the actual base providers defined in schemas.ts
const expectedIds = [
'openai',
'openai-chat',
'openai-responses',
'openai-compatible',
'anthropic',
'google',
'xai',
'azure',
'azure-responses',
'deepseek',
'openrouter',
'cherryin',
'cherryin-chat'
'deepseek'
]
const actualIds = baseProviders.map((p) => p.id)
expectedIds.forEach((id) => {

View File

@@ -232,13 +232,11 @@ describe('RuntimeExecutor.generateImage', () => {
expect(pluginCallOrder).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
// transformParams receives params without model (model is handled separately)
// and context with core fields + dynamic fields (requestId, startTime, etc.)
expect(testPlugin.transformParams).toHaveBeenCalledWith(
expect.objectContaining({ prompt: 'A test image' }),
{ prompt: 'A test image' },
expect.objectContaining({
providerId: 'openai',
model: 'dall-e-3'
modelId: 'dall-e-3'
})
)
@@ -275,12 +273,11 @@ describe('RuntimeExecutor.generateImage', () => {
await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
// resolveModel receives model id and context with core fields
expect(modelResolutionPlugin.resolveModel).toHaveBeenCalledWith(
'dall-e-3',
expect.objectContaining({
providerId: 'openai',
model: 'dall-e-3'
modelId: 'dall-e-3'
})
)
@@ -342,11 +339,12 @@ describe('RuntimeExecutor.generateImage', () => {
.generateImage({ model: 'invalid-model', prompt: 'A test image' })
.catch((error) => error)
// Error is thrown from pluginEngine directly as ImageModelResolutionError
expect(thrownError).toBeInstanceOf(ImageModelResolutionError)
expect(thrownError.message).toContain('Failed to resolve image model: invalid-model')
expect(thrownError).toBeInstanceOf(ImageGenerationError)
expect(thrownError.message).toContain('Failed to generate image:')
expect(thrownError.providerId).toBe('openai')
expect(thrownError.modelId).toBe('invalid-model')
expect(thrownError.cause).toBeInstanceOf(ImageModelResolutionError)
expect(thrownError.cause.message).toContain('Failed to resolve image model: invalid-model')
})
it('should handle ImageModelResolutionError without provider', async () => {
@@ -364,9 +362,8 @@ describe('RuntimeExecutor.generateImage', () => {
const apiError = new Error('API request failed')
vi.mocked(aiGenerateImage).mockRejectedValue(apiError)
// Error propagates directly from pluginEngine without wrapping
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
'API request failed'
'Failed to generate image:'
)
})
@@ -379,9 +376,8 @@ describe('RuntimeExecutor.generateImage', () => {
vi.mocked(aiGenerateImage).mockRejectedValue(noImageError)
vi.mocked(NoImageGeneratedError.isInstance).mockReturnValue(true)
// Error propagates directly from pluginEngine
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
'No image generated'
'Failed to generate image:'
)
})
@@ -402,17 +398,15 @@ describe('RuntimeExecutor.generateImage', () => {
[errorPlugin]
)
// Error propagates directly from pluginEngine
await expect(executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
'Generation failed'
'Failed to generate image:'
)
// onError receives the original error and context with core fields
expect(errorPlugin.onError).toHaveBeenCalledWith(
error,
expect.objectContaining({
providerId: 'openai',
model: 'dall-e-3'
modelId: 'dall-e-3'
})
)
})
@@ -425,10 +419,9 @@ describe('RuntimeExecutor.generateImage', () => {
const abortController = new AbortController()
setTimeout(() => abortController.abort(), 10)
// Error propagates directly from pluginEngine
await expect(
executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', abortSignal: abortController.signal })
).rejects.toThrow('Operation was aborted')
).rejects.toThrow('Failed to generate image:')
})
})

View File

@@ -17,14 +17,10 @@ import type { AiPlugin } from '../../plugins'
import { globalRegistryManagement } from '../../providers/RegistryManagement'
import { RuntimeExecutor } from '../executor'
// Mock AI SDK - use importOriginal to keep jsonSchema and other non-mocked exports
vi.mock('ai', async (importOriginal) => {
const actual = (await importOriginal()) as Record<string, unknown>
return {
...actual,
generateText: vi.fn()
}
})
// Mock AI SDK
vi.mock('ai', () => ({
generateText: vi.fn()
}))
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
@@ -413,12 +409,11 @@ describe('RuntimeExecutor.generateText', () => {
})
).rejects.toThrow('Generation failed')
// onError receives the original error and context with core fields
expect(errorPlugin.onError).toHaveBeenCalledWith(
error,
expect.objectContaining({
providerId: 'openai',
model: 'gpt-4'
modelId: 'gpt-4'
})
)
})

View File

@@ -11,14 +11,10 @@ import type { AiPlugin } from '../../plugins'
import { globalRegistryManagement } from '../../providers/RegistryManagement'
import { RuntimeExecutor } from '../executor'
// Mock AI SDK - use importOriginal to keep jsonSchema and other non-mocked exports
vi.mock('ai', async (importOriginal) => {
const actual = (await importOriginal()) as Record<string, unknown>
return {
...actual,
streamText: vi.fn()
}
})
// Mock AI SDK
vi.mock('ai', () => ({
streamText: vi.fn()
}))
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
@@ -157,7 +153,7 @@ describe('RuntimeExecutor.streamText', () => {
describe('Max Tokens Parameter', () => {
const maxTokensValues = [10, 50, 100, 500, 1000, 2000, 4000]
it.each(maxTokensValues)('should support maxOutputTokens=%s', async (maxOutputTokens) => {
it.each(maxTokensValues)('should support maxTokens=%s', async (maxTokens) => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
@@ -172,13 +168,12 @@ describe('RuntimeExecutor.streamText', () => {
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
maxOutputTokens
maxOutputTokens: maxTokens
})
// Parameters are passed through without transformation
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
maxOutputTokens
maxTokens
})
)
})
@@ -518,12 +513,11 @@ describe('RuntimeExecutor.streamText', () => {
})
).rejects.toThrow('Stream error')
// onError receives the original error and context with core fields
expect(errorPlugin.onError).toHaveBeenCalledWith(
error,
expect.objectContaining({
providerId: 'openai',
model: 'gpt-4'
modelId: 'gpt-4'
})
)
})

View File

@@ -1,20 +1,12 @@
import path from 'node:path'
import { fileURLToPath } from 'node:url'
import { defineConfig } from 'vitest/config'
const __dirname = path.dirname(fileURLToPath(import.meta.url))
export default defineConfig({
test: {
globals: true,
setupFiles: [path.resolve(__dirname, './src/__tests__/setup.ts')]
globals: true
},
resolve: {
alias: {
'@': path.resolve(__dirname, './src'),
// Mock external packages that may not be available in test environment
'@cherrystudio/ai-sdk-provider': path.resolve(__dirname, './src/__tests__/mocks/ai-sdk-provider.ts')
'@': './src'
}
},
esbuild: {

View File

@@ -374,5 +374,13 @@ export enum IpcChannel {
WebSocket_Stop = 'webSocket:stop',
WebSocket_Status = 'webSocket:status',
WebSocket_SendFile = 'webSocket:send-file',
WebSocket_GetAllCandidates = 'webSocket:get-all-candidates'
WebSocket_GetAllCandidates = 'webSocket:get-all-candidates',
// Volcengine
Volcengine_SaveCredentials = 'volcengine:save-credentials',
Volcengine_HasCredentials = 'volcengine:has-credentials',
Volcengine_ClearCredentials = 'volcengine:clear-credentials',
Volcengine_ListModels = 'volcengine:list-models',
Volcengine_GetAuthHeaders = 'volcengine:get-auth-headers',
Volcengine_MakeRequest = 'volcengine:make-request'
}

View File

@@ -9,27 +9,13 @@
*/
import Anthropic from '@anthropic-ai/sdk'
import type { MessageCreateParams, TextBlockParam, Tool as AnthropicTool } from '@anthropic-ai/sdk/resources'
import type { TextBlockParam } from '@anthropic-ai/sdk/resources'
import { loggerService } from '@logger'
import { type Provider, SystemProviderIds } from '@types'
import type { Provider } from '@types'
import type { ModelMessage } from 'ai'
const logger = loggerService.withContext('anthropic-sdk')
/**
* Context for Anthropic SDK client creation.
* This allows the shared module to be used in different environments
* by providing environment-specific implementations.
*/
export interface AnthropicSdkContext {
/**
* Custom fetch function to use for HTTP requests.
* In Electron main process, this should be `net.fetch`.
* In other environments, can use the default fetch or a custom implementation.
*/
fetch?: typeof globalThis.fetch
}
const defaultClaudeCodeSystemPrompt = `You are Claude Code, Anthropic's official CLI for Claude.`
const defaultClaudeCodeSystem: Array<TextBlockParam> = [
@@ -72,11 +58,8 @@ const defaultClaudeCodeSystem: Array<TextBlockParam> = [
export function getSdkClient(
provider: Provider,
oauthToken?: string | null,
extraHeaders?: Record<string, string | string[]>,
context?: AnthropicSdkContext
extraHeaders?: Record<string, string | string[]>
): Anthropic {
const customFetch = context?.fetch
if (provider.authType === 'oauth') {
if (!oauthToken) {
throw new Error('OAuth token is not available')
@@ -102,8 +85,7 @@ export function getSdkClient(
'x-stainless-runtime': 'node',
'x-stainless-runtime-version': 'v22.18.0',
...extraHeaders
},
fetch: customFetch
}
})
}
let baseURL =
@@ -124,12 +106,11 @@ export function getSdkClient(
baseURL,
dangerouslyAllowBrowser: true,
defaultHeaders: {
'anthropic-beta': 'interleaved-thinking-2025-05-14',
'anthropic-beta': 'output-128k-2025-02-19',
'APP-Code': 'MLTG2087',
...provider.extra_headers,
...extraHeaders
},
fetch: customFetch
}
})
}
@@ -139,11 +120,9 @@ export function getSdkClient(
baseURL,
dangerouslyAllowBrowser: true,
defaultHeaders: {
'anthropic-beta': 'interleaved-thinking-2025-05-14',
Authorization: provider.id === SystemProviderIds.longcat ? `Bearer ${provider.apiKey}` : undefined,
'anthropic-beta': 'output-128k-2025-02-19',
...provider.extra_headers
},
fetch: customFetch
}
})
}
@@ -194,31 +173,3 @@ export function buildClaudeCodeSystemModelMessage(system?: string | Array<TextBl
content: block.text
}))
}
/**
* Sanitize tool definitions for Anthropic API.
*
* Removes non-standard fields like `input_examples` from tool definitions
* that Anthropic's API doesn't support. This prevents validation errors when
* tools with extended fields are passed to the Anthropic SDK.
*
* @param tools - Array of tool definitions from MessageCreateParams
* @returns Sanitized tools array with non-standard fields removed
*
* @example
* ```typescript
* const sanitizedTools = sanitizeToolsForAnthropic(request.tools)
* ```
*/
export function sanitizeToolsForAnthropic(tools?: MessageCreateParams['tools']): MessageCreateParams['tools'] {
if (!tools || tools.length === 0) return tools
return tools.map((tool) => {
if ('type' in tool && tool.type !== 'custom') return tool
// oxlint-disable-next-line no-unused-vars
const { input_examples, ...sanitizedTool } = tool as AnthropicTool & { input_examples?: unknown }
return sanitizedTool as typeof tool
})
}

View File

@@ -1,245 +0,0 @@
/**
* Shared API Utilities
*
* Common utilities for API URL formatting and validation.
* Used by both main process (API Server) and renderer.
*/
import type { MinimalProvider } from '@shared/provider'
import { trim } from 'lodash'
// Supported endpoints for routing
export const SUPPORTED_IMAGE_ENDPOINT_LIST = ['images/generations', 'images/edits', 'predict'] as const
export const SUPPORTED_ENDPOINT_LIST = [
'chat/completions',
'responses',
'messages',
'generateContent',
'streamGenerateContent',
...SUPPORTED_IMAGE_ENDPOINT_LIST
] as const
/**
* Removes the trailing slash from a URL string if it exists.
*/
export function withoutTrailingSlash<T extends string>(url: T): T {
return url.replace(/\/$/, '') as T
}
/**
* Matches a version segment in a path that starts with `/v<number>` and optionally
* continues with `alpha` or `beta`. The segment may be followed by `/` or the end
* of the string (useful for cases like `/v3alpha/resources`).
*/
const VERSION_REGEX_PATTERN = '\\/v\\d+(?:alpha|beta)?(?=\\/|$)'
/**
* Matches an API version at the end of a URL (with optional trailing slash).
* Used to detect and extract versions only from the trailing position.
*/
const TRAILING_VERSION_REGEX = /\/v\d+(?:alpha|beta)?\/?$/i
/**
* 判断 host 的 path 中是否包含形如版本的字符串(例如 /v1、/v2beta 等),
*
* @param host - 要检查的 host 或 path 字符串
* @returns 如果 path 中包含版本字符串则返回 true否则 false
*/
export function hasAPIVersion(host?: string): boolean {
if (!host) return false
const regex = new RegExp(VERSION_REGEX_PATTERN, 'i')
try {
const url = new URL(host)
return regex.test(url.pathname)
} catch {
// 若无法作为完整 URL 解析,则当作路径直接检测
return regex.test(host)
}
}
/**
* 格式化 Azure OpenAI 的 API 主机地址。
*/
export function formatAzureOpenAIApiHost(host: string): string {
const normalizedHost = withoutTrailingSlash(host)
?.replace(/\/v1$/, '')
.replace(/\/openai$/, '')
// NOTE: AISDK会添加上`v1`
return formatApiHost(normalizedHost + '/openai', false)
}
export function formatVertexApiHost(
provider: MinimalProvider,
project: string = 'test-project',
location: string = 'us-central1'
): string {
const { apiHost } = provider
const trimmedHost = withoutTrailingSlash(trim(apiHost))
if (!trimmedHost || trimmedHost.endsWith('aiplatform.googleapis.com')) {
const host =
location === 'global' ? 'https://aiplatform.googleapis.com' : `https://${location}-aiplatform.googleapis.com`
return `${formatApiHost(host)}/projects/${project}/locations/${location}`
}
return formatApiHost(trimmedHost)
}
/**
* Formats an API host URL by normalizing it and optionally appending an API version.
*
* @param host - The API host URL to format. Leading/trailing whitespace will be trimmed and trailing slashes removed.
* @param supportApiVersion - Whether the API version is supported. Defaults to `true`.
* @param apiVersion - The API version to append if needed. Defaults to `'v1'`.
*
* @returns The formatted API host URL. If the host is empty after normalization, returns an empty string.
* If the host ends with '#', API version is not supported, or the host already contains a version, returns the normalized host as-is.
* Otherwise, returns the host with the API version appended.
*
* @example
* formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1'
* formatApiHost('https://api.example.com#') // Returns 'https://api.example.com#'
* formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2'
*/
export function formatApiHost(host?: string, supportApiVersion: boolean = true, apiVersion: string = 'v1'): string {
const normalizedHost = withoutTrailingSlash(trim(host))
if (!normalizedHost) {
return ''
}
if (normalizedHost.endsWith('#') || !supportApiVersion || hasAPIVersion(normalizedHost)) {
return normalizedHost
}
return `${normalizedHost}/${apiVersion}`
}
/**
* Converts an API host URL into separate base URL and endpoint components.
*
* This function extracts endpoint information from a composite API host string.
* If the host ends with '#', it attempts to match the preceding part against the supported endpoint list.
*
* @param apiHost - The API host string to parse
* @returns An object containing:
* - `baseURL`: The base URL without the endpoint suffix
* - `endpoint`: The matched endpoint identifier, or empty string if no match found
*
* @example
* routeToEndpoint('https://api.example.com/openai/chat/completions#')
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: 'chat/completions' }
*
* @example
* routeToEndpoint('https://api.example.com/v1')
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: '' }
*/
export function routeToEndpoint(apiHost: string): { baseURL: string; endpoint: string } {
const trimmedHost = (apiHost || '').trim()
if (!trimmedHost.endsWith('#')) {
return { baseURL: trimmedHost, endpoint: '' }
}
// Remove trailing #
const host = trimmedHost.slice(0, -1)
const endpointMatch = SUPPORTED_ENDPOINT_LIST.find((endpoint) => host.endsWith(endpoint))
if (!endpointMatch) {
const baseURL = withoutTrailingSlash(host)
return { baseURL, endpoint: '' }
}
const baseSegment = host.slice(0, host.length - endpointMatch.length)
const baseURL = withoutTrailingSlash(baseSegment).replace(/:$/, '') // Remove trailing colon (gemini special case)
return { baseURL, endpoint: endpointMatch }
}
/**
* Gets the AI SDK compatible base URL from a provider's apiHost.
*
* AI SDK expects baseURL WITH version suffix (e.g., /v1).
* This function:
* 1. Handles '#' endpoint routing format
* 2. Ensures the URL has a version suffix (adds /v1 if missing)
*
* @param apiHost - The provider's apiHost value (may or may not have /v1)
* @param apiVersion - The API version to use if missing. Defaults to 'v1'.
* @returns The baseURL suitable for AI SDK (with version suffix)
*
* @example
* getAiSdkBaseUrl('https://api.openai.com') // 'https://api.openai.com/v1'
* getAiSdkBaseUrl('https://api.openai.com/v1') // 'https://api.openai.com/v1'
* getAiSdkBaseUrl('https://api.example.com/chat/completions#') // 'https://api.example.com'
*/
export function getAiSdkBaseUrl(apiHost: string, apiVersion: string = 'v1'): string {
// First handle '#' endpoint routing format
const { baseURL } = routeToEndpoint(apiHost)
// If already has version, return as-is
if (hasAPIVersion(baseURL)) {
return withoutTrailingSlash(baseURL)
}
// Add version suffix
return `${withoutTrailingSlash(baseURL)}/${apiVersion}`
}
/**
* Validates an API host address.
*
* @param apiHost - The API host address to validate
* @returns true if valid URL with http/https protocol, false otherwise
*/
export function validateApiHost(apiHost: string): boolean {
if (!apiHost || !apiHost.trim()) {
return true // Allow empty
}
try {
const url = new URL(apiHost.trim())
return url.protocol === 'http:' || url.protocol === 'https:'
} catch {
return false
}
}
/**
* Extracts the trailing API version segment from a URL path.
*
* This function extracts API version patterns (e.g., `v1`, `v2beta`) from the end of a URL.
* Only versions at the end of the path are extracted, not versions in the middle.
* The returned version string does not include leading or trailing slashes.
*
* @param {string} url - The URL string to parse.
* @returns {string | undefined} The trailing API version found (e.g., 'v1', 'v2beta'), or undefined if none found.
*
* @example
* getTrailingApiVersion('https://api.example.com/v1') // 'v1'
* getTrailingApiVersion('https://api.example.com/v2beta/') // 'v2beta'
* getTrailingApiVersion('https://api.example.com/v1/chat') // undefined (version not at end)
* getTrailingApiVersion('https://gateway.ai.cloudflare.com/v1/xxx/v1beta') // 'v1beta'
* getTrailingApiVersion('https://api.example.com') // undefined
*/
export function getTrailingApiVersion(url: string): string | undefined {
const match = url.match(TRAILING_VERSION_REGEX)
if (match) {
// Extract version without leading slash and trailing slash
return match[0].replace(/^\//, '').replace(/\/$/, '')
}
return undefined
}
/**
* Removes the trailing API version segment from a URL path.
*
* This function removes API version patterns (e.g., `/v1`, `/v2beta`) from the end of a URL.
* Only versions at the end of the path are removed, not versions in the middle.
*
* @param {string} url - The URL string to process.
* @returns {string} The URL with the trailing API version removed, or the original URL if no trailing version found.
*
* @example
* withoutTrailingApiVersion('https://api.example.com/v1') // 'https://api.example.com'
* withoutTrailingApiVersion('https://api.example.com/v2beta/') // 'https://api.example.com'
* withoutTrailingApiVersion('https://api.example.com/v1/chat') // 'https://api.example.com/v1/chat' (no change)
* withoutTrailingApiVersion('https://api.example.com') // 'https://api.example.com'
*/
export function withoutTrailingApiVersion(url: string): string {
return url.replace(TRAILING_VERSION_REGEX, '')
}

View File

@@ -43,35 +43,6 @@ export function isSiliconAnthropicCompatibleModel(modelId: string): boolean {
}
/**
* PPIO provider models that support Anthropic API endpoint.
* These models can be used with Claude Code via the Anthropic-compatible API.
*
* @see https://ppio.com/docs/model/llm-anthropic-compatibility
* Silicon provider's Anthropic API host URL.
*/
export const PPIO_ANTHROPIC_COMPATIBLE_MODELS: readonly string[] = [
'moonshotai/kimi-k2-thinking',
'minimax/minimax-m2',
'deepseek/deepseek-v3.2-exp',
'deepseek/deepseek-v3.1-terminus',
'zai-org/glm-4.6',
'moonshotai/kimi-k2-0905',
'deepseek/deepseek-v3.1',
'moonshotai/kimi-k2-instruct',
'qwen/qwen3-next-80b-a3b-instruct',
'qwen/qwen3-next-80b-a3b-thinking'
]
/**
* Creates a Set for efficient lookup of PPIO Anthropic-compatible model IDs.
*/
const PPIO_ANTHROPIC_COMPATIBLE_MODEL_SET = new Set(PPIO_ANTHROPIC_COMPATIBLE_MODELS)
/**
* Checks if a model ID is compatible with Anthropic API on PPIO provider.
*
* @param modelId - The model ID to check
* @returns true if the model supports Anthropic API endpoint
*/
export function isPpioAnthropicCompatibleModel(modelId: string): boolean {
return PPIO_ANTHROPIC_COMPATIBLE_MODEL_SET.has(modelId)
}
export const SILICON_ANTHROPIC_API_HOST = 'https://api.siliconflow.cn'

View File

@@ -1,15 +0,0 @@
/**
* Shared AI SDK Middlewares
*
* Environment-agnostic middlewares that can be used in both
* renderer process and main process (API server).
*/
export {
buildSharedMiddlewares,
getReasoningTagName,
isGemini3ModelId,
openrouterReasoningMiddleware,
type SharedMiddlewareConfig,
skipGeminiThoughtSignatureMiddleware
} from './middlewares'

View File

@@ -1,205 +0,0 @@
/**
* Shared AI SDK Middlewares
*
* These middlewares are environment-agnostic and can be used in both
* renderer process and main process (API server).
*/
import type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider'
import { extractReasoningMiddleware } from 'ai'
/**
* Configuration for building shared middlewares
*/
export interface SharedMiddlewareConfig {
/**
* Whether to enable reasoning extraction
*/
enableReasoning?: boolean
/**
* Tag name for reasoning extraction
* Defaults based on model ID
*/
reasoningTagName?: string
/**
* Model ID - used to determine default reasoning tag and model detection
*/
modelId?: string
/**
* Provider ID (Cherry Studio provider ID)
* Used for provider-specific middlewares like OpenRouter
*/
providerId?: string
/**
* AI SDK Provider ID
* Used for Gemini thought signature middleware
* e.g., 'google', 'google-vertex'
*/
aiSdkProviderId?: string
}
/**
* Check if model ID represents a Gemini 3 (2.5) model
* that requires thought signature handling
*
* @param modelId - The model ID string (not Model object)
*/
export function isGemini3ModelId(modelId?: string): boolean {
if (!modelId) return false
const lowerModelId = modelId.toLowerCase()
return lowerModelId.includes('gemini-3')
}
/**
* Get the default reasoning tag name based on model ID
*
* Different models use different tags for reasoning content:
* - Most models: 'think'
* - GPT-OSS models: 'reasoning'
* - Gemini models: 'thought'
* - Seed models: 'seed:think'
*/
export function getReasoningTagName(modelId?: string): string {
if (!modelId) return 'think'
const lowerModelId = modelId.toLowerCase()
if (lowerModelId.includes('gpt-oss')) return 'reasoning'
if (lowerModelId.includes('gemini')) return 'thought'
if (lowerModelId.includes('seed-oss-36b')) return 'seed:think'
return 'think'
}
/**
* Skip Gemini Thought Signature Middleware
*
* Due to the complexity of multi-model client requests (which can switch
* to other models mid-process), this middleware skips all Gemini 3
* thinking signatures validation.
*
* @param aiSdkId - AI SDK Provider ID (e.g., 'google', 'google-vertex')
* @returns LanguageModelV2Middleware
*/
export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelV2Middleware {
const MAGIC_STRING = 'skip_thought_signature_validator'
return {
middlewareVersion: 'v2',
transformParams: async ({ params }) => {
const transformedParams = { ...params }
// Process messages in prompt
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
transformedParams.prompt = transformedParams.prompt.map((message) => {
if (typeof message.content !== 'string') {
for (const part of message.content) {
const googleOptions = part?.providerOptions?.[aiSdkId]
if (googleOptions?.thoughtSignature) {
googleOptions.thoughtSignature = MAGIC_STRING
}
}
}
return message
})
}
return transformedParams
}
}
}
/**
* OpenRouter Reasoning Middleware
*
* Filters out [REDACTED] blocks from OpenRouter reasoning responses.
* OpenRouter may include [REDACTED] markers in reasoning content that
* should be removed for cleaner output.
*
* @see https://openrouter.ai/docs/docs/best-practices/reasoning-tokens
* @returns LanguageModelV2Middleware
*/
export function openrouterReasoningMiddleware(): LanguageModelV2Middleware {
const REDACTED_BLOCK = '[REDACTED]'
return {
middlewareVersion: 'v2',
wrapGenerate: async ({ doGenerate }) => {
const { content, ...rest } = await doGenerate()
const modifiedContent = content.map((part) => {
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
return {
...part,
text: part.text.replace(REDACTED_BLOCK, '')
}
}
return part
})
return { content: modifiedContent, ...rest }
},
wrapStream: async ({ doStream }) => {
const { stream, ...rest } = await doStream()
return {
stream: stream.pipeThrough(
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
transform(
chunk: LanguageModelV2StreamPart,
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
) {
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
controller.enqueue({
...chunk,
delta: chunk.delta.replace(REDACTED_BLOCK, '')
})
} else {
controller.enqueue(chunk)
}
}
})
),
...rest
}
}
}
}
/**
* Build shared middlewares based on configuration
*
* This function builds a set of middlewares that are commonly needed
* across different environments (renderer, API server).
*
* @param config - Configuration for middleware building
* @returns Array of AI SDK middlewares
*
* @example
* ```typescript
* import { buildSharedMiddlewares } from '@shared/middleware'
*
* const middlewares = buildSharedMiddlewares({
* enableReasoning: true,
* modelId: 'gemini-2.5-pro',
* providerId: 'openrouter',
* aiSdkProviderId: 'google'
* })
* ```
*/
export function buildSharedMiddlewares(config: SharedMiddlewareConfig): LanguageModelV2Middleware[] {
const middlewares: LanguageModelV2Middleware[] = []
// 1. Reasoning extraction middleware
if (config.enableReasoning) {
const tagName = config.reasoningTagName || getReasoningTagName(config.modelId)
middlewares.push(extractReasoningMiddleware({ tagName }))
}
// 2. OpenRouter-specific: filter [REDACTED] blocks
if (config.providerId === 'openrouter' && config.enableReasoning) {
middlewares.push(openrouterReasoningMiddleware())
}
// 3. Gemini 3 (2.5) specific: skip thought signature validation
if (isGemini3ModelId(config.modelId) && config.aiSdkProviderId) {
middlewares.push(skipGeminiThoughtSignatureMiddleware(config.aiSdkProviderId))
}
return middlewares
}

View File

@@ -1,22 +0,0 @@
import type { MinimalModel, MinimalProvider, ProviderType } from '../types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
const AZURE_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: MinimalProvider) => ({
...provider,
type: 'anthropic' as ProviderType,
apiHost: provider.apiHost + 'anthropic/v1',
id: 'azure-anthropic'
})
}
],
fallbackRule: (provider: MinimalProvider) => provider
}
export const azureAnthropicProviderCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
provider2Provider<MinimalModel, MinimalProvider, P>(AZURE_ANTHROPIC_RULES, model, provider)

View File

@@ -1,32 +0,0 @@
import type { MinimalModel, MinimalProvider } from '../types'
import type { RuleSet } from './types'
export const startsWith =
(prefix: string) =>
<M extends MinimalModel>(model: M) =>
model.id.toLowerCase().startsWith(prefix.toLowerCase())
export const endpointIs =
(type: string) =>
<M extends MinimalModel>(model: M) =>
model.endpoint_type === type
/**
* 解析模型对应的Provider
* @param ruleSet 规则集对象
* @param model 模型对象
* @param provider 原始provider对象
* @returns 解析出的provider对象
*/
export function provider2Provider<M extends MinimalModel, R extends MinimalProvider, P extends R = R>(
ruleSet: RuleSet<M, R>,
model: M,
provider: P
): P {
for (const rule of ruleSet.rules) {
if (rule.match(model)) {
return rule.provider(provider) as P
}
}
return ruleSet.fallbackRule(provider) as P
}

View File

@@ -1,6 +0,0 @@
export { aihubmixProviderCreator } from './aihubmix'
export { azureAnthropicProviderCreator } from './azure-anthropic'
export { endpointIs, provider2Provider, startsWith } from './helper'
export { newApiResolverCreator } from './newApi'
export type { RuleSet } from './types'
export { vertexAnthropicProviderCreator } from './vertex-anthropic'

View File

@@ -1,9 +0,0 @@
import type { MinimalModel, MinimalProvider } from '../types'
export interface RuleSet<M extends MinimalModel = MinimalModel, P extends MinimalProvider = MinimalProvider> {
rules: Array<{
match: (model: M) => boolean
provider: (provider: P) => P
}>
fallbackRule: (provider: P) => P
}

View File

@@ -1,19 +0,0 @@
import type { MinimalModel, MinimalProvider } from '../types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
const VERTEX_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: MinimalProvider) => ({
...provider,
id: 'google-vertex-anthropic'
})
}
],
fallbackRule: (provider: MinimalProvider) => provider
}
export const vertexAnthropicProviderCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
provider2Provider<MinimalModel, MinimalProvider, P>(VERTEX_ANTHROPIC_RULES, model, provider)

View File

@@ -1,26 +0,0 @@
import { getLowerBaseModelName } from '@shared/utils/naming'
import type { MinimalModel } from './types'
export const COPILOT_EDITOR_VERSION = 'vscode/1.104.1'
export const COPILOT_PLUGIN_VERSION = 'copilot-chat/0.26.7'
export const COPILOT_INTEGRATION_ID = 'vscode-chat'
export const COPILOT_USER_AGENT = 'GitHubCopilotChat/0.26.7'
export const COPILOT_DEFAULT_HEADERS = {
'Copilot-Integration-Id': COPILOT_INTEGRATION_ID,
'User-Agent': COPILOT_USER_AGENT,
'Editor-Version': COPILOT_EDITOR_VERSION,
'Editor-Plugin-Version': COPILOT_PLUGIN_VERSION,
'editor-version': COPILOT_EDITOR_VERSION,
'editor-plugin-version': COPILOT_PLUGIN_VERSION,
'copilot-vision-request': 'true'
} as const
// Models that require the OpenAI Responses endpoint when routed through GitHub Copilot (#10560)
const COPILOT_RESPONSES_MODEL_IDS = ['gpt-5-codex', 'gpt-5.1-codex', 'gpt-5.1-codex-mini']
export function isCopilotResponsesModel<M extends MinimalModel>(model: M): boolean {
const normalizedId = getLowerBaseModelName(model.id)
return COPILOT_RESPONSES_MODEL_IDS.some((target) => normalizedId === target)
}

View File

@@ -1,100 +0,0 @@
/**
* Provider Type Detection Utilities
*
* Functions to detect provider types based on provider configuration.
* These are pure functions that only depend on provider.type and provider.id.
*
* NOTE: These functions should match the logic in @renderer/utils/provider.ts
*/
import type { MinimalProvider } from './types'
/**
* Check if provider is Anthropic type
*/
export function isAnthropicProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'anthropic'
}
/**
* Check if provider is OpenAI Response type (openai-response)
* NOTE: This matches isOpenAIProvider in renderer/utils/provider.ts
*/
export function isOpenAIProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'openai-response'
}
/**
* Check if provider is Gemini type
*/
export function isGeminiProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'gemini'
}
/**
* Check if provider is Azure OpenAI type
*/
export function isAzureOpenAIProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'azure-openai'
}
/**
* Check if provider is Vertex AI type
*/
export function isVertexProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'vertexai'
}
/**
* Check if provider is AWS Bedrock type
*/
export function isAwsBedrockProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'aws-bedrock'
}
/**
* Check if provider is AI Gateway type
*/
export function isAIGatewayProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.type === 'ai-gateway'
}
/**
* Check if Azure OpenAI provider uses responses endpoint
* Matches isAzureResponsesEndpoint in renderer/utils/provider.ts
*/
export function isAzureResponsesEndpoint<P extends MinimalProvider>(provider: P): boolean {
return provider.apiVersion === 'preview' || provider.apiVersion === 'v1'
}
/**
* Check if provider is Cherry AI type
* Matches isCherryAIProvider in renderer/utils/provider.ts
*/
export function isCherryAIProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.id === 'cherryai'
}
/**
* Check if provider is Perplexity type
* Matches isPerplexityProvider in renderer/utils/provider.ts
*/
export function isPerplexityProvider<P extends MinimalProvider>(provider: P): boolean {
return provider.id === 'perplexity'
}
/**
* Check if provider is new-api type (supports multiple backends)
* Matches isNewApiProvider in renderer/utils/provider.ts
*/
export function isNewApiProvider<P extends MinimalProvider>(provider: P): boolean {
return ['new-api', 'cherryin'].includes(provider.id) || provider.type === ('new-api' as string)
}
/**
* Check if provider is OpenAI compatible
* Matches isOpenAICompatibleProvider in renderer/utils/provider.ts
*/
export function isOpenAICompatibleProvider<P extends MinimalProvider>(provider: P): boolean {
return ['openai', 'new-api', 'mistral'].includes(provider.type)
}

View File

@@ -1,136 +0,0 @@
/**
* Provider API Host Formatting
*
* Utilities for formatting provider API hosts to work with AI SDK.
* These handle the differences between how Cherry Studio stores API hosts
* and how AI SDK expects them.
*/
import {
formatApiHost,
formatAzureOpenAIApiHost,
formatVertexApiHost,
routeToEndpoint,
withoutTrailingSlash
} from '../api'
import {
isAnthropicProvider,
isAzureOpenAIProvider,
isCherryAIProvider,
isGeminiProvider,
isPerplexityProvider,
isVertexProvider
} from './detection'
import type { MinimalProvider } from './types'
import { SystemProviderIds } from './types'
/**
* Interface for environment-specific implementations
* Renderer and Main process can provide their own implementations
*/
export interface ProviderFormatContext {
vertex: {
project: string
location: string
}
}
/**
* Default Azure OpenAI API host formatter
*/
export function defaultFormatAzureOpenAIApiHost(host: string): string {
const normalizedHost = withoutTrailingSlash(host)
?.replace(/\/v1$/, '')
.replace(/\/openai$/, '')
// AI SDK will add /v1
return formatApiHost(normalizedHost + '/openai', false)
}
/**
* Format provider API host for AI SDK
*
* This function normalizes the apiHost to work with AI SDK.
* Different providers have different requirements:
* - Most providers: add /v1 suffix
* - Gemini: add /v1beta suffix
* - Some providers: no suffix needed
*
* @param provider - The provider to format
* @param context - Optional context with environment-specific implementations
* @returns Provider with formatted apiHost (and anthropicApiHost if applicable)
*/
export function formatProviderApiHost<T extends MinimalProvider>(provider: T, context: ProviderFormatContext): T {
const formatted = { ...provider }
// Format anthropicApiHost if present
if (formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost)
}
// Format based on provider type
if (isAnthropicProvider(provider)) {
const baseHost = formatted.anthropicApiHost || formatted.apiHost
// AI SDK needs /v1 in baseURL
formatted.apiHost = formatApiHost(baseHost)
if (!formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatted.apiHost
}
} else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isGeminiProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, true, 'v1beta')
} else if (isAzureOpenAIProvider(formatted)) {
formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost)
} else if (isVertexProvider(formatted)) {
formatted.apiHost = formatVertexApiHost(formatted, context.vertex.project, context.vertex.location)
} else if (isCherryAIProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isPerplexityProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else {
formatted.apiHost = formatApiHost(formatted.apiHost)
}
return formatted
}
/**
* Get the base URL for AI SDK from a formatted provider
*
* This extracts the baseURL that AI SDK expects, handling
* the '#' endpoint routing format if present.
*
* @param formattedApiHost - The formatted apiHost (after formatProviderApiHost)
* @returns The baseURL for AI SDK
*/
export function getBaseUrlForAiSdk(formattedApiHost: string): string {
const { baseURL } = routeToEndpoint(formattedApiHost)
return baseURL
}
/**
* Get rotated API key from comma-separated keys
*
* This is the interface for API key rotation. The actual implementation
* depends on the environment (renderer uses window.keyv, main uses its own storage).
*/
export interface ApiKeyRotator {
/**
* Get the next API key in rotation
* @param providerId - The provider ID for tracking rotation
* @param keys - Comma-separated API keys
* @returns The next API key to use
*/
getRotatedKey(providerId: string, keys: string): string
}
/**
* Simple API key rotator that always returns the first key
* Use this when rotation is not needed
*/
export const simpleKeyRotator: ApiKeyRotator = {
getRotatedKey(_providerId: string, keys: string): string {
const keyList = keys.split(',').map((k) => k.trim())
return keyList[0] || keys
}
}

View File

@@ -1,48 +0,0 @@
/**
* Shared Provider Utilities
*
* This module exports utilities for working with AI providers
* that can be shared between main process and renderer process.
*/
// Type definitions
export type { MinimalProvider, ProviderType, SystemProviderId } from './types'
export { SystemProviderIds } from './types'
// Provider type detection
export {
isAIGatewayProvider,
isAnthropicProvider,
isAwsBedrockProvider,
isAzureOpenAIProvider,
isAzureResponsesEndpoint,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isOpenAICompatibleProvider,
isOpenAIProvider,
isPerplexityProvider,
isVertexProvider
} from './detection'
// API host formatting
export type { ApiKeyRotator, ProviderFormatContext } from './format'
export {
defaultFormatAzureOpenAIApiHost,
formatProviderApiHost,
getBaseUrlForAiSdk,
simpleKeyRotator
} from './format'
// Provider ID mapping
export { getAiSdkProviderId, STATIC_PROVIDER_MAPPING, tryResolveProviderId } from './mapping'
// AI SDK configuration
export type { AiSdkConfig, AiSdkConfigContext } from './sdk-config'
export { providerToAiSdkConfig } from './sdk-config'
// Provider resolution
export { resolveActualProvider } from './resolve'
// Provider initialization
export { initializeSharedProviders, SHARED_PROVIDER_CONFIGS } from './initialization'

View File

@@ -1,107 +0,0 @@
import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider'
type ProviderInitializationLogger = {
warn?: (message: string) => void
error?: (message: string, error: Error) => void
}
export const SHARED_PROVIDER_CONFIGS: ProviderConfig[] = [
{
id: 'openrouter',
name: 'OpenRouter',
import: () => import('@openrouter/ai-sdk-provider'),
creatorFunctionName: 'createOpenRouter',
supportsImageGeneration: true,
aliases: ['openrouter']
},
{
id: 'google-vertex',
name: 'Google Vertex AI',
import: () => import('@ai-sdk/google-vertex/edge'),
creatorFunctionName: 'createVertex',
supportsImageGeneration: true,
aliases: ['vertexai']
},
{
id: 'google-vertex-anthropic',
name: 'Google Vertex AI Anthropic',
import: () => import('@ai-sdk/google-vertex/anthropic/edge'),
creatorFunctionName: 'createVertexAnthropic',
supportsImageGeneration: true,
aliases: ['vertexai-anthropic']
},
{
id: 'azure-anthropic',
name: 'Azure AI Anthropic',
import: () => import('@ai-sdk/anthropic'),
creatorFunctionName: 'createAnthropic',
supportsImageGeneration: false,
aliases: ['azure-anthropic']
},
{
id: 'github-copilot-openai-compatible',
name: 'GitHub Copilot OpenAI Compatible',
import: () => import('@opeoginni/github-copilot-openai-compatible'),
creatorFunctionName: 'createGitHubCopilotOpenAICompatible',
supportsImageGeneration: false,
aliases: ['copilot', 'github-copilot']
},
{
id: 'bedrock',
name: 'Amazon Bedrock',
import: () => import('@ai-sdk/amazon-bedrock'),
creatorFunctionName: 'createAmazonBedrock',
supportsImageGeneration: true,
aliases: ['aws-bedrock']
},
{
id: 'perplexity',
name: 'Perplexity',
import: () => import('@ai-sdk/perplexity'),
creatorFunctionName: 'createPerplexity',
supportsImageGeneration: false,
aliases: ['perplexity']
},
{
id: 'mistral',
name: 'Mistral',
import: () => import('@ai-sdk/mistral'),
creatorFunctionName: 'createMistral',
supportsImageGeneration: false,
aliases: ['mistral']
},
{
id: 'huggingface',
name: 'HuggingFace',
import: () => import('@ai-sdk/huggingface'),
creatorFunctionName: 'createHuggingFace',
supportsImageGeneration: true,
aliases: ['hf', 'hugging-face']
},
{
id: 'ai-gateway',
name: 'AI Gateway',
import: () => import('@ai-sdk/gateway'),
creatorFunctionName: 'createGateway',
supportsImageGeneration: true,
aliases: ['gateway']
},
{
id: 'cerebras',
name: 'Cerebras',
import: () => import('@ai-sdk/cerebras'),
creatorFunctionName: 'createCerebras',
supportsImageGeneration: false
}
] as const
export function initializeSharedProviders(logger?: ProviderInitializationLogger): void {
try {
const successCount = registerMultipleProviderConfigs(SHARED_PROVIDER_CONFIGS)
if (successCount < SHARED_PROVIDER_CONFIGS.length) {
logger?.warn?.('Some providers failed to register. Check previous error logs.')
}
} catch (error) {
logger?.error?.('Failed to initialize shared providers', error as Error)
}
}

View File

@@ -1,95 +0,0 @@
/**
* Provider ID Mapping
*
* Maps Cherry Studio provider IDs/types to AI SDK provider IDs.
* This logic should match @renderer/aiCore/provider/factory.ts
*/
import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider'
import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from './detection'
import type { MinimalProvider } from './types'
/**
* Static mapping from Cherry Studio provider ID/type to AI SDK provider ID
* Matches STATIC_PROVIDER_MAPPING in @renderer/aiCore/provider/factory.ts
*/
export const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
gemini: 'google', // Google Gemini -> google
'azure-openai': 'azure', // Azure OpenAI -> azure
'openai-response': 'openai', // OpenAI Responses -> openai
grok: 'xai', // Grok -> xai
copilot: 'github-copilot-openai-compatible'
}
/**
* Try to resolve a provider identifier to an AI SDK provider ID
* Matches tryResolveProviderId in @renderer/aiCore/provider/factory.ts
*
* @param identifier - The provider ID or type to resolve
* @param checker - Provider config checker (defaults to static mapping only)
* @returns The resolved AI SDK provider ID, or null if not found
*/
export function tryResolveProviderId(identifier: string): ProviderId | null {
// 1. 检查静态映射
const staticMapping = STATIC_PROVIDER_MAPPING[identifier]
if (staticMapping) {
return staticMapping
}
// 2. 检查AiCore是否支持包括别名支持
if (hasProviderConfigByAlias(identifier)) {
// 解析为真实的Provider ID
return resolveProviderConfigId(identifier) as ProviderId
}
return null
}
/**
* Get the AI SDK Provider ID for a Cherry Studio provider
* Matches getAiSdkProviderId in @renderer/aiCore/provider/factory.ts
*
* Logic:
* 1. Handle Azure OpenAI specially (check responses endpoint)
* 2. Try to resolve from provider.id
* 3. Try to resolve from provider.type (but not for generic 'openai' type)
* 4. Check for OpenAI API host pattern
* 5. Fallback to provider's own ID
*
* @param provider - The Cherry Studio provider
* @param checker - Provider config checker (defaults to static mapping only)
* @returns The AI SDK provider ID to use
*/
export function getAiSdkProviderId(provider: MinimalProvider): ProviderId {
// 1. Handle Azure OpenAI specially - check this FIRST before other resolution
if (isAzureOpenAIProvider(provider)) {
if (isAzureResponsesEndpoint(provider)) {
return 'azure-responses'
}
return 'azure'
}
// 2. 尝试解析provider.id
const resolvedFromId = tryResolveProviderId(provider.id)
if (resolvedFromId) {
return resolvedFromId
}
// 3. 尝试解析provider.type
// 会把所有类型为openai的自定义provider解析到aisdk的openaiProvider上
if (provider.type !== 'openai') {
const resolvedFromType = tryResolveProviderId(provider.type)
if (resolvedFromType) {
return resolvedFromType
}
}
// 4. Check for OpenAI API host pattern
if (provider.apiHost.includes('api.openai.com')) {
return 'openai-chat'
}
// 5. 最后的fallback使用provider本身的id
return provider.id
}

View File

@@ -1,43 +0,0 @@
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
import { isAzureOpenAIProvider, isNewApiProvider } from './detection'
import type { MinimalModel, MinimalProvider } from './types'
export interface ResolveActualProviderOptions<P extends MinimalProvider> {
isSystemProvider?: (provider: P) => boolean
}
const defaultIsSystemProvider = <P extends MinimalProvider>(provider: P): boolean => {
if ('isSystem' in provider) {
return Boolean((provider as unknown as { isSystem?: boolean }).isSystem)
}
return false
}
export function resolveActualProvider<M extends MinimalModel, P extends MinimalProvider>(
provider: P,
model: M,
options: ResolveActualProviderOptions<P> = {}
): P {
let resolvedProvider = provider
if (isNewApiProvider(resolvedProvider)) {
resolvedProvider = newApiResolverCreator(model, resolvedProvider)
}
const isSystemProvider = options.isSystemProvider?.(resolvedProvider) ?? defaultIsSystemProvider(resolvedProvider)
if (isSystemProvider && resolvedProvider.id === 'aihubmix') {
resolvedProvider = aihubmixProviderCreator(model, resolvedProvider)
}
if (isSystemProvider && resolvedProvider.id === 'vertexai') {
resolvedProvider = vertexAnthropicProviderCreator(model, resolvedProvider)
}
if (isAzureOpenAIProvider(resolvedProvider)) {
resolvedProvider = azureAnthropicProviderCreator(model, resolvedProvider)
}
return resolvedProvider
}

View File

@@ -1,259 +0,0 @@
/**
* AI SDK Configuration
*
* Shared utilities for converting Cherry Studio Provider to AI SDK configuration.
* Environment-specific logic (renderer/main) is injected via context interfaces.
*/
import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
import { routeToEndpoint } from '../api'
import { getAiSdkProviderId } from './mapping'
import type { MinimalProvider } from './types'
import { SystemProviderIds } from './types'
/**
* AI SDK configuration result
*/
export interface AiSdkConfig {
providerId: string
options: Record<string, unknown>
}
/**
* Context for environment-specific implementations
*/
export interface AiSdkConfigContext {
/**
* Get the rotated API key (for multi-key support)
* Default: returns first key
*/
getRotatedApiKey?: (provider: MinimalProvider) => string
/**
* Check if a model uses chat completion only (for OpenAI response mode)
* Default: returns false
*/
isOpenAIChatCompletionOnlyModel?: (modelId: string) => boolean
/**
* Get Copilot default headers (constants)
* Default: returns empty object
*/
getCopilotDefaultHeaders?: () => Record<string, string>
/**
* Get Copilot stored headers from state
* Default: returns empty object
*/
getCopilotStoredHeaders?: () => Record<string, string>
/**
* Get AWS Bedrock configuration
* Default: returns undefined (not configured)
*/
getAwsBedrockConfig?: () =>
| {
authType: 'apiKey' | 'iam'
region: string
apiKey?: string
accessKeyId?: string
secretAccessKey?: string
}
| undefined
/**
* Get Vertex AI configuration
* Default: returns undefined (not configured)
*/
getVertexConfig?: (provider: MinimalProvider) =>
| {
project: string
location: string
googleCredentials: {
privateKey: string
clientEmail: string
}
}
| undefined
/**
* Get endpoint type for cherryin provider
*/
getEndpointType?: (modelId: string) => string | undefined
/**
* Custom fetch implementation
* Main process: use Electron net.fetch
* Renderer process: use browser fetch (default)
*/
fetch?: typeof globalThis.fetch
/**
* Get CherryAI signed fetch wrapper
* Returns a fetch function that adds signature headers to requests
*/
getCherryAISignedFetch?: () => typeof globalThis.fetch
}
/**
* Default simple key rotator - returns first key
*/
function defaultGetRotatedApiKey(provider: MinimalProvider): string {
const keys = provider.apiKey.split(',').map((k) => k.trim())
return keys[0] || provider.apiKey
}
/**
* Convert Cherry Studio Provider to AI SDK configuration
*
* @param provider - The formatted provider (after formatProviderApiHost)
* @param modelId - The model ID to use
* @param context - Environment-specific implementations
* @returns AI SDK configuration
*/
export function providerToAiSdkConfig(
provider: MinimalProvider,
modelId: string,
context: AiSdkConfigContext = {}
): AiSdkConfig {
const getRotatedApiKey = context.getRotatedApiKey || defaultGetRotatedApiKey
const isOpenAIChatCompletionOnlyModel = context.isOpenAIChatCompletionOnlyModel || (() => false)
const aiSdkProviderId = getAiSdkProviderId(provider)
// Build base config
const { baseURL, endpoint } = routeToEndpoint(provider.apiHost)
const baseConfig = {
baseURL,
apiKey: getRotatedApiKey(provider)
}
// Handle Copilot specially
if (provider.id === SystemProviderIds.copilot) {
const defaultHeaders = context.getCopilotDefaultHeaders?.() ?? {}
const storedHeaders = context.getCopilotStoredHeaders?.() ?? {}
const copilotExtraOptions: Record<string, unknown> = {
headers: {
...defaultHeaders,
...storedHeaders,
...provider.extra_headers
},
name: provider.id,
includeUsage: true
}
if (context.fetch) {
copilotExtraOptions.fetch = context.fetch
}
const options = ProviderConfigFactory.fromProvider(
'github-copilot-openai-compatible',
baseConfig,
copilotExtraOptions
)
return {
providerId: 'github-copilot-openai-compatible',
options
}
}
// Build extra options
const extraOptions: Record<string, unknown> = {}
if (endpoint) {
extraOptions.endpoint = endpoint
}
// Handle OpenAI mode
if (provider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(modelId)) {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && provider.type === 'openai')) {
extraOptions.mode = 'chat'
}
// Add extra headers
if (provider.extra_headers) {
extraOptions.headers = provider.extra_headers
if (aiSdkProviderId === 'openai') {
extraOptions.headers = {
...(extraOptions.headers as Record<string, string>),
'HTTP-Referer': 'https://cherry-ai.com',
'X-Title': 'Cherry Studio',
'X-Api-Key': baseConfig.apiKey
}
}
}
// Handle Azure modes
if (aiSdkProviderId === 'azure-responses') {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'azure') {
extraOptions.mode = 'chat'
}
// Handle AWS Bedrock
if (aiSdkProviderId === 'bedrock') {
const bedrockConfig = context.getAwsBedrockConfig?.()
if (bedrockConfig) {
extraOptions.region = bedrockConfig.region
if (bedrockConfig.authType === 'apiKey') {
extraOptions.apiKey = bedrockConfig.apiKey
} else {
extraOptions.accessKeyId = bedrockConfig.accessKeyId
extraOptions.secretAccessKey = bedrockConfig.secretAccessKey
}
}
}
// Handle Vertex AI
if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') {
const vertexConfig = context.getVertexConfig?.(provider)
if (vertexConfig) {
extraOptions.project = vertexConfig.project
extraOptions.location = vertexConfig.location
extraOptions.googleCredentials = {
...vertexConfig.googleCredentials,
privateKey: formatPrivateKey(vertexConfig.googleCredentials.privateKey)
}
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
}
}
// Handle cherryin endpoint type
if (aiSdkProviderId === 'cherryin') {
const endpointType = context.getEndpointType?.(modelId)
if (endpointType) {
extraOptions.endpointType = endpointType
}
}
// Handle cherryai signed fetch
if (provider.id === 'cherryai') {
const signedFetch = context.getCherryAISignedFetch?.()
if (signedFetch) {
extraOptions.fetch = signedFetch
}
} else if (context.fetch) {
extraOptions.fetch = context.fetch
}
// Check if AI SDK supports this provider natively
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return {
providerId: aiSdkProviderId,
options
}
}
// Fallback to openai-compatible
const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey)
return {
providerId: 'openai-compatible',
options: {
...options,
name: provider.id,
...extraOptions,
includeUsage: true
}
}
}

View File

@@ -1,174 +0,0 @@
import * as z from 'zod'
export const ProviderTypeSchema = z.enum([
'openai',
'openai-response',
'anthropic',
'gemini',
'azure-openai',
'vertexai',
'mistral',
'aws-bedrock',
'vertex-anthropic',
'new-api',
'ai-gateway'
])
export type ProviderType = z.infer<typeof ProviderTypeSchema>
/**
* Minimal provider interface for shared utilities
* This is the subset of Provider that shared code needs
*/
export type MinimalProvider = {
id: string
type: ProviderType
apiKey: string
apiHost: string
anthropicApiHost?: string
apiVersion?: string
extra_headers?: Record<string, string>
}
/**
* Minimal model interface for shared utilities
* This is the subset of Model that shared code needs
*/
export type MinimalModel = {
id: string
endpoint_type?: string
}
export const SystemProviderIdSchema = z.enum([
'cherryin',
'silicon',
'aihubmix',
'ocoolai',
'deepseek',
'ppio',
'alayanew',
'qiniu',
'dmxapi',
'burncloud',
'tokenflux',
'302ai',
'cephalon',
'lanyun',
'ph8',
'openrouter',
'ollama',
'ovms',
'new-api',
'lmstudio',
'anthropic',
'openai',
'azure-openai',
'gemini',
'vertexai',
'github',
'copilot',
'zhipu',
'yi',
'moonshot',
'baichuan',
'dashscope',
'stepfun',
'doubao',
'infini',
'minimax',
'groq',
'together',
'fireworks',
'nvidia',
'grok',
'hyperbolic',
'mistral',
'jina',
'perplexity',
'modelscope',
'xirang',
'hunyuan',
'tencent-cloud-ti',
'baidu-cloud',
'gpustack',
'voyageai',
'aws-bedrock',
'poe',
'aionly',
'longcat',
'huggingface',
'sophnet',
'ai-gateway',
'cerebras'
])
export type SystemProviderId = z.infer<typeof SystemProviderIdSchema>
export const isSystemProviderId = (id: string): id is SystemProviderId => {
return SystemProviderIdSchema.safeParse(id).success
}
export const SystemProviderIds = {
cherryin: 'cherryin',
silicon: 'silicon',
aihubmix: 'aihubmix',
ocoolai: 'ocoolai',
deepseek: 'deepseek',
ppio: 'ppio',
alayanew: 'alayanew',
qiniu: 'qiniu',
dmxapi: 'dmxapi',
burncloud: 'burncloud',
tokenflux: 'tokenflux',
'302ai': '302ai',
cephalon: 'cephalon',
lanyun: 'lanyun',
ph8: 'ph8',
sophnet: 'sophnet',
openrouter: 'openrouter',
ollama: 'ollama',
ovms: 'ovms',
'new-api': 'new-api',
lmstudio: 'lmstudio',
anthropic: 'anthropic',
openai: 'openai',
'azure-openai': 'azure-openai',
gemini: 'gemini',
vertexai: 'vertexai',
github: 'github',
copilot: 'copilot',
zhipu: 'zhipu',
yi: 'yi',
moonshot: 'moonshot',
baichuan: 'baichuan',
dashscope: 'dashscope',
stepfun: 'stepfun',
doubao: 'doubao',
infini: 'infini',
minimax: 'minimax',
groq: 'groq',
together: 'together',
fireworks: 'fireworks',
nvidia: 'nvidia',
grok: 'grok',
hyperbolic: 'hyperbolic',
mistral: 'mistral',
jina: 'jina',
perplexity: 'perplexity',
modelscope: 'modelscope',
xirang: 'xirang',
hunyuan: 'hunyuan',
'tencent-cloud-ti': 'tencent-cloud-ti',
'baidu-cloud': 'baidu-cloud',
gpustack: 'gpustack',
voyageai: 'voyageai',
'aws-bedrock': 'aws-bedrock',
poe: 'poe',
aionly: 'aionly',
longcat: 'longcat',
huggingface: 'huggingface',
'ai-gateway': 'ai-gateway',
cerebras: 'cerebras'
} as const satisfies Record<SystemProviderId, SystemProviderId>
export type SystemProviderIdTypeMap = typeof SystemProviderIds

View File

@@ -1 +0,0 @@
export { getBaseModelName, getLowerBaseModelName } from './naming'

View File

@@ -1,31 +0,0 @@
/**
* 从模型 ID 中提取基础名称。
* 例如:
* - 'deepseek/deepseek-r1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/deepseek-r1' => 'deepseek-r1'
* @param {string} id 模型 ID
* @param {string} [delimiter='/'] 分隔符,默认为 '/'
* @returns {string} 基础名称
*/
export const getBaseModelName = (id: string, delimiter: string = '/'): string => {
const parts = id.split(delimiter)
return parts[parts.length - 1]
}
/**
* 从模型 ID 中提取基础名称并转换为小写。
* 例如:
* - 'deepseek/DeepSeek-R1' => 'deepseek-r1'
* - 'deepseek-ai/deepseek/DeepSeek-R1' => 'deepseek-r1'
* @param {string} id 模型 ID
* @param {string} [delimiter='/'] 分隔符,默认为 '/'
* @returns {string} 小写的基础名称
*/
export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => {
const baseModelName = getBaseModelName(id, delimiter).toLowerCase()
// for openrouter
if (baseModelName.endsWith(':free')) {
return baseModelName.replace(':free', '')
}
return baseModelName
}

View File

@@ -1,64 +1,42 @@
import { defineConfig } from '@playwright/test'
import { defineConfig, devices } from '@playwright/test'
/**
* Playwright configuration for Electron e2e testing.
* See https://playwright.dev/docs/test-configuration
* See https://playwright.dev/docs/test-configuration.
*/
export default defineConfig({
// Look for test files in the specs directory
testDir: './tests/e2e/specs',
// Global timeout for each test
timeout: 60000,
// Assertion timeout
expect: {
timeout: 10000
},
// Electron apps should run tests sequentially to avoid conflicts
fullyParallel: false,
workers: 1,
// Fail the build on CI if you accidentally left test.only in the source code
// Look for test files, relative to this configuration file.
testDir: './tests/e2e',
/* Run tests in files in parallel */
fullyParallel: true,
/* Fail the build on CI if you accidentally left test.only in the source code. */
forbidOnly: !!process.env.CI,
// Retry on CI only
/* Retry on CI only */
retries: process.env.CI ? 2 : 0,
// Reporter configuration
reporter: [['html', { outputFolder: 'playwright-report' }], ['list']],
// Global setup and teardown
globalSetup: './tests/e2e/global-setup.ts',
globalTeardown: './tests/e2e/global-teardown.ts',
// Output directory for test artifacts
outputDir: './test-results',
// Shared settings for all tests
/* Opt out of parallel tests on CI. */
workers: process.env.CI ? 1 : undefined,
/* Reporter to use. See https://playwright.dev/docs/test-reporters */
reporter: 'html',
/* Shared settings for all the projects below. See https://playwright.dev/docs/api/class-testoptions. */
use: {
// Collect trace when retrying the failed test
trace: 'retain-on-failure',
/* Base URL to use in actions like `await page.goto('/')`. */
// baseURL: 'http://localhost:3000',
// Take screenshot only on failure
screenshot: 'only-on-failure',
// Record video only on failure
video: 'retain-on-failure',
// Action timeout
actionTimeout: 15000,
// Navigation timeout
navigationTimeout: 30000
/* Collect trace when retrying the failed test. See https://playwright.dev/docs/trace-viewer */
trace: 'on-first-retry'
},
// Single project for Electron testing
/* Configure projects for major browsers */
projects: [
{
name: 'electron',
testMatch: '**/*.spec.ts'
name: 'chromium',
use: { ...devices['Desktop Chrome'] }
}
]
/* Run your local dev server before starting the tests */
// webServer: {
// command: 'npm run start',
// url: 'http://localhost:3000',
// reuseExistingServer: !process.env.CI,
// },
})

View File

@@ -1,638 +0,0 @@
/**
* AI SDK to Anthropic SSE Adapter
*
* Converts AI SDK's fullStream (TextStreamPart) events to Anthropic Messages API SSE format.
* This enables any AI provider supported by AI SDK to be exposed via Anthropic-compatible API.
*
* Anthropic SSE Event Flow:
* 1. message_start - Initial message with metadata
* 2. content_block_start - Begin a content block (text, tool_use, thinking)
* 3. content_block_delta - Incremental content updates
* 4. content_block_stop - End a content block
* 5. message_delta - Updates to overall message (stop_reason, usage)
* 6. message_stop - Stream complete
*
* @see https://docs.anthropic.com/en/api/messages-streaming
*/
import type {
ContentBlock,
InputJSONDelta,
Message,
MessageDeltaUsage,
RawContentBlockDeltaEvent,
RawContentBlockStartEvent,
RawContentBlockStopEvent,
RawMessageDeltaEvent,
RawMessageStartEvent,
RawMessageStopEvent,
RawMessageStreamEvent,
StopReason,
TextBlock,
TextDelta,
ThinkingBlock,
ThinkingDelta,
ToolUseBlock,
Usage
} from '@anthropic-ai/sdk/resources/messages'
import { loggerService } from '@logger'
import { type FinishReason, type LanguageModelUsage, type TextStreamPart, type ToolSet } from 'ai'
import { googleReasoningCache, openRouterReasoningCache } from '../../services/CacheService'
const logger = loggerService.withContext('AiSdkToAnthropicSSE')
interface ContentBlockState {
type: 'text' | 'tool_use' | 'thinking'
index: number
started: boolean
content: string
// For tool_use blocks
toolId?: string
toolName?: string
toolInput?: string
}
interface AdapterState {
messageId: string
model: string
inputTokens: number
outputTokens: number
cacheInputTokens: number
currentBlockIndex: number
blocks: Map<number, ContentBlockState>
textBlockIndex: number | null
// Track multiple thinking blocks by their reasoning ID
thinkingBlocks: Map<string, number> // reasoningId -> blockIndex
currentThinkingId: string | null // Currently active thinking block ID
toolBlocks: Map<string, number> // toolCallId -> blockIndex
stopReason: StopReason | null
hasEmittedMessageStart: boolean
}
export type SSEEventCallback = (event: RawMessageStreamEvent) => void
export interface AiSdkToAnthropicSSEOptions {
model: string
messageId?: string
inputTokens?: number
onEvent: SSEEventCallback
}
/**
* Adapter that converts AI SDK fullStream events to Anthropic SSE events
*/
export class AiSdkToAnthropicSSE {
private state: AdapterState
private onEvent: SSEEventCallback
constructor(options: AiSdkToAnthropicSSEOptions) {
this.onEvent = options.onEvent
this.state = {
messageId: options.messageId || `msg_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`,
model: options.model,
inputTokens: options.inputTokens || 0,
outputTokens: 0,
cacheInputTokens: 0,
currentBlockIndex: 0,
blocks: new Map(),
textBlockIndex: null,
thinkingBlocks: new Map(),
currentThinkingId: null,
toolBlocks: new Map(),
stopReason: null,
hasEmittedMessageStart: false
}
}
/**
* Process the AI SDK stream and emit Anthropic SSE events
*/
async processStream(fullStream: ReadableStream<TextStreamPart<ToolSet>>): Promise<void> {
const reader = fullStream.getReader()
try {
// Emit message_start at the beginning
this.emitMessageStart()
while (true) {
const { done, value } = await reader.read()
if (done) {
break
}
this.processChunk(value)
}
// Ensure all blocks are closed and emit final events
this.finalize()
} catch (error) {
await reader.cancel()
throw error
} finally {
reader.releaseLock()
}
}
/**
* Process a single AI SDK chunk and emit corresponding Anthropic events
*/
private processChunk(chunk: TextStreamPart<ToolSet>): void {
logger.silly('AiSdkToAnthropicSSE - Processing chunk:', { chunk: JSON.stringify(chunk) })
switch (chunk.type) {
// === Text Events ===
case 'text-start':
this.startTextBlock()
break
case 'text-delta':
this.emitTextDelta(chunk.text || '')
break
case 'text-end':
this.stopTextBlock()
break
// === Reasoning/Thinking Events ===
case 'reasoning-start': {
const reasoningId = chunk.id
this.startThinkingBlock(reasoningId)
break
}
case 'reasoning-delta': {
const reasoningId = chunk.id
this.emitThinkingDelta(chunk.text || '', reasoningId)
break
}
case 'reasoning-end': {
const reasoningId = chunk.id
this.stopThinkingBlock(reasoningId)
break
}
// === Tool Events ===
case 'tool-call':
if (googleReasoningCache && chunk.providerMetadata?.google?.thoughtSignature) {
googleReasoningCache.set(
`google-${chunk.toolName}`,
chunk.providerMetadata?.google?.thoughtSignature as string
)
}
// FIXME: 按toolcall id绑定
if (
openRouterReasoningCache &&
chunk.providerMetadata?.openrouter?.reasoning_details &&
Array.isArray(chunk.providerMetadata.openrouter.reasoning_details)
) {
openRouterReasoningCache.set(
'openrouter',
JSON.parse(JSON.stringify(chunk.providerMetadata.openrouter.reasoning_details))
)
}
this.handleToolCall({
type: 'tool-call',
toolCallId: chunk.toolCallId,
toolName: chunk.toolName,
args: chunk.input
})
break
case 'tool-result':
// this.handleToolResult({
// type: 'tool-result',
// toolCallId: chunk.toolCallId,
// toolName: chunk.toolName,
// args: chunk.input,
// result: chunk.output
// })
break
case 'finish-step':
if (chunk.finishReason === 'tool-calls') {
this.state.stopReason = 'tool_use'
}
break
case 'finish':
this.handleFinish(chunk)
break
case 'error':
throw chunk.error
// Ignore other event types
default:
break
}
}
private emitMessageStart(): void {
if (this.state.hasEmittedMessageStart) return
this.state.hasEmittedMessageStart = true
const usage: Usage = {
input_tokens: this.state.inputTokens,
output_tokens: 0,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
server_tool_use: null
}
const message: Message = {
id: this.state.messageId,
type: 'message',
role: 'assistant',
content: [],
model: this.state.model,
stop_reason: null,
stop_sequence: null,
usage
}
const event: RawMessageStartEvent = {
type: 'message_start',
message
}
this.onEvent(event)
}
private startTextBlock(): void {
// If we already have a text block, don't create another
if (this.state.textBlockIndex !== null) return
const index = this.state.currentBlockIndex++
this.state.textBlockIndex = index
this.state.blocks.set(index, {
type: 'text',
index,
started: true,
content: ''
})
const contentBlock: TextBlock = {
type: 'text',
text: '',
citations: null
}
const event: RawContentBlockStartEvent = {
type: 'content_block_start',
index,
content_block: contentBlock
}
this.onEvent(event)
}
private emitTextDelta(text: string): void {
if (!text) return
// Auto-start text block if not started
if (this.state.textBlockIndex === null) {
this.startTextBlock()
}
const index = this.state.textBlockIndex!
const block = this.state.blocks.get(index)
if (block) {
block.content += text
}
const delta: TextDelta = {
type: 'text_delta',
text
}
const event: RawContentBlockDeltaEvent = {
type: 'content_block_delta',
index,
delta
}
this.onEvent(event)
}
private stopTextBlock(): void {
if (this.state.textBlockIndex === null) return
const index = this.state.textBlockIndex
const event: RawContentBlockStopEvent = {
type: 'content_block_stop',
index
}
this.onEvent(event)
this.state.textBlockIndex = null
}
private startThinkingBlock(reasoningId: string): void {
// Check if this thinking block already exists
if (this.state.thinkingBlocks.has(reasoningId)) return
const index = this.state.currentBlockIndex++
this.state.thinkingBlocks.set(reasoningId, index)
this.state.currentThinkingId = reasoningId
this.state.blocks.set(index, {
type: 'thinking',
index,
started: true,
content: ''
})
const contentBlock: ThinkingBlock = {
type: 'thinking',
thinking: '',
signature: ''
}
const event: RawContentBlockStartEvent = {
type: 'content_block_start',
index,
content_block: contentBlock
}
this.onEvent(event)
}
private emitThinkingDelta(text: string, reasoningId?: string): void {
if (!text) return
// Determine which thinking block to use
const targetId = reasoningId || this.state.currentThinkingId
if (!targetId) {
// Auto-start thinking block if not started
const newId = `reasoning_${Date.now()}`
this.startThinkingBlock(newId)
return this.emitThinkingDelta(text, newId)
}
const index = this.state.thinkingBlocks.get(targetId)
if (index === undefined) {
// If the block doesn't exist, create it
this.startThinkingBlock(targetId)
return this.emitThinkingDelta(text, targetId)
}
const block = this.state.blocks.get(index)
if (block) {
block.content += text
}
const delta: ThinkingDelta = {
type: 'thinking_delta',
thinking: text
}
const event: RawContentBlockDeltaEvent = {
type: 'content_block_delta',
index,
delta
}
this.onEvent(event)
}
private stopThinkingBlock(reasoningId?: string): void {
const targetId = reasoningId || this.state.currentThinkingId
if (!targetId) return
const index = this.state.thinkingBlocks.get(targetId)
if (index === undefined) return
const event: RawContentBlockStopEvent = {
type: 'content_block_stop',
index
}
this.onEvent(event)
this.state.thinkingBlocks.delete(targetId)
// Update currentThinkingId if we just closed the current one
if (this.state.currentThinkingId === targetId) {
// Set to the most recent remaining thinking block, or null if none
const remaining = Array.from(this.state.thinkingBlocks.keys())
this.state.currentThinkingId = remaining.length > 0 ? remaining[remaining.length - 1] : null
}
}
private handleToolCall(chunk: { type: 'tool-call'; toolCallId: string; toolName: string; args: unknown }): void {
const { toolCallId, toolName, args } = chunk
// Check if we already have this tool call
if (this.state.toolBlocks.has(toolCallId)) {
return
}
const index = this.state.currentBlockIndex++
this.state.toolBlocks.set(toolCallId, index)
const inputJson = JSON.stringify(args)
this.state.blocks.set(index, {
type: 'tool_use',
index,
started: true,
content: inputJson,
toolId: toolCallId,
toolName,
toolInput: inputJson
})
// Emit content_block_start for tool_use
const contentBlock: ToolUseBlock = {
type: 'tool_use',
id: toolCallId,
name: toolName,
input: {}
}
const startEvent: RawContentBlockStartEvent = {
type: 'content_block_start',
index,
content_block: contentBlock
}
this.onEvent(startEvent)
// Emit the full input as a delta (Anthropic streams JSON incrementally)
const delta: InputJSONDelta = {
type: 'input_json_delta',
partial_json: inputJson
}
const deltaEvent: RawContentBlockDeltaEvent = {
type: 'content_block_delta',
index,
delta
}
this.onEvent(deltaEvent)
// Emit content_block_stop
const stopEvent: RawContentBlockStopEvent = {
type: 'content_block_stop',
index
}
this.onEvent(stopEvent)
// Mark that we have tool use
this.state.stopReason = 'tool_use'
}
private handleFinish(chunk: { type: 'finish'; finishReason?: FinishReason; totalUsage?: LanguageModelUsage }): void {
// Update usage
if (chunk.totalUsage) {
this.state.inputTokens = chunk.totalUsage.inputTokens || 0
this.state.outputTokens = chunk.totalUsage.outputTokens || 0
this.state.cacheInputTokens = chunk.totalUsage.cachedInputTokens || 0
}
// Determine finish reason
if (!this.state.stopReason) {
switch (chunk.finishReason) {
case 'stop':
this.state.stopReason = 'end_turn'
break
case 'length':
this.state.stopReason = 'max_tokens'
break
case 'tool-calls':
this.state.stopReason = 'tool_use'
break
case 'content-filter':
this.state.stopReason = 'refusal'
break
default:
this.state.stopReason = 'end_turn'
}
}
}
private finalize(): void {
// Close any open blocks
if (this.state.textBlockIndex !== null) {
this.stopTextBlock()
}
// Close all open thinking blocks
for (const reasoningId of this.state.thinkingBlocks.keys()) {
this.stopThinkingBlock(reasoningId)
}
// Emit message_delta with final stop reason and usage
const usage: MessageDeltaUsage = {
output_tokens: this.state.outputTokens,
input_tokens: this.state.inputTokens,
cache_creation_input_tokens: this.state.cacheInputTokens,
cache_read_input_tokens: null,
server_tool_use: null
}
const messageDeltaEvent: RawMessageDeltaEvent = {
type: 'message_delta',
delta: {
stop_reason: this.state.stopReason || 'end_turn',
stop_sequence: null
},
usage
}
this.onEvent(messageDeltaEvent)
// Emit message_stop
const messageStopEvent: RawMessageStopEvent = {
type: 'message_stop'
}
this.onEvent(messageStopEvent)
}
/**
* Set input token count (typically from prompt)
*/
setInputTokens(count: number): void {
this.state.inputTokens = count
}
/**
* Get the current message ID
*/
getMessageId(): string {
return this.state.messageId
}
/**
* Build a complete Message object for non-streaming responses
*/
buildNonStreamingResponse(): Message {
const content: ContentBlock[] = []
// Collect all content blocks in order
const sortedBlocks = Array.from(this.state.blocks.values()).sort((a, b) => a.index - b.index)
for (const block of sortedBlocks) {
switch (block.type) {
case 'text':
content.push({
type: 'text',
text: block.content,
citations: null
} as TextBlock)
break
case 'thinking':
content.push({
type: 'thinking',
thinking: block.content
} as ThinkingBlock)
break
case 'tool_use':
content.push({
type: 'tool_use',
id: block.toolId!,
name: block.toolName!,
input: JSON.parse(block.toolInput || '{}')
} as ToolUseBlock)
break
}
}
return {
id: this.state.messageId,
type: 'message',
role: 'assistant',
content,
model: this.state.model,
stop_reason: this.state.stopReason || 'end_turn',
stop_sequence: null,
usage: {
input_tokens: this.state.inputTokens,
output_tokens: this.state.outputTokens,
cache_creation_input_tokens: 0,
cache_read_input_tokens: 0,
server_tool_use: null
}
}
}
}
/**
* Format an Anthropic SSE event for HTTP streaming
*/
export function formatSSEEvent(event: RawMessageStreamEvent): string {
return `event: ${event.type}\ndata: ${JSON.stringify(event)}\n\n`
}
/**
* Create a done marker for SSE stream
*/
export function formatSSEDone(): string {
return 'data: [DONE]\n\n'
}
export default AiSdkToAnthropicSSE

View File

@@ -1,13 +0,0 @@
/**
* Shared Adapters
*
* This module exports adapters for converting between different AI API formats.
*/
export {
AiSdkToAnthropicSSE,
type AiSdkToAnthropicSSEOptions,
formatSSEDone,
formatSSEEvent,
type SSEEventCallback
} from './AiSdkToAnthropicSSE'

View File

@@ -1,95 +0,0 @@
import * as z from 'zod/v4'
enum ReasoningFormat {
Unknown = 'unknown',
OpenAIResponsesV1 = 'openai-responses-v1',
XAIResponsesV1 = 'xai-responses-v1',
AnthropicClaudeV1 = 'anthropic-claude-v1',
GoogleGeminiV1 = 'google-gemini-v1'
}
// Anthropic Claude was the first reasoning that we're
// passing back and forth
export const DEFAULT_REASONING_FORMAT = ReasoningFormat.AnthropicClaudeV1
function isDefinedOrNotNull<T>(value: T | null | undefined): value is T {
return value !== null && value !== undefined
}
export enum ReasoningDetailType {
Summary = 'reasoning.summary',
Encrypted = 'reasoning.encrypted',
Text = 'reasoning.text'
}
export const CommonReasoningDetailSchema = z
.object({
id: z.string().nullish(),
format: z.enum(ReasoningFormat).nullish(),
index: z.number().optional()
})
.loose()
export const ReasoningDetailSummarySchema = z
.object({
type: z.literal(ReasoningDetailType.Summary),
summary: z.string()
})
.extend(CommonReasoningDetailSchema.shape)
export type ReasoningDetailSummary = z.infer<typeof ReasoningDetailSummarySchema>
export const ReasoningDetailEncryptedSchema = z
.object({
type: z.literal(ReasoningDetailType.Encrypted),
data: z.string()
})
.extend(CommonReasoningDetailSchema.shape)
export type ReasoningDetailEncrypted = z.infer<typeof ReasoningDetailEncryptedSchema>
export const ReasoningDetailTextSchema = z
.object({
type: z.literal(ReasoningDetailType.Text),
text: z.string().nullish(),
signature: z.string().nullish()
})
.extend(CommonReasoningDetailSchema.shape)
export type ReasoningDetailText = z.infer<typeof ReasoningDetailTextSchema>
export const ReasoningDetailUnionSchema = z.union([
ReasoningDetailSummarySchema,
ReasoningDetailEncryptedSchema,
ReasoningDetailTextSchema
])
export type ReasoningDetailUnion = z.infer<typeof ReasoningDetailUnionSchema>
const ReasoningDetailsWithUnknownSchema = z.union([ReasoningDetailUnionSchema, z.unknown().transform(() => null)])
export const ReasoningDetailArraySchema = z
.array(ReasoningDetailsWithUnknownSchema)
.transform((d) => d.filter((d): d is ReasoningDetailUnion => !!d))
export const OutputUnionToReasoningDetailsSchema = z.union([
z
.object({
delta: z.object({
reasoning_details: z.array(ReasoningDetailsWithUnknownSchema)
})
})
.transform((data) => data.delta.reasoning_details.filter(isDefinedOrNotNull)),
z
.object({
message: z.object({
reasoning_details: z.array(ReasoningDetailsWithUnknownSchema)
})
})
.transform((data) => data.message.reasoning_details.filter(isDefinedOrNotNull)),
z
.object({
text: z.string(),
reasoning_details: z.array(ReasoningDetailsWithUnknownSchema)
})
.transform((data) => data.reasoning_details.filter(isDefinedOrNotNull))
])

View File

@@ -1,93 +1,17 @@
import type { MessageCreateParams } from '@anthropic-ai/sdk/resources'
import { loggerService } from '@logger'
import { buildSharedMiddlewares, type SharedMiddlewareConfig } from '@shared/middleware'
import { getAiSdkProviderId } from '@shared/provider'
import type { Provider } from '@types'
import type { Request, Response } from 'express'
import express from 'express'
import { messagesService } from '../services/messages'
import { generateUnifiedMessage, streamUnifiedMessages } from '../services/unified-messages'
import { getProviderById, isModelAnthropicCompatible, validateModelId } from '../utils'
/**
* Check if a specific model on a provider should use direct Anthropic SDK
*
* A provider+model combination is considered "Anthropic-compatible" if:
* 1. It's a native Anthropic provider (type === 'anthropic'), OR
* 2. It has anthropicApiHost configured AND the specific model supports Anthropic API
* (for aggregated providers like Silicon, only certain models support Anthropic endpoint)
*
* @param provider - The provider to check
* @param modelId - The model ID to check (without provider prefix)
* @returns true if should use direct Anthropic SDK, false for unified SDK
*/
function shouldUseDirectAnthropic(provider: Provider, modelId: string): boolean {
// Native Anthropic provider - always use direct SDK
if (provider.type === 'anthropic') {
return true
}
// No anthropicApiHost configured - use unified SDK
if (!provider.anthropicApiHost?.trim()) {
return false
}
// Has anthropicApiHost - check model-level compatibility
// For aggregated providers, only specific models support Anthropic API
return isModelAnthropicCompatible(provider, modelId)
}
import { getProviderById, validateModelId } from '../utils'
const logger = loggerService.withContext('ApiServerMessagesRoutes')
const router = express.Router()
const providerRouter = express.Router({ mergeParams: true })
/**
* Estimate token count from messages
* Simple approximation: ~4 characters per token for English text
*/
interface CountTokensInput {
messages: Array<{ role: string; content: string | Array<{ type: string; text?: string }> }>
system?: string | Array<{ type: string; text?: string }>
}
function estimateTokenCount(input: CountTokensInput): number {
const { messages, system } = input
let totalChars = 0
// Count system message tokens
if (system) {
if (typeof system === 'string') {
totalChars += system.length
} else if (Array.isArray(system)) {
for (const block of system) {
if (block.type === 'text' && block.text) {
totalChars += block.text.length
}
}
}
}
// Count message tokens
for (const msg of messages) {
if (typeof msg.content === 'string') {
totalChars += msg.content.length
} else if (Array.isArray(msg.content)) {
for (const block of msg.content) {
if (block.type === 'text' && block.text) {
totalChars += block.text.length
}
}
}
// Add overhead for role
totalChars += 10
}
// Estimate tokens (~4 chars per token, with some overhead)
return Math.ceil(totalChars / 4) + messages.length * 3
}
// Helper function for basic request validation
async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> {
const request: MessageCreateParams = req.body
@@ -109,36 +33,21 @@ async function validateRequestBody(req: Request): Promise<{ valid: boolean; erro
}
interface HandleMessageProcessingOptions {
req: Request
res: Response
provider: Provider
request: MessageCreateParams
modelId?: string
}
/**
* Handle message processing using direct Anthropic SDK
* Used for providers with anthropicApiHost or native Anthropic providers
* This bypasses AI SDK conversion and uses native Anthropic protocol
*/
async function handleDirectAnthropicProcessing({
async function handleMessageProcessing({
req,
res,
provider,
request,
modelId,
extraHeaders
}: HandleMessageProcessingOptions & { extraHeaders?: Record<string, string | string[]> }): Promise<void> {
const actualModelId = modelId || request.model
logger.info('Processing message via direct Anthropic SDK', {
providerId: provider.id,
providerType: provider.type,
modelId: actualModelId,
stream: !!request.stream,
anthropicApiHost: provider.anthropicApiHost
})
modelId
}: HandleMessageProcessingOptions): Promise<void> {
try {
// Validate request
const validation = messagesService.validateRequest(request)
if (!validation.isValid) {
res.status(400).json({
@@ -151,126 +60,28 @@ async function handleDirectAnthropicProcessing({
return
}
// Process message using messagesService (native Anthropic SDK)
const extraHeaders = messagesService.prepareHeaders(req.headers)
const { client, anthropicRequest } = await messagesService.processMessage({
provider,
request,
extraHeaders,
modelId: actualModelId
modelId
})
if (request.stream) {
// Use native Anthropic streaming
await messagesService.handleStreaming(client, anthropicRequest, { response: res }, provider)
} else {
// Use native Anthropic non-streaming
const response = await client.messages.create(anthropicRequest)
res.json(response)
}
} catch (error: any) {
logger.error('Direct Anthropic processing error', { error })
const { statusCode, errorResponse } = messagesService.transformError(error)
res.status(statusCode).json(errorResponse)
}
}
/**
* Handle message processing using unified AI SDK
* Used for non-Anthropic providers that need format conversion
* - Uses AI SDK adapters with output converted to Anthropic SSE format
*/
async function handleUnifiedProcessing({
res,
provider,
request,
modelId
}: HandleMessageProcessingOptions): Promise<void> {
const actualModelId = modelId || request.model
logger.info('Processing message via unified AI SDK', {
providerId: provider.id,
providerType: provider.type,
modelId: actualModelId,
stream: !!request.stream
})
try {
// Validate request
const validation = messagesService.validateRequest(request)
if (!validation.isValid) {
res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: validation.errors.join('; ')
}
})
return
}
const middlewareConfig: SharedMiddlewareConfig = {
modelId: actualModelId,
providerId: provider.id,
aiSdkProviderId: getAiSdkProviderId(provider)
}
const middlewares = buildSharedMiddlewares(middlewareConfig)
logger.debug('Built middlewares for unified processing', {
middlewareCount: middlewares.length,
modelId: actualModelId,
providerId: provider.id
})
if (request.stream) {
await streamUnifiedMessages({
response: res,
provider,
modelId: actualModelId,
params: request,
middlewares,
onError: (error) => {
logger.error('Stream error', error as Error)
},
onComplete: () => {
logger.debug('Stream completed')
}
})
} else {
const response = await generateUnifiedMessage({
provider,
modelId: actualModelId,
params: request,
middlewares
})
res.json(response)
}
const response = await client.messages.create(anthropicRequest)
res.json(response)
} catch (error: any) {
logger.error('Message processing error', { error })
const { statusCode, errorResponse } = messagesService.transformError(error)
res.status(statusCode).json(errorResponse)
}
}
/**
* Handle message processing - routes to appropriate handler based on provider and model
*
* Routing logic:
* - Native Anthropic providers (type === 'anthropic'): Direct Anthropic SDK
* - Providers with anthropicApiHost AND model supports Anthropic API: Direct Anthropic SDK
* - Other providers/models: Unified AI SDK with Anthropic SSE conversion
*/
async function handleMessageProcessing({
res,
provider,
request,
modelId
}: HandleMessageProcessingOptions): Promise<void> {
const actualModelId = modelId || request.model
if (shouldUseDirectAnthropic(provider, actualModelId)) {
return handleDirectAnthropicProcessing({ res, provider, request, modelId })
}
return handleUnifiedProcessing({ res, provider, request, modelId })
}
/**
* @swagger
* /v1/messages:
@@ -424,7 +235,7 @@ router.post('/', async (req: Request, res: Response) => {
const provider = modelValidation.provider!
const modelId = modelValidation.modelId!
return handleMessageProcessing({ res, provider, request, modelId })
return handleMessageProcessing({ req, res, provider, request, modelId })
} catch (error: any) {
logger.error('Message processing error', { error })
const { statusCode, errorResponse } = messagesService.transformError(error)
@@ -582,7 +393,7 @@ providerRouter.post('/', async (req: Request, res: Response) => {
const request: MessageCreateParams = req.body
return handleMessageProcessing({ res, provider, request })
return handleMessageProcessing({ req, res, provider, request })
} catch (error: any) {
logger.error('Message processing error', { error })
const { statusCode, errorResponse } = messagesService.transformError(error)
@@ -590,132 +401,4 @@ providerRouter.post('/', async (req: Request, res: Response) => {
}
})
/**
* @swagger
* /v1/messages/count_tokens:
* post:
* summary: Count tokens for messages
* description: Count tokens for Anthropic Messages API format (required by Claude Code SDK)
* tags: [Messages]
* requestBody:
* required: true
* content:
* application/json:
* schema:
* type: object
* required:
* - model
* - messages
* properties:
* model:
* type: string
* description: Model ID
* messages:
* type: array
* items:
* type: object
* system:
* type: string
* description: System message
* responses:
* 200:
* description: Token count response
* content:
* application/json:
* schema:
* type: object
* properties:
* input_tokens:
* type: integer
* 400:
* description: Bad request
*/
router.post('/count_tokens', async (req: Request, res: Response) => {
try {
const { model, messages, system } = req.body
if (!model) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'model parameter is required'
}
})
}
if (!messages || !Array.isArray(messages)) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'messages parameter is required'
}
})
}
const estimatedTokens = estimateTokenCount({ messages, system })
logger.debug('Token count estimated', {
model,
messageCount: messages.length,
estimatedTokens
})
return res.json({
input_tokens: estimatedTokens
})
} catch (error: any) {
logger.error('Token counting error', { error })
return res.status(500).json({
type: 'error',
error: {
type: 'api_error',
message: error.message || 'Internal server error'
}
})
}
})
/**
* Provider-specific count_tokens endpoint
*/
providerRouter.post('/count_tokens', async (req: Request, res: Response) => {
try {
const { model, messages, system } = req.body
if (!messages || !Array.isArray(messages)) {
return res.status(400).json({
type: 'error',
error: {
type: 'invalid_request_error',
message: 'messages parameter is required'
}
})
}
const estimatedTokens = estimateTokenCount({ messages, system })
logger.debug('Token count estimated (provider route)', {
providerId: req.params.provider,
model,
messageCount: messages.length,
estimatedTokens
})
return res.json({
input_tokens: estimatedTokens
})
} catch (error: any) {
logger.error('Token counting error', { error })
return res.status(500).json({
type: 'error',
error: {
type: 'api_error',
message: error.message || 'Internal server error'
}
})
}
})
export { providerRouter as messagesProviderRoutes, router as messagesRoutes }

View File

@@ -2,10 +2,8 @@ import type Anthropic from '@anthropic-ai/sdk'
import type { MessageCreateParams, MessageStreamEvent } from '@anthropic-ai/sdk/resources'
import { loggerService } from '@logger'
import anthropicService from '@main/services/AnthropicService'
import { buildClaudeCodeSystemMessage, getSdkClient, sanitizeToolsForAnthropic } from '@shared/anthropic'
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
import type { Provider } from '@types'
import { APICallError, RetryError } from 'ai'
import { net } from 'electron'
import type { Response } from 'express'
const logger = loggerService.withContext('MessagesService')
@@ -100,30 +98,11 @@ export class MessagesService {
async getClient(provider: Provider, extraHeaders?: Record<string, string | string[]>): Promise<Anthropic> {
// Create Anthropic client for the provider
// Wrap net.fetch to handle compatibility issues:
// 1. net.fetch expects string URLs, not Request objects
// 2. net.fetch doesn't support 'agent' option from Node.js http module
const electronFetch: typeof globalThis.fetch = async (input: URL | RequestInfo, init?: RequestInit) => {
const url = typeof input === 'string' ? input : input instanceof URL ? input.toString() : input.url
// Remove unsupported options for Electron's net.fetch
if (init) {
const initWithAgent = init as RequestInit & { agent?: unknown }
delete initWithAgent.agent
const headers = new Headers(initWithAgent.headers)
if (headers.has('content-length')) {
headers.delete('content-length')
}
initWithAgent.headers = headers
return net.fetch(url, initWithAgent)
}
return net.fetch(url)
}
const context = { fetch: electronFetch }
if (provider.authType === 'oauth') {
const oauthToken = await anthropicService.getValidAccessToken()
return getSdkClient(provider, oauthToken, extraHeaders, context)
return getSdkClient(provider, oauthToken, extraHeaders)
}
return getSdkClient(provider, null, extraHeaders, context)
return getSdkClient(provider, null, extraHeaders)
}
prepareHeaders(headers: Record<string, string | string[] | undefined>): Record<string, string | string[]> {
@@ -148,8 +127,7 @@ export class MessagesService {
createAnthropicRequest(request: MessageCreateParams, provider: Provider, modelId?: string): MessageCreateParams {
const anthropicRequest: MessageCreateParams = {
...request,
stream: !!request.stream,
tools: sanitizeToolsForAnthropic(request.tools)
stream: !!request.stream
}
// Override model if provided
@@ -255,71 +233,9 @@ export class MessagesService {
}
transformError(error: any): { statusCode: number; errorResponse: ErrorResponse } {
let statusCode: number | undefined = undefined
let errorType: string | undefined = undefined
let errorMessage: string | undefined = undefined
const errorMap: Record<number, string> = {
400: 'invalid_request_error',
401: 'authentication_error',
403: 'forbidden_error',
404: 'not_found_error',
429: 'rate_limit_error',
500: 'internal_server_error'
}
// Handle AI SDK RetryError - extract the last error for better error messages
if (RetryError.isInstance(error)) {
const lastError = error.lastError
// If the last error is an APICallError, extract its details
if (APICallError.isInstance(lastError)) {
statusCode = lastError.statusCode || 502
errorMessage = lastError.message
return {
statusCode,
errorResponse: {
type: 'error',
error: {
type: errorMap[statusCode] || 'api_error',
message: `${error.reason}: ${errorMessage}`,
requestId: lastError.name
}
}
}
}
// Fallback for other retry errors
errorMessage = error.message
statusCode = 502
return {
statusCode,
errorResponse: {
type: 'error',
error: {
type: 'api_error',
message: errorMessage,
requestId: error.name
}
}
}
}
if (APICallError.isInstance(error)) {
statusCode = error.statusCode
errorMessage = error.message
if (statusCode) {
return {
statusCode,
errorResponse: {
type: 'error',
error: {
type: errorMap[statusCode] || 'api_error',
message: errorMessage,
requestId: error.name
}
}
}
}
}
let statusCode = 500
let errorType = 'api_error'
let errorMessage = 'Internal server error'
const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined
const anthropicError = error?.error
@@ -361,11 +277,11 @@ export class MessagesService {
typeof errorMessage === 'string' && errorMessage.length > 0 ? errorMessage : 'Internal server error'
return {
statusCode: statusCode ?? 500,
statusCode,
errorResponse: {
type: 'error',
error: {
type: errorType || 'api_error',
type: errorType,
message: safeErrorMessage,
requestId: error?.request_id
}

View File

@@ -1,6 +1,13 @@
import { isEmpty } from 'lodash'
import type { ApiModel, ApiModelsFilter, ApiModelsResponse } from '../../../renderer/src/types/apiModels'
import { loggerService } from '../../services/LoggerService'
import { getAvailableProviders, listAllAvailableModels, transformModelToOpenAI } from '../utils'
import {
getAvailableProviders,
getProviderAnthropicModelChecker,
listAllAvailableModels,
transformModelToOpenAI
} from '../utils'
const logger = loggerService.withContext('ModelsService')
@@ -13,12 +20,11 @@ export class ModelsService {
try {
logger.debug('Getting available models from providers', { filter })
const providers = await getAvailableProviders()
let providers = await getAvailableProviders()
// Note: When providerType === 'anthropic', we now return ALL available models
// because the API Server's unified adapter (AiSdkToAnthropicSSE) can convert
// any provider's response to Anthropic SSE format. This enables Claude Code Agent
// to work with OpenAI, Gemini, and other providers transparently.
if (filter.providerType === 'anthropic') {
providers = providers.filter((p) => p.type === 'anthropic' || !isEmpty(p.anthropicApiHost?.trim()))
}
const models = await listAllAvailableModels(providers)
// Use Map to deduplicate models by their full ID (provider:model_id)
@@ -26,11 +32,20 @@ export class ModelsService {
for (const model of models) {
const provider = providers.find((p) => p.id === model.provider)
// logger.debug(`Processing model ${model.id}`)
if (!provider) {
logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`)
continue
}
if (filter.providerType === 'anthropic') {
const checker = getProviderAnthropicModelChecker(provider.id)
if (!checker(model)) {
logger.debug(`Skipping model ${model.id} from ${model.provider}. Reason: Not an Anthropic model.`)
continue
}
}
const openAIModel = transformModelToOpenAI(model, provider)
const fullModelId = openAIModel.id // This is already in format "provider:model_id"

View File

@@ -1,718 +0,0 @@
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import type { LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider'
import type { ProviderOptions, ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils'
import type {
ImageBlockParam,
MessageCreateParams,
TextBlockParam,
Tool as AnthropicTool
} from '@anthropic-ai/sdk/resources/messages'
import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core'
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@main/apiServer/adapters'
import { generateSignature as cherryaiGenerateSignature } from '@main/integration/cherryai'
import anthropicService from '@main/services/AnthropicService'
import copilotService from '@main/services/CopilotService'
import { reduxService } from '@main/services/ReduxService'
import { isGemini3ModelId } from '@shared/middleware'
import {
type AiSdkConfig,
type AiSdkConfigContext,
formatProviderApiHost,
initializeSharedProviders,
isAnthropicProvider,
isGeminiProvider,
isOpenAIProvider,
type ProviderFormatContext,
providerToAiSdkConfig as sharedProviderToAiSdkConfig,
resolveActualProvider
} from '@shared/provider'
import { COPILOT_DEFAULT_HEADERS } from '@shared/provider/constant'
import { defaultAppHeaders } from '@shared/utils'
import type { Provider } from '@types'
import type { ImagePart, JSONValue, ModelMessage, Provider as AiSdkProvider, TextPart, Tool as AiSdkTool } from 'ai'
import { simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel, zodSchema } from 'ai'
import { net } from 'electron'
import type { Response } from 'express'
import * as z from 'zod'
import { googleReasoningCache, openRouterReasoningCache } from '../../services/CacheService'
const logger = loggerService.withContext('UnifiedMessagesService')
const MAGIC_STRING = 'skip_thought_signature_validator'
function sanitizeJson(value: unknown): JSONValue {
return JSON.parse(JSON.stringify(value))
}
initializeSharedProviders({
warn: (message) => logger.warn(message),
error: (message, error) => logger.error(message, error)
})
/**
* Configuration for unified message streaming
*/
export interface UnifiedStreamConfig {
response: Response
provider: Provider
modelId: string
params: MessageCreateParams
onError?: (error: unknown) => void
onComplete?: () => void
/**
* Optional AI SDK middlewares to apply
*/
middlewares?: LanguageModelV2Middleware[]
/**
* Optional AI Core plugins to use with the executor
*/
plugins?: AiPlugin[]
}
/**
* Configuration for non-streaming message generation
*/
export interface GenerateUnifiedMessageConfig {
provider: Provider
modelId: string
params: MessageCreateParams
middlewares?: LanguageModelV2Middleware[]
plugins?: AiPlugin[]
}
function getMainProcessFormatContext(): ProviderFormatContext {
const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai')
return {
vertex: {
project: vertexSettings?.projectId || 'default-project',
location: vertexSettings?.location || 'us-central1'
}
}
}
const mainProcessSdkContext: AiSdkConfigContext = {
getRotatedApiKey: (provider) => {
const keys = provider.apiKey.split(',').map((k) => k.trim())
return keys[0] || provider.apiKey
},
fetch: net.fetch as typeof globalThis.fetch
}
function getActualProvider(provider: Provider, modelId: string): Provider {
const model = provider.models?.find((m) => m.id === modelId)
if (!model) return provider
return resolveActualProvider(provider, model)
}
function providerToAiSdkConfig(provider: Provider, modelId: string): AiSdkConfig {
const actualProvider = getActualProvider(provider, modelId)
const formattedProvider = formatProviderApiHost(actualProvider, getMainProcessFormatContext())
return sharedProviderToAiSdkConfig(formattedProvider, modelId, mainProcessSdkContext)
}
function convertAnthropicToolResultToAiSdk(
content: string | Array<TextBlockParam | ImageBlockParam>
): LanguageModelV2ToolResultOutput {
if (typeof content === 'string') {
return { type: 'text', value: content }
}
const values: Array<{ type: 'text'; text: string } | { type: 'media'; data: string; mediaType: string }> = []
for (const block of content) {
if (block.type === 'text') {
values.push({ type: 'text', text: block.text })
} else if (block.type === 'image') {
values.push({
type: 'media',
data: block.source.type === 'base64' ? block.source.data : block.source.url,
mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png'
})
}
}
return { type: 'content', value: values }
}
// Type alias for JSON Schema (compatible with recursive calls)
type JsonSchemaLike = AnthropicTool.InputSchema | Record<string, unknown>
/**
* Convert JSON Schema to Zod schema
* This avoids non-standard fields like input_examples that Anthropic doesn't support
*/
function jsonSchemaToZod(schema: JsonSchemaLike): z.ZodTypeAny {
const s = schema as Record<string, unknown>
const schemaType = s.type as string | string[] | undefined
const enumValues = s.enum as unknown[] | undefined
const description = s.description as string | undefined
// Handle enum first
if (enumValues && Array.isArray(enumValues) && enumValues.length > 0) {
if (enumValues.every((v) => typeof v === 'string')) {
const zodEnum = z.enum(enumValues as [string, ...string[]])
return description ? zodEnum.describe(description) : zodEnum
}
// For non-string enums, use union of literals
const literals = enumValues.map((v) => z.literal(v as string | number | boolean))
if (literals.length === 1) {
return description ? literals[0].describe(description) : literals[0]
}
const zodUnion = z.union(literals as unknown as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
return description ? zodUnion.describe(description) : zodUnion
}
// Handle union types (type: ["string", "null"])
if (Array.isArray(schemaType)) {
const schemas = schemaType.map((t) => jsonSchemaToZod({ ...s, type: t, enum: undefined }))
if (schemas.length === 1) {
return schemas[0]
}
return z.union(schemas as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
}
// Handle by type
switch (schemaType) {
case 'string': {
let zodString = z.string()
if (typeof s.minLength === 'number') zodString = zodString.min(s.minLength)
if (typeof s.maxLength === 'number') zodString = zodString.max(s.maxLength)
if (typeof s.pattern === 'string') zodString = zodString.regex(new RegExp(s.pattern))
return description ? zodString.describe(description) : zodString
}
case 'number':
case 'integer': {
let zodNumber = schemaType === 'integer' ? z.number().int() : z.number()
if (typeof s.minimum === 'number') zodNumber = zodNumber.min(s.minimum)
if (typeof s.maximum === 'number') zodNumber = zodNumber.max(s.maximum)
return description ? zodNumber.describe(description) : zodNumber
}
case 'boolean': {
const zodBoolean = z.boolean()
return description ? zodBoolean.describe(description) : zodBoolean
}
case 'null':
return z.null()
case 'array': {
const items = s.items as Record<string, unknown> | undefined
let zodArray = items ? z.array(jsonSchemaToZod(items)) : z.array(z.unknown())
if (typeof s.minItems === 'number') zodArray = zodArray.min(s.minItems)
if (typeof s.maxItems === 'number') zodArray = zodArray.max(s.maxItems)
return description ? zodArray.describe(description) : zodArray
}
case 'object': {
const properties = s.properties as Record<string, Record<string, unknown>> | undefined
const required = (s.required as string[]) || []
// Always use z.object() to ensure "properties" field is present in output schema
// OpenAI requires explicit properties field even for empty objects
const shape: Record<string, z.ZodTypeAny> = {}
if (properties) {
for (const [key, propSchema] of Object.entries(properties)) {
const zodProp = jsonSchemaToZod(propSchema)
shape[key] = required.includes(key) ? zodProp : zodProp.optional()
}
}
const zodObject = z.object(shape)
return description ? zodObject.describe(description) : zodObject
}
default:
// Unknown type, use z.unknown()
return z.unknown()
}
}
function convertAnthropicToolsToAiSdk(tools: MessageCreateParams['tools']): Record<string, AiSdkTool> | undefined {
if (!tools || tools.length === 0) return undefined
const aiSdkTools: Record<string, AiSdkTool> = {}
for (const anthropicTool of tools) {
if (anthropicTool.type === 'bash_20250124') continue
const toolDef = anthropicTool as AnthropicTool
const rawSchema = toolDef.input_schema
const schema = jsonSchemaToZod(rawSchema)
// Use tool() with inputSchema (AI SDK v5 API)
const aiTool = tool({
description: toolDef.description || '',
inputSchema: zodSchema(schema)
})
aiSdkTools[toolDef.name] = aiTool
}
return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined
}
function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage[] {
const messages: ModelMessage[] = []
// System message
if (params.system) {
if (typeof params.system === 'string') {
messages.push({ role: 'system', content: params.system })
} else if (Array.isArray(params.system)) {
const systemText = params.system
.filter((block) => block.type === 'text')
.map((block) => block.text)
.join('\n')
if (systemText) {
messages.push({ role: 'system', content: systemText })
}
}
}
const toolCallIdToName = new Map<string, string>()
for (const msg of params.messages) {
if (Array.isArray(msg.content)) {
for (const block of msg.content) {
if (block.type === 'tool_use') {
toolCallIdToName.set(block.id, block.name)
}
}
}
}
// User/assistant messages
for (const msg of params.messages) {
if (typeof msg.content === 'string') {
messages.push({
role: msg.role === 'user' ? 'user' : 'assistant',
content: msg.content
})
} else if (Array.isArray(msg.content)) {
const textParts: TextPart[] = []
const imageParts: ImagePart[] = []
const reasoningParts: ReasoningPart[] = []
const toolCallParts: ToolCallPart[] = []
const toolResultParts: ToolResultPart[] = []
for (const block of msg.content) {
if (block.type === 'text') {
textParts.push({ type: 'text', text: block.text })
} else if (block.type === 'thinking') {
reasoningParts.push({ type: 'reasoning', text: block.thinking })
} else if (block.type === 'redacted_thinking') {
reasoningParts.push({ type: 'reasoning', text: block.data })
} else if (block.type === 'image') {
const source = block.source
if (source.type === 'base64') {
imageParts.push({ type: 'image', image: `data:${source.media_type};base64,${source.data}` })
} else if (source.type === 'url') {
imageParts.push({ type: 'image', image: source.url })
}
} else if (block.type === 'tool_use') {
const options: ProviderOptions = {}
if (isGemini3ModelId(params.model)) {
if (googleReasoningCache.get(`google-${block.name}`)) {
options.google = {
thoughtSignature: MAGIC_STRING
}
} else if (openRouterReasoningCache.get('openrouter')) {
options.openrouter = {
reasoning_details: (sanitizeJson(openRouterReasoningCache.get('openrouter')) as JSONValue[]) || []
}
}
}
toolCallParts.push({
type: 'tool-call',
toolName: block.name,
toolCallId: block.id,
input: block.input,
providerOptions: options
})
} else if (block.type === 'tool_result') {
// Look up toolName from the pre-built map (covers cross-message references)
const toolName = toolCallIdToName.get(block.tool_use_id) || 'unknown'
toolResultParts.push({
type: 'tool-result',
toolCallId: block.tool_use_id,
toolName,
output: block.content ? convertAnthropicToolResultToAiSdk(block.content) : { type: 'text', value: '' }
})
}
}
if (toolResultParts.length > 0) {
messages.push({ role: 'tool', content: [...toolResultParts] })
}
if (msg.role === 'user') {
const userContent = [...textParts, ...imageParts]
if (userContent.length > 0) {
messages.push({ role: 'user', content: userContent })
}
} else {
const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts]
if (assistantContent.length > 0) {
let providerOptions: ProviderOptions | undefined = undefined
if (openRouterReasoningCache.get('openrouter')) {
providerOptions = {
openrouter: {
reasoning_details: (sanitizeJson(openRouterReasoningCache.get('openrouter')) as JSONValue[]) || []
}
}
} else if (isGemini3ModelId(params.model)) {
providerOptions = {
google: {
thoughtSignature: MAGIC_STRING
}
}
}
messages.push({ role: 'assistant', content: assistantContent, providerOptions })
}
}
}
}
return messages
}
interface ExecuteStreamConfig {
provider: Provider
modelId: string
params: MessageCreateParams
middlewares?: LanguageModelV2Middleware[]
plugins?: AiPlugin[]
onEvent?: (event: Parameters<typeof formatSSEEvent>[0]) => void
}
/**
* Create AI SDK provider instance from config
* Similar to renderer's createAiSdkProvider
*/
async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider> {
let providerId = config.providerId
// Handle special provider modes (same as renderer)
if (providerId === 'openai' && config.options?.mode === 'chat') {
providerId = 'openai-chat'
} else if (providerId === 'azure' && config.options?.mode === 'responses') {
providerId = 'azure-responses'
} else if (providerId === 'cherryin' && config.options?.mode === 'chat') {
providerId = 'cherryin-chat'
}
const provider = await createProviderCore(providerId, config.options)
return provider
}
/**
* Prepare special provider configuration for providers that need dynamic tokens
* Similar to renderer's prepareSpecialProviderConfig
*/
async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkConfig): Promise<AiSdkConfig> {
switch (provider.id) {
case 'copilot': {
const storedHeaders =
((await reduxService.select('state.copilot.defaultHeaders')) as Record<string, string> | null) ?? {}
const headers: Record<string, string> = {
...COPILOT_DEFAULT_HEADERS,
...storedHeaders
}
try {
const { token } = await copilotService.getToken(null as any, headers)
config.options.apiKey = token
const existingHeaders = (config.options.headers as Record<string, string> | undefined) ?? {}
config.options.headers = {
...headers,
...existingHeaders
}
} catch (error) {
logger.error('Failed to get Copilot token', error as Error)
throw new Error('Failed to get Copilot token. Please re-authorize Copilot.')
}
break
}
case 'anthropic': {
if (provider.authType === 'oauth') {
try {
const oauthToken = await anthropicService.getValidAccessToken()
if (!oauthToken) {
throw new Error('Anthropic OAuth token not available. Please re-authorize.')
}
config.options = {
...config.options,
headers: {
...(config.options.headers ? config.options.headers : {}),
'Content-Type': 'application/json',
'anthropic-version': '2023-06-01',
'anthropic-beta': 'oauth-2025-04-20',
Authorization: `Bearer ${oauthToken}`
},
baseURL: 'https://api.anthropic.com/v1',
apiKey: ''
}
} catch (error) {
logger.error('Failed to get Anthropic OAuth token', error as Error)
throw new Error('Failed to get Anthropic OAuth token. Please re-authorize.')
}
}
break
}
case 'cherryai': {
// Create a signed fetch wrapper for cherryai
const baseFetch = net.fetch as typeof globalThis.fetch
config.options.fetch = async (url: RequestInfo | URL, options?: RequestInit) => {
if (!options?.body) {
return baseFetch(url, options)
}
const signature = cherryaiGenerateSignature({
method: 'POST',
path: '/chat/completions',
query: '',
body: JSON.parse(options.body as string)
})
return baseFetch(url, {
...options,
headers: {
...(options.headers as Record<string, string>),
...signature
}
})
}
break
}
}
return config
}
function mapAnthropicThinkToAISdkProviderOptions(
provider: Provider,
config: MessageCreateParams['thinking']
): ProviderOptions | undefined {
if (!config) return undefined
if (isAnthropicProvider(provider)) {
return {
anthropic: {
...mapToAnthropicProviderOptions(config)
}
}
}
if (isGeminiProvider(provider)) {
return {
google: {
...mapToGeminiProviderOptions(config)
}
}
}
if (isOpenAIProvider(provider)) {
return {
openai: {
...mapToOpenAIProviderOptions(config)
}
}
}
return undefined
}
function mapToAnthropicProviderOptions(config: NonNullable<MessageCreateParams['thinking']>): AnthropicProviderOptions {
return {
thinking: {
type: config.type,
budgetTokens: config.type === 'enabled' ? config.budget_tokens : undefined
}
}
}
function mapToGeminiProviderOptions(
config: NonNullable<MessageCreateParams['thinking']>
): GoogleGenerativeAIProviderOptions {
return {
thinkingConfig: {
thinkingBudget: config.type === 'enabled' ? config.budget_tokens : -1,
includeThoughts: config.type === 'enabled'
}
}
}
function mapToOpenAIProviderOptions(
config: NonNullable<MessageCreateParams['thinking']>
): OpenAIResponsesProviderOptions {
return {
reasoningEffort: config.type === 'enabled' ? 'high' : 'none'
}
}
/**
* Core stream execution function - single source of truth for AI SDK calls
*/
async function executeStream(config: ExecuteStreamConfig): Promise<AiSdkToAnthropicSSE> {
const { provider, modelId, params, middlewares = [], plugins = [], onEvent } = config
// Convert provider config to AI SDK config
let sdkConfig = providerToAiSdkConfig(provider, modelId)
// Prepare special provider config (Copilot, Anthropic OAuth, etc.)
sdkConfig = await prepareSpecialProviderConfig(provider, sdkConfig)
// Create provider instance and get language model
const aiSdkProvider = await createAiSdkProvider(sdkConfig)
const baseModel = aiSdkProvider.languageModel(modelId)
// Apply middlewares if present
const model =
middlewares.length > 0 && typeof baseModel === 'object'
? (wrapLanguageModel({ model: baseModel, middleware: middlewares }) as typeof baseModel)
: baseModel
// Create executor with plugins
const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins)
// Convert messages and tools
const coreMessages = convertAnthropicToAiMessages(params)
const tools = convertAnthropicToolsToAiSdk(params.tools)
// Create the adapter
const adapter = new AiSdkToAnthropicSSE({
model: `${provider.id}:${modelId}`,
onEvent: onEvent || (() => {})
})
// Execute stream - pass model object instead of string
const result = await executor.streamText({
model, // Now passing LanguageModel object, not string
messages: coreMessages,
// FIXME: Claude Code传入的maxToken会超出有些模型限制需做特殊处理可能在v2好修复一点现在维护的成本有点高
// 已知: 豆包
maxOutputTokens: params.max_tokens,
temperature: params.temperature,
topP: params.top_p,
topK: params.top_k,
stopSequences: params.stop_sequences,
stopWhen: stepCountIs(100),
headers: defaultAppHeaders(),
tools,
providerOptions: mapAnthropicThinkToAISdkProviderOptions(provider, params.thinking)
})
// Process the stream through the adapter
await adapter.processStream(result.fullStream)
return adapter
}
/**
* Stream a message request using AI SDK executor and convert to Anthropic SSE format
*/
export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise<void> {
const { response, provider, modelId, params, onError, onComplete, middlewares = [], plugins = [] } = config
logger.info('Starting unified message stream', {
providerId: provider.id,
providerType: provider.type,
modelId,
stream: params.stream,
middlewareCount: middlewares.length,
pluginCount: plugins.length
})
try {
response.setHeader('Content-Type', 'text/event-stream')
response.setHeader('Cache-Control', 'no-cache')
response.setHeader('Connection', 'keep-alive')
response.setHeader('X-Accel-Buffering', 'no')
await executeStream({
provider,
modelId,
params,
middlewares,
plugins,
onEvent: (event) => {
logger.silly('Streaming event', { eventType: event.type })
const sseData = formatSSEEvent(event)
response.write(sseData)
}
})
// Send done marker
response.write(formatSSEDone())
response.end()
logger.info('Unified message stream completed', { providerId: provider.id, modelId })
onComplete?.()
} catch (error) {
logger.error('Error in unified message stream', error as Error, { providerId: provider.id, modelId })
onError?.(error)
throw error
}
}
/**
* Generate a non-streaming message response
*
* Uses simulateStreamingMiddleware to reuse the same streaming logic,
* similar to renderer's ModernAiProvider pattern.
*/
export async function generateUnifiedMessage(
providerOrConfig: Provider | GenerateUnifiedMessageConfig,
modelId?: string,
params?: MessageCreateParams
): Promise<ReturnType<typeof AiSdkToAnthropicSSE.prototype.buildNonStreamingResponse>> {
// Support both old signature and new config-based signature
let config: GenerateUnifiedMessageConfig
if ('provider' in providerOrConfig && 'modelId' in providerOrConfig && 'params' in providerOrConfig) {
config = providerOrConfig
} else {
config = {
provider: providerOrConfig as Provider,
modelId: modelId!,
params: params!
}
}
const { provider, middlewares = [], plugins = [] } = config
logger.info('Starting unified message generation', {
providerId: provider.id,
providerType: provider.type,
modelId: config.modelId,
middlewareCount: middlewares.length,
pluginCount: plugins.length
})
try {
// Add simulateStreamingMiddleware to reuse streaming logic for non-streaming
const allMiddlewares = [simulateStreamingMiddleware(), ...middlewares]
const adapter = await executeStream({
provider,
modelId: config.modelId,
params: config.params,
middlewares: allMiddlewares,
plugins
})
const finalResponse = adapter.buildNonStreamingResponse()
logger.info('Unified message generation completed', {
providerId: provider.id,
modelId: config.modelId
})
return finalResponse
} catch (error) {
logger.error('Error in unified message generation', error as Error, {
providerId: provider.id,
modelId: config.modelId
})
throw error
}
}
export default {
streamUnifiedMessages,
generateUnifiedMessage
}

View File

@@ -1,7 +1,7 @@
import { CacheService } from '@main/services/CacheService'
import { loggerService } from '@main/services/LoggerService'
import { reduxService } from '@main/services/ReduxService'
import { isPpioAnthropicCompatibleModel, isSiliconAnthropicCompatibleModel } from '@shared/config/providers'
import { isSiliconAnthropicCompatibleModel } from '@shared/config/providers'
import type { ApiModel, Model, Provider } from '@types'
const logger = loggerService.withContext('ApiServerUtils')
@@ -28,9 +28,10 @@ export async function getAvailableProviders(): Promise<Provider[]> {
return []
}
// Support all provider types that AI SDK can handle
// The unified-messages service uses AI SDK which supports many providers
const supportedProviders = providers.filter((p: Provider) => p.enabled)
// Support OpenAI and Anthropic type providers for API server
const supportedProviders = providers.filter(
(p: Provider) => p.enabled && (p.type === 'openai' || p.type === 'anthropic')
)
// Cache the filtered results
CacheService.set(PROVIDERS_CACHE_KEY, supportedProviders, PROVIDERS_CACHE_TTL)
@@ -159,7 +160,7 @@ export async function validateModelId(model: string): Promise<{
valid: false,
error: {
type: 'provider_not_found',
message: `Provider '${providerId}' not found or not enabled.`,
message: `Provider '${providerId}' not found, not enabled, or not supported. Only OpenAI providers are currently supported.`,
code: 'provider_not_found'
}
}
@@ -261,8 +262,14 @@ export function validateProvider(provider: Provider): boolean {
return false
}
// AI SDK supports many provider types, no longer need to filter by type
// The unified-messages service handles all supported types
// Support OpenAI and Anthropic type providers
if (provider.type !== 'openai' && provider.type !== 'anthropic') {
logger.debug('Provider type not supported', {
providerId: provider.id,
providerType: provider.type
})
return false
}
return true
} catch (error: any) {
@@ -283,39 +290,8 @@ export const getProviderAnthropicModelChecker = (providerId: string): ((m: Model
return (m: Model) => m.id.includes('claude')
case 'silicon':
return (m: Model) => isSiliconAnthropicCompatibleModel(m.id)
case 'ppio':
return (m: Model) => isPpioAnthropicCompatibleModel(m.id)
default:
// allow all models when checker not configured
return () => true
}
}
/**
* Check if a specific model is compatible with Anthropic API for a given provider.
*
* This is used for fine-grained routing decisions at the model level.
* For aggregated providers (like Silicon), only certain models support the Anthropic API endpoint.
*
* @param provider - The provider to check
* @param modelId - The model ID to check (without provider prefix)
* @returns true if the model supports Anthropic API endpoint
*/
export function isModelAnthropicCompatible(provider: Provider, modelId: string): boolean {
const checker = getProviderAnthropicModelChecker(provider.id)
const model = provider.models?.find((m) => m.id === modelId)
if (model) {
return checker(model)
}
const minimalModel: Model = {
id: modelId,
name: modelId,
provider: provider.id,
group: ''
}
return checker(minimalModel)
}

View File

@@ -73,6 +73,7 @@ import {
import storeSyncService from './services/StoreSyncService'
import { themeService } from './services/ThemeService'
import VertexAIService from './services/VertexAIService'
import VolcengineService from './services/VolcengineService'
import WebSocketService from './services/WebSocketService'
import { setOpenLinkExternal } from './services/WebviewService'
import { windowService } from './services/WindowService'
@@ -1077,6 +1078,14 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle(IpcChannel.WebSocket_SendFile, WebSocketService.sendFile)
ipcMain.handle(IpcChannel.WebSocket_GetAllCandidates, WebSocketService.getAllCandidates)
// Volcengine
ipcMain.handle(IpcChannel.Volcengine_SaveCredentials, VolcengineService.saveCredentials)
ipcMain.handle(IpcChannel.Volcengine_HasCredentials, VolcengineService.hasCredentials)
ipcMain.handle(IpcChannel.Volcengine_ClearCredentials, VolcengineService.clearCredentials)
ipcMain.handle(IpcChannel.Volcengine_ListModels, VolcengineService.listModels)
ipcMain.handle(IpcChannel.Volcengine_GetAuthHeaders, VolcengineService.getAuthHeaders)
ipcMain.handle(IpcChannel.Volcengine_MakeRequest, VolcengineService.makeRequest)
ipcMain.handle(IpcChannel.APP_CrashRenderProcess, () => {
mainWindow.webContents.forcefullyCrashRenderer()
})

View File

@@ -1,19 +1,9 @@
import type { ReasoningDetailUnion } from '@main/apiServer/adapters/openrouter'
interface CacheItem<T> {
data: T
timestamp: number
duration: number
}
/**
* Interface for reasoning cache
*/
export interface IReasoningCache<T> {
set(key: string, value: T): void
get(key: string): T | undefined
}
export class CacheService {
private static cache: Map<string, CacheItem<any>> = new Map()
@@ -82,14 +72,3 @@ export class CacheService {
return true
}
}
// Singleton cache instances using CacheService
export const googleReasoningCache: IReasoningCache<string> = {
set: (key, value) => CacheService.set(`google-reasoning:${key}`, value, 30 * 60 * 1000),
get: (key) => CacheService.get(`google-reasoning:${key}`) || undefined
}
export const openRouterReasoningCache: IReasoningCache<ReasoningDetailUnion[]> = {
set: (key, value) => CacheService.set(`openrouter-reasoning:${key}`, value, 30 * 60 * 1000),
get: (key) => CacheService.get(`openrouter-reasoning:${key}`) || undefined
}

View File

@@ -548,17 +548,6 @@ class CodeToolsService {
logger.debug(`Environment variables:`, Object.keys(env))
logger.debug(`Options:`, options)
// Validate directory exists before proceeding
if (!directory || !fs.existsSync(directory)) {
const errorMessage = `Directory does not exist: ${directory}`
logger.error(errorMessage)
return {
success: false,
message: errorMessage,
command: ''
}
}
const packageName = await this.getPackageName(cliTool)
const bunPath = await this.getBunPath()
const executableName = await this.getCliExecutableName(cliTool)
@@ -720,7 +709,6 @@ class CodeToolsService {
// Build bat file content, including debug information
const batContent = [
'@echo off',
'chcp 65001 >nul 2>&1', // Switch to UTF-8 code page for international path support
`title ${cliTool} - Cherry Studio`, // Set window title in bat file
'echo ================================================',
'echo Cherry Studio CLI Tool Launcher',

View File

@@ -620,7 +620,7 @@ class McpService {
tools.map((tool: SDKTool) => {
const serverTool: MCPTool = {
...tool,
id: buildFunctionCallToolName(server.name, tool.name, server.id),
id: buildFunctionCallToolName(server.name, tool.name),
serverId: server.id,
serverName: server.name,
type: 'mcp'

View File

@@ -0,0 +1,732 @@
import { loggerService } from '@logger'
import crypto from 'crypto'
import { app, net, safeStorage } from 'electron'
import fs from 'fs'
import path from 'path'
import * as z from 'zod'
import { getConfigDir } from '../utils/file'
const logger = loggerService.withContext('VolcengineService')
// Configuration constants
const CONFIG = {
ALGORITHM: 'HMAC-SHA256',
REQUEST_TYPE: 'request',
DEFAULT_REGION: 'cn-beijing',
SERVICE_NAME: 'ark',
DEFAULT_HEADERS: {
'content-type': 'application/json',
accept: 'application/json'
},
API_URLS: {
ARK_HOST: 'open.volcengineapi.com'
},
CREDENTIALS_FILE_NAME: '.volcengine_credentials',
API_VERSION: '2024-01-01',
DEFAULT_PAGE_SIZE: 100
} as const
// Request schemas
const ListFoundationModelsRequestSchema = z.object({
PageNumber: z.optional(z.number()),
PageSize: z.optional(z.number())
})
const ListEndpointsRequestSchema = z.object({
ProjectName: z.optional(z.string()),
PageNumber: z.optional(z.number()),
PageSize: z.optional(z.number())
})
// Response schemas - only keep fields needed for model list
const FoundationModelItemSchema = z.object({
Name: z.string(),
DisplayName: z.optional(z.string()),
Description: z.optional(z.string())
})
const EndpointItemSchema = z.object({
Id: z.string(),
Name: z.optional(z.string()),
Description: z.optional(z.string()),
ModelReference: z.optional(
z.object({
FoundationModel: z.optional(
z.object({
Name: z.optional(z.string()),
ModelVersion: z.optional(z.string())
})
),
CustomModelId: z.optional(z.string())
})
)
})
const ListFoundationModelsResponseSchema = z.object({
Result: z.object({
TotalCount: z.number(),
Items: z.array(FoundationModelItemSchema)
})
})
const ListEndpointsResponseSchema = z.object({
Result: z.object({
TotalCount: z.number(),
Items: z.array(EndpointItemSchema)
})
})
// Infer types from schemas
type ListFoundationModelsRequest = z.infer<typeof ListFoundationModelsRequestSchema>
type ListEndpointsRequest = z.infer<typeof ListEndpointsRequestSchema>
type ListFoundationModelsResponse = z.infer<typeof ListFoundationModelsResponseSchema>
type ListEndpointsResponse = z.infer<typeof ListEndpointsResponseSchema>
// ============= Internal Type Definitions =============
interface VolcengineCredentials {
accessKeyId: string
secretAccessKey: string
}
interface SignedRequestParams {
method: 'GET' | 'POST'
host: string
path: string
query: Record<string, string>
headers: Record<string, string>
body?: string
service: string
region: string
}
interface SignedHeaders {
Authorization: string
'X-Date': string
'X-Content-Sha256': string
Host: string
}
interface ModelInfo {
id: string
name: string
description?: string
created?: number
}
interface ListModelsResult {
models: ModelInfo[]
total?: number
warnings?: string[]
}
// Custom error class
class VolcengineServiceError extends Error {
constructor(
message: string,
public readonly cause?: unknown
) {
super(message)
this.name = 'VolcengineServiceError'
}
}
/**
* Volcengine API Signing Service
*
* Implements HMAC-SHA256 signing algorithm for Volcengine API authentication.
* Securely stores credentials using Electron's safeStorage.
*/
class VolcengineService {
private readonly credentialsFilePath: string
constructor() {
this.credentialsFilePath = this.getCredentialsFilePath()
}
/**
* Get the path for storing encrypted credentials
*/
private getCredentialsFilePath(): string {
const oldPath = path.join(app.getPath('userData'), CONFIG.CREDENTIALS_FILE_NAME)
if (fs.existsSync(oldPath)) {
return oldPath
}
return path.join(getConfigDir(), CONFIG.CREDENTIALS_FILE_NAME)
}
// ============= Cryptographic Helper Methods =============
/**
* Calculate SHA256 hash of data and return hex encoded string
*/
private sha256Hash(data: string | Buffer): string {
return crypto.createHash('sha256').update(data).digest('hex')
}
/**
* Calculate HMAC-SHA256 and return buffer
*/
private hmacSha256(key: Buffer | string, data: string): Buffer {
return crypto.createHmac('sha256', key).update(data, 'utf8').digest()
}
/**
* Calculate HMAC-SHA256 and return hex encoded string
*/
private hmacSha256Hex(key: Buffer | string, data: string): string {
return crypto.createHmac('sha256', key).update(data, 'utf8').digest('hex')
}
/**
* URL encode according to RFC3986
*/
private uriEncode(str: string, encodeSlash: boolean = true): string {
if (!str) return ''
return str
.split('')
.map((char) => {
if (
(char >= 'A' && char <= 'Z') ||
(char >= 'a' && char <= 'z') ||
(char >= '0' && char <= '9') ||
char === '_' ||
char === '-' ||
char === '~' ||
char === '.'
) {
return char
}
if (char === '/' && !encodeSlash) {
return char
}
return encodeURIComponent(char)
})
.join('')
}
// ============= Signing Implementation =============
/**
* Get current UTC time in ISO8601 format (YYYYMMDD'T'HHMMSS'Z')
*/
private getIso8601DateTime(): string {
const now = new Date()
return now
.toISOString()
.replace(/[-:]/g, '')
.replace(/\.\d{3}/, '')
}
/**
* Get date portion from datetime (YYYYMMDD)
*/
private getDateFromDateTime(dateTime: string): string {
return dateTime.substring(0, 8)
}
/**
* Build canonical query string from query parameters
*/
private buildCanonicalQueryString(query: Record<string, string>): string {
if (!query || Object.keys(query).length === 0) {
return ''
}
return Object.keys(query)
.sort()
.map((key) => `${this.uriEncode(key)}=${this.uriEncode(query[key])}`)
.join('&')
}
/**
* Build canonical headers string
*/
private buildCanonicalHeaders(headers: Record<string, string>): {
canonicalHeaders: string
signedHeaders: string
} {
const sortedKeys = Object.keys(headers)
.map((k) => k.toLowerCase())
.sort()
const canonicalHeaders = sortedKeys.map((key) => `${key}:${headers[key]?.trim() || ''}`).join('\n') + '\n'
const signedHeaders = sortedKeys.join(';')
return { canonicalHeaders, signedHeaders }
}
/**
* Create the signing key through a series of HMAC operations
*
* kSecret = SecretAccessKey
* kDate = HMAC(kSecret, Date)
* kRegion = HMAC(kDate, Region)
* kService = HMAC(kRegion, Service)
* kSigning = HMAC(kService, "request")
*/
private deriveSigningKey(secretKey: string, date: string, region: string, service: string): Buffer {
const kDate = this.hmacSha256(secretKey, date)
const kRegion = this.hmacSha256(kDate, region)
const kService = this.hmacSha256(kRegion, service)
const kSigning = this.hmacSha256(kService, CONFIG.REQUEST_TYPE)
return kSigning
}
/**
* Create canonical request string
*
* CanonicalRequest =
* HTTPRequestMethod + '\n' +
* CanonicalURI + '\n' +
* CanonicalQueryString + '\n' +
* CanonicalHeaders + '\n' +
* SignedHeaders + '\n' +
* HexEncode(Hash(RequestPayload))
*/
private createCanonicalRequest(
method: string,
canonicalUri: string,
canonicalQueryString: string,
canonicalHeaders: string,
signedHeaders: string,
payloadHash: string
): string {
return [method, canonicalUri, canonicalQueryString, canonicalHeaders, signedHeaders, payloadHash].join('\n')
}
/**
* Create string to sign
*
* StringToSign =
* Algorithm + '\n' +
* RequestDateTime + '\n' +
* CredentialScope + '\n' +
* HexEncode(Hash(CanonicalRequest))
*/
private createStringToSign(dateTime: string, credentialScope: string, canonicalRequest: string): string {
const hashedCanonicalRequest = this.sha256Hash(canonicalRequest)
return [CONFIG.ALGORITHM, dateTime, credentialScope, hashedCanonicalRequest].join('\n')
}
/**
* Generate signature for the request
*/
private generateSignature(params: SignedRequestParams, credentials: VolcengineCredentials): SignedHeaders {
const { method, host, path: requestPath, query, body, service, region } = params
// Step 1: Prepare datetime
const dateTime = this.getIso8601DateTime()
const date = this.getDateFromDateTime(dateTime)
// Step 2: Calculate payload hash
const payloadHash = this.sha256Hash(body || '')
// Step 3: Prepare headers for signing
const headersToSign: Record<string, string> = {
host: host,
'x-date': dateTime,
'x-content-sha256': payloadHash,
'content-type': 'application/json'
}
// Step 4: Build canonical components
const canonicalUri = this.uriEncode(requestPath, false) || '/'
const canonicalQueryString = this.buildCanonicalQueryString(query)
const { canonicalHeaders, signedHeaders } = this.buildCanonicalHeaders(headersToSign)
// Step 5: Create canonical request
const canonicalRequest = this.createCanonicalRequest(
method.toUpperCase(),
canonicalUri,
canonicalQueryString,
canonicalHeaders,
signedHeaders,
payloadHash
)
// Step 6: Create credential scope and string to sign
const credentialScope = `${date}/${region}/${service}/${CONFIG.REQUEST_TYPE}`
const stringToSign = this.createStringToSign(dateTime, credentialScope, canonicalRequest)
// Step 7: Calculate signature
const signingKey = this.deriveSigningKey(credentials.secretAccessKey, date, region, service)
const signature = this.hmacSha256Hex(signingKey, stringToSign)
// Step 8: Build authorization header
const authorization = `${CONFIG.ALGORITHM} Credential=${credentials.accessKeyId}/${credentialScope}, SignedHeaders=${signedHeaders}, Signature=${signature}`
return {
Authorization: authorization,
'X-Date': dateTime,
'X-Content-Sha256': payloadHash,
Host: host
}
}
// ============= Credential Management =============
/**
* Save credentials securely using Electron's safeStorage
*/
public saveCredentials = async (
_: Electron.IpcMainInvokeEvent,
accessKeyId: string,
secretAccessKey: string
): Promise<void> => {
try {
if (!accessKeyId || !secretAccessKey) {
throw new VolcengineServiceError('Access Key ID and Secret Access Key are required')
}
const credentials: VolcengineCredentials = { accessKeyId, secretAccessKey }
const credentialsJson = JSON.stringify(credentials)
const encryptedData = safeStorage.encryptString(credentialsJson)
// Ensure directory exists
const dir = path.dirname(this.credentialsFilePath)
if (!fs.existsSync(dir)) {
await fs.promises.mkdir(dir, { recursive: true })
}
await fs.promises.writeFile(this.credentialsFilePath, encryptedData)
logger.info('Volcengine credentials saved successfully')
} catch (error) {
logger.error('Failed to save Volcengine credentials:', error as Error)
throw new VolcengineServiceError('Failed to save credentials', error)
}
}
/**
* Load credentials from encrypted storage
* @throws VolcengineServiceError if credentials file exists but is corrupted
*/
private async loadCredentials(): Promise<VolcengineCredentials | null> {
if (!fs.existsSync(this.credentialsFilePath)) {
return null
}
try {
const encryptedData = await fs.promises.readFile(this.credentialsFilePath)
const decryptedJson = safeStorage.decryptString(Buffer.from(encryptedData))
return JSON.parse(decryptedJson) as VolcengineCredentials
} catch (error) {
logger.error('Failed to load Volcengine credentials:', error as Error)
throw new VolcengineServiceError(
'Credentials file exists but could not be loaded. Please re-enter your credentials.',
error
)
}
}
/**
* Check if credentials exist
*/
public hasCredentials = async (): Promise<boolean> => {
return fs.existsSync(this.credentialsFilePath)
}
/**
* Clear stored credentials
*/
public clearCredentials = async (): Promise<void> => {
try {
if (fs.existsSync(this.credentialsFilePath)) {
await fs.promises.unlink(this.credentialsFilePath)
logger.info('Volcengine credentials cleared')
}
} catch (error) {
logger.error('Failed to clear Volcengine credentials:', error as Error)
throw new VolcengineServiceError('Failed to clear credentials', error)
}
}
// ============= API Methods =============
/**
* Make a signed request to Volcengine API
*/
private async makeSignedRequest<T>(
method: 'GET' | 'POST',
host: string,
path: string,
action: string,
version: string,
query?: Record<string, string>,
body?: Record<string, unknown>,
service: string = CONFIG.SERVICE_NAME,
region: string = CONFIG.DEFAULT_REGION
): Promise<T> {
const credentials = await this.loadCredentials()
if (!credentials) {
throw new VolcengineServiceError('No credentials found. Please save credentials first.')
}
const fullQuery: Record<string, string> = {
Action: action,
Version: version,
...query
}
const bodyString = body ? JSON.stringify(body) : ''
const signedHeaders = this.generateSignature(
{
method,
host,
path,
query: fullQuery,
headers: {},
body: bodyString,
service,
region
},
credentials
)
// Build URL with query string (use simple encoding for URL, canonical encoding is only for signature)
const urlParams = new URLSearchParams(fullQuery)
const url = `https://${host}${path}?${urlParams.toString()}`
const requestHeaders: Record<string, string> = {
...CONFIG.DEFAULT_HEADERS,
Authorization: signedHeaders.Authorization,
'X-Date': signedHeaders['X-Date'],
'X-Content-Sha256': signedHeaders['X-Content-Sha256']
}
logger.debug('Making Volcengine API request', { url, method, action })
try {
const response = await net.fetch(url, {
method,
headers: requestHeaders,
body: method === 'POST' && bodyString ? bodyString : undefined
})
if (!response.ok) {
const errorText = await response.text()
logger.error(`Volcengine API error: ${response.status}`, { errorText })
throw new VolcengineServiceError(`API request failed: ${response.status} - ${errorText}`)
}
return (await response.json()) as T
} catch (error) {
if (error instanceof VolcengineServiceError) {
throw error
}
logger.error('Volcengine API request failed:', error as Error)
throw new VolcengineServiceError('API request failed', error)
}
}
/**
* List foundation models from Volcengine ARK
*/
private async listFoundationModels(region: string = CONFIG.DEFAULT_REGION): Promise<ListFoundationModelsResponse> {
const requestBody: ListFoundationModelsRequest = {
PageNumber: 1,
PageSize: CONFIG.DEFAULT_PAGE_SIZE
}
const response = await this.makeSignedRequest<unknown>(
'POST',
CONFIG.API_URLS.ARK_HOST,
'/',
'ListFoundationModels',
CONFIG.API_VERSION,
{},
requestBody,
CONFIG.SERVICE_NAME,
region
)
return ListFoundationModelsResponseSchema.parse(response)
}
/**
* List user-created endpoints from Volcengine ARK
*/
private async listEndpoints(
projectName?: string,
region: string = CONFIG.DEFAULT_REGION
): Promise<ListEndpointsResponse> {
const requestBody: ListEndpointsRequest = {
ProjectName: projectName || 'default',
PageNumber: 1,
PageSize: CONFIG.DEFAULT_PAGE_SIZE
}
const response = await this.makeSignedRequest<unknown>(
'POST',
CONFIG.API_URLS.ARK_HOST,
'/',
'ListEndpoints',
CONFIG.API_VERSION,
{},
requestBody,
CONFIG.SERVICE_NAME,
region
)
return ListEndpointsResponseSchema.parse(response)
}
/**
* List all available models from Volcengine ARK
* Combines foundation models and user-created endpoints
*/
public listModels = async (
_?: Electron.IpcMainInvokeEvent,
projectName?: string,
region?: string
): Promise<ListModelsResult> => {
try {
const effectiveRegion = region || CONFIG.DEFAULT_REGION
const [foundationModelsResult, endpointsResult] = await Promise.allSettled([
this.listFoundationModels(effectiveRegion),
this.listEndpoints(projectName, effectiveRegion)
])
const models: ModelInfo[] = []
const warnings: string[] = []
if (foundationModelsResult.status === 'fulfilled') {
const foundationModels = foundationModelsResult.value
for (const item of foundationModels.Result.Items) {
models.push({
id: item.Name,
name: item.DisplayName || item.Name,
description: item.Description
})
}
logger.info(`Found ${foundationModels.Result.Items.length} foundation models`)
} else {
const errorMsg = `Failed to fetch foundation models: ${foundationModelsResult.reason}`
logger.warn(errorMsg)
warnings.push(errorMsg)
}
// Process endpoints
if (endpointsResult.status === 'fulfilled') {
const endpoints = endpointsResult.value
for (const item of endpoints.Result.Items) {
const modelRef = item.ModelReference
const foundationModelName = modelRef?.FoundationModel?.Name
const modelVersion = modelRef?.FoundationModel?.ModelVersion
const customModelId = modelRef?.CustomModelId
let displayName = item.Name || item.Id
if (foundationModelName) {
displayName = modelVersion ? `${foundationModelName} (${modelVersion})` : foundationModelName
} else if (customModelId) {
displayName = customModelId
}
models.push({
id: item.Id,
name: displayName,
description: item.Description
})
}
logger.info(`Found ${endpoints.Result.Items.length} endpoints`)
} else {
const errorMsg = `Failed to fetch endpoints: ${endpointsResult.reason}`
logger.warn(errorMsg)
warnings.push(errorMsg)
}
// If both failed, throw error
if (foundationModelsResult.status === 'rejected' && endpointsResult.status === 'rejected') {
throw new VolcengineServiceError('Failed to fetch both foundation models and endpoints')
}
const total =
(foundationModelsResult.status === 'fulfilled' ? foundationModelsResult.value.Result.TotalCount : 0) +
(endpointsResult.status === 'fulfilled' ? endpointsResult.value.Result.TotalCount : 0)
logger.info(`Total models found: ${models.length}`)
return {
models,
total,
warnings: warnings.length > 0 ? warnings : undefined
}
} catch (error) {
logger.error('Failed to list Volcengine models:', error as Error)
throw new VolcengineServiceError('Failed to list models', error)
}
}
/**
* Get authorization headers for external use
* This allows the renderer process to make direct API calls with proper authentication
*/
public getAuthHeaders = async (
_: Electron.IpcMainInvokeEvent,
params: {
method: 'GET' | 'POST'
host: string
path: string
query?: Record<string, string>
body?: string
service?: string
region?: string
}
): Promise<SignedHeaders> => {
const credentials = await this.loadCredentials()
if (!credentials) {
throw new VolcengineServiceError('No credentials found. Please save credentials first.')
}
return this.generateSignature(
{
method: params.method,
host: params.host,
path: params.path,
query: params.query || {},
headers: {},
body: params.body,
service: params.service || CONFIG.SERVICE_NAME,
region: params.region || CONFIG.DEFAULT_REGION
},
credentials
)
}
/**
* Make a generic signed API request
* This is a more flexible method that allows custom API calls
*/
public makeRequest = async (
_: Electron.IpcMainInvokeEvent,
params: {
method: 'GET' | 'POST'
host: string
path: string
action: string
version: string
query?: Record<string, string>
body?: Record<string, unknown>
service?: string
region?: string
}
): Promise<unknown> => {
return this.makeSignedRequest(
params.method,
params.host,
params.path,
params.action,
params.version,
params.query,
params.body,
params.service || CONFIG.SERVICE_NAME,
params.region || CONFIG.DEFAULT_REGION
)
}
}
export default new VolcengineService()

View File

@@ -87,7 +87,6 @@ export class ClaudeStreamState {
private pendingUsage: PendingUsageState = {}
private pendingToolCalls = new Map<string, PendingToolCall>()
private stepActive = false
private _streamFinished = false
constructor(options: ClaudeStreamStateOptions) {
this.logger = loggerService.withContext('ClaudeStreamState')
@@ -290,16 +289,6 @@ export class ClaudeStreamState {
getNamespacedToolCallId(rawToolCallId: string): string {
return buildNamespacedToolCallId(this.agentSessionId, rawToolCallId)
}
/** Marks the stream as finished (either completed or errored). */
markFinished(): void {
this._streamFinished = true
}
/** Returns true if the stream has already emitted a terminal event. */
isFinished(): boolean {
return this._streamFinished
}
}
export type { PendingToolCall }

View File

@@ -1,7 +1,6 @@
// src/main/services/agents/services/claudecode/index.ts
import { EventEmitter } from 'node:events'
import { createRequire } from 'node:module'
import path from 'node:path'
import type {
CanUseTool,
@@ -85,14 +84,18 @@ class ClaudeCodeService implements AgentServiceInterface {
})
return aiStream
}
// Validate provider has required configuration
// Note: We no longer restrict to anthropic type only - the API Server's unified adapter
// handles format conversion for any provider type (OpenAI, Gemini, etc.)
if (!modelInfo.provider?.apiKey) {
logger.error('Provider API key is missing', { modelInfo })
if (
(modelInfo.provider?.type !== 'anthropic' &&
(modelInfo.provider?.anthropicApiHost === undefined || modelInfo.provider.anthropicApiHost.trim() === '')) ||
modelInfo.provider.apiKey === ''
) {
logger.error('Anthropic provider configuration is missing', {
modelInfo
})
aiStream.emit('data', {
type: 'error',
error: new Error(`Provider '${modelInfo.provider?.id}' is missing API key configuration.`)
error: new Error(`Invalid provider type '${modelInfo.provider?.type}'. Expected 'anthropic' provider type.`)
})
return aiStream
}
@@ -103,25 +106,22 @@ class ClaudeCodeService implements AgentServiceInterface {
Object.entries(loginShellEnv).filter(([key]) => !key.toLowerCase().endsWith('_proxy'))
) as Record<string, string>
// Route through local API Server which handles format conversion via unified adapter
// This enables Claude Code Agent to work with any provider (OpenAI, Gemini, etc.)
// The API Server converts AI SDK responses to Anthropic SSE format transparently
const env = {
...loginShellEnvWithoutProxies,
ANTHROPIC_API_KEY: apiConfig.apiKey,
ANTHROPIC_AUTH_TOKEN: apiConfig.apiKey,
ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`,
// TODO: fix the proxy api server
// ANTHROPIC_API_KEY: apiConfig.apiKey,
// ANTHROPIC_AUTH_TOKEN: apiConfig.apiKey,
// ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`,
ANTHROPIC_API_KEY: modelInfo.provider.apiKey,
ANTHROPIC_AUTH_TOKEN: modelInfo.provider.apiKey,
ANTHROPIC_BASE_URL: modelInfo.provider.anthropicApiHost?.trim() || modelInfo.provider.apiHost,
ANTHROPIC_MODEL: modelInfo.modelId,
ANTHROPIC_DEFAULT_OPUS_MODEL: modelInfo.modelId,
ANTHROPIC_DEFAULT_SONNET_MODEL: modelInfo.modelId,
// TODO: support set small model in UI
ANTHROPIC_DEFAULT_HAIKU_MODEL: modelInfo.modelId,
ELECTRON_RUN_AS_NODE: '1',
ELECTRON_NO_ATTACH_CONSOLE: '1',
// Set CLAUDE_CONFIG_DIR to app's userData directory to avoid path encoding issues
// on Windows when the username contains non-ASCII characters (e.g., Chinese characters)
// This prevents the SDK from using the user's home directory which may have encoding problems
CLAUDE_CONFIG_DIR: path.join(app.getPath('userData'), '.claude')
ELECTRON_NO_ATTACH_CONSOLE: '1'
}
const errorChunks: string[] = []
@@ -534,19 +534,6 @@ class ClaudeCodeService implements AgentServiceInterface {
return
}
// Skip emitting error if stream already finished (error was handled via result message)
if (streamState.isFinished()) {
logger.debug('SDK process exited after stream finished, skipping duplicate error event', {
duration,
error: errorObj instanceof Error ? { name: errorObj.name, message: errorObj.message } : String(errorObj)
})
// Still emit complete to signal stream end
stream.emit('data', {
type: 'complete'
})
return
}
errorChunks.push(errorObj instanceof Error ? errorObj.message : String(errorObj))
const errorMessage = errorChunks.join('\n\n')
logger.error('SDK query failed', {

View File

@@ -121,7 +121,7 @@ export function transformSDKMessageToStreamParts(sdkMessage: SDKMessage, state:
case 'system':
return handleSystemMessage(sdkMessage)
case 'result':
return handleResultMessage(sdkMessage, state)
return handleResultMessage(sdkMessage)
default:
logger.warn('Unknown SDKMessage type', { type: (sdkMessage as any).type })
return []
@@ -193,30 +193,6 @@ function handleAssistantMessage(
}
break
}
case 'thinking':
case 'redacted_thinking': {
const thinkingText = block.type === 'thinking' ? block.thinking : block.data
if (thinkingText) {
const id = generateMessageId()
chunks.push({
type: 'reasoning-start',
id,
providerMetadata
})
chunks.push({
type: 'reasoning-delta',
id,
text: thinkingText,
providerMetadata
})
chunks.push({
type: 'reasoning-end',
id,
providerMetadata
})
}
break
}
case 'tool_use':
handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks)
break
@@ -469,11 +445,7 @@ function handleStreamEvent(
case 'content_block_stop': {
const block = state.closeBlock(event.index)
if (!block) {
// Some providers (e.g., Gemini) send content via assistant message before stream events,
// so the block may not exist in state. This is expected behavior, not an error.
logger.debug('Received content_block_stop for unknown index (may be from non-streaming content)', {
index: event.index
})
logger.warn('Received content_block_stop for unknown index', { index: event.index })
break
}
@@ -707,13 +679,7 @@ function handleSystemMessage(message: Extract<SDKMessage, { type: 'system' }>):
* Successful runs yield a `finish` frame with aggregated usage metrics, while
* failures are surfaced as `error` frames.
*/
function handleResultMessage(
message: Extract<SDKMessage, { type: 'result' }>,
state: ClaudeStreamState
): AgentStreamPart[] {
// Mark stream as finished to prevent duplicate error events when SDK process exits
state.markFinished()
function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>): AgentStreamPart[] {
const chunks: AgentStreamPart[] = []
let usage: LanguageModelUsage | undefined
@@ -725,33 +691,26 @@ function handleResultMessage(
}
}
chunks.push({
type: 'finish',
totalUsage: usage ?? emptyUsage,
finishReason: mapClaudeCodeFinishReason(message.subtype),
providerMetadata: {
...sdkMessageToProviderMetadata(message),
usage: message.usage,
durationMs: message.duration_ms,
costUsd: message.total_cost_usd,
raw: message
}
} as AgentStreamPart)
if (message.subtype !== 'success') {
if (message.subtype === 'success') {
chunks.push({
type: 'finish',
totalUsage: usage ?? emptyUsage,
finishReason: mapClaudeCodeFinishReason(message.subtype),
providerMetadata: {
...sdkMessageToProviderMetadata(message),
usage: message.usage,
durationMs: message.duration_ms,
costUsd: message.total_cost_usd,
raw: message
}
} as AgentStreamPart)
} else {
chunks.push({
type: 'error',
error: {
message: `${message.subtype}: Process failed after ${message.num_turns} turns`
}
} as AgentStreamPart)
} else {
if (message.is_error) {
const errorMatch = message.result.match(/\{.*\}/)
if (errorMatch) {
const errorDetail = JSON.parse(errorMatch[0])
chunks.push(errorDetail)
}
}
}
return chunks
}

View File

@@ -1,196 +0,0 @@
import { describe, expect, it } from 'vitest'
import { buildFunctionCallToolName } from '../mcp'
describe('buildFunctionCallToolName', () => {
describe('basic functionality', () => {
it('should combine server name and tool name', () => {
const result = buildFunctionCallToolName('github', 'search_issues')
expect(result).toContain('github')
expect(result).toContain('search')
})
it('should sanitize names by replacing dashes with underscores', () => {
const result = buildFunctionCallToolName('my-server', 'my-tool')
// Input dashes are replaced, but the separator between server and tool is a dash
expect(result).toBe('my_serv-my_tool')
expect(result).toContain('_')
})
it('should handle empty server names gracefully', () => {
const result = buildFunctionCallToolName('', 'tool')
expect(result).toBeTruthy()
})
})
describe('uniqueness with serverId', () => {
it('should generate different IDs for same server name but different serverIds', () => {
const serverId1 = 'server-id-123456'
const serverId2 = 'server-id-789012'
const serverName = 'github'
const toolName = 'search_repos'
const result1 = buildFunctionCallToolName(serverName, toolName, serverId1)
const result2 = buildFunctionCallToolName(serverName, toolName, serverId2)
expect(result1).not.toBe(result2)
expect(result1).toContain('123456')
expect(result2).toContain('789012')
})
it('should generate same ID when serverId is not provided', () => {
const serverName = 'github'
const toolName = 'search_repos'
const result1 = buildFunctionCallToolName(serverName, toolName)
const result2 = buildFunctionCallToolName(serverName, toolName)
expect(result1).toBe(result2)
})
it('should include serverId suffix when provided', () => {
const serverId = 'abc123def456'
const result = buildFunctionCallToolName('server', 'tool', serverId)
// Should include last 6 chars of serverId
expect(result).toContain('ef456')
})
})
describe('character sanitization', () => {
it('should replace invalid characters with underscores', () => {
const result = buildFunctionCallToolName('test@server', 'tool#name')
expect(result).not.toMatch(/[@#]/)
expect(result).toMatch(/^[a-zA-Z0-9_-]+$/)
})
it('should ensure name starts with a letter', () => {
const result = buildFunctionCallToolName('123server', '456tool')
expect(result).toMatch(/^[a-zA-Z]/)
})
it('should handle consecutive underscores/dashes', () => {
const result = buildFunctionCallToolName('my--server', 'my__tool')
expect(result).not.toMatch(/[_-]{2,}/)
})
})
describe('length constraints', () => {
it('should truncate names longer than 63 characters', () => {
const longServerName = 'a'.repeat(50)
const longToolName = 'b'.repeat(50)
const result = buildFunctionCallToolName(longServerName, longToolName, 'id123456')
expect(result.length).toBeLessThanOrEqual(63)
})
it('should not end with underscore or dash after truncation', () => {
const longServerName = 'a'.repeat(50)
const longToolName = 'b'.repeat(50)
const result = buildFunctionCallToolName(longServerName, longToolName, 'id123456')
expect(result).not.toMatch(/[_-]$/)
})
it('should preserve serverId suffix even with long server/tool names', () => {
const longServerName = 'a'.repeat(50)
const longToolName = 'b'.repeat(50)
const serverId = 'server-id-xyz789'
const result = buildFunctionCallToolName(longServerName, longToolName, serverId)
// The suffix should be preserved and not truncated
expect(result).toContain('xyz789')
expect(result.length).toBeLessThanOrEqual(63)
})
it('should ensure two long-named servers with different IDs produce different results', () => {
const longServerName = 'a'.repeat(50)
const longToolName = 'b'.repeat(50)
const serverId1 = 'server-id-abc123'
const serverId2 = 'server-id-def456'
const result1 = buildFunctionCallToolName(longServerName, longToolName, serverId1)
const result2 = buildFunctionCallToolName(longServerName, longToolName, serverId2)
// Both should be within limit
expect(result1.length).toBeLessThanOrEqual(63)
expect(result2.length).toBeLessThanOrEqual(63)
// They should be different due to preserved suffix
expect(result1).not.toBe(result2)
})
})
describe('edge cases with serverId', () => {
it('should handle serverId with only non-alphanumeric characters', () => {
const serverId = '------' // All dashes
const result = buildFunctionCallToolName('server', 'tool', serverId)
// Should still produce a valid unique suffix via fallback hash
expect(result).toBeTruthy()
expect(result.length).toBeLessThanOrEqual(63)
expect(result).toMatch(/^[a-zA-Z][a-zA-Z0-9_-]*$/)
// Should have a suffix (underscore followed by something)
expect(result).toMatch(/_[a-z0-9]+$/)
})
it('should produce different results for different non-alphanumeric serverIds', () => {
const serverId1 = '------'
const serverId2 = '!!!!!!'
const result1 = buildFunctionCallToolName('server', 'tool', serverId1)
const result2 = buildFunctionCallToolName('server', 'tool', serverId2)
// Should be different because the hash fallback produces different values
expect(result1).not.toBe(result2)
})
it('should handle empty string serverId differently from undefined', () => {
const resultWithEmpty = buildFunctionCallToolName('server', 'tool', '')
const resultWithUndefined = buildFunctionCallToolName('server', 'tool', undefined)
// Empty string is falsy, so both should behave the same (no suffix)
expect(resultWithEmpty).toBe(resultWithUndefined)
})
it('should handle serverId with mixed alphanumeric and special chars', () => {
const serverId = 'ab@#cd' // Mixed chars, last 6 chars contain some alphanumeric
const result = buildFunctionCallToolName('server', 'tool', serverId)
// Should extract alphanumeric chars: 'abcd' from 'ab@#cd'
expect(result).toContain('abcd')
})
})
describe('real-world scenarios', () => {
it('should handle GitHub MCP server instances correctly', () => {
const serverName = 'github'
const toolName = 'search_repositories'
const githubComId = 'server-github-com-abc123'
const gheId = 'server-ghe-internal-xyz789'
const tool1 = buildFunctionCallToolName(serverName, toolName, githubComId)
const tool2 = buildFunctionCallToolName(serverName, toolName, gheId)
// Should be different
expect(tool1).not.toBe(tool2)
// Both should be valid identifiers
expect(tool1).toMatch(/^[a-zA-Z][a-zA-Z0-9_-]*$/)
expect(tool2).toMatch(/^[a-zA-Z][a-zA-Z0-9_-]*$/)
// Both should be <= 63 chars
expect(tool1.length).toBeLessThanOrEqual(63)
expect(tool2.length).toBeLessThanOrEqual(63)
})
it('should handle tool names that already include server name prefix', () => {
const result = buildFunctionCallToolName('github', 'github_search_repos')
expect(result).toBeTruthy()
// Should not double the server name
expect(result.split('github').length - 1).toBeLessThanOrEqual(2)
})
})
})

View File

@@ -1,25 +1,7 @@
export function buildFunctionCallToolName(serverName: string, toolName: string, serverId?: string) {
export function buildFunctionCallToolName(serverName: string, toolName: string) {
const sanitizedServer = serverName.trim().replace(/-/g, '_')
const sanitizedTool = toolName.trim().replace(/-/g, '_')
// Calculate suffix first to reserve space for it
// Suffix format: "_" + 6 alphanumeric chars = 7 chars total
let serverIdSuffix = ''
if (serverId) {
// Take the last 6 characters of the serverId for brevity
serverIdSuffix = serverId.slice(-6).replace(/[^a-zA-Z0-9]/g, '')
// Fallback: if suffix becomes empty (all non-alphanumeric chars), use a simple hash
if (!serverIdSuffix) {
const hash = serverId.split('').reduce((acc, char) => acc + char.charCodeAt(0), 0)
serverIdSuffix = hash.toString(36).slice(-6) || 'x'
}
}
// Reserve space for suffix when calculating max base name length
const SUFFIX_LENGTH = serverIdSuffix ? serverIdSuffix.length + 1 : 0 // +1 for underscore
const MAX_BASE_LENGTH = 63 - SUFFIX_LENGTH
// Combine server name and tool name
let name = sanitizedTool
if (!sanitizedTool.includes(sanitizedServer.slice(0, 7))) {
@@ -38,9 +20,9 @@ export function buildFunctionCallToolName(serverName: string, toolName: string,
// Remove consecutive underscores/dashes (optional improvement)
name = name.replace(/[_-]{2,}/g, '_')
// Truncate base name BEFORE adding suffix to ensure suffix is never cut off
if (name.length > MAX_BASE_LENGTH) {
name = name.slice(0, MAX_BASE_LENGTH)
// Truncate to 63 characters maximum
if (name.length > 63) {
name = name.slice(0, 63)
}
// Handle edge case: ensure we still have a valid name if truncation left invalid chars at edges
@@ -48,10 +30,5 @@ export function buildFunctionCallToolName(serverName: string, toolName: string,
name = name.slice(0, -1)
}
// Now append the suffix - it will always fit within 63 chars
if (serverIdSuffix) {
name = `${name}_${serverIdSuffix}`
}
return name
}

View File

@@ -572,6 +572,41 @@ const api = {
status: () => ipcRenderer.invoke(IpcChannel.WebSocket_Status),
sendFile: (filePath: string) => ipcRenderer.invoke(IpcChannel.WebSocket_SendFile, filePath),
getAllCandidates: () => ipcRenderer.invoke(IpcChannel.WebSocket_GetAllCandidates)
},
volcengine: {
saveCredentials: (accessKeyId: string, secretAccessKey: string): Promise<void> =>
ipcRenderer.invoke(IpcChannel.Volcengine_SaveCredentials, accessKeyId, secretAccessKey),
hasCredentials: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.Volcengine_HasCredentials),
clearCredentials: (): Promise<void> => ipcRenderer.invoke(IpcChannel.Volcengine_ClearCredentials),
listModels: (
projectName?: string,
region?: string
): Promise<{
models: Array<{ id: string; name: string; description?: string; created?: number }>
total?: number
warnings?: string[]
}> => ipcRenderer.invoke(IpcChannel.Volcengine_ListModels, projectName, region),
getAuthHeaders: (params: {
method: 'GET' | 'POST'
host: string
path: string
query?: Record<string, string>
body?: string
service?: string
region?: string
}): Promise<{ Authorization: string; 'X-Date': string; 'X-Content-Sha256': string; Host: string }> =>
ipcRenderer.invoke(IpcChannel.Volcengine_GetAuthHeaders, params),
makeRequest: (params: {
method: 'GET' | 'POST'
host: string
path: string
action: string
version: string
query?: Record<string, string>
body?: Record<string, unknown>
service?: string
region?: string
}): Promise<unknown> => ipcRenderer.invoke(IpcChannel.Volcengine_MakeRequest, params)
}
}

View File

@@ -212,9 +212,8 @@ export class ToolCallChunkHandler {
description: toolName,
type: 'builtin'
} as BaseTool
} else if ((mcpTool = this.mcpTools.find((t) => t.id === toolName) as MCPTool)) {
} else if ((mcpTool = this.mcpTools.find((t) => t.name === toolName) as MCPTool)) {
// 如果是客户端执行的 MCP 工具,沿用现有逻辑
// toolName is mcpTool.id (registered with id as key in convertMcpToolsToAiSdkTools)
logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`)
// mcpTool = this.mcpTools.find((t) => t.name === toolName) as MCPTool
// if (!mcpTool) {

View File

@@ -14,6 +14,7 @@ import { OpenAIAPIClient } from './openai/OpenAIApiClient'
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
import { OVMSClient } from './ovms/OVMSClient'
import { PPIOAPIClient } from './ppio/PPIOAPIClient'
import { VolcengineAPIClient } from './volcengine/VolcengineAPIClient'
import { ZhipuAPIClient } from './zhipu/ZhipuAPIClient'
const logger = loggerService.withContext('ApiClientFactory')
@@ -64,6 +65,12 @@ export class ApiClientFactory {
return instance
}
if (provider.id === 'doubao') {
logger.debug(`Creating VolcengineAPIClient for provider: ${provider.id}`)
instance = new VolcengineAPIClient(provider) as BaseApiClient
return instance
}
if (provider.id === 'ovms') {
logger.debug(`Creating OVMSClient for provider: ${provider.id}`)
instance = new OVMSClient(provider) as BaseApiClient

View File

@@ -405,9 +405,6 @@ export abstract class BaseApiClient<
if (!param.name?.trim()) {
return acc
}
// Parse JSON type parameters (Legacy API clients)
// Related: src/renderer/src/pages/settings/AssistantSettings/AssistantModelSettings.tsx:133-148
// The UI stores JSON type params as strings, this function parses them before sending to API
if (param.type === 'json') {
const value = param.value as string
if (value === 'undefined') {

View File

@@ -46,7 +46,6 @@ import type {
GeminiSdkRawOutput,
GeminiSdkToolCall
} from '@renderer/types/sdk'
import { getTrailingApiVersion, withoutTrailingApiVersion } from '@renderer/utils'
import { isToolUseModeFunction } from '@renderer/utils/assistant'
import {
geminiFunctionCallToMcpTool,
@@ -164,10 +163,6 @@ export class GeminiAPIClient extends BaseApiClient<
return models
}
override getBaseURL(): string {
return withoutTrailingApiVersion(super.getBaseURL())
}
override async getSdkInstance() {
if (this.sdkInstance) {
return this.sdkInstance
@@ -193,13 +188,6 @@ export class GeminiAPIClient extends BaseApiClient<
if (this.provider.isVertex) {
return 'v1'
}
// Extract trailing API version from the URL
const trailingVersion = getTrailingApiVersion(this.provider.apiHost || '')
if (trailingVersion) {
return trailingVersion
}
return 'v1beta'
}

View File

@@ -24,7 +24,7 @@ export class VertexAPIClient extends GeminiAPIClient {
this.anthropicVertexClient = new AnthropicVertexClient(provider)
// 如果传入的是普通 Provider转换为 VertexProvider
if (isVertexProvider(provider)) {
this.vertexProvider = provider as VertexProvider
this.vertexProvider = provider
} else {
this.vertexProvider = createVertexProvider(provider)
}

View File

@@ -0,0 +1,74 @@
import type OpenAI from '@cherrystudio/openai'
import { loggerService } from '@logger'
import { getVolcengineProjectName, getVolcengineRegion } from '@renderer/hooks/useVolcengine'
import type { Provider } from '@renderer/types'
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
const logger = loggerService.withContext('VolcengineAPIClient')
/**
* Volcengine (Doubao) API Client
*
* Extends OpenAIAPIClient for standard chat completions (OpenAI-compatible),
* but overrides listModels to use Volcengine's signed API via IPC.
*/
export class VolcengineAPIClient extends OpenAIAPIClient {
constructor(provider: Provider) {
super(provider)
}
/**
* List models using Volcengine's signed API
* This calls the main process VolcengineService which handles HMAC-SHA256 signing
*/
override async listModels(): Promise<OpenAI.Models.Model[]> {
try {
const hasCredentials = await window.api.volcengine.hasCredentials()
if (!hasCredentials) {
logger.info('Volcengine credentials not configured, falling back to OpenAI-compatible list')
// Fall back to standard OpenAI-compatible API if no Volcengine credentials
return super.listModels()
}
logger.info('Fetching models from Volcengine API using signed request')
const projectName = getVolcengineProjectName()
const region = getVolcengineRegion()
const response = await window.api.volcengine.listModels(projectName, region)
if (!response || !response.models) {
logger.warn('Empty response from Volcengine listModels')
return []
}
// Notify user of any partial failures
if (response.warnings && response.warnings.length > 0) {
for (const warning of response.warnings) {
logger.warn(warning)
}
window.toast?.warning('Some Volcengine models could not be fetched. Check logs for details.')
}
const models: OpenAI.Models.Model[] = response.models.map((model) => ({
id: model.id,
object: 'model' as const,
created: model.created || Math.floor(Date.now() / 1000),
owned_by: 'volcengine',
// @ts-ignore - description is used by UI to display model name
name: model.name || model.id
}))
logger.info(`Found ${models.length} models from Volcengine API`)
return models
} catch (error) {
logger.error('Failed to list Volcengine models:', error as Error)
// Notify user before falling back
window.toast?.warning('Failed to fetch Volcengine models. Check credentials if this persists.')
// Fall back to standard OpenAI-compatible API on error
logger.info('Falling back to OpenAI-compatible model list')
return super.listModels()
}
}
}

View File

@@ -5,7 +5,6 @@ import type { MCPTool } from '@renderer/types'
import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
import { openrouterReasoningMiddleware, skipGeminiThoughtSignatureMiddleware } from '@shared/middleware'
import type { LanguageModelMiddleware } from 'ai'
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
import { isEmpty } from 'lodash'
@@ -14,7 +13,9 @@ import { getAiSdkProviderId } from '../provider/factory'
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
import { noThinkMiddleware } from './noThinkMiddleware'
import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware'
import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware'
import { toolChoiceMiddleware } from './toolChoiceMiddleware'
const logger = loggerService.withContext('AiSdkMiddlewareBuilder')

View File

@@ -0,0 +1,50 @@
import type { LanguageModelV2StreamPart } from '@ai-sdk/provider'
import type { LanguageModelMiddleware } from 'ai'
/**
* https://openrouter.ai/docs/docs/best-practices/reasoning-tokens#example-preserving-reasoning-blocks-with-openrouter-and-claude
*
* @returns LanguageModelMiddleware - a middleware filter redacted block
*/
export function openrouterReasoningMiddleware(): LanguageModelMiddleware {
const REDACTED_BLOCK = '[REDACTED]'
return {
middlewareVersion: 'v2',
wrapGenerate: async ({ doGenerate }) => {
const { content, ...rest } = await doGenerate()
const modifiedContent = content.map((part) => {
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
return {
...part,
text: part.text.replace(REDACTED_BLOCK, '')
}
}
return part
})
return { content: modifiedContent, ...rest }
},
wrapStream: async ({ doStream }) => {
const { stream, ...rest } = await doStream()
return {
stream: stream.pipeThrough(
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
transform(
chunk: LanguageModelV2StreamPart,
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
) {
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
controller.enqueue({
...chunk,
delta: chunk.delta.replace(REDACTED_BLOCK, '')
})
} else {
controller.enqueue(chunk)
}
}
})
),
...rest
}
}
}
}

View File

@@ -0,0 +1,36 @@
import type { LanguageModelMiddleware } from 'ai'
/**
* skip Gemini Thought Signature Middleware
* 由于多模型客户端请求的复杂性(可以中途切换其他模型),这里选择通过中间件方式添加跳过所有 Gemini3 思考签名
* Due to the complexity of multi-model client requests (which can switch to other models mid-process),
* it was decided to add a skip for all Gemini3 thinking signatures via middleware.
* @param aiSdkId AI SDK Provider ID
* @returns LanguageModelMiddleware
*/
export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelMiddleware {
const MAGIC_STRING = 'skip_thought_signature_validator'
return {
middlewareVersion: 'v2',
transformParams: async ({ params }) => {
const transformedParams = { ...params }
// Process messages in prompt
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
transformedParams.prompt = transformedParams.prompt.map((message) => {
if (typeof message.content !== 'string') {
for (const part of message.content) {
const googleOptions = part?.providerOptions?.[aiSdkId]
if (googleOptions?.thoughtSignature) {
googleOptions.thoughtSignature = MAGIC_STRING
}
}
}
return message
})
}
return transformedParams
}
}
}

View File

@@ -7,7 +7,7 @@ import { isAwsBedrockProvider, isVertexProvider } from '@renderer/utils/provider
// https://docs.claude.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
const INTERLEAVED_THINKING_HEADER = 'interleaved-thinking-2025-05-14'
// https://docs.claude.com/en/docs/build-with-claude/context-windows#1m-token-context-window
// const CONTEXT_100M_HEADER = 'context-1m-2025-08-07'
const CONTEXT_100M_HEADER = 'context-1m-2025-08-07'
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude/web-search
const WEBSEARCH_HEADER = 'web-search-2025-03-05'
@@ -17,7 +17,7 @@ export function addAnthropicHeaders(assistant: Assistant, model: Model): string[
if (
isClaude45ReasoningModel(model) &&
isToolUseModeFunction(assistant) &&
!(isVertexProvider(provider) || isAwsBedrockProvider(provider))
!(isVertexProvider(provider) && isAwsBedrockProvider(provider))
) {
anthropicHeaders.push(INTERLEAVED_THINKING_HEADER)
}
@@ -25,9 +25,7 @@ export function addAnthropicHeaders(assistant: Assistant, model: Model): string[
if (isVertexProvider(provider) && assistant.enableWebSearch) {
anthropicHeaders.push(WEBSEARCH_HEADER)
}
// We may add it by user preference in assistant.settings instead of always adding it.
// See #11540, #11397
// anthropicHeaders.push(CONTEXT_100M_HEADER)
anthropicHeaders.push(CONTEXT_100M_HEADER)
}
return anthropicHeaders
}

View File

@@ -28,7 +28,6 @@ import { type Assistant, type MCPTool, type Provider } from '@renderer/types'
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern'
import { replacePromptVariables } from '@renderer/utils/prompt'
import { isAwsBedrockProvider } from '@renderer/utils/provider'
import type { ModelMessage, Tool } from 'ai'
import { stepCountIs } from 'ai'
@@ -176,7 +175,7 @@ export async function buildStreamTextParams(
let headers: Record<string, string | undefined> = options.requestOptions?.headers ?? {}
if (isAnthropicModel(model) && !isAwsBedrockProvider(provider)) {
if (isAnthropicModel(model)) {
const newBetaHeaders = { 'anthropic-beta': addAnthropicHeaders(assistant, model).join(',') }
headers = combineHeaders(headers, newBetaHeaders)
}

View File

@@ -1,4 +1,4 @@
import type { Model, Provider } from '@renderer/types'
import type { Provider } from '@renderer/types'
import { describe, expect, it, vi } from 'vitest'
import { getAiSdkProviderId } from '../factory'
@@ -68,18 +68,6 @@ function createTestProvider(id: string, type: string): Provider {
} as Provider
}
function createAzureProvider(id: string, apiVersion?: string, model?: string): Provider {
return {
id,
type: 'azure-openai',
name: `Azure Test ${id}`,
apiKey: 'azure-test-key',
apiHost: 'azure-test-host',
apiVersion,
models: [{ id: model || 'gpt-4' } as Model]
}
}
describe('Integrated Provider Registry', () => {
describe('Provider ID Resolution', () => {
it('should resolve openrouter provider correctly', () => {
@@ -123,24 +111,6 @@ describe('Integrated Provider Registry', () => {
const result = getAiSdkProviderId(unknownProvider)
expect(result).toBe('unknown-provider')
})
it('should handle Azure OpenAI providers correctly', () => {
const azureProvider = createAzureProvider('azure-test', '2024-02-15', 'gpt-4o')
const result = getAiSdkProviderId(azureProvider)
expect(result).toBe('azure')
})
it('should handle Azure OpenAI providers response endpoint correctly', () => {
const azureProvider = createAzureProvider('azure-test', 'v1', 'gpt-4o')
const result = getAiSdkProviderId(azureProvider)
expect(result).toBe('azure-responses')
})
it('should handle Azure provider Claude Models', () => {
const provider = createTestProvider('azure-anthropic', 'anthropic')
const result = getAiSdkProviderId(provider)
expect(result).toBe('azure-anthropic')
})
})
describe('Backward Compatibility', () => {

View File

@@ -24,17 +24,7 @@ vi.mock('@renderer/services/AssistantService', () => ({
vi.mock('@renderer/store', () => ({
default: {
getState: () => ({
copilot: { defaultHeaders: {} },
llm: {
settings: {
vertexai: {
projectId: 'test-project',
location: 'us-central1'
}
}
}
})
getState: () => ({ copilot: { defaultHeaders: {} } })
}
}))
@@ -43,7 +33,7 @@ vi.mock('@renderer/utils/api', () => ({
if (isSupportedAPIVersion === false) {
return host // Return host as-is when isSupportedAPIVersion is false
}
return host ? `${host}/v1` : '' // Default behavior when isSupportedAPIVersion is true
return `${host}/v1` // Default behavior when isSupportedAPIVersion is true
}),
routeToEndpoint: vi.fn((host) => ({
baseURL: host,
@@ -51,20 +41,6 @@ vi.mock('@renderer/utils/api', () => ({
}))
}))
// Also mock @shared/api since formatProviderApiHost uses it directly
vi.mock('@shared/api', async (importOriginal) => {
const actual = (await importOriginal()) as any
return {
...actual,
formatApiHost: vi.fn((host, isSupportedAPIVersion = true) => {
if (isSupportedAPIVersion === false) {
return host || '' // Return host as-is when isSupportedAPIVersion is false
}
return host ? `${host}/v1` : '' // Default behavior when isSupportedAPIVersion is true
})
}
})
vi.mock('@renderer/utils/provider', async (importOriginal) => {
const actual = (await importOriginal()) as any
return {
@@ -97,8 +73,8 @@ vi.mock('@renderer/services/AssistantService', () => ({
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model, Provider } from '@renderer/types'
import { formatApiHost } from '@renderer/utils/api'
import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider'
import { formatApiHost } from '@shared/api'
import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants'
import { getActualProvider, providerToAiSdkConfig } from '../providerConfig'

View File

@@ -1,13 +1,13 @@
/**
* AiHubMix规则集
*/
import { getLowerBaseModelName } from '@shared/utils/naming'
import { isOpenAILLMModel } from '@renderer/config/models'
import type { Provider } from '@renderer/types'
import type { MinimalModel, MinimalProvider } from '../types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
const extraProviderConfig = <P extends MinimalProvider>(provider: P) => {
const extraProviderConfig = (provider: Provider) => {
return {
...provider,
extra_headers: {
@@ -17,23 +17,11 @@ const extraProviderConfig = <P extends MinimalProvider>(provider: P) => {
}
}
function isOpenAILLMModel<M extends MinimalModel>(model: M): boolean {
const modelId = getLowerBaseModelName(model.id)
const reasonings = ['o1', 'o3', 'o4', 'gpt-oss']
if (reasonings.some((r) => modelId.includes(r))) {
return true
}
if (modelId.includes('gpt')) {
return true
}
return false
}
const AIHUBMIX_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider) => {
provider: (provider: Provider) => {
return extraProviderConfig({
...provider,
type: 'anthropic'
@@ -46,7 +34,7 @@ const AIHUBMIX_RULES: RuleSet = {
!model.id.endsWith('-nothink') &&
!model.id.endsWith('-search') &&
!model.id.includes('embedding'),
provider: (provider) => {
provider: (provider: Provider) => {
return extraProviderConfig({
...provider,
type: 'gemini',
@@ -56,7 +44,7 @@ const AIHUBMIX_RULES: RuleSet = {
},
{
match: isOpenAILLMModel,
provider: (provider) => {
provider: (provider: Provider) => {
return extraProviderConfig({
...provider,
type: 'openai-response'
@@ -64,8 +52,7 @@ const AIHUBMIX_RULES: RuleSet = {
}
}
],
fallbackRule: (provider) => extraProviderConfig(provider)
fallbackRule: (provider: Provider) => extraProviderConfig(provider)
}
export const aihubmixProviderCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
provider2Provider<MinimalModel, MinimalProvider, P>(AIHUBMIX_RULES, model, provider)
export const aihubmixProviderCreator = provider2Provider.bind(null, AIHUBMIX_RULES)

View File

@@ -0,0 +1,22 @@
import type { Provider } from '@renderer/types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
const AZURE_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: Provider) => ({
...provider,
type: 'anthropic',
apiHost: provider.apiHost + 'anthropic/v1',
id: 'azure-anthropic'
})
}
],
fallbackRule: (provider: Provider) => provider
}
export const azureAnthropicProviderCreator = provider2Provider.bind(null, AZURE_ANTHROPIC_RULES)

View File

@@ -0,0 +1,22 @@
import type { Model, Provider } from '@renderer/types'
import type { RuleSet } from './types'
export const startsWith = (prefix: string) => (model: Model) => model.id.toLowerCase().startsWith(prefix.toLowerCase())
export const endpointIs = (type: string) => (model: Model) => model.endpoint_type === type
/**
* 解析模型对应的Provider
* @param ruleSet 规则集对象
* @param model 模型对象
* @param provider 原始provider对象
* @returns 解析出的provider对象
*/
export function provider2Provider(ruleSet: RuleSet, model: Model, provider: Provider): Provider {
for (const rule of ruleSet.rules) {
if (rule.match(model)) {
return rule.provider(provider)
}
}
return ruleSet.fallbackRule(provider)
}

View File

@@ -1,7 +1,3 @@
// Re-export from shared config
export {
aihubmixProviderCreator,
azureAnthropicProviderCreator,
newApiResolverCreator,
vertexAnthropicProviderCreator
} from '@shared/provider/config'
export { aihubmixProviderCreator } from './aihubmix'
export { newApiResolverCreator } from './newApi'
export { vertexAnthropicProviderCreator } from './vertext-anthropic'

View File

@@ -1,7 +1,8 @@
/**
* NewAPI规则集
*/
import type { MinimalModel, MinimalProvider, ProviderType } from '../types'
import type { Provider } from '@renderer/types'
import { endpointIs, provider2Provider } from './helper'
import type { RuleSet } from './types'
@@ -9,43 +10,42 @@ const NEWAPI_RULES: RuleSet = {
rules: [
{
match: endpointIs('anthropic'),
provider: (provider) => {
provider: (provider: Provider) => {
return {
...provider,
type: 'anthropic' as ProviderType
type: 'anthropic'
}
}
},
{
match: endpointIs('gemini'),
provider: (provider) => {
provider: (provider: Provider) => {
return {
...provider,
type: 'gemini' as ProviderType
type: 'gemini'
}
}
},
{
match: endpointIs('openai-response'),
provider: (provider) => {
provider: (provider: Provider) => {
return {
...provider,
type: 'openai-response' as ProviderType
type: 'openai-response'
}
}
},
{
match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model),
provider: (provider) => {
provider: (provider: Provider) => {
return {
...provider,
type: 'openai' as ProviderType
type: 'openai'
}
}
}
],
fallbackRule: (provider) => provider
fallbackRule: (provider: Provider) => provider
}
export const newApiResolverCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
provider2Provider<MinimalModel, MinimalProvider, P>(NEWAPI_RULES, model, provider)
export const newApiResolverCreator = provider2Provider.bind(null, NEWAPI_RULES)

View File

@@ -0,0 +1,9 @@
import type { Model, Provider } from '@renderer/types'
export interface RuleSet {
rules: Array<{
match: (model: Model) => boolean
provider: (provider: Provider) => Provider
}>
fallbackRule: (provider: Provider) => Provider
}

View File

@@ -0,0 +1,19 @@
import type { Provider } from '@renderer/types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
const VERTEX_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: Provider) => ({
...provider,
id: 'google-vertex-anthropic'
})
}
],
fallbackRule: (provider: Provider) => provider
}
export const vertexAnthropicProviderCreator = provider2Provider.bind(null, VERTEX_ANTHROPIC_RULES)

View File

@@ -1 +1,25 @@
export { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '@shared/provider/constant'
import type { Model } from '@renderer/types'
export const COPILOT_EDITOR_VERSION = 'vscode/1.104.1'
export const COPILOT_PLUGIN_VERSION = 'copilot-chat/0.26.7'
export const COPILOT_INTEGRATION_ID = 'vscode-chat'
export const COPILOT_USER_AGENT = 'GitHubCopilotChat/0.26.7'
export const COPILOT_DEFAULT_HEADERS = {
'Copilot-Integration-Id': COPILOT_INTEGRATION_ID,
'User-Agent': COPILOT_USER_AGENT,
'Editor-Version': COPILOT_EDITOR_VERSION,
'Editor-Plugin-Version': COPILOT_PLUGIN_VERSION,
'editor-version': COPILOT_EDITOR_VERSION,
'editor-plugin-version': COPILOT_PLUGIN_VERSION,
'copilot-vision-request': 'true'
} as const
// Models that require the OpenAI Responses endpoint when routed through GitHub Copilot (#10560)
const COPILOT_RESPONSES_MODEL_IDS = ['gpt-5-codex']
export function isCopilotResponsesModel(model: Model): boolean {
const normalizedId = model.id?.trim().toLowerCase()
const normalizedName = model.name?.trim().toLowerCase()
return COPILOT_RESPONSES_MODEL_IDS.some((target) => normalizedId === target || normalizedName === target)
}

View File

@@ -1,7 +1,8 @@
import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider'
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import type { Provider } from '@renderer/types'
import { getAiSdkProviderId as sharedGetAiSdkProviderId } from '@shared/provider'
import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from '@renderer/utils/provider'
import type { Provider as AiSdkProvider } from 'ai'
import type { AiSdkConfig } from '../types'
@@ -21,12 +22,68 @@ const logger = loggerService.withContext('ProviderFactory')
}
})()
/**
* 静态Provider映射表
* 处理Cherry Studio特有的provider ID到AI SDK标准ID的映射
*/
const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
gemini: 'google', // Google Gemini -> google
'azure-openai': 'azure', // Azure OpenAI -> azure
'openai-response': 'openai', // OpenAI Responses -> openai
grok: 'xai', // Grok -> xai
copilot: 'github-copilot-openai-compatible'
}
/**
* 尝试解析provider标识符支持静态映射和别名
*/
function tryResolveProviderId(identifier: string): ProviderId | null {
// 1. 检查静态映射
const staticMapping = STATIC_PROVIDER_MAPPING[identifier]
if (staticMapping) {
return staticMapping
}
// 2. 检查AiCore是否支持包括别名支持
if (hasProviderConfigByAlias(identifier)) {
// 解析为真实的Provider ID
return resolveProviderConfigId(identifier) as ProviderId
}
return null
}
/**
* 获取AI SDK Provider ID
* Uses shared implementation with renderer-specific config checker
* 简化版:减少重复逻辑,利用通用解析函数
*/
export function getAiSdkProviderId(provider: Provider): string {
return sharedGetAiSdkProviderId(provider)
// 1. 尝试解析provider.id
const resolvedFromId = tryResolveProviderId(provider.id)
if (isAzureOpenAIProvider(provider)) {
if (isAzureResponsesEndpoint(provider)) {
return 'azure-responses'
} else {
return 'azure'
}
}
if (resolvedFromId) {
return resolvedFromId
}
// 2. 尝试解析provider.type
// 会把所有类型为openai的自定义provider解析到aisdk的openaiProvider上
if (provider.type !== 'openai') {
const resolvedFromType = tryResolveProviderId(provider.type)
if (resolvedFromType) {
return resolvedFromType
}
}
if (provider.apiHost.includes('api.openai.com')) {
return 'openai-chat'
}
// 3. 最后的fallback使用provider本身的id
return provider.id
}
export async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {

View File

@@ -1,4 +1,4 @@
import { hasProviderConfig } from '@cherrystudio/ai-core/provider'
import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
import {
getAwsBedrockAccessKeyId,
@@ -10,17 +10,22 @@ import {
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService'
import store from '@renderer/store'
import { isSystemProvider, type Model, type Provider } from '@renderer/types'
import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import { formatApiHost, formatAzureOpenAIApiHost, formatVertexApiHost, routeToEndpoint } from '@renderer/utils/api'
import {
type AiSdkConfigContext,
formatProviderApiHost as sharedFormatProviderApiHost,
type ProviderFormatContext,
providerToAiSdkConfig as sharedProviderToAiSdkConfig,
resolveActualProvider
} from '@shared/provider'
isAnthropicProvider,
isAzureOpenAIProvider,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isPerplexityProvider,
isVertexProvider
} from '@renderer/utils/provider'
import { cloneDeep } from 'lodash'
import type { AiSdkConfig } from '../types'
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
import { COPILOT_DEFAULT_HEADERS } from './constants'
import { getAiSdkProviderId } from './factory'
@@ -51,51 +56,61 @@ function getRotatedApiKey(provider: Provider): string {
}
/**
* Renderer-specific context for providerToAiSdkConfig
* Provides implementations using browser APIs, store, and hooks
* 处理特殊provider的转换逻辑
*/
function createRendererSdkContext(model: Model): AiSdkConfigContext {
return {
getRotatedApiKey: (provider) => getRotatedApiKey(provider as Provider),
isOpenAIChatCompletionOnlyModel: () => isOpenAIChatCompletionOnlyModel(model),
getCopilotDefaultHeaders: () => COPILOT_DEFAULT_HEADERS,
getCopilotStoredHeaders: () => store.getState().copilot.defaultHeaders ?? {},
getAwsBedrockConfig: () => {
const authType = getAwsBedrockAuthType()
return {
authType,
region: getAwsBedrockRegion(),
apiKey: authType === 'apiKey' ? getAwsBedrockApiKey() : undefined,
accessKeyId: authType === 'iam' ? getAwsBedrockAccessKeyId() : undefined,
secretAccessKey: authType === 'iam' ? getAwsBedrockSecretAccessKey() : undefined
}
},
getVertexConfig: (provider) => {
if (!isVertexAIConfigured()) {
return undefined
}
return createVertexProvider(provider as Provider)
},
getEndpointType: () => model.endpoint_type
function handleSpecialProviders(model: Model, provider: Provider): Provider {
if (isNewApiProvider(provider)) {
return newApiResolverCreator(model, provider)
}
if (isSystemProvider(provider)) {
if (provider.id === 'aihubmix') {
return aihubmixProviderCreator(model, provider)
}
if (provider.id === 'vertexai') {
return vertexAnthropicProviderCreator(model, provider)
}
}
if (isAzureOpenAIProvider(provider)) {
return azureAnthropicProviderCreator(model, provider)
}
return provider
}
/**
* 主要用来对齐AISdk的BaseURL格式
* Uses shared implementation with renderer-specific context
* @param provider
* @returns
*/
function getRendererFormatContext(): ProviderFormatContext {
const vertexSettings = store.getState().llm.settings.vertexai
return {
vertex: {
project: vertexSettings.projectId || 'default-project',
location: vertexSettings.location || 'us-central1'
}
}
}
function formatProviderApiHost(provider: Provider): Provider {
return sharedFormatProviderApiHost(provider, getRendererFormatContext())
const formatted = { ...provider }
if (formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost)
}
if (isAnthropicProvider(provider)) {
const baseHost = formatted.anthropicApiHost || formatted.apiHost
// AI SDK needs /v1 in baseURL, Anthropic SDK will strip it in getSdkClient
formatted.apiHost = formatApiHost(baseHost)
if (!formatted.anthropicApiHost) {
formatted.anthropicApiHost = formatted.apiHost
}
} else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isGeminiProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, true, 'v1beta')
} else if (isAzureOpenAIProvider(formatted)) {
formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost)
} else if (isVertexProvider(formatted)) {
formatted.apiHost = formatVertexApiHost(formatted)
} else if (isCherryAIProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else if (isPerplexityProvider(formatted)) {
formatted.apiHost = formatApiHost(formatted.apiHost, false)
} else {
formatted.apiHost = formatApiHost(formatted.apiHost)
}
return formatted
}
/**
@@ -107,9 +122,7 @@ export function getActualProvider(model: Model): Provider {
// 按顺序处理各种转换
let actualProvider = cloneDeep(baseProvider)
actualProvider = resolveActualProvider(actualProvider, model, {
isSystemProvider
}) as Provider
actualProvider = handleSpecialProviders(model, actualProvider)
actualProvider = formatProviderApiHost(actualProvider)
return actualProvider
@@ -117,11 +130,121 @@ export function getActualProvider(model: Model): Provider {
/**
* 将 Provider 配置转换为新 AI SDK 格式
* Uses shared implementation with renderer-specific context
* 简化版:利用新的别名映射系统
*/
export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig {
const context = createRendererSdkContext(model)
return sharedProviderToAiSdkConfig(actualProvider, model.id, context) as AiSdkConfig
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
// 构建基础配置
const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost)
const baseConfig = {
baseURL: baseURL,
apiKey: getRotatedApiKey(actualProvider)
}
const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot
if (isCopilotProvider) {
const storedHeaders = store.getState().copilot.defaultHeaders ?? {}
const options = ProviderConfigFactory.fromProvider('github-copilot-openai-compatible', baseConfig, {
headers: {
...COPILOT_DEFAULT_HEADERS,
...storedHeaders,
...actualProvider.extra_headers
},
name: actualProvider.id,
includeUsage: true
})
return {
providerId: 'github-copilot-openai-compatible',
options
}
}
// 处理OpenAI模式
const extraOptions: any = {}
extraOptions.endpoint = endpoint
if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && actualProvider.type === 'openai')) {
extraOptions.mode = 'chat'
}
// 添加额外headers
if (actualProvider.extra_headers) {
extraOptions.headers = actualProvider.extra_headers
// copy from openaiBaseClient/openaiResponseApiClient
if (aiSdkProviderId === 'openai') {
extraOptions.headers = {
...extraOptions.headers,
'HTTP-Referer': 'https://cherry-ai.com',
'X-Title': 'Cherry Studio',
'X-Api-Key': baseConfig.apiKey
}
}
}
// azure
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses?tabs=python-key#responses-api
if (aiSdkProviderId === 'azure-responses') {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'azure') {
extraOptions.mode = 'chat'
}
// bedrock
if (aiSdkProviderId === 'bedrock') {
const authType = getAwsBedrockAuthType()
extraOptions.region = getAwsBedrockRegion()
if (authType === 'apiKey') {
extraOptions.apiKey = getAwsBedrockApiKey()
} else {
extraOptions.accessKeyId = getAwsBedrockAccessKeyId()
extraOptions.secretAccessKey = getAwsBedrockSecretAccessKey()
}
}
// google-vertex
if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') {
if (!isVertexAIConfigured()) {
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
}
const { project, location, googleCredentials } = createVertexProvider(actualProvider)
extraOptions.project = project
extraOptions.location = location
extraOptions.googleCredentials = {
...googleCredentials,
privateKey: formatPrivateKey(googleCredentials.privateKey)
}
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
}
// cherryin
if (aiSdkProviderId === 'cherryin') {
if (model.endpoint_type) {
extraOptions.endpointType = model.endpoint_type
}
}
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return {
providerId: aiSdkProviderId,
options
}
}
// 否则fallback到openai-compatible
const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey)
return {
providerId: 'openai-compatible',
options: {
...options,
name: actualProvider.id,
...extraOptions,
includeUsage: true
}
}
}
/**
@@ -164,13 +287,13 @@ export async function prepareSpecialProviderConfig(
break
}
case 'cherryai': {
config.options.fetch = async (url: RequestInfo | URL, options: RequestInit) => {
config.options.fetch = async (url, options) => {
// 在这里对最终参数进行签名
const signature = await window.api.cherryai.generateSignature({
method: 'POST',
path: '/chat/completions',
query: '',
body: JSON.parse(options.body as string)
body: JSON.parse(options.body)
})
return fetch(url, {
...options,

View File

@@ -1,13 +1,113 @@
import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import { initializeSharedProviders, SHARED_PROVIDER_CONFIGS } from '@shared/provider'
const logger = loggerService.withContext('ProviderConfigs')
export const NEW_PROVIDER_CONFIGS = SHARED_PROVIDER_CONFIGS
/**
* 新Provider配置定义
* 定义了需要动态注册的AI Providers
*/
export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
{
id: 'openrouter',
name: 'OpenRouter',
import: () => import('@openrouter/ai-sdk-provider'),
creatorFunctionName: 'createOpenRouter',
supportsImageGeneration: true,
aliases: ['openrouter']
},
{
id: 'google-vertex',
name: 'Google Vertex AI',
import: () => import('@ai-sdk/google-vertex/edge'),
creatorFunctionName: 'createVertex',
supportsImageGeneration: true,
aliases: ['vertexai']
},
{
id: 'google-vertex-anthropic',
name: 'Google Vertex AI Anthropic',
import: () => import('@ai-sdk/google-vertex/anthropic/edge'),
creatorFunctionName: 'createVertexAnthropic',
supportsImageGeneration: true,
aliases: ['vertexai-anthropic']
},
{
id: 'azure-anthropic',
name: 'Azure AI Anthropic',
import: () => import('@ai-sdk/anthropic'),
creatorFunctionName: 'createAnthropic',
supportsImageGeneration: false,
aliases: ['azure-anthropic']
},
{
id: 'github-copilot-openai-compatible',
name: 'GitHub Copilot OpenAI Compatible',
import: () => import('@opeoginni/github-copilot-openai-compatible'),
creatorFunctionName: 'createGitHubCopilotOpenAICompatible',
supportsImageGeneration: false,
aliases: ['copilot', 'github-copilot']
},
{
id: 'bedrock',
name: 'Amazon Bedrock',
import: () => import('@ai-sdk/amazon-bedrock'),
creatorFunctionName: 'createAmazonBedrock',
supportsImageGeneration: true,
aliases: ['aws-bedrock']
},
{
id: 'perplexity',
name: 'Perplexity',
import: () => import('@ai-sdk/perplexity'),
creatorFunctionName: 'createPerplexity',
supportsImageGeneration: false,
aliases: ['perplexity']
},
{
id: 'mistral',
name: 'Mistral',
import: () => import('@ai-sdk/mistral'),
creatorFunctionName: 'createMistral',
supportsImageGeneration: false,
aliases: ['mistral']
},
{
id: 'huggingface',
name: 'HuggingFace',
import: () => import('@ai-sdk/huggingface'),
creatorFunctionName: 'createHuggingFace',
supportsImageGeneration: true,
aliases: ['hf', 'hugging-face']
},
{
id: 'ai-gateway',
name: 'AI Gateway',
import: () => import('@ai-sdk/gateway'),
creatorFunctionName: 'createGateway',
supportsImageGeneration: true,
aliases: ['gateway']
},
{
id: 'cerebras',
name: 'Cerebras',
import: () => import('@ai-sdk/cerebras'),
creatorFunctionName: 'createCerebras',
supportsImageGeneration: false
}
] as const
/**
* 初始化新的Providers
* 使用aiCore的动态注册功能
*/
export async function initializeNewProviders(): Promise<void> {
initializeSharedProviders({
warn: (message) => logger.warn(message),
error: (message, error) => logger.error(message, error)
})
try {
const successCount = registerMultipleProviderConfigs(NEW_PROVIDER_CONFIGS)
if (successCount < NEW_PROVIDER_CONFIGS.length) {
logger.warn('Some providers failed to register. Check previous error logs.')
}
} catch (error) {
logger.error('Failed to initialize new providers:', error as Error)
}
}

View File

@@ -245,8 +245,8 @@ export class AiSdkSpanAdapter {
'gen_ai.usage.output_tokens'
]
const promptTokens = attributes[inputsTokenKeys.find((key) => attributes[key]) || '']
const completionTokens = attributes[outputTokenKeys.find((key) => attributes[key]) || '']
const completionTokens = attributes[inputsTokenKeys.find((key) => attributes[key]) || '']
const promptTokens = attributes[outputTokenKeys.find((key) => attributes[key]) || '']
if (completionTokens !== undefined || promptTokens !== undefined) {
const usage: TokenUsage = {

View File

@@ -1,53 +0,0 @@
import type { Span } from '@opentelemetry/api'
import { SpanKind, SpanStatusCode } from '@opentelemetry/api'
import { describe, expect, it, vi } from 'vitest'
import { AiSdkSpanAdapter } from '../AiSdkSpanAdapter'
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({
debug: vi.fn(),
error: vi.fn(),
info: vi.fn(),
warn: vi.fn()
})
}
}))
describe('AiSdkSpanAdapter', () => {
const createMockSpan = (attributes: Record<string, unknown>): Span => {
const span = {
spanContext: () => ({
traceId: 'trace-id',
spanId: 'span-id'
}),
_attributes: attributes,
_events: [],
name: 'test span',
status: { code: SpanStatusCode.OK },
kind: SpanKind.CLIENT,
startTime: [0, 0] as [number, number],
endTime: [0, 1] as [number, number],
ended: true,
parentSpanId: '',
links: []
}
return span as unknown as Span
}
it('maps prompt and completion usage tokens to the correct fields', () => {
const attributes = {
'ai.usage.promptTokens': 321,
'ai.usage.completionTokens': 654
}
const span = createMockSpan(attributes)
const result = AiSdkSpanAdapter.convertToSpanEntity({ span })
expect(result.usage).toBeDefined()
expect(result.usage?.prompt_tokens).toBe(321)
expect(result.usage?.completion_tokens).toBe(654)
expect(result.usage?.total_tokens).toBe(975)
})
})

View File

@@ -71,11 +71,10 @@ describe('mcp utils', () => {
const result = setupToolsConfig(mcpTools)
expect(result).not.toBeUndefined()
// Tools are now keyed by id (which includes serverId suffix) for uniqueness
expect(Object.keys(result!)).toEqual(['test-tool-1'])
expect(result!['test-tool-1']).toHaveProperty('description')
expect(result!['test-tool-1']).toHaveProperty('inputSchema')
expect(result!['test-tool-1']).toHaveProperty('execute')
expect(Object.keys(result!)).toEqual(['test-tool'])
expect(result!['test-tool']).toHaveProperty('description')
expect(result!['test-tool']).toHaveProperty('inputSchema')
expect(result!['test-tool']).toHaveProperty('execute')
})
it('should handle multiple MCP tools', () => {
@@ -110,8 +109,7 @@ describe('mcp utils', () => {
expect(result).not.toBeUndefined()
expect(Object.keys(result!)).toHaveLength(2)
// Tools are keyed by id for uniqueness
expect(Object.keys(result!)).toEqual(['tool1-id', 'tool2-id'])
expect(Object.keys(result!)).toEqual(['tool1', 'tool2'])
})
})
@@ -137,10 +135,9 @@ describe('mcp utils', () => {
const result = convertMcpToolsToAiSdkTools(mcpTools)
// Tools are keyed by id for uniqueness when multiple server instances exist
expect(Object.keys(result)).toEqual(['get-weather-id'])
expect(Object.keys(result)).toEqual(['get-weather'])
const tool = result['get-weather-id'] as Tool
const tool = result['get-weather'] as Tool
expect(tool.description).toBe('Get weather information')
expect(tool.inputSchema).toBeDefined()
expect(typeof tool.execute).toBe('function')
@@ -163,8 +160,8 @@ describe('mcp utils', () => {
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['no-desc-tool-id'])
const tool = result['no-desc-tool-id'] as Tool
expect(Object.keys(result)).toEqual(['no-desc-tool'])
const tool = result['no-desc-tool'] as Tool
expect(tool.description).toBe('Tool from test-server')
})
@@ -205,13 +202,13 @@ describe('mcp utils', () => {
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['complex-tool-id'])
const tool = result['complex-tool-id'] as Tool
expect(Object.keys(result)).toEqual(['complex-tool'])
const tool = result['complex-tool'] as Tool
expect(tool.inputSchema).toBeDefined()
expect(typeof tool.execute).toBe('function')
})
it('should preserve tool id with special characters', () => {
it('should preserve tool names with special characters', () => {
const mcpTools: MCPTool[] = [
{
id: 'special-tool-id',
@@ -228,8 +225,7 @@ describe('mcp utils', () => {
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
// Tools are keyed by id for uniqueness
expect(Object.keys(result)).toEqual(['special-tool-id'])
expect(Object.keys(result)).toEqual(['tool_with-special.chars'])
})
it('should handle multiple tools with different schemas', () => {
@@ -280,11 +276,10 @@ describe('mcp utils', () => {
const result = convertMcpToolsToAiSdkTools(mcpTools)
// Tools are keyed by id for uniqueness
expect(Object.keys(result).sort()).toEqual(['boolean-tool-id', 'number-tool-id', 'string-tool-id'])
expect(result['string-tool-id']).toBeDefined()
expect(result['number-tool-id']).toBeDefined()
expect(result['boolean-tool-id']).toBeDefined()
expect(Object.keys(result).sort()).toEqual(['boolean-tool', 'number-tool', 'string-tool'])
expect(result['string-tool']).toBeDefined()
expect(result['number-tool']).toBeDefined()
expect(result['boolean-tool']).toBeDefined()
})
})
@@ -315,7 +310,7 @@ describe('mcp utils', () => {
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['test-exec-tool-id'] as Tool
const tool = tools['test-exec-tool'] as Tool
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'test-call-123' })
expect(requestToolConfirmation).toHaveBeenCalled()
@@ -348,7 +343,7 @@ describe('mcp utils', () => {
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['cancelled-tool-id'] as Tool
const tool = tools['cancelled-tool'] as Tool
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'cancel-call-123' })
expect(requestToolConfirmation).toHaveBeenCalled()
@@ -390,7 +385,7 @@ describe('mcp utils', () => {
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['error-tool-id'] as Tool
const tool = tools['error-tool'] as Tool
await expect(
tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'error-call-123' })
@@ -426,7 +421,7 @@ describe('mcp utils', () => {
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['auto-approve-tool-id'] as Tool
const tool = tools['auto-approve-tool'] as Tool
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'auto-call-123' })
expect(requestToolConfirmation).not.toHaveBeenCalled()

View File

@@ -154,10 +154,6 @@ vi.mock('../websearch', () => ({
getWebSearchParams: vi.fn(() => ({ enable_search: true }))
}))
vi.mock('../../prepareParams/header', () => ({
addAnthropicHeaders: vi.fn(() => ['context-1m-2025-08-07'])
}))
const ensureWindowApi = () => {
const globalWindow = window as any
globalWindow.api = globalWindow.api || {}
@@ -637,64 +633,5 @@ describe('options utils', () => {
expect(result.providerOptions).toHaveProperty('anthropic')
})
})
describe('AWS Bedrock provider', () => {
const bedrockProvider = {
id: 'bedrock',
name: 'AWS Bedrock',
type: 'aws-bedrock',
apiKey: 'test-key',
apiHost: 'https://bedrock.us-east-1.amazonaws.com',
models: [] as Model[]
} as Provider
const bedrockModel: Model = {
id: 'anthropic.claude-sonnet-4-20250514-v1:0',
name: 'Claude Sonnet 4',
provider: 'bedrock'
} as Model
it('should build basic Bedrock options', () => {
const result = buildProviderOptions(mockAssistant, bedrockModel, bedrockProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.providerOptions).toHaveProperty('bedrock')
expect(result.providerOptions.bedrock).toBeDefined()
})
it('should include anthropicBeta when Anthropic headers are needed', async () => {
const { addAnthropicHeaders } = await import('../../prepareParams/header')
vi.mocked(addAnthropicHeaders).mockReturnValue(['interleaved-thinking-2025-05-14', 'context-1m-2025-08-07'])
const result = buildProviderOptions(mockAssistant, bedrockModel, bedrockProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.providerOptions.bedrock).toHaveProperty('anthropicBeta')
expect(result.providerOptions.bedrock.anthropicBeta).toEqual([
'interleaved-thinking-2025-05-14',
'context-1m-2025-08-07'
])
})
it('should include reasoning parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, bedrockModel, bedrockProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.providerOptions.bedrock).toHaveProperty('reasoningConfig')
expect(result.providerOptions.bedrock.reasoningConfig).toEqual({
type: 'enabled',
budgetTokens: 5000
})
})
})
})
})

View File

@@ -144,7 +144,7 @@ describe('reasoning utils', () => {
expect(result).toEqual({})
})
it('should not override reasoning for OpenRouter when reasoning effort undefined', async () => {
it('should disable reasoning for OpenRouter when no reasoning effort set', async () => {
const { isReasoningModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
@@ -161,29 +161,6 @@ describe('reasoning utils', () => {
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({})
})
it('should disable reasoning for OpenRouter when reasoning effort explicitly none', async () => {
const { isReasoningModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
const model: Model = {
id: 'anthropic/claude-sonnet-4',
name: 'Claude Sonnet 4',
provider: SystemProviderIds.openrouter
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'none'
}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({ reasoning: { enabled: false, exclude: true } })
})
@@ -292,9 +269,7 @@ describe('reasoning utils', () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'none'
}
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)

View File

@@ -28,9 +28,7 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): ToolSet {
const tools: ToolSet = {}
for (const mcpTool of mcpTools) {
// Use mcpTool.id (which includes serverId suffix) to ensure uniqueness
// when multiple instances of the same MCP server type are configured
tools[mcpTool.id] = tool({
tools[mcpTool.name] = tool({
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
inputSchema: jsonSchema(mcpTool.inputSchema as JSONSchema7),
execute: async (params, { toolCallId }) => {

View File

@@ -36,7 +36,6 @@ import { isSupportServiceTierProvider, isSupportVerbosityProvider } from '@rende
import type { JSONValue } from 'ai'
import { t } from 'i18next'
import { addAnthropicHeaders } from '../prepareParams/header'
import { getAiSdkProviderId } from '../provider/factory'
import { buildGeminiGenerateImageParams } from './image'
import {
@@ -470,11 +469,6 @@ function buildBedrockProviderOptions(
}
}
const betaHeaders = addAnthropicHeaders(assistant, model)
if (betaHeaders.length > 0) {
providerOptions.anthropicBeta = betaHeaders
}
return providerOptions
}

View File

@@ -16,8 +16,10 @@ import {
isGPT5SeriesModel,
isGPT51SeriesModel,
isGrok4FastReasoningModel,
isGrokReasoningModel,
isOpenAIDeepResearchModel,
isOpenAIModel,
isOpenAIReasoningModel,
isQwenAlwaysThinkModel,
isQwenReasoningModel,
isReasoningModel,
@@ -62,22 +64,30 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
}
const reasoningEffort = assistant?.settings?.reasoning_effort
// reasoningEffort is not set, no extra reasoning setting
// Generally, for every model which supports reasoning control, the reasoning effort won't be undefined.
// It's for some reasoning models that don't support reasoning control, such as deepseek reasoner.
if (!reasoningEffort) {
return {}
}
// Handle 'none' reasoningEffort. It's explicitly off.
if (reasoningEffort === 'none') {
// Handle undefined and 'none' reasoningEffort.
// TODO: They should be separated.
if (!reasoningEffort || reasoningEffort === 'none') {
// openrouter: use reasoning
if (model.provider === SystemProviderIds.openrouter) {
// Don't disable reasoning for Gemini models that support thinking tokens
if (isSupportedThinkingTokenGeminiModel(model) && !GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
return {}
}
// 'none' is not an available value for effort for now.
// I think they should resolve this issue soon, so I'll just go ahead and use this value.
if (isGPT51SeriesModel(model) && reasoningEffort === 'none') {
return { reasoning: { effort: 'none' } }
}
// Don't disable reasoning for models that require it
if (
isGrokReasoningModel(model) ||
isOpenAIReasoningModel(model) ||
isQwenAlwaysThinkModel(model) ||
model.id.includes('seed-oss') ||
model.id.includes('minimax-m2')
) {
return {}
}
return { reasoning: { enabled: false, exclude: true } }
}
@@ -91,6 +101,11 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
return { enable_thinking: false }
}
// claude
if (isSupportedThinkingTokenClaudeModel(model)) {
return {}
}
// gemini
if (isSupportedThinkingTokenGeminiModel(model)) {
if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
@@ -103,10 +118,8 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
}
}
}
} else {
logger.warn(`Model ${model.id} cannot disable reasoning. Fallback to empty reasoning param.`)
return {}
}
return {}
}
// use thinking, doubao, zhipu, etc.
@@ -126,7 +139,6 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
}
}
logger.warn(`Model ${model.id} doesn't match any disable reasoning behavior. Fallback to empty reasoning param.`)
return {}
}
@@ -281,7 +293,6 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
}
// OpenRouter models, use reasoning
// FIXME: duplicated openrouter handling. remove one
if (model.provider === SystemProviderIds.openrouter) {
if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) {
return {
@@ -673,10 +684,6 @@ export function getCustomParameters(assistant: Assistant): Record<string, any> {
if (!param.name?.trim()) {
return acc
}
// Parse JSON type parameters
// Related: src/renderer/src/pages/settings/AssistantSettings/AssistantModelSettings.tsx:133-148
// The UI stores JSON type params as strings (e.g., '{"key":"value"}')
// This function parses them into objects before sending to the API
if (param.type === 'json') {
const value = param.value as string
if (value === 'undefined') {

View File

@@ -215,10 +215,6 @@
border-top: none !important;
}
.ant-collapse-header-text {
overflow-x: hidden;
}
.ant-slider .ant-slider-handle::after {
box-shadow: 0 1px 4px 0px rgb(128 128 128 / 50%) !important;
}

View File

@@ -10,7 +10,6 @@ import {
} from '@ant-design/icons'
import { loggerService } from '@logger'
import { download } from '@renderer/utils/download'
import { convertImageToPng } from '@renderer/utils/image'
import type { ImageProps as AntImageProps } from 'antd'
import { Dropdown, Image as AntImage, Space } from 'antd'
import { Base64 } from 'js-base64'
@@ -34,38 +33,39 @@ const ImageViewer: React.FC<ImageViewerProps> = ({ src, style, ...props }) => {
// 复制图片到剪贴板
const handleCopyImage = async (src: string) => {
try {
let blob: Blob
if (src.startsWith('data:')) {
// 处理 base64 格式的图片
const match = src.match(/^data:(image\/\w+);base64,(.+)$/)
if (!match) throw new Error('Invalid base64 image format')
const mimeType = match[1]
const byteArray = Base64.toUint8Array(match[2])
blob = new Blob([byteArray], { type: mimeType })
const blob = new Blob([byteArray], { type: mimeType })
await navigator.clipboard.write([new ClipboardItem({ [mimeType]: blob })])
} else if (src.startsWith('file://')) {
// 处理本地文件路径
const bytes = await window.api.fs.read(src)
const mimeType = mime.getType(src) || 'application/octet-stream'
blob = new Blob([bytes], { type: mimeType })
const blob = new Blob([bytes], { type: mimeType })
await navigator.clipboard.write([
new ClipboardItem({
[mimeType]: blob
})
])
} else {
// 处理 URL 格式的图片
const response = await fetch(src)
blob = await response.blob()
const blob = await response.blob()
await navigator.clipboard.write([
new ClipboardItem({
[blob.type]: blob
})
])
}
// 统一转换为 PNG 以确保兼容性(剪贴板 API 不支持 JPEG
const pngBlob = await convertImageToPng(blob)
const item = new ClipboardItem({
'image/png': pngBlob
})
await navigator.clipboard.write([item])
window.toast.success(t('message.copy.success'))
} catch (error) {
const err = error as Error
logger.error(`Failed to copy image: ${err.message}`, { stack: err.stack })
logger.error('Failed to copy image:', error as Error)
window.toast.error(t('message.copy.failed'))
}
}

View File

@@ -57,7 +57,7 @@ const PopupContainer: React.FC<Props> = ({ model, apiFilter, modelFilter, showTa
const [_searchText, setSearchText] = useState('')
const searchText = useDeferredValue(_searchText)
const { models, isLoading } = useApiModels(apiFilter)
const adaptedModels = useMemo(() => models.map((model) => apiModelAdapter(model)), [models])
const adaptedModels = models.map((model) => apiModelAdapter(model))
// 当前选中的模型ID
const currentModelId = model ? model.id : ''

View File

@@ -309,14 +309,11 @@ describe('Ling Models', () => {
describe('Claude & regional providers', () => {
it('identifies claude 4.5 variants', () => {
expect(isClaude45ReasoningModel(createModel({ id: 'claude-sonnet-4.5-preview' }))).toBe(true)
expect(isClaude4SeriesModel(createModel({ id: 'claude-sonnet-4-5@20250929' }))).toBe(true)
expect(isClaude45ReasoningModel(createModel({ id: 'claude-3-sonnet' }))).toBe(false)
})
it('identifies claude 4 variants', () => {
expect(isClaude4SeriesModel(createModel({ id: 'claude-opus-4' }))).toBe(true)
expect(isClaude4SeriesModel(createModel({ id: 'claude-sonnet-4@20250514' }))).toBe(true)
expect(isClaude4SeriesModel(createModel({ id: 'anthropic.claude-sonnet-4-20250514-v1:0' }))).toBe(true)
expect(isClaude4SeriesModel(createModel({ id: 'claude-4.2-sonnet-variant' }))).toBe(false)
expect(isClaude4SeriesModel(createModel({ id: 'claude-3-haiku' }))).toBe(false)
})

View File

@@ -15,7 +15,6 @@ import {
isSupportVerbosityModel
} from '../openai'
import { isQwenMTModel } from '../qwen'
import { isFunctionCallingModel } from '../tooluse'
import {
agentModelFilter,
getModelSupportedVerbosity,
@@ -113,7 +112,6 @@ const textToImageMock = vi.mocked(isTextToImageModel)
const generateImageMock = vi.mocked(isGenerateImageModel)
const reasoningMock = vi.mocked(isOpenAIReasoningModel)
const openAIWebSearchOnlyMock = vi.mocked(isOpenAIWebSearchChatCompletionOnlyModel)
const isFunctionCallingModelMock = vi.mocked(isFunctionCallingModel)
describe('model utils', () => {
beforeEach(() => {
@@ -122,387 +120,200 @@ describe('model utils', () => {
rerankMock.mockReturnValue(false)
visionMock.mockReturnValue(true)
textToImageMock.mockReturnValue(false)
generateImageMock.mockReturnValue(false)
generateImageMock.mockReturnValue(true)
reasoningMock.mockReturnValue(false)
openAIWebSearchOnlyMock.mockReturnValue(false)
})
describe('OpenAI model detection', () => {
describe('isOpenAILLMModel', () => {
it('returns false for undefined model', () => {
expect(isOpenAILLMModel(undefined as unknown as Model)).toBe(false)
})
it('detects OpenAI LLM models through reasoning and GPT prefix', () => {
expect(isOpenAILLMModel(undefined as unknown as Model)).toBe(false)
expect(isOpenAILLMModel(createModel({ id: 'gpt-4o-image' }))).toBe(false)
it('returns false for image generation models', () => {
expect(isOpenAILLMModel(createModel({ id: 'gpt-4o-image' }))).toBe(false)
})
reasoningMock.mockReturnValueOnce(true)
expect(isOpenAILLMModel(createModel({ id: 'o1-preview' }))).toBe(true)
it('returns true for reasoning models', () => {
reasoningMock.mockReturnValueOnce(true)
expect(isOpenAILLMModel(createModel({ id: 'o1-preview' }))).toBe(true)
})
expect(isOpenAILLMModel(createModel({ id: 'GPT-5-turbo' }))).toBe(true)
})
it('returns true for GPT-prefixed models', () => {
expect(isOpenAILLMModel(createModel({ id: 'GPT-5-turbo' }))).toBe(true)
})
it('detects OpenAI models via GPT prefix or reasoning support', () => {
expect(isOpenAIModel(createModel({ id: 'gpt-4.1' }))).toBe(true)
reasoningMock.mockReturnValueOnce(true)
expect(isOpenAIModel(createModel({ id: 'o3' }))).toBe(true)
})
it('evaluates support for flex service tier and alias helper', () => {
expect(isSupportFlexServiceTierModel(createModel({ id: 'o3' }))).toBe(true)
expect(isSupportFlexServiceTierModel(createModel({ id: 'o3-mini' }))).toBe(false)
expect(isSupportFlexServiceTierModel(createModel({ id: 'o4-mini' }))).toBe(true)
expect(isSupportFlexServiceTierModel(createModel({ id: 'gpt-5-preview' }))).toBe(true)
expect(isSupportedFlexServiceTier(createModel({ id: 'gpt-4o' }))).toBe(false)
})
it('detects verbosity support for GPT-5+ families', () => {
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5' }))).toBe(true)
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5-chat' }))).toBe(false)
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5.1-preview' }))).toBe(true)
})
it('limits verbosity controls for GPT-5 Pro models', () => {
const proModel = createModel({ id: 'gpt-5-pro' })
const previewModel = createModel({ id: 'gpt-5-preview' })
expect(getModelSupportedVerbosity(proModel)).toEqual([undefined, 'high'])
expect(getModelSupportedVerbosity(previewModel)).toEqual([undefined, 'low', 'medium', 'high'])
expect(isGPT5ProModel(proModel)).toBe(true)
expect(isGPT5ProModel(previewModel)).toBe(false)
})
it('identifies OpenAI chat-completion-only models', () => {
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'o1-mini' }))).toBe(true)
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'gpt-4o' }))).toBe(false)
})
it('filters unsupported OpenAI catalog entries', () => {
expect(isSupportedModel({ id: 'gpt-4', object: 'model' } as any)).toBe(true)
expect(isSupportedModel({ id: 'tts-1', object: 'model' } as any)).toBe(false)
})
it('calculates temperature/top-p support correctly', () => {
const model = createModel({ id: 'o1' })
reasoningMock.mockReturnValue(true)
expect(isNotSupportTemperatureAndTopP(model)).toBe(true)
const openWeight = createModel({ id: 'gpt-oss-debug' })
expect(isNotSupportTemperatureAndTopP(openWeight)).toBe(false)
const chatOnly = createModel({ id: 'o1-preview' })
reasoningMock.mockReturnValue(false)
expect(isNotSupportTemperatureAndTopP(chatOnly)).toBe(true)
const qwenMt = createModel({ id: 'qwen-mt-large', provider: 'aliyun' })
expect(isNotSupportTemperatureAndTopP(qwenMt)).toBe(true)
})
it('handles gemma and gemini detections plus zhipu tagging', () => {
expect(isGemmaModel(createModel({ id: 'Gemma-3-27B' }))).toBe(true)
expect(isGemmaModel(createModel({ group: 'Gemma' }))).toBe(true)
expect(isGemmaModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isGeminiModel(createModel({ id: 'Gemini-2.0' }))).toBe(true)
expect(isZhipuModel(createModel({ provider: 'zhipu' }))).toBe(true)
expect(isZhipuModel(createModel({ provider: 'openai' }))).toBe(false)
})
it('groups qwen models by prefix', () => {
const qwen = createModel({ id: 'Qwen-7B', provider: 'qwen', name: 'Qwen-7B' })
const qwenOmni = createModel({ id: 'qwen2.5-omni', name: 'qwen2.5-omni' })
const other = createModel({ id: 'deepseek-v3', group: 'DeepSeek' })
const grouped = groupQwenModels([qwen, qwenOmni, other])
expect(Object.keys(grouped)).toContain('qwen-7b')
expect(Object.keys(grouped)).toContain('qwen2.5')
expect(grouped.DeepSeek).toContain(other)
})
it('aggregates boolean helpers based on regex rules', () => {
expect(isAnthropicModel(createModel({ id: 'claude-3.5' }))).toBe(true)
expect(isQwenMTModel(createModel({ id: 'qwen-mt-plus' }))).toBe(true)
expect(isNotSupportSystemMessageModel(createModel({ id: 'gemma-moe' }))).toBe(true)
expect(isOpenAIOpenWeightModel(createModel({ id: 'gpt-oss-free' }))).toBe(true)
})
describe('isNotSupportedTextDelta', () => {
it('returns true for qwen-mt-turbo and qwen-mt-plus models', () => {
// qwen-mt series that don't support text delta
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-turbo' }))).toBe(true)
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-plus' }))).toBe(true)
expect(isNotSupportTextDeltaModel(createModel({ id: 'Qwen-MT-Turbo' }))).toBe(true)
expect(isNotSupportTextDeltaModel(createModel({ id: 'QWEN-MT-PLUS' }))).toBe(true)
})
describe('isOpenAIModel', () => {
it('detects models via GPT prefix', () => {
expect(isOpenAIModel(createModel({ id: 'gpt-4.1' }))).toBe(true)
})
it('returns false for qwen-mt-flash and other models', () => {
// qwen-mt-flash supports text delta
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-flash' }))).toBe(false)
expect(isNotSupportTextDeltaModel(createModel({ id: 'Qwen-MT-Flash' }))).toBe(false)
it('detects models via reasoning support', () => {
reasoningMock.mockReturnValueOnce(true)
expect(isOpenAIModel(createModel({ id: 'o3' }))).toBe(true)
})
// Legacy qwen models without mt prefix (support text delta)
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-turbo' }))).toBe(false)
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-plus' }))).toBe(false)
// Other qwen models
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-max' }))).toBe(false)
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen2.5-72b' }))).toBe(false)
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-vl-plus' }))).toBe(false)
// Non-qwen models
expect(isNotSupportTextDeltaModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isNotSupportTextDeltaModel(createModel({ id: 'claude-3.5' }))).toBe(false)
expect(isNotSupportTextDeltaModel(createModel({ id: 'glm-4-plus' }))).toBe(false)
})
describe('isOpenAIChatCompletionOnlyModel', () => {
it('identifies chat-completion-only models', () => {
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'o1-mini' }))).toBe(true)
})
it('handles models with version suffixes', () => {
// qwen-mt models with version suffixes
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-turbo-1201' }))).toBe(true)
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-plus-0828' }))).toBe(true)
it('returns false for general models', () => {
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'gpt-4o' }))).toBe(false)
})
// Legacy qwen models with version suffixes (support text delta)
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-turbo-0828' }))).toBe(false)
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-plus-latest' }))).toBe(false)
})
})
describe('GPT-5 family detection', () => {
describe('isGPT5SeriesModel', () => {
it('returns true for GPT-5 models', () => {
expect(isGPT5SeriesModel(createModel({ id: 'gpt-5-preview' }))).toBe(true)
})
it('returns false for GPT-5.1 models', () => {
expect(isGPT5SeriesModel(createModel({ id: 'gpt-5.1-preview' }))).toBe(false)
})
})
describe('isGPT51SeriesModel', () => {
it('returns true for GPT-5.1 models', () => {
expect(isGPT51SeriesModel(createModel({ id: 'gpt-5.1-mini' }))).toBe(true)
})
})
describe('isGPT5SeriesReasoningModel', () => {
it('returns true for GPT-5 reasoning models', () => {
expect(isGPT5SeriesReasoningModel(createModel({ id: 'gpt-5' }))).toBe(true)
})
it('returns false for gpt-5-chat', () => {
expect(isGPT5SeriesReasoningModel(createModel({ id: 'gpt-5-chat' }))).toBe(false)
})
})
describe('isGPT5ProModel', () => {
it('returns true for GPT-5 Pro models', () => {
expect(isGPT5ProModel(createModel({ id: 'gpt-5-pro' }))).toBe(true)
})
it('returns false for non-Pro GPT-5 models', () => {
expect(isGPT5ProModel(createModel({ id: 'gpt-5-preview' }))).toBe(false)
})
})
it('evaluates GPT-5 family helpers', () => {
expect(isGPT5SeriesModel(createModel({ id: 'gpt-5-preview' }))).toBe(true)
expect(isGPT5SeriesModel(createModel({ id: 'gpt-5.1-preview' }))).toBe(false)
expect(isGPT51SeriesModel(createModel({ id: 'gpt-5.1-mini' }))).toBe(true)
expect(isGPT5SeriesReasoningModel(createModel({ id: 'gpt-5-prompt' }))).toBe(true)
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5-chat' }))).toBe(false)
})
describe('Verbosity support', () => {
describe('isSupportVerbosityModel', () => {
it('returns true for GPT-5 models', () => {
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5' }))).toBe(true)
})
it('wraps generate/vision helpers that operate on arrays', () => {
const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })]
expect(isVisionModels(models)).toBe(true)
visionMock.mockReturnValueOnce(true).mockReturnValueOnce(false)
expect(isVisionModels(models)).toBe(false)
it('returns false for GPT-5 chat models', () => {
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5-chat' }))).toBe(false)
})
it('returns true for GPT-5.1 models', () => {
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5.1-preview' }))).toBe(true)
})
})
describe('getModelSupportedVerbosity', () => {
it('returns only "high" for GPT-5 Pro models', () => {
expect(getModelSupportedVerbosity(createModel({ id: 'gpt-5-pro' }))).toEqual([undefined, 'high'])
expect(getModelSupportedVerbosity(createModel({ id: 'gpt-5-pro-2025-10-06' }))).toEqual([undefined, 'high'])
})
it('returns all levels for non-Pro GPT-5 models', () => {
const previewModel = createModel({ id: 'gpt-5-preview' })
expect(getModelSupportedVerbosity(previewModel)).toEqual([undefined, 'low', 'medium', 'high'])
})
it('returns all levels for GPT-5.1 models', () => {
const gpt51Model = createModel({ id: 'gpt-5.1-preview' })
expect(getModelSupportedVerbosity(gpt51Model)).toEqual([undefined, 'low', 'medium', 'high'])
})
it('returns only undefined for non-GPT-5 models', () => {
expect(getModelSupportedVerbosity(createModel({ id: 'gpt-4o' }))).toEqual([undefined])
expect(getModelSupportedVerbosity(createModel({ id: 'claude-3.5' }))).toEqual([undefined])
})
it('returns only undefined for undefiend/null input', () => {
expect(getModelSupportedVerbosity(undefined)).toEqual([undefined])
expect(getModelSupportedVerbosity(null)).toEqual([undefined])
})
})
expect(isGenerateImageModels(models)).toBe(true)
generateImageMock.mockReturnValueOnce(true).mockReturnValueOnce(false)
expect(isGenerateImageModels(models)).toBe(false)
})
describe('Flex service tier support', () => {
describe('isSupportFlexServiceTierModel', () => {
it('returns true for supported models', () => {
expect(isSupportFlexServiceTierModel(createModel({ id: 'o3' }))).toBe(true)
expect(isSupportFlexServiceTierModel(createModel({ id: 'o4-mini' }))).toBe(true)
expect(isSupportFlexServiceTierModel(createModel({ id: 'gpt-5-preview' }))).toBe(true)
})
it('filters models for agent usage', () => {
expect(agentModelFilter(createModel())).toBe(true)
it('returns false for unsupported models', () => {
expect(isSupportFlexServiceTierModel(createModel({ id: 'o3-mini' }))).toBe(false)
})
})
embeddingMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'text-embedding' }))).toBe(false)
describe('isSupportedFlexServiceTier', () => {
it('returns false for non-flex models', () => {
expect(isSupportedFlexServiceTier(createModel({ id: 'gpt-4o' }))).toBe(false)
})
})
embeddingMock.mockReturnValue(false)
rerankMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'rerank' }))).toBe(false)
rerankMock.mockReturnValue(false)
textToImageMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'gpt-image-1' }))).toBe(false)
})
describe('Temperature and top-p support', () => {
describe('isNotSupportTemperatureAndTopP', () => {
it('returns true for reasoning models', () => {
const model = createModel({ id: 'o1' })
reasoningMock.mockReturnValue(true)
expect(isNotSupportTemperatureAndTopP(model)).toBe(true)
})
it('identifies models with maximum temperature of 1.0', () => {
// Zhipu models should have max temperature of 1.0
expect(isMaxTemperatureOneModel(createModel({ id: 'glm-4' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'GLM-4-Plus' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'glm-3-turbo' }))).toBe(true)
it('returns false for open weight models', () => {
const openWeight = createModel({ id: 'gpt-oss-debug' })
expect(isNotSupportTemperatureAndTopP(openWeight)).toBe(false)
})
// Anthropic models should have max temperature of 1.0
expect(isMaxTemperatureOneModel(createModel({ id: 'claude-3.5-sonnet' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'Claude-3-opus' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'claude-2.1' }))).toBe(true)
it('returns true for chat-only models without reasoning', () => {
const chatOnly = createModel({ id: 'o1-preview' })
reasoningMock.mockReturnValue(false)
expect(isNotSupportTemperatureAndTopP(chatOnly)).toBe(true)
})
// Moonshot models should have max temperature of 1.0
expect(isMaxTemperatureOneModel(createModel({ id: 'moonshot-1.0' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'kimi-k2-thinking' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'Moonshot-Pro' }))).toBe(true)
it('returns true for Qwen MT models', () => {
const qwenMt = createModel({ id: 'qwen-mt-large', provider: 'aliyun' })
expect(isNotSupportTemperatureAndTopP(qwenMt)).toBe(true)
})
})
})
describe('Text delta support', () => {
describe('isNotSupportTextDeltaModel', () => {
it('returns true for qwen-mt-turbo and qwen-mt-plus models', () => {
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-turbo' }))).toBe(true)
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-plus' }))).toBe(true)
expect(isNotSupportTextDeltaModel(createModel({ id: 'Qwen-MT-Turbo' }))).toBe(true)
expect(isNotSupportTextDeltaModel(createModel({ id: 'QWEN-MT-PLUS' }))).toBe(true)
})
it('returns false for qwen-mt-flash and other models', () => {
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-flash' }))).toBe(false)
expect(isNotSupportTextDeltaModel(createModel({ id: 'Qwen-MT-Flash' }))).toBe(false)
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-turbo' }))).toBe(false)
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-plus' }))).toBe(false)
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-max' }))).toBe(false)
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen2.5-72b' }))).toBe(false)
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-vl-plus' }))).toBe(false)
})
it('returns false for non-qwen models', () => {
expect(isNotSupportTextDeltaModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isNotSupportTextDeltaModel(createModel({ id: 'claude-3.5' }))).toBe(false)
expect(isNotSupportTextDeltaModel(createModel({ id: 'glm-4-plus' }))).toBe(false)
})
it('handles models with version suffixes', () => {
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-turbo-1201' }))).toBe(true)
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-plus-0828' }))).toBe(true)
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-turbo-0828' }))).toBe(false)
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-plus-latest' }))).toBe(false)
})
})
})
describe('Model provider detection', () => {
describe('isGemmaModel', () => {
it('detects Gemma models by ID', () => {
expect(isGemmaModel(createModel({ id: 'Gemma-3-27B' }))).toBe(true)
})
it('detects Gemma models by group', () => {
expect(isGemmaModel(createModel({ group: 'Gemma' }))).toBe(true)
})
it('returns false for non-Gemma models', () => {
expect(isGemmaModel(createModel({ id: 'gpt-4o' }))).toBe(false)
})
})
describe('isGeminiModel', () => {
it('detects Gemini models', () => {
expect(isGeminiModel(createModel({ id: 'Gemini-2.0' }))).toBe(true)
})
})
describe('isZhipuModel', () => {
it('detects Zhipu models by provider', () => {
expect(isZhipuModel(createModel({ provider: 'zhipu' }))).toBe(true)
})
it('returns false for non-Zhipu models', () => {
expect(isZhipuModel(createModel({ provider: 'openai' }))).toBe(false)
})
})
describe('isAnthropicModel', () => {
it('detects Anthropic models', () => {
expect(isAnthropicModel(createModel({ id: 'claude-3.5' }))).toBe(true)
})
})
describe('isQwenMTModel', () => {
it('detects Qwen MT models', () => {
expect(isQwenMTModel(createModel({ id: 'qwen-mt-plus' }))).toBe(true)
})
})
describe('isOpenAIOpenWeightModel', () => {
it('detects OpenAI open weight models', () => {
expect(isOpenAIOpenWeightModel(createModel({ id: 'gpt-oss-free' }))).toBe(true)
})
})
})
describe('System message support', () => {
describe('isNotSupportSystemMessageModel', () => {
it('returns true for models that do not support system messages', () => {
expect(isNotSupportSystemMessageModel(createModel({ id: 'gemma-moe' }))).toBe(true)
})
})
})
describe('Model grouping', () => {
describe('groupQwenModels', () => {
it('groups qwen models by prefix', () => {
const qwen = createModel({ id: 'Qwen-7B', provider: 'qwen', name: 'Qwen-7B' })
const qwenOmni = createModel({ id: 'qwen2.5-omni', name: 'qwen2.5-omni' })
const other = createModel({ id: 'deepseek-v3', group: 'DeepSeek' })
const grouped = groupQwenModels([qwen, qwenOmni, other])
expect(Object.keys(grouped)).toContain('qwen-7b')
expect(Object.keys(grouped)).toContain('qwen2.5')
expect(grouped.DeepSeek).toContain(other)
})
})
})
describe('Vision and image generation', () => {
describe('isVisionModels', () => {
it('returns true when all models support vision', () => {
const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })]
expect(isVisionModels(models)).toBe(true)
})
it('returns false when some models do not support vision', () => {
const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })]
visionMock.mockReturnValueOnce(true).mockReturnValueOnce(false)
expect(isVisionModels(models)).toBe(false)
})
})
describe('isGenerateImageModels', () => {
it('returns true when all models support image generation', () => {
const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })]
generateImageMock.mockReturnValue(true)
expect(isGenerateImageModels(models)).toBe(true)
})
it('returns false when some models do not support image generation', () => {
const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })]
generateImageMock.mockReturnValueOnce(true).mockReturnValueOnce(false)
expect(isGenerateImageModels(models)).toBe(false)
})
})
})
describe('Model filtering', () => {
describe('isSupportedModel', () => {
it('filters supported OpenAI catalog entries', () => {
expect(isSupportedModel({ id: 'gpt-4', object: 'model' } as any)).toBe(true)
})
it('filters unsupported OpenAI catalog entries', () => {
expect(isSupportedModel({ id: 'tts-1', object: 'model' } as any)).toBe(false)
})
})
describe('agentModelFilter', () => {
it('returns true for regular models', () => {
expect(agentModelFilter(createModel())).toBe(true)
})
it('filters out embedding models', () => {
embeddingMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'text-embedding' }))).toBe(false)
})
it('filters out rerank models', () => {
embeddingMock.mockReturnValue(false)
rerankMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'rerank' }))).toBe(false)
})
it('filters out non-function-call models', () => {
rerankMock.mockReturnValue(false)
isFunctionCallingModelMock.mockReturnValueOnce(false)
expect(agentModelFilter(createModel({ id: 'DeepSeek R1' }))).toBe(false)
})
it('filters out text-to-image models', () => {
rerankMock.mockReturnValue(false)
textToImageMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'gpt-image-1' }))).toBe(false)
})
})
textToImageMock.mockReturnValue(false)
generateImageMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'dall-e-3' }))).toBe(false)
})
describe('Temperature limits', () => {
describe('isMaxTemperatureOneModel', () => {
it('returns true for Zhipu models', () => {
expect(isMaxTemperatureOneModel(createModel({ id: 'glm-4' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'GLM-4-Plus' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'glm-3-turbo' }))).toBe(true)
})
it('returns true for Anthropic models', () => {
expect(isMaxTemperatureOneModel(createModel({ id: 'claude-3.5-sonnet' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'Claude-3-opus' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'claude-2.1' }))).toBe(true)
})
it('returns true for Moonshot models', () => {
expect(isMaxTemperatureOneModel(createModel({ id: 'moonshot-1.0' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'kimi-k2-thinking' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'Moonshot-Pro' }))).toBe(true)
})
it('returns false for other models', () => {
expect(isMaxTemperatureOneModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isMaxTemperatureOneModel(createModel({ id: 'gpt-4-turbo' }))).toBe(false)
expect(isMaxTemperatureOneModel(createModel({ id: 'qwen-max' }))).toBe(false)
expect(isMaxTemperatureOneModel(createModel({ id: 'gemini-pro' }))).toBe(false)
})
})
// Other models should return false
expect(isMaxTemperatureOneModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isMaxTemperatureOneModel(createModel({ id: 'gpt-4-turbo' }))).toBe(false)
expect(isMaxTemperatureOneModel(createModel({ id: 'qwen-max' }))).toBe(false)
expect(isMaxTemperatureOneModel(createModel({ id: 'gemini-pro' }))).toBe(false)
})
})

View File

@@ -396,11 +396,7 @@ export function isClaude45ReasoningModel(model: Model): boolean {
export function isClaude4SeriesModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id, '/')
// Supports various formats including:
// - Direct API: claude-sonnet-4, claude-opus-4-20250514
// - GCP Vertex AI: claude-sonnet-4@20250514
// - AWS Bedrock: anthropic.claude-sonnet-4-20250514-v1:0
const regex = /claude-(sonnet|opus|haiku)-4(?:[.-]\d+)?(?:[@\-:][\w\-:]+)?$/i
const regex = /claude-(sonnet|opus|haiku)-4(?:[.-]\d+)?(?:-[\w-]+)?$/i
return regex.test(modelId)
}
@@ -460,19 +456,16 @@ export const isSupportedThinkingTokenZhipuModel = (model: Model): boolean => {
}
export const isDeepSeekHybridInferenceModel = (model: Model) => {
const { idResult, nameResult } = withModelIdAndNameAsId(model, (model) => {
const modelId = getLowerBaseModelName(model.id)
// deepseek官方使用chat和reasoner做推理控制其他provider需要单独判断id可能会有所差别
// openrouter: deepseek/deepseek-chat-v3.1 不知道会不会有其他provider仿照ds官方分出一个同id的作为非思考模式的模型这里有风险
// Matches: "deepseek-v3" followed by ".digit" or "-digit".
// Optionally, this can be followed by ".alphanumeric_sequence" or "-alphanumeric_sequence"
// until the end of the string.
// Examples: deepseek-v3.1, deepseek-v3-1, deepseek-v3.1.2, deepseek-v3.1-alpha
// Does NOT match: deepseek-v3.123 (missing separator after '1'), deepseek-v3.x (x isn't a digit)
// TODO: move to utils and add test cases
return /deepseek-v3(?:\.\d|-\d)(?:(\.|-)\w+)?$/.test(modelId) || modelId.includes('deepseek-chat-v3.1')
})
return idResult || nameResult
const modelId = getLowerBaseModelName(model.id)
// deepseek官方使用chat和reasoner做推理控制其他provider需要单独判断id可能会有所差别
// openrouter: deepseek/deepseek-chat-v3.1 不知道会不会有其他provider仿照ds官方分出一个同id的作为非思考模式的模型这里有风险
// Matches: "deepseek-v3" followed by ".digit" or "-digit".
// Optionally, this can be followed by ".alphanumeric_sequence" or "-alphanumeric_sequence"
// until the end of the string.
// Examples: deepseek-v3.1, deepseek-v3-1, deepseek-v3.1.2, deepseek-v3.1-alpha
// Does NOT match: deepseek-v3.123 (missing separator after '1'), deepseek-v3.x (x isn't a digit)
// TODO: move to utils and add test cases
return /deepseek-v3(?:\.\d|-\d)(?:(\.|-)\w+)?$/.test(modelId) || modelId.includes('deepseek-chat-v3.1')
}
export const isLingReasoningModel = (model?: Model): boolean => {
@@ -526,6 +519,7 @@ export function isReasoningModel(model?: Model): boolean {
REASONING_REGEX.test(model.name) ||
isSupportedThinkingTokenDoubaoModel(model) ||
isDeepSeekHybridInferenceModel(model) ||
isDeepSeekHybridInferenceModel({ ...model, id: model.name }) ||
false
)
}

View File

@@ -1,8 +1,6 @@
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model } from '@renderer/types'
import { isSystemProviderId } from '@renderer/types'
import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
import { isAzureOpenAIProvider } from '@shared/provider'
import { isEmbeddingModel, isRerankModel } from './embedding'
import { isDeepSeekHybridInferenceModel } from './reasoning'
@@ -54,13 +52,6 @@ export const FUNCTION_CALLING_REGEX = new RegExp(
'i'
)
const AZURE_FUNCTION_CALLING_EXCLUDED_MODELS = [
'(?:Meta-)?Llama-3(?:\\.\\d+)?-[\\w-]+',
'Phi-[34](?:\\.[\\w-]+)?(?:-[\\w-]+)?',
'DeepSeek-(?:R1|V3)',
'Codestral-2501'
]
export function isFunctionCallingModel(model?: Model): boolean {
if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) {
return false
@@ -76,15 +67,6 @@ export function isFunctionCallingModel(model?: Model): boolean {
return FUNCTION_CALLING_REGEX.test(modelId) || FUNCTION_CALLING_REGEX.test(model.name)
}
const provider = getProviderByModel(model)
if (isAzureOpenAIProvider(provider)) {
const azureExcludedRegex = new RegExp(`\\b(?:${AZURE_FUNCTION_CALLING_EXCLUDED_MODELS.join('|')})\\b`, 'i')
if (azureExcludedRegex.test(modelId)) {
return false
}
}
if (['deepseek', 'anthropic', 'kimi', 'moonshot'].includes(model.provider)) {
return true
}

View File

@@ -1,20 +1,11 @@
import type OpenAI from '@cherrystudio/openai'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding'
import { getProviderByModel } from '@renderer/services/AssistantService'
import { type Model, SystemProviderIds } from '@renderer/types'
import type { OpenAIVerbosity, ValidOpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import { getLowerBaseModelName } from '@renderer/utils'
import {
isGPT5ProModel,
isGPT5SeriesModel,
isGPT51SeriesModel,
isOpenAIChatCompletionOnlyModel,
isOpenAIOpenWeightModel,
isOpenAIReasoningModel
} from './openai'
import { isOpenAIChatCompletionOnlyModel, isOpenAIOpenWeightModel, isOpenAIReasoningModel } from './openai'
import { isQwenMTModel } from './qwen'
import { isFunctionCallingModel } from './tooluse'
import { isGenerateImageModel, isTextToImageModel, isVisionModel } from './vision'
export const NOT_SUPPORTED_REGEX = /(?:^tts|whisper|speech)/i
export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini.*-flash.*$', 'i')
@@ -132,46 +123,21 @@ export const isNotSupportSystemMessageModel = (model: Model): boolean => {
return isQwenMTModel(model) || isGemmaModel(model)
}
// Verbosity settings is only supported by GPT-5 and newer models
// Specifically, GPT-5 and GPT-5.1 for now
// GPT-5 verbosity configuration
// gpt-5-pro only supports 'high', other GPT-5 models support all levels
const MODEL_SUPPORTED_VERBOSITY: readonly {
readonly validator: (model: Model) => boolean
readonly values: readonly ValidOpenAIVerbosity[]
}[] = [
// gpt-5-pro
{ validator: isGPT5ProModel, values: ['high'] },
// gpt-5 except gpt-5-pro
{
validator: (model: Model) => isGPT5SeriesModel(model) && !isGPT5ProModel(model),
values: ['low', 'medium', 'high']
},
// gpt-5.1
{ validator: isGPT51SeriesModel, values: ['low', 'medium', 'high'] }
]
export const MODEL_SUPPORTED_VERBOSITY: Record<string, ValidOpenAIVerbosity[]> = {
'gpt-5-pro': ['high'],
default: ['low', 'medium', 'high']
} as const
/**
* Returns the list of supported verbosity levels for the given model.
* If the model is not recognized as a GPT-5 series model, only `undefined` is returned.
* For GPT-5-pro, only 'high' is supported; for other GPT-5 models, 'low', 'medium', and 'high' are supported.
* For GPT-5.1 series models, 'low', 'medium', and 'high' are supported.
* @param model - The model to check
* @returns An array of supported verbosity levels, always including `undefined` as the first element
*/
export const getModelSupportedVerbosity = (model: Model | undefined | null): OpenAIVerbosity[] => {
if (!model) {
return [undefined]
export const getModelSupportedVerbosity = (model: Model): OpenAIVerbosity[] => {
const modelId = getLowerBaseModelName(model.id)
let supportedValues: ValidOpenAIVerbosity[]
if (modelId.includes('gpt-5-pro')) {
supportedValues = MODEL_SUPPORTED_VERBOSITY['gpt-5-pro']
} else {
supportedValues = MODEL_SUPPORTED_VERBOSITY.default
}
let supportedValues: ValidOpenAIVerbosity[] = []
for (const { validator, values } of MODEL_SUPPORTED_VERBOSITY) {
if (validator(model)) {
supportedValues = [...values]
break
}
}
return [undefined, ...supportedValues]
}
@@ -183,21 +149,8 @@ export const isGeminiModel = (model: Model) => {
// zhipu 视觉推理模型用这组 special token 标记推理结果
export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const
// TODO: 支持提示词模式的工具调用
export const agentModelFilter = (model: Model): boolean => {
const provider = getProviderByModel(model)
// 需要适配,且容易超出限额
if (provider.id === SystemProviderIds.copilot) {
return false
}
return (
!isEmbeddingModel(model) &&
!isRerankModel(model) &&
!isTextToImageModel(model) &&
!isGenerateImageModel(model) &&
isFunctionCallingModel(model)
)
return !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model)
}
export const isMaxTemperatureOneModel = (model: Model): boolean => {

Some files were not shown because too many files have changed in this diff Show More