Compare commits
68 Commits
feat/acces
...
feat/proxy
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d77202afd | ||
|
|
8c9d79a7d4 | ||
|
|
fb9a8e7e2c | ||
|
|
14cc38a626 | ||
|
|
a231952969 | ||
|
|
874d69291f | ||
|
|
4a913fcef7 | ||
|
|
b3a58ec321 | ||
|
|
0097ca80e2 | ||
|
|
4c1466cd27 | ||
|
|
d968df4612 | ||
|
|
2bd680361a | ||
|
|
cc676d4bef | ||
|
|
3b1155b538 | ||
|
|
03ff6e1ca6 | ||
|
|
350519ac0a | ||
|
|
706fac898a | ||
|
|
f5c144404d | ||
|
|
50a217a638 | ||
|
|
444c13e1e3 | ||
|
|
255b19d6ee | ||
|
|
3989229f61 | ||
|
|
c6c7c240a3 | ||
|
|
f1f4831157 | ||
|
|
35cfc7c517 | ||
|
|
e255a992cc | ||
|
|
876f59d650 | ||
|
|
c23e88ecd1 | ||
|
|
284d0f99e1 | ||
|
|
13ac5d564a | ||
|
|
e8dccf51fe | ||
|
|
4620b71aee | ||
|
|
ed769ac4f7 | ||
|
|
1b926178f1 | ||
|
|
5167c927be | ||
|
|
95c18d192a | ||
|
|
534d27f37e | ||
|
|
313be4427b | ||
|
|
9d34098a53 | ||
|
|
d367040fd4 | ||
|
|
b18c64b725 | ||
|
|
7ce1590eaf | ||
|
|
356e828422 | ||
|
|
ce25001590 | ||
|
|
77a9504f74 | ||
|
|
77c1b77113 | ||
|
|
bf35902696 | ||
|
|
f163c4d3ee | ||
|
|
0d12b5fbc2 | ||
|
|
0f6ec3e061 | ||
|
|
1746e8b21f | ||
|
|
5d1d2b7a9b | ||
|
|
15c0a3881c | ||
|
|
dad9cc95ad | ||
|
|
f02c0fe962 | ||
|
|
4c4102da20 | ||
|
|
2a1adfe322 | ||
|
|
36ed062b84 | ||
|
|
f225fbe3e3 | ||
|
|
0836eef1a6 | ||
|
|
d0bd10190d | ||
|
|
ccfb9423e0 | ||
|
|
192357a32e | ||
|
|
d8191bd4fb | ||
|
|
a5e7aa1342 | ||
|
|
d15571c727 | ||
|
|
a2f67dddb6 | ||
|
|
8f00321a60 |
2
.gitignore
vendored
2
.gitignore
vendored
@@ -73,3 +73,5 @@ test-results
|
|||||||
YOUR_MEMORY_FILE_PATH
|
YOUR_MEMORY_FILE_PATH
|
||||||
|
|
||||||
.sessions/
|
.sessions/
|
||||||
|
.next/
|
||||||
|
*.tsbuildinfo
|
||||||
|
|||||||
@@ -11,6 +11,7 @@
|
|||||||
"dist/**",
|
"dist/**",
|
||||||
"out/**",
|
"out/**",
|
||||||
"local/**",
|
"local/**",
|
||||||
|
"tests/**",
|
||||||
".yarn/**",
|
".yarn/**",
|
||||||
".gitignore",
|
".gitignore",
|
||||||
"scripts/cloudflare-worker.js",
|
"scripts/cloudflare-worker.js",
|
||||||
|
|||||||
10
CLAUDE.md
10
CLAUDE.md
@@ -12,7 +12,15 @@ This file provides guidance to AI coding assistants when working with code in th
|
|||||||
- **Always propose before executing**: Before making any changes, clearly explain your planned approach and wait for explicit user approval to ensure alignment and prevent unwanted modifications.
|
- **Always propose before executing**: Before making any changes, clearly explain your planned approach and wait for explicit user approval to ensure alignment and prevent unwanted modifications.
|
||||||
- **Lint, test, and format before completion**: Coding tasks are only complete after running `yarn lint`, `yarn test`, and `yarn format` successfully.
|
- **Lint, test, and format before completion**: Coding tasks are only complete after running `yarn lint`, `yarn test`, and `yarn format` successfully.
|
||||||
- **Write conventional commits**: Commit small, focused changes using Conventional Commit messages (e.g., `feat:`, `fix:`, `refactor:`, `docs:`).
|
- **Write conventional commits**: Commit small, focused changes using Conventional Commit messages (e.g., `feat:`, `fix:`, `refactor:`, `docs:`).
|
||||||
- **Follow PR template**: When submitting pull requests, follow the template in `.github/pull_request_template.md` to ensure complete context and documentation.
|
|
||||||
|
## Pull Request Workflow (CRITICAL)
|
||||||
|
|
||||||
|
When creating a Pull Request, you MUST:
|
||||||
|
|
||||||
|
1. **Read the PR template first**: Always read `.github/pull_request_template.md` before creating the PR
|
||||||
|
2. **Follow ALL template sections**: Structure the `--body` parameter to include every section from the template
|
||||||
|
3. **Never skip sections**: Include all sections even if marking them as N/A or "None"
|
||||||
|
4. **Use proper formatting**: Match the template's markdown structure exactly (headings, checkboxes, code blocks)
|
||||||
|
|
||||||
## Development Commands
|
## Development Commands
|
||||||
|
|
||||||
|
|||||||
@@ -134,56 +134,108 @@ artifactBuildCompleted: scripts/artifact-build-completed.js
|
|||||||
releaseInfo:
|
releaseInfo:
|
||||||
releaseNotes: |
|
releaseNotes: |
|
||||||
<!--LANG:en-->
|
<!--LANG:en-->
|
||||||
What's New in v1.7.0-rc.3
|
A New Era of Intelligence with Cherry Studio 1.7.1
|
||||||
|
|
||||||
✨ New Features:
|
Today we're releasing Cherry Studio 1.7.1 — our most ambitious update yet, introducing Agent: autonomous AI that thinks, plans, and acts.
|
||||||
- Provider: Added Silicon provider support for Anthropic API compatibility
|
|
||||||
- Provider: AIHubMix support for nano banana
|
|
||||||
|
|
||||||
🐛 Bug Fixes:
|
For years, AI assistants have been reactive — waiting for your commands, responding to your questions. With Agent, we're changing that. Now, AI can truly work alongside you: understanding complex goals, breaking them into steps, and executing them independently.
|
||||||
- i18n: Clean up translation tags and untranslated strings
|
|
||||||
- Provider: Fixed Silicon provider code list
|
|
||||||
- Provider: Fixed Poe API reasoning parameters for GPT-5 and reasoning models
|
|
||||||
- Provider: Fixed duplicate /v1 in Anthropic API endpoints
|
|
||||||
- Provider: Fixed Azure provider handling in AI SDK integration
|
|
||||||
- Models: Added Claude Opus 4.5 pattern to THINKING_TOKEN_MAP
|
|
||||||
- Models: Improved Gemini reasoning and message handling
|
|
||||||
- Models: Fixed custom parameters for Gemini models
|
|
||||||
- Models: Fixed qwen-mt-flash text delta support
|
|
||||||
- Models: Fixed Groq verbosity setting
|
|
||||||
- UI: Fixed quota display and quota tips
|
|
||||||
- UI: Fixed web search button condition
|
|
||||||
- Settings: Fixed updateAssistantPreset reducer to properly update preset
|
|
||||||
- Settings: Respect enableMaxTokens setting when maxTokens is not configured
|
|
||||||
- SDK: Fixed header merging logic in AI SDK
|
|
||||||
|
|
||||||
⚡ Improvements:
|
This is what we've been building toward. And it's just the beginning.
|
||||||
- SDK: Upgraded @anthropic-ai/claude-agent-sdk to 0.1.53
|
|
||||||
|
🤖 Meet Agent
|
||||||
|
Imagine having a brilliant colleague who never sleeps. Give Agent a goal — write a report, analyze data, refactor code — and watch it work. It reasons through problems, breaks them into steps, calls the right tools, and adapts when things change.
|
||||||
|
|
||||||
|
- **Think → Plan → Act**: From goal to execution, fully autonomous
|
||||||
|
- **Deep Reasoning**: Multi-turn thinking that solves real problems
|
||||||
|
- **Tool Mastery**: File operations, web search, code execution, and more
|
||||||
|
- **Skill Plugins**: Extend with custom commands and capabilities
|
||||||
|
- **You Stay in Control**: Real-time approval for sensitive actions
|
||||||
|
- **Full Visibility**: Every thought, every decision, fully transparent
|
||||||
|
|
||||||
|
🌐 Expanding Ecosystem
|
||||||
|
- **New Providers**: HuggingFace, Mistral, CherryIN, AI Gateway, Intel OVMS, Didi MCP
|
||||||
|
- **New Models**: Claude 4.5 Haiku, DeepSeek v3.2, GLM-4.6, Doubao, Ling series
|
||||||
|
- **MCP Integration**: Alibaba Cloud, ModelScope, Higress, MCP.so, TokenFlux and more
|
||||||
|
|
||||||
|
📚 Smarter Knowledge Base
|
||||||
|
- **OpenMinerU**: Self-hosted document processing
|
||||||
|
- **Full-Text Search**: Find anything instantly across your notes
|
||||||
|
- **Enhanced Tool Selection**: Smarter configuration for better AI assistance
|
||||||
|
|
||||||
|
📝 Notes, Reimagined
|
||||||
|
- Full-text search with highlighted results
|
||||||
|
- AI-powered smart rename
|
||||||
|
- Export as image
|
||||||
|
- Auto-wrap for tables
|
||||||
|
|
||||||
|
🖼️ Image & OCR
|
||||||
|
- Intel OVMS painting capabilities
|
||||||
|
- Intel OpenVINO NPU-accelerated OCR
|
||||||
|
|
||||||
|
🌍 Now in 10+ Languages
|
||||||
|
- Added German support
|
||||||
|
- Enhanced internationalization
|
||||||
|
|
||||||
|
⚡ Faster & More Polished
|
||||||
|
- Electron 38 upgrade
|
||||||
|
- New MCP management interface
|
||||||
|
- Dozens of UI refinements
|
||||||
|
|
||||||
|
❤️ Fully Open Source
|
||||||
|
Commercial restrictions removed. Cherry Studio now follows standard AGPL v3 — free for teams of any size.
|
||||||
|
|
||||||
|
The Agent Era is here. We can't wait to see what you'll create.
|
||||||
|
|
||||||
<!--LANG:zh-CN-->
|
<!--LANG:zh-CN-->
|
||||||
v1.7.0-rc.3 更新内容
|
Cherry Studio 1.7.1:开启智能新纪元
|
||||||
|
|
||||||
✨ 新功能:
|
今天,我们正式发布 Cherry Studio 1.7.1 —— 迄今最具雄心的版本,带来全新的 Agent:能够自主思考、规划和行动的 AI。
|
||||||
- 提供商:新增 Silicon 提供商对 Anthropic API 的兼容性支持
|
|
||||||
- 提供商:AIHubMix 支持 nano banana
|
|
||||||
|
|
||||||
🐛 问题修复:
|
多年来,AI 助手一直是被动的——等待你的指令,回应你的问题。Agent 改变了这一切。现在,AI 能够真正与你并肩工作:理解复杂目标,将其拆解为步骤,并独立执行。
|
||||||
- 国际化:清理翻译标签和未翻译字符串
|
|
||||||
- 提供商:修复 Silicon 提供商代码列表
|
|
||||||
- 提供商:修复 Poe API 对 GPT-5 和推理模型的推理参数
|
|
||||||
- 提供商:修复 Anthropic API 端点重复 /v1 问题
|
|
||||||
- 提供商:修复 Azure 提供商在 AI SDK 集成中的处理
|
|
||||||
- 模型:Claude Opus 4.5 添加到 THINKING_TOKEN_MAP
|
|
||||||
- 模型:改进 Gemini 推理和消息处理
|
|
||||||
- 模型:修复 Gemini 模型自定义参数
|
|
||||||
- 模型:修复 qwen-mt-flash text delta 支持
|
|
||||||
- 模型:修复 Groq verbosity 设置
|
|
||||||
- 界面:修复配额显示和配额提示
|
|
||||||
- 界面:修复 Web 搜索按钮条件
|
|
||||||
- 设置:修复 updateAssistantPreset reducer 正确更新 preset
|
|
||||||
- 设置:尊重 enableMaxTokens 设置
|
|
||||||
- SDK:修复 AI SDK 中 header 合并逻辑
|
|
||||||
|
|
||||||
⚡ 改进:
|
这是我们一直在构建的未来。而这,仅仅是开始。
|
||||||
- SDK:升级 @anthropic-ai/claude-agent-sdk 到 0.1.53
|
|
||||||
|
🤖 认识 Agent
|
||||||
|
想象一位永不疲倦的得力伙伴。给 Agent 一个目标——撰写报告、分析数据、重构代码——然后看它工作。它会推理问题、拆解步骤、调用工具,并在情况变化时灵活应对。
|
||||||
|
|
||||||
|
- **思考 → 规划 → 行动**:从目标到执行,全程自主
|
||||||
|
- **深度推理**:多轮思考,解决真实问题
|
||||||
|
- **工具大师**:文件操作、网络搜索、代码执行,样样精通
|
||||||
|
- **技能插件**:自定义命令,无限扩展
|
||||||
|
- **你掌控全局**:敏感操作,实时审批
|
||||||
|
- **完全透明**:每一步思考,每一个决策,清晰可见
|
||||||
|
|
||||||
|
🌐 生态持续壮大
|
||||||
|
- **新增服务商**:Hugging Face、Mistral、Perplexity、SophNet、AI Gateway、Cerebras AI
|
||||||
|
- **新增模型**:Gemini 3、Gemini 3 Pro(支持图像预览)、GPT-5.1、Claude Opus 4.5
|
||||||
|
- **MCP 集成**:百炼、魔搭、Higress、MCP.so、TokenFlux 等平台
|
||||||
|
|
||||||
|
📚 更智能的知识库
|
||||||
|
- **OpenMinerU**:本地自部署文档处理
|
||||||
|
- **全文搜索**:笔记内容一搜即达
|
||||||
|
- **增强工具选择**:更智能的配置,更好的 AI 协助
|
||||||
|
|
||||||
|
📝 笔记,焕然一新
|
||||||
|
- 全文搜索,结果高亮
|
||||||
|
- AI 智能重命名
|
||||||
|
- 导出为图片
|
||||||
|
- 表格自动换行
|
||||||
|
|
||||||
|
🖼️ 图像与 OCR
|
||||||
|
- Intel OVMS 绘图能力
|
||||||
|
- Intel OpenVINO NPU 加速 OCR
|
||||||
|
|
||||||
|
🌍 支持 10+ 种语言
|
||||||
|
- 新增德语支持
|
||||||
|
- 全面增强国际化
|
||||||
|
|
||||||
|
⚡ 更快、更精致
|
||||||
|
- 升级 Electron 38
|
||||||
|
- 新的 MCP 管理界面
|
||||||
|
- 数十处 UI 细节打磨
|
||||||
|
|
||||||
|
❤️ 完全开源
|
||||||
|
商用限制已移除。Cherry Studio 现遵循标准 AGPL v3 协议——任意规模团队均可自由使用。
|
||||||
|
|
||||||
|
Agent 纪元已至。期待你的创造。
|
||||||
<!--LANG:END-->
|
<!--LANG:END-->
|
||||||
|
|||||||
@@ -25,7 +25,10 @@ export default defineConfig({
|
|||||||
'@shared': resolve('packages/shared'),
|
'@shared': resolve('packages/shared'),
|
||||||
'@logger': resolve('src/main/services/LoggerService'),
|
'@logger': resolve('src/main/services/LoggerService'),
|
||||||
'@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'),
|
'@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'),
|
||||||
'@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node')
|
'@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node'),
|
||||||
|
'@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'),
|
||||||
|
'@cherrystudio/ai-core': resolve('packages/aiCore/src'),
|
||||||
|
'@cherrystudio/ai-sdk-provider': resolve('packages/ai-sdk-provider/src')
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
build: {
|
build: {
|
||||||
|
|||||||
@@ -58,6 +58,7 @@ export default defineConfig([
|
|||||||
'dist/**',
|
'dist/**',
|
||||||
'out/**',
|
'out/**',
|
||||||
'local/**',
|
'local/**',
|
||||||
|
'tests/**',
|
||||||
'.yarn/**',
|
'.yarn/**',
|
||||||
'.gitignore',
|
'.gitignore',
|
||||||
'scripts/cloudflare-worker.js',
|
'scripts/cloudflare-worker.js',
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "CherryStudio",
|
"name": "CherryStudio",
|
||||||
"version": "1.7.0-rc.3",
|
"version": "1.7.1",
|
||||||
"private": true,
|
"private": true,
|
||||||
"description": "A powerful AI assistant for producer.",
|
"description": "A powerful AI assistant for producer.",
|
||||||
"main": "./out/main/index.js",
|
"main": "./out/main/index.js",
|
||||||
@@ -62,6 +62,7 @@
|
|||||||
"test": "vitest run --silent",
|
"test": "vitest run --silent",
|
||||||
"test:main": "vitest run --project main",
|
"test:main": "vitest run --project main",
|
||||||
"test:renderer": "vitest run --project renderer",
|
"test:renderer": "vitest run --project renderer",
|
||||||
|
"test:aicore": "vitest run --project aiCore",
|
||||||
"test:update": "yarn test:renderer --update",
|
"test:update": "yarn test:renderer --update",
|
||||||
"test:coverage": "vitest run --coverage --silent",
|
"test:coverage": "vitest run --coverage --silent",
|
||||||
"test:ui": "vitest --ui",
|
"test:ui": "vitest --ui",
|
||||||
@@ -164,7 +165,7 @@
|
|||||||
"@modelcontextprotocol/sdk": "^1.17.5",
|
"@modelcontextprotocol/sdk": "^1.17.5",
|
||||||
"@mozilla/readability": "^0.6.0",
|
"@mozilla/readability": "^0.6.0",
|
||||||
"@notionhq/client": "^2.2.15",
|
"@notionhq/client": "^2.2.15",
|
||||||
"@openrouter/ai-sdk-provider": "^1.2.5",
|
"@openrouter/ai-sdk-provider": "^1.2.8",
|
||||||
"@opentelemetry/api": "^1.9.0",
|
"@opentelemetry/api": "^1.9.0",
|
||||||
"@opentelemetry/core": "2.0.0",
|
"@opentelemetry/core": "2.0.0",
|
||||||
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
|
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
|
||||||
@@ -172,7 +173,7 @@
|
|||||||
"@opentelemetry/sdk-trace-node": "^2.0.0",
|
"@opentelemetry/sdk-trace-node": "^2.0.0",
|
||||||
"@opentelemetry/sdk-trace-web": "^2.0.0",
|
"@opentelemetry/sdk-trace-web": "^2.0.0",
|
||||||
"@opeoginni/github-copilot-openai-compatible": "^0.1.21",
|
"@opeoginni/github-copilot-openai-compatible": "^0.1.21",
|
||||||
"@playwright/test": "^1.52.0",
|
"@playwright/test": "^1.55.1",
|
||||||
"@radix-ui/react-context-menu": "^2.2.16",
|
"@radix-ui/react-context-menu": "^2.2.16",
|
||||||
"@reduxjs/toolkit": "^2.2.5",
|
"@reduxjs/toolkit": "^2.2.5",
|
||||||
"@shikijs/markdown-it": "^3.12.0",
|
"@shikijs/markdown-it": "^3.12.0",
|
||||||
@@ -321,7 +322,6 @@
|
|||||||
"p-queue": "^8.1.0",
|
"p-queue": "^8.1.0",
|
||||||
"pdf-lib": "^1.17.1",
|
"pdf-lib": "^1.17.1",
|
||||||
"pdf-parse": "^1.1.1",
|
"pdf-parse": "^1.1.1",
|
||||||
"playwright": "^1.55.1",
|
|
||||||
"proxy-agent": "^6.5.0",
|
"proxy-agent": "^6.5.0",
|
||||||
"react": "^19.2.0",
|
"react": "^19.2.0",
|
||||||
"react-dom": "^19.2.0",
|
"react-dom": "^19.2.0",
|
||||||
|
|||||||
@@ -69,6 +69,7 @@ export interface CherryInProviderSettings {
|
|||||||
headers?: HeadersInput
|
headers?: HeadersInput
|
||||||
/**
|
/**
|
||||||
* Optional endpoint type to distinguish different endpoint behaviors.
|
* Optional endpoint type to distinguish different endpoint behaviors.
|
||||||
|
* "image-generation" is also openai endpoint, but specifically for image generation.
|
||||||
*/
|
*/
|
||||||
endpointType?: 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'image-generation' | 'jina-rerank'
|
endpointType?: 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'image-generation' | 'jina-rerank'
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,12 +3,13 @@
|
|||||||
* Provides realistic mock responses for all provider types
|
* Provides realistic mock responses for all provider types
|
||||||
*/
|
*/
|
||||||
|
|
||||||
import { jsonSchema, type ModelMessage, type Tool } from 'ai'
|
import type { ModelMessage, Tool } from 'ai'
|
||||||
|
import { jsonSchema } from 'ai'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Standard test messages for all scenarios
|
* Standard test messages for all scenarios
|
||||||
*/
|
*/
|
||||||
export const testMessages = {
|
export const testMessages: Record<string, ModelMessage[]> = {
|
||||||
simple: [{ role: 'user' as const, content: 'Hello, how are you?' }],
|
simple: [{ role: 'user' as const, content: 'Hello, how are you?' }],
|
||||||
|
|
||||||
conversation: [
|
conversation: [
|
||||||
@@ -45,7 +46,7 @@ export const testMessages = {
|
|||||||
{ role: 'assistant' as const, content: '15 * 23 = 345' },
|
{ role: 'assistant' as const, content: '15 * 23 = 345' },
|
||||||
{ role: 'user' as const, content: 'Now divide that by 5' }
|
{ role: 'user' as const, content: 'Now divide that by 5' }
|
||||||
]
|
]
|
||||||
} satisfies Record<string, ModelMessage[]>
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Standard test tools for tool calling scenarios
|
* Standard test tools for tool calling scenarios
|
||||||
@@ -138,68 +139,17 @@ export const testTools: Record<string, Tool> = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Mock streaming chunks for different providers
|
|
||||||
*/
|
|
||||||
export const mockStreamingChunks = {
|
|
||||||
text: [
|
|
||||||
{ type: 'text-delta' as const, textDelta: 'Hello' },
|
|
||||||
{ type: 'text-delta' as const, textDelta: ', ' },
|
|
||||||
{ type: 'text-delta' as const, textDelta: 'this ' },
|
|
||||||
{ type: 'text-delta' as const, textDelta: 'is ' },
|
|
||||||
{ type: 'text-delta' as const, textDelta: 'a ' },
|
|
||||||
{ type: 'text-delta' as const, textDelta: 'test.' }
|
|
||||||
],
|
|
||||||
|
|
||||||
withToolCall: [
|
|
||||||
{ type: 'text-delta' as const, textDelta: 'Let me check the weather for you.' },
|
|
||||||
{
|
|
||||||
type: 'tool-call-delta' as const,
|
|
||||||
toolCallType: 'function' as const,
|
|
||||||
toolCallId: 'call_123',
|
|
||||||
toolName: 'getWeather',
|
|
||||||
argsTextDelta: '{"location":'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
type: 'tool-call-delta' as const,
|
|
||||||
toolCallType: 'function' as const,
|
|
||||||
toolCallId: 'call_123',
|
|
||||||
toolName: 'getWeather',
|
|
||||||
argsTextDelta: ' "San Francisco, CA"}'
|
|
||||||
},
|
|
||||||
{
|
|
||||||
type: 'tool-call' as const,
|
|
||||||
toolCallType: 'function' as const,
|
|
||||||
toolCallId: 'call_123',
|
|
||||||
toolName: 'getWeather',
|
|
||||||
args: { location: 'San Francisco, CA' }
|
|
||||||
}
|
|
||||||
],
|
|
||||||
|
|
||||||
withFinish: [
|
|
||||||
{ type: 'text-delta' as const, textDelta: 'Complete response.' },
|
|
||||||
{
|
|
||||||
type: 'finish' as const,
|
|
||||||
finishReason: 'stop' as const,
|
|
||||||
usage: {
|
|
||||||
promptTokens: 10,
|
|
||||||
completionTokens: 5,
|
|
||||||
totalTokens: 15
|
|
||||||
}
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Mock complete responses for non-streaming scenarios
|
* Mock complete responses for non-streaming scenarios
|
||||||
|
* Note: AI SDK v5 uses inputTokens/outputTokens instead of promptTokens/completionTokens
|
||||||
*/
|
*/
|
||||||
export const mockCompleteResponses = {
|
export const mockCompleteResponses = {
|
||||||
simple: {
|
simple: {
|
||||||
text: 'This is a simple response.',
|
text: 'This is a simple response.',
|
||||||
finishReason: 'stop' as const,
|
finishReason: 'stop' as const,
|
||||||
usage: {
|
usage: {
|
||||||
promptTokens: 15,
|
inputTokens: 15,
|
||||||
completionTokens: 8,
|
outputTokens: 8,
|
||||||
totalTokens: 23
|
totalTokens: 23
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -215,8 +165,8 @@ export const mockCompleteResponses = {
|
|||||||
],
|
],
|
||||||
finishReason: 'tool-calls' as const,
|
finishReason: 'tool-calls' as const,
|
||||||
usage: {
|
usage: {
|
||||||
promptTokens: 25,
|
inputTokens: 25,
|
||||||
completionTokens: 12,
|
outputTokens: 12,
|
||||||
totalTokens: 37
|
totalTokens: 37
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -225,14 +175,15 @@ export const mockCompleteResponses = {
|
|||||||
text: 'Response with warnings.',
|
text: 'Response with warnings.',
|
||||||
finishReason: 'stop' as const,
|
finishReason: 'stop' as const,
|
||||||
usage: {
|
usage: {
|
||||||
promptTokens: 10,
|
inputTokens: 10,
|
||||||
completionTokens: 5,
|
outputTokens: 5,
|
||||||
totalTokens: 15
|
totalTokens: 15
|
||||||
},
|
},
|
||||||
warnings: [
|
warnings: [
|
||||||
{
|
{
|
||||||
type: 'unsupported-setting' as const,
|
type: 'unsupported-setting' as const,
|
||||||
message: 'Temperature parameter not supported for this model'
|
setting: 'temperature',
|
||||||
|
details: 'Temperature parameter not supported for this model'
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
@@ -285,47 +236,3 @@ export const mockImageResponses = {
|
|||||||
warnings: []
|
warnings: []
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Mock error responses
|
|
||||||
*/
|
|
||||||
export const mockErrors = {
|
|
||||||
invalidApiKey: {
|
|
||||||
name: 'APIError',
|
|
||||||
message: 'Invalid API key provided',
|
|
||||||
statusCode: 401
|
|
||||||
},
|
|
||||||
|
|
||||||
rateLimitExceeded: {
|
|
||||||
name: 'RateLimitError',
|
|
||||||
message: 'Rate limit exceeded. Please try again later.',
|
|
||||||
statusCode: 429,
|
|
||||||
headers: {
|
|
||||||
'retry-after': '60'
|
|
||||||
}
|
|
||||||
},
|
|
||||||
|
|
||||||
modelNotFound: {
|
|
||||||
name: 'ModelNotFoundError',
|
|
||||||
message: 'The requested model was not found',
|
|
||||||
statusCode: 404
|
|
||||||
},
|
|
||||||
|
|
||||||
contextLengthExceeded: {
|
|
||||||
name: 'ContextLengthError',
|
|
||||||
message: "This model's maximum context length is 4096 tokens",
|
|
||||||
statusCode: 400
|
|
||||||
},
|
|
||||||
|
|
||||||
timeout: {
|
|
||||||
name: 'TimeoutError',
|
|
||||||
message: 'Request timed out after 30000ms',
|
|
||||||
code: 'ETIMEDOUT'
|
|
||||||
},
|
|
||||||
|
|
||||||
networkError: {
|
|
||||||
name: 'NetworkError',
|
|
||||||
message: 'Network connection failed',
|
|
||||||
code: 'ECONNREFUSED'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
35
packages/aiCore/src/__tests__/mocks/ai-sdk-provider.ts
Normal file
35
packages/aiCore/src/__tests__/mocks/ai-sdk-provider.ts
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
/**
|
||||||
|
* Mock for @cherrystudio/ai-sdk-provider
|
||||||
|
* This mock is used in tests to avoid importing the actual package
|
||||||
|
*/
|
||||||
|
|
||||||
|
export type CherryInProviderSettings = {
|
||||||
|
apiKey?: string
|
||||||
|
baseURL?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
// oxlint-disable-next-line no-unused-vars
|
||||||
|
export const createCherryIn = (_options?: CherryInProviderSettings) => ({
|
||||||
|
// oxlint-disable-next-line no-unused-vars
|
||||||
|
languageModel: (_modelId: string) => ({
|
||||||
|
specificationVersion: 'v1',
|
||||||
|
provider: 'cherryin',
|
||||||
|
modelId: 'mock-model',
|
||||||
|
doGenerate: async () => ({ text: 'mock response' }),
|
||||||
|
doStream: async () => ({ stream: (async function* () {})() })
|
||||||
|
}),
|
||||||
|
// oxlint-disable-next-line no-unused-vars
|
||||||
|
chat: (_modelId: string) => ({
|
||||||
|
specificationVersion: 'v1',
|
||||||
|
provider: 'cherryin-chat',
|
||||||
|
modelId: 'mock-model',
|
||||||
|
doGenerate: async () => ({ text: 'mock response' }),
|
||||||
|
doStream: async () => ({ stream: (async function* () {})() })
|
||||||
|
}),
|
||||||
|
// oxlint-disable-next-line no-unused-vars
|
||||||
|
textEmbeddingModel: (_modelId: string) => ({
|
||||||
|
specificationVersion: 'v1',
|
||||||
|
provider: 'cherryin',
|
||||||
|
modelId: 'mock-embedding-model'
|
||||||
|
})
|
||||||
|
})
|
||||||
9
packages/aiCore/src/__tests__/setup.ts
Normal file
9
packages/aiCore/src/__tests__/setup.ts
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
/**
|
||||||
|
* Vitest Setup File
|
||||||
|
* Global test configuration and mocks for @cherrystudio/ai-core package
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Mock Vite SSR helper to avoid Node environment errors
|
||||||
|
;(globalThis as any).__vite_ssr_exportName__ = (_name: string, value: any) => value
|
||||||
|
|
||||||
|
// Note: @cherrystudio/ai-sdk-provider is mocked via alias in vitest.config.ts
|
||||||
109
packages/aiCore/src/core/options/__tests__/factory.test.ts
Normal file
109
packages/aiCore/src/core/options/__tests__/factory.test.ts
Normal file
@@ -0,0 +1,109 @@
|
|||||||
|
import { describe, expect, it } from 'vitest'
|
||||||
|
|
||||||
|
import { createOpenAIOptions, createOpenRouterOptions, mergeProviderOptions } from '../factory'
|
||||||
|
|
||||||
|
describe('mergeProviderOptions', () => {
|
||||||
|
it('deep merges provider options for the same provider', () => {
|
||||||
|
const reasoningOptions = createOpenRouterOptions({
|
||||||
|
reasoning: {
|
||||||
|
enabled: true,
|
||||||
|
effort: 'medium'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
const webSearchOptions = createOpenRouterOptions({
|
||||||
|
plugins: [{ id: 'web', max_results: 5 }]
|
||||||
|
})
|
||||||
|
|
||||||
|
const merged = mergeProviderOptions(reasoningOptions, webSearchOptions)
|
||||||
|
|
||||||
|
expect(merged.openrouter).toEqual({
|
||||||
|
reasoning: {
|
||||||
|
enabled: true,
|
||||||
|
effort: 'medium'
|
||||||
|
},
|
||||||
|
plugins: [{ id: 'web', max_results: 5 }]
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('preserves options from other providers while merging', () => {
|
||||||
|
const openRouter = createOpenRouterOptions({
|
||||||
|
reasoning: { enabled: true }
|
||||||
|
})
|
||||||
|
const openAI = createOpenAIOptions({
|
||||||
|
reasoningEffort: 'low'
|
||||||
|
})
|
||||||
|
const merged = mergeProviderOptions(openRouter, openAI)
|
||||||
|
|
||||||
|
expect(merged.openrouter).toEqual({ reasoning: { enabled: true } })
|
||||||
|
expect(merged.openai).toEqual({ reasoningEffort: 'low' })
|
||||||
|
})
|
||||||
|
|
||||||
|
it('overwrites primitive values with later values', () => {
|
||||||
|
const first = createOpenAIOptions({
|
||||||
|
reasoningEffort: 'low',
|
||||||
|
user: 'user-123'
|
||||||
|
})
|
||||||
|
const second = createOpenAIOptions({
|
||||||
|
reasoningEffort: 'high',
|
||||||
|
maxToolCalls: 5
|
||||||
|
})
|
||||||
|
|
||||||
|
const merged = mergeProviderOptions(first, second)
|
||||||
|
|
||||||
|
expect(merged.openai).toEqual({
|
||||||
|
reasoningEffort: 'high', // overwritten by second
|
||||||
|
user: 'user-123', // preserved from first
|
||||||
|
maxToolCalls: 5 // added from second
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('overwrites arrays with later values instead of merging', () => {
|
||||||
|
const first = createOpenRouterOptions({
|
||||||
|
models: ['gpt-4', 'gpt-3.5-turbo']
|
||||||
|
})
|
||||||
|
const second = createOpenRouterOptions({
|
||||||
|
models: ['claude-3-opus', 'claude-3-sonnet']
|
||||||
|
})
|
||||||
|
|
||||||
|
const merged = mergeProviderOptions(first, second)
|
||||||
|
|
||||||
|
// Array is completely replaced, not merged
|
||||||
|
expect(merged.openrouter?.models).toEqual(['claude-3-opus', 'claude-3-sonnet'])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('deeply merges nested objects while overwriting primitives', () => {
|
||||||
|
const first = createOpenRouterOptions({
|
||||||
|
reasoning: {
|
||||||
|
enabled: true,
|
||||||
|
effort: 'low'
|
||||||
|
},
|
||||||
|
user: 'user-123'
|
||||||
|
})
|
||||||
|
const second = createOpenRouterOptions({
|
||||||
|
reasoning: {
|
||||||
|
effort: 'high',
|
||||||
|
max_tokens: 500
|
||||||
|
},
|
||||||
|
user: 'user-456'
|
||||||
|
})
|
||||||
|
|
||||||
|
const merged = mergeProviderOptions(first, second)
|
||||||
|
|
||||||
|
expect(merged.openrouter).toEqual({
|
||||||
|
reasoning: {
|
||||||
|
enabled: true, // preserved from first
|
||||||
|
effort: 'high', // overwritten by second
|
||||||
|
max_tokens: 500 // added from second
|
||||||
|
},
|
||||||
|
user: 'user-456' // overwritten by second
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('replaces arrays instead of merging them', () => {
|
||||||
|
const first = createOpenRouterOptions({ plugins: [{ id: 'old' }] })
|
||||||
|
const second = createOpenRouterOptions({ plugins: [{ id: 'new' }] })
|
||||||
|
const merged = mergeProviderOptions(first, second)
|
||||||
|
// @ts-expect-error type-check for openrouter options is skipped. see function signature of createOpenRouterOptions
|
||||||
|
expect(merged.openrouter?.plugins).toEqual([{ id: 'new' }])
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -26,13 +26,65 @@ export function createGenericProviderOptions<T extends string>(
|
|||||||
return { [provider]: options } as Record<T, Record<string, any>>
|
return { [provider]: options } as Record<T, Record<string, any>>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PlainObject = Record<string, any>
|
||||||
|
|
||||||
|
const isPlainObject = (value: unknown): value is PlainObject => {
|
||||||
|
return typeof value === 'object' && value !== null && !Array.isArray(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
function deepMergeObjects<T extends PlainObject>(target: T, source: PlainObject): T {
|
||||||
|
const result: PlainObject = { ...target }
|
||||||
|
Object.entries(source).forEach(([key, value]) => {
|
||||||
|
if (isPlainObject(value) && isPlainObject(result[key])) {
|
||||||
|
result[key] = deepMergeObjects(result[key], value)
|
||||||
|
} else {
|
||||||
|
result[key] = value
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return result as T
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 合并多个供应商的options
|
* Deep-merge multiple provider-specific options.
|
||||||
* @param optionsMap 包含多个供应商选项的对象
|
* Nested objects are recursively merged; primitive values are overwritten.
|
||||||
* @returns 合并后的TypedProviderOptions
|
*
|
||||||
|
* When the same key appears in multiple options:
|
||||||
|
* - If both values are plain objects: they are deeply merged (recursive merge)
|
||||||
|
* - If values are primitives/arrays: the later value overwrites the earlier one
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* mergeProviderOptions(
|
||||||
|
* { openrouter: { reasoning: { enabled: true, effort: 'low' }, user: 'user-123' } },
|
||||||
|
* { openrouter: { reasoning: { effort: 'high', max_tokens: 500 }, models: ['gpt-4'] } }
|
||||||
|
* )
|
||||||
|
* // Result: {
|
||||||
|
* // openrouter: {
|
||||||
|
* // reasoning: { enabled: true, effort: 'high', max_tokens: 500 },
|
||||||
|
* // user: 'user-123',
|
||||||
|
* // models: ['gpt-4']
|
||||||
|
* // }
|
||||||
|
* // }
|
||||||
|
*
|
||||||
|
* @param optionsMap Objects containing options for multiple providers
|
||||||
|
* @returns Fully merged TypedProviderOptions
|
||||||
*/
|
*/
|
||||||
export function mergeProviderOptions(...optionsMap: Partial<TypedProviderOptions>[]): TypedProviderOptions {
|
export function mergeProviderOptions(...optionsMap: Partial<TypedProviderOptions>[]): TypedProviderOptions {
|
||||||
return Object.assign({}, ...optionsMap)
|
return optionsMap.reduce<TypedProviderOptions>((acc, options) => {
|
||||||
|
if (!options) {
|
||||||
|
return acc
|
||||||
|
}
|
||||||
|
Object.entries(options).forEach(([providerId, providerOptions]) => {
|
||||||
|
if (!providerOptions) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if (acc[providerId]) {
|
||||||
|
acc[providerId] = deepMergeObjects(acc[providerId] as PlainObject, providerOptions as PlainObject)
|
||||||
|
} else {
|
||||||
|
acc[providerId] = providerOptions as any
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return acc
|
||||||
|
}, {} as TypedProviderOptions)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -19,15 +19,20 @@ describe('Provider Schemas', () => {
|
|||||||
expect(Array.isArray(baseProviders)).toBe(true)
|
expect(Array.isArray(baseProviders)).toBe(true)
|
||||||
expect(baseProviders.length).toBeGreaterThan(0)
|
expect(baseProviders.length).toBeGreaterThan(0)
|
||||||
|
|
||||||
|
// These are the actual base providers defined in schemas.ts
|
||||||
const expectedIds = [
|
const expectedIds = [
|
||||||
'openai',
|
'openai',
|
||||||
'openai-responses',
|
'openai-chat',
|
||||||
'openai-compatible',
|
'openai-compatible',
|
||||||
'anthropic',
|
'anthropic',
|
||||||
'google',
|
'google',
|
||||||
'xai',
|
'xai',
|
||||||
'azure',
|
'azure',
|
||||||
'deepseek'
|
'azure-responses',
|
||||||
|
'deepseek',
|
||||||
|
'openrouter',
|
||||||
|
'cherryin',
|
||||||
|
'cherryin-chat'
|
||||||
]
|
]
|
||||||
const actualIds = baseProviders.map((p) => p.id)
|
const actualIds = baseProviders.map((p) => p.id)
|
||||||
expectedIds.forEach((id) => {
|
expectedIds.forEach((id) => {
|
||||||
|
|||||||
@@ -232,11 +232,13 @@ describe('RuntimeExecutor.generateImage', () => {
|
|||||||
|
|
||||||
expect(pluginCallOrder).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
|
expect(pluginCallOrder).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
|
||||||
|
|
||||||
|
// transformParams receives params without model (model is handled separately)
|
||||||
|
// and context with core fields + dynamic fields (requestId, startTime, etc.)
|
||||||
expect(testPlugin.transformParams).toHaveBeenCalledWith(
|
expect(testPlugin.transformParams).toHaveBeenCalledWith(
|
||||||
{ prompt: 'A test image' },
|
expect.objectContaining({ prompt: 'A test image' }),
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
providerId: 'openai',
|
providerId: 'openai',
|
||||||
modelId: 'dall-e-3'
|
model: 'dall-e-3'
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -273,11 +275,12 @@ describe('RuntimeExecutor.generateImage', () => {
|
|||||||
|
|
||||||
await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
|
await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
|
||||||
|
|
||||||
|
// resolveModel receives model id and context with core fields
|
||||||
expect(modelResolutionPlugin.resolveModel).toHaveBeenCalledWith(
|
expect(modelResolutionPlugin.resolveModel).toHaveBeenCalledWith(
|
||||||
'dall-e-3',
|
'dall-e-3',
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
providerId: 'openai',
|
providerId: 'openai',
|
||||||
modelId: 'dall-e-3'
|
model: 'dall-e-3'
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -339,12 +342,11 @@ describe('RuntimeExecutor.generateImage', () => {
|
|||||||
.generateImage({ model: 'invalid-model', prompt: 'A test image' })
|
.generateImage({ model: 'invalid-model', prompt: 'A test image' })
|
||||||
.catch((error) => error)
|
.catch((error) => error)
|
||||||
|
|
||||||
expect(thrownError).toBeInstanceOf(ImageGenerationError)
|
// Error is thrown from pluginEngine directly as ImageModelResolutionError
|
||||||
expect(thrownError.message).toContain('Failed to generate image:')
|
expect(thrownError).toBeInstanceOf(ImageModelResolutionError)
|
||||||
|
expect(thrownError.message).toContain('Failed to resolve image model: invalid-model')
|
||||||
expect(thrownError.providerId).toBe('openai')
|
expect(thrownError.providerId).toBe('openai')
|
||||||
expect(thrownError.modelId).toBe('invalid-model')
|
expect(thrownError.modelId).toBe('invalid-model')
|
||||||
expect(thrownError.cause).toBeInstanceOf(ImageModelResolutionError)
|
|
||||||
expect(thrownError.cause.message).toContain('Failed to resolve image model: invalid-model')
|
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should handle ImageModelResolutionError without provider', async () => {
|
it('should handle ImageModelResolutionError without provider', async () => {
|
||||||
@@ -362,8 +364,9 @@ describe('RuntimeExecutor.generateImage', () => {
|
|||||||
const apiError = new Error('API request failed')
|
const apiError = new Error('API request failed')
|
||||||
vi.mocked(aiGenerateImage).mockRejectedValue(apiError)
|
vi.mocked(aiGenerateImage).mockRejectedValue(apiError)
|
||||||
|
|
||||||
|
// Error propagates directly from pluginEngine without wrapping
|
||||||
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
|
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
|
||||||
'Failed to generate image:'
|
'API request failed'
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -376,8 +379,9 @@ describe('RuntimeExecutor.generateImage', () => {
|
|||||||
vi.mocked(aiGenerateImage).mockRejectedValue(noImageError)
|
vi.mocked(aiGenerateImage).mockRejectedValue(noImageError)
|
||||||
vi.mocked(NoImageGeneratedError.isInstance).mockReturnValue(true)
|
vi.mocked(NoImageGeneratedError.isInstance).mockReturnValue(true)
|
||||||
|
|
||||||
|
// Error propagates directly from pluginEngine
|
||||||
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
|
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
|
||||||
'Failed to generate image:'
|
'No image generated'
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -398,15 +402,17 @@ describe('RuntimeExecutor.generateImage', () => {
|
|||||||
[errorPlugin]
|
[errorPlugin]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Error propagates directly from pluginEngine
|
||||||
await expect(executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
|
await expect(executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
|
||||||
'Failed to generate image:'
|
'Generation failed'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// onError receives the original error and context with core fields
|
||||||
expect(errorPlugin.onError).toHaveBeenCalledWith(
|
expect(errorPlugin.onError).toHaveBeenCalledWith(
|
||||||
error,
|
error,
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
providerId: 'openai',
|
providerId: 'openai',
|
||||||
modelId: 'dall-e-3'
|
model: 'dall-e-3'
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
@@ -419,9 +425,10 @@ describe('RuntimeExecutor.generateImage', () => {
|
|||||||
const abortController = new AbortController()
|
const abortController = new AbortController()
|
||||||
setTimeout(() => abortController.abort(), 10)
|
setTimeout(() => abortController.abort(), 10)
|
||||||
|
|
||||||
|
// Error propagates directly from pluginEngine
|
||||||
await expect(
|
await expect(
|
||||||
executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', abortSignal: abortController.signal })
|
executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', abortSignal: abortController.signal })
|
||||||
).rejects.toThrow('Failed to generate image:')
|
).rejects.toThrow('Operation was aborted')
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
|||||||
@@ -17,10 +17,14 @@ import type { AiPlugin } from '../../plugins'
|
|||||||
import { globalRegistryManagement } from '../../providers/RegistryManagement'
|
import { globalRegistryManagement } from '../../providers/RegistryManagement'
|
||||||
import { RuntimeExecutor } from '../executor'
|
import { RuntimeExecutor } from '../executor'
|
||||||
|
|
||||||
// Mock AI SDK
|
// Mock AI SDK - use importOriginal to keep jsonSchema and other non-mocked exports
|
||||||
vi.mock('ai', () => ({
|
vi.mock('ai', async (importOriginal) => {
|
||||||
generateText: vi.fn()
|
const actual = (await importOriginal()) as Record<string, unknown>
|
||||||
}))
|
return {
|
||||||
|
...actual,
|
||||||
|
generateText: vi.fn()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
vi.mock('../../providers/RegistryManagement', () => ({
|
vi.mock('../../providers/RegistryManagement', () => ({
|
||||||
globalRegistryManagement: {
|
globalRegistryManagement: {
|
||||||
@@ -409,11 +413,12 @@ describe('RuntimeExecutor.generateText', () => {
|
|||||||
})
|
})
|
||||||
).rejects.toThrow('Generation failed')
|
).rejects.toThrow('Generation failed')
|
||||||
|
|
||||||
|
// onError receives the original error and context with core fields
|
||||||
expect(errorPlugin.onError).toHaveBeenCalledWith(
|
expect(errorPlugin.onError).toHaveBeenCalledWith(
|
||||||
error,
|
error,
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
providerId: 'openai',
|
providerId: 'openai',
|
||||||
modelId: 'gpt-4'
|
model: 'gpt-4'
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -11,10 +11,14 @@ import type { AiPlugin } from '../../plugins'
|
|||||||
import { globalRegistryManagement } from '../../providers/RegistryManagement'
|
import { globalRegistryManagement } from '../../providers/RegistryManagement'
|
||||||
import { RuntimeExecutor } from '../executor'
|
import { RuntimeExecutor } from '../executor'
|
||||||
|
|
||||||
// Mock AI SDK
|
// Mock AI SDK - use importOriginal to keep jsonSchema and other non-mocked exports
|
||||||
vi.mock('ai', () => ({
|
vi.mock('ai', async (importOriginal) => {
|
||||||
streamText: vi.fn()
|
const actual = (await importOriginal()) as Record<string, unknown>
|
||||||
}))
|
return {
|
||||||
|
...actual,
|
||||||
|
streamText: vi.fn()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
vi.mock('../../providers/RegistryManagement', () => ({
|
vi.mock('../../providers/RegistryManagement', () => ({
|
||||||
globalRegistryManagement: {
|
globalRegistryManagement: {
|
||||||
@@ -153,7 +157,7 @@ describe('RuntimeExecutor.streamText', () => {
|
|||||||
describe('Max Tokens Parameter', () => {
|
describe('Max Tokens Parameter', () => {
|
||||||
const maxTokensValues = [10, 50, 100, 500, 1000, 2000, 4000]
|
const maxTokensValues = [10, 50, 100, 500, 1000, 2000, 4000]
|
||||||
|
|
||||||
it.each(maxTokensValues)('should support maxTokens=%s', async (maxTokens) => {
|
it.each(maxTokensValues)('should support maxOutputTokens=%s', async (maxOutputTokens) => {
|
||||||
const mockStream = {
|
const mockStream = {
|
||||||
textStream: (async function* () {
|
textStream: (async function* () {
|
||||||
yield 'Response'
|
yield 'Response'
|
||||||
@@ -168,12 +172,13 @@ describe('RuntimeExecutor.streamText', () => {
|
|||||||
await executor.streamText({
|
await executor.streamText({
|
||||||
model: 'gpt-4',
|
model: 'gpt-4',
|
||||||
messages: testMessages.simple,
|
messages: testMessages.simple,
|
||||||
maxOutputTokens: maxTokens
|
maxOutputTokens
|
||||||
})
|
})
|
||||||
|
|
||||||
|
// Parameters are passed through without transformation
|
||||||
expect(streamText).toHaveBeenCalledWith(
|
expect(streamText).toHaveBeenCalledWith(
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
maxTokens
|
maxOutputTokens
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
@@ -513,11 +518,12 @@ describe('RuntimeExecutor.streamText', () => {
|
|||||||
})
|
})
|
||||||
).rejects.toThrow('Stream error')
|
).rejects.toThrow('Stream error')
|
||||||
|
|
||||||
|
// onError receives the original error and context with core fields
|
||||||
expect(errorPlugin.onError).toHaveBeenCalledWith(
|
expect(errorPlugin.onError).toHaveBeenCalledWith(
|
||||||
error,
|
error,
|
||||||
expect.objectContaining({
|
expect.objectContaining({
|
||||||
providerId: 'openai',
|
providerId: 'openai',
|
||||||
modelId: 'gpt-4'
|
model: 'gpt-4'
|
||||||
})
|
})
|
||||||
)
|
)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,12 +1,20 @@
|
|||||||
|
import path from 'node:path'
|
||||||
|
import { fileURLToPath } from 'node:url'
|
||||||
|
|
||||||
import { defineConfig } from 'vitest/config'
|
import { defineConfig } from 'vitest/config'
|
||||||
|
|
||||||
|
const __dirname = path.dirname(fileURLToPath(import.meta.url))
|
||||||
|
|
||||||
export default defineConfig({
|
export default defineConfig({
|
||||||
test: {
|
test: {
|
||||||
globals: true
|
globals: true,
|
||||||
|
setupFiles: [path.resolve(__dirname, './src/__tests__/setup.ts')]
|
||||||
},
|
},
|
||||||
resolve: {
|
resolve: {
|
||||||
alias: {
|
alias: {
|
||||||
'@': './src'
|
'@': path.resolve(__dirname, './src'),
|
||||||
|
// Mock external packages that may not be available in test environment
|
||||||
|
'@cherrystudio/ai-sdk-provider': path.resolve(__dirname, './src/__tests__/mocks/ai-sdk-provider.ts')
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
esbuild: {
|
esbuild: {
|
||||||
|
|||||||
@@ -9,13 +9,27 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import Anthropic from '@anthropic-ai/sdk'
|
import Anthropic from '@anthropic-ai/sdk'
|
||||||
import type { TextBlockParam } from '@anthropic-ai/sdk/resources'
|
import type { MessageCreateParams, TextBlockParam, Tool as AnthropicTool } from '@anthropic-ai/sdk/resources'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import type { Provider } from '@types'
|
import { type Provider, SystemProviderIds } from '@types'
|
||||||
import type { ModelMessage } from 'ai'
|
import type { ModelMessage } from 'ai'
|
||||||
|
|
||||||
const logger = loggerService.withContext('anthropic-sdk')
|
const logger = loggerService.withContext('anthropic-sdk')
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Context for Anthropic SDK client creation.
|
||||||
|
* This allows the shared module to be used in different environments
|
||||||
|
* by providing environment-specific implementations.
|
||||||
|
*/
|
||||||
|
export interface AnthropicSdkContext {
|
||||||
|
/**
|
||||||
|
* Custom fetch function to use for HTTP requests.
|
||||||
|
* In Electron main process, this should be `net.fetch`.
|
||||||
|
* In other environments, can use the default fetch or a custom implementation.
|
||||||
|
*/
|
||||||
|
fetch?: typeof globalThis.fetch
|
||||||
|
}
|
||||||
|
|
||||||
const defaultClaudeCodeSystemPrompt = `You are Claude Code, Anthropic's official CLI for Claude.`
|
const defaultClaudeCodeSystemPrompt = `You are Claude Code, Anthropic's official CLI for Claude.`
|
||||||
|
|
||||||
const defaultClaudeCodeSystem: Array<TextBlockParam> = [
|
const defaultClaudeCodeSystem: Array<TextBlockParam> = [
|
||||||
@@ -58,8 +72,11 @@ const defaultClaudeCodeSystem: Array<TextBlockParam> = [
|
|||||||
export function getSdkClient(
|
export function getSdkClient(
|
||||||
provider: Provider,
|
provider: Provider,
|
||||||
oauthToken?: string | null,
|
oauthToken?: string | null,
|
||||||
extraHeaders?: Record<string, string | string[]>
|
extraHeaders?: Record<string, string | string[]>,
|
||||||
|
context?: AnthropicSdkContext
|
||||||
): Anthropic {
|
): Anthropic {
|
||||||
|
const customFetch = context?.fetch
|
||||||
|
|
||||||
if (provider.authType === 'oauth') {
|
if (provider.authType === 'oauth') {
|
||||||
if (!oauthToken) {
|
if (!oauthToken) {
|
||||||
throw new Error('OAuth token is not available')
|
throw new Error('OAuth token is not available')
|
||||||
@@ -85,7 +102,8 @@ export function getSdkClient(
|
|||||||
'x-stainless-runtime': 'node',
|
'x-stainless-runtime': 'node',
|
||||||
'x-stainless-runtime-version': 'v22.18.0',
|
'x-stainless-runtime-version': 'v22.18.0',
|
||||||
...extraHeaders
|
...extraHeaders
|
||||||
}
|
},
|
||||||
|
fetch: customFetch
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
let baseURL =
|
let baseURL =
|
||||||
@@ -106,11 +124,12 @@ export function getSdkClient(
|
|||||||
baseURL,
|
baseURL,
|
||||||
dangerouslyAllowBrowser: true,
|
dangerouslyAllowBrowser: true,
|
||||||
defaultHeaders: {
|
defaultHeaders: {
|
||||||
'anthropic-beta': 'output-128k-2025-02-19',
|
'anthropic-beta': 'interleaved-thinking-2025-05-14',
|
||||||
'APP-Code': 'MLTG2087',
|
'APP-Code': 'MLTG2087',
|
||||||
...provider.extra_headers,
|
...provider.extra_headers,
|
||||||
...extraHeaders
|
...extraHeaders
|
||||||
}
|
},
|
||||||
|
fetch: customFetch
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -120,9 +139,11 @@ export function getSdkClient(
|
|||||||
baseURL,
|
baseURL,
|
||||||
dangerouslyAllowBrowser: true,
|
dangerouslyAllowBrowser: true,
|
||||||
defaultHeaders: {
|
defaultHeaders: {
|
||||||
'anthropic-beta': 'output-128k-2025-02-19',
|
'anthropic-beta': 'interleaved-thinking-2025-05-14',
|
||||||
|
Authorization: provider.id === SystemProviderIds.longcat ? `Bearer ${provider.apiKey}` : undefined,
|
||||||
...provider.extra_headers
|
...provider.extra_headers
|
||||||
}
|
},
|
||||||
|
fetch: customFetch
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -173,3 +194,31 @@ export function buildClaudeCodeSystemModelMessage(system?: string | Array<TextBl
|
|||||||
content: block.text
|
content: block.text
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sanitize tool definitions for Anthropic API.
|
||||||
|
*
|
||||||
|
* Removes non-standard fields like `input_examples` from tool definitions
|
||||||
|
* that Anthropic's API doesn't support. This prevents validation errors when
|
||||||
|
* tools with extended fields are passed to the Anthropic SDK.
|
||||||
|
*
|
||||||
|
* @param tools - Array of tool definitions from MessageCreateParams
|
||||||
|
* @returns Sanitized tools array with non-standard fields removed
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* ```typescript
|
||||||
|
* const sanitizedTools = sanitizeToolsForAnthropic(request.tools)
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
export function sanitizeToolsForAnthropic(tools?: MessageCreateParams['tools']): MessageCreateParams['tools'] {
|
||||||
|
if (!tools || tools.length === 0) return tools
|
||||||
|
|
||||||
|
return tools.map((tool) => {
|
||||||
|
if ('type' in tool && tool.type !== 'custom') return tool
|
||||||
|
|
||||||
|
// oxlint-disable-next-line no-unused-vars
|
||||||
|
const { input_examples, ...sanitizedTool } = tool as AnthropicTool & { input_examples?: unknown }
|
||||||
|
|
||||||
|
return sanitizedTool as typeof tool
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
245
packages/shared/api/index.ts
Normal file
245
packages/shared/api/index.ts
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
/**
|
||||||
|
* Shared API Utilities
|
||||||
|
*
|
||||||
|
* Common utilities for API URL formatting and validation.
|
||||||
|
* Used by both main process (API Server) and renderer.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type { MinimalProvider } from '@shared/provider'
|
||||||
|
import { trim } from 'lodash'
|
||||||
|
|
||||||
|
// Supported endpoints for routing
|
||||||
|
export const SUPPORTED_IMAGE_ENDPOINT_LIST = ['images/generations', 'images/edits', 'predict'] as const
|
||||||
|
export const SUPPORTED_ENDPOINT_LIST = [
|
||||||
|
'chat/completions',
|
||||||
|
'responses',
|
||||||
|
'messages',
|
||||||
|
'generateContent',
|
||||||
|
'streamGenerateContent',
|
||||||
|
...SUPPORTED_IMAGE_ENDPOINT_LIST
|
||||||
|
] as const
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Removes the trailing slash from a URL string if it exists.
|
||||||
|
*/
|
||||||
|
export function withoutTrailingSlash<T extends string>(url: T): T {
|
||||||
|
return url.replace(/\/$/, '') as T
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Matches a version segment in a path that starts with `/v<number>` and optionally
|
||||||
|
* continues with `alpha` or `beta`. The segment may be followed by `/` or the end
|
||||||
|
* of the string (useful for cases like `/v3alpha/resources`).
|
||||||
|
*/
|
||||||
|
const VERSION_REGEX_PATTERN = '\\/v\\d+(?:alpha|beta)?(?=\\/|$)'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Matches an API version at the end of a URL (with optional trailing slash).
|
||||||
|
* Used to detect and extract versions only from the trailing position.
|
||||||
|
*/
|
||||||
|
const TRAILING_VERSION_REGEX = /\/v\d+(?:alpha|beta)?\/?$/i
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 判断 host 的 path 中是否包含形如版本的字符串(例如 /v1、/v2beta 等),
|
||||||
|
*
|
||||||
|
* @param host - 要检查的 host 或 path 字符串
|
||||||
|
* @returns 如果 path 中包含版本字符串则返回 true,否则 false
|
||||||
|
*/
|
||||||
|
export function hasAPIVersion(host?: string): boolean {
|
||||||
|
if (!host) return false
|
||||||
|
|
||||||
|
const regex = new RegExp(VERSION_REGEX_PATTERN, 'i')
|
||||||
|
|
||||||
|
try {
|
||||||
|
const url = new URL(host)
|
||||||
|
return regex.test(url.pathname)
|
||||||
|
} catch {
|
||||||
|
// 若无法作为完整 URL 解析,则当作路径直接检测
|
||||||
|
return regex.test(host)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 格式化 Azure OpenAI 的 API 主机地址。
|
||||||
|
*/
|
||||||
|
export function formatAzureOpenAIApiHost(host: string): string {
|
||||||
|
const normalizedHost = withoutTrailingSlash(host)
|
||||||
|
?.replace(/\/v1$/, '')
|
||||||
|
.replace(/\/openai$/, '')
|
||||||
|
// NOTE: AISDK会添加上`v1`
|
||||||
|
return formatApiHost(normalizedHost + '/openai', false)
|
||||||
|
}
|
||||||
|
|
||||||
|
export function formatVertexApiHost(
|
||||||
|
provider: MinimalProvider,
|
||||||
|
project: string = 'test-project',
|
||||||
|
location: string = 'us-central1'
|
||||||
|
): string {
|
||||||
|
const { apiHost } = provider
|
||||||
|
const trimmedHost = withoutTrailingSlash(trim(apiHost))
|
||||||
|
if (!trimmedHost || trimmedHost.endsWith('aiplatform.googleapis.com')) {
|
||||||
|
const host =
|
||||||
|
location === 'global' ? 'https://aiplatform.googleapis.com' : `https://${location}-aiplatform.googleapis.com`
|
||||||
|
return `${formatApiHost(host)}/projects/${project}/locations/${location}`
|
||||||
|
}
|
||||||
|
return formatApiHost(trimmedHost)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Formats an API host URL by normalizing it and optionally appending an API version.
|
||||||
|
*
|
||||||
|
* @param host - The API host URL to format. Leading/trailing whitespace will be trimmed and trailing slashes removed.
|
||||||
|
* @param supportApiVersion - Whether the API version is supported. Defaults to `true`.
|
||||||
|
* @param apiVersion - The API version to append if needed. Defaults to `'v1'`.
|
||||||
|
*
|
||||||
|
* @returns The formatted API host URL. If the host is empty after normalization, returns an empty string.
|
||||||
|
* If the host ends with '#', API version is not supported, or the host already contains a version, returns the normalized host as-is.
|
||||||
|
* Otherwise, returns the host with the API version appended.
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* formatApiHost('https://api.example.com/') // Returns 'https://api.example.com/v1'
|
||||||
|
* formatApiHost('https://api.example.com#') // Returns 'https://api.example.com#'
|
||||||
|
* formatApiHost('https://api.example.com/v2', true, 'v1') // Returns 'https://api.example.com/v2'
|
||||||
|
*/
|
||||||
|
export function formatApiHost(host?: string, supportApiVersion: boolean = true, apiVersion: string = 'v1'): string {
|
||||||
|
const normalizedHost = withoutTrailingSlash(trim(host))
|
||||||
|
if (!normalizedHost) {
|
||||||
|
return ''
|
||||||
|
}
|
||||||
|
|
||||||
|
if (normalizedHost.endsWith('#') || !supportApiVersion || hasAPIVersion(normalizedHost)) {
|
||||||
|
return normalizedHost
|
||||||
|
}
|
||||||
|
return `${normalizedHost}/${apiVersion}`
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts an API host URL into separate base URL and endpoint components.
|
||||||
|
*
|
||||||
|
* This function extracts endpoint information from a composite API host string.
|
||||||
|
* If the host ends with '#', it attempts to match the preceding part against the supported endpoint list.
|
||||||
|
*
|
||||||
|
* @param apiHost - The API host string to parse
|
||||||
|
* @returns An object containing:
|
||||||
|
* - `baseURL`: The base URL without the endpoint suffix
|
||||||
|
* - `endpoint`: The matched endpoint identifier, or empty string if no match found
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* routeToEndpoint('https://api.example.com/openai/chat/completions#')
|
||||||
|
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: 'chat/completions' }
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* routeToEndpoint('https://api.example.com/v1')
|
||||||
|
* // Returns: { baseURL: 'https://api.example.com/v1', endpoint: '' }
|
||||||
|
*/
|
||||||
|
export function routeToEndpoint(apiHost: string): { baseURL: string; endpoint: string } {
|
||||||
|
const trimmedHost = (apiHost || '').trim()
|
||||||
|
if (!trimmedHost.endsWith('#')) {
|
||||||
|
return { baseURL: trimmedHost, endpoint: '' }
|
||||||
|
}
|
||||||
|
// Remove trailing #
|
||||||
|
const host = trimmedHost.slice(0, -1)
|
||||||
|
const endpointMatch = SUPPORTED_ENDPOINT_LIST.find((endpoint) => host.endsWith(endpoint))
|
||||||
|
if (!endpointMatch) {
|
||||||
|
const baseURL = withoutTrailingSlash(host)
|
||||||
|
return { baseURL, endpoint: '' }
|
||||||
|
}
|
||||||
|
const baseSegment = host.slice(0, host.length - endpointMatch.length)
|
||||||
|
const baseURL = withoutTrailingSlash(baseSegment).replace(/:$/, '') // Remove trailing colon (gemini special case)
|
||||||
|
return { baseURL, endpoint: endpointMatch }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Gets the AI SDK compatible base URL from a provider's apiHost.
|
||||||
|
*
|
||||||
|
* AI SDK expects baseURL WITH version suffix (e.g., /v1).
|
||||||
|
* This function:
|
||||||
|
* 1. Handles '#' endpoint routing format
|
||||||
|
* 2. Ensures the URL has a version suffix (adds /v1 if missing)
|
||||||
|
*
|
||||||
|
* @param apiHost - The provider's apiHost value (may or may not have /v1)
|
||||||
|
* @param apiVersion - The API version to use if missing. Defaults to 'v1'.
|
||||||
|
* @returns The baseURL suitable for AI SDK (with version suffix)
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* getAiSdkBaseUrl('https://api.openai.com') // 'https://api.openai.com/v1'
|
||||||
|
* getAiSdkBaseUrl('https://api.openai.com/v1') // 'https://api.openai.com/v1'
|
||||||
|
* getAiSdkBaseUrl('https://api.example.com/chat/completions#') // 'https://api.example.com'
|
||||||
|
*/
|
||||||
|
export function getAiSdkBaseUrl(apiHost: string, apiVersion: string = 'v1'): string {
|
||||||
|
// First handle '#' endpoint routing format
|
||||||
|
const { baseURL } = routeToEndpoint(apiHost)
|
||||||
|
|
||||||
|
// If already has version, return as-is
|
||||||
|
if (hasAPIVersion(baseURL)) {
|
||||||
|
return withoutTrailingSlash(baseURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add version suffix
|
||||||
|
return `${withoutTrailingSlash(baseURL)}/${apiVersion}`
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates an API host address.
|
||||||
|
*
|
||||||
|
* @param apiHost - The API host address to validate
|
||||||
|
* @returns true if valid URL with http/https protocol, false otherwise
|
||||||
|
*/
|
||||||
|
export function validateApiHost(apiHost: string): boolean {
|
||||||
|
if (!apiHost || !apiHost.trim()) {
|
||||||
|
return true // Allow empty
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
const url = new URL(apiHost.trim())
|
||||||
|
return url.protocol === 'http:' || url.protocol === 'https:'
|
||||||
|
} catch {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extracts the trailing API version segment from a URL path.
|
||||||
|
*
|
||||||
|
* This function extracts API version patterns (e.g., `v1`, `v2beta`) from the end of a URL.
|
||||||
|
* Only versions at the end of the path are extracted, not versions in the middle.
|
||||||
|
* The returned version string does not include leading or trailing slashes.
|
||||||
|
*
|
||||||
|
* @param {string} url - The URL string to parse.
|
||||||
|
* @returns {string | undefined} The trailing API version found (e.g., 'v1', 'v2beta'), or undefined if none found.
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* getTrailingApiVersion('https://api.example.com/v1') // 'v1'
|
||||||
|
* getTrailingApiVersion('https://api.example.com/v2beta/') // 'v2beta'
|
||||||
|
* getTrailingApiVersion('https://api.example.com/v1/chat') // undefined (version not at end)
|
||||||
|
* getTrailingApiVersion('https://gateway.ai.cloudflare.com/v1/xxx/v1beta') // 'v1beta'
|
||||||
|
* getTrailingApiVersion('https://api.example.com') // undefined
|
||||||
|
*/
|
||||||
|
export function getTrailingApiVersion(url: string): string | undefined {
|
||||||
|
const match = url.match(TRAILING_VERSION_REGEX)
|
||||||
|
|
||||||
|
if (match) {
|
||||||
|
// Extract version without leading slash and trailing slash
|
||||||
|
return match[0].replace(/^\//, '').replace(/\/$/, '')
|
||||||
|
}
|
||||||
|
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Removes the trailing API version segment from a URL path.
|
||||||
|
*
|
||||||
|
* This function removes API version patterns (e.g., `/v1`, `/v2beta`) from the end of a URL.
|
||||||
|
* Only versions at the end of the path are removed, not versions in the middle.
|
||||||
|
*
|
||||||
|
* @param {string} url - The URL string to process.
|
||||||
|
* @returns {string} The URL with the trailing API version removed, or the original URL if no trailing version found.
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* withoutTrailingApiVersion('https://api.example.com/v1') // 'https://api.example.com'
|
||||||
|
* withoutTrailingApiVersion('https://api.example.com/v2beta/') // 'https://api.example.com'
|
||||||
|
* withoutTrailingApiVersion('https://api.example.com/v1/chat') // 'https://api.example.com/v1/chat' (no change)
|
||||||
|
* withoutTrailingApiVersion('https://api.example.com') // 'https://api.example.com'
|
||||||
|
*/
|
||||||
|
export function withoutTrailingApiVersion(url: string): string {
|
||||||
|
return url.replace(TRAILING_VERSION_REGEX, '')
|
||||||
|
}
|
||||||
@@ -43,6 +43,35 @@ export function isSiliconAnthropicCompatibleModel(modelId: string): boolean {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Silicon provider's Anthropic API host URL.
|
* PPIO provider models that support Anthropic API endpoint.
|
||||||
|
* These models can be used with Claude Code via the Anthropic-compatible API.
|
||||||
|
*
|
||||||
|
* @see https://ppio.com/docs/model/llm-anthropic-compatibility
|
||||||
*/
|
*/
|
||||||
export const SILICON_ANTHROPIC_API_HOST = 'https://api.siliconflow.cn'
|
export const PPIO_ANTHROPIC_COMPATIBLE_MODELS: readonly string[] = [
|
||||||
|
'moonshotai/kimi-k2-thinking',
|
||||||
|
'minimax/minimax-m2',
|
||||||
|
'deepseek/deepseek-v3.2-exp',
|
||||||
|
'deepseek/deepseek-v3.1-terminus',
|
||||||
|
'zai-org/glm-4.6',
|
||||||
|
'moonshotai/kimi-k2-0905',
|
||||||
|
'deepseek/deepseek-v3.1',
|
||||||
|
'moonshotai/kimi-k2-instruct',
|
||||||
|
'qwen/qwen3-next-80b-a3b-instruct',
|
||||||
|
'qwen/qwen3-next-80b-a3b-thinking'
|
||||||
|
]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a Set for efficient lookup of PPIO Anthropic-compatible model IDs.
|
||||||
|
*/
|
||||||
|
const PPIO_ANTHROPIC_COMPATIBLE_MODEL_SET = new Set(PPIO_ANTHROPIC_COMPATIBLE_MODELS)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Checks if a model ID is compatible with Anthropic API on PPIO provider.
|
||||||
|
*
|
||||||
|
* @param modelId - The model ID to check
|
||||||
|
* @returns true if the model supports Anthropic API endpoint
|
||||||
|
*/
|
||||||
|
export function isPpioAnthropicCompatibleModel(modelId: string): boolean {
|
||||||
|
return PPIO_ANTHROPIC_COMPATIBLE_MODEL_SET.has(modelId)
|
||||||
|
}
|
||||||
|
|||||||
15
packages/shared/middleware/index.ts
Normal file
15
packages/shared/middleware/index.ts
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
/**
|
||||||
|
* Shared AI SDK Middlewares
|
||||||
|
*
|
||||||
|
* Environment-agnostic middlewares that can be used in both
|
||||||
|
* renderer process and main process (API server).
|
||||||
|
*/
|
||||||
|
|
||||||
|
export {
|
||||||
|
buildSharedMiddlewares,
|
||||||
|
getReasoningTagName,
|
||||||
|
isGemini3ModelId,
|
||||||
|
openrouterReasoningMiddleware,
|
||||||
|
type SharedMiddlewareConfig,
|
||||||
|
skipGeminiThoughtSignatureMiddleware
|
||||||
|
} from './middlewares'
|
||||||
205
packages/shared/middleware/middlewares.ts
Normal file
205
packages/shared/middleware/middlewares.ts
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
/**
|
||||||
|
* Shared AI SDK Middlewares
|
||||||
|
*
|
||||||
|
* These middlewares are environment-agnostic and can be used in both
|
||||||
|
* renderer process and main process (API server).
|
||||||
|
*/
|
||||||
|
import type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider'
|
||||||
|
import { extractReasoningMiddleware } from 'ai'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configuration for building shared middlewares
|
||||||
|
*/
|
||||||
|
export interface SharedMiddlewareConfig {
|
||||||
|
/**
|
||||||
|
* Whether to enable reasoning extraction
|
||||||
|
*/
|
||||||
|
enableReasoning?: boolean
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Tag name for reasoning extraction
|
||||||
|
* Defaults based on model ID
|
||||||
|
*/
|
||||||
|
reasoningTagName?: string
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Model ID - used to determine default reasoning tag and model detection
|
||||||
|
*/
|
||||||
|
modelId?: string
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provider ID (Cherry Studio provider ID)
|
||||||
|
* Used for provider-specific middlewares like OpenRouter
|
||||||
|
*/
|
||||||
|
providerId?: string
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AI SDK Provider ID
|
||||||
|
* Used for Gemini thought signature middleware
|
||||||
|
* e.g., 'google', 'google-vertex'
|
||||||
|
*/
|
||||||
|
aiSdkProviderId?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if model ID represents a Gemini 3 (2.5) model
|
||||||
|
* that requires thought signature handling
|
||||||
|
*
|
||||||
|
* @param modelId - The model ID string (not Model object)
|
||||||
|
*/
|
||||||
|
export function isGemini3ModelId(modelId?: string): boolean {
|
||||||
|
if (!modelId) return false
|
||||||
|
const lowerModelId = modelId.toLowerCase()
|
||||||
|
return lowerModelId.includes('gemini-3')
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the default reasoning tag name based on model ID
|
||||||
|
*
|
||||||
|
* Different models use different tags for reasoning content:
|
||||||
|
* - Most models: 'think'
|
||||||
|
* - GPT-OSS models: 'reasoning'
|
||||||
|
* - Gemini models: 'thought'
|
||||||
|
* - Seed models: 'seed:think'
|
||||||
|
*/
|
||||||
|
export function getReasoningTagName(modelId?: string): string {
|
||||||
|
if (!modelId) return 'think'
|
||||||
|
const lowerModelId = modelId.toLowerCase()
|
||||||
|
if (lowerModelId.includes('gpt-oss')) return 'reasoning'
|
||||||
|
if (lowerModelId.includes('gemini')) return 'thought'
|
||||||
|
if (lowerModelId.includes('seed-oss-36b')) return 'seed:think'
|
||||||
|
return 'think'
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Skip Gemini Thought Signature Middleware
|
||||||
|
*
|
||||||
|
* Due to the complexity of multi-model client requests (which can switch
|
||||||
|
* to other models mid-process), this middleware skips all Gemini 3
|
||||||
|
* thinking signatures validation.
|
||||||
|
*
|
||||||
|
* @param aiSdkId - AI SDK Provider ID (e.g., 'google', 'google-vertex')
|
||||||
|
* @returns LanguageModelV2Middleware
|
||||||
|
*/
|
||||||
|
export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelV2Middleware {
|
||||||
|
const MAGIC_STRING = 'skip_thought_signature_validator'
|
||||||
|
return {
|
||||||
|
middlewareVersion: 'v2',
|
||||||
|
|
||||||
|
transformParams: async ({ params }) => {
|
||||||
|
const transformedParams = { ...params }
|
||||||
|
// Process messages in prompt
|
||||||
|
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
|
||||||
|
transformedParams.prompt = transformedParams.prompt.map((message) => {
|
||||||
|
if (typeof message.content !== 'string') {
|
||||||
|
for (const part of message.content) {
|
||||||
|
const googleOptions = part?.providerOptions?.[aiSdkId]
|
||||||
|
if (googleOptions?.thoughtSignature) {
|
||||||
|
googleOptions.thoughtSignature = MAGIC_STRING
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return message
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return transformedParams
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* OpenRouter Reasoning Middleware
|
||||||
|
*
|
||||||
|
* Filters out [REDACTED] blocks from OpenRouter reasoning responses.
|
||||||
|
* OpenRouter may include [REDACTED] markers in reasoning content that
|
||||||
|
* should be removed for cleaner output.
|
||||||
|
*
|
||||||
|
* @see https://openrouter.ai/docs/docs/best-practices/reasoning-tokens
|
||||||
|
* @returns LanguageModelV2Middleware
|
||||||
|
*/
|
||||||
|
export function openrouterReasoningMiddleware(): LanguageModelV2Middleware {
|
||||||
|
const REDACTED_BLOCK = '[REDACTED]'
|
||||||
|
return {
|
||||||
|
middlewareVersion: 'v2',
|
||||||
|
wrapGenerate: async ({ doGenerate }) => {
|
||||||
|
const { content, ...rest } = await doGenerate()
|
||||||
|
const modifiedContent = content.map((part) => {
|
||||||
|
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
|
||||||
|
return {
|
||||||
|
...part,
|
||||||
|
text: part.text.replace(REDACTED_BLOCK, '')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return part
|
||||||
|
})
|
||||||
|
return { content: modifiedContent, ...rest }
|
||||||
|
},
|
||||||
|
wrapStream: async ({ doStream }) => {
|
||||||
|
const { stream, ...rest } = await doStream()
|
||||||
|
return {
|
||||||
|
stream: stream.pipeThrough(
|
||||||
|
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
|
||||||
|
transform(
|
||||||
|
chunk: LanguageModelV2StreamPart,
|
||||||
|
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
|
||||||
|
) {
|
||||||
|
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
|
||||||
|
controller.enqueue({
|
||||||
|
...chunk,
|
||||||
|
delta: chunk.delta.replace(REDACTED_BLOCK, '')
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
),
|
||||||
|
...rest
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build shared middlewares based on configuration
|
||||||
|
*
|
||||||
|
* This function builds a set of middlewares that are commonly needed
|
||||||
|
* across different environments (renderer, API server).
|
||||||
|
*
|
||||||
|
* @param config - Configuration for middleware building
|
||||||
|
* @returns Array of AI SDK middlewares
|
||||||
|
*
|
||||||
|
* @example
|
||||||
|
* ```typescript
|
||||||
|
* import { buildSharedMiddlewares } from '@shared/middleware'
|
||||||
|
*
|
||||||
|
* const middlewares = buildSharedMiddlewares({
|
||||||
|
* enableReasoning: true,
|
||||||
|
* modelId: 'gemini-2.5-pro',
|
||||||
|
* providerId: 'openrouter',
|
||||||
|
* aiSdkProviderId: 'google'
|
||||||
|
* })
|
||||||
|
* ```
|
||||||
|
*/
|
||||||
|
export function buildSharedMiddlewares(config: SharedMiddlewareConfig): LanguageModelV2Middleware[] {
|
||||||
|
const middlewares: LanguageModelV2Middleware[] = []
|
||||||
|
|
||||||
|
// 1. Reasoning extraction middleware
|
||||||
|
if (config.enableReasoning) {
|
||||||
|
const tagName = config.reasoningTagName || getReasoningTagName(config.modelId)
|
||||||
|
middlewares.push(extractReasoningMiddleware({ tagName }))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. OpenRouter-specific: filter [REDACTED] blocks
|
||||||
|
if (config.providerId === 'openrouter' && config.enableReasoning) {
|
||||||
|
middlewares.push(openrouterReasoningMiddleware())
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. Gemini 3 (2.5) specific: skip thought signature validation
|
||||||
|
if (isGemini3ModelId(config.modelId) && config.aiSdkProviderId) {
|
||||||
|
middlewares.push(skipGeminiThoughtSignatureMiddleware(config.aiSdkProviderId))
|
||||||
|
}
|
||||||
|
|
||||||
|
return middlewares
|
||||||
|
}
|
||||||
@@ -1,13 +1,13 @@
|
|||||||
/**
|
/**
|
||||||
* AiHubMix规则集
|
* AiHubMix规则集
|
||||||
*/
|
*/
|
||||||
import { isOpenAILLMModel } from '@renderer/config/models'
|
import { getLowerBaseModelName } from '@shared/utils/naming'
|
||||||
import type { Provider } from '@renderer/types'
|
|
||||||
|
|
||||||
|
import type { MinimalModel, MinimalProvider } from '../types'
|
||||||
import { provider2Provider, startsWith } from './helper'
|
import { provider2Provider, startsWith } from './helper'
|
||||||
import type { RuleSet } from './types'
|
import type { RuleSet } from './types'
|
||||||
|
|
||||||
const extraProviderConfig = (provider: Provider) => {
|
const extraProviderConfig = <P extends MinimalProvider>(provider: P) => {
|
||||||
return {
|
return {
|
||||||
...provider,
|
...provider,
|
||||||
extra_headers: {
|
extra_headers: {
|
||||||
@@ -17,11 +17,23 @@ const extraProviderConfig = (provider: Provider) => {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function isOpenAILLMModel<M extends MinimalModel>(model: M): boolean {
|
||||||
|
const modelId = getLowerBaseModelName(model.id)
|
||||||
|
const reasonings = ['o1', 'o3', 'o4', 'gpt-oss']
|
||||||
|
if (reasonings.some((r) => modelId.includes(r))) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if (modelId.includes('gpt')) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
const AIHUBMIX_RULES: RuleSet = {
|
const AIHUBMIX_RULES: RuleSet = {
|
||||||
rules: [
|
rules: [
|
||||||
{
|
{
|
||||||
match: startsWith('claude'),
|
match: startsWith('claude'),
|
||||||
provider: (provider: Provider) => {
|
provider: (provider) => {
|
||||||
return extraProviderConfig({
|
return extraProviderConfig({
|
||||||
...provider,
|
...provider,
|
||||||
type: 'anthropic'
|
type: 'anthropic'
|
||||||
@@ -34,7 +46,7 @@ const AIHUBMIX_RULES: RuleSet = {
|
|||||||
!model.id.endsWith('-nothink') &&
|
!model.id.endsWith('-nothink') &&
|
||||||
!model.id.endsWith('-search') &&
|
!model.id.endsWith('-search') &&
|
||||||
!model.id.includes('embedding'),
|
!model.id.includes('embedding'),
|
||||||
provider: (provider: Provider) => {
|
provider: (provider) => {
|
||||||
return extraProviderConfig({
|
return extraProviderConfig({
|
||||||
...provider,
|
...provider,
|
||||||
type: 'gemini',
|
type: 'gemini',
|
||||||
@@ -44,7 +56,7 @@ const AIHUBMIX_RULES: RuleSet = {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
match: isOpenAILLMModel,
|
match: isOpenAILLMModel,
|
||||||
provider: (provider: Provider) => {
|
provider: (provider) => {
|
||||||
return extraProviderConfig({
|
return extraProviderConfig({
|
||||||
...provider,
|
...provider,
|
||||||
type: 'openai-response'
|
type: 'openai-response'
|
||||||
@@ -52,7 +64,8 @@ const AIHUBMIX_RULES: RuleSet = {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
fallbackRule: (provider: Provider) => extraProviderConfig(provider)
|
fallbackRule: (provider) => extraProviderConfig(provider)
|
||||||
}
|
}
|
||||||
|
|
||||||
export const aihubmixProviderCreator = provider2Provider.bind(null, AIHUBMIX_RULES)
|
export const aihubmixProviderCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
|
||||||
|
provider2Provider<MinimalModel, MinimalProvider, P>(AIHUBMIX_RULES, model, provider)
|
||||||
22
packages/shared/provider/config/azure-anthropic.ts
Normal file
22
packages/shared/provider/config/azure-anthropic.ts
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import type { MinimalModel, MinimalProvider, ProviderType } from '../types'
|
||||||
|
import { provider2Provider, startsWith } from './helper'
|
||||||
|
import type { RuleSet } from './types'
|
||||||
|
|
||||||
|
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
|
||||||
|
const AZURE_ANTHROPIC_RULES: RuleSet = {
|
||||||
|
rules: [
|
||||||
|
{
|
||||||
|
match: startsWith('claude'),
|
||||||
|
provider: (provider: MinimalProvider) => ({
|
||||||
|
...provider,
|
||||||
|
type: 'anthropic' as ProviderType,
|
||||||
|
apiHost: provider.apiHost + 'anthropic/v1',
|
||||||
|
id: 'azure-anthropic'
|
||||||
|
})
|
||||||
|
}
|
||||||
|
],
|
||||||
|
fallbackRule: (provider: MinimalProvider) => provider
|
||||||
|
}
|
||||||
|
|
||||||
|
export const azureAnthropicProviderCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
|
||||||
|
provider2Provider<MinimalModel, MinimalProvider, P>(AZURE_ANTHROPIC_RULES, model, provider)
|
||||||
32
packages/shared/provider/config/helper.ts
Normal file
32
packages/shared/provider/config/helper.ts
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
import type { MinimalModel, MinimalProvider } from '../types'
|
||||||
|
import type { RuleSet } from './types'
|
||||||
|
|
||||||
|
export const startsWith =
|
||||||
|
(prefix: string) =>
|
||||||
|
<M extends MinimalModel>(model: M) =>
|
||||||
|
model.id.toLowerCase().startsWith(prefix.toLowerCase())
|
||||||
|
|
||||||
|
export const endpointIs =
|
||||||
|
(type: string) =>
|
||||||
|
<M extends MinimalModel>(model: M) =>
|
||||||
|
model.endpoint_type === type
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析模型对应的Provider
|
||||||
|
* @param ruleSet 规则集对象
|
||||||
|
* @param model 模型对象
|
||||||
|
* @param provider 原始provider对象
|
||||||
|
* @returns 解析出的provider对象
|
||||||
|
*/
|
||||||
|
export function provider2Provider<M extends MinimalModel, R extends MinimalProvider, P extends R = R>(
|
||||||
|
ruleSet: RuleSet<M, R>,
|
||||||
|
model: M,
|
||||||
|
provider: P
|
||||||
|
): P {
|
||||||
|
for (const rule of ruleSet.rules) {
|
||||||
|
if (rule.match(model)) {
|
||||||
|
return rule.provider(provider) as P
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ruleSet.fallbackRule(provider) as P
|
||||||
|
}
|
||||||
6
packages/shared/provider/config/index.ts
Normal file
6
packages/shared/provider/config/index.ts
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
export { aihubmixProviderCreator } from './aihubmix'
|
||||||
|
export { azureAnthropicProviderCreator } from './azure-anthropic'
|
||||||
|
export { endpointIs, provider2Provider, startsWith } from './helper'
|
||||||
|
export { newApiResolverCreator } from './newApi'
|
||||||
|
export type { RuleSet } from './types'
|
||||||
|
export { vertexAnthropicProviderCreator } from './vertex-anthropic'
|
||||||
@@ -1,8 +1,7 @@
|
|||||||
/**
|
/**
|
||||||
* NewAPI规则集
|
* NewAPI规则集
|
||||||
*/
|
*/
|
||||||
import type { Provider } from '@renderer/types'
|
import type { MinimalModel, MinimalProvider, ProviderType } from '../types'
|
||||||
|
|
||||||
import { endpointIs, provider2Provider } from './helper'
|
import { endpointIs, provider2Provider } from './helper'
|
||||||
import type { RuleSet } from './types'
|
import type { RuleSet } from './types'
|
||||||
|
|
||||||
@@ -10,42 +9,43 @@ const NEWAPI_RULES: RuleSet = {
|
|||||||
rules: [
|
rules: [
|
||||||
{
|
{
|
||||||
match: endpointIs('anthropic'),
|
match: endpointIs('anthropic'),
|
||||||
provider: (provider: Provider) => {
|
provider: (provider) => {
|
||||||
return {
|
return {
|
||||||
...provider,
|
...provider,
|
||||||
type: 'anthropic'
|
type: 'anthropic' as ProviderType
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
match: endpointIs('gemini'),
|
match: endpointIs('gemini'),
|
||||||
provider: (provider: Provider) => {
|
provider: (provider) => {
|
||||||
return {
|
return {
|
||||||
...provider,
|
...provider,
|
||||||
type: 'gemini'
|
type: 'gemini' as ProviderType
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
match: endpointIs('openai-response'),
|
match: endpointIs('openai-response'),
|
||||||
provider: (provider: Provider) => {
|
provider: (provider) => {
|
||||||
return {
|
return {
|
||||||
...provider,
|
...provider,
|
||||||
type: 'openai-response'
|
type: 'openai-response' as ProviderType
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model),
|
match: (model) => endpointIs('openai')(model) || endpointIs('image-generation')(model),
|
||||||
provider: (provider: Provider) => {
|
provider: (provider) => {
|
||||||
return {
|
return {
|
||||||
...provider,
|
...provider,
|
||||||
type: 'openai'
|
type: 'openai' as ProviderType
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
fallbackRule: (provider: Provider) => provider
|
fallbackRule: (provider) => provider
|
||||||
}
|
}
|
||||||
|
|
||||||
export const newApiResolverCreator = provider2Provider.bind(null, NEWAPI_RULES)
|
export const newApiResolverCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
|
||||||
|
provider2Provider<MinimalModel, MinimalProvider, P>(NEWAPI_RULES, model, provider)
|
||||||
9
packages/shared/provider/config/types.ts
Normal file
9
packages/shared/provider/config/types.ts
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
import type { MinimalModel, MinimalProvider } from '../types'
|
||||||
|
|
||||||
|
export interface RuleSet<M extends MinimalModel = MinimalModel, P extends MinimalProvider = MinimalProvider> {
|
||||||
|
rules: Array<{
|
||||||
|
match: (model: M) => boolean
|
||||||
|
provider: (provider: P) => P
|
||||||
|
}>
|
||||||
|
fallbackRule: (provider: P) => P
|
||||||
|
}
|
||||||
19
packages/shared/provider/config/vertex-anthropic.ts
Normal file
19
packages/shared/provider/config/vertex-anthropic.ts
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
import type { MinimalModel, MinimalProvider } from '../types'
|
||||||
|
import { provider2Provider, startsWith } from './helper'
|
||||||
|
import type { RuleSet } from './types'
|
||||||
|
|
||||||
|
const VERTEX_ANTHROPIC_RULES: RuleSet = {
|
||||||
|
rules: [
|
||||||
|
{
|
||||||
|
match: startsWith('claude'),
|
||||||
|
provider: (provider: MinimalProvider) => ({
|
||||||
|
...provider,
|
||||||
|
id: 'google-vertex-anthropic'
|
||||||
|
})
|
||||||
|
}
|
||||||
|
],
|
||||||
|
fallbackRule: (provider: MinimalProvider) => provider
|
||||||
|
}
|
||||||
|
|
||||||
|
export const vertexAnthropicProviderCreator = <P extends MinimalProvider>(model: MinimalModel, provider: P): P =>
|
||||||
|
provider2Provider<MinimalModel, MinimalProvider, P>(VERTEX_ANTHROPIC_RULES, model, provider)
|
||||||
26
packages/shared/provider/constant.ts
Normal file
26
packages/shared/provider/constant.ts
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
import { getLowerBaseModelName } from '@shared/utils/naming'
|
||||||
|
|
||||||
|
import type { MinimalModel } from './types'
|
||||||
|
|
||||||
|
export const COPILOT_EDITOR_VERSION = 'vscode/1.104.1'
|
||||||
|
export const COPILOT_PLUGIN_VERSION = 'copilot-chat/0.26.7'
|
||||||
|
export const COPILOT_INTEGRATION_ID = 'vscode-chat'
|
||||||
|
export const COPILOT_USER_AGENT = 'GitHubCopilotChat/0.26.7'
|
||||||
|
|
||||||
|
export const COPILOT_DEFAULT_HEADERS = {
|
||||||
|
'Copilot-Integration-Id': COPILOT_INTEGRATION_ID,
|
||||||
|
'User-Agent': COPILOT_USER_AGENT,
|
||||||
|
'Editor-Version': COPILOT_EDITOR_VERSION,
|
||||||
|
'Editor-Plugin-Version': COPILOT_PLUGIN_VERSION,
|
||||||
|
'editor-version': COPILOT_EDITOR_VERSION,
|
||||||
|
'editor-plugin-version': COPILOT_PLUGIN_VERSION,
|
||||||
|
'copilot-vision-request': 'true'
|
||||||
|
} as const
|
||||||
|
|
||||||
|
// Models that require the OpenAI Responses endpoint when routed through GitHub Copilot (#10560)
|
||||||
|
const COPILOT_RESPONSES_MODEL_IDS = ['gpt-5-codex', 'gpt-5.1-codex', 'gpt-5.1-codex-mini']
|
||||||
|
|
||||||
|
export function isCopilotResponsesModel<M extends MinimalModel>(model: M): boolean {
|
||||||
|
const normalizedId = getLowerBaseModelName(model.id)
|
||||||
|
return COPILOT_RESPONSES_MODEL_IDS.some((target) => normalizedId === target)
|
||||||
|
}
|
||||||
100
packages/shared/provider/detection.ts
Normal file
100
packages/shared/provider/detection.ts
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
/**
|
||||||
|
* Provider Type Detection Utilities
|
||||||
|
*
|
||||||
|
* Functions to detect provider types based on provider configuration.
|
||||||
|
* These are pure functions that only depend on provider.type and provider.id.
|
||||||
|
*
|
||||||
|
* NOTE: These functions should match the logic in @renderer/utils/provider.ts
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type { MinimalProvider } from './types'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if provider is Anthropic type
|
||||||
|
*/
|
||||||
|
export function isAnthropicProvider<P extends MinimalProvider>(provider: P): boolean {
|
||||||
|
return provider.type === 'anthropic'
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if provider is OpenAI Response type (openai-response)
|
||||||
|
* NOTE: This matches isOpenAIProvider in renderer/utils/provider.ts
|
||||||
|
*/
|
||||||
|
export function isOpenAIProvider<P extends MinimalProvider>(provider: P): boolean {
|
||||||
|
return provider.type === 'openai-response'
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if provider is Gemini type
|
||||||
|
*/
|
||||||
|
export function isGeminiProvider<P extends MinimalProvider>(provider: P): boolean {
|
||||||
|
return provider.type === 'gemini'
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if provider is Azure OpenAI type
|
||||||
|
*/
|
||||||
|
export function isAzureOpenAIProvider<P extends MinimalProvider>(provider: P): boolean {
|
||||||
|
return provider.type === 'azure-openai'
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if provider is Vertex AI type
|
||||||
|
*/
|
||||||
|
export function isVertexProvider<P extends MinimalProvider>(provider: P): boolean {
|
||||||
|
return provider.type === 'vertexai'
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if provider is AWS Bedrock type
|
||||||
|
*/
|
||||||
|
export function isAwsBedrockProvider<P extends MinimalProvider>(provider: P): boolean {
|
||||||
|
return provider.type === 'aws-bedrock'
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if provider is AI Gateway type
|
||||||
|
*/
|
||||||
|
export function isAIGatewayProvider<P extends MinimalProvider>(provider: P): boolean {
|
||||||
|
return provider.type === 'ai-gateway'
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if Azure OpenAI provider uses responses endpoint
|
||||||
|
* Matches isAzureResponsesEndpoint in renderer/utils/provider.ts
|
||||||
|
*/
|
||||||
|
export function isAzureResponsesEndpoint<P extends MinimalProvider>(provider: P): boolean {
|
||||||
|
return provider.apiVersion === 'preview' || provider.apiVersion === 'v1'
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if provider is Cherry AI type
|
||||||
|
* Matches isCherryAIProvider in renderer/utils/provider.ts
|
||||||
|
*/
|
||||||
|
export function isCherryAIProvider<P extends MinimalProvider>(provider: P): boolean {
|
||||||
|
return provider.id === 'cherryai'
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if provider is Perplexity type
|
||||||
|
* Matches isPerplexityProvider in renderer/utils/provider.ts
|
||||||
|
*/
|
||||||
|
export function isPerplexityProvider<P extends MinimalProvider>(provider: P): boolean {
|
||||||
|
return provider.id === 'perplexity'
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if provider is new-api type (supports multiple backends)
|
||||||
|
* Matches isNewApiProvider in renderer/utils/provider.ts
|
||||||
|
*/
|
||||||
|
export function isNewApiProvider<P extends MinimalProvider>(provider: P): boolean {
|
||||||
|
return ['new-api', 'cherryin'].includes(provider.id) || provider.type === ('new-api' as string)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if provider is OpenAI compatible
|
||||||
|
* Matches isOpenAICompatibleProvider in renderer/utils/provider.ts
|
||||||
|
*/
|
||||||
|
export function isOpenAICompatibleProvider<P extends MinimalProvider>(provider: P): boolean {
|
||||||
|
return ['openai', 'new-api', 'mistral'].includes(provider.type)
|
||||||
|
}
|
||||||
136
packages/shared/provider/format.ts
Normal file
136
packages/shared/provider/format.ts
Normal file
@@ -0,0 +1,136 @@
|
|||||||
|
/**
|
||||||
|
* Provider API Host Formatting
|
||||||
|
*
|
||||||
|
* Utilities for formatting provider API hosts to work with AI SDK.
|
||||||
|
* These handle the differences between how Cherry Studio stores API hosts
|
||||||
|
* and how AI SDK expects them.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import {
|
||||||
|
formatApiHost,
|
||||||
|
formatAzureOpenAIApiHost,
|
||||||
|
formatVertexApiHost,
|
||||||
|
routeToEndpoint,
|
||||||
|
withoutTrailingSlash
|
||||||
|
} from '../api'
|
||||||
|
import {
|
||||||
|
isAnthropicProvider,
|
||||||
|
isAzureOpenAIProvider,
|
||||||
|
isCherryAIProvider,
|
||||||
|
isGeminiProvider,
|
||||||
|
isPerplexityProvider,
|
||||||
|
isVertexProvider
|
||||||
|
} from './detection'
|
||||||
|
import type { MinimalProvider } from './types'
|
||||||
|
import { SystemProviderIds } from './types'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interface for environment-specific implementations
|
||||||
|
* Renderer and Main process can provide their own implementations
|
||||||
|
*/
|
||||||
|
export interface ProviderFormatContext {
|
||||||
|
vertex: {
|
||||||
|
project: string
|
||||||
|
location: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Default Azure OpenAI API host formatter
|
||||||
|
*/
|
||||||
|
export function defaultFormatAzureOpenAIApiHost(host: string): string {
|
||||||
|
const normalizedHost = withoutTrailingSlash(host)
|
||||||
|
?.replace(/\/v1$/, '')
|
||||||
|
.replace(/\/openai$/, '')
|
||||||
|
// AI SDK will add /v1
|
||||||
|
return formatApiHost(normalizedHost + '/openai', false)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Format provider API host for AI SDK
|
||||||
|
*
|
||||||
|
* This function normalizes the apiHost to work with AI SDK.
|
||||||
|
* Different providers have different requirements:
|
||||||
|
* - Most providers: add /v1 suffix
|
||||||
|
* - Gemini: add /v1beta suffix
|
||||||
|
* - Some providers: no suffix needed
|
||||||
|
*
|
||||||
|
* @param provider - The provider to format
|
||||||
|
* @param context - Optional context with environment-specific implementations
|
||||||
|
* @returns Provider with formatted apiHost (and anthropicApiHost if applicable)
|
||||||
|
*/
|
||||||
|
export function formatProviderApiHost<T extends MinimalProvider>(provider: T, context: ProviderFormatContext): T {
|
||||||
|
const formatted = { ...provider }
|
||||||
|
|
||||||
|
// Format anthropicApiHost if present
|
||||||
|
if (formatted.anthropicApiHost) {
|
||||||
|
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Format based on provider type
|
||||||
|
if (isAnthropicProvider(provider)) {
|
||||||
|
const baseHost = formatted.anthropicApiHost || formatted.apiHost
|
||||||
|
// AI SDK needs /v1 in baseURL
|
||||||
|
formatted.apiHost = formatApiHost(baseHost)
|
||||||
|
if (!formatted.anthropicApiHost) {
|
||||||
|
formatted.anthropicApiHost = formatted.apiHost
|
||||||
|
}
|
||||||
|
} else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) {
|
||||||
|
formatted.apiHost = formatApiHost(formatted.apiHost, false)
|
||||||
|
} else if (isGeminiProvider(formatted)) {
|
||||||
|
formatted.apiHost = formatApiHost(formatted.apiHost, true, 'v1beta')
|
||||||
|
} else if (isAzureOpenAIProvider(formatted)) {
|
||||||
|
formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost)
|
||||||
|
} else if (isVertexProvider(formatted)) {
|
||||||
|
formatted.apiHost = formatVertexApiHost(formatted, context.vertex.project, context.vertex.location)
|
||||||
|
} else if (isCherryAIProvider(formatted)) {
|
||||||
|
formatted.apiHost = formatApiHost(formatted.apiHost, false)
|
||||||
|
} else if (isPerplexityProvider(formatted)) {
|
||||||
|
formatted.apiHost = formatApiHost(formatted.apiHost, false)
|
||||||
|
} else {
|
||||||
|
formatted.apiHost = formatApiHost(formatted.apiHost)
|
||||||
|
}
|
||||||
|
|
||||||
|
return formatted
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the base URL for AI SDK from a formatted provider
|
||||||
|
*
|
||||||
|
* This extracts the baseURL that AI SDK expects, handling
|
||||||
|
* the '#' endpoint routing format if present.
|
||||||
|
*
|
||||||
|
* @param formattedApiHost - The formatted apiHost (after formatProviderApiHost)
|
||||||
|
* @returns The baseURL for AI SDK
|
||||||
|
*/
|
||||||
|
export function getBaseUrlForAiSdk(formattedApiHost: string): string {
|
||||||
|
const { baseURL } = routeToEndpoint(formattedApiHost)
|
||||||
|
return baseURL
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get rotated API key from comma-separated keys
|
||||||
|
*
|
||||||
|
* This is the interface for API key rotation. The actual implementation
|
||||||
|
* depends on the environment (renderer uses window.keyv, main uses its own storage).
|
||||||
|
*/
|
||||||
|
export interface ApiKeyRotator {
|
||||||
|
/**
|
||||||
|
* Get the next API key in rotation
|
||||||
|
* @param providerId - The provider ID for tracking rotation
|
||||||
|
* @param keys - Comma-separated API keys
|
||||||
|
* @returns The next API key to use
|
||||||
|
*/
|
||||||
|
getRotatedKey(providerId: string, keys: string): string
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Simple API key rotator that always returns the first key
|
||||||
|
* Use this when rotation is not needed
|
||||||
|
*/
|
||||||
|
export const simpleKeyRotator: ApiKeyRotator = {
|
||||||
|
getRotatedKey(_providerId: string, keys: string): string {
|
||||||
|
const keyList = keys.split(',').map((k) => k.trim())
|
||||||
|
return keyList[0] || keys
|
||||||
|
}
|
||||||
|
}
|
||||||
48
packages/shared/provider/index.ts
Normal file
48
packages/shared/provider/index.ts
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
/**
|
||||||
|
* Shared Provider Utilities
|
||||||
|
*
|
||||||
|
* This module exports utilities for working with AI providers
|
||||||
|
* that can be shared between main process and renderer process.
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Type definitions
|
||||||
|
export type { MinimalProvider, ProviderType, SystemProviderId } from './types'
|
||||||
|
export { SystemProviderIds } from './types'
|
||||||
|
|
||||||
|
// Provider type detection
|
||||||
|
export {
|
||||||
|
isAIGatewayProvider,
|
||||||
|
isAnthropicProvider,
|
||||||
|
isAwsBedrockProvider,
|
||||||
|
isAzureOpenAIProvider,
|
||||||
|
isAzureResponsesEndpoint,
|
||||||
|
isCherryAIProvider,
|
||||||
|
isGeminiProvider,
|
||||||
|
isNewApiProvider,
|
||||||
|
isOpenAICompatibleProvider,
|
||||||
|
isOpenAIProvider,
|
||||||
|
isPerplexityProvider,
|
||||||
|
isVertexProvider
|
||||||
|
} from './detection'
|
||||||
|
|
||||||
|
// API host formatting
|
||||||
|
export type { ApiKeyRotator, ProviderFormatContext } from './format'
|
||||||
|
export {
|
||||||
|
defaultFormatAzureOpenAIApiHost,
|
||||||
|
formatProviderApiHost,
|
||||||
|
getBaseUrlForAiSdk,
|
||||||
|
simpleKeyRotator
|
||||||
|
} from './format'
|
||||||
|
|
||||||
|
// Provider ID mapping
|
||||||
|
export { getAiSdkProviderId, STATIC_PROVIDER_MAPPING, tryResolveProviderId } from './mapping'
|
||||||
|
|
||||||
|
// AI SDK configuration
|
||||||
|
export type { AiSdkConfig, AiSdkConfigContext } from './sdk-config'
|
||||||
|
export { providerToAiSdkConfig } from './sdk-config'
|
||||||
|
|
||||||
|
// Provider resolution
|
||||||
|
export { resolveActualProvider } from './resolve'
|
||||||
|
|
||||||
|
// Provider initialization
|
||||||
|
export { initializeSharedProviders, SHARED_PROVIDER_CONFIGS } from './initialization'
|
||||||
107
packages/shared/provider/initialization.ts
Normal file
107
packages/shared/provider/initialization.ts
Normal file
@@ -0,0 +1,107 @@
|
|||||||
|
import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider'
|
||||||
|
|
||||||
|
type ProviderInitializationLogger = {
|
||||||
|
warn?: (message: string) => void
|
||||||
|
error?: (message: string, error: Error) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
export const SHARED_PROVIDER_CONFIGS: ProviderConfig[] = [
|
||||||
|
{
|
||||||
|
id: 'openrouter',
|
||||||
|
name: 'OpenRouter',
|
||||||
|
import: () => import('@openrouter/ai-sdk-provider'),
|
||||||
|
creatorFunctionName: 'createOpenRouter',
|
||||||
|
supportsImageGeneration: true,
|
||||||
|
aliases: ['openrouter']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'google-vertex',
|
||||||
|
name: 'Google Vertex AI',
|
||||||
|
import: () => import('@ai-sdk/google-vertex/edge'),
|
||||||
|
creatorFunctionName: 'createVertex',
|
||||||
|
supportsImageGeneration: true,
|
||||||
|
aliases: ['vertexai']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'google-vertex-anthropic',
|
||||||
|
name: 'Google Vertex AI Anthropic',
|
||||||
|
import: () => import('@ai-sdk/google-vertex/anthropic/edge'),
|
||||||
|
creatorFunctionName: 'createVertexAnthropic',
|
||||||
|
supportsImageGeneration: true,
|
||||||
|
aliases: ['vertexai-anthropic']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'azure-anthropic',
|
||||||
|
name: 'Azure AI Anthropic',
|
||||||
|
import: () => import('@ai-sdk/anthropic'),
|
||||||
|
creatorFunctionName: 'createAnthropic',
|
||||||
|
supportsImageGeneration: false,
|
||||||
|
aliases: ['azure-anthropic']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'github-copilot-openai-compatible',
|
||||||
|
name: 'GitHub Copilot OpenAI Compatible',
|
||||||
|
import: () => import('@opeoginni/github-copilot-openai-compatible'),
|
||||||
|
creatorFunctionName: 'createGitHubCopilotOpenAICompatible',
|
||||||
|
supportsImageGeneration: false,
|
||||||
|
aliases: ['copilot', 'github-copilot']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'bedrock',
|
||||||
|
name: 'Amazon Bedrock',
|
||||||
|
import: () => import('@ai-sdk/amazon-bedrock'),
|
||||||
|
creatorFunctionName: 'createAmazonBedrock',
|
||||||
|
supportsImageGeneration: true,
|
||||||
|
aliases: ['aws-bedrock']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'perplexity',
|
||||||
|
name: 'Perplexity',
|
||||||
|
import: () => import('@ai-sdk/perplexity'),
|
||||||
|
creatorFunctionName: 'createPerplexity',
|
||||||
|
supportsImageGeneration: false,
|
||||||
|
aliases: ['perplexity']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'mistral',
|
||||||
|
name: 'Mistral',
|
||||||
|
import: () => import('@ai-sdk/mistral'),
|
||||||
|
creatorFunctionName: 'createMistral',
|
||||||
|
supportsImageGeneration: false,
|
||||||
|
aliases: ['mistral']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'huggingface',
|
||||||
|
name: 'HuggingFace',
|
||||||
|
import: () => import('@ai-sdk/huggingface'),
|
||||||
|
creatorFunctionName: 'createHuggingFace',
|
||||||
|
supportsImageGeneration: true,
|
||||||
|
aliases: ['hf', 'hugging-face']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'ai-gateway',
|
||||||
|
name: 'AI Gateway',
|
||||||
|
import: () => import('@ai-sdk/gateway'),
|
||||||
|
creatorFunctionName: 'createGateway',
|
||||||
|
supportsImageGeneration: true,
|
||||||
|
aliases: ['gateway']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'cerebras',
|
||||||
|
name: 'Cerebras',
|
||||||
|
import: () => import('@ai-sdk/cerebras'),
|
||||||
|
creatorFunctionName: 'createCerebras',
|
||||||
|
supportsImageGeneration: false
|
||||||
|
}
|
||||||
|
] as const
|
||||||
|
|
||||||
|
export function initializeSharedProviders(logger?: ProviderInitializationLogger): void {
|
||||||
|
try {
|
||||||
|
const successCount = registerMultipleProviderConfigs(SHARED_PROVIDER_CONFIGS)
|
||||||
|
if (successCount < SHARED_PROVIDER_CONFIGS.length) {
|
||||||
|
logger?.warn?.('Some providers failed to register. Check previous error logs.')
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger?.error?.('Failed to initialize shared providers', error as Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
95
packages/shared/provider/mapping.ts
Normal file
95
packages/shared/provider/mapping.ts
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
/**
|
||||||
|
* Provider ID Mapping
|
||||||
|
*
|
||||||
|
* Maps Cherry Studio provider IDs/types to AI SDK provider IDs.
|
||||||
|
* This logic should match @renderer/aiCore/provider/factory.ts
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider'
|
||||||
|
|
||||||
|
import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from './detection'
|
||||||
|
import type { MinimalProvider } from './types'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Static mapping from Cherry Studio provider ID/type to AI SDK provider ID
|
||||||
|
* Matches STATIC_PROVIDER_MAPPING in @renderer/aiCore/provider/factory.ts
|
||||||
|
*/
|
||||||
|
export const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
|
||||||
|
gemini: 'google', // Google Gemini -> google
|
||||||
|
'azure-openai': 'azure', // Azure OpenAI -> azure
|
||||||
|
'openai-response': 'openai', // OpenAI Responses -> openai
|
||||||
|
grok: 'xai', // Grok -> xai
|
||||||
|
copilot: 'github-copilot-openai-compatible'
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Try to resolve a provider identifier to an AI SDK provider ID
|
||||||
|
* Matches tryResolveProviderId in @renderer/aiCore/provider/factory.ts
|
||||||
|
*
|
||||||
|
* @param identifier - The provider ID or type to resolve
|
||||||
|
* @param checker - Provider config checker (defaults to static mapping only)
|
||||||
|
* @returns The resolved AI SDK provider ID, or null if not found
|
||||||
|
*/
|
||||||
|
export function tryResolveProviderId(identifier: string): ProviderId | null {
|
||||||
|
// 1. 检查静态映射
|
||||||
|
const staticMapping = STATIC_PROVIDER_MAPPING[identifier]
|
||||||
|
if (staticMapping) {
|
||||||
|
return staticMapping
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 检查AiCore是否支持(包括别名支持)
|
||||||
|
if (hasProviderConfigByAlias(identifier)) {
|
||||||
|
// 解析为真实的Provider ID
|
||||||
|
return resolveProviderConfigId(identifier) as ProviderId
|
||||||
|
}
|
||||||
|
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the AI SDK Provider ID for a Cherry Studio provider
|
||||||
|
* Matches getAiSdkProviderId in @renderer/aiCore/provider/factory.ts
|
||||||
|
*
|
||||||
|
* Logic:
|
||||||
|
* 1. Handle Azure OpenAI specially (check responses endpoint)
|
||||||
|
* 2. Try to resolve from provider.id
|
||||||
|
* 3. Try to resolve from provider.type (but not for generic 'openai' type)
|
||||||
|
* 4. Check for OpenAI API host pattern
|
||||||
|
* 5. Fallback to provider's own ID
|
||||||
|
*
|
||||||
|
* @param provider - The Cherry Studio provider
|
||||||
|
* @param checker - Provider config checker (defaults to static mapping only)
|
||||||
|
* @returns The AI SDK provider ID to use
|
||||||
|
*/
|
||||||
|
export function getAiSdkProviderId(provider: MinimalProvider): ProviderId {
|
||||||
|
// 1. Handle Azure OpenAI specially - check this FIRST before other resolution
|
||||||
|
if (isAzureOpenAIProvider(provider)) {
|
||||||
|
if (isAzureResponsesEndpoint(provider)) {
|
||||||
|
return 'azure-responses'
|
||||||
|
}
|
||||||
|
return 'azure'
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 尝试解析provider.id
|
||||||
|
const resolvedFromId = tryResolveProviderId(provider.id)
|
||||||
|
if (resolvedFromId) {
|
||||||
|
return resolvedFromId
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 尝试解析provider.type
|
||||||
|
// 会把所有类型为openai的自定义provider解析到aisdk的openaiProvider上
|
||||||
|
if (provider.type !== 'openai') {
|
||||||
|
const resolvedFromType = tryResolveProviderId(provider.type)
|
||||||
|
if (resolvedFromType) {
|
||||||
|
return resolvedFromType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. Check for OpenAI API host pattern
|
||||||
|
if (provider.apiHost.includes('api.openai.com')) {
|
||||||
|
return 'openai-chat'
|
||||||
|
}
|
||||||
|
|
||||||
|
// 5. 最后的fallback(使用provider本身的id)
|
||||||
|
return provider.id
|
||||||
|
}
|
||||||
43
packages/shared/provider/resolve.ts
Normal file
43
packages/shared/provider/resolve.ts
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
|
||||||
|
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
|
||||||
|
import { isAzureOpenAIProvider, isNewApiProvider } from './detection'
|
||||||
|
import type { MinimalModel, MinimalProvider } from './types'
|
||||||
|
|
||||||
|
export interface ResolveActualProviderOptions<P extends MinimalProvider> {
|
||||||
|
isSystemProvider?: (provider: P) => boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
const defaultIsSystemProvider = <P extends MinimalProvider>(provider: P): boolean => {
|
||||||
|
if ('isSystem' in provider) {
|
||||||
|
return Boolean((provider as unknown as { isSystem?: boolean }).isSystem)
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
export function resolveActualProvider<M extends MinimalModel, P extends MinimalProvider>(
|
||||||
|
provider: P,
|
||||||
|
model: M,
|
||||||
|
options: ResolveActualProviderOptions<P> = {}
|
||||||
|
): P {
|
||||||
|
let resolvedProvider = provider
|
||||||
|
|
||||||
|
if (isNewApiProvider(resolvedProvider)) {
|
||||||
|
resolvedProvider = newApiResolverCreator(model, resolvedProvider)
|
||||||
|
}
|
||||||
|
|
||||||
|
const isSystemProvider = options.isSystemProvider?.(resolvedProvider) ?? defaultIsSystemProvider(resolvedProvider)
|
||||||
|
|
||||||
|
if (isSystemProvider && resolvedProvider.id === 'aihubmix') {
|
||||||
|
resolvedProvider = aihubmixProviderCreator(model, resolvedProvider)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isSystemProvider && resolvedProvider.id === 'vertexai') {
|
||||||
|
resolvedProvider = vertexAnthropicProviderCreator(model, resolvedProvider)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isAzureOpenAIProvider(resolvedProvider)) {
|
||||||
|
resolvedProvider = azureAnthropicProviderCreator(model, resolvedProvider)
|
||||||
|
}
|
||||||
|
|
||||||
|
return resolvedProvider
|
||||||
|
}
|
||||||
259
packages/shared/provider/sdk-config.ts
Normal file
259
packages/shared/provider/sdk-config.ts
Normal file
@@ -0,0 +1,259 @@
|
|||||||
|
/**
|
||||||
|
* AI SDK Configuration
|
||||||
|
*
|
||||||
|
* Shared utilities for converting Cherry Studio Provider to AI SDK configuration.
|
||||||
|
* Environment-specific logic (renderer/main) is injected via context interfaces.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
|
||||||
|
|
||||||
|
import { routeToEndpoint } from '../api'
|
||||||
|
import { getAiSdkProviderId } from './mapping'
|
||||||
|
import type { MinimalProvider } from './types'
|
||||||
|
import { SystemProviderIds } from './types'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AI SDK configuration result
|
||||||
|
*/
|
||||||
|
export interface AiSdkConfig {
|
||||||
|
providerId: string
|
||||||
|
options: Record<string, unknown>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Context for environment-specific implementations
|
||||||
|
*/
|
||||||
|
export interface AiSdkConfigContext {
|
||||||
|
/**
|
||||||
|
* Get the rotated API key (for multi-key support)
|
||||||
|
* Default: returns first key
|
||||||
|
*/
|
||||||
|
getRotatedApiKey?: (provider: MinimalProvider) => string
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a model uses chat completion only (for OpenAI response mode)
|
||||||
|
* Default: returns false
|
||||||
|
*/
|
||||||
|
isOpenAIChatCompletionOnlyModel?: (modelId: string) => boolean
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get Copilot default headers (constants)
|
||||||
|
* Default: returns empty object
|
||||||
|
*/
|
||||||
|
getCopilotDefaultHeaders?: () => Record<string, string>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get Copilot stored headers from state
|
||||||
|
* Default: returns empty object
|
||||||
|
*/
|
||||||
|
getCopilotStoredHeaders?: () => Record<string, string>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get AWS Bedrock configuration
|
||||||
|
* Default: returns undefined (not configured)
|
||||||
|
*/
|
||||||
|
getAwsBedrockConfig?: () =>
|
||||||
|
| {
|
||||||
|
authType: 'apiKey' | 'iam'
|
||||||
|
region: string
|
||||||
|
apiKey?: string
|
||||||
|
accessKeyId?: string
|
||||||
|
secretAccessKey?: string
|
||||||
|
}
|
||||||
|
| undefined
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get Vertex AI configuration
|
||||||
|
* Default: returns undefined (not configured)
|
||||||
|
*/
|
||||||
|
getVertexConfig?: (provider: MinimalProvider) =>
|
||||||
|
| {
|
||||||
|
project: string
|
||||||
|
location: string
|
||||||
|
googleCredentials: {
|
||||||
|
privateKey: string
|
||||||
|
clientEmail: string
|
||||||
|
}
|
||||||
|
}
|
||||||
|
| undefined
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get endpoint type for cherryin provider
|
||||||
|
*/
|
||||||
|
getEndpointType?: (modelId: string) => string | undefined
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Custom fetch implementation
|
||||||
|
* Main process: use Electron net.fetch
|
||||||
|
* Renderer process: use browser fetch (default)
|
||||||
|
*/
|
||||||
|
fetch?: typeof globalThis.fetch
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get CherryAI signed fetch wrapper
|
||||||
|
* Returns a fetch function that adds signature headers to requests
|
||||||
|
*/
|
||||||
|
getCherryAISignedFetch?: () => typeof globalThis.fetch
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Default simple key rotator - returns first key
|
||||||
|
*/
|
||||||
|
function defaultGetRotatedApiKey(provider: MinimalProvider): string {
|
||||||
|
const keys = provider.apiKey.split(',').map((k) => k.trim())
|
||||||
|
return keys[0] || provider.apiKey
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert Cherry Studio Provider to AI SDK configuration
|
||||||
|
*
|
||||||
|
* @param provider - The formatted provider (after formatProviderApiHost)
|
||||||
|
* @param modelId - The model ID to use
|
||||||
|
* @param context - Environment-specific implementations
|
||||||
|
* @returns AI SDK configuration
|
||||||
|
*/
|
||||||
|
export function providerToAiSdkConfig(
|
||||||
|
provider: MinimalProvider,
|
||||||
|
modelId: string,
|
||||||
|
context: AiSdkConfigContext = {}
|
||||||
|
): AiSdkConfig {
|
||||||
|
const getRotatedApiKey = context.getRotatedApiKey || defaultGetRotatedApiKey
|
||||||
|
const isOpenAIChatCompletionOnlyModel = context.isOpenAIChatCompletionOnlyModel || (() => false)
|
||||||
|
|
||||||
|
const aiSdkProviderId = getAiSdkProviderId(provider)
|
||||||
|
|
||||||
|
// Build base config
|
||||||
|
const { baseURL, endpoint } = routeToEndpoint(provider.apiHost)
|
||||||
|
const baseConfig = {
|
||||||
|
baseURL,
|
||||||
|
apiKey: getRotatedApiKey(provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Copilot specially
|
||||||
|
if (provider.id === SystemProviderIds.copilot) {
|
||||||
|
const defaultHeaders = context.getCopilotDefaultHeaders?.() ?? {}
|
||||||
|
const storedHeaders = context.getCopilotStoredHeaders?.() ?? {}
|
||||||
|
const copilotExtraOptions: Record<string, unknown> = {
|
||||||
|
headers: {
|
||||||
|
...defaultHeaders,
|
||||||
|
...storedHeaders,
|
||||||
|
...provider.extra_headers
|
||||||
|
},
|
||||||
|
name: provider.id,
|
||||||
|
includeUsage: true
|
||||||
|
}
|
||||||
|
if (context.fetch) {
|
||||||
|
copilotExtraOptions.fetch = context.fetch
|
||||||
|
}
|
||||||
|
const options = ProviderConfigFactory.fromProvider(
|
||||||
|
'github-copilot-openai-compatible',
|
||||||
|
baseConfig,
|
||||||
|
copilotExtraOptions
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
providerId: 'github-copilot-openai-compatible',
|
||||||
|
options
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build extra options
|
||||||
|
const extraOptions: Record<string, unknown> = {}
|
||||||
|
if (endpoint) {
|
||||||
|
extraOptions.endpoint = endpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle OpenAI mode
|
||||||
|
if (provider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(modelId)) {
|
||||||
|
extraOptions.mode = 'responses'
|
||||||
|
} else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && provider.type === 'openai')) {
|
||||||
|
extraOptions.mode = 'chat'
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add extra headers
|
||||||
|
if (provider.extra_headers) {
|
||||||
|
extraOptions.headers = provider.extra_headers
|
||||||
|
if (aiSdkProviderId === 'openai') {
|
||||||
|
extraOptions.headers = {
|
||||||
|
...(extraOptions.headers as Record<string, string>),
|
||||||
|
'HTTP-Referer': 'https://cherry-ai.com',
|
||||||
|
'X-Title': 'Cherry Studio',
|
||||||
|
'X-Api-Key': baseConfig.apiKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Azure modes
|
||||||
|
if (aiSdkProviderId === 'azure-responses') {
|
||||||
|
extraOptions.mode = 'responses'
|
||||||
|
} else if (aiSdkProviderId === 'azure') {
|
||||||
|
extraOptions.mode = 'chat'
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle AWS Bedrock
|
||||||
|
if (aiSdkProviderId === 'bedrock') {
|
||||||
|
const bedrockConfig = context.getAwsBedrockConfig?.()
|
||||||
|
if (bedrockConfig) {
|
||||||
|
extraOptions.region = bedrockConfig.region
|
||||||
|
if (bedrockConfig.authType === 'apiKey') {
|
||||||
|
extraOptions.apiKey = bedrockConfig.apiKey
|
||||||
|
} else {
|
||||||
|
extraOptions.accessKeyId = bedrockConfig.accessKeyId
|
||||||
|
extraOptions.secretAccessKey = bedrockConfig.secretAccessKey
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle Vertex AI
|
||||||
|
if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') {
|
||||||
|
const vertexConfig = context.getVertexConfig?.(provider)
|
||||||
|
if (vertexConfig) {
|
||||||
|
extraOptions.project = vertexConfig.project
|
||||||
|
extraOptions.location = vertexConfig.location
|
||||||
|
extraOptions.googleCredentials = {
|
||||||
|
...vertexConfig.googleCredentials,
|
||||||
|
privateKey: formatPrivateKey(vertexConfig.googleCredentials.privateKey)
|
||||||
|
}
|
||||||
|
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle cherryin endpoint type
|
||||||
|
if (aiSdkProviderId === 'cherryin') {
|
||||||
|
const endpointType = context.getEndpointType?.(modelId)
|
||||||
|
if (endpointType) {
|
||||||
|
extraOptions.endpointType = endpointType
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle cherryai signed fetch
|
||||||
|
if (provider.id === 'cherryai') {
|
||||||
|
const signedFetch = context.getCherryAISignedFetch?.()
|
||||||
|
if (signedFetch) {
|
||||||
|
extraOptions.fetch = signedFetch
|
||||||
|
}
|
||||||
|
} else if (context.fetch) {
|
||||||
|
extraOptions.fetch = context.fetch
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if AI SDK supports this provider natively
|
||||||
|
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||||
|
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
|
||||||
|
return {
|
||||||
|
providerId: aiSdkProviderId,
|
||||||
|
options
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback to openai-compatible
|
||||||
|
const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey)
|
||||||
|
return {
|
||||||
|
providerId: 'openai-compatible',
|
||||||
|
options: {
|
||||||
|
...options,
|
||||||
|
name: provider.id,
|
||||||
|
...extraOptions,
|
||||||
|
includeUsage: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
174
packages/shared/provider/types.ts
Normal file
174
packages/shared/provider/types.ts
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
import * as z from 'zod'
|
||||||
|
|
||||||
|
export const ProviderTypeSchema = z.enum([
|
||||||
|
'openai',
|
||||||
|
'openai-response',
|
||||||
|
'anthropic',
|
||||||
|
'gemini',
|
||||||
|
'azure-openai',
|
||||||
|
'vertexai',
|
||||||
|
'mistral',
|
||||||
|
'aws-bedrock',
|
||||||
|
'vertex-anthropic',
|
||||||
|
'new-api',
|
||||||
|
'ai-gateway'
|
||||||
|
])
|
||||||
|
|
||||||
|
export type ProviderType = z.infer<typeof ProviderTypeSchema>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Minimal provider interface for shared utilities
|
||||||
|
* This is the subset of Provider that shared code needs
|
||||||
|
*/
|
||||||
|
export type MinimalProvider = {
|
||||||
|
id: string
|
||||||
|
type: ProviderType
|
||||||
|
apiKey: string
|
||||||
|
apiHost: string
|
||||||
|
anthropicApiHost?: string
|
||||||
|
apiVersion?: string
|
||||||
|
extra_headers?: Record<string, string>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Minimal model interface for shared utilities
|
||||||
|
* This is the subset of Model that shared code needs
|
||||||
|
*/
|
||||||
|
export type MinimalModel = {
|
||||||
|
id: string
|
||||||
|
endpoint_type?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export const SystemProviderIdSchema = z.enum([
|
||||||
|
'cherryin',
|
||||||
|
'silicon',
|
||||||
|
'aihubmix',
|
||||||
|
'ocoolai',
|
||||||
|
'deepseek',
|
||||||
|
'ppio',
|
||||||
|
'alayanew',
|
||||||
|
'qiniu',
|
||||||
|
'dmxapi',
|
||||||
|
'burncloud',
|
||||||
|
'tokenflux',
|
||||||
|
'302ai',
|
||||||
|
'cephalon',
|
||||||
|
'lanyun',
|
||||||
|
'ph8',
|
||||||
|
'openrouter',
|
||||||
|
'ollama',
|
||||||
|
'ovms',
|
||||||
|
'new-api',
|
||||||
|
'lmstudio',
|
||||||
|
'anthropic',
|
||||||
|
'openai',
|
||||||
|
'azure-openai',
|
||||||
|
'gemini',
|
||||||
|
'vertexai',
|
||||||
|
'github',
|
||||||
|
'copilot',
|
||||||
|
'zhipu',
|
||||||
|
'yi',
|
||||||
|
'moonshot',
|
||||||
|
'baichuan',
|
||||||
|
'dashscope',
|
||||||
|
'stepfun',
|
||||||
|
'doubao',
|
||||||
|
'infini',
|
||||||
|
'minimax',
|
||||||
|
'groq',
|
||||||
|
'together',
|
||||||
|
'fireworks',
|
||||||
|
'nvidia',
|
||||||
|
'grok',
|
||||||
|
'hyperbolic',
|
||||||
|
'mistral',
|
||||||
|
'jina',
|
||||||
|
'perplexity',
|
||||||
|
'modelscope',
|
||||||
|
'xirang',
|
||||||
|
'hunyuan',
|
||||||
|
'tencent-cloud-ti',
|
||||||
|
'baidu-cloud',
|
||||||
|
'gpustack',
|
||||||
|
'voyageai',
|
||||||
|
'aws-bedrock',
|
||||||
|
'poe',
|
||||||
|
'aionly',
|
||||||
|
'longcat',
|
||||||
|
'huggingface',
|
||||||
|
'sophnet',
|
||||||
|
'ai-gateway',
|
||||||
|
'cerebras'
|
||||||
|
])
|
||||||
|
|
||||||
|
export type SystemProviderId = z.infer<typeof SystemProviderIdSchema>
|
||||||
|
|
||||||
|
export const isSystemProviderId = (id: string): id is SystemProviderId => {
|
||||||
|
return SystemProviderIdSchema.safeParse(id).success
|
||||||
|
}
|
||||||
|
|
||||||
|
export const SystemProviderIds = {
|
||||||
|
cherryin: 'cherryin',
|
||||||
|
silicon: 'silicon',
|
||||||
|
aihubmix: 'aihubmix',
|
||||||
|
ocoolai: 'ocoolai',
|
||||||
|
deepseek: 'deepseek',
|
||||||
|
ppio: 'ppio',
|
||||||
|
alayanew: 'alayanew',
|
||||||
|
qiniu: 'qiniu',
|
||||||
|
dmxapi: 'dmxapi',
|
||||||
|
burncloud: 'burncloud',
|
||||||
|
tokenflux: 'tokenflux',
|
||||||
|
'302ai': '302ai',
|
||||||
|
cephalon: 'cephalon',
|
||||||
|
lanyun: 'lanyun',
|
||||||
|
ph8: 'ph8',
|
||||||
|
sophnet: 'sophnet',
|
||||||
|
openrouter: 'openrouter',
|
||||||
|
ollama: 'ollama',
|
||||||
|
ovms: 'ovms',
|
||||||
|
'new-api': 'new-api',
|
||||||
|
lmstudio: 'lmstudio',
|
||||||
|
anthropic: 'anthropic',
|
||||||
|
openai: 'openai',
|
||||||
|
'azure-openai': 'azure-openai',
|
||||||
|
gemini: 'gemini',
|
||||||
|
vertexai: 'vertexai',
|
||||||
|
github: 'github',
|
||||||
|
copilot: 'copilot',
|
||||||
|
zhipu: 'zhipu',
|
||||||
|
yi: 'yi',
|
||||||
|
moonshot: 'moonshot',
|
||||||
|
baichuan: 'baichuan',
|
||||||
|
dashscope: 'dashscope',
|
||||||
|
stepfun: 'stepfun',
|
||||||
|
doubao: 'doubao',
|
||||||
|
infini: 'infini',
|
||||||
|
minimax: 'minimax',
|
||||||
|
groq: 'groq',
|
||||||
|
together: 'together',
|
||||||
|
fireworks: 'fireworks',
|
||||||
|
nvidia: 'nvidia',
|
||||||
|
grok: 'grok',
|
||||||
|
hyperbolic: 'hyperbolic',
|
||||||
|
mistral: 'mistral',
|
||||||
|
jina: 'jina',
|
||||||
|
perplexity: 'perplexity',
|
||||||
|
modelscope: 'modelscope',
|
||||||
|
xirang: 'xirang',
|
||||||
|
hunyuan: 'hunyuan',
|
||||||
|
'tencent-cloud-ti': 'tencent-cloud-ti',
|
||||||
|
'baidu-cloud': 'baidu-cloud',
|
||||||
|
gpustack: 'gpustack',
|
||||||
|
voyageai: 'voyageai',
|
||||||
|
'aws-bedrock': 'aws-bedrock',
|
||||||
|
poe: 'poe',
|
||||||
|
aionly: 'aionly',
|
||||||
|
longcat: 'longcat',
|
||||||
|
huggingface: 'huggingface',
|
||||||
|
'ai-gateway': 'ai-gateway',
|
||||||
|
cerebras: 'cerebras'
|
||||||
|
} as const satisfies Record<SystemProviderId, SystemProviderId>
|
||||||
|
|
||||||
|
export type SystemProviderIdTypeMap = typeof SystemProviderIds
|
||||||
1
packages/shared/utils/index.ts
Normal file
1
packages/shared/utils/index.ts
Normal file
@@ -0,0 +1 @@
|
|||||||
|
export { getBaseModelName, getLowerBaseModelName } from './naming'
|
||||||
31
packages/shared/utils/naming.ts
Normal file
31
packages/shared/utils/naming.ts
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
/**
|
||||||
|
* 从模型 ID 中提取基础名称。
|
||||||
|
* 例如:
|
||||||
|
* - 'deepseek/deepseek-r1' => 'deepseek-r1'
|
||||||
|
* - 'deepseek-ai/deepseek/deepseek-r1' => 'deepseek-r1'
|
||||||
|
* @param {string} id 模型 ID
|
||||||
|
* @param {string} [delimiter='/'] 分隔符,默认为 '/'
|
||||||
|
* @returns {string} 基础名称
|
||||||
|
*/
|
||||||
|
export const getBaseModelName = (id: string, delimiter: string = '/'): string => {
|
||||||
|
const parts = id.split(delimiter)
|
||||||
|
return parts[parts.length - 1]
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从模型 ID 中提取基础名称并转换为小写。
|
||||||
|
* 例如:
|
||||||
|
* - 'deepseek/DeepSeek-R1' => 'deepseek-r1'
|
||||||
|
* - 'deepseek-ai/deepseek/DeepSeek-R1' => 'deepseek-r1'
|
||||||
|
* @param {string} id 模型 ID
|
||||||
|
* @param {string} [delimiter='/'] 分隔符,默认为 '/'
|
||||||
|
* @returns {string} 小写的基础名称
|
||||||
|
*/
|
||||||
|
export const getLowerBaseModelName = (id: string, delimiter: string = '/'): string => {
|
||||||
|
const baseModelName = getBaseModelName(id, delimiter).toLowerCase()
|
||||||
|
// for openrouter
|
||||||
|
if (baseModelName.endsWith(':free')) {
|
||||||
|
return baseModelName.replace(':free', '')
|
||||||
|
}
|
||||||
|
return baseModelName
|
||||||
|
}
|
||||||
@@ -1,42 +1,64 @@
|
|||||||
import { defineConfig, devices } from '@playwright/test'
|
import { defineConfig } from '@playwright/test'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* See https://playwright.dev/docs/test-configuration.
|
* Playwright configuration for Electron e2e testing.
|
||||||
|
* See https://playwright.dev/docs/test-configuration
|
||||||
*/
|
*/
|
||||||
export default defineConfig({
|
export default defineConfig({
|
||||||
// Look for test files, relative to this configuration file.
|
// Look for test files in the specs directory
|
||||||
testDir: './tests/e2e',
|
testDir: './tests/e2e/specs',
|
||||||
/* Run tests in files in parallel */
|
|
||||||
fullyParallel: true,
|
|
||||||
/* Fail the build on CI if you accidentally left test.only in the source code. */
|
|
||||||
forbidOnly: !!process.env.CI,
|
|
||||||
/* Retry on CI only */
|
|
||||||
retries: process.env.CI ? 2 : 0,
|
|
||||||
/* Opt out of parallel tests on CI. */
|
|
||||||
workers: process.env.CI ? 1 : undefined,
|
|
||||||
/* Reporter to use. See https://playwright.dev/docs/test-reporters */
|
|
||||||
reporter: 'html',
|
|
||||||
/* Shared settings for all the projects below. See https://playwright.dev/docs/api/class-testoptions. */
|
|
||||||
use: {
|
|
||||||
/* Base URL to use in actions like `await page.goto('/')`. */
|
|
||||||
// baseURL: 'http://localhost:3000',
|
|
||||||
|
|
||||||
/* Collect trace when retrying the failed test. See https://playwright.dev/docs/trace-viewer */
|
// Global timeout for each test
|
||||||
trace: 'on-first-retry'
|
timeout: 60000,
|
||||||
|
|
||||||
|
// Assertion timeout
|
||||||
|
expect: {
|
||||||
|
timeout: 10000
|
||||||
},
|
},
|
||||||
|
|
||||||
/* Configure projects for major browsers */
|
// Electron apps should run tests sequentially to avoid conflicts
|
||||||
|
fullyParallel: false,
|
||||||
|
workers: 1,
|
||||||
|
|
||||||
|
// Fail the build on CI if you accidentally left test.only in the source code
|
||||||
|
forbidOnly: !!process.env.CI,
|
||||||
|
|
||||||
|
// Retry on CI only
|
||||||
|
retries: process.env.CI ? 2 : 0,
|
||||||
|
|
||||||
|
// Reporter configuration
|
||||||
|
reporter: [['html', { outputFolder: 'playwright-report' }], ['list']],
|
||||||
|
|
||||||
|
// Global setup and teardown
|
||||||
|
globalSetup: './tests/e2e/global-setup.ts',
|
||||||
|
globalTeardown: './tests/e2e/global-teardown.ts',
|
||||||
|
|
||||||
|
// Output directory for test artifacts
|
||||||
|
outputDir: './test-results',
|
||||||
|
|
||||||
|
// Shared settings for all tests
|
||||||
|
use: {
|
||||||
|
// Collect trace when retrying the failed test
|
||||||
|
trace: 'retain-on-failure',
|
||||||
|
|
||||||
|
// Take screenshot only on failure
|
||||||
|
screenshot: 'only-on-failure',
|
||||||
|
|
||||||
|
// Record video only on failure
|
||||||
|
video: 'retain-on-failure',
|
||||||
|
|
||||||
|
// Action timeout
|
||||||
|
actionTimeout: 15000,
|
||||||
|
|
||||||
|
// Navigation timeout
|
||||||
|
navigationTimeout: 30000
|
||||||
|
},
|
||||||
|
|
||||||
|
// Single project for Electron testing
|
||||||
projects: [
|
projects: [
|
||||||
{
|
{
|
||||||
name: 'chromium',
|
name: 'electron',
|
||||||
use: { ...devices['Desktop Chrome'] }
|
testMatch: '**/*.spec.ts'
|
||||||
}
|
}
|
||||||
]
|
]
|
||||||
|
|
||||||
/* Run your local dev server before starting the tests */
|
|
||||||
// webServer: {
|
|
||||||
// command: 'npm run start',
|
|
||||||
// url: 'http://localhost:3000',
|
|
||||||
// reuseExistingServer: !process.env.CI,
|
|
||||||
// },
|
|
||||||
})
|
})
|
||||||
|
|||||||
638
src/main/apiServer/adapters/AiSdkToAnthropicSSE.ts
Normal file
638
src/main/apiServer/adapters/AiSdkToAnthropicSSE.ts
Normal file
@@ -0,0 +1,638 @@
|
|||||||
|
/**
|
||||||
|
* AI SDK to Anthropic SSE Adapter
|
||||||
|
*
|
||||||
|
* Converts AI SDK's fullStream (TextStreamPart) events to Anthropic Messages API SSE format.
|
||||||
|
* This enables any AI provider supported by AI SDK to be exposed via Anthropic-compatible API.
|
||||||
|
*
|
||||||
|
* Anthropic SSE Event Flow:
|
||||||
|
* 1. message_start - Initial message with metadata
|
||||||
|
* 2. content_block_start - Begin a content block (text, tool_use, thinking)
|
||||||
|
* 3. content_block_delta - Incremental content updates
|
||||||
|
* 4. content_block_stop - End a content block
|
||||||
|
* 5. message_delta - Updates to overall message (stop_reason, usage)
|
||||||
|
* 6. message_stop - Stream complete
|
||||||
|
*
|
||||||
|
* @see https://docs.anthropic.com/en/api/messages-streaming
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type {
|
||||||
|
ContentBlock,
|
||||||
|
InputJSONDelta,
|
||||||
|
Message,
|
||||||
|
MessageDeltaUsage,
|
||||||
|
RawContentBlockDeltaEvent,
|
||||||
|
RawContentBlockStartEvent,
|
||||||
|
RawContentBlockStopEvent,
|
||||||
|
RawMessageDeltaEvent,
|
||||||
|
RawMessageStartEvent,
|
||||||
|
RawMessageStopEvent,
|
||||||
|
RawMessageStreamEvent,
|
||||||
|
StopReason,
|
||||||
|
TextBlock,
|
||||||
|
TextDelta,
|
||||||
|
ThinkingBlock,
|
||||||
|
ThinkingDelta,
|
||||||
|
ToolUseBlock,
|
||||||
|
Usage
|
||||||
|
} from '@anthropic-ai/sdk/resources/messages'
|
||||||
|
import { loggerService } from '@logger'
|
||||||
|
import { type FinishReason, type LanguageModelUsage, type TextStreamPart, type ToolSet } from 'ai'
|
||||||
|
|
||||||
|
import { googleReasoningCache, openRouterReasoningCache } from '../../services/CacheService'
|
||||||
|
|
||||||
|
const logger = loggerService.withContext('AiSdkToAnthropicSSE')
|
||||||
|
|
||||||
|
interface ContentBlockState {
|
||||||
|
type: 'text' | 'tool_use' | 'thinking'
|
||||||
|
index: number
|
||||||
|
started: boolean
|
||||||
|
content: string
|
||||||
|
// For tool_use blocks
|
||||||
|
toolId?: string
|
||||||
|
toolName?: string
|
||||||
|
toolInput?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
interface AdapterState {
|
||||||
|
messageId: string
|
||||||
|
model: string
|
||||||
|
inputTokens: number
|
||||||
|
outputTokens: number
|
||||||
|
cacheInputTokens: number
|
||||||
|
currentBlockIndex: number
|
||||||
|
blocks: Map<number, ContentBlockState>
|
||||||
|
textBlockIndex: number | null
|
||||||
|
// Track multiple thinking blocks by their reasoning ID
|
||||||
|
thinkingBlocks: Map<string, number> // reasoningId -> blockIndex
|
||||||
|
currentThinkingId: string | null // Currently active thinking block ID
|
||||||
|
toolBlocks: Map<string, number> // toolCallId -> blockIndex
|
||||||
|
stopReason: StopReason | null
|
||||||
|
hasEmittedMessageStart: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
export type SSEEventCallback = (event: RawMessageStreamEvent) => void
|
||||||
|
|
||||||
|
export interface AiSdkToAnthropicSSEOptions {
|
||||||
|
model: string
|
||||||
|
messageId?: string
|
||||||
|
inputTokens?: number
|
||||||
|
onEvent: SSEEventCallback
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Adapter that converts AI SDK fullStream events to Anthropic SSE events
|
||||||
|
*/
|
||||||
|
export class AiSdkToAnthropicSSE {
|
||||||
|
private state: AdapterState
|
||||||
|
private onEvent: SSEEventCallback
|
||||||
|
|
||||||
|
constructor(options: AiSdkToAnthropicSSEOptions) {
|
||||||
|
this.onEvent = options.onEvent
|
||||||
|
this.state = {
|
||||||
|
messageId: options.messageId || `msg_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`,
|
||||||
|
model: options.model,
|
||||||
|
inputTokens: options.inputTokens || 0,
|
||||||
|
outputTokens: 0,
|
||||||
|
cacheInputTokens: 0,
|
||||||
|
currentBlockIndex: 0,
|
||||||
|
blocks: new Map(),
|
||||||
|
textBlockIndex: null,
|
||||||
|
thinkingBlocks: new Map(),
|
||||||
|
currentThinkingId: null,
|
||||||
|
toolBlocks: new Map(),
|
||||||
|
stopReason: null,
|
||||||
|
hasEmittedMessageStart: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Process the AI SDK stream and emit Anthropic SSE events
|
||||||
|
*/
|
||||||
|
async processStream(fullStream: ReadableStream<TextStreamPart<ToolSet>>): Promise<void> {
|
||||||
|
const reader = fullStream.getReader()
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Emit message_start at the beginning
|
||||||
|
this.emitMessageStart()
|
||||||
|
|
||||||
|
while (true) {
|
||||||
|
const { done, value } = await reader.read()
|
||||||
|
|
||||||
|
if (done) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
this.processChunk(value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Ensure all blocks are closed and emit final events
|
||||||
|
this.finalize()
|
||||||
|
} catch (error) {
|
||||||
|
await reader.cancel()
|
||||||
|
throw error
|
||||||
|
} finally {
|
||||||
|
reader.releaseLock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Process a single AI SDK chunk and emit corresponding Anthropic events
|
||||||
|
*/
|
||||||
|
private processChunk(chunk: TextStreamPart<ToolSet>): void {
|
||||||
|
logger.silly('AiSdkToAnthropicSSE - Processing chunk:', { chunk: JSON.stringify(chunk) })
|
||||||
|
switch (chunk.type) {
|
||||||
|
// === Text Events ===
|
||||||
|
case 'text-start':
|
||||||
|
this.startTextBlock()
|
||||||
|
break
|
||||||
|
|
||||||
|
case 'text-delta':
|
||||||
|
this.emitTextDelta(chunk.text || '')
|
||||||
|
break
|
||||||
|
|
||||||
|
case 'text-end':
|
||||||
|
this.stopTextBlock()
|
||||||
|
break
|
||||||
|
|
||||||
|
// === Reasoning/Thinking Events ===
|
||||||
|
case 'reasoning-start': {
|
||||||
|
const reasoningId = chunk.id
|
||||||
|
this.startThinkingBlock(reasoningId)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'reasoning-delta': {
|
||||||
|
const reasoningId = chunk.id
|
||||||
|
this.emitThinkingDelta(chunk.text || '', reasoningId)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'reasoning-end': {
|
||||||
|
const reasoningId = chunk.id
|
||||||
|
this.stopThinkingBlock(reasoningId)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// === Tool Events ===
|
||||||
|
case 'tool-call':
|
||||||
|
if (googleReasoningCache && chunk.providerMetadata?.google?.thoughtSignature) {
|
||||||
|
googleReasoningCache.set(
|
||||||
|
`google-${chunk.toolName}`,
|
||||||
|
chunk.providerMetadata?.google?.thoughtSignature as string
|
||||||
|
)
|
||||||
|
}
|
||||||
|
// FIXME: 按toolcall id绑定
|
||||||
|
if (
|
||||||
|
openRouterReasoningCache &&
|
||||||
|
chunk.providerMetadata?.openrouter?.reasoning_details &&
|
||||||
|
Array.isArray(chunk.providerMetadata.openrouter.reasoning_details)
|
||||||
|
) {
|
||||||
|
openRouterReasoningCache.set(
|
||||||
|
'openrouter',
|
||||||
|
JSON.parse(JSON.stringify(chunk.providerMetadata.openrouter.reasoning_details))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
this.handleToolCall({
|
||||||
|
type: 'tool-call',
|
||||||
|
toolCallId: chunk.toolCallId,
|
||||||
|
toolName: chunk.toolName,
|
||||||
|
args: chunk.input
|
||||||
|
})
|
||||||
|
break
|
||||||
|
|
||||||
|
case 'tool-result':
|
||||||
|
// this.handleToolResult({
|
||||||
|
// type: 'tool-result',
|
||||||
|
// toolCallId: chunk.toolCallId,
|
||||||
|
// toolName: chunk.toolName,
|
||||||
|
// args: chunk.input,
|
||||||
|
// result: chunk.output
|
||||||
|
// })
|
||||||
|
break
|
||||||
|
|
||||||
|
case 'finish-step':
|
||||||
|
if (chunk.finishReason === 'tool-calls') {
|
||||||
|
this.state.stopReason = 'tool_use'
|
||||||
|
}
|
||||||
|
break
|
||||||
|
|
||||||
|
case 'finish':
|
||||||
|
this.handleFinish(chunk)
|
||||||
|
break
|
||||||
|
|
||||||
|
case 'error':
|
||||||
|
throw chunk.error
|
||||||
|
|
||||||
|
// Ignore other event types
|
||||||
|
default:
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private emitMessageStart(): void {
|
||||||
|
if (this.state.hasEmittedMessageStart) return
|
||||||
|
|
||||||
|
this.state.hasEmittedMessageStart = true
|
||||||
|
|
||||||
|
const usage: Usage = {
|
||||||
|
input_tokens: this.state.inputTokens,
|
||||||
|
output_tokens: 0,
|
||||||
|
cache_creation_input_tokens: 0,
|
||||||
|
cache_read_input_tokens: 0,
|
||||||
|
server_tool_use: null
|
||||||
|
}
|
||||||
|
|
||||||
|
const message: Message = {
|
||||||
|
id: this.state.messageId,
|
||||||
|
type: 'message',
|
||||||
|
role: 'assistant',
|
||||||
|
content: [],
|
||||||
|
model: this.state.model,
|
||||||
|
stop_reason: null,
|
||||||
|
stop_sequence: null,
|
||||||
|
usage
|
||||||
|
}
|
||||||
|
|
||||||
|
const event: RawMessageStartEvent = {
|
||||||
|
type: 'message_start',
|
||||||
|
message
|
||||||
|
}
|
||||||
|
|
||||||
|
this.onEvent(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
private startTextBlock(): void {
|
||||||
|
// If we already have a text block, don't create another
|
||||||
|
if (this.state.textBlockIndex !== null) return
|
||||||
|
|
||||||
|
const index = this.state.currentBlockIndex++
|
||||||
|
this.state.textBlockIndex = index
|
||||||
|
this.state.blocks.set(index, {
|
||||||
|
type: 'text',
|
||||||
|
index,
|
||||||
|
started: true,
|
||||||
|
content: ''
|
||||||
|
})
|
||||||
|
|
||||||
|
const contentBlock: TextBlock = {
|
||||||
|
type: 'text',
|
||||||
|
text: '',
|
||||||
|
citations: null
|
||||||
|
}
|
||||||
|
|
||||||
|
const event: RawContentBlockStartEvent = {
|
||||||
|
type: 'content_block_start',
|
||||||
|
index,
|
||||||
|
content_block: contentBlock
|
||||||
|
}
|
||||||
|
|
||||||
|
this.onEvent(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
private emitTextDelta(text: string): void {
|
||||||
|
if (!text) return
|
||||||
|
|
||||||
|
// Auto-start text block if not started
|
||||||
|
if (this.state.textBlockIndex === null) {
|
||||||
|
this.startTextBlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
const index = this.state.textBlockIndex!
|
||||||
|
const block = this.state.blocks.get(index)
|
||||||
|
if (block) {
|
||||||
|
block.content += text
|
||||||
|
}
|
||||||
|
|
||||||
|
const delta: TextDelta = {
|
||||||
|
type: 'text_delta',
|
||||||
|
text
|
||||||
|
}
|
||||||
|
|
||||||
|
const event: RawContentBlockDeltaEvent = {
|
||||||
|
type: 'content_block_delta',
|
||||||
|
index,
|
||||||
|
delta
|
||||||
|
}
|
||||||
|
|
||||||
|
this.onEvent(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
private stopTextBlock(): void {
|
||||||
|
if (this.state.textBlockIndex === null) return
|
||||||
|
|
||||||
|
const index = this.state.textBlockIndex
|
||||||
|
|
||||||
|
const event: RawContentBlockStopEvent = {
|
||||||
|
type: 'content_block_stop',
|
||||||
|
index
|
||||||
|
}
|
||||||
|
|
||||||
|
this.onEvent(event)
|
||||||
|
this.state.textBlockIndex = null
|
||||||
|
}
|
||||||
|
|
||||||
|
private startThinkingBlock(reasoningId: string): void {
|
||||||
|
// Check if this thinking block already exists
|
||||||
|
if (this.state.thinkingBlocks.has(reasoningId)) return
|
||||||
|
|
||||||
|
const index = this.state.currentBlockIndex++
|
||||||
|
this.state.thinkingBlocks.set(reasoningId, index)
|
||||||
|
this.state.currentThinkingId = reasoningId
|
||||||
|
this.state.blocks.set(index, {
|
||||||
|
type: 'thinking',
|
||||||
|
index,
|
||||||
|
started: true,
|
||||||
|
content: ''
|
||||||
|
})
|
||||||
|
|
||||||
|
const contentBlock: ThinkingBlock = {
|
||||||
|
type: 'thinking',
|
||||||
|
thinking: '',
|
||||||
|
signature: ''
|
||||||
|
}
|
||||||
|
|
||||||
|
const event: RawContentBlockStartEvent = {
|
||||||
|
type: 'content_block_start',
|
||||||
|
index,
|
||||||
|
content_block: contentBlock
|
||||||
|
}
|
||||||
|
|
||||||
|
this.onEvent(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
private emitThinkingDelta(text: string, reasoningId?: string): void {
|
||||||
|
if (!text) return
|
||||||
|
|
||||||
|
// Determine which thinking block to use
|
||||||
|
const targetId = reasoningId || this.state.currentThinkingId
|
||||||
|
if (!targetId) {
|
||||||
|
// Auto-start thinking block if not started
|
||||||
|
const newId = `reasoning_${Date.now()}`
|
||||||
|
this.startThinkingBlock(newId)
|
||||||
|
return this.emitThinkingDelta(text, newId)
|
||||||
|
}
|
||||||
|
|
||||||
|
const index = this.state.thinkingBlocks.get(targetId)
|
||||||
|
if (index === undefined) {
|
||||||
|
// If the block doesn't exist, create it
|
||||||
|
this.startThinkingBlock(targetId)
|
||||||
|
return this.emitThinkingDelta(text, targetId)
|
||||||
|
}
|
||||||
|
|
||||||
|
const block = this.state.blocks.get(index)
|
||||||
|
if (block) {
|
||||||
|
block.content += text
|
||||||
|
}
|
||||||
|
|
||||||
|
const delta: ThinkingDelta = {
|
||||||
|
type: 'thinking_delta',
|
||||||
|
thinking: text
|
||||||
|
}
|
||||||
|
|
||||||
|
const event: RawContentBlockDeltaEvent = {
|
||||||
|
type: 'content_block_delta',
|
||||||
|
index,
|
||||||
|
delta
|
||||||
|
}
|
||||||
|
|
||||||
|
this.onEvent(event)
|
||||||
|
}
|
||||||
|
|
||||||
|
private stopThinkingBlock(reasoningId?: string): void {
|
||||||
|
const targetId = reasoningId || this.state.currentThinkingId
|
||||||
|
if (!targetId) return
|
||||||
|
|
||||||
|
const index = this.state.thinkingBlocks.get(targetId)
|
||||||
|
if (index === undefined) return
|
||||||
|
|
||||||
|
const event: RawContentBlockStopEvent = {
|
||||||
|
type: 'content_block_stop',
|
||||||
|
index
|
||||||
|
}
|
||||||
|
|
||||||
|
this.onEvent(event)
|
||||||
|
this.state.thinkingBlocks.delete(targetId)
|
||||||
|
|
||||||
|
// Update currentThinkingId if we just closed the current one
|
||||||
|
if (this.state.currentThinkingId === targetId) {
|
||||||
|
// Set to the most recent remaining thinking block, or null if none
|
||||||
|
const remaining = Array.from(this.state.thinkingBlocks.keys())
|
||||||
|
this.state.currentThinkingId = remaining.length > 0 ? remaining[remaining.length - 1] : null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private handleToolCall(chunk: { type: 'tool-call'; toolCallId: string; toolName: string; args: unknown }): void {
|
||||||
|
const { toolCallId, toolName, args } = chunk
|
||||||
|
|
||||||
|
// Check if we already have this tool call
|
||||||
|
if (this.state.toolBlocks.has(toolCallId)) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const index = this.state.currentBlockIndex++
|
||||||
|
this.state.toolBlocks.set(toolCallId, index)
|
||||||
|
|
||||||
|
const inputJson = JSON.stringify(args)
|
||||||
|
|
||||||
|
this.state.blocks.set(index, {
|
||||||
|
type: 'tool_use',
|
||||||
|
index,
|
||||||
|
started: true,
|
||||||
|
content: inputJson,
|
||||||
|
toolId: toolCallId,
|
||||||
|
toolName,
|
||||||
|
toolInput: inputJson
|
||||||
|
})
|
||||||
|
|
||||||
|
// Emit content_block_start for tool_use
|
||||||
|
const contentBlock: ToolUseBlock = {
|
||||||
|
type: 'tool_use',
|
||||||
|
id: toolCallId,
|
||||||
|
name: toolName,
|
||||||
|
input: {}
|
||||||
|
}
|
||||||
|
|
||||||
|
const startEvent: RawContentBlockStartEvent = {
|
||||||
|
type: 'content_block_start',
|
||||||
|
index,
|
||||||
|
content_block: contentBlock
|
||||||
|
}
|
||||||
|
|
||||||
|
this.onEvent(startEvent)
|
||||||
|
|
||||||
|
// Emit the full input as a delta (Anthropic streams JSON incrementally)
|
||||||
|
const delta: InputJSONDelta = {
|
||||||
|
type: 'input_json_delta',
|
||||||
|
partial_json: inputJson
|
||||||
|
}
|
||||||
|
|
||||||
|
const deltaEvent: RawContentBlockDeltaEvent = {
|
||||||
|
type: 'content_block_delta',
|
||||||
|
index,
|
||||||
|
delta
|
||||||
|
}
|
||||||
|
|
||||||
|
this.onEvent(deltaEvent)
|
||||||
|
|
||||||
|
// Emit content_block_stop
|
||||||
|
const stopEvent: RawContentBlockStopEvent = {
|
||||||
|
type: 'content_block_stop',
|
||||||
|
index
|
||||||
|
}
|
||||||
|
|
||||||
|
this.onEvent(stopEvent)
|
||||||
|
|
||||||
|
// Mark that we have tool use
|
||||||
|
this.state.stopReason = 'tool_use'
|
||||||
|
}
|
||||||
|
|
||||||
|
private handleFinish(chunk: { type: 'finish'; finishReason?: FinishReason; totalUsage?: LanguageModelUsage }): void {
|
||||||
|
// Update usage
|
||||||
|
if (chunk.totalUsage) {
|
||||||
|
this.state.inputTokens = chunk.totalUsage.inputTokens || 0
|
||||||
|
this.state.outputTokens = chunk.totalUsage.outputTokens || 0
|
||||||
|
this.state.cacheInputTokens = chunk.totalUsage.cachedInputTokens || 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine finish reason
|
||||||
|
if (!this.state.stopReason) {
|
||||||
|
switch (chunk.finishReason) {
|
||||||
|
case 'stop':
|
||||||
|
this.state.stopReason = 'end_turn'
|
||||||
|
break
|
||||||
|
case 'length':
|
||||||
|
this.state.stopReason = 'max_tokens'
|
||||||
|
break
|
||||||
|
case 'tool-calls':
|
||||||
|
this.state.stopReason = 'tool_use'
|
||||||
|
break
|
||||||
|
case 'content-filter':
|
||||||
|
this.state.stopReason = 'refusal'
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
this.state.stopReason = 'end_turn'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private finalize(): void {
|
||||||
|
// Close any open blocks
|
||||||
|
if (this.state.textBlockIndex !== null) {
|
||||||
|
this.stopTextBlock()
|
||||||
|
}
|
||||||
|
// Close all open thinking blocks
|
||||||
|
for (const reasoningId of this.state.thinkingBlocks.keys()) {
|
||||||
|
this.stopThinkingBlock(reasoningId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Emit message_delta with final stop reason and usage
|
||||||
|
const usage: MessageDeltaUsage = {
|
||||||
|
output_tokens: this.state.outputTokens,
|
||||||
|
input_tokens: this.state.inputTokens,
|
||||||
|
cache_creation_input_tokens: this.state.cacheInputTokens,
|
||||||
|
cache_read_input_tokens: null,
|
||||||
|
server_tool_use: null
|
||||||
|
}
|
||||||
|
|
||||||
|
const messageDeltaEvent: RawMessageDeltaEvent = {
|
||||||
|
type: 'message_delta',
|
||||||
|
delta: {
|
||||||
|
stop_reason: this.state.stopReason || 'end_turn',
|
||||||
|
stop_sequence: null
|
||||||
|
},
|
||||||
|
usage
|
||||||
|
}
|
||||||
|
|
||||||
|
this.onEvent(messageDeltaEvent)
|
||||||
|
|
||||||
|
// Emit message_stop
|
||||||
|
const messageStopEvent: RawMessageStopEvent = {
|
||||||
|
type: 'message_stop'
|
||||||
|
}
|
||||||
|
|
||||||
|
this.onEvent(messageStopEvent)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Set input token count (typically from prompt)
|
||||||
|
*/
|
||||||
|
setInputTokens(count: number): void {
|
||||||
|
this.state.inputTokens = count
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the current message ID
|
||||||
|
*/
|
||||||
|
getMessageId(): string {
|
||||||
|
return this.state.messageId
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Build a complete Message object for non-streaming responses
|
||||||
|
*/
|
||||||
|
buildNonStreamingResponse(): Message {
|
||||||
|
const content: ContentBlock[] = []
|
||||||
|
|
||||||
|
// Collect all content blocks in order
|
||||||
|
const sortedBlocks = Array.from(this.state.blocks.values()).sort((a, b) => a.index - b.index)
|
||||||
|
|
||||||
|
for (const block of sortedBlocks) {
|
||||||
|
switch (block.type) {
|
||||||
|
case 'text':
|
||||||
|
content.push({
|
||||||
|
type: 'text',
|
||||||
|
text: block.content,
|
||||||
|
citations: null
|
||||||
|
} as TextBlock)
|
||||||
|
break
|
||||||
|
case 'thinking':
|
||||||
|
content.push({
|
||||||
|
type: 'thinking',
|
||||||
|
thinking: block.content
|
||||||
|
} as ThinkingBlock)
|
||||||
|
break
|
||||||
|
case 'tool_use':
|
||||||
|
content.push({
|
||||||
|
type: 'tool_use',
|
||||||
|
id: block.toolId!,
|
||||||
|
name: block.toolName!,
|
||||||
|
input: JSON.parse(block.toolInput || '{}')
|
||||||
|
} as ToolUseBlock)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
id: this.state.messageId,
|
||||||
|
type: 'message',
|
||||||
|
role: 'assistant',
|
||||||
|
content,
|
||||||
|
model: this.state.model,
|
||||||
|
stop_reason: this.state.stopReason || 'end_turn',
|
||||||
|
stop_sequence: null,
|
||||||
|
usage: {
|
||||||
|
input_tokens: this.state.inputTokens,
|
||||||
|
output_tokens: this.state.outputTokens,
|
||||||
|
cache_creation_input_tokens: 0,
|
||||||
|
cache_read_input_tokens: 0,
|
||||||
|
server_tool_use: null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Format an Anthropic SSE event for HTTP streaming
|
||||||
|
*/
|
||||||
|
export function formatSSEEvent(event: RawMessageStreamEvent): string {
|
||||||
|
return `event: ${event.type}\ndata: ${JSON.stringify(event)}\n\n`
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create a done marker for SSE stream
|
||||||
|
*/
|
||||||
|
export function formatSSEDone(): string {
|
||||||
|
return 'data: [DONE]\n\n'
|
||||||
|
}
|
||||||
|
|
||||||
|
export default AiSdkToAnthropicSSE
|
||||||
13
src/main/apiServer/adapters/index.ts
Normal file
13
src/main/apiServer/adapters/index.ts
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
/**
|
||||||
|
* Shared Adapters
|
||||||
|
*
|
||||||
|
* This module exports adapters for converting between different AI API formats.
|
||||||
|
*/
|
||||||
|
|
||||||
|
export {
|
||||||
|
AiSdkToAnthropicSSE,
|
||||||
|
type AiSdkToAnthropicSSEOptions,
|
||||||
|
formatSSEDone,
|
||||||
|
formatSSEEvent,
|
||||||
|
type SSEEventCallback
|
||||||
|
} from './AiSdkToAnthropicSSE'
|
||||||
95
src/main/apiServer/adapters/openrouter.ts
Normal file
95
src/main/apiServer/adapters/openrouter.ts
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import * as z from 'zod/v4'
|
||||||
|
|
||||||
|
enum ReasoningFormat {
|
||||||
|
Unknown = 'unknown',
|
||||||
|
OpenAIResponsesV1 = 'openai-responses-v1',
|
||||||
|
XAIResponsesV1 = 'xai-responses-v1',
|
||||||
|
AnthropicClaudeV1 = 'anthropic-claude-v1',
|
||||||
|
GoogleGeminiV1 = 'google-gemini-v1'
|
||||||
|
}
|
||||||
|
|
||||||
|
// Anthropic Claude was the first reasoning that we're
|
||||||
|
// passing back and forth
|
||||||
|
export const DEFAULT_REASONING_FORMAT = ReasoningFormat.AnthropicClaudeV1
|
||||||
|
|
||||||
|
function isDefinedOrNotNull<T>(value: T | null | undefined): value is T {
|
||||||
|
return value !== null && value !== undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
export enum ReasoningDetailType {
|
||||||
|
Summary = 'reasoning.summary',
|
||||||
|
Encrypted = 'reasoning.encrypted',
|
||||||
|
Text = 'reasoning.text'
|
||||||
|
}
|
||||||
|
|
||||||
|
export const CommonReasoningDetailSchema = z
|
||||||
|
.object({
|
||||||
|
id: z.string().nullish(),
|
||||||
|
format: z.enum(ReasoningFormat).nullish(),
|
||||||
|
index: z.number().optional()
|
||||||
|
})
|
||||||
|
.loose()
|
||||||
|
|
||||||
|
export const ReasoningDetailSummarySchema = z
|
||||||
|
.object({
|
||||||
|
type: z.literal(ReasoningDetailType.Summary),
|
||||||
|
summary: z.string()
|
||||||
|
})
|
||||||
|
.extend(CommonReasoningDetailSchema.shape)
|
||||||
|
export type ReasoningDetailSummary = z.infer<typeof ReasoningDetailSummarySchema>
|
||||||
|
|
||||||
|
export const ReasoningDetailEncryptedSchema = z
|
||||||
|
.object({
|
||||||
|
type: z.literal(ReasoningDetailType.Encrypted),
|
||||||
|
data: z.string()
|
||||||
|
})
|
||||||
|
.extend(CommonReasoningDetailSchema.shape)
|
||||||
|
|
||||||
|
export type ReasoningDetailEncrypted = z.infer<typeof ReasoningDetailEncryptedSchema>
|
||||||
|
|
||||||
|
export const ReasoningDetailTextSchema = z
|
||||||
|
.object({
|
||||||
|
type: z.literal(ReasoningDetailType.Text),
|
||||||
|
text: z.string().nullish(),
|
||||||
|
signature: z.string().nullish()
|
||||||
|
})
|
||||||
|
.extend(CommonReasoningDetailSchema.shape)
|
||||||
|
|
||||||
|
export type ReasoningDetailText = z.infer<typeof ReasoningDetailTextSchema>
|
||||||
|
|
||||||
|
export const ReasoningDetailUnionSchema = z.union([
|
||||||
|
ReasoningDetailSummarySchema,
|
||||||
|
ReasoningDetailEncryptedSchema,
|
||||||
|
ReasoningDetailTextSchema
|
||||||
|
])
|
||||||
|
|
||||||
|
export type ReasoningDetailUnion = z.infer<typeof ReasoningDetailUnionSchema>
|
||||||
|
|
||||||
|
const ReasoningDetailsWithUnknownSchema = z.union([ReasoningDetailUnionSchema, z.unknown().transform(() => null)])
|
||||||
|
|
||||||
|
export const ReasoningDetailArraySchema = z
|
||||||
|
.array(ReasoningDetailsWithUnknownSchema)
|
||||||
|
.transform((d) => d.filter((d): d is ReasoningDetailUnion => !!d))
|
||||||
|
|
||||||
|
export const OutputUnionToReasoningDetailsSchema = z.union([
|
||||||
|
z
|
||||||
|
.object({
|
||||||
|
delta: z.object({
|
||||||
|
reasoning_details: z.array(ReasoningDetailsWithUnknownSchema)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.transform((data) => data.delta.reasoning_details.filter(isDefinedOrNotNull)),
|
||||||
|
z
|
||||||
|
.object({
|
||||||
|
message: z.object({
|
||||||
|
reasoning_details: z.array(ReasoningDetailsWithUnknownSchema)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.transform((data) => data.message.reasoning_details.filter(isDefinedOrNotNull)),
|
||||||
|
z
|
||||||
|
.object({
|
||||||
|
text: z.string(),
|
||||||
|
reasoning_details: z.array(ReasoningDetailsWithUnknownSchema)
|
||||||
|
})
|
||||||
|
.transform((data) => data.reasoning_details.filter(isDefinedOrNotNull))
|
||||||
|
])
|
||||||
@@ -1,17 +1,93 @@
|
|||||||
import type { MessageCreateParams } from '@anthropic-ai/sdk/resources'
|
import type { MessageCreateParams } from '@anthropic-ai/sdk/resources'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
|
import { buildSharedMiddlewares, type SharedMiddlewareConfig } from '@shared/middleware'
|
||||||
|
import { getAiSdkProviderId } from '@shared/provider'
|
||||||
import type { Provider } from '@types'
|
import type { Provider } from '@types'
|
||||||
import type { Request, Response } from 'express'
|
import type { Request, Response } from 'express'
|
||||||
import express from 'express'
|
import express from 'express'
|
||||||
|
|
||||||
import { messagesService } from '../services/messages'
|
import { messagesService } from '../services/messages'
|
||||||
import { getProviderById, validateModelId } from '../utils'
|
import { generateUnifiedMessage, streamUnifiedMessages } from '../services/unified-messages'
|
||||||
|
import { getProviderById, isModelAnthropicCompatible, validateModelId } from '../utils'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a specific model on a provider should use direct Anthropic SDK
|
||||||
|
*
|
||||||
|
* A provider+model combination is considered "Anthropic-compatible" if:
|
||||||
|
* 1. It's a native Anthropic provider (type === 'anthropic'), OR
|
||||||
|
* 2. It has anthropicApiHost configured AND the specific model supports Anthropic API
|
||||||
|
* (for aggregated providers like Silicon, only certain models support Anthropic endpoint)
|
||||||
|
*
|
||||||
|
* @param provider - The provider to check
|
||||||
|
* @param modelId - The model ID to check (without provider prefix)
|
||||||
|
* @returns true if should use direct Anthropic SDK, false for unified SDK
|
||||||
|
*/
|
||||||
|
function shouldUseDirectAnthropic(provider: Provider, modelId: string): boolean {
|
||||||
|
// Native Anthropic provider - always use direct SDK
|
||||||
|
if (provider.type === 'anthropic') {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// No anthropicApiHost configured - use unified SDK
|
||||||
|
if (!provider.anthropicApiHost?.trim()) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Has anthropicApiHost - check model-level compatibility
|
||||||
|
// For aggregated providers, only specific models support Anthropic API
|
||||||
|
return isModelAnthropicCompatible(provider, modelId)
|
||||||
|
}
|
||||||
|
|
||||||
const logger = loggerService.withContext('ApiServerMessagesRoutes')
|
const logger = loggerService.withContext('ApiServerMessagesRoutes')
|
||||||
|
|
||||||
const router = express.Router()
|
const router = express.Router()
|
||||||
const providerRouter = express.Router({ mergeParams: true })
|
const providerRouter = express.Router({ mergeParams: true })
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Estimate token count from messages
|
||||||
|
* Simple approximation: ~4 characters per token for English text
|
||||||
|
*/
|
||||||
|
interface CountTokensInput {
|
||||||
|
messages: Array<{ role: string; content: string | Array<{ type: string; text?: string }> }>
|
||||||
|
system?: string | Array<{ type: string; text?: string }>
|
||||||
|
}
|
||||||
|
|
||||||
|
function estimateTokenCount(input: CountTokensInput): number {
|
||||||
|
const { messages, system } = input
|
||||||
|
let totalChars = 0
|
||||||
|
|
||||||
|
// Count system message tokens
|
||||||
|
if (system) {
|
||||||
|
if (typeof system === 'string') {
|
||||||
|
totalChars += system.length
|
||||||
|
} else if (Array.isArray(system)) {
|
||||||
|
for (const block of system) {
|
||||||
|
if (block.type === 'text' && block.text) {
|
||||||
|
totalChars += block.text.length
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Count message tokens
|
||||||
|
for (const msg of messages) {
|
||||||
|
if (typeof msg.content === 'string') {
|
||||||
|
totalChars += msg.content.length
|
||||||
|
} else if (Array.isArray(msg.content)) {
|
||||||
|
for (const block of msg.content) {
|
||||||
|
if (block.type === 'text' && block.text) {
|
||||||
|
totalChars += block.text.length
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Add overhead for role
|
||||||
|
totalChars += 10
|
||||||
|
}
|
||||||
|
|
||||||
|
// Estimate tokens (~4 chars per token, with some overhead)
|
||||||
|
return Math.ceil(totalChars / 4) + messages.length * 3
|
||||||
|
}
|
||||||
|
|
||||||
// Helper function for basic request validation
|
// Helper function for basic request validation
|
||||||
async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> {
|
async function validateRequestBody(req: Request): Promise<{ valid: boolean; error?: any }> {
|
||||||
const request: MessageCreateParams = req.body
|
const request: MessageCreateParams = req.body
|
||||||
@@ -33,21 +109,36 @@ async function validateRequestBody(req: Request): Promise<{ valid: boolean; erro
|
|||||||
}
|
}
|
||||||
|
|
||||||
interface HandleMessageProcessingOptions {
|
interface HandleMessageProcessingOptions {
|
||||||
req: Request
|
|
||||||
res: Response
|
res: Response
|
||||||
provider: Provider
|
provider: Provider
|
||||||
request: MessageCreateParams
|
request: MessageCreateParams
|
||||||
modelId?: string
|
modelId?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
async function handleMessageProcessing({
|
/**
|
||||||
req,
|
* Handle message processing using direct Anthropic SDK
|
||||||
|
* Used for providers with anthropicApiHost or native Anthropic providers
|
||||||
|
* This bypasses AI SDK conversion and uses native Anthropic protocol
|
||||||
|
*/
|
||||||
|
async function handleDirectAnthropicProcessing({
|
||||||
res,
|
res,
|
||||||
provider,
|
provider,
|
||||||
request,
|
request,
|
||||||
modelId
|
modelId,
|
||||||
}: HandleMessageProcessingOptions): Promise<void> {
|
extraHeaders
|
||||||
|
}: HandleMessageProcessingOptions & { extraHeaders?: Record<string, string | string[]> }): Promise<void> {
|
||||||
|
const actualModelId = modelId || request.model
|
||||||
|
|
||||||
|
logger.info('Processing message via direct Anthropic SDK', {
|
||||||
|
providerId: provider.id,
|
||||||
|
providerType: provider.type,
|
||||||
|
modelId: actualModelId,
|
||||||
|
stream: !!request.stream,
|
||||||
|
anthropicApiHost: provider.anthropicApiHost
|
||||||
|
})
|
||||||
|
|
||||||
try {
|
try {
|
||||||
|
// Validate request
|
||||||
const validation = messagesService.validateRequest(request)
|
const validation = messagesService.validateRequest(request)
|
||||||
if (!validation.isValid) {
|
if (!validation.isValid) {
|
||||||
res.status(400).json({
|
res.status(400).json({
|
||||||
@@ -60,28 +151,126 @@ async function handleMessageProcessing({
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
const extraHeaders = messagesService.prepareHeaders(req.headers)
|
// Process message using messagesService (native Anthropic SDK)
|
||||||
const { client, anthropicRequest } = await messagesService.processMessage({
|
const { client, anthropicRequest } = await messagesService.processMessage({
|
||||||
provider,
|
provider,
|
||||||
request,
|
request,
|
||||||
extraHeaders,
|
extraHeaders,
|
||||||
modelId
|
modelId: actualModelId
|
||||||
})
|
})
|
||||||
|
|
||||||
if (request.stream) {
|
if (request.stream) {
|
||||||
|
// Use native Anthropic streaming
|
||||||
await messagesService.handleStreaming(client, anthropicRequest, { response: res }, provider)
|
await messagesService.handleStreaming(client, anthropicRequest, { response: res }, provider)
|
||||||
return
|
} else {
|
||||||
|
// Use native Anthropic non-streaming
|
||||||
|
const response = await client.messages.create(anthropicRequest)
|
||||||
|
res.json(response)
|
||||||
}
|
}
|
||||||
|
|
||||||
const response = await client.messages.create(anthropicRequest)
|
|
||||||
res.json(response)
|
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Message processing error', { error })
|
logger.error('Direct Anthropic processing error', { error })
|
||||||
const { statusCode, errorResponse } = messagesService.transformError(error)
|
const { statusCode, errorResponse } = messagesService.transformError(error)
|
||||||
res.status(statusCode).json(errorResponse)
|
res.status(statusCode).json(errorResponse)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle message processing using unified AI SDK
|
||||||
|
* Used for non-Anthropic providers that need format conversion
|
||||||
|
* - Uses AI SDK adapters with output converted to Anthropic SSE format
|
||||||
|
*/
|
||||||
|
async function handleUnifiedProcessing({
|
||||||
|
res,
|
||||||
|
provider,
|
||||||
|
request,
|
||||||
|
modelId
|
||||||
|
}: HandleMessageProcessingOptions): Promise<void> {
|
||||||
|
const actualModelId = modelId || request.model
|
||||||
|
|
||||||
|
logger.info('Processing message via unified AI SDK', {
|
||||||
|
providerId: provider.id,
|
||||||
|
providerType: provider.type,
|
||||||
|
modelId: actualModelId,
|
||||||
|
stream: !!request.stream
|
||||||
|
})
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Validate request
|
||||||
|
const validation = messagesService.validateRequest(request)
|
||||||
|
if (!validation.isValid) {
|
||||||
|
res.status(400).json({
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
type: 'invalid_request_error',
|
||||||
|
message: validation.errors.join('; ')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const middlewareConfig: SharedMiddlewareConfig = {
|
||||||
|
modelId: actualModelId,
|
||||||
|
providerId: provider.id,
|
||||||
|
aiSdkProviderId: getAiSdkProviderId(provider)
|
||||||
|
}
|
||||||
|
const middlewares = buildSharedMiddlewares(middlewareConfig)
|
||||||
|
|
||||||
|
logger.debug('Built middlewares for unified processing', {
|
||||||
|
middlewareCount: middlewares.length,
|
||||||
|
modelId: actualModelId,
|
||||||
|
providerId: provider.id
|
||||||
|
})
|
||||||
|
|
||||||
|
if (request.stream) {
|
||||||
|
await streamUnifiedMessages({
|
||||||
|
response: res,
|
||||||
|
provider,
|
||||||
|
modelId: actualModelId,
|
||||||
|
params: request,
|
||||||
|
middlewares,
|
||||||
|
onError: (error) => {
|
||||||
|
logger.error('Stream error', error as Error)
|
||||||
|
},
|
||||||
|
onComplete: () => {
|
||||||
|
logger.debug('Stream completed')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
const response = await generateUnifiedMessage({
|
||||||
|
provider,
|
||||||
|
modelId: actualModelId,
|
||||||
|
params: request,
|
||||||
|
middlewares
|
||||||
|
})
|
||||||
|
res.json(response)
|
||||||
|
}
|
||||||
|
} catch (error: any) {
|
||||||
|
const { statusCode, errorResponse } = messagesService.transformError(error)
|
||||||
|
res.status(statusCode).json(errorResponse)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle message processing - routes to appropriate handler based on provider and model
|
||||||
|
*
|
||||||
|
* Routing logic:
|
||||||
|
* - Native Anthropic providers (type === 'anthropic'): Direct Anthropic SDK
|
||||||
|
* - Providers with anthropicApiHost AND model supports Anthropic API: Direct Anthropic SDK
|
||||||
|
* - Other providers/models: Unified AI SDK with Anthropic SSE conversion
|
||||||
|
*/
|
||||||
|
async function handleMessageProcessing({
|
||||||
|
res,
|
||||||
|
provider,
|
||||||
|
request,
|
||||||
|
modelId
|
||||||
|
}: HandleMessageProcessingOptions): Promise<void> {
|
||||||
|
const actualModelId = modelId || request.model
|
||||||
|
if (shouldUseDirectAnthropic(provider, actualModelId)) {
|
||||||
|
return handleDirectAnthropicProcessing({ res, provider, request, modelId })
|
||||||
|
}
|
||||||
|
return handleUnifiedProcessing({ res, provider, request, modelId })
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @swagger
|
* @swagger
|
||||||
* /v1/messages:
|
* /v1/messages:
|
||||||
@@ -235,7 +424,7 @@ router.post('/', async (req: Request, res: Response) => {
|
|||||||
const provider = modelValidation.provider!
|
const provider = modelValidation.provider!
|
||||||
const modelId = modelValidation.modelId!
|
const modelId = modelValidation.modelId!
|
||||||
|
|
||||||
return handleMessageProcessing({ req, res, provider, request, modelId })
|
return handleMessageProcessing({ res, provider, request, modelId })
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Message processing error', { error })
|
logger.error('Message processing error', { error })
|
||||||
const { statusCode, errorResponse } = messagesService.transformError(error)
|
const { statusCode, errorResponse } = messagesService.transformError(error)
|
||||||
@@ -393,7 +582,7 @@ providerRouter.post('/', async (req: Request, res: Response) => {
|
|||||||
|
|
||||||
const request: MessageCreateParams = req.body
|
const request: MessageCreateParams = req.body
|
||||||
|
|
||||||
return handleMessageProcessing({ req, res, provider, request })
|
return handleMessageProcessing({ res, provider, request })
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Message processing error', { error })
|
logger.error('Message processing error', { error })
|
||||||
const { statusCode, errorResponse } = messagesService.transformError(error)
|
const { statusCode, errorResponse } = messagesService.transformError(error)
|
||||||
@@ -401,4 +590,132 @@ providerRouter.post('/', async (req: Request, res: Response) => {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @swagger
|
||||||
|
* /v1/messages/count_tokens:
|
||||||
|
* post:
|
||||||
|
* summary: Count tokens for messages
|
||||||
|
* description: Count tokens for Anthropic Messages API format (required by Claude Code SDK)
|
||||||
|
* tags: [Messages]
|
||||||
|
* requestBody:
|
||||||
|
* required: true
|
||||||
|
* content:
|
||||||
|
* application/json:
|
||||||
|
* schema:
|
||||||
|
* type: object
|
||||||
|
* required:
|
||||||
|
* - model
|
||||||
|
* - messages
|
||||||
|
* properties:
|
||||||
|
* model:
|
||||||
|
* type: string
|
||||||
|
* description: Model ID
|
||||||
|
* messages:
|
||||||
|
* type: array
|
||||||
|
* items:
|
||||||
|
* type: object
|
||||||
|
* system:
|
||||||
|
* type: string
|
||||||
|
* description: System message
|
||||||
|
* responses:
|
||||||
|
* 200:
|
||||||
|
* description: Token count response
|
||||||
|
* content:
|
||||||
|
* application/json:
|
||||||
|
* schema:
|
||||||
|
* type: object
|
||||||
|
* properties:
|
||||||
|
* input_tokens:
|
||||||
|
* type: integer
|
||||||
|
* 400:
|
||||||
|
* description: Bad request
|
||||||
|
*/
|
||||||
|
router.post('/count_tokens', async (req: Request, res: Response) => {
|
||||||
|
try {
|
||||||
|
const { model, messages, system } = req.body
|
||||||
|
|
||||||
|
if (!model) {
|
||||||
|
return res.status(400).json({
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
type: 'invalid_request_error',
|
||||||
|
message: 'model parameter is required'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!messages || !Array.isArray(messages)) {
|
||||||
|
return res.status(400).json({
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
type: 'invalid_request_error',
|
||||||
|
message: 'messages parameter is required'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const estimatedTokens = estimateTokenCount({ messages, system })
|
||||||
|
|
||||||
|
logger.debug('Token count estimated', {
|
||||||
|
model,
|
||||||
|
messageCount: messages.length,
|
||||||
|
estimatedTokens
|
||||||
|
})
|
||||||
|
|
||||||
|
return res.json({
|
||||||
|
input_tokens: estimatedTokens
|
||||||
|
})
|
||||||
|
} catch (error: any) {
|
||||||
|
logger.error('Token counting error', { error })
|
||||||
|
return res.status(500).json({
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
type: 'api_error',
|
||||||
|
message: error.message || 'Internal server error'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provider-specific count_tokens endpoint
|
||||||
|
*/
|
||||||
|
providerRouter.post('/count_tokens', async (req: Request, res: Response) => {
|
||||||
|
try {
|
||||||
|
const { model, messages, system } = req.body
|
||||||
|
|
||||||
|
if (!messages || !Array.isArray(messages)) {
|
||||||
|
return res.status(400).json({
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
type: 'invalid_request_error',
|
||||||
|
message: 'messages parameter is required'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const estimatedTokens = estimateTokenCount({ messages, system })
|
||||||
|
|
||||||
|
logger.debug('Token count estimated (provider route)', {
|
||||||
|
providerId: req.params.provider,
|
||||||
|
model,
|
||||||
|
messageCount: messages.length,
|
||||||
|
estimatedTokens
|
||||||
|
})
|
||||||
|
|
||||||
|
return res.json({
|
||||||
|
input_tokens: estimatedTokens
|
||||||
|
})
|
||||||
|
} catch (error: any) {
|
||||||
|
logger.error('Token counting error', { error })
|
||||||
|
return res.status(500).json({
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
type: 'api_error',
|
||||||
|
message: error.message || 'Internal server error'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
export { providerRouter as messagesProviderRoutes, router as messagesRoutes }
|
export { providerRouter as messagesProviderRoutes, router as messagesRoutes }
|
||||||
|
|||||||
@@ -2,8 +2,10 @@ import type Anthropic from '@anthropic-ai/sdk'
|
|||||||
import type { MessageCreateParams, MessageStreamEvent } from '@anthropic-ai/sdk/resources'
|
import type { MessageCreateParams, MessageStreamEvent } from '@anthropic-ai/sdk/resources'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import anthropicService from '@main/services/AnthropicService'
|
import anthropicService from '@main/services/AnthropicService'
|
||||||
import { buildClaudeCodeSystemMessage, getSdkClient } from '@shared/anthropic'
|
import { buildClaudeCodeSystemMessage, getSdkClient, sanitizeToolsForAnthropic } from '@shared/anthropic'
|
||||||
import type { Provider } from '@types'
|
import type { Provider } from '@types'
|
||||||
|
import { APICallError, RetryError } from 'ai'
|
||||||
|
import { net } from 'electron'
|
||||||
import type { Response } from 'express'
|
import type { Response } from 'express'
|
||||||
|
|
||||||
const logger = loggerService.withContext('MessagesService')
|
const logger = loggerService.withContext('MessagesService')
|
||||||
@@ -98,11 +100,30 @@ export class MessagesService {
|
|||||||
|
|
||||||
async getClient(provider: Provider, extraHeaders?: Record<string, string | string[]>): Promise<Anthropic> {
|
async getClient(provider: Provider, extraHeaders?: Record<string, string | string[]>): Promise<Anthropic> {
|
||||||
// Create Anthropic client for the provider
|
// Create Anthropic client for the provider
|
||||||
|
// Wrap net.fetch to handle compatibility issues:
|
||||||
|
// 1. net.fetch expects string URLs, not Request objects
|
||||||
|
// 2. net.fetch doesn't support 'agent' option from Node.js http module
|
||||||
|
const electronFetch: typeof globalThis.fetch = async (input: URL | RequestInfo, init?: RequestInit) => {
|
||||||
|
const url = typeof input === 'string' ? input : input instanceof URL ? input.toString() : input.url
|
||||||
|
// Remove unsupported options for Electron's net.fetch
|
||||||
|
if (init) {
|
||||||
|
const initWithAgent = init as RequestInit & { agent?: unknown }
|
||||||
|
delete initWithAgent.agent
|
||||||
|
const headers = new Headers(initWithAgent.headers)
|
||||||
|
if (headers.has('content-length')) {
|
||||||
|
headers.delete('content-length')
|
||||||
|
}
|
||||||
|
initWithAgent.headers = headers
|
||||||
|
return net.fetch(url, initWithAgent)
|
||||||
|
}
|
||||||
|
return net.fetch(url)
|
||||||
|
}
|
||||||
|
const context = { fetch: electronFetch }
|
||||||
if (provider.authType === 'oauth') {
|
if (provider.authType === 'oauth') {
|
||||||
const oauthToken = await anthropicService.getValidAccessToken()
|
const oauthToken = await anthropicService.getValidAccessToken()
|
||||||
return getSdkClient(provider, oauthToken, extraHeaders)
|
return getSdkClient(provider, oauthToken, extraHeaders, context)
|
||||||
}
|
}
|
||||||
return getSdkClient(provider, null, extraHeaders)
|
return getSdkClient(provider, null, extraHeaders, context)
|
||||||
}
|
}
|
||||||
|
|
||||||
prepareHeaders(headers: Record<string, string | string[] | undefined>): Record<string, string | string[]> {
|
prepareHeaders(headers: Record<string, string | string[] | undefined>): Record<string, string | string[]> {
|
||||||
@@ -127,7 +148,8 @@ export class MessagesService {
|
|||||||
createAnthropicRequest(request: MessageCreateParams, provider: Provider, modelId?: string): MessageCreateParams {
|
createAnthropicRequest(request: MessageCreateParams, provider: Provider, modelId?: string): MessageCreateParams {
|
||||||
const anthropicRequest: MessageCreateParams = {
|
const anthropicRequest: MessageCreateParams = {
|
||||||
...request,
|
...request,
|
||||||
stream: !!request.stream
|
stream: !!request.stream,
|
||||||
|
tools: sanitizeToolsForAnthropic(request.tools)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Override model if provided
|
// Override model if provided
|
||||||
@@ -233,9 +255,71 @@ export class MessagesService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
transformError(error: any): { statusCode: number; errorResponse: ErrorResponse } {
|
transformError(error: any): { statusCode: number; errorResponse: ErrorResponse } {
|
||||||
let statusCode = 500
|
let statusCode: number | undefined = undefined
|
||||||
let errorType = 'api_error'
|
let errorType: string | undefined = undefined
|
||||||
let errorMessage = 'Internal server error'
|
let errorMessage: string | undefined = undefined
|
||||||
|
|
||||||
|
const errorMap: Record<number, string> = {
|
||||||
|
400: 'invalid_request_error',
|
||||||
|
401: 'authentication_error',
|
||||||
|
403: 'forbidden_error',
|
||||||
|
404: 'not_found_error',
|
||||||
|
429: 'rate_limit_error',
|
||||||
|
500: 'internal_server_error'
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle AI SDK RetryError - extract the last error for better error messages
|
||||||
|
if (RetryError.isInstance(error)) {
|
||||||
|
const lastError = error.lastError
|
||||||
|
// If the last error is an APICallError, extract its details
|
||||||
|
if (APICallError.isInstance(lastError)) {
|
||||||
|
statusCode = lastError.statusCode || 502
|
||||||
|
errorMessage = lastError.message
|
||||||
|
return {
|
||||||
|
statusCode,
|
||||||
|
errorResponse: {
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
type: errorMap[statusCode] || 'api_error',
|
||||||
|
message: `${error.reason}: ${errorMessage}`,
|
||||||
|
requestId: lastError.name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Fallback for other retry errors
|
||||||
|
errorMessage = error.message
|
||||||
|
statusCode = 502
|
||||||
|
return {
|
||||||
|
statusCode,
|
||||||
|
errorResponse: {
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
type: 'api_error',
|
||||||
|
message: errorMessage,
|
||||||
|
requestId: error.name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (APICallError.isInstance(error)) {
|
||||||
|
statusCode = error.statusCode
|
||||||
|
errorMessage = error.message
|
||||||
|
if (statusCode) {
|
||||||
|
return {
|
||||||
|
statusCode,
|
||||||
|
errorResponse: {
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
type: errorMap[statusCode] || 'api_error',
|
||||||
|
message: errorMessage,
|
||||||
|
requestId: error.name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined
|
const anthropicStatus = typeof error?.status === 'number' ? error.status : undefined
|
||||||
const anthropicError = error?.error
|
const anthropicError = error?.error
|
||||||
@@ -277,11 +361,11 @@ export class MessagesService {
|
|||||||
typeof errorMessage === 'string' && errorMessage.length > 0 ? errorMessage : 'Internal server error'
|
typeof errorMessage === 'string' && errorMessage.length > 0 ? errorMessage : 'Internal server error'
|
||||||
|
|
||||||
return {
|
return {
|
||||||
statusCode,
|
statusCode: statusCode ?? 500,
|
||||||
errorResponse: {
|
errorResponse: {
|
||||||
type: 'error',
|
type: 'error',
|
||||||
error: {
|
error: {
|
||||||
type: errorType,
|
type: errorType || 'api_error',
|
||||||
message: safeErrorMessage,
|
message: safeErrorMessage,
|
||||||
requestId: error?.request_id
|
requestId: error?.request_id
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,6 @@
|
|||||||
import { isEmpty } from 'lodash'
|
|
||||||
|
|
||||||
import type { ApiModel, ApiModelsFilter, ApiModelsResponse } from '../../../renderer/src/types/apiModels'
|
import type { ApiModel, ApiModelsFilter, ApiModelsResponse } from '../../../renderer/src/types/apiModels'
|
||||||
import { loggerService } from '../../services/LoggerService'
|
import { loggerService } from '../../services/LoggerService'
|
||||||
import {
|
import { getAvailableProviders, listAllAvailableModels, transformModelToOpenAI } from '../utils'
|
||||||
getAvailableProviders,
|
|
||||||
getProviderAnthropicModelChecker,
|
|
||||||
listAllAvailableModels,
|
|
||||||
transformModelToOpenAI
|
|
||||||
} from '../utils'
|
|
||||||
|
|
||||||
const logger = loggerService.withContext('ModelsService')
|
const logger = loggerService.withContext('ModelsService')
|
||||||
|
|
||||||
@@ -20,11 +13,12 @@ export class ModelsService {
|
|||||||
try {
|
try {
|
||||||
logger.debug('Getting available models from providers', { filter })
|
logger.debug('Getting available models from providers', { filter })
|
||||||
|
|
||||||
let providers = await getAvailableProviders()
|
const providers = await getAvailableProviders()
|
||||||
|
|
||||||
if (filter.providerType === 'anthropic') {
|
// Note: When providerType === 'anthropic', we now return ALL available models
|
||||||
providers = providers.filter((p) => p.type === 'anthropic' || !isEmpty(p.anthropicApiHost?.trim()))
|
// because the API Server's unified adapter (AiSdkToAnthropicSSE) can convert
|
||||||
}
|
// any provider's response to Anthropic SSE format. This enables Claude Code Agent
|
||||||
|
// to work with OpenAI, Gemini, and other providers transparently.
|
||||||
|
|
||||||
const models = await listAllAvailableModels(providers)
|
const models = await listAllAvailableModels(providers)
|
||||||
// Use Map to deduplicate models by their full ID (provider:model_id)
|
// Use Map to deduplicate models by their full ID (provider:model_id)
|
||||||
@@ -32,20 +26,11 @@ export class ModelsService {
|
|||||||
|
|
||||||
for (const model of models) {
|
for (const model of models) {
|
||||||
const provider = providers.find((p) => p.id === model.provider)
|
const provider = providers.find((p) => p.id === model.provider)
|
||||||
// logger.debug(`Processing model ${model.id}`)
|
|
||||||
if (!provider) {
|
if (!provider) {
|
||||||
logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`)
|
logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if (filter.providerType === 'anthropic') {
|
|
||||||
const checker = getProviderAnthropicModelChecker(provider.id)
|
|
||||||
if (!checker(model)) {
|
|
||||||
logger.debug(`Skipping model ${model.id} from ${model.provider}. Reason: Not an Anthropic model.`)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const openAIModel = transformModelToOpenAI(model, provider)
|
const openAIModel = transformModelToOpenAI(model, provider)
|
||||||
const fullModelId = openAIModel.id // This is already in format "provider:model_id"
|
const fullModelId = openAIModel.id // This is already in format "provider:model_id"
|
||||||
|
|
||||||
|
|||||||
718
src/main/apiServer/services/unified-messages.ts
Normal file
718
src/main/apiServer/services/unified-messages.ts
Normal file
@@ -0,0 +1,718 @@
|
|||||||
|
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
|
||||||
|
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
|
||||||
|
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
|
||||||
|
import type { LanguageModelV2Middleware, LanguageModelV2ToolResultOutput } from '@ai-sdk/provider'
|
||||||
|
import type { ProviderOptions, ReasoningPart, ToolCallPart, ToolResultPart } from '@ai-sdk/provider-utils'
|
||||||
|
import type {
|
||||||
|
ImageBlockParam,
|
||||||
|
MessageCreateParams,
|
||||||
|
TextBlockParam,
|
||||||
|
Tool as AnthropicTool
|
||||||
|
} from '@anthropic-ai/sdk/resources/messages'
|
||||||
|
import { type AiPlugin, createExecutor } from '@cherrystudio/ai-core'
|
||||||
|
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
|
||||||
|
import { loggerService } from '@logger'
|
||||||
|
import { AiSdkToAnthropicSSE, formatSSEDone, formatSSEEvent } from '@main/apiServer/adapters'
|
||||||
|
import { generateSignature as cherryaiGenerateSignature } from '@main/integration/cherryai'
|
||||||
|
import anthropicService from '@main/services/AnthropicService'
|
||||||
|
import copilotService from '@main/services/CopilotService'
|
||||||
|
import { reduxService } from '@main/services/ReduxService'
|
||||||
|
import { isGemini3ModelId } from '@shared/middleware'
|
||||||
|
import {
|
||||||
|
type AiSdkConfig,
|
||||||
|
type AiSdkConfigContext,
|
||||||
|
formatProviderApiHost,
|
||||||
|
initializeSharedProviders,
|
||||||
|
isAnthropicProvider,
|
||||||
|
isGeminiProvider,
|
||||||
|
isOpenAIProvider,
|
||||||
|
type ProviderFormatContext,
|
||||||
|
providerToAiSdkConfig as sharedProviderToAiSdkConfig,
|
||||||
|
resolveActualProvider
|
||||||
|
} from '@shared/provider'
|
||||||
|
import { COPILOT_DEFAULT_HEADERS } from '@shared/provider/constant'
|
||||||
|
import { defaultAppHeaders } from '@shared/utils'
|
||||||
|
import type { Provider } from '@types'
|
||||||
|
import type { ImagePart, JSONValue, ModelMessage, Provider as AiSdkProvider, TextPart, Tool as AiSdkTool } from 'ai'
|
||||||
|
import { simulateStreamingMiddleware, stepCountIs, tool, wrapLanguageModel, zodSchema } from 'ai'
|
||||||
|
import { net } from 'electron'
|
||||||
|
import type { Response } from 'express'
|
||||||
|
import * as z from 'zod'
|
||||||
|
|
||||||
|
import { googleReasoningCache, openRouterReasoningCache } from '../../services/CacheService'
|
||||||
|
|
||||||
|
const logger = loggerService.withContext('UnifiedMessagesService')
|
||||||
|
|
||||||
|
const MAGIC_STRING = 'skip_thought_signature_validator'
|
||||||
|
|
||||||
|
function sanitizeJson(value: unknown): JSONValue {
|
||||||
|
return JSON.parse(JSON.stringify(value))
|
||||||
|
}
|
||||||
|
|
||||||
|
initializeSharedProviders({
|
||||||
|
warn: (message) => logger.warn(message),
|
||||||
|
error: (message, error) => logger.error(message, error)
|
||||||
|
})
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configuration for unified message streaming
|
||||||
|
*/
|
||||||
|
export interface UnifiedStreamConfig {
|
||||||
|
response: Response
|
||||||
|
provider: Provider
|
||||||
|
modelId: string
|
||||||
|
params: MessageCreateParams
|
||||||
|
onError?: (error: unknown) => void
|
||||||
|
onComplete?: () => void
|
||||||
|
/**
|
||||||
|
* Optional AI SDK middlewares to apply
|
||||||
|
*/
|
||||||
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
|
/**
|
||||||
|
* Optional AI Core plugins to use with the executor
|
||||||
|
*/
|
||||||
|
plugins?: AiPlugin[]
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Configuration for non-streaming message generation
|
||||||
|
*/
|
||||||
|
export interface GenerateUnifiedMessageConfig {
|
||||||
|
provider: Provider
|
||||||
|
modelId: string
|
||||||
|
params: MessageCreateParams
|
||||||
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
|
plugins?: AiPlugin[]
|
||||||
|
}
|
||||||
|
|
||||||
|
function getMainProcessFormatContext(): ProviderFormatContext {
|
||||||
|
const vertexSettings = reduxService.selectSync<{ projectId: string; location: string }>('state.llm.settings.vertexai')
|
||||||
|
return {
|
||||||
|
vertex: {
|
||||||
|
project: vertexSettings?.projectId || 'default-project',
|
||||||
|
location: vertexSettings?.location || 'us-central1'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const mainProcessSdkContext: AiSdkConfigContext = {
|
||||||
|
getRotatedApiKey: (provider) => {
|
||||||
|
const keys = provider.apiKey.split(',').map((k) => k.trim())
|
||||||
|
return keys[0] || provider.apiKey
|
||||||
|
},
|
||||||
|
fetch: net.fetch as typeof globalThis.fetch
|
||||||
|
}
|
||||||
|
|
||||||
|
function getActualProvider(provider: Provider, modelId: string): Provider {
|
||||||
|
const model = provider.models?.find((m) => m.id === modelId)
|
||||||
|
if (!model) return provider
|
||||||
|
return resolveActualProvider(provider, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
function providerToAiSdkConfig(provider: Provider, modelId: string): AiSdkConfig {
|
||||||
|
const actualProvider = getActualProvider(provider, modelId)
|
||||||
|
const formattedProvider = formatProviderApiHost(actualProvider, getMainProcessFormatContext())
|
||||||
|
return sharedProviderToAiSdkConfig(formattedProvider, modelId, mainProcessSdkContext)
|
||||||
|
}
|
||||||
|
|
||||||
|
function convertAnthropicToolResultToAiSdk(
|
||||||
|
content: string | Array<TextBlockParam | ImageBlockParam>
|
||||||
|
): LanguageModelV2ToolResultOutput {
|
||||||
|
if (typeof content === 'string') {
|
||||||
|
return { type: 'text', value: content }
|
||||||
|
}
|
||||||
|
const values: Array<{ type: 'text'; text: string } | { type: 'media'; data: string; mediaType: string }> = []
|
||||||
|
for (const block of content) {
|
||||||
|
if (block.type === 'text') {
|
||||||
|
values.push({ type: 'text', text: block.text })
|
||||||
|
} else if (block.type === 'image') {
|
||||||
|
values.push({
|
||||||
|
type: 'media',
|
||||||
|
data: block.source.type === 'base64' ? block.source.data : block.source.url,
|
||||||
|
mediaType: block.source.type === 'base64' ? block.source.media_type : 'image/png'
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return { type: 'content', value: values }
|
||||||
|
}
|
||||||
|
|
||||||
|
// Type alias for JSON Schema (compatible with recursive calls)
|
||||||
|
type JsonSchemaLike = AnthropicTool.InputSchema | Record<string, unknown>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Convert JSON Schema to Zod schema
|
||||||
|
* This avoids non-standard fields like input_examples that Anthropic doesn't support
|
||||||
|
*/
|
||||||
|
function jsonSchemaToZod(schema: JsonSchemaLike): z.ZodTypeAny {
|
||||||
|
const s = schema as Record<string, unknown>
|
||||||
|
const schemaType = s.type as string | string[] | undefined
|
||||||
|
const enumValues = s.enum as unknown[] | undefined
|
||||||
|
const description = s.description as string | undefined
|
||||||
|
|
||||||
|
// Handle enum first
|
||||||
|
if (enumValues && Array.isArray(enumValues) && enumValues.length > 0) {
|
||||||
|
if (enumValues.every((v) => typeof v === 'string')) {
|
||||||
|
const zodEnum = z.enum(enumValues as [string, ...string[]])
|
||||||
|
return description ? zodEnum.describe(description) : zodEnum
|
||||||
|
}
|
||||||
|
// For non-string enums, use union of literals
|
||||||
|
const literals = enumValues.map((v) => z.literal(v as string | number | boolean))
|
||||||
|
if (literals.length === 1) {
|
||||||
|
return description ? literals[0].describe(description) : literals[0]
|
||||||
|
}
|
||||||
|
const zodUnion = z.union(literals as unknown as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
|
||||||
|
return description ? zodUnion.describe(description) : zodUnion
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle union types (type: ["string", "null"])
|
||||||
|
if (Array.isArray(schemaType)) {
|
||||||
|
const schemas = schemaType.map((t) => jsonSchemaToZod({ ...s, type: t, enum: undefined }))
|
||||||
|
if (schemas.length === 1) {
|
||||||
|
return schemas[0]
|
||||||
|
}
|
||||||
|
return z.union(schemas as [z.ZodTypeAny, z.ZodTypeAny, ...z.ZodTypeAny[]])
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle by type
|
||||||
|
switch (schemaType) {
|
||||||
|
case 'string': {
|
||||||
|
let zodString = z.string()
|
||||||
|
if (typeof s.minLength === 'number') zodString = zodString.min(s.minLength)
|
||||||
|
if (typeof s.maxLength === 'number') zodString = zodString.max(s.maxLength)
|
||||||
|
if (typeof s.pattern === 'string') zodString = zodString.regex(new RegExp(s.pattern))
|
||||||
|
return description ? zodString.describe(description) : zodString
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'number':
|
||||||
|
case 'integer': {
|
||||||
|
let zodNumber = schemaType === 'integer' ? z.number().int() : z.number()
|
||||||
|
if (typeof s.minimum === 'number') zodNumber = zodNumber.min(s.minimum)
|
||||||
|
if (typeof s.maximum === 'number') zodNumber = zodNumber.max(s.maximum)
|
||||||
|
return description ? zodNumber.describe(description) : zodNumber
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'boolean': {
|
||||||
|
const zodBoolean = z.boolean()
|
||||||
|
return description ? zodBoolean.describe(description) : zodBoolean
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'null':
|
||||||
|
return z.null()
|
||||||
|
|
||||||
|
case 'array': {
|
||||||
|
const items = s.items as Record<string, unknown> | undefined
|
||||||
|
let zodArray = items ? z.array(jsonSchemaToZod(items)) : z.array(z.unknown())
|
||||||
|
if (typeof s.minItems === 'number') zodArray = zodArray.min(s.minItems)
|
||||||
|
if (typeof s.maxItems === 'number') zodArray = zodArray.max(s.maxItems)
|
||||||
|
return description ? zodArray.describe(description) : zodArray
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'object': {
|
||||||
|
const properties = s.properties as Record<string, Record<string, unknown>> | undefined
|
||||||
|
const required = (s.required as string[]) || []
|
||||||
|
|
||||||
|
// Always use z.object() to ensure "properties" field is present in output schema
|
||||||
|
// OpenAI requires explicit properties field even for empty objects
|
||||||
|
const shape: Record<string, z.ZodTypeAny> = {}
|
||||||
|
if (properties) {
|
||||||
|
for (const [key, propSchema] of Object.entries(properties)) {
|
||||||
|
const zodProp = jsonSchemaToZod(propSchema)
|
||||||
|
shape[key] = required.includes(key) ? zodProp : zodProp.optional()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const zodObject = z.object(shape)
|
||||||
|
return description ? zodObject.describe(description) : zodObject
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Unknown type, use z.unknown()
|
||||||
|
return z.unknown()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function convertAnthropicToolsToAiSdk(tools: MessageCreateParams['tools']): Record<string, AiSdkTool> | undefined {
|
||||||
|
if (!tools || tools.length === 0) return undefined
|
||||||
|
|
||||||
|
const aiSdkTools: Record<string, AiSdkTool> = {}
|
||||||
|
for (const anthropicTool of tools) {
|
||||||
|
if (anthropicTool.type === 'bash_20250124') continue
|
||||||
|
const toolDef = anthropicTool as AnthropicTool
|
||||||
|
const rawSchema = toolDef.input_schema
|
||||||
|
const schema = jsonSchemaToZod(rawSchema)
|
||||||
|
|
||||||
|
// Use tool() with inputSchema (AI SDK v5 API)
|
||||||
|
const aiTool = tool({
|
||||||
|
description: toolDef.description || '',
|
||||||
|
inputSchema: zodSchema(schema)
|
||||||
|
})
|
||||||
|
|
||||||
|
aiSdkTools[toolDef.name] = aiTool
|
||||||
|
}
|
||||||
|
return Object.keys(aiSdkTools).length > 0 ? aiSdkTools : undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
function convertAnthropicToAiMessages(params: MessageCreateParams): ModelMessage[] {
|
||||||
|
const messages: ModelMessage[] = []
|
||||||
|
|
||||||
|
// System message
|
||||||
|
if (params.system) {
|
||||||
|
if (typeof params.system === 'string') {
|
||||||
|
messages.push({ role: 'system', content: params.system })
|
||||||
|
} else if (Array.isArray(params.system)) {
|
||||||
|
const systemText = params.system
|
||||||
|
.filter((block) => block.type === 'text')
|
||||||
|
.map((block) => block.text)
|
||||||
|
.join('\n')
|
||||||
|
if (systemText) {
|
||||||
|
messages.push({ role: 'system', content: systemText })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const toolCallIdToName = new Map<string, string>()
|
||||||
|
for (const msg of params.messages) {
|
||||||
|
if (Array.isArray(msg.content)) {
|
||||||
|
for (const block of msg.content) {
|
||||||
|
if (block.type === 'tool_use') {
|
||||||
|
toolCallIdToName.set(block.id, block.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// User/assistant messages
|
||||||
|
for (const msg of params.messages) {
|
||||||
|
if (typeof msg.content === 'string') {
|
||||||
|
messages.push({
|
||||||
|
role: msg.role === 'user' ? 'user' : 'assistant',
|
||||||
|
content: msg.content
|
||||||
|
})
|
||||||
|
} else if (Array.isArray(msg.content)) {
|
||||||
|
const textParts: TextPart[] = []
|
||||||
|
const imageParts: ImagePart[] = []
|
||||||
|
const reasoningParts: ReasoningPart[] = []
|
||||||
|
const toolCallParts: ToolCallPart[] = []
|
||||||
|
const toolResultParts: ToolResultPart[] = []
|
||||||
|
|
||||||
|
for (const block of msg.content) {
|
||||||
|
if (block.type === 'text') {
|
||||||
|
textParts.push({ type: 'text', text: block.text })
|
||||||
|
} else if (block.type === 'thinking') {
|
||||||
|
reasoningParts.push({ type: 'reasoning', text: block.thinking })
|
||||||
|
} else if (block.type === 'redacted_thinking') {
|
||||||
|
reasoningParts.push({ type: 'reasoning', text: block.data })
|
||||||
|
} else if (block.type === 'image') {
|
||||||
|
const source = block.source
|
||||||
|
if (source.type === 'base64') {
|
||||||
|
imageParts.push({ type: 'image', image: `data:${source.media_type};base64,${source.data}` })
|
||||||
|
} else if (source.type === 'url') {
|
||||||
|
imageParts.push({ type: 'image', image: source.url })
|
||||||
|
}
|
||||||
|
} else if (block.type === 'tool_use') {
|
||||||
|
const options: ProviderOptions = {}
|
||||||
|
|
||||||
|
if (isGemini3ModelId(params.model)) {
|
||||||
|
if (googleReasoningCache.get(`google-${block.name}`)) {
|
||||||
|
options.google = {
|
||||||
|
thoughtSignature: MAGIC_STRING
|
||||||
|
}
|
||||||
|
} else if (openRouterReasoningCache.get('openrouter')) {
|
||||||
|
options.openrouter = {
|
||||||
|
reasoning_details: (sanitizeJson(openRouterReasoningCache.get('openrouter')) as JSONValue[]) || []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
toolCallParts.push({
|
||||||
|
type: 'tool-call',
|
||||||
|
toolName: block.name,
|
||||||
|
toolCallId: block.id,
|
||||||
|
input: block.input,
|
||||||
|
providerOptions: options
|
||||||
|
})
|
||||||
|
} else if (block.type === 'tool_result') {
|
||||||
|
// Look up toolName from the pre-built map (covers cross-message references)
|
||||||
|
const toolName = toolCallIdToName.get(block.tool_use_id) || 'unknown'
|
||||||
|
toolResultParts.push({
|
||||||
|
type: 'tool-result',
|
||||||
|
toolCallId: block.tool_use_id,
|
||||||
|
toolName,
|
||||||
|
output: block.content ? convertAnthropicToolResultToAiSdk(block.content) : { type: 'text', value: '' }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (toolResultParts.length > 0) {
|
||||||
|
messages.push({ role: 'tool', content: [...toolResultParts] })
|
||||||
|
}
|
||||||
|
|
||||||
|
if (msg.role === 'user') {
|
||||||
|
const userContent = [...textParts, ...imageParts]
|
||||||
|
if (userContent.length > 0) {
|
||||||
|
messages.push({ role: 'user', content: userContent })
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const assistantContent = [...reasoningParts, ...textParts, ...toolCallParts]
|
||||||
|
if (assistantContent.length > 0) {
|
||||||
|
let providerOptions: ProviderOptions | undefined = undefined
|
||||||
|
if (openRouterReasoningCache.get('openrouter')) {
|
||||||
|
providerOptions = {
|
||||||
|
openrouter: {
|
||||||
|
reasoning_details: (sanitizeJson(openRouterReasoningCache.get('openrouter')) as JSONValue[]) || []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else if (isGemini3ModelId(params.model)) {
|
||||||
|
providerOptions = {
|
||||||
|
google: {
|
||||||
|
thoughtSignature: MAGIC_STRING
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
messages.push({ role: 'assistant', content: assistantContent, providerOptions })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return messages
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ExecuteStreamConfig {
|
||||||
|
provider: Provider
|
||||||
|
modelId: string
|
||||||
|
params: MessageCreateParams
|
||||||
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
|
plugins?: AiPlugin[]
|
||||||
|
onEvent?: (event: Parameters<typeof formatSSEEvent>[0]) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Create AI SDK provider instance from config
|
||||||
|
* Similar to renderer's createAiSdkProvider
|
||||||
|
*/
|
||||||
|
async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider> {
|
||||||
|
let providerId = config.providerId
|
||||||
|
|
||||||
|
// Handle special provider modes (same as renderer)
|
||||||
|
if (providerId === 'openai' && config.options?.mode === 'chat') {
|
||||||
|
providerId = 'openai-chat'
|
||||||
|
} else if (providerId === 'azure' && config.options?.mode === 'responses') {
|
||||||
|
providerId = 'azure-responses'
|
||||||
|
} else if (providerId === 'cherryin' && config.options?.mode === 'chat') {
|
||||||
|
providerId = 'cherryin-chat'
|
||||||
|
}
|
||||||
|
|
||||||
|
const provider = await createProviderCore(providerId, config.options)
|
||||||
|
|
||||||
|
return provider
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Prepare special provider configuration for providers that need dynamic tokens
|
||||||
|
* Similar to renderer's prepareSpecialProviderConfig
|
||||||
|
*/
|
||||||
|
async function prepareSpecialProviderConfig(provider: Provider, config: AiSdkConfig): Promise<AiSdkConfig> {
|
||||||
|
switch (provider.id) {
|
||||||
|
case 'copilot': {
|
||||||
|
const storedHeaders =
|
||||||
|
((await reduxService.select('state.copilot.defaultHeaders')) as Record<string, string> | null) ?? {}
|
||||||
|
const headers: Record<string, string> = {
|
||||||
|
...COPILOT_DEFAULT_HEADERS,
|
||||||
|
...storedHeaders
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const { token } = await copilotService.getToken(null as any, headers)
|
||||||
|
config.options.apiKey = token
|
||||||
|
const existingHeaders = (config.options.headers as Record<string, string> | undefined) ?? {}
|
||||||
|
config.options.headers = {
|
||||||
|
...headers,
|
||||||
|
...existingHeaders
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to get Copilot token', error as Error)
|
||||||
|
throw new Error('Failed to get Copilot token. Please re-authorize Copilot.')
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'anthropic': {
|
||||||
|
if (provider.authType === 'oauth') {
|
||||||
|
try {
|
||||||
|
const oauthToken = await anthropicService.getValidAccessToken()
|
||||||
|
if (!oauthToken) {
|
||||||
|
throw new Error('Anthropic OAuth token not available. Please re-authorize.')
|
||||||
|
}
|
||||||
|
config.options = {
|
||||||
|
...config.options,
|
||||||
|
headers: {
|
||||||
|
...(config.options.headers ? config.options.headers : {}),
|
||||||
|
'Content-Type': 'application/json',
|
||||||
|
'anthropic-version': '2023-06-01',
|
||||||
|
'anthropic-beta': 'oauth-2025-04-20',
|
||||||
|
Authorization: `Bearer ${oauthToken}`
|
||||||
|
},
|
||||||
|
baseURL: 'https://api.anthropic.com/v1',
|
||||||
|
apiKey: ''
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to get Anthropic OAuth token', error as Error)
|
||||||
|
throw new Error('Failed to get Anthropic OAuth token. Please re-authorize.')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'cherryai': {
|
||||||
|
// Create a signed fetch wrapper for cherryai
|
||||||
|
const baseFetch = net.fetch as typeof globalThis.fetch
|
||||||
|
config.options.fetch = async (url: RequestInfo | URL, options?: RequestInit) => {
|
||||||
|
if (!options?.body) {
|
||||||
|
return baseFetch(url, options)
|
||||||
|
}
|
||||||
|
const signature = cherryaiGenerateSignature({
|
||||||
|
method: 'POST',
|
||||||
|
path: '/chat/completions',
|
||||||
|
query: '',
|
||||||
|
body: JSON.parse(options.body as string)
|
||||||
|
})
|
||||||
|
return baseFetch(url, {
|
||||||
|
...options,
|
||||||
|
headers: {
|
||||||
|
...(options.headers as Record<string, string>),
|
||||||
|
...signature
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
function mapAnthropicThinkToAISdkProviderOptions(
|
||||||
|
provider: Provider,
|
||||||
|
config: MessageCreateParams['thinking']
|
||||||
|
): ProviderOptions | undefined {
|
||||||
|
if (!config) return undefined
|
||||||
|
if (isAnthropicProvider(provider)) {
|
||||||
|
return {
|
||||||
|
anthropic: {
|
||||||
|
...mapToAnthropicProviderOptions(config)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (isGeminiProvider(provider)) {
|
||||||
|
return {
|
||||||
|
google: {
|
||||||
|
...mapToGeminiProviderOptions(config)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (isOpenAIProvider(provider)) {
|
||||||
|
return {
|
||||||
|
openai: {
|
||||||
|
...mapToOpenAIProviderOptions(config)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
function mapToAnthropicProviderOptions(config: NonNullable<MessageCreateParams['thinking']>): AnthropicProviderOptions {
|
||||||
|
return {
|
||||||
|
thinking: {
|
||||||
|
type: config.type,
|
||||||
|
budgetTokens: config.type === 'enabled' ? config.budget_tokens : undefined
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function mapToGeminiProviderOptions(
|
||||||
|
config: NonNullable<MessageCreateParams['thinking']>
|
||||||
|
): GoogleGenerativeAIProviderOptions {
|
||||||
|
return {
|
||||||
|
thinkingConfig: {
|
||||||
|
thinkingBudget: config.type === 'enabled' ? config.budget_tokens : -1,
|
||||||
|
includeThoughts: config.type === 'enabled'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function mapToOpenAIProviderOptions(
|
||||||
|
config: NonNullable<MessageCreateParams['thinking']>
|
||||||
|
): OpenAIResponsesProviderOptions {
|
||||||
|
return {
|
||||||
|
reasoningEffort: config.type === 'enabled' ? 'high' : 'none'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Core stream execution function - single source of truth for AI SDK calls
|
||||||
|
*/
|
||||||
|
async function executeStream(config: ExecuteStreamConfig): Promise<AiSdkToAnthropicSSE> {
|
||||||
|
const { provider, modelId, params, middlewares = [], plugins = [], onEvent } = config
|
||||||
|
|
||||||
|
// Convert provider config to AI SDK config
|
||||||
|
let sdkConfig = providerToAiSdkConfig(provider, modelId)
|
||||||
|
|
||||||
|
// Prepare special provider config (Copilot, Anthropic OAuth, etc.)
|
||||||
|
sdkConfig = await prepareSpecialProviderConfig(provider, sdkConfig)
|
||||||
|
|
||||||
|
// Create provider instance and get language model
|
||||||
|
const aiSdkProvider = await createAiSdkProvider(sdkConfig)
|
||||||
|
const baseModel = aiSdkProvider.languageModel(modelId)
|
||||||
|
|
||||||
|
// Apply middlewares if present
|
||||||
|
const model =
|
||||||
|
middlewares.length > 0 && typeof baseModel === 'object'
|
||||||
|
? (wrapLanguageModel({ model: baseModel, middleware: middlewares }) as typeof baseModel)
|
||||||
|
: baseModel
|
||||||
|
|
||||||
|
// Create executor with plugins
|
||||||
|
const executor = createExecutor(sdkConfig.providerId, sdkConfig.options, plugins)
|
||||||
|
|
||||||
|
// Convert messages and tools
|
||||||
|
const coreMessages = convertAnthropicToAiMessages(params)
|
||||||
|
const tools = convertAnthropicToolsToAiSdk(params.tools)
|
||||||
|
|
||||||
|
// Create the adapter
|
||||||
|
const adapter = new AiSdkToAnthropicSSE({
|
||||||
|
model: `${provider.id}:${modelId}`,
|
||||||
|
onEvent: onEvent || (() => {})
|
||||||
|
})
|
||||||
|
|
||||||
|
// Execute stream - pass model object instead of string
|
||||||
|
const result = await executor.streamText({
|
||||||
|
model, // Now passing LanguageModel object, not string
|
||||||
|
messages: coreMessages,
|
||||||
|
// FIXME: Claude Code传入的maxToken会超出有些模型限制,需做特殊处理,可能在v2好修复一点,现在维护的成本有点高
|
||||||
|
// 已知: 豆包
|
||||||
|
maxOutputTokens: params.max_tokens,
|
||||||
|
temperature: params.temperature,
|
||||||
|
topP: params.top_p,
|
||||||
|
topK: params.top_k,
|
||||||
|
stopSequences: params.stop_sequences,
|
||||||
|
stopWhen: stepCountIs(100),
|
||||||
|
headers: defaultAppHeaders(),
|
||||||
|
tools,
|
||||||
|
providerOptions: mapAnthropicThinkToAISdkProviderOptions(provider, params.thinking)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Process the stream through the adapter
|
||||||
|
await adapter.processStream(result.fullStream)
|
||||||
|
|
||||||
|
return adapter
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Stream a message request using AI SDK executor and convert to Anthropic SSE format
|
||||||
|
*/
|
||||||
|
export async function streamUnifiedMessages(config: UnifiedStreamConfig): Promise<void> {
|
||||||
|
const { response, provider, modelId, params, onError, onComplete, middlewares = [], plugins = [] } = config
|
||||||
|
|
||||||
|
logger.info('Starting unified message stream', {
|
||||||
|
providerId: provider.id,
|
||||||
|
providerType: provider.type,
|
||||||
|
modelId,
|
||||||
|
stream: params.stream,
|
||||||
|
middlewareCount: middlewares.length,
|
||||||
|
pluginCount: plugins.length
|
||||||
|
})
|
||||||
|
|
||||||
|
try {
|
||||||
|
response.setHeader('Content-Type', 'text/event-stream')
|
||||||
|
response.setHeader('Cache-Control', 'no-cache')
|
||||||
|
response.setHeader('Connection', 'keep-alive')
|
||||||
|
response.setHeader('X-Accel-Buffering', 'no')
|
||||||
|
|
||||||
|
await executeStream({
|
||||||
|
provider,
|
||||||
|
modelId,
|
||||||
|
params,
|
||||||
|
middlewares,
|
||||||
|
plugins,
|
||||||
|
onEvent: (event) => {
|
||||||
|
logger.silly('Streaming event', { eventType: event.type })
|
||||||
|
const sseData = formatSSEEvent(event)
|
||||||
|
response.write(sseData)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Send done marker
|
||||||
|
response.write(formatSSEDone())
|
||||||
|
response.end()
|
||||||
|
|
||||||
|
logger.info('Unified message stream completed', { providerId: provider.id, modelId })
|
||||||
|
onComplete?.()
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error in unified message stream', error as Error, { providerId: provider.id, modelId })
|
||||||
|
onError?.(error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generate a non-streaming message response
|
||||||
|
*
|
||||||
|
* Uses simulateStreamingMiddleware to reuse the same streaming logic,
|
||||||
|
* similar to renderer's ModernAiProvider pattern.
|
||||||
|
*/
|
||||||
|
export async function generateUnifiedMessage(
|
||||||
|
providerOrConfig: Provider | GenerateUnifiedMessageConfig,
|
||||||
|
modelId?: string,
|
||||||
|
params?: MessageCreateParams
|
||||||
|
): Promise<ReturnType<typeof AiSdkToAnthropicSSE.prototype.buildNonStreamingResponse>> {
|
||||||
|
// Support both old signature and new config-based signature
|
||||||
|
let config: GenerateUnifiedMessageConfig
|
||||||
|
if ('provider' in providerOrConfig && 'modelId' in providerOrConfig && 'params' in providerOrConfig) {
|
||||||
|
config = providerOrConfig
|
||||||
|
} else {
|
||||||
|
config = {
|
||||||
|
provider: providerOrConfig as Provider,
|
||||||
|
modelId: modelId!,
|
||||||
|
params: params!
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const { provider, middlewares = [], plugins = [] } = config
|
||||||
|
|
||||||
|
logger.info('Starting unified message generation', {
|
||||||
|
providerId: provider.id,
|
||||||
|
providerType: provider.type,
|
||||||
|
modelId: config.modelId,
|
||||||
|
middlewareCount: middlewares.length,
|
||||||
|
pluginCount: plugins.length
|
||||||
|
})
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Add simulateStreamingMiddleware to reuse streaming logic for non-streaming
|
||||||
|
const allMiddlewares = [simulateStreamingMiddleware(), ...middlewares]
|
||||||
|
|
||||||
|
const adapter = await executeStream({
|
||||||
|
provider,
|
||||||
|
modelId: config.modelId,
|
||||||
|
params: config.params,
|
||||||
|
middlewares: allMiddlewares,
|
||||||
|
plugins
|
||||||
|
})
|
||||||
|
|
||||||
|
const finalResponse = adapter.buildNonStreamingResponse()
|
||||||
|
|
||||||
|
logger.info('Unified message generation completed', {
|
||||||
|
providerId: provider.id,
|
||||||
|
modelId: config.modelId
|
||||||
|
})
|
||||||
|
|
||||||
|
return finalResponse
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error in unified message generation', error as Error, {
|
||||||
|
providerId: provider.id,
|
||||||
|
modelId: config.modelId
|
||||||
|
})
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export default {
|
||||||
|
streamUnifiedMessages,
|
||||||
|
generateUnifiedMessage
|
||||||
|
}
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import { CacheService } from '@main/services/CacheService'
|
import { CacheService } from '@main/services/CacheService'
|
||||||
import { loggerService } from '@main/services/LoggerService'
|
import { loggerService } from '@main/services/LoggerService'
|
||||||
import { reduxService } from '@main/services/ReduxService'
|
import { reduxService } from '@main/services/ReduxService'
|
||||||
import { isSiliconAnthropicCompatibleModel } from '@shared/config/providers'
|
import { isPpioAnthropicCompatibleModel, isSiliconAnthropicCompatibleModel } from '@shared/config/providers'
|
||||||
import type { ApiModel, Model, Provider } from '@types'
|
import type { ApiModel, Model, Provider } from '@types'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ApiServerUtils')
|
const logger = loggerService.withContext('ApiServerUtils')
|
||||||
@@ -28,10 +28,9 @@ export async function getAvailableProviders(): Promise<Provider[]> {
|
|||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
|
|
||||||
// Support OpenAI and Anthropic type providers for API server
|
// Support all provider types that AI SDK can handle
|
||||||
const supportedProviders = providers.filter(
|
// The unified-messages service uses AI SDK which supports many providers
|
||||||
(p: Provider) => p.enabled && (p.type === 'openai' || p.type === 'anthropic')
|
const supportedProviders = providers.filter((p: Provider) => p.enabled)
|
||||||
)
|
|
||||||
|
|
||||||
// Cache the filtered results
|
// Cache the filtered results
|
||||||
CacheService.set(PROVIDERS_CACHE_KEY, supportedProviders, PROVIDERS_CACHE_TTL)
|
CacheService.set(PROVIDERS_CACHE_KEY, supportedProviders, PROVIDERS_CACHE_TTL)
|
||||||
@@ -160,7 +159,7 @@ export async function validateModelId(model: string): Promise<{
|
|||||||
valid: false,
|
valid: false,
|
||||||
error: {
|
error: {
|
||||||
type: 'provider_not_found',
|
type: 'provider_not_found',
|
||||||
message: `Provider '${providerId}' not found, not enabled, or not supported. Only OpenAI providers are currently supported.`,
|
message: `Provider '${providerId}' not found or not enabled.`,
|
||||||
code: 'provider_not_found'
|
code: 'provider_not_found'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -262,14 +261,8 @@ export function validateProvider(provider: Provider): boolean {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Support OpenAI and Anthropic type providers
|
// AI SDK supports many provider types, no longer need to filter by type
|
||||||
if (provider.type !== 'openai' && provider.type !== 'anthropic') {
|
// The unified-messages service handles all supported types
|
||||||
logger.debug('Provider type not supported', {
|
|
||||||
providerId: provider.id,
|
|
||||||
providerType: provider.type
|
|
||||||
})
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
return true
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
@@ -290,8 +283,39 @@ export const getProviderAnthropicModelChecker = (providerId: string): ((m: Model
|
|||||||
return (m: Model) => m.id.includes('claude')
|
return (m: Model) => m.id.includes('claude')
|
||||||
case 'silicon':
|
case 'silicon':
|
||||||
return (m: Model) => isSiliconAnthropicCompatibleModel(m.id)
|
return (m: Model) => isSiliconAnthropicCompatibleModel(m.id)
|
||||||
|
case 'ppio':
|
||||||
|
return (m: Model) => isPpioAnthropicCompatibleModel(m.id)
|
||||||
default:
|
default:
|
||||||
// allow all models when checker not configured
|
// allow all models when checker not configured
|
||||||
return () => true
|
return () => true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if a specific model is compatible with Anthropic API for a given provider.
|
||||||
|
*
|
||||||
|
* This is used for fine-grained routing decisions at the model level.
|
||||||
|
* For aggregated providers (like Silicon), only certain models support the Anthropic API endpoint.
|
||||||
|
*
|
||||||
|
* @param provider - The provider to check
|
||||||
|
* @param modelId - The model ID to check (without provider prefix)
|
||||||
|
* @returns true if the model supports Anthropic API endpoint
|
||||||
|
*/
|
||||||
|
export function isModelAnthropicCompatible(provider: Provider, modelId: string): boolean {
|
||||||
|
const checker = getProviderAnthropicModelChecker(provider.id)
|
||||||
|
|
||||||
|
const model = provider.models?.find((m) => m.id === modelId)
|
||||||
|
|
||||||
|
if (model) {
|
||||||
|
return checker(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
const minimalModel: Model = {
|
||||||
|
id: modelId,
|
||||||
|
name: modelId,
|
||||||
|
provider: provider.id,
|
||||||
|
group: ''
|
||||||
|
}
|
||||||
|
|
||||||
|
return checker(minimalModel)
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,9 +1,19 @@
|
|||||||
|
import type { ReasoningDetailUnion } from '@main/apiServer/adapters/openrouter'
|
||||||
|
|
||||||
interface CacheItem<T> {
|
interface CacheItem<T> {
|
||||||
data: T
|
data: T
|
||||||
timestamp: number
|
timestamp: number
|
||||||
duration: number
|
duration: number
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Interface for reasoning cache
|
||||||
|
*/
|
||||||
|
export interface IReasoningCache<T> {
|
||||||
|
set(key: string, value: T): void
|
||||||
|
get(key: string): T | undefined
|
||||||
|
}
|
||||||
|
|
||||||
export class CacheService {
|
export class CacheService {
|
||||||
private static cache: Map<string, CacheItem<any>> = new Map()
|
private static cache: Map<string, CacheItem<any>> = new Map()
|
||||||
|
|
||||||
@@ -72,3 +82,14 @@ export class CacheService {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Singleton cache instances using CacheService
|
||||||
|
export const googleReasoningCache: IReasoningCache<string> = {
|
||||||
|
set: (key, value) => CacheService.set(`google-reasoning:${key}`, value, 30 * 60 * 1000),
|
||||||
|
get: (key) => CacheService.get(`google-reasoning:${key}`) || undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
export const openRouterReasoningCache: IReasoningCache<ReasoningDetailUnion[]> = {
|
||||||
|
set: (key, value) => CacheService.set(`openrouter-reasoning:${key}`, value, 30 * 60 * 1000),
|
||||||
|
get: (key) => CacheService.get(`openrouter-reasoning:${key}`) || undefined
|
||||||
|
}
|
||||||
|
|||||||
@@ -548,6 +548,17 @@ class CodeToolsService {
|
|||||||
logger.debug(`Environment variables:`, Object.keys(env))
|
logger.debug(`Environment variables:`, Object.keys(env))
|
||||||
logger.debug(`Options:`, options)
|
logger.debug(`Options:`, options)
|
||||||
|
|
||||||
|
// Validate directory exists before proceeding
|
||||||
|
if (!directory || !fs.existsSync(directory)) {
|
||||||
|
const errorMessage = `Directory does not exist: ${directory}`
|
||||||
|
logger.error(errorMessage)
|
||||||
|
return {
|
||||||
|
success: false,
|
||||||
|
message: errorMessage,
|
||||||
|
command: ''
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const packageName = await this.getPackageName(cliTool)
|
const packageName = await this.getPackageName(cliTool)
|
||||||
const bunPath = await this.getBunPath()
|
const bunPath = await this.getBunPath()
|
||||||
const executableName = await this.getCliExecutableName(cliTool)
|
const executableName = await this.getCliExecutableName(cliTool)
|
||||||
@@ -709,6 +720,7 @@ class CodeToolsService {
|
|||||||
// Build bat file content, including debug information
|
// Build bat file content, including debug information
|
||||||
const batContent = [
|
const batContent = [
|
||||||
'@echo off',
|
'@echo off',
|
||||||
|
'chcp 65001 >nul 2>&1', // Switch to UTF-8 code page for international path support
|
||||||
`title ${cliTool} - Cherry Studio`, // Set window title in bat file
|
`title ${cliTool} - Cherry Studio`, // Set window title in bat file
|
||||||
'echo ================================================',
|
'echo ================================================',
|
||||||
'echo Cherry Studio CLI Tool Launcher',
|
'echo Cherry Studio CLI Tool Launcher',
|
||||||
|
|||||||
@@ -620,7 +620,7 @@ class McpService {
|
|||||||
tools.map((tool: SDKTool) => {
|
tools.map((tool: SDKTool) => {
|
||||||
const serverTool: MCPTool = {
|
const serverTool: MCPTool = {
|
||||||
...tool,
|
...tool,
|
||||||
id: buildFunctionCallToolName(server.name, tool.name),
|
id: buildFunctionCallToolName(server.name, tool.name, server.id),
|
||||||
serverId: server.id,
|
serverId: server.id,
|
||||||
serverName: server.name,
|
serverName: server.name,
|
||||||
type: 'mcp'
|
type: 'mcp'
|
||||||
|
|||||||
@@ -87,6 +87,7 @@ export class ClaudeStreamState {
|
|||||||
private pendingUsage: PendingUsageState = {}
|
private pendingUsage: PendingUsageState = {}
|
||||||
private pendingToolCalls = new Map<string, PendingToolCall>()
|
private pendingToolCalls = new Map<string, PendingToolCall>()
|
||||||
private stepActive = false
|
private stepActive = false
|
||||||
|
private _streamFinished = false
|
||||||
|
|
||||||
constructor(options: ClaudeStreamStateOptions) {
|
constructor(options: ClaudeStreamStateOptions) {
|
||||||
this.logger = loggerService.withContext('ClaudeStreamState')
|
this.logger = loggerService.withContext('ClaudeStreamState')
|
||||||
@@ -289,6 +290,16 @@ export class ClaudeStreamState {
|
|||||||
getNamespacedToolCallId(rawToolCallId: string): string {
|
getNamespacedToolCallId(rawToolCallId: string): string {
|
||||||
return buildNamespacedToolCallId(this.agentSessionId, rawToolCallId)
|
return buildNamespacedToolCallId(this.agentSessionId, rawToolCallId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/** Marks the stream as finished (either completed or errored). */
|
||||||
|
markFinished(): void {
|
||||||
|
this._streamFinished = true
|
||||||
|
}
|
||||||
|
|
||||||
|
/** Returns true if the stream has already emitted a terminal event. */
|
||||||
|
isFinished(): boolean {
|
||||||
|
return this._streamFinished
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export type { PendingToolCall }
|
export type { PendingToolCall }
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
// src/main/services/agents/services/claudecode/index.ts
|
// src/main/services/agents/services/claudecode/index.ts
|
||||||
import { EventEmitter } from 'node:events'
|
import { EventEmitter } from 'node:events'
|
||||||
import { createRequire } from 'node:module'
|
import { createRequire } from 'node:module'
|
||||||
|
import path from 'node:path'
|
||||||
|
|
||||||
import type {
|
import type {
|
||||||
CanUseTool,
|
CanUseTool,
|
||||||
@@ -84,18 +85,14 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
})
|
})
|
||||||
return aiStream
|
return aiStream
|
||||||
}
|
}
|
||||||
if (
|
// Validate provider has required configuration
|
||||||
(modelInfo.provider?.type !== 'anthropic' &&
|
// Note: We no longer restrict to anthropic type only - the API Server's unified adapter
|
||||||
(modelInfo.provider?.anthropicApiHost === undefined || modelInfo.provider.anthropicApiHost.trim() === '')) ||
|
// handles format conversion for any provider type (OpenAI, Gemini, etc.)
|
||||||
modelInfo.provider.apiKey === ''
|
if (!modelInfo.provider?.apiKey) {
|
||||||
) {
|
logger.error('Provider API key is missing', { modelInfo })
|
||||||
logger.error('Anthropic provider configuration is missing', {
|
|
||||||
modelInfo
|
|
||||||
})
|
|
||||||
|
|
||||||
aiStream.emit('data', {
|
aiStream.emit('data', {
|
||||||
type: 'error',
|
type: 'error',
|
||||||
error: new Error(`Invalid provider type '${modelInfo.provider?.type}'. Expected 'anthropic' provider type.`)
|
error: new Error(`Provider '${modelInfo.provider?.id}' is missing API key configuration.`)
|
||||||
})
|
})
|
||||||
return aiStream
|
return aiStream
|
||||||
}
|
}
|
||||||
@@ -106,22 +103,25 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
Object.entries(loginShellEnv).filter(([key]) => !key.toLowerCase().endsWith('_proxy'))
|
Object.entries(loginShellEnv).filter(([key]) => !key.toLowerCase().endsWith('_proxy'))
|
||||||
) as Record<string, string>
|
) as Record<string, string>
|
||||||
|
|
||||||
|
// Route through local API Server which handles format conversion via unified adapter
|
||||||
|
// This enables Claude Code Agent to work with any provider (OpenAI, Gemini, etc.)
|
||||||
|
// The API Server converts AI SDK responses to Anthropic SSE format transparently
|
||||||
const env = {
|
const env = {
|
||||||
...loginShellEnvWithoutProxies,
|
...loginShellEnvWithoutProxies,
|
||||||
// TODO: fix the proxy api server
|
ANTHROPIC_API_KEY: apiConfig.apiKey,
|
||||||
// ANTHROPIC_API_KEY: apiConfig.apiKey,
|
ANTHROPIC_AUTH_TOKEN: apiConfig.apiKey,
|
||||||
// ANTHROPIC_AUTH_TOKEN: apiConfig.apiKey,
|
ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`,
|
||||||
// ANTHROPIC_BASE_URL: `http://${apiConfig.host}:${apiConfig.port}/${modelInfo.provider.id}`,
|
|
||||||
ANTHROPIC_API_KEY: modelInfo.provider.apiKey,
|
|
||||||
ANTHROPIC_AUTH_TOKEN: modelInfo.provider.apiKey,
|
|
||||||
ANTHROPIC_BASE_URL: modelInfo.provider.anthropicApiHost?.trim() || modelInfo.provider.apiHost,
|
|
||||||
ANTHROPIC_MODEL: modelInfo.modelId,
|
ANTHROPIC_MODEL: modelInfo.modelId,
|
||||||
ANTHROPIC_DEFAULT_OPUS_MODEL: modelInfo.modelId,
|
ANTHROPIC_DEFAULT_OPUS_MODEL: modelInfo.modelId,
|
||||||
ANTHROPIC_DEFAULT_SONNET_MODEL: modelInfo.modelId,
|
ANTHROPIC_DEFAULT_SONNET_MODEL: modelInfo.modelId,
|
||||||
// TODO: support set small model in UI
|
// TODO: support set small model in UI
|
||||||
ANTHROPIC_DEFAULT_HAIKU_MODEL: modelInfo.modelId,
|
ANTHROPIC_DEFAULT_HAIKU_MODEL: modelInfo.modelId,
|
||||||
ELECTRON_RUN_AS_NODE: '1',
|
ELECTRON_RUN_AS_NODE: '1',
|
||||||
ELECTRON_NO_ATTACH_CONSOLE: '1'
|
ELECTRON_NO_ATTACH_CONSOLE: '1',
|
||||||
|
// Set CLAUDE_CONFIG_DIR to app's userData directory to avoid path encoding issues
|
||||||
|
// on Windows when the username contains non-ASCII characters (e.g., Chinese characters)
|
||||||
|
// This prevents the SDK from using the user's home directory which may have encoding problems
|
||||||
|
CLAUDE_CONFIG_DIR: path.join(app.getPath('userData'), '.claude')
|
||||||
}
|
}
|
||||||
|
|
||||||
const errorChunks: string[] = []
|
const errorChunks: string[] = []
|
||||||
@@ -534,6 +534,19 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip emitting error if stream already finished (error was handled via result message)
|
||||||
|
if (streamState.isFinished()) {
|
||||||
|
logger.debug('SDK process exited after stream finished, skipping duplicate error event', {
|
||||||
|
duration,
|
||||||
|
error: errorObj instanceof Error ? { name: errorObj.name, message: errorObj.message } : String(errorObj)
|
||||||
|
})
|
||||||
|
// Still emit complete to signal stream end
|
||||||
|
stream.emit('data', {
|
||||||
|
type: 'complete'
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
errorChunks.push(errorObj instanceof Error ? errorObj.message : String(errorObj))
|
errorChunks.push(errorObj instanceof Error ? errorObj.message : String(errorObj))
|
||||||
const errorMessage = errorChunks.join('\n\n')
|
const errorMessage = errorChunks.join('\n\n')
|
||||||
logger.error('SDK query failed', {
|
logger.error('SDK query failed', {
|
||||||
|
|||||||
@@ -121,7 +121,7 @@ export function transformSDKMessageToStreamParts(sdkMessage: SDKMessage, state:
|
|||||||
case 'system':
|
case 'system':
|
||||||
return handleSystemMessage(sdkMessage)
|
return handleSystemMessage(sdkMessage)
|
||||||
case 'result':
|
case 'result':
|
||||||
return handleResultMessage(sdkMessage)
|
return handleResultMessage(sdkMessage, state)
|
||||||
default:
|
default:
|
||||||
logger.warn('Unknown SDKMessage type', { type: (sdkMessage as any).type })
|
logger.warn('Unknown SDKMessage type', { type: (sdkMessage as any).type })
|
||||||
return []
|
return []
|
||||||
@@ -193,6 +193,30 @@ function handleAssistantMessage(
|
|||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
case 'thinking':
|
||||||
|
case 'redacted_thinking': {
|
||||||
|
const thinkingText = block.type === 'thinking' ? block.thinking : block.data
|
||||||
|
if (thinkingText) {
|
||||||
|
const id = generateMessageId()
|
||||||
|
chunks.push({
|
||||||
|
type: 'reasoning-start',
|
||||||
|
id,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
chunks.push({
|
||||||
|
type: 'reasoning-delta',
|
||||||
|
id,
|
||||||
|
text: thinkingText,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
chunks.push({
|
||||||
|
type: 'reasoning-end',
|
||||||
|
id,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
case 'tool_use':
|
case 'tool_use':
|
||||||
handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks)
|
handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks)
|
||||||
break
|
break
|
||||||
@@ -445,7 +469,11 @@ function handleStreamEvent(
|
|||||||
case 'content_block_stop': {
|
case 'content_block_stop': {
|
||||||
const block = state.closeBlock(event.index)
|
const block = state.closeBlock(event.index)
|
||||||
if (!block) {
|
if (!block) {
|
||||||
logger.warn('Received content_block_stop for unknown index', { index: event.index })
|
// Some providers (e.g., Gemini) send content via assistant message before stream events,
|
||||||
|
// so the block may not exist in state. This is expected behavior, not an error.
|
||||||
|
logger.debug('Received content_block_stop for unknown index (may be from non-streaming content)', {
|
||||||
|
index: event.index
|
||||||
|
})
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -679,7 +707,13 @@ function handleSystemMessage(message: Extract<SDKMessage, { type: 'system' }>):
|
|||||||
* Successful runs yield a `finish` frame with aggregated usage metrics, while
|
* Successful runs yield a `finish` frame with aggregated usage metrics, while
|
||||||
* failures are surfaced as `error` frames.
|
* failures are surfaced as `error` frames.
|
||||||
*/
|
*/
|
||||||
function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>): AgentStreamPart[] {
|
function handleResultMessage(
|
||||||
|
message: Extract<SDKMessage, { type: 'result' }>,
|
||||||
|
state: ClaudeStreamState
|
||||||
|
): AgentStreamPart[] {
|
||||||
|
// Mark stream as finished to prevent duplicate error events when SDK process exits
|
||||||
|
state.markFinished()
|
||||||
|
|
||||||
const chunks: AgentStreamPart[] = []
|
const chunks: AgentStreamPart[] = []
|
||||||
|
|
||||||
let usage: LanguageModelUsage | undefined
|
let usage: LanguageModelUsage | undefined
|
||||||
@@ -691,26 +725,33 @@ function handleResultMessage(message: Extract<SDKMessage, { type: 'result' }>):
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (message.subtype === 'success') {
|
chunks.push({
|
||||||
chunks.push({
|
type: 'finish',
|
||||||
type: 'finish',
|
totalUsage: usage ?? emptyUsage,
|
||||||
totalUsage: usage ?? emptyUsage,
|
finishReason: mapClaudeCodeFinishReason(message.subtype),
|
||||||
finishReason: mapClaudeCodeFinishReason(message.subtype),
|
providerMetadata: {
|
||||||
providerMetadata: {
|
...sdkMessageToProviderMetadata(message),
|
||||||
...sdkMessageToProviderMetadata(message),
|
usage: message.usage,
|
||||||
usage: message.usage,
|
durationMs: message.duration_ms,
|
||||||
durationMs: message.duration_ms,
|
costUsd: message.total_cost_usd,
|
||||||
costUsd: message.total_cost_usd,
|
raw: message
|
||||||
raw: message
|
}
|
||||||
}
|
} as AgentStreamPart)
|
||||||
} as AgentStreamPart)
|
if (message.subtype !== 'success') {
|
||||||
} else {
|
|
||||||
chunks.push({
|
chunks.push({
|
||||||
type: 'error',
|
type: 'error',
|
||||||
error: {
|
error: {
|
||||||
message: `${message.subtype}: Process failed after ${message.num_turns} turns`
|
message: `${message.subtype}: Process failed after ${message.num_turns} turns`
|
||||||
}
|
}
|
||||||
} as AgentStreamPart)
|
} as AgentStreamPart)
|
||||||
|
} else {
|
||||||
|
if (message.is_error) {
|
||||||
|
const errorMatch = message.result.match(/\{.*\}/)
|
||||||
|
if (errorMatch) {
|
||||||
|
const errorDetail = JSON.parse(errorMatch[0])
|
||||||
|
chunks.push(errorDetail)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return chunks
|
return chunks
|
||||||
}
|
}
|
||||||
|
|||||||
196
src/main/utils/__tests__/mcp.test.ts
Normal file
196
src/main/utils/__tests__/mcp.test.ts
Normal file
@@ -0,0 +1,196 @@
|
|||||||
|
import { describe, expect, it } from 'vitest'
|
||||||
|
|
||||||
|
import { buildFunctionCallToolName } from '../mcp'
|
||||||
|
|
||||||
|
describe('buildFunctionCallToolName', () => {
|
||||||
|
describe('basic functionality', () => {
|
||||||
|
it('should combine server name and tool name', () => {
|
||||||
|
const result = buildFunctionCallToolName('github', 'search_issues')
|
||||||
|
expect(result).toContain('github')
|
||||||
|
expect(result).toContain('search')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should sanitize names by replacing dashes with underscores', () => {
|
||||||
|
const result = buildFunctionCallToolName('my-server', 'my-tool')
|
||||||
|
// Input dashes are replaced, but the separator between server and tool is a dash
|
||||||
|
expect(result).toBe('my_serv-my_tool')
|
||||||
|
expect(result).toContain('_')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle empty server names gracefully', () => {
|
||||||
|
const result = buildFunctionCallToolName('', 'tool')
|
||||||
|
expect(result).toBeTruthy()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('uniqueness with serverId', () => {
|
||||||
|
it('should generate different IDs for same server name but different serverIds', () => {
|
||||||
|
const serverId1 = 'server-id-123456'
|
||||||
|
const serverId2 = 'server-id-789012'
|
||||||
|
const serverName = 'github'
|
||||||
|
const toolName = 'search_repos'
|
||||||
|
|
||||||
|
const result1 = buildFunctionCallToolName(serverName, toolName, serverId1)
|
||||||
|
const result2 = buildFunctionCallToolName(serverName, toolName, serverId2)
|
||||||
|
|
||||||
|
expect(result1).not.toBe(result2)
|
||||||
|
expect(result1).toContain('123456')
|
||||||
|
expect(result2).toContain('789012')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should generate same ID when serverId is not provided', () => {
|
||||||
|
const serverName = 'github'
|
||||||
|
const toolName = 'search_repos'
|
||||||
|
|
||||||
|
const result1 = buildFunctionCallToolName(serverName, toolName)
|
||||||
|
const result2 = buildFunctionCallToolName(serverName, toolName)
|
||||||
|
|
||||||
|
expect(result1).toBe(result2)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should include serverId suffix when provided', () => {
|
||||||
|
const serverId = 'abc123def456'
|
||||||
|
const result = buildFunctionCallToolName('server', 'tool', serverId)
|
||||||
|
|
||||||
|
// Should include last 6 chars of serverId
|
||||||
|
expect(result).toContain('ef456')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('character sanitization', () => {
|
||||||
|
it('should replace invalid characters with underscores', () => {
|
||||||
|
const result = buildFunctionCallToolName('test@server', 'tool#name')
|
||||||
|
expect(result).not.toMatch(/[@#]/)
|
||||||
|
expect(result).toMatch(/^[a-zA-Z0-9_-]+$/)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should ensure name starts with a letter', () => {
|
||||||
|
const result = buildFunctionCallToolName('123server', '456tool')
|
||||||
|
expect(result).toMatch(/^[a-zA-Z]/)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle consecutive underscores/dashes', () => {
|
||||||
|
const result = buildFunctionCallToolName('my--server', 'my__tool')
|
||||||
|
expect(result).not.toMatch(/[_-]{2,}/)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('length constraints', () => {
|
||||||
|
it('should truncate names longer than 63 characters', () => {
|
||||||
|
const longServerName = 'a'.repeat(50)
|
||||||
|
const longToolName = 'b'.repeat(50)
|
||||||
|
const result = buildFunctionCallToolName(longServerName, longToolName, 'id123456')
|
||||||
|
|
||||||
|
expect(result.length).toBeLessThanOrEqual(63)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should not end with underscore or dash after truncation', () => {
|
||||||
|
const longServerName = 'a'.repeat(50)
|
||||||
|
const longToolName = 'b'.repeat(50)
|
||||||
|
const result = buildFunctionCallToolName(longServerName, longToolName, 'id123456')
|
||||||
|
|
||||||
|
expect(result).not.toMatch(/[_-]$/)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should preserve serverId suffix even with long server/tool names', () => {
|
||||||
|
const longServerName = 'a'.repeat(50)
|
||||||
|
const longToolName = 'b'.repeat(50)
|
||||||
|
const serverId = 'server-id-xyz789'
|
||||||
|
|
||||||
|
const result = buildFunctionCallToolName(longServerName, longToolName, serverId)
|
||||||
|
|
||||||
|
// The suffix should be preserved and not truncated
|
||||||
|
expect(result).toContain('xyz789')
|
||||||
|
expect(result.length).toBeLessThanOrEqual(63)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should ensure two long-named servers with different IDs produce different results', () => {
|
||||||
|
const longServerName = 'a'.repeat(50)
|
||||||
|
const longToolName = 'b'.repeat(50)
|
||||||
|
const serverId1 = 'server-id-abc123'
|
||||||
|
const serverId2 = 'server-id-def456'
|
||||||
|
|
||||||
|
const result1 = buildFunctionCallToolName(longServerName, longToolName, serverId1)
|
||||||
|
const result2 = buildFunctionCallToolName(longServerName, longToolName, serverId2)
|
||||||
|
|
||||||
|
// Both should be within limit
|
||||||
|
expect(result1.length).toBeLessThanOrEqual(63)
|
||||||
|
expect(result2.length).toBeLessThanOrEqual(63)
|
||||||
|
|
||||||
|
// They should be different due to preserved suffix
|
||||||
|
expect(result1).not.toBe(result2)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('edge cases with serverId', () => {
|
||||||
|
it('should handle serverId with only non-alphanumeric characters', () => {
|
||||||
|
const serverId = '------' // All dashes
|
||||||
|
const result = buildFunctionCallToolName('server', 'tool', serverId)
|
||||||
|
|
||||||
|
// Should still produce a valid unique suffix via fallback hash
|
||||||
|
expect(result).toBeTruthy()
|
||||||
|
expect(result.length).toBeLessThanOrEqual(63)
|
||||||
|
expect(result).toMatch(/^[a-zA-Z][a-zA-Z0-9_-]*$/)
|
||||||
|
// Should have a suffix (underscore followed by something)
|
||||||
|
expect(result).toMatch(/_[a-z0-9]+$/)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should produce different results for different non-alphanumeric serverIds', () => {
|
||||||
|
const serverId1 = '------'
|
||||||
|
const serverId2 = '!!!!!!'
|
||||||
|
|
||||||
|
const result1 = buildFunctionCallToolName('server', 'tool', serverId1)
|
||||||
|
const result2 = buildFunctionCallToolName('server', 'tool', serverId2)
|
||||||
|
|
||||||
|
// Should be different because the hash fallback produces different values
|
||||||
|
expect(result1).not.toBe(result2)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle empty string serverId differently from undefined', () => {
|
||||||
|
const resultWithEmpty = buildFunctionCallToolName('server', 'tool', '')
|
||||||
|
const resultWithUndefined = buildFunctionCallToolName('server', 'tool', undefined)
|
||||||
|
|
||||||
|
// Empty string is falsy, so both should behave the same (no suffix)
|
||||||
|
expect(resultWithEmpty).toBe(resultWithUndefined)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle serverId with mixed alphanumeric and special chars', () => {
|
||||||
|
const serverId = 'ab@#cd' // Mixed chars, last 6 chars contain some alphanumeric
|
||||||
|
const result = buildFunctionCallToolName('server', 'tool', serverId)
|
||||||
|
|
||||||
|
// Should extract alphanumeric chars: 'abcd' from 'ab@#cd'
|
||||||
|
expect(result).toContain('abcd')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('real-world scenarios', () => {
|
||||||
|
it('should handle GitHub MCP server instances correctly', () => {
|
||||||
|
const serverName = 'github'
|
||||||
|
const toolName = 'search_repositories'
|
||||||
|
|
||||||
|
const githubComId = 'server-github-com-abc123'
|
||||||
|
const gheId = 'server-ghe-internal-xyz789'
|
||||||
|
|
||||||
|
const tool1 = buildFunctionCallToolName(serverName, toolName, githubComId)
|
||||||
|
const tool2 = buildFunctionCallToolName(serverName, toolName, gheId)
|
||||||
|
|
||||||
|
// Should be different
|
||||||
|
expect(tool1).not.toBe(tool2)
|
||||||
|
|
||||||
|
// Both should be valid identifiers
|
||||||
|
expect(tool1).toMatch(/^[a-zA-Z][a-zA-Z0-9_-]*$/)
|
||||||
|
expect(tool2).toMatch(/^[a-zA-Z][a-zA-Z0-9_-]*$/)
|
||||||
|
|
||||||
|
// Both should be <= 63 chars
|
||||||
|
expect(tool1.length).toBeLessThanOrEqual(63)
|
||||||
|
expect(tool2.length).toBeLessThanOrEqual(63)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle tool names that already include server name prefix', () => {
|
||||||
|
const result = buildFunctionCallToolName('github', 'github_search_repos')
|
||||||
|
expect(result).toBeTruthy()
|
||||||
|
// Should not double the server name
|
||||||
|
expect(result.split('github').length - 1).toBeLessThanOrEqual(2)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -1,7 +1,25 @@
|
|||||||
export function buildFunctionCallToolName(serverName: string, toolName: string) {
|
export function buildFunctionCallToolName(serverName: string, toolName: string, serverId?: string) {
|
||||||
const sanitizedServer = serverName.trim().replace(/-/g, '_')
|
const sanitizedServer = serverName.trim().replace(/-/g, '_')
|
||||||
const sanitizedTool = toolName.trim().replace(/-/g, '_')
|
const sanitizedTool = toolName.trim().replace(/-/g, '_')
|
||||||
|
|
||||||
|
// Calculate suffix first to reserve space for it
|
||||||
|
// Suffix format: "_" + 6 alphanumeric chars = 7 chars total
|
||||||
|
let serverIdSuffix = ''
|
||||||
|
if (serverId) {
|
||||||
|
// Take the last 6 characters of the serverId for brevity
|
||||||
|
serverIdSuffix = serverId.slice(-6).replace(/[^a-zA-Z0-9]/g, '')
|
||||||
|
|
||||||
|
// Fallback: if suffix becomes empty (all non-alphanumeric chars), use a simple hash
|
||||||
|
if (!serverIdSuffix) {
|
||||||
|
const hash = serverId.split('').reduce((acc, char) => acc + char.charCodeAt(0), 0)
|
||||||
|
serverIdSuffix = hash.toString(36).slice(-6) || 'x'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reserve space for suffix when calculating max base name length
|
||||||
|
const SUFFIX_LENGTH = serverIdSuffix ? serverIdSuffix.length + 1 : 0 // +1 for underscore
|
||||||
|
const MAX_BASE_LENGTH = 63 - SUFFIX_LENGTH
|
||||||
|
|
||||||
// Combine server name and tool name
|
// Combine server name and tool name
|
||||||
let name = sanitizedTool
|
let name = sanitizedTool
|
||||||
if (!sanitizedTool.includes(sanitizedServer.slice(0, 7))) {
|
if (!sanitizedTool.includes(sanitizedServer.slice(0, 7))) {
|
||||||
@@ -20,9 +38,9 @@ export function buildFunctionCallToolName(serverName: string, toolName: string)
|
|||||||
// Remove consecutive underscores/dashes (optional improvement)
|
// Remove consecutive underscores/dashes (optional improvement)
|
||||||
name = name.replace(/[_-]{2,}/g, '_')
|
name = name.replace(/[_-]{2,}/g, '_')
|
||||||
|
|
||||||
// Truncate to 63 characters maximum
|
// Truncate base name BEFORE adding suffix to ensure suffix is never cut off
|
||||||
if (name.length > 63) {
|
if (name.length > MAX_BASE_LENGTH) {
|
||||||
name = name.slice(0, 63)
|
name = name.slice(0, MAX_BASE_LENGTH)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle edge case: ensure we still have a valid name if truncation left invalid chars at edges
|
// Handle edge case: ensure we still have a valid name if truncation left invalid chars at edges
|
||||||
@@ -30,5 +48,10 @@ export function buildFunctionCallToolName(serverName: string, toolName: string)
|
|||||||
name = name.slice(0, -1)
|
name = name.slice(0, -1)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Now append the suffix - it will always fit within 63 chars
|
||||||
|
if (serverIdSuffix) {
|
||||||
|
name = `${name}_${serverIdSuffix}`
|
||||||
|
}
|
||||||
|
|
||||||
return name
|
return name
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -212,8 +212,9 @@ export class ToolCallChunkHandler {
|
|||||||
description: toolName,
|
description: toolName,
|
||||||
type: 'builtin'
|
type: 'builtin'
|
||||||
} as BaseTool
|
} as BaseTool
|
||||||
} else if ((mcpTool = this.mcpTools.find((t) => t.name === toolName) as MCPTool)) {
|
} else if ((mcpTool = this.mcpTools.find((t) => t.id === toolName) as MCPTool)) {
|
||||||
// 如果是客户端执行的 MCP 工具,沿用现有逻辑
|
// 如果是客户端执行的 MCP 工具,沿用现有逻辑
|
||||||
|
// toolName is mcpTool.id (registered with id as key in convertMcpToolsToAiSdkTools)
|
||||||
logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`)
|
logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`)
|
||||||
// mcpTool = this.mcpTools.find((t) => t.name === toolName) as MCPTool
|
// mcpTool = this.mcpTools.find((t) => t.name === toolName) as MCPTool
|
||||||
// if (!mcpTool) {
|
// if (!mcpTool) {
|
||||||
|
|||||||
@@ -405,6 +405,9 @@ export abstract class BaseApiClient<
|
|||||||
if (!param.name?.trim()) {
|
if (!param.name?.trim()) {
|
||||||
return acc
|
return acc
|
||||||
}
|
}
|
||||||
|
// Parse JSON type parameters (Legacy API clients)
|
||||||
|
// Related: src/renderer/src/pages/settings/AssistantSettings/AssistantModelSettings.tsx:133-148
|
||||||
|
// The UI stores JSON type params as strings, this function parses them before sending to API
|
||||||
if (param.type === 'json') {
|
if (param.type === 'json') {
|
||||||
const value = param.value as string
|
const value = param.value as string
|
||||||
if (value === 'undefined') {
|
if (value === 'undefined') {
|
||||||
|
|||||||
@@ -46,6 +46,7 @@ import type {
|
|||||||
GeminiSdkRawOutput,
|
GeminiSdkRawOutput,
|
||||||
GeminiSdkToolCall
|
GeminiSdkToolCall
|
||||||
} from '@renderer/types/sdk'
|
} from '@renderer/types/sdk'
|
||||||
|
import { getTrailingApiVersion, withoutTrailingApiVersion } from '@renderer/utils'
|
||||||
import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
||||||
import {
|
import {
|
||||||
geminiFunctionCallToMcpTool,
|
geminiFunctionCallToMcpTool,
|
||||||
@@ -163,6 +164,10 @@ export class GeminiAPIClient extends BaseApiClient<
|
|||||||
return models
|
return models
|
||||||
}
|
}
|
||||||
|
|
||||||
|
override getBaseURL(): string {
|
||||||
|
return withoutTrailingApiVersion(super.getBaseURL())
|
||||||
|
}
|
||||||
|
|
||||||
override async getSdkInstance() {
|
override async getSdkInstance() {
|
||||||
if (this.sdkInstance) {
|
if (this.sdkInstance) {
|
||||||
return this.sdkInstance
|
return this.sdkInstance
|
||||||
@@ -188,6 +193,13 @@ export class GeminiAPIClient extends BaseApiClient<
|
|||||||
if (this.provider.isVertex) {
|
if (this.provider.isVertex) {
|
||||||
return 'v1'
|
return 'v1'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Extract trailing API version from the URL
|
||||||
|
const trailingVersion = getTrailingApiVersion(this.provider.apiHost || '')
|
||||||
|
if (trailingVersion) {
|
||||||
|
return trailingVersion
|
||||||
|
}
|
||||||
|
|
||||||
return 'v1beta'
|
return 'v1beta'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -24,7 +24,7 @@ export class VertexAPIClient extends GeminiAPIClient {
|
|||||||
this.anthropicVertexClient = new AnthropicVertexClient(provider)
|
this.anthropicVertexClient = new AnthropicVertexClient(provider)
|
||||||
// 如果传入的是普通 Provider,转换为 VertexProvider
|
// 如果传入的是普通 Provider,转换为 VertexProvider
|
||||||
if (isVertexProvider(provider)) {
|
if (isVertexProvider(provider)) {
|
||||||
this.vertexProvider = provider
|
this.vertexProvider = provider as VertexProvider
|
||||||
} else {
|
} else {
|
||||||
this.vertexProvider = createVertexProvider(provider)
|
this.vertexProvider = createVertexProvider(provider)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import type { MCPTool } from '@renderer/types'
|
|||||||
import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types'
|
import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types'
|
||||||
import type { Chunk } from '@renderer/types/chunk'
|
import type { Chunk } from '@renderer/types/chunk'
|
||||||
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
|
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
|
||||||
|
import { openrouterReasoningMiddleware, skipGeminiThoughtSignatureMiddleware } from '@shared/middleware'
|
||||||
import type { LanguageModelMiddleware } from 'ai'
|
import type { LanguageModelMiddleware } from 'ai'
|
||||||
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
|
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
|
||||||
import { isEmpty } from 'lodash'
|
import { isEmpty } from 'lodash'
|
||||||
@@ -13,9 +14,7 @@ import { getAiSdkProviderId } from '../provider/factory'
|
|||||||
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
|
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
|
||||||
import { noThinkMiddleware } from './noThinkMiddleware'
|
import { noThinkMiddleware } from './noThinkMiddleware'
|
||||||
import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware'
|
import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware'
|
||||||
import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
|
|
||||||
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
|
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
|
||||||
import { skipGeminiThoughtSignatureMiddleware } from './skipGeminiThoughtSignatureMiddleware'
|
|
||||||
import { toolChoiceMiddleware } from './toolChoiceMiddleware'
|
import { toolChoiceMiddleware } from './toolChoiceMiddleware'
|
||||||
|
|
||||||
const logger = loggerService.withContext('AiSdkMiddlewareBuilder')
|
const logger = loggerService.withContext('AiSdkMiddlewareBuilder')
|
||||||
|
|||||||
@@ -1,50 +0,0 @@
|
|||||||
import type { LanguageModelV2StreamPart } from '@ai-sdk/provider'
|
|
||||||
import type { LanguageModelMiddleware } from 'ai'
|
|
||||||
|
|
||||||
/**
|
|
||||||
* https://openrouter.ai/docs/docs/best-practices/reasoning-tokens#example-preserving-reasoning-blocks-with-openrouter-and-claude
|
|
||||||
*
|
|
||||||
* @returns LanguageModelMiddleware - a middleware filter redacted block
|
|
||||||
*/
|
|
||||||
export function openrouterReasoningMiddleware(): LanguageModelMiddleware {
|
|
||||||
const REDACTED_BLOCK = '[REDACTED]'
|
|
||||||
return {
|
|
||||||
middlewareVersion: 'v2',
|
|
||||||
wrapGenerate: async ({ doGenerate }) => {
|
|
||||||
const { content, ...rest } = await doGenerate()
|
|
||||||
const modifiedContent = content.map((part) => {
|
|
||||||
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
|
|
||||||
return {
|
|
||||||
...part,
|
|
||||||
text: part.text.replace(REDACTED_BLOCK, '')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return part
|
|
||||||
})
|
|
||||||
return { content: modifiedContent, ...rest }
|
|
||||||
},
|
|
||||||
wrapStream: async ({ doStream }) => {
|
|
||||||
const { stream, ...rest } = await doStream()
|
|
||||||
return {
|
|
||||||
stream: stream.pipeThrough(
|
|
||||||
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
|
|
||||||
transform(
|
|
||||||
chunk: LanguageModelV2StreamPart,
|
|
||||||
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
|
|
||||||
) {
|
|
||||||
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
|
|
||||||
controller.enqueue({
|
|
||||||
...chunk,
|
|
||||||
delta: chunk.delta.replace(REDACTED_BLOCK, '')
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
controller.enqueue(chunk)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
),
|
|
||||||
...rest
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
import type { LanguageModelMiddleware } from 'ai'
|
|
||||||
|
|
||||||
/**
|
|
||||||
* skip Gemini Thought Signature Middleware
|
|
||||||
* 由于多模型客户端请求的复杂性(可以中途切换其他模型),这里选择通过中间件方式添加跳过所有 Gemini3 思考签名
|
|
||||||
* Due to the complexity of multi-model client requests (which can switch to other models mid-process),
|
|
||||||
* it was decided to add a skip for all Gemini3 thinking signatures via middleware.
|
|
||||||
* @param aiSdkId AI SDK Provider ID
|
|
||||||
* @returns LanguageModelMiddleware
|
|
||||||
*/
|
|
||||||
export function skipGeminiThoughtSignatureMiddleware(aiSdkId: string): LanguageModelMiddleware {
|
|
||||||
const MAGIC_STRING = 'skip_thought_signature_validator'
|
|
||||||
return {
|
|
||||||
middlewareVersion: 'v2',
|
|
||||||
|
|
||||||
transformParams: async ({ params }) => {
|
|
||||||
const transformedParams = { ...params }
|
|
||||||
// Process messages in prompt
|
|
||||||
if (transformedParams.prompt && Array.isArray(transformedParams.prompt)) {
|
|
||||||
transformedParams.prompt = transformedParams.prompt.map((message) => {
|
|
||||||
if (typeof message.content !== 'string') {
|
|
||||||
for (const part of message.content) {
|
|
||||||
const googleOptions = part?.providerOptions?.[aiSdkId]
|
|
||||||
if (googleOptions?.thoughtSignature) {
|
|
||||||
googleOptions.thoughtSignature = MAGIC_STRING
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return message
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
return transformedParams
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -7,7 +7,7 @@ import { isAwsBedrockProvider, isVertexProvider } from '@renderer/utils/provider
|
|||||||
// https://docs.claude.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
|
// https://docs.claude.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
|
||||||
const INTERLEAVED_THINKING_HEADER = 'interleaved-thinking-2025-05-14'
|
const INTERLEAVED_THINKING_HEADER = 'interleaved-thinking-2025-05-14'
|
||||||
// https://docs.claude.com/en/docs/build-with-claude/context-windows#1m-token-context-window
|
// https://docs.claude.com/en/docs/build-with-claude/context-windows#1m-token-context-window
|
||||||
const CONTEXT_100M_HEADER = 'context-1m-2025-08-07'
|
// const CONTEXT_100M_HEADER = 'context-1m-2025-08-07'
|
||||||
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude/web-search
|
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude/web-search
|
||||||
const WEBSEARCH_HEADER = 'web-search-2025-03-05'
|
const WEBSEARCH_HEADER = 'web-search-2025-03-05'
|
||||||
|
|
||||||
@@ -17,7 +17,7 @@ export function addAnthropicHeaders(assistant: Assistant, model: Model): string[
|
|||||||
if (
|
if (
|
||||||
isClaude45ReasoningModel(model) &&
|
isClaude45ReasoningModel(model) &&
|
||||||
isToolUseModeFunction(assistant) &&
|
isToolUseModeFunction(assistant) &&
|
||||||
!(isVertexProvider(provider) && isAwsBedrockProvider(provider))
|
!(isVertexProvider(provider) || isAwsBedrockProvider(provider))
|
||||||
) {
|
) {
|
||||||
anthropicHeaders.push(INTERLEAVED_THINKING_HEADER)
|
anthropicHeaders.push(INTERLEAVED_THINKING_HEADER)
|
||||||
}
|
}
|
||||||
@@ -25,7 +25,9 @@ export function addAnthropicHeaders(assistant: Assistant, model: Model): string[
|
|||||||
if (isVertexProvider(provider) && assistant.enableWebSearch) {
|
if (isVertexProvider(provider) && assistant.enableWebSearch) {
|
||||||
anthropicHeaders.push(WEBSEARCH_HEADER)
|
anthropicHeaders.push(WEBSEARCH_HEADER)
|
||||||
}
|
}
|
||||||
anthropicHeaders.push(CONTEXT_100M_HEADER)
|
// We may add it by user preference in assistant.settings instead of always adding it.
|
||||||
|
// See #11540, #11397
|
||||||
|
// anthropicHeaders.push(CONTEXT_100M_HEADER)
|
||||||
}
|
}
|
||||||
return anthropicHeaders
|
return anthropicHeaders
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ import { type Assistant, type MCPTool, type Provider } from '@renderer/types'
|
|||||||
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||||
import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern'
|
import { mapRegexToPatterns } from '@renderer/utils/blacklistMatchPattern'
|
||||||
import { replacePromptVariables } from '@renderer/utils/prompt'
|
import { replacePromptVariables } from '@renderer/utils/prompt'
|
||||||
|
import { isAwsBedrockProvider } from '@renderer/utils/provider'
|
||||||
import type { ModelMessage, Tool } from 'ai'
|
import type { ModelMessage, Tool } from 'ai'
|
||||||
import { stepCountIs } from 'ai'
|
import { stepCountIs } from 'ai'
|
||||||
|
|
||||||
@@ -175,7 +176,7 @@ export async function buildStreamTextParams(
|
|||||||
|
|
||||||
let headers: Record<string, string | undefined> = options.requestOptions?.headers ?? {}
|
let headers: Record<string, string | undefined> = options.requestOptions?.headers ?? {}
|
||||||
|
|
||||||
if (isAnthropicModel(model)) {
|
if (isAnthropicModel(model) && !isAwsBedrockProvider(provider)) {
|
||||||
const newBetaHeaders = { 'anthropic-beta': addAnthropicHeaders(assistant, model).join(',') }
|
const newBetaHeaders = { 'anthropic-beta': addAnthropicHeaders(assistant, model).join(',') }
|
||||||
headers = combineHeaders(headers, newBetaHeaders)
|
headers = combineHeaders(headers, newBetaHeaders)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import type { Provider } from '@renderer/types'
|
import type { Model, Provider } from '@renderer/types'
|
||||||
import { describe, expect, it, vi } from 'vitest'
|
import { describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
import { getAiSdkProviderId } from '../factory'
|
import { getAiSdkProviderId } from '../factory'
|
||||||
@@ -68,6 +68,18 @@ function createTestProvider(id: string, type: string): Provider {
|
|||||||
} as Provider
|
} as Provider
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function createAzureProvider(id: string, apiVersion?: string, model?: string): Provider {
|
||||||
|
return {
|
||||||
|
id,
|
||||||
|
type: 'azure-openai',
|
||||||
|
name: `Azure Test ${id}`,
|
||||||
|
apiKey: 'azure-test-key',
|
||||||
|
apiHost: 'azure-test-host',
|
||||||
|
apiVersion,
|
||||||
|
models: [{ id: model || 'gpt-4' } as Model]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
describe('Integrated Provider Registry', () => {
|
describe('Integrated Provider Registry', () => {
|
||||||
describe('Provider ID Resolution', () => {
|
describe('Provider ID Resolution', () => {
|
||||||
it('should resolve openrouter provider correctly', () => {
|
it('should resolve openrouter provider correctly', () => {
|
||||||
@@ -111,6 +123,24 @@ describe('Integrated Provider Registry', () => {
|
|||||||
const result = getAiSdkProviderId(unknownProvider)
|
const result = getAiSdkProviderId(unknownProvider)
|
||||||
expect(result).toBe('unknown-provider')
|
expect(result).toBe('unknown-provider')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('should handle Azure OpenAI providers correctly', () => {
|
||||||
|
const azureProvider = createAzureProvider('azure-test', '2024-02-15', 'gpt-4o')
|
||||||
|
const result = getAiSdkProviderId(azureProvider)
|
||||||
|
expect(result).toBe('azure')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle Azure OpenAI providers response endpoint correctly', () => {
|
||||||
|
const azureProvider = createAzureProvider('azure-test', 'v1', 'gpt-4o')
|
||||||
|
const result = getAiSdkProviderId(azureProvider)
|
||||||
|
expect(result).toBe('azure-responses')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle Azure provider Claude Models', () => {
|
||||||
|
const provider = createTestProvider('azure-anthropic', 'anthropic')
|
||||||
|
const result = getAiSdkProviderId(provider)
|
||||||
|
expect(result).toBe('azure-anthropic')
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
describe('Backward Compatibility', () => {
|
describe('Backward Compatibility', () => {
|
||||||
|
|||||||
@@ -24,7 +24,17 @@ vi.mock('@renderer/services/AssistantService', () => ({
|
|||||||
|
|
||||||
vi.mock('@renderer/store', () => ({
|
vi.mock('@renderer/store', () => ({
|
||||||
default: {
|
default: {
|
||||||
getState: () => ({ copilot: { defaultHeaders: {} } })
|
getState: () => ({
|
||||||
|
copilot: { defaultHeaders: {} },
|
||||||
|
llm: {
|
||||||
|
settings: {
|
||||||
|
vertexai: {
|
||||||
|
projectId: 'test-project',
|
||||||
|
location: 'us-central1'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
@@ -33,7 +43,7 @@ vi.mock('@renderer/utils/api', () => ({
|
|||||||
if (isSupportedAPIVersion === false) {
|
if (isSupportedAPIVersion === false) {
|
||||||
return host // Return host as-is when isSupportedAPIVersion is false
|
return host // Return host as-is when isSupportedAPIVersion is false
|
||||||
}
|
}
|
||||||
return `${host}/v1` // Default behavior when isSupportedAPIVersion is true
|
return host ? `${host}/v1` : '' // Default behavior when isSupportedAPIVersion is true
|
||||||
}),
|
}),
|
||||||
routeToEndpoint: vi.fn((host) => ({
|
routeToEndpoint: vi.fn((host) => ({
|
||||||
baseURL: host,
|
baseURL: host,
|
||||||
@@ -41,6 +51,20 @@ vi.mock('@renderer/utils/api', () => ({
|
|||||||
}))
|
}))
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
// Also mock @shared/api since formatProviderApiHost uses it directly
|
||||||
|
vi.mock('@shared/api', async (importOriginal) => {
|
||||||
|
const actual = (await importOriginal()) as any
|
||||||
|
return {
|
||||||
|
...actual,
|
||||||
|
formatApiHost: vi.fn((host, isSupportedAPIVersion = true) => {
|
||||||
|
if (isSupportedAPIVersion === false) {
|
||||||
|
return host || '' // Return host as-is when isSupportedAPIVersion is false
|
||||||
|
}
|
||||||
|
return host ? `${host}/v1` : '' // Default behavior when isSupportedAPIVersion is true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
vi.mock('@renderer/utils/provider', async (importOriginal) => {
|
vi.mock('@renderer/utils/provider', async (importOriginal) => {
|
||||||
const actual = (await importOriginal()) as any
|
const actual = (await importOriginal()) as any
|
||||||
return {
|
return {
|
||||||
@@ -73,8 +97,8 @@ vi.mock('@renderer/services/AssistantService', () => ({
|
|||||||
|
|
||||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
import type { Model, Provider } from '@renderer/types'
|
import type { Model, Provider } from '@renderer/types'
|
||||||
import { formatApiHost } from '@renderer/utils/api'
|
|
||||||
import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider'
|
import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider'
|
||||||
|
import { formatApiHost } from '@shared/api'
|
||||||
|
|
||||||
import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants'
|
import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants'
|
||||||
import { getActualProvider, providerToAiSdkConfig } from '../providerConfig'
|
import { getActualProvider, providerToAiSdkConfig } from '../providerConfig'
|
||||||
|
|||||||
@@ -1,22 +0,0 @@
|
|||||||
import type { Provider } from '@renderer/types'
|
|
||||||
|
|
||||||
import { provider2Provider, startsWith } from './helper'
|
|
||||||
import type { RuleSet } from './types'
|
|
||||||
|
|
||||||
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
|
|
||||||
const AZURE_ANTHROPIC_RULES: RuleSet = {
|
|
||||||
rules: [
|
|
||||||
{
|
|
||||||
match: startsWith('claude'),
|
|
||||||
provider: (provider: Provider) => ({
|
|
||||||
...provider,
|
|
||||||
type: 'anthropic',
|
|
||||||
apiHost: provider.apiHost + 'anthropic/v1',
|
|
||||||
id: 'azure-anthropic'
|
|
||||||
})
|
|
||||||
}
|
|
||||||
],
|
|
||||||
fallbackRule: (provider: Provider) => provider
|
|
||||||
}
|
|
||||||
|
|
||||||
export const azureAnthropicProviderCreator = provider2Provider.bind(null, AZURE_ANTHROPIC_RULES)
|
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
import type { Model, Provider } from '@renderer/types'
|
|
||||||
|
|
||||||
import type { RuleSet } from './types'
|
|
||||||
|
|
||||||
export const startsWith = (prefix: string) => (model: Model) => model.id.toLowerCase().startsWith(prefix.toLowerCase())
|
|
||||||
export const endpointIs = (type: string) => (model: Model) => model.endpoint_type === type
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 解析模型对应的Provider
|
|
||||||
* @param ruleSet 规则集对象
|
|
||||||
* @param model 模型对象
|
|
||||||
* @param provider 原始provider对象
|
|
||||||
* @returns 解析出的provider对象
|
|
||||||
*/
|
|
||||||
export function provider2Provider(ruleSet: RuleSet, model: Model, provider: Provider): Provider {
|
|
||||||
for (const rule of ruleSet.rules) {
|
|
||||||
if (rule.match(model)) {
|
|
||||||
return rule.provider(provider)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ruleSet.fallbackRule(provider)
|
|
||||||
}
|
|
||||||
@@ -1,3 +1,7 @@
|
|||||||
export { aihubmixProviderCreator } from './aihubmix'
|
// Re-export from shared config
|
||||||
export { newApiResolverCreator } from './newApi'
|
export {
|
||||||
export { vertexAnthropicProviderCreator } from './vertext-anthropic'
|
aihubmixProviderCreator,
|
||||||
|
azureAnthropicProviderCreator,
|
||||||
|
newApiResolverCreator,
|
||||||
|
vertexAnthropicProviderCreator
|
||||||
|
} from '@shared/provider/config'
|
||||||
|
|||||||
@@ -1,9 +0,0 @@
|
|||||||
import type { Model, Provider } from '@renderer/types'
|
|
||||||
|
|
||||||
export interface RuleSet {
|
|
||||||
rules: Array<{
|
|
||||||
match: (model: Model) => boolean
|
|
||||||
provider: (provider: Provider) => Provider
|
|
||||||
}>
|
|
||||||
fallbackRule: (provider: Provider) => Provider
|
|
||||||
}
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
import type { Provider } from '@renderer/types'
|
|
||||||
|
|
||||||
import { provider2Provider, startsWith } from './helper'
|
|
||||||
import type { RuleSet } from './types'
|
|
||||||
|
|
||||||
const VERTEX_ANTHROPIC_RULES: RuleSet = {
|
|
||||||
rules: [
|
|
||||||
{
|
|
||||||
match: startsWith('claude'),
|
|
||||||
provider: (provider: Provider) => ({
|
|
||||||
...provider,
|
|
||||||
id: 'google-vertex-anthropic'
|
|
||||||
})
|
|
||||||
}
|
|
||||||
],
|
|
||||||
fallbackRule: (provider: Provider) => provider
|
|
||||||
}
|
|
||||||
|
|
||||||
export const vertexAnthropicProviderCreator = provider2Provider.bind(null, VERTEX_ANTHROPIC_RULES)
|
|
||||||
@@ -1,25 +1 @@
|
|||||||
import type { Model } from '@renderer/types'
|
export { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '@shared/provider/constant'
|
||||||
|
|
||||||
export const COPILOT_EDITOR_VERSION = 'vscode/1.104.1'
|
|
||||||
export const COPILOT_PLUGIN_VERSION = 'copilot-chat/0.26.7'
|
|
||||||
export const COPILOT_INTEGRATION_ID = 'vscode-chat'
|
|
||||||
export const COPILOT_USER_AGENT = 'GitHubCopilotChat/0.26.7'
|
|
||||||
|
|
||||||
export const COPILOT_DEFAULT_HEADERS = {
|
|
||||||
'Copilot-Integration-Id': COPILOT_INTEGRATION_ID,
|
|
||||||
'User-Agent': COPILOT_USER_AGENT,
|
|
||||||
'Editor-Version': COPILOT_EDITOR_VERSION,
|
|
||||||
'Editor-Plugin-Version': COPILOT_PLUGIN_VERSION,
|
|
||||||
'editor-version': COPILOT_EDITOR_VERSION,
|
|
||||||
'editor-plugin-version': COPILOT_PLUGIN_VERSION,
|
|
||||||
'copilot-vision-request': 'true'
|
|
||||||
} as const
|
|
||||||
|
|
||||||
// Models that require the OpenAI Responses endpoint when routed through GitHub Copilot (#10560)
|
|
||||||
const COPILOT_RESPONSES_MODEL_IDS = ['gpt-5-codex']
|
|
||||||
|
|
||||||
export function isCopilotResponsesModel(model: Model): boolean {
|
|
||||||
const normalizedId = model.id?.trim().toLowerCase()
|
|
||||||
const normalizedName = model.name?.trim().toLowerCase()
|
|
||||||
return COPILOT_RESPONSES_MODEL_IDS.some((target) => normalizedId === target || normalizedName === target)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } from '@cherrystudio/ai-core/provider'
|
|
||||||
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
|
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import type { Provider } from '@renderer/types'
|
import type { Provider } from '@renderer/types'
|
||||||
import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from '@renderer/utils/provider'
|
import { getAiSdkProviderId as sharedGetAiSdkProviderId } from '@shared/provider'
|
||||||
import type { Provider as AiSdkProvider } from 'ai'
|
import type { Provider as AiSdkProvider } from 'ai'
|
||||||
|
|
||||||
import type { AiSdkConfig } from '../types'
|
import type { AiSdkConfig } from '../types'
|
||||||
@@ -22,68 +21,12 @@ const logger = loggerService.withContext('ProviderFactory')
|
|||||||
}
|
}
|
||||||
})()
|
})()
|
||||||
|
|
||||||
/**
|
|
||||||
* 静态Provider映射表
|
|
||||||
* 处理Cherry Studio特有的provider ID到AI SDK标准ID的映射
|
|
||||||
*/
|
|
||||||
const STATIC_PROVIDER_MAPPING: Record<string, ProviderId> = {
|
|
||||||
gemini: 'google', // Google Gemini -> google
|
|
||||||
'azure-openai': 'azure', // Azure OpenAI -> azure
|
|
||||||
'openai-response': 'openai', // OpenAI Responses -> openai
|
|
||||||
grok: 'xai', // Grok -> xai
|
|
||||||
copilot: 'github-copilot-openai-compatible'
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 尝试解析provider标识符(支持静态映射和别名)
|
|
||||||
*/
|
|
||||||
function tryResolveProviderId(identifier: string): ProviderId | null {
|
|
||||||
// 1. 检查静态映射
|
|
||||||
const staticMapping = STATIC_PROVIDER_MAPPING[identifier]
|
|
||||||
if (staticMapping) {
|
|
||||||
return staticMapping
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. 检查AiCore是否支持(包括别名支持)
|
|
||||||
if (hasProviderConfigByAlias(identifier)) {
|
|
||||||
// 解析为真实的Provider ID
|
|
||||||
return resolveProviderConfigId(identifier) as ProviderId
|
|
||||||
}
|
|
||||||
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取AI SDK Provider ID
|
* 获取AI SDK Provider ID
|
||||||
* 简化版:减少重复逻辑,利用通用解析函数
|
* Uses shared implementation with renderer-specific config checker
|
||||||
*/
|
*/
|
||||||
export function getAiSdkProviderId(provider: Provider): string {
|
export function getAiSdkProviderId(provider: Provider): string {
|
||||||
// 1. 尝试解析provider.id
|
return sharedGetAiSdkProviderId(provider)
|
||||||
const resolvedFromId = tryResolveProviderId(provider.id)
|
|
||||||
if (isAzureOpenAIProvider(provider)) {
|
|
||||||
if (isAzureResponsesEndpoint(provider)) {
|
|
||||||
return 'azure-responses'
|
|
||||||
} else {
|
|
||||||
return 'azure'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (resolvedFromId) {
|
|
||||||
return resolvedFromId
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. 尝试解析provider.type
|
|
||||||
// 会把所有类型为openai的自定义provider解析到aisdk的openaiProvider上
|
|
||||||
if (provider.type !== 'openai') {
|
|
||||||
const resolvedFromType = tryResolveProviderId(provider.type)
|
|
||||||
if (resolvedFromType) {
|
|
||||||
return resolvedFromType
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (provider.apiHost.includes('api.openai.com')) {
|
|
||||||
return 'openai-chat'
|
|
||||||
}
|
|
||||||
// 3. 最后的fallback(使用provider本身的id)
|
|
||||||
return provider.id
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {
|
export async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
|
import { hasProviderConfig } from '@cherrystudio/ai-core/provider'
|
||||||
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
|
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
|
||||||
import {
|
import {
|
||||||
getAwsBedrockAccessKeyId,
|
getAwsBedrockAccessKeyId,
|
||||||
@@ -10,22 +10,17 @@ import {
|
|||||||
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
|
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
|
||||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
import store from '@renderer/store'
|
import store from '@renderer/store'
|
||||||
import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types'
|
import { isSystemProvider, type Model, type Provider } from '@renderer/types'
|
||||||
import { formatApiHost, formatAzureOpenAIApiHost, formatVertexApiHost, routeToEndpoint } from '@renderer/utils/api'
|
|
||||||
import {
|
import {
|
||||||
isAnthropicProvider,
|
type AiSdkConfigContext,
|
||||||
isAzureOpenAIProvider,
|
formatProviderApiHost as sharedFormatProviderApiHost,
|
||||||
isCherryAIProvider,
|
type ProviderFormatContext,
|
||||||
isGeminiProvider,
|
providerToAiSdkConfig as sharedProviderToAiSdkConfig,
|
||||||
isNewApiProvider,
|
resolveActualProvider
|
||||||
isPerplexityProvider,
|
} from '@shared/provider'
|
||||||
isVertexProvider
|
|
||||||
} from '@renderer/utils/provider'
|
|
||||||
import { cloneDeep } from 'lodash'
|
import { cloneDeep } from 'lodash'
|
||||||
|
|
||||||
import type { AiSdkConfig } from '../types'
|
import type { AiSdkConfig } from '../types'
|
||||||
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
|
|
||||||
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
|
|
||||||
import { COPILOT_DEFAULT_HEADERS } from './constants'
|
import { COPILOT_DEFAULT_HEADERS } from './constants'
|
||||||
import { getAiSdkProviderId } from './factory'
|
import { getAiSdkProviderId } from './factory'
|
||||||
|
|
||||||
@@ -56,61 +51,51 @@ function getRotatedApiKey(provider: Provider): string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 处理特殊provider的转换逻辑
|
* Renderer-specific context for providerToAiSdkConfig
|
||||||
|
* Provides implementations using browser APIs, store, and hooks
|
||||||
*/
|
*/
|
||||||
function handleSpecialProviders(model: Model, provider: Provider): Provider {
|
function createRendererSdkContext(model: Model): AiSdkConfigContext {
|
||||||
if (isNewApiProvider(provider)) {
|
return {
|
||||||
return newApiResolverCreator(model, provider)
|
getRotatedApiKey: (provider) => getRotatedApiKey(provider as Provider),
|
||||||
|
isOpenAIChatCompletionOnlyModel: () => isOpenAIChatCompletionOnlyModel(model),
|
||||||
|
getCopilotDefaultHeaders: () => COPILOT_DEFAULT_HEADERS,
|
||||||
|
getCopilotStoredHeaders: () => store.getState().copilot.defaultHeaders ?? {},
|
||||||
|
getAwsBedrockConfig: () => {
|
||||||
|
const authType = getAwsBedrockAuthType()
|
||||||
|
return {
|
||||||
|
authType,
|
||||||
|
region: getAwsBedrockRegion(),
|
||||||
|
apiKey: authType === 'apiKey' ? getAwsBedrockApiKey() : undefined,
|
||||||
|
accessKeyId: authType === 'iam' ? getAwsBedrockAccessKeyId() : undefined,
|
||||||
|
secretAccessKey: authType === 'iam' ? getAwsBedrockSecretAccessKey() : undefined
|
||||||
|
}
|
||||||
|
},
|
||||||
|
getVertexConfig: (provider) => {
|
||||||
|
if (!isVertexAIConfigured()) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
return createVertexProvider(provider as Provider)
|
||||||
|
},
|
||||||
|
getEndpointType: () => model.endpoint_type
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isSystemProvider(provider)) {
|
|
||||||
if (provider.id === 'aihubmix') {
|
|
||||||
return aihubmixProviderCreator(model, provider)
|
|
||||||
}
|
|
||||||
if (provider.id === 'vertexai') {
|
|
||||||
return vertexAnthropicProviderCreator(model, provider)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if (isAzureOpenAIProvider(provider)) {
|
|
||||||
return azureAnthropicProviderCreator(model, provider)
|
|
||||||
}
|
|
||||||
return provider
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 主要用来对齐AISdk的BaseURL格式
|
* 主要用来对齐AISdk的BaseURL格式
|
||||||
* @param provider
|
* Uses shared implementation with renderer-specific context
|
||||||
* @returns
|
|
||||||
*/
|
*/
|
||||||
function formatProviderApiHost(provider: Provider): Provider {
|
function getRendererFormatContext(): ProviderFormatContext {
|
||||||
const formatted = { ...provider }
|
const vertexSettings = store.getState().llm.settings.vertexai
|
||||||
if (formatted.anthropicApiHost) {
|
return {
|
||||||
formatted.anthropicApiHost = formatApiHost(formatted.anthropicApiHost)
|
vertex: {
|
||||||
}
|
project: vertexSettings.projectId || 'default-project',
|
||||||
|
location: vertexSettings.location || 'us-central1'
|
||||||
if (isAnthropicProvider(provider)) {
|
|
||||||
const baseHost = formatted.anthropicApiHost || formatted.apiHost
|
|
||||||
// AI SDK needs /v1 in baseURL, Anthropic SDK will strip it in getSdkClient
|
|
||||||
formatted.apiHost = formatApiHost(baseHost)
|
|
||||||
if (!formatted.anthropicApiHost) {
|
|
||||||
formatted.anthropicApiHost = formatted.apiHost
|
|
||||||
}
|
}
|
||||||
} else if (formatted.id === SystemProviderIds.copilot || formatted.id === SystemProviderIds.github) {
|
|
||||||
formatted.apiHost = formatApiHost(formatted.apiHost, false)
|
|
||||||
} else if (isGeminiProvider(formatted)) {
|
|
||||||
formatted.apiHost = formatApiHost(formatted.apiHost, true, 'v1beta')
|
|
||||||
} else if (isAzureOpenAIProvider(formatted)) {
|
|
||||||
formatted.apiHost = formatAzureOpenAIApiHost(formatted.apiHost)
|
|
||||||
} else if (isVertexProvider(formatted)) {
|
|
||||||
formatted.apiHost = formatVertexApiHost(formatted)
|
|
||||||
} else if (isCherryAIProvider(formatted)) {
|
|
||||||
formatted.apiHost = formatApiHost(formatted.apiHost, false)
|
|
||||||
} else if (isPerplexityProvider(formatted)) {
|
|
||||||
formatted.apiHost = formatApiHost(formatted.apiHost, false)
|
|
||||||
} else {
|
|
||||||
formatted.apiHost = formatApiHost(formatted.apiHost)
|
|
||||||
}
|
}
|
||||||
return formatted
|
}
|
||||||
|
|
||||||
|
function formatProviderApiHost(provider: Provider): Provider {
|
||||||
|
return sharedFormatProviderApiHost(provider, getRendererFormatContext())
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -122,7 +107,9 @@ export function getActualProvider(model: Model): Provider {
|
|||||||
|
|
||||||
// 按顺序处理各种转换
|
// 按顺序处理各种转换
|
||||||
let actualProvider = cloneDeep(baseProvider)
|
let actualProvider = cloneDeep(baseProvider)
|
||||||
actualProvider = handleSpecialProviders(model, actualProvider)
|
actualProvider = resolveActualProvider(actualProvider, model, {
|
||||||
|
isSystemProvider
|
||||||
|
}) as Provider
|
||||||
actualProvider = formatProviderApiHost(actualProvider)
|
actualProvider = formatProviderApiHost(actualProvider)
|
||||||
|
|
||||||
return actualProvider
|
return actualProvider
|
||||||
@@ -130,121 +117,11 @@ export function getActualProvider(model: Model): Provider {
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* 将 Provider 配置转换为新 AI SDK 格式
|
* 将 Provider 配置转换为新 AI SDK 格式
|
||||||
* 简化版:利用新的别名映射系统
|
* Uses shared implementation with renderer-specific context
|
||||||
*/
|
*/
|
||||||
export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig {
|
export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig {
|
||||||
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
const context = createRendererSdkContext(model)
|
||||||
|
return sharedProviderToAiSdkConfig(actualProvider, model.id, context) as AiSdkConfig
|
||||||
// 构建基础配置
|
|
||||||
const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost)
|
|
||||||
const baseConfig = {
|
|
||||||
baseURL: baseURL,
|
|
||||||
apiKey: getRotatedApiKey(actualProvider)
|
|
||||||
}
|
|
||||||
|
|
||||||
const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot
|
|
||||||
if (isCopilotProvider) {
|
|
||||||
const storedHeaders = store.getState().copilot.defaultHeaders ?? {}
|
|
||||||
const options = ProviderConfigFactory.fromProvider('github-copilot-openai-compatible', baseConfig, {
|
|
||||||
headers: {
|
|
||||||
...COPILOT_DEFAULT_HEADERS,
|
|
||||||
...storedHeaders,
|
|
||||||
...actualProvider.extra_headers
|
|
||||||
},
|
|
||||||
name: actualProvider.id,
|
|
||||||
includeUsage: true
|
|
||||||
})
|
|
||||||
|
|
||||||
return {
|
|
||||||
providerId: 'github-copilot-openai-compatible',
|
|
||||||
options
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 处理OpenAI模式
|
|
||||||
const extraOptions: any = {}
|
|
||||||
extraOptions.endpoint = endpoint
|
|
||||||
if (actualProvider.type === 'openai-response' && !isOpenAIChatCompletionOnlyModel(model)) {
|
|
||||||
extraOptions.mode = 'responses'
|
|
||||||
} else if (aiSdkProviderId === 'openai' || (aiSdkProviderId === 'cherryin' && actualProvider.type === 'openai')) {
|
|
||||||
extraOptions.mode = 'chat'
|
|
||||||
}
|
|
||||||
|
|
||||||
// 添加额外headers
|
|
||||||
if (actualProvider.extra_headers) {
|
|
||||||
extraOptions.headers = actualProvider.extra_headers
|
|
||||||
// copy from openaiBaseClient/openaiResponseApiClient
|
|
||||||
if (aiSdkProviderId === 'openai') {
|
|
||||||
extraOptions.headers = {
|
|
||||||
...extraOptions.headers,
|
|
||||||
'HTTP-Referer': 'https://cherry-ai.com',
|
|
||||||
'X-Title': 'Cherry Studio',
|
|
||||||
'X-Api-Key': baseConfig.apiKey
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// azure
|
|
||||||
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest
|
|
||||||
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses?tabs=python-key#responses-api
|
|
||||||
if (aiSdkProviderId === 'azure-responses') {
|
|
||||||
extraOptions.mode = 'responses'
|
|
||||||
} else if (aiSdkProviderId === 'azure') {
|
|
||||||
extraOptions.mode = 'chat'
|
|
||||||
}
|
|
||||||
|
|
||||||
// bedrock
|
|
||||||
if (aiSdkProviderId === 'bedrock') {
|
|
||||||
const authType = getAwsBedrockAuthType()
|
|
||||||
extraOptions.region = getAwsBedrockRegion()
|
|
||||||
|
|
||||||
if (authType === 'apiKey') {
|
|
||||||
extraOptions.apiKey = getAwsBedrockApiKey()
|
|
||||||
} else {
|
|
||||||
extraOptions.accessKeyId = getAwsBedrockAccessKeyId()
|
|
||||||
extraOptions.secretAccessKey = getAwsBedrockSecretAccessKey()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// google-vertex
|
|
||||||
if (aiSdkProviderId === 'google-vertex' || aiSdkProviderId === 'google-vertex-anthropic') {
|
|
||||||
if (!isVertexAIConfigured()) {
|
|
||||||
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
|
|
||||||
}
|
|
||||||
const { project, location, googleCredentials } = createVertexProvider(actualProvider)
|
|
||||||
extraOptions.project = project
|
|
||||||
extraOptions.location = location
|
|
||||||
extraOptions.googleCredentials = {
|
|
||||||
...googleCredentials,
|
|
||||||
privateKey: formatPrivateKey(googleCredentials.privateKey)
|
|
||||||
}
|
|
||||||
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
|
|
||||||
}
|
|
||||||
|
|
||||||
// cherryin
|
|
||||||
if (aiSdkProviderId === 'cherryin') {
|
|
||||||
if (model.endpoint_type) {
|
|
||||||
extraOptions.endpointType = model.endpoint_type
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
|
||||||
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
|
|
||||||
return {
|
|
||||||
providerId: aiSdkProviderId,
|
|
||||||
options
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// 否则fallback到openai-compatible
|
|
||||||
const options = ProviderConfigFactory.createOpenAICompatible(baseConfig.baseURL, baseConfig.apiKey)
|
|
||||||
return {
|
|
||||||
providerId: 'openai-compatible',
|
|
||||||
options: {
|
|
||||||
...options,
|
|
||||||
name: actualProvider.id,
|
|
||||||
...extraOptions,
|
|
||||||
includeUsage: true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -287,13 +164,13 @@ export async function prepareSpecialProviderConfig(
|
|||||||
break
|
break
|
||||||
}
|
}
|
||||||
case 'cherryai': {
|
case 'cherryai': {
|
||||||
config.options.fetch = async (url, options) => {
|
config.options.fetch = async (url: RequestInfo | URL, options: RequestInit) => {
|
||||||
// 在这里对最终参数进行签名
|
// 在这里对最终参数进行签名
|
||||||
const signature = await window.api.cherryai.generateSignature({
|
const signature = await window.api.cherryai.generateSignature({
|
||||||
method: 'POST',
|
method: 'POST',
|
||||||
path: '/chat/completions',
|
path: '/chat/completions',
|
||||||
query: '',
|
query: '',
|
||||||
body: JSON.parse(options.body)
|
body: JSON.parse(options.body as string)
|
||||||
})
|
})
|
||||||
return fetch(url, {
|
return fetch(url, {
|
||||||
...options,
|
...options,
|
||||||
|
|||||||
@@ -1,113 +1,13 @@
|
|||||||
import { type ProviderConfig, registerMultipleProviderConfigs } from '@cherrystudio/ai-core/provider'
|
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
|
import { initializeSharedProviders, SHARED_PROVIDER_CONFIGS } from '@shared/provider'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ProviderConfigs')
|
const logger = loggerService.withContext('ProviderConfigs')
|
||||||
|
|
||||||
/**
|
export const NEW_PROVIDER_CONFIGS = SHARED_PROVIDER_CONFIGS
|
||||||
* 新Provider配置定义
|
|
||||||
* 定义了需要动态注册的AI Providers
|
|
||||||
*/
|
|
||||||
export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
|
|
||||||
{
|
|
||||||
id: 'openrouter',
|
|
||||||
name: 'OpenRouter',
|
|
||||||
import: () => import('@openrouter/ai-sdk-provider'),
|
|
||||||
creatorFunctionName: 'createOpenRouter',
|
|
||||||
supportsImageGeneration: true,
|
|
||||||
aliases: ['openrouter']
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'google-vertex',
|
|
||||||
name: 'Google Vertex AI',
|
|
||||||
import: () => import('@ai-sdk/google-vertex/edge'),
|
|
||||||
creatorFunctionName: 'createVertex',
|
|
||||||
supportsImageGeneration: true,
|
|
||||||
aliases: ['vertexai']
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'google-vertex-anthropic',
|
|
||||||
name: 'Google Vertex AI Anthropic',
|
|
||||||
import: () => import('@ai-sdk/google-vertex/anthropic/edge'),
|
|
||||||
creatorFunctionName: 'createVertexAnthropic',
|
|
||||||
supportsImageGeneration: true,
|
|
||||||
aliases: ['vertexai-anthropic']
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'azure-anthropic',
|
|
||||||
name: 'Azure AI Anthropic',
|
|
||||||
import: () => import('@ai-sdk/anthropic'),
|
|
||||||
creatorFunctionName: 'createAnthropic',
|
|
||||||
supportsImageGeneration: false,
|
|
||||||
aliases: ['azure-anthropic']
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'github-copilot-openai-compatible',
|
|
||||||
name: 'GitHub Copilot OpenAI Compatible',
|
|
||||||
import: () => import('@opeoginni/github-copilot-openai-compatible'),
|
|
||||||
creatorFunctionName: 'createGitHubCopilotOpenAICompatible',
|
|
||||||
supportsImageGeneration: false,
|
|
||||||
aliases: ['copilot', 'github-copilot']
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'bedrock',
|
|
||||||
name: 'Amazon Bedrock',
|
|
||||||
import: () => import('@ai-sdk/amazon-bedrock'),
|
|
||||||
creatorFunctionName: 'createAmazonBedrock',
|
|
||||||
supportsImageGeneration: true,
|
|
||||||
aliases: ['aws-bedrock']
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'perplexity',
|
|
||||||
name: 'Perplexity',
|
|
||||||
import: () => import('@ai-sdk/perplexity'),
|
|
||||||
creatorFunctionName: 'createPerplexity',
|
|
||||||
supportsImageGeneration: false,
|
|
||||||
aliases: ['perplexity']
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'mistral',
|
|
||||||
name: 'Mistral',
|
|
||||||
import: () => import('@ai-sdk/mistral'),
|
|
||||||
creatorFunctionName: 'createMistral',
|
|
||||||
supportsImageGeneration: false,
|
|
||||||
aliases: ['mistral']
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'huggingface',
|
|
||||||
name: 'HuggingFace',
|
|
||||||
import: () => import('@ai-sdk/huggingface'),
|
|
||||||
creatorFunctionName: 'createHuggingFace',
|
|
||||||
supportsImageGeneration: true,
|
|
||||||
aliases: ['hf', 'hugging-face']
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'ai-gateway',
|
|
||||||
name: 'AI Gateway',
|
|
||||||
import: () => import('@ai-sdk/gateway'),
|
|
||||||
creatorFunctionName: 'createGateway',
|
|
||||||
supportsImageGeneration: true,
|
|
||||||
aliases: ['gateway']
|
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'cerebras',
|
|
||||||
name: 'Cerebras',
|
|
||||||
import: () => import('@ai-sdk/cerebras'),
|
|
||||||
creatorFunctionName: 'createCerebras',
|
|
||||||
supportsImageGeneration: false
|
|
||||||
}
|
|
||||||
] as const
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 初始化新的Providers
|
|
||||||
* 使用aiCore的动态注册功能
|
|
||||||
*/
|
|
||||||
export async function initializeNewProviders(): Promise<void> {
|
export async function initializeNewProviders(): Promise<void> {
|
||||||
try {
|
initializeSharedProviders({
|
||||||
const successCount = registerMultipleProviderConfigs(NEW_PROVIDER_CONFIGS)
|
warn: (message) => logger.warn(message),
|
||||||
if (successCount < NEW_PROVIDER_CONFIGS.length) {
|
error: (message, error) => logger.error(message, error)
|
||||||
logger.warn('Some providers failed to register. Check previous error logs.')
|
})
|
||||||
}
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('Failed to initialize new providers:', error as Error)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -245,8 +245,8 @@ export class AiSdkSpanAdapter {
|
|||||||
'gen_ai.usage.output_tokens'
|
'gen_ai.usage.output_tokens'
|
||||||
]
|
]
|
||||||
|
|
||||||
const completionTokens = attributes[inputsTokenKeys.find((key) => attributes[key]) || '']
|
const promptTokens = attributes[inputsTokenKeys.find((key) => attributes[key]) || '']
|
||||||
const promptTokens = attributes[outputTokenKeys.find((key) => attributes[key]) || '']
|
const completionTokens = attributes[outputTokenKeys.find((key) => attributes[key]) || '']
|
||||||
|
|
||||||
if (completionTokens !== undefined || promptTokens !== undefined) {
|
if (completionTokens !== undefined || promptTokens !== undefined) {
|
||||||
const usage: TokenUsage = {
|
const usage: TokenUsage = {
|
||||||
|
|||||||
@@ -0,0 +1,53 @@
|
|||||||
|
import type { Span } from '@opentelemetry/api'
|
||||||
|
import { SpanKind, SpanStatusCode } from '@opentelemetry/api'
|
||||||
|
import { describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
import { AiSdkSpanAdapter } from '../AiSdkSpanAdapter'
|
||||||
|
|
||||||
|
vi.mock('@logger', () => ({
|
||||||
|
loggerService: {
|
||||||
|
withContext: () => ({
|
||||||
|
debug: vi.fn(),
|
||||||
|
error: vi.fn(),
|
||||||
|
info: vi.fn(),
|
||||||
|
warn: vi.fn()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('AiSdkSpanAdapter', () => {
|
||||||
|
const createMockSpan = (attributes: Record<string, unknown>): Span => {
|
||||||
|
const span = {
|
||||||
|
spanContext: () => ({
|
||||||
|
traceId: 'trace-id',
|
||||||
|
spanId: 'span-id'
|
||||||
|
}),
|
||||||
|
_attributes: attributes,
|
||||||
|
_events: [],
|
||||||
|
name: 'test span',
|
||||||
|
status: { code: SpanStatusCode.OK },
|
||||||
|
kind: SpanKind.CLIENT,
|
||||||
|
startTime: [0, 0] as [number, number],
|
||||||
|
endTime: [0, 1] as [number, number],
|
||||||
|
ended: true,
|
||||||
|
parentSpanId: '',
|
||||||
|
links: []
|
||||||
|
}
|
||||||
|
return span as unknown as Span
|
||||||
|
}
|
||||||
|
|
||||||
|
it('maps prompt and completion usage tokens to the correct fields', () => {
|
||||||
|
const attributes = {
|
||||||
|
'ai.usage.promptTokens': 321,
|
||||||
|
'ai.usage.completionTokens': 654
|
||||||
|
}
|
||||||
|
|
||||||
|
const span = createMockSpan(attributes)
|
||||||
|
const result = AiSdkSpanAdapter.convertToSpanEntity({ span })
|
||||||
|
|
||||||
|
expect(result.usage).toBeDefined()
|
||||||
|
expect(result.usage?.prompt_tokens).toBe(321)
|
||||||
|
expect(result.usage?.completion_tokens).toBe(654)
|
||||||
|
expect(result.usage?.total_tokens).toBe(975)
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -71,10 +71,11 @@ describe('mcp utils', () => {
|
|||||||
const result = setupToolsConfig(mcpTools)
|
const result = setupToolsConfig(mcpTools)
|
||||||
|
|
||||||
expect(result).not.toBeUndefined()
|
expect(result).not.toBeUndefined()
|
||||||
expect(Object.keys(result!)).toEqual(['test-tool'])
|
// Tools are now keyed by id (which includes serverId suffix) for uniqueness
|
||||||
expect(result!['test-tool']).toHaveProperty('description')
|
expect(Object.keys(result!)).toEqual(['test-tool-1'])
|
||||||
expect(result!['test-tool']).toHaveProperty('inputSchema')
|
expect(result!['test-tool-1']).toHaveProperty('description')
|
||||||
expect(result!['test-tool']).toHaveProperty('execute')
|
expect(result!['test-tool-1']).toHaveProperty('inputSchema')
|
||||||
|
expect(result!['test-tool-1']).toHaveProperty('execute')
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should handle multiple MCP tools', () => {
|
it('should handle multiple MCP tools', () => {
|
||||||
@@ -109,7 +110,8 @@ describe('mcp utils', () => {
|
|||||||
|
|
||||||
expect(result).not.toBeUndefined()
|
expect(result).not.toBeUndefined()
|
||||||
expect(Object.keys(result!)).toHaveLength(2)
|
expect(Object.keys(result!)).toHaveLength(2)
|
||||||
expect(Object.keys(result!)).toEqual(['tool1', 'tool2'])
|
// Tools are keyed by id for uniqueness
|
||||||
|
expect(Object.keys(result!)).toEqual(['tool1-id', 'tool2-id'])
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -135,9 +137,10 @@ describe('mcp utils', () => {
|
|||||||
|
|
||||||
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||||
|
|
||||||
expect(Object.keys(result)).toEqual(['get-weather'])
|
// Tools are keyed by id for uniqueness when multiple server instances exist
|
||||||
|
expect(Object.keys(result)).toEqual(['get-weather-id'])
|
||||||
|
|
||||||
const tool = result['get-weather'] as Tool
|
const tool = result['get-weather-id'] as Tool
|
||||||
expect(tool.description).toBe('Get weather information')
|
expect(tool.description).toBe('Get weather information')
|
||||||
expect(tool.inputSchema).toBeDefined()
|
expect(tool.inputSchema).toBeDefined()
|
||||||
expect(typeof tool.execute).toBe('function')
|
expect(typeof tool.execute).toBe('function')
|
||||||
@@ -160,8 +163,8 @@ describe('mcp utils', () => {
|
|||||||
|
|
||||||
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||||
|
|
||||||
expect(Object.keys(result)).toEqual(['no-desc-tool'])
|
expect(Object.keys(result)).toEqual(['no-desc-tool-id'])
|
||||||
const tool = result['no-desc-tool'] as Tool
|
const tool = result['no-desc-tool-id'] as Tool
|
||||||
expect(tool.description).toBe('Tool from test-server')
|
expect(tool.description).toBe('Tool from test-server')
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -202,13 +205,13 @@ describe('mcp utils', () => {
|
|||||||
|
|
||||||
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||||
|
|
||||||
expect(Object.keys(result)).toEqual(['complex-tool'])
|
expect(Object.keys(result)).toEqual(['complex-tool-id'])
|
||||||
const tool = result['complex-tool'] as Tool
|
const tool = result['complex-tool-id'] as Tool
|
||||||
expect(tool.inputSchema).toBeDefined()
|
expect(tool.inputSchema).toBeDefined()
|
||||||
expect(typeof tool.execute).toBe('function')
|
expect(typeof tool.execute).toBe('function')
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should preserve tool names with special characters', () => {
|
it('should preserve tool id with special characters', () => {
|
||||||
const mcpTools: MCPTool[] = [
|
const mcpTools: MCPTool[] = [
|
||||||
{
|
{
|
||||||
id: 'special-tool-id',
|
id: 'special-tool-id',
|
||||||
@@ -225,7 +228,8 @@ describe('mcp utils', () => {
|
|||||||
]
|
]
|
||||||
|
|
||||||
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||||
expect(Object.keys(result)).toEqual(['tool_with-special.chars'])
|
// Tools are keyed by id for uniqueness
|
||||||
|
expect(Object.keys(result)).toEqual(['special-tool-id'])
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should handle multiple tools with different schemas', () => {
|
it('should handle multiple tools with different schemas', () => {
|
||||||
@@ -276,10 +280,11 @@ describe('mcp utils', () => {
|
|||||||
|
|
||||||
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||||
|
|
||||||
expect(Object.keys(result).sort()).toEqual(['boolean-tool', 'number-tool', 'string-tool'])
|
// Tools are keyed by id for uniqueness
|
||||||
expect(result['string-tool']).toBeDefined()
|
expect(Object.keys(result).sort()).toEqual(['boolean-tool-id', 'number-tool-id', 'string-tool-id'])
|
||||||
expect(result['number-tool']).toBeDefined()
|
expect(result['string-tool-id']).toBeDefined()
|
||||||
expect(result['boolean-tool']).toBeDefined()
|
expect(result['number-tool-id']).toBeDefined()
|
||||||
|
expect(result['boolean-tool-id']).toBeDefined()
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -310,7 +315,7 @@ describe('mcp utils', () => {
|
|||||||
]
|
]
|
||||||
|
|
||||||
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
||||||
const tool = tools['test-exec-tool'] as Tool
|
const tool = tools['test-exec-tool-id'] as Tool
|
||||||
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'test-call-123' })
|
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'test-call-123' })
|
||||||
|
|
||||||
expect(requestToolConfirmation).toHaveBeenCalled()
|
expect(requestToolConfirmation).toHaveBeenCalled()
|
||||||
@@ -343,7 +348,7 @@ describe('mcp utils', () => {
|
|||||||
]
|
]
|
||||||
|
|
||||||
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
||||||
const tool = tools['cancelled-tool'] as Tool
|
const tool = tools['cancelled-tool-id'] as Tool
|
||||||
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'cancel-call-123' })
|
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'cancel-call-123' })
|
||||||
|
|
||||||
expect(requestToolConfirmation).toHaveBeenCalled()
|
expect(requestToolConfirmation).toHaveBeenCalled()
|
||||||
@@ -385,7 +390,7 @@ describe('mcp utils', () => {
|
|||||||
]
|
]
|
||||||
|
|
||||||
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
||||||
const tool = tools['error-tool'] as Tool
|
const tool = tools['error-tool-id'] as Tool
|
||||||
|
|
||||||
await expect(
|
await expect(
|
||||||
tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'error-call-123' })
|
tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'error-call-123' })
|
||||||
@@ -421,7 +426,7 @@ describe('mcp utils', () => {
|
|||||||
]
|
]
|
||||||
|
|
||||||
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
||||||
const tool = tools['auto-approve-tool'] as Tool
|
const tool = tools['auto-approve-tool-id'] as Tool
|
||||||
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'auto-call-123' })
|
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'auto-call-123' })
|
||||||
|
|
||||||
expect(requestToolConfirmation).not.toHaveBeenCalled()
|
expect(requestToolConfirmation).not.toHaveBeenCalled()
|
||||||
|
|||||||
@@ -154,6 +154,10 @@ vi.mock('../websearch', () => ({
|
|||||||
getWebSearchParams: vi.fn(() => ({ enable_search: true }))
|
getWebSearchParams: vi.fn(() => ({ enable_search: true }))
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
vi.mock('../../prepareParams/header', () => ({
|
||||||
|
addAnthropicHeaders: vi.fn(() => ['context-1m-2025-08-07'])
|
||||||
|
}))
|
||||||
|
|
||||||
const ensureWindowApi = () => {
|
const ensureWindowApi = () => {
|
||||||
const globalWindow = window as any
|
const globalWindow = window as any
|
||||||
globalWindow.api = globalWindow.api || {}
|
globalWindow.api = globalWindow.api || {}
|
||||||
@@ -633,5 +637,64 @@ describe('options utils', () => {
|
|||||||
expect(result.providerOptions).toHaveProperty('anthropic')
|
expect(result.providerOptions).toHaveProperty('anthropic')
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe('AWS Bedrock provider', () => {
|
||||||
|
const bedrockProvider = {
|
||||||
|
id: 'bedrock',
|
||||||
|
name: 'AWS Bedrock',
|
||||||
|
type: 'aws-bedrock',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://bedrock.us-east-1.amazonaws.com',
|
||||||
|
models: [] as Model[]
|
||||||
|
} as Provider
|
||||||
|
|
||||||
|
const bedrockModel: Model = {
|
||||||
|
id: 'anthropic.claude-sonnet-4-20250514-v1:0',
|
||||||
|
name: 'Claude Sonnet 4',
|
||||||
|
provider: 'bedrock'
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
it('should build basic Bedrock options', () => {
|
||||||
|
const result = buildProviderOptions(mockAssistant, bedrockModel, bedrockProvider, {
|
||||||
|
enableReasoning: false,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.providerOptions).toHaveProperty('bedrock')
|
||||||
|
expect(result.providerOptions.bedrock).toBeDefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should include anthropicBeta when Anthropic headers are needed', async () => {
|
||||||
|
const { addAnthropicHeaders } = await import('../../prepareParams/header')
|
||||||
|
vi.mocked(addAnthropicHeaders).mockReturnValue(['interleaved-thinking-2025-05-14', 'context-1m-2025-08-07'])
|
||||||
|
|
||||||
|
const result = buildProviderOptions(mockAssistant, bedrockModel, bedrockProvider, {
|
||||||
|
enableReasoning: false,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.providerOptions.bedrock).toHaveProperty('anthropicBeta')
|
||||||
|
expect(result.providerOptions.bedrock.anthropicBeta).toEqual([
|
||||||
|
'interleaved-thinking-2025-05-14',
|
||||||
|
'context-1m-2025-08-07'
|
||||||
|
])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should include reasoning parameters when enabled', () => {
|
||||||
|
const result = buildProviderOptions(mockAssistant, bedrockModel, bedrockProvider, {
|
||||||
|
enableReasoning: true,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.providerOptions.bedrock).toHaveProperty('reasoningConfig')
|
||||||
|
expect(result.providerOptions.bedrock.reasoningConfig).toEqual({
|
||||||
|
type: 'enabled',
|
||||||
|
budgetTokens: 5000
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -144,7 +144,7 @@ describe('reasoning utils', () => {
|
|||||||
expect(result).toEqual({})
|
expect(result).toEqual({})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('should disable reasoning for OpenRouter when no reasoning effort set', async () => {
|
it('should not override reasoning for OpenRouter when reasoning effort undefined', async () => {
|
||||||
const { isReasoningModel } = await import('@renderer/config/models')
|
const { isReasoningModel } = await import('@renderer/config/models')
|
||||||
|
|
||||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
@@ -161,6 +161,29 @@ describe('reasoning utils', () => {
|
|||||||
settings: {}
|
settings: {}
|
||||||
} as Assistant
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getReasoningEffort(assistant, model)
|
||||||
|
expect(result).toEqual({})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should disable reasoning for OpenRouter when reasoning effort explicitly none', async () => {
|
||||||
|
const { isReasoningModel } = await import('@renderer/config/models')
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'anthropic/claude-sonnet-4',
|
||||||
|
name: 'Claude Sonnet 4',
|
||||||
|
provider: SystemProviderIds.openrouter
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {
|
||||||
|
reasoning_effort: 'none'
|
||||||
|
}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
const result = getReasoningEffort(assistant, model)
|
const result = getReasoningEffort(assistant, model)
|
||||||
expect(result).toEqual({ reasoning: { enabled: false, exclude: true } })
|
expect(result).toEqual({ reasoning: { enabled: false, exclude: true } })
|
||||||
})
|
})
|
||||||
@@ -269,7 +292,9 @@ describe('reasoning utils', () => {
|
|||||||
const assistant: Assistant = {
|
const assistant: Assistant = {
|
||||||
id: 'test',
|
id: 'test',
|
||||||
name: 'Test',
|
name: 'Test',
|
||||||
settings: {}
|
settings: {
|
||||||
|
reasoning_effort: 'none'
|
||||||
|
}
|
||||||
} as Assistant
|
} as Assistant
|
||||||
|
|
||||||
const result = getReasoningEffort(assistant, model)
|
const result = getReasoningEffort(assistant, model)
|
||||||
|
|||||||
@@ -28,7 +28,9 @@ export function convertMcpToolsToAiSdkTools(mcpTools: MCPTool[]): ToolSet {
|
|||||||
const tools: ToolSet = {}
|
const tools: ToolSet = {}
|
||||||
|
|
||||||
for (const mcpTool of mcpTools) {
|
for (const mcpTool of mcpTools) {
|
||||||
tools[mcpTool.name] = tool({
|
// Use mcpTool.id (which includes serverId suffix) to ensure uniqueness
|
||||||
|
// when multiple instances of the same MCP server type are configured
|
||||||
|
tools[mcpTool.id] = tool({
|
||||||
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
|
description: mcpTool.description || `Tool from ${mcpTool.serverName}`,
|
||||||
inputSchema: jsonSchema(mcpTool.inputSchema as JSONSchema7),
|
inputSchema: jsonSchema(mcpTool.inputSchema as JSONSchema7),
|
||||||
execute: async (params, { toolCallId }) => {
|
execute: async (params, { toolCallId }) => {
|
||||||
|
|||||||
@@ -36,6 +36,7 @@ import { isSupportServiceTierProvider, isSupportVerbosityProvider } from '@rende
|
|||||||
import type { JSONValue } from 'ai'
|
import type { JSONValue } from 'ai'
|
||||||
import { t } from 'i18next'
|
import { t } from 'i18next'
|
||||||
|
|
||||||
|
import { addAnthropicHeaders } from '../prepareParams/header'
|
||||||
import { getAiSdkProviderId } from '../provider/factory'
|
import { getAiSdkProviderId } from '../provider/factory'
|
||||||
import { buildGeminiGenerateImageParams } from './image'
|
import { buildGeminiGenerateImageParams } from './image'
|
||||||
import {
|
import {
|
||||||
@@ -469,6 +470,11 @@ function buildBedrockProviderOptions(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const betaHeaders = addAnthropicHeaders(assistant, model)
|
||||||
|
if (betaHeaders.length > 0) {
|
||||||
|
providerOptions.anthropicBeta = betaHeaders
|
||||||
|
}
|
||||||
|
|
||||||
return providerOptions
|
return providerOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,10 +16,8 @@ import {
|
|||||||
isGPT5SeriesModel,
|
isGPT5SeriesModel,
|
||||||
isGPT51SeriesModel,
|
isGPT51SeriesModel,
|
||||||
isGrok4FastReasoningModel,
|
isGrok4FastReasoningModel,
|
||||||
isGrokReasoningModel,
|
|
||||||
isOpenAIDeepResearchModel,
|
isOpenAIDeepResearchModel,
|
||||||
isOpenAIModel,
|
isOpenAIModel,
|
||||||
isOpenAIReasoningModel,
|
|
||||||
isQwenAlwaysThinkModel,
|
isQwenAlwaysThinkModel,
|
||||||
isQwenReasoningModel,
|
isQwenReasoningModel,
|
||||||
isReasoningModel,
|
isReasoningModel,
|
||||||
@@ -64,30 +62,22 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
|
|||||||
}
|
}
|
||||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||||
|
|
||||||
// Handle undefined and 'none' reasoningEffort.
|
// reasoningEffort is not set, no extra reasoning setting
|
||||||
// TODO: They should be separated.
|
// Generally, for every model which supports reasoning control, the reasoning effort won't be undefined.
|
||||||
if (!reasoningEffort || reasoningEffort === 'none') {
|
// It's for some reasoning models that don't support reasoning control, such as deepseek reasoner.
|
||||||
|
if (!reasoningEffort) {
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle 'none' reasoningEffort. It's explicitly off.
|
||||||
|
if (reasoningEffort === 'none') {
|
||||||
// openrouter: use reasoning
|
// openrouter: use reasoning
|
||||||
if (model.provider === SystemProviderIds.openrouter) {
|
if (model.provider === SystemProviderIds.openrouter) {
|
||||||
// Don't disable reasoning for Gemini models that support thinking tokens
|
|
||||||
if (isSupportedThinkingTokenGeminiModel(model) && !GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
|
|
||||||
return {}
|
|
||||||
}
|
|
||||||
// 'none' is not an available value for effort for now.
|
// 'none' is not an available value for effort for now.
|
||||||
// I think they should resolve this issue soon, so I'll just go ahead and use this value.
|
// I think they should resolve this issue soon, so I'll just go ahead and use this value.
|
||||||
if (isGPT51SeriesModel(model) && reasoningEffort === 'none') {
|
if (isGPT51SeriesModel(model) && reasoningEffort === 'none') {
|
||||||
return { reasoning: { effort: 'none' } }
|
return { reasoning: { effort: 'none' } }
|
||||||
}
|
}
|
||||||
// Don't disable reasoning for models that require it
|
|
||||||
if (
|
|
||||||
isGrokReasoningModel(model) ||
|
|
||||||
isOpenAIReasoningModel(model) ||
|
|
||||||
isQwenAlwaysThinkModel(model) ||
|
|
||||||
model.id.includes('seed-oss') ||
|
|
||||||
model.id.includes('minimax-m2')
|
|
||||||
) {
|
|
||||||
return {}
|
|
||||||
}
|
|
||||||
return { reasoning: { enabled: false, exclude: true } }
|
return { reasoning: { enabled: false, exclude: true } }
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,11 +91,6 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
|
|||||||
return { enable_thinking: false }
|
return { enable_thinking: false }
|
||||||
}
|
}
|
||||||
|
|
||||||
// claude
|
|
||||||
if (isSupportedThinkingTokenClaudeModel(model)) {
|
|
||||||
return {}
|
|
||||||
}
|
|
||||||
|
|
||||||
// gemini
|
// gemini
|
||||||
if (isSupportedThinkingTokenGeminiModel(model)) {
|
if (isSupportedThinkingTokenGeminiModel(model)) {
|
||||||
if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
|
if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
|
||||||
@@ -118,8 +103,10 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
logger.warn(`Model ${model.id} cannot disable reasoning. Fallback to empty reasoning param.`)
|
||||||
|
return {}
|
||||||
}
|
}
|
||||||
return {}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// use thinking, doubao, zhipu, etc.
|
// use thinking, doubao, zhipu, etc.
|
||||||
@@ -139,6 +126,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.warn(`Model ${model.id} doesn't match any disable reasoning behavior. Fallback to empty reasoning param.`)
|
||||||
return {}
|
return {}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -293,6 +281,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
|
|||||||
}
|
}
|
||||||
|
|
||||||
// OpenRouter models, use reasoning
|
// OpenRouter models, use reasoning
|
||||||
|
// FIXME: duplicated openrouter handling. remove one
|
||||||
if (model.provider === SystemProviderIds.openrouter) {
|
if (model.provider === SystemProviderIds.openrouter) {
|
||||||
if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) {
|
if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) {
|
||||||
return {
|
return {
|
||||||
@@ -684,6 +673,10 @@ export function getCustomParameters(assistant: Assistant): Record<string, any> {
|
|||||||
if (!param.name?.trim()) {
|
if (!param.name?.trim()) {
|
||||||
return acc
|
return acc
|
||||||
}
|
}
|
||||||
|
// Parse JSON type parameters
|
||||||
|
// Related: src/renderer/src/pages/settings/AssistantSettings/AssistantModelSettings.tsx:133-148
|
||||||
|
// The UI stores JSON type params as strings (e.g., '{"key":"value"}')
|
||||||
|
// This function parses them into objects before sending to the API
|
||||||
if (param.type === 'json') {
|
if (param.type === 'json') {
|
||||||
const value = param.value as string
|
const value = param.value as string
|
||||||
if (value === 'undefined') {
|
if (value === 'undefined') {
|
||||||
|
|||||||
@@ -215,6 +215,10 @@
|
|||||||
border-top: none !important;
|
border-top: none !important;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.ant-collapse-header-text {
|
||||||
|
overflow-x: hidden;
|
||||||
|
}
|
||||||
|
|
||||||
.ant-slider .ant-slider-handle::after {
|
.ant-slider .ant-slider-handle::after {
|
||||||
box-shadow: 0 1px 4px 0px rgb(128 128 128 / 50%) !important;
|
box-shadow: 0 1px 4px 0px rgb(128 128 128 / 50%) !important;
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import {
|
|||||||
} from '@ant-design/icons'
|
} from '@ant-design/icons'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { download } from '@renderer/utils/download'
|
import { download } from '@renderer/utils/download'
|
||||||
|
import { convertImageToPng } from '@renderer/utils/image'
|
||||||
import type { ImageProps as AntImageProps } from 'antd'
|
import type { ImageProps as AntImageProps } from 'antd'
|
||||||
import { Dropdown, Image as AntImage, Space } from 'antd'
|
import { Dropdown, Image as AntImage, Space } from 'antd'
|
||||||
import { Base64 } from 'js-base64'
|
import { Base64 } from 'js-base64'
|
||||||
@@ -33,39 +34,38 @@ const ImageViewer: React.FC<ImageViewerProps> = ({ src, style, ...props }) => {
|
|||||||
// 复制图片到剪贴板
|
// 复制图片到剪贴板
|
||||||
const handleCopyImage = async (src: string) => {
|
const handleCopyImage = async (src: string) => {
|
||||||
try {
|
try {
|
||||||
|
let blob: Blob
|
||||||
|
|
||||||
if (src.startsWith('data:')) {
|
if (src.startsWith('data:')) {
|
||||||
// 处理 base64 格式的图片
|
// 处理 base64 格式的图片
|
||||||
const match = src.match(/^data:(image\/\w+);base64,(.+)$/)
|
const match = src.match(/^data:(image\/\w+);base64,(.+)$/)
|
||||||
if (!match) throw new Error('Invalid base64 image format')
|
if (!match) throw new Error('Invalid base64 image format')
|
||||||
const mimeType = match[1]
|
const mimeType = match[1]
|
||||||
const byteArray = Base64.toUint8Array(match[2])
|
const byteArray = Base64.toUint8Array(match[2])
|
||||||
const blob = new Blob([byteArray], { type: mimeType })
|
blob = new Blob([byteArray], { type: mimeType })
|
||||||
await navigator.clipboard.write([new ClipboardItem({ [mimeType]: blob })])
|
|
||||||
} else if (src.startsWith('file://')) {
|
} else if (src.startsWith('file://')) {
|
||||||
// 处理本地文件路径
|
// 处理本地文件路径
|
||||||
const bytes = await window.api.fs.read(src)
|
const bytes = await window.api.fs.read(src)
|
||||||
const mimeType = mime.getType(src) || 'application/octet-stream'
|
const mimeType = mime.getType(src) || 'application/octet-stream'
|
||||||
const blob = new Blob([bytes], { type: mimeType })
|
blob = new Blob([bytes], { type: mimeType })
|
||||||
await navigator.clipboard.write([
|
|
||||||
new ClipboardItem({
|
|
||||||
[mimeType]: blob
|
|
||||||
})
|
|
||||||
])
|
|
||||||
} else {
|
} else {
|
||||||
// 处理 URL 格式的图片
|
// 处理 URL 格式的图片
|
||||||
const response = await fetch(src)
|
const response = await fetch(src)
|
||||||
const blob = await response.blob()
|
blob = await response.blob()
|
||||||
|
|
||||||
await navigator.clipboard.write([
|
|
||||||
new ClipboardItem({
|
|
||||||
[blob.type]: blob
|
|
||||||
})
|
|
||||||
])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 统一转换为 PNG 以确保兼容性(剪贴板 API 不支持 JPEG)
|
||||||
|
const pngBlob = await convertImageToPng(blob)
|
||||||
|
|
||||||
|
const item = new ClipboardItem({
|
||||||
|
'image/png': pngBlob
|
||||||
|
})
|
||||||
|
await navigator.clipboard.write([item])
|
||||||
|
|
||||||
window.toast.success(t('message.copy.success'))
|
window.toast.success(t('message.copy.success'))
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
logger.error('Failed to copy image:', error as Error)
|
const err = error as Error
|
||||||
|
logger.error(`Failed to copy image: ${err.message}`, { stack: err.stack })
|
||||||
window.toast.error(t('message.copy.failed'))
|
window.toast.error(t('message.copy.failed'))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,7 +57,7 @@ const PopupContainer: React.FC<Props> = ({ model, apiFilter, modelFilter, showTa
|
|||||||
const [_searchText, setSearchText] = useState('')
|
const [_searchText, setSearchText] = useState('')
|
||||||
const searchText = useDeferredValue(_searchText)
|
const searchText = useDeferredValue(_searchText)
|
||||||
const { models, isLoading } = useApiModels(apiFilter)
|
const { models, isLoading } = useApiModels(apiFilter)
|
||||||
const adaptedModels = models.map((model) => apiModelAdapter(model))
|
const adaptedModels = useMemo(() => models.map((model) => apiModelAdapter(model)), [models])
|
||||||
|
|
||||||
// 当前选中的模型ID
|
// 当前选中的模型ID
|
||||||
const currentModelId = model ? model.id : ''
|
const currentModelId = model ? model.id : ''
|
||||||
|
|||||||
@@ -309,11 +309,14 @@ describe('Ling Models', () => {
|
|||||||
describe('Claude & regional providers', () => {
|
describe('Claude & regional providers', () => {
|
||||||
it('identifies claude 4.5 variants', () => {
|
it('identifies claude 4.5 variants', () => {
|
||||||
expect(isClaude45ReasoningModel(createModel({ id: 'claude-sonnet-4.5-preview' }))).toBe(true)
|
expect(isClaude45ReasoningModel(createModel({ id: 'claude-sonnet-4.5-preview' }))).toBe(true)
|
||||||
|
expect(isClaude4SeriesModel(createModel({ id: 'claude-sonnet-4-5@20250929' }))).toBe(true)
|
||||||
expect(isClaude45ReasoningModel(createModel({ id: 'claude-3-sonnet' }))).toBe(false)
|
expect(isClaude45ReasoningModel(createModel({ id: 'claude-3-sonnet' }))).toBe(false)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('identifies claude 4 variants', () => {
|
it('identifies claude 4 variants', () => {
|
||||||
expect(isClaude4SeriesModel(createModel({ id: 'claude-opus-4' }))).toBe(true)
|
expect(isClaude4SeriesModel(createModel({ id: 'claude-opus-4' }))).toBe(true)
|
||||||
|
expect(isClaude4SeriesModel(createModel({ id: 'claude-sonnet-4@20250514' }))).toBe(true)
|
||||||
|
expect(isClaude4SeriesModel(createModel({ id: 'anthropic.claude-sonnet-4-20250514-v1:0' }))).toBe(true)
|
||||||
expect(isClaude4SeriesModel(createModel({ id: 'claude-4.2-sonnet-variant' }))).toBe(false)
|
expect(isClaude4SeriesModel(createModel({ id: 'claude-4.2-sonnet-variant' }))).toBe(false)
|
||||||
expect(isClaude4SeriesModel(createModel({ id: 'claude-3-haiku' }))).toBe(false)
|
expect(isClaude4SeriesModel(createModel({ id: 'claude-3-haiku' }))).toBe(false)
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import {
|
|||||||
isSupportVerbosityModel
|
isSupportVerbosityModel
|
||||||
} from '../openai'
|
} from '../openai'
|
||||||
import { isQwenMTModel } from '../qwen'
|
import { isQwenMTModel } from '../qwen'
|
||||||
|
import { isFunctionCallingModel } from '../tooluse'
|
||||||
import {
|
import {
|
||||||
agentModelFilter,
|
agentModelFilter,
|
||||||
getModelSupportedVerbosity,
|
getModelSupportedVerbosity,
|
||||||
@@ -112,6 +113,7 @@ const textToImageMock = vi.mocked(isTextToImageModel)
|
|||||||
const generateImageMock = vi.mocked(isGenerateImageModel)
|
const generateImageMock = vi.mocked(isGenerateImageModel)
|
||||||
const reasoningMock = vi.mocked(isOpenAIReasoningModel)
|
const reasoningMock = vi.mocked(isOpenAIReasoningModel)
|
||||||
const openAIWebSearchOnlyMock = vi.mocked(isOpenAIWebSearchChatCompletionOnlyModel)
|
const openAIWebSearchOnlyMock = vi.mocked(isOpenAIWebSearchChatCompletionOnlyModel)
|
||||||
|
const isFunctionCallingModelMock = vi.mocked(isFunctionCallingModel)
|
||||||
|
|
||||||
describe('model utils', () => {
|
describe('model utils', () => {
|
||||||
beforeEach(() => {
|
beforeEach(() => {
|
||||||
@@ -120,200 +122,387 @@ describe('model utils', () => {
|
|||||||
rerankMock.mockReturnValue(false)
|
rerankMock.mockReturnValue(false)
|
||||||
visionMock.mockReturnValue(true)
|
visionMock.mockReturnValue(true)
|
||||||
textToImageMock.mockReturnValue(false)
|
textToImageMock.mockReturnValue(false)
|
||||||
generateImageMock.mockReturnValue(true)
|
generateImageMock.mockReturnValue(false)
|
||||||
reasoningMock.mockReturnValue(false)
|
reasoningMock.mockReturnValue(false)
|
||||||
openAIWebSearchOnlyMock.mockReturnValue(false)
|
openAIWebSearchOnlyMock.mockReturnValue(false)
|
||||||
})
|
})
|
||||||
|
|
||||||
it('detects OpenAI LLM models through reasoning and GPT prefix', () => {
|
describe('OpenAI model detection', () => {
|
||||||
expect(isOpenAILLMModel(undefined as unknown as Model)).toBe(false)
|
describe('isOpenAILLMModel', () => {
|
||||||
expect(isOpenAILLMModel(createModel({ id: 'gpt-4o-image' }))).toBe(false)
|
it('returns false for undefined model', () => {
|
||||||
|
expect(isOpenAILLMModel(undefined as unknown as Model)).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
reasoningMock.mockReturnValueOnce(true)
|
it('returns false for image generation models', () => {
|
||||||
expect(isOpenAILLMModel(createModel({ id: 'o1-preview' }))).toBe(true)
|
expect(isOpenAILLMModel(createModel({ id: 'gpt-4o-image' }))).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
expect(isOpenAILLMModel(createModel({ id: 'GPT-5-turbo' }))).toBe(true)
|
it('returns true for reasoning models', () => {
|
||||||
})
|
reasoningMock.mockReturnValueOnce(true)
|
||||||
|
expect(isOpenAILLMModel(createModel({ id: 'o1-preview' }))).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
it('detects OpenAI models via GPT prefix or reasoning support', () => {
|
it('returns true for GPT-prefixed models', () => {
|
||||||
expect(isOpenAIModel(createModel({ id: 'gpt-4.1' }))).toBe(true)
|
expect(isOpenAILLMModel(createModel({ id: 'GPT-5-turbo' }))).toBe(true)
|
||||||
reasoningMock.mockReturnValueOnce(true)
|
})
|
||||||
expect(isOpenAIModel(createModel({ id: 'o3' }))).toBe(true)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('evaluates support for flex service tier and alias helper', () => {
|
|
||||||
expect(isSupportFlexServiceTierModel(createModel({ id: 'o3' }))).toBe(true)
|
|
||||||
expect(isSupportFlexServiceTierModel(createModel({ id: 'o3-mini' }))).toBe(false)
|
|
||||||
expect(isSupportFlexServiceTierModel(createModel({ id: 'o4-mini' }))).toBe(true)
|
|
||||||
expect(isSupportFlexServiceTierModel(createModel({ id: 'gpt-5-preview' }))).toBe(true)
|
|
||||||
expect(isSupportedFlexServiceTier(createModel({ id: 'gpt-4o' }))).toBe(false)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('detects verbosity support for GPT-5+ families', () => {
|
|
||||||
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5' }))).toBe(true)
|
|
||||||
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5-chat' }))).toBe(false)
|
|
||||||
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5.1-preview' }))).toBe(true)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('limits verbosity controls for GPT-5 Pro models', () => {
|
|
||||||
const proModel = createModel({ id: 'gpt-5-pro' })
|
|
||||||
const previewModel = createModel({ id: 'gpt-5-preview' })
|
|
||||||
expect(getModelSupportedVerbosity(proModel)).toEqual([undefined, 'high'])
|
|
||||||
expect(getModelSupportedVerbosity(previewModel)).toEqual([undefined, 'low', 'medium', 'high'])
|
|
||||||
expect(isGPT5ProModel(proModel)).toBe(true)
|
|
||||||
expect(isGPT5ProModel(previewModel)).toBe(false)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('identifies OpenAI chat-completion-only models', () => {
|
|
||||||
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
|
|
||||||
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'o1-mini' }))).toBe(true)
|
|
||||||
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'gpt-4o' }))).toBe(false)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('filters unsupported OpenAI catalog entries', () => {
|
|
||||||
expect(isSupportedModel({ id: 'gpt-4', object: 'model' } as any)).toBe(true)
|
|
||||||
expect(isSupportedModel({ id: 'tts-1', object: 'model' } as any)).toBe(false)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('calculates temperature/top-p support correctly', () => {
|
|
||||||
const model = createModel({ id: 'o1' })
|
|
||||||
reasoningMock.mockReturnValue(true)
|
|
||||||
expect(isNotSupportTemperatureAndTopP(model)).toBe(true)
|
|
||||||
|
|
||||||
const openWeight = createModel({ id: 'gpt-oss-debug' })
|
|
||||||
expect(isNotSupportTemperatureAndTopP(openWeight)).toBe(false)
|
|
||||||
|
|
||||||
const chatOnly = createModel({ id: 'o1-preview' })
|
|
||||||
reasoningMock.mockReturnValue(false)
|
|
||||||
expect(isNotSupportTemperatureAndTopP(chatOnly)).toBe(true)
|
|
||||||
|
|
||||||
const qwenMt = createModel({ id: 'qwen-mt-large', provider: 'aliyun' })
|
|
||||||
expect(isNotSupportTemperatureAndTopP(qwenMt)).toBe(true)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('handles gemma and gemini detections plus zhipu tagging', () => {
|
|
||||||
expect(isGemmaModel(createModel({ id: 'Gemma-3-27B' }))).toBe(true)
|
|
||||||
expect(isGemmaModel(createModel({ group: 'Gemma' }))).toBe(true)
|
|
||||||
expect(isGemmaModel(createModel({ id: 'gpt-4o' }))).toBe(false)
|
|
||||||
|
|
||||||
expect(isGeminiModel(createModel({ id: 'Gemini-2.0' }))).toBe(true)
|
|
||||||
|
|
||||||
expect(isZhipuModel(createModel({ provider: 'zhipu' }))).toBe(true)
|
|
||||||
expect(isZhipuModel(createModel({ provider: 'openai' }))).toBe(false)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('groups qwen models by prefix', () => {
|
|
||||||
const qwen = createModel({ id: 'Qwen-7B', provider: 'qwen', name: 'Qwen-7B' })
|
|
||||||
const qwenOmni = createModel({ id: 'qwen2.5-omni', name: 'qwen2.5-omni' })
|
|
||||||
const other = createModel({ id: 'deepseek-v3', group: 'DeepSeek' })
|
|
||||||
|
|
||||||
const grouped = groupQwenModels([qwen, qwenOmni, other])
|
|
||||||
expect(Object.keys(grouped)).toContain('qwen-7b')
|
|
||||||
expect(Object.keys(grouped)).toContain('qwen2.5')
|
|
||||||
expect(grouped.DeepSeek).toContain(other)
|
|
||||||
})
|
|
||||||
|
|
||||||
it('aggregates boolean helpers based on regex rules', () => {
|
|
||||||
expect(isAnthropicModel(createModel({ id: 'claude-3.5' }))).toBe(true)
|
|
||||||
expect(isQwenMTModel(createModel({ id: 'qwen-mt-plus' }))).toBe(true)
|
|
||||||
expect(isNotSupportSystemMessageModel(createModel({ id: 'gemma-moe' }))).toBe(true)
|
|
||||||
expect(isOpenAIOpenWeightModel(createModel({ id: 'gpt-oss-free' }))).toBe(true)
|
|
||||||
})
|
|
||||||
|
|
||||||
describe('isNotSupportedTextDelta', () => {
|
|
||||||
it('returns true for qwen-mt-turbo and qwen-mt-plus models', () => {
|
|
||||||
// qwen-mt series that don't support text delta
|
|
||||||
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-turbo' }))).toBe(true)
|
|
||||||
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-plus' }))).toBe(true)
|
|
||||||
expect(isNotSupportTextDeltaModel(createModel({ id: 'Qwen-MT-Turbo' }))).toBe(true)
|
|
||||||
expect(isNotSupportTextDeltaModel(createModel({ id: 'QWEN-MT-PLUS' }))).toBe(true)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
it('returns false for qwen-mt-flash and other models', () => {
|
describe('isOpenAIModel', () => {
|
||||||
// qwen-mt-flash supports text delta
|
it('detects models via GPT prefix', () => {
|
||||||
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-flash' }))).toBe(false)
|
expect(isOpenAIModel(createModel({ id: 'gpt-4.1' }))).toBe(true)
|
||||||
expect(isNotSupportTextDeltaModel(createModel({ id: 'Qwen-MT-Flash' }))).toBe(false)
|
})
|
||||||
|
|
||||||
// Legacy qwen models without mt prefix (support text delta)
|
it('detects models via reasoning support', () => {
|
||||||
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-turbo' }))).toBe(false)
|
reasoningMock.mockReturnValueOnce(true)
|
||||||
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-plus' }))).toBe(false)
|
expect(isOpenAIModel(createModel({ id: 'o3' }))).toBe(true)
|
||||||
|
})
|
||||||
// Other qwen models
|
|
||||||
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-max' }))).toBe(false)
|
|
||||||
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen2.5-72b' }))).toBe(false)
|
|
||||||
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-vl-plus' }))).toBe(false)
|
|
||||||
|
|
||||||
// Non-qwen models
|
|
||||||
expect(isNotSupportTextDeltaModel(createModel({ id: 'gpt-4o' }))).toBe(false)
|
|
||||||
expect(isNotSupportTextDeltaModel(createModel({ id: 'claude-3.5' }))).toBe(false)
|
|
||||||
expect(isNotSupportTextDeltaModel(createModel({ id: 'glm-4-plus' }))).toBe(false)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
it('handles models with version suffixes', () => {
|
describe('isOpenAIChatCompletionOnlyModel', () => {
|
||||||
// qwen-mt models with version suffixes
|
it('identifies chat-completion-only models', () => {
|
||||||
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-turbo-1201' }))).toBe(true)
|
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
|
||||||
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-plus-0828' }))).toBe(true)
|
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'o1-mini' }))).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
// Legacy qwen models with version suffixes (support text delta)
|
it('returns false for general models', () => {
|
||||||
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-turbo-0828' }))).toBe(false)
|
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'gpt-4o' }))).toBe(false)
|
||||||
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-plus-latest' }))).toBe(false)
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('evaluates GPT-5 family helpers', () => {
|
describe('GPT-5 family detection', () => {
|
||||||
expect(isGPT5SeriesModel(createModel({ id: 'gpt-5-preview' }))).toBe(true)
|
describe('isGPT5SeriesModel', () => {
|
||||||
expect(isGPT5SeriesModel(createModel({ id: 'gpt-5.1-preview' }))).toBe(false)
|
it('returns true for GPT-5 models', () => {
|
||||||
expect(isGPT51SeriesModel(createModel({ id: 'gpt-5.1-mini' }))).toBe(true)
|
expect(isGPT5SeriesModel(createModel({ id: 'gpt-5-preview' }))).toBe(true)
|
||||||
expect(isGPT5SeriesReasoningModel(createModel({ id: 'gpt-5-prompt' }))).toBe(true)
|
})
|
||||||
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5-chat' }))).toBe(false)
|
|
||||||
|
it('returns false for GPT-5.1 models', () => {
|
||||||
|
expect(isGPT5SeriesModel(createModel({ id: 'gpt-5.1-preview' }))).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('isGPT51SeriesModel', () => {
|
||||||
|
it('returns true for GPT-5.1 models', () => {
|
||||||
|
expect(isGPT51SeriesModel(createModel({ id: 'gpt-5.1-mini' }))).toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('isGPT5SeriesReasoningModel', () => {
|
||||||
|
it('returns true for GPT-5 reasoning models', () => {
|
||||||
|
expect(isGPT5SeriesReasoningModel(createModel({ id: 'gpt-5' }))).toBe(true)
|
||||||
|
})
|
||||||
|
it('returns false for gpt-5-chat', () => {
|
||||||
|
expect(isGPT5SeriesReasoningModel(createModel({ id: 'gpt-5-chat' }))).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('isGPT5ProModel', () => {
|
||||||
|
it('returns true for GPT-5 Pro models', () => {
|
||||||
|
expect(isGPT5ProModel(createModel({ id: 'gpt-5-pro' }))).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns false for non-Pro GPT-5 models', () => {
|
||||||
|
expect(isGPT5ProModel(createModel({ id: 'gpt-5-preview' }))).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('wraps generate/vision helpers that operate on arrays', () => {
|
describe('Verbosity support', () => {
|
||||||
const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })]
|
describe('isSupportVerbosityModel', () => {
|
||||||
expect(isVisionModels(models)).toBe(true)
|
it('returns true for GPT-5 models', () => {
|
||||||
visionMock.mockReturnValueOnce(true).mockReturnValueOnce(false)
|
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5' }))).toBe(true)
|
||||||
expect(isVisionModels(models)).toBe(false)
|
})
|
||||||
|
|
||||||
expect(isGenerateImageModels(models)).toBe(true)
|
it('returns false for GPT-5 chat models', () => {
|
||||||
generateImageMock.mockReturnValueOnce(true).mockReturnValueOnce(false)
|
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5-chat' }))).toBe(false)
|
||||||
expect(isGenerateImageModels(models)).toBe(false)
|
})
|
||||||
|
|
||||||
|
it('returns true for GPT-5.1 models', () => {
|
||||||
|
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5.1-preview' }))).toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getModelSupportedVerbosity', () => {
|
||||||
|
it('returns only "high" for GPT-5 Pro models', () => {
|
||||||
|
expect(getModelSupportedVerbosity(createModel({ id: 'gpt-5-pro' }))).toEqual([undefined, 'high'])
|
||||||
|
expect(getModelSupportedVerbosity(createModel({ id: 'gpt-5-pro-2025-10-06' }))).toEqual([undefined, 'high'])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns all levels for non-Pro GPT-5 models', () => {
|
||||||
|
const previewModel = createModel({ id: 'gpt-5-preview' })
|
||||||
|
expect(getModelSupportedVerbosity(previewModel)).toEqual([undefined, 'low', 'medium', 'high'])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns all levels for GPT-5.1 models', () => {
|
||||||
|
const gpt51Model = createModel({ id: 'gpt-5.1-preview' })
|
||||||
|
expect(getModelSupportedVerbosity(gpt51Model)).toEqual([undefined, 'low', 'medium', 'high'])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns only undefined for non-GPT-5 models', () => {
|
||||||
|
expect(getModelSupportedVerbosity(createModel({ id: 'gpt-4o' }))).toEqual([undefined])
|
||||||
|
expect(getModelSupportedVerbosity(createModel({ id: 'claude-3.5' }))).toEqual([undefined])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns only undefined for undefiend/null input', () => {
|
||||||
|
expect(getModelSupportedVerbosity(undefined)).toEqual([undefined])
|
||||||
|
expect(getModelSupportedVerbosity(null)).toEqual([undefined])
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
it('filters models for agent usage', () => {
|
describe('Flex service tier support', () => {
|
||||||
expect(agentModelFilter(createModel())).toBe(true)
|
describe('isSupportFlexServiceTierModel', () => {
|
||||||
|
it('returns true for supported models', () => {
|
||||||
|
expect(isSupportFlexServiceTierModel(createModel({ id: 'o3' }))).toBe(true)
|
||||||
|
expect(isSupportFlexServiceTierModel(createModel({ id: 'o4-mini' }))).toBe(true)
|
||||||
|
expect(isSupportFlexServiceTierModel(createModel({ id: 'gpt-5-preview' }))).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
embeddingMock.mockReturnValueOnce(true)
|
it('returns false for unsupported models', () => {
|
||||||
expect(agentModelFilter(createModel({ id: 'text-embedding' }))).toBe(false)
|
expect(isSupportFlexServiceTierModel(createModel({ id: 'o3-mini' }))).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
embeddingMock.mockReturnValue(false)
|
describe('isSupportedFlexServiceTier', () => {
|
||||||
rerankMock.mockReturnValueOnce(true)
|
it('returns false for non-flex models', () => {
|
||||||
expect(agentModelFilter(createModel({ id: 'rerank' }))).toBe(false)
|
expect(isSupportedFlexServiceTier(createModel({ id: 'gpt-4o' }))).toBe(false)
|
||||||
|
})
|
||||||
rerankMock.mockReturnValue(false)
|
})
|
||||||
textToImageMock.mockReturnValueOnce(true)
|
|
||||||
expect(agentModelFilter(createModel({ id: 'gpt-image-1' }))).toBe(false)
|
|
||||||
})
|
})
|
||||||
|
|
||||||
it('identifies models with maximum temperature of 1.0', () => {
|
describe('Temperature and top-p support', () => {
|
||||||
// Zhipu models should have max temperature of 1.0
|
describe('isNotSupportTemperatureAndTopP', () => {
|
||||||
expect(isMaxTemperatureOneModel(createModel({ id: 'glm-4' }))).toBe(true)
|
it('returns true for reasoning models', () => {
|
||||||
expect(isMaxTemperatureOneModel(createModel({ id: 'GLM-4-Plus' }))).toBe(true)
|
const model = createModel({ id: 'o1' })
|
||||||
expect(isMaxTemperatureOneModel(createModel({ id: 'glm-3-turbo' }))).toBe(true)
|
reasoningMock.mockReturnValue(true)
|
||||||
|
expect(isNotSupportTemperatureAndTopP(model)).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
// Anthropic models should have max temperature of 1.0
|
it('returns false for open weight models', () => {
|
||||||
expect(isMaxTemperatureOneModel(createModel({ id: 'claude-3.5-sonnet' }))).toBe(true)
|
const openWeight = createModel({ id: 'gpt-oss-debug' })
|
||||||
expect(isMaxTemperatureOneModel(createModel({ id: 'Claude-3-opus' }))).toBe(true)
|
expect(isNotSupportTemperatureAndTopP(openWeight)).toBe(false)
|
||||||
expect(isMaxTemperatureOneModel(createModel({ id: 'claude-2.1' }))).toBe(true)
|
})
|
||||||
|
|
||||||
// Moonshot models should have max temperature of 1.0
|
it('returns true for chat-only models without reasoning', () => {
|
||||||
expect(isMaxTemperatureOneModel(createModel({ id: 'moonshot-1.0' }))).toBe(true)
|
const chatOnly = createModel({ id: 'o1-preview' })
|
||||||
expect(isMaxTemperatureOneModel(createModel({ id: 'kimi-k2-thinking' }))).toBe(true)
|
reasoningMock.mockReturnValue(false)
|
||||||
expect(isMaxTemperatureOneModel(createModel({ id: 'Moonshot-Pro' }))).toBe(true)
|
expect(isNotSupportTemperatureAndTopP(chatOnly)).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
// Other models should return false
|
it('returns true for Qwen MT models', () => {
|
||||||
expect(isMaxTemperatureOneModel(createModel({ id: 'gpt-4o' }))).toBe(false)
|
const qwenMt = createModel({ id: 'qwen-mt-large', provider: 'aliyun' })
|
||||||
expect(isMaxTemperatureOneModel(createModel({ id: 'gpt-4-turbo' }))).toBe(false)
|
expect(isNotSupportTemperatureAndTopP(qwenMt)).toBe(true)
|
||||||
expect(isMaxTemperatureOneModel(createModel({ id: 'qwen-max' }))).toBe(false)
|
})
|
||||||
expect(isMaxTemperatureOneModel(createModel({ id: 'gemini-pro' }))).toBe(false)
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Text delta support', () => {
|
||||||
|
describe('isNotSupportTextDeltaModel', () => {
|
||||||
|
it('returns true for qwen-mt-turbo and qwen-mt-plus models', () => {
|
||||||
|
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-turbo' }))).toBe(true)
|
||||||
|
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-plus' }))).toBe(true)
|
||||||
|
expect(isNotSupportTextDeltaModel(createModel({ id: 'Qwen-MT-Turbo' }))).toBe(true)
|
||||||
|
expect(isNotSupportTextDeltaModel(createModel({ id: 'QWEN-MT-PLUS' }))).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns false for qwen-mt-flash and other models', () => {
|
||||||
|
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-flash' }))).toBe(false)
|
||||||
|
expect(isNotSupportTextDeltaModel(createModel({ id: 'Qwen-MT-Flash' }))).toBe(false)
|
||||||
|
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-turbo' }))).toBe(false)
|
||||||
|
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-plus' }))).toBe(false)
|
||||||
|
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-max' }))).toBe(false)
|
||||||
|
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen2.5-72b' }))).toBe(false)
|
||||||
|
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-vl-plus' }))).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns false for non-qwen models', () => {
|
||||||
|
expect(isNotSupportTextDeltaModel(createModel({ id: 'gpt-4o' }))).toBe(false)
|
||||||
|
expect(isNotSupportTextDeltaModel(createModel({ id: 'claude-3.5' }))).toBe(false)
|
||||||
|
expect(isNotSupportTextDeltaModel(createModel({ id: 'glm-4-plus' }))).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('handles models with version suffixes', () => {
|
||||||
|
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-turbo-1201' }))).toBe(true)
|
||||||
|
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-mt-plus-0828' }))).toBe(true)
|
||||||
|
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-turbo-0828' }))).toBe(false)
|
||||||
|
expect(isNotSupportTextDeltaModel(createModel({ id: 'qwen-plus-latest' }))).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Model provider detection', () => {
|
||||||
|
describe('isGemmaModel', () => {
|
||||||
|
it('detects Gemma models by ID', () => {
|
||||||
|
expect(isGemmaModel(createModel({ id: 'Gemma-3-27B' }))).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('detects Gemma models by group', () => {
|
||||||
|
expect(isGemmaModel(createModel({ group: 'Gemma' }))).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns false for non-Gemma models', () => {
|
||||||
|
expect(isGemmaModel(createModel({ id: 'gpt-4o' }))).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('isGeminiModel', () => {
|
||||||
|
it('detects Gemini models', () => {
|
||||||
|
expect(isGeminiModel(createModel({ id: 'Gemini-2.0' }))).toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('isZhipuModel', () => {
|
||||||
|
it('detects Zhipu models by provider', () => {
|
||||||
|
expect(isZhipuModel(createModel({ provider: 'zhipu' }))).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns false for non-Zhipu models', () => {
|
||||||
|
expect(isZhipuModel(createModel({ provider: 'openai' }))).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('isAnthropicModel', () => {
|
||||||
|
it('detects Anthropic models', () => {
|
||||||
|
expect(isAnthropicModel(createModel({ id: 'claude-3.5' }))).toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('isQwenMTModel', () => {
|
||||||
|
it('detects Qwen MT models', () => {
|
||||||
|
expect(isQwenMTModel(createModel({ id: 'qwen-mt-plus' }))).toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('isOpenAIOpenWeightModel', () => {
|
||||||
|
it('detects OpenAI open weight models', () => {
|
||||||
|
expect(isOpenAIOpenWeightModel(createModel({ id: 'gpt-oss-free' }))).toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('System message support', () => {
|
||||||
|
describe('isNotSupportSystemMessageModel', () => {
|
||||||
|
it('returns true for models that do not support system messages', () => {
|
||||||
|
expect(isNotSupportSystemMessageModel(createModel({ id: 'gemma-moe' }))).toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Model grouping', () => {
|
||||||
|
describe('groupQwenModels', () => {
|
||||||
|
it('groups qwen models by prefix', () => {
|
||||||
|
const qwen = createModel({ id: 'Qwen-7B', provider: 'qwen', name: 'Qwen-7B' })
|
||||||
|
const qwenOmni = createModel({ id: 'qwen2.5-omni', name: 'qwen2.5-omni' })
|
||||||
|
const other = createModel({ id: 'deepseek-v3', group: 'DeepSeek' })
|
||||||
|
|
||||||
|
const grouped = groupQwenModels([qwen, qwenOmni, other])
|
||||||
|
expect(Object.keys(grouped)).toContain('qwen-7b')
|
||||||
|
expect(Object.keys(grouped)).toContain('qwen2.5')
|
||||||
|
expect(grouped.DeepSeek).toContain(other)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Vision and image generation', () => {
|
||||||
|
describe('isVisionModels', () => {
|
||||||
|
it('returns true when all models support vision', () => {
|
||||||
|
const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })]
|
||||||
|
expect(isVisionModels(models)).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns false when some models do not support vision', () => {
|
||||||
|
const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })]
|
||||||
|
visionMock.mockReturnValueOnce(true).mockReturnValueOnce(false)
|
||||||
|
expect(isVisionModels(models)).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('isGenerateImageModels', () => {
|
||||||
|
it('returns true when all models support image generation', () => {
|
||||||
|
const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })]
|
||||||
|
generateImageMock.mockReturnValue(true)
|
||||||
|
expect(isGenerateImageModels(models)).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns false when some models do not support image generation', () => {
|
||||||
|
const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })]
|
||||||
|
generateImageMock.mockReturnValueOnce(true).mockReturnValueOnce(false)
|
||||||
|
expect(isGenerateImageModels(models)).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Model filtering', () => {
|
||||||
|
describe('isSupportedModel', () => {
|
||||||
|
it('filters supported OpenAI catalog entries', () => {
|
||||||
|
expect(isSupportedModel({ id: 'gpt-4', object: 'model' } as any)).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('filters unsupported OpenAI catalog entries', () => {
|
||||||
|
expect(isSupportedModel({ id: 'tts-1', object: 'model' } as any)).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('agentModelFilter', () => {
|
||||||
|
it('returns true for regular models', () => {
|
||||||
|
expect(agentModelFilter(createModel())).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('filters out embedding models', () => {
|
||||||
|
embeddingMock.mockReturnValueOnce(true)
|
||||||
|
expect(agentModelFilter(createModel({ id: 'text-embedding' }))).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('filters out rerank models', () => {
|
||||||
|
embeddingMock.mockReturnValue(false)
|
||||||
|
rerankMock.mockReturnValueOnce(true)
|
||||||
|
expect(agentModelFilter(createModel({ id: 'rerank' }))).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('filters out non-function-call models', () => {
|
||||||
|
rerankMock.mockReturnValue(false)
|
||||||
|
isFunctionCallingModelMock.mockReturnValueOnce(false)
|
||||||
|
expect(agentModelFilter(createModel({ id: 'DeepSeek R1' }))).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('filters out text-to-image models', () => {
|
||||||
|
rerankMock.mockReturnValue(false)
|
||||||
|
textToImageMock.mockReturnValueOnce(true)
|
||||||
|
expect(agentModelFilter(createModel({ id: 'gpt-image-1' }))).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
textToImageMock.mockReturnValue(false)
|
||||||
|
generateImageMock.mockReturnValueOnce(true)
|
||||||
|
expect(agentModelFilter(createModel({ id: 'dall-e-3' }))).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Temperature limits', () => {
|
||||||
|
describe('isMaxTemperatureOneModel', () => {
|
||||||
|
it('returns true for Zhipu models', () => {
|
||||||
|
expect(isMaxTemperatureOneModel(createModel({ id: 'glm-4' }))).toBe(true)
|
||||||
|
expect(isMaxTemperatureOneModel(createModel({ id: 'GLM-4-Plus' }))).toBe(true)
|
||||||
|
expect(isMaxTemperatureOneModel(createModel({ id: 'glm-3-turbo' }))).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns true for Anthropic models', () => {
|
||||||
|
expect(isMaxTemperatureOneModel(createModel({ id: 'claude-3.5-sonnet' }))).toBe(true)
|
||||||
|
expect(isMaxTemperatureOneModel(createModel({ id: 'Claude-3-opus' }))).toBe(true)
|
||||||
|
expect(isMaxTemperatureOneModel(createModel({ id: 'claude-2.1' }))).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns true for Moonshot models', () => {
|
||||||
|
expect(isMaxTemperatureOneModel(createModel({ id: 'moonshot-1.0' }))).toBe(true)
|
||||||
|
expect(isMaxTemperatureOneModel(createModel({ id: 'kimi-k2-thinking' }))).toBe(true)
|
||||||
|
expect(isMaxTemperatureOneModel(createModel({ id: 'Moonshot-Pro' }))).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns false for other models', () => {
|
||||||
|
expect(isMaxTemperatureOneModel(createModel({ id: 'gpt-4o' }))).toBe(false)
|
||||||
|
expect(isMaxTemperatureOneModel(createModel({ id: 'gpt-4-turbo' }))).toBe(false)
|
||||||
|
expect(isMaxTemperatureOneModel(createModel({ id: 'qwen-max' }))).toBe(false)
|
||||||
|
expect(isMaxTemperatureOneModel(createModel({ id: 'gemini-pro' }))).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -396,7 +396,11 @@ export function isClaude45ReasoningModel(model: Model): boolean {
|
|||||||
|
|
||||||
export function isClaude4SeriesModel(model: Model): boolean {
|
export function isClaude4SeriesModel(model: Model): boolean {
|
||||||
const modelId = getLowerBaseModelName(model.id, '/')
|
const modelId = getLowerBaseModelName(model.id, '/')
|
||||||
const regex = /claude-(sonnet|opus|haiku)-4(?:[.-]\d+)?(?:-[\w-]+)?$/i
|
// Supports various formats including:
|
||||||
|
// - Direct API: claude-sonnet-4, claude-opus-4-20250514
|
||||||
|
// - GCP Vertex AI: claude-sonnet-4@20250514
|
||||||
|
// - AWS Bedrock: anthropic.claude-sonnet-4-20250514-v1:0
|
||||||
|
const regex = /claude-(sonnet|opus|haiku)-4(?:[.-]\d+)?(?:[@\-:][\w\-:]+)?$/i
|
||||||
return regex.test(modelId)
|
return regex.test(modelId)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -456,16 +460,19 @@ export const isSupportedThinkingTokenZhipuModel = (model: Model): boolean => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
export const isDeepSeekHybridInferenceModel = (model: Model) => {
|
export const isDeepSeekHybridInferenceModel = (model: Model) => {
|
||||||
const modelId = getLowerBaseModelName(model.id)
|
const { idResult, nameResult } = withModelIdAndNameAsId(model, (model) => {
|
||||||
// deepseek官方使用chat和reasoner做推理控制,其他provider需要单独判断,id可能会有所差别
|
const modelId = getLowerBaseModelName(model.id)
|
||||||
// openrouter: deepseek/deepseek-chat-v3.1 不知道会不会有其他provider仿照ds官方分出一个同id的作为非思考模式的模型,这里有风险
|
// deepseek官方使用chat和reasoner做推理控制,其他provider需要单独判断,id可能会有所差别
|
||||||
// Matches: "deepseek-v3" followed by ".digit" or "-digit".
|
// openrouter: deepseek/deepseek-chat-v3.1 不知道会不会有其他provider仿照ds官方分出一个同id的作为非思考模式的模型,这里有风险
|
||||||
// Optionally, this can be followed by ".alphanumeric_sequence" or "-alphanumeric_sequence"
|
// Matches: "deepseek-v3" followed by ".digit" or "-digit".
|
||||||
// until the end of the string.
|
// Optionally, this can be followed by ".alphanumeric_sequence" or "-alphanumeric_sequence"
|
||||||
// Examples: deepseek-v3.1, deepseek-v3-1, deepseek-v3.1.2, deepseek-v3.1-alpha
|
// until the end of the string.
|
||||||
// Does NOT match: deepseek-v3.123 (missing separator after '1'), deepseek-v3.x (x isn't a digit)
|
// Examples: deepseek-v3.1, deepseek-v3-1, deepseek-v3.1.2, deepseek-v3.1-alpha
|
||||||
// TODO: move to utils and add test cases
|
// Does NOT match: deepseek-v3.123 (missing separator after '1'), deepseek-v3.x (x isn't a digit)
|
||||||
return /deepseek-v3(?:\.\d|-\d)(?:(\.|-)\w+)?$/.test(modelId) || modelId.includes('deepseek-chat-v3.1')
|
// TODO: move to utils and add test cases
|
||||||
|
return /deepseek-v3(?:\.\d|-\d)(?:(\.|-)\w+)?$/.test(modelId) || modelId.includes('deepseek-chat-v3.1')
|
||||||
|
})
|
||||||
|
return idResult || nameResult
|
||||||
}
|
}
|
||||||
|
|
||||||
export const isLingReasoningModel = (model?: Model): boolean => {
|
export const isLingReasoningModel = (model?: Model): boolean => {
|
||||||
@@ -519,7 +526,6 @@ export function isReasoningModel(model?: Model): boolean {
|
|||||||
REASONING_REGEX.test(model.name) ||
|
REASONING_REGEX.test(model.name) ||
|
||||||
isSupportedThinkingTokenDoubaoModel(model) ||
|
isSupportedThinkingTokenDoubaoModel(model) ||
|
||||||
isDeepSeekHybridInferenceModel(model) ||
|
isDeepSeekHybridInferenceModel(model) ||
|
||||||
isDeepSeekHybridInferenceModel({ ...model, id: model.name }) ||
|
|
||||||
false
|
false
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
|
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
import type { Model } from '@renderer/types'
|
import type { Model } from '@renderer/types'
|
||||||
import { isSystemProviderId } from '@renderer/types'
|
import { isSystemProviderId } from '@renderer/types'
|
||||||
import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
|
import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
|
||||||
|
import { isAzureOpenAIProvider } from '@shared/provider'
|
||||||
|
|
||||||
import { isEmbeddingModel, isRerankModel } from './embedding'
|
import { isEmbeddingModel, isRerankModel } from './embedding'
|
||||||
import { isDeepSeekHybridInferenceModel } from './reasoning'
|
import { isDeepSeekHybridInferenceModel } from './reasoning'
|
||||||
@@ -52,6 +54,13 @@ export const FUNCTION_CALLING_REGEX = new RegExp(
|
|||||||
'i'
|
'i'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const AZURE_FUNCTION_CALLING_EXCLUDED_MODELS = [
|
||||||
|
'(?:Meta-)?Llama-3(?:\\.\\d+)?-[\\w-]+',
|
||||||
|
'Phi-[34](?:\\.[\\w-]+)?(?:-[\\w-]+)?',
|
||||||
|
'DeepSeek-(?:R1|V3)',
|
||||||
|
'Codestral-2501'
|
||||||
|
]
|
||||||
|
|
||||||
export function isFunctionCallingModel(model?: Model): boolean {
|
export function isFunctionCallingModel(model?: Model): boolean {
|
||||||
if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) {
|
if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) {
|
||||||
return false
|
return false
|
||||||
@@ -67,6 +76,15 @@ export function isFunctionCallingModel(model?: Model): boolean {
|
|||||||
return FUNCTION_CALLING_REGEX.test(modelId) || FUNCTION_CALLING_REGEX.test(model.name)
|
return FUNCTION_CALLING_REGEX.test(modelId) || FUNCTION_CALLING_REGEX.test(model.name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const provider = getProviderByModel(model)
|
||||||
|
|
||||||
|
if (isAzureOpenAIProvider(provider)) {
|
||||||
|
const azureExcludedRegex = new RegExp(`\\b(?:${AZURE_FUNCTION_CALLING_EXCLUDED_MODELS.join('|')})\\b`, 'i')
|
||||||
|
if (azureExcludedRegex.test(modelId)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (['deepseek', 'anthropic', 'kimi', 'moonshot'].includes(model.provider)) {
|
if (['deepseek', 'anthropic', 'kimi', 'moonshot'].includes(model.provider)) {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,20 @@
|
|||||||
import type OpenAI from '@cherrystudio/openai'
|
import type OpenAI from '@cherrystudio/openai'
|
||||||
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding'
|
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding'
|
||||||
|
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
import { type Model, SystemProviderIds } from '@renderer/types'
|
import { type Model, SystemProviderIds } from '@renderer/types'
|
||||||
import type { OpenAIVerbosity, ValidOpenAIVerbosity } from '@renderer/types/aiCoreTypes'
|
import type { OpenAIVerbosity, ValidOpenAIVerbosity } from '@renderer/types/aiCoreTypes'
|
||||||
import { getLowerBaseModelName } from '@renderer/utils'
|
import { getLowerBaseModelName } from '@renderer/utils'
|
||||||
|
|
||||||
import { isOpenAIChatCompletionOnlyModel, isOpenAIOpenWeightModel, isOpenAIReasoningModel } from './openai'
|
import {
|
||||||
|
isGPT5ProModel,
|
||||||
|
isGPT5SeriesModel,
|
||||||
|
isGPT51SeriesModel,
|
||||||
|
isOpenAIChatCompletionOnlyModel,
|
||||||
|
isOpenAIOpenWeightModel,
|
||||||
|
isOpenAIReasoningModel
|
||||||
|
} from './openai'
|
||||||
import { isQwenMTModel } from './qwen'
|
import { isQwenMTModel } from './qwen'
|
||||||
|
import { isFunctionCallingModel } from './tooluse'
|
||||||
import { isGenerateImageModel, isTextToImageModel, isVisionModel } from './vision'
|
import { isGenerateImageModel, isTextToImageModel, isVisionModel } from './vision'
|
||||||
export const NOT_SUPPORTED_REGEX = /(?:^tts|whisper|speech)/i
|
export const NOT_SUPPORTED_REGEX = /(?:^tts|whisper|speech)/i
|
||||||
export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini.*-flash.*$', 'i')
|
export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini.*-flash.*$', 'i')
|
||||||
@@ -123,21 +132,46 @@ export const isNotSupportSystemMessageModel = (model: Model): boolean => {
|
|||||||
return isQwenMTModel(model) || isGemmaModel(model)
|
return isQwenMTModel(model) || isGemmaModel(model)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GPT-5 verbosity configuration
|
// Verbosity settings is only supported by GPT-5 and newer models
|
||||||
|
// Specifically, GPT-5 and GPT-5.1 for now
|
||||||
// gpt-5-pro only supports 'high', other GPT-5 models support all levels
|
// gpt-5-pro only supports 'high', other GPT-5 models support all levels
|
||||||
export const MODEL_SUPPORTED_VERBOSITY: Record<string, ValidOpenAIVerbosity[]> = {
|
const MODEL_SUPPORTED_VERBOSITY: readonly {
|
||||||
'gpt-5-pro': ['high'],
|
readonly validator: (model: Model) => boolean
|
||||||
default: ['low', 'medium', 'high']
|
readonly values: readonly ValidOpenAIVerbosity[]
|
||||||
} as const
|
}[] = [
|
||||||
|
// gpt-5-pro
|
||||||
|
{ validator: isGPT5ProModel, values: ['high'] },
|
||||||
|
// gpt-5 except gpt-5-pro
|
||||||
|
{
|
||||||
|
validator: (model: Model) => isGPT5SeriesModel(model) && !isGPT5ProModel(model),
|
||||||
|
values: ['low', 'medium', 'high']
|
||||||
|
},
|
||||||
|
// gpt-5.1
|
||||||
|
{ validator: isGPT51SeriesModel, values: ['low', 'medium', 'high'] }
|
||||||
|
]
|
||||||
|
|
||||||
export const getModelSupportedVerbosity = (model: Model): OpenAIVerbosity[] => {
|
/**
|
||||||
const modelId = getLowerBaseModelName(model.id)
|
* Returns the list of supported verbosity levels for the given model.
|
||||||
let supportedValues: ValidOpenAIVerbosity[]
|
* If the model is not recognized as a GPT-5 series model, only `undefined` is returned.
|
||||||
if (modelId.includes('gpt-5-pro')) {
|
* For GPT-5-pro, only 'high' is supported; for other GPT-5 models, 'low', 'medium', and 'high' are supported.
|
||||||
supportedValues = MODEL_SUPPORTED_VERBOSITY['gpt-5-pro']
|
* For GPT-5.1 series models, 'low', 'medium', and 'high' are supported.
|
||||||
} else {
|
* @param model - The model to check
|
||||||
supportedValues = MODEL_SUPPORTED_VERBOSITY.default
|
* @returns An array of supported verbosity levels, always including `undefined` as the first element
|
||||||
|
*/
|
||||||
|
export const getModelSupportedVerbosity = (model: Model | undefined | null): OpenAIVerbosity[] => {
|
||||||
|
if (!model) {
|
||||||
|
return [undefined]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let supportedValues: ValidOpenAIVerbosity[] = []
|
||||||
|
|
||||||
|
for (const { validator, values } of MODEL_SUPPORTED_VERBOSITY) {
|
||||||
|
if (validator(model)) {
|
||||||
|
supportedValues = [...values]
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return [undefined, ...supportedValues]
|
return [undefined, ...supportedValues]
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -149,8 +183,21 @@ export const isGeminiModel = (model: Model) => {
|
|||||||
// zhipu 视觉推理模型用这组 special token 标记推理结果
|
// zhipu 视觉推理模型用这组 special token 标记推理结果
|
||||||
export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const
|
export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const
|
||||||
|
|
||||||
|
// TODO: 支持提示词模式的工具调用
|
||||||
export const agentModelFilter = (model: Model): boolean => {
|
export const agentModelFilter = (model: Model): boolean => {
|
||||||
return !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model)
|
const provider = getProviderByModel(model)
|
||||||
|
|
||||||
|
// 需要适配,且容易超出限额
|
||||||
|
if (provider.id === SystemProviderIds.copilot) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return (
|
||||||
|
!isEmbeddingModel(model) &&
|
||||||
|
!isRerankModel(model) &&
|
||||||
|
!isTextToImageModel(model) &&
|
||||||
|
!isGenerateImageModel(model) &&
|
||||||
|
isFunctionCallingModel(model)
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
export const isMaxTemperatureOneModel = (model: Model): boolean => {
|
export const isMaxTemperatureOneModel = (model: Model): boolean => {
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import { throttle } from 'lodash'
|
import { throttle } from 'lodash'
|
||||||
import { useEffect, useRef } from 'react'
|
import { useEffect, useMemo, useRef } from 'react'
|
||||||
|
|
||||||
import { useTimer } from './useTimer'
|
import { useTimer } from './useTimer'
|
||||||
|
|
||||||
@@ -12,13 +12,18 @@ import { useTimer } from './useTimer'
|
|||||||
*/
|
*/
|
||||||
export default function useScrollPosition(key: string, throttleWait?: number) {
|
export default function useScrollPosition(key: string, throttleWait?: number) {
|
||||||
const containerRef = useRef<HTMLDivElement>(null)
|
const containerRef = useRef<HTMLDivElement>(null)
|
||||||
const scrollKey = `scroll:${key}`
|
const scrollKey = useMemo(() => `scroll:${key}`, [key])
|
||||||
|
const scrollKeyRef = useRef(scrollKey)
|
||||||
const { setTimeoutTimer } = useTimer()
|
const { setTimeoutTimer } = useTimer()
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
scrollKeyRef.current = scrollKey
|
||||||
|
}, [scrollKey])
|
||||||
|
|
||||||
const handleScroll = throttle(() => {
|
const handleScroll = throttle(() => {
|
||||||
const position = containerRef.current?.scrollTop ?? 0
|
const position = containerRef.current?.scrollTop ?? 0
|
||||||
window.requestAnimationFrame(() => {
|
window.requestAnimationFrame(() => {
|
||||||
window.keyv.set(scrollKey, position)
|
window.keyv.set(scrollKeyRef.current, position)
|
||||||
})
|
})
|
||||||
}, throttleWait ?? 100)
|
}, throttleWait ?? 100)
|
||||||
|
|
||||||
@@ -28,5 +33,9 @@ export default function useScrollPosition(key: string, throttleWait?: number) {
|
|||||||
setTimeoutTimer('scrollEffect', scroll, 50)
|
setTimeoutTimer('scrollEffect', scroll, 50)
|
||||||
}, [scrollKey, setTimeoutTimer])
|
}, [scrollKey, setTimeoutTimer])
|
||||||
|
|
||||||
|
useEffect(() => {
|
||||||
|
return () => handleScroll.cancel()
|
||||||
|
}, [handleScroll])
|
||||||
|
|
||||||
return { containerRef, handleScroll }
|
return { containerRef, handleScroll }
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
import { useEffect, useRef } from 'react'
|
import { useCallback, useEffect, useRef } from 'react'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 定时器管理 Hook,用于管理 setTimeout 和 setInterval 定时器,支持通过 key 来标识不同的定时器
|
* 定时器管理 Hook,用于管理 setTimeout 和 setInterval 定时器,支持通过 key 来标识不同的定时器
|
||||||
@@ -43,10 +43,38 @@ export const useTimer = () => {
|
|||||||
const timeoutMapRef = useRef(new Map<string, NodeJS.Timeout>())
|
const timeoutMapRef = useRef(new Map<string, NodeJS.Timeout>())
|
||||||
const intervalMapRef = useRef(new Map<string, NodeJS.Timeout>())
|
const intervalMapRef = useRef(new Map<string, NodeJS.Timeout>())
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 清除指定 key 的 setTimeout 定时器
|
||||||
|
* @param key - 定时器标识符
|
||||||
|
*/
|
||||||
|
const clearTimeoutTimer = useCallback((key: string) => {
|
||||||
|
clearTimeout(timeoutMapRef.current.get(key))
|
||||||
|
timeoutMapRef.current.delete(key)
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 清除指定 key 的 setInterval 定时器
|
||||||
|
* @param key - 定时器标识符
|
||||||
|
*/
|
||||||
|
const clearIntervalTimer = useCallback((key: string) => {
|
||||||
|
clearInterval(intervalMapRef.current.get(key))
|
||||||
|
intervalMapRef.current.delete(key)
|
||||||
|
}, [])
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 清除所有定时器,包括 setTimeout 和 setInterval
|
||||||
|
*/
|
||||||
|
const clearAllTimers = useCallback(() => {
|
||||||
|
timeoutMapRef.current.forEach((timer) => clearTimeout(timer))
|
||||||
|
intervalMapRef.current.forEach((timer) => clearInterval(timer))
|
||||||
|
timeoutMapRef.current.clear()
|
||||||
|
intervalMapRef.current.clear()
|
||||||
|
}, [])
|
||||||
|
|
||||||
// 组件卸载时自动清理所有定时器
|
// 组件卸载时自动清理所有定时器
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
return () => clearAllTimers()
|
return () => clearAllTimers()
|
||||||
}, [])
|
}, [clearAllTimers])
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 设置一个 setTimeout 定时器
|
* 设置一个 setTimeout 定时器
|
||||||
@@ -65,12 +93,15 @@ export const useTimer = () => {
|
|||||||
* cleanup();
|
* cleanup();
|
||||||
* ```
|
* ```
|
||||||
*/
|
*/
|
||||||
const setTimeoutTimer = (key: string, ...args: Parameters<typeof setTimeout>) => {
|
const setTimeoutTimer = useCallback(
|
||||||
clearTimeout(timeoutMapRef.current.get(key))
|
(key: string, ...args: Parameters<typeof setTimeout>) => {
|
||||||
const timer = setTimeout(...args)
|
clearTimeout(timeoutMapRef.current.get(key))
|
||||||
timeoutMapRef.current.set(key, timer)
|
const timer = setTimeout(...args)
|
||||||
return () => clearTimeoutTimer(key)
|
timeoutMapRef.current.set(key, timer)
|
||||||
}
|
return () => clearTimeoutTimer(key)
|
||||||
|
},
|
||||||
|
[clearTimeoutTimer]
|
||||||
|
)
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 设置一个 setInterval 定时器
|
* 设置一个 setInterval 定时器
|
||||||
@@ -89,56 +120,31 @@ export const useTimer = () => {
|
|||||||
* cleanup();
|
* cleanup();
|
||||||
* ```
|
* ```
|
||||||
*/
|
*/
|
||||||
const setIntervalTimer = (key: string, ...args: Parameters<typeof setInterval>) => {
|
const setIntervalTimer = useCallback(
|
||||||
clearInterval(intervalMapRef.current.get(key))
|
(key: string, ...args: Parameters<typeof setInterval>) => {
|
||||||
const timer = setInterval(...args)
|
clearInterval(intervalMapRef.current.get(key))
|
||||||
intervalMapRef.current.set(key, timer)
|
const timer = setInterval(...args)
|
||||||
return () => clearIntervalTimer(key)
|
intervalMapRef.current.set(key, timer)
|
||||||
}
|
return () => clearIntervalTimer(key)
|
||||||
|
},
|
||||||
/**
|
[clearIntervalTimer]
|
||||||
* 清除指定 key 的 setTimeout 定时器
|
)
|
||||||
* @param key - 定时器标识符
|
|
||||||
*/
|
|
||||||
const clearTimeoutTimer = (key: string) => {
|
|
||||||
clearTimeout(timeoutMapRef.current.get(key))
|
|
||||||
timeoutMapRef.current.delete(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* 清除指定 key 的 setInterval 定时器
|
|
||||||
* @param key - 定时器标识符
|
|
||||||
*/
|
|
||||||
const clearIntervalTimer = (key: string) => {
|
|
||||||
clearInterval(intervalMapRef.current.get(key))
|
|
||||||
intervalMapRef.current.delete(key)
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 清除所有 setTimeout 定时器
|
* 清除所有 setTimeout 定时器
|
||||||
*/
|
*/
|
||||||
const clearAllTimeoutTimers = () => {
|
const clearAllTimeoutTimers = useCallback(() => {
|
||||||
timeoutMapRef.current.forEach((timer) => clearTimeout(timer))
|
timeoutMapRef.current.forEach((timer) => clearTimeout(timer))
|
||||||
timeoutMapRef.current.clear()
|
timeoutMapRef.current.clear()
|
||||||
}
|
}, [])
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 清除所有 setInterval 定时器
|
* 清除所有 setInterval 定时器
|
||||||
*/
|
*/
|
||||||
const clearAllIntervalTimers = () => {
|
const clearAllIntervalTimers = useCallback(() => {
|
||||||
intervalMapRef.current.forEach((timer) => clearInterval(timer))
|
intervalMapRef.current.forEach((timer) => clearInterval(timer))
|
||||||
intervalMapRef.current.clear()
|
intervalMapRef.current.clear()
|
||||||
}
|
}, [])
|
||||||
|
|
||||||
/**
|
|
||||||
* 清除所有定时器,包括 setTimeout 和 setInterval
|
|
||||||
*/
|
|
||||||
const clearAllTimers = () => {
|
|
||||||
timeoutMapRef.current.forEach((timer) => clearTimeout(timer))
|
|
||||||
intervalMapRef.current.forEach((timer) => clearInterval(timer))
|
|
||||||
timeoutMapRef.current.clear()
|
|
||||||
intervalMapRef.current.clear()
|
|
||||||
}
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
setTimeoutTimer,
|
setTimeoutTimer,
|
||||||
|
|||||||
@@ -280,6 +280,7 @@
|
|||||||
"denied": "Tool request was denied.",
|
"denied": "Tool request was denied.",
|
||||||
"timeout": "Tool request timed out before receiving approval."
|
"timeout": "Tool request timed out before receiving approval."
|
||||||
},
|
},
|
||||||
|
"toolPendingFallback": "Tool",
|
||||||
"waiting": "Waiting for tool permission decision..."
|
"waiting": "Waiting for tool permission decision..."
|
||||||
},
|
},
|
||||||
"type": {
|
"type": {
|
||||||
@@ -1208,7 +1209,7 @@
|
|||||||
"endpoint_type": {
|
"endpoint_type": {
|
||||||
"anthropic": "Anthropic",
|
"anthropic": "Anthropic",
|
||||||
"gemini": "Gemini",
|
"gemini": "Gemini",
|
||||||
"image-generation": "Image Generation",
|
"image-generation": "Image Generation (OpenAI)",
|
||||||
"jina-rerank": "Jina Rerank",
|
"jina-rerank": "Jina Rerank",
|
||||||
"openai": "OpenAI",
|
"openai": "OpenAI",
|
||||||
"openai-response": "OpenAI-Response"
|
"openai-response": "OpenAI-Response"
|
||||||
|
|||||||
@@ -280,6 +280,7 @@
|
|||||||
"denied": "工具请求已被拒绝。",
|
"denied": "工具请求已被拒绝。",
|
||||||
"timeout": "工具请求在收到批准前超时。"
|
"timeout": "工具请求在收到批准前超时。"
|
||||||
},
|
},
|
||||||
|
"toolPendingFallback": "工具",
|
||||||
"waiting": "等待工具权限决定..."
|
"waiting": "等待工具权限决定..."
|
||||||
},
|
},
|
||||||
"type": {
|
"type": {
|
||||||
@@ -1208,7 +1209,7 @@
|
|||||||
"endpoint_type": {
|
"endpoint_type": {
|
||||||
"anthropic": "Anthropic",
|
"anthropic": "Anthropic",
|
||||||
"gemini": "Gemini",
|
"gemini": "Gemini",
|
||||||
"image-generation": "图片生成",
|
"image-generation": "图像生成 (OpenAI)",
|
||||||
"jina-rerank": "Jina 重排序",
|
"jina-rerank": "Jina 重排序",
|
||||||
"openai": "OpenAI",
|
"openai": "OpenAI",
|
||||||
"openai-response": "OpenAI-Response"
|
"openai-response": "OpenAI-Response"
|
||||||
|
|||||||
@@ -280,6 +280,7 @@
|
|||||||
"denied": "工具請求已被拒絕。",
|
"denied": "工具請求已被拒絕。",
|
||||||
"timeout": "工具請求在收到核准前逾時。"
|
"timeout": "工具請求在收到核准前逾時。"
|
||||||
},
|
},
|
||||||
|
"toolPendingFallback": "工具",
|
||||||
"waiting": "等待工具權限決定..."
|
"waiting": "等待工具權限決定..."
|
||||||
},
|
},
|
||||||
"type": {
|
"type": {
|
||||||
@@ -1208,7 +1209,7 @@
|
|||||||
"endpoint_type": {
|
"endpoint_type": {
|
||||||
"anthropic": "Anthropic",
|
"anthropic": "Anthropic",
|
||||||
"gemini": "Gemini",
|
"gemini": "Gemini",
|
||||||
"image-generation": "圖片生成",
|
"image-generation": "圖像生成 (OpenAI)",
|
||||||
"jina-rerank": "Jina Rerank",
|
"jina-rerank": "Jina Rerank",
|
||||||
"openai": "OpenAI",
|
"openai": "OpenAI",
|
||||||
"openai-response": "OpenAI-Response"
|
"openai-response": "OpenAI-Response"
|
||||||
|
|||||||
@@ -280,6 +280,7 @@
|
|||||||
"denied": "Tool-Anfrage wurde abgelehnt.",
|
"denied": "Tool-Anfrage wurde abgelehnt.",
|
||||||
"timeout": "Tool-Anfrage ist abgelaufen, bevor eine Genehmigung eingegangen ist."
|
"timeout": "Tool-Anfrage ist abgelaufen, bevor eine Genehmigung eingegangen ist."
|
||||||
},
|
},
|
||||||
|
"toolPendingFallback": "Werkzeug",
|
||||||
"waiting": "Warten auf Entscheidung über Tool-Berechtigung..."
|
"waiting": "Warten auf Entscheidung über Tool-Berechtigung..."
|
||||||
},
|
},
|
||||||
"type": {
|
"type": {
|
||||||
@@ -1208,7 +1209,7 @@
|
|||||||
"endpoint_type": {
|
"endpoint_type": {
|
||||||
"anthropic": "Anthropic",
|
"anthropic": "Anthropic",
|
||||||
"gemini": "Gemini",
|
"gemini": "Gemini",
|
||||||
"image-generation": "Bildgenerierung",
|
"image-generation": "Bilderzeugung (OpenAI)",
|
||||||
"jina-rerank": "Jina Reranking",
|
"jina-rerank": "Jina Reranking",
|
||||||
"openai": "OpenAI",
|
"openai": "OpenAI",
|
||||||
"openai-response": "OpenAI-Response"
|
"openai-response": "OpenAI-Response"
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user