Compare commits

..

3 Commits

115 changed files with 1196 additions and 31176 deletions

View File

@@ -107,9 +107,11 @@ afterSign: scripts/notarize.js
artifactBuildCompleted: scripts/artifact-build-completed.js
releaseInfo:
releaseNotes: |
- 新功能:可选数据保存目录
- 快捷助手:支持单独选择助手,支持暂停、上下文、思考过程、流式
- 划词助手:系统托盘菜单开关
- 翻译:新增 Markdown 预览选项
- 新供应商:新增 Vertex AI 服务商
- 错误修复和界面优化
划词助手:支持文本选择快捷键、开关快捷键、思考块支持和引用功能
复制功能新增纯文本复制去除Markdown格式符号
知识库支持设置向量维度修复Ollama分数错误和维度编辑问题
多语言:增加模型名称多语言提示和翻译源语言手动选择
文件管理:修复主题/消息删除时文件未清理问题,优化文件选择流程
模型修复Gemini模型推理预算、Voyage AI嵌入问题和DeepSeek翻译模型更新
图像功能统一图片查看器支持Base64图片渲染修复图片预览相关问题
UI实现标签折叠/拖拽排序,修复气泡溢出,增加引文索引显示

View File

@@ -68,16 +68,12 @@ export default defineConfig({
}
},
optimizeDeps: {
exclude: ['pyodide'],
esbuildOptions: {
target: 'esnext' // for dev
}
exclude: ['pyodide']
},
worker: {
format: 'es'
},
build: {
target: 'esnext', // for build
rollupOptions: {
input: {
index: resolve(__dirname, 'src/renderer/index.html'),

View File

@@ -1,6 +1,6 @@
{
"name": "CherryStudio",
"version": "1.4.4",
"version": "1.4.2",
"private": true,
"description": "A powerful AI assistant for producer.",
"main": "./out/main/index.js",
@@ -73,7 +73,6 @@
"@agentic/tavily": "^7.3.3",
"@ant-design/v5-patch-for-react-19": "^1.0.3",
"@anthropic-ai/sdk": "^0.41.0",
"@cherrystudio/ai-core": "workspace:*",
"@cherrystudio/embedjs": "^0.1.31",
"@cherrystudio/embedjs-libsql": "^0.1.31",
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
@@ -191,7 +190,7 @@
"react-hotkeys-hook": "^4.6.1",
"react-i18next": "^14.1.2",
"react-infinite-scroll-component": "^6.1.0",
"react-markdown": "^10.1.0",
"react-markdown": "^9.0.1",
"react-redux": "^9.1.2",
"react-router": "6",
"react-router-dom": "6",
@@ -200,10 +199,10 @@
"redux": "^5.0.1",
"redux-persist": "^6.0.0",
"rehype-katex": "^7.0.1",
"rehype-mathjax": "^7.1.0",
"rehype-mathjax": "^7.0.0",
"rehype-raw": "^7.0.0",
"remark-cjk-friendly": "^1.2.0",
"remark-gfm": "^4.0.1",
"remark-cjk-friendly": "^1.1.0",
"remark-gfm": "^4.0.0",
"remark-math": "^6.0.0",
"remove-markdown": "^0.6.2",
"rollup-plugin-visualizer": "^5.12.0",

View File

@@ -1,622 +0,0 @@
# Cherry Studio AI Core 基于 Vercel AI SDK 的技术架构
## 1. 架构设计理念
### 1.1 设计目标
- **统一接口**:使用 Vercel AI SDK 统一不同 AI Provider 的接口差异
- **动态导入**:通过动态导入实现按需加载,减少打包体积
- **最小包装**:直接使用 AI SDK 的类型和接口,避免重复定义
- **插件系统**:基于钩子的插件架构,支持请求全生命周期扩展
- **类型安全**:利用 TypeScript 和 AI SDK 的类型系统确保类型安全
- **轻量级**:专注核心功能,保持包的轻量和高效
- **包级独立**:作为独立包管理,便于复用和维护
### 1.2 核心优势
- **标准化**AI SDK 提供统一的模型接口,减少适配工作
- **简化维护**:废弃复杂的 XxxApiClient统一为工厂函数模式
- **更好的开发体验**:完整的 TypeScript 支持和丰富的生态系统
- **性能优化**AI SDK 内置优化和最佳实践
- **模块化设计**:独立包结构,支持跨项目复用
- **可扩展插件**:基于钩子的插件系统,支持灵活的功能扩展和流转换
## 2. 整体架构图
```mermaid
graph TD
subgraph "Cherry Studio 主应用"
UI["用户界面"]
Components["React 组件"]
end
subgraph "packages/aiCore (AI Core 包)"
ApiClientFactory["ApiClientFactory (工厂类)"]
UniversalClient["UniversalAiSdkClient (统一客户端)"]
ProviderRegistry["Provider 注册表"]
PluginManager["插件管理器"]
end
subgraph "动态导入层"
DynamicImport["动态导入"]
end
subgraph "Vercel AI SDK"
AICore["ai (核心库)"]
OpenAI["@ai-sdk/openai"]
Anthropic["@ai-sdk/anthropic"]
Google["@ai-sdk/google"]
XAI["@ai-sdk/xai"]
Others["其他 19+ Providers"]
end
subgraph "插件生态"
FirstHooks["First Hooks (resolveModel, loadTemplate)"]
SequentialHooks["Sequential Hooks (transformParams, transformResult)"]
ParallelHooks["Parallel Hooks (onRequestStart, onRequestEnd, onError)"]
StreamHooks["Stream Hooks (transformStream)"]
end
UI --> ApiClientFactory
Components --> ApiClientFactory
ApiClientFactory --> UniversalClient
UniversalClient --> PluginManager
PluginManager --> ProviderRegistry
ProviderRegistry --> DynamicImport
DynamicImport --> OpenAI
DynamicImport --> Anthropic
DynamicImport --> Google
DynamicImport --> XAI
DynamicImport --> Others
UniversalClient --> AICore
AICore --> streamText
AICore --> generateText
PluginManager --> FirstHooks
PluginManager --> SequentialHooks
PluginManager --> ParallelHooks
PluginManager --> StreamHooks
```
## 3. 包结构设计
### 3.1 包级文件结构(当前简化版 + 规划)
```
packages/aiCore/
├── src/
│ ├── providers/
│ │ ├── registry.ts # Provider 注册表 ✅
│ │ └── types.ts # 核心类型定义 ✅
│ ├── clients/
│ │ ├── UniversalAiSdkClient.ts # 统一AI SDK客户端 ✅
│ │ └── ApiClientFactory.ts # 客户端工厂 ✅
│ ├── middleware/ # 插件系统 ✅
│ │ ├── types.ts # 插件类型定义 ✅
│ │ ├── manager.ts # 插件管理器 ✅
│ │ ├── examples/ # 示例插件 ✅
│ │ │ ├── example-plugins.ts # 示例插件实现 ✅
│ │ │ └── example-usage.ts # 使用示例 ✅
│ │ ├── README.md # 插件系统文档 ✅
│ │ └── index.ts # 插件模块入口 ✅
│ ├── services/ # 高级服务 (规划中)
│ │ ├── AiCoreService.ts # 统一服务入口
│ │ ├── CompletionsService.ts # 文本生成服务
│ │ ├── EmbeddingService.ts # 嵌入服务
│ │ └── ImageService.ts # 图像生成服务
│ └── index.ts # 包主入口文件 ✅
├── package.json # 包配置文件 ✅
├── tsconfig.json # TypeScript 配置 ✅
├── README.md # 包说明文档 ✅
└── AI_SDK_ARCHITECTURE.md # 本文档 ✅
```
**图例:**
- ✅ 已实现
- 规划中:设计完成,待实现
## 4. 核心组件详解
### 4.1 Provider 注册表 (`providers/registry.ts`)
统一管理所有 AI Provider 的注册和动态导入。
**主要功能:**
- 动态导入 AI SDK providers
- 提供统一的 Provider 创建接口
- 支持 19+ 官方 AI SDK providers
- 类型安全的 Provider 配置
**核心 API**
```typescript
export interface ProviderConfig {
id: string
name: string
import: () => Promise<any>
creatorFunctionName: string
}
export class AiProviderRegistry {
getProvider(id: string): ProviderConfig | undefined
getAllProviders(): ProviderConfig[]
isSupported(id: string): boolean
registerProvider(config: ProviderConfig): void
}
```
**支持的 Providers**
- OpenAI, Anthropic, Google, XAI
- Azure OpenAI, Amazon Bedrock, Google Vertex
- Groq, Together.ai, Fireworks, DeepSeek
- Cerebras, DeepInfra, Replicate, Perplexity
- Cohere, Fal AI, Vercel (19+ providers)
### 4.2 统一AI SDK客户端 (`clients/UniversalAiSdkClient.ts`)
将不同 AI providers 包装为统一接口。
**主要功能:**
- 异步初始化和动态加载
- 统一的 stream() 和 generate() 方法
- 直接使用 AI SDK 的 streamText() 和 generateText()
- 配置验证和错误处理
**核心 API**
```typescript
export class UniversalAiSdkClient {
async initialize(): Promise<void>
isInitialized(): boolean
async stream(request: any): Promise<any>
async generate(request: any): Promise<any>
validateConfig(): boolean
getProviderInfo(): { id: string; name: string; isInitialized: boolean }
}
```
### 4.3 客户端工厂 (`clients/ApiClientFactory.ts`)
统一创建和管理 AI SDK 客户端。
**主要功能:**
- 统一的客户端创建接口
- 智能缓存和复用机制
- 批量创建和健康检查
- 错误处理和重试
**核心 API**
```typescript
export class ApiClientFactory {
static async createAiSdkClient(providerId: string, options: any): Promise<UniversalAiSdkClient>
static getCachedClient(providerId: string, options: any): UniversalAiSdkClient | undefined
static clearCache(): void
static async healthCheck(): Promise<HealthCheckResult>
static getSupportedProviders(): ProviderInfo[]
}
```
### 4.4 钩子风格插件系统 ✅
基于钩子机制的插件架构设计,提供灵活的扩展系统。
**钩子类型:**
1. **First Hooks**:执行到第一个有效结果就停止
2. **Sequential Hooks**:按序链式执行,可变换数据
3. **Parallel Hooks**:并发执行,用于副作用
4. **Stream Hooks**:流转换,直接传递给 AI SDK
**优先级系统:**
- `pre`:前置处理(-100 到 -1
- `normal`标准处理0 到 99
- `post`后置处理100 到 199
**核心钩子:**
**First Hooks (第一个有效结果)**
- `resolveModel`:模型解析,返回第一个匹配的模型
- `loadTemplate`:模板加载,返回第一个找到的模板
**Sequential Hooks (链式变换)**
- `transformParams`:参数转换,依次变换请求参数
- `transformResult`:结果转换,依次变换响应结果
**Parallel Hooks (并发副作用)**
- `onRequestStart`:请求开始时触发
- `onRequestEnd`:请求结束时触发
- `onError`:错误发生时触发
**Stream Hooks (流转换)**
- `transformStream`:流转换,返回 AI SDK 转换函数
**插件 API 设计:**
```typescript
export interface Plugin {
name: string
enforce?: 'pre' | 'normal' | 'post'
// First hooks - 执行到第一个有效结果
resolveModel?(params: ResolveModelParams): Promise<string | null>
loadTemplate?(params: LoadTemplateParams): Promise<Template | null>
// Sequential hooks - 链式变换
transformParams?(params: any, context: PluginContext): Promise<any>
transformResult?(result: any, context: PluginContext): Promise<any>
// Parallel hooks - 并发副作用
onRequestStart?(context: PluginContext): Promise<void>
onRequestEnd?(context: PluginContext): Promise<void>
onError?(error: Error, context: PluginContext): Promise<void>
// Stream hooks - AI SDK 流转换
transformStream?(context: PluginContext): Promise<(readable: ReadableStream) => ReadableStream>
}
export interface PluginContext {
request: any
response?: any
metadata: Record<string, any>
provider: string
model: string
}
export class PluginManager {
use(plugin: Plugin): this
executeFirstHook<T>(hookName: string, ...args: any[]): Promise<T | null>
executeSequentialHook<T>(hookName: string, initialValue: T, context: PluginContext): Promise<T>
executeParallelHook(hookName: string, ...args: any[]): Promise<void>
collectStreamTransforms(context: PluginContext): Promise<Array<(readable: ReadableStream) => ReadableStream>>
}
```
### 4.5 统一服务接口 (规划中)
作为包的主要对外接口,提供高级 AI 功能。
**服务方法:**
- `completions()`: 文本生成
- `streamCompletions()`: 流式文本生成
- `generateObject()`: 结构化数据生成
- `generateImage()`: 图像生成
- `embed()`: 文本嵌入
**API 设计:**
```typescript
export class AiCoreService {
constructor(middlewares?: Middleware[])
async completions(request: CompletionRequest): Promise<CompletionResponse>
async streamCompletions(request: CompletionRequest): Promise<StreamCompletionResponse>
async generateObject<T>(request: ObjectGenerationRequest): Promise<T>
async generateImage(request: ImageGenerationRequest): Promise<ImageResponse>
async embed(request: EmbeddingRequest): Promise<EmbeddingResponse>
use(middleware: Middleware): this
configure(config: AiCoreConfig): this
}
```
## 5. 使用方式
### 5.1 多 Provider 支持
```typescript
import { createAiSdkClient, AiCore } from '@cherrystudio/ai-core'
// 检查支持的 providers
const providers = AiCore.getSupportedProviders()
console.log(`支持 ${providers.length} 个 AI providers`)
// 创建多个 provider 客户端
const openai = await createAiSdkClient('openai', { apiKey: 'openai-key' })
const anthropic = await createAiSdkClient('anthropic', { apiKey: 'anthropic-key' })
const google = await createAiSdkClient('google', { apiKey: 'google-key' })
const xai = await createAiSdkClient('xai', { apiKey: 'xai-key' })
```
### 5.2 在 Cherry Studio 中集成
```typescript
// 替换现有的 XxxApiClient
// 之前:
// const openaiClient = new OpenAIApiClient(config)
// const anthropicClient = new AnthropicApiClient(config)
// 现在:
import { createAiSdkClient } from '@cherrystudio/ai-core'
const createProviderClient = async (provider: CherryProvider) => {
return await createAiSdkClient(provider.id, {
apiKey: provider.apiKey,
baseURL: provider.baseURL
})
}
```
### 5.6 完整的工作流示例 (规划中)
```typescript
import {
createAiSdkClient,
AiCoreService,
MiddlewareChain,
PreRequestMiddleware,
StreamProcessingMiddleware,
PostResponseMiddleware
} from '@cherrystudio/ai-core'
// 创建完整的工作流
const createEnhancedAiService = async () => {
// 创建中间件链
const middlewareChain = new MiddlewareChain()
.use(
new PreRequestMiddleware({
validateApiKey: true,
checkRateLimit: true
})
)
.use(
new StreamProcessingMiddleware({
enableProgressTracking: true,
chunkTransform: (chunk) => ({
...chunk,
timestamp: Date.now()
})
})
)
.use(
new PostResponseMiddleware({
saveToHistory: true,
calculateMetrics: true
})
)
// 创建服务实例
const service = new AiCoreService(middlewareChain.middlewares)
return service
}
// 使用增强服务
const enhancedService = await createEnhancedAiService()
const response = await enhancedService.completions({
provider: 'anthropic',
model: 'claude-3-sonnet',
messages: [{ role: 'user', content: 'Write a technical blog post about AI middleware' }],
options: {
temperature: 0.7,
maxTokens: 2000
},
middleware: {
// 中间件特定配置
thinking: { recordSteps: true },
cache: { enabled: true, ttl: 1800 },
logging: { level: 'debug' }
}
})
```
## 6. 简化设计原则
### 6.1 最小包装原则
- 直接使用 AI SDK 的类型,不重复定义
- 避免过度抽象和复杂的中间层
- 保持与 AI SDK 原生 API 的一致性
### 6.2 动态导入优化
```typescript
// 按需加载,减少打包体积
const module = await import('@ai-sdk/openai')
const createOpenAI = module.createOpenAI
```
### 6.3 类型安全
```typescript
// 直接使用 AI SDK 类型
import { streamText, generateText } from 'ai'
// 避免重复定义,直接传递参数
return streamText({ model, ...request })
```
### 6.4 配置简化
```typescript
// 简化的 Provider 配置
interface ProviderConfig {
id: string // provider 标识
name: string // 显示名称
import: () => Promise<any> // 动态导入函数
creatorFunctionName: string // 创建函数名
}
```
## 7. 技术要点
### 7.1 动态导入策略
- **按需加载**:只加载用户实际使用的 providers
- **缓存机制**:避免重复导入和初始化
- **错误处理**:优雅处理导入失败的情况
### 7.2 依赖管理策略
- **核心依赖**`ai` 库作为必需依赖
- **可选依赖**:所有 `@ai-sdk/*` 包都是可选的
- **版本兼容**:支持 AI SDK v3-v5 版本
### 7.3 缓存策略
- **客户端缓存**:基于 provider + options 的智能缓存
- **配置哈希**:安全的 API key 哈希处理
- **生命周期管理**:支持缓存清理和验证
## 8. 迁移策略
### 8.1 阶段一:包基础搭建 (Week 1) ✅ 已完成
1. ✅ 创建简化的包结构
2. ✅ 实现 Provider 注册表
3. ✅ 创建统一客户端和工厂
4. ✅ 配置构建和类型系统
### 8.2 阶段二:核心功能完善 (Week 2) ✅ 已完成
1. ✅ 支持 19+ 官方 AI SDK providers
2. ✅ 实现缓存和错误处理
3. ✅ 完善类型安全和 API 设计
4. ✅ 添加便捷函数和工具
### 8.3 阶段三:集成测试 (Week 3) 🔄 进行中
1. 在 Cherry Studio 中集成测试
2. 功能完整性验证
3. 性能基准测试
4. 兼容性问题修复
### 8.4 阶段四:插件系统实现 ✅ 已完成
1. **插件核心架构**
- 实现 `PluginManager``PluginContext`
- 创建钩子风格插件接口和类型系统
- 建立四种钩子类型执行机制
2. **钩子系统**
- `First Hooks`:第一个有效结果执行
- `Sequential Hooks`:链式数据变换
- `Parallel Hooks`:并发副作用处理
- `Stream Hooks`AI SDK 流转换集成
3. **优先级和排序**
- `pre`/`normal`/`post` 优先级系统
- 插件注册顺序维护
- 错误处理和插件隔离
4. **集成到现有架构**
-`UniversalAiSdkClient` 中集成插件管理器
- 更新 `ApiClientFactory` 支持插件配置
- 创建示例插件和使用文档
### 8.5 阶段五:特性插件扩展 (规划中)
1. **Cherry Studio 特性插件**
- `ThinkingPlugin`:思考过程记录和提取
- `ToolCallPlugin`:工具调用处理和增强
- `WebSearchPlugin`:网络搜索集成
2. **高级功能**
- 插件组合和条件执行
- 动态插件加载系统
- 插件配置管理和持久化
### 8.6 阶段六:文档和发布 (Week 7) 📋 规划中
1. 完善使用文档和示例
2. 插件开发指南和最佳实践
3. 准备发布到 npm
4. 建立维护流程
### 8.7 阶段七:生态系统扩展 (Week 8+) 🚀 未来规划
1. 社区插件生态系统
2. 可视化插件编排工具
3. 性能监控和分析
4. 高级缓存和优化策略
## 9. 预期收益
### 9.1 开发效率提升
- **90%** 减少新 Provider 接入时间(只需添加注册表配置)
- **70%** 减少维护工作量
- **95%** 提升开发体验(统一接口 + 类型安全)
- **独立开发**:可以独立于主应用开发和测试
### 9.2 代码质量改善
- 完整的 TypeScript 类型安全
- 统一的错误处理机制
- 标准化的 AI SDK 接口
- 更好的测试覆盖率
### 9.3 架构优势
- **轻量级**:最小化的包装层
- **可复用**:其他项目可以直接使用
- **可维护**:独立版本管理和发布
- **可扩展**:新 provider 只需配置即可
### 9.4 生态系统价值
- 支持 AI SDK 的完整生态系统
- 可以独立发布到 npm
- 为开源社区贡献价值
- 建立统一的 AI 基础设施
## 10. 风险评估与应对
### 10.1 技术风险
- **AI SDK 版本兼容**:支持多版本兼容策略
- **依赖管理**:合理使用 peerDependencies
- **类型一致性**:直接使用 AI SDK 类型
- **性能影响**:最小化包装层开销
### 10.2 迁移风险
- **功能对等性**:确保所有现有功能都能实现
- **API 兼容性**:提供平滑的迁移路径
- **集成复杂度**:保持简单的集成方式
- **学习成本**:提供清晰的使用文档
## 11. 总结
简化的 AI Core 架构专注于核心价值:
### 11.1 核心价值
- **统一接口**:一套 API 支持 19+ AI providers
- **按需加载**:只打包用户实际使用的 providers
- **类型安全**:完整的 TypeScript 支持
- **轻量高效**:最小化的包装层
### 11.2 设计哲学
- **直接使用 AI SDK**:避免重复造轮子,充分利用原生能力
- **最小包装**:只在必要时添加抽象层,保持轻量高效
- **开发者友好**:简单易用的 API 设计,熟悉的钩子风格
- **生态兼容**:充分利用 AI SDK 生态系统和原生流转换
- **插件优先**:基于钩子的扩展模式,支持灵活组合
### 11.3 成功关键
1. **保持简单**:专注核心功能,避免过度设计
2. **充分测试**:确保功能完整性和稳定性
3. **渐进迁移**:平滑过渡,降低风险
4. **文档完善**:支持快速上手和深度使用
这个基于钩子的插件系统架构为 Cherry Studio 提供了一个轻量、高效、可维护的 AI 基础设施,通过熟悉的钩子模式和原生 AI SDK 集成,为开发者提供了强大而简洁的扩展能力,同时为社区贡献了一个高质量的开源包。

View File

@@ -1,222 +0,0 @@
# @cherrystudio/ai-core
Cherry Studio AI Core 是一个基于 Vercel AI SDK 的统一 AI Provider 接口包。
## 特性
- 🚀 统一的 AI Provider 接口
- 🔄 动态导入支持
- 💾 智能缓存机制
- 🛠️ TypeScript 支持
- 📦 轻量级设计
## 支持的 Providers
基于 [AI SDK 官方支持的 providers](https://ai-sdk.dev/providers/ai-sdk-providers)
**核心 Providers:**
- OpenAI
- Anthropic
- Google Generative AI
- Google Vertex AI
- Mistral AI
- xAI (Grok)
- Azure OpenAI
- Amazon Bedrock
**扩展 Providers:**
- Cohere
- Groq
- Together.ai
- Fireworks
- DeepSeek
- Cerebras
- DeepInfra
- Replicate
- Perplexity
- Fal AI
- Vercel
## 安装
```bash
npm install @cherrystudio/ai-core ai
```
还需要安装你要使用的 AI SDK provider:
```bash
npm install @ai-sdk/openai @ai-sdk/anthropic @ai-sdk/google
```
## 使用示例
### 基础用法
```typescript
import { createAiSdkClient } from '@cherrystudio/ai-core'
// 创建 OpenAI 客户端
const client = await createAiSdkClient('openai', {
apiKey: 'your-api-key'
})
// 流式生成
const result = await client.stream({
modelId: 'gpt-4',
messages: [{ role: 'user', content: 'Hello!' }]
})
// 非流式生成
const response = await client.generate({
modelId: 'gpt-4',
messages: [{ role: 'user', content: 'Hello!' }]
})
```
### 便捷函数
```typescript
import { createOpenAIClient, streamGeneration } from '@cherrystudio/ai-core'
// 快速创建 OpenAI 客户端
const client = await createOpenAIClient({
apiKey: 'your-api-key'
})
// 便捷流式生成
const result = await streamGeneration('openai', 'gpt-4', [{ role: 'user', content: 'Hello!' }], {
apiKey: 'your-api-key'
})
```
### 多 Provider 支持
```typescript
import { createAiSdkClient } from '@cherrystudio/ai-core'
// 支持多种 AI providers
const openaiClient = await createAiSdkClient('openai', { apiKey: 'openai-key' })
const anthropicClient = await createAiSdkClient('anthropic', { apiKey: 'anthropic-key' })
const googleClient = await createAiSdkClient('google', { apiKey: 'google-key' })
const xaiClient = await createAiSdkClient('xai', { apiKey: 'xai-key' })
```
### 使用 AI SDK 原生 Provider 注册表
> https://ai-sdk.dev/docs/reference/ai-sdk-core/provider-registry
除了使用内建的 provider 管理,你还可以使用 AI SDK 原生的 `createProviderRegistry` 来构建自己的 provider 注册表。
#### 基本用法示例
```typescript
import { createClient } from '@cherrystudio/ai-core'
import { createProviderRegistry } from 'ai'
import { createOpenAI } from '@ai-sdk/openai'
import { anthropic } from '@ai-sdk/anthropic'
// 1. 创建 AI SDK 原生注册表
export const registry = createProviderRegistry({
// register provider with prefix and default setup:
anthropic,
// register provider with prefix and custom setup:
openai: createOpenAI({
apiKey: process.env.OPENAI_API_KEY
})
})
// 2. 创建client,'openai'可以传空或者传providerId(内建的provider)
const client = PluginEnabledAiClient.create('openai', {
apiKey: process.env.OPENAI_API_KEY
})
// 3. 方式1使用内建逻辑传统方式
const result1 = await client.streamText('gpt-4', {
messages: [{ role: 'user', content: 'Hello with built-in logic!' }]
})
// 4. 方式2使用自定义注册表灵活方式
const result2 = await client.streamText({
model: registry.languageModel('openai:gpt-4'),
messages: [{ role: 'user', content: 'Hello with custom registry!' }]
})
// 5. 支持的重载方法
await client.generateObject({
model: registry.languageModel('openai:gpt-4'),
schema: z.object({ name: z.string() }),
messages: [{ role: 'user', content: 'Generate a user' }]
})
await client.streamObject({
model: registry.languageModel('anthropic:claude-3-opus-20240229'),
schema: z.object({ items: z.array(z.string()) }),
messages: [{ role: 'user', content: 'Generate a list' }]
})
```
#### 与插件系统配合使用
更强大的是,你还可以将自定义注册表与 Cherry Studio 的插件系统结合使用:
```typescript
import { PluginEnabledAiClient } from '@cherrystudio/ai-core'
import { createProviderRegistry } from 'ai'
import { createOpenAI } from '@ai-sdk/openai'
import { anthropic } from '@ai-sdk/anthropic'
// 1. 创建带插件的客户端
const client = PluginEnabledAiClient.create(
'openai',
{
apiKey: process.env.OPENAI_API_KEY
},
[LoggingPlugin, RetryPlugin]
)
// 2. 创建自定义注册表
const registry = createProviderRegistry({
openai: createOpenAI({ apiKey: process.env.OPENAI_API_KEY }),
anthropic: anthropic({ apiKey: process.env.ANTHROPIC_API_KEY })
})
// 3. 方式1使用内建逻辑 + 完整插件系统
await client.streamText('gpt-4', {
messages: [{ role: 'user', content: 'Hello with plugins!' }]
})
// 4. 方式2使用自定义注册表 + 有限插件支持
await client.streamText({
model: registry.languageModel('anthropic:claude-3-opus-20240229'),
messages: [{ role: 'user', content: 'Hello from Claude!' }]
})
// 5. 支持的方法
await client.generateObject({
model: registry.languageModel('openai:gpt-4'),
schema: z.object({ name: z.string() }),
messages: [{ role: 'user', content: 'Generate a user' }]
})
await client.streamObject({
model: registry.languageModel('openai:gpt-4'),
schema: z.object({ items: z.array(z.string()) }),
messages: [{ role: 'user', content: 'Generate a list' }]
})
```
#### 混合使用的优势
- **灵活性**:可以根据需要选择使用内建逻辑或自定义注册表
- **兼容性**:完全兼容 AI SDK 的 `createProviderRegistry` API
- **渐进式**:可以逐步迁移现有代码,无需一次性重构
- **插件支持**:自定义注册表仍可享受 Cherry Studio 插件系统的部分功能
- **最佳实践**:结合两种方式的优点,既有动态加载的性能优势,又有统一注册表的便利性
## License
MIT

View File

@@ -1,124 +0,0 @@
{
"name": "@cherrystudio/ai-core",
"version": "1.0.0",
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
"main": "src/index.ts",
"types": "src/index.ts",
"scripts": {
"build": "tsdown",
"dev": "tsc -w",
"clean": "rm -rf dist"
},
"keywords": [
"ai",
"sdk",
"openai",
"anthropic",
"google",
"cherry-studio",
"vercel-ai-sdk"
],
"author": "Cherry Studio",
"license": "MIT",
"dependencies": {
"@ai-sdk/amazon-bedrock": "^2.2.10",
"@ai-sdk/anthropic": "^1.2.12",
"@ai-sdk/azure": "^1.3.23",
"@ai-sdk/cerebras": "^0.2.14",
"@ai-sdk/cohere": "^1.2.10",
"@ai-sdk/deepinfra": "^0.2.15",
"@ai-sdk/deepseek": "^0.2.14",
"@ai-sdk/fal": "^0.1.12",
"@ai-sdk/fireworks": "^0.2.14",
"@ai-sdk/google": "^1.2.19",
"@ai-sdk/google-vertex": "^2.2.24",
"@ai-sdk/groq": "^1.2.9",
"@ai-sdk/mistral": "^1.2.8",
"@ai-sdk/openai": "^1.3.22",
"@ai-sdk/openai-compatible": "^0.2.14",
"@ai-sdk/perplexity": "^1.1.9",
"@ai-sdk/replicate": "^0.2.8",
"@ai-sdk/togetherai": "^0.2.14",
"@ai-sdk/vercel": "^0.0.1",
"@ai-sdk/xai": "^1.2.16",
"@openrouter/ai-sdk-provider": "^0.1.0",
"ai": "^4.3.16",
"anthropic-vertex-ai": "^1.0.2",
"ollama-ai-provider": "^1.2.0",
"qwen-ai-provider": "^0.1.0",
"zhipu-ai-provider": "^0.1.1"
},
"peerDependenciesMeta": {
"@ai-sdk/amazon-bedrock": {
"optional": true
},
"@ai-sdk/anthropic": {
"optional": true
},
"@ai-sdk/azure": {
"optional": true
},
"@ai-sdk/cerebras": {
"optional": true
},
"@ai-sdk/cohere": {
"optional": true
},
"@ai-sdk/deepinfra": {
"optional": true
},
"@ai-sdk/deepseek": {
"optional": true
},
"@ai-sdk/fal": {
"optional": true
},
"@ai-sdk/fireworks": {
"optional": true
},
"@ai-sdk/google": {
"optional": true
},
"@ai-sdk/google-vertex": {
"optional": true
},
"@ai-sdk/groq": {
"optional": true
},
"@ai-sdk/mistral": {
"optional": true
},
"@ai-sdk/openai": {
"optional": true
},
"@ai-sdk/perplexity": {
"optional": true
},
"@ai-sdk/replicate": {
"optional": true
},
"@ai-sdk/together": {
"optional": true
},
"@ai-sdk/vercel": {
"optional": true
},
"@ai-sdk/xai": {
"optional": true
}
},
"devDependencies": {
"tsdown": "^0.12.8",
"typescript": "^5.0.0"
},
"files": [
"dist"
],
"exports": {
".": {
"types": "./src/index.ts",
"import": "./src/index.ts",
"require": "./src/index.ts"
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,223 +0,0 @@
/**
* API Client Factory
* 整合现有实现的改进版API客户端工厂
*/
import type { ImageModelV1 } from '@ai-sdk/provider'
import { type LanguageModelV1, LanguageModelV1Middleware, wrapLanguageModel } from 'ai'
import { aiProviderRegistry } from '../providers/registry'
import { type ProviderId, type ProviderSettingsMap } from './types'
// 客户端配置接口
export interface ClientConfig {
providerId: string
options?: any
}
// 错误类型
export class ClientFactoryError extends Error {
constructor(
message: string,
public providerId?: string,
public cause?: Error
) {
super(message)
this.name = 'ClientFactoryError'
}
}
/**
* API Client Factory
* 统一管理和创建AI SDK客户端
*/
export class ApiClientFactory {
/**
* 创建 AI SDK 模型实例
* 对于已知的 Provider 使用严格类型检查,未知的 Provider 默认使用 openai-compatible
*/
static async createClient<T extends ProviderId>(
providerId: T,
modelId: string,
options: ProviderSettingsMap[T],
middlewares?: LanguageModelV1Middleware[]
): Promise<LanguageModelV1>
static async createClient(
providerId: string,
modelId: string,
options: ProviderSettingsMap['openai-compatible'],
middlewares?: LanguageModelV1Middleware[]
): Promise<LanguageModelV1>
static async createClient(
providerId: string,
modelId: string = 'default',
options: any,
middlewares?: LanguageModelV1Middleware[]
): Promise<LanguageModelV1> {
try {
// 对于不在注册表中的 provider默认使用 openai-compatible
const effectiveProviderId = aiProviderRegistry.isSupported(providerId) ? providerId : 'openai-compatible'
// 获取Provider配置
const providerConfig = aiProviderRegistry.getProvider(effectiveProviderId)
if (!providerConfig) {
throw new ClientFactoryError(`Provider "${effectiveProviderId}" is not registered`, providerId)
}
// 动态导入模块
const module = await providerConfig.import()
// 获取创建函数
const creatorFunction = module[providerConfig.creatorFunctionName]
if (typeof creatorFunction !== 'function') {
throw new ClientFactoryError(
`Creator function "${providerConfig.creatorFunctionName}" not found in the imported module for provider "${effectiveProviderId}"`
)
}
// 创建provider实例
const provider = creatorFunction(options)
// 返回模型实例
if (typeof provider === 'function') {
let model: LanguageModelV1 = provider(modelId)
// 应用 AI SDK 中间件
if (middlewares && middlewares.length > 0) {
model = wrapLanguageModel({
model: model,
middleware: middlewares
})
}
return model
} else {
throw new ClientFactoryError(`Unknown model access pattern for provider "${effectiveProviderId}"`)
}
} catch (error) {
if (error instanceof ClientFactoryError) {
throw error
}
throw new ClientFactoryError(
`Failed to create client for provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
providerId,
error instanceof Error ? error : undefined
)
}
}
static async createImageClient<T extends ProviderId>(
providerId: T,
modelId: string,
options: ProviderSettingsMap[T]
): Promise<ImageModelV1>
static async createImageClient(
providerId: string,
modelId: string,
options: ProviderSettingsMap['openai-compatible']
): Promise<ImageModelV1>
static async createImageClient(providerId: string, modelId: string = 'default', options: any): Promise<ImageModelV1> {
try {
if (!aiProviderRegistry.isSupported(providerId)) {
throw new ClientFactoryError(`Provider "${providerId}" is not supported`, providerId)
}
const providerConfig = aiProviderRegistry.getProvider(providerId)
if (!providerConfig) {
throw new ClientFactoryError(`Provider "${providerId}" is not registered`, providerId)
}
if (!providerConfig.supportsImageGeneration) {
throw new ClientFactoryError(`Provider "${providerId}" does not support image generation`, providerId)
}
const module = await providerConfig.import()
const creatorFunction = module[providerConfig.creatorFunctionName]
if (typeof creatorFunction !== 'function') {
throw new ClientFactoryError(
`Creator function "${providerConfig.creatorFunctionName}" not found in the imported module for provider "${providerId}"`
)
}
const provider = creatorFunction(options)
if (provider && typeof provider.image === 'function') {
return provider.image(modelId)
} else {
throw new ClientFactoryError(`Image model function not found for provider "${providerId}"`)
}
} catch (error) {
if (error instanceof ClientFactoryError) {
throw error
}
throw new ClientFactoryError(
`Failed to create image client for provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
providerId,
error instanceof Error ? error : undefined
)
}
}
/**
* 获取支持的 Providers 列表
*/
static getSupportedProviders(): Array<{
id: string
name: string
}> {
return aiProviderRegistry.getAllProviders().map((provider) => ({
id: provider.id,
name: provider.name
}))
}
/**
* 获取 Provider 信息
*/
static getClientInfo(providerId: string): {
id: string
name: string
isSupported: boolean
effectiveProvider: string
} {
const effectiveProviderId = aiProviderRegistry.isSupported(providerId) ? providerId : 'openai-compatible'
const provider = aiProviderRegistry.getProvider(effectiveProviderId)
return {
id: providerId,
name: provider?.name || providerId,
isSupported: aiProviderRegistry.isSupported(providerId),
effectiveProvider: effectiveProviderId
}
}
}
// 便捷导出函数
export function createClient<T extends ProviderId>(
providerId: T,
modelId: string,
options: ProviderSettingsMap[T]
): Promise<LanguageModelV1>
export function createClient(
providerId: string,
modelId: string,
options: ProviderSettingsMap['openai-compatible']
): Promise<LanguageModelV1>
export function createClient(providerId: string, modelId: string = 'default', options: any): Promise<LanguageModelV1> {
return ApiClientFactory.createClient(providerId, modelId, options)
}
export const createImageClient = (providerId: string, modelId: string, options: any): Promise<ImageModelV1> =>
ApiClientFactory.createImageClient(providerId, modelId, options)
export const getSupportedProviders = () => ApiClientFactory.getSupportedProviders()
export const getClientInfo = (providerId: string) => ApiClientFactory.getClientInfo(providerId)

View File

@@ -1,381 +0,0 @@
/**
* AI Client - Cherry Studio AI Core 的主要客户端接口
* 默认集成插件系统,提供完整的 AI 调用能力
*
* ## 使用方式
*
* ```typescript
* import { AiClient } from '@cherrystudio/ai-core'
*
* // 创建客户端(默认带插件系统)
* const client = AiClient.create('openai', {
* name: 'openai',
* apiKey: process.env.OPENAI_API_KEY
* }, [LoggingPlugin, ContentFilterPlugin])
*
* // 使用方式与 UniversalAiSdkClient 完全相同
* const result = await client.generateText('gpt-4', {
* messages: [{ role: 'user', content: 'Hello!' }]
* })
* ```
*/
import { generateObject, generateText, LanguageModelV1Middleware, streamObject, streamText } from 'ai'
import { AiPlugin, createContext, PluginManager } from '../plugins'
import { isProviderSupported } from '../providers/registry'
import { ApiClientFactory } from './ApiClientFactory'
import { type ProviderId, type ProviderSettingsMap } from './types'
import { UniversalAiSdkClient } from './UniversalAiSdkClient'
/**
* Cherry Studio AI Core 的主要客户端
* 默认集成插件系统,提供完整的 AI 调用能力
*/
export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
private pluginManager: PluginManager
private baseClient: UniversalAiSdkClient<T>
private middlewares: LanguageModelV1Middleware[] = []
constructor(
private readonly providerId: T,
private readonly options: ProviderSettingsMap[T],
plugins: AiPlugin[] = []
) {
this.pluginManager = new PluginManager(plugins)
this.baseClient = UniversalAiSdkClient.create(providerId, options)
this.updateMiddlewares()
}
/**
* 添加单个插件
*/
use(plugin: AiPlugin): this {
this.pluginManager.use(plugin)
this.updateMiddlewares()
return this
}
/**
* 批量添加插件
*/
usePlugins(plugins: AiPlugin[]): this {
plugins.forEach((plugin) => this.pluginManager.use(plugin))
this.updateMiddlewares()
return this
}
/**
* 移除插件
*/
removePlugin(pluginName: string): this {
this.pluginManager.remove(pluginName)
this.updateMiddlewares()
return this
}
/**
* 重新计算并更新中间件列表
* 这是一个原子操作,以确保中间件列表总是最新的
*/
private updateMiddlewares(): void {
const pluginMiddlewares = this.pluginManager.collectAiSdkMiddlewares()
this.middlewares = pluginMiddlewares
}
/**
* 获取插件统计信息
*/
getPluginStats() {
return this.pluginManager.getStats()
}
/**
* 获取插件列表
*/
getPlugins() {
return this.pluginManager.getPlugins()
}
/**
* 执行插件处理的通用逻辑
* 1-5步骤的通用处理
*/
private async executeWithPlugins<TParams, TResult>(
methodName: string,
modelId: string,
params: TParams,
executor: (finalModelId: string, transformedParams: TParams) => Promise<TResult>
): Promise<TResult> {
// 创建请求上下文
const context = createContext(this.providerId, modelId, params)
try {
// 1. 触发请求开始事件
await this.pluginManager.executeParallel('onRequestStart', context)
// 2. 解析模型别名
const resolvedModelId = await this.pluginManager.executeFirst<string>('resolveModel', modelId, context)
const finalModelId = resolvedModelId || modelId
// 3. 转换请求参数
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
// 4. 执行具体的 API 调用
const result = await executor(finalModelId, transformedParams)
// 5. 转换结果(对于非流式调用)
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
// 6. 触发完成事件
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
return transformedResult
} catch (error) {
// 7. 触发错误事件
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
throw error
}
}
/**
* 执行流式调用的通用逻辑
* 流式调用的特殊处理(支持流转换器)
*/
private async executeStreamWithPlugins<TParams, TResult>(
methodName: string,
modelId: string,
params: TParams,
executor: (finalModelId: string, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>
): Promise<TResult> {
// 创建请求上下文
const context = createContext(this.providerId, modelId, params)
try {
// 1. 触发请求开始事件
await this.pluginManager.executeParallel('onRequestStart', context)
// 2. 解析模型别名
const resolvedModelId = await this.pluginManager.executeFirst<string>('resolveModel', modelId, context)
const finalModelId = resolvedModelId || modelId
// 3. 转换请求参数
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
// 4. 收集流转换器
const streamTransforms = this.pluginManager.collectStreamTransforms()
// 5. 执行流式 API 调用
const result = await executor(finalModelId, transformedParams, streamTransforms)
// 6. 触发完成事件(注意:对于流式调用,这里触发的是开始流式响应的事件)
await this.pluginManager.executeParallel('onRequestEnd', context, { stream: true })
return result
} catch (error) {
// 7. 触发错误事件
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
throw error
}
}
/**
* 获取注入了中间件的 AI SDK 模型实例
* 这是应用原生中间件的关键
*/
private async getModelWithMiddlewares(modelId: string) {
const middlewares = this.middlewares
// 3. 如果有中间件,创建一个新的、注入了中间件的客户端实例
return await ApiClientFactory.createClient(
this.providerId,
modelId,
this.options,
middlewares.length > 0 ? middlewares : undefined
)
}
/**
* 流式文本生成
*/
async streamText(
modelId: string,
params: Omit<Parameters<typeof streamText>[0], 'model'>
): Promise<ReturnType<typeof streamText>>
async streamText(params: Parameters<typeof streamText>[0]): Promise<ReturnType<typeof streamText>>
async streamText(
modelIdOrParams: string | Parameters<typeof streamText>[0],
params?: Omit<Parameters<typeof streamText>[0], 'model'>
): Promise<ReturnType<typeof streamText>> {
if (typeof modelIdOrParams === 'string') {
// 传统方式:使用内建逻辑
return this.executeStreamWithPlugins(
'streamText',
modelIdOrParams,
params!,
async (finalModelId, transformedParams, streamTransforms) => {
const model = await this.getModelWithMiddlewares(finalModelId)
const experimental_transform =
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
return await streamText({
model,
...transformedParams,
experimental_transform
})
}
)
} else {
// 外部 registry 方式:直接使用用户提供的 model
return await streamText(modelIdOrParams)
}
}
/**
* 生成文本
* 可能不需要了,因为内置模拟非流中间件
*/
async generateText(
modelId: string,
params: Omit<Parameters<typeof generateText>[0], 'model'>
): Promise<ReturnType<typeof generateText>> {
return this.executeWithPlugins('generateText', modelId, params, async (finalModelId, transformedParams) => {
const model = await this.getModelWithMiddlewares(finalModelId)
return await generateText({ model, ...transformedParams })
})
}
/**
* 生成结构化对象
*/
async generateObject(
modelId: string,
params: Omit<Parameters<typeof generateObject>[0], 'model'>
): Promise<ReturnType<typeof generateObject>>
async generateObject(params: Parameters<typeof generateObject>[0]): Promise<ReturnType<typeof generateObject>>
async generateObject(
modelIdOrParams: string | Parameters<typeof generateObject>[0],
params?: Omit<Parameters<typeof generateObject>[0], 'model'>
): Promise<ReturnType<typeof generateObject>> {
if (typeof modelIdOrParams === 'string') {
// 传统方式:使用内建逻辑
return this.executeWithPlugins(
'generateObject',
modelIdOrParams,
params!,
async (finalModelId, transformedParams) => {
const model = await this.getModelWithMiddlewares(finalModelId)
return await generateObject({ model, ...transformedParams })
}
)
} else {
// 外部 registry 方式:直接使用用户提供的 model
return await generateObject(modelIdOrParams)
}
}
/**
* 流式生成结构化对象
*/
async streamObject(
modelId: string,
params: Omit<Parameters<typeof streamObject>[0], 'model'>
): Promise<ReturnType<typeof streamObject>>
async streamObject(params: Parameters<typeof streamObject>[0]): Promise<ReturnType<typeof streamObject>>
async streamObject(
modelIdOrParams: string | Parameters<typeof streamObject>[0],
params?: Omit<Parameters<typeof streamObject>[0], 'model'>
): Promise<ReturnType<typeof streamObject>> {
if (typeof modelIdOrParams === 'string') {
// 传统方式:使用内建逻辑
return this.executeWithPlugins(
'streamObject',
modelIdOrParams,
params!,
async (finalModelId, transformedParams) => {
return await this.baseClient.streamObject(finalModelId, transformedParams)
}
)
} else {
// 外部 registry 方式:直接使用用户提供的 model
return await streamObject(modelIdOrParams)
}
}
/**
* 获取客户端信息
*/
getClientInfo() {
return this.baseClient.getClientInfo()
}
/**
* 获取底层客户端实例(用于高级用法)
*/
getBaseClient(): UniversalAiSdkClient<T> {
return this.baseClient
}
// === 静态工厂方法 ===
/**
* 创建 OpenAI Compatible 客户端
*/
static createOpenAICompatible(
config: ProviderSettingsMap['openai-compatible'],
plugins: AiPlugin[] = []
): PluginEnabledAiClient<'openai-compatible'> {
return new PluginEnabledAiClient('openai-compatible', config, plugins)
}
/**
* 创建标准提供商客户端
*/
static create<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T],
plugins?: AiPlugin[]
): PluginEnabledAiClient<T>
static create(
providerId: string,
options: ProviderSettingsMap['openai-compatible'],
plugins?: AiPlugin[]
): PluginEnabledAiClient<'openai-compatible'>
static create(providerId: string, options: any, plugins: AiPlugin[] = []): PluginEnabledAiClient {
if (isProviderSupported(providerId)) {
return new PluginEnabledAiClient(providerId as ProviderId, options, plugins)
} else {
// 对于未知 provider使用 openai-compatible
return new PluginEnabledAiClient('openai-compatible', options, plugins)
}
}
}
/**
* 创建 AI 客户端的工厂函数(默认带插件系统)
*/
export function createClient<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T],
plugins?: AiPlugin[]
): PluginEnabledAiClient<T>
export function createClient(
providerId: string,
options: ProviderSettingsMap['openai-compatible'],
plugins?: AiPlugin[]
): PluginEnabledAiClient<'openai-compatible'>
export function createClient(providerId: string, options: any, plugins: AiPlugin[] = []): PluginEnabledAiClient {
return PluginEnabledAiClient.create(providerId, options, plugins)
}
/**
* 创建 OpenAI Compatible 客户端的便捷函数
*/
export function createCompatibleClient(
config: ProviderSettingsMap['openai-compatible'],
plugins: AiPlugin[] = []
): PluginEnabledAiClient<'openai-compatible'> {
return PluginEnabledAiClient.createOpenAICompatible(config, plugins)
}

View File

@@ -1,228 +0,0 @@
/**
* Universal AI SDK Client
* 统一的AI SDK客户端实现
*
* ## 使用方式
*
* ### 1. 官方提供商
* ```typescript
* import { UniversalAiSdkClient } from '@cherrystudio/ai-core'
*
* // OpenAI
* const openai = UniversalAiSdkClient.create('openai', {
* name: 'openai',
* apiHost: 'https://api.openai.com/v1',
* apiKey: process.env.OPENAI_API_KEY
* })
*
* // Anthropic
* const anthropic = UniversalAiSdkClient.create('anthropic', {
* name: 'anthropic',
* apiHost: 'https://api.anthropic.com',
* apiKey: process.env.ANTHROPIC_API_KEY
* })
* ```
*
* ### 2. OpenAI Compatible 第三方提供商
* ```typescript
* // LM Studio (本地运行)
* const lmStudio = UniversalAiSdkClient.createOpenAICompatible({
* name: 'lm-studio',
* baseURL: 'http://localhost:1234/v1'
* })
*
* // Ollama (本地运行)
* const ollama = UniversalAiSdkClient.createOpenAICompatible({
* name: 'ollama',
* baseURL: 'http://localhost:11434/v1'
* })
*
* // 自定义第三方 API
* const customProvider = UniversalAiSdkClient.createOpenAICompatible({
* name: 'my-provider',
* apiKey: process.env.CUSTOM_API_KEY,
* baseURL: 'https://api.customprovider.com/v1',
* headers: {
* 'X-Custom-Header': 'value',
* 'User-Agent': 'MyApp/1.0'
* },
* queryParams: {
* 'api-version': '2024-01'
* }
* })
* ```
*
* ### 3. 使用客户端进行 AI 调用
* ```typescript
* // 流式文本生成
* const stream = await client.streamText('gpt-4', {
* messages: [{ role: 'user', content: 'Hello!' }]
* })
*
* // 生成文本
* const { text } = await client.generateText('gpt-4', {
* messages: [{ role: 'user', content: 'Hello!' }]
* })
*
* // 生成结构化对象
* const { object } = await client.generateObject('gpt-4', {
* messages: [{ role: 'user', content: 'Generate a user profile' }],
* schema: z.object({
* name: z.string(),
* age: z.number()
* })
* })
* ```
*/
import { experimental_generateImage as generateImage, generateObject, generateText, streamObject, streamText } from 'ai'
import { ApiClientFactory } from './ApiClientFactory'
import { type ProviderId, type ProviderSettingsMap } from './types'
/**
* 通用 AI SDK 客户端
* 为特定 AI 提供商创建的客户端实例
*/
export class UniversalAiSdkClient<T extends ProviderId = ProviderId> {
constructor(
private readonly providerId: T,
private readonly options: ProviderSettingsMap[T]
) {}
/**
* 流式文本生成
* 直接使用 AI SDK 的 streamText 参数类型
*/
async streamText(
modelId: string,
params: Omit<Parameters<typeof streamText>[0], 'model'>
): Promise<ReturnType<typeof streamText>> {
const model = await ApiClientFactory.createClient(this.providerId, modelId, this.options)
return streamText({
model,
...params
})
}
/**
* 生成文本
* 直接使用 AI SDK 的 generateText 参数类型
*/
async generateText(
modelId: string,
params: Omit<Parameters<typeof generateText>[0], 'model'>
): Promise<ReturnType<typeof generateText>> {
const model = await ApiClientFactory.createClient(this.providerId, modelId, this.options)
return generateText({
model,
...params
})
}
/**
* 生成结构化对象
* 直接使用 AI SDK 的 generateObject 参数类型
*/
async generateObject(
modelId: string,
params: Omit<Parameters<typeof generateObject>[0], 'model'>
): Promise<ReturnType<typeof generateObject>> {
const model = await ApiClientFactory.createClient(this.providerId, modelId, this.options)
return await generateObject({
model,
...params
})
}
/**
* 流式生成结构化对象
* 直接使用 AI SDK 的 streamObject 参数类型
*/
async streamObject(
modelId: string,
params: Omit<Parameters<typeof streamObject>[0], 'model'>
): Promise<ReturnType<typeof streamObject>> {
const model = await ApiClientFactory.createClient(this.providerId, modelId, this.options)
return streamObject({
model,
...params
})
}
async generateImage(
modelId: string,
params: Omit<Parameters<typeof generateImage>[0], 'model'>
): Promise<ReturnType<typeof generateImage>> {
const model = await ApiClientFactory.createImageClient(this.providerId, modelId, this.options)
return generateImage({
model,
...params
})
}
/**
* 获取客户端信息
*/
getClientInfo() {
return ApiClientFactory.getClientInfo(this.providerId)
}
// === 静态工厂方法 ===
/**
* 创建 OpenAI Compatible 客户端
* 用于那些实现 OpenAI API 的第三方提供商
*/
static createOpenAICompatible(
config: ProviderSettingsMap['openai-compatible']
): UniversalAiSdkClient<'openai-compatible'> {
return new UniversalAiSdkClient('openai-compatible', config)
}
/**
* 创建标准提供商客户端
* 对于已知的 Provider 使用严格类型检查,未知的 Provider 默认使用 openai-compatible
*/
static create<T extends ProviderId>(providerId: T, options: ProviderSettingsMap[T]): UniversalAiSdkClient<T>
static create(
providerId: string,
options: ProviderSettingsMap['openai-compatible']
): UniversalAiSdkClient<'openai-compatible'>
static create(providerId: string, options: any): UniversalAiSdkClient {
if (providerId in ({} as ProviderSettingsMap)) {
return new UniversalAiSdkClient(providerId as ProviderId, options)
} else {
// 对于未知 provider使用 openai-compatible
return new UniversalAiSdkClient('openai-compatible', options)
}
}
}
/**
* 创建客户端实例的工厂函数
*/
export function createUniversalClient<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T]
): UniversalAiSdkClient<T>
export function createUniversalClient(
providerId: string,
options: ProviderSettingsMap['openai-compatible']
): UniversalAiSdkClient<'openai-compatible'>
export function createUniversalClient(providerId: string, options: any): UniversalAiSdkClient {
return UniversalAiSdkClient.create(providerId, options)
}
/**
* 创建 OpenAI Compatible 客户端的便捷函数
*/
export function createOpenAICompatibleClient(
config: ProviderSettingsMap['openai-compatible']
): UniversalAiSdkClient<'openai-compatible'> {
return UniversalAiSdkClient.createOpenAICompatible(config)
}

View File

@@ -1,42 +0,0 @@
import { generateObject, generateText, streamObject, streamText } from 'ai'
import type { ProviderSettingsMap } from '../providers/registry'
// ProviderSettings 是所有 Provider Settings 的联合类型
export type ProviderSettings = ProviderSettingsMap[keyof ProviderSettingsMap]
export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model'>
export type GenerateTextParams = Omit<Parameters<typeof generateText>[0], 'model'>
export type StreamObjectParams = Omit<Parameters<typeof streamObject>[0], 'model'>
export type GenerateObjectParams = Omit<Parameters<typeof generateObject>[0], 'model'>
// 重新导出 ProviderSettingsMap 中的所有类型
export type {
AmazonBedrockProviderSettings,
AnthropicProviderSettings,
AnthropicVertexProviderSettings,
AzureOpenAIProviderSettings,
CerebrasProviderSettings,
CohereProviderSettings,
DeepInfraProviderSettings,
DeepSeekProviderSettings,
FalProviderSettings,
FireworksProviderSettings,
GoogleGenerativeAIProviderSettings,
GoogleVertexProviderSettings,
GroqProviderSettings,
MistralProviderSettings,
OllamaProviderSettings,
OpenAICompatibleProviderSettings,
OpenAIProviderSettings,
OpenRouterProviderSettings,
PerplexityProviderSettings,
ProviderId,
ProviderSettingsMap,
QwenProviderSettings,
ReplicateProviderSettings,
TogetherAIProviderSettings,
VercelProviderSettings,
XaiProviderSettings,
ZhipuProviderSettings
} from '../providers/registry'

View File

@@ -1,214 +0,0 @@
/**
* Cherry Studio AI Core Package
* 基于 Vercel AI SDK 的统一 AI Provider 接口
*/
// 导入内部使用的类和函数
import { ApiClientFactory } from './clients/ApiClientFactory'
import { createClient } from './clients/PluginEnabledAiClient'
import { type ProviderSettingsMap } from './clients/types'
import { createUniversalClient } from './clients/UniversalAiSdkClient'
import { aiProviderRegistry, isProviderSupported } from './providers/registry'
// ==================== 主要客户端接口 ====================
// 默认使用集成插件系统的客户端
export {
PluginEnabledAiClient as AiClient,
createClient,
createCompatibleClient
} from './clients/PluginEnabledAiClient'
// 为了向后兼容,也导出原名称
export { PluginEnabledAiClient } from './clients/PluginEnabledAiClient'
// ==================== 插件系统 ====================
export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './plugins'
export { createContext, definePlugin, PluginManager } from './plugins'
// ==================== 底层客户端(高级用法) ====================
// 不带插件系统的基础客户端,用于需要绕过插件系统的场景
export {
createOpenAICompatibleClient as createBasicOpenAICompatibleClient,
createUniversalClient,
UniversalAiSdkClient
} from './clients/UniversalAiSdkClient'
// ==================== 低级 API ====================
export { ApiClientFactory } from './clients/ApiClientFactory'
export { aiProviderRegistry } from './providers/registry'
// ==================== 类型定义 ====================
export type { ClientFactoryError } from './clients/ApiClientFactory'
export type {
GenerateObjectParams,
GenerateTextParams,
ProviderSettings,
StreamObjectParams,
StreamTextParams
} from './clients/types'
export type { ProviderConfig } from './providers/registry'
export type { ProviderError } from './providers/types'
export * as aiSdk from 'ai'
// ==================== AI SDK 常用类型导出 ====================
// 直接导出 AI SDK 的常用类型,方便使用
export type {
CoreAssistantMessage,
// 消息相关类型
CoreMessage,
CoreSystemMessage,
CoreToolMessage,
CoreUserMessage,
// 通用类型
FinishReason,
GenerateObjectResult,
// 生成相关类型
GenerateTextResult,
InvalidToolArgumentsError,
LanguageModelUsage, // AI SDK 4.0 中 TokenUsage 改名为 LanguageModelUsage
LanguageModelV1Middleware,
LanguageModelV1StreamPart,
// 错误类型
NoSuchToolError,
StreamTextResult,
// 流相关类型
TextStreamPart,
// 工具相关类型
Tool,
ToolCall,
ToolExecutionError,
ToolResult
} from 'ai'
export { defaultSettingsMiddleware, extractReasoningMiddleware, simulateStreamingMiddleware, smoothStream } from 'ai'
// 重新导出所有 Provider Settings 类型
export type {
AmazonBedrockProviderSettings,
AnthropicProviderSettings,
AnthropicVertexProviderSettings,
AzureOpenAIProviderSettings,
CerebrasProviderSettings,
CohereProviderSettings,
DeepInfraProviderSettings,
DeepSeekProviderSettings,
FalProviderSettings,
FireworksProviderSettings,
GoogleGenerativeAIProviderSettings,
GoogleVertexProviderSettings,
GroqProviderSettings,
MistralProviderSettings,
OllamaProviderSettings,
OpenAICompatibleProviderSettings,
OpenAIProviderSettings,
OpenRouterProviderSettings,
PerplexityProviderSettings,
ProviderId,
ProviderSettingsMap,
QwenProviderSettings,
ReplicateProviderSettings,
TogetherAIProviderSettings,
VercelProviderSettings,
XaiProviderSettings,
ZhipuProviderSettings
} from './clients/types'
// ==================== 工具函数 ====================
export { createClient as createApiClient, getClientInfo, getSupportedProviders } from './clients/ApiClientFactory'
export { getAllProviders, getProvider, isProviderSupported, registerProvider } from './providers/registry'
// ==================== 包信息 ====================
export const AI_CORE_VERSION = '1.0.0'
export const AI_CORE_NAME = '@cherrystudio/ai-core'
// ==================== 便捷 API ====================
// 主要的便捷工厂类
export const AiCore = {
version: AI_CORE_VERSION,
name: AI_CORE_NAME,
// 创建主要客户端(默认带插件系统)
create(providerId: string, options: any = {}, plugins: any[] = []) {
return createClient(providerId, options, plugins)
},
// 创建基础客户端(不带插件系统)
createBasic(providerId: string, options: any = {}) {
return createUniversalClient(providerId, options)
},
// 获取支持的providers
getSupportedProviders() {
return ApiClientFactory.getSupportedProviders()
},
// 检查provider支持
isSupported(providerId: string) {
return isProviderSupported(providerId)
},
// 获取客户端信息
getClientInfo(providerId: string) {
return ApiClientFactory.getClientInfo(providerId)
}
}
export const createOpenAIClient = (options: ProviderSettingsMap['openai'], plugins?: any[]) => {
return createClient('openai', options, plugins)
}
export const createAnthropicClient = (options: ProviderSettingsMap['anthropic'], plugins?: any[]) => {
return createClient('anthropic', options, plugins)
}
export const createGoogleClient = (options: ProviderSettingsMap['google'], plugins?: any[]) => {
return createClient('google', options, plugins)
}
export const createXAIClient = (options: ProviderSettingsMap['xai'], plugins?: any[]) => {
return createClient('xai', options, plugins)
}
// ==================== 调试和开发工具 ====================
export const DevTools = {
// 列出所有注册的providers
listProviders() {
return aiProviderRegistry.getAllProviders().map((p) => ({
id: p.id,
name: p.name
}))
},
// 测试provider连接
async testProvider(providerId: string, options: any) {
try {
const client = createClient(providerId, options)
const info = client.getClientInfo()
return {
success: true,
providerId: info.id,
name: info.name,
isSupported: info.isSupported
}
} catch (error) {
return {
success: false,
providerId,
error: error instanceof Error ? error.message : 'Unknown error'
}
}
},
// 获取provider详细信息
getProviderDetails() {
const providers = aiProviderRegistry.getAllProviders()
return {
supportedProviders: providers.length,
registeredProviders: providers.length,
providers: providers.map((p) => ({
id: p.id,
name: p.name
}))
}
}
}

View File

@@ -1,259 +0,0 @@
# AI Core 插件系统
支持四种钩子类型:**First**、**Sequential**、**Parallel** 和 **Stream**
## 🎯 设计理念
借鉴 Rollup/Vite 的成熟插件思想:
- **语义清晰**:不同钩子有不同的执行语义
- **类型安全**TypeScript 完整支持
- **性能优化**First 短路、Parallel 并发、Sequential 链式
- **易于扩展**`enforce` 排序 + 功能分类
## 📋 钩子类型
### 🥇 First 钩子 - 首个有效结果
```typescript
// 只执行第一个返回值的插件,用于解析和查找
resolveModel?: (modelId: string, context: AiRequestContext) => string | null
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null
```
### 🔄 Sequential 钩子 - 链式数据转换
```typescript
// 按顺序链式执行,每个插件可以修改数据
transformParams?: (params: any, context: AiRequestContext) => any
transformResult?: (result: any, context: AiRequestContext) => any
```
### ⚡ Parallel 钩子 - 并行副作用
```typescript
// 并发执行,用于日志、监控等副作用
onRequestStart?: (context: AiRequestContext) => void
onRequestEnd?: (context: AiRequestContext, result: any) => void
onError?: (error: Error, context: AiRequestContext) => void
```
### 🌊 Stream 钩子 - 流处理
```typescript
// 直接使用 AI SDK 的 TransformStream
transformStream?: () => (options) => TransformStream<TextStreamPart, TextStreamPart>
```
## 🚀 快速开始
### 基础用法
```typescript
import { PluginManager, createContext, definePlugin } from '@cherrystudio/ai-core/middleware'
// 创建插件管理器
const pluginManager = new PluginManager()
// 添加插件
pluginManager.use({
name: 'my-plugin',
async transformParams(params, context) {
return { ...params, temperature: 0.7 }
}
})
// 使用插件
const context = createContext('openai', 'gpt-4', { messages: [] })
const transformedParams = await pluginManager.executeSequential(
'transformParams',
{ messages: [{ role: 'user', content: 'Hello' }] },
context
)
```
### 完整示例
```typescript
import {
PluginManager,
ModelAliasPlugin,
LoggingPlugin,
ParamsValidationPlugin,
createContext
} from '@cherrystudio/ai-core/middleware'
// 创建插件管理器
const manager = new PluginManager([
ModelAliasPlugin, // 模型别名解析
ParamsValidationPlugin, // 参数验证
LoggingPlugin // 日志记录
])
// AI 请求流程
async function aiRequest(providerId: string, modelId: string, params: any) {
const context = createContext(providerId, modelId, params)
try {
// 1. 【并行】触发请求开始事件
await manager.executeParallel('onRequestStart', context)
// 2. 【首个】解析模型别名
const resolvedModel = await manager.executeFirst('resolveModel', modelId, context)
context.modelId = resolvedModel || modelId
// 3. 【串行】转换请求参数
const transformedParams = await manager.executeSequential('transformParams', params, context)
// 4. 【流处理】收集流转换器AI SDK 原生支持数组)
const streamTransforms = manager.collectStreamTransforms()
// 5. 调用 AI SDK这里省略具体实现
const result = await callAiSdk(transformedParams, streamTransforms)
// 6. 【串行】转换响应结果
const transformedResult = await manager.executeSequential('transformResult', result, context)
// 7. 【并行】触发请求完成事件
await manager.executeParallel('onRequestEnd', context, transformedResult)
return transformedResult
} catch (error) {
// 8. 【并行】触发错误事件
await manager.executeParallel('onError', context, undefined, error)
throw error
}
}
```
## 🔧 自定义插件
### 模型别名插件
```typescript
const ModelAliasPlugin = definePlugin({
name: 'model-alias',
enforce: 'pre', // 最先执行
async resolveModel(modelId) {
const aliases = {
gpt4: 'gpt-4-turbo-preview',
claude: 'claude-3-sonnet-20240229'
}
return aliases[modelId] || null
}
})
```
### 参数验证插件
```typescript
const ValidationPlugin = definePlugin({
name: 'validation',
async transformParams(params) {
if (!params.messages) {
throw new Error('messages is required')
}
return {
...params,
temperature: params.temperature ?? 0.7,
max_tokens: params.max_tokens ?? 4096
}
}
})
```
### 监控插件
```typescript
const MonitoringPlugin = definePlugin({
name: 'monitoring',
enforce: 'post', // 最后执行
async onRequestEnd(context, result) {
const duration = Date.now() - context.startTime
console.log(`请求耗时: ${duration}ms`)
}
})
```
### 内容过滤插件
```typescript
const FilterPlugin = definePlugin({
name: 'content-filter',
transformStream() {
return () =>
new TransformStream({
transform(chunk, controller) {
if (chunk.type === 'text-delta') {
const filtered = chunk.textDelta.replace(/敏感词/g, '***')
controller.enqueue({ ...chunk, textDelta: filtered })
} else {
controller.enqueue(chunk)
}
}
})
}
})
```
## 📊 执行顺序
### 插件排序
```
enforce: 'pre' → normal → enforce: 'post'
```
### 钩子执行流程
```mermaid
graph TD
A[请求开始] --> B[onRequestStart 并行执行]
B --> C[resolveModel 首个有效]
C --> D[loadTemplate 首个有效]
D --> E[transformParams 串行执行]
E --> F[collectStreamTransforms]
F --> G[AI SDK 调用]
G --> H[transformResult 串行执行]
H --> I[onRequestEnd 并行执行]
G --> J[异常处理]
J --> K[onError 并行执行]
```
## 💡 最佳实践
1. **功能单一**:每个插件专注一个功能
2. **幂等性**:插件应该是幂等的,重复执行不会产生副作用
3. **错误处理**:插件内部处理异常,不要让异常向上传播
4. **性能优化**使用合适的钩子类型First vs Sequential vs Parallel
5. **命名规范**:使用语义化的插件名称
## 🔍 调试工具
```typescript
// 查看插件统计信息
const stats = manager.getStats()
console.log('插件统计:', stats)
// 查看所有插件
const plugins = manager.getPlugins()
console.log(
'已注册插件:',
plugins.map((p) => p.name)
)
```
## ⚡ 性能优势
- **First 钩子**:一旦找到结果立即停止,避免无效计算
- **Parallel 钩子**:真正并发执行,不阻塞主流程
- **Sequential 钩子**:保证数据转换的顺序性
- **Stream 钩子**:直接集成 AI SDK零开销
这个设计兼顾了简洁性和强大功能,为 AI Core 提供了灵活而高效的扩展机制。

View File

@@ -1,192 +0,0 @@
import type { AiPlugin } from '../types'
/**
* 【First 钩子示例】模型别名解析插件
*/
export const ModelAliasPlugin: AiPlugin = {
name: 'model-alias',
enforce: 'pre',
async resolveModel(modelId) {
const aliases: Record<string, string> = {
gpt4: 'gpt-4-turbo-preview',
claude: 'claude-3-sonnet-20240229',
gemini: 'gemini-pro'
}
return aliases[modelId] || null
}
}
/**
* 【Sequential 钩子示例】参数验证和转换插件
*/
export const ParamsValidationPlugin: AiPlugin = {
name: 'params-validation',
async transformParams(params) {
// 参数验证
if (!params.messages || !Array.isArray(params.messages)) {
throw new Error('Invalid messages parameter')
}
// 参数转换:添加默认配置
return {
...params,
temperature: params.temperature ?? 0.7,
max_tokens: params.max_tokens ?? 4096,
stream: params.stream ?? true
}
},
async transformResult(result, context) {
// 结果后处理:添加元数据
return {
...result,
metadata: {
...result.metadata,
processedAt: new Date().toISOString(),
provider: context.providerId,
model: context.modelId
}
}
}
}
/**
* 【Parallel 钩子示例】日志记录插件
*/
export const LoggingPlugin: AiPlugin = {
name: 'logging',
async onRequestStart(context) {
console.log(`🚀 AI请求开始: ${context.providerId}/${context.modelId}`, {
requestId: context.requestId,
timestamp: new Date().toISOString()
})
},
async onRequestEnd(context, result) {
const duration = Date.now() - context.startTime
console.log(`✅ AI请求完成: ${context.requestId} (${duration}ms)`, {
provider: context.providerId,
model: context.modelId,
hasResult: !!result
})
},
async onError(error, context) {
const duration = Date.now() - context.startTime
console.error(`❌ AI请求失败: ${context.requestId} (${duration}ms)`, {
provider: context.providerId,
model: context.modelId,
error: error.message,
stack: error.stack
})
}
}
/**
* 【Parallel 钩子示例】性能监控插件
*/
export const PerformancePlugin: AiPlugin = {
name: 'performance',
enforce: 'post',
async onRequestEnd(context) {
const duration = Date.now() - context.startTime
// 记录性能指标
const metrics = {
requestId: context.requestId,
provider: context.providerId,
model: context.modelId,
duration,
timestamp: context.startTime,
success: true
}
// 发送到监控系统(这里只是示例)
// await sendMetrics(metrics)
console.log('📊 性能指标:', metrics)
},
async onError(error, context) {
const duration = Date.now() - context.startTime
const metrics = {
requestId: context.requestId,
provider: context.providerId,
model: context.modelId,
duration,
timestamp: context.startTime,
success: false,
errorType: error.constructor.name
}
console.log('📊 错误指标:', metrics)
}
}
/**
* 【Stream 钩子示例】内容过滤插件
*/
export const ContentFilterPlugin: AiPlugin = {
name: 'content-filter',
transformStream() {
return () =>
new TransformStream({
transform(chunk, controller) {
// 过滤敏感内容
if (chunk.type === 'text-delta') {
const filtered = chunk.textDelta.replace(/\b(敏感词|违禁词)\b/g, '***')
controller.enqueue({
...chunk,
textDelta: filtered
})
} else {
controller.enqueue(chunk)
}
}
})
}
}
/**
* 【First 钩子示例】模板加载插件
*/
export const TemplatePlugin: AiPlugin = {
name: 'template-loader',
async loadTemplate(templateName) {
const templates: Record<string, any> = {
chat: {
systemPrompt: '你是一个有用的AI助手',
temperature: 0.7
},
coding: {
systemPrompt: '你是一个专业的编程助手,请提供清晰、高质量的代码',
temperature: 0.3
},
creative: {
systemPrompt: '你是一个创意写作助手,请发挥想象力',
temperature: 0.9
}
}
return templates[templateName] || null
}
}
/**
* 示例插件组合
*/
export const defaultPlugins: AiPlugin[] = [
ModelAliasPlugin,
TemplatePlugin,
ParamsValidationPlugin,
LoggingPlugin,
PerformancePlugin,
ContentFilterPlugin
]

View File

@@ -1,255 +0,0 @@
import { openai } from '@ai-sdk/openai'
import { streamText } from 'ai'
import { PluginEnabledAiClient } from '../../clients/PluginEnabledAiClient'
import { createContext, PluginManager } from '../'
import { ContentFilterPlugin, LoggingPlugin } from './example-plugins'
/**
* 使用 PluginEnabledAiClient 的推荐方式
* 这是最简单直接的使用方法
*/
export async function exampleWithPluginEnabledClient() {
console.log('=== 使用 PluginEnabledAiClient 示例 ===')
// 1. 创建带插件的客户端 - 链式调用方式
const client = PluginEnabledAiClient.create('openai-compatible', {
name: 'openai',
baseURL: 'https://api.openai.com/v1',
apiKey: process.env.OPENAI_API_KEY || 'sk-test'
})
.use(LoggingPlugin)
.use(ContentFilterPlugin)
// 2. 或者在创建时传入插件(也可以这样使用)
// const clientWithPlugins = PluginEnabledAiClient.create(
// 'openai-compatible',
// {
// name: 'openai',
// baseURL: 'https://api.openai.com/v1',
// apiKey: process.env.OPENAI_API_KEY || 'sk-test'
// },
// [LoggingPlugin, ContentFilterPlugin]
// )
// 3. 查看插件统计信息
console.log('插件统计:', client.getPluginStats())
try {
// 4. 使用客户端进行 AI 调用(插件会自动生效)
console.log('开始生成文本...')
const result = await client.generateText('gpt-4', {
messages: [{ role: 'user', content: 'Hello, world!' }],
temperature: 0.7
})
console.log('生成的文本:', result.text)
// 5. 流式调用(支持流转换器)
console.log('开始流式生成...')
const streamResult = await client.streamText('gpt-4', {
messages: [{ role: 'user', content: 'Tell me a short story about AI' }]
})
console.log('开始流式响应...')
for await (const textPart of streamResult.textStream) {
process.stdout.write(textPart)
}
console.log('\n流式响应完成')
return result
} catch (error) {
console.error('调用失败:', error)
throw error
}
}
/**
* 创建 OpenAI Compatible 客户端的示例
*/
export function exampleOpenAICompatible() {
console.log('=== OpenAI Compatible 示例 ===')
// Ollama 示例
const ollama = PluginEnabledAiClient.createOpenAICompatible(
{
name: 'ollama',
baseURL: 'http://localhost:11434/v1'
},
[LoggingPlugin]
)
// LM Studio 示例
const lmStudio = PluginEnabledAiClient.createOpenAICompatible({
name: 'lm-studio',
baseURL: 'http://localhost:1234/v1'
}).use(ContentFilterPlugin)
console.log('Ollama 插件统计:', ollama.getPluginStats())
console.log('LM Studio 插件统计:', lmStudio.getPluginStats())
return { ollama, lmStudio }
}
/**
* 动态插件管理示例
*/
export function exampleDynamicPlugins() {
console.log('=== 动态插件管理示例 ===')
const client = PluginEnabledAiClient.create('openai-compatible', {
name: 'openai',
baseURL: 'https://api.openai.com/v1',
apiKey: 'your-api-key'
})
console.log('初始状态:', client.getPluginStats())
// 动态添加插件
client.use(LoggingPlugin)
console.log('添加 LoggingPlugin 后:', client.getPluginStats())
client.usePlugins([ContentFilterPlugin])
console.log('添加 ContentFilterPlugin 后:', client.getPluginStats())
// 移除插件
client.removePlugin('content-filter')
console.log('移除 content-filter 后:', client.getPluginStats())
return client
}
/**
* 完整的低级 API 示例(原有的 example-usage.ts 的方式)
* 这种方式适合需要精细控制插件生命周期的场景
*/
export async function exampleLowLevelApi() {
console.log('=== 低级 API 示例 ===')
// 1. 创建插件管理器
const pluginManager = new PluginManager([LoggingPlugin, ContentFilterPlugin])
// 2. 创建请求上下文
const context = createContext('openai', 'gpt-4', {
messages: [{ role: 'user', content: 'Hello!' }]
})
try {
// 3. 触发请求开始事件
await pluginManager.executeParallel('onRequestStart', context)
// 4. 解析模型别名
const resolvedModel = await pluginManager.executeFirst('resolveModel', 'gpt-4', context)
console.log('Resolved model:', resolvedModel || 'gpt-4')
// 5. 转换请求参数
const params = {
messages: [{ role: 'user' as const, content: 'Hello, AI!' }],
temperature: 0.7
}
const transformedParams = await pluginManager.executeSequential('transformParams', params, context)
// 6. 收集流转换器关键AI SDK 原生支持数组!)
const streamTransforms = pluginManager.collectStreamTransforms()
// 7. 调用 AI SDK直接传入转换器工厂数组
const result = await streamText({
model: openai('gpt-4'),
...transformedParams,
experimental_transform: streamTransforms // 直接传入工厂函数数组
})
// 8. 处理结果
let fullText = ''
for await (const textPart of result.textStream) {
fullText += textPart
console.log('Streaming:', textPart)
}
// 9. 转换最终结果
const finalResult = { text: fullText, usage: await result.usage }
const transformedResult = await pluginManager.executeSequential('transformResult', finalResult, context)
// 10. 触发完成事件
await pluginManager.executeParallel('onRequestEnd', context, transformedResult)
return transformedResult
} catch (error) {
// 11. 触发错误事件
await pluginManager.executeParallel('onError', context, undefined, error as Error)
throw error
}
}
/**
* 流转换器数组的其他使用方式
*/
export function demonstrateStreamTransforms() {
console.log('=== 流转换器示例 ===')
const pluginManager = new PluginManager([
ContentFilterPlugin,
{
name: 'text-replacer',
transformStream() {
return () =>
new TransformStream({
transform(chunk, controller) {
if (chunk.type === 'text-delta') {
const replaced = chunk.textDelta.replace(/hello/gi, 'hi')
controller.enqueue({ ...chunk, textDelta: replaced })
} else {
controller.enqueue(chunk)
}
}
})
}
}
])
// 获取所有流转换器
const transforms = pluginManager.collectStreamTransforms()
console.log(`收集到 ${transforms.length} 个流转换器`)
// 可以单独使用每个转换器
transforms.forEach((factory, index) => {
console.log(`转换器 ${index + 1} 已准备就绪`)
const transform = factory({ stopStream: () => {} })
console.log('Transform created:', transform)
})
return transforms
}
/**
* 运行所有示例
*/
export async function runAllExamples() {
console.log('🚀 开始运行所有示例...\n')
try {
// 1. PluginEnabledAiClient 示例(推荐)
await exampleWithPluginEnabledClient()
console.log('✅ PluginEnabledAiClient 示例完成\n')
// 2. OpenAI Compatible 示例
exampleOpenAICompatible()
console.log('✅ OpenAI Compatible 示例完成\n')
// 3. 动态插件管理示例
exampleDynamicPlugins()
console.log('✅ 动态插件管理示例完成\n')
// 4. 流转换器示例
demonstrateStreamTransforms()
console.log('✅ 流转换器示例完成\n')
// 5. 低级 API 示例
// await exampleLowLevelApi()
console.log('✅ 低级 API 示例完成\n')
console.log('🎉 所有示例运行完成!')
} catch (error) {
console.error('❌ 示例运行失败:', error)
}
}

View File

@@ -1,23 +0,0 @@
// 核心类型和接口
export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './types'
import type { AiPlugin, AiRequestContext } from './types'
// 插件管理器
export { PluginManager } from './manager'
// 工具函数
export function createContext(providerId: string, modelId: string, originalParams: any): AiRequestContext {
return {
providerId,
modelId,
originalParams,
metadata: {},
startTime: Date.now(),
requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`
}
}
// 插件构建器 - 便于创建插件
export function definePlugin(plugin: AiPlugin): AiPlugin {
return plugin
}

View File

@@ -1,189 +0,0 @@
import type { LanguageModelV1Middleware, TextStreamPart, ToolSet } from 'ai'
import { AiPlugin, AiRequestContext } from './types'
/**
* 插件管理器 - 基于 Rollup 钩子分类设计
*/
export class PluginManager {
private plugins: AiPlugin[] = []
constructor(plugins: AiPlugin[] = []) {
this.plugins = this.sortPlugins(plugins)
}
/**
* 添加插件
*/
use(plugin: AiPlugin): this {
this.plugins = this.sortPlugins([...this.plugins, plugin])
return this
}
/**
* 移除插件
*/
remove(pluginName: string): this {
this.plugins = this.plugins.filter((p) => p.name !== pluginName)
return this
}
/**
* 插件排序pre -> normal -> post
*/
private sortPlugins(plugins: AiPlugin[]): AiPlugin[] {
const pre: AiPlugin[] = []
const normal: AiPlugin[] = []
const post: AiPlugin[] = []
plugins.forEach((plugin) => {
if (plugin.enforce === 'pre') {
pre.push(plugin)
} else if (plugin.enforce === 'post') {
post.push(plugin)
} else {
normal.push(plugin)
}
})
return [...pre, ...normal, ...post]
}
/**
* 执行 First 钩子 - 返回第一个有效结果
*/
async executeFirst<T>(
hookName: 'resolveModel' | 'loadTemplate',
arg: string,
context: AiRequestContext
): Promise<T | null> {
for (const plugin of this.plugins) {
const hook = plugin[hookName]
if (hook) {
const result = await hook(arg, context)
if (result !== null && result !== undefined) {
return result as T
}
}
}
return null
}
/**
* 执行 Sequential 钩子 - 链式数据转换
*/
async executeSequential<T>(
hookName: 'transformParams' | 'transformResult',
initialValue: T,
context: AiRequestContext
): Promise<T> {
let result = initialValue
for (const plugin of this.plugins) {
const hook = plugin[hookName]
if (hook) {
result = await hook(result, context)
}
}
return result
}
/**
* 执行 Parallel 钩子 - 并行副作用
*/
async executeParallel(
hookName: 'onRequestStart' | 'onRequestEnd' | 'onError',
context: AiRequestContext,
result?: any,
error?: Error
): Promise<void> {
const promises = this.plugins
.map((plugin) => {
const hook = plugin[hookName]
if (!hook) return null
if (hookName === 'onError' && error) {
return (hook as any)(error, context)
} else if (hookName === 'onRequestEnd' && result !== undefined) {
return (hook as any)(context, result)
} else if (hookName === 'onRequestStart') {
return (hook as any)(context)
}
return null
})
.filter(Boolean)
// 使用 Promise.all 而不是 allSettled让插件错误能够抛出
await Promise.all(promises)
}
/**
* 收集所有流转换器返回数组AI SDK 原生支持)
*/
collectStreamTransforms<TOOLS extends ToolSet>(): Array<
(options: {
tools?: TOOLS
stopStream: () => void
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
> {
return this.plugins.map((plugin) => plugin.transformStream?.()).filter(Boolean) as Array<
(options: {
tools?: TOOLS
stopStream: () => void
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
>
}
/**
* 收集所有 AI SDK 原生中间件
*/
collectAiSdkMiddlewares(): LanguageModelV1Middleware[] {
return this.plugins.flatMap((plugin) => plugin.aiSdkMiddlewares || [])
}
/**
* 获取所有插件信息
*/
getPlugins(): AiPlugin[] {
return [...this.plugins]
}
/**
* 获取插件统计信息
*/
getStats() {
const stats = {
total: this.plugins.length,
pre: 0,
normal: 0,
post: 0,
hooks: {
resolveModel: 0,
loadTemplate: 0,
transformParams: 0,
transformResult: 0,
onRequestStart: 0,
onRequestEnd: 0,
onError: 0,
transformStream: 0
}
}
this.plugins.forEach((plugin) => {
// 统计 enforce 类型
if (plugin.enforce === 'pre') stats.pre++
else if (plugin.enforce === 'post') stats.post++
else stats.normal++
// 统计钩子数量
Object.keys(stats.hooks).forEach((hookName) => {
if (plugin[hookName as keyof AiPlugin]) {
stats.hooks[hookName as keyof typeof stats.hooks]++
}
})
})
return stats
}
}

View File

@@ -1,87 +0,0 @@
import type { LanguageModelV1Middleware, TextStreamPart, ToolSet } from 'ai'
/**
* 生命周期阶段定义
*/
export enum LifecycleStage {
PRE_REQUEST = 'pre-request', // 请求预处理
REQUEST_EXECUTION = 'execution', // 请求执行
STREAM_PROCESSING = 'stream', // 流式处理(仅流模式)
POST_RESPONSE = 'post-response', // 响应后处理
ERROR_HANDLING = 'error' // 错误处理
}
/**
* 生命周期上下文
*/
export interface LifecycleContext {
currentStage: LifecycleStage
startTime: number
stageStartTime: number
completedStages: Set<LifecycleStage>
stageDurations: Map<LifecycleStage, number>
metadata: Record<string, any>
}
/**
* AI 请求上下文
*/
export interface AiRequestContext {
providerId: string
modelId: string
originalParams: any
metadata: Record<string, any>
startTime: number
requestId: string
}
/**
* 借鉴 Rollup 的钩子分类设计
*/
export interface AiPlugin {
name: string
enforce?: 'pre' | 'post'
// 【First】首个钩子 - 只执行第一个返回值的插件
resolveModel?: (modelId: string, context: AiRequestContext) => string | null | Promise<string | null>
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
// 【Sequential】串行钩子 - 链式执行,支持数据转换
transformParams?: (params: any, context: AiRequestContext) => any | Promise<any>
transformResult?: (result: any, context: AiRequestContext) => any | Promise<any>
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise<void>
onError?: (error: Error, context: AiRequestContext) => void | Promise<void>
// 【Stream】流处理 - 直接使用 AI SDK
transformStream?: <TOOLS extends ToolSet>() => (options: {
tools?: TOOLS
stopStream: () => void
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
// AI SDK 原生中间件
aiSdkMiddlewares?: LanguageModelV1Middleware[]
}
/**
* 插件管理器配置
*/
export interface PluginManagerConfig {
plugins: AiPlugin[]
context: Partial<AiRequestContext>
}
/**
* 钩子执行器类型
*/
export type HookType = 'first' | 'sequential' | 'parallel' | 'stream'
/**
* 钩子执行结果
*/
export interface HookResult<T = any> {
value: T
stop?: boolean
}

View File

@@ -1,376 +0,0 @@
/**
* AI Provider 注册表
* 静态类型 + 动态导入模式:所有类型静态导入,所有实现动态导入
*/
// 静态导入所有 AI SDK 类型
import { type AmazonBedrockProviderSettings } from '@ai-sdk/amazon-bedrock'
import { type AnthropicProviderSettings } from '@ai-sdk/anthropic'
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
import { type CerebrasProviderSettings } from '@ai-sdk/cerebras'
import { type CohereProviderSettings } from '@ai-sdk/cohere'
import { type DeepInfraProviderSettings } from '@ai-sdk/deepinfra'
import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
import { type FalProviderSettings } from '@ai-sdk/fal'
import { type FireworksProviderSettings } from '@ai-sdk/fireworks'
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
import { type GoogleVertexProviderSettings } from '@ai-sdk/google-vertex'
import { type GroqProviderSettings } from '@ai-sdk/groq'
import { type MistralProviderSettings } from '@ai-sdk/mistral'
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
import { type PerplexityProviderSettings } from '@ai-sdk/perplexity'
import { type ReplicateProviderSettings } from '@ai-sdk/replicate'
import { type TogetherAIProviderSettings } from '@ai-sdk/togetherai'
import { type VercelProviderSettings } from '@ai-sdk/vercel'
import { type XaiProviderSettings } from '@ai-sdk/xai'
import { type OpenRouterProviderSettings } from '@openrouter/ai-sdk-provider'
import { type AnthropicVertexProviderSettings } from 'anthropic-vertex-ai'
import { type OllamaProviderSettings } from 'ollama-ai-provider'
import { type QwenProviderSettings } from 'qwen-ai-provider'
import { type ZhipuProviderSettings } from 'zhipu-ai-provider'
// 类型安全的 Provider Settings 映射
export type ProviderSettingsMap = {
openai: OpenAIProviderSettings
'openai-compatible': OpenAICompatibleProviderSettings
anthropic: AnthropicProviderSettings
google: GoogleGenerativeAIProviderSettings
'google-vertex': GoogleVertexProviderSettings
mistral: MistralProviderSettings
xai: XaiProviderSettings
azure: AzureOpenAIProviderSettings
bedrock: AmazonBedrockProviderSettings
cohere: CohereProviderSettings
groq: GroqProviderSettings
together: TogetherAIProviderSettings
fireworks: FireworksProviderSettings
deepseek: DeepSeekProviderSettings
cerebras: CerebrasProviderSettings
deepinfra: DeepInfraProviderSettings
replicate: ReplicateProviderSettings
perplexity: PerplexityProviderSettings
fal: FalProviderSettings
vercel: VercelProviderSettings
ollama: OllamaProviderSettings
qwen: QwenProviderSettings
zhipu: ZhipuProviderSettings
'anthropic-vertex': AnthropicVertexProviderSettings
openrouter: OpenRouterProviderSettings
}
export type ProviderId = keyof ProviderSettingsMap
// 统一的 Provider 配置接口(所有都使用动态导入)
export interface ProviderConfig {
id: string
name: string
// 动态导入函数
import: () => Promise<any>
// 创建函数名称
creatorFunctionName: string
// 是否支持图片生成
supportsImageGeneration?: boolean
}
/**
* AI SDK Provider 注册表
* 管理所有支持的 AI Providers 及其动态导入
*/
export class AiProviderRegistry {
private static instance: AiProviderRegistry
private registry = new Map<string, ProviderConfig>()
private constructor() {
this.initializeProviders()
}
public static getInstance(): AiProviderRegistry {
if (!AiProviderRegistry.instance) {
AiProviderRegistry.instance = new AiProviderRegistry()
}
return AiProviderRegistry.instance
}
/**
* 初始化所有支持的 Providers
* 基于 AI SDK 官方文档: https://ai-sdk.dev/providers/ai-sdk-providers
*/
private initializeProviders(): void {
const providers: ProviderConfig[] = [
// 官方 AI SDK Providers (19个)
{
id: 'openai',
name: 'OpenAI',
import: () => import('@ai-sdk/openai'),
creatorFunctionName: 'createOpenAI',
supportsImageGeneration: true
},
{
id: 'openai-compatible',
name: 'OpenAI Compatible',
import: () => import('@ai-sdk/openai-compatible'),
creatorFunctionName: 'createOpenAICompatible'
},
{
id: 'anthropic',
name: 'Anthropic',
import: () => import('@ai-sdk/anthropic'),
creatorFunctionName: 'createAnthropic',
supportsImageGeneration: false
},
{
id: 'google',
name: 'Google Generative AI',
import: () => import('@ai-sdk/google'),
creatorFunctionName: 'createGoogleGenerativeAI',
supportsImageGeneration: true
},
{
id: 'google-vertex',
name: 'Google Vertex AI',
import: () => import('@ai-sdk/google-vertex'),
creatorFunctionName: 'createVertex',
supportsImageGeneration: true
},
{
id: 'mistral',
name: 'Mistral AI',
import: () => import('@ai-sdk/mistral'),
creatorFunctionName: 'createMistral',
supportsImageGeneration: false
},
{
id: 'xai',
name: 'xAI (Grok)',
import: () => import('@ai-sdk/xai'),
creatorFunctionName: 'createXai',
supportsImageGeneration: true
},
{
id: 'azure',
name: 'Azure OpenAI',
import: () => import('@ai-sdk/azure'),
creatorFunctionName: 'createAzure',
supportsImageGeneration: true
},
{
id: 'bedrock',
name: 'Amazon Bedrock',
import: () => import('@ai-sdk/amazon-bedrock'),
creatorFunctionName: 'createAmazonBedrock',
supportsImageGeneration: false
},
{
id: 'cohere',
name: 'Cohere',
import: () => import('@ai-sdk/cohere'),
creatorFunctionName: 'createCohere',
supportsImageGeneration: false
},
{
id: 'groq',
name: 'Groq',
import: () => import('@ai-sdk/groq'),
creatorFunctionName: 'createGroq',
supportsImageGeneration: false
},
{
id: 'together',
name: 'Together.ai',
import: () => import('@ai-sdk/togetherai'),
creatorFunctionName: 'createTogetherAI',
supportsImageGeneration: true
},
{
id: 'fireworks',
name: 'Fireworks',
import: () => import('@ai-sdk/fireworks'),
creatorFunctionName: 'createFireworks',
supportsImageGeneration: true
},
{
id: 'deepseek',
name: 'DeepSeek',
import: () => import('@ai-sdk/deepseek'),
creatorFunctionName: 'createDeepSeek',
supportsImageGeneration: false
},
{
id: 'cerebras',
name: 'Cerebras',
import: () => import('@ai-sdk/cerebras'),
creatorFunctionName: 'createCerebras',
supportsImageGeneration: false
},
{
id: 'deepinfra',
name: 'DeepInfra',
import: () => import('@ai-sdk/deepinfra'),
creatorFunctionName: 'createDeepInfra',
supportsImageGeneration: false
},
{
id: 'replicate',
name: 'Replicate',
import: () => import('@ai-sdk/replicate'),
creatorFunctionName: 'createReplicate',
supportsImageGeneration: true
},
{
id: 'perplexity',
name: 'Perplexity',
import: () => import('@ai-sdk/perplexity'),
creatorFunctionName: 'createPerplexity',
supportsImageGeneration: false
},
{
id: 'fal',
name: 'Fal AI',
import: () => import('@ai-sdk/fal'),
creatorFunctionName: 'createFal',
supportsImageGeneration: false
},
{
id: 'vercel',
name: 'Vercel',
import: () => import('@ai-sdk/vercel'),
creatorFunctionName: 'createVercel'
},
// 社区 Providers (5个)
{
id: 'ollama',
name: 'Ollama',
import: () => import('ollama-ai-provider'),
creatorFunctionName: 'createOllama',
supportsImageGeneration: false
},
{
id: 'qwen',
name: 'Qwen',
import: () => import('qwen-ai-provider'),
creatorFunctionName: 'createQwen',
supportsImageGeneration: false
},
{
id: 'zhipu',
name: 'Zhipu AI',
import: () => import('zhipu-ai-provider'),
creatorFunctionName: 'createZhipu',
supportsImageGeneration: false
},
{
id: 'anthropic-vertex',
name: 'Anthropic Vertex AI',
import: () => import('anthropic-vertex-ai'),
creatorFunctionName: 'createAnthropicVertex',
supportsImageGeneration: false
},
{
id: 'openrouter',
name: 'OpenRouter',
import: () => import('@openrouter/ai-sdk-provider'),
creatorFunctionName: 'createOpenRouter',
supportsImageGeneration: false
}
]
// 注册所有 providers (总计24个)
providers.forEach((config) => {
this.registry.set(config.id, config)
})
}
/**
* 获取所有已注册的 Providers
*/
public getAllProviders(): ProviderConfig[] {
return Array.from(this.registry.values())
}
/**
* 根据 ID 获取 Provider 配置
*/
public getProvider(id: string): ProviderConfig | undefined {
return this.registry.get(id)
}
/**
* 检查 Provider 是否支持(是否已注册)
*/
public isSupported(id: string): boolean {
return this.registry.has(id)
}
/**
* 注册新的 Provider用于扩展
*/
public registerProvider(config: ProviderConfig): void {
this.registry.set(config.id, config)
}
/**
* 清理资源
*/
public cleanup(): void {
this.registry.clear()
}
/**
* 获取兼容现有实现的注册表格式
*/
public getCompatibleRegistry(): Record<string, { import: () => Promise<any>; creatorFunctionName: string }> {
const compatibleRegistry: Record<string, { import: () => Promise<any>; creatorFunctionName: string }> = {}
this.getAllProviders().forEach((provider) => {
compatibleRegistry[provider.id] = {
import: provider.import,
creatorFunctionName: provider.creatorFunctionName
}
})
return compatibleRegistry
}
}
// 导出单例实例
export const aiProviderRegistry = AiProviderRegistry.getInstance()
// 便捷函数
export const getProvider = (id: string) => aiProviderRegistry.getProvider(id)
export const getAllProviders = () => aiProviderRegistry.getAllProviders()
export const isProviderSupported = (id: string) => aiProviderRegistry.isSupported(id)
export const registerProvider = (config: ProviderConfig) => aiProviderRegistry.registerProvider(config)
// 兼容现有实现的导出
export const PROVIDER_REGISTRY = aiProviderRegistry.getCompatibleRegistry()
// 重新导出所有类型供外部使用
export type {
AmazonBedrockProviderSettings,
AnthropicProviderSettings,
AnthropicVertexProviderSettings,
AzureOpenAIProviderSettings,
CerebrasProviderSettings,
CohereProviderSettings,
DeepInfraProviderSettings,
DeepSeekProviderSettings,
FalProviderSettings,
FireworksProviderSettings,
GoogleGenerativeAIProviderSettings,
GoogleVertexProviderSettings,
GroqProviderSettings,
MistralProviderSettings,
OllamaProviderSettings,
OpenAICompatibleProviderSettings,
OpenAIProviderSettings,
OpenRouterProviderSettings,
PerplexityProviderSettings,
QwenProviderSettings,
ReplicateProviderSettings,
TogetherAIProviderSettings,
VercelProviderSettings,
XaiProviderSettings,
ZhipuProviderSettings
}

View File

@@ -1,47 +0,0 @@
/**
* Provider 相关核心类型定义
* 只定义必要的接口,其他类型直接使用 AI SDK
*/
// Provider 配置接口(简化版)
export interface ProviderConfig {
id: string
name: string
import: () => Promise<any>
creatorFunctionName: string
}
// API 客户端工厂接口
export interface ApiClientFactory {
createAiSdkClient(providerId: string, options?: any): Promise<any>
getCachedClient(providerId: string, options?: any): any
clearCache(): void
}
// 客户端配置
export interface ClientConfig {
providerId: string
apiKey?: string
baseURL?: string
[key: string]: any
}
// 错误类型
export class ProviderError extends Error {
constructor(
message: string,
public providerId: string,
public code?: string,
public cause?: Error
) {
super(message)
this.name = 'ProviderError'
}
}
// 缓存统计信息
export interface CacheStats {
size: number
keys: string[]
lastCleanup?: Date
}

View File

@@ -1,26 +0,0 @@
{
"compilerOptions": {
"target": "ES2020",
"module": "ESNext",
"moduleResolution": "node",
"declaration": true,
"outDir": "./dist",
"rootDir": "./src",
"strict": true,
"esModuleInterop": true,
"skipLibCheck": true,
"forceConsistentCasingInFileNames": true,
"resolveJsonModule": true,
"allowSyntheticDefaultImports": true,
"noEmitOnError": false,
"experimentalDecorators": true,
"emitDecoratorMetadata": true
},
"include": [
"src/**/*"
],
"exclude": [
"node_modules",
"dist"
]
}

View File

@@ -15,12 +15,7 @@ export enum IpcChannel {
App_SetAutoUpdate = 'app:set-auto-update',
App_SetFeedUrl = 'app:set-feed-url',
App_HandleZoomFactor = 'app:handle-zoom-factor',
App_Select = 'app:select',
App_HasWritePermission = 'app:has-write-permission',
App_Copy = 'app:copy',
App_SetStopQuitApp = 'app:set-stop-quit-app',
App_SetAppDataPath = 'app:set-app-data-path',
App_RelaunchApp = 'app:relaunch-app',
App_IsBinaryExist = 'app:is-binary-exist',
App_GetBinaryPath = 'app:get-binary-path',
App_InstallUvBinary = 'app:install-uv-binary',

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -1,6 +1,7 @@
import { app } from 'electron'
import { getDataPath } from './utils'
const isDev = process.env.NODE_ENV === 'development'
if (isDev) {

View File

@@ -1,7 +1,6 @@
import '@main/config'
import { electronApp, optimizer } from '@electron-toolkit/utils'
import { initAppDataDir } from '@main/utils/file'
import { replaceDevtoolsFont } from '@main/utils/windowUtil'
import { app } from 'electron'
import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer'
@@ -21,8 +20,8 @@ import selectionService, { initSelectionService } from './services/SelectionServ
import { registerShortcuts } from './services/ShortcutService'
import { TrayService } from './services/TrayService'
import { windowService } from './services/WindowService'
import { setUserDataDir } from './utils/file'
initAppDataDir()
Logger.initialize()
/**
@@ -73,6 +72,9 @@ if (!app.requestSingleInstanceLock()) {
app.quit()
process.exit(0)
} else {
// Portable dir must be setup before app ready
setUserDataDir()
// This method will be called when Electron has finished
// initialization and is ready to create browser windows.
// Some APIs can only be used after this event occurs.

View File

@@ -7,7 +7,7 @@ import { handleZoomFactor } from '@main/utils/zoom'
import { FeedUrl } from '@shared/config/constant'
import { IpcChannel } from '@shared/IpcChannel'
import { Shortcut, ThemeMode } from '@types'
import { BrowserWindow, dialog, ipcMain, session, shell } from 'electron'
import { BrowserWindow, ipcMain, session, shell } from 'electron'
import log from 'electron-log'
import { Notification } from 'src/renderer/src/types/notification'
@@ -34,7 +34,7 @@ import { setOpenLinkExternal } from './services/WebviewService'
import { windowService } from './services/WindowService'
import { calculateDirectorySize, getResourcePath } from './utils'
import { decrypt, encrypt } from './utils/aes'
import { getCacheDir, getConfigDir, getFilesDir, hasWritePermission, updateConfig } from './utils/file'
import { getCacheDir, getConfigDir, getFilesDir } from './utils/file'
import { compress, decompress } from './utils/zip'
const fileManager = new FileStorage()
@@ -175,70 +175,6 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
}
})
let preventQuitListener: ((event: Electron.Event) => void) | null = null
ipcMain.handle(IpcChannel.App_SetStopQuitApp, (_, stop: boolean = false, reason: string = '') => {
if (stop) {
// Only add listener if not already added
if (!preventQuitListener) {
preventQuitListener = (event: Electron.Event) => {
event.preventDefault()
notificationService.sendNotification({
title: reason,
message: reason
} as Notification)
}
app.on('before-quit', preventQuitListener)
}
} else {
// Remove listener if it exists
if (preventQuitListener) {
app.removeListener('before-quit', preventQuitListener)
preventQuitListener = null
}
}
})
// Select app data path
ipcMain.handle(IpcChannel.App_Select, async (_, options: Electron.OpenDialogOptions) => {
try {
const { canceled, filePaths } = await dialog.showOpenDialog(options)
if (canceled || filePaths.length === 0) {
return null
}
return filePaths[0]
} catch (error: any) {
log.error('Failed to select app data path:', error)
return null
}
})
ipcMain.handle(IpcChannel.App_HasWritePermission, async (_, filePath: string) => {
return hasWritePermission(filePath)
})
// Set app data path
ipcMain.handle(IpcChannel.App_SetAppDataPath, async (_, filePath: string) => {
updateConfig(filePath)
app.setPath('userData', filePath)
})
// Copy user data to new location
ipcMain.handle(IpcChannel.App_Copy, async (_, oldPath: string, newPath: string) => {
try {
await fs.promises.cp(oldPath, newPath, { recursive: true })
return { success: true }
} catch (error: any) {
log.error('Failed to copy user data:', error)
return { success: false, error: error.message }
}
})
// Relaunch app
ipcMain.handle(IpcChannel.App_RelaunchApp, () => {
app.relaunch()
app.exit(0)
})
// check for update
ipcMain.handle(IpcChannel.App_CheckForUpdate, async () => {
return await appUpdater.checkForUpdates()

View File

@@ -9,7 +9,6 @@ import StreamZip from 'node-stream-zip'
import * as path from 'path'
import { CreateDirectoryOptions, FileStat } from 'webdav'
import { getDataPath } from '../utils'
import WebDav from './WebDav'
import { windowService } from './WindowService'
@@ -254,7 +253,7 @@ class BackupManager {
Logger.log('[backup] step 3: restore Data directory')
// 恢复 Data 目录
const sourcePath = path.join(this.tempDir, 'Data')
const destPath = getDataPath()
const destPath = path.join(app.getPath('userData'), 'Data')
const dataExists = await fs.pathExists(sourcePath)
const dataFiles = dataExists ? await fs.readdir(sourcePath) : []

View File

@@ -25,12 +25,12 @@ import Embeddings from '@main/embeddings/Embeddings'
import { addFileLoader } from '@main/loader'
import Reranker from '@main/reranker/Reranker'
import { windowService } from '@main/services/WindowService'
import { getDataPath } from '@main/utils'
import { getAllFiles } from '@main/utils/file'
import { MB } from '@shared/config/constant'
import type { LoaderReturn } from '@shared/config/types'
import { IpcChannel } from '@shared/IpcChannel'
import { FileType, KnowledgeBaseParams, KnowledgeItem } from '@types'
import { app } from 'electron'
import Logger from 'electron-log'
import { v4 as uuidv4 } from 'uuid'
@@ -88,7 +88,7 @@ const loaderTaskIntoOfSet = (loaderTask: LoaderTask): LoaderTaskOfSet => {
}
class KnowledgeService {
private storageDir = path.join(getDataPath(), 'KnowledgeBase')
private storageDir = path.join(app.getPath('userData'), 'Data', 'KnowledgeBase')
// Byte based
private workload = 0
private processingItemCount = 0

View File

@@ -1,4 +1,4 @@
import { isLinux, isMac, isWin } from '@main/constant'
import { isMac } from '@main/constant'
import { locales } from '@main/utils/locales'
import { app, Menu, MenuItemConstructorOptions, nativeImage, nativeTheme, Tray } from 'electron'
@@ -6,7 +6,6 @@ import icon from '../../../build/tray_icon.png?asset'
import iconDark from '../../../build/tray_icon_dark.png?asset'
import iconLight from '../../../build/tray_icon_light.png?asset'
import { ConfigKeys, configManager } from './ConfigManager'
import selectionService from './SelectionService'
import { windowService } from './WindowService'
export class TrayService {
@@ -30,14 +29,14 @@ export class TrayService {
const iconPath = isMac ? (nativeTheme.shouldUseDarkColors ? iconLight : iconDark) : icon
const tray = new Tray(iconPath)
if (isWin) {
if (process.platform === 'win32') {
tray.setImage(iconPath)
} else if (isMac) {
} else if (process.platform === 'darwin') {
const image = nativeImage.createFromPath(iconPath)
const resizedImage = image.resize({ width: 16, height: 16 })
resizedImage.setTemplateImage(true)
tray.setImage(resizedImage)
} else if (isLinux) {
} else if (process.platform === 'linux') {
const image = nativeImage.createFromPath(iconPath)
const resizedImage = image.resize({ width: 16, height: 16 })
tray.setImage(resizedImage)
@@ -47,7 +46,7 @@ export class TrayService {
this.updateContextMenu()
if (isLinux) {
if (process.platform === 'linux') {
this.tray.setContextMenu(this.contextMenu)
}
@@ -70,31 +69,19 @@ export class TrayService {
private updateContextMenu() {
const locale = locales[configManager.getLanguage()]
const { tray: trayLocale, selection: selectionLocale } = locale.translation
const { tray: trayLocale } = locale.translation
const quickAssistantEnabled = configManager.getEnableQuickAssistant()
const selectionAssistantEnabled = configManager.getSelectionAssistantEnabled()
const enableQuickAssistant = configManager.getEnableQuickAssistant()
const template = [
{
label: trayLocale.show_window,
click: () => windowService.showMainWindow()
},
quickAssistantEnabled && {
enableQuickAssistant && {
label: trayLocale.show_mini_window,
click: () => windowService.showMiniWindow()
},
isWin && {
label: selectionLocale.name + (selectionAssistantEnabled ? ' - On' : ' - Off'),
// type: 'checkbox',
// checked: selectionAssistantEnabled,
click: () => {
if (selectionService) {
selectionService.toggleEnabled()
this.updateContextMenu()
}
}
},
{ type: 'separator' },
{
label: trayLocale.quit,
@@ -131,10 +118,6 @@ export class TrayService {
configManager.subscribe(ConfigKeys.EnableQuickAssistant, () => {
this.updateContextMenu()
})
configManager.subscribe(ConfigKeys.SelectionAssistantEnabled, () => {
this.updateContextMenu()
})
}
private quit() {

View File

@@ -56,7 +56,7 @@ export class WindowService {
minHeight: 600,
show: false,
autoHideMenuBar: true,
transparent: false,
transparent: isMac,
vibrancy: 'sidebar',
visualEffectState: 'active',
titleBarStyle: 'hidden',

View File

@@ -2,7 +2,7 @@ import * as fs from 'node:fs'
import os from 'node:os'
import path from 'node:path'
import { isPortable } from '@main/constant'
import { isMac } from '@main/constant'
import { audioExts, documentExts, imageExts, textExts, videoExts } from '@shared/config/constant'
import { FileType, FileTypes } from '@types'
import { app } from 'electron'
@@ -23,61 +23,6 @@ function initFileTypeMap() {
// 初始化映射表
initFileTypeMap()
export function hasWritePermission(path: string) {
try {
fs.accessSync(path, fs.constants.W_OK)
return true
} catch (error) {
return false
}
}
function getAppDataPathFromConfig() {
try {
const configPath = path.join(getConfigDir(), 'config.json')
if (fs.existsSync(configPath)) {
const config = JSON.parse(fs.readFileSync(configPath, 'utf-8'))
if (config.appDataPath && fs.existsSync(config.appDataPath) && hasWritePermission(config.appDataPath)) {
return config.appDataPath
}
}
} catch (error) {
return null
}
return null
}
export function initAppDataDir() {
const appDataPath = getAppDataPathFromConfig()
if (appDataPath) {
app.setPath('userData', appDataPath)
return
}
if (isPortable) {
const portableDir = process.env.PORTABLE_EXECUTABLE_DIR
app.setPath('userData', path.join(portableDir || app.getPath('exe'), 'data'))
return
}
}
export function updateConfig(appDataPath: string) {
const configDir = getConfigDir()
if (!fs.existsSync(configDir)) {
fs.mkdirSync(configDir, { recursive: true })
}
const configPath = path.join(getConfigDir(), 'config.json')
if (!fs.existsSync(configPath)) {
fs.writeFileSync(configPath, JSON.stringify({ appDataPath }, null, 2))
return
}
const config = JSON.parse(fs.readFileSync(configPath, 'utf-8'))
config.appDataPath = appDataPath
fs.writeFileSync(configPath, JSON.stringify(config, null, 2))
}
export function getFileType(ext: string): FileTypes {
ext = ext.toLowerCase()
return fileTypeMap.get(ext) || FileTypes.OTHER
@@ -143,3 +88,12 @@ export function getCacheDir() {
export function getAppConfigDir(name: string) {
return path.join(getConfigDir(), name)
}
export function setUserDataDir() {
if (!isMac) {
const dir = path.join(path.dirname(app.getPath('exe')), 'data')
if (fs.existsSync(dir) && fs.statSync(dir).isDirectory()) {
app.setPath('userData', dir)
}
}
}

View File

@@ -26,12 +26,6 @@ const api = {
handleZoomFactor: (delta: number, reset: boolean = false) =>
ipcRenderer.invoke(IpcChannel.App_HandleZoomFactor, delta, reset),
setAutoUpdate: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetAutoUpdate, isActive),
select: (options: Electron.OpenDialogOptions) => ipcRenderer.invoke(IpcChannel.App_Select, options),
hasWritePermission: (path: string) => ipcRenderer.invoke(IpcChannel.App_HasWritePermission, path),
setAppDataPath: (path: string) => ipcRenderer.invoke(IpcChannel.App_SetAppDataPath, path),
copy: (oldPath: string, newPath: string) => ipcRenderer.invoke(IpcChannel.App_Copy, oldPath, newPath),
setStopQuitApp: (stop: boolean, reason: string) => ipcRenderer.invoke(IpcChannel.App_SetStopQuitApp, stop, reason),
relaunchApp: () => ipcRenderer.invoke(IpcChannel.App_RelaunchApp),
openWebsite: (url: string) => ipcRenderer.invoke(IpcChannel.Open_Website, url),
getCacheSize: () => ipcRenderer.invoke(IpcChannel.App_GetCacheSize),
clearCache: () => ipcRenderer.invoke(IpcChannel.App_ClearCache),

View File

@@ -1,284 +0,0 @@
/**
* AI SDK 到 Cherry Studio Chunk 适配器
* 用于将 AI SDK 的 fullStream 转换为 Cherry Studio 的 chunk 格式
*/
import { TextStreamPart } from '@cherrystudio/ai-core'
import { Chunk, ChunkType } from '@renderer/types/chunk'
export interface CherryStudioChunk {
type: 'text-delta' | 'text-complete' | 'tool-call' | 'tool-result' | 'finish' | 'error'
text?: string
toolCall?: any
toolResult?: any
finishReason?: string
usage?: any
error?: any
}
/**
* AI SDK 到 Cherry Studio Chunk 适配器类
* 处理 fullStream 到 Cherry Studio chunk 的转换
*/
export class AiSdkToChunkAdapter {
constructor(private onChunk: (chunk: Chunk) => void) {}
/**
* 处理 AI SDK 流结果
* @param aiSdkResult AI SDK 的流结果对象
* @returns 最终的文本内容
*/
async processStream(aiSdkResult: any): Promise<string> {
// 如果是流式且有 fullStream
if (aiSdkResult.fullStream) {
await this.readFullStream(aiSdkResult.fullStream)
}
// 使用 streamResult.text 获取最终结果
return await aiSdkResult.text
}
/**
* 读取 fullStream 并转换为 Cherry Studio chunks
* @param fullStream AI SDK 的 fullStream (ReadableStream)
*/
private async readFullStream(fullStream: ReadableStream<TextStreamPart<any>>) {
const reader = fullStream.getReader()
const final = {
text: '',
reasoning_content: ''
}
try {
while (true) {
const { done, value } = await reader.read()
if (done) {
break
}
// 转换并发送 chunk
this.convertAndEmitChunk(value, final)
}
} finally {
reader.releaseLock()
}
}
/**
* 转换 AI SDK chunk 为 Cherry Studio chunk 并调用回调
* @param chunk AI SDK 的 chunk 数据
*/
private convertAndEmitChunk(chunk: any, final: { text: string; reasoning_content: string }) {
console.log('AI SDK chunk type:', chunk.type, chunk)
switch (chunk.type) {
// === 文本相关事件 ===
case 'text-delta':
final.text += chunk.textDelta || ''
this.onChunk({
type: ChunkType.TEXT_DELTA,
text: chunk.textDelta || ''
})
break
case 'reasoning':
this.onChunk({
type: ChunkType.THINKING_DELTA,
text: chunk.textDelta || ''
})
break
case 'redacted-reasoning':
this.onChunk({
type: ChunkType.THINKING_DELTA,
text: chunk.data || ''
})
break
case 'reasoning-signature':
this.onChunk({
type: ChunkType.THINKING_COMPLETE,
text: chunk.text || '',
thinking_millsec: chunk.thinking_millsec || 0
})
break
// === 工具调用相关事件 ===
case 'tool-call-streaming-start':
// 开始流式工具调用
this.onChunk({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: [
{
id: chunk.toolCallId,
name: chunk.toolName,
args: {}
}
]
})
break
case 'tool-call-delta':
// 工具调用参数的增量更新
this.onChunk({
type: ChunkType.MCP_TOOL_IN_PROGRESS,
responses: [
{
id: chunk.toolCallId,
tool: {
id: chunk.toolName,
// TODO: serverId,serverName
serverId: 'ai-sdk',
serverName: 'AI SDK',
name: chunk.toolName,
description: '',
inputSchema: {
type: 'object',
title: chunk.toolName,
properties: {}
}
},
arguments: {},
status: 'invoking',
response: chunk.argsTextDelta,
toolCallId: chunk.toolCallId
}
]
})
break
case 'tool-call':
// 完整的工具调用
this.onChunk({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: [
{
id: chunk.toolCallId,
name: chunk.toolName,
args: chunk.args
}
]
})
break
case 'tool-result':
// 工具调用结果
this.onChunk({
type: ChunkType.MCP_TOOL_COMPLETE,
responses: [
{
id: chunk.toolCallId,
tool: {
id: chunk.toolName,
// TODO: serverId,serverName
serverId: 'ai-sdk',
serverName: 'AI SDK',
name: chunk.toolName,
description: '',
inputSchema: {
type: 'object',
title: chunk.toolName,
properties: {}
}
},
arguments: chunk.args || {},
status: 'done',
response: chunk.result,
toolCallId: chunk.toolCallId
}
]
})
break
// === 步骤相关事件 ===
// case 'step-start':
// this.onChunk({
// type: ChunkType.LLM_RESPONSE_CREATED
// })
// break
case 'step-finish':
this.onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
text: final.text || '',
reasoning_content: final.reasoning_content || '',
usage: {
completion_tokens: chunk.usage.completionTokens || 0,
prompt_tokens: chunk.usage.promptTokens || 0,
total_tokens: chunk.usage.totalTokens || 0
},
metrics: chunk.usage
? {
completion_tokens: chunk.usage.completionTokens || 0,
time_completion_millsec: 0
}
: undefined
}
})
break
case 'finish':
this.onChunk({
type: ChunkType.TEXT_COMPLETE,
text: final.text || '' // TEXT_COMPLETE 需要 text 字段
})
this.onChunk({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
text: final.text || '',
reasoning_content: final.reasoning_content || '',
usage: {
completion_tokens: chunk.usage.completionTokens || 0,
prompt_tokens: chunk.usage.promptTokens || 0,
total_tokens: chunk.usage.totalTokens || 0
},
metrics: chunk.usage
? {
completion_tokens: chunk.usage.completionTokens || 0,
time_completion_millsec: 0
}
: undefined
}
})
break
// === 源和文件相关事件 ===
case 'source':
// 源信息,可以映射到知识搜索完成
this.onChunk({
type: ChunkType.KNOWLEDGE_SEARCH_COMPLETE,
knowledge: [
{
id: Number(chunk.source.id) || Date.now(),
content: chunk.source.title || '',
sourceUrl: chunk.source.url || '',
type: 'url'
}
]
})
break
case 'file':
// 文件相关事件,可能是图片生成
this.onChunk({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: [chunk.base64]
}
})
break
case 'error':
this.onChunk({
type: ChunkType.ERROR,
error: {
message: chunk.error || 'Unknown error'
}
})
break
default:
// 其他类型的 chunk 可以忽略或记录日志
console.log('Unhandled AI SDK chunk type:', chunk.type, chunk)
}
}
}
export default AiSdkToChunkAdapter

View File

@@ -20,7 +20,6 @@ import {
SdkToolCall
} from '@renderer/types/sdk'
import { CompletionsContext } from '../middleware/types'
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
import { BaseApiClient } from './BaseApiClient'
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
@@ -164,8 +163,8 @@ export class AihubmixAPIClient extends BaseApiClient {
return this.currentClient.getRequestTransformer()
}
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<SdkRawChunk> {
return this.currentClient.getResponseChunkTransformer(ctx)
getResponseChunkTransformer(): ResponseChunkTransformer<SdkRawChunk> {
return this.currentClient.getResponseChunkTransformer()
}
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] {

View File

@@ -42,8 +42,7 @@ import { defaultTimeout } from '@shared/config/constant'
import Logger from 'electron-log/renderer'
import { isEmpty } from 'lodash'
import { CompletionsContext } from '../middleware/types'
import { ApiClient, RequestTransformer, ResponseChunkTransformer } from './types'
import { ApiClient, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from './types'
/**
* Abstract base class for API clients.
@@ -96,7 +95,7 @@ export abstract class BaseApiClient<
// 在 CoreRequestToSdkParamsMiddleware中使用
abstract getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
// 在RawSdkChunkToGenericChunkMiddleware中使用
abstract getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<TRawChunk>
abstract getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk>
/**
* 工具转换
@@ -111,7 +110,7 @@ export abstract class BaseApiClient<
abstract buildSdkMessages(
currentReqMessages: TMessageParam[],
output: TRawOutput | string | undefined,
output: TRawOutput | string,
toolResults: TMessageParam[],
toolCalls?: TToolCall[]
): TMessageParam[]
@@ -130,6 +129,17 @@ export abstract class BaseApiClient<
*/
abstract extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
/**
* 附加原始流监听器
*/
public attachRawStreamListener<TListener extends RawStreamListener<TRawChunk>>(
rawOutput: TRawOutput,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
_listener: TListener
): TRawOutput {
return rawOutput
}
/**
* 通用函数
**/

View File

@@ -90,7 +90,7 @@ export class AnthropicAPIClient extends BaseApiClient<
return this.sdkInstance
}
this.sdkInstance = new Anthropic({
apiKey: this.apiKey,
apiKey: this.getApiKey(),
baseURL: this.getBaseURL(),
dangerouslyAllowBrowser: true,
defaultHeaders: {
@@ -125,7 +125,7 @@ export class AnthropicAPIClient extends BaseApiClient<
// @ts-ignore sdk未提供
override async getEmbeddingDimensions(): Promise<number> {
throw new Error("Anthropic SDK doesn't support getEmbeddingDimensions method.")
return 0
}
override getTemperature(assistant: Assistant, model: Model): number | undefined {
@@ -367,13 +367,12 @@ export class AnthropicAPIClient extends BaseApiClient<
* Anthropic专用的原始流监听器
* 处理MessageStream对象的特定事件
*/
attachRawStreamListener(
override attachRawStreamListener(
rawOutput: AnthropicSdkRawOutput,
listener: RawStreamListener<AnthropicSdkRawChunk>
): AnthropicSdkRawOutput {
console.log(`[AnthropicApiClient] 附加流监听器到原始输出`)
// 专用的Anthropic事件处理
const anthropicListener = listener as AnthropicStreamListener
// 检查是否为MessageStream
if (rawOutput instanceof MessageStream) {
console.log(`[AnthropicApiClient] 检测到 Anthropic MessageStream附加专用监听器`)
@@ -388,6 +387,9 @@ export class AnthropicAPIClient extends BaseApiClient<
})
}
// 专用的Anthropic事件处理
const anthropicListener = listener as AnthropicStreamListener
if (anthropicListener.onContentBlock) {
rawOutput.on('contentBlock', anthropicListener.onContentBlock)
}
@@ -411,10 +413,6 @@ export class AnthropicAPIClient extends BaseApiClient<
return rawOutput
}
if (anthropicListener.onMessage) {
anthropicListener.onMessage(rawOutput)
}
// 对于非MessageStream响应
return rawOutput
}
@@ -520,7 +518,6 @@ export class AnthropicAPIClient extends BaseApiClient<
async transform(rawChunk: AnthropicSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
switch (rawChunk.type) {
case 'message': {
let i = 0
for (const content of rawChunk.content) {
switch (content.type) {
case 'text': {
@@ -531,8 +528,7 @@ export class AnthropicAPIClient extends BaseApiClient<
break
}
case 'tool_use': {
toolCalls[i] = content
i++
toolCalls[0] = content
break
}
case 'thinking': {
@@ -554,22 +550,6 @@ export class AnthropicAPIClient extends BaseApiClient<
}
}
}
if (i > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: Object.values(toolCalls)
} as MCPToolCreatedChunk)
}
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: rawChunk.usage.input_tokens || 0,
completion_tokens: rawChunk.usage.output_tokens || 0,
total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0)
}
}
})
break
}
case 'content_block_start': {

View File

@@ -85,7 +85,7 @@ export class GeminiAPIClient extends BaseApiClient<
...rest,
config: {
...rest.config,
abortSignal: options?.signal,
abortSignal: options?.abortSignal,
httpOptions: {
...rest.config?.httpOptions,
timeout: options?.timeout
@@ -147,12 +147,15 @@ export class GeminiAPIClient extends BaseApiClient<
override async getEmbeddingDimensions(model: Model): Promise<number> {
const sdk = await this.getSdkInstance()
const data = await sdk.models.embedContent({
model: model.id,
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
})
return data.embeddings?.[0]?.values?.length || 0
try {
const data = await sdk.models.embedContent({
model: model.id,
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
})
return data.embeddings?.[0]?.values?.length || 0
} catch (e) {
return 0
}
}
override async listModels(): Promise<GeminiModel[]> {
@@ -413,9 +416,8 @@ export class GeminiAPIClient extends BaseApiClient<
}
}
const { min, max } = findTokenLimit(model.id) || { min: 0, max: 0 }
// 计算 budgetTokens确保不低于 min
const budget = Math.floor((max - min) * effortRatio + min)
const { max } = findTokenLimit(model.id) || { max: 0 }
const budget = Math.floor(max * effortRatio)
return {
thinkingConfig: {
@@ -464,7 +466,7 @@ export class GeminiAPIClient extends BaseApiClient<
systemInstruction = await buildSystemPrompt(assistant.prompt || '', mcpTools, assistant)
}
let messageContents: Content = { role: 'user', parts: [] } // Initialize messageContents
let messageContents: Content
const history: Content[] = []
// 3. 处理用户消息
if (typeof messages === 'string') {
@@ -473,13 +475,10 @@ export class GeminiAPIClient extends BaseApiClient<
parts: [{ text: messages }]
}
} else {
const userLastMessage = messages.pop()
if (userLastMessage) {
messageContents = await this.convertMessageToSdkParam(userLastMessage)
for (const message of messages) {
history.push(await this.convertMessageToSdkParam(message))
}
messages.push(userLastMessage)
const userLastMessage = messages.pop()!
messageContents = await this.convertMessageToSdkParam(userLastMessage)
for (const message of messages) {
history.push(await this.convertMessageToSdkParam(message))
}
}
@@ -492,10 +491,6 @@ export class GeminiAPIClient extends BaseApiClient<
if (isGemmaModel(model) && assistant.prompt) {
const isFirstMessage = history.length === 0
if (isFirstMessage && messageContents) {
const userMessageText =
messageContents.parts && messageContents.parts.length > 0
? (messageContents.parts[0] as Part).text || ''
: ''
const systemMessage = [
{
text:
@@ -503,7 +498,7 @@ export class GeminiAPIClient extends BaseApiClient<
systemInstruction +
'<end_of_turn>\n' +
'<start_of_turn>user\n' +
userMessageText +
(messageContents?.parts?.[0] as Part).text +
'<end_of_turn>'
}
] as Part[]
@@ -520,7 +515,13 @@ export class GeminiAPIClient extends BaseApiClient<
const newMessageContents =
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
? recursiveSdkMessages[recursiveSdkMessages.length - 1]
? {
...messageContents,
parts: [
...(messageContents.parts || []),
...(recursiveSdkMessages[recursiveSdkMessages.length - 1].parts || [])
]
}
: messageContents
const generateContentConfig: GenerateContentConfig = {
@@ -554,7 +555,7 @@ export class GeminiAPIClient extends BaseApiClient<
getResponseChunkTransformer(): ResponseChunkTransformer<GeminiSdkRawChunk> {
return () => ({
async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
const toolCalls: FunctionCall[] = []
let toolCalls: FunctionCall[] = []
if (chunk.candidates && chunk.candidates.length > 0) {
for (const candidate of chunk.candidates) {
if (candidate.content) {
@@ -582,8 +583,6 @@ export class GeminiAPIClient extends BaseApiClient<
]
}
})
} else if (part.functionCall) {
toolCalls.push(part.functionCall)
}
})
}
@@ -598,6 +597,9 @@ export class GeminiAPIClient extends BaseApiClient<
}
} as LLMWebSearchCompleteChunk)
}
if (chunk.functionCalls) {
toolCalls = toolCalls.concat(chunk.functionCalls)
}
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
@@ -700,11 +702,12 @@ export class GeminiAPIClient extends BaseApiClient<
.filter((p) => p !== undefined)
)
const lastMessage = currentReqMessages[currentReqMessages.length - 1]
if (lastMessage) {
lastMessage.parts?.push(...parts)
const userMessage: Content = {
role: 'user',
parts: parts
}
return currentReqMessages
return [...currentReqMessages, userMessage]
}
override estimateMessageTokens(message: GeminiSdkMessageParam): number {
@@ -731,20 +734,7 @@ export class GeminiAPIClient extends BaseApiClient<
}
public extractMessagesFromSdkPayload(sdkPayload: GeminiSdkParams): GeminiSdkMessageParam[] {
const messageParam: GeminiSdkMessageParam = {
role: 'user',
parts: []
}
if (Array.isArray(sdkPayload.message)) {
sdkPayload.message.forEach((part) => {
if (typeof part === 'string') {
messageParam.parts?.push({ text: part })
} else if (typeof part === 'object') {
messageParam.parts?.push(part)
}
})
}
return [messageParam, ...(sdkPayload.history || [])]
return sdkPayload.history || []
}
private async uploadFile(file: FileType): Promise<File> {

View File

@@ -337,14 +337,10 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
public buildSdkMessages(
currentReqMessages: OpenAISdkMessageParam[],
output: string | undefined,
output: string,
toolResults: OpenAISdkMessageParam[],
toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[]
): OpenAISdkMessageParam[] {
if (!output && toolCalls.length === 0) {
return [...currentReqMessages, ...toolResults]
}
const assistantMessage: OpenAISdkMessageParam = {
role: 'assistant',
content: output,
@@ -494,7 +490,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
}
// 在RawSdkChunkToGenericChunkMiddleware中使用
getResponseChunkTransformer(): ResponseChunkTransformer<OpenAISdkRawChunk> {
getResponseChunkTransformer = (): ResponseChunkTransformer<OpenAISdkRawChunk> => {
let hasBeenCollectedWebSearch = false
const collectWebSearchData = (
chunk: OpenAISdkRawChunk,
@@ -588,52 +584,9 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
return null
}
const toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] = []
let isFinished = false
let lastUsageInfo: any = null
/**
* 统一的完成信号发送逻辑
* - 有 finish_reason 时
* - 无 finish_reason 但是流正常结束时
*/
const emitCompletionSignals = (controller: TransformStreamDefaultController<GenericChunk>) => {
if (isFinished) return
if (toolCalls.length > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
}
const usage = lastUsageInfo || {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0
}
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: { usage }
})
// 防止重复发送
isFinished = true
}
return (context: ResponseChunkTransformerContext) => ({
async transform(chunk: OpenAISdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
// 持续更新usage信息
if (chunk.usage) {
lastUsageInfo = {
prompt_tokens: chunk.usage.prompt_tokens || 0,
completion_tokens: chunk.usage.completion_tokens || 0,
total_tokens: (chunk.usage.prompt_tokens || 0) + (chunk.usage.completion_tokens || 0)
}
}
// 处理chunk
if ('choices' in chunk && chunk.choices && chunk.choices.length > 0) {
const choice = chunk.choices[0]
@@ -698,6 +651,12 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
// 处理finish_reason发送流结束信号
if ('finish_reason' in choice && choice.finish_reason) {
Logger.debug(`[OpenAIApiClient] Stream finished with reason: ${choice.finish_reason}`)
if (toolCalls.length > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
}
const webSearchData = collectWebSearchData(chunk, contentSource, context)
if (webSearchData) {
controller.enqueue({
@@ -705,17 +664,18 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
llm_web_search: webSearchData
})
}
emitCompletionSignals(controller)
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: chunk.usage?.prompt_tokens || 0,
completion_tokens: chunk.usage?.completion_tokens || 0,
total_tokens: (chunk.usage?.prompt_tokens || 0) + (chunk.usage?.completion_tokens || 0)
}
}
})
}
}
},
// 流正常结束时,检查是否需要发送完成信号
flush(controller) {
if (isFinished) return
Logger.debug('[OpenAIApiClient] Stream ended without finish_reason, emitting fallback completion signals')
emitCompletionSignals(controller)
}
})
}

View File

@@ -85,13 +85,16 @@ export abstract class OpenAIBaseClient<
override async getEmbeddingDimensions(model: Model): Promise<number> {
const sdk = await this.getSdkInstance()
const data = await sdk.embeddings.create({
model: model.id,
input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi',
encoding_format: 'float'
})
return data.data[0].embedding.length
try {
const data = await sdk.embeddings.create({
model: model.id,
input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi',
encoding_format: 'float'
})
return data.data[0].embedding.length
} catch (e) {
return 0
}
}
override async listModels(): Promise<OpenAI.Models.Model[]> {
@@ -135,7 +138,7 @@ export abstract class OpenAIBaseClient<
return this.sdkInstance
}
let apiKeyForSdkInstance = this.apiKey
let apiKeyForSdkInstance = this.provider.apiKey
if (this.provider.id === 'copilot') {
const defaultHeaders = store.getState().copilot.defaultHeaders

View File

@@ -1,5 +1,4 @@
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { CompletionsContext } from '@renderer/aiCore/middleware/types'
import {
isOpenAIChatCompletionOnlyModel,
isSupportedReasoningEffortOpenAIModel,
@@ -39,7 +38,6 @@ import { buildSystemPrompt } from '@renderer/utils/prompt'
import { MB } from '@shared/config/constant'
import { isEmpty } from 'lodash'
import OpenAI from 'openai'
import { ResponseInput } from 'openai/resources/responses/responses'
import { RequestTransformer, ResponseChunkTransformer } from '../types'
import { OpenAIAPIClient } from './OpenAIApiClient'
@@ -78,7 +76,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
return new OpenAI({
dangerouslyAllowBrowser: true,
apiKey: this.apiKey,
apiKey: this.provider.apiKey,
baseURL: this.getBaseURL(),
defaultHeaders: {
...this.defaultHeaders()
@@ -227,29 +225,17 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
return
}
private convertResponseToMessageContent(response: OpenAI.Responses.Response): ResponseInput {
const content: OpenAI.Responses.ResponseInput = []
content.push(...response.output)
return content
}
public buildSdkMessages(
currentReqMessages: OpenAIResponseSdkMessageParam[],
output: OpenAI.Responses.Response | undefined,
output: string,
toolResults: OpenAIResponseSdkMessageParam[],
toolCalls: OpenAIResponseSdkToolCall[]
): OpenAIResponseSdkMessageParam[] {
if (!output && toolCalls.length === 0) {
return [...currentReqMessages, ...toolResults]
const assistantMessage: OpenAIResponseSdkMessageParam = {
role: 'assistant',
content: [{ type: 'input_text', text: output }]
}
if (!output) {
return [...currentReqMessages, ...(toolCalls || []), ...(toolResults || [])]
}
const content = this.convertResponseToMessageContent(output)
const newReqMessages = [...currentReqMessages, ...content, ...(toolResults || [])]
const newReqMessages = [...currentReqMessages, assistantMessage, ...(toolCalls || []), ...(toolResults || [])]
return newReqMessages
}
@@ -421,17 +407,13 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
}
}
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<OpenAIResponseSdkRawChunk> {
getResponseChunkTransformer(): ResponseChunkTransformer<OpenAIResponseSdkRawChunk> {
const toolCalls: OpenAIResponseSdkToolCall[] = []
const outputItems: OpenAI.Responses.ResponseOutputItem[] = []
let hasBeenCollectedToolCalls = false
return () => ({
async transform(chunk: OpenAIResponseSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
// 处理chunk
if ('output' in chunk) {
if (ctx._internal?.toolProcessingState) {
ctx._internal.toolProcessingState.output = chunk
}
for (const output of chunk.output) {
switch (output.type) {
case 'message':
@@ -473,22 +455,6 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
})
}
}
if (toolCalls.length > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
}
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: chunk.usage?.input_tokens || 0,
completion_tokens: chunk.usage?.output_tokens || 0,
total_tokens: chunk.usage?.total_tokens || 0
}
}
})
} else {
switch (chunk.type) {
case 'response.output_item.added':
@@ -536,8 +502,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
if (outputItem.type === 'function_call') {
toolCalls.push({
...outputItem,
arguments: chunk.arguments,
status: 'completed'
arguments: chunk.arguments
})
}
}
@@ -553,26 +518,15 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
}
})
}
if (toolCalls.length > 0 && !hasBeenCollectedToolCalls) {
if (toolCalls.length > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
hasBeenCollectedToolCalls = true
}
break
}
case 'response.completed': {
if (ctx._internal?.toolProcessingState) {
ctx._internal.toolProcessingState.output = chunk.response
}
if (toolCalls.length > 0 && !hasBeenCollectedToolCalls) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
hasBeenCollectedToolCalls = true
}
const completion_tokens = chunk.response.usage?.output_tokens || 0
const total_tokens = chunk.response.usage?.total_tokens || 0
controller.enqueue({

View File

@@ -3,8 +3,6 @@ import { Assistant, MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@r
import { Provider } from '@renderer/types'
import {
AnthropicSdkRawChunk,
OpenAIResponseSdkRawChunk,
OpenAIResponseSdkRawOutput,
OpenAISdkRawChunk,
SdkMessageParam,
SdkParams,
@@ -16,7 +14,6 @@ import {
import OpenAI from 'openai'
import { CompletionsParams, GenericChunk } from '../middleware/schemas'
import { CompletionsContext } from '../middleware/types'
/**
* 原始流监听器接口
@@ -36,14 +33,6 @@ export interface OpenAIStreamListener extends RawStreamListener<OpenAISdkRawChun
onFinishReason?: (reason: string) => void
}
/**
* OpenAI Response 专用的流监听器
*/
export interface OpenAIResponseStreamListener<TChunk extends OpenAIResponseSdkRawChunk = OpenAIResponseSdkRawChunk>
extends RawStreamListener<TChunk> {
onMessage?: (response: OpenAIResponseSdkRawOutput) => void
}
/**
* Anthropic 专用的流监听器
*/
@@ -112,7 +101,7 @@ export interface ApiClient<
// SDK相关方法
getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<TRawChunk>
getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk>
// 原始流监听方法
attachRawStreamListener?(rawOutput: TRawOutput, listener: RawStreamListener<TRawChunk>): TRawOutput

View File

@@ -11,7 +11,6 @@ import { AnthropicAPIClient } from './clients/anthropic/AnthropicAPIClient'
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
import { CompletionsMiddlewareBuilder } from './middleware/builder'
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
import { applyCompletionsMiddlewares } from './middleware/composer'
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
@@ -63,7 +62,6 @@ export default class AiProvider {
builder.clear()
builder
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
} else {
@@ -76,7 +74,7 @@ export default class AiProvider {
if (!(this.apiClient instanceof OpenAIAPIClient)) {
builder.remove(ThinkingTagExtractionMiddlewareName)
}
if (!(this.apiClient instanceof AnthropicAPIClient) && !(this.apiClient instanceof OpenAIResponseAPIClient)) {
if (!(this.apiClient instanceof AnthropicAPIClient)) {
builder.remove(RawStreamListenerMiddlewareName)
}
if (!params.enableWebSearch) {
@@ -114,7 +112,7 @@ export default class AiProvider {
return dimensions
} catch (error) {
console.error('Error getting embedding dimensions:', error)
throw error
return 0
}
}

View File

@@ -1,229 +0,0 @@
/**
* Cherry Studio AI Core - 新版本入口
* 集成 @cherrystudio/ai-core 库的渐进式重构方案
*
* 融合方案:简化实现,专注于核心功能
* 1. 优先使用新AI SDK
* 2. 失败时fallback到原有实现
* 3. 暂时保持接口兼容性
*/
import {
AiClient,
AiCore,
createClient,
type OpenAICompatibleProviderSettings,
type ProviderId,
smoothStream,
StreamTextParams
} from '@cherrystudio/ai-core'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
// 引入适配器
import AiSdkToChunkAdapter from './AiSdkToChunkAdapter'
// 引入原有的AiProvider作为fallback
import LegacyAiProvider from './index'
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder'
import { CompletionsResult } from './middleware/schemas'
// 引入参数转换模块
/**
* 将现有 Provider 类型映射到 AI SDK 的 Provider ID
* 根据 registry.ts 中的支持列表进行映射
*/
function mapProviderTypeToAiSdkId(providerType: string): string {
// Cherry Studio Provider Type -> AI SDK Provider ID 映射表
const typeMapping: Record<string, string> = {
// 需要转换的映射
grok: 'xai', // grok -> xai
'azure-openai': 'azure', // azure-openai -> azure
gemini: 'google', // gemini -> google
vertexai: 'google-vertex' // vertexai -> google-vertex
}
return typeMapping[providerType]
}
/**
* 将 Provider 配置转换为新 AI SDK 格式
*/
function providerToAiSdkConfig(provider: Provider): {
providerId: ProviderId | 'openai-compatible'
options: any
} {
console.log('provider', provider)
// 1. 先映射 provider 类型到 AI SDK ID
const mappedProviderId = mapProviderTypeToAiSdkId(provider.id)
// 2. 检查映射后的 provider ID 是否在 AI SDK 注册表中
const isSupported = AiCore.isSupported(mappedProviderId)
console.log(`Provider mapping: ${provider.type} -> ${mappedProviderId}, supported: ${isSupported}`)
// 3. 如果映射的 provider 不支持,则使用 openai-compatible
if (isSupported) {
return {
providerId: mappedProviderId as ProviderId,
options: {
apiKey: provider.apiKey
}
}
} else {
console.log(`Using openai-compatible fallback for provider: ${provider.type}`)
const compatibleConfig: OpenAICompatibleProviderSettings = {
name: provider.name || provider.type,
apiKey: provider.apiKey,
baseURL: provider.apiHost
}
return {
providerId: 'openai-compatible',
options: compatibleConfig
}
}
}
/**
* 检查是否支持使用新的AI SDK
*/
function isModernSdkSupported(provider: Provider, model?: Model): boolean {
// 目前支持主要的providers
const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai', 'vertexai']
// 检查provider类型
if (!supportedProviders.includes(provider.type)) {
return false
}
// 检查是否为图像生成模型(暂时不支持)
if (model && isDedicatedImageGenerationModel(model)) {
return false
}
return true
}
export default class ModernAiProvider {
private modernClient?: AiClient
private legacyProvider: LegacyAiProvider
private provider: Provider
constructor(provider: Provider) {
this.provider = provider
this.legacyProvider = new LegacyAiProvider(provider)
// 初始化时不构建中间件,等到需要时再构建
const config = providerToAiSdkConfig(provider)
this.modernClient = createClient(config.providerId, config.options)
}
public async completions(
modelId: string,
params: StreamTextParams,
middlewareConfig: AiSdkMiddlewareConfig
): Promise<CompletionsResult> {
// const model = params.assistant.model
// 检查是否应该使用现代化客户端
// if (this.modernClient && model && isModernSdkSupported(this.provider, model)) {
// try {
return await this.modernCompletions(modelId, params, middlewareConfig)
// } catch (error) {
// console.warn('Modern client failed, falling back to legacy:', error)
// fallback到原有实现
// }
// }
// 使用原有实现
// return this.legacyProvider.completions(params, options)
}
/**
* 使用现代化AI SDK的completions实现
* 使用建造者模式动态构建中间件
*/
private async modernCompletions(
modelId: string,
params: StreamTextParams,
middlewareConfig: AiSdkMiddlewareConfig
): Promise<CompletionsResult> {
if (!this.modernClient) {
throw new Error('Modern AI SDK client not initialized')
}
try {
// 合并传入的配置和实例配置
const finalConfig: AiSdkMiddlewareConfig = {
...middlewareConfig,
provider: this.provider,
// 工具相关信息从 params 中获取
enableTool: params.tools !== undefined && Array.isArray(params.tools) && params.tools.length > 0
}
// 动态构建中间件数组
const middlewares = buildAiSdkMiddlewares(finalConfig)
console.log(
'构建的中间件:',
middlewares.map((m) => m.name)
)
// 创建带有中间件的客户端
const config = providerToAiSdkConfig(this.provider)
const clientWithMiddlewares = createClient(config.providerId, config.options, middlewares)
if (middlewareConfig.onChunk) {
// 流式处理 - 使用适配器
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
const streamResult = await clientWithMiddlewares.streamText(modelId, {
...params,
experimental_transform: smoothStream({
delayInMs: 80,
// 中文3个字符一个chunk,英文一个单词一个chunk
chunking: /([\u4E00-\u9FFF]{3})|\S+\s+/
})
})
const finalText = await adapter.processStream(streamResult)
return {
getText: () => finalText
}
} else {
// 流式处理但没有 onChunk 回调
const streamResult = await clientWithMiddlewares.streamText(modelId, params)
const finalText = await streamResult.text
return {
getText: () => finalText
}
}
} catch (error) {
console.error('Modern AI SDK error:', error)
throw error
}
}
// 代理其他方法到原有实现
public async models() {
return this.legacyProvider.models()
}
public async getEmbeddingDimensions(model: Model): Promise<number> {
return this.legacyProvider.getEmbeddingDimensions(model)
}
public async generateImage(params: GenerateImageParams): Promise<string[]> {
return this.legacyProvider.generateImage(params)
}
public getBaseURL(): string {
return this.legacyProvider.getBaseURL()
}
public getApiKey(): string {
return this.legacyProvider.getApiKey()
}
}
// 为了方便调试,导出一些工具函数
export { isModernSdkSupported, providerToAiSdkConfig }

View File

@@ -1,188 +0,0 @@
import { AiPlugin, simulateStreamingMiddleware } from '@cherrystudio/ai-core'
import { isReasoningModel } from '@renderer/config/models'
import type { Model, Provider } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import thinkingTimeMiddleware from './ThinkingTimeMiddleware'
/**
* AI SDK 中间件配置项
*/
export interface AiSdkMiddlewareConfig {
streamOutput?: boolean
onChunk?: (chunk: Chunk) => void
model?: Model
provider?: Provider
enableReasoning?: boolean
enableTool?: boolean
enableWebSearch?: boolean
}
/**
* 具名的 AI SDK 中间件
*/
export type NamedAiSdkMiddleware = AiPlugin
/**
* AI SDK 中间件建造者
* 用于根据不同条件动态构建中间件数组
*/
export class AiSdkMiddlewareBuilder {
private middlewares: NamedAiSdkMiddleware[] = []
/**
* 添加具名中间件
*/
public add(namedMiddleware: NamedAiSdkMiddleware): this {
this.middlewares.push(namedMiddleware)
return this
}
/**
* 在指定位置插入中间件
*/
public insertAfter(targetName: string, middleware: NamedAiSdkMiddleware): this {
const index = this.middlewares.findIndex((m) => m.name === targetName)
if (index !== -1) {
this.middlewares.splice(index + 1, 0, middleware)
} else {
console.warn(`AiSdkMiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`)
}
return this
}
/**
* 检查是否包含指定名称的中间件
*/
public has(name: string): boolean {
return this.middlewares.some((m) => m.name === name)
}
/**
* 移除指定名称的中间件
*/
public remove(name: string): this {
this.middlewares = this.middlewares.filter((m) => m.name !== name)
return this
}
/**
* 构建最终的中间件数组
*/
public build(): NamedAiSdkMiddleware[] {
return [...this.middlewares]
}
/**
* 清空所有中间件
*/
public clear(): this {
this.middlewares = []
return this
}
/**
* 获取中间件总数
*/
public get length(): number {
return this.middlewares.length
}
}
/**
* 根据配置构建AI SDK中间件的工厂函数
* 这里要注意构建顺序,因为有些中间件需要依赖其他中间件的结果
*/
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): NamedAiSdkMiddleware[] {
const builder = new AiSdkMiddlewareBuilder()
// 1. 思考模型且有onChunk回调时添加思考时间中间件
if (config.onChunk && config.model && isReasoningModel(config.model)) {
builder.add({
name: 'thinking-time',
aiSdkMiddlewares: [thinkingTimeMiddleware()]
})
}
// 2. 可以在这里根据其他条件添加更多中间件
// 例如工具调用、Web搜索等相关中间件
// 3. 根据provider添加特定中间件
if (config.provider) {
addProviderSpecificMiddlewares(builder, config)
}
// 4. 根据模型类型添加特定中间件
if (config.model) {
addModelSpecificMiddlewares(builder, config)
}
// 5. 非流式输出时添加模拟流中间件
if (config.streamOutput === false) {
builder.add({
name: 'simulate-streaming',
aiSdkMiddlewares: [simulateStreamingMiddleware()]
})
}
return builder.build()
}
/**
* 添加provider特定的中间件
*/
function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void {
if (!config.provider) return
// 根据不同provider添加特定中间件
switch (config.provider.type) {
case 'anthropic':
// Anthropic特定中间件
break
case 'openai':
// OpenAI特定中间件
break
case 'gemini':
// Gemini特定中间件
break
default:
// 其他provider的通用处理
break
}
}
/**
* 添加模型特定的中间件
*/
function addModelSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void {
if (!config.model) return
// 可以根据模型ID或特性添加特定中间件
// 例如:图像生成模型、多模态模型等
// 示例:某些模型需要特殊处理
if (config.model.id.includes('dalle') || config.model.id.includes('midjourney')) {
// 图像生成相关中间件
}
}
/**
* 创建一个预配置的中间件建造者
*/
export function createAiSdkMiddlewareBuilder(): AiSdkMiddlewareBuilder {
return new AiSdkMiddlewareBuilder()
}
/**
* 创建一个带有默认中间件的建造者
*/
export function createDefaultAiSdkMiddlewareBuilder(config: AiSdkMiddlewareConfig): AiSdkMiddlewareBuilder {
const builder = new AiSdkMiddlewareBuilder()
const defaultMiddlewares = buildAiSdkMiddlewares(config)
defaultMiddlewares.forEach((middleware) => {
builder.add(middleware)
})
return builder
}

View File

@@ -1,140 +0,0 @@
# AI SDK 中间件建造者
## 概述
`AiSdkMiddlewareBuilder` 是一个用于动态构建 AI SDK 中间件数组的建造者模式实现。它可以根据不同的条件如流式输出、思考模型、provider类型等自动构建合适的中间件组合。
## 使用方式
### 基本用法
```typescript
import { buildAiSdkMiddlewares, type AiSdkMiddlewareConfig } from './AiSdkMiddlewareBuilder'
// 配置中间件参数
const config: AiSdkMiddlewareConfig = {
streamOutput: false, // 非流式输出
onChunk: chunkHandler, // chunk回调函数
model: currentModel, // 当前模型
provider: currentProvider, // 当前provider
enableReasoning: true, // 启用推理
enableTool: false, // 禁用工具
enableWebSearch: false // 禁用网页搜索
}
// 构建中间件数组
const middlewares = buildAiSdkMiddlewares(config)
// 创建带有中间件的客户端
const client = createClient(providerId, options, middlewares)
```
### 手动构建
```typescript
import { AiSdkMiddlewareBuilder, createAiSdkMiddlewareBuilder } from './AiSdkMiddlewareBuilder'
const builder = createAiSdkMiddlewareBuilder()
// 添加特定中间件
builder.add({
name: 'custom-middleware',
aiSdkMiddlewares: [customMiddleware()]
})
// 检查是否包含某个中间件
if (builder.has('thinking-time')) {
console.log('已包含思考时间中间件')
}
// 移除不需要的中间件
builder.remove('simulate-streaming')
// 构建最终数组
const middlewares = builder.build()
```
## 支持的条件
### 1. 流式输出控制
- **streamOutput = false**: 自动添加 `simulateStreamingMiddleware`
- **streamOutput = true**: 使用原生流式处理
### 2. 思考模型处理
- **条件**: `onChunk` 存在 && `isReasoningModel(model)` 为 true
- **效果**: 自动添加 `thinkingTimeMiddleware`
### 3. Provider 特定中间件
根据不同的 provider 类型添加特定中间件:
- **anthropic**: Anthropic 特定处理
- **openai**: OpenAI 特定处理
- **gemini**: Gemini 特定处理
### 4. 模型特定中间件
根据模型特性添加中间件:
- **图像生成模型**: 添加图像处理相关中间件
- **多模态模型**: 添加多模态处理中间件
## 扩展指南
### 添加新的条件判断
`buildAiSdkMiddlewares` 函数中添加新的条件:
```typescript
// 例如:添加缓存中间件
if (config.enableCache) {
builder.add({
name: 'cache',
aiSdkMiddlewares: [cacheMiddleware(config.cacheOptions)]
})
}
```
### 添加 Provider 特定处理
`addProviderSpecificMiddlewares` 函数中添加:
```typescript
case 'custom-provider':
builder.add({
name: 'custom-provider-middleware',
aiSdkMiddlewares: [customProviderMiddleware()]
})
break
```
### 添加模型特定处理
`addModelSpecificMiddlewares` 函数中添加:
```typescript
if (config.model.id.includes('custom-model')) {
builder.add({
name: 'custom-model-middleware',
aiSdkMiddlewares: [customModelMiddleware()]
})
}
```
## 中间件执行顺序
中间件按照添加顺序执行:
1. **simulate-streaming** (如果 streamOutput = false)
2. **thinking-time** (如果是思考模型且有 onChunk)
3. **provider-specific** (根据 provider 类型)
4. **model-specific** (根据模型类型)
## 注意事项
1. 中间件的执行顺序很重要,确保按正确顺序添加
2. 避免添加冲突的中间件
3. 某些中间件可能有依赖关系,需要确保依赖的中间件先添加
4. 建议在开发环境下启用日志,以便调试中间件构建过程

View File

@@ -1,67 +0,0 @@
import { LanguageModelV1Middleware, LanguageModelV1StreamPart } from '@cherrystudio/ai-core'
import { ChunkType, ThinkingCompleteChunk } from '@renderer/types/chunk'
/**
* 一个用于统计 LLM "思考时间"Time to First Token的 AI SDK 中间件。
*
* 工作原理:
* 1. 在 `stream` 方法被调用时,记录一个起始时间。
* 2. 它会创建一个新的 `TransformStream` 来代理原始的流。
* 3. 当第一个数据块 (chunk) 从原始流中到达时,记录结束时间。
* 4. 计算两者之差,即为 "思考时间"
* 这里只处理了thinking_complete
*/
export default function thinkingTimeMiddleware(): LanguageModelV1Middleware {
return {
wrapStream: async ({ doStream }) => {
let hasThinkingContent = false
let thinkingStartTime = 0
let accumulatedThinkingContent = ''
const { stream, ...reset } = await doStream()
const transformStream = new TransformStream<LanguageModelV1StreamPart, any>({
transform(chunk, controller) {
if (chunk.type === 'reasoning' || chunk.type === 'redacted-reasoning') {
if (!hasThinkingContent) {
hasThinkingContent = true
thinkingStartTime = Date.now()
}
accumulatedThinkingContent += chunk.textDelta || ''
} else {
if (hasThinkingContent && thinkingStartTime > 0) {
const thinkingTime = Date.now() - thinkingStartTime
const thinkingCompleteChunk = {
type: 'reasoning-signature',
text: accumulatedThinkingContent,
thinking_millsec: thinkingTime
}
controller.enqueue(thinkingCompleteChunk)
hasThinkingContent = false
thinkingStartTime = 0
accumulatedThinkingContent = ''
}
}
// 将所有 chunk 原样传递下去
controller.enqueue(chunk)
},
flush(controller) {
// 如果流的末尾都是 reasoning也需要发送 complete 事件
if (hasThinkingContent && thinkingStartTime > 0) {
const thinkingTime = Date.now() - thinkingStartTime
const thinkingCompleteChunk: ThinkingCompleteChunk = {
type: ChunkType.THINKING_COMPLETE,
text: accumulatedThinkingContent,
thinking_millsec: thinkingTime
}
controller.enqueue(thinkingCompleteChunk)
}
controller.terminate()
}
})
return {
stream: stream.pipeThrough(transformStream),
...reset
}
}
}
}

View File

@@ -1,4 +1,5 @@
import { Chunk } from '@renderer/types/chunk'
import { isAbortError } from '@renderer/utils/error'
import { CompletionsResult } from '../schemas'
import { CompletionsContext } from '../types'
@@ -25,26 +26,29 @@ export const ErrorHandlerMiddleware =
// 尝试执行下一个中间件
return await next(ctx, params)
} catch (error: any) {
console.log('ErrorHandlerMiddleware_error', error)
// 1. 使用通用的工具函数将错误解析为标准格式
const errorChunk = createErrorChunk(error)
// 2. 调用从外部传入的 onError 回调
if (params.onError) {
params.onError(error)
}
// 3. 根据配置决定是重新抛出错误,还是将其作为流的一部分向下传递
if (shouldThrow) {
throw error
}
// 如果不抛出,则创建一个只包含该错误块的流并向下传递
const errorStream = new ReadableStream<Chunk>({
start(controller) {
controller.enqueue(errorChunk)
controller.close()
let errorStream: ReadableStream<Chunk> | undefined
// 有些sdk的abort error 是直接抛出的
if (!isAbortError(error)) {
// 1. 使用通用的工具函数将错误解析为标准格式
const errorChunk = createErrorChunk(error)
// 2. 调用从外部传入的 onError 回调
if (params.onError) {
params.onError(error)
}
})
// 3. 根据配置决定是重新抛出错误,还是将其作为流的一部分向下传递
if (shouldThrow) {
throw error
}
// 如果不抛出,则创建一个只包含该错误块的流并向下传递
errorStream = new ReadableStream<Chunk>({
start(controller) {
controller.enqueue(errorChunk)
controller.close()
}
})
}
return {
rawOutput: undefined,

View File

@@ -153,7 +153,7 @@ function createToolHandlingTransform(
if (toolResult.length > 0) {
const output = ctx._internal.toolProcessingState?.output
const newParams = buildParamsWithToolResults(ctx, currentParams, output, toolResult, toolCalls)
const newParams = buildParamsWithToolResults(ctx, currentParams, output!, toolResult, toolCalls)
await executeWithToolHandling(newParams, depth + 1)
}
} catch (error) {
@@ -243,7 +243,7 @@ async function executeToolUseResponses(
function buildParamsWithToolResults(
ctx: CompletionsContext,
currentParams: CompletionsParams,
output: SdkRawOutput | string | undefined,
output: SdkRawOutput | string,
toolResults: SdkMessageParam[],
toolCalls: SdkToolCall[]
): CompletionsParams {

View File

@@ -15,6 +15,8 @@ export const RawStreamListenerMiddleware: CompletionsMiddleware =
// 在这里可以监听到从SDK返回的最原始流
if (result.rawOutput) {
console.log(`[${MIDDLEWARE_NAME}] 检测到原始SDK输出准备附加监听器`)
const providerType = ctx.apiClientInstance.provider.type
// TODO: 后面下放到AnthropicAPIClient
if (providerType === 'anthropic') {

View File

@@ -37,7 +37,7 @@ export const ResponseTransformMiddleware: CompletionsMiddleware =
}
// 获取响应转换器
const responseChunkTransformer = apiClient.getResponseChunkTransformer(ctx)
const responseChunkTransformer = apiClient.getResponseChunkTransformer?.()
if (!responseChunkTransformer) {
Logger.warn(`[${MIDDLEWARE_NAME}] No ResponseChunkTransformer available, skipping transformation`)
return result

View File

@@ -25,6 +25,7 @@ export const StreamAdapterMiddleware: CompletionsMiddleware =
// 但是这个中间件的职责是流适配,是否在这调用优待商榷
// 调用下游中间件
const result = await next(ctx, params)
if (
result.rawOutput &&
!(result.rawOutput instanceof ReadableStream) &&

View File

@@ -14,6 +14,8 @@ export const TransformCoreToSdkParamsMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
Logger.debug(`🔄 [${MIDDLEWARE_NAME}] Starting core to SDK params transformation:`, ctx)
const internal = ctx._internal
// 🔧 检测递归调用:检查 params 中是否携带了预处理的 SDK 消息

View File

@@ -17,6 +17,7 @@ export const ImageGenerationMiddleware: CompletionsMiddleware =
const { assistant, messages } = params
const client = context.apiClientInstance as BaseApiClient<OpenAI>
const signal = context._internal?.flowControl?.abortSignal
if (!assistant.model || !isDedicatedImageGenerationModel(assistant.model) || typeof messages === 'string') {
return next(context, params)
}

View File

@@ -1,110 +0,0 @@
/**
* Cherry Studio 参数转换插件
* 专门处理 Cherry Studio 特有的消息格式、文件处理、Assistant 设置等
*/
import { definePlugin } from '@cherrystudio/ai-core'
import type { Assistant, MCPTool, Message, Model } from '@renderer/types'
import {
buildStreamTextParams,
convertMessagesToSdkMessages,
getCustomParameters,
getTemperature,
getTopP
} from '../transformParameters'
/**
* Cherry Studio 核心转换插件
* 负责将 Cherry Studio 的数据结构转换为 AI SDK 兼容格式
*/
export const cherryStudioTransformPlugin = definePlugin({
name: 'cherry-studio-transform',
/**
* 转换请求参数
* 将 Cherry Studio 的 Assistant + Messages 转换为 AI SDK 格式
*/
transformParams: async (params: any, context) => {
// 检查是否有 Cherry Studio 特有的数据结构
const cherryData = context.metadata?.cherryStudio
if (!cherryData) {
return params // 不是 Cherry Studio 调用,直接返回
}
const { assistant, messages, mcpTools, enableTools } = cherryData
try {
// 1. 转换 Cherry Studio 消息为 AI SDK 消息
const sdkMessages = await convertMessagesToSdkMessages(messages as Message[], assistant.model as Model)
// 2. 构建完整的 AI SDK 参数
const { params: transformedParams } = await buildStreamTextParams(sdkMessages, assistant as Assistant, {
mcpTools: mcpTools as MCPTool[],
enableTools,
requestOptions: {
signal: params.abortSignal,
headers: params.headers
}
})
// 3. 合并原始参数和转换后的参数
return {
...params,
...transformedParams,
// 保留原始的一些关键参数
abortSignal: params.abortSignal,
headers: params.headers
}
} catch (error) {
console.error('Cherry Studio 参数转换失败:', error)
return params // 转换失败时返回原始参数
}
}
})
/**
* Cherry Studio Assistant 设置插件
* 专门处理 Assistant 的温度、TopP、自定义参数等设置
*/
export const cherryStudioSettingsPlugin = definePlugin({
name: 'cherry-studio-settings',
transformParams: async (params: any, context) => {
const cherryData = context.metadata?.cherryStudio
if (!cherryData?.assistant) {
return params
}
const { assistant } = cherryData
const model = assistant.model as Model
return {
...params,
temperature: getTemperature(assistant as Assistant, model),
topP: getTopP(assistant as Assistant, model),
...getCustomParameters(assistant as Assistant)
}
}
})
/**
* 便捷函数:为 Cherry Studio 调用准备上下文元数据
*/
export function createCherryStudioContext(
assistant: Assistant,
messages: Message[],
options: {
mcpTools?: MCPTool[]
enableTools?: boolean
} = {}
) {
return {
cherryStudio: {
assistant,
messages,
mcpTools: options.mcpTools,
enableTools: options.enableTools
}
}
}

View File

@@ -1,311 +0,0 @@
/**
* AI SDK 参数转换模块
* 统一管理从各个 apiClient 提取的参数处理和转换功能
*/
import type { CoreMessage, StreamTextParams } from '@cherrystudio/ai-core'
import {
isGenerateImageModel,
isNotSupportTemperatureAndTopP,
isOpenRouterBuiltInWebSearchModel,
isReasoningModel,
isSupportedDisableGenerationModel,
isSupportedFlexServiceTier,
isSupportedReasoningEffortModel,
isSupportedThinkingTokenModel,
isWebSearchModel
} from '@renderer/config/models'
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
import type { Assistant, MCPTool, Message, Model } from '@renderer/types'
import { FileTypes } from '@renderer/types'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { defaultTimeout } from '@shared/config/constant'
/**
* 获取温度参数
*/
export function getTemperature(assistant: Assistant, model: Model): number | undefined {
return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.temperature
}
/**
* 获取 TopP 参数
*/
export function getTopP(assistant: Assistant, model: Model): number | undefined {
return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.topP
}
/**
* 获取超时设置
*/
export function getTimeout(model: Model): number {
if (isSupportedFlexServiceTier(model)) {
return 15 * 1000 * 60
}
return defaultTimeout
}
/**
* 构建系统提示词
*/
export async function buildSystemPromptWithTools(
prompt: string,
mcpTools?: MCPTool[],
assistant?: Assistant
): Promise<string> {
return await buildSystemPrompt(prompt, mcpTools, assistant)
}
// /**
// * 转换 MCP 工具为 AI SDK 工具格式
// * 注意:这里返回通用格式,实际使用时需要根据具体 provider 转换
// TODO: 需要使用ai-sdk的mcp
// */
// export function convertMcpToolsToSdkTools(mcpTools: MCPTool[]): Pick<StreamTextParams, 'tools'> {
// return mcpTools.map((tool) => ({
// type: 'function',
// function: {
// name: tool.id,
// description: tool.description,
// parameters: tool.inputSchema || {}
// }
// }))
// }
/**
* 提取文件内容
*/
export async function extractFileContent(message: Message): Promise<string> {
const fileBlocks = findFileBlocks(message)
if (fileBlocks.length > 0) {
const textFileBlocks = fileBlocks.filter(
(fb) => fb.file && [FileTypes.TEXT, FileTypes.DOCUMENT].includes(fb.file.type)
)
if (textFileBlocks.length > 0) {
let text = ''
const divider = '\n\n---\n\n'
for (const fileBlock of textFileBlocks) {
const file = fileBlock.file
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
const fileNameRow = 'file: ' + file.origin_name + '\n\n'
text = text + fileNameRow + fileContent + divider
}
return text
}
}
return ''
}
/**
* 转换消息为 AI SDK 参数格式
* 基于 OpenAI 格式的通用转换,支持文本、图片和文件
*/
export async function convertMessageToSdkParam(message: Message, isVisionModel = false): Promise<any> {
const content = getMainTextContent(message)
const fileBlocks = findFileBlocks(message)
const imageBlocks = findImageBlocks(message)
// 简单消息(无文件无图片)
if (fileBlocks.length === 0 && imageBlocks.length === 0) {
return {
role: message.role === 'system' ? 'user' : message.role,
content
}
}
// 复杂消息(包含文件或图片)
const parts: any[] = []
if (content) {
parts.push({ type: 'text', text: content })
}
// 处理图片(仅在支持视觉的模型中)
if (isVisionModel) {
for (const imageBlock of imageBlocks) {
if (imageBlock.file) {
try {
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
parts.push({
type: 'image_url',
image_url: { url: image.data }
})
} catch (error) {
console.warn('Failed to load image:', error)
}
} else if (imageBlock.url && imageBlock.url.startsWith('data:')) {
parts.push({
type: 'image_url',
image_url: { url: imageBlock.url }
})
}
}
}
// 处理文件
for (const fileBlock of fileBlocks) {
const file = fileBlock.file
if (!file) continue
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
try {
const fileContent = await window.api.file.read(file.id + file.ext)
parts.push({
type: 'text',
text: `${file.origin_name}\n${fileContent.trim()}`
})
} catch (error) {
console.warn('Failed to read file:', error)
}
}
}
return {
role: message.role === 'system' ? 'user' : message.role,
content: parts.length === 1 && parts[0].type === 'text' ? parts[0].text : parts
}
}
/**
* 转换 Cherry Studio 消息数组为 AI SDK 消息数组
*/
export async function convertMessagesToSdkMessages(
messages: Message[],
model: Model
): Promise<StreamTextParams['messages']> {
const sdkMessages: StreamTextParams['messages'] = []
const isVision = model.id.includes('vision') || model.id.includes('gpt-4') // 简单的视觉模型检测
for (const message of messages) {
const sdkMessage = await convertMessageToSdkParam(message, isVision)
sdkMessages.push(sdkMessage)
}
return sdkMessages
}
/**
* 构建 AI SDK 流式参数
* 这是主要的参数构建函数,整合所有转换逻辑
*/
export async function buildStreamTextParams(
sdkMessages: StreamTextParams['messages'],
assistant: Assistant,
options: {
mcpTools?: MCPTool[]
enableTools?: boolean
requestOptions?: {
signal?: AbortSignal
timeout?: number
headers?: Record<string, string>
}
} = {}
): Promise<{ params: StreamTextParams; modelId: string }> {
const { mcpTools, enableTools = false } = options
const model = assistant.model || getDefaultModel()
const { maxTokens, reasoning_effort } = getAssistantSettings(assistant)
const enableReasoning =
((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) &&
reasoning_effort !== undefined) ||
(isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model)))
const enableWebSearch =
(assistant.enableWebSearch && isWebSearchModel(model)) ||
isOpenRouterBuiltInWebSearchModel(model) ||
model.id.includes('sonar') ||
false
const enableGenerateImage =
isGenerateImageModel(model) &&
(isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage || false : true)
// 构建系统提示
let systemPrompt = assistant.prompt || ''
if (mcpTools && mcpTools.length > 0) {
systemPrompt = await buildSystemPromptWithTools(systemPrompt, mcpTools, assistant)
}
// 构建基础参数
const params: StreamTextParams = {
messages: sdkMessages,
maxTokens: maxTokens || 1000,
temperature: getTemperature(assistant, model),
topP: getTopP(assistant, model),
system: systemPrompt || undefined,
abortSignal: options.requestOptions?.signal,
headers: options.requestOptions?.headers,
// 随便填着,后面再改
providerOptions: {
reasoning: {
enabled: enableReasoning
},
webSearch: {
enabled: enableWebSearch
},
generateImage: {
enabled: enableGenerateImage
}
},
...getCustomParameters(assistant)
}
// 添加工具(如果启用且有工具)
if (enableTools && mcpTools && mcpTools.length > 0) {
// TODO: 暂时注释掉工具支持,等类型问题解决后再启用
// params.tools = convertMcpToolsToSdkTools(mcpTools)
}
return { params, modelId: model.id }
}
/**
* 构建非流式的 generateText 参数
*/
export async function buildGenerateTextParams(
messages: CoreMessage[],
assistant: Assistant,
options: {
mcpTools?: MCPTool[]
enableTools?: boolean
} = {}
): Promise<any> {
// 复用流式参数的构建逻辑
return await buildStreamTextParams(messages, assistant, options)
}
/**
* 获取自定义参数
* 从 assistant 设置中提取自定义参数
*/
export function getCustomParameters(assistant: Assistant): Record<string, any> {
return (
assistant?.settings?.customParameters?.reduce((acc, param) => {
if (!param.name?.trim()) {
return acc
}
if (param.type === 'json') {
const value = param.value as string
if (value === 'undefined') {
return { ...acc, [param.name]: undefined }
}
try {
return { ...acc, [param.name]: JSON.parse(value) }
} catch {
return { ...acc, [param.name]: value }
}
}
return {
...acc,
[param.name]: param.value
}
}, {}) || {}
)
}

View File

@@ -1,8 +1 @@
<svg width="22" height="22" viewBox="13 -2 25 22" xmlns="http://www.w3.org/2000/svg">
<g id="White=False">
<g id="if">
<path d="M21.2002 3.73454C22.5633 3.73454 23.0666 2.89917 23.0666 1.86812C23.0666 0.837081 22.5623 0.00170898 21.2002 0.00170898C19.838 0.00170898 19.3337 0.837081 19.3337 1.86812C19.3337 2.89917 19.838 3.73454 21.2002 3.73454Z" fill="#0033FF"/>
<path d="M27.7336 4.13435V5.33473H24.6668V8.00171H27.7336V14.6687H22.6668V5.33567H15.9998V8.00265H19.7336V14.6696H15.3337V17.3366H35.3337V14.6696H30.6668V8.00265H35.3337V5.33567H30.6668V2.66869H35.3337V0.00170898H31.8671C29.5877 0.00170898 27.7336 1.8559 27.7336 4.13529V4.13435Z" fill="#0033FF"/>
</g>
</g>
</svg>
<svg fill="currentColor" fill-rule="evenodd" height="1em" style="flex:none;line-height:1" viewBox="0 0 24 24" width="1em" xmlns="http://www.w3.org/2000/svg"><title>Dify</title><clipPath id="lobe-icons-dify-fill"><path d="M1 0h10.286c6.627 0 12 5.373 12 12s-5.373 12-12 12H1V0z"></path></clipPath><foreignObject clip-path="url(#lobe-icons-dify-fill)" height="24" style="background:conic-gradient(from 180deg at 50% 50%, #0222C3, #8FB1F4, #FFFFFF)" width="24"></foreignObject></svg>

Before

Width:  |  Height:  |  Size: 680 B

After

Width:  |  Height:  |  Size: 480 B

View File

@@ -197,26 +197,11 @@
}
}
.ant-dropdown {
.ant-dropdown-menu {
max-height: 50vh;
overflow-y: auto;
border: 0.5px solid var(--color-border);
.ant-dropdown-menu-sub {
max-height: 50vh;
width: max-content;
overflow-y: auto;
overflow-x: hidden;
border: 0.5px solid var(--color-border);
}
}
.ant-dropdown-arrow + .ant-dropdown-menu {
border: none;
}
}
.ant-select-dropdown {
border: 0.5px solid var(--color-border);
.ant-dropdown-menu .ant-dropdown-menu-sub {
max-height: 350px;
width: max-content;
overflow-y: auto;
overflow-x: hidden;
}
.ant-collapse {

View File

@@ -4,7 +4,7 @@ import { CodeTool, CodeToolbar, TOOL_SPECS, useCodeTool } from '@renderer/compon
import { useSettings } from '@renderer/hooks/useSettings'
import { pyodideService } from '@renderer/services/PyodideService'
import { extractTitle } from '@renderer/utils/formats'
import { getExtensionByLanguage, isValidPlantUML } from '@renderer/utils/markdown'
import { isValidPlantUML } from '@renderer/utils/markdown'
import dayjs from 'dayjs'
import { CirclePlay, CodeXml, Copy, Download, Eye, Square, SquarePen, SquareSplitHorizontal } from 'lucide-react'
import React, { memo, useCallback, useEffect, useMemo, useState } from 'react'
@@ -67,21 +67,23 @@ const CodeBlockView: React.FC<Props> = ({ children, language, onSave }) => {
window.message.success({ content: t('code_block.copy.success'), key: 'copy-code' })
}, [children, t])
const handleDownloadSource = useCallback(async () => {
const handleDownloadSource = useCallback(() => {
let fileName = ''
// 尝试提取 HTML 标题
// 尝试提取标题
if (language === 'html' && children.includes('</html>')) {
fileName = extractTitle(children) || ''
const title = extractTitle(children)
if (title) {
fileName = `${title}.html`
}
}
// 默认使用日期格式命名
if (!fileName) {
fileName = `${dayjs().format('YYYYMMDDHHmm')}`
fileName = `${dayjs().format('YYYYMMDDHHmm')}.${language}`
}
const ext = await getExtensionByLanguage(language)
window.api.file.save(`${fileName}${ext}`, children)
window.api.file.save(fileName, children)
}, [children, language])
const handleRunScript = useCallback(() => {

View File

@@ -41,10 +41,11 @@ const MarkdownEditor: FC<MarkdownEditorProps> = ({
return (
<EditorContainer style={{ height }}>
<InputArea value={inputValue} onChange={handleChange} placeholder={placeholder} autoFocus={autoFocus} />
<PreviewArea className="markdown">
<PreviewArea>
<ReactMarkdown
remarkPlugins={[remarkGfm, remarkCjkFriendly, remarkMath]}
rehypePlugins={[rehypeRaw, rehypeKatex]}>
rehypePlugins={[rehypeRaw, rehypeKatex]}
className="markdown">
{inputValue || t('settings.provider.notes.markdown_editor_default_value')}
</ReactMarkdown>
</PreviewArea>

View File

@@ -1,9 +1,16 @@
import { backupToWebdav } from '@renderer/services/BackupService'
import { Input, Modal } from 'antd'
import { backupToWebdav, restoreFromWebdav } from '@renderer/services/BackupService'
import { formatFileSize } from '@renderer/utils'
import { Input, Modal, Select, Spin } from 'antd'
import dayjs from 'dayjs'
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
interface BackupFile {
fileName: string
modifiedTime: string
size: number
}
interface WebdavModalProps {
isModalVisible: boolean
handleBackup: () => void
@@ -80,3 +87,156 @@ export function WebdavBackupModal({
</Modal>
)
}
interface WebdavRestoreModalProps {
isRestoreModalVisible: boolean
handleRestore: () => void
handleCancel: () => void
restoring: boolean
selectedFile: string | null
setSelectedFile: (value: string | null) => void
loadingFiles: boolean
backupFiles: BackupFile[]
}
interface UseWebdavRestoreModalProps {
webdavHost: string | undefined
webdavUser: string | undefined
webdavPass: string | undefined
webdavPath: string | undefined
restoreMethod?: typeof restoreFromWebdav
}
export function useWebdavRestoreModal({
webdavHost,
webdavUser,
webdavPass,
webdavPath,
restoreMethod
}: UseWebdavRestoreModalProps) {
const { t } = useTranslation()
const [isRestoreModalVisible, setIsRestoreModalVisible] = useState(false)
const [restoring, setRestoring] = useState(false)
const [selectedFile, setSelectedFile] = useState<string | null>(null)
const [loadingFiles, setLoadingFiles] = useState(false)
const [backupFiles, setBackupFiles] = useState<BackupFile[]>([])
const showRestoreModal = useCallback(async () => {
if (!webdavHost) {
window.message.error({ content: t('message.error.invalid.webdav'), key: 'webdav-error' })
return
}
setIsRestoreModalVisible(true)
setLoadingFiles(true)
try {
const files = await window.api.backup.listWebdavFiles({
webdavHost,
webdavUser,
webdavPass,
webdavPath
})
setBackupFiles(files)
} catch (error: any) {
window.message.error({ content: error.message, key: 'list-files-error' })
} finally {
setLoadingFiles(false)
}
}, [webdavHost, webdavUser, webdavPass, webdavPath, t])
const handleRestore = useCallback(async () => {
if (!selectedFile || !webdavHost) {
window.message.error({
content: !selectedFile ? t('message.error.no.file.selected') : t('message.error.invalid.webdav'),
key: 'restore-error'
})
return
}
window.modal.confirm({
title: t('settings.data.webdav.restore.confirm.title'),
content: t('settings.data.webdav.restore.confirm.content'),
centered: true,
onOk: async () => {
setRestoring(true)
try {
await (restoreMethod ?? restoreFromWebdav)(selectedFile)
setIsRestoreModalVisible(false)
} catch (error: any) {
window.message.error({ content: error.message, key: 'restore-error' })
} finally {
setRestoring(false)
}
}
})
}, [selectedFile, webdavHost, t, restoreMethod])
const handleCancel = () => {
setIsRestoreModalVisible(false)
}
return {
isRestoreModalVisible,
handleRestore,
handleCancel,
restoring,
selectedFile,
setSelectedFile,
loadingFiles,
backupFiles,
showRestoreModal
}
}
export function WebdavRestoreModal({
isRestoreModalVisible,
handleRestore,
handleCancel,
restoring,
selectedFile,
setSelectedFile,
loadingFiles,
backupFiles
}: WebdavRestoreModalProps) {
const { t } = useTranslation()
return (
<Modal
title={t('settings.data.webdav.restore.modal.title')}
open={isRestoreModalVisible}
onOk={handleRestore}
onCancel={handleCancel}
okButtonProps={{ loading: restoring }}
width={600}
transitionName="animation-move-down"
centered>
<div style={{ position: 'relative' }}>
<Select
style={{ width: '100%' }}
placeholder={t('settings.data.webdav.restore.modal.select.placeholder')}
value={selectedFile}
onChange={setSelectedFile}
options={backupFiles.map(formatFileOption)}
loading={loadingFiles}
showSearch
filterOption={(input, option) => (option?.label ?? '').toLowerCase().includes(input.toLowerCase())}
/>
{loadingFiles && (
<div style={{ position: 'absolute', top: '50%', left: '50%', transform: 'translate(-50%, -50%)' }}>
<Spin />
</div>
)}
</div>
</Modal>
)
}
function formatFileOption(file: BackupFile) {
const date = dayjs(file.modifiedTime).format('YYYY-MM-DD HH:mm:ss')
const size = formatFileSize(file.size)
return {
label: `${file.fileName} (${date}, ${size})`,
value: file.fileName
}
}

View File

@@ -184,7 +184,7 @@ const visionAllowedModels = [
'deepseek-vl(?:[\\w-]+)?',
'kimi-latest',
'gemma-3(?:-[\\w-]+)',
'doubao-seed-1[.-]6(?:-[\\w-]+)'
'doubao-1.6-seed(?:-[\\w-]+)'
]
const visionExcludedModels = [
@@ -238,8 +238,7 @@ export const FUNCTION_CALLING_MODELS = [
'glm-4(?:-[\\w-]+)?',
'learnlm(?:-[\\w-]+)?',
'gemini(?:-[\\w-]+)?', // 提前排除了gemini的嵌入模型
'grok-3(?:-[\\w-]+)?',
'doubao-seed-1[.-]6(?:-[\\w-]+)?'
'grok-3(?:-[\\w-]+)?'
]
const FUNCTION_CALLING_EXCLUDED_MODELS = [
@@ -2289,8 +2288,6 @@ export const TEXT_TO_IMAGES_MODELS_SUPPORT_IMAGE_ENHANCEMENT = [
]
export const SUPPORTED_DISABLE_GENERATION_MODELS = [
'gemini-2.0-flash-exp-image-generation',
'gemini-2.0-flash-preview-image-generation',
'gemini-2.0-flash-exp',
'gpt-4o',
'gpt-4o-mini',
@@ -2310,7 +2307,21 @@ export const GENERATE_IMAGE_MODELS = [
...SUPPORTED_DISABLE_GENERATION_MODELS
]
export const GEMINI_SEARCH_REGEX = new RegExp('gemini-2\\..*', 'i')
export const GEMINI_SEARCH_MODELS = [
'gemini-2.0-flash',
'gemini-2.0-flash-lite',
'gemini-2.0-flash-exp',
'gemini-2.0-flash-001',
'gemini-2.0-pro-exp-02-05',
'gemini-2.0-pro-exp',
'gemini-2.5-pro-exp',
'gemini-2.5-pro-exp-03-25',
'gemini-2.5-pro-preview',
'gemini-2.5-pro-preview-03-25',
'gemini-2.5-pro-preview-05-06',
'gemini-2.5-flash-preview',
'gemini-2.5-flash-preview-04-17'
]
export const OPENAI_NO_SUPPORT_DEV_ROLE_MODELS = ['o1-preview', 'o1-mini']
@@ -2354,7 +2365,7 @@ export function isVisionModel(model: Model): boolean {
// }
if (model.provider === 'doubao') {
return VISION_REGEX.test(model.name) || VISION_REGEX.test(model.id) || model.type?.includes('vision') || false
return VISION_REGEX.test(model.name) || model.type?.includes('vision') || false
}
return VISION_REGEX.test(model.id) || model.type?.includes('vision') || false
@@ -2643,13 +2654,13 @@ export function isWebSearchModel(model: Model): boolean {
}
if (provider?.type === 'openai') {
if (GEMINI_SEARCH_REGEX.test(baseName) || isOpenAIWebSearchModel(model)) {
if (GEMINI_SEARCH_MODELS.includes(baseName) || isOpenAIWebSearchModel(model)) {
return true
}
}
if (provider.id === 'gemini' || provider?.type === 'gemini') {
return GEMINI_SEARCH_REGEX.test(baseName)
return GEMINI_SEARCH_MODELS.includes(baseName)
}
if (provider.id === 'hunyuan') {
@@ -2688,7 +2699,7 @@ export function isOpenRouterBuiltInWebSearchModel(model: Model): boolean {
return false
}
return isOpenAIWebSearchChatCompletionOnlyModel(model) || model.id.includes('sonar')
return isOpenAIWebSearchModel(model) || model.id.includes('sonar')
}
export function isGenerateImageModel(model: Model): boolean {
@@ -2826,7 +2837,6 @@ export function groupQwenModels(models: Model[]): Record<string, Model[]> {
export const THINKING_TOKEN_MAP: Record<string, { min: number; max: number }> = {
// Gemini models
'gemini-2\\.5-flash-lite.*$': { min: 512, max: 24576 },
'gemini-.*-flash.*$': { min: 0, max: 24576 },
'gemini-.*-pro.*$': { min: 128, max: 32768 },
@@ -2853,10 +2863,10 @@ export const findTokenLimit = (modelId: string): { min: number; max: number } |
// Doubao 支持思考模式的模型正则
export const DOUBAO_THINKING_MODEL_REGEX =
/doubao-(?:1[.-]5-thinking-vision-pro|1[.-]5-thinking-pro-m|seed-1[.-]6(?:-flash)?)(?:-[\w-]+)?/i
/doubao-(?:1(\.|-5)-thinking-vision-pro|1(\.|-)5-thinking-pro-m|seed-1\.6|seed-1\.6-flash)(?:-[\\w-]+)?/i
// 支持 auto 的 Doubao 模型 doubao-seed-1.6-xxx doubao-seed-1-6-xxx doubao-1-5-thinking-pro-m-xxx
export const DOUBAO_THINKING_AUTO_MODEL_REGEX = /doubao-(1-5-thinking-pro-m|seed-1\.6|seed-1-6-[\w-]+)(?:-[\w-]+)*/i
// 支持 auto 的 Doubao 模型
export const DOUBAO_THINKING_AUTO_MODEL_REGEX = /doubao-(?:1-5-thinking-pro-m|seed-1.6)(?:-[\\w-]+)?/i
export function isDoubaoThinkingAutoModel(model: Model): boolean {
return DOUBAO_THINKING_AUTO_MODEL_REGEX.test(model.id)

View File

@@ -183,7 +183,7 @@
"input.new.context": "Clear Context {{Command}}",
"input.new_topic": "New Topic {{Command}}",
"input.pause": "Pause",
"input.placeholder": "Type your message here, press {{key}} to send...",
"input.placeholder": "Type your message here...",
"input.send": "Send",
"input.settings": "Settings",
"input.topics": " Topics ",
@@ -787,18 +787,6 @@
"string": "Text"
},
"pinned": "Pinned",
"price": {
"cost": "Cost",
"currency": "Currency",
"custom": "Custom",
"custom_currency": "Custom Currency",
"custom_currency_placeholder": "Enter Custom Currency",
"input": "Input Price",
"million_tokens": "M Tokens",
"output": "Output Price",
"price": "Price"
},
"reasoning": "Reasoning",
"rerank_model": "Reranker",
"rerank_model_support_provider": "Currently, the reranker model only supports some providers ({{provider}})",
"rerank_model_not_support_provider": "Currently, the reranker model does not support this provider ({{provider}})",
@@ -1085,24 +1073,6 @@
"assistant.title": "Default Assistant",
"data": {
"app_data": "App Data",
"app_data.select": "Modify Directory",
"app_data.select_title": "Change App Data Directory",
"app_data.restart_notice": "The app will need to restart to apply the changes",
"app_data.copy_data_option": "Copy data from original directory to new directory",
"app_data.copy_time_notice": "Copying data may take a while, do not force quit app",
"app_data.path_changed_without_copy": "Path changed successfully, but data not copied",
"app_data.copying_warning": "Data copying, do not force quit app",
"app_data.copying": "Copying data to new location...",
"app_data.copy_success": "Successfully copied data to new location",
"app_data.copy_failed": "Failed to copy data",
"app_data.select_success": "Data directory changed, the app will restart to apply changes",
"app_data.select_error": "Failed to change data directory",
"app_data.migration_title": "Data Migration",
"app_data.original_path": "Original Path",
"app_data.new_path": "New Path",
"app_data.select_error_root_path": "New path cannot be the root path",
"app_data.select_error_write_permission": "New path does not have write permission",
"app_data.stop_quit_app_reason": "The app is currently migrating data and cannot be exited",
"app_knowledge": "Knowledge Base Files",
"app_knowledge.button.delete": "Delete File",
"app_knowledge.remove_all": "Remove Knowledge Base Files",
@@ -1135,8 +1105,7 @@
"obsidian": "Export to Obsidian",
"siyuan": "Export to SiYuan Note",
"joplin": "Export to Joplin",
"docx": "Export as Word",
"plain_text": "Copy as Plain Text"
"docx": "Export as Word"
},
"joplin": {
"check": {
@@ -1163,7 +1132,7 @@
"markdown_export.select": "Select",
"markdown_export.title": "Markdown Export",
"markdown_export.show_model_name.title": "Use Model Name on Export",
"markdown_export.show_model_name.help": "When enabled, the model name will be displayed when exporting to Markdown. Note: This option also affects all export methods through Markdown, such as Notion, Yuque, etc.",
"markdown_export.show_model_name.help": "When enabled, the topic-naming model will be used to create titles for exported messages. Note: This option also affects all export methods through Markdown, such as Notion, Yuque, etc.",
"markdown_export.show_model_provider.title": "Show Model Provider",
"markdown_export.show_model_provider.help": "Display the model provider (e.g., OpenAI, Gemini) when exporting to Markdown",
"minute_interval_one": "{{count}} minute",
@@ -1227,6 +1196,8 @@
"restore.confirm.content": "Restoring from WebDAV will overwrite current data. Do you want to continue?",
"restore.confirm.title": "Confirm Restore",
"restore.content": "Restore from WebDAV will overwrite the current data, continue?",
"restore.modal.select.placeholder": "Please select a backup file to restore",
"restore.modal.title": "Restore from WebDAV",
"restore.title": "Restore from WebDAV",
"syncError": "Backup Error",
"syncStatus": "Backup Status",
@@ -1914,8 +1885,7 @@
"model_desc": "Model used for translation service",
"bidirectional": "Bidirectional Translation Settings",
"bidirectional_tip": "When enabled, only bidirectional translation between source and target languages is supported",
"scroll_sync": "Scroll Sync Settings",
"preview": "Markdown Preview"
"scroll_sync": "Scroll Sync Settings"
},
"title": "Translation",
"tooltip.newline": "Newline",

View File

@@ -183,7 +183,7 @@
"input.new.context": "コンテキストをクリア {{Command}}",
"input.new_topic": "新しいトピック {{Command}}",
"input.pause": "一時停止",
"input.placeholder": "ここにメッセージを入力し、{{key}} を押して送信...",
"input.placeholder": "ここにメッセージを入力...",
"input.send": "送信",
"input.settings": "設定",
"input.topics": " トピック ",
@@ -804,19 +804,7 @@
"vision": "画像",
"websearch": "ウェブ検索"
},
"rerank_model_not_support_provider": "現在、並べ替えモデルはこのプロバイダー ({{provider}}) をサポートしていません。",
"price": {
"cost": "コスト",
"currency": "通貨",
"custom": "カスタム",
"custom_currency": "カスタム通貨",
"custom_currency_placeholder": "カスタム通貨を入力してください",
"input": "入力価格",
"million_tokens": "百万トークン",
"output": "出力価格",
"price": "価格"
},
"reasoning": "思考"
"rerank_model_not_support_provider": "現在、並べ替えモデルはこのプロバイダー ({{provider}}) をサポートしていません。"
},
"navbar": {
"expand": "ダイアログを展開",
@@ -1083,25 +1071,7 @@
"assistant.title": "デフォルトアシスタント",
"data": {
"app_data": "アプリデータ",
"app_data.select": "ディレクトリを変更",
"app_data.select_title": "アプリデータディレクトリの変更",
"app_data.restart_notice": "変更を適用するには、アプリを再起動する必要があります",
"app_data.copy_data_option": "データをコピーする, 開くと元のディレクトリのデータが新しいディレクトリにコピーされます",
"app_data.copy_time_notice": "データコピーには時間がかかります。アプリを強制終了しないでください",
"app_data.path_changed_without_copy": "パスが変更されましたが、データがコピーされていません",
"app_data.copying_warning": "データコピー中、アプリを強制終了しないでください",
"app_data.copying": "新しい場所にデータをコピーしています...",
"app_data.copy_success": "データを新しい場所に正常にコピーしました",
"app_data.copy_failed": "データのコピーに失敗しました",
"app_data.select_success": "データディレクトリが変更されました。変更を適用するためにアプリが再起動します",
"app_data.select_error": "データディレクトリの変更に失敗しました",
"app_data.migration_title": "データ移行",
"app_data.original_path": "元のパス",
"app_data.new_path": "新しいパス",
"app_data.select_error_root_path": "新しいパスはルートパスにできません",
"app_data.select_error_write_permission": "新しいパスに書き込み権限がありません",
"app_data.stop_quit_app_reason": "アプリは現在データを移行しているため、終了できません",
"app_knowledge": "知識ベースファイル",
"app_knowledge": "ナレッジベースファイル",
"app_knowledge.button.delete": "ファイルを削除",
"app_knowledge.remove_all": "ナレッジベースファイルを削除",
"app_knowledge.remove_all_confirm": "ナレッジベースファイルを削除すると、ナレッジベース自体は削除されません。これにより、ストレージ容量を節約できます。続行しますか?",
@@ -1133,8 +1103,7 @@
"obsidian": "Obsidianにエクスポート",
"siyuan": "思源ノートにエクスポート",
"joplin": "Joplinにエクスポート",
"docx": "Wordとしてエクスポート",
"plain_text": "プレーンテキストとしてコピー"
"docx": "Wordとしてエクスポート"
},
"joplin": {
"check": {
@@ -1161,7 +1130,7 @@
"markdown_export.select": "選択",
"markdown_export.title": "Markdown エクスポート",
"markdown_export.show_model_name.title": "エクスポート時にモデル名を使用",
"markdown_export.show_model_name.help": "有効にすると、Markdownエクスポート時にモデル名を表示します。注意この設定はNotion、Yuqueなど、Markdownを通じたすべてのエクスポート方法にも影響します。",
"markdown_export.show_model_name.help": "有効にすると、トピック命名モデルがエクスポートされたメッセージのタイトル作成に使用されます。注意この設定はNotion、Yuqueなど、Markdownを通じたすべてのエクスポート方法にも影響します。",
"markdown_export.show_model_provider.title": "モデルプロバイダーを表示",
"markdown_export.show_model_provider.help": "Markdownエクスポート時にモデルプロバイダーOpenAI、Geminiなどを表示します。",
"minute_interval_one": "{{count}} 分",
@@ -1207,6 +1176,8 @@
"restore.confirm.content": "WebDAV から復元すると現在のデータが上書きされます。続行しますか?",
"restore.confirm.title": "復元を確認",
"restore.content": "WebDAVから復元すると現在のデータが上書きされます。続行しますか",
"restore.modal.select.placeholder": "復元するバックアップファイルを選択してください",
"restore.modal.title": "WebDAV から復元",
"restore.title": "WebDAVから復元",
"syncError": "バックアップエラー",
"syncStatus": "バックアップ状態",
@@ -1913,8 +1884,7 @@
"model_desc": "翻訳サービスで使用されるモデル",
"bidirectional": "双方向翻訳設定",
"bidirectional_tip": "有効にすると、ソース言語と目標言語間の双方向翻訳のみがサポートされます",
"scroll_sync": "スクロール同期設定",
"preview": "Markdown プレビュー"
"scroll_sync": "スクロール同期設定"
},
"title": "翻訳",
"tooltip.newline": "改行",

View File

@@ -183,7 +183,7 @@
"input.new.context": "Очистить контекст {{Command}}",
"input.new_topic": "Новый топик {{Command}}",
"input.pause": "Остановить",
"input.placeholder": "Введите ваше сообщение здесь, нажмите {{key}} для отправки...",
"input.placeholder": "Введите ваше сообщение здесь...",
"input.send": "Отправить",
"input.settings": "Настройки",
"input.topics": " Топики ",
@@ -804,19 +804,7 @@
"vision": "Визуальные",
"websearch": "Веб-поисковые"
},
"rerank_model_not_support_provider": "В настоящее время модель переупорядочивания не поддерживает этого провайдера ({{provider}})",
"price": {
"cost": "Стоимость",
"currency": "Валюта",
"custom": "Пользовательский",
"custom_currency": "Пользовательская валюта",
"custom_currency_placeholder": "Введите пользовательскую валюту",
"input": "Цена ввода",
"million_tokens": "M Tokens",
"output": "Цена вывода",
"price": "Цена"
},
"reasoning": "Рассуждение"
"rerank_model_not_support_provider": "В настоящее время модель переупорядочивания не поддерживает этого провайдера ({{provider}})"
},
"navbar": {
"expand": "Развернуть диалоговое окно",
@@ -974,8 +962,7 @@
"per_image": "за изображение",
"per_images": "за изображения",
"required_field": "Обязательное поле",
"uploaded_input": "Загруженный ввод",
"prompt_placeholder_en": "[to be translated]:Enter your image description, currently Imagen only supports English prompts"
"uploaded_input": "Загруженный ввод"
},
"prompts": {
"explanation": "Объясните мне этот концепт",
@@ -1083,25 +1070,7 @@
"assistant.title": "Ассистент по умолчанию",
"data": {
"app_data": "Данные приложения",
"app_data.select": "Изменить директорию",
"app_data.select_title": "Изменить директорию данных приложения",
"app_data.restart_notice": "Для применения изменений потребуется перезапуск приложения",
"app_data.copy_data_option": "Копировать данные из исходной директории в новую директорию",
"app_data.copy_time_notice": "Копирование данных из исходной директории займет некоторое время, пожалуйста, будьте терпеливы",
"app_data.path_changed_without_copy": "Путь изменен успешно, но данные не скопированы",
"app_data.copying_warning": "Копирование данных, нельзя взаимодействовать с приложением, не закрывайте приложение",
"app_data.copying": "Копирование данных в новое место...",
"app_data.copy_success": "Данные успешно скопированы в новое место",
"app_data.copy_failed": "Не удалось скопировать данные",
"app_data.select_success": "Директория данных изменена, приложение будет перезапущено для применения изменений",
"app_data.select_error": "Не удалось изменить директорию данных",
"app_data.migration_title": "Миграция данных",
"app_data.original_path": "Исходный путь",
"app_data.new_path": "Новый путь",
"app_data.select_error_root_path": "Новый путь не может быть корневым",
"app_data.select_error_write_permission": "Новый путь не имеет разрешения на запись",
"app_data.stop_quit_app_reason": "Приложение в настоящее время перемещает данные и не может быть закрыто",
"app_knowledge": "Файлы базы знаний",
"app_knowledge": "База знаний",
"app_knowledge.button.delete": "Удалить файл",
"app_knowledge.remove_all": "Удалить файлы базы знаний",
"app_knowledge.remove_all_confirm": "Удаление файлов базы знаний не удалит саму базу знаний, что позволит уменьшить занимаемый объем памяти, продолжить?",
@@ -1133,8 +1102,7 @@
"obsidian": "Экспорт в Obsidian",
"siyuan": "Экспорт в SiYuan Note",
"joplin": "Экспорт в Joplin",
"docx": "Экспорт в Word",
"plain_text": "Копировать как чистый текст"
"docx": "Экспорт в Word"
},
"joplin": {
"check": {
@@ -1161,7 +1129,7 @@
"markdown_export.select": "Выбрать",
"markdown_export.title": "Экспорт в Markdown",
"markdown_export.show_model_name.title": "Использовать имя модели при экспорте",
"markdown_export.show_model_name.help": "Если включено, при экспорте в Markdown будет отображаться имя модели. Примечание: Эта опция также влияет на все методы экспорта через Markdown, такие как Notion, Yuque и т.д.",
"markdown_export.show_model_name.help": "Если включено, для создания заголовков экспортируемых сообщений будет использоваться модель именования темы. Примечание: Эта опция также влияет на все методы экспорта через Markdown, такие как Notion, Yuque и т.д.",
"markdown_export.show_model_provider.title": "Показать поставщика модели",
"markdown_export.show_model_provider.help": "Показывать поставщика модели (например, OpenAI, Gemini) при экспорте в Markdown",
"minute_interval_one": "{{count}} минута",
@@ -1225,6 +1193,8 @@
"restore.confirm.content": "Восстановление с WebDAV перезапишет текущие данные, продолжить?",
"restore.confirm.title": "Подтверждение восстановления",
"restore.content": "Восстановление с WebDAV перезапишет текущие данные, продолжить?",
"restore.modal.select.placeholder": "Выберите файл резервной копии для восстановления",
"restore.modal.title": "Восстановление с WebDAV",
"restore.title": "Восстановление с WebDAV",
"syncError": "Ошибка резервного копирования",
"syncStatus": "Статус резервного копирования",
@@ -1913,8 +1883,7 @@
"model_desc": "Модель, используемая для службы перевода",
"bidirectional": "Настройки двунаправленного перевода",
"scroll_sync": "Настройки синхронизации прокрутки",
"bidirectional_tip": "Если включено, перевод будет выполняться в обоих направлениях, исходный текст будет переведен на целевой язык и наоборот.",
"preview": "Markdown предпросмотр"
"bidirectional_tip": "Если включено, перевод будет выполняться в обоих направлениях, исходный текст будет переведен на целевой язык и наоборот."
},
"title": "Перевод",
"tooltip.newline": "Перевести",

View File

@@ -183,7 +183,7 @@
"input.new.context": "清除上下文 {{Command}}",
"input.new_topic": "新话题 {{Command}}",
"input.pause": "暂停",
"input.placeholder": "在这里输入消息,按 {{key}} 发送...",
"input.placeholder": "在这里输入消息...",
"input.translating": "翻译中...",
"input.send": "发送",
"input.settings": "设置",
@@ -787,18 +787,6 @@
"string": "文本"
},
"pinned": "已固定",
"price": {
"cost": "花费",
"currency": "币种",
"custom": "自定义",
"custom_currency": "自定义币种",
"custom_currency_placeholder": "请输入自定义币种",
"input": "输入价格",
"million_tokens": "百万 Token",
"output": "输出价格",
"price": "价格"
},
"reasoning": "推理",
"rerank_model": "重排模型",
"rerank_model_support_provider": "目前重排序模型仅支持部分服务商 ({{provider}})",
"rerank_model_not_support_provider": "目前重排序模型不支持该服务商 ({{provider}})",
@@ -1085,24 +1073,6 @@
"assistant.title": "默认助手",
"data": {
"app_data": "应用数据",
"app_data.select": "修改目录",
"app_data.select_title": "更改应用数据目录",
"app_data.restart_notice": "应用需要重启以应用更改",
"app_data.copy_data_option": "复制数据,开启后会将原始目录数据复制到新目录",
"app_data.copy_time_notice": "复制数据将需要一些时间,复制期间不要关闭应用",
"app_data.path_changed_without_copy": "路径已更改成功,但数据未复制",
"app_data.copying_warning": "数据复制中不要强制退出app",
"app_data.copying": "正在将数据复制到新位置...",
"app_data.copy_success": "已成功复制数据到新位置",
"app_data.copy_failed": "复制数据失败",
"app_data.select_success": "数据目录已更改,应用将重启以应用更改",
"app_data.select_error": "更改数据目录失败",
"app_data.migration_title": "数据迁移",
"app_data.original_path": "原始路径",
"app_data.new_path": "新路径",
"app_data.select_error_root_path": "新路径不能是根路径",
"app_data.select_error_write_permission": "新路径没有写入权限",
"app_data.stop_quit_app_reason": "应用目前在迁移数据, 不能退出",
"app_knowledge": "知识库文件",
"app_knowledge.button.delete": "删除文件",
"app_knowledge.remove_all": "删除知识库文件",
@@ -1135,8 +1105,7 @@
"obsidian": "导出到Obsidian",
"siyuan": "导出到思源笔记",
"joplin": "导出到Joplin",
"docx": "导出为Word",
"plain_text": "复制为纯文本"
"docx": "导出为Word"
},
"joplin": {
"check": {
@@ -1163,7 +1132,7 @@
"markdown_export.select": "选择",
"markdown_export.title": "Markdown 导出",
"markdown_export.show_model_name.title": "导出时使用模型名称",
"markdown_export.show_model_name.help": "开启后,导出Markdown时会显示模型名称。注意该项也会影响所有通过Markdown导出的方式如Notion、语雀等。",
"markdown_export.show_model_name.help": "开启后,使用话题命名模型为导出的消息创建标题。注意该项也会影响所有通过Markdown导出的方式如Notion、语雀等。",
"markdown_export.show_model_provider.title": "显示模型供应商",
"markdown_export.show_model_provider.help": "在导出Markdown时显示模型供应商如OpenAI、Gemini等",
"message_title.use_topic_naming.title": "使用话题命名模型为导出的消息创建标题",
@@ -1229,6 +1198,8 @@
"restore.confirm.content": "从 WebDAV 恢复将会覆盖当前数据,是否继续?",
"restore.confirm.title": "确认恢复",
"restore.content": "从 WebDAV 恢复将覆盖当前数据,是否继续?",
"restore.modal.select.placeholder": "请选择要恢复的备份文件",
"restore.modal.title": "从 WebDAV 恢复",
"restore.title": "从 WebDAV 恢复",
"syncError": "备份错误",
"syncStatus": "备份状态",
@@ -1916,8 +1887,7 @@
"model_desc": "翻译服务使用的模型",
"bidirectional": "双向翻译设置",
"bidirectional_tip": "开启后,仅支持在源语言和目标语言之间进行双向翻译",
"scroll_sync": "滚动同步设置",
"preview": "Markdown 预览"
"scroll_sync": "滚动同步设置"
},
"title": "翻译",
"tooltip.newline": "换行",

View File

@@ -183,7 +183,7 @@
"input.new.context": "清除上下文 {{Command}}",
"input.new_topic": "新話題 {{Command}}",
"input.pause": "暫停",
"input.placeholder": "在此輸入您的訊息,按 {{key}} 傳送...",
"input.placeholder": "在此輸入您的訊息...",
"input.send": "傳送",
"input.settings": "設定",
"input.topics": " 話題 ",
@@ -804,19 +804,7 @@
"vision": "視覺",
"websearch": "網路搜尋"
},
"rerank_model_not_support_provider": "目前,重新排序模型不支援此提供者({{provider}}",
"price": {
"cost": "花費",
"currency": "幣種",
"custom": "自訂",
"custom_currency": "自訂幣種",
"custom_currency_placeholder": "請輸入自訂幣種",
"input": "輸入價格",
"million_tokens": "M Tokens",
"output": "輸出價格",
"price": "價格"
},
"reasoning": "推理"
"rerank_model_not_support_provider": "目前,重新排序模型不支援此提供者({{provider}}"
},
"navbar": {
"expand": "伸縮對話框",
@@ -1084,25 +1072,7 @@
"assistant.icon.type.none": "不顯示",
"assistant.title": "預設助手",
"data": {
"app_data": "應用數據",
"app_data.select": "修改目錄",
"app_data.select_title": "變更應用數據目錄",
"app_data.restart_notice": "變更數據目錄後需要重啟應用才能生效",
"app_data.copy_data_option": "複製數據, 開啟後會將原始目錄數據複製到新目錄",
"app_data.copy_time_notice": "複製數據將需要一些時間,複製期間不要關閉應用",
"app_data.path_changed_without_copy": "路徑已變更成功,但數據未複製",
"app_data.copying_warning": "數據複製中,不要強制退出應用",
"app_data.copying": "正在複製數據到新位置...",
"app_data.copy_success": "成功複製數據到新位置",
"app_data.copy_failed": "複製數據失敗",
"app_data.select_success": "數據目錄已變更,應用將重啟以應用變更",
"app_data.select_error": "變更數據目錄失敗",
"app_data.migration_title": "數據遷移",
"app_data.original_path": "原始路徑",
"app_data.new_path": "新路徑",
"app_data.select_error_root_path": "新路徑不能是根路徑",
"app_data.select_error_write_permission": "新路徑沒有寫入權限",
"app_data.stop_quit_app_reason": "應用目前正在遷移數據,不能退出",
"app_data": "應用程式資料",
"app_knowledge": "知識庫文件",
"app_knowledge.button.delete": "刪除檔案",
"app_knowledge.remove_all": "刪除知識庫檔案",
@@ -1135,8 +1105,7 @@
"obsidian": "匯出到Obsidian",
"siyuan": "匯出到思源筆記",
"joplin": "匯出到Joplin",
"docx": "匯出為Word",
"plain_text": "複製為純文本"
"docx": "匯出為Word"
},
"joplin": {
"check": {
@@ -1163,7 +1132,7 @@
"markdown_export.select": "選擇",
"markdown_export.title": "Markdown 匯出",
"markdown_export.show_model_name.title": "匯出時使用模型名稱",
"markdown_export.show_model_name.help": "啟用後,匯出Markdown時會顯示模型名稱。注意該項也會影響所有透過Markdown匯出的方式如Notion、語雀等。",
"markdown_export.show_model_name.help": "啟用後,將以主題命名模型為匯出的訊息建立標題。注意該項也會影響所有透過Markdown匯出的方式如Notion、語雀等。",
"markdown_export.show_model_provider.title": "顯示模型供應商",
"markdown_export.show_model_provider.help": "在匯出Markdown時顯示模型供應商如OpenAI、Gemini等",
"minute_interval_one": "{{count}} 分鐘",
@@ -1227,6 +1196,8 @@
"restore.confirm.content": "從 WebDAV 恢復將覆蓋目前資料,是否繼續?",
"restore.confirm.title": "復元確認",
"restore.content": "從 WebDAV 恢復將覆蓋目前資料,是否繼續?",
"restore.modal.select.placeholder": "請選擇要恢復的備份文件",
"restore.modal.title": "從 WebDAV 恢復",
"restore.title": "從 WebDAV 恢復",
"syncError": "備份錯誤",
"syncStatus": "備份狀態",
@@ -1913,8 +1884,7 @@
"model_desc": "翻譯服務使用的模型",
"bidirectional": "雙向翻譯設定",
"bidirectional_tip": "開啟後,僅支援在源語言和目標語言之間進行雙向翻譯",
"scroll_sync": "滾動同步設定",
"preview": "Markdown 預覽"
"scroll_sync": "滾動同步設定"
},
"title": "翻譯",
"tooltip.newline": "換行",

View File

@@ -75,8 +75,8 @@ const AgentsPage: FC = () => {
{agent.description && <AgentDescription>{agent.description}</AgentDescription>}
{agent.prompt && (
<AgentPrompt className="markdown">
<ReactMarkdown>{agent.prompt}</ReactMarkdown>
<AgentPrompt>
<ReactMarkdown className="markdown">{agent.prompt}</ReactMarkdown>{' '}
</AgentPrompt>
)}
</Flex>

View File

@@ -1,7 +1,5 @@
import { groupTranslations } from '@renderer/pages/agents/agentGroupTranslations'
import { DynamicIcon, IconName } from 'lucide-react/dynamic'
import { FC } from 'react'
import { useTranslation } from 'react-i18next'
interface Props {
groupName: string
@@ -10,25 +8,6 @@ interface Props {
}
export const AgentGroupIcon: FC<Props> = ({ groupName, size = 20, strokeWidth = 1.2 }) => {
const { i18n } = useTranslation()
const currentLanguage = i18n.language as keyof (typeof groupTranslations)[string]
const findOriginalKey = (name: string): string => {
if (groupTranslations[name]) {
return name
}
for (const key in groupTranslations) {
if (groupTranslations[key][currentLanguage] === name) {
return key
}
}
return name
}
const originalKey = findOriginalKey(groupName)
const iconMap: { [key: string]: IconName } = {
: 'user-check',
: 'star',
@@ -67,5 +46,5 @@ export const AgentGroupIcon: FC<Props> = ({ groupName, size = 20, strokeWidth =
: 'search'
} as const
return <DynamicIcon name={iconMap[originalKey] || 'bot-message-square'} size={size} strokeWidth={strokeWidth} />
return <DynamicIcon name={iconMap[groupName] || 'bot-message-square'} size={size} strokeWidth={strokeWidth} />
}

View File

@@ -3,7 +3,6 @@ import { useSettings } from '@renderer/hooks/useSettings'
import store from '@renderer/store'
import { Agent } from '@renderer/types'
import { useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
let _agents: Agent[] = []
@@ -23,8 +22,6 @@ export function useSystemAgents() {
const [agents, setAgents] = useState<Agent[]>([])
const { resourcesPath } = useRuntime()
const { agentssubscribeUrl } = store.getState().settings
const { i18n } = useTranslation()
const currentLanguage = i18n.language
useEffect(() => {
const loadAgents = async () => {
@@ -47,21 +44,9 @@ export function useSystemAgents() {
}
// 如果没有远程配置或获取失败,加载本地代理
if (resourcesPath) {
try {
let fileName = 'agents.json'
if (currentLanguage === 'zh-CN') {
fileName = 'agents-zh.json'
} else {
fileName = 'agents-en.json'
}
const localAgentsData = await window.api.fs.read(`${resourcesPath}/data/${fileName}`, 'utf-8')
_agents = JSON.parse(localAgentsData) as Agent[]
} catch (error) {
const localAgentsData = await window.api.fs.read(resourcesPath + '/data/agents.json', 'utf-8')
_agents = JSON.parse(localAgentsData) as Agent[]
}
if (resourcesPath && _agents.length === 0) {
const localAgentsData = await window.api.fs.read(resourcesPath + '/data/agents.json', 'utf-8')
_agents = JSON.parse(localAgentsData) as Agent[]
}
setAgents(_agents)
@@ -73,7 +58,7 @@ export function useSystemAgents() {
}
loadAgents()
}, [defaultAgent, resourcesPath, agentssubscribeUrl, currentLanguage])
}, [defaultAgent, resourcesPath, agentssubscribeUrl])
return agents
}

View File

@@ -13,9 +13,10 @@ import {
import db from '@renderer/databases'
import { useAssistant } from '@renderer/hooks/useAssistant'
import { useKnowledgeBases } from '@renderer/hooks/useKnowledge'
import { useMCPServers } from '@renderer/hooks/useMCPServers'
import { useMessageOperations, useTopicLoading } from '@renderer/hooks/useMessageOperations'
import { modelGenerating, useRuntime } from '@renderer/hooks/useRuntime'
import { useSettings } from '@renderer/hooks/useSettings'
import { useMessageStyle, useSettings } from '@renderer/hooks/useSettings'
import { useShortcut, useShortcutDisplay } from '@renderer/hooks/useShortcuts'
import { useSidebarIconShow } from '@renderer/hooks/useSidebarIcon'
import { getDefaultTopic } from '@renderer/services/AssistantService'
@@ -35,7 +36,6 @@ import type { MessageInputBaseParams } from '@renderer/types/newMessage'
import { classNames, delay, formatFileSize, getFileExtension } from '@renderer/utils'
import { formatQuotedText } from '@renderer/utils/formats'
import { getFilesFromDropEvent } from '@renderer/utils/input'
import { getSendMessageShortcutLabel, isSendMessageKeyPressed } from '@renderer/utils/input'
import { documentExts, imageExts, textExts } from '@shared/config/constant'
import { IpcChannel } from '@shared/IpcChannel'
import { Button, Tooltip } from 'antd'
@@ -87,6 +87,7 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
const { t } = useTranslation()
const containerRef = useRef(null)
const { searching } = useRuntime()
const { isBubbleStyle } = useMessageStyle()
const { pauseMessages } = useMessageOperations(topic)
const loading = useTopicLoading(topic)
const dispatch = useAppDispatch()
@@ -103,6 +104,7 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
const currentMessageId = useRef<string>('')
const isVision = useMemo(() => isVisionModel(model), [model])
const supportExts = useMemo(() => [...textExts, ...documentExts, ...(isVision ? imageExts : [])], [isVision])
const { activedMcpServers } = useMCPServers()
const { bases: knowledgeBases } = useKnowledgeBases()
const isMultiSelectMode = useAppSelector((state) => state.runtime.chat.isMultiSelectMode)
@@ -173,11 +175,22 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
if (uploadedFiles) {
baseUserMessage.files = uploadedFiles
}
const knowledgeBaseIds = selectedKnowledgeBases?.map((base) => base.id)
if (knowledgeBaseIds) {
baseUserMessage.knowledgeBaseIds = knowledgeBaseIds
}
if (mentionModels) {
baseUserMessage.mentions = mentionModels
}
if (!isEmpty(assistant.mcpServers) && !isEmpty(activedMcpServers)) {
baseUserMessage.enabledMCPs = activedMcpServers.filter((server) =>
assistant.mcpServers?.some((s) => s.id === server.id)
)
}
const assistantWithTopicPrompt = topic.prompt
? { ...assistant, prompt: `${assistant.prompt}\n${topic.prompt}` }
: assistant
@@ -198,7 +211,19 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
} catch (error) {
console.error('Failed to send message:', error)
}
}, [assistant, dispatch, files, inputEmpty, loading, mentionModels, resizeTextArea, text, topic])
}, [
activedMcpServers,
assistant,
dispatch,
files,
inputEmpty,
loading,
mentionModels,
resizeTextArea,
selectedKnowledgeBases,
text,
topic
])
const translate = useCallback(async () => {
if (isTranslating) {
@@ -284,6 +309,8 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
}, [knowledgeBases, openKnowledgeFileList, quickPanel, t, inputbarToolsRef])
const handleKeyDown = (event: React.KeyboardEvent<HTMLTextAreaElement>) => {
const isEnterPressed = event.key === 'Enter' && !event.nativeEvent.isComposing
// 按下Tab键自动选中${xxx}
if (event.key === 'Tab' && inputFocus) {
event.preventDefault()
@@ -339,37 +366,32 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
}
}
//to check if the SendMessage key is pressed
//other keys should be ignored
const isEnterPressed = event.key === 'Enter' && !event.nativeEvent.isComposing
if (isEnterPressed) {
if (isSendMessageKeyPressed(event, sendMessageShortcut)) {
if (quickPanel.isVisible) return event.preventDefault()
sendMessage()
return event.preventDefault()
} else {
//shift+enter's default behavior is to add a new line, ignore it
if (!event.shiftKey) {
event.preventDefault()
if (isEnterPressed && !event.shiftKey && sendMessageShortcut === 'Enter') {
if (quickPanel.isVisible) return event.preventDefault()
const textArea = textareaRef.current?.resizableTextArea?.textArea
if (textArea) {
const start = textArea.selectionStart
const end = textArea.selectionEnd
const text = textArea.value
const newText = text.substring(0, start) + '\n' + text.substring(end)
sendMessage()
return event.preventDefault()
}
// update text by setState, not directly modify textarea.value
setText(newText)
if (sendMessageShortcut === 'Shift+Enter' && isEnterPressed && event.shiftKey) {
if (quickPanel.isVisible) return event.preventDefault()
// set cursor position in the next render cycle
setTimeout(() => {
textArea.selectionStart = textArea.selectionEnd = start + 1
onInput() // trigger resizeTextArea
}, 0)
}
}
}
sendMessage()
return event.preventDefault()
}
if (sendMessageShortcut === 'Ctrl+Enter' && isEnterPressed && event.ctrlKey) {
if (quickPanel.isVisible) return event.preventDefault()
sendMessage()
return event.preventDefault()
}
if (sendMessageShortcut === 'Command+Enter' && isEnterPressed && event.metaKey) {
if (quickPanel.isVisible) return event.preventDefault()
sendMessage()
return event.preventDefault()
}
if (enableBackspaceDeleteModel && event.key === 'Backspace' && text.trim() === '' && mentionModels.length > 0) {
@@ -672,6 +694,8 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
setSelectedKnowledgeBases(showKnowledgeIcon ? (assistant.knowledge_bases ?? []) : [])
}, [assistant.id, assistant.knowledge_bases, showKnowledgeIcon])
const textareaRows = window.innerHeight >= 1000 || isBubbleStyle ? 2 : 1
const handleKnowledgeBaseSelect = (bases?: KnowledgeBase[]) => {
updateAssistant({ ...assistant, knowledge_bases: bases })
setSelectedKnowledgeBases(bases ?? [])
@@ -774,16 +798,12 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
value={text}
onChange={onChange}
onKeyDown={handleKeyDown}
placeholder={
isTranslating
? t('chat.input.translating')
: t('chat.input.placeholder', { key: getSendMessageShortcutLabel(sendMessageShortcut) })
}
placeholder={isTranslating ? t('chat.input.translating') : t('chat.input.placeholder')}
autoFocus
contextMenu="true"
variant="borderless"
spellCheck={false}
rows={2}
rows={textareaRows}
ref={textareaRef}
style={{
fontSize,
@@ -930,7 +950,7 @@ const Textarea = styled(TextArea)`
overflow: auto;
width: 100%;
box-sizing: border-box;
transition: none !important;
transition: height 0.2s ease;
&.ant-input {
line-height: 1.4;
}

View File

@@ -24,7 +24,6 @@ import remarkMath from 'remark-math'
import CodeBlock from './CodeBlock'
import Link from './Link'
import remarkDisableConstructs from './plugins/remarkDisableConstructs'
import Table from './Table'
const ALLOWED_ELEMENTS =
@@ -41,7 +40,7 @@ const Markdown: FC<Props> = ({ block }) => {
const { mathEngine } = useSettings()
const remarkPlugins = useMemo(() => {
const plugins = [remarkGfm, remarkCjkFriendly, remarkDisableConstructs(['codeIndented'])]
const plugins = [remarkGfm, remarkCjkFriendly]
if (mathEngine !== 'none') {
plugins.push(remarkMath)
}
@@ -106,21 +105,20 @@ const Markdown: FC<Props> = ({ block }) => {
}, [])
return (
<div className="markdown">
<ReactMarkdown
rehypePlugins={rehypePlugins}
remarkPlugins={remarkPlugins}
components={components}
disallowedElements={DISALLOWED_ELEMENTS}
urlTransform={urlTransform}
remarkRehypeOptions={{
footnoteLabel: t('common.footnotes'),
footnoteLabelTagName: 'h4',
footnoteBackContent: ' '
}}>
{messageContent}
</ReactMarkdown>
</div>
<ReactMarkdown
rehypePlugins={rehypePlugins}
remarkPlugins={remarkPlugins}
className="markdown"
components={components}
disallowedElements={DISALLOWED_ELEMENTS}
urlTransform={urlTransform}
remarkRehypeOptions={{
footnoteLabel: t('common.footnotes'),
footnoteLabelTagName: 'h4',
footnoteBackContent: ' '
}}>
{messageContent}
</ReactMarkdown>
)
}

View File

@@ -103,12 +103,6 @@ vi.mock('rehype-katex', () => ({ __esModule: true, default: vi.fn() }))
vi.mock('rehype-mathjax', () => ({ __esModule: true, default: vi.fn() }))
vi.mock('rehype-raw', () => ({ __esModule: true, default: vi.fn() }))
// Mock custom plugins
vi.mock('../plugins/remarkDisableConstructs', () => ({
__esModule: true,
default: vi.fn()
}))
// Mock ReactMarkdown with realistic rendering
vi.mock('react-markdown', () => ({
__esModule: true,
@@ -168,16 +162,12 @@ describe('Markdown', () => {
describe('rendering', () => {
it('should render markdown content with correct structure', () => {
const block = createMainTextBlock({ content: 'Test content' })
const { container } = render(<Markdown block={block} />)
render(<Markdown block={block} />)
// Check that the outer container has the markdown class
const markdownContainer = container.querySelector('.markdown')
expect(markdownContainer).toBeInTheDocument()
// Check that the markdown content is rendered inside
const markdownContent = screen.getByTestId('markdown-content')
expect(markdownContent).toBeInTheDocument()
expect(markdownContent).toHaveTextContent('Test content')
const markdown = screen.getByTestId('markdown-content')
expect(markdown).toBeInTheDocument()
expect(markdown).toHaveClass('markdown')
expect(markdown).toHaveTextContent('Test content')
})
it('should handle empty content gracefully', () => {

View File

@@ -3,58 +3,55 @@
exports[`Markdown > rendering > should match snapshot 1`] = `
<div
class="markdown"
data-testid="markdown-content"
>
<div
data-testid="markdown-content"
>
# Test Markdown
# Test Markdown
This is **bold** text.
<span
data-testid="has-link-component"
>
link
</span>
<span
data-testid="has-link-component"
>
link
</span>
<div
data-testid="has-code-component"
>
<div
data-testid="has-code-component"
data-id="code-block-1"
data-testid="code-block"
>
<div
data-id="code-block-1"
data-testid="code-block"
<code>
test code
</code>
<button
type="button"
>
<code>
test code
</code>
<button
type="button"
>
Save
</button>
</div>
Save
</button>
</div>
<div
data-testid="has-table-component"
>
<div
data-block-id="test-block-1"
data-testid="table-component"
>
<table>
test table
</table>
<button
data-testid="copy-table-button"
type="button"
>
Copy Table
</button>
</div>
</div>
<span
data-testid="has-img-component"
>
img
</span>
</div>
<div
data-testid="has-table-component"
>
<div
data-block-id="test-block-1"
data-testid="table-component"
>
<table>
test table
</table>
<button
data-testid="copy-table-button"
type="button"
>
Copy Table
</button>
</div>
</div>
<span
data-testid="has-img-component"
>
img
</span>
</div>
`;

View File

@@ -1,155 +0,0 @@
import { render } from '@testing-library/react'
import ReactMarkdown from 'react-markdown'
import { describe, expect, it } from 'vitest'
import remarkDisableConstructs from '../remarkDisableConstructs'
describe('disableIndentedCode', () => {
const renderMarkdown = (markdown: string, constructs: string[] = ['codeIndented']) => {
return render(<ReactMarkdown remarkPlugins={[remarkDisableConstructs(constructs)]}>{markdown}</ReactMarkdown>)
}
describe('normal path', () => {
it('should disable indented code blocks while preserving other code types', () => {
const markdown = `
# Test Document
Regular paragraph.
This should be treated as a regular paragraph, not code
\`inline code\` should work
\`\`\`javascript
// This fenced code should work
console.log('hello')
\`\`\`
Another paragraph.
`
const { container } = renderMarkdown(markdown)
// Verify only fenced code (pre element)
expect(container.querySelectorAll('pre')).toHaveLength(1)
// Verify inline code
const inlineCode = container.querySelector('code:not(pre code)')
expect(inlineCode?.textContent).toBe('inline code')
// Verify fenced code
const fencedCode = container.querySelector('pre code')
expect(fencedCode?.textContent).toContain('console.log')
// Verify indented content becomes paragraph
const paragraphs = container.querySelectorAll('p')
const indentedParagraph = Array.from(paragraphs).find((p) =>
p.textContent?.includes('This should be treated as a regular paragraph')
)
expect(indentedParagraph).toBeTruthy()
})
it('should handle indented code in nested structures', () => {
const markdown = `
> Blockquote with \`inline code\`
>
> This indented code in blockquote should become text
1. List item
This indented code in list should become text
* Bullet list
* Nested item
More indented code to convert
`
const { container } = renderMarkdown(markdown)
// Verify no indented code blocks
expect(container.querySelectorAll('pre')).toHaveLength(0)
// Verify blockquote exists and contains converted text
const blockquote = container.querySelector('blockquote')
expect(blockquote?.textContent).toContain('This indented code in blockquote should become text')
// Verify lists exist
const lists = container.querySelectorAll('ul, ol')
expect(lists.length).toBeGreaterThan(0)
})
it('should preserve other markdown elements when disabling constructs', () => {
const markdown = `
# Heading
Paragraph text.
Indented code to disable
[Link text](https://example.com)
\`\`\`
Fenced code to keep
\`\`\`
`
const { container } = renderMarkdown(markdown)
// Verify heading
expect(container.querySelector('h1')?.textContent).toBe('Heading')
// Verify link
const link = container.querySelector('a')
expect(link?.textContent).toBe('Link text')
expect(link?.getAttribute('href')).toBe('https://example.com')
// Verify only fenced code
expect(container.querySelectorAll('pre')).toHaveLength(1)
})
})
describe('edge cases', () => {
it('should not affect markdown when no constructs are disabled', () => {
const markdown = `
This is indented code
\`inline code\`
\`\`\`javascript
console.log('fenced')
\`\`\`
`
const { container } = renderMarkdown(markdown, [])
// Should have indented code and fenced code
expect(container.querySelectorAll('pre')).toHaveLength(2)
})
it('should handle markdown with only inline and fenced code', () => {
const markdown = `
Regular paragraph with \`inline code\`.
\`\`\`typescript
function test(): string {
return "hello";
}
\`\`\`
`
const { container } = renderMarkdown(markdown)
// Should have only fenced code
expect(container.querySelectorAll('pre')).toHaveLength(1)
// Verify fenced code content
const fencedCode = container.querySelector('pre code')
expect(fencedCode?.textContent).toContain('function test()')
// Verify inline code
const inlineCode = container.querySelector('code:not(pre code)')
expect(inlineCode?.textContent).toBe('inline code')
})
})
})

View File

@@ -1,107 +0,0 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import remarkDisableConstructs from '../remarkDisableConstructs'
describe('remarkDisableConstructs', () => {
let mockData: any
let mockThis: any
beforeEach(() => {
mockData = {}
mockThis = {
data: vi.fn().mockReturnValue(mockData)
}
})
describe('plugin creation', () => {
it('should return a function when called', () => {
const plugin = remarkDisableConstructs(['codeIndented'])
expect(typeof plugin).toBe('function')
})
})
describe('normal path', () => {
it('should add micromarkExtensions for single construct', () => {
const plugin = remarkDisableConstructs(['codeIndented'])
plugin.call(mockThis as any)
expect(mockData).toHaveProperty('micromarkExtensions')
expect(Array.isArray(mockData.micromarkExtensions)).toBe(true)
expect(mockData.micromarkExtensions).toHaveLength(1)
expect(mockData.micromarkExtensions[0]).toEqual({
disable: {
null: ['codeIndented']
}
})
})
it('should handle multiple constructs', () => {
const constructs = ['codeIndented', 'autolink', 'htmlFlow']
const plugin = remarkDisableConstructs(constructs)
plugin.call(mockThis as any)
expect(mockData.micromarkExtensions[0]).toEqual({
disable: {
null: constructs
}
})
})
})
describe('edge cases', () => {
it('should not add extensions when empty array is provided', () => {
const plugin = remarkDisableConstructs([])
plugin.call(mockThis as any)
expect(mockData).not.toHaveProperty('micromarkExtensions')
})
it('should not add extensions when undefined is passed', () => {
const plugin = remarkDisableConstructs()
plugin.call(mockThis as any)
expect(mockData).not.toHaveProperty('micromarkExtensions')
})
it('should handle empty construct names', () => {
const plugin = remarkDisableConstructs(['', ' '])
plugin.call(mockThis as any)
expect(mockData.micromarkExtensions[0]).toEqual({
disable: {
null: ['', ' ']
}
})
})
it('should handle mixed valid and empty construct names', () => {
const plugin = remarkDisableConstructs(['codeIndented', '', 'autolink'])
plugin.call(mockThis as any)
expect(mockData.micromarkExtensions[0]).toEqual({
disable: {
null: ['codeIndented', '', 'autolink']
}
})
})
})
describe('interaction with existing data', () => {
it('should append to existing micromarkExtensions', () => {
const existingExtension = { some: 'extension' }
mockData.micromarkExtensions = [existingExtension]
const plugin = remarkDisableConstructs(['codeIndented'])
plugin.call(mockThis as any)
expect(mockData.micromarkExtensions).toHaveLength(2)
expect(mockData.micromarkExtensions[0]).toBe(existingExtension)
expect(mockData.micromarkExtensions[1]).toEqual({
disable: {
null: ['codeIndented']
}
})
})
})
})

View File

@@ -1,53 +0,0 @@
import type { Plugin } from 'unified'
/**
* Custom remark plugin to disable specific markdown constructs
*
* This plugin allows you to disable specific markdown constructs by passing
* them as micromark extensions to the underlying parser.
*
* @see https://github.com/micromark/micromark
*
* @example
* ```typescript
* // Disable indented code blocks
* remarkDisableConstructs(['codeIndented'])
*
* // Disable multiple constructs
* remarkDisableConstructs(['codeIndented', 'autolink', 'htmlFlow'])
* ```
*/
/**
* Helper function to add values to plugin data
* @param data - The plugin data object
* @param field - The field name to add to
* @param value - The value to add
*/
function add(data: any, field: string, value: unknown): void {
const list = data[field] ? data[field] : (data[field] = [])
list.push(value)
}
/**
* Remark plugin to disable specific markdown constructs
* @param constructs - Array of construct names to disable (e.g., ['codeIndented', 'autolink'])
* @returns A remark plugin function
*/
function remarkDisableConstructs(constructs: string[] = []): Plugin<[], any, any> {
return function () {
const data = this.data()
if (constructs.length > 0) {
const disableExtension = {
disable: {
null: constructs
}
}
add(data, 'micromarkExtensions', disableExtension)
}
}
}
export default remarkDisableConstructs

View File

@@ -1,3 +1,4 @@
import { DownOutlined } from '@ant-design/icons'
import EmojiAvatar from '@renderer/components/Avatar/EmojiAvatar'
import { APP_NAME, AppLogo, isLocalAi } from '@renderer/config/env'
import { getModelLogo } from '@renderer/config/models'
@@ -13,7 +14,6 @@ import type { Message } from '@renderer/types/newMessage'
import { isEmoji, removeLeadingEmoji } from '@renderer/utils'
import { getMainTextContent } from '@renderer/utils/messageUtils/find'
import { Avatar } from 'antd'
import { CircleChevronDown } from 'lucide-react'
import { type FC, useCallback, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
@@ -104,18 +104,14 @@ const MessageAnchorLine: FC<MessageLineProps> = ({ messages }) => {
if (groupMessages.length > 1) {
for (const m of groupMessages) {
dispatch(
newMessagesActions.updateMessage({
topicId: m.topicId,
messageId: m.id,
updates: { foldSelected: m.id === message.id }
})
newMessagesActions.updateMessage({ topicId: m.topicId, messageId: m.id, updates: { foldSelected: true } })
)
}
setTimeout(() => {
const messageElement = document.getElementById(`message-${message.id}`)
if (messageElement) {
messageElement.scrollIntoView({ behavior: 'auto', block: 'start' })
messageElement.scrollIntoView({ behavior: 'smooth', block: 'nearest' })
}
}, 100)
}
@@ -187,9 +183,16 @@ const MessageAnchorLine: FC<MessageLineProps> = ({ messages }) => {
opacity: mouseY ? 0.5 + calculateValueByDistance('bottom-anchor', 1) : 0.6
}}
onClick={scrollToBottom}>
<CircleChevronDown
<MessageItemContainer
style={{ transform: `scale(${1 + calculateValueByDistance('bottom-anchor', 1)})` }}></MessageItemContainer>
<Avatar
icon={<DownOutlined style={{ color: theme === 'dark' ? 'var(--color-text)' : 'var(--color-primary)' }} />}
size={10 + calculateValueByDistance('bottom-anchor', 20)}
style={{ color: theme === 'dark' ? 'var(--color-text)' : 'var(--color-primary)' }}
style={{
backgroundColor: theme === 'dark' ? 'var(--color-background-soft)' : 'var(--color-primary-light)',
border: `1px solid ${theme === 'dark' ? 'var(--color-border-soft)' : 'var(--color-primary-soft)'}`,
opacity: 0.9
}}
/>
</MessageItem>
{messages.map((message, index) => {
@@ -200,8 +203,6 @@ const MessageAnchorLine: FC<MessageLineProps> = ({ messages }) => {
const username = removeLeadingEmoji(getUserName(message))
const content = getMainTextContent(message)
if (message.type === 'clear') return null
return (
<MessageItem
key={message.id}
@@ -261,6 +262,7 @@ const MessageItemContainer = styled.div`
justify-content: space-between;
text-align: right;
gap: 4px;
text-shadow: 0 0 2px rgba(255, 255, 255, 0.5);
opacity: 0;
transform-origin: right center;
`

View File

@@ -8,7 +8,7 @@ import PasteService from '@renderer/services/PasteService'
import { FileType, FileTypes } from '@renderer/types'
import { Message, MessageBlock, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
import { classNames, getFileExtension } from '@renderer/utils'
import { getFilesFromDropEvent, isSendMessageKeyPressed } from '@renderer/utils/input'
import { getFilesFromDropEvent } from '@renderer/utils/input'
import { createFileBlock, createImageBlock } from '@renderer/utils/messageUtils/create'
import { findAllBlocks } from '@renderer/utils/messageUtils/find'
import { documentExts, imageExts, textExts } from '@shared/config/constant'
@@ -169,39 +169,31 @@ const MessageBlockEditor: FC<Props> = ({ message, onSave, onResend, onCancel })
onResend(updatedBlocks)
}
const handleKeyDown = (event: React.KeyboardEvent<HTMLTextAreaElement>, blockId: string) => {
const handleKeyDown = (event: React.KeyboardEvent<HTMLTextAreaElement>) => {
if (message.role !== 'user') {
return
}
// keep the same enter behavior as inputbar
const isEnterPressed = event.key === 'Enter' && !event.nativeEvent.isComposing
if (isEnterPressed) {
if (isSendMessageKeyPressed(event, sendMessageShortcut)) {
handleResend()
return event.preventDefault()
} else {
if (!event.shiftKey) {
event.preventDefault()
const textArea = textareaRef.current?.resizableTextArea?.textArea
if (textArea) {
const start = textArea.selectionStart
const end = textArea.selectionEnd
const text = textArea.value
const newText = text.substring(0, start) + '\n' + text.substring(end)
if (isEnterPressed && !event.shiftKey && sendMessageShortcut === 'Enter') {
handleResend()
return event.preventDefault()
}
//same with onChange()
handleTextChange(blockId, newText)
if (sendMessageShortcut === 'Shift+Enter' && isEnterPressed && event.shiftKey) {
handleResend()
return event.preventDefault()
}
// set cursor position in the next render cycle
setTimeout(() => {
textArea.selectionStart = textArea.selectionEnd = start + 1
resizeTextArea() // trigger resizeTextArea
}, 0)
}
}
}
if (sendMessageShortcut === 'Ctrl+Enter' && isEnterPressed && event.ctrlKey) {
handleResend()
return event.preventDefault()
}
if (sendMessageShortcut === 'Command+Enter' && isEnterPressed && event.metaKey) {
handleResend()
return event.preventDefault()
}
}
@@ -220,7 +212,7 @@ const MessageBlockEditor: FC<Props> = ({ message, onSave, onResend, onCancel })
handleTextChange(block.id, e.target.value)
resizeTextArea()
}}
onKeyDown={(e) => handleKeyDown(e, block.id)}
onKeyDown={handleKeyDown}
autoFocus
contextMenu="true"
spellCheck={false}

View File

@@ -205,7 +205,7 @@ const MessageMenubar: FC<Props> = (props) => {
key: 'export',
icon: <Share size={16} color="var(--color-icon)" style={{ marginTop: 3 }} />,
children: [
exportMenuOptions.plain_text && {
{
label: t('chat.topics.copy.plain_text'),
key: 'copy_message_plain_text',
onClick: () => copyMessageAsPlainText(message)

View File

@@ -18,29 +18,6 @@ const MessgeTokens: React.FC<MessageTokensProps> = ({ message }) => {
EventEmitter.emit(EVENT_NAMES.LOCATE_MESSAGE + ':' + message.id, false)
}
const getPrice = () => {
const inputTokens = message?.usage?.prompt_tokens ?? 0
const outputTokens = message?.usage?.completion_tokens ?? 0
const model = message.model
if (!model || model.pricing?.input_per_million_tokens === 0 || model.pricing?.output_per_million_tokens === 0) {
return 0
}
return (
(inputTokens * (model.pricing?.input_per_million_tokens ?? 0) +
outputTokens * (model.pricing?.output_per_million_tokens ?? 0)) /
1000000
)
}
const getPriceString = () => {
const price = getPrice()
if (price === 0) {
return ''
}
const currencySymbol = message.model?.pricing?.currencySymbol || '$'
return `| ${t('models.price.cost')}: ${currencySymbol}${price}`
}
if (!message.usage) {
return <div />
}
@@ -72,7 +49,6 @@ const MessgeTokens: React.FC<MessageTokensProps> = ({ message }) => {
<span>{message?.usage?.total_tokens}</span>
<span>{message?.usage?.prompt_tokens}</span>
<span>{message?.usage?.completion_tokens}</span>
<span>{getPriceString()}</span>
</span>
)

View File

@@ -1,7 +1,13 @@
import { CheckOutlined } from '@ant-design/icons'
import { HStack } from '@renderer/components/Layout'
import Scrollbar from '@renderer/components/Scrollbar'
import { DEFAULT_CONTEXTCOUNT, DEFAULT_MAX_TOKENS, DEFAULT_TEMPERATURE } from '@renderer/config/constant'
import {
DEFAULT_CONTEXTCOUNT,
DEFAULT_MAX_TOKENS,
DEFAULT_TEMPERATURE,
isMac,
isWindows
} from '@renderer/config/constant'
import {
isOpenAIModel,
isSupportedFlexServiceTier,
@@ -53,7 +59,6 @@ import {
TranslateLanguageVarious
} from '@renderer/types'
import { modalConfirm } from '@renderer/utils'
import { getSendMessageShortcutLabel } from '@renderer/utils/input'
import { Button, Col, InputNumber, Row, Select, Slider, Switch, Tooltip } from 'antd'
import { CircleHelp, Settings2 } from 'lucide-react'
import { FC, useCallback, useEffect, useMemo, useState } from 'react'
@@ -665,11 +670,10 @@ const SettingsTab: FC<Props> = (props) => {
value={sendMessageShortcut}
menuItemSelectedIcon={<CheckOutlined />}
options={[
{ value: 'Enter', label: getSendMessageShortcutLabel('Enter') },
{ value: 'Ctrl+Enter', label: getSendMessageShortcutLabel('Ctrl+Enter') },
{ value: 'Alt+Enter', label: getSendMessageShortcutLabel('Alt+Enter') },
{ value: 'Command+Enter', label: getSendMessageShortcutLabel('Command+Enter') },
{ value: 'Shift+Enter', label: getSendMessageShortcutLabel('Shift+Enter') }
{ value: 'Enter', label: 'Enter' },
{ value: 'Shift+Enter', label: 'Shift + Enter' },
{ value: 'Ctrl+Enter', label: 'Ctrl + Enter' },
{ value: 'Command+Enter', label: `${isMac ? '⌘' : isWindows ? 'Win' : 'Super'} + Enter` }
]}
onChange={(value) => setSendMessageShortcut(value as SendMessageShortcut)}
style={{ width: 135 }}

View File

@@ -157,7 +157,7 @@ const Container = styled.div`
flex-direction: column;
max-width: var(--assistants-width);
min-width: var(--assistants-width);
background-color: var(--color-background);
background-color: transparent;
overflow: hidden;
.collapsed {
width: 0;

View File

@@ -6,7 +6,6 @@ import { useAssistant } from '@renderer/hooks/useAssistant'
import { getProviderName } from '@renderer/services/ProviderService'
import { Assistant } from '@renderer/types'
import { Button } from 'antd'
import { ChevronsUpDown } from 'lucide-react'
import { FC } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
@@ -46,10 +45,9 @@ const SelectModelButton: FC<Props> = ({ assistant }) => {
<ButtonContent>
<ModelAvatar model={model} size={20} />
<ModelName>
{model ? model.name : t('button.select_model')} {providerName ? ' | ' + providerName : ''}
{model ? model.name : t('button.select_model')} {providerName ? '| ' + providerName : ''}
</ModelName>
</ButtonContent>
<ChevronsUpDown size={14} color="var(--color-icon)" />
</DropdownButton>
)
}
@@ -57,23 +55,21 @@ const SelectModelButton: FC<Props> = ({ assistant }) => {
const DropdownButton = styled(Button)`
font-size: 11px;
border-radius: 15px;
padding: 13px 5px;
padding: 12px 8px 12px 3px;
-webkit-app-region: none;
box-shadow: none;
background-color: transparent;
border: 1px solid transparent;
margin-top: 1px;
`
const ButtonContent = styled.div`
display: flex;
align-items: center;
gap: 6px;
gap: 5px;
`
const ModelName = styled.span`
font-weight: 500;
margin-right: -2px;
`
export default SelectModelButton

View File

@@ -116,6 +116,7 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
const aiProvider = new AiProvider(provider)
values.dimensions = await aiProvider.getEmbeddingDimensions(selectedEmbeddingModel)
} catch (error) {
console.error('Error getting embedding dimensions:', error)
window.message.error(t('message.error.get_embedding_dimensions') + '\n' + getErrorMessage(error))
setLoading(false)
return

View File

@@ -103,8 +103,8 @@ const AssistantPromptSettings: React.FC<Props> = ({ assistant, updateAssistant }
</HStack>
<TextAreaContainer>
{showMarkdown ? (
<MarkdownContainer className="markdown" onClick={() => setShowMarkdown(false)}>
<ReactMarkdown>{prompt}</ReactMarkdown>
<MarkdownContainer onClick={() => setShowMarkdown(false)}>
<ReactMarkdown className="markdown">{prompt}</ReactMarkdown>
<div style={{ height: '30px' }} />
</MarkdownContainer>
) : (

View File

@@ -2,7 +2,6 @@ import {
CloudSyncOutlined,
FileSearchOutlined,
FolderOpenOutlined,
LoadingOutlined,
SaveOutlined,
YuqueOutlined
} from '@ant-design/icons'
@@ -19,7 +18,7 @@ import store, { useAppDispatch } from '@renderer/store'
import { setSkipBackupFile as _setSkipBackupFile } from '@renderer/store/settings'
import { AppInfo } from '@renderer/types'
import { formatFileSize } from '@renderer/utils'
import { Button, Progress, Switch, Typography } from 'antd'
import { Button, Switch, Typography } from 'antd'
import { FileText, FolderCog, FolderInput, Sparkle } from 'lucide-react'
import { FC, useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
@@ -180,281 +179,6 @@ const DataSettings: FC = () => {
})
}
const handleSelectAppDataPath = async () => {
if (!appInfo || !appInfo.appDataPath) {
return
}
const newAppDataPath = await window.api.select({
properties: ['openDirectory', 'createDirectory'],
title: t('settings.data.app_data.select_title')
})
if (!newAppDataPath) {
return
}
// check new app data path is root path
// if is root path, show error
const pathParts = newAppDataPath.split(/[/\\]/).filter((part: string) => part !== '')
if (pathParts.length <= 1) {
window.message.error(t('settings.data.app_data.select_error_root_path'))
return
}
// check new app data path has write permission
const hasWritePermission = await window.api.hasWritePermission(newAppDataPath)
if (!hasWritePermission) {
window.message.error(t('settings.data.app_data.select_error_write_permission'))
return
}
const migrationTitle = (
<div style={{ fontSize: '18px', fontWeight: 'bold' }}>{t('settings.data.app_data.migration_title')}</div>
)
const migrationClassName = 'migration-modal'
const messageKey = 'data-migration'
// 显示确认对话框
showMigrationConfirmModal(appInfo.appDataPath, newAppDataPath, migrationTitle, migrationClassName, messageKey)
}
// 显示确认迁移的对话框
const showMigrationConfirmModal = (
originalPath: string,
newPath: string,
title: React.ReactNode,
className: string,
messageKey: string
) => {
// 复制数据选项状态
let shouldCopyData = true
// 创建路径内容组件
const PathsContent = () => (
<div>
<MigrationPathRow>
<MigrationPathLabel>{t('settings.data.app_data.original_path')}:</MigrationPathLabel>
<MigrationPathValue>{originalPath}</MigrationPathValue>
</MigrationPathRow>
<MigrationPathRow style={{ marginTop: '16px' }}>
<MigrationPathLabel>{t('settings.data.app_data.new_path')}:</MigrationPathLabel>
<MigrationPathValue>{newPath}</MigrationPathValue>
</MigrationPathRow>
</div>
)
const CopyDataContent = () => (
<div>
<MigrationPathRow style={{ marginTop: '20px', flexDirection: 'row', alignItems: 'center' }}>
<Switch
defaultChecked={true}
onChange={(checked) => {
shouldCopyData = checked
}}
style={{ marginRight: '8px' }}
/>
<MigrationPathLabel style={{ fontWeight: 'normal', fontSize: '14px' }}>
{t('settings.data.app_data.copy_data_option')}
</MigrationPathLabel>
</MigrationPathRow>
</div>
)
// 显示确认模态框
const modal = window.modal.confirm({
title,
className,
width: 'min(600px, 90vw)',
style: { minHeight: '400px' },
content: (
<MigrationModalContent>
<PathsContent />
<CopyDataContent />
<MigrationNotice>
<p style={{ color: 'var(--color-warning)' }}>{t('settings.data.app_data.restart_notice')}</p>
<p style={{ color: 'var(--color-text-3)', marginTop: '8px' }}>
{t('settings.data.app_data.copy_time_notice')}
</p>
</MigrationNotice>
</MigrationModalContent>
),
centered: true,
okButtonProps: {
danger: true
},
okText: t('common.confirm'),
cancelText: t('common.cancel'),
onOk: async () => {
try {
// 立即关闭确认对话框
modal.destroy()
// 设置停止退出应用
window.api.setStopQuitApp(true, t('settings.data.app_data.stop_quit_app_reason'))
if (shouldCopyData) {
// 如果选择复制数据,显示进度模态框并执行迁移
const { loadingModal, progressInterval, updateProgress } = showProgressModal(title, className, PathsContent)
try {
await startMigration(originalPath, newPath, progressInterval, updateProgress, loadingModal, messageKey)
} catch (error) {
if (progressInterval) {
clearInterval(progressInterval)
}
loadingModal.destroy()
throw error
}
} else {
// 如果不复制数据,直接设置新的应用数据路径
await window.api.setAppDataPath(newPath)
window.message.success(t('settings.data.app_data.path_changed_without_copy'))
}
// 更新应用数据路径
setAppInfo(await window.api.getAppInfo())
// 通知用户并重启应用
setTimeout(() => {
window.message.success(t('settings.data.app_data.select_success'))
window.api.setStopQuitApp(false, '')
window.api.relaunchApp()
}, 1000)
} catch (error) {
window.api.setStopQuitApp(false, '')
window.message.error({
content:
(shouldCopyData
? t('settings.data.app_data.copy_failed')
: t('settings.data.app_data.path_change_failed')) +
': ' +
error,
duration: 5
})
}
}
})
}
// 显示进度模态框
const showProgressModal = (title: React.ReactNode, className: string, PathsContent: React.FC) => {
let currentProgress = 0
let progressInterval: NodeJS.Timeout | null = null
// 创建进度更新模态框
const loadingModal = window.modal.info({
title,
className,
width: 'min(600px, 90vw)',
style: { minHeight: '400px' },
icon: <LoadingOutlined style={{ fontSize: 18 }} />,
content: (
<MigrationModalContent>
<PathsContent />
<MigrationNotice>
<p>{t('settings.data.app_data.copying')}</p>
<div style={{ marginTop: '12px' }}>
<Progress percent={currentProgress} status="active" strokeWidth={8} />
</div>
<p style={{ color: 'var(--color-warning)', marginTop: '12px', fontSize: '13px' }}>
{t('settings.data.app_data.copying_warning')}
</p>
</MigrationNotice>
</MigrationModalContent>
),
centered: true,
closable: false,
maskClosable: false,
okButtonProps: { style: { display: 'none' } }
})
// 更新进度的函数
const updateProgress = (progress: number, status: 'active' | 'success' = 'active') => {
loadingModal.update({
title,
content: (
<MigrationModalContent>
<PathsContent />
<MigrationNotice>
<p>{t('settings.data.app_data.copying')}</p>
<div style={{ marginTop: '12px' }}>
<Progress percent={Math.round(progress)} status={status} strokeWidth={8} />
</div>
<p style={{ color: 'var(--color-warning)', marginTop: '12px', fontSize: '13px' }}>
{t('settings.data.app_data.copying_warning')}
</p>
</MigrationNotice>
</MigrationModalContent>
)
})
}
// 开始模拟进度更新
progressInterval = setInterval(() => {
if (currentProgress < 95) {
currentProgress += Math.random() * 5 + 1
if (currentProgress > 95) currentProgress = 95
updateProgress(currentProgress)
}
}, 500)
return { loadingModal, progressInterval, updateProgress }
}
// 开始迁移数据
const startMigration = async (
originalPath: string,
newPath: string,
progressInterval: NodeJS.Timeout | null,
updateProgress: (progress: number, status?: 'active' | 'success') => void,
loadingModal: { destroy: () => void },
messageKey: string
): Promise<void> => {
// 开始复制过程
const copyResult = await window.api.copy(originalPath, newPath)
// 停止进度更新
if (progressInterval) {
clearInterval(progressInterval)
}
// 显示100%完成
updateProgress(100, 'success')
if (!copyResult.success) {
// 延迟关闭加载模态框
await new Promise<void>((resolve) => {
setTimeout(() => {
loadingModal.destroy()
window.message.error({
content: t('settings.data.app_data.copy_failed') + ': ' + copyResult.error,
key: messageKey,
duration: 5
})
resolve()
}, 500)
})
throw new Error(copyResult.error || 'Unknown error during copy')
}
// 在复制成功后设置新的AppDataPath
await window.api.setAppDataPath(newPath)
// 短暂延迟以显示100%完成
await new Promise((resolve) => setTimeout(resolve, 500))
// 关闭加载模态框
loadingModal.destroy()
window.message.success({
content: t('settings.data.app_data.copy_success'),
key: messageKey,
duration: 2
})
}
const onSkipBackupFilesChange = (value: boolean) => {
setSkipBackupFile(value)
dispatch(_setSkipBackupFile(value))
@@ -521,9 +245,6 @@ const DataSettings: FC = () => {
<PathRow>
<PathText style={{ color: 'var(--color-text-3)' }}>{appInfo?.appDataPath}</PathText>
<StyledIcon onClick={() => handleOpenPath(appInfo?.appDataPath)} style={{ flexShrink: 0 }} />
<HStack gap="5px" style={{ marginLeft: '8px' }}>
<Button onClick={handleSelectAppDataPath}>{t('settings.data.app_data.select')}</Button>
</HStack>
</PathRow>
</SettingRow>
<SettingDivider />
@@ -631,38 +352,4 @@ const PathRow = styled(HStack)`
gap: 5px;
`
// Add styled components for migration modal
const MigrationModalContent = styled.div`
padding: 20px 0 10px;
display: flex;
flex-direction: column;
`
const MigrationNotice = styled.div`
margin-top: 24px;
font-size: 14px;
`
const MigrationPathRow = styled.div`
display: flex;
flex-direction: column;
gap: 5px;
`
const MigrationPathLabel = styled.div`
font-weight: 600;
font-size: 15px;
color: var(--color-text-1);
`
const MigrationPathValue = styled.div`
font-size: 14px;
color: var(--color-text-2);
background-color: var(--color-background-soft);
padding: 8px 12px;
border-radius: 4px;
word-break: break-all;
border: 1px solid var(--color-border);
`
export default DataSettings

View File

@@ -84,16 +84,6 @@ const ExportMenuOptions: FC = () => {
<SettingRowTitle>{t('settings.data.export_menu.docx')}</SettingRowTitle>
<Switch checked={exportMenuOptions.docx} onChange={(checked) => handleToggleOption('docx', checked)} />
</SettingRow>
<SettingDivider />
<SettingRow>
<SettingRowTitle>{t('settings.data.export_menu.plain_text')}</SettingRowTitle>
<Switch
checked={exportMenuOptions.plain_text}
onChange={(checked) => handleToggleOption('plain_text', checked)}
/>
</SettingRow>
</SettingGroup>
)
}

View File

@@ -9,7 +9,7 @@ import {
} from '@renderer/config/models'
import { Model, ModelType } from '@renderer/types'
import { getDefaultGroupName } from '@renderer/utils'
import { Button, Checkbox, Divider, Flex, Form, Input, InputNumber, message, Modal, Select } from 'antd'
import { Button, Checkbox, Divider, Flex, Form, Input, message, Modal } from 'antd'
import { FC, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
@@ -20,42 +20,25 @@ interface ModelEditContentProps {
onClose: () => void
}
const symbols = ['$', '¥', '€', '£']
const ModelEditContent: FC<ModelEditContentProps> = ({ model, onUpdateModel, open, onClose }) => {
const [form] = Form.useForm()
const { t } = useTranslation()
const [showMoreSettings, setShowMoreSettings] = useState(false)
const [currencySymbol, setCurrencySymbol] = useState(model.pricing?.currencySymbol || '$')
const [isCustomCurrency, setIsCustomCurrency] = useState(!symbols.includes(model.pricing?.currencySymbol || '$'))
const [showModelTypes, setShowModelTypes] = useState(false)
const onFinish = (values: any) => {
const finalCurrencySymbol = isCustomCurrency ? values.customCurrencySymbol : values.currencySymbol
const updatedModel = {
...model,
id: values.id || model.id,
name: values.name || model.name,
group: values.group || model.group,
pricing: {
input_per_million_tokens: Number(values.input_per_million_tokens) || 0,
output_per_million_tokens: Number(values.output_per_million_tokens) || 0,
currencySymbol: finalCurrencySymbol || '$'
}
group: values.group || model.group
}
onUpdateModel(updatedModel)
setShowMoreSettings(false)
setShowModelTypes(false)
onClose()
}
const handleClose = () => {
setShowMoreSettings(false)
setShowModelTypes(false)
onClose()
}
const currencyOptions = [
...symbols.map((symbol) => ({ label: symbol, value: symbol })),
{ label: t('models.price.custom'), value: 'custom' }
]
return (
<Modal
title={t('models.edit')}
@@ -69,7 +52,7 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ model, onUpdateModel, ope
if (visible) {
form.getFieldInstance('id')?.focus()
} else {
setShowMoreSettings(false)
setShowModelTypes(false)
}
}}>
<Form
@@ -81,15 +64,7 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ model, onUpdateModel, ope
initialValues={{
id: model.id,
name: model.name,
group: model.group,
input_per_million_tokens: model.pricing?.input_per_million_tokens ?? 0,
output_per_million_tokens: model.pricing?.output_per_million_tokens ?? 0,
currencySymbol: symbols.includes(model.pricing?.currencySymbol || '$')
? model.pricing?.currencySymbol || '$'
: 'custom',
customCurrencySymbol: symbols.includes(model.pricing?.currencySymbol || '$')
? ''
: model.pricing?.currencySymbol || ''
group: model.group
}}
onFinish={onFinish}>
<Form.Item
@@ -134,22 +109,20 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ model, onUpdateModel, ope
<Input placeholder={t('settings.models.add.group_name.placeholder')} spellCheck={false} />
</Form.Item>
<Form.Item style={{ marginBottom: 15, textAlign: 'center' }}>
<Flex justify="center" align="center" style={{ position: 'relative' }}>
<MoreSettingsRow
onClick={() => setShowMoreSettings(!showMoreSettings)}
style={{ position: 'absolute', right: 0 }}>
<Flex justify="space-between" align="center" style={{ position: 'relative' }}>
<MoreSettingsRow onClick={() => setShowModelTypes(!showModelTypes)}>
{t('settings.moresetting')}
<ExpandIcon>{showMoreSettings ? <UpOutlined /> : <DownOutlined />}</ExpandIcon>
<ExpandIcon>{showModelTypes ? <UpOutlined /> : <DownOutlined />}</ExpandIcon>
</MoreSettingsRow>
<Button type="primary" htmlType="submit" size="middle">
{t('common.save')}
</Button>
</Flex>
</Form.Item>
{showMoreSettings && (
{showModelTypes && (
<div>
<Divider style={{ margin: '0 0 15px 0' }} />
<TypeTitle>{t('models.type.select')}</TypeTitle>
<TypeTitle>{t('models.type.select')}:</TypeTitle>
{(() => {
const defaultTypes = [
...(isVisionModel(model) ? ['vision'] : []),
@@ -220,59 +193,6 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ model, onUpdateModel, ope
/>
)
})()}
<TypeTitle>{t('models.price.price')}</TypeTitle>
<Form.Item name="currencySymbol" label={t('models.price.currency')} style={{ marginBottom: 10 }}>
<Select
style={{ width: '100px' }}
options={currencyOptions}
onChange={(value) => {
if (value === 'custom') {
setIsCustomCurrency(true)
setCurrencySymbol(form.getFieldValue('customCurrencySymbol') || '')
} else {
setIsCustomCurrency(false)
setCurrencySymbol(value)
}
}}
dropdownMatchSelectWidth={false}
/>
</Form.Item>
{isCustomCurrency && (
<Form.Item
name="customCurrencySymbol"
label={t('models.price.custom_currency')}
style={{ marginBottom: 10 }}
rules={[{ required: isCustomCurrency }]}>
<Input
style={{ width: '100px' }}
placeholder={t('models.price.custom_currency_placeholder')}
maxLength={5}
onChange={(e) => setCurrencySymbol(e.target.value)}
/>
</Form.Item>
)}
<Form.Item label={t('models.price.input')} name="input_per_million_tokens">
<InputNumber
placeholder="0.00"
min={0}
step={0.01}
precision={2}
style={{ width: '240px' }}
addonAfter={`${currencySymbol} / ${t('models.price.million_tokens')}`}
/>
</Form.Item>
<Form.Item label={t('models.price.output')} name="output_per_million_tokens">
<InputNumber
placeholder="0.00"
min={0}
step={0.01}
precision={2}
style={{ width: '240px' }}
addonAfter={`${currencySymbol} / ${t('models.price.million_tokens')}`}
/>
</Form.Item>
</div>
)}
</Form>
@@ -281,7 +201,6 @@ const ModelEditContent: FC<ModelEditContentProps> = ({ model, onUpdateModel, ope
}
const TypeTitle = styled.div`
margin-top: 16px;
margin-bottom: 12px;
font-size: 14px;
font-weight: 600;

View File

@@ -26,7 +26,6 @@ import { find, isEmpty, sortBy } from 'lodash'
import { HelpCircle, Settings2, TriangleAlert } from 'lucide-react'
import { FC, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import ReactMarkdown from 'react-markdown'
import styled from 'styled-components'
let _text = ''
@@ -40,8 +39,6 @@ const TranslateSettings: FC<{
setIsScrollSyncEnabled: (value: boolean) => void
isBidirectional: boolean
setIsBidirectional: (value: boolean) => void
enableMarkdown: boolean
setEnableMarkdown: (value: boolean) => void
bidirectionalPair: [string, string]
setBidirectionalPair: (value: [string, string]) => void
translateModel: Model | undefined
@@ -55,8 +52,6 @@ const TranslateSettings: FC<{
setIsScrollSyncEnabled,
isBidirectional,
setIsBidirectional,
enableMarkdown,
setEnableMarkdown,
bidirectionalPair,
setBidirectionalPair,
translateModel,
@@ -87,7 +82,6 @@ const TranslateSettings: FC<{
setBidirectionalPair(localPair)
db.settings.put({ id: 'translate:bidirectional:pair', value: localPair })
db.settings.put({ id: 'translate:scroll:sync', value: isScrollSyncEnabled })
db.settings.put({ id: 'translate:markdown:enabled', value: enableMarkdown })
window.message.success({
content: t('message.save.success.title'),
key: 'translate-settings-save'
@@ -141,13 +135,6 @@ const TranslateSettings: FC<{
</div>
</div>
<div>
<Flex align="center" justify="space-between">
<div style={{ fontWeight: 500 }}>{t('translate.settings.preview')}</div>
<Switch checked={enableMarkdown} onChange={setEnableMarkdown} />
</Flex>
</div>
<div>
<Flex align="center" justify="space-between">
<div style={{ fontWeight: 500 }}>{t('translate.settings.scroll_sync')}</div>
@@ -225,7 +212,6 @@ const TranslatePage: FC = () => {
const [historyDrawerVisible, setHistoryDrawerVisible] = useState(false)
const [isScrollSyncEnabled, setIsScrollSyncEnabled] = useState(false)
const [isBidirectional, setIsBidirectional] = useState(false)
const [enableMarkdown, setEnableMarkdown] = useState(false)
const [bidirectionalPair, setBidirectionalPair] = useState<[string, string]>(['english', 'chinese'])
const [settingsVisible, setSettingsVisible] = useState(false)
const [detectedLanguage, setDetectedLanguage] = useState<string | null>(null)
@@ -402,9 +388,6 @@ const TranslatePage: FC = () => {
const scrollSyncSetting = await db.settings.get({ id: 'translate:scroll:sync' })
setIsScrollSyncEnabled(scrollSyncSetting ? scrollSyncSetting.value : false)
const markdownSetting = await db.settings.get({ id: 'translate:markdown:enabled' })
setEnableMarkdown(markdownSetting ? markdownSetting.value : false)
})
}, [])
@@ -603,13 +586,7 @@ const TranslatePage: FC = () => {
</OperationBar>
<OutputText ref={outputTextRef} onScroll={handleOutputScroll} className="selectable">
{!result ? (
t('translate.output.placeholder')
) : enableMarkdown ? (
<ReactMarkdown>{result}</ReactMarkdown>
) : (
result
)}
{result || t('translate.output.placeholder')}
</OutputText>
</OutputContainer>
</ContentContainer>
@@ -621,8 +598,6 @@ const TranslatePage: FC = () => {
setIsScrollSyncEnabled={setIsScrollSyncEnabled}
isBidirectional={isBidirectional}
setIsBidirectional={toggleBidirectional}
enableMarkdown={enableMarkdown}
setEnableMarkdown={setEnableMarkdown}
bidirectionalPair={bidirectionalPair}
setBidirectionalPair={setBidirectionalPair}
translateModel={translateModel}

View File

@@ -1,321 +1,365 @@
/**
* 职责提供原子化的、无状态的API调用函数
*/
import { StreamTextParams } from '@cherrystudio/ai-core'
import { AiSdkMiddlewareConfig } from '@renderer/aiCore/middleware/aisdk/AiSdkMiddlewareBuilder'
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
import { buildStreamTextParams } from '@renderer/aiCore/transformParameters'
import Logger from '@renderer/config/logger'
import {
isEmbeddingModel,
isGenerateImageModel,
isOpenRouterBuiltInWebSearchModel,
isReasoningModel,
isSupportedDisableGenerationModel,
isSupportedReasoningEffortModel,
isSupportedThinkingTokenModel
isSupportedThinkingTokenModel,
isWebSearchModel
} from '@renderer/config/models'
import {
SEARCH_SUMMARY_PROMPT,
SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY,
SEARCH_SUMMARY_PROMPT_WEB_ONLY
} from '@renderer/config/prompts'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import { Assistant, MCPTool, Model, Provider } from '@renderer/types'
import {
Assistant,
ExternalToolResult,
KnowledgeReference,
MCPTool,
Model,
Provider,
WebSearchResponse,
WebSearchSource
} from '@renderer/types'
import { type Chunk, ChunkType } from '@renderer/types/chunk'
import { Message } from '@renderer/types/newMessage'
import { SdkModel } from '@renderer/types/sdk'
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { isEmpty, takeRight } from 'lodash'
import { isAbortError } from '@renderer/utils/error'
import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract'
import { findFileBlocks, getKnowledgeBaseIds, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { findLast, isEmpty, takeRight } from 'lodash'
import AiProvider from '../aiCore'
import AiProviderNew from '../aiCore/index_new'
import store from '../store'
import {
getAssistantProvider,
getAssistantSettings,
getDefaultModel,
getProviderByModel,
getTopNamingModel,
getTranslateModel
} from './AssistantService'
import { getDefaultAssistant } from './AssistantService'
import { processKnowledgeSearch } from './KnowledgeService'
import {
filterContextMessages,
filterEmptyMessages,
filterUsefulMessages,
filterUserRoleStartMessages
} from './MessagesService'
import WebSearchService from './WebSearchService'
// // TODO考虑拆开
// async function fetchExternalTool(
// lastUserMessage: Message,
// assistant: Assistant,
// onChunkReceived: (chunk: Chunk) => void,
// lastAnswer?: Message
// ) {
// // 可能会有重复?
// const knowledgeBaseIds = assistant.knowledge_bases?.map((base) => base.id)
// const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
// const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
// const webSearchProvider = WebSearchService.getWebSearchProvider(assistant.webSearchProviderId)
// TODO考虑拆开
async function fetchExternalTool(
lastUserMessage: Message,
assistant: Assistant,
onChunkReceived: (chunk: Chunk) => void,
lastAnswer?: Message
): Promise<ExternalToolResult> {
// 可能会有重复?
const knowledgeBaseIds = getKnowledgeBaseIds(lastUserMessage)
const hasKnowledgeBase = !isEmpty(knowledgeBaseIds)
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
const webSearchProvider = WebSearchService.getWebSearchProvider(assistant.webSearchProviderId)
// // 使用外部搜索工具
// const shouldWebSearch = !!assistant.webSearchProviderId && webSearchProvider !== null
// const shouldKnowledgeSearch = hasKnowledgeBase
// 使用外部搜索工具
const shouldWebSearch = !!assistant.webSearchProviderId && webSearchProvider !== null
const shouldKnowledgeSearch = hasKnowledgeBase
// // 在工具链开始时发送进度通知
// const willUseTools = shouldWebSearch || shouldKnowledgeSearch
// if (willUseTools) {
// onChunkReceived({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS })
// }
// 在工具链开始时发送进度通知
const willUseTools = shouldWebSearch || shouldKnowledgeSearch
if (willUseTools) {
onChunkReceived({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS })
}
// // --- Keyword/Question Extraction Function ---
// const extract = async (): Promise<ExtractResults | undefined> => {
// if (!lastUserMessage) return undefined
// --- Keyword/Question Extraction Function ---
const extract = async (): Promise<ExtractResults | undefined> => {
if (!lastUserMessage) return undefined
// // 根据配置决定是否需要提取
// const needWebExtract = shouldWebSearch
// const needKnowledgeExtract = hasKnowledgeBase && knowledgeRecognition === 'on'
// 根据配置决定是否需要提取
const needWebExtract = shouldWebSearch
const needKnowledgeExtract = hasKnowledgeBase && knowledgeRecognition === 'on'
// if (!needWebExtract && !needKnowledgeExtract) return undefined
if (!needWebExtract && !needKnowledgeExtract) return undefined
// let prompt: string
// if (needWebExtract && !needKnowledgeExtract) {
// prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY
// } else if (!needWebExtract && needKnowledgeExtract) {
// prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY
// } else {
// prompt = SEARCH_SUMMARY_PROMPT
// }
let prompt: string
if (needWebExtract && !needKnowledgeExtract) {
prompt = SEARCH_SUMMARY_PROMPT_WEB_ONLY
} else if (!needWebExtract && needKnowledgeExtract) {
prompt = SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY
} else {
prompt = SEARCH_SUMMARY_PROMPT
}
// const summaryAssistant = getDefaultAssistant()
// summaryAssistant.model = assistant.model || getDefaultModel()
// summaryAssistant.prompt = prompt
const summaryAssistant = getDefaultAssistant()
summaryAssistant.model = assistant.model || getDefaultModel()
summaryAssistant.prompt = prompt
// try {
// const result = await fetchSearchSummary({
// messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage],
// assistant: summaryAssistant
// })
// if (!result) return getFallbackResult()
// const extracted = extractInfoFromXML(result.getText())
// // 根据需求过滤结果
// return {
// websearch: needWebExtract ? extracted?.websearch : undefined,
// knowledge: needKnowledgeExtract ? extracted?.knowledge : undefined
// }
// } catch (e: any) {
// console.error('extract error', e)
// if (isAbortError(e)) throw e
// return getFallbackResult()
// }
// }
// const getFallbackResult = (): ExtractResults => {
// const fallbackContent = getMainTextContent(lastUserMessage)
// return {
// websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined,
// knowledge: shouldKnowledgeSearch
// ? {
// question: [fallbackContent || 'search'],
// rewrite: fallbackContent
// }
// : undefined
// }
// }
// // --- Web Search Function ---
// const searchTheWeb = async (extractResults: ExtractResults | undefined): Promise<WebSearchResponse | undefined> => {
// if (!shouldWebSearch) return
// // Add check for extractResults existence early
// if (!extractResults?.websearch) {
// console.warn('searchTheWeb called without valid extractResults.websearch')
// return
// }
// if (extractResults.websearch.question[0] === 'not_needed') return
// // Add check for assistant.model before using it
// if (!assistant.model) {
// console.warn('searchTheWeb called without assistant.model')
// return undefined
// }
// try {
// // Use the consolidated processWebsearch function
// WebSearchService.createAbortSignal(lastUserMessage.id)
// return {
// results: await WebSearchService.processWebsearch(webSearchProvider!, extractResults),
// source: WebSearchSource.WEBSEARCH
// }
// } catch (error) {
// if (isAbortError(error)) throw error
// console.error('Web search failed:', error)
// return
// }
// }
// // --- Knowledge Base Search Function ---
// const searchKnowledgeBase = async (
// extractResults: ExtractResults | undefined
// ): Promise<KnowledgeReference[] | undefined> => {
// if (!hasKnowledgeBase) return
// // 知识库搜索条件
// let searchCriteria: { question: string[]; rewrite: string }
// if (knowledgeRecognition === 'off') {
// const directContent = getMainTextContent(lastUserMessage)
// searchCriteria = { question: [directContent || 'search'], rewrite: directContent }
// } else {
// // auto mode
// if (!extractResults?.knowledge) {
// console.warn('searchKnowledgeBase: No valid search criteria in auto mode')
// return
// }
// searchCriteria = extractResults.knowledge
// }
// if (searchCriteria.question[0] === 'not_needed') return
// try {
// const tempExtractResults: ExtractResults = {
// websearch: undefined,
// knowledge: searchCriteria
// }
// // Attempt to get knowledgeBaseIds from the main text block
// // NOTE: This assumes knowledgeBaseIds are ONLY on the main text block
// // NOTE: processKnowledgeSearch needs to handle undefined ids gracefully
// // const mainTextBlock = mainTextBlocks
// // ?.map((blockId) => store.getState().messageBlocks.entities[blockId])
// // .find((block) => block?.type === MessageBlockType.MAIN_TEXT) as MainTextMessageBlock | undefined
// return await processKnowledgeSearch(tempExtractResults, knowledgeBaseIds)
// } catch (error) {
// console.error('Knowledge base search failed:', error)
// return
// }
// }
// // --- Execute Extraction and Searches ---
// let extractResults: ExtractResults | undefined
// try {
// // 根据配置决定是否需要提取
// if (shouldWebSearch || hasKnowledgeBase) {
// extractResults = await extract()
// Logger.log('[fetchExternalTool] Extraction results:', extractResults)
// }
// let webSearchResponseFromSearch: WebSearchResponse | undefined
// let knowledgeReferencesFromSearch: KnowledgeReference[] | undefined
// // 并行执行搜索
// if (shouldWebSearch || shouldKnowledgeSearch) {
// ;[webSearchResponseFromSearch, knowledgeReferencesFromSearch] = await Promise.all([
// searchTheWeb(extractResults),
// searchKnowledgeBase(extractResults)
// ])
// }
// // 存储搜索结果
// if (lastUserMessage) {
// if (webSearchResponseFromSearch) {
// window.keyv.set(`web-search-${lastUserMessage.id}`, webSearchResponseFromSearch)
// }
// if (knowledgeReferencesFromSearch) {
// window.keyv.set(`knowledge-search-${lastUserMessage.id}`, knowledgeReferencesFromSearch)
// }
// }
// // 发送工具执行完成通知
// if (willUseTools) {
// onChunkReceived({
// type: ChunkType.EXTERNEL_TOOL_COMPLETE,
// external_tool: {
// webSearch: webSearchResponseFromSearch,
// knowledge: knowledgeReferencesFromSearch
// }
// })
// }
// } catch (error) {
// if (isAbortError(error)) throw error
// console.error('Tool execution failed:', error)
// // 发送错误状态
// if (willUseTools) {
// onChunkReceived({
// type: ChunkType.EXTERNEL_TOOL_COMPLETE,
// external_tool: {
// webSearch: undefined,
// knowledge: undefined
// }
// })
// }
// return { mcpTools: [] }
// }
// }
export async function fetchMcpTools(assistant: Assistant) {
// Get MCP tools (Fix duplicate declaration)
let mcpTools: MCPTool[] = [] // Initialize as empty array
const allMcpServers = store.getState().mcp.servers || []
const activedMcpServers = allMcpServers.filter((s) => s.isActive)
const assistantMcpServers = assistant.mcpServers || []
const enabledMCPs = activedMcpServers.filter((server) => assistantMcpServers.some((s) => s.id === server.id))
if (enabledMCPs && enabledMCPs.length > 0) {
try {
const toolPromises = enabledMCPs.map<Promise<MCPTool[]>>(async (mcpServer) => {
try {
const tools = await window.api.mcp.listTools(mcpServer)
return tools.filter((tool: any) => !mcpServer.disabledTools?.includes(tool.name))
} catch (error) {
console.error(`Error fetching tools from MCP server ${mcpServer.name}:`, error)
return []
}
const result = await fetchSearchSummary({
messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage],
assistant: summaryAssistant
})
const results = await Promise.allSettled(toolPromises)
mcpTools = results
.filter((result): result is PromiseFulfilledResult<MCPTool[]> => result.status === 'fulfilled')
.map((result) => result.value)
.flat()
} catch (toolError) {
console.error('Error fetching MCP tools:', toolError)
if (!result) return getFallbackResult()
const extracted = extractInfoFromXML(result.getText())
// 根据需求过滤结果
return {
websearch: needWebExtract ? extracted?.websearch : undefined,
knowledge: needKnowledgeExtract ? extracted?.knowledge : undefined
}
} catch (e: any) {
console.error('extract error', e)
if (isAbortError(e)) throw e
return getFallbackResult()
}
}
return mcpTools
const getFallbackResult = (): ExtractResults => {
const fallbackContent = getMainTextContent(lastUserMessage)
return {
websearch: shouldWebSearch ? { question: [fallbackContent || 'search'] } : undefined,
knowledge: shouldKnowledgeSearch
? {
question: [fallbackContent || 'search'],
rewrite: fallbackContent
}
: undefined
}
}
// --- Web Search Function ---
const searchTheWeb = async (extractResults: ExtractResults | undefined): Promise<WebSearchResponse | undefined> => {
if (!shouldWebSearch) return
// Add check for extractResults existence early
if (!extractResults?.websearch) {
console.warn('searchTheWeb called without valid extractResults.websearch')
return
}
if (extractResults.websearch.question[0] === 'not_needed') return
// Add check for assistant.model before using it
if (!assistant.model) {
console.warn('searchTheWeb called without assistant.model')
return undefined
}
try {
// Use the consolidated processWebsearch function
WebSearchService.createAbortSignal(lastUserMessage.id)
return {
results: await WebSearchService.processWebsearch(webSearchProvider!, extractResults),
source: WebSearchSource.WEBSEARCH
}
} catch (error) {
if (isAbortError(error)) throw error
console.error('Web search failed:', error)
return
}
}
// --- Knowledge Base Search Function ---
const searchKnowledgeBase = async (
extractResults: ExtractResults | undefined
): Promise<KnowledgeReference[] | undefined> => {
if (!hasKnowledgeBase) return
// 知识库搜索条件
let searchCriteria: { question: string[]; rewrite: string }
if (knowledgeRecognition === 'off') {
const directContent = getMainTextContent(lastUserMessage)
searchCriteria = { question: [directContent || 'search'], rewrite: directContent }
} else {
// auto mode
if (!extractResults?.knowledge) {
console.warn('searchKnowledgeBase: No valid search criteria in auto mode')
return
}
searchCriteria = extractResults.knowledge
}
if (searchCriteria.question[0] === 'not_needed') return
try {
const tempExtractResults: ExtractResults = {
websearch: undefined,
knowledge: searchCriteria
}
// Attempt to get knowledgeBaseIds from the main text block
// NOTE: This assumes knowledgeBaseIds are ONLY on the main text block
// NOTE: processKnowledgeSearch needs to handle undefined ids gracefully
// const mainTextBlock = mainTextBlocks
// ?.map((blockId) => store.getState().messageBlocks.entities[blockId])
// .find((block) => block?.type === MessageBlockType.MAIN_TEXT) as MainTextMessageBlock | undefined
return await processKnowledgeSearch(tempExtractResults, knowledgeBaseIds)
} catch (error) {
console.error('Knowledge base search failed:', error)
return
}
}
// --- Execute Extraction and Searches ---
let extractResults: ExtractResults | undefined
try {
// 根据配置决定是否需要提取
if (shouldWebSearch || hasKnowledgeBase) {
extractResults = await extract()
Logger.log('[fetchExternalTool] Extraction results:', extractResults)
}
let webSearchResponseFromSearch: WebSearchResponse | undefined
let knowledgeReferencesFromSearch: KnowledgeReference[] | undefined
// 并行执行搜索
if (shouldWebSearch || shouldKnowledgeSearch) {
;[webSearchResponseFromSearch, knowledgeReferencesFromSearch] = await Promise.all([
searchTheWeb(extractResults),
searchKnowledgeBase(extractResults)
])
}
// 存储搜索结果
if (lastUserMessage) {
if (webSearchResponseFromSearch) {
window.keyv.set(`web-search-${lastUserMessage.id}`, webSearchResponseFromSearch)
}
if (knowledgeReferencesFromSearch) {
window.keyv.set(`knowledge-search-${lastUserMessage.id}`, knowledgeReferencesFromSearch)
}
}
// 发送工具执行完成通知
if (willUseTools) {
onChunkReceived({
type: ChunkType.EXTERNEL_TOOL_COMPLETE,
external_tool: {
webSearch: webSearchResponseFromSearch,
knowledge: knowledgeReferencesFromSearch
}
})
}
// Get MCP tools (Fix duplicate declaration)
let mcpTools: MCPTool[] = [] // Initialize as empty array
const enabledMCPs = assistant.mcpServers
if (enabledMCPs && enabledMCPs.length > 0) {
try {
const toolPromises = enabledMCPs.map(async (mcpServer) => {
const tools = await window.api.mcp.listTools(mcpServer)
return tools.filter((tool: any) => !mcpServer.disabledTools?.includes(tool.name))
})
const results = await Promise.all(toolPromises)
mcpTools = results.flat() // Flatten the array of arrays
} catch (toolError) {
console.error('Error fetching MCP tools:', toolError)
}
}
return { mcpTools }
} catch (error) {
if (isAbortError(error)) throw error
console.error('Tool execution failed:', error)
// 发送错误状态
if (willUseTools) {
onChunkReceived({
type: ChunkType.EXTERNEL_TOOL_COMPLETE,
external_tool: {
webSearch: undefined,
knowledge: undefined
}
})
}
return { mcpTools: [] }
}
}
export async function fetchChatCompletion({
messages,
assistant,
options,
onChunkReceived
}: {
messages: StreamTextParams['messages']
messages: Message[]
assistant: Assistant
options: {
signal?: AbortSignal
timeout?: number
headers?: Record<string, string>
}
onChunkReceived: (chunk: Chunk) => void
// TODO
// onChunkStatus: (status: 'searching' | 'processing' | 'success' | 'error') => void
}) {
console.log('fetchChatCompletion', messages, assistant)
const provider = getAssistantProvider(assistant)
const AI = new AiProviderNew(provider)
const AI = new AiProvider(provider)
const mcpTools = await fetchMcpTools(assistant)
// Make sure that 'Clear Context' works for all scenarios including external tool and normal chat.
messages = filterContextMessages(messages)
// 使用 transformParameters 模块构建参数
const { params: aiSdkParams, modelId } = await buildStreamTextParams(messages, assistant, {
mcpTools: mcpTools,
requestOptions: options
})
const middlewareConfig: AiSdkMiddlewareConfig = {
streamOutput: assistant.settings?.streamOutput ?? true,
onChunk: onChunkReceived,
model: assistant.model,
provider: provider,
enableReasoning: assistant.settings?.reasoning_effort !== undefined
const lastUserMessage = findLast(messages, (m) => m.role === 'user')
const lastAnswer = findLast(messages, (m) => m.role === 'assistant')
if (!lastUserMessage) {
console.error('fetchChatCompletion returning early: Missing lastUserMessage or lastAnswer')
return
}
// try {
// NOTE: The search results are NOT added to the messages sent to the AI here.
// They will be retrieved and used by the messageThunk later to create CitationBlocks.
const { mcpTools } = await fetchExternalTool(lastUserMessage, assistant, onChunkReceived, lastAnswer)
const model = assistant.model || getDefaultModel()
const { maxTokens, contextCount } = getAssistantSettings(assistant)
const filteredMessages = filterUsefulMessages(messages)
const _messages = filterUserRoleStartMessages(
filterEmptyMessages(filterContextMessages(takeRight(filteredMessages, contextCount + 2))) // 取原来几个provider的最大值
)
const enableReasoning =
((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) &&
assistant.settings?.reasoning_effort !== undefined) ||
(isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model)))
const enableWebSearch =
(assistant.enableWebSearch && isWebSearchModel(model)) ||
isOpenRouterBuiltInWebSearchModel(model) ||
model.id.includes('sonar') ||
false
const enableGenerateImage =
isGenerateImageModel(model) && (isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage : true)
// --- Call AI Completions ---
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
await AI.completions(modelId, aiSdkParams, middlewareConfig)
if (enableWebSearch) {
onChunkReceived({ type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS })
}
await AI.completions(
{
callType: 'chat',
messages: _messages,
assistant,
onChunk: onChunkReceived,
mcpTools: mcpTools,
maxTokens,
streamOutput: assistant.settings?.streamOutput || false,
enableReasoning,
enableWebSearch,
enableGenerateImage
},
{
streamOutput: assistant.settings?.streamOutput || false
}
)
}
interface FetchTranslateProps {
@@ -526,7 +570,10 @@ export async function checkApi(provider: Provider, model: Model): Promise<void>
assistant.model = model
try {
if (isEmbeddingModel(model)) {
await ai.getEmbeddingDimensions(model)
const result = await ai.getEmbeddingDimensions(model)
if (result === 0) {
throw new Error(i18n.t('message.error.enter.model'))
}
} else {
const params: CompletionsParams = {
callType: 'check',

View File

@@ -14,17 +14,7 @@ export function getDefaultAssistant(): Assistant {
topics: [getDefaultTopic('default')],
messages: [],
type: 'assistant',
regularPhrases: [], // Added regularPhrases
settings: {
temperature: DEFAULT_TEMPERATURE,
contextCount: DEFAULT_CONTEXTCOUNT,
enableMaxTokens: false,
maxTokens: 0,
streamOutput: true,
topP: 1,
toolUseMode: 'prompt',
customParameters: []
}
regularPhrases: [] // Added regularPhrases
}
}
@@ -137,17 +127,7 @@ export async function createAssistantFromAgent(agent: Agent) {
topics: [topic],
model: agent.defaultModel,
type: 'assistant',
regularPhrases: agent.regularPhrases || [], // Ensured regularPhrases
settings: agent.settings || {
temperature: DEFAULT_TEMPERATURE,
contextCount: DEFAULT_CONTEXTCOUNT,
enableMaxTokens: false,
maxTokens: 0,
streamOutput: true,
topP: 1,
toolUseMode: 'prompt',
customParameters: []
}
regularPhrases: agent.regularPhrases || [] // Ensured regularPhrases
}
store.dispatch(addAssistant(assistant))

View File

@@ -1,34 +0,0 @@
import { StreamTextParams } from '@cherrystudio/ai-core'
import { convertMessagesToSdkMessages } from '@renderer/aiCore/transformParameters'
import { Assistant, Message } from '@renderer/types'
import { isEmpty, takeRight } from 'lodash'
import { getAssistantSettings, getDefaultModel } from './AssistantService'
import {
filterContextMessages,
filterEmptyMessages,
filterUsefulMessages,
filterUserRoleStartMessages
} from './MessagesService'
export class ConversationService {
static async prepareMessagesForLlm(messages: Message[], assistant: Assistant): Promise<StreamTextParams['messages']> {
const { contextCount } = getAssistantSettings(assistant)
// This logic is extracted from the original ApiService.fetchChatCompletion
const contextMessages = filterContextMessages(messages)
const filteredMessages = filterUsefulMessages(contextMessages)
// Take the last `contextCount` messages, plus 2 to allow for a final user/assistant exchange.
const finalMessages = filterUserRoleStartMessages(
filterEmptyMessages(takeRight(filteredMessages, contextCount + 2))
)
return await convertMessagesToSdkMessages(finalMessages, assistant.model || getDefaultModel())
}
static needsWebSearch(assistant: Assistant): boolean {
return !!assistant.webSearchProviderId
}
static needsKnowledgeSearch(assistant: Assistant): boolean {
return !isEmpty(assistant.knowledge_bases)
}
}

View File

@@ -101,7 +101,7 @@ export const searchKnowledgeBase = async (
// 执行搜索
const searchResults = await window.api.knowledgeBase.search({
search: rewrite || query,
search: query,
base: baseParams
})

View File

@@ -6,7 +6,7 @@ import { fetchMessagesSummary } from '@renderer/services/ApiService'
import store from '@renderer/store'
import { messageBlocksSelectors, removeManyBlocks } from '@renderer/store/messageBlock'
import { selectMessagesForTopic } from '@renderer/store/newMessage'
import type { Assistant, FileType, Model, Topic, Usage } from '@renderer/types'
import type { Assistant, FileType, MCPServer, Model, Topic, Usage } from '@renderer/types'
import { FileTypes } from '@renderer/types'
import type { Message, MessageBlock } from '@renderer/types/newMessage'
import { AssistantMessageStatus, MessageBlockStatus, MessageBlockType } from '@renderer/types/newMessage'
@@ -108,7 +108,9 @@ export function getUserMessage({
content,
files,
// Keep other potential params if needed by createMessage
knowledgeBaseIds,
mentions,
enabledMCPs,
usage
}: {
assistant: Assistant
@@ -118,6 +120,7 @@ export function getUserMessage({
files?: FileType[]
knowledgeBaseIds?: string[]
mentions?: Model[]
enabledMCPs?: MCPServer[]
usage?: Usage
}): { message: Message; blocks: MessageBlock[] } {
const defaultModel = getDefaultModel()
@@ -130,7 +133,8 @@ export function getUserMessage({
if (content !== undefined) {
// Pass messageId when creating blocks
const textBlock = createMainTextBlock(messageId, content, {
status: MessageBlockStatus.SUCCESS
status: MessageBlockStatus.SUCCESS,
knowledgeBaseIds
})
blocks.push(textBlock)
blockIds.push(textBlock.id)
@@ -161,7 +165,7 @@ export function getUserMessage({
blocks: blockIds,
// 移除knowledgeBaseIds
mentions,
// 移除mcp
enabledMCPs,
type,
usage
}
@@ -199,6 +203,7 @@ export function resetAssistantMessage(message: Message, model?: Model): Message
useful: undefined,
askId: undefined,
mentions: undefined,
enabledMCPs: undefined,
blocks: [],
createdAt: new Date().toISOString()
}

View File

@@ -1,54 +0,0 @@
import { Assistant, Message } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import { fetchChatCompletion } from './ApiService'
import { ConversationService } from './ConversationService'
/**
* The request object for handling a user message.
*/
export interface OrchestrationRequest {
messages: Message[]
assistant: Assistant
options: {
signal?: AbortSignal
timeout?: number
headers?: Record<string, string>
}
}
/**
* The OrchestrationService is responsible for orchestrating the different services
* to handle a user's message. It contains the core logic of the application.
*/
export class OrchestrationService {
constructor() {
// In the future, this could be a singleton, but for now, a new instance is fine.
// this.conversationService = new ConversationService()
}
/**
* This is the core method to handle user messages.
* It takes the message context and an events object for callbacks,
* and orchestrates the call to the LLM.
* The logic is moved from `messageThunk.ts`.
* @param request The orchestration request containing messages and assistant info.
* @param events A set of callbacks to report progress and results to the UI layer.
*/
async handleUserMessage(request: OrchestrationRequest, onChunkReceived: (chunk: Chunk) => void) {
const { messages, assistant } = request
try {
const llmMessages = await ConversationService.prepareMessagesForLlm(messages, assistant)
await fetchChatCompletion({
messages: llmMessages,
assistant: assistant,
options: request.options,
onChunkReceived
})
} catch (error: any) {
onChunkReceived({ type: ChunkType.ERROR, error })
}
}
}

View File

@@ -50,7 +50,7 @@ const persistedReducer = persistReducer(
{
key: 'cherry-studio',
storage,
version: 114,
version: 112,
blacklist: ['runtime', 'messages', 'messageBlocks'],
migrate
},

View File

@@ -237,15 +237,14 @@ export const INITIAL_PROVIDERS: Provider[] = [
isVertex: false
},
{
id: 'vertexai',
name: 'VertexAI',
type: 'vertexai',
id: 'zhipu',
name: 'ZhiPu',
type: 'openai',
apiKey: '',
apiHost: 'https://aiplatform.googleapis.com',
models: [],
apiHost: 'https://open.bigmodel.cn/api/paas/v4/',
models: SYSTEM_MODELS.zhipu,
isSystem: true,
enabled: false,
isVertex: true
enabled: false
},
{
id: 'github',
@@ -268,16 +267,6 @@ export const INITIAL_PROVIDERS: Provider[] = [
enabled: false,
isAuthed: false
},
{
id: 'zhipu',
name: 'ZhiPu',
type: 'openai',
apiKey: '',
apiHost: 'https://open.bigmodel.cn/api/paas/v4/',
models: SYSTEM_MODELS.zhipu,
isSystem: true,
enabled: false
},
{
id: 'yi',
name: 'Yi',
@@ -388,6 +377,26 @@ export const INITIAL_PROVIDERS: Provider[] = [
isSystem: true,
enabled: false
},
{
id: 'zhinao',
name: 'zhinao',
type: 'openai',
apiKey: '',
apiHost: 'https://api.360.cn',
models: SYSTEM_MODELS.zhinao,
isSystem: true,
enabled: false
},
{
id: 'hunyuan',
name: 'hunyuan',
type: 'openai',
apiKey: '',
apiHost: 'https://api.hunyuan.cloud.tencent.com',
models: SYSTEM_MODELS.hunyuan,
isSystem: true,
enabled: false
},
{
id: 'nvidia',
name: 'nvidia',
@@ -468,16 +477,6 @@ export const INITIAL_PROVIDERS: Provider[] = [
isSystem: true,
enabled: false
},
{
id: 'hunyuan',
name: 'hunyuan',
type: 'openai',
apiKey: '',
apiHost: 'https://api.hunyuan.cloud.tencent.com',
models: SYSTEM_MODELS.hunyuan,
isSystem: true,
enabled: false
},
{
id: 'tencent-cloud-ti',
name: 'Tencent Cloud TI',
@@ -517,6 +516,17 @@ export const INITIAL_PROVIDERS: Provider[] = [
models: SYSTEM_MODELS.voyageai,
isSystem: true,
enabled: false
},
{
id: 'vertexai',
name: 'VertexAI',
type: 'vertexai',
apiKey: '',
apiHost: 'https://aiplatform.googleapis.com',
models: [],
isSystem: true,
enabled: false,
isVertex: true
}
]

View File

@@ -1582,6 +1582,7 @@ const migrateConfig = {
'113': (state: RootState) => {
try {
addProvider(state, 'vertexai')
state.llm.providers = moveProvider(state.llm.providers, 'vertexai', 10)
if (!state.llm.settings.vertexai) {
state.llm.settings.vertexai = llmInitialState.settings.vertexai
}
@@ -1595,18 +1596,6 @@ const migrateConfig = {
} catch (error) {
return state
}
},
'114': (state: RootState) => {
try {
if (state.settings && state.settings.exportMenuOptions) {
if (typeof state.settings.exportMenuOptions.plain_text === 'undefined') {
state.settings.exportMenuOptions.plain_text = true
}
}
return state
} catch (error) {
return state
}
}
}

Some files were not shown because too many files have changed in this diff Show More