Compare commits

..

3 Commits

Author SHA1 Message Date
lizhixuan
c518c9090b docs: add example implementation of Redux Slice for message management
- Introduced a new section in the technical documentation detailing the implementation of a `messages` slice using `createEntityAdapter`.
- Provided TypeScript code for the slice, including actions for adding, updating, and removing messages.
- Summarized the core principles of the slice's design, emphasizing single responsibility, logical separation of concerns, and performance optimization.
- Included a migration strategy for transitioning from the previous state structure to the new message pool approach.
2025-06-12 18:34:36 +08:00
suyao
91045ecc2b docs: finalize technical report for message history version management system with multi-model support
- Updated the design document to reflect the final version, incorporating multi-model support and enhanced version management features.
- Expanded the data structure section to include new entities and relationships, such as `askId`, `parentMessageId`, and `siblingIds`.
- Improved the core operation processes, including sending new messages and managing message versions.
- Added detailed diagrams and performance analysis to illustrate the new architecture and its advantages over the previous system.
- Ensured backward compatibility while introducing new functionalities for branching conversations and version control.
2025-06-12 15:59:54 +08:00
suyao
748ca008b4 docs: add technical report for message history version management system
- Introduced a comprehensive design document outlining the architecture and requirements for a message history version management system.
- Added new entities `UserMessage` and `AssistantMessageGroup` to support a directed multi-branch conversation structure and version management.
- Updated existing entities to accommodate the new architecture while maintaining backward compatibility.
- Included performance analysis and migration strategies for transitioning to the new system.
2025-06-12 14:25:42 +08:00
213 changed files with 7783 additions and 48992 deletions

View File

@@ -44,4 +44,4 @@ jobs:
run: yarn build:check
- name: Lint Check
run: yarn test:lint
run: yarn lint

View File

@@ -27,7 +27,7 @@ jobs:
- name: Check out Git repository
uses: actions/checkout@v4
with:
fetch-depth: 0
ref: main
- name: Get release tag
id: get-tag
@@ -149,4 +149,4 @@ jobs:
token: ${{ secrets.REPO_DISPATCH_TOKEN }}
repository: CherryHQ/cherry-studio-docs
event-type: update-download-version
client-payload: '{"version": "${{ steps.get-tag.outputs.tag }}"}'
client-payload: '{"version": "${{ steps.get-tag.outputs.tag }}"}'

1
.vscode/launch.json vendored
View File

@@ -7,6 +7,7 @@
"request": "launch",
"cwd": "${workspaceRoot}",
"runtimeExecutable": "${workspaceRoot}/node_modules/.bin/electron-vite",
"runtimeVersion": "20",
"windows": {
"runtimeExecutable": "${workspaceRoot}/node_modules/.bin/electron-vite.cmd"
},

File diff suppressed because one or more lines are too long

View File

@@ -65,44 +65,11 @@ index e8bd7bb46c8a54b3f55cf3a853ef924195271e01..f956e9f3fe9eb903c78aef3502553b01
await packager.info.emitArtifactBuildCompleted({
file: installerPath,
updateInfo,
diff --git a/out/util/yarn.js b/out/util/yarn.js
index 1ee20f8b252a8f28d0c7b103789cf0a9a427aec1..c2878ec54d57da50bf14225e0c70c9c88664eb8a 100644
--- a/out/util/yarn.js
+++ b/out/util/yarn.js
@@ -140,6 +140,7 @@ async function rebuild(config, { appDir, projectDir }, options) {
arch,
platform,
buildFromSource,
+ ignoreModules: config.excludeReBuildModules || undefined,
projectRootPath: projectDir,
mode: config.nativeRebuilder || "sequential",
disablePreGypCopy: true,
diff --git a/scheme.json b/scheme.json
index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..0167441bf928a92f59b5dbe70b2317a74dda74c9 100644
index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..a89c7a9b0b608fef67902c49106a43ebd0fa8b61 100644
--- a/scheme.json
+++ b/scheme.json
@@ -1825,6 +1825,20 @@
"string"
]
},
+ "excludeReBuildModules": {
+ "anyOf": [
+ {
+ "items": {
+ "type": "string"
+ },
+ "type": "array"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "description": "The modules to exclude from the rebuild."
+ },
"executableArgs": {
"anyOf": [
{
@@ -1975,6 +1989,13 @@
@@ -1975,6 +1975,13 @@
],
"description": "The mime types in addition to specified in the file associations. Use it if you don't want to register a new mime type, but reuse existing."
},
@@ -116,7 +83,7 @@ index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..0167441bf928a92f59b5dbe70b2317a7
"packageCategory": {
"description": "backward compatibility + to allow specify fpm-only category for all possible fpm targets in one place",
"type": [
@@ -2327,6 +2348,13 @@
@@ -2327,6 +2334,13 @@
"MacConfiguration": {
"additionalProperties": false,
"properties": {
@@ -130,28 +97,7 @@ index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..0167441bf928a92f59b5dbe70b2317a7
"additionalArguments": {
"anyOf": [
{
@@ -2527,6 +2555,20 @@
"string"
]
},
+ "excludeReBuildModules": {
+ "anyOf": [
+ {
+ "items": {
+ "type": "string"
+ },
+ "type": "array"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "description": "The modules to exclude from the rebuild."
+ },
"executableName": {
"description": "The executable name. Defaults to `productName`.",
"type": [
@@ -2737,7 +2779,7 @@
@@ -2737,7 +2751,7 @@
"type": "boolean"
},
"minimumSystemVersion": {
@@ -160,7 +106,7 @@ index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..0167441bf928a92f59b5dbe70b2317a7
"type": [
"null",
"string"
@@ -2959,6 +3001,13 @@
@@ -2959,6 +2973,13 @@
"MasConfiguration": {
"additionalProperties": false,
"properties": {
@@ -174,28 +120,7 @@ index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..0167441bf928a92f59b5dbe70b2317a7
"additionalArguments": {
"anyOf": [
{
@@ -3159,6 +3208,20 @@
"string"
]
},
+ "excludeReBuildModules": {
+ "anyOf": [
+ {
+ "items": {
+ "type": "string"
+ },
+ "type": "array"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "description": "The modules to exclude from the rebuild."
+ },
"executableName": {
"description": "The executable name. Defaults to `productName`.",
"type": [
@@ -3369,7 +3432,7 @@
@@ -3369,7 +3390,7 @@
"type": "boolean"
},
"minimumSystemVersion": {
@@ -204,28 +129,7 @@ index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..0167441bf928a92f59b5dbe70b2317a7
"type": [
"null",
"string"
@@ -6381,6 +6444,20 @@
"string"
]
},
+ "excludeReBuildModules": {
+ "anyOf": [
+ {
+ "items": {
+ "type": "string"
+ },
+ "type": "array"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "description": "The modules to exclude from the rebuild."
+ },
"executableName": {
"description": "The executable name. Defaults to `productName`.",
"type": [
@@ -6507,6 +6584,13 @@
@@ -6507,6 +6528,13 @@
"string"
]
},
@@ -239,28 +143,7 @@ index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..0167441bf928a92f59b5dbe70b2317a7
"protocols": {
"anyOf": [
{
@@ -7153,6 +7237,20 @@
"string"
]
},
+ "excludeReBuildModules": {
+ "anyOf": [
+ {
+ "items": {
+ "type": "string"
+ },
+ "type": "array"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "description": "The modules to exclude from the rebuild."
+ },
"executableName": {
"description": "The executable name. Defaults to `productName`.",
"type": [
@@ -7376,6 +7474,13 @@
@@ -7376,6 +7404,13 @@
],
"description": "MAS (Mac Application Store) development options (`mas-dev` target)."
},

View File

@@ -1,214 +0,0 @@
# 如何为 AI Provider 编写中间件
本文档旨在指导开发者如何为我们的 AI Provider 框架创建和集成自定义中间件。中间件提供了一种强大而灵活的方式来增强、修改或观察 Provider 方法的调用过程,例如日志记录、缓存、请求/响应转换、错误处理等。
## 架构概览
我们的中间件架构借鉴了 Redux 的三段式设计,并结合了 JavaScript Proxy 来动态地将中间件应用于 Provider 的方法。
- **Proxy**: 拦截对 Provider 方法的调用,并将调用引导至中间件链。
- **中间件链**: 一系列按顺序执行的中间件函数。每个中间件都可以处理请求/响应,然后将控制权传递给链中的下一个中间件,或者在某些情况下提前终止链。
- **上下文 (Context)**: 一个在中间件之间传递的对象携带了关于当前调用的信息如方法名、原始参数、Provider 实例、以及中间件自定义的数据)。
## 中间件的类型
目前主要支持两种类型的中间件,它们共享相似的结构但针对不同的场景:
1. **`CompletionsMiddleware`**: 专门为 `completions` 方法设计。这是最常用的中间件类型,因为它允许对 AI 模型的核心聊天/文本生成功能进行精细控制。
2. **`ProviderMethodMiddleware`**: 通用中间件,可以应用于 Provider 上的任何其他方法(例如,`translate`, `summarize` 等,如果这些方法也通过中间件系统包装)。
## 编写一个 `CompletionsMiddleware`
`CompletionsMiddleware` 的基本签名TypeScript 类型)如下:
```typescript
import { AiProviderMiddlewareCompletionsContext, CompletionsParams, MiddlewareAPI } from './AiProviderMiddlewareTypes' // 假设类型定义文件路径
export type CompletionsMiddleware = (
api: MiddlewareAPI<AiProviderMiddlewareCompletionsContext, [CompletionsParams]>
) => (
next: (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise<any> // next 返回 Promise<any> 代表原始SDK响应或下游中间件的结果
) => (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise<void> // 最内层函数通常返回 Promise<void>,因为结果通过 onChunk 或 context 副作用传递
```
让我们分解这个三段式结构:
1. **第一层函数 `(api) => { ... }`**:
- 接收一个 `api` 对象。
- `api` 对象提供了以下方法:
- `api.getContext()`: 获取当前调用的上下文对象 (`AiProviderMiddlewareCompletionsContext`)。
- `api.getOriginalArgs()`: 获取传递给 `completions` 方法的原始参数数组 (即 `[CompletionsParams]`)。
- `api.getProviderId()`: 获取当前 Provider 的 ID。
- `api.getProviderInstance()`: 获取原始的 Provider 实例。
- 此函数通常用于进行一次性的设置或获取所需的服务/配置。它返回第二层函数。
2. **第二层函数 `(next) => { ... }`**:
- 接收一个 `next` 函数。
- `next` 函数代表了中间件链中的下一个环节。调用 `next(context, params)` 会将控制权传递给下一个中间件,或者如果当前中间件是链中的最后一个,则会调用核心的 Provider 方法逻辑 (例如,实际的 SDK 调用)。
- `next` 函数接收当前的 `context``params` (这些可能已被上游中间件修改)。
- **重要的是**`next` 的返回类型通常是 `Promise<any>`。对于 `completions` 方法,如果 `next` 调用了实际的 SDK它将返回原始的 SDK 响应例如OpenAI 的流对象或 JSON 对象)。你需要处理这个响应。
- 此函数返回第三层(也是最核心的)函数。
3. **第三层函数 `(context, params) => { ... }`**:
- 这是执行中间件主要逻辑的地方。
- 它接收当前的 `context` (`AiProviderMiddlewareCompletionsContext`) 和 `params` (`CompletionsParams`)。
- 在此函数中,你可以:
- **在调用 `next` 之前**:
- 读取或修改 `params`。例如,添加默认参数、转换消息格式。
- 读取或修改 `context`。例如,设置一个时间戳用于后续计算延迟。
- 执行某些检查,如果不满足条件,可以不调用 `next` 而直接返回或抛出错误(例如,参数校验失败)。
- **调用 `await next(context, params)`**:
- 这是将控制权传递给下游的关键步骤。
- `next` 的返回值是原始的 SDK 响应或下游中间件的结果,你需要根据情况处理它(例如,如果是流,则开始消费流)。
- **在调用 `next` 之后**:
- 处理 `next` 的返回结果。例如,如果 `next` 返回了一个流,你可以在这里开始迭代处理这个流,并通过 `context.onChunk` 发送数据块。
- 基于 `context` 的变化或 `next` 的结果执行进一步操作。例如,计算总耗时、记录日志。
- 修改最终结果(尽管对于 `completions`,结果通常通过 `onChunk` 副作用发出)。
### 示例:一个简单的日志中间件
```typescript
import {
AiProviderMiddlewareCompletionsContext,
CompletionsParams,
MiddlewareAPI,
OnChunkFunction // 假设 OnChunkFunction 类型被导出
} from './AiProviderMiddlewareTypes' // 调整路径
import { ChunkType } from '@renderer/types' // 调整路径
export const createSimpleLoggingMiddleware = (): CompletionsMiddleware => {
return (api: MiddlewareAPI<AiProviderMiddlewareCompletionsContext, [CompletionsParams]>) => {
// console.log(`[LoggingMiddleware] Initialized for provider: ${api.getProviderId()}`);
return (next: (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise<any>) => {
return async (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams): Promise<void> => {
const startTime = Date.now()
// 从 context 中获取 onChunk (它最初来自 params.onChunk)
const onChunk = context.onChunk
console.log(
`[LoggingMiddleware] Request for ${context.methodName} with params:`,
params.messages?.[params.messages.length - 1]?.content
)
try {
// 调用下一个中间件或核心逻辑
// `rawSdkResponse` 是来自下游的原始响应 (例如 OpenAIStream 或 ChatCompletion 对象)
const rawSdkResponse = await next(context, params)
// 此处简单示例不处理 rawSdkResponse假设下游中间件 (如 StreamingResponseHandler)
// 会处理它并通过 onChunk 发送数据。
// 如果这个日志中间件在 StreamingResponseHandler 之后,那么流已经被处理。
// 如果在之前,那么它需要自己处理 rawSdkResponse 或确保下游会处理。
const duration = Date.now() - startTime
console.log(`[LoggingMiddleware] Request for ${context.methodName} completed in ${duration}ms.`)
// 假设下游已经通过 onChunk 发送了所有数据。
// 如果这个中间件是链的末端,并且需要确保 BLOCK_COMPLETE 被发送,
// 它可能需要更复杂的逻辑来跟踪何时所有数据都已发送。
} catch (error) {
const duration = Date.now() - startTime
console.error(`[LoggingMiddleware] Request for ${context.methodName} failed after ${duration}ms:`, error)
// 如果 onChunk 可用,可以尝试发送一个错误块
if (onChunk) {
onChunk({
type: ChunkType.ERROR,
error: { message: (error as Error).message, name: (error as Error).name, stack: (error as Error).stack }
})
// 考虑是否还需要发送 BLOCK_COMPLETE 来结束流
onChunk({ type: ChunkType.BLOCK_COMPLETE, response: {} })
}
throw error // 重新抛出错误,以便上层或全局错误处理器可以捕获
}
}
}
}
}
```
### `AiProviderMiddlewareCompletionsContext` 的重要性
`AiProviderMiddlewareCompletionsContext` 是在中间件之间传递状态和数据的核心。它通常包含:
- `methodName`: 当前调用的方法名 (总是 `'completions'`)。
- `originalArgs`: 传递给 `completions` 的原始参数数组。
- `providerId`: Provider 的 ID。
- `_providerInstance`: Provider 实例。
- `onChunk`: 从原始 `CompletionsParams` 传入的回调函数,用于流式发送数据块。**所有中间件都应该通过 `context.onChunk` 来发送数据。**
- `messages`, `model`, `assistant`, `mcpTools`: 从原始 `CompletionsParams` 中提取的常用字段,方便访问。
- **自定义字段**: 中间件可以向上下文中添加自定义字段,以供后续中间件使用。例如,一个缓存中间件可能会添加 `context.cacheHit = true`
**关键**: 当你在中间件中修改 `params``context` 时,这些修改会向下游中间件传播(如果它们在 `next` 调用之前修改)。
### 中间件的顺序
中间件的执行顺序非常重要。它们在 `AiProviderMiddlewareConfig` 的数组中定义的顺序就是它们的执行顺序。
- 请求首先通过第一个中间件,然后是第二个,依此类推。
- 响应(或 `next` 的调用结果)则以相反的顺序"冒泡"回来。
例如,如果链是 `[AuthMiddleware, CacheMiddleware, LoggingMiddleware]`
1. `AuthMiddleware` 先执行其 "调用 `next` 之前" 的逻辑。
2. 然后 `CacheMiddleware` 执行其 "调用 `next` 之前" 的逻辑。
3. 然后 `LoggingMiddleware` 执行其 "调用 `next` 之前" 的逻辑。
4. 核心SDK调用或链的末端
5. `LoggingMiddleware` 先接收到结果,执行其 "调用 `next` 之后" 的逻辑。
6. 然后 `CacheMiddleware` 接收到结果(可能已被 LoggingMiddleware 修改的上下文),执行其 "调用 `next` 之后" 的逻辑(例如,存储结果)。
7. 最后 `AuthMiddleware` 接收到结果,执行其 "调用 `next` 之后" 的逻辑。
### 注册中间件
中间件在 `src/renderer/src/providers/middleware/register.ts` (或其他类似的配置文件) 中进行注册。
```typescript
// register.ts
import { AiProviderMiddlewareConfig } from './AiProviderMiddlewareTypes'
import { createSimpleLoggingMiddleware } from './common/SimpleLoggingMiddleware' // 假设你创建了这个文件
import { createCompletionsLoggingMiddleware } from './common/CompletionsLoggingMiddleware' // 已有的
const middlewareConfig: AiProviderMiddlewareConfig = {
completions: [
createSimpleLoggingMiddleware(), // 你新加的中间件
createCompletionsLoggingMiddleware() // 已有的日志中间件
// ... 其他 completions 中间件
],
methods: {
// translate: [createGenericLoggingMiddleware()],
// ... 其他方法的中间件
}
}
export default middlewareConfig
```
### 最佳实践
1. **单一职责**: 每个中间件应专注于一个特定的功能(例如,日志、缓存、转换特定数据)。
2. **无副作用 (尽可能)**: 除了通过 `context``onChunk` 明确的副作用外,尽量避免修改全局状态或产生其他隐蔽的副作用。
3. **错误处理**:
- 在中间件内部使用 `try...catch` 来处理可能发生的错误。
- 决定是自行处理错误(例如,通过 `onChunk` 发送错误块)还是将错误重新抛出给上游。
- 如果重新抛出,确保错误对象包含足够的信息。
4. **性能考虑**: 中间件会增加请求处理的开销。避免在中间件中执行非常耗时的同步操作。对于IO密集型操作确保它们是异步的。
5. **可配置性**: 使中间件的行为可通过参数或配置进行调整。例如,日志中间件可以接受一个日志级别参数。
6. **上下文管理**:
- 谨慎地向 `context` 添加数据。避免污染 `context` 或添加过大的对象。
- 明确你添加到 `context` 的字段的用途和生命周期。
7. **`next` 的调用**:
- 除非你有充分的理由提前终止请求(例如,缓存命中、授权失败),否则**总是确保调用 `await next(context, params)`**。否则,下游的中间件和核心逻辑将不会执行。
- 理解 `next` 的返回值并正确处理它,特别是当它是一个流时。你需要负责消费这个流或将其传递给另一个能够消费它的组件/中间件。
8. **命名清晰**: 给你的中间件和它们创建的函数起描述性的名字。
9. **文档和注释**: 对复杂的中间件逻辑添加注释,解释其工作原理和目的。
### 调试技巧
- 在中间件的关键点使用 `console.log` 或调试器来检查 `params``context` 的状态以及 `next` 的返回值。
- 暂时简化中间件链,只保留你正在调试的中间件和最简单的核心逻辑,以隔离问题。
- 编写单元测试来独立验证每个中间件的行为。
通过遵循这些指南,你应该能够有效地为我们的系统创建强大且可维护的中间件。如果你有任何疑问或需要进一步的帮助,请咨询团队。

View File

@@ -0,0 +1,635 @@
# 消息历史版本管理系统设计技术报告(最终版 - 含多模型支持)
## 1. 系统概述
基于现有扁平化架构的最小化扩展,通过 **Topic快照 + Message字段扩展含siblingIds** 实现版本管理、分支对话和多模型并行回复功能。
### 1.1 核心设计理念
- **最小破坏性**:只扩展现有实体,不新增表
- **快照渲染**通过Topic简单快照管理主线渲染顺序
- **关系扩展**通过Message字段实现树状分支、双向链表版本、多模型兄弟关系
## 2. 数据结构设计
### 2.1 实体定义
```typescript
interface Topic {
// === 现有字段保持不变 ===
id: string
name: string
createdAt: string
updatedAt: string
// === 保持简单快照 ===
activeMessageIds: string[] // 当前活跃对话主线的消息ID顺序
}
interface Message {
// === 现有字段保持不变 ===
id: string
role: 'user' | 'assistant' | 'system'
topicId: string
blocks: MessageBlock['id'][]
// === 新增:关系字段 ===
askId?: string // 问答关系assistant指向对应的user消息
parentMessageId?: string // 分支关系:指向回复的目标消息
version?: number // 版本号assistant消息专用
prevVersionId?: string // 版本链表:前一版本
nextVersionId?: string // 版本链表:后一版本
groupRequestId?: string // 请求分组同次API请求的标识
siblingIds?: string[] // 兄弟关系同级多模型回复的ID列表
}
interface MessageBlock {
// === 完全不变 ===
id: string
messageId: string
type: MessageBlockType
content: string
// ...其他现有字段
}
```
### 2.2 数据关系图
```mermaid
graph TB
subgraph "Topic快照层 (主线)"
T[Topic.activeMessageIds: user1→asst1-gpt→user2]
end
subgraph "消息实体层"
U1[User Message 1<br/>id: user1]
A1G["GPT-4 回复<br/>id: asst1-gpt, askId: user1<br/>siblingIds: [asst1-claude]"]
A1C["Claude 回复<br/>id: asst1-claude, askId: user1<br/>siblingIds: [asst1-gpt]"]
U2["User Message 2<br/>id: user2, parentMessageId: asst1-gpt"]
end
subgraph "版本链表层 (隐藏)"
A1GV0[GPT-4 v0<br/>askId: user1, version: 0]
A1GV1[GPT-4 v1<br/>askId: user1, version: 1]
A1GV0 -.->|nextVersionId| A1GV1
A1GV1 -.->|prevVersionId| A1GV0
end
subgraph "分支树层 (隐藏)"
U1B[User Branch 1<br/>parentMessageId: asst1-gpt]
A1B[Assistant Branch 1<br/>askId: user1b]
end
T --> U1
T --> A1G
T --> U2
A1G -.->|askId| U1
A1C -.->|askId| U1
A1G -.->|siblingIds| A1C
A1C -.->|siblingIds| A1G
U2 -.->|parentMessageId| A1G
U1B -.->|parentMessageId| A1G
A1B -.->|askId| U1B
```
## 3. 核心操作流程
### 3.1 发送新消息(多模型)
```mermaid
sequenceDiagram
participant UI
participant Redux
participant DB
participant API
UI->>Redux: sendMessage(userContent, models[])
Note over Redux: 1. 创建用户消息
Redux->>Redux: userMessage = { id: uuid(), role: 'user', ... }
Note over Redux: 2. 创建助手消息(多模型)
Redux->>Redux: groupRequestId = uuid()
Redux->>Redux: assistantMessages = models.map(m => createAssistant(userMessage.id, m))
Note over Redux: 3. 设置兄弟关系
Redux->>Redux: assistantIds = assistantMessages.map(m => m.id)
loop 每个助手消息
Redux->>Redux: msg.siblingIds = assistantIds.filter(id => id !== msg.id)
end
Note over Redux: 4. 更新Topic快照
Redux->>Redux: newActiveMessageIds = [<br/>...oldIds,<br/>userMessage.id,<br/>assistantMessages[0].id<br/>]
Note over Redux: 5. 原子保存
Redux->>DB: transaction([messages, topics])
DB->>DB: messages.bulkPut([userMessage, ...assistantMessages])
DB->>DB: topics.update(topicId, { activeMessageIds })
Note over Redux: 6. 发送API请求
loop 每个模型
Redux->>API: generateResponse(model, userContent)
end
Redux->>UI: 更新状态
```
**复杂度**O(M) where M = 模型数量
### 3.2 重发消息(版本管理)
```mermaid
sequenceDiagram
participant UI
participant Redux
participant DB
UI->>Redux: resendMessage(userMessageId)
Note over Redux: 1. 查找现有版本
Redux->>DB: messages.where('askId').equals(userMessageId)
DB-->>Redux: existingVersions[]
Note over Redux: 2. 计算新版本号
Redux->>Redux: latestVersion = max(versions.map(v => v.version))
Redux->>Redux: newVersion = latestVersion + 1
Note over Redux: 3. 创建新版本消息(可能多模型)
Redux->>Redux: newGroupRequestId = uuid()
Redux->>Redux: newVersionMessages = models.map(m => createNewVersion(prevMsg, newVersion, newGroupRequestId))
Note over Redux: 4. 设置新版本的兄弟关系
Redux->>Redux: newVersionIds = newVersionMessages.map(m => m.id)
loop 每个新版本消息
Redux->>Redux: msg.siblingIds = newVersionIds.filter(id => id !== msg.id)
end
Note over Redux: 5. 更新版本链表
Redux->>DB: transaction(messages)
DB->>DB: messages.update(prevMessage.id, { nextVersionId })
DB->>DB: messages.bulkPut(newVersionMessages)
Redux->>UI: 更新状态
```
**复杂度**O(V) 查找 + O(M) 创建
### 3.3 切换活跃模型UI交互
```mermaid
flowchart TD
A[用户在UI上选择其他模型] --> B[获取当前快照]
B --> C[找到当前助手消息在快照中的位置]
C --> D[用新选择的模型消息ID替换快照中的ID]
D --> E[保存到数据库]
E --> F[Redux自动重新渲染]
style A fill:#e1f5fe
style F fill:#c8e6c9
```
```typescript
const switchActiveModel = async (topicId: string, messageIndex: number, newModelMessageId: string) => {
const topic = await db.topics.get(topicId)
const newActiveMessageIds = [...topic.activeMessageIds]
newActiveMessageIds[messageIndex] = newModelMessageId
await db.topics.update(topicId, { activeMessageIds: newActiveMessageIds })
}
```
**复杂度**O(1)
## 4. 字段作用详解
### 4.1 关键字段关系图
```mermaid
graph LR
subgraph "问答关系"
askId[askId<br/>assistant → user<br/>逻辑关系,永久不变]
end
subgraph "分支关系"
parentId[parentMessageId<br/>message → message<br/>分支对话,树状结构]
end
subgraph "版本关系"
version[version + prevVersionId + nextVersionId<br/>同askId下的版本链表]
end
subgraph "请求分组"
groupId[groupRequestId<br/>同次API请求标识<br/>一次性,每次重发都变]
end
subgraph "兄弟关系"
siblingId[siblingIds<br/>同级多模型回复<br/>双向引用]
end
askId -.-> version
askId -.-> siblingId
parentId -.-> askId
groupId -.-> askId
```
### 4.2 字段使用场景
| 字段 | 用途 | 查询场景 | 生命周期 |
| -------------------------------- | ---------- | -------------------------- | -------- |
| **askId** | 问答映射 | 查找用户问题的所有回复版本 | 永久不变 |
| **parentMessageId** | 分支对话 | 查找某消息的回复分支 | 永久不变 |
| **version + prev/nextVersionId** | 版本管理 | 版本历史导航 | 永久不变 |
| **groupRequestId** | 请求追踪 | 批量状态更新、请求监控 | 一次性 |
| **siblingIds** | 多模型并行 | 渲染同级多模型回复 | 永久不变 |
### 4.3 多模型并行渲染示例
```mermaid
graph TD
U1[User: 帮我写个函数<br/>id: user1]
subgraph "第一次请求 (groupRequestId: req1)"
A1["GPT-4 回复<br/>id: asst1-gpt, askId: user1<br/>siblingIds: [asst1-claude]"]
A2["Claude 回复<br/>id: asst1-claude, askId: user1<br/>siblingIds: [asst1-gpt]"]
end
subgraph "Topic快照 (主线)"
T["activeMessageIds: [user1, asst1-gpt]"]
end
subgraph "UI渲染 (通过siblingIds扩展)"
UI_U1[User: 帮我写个函数]
UI_A1["GPT-4 回复 (活跃)"]
UI_A2["Claude 回复 (可选)"]
end
U1 --> A1
U1 --> A2
T --> U1
T --> A1
A1 -.->|siblingIds| A2
A2 -.->|siblingIds| A1
UI_U1 -.-> UI_A1
UI_U1 -.-> UI_A2
```
## 5. 数据查询与状态管理
### 5.1 话题加载流程
```mermaid
sequenceDiagram
participant UI
participant Redux
participant DB
participant Selector
UI->>Redux: loadTopic(topicId)
Redux->>DB: 并行查询
par 查询消息
DB->>DB: messages.where('topicId').equals(topicId)
and 查询块
DB->>DB: messageBlocks.where('topicId').equals(topicId)
end
DB-->>Redux: { messages[], blocks[] }
Redux->>Redux: 更新实体状态
UI->>Selector: selectActiveConversationWithSiblings(topicId)
Selector->>Redux: 获取Topic.activeMessageIds
Selector->>Redux: 获取messages实体
Selector-->>UI: 按快照顺序的消息列表 (含兄弟节点)
Note over UI: 渲染对话界面 (支持多模型)
```
### 5.2 渲染选择器(含兄弟节点)
```typescript
export const selectActiveConversationWithSiblings = createSelector(
[
(state: RootState, topicId: string) => state.topics.entities[topicId]?.activeMessageIds || [],
(state: RootState) => state.messages.entities,
(state: RootState) => state.messageBlocks.entities
],
(activeMessageIds, messagesEntities, blocksEntities) => {
return activeMessageIds
.map((messageId) => {
const message = messagesEntities[messageId]
if (!message) return null
if (message.role === 'user') {
return { type: 'user', message, blocks: getMessageBlocks(message, blocksEntities) }
} else if (message.role === 'assistant') {
const siblingMessages = (message.siblingIds || []).map((id) => messagesEntities[id]).filter(Boolean)
const allAssistantMessages = [message, ...siblingMessages]
return {
type: 'assistant_group',
messages: allAssistantMessages.map((msg) => ({
message: msg,
blocks: getMessageBlocks(msg, blocksEntities),
isActive: msg.id === messageId
})),
activeMessageId: messageId
}
}
})
.filter(Boolean)
}
)
```
**复杂度**O(N + S) where N = 快照长度, S = 兄弟节点总数
## 6. 时空复杂度分析
### 6.1 核心操作复杂度对比
```mermaid
graph LR
subgraph "现有架构"
A1[加载话题: O(M+B)]
A2[渲染对话: O(M) 需要过滤排序]
A3[发送消息: O(1)]
end
subgraph "新架构 (含多模型)"
B1[加载话题: O(M+B) ✅相同]
B2[渲染对话: O(N+S) ✅更优]
B3[发送消息: O(M_models) ✅相同]
B4[版本切换: O(1) ➕新功能]
B5[重发消息: O(V)+O(M_models) ➕新功能]
B6[模型切换: O(1) ➕新功能]
end
style B1 fill:#c8e6c9
style B2 fill:#c8e6c9
style B3 fill:#c8e6c9
style B4 fill:#fff3e0
style B5 fill:#fff3e0
style B6 fill:#fff3e0
```
### 6.2 性能优势分析
| 操作 | 现有架构 | 新架构 | 优势说明 |
| ------------ | -------------- | ---------------------------- | -------------------- |
| **话题加载** | O(M + B) | O(M + B) | 性能保持不变 |
| **对话渲染** | O(M) 过滤+排序 | **O(N+S)** 直接索引+兄弟扩展 | N << MS通常较小 |
| **发送消息** | O(1) | O(M_models) | 支持多模型,合理增长 |
| **版本切换** | 不支持 | **O(1)** | 新功能,极佳性能 |
| **模型切换** | 不支持 | **O(1)** | 新功能,极佳性能 |
**关键优势**
- **渲染性能提升**:从 O(M) 优化到 O(N+S),长对话场景收益显著
- **多模型支持**:通过 siblingIds 优雅实现
- **版本管理**O(1) 的版本/模型切换,用户体验极佳
- **向后兼容**:现有核心操作性能保持不变
## 7. 数据库Schema演进
### 7.1 Migration策略
```mermaid
flowchart TD
A[现有Schema] --> B[添加字段]
B --> C[创建索引]
C --> D[数据迁移]
D --> E[验证完整性]
B1[Topic: +activeMessageIds]
B2[Message: +askId, +parentMessageId<br/>+version, +prevVersionId<br/>+nextVersionId, +groupRequestId<br/>+siblingIds]
C1[idx_messages_askid_version]
C2[idx_messages_parent]
C3[idx_messages_group_request]
D1[生成activeMessageIds快照]
D2[设置现有assistant消息version=0]
B --> B1
B --> B2
C --> C1
C --> C2
C --> C3
D --> D1
D --> D2
```
### 7.2 SQL Migration
```sql
-- 1. 添加字段
ALTER TABLE topics ADD COLUMN activeMessageIds TEXT; -- JSON数组
ALTER TABLE messages ADD COLUMN askId TEXT;
ALTER TABLE messages ADD COLUMN parentMessageId TEXT;
ALTER TABLE messages ADD COLUMN version INTEGER;
ALTER TABLE messages ADD COLUMN prevVersionId TEXT;
ALTER TABLE messages ADD COLUMN nextVersionId TEXT;
ALTER TABLE messages ADD COLUMN groupRequestId TEXT;
ALTER TABLE messages ADD COLUMN siblingIds TEXT; -- JSON数组
-- 2. 创建索引
CREATE INDEX idx_messages_askid_version ON messages(askId, version);
CREATE INDEX idx_messages_parent ON messages(parentMessageId);
CREATE INDEX idx_messages_group_request ON messages(groupRequestId);
-- 3. 数据迁移
UPDATE messages SET version = 0 WHERE role = 'assistant';
```
## 8. 流式更新兼容性
### 8.1 MessageBlock更新流程
```mermaid
sequenceDiagram
participant Stream
participant Redux
participant DB
participant UI
Note over Stream: 流式内容到达
Stream->>Redux: updateBlock(blockId, content)
Redux->>Redux: updateOneBlock({ id, changes })
Redux->>UI: 立即更新显示
Note over Redux: 节流数据库写入
Redux->>DB: throttledDbUpdate(blockId, content)
Note over Stream,UI: 版本/兄弟关系不影响块更新
```
**关键点**
- MessageBlock 仍然直接关联到 Message
- 版本/兄弟关系在 Message 层面,不影响 Block 的流式更新
- 现有的节流机制和更新逻辑完全保持不变
## 9. 系统架构总览
### 9.1 整体架构图
```mermaid
graph TB
subgraph "UI层"
UI1[对话界面]
UI2[版本选择器]
UI3[分支导航]
UI4[模型切换器]
end
subgraph "Redux状态层"
R1[topics: EntityAdapter]
R2[messages: EntityAdapter]
R3[messageBlocks: EntityAdapter]
S1[selectActiveConversationWithSiblings]
S2[selectVersionHistory]
end
subgraph "数据库层"
DB1[(topics表)]
DB2[(messages表)]
DB3[(messageBlocks表)]
end
subgraph "API层"
API1[多模型并行请求]
API2[流式响应处理]
end
UI1 --> S1
UI2 --> S2
UI4 --> S1
S1 --> R1
S1 --> R2
S2 --> R2
R1 <--> DB1
R2 <--> DB2
R3 <--> DB3
R2 --> API1
API2 --> R3
style UI1 fill:#e3f2fd
style R1 fill:#f3e5f5
style R2 fill:#f3e5f5
style R3 fill:#f3e5f5
style DB1 fill:#e8f5e8
style DB2 fill:#e8f5e8
style DB3 fill:#e8f5e8
```
### 9.2 数据流向
```mermaid
flowchart LR
A[用户输入] --> B[创建User Message]
B --> C["创建Assistant Messages (多模型)"]
C --> C1[设置Sibling关系]
C1 --> D["更新Topic快照 (主线)"]
D --> E[API并行请求]
E --> F[流式更新Blocks]
F --> G["UI实时渲染 (含多模型)"]
H[版本切换] --> I[更新快照指针]
I --> G
J[分支对话] --> K[创建分支消息]
K --> D
L[模型切换] --> I
style A fill:#ffebee
style G fill:#e8f5e8
style H fill:#fff3e0
style J fill:#f3e5f5
style L fill:#e1f5fe
```
## 10. Redux Slice 实现范例
根据上述架构设计,`messages` slice 将演变为一个纯粹的、由 `createEntityAdapter` 管理的"消息池"。它只负责高效地存储和访问单个消息实体,而不再关心对话的顺序。
### `store/messagesSlice.ts`
```typescript
import { createSlice, createEntityAdapter, PayloadAction } from '@reduxjs/toolkit'
import type { RootState } from './store' // 你的store类型定义
import type { Message } from '@renderer/types/newMessage' // 假设 Message 类型定义在外部
// 1. 创建 Entity Adapter
// 它会自动生成管理实体的reducer逻辑实现一个高效的消息池。
const messagesAdapter = createEntityAdapter<Message>()
// 2. 定义 Slice 的初始状态
// adapter.getInitialState() 会自动创建 { ids: [], entities: {} } 结构
const initialState = messagesAdapter.getInitialState()
// 3. 创建 Slice
const messagesSlice = createSlice({
name: 'messages',
initialState,
// Reducers被极大简化多数直接引用adapter提供的方法
reducers: {
// Action: 添加一条消息
messageAdded: messagesAdapter.addOne,
// Action: 一次性添加或更新多个消息 (高性能)
// 用途: 加载话题历史、发送新一轮问答(user+assistants)
messagesUpserted: messagesAdapter.upsertMany,
// Action: 更新单个消息
// 用途: 流式更新结束、状态变更等
messageUpdated: messagesAdapter.updateOne,
// Action: 删除单个消息
messageRemoved: messagesAdapter.removeOne,
// Action: 删除多个消息
messagesRemoved: messagesAdapter.removeMany,
// Action: 用新数据完全替换消息池
// 用途: 首次加载或强制刷新
messagesSet: messagesAdapter.setAll
}
})
// 4. 导出 Actions
export const { messageAdded, messagesUpserted, messageUpdated, messageRemoved, messagesRemoved, messagesSet } =
messagesSlice.actions
// 5. 导出 Selectors
// Adapter 会自动创建高效的查询函数 (e.g., O(1) by ID)
export const messagesSelectors = messagesAdapter.getSelectors((state: RootState) => state.messages)
// 6. 导出 Reducer
export default messagesSlice.reducer
```
### 核心思想总结
1. **职责单一**: 此 Slice 只做一件事——管理 `Message` 实体。它像一个数据库表,高效地处理增删改查,但对业务逻辑(如对话顺序)一无所知。
2. **逻辑上移**: 所有涉及多个 Slice 的复杂业务逻辑(如发送消息、切换版本)都应封装在 **Thunks** 或其他中间件中。Thunk 作为流程协调者,会 `dispatch` 多个原子化的 Action 给 `messagesSlice``topicsSlice`,以完成一次完整的业务操作并保证数据一致性。
3. **性能保证**: `createEntityAdapter` 内部使用哈希表(对象)来存储实体,确保通过 ID 查询消息的操作为 O(1) 复杂度,性能极佳。
### 旧状态属性迁移
为了完成 `messagesSlice` 向纯粹"消息池"的演进,原有的混合状态属性需要被迁移或废弃,以实现彻底的职责分离。
| 原属性 (`newMessage.ts`) | 处理方式 | 新的归宿 / 说明 |
| :----------------------- | :------------ | :-------------------------------------------------------------------------------------------- |
| `messageIdsByTopic` | **废弃** | 核心职责转移。由 `topicsSlice` 中的 `activeMessageIds` 字段接管,作为渲染快照。 |
| `currentTopicId` | **迁移** | 属于UI当前上下文状态应迁移至 `topicsSlice`。 |
| `loadingByTopic` | **迁移** | 话题的加载状态与话题本身更相关,应迁移至 `topicsSlice`。 |
| `displayCount` | **废弃/迁移** | UI相关的显示逻辑不属于消息数据层。建议迁移至专门的 `Slice` 或在相关组件中作为本地状态管理。 |

Binary file not shown.

After

Width:  |  Height:  |  Size: 122 KiB

View File

@@ -107,9 +107,11 @@ afterSign: scripts/notarize.js
artifactBuildCompleted: scripts/artifact-build-completed.js
releaseInfo:
releaseNotes: |
- 新功能:可选数据保存目录
- 快捷助手:支持单独选择助手,支持暂停、上下文、思考过程、流式
- 划词助手:系统托盘菜单开关
- 翻译:新增 Markdown 预览选项
- 新供应商:新增 Vertex AI 服务商
- 错误修复和界面优化
划词助手:支持文本选择快捷键、开关快捷键、思考块支持和引用功能
复制功能新增纯文本复制去除Markdown格式符号
知识库支持设置向量维度修复Ollama分数错误和维度编辑问题
多语言:增加模型名称多语言提示和翻译源语言手动选择
文件管理:修复主题/消息删除时文件未清理问题,优化文件选择流程
模型修复Gemini模型推理预算、Voyage AI嵌入问题和DeepSeek翻译模型更新
图像功能统一图片查看器支持Base64图片渲染修复图片预览相关问题
UI实现标签折叠/拖拽排序,修复气泡溢出,增加引文索引显示

View File

@@ -19,13 +19,7 @@ export default defineConfig({
},
build: {
rollupOptions: {
external: ['@libsql/client', 'bufferutil', 'utf-8-validate'],
output: {
// 彻底禁用代码分割 - 返回 null 强制单文件打包
manualChunks: undefined,
// 内联所有动态导入,这是关键配置
inlineDynamicImports: true
}
external: ['@libsql/client', 'bufferutil', 'utf-8-validate']
},
sourcemap: process.env.NODE_ENV === 'development'
},
@@ -68,16 +62,12 @@ export default defineConfig({
}
},
optimizeDeps: {
exclude: ['pyodide'],
esbuildOptions: {
target: 'esnext' // for dev
}
exclude: ['pyodide']
},
worker: {
format: 'es'
},
build: {
target: 'esnext', // for build
rollupOptions: {
input: {
index: resolve(__dirname, 'src/renderer/index.html'),

View File

@@ -1,6 +1,6 @@
{
"name": "CherryStudio",
"version": "1.4.4",
"version": "1.4.2",
"private": true,
"description": "A powerful AI assistant for producer.",
"main": "./out/main/index.js",
@@ -58,22 +58,6 @@
"prepare": "husky"
},
"dependencies": {
"@libsql/client": "0.14.0",
"@libsql/win32-x64-msvc": "^0.4.7",
"@strongtz/win32-arm64-msvc": "^0.4.7",
"jsdom": "26.1.0",
"notion-helper": "^1.3.22",
"os-proxy-config": "^1.1.2",
"selection-hook": "^0.9.23",
"turndown": "7.2.0"
},
"devDependencies": {
"@agentic/exa": "^7.3.3",
"@agentic/searxng": "^7.3.3",
"@agentic/tavily": "^7.3.3",
"@ant-design/v5-patch-for-react-19": "^1.0.3",
"@anthropic-ai/sdk": "^0.41.0",
"@cherrystudio/ai-core": "workspace:*",
"@cherrystudio/embedjs": "^0.1.31",
"@cherrystudio/embedjs-libsql": "^0.1.31",
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
@@ -86,20 +70,54 @@
"@cherrystudio/embedjs-loader-xml": "^0.1.31",
"@cherrystudio/embedjs-ollama": "^0.1.31",
"@cherrystudio/embedjs-openai": "^0.1.31",
"@electron-toolkit/utils": "^3.0.0",
"@langchain/community": "^0.3.36",
"@langchain/ollama": "^0.2.1",
"@strongtz/win32-arm64-msvc": "^0.4.7",
"@tanstack/react-query": "^5.27.0",
"@types/react-infinite-scroll-component": "^5.0.0",
"archiver": "^7.0.1",
"async-mutex": "^0.5.0",
"diff": "^7.0.0",
"docx": "^9.0.2",
"electron-log": "^5.1.5",
"electron-store": "^8.2.0",
"electron-updater": "6.6.4",
"electron-window-state": "^5.0.3",
"epub": "patch:epub@npm%3A1.3.0#~/.yarn/patches/epub-npm-1.3.0-8325494ffe.patch",
"fast-xml-parser": "^5.2.0",
"franc-min": "^6.2.0",
"fs-extra": "^11.2.0",
"jsdom": "^26.0.0",
"markdown-it": "^14.1.0",
"node-stream-zip": "^1.15.0",
"officeparser": "^4.1.1",
"os-proxy-config": "^1.1.2",
"proxy-agent": "^6.5.0",
"remove-markdown": "^0.6.2",
"selection-hook": "^0.9.23",
"tar": "^7.4.3",
"turndown": "^7.2.0",
"webdav": "^5.8.0",
"zipread": "^1.3.3"
},
"devDependencies": {
"@agentic/exa": "^7.3.3",
"@agentic/searxng": "^7.3.3",
"@agentic/tavily": "^7.3.3",
"@ant-design/v5-patch-for-react-19": "^1.0.3",
"@anthropic-ai/sdk": "^0.41.0",
"@electron-toolkit/eslint-config-prettier": "^3.0.0",
"@electron-toolkit/eslint-config-ts": "^3.0.0",
"@electron-toolkit/preload": "^3.0.0",
"@electron-toolkit/tsconfig": "^1.0.1",
"@electron-toolkit/utils": "^3.0.0",
"@electron/notarize": "^2.5.0",
"@emotion/is-prop-valid": "^1.3.1",
"@eslint-react/eslint-plugin": "^1.36.1",
"@eslint/js": "^9.22.0",
"@google/genai": "patch:@google/genai@npm%3A1.0.1#~/.yarn/patches/@google-genai-npm-1.0.1-e26f0f9af7.patch",
"@google/genai": "^1.0.1",
"@hello-pangea/dnd": "^16.6.0",
"@kangfenmao/keyv-storage": "^0.1.0",
"@langchain/community": "^0.3.36",
"@langchain/ollama": "^0.2.1",
"@modelcontextprotocol/sdk": "^1.11.4",
"@mozilla/readability": "^0.6.0",
"@notionhq/client": "^2.2.15",
@@ -107,7 +125,6 @@
"@reduxjs/toolkit": "^2.2.5",
"@shikijs/markdown-it": "^3.4.2",
"@swc/plugin-styled-components": "^7.1.5",
"@tanstack/react-query": "^5.27.0",
"@testing-library/dom": "^10.4.0",
"@testing-library/jest-dom": "^6.6.3",
"@testing-library/react": "^16.3.0",
@@ -134,37 +151,24 @@
"@vitest/web-worker": "^3.1.4",
"@xyflow/react": "^12.4.4",
"antd": "^5.22.5",
"archiver": "^7.0.1",
"async-mutex": "^0.5.0",
"axios": "^1.7.3",
"browser-image-compression": "^2.0.2",
"color": "^5.0.0",
"dayjs": "^1.11.11",
"dexie": "^4.0.8",
"dexie-react-hooks": "^1.1.7",
"diff": "^7.0.0",
"docx": "^9.0.2",
"dotenv-cli": "^7.4.2",
"electron": "35.4.0",
"electron-builder": "26.0.15",
"electron-devtools-installer": "^3.2.0",
"electron-log": "^5.1.5",
"electron-store": "^8.2.0",
"electron-updater": "6.6.4",
"electron-vite": "^3.1.0",
"electron-window-state": "^5.0.3",
"emittery": "^1.0.3",
"emoji-picker-element": "^1.22.1",
"epub": "patch:epub@npm%3A1.3.0#~/.yarn/patches/epub-npm-1.3.0-8325494ffe.patch",
"eslint": "^9.22.0",
"eslint-plugin-react-hooks": "^5.2.0",
"eslint-plugin-simple-import-sort": "^12.1.1",
"eslint-plugin-unused-imports": "^4.1.4",
"fast-diff": "^1.3.0",
"fast-xml-parser": "^5.2.0",
"franc-min": "^6.2.0",
"fs-extra": "^11.2.0",
"google-auth-library": "^9.15.1",
"html-to-image": "^1.11.13",
"husky": "^9.1.7",
"i18next": "^23.11.5",
@@ -173,25 +177,21 @@
"lodash": "^4.17.21",
"lru-cache": "^11.1.0",
"lucide-react": "^0.487.0",
"markdown-it": "^14.1.0",
"mermaid": "^11.6.0",
"mime": "^4.0.4",
"motion": "^12.10.5",
"node-stream-zip": "^1.15.0",
"npx-scope-finder": "^1.2.0",
"officeparser": "^4.1.1",
"openai": "patch:openai@npm%3A5.1.0#~/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch",
"p-queue": "^8.1.0",
"playwright": "^1.52.0",
"prettier": "^3.5.3",
"proxy-agent": "^6.5.0",
"rc-virtual-list": "^3.18.6",
"react": "^19.0.0",
"react-dom": "^19.0.0",
"react-hotkeys-hook": "^4.6.1",
"react-i18next": "^14.1.2",
"react-infinite-scroll-component": "^6.1.0",
"react-markdown": "^10.1.0",
"react-markdown": "^9.0.1",
"react-redux": "^9.1.2",
"react-router": "6",
"react-router-dom": "6",
@@ -200,26 +200,22 @@
"redux": "^5.0.1",
"redux-persist": "^6.0.0",
"rehype-katex": "^7.0.1",
"rehype-mathjax": "^7.1.0",
"rehype-mathjax": "^7.0.0",
"rehype-raw": "^7.0.0",
"remark-cjk-friendly": "^1.2.0",
"remark-gfm": "^4.0.1",
"remark-cjk-friendly": "^1.1.0",
"remark-gfm": "^4.0.0",
"remark-math": "^6.0.0",
"remove-markdown": "^0.6.2",
"rollup-plugin-visualizer": "^5.12.0",
"sass": "^1.88.0",
"shiki": "^3.4.2",
"string-width": "^7.2.0",
"styled-components": "^6.1.11",
"tar": "^7.4.3",
"tiny-pinyin": "^1.3.2",
"tokenx": "^0.4.1",
"typescript": "^5.6.2",
"uuid": "^10.0.0",
"vite": "6.2.6",
"vitest": "^3.1.4",
"webdav": "^5.8.0",
"zipread": "^1.3.3"
"vitest": "^3.1.4"
},
"resolutions": {
"pdf-parse@npm:1.1.1": "patch:pdf-parse@npm%3A1.1.1#~/.yarn/patches/pdf-parse-npm-1.1.1-04a6109b2a.patch",

View File

@@ -1,622 +0,0 @@
# Cherry Studio AI Core 基于 Vercel AI SDK 的技术架构
## 1. 架构设计理念
### 1.1 设计目标
- **统一接口**:使用 Vercel AI SDK 统一不同 AI Provider 的接口差异
- **动态导入**:通过动态导入实现按需加载,减少打包体积
- **最小包装**:直接使用 AI SDK 的类型和接口,避免重复定义
- **插件系统**:基于钩子的插件架构,支持请求全生命周期扩展
- **类型安全**:利用 TypeScript 和 AI SDK 的类型系统确保类型安全
- **轻量级**:专注核心功能,保持包的轻量和高效
- **包级独立**:作为独立包管理,便于复用和维护
### 1.2 核心优势
- **标准化**AI SDK 提供统一的模型接口,减少适配工作
- **简化维护**:废弃复杂的 XxxApiClient统一为工厂函数模式
- **更好的开发体验**:完整的 TypeScript 支持和丰富的生态系统
- **性能优化**AI SDK 内置优化和最佳实践
- **模块化设计**:独立包结构,支持跨项目复用
- **可扩展插件**:基于钩子的插件系统,支持灵活的功能扩展和流转换
## 2. 整体架构图
```mermaid
graph TD
subgraph "Cherry Studio 主应用"
UI["用户界面"]
Components["React 组件"]
end
subgraph "packages/aiCore (AI Core 包)"
ApiClientFactory["ApiClientFactory (工厂类)"]
UniversalClient["UniversalAiSdkClient (统一客户端)"]
ProviderRegistry["Provider 注册表"]
PluginManager["插件管理器"]
end
subgraph "动态导入层"
DynamicImport["动态导入"]
end
subgraph "Vercel AI SDK"
AICore["ai (核心库)"]
OpenAI["@ai-sdk/openai"]
Anthropic["@ai-sdk/anthropic"]
Google["@ai-sdk/google"]
XAI["@ai-sdk/xai"]
Others["其他 19+ Providers"]
end
subgraph "插件生态"
FirstHooks["First Hooks (resolveModel, loadTemplate)"]
SequentialHooks["Sequential Hooks (transformParams, transformResult)"]
ParallelHooks["Parallel Hooks (onRequestStart, onRequestEnd, onError)"]
StreamHooks["Stream Hooks (transformStream)"]
end
UI --> ApiClientFactory
Components --> ApiClientFactory
ApiClientFactory --> UniversalClient
UniversalClient --> PluginManager
PluginManager --> ProviderRegistry
ProviderRegistry --> DynamicImport
DynamicImport --> OpenAI
DynamicImport --> Anthropic
DynamicImport --> Google
DynamicImport --> XAI
DynamicImport --> Others
UniversalClient --> AICore
AICore --> streamText
AICore --> generateText
PluginManager --> FirstHooks
PluginManager --> SequentialHooks
PluginManager --> ParallelHooks
PluginManager --> StreamHooks
```
## 3. 包结构设计
### 3.1 包级文件结构(当前简化版 + 规划)
```
packages/aiCore/
├── src/
│ ├── providers/
│ │ ├── registry.ts # Provider 注册表 ✅
│ │ └── types.ts # 核心类型定义 ✅
│ ├── clients/
│ │ ├── UniversalAiSdkClient.ts # 统一AI SDK客户端 ✅
│ │ └── ApiClientFactory.ts # 客户端工厂 ✅
│ ├── middleware/ # 插件系统 ✅
│ │ ├── types.ts # 插件类型定义 ✅
│ │ ├── manager.ts # 插件管理器 ✅
│ │ ├── examples/ # 示例插件 ✅
│ │ │ ├── example-plugins.ts # 示例插件实现 ✅
│ │ │ └── example-usage.ts # 使用示例 ✅
│ │ ├── README.md # 插件系统文档 ✅
│ │ └── index.ts # 插件模块入口 ✅
│ ├── services/ # 高级服务 (规划中)
│ │ ├── AiCoreService.ts # 统一服务入口
│ │ ├── CompletionsService.ts # 文本生成服务
│ │ ├── EmbeddingService.ts # 嵌入服务
│ │ └── ImageService.ts # 图像生成服务
│ └── index.ts # 包主入口文件 ✅
├── package.json # 包配置文件 ✅
├── tsconfig.json # TypeScript 配置 ✅
├── README.md # 包说明文档 ✅
└── AI_SDK_ARCHITECTURE.md # 本文档 ✅
```
**图例:**
- ✅ 已实现
- 规划中:设计完成,待实现
## 4. 核心组件详解
### 4.1 Provider 注册表 (`providers/registry.ts`)
统一管理所有 AI Provider 的注册和动态导入。
**主要功能:**
- 动态导入 AI SDK providers
- 提供统一的 Provider 创建接口
- 支持 19+ 官方 AI SDK providers
- 类型安全的 Provider 配置
**核心 API**
```typescript
export interface ProviderConfig {
id: string
name: string
import: () => Promise<any>
creatorFunctionName: string
}
export class AiProviderRegistry {
getProvider(id: string): ProviderConfig | undefined
getAllProviders(): ProviderConfig[]
isSupported(id: string): boolean
registerProvider(config: ProviderConfig): void
}
```
**支持的 Providers**
- OpenAI, Anthropic, Google, XAI
- Azure OpenAI, Amazon Bedrock, Google Vertex
- Groq, Together.ai, Fireworks, DeepSeek
- Cerebras, DeepInfra, Replicate, Perplexity
- Cohere, Fal AI, Vercel (19+ providers)
### 4.2 统一AI SDK客户端 (`clients/UniversalAiSdkClient.ts`)
将不同 AI providers 包装为统一接口。
**主要功能:**
- 异步初始化和动态加载
- 统一的 stream() 和 generate() 方法
- 直接使用 AI SDK 的 streamText() 和 generateText()
- 配置验证和错误处理
**核心 API**
```typescript
export class UniversalAiSdkClient {
async initialize(): Promise<void>
isInitialized(): boolean
async stream(request: any): Promise<any>
async generate(request: any): Promise<any>
validateConfig(): boolean
getProviderInfo(): { id: string; name: string; isInitialized: boolean }
}
```
### 4.3 客户端工厂 (`clients/ApiClientFactory.ts`)
统一创建和管理 AI SDK 客户端。
**主要功能:**
- 统一的客户端创建接口
- 智能缓存和复用机制
- 批量创建和健康检查
- 错误处理和重试
**核心 API**
```typescript
export class ApiClientFactory {
static async createAiSdkClient(providerId: string, options: any): Promise<UniversalAiSdkClient>
static getCachedClient(providerId: string, options: any): UniversalAiSdkClient | undefined
static clearCache(): void
static async healthCheck(): Promise<HealthCheckResult>
static getSupportedProviders(): ProviderInfo[]
}
```
### 4.4 钩子风格插件系统 ✅
基于钩子机制的插件架构设计,提供灵活的扩展系统。
**钩子类型:**
1. **First Hooks**:执行到第一个有效结果就停止
2. **Sequential Hooks**:按序链式执行,可变换数据
3. **Parallel Hooks**:并发执行,用于副作用
4. **Stream Hooks**:流转换,直接传递给 AI SDK
**优先级系统:**
- `pre`:前置处理(-100 到 -1
- `normal`标准处理0 到 99
- `post`后置处理100 到 199
**核心钩子:**
**First Hooks (第一个有效结果)**
- `resolveModel`:模型解析,返回第一个匹配的模型
- `loadTemplate`:模板加载,返回第一个找到的模板
**Sequential Hooks (链式变换)**
- `transformParams`:参数转换,依次变换请求参数
- `transformResult`:结果转换,依次变换响应结果
**Parallel Hooks (并发副作用)**
- `onRequestStart`:请求开始时触发
- `onRequestEnd`:请求结束时触发
- `onError`:错误发生时触发
**Stream Hooks (流转换)**
- `transformStream`:流转换,返回 AI SDK 转换函数
**插件 API 设计:**
```typescript
export interface Plugin {
name: string
enforce?: 'pre' | 'normal' | 'post'
// First hooks - 执行到第一个有效结果
resolveModel?(params: ResolveModelParams): Promise<string | null>
loadTemplate?(params: LoadTemplateParams): Promise<Template | null>
// Sequential hooks - 链式变换
transformParams?(params: any, context: PluginContext): Promise<any>
transformResult?(result: any, context: PluginContext): Promise<any>
// Parallel hooks - 并发副作用
onRequestStart?(context: PluginContext): Promise<void>
onRequestEnd?(context: PluginContext): Promise<void>
onError?(error: Error, context: PluginContext): Promise<void>
// Stream hooks - AI SDK 流转换
transformStream?(context: PluginContext): Promise<(readable: ReadableStream) => ReadableStream>
}
export interface PluginContext {
request: any
response?: any
metadata: Record<string, any>
provider: string
model: string
}
export class PluginManager {
use(plugin: Plugin): this
executeFirstHook<T>(hookName: string, ...args: any[]): Promise<T | null>
executeSequentialHook<T>(hookName: string, initialValue: T, context: PluginContext): Promise<T>
executeParallelHook(hookName: string, ...args: any[]): Promise<void>
collectStreamTransforms(context: PluginContext): Promise<Array<(readable: ReadableStream) => ReadableStream>>
}
```
### 4.5 统一服务接口 (规划中)
作为包的主要对外接口,提供高级 AI 功能。
**服务方法:**
- `completions()`: 文本生成
- `streamCompletions()`: 流式文本生成
- `generateObject()`: 结构化数据生成
- `generateImage()`: 图像生成
- `embed()`: 文本嵌入
**API 设计:**
```typescript
export class AiCoreService {
constructor(middlewares?: Middleware[])
async completions(request: CompletionRequest): Promise<CompletionResponse>
async streamCompletions(request: CompletionRequest): Promise<StreamCompletionResponse>
async generateObject<T>(request: ObjectGenerationRequest): Promise<T>
async generateImage(request: ImageGenerationRequest): Promise<ImageResponse>
async embed(request: EmbeddingRequest): Promise<EmbeddingResponse>
use(middleware: Middleware): this
configure(config: AiCoreConfig): this
}
```
## 5. 使用方式
### 5.1 多 Provider 支持
```typescript
import { createAiSdkClient, AiCore } from '@cherrystudio/ai-core'
// 检查支持的 providers
const providers = AiCore.getSupportedProviders()
console.log(`支持 ${providers.length} 个 AI providers`)
// 创建多个 provider 客户端
const openai = await createAiSdkClient('openai', { apiKey: 'openai-key' })
const anthropic = await createAiSdkClient('anthropic', { apiKey: 'anthropic-key' })
const google = await createAiSdkClient('google', { apiKey: 'google-key' })
const xai = await createAiSdkClient('xai', { apiKey: 'xai-key' })
```
### 5.2 在 Cherry Studio 中集成
```typescript
// 替换现有的 XxxApiClient
// 之前:
// const openaiClient = new OpenAIApiClient(config)
// const anthropicClient = new AnthropicApiClient(config)
// 现在:
import { createAiSdkClient } from '@cherrystudio/ai-core'
const createProviderClient = async (provider: CherryProvider) => {
return await createAiSdkClient(provider.id, {
apiKey: provider.apiKey,
baseURL: provider.baseURL
})
}
```
### 5.6 完整的工作流示例 (规划中)
```typescript
import {
createAiSdkClient,
AiCoreService,
MiddlewareChain,
PreRequestMiddleware,
StreamProcessingMiddleware,
PostResponseMiddleware
} from '@cherrystudio/ai-core'
// 创建完整的工作流
const createEnhancedAiService = async () => {
// 创建中间件链
const middlewareChain = new MiddlewareChain()
.use(
new PreRequestMiddleware({
validateApiKey: true,
checkRateLimit: true
})
)
.use(
new StreamProcessingMiddleware({
enableProgressTracking: true,
chunkTransform: (chunk) => ({
...chunk,
timestamp: Date.now()
})
})
)
.use(
new PostResponseMiddleware({
saveToHistory: true,
calculateMetrics: true
})
)
// 创建服务实例
const service = new AiCoreService(middlewareChain.middlewares)
return service
}
// 使用增强服务
const enhancedService = await createEnhancedAiService()
const response = await enhancedService.completions({
provider: 'anthropic',
model: 'claude-3-sonnet',
messages: [{ role: 'user', content: 'Write a technical blog post about AI middleware' }],
options: {
temperature: 0.7,
maxTokens: 2000
},
middleware: {
// 中间件特定配置
thinking: { recordSteps: true },
cache: { enabled: true, ttl: 1800 },
logging: { level: 'debug' }
}
})
```
## 6. 简化设计原则
### 6.1 最小包装原则
- 直接使用 AI SDK 的类型,不重复定义
- 避免过度抽象和复杂的中间层
- 保持与 AI SDK 原生 API 的一致性
### 6.2 动态导入优化
```typescript
// 按需加载,减少打包体积
const module = await import('@ai-sdk/openai')
const createOpenAI = module.createOpenAI
```
### 6.3 类型安全
```typescript
// 直接使用 AI SDK 类型
import { streamText, generateText } from 'ai'
// 避免重复定义,直接传递参数
return streamText({ model, ...request })
```
### 6.4 配置简化
```typescript
// 简化的 Provider 配置
interface ProviderConfig {
id: string // provider 标识
name: string // 显示名称
import: () => Promise<any> // 动态导入函数
creatorFunctionName: string // 创建函数名
}
```
## 7. 技术要点
### 7.1 动态导入策略
- **按需加载**:只加载用户实际使用的 providers
- **缓存机制**:避免重复导入和初始化
- **错误处理**:优雅处理导入失败的情况
### 7.2 依赖管理策略
- **核心依赖**`ai` 库作为必需依赖
- **可选依赖**:所有 `@ai-sdk/*` 包都是可选的
- **版本兼容**:支持 AI SDK v3-v5 版本
### 7.3 缓存策略
- **客户端缓存**:基于 provider + options 的智能缓存
- **配置哈希**:安全的 API key 哈希处理
- **生命周期管理**:支持缓存清理和验证
## 8. 迁移策略
### 8.1 阶段一:包基础搭建 (Week 1) ✅ 已完成
1. ✅ 创建简化的包结构
2. ✅ 实现 Provider 注册表
3. ✅ 创建统一客户端和工厂
4. ✅ 配置构建和类型系统
### 8.2 阶段二:核心功能完善 (Week 2) ✅ 已完成
1. ✅ 支持 19+ 官方 AI SDK providers
2. ✅ 实现缓存和错误处理
3. ✅ 完善类型安全和 API 设计
4. ✅ 添加便捷函数和工具
### 8.3 阶段三:集成测试 (Week 3) 🔄 进行中
1. 在 Cherry Studio 中集成测试
2. 功能完整性验证
3. 性能基准测试
4. 兼容性问题修复
### 8.4 阶段四:插件系统实现 ✅ 已完成
1. **插件核心架构**
- 实现 `PluginManager``PluginContext`
- 创建钩子风格插件接口和类型系统
- 建立四种钩子类型执行机制
2. **钩子系统**
- `First Hooks`:第一个有效结果执行
- `Sequential Hooks`:链式数据变换
- `Parallel Hooks`:并发副作用处理
- `Stream Hooks`AI SDK 流转换集成
3. **优先级和排序**
- `pre`/`normal`/`post` 优先级系统
- 插件注册顺序维护
- 错误处理和插件隔离
4. **集成到现有架构**
-`UniversalAiSdkClient` 中集成插件管理器
- 更新 `ApiClientFactory` 支持插件配置
- 创建示例插件和使用文档
### 8.5 阶段五:特性插件扩展 (规划中)
1. **Cherry Studio 特性插件**
- `ThinkingPlugin`:思考过程记录和提取
- `ToolCallPlugin`:工具调用处理和增强
- `WebSearchPlugin`:网络搜索集成
2. **高级功能**
- 插件组合和条件执行
- 动态插件加载系统
- 插件配置管理和持久化
### 8.6 阶段六:文档和发布 (Week 7) 📋 规划中
1. 完善使用文档和示例
2. 插件开发指南和最佳实践
3. 准备发布到 npm
4. 建立维护流程
### 8.7 阶段七:生态系统扩展 (Week 8+) 🚀 未来规划
1. 社区插件生态系统
2. 可视化插件编排工具
3. 性能监控和分析
4. 高级缓存和优化策略
## 9. 预期收益
### 9.1 开发效率提升
- **90%** 减少新 Provider 接入时间(只需添加注册表配置)
- **70%** 减少维护工作量
- **95%** 提升开发体验(统一接口 + 类型安全)
- **独立开发**:可以独立于主应用开发和测试
### 9.2 代码质量改善
- 完整的 TypeScript 类型安全
- 统一的错误处理机制
- 标准化的 AI SDK 接口
- 更好的测试覆盖率
### 9.3 架构优势
- **轻量级**:最小化的包装层
- **可复用**:其他项目可以直接使用
- **可维护**:独立版本管理和发布
- **可扩展**:新 provider 只需配置即可
### 9.4 生态系统价值
- 支持 AI SDK 的完整生态系统
- 可以独立发布到 npm
- 为开源社区贡献价值
- 建立统一的 AI 基础设施
## 10. 风险评估与应对
### 10.1 技术风险
- **AI SDK 版本兼容**:支持多版本兼容策略
- **依赖管理**:合理使用 peerDependencies
- **类型一致性**:直接使用 AI SDK 类型
- **性能影响**:最小化包装层开销
### 10.2 迁移风险
- **功能对等性**:确保所有现有功能都能实现
- **API 兼容性**:提供平滑的迁移路径
- **集成复杂度**:保持简单的集成方式
- **学习成本**:提供清晰的使用文档
## 11. 总结
简化的 AI Core 架构专注于核心价值:
### 11.1 核心价值
- **统一接口**:一套 API 支持 19+ AI providers
- **按需加载**:只打包用户实际使用的 providers
- **类型安全**:完整的 TypeScript 支持
- **轻量高效**:最小化的包装层
### 11.2 设计哲学
- **直接使用 AI SDK**:避免重复造轮子,充分利用原生能力
- **最小包装**:只在必要时添加抽象层,保持轻量高效
- **开发者友好**:简单易用的 API 设计,熟悉的钩子风格
- **生态兼容**:充分利用 AI SDK 生态系统和原生流转换
- **插件优先**:基于钩子的扩展模式,支持灵活组合
### 11.3 成功关键
1. **保持简单**:专注核心功能,避免过度设计
2. **充分测试**:确保功能完整性和稳定性
3. **渐进迁移**:平滑过渡,降低风险
4. **文档完善**:支持快速上手和深度使用
这个基于钩子的插件系统架构为 Cherry Studio 提供了一个轻量、高效、可维护的 AI 基础设施,通过熟悉的钩子模式和原生 AI SDK 集成,为开发者提供了强大而简洁的扩展能力,同时为社区贡献了一个高质量的开源包。

View File

@@ -1,222 +0,0 @@
# @cherrystudio/ai-core
Cherry Studio AI Core 是一个基于 Vercel AI SDK 的统一 AI Provider 接口包。
## 特性
- 🚀 统一的 AI Provider 接口
- 🔄 动态导入支持
- 💾 智能缓存机制
- 🛠️ TypeScript 支持
- 📦 轻量级设计
## 支持的 Providers
基于 [AI SDK 官方支持的 providers](https://ai-sdk.dev/providers/ai-sdk-providers)
**核心 Providers:**
- OpenAI
- Anthropic
- Google Generative AI
- Google Vertex AI
- Mistral AI
- xAI (Grok)
- Azure OpenAI
- Amazon Bedrock
**扩展 Providers:**
- Cohere
- Groq
- Together.ai
- Fireworks
- DeepSeek
- Cerebras
- DeepInfra
- Replicate
- Perplexity
- Fal AI
- Vercel
## 安装
```bash
npm install @cherrystudio/ai-core ai
```
还需要安装你要使用的 AI SDK provider:
```bash
npm install @ai-sdk/openai @ai-sdk/anthropic @ai-sdk/google
```
## 使用示例
### 基础用法
```typescript
import { createAiSdkClient } from '@cherrystudio/ai-core'
// 创建 OpenAI 客户端
const client = await createAiSdkClient('openai', {
apiKey: 'your-api-key'
})
// 流式生成
const result = await client.stream({
modelId: 'gpt-4',
messages: [{ role: 'user', content: 'Hello!' }]
})
// 非流式生成
const response = await client.generate({
modelId: 'gpt-4',
messages: [{ role: 'user', content: 'Hello!' }]
})
```
### 便捷函数
```typescript
import { createOpenAIClient, streamGeneration } from '@cherrystudio/ai-core'
// 快速创建 OpenAI 客户端
const client = await createOpenAIClient({
apiKey: 'your-api-key'
})
// 便捷流式生成
const result = await streamGeneration('openai', 'gpt-4', [{ role: 'user', content: 'Hello!' }], {
apiKey: 'your-api-key'
})
```
### 多 Provider 支持
```typescript
import { createAiSdkClient } from '@cherrystudio/ai-core'
// 支持多种 AI providers
const openaiClient = await createAiSdkClient('openai', { apiKey: 'openai-key' })
const anthropicClient = await createAiSdkClient('anthropic', { apiKey: 'anthropic-key' })
const googleClient = await createAiSdkClient('google', { apiKey: 'google-key' })
const xaiClient = await createAiSdkClient('xai', { apiKey: 'xai-key' })
```
### 使用 AI SDK 原生 Provider 注册表
> https://ai-sdk.dev/docs/reference/ai-sdk-core/provider-registry
除了使用内建的 provider 管理,你还可以使用 AI SDK 原生的 `createProviderRegistry` 来构建自己的 provider 注册表。
#### 基本用法示例
```typescript
import { createClient } from '@cherrystudio/ai-core'
import { createProviderRegistry } from 'ai'
import { createOpenAI } from '@ai-sdk/openai'
import { anthropic } from '@ai-sdk/anthropic'
// 1. 创建 AI SDK 原生注册表
export const registry = createProviderRegistry({
// register provider with prefix and default setup:
anthropic,
// register provider with prefix and custom setup:
openai: createOpenAI({
apiKey: process.env.OPENAI_API_KEY
})
})
// 2. 创建client,'openai'可以传空或者传providerId(内建的provider)
const client = PluginEnabledAiClient.create('openai', {
apiKey: process.env.OPENAI_API_KEY
})
// 3. 方式1使用内建逻辑传统方式
const result1 = await client.streamText('gpt-4', {
messages: [{ role: 'user', content: 'Hello with built-in logic!' }]
})
// 4. 方式2使用自定义注册表灵活方式
const result2 = await client.streamText({
model: registry.languageModel('openai:gpt-4'),
messages: [{ role: 'user', content: 'Hello with custom registry!' }]
})
// 5. 支持的重载方法
await client.generateObject({
model: registry.languageModel('openai:gpt-4'),
schema: z.object({ name: z.string() }),
messages: [{ role: 'user', content: 'Generate a user' }]
})
await client.streamObject({
model: registry.languageModel('anthropic:claude-3-opus-20240229'),
schema: z.object({ items: z.array(z.string()) }),
messages: [{ role: 'user', content: 'Generate a list' }]
})
```
#### 与插件系统配合使用
更强大的是,你还可以将自定义注册表与 Cherry Studio 的插件系统结合使用:
```typescript
import { PluginEnabledAiClient } from '@cherrystudio/ai-core'
import { createProviderRegistry } from 'ai'
import { createOpenAI } from '@ai-sdk/openai'
import { anthropic } from '@ai-sdk/anthropic'
// 1. 创建带插件的客户端
const client = PluginEnabledAiClient.create(
'openai',
{
apiKey: process.env.OPENAI_API_KEY
},
[LoggingPlugin, RetryPlugin]
)
// 2. 创建自定义注册表
const registry = createProviderRegistry({
openai: createOpenAI({ apiKey: process.env.OPENAI_API_KEY }),
anthropic: anthropic({ apiKey: process.env.ANTHROPIC_API_KEY })
})
// 3. 方式1使用内建逻辑 + 完整插件系统
await client.streamText('gpt-4', {
messages: [{ role: 'user', content: 'Hello with plugins!' }]
})
// 4. 方式2使用自定义注册表 + 有限插件支持
await client.streamText({
model: registry.languageModel('anthropic:claude-3-opus-20240229'),
messages: [{ role: 'user', content: 'Hello from Claude!' }]
})
// 5. 支持的方法
await client.generateObject({
model: registry.languageModel('openai:gpt-4'),
schema: z.object({ name: z.string() }),
messages: [{ role: 'user', content: 'Generate a user' }]
})
await client.streamObject({
model: registry.languageModel('openai:gpt-4'),
schema: z.object({ items: z.array(z.string()) }),
messages: [{ role: 'user', content: 'Generate a list' }]
})
```
#### 混合使用的优势
- **灵活性**:可以根据需要选择使用内建逻辑或自定义注册表
- **兼容性**:完全兼容 AI SDK 的 `createProviderRegistry` API
- **渐进式**:可以逐步迁移现有代码,无需一次性重构
- **插件支持**:自定义注册表仍可享受 Cherry Studio 插件系统的部分功能
- **最佳实践**:结合两种方式的优点,既有动态加载的性能优势,又有统一注册表的便利性
## License
MIT

View File

@@ -1,124 +0,0 @@
{
"name": "@cherrystudio/ai-core",
"version": "1.0.0",
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
"main": "src/index.ts",
"types": "src/index.ts",
"scripts": {
"build": "tsdown",
"dev": "tsc -w",
"clean": "rm -rf dist"
},
"keywords": [
"ai",
"sdk",
"openai",
"anthropic",
"google",
"cherry-studio",
"vercel-ai-sdk"
],
"author": "Cherry Studio",
"license": "MIT",
"dependencies": {
"@ai-sdk/amazon-bedrock": "^2.2.10",
"@ai-sdk/anthropic": "^1.2.12",
"@ai-sdk/azure": "^1.3.23",
"@ai-sdk/cerebras": "^0.2.14",
"@ai-sdk/cohere": "^1.2.10",
"@ai-sdk/deepinfra": "^0.2.15",
"@ai-sdk/deepseek": "^0.2.14",
"@ai-sdk/fal": "^0.1.12",
"@ai-sdk/fireworks": "^0.2.14",
"@ai-sdk/google": "^1.2.19",
"@ai-sdk/google-vertex": "^2.2.24",
"@ai-sdk/groq": "^1.2.9",
"@ai-sdk/mistral": "^1.2.8",
"@ai-sdk/openai": "^1.3.22",
"@ai-sdk/openai-compatible": "^0.2.14",
"@ai-sdk/perplexity": "^1.1.9",
"@ai-sdk/replicate": "^0.2.8",
"@ai-sdk/togetherai": "^0.2.14",
"@ai-sdk/vercel": "^0.0.1",
"@ai-sdk/xai": "^1.2.16",
"@openrouter/ai-sdk-provider": "^0.1.0",
"ai": "^4.3.16",
"anthropic-vertex-ai": "^1.0.2",
"ollama-ai-provider": "^1.2.0",
"qwen-ai-provider": "^0.1.0",
"zhipu-ai-provider": "^0.1.1"
},
"peerDependenciesMeta": {
"@ai-sdk/amazon-bedrock": {
"optional": true
},
"@ai-sdk/anthropic": {
"optional": true
},
"@ai-sdk/azure": {
"optional": true
},
"@ai-sdk/cerebras": {
"optional": true
},
"@ai-sdk/cohere": {
"optional": true
},
"@ai-sdk/deepinfra": {
"optional": true
},
"@ai-sdk/deepseek": {
"optional": true
},
"@ai-sdk/fal": {
"optional": true
},
"@ai-sdk/fireworks": {
"optional": true
},
"@ai-sdk/google": {
"optional": true
},
"@ai-sdk/google-vertex": {
"optional": true
},
"@ai-sdk/groq": {
"optional": true
},
"@ai-sdk/mistral": {
"optional": true
},
"@ai-sdk/openai": {
"optional": true
},
"@ai-sdk/perplexity": {
"optional": true
},
"@ai-sdk/replicate": {
"optional": true
},
"@ai-sdk/together": {
"optional": true
},
"@ai-sdk/vercel": {
"optional": true
},
"@ai-sdk/xai": {
"optional": true
}
},
"devDependencies": {
"tsdown": "^0.12.8",
"typescript": "^5.0.0"
},
"files": [
"dist"
],
"exports": {
".": {
"types": "./src/index.ts",
"import": "./src/index.ts",
"require": "./src/index.ts"
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -1,223 +0,0 @@
/**
* API Client Factory
* 整合现有实现的改进版API客户端工厂
*/
import type { ImageModelV1 } from '@ai-sdk/provider'
import { type LanguageModelV1, LanguageModelV1Middleware, wrapLanguageModel } from 'ai'
import { aiProviderRegistry } from '../providers/registry'
import { type ProviderId, type ProviderSettingsMap } from './types'
// 客户端配置接口
export interface ClientConfig {
providerId: string
options?: any
}
// 错误类型
export class ClientFactoryError extends Error {
constructor(
message: string,
public providerId?: string,
public cause?: Error
) {
super(message)
this.name = 'ClientFactoryError'
}
}
/**
* API Client Factory
* 统一管理和创建AI SDK客户端
*/
export class ApiClientFactory {
/**
* 创建 AI SDK 模型实例
* 对于已知的 Provider 使用严格类型检查,未知的 Provider 默认使用 openai-compatible
*/
static async createClient<T extends ProviderId>(
providerId: T,
modelId: string,
options: ProviderSettingsMap[T],
middlewares?: LanguageModelV1Middleware[]
): Promise<LanguageModelV1>
static async createClient(
providerId: string,
modelId: string,
options: ProviderSettingsMap['openai-compatible'],
middlewares?: LanguageModelV1Middleware[]
): Promise<LanguageModelV1>
static async createClient(
providerId: string,
modelId: string = 'default',
options: any,
middlewares?: LanguageModelV1Middleware[]
): Promise<LanguageModelV1> {
try {
// 对于不在注册表中的 provider默认使用 openai-compatible
const effectiveProviderId = aiProviderRegistry.isSupported(providerId) ? providerId : 'openai-compatible'
// 获取Provider配置
const providerConfig = aiProviderRegistry.getProvider(effectiveProviderId)
if (!providerConfig) {
throw new ClientFactoryError(`Provider "${effectiveProviderId}" is not registered`, providerId)
}
// 动态导入模块
const module = await providerConfig.import()
// 获取创建函数
const creatorFunction = module[providerConfig.creatorFunctionName]
if (typeof creatorFunction !== 'function') {
throw new ClientFactoryError(
`Creator function "${providerConfig.creatorFunctionName}" not found in the imported module for provider "${effectiveProviderId}"`
)
}
// 创建provider实例
const provider = creatorFunction(options)
// 返回模型实例
if (typeof provider === 'function') {
let model: LanguageModelV1 = provider(modelId)
// 应用 AI SDK 中间件
if (middlewares && middlewares.length > 0) {
model = wrapLanguageModel({
model: model,
middleware: middlewares
})
}
return model
} else {
throw new ClientFactoryError(`Unknown model access pattern for provider "${effectiveProviderId}"`)
}
} catch (error) {
if (error instanceof ClientFactoryError) {
throw error
}
throw new ClientFactoryError(
`Failed to create client for provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
providerId,
error instanceof Error ? error : undefined
)
}
}
static async createImageClient<T extends ProviderId>(
providerId: T,
modelId: string,
options: ProviderSettingsMap[T]
): Promise<ImageModelV1>
static async createImageClient(
providerId: string,
modelId: string,
options: ProviderSettingsMap['openai-compatible']
): Promise<ImageModelV1>
static async createImageClient(providerId: string, modelId: string = 'default', options: any): Promise<ImageModelV1> {
try {
if (!aiProviderRegistry.isSupported(providerId)) {
throw new ClientFactoryError(`Provider "${providerId}" is not supported`, providerId)
}
const providerConfig = aiProviderRegistry.getProvider(providerId)
if (!providerConfig) {
throw new ClientFactoryError(`Provider "${providerId}" is not registered`, providerId)
}
if (!providerConfig.supportsImageGeneration) {
throw new ClientFactoryError(`Provider "${providerId}" does not support image generation`, providerId)
}
const module = await providerConfig.import()
const creatorFunction = module[providerConfig.creatorFunctionName]
if (typeof creatorFunction !== 'function') {
throw new ClientFactoryError(
`Creator function "${providerConfig.creatorFunctionName}" not found in the imported module for provider "${providerId}"`
)
}
const provider = creatorFunction(options)
if (provider && typeof provider.image === 'function') {
return provider.image(modelId)
} else {
throw new ClientFactoryError(`Image model function not found for provider "${providerId}"`)
}
} catch (error) {
if (error instanceof ClientFactoryError) {
throw error
}
throw new ClientFactoryError(
`Failed to create image client for provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
providerId,
error instanceof Error ? error : undefined
)
}
}
/**
* 获取支持的 Providers 列表
*/
static getSupportedProviders(): Array<{
id: string
name: string
}> {
return aiProviderRegistry.getAllProviders().map((provider) => ({
id: provider.id,
name: provider.name
}))
}
/**
* 获取 Provider 信息
*/
static getClientInfo(providerId: string): {
id: string
name: string
isSupported: boolean
effectiveProvider: string
} {
const effectiveProviderId = aiProviderRegistry.isSupported(providerId) ? providerId : 'openai-compatible'
const provider = aiProviderRegistry.getProvider(effectiveProviderId)
return {
id: providerId,
name: provider?.name || providerId,
isSupported: aiProviderRegistry.isSupported(providerId),
effectiveProvider: effectiveProviderId
}
}
}
// 便捷导出函数
export function createClient<T extends ProviderId>(
providerId: T,
modelId: string,
options: ProviderSettingsMap[T]
): Promise<LanguageModelV1>
export function createClient(
providerId: string,
modelId: string,
options: ProviderSettingsMap['openai-compatible']
): Promise<LanguageModelV1>
export function createClient(providerId: string, modelId: string = 'default', options: any): Promise<LanguageModelV1> {
return ApiClientFactory.createClient(providerId, modelId, options)
}
export const createImageClient = (providerId: string, modelId: string, options: any): Promise<ImageModelV1> =>
ApiClientFactory.createImageClient(providerId, modelId, options)
export const getSupportedProviders = () => ApiClientFactory.getSupportedProviders()
export const getClientInfo = (providerId: string) => ApiClientFactory.getClientInfo(providerId)

View File

@@ -1,381 +0,0 @@
/**
* AI Client - Cherry Studio AI Core 的主要客户端接口
* 默认集成插件系统,提供完整的 AI 调用能力
*
* ## 使用方式
*
* ```typescript
* import { AiClient } from '@cherrystudio/ai-core'
*
* // 创建客户端(默认带插件系统)
* const client = AiClient.create('openai', {
* name: 'openai',
* apiKey: process.env.OPENAI_API_KEY
* }, [LoggingPlugin, ContentFilterPlugin])
*
* // 使用方式与 UniversalAiSdkClient 完全相同
* const result = await client.generateText('gpt-4', {
* messages: [{ role: 'user', content: 'Hello!' }]
* })
* ```
*/
import { generateObject, generateText, LanguageModelV1Middleware, streamObject, streamText } from 'ai'
import { AiPlugin, createContext, PluginManager } from '../plugins'
import { isProviderSupported } from '../providers/registry'
import { ApiClientFactory } from './ApiClientFactory'
import { type ProviderId, type ProviderSettingsMap } from './types'
import { UniversalAiSdkClient } from './UniversalAiSdkClient'
/**
* Cherry Studio AI Core 的主要客户端
* 默认集成插件系统,提供完整的 AI 调用能力
*/
export class PluginEnabledAiClient<T extends ProviderId = ProviderId> {
private pluginManager: PluginManager
private baseClient: UniversalAiSdkClient<T>
private middlewares: LanguageModelV1Middleware[] = []
constructor(
private readonly providerId: T,
private readonly options: ProviderSettingsMap[T],
plugins: AiPlugin[] = []
) {
this.pluginManager = new PluginManager(plugins)
this.baseClient = UniversalAiSdkClient.create(providerId, options)
this.updateMiddlewares()
}
/**
* 添加单个插件
*/
use(plugin: AiPlugin): this {
this.pluginManager.use(plugin)
this.updateMiddlewares()
return this
}
/**
* 批量添加插件
*/
usePlugins(plugins: AiPlugin[]): this {
plugins.forEach((plugin) => this.pluginManager.use(plugin))
this.updateMiddlewares()
return this
}
/**
* 移除插件
*/
removePlugin(pluginName: string): this {
this.pluginManager.remove(pluginName)
this.updateMiddlewares()
return this
}
/**
* 重新计算并更新中间件列表
* 这是一个原子操作,以确保中间件列表总是最新的
*/
private updateMiddlewares(): void {
const pluginMiddlewares = this.pluginManager.collectAiSdkMiddlewares()
this.middlewares = pluginMiddlewares
}
/**
* 获取插件统计信息
*/
getPluginStats() {
return this.pluginManager.getStats()
}
/**
* 获取插件列表
*/
getPlugins() {
return this.pluginManager.getPlugins()
}
/**
* 执行插件处理的通用逻辑
* 1-5步骤的通用处理
*/
private async executeWithPlugins<TParams, TResult>(
methodName: string,
modelId: string,
params: TParams,
executor: (finalModelId: string, transformedParams: TParams) => Promise<TResult>
): Promise<TResult> {
// 创建请求上下文
const context = createContext(this.providerId, modelId, params)
try {
// 1. 触发请求开始事件
await this.pluginManager.executeParallel('onRequestStart', context)
// 2. 解析模型别名
const resolvedModelId = await this.pluginManager.executeFirst<string>('resolveModel', modelId, context)
const finalModelId = resolvedModelId || modelId
// 3. 转换请求参数
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
// 4. 执行具体的 API 调用
const result = await executor(finalModelId, transformedParams)
// 5. 转换结果(对于非流式调用)
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
// 6. 触发完成事件
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
return transformedResult
} catch (error) {
// 7. 触发错误事件
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
throw error
}
}
/**
* 执行流式调用的通用逻辑
* 流式调用的特殊处理(支持流转换器)
*/
private async executeStreamWithPlugins<TParams, TResult>(
methodName: string,
modelId: string,
params: TParams,
executor: (finalModelId: string, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>
): Promise<TResult> {
// 创建请求上下文
const context = createContext(this.providerId, modelId, params)
try {
// 1. 触发请求开始事件
await this.pluginManager.executeParallel('onRequestStart', context)
// 2. 解析模型别名
const resolvedModelId = await this.pluginManager.executeFirst<string>('resolveModel', modelId, context)
const finalModelId = resolvedModelId || modelId
// 3. 转换请求参数
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
// 4. 收集流转换器
const streamTransforms = this.pluginManager.collectStreamTransforms()
// 5. 执行流式 API 调用
const result = await executor(finalModelId, transformedParams, streamTransforms)
// 6. 触发完成事件(注意:对于流式调用,这里触发的是开始流式响应的事件)
await this.pluginManager.executeParallel('onRequestEnd', context, { stream: true })
return result
} catch (error) {
// 7. 触发错误事件
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
throw error
}
}
/**
* 获取注入了中间件的 AI SDK 模型实例
* 这是应用原生中间件的关键
*/
private async getModelWithMiddlewares(modelId: string) {
const middlewares = this.middlewares
// 3. 如果有中间件,创建一个新的、注入了中间件的客户端实例
return await ApiClientFactory.createClient(
this.providerId,
modelId,
this.options,
middlewares.length > 0 ? middlewares : undefined
)
}
/**
* 流式文本生成
*/
async streamText(
modelId: string,
params: Omit<Parameters<typeof streamText>[0], 'model'>
): Promise<ReturnType<typeof streamText>>
async streamText(params: Parameters<typeof streamText>[0]): Promise<ReturnType<typeof streamText>>
async streamText(
modelIdOrParams: string | Parameters<typeof streamText>[0],
params?: Omit<Parameters<typeof streamText>[0], 'model'>
): Promise<ReturnType<typeof streamText>> {
if (typeof modelIdOrParams === 'string') {
// 传统方式:使用内建逻辑
return this.executeStreamWithPlugins(
'streamText',
modelIdOrParams,
params!,
async (finalModelId, transformedParams, streamTransforms) => {
const model = await this.getModelWithMiddlewares(finalModelId)
const experimental_transform =
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
return await streamText({
model,
...transformedParams,
experimental_transform
})
}
)
} else {
// 外部 registry 方式:直接使用用户提供的 model
return await streamText(modelIdOrParams)
}
}
/**
* 生成文本
* 可能不需要了,因为内置模拟非流中间件
*/
async generateText(
modelId: string,
params: Omit<Parameters<typeof generateText>[0], 'model'>
): Promise<ReturnType<typeof generateText>> {
return this.executeWithPlugins('generateText', modelId, params, async (finalModelId, transformedParams) => {
const model = await this.getModelWithMiddlewares(finalModelId)
return await generateText({ model, ...transformedParams })
})
}
/**
* 生成结构化对象
*/
async generateObject(
modelId: string,
params: Omit<Parameters<typeof generateObject>[0], 'model'>
): Promise<ReturnType<typeof generateObject>>
async generateObject(params: Parameters<typeof generateObject>[0]): Promise<ReturnType<typeof generateObject>>
async generateObject(
modelIdOrParams: string | Parameters<typeof generateObject>[0],
params?: Omit<Parameters<typeof generateObject>[0], 'model'>
): Promise<ReturnType<typeof generateObject>> {
if (typeof modelIdOrParams === 'string') {
// 传统方式:使用内建逻辑
return this.executeWithPlugins(
'generateObject',
modelIdOrParams,
params!,
async (finalModelId, transformedParams) => {
const model = await this.getModelWithMiddlewares(finalModelId)
return await generateObject({ model, ...transformedParams })
}
)
} else {
// 外部 registry 方式:直接使用用户提供的 model
return await generateObject(modelIdOrParams)
}
}
/**
* 流式生成结构化对象
*/
async streamObject(
modelId: string,
params: Omit<Parameters<typeof streamObject>[0], 'model'>
): Promise<ReturnType<typeof streamObject>>
async streamObject(params: Parameters<typeof streamObject>[0]): Promise<ReturnType<typeof streamObject>>
async streamObject(
modelIdOrParams: string | Parameters<typeof streamObject>[0],
params?: Omit<Parameters<typeof streamObject>[0], 'model'>
): Promise<ReturnType<typeof streamObject>> {
if (typeof modelIdOrParams === 'string') {
// 传统方式:使用内建逻辑
return this.executeWithPlugins(
'streamObject',
modelIdOrParams,
params!,
async (finalModelId, transformedParams) => {
return await this.baseClient.streamObject(finalModelId, transformedParams)
}
)
} else {
// 外部 registry 方式:直接使用用户提供的 model
return await streamObject(modelIdOrParams)
}
}
/**
* 获取客户端信息
*/
getClientInfo() {
return this.baseClient.getClientInfo()
}
/**
* 获取底层客户端实例(用于高级用法)
*/
getBaseClient(): UniversalAiSdkClient<T> {
return this.baseClient
}
// === 静态工厂方法 ===
/**
* 创建 OpenAI Compatible 客户端
*/
static createOpenAICompatible(
config: ProviderSettingsMap['openai-compatible'],
plugins: AiPlugin[] = []
): PluginEnabledAiClient<'openai-compatible'> {
return new PluginEnabledAiClient('openai-compatible', config, plugins)
}
/**
* 创建标准提供商客户端
*/
static create<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T],
plugins?: AiPlugin[]
): PluginEnabledAiClient<T>
static create(
providerId: string,
options: ProviderSettingsMap['openai-compatible'],
plugins?: AiPlugin[]
): PluginEnabledAiClient<'openai-compatible'>
static create(providerId: string, options: any, plugins: AiPlugin[] = []): PluginEnabledAiClient {
if (isProviderSupported(providerId)) {
return new PluginEnabledAiClient(providerId as ProviderId, options, plugins)
} else {
// 对于未知 provider使用 openai-compatible
return new PluginEnabledAiClient('openai-compatible', options, plugins)
}
}
}
/**
* 创建 AI 客户端的工厂函数(默认带插件系统)
*/
export function createClient<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T],
plugins?: AiPlugin[]
): PluginEnabledAiClient<T>
export function createClient(
providerId: string,
options: ProviderSettingsMap['openai-compatible'],
plugins?: AiPlugin[]
): PluginEnabledAiClient<'openai-compatible'>
export function createClient(providerId: string, options: any, plugins: AiPlugin[] = []): PluginEnabledAiClient {
return PluginEnabledAiClient.create(providerId, options, plugins)
}
/**
* 创建 OpenAI Compatible 客户端的便捷函数
*/
export function createCompatibleClient(
config: ProviderSettingsMap['openai-compatible'],
plugins: AiPlugin[] = []
): PluginEnabledAiClient<'openai-compatible'> {
return PluginEnabledAiClient.createOpenAICompatible(config, plugins)
}

View File

@@ -1,228 +0,0 @@
/**
* Universal AI SDK Client
* 统一的AI SDK客户端实现
*
* ## 使用方式
*
* ### 1. 官方提供商
* ```typescript
* import { UniversalAiSdkClient } from '@cherrystudio/ai-core'
*
* // OpenAI
* const openai = UniversalAiSdkClient.create('openai', {
* name: 'openai',
* apiHost: 'https://api.openai.com/v1',
* apiKey: process.env.OPENAI_API_KEY
* })
*
* // Anthropic
* const anthropic = UniversalAiSdkClient.create('anthropic', {
* name: 'anthropic',
* apiHost: 'https://api.anthropic.com',
* apiKey: process.env.ANTHROPIC_API_KEY
* })
* ```
*
* ### 2. OpenAI Compatible 第三方提供商
* ```typescript
* // LM Studio (本地运行)
* const lmStudio = UniversalAiSdkClient.createOpenAICompatible({
* name: 'lm-studio',
* baseURL: 'http://localhost:1234/v1'
* })
*
* // Ollama (本地运行)
* const ollama = UniversalAiSdkClient.createOpenAICompatible({
* name: 'ollama',
* baseURL: 'http://localhost:11434/v1'
* })
*
* // 自定义第三方 API
* const customProvider = UniversalAiSdkClient.createOpenAICompatible({
* name: 'my-provider',
* apiKey: process.env.CUSTOM_API_KEY,
* baseURL: 'https://api.customprovider.com/v1',
* headers: {
* 'X-Custom-Header': 'value',
* 'User-Agent': 'MyApp/1.0'
* },
* queryParams: {
* 'api-version': '2024-01'
* }
* })
* ```
*
* ### 3. 使用客户端进行 AI 调用
* ```typescript
* // 流式文本生成
* const stream = await client.streamText('gpt-4', {
* messages: [{ role: 'user', content: 'Hello!' }]
* })
*
* // 生成文本
* const { text } = await client.generateText('gpt-4', {
* messages: [{ role: 'user', content: 'Hello!' }]
* })
*
* // 生成结构化对象
* const { object } = await client.generateObject('gpt-4', {
* messages: [{ role: 'user', content: 'Generate a user profile' }],
* schema: z.object({
* name: z.string(),
* age: z.number()
* })
* })
* ```
*/
import { experimental_generateImage as generateImage, generateObject, generateText, streamObject, streamText } from 'ai'
import { ApiClientFactory } from './ApiClientFactory'
import { type ProviderId, type ProviderSettingsMap } from './types'
/**
* 通用 AI SDK 客户端
* 为特定 AI 提供商创建的客户端实例
*/
export class UniversalAiSdkClient<T extends ProviderId = ProviderId> {
constructor(
private readonly providerId: T,
private readonly options: ProviderSettingsMap[T]
) {}
/**
* 流式文本生成
* 直接使用 AI SDK 的 streamText 参数类型
*/
async streamText(
modelId: string,
params: Omit<Parameters<typeof streamText>[0], 'model'>
): Promise<ReturnType<typeof streamText>> {
const model = await ApiClientFactory.createClient(this.providerId, modelId, this.options)
return streamText({
model,
...params
})
}
/**
* 生成文本
* 直接使用 AI SDK 的 generateText 参数类型
*/
async generateText(
modelId: string,
params: Omit<Parameters<typeof generateText>[0], 'model'>
): Promise<ReturnType<typeof generateText>> {
const model = await ApiClientFactory.createClient(this.providerId, modelId, this.options)
return generateText({
model,
...params
})
}
/**
* 生成结构化对象
* 直接使用 AI SDK 的 generateObject 参数类型
*/
async generateObject(
modelId: string,
params: Omit<Parameters<typeof generateObject>[0], 'model'>
): Promise<ReturnType<typeof generateObject>> {
const model = await ApiClientFactory.createClient(this.providerId, modelId, this.options)
return await generateObject({
model,
...params
})
}
/**
* 流式生成结构化对象
* 直接使用 AI SDK 的 streamObject 参数类型
*/
async streamObject(
modelId: string,
params: Omit<Parameters<typeof streamObject>[0], 'model'>
): Promise<ReturnType<typeof streamObject>> {
const model = await ApiClientFactory.createClient(this.providerId, modelId, this.options)
return streamObject({
model,
...params
})
}
async generateImage(
modelId: string,
params: Omit<Parameters<typeof generateImage>[0], 'model'>
): Promise<ReturnType<typeof generateImage>> {
const model = await ApiClientFactory.createImageClient(this.providerId, modelId, this.options)
return generateImage({
model,
...params
})
}
/**
* 获取客户端信息
*/
getClientInfo() {
return ApiClientFactory.getClientInfo(this.providerId)
}
// === 静态工厂方法 ===
/**
* 创建 OpenAI Compatible 客户端
* 用于那些实现 OpenAI API 的第三方提供商
*/
static createOpenAICompatible(
config: ProviderSettingsMap['openai-compatible']
): UniversalAiSdkClient<'openai-compatible'> {
return new UniversalAiSdkClient('openai-compatible', config)
}
/**
* 创建标准提供商客户端
* 对于已知的 Provider 使用严格类型检查,未知的 Provider 默认使用 openai-compatible
*/
static create<T extends ProviderId>(providerId: T, options: ProviderSettingsMap[T]): UniversalAiSdkClient<T>
static create(
providerId: string,
options: ProviderSettingsMap['openai-compatible']
): UniversalAiSdkClient<'openai-compatible'>
static create(providerId: string, options: any): UniversalAiSdkClient {
if (providerId in ({} as ProviderSettingsMap)) {
return new UniversalAiSdkClient(providerId as ProviderId, options)
} else {
// 对于未知 provider使用 openai-compatible
return new UniversalAiSdkClient('openai-compatible', options)
}
}
}
/**
* 创建客户端实例的工厂函数
*/
export function createUniversalClient<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T]
): UniversalAiSdkClient<T>
export function createUniversalClient(
providerId: string,
options: ProviderSettingsMap['openai-compatible']
): UniversalAiSdkClient<'openai-compatible'>
export function createUniversalClient(providerId: string, options: any): UniversalAiSdkClient {
return UniversalAiSdkClient.create(providerId, options)
}
/**
* 创建 OpenAI Compatible 客户端的便捷函数
*/
export function createOpenAICompatibleClient(
config: ProviderSettingsMap['openai-compatible']
): UniversalAiSdkClient<'openai-compatible'> {
return UniversalAiSdkClient.createOpenAICompatible(config)
}

View File

@@ -1,42 +0,0 @@
import { generateObject, generateText, streamObject, streamText } from 'ai'
import type { ProviderSettingsMap } from '../providers/registry'
// ProviderSettings 是所有 Provider Settings 的联合类型
export type ProviderSettings = ProviderSettingsMap[keyof ProviderSettingsMap]
export type StreamTextParams = Omit<Parameters<typeof streamText>[0], 'model'>
export type GenerateTextParams = Omit<Parameters<typeof generateText>[0], 'model'>
export type StreamObjectParams = Omit<Parameters<typeof streamObject>[0], 'model'>
export type GenerateObjectParams = Omit<Parameters<typeof generateObject>[0], 'model'>
// 重新导出 ProviderSettingsMap 中的所有类型
export type {
AmazonBedrockProviderSettings,
AnthropicProviderSettings,
AnthropicVertexProviderSettings,
AzureOpenAIProviderSettings,
CerebrasProviderSettings,
CohereProviderSettings,
DeepInfraProviderSettings,
DeepSeekProviderSettings,
FalProviderSettings,
FireworksProviderSettings,
GoogleGenerativeAIProviderSettings,
GoogleVertexProviderSettings,
GroqProviderSettings,
MistralProviderSettings,
OllamaProviderSettings,
OpenAICompatibleProviderSettings,
OpenAIProviderSettings,
OpenRouterProviderSettings,
PerplexityProviderSettings,
ProviderId,
ProviderSettingsMap,
QwenProviderSettings,
ReplicateProviderSettings,
TogetherAIProviderSettings,
VercelProviderSettings,
XaiProviderSettings,
ZhipuProviderSettings
} from '../providers/registry'

View File

@@ -1,214 +0,0 @@
/**
* Cherry Studio AI Core Package
* 基于 Vercel AI SDK 的统一 AI Provider 接口
*/
// 导入内部使用的类和函数
import { ApiClientFactory } from './clients/ApiClientFactory'
import { createClient } from './clients/PluginEnabledAiClient'
import { type ProviderSettingsMap } from './clients/types'
import { createUniversalClient } from './clients/UniversalAiSdkClient'
import { aiProviderRegistry, isProviderSupported } from './providers/registry'
// ==================== 主要客户端接口 ====================
// 默认使用集成插件系统的客户端
export {
PluginEnabledAiClient as AiClient,
createClient,
createCompatibleClient
} from './clients/PluginEnabledAiClient'
// 为了向后兼容,也导出原名称
export { PluginEnabledAiClient } from './clients/PluginEnabledAiClient'
// ==================== 插件系统 ====================
export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './plugins'
export { createContext, definePlugin, PluginManager } from './plugins'
// ==================== 底层客户端(高级用法) ====================
// 不带插件系统的基础客户端,用于需要绕过插件系统的场景
export {
createOpenAICompatibleClient as createBasicOpenAICompatibleClient,
createUniversalClient,
UniversalAiSdkClient
} from './clients/UniversalAiSdkClient'
// ==================== 低级 API ====================
export { ApiClientFactory } from './clients/ApiClientFactory'
export { aiProviderRegistry } from './providers/registry'
// ==================== 类型定义 ====================
export type { ClientFactoryError } from './clients/ApiClientFactory'
export type {
GenerateObjectParams,
GenerateTextParams,
ProviderSettings,
StreamObjectParams,
StreamTextParams
} from './clients/types'
export type { ProviderConfig } from './providers/registry'
export type { ProviderError } from './providers/types'
export * as aiSdk from 'ai'
// ==================== AI SDK 常用类型导出 ====================
// 直接导出 AI SDK 的常用类型,方便使用
export type {
CoreAssistantMessage,
// 消息相关类型
CoreMessage,
CoreSystemMessage,
CoreToolMessage,
CoreUserMessage,
// 通用类型
FinishReason,
GenerateObjectResult,
// 生成相关类型
GenerateTextResult,
InvalidToolArgumentsError,
LanguageModelUsage, // AI SDK 4.0 中 TokenUsage 改名为 LanguageModelUsage
LanguageModelV1Middleware,
LanguageModelV1StreamPart,
// 错误类型
NoSuchToolError,
StreamTextResult,
// 流相关类型
TextStreamPart,
// 工具相关类型
Tool,
ToolCall,
ToolExecutionError,
ToolResult
} from 'ai'
export { defaultSettingsMiddleware, extractReasoningMiddleware, simulateStreamingMiddleware, smoothStream } from 'ai'
// 重新导出所有 Provider Settings 类型
export type {
AmazonBedrockProviderSettings,
AnthropicProviderSettings,
AnthropicVertexProviderSettings,
AzureOpenAIProviderSettings,
CerebrasProviderSettings,
CohereProviderSettings,
DeepInfraProviderSettings,
DeepSeekProviderSettings,
FalProviderSettings,
FireworksProviderSettings,
GoogleGenerativeAIProviderSettings,
GoogleVertexProviderSettings,
GroqProviderSettings,
MistralProviderSettings,
OllamaProviderSettings,
OpenAICompatibleProviderSettings,
OpenAIProviderSettings,
OpenRouterProviderSettings,
PerplexityProviderSettings,
ProviderId,
ProviderSettingsMap,
QwenProviderSettings,
ReplicateProviderSettings,
TogetherAIProviderSettings,
VercelProviderSettings,
XaiProviderSettings,
ZhipuProviderSettings
} from './clients/types'
// ==================== 工具函数 ====================
export { createClient as createApiClient, getClientInfo, getSupportedProviders } from './clients/ApiClientFactory'
export { getAllProviders, getProvider, isProviderSupported, registerProvider } from './providers/registry'
// ==================== 包信息 ====================
export const AI_CORE_VERSION = '1.0.0'
export const AI_CORE_NAME = '@cherrystudio/ai-core'
// ==================== 便捷 API ====================
// 主要的便捷工厂类
export const AiCore = {
version: AI_CORE_VERSION,
name: AI_CORE_NAME,
// 创建主要客户端(默认带插件系统)
create(providerId: string, options: any = {}, plugins: any[] = []) {
return createClient(providerId, options, plugins)
},
// 创建基础客户端(不带插件系统)
createBasic(providerId: string, options: any = {}) {
return createUniversalClient(providerId, options)
},
// 获取支持的providers
getSupportedProviders() {
return ApiClientFactory.getSupportedProviders()
},
// 检查provider支持
isSupported(providerId: string) {
return isProviderSupported(providerId)
},
// 获取客户端信息
getClientInfo(providerId: string) {
return ApiClientFactory.getClientInfo(providerId)
}
}
export const createOpenAIClient = (options: ProviderSettingsMap['openai'], plugins?: any[]) => {
return createClient('openai', options, plugins)
}
export const createAnthropicClient = (options: ProviderSettingsMap['anthropic'], plugins?: any[]) => {
return createClient('anthropic', options, plugins)
}
export const createGoogleClient = (options: ProviderSettingsMap['google'], plugins?: any[]) => {
return createClient('google', options, plugins)
}
export const createXAIClient = (options: ProviderSettingsMap['xai'], plugins?: any[]) => {
return createClient('xai', options, plugins)
}
// ==================== 调试和开发工具 ====================
export const DevTools = {
// 列出所有注册的providers
listProviders() {
return aiProviderRegistry.getAllProviders().map((p) => ({
id: p.id,
name: p.name
}))
},
// 测试provider连接
async testProvider(providerId: string, options: any) {
try {
const client = createClient(providerId, options)
const info = client.getClientInfo()
return {
success: true,
providerId: info.id,
name: info.name,
isSupported: info.isSupported
}
} catch (error) {
return {
success: false,
providerId,
error: error instanceof Error ? error.message : 'Unknown error'
}
}
},
// 获取provider详细信息
getProviderDetails() {
const providers = aiProviderRegistry.getAllProviders()
return {
supportedProviders: providers.length,
registeredProviders: providers.length,
providers: providers.map((p) => ({
id: p.id,
name: p.name
}))
}
}
}

View File

@@ -1,259 +0,0 @@
# AI Core 插件系统
支持四种钩子类型:**First**、**Sequential**、**Parallel** 和 **Stream**
## 🎯 设计理念
借鉴 Rollup/Vite 的成熟插件思想:
- **语义清晰**:不同钩子有不同的执行语义
- **类型安全**TypeScript 完整支持
- **性能优化**First 短路、Parallel 并发、Sequential 链式
- **易于扩展**`enforce` 排序 + 功能分类
## 📋 钩子类型
### 🥇 First 钩子 - 首个有效结果
```typescript
// 只执行第一个返回值的插件,用于解析和查找
resolveModel?: (modelId: string, context: AiRequestContext) => string | null
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null
```
### 🔄 Sequential 钩子 - 链式数据转换
```typescript
// 按顺序链式执行,每个插件可以修改数据
transformParams?: (params: any, context: AiRequestContext) => any
transformResult?: (result: any, context: AiRequestContext) => any
```
### ⚡ Parallel 钩子 - 并行副作用
```typescript
// 并发执行,用于日志、监控等副作用
onRequestStart?: (context: AiRequestContext) => void
onRequestEnd?: (context: AiRequestContext, result: any) => void
onError?: (error: Error, context: AiRequestContext) => void
```
### 🌊 Stream 钩子 - 流处理
```typescript
// 直接使用 AI SDK 的 TransformStream
transformStream?: () => (options) => TransformStream<TextStreamPart, TextStreamPart>
```
## 🚀 快速开始
### 基础用法
```typescript
import { PluginManager, createContext, definePlugin } from '@cherrystudio/ai-core/middleware'
// 创建插件管理器
const pluginManager = new PluginManager()
// 添加插件
pluginManager.use({
name: 'my-plugin',
async transformParams(params, context) {
return { ...params, temperature: 0.7 }
}
})
// 使用插件
const context = createContext('openai', 'gpt-4', { messages: [] })
const transformedParams = await pluginManager.executeSequential(
'transformParams',
{ messages: [{ role: 'user', content: 'Hello' }] },
context
)
```
### 完整示例
```typescript
import {
PluginManager,
ModelAliasPlugin,
LoggingPlugin,
ParamsValidationPlugin,
createContext
} from '@cherrystudio/ai-core/middleware'
// 创建插件管理器
const manager = new PluginManager([
ModelAliasPlugin, // 模型别名解析
ParamsValidationPlugin, // 参数验证
LoggingPlugin // 日志记录
])
// AI 请求流程
async function aiRequest(providerId: string, modelId: string, params: any) {
const context = createContext(providerId, modelId, params)
try {
// 1. 【并行】触发请求开始事件
await manager.executeParallel('onRequestStart', context)
// 2. 【首个】解析模型别名
const resolvedModel = await manager.executeFirst('resolveModel', modelId, context)
context.modelId = resolvedModel || modelId
// 3. 【串行】转换请求参数
const transformedParams = await manager.executeSequential('transformParams', params, context)
// 4. 【流处理】收集流转换器AI SDK 原生支持数组)
const streamTransforms = manager.collectStreamTransforms()
// 5. 调用 AI SDK这里省略具体实现
const result = await callAiSdk(transformedParams, streamTransforms)
// 6. 【串行】转换响应结果
const transformedResult = await manager.executeSequential('transformResult', result, context)
// 7. 【并行】触发请求完成事件
await manager.executeParallel('onRequestEnd', context, transformedResult)
return transformedResult
} catch (error) {
// 8. 【并行】触发错误事件
await manager.executeParallel('onError', context, undefined, error)
throw error
}
}
```
## 🔧 自定义插件
### 模型别名插件
```typescript
const ModelAliasPlugin = definePlugin({
name: 'model-alias',
enforce: 'pre', // 最先执行
async resolveModel(modelId) {
const aliases = {
gpt4: 'gpt-4-turbo-preview',
claude: 'claude-3-sonnet-20240229'
}
return aliases[modelId] || null
}
})
```
### 参数验证插件
```typescript
const ValidationPlugin = definePlugin({
name: 'validation',
async transformParams(params) {
if (!params.messages) {
throw new Error('messages is required')
}
return {
...params,
temperature: params.temperature ?? 0.7,
max_tokens: params.max_tokens ?? 4096
}
}
})
```
### 监控插件
```typescript
const MonitoringPlugin = definePlugin({
name: 'monitoring',
enforce: 'post', // 最后执行
async onRequestEnd(context, result) {
const duration = Date.now() - context.startTime
console.log(`请求耗时: ${duration}ms`)
}
})
```
### 内容过滤插件
```typescript
const FilterPlugin = definePlugin({
name: 'content-filter',
transformStream() {
return () =>
new TransformStream({
transform(chunk, controller) {
if (chunk.type === 'text-delta') {
const filtered = chunk.textDelta.replace(/敏感词/g, '***')
controller.enqueue({ ...chunk, textDelta: filtered })
} else {
controller.enqueue(chunk)
}
}
})
}
})
```
## 📊 执行顺序
### 插件排序
```
enforce: 'pre' → normal → enforce: 'post'
```
### 钩子执行流程
```mermaid
graph TD
A[请求开始] --> B[onRequestStart 并行执行]
B --> C[resolveModel 首个有效]
C --> D[loadTemplate 首个有效]
D --> E[transformParams 串行执行]
E --> F[collectStreamTransforms]
F --> G[AI SDK 调用]
G --> H[transformResult 串行执行]
H --> I[onRequestEnd 并行执行]
G --> J[异常处理]
J --> K[onError 并行执行]
```
## 💡 最佳实践
1. **功能单一**:每个插件专注一个功能
2. **幂等性**:插件应该是幂等的,重复执行不会产生副作用
3. **错误处理**:插件内部处理异常,不要让异常向上传播
4. **性能优化**使用合适的钩子类型First vs Sequential vs Parallel
5. **命名规范**:使用语义化的插件名称
## 🔍 调试工具
```typescript
// 查看插件统计信息
const stats = manager.getStats()
console.log('插件统计:', stats)
// 查看所有插件
const plugins = manager.getPlugins()
console.log(
'已注册插件:',
plugins.map((p) => p.name)
)
```
## ⚡ 性能优势
- **First 钩子**:一旦找到结果立即停止,避免无效计算
- **Parallel 钩子**:真正并发执行,不阻塞主流程
- **Sequential 钩子**:保证数据转换的顺序性
- **Stream 钩子**:直接集成 AI SDK零开销
这个设计兼顾了简洁性和强大功能,为 AI Core 提供了灵活而高效的扩展机制。

View File

@@ -1,192 +0,0 @@
import type { AiPlugin } from '../types'
/**
* 【First 钩子示例】模型别名解析插件
*/
export const ModelAliasPlugin: AiPlugin = {
name: 'model-alias',
enforce: 'pre',
async resolveModel(modelId) {
const aliases: Record<string, string> = {
gpt4: 'gpt-4-turbo-preview',
claude: 'claude-3-sonnet-20240229',
gemini: 'gemini-pro'
}
return aliases[modelId] || null
}
}
/**
* 【Sequential 钩子示例】参数验证和转换插件
*/
export const ParamsValidationPlugin: AiPlugin = {
name: 'params-validation',
async transformParams(params) {
// 参数验证
if (!params.messages || !Array.isArray(params.messages)) {
throw new Error('Invalid messages parameter')
}
// 参数转换:添加默认配置
return {
...params,
temperature: params.temperature ?? 0.7,
max_tokens: params.max_tokens ?? 4096,
stream: params.stream ?? true
}
},
async transformResult(result, context) {
// 结果后处理:添加元数据
return {
...result,
metadata: {
...result.metadata,
processedAt: new Date().toISOString(),
provider: context.providerId,
model: context.modelId
}
}
}
}
/**
* 【Parallel 钩子示例】日志记录插件
*/
export const LoggingPlugin: AiPlugin = {
name: 'logging',
async onRequestStart(context) {
console.log(`🚀 AI请求开始: ${context.providerId}/${context.modelId}`, {
requestId: context.requestId,
timestamp: new Date().toISOString()
})
},
async onRequestEnd(context, result) {
const duration = Date.now() - context.startTime
console.log(`✅ AI请求完成: ${context.requestId} (${duration}ms)`, {
provider: context.providerId,
model: context.modelId,
hasResult: !!result
})
},
async onError(error, context) {
const duration = Date.now() - context.startTime
console.error(`❌ AI请求失败: ${context.requestId} (${duration}ms)`, {
provider: context.providerId,
model: context.modelId,
error: error.message,
stack: error.stack
})
}
}
/**
* 【Parallel 钩子示例】性能监控插件
*/
export const PerformancePlugin: AiPlugin = {
name: 'performance',
enforce: 'post',
async onRequestEnd(context) {
const duration = Date.now() - context.startTime
// 记录性能指标
const metrics = {
requestId: context.requestId,
provider: context.providerId,
model: context.modelId,
duration,
timestamp: context.startTime,
success: true
}
// 发送到监控系统(这里只是示例)
// await sendMetrics(metrics)
console.log('📊 性能指标:', metrics)
},
async onError(error, context) {
const duration = Date.now() - context.startTime
const metrics = {
requestId: context.requestId,
provider: context.providerId,
model: context.modelId,
duration,
timestamp: context.startTime,
success: false,
errorType: error.constructor.name
}
console.log('📊 错误指标:', metrics)
}
}
/**
* 【Stream 钩子示例】内容过滤插件
*/
export const ContentFilterPlugin: AiPlugin = {
name: 'content-filter',
transformStream() {
return () =>
new TransformStream({
transform(chunk, controller) {
// 过滤敏感内容
if (chunk.type === 'text-delta') {
const filtered = chunk.textDelta.replace(/\b(敏感词|违禁词)\b/g, '***')
controller.enqueue({
...chunk,
textDelta: filtered
})
} else {
controller.enqueue(chunk)
}
}
})
}
}
/**
* 【First 钩子示例】模板加载插件
*/
export const TemplatePlugin: AiPlugin = {
name: 'template-loader',
async loadTemplate(templateName) {
const templates: Record<string, any> = {
chat: {
systemPrompt: '你是一个有用的AI助手',
temperature: 0.7
},
coding: {
systemPrompt: '你是一个专业的编程助手,请提供清晰、高质量的代码',
temperature: 0.3
},
creative: {
systemPrompt: '你是一个创意写作助手,请发挥想象力',
temperature: 0.9
}
}
return templates[templateName] || null
}
}
/**
* 示例插件组合
*/
export const defaultPlugins: AiPlugin[] = [
ModelAliasPlugin,
TemplatePlugin,
ParamsValidationPlugin,
LoggingPlugin,
PerformancePlugin,
ContentFilterPlugin
]

View File

@@ -1,255 +0,0 @@
import { openai } from '@ai-sdk/openai'
import { streamText } from 'ai'
import { PluginEnabledAiClient } from '../../clients/PluginEnabledAiClient'
import { createContext, PluginManager } from '../'
import { ContentFilterPlugin, LoggingPlugin } from './example-plugins'
/**
* 使用 PluginEnabledAiClient 的推荐方式
* 这是最简单直接的使用方法
*/
export async function exampleWithPluginEnabledClient() {
console.log('=== 使用 PluginEnabledAiClient 示例 ===')
// 1. 创建带插件的客户端 - 链式调用方式
const client = PluginEnabledAiClient.create('openai-compatible', {
name: 'openai',
baseURL: 'https://api.openai.com/v1',
apiKey: process.env.OPENAI_API_KEY || 'sk-test'
})
.use(LoggingPlugin)
.use(ContentFilterPlugin)
// 2. 或者在创建时传入插件(也可以这样使用)
// const clientWithPlugins = PluginEnabledAiClient.create(
// 'openai-compatible',
// {
// name: 'openai',
// baseURL: 'https://api.openai.com/v1',
// apiKey: process.env.OPENAI_API_KEY || 'sk-test'
// },
// [LoggingPlugin, ContentFilterPlugin]
// )
// 3. 查看插件统计信息
console.log('插件统计:', client.getPluginStats())
try {
// 4. 使用客户端进行 AI 调用(插件会自动生效)
console.log('开始生成文本...')
const result = await client.generateText('gpt-4', {
messages: [{ role: 'user', content: 'Hello, world!' }],
temperature: 0.7
})
console.log('生成的文本:', result.text)
// 5. 流式调用(支持流转换器)
console.log('开始流式生成...')
const streamResult = await client.streamText('gpt-4', {
messages: [{ role: 'user', content: 'Tell me a short story about AI' }]
})
console.log('开始流式响应...')
for await (const textPart of streamResult.textStream) {
process.stdout.write(textPart)
}
console.log('\n流式响应完成')
return result
} catch (error) {
console.error('调用失败:', error)
throw error
}
}
/**
* 创建 OpenAI Compatible 客户端的示例
*/
export function exampleOpenAICompatible() {
console.log('=== OpenAI Compatible 示例 ===')
// Ollama 示例
const ollama = PluginEnabledAiClient.createOpenAICompatible(
{
name: 'ollama',
baseURL: 'http://localhost:11434/v1'
},
[LoggingPlugin]
)
// LM Studio 示例
const lmStudio = PluginEnabledAiClient.createOpenAICompatible({
name: 'lm-studio',
baseURL: 'http://localhost:1234/v1'
}).use(ContentFilterPlugin)
console.log('Ollama 插件统计:', ollama.getPluginStats())
console.log('LM Studio 插件统计:', lmStudio.getPluginStats())
return { ollama, lmStudio }
}
/**
* 动态插件管理示例
*/
export function exampleDynamicPlugins() {
console.log('=== 动态插件管理示例 ===')
const client = PluginEnabledAiClient.create('openai-compatible', {
name: 'openai',
baseURL: 'https://api.openai.com/v1',
apiKey: 'your-api-key'
})
console.log('初始状态:', client.getPluginStats())
// 动态添加插件
client.use(LoggingPlugin)
console.log('添加 LoggingPlugin 后:', client.getPluginStats())
client.usePlugins([ContentFilterPlugin])
console.log('添加 ContentFilterPlugin 后:', client.getPluginStats())
// 移除插件
client.removePlugin('content-filter')
console.log('移除 content-filter 后:', client.getPluginStats())
return client
}
/**
* 完整的低级 API 示例(原有的 example-usage.ts 的方式)
* 这种方式适合需要精细控制插件生命周期的场景
*/
export async function exampleLowLevelApi() {
console.log('=== 低级 API 示例 ===')
// 1. 创建插件管理器
const pluginManager = new PluginManager([LoggingPlugin, ContentFilterPlugin])
// 2. 创建请求上下文
const context = createContext('openai', 'gpt-4', {
messages: [{ role: 'user', content: 'Hello!' }]
})
try {
// 3. 触发请求开始事件
await pluginManager.executeParallel('onRequestStart', context)
// 4. 解析模型别名
const resolvedModel = await pluginManager.executeFirst('resolveModel', 'gpt-4', context)
console.log('Resolved model:', resolvedModel || 'gpt-4')
// 5. 转换请求参数
const params = {
messages: [{ role: 'user' as const, content: 'Hello, AI!' }],
temperature: 0.7
}
const transformedParams = await pluginManager.executeSequential('transformParams', params, context)
// 6. 收集流转换器关键AI SDK 原生支持数组!)
const streamTransforms = pluginManager.collectStreamTransforms()
// 7. 调用 AI SDK直接传入转换器工厂数组
const result = await streamText({
model: openai('gpt-4'),
...transformedParams,
experimental_transform: streamTransforms // 直接传入工厂函数数组
})
// 8. 处理结果
let fullText = ''
for await (const textPart of result.textStream) {
fullText += textPart
console.log('Streaming:', textPart)
}
// 9. 转换最终结果
const finalResult = { text: fullText, usage: await result.usage }
const transformedResult = await pluginManager.executeSequential('transformResult', finalResult, context)
// 10. 触发完成事件
await pluginManager.executeParallel('onRequestEnd', context, transformedResult)
return transformedResult
} catch (error) {
// 11. 触发错误事件
await pluginManager.executeParallel('onError', context, undefined, error as Error)
throw error
}
}
/**
* 流转换器数组的其他使用方式
*/
export function demonstrateStreamTransforms() {
console.log('=== 流转换器示例 ===')
const pluginManager = new PluginManager([
ContentFilterPlugin,
{
name: 'text-replacer',
transformStream() {
return () =>
new TransformStream({
transform(chunk, controller) {
if (chunk.type === 'text-delta') {
const replaced = chunk.textDelta.replace(/hello/gi, 'hi')
controller.enqueue({ ...chunk, textDelta: replaced })
} else {
controller.enqueue(chunk)
}
}
})
}
}
])
// 获取所有流转换器
const transforms = pluginManager.collectStreamTransforms()
console.log(`收集到 ${transforms.length} 个流转换器`)
// 可以单独使用每个转换器
transforms.forEach((factory, index) => {
console.log(`转换器 ${index + 1} 已准备就绪`)
const transform = factory({ stopStream: () => {} })
console.log('Transform created:', transform)
})
return transforms
}
/**
* 运行所有示例
*/
export async function runAllExamples() {
console.log('🚀 开始运行所有示例...\n')
try {
// 1. PluginEnabledAiClient 示例(推荐)
await exampleWithPluginEnabledClient()
console.log('✅ PluginEnabledAiClient 示例完成\n')
// 2. OpenAI Compatible 示例
exampleOpenAICompatible()
console.log('✅ OpenAI Compatible 示例完成\n')
// 3. 动态插件管理示例
exampleDynamicPlugins()
console.log('✅ 动态插件管理示例完成\n')
// 4. 流转换器示例
demonstrateStreamTransforms()
console.log('✅ 流转换器示例完成\n')
// 5. 低级 API 示例
// await exampleLowLevelApi()
console.log('✅ 低级 API 示例完成\n')
console.log('🎉 所有示例运行完成!')
} catch (error) {
console.error('❌ 示例运行失败:', error)
}
}

View File

@@ -1,23 +0,0 @@
// 核心类型和接口
export type { AiPlugin, AiRequestContext, HookResult, HookType, PluginManagerConfig } from './types'
import type { AiPlugin, AiRequestContext } from './types'
// 插件管理器
export { PluginManager } from './manager'
// 工具函数
export function createContext(providerId: string, modelId: string, originalParams: any): AiRequestContext {
return {
providerId,
modelId,
originalParams,
metadata: {},
startTime: Date.now(),
requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`
}
}
// 插件构建器 - 便于创建插件
export function definePlugin(plugin: AiPlugin): AiPlugin {
return plugin
}

View File

@@ -1,189 +0,0 @@
import type { LanguageModelV1Middleware, TextStreamPart, ToolSet } from 'ai'
import { AiPlugin, AiRequestContext } from './types'
/**
* 插件管理器 - 基于 Rollup 钩子分类设计
*/
export class PluginManager {
private plugins: AiPlugin[] = []
constructor(plugins: AiPlugin[] = []) {
this.plugins = this.sortPlugins(plugins)
}
/**
* 添加插件
*/
use(plugin: AiPlugin): this {
this.plugins = this.sortPlugins([...this.plugins, plugin])
return this
}
/**
* 移除插件
*/
remove(pluginName: string): this {
this.plugins = this.plugins.filter((p) => p.name !== pluginName)
return this
}
/**
* 插件排序pre -> normal -> post
*/
private sortPlugins(plugins: AiPlugin[]): AiPlugin[] {
const pre: AiPlugin[] = []
const normal: AiPlugin[] = []
const post: AiPlugin[] = []
plugins.forEach((plugin) => {
if (plugin.enforce === 'pre') {
pre.push(plugin)
} else if (plugin.enforce === 'post') {
post.push(plugin)
} else {
normal.push(plugin)
}
})
return [...pre, ...normal, ...post]
}
/**
* 执行 First 钩子 - 返回第一个有效结果
*/
async executeFirst<T>(
hookName: 'resolveModel' | 'loadTemplate',
arg: string,
context: AiRequestContext
): Promise<T | null> {
for (const plugin of this.plugins) {
const hook = plugin[hookName]
if (hook) {
const result = await hook(arg, context)
if (result !== null && result !== undefined) {
return result as T
}
}
}
return null
}
/**
* 执行 Sequential 钩子 - 链式数据转换
*/
async executeSequential<T>(
hookName: 'transformParams' | 'transformResult',
initialValue: T,
context: AiRequestContext
): Promise<T> {
let result = initialValue
for (const plugin of this.plugins) {
const hook = plugin[hookName]
if (hook) {
result = await hook(result, context)
}
}
return result
}
/**
* 执行 Parallel 钩子 - 并行副作用
*/
async executeParallel(
hookName: 'onRequestStart' | 'onRequestEnd' | 'onError',
context: AiRequestContext,
result?: any,
error?: Error
): Promise<void> {
const promises = this.plugins
.map((plugin) => {
const hook = plugin[hookName]
if (!hook) return null
if (hookName === 'onError' && error) {
return (hook as any)(error, context)
} else if (hookName === 'onRequestEnd' && result !== undefined) {
return (hook as any)(context, result)
} else if (hookName === 'onRequestStart') {
return (hook as any)(context)
}
return null
})
.filter(Boolean)
// 使用 Promise.all 而不是 allSettled让插件错误能够抛出
await Promise.all(promises)
}
/**
* 收集所有流转换器返回数组AI SDK 原生支持)
*/
collectStreamTransforms<TOOLS extends ToolSet>(): Array<
(options: {
tools?: TOOLS
stopStream: () => void
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
> {
return this.plugins.map((plugin) => plugin.transformStream?.()).filter(Boolean) as Array<
(options: {
tools?: TOOLS
stopStream: () => void
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
>
}
/**
* 收集所有 AI SDK 原生中间件
*/
collectAiSdkMiddlewares(): LanguageModelV1Middleware[] {
return this.plugins.flatMap((plugin) => plugin.aiSdkMiddlewares || [])
}
/**
* 获取所有插件信息
*/
getPlugins(): AiPlugin[] {
return [...this.plugins]
}
/**
* 获取插件统计信息
*/
getStats() {
const stats = {
total: this.plugins.length,
pre: 0,
normal: 0,
post: 0,
hooks: {
resolveModel: 0,
loadTemplate: 0,
transformParams: 0,
transformResult: 0,
onRequestStart: 0,
onRequestEnd: 0,
onError: 0,
transformStream: 0
}
}
this.plugins.forEach((plugin) => {
// 统计 enforce 类型
if (plugin.enforce === 'pre') stats.pre++
else if (plugin.enforce === 'post') stats.post++
else stats.normal++
// 统计钩子数量
Object.keys(stats.hooks).forEach((hookName) => {
if (plugin[hookName as keyof AiPlugin]) {
stats.hooks[hookName as keyof typeof stats.hooks]++
}
})
})
return stats
}
}

View File

@@ -1,87 +0,0 @@
import type { LanguageModelV1Middleware, TextStreamPart, ToolSet } from 'ai'
/**
* 生命周期阶段定义
*/
export enum LifecycleStage {
PRE_REQUEST = 'pre-request', // 请求预处理
REQUEST_EXECUTION = 'execution', // 请求执行
STREAM_PROCESSING = 'stream', // 流式处理(仅流模式)
POST_RESPONSE = 'post-response', // 响应后处理
ERROR_HANDLING = 'error' // 错误处理
}
/**
* 生命周期上下文
*/
export interface LifecycleContext {
currentStage: LifecycleStage
startTime: number
stageStartTime: number
completedStages: Set<LifecycleStage>
stageDurations: Map<LifecycleStage, number>
metadata: Record<string, any>
}
/**
* AI 请求上下文
*/
export interface AiRequestContext {
providerId: string
modelId: string
originalParams: any
metadata: Record<string, any>
startTime: number
requestId: string
}
/**
* 借鉴 Rollup 的钩子分类设计
*/
export interface AiPlugin {
name: string
enforce?: 'pre' | 'post'
// 【First】首个钩子 - 只执行第一个返回值的插件
resolveModel?: (modelId: string, context: AiRequestContext) => string | null | Promise<string | null>
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
// 【Sequential】串行钩子 - 链式执行,支持数据转换
transformParams?: (params: any, context: AiRequestContext) => any | Promise<any>
transformResult?: (result: any, context: AiRequestContext) => any | Promise<any>
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise<void>
onError?: (error: Error, context: AiRequestContext) => void | Promise<void>
// 【Stream】流处理 - 直接使用 AI SDK
transformStream?: <TOOLS extends ToolSet>() => (options: {
tools?: TOOLS
stopStream: () => void
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
// AI SDK 原生中间件
aiSdkMiddlewares?: LanguageModelV1Middleware[]
}
/**
* 插件管理器配置
*/
export interface PluginManagerConfig {
plugins: AiPlugin[]
context: Partial<AiRequestContext>
}
/**
* 钩子执行器类型
*/
export type HookType = 'first' | 'sequential' | 'parallel' | 'stream'
/**
* 钩子执行结果
*/
export interface HookResult<T = any> {
value: T
stop?: boolean
}

View File

@@ -1,376 +0,0 @@
/**
* AI Provider 注册表
* 静态类型 + 动态导入模式:所有类型静态导入,所有实现动态导入
*/
// 静态导入所有 AI SDK 类型
import { type AmazonBedrockProviderSettings } from '@ai-sdk/amazon-bedrock'
import { type AnthropicProviderSettings } from '@ai-sdk/anthropic'
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
import { type CerebrasProviderSettings } from '@ai-sdk/cerebras'
import { type CohereProviderSettings } from '@ai-sdk/cohere'
import { type DeepInfraProviderSettings } from '@ai-sdk/deepinfra'
import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
import { type FalProviderSettings } from '@ai-sdk/fal'
import { type FireworksProviderSettings } from '@ai-sdk/fireworks'
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
import { type GoogleVertexProviderSettings } from '@ai-sdk/google-vertex'
import { type GroqProviderSettings } from '@ai-sdk/groq'
import { type MistralProviderSettings } from '@ai-sdk/mistral'
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
import { type PerplexityProviderSettings } from '@ai-sdk/perplexity'
import { type ReplicateProviderSettings } from '@ai-sdk/replicate'
import { type TogetherAIProviderSettings } from '@ai-sdk/togetherai'
import { type VercelProviderSettings } from '@ai-sdk/vercel'
import { type XaiProviderSettings } from '@ai-sdk/xai'
import { type OpenRouterProviderSettings } from '@openrouter/ai-sdk-provider'
import { type AnthropicVertexProviderSettings } from 'anthropic-vertex-ai'
import { type OllamaProviderSettings } from 'ollama-ai-provider'
import { type QwenProviderSettings } from 'qwen-ai-provider'
import { type ZhipuProviderSettings } from 'zhipu-ai-provider'
// 类型安全的 Provider Settings 映射
export type ProviderSettingsMap = {
openai: OpenAIProviderSettings
'openai-compatible': OpenAICompatibleProviderSettings
anthropic: AnthropicProviderSettings
google: GoogleGenerativeAIProviderSettings
'google-vertex': GoogleVertexProviderSettings
mistral: MistralProviderSettings
xai: XaiProviderSettings
azure: AzureOpenAIProviderSettings
bedrock: AmazonBedrockProviderSettings
cohere: CohereProviderSettings
groq: GroqProviderSettings
together: TogetherAIProviderSettings
fireworks: FireworksProviderSettings
deepseek: DeepSeekProviderSettings
cerebras: CerebrasProviderSettings
deepinfra: DeepInfraProviderSettings
replicate: ReplicateProviderSettings
perplexity: PerplexityProviderSettings
fal: FalProviderSettings
vercel: VercelProviderSettings
ollama: OllamaProviderSettings
qwen: QwenProviderSettings
zhipu: ZhipuProviderSettings
'anthropic-vertex': AnthropicVertexProviderSettings
openrouter: OpenRouterProviderSettings
}
export type ProviderId = keyof ProviderSettingsMap
// 统一的 Provider 配置接口(所有都使用动态导入)
export interface ProviderConfig {
id: string
name: string
// 动态导入函数
import: () => Promise<any>
// 创建函数名称
creatorFunctionName: string
// 是否支持图片生成
supportsImageGeneration?: boolean
}
/**
* AI SDK Provider 注册表
* 管理所有支持的 AI Providers 及其动态导入
*/
export class AiProviderRegistry {
private static instance: AiProviderRegistry
private registry = new Map<string, ProviderConfig>()
private constructor() {
this.initializeProviders()
}
public static getInstance(): AiProviderRegistry {
if (!AiProviderRegistry.instance) {
AiProviderRegistry.instance = new AiProviderRegistry()
}
return AiProviderRegistry.instance
}
/**
* 初始化所有支持的 Providers
* 基于 AI SDK 官方文档: https://ai-sdk.dev/providers/ai-sdk-providers
*/
private initializeProviders(): void {
const providers: ProviderConfig[] = [
// 官方 AI SDK Providers (19个)
{
id: 'openai',
name: 'OpenAI',
import: () => import('@ai-sdk/openai'),
creatorFunctionName: 'createOpenAI',
supportsImageGeneration: true
},
{
id: 'openai-compatible',
name: 'OpenAI Compatible',
import: () => import('@ai-sdk/openai-compatible'),
creatorFunctionName: 'createOpenAICompatible'
},
{
id: 'anthropic',
name: 'Anthropic',
import: () => import('@ai-sdk/anthropic'),
creatorFunctionName: 'createAnthropic',
supportsImageGeneration: false
},
{
id: 'google',
name: 'Google Generative AI',
import: () => import('@ai-sdk/google'),
creatorFunctionName: 'createGoogleGenerativeAI',
supportsImageGeneration: true
},
{
id: 'google-vertex',
name: 'Google Vertex AI',
import: () => import('@ai-sdk/google-vertex'),
creatorFunctionName: 'createVertex',
supportsImageGeneration: true
},
{
id: 'mistral',
name: 'Mistral AI',
import: () => import('@ai-sdk/mistral'),
creatorFunctionName: 'createMistral',
supportsImageGeneration: false
},
{
id: 'xai',
name: 'xAI (Grok)',
import: () => import('@ai-sdk/xai'),
creatorFunctionName: 'createXai',
supportsImageGeneration: true
},
{
id: 'azure',
name: 'Azure OpenAI',
import: () => import('@ai-sdk/azure'),
creatorFunctionName: 'createAzure',
supportsImageGeneration: true
},
{
id: 'bedrock',
name: 'Amazon Bedrock',
import: () => import('@ai-sdk/amazon-bedrock'),
creatorFunctionName: 'createAmazonBedrock',
supportsImageGeneration: false
},
{
id: 'cohere',
name: 'Cohere',
import: () => import('@ai-sdk/cohere'),
creatorFunctionName: 'createCohere',
supportsImageGeneration: false
},
{
id: 'groq',
name: 'Groq',
import: () => import('@ai-sdk/groq'),
creatorFunctionName: 'createGroq',
supportsImageGeneration: false
},
{
id: 'together',
name: 'Together.ai',
import: () => import('@ai-sdk/togetherai'),
creatorFunctionName: 'createTogetherAI',
supportsImageGeneration: true
},
{
id: 'fireworks',
name: 'Fireworks',
import: () => import('@ai-sdk/fireworks'),
creatorFunctionName: 'createFireworks',
supportsImageGeneration: true
},
{
id: 'deepseek',
name: 'DeepSeek',
import: () => import('@ai-sdk/deepseek'),
creatorFunctionName: 'createDeepSeek',
supportsImageGeneration: false
},
{
id: 'cerebras',
name: 'Cerebras',
import: () => import('@ai-sdk/cerebras'),
creatorFunctionName: 'createCerebras',
supportsImageGeneration: false
},
{
id: 'deepinfra',
name: 'DeepInfra',
import: () => import('@ai-sdk/deepinfra'),
creatorFunctionName: 'createDeepInfra',
supportsImageGeneration: false
},
{
id: 'replicate',
name: 'Replicate',
import: () => import('@ai-sdk/replicate'),
creatorFunctionName: 'createReplicate',
supportsImageGeneration: true
},
{
id: 'perplexity',
name: 'Perplexity',
import: () => import('@ai-sdk/perplexity'),
creatorFunctionName: 'createPerplexity',
supportsImageGeneration: false
},
{
id: 'fal',
name: 'Fal AI',
import: () => import('@ai-sdk/fal'),
creatorFunctionName: 'createFal',
supportsImageGeneration: false
},
{
id: 'vercel',
name: 'Vercel',
import: () => import('@ai-sdk/vercel'),
creatorFunctionName: 'createVercel'
},
// 社区 Providers (5个)
{
id: 'ollama',
name: 'Ollama',
import: () => import('ollama-ai-provider'),
creatorFunctionName: 'createOllama',
supportsImageGeneration: false
},
{
id: 'qwen',
name: 'Qwen',
import: () => import('qwen-ai-provider'),
creatorFunctionName: 'createQwen',
supportsImageGeneration: false
},
{
id: 'zhipu',
name: 'Zhipu AI',
import: () => import('zhipu-ai-provider'),
creatorFunctionName: 'createZhipu',
supportsImageGeneration: false
},
{
id: 'anthropic-vertex',
name: 'Anthropic Vertex AI',
import: () => import('anthropic-vertex-ai'),
creatorFunctionName: 'createAnthropicVertex',
supportsImageGeneration: false
},
{
id: 'openrouter',
name: 'OpenRouter',
import: () => import('@openrouter/ai-sdk-provider'),
creatorFunctionName: 'createOpenRouter',
supportsImageGeneration: false
}
]
// 注册所有 providers (总计24个)
providers.forEach((config) => {
this.registry.set(config.id, config)
})
}
/**
* 获取所有已注册的 Providers
*/
public getAllProviders(): ProviderConfig[] {
return Array.from(this.registry.values())
}
/**
* 根据 ID 获取 Provider 配置
*/
public getProvider(id: string): ProviderConfig | undefined {
return this.registry.get(id)
}
/**
* 检查 Provider 是否支持(是否已注册)
*/
public isSupported(id: string): boolean {
return this.registry.has(id)
}
/**
* 注册新的 Provider用于扩展
*/
public registerProvider(config: ProviderConfig): void {
this.registry.set(config.id, config)
}
/**
* 清理资源
*/
public cleanup(): void {
this.registry.clear()
}
/**
* 获取兼容现有实现的注册表格式
*/
public getCompatibleRegistry(): Record<string, { import: () => Promise<any>; creatorFunctionName: string }> {
const compatibleRegistry: Record<string, { import: () => Promise<any>; creatorFunctionName: string }> = {}
this.getAllProviders().forEach((provider) => {
compatibleRegistry[provider.id] = {
import: provider.import,
creatorFunctionName: provider.creatorFunctionName
}
})
return compatibleRegistry
}
}
// 导出单例实例
export const aiProviderRegistry = AiProviderRegistry.getInstance()
// 便捷函数
export const getProvider = (id: string) => aiProviderRegistry.getProvider(id)
export const getAllProviders = () => aiProviderRegistry.getAllProviders()
export const isProviderSupported = (id: string) => aiProviderRegistry.isSupported(id)
export const registerProvider = (config: ProviderConfig) => aiProviderRegistry.registerProvider(config)
// 兼容现有实现的导出
export const PROVIDER_REGISTRY = aiProviderRegistry.getCompatibleRegistry()
// 重新导出所有类型供外部使用
export type {
AmazonBedrockProviderSettings,
AnthropicProviderSettings,
AnthropicVertexProviderSettings,
AzureOpenAIProviderSettings,
CerebrasProviderSettings,
CohereProviderSettings,
DeepInfraProviderSettings,
DeepSeekProviderSettings,
FalProviderSettings,
FireworksProviderSettings,
GoogleGenerativeAIProviderSettings,
GoogleVertexProviderSettings,
GroqProviderSettings,
MistralProviderSettings,
OllamaProviderSettings,
OpenAICompatibleProviderSettings,
OpenAIProviderSettings,
OpenRouterProviderSettings,
PerplexityProviderSettings,
QwenProviderSettings,
ReplicateProviderSettings,
TogetherAIProviderSettings,
VercelProviderSettings,
XaiProviderSettings,
ZhipuProviderSettings
}

View File

@@ -1,47 +0,0 @@
/**
* Provider 相关核心类型定义
* 只定义必要的接口,其他类型直接使用 AI SDK
*/
// Provider 配置接口(简化版)
export interface ProviderConfig {
id: string
name: string
import: () => Promise<any>
creatorFunctionName: string
}
// API 客户端工厂接口
export interface ApiClientFactory {
createAiSdkClient(providerId: string, options?: any): Promise<any>
getCachedClient(providerId: string, options?: any): any
clearCache(): void
}
// 客户端配置
export interface ClientConfig {
providerId: string
apiKey?: string
baseURL?: string
[key: string]: any
}
// 错误类型
export class ProviderError extends Error {
constructor(
message: string,
public providerId: string,
public code?: string,
public cause?: Error
) {
super(message)
this.name = 'ProviderError'
}
}
// 缓存统计信息
export interface CacheStats {
size: number
keys: string[]
lastCleanup?: Date
}

View File

@@ -1,26 +0,0 @@
{
"compilerOptions": {
"target": "ES2020",
"module": "ESNext",
"moduleResolution": "node",
"declaration": true,
"outDir": "./dist",
"rootDir": "./src",
"strict": true,
"esModuleInterop": true,
"skipLibCheck": true,
"forceConsistentCasingInFileNames": true,
"resolveJsonModule": true,
"allowSyntheticDefaultImports": true,
"noEmitOnError": false,
"experimentalDecorators": true,
"emitDecoratorMetadata": true
},
"include": [
"src/**/*"
],
"exclude": [
"node_modules",
"dist"
]
}

View File

@@ -15,12 +15,7 @@ export enum IpcChannel {
App_SetAutoUpdate = 'app:set-auto-update',
App_SetFeedUrl = 'app:set-feed-url',
App_HandleZoomFactor = 'app:handle-zoom-factor',
App_Select = 'app:select',
App_HasWritePermission = 'app:has-write-permission',
App_Copy = 'app:copy',
App_SetStopQuitApp = 'app:set-stop-quit-app',
App_SetAppDataPath = 'app:set-app-data-path',
App_RelaunchApp = 'app:relaunch-app',
App_IsBinaryExist = 'app:is-binary-exist',
App_GetBinaryPath = 'app:get-binary-path',
App_InstallUvBinary = 'app:install-uv-binary',
@@ -91,10 +86,6 @@ export enum IpcChannel {
Gemini_ListFiles = 'gemini:list-files',
Gemini_DeleteFile = 'gemini:delete-file',
// VertexAI
VertexAI_GetAuthHeaders = 'vertexai:get-auth-headers',
VertexAI_ClearAuthCache = 'vertexai:clear-auth-cache',
Windows_ResetMinimumSize = 'window:reset-minimum-size',
Windows_SetMinimumSize = 'window:set-minimum-size',
@@ -127,7 +118,6 @@ export enum IpcChannel {
File_Copy = 'file:copy',
File_BinaryImage = 'file:binaryImage',
File_Base64File = 'file:base64File',
File_GetPdfInfo = 'file:getPdfInfo',
Fs_Read = 'fs:read',
Export_Word = 'export:word',

View File

@@ -408,4 +408,3 @@ export enum FeedUrl {
PRODUCTION = 'https://releases.cherry-ai.com',
EARLY_ACCESS = 'https://github.com/CherryHQ/cherry-studio/releases/latest/download'
}
export const defaultTimeout = 5 * 1000 * 60

File diff suppressed because it is too large Load Diff

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -36,11 +36,6 @@ exports.default = async function (context) {
keepPackageNodeFiles(node_modules_path, '@libsql', ['win32-x64-msvc'])
}
}
if (platform === 'windows') {
fs.rmSync(path.join(context.appOutDir, 'LICENSE.electron.txt'), { force: true })
fs.rmSync(path.join(context.appOutDir, 'LICENSES.chromium.html'), { force: true })
}
}
/**

View File

@@ -1,6 +1,7 @@
import { app } from 'electron'
import { getDataPath } from './utils'
const isDev = process.env.NODE_ENV === 'development'
if (isDev) {

View File

@@ -1,7 +1,6 @@
import '@main/config'
import { electronApp, optimizer } from '@electron-toolkit/utils'
import { initAppDataDir } from '@main/utils/file'
import { replaceDevtoolsFont } from '@main/utils/windowUtil'
import { app } from 'electron'
import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer'
@@ -21,8 +20,8 @@ import selectionService, { initSelectionService } from './services/SelectionServ
import { registerShortcuts } from './services/ShortcutService'
import { TrayService } from './services/TrayService'
import { windowService } from './services/WindowService'
import { setUserDataDir } from './utils/file'
initAppDataDir()
Logger.initialize()
/**
@@ -73,6 +72,9 @@ if (!app.requestSingleInstanceLock()) {
app.quit()
process.exit(0)
} else {
// Portable dir must be setup before app ready
setUserDataDir()
// This method will be called when Electron has finished
// initialization and is ready to create browser windows.
// Some APIs can only be used after this event occurs.

View File

@@ -4,10 +4,9 @@ import { arch } from 'node:os'
import { isMac, isWin } from '@main/constant'
import { getBinaryPath, isBinaryExists, runInstallScript } from '@main/utils/process'
import { handleZoomFactor } from '@main/utils/zoom'
import { FeedUrl } from '@shared/config/constant'
import { IpcChannel } from '@shared/IpcChannel'
import { Shortcut, ThemeMode } from '@types'
import { BrowserWindow, dialog, ipcMain, session, shell } from 'electron'
import { BrowserWindow, ipcMain, session, shell } from 'electron'
import log from 'electron-log'
import { Notification } from 'src/renderer/src/types/notification'
@@ -29,19 +28,18 @@ import { SelectionService } from './services/SelectionService'
import { registerShortcuts, unregisterAllShortcuts } from './services/ShortcutService'
import storeSyncService from './services/StoreSyncService'
import { themeService } from './services/ThemeService'
import VertexAIService from './services/VertexAIService'
import { setOpenLinkExternal } from './services/WebviewService'
import { windowService } from './services/WindowService'
import { calculateDirectorySize, getResourcePath } from './utils'
import { decrypt, encrypt } from './utils/aes'
import { getCacheDir, getConfigDir, getFilesDir, hasWritePermission, updateConfig } from './utils/file'
import { getCacheDir, getConfigDir, getFilesDir } from './utils/file'
import { compress, decompress } from './utils/zip'
import { FeedUrl } from '@shared/config/constant'
const fileManager = new FileStorage()
const backupManager = new BackupManager()
const exportService = new ExportService(fileManager)
const obsidianVaultService = new ObsidianVaultService()
const vertexAIService = VertexAIService.getInstance()
export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
const appUpdater = new AppUpdater(mainWindow)
@@ -175,70 +173,6 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
}
})
let preventQuitListener: ((event: Electron.Event) => void) | null = null
ipcMain.handle(IpcChannel.App_SetStopQuitApp, (_, stop: boolean = false, reason: string = '') => {
if (stop) {
// Only add listener if not already added
if (!preventQuitListener) {
preventQuitListener = (event: Electron.Event) => {
event.preventDefault()
notificationService.sendNotification({
title: reason,
message: reason
} as Notification)
}
app.on('before-quit', preventQuitListener)
}
} else {
// Remove listener if it exists
if (preventQuitListener) {
app.removeListener('before-quit', preventQuitListener)
preventQuitListener = null
}
}
})
// Select app data path
ipcMain.handle(IpcChannel.App_Select, async (_, options: Electron.OpenDialogOptions) => {
try {
const { canceled, filePaths } = await dialog.showOpenDialog(options)
if (canceled || filePaths.length === 0) {
return null
}
return filePaths[0]
} catch (error: any) {
log.error('Failed to select app data path:', error)
return null
}
})
ipcMain.handle(IpcChannel.App_HasWritePermission, async (_, filePath: string) => {
return hasWritePermission(filePath)
})
// Set app data path
ipcMain.handle(IpcChannel.App_SetAppDataPath, async (_, filePath: string) => {
updateConfig(filePath)
app.setPath('userData', filePath)
})
// Copy user data to new location
ipcMain.handle(IpcChannel.App_Copy, async (_, oldPath: string, newPath: string) => {
try {
await fs.promises.cp(oldPath, newPath, { recursive: true })
return { success: true }
} catch (error: any) {
log.error('Failed to copy user data:', error)
return { success: false, error: error.message }
}
})
// Relaunch app
ipcMain.handle(IpcChannel.App_RelaunchApp, () => {
app.relaunch()
app.exit(0)
})
// check for update
ipcMain.handle(IpcChannel.App_CheckForUpdate, async () => {
return await appUpdater.checkForUpdates()
@@ -292,7 +226,6 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle(IpcChannel.File_Base64Image, fileManager.base64Image)
ipcMain.handle(IpcChannel.File_SaveBase64Image, fileManager.saveBase64Image)
ipcMain.handle(IpcChannel.File_Base64File, fileManager.base64File)
ipcMain.handle(IpcChannel.File_GetPdfInfo, fileManager.pdfPageCount)
ipcMain.handle(IpcChannel.File_Download, fileManager.downloadFile)
ipcMain.handle(IpcChannel.File_Copy, fileManager.copyFile)
ipcMain.handle(IpcChannel.File_BinaryImage, fileManager.binaryImage)
@@ -340,15 +273,6 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
}
})
// VertexAI
ipcMain.handle(IpcChannel.VertexAI_GetAuthHeaders, async (_, params) => {
return vertexAIService.getAuthHeaders(params)
})
ipcMain.handle(IpcChannel.VertexAI_ClearAuthCache, async (_, projectId: string, clientEmail?: string) => {
vertexAIService.clearAuthCache(projectId, clientEmail)
})
// mini window
ipcMain.handle(IpcChannel.MiniWindow_Show, () => windowService.showMiniWindow())
ipcMain.handle(IpcChannel.MiniWindow_Hide, () => windowService.hideMiniWindow())

View File

@@ -21,13 +21,10 @@ export default abstract class BaseReranker {
return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
}
let baseURL = this.base.rerankBaseURL
if (baseURL && baseURL.endsWith('/')) {
// `/` 结尾强制使用rerankBaseURL
return `${baseURL}rerank`
}
let baseURL = this.base?.rerankBaseURL?.endsWith('/')
? this.base.rerankBaseURL.slice(0, -1)
: this.base.rerankBaseURL
// 必须携带/v1否则会404
if (baseURL && !baseURL.endsWith('/v1')) {
baseURL = `${baseURL}/v1`
}
@@ -61,12 +58,6 @@ export default abstract class BaseReranker {
top_n: topN
}
}
} else if (provider?.includes('tei')) {
return {
query,
texts: documents,
return_text: true
}
} else {
return {
model: this.base.rerankModel,
@@ -86,13 +77,6 @@ export default abstract class BaseReranker {
return data.output.results
} else if (provider === 'voyageai') {
return data.data
} else if (provider === 'mis-tei') {
return data.map((item: any) => {
return {
index: item.index,
relevance_score: item.score
}
})
} else {
return data.results
}

View File

@@ -1,12 +1,11 @@
import { isWin } from '@main/constant'
import { locales } from '@main/utils/locales'
import { FeedUrl } from '@shared/config/constant'
import { IpcChannel } from '@shared/IpcChannel'
import { FeedUrl } from '@shared/config/constant'
import { UpdateInfo } from 'builder-util-runtime'
import { app, BrowserWindow, dialog } from 'electron'
import logger from 'electron-log'
import { AppUpdater as _AppUpdater, autoUpdater, NsisUpdater } from 'electron-updater'
import path from 'path'
import { AppUpdater as _AppUpdater, autoUpdater } from 'electron-updater'
import icon from '../../../build/icon.png?asset'
import { configManager } from './ConfigManager'
@@ -57,37 +56,9 @@ export default class AppUpdater {
logger.info('下载完成', releaseInfo)
})
if (isWin) {
;(autoUpdater as NsisUpdater).installDirectory = path.dirname(app.getPath('exe'))
}
this.autoUpdater = autoUpdater
}
private async _getIpCountry() {
try {
// add timeout using AbortController
const controller = new AbortController()
const timeoutId = setTimeout(() => controller.abort(), 5000)
const ipinfo = await fetch('https://ipinfo.io/json', {
signal: controller.signal,
headers: {
'User-Agent':
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36',
'Accept-Language': 'en-US,en;q=0.9'
}
})
clearTimeout(timeoutId)
const data = await ipinfo.json()
return data.country || 'CN'
} catch (error) {
logger.error('Failed to get ipinfo:', error)
return 'CN'
}
}
public setAutoUpdate(isActive: boolean) {
autoUpdater.autoDownload = isActive
autoUpdater.autoInstallOnAppQuit = isActive
@@ -106,12 +77,6 @@ export default class AppUpdater {
}
}
const ipCountry = await this._getIpCountry()
logger.info('ipCountry', ipCountry)
if (ipCountry !== 'CN') {
this.autoUpdater.setFeedURL(FeedUrl.EARLY_ACCESS)
}
try {
const update = await this.autoUpdater.checkForUpdates()
if (update?.isUpdateAvailable && !this.autoUpdater.autoDownload) {

View File

@@ -9,7 +9,6 @@ import StreamZip from 'node-stream-zip'
import * as path from 'path'
import { CreateDirectoryOptions, FileStat } from 'webdav'
import { getDataPath } from '../utils'
import WebDav from './WebDav'
import { windowService } from './WindowService'
@@ -254,7 +253,7 @@ class BackupManager {
Logger.log('[backup] step 3: restore Data directory')
// 恢复 Data 目录
const sourcePath = path.join(this.tempDir, 'Data')
const destPath = getDataPath()
const destPath = path.join(app.getPath('userData'), 'Data')
const dataExists = await fs.pathExists(sourcePath)
const dataFiles = dataExists ? await fs.readdir(sourcePath) : []

View File

@@ -15,7 +15,6 @@ import * as fs from 'fs'
import { writeFileSync } from 'fs'
import { readFile } from 'fs/promises'
import officeParser from 'officeparser'
import { getDocument } from 'officeparser/pdfjs-dist-build/pdf.js'
import * as path from 'path'
import { chdir } from 'process'
import { v4 as uuidv4 } from 'uuid'
@@ -322,16 +321,6 @@ class FileStorage {
return { data: base64, mime }
}
public pdfPageCount = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<number> => {
const filePath = path.join(this.storageDir, id)
const buffer = await fs.promises.readFile(filePath)
const doc = await getDocument({ data: buffer }).promise
const pages = doc.numPages
await doc.destroy()
return pages
}
public binaryImage = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<{ data: Buffer; mime: string }> => {
const filePath = path.join(this.storageDir, id)
const data = await fs.promises.readFile(filePath)

View File

@@ -25,12 +25,12 @@ import Embeddings from '@main/embeddings/Embeddings'
import { addFileLoader } from '@main/loader'
import Reranker from '@main/reranker/Reranker'
import { windowService } from '@main/services/WindowService'
import { getDataPath } from '@main/utils'
import { getAllFiles } from '@main/utils/file'
import { MB } from '@shared/config/constant'
import type { LoaderReturn } from '@shared/config/types'
import { IpcChannel } from '@shared/IpcChannel'
import { FileType, KnowledgeBaseParams, KnowledgeItem } from '@types'
import { app } from 'electron'
import Logger from 'electron-log'
import { v4 as uuidv4 } from 'uuid'
@@ -88,7 +88,7 @@ const loaderTaskIntoOfSet = (loaderTask: LoaderTask): LoaderTaskOfSet => {
}
class KnowledgeService {
private storageDir = path.join(getDataPath(), 'KnowledgeBase')
private storageDir = path.join(app.getPath('userData'), 'Data', 'KnowledgeBase')
// Byte based
private workload = 0
private processingItemCount = 0

View File

@@ -285,7 +285,7 @@ export class SelectionService {
this.processTriggerMode()
this.started = true
this.logInfo('SelectionService Started', true)
this.logInfo('SelectionService Started')
return true
}
@@ -319,7 +319,7 @@ export class SelectionService {
this.closePreloadedActionWindows()
this.started = false
this.logInfo('SelectionService Stopped', true)
this.logInfo('SelectionService Stopped')
return true
}
@@ -335,7 +335,7 @@ export class SelectionService {
this.selectionHook = null
this.initStatus = false
SelectionService.instance = null
this.logInfo('SelectionService Quitted', true)
this.logInfo('SelectionService Quitted')
}
/**
@@ -456,18 +456,8 @@ export class SelectionService {
x: posX,
y: posY
})
//set the window to always on top (highest level)
//should set every time the window is shown
this.toolbarWindow!.setAlwaysOnTop(true, 'screen-saver')
this.toolbarWindow!.show()
/**
* In Windows 10, setOpacity(1) will make the window completely transparent
* It's a strange behavior, so we don't use it for compatibility
*/
// this.toolbarWindow!.setOpacity(1)
this.toolbarWindow!.setOpacity(1)
this.startHideByMouseKeyListener()
}
@@ -477,7 +467,7 @@ export class SelectionService {
public hideToolbar(): void {
if (!this.isToolbarAlive()) return
// this.toolbarWindow!.setOpacity(0)
this.toolbarWindow!.setOpacity(0)
this.toolbarWindow!.hide()
this.stopHideByMouseKeyListener()
@@ -1274,10 +1264,8 @@ export class SelectionService {
this.isIpcHandlerRegistered = true
}
private logInfo(message: string, forceShow: boolean = false) {
if (isDev || forceShow) {
Logger.info('[SelectionService] Info: ', message)
}
private logInfo(message: string) {
isDev && Logger.info('[SelectionService] Info: ', message)
}
private logError(...args: [...string[], Error]) {

View File

@@ -1,4 +1,4 @@
import { isLinux, isMac, isWin } from '@main/constant'
import { isMac } from '@main/constant'
import { locales } from '@main/utils/locales'
import { app, Menu, MenuItemConstructorOptions, nativeImage, nativeTheme, Tray } from 'electron'
@@ -6,7 +6,6 @@ import icon from '../../../build/tray_icon.png?asset'
import iconDark from '../../../build/tray_icon_dark.png?asset'
import iconLight from '../../../build/tray_icon_light.png?asset'
import { ConfigKeys, configManager } from './ConfigManager'
import selectionService from './SelectionService'
import { windowService } from './WindowService'
export class TrayService {
@@ -30,14 +29,14 @@ export class TrayService {
const iconPath = isMac ? (nativeTheme.shouldUseDarkColors ? iconLight : iconDark) : icon
const tray = new Tray(iconPath)
if (isWin) {
if (process.platform === 'win32') {
tray.setImage(iconPath)
} else if (isMac) {
} else if (process.platform === 'darwin') {
const image = nativeImage.createFromPath(iconPath)
const resizedImage = image.resize({ width: 16, height: 16 })
resizedImage.setTemplateImage(true)
tray.setImage(resizedImage)
} else if (isLinux) {
} else if (process.platform === 'linux') {
const image = nativeImage.createFromPath(iconPath)
const resizedImage = image.resize({ width: 16, height: 16 })
tray.setImage(resizedImage)
@@ -47,7 +46,7 @@ export class TrayService {
this.updateContextMenu()
if (isLinux) {
if (process.platform === 'linux') {
this.tray.setContextMenu(this.contextMenu)
}
@@ -70,31 +69,19 @@ export class TrayService {
private updateContextMenu() {
const locale = locales[configManager.getLanguage()]
const { tray: trayLocale, selection: selectionLocale } = locale.translation
const { tray: trayLocale } = locale.translation
const quickAssistantEnabled = configManager.getEnableQuickAssistant()
const selectionAssistantEnabled = configManager.getSelectionAssistantEnabled()
const enableQuickAssistant = configManager.getEnableQuickAssistant()
const template = [
{
label: trayLocale.show_window,
click: () => windowService.showMainWindow()
},
quickAssistantEnabled && {
enableQuickAssistant && {
label: trayLocale.show_mini_window,
click: () => windowService.showMiniWindow()
},
isWin && {
label: selectionLocale.name + (selectionAssistantEnabled ? ' - On' : ' - Off'),
// type: 'checkbox',
// checked: selectionAssistantEnabled,
click: () => {
if (selectionService) {
selectionService.toggleEnabled()
this.updateContextMenu()
}
}
},
{ type: 'separator' },
{
label: trayLocale.quit,
@@ -131,10 +118,6 @@ export class TrayService {
configManager.subscribe(ConfigKeys.EnableQuickAssistant, () => {
this.updateContextMenu()
})
configManager.subscribe(ConfigKeys.SelectionAssistantEnabled, () => {
this.updateContextMenu()
})
}
private quit() {

View File

@@ -1,142 +0,0 @@
import { GoogleAuth } from 'google-auth-library'
interface ServiceAccountCredentials {
privateKey: string
clientEmail: string
}
interface VertexAIAuthParams {
projectId: string
serviceAccount?: ServiceAccountCredentials
}
const REQUIRED_VERTEX_AI_SCOPE = 'https://www.googleapis.com/auth/cloud-platform'
class VertexAIService {
private static instance: VertexAIService
private authClients: Map<string, GoogleAuth> = new Map()
static getInstance(): VertexAIService {
if (!VertexAIService.instance) {
VertexAIService.instance = new VertexAIService()
}
return VertexAIService.instance
}
/**
* 格式化私钥确保它包含正确的PEM头部和尾部
*/
private formatPrivateKey(privateKey: string): string {
if (!privateKey || typeof privateKey !== 'string') {
throw new Error('Private key must be a non-empty string')
}
// 处理JSON字符串中的转义换行符
let key = privateKey.replace(/\\n/g, '\n')
// 如果已经是正确格式的PEM直接返回
if (key.includes('-----BEGIN PRIVATE KEY-----') && key.includes('-----END PRIVATE KEY-----')) {
return key
}
// 移除所有换行符和空白字符(为了重新格式化)
key = key.replace(/\s+/g, '')
// 移除可能存在的头部和尾部
key = key.replace(/-----BEGIN[^-]*-----/g, '')
key = key.replace(/-----END[^-]*-----/g, '')
// 确保私钥不为空
if (!key) {
throw new Error('Private key is empty after formatting')
}
// 添加正确的PEM头部和尾部并格式化为64字符一行
const formattedKey = key.match(/.{1,64}/g)?.join('\n') || key
return `-----BEGIN PRIVATE KEY-----\n${formattedKey}\n-----END PRIVATE KEY-----`
}
/**
* 获取认证头用于 Vertex AI 请求
*/
async getAuthHeaders(params: VertexAIAuthParams): Promise<Record<string, string>> {
const { projectId, serviceAccount } = params
if (!serviceAccount?.privateKey || !serviceAccount?.clientEmail) {
throw new Error('Service account credentials are required')
}
// 创建缓存键
const cacheKey = `${projectId}-${serviceAccount.clientEmail}`
// 检查是否已有客户端实例
let auth = this.authClients.get(cacheKey)
if (!auth) {
try {
// 格式化私钥
const formattedPrivateKey = this.formatPrivateKey(serviceAccount.privateKey)
// 创建新的认证客户端
auth = new GoogleAuth({
credentials: {
private_key: formattedPrivateKey,
client_email: serviceAccount.clientEmail
},
projectId,
scopes: [REQUIRED_VERTEX_AI_SCOPE]
})
this.authClients.set(cacheKey, auth)
} catch (formatError: any) {
throw new Error(`Invalid private key format: ${formatError.message}`)
}
}
try {
// 获取认证头
const authHeaders = await auth.getRequestHeaders()
// 转换为普通对象
const headers: Record<string, string> = {}
for (const [key, value] of Object.entries(authHeaders)) {
if (typeof value === 'string') {
headers[key] = value
}
}
return headers
} catch (error: any) {
// 如果认证失败,清除缓存的客户端
this.authClients.delete(cacheKey)
throw new Error(`Failed to authenticate with service account: ${error.message}`)
}
}
/**
* 清理指定项目的认证缓存
*/
clearAuthCache(projectId: string, clientEmail?: string): void {
if (clientEmail) {
const cacheKey = `${projectId}-${clientEmail}`
this.authClients.delete(cacheKey)
} else {
// 清理该项目的所有缓存
for (const [key] of this.authClients) {
if (key.startsWith(`${projectId}-`)) {
this.authClients.delete(key)
}
}
}
}
/**
* 清理所有认证缓存
*/
clearAllAuthCache(): void {
this.authClients.clear()
}
}
export default VertexAIService

View File

@@ -56,7 +56,7 @@ export class WindowService {
minHeight: 600,
show: false,
autoHideMenuBar: true,
transparent: false,
transparent: isMac,
vibrancy: 'sidebar',
visualEffectState: 'active',
titleBarStyle: 'hidden',

View File

@@ -2,7 +2,7 @@ import * as fs from 'node:fs'
import os from 'node:os'
import path from 'node:path'
import { isPortable } from '@main/constant'
import { isMac } from '@main/constant'
import { audioExts, documentExts, imageExts, textExts, videoExts } from '@shared/config/constant'
import { FileType, FileTypes } from '@types'
import { app } from 'electron'
@@ -23,61 +23,6 @@ function initFileTypeMap() {
// 初始化映射表
initFileTypeMap()
export function hasWritePermission(path: string) {
try {
fs.accessSync(path, fs.constants.W_OK)
return true
} catch (error) {
return false
}
}
function getAppDataPathFromConfig() {
try {
const configPath = path.join(getConfigDir(), 'config.json')
if (fs.existsSync(configPath)) {
const config = JSON.parse(fs.readFileSync(configPath, 'utf-8'))
if (config.appDataPath && fs.existsSync(config.appDataPath) && hasWritePermission(config.appDataPath)) {
return config.appDataPath
}
}
} catch (error) {
return null
}
return null
}
export function initAppDataDir() {
const appDataPath = getAppDataPathFromConfig()
if (appDataPath) {
app.setPath('userData', appDataPath)
return
}
if (isPortable) {
const portableDir = process.env.PORTABLE_EXECUTABLE_DIR
app.setPath('userData', path.join(portableDir || app.getPath('exe'), 'data'))
return
}
}
export function updateConfig(appDataPath: string) {
const configDir = getConfigDir()
if (!fs.existsSync(configDir)) {
fs.mkdirSync(configDir, { recursive: true })
}
const configPath = path.join(getConfigDir(), 'config.json')
if (!fs.existsSync(configPath)) {
fs.writeFileSync(configPath, JSON.stringify({ appDataPath }, null, 2))
return
}
const config = JSON.parse(fs.readFileSync(configPath, 'utf-8'))
config.appDataPath = appDataPath
fs.writeFileSync(configPath, JSON.stringify(config, null, 2))
}
export function getFileType(ext: string): FileTypes {
ext = ext.toLowerCase()
return fileTypeMap.get(ext) || FileTypes.OTHER
@@ -143,3 +88,12 @@ export function getCacheDir() {
export function getAppConfigDir(name: string) {
return path.join(getConfigDir(), name)
}
export function setUserDataDir() {
if (!isMac) {
const dir = path.join(path.dirname(app.getPath('exe')), 'data')
if (fs.existsSync(dir) && fs.statSync(dir).isDirectory()) {
app.setPath('userData', dir)
}
}
}

View File

@@ -26,12 +26,6 @@ const api = {
handleZoomFactor: (delta: number, reset: boolean = false) =>
ipcRenderer.invoke(IpcChannel.App_HandleZoomFactor, delta, reset),
setAutoUpdate: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetAutoUpdate, isActive),
select: (options: Electron.OpenDialogOptions) => ipcRenderer.invoke(IpcChannel.App_Select, options),
hasWritePermission: (path: string) => ipcRenderer.invoke(IpcChannel.App_HasWritePermission, path),
setAppDataPath: (path: string) => ipcRenderer.invoke(IpcChannel.App_SetAppDataPath, path),
copy: (oldPath: string, newPath: string) => ipcRenderer.invoke(IpcChannel.App_Copy, oldPath, newPath),
setStopQuitApp: (stop: boolean, reason: string) => ipcRenderer.invoke(IpcChannel.App_SetStopQuitApp, stop, reason),
relaunchApp: () => ipcRenderer.invoke(IpcChannel.App_RelaunchApp),
openWebsite: (url: string) => ipcRenderer.invoke(IpcChannel.Open_Website, url),
getCacheSize: () => ipcRenderer.invoke(IpcChannel.App_GetCacheSize),
clearCache: () => ipcRenderer.invoke(IpcChannel.App_ClearCache),
@@ -89,7 +83,6 @@ const api = {
copy: (fileId: string, destPath: string) => ipcRenderer.invoke(IpcChannel.File_Copy, fileId, destPath),
binaryImage: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_BinaryImage, fileId),
base64File: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_Base64File, fileId),
pdfInfo: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_GetPdfInfo, fileId),
getPathForFile: (file: File) => webUtils.getPathForFile(file)
},
fs: {
@@ -135,13 +128,6 @@ const api = {
listFiles: (apiKey: string) => ipcRenderer.invoke(IpcChannel.Gemini_ListFiles, apiKey),
deleteFile: (fileId: string, apiKey: string) => ipcRenderer.invoke(IpcChannel.Gemini_DeleteFile, fileId, apiKey)
},
vertexAI: {
getAuthHeaders: (params: { projectId: string; serviceAccount?: { privateKey: string; clientEmail: string } }) =>
ipcRenderer.invoke(IpcChannel.VertexAI_GetAuthHeaders, params),
clearAuthCache: (projectId: string, clientEmail?: string) =>
ipcRenderer.invoke(IpcChannel.VertexAI_ClearAuthCache, projectId, clientEmail)
},
config: {
set: (key: string, value: any, isNotify: boolean = false) =>
ipcRenderer.invoke(IpcChannel.Config_Set, key, value, isNotify),

View File

@@ -1,223 +0,0 @@
# Cherry Studio AI Provider 技术架构文档 (新方案)
## 1. 核心设计理念与目标
本架构旨在重构 Cherry Studio 的 AI Provider现称为 `aiCore`)层,以实现以下目标:
- **职责清晰**:明确划分各组件的职责,降低耦合度。
- **高度复用**:最大化业务逻辑和通用处理逻辑的复用,减少重复代码。
- **易于扩展**:方便快捷地接入新的 AI Provider (LLM供应商) 和添加新的 AI 功能 (如翻译、摘要、图像生成等)。
- **易于维护**:简化单个组件的复杂性,提高代码的可读性和可维护性。
- **标准化**:统一内部数据流和接口,简化不同 Provider 之间的差异处理。
核心思路是将纯粹的 **SDK 适配层 (`XxxApiClient`)**、**通用逻辑处理与智能解析层 (中间件)** 以及 **统一业务功能入口层 (`AiCoreService`)** 清晰地分离开来。
## 2. 核心组件详解
### 2.1. `aiCore` (原 `AiProvider` 文件夹)
这是整个 AI 功能的核心模块。
#### 2.1.1. `XxxApiClient` (例如 `aiCore/clients/openai/OpenAIApiClient.ts`)
- **职责**:作为特定 AI Provider SDK 的纯粹适配层。
- **参数适配**:将应用内部统一的 `CoreRequest` 对象 (见下文) 转换为特定 SDK 所需的请求参数格式。
- **基础响应转换**:将 SDK 返回的原始数据块 (`RawSdkChunk`,例如 `OpenAI.Chat.Completions.ChatCompletionChunk`) 转换为一组最基础、最直接的应用层 `Chunk` 对象 (定义于 `src/renderer/src/types/chunk.ts`)。
- 例如SDK 的 `delta.content` -> `TextDeltaChunk`SDK 的 `delta.reasoning_content` -> `ThinkingDeltaChunk`SDK 的 `delta.tool_calls` -> `RawToolCallChunk` (包含原始工具调用数据)。
- **关键**`XxxApiClient` **不处理**耦合在文本内容中的复杂结构,如 `<think>``<tool_use>` 标签。
- **特点**:极度轻量化,代码量少,易于实现和维护新的 Provider 适配。
#### 2.1.2. `ApiClient.ts` (或 `BaseApiClient.ts` 的核心接口)
- 定义了所有 `XxxApiClient` 必须实现的接口,如:
- `getSdkInstance(): Promise<TSdkInstance> | TSdkInstance`
- `getRequestTransformer(): RequestTransformer<TSdkParams>`
- `getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk, TResponseContext>`
- 其他可选的、与特定 Provider 相关的辅助方法 (如工具调用转换)。
#### 2.1.3. `ApiClientFactory.ts`
- 根据 Provider 配置动态创建和返回相应的 `XxxApiClient` 实例。
#### 2.1.4. `AiCoreService.ts` (`aiCore/index.ts`)
- **职责**:作为所有 AI 相关业务功能的统一入口。
- 提供面向应用的高层接口,例如:
- `executeCompletions(params: CompletionsParams): Promise<AggregatedCompletionsResult>`
- `translateText(params: TranslateParams): Promise<AggregatedTranslateResult>`
- `summarizeText(params: SummarizeParams): Promise<AggregatedSummarizeResult>`
- 未来可能的 `generateImage(prompt: string): Promise<ImageResult>` 等。
- **返回 `Promise`**:每个服务方法返回一个 `Promise`,该 `Promise` 会在整个(可能是流式的)操作完成后,以包含所有聚合结果(如完整文本、工具调用详情、最终的`usage`/`metrics`等)的对象来 `resolve`
- **支持流式回调**:服务方法的参数 (如 `CompletionsParams`) 依然包含 `onChunk` 回调,用于向调用方实时推送处理过程中的 `Chunk` 数据实现流式UI更新。
- **封装特定任务的提示工程 (Prompt Engineering)**
- 例如,`translateText` 方法内部会构建一个包含特定翻译指令的 `CoreRequest`
- **编排和调用中间件链**:通过内部的 `MiddlewareBuilder` (参见 `middleware/BUILDER_USAGE.md`) 实例,根据调用的业务方法和参数,动态构建和组织合适的中间件序列,然后通过 `applyCompletionsMiddlewares` 等组合函数执行。
- 获取 `ApiClient` 实例并将其注入到中间件上游的 `Context` 中。
- **将 `Promise``resolve``reject` 函数传递给中间件链** (通过 `Context`),以便 `FinalChunkConsumerAndNotifierMiddleware` 可以在操作完成或发生错误时结束该 `Promise`
- **优势**
- 业务逻辑(如翻译、摘要的提示构建和流程控制)只需实现一次,即可支持所有通过 `ApiClient` 接入的底层 Provider。
- **支持外部编排**:调用方可以 `await` 服务方法以获取最终聚合结果,然后将此结果作为后续操作的输入,轻松实现多步骤工作流。
- **支持内部组合**:服务自身也可以通过 `await` 调用其他原子服务方法来构建更复杂的组合功能。
#### 2.1.5. `coreRequestTypes.ts` (或 `types.ts`)
- 定义核心的、Provider 无关的内部请求结构,例如:
- `CoreCompletionsRequest`: 包含标准化后的消息列表、模型配置、工具列表、最大Token数、是否流式输出等。
- `CoreTranslateRequest`, `CoreSummarizeRequest` 等 (如果与 `CoreCompletionsRequest` 结构差异较大,否则可复用并添加任务类型标记)。
### 2.2. `middleware`
中间件层负责处理请求和响应流中的通用逻辑和特定特性。其设计和使用遵循 `middleware/BUILDER_USAGE.md` 中定义的规范。
**核心组件包括:**
- **`MiddlewareBuilder`**: 一个通用的、提供流式API的类用于动态构建中间件链。它支持从基础链开始根据条件添加、插入、替换或移除中间件。
- **`applyCompletionsMiddlewares`**: 负责接收 `MiddlewareBuilder` 构建的链并按顺序执行,专门用于 Completions 流程。
- **`MiddlewareRegistry`**: 集中管理所有可用中间件的注册表,提供统一的中间件访问接口。
- **各种独立的中间件模块** (存放于 `common/`, `core/`, `feat/` 子目录)。
#### 2.2.1. `middlewareTypes.ts`
- 定义中间件的核心类型,如 `AiProviderMiddlewareContext` (扩展后包含 `_apiClientInstance``_coreRequest`)、`MiddlewareAPI``CompletionsMiddleware` 等。
#### 2.2.2. 核心中间件 (`middleware/core/`)
- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()``CoreRequest` 转换为特定 SDK 的参数,并存入上下文。
- **`RequestExecutionMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。
- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流 (如异步迭代器) 统一适配为 `ReadableStream<RawSdkChunk>`
- **`RawSdkChunk`**指特定AI提供商SDK在流式响应中返回的、未经应用层统一处理的原始数据块格式 (例如 OpenAI 的 `ChatCompletionChunk`Gemini 的 `GenerateContentResponse` 中的部分等)。
- **`RawSdkChunkToAppChunkMiddleware.ts`**: (新增) 消费 `ReadableStream<RawSdkChunk>`,在其内部对每个 `RawSdkChunk` 调用 `ApiClient.getResponseChunkTransformer()`,将其转换为一个或多个基础的应用层 `Chunk` 对象,并输出 `ReadableStream<Chunk>`
#### 2.2.3. 特性中间件 (`middleware/feat/`)
这些中间件消费由 `ResponseTransformMiddleware` 输出的、相对标准化的 `Chunk` 流,并处理更复杂的逻辑。
- **`ThinkingTagExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `<think>...</think>` 文本内嵌标签,生成 `ThinkingDeltaChunk``ThinkingCompleteChunk`
- **`ToolUseExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `<tool_use>...</tool_use>` 文本内嵌标签,生成工具调用相关的 Chunk。如果 `ApiClient` 输出了原生工具调用数据,此中间件也负责将其转换为标准格式。
#### 2.2.4. 核心处理中间件 (`middleware/core/`)
- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()``CoreRequest` 转换为特定 SDK 的参数,并存入上下文。
- **`SdkCallMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。
- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流统一适配为标准流格式。
- **`ResponseTransformMiddleware.ts`**: 将原始 SDK 响应转换为应用层标准 `Chunk` 对象。
- **`TextChunkMiddleware.ts`**: 处理文本相关的 Chunk 流。
- **`ThinkChunkMiddleware.ts`**: 处理思考相关的 Chunk 流。
- **`McpToolChunkMiddleware.ts`**: 处理工具调用相关的 Chunk 流。
- **`WebSearchMiddleware.ts`**: 处理 Web 搜索相关逻辑。
#### 2.2.5. 通用中间件 (`middleware/common/`)
- **`LoggingMiddleware.ts`**: 请求和响应日志。
- **`AbortHandlerMiddleware.ts`**: 处理请求中止。
- **`FinalChunkConsumerMiddleware.ts`**: 消费最终的 `Chunk` 流,通过 `context.onChunk` 回调通知应用层实时数据。
- **累积数据**:在流式处理过程中,累积关键数据,如文本片段、工具调用信息、`usage`/`metrics` 等。
- **结束 `Promise`**:当输入流结束时,使用累积的聚合结果来完成整个处理流程。
- 在流结束时,发送包含最终累加信息的完成信号。
### 2.3. `types/chunk.ts`
- 定义应用全局统一的 `Chunk` 类型及其所有变体。这包括基础类型 (如 `TextDeltaChunk`, `ThinkingDeltaChunk`)、SDK原生数据传递类型 (如 `RawToolCallChunk`, `RawFinishChunk` - 作为 `ApiClient` 转换的中间产物),以及功能性类型 (如 `McpToolCallRequestChunk`, `WebSearchCompleteChunk`)。
## 3. 核心执行流程 (以 `AiCoreService.executeCompletions` 为例)
```markdown
**应用层 (例如 UI 组件)**
||
\\/
**`AiProvider.completions` (`aiCore/index.ts`)**
(1. prepare ApiClient instance. 2. use `CompletionsMiddlewareBuilder.withDefaults()` to build middleware chain. 3. call `applyCompletionsMiddlewares`)
||
\\/
**`applyCompletionsMiddlewares` (`middleware/composer.ts`)**
(接收构建好的链、ApiClient实例、原始SDK方法开始按序执行中间件)
||
\\/
**[ 预处理阶段中间件 ]**
(例如: `FinalChunkConsumerMiddleware`, `TransformCoreToSdkParamsMiddleware`, `AbortHandlerMiddleware`)
|| (Context 中准备好 SDK 请求参数)
\\/
**[ 处理阶段中间件 ]**
(例如: `McpToolChunkMiddleware`, `WebSearchMiddleware`, `TextChunkMiddleware`, `ThinkingTagExtractionMiddleware`)
|| (处理各种特性和Chunk类型)
\\/
**[ SDK调用阶段中间件 ]**
(例如: `ResponseTransformMiddleware`, `StreamAdapterMiddleware`, `SdkCallMiddleware`)
|| (输出: 标准化的应用层Chunk流)
\\/
**`FinalChunkConsumerMiddleware` (核心)**
(消费最终的 `Chunk` 流, 通过 `context.onChunk` 回调通知应用层, 并在流结束时完成处理)
||
\\/
**`AiProvider.completions` 返回 `Promise<CompletionsResult>`**
```
## 4. 建议的文件/目录结构
```
src/renderer/src/
└── aiCore/
├── clients/
│ ├── openai/
│ ├── gemini/
│ ├── anthropic/
│ ├── BaseApiClient.ts
│ ├── ApiClientFactory.ts
│ ├── AihubmixAPIClient.ts
│ ├── index.ts
│ └── types.ts
├── middleware/
│ ├── common/
│ ├── core/
│ ├── feat/
│ ├── builder.ts
│ ├── composer.ts
│ ├── index.ts
│ ├── register.ts
│ ├── schemas.ts
│ ├── types.ts
│ └── utils.ts
├── types/
│ ├── chunk.ts
│ └── ...
└── index.ts
```
## 5. 迁移和实施建议
- **小步快跑,逐步迭代**:优先完成核心流程的重构(例如 `completions`),再逐步迁移其他功能(`translate` 等)和其他 Provider。
- **优先定义核心类型**`CoreRequest`, `Chunk`, `ApiClient` 接口是整个架构的基石。
- **为 `ApiClient` 瘦身**:将现有 `XxxProvider` 中的复杂逻辑剥离到新的中间件或 `AiCoreService` 中。
- **强化中间件**:让中间件承担起更多解析和特性处理的责任。
- **编写单元测试和集成测试**:确保每个组件和整体流程的正确性。
此架构旨在提供一个更健壮、更灵活、更易于维护的 AI 功能核心,支撑 Cherry Studio 未来的发展。
## 6. 迁移策略与实施建议
本节内容提炼自早期的 `migrate.md` 文档,并根据最新的架构讨论进行了调整。
**目标架构核心组件回顾:**
与第 2 节描述的核心组件一致,主要包括 `XxxApiClient`, `AiCoreService`, 中间件链, `CoreRequest` 类型, 和标准化的 `Chunk` 类型。
**迁移步骤:**
**Phase 0: 准备工作和类型定义**
1. **定义核心数据结构 (TypeScript 类型)**
- `CoreCompletionsRequest` (Type):定义应用内部统一的对话请求结构。
- `Chunk` (Type - 检查并按需扩展现有 `src/renderer/src/types/chunk.ts`)定义所有可能的通用Chunk类型。
- 为其他API翻译、总结定义类似的 `CoreXxxRequest` (Type)。
2. **定义 `ApiClient` 接口:** 明确 `getRequestTransformer`, `getResponseChunkTransformer`, `getSdkInstance` 等核心方法。
3. **调整 `AiProviderMiddlewareContext`**
- 确保包含 `_apiClientInstance: ApiClient<any,any,any>`
- 确保包含 `_coreRequest: CoreRequestType`
- 考虑添加 `resolvePromise: (value: AggregatedResultType) => void``rejectPromise: (reason?: any) => void` 用于 `AiCoreService` 的 Promise 返回。
**Phase 1: 实现第一个 `ApiClient` (以 `OpenAIApiClient` 为例)**
1. **创建 `OpenAIApiClient` 类:** 实现 `ApiClient` 接口。
2. **迁移SDK实例和配置。**
3. **实现 `getRequestTransformer()`**`CoreCompletionsRequest` 转换为 OpenAI SDK 参数。
4. **实现 `getResponseChunkTransformer()`**`OpenAI.Chat.Completions.ChatCompletionChunk` 转换为基础的 `

View File

@@ -1,284 +0,0 @@
/**
* AI SDK 到 Cherry Studio Chunk 适配器
* 用于将 AI SDK 的 fullStream 转换为 Cherry Studio 的 chunk 格式
*/
import { TextStreamPart } from '@cherrystudio/ai-core'
import { Chunk, ChunkType } from '@renderer/types/chunk'
export interface CherryStudioChunk {
type: 'text-delta' | 'text-complete' | 'tool-call' | 'tool-result' | 'finish' | 'error'
text?: string
toolCall?: any
toolResult?: any
finishReason?: string
usage?: any
error?: any
}
/**
* AI SDK 到 Cherry Studio Chunk 适配器类
* 处理 fullStream 到 Cherry Studio chunk 的转换
*/
export class AiSdkToChunkAdapter {
constructor(private onChunk: (chunk: Chunk) => void) {}
/**
* 处理 AI SDK 流结果
* @param aiSdkResult AI SDK 的流结果对象
* @returns 最终的文本内容
*/
async processStream(aiSdkResult: any): Promise<string> {
// 如果是流式且有 fullStream
if (aiSdkResult.fullStream) {
await this.readFullStream(aiSdkResult.fullStream)
}
// 使用 streamResult.text 获取最终结果
return await aiSdkResult.text
}
/**
* 读取 fullStream 并转换为 Cherry Studio chunks
* @param fullStream AI SDK 的 fullStream (ReadableStream)
*/
private async readFullStream(fullStream: ReadableStream<TextStreamPart<any>>) {
const reader = fullStream.getReader()
const final = {
text: '',
reasoning_content: ''
}
try {
while (true) {
const { done, value } = await reader.read()
if (done) {
break
}
// 转换并发送 chunk
this.convertAndEmitChunk(value, final)
}
} finally {
reader.releaseLock()
}
}
/**
* 转换 AI SDK chunk 为 Cherry Studio chunk 并调用回调
* @param chunk AI SDK 的 chunk 数据
*/
private convertAndEmitChunk(chunk: any, final: { text: string; reasoning_content: string }) {
console.log('AI SDK chunk type:', chunk.type, chunk)
switch (chunk.type) {
// === 文本相关事件 ===
case 'text-delta':
final.text += chunk.textDelta || ''
this.onChunk({
type: ChunkType.TEXT_DELTA,
text: chunk.textDelta || ''
})
break
case 'reasoning':
this.onChunk({
type: ChunkType.THINKING_DELTA,
text: chunk.textDelta || ''
})
break
case 'redacted-reasoning':
this.onChunk({
type: ChunkType.THINKING_DELTA,
text: chunk.data || ''
})
break
case 'reasoning-signature':
this.onChunk({
type: ChunkType.THINKING_COMPLETE,
text: chunk.text || '',
thinking_millsec: chunk.thinking_millsec || 0
})
break
// === 工具调用相关事件 ===
case 'tool-call-streaming-start':
// 开始流式工具调用
this.onChunk({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: [
{
id: chunk.toolCallId,
name: chunk.toolName,
args: {}
}
]
})
break
case 'tool-call-delta':
// 工具调用参数的增量更新
this.onChunk({
type: ChunkType.MCP_TOOL_IN_PROGRESS,
responses: [
{
id: chunk.toolCallId,
tool: {
id: chunk.toolName,
// TODO: serverId,serverName
serverId: 'ai-sdk',
serverName: 'AI SDK',
name: chunk.toolName,
description: '',
inputSchema: {
type: 'object',
title: chunk.toolName,
properties: {}
}
},
arguments: {},
status: 'invoking',
response: chunk.argsTextDelta,
toolCallId: chunk.toolCallId
}
]
})
break
case 'tool-call':
// 完整的工具调用
this.onChunk({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: [
{
id: chunk.toolCallId,
name: chunk.toolName,
args: chunk.args
}
]
})
break
case 'tool-result':
// 工具调用结果
this.onChunk({
type: ChunkType.MCP_TOOL_COMPLETE,
responses: [
{
id: chunk.toolCallId,
tool: {
id: chunk.toolName,
// TODO: serverId,serverName
serverId: 'ai-sdk',
serverName: 'AI SDK',
name: chunk.toolName,
description: '',
inputSchema: {
type: 'object',
title: chunk.toolName,
properties: {}
}
},
arguments: chunk.args || {},
status: 'done',
response: chunk.result,
toolCallId: chunk.toolCallId
}
]
})
break
// === 步骤相关事件 ===
// case 'step-start':
// this.onChunk({
// type: ChunkType.LLM_RESPONSE_CREATED
// })
// break
case 'step-finish':
this.onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
text: final.text || '',
reasoning_content: final.reasoning_content || '',
usage: {
completion_tokens: chunk.usage.completionTokens || 0,
prompt_tokens: chunk.usage.promptTokens || 0,
total_tokens: chunk.usage.totalTokens || 0
},
metrics: chunk.usage
? {
completion_tokens: chunk.usage.completionTokens || 0,
time_completion_millsec: 0
}
: undefined
}
})
break
case 'finish':
this.onChunk({
type: ChunkType.TEXT_COMPLETE,
text: final.text || '' // TEXT_COMPLETE 需要 text 字段
})
this.onChunk({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
text: final.text || '',
reasoning_content: final.reasoning_content || '',
usage: {
completion_tokens: chunk.usage.completionTokens || 0,
prompt_tokens: chunk.usage.promptTokens || 0,
total_tokens: chunk.usage.totalTokens || 0
},
metrics: chunk.usage
? {
completion_tokens: chunk.usage.completionTokens || 0,
time_completion_millsec: 0
}
: undefined
}
})
break
// === 源和文件相关事件 ===
case 'source':
// 源信息,可以映射到知识搜索完成
this.onChunk({
type: ChunkType.KNOWLEDGE_SEARCH_COMPLETE,
knowledge: [
{
id: Number(chunk.source.id) || Date.now(),
content: chunk.source.title || '',
sourceUrl: chunk.source.url || '',
type: 'url'
}
]
})
break
case 'file':
// 文件相关事件,可能是图片生成
this.onChunk({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: [chunk.base64]
}
})
break
case 'error':
this.onChunk({
type: ChunkType.ERROR,
error: {
message: chunk.error || 'Unknown error'
}
})
break
default:
// 其他类型的 chunk 可以忽略或记录日志
console.log('Unhandled AI SDK chunk type:', chunk.type, chunk)
}
}
}
export default AiSdkToChunkAdapter

View File

@@ -1,208 +0,0 @@
import { isOpenAILLMModel } from '@renderer/config/models'
import {
GenerateImageParams,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
ToolCallResponse
} from '@renderer/types'
import {
RequestOptions,
SdkInstance,
SdkMessageParam,
SdkModel,
SdkParams,
SdkRawChunk,
SdkRawOutput,
SdkTool,
SdkToolCall
} from '@renderer/types/sdk'
import { CompletionsContext } from '../middleware/types'
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
import { BaseApiClient } from './BaseApiClient'
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
import { RequestTransformer, ResponseChunkTransformer } from './types'
/**
* AihubmixAPIClient - 根据模型类型自动选择合适的ApiClient
* 使用装饰器模式实现在ApiClient层面进行模型路由
*/
export class AihubmixAPIClient extends BaseApiClient {
// 使用联合类型而不是any保持类型安全
private clients: Map<string, AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient> =
new Map()
private defaultClient: OpenAIAPIClient
private currentClient: BaseApiClient
constructor(provider: Provider) {
super(provider)
// 初始化各个client - 现在有类型安全
const claudeClient = new AnthropicAPIClient(provider)
const geminiClient = new GeminiAPIClient({ ...provider, apiHost: 'https://aihubmix.com/gemini' })
const openaiClient = new OpenAIResponseAPIClient(provider)
const defaultClient = new OpenAIAPIClient(provider)
this.clients.set('claude', claudeClient)
this.clients.set('gemini', geminiClient)
this.clients.set('openai', openaiClient)
this.clients.set('default', defaultClient)
// 设置默认client
this.defaultClient = defaultClient
this.currentClient = this.defaultClient as BaseApiClient
}
/**
* 类型守卫确保client是BaseApiClient的实例
*/
private isValidClient(client: unknown): client is BaseApiClient {
return (
client !== null &&
client !== undefined &&
typeof client === 'object' &&
'createCompletions' in client &&
'getRequestTransformer' in client &&
'getResponseChunkTransformer' in client
)
}
/**
* 根据模型获取合适的client
*/
private getClient(model: Model): BaseApiClient {
const id = model.id.toLowerCase()
// claude开头
if (id.startsWith('claude')) {
const client = this.clients.get('claude')
if (!client || !this.isValidClient(client)) {
throw new Error('Claude client not properly initialized')
}
return client
}
// gemini开头 且不以-nothink、-search结尾
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
const client = this.clients.get('gemini')
if (!client || !this.isValidClient(client)) {
throw new Error('Gemini client not properly initialized')
}
return client
}
// OpenAI系列模型
if (isOpenAILLMModel(model)) {
const client = this.clients.get('openai')
if (!client || !this.isValidClient(client)) {
throw new Error('OpenAI client not properly initialized')
}
return client
}
return this.defaultClient as BaseApiClient
}
/**
* 根据模型选择合适的client并委托调用
*/
public getClientForModel(model: Model): BaseApiClient {
this.currentClient = this.getClient(model)
return this.currentClient
}
// ============ BaseApiClient 抽象方法实现 ============
async createCompletions(payload: SdkParams, options?: RequestOptions): Promise<SdkRawOutput> {
// 尝试从payload中提取模型信息来选择client
const modelId = this.extractModelFromPayload(payload)
if (modelId) {
const modelObj = { id: modelId } as Model
const targetClient = this.getClient(modelObj)
return targetClient.createCompletions(payload, options)
}
// 如果无法从payload中提取模型使用当前设置的client
return this.currentClient.createCompletions(payload, options)
}
/**
* 从SDK payload中提取模型ID
*/
private extractModelFromPayload(payload: SdkParams): string | null {
// 不同的SDK可能有不同的字段名
if ('model' in payload && typeof payload.model === 'string') {
return payload.model
}
return null
}
async generateImage(params: GenerateImageParams): Promise<string[]> {
return this.currentClient.generateImage(params)
}
async getEmbeddingDimensions(model?: Model): Promise<number> {
const client = model ? this.getClient(model) : this.currentClient
return client.getEmbeddingDimensions(model)
}
async listModels(): Promise<SdkModel[]> {
// 可以聚合所有client的模型或者使用默认client
return this.defaultClient.listModels()
}
async getSdkInstance(): Promise<SdkInstance> {
return this.currentClient.getSdkInstance()
}
getRequestTransformer(): RequestTransformer<SdkParams, SdkMessageParam> {
return this.currentClient.getRequestTransformer()
}
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<SdkRawChunk> {
return this.currentClient.getResponseChunkTransformer(ctx)
}
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] {
return this.currentClient.convertMcpToolsToSdkTools(mcpTools)
}
convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools)
}
convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse {
return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
}
buildSdkMessages(
currentReqMessages: SdkMessageParam[],
output: SdkRawOutput | string,
toolResults: SdkMessageParam[],
toolCalls?: SdkToolCall[]
): SdkMessageParam[] {
return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
}
convertMcpToolResponseToSdkMessageParam(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): SdkMessageParam | undefined {
const client = this.getClient(model)
return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
}
extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] {
return this.currentClient.extractMessagesFromSdkPayload(sdkPayload)
}
estimateMessageTokens(message: SdkMessageParam): number {
return this.currentClient.estimateMessageTokens(message)
}
}

View File

@@ -1,66 +0,0 @@
import { Provider } from '@renderer/types'
import { AihubmixAPIClient } from './AihubmixAPIClient'
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
import { BaseApiClient } from './BaseApiClient'
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
import { VertexAPIClient } from './gemini/VertexAPIClient'
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
/**
* Factory for creating ApiClient instances based on provider configuration
* 根据提供者配置创建ApiClient实例的工厂
*/
export class ApiClientFactory {
/**
* Create an ApiClient instance for the given provider
* 为给定的提供者创建ApiClient实例
*/
static create(provider: Provider): BaseApiClient {
console.log(`[ApiClientFactory] Creating ApiClient for provider:`, {
id: provider.id,
type: provider.type
})
let instance: BaseApiClient
// 首先检查特殊的provider id
if (provider.id === 'aihubmix') {
console.log(`[ApiClientFactory] Creating AihubmixAPIClient for provider: ${provider.id}`)
instance = new AihubmixAPIClient(provider) as BaseApiClient
return instance
}
// 然后检查标准的provider type
switch (provider.type) {
case 'openai':
case 'azure-openai':
console.log(`[ApiClientFactory] Creating OpenAIApiClient for provider: ${provider.id}`)
instance = new OpenAIAPIClient(provider) as BaseApiClient
break
case 'openai-response':
instance = new OpenAIResponseAPIClient(provider) as BaseApiClient
break
case 'gemini':
instance = new GeminiAPIClient(provider) as BaseApiClient
break
case 'vertexai':
instance = new VertexAPIClient(provider) as BaseApiClient
break
case 'anthropic':
instance = new AnthropicAPIClient(provider) as BaseApiClient
break
default:
console.log(`[ApiClientFactory] Using default OpenAIApiClient for provider: ${provider.id}`)
instance = new OpenAIAPIClient(provider) as BaseApiClient
break
}
return instance
}
}
export function isOpenAIProvider(provider: Provider) {
return !['anthropic', 'gemini'].includes(provider.type)
}

View File

@@ -1,734 +0,0 @@
import Anthropic from '@anthropic-ai/sdk'
import {
Base64ImageSource,
ImageBlockParam,
MessageParam,
TextBlockParam,
ToolResultBlockParam,
ToolUseBlock,
WebSearchTool20250305
} from '@anthropic-ai/sdk/resources'
import {
ContentBlock,
ContentBlockParam,
MessageCreateParams,
MessageCreateParamsBase,
RedactedThinkingBlockParam,
ServerToolUseBlockParam,
ThinkingBlockParam,
ThinkingConfigParam,
ToolUnion,
ToolUseBlockParam,
WebSearchResultBlock,
WebSearchToolResultBlockParam,
WebSearchToolResultError
} from '@anthropic-ai/sdk/resources/messages'
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import Logger from '@renderer/config/logger'
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import FileManager from '@renderer/services/FileManager'
import { estimateTextTokens } from '@renderer/services/TokenService'
import {
Assistant,
EFFORT_RATIO,
FileTypes,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
ToolCallResponse,
WebSearchSource
} from '@renderer/types'
import {
ChunkType,
ErrorChunk,
LLMWebSearchCompleteChunk,
LLMWebSearchInProgressChunk,
MCPToolCreatedChunk,
TextDeltaChunk,
ThinkingDeltaChunk
} from '@renderer/types/chunk'
import type { Message } from '@renderer/types/newMessage'
import {
AnthropicSdkMessageParam,
AnthropicSdkParams,
AnthropicSdkRawChunk,
AnthropicSdkRawOutput
} from '@renderer/types/sdk'
import { addImageFileToContents } from '@renderer/utils/formats'
import {
anthropicToolUseToMcpTool,
isEnabledToolUse,
mcpToolCallResponseToAnthropicMessage,
mcpToolsToAnthropicTools
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { BaseApiClient } from '../BaseApiClient'
import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
export class AnthropicAPIClient extends BaseApiClient<
Anthropic,
AnthropicSdkParams,
AnthropicSdkRawOutput,
AnthropicSdkRawChunk,
AnthropicSdkMessageParam,
ToolUseBlock,
ToolUnion
> {
constructor(provider: Provider) {
super(provider)
}
async getSdkInstance(): Promise<Anthropic> {
if (this.sdkInstance) {
return this.sdkInstance
}
this.sdkInstance = new Anthropic({
apiKey: this.apiKey,
baseURL: this.getBaseURL(),
dangerouslyAllowBrowser: true,
defaultHeaders: {
'anthropic-beta': 'output-128k-2025-02-19'
}
})
return this.sdkInstance
}
override async createCompletions(
payload: AnthropicSdkParams,
options?: Anthropic.RequestOptions
): Promise<AnthropicSdkRawOutput> {
const sdk = await this.getSdkInstance()
if (payload.stream) {
return sdk.messages.stream(payload, options)
}
return await sdk.messages.create(payload, options)
}
// @ts-ignore sdk未提供
// eslint-disable-next-line @typescript-eslint/no-unused-vars
override async generateImage(generateImageParams: GenerateImageParams): Promise<string[]> {
return []
}
override async listModels(): Promise<Anthropic.ModelInfo[]> {
const sdk = await this.getSdkInstance()
const response = await sdk.models.list()
return response.data
}
// @ts-ignore sdk未提供
override async getEmbeddingDimensions(): Promise<number> {
throw new Error("Anthropic SDK doesn't support getEmbeddingDimensions method.")
}
override getTemperature(assistant: Assistant, model: Model): number | undefined {
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
return undefined
}
return assistant.settings?.temperature
}
override getTopP(assistant: Assistant, model: Model): number | undefined {
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
return undefined
}
return assistant.settings?.topP
}
/**
* Get the reasoning effort
* @param assistant - The assistant
* @param model - The model
* @returns The reasoning effort
*/
private getBudgetToken(assistant: Assistant, model: Model): ThinkingConfigParam | undefined {
if (!isReasoningModel(model)) {
return undefined
}
const { maxTokens } = getAssistantSettings(assistant)
const reasoningEffort = assistant?.settings?.reasoning_effort
if (reasoningEffort === undefined) {
return {
type: 'disabled'
}
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
const budgetTokens = Math.max(
1024,
Math.floor(
Math.min(
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
findTokenLimit(model.id)?.min!,
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
)
)
)
return {
type: 'enabled',
budget_tokens: budgetTokens
}
}
/**
* Get the message parameter
* @param message - The message
* @param model - The model
* @returns The message parameter
*/
public async convertMessageToSdkParam(message: Message): Promise<AnthropicSdkMessageParam> {
const parts: MessageParam['content'] = [
{
type: 'text',
text: getMainTextContent(message)
}
]
// Get and process image blocks
const imageBlocks = findImageBlocks(message)
for (const imageBlock of imageBlocks) {
if (imageBlock.file) {
// Handle uploaded file
const file = imageBlock.file
const base64Data = await window.api.file.base64Image(file.id + file.ext)
parts.push({
type: 'image',
source: {
data: base64Data.base64,
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
type: 'base64'
}
})
}
}
// Get and process file blocks
const fileBlocks = findFileBlocks(message)
for (const fileBlock of fileBlocks) {
const { file } = fileBlock
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) {
const base64Data = await FileManager.readBase64File(file)
parts.push({
type: 'document',
source: {
type: 'base64',
media_type: 'application/pdf',
data: base64Data
}
})
} else {
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
parts.push({
type: 'text',
text: file.origin_name + '\n' + fileContent
})
}
}
}
return {
role: message.role === 'system' ? 'user' : message.role,
content: parts
}
}
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ToolUnion[] {
return mcpToolsToAnthropicTools(mcpTools)
}
public convertMcpToolResponseToSdkMessageParam(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): AnthropicSdkMessageParam | undefined {
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model)
} else if ('toolCallId' in mcpToolResponse) {
return {
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: mcpToolResponse.toolCallId!,
content: resp.content
.map((item) => {
if (item.type === 'text') {
return {
type: 'text',
text: item.text || ''
} satisfies TextBlockParam
}
if (item.type === 'image') {
return {
type: 'image',
source: {
data: item.data || '',
media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'],
type: 'base64'
}
} satisfies ImageBlockParam
}
return
})
.filter((n) => typeof n !== 'undefined'),
is_error: resp.isError
} satisfies ToolResultBlockParam
]
}
}
return
}
// Implementing abstract methods from BaseApiClient
convertSdkToolCallToMcp(toolCall: ToolUseBlock, mcpTools: MCPTool[]): MCPTool | undefined {
// Based on anthropicToolUseToMcpTool logic in AnthropicProvider
// This might need adjustment based on how tool calls are specifically handled in the new structure
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
return mcpTool
}
convertSdkToolCallToMcpToolResponse(toolCall: ToolUseBlock, mcpTool: MCPTool): ToolCallResponse {
return {
id: toolCall.id,
toolCallId: toolCall.id,
tool: mcpTool,
arguments: toolCall.input as Record<string, unknown>,
status: 'pending'
} as ToolCallResponse
}
override buildSdkMessages(
currentReqMessages: AnthropicSdkMessageParam[],
output: Anthropic.Message,
toolResults: AnthropicSdkMessageParam[]
): AnthropicSdkMessageParam[] {
const assistantMessage: AnthropicSdkMessageParam = {
role: output.role,
content: convertContentBlocksToParams(output.content)
}
const newMessages: AnthropicSdkMessageParam[] = [...currentReqMessages, assistantMessage]
if (toolResults && toolResults.length > 0) {
newMessages.push(...toolResults)
}
return newMessages
}
override estimateMessageTokens(message: AnthropicSdkMessageParam): number {
if (typeof message.content === 'string') {
return estimateTextTokens(message.content)
}
return message.content
.map((content) => {
switch (content.type) {
case 'text':
return estimateTextTokens(content.text)
case 'image':
if (content.source.type === 'base64') {
return estimateTextTokens(content.source.data)
} else {
return estimateTextTokens(content.source.url)
}
case 'tool_use':
return estimateTextTokens(JSON.stringify(content.input))
case 'tool_result':
return estimateTextTokens(JSON.stringify(content.content))
default:
return 0
}
})
.reduce((acc, curr) => acc + curr, 0)
}
public buildAssistantMessage(message: Anthropic.Message): AnthropicSdkMessageParam {
const messageParam: AnthropicSdkMessageParam = {
role: message.role,
content: convertContentBlocksToParams(message.content)
}
return messageParam
}
public extractMessagesFromSdkPayload(sdkPayload: AnthropicSdkParams): AnthropicSdkMessageParam[] {
return sdkPayload.messages || []
}
/**
* Anthropic专用的原始流监听器
* 处理MessageStream对象的特定事件
*/
attachRawStreamListener(
rawOutput: AnthropicSdkRawOutput,
listener: RawStreamListener<AnthropicSdkRawChunk>
): AnthropicSdkRawOutput {
console.log(`[AnthropicApiClient] 附加流监听器到原始输出`)
// 专用的Anthropic事件处理
const anthropicListener = listener as AnthropicStreamListener
// 检查是否为MessageStream
if (rawOutput instanceof MessageStream) {
console.log(`[AnthropicApiClient] 检测到 Anthropic MessageStream附加专用监听器`)
if (listener.onStart) {
listener.onStart()
}
if (listener.onChunk) {
rawOutput.on('streamEvent', (event: AnthropicSdkRawChunk) => {
listener.onChunk!(event)
})
}
if (anthropicListener.onContentBlock) {
rawOutput.on('contentBlock', anthropicListener.onContentBlock)
}
if (anthropicListener.onMessage) {
rawOutput.on('finalMessage', anthropicListener.onMessage)
}
if (listener.onEnd) {
rawOutput.on('end', () => {
listener.onEnd!()
})
}
if (listener.onError) {
rawOutput.on('error', (error: Error) => {
listener.onError!(error)
})
}
return rawOutput
}
if (anthropicListener.onMessage) {
anthropicListener.onMessage(rawOutput)
}
// 对于非MessageStream响应
return rawOutput
}
private async getWebSearchParams(model: Model): Promise<WebSearchTool20250305 | undefined> {
if (!isWebSearchModel(model)) {
return undefined
}
return {
type: 'web_search_20250305',
name: 'web_search',
max_uses: 5
} as WebSearchTool20250305
}
getRequestTransformer(): RequestTransformer<AnthropicSdkParams, AnthropicSdkMessageParam> {
return {
transform: async (
coreRequest,
assistant,
model,
isRecursiveCall,
recursiveSdkMessages
): Promise<{
payload: AnthropicSdkParams
messages: AnthropicSdkMessageParam[]
metadata: Record<string, any>
}> => {
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
// 1. 处理系统消息
let systemPrompt = assistant.prompt
// 2. 设置工具
const { tools } = this.setupToolsConfig({
mcpTools: mcpTools,
model,
enableToolUse: isEnabledToolUse(assistant)
})
if (this.useSystemPromptForTools) {
systemPrompt = await buildSystemPrompt(systemPrompt, mcpTools, assistant)
}
const systemMessage: TextBlockParam | undefined = systemPrompt
? { type: 'text', text: systemPrompt }
: undefined
// 3. 处理用户消息
const sdkMessages: AnthropicSdkMessageParam[] = []
if (typeof messages === 'string') {
sdkMessages.push({ role: 'user', content: messages })
} else {
const processedMessages = addImageFileToContents(messages)
for (const message of processedMessages) {
sdkMessages.push(await this.convertMessageToSdkParam(message))
}
}
if (enableWebSearch) {
const webSearchTool = await this.getWebSearchParams(model)
if (webSearchTool) {
tools.push(webSearchTool)
}
}
const commonParams: MessageCreateParamsBase = {
model: model.id,
messages:
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
? recursiveSdkMessages
: sdkMessages,
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
temperature: this.getTemperature(assistant, model),
top_p: this.getTopP(assistant, model),
system: systemMessage ? [systemMessage] : undefined,
thinking: this.getBudgetToken(assistant, model),
tools: tools.length > 0 ? tools : undefined,
...this.getCustomParameters(assistant)
}
const finalParams: MessageCreateParams = streamOutput
? {
...commonParams,
stream: true
}
: {
...commonParams,
stream: false
}
const timeout = this.getTimeout(model)
return { payload: finalParams, messages: sdkMessages, metadata: { timeout } }
}
}
}
getResponseChunkTransformer(): ResponseChunkTransformer<AnthropicSdkRawChunk> {
return () => {
let accumulatedJson = ''
const toolCalls: Record<number, ToolUseBlock> = {}
return {
async transform(rawChunk: AnthropicSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
switch (rawChunk.type) {
case 'message': {
let i = 0
for (const content of rawChunk.content) {
switch (content.type) {
case 'text': {
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: content.text
} as TextDeltaChunk)
break
}
case 'tool_use': {
toolCalls[i] = content
i++
break
}
case 'thinking': {
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: content.thinking
} as ThinkingDeltaChunk)
break
}
case 'web_search_tool_result': {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: content.content,
source: WebSearchSource.ANTHROPIC
}
} as LLMWebSearchCompleteChunk)
break
}
}
}
if (i > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: Object.values(toolCalls)
} as MCPToolCreatedChunk)
}
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: rawChunk.usage.input_tokens || 0,
completion_tokens: rawChunk.usage.output_tokens || 0,
total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0)
}
}
})
break
}
case 'content_block_start': {
const contentBlock = rawChunk.content_block
switch (contentBlock.type) {
case 'server_tool_use': {
if (contentBlock.name === 'web_search') {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS
} as LLMWebSearchInProgressChunk)
}
break
}
case 'web_search_tool_result': {
if (
contentBlock.content &&
(contentBlock.content as WebSearchToolResultError).type === 'web_search_tool_result_error'
) {
controller.enqueue({
type: ChunkType.ERROR,
error: {
code: (contentBlock.content as WebSearchToolResultError).error_code,
message: (contentBlock.content as WebSearchToolResultError).error_code
}
} as ErrorChunk)
} else {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: contentBlock.content as Array<WebSearchResultBlock>,
source: WebSearchSource.ANTHROPIC
}
} as LLMWebSearchCompleteChunk)
}
break
}
case 'tool_use': {
toolCalls[rawChunk.index] = contentBlock
break
}
}
break
}
case 'content_block_delta': {
const messageDelta = rawChunk.delta
switch (messageDelta.type) {
case 'text_delta': {
if (messageDelta.text) {
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: messageDelta.text
} as TextDeltaChunk)
}
break
}
case 'thinking_delta': {
if (messageDelta.thinking) {
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: messageDelta.thinking
} as ThinkingDeltaChunk)
}
break
}
case 'input_json_delta': {
if (messageDelta.partial_json) {
accumulatedJson += messageDelta.partial_json
}
break
}
}
break
}
case 'content_block_stop': {
const toolCall = toolCalls[rawChunk.index]
if (toolCall) {
try {
toolCall.input = JSON.parse(accumulatedJson)
Logger.debug(`Tool call id: ${toolCall.id}, accumulated json: ${accumulatedJson}`)
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: [toolCall]
} as MCPToolCreatedChunk)
} catch (error) {
Logger.error(`Error parsing tool call input: ${error}`)
}
}
break
}
case 'message_delta': {
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: rawChunk.usage.input_tokens || 0,
completion_tokens: rawChunk.usage.output_tokens || 0,
total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0)
}
}
})
}
}
}
}
}
}
}
/**
* 将 ContentBlock 数组转换为 ContentBlockParam 数组
* 去除服务器生成的额外字段只保留发送给API所需的字段
*/
function convertContentBlocksToParams(contentBlocks: ContentBlock[]): ContentBlockParam[] {
return contentBlocks.map((block): ContentBlockParam => {
switch (block.type) {
case 'text':
// TextBlock -> TextBlockParam去除 citations 等服务器字段
return {
type: 'text',
text: block.text
} satisfies TextBlockParam
case 'tool_use':
// ToolUseBlock -> ToolUseBlockParam
return {
type: 'tool_use',
id: block.id,
name: block.name,
input: block.input
} satisfies ToolUseBlockParam
case 'thinking':
// ThinkingBlock -> ThinkingBlockParam
return {
type: 'thinking',
thinking: block.thinking,
signature: block.signature
} satisfies ThinkingBlockParam
case 'redacted_thinking':
// RedactedThinkingBlock -> RedactedThinkingBlockParam
return {
type: 'redacted_thinking',
data: block.data
} satisfies RedactedThinkingBlockParam
case 'server_tool_use':
// ServerToolUseBlock -> ServerToolUseBlockParam
return {
type: 'server_tool_use',
id: block.id,
name: block.name,
input: block.input
} satisfies ServerToolUseBlockParam
case 'web_search_tool_result':
// WebSearchToolResultBlock -> WebSearchToolResultBlockParam
return {
type: 'web_search_tool_result',
tool_use_id: block.tool_use_id,
content: block.content
} satisfies WebSearchToolResultBlockParam
default:
return block as ContentBlockParam
}
})
}

View File

@@ -1,807 +0,0 @@
import {
Content,
File,
FileState,
FunctionCall,
GenerateContentConfig,
GenerateImagesConfig,
GoogleGenAI,
HarmBlockThreshold,
HarmCategory,
Modality,
Model as GeminiModel,
Pager,
Part,
SafetySetting,
SendMessageParameters,
ThinkingConfig,
Tool
} from '@google/genai'
import { nanoid } from '@reduxjs/toolkit'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import {
findTokenLimit,
GEMINI_FLASH_MODEL_REGEX,
isGeminiReasoningModel,
isGemmaModel,
isVisionModel
} from '@renderer/config/models'
import { CacheService } from '@renderer/services/CacheService'
import { estimateTextTokens } from '@renderer/services/TokenService'
import {
Assistant,
EFFORT_RATIO,
FileType,
FileTypes,
GenerateImageParams,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
ToolCallResponse,
WebSearchSource
} from '@renderer/types'
import { ChunkType, LLMWebSearchCompleteChunk } from '@renderer/types/chunk'
import { Message } from '@renderer/types/newMessage'
import {
GeminiOptions,
GeminiSdkMessageParam,
GeminiSdkParams,
GeminiSdkRawChunk,
GeminiSdkRawOutput,
GeminiSdkToolCall
} from '@renderer/types/sdk'
import {
geminiFunctionCallToMcpTool,
isEnabledToolUse,
mcpToolCallResponseToGeminiMessage,
mcpToolsToGeminiTools
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { MB } from '@shared/config/constant'
import { BaseApiClient } from '../BaseApiClient'
import { RequestTransformer, ResponseChunkTransformer } from '../types'
export class GeminiAPIClient extends BaseApiClient<
GoogleGenAI,
GeminiSdkParams,
GeminiSdkRawOutput,
GeminiSdkRawChunk,
GeminiSdkMessageParam,
GeminiSdkToolCall,
Tool
> {
constructor(provider: Provider) {
super(provider)
}
override async createCompletions(payload: GeminiSdkParams, options?: GeminiOptions): Promise<GeminiSdkRawOutput> {
const sdk = await this.getSdkInstance()
const { model, history, ...rest } = payload
const realPayload: Omit<GeminiSdkParams, 'model'> = {
...rest,
config: {
...rest.config,
abortSignal: options?.signal,
httpOptions: {
...rest.config?.httpOptions,
timeout: options?.timeout
}
}
} satisfies SendMessageParameters
const streamOutput = options?.streamOutput
const chat = sdk.chats.create({
model: model,
history: history
})
if (streamOutput) {
const stream = chat.sendMessageStream(realPayload)
return stream
} else {
const response = await chat.sendMessage(realPayload)
return response
}
}
override async generateImage(generateImageParams: GenerateImageParams): Promise<string[]> {
const sdk = await this.getSdkInstance()
try {
const { model, prompt, imageSize, batchSize, signal } = generateImageParams
const config: GenerateImagesConfig = {
numberOfImages: batchSize,
aspectRatio: imageSize,
abortSignal: signal,
httpOptions: {
timeout: 5 * 60 * 1000
}
}
const response = await sdk.models.generateImages({
model: model,
prompt,
config
})
if (!response.generatedImages || response.generatedImages.length === 0) {
return []
}
const images = response.generatedImages
.filter((image) => image.image?.imageBytes)
.map((image) => {
const dataPrefix = `data:${image.image?.mimeType || 'image/png'};base64,`
return dataPrefix + image.image?.imageBytes
})
// console.log(response?.generatedImages?.[0]?.image?.imageBytes);
return images
} catch (error) {
console.error('[generateImage] error:', error)
throw error
}
}
override async getEmbeddingDimensions(model: Model): Promise<number> {
const sdk = await this.getSdkInstance()
const data = await sdk.models.embedContent({
model: model.id,
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
})
return data.embeddings?.[0]?.values?.length || 0
}
override async listModels(): Promise<GeminiModel[]> {
const sdk = await this.getSdkInstance()
const response = await sdk.models.list()
const models: GeminiModel[] = []
for await (const model of response) {
models.push(model)
}
return models
}
override async getSdkInstance() {
if (this.sdkInstance) {
return this.sdkInstance
}
this.sdkInstance = new GoogleGenAI({
vertexai: false,
apiKey: this.apiKey,
apiVersion: this.getApiVersion(),
httpOptions: {
baseUrl: this.getBaseURL(),
apiVersion: this.getApiVersion()
}
})
return this.sdkInstance
}
protected getApiVersion(): string {
if (this.provider.isVertex) {
return 'v1'
}
return 'v1beta'
}
/**
* Handle a PDF file
* @param file - The file
* @returns The part
*/
private async handlePdfFile(file: FileType): Promise<Part> {
const smallFileSize = 20 * MB
const isSmallFile = file.size < smallFileSize
if (isSmallFile) {
const { data, mimeType } = await this.base64File(file)
return {
inlineData: {
data,
mimeType
} as Part['inlineData']
}
}
// Retrieve file from Gemini uploaded files
const fileMetadata: File | undefined = await this.retrieveFile(file)
if (fileMetadata) {
return {
fileData: {
fileUri: fileMetadata.uri,
mimeType: fileMetadata.mimeType
} as Part['fileData']
}
}
// If file is not found, upload it to Gemini
const result = await this.uploadFile(file)
return {
fileData: {
fileUri: result.uri,
mimeType: result.mimeType
} as Part['fileData']
}
}
/**
* Get the message contents
* @param message - The message
* @returns The message contents
*/
private async convertMessageToSdkParam(message: Message): Promise<Content> {
const role = message.role === 'user' ? 'user' : 'model'
const parts: Part[] = [{ text: await this.getMessageContent(message) }]
// Add any generated images from previous responses
const imageBlocks = findImageBlocks(message)
for (const imageBlock of imageBlocks) {
if (
imageBlock.metadata?.generateImageResponse?.images &&
imageBlock.metadata.generateImageResponse.images.length > 0
) {
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
if (imageUrl && imageUrl.startsWith('data:')) {
// Extract base64 data and mime type from the data URL
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
if (matches && matches.length === 3) {
const mimeType = matches[1]
const base64Data = matches[2]
parts.push({
inlineData: {
data: base64Data,
mimeType: mimeType
} as Part['inlineData']
})
}
}
}
}
const file = imageBlock.file
if (file) {
const base64Data = await window.api.file.base64Image(file.id + file.ext)
parts.push({
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
} as Part['inlineData']
})
}
}
const fileBlocks = findFileBlocks(message)
for (const fileBlock of fileBlocks) {
const file = fileBlock.file
if (file.type === FileTypes.IMAGE) {
const base64Data = await window.api.file.base64Image(file.id + file.ext)
parts.push({
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
} as Part['inlineData']
})
}
if (file.ext === '.pdf') {
parts.push(await this.handlePdfFile(file))
continue
}
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
parts.push({
text: file.origin_name + '\n' + fileContent
})
}
}
return {
role,
parts: parts
}
}
// @ts-ignore unused
private async getImageFileContents(message: Message): Promise<Content> {
const role = message.role === 'user' ? 'user' : 'model'
const content = getMainTextContent(message)
const parts: Part[] = [{ text: content }]
const imageBlocks = findImageBlocks(message)
for (const imageBlock of imageBlocks) {
if (
imageBlock.metadata?.generateImageResponse?.images &&
imageBlock.metadata.generateImageResponse.images.length > 0
) {
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
if (imageUrl && imageUrl.startsWith('data:')) {
// Extract base64 data and mime type from the data URL
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
if (matches && matches.length === 3) {
const mimeType = matches[1]
const base64Data = matches[2]
parts.push({
inlineData: {
data: base64Data,
mimeType: mimeType
} as Part['inlineData']
})
}
}
}
}
const file = imageBlock.file
if (file) {
const base64Data = await window.api.file.base64Image(file.id + file.ext)
parts.push({
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
} as Part['inlineData']
})
}
}
return {
role,
parts: parts
}
}
/**
* Get the safety settings
* @returns The safety settings
*/
private getSafetySettings(): SafetySetting[] {
const safetyThreshold = 'OFF' as HarmBlockThreshold
return [
{
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold: safetyThreshold
},
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold: safetyThreshold
},
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: safetyThreshold
},
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: safetyThreshold
},
{
category: HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY,
threshold: HarmBlockThreshold.BLOCK_NONE
}
]
}
/**
* Get the reasoning effort for the assistant
* @param assistant - The assistant
* @param model - The model
* @returns The reasoning effort
*/
private getBudgetToken(assistant: Assistant, model: Model) {
if (isGeminiReasoningModel(model)) {
const reasoningEffort = assistant?.settings?.reasoning_effort
// 如果thinking_budget是undefined不思考
if (reasoningEffort === undefined) {
return {
thinkingConfig: {
includeThoughts: false,
...(GEMINI_FLASH_MODEL_REGEX.test(model.id) ? { thinkingBudget: 0 } : {})
} as ThinkingConfig
}
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
if (effortRatio > 1) {
return {
thinkingConfig: {
includeThoughts: true
}
}
}
const { min, max } = findTokenLimit(model.id) || { min: 0, max: 0 }
// 计算 budgetTokens确保不低于 min
const budget = Math.floor((max - min) * effortRatio + min)
return {
thinkingConfig: {
...(budget > 0 ? { thinkingBudget: budget } : {}),
includeThoughts: true
} as ThinkingConfig
}
}
return {}
}
private getGenerateImageParameter(): Partial<GenerateContentConfig> {
return {
systemInstruction: undefined,
responseModalities: [Modality.TEXT, Modality.IMAGE],
responseMimeType: 'text/plain'
}
}
getRequestTransformer(): RequestTransformer<GeminiSdkParams, GeminiSdkMessageParam> {
return {
transform: async (
coreRequest,
assistant,
model,
isRecursiveCall,
recursiveSdkMessages
): Promise<{
payload: GeminiSdkParams
messages: GeminiSdkMessageParam[]
metadata: Record<string, any>
}> => {
const { messages, mcpTools, maxTokens, enableWebSearch, enableGenerateImage } = coreRequest
// 1. 处理系统消息
let systemInstruction = assistant.prompt
// 2. 设置工具
const { tools } = this.setupToolsConfig({
mcpTools,
model,
enableToolUse: isEnabledToolUse(assistant)
})
if (this.useSystemPromptForTools) {
systemInstruction = await buildSystemPrompt(assistant.prompt || '', mcpTools, assistant)
}
let messageContents: Content = { role: 'user', parts: [] } // Initialize messageContents
const history: Content[] = []
// 3. 处理用户消息
if (typeof messages === 'string') {
messageContents = {
role: 'user',
parts: [{ text: messages }]
}
} else {
const userLastMessage = messages.pop()
if (userLastMessage) {
messageContents = await this.convertMessageToSdkParam(userLastMessage)
for (const message of messages) {
history.push(await this.convertMessageToSdkParam(message))
}
messages.push(userLastMessage)
}
}
if (enableWebSearch) {
tools.push({
googleSearch: {}
})
}
if (isGemmaModel(model) && assistant.prompt) {
const isFirstMessage = history.length === 0
if (isFirstMessage && messageContents) {
const userMessageText =
messageContents.parts && messageContents.parts.length > 0
? (messageContents.parts[0] as Part).text || ''
: ''
const systemMessage = [
{
text:
'<start_of_turn>user\n' +
systemInstruction +
'<end_of_turn>\n' +
'<start_of_turn>user\n' +
userMessageText +
'<end_of_turn>'
}
] as Part[]
if (messageContents && messageContents.parts) {
messageContents.parts[0] = systemMessage[0]
}
}
}
const newHistory =
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
? recursiveSdkMessages.slice(0, recursiveSdkMessages.length - 1)
: history
const newMessageContents =
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
? recursiveSdkMessages[recursiveSdkMessages.length - 1]
: messageContents
const generateContentConfig: GenerateContentConfig = {
safetySettings: this.getSafetySettings(),
systemInstruction: isGemmaModel(model) ? undefined : systemInstruction,
temperature: this.getTemperature(assistant, model),
topP: this.getTopP(assistant, model),
maxOutputTokens: maxTokens,
tools: tools,
...(enableGenerateImage ? this.getGenerateImageParameter() : {}),
...this.getBudgetToken(assistant, model),
...this.getCustomParameters(assistant)
}
const param: GeminiSdkParams = {
model: model.id,
config: generateContentConfig,
history: newHistory,
message: newMessageContents.parts!
}
return {
payload: param,
messages: [messageContents],
metadata: {}
}
}
}
}
getResponseChunkTransformer(): ResponseChunkTransformer<GeminiSdkRawChunk> {
return () => ({
async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
const toolCalls: FunctionCall[] = []
if (chunk.candidates && chunk.candidates.length > 0) {
for (const candidate of chunk.candidates) {
if (candidate.content) {
candidate.content.parts?.forEach((part) => {
const text = part.text || ''
if (part.thought) {
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: text
})
} else if (part.text) {
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: text
})
} else if (part.inlineData) {
controller.enqueue({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: [
part.inlineData?.data?.startsWith('data:')
? part.inlineData?.data
: `data:${part.inlineData?.mimeType || 'image/png'};base64,${part.inlineData?.data}`
]
}
})
} else if (part.functionCall) {
toolCalls.push(part.functionCall)
}
})
}
if (candidate.finishReason) {
if (candidate.groundingMetadata) {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: candidate.groundingMetadata,
source: WebSearchSource.GEMINI
}
} as LLMWebSearchCompleteChunk)
}
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
completion_tokens:
(chunk.usageMetadata?.totalTokenCount || 0) - (chunk.usageMetadata?.promptTokenCount || 0),
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
}
}
})
}
}
}
if (toolCalls.length > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
}
}
})
}
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): Tool[] {
return mcpToolsToGeminiTools(mcpTools)
}
public convertSdkToolCallToMcp(toolCall: GeminiSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
return geminiFunctionCallToMcpTool(mcpTools, toolCall)
}
public convertSdkToolCallToMcpToolResponse(toolCall: GeminiSdkToolCall, mcpTool: MCPTool): ToolCallResponse {
const parsedArgs = (() => {
try {
return typeof toolCall.args === 'string' ? JSON.parse(toolCall.args) : toolCall.args
} catch {
return toolCall.args
}
})()
return {
id: toolCall.id || nanoid(),
toolCallId: toolCall.id,
tool: mcpTool,
arguments: parsedArgs,
status: 'pending'
} as ToolCallResponse
}
public convertMcpToolResponseToSdkMessageParam(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): GeminiSdkMessageParam | undefined {
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
return mcpToolCallResponseToGeminiMessage(mcpToolResponse, resp, isVisionModel(model))
} else if ('toolCallId' in mcpToolResponse) {
return {
role: 'user',
parts: [
{
functionResponse: {
id: mcpToolResponse.toolCallId,
name: mcpToolResponse.tool.id,
response: {
output: !resp.isError ? resp.content : undefined,
error: resp.isError ? resp.content : undefined
}
}
}
]
} satisfies Content
}
return
}
public buildSdkMessages(
currentReqMessages: Content[],
output: string,
toolResults: Content[],
toolCalls: FunctionCall[]
): Content[] {
const parts: Part[] = []
if (output) {
parts.push({
text: output
})
}
toolCalls.forEach((toolCall) => {
parts.push({
functionCall: toolCall
})
})
parts.push(
...toolResults
.map((ts) => ts.parts)
.flat()
.filter((p) => p !== undefined)
)
const lastMessage = currentReqMessages[currentReqMessages.length - 1]
if (lastMessage) {
lastMessage.parts?.push(...parts)
}
return currentReqMessages
}
override estimateMessageTokens(message: GeminiSdkMessageParam): number {
return (
message.parts?.reduce((acc, part) => {
if (part.text) {
return acc + estimateTextTokens(part.text)
}
if (part.functionCall) {
return acc + estimateTextTokens(JSON.stringify(part.functionCall))
}
if (part.functionResponse) {
return acc + estimateTextTokens(JSON.stringify(part.functionResponse.response))
}
if (part.inlineData) {
return acc + estimateTextTokens(part.inlineData.data || '')
}
if (part.fileData) {
return acc + estimateTextTokens(part.fileData.fileUri || '')
}
return acc
}, 0) || 0
)
}
public extractMessagesFromSdkPayload(sdkPayload: GeminiSdkParams): GeminiSdkMessageParam[] {
const messageParam: GeminiSdkMessageParam = {
role: 'user',
parts: []
}
if (Array.isArray(sdkPayload.message)) {
sdkPayload.message.forEach((part) => {
if (typeof part === 'string') {
messageParam.parts?.push({ text: part })
} else if (typeof part === 'object') {
messageParam.parts?.push(part)
}
})
}
return [messageParam, ...(sdkPayload.history || [])]
}
private async uploadFile(file: FileType): Promise<File> {
return await this.sdkInstance!.files.upload({
file: file.path,
config: {
mimeType: 'application/pdf',
name: file.id,
displayName: file.origin_name
}
})
}
private async base64File(file: FileType) {
const { data } = await window.api.file.base64File(file.id + file.ext)
return {
data,
mimeType: 'application/pdf'
}
}
private async retrieveFile(file: FileType): Promise<File | undefined> {
const cachedResponse = CacheService.get<any>('gemini_file_list')
if (cachedResponse) {
return this.processResponse(cachedResponse, file)
}
const response = await this.sdkInstance!.files.list()
CacheService.set('gemini_file_list', response, 3000)
return this.processResponse(response, file)
}
private async processResponse(response: Pager<File>, file: FileType) {
for await (const f of response) {
if (f.state === FileState.ACTIVE) {
if (f.displayName === file.origin_name && Number(f.sizeBytes) === file.size) {
return f
}
}
}
return undefined
}
// @ts-ignore unused
private async listFiles(): Promise<File[]> {
const files: File[] = []
for await (const f of await this.sdkInstance!.files.list()) {
files.push(f)
}
return files
}
// @ts-ignore unused
private async deleteFile(fileId: string) {
await this.sdkInstance!.files.delete({ name: fileId })
}
}

View File

@@ -1,95 +0,0 @@
import { GoogleGenAI } from '@google/genai'
import { getVertexAILocation, getVertexAIProjectId, getVertexAIServiceAccount } from '@renderer/hooks/useVertexAI'
import { Provider } from '@renderer/types'
import { GeminiAPIClient } from './GeminiAPIClient'
export class VertexAPIClient extends GeminiAPIClient {
private authHeaders?: Record<string, string>
private authHeadersExpiry?: number
constructor(provider: Provider) {
super(provider)
}
override async getSdkInstance() {
if (this.sdkInstance) {
return this.sdkInstance
}
const serviceAccount = getVertexAIServiceAccount()
const projectId = getVertexAIProjectId()
const location = getVertexAILocation()
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId || !location) {
throw new Error('Vertex AI settings are not configured')
}
const authHeaders = await this.getServiceAccountAuthHeaders()
this.sdkInstance = new GoogleGenAI({
vertexai: true,
project: projectId,
location: location,
httpOptions: {
apiVersion: this.getApiVersion(),
headers: authHeaders
}
})
return this.sdkInstance
}
/**
* 获取认证头,如果配置了 service account 则从主进程获取
*/
private async getServiceAccountAuthHeaders(): Promise<Record<string, string> | undefined> {
const serviceAccount = getVertexAIServiceAccount()
const projectId = getVertexAIProjectId()
// 检查是否配置了 service account
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId) {
return undefined
}
// 检查是否已有有效的认证头(提前 5 分钟过期)
const now = Date.now()
if (this.authHeaders && this.authHeadersExpiry && this.authHeadersExpiry - now > 5 * 60 * 1000) {
return this.authHeaders
}
try {
// 从主进程获取认证头
this.authHeaders = await window.api.vertexAI.getAuthHeaders({
projectId,
serviceAccount: {
privateKey: serviceAccount.privateKey,
clientEmail: serviceAccount.clientEmail
}
})
// 设置过期时间(通常认证头有效期为 1 小时)
this.authHeadersExpiry = now + 60 * 60 * 1000
return this.authHeaders
} catch (error: any) {
console.error('Failed to get auth headers:', error)
throw new Error(`Service Account authentication failed: ${error.message}`)
}
}
/**
* 清理认证缓存并重新初始化
*/
clearAuthCache(): void {
this.authHeaders = undefined
this.authHeadersExpiry = undefined
const serviceAccount = getVertexAIServiceAccount()
const projectId = getVertexAIProjectId()
if (projectId && serviceAccount.clientEmail) {
window.api.vertexAI.clearAuthCache(projectId, serviceAccount.clientEmail)
}
}
}

View File

@@ -1,6 +0,0 @@
export * from './ApiClientFactory'
export * from './BaseApiClient'
export * from './types'
// Export specific clients from subdirectories
export * from './openai/OpenAIApiClient'

View File

@@ -1,722 +0,0 @@
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import Logger from '@renderer/config/logger'
import {
findTokenLimit,
GEMINI_FLASH_MODEL_REGEX,
getOpenAIWebSearchParams,
isDoubaoThinkingAutoModel,
isReasoningModel,
isSupportedReasoningEffortGrokModel,
isSupportedReasoningEffortModel,
isSupportedReasoningEffortOpenAIModel,
isSupportedThinkingTokenClaudeModel,
isSupportedThinkingTokenDoubaoModel,
isSupportedThinkingTokenGeminiModel,
isSupportedThinkingTokenModel,
isSupportedThinkingTokenQwenModel,
isVisionModel
} from '@renderer/config/models'
import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService'
import { estimateTextTokens } from '@renderer/services/TokenService'
// For Copilot token
import {
Assistant,
EFFORT_RATIO,
FileTypes,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
ToolCallResponse,
WebSearchSource
} from '@renderer/types'
import { ChunkType } from '@renderer/types/chunk'
import { Message } from '@renderer/types/newMessage'
import {
OpenAISdkMessageParam,
OpenAISdkParams,
OpenAISdkRawChunk,
OpenAISdkRawContentSource,
OpenAISdkRawOutput,
ReasoningEffortOptionalParams
} from '@renderer/types/sdk'
import { addImageFileToContents } from '@renderer/utils/formats'
import {
isEnabledToolUse,
mcpToolCallResponseToOpenAICompatibleMessage,
mcpToolsToOpenAIChatTools,
openAIToolsToMcpTool
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import OpenAI, { AzureOpenAI } from 'openai'
import { ChatCompletionContentPart, ChatCompletionContentPartRefusal, ChatCompletionTool } from 'openai/resources'
import { GenericChunk } from '../../middleware/schemas'
import { RequestTransformer, ResponseChunkTransformer, ResponseChunkTransformerContext } from '../types'
import { OpenAIBaseClient } from './OpenAIBaseClient'
export class OpenAIAPIClient extends OpenAIBaseClient<
OpenAI | AzureOpenAI,
OpenAISdkParams,
OpenAISdkRawOutput,
OpenAISdkRawChunk,
OpenAISdkMessageParam,
OpenAI.Chat.Completions.ChatCompletionMessageToolCall,
ChatCompletionTool
> {
constructor(provider: Provider) {
super(provider)
}
override async createCompletions(
payload: OpenAISdkParams,
options?: OpenAI.RequestOptions
): Promise<OpenAISdkRawOutput> {
const sdk = await this.getSdkInstance()
// @ts-ignore - SDK参数可能有额外的字段
return await sdk.chat.completions.create(payload, options)
}
/**
* Get the reasoning effort for the assistant
* @param assistant - The assistant
* @param model - The model
* @returns The reasoning effort
*/
// Method for reasoning effort, moved from OpenAIProvider
override getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams {
if (this.provider.id === 'groq') {
return {}
}
if (!isReasoningModel(model)) {
return {}
}
const reasoningEffort = assistant?.settings?.reasoning_effort
// Doubao 思考模式支持
if (isSupportedThinkingTokenDoubaoModel(model)) {
// reasoningEffort 为空,默认开启 enabled
if (!reasoningEffort) {
return { thinking: { type: 'disabled' } }
}
if (reasoningEffort === 'high') {
return { thinking: { type: 'enabled' } }
}
if (reasoningEffort === 'auto' && isDoubaoThinkingAutoModel(model)) {
return { thinking: { type: 'auto' } }
}
// 其他情况不带 thinking 字段
return {}
}
if (!reasoningEffort) {
if (isSupportedThinkingTokenQwenModel(model)) {
return { enable_thinking: false }
}
if (isSupportedThinkingTokenClaudeModel(model)) {
return {}
}
if (isSupportedThinkingTokenGeminiModel(model)) {
// openrouter没有提供一个不推理的选项先隐藏
if (this.provider.id === 'openrouter') {
return { reasoning: { max_tokens: 0, exclude: true } }
}
if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
return { reasoning_effort: 'none' }
}
return {}
}
if (isSupportedThinkingTokenDoubaoModel(model)) {
return { thinking: { type: 'disabled' } }
}
return {}
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
const budgetTokens = Math.floor(
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + findTokenLimit(model.id)?.min!
)
// OpenRouter models
if (model.provider === 'openrouter') {
if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) {
return {
reasoning: {
effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort
}
}
}
}
// Qwen models
if (isSupportedThinkingTokenQwenModel(model)) {
return {
enable_thinking: true,
thinking_budget: budgetTokens
}
}
// Grok models
if (isSupportedReasoningEffortGrokModel(model)) {
return {
reasoning_effort: reasoningEffort
}
}
// OpenAI models
if (isSupportedReasoningEffortOpenAIModel(model) || isSupportedThinkingTokenGeminiModel(model)) {
return {
reasoning_effort: reasoningEffort
}
}
// Claude models
if (isSupportedThinkingTokenClaudeModel(model)) {
const maxTokens = assistant.settings?.maxTokens
return {
thinking: {
type: 'enabled',
budget_tokens: Math.floor(
Math.max(1024, Math.min(budgetTokens, (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio))
)
}
}
}
// Doubao models
if (isSupportedThinkingTokenDoubaoModel(model)) {
if (assistant.settings?.reasoning_effort === 'high') {
return {
thinking: {
type: 'enabled'
}
}
}
}
// Default case: no special thinking settings
return {}
}
/**
* Check if the provider does not support files
* @returns True if the provider does not support files, false otherwise
*/
private get isNotSupportFiles() {
if (this.provider?.isNotSupportArrayContent) {
return true
}
const providers = ['deepseek', 'baichuan', 'minimax', 'xirang']
return providers.includes(this.provider.id)
}
/**
* Get the message parameter
* @param message - The message
* @param model - The model
* @returns The message parameter
*/
public async convertMessageToSdkParam(message: Message, model: Model): Promise<OpenAISdkMessageParam> {
const isVision = isVisionModel(model)
const content = await this.getMessageContent(message)
const fileBlocks = findFileBlocks(message)
const imageBlocks = findImageBlocks(message)
if (fileBlocks.length === 0 && imageBlocks.length === 0) {
return {
role: message.role === 'system' ? 'user' : message.role,
content
} as OpenAISdkMessageParam
}
// If the model does not support files, extract the file content
if (this.isNotSupportFiles) {
const fileContent = await this.extractFileContent(message)
return {
role: message.role === 'system' ? 'user' : message.role,
content: content + '\n\n---\n\n' + fileContent
} as OpenAISdkMessageParam
}
// If the model supports files, add the file content to the message
const parts: ChatCompletionContentPart[] = []
if (content) {
parts.push({ type: 'text', text: content })
}
for (const imageBlock of imageBlocks) {
if (isVision) {
if (imageBlock.file) {
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
parts.push({ type: 'image_url', image_url: { url: image.data } })
} else if (imageBlock.url && imageBlock.url.startsWith('data:')) {
parts.push({ type: 'image_url', image_url: { url: imageBlock.url } })
}
}
}
for (const fileBlock of fileBlocks) {
const file = fileBlock.file
if (!file) {
continue
}
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
parts.push({
type: 'text',
text: file.origin_name + '\n' + fileContent
})
}
}
return {
role: message.role === 'system' ? 'user' : message.role,
content: parts
} as OpenAISdkMessageParam
}
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ChatCompletionTool[] {
return mcpToolsToOpenAIChatTools(mcpTools)
}
public convertSdkToolCallToMcp(
toolCall: OpenAI.Chat.Completions.ChatCompletionMessageToolCall,
mcpTools: MCPTool[]
): MCPTool | undefined {
return openAIToolsToMcpTool(mcpTools, toolCall)
}
public convertSdkToolCallToMcpToolResponse(
toolCall: OpenAI.Chat.Completions.ChatCompletionMessageToolCall,
mcpTool: MCPTool
): ToolCallResponse {
let parsedArgs: any
try {
parsedArgs = JSON.parse(toolCall.function.arguments)
} catch {
parsedArgs = toolCall.function.arguments
}
return {
id: toolCall.id,
toolCallId: toolCall.id,
tool: mcpTool,
arguments: parsedArgs,
status: 'pending'
} as ToolCallResponse
}
public convertMcpToolResponseToSdkMessageParam(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): OpenAISdkMessageParam | undefined {
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
// This case is for Anthropic/Claude like tool usage, OpenAI uses tool_call_id
// For OpenAI, we primarily expect toolCallId. This might need adjustment if mixing provider concepts.
return mcpToolCallResponseToOpenAICompatibleMessage(mcpToolResponse, resp, isVisionModel(model))
} else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) {
return {
role: 'tool',
tool_call_id: mcpToolResponse.toolCallId,
content: JSON.stringify(resp.content)
} as OpenAI.Chat.Completions.ChatCompletionToolMessageParam
}
return undefined
}
public buildSdkMessages(
currentReqMessages: OpenAISdkMessageParam[],
output: string | undefined,
toolResults: OpenAISdkMessageParam[],
toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[]
): OpenAISdkMessageParam[] {
if (!output && toolCalls.length === 0) {
return [...currentReqMessages, ...toolResults]
}
const assistantMessage: OpenAISdkMessageParam = {
role: 'assistant',
content: output,
tool_calls: toolCalls.length > 0 ? toolCalls : undefined
}
const newReqMessages = [...currentReqMessages, assistantMessage, ...toolResults]
return newReqMessages
}
override estimateMessageTokens(message: OpenAISdkMessageParam): number {
let sum = 0
if (typeof message.content === 'string') {
sum += estimateTextTokens(message.content)
} else if (Array.isArray(message.content)) {
sum += (message.content || [])
.map((part: ChatCompletionContentPart | ChatCompletionContentPartRefusal) => {
switch (part.type) {
case 'text':
return estimateTextTokens(part.text)
case 'image_url':
return estimateTextTokens(part.image_url.url)
case 'input_audio':
return estimateTextTokens(part.input_audio.data)
case 'file':
return estimateTextTokens(part.file.file_data || '')
default:
return 0
}
})
.reduce((acc, curr) => acc + curr, 0)
}
if ('tool_calls' in message && message.tool_calls) {
sum += message.tool_calls.reduce((acc, toolCall) => {
return acc + estimateTextTokens(JSON.stringify(toolCall.function.arguments))
}, 0)
}
return sum
}
public extractMessagesFromSdkPayload(sdkPayload: OpenAISdkParams): OpenAISdkMessageParam[] {
return sdkPayload.messages || []
}
getRequestTransformer(): RequestTransformer<OpenAISdkParams, OpenAISdkMessageParam> {
return {
transform: async (
coreRequest,
assistant,
model,
isRecursiveCall,
recursiveSdkMessages
): Promise<{
payload: OpenAISdkParams
messages: OpenAISdkMessageParam[]
metadata: Record<string, any>
}> => {
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
// 1. 处理系统消息
let systemMessage = { role: 'system', content: assistant.prompt || '' }
if (isSupportedReasoningEffortOpenAIModel(model)) {
systemMessage = {
role: 'developer',
content: `Formatting re-enabled${systemMessage ? '\n' + systemMessage.content : ''}`
}
}
if (model.id.includes('o1-mini') || model.id.includes('o1-preview')) {
systemMessage.role = 'assistant'
}
// 2. 设置工具必须在this.usesystemPromptForTools前面
const { tools } = this.setupToolsConfig({
mcpTools: mcpTools,
model,
enableToolUse: isEnabledToolUse(assistant)
})
if (this.useSystemPromptForTools) {
systemMessage.content = await buildSystemPrompt(systemMessage.content || '', mcpTools, assistant)
}
// 3. 处理用户消息
const userMessages: OpenAISdkMessageParam[] = []
if (typeof messages === 'string') {
userMessages.push({ role: 'user', content: messages })
} else {
const processedMessages = addImageFileToContents(messages)
for (const message of processedMessages) {
userMessages.push(await this.convertMessageToSdkParam(message, model))
}
}
const lastUserMsg = userMessages.findLast((m) => m.role === 'user')
if (lastUserMsg && isSupportedThinkingTokenQwenModel(model)) {
const postsuffix = '/no_think'
const qwenThinkModeEnabled = assistant.settings?.qwenThinkMode === true
const currentContent = lastUserMsg.content
lastUserMsg.content = processPostsuffixQwen3Model(currentContent, postsuffix, qwenThinkModeEnabled) as any
}
// 4. 最终请求消息
let reqMessages: OpenAISdkMessageParam[]
if (!systemMessage.content) {
reqMessages = [...userMessages]
} else {
reqMessages = [systemMessage, ...userMessages].filter(Boolean) as OpenAISdkMessageParam[]
}
reqMessages = processReqMessages(model, reqMessages)
// 5. 创建通用参数
const commonParams = {
model: model.id,
messages:
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
? recursiveSdkMessages
: reqMessages,
temperature: this.getTemperature(assistant, model),
top_p: this.getTopP(assistant, model),
max_tokens: maxTokens,
tools: tools.length > 0 ? tools : undefined,
service_tier: this.getServiceTier(model),
...this.getProviderSpecificParameters(assistant, model),
...this.getReasoningEffort(assistant, model),
...getOpenAIWebSearchParams(model, enableWebSearch),
...this.getCustomParameters(assistant)
}
// Create the appropriate parameters object based on whether streaming is enabled
const sdkParams: OpenAISdkParams = streamOutput
? {
...commonParams,
stream: true
}
: {
...commonParams,
stream: false
}
const timeout = this.getTimeout(model)
return { payload: sdkParams, messages: reqMessages, metadata: { timeout } }
}
}
}
// 在RawSdkChunkToGenericChunkMiddleware中使用
getResponseChunkTransformer(): ResponseChunkTransformer<OpenAISdkRawChunk> {
let hasBeenCollectedWebSearch = false
const collectWebSearchData = (
chunk: OpenAISdkRawChunk,
contentSource: OpenAISdkRawContentSource,
context: ResponseChunkTransformerContext
) => {
if (hasBeenCollectedWebSearch) {
return
}
// OpenAI annotations
// @ts-ignore - annotations may not be in standard type definitions
const annotations = contentSource.annotations || chunk.annotations
if (annotations && annotations.length > 0 && annotations[0].type === 'url_citation') {
hasBeenCollectedWebSearch = true
return {
results: annotations,
source: WebSearchSource.OPENAI
}
}
// Grok citations
// @ts-ignore - citations may not be in standard type definitions
if (context.provider?.id === 'grok' && chunk.citations) {
hasBeenCollectedWebSearch = true
return {
// @ts-ignore - citations may not be in standard type definitions
results: chunk.citations,
source: WebSearchSource.GROK
}
}
// Perplexity citations
// @ts-ignore - citations may not be in standard type definitions
if (context.provider?.id === 'perplexity' && chunk.citations && chunk.citations.length > 0) {
hasBeenCollectedWebSearch = true
return {
// @ts-ignore - citations may not be in standard type definitions
results: chunk.citations,
source: WebSearchSource.PERPLEXITY
}
}
// OpenRouter citations
// @ts-ignore - citations may not be in standard type definitions
if (context.provider?.id === 'openrouter' && chunk.citations && chunk.citations.length > 0) {
hasBeenCollectedWebSearch = true
return {
// @ts-ignore - citations may not be in standard type definitions
results: chunk.citations,
source: WebSearchSource.OPENROUTER
}
}
// Zhipu web search
// @ts-ignore - web_search may not be in standard type definitions
if (context.provider?.id === 'zhipu' && chunk.web_search) {
hasBeenCollectedWebSearch = true
return {
// @ts-ignore - web_search may not be in standard type definitions
results: chunk.web_search,
source: WebSearchSource.ZHIPU
}
}
// Hunyuan web search
// @ts-ignore - search_info may not be in standard type definitions
if (context.provider?.id === 'hunyuan' && chunk.search_info?.search_results) {
hasBeenCollectedWebSearch = true
return {
// @ts-ignore - search_info may not be in standard type definitions
results: chunk.search_info.search_results,
source: WebSearchSource.HUNYUAN
}
}
// TODO: 放到AnthropicApiClient中
// // Other providers...
// // @ts-ignore - web_search may not be in standard type definitions
// if (chunk.web_search) {
// const sourceMap: Record<string, string> = {
// openai: 'openai',
// anthropic: 'anthropic',
// qwenlm: 'qwen'
// }
// const source = sourceMap[context.provider?.id] || 'openai_response'
// return {
// results: chunk.web_search,
// source: source as const
// }
// }
return null
}
const toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] = []
let isFinished = false
let lastUsageInfo: any = null
/**
* 统一的完成信号发送逻辑
* - 有 finish_reason 时
* - 无 finish_reason 但是流正常结束时
*/
const emitCompletionSignals = (controller: TransformStreamDefaultController<GenericChunk>) => {
if (isFinished) return
if (toolCalls.length > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
}
const usage = lastUsageInfo || {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0
}
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: { usage }
})
// 防止重复发送
isFinished = true
}
return (context: ResponseChunkTransformerContext) => ({
async transform(chunk: OpenAISdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
// 持续更新usage信息
if (chunk.usage) {
lastUsageInfo = {
prompt_tokens: chunk.usage.prompt_tokens || 0,
completion_tokens: chunk.usage.completion_tokens || 0,
total_tokens: (chunk.usage.prompt_tokens || 0) + (chunk.usage.completion_tokens || 0)
}
}
// 处理chunk
if ('choices' in chunk && chunk.choices && chunk.choices.length > 0) {
const choice = chunk.choices[0]
if (!choice) return
// 对于流式响应使用delta对于非流式响应使用message
const contentSource: OpenAISdkRawContentSource | null =
'delta' in choice ? choice.delta : 'message' in choice ? choice.message : null
if (!contentSource) return
const webSearchData = collectWebSearchData(chunk, contentSource, context)
if (webSearchData) {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: webSearchData
})
}
// 处理推理内容 (e.g. from OpenRouter DeepSeek-R1)
// @ts-ignore - reasoning_content is not in standard OpenAI types but some providers use it
const reasoningText = contentSource.reasoning_content || contentSource.reasoning
if (reasoningText) {
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: reasoningText
})
}
// 处理文本内容
if (contentSource.content) {
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: contentSource.content
})
}
// 处理工具调用
if (contentSource.tool_calls) {
for (const toolCall of contentSource.tool_calls) {
if ('index' in toolCall) {
const { id, index, function: fun } = toolCall
if (fun?.name) {
toolCalls[index] = {
id: id || '',
function: {
name: fun.name,
arguments: fun.arguments || ''
},
type: 'function'
}
} else if (fun?.arguments) {
toolCalls[index].function.arguments += fun.arguments
}
} else {
toolCalls.push(toolCall)
}
}
}
// 处理finish_reason发送流结束信号
if ('finish_reason' in choice && choice.finish_reason) {
Logger.debug(`[OpenAIApiClient] Stream finished with reason: ${choice.finish_reason}`)
const webSearchData = collectWebSearchData(chunk, contentSource, context)
if (webSearchData) {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: webSearchData
})
}
emitCompletionSignals(controller)
}
}
},
// 流正常结束时,检查是否需要发送完成信号
flush(controller) {
if (isFinished) return
Logger.debug('[OpenAIApiClient] Stream ended without finish_reason, emitting fallback completion signals')
emitCompletionSignals(controller)
}
})
}
}

View File

@@ -1,255 +0,0 @@
import {
isClaudeReasoningModel,
isNotSupportTemperatureAndTopP,
isOpenAIReasoningModel,
isSupportedModel,
isSupportedReasoningEffortOpenAIModel
} from '@renderer/config/models'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import store from '@renderer/store'
import { SettingsState } from '@renderer/store/settings'
import { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
import {
OpenAIResponseSdkMessageParam,
OpenAIResponseSdkParams,
OpenAIResponseSdkRawChunk,
OpenAIResponseSdkRawOutput,
OpenAIResponseSdkTool,
OpenAIResponseSdkToolCall,
OpenAISdkMessageParam,
OpenAISdkParams,
OpenAISdkRawChunk,
OpenAISdkRawOutput,
ReasoningEffortOptionalParams
} from '@renderer/types/sdk'
import { formatApiHost } from '@renderer/utils/api'
import OpenAI, { AzureOpenAI } from 'openai'
import { BaseApiClient } from '../BaseApiClient'
/**
* 抽象的OpenAI基础客户端类包含两个OpenAI客户端之间的共享功能
*/
export abstract class OpenAIBaseClient<
TSdkInstance extends OpenAI | AzureOpenAI,
TSdkParams extends OpenAISdkParams | OpenAIResponseSdkParams,
TRawOutput extends OpenAISdkRawOutput | OpenAIResponseSdkRawOutput,
TRawChunk extends OpenAISdkRawChunk | OpenAIResponseSdkRawChunk,
TMessageParam extends OpenAISdkMessageParam | OpenAIResponseSdkMessageParam,
TToolCall extends OpenAI.Chat.Completions.ChatCompletionMessageToolCall | OpenAIResponseSdkToolCall,
TSdkSpecificTool extends OpenAI.Chat.Completions.ChatCompletionTool | OpenAIResponseSdkTool
> extends BaseApiClient<TSdkInstance, TSdkParams, TRawOutput, TRawChunk, TMessageParam, TToolCall, TSdkSpecificTool> {
constructor(provider: Provider) {
super(provider)
}
// 仅适用于openai
override getBaseURL(): string {
const host = this.provider.apiHost
return formatApiHost(host)
}
override async generateImage({
model,
prompt,
negativePrompt,
imageSize,
batchSize,
seed,
numInferenceSteps,
guidanceScale,
signal,
promptEnhancement
}: GenerateImageParams): Promise<string[]> {
const sdk = await this.getSdkInstance()
const response = (await sdk.request({
method: 'post',
path: '/images/generations',
signal,
body: {
model,
prompt,
negative_prompt: negativePrompt,
image_size: imageSize,
batch_size: batchSize,
seed: seed ? parseInt(seed) : undefined,
num_inference_steps: numInferenceSteps,
guidance_scale: guidanceScale,
prompt_enhancement: promptEnhancement
}
})) as { data: Array<{ url: string }> }
return response.data.map((item) => item.url)
}
override async getEmbeddingDimensions(model: Model): Promise<number> {
const sdk = await this.getSdkInstance()
const data = await sdk.embeddings.create({
model: model.id,
input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi',
encoding_format: 'float'
})
return data.data[0].embedding.length
}
override async listModels(): Promise<OpenAI.Models.Model[]> {
try {
const sdk = await this.getSdkInstance()
const response = await sdk.models.list()
if (this.provider.id === 'github') {
// @ts-ignore key is not typed
return response?.body
.map((model) => ({
id: model.name,
description: model.summary,
object: 'model',
owned_by: model.publisher
}))
.filter(isSupportedModel)
}
if (this.provider.id === 'together') {
// @ts-ignore key is not typed
return response?.body.map((model) => ({
id: model.id,
description: model.display_name,
object: 'model',
owned_by: model.organization
}))
}
const models = response.data || []
models.forEach((model) => {
model.id = model.id.trim()
})
return models.filter(isSupportedModel)
} catch (error) {
console.error('Error listing models:', error)
return []
}
}
override async getSdkInstance() {
if (this.sdkInstance) {
return this.sdkInstance
}
let apiKeyForSdkInstance = this.apiKey
if (this.provider.id === 'copilot') {
const defaultHeaders = store.getState().copilot.defaultHeaders
const { token } = await window.api.copilot.getToken(defaultHeaders)
// this.provider.apiKey不允许修改
// this.provider.apiKey = token
apiKeyForSdkInstance = token
}
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
this.sdkInstance = new AzureOpenAI({
dangerouslyAllowBrowser: true,
apiKey: apiKeyForSdkInstance,
apiVersion: this.provider.apiVersion,
endpoint: this.provider.apiHost
}) as TSdkInstance
} else {
this.sdkInstance = new OpenAI({
dangerouslyAllowBrowser: true,
apiKey: apiKeyForSdkInstance,
baseURL: this.getBaseURL(),
defaultHeaders: {
...this.defaultHeaders(),
...(this.provider.id === 'copilot' ? { 'editor-version': 'vscode/1.97.2' } : {}),
...(this.provider.id === 'copilot' ? { 'copilot-vision-request': 'true' } : {})
}
}) as TSdkInstance
}
return this.sdkInstance
}
override getTemperature(assistant: Assistant, model: Model): number | undefined {
if (
isNotSupportTemperatureAndTopP(model) ||
(assistant.settings?.reasoning_effort && isClaudeReasoningModel(model))
) {
return undefined
}
return assistant.settings?.temperature
}
override getTopP(assistant: Assistant, model: Model): number | undefined {
if (
isNotSupportTemperatureAndTopP(model) ||
(assistant.settings?.reasoning_effort && isClaudeReasoningModel(model))
) {
return undefined
}
return assistant.settings?.topP
}
/**
* Get the provider specific parameters for the assistant
* @param assistant - The assistant
* @param model - The model
* @returns The provider specific parameters
*/
protected getProviderSpecificParameters(assistant: Assistant, model: Model) {
const { maxTokens } = getAssistantSettings(assistant)
if (this.provider.id === 'openrouter') {
if (model.id.includes('deepseek-r1')) {
return {
include_reasoning: true
}
}
}
if (isOpenAIReasoningModel(model)) {
return {
max_tokens: undefined,
max_completion_tokens: maxTokens
}
}
return {}
}
/**
* Get the reasoning effort for the assistant
* @param assistant - The assistant
* @param model - The model
* @returns The reasoning effort
*/
protected getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams {
if (!isSupportedReasoningEffortOpenAIModel(model)) {
return {}
}
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
const summaryText = openAI?.summaryText || 'off'
let summary: string | undefined = undefined
if (summaryText === 'off' || model.id.includes('o1-pro')) {
summary = undefined
} else {
summary = summaryText
}
const reasoningEffort = assistant?.settings?.reasoning_effort
if (!reasoningEffort) {
return {}
}
if (isSupportedReasoningEffortOpenAIModel(model)) {
return {
reasoning: {
effort: reasoningEffort as OpenAI.ReasoningEffort,
summary: summary
} as OpenAI.Reasoning
}
}
return {}
}
}

View File

@@ -1,605 +0,0 @@
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { CompletionsContext } from '@renderer/aiCore/middleware/types'
import {
isOpenAIChatCompletionOnlyModel,
isSupportedReasoningEffortOpenAIModel,
isVisionModel
} from '@renderer/config/models'
import { estimateTextTokens } from '@renderer/services/TokenService'
import {
FileType,
FileTypes,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
ToolCallResponse,
WebSearchSource
} from '@renderer/types'
import { ChunkType } from '@renderer/types/chunk'
import { Message } from '@renderer/types/newMessage'
import {
OpenAIResponseSdkMessageParam,
OpenAIResponseSdkParams,
OpenAIResponseSdkRawChunk,
OpenAIResponseSdkRawOutput,
OpenAIResponseSdkTool,
OpenAIResponseSdkToolCall
} from '@renderer/types/sdk'
import { addImageFileToContents } from '@renderer/utils/formats'
import {
isEnabledToolUse,
mcpToolCallResponseToOpenAIMessage,
mcpToolsToOpenAIResponseTools,
openAIToolsToMcpTool
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { MB } from '@shared/config/constant'
import { isEmpty } from 'lodash'
import OpenAI from 'openai'
import { ResponseInput } from 'openai/resources/responses/responses'
import { RequestTransformer, ResponseChunkTransformer } from '../types'
import { OpenAIAPIClient } from './OpenAIApiClient'
import { OpenAIBaseClient } from './OpenAIBaseClient'
export class OpenAIResponseAPIClient extends OpenAIBaseClient<
OpenAI,
OpenAIResponseSdkParams,
OpenAIResponseSdkRawOutput,
OpenAIResponseSdkRawChunk,
OpenAIResponseSdkMessageParam,
OpenAIResponseSdkToolCall,
OpenAIResponseSdkTool
> {
private client: OpenAIAPIClient
constructor(provider: Provider) {
super(provider)
this.client = new OpenAIAPIClient(provider)
}
/**
* 根据模型特征选择合适的客户端
*/
public getClient(model: Model) {
if (isOpenAIChatCompletionOnlyModel(model)) {
return this.client
} else {
return this
}
}
override async getSdkInstance() {
if (this.sdkInstance) {
return this.sdkInstance
}
return new OpenAI({
dangerouslyAllowBrowser: true,
apiKey: this.apiKey,
baseURL: this.getBaseURL(),
defaultHeaders: {
...this.defaultHeaders()
}
})
}
override async createCompletions(
payload: OpenAIResponseSdkParams,
options?: OpenAI.RequestOptions
): Promise<OpenAIResponseSdkRawOutput> {
const sdk = await this.getSdkInstance()
return await sdk.responses.create(payload, options)
}
private async handlePdfFile(file: FileType): Promise<OpenAI.Responses.ResponseInputFile | undefined> {
if (file.size > 32 * MB) return undefined
try {
const pageCount = await window.api.file.pdfInfo(file.id + file.ext)
if (pageCount > 100) return undefined
} catch {
return undefined
}
const { data } = await window.api.file.base64File(file.id + file.ext)
return {
type: 'input_file',
filename: file.origin_name,
file_data: `data:application/pdf;base64,${data}`
} as OpenAI.Responses.ResponseInputFile
}
public async convertMessageToSdkParam(message: Message, model: Model): Promise<OpenAIResponseSdkMessageParam> {
const isVision = isVisionModel(model)
const content = await this.getMessageContent(message)
const fileBlocks = findFileBlocks(message)
const imageBlocks = findImageBlocks(message)
if (fileBlocks.length === 0 && imageBlocks.length === 0) {
if (message.role === 'assistant') {
return {
role: 'assistant',
content: content
}
} else {
return {
role: message.role === 'system' ? 'user' : message.role,
content: content ? [{ type: 'input_text', text: content }] : []
} as OpenAI.Responses.EasyInputMessage
}
}
const parts: OpenAI.Responses.ResponseInputContent[] = []
if (content) {
parts.push({
type: 'input_text',
text: content
})
}
for (const imageBlock of imageBlocks) {
if (isVision) {
if (imageBlock.file) {
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
parts.push({
detail: 'auto',
type: 'input_image',
image_url: image.data as string
})
} else if (imageBlock.url && imageBlock.url.startsWith('data:')) {
parts.push({
detail: 'auto',
type: 'input_image',
image_url: imageBlock.url
})
}
}
}
for (const fileBlock of fileBlocks) {
const file = fileBlock.file
if (!file) continue
if (isVision && file.ext === '.pdf') {
const pdfPart = await this.handlePdfFile(file)
if (pdfPart) {
parts.push(pdfPart)
continue
}
}
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
parts.push({
type: 'input_text',
text: file.origin_name + '\n' + fileContent
})
}
}
return {
role: message.role === 'system' ? 'user' : message.role,
content: parts
}
}
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): OpenAI.Responses.Tool[] {
return mcpToolsToOpenAIResponseTools(mcpTools)
}
public convertSdkToolCallToMcp(toolCall: OpenAIResponseSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
return openAIToolsToMcpTool(mcpTools, toolCall)
}
public convertSdkToolCallToMcpToolResponse(toolCall: OpenAIResponseSdkToolCall, mcpTool: MCPTool): ToolCallResponse {
const parsedArgs = (() => {
try {
return JSON.parse(toolCall.arguments)
} catch {
return toolCall.arguments
}
})()
return {
id: toolCall.call_id,
toolCallId: toolCall.call_id,
tool: mcpTool,
arguments: parsedArgs,
status: 'pending'
}
}
public convertMcpToolResponseToSdkMessageParam(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): OpenAIResponseSdkMessageParam | undefined {
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
return mcpToolCallResponseToOpenAIMessage(mcpToolResponse, resp, isVisionModel(model))
} else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) {
return {
type: 'function_call_output',
call_id: mcpToolResponse.toolCallId,
output: JSON.stringify(resp.content)
}
}
return
}
private convertResponseToMessageContent(response: OpenAI.Responses.Response): ResponseInput {
const content: OpenAI.Responses.ResponseInput = []
content.push(...response.output)
return content
}
public buildSdkMessages(
currentReqMessages: OpenAIResponseSdkMessageParam[],
output: OpenAI.Responses.Response | undefined,
toolResults: OpenAIResponseSdkMessageParam[],
toolCalls: OpenAIResponseSdkToolCall[]
): OpenAIResponseSdkMessageParam[] {
if (!output && toolCalls.length === 0) {
return [...currentReqMessages, ...toolResults]
}
if (!output) {
return [...currentReqMessages, ...(toolCalls || []), ...(toolResults || [])]
}
const content = this.convertResponseToMessageContent(output)
const newReqMessages = [...currentReqMessages, ...content, ...(toolResults || [])]
return newReqMessages
}
override estimateMessageTokens(message: OpenAIResponseSdkMessageParam): number {
let sum = 0
if ('content' in message) {
if (typeof message.content === 'string') {
sum += estimateTextTokens(message.content)
} else if (Array.isArray(message.content)) {
for (const part of message.content) {
switch (part.type) {
case 'input_text':
sum += estimateTextTokens(part.text)
break
case 'input_image':
sum += estimateTextTokens(part.image_url || '')
break
default:
break
}
}
}
}
switch (message.type) {
case 'function_call_output':
sum += estimateTextTokens(message.output)
break
case 'function_call':
sum += estimateTextTokens(message.arguments)
break
default:
break
}
return sum
}
public extractMessagesFromSdkPayload(sdkPayload: OpenAIResponseSdkParams): OpenAIResponseSdkMessageParam[] {
if (typeof sdkPayload.input === 'string') {
return [{ role: 'user', content: sdkPayload.input }]
}
return sdkPayload.input
}
getRequestTransformer(): RequestTransformer<OpenAIResponseSdkParams, OpenAIResponseSdkMessageParam> {
return {
transform: async (
coreRequest,
assistant,
model,
isRecursiveCall,
recursiveSdkMessages
): Promise<{
payload: OpenAIResponseSdkParams
messages: OpenAIResponseSdkMessageParam[]
metadata: Record<string, any>
}> => {
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch, enableGenerateImage } = coreRequest
// 1. 处理系统消息
const systemMessage: OpenAI.Responses.EasyInputMessage = {
role: 'system',
content: []
}
const systemMessageContent: OpenAI.Responses.ResponseInputMessageContentList = []
const systemMessageInput: OpenAI.Responses.ResponseInputText = {
text: assistant.prompt || '',
type: 'input_text'
}
if (isSupportedReasoningEffortOpenAIModel(model)) {
systemMessage.role = 'developer'
}
// 2. 设置工具
let tools: OpenAI.Responses.Tool[] = []
const { tools: extraTools } = this.setupToolsConfig({
mcpTools: mcpTools,
model,
enableToolUse: isEnabledToolUse(assistant)
})
if (this.useSystemPromptForTools) {
systemMessageInput.text = await buildSystemPrompt(systemMessageInput.text || '', mcpTools, assistant)
}
systemMessageContent.push(systemMessageInput)
systemMessage.content = systemMessageContent
// 3. 处理用户消息
let userMessage: OpenAI.Responses.ResponseInputItem[] = []
if (typeof messages === 'string') {
userMessage.push({ role: 'user', content: messages })
} else {
const processedMessages = addImageFileToContents(messages)
for (const message of processedMessages) {
userMessage.push(await this.convertMessageToSdkParam(message, model))
}
}
// FIXME: 最好还是直接使用previous_response_id来处理或者在数据库中存储image_generation_call的id
if (enableGenerateImage) {
const finalAssistantMessage = userMessage.findLast(
(m) => (m as OpenAI.Responses.EasyInputMessage).role === 'assistant'
) as OpenAI.Responses.EasyInputMessage
const finalUserMessage = userMessage.pop() as OpenAI.Responses.EasyInputMessage
if (
finalAssistantMessage &&
Array.isArray(finalAssistantMessage.content) &&
finalUserMessage &&
Array.isArray(finalUserMessage.content)
) {
finalAssistantMessage.content = [...finalAssistantMessage.content, ...finalUserMessage.content]
}
// 这里是故意将上条助手消息的内容(包含图片和文件)作为用户消息发送
userMessage = [{ ...finalAssistantMessage, role: 'user' } as OpenAI.Responses.EasyInputMessage]
}
// 4. 最终请求消息
let reqMessages: OpenAI.Responses.ResponseInput
if (!systemMessage.content) {
reqMessages = [...userMessage]
} else {
reqMessages = [systemMessage, ...userMessage].filter(Boolean) as OpenAI.Responses.EasyInputMessage[]
}
if (enableWebSearch) {
tools.push({
type: 'web_search_preview'
})
}
if (enableGenerateImage) {
tools.push({
type: 'image_generation',
partial_images: streamOutput ? 2 : undefined
})
}
const toolChoices: OpenAI.Responses.ToolChoiceTypes = {
type: 'web_search_preview'
}
tools = tools.concat(extraTools)
const commonParams = {
model: model.id,
input:
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
? recursiveSdkMessages
: reqMessages,
temperature: this.getTemperature(assistant, model),
top_p: this.getTopP(assistant, model),
max_output_tokens: maxTokens,
stream: streamOutput,
tools: !isEmpty(tools) ? tools : undefined,
tool_choice: enableWebSearch ? toolChoices : undefined,
service_tier: this.getServiceTier(model),
...(this.getReasoningEffort(assistant, model) as OpenAI.Reasoning),
...this.getCustomParameters(assistant)
}
const sdkParams: OpenAIResponseSdkParams = streamOutput
? {
...commonParams,
stream: true
}
: {
...commonParams,
stream: false
}
const timeout = this.getTimeout(model)
return { payload: sdkParams, messages: reqMessages, metadata: { timeout } }
}
}
}
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<OpenAIResponseSdkRawChunk> {
const toolCalls: OpenAIResponseSdkToolCall[] = []
const outputItems: OpenAI.Responses.ResponseOutputItem[] = []
let hasBeenCollectedToolCalls = false
return () => ({
async transform(chunk: OpenAIResponseSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
// 处理chunk
if ('output' in chunk) {
if (ctx._internal?.toolProcessingState) {
ctx._internal.toolProcessingState.output = chunk
}
for (const output of chunk.output) {
switch (output.type) {
case 'message':
if (output.content[0].type === 'output_text') {
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: output.content[0].text
})
if (output.content[0].annotations && output.content[0].annotations.length > 0) {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
source: WebSearchSource.OPENAI_RESPONSE,
results: output.content[0].annotations
}
})
}
}
break
case 'reasoning':
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: output.summary.map((s) => s.text).join('\n')
})
break
case 'function_call':
toolCalls.push(output)
break
case 'image_generation_call':
controller.enqueue({
type: ChunkType.IMAGE_CREATED
})
controller.enqueue({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: [`data:image/png;base64,${output.result}`]
}
})
}
}
if (toolCalls.length > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
}
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: chunk.usage?.input_tokens || 0,
completion_tokens: chunk.usage?.output_tokens || 0,
total_tokens: chunk.usage?.total_tokens || 0
}
}
})
} else {
switch (chunk.type) {
case 'response.output_item.added':
if (chunk.item.type === 'function_call') {
outputItems.push(chunk.item)
}
break
case 'response.reasoning_summary_text.delta':
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: chunk.delta
})
break
case 'response.image_generation_call.generating':
controller.enqueue({
type: ChunkType.IMAGE_CREATED
})
break
case 'response.image_generation_call.partial_image':
controller.enqueue({
type: ChunkType.IMAGE_DELTA,
image: {
type: 'base64',
images: [`data:image/png;base64,${chunk.partial_image_b64}`]
}
})
break
case 'response.image_generation_call.completed':
controller.enqueue({
type: ChunkType.IMAGE_COMPLETE
})
break
case 'response.output_text.delta': {
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: chunk.delta
})
break
}
case 'response.function_call_arguments.done': {
const outputItem: OpenAI.Responses.ResponseOutputItem | undefined = outputItems.find(
(item) => item.id === chunk.item_id
)
if (outputItem) {
if (outputItem.type === 'function_call') {
toolCalls.push({
...outputItem,
arguments: chunk.arguments,
status: 'completed'
})
}
}
break
}
case 'response.content_part.done': {
if (chunk.part.type === 'output_text' && chunk.part.annotations && chunk.part.annotations.length > 0) {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
source: WebSearchSource.OPENAI_RESPONSE,
results: chunk.part.annotations
}
})
}
if (toolCalls.length > 0 && !hasBeenCollectedToolCalls) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
hasBeenCollectedToolCalls = true
}
break
}
case 'response.completed': {
if (ctx._internal?.toolProcessingState) {
ctx._internal.toolProcessingState.output = chunk.response
}
if (toolCalls.length > 0 && !hasBeenCollectedToolCalls) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
hasBeenCollectedToolCalls = true
}
const completion_tokens = chunk.response.usage?.output_tokens || 0
const total_tokens = chunk.response.usage?.total_tokens || 0
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: chunk.response.usage?.input_tokens || 0,
completion_tokens: completion_tokens,
total_tokens: total_tokens
}
}
})
break
}
case 'error': {
controller.enqueue({
type: ChunkType.ERROR,
error: {
message: chunk.message,
code: chunk.code
}
})
break
}
}
}
}
})
}
}

View File

@@ -1,140 +0,0 @@
import Anthropic from '@anthropic-ai/sdk'
import { Assistant, MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types'
import { Provider } from '@renderer/types'
import {
AnthropicSdkRawChunk,
OpenAIResponseSdkRawChunk,
OpenAIResponseSdkRawOutput,
OpenAISdkRawChunk,
SdkMessageParam,
SdkParams,
SdkRawChunk,
SdkRawOutput,
SdkTool,
SdkToolCall
} from '@renderer/types/sdk'
import OpenAI from 'openai'
import { CompletionsParams, GenericChunk } from '../middleware/schemas'
import { CompletionsContext } from '../middleware/types'
/**
* 原始流监听器接口
*/
export interface RawStreamListener<TRawChunk = SdkRawChunk> {
onChunk?: (chunk: TRawChunk) => void
onStart?: () => void
onEnd?: () => void
onError?: (error: Error) => void
}
/**
* OpenAI 专用的流监听器
*/
export interface OpenAIStreamListener extends RawStreamListener<OpenAISdkRawChunk> {
onChoice?: (choice: OpenAI.Chat.Completions.ChatCompletionChunk.Choice) => void
onFinishReason?: (reason: string) => void
}
/**
* OpenAI Response 专用的流监听器
*/
export interface OpenAIResponseStreamListener<TChunk extends OpenAIResponseSdkRawChunk = OpenAIResponseSdkRawChunk>
extends RawStreamListener<TChunk> {
onMessage?: (response: OpenAIResponseSdkRawOutput) => void
}
/**
* Anthropic 专用的流监听器
*/
export interface AnthropicStreamListener<TChunk extends AnthropicSdkRawChunk = AnthropicSdkRawChunk>
extends RawStreamListener<TChunk> {
onContentBlock?: (contentBlock: Anthropic.Messages.ContentBlock) => void
onMessage?: (message: Anthropic.Messages.Message) => void
}
/**
* 请求转换器接口
*/
export interface RequestTransformer<
TSdkParams extends SdkParams = SdkParams,
TMessageParam extends SdkMessageParam = SdkMessageParam
> {
transform(
completionsParams: CompletionsParams,
assistant: Assistant,
model: Model,
isRecursiveCall?: boolean,
recursiveSdkMessages?: TMessageParam[]
): Promise<{
payload: TSdkParams
messages: TMessageParam[]
metadata?: Record<string, any>
}>
}
/**
* 响应块转换器接口
*/
export type ResponseChunkTransformer<TRawChunk extends SdkRawChunk = SdkRawChunk, TContext = any> = (
context?: TContext
) => Transformer<TRawChunk, GenericChunk>
export interface ResponseChunkTransformerContext {
isStreaming: boolean
isEnabledToolCalling: boolean
isEnabledWebSearch: boolean
isEnabledReasoning: boolean
mcpTools: MCPTool[]
provider: Provider
}
/**
* API客户端接口
*/
export interface ApiClient<
TSdkInstance = any,
TSdkParams extends SdkParams = SdkParams,
TRawOutput extends SdkRawOutput = SdkRawOutput,
TRawChunk extends SdkRawChunk = SdkRawChunk,
TMessageParam extends SdkMessageParam = SdkMessageParam,
TToolCall extends SdkToolCall = SdkToolCall,
TSdkSpecificTool extends SdkTool = SdkTool
> {
provider: Provider
// 核心方法 - 在中间件架构中,这个方法可能只是一个占位符
// 实际的SDK调用由SdkCallMiddleware处理
// completions(params: CompletionsParams): Promise<CompletionsResult>
createCompletions(payload: TSdkParams): Promise<TRawOutput>
// SDK相关方法
getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<TRawChunk>
// 原始流监听方法
attachRawStreamListener?(rawOutput: TRawOutput, listener: RawStreamListener<TRawChunk>): TRawOutput
// 工具转换相关方法 (保持可选因为不是所有Provider都支持工具)
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[]
convertMcpToolResponseToSdkMessageParam?(
mcpToolResponse: MCPToolResponse,
resp: any,
model: Model
): TMessageParam | undefined
convertSdkToolCallToMcp?(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined
convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse
// 构建SDK特定的消息列表用于工具调用后的递归调用
buildSdkMessages(
currentReqMessages: TMessageParam[],
output: TRawOutput | string,
toolResults: TMessageParam[],
toolCalls?: TToolCall[]
): TMessageParam[]
// 从SDK载荷中提取消息数组用于中间件中的类型安全访问
extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
}

View File

@@ -1,132 +0,0 @@
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
import { RequestOptions, SdkModel } from '@renderer/types/sdk'
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
import { OpenAIAPIClient } from './clients'
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
import { AnthropicAPIClient } from './clients/anthropic/AnthropicAPIClient'
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
import { CompletionsMiddlewareBuilder } from './middleware/builder'
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
import { applyCompletionsMiddlewares } from './middleware/composer'
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
import { MIDDLEWARE_NAME as ThinkChunkMiddlewareName } from './middleware/core/ThinkChunkMiddleware'
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
import { MiddlewareRegistry } from './middleware/register'
import { CompletionsParams, CompletionsResult } from './middleware/schemas'
export default class AiProvider {
private apiClient: BaseApiClient
constructor(provider: Provider) {
// Use the new ApiClientFactory to get a BaseApiClient instance
this.apiClient = ApiClientFactory.create(provider)
}
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
// 1. 根据模型识别正确的客户端
const model = params.assistant.model
if (!model) {
return Promise.reject(new Error('Model is required'))
}
// 根据client类型选择合适的处理方式
let client: BaseApiClient
if (this.apiClient instanceof AihubmixAPIClient) {
// AihubmixAPIClient: 根据模型选择合适的子client
client = this.apiClient.getClientForModel(model)
if (client instanceof OpenAIResponseAPIClient) {
client = client.getClient(model) as BaseApiClient
}
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
// OpenAIResponseAPIClient: 根据模型特征选择API类型
client = this.apiClient.getClient(model) as BaseApiClient
} else {
// 其他client直接使用
client = this.apiClient
}
// 2. 构建中间件链
const builder = CompletionsMiddlewareBuilder.withDefaults()
// images api
if (isDedicatedImageGenerationModel(model)) {
builder.clear()
builder
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
} else {
// Existing logic for other models
if (!params.enableReasoning) {
builder.remove(ThinkingTagExtractionMiddlewareName)
builder.remove(ThinkChunkMiddlewareName)
}
// 注意用client判断会导致typescript类型收窄
if (!(this.apiClient instanceof OpenAIAPIClient)) {
builder.remove(ThinkingTagExtractionMiddlewareName)
}
if (!(this.apiClient instanceof AnthropicAPIClient) && !(this.apiClient instanceof OpenAIResponseAPIClient)) {
builder.remove(RawStreamListenerMiddlewareName)
}
if (!params.enableWebSearch) {
builder.remove(WebSearchMiddlewareName)
}
if (!params.mcpTools?.length) {
builder.remove(ToolUseExtractionMiddlewareName)
builder.remove(McpToolChunkMiddlewareName)
}
if (isEnabledToolUse(params.assistant) && isFunctionCallingModel(model)) {
builder.remove(ToolUseExtractionMiddlewareName)
}
if (params.callType !== 'chat') {
builder.remove(AbortHandlerMiddlewareName)
}
}
const middlewares = builder.build()
// 3. Create the wrapped SDK method with middlewares
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
// 4. Execute the wrapped method with the original params
return wrappedCompletionMethod(params, options)
}
public async models(): Promise<SdkModel[]> {
return this.apiClient.listModels()
}
public async getEmbeddingDimensions(model: Model): Promise<number> {
try {
// Use the SDK instance to test embedding capabilities
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
return dimensions
} catch (error) {
console.error('Error getting embedding dimensions:', error)
throw error
}
}
public async generateImage(params: GenerateImageParams): Promise<string[]> {
return this.apiClient.generateImage(params)
}
public getBaseURL(): string {
return this.apiClient.getBaseURL()
}
public getApiKey(): string {
return this.apiClient.getApiKey()
}
}

View File

@@ -1,229 +0,0 @@
/**
* Cherry Studio AI Core - 新版本入口
* 集成 @cherrystudio/ai-core 库的渐进式重构方案
*
* 融合方案:简化实现,专注于核心功能
* 1. 优先使用新AI SDK
* 2. 失败时fallback到原有实现
* 3. 暂时保持接口兼容性
*/
import {
AiClient,
AiCore,
createClient,
type OpenAICompatibleProviderSettings,
type ProviderId,
smoothStream,
StreamTextParams
} from '@cherrystudio/ai-core'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
// 引入适配器
import AiSdkToChunkAdapter from './AiSdkToChunkAdapter'
// 引入原有的AiProvider作为fallback
import LegacyAiProvider from './index'
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/aisdk/AiSdkMiddlewareBuilder'
import { CompletionsResult } from './middleware/schemas'
// 引入参数转换模块
/**
* 将现有 Provider 类型映射到 AI SDK 的 Provider ID
* 根据 registry.ts 中的支持列表进行映射
*/
function mapProviderTypeToAiSdkId(providerType: string): string {
// Cherry Studio Provider Type -> AI SDK Provider ID 映射表
const typeMapping: Record<string, string> = {
// 需要转换的映射
grok: 'xai', // grok -> xai
'azure-openai': 'azure', // azure-openai -> azure
gemini: 'google', // gemini -> google
vertexai: 'google-vertex' // vertexai -> google-vertex
}
return typeMapping[providerType]
}
/**
* 将 Provider 配置转换为新 AI SDK 格式
*/
function providerToAiSdkConfig(provider: Provider): {
providerId: ProviderId | 'openai-compatible'
options: any
} {
console.log('provider', provider)
// 1. 先映射 provider 类型到 AI SDK ID
const mappedProviderId = mapProviderTypeToAiSdkId(provider.id)
// 2. 检查映射后的 provider ID 是否在 AI SDK 注册表中
const isSupported = AiCore.isSupported(mappedProviderId)
console.log(`Provider mapping: ${provider.type} -> ${mappedProviderId}, supported: ${isSupported}`)
// 3. 如果映射的 provider 不支持,则使用 openai-compatible
if (isSupported) {
return {
providerId: mappedProviderId as ProviderId,
options: {
apiKey: provider.apiKey
}
}
} else {
console.log(`Using openai-compatible fallback for provider: ${provider.type}`)
const compatibleConfig: OpenAICompatibleProviderSettings = {
name: provider.name || provider.type,
apiKey: provider.apiKey,
baseURL: provider.apiHost
}
return {
providerId: 'openai-compatible',
options: compatibleConfig
}
}
}
/**
* 检查是否支持使用新的AI SDK
*/
function isModernSdkSupported(provider: Provider, model?: Model): boolean {
// 目前支持主要的providers
const supportedProviders = ['openai', 'anthropic', 'gemini', 'azure-openai', 'vertexai']
// 检查provider类型
if (!supportedProviders.includes(provider.type)) {
return false
}
// 检查是否为图像生成模型(暂时不支持)
if (model && isDedicatedImageGenerationModel(model)) {
return false
}
return true
}
export default class ModernAiProvider {
private modernClient?: AiClient
private legacyProvider: LegacyAiProvider
private provider: Provider
constructor(provider: Provider) {
this.provider = provider
this.legacyProvider = new LegacyAiProvider(provider)
// 初始化时不构建中间件,等到需要时再构建
const config = providerToAiSdkConfig(provider)
this.modernClient = createClient(config.providerId, config.options)
}
public async completions(
modelId: string,
params: StreamTextParams,
middlewareConfig: AiSdkMiddlewareConfig
): Promise<CompletionsResult> {
// const model = params.assistant.model
// 检查是否应该使用现代化客户端
// if (this.modernClient && model && isModernSdkSupported(this.provider, model)) {
// try {
return await this.modernCompletions(modelId, params, middlewareConfig)
// } catch (error) {
// console.warn('Modern client failed, falling back to legacy:', error)
// fallback到原有实现
// }
// }
// 使用原有实现
// return this.legacyProvider.completions(params, options)
}
/**
* 使用现代化AI SDK的completions实现
* 使用建造者模式动态构建中间件
*/
private async modernCompletions(
modelId: string,
params: StreamTextParams,
middlewareConfig: AiSdkMiddlewareConfig
): Promise<CompletionsResult> {
if (!this.modernClient) {
throw new Error('Modern AI SDK client not initialized')
}
try {
// 合并传入的配置和实例配置
const finalConfig: AiSdkMiddlewareConfig = {
...middlewareConfig,
provider: this.provider,
// 工具相关信息从 params 中获取
enableTool: params.tools !== undefined && Array.isArray(params.tools) && params.tools.length > 0
}
// 动态构建中间件数组
const middlewares = buildAiSdkMiddlewares(finalConfig)
console.log(
'构建的中间件:',
middlewares.map((m) => m.name)
)
// 创建带有中间件的客户端
const config = providerToAiSdkConfig(this.provider)
const clientWithMiddlewares = createClient(config.providerId, config.options, middlewares)
if (middlewareConfig.onChunk) {
// 流式处理 - 使用适配器
const adapter = new AiSdkToChunkAdapter(middlewareConfig.onChunk)
const streamResult = await clientWithMiddlewares.streamText(modelId, {
...params,
experimental_transform: smoothStream({
delayInMs: 80,
// 中文3个字符一个chunk,英文一个单词一个chunk
chunking: /([\u4E00-\u9FFF]{3})|\S+\s+/
})
})
const finalText = await adapter.processStream(streamResult)
return {
getText: () => finalText
}
} else {
// 流式处理但没有 onChunk 回调
const streamResult = await clientWithMiddlewares.streamText(modelId, params)
const finalText = await streamResult.text
return {
getText: () => finalText
}
}
} catch (error) {
console.error('Modern AI SDK error:', error)
throw error
}
}
// 代理其他方法到原有实现
public async models() {
return this.legacyProvider.models()
}
public async getEmbeddingDimensions(model: Model): Promise<number> {
return this.legacyProvider.getEmbeddingDimensions(model)
}
public async generateImage(params: GenerateImageParams): Promise<string[]> {
return this.legacyProvider.generateImage(params)
}
public getBaseURL(): string {
return this.legacyProvider.getBaseURL()
}
public getApiKey(): string {
return this.legacyProvider.getApiKey()
}
}
// 为了方便调试,导出一些工具函数
export { isModernSdkSupported, providerToAiSdkConfig }

View File

@@ -1,182 +0,0 @@
# MiddlewareBuilder 使用指南
`MiddlewareBuilder` 是一个用于动态构建和管理中间件链的工具,提供灵活的中间件组织和配置能力。
## 主要特性
### 1. 统一的中间件命名
所有中间件都通过导出的 `MIDDLEWARE_NAME` 常量标识:
```typescript
// 中间件文件示例
export const MIDDLEWARE_NAME = 'SdkCallMiddleware'
export const SdkCallMiddleware: CompletionsMiddleware = ...
```
### 2. NamedMiddleware 接口
中间件使用统一的 `NamedMiddleware` 接口格式:
```typescript
interface NamedMiddleware<TMiddleware = any> {
name: string
middleware: TMiddleware
}
```
### 3. 中间件注册表
通过 `MiddlewareRegistry` 集中管理所有可用中间件:
```typescript
import { MiddlewareRegistry } from './register'
// 通过名称获取中间件
const sdkCallMiddleware = MiddlewareRegistry['SdkCallMiddleware']
```
## 基本用法
### 1. 使用默认中间件链
```typescript
import { CompletionsMiddlewareBuilder } from './builder'
const builder = CompletionsMiddlewareBuilder.withDefaults()
const middlewares = builder.build()
```
### 2. 自定义中间件链
```typescript
import { createCompletionsBuilder, MiddlewareRegistry } from './builder'
const builder = createCompletionsBuilder([
MiddlewareRegistry['AbortHandlerMiddleware'],
MiddlewareRegistry['TextChunkMiddleware']
])
const middlewares = builder.build()
```
### 3. 动态调整中间件链
```typescript
const builder = CompletionsMiddlewareBuilder.withDefaults()
// 根据条件添加、移除、替换中间件
if (needsLogging) {
builder.prepend(MiddlewareRegistry['GenericLoggingMiddleware'])
}
if (disableTools) {
builder.remove('McpToolChunkMiddleware')
}
if (customThinking) {
builder.replace('ThinkingTagExtractionMiddleware', customThinkingMiddleware)
}
const middlewares = builder.build()
```
### 4. 链式操作
```typescript
const middlewares = CompletionsMiddlewareBuilder.withDefaults()
.add(MiddlewareRegistry['CustomMiddleware'])
.insertBefore('SdkCallMiddleware', MiddlewareRegistry['SecurityCheckMiddleware'])
.remove('WebSearchMiddleware')
.build()
```
## API 参考
### CompletionsMiddlewareBuilder
**静态方法:**
- `static withDefaults()`: 创建带有默认中间件链的构建器
**实例方法:**
- `add(middleware: NamedMiddleware)`: 在链末尾添加中间件
- `prepend(middleware: NamedMiddleware)`: 在链开头添加中间件
- `insertAfter(targetName: string, middleware: NamedMiddleware)`: 在指定中间件后插入
- `insertBefore(targetName: string, middleware: NamedMiddleware)`: 在指定中间件前插入
- `replace(targetName: string, middleware: NamedMiddleware)`: 替换指定中间件
- `remove(targetName: string)`: 移除指定中间件
- `has(name: string)`: 检查是否包含指定中间件
- `build()`: 构建最终的中间件数组
- `getChain()`: 获取当前链(包含名称信息)
- `clear()`: 清空中间件链
- `execute(context, params, middlewareExecutor)`: 直接执行构建好的中间件链
### 工厂函数
- `createCompletionsBuilder(baseChain?)`: 创建 Completions 中间件构建器
- `createMethodBuilder(baseChain?)`: 创建通用方法中间件构建器
- `addMiddlewareName(middleware, name)`: 为中间件添加名称属性的辅助函数
### 中间件注册表
- `MiddlewareRegistry`: 所有注册中间件的集中访问点
- `getMiddleware(name)`: 根据名称获取中间件
- `getRegisteredMiddlewareNames()`: 获取所有注册的中间件名称
- `DefaultCompletionsNamedMiddlewares`: 默认的 Completions 中间件链NamedMiddleware 格式)
## 类型安全
构建器提供完整的 TypeScript 类型支持:
- `CompletionsMiddlewareBuilder` 专门用于 `CompletionsMiddleware` 类型
- `MethodMiddlewareBuilder` 用于通用的 `MethodMiddleware` 类型
- 所有中间件操作都基于 `NamedMiddleware<TMiddleware>` 接口
## 默认中间件链
默认的 Completions 中间件执行顺序:
1. `FinalChunkConsumerMiddleware` - 最终消费者
2. `TransformCoreToSdkParamsMiddleware` - 参数转换
3. `AbortHandlerMiddleware` - 中止处理
4. `McpToolChunkMiddleware` - 工具处理
5. `WebSearchMiddleware` - Web搜索处理
6. `TextChunkMiddleware` - 文本处理
7. `ThinkingTagExtractionMiddleware` - 思考标签提取处理
8. `ThinkChunkMiddleware` - 思考处理
9. `ResponseTransformMiddleware` - 响应转换
10. `StreamAdapterMiddleware` - 流适配器
11. `SdkCallMiddleware` - SDK调用
## 在 AiProvider 中的使用
```typescript
export default class AiProvider {
public async completions(params: CompletionsParams): Promise<CompletionsResult> {
// 1. 构建中间件链
const builder = CompletionsMiddlewareBuilder.withDefaults()
// 2. 根据参数动态调整
if (params.enableCustomFeature) {
builder.insertAfter('StreamAdapterMiddleware', customFeatureMiddleware)
}
// 3. 应用中间件
const middlewares = builder.build()
const wrappedMethod = applyCompletionsMiddlewares(this.apiClient, this.apiClient.createCompletions, middlewares)
return wrappedMethod(params)
}
}
```
## 注意事项
1. **类型兼容性**`MethodMiddleware``CompletionsMiddleware` 不兼容,需要使用对应的构建器
2. **中间件名称**:所有中间件必须导出 `MIDDLEWARE_NAME` 常量用于标识
3. **注册表管理**:新增中间件需要在 `register.ts` 中注册
4. **默认链**:默认链通过 `DefaultCompletionsNamedMiddlewares` 提供,支持延迟加载避免循环依赖
这种设计使得中间件链的构建既灵活又类型安全,同时保持了简洁的 API 接口。

View File

@@ -1,175 +0,0 @@
# Cherry Studio 中间件规范
本文档定义了 Cherry Studio `aiCore` 模块中中间件的设计、实现和使用规范。目标是建立一个灵活、可维护且易于扩展的中间件系统。
## 1. 核心概念
### 1.1. 中间件 (Middleware)
中间件是一个函数或对象,它在 AI 请求的处理流程中的特定阶段执行,可以访问和修改请求上下文 (`AiProviderMiddlewareContext`)、请求参数 (`Params`),并控制是否将请求传递给下一个中间件或终止流程。
每个中间件应该专注于一个单一的横切关注点,例如日志记录、错误处理、流适配、特性解析等。
### 1.2. `AiProviderMiddlewareContext` (上下文对象)
这是一个在整个中间件链执行过程中传递的对象,包含以下核心信息:
- `_apiClientInstance: ApiClient<any,any,any>`: 当前选定的、已实例化的 AI Provider 客户端。
- `_coreRequest: CoreRequestType`: 标准化的内部核心请求对象。
- `resolvePromise: (value: AggregatedResultType) => void`: 用于在整个操作成功完成时解析 `AiCoreService` 返回的 Promise。
- `rejectPromise: (reason?: any) => void`: 用于在发生错误时拒绝 `AiCoreService` 返回的 Promise。
- `onChunk?: (chunk: Chunk) => void`: 应用层提供的流式数据块回调。
- `abortController?: AbortController`: 用于中止请求的控制器。
- 其他中间件可能读写的、与当前请求相关的动态数据。
### 1.3. `MiddlewareName` (中间件名称)
为了方便动态操作(如插入、替换、移除)中间件,每个重要的、可能被其他逻辑引用的中间件都应该有一个唯一的、可识别的名称。推荐使用 TypeScript 的 `enum` 来定义:
```typescript
// example
export enum MiddlewareName {
LOGGING_START = 'LoggingStartMiddleware',
LOGGING_END = 'LoggingEndMiddleware',
ERROR_HANDLING = 'ErrorHandlingMiddleware',
ABORT_HANDLER = 'AbortHandlerMiddleware',
// Core Flow
TRANSFORM_CORE_TO_SDK_PARAMS = 'TransformCoreToSdkParamsMiddleware',
REQUEST_EXECUTION = 'RequestExecutionMiddleware',
STREAM_ADAPTER = 'StreamAdapterMiddleware',
RAW_SDK_CHUNK_TO_APP_CHUNK = 'RawSdkChunkToAppChunkMiddleware',
// Features
THINKING_TAG_EXTRACTION = 'ThinkingTagExtractionMiddleware',
TOOL_USE_TAG_EXTRACTION = 'ToolUseTagExtractionMiddleware',
MCP_TOOL_HANDLER = 'McpToolHandlerMiddleware',
// Finalization
FINAL_CHUNK_CONSUMER = 'FinalChunkConsumerAndNotifierMiddleware'
// Add more as needed
}
```
中间件实例需要某种方式暴露其 `MiddlewareName`,例如通过一个 `name` 属性。
### 1.4. 中间件执行结构
我们采用一种灵活的中间件执行结构。一个中间件通常是一个函数,它接收 `Context``Params`,以及一个 `next` 函数(用于调用链中的下一个中间件)。
```typescript
// 简化形式的中间件函数签名
type MiddlewareFunction = (
context: AiProviderMiddlewareContext,
params: any, // e.g., CompletionsParams
next: () => Promise<void> // next 通常返回 Promise 以支持异步操作
) => Promise<void> // 中间件自身也可能返回 Promise
// 或者更经典的 Koa/Express 风格 (三段式)
// type MiddlewareFactory = (api?: MiddlewareApi) =>
// (nextMiddleware: (ctx: AiProviderMiddlewareContext, params: any) => Promise<void>) =>
// (context: AiProviderMiddlewareContext, params: any) => Promise<void>;
// 当前设计更倾向于上述简化的 MiddlewareFunction由 MiddlewareExecutor 负责 next 的编排。
```
`MiddlewareExecutor` (或 `applyMiddlewares`) 会负责管理 `next` 的调用。
## 2. `MiddlewareBuilder` (通用中间件构建器)
为了动态构建和管理中间件链,我们引入一个通用的 `MiddlewareBuilder` 类。
### 2.1. 设计理念
`MiddlewareBuilder` 提供了一个流式 API用于以声明式的方式构建中间件链。它允许从一个基础链开始然后根据特定条件添加、插入、替换或移除中间件。
### 2.2. API 概览
```typescript
class MiddlewareBuilder {
constructor(baseChain?: Middleware[])
add(middleware: Middleware): this
prepend(middleware: Middleware): this
insertAfter(targetName: MiddlewareName, middlewareToInsert: Middleware): this
insertBefore(targetName: MiddlewareName, middlewareToInsert: Middleware): this
replace(targetName: MiddlewareName, newMiddleware: Middleware): this
remove(targetName: MiddlewareName): this
build(): Middleware[] // 返回构建好的中间件数组
// 可选:直接执行链
execute(
context: AiProviderMiddlewareContext,
params: any,
middlewareExecutor: (chain: Middleware[], context: AiProviderMiddlewareContext, params: any) => void
): void
}
```
### 2.3. 使用示例
```typescript
// 1. 定义一些中间件实例 (假设它们有 .name 属性)
const loggingStart = { name: MiddlewareName.LOGGING_START, fn: loggingStartFn }
const requestExec = { name: MiddlewareName.REQUEST_EXECUTION, fn: requestExecFn }
const streamAdapter = { name: MiddlewareName.STREAM_ADAPTER, fn: streamAdapterFn }
const customFeature = { name: MiddlewareName.CUSTOM_FEATURE, fn: customFeatureFn } // 假设自定义
// 2. 定义一个基础链 (可选)
const BASE_CHAIN: Middleware[] = [loggingStart, requestExec, streamAdapter]
// 3. 使用 MiddlewareBuilder
const builder = new MiddlewareBuilder(BASE_CHAIN)
if (params.needsCustomFeature) {
builder.insertAfter(MiddlewareName.STREAM_ADAPTER, customFeature)
}
if (params.isHighSecurityContext) {
builder.insertBefore(MiddlewareName.REQUEST_EXECUTION, высокоSecurityCheckMiddleware)
}
if (params.overrideLogging) {
builder.replace(MiddlewareName.LOGGING_START, newSpecialLoggingMiddleware)
}
// 4. 获取最终链
const finalChain = builder.build()
// 5. 执行 (通过外部执行器)
// middlewareExecutor(finalChain, context, params);
// 或者 builder.execute(context, params, middlewareExecutor);
```
## 3. `MiddlewareExecutor` / `applyMiddlewares` (中间件执行器)
这是负责接收 `MiddlewareBuilder` 构建的中间件链并实际执行它们的组件。
### 3.1. 职责
- 接收 `Middleware[]`, `AiProviderMiddlewareContext`, `Params`
- 按顺序迭代中间件。
- 为每个中间件提供正确的 `next` 函数,该函数在被调用时会执行链中的下一个中间件。
- 处理中间件执行过程中的Promise如果中间件是异步的
- 基础的错误捕获(具体错误处理应由链内的 `ErrorHandlingMiddleware` 负责)。
## 4. 在 `AiCoreService` 中使用
`AiCoreService` 中的每个核心业务方法 (如 `executeCompletions`) 将负责:
1. 准备基础数据:实例化 `ApiClient`,转换 `Params``CoreRequest`
2. 实例化 `MiddlewareBuilder`,可能会传入一个特定于该业务方法的基础中间件链。
3. 根据 `Params``CoreRequest` 中的条件,调用 `MiddlewareBuilder` 的方法来动态调整中间件链。
4. 调用 `MiddlewareBuilder.build()` 获取最终的中间件链。
5. 创建完整的 `AiProviderMiddlewareContext` (包含 `resolvePromise`, `rejectPromise` 等)。
6. 调用 `MiddlewareExecutor` (或 `applyMiddlewares`) 来执行构建好的链。
## 5. 组合功能
对于组合功能(例如 "Completions then Translate"
- 不推荐创建一个单一、庞大的 `MiddlewareBuilder` 来处理整个组合流程。
- 推荐在 `AiCoreService` 中创建一个新的方法,该方法按顺序 `await` 调用底层的原子 `AiCoreService` 方法(例如,先 `await this.executeCompletions(...)`,然后用其结果 `await this.translateText(...)`)。
- 每个被调用的原子方法内部会使用其自身的 `MiddlewareBuilder` 实例来构建和执行其特定阶段的中间件链。
- 这种方式最大化了复用,并保持了各部分职责的清晰。
## 6. 中间件命名和发现
为中间件赋予唯一的 `MiddlewareName` 对于 `MiddlewareBuilder``insertAfter`, `insertBefore`, `replace`, `remove` 等操作至关重要。确保中间件实例能够以某种方式暴露其名称(例如,一个 `name` 属性)。

View File

@@ -1,188 +0,0 @@
import { AiPlugin, simulateStreamingMiddleware } from '@cherrystudio/ai-core'
import { isReasoningModel } from '@renderer/config/models'
import type { Model, Provider } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import thinkingTimeMiddleware from './ThinkingTimeMiddleware'
/**
* AI SDK 中间件配置项
*/
export interface AiSdkMiddlewareConfig {
streamOutput?: boolean
onChunk?: (chunk: Chunk) => void
model?: Model
provider?: Provider
enableReasoning?: boolean
enableTool?: boolean
enableWebSearch?: boolean
}
/**
* 具名的 AI SDK 中间件
*/
export type NamedAiSdkMiddleware = AiPlugin
/**
* AI SDK 中间件建造者
* 用于根据不同条件动态构建中间件数组
*/
export class AiSdkMiddlewareBuilder {
private middlewares: NamedAiSdkMiddleware[] = []
/**
* 添加具名中间件
*/
public add(namedMiddleware: NamedAiSdkMiddleware): this {
this.middlewares.push(namedMiddleware)
return this
}
/**
* 在指定位置插入中间件
*/
public insertAfter(targetName: string, middleware: NamedAiSdkMiddleware): this {
const index = this.middlewares.findIndex((m) => m.name === targetName)
if (index !== -1) {
this.middlewares.splice(index + 1, 0, middleware)
} else {
console.warn(`AiSdkMiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`)
}
return this
}
/**
* 检查是否包含指定名称的中间件
*/
public has(name: string): boolean {
return this.middlewares.some((m) => m.name === name)
}
/**
* 移除指定名称的中间件
*/
public remove(name: string): this {
this.middlewares = this.middlewares.filter((m) => m.name !== name)
return this
}
/**
* 构建最终的中间件数组
*/
public build(): NamedAiSdkMiddleware[] {
return [...this.middlewares]
}
/**
* 清空所有中间件
*/
public clear(): this {
this.middlewares = []
return this
}
/**
* 获取中间件总数
*/
public get length(): number {
return this.middlewares.length
}
}
/**
* 根据配置构建AI SDK中间件的工厂函数
* 这里要注意构建顺序,因为有些中间件需要依赖其他中间件的结果
*/
export function buildAiSdkMiddlewares(config: AiSdkMiddlewareConfig): NamedAiSdkMiddleware[] {
const builder = new AiSdkMiddlewareBuilder()
// 1. 思考模型且有onChunk回调时添加思考时间中间件
if (config.onChunk && config.model && isReasoningModel(config.model)) {
builder.add({
name: 'thinking-time',
aiSdkMiddlewares: [thinkingTimeMiddleware()]
})
}
// 2. 可以在这里根据其他条件添加更多中间件
// 例如工具调用、Web搜索等相关中间件
// 3. 根据provider添加特定中间件
if (config.provider) {
addProviderSpecificMiddlewares(builder, config)
}
// 4. 根据模型类型添加特定中间件
if (config.model) {
addModelSpecificMiddlewares(builder, config)
}
// 5. 非流式输出时添加模拟流中间件
if (config.streamOutput === false) {
builder.add({
name: 'simulate-streaming',
aiSdkMiddlewares: [simulateStreamingMiddleware()]
})
}
return builder.build()
}
/**
* 添加provider特定的中间件
*/
function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void {
if (!config.provider) return
// 根据不同provider添加特定中间件
switch (config.provider.type) {
case 'anthropic':
// Anthropic特定中间件
break
case 'openai':
// OpenAI特定中间件
break
case 'gemini':
// Gemini特定中间件
break
default:
// 其他provider的通用处理
break
}
}
/**
* 添加模型特定的中间件
*/
function addModelSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config: AiSdkMiddlewareConfig): void {
if (!config.model) return
// 可以根据模型ID或特性添加特定中间件
// 例如:图像生成模型、多模态模型等
// 示例:某些模型需要特殊处理
if (config.model.id.includes('dalle') || config.model.id.includes('midjourney')) {
// 图像生成相关中间件
}
}
/**
* 创建一个预配置的中间件建造者
*/
export function createAiSdkMiddlewareBuilder(): AiSdkMiddlewareBuilder {
return new AiSdkMiddlewareBuilder()
}
/**
* 创建一个带有默认中间件的建造者
*/
export function createDefaultAiSdkMiddlewareBuilder(config: AiSdkMiddlewareConfig): AiSdkMiddlewareBuilder {
const builder = new AiSdkMiddlewareBuilder()
const defaultMiddlewares = buildAiSdkMiddlewares(config)
defaultMiddlewares.forEach((middleware) => {
builder.add(middleware)
})
return builder
}

View File

@@ -1,140 +0,0 @@
# AI SDK 中间件建造者
## 概述
`AiSdkMiddlewareBuilder` 是一个用于动态构建 AI SDK 中间件数组的建造者模式实现。它可以根据不同的条件如流式输出、思考模型、provider类型等自动构建合适的中间件组合。
## 使用方式
### 基本用法
```typescript
import { buildAiSdkMiddlewares, type AiSdkMiddlewareConfig } from './AiSdkMiddlewareBuilder'
// 配置中间件参数
const config: AiSdkMiddlewareConfig = {
streamOutput: false, // 非流式输出
onChunk: chunkHandler, // chunk回调函数
model: currentModel, // 当前模型
provider: currentProvider, // 当前provider
enableReasoning: true, // 启用推理
enableTool: false, // 禁用工具
enableWebSearch: false // 禁用网页搜索
}
// 构建中间件数组
const middlewares = buildAiSdkMiddlewares(config)
// 创建带有中间件的客户端
const client = createClient(providerId, options, middlewares)
```
### 手动构建
```typescript
import { AiSdkMiddlewareBuilder, createAiSdkMiddlewareBuilder } from './AiSdkMiddlewareBuilder'
const builder = createAiSdkMiddlewareBuilder()
// 添加特定中间件
builder.add({
name: 'custom-middleware',
aiSdkMiddlewares: [customMiddleware()]
})
// 检查是否包含某个中间件
if (builder.has('thinking-time')) {
console.log('已包含思考时间中间件')
}
// 移除不需要的中间件
builder.remove('simulate-streaming')
// 构建最终数组
const middlewares = builder.build()
```
## 支持的条件
### 1. 流式输出控制
- **streamOutput = false**: 自动添加 `simulateStreamingMiddleware`
- **streamOutput = true**: 使用原生流式处理
### 2. 思考模型处理
- **条件**: `onChunk` 存在 && `isReasoningModel(model)` 为 true
- **效果**: 自动添加 `thinkingTimeMiddleware`
### 3. Provider 特定中间件
根据不同的 provider 类型添加特定中间件:
- **anthropic**: Anthropic 特定处理
- **openai**: OpenAI 特定处理
- **gemini**: Gemini 特定处理
### 4. 模型特定中间件
根据模型特性添加中间件:
- **图像生成模型**: 添加图像处理相关中间件
- **多模态模型**: 添加多模态处理中间件
## 扩展指南
### 添加新的条件判断
`buildAiSdkMiddlewares` 函数中添加新的条件:
```typescript
// 例如:添加缓存中间件
if (config.enableCache) {
builder.add({
name: 'cache',
aiSdkMiddlewares: [cacheMiddleware(config.cacheOptions)]
})
}
```
### 添加 Provider 特定处理
`addProviderSpecificMiddlewares` 函数中添加:
```typescript
case 'custom-provider':
builder.add({
name: 'custom-provider-middleware',
aiSdkMiddlewares: [customProviderMiddleware()]
})
break
```
### 添加模型特定处理
`addModelSpecificMiddlewares` 函数中添加:
```typescript
if (config.model.id.includes('custom-model')) {
builder.add({
name: 'custom-model-middleware',
aiSdkMiddlewares: [customModelMiddleware()]
})
}
```
## 中间件执行顺序
中间件按照添加顺序执行:
1. **simulate-streaming** (如果 streamOutput = false)
2. **thinking-time** (如果是思考模型且有 onChunk)
3. **provider-specific** (根据 provider 类型)
4. **model-specific** (根据模型类型)
## 注意事项
1. 中间件的执行顺序很重要,确保按正确顺序添加
2. 避免添加冲突的中间件
3. 某些中间件可能有依赖关系,需要确保依赖的中间件先添加
4. 建议在开发环境下启用日志,以便调试中间件构建过程

View File

@@ -1,67 +0,0 @@
import { LanguageModelV1Middleware, LanguageModelV1StreamPart } from '@cherrystudio/ai-core'
import { ChunkType, ThinkingCompleteChunk } from '@renderer/types/chunk'
/**
* 一个用于统计 LLM "思考时间"Time to First Token的 AI SDK 中间件。
*
* 工作原理:
* 1. 在 `stream` 方法被调用时,记录一个起始时间。
* 2. 它会创建一个新的 `TransformStream` 来代理原始的流。
* 3. 当第一个数据块 (chunk) 从原始流中到达时,记录结束时间。
* 4. 计算两者之差,即为 "思考时间"
* 这里只处理了thinking_complete
*/
export default function thinkingTimeMiddleware(): LanguageModelV1Middleware {
return {
wrapStream: async ({ doStream }) => {
let hasThinkingContent = false
let thinkingStartTime = 0
let accumulatedThinkingContent = ''
const { stream, ...reset } = await doStream()
const transformStream = new TransformStream<LanguageModelV1StreamPart, any>({
transform(chunk, controller) {
if (chunk.type === 'reasoning' || chunk.type === 'redacted-reasoning') {
if (!hasThinkingContent) {
hasThinkingContent = true
thinkingStartTime = Date.now()
}
accumulatedThinkingContent += chunk.textDelta || ''
} else {
if (hasThinkingContent && thinkingStartTime > 0) {
const thinkingTime = Date.now() - thinkingStartTime
const thinkingCompleteChunk = {
type: 'reasoning-signature',
text: accumulatedThinkingContent,
thinking_millsec: thinkingTime
}
controller.enqueue(thinkingCompleteChunk)
hasThinkingContent = false
thinkingStartTime = 0
accumulatedThinkingContent = ''
}
}
// 将所有 chunk 原样传递下去
controller.enqueue(chunk)
},
flush(controller) {
// 如果流的末尾都是 reasoning也需要发送 complete 事件
if (hasThinkingContent && thinkingStartTime > 0) {
const thinkingTime = Date.now() - thinkingStartTime
const thinkingCompleteChunk: ThinkingCompleteChunk = {
type: ChunkType.THINKING_COMPLETE,
text: accumulatedThinkingContent,
thinking_millsec: thinkingTime
}
controller.enqueue(thinkingCompleteChunk)
}
controller.terminate()
}
})
return {
stream: stream.pipeThrough(transformStream),
...reset
}
}
}
}

View File

@@ -1,241 +0,0 @@
import { DefaultCompletionsNamedMiddlewares } from './register'
import { BaseContext, CompletionsMiddleware, MethodMiddleware } from './types'
/**
* 带有名称标识的中间件接口
*/
export interface NamedMiddleware<TMiddleware = any> {
name: string
middleware: TMiddleware
}
/**
* 中间件执行器函数类型
*/
export type MiddlewareExecutor<TContext extends BaseContext = BaseContext> = (
chain: any[],
context: TContext,
params: any
) => Promise<any>
/**
* 通用中间件构建器类
* 提供流式 API 用于动态构建和管理中间件链
*
* 注意:所有中间件都通过 MiddlewareRegistry 管理,使用 NamedMiddleware 格式
*/
export class MiddlewareBuilder<TMiddleware = any> {
private middlewares: NamedMiddleware<TMiddleware>[]
/**
* 构造函数
* @param baseChain - 可选的基础中间件链NamedMiddleware 格式)
*/
constructor(baseChain?: NamedMiddleware<TMiddleware>[]) {
this.middlewares = baseChain ? [...baseChain] : []
}
/**
* 在链的末尾添加中间件
* @param middleware - 要添加的具名中间件
* @returns this支持链式调用
*/
add(middleware: NamedMiddleware<TMiddleware>): this {
this.middlewares.push(middleware)
return this
}
/**
* 在链的开头添加中间件
* @param middleware - 要添加的具名中间件
* @returns this支持链式调用
*/
prepend(middleware: NamedMiddleware<TMiddleware>): this {
this.middlewares.unshift(middleware)
return this
}
/**
* 在指定中间件之后插入新中间件
* @param targetName - 目标中间件名称
* @param middlewareToInsert - 要插入的具名中间件
* @returns this支持链式调用
*/
insertAfter(targetName: string, middlewareToInsert: NamedMiddleware<TMiddleware>): this {
const index = this.findMiddlewareIndex(targetName)
if (index !== -1) {
this.middlewares.splice(index + 1, 0, middlewareToInsert)
} else {
console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`)
}
return this
}
/**
* 在指定中间件之前插入新中间件
* @param targetName - 目标中间件名称
* @param middlewareToInsert - 要插入的具名中间件
* @returns this支持链式调用
*/
insertBefore(targetName: string, middlewareToInsert: NamedMiddleware<TMiddleware>): this {
const index = this.findMiddlewareIndex(targetName)
if (index !== -1) {
this.middlewares.splice(index, 0, middlewareToInsert)
} else {
console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`)
}
return this
}
/**
* 替换指定的中间件
* @param targetName - 要替换的中间件名称
* @param newMiddleware - 新的具名中间件
* @returns this支持链式调用
*/
replace(targetName: string, newMiddleware: NamedMiddleware<TMiddleware>): this {
const index = this.findMiddlewareIndex(targetName)
if (index !== -1) {
this.middlewares[index] = newMiddleware
} else {
console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法替换`)
}
return this
}
/**
* 移除指定的中间件
* @param targetName - 要移除的中间件名称
* @returns this支持链式调用
*/
remove(targetName: string): this {
const index = this.findMiddlewareIndex(targetName)
if (index !== -1) {
this.middlewares.splice(index, 1)
}
return this
}
/**
* 构建最终的中间件数组
* @returns 构建好的中间件数组
*/
build(): TMiddleware[] {
return this.middlewares.map((item) => item.middleware)
}
/**
* 获取当前中间件链的副本(包含名称信息)
* @returns 当前中间件链的副本
*/
getChain(): NamedMiddleware<TMiddleware>[] {
return [...this.middlewares]
}
/**
* 检查是否包含指定名称的中间件
* @param name - 中间件名称
* @returns 是否包含该中间件
*/
has(name: string): boolean {
return this.findMiddlewareIndex(name) !== -1
}
/**
* 获取中间件链的长度
* @returns 中间件数量
*/
get length(): number {
return this.middlewares.length
}
/**
* 清空中间件链
* @returns this支持链式调用
*/
clear(): this {
this.middlewares = []
return this
}
/**
* 直接执行构建好的中间件链
* @param context - 中间件上下文
* @param params - 参数
* @param middlewareExecutor - 中间件执行器
* @returns 执行结果
*/
execute<TContext extends BaseContext>(
context: TContext,
params: any,
middlewareExecutor: MiddlewareExecutor<TContext>
): Promise<any> {
const chain = this.build()
return middlewareExecutor(chain, context, params)
}
/**
* 查找中间件在链中的索引
* @param name - 中间件名称
* @returns 索引,如果未找到返回 -1
*/
private findMiddlewareIndex(name: string): number {
return this.middlewares.findIndex((item) => item.name === name)
}
}
/**
* Completions 中间件构建器
*/
export class CompletionsMiddlewareBuilder extends MiddlewareBuilder<CompletionsMiddleware> {
constructor(baseChain?: NamedMiddleware<CompletionsMiddleware>[]) {
super(baseChain)
}
/**
* 使用默认的 Completions 中间件链
* @returns CompletionsMiddlewareBuilder 实例
*/
static withDefaults(): CompletionsMiddlewareBuilder {
return new CompletionsMiddlewareBuilder(DefaultCompletionsNamedMiddlewares)
}
}
/**
* 通用方法中间件构建器
*/
export class MethodMiddlewareBuilder extends MiddlewareBuilder<MethodMiddleware> {
constructor(baseChain?: NamedMiddleware<MethodMiddleware>[]) {
super(baseChain)
}
}
// 便捷的工厂函数
/**
* 创建 Completions 中间件构建器
* @param baseChain - 可选的基础链
* @returns Completions 中间件构建器实例
*/
export function createCompletionsBuilder(
baseChain?: NamedMiddleware<CompletionsMiddleware>[]
): CompletionsMiddlewareBuilder {
return new CompletionsMiddlewareBuilder(baseChain)
}
/**
* 创建通用方法中间件构建器
* @param baseChain - 可选的基础链
* @returns 通用方法中间件构建器实例
*/
export function createMethodBuilder(baseChain?: NamedMiddleware<MethodMiddleware>[]): MethodMiddlewareBuilder {
return new MethodMiddlewareBuilder(baseChain)
}
/**
* 为中间件添加名称属性的辅助函数
* 可以用于给现有的中间件添加名称属性
*/
export function addMiddlewareName<T extends object>(middleware: T, name: string): T & { MIDDLEWARE_NAME: string } {
return Object.assign(middleware, { MIDDLEWARE_NAME: name })
}

View File

@@ -1,106 +0,0 @@
import { Chunk, ChunkType, ErrorChunk } from '@renderer/types/chunk'
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
import { CompletionsParams, CompletionsResult } from '../schemas'
import type { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'AbortHandlerMiddleware'
export const AbortHandlerMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const isRecursiveCall = ctx._internal?.toolProcessingState?.isRecursiveCall || false
// 在递归调用中,跳过 AbortController 的创建,直接使用已有的
if (isRecursiveCall) {
const result = await next(ctx, params)
return result
}
// 获取当前消息的ID用于abort管理
// 优先使用处理过的消息,如果没有则使用原始消息
let messageId: string | undefined
if (typeof params.messages === 'string') {
messageId = `message-${Date.now()}-${Math.random().toString(36).substring(2, 9)}`
} else {
const processedMessages = params.messages
const lastUserMessage = processedMessages.findLast((m) => m.role === 'user')
messageId = lastUserMessage?.id
}
if (!messageId) {
console.warn(`[${MIDDLEWARE_NAME}] No messageId found, abort functionality will not be available.`)
return next(ctx, params)
}
const abortController = new AbortController()
const abortFn = (): void => abortController.abort()
addAbortController(messageId, abortFn)
let abortSignal: AbortSignal | null = abortController.signal
const cleanup = (): void => {
removeAbortController(messageId as string, abortFn)
if (ctx._internal?.flowControl) {
ctx._internal.flowControl.abortController = undefined
ctx._internal.flowControl.abortSignal = undefined
ctx._internal.flowControl.cleanup = undefined
}
abortSignal = null
}
// 将controller添加到_internal中的flowControl状态
if (!ctx._internal.flowControl) {
ctx._internal.flowControl = {}
}
ctx._internal.flowControl.abortController = abortController
ctx._internal.flowControl.abortSignal = abortSignal
ctx._internal.flowControl.cleanup = cleanup
const result = await next(ctx, params)
const error = new DOMException('Request was aborted', 'AbortError')
const streamWithAbortHandler = (result.stream as ReadableStream<Chunk>).pipeThrough(
new TransformStream<Chunk, Chunk | ErrorChunk>({
transform(chunk, controller) {
// 检查 abort 状态
if (abortSignal?.aborted) {
// 转换为 ErrorChunk
const errorChunk: ErrorChunk = {
type: ChunkType.ERROR,
error
}
controller.enqueue(errorChunk)
cleanup()
return
}
// 正常传递 chunk
controller.enqueue(chunk)
},
flush(controller) {
// 在流结束时再次检查 abort 状态
if (abortSignal?.aborted) {
const errorChunk: ErrorChunk = {
type: ChunkType.ERROR,
error
}
controller.enqueue(errorChunk)
}
// 在流完全处理完成后清理 AbortController
cleanup()
}
})
)
return {
...result,
stream: streamWithAbortHandler
}
}

View File

@@ -1,56 +0,0 @@
import { Chunk } from '@renderer/types/chunk'
import { CompletionsResult } from '../schemas'
import { CompletionsContext } from '../types'
import { createErrorChunk } from '../utils'
export const MIDDLEWARE_NAME = 'ErrorHandlerMiddleware'
/**
* 创建一个错误处理中间件。
*
* 这是一个高阶函数,它接收配置并返回一个标准的中间件。
* 它的主要职责是捕获下游中间件或API调用中发生的任何错误。
*
* @param config - 中间件的配置。
* @returns 一个配置好的CompletionsMiddleware。
*/
export const ErrorHandlerMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params): Promise<CompletionsResult> => {
const { shouldThrow } = params
try {
// 尝试执行下一个中间件
return await next(ctx, params)
} catch (error: any) {
console.log('ErrorHandlerMiddleware_error', error)
// 1. 使用通用的工具函数将错误解析为标准格式
const errorChunk = createErrorChunk(error)
// 2. 调用从外部传入的 onError 回调
if (params.onError) {
params.onError(error)
}
// 3. 根据配置决定是重新抛出错误,还是将其作为流的一部分向下传递
if (shouldThrow) {
throw error
}
// 如果不抛出,则创建一个只包含该错误块的流并向下传递
const errorStream = new ReadableStream<Chunk>({
start(controller) {
controller.enqueue(errorChunk)
controller.close()
}
})
return {
rawOutput: undefined,
stream: errorStream, // 将包含错误的流传递下去
controller: undefined,
getText: () => '' // 错误情况下没有文本结果
}
}
}

View File

@@ -1,183 +0,0 @@
import Logger from '@renderer/config/logger'
import { Usage } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import { ChunkType } from '@renderer/types/chunk'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'FinalChunkConsumerAndNotifierMiddleware'
/**
* 最终Chunk消费和通知中间件
*
* 职责:
* 1. 消费所有GenericChunk流中的chunks并转发给onChunk回调
* 2. 累加usage/metrics数据从原始SDK chunks或GenericChunk中提取
* 3. 在检测到LLM_RESPONSE_COMPLETE时发送包含累计数据的BLOCK_COMPLETE
* 4. 处理MCP工具调用的多轮请求中的数据累加
*/
const FinalChunkConsumerMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const isRecursiveCall =
params._internal?.toolProcessingState?.isRecursiveCall ||
ctx._internal?.toolProcessingState?.isRecursiveCall ||
false
// 初始化累计数据(只在顶层调用时初始化)
if (!isRecursiveCall) {
if (!ctx._internal.customState) {
ctx._internal.customState = {}
}
ctx._internal.observer = {
usage: {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0
},
metrics: {
completion_tokens: 0,
time_completion_millsec: 0,
time_first_token_millsec: 0,
time_thinking_millsec: 0
}
}
// 初始化文本累积器
ctx._internal.customState.accumulatedText = ''
ctx._internal.customState.startTimestamp = Date.now()
}
// 调用下游中间件
const result = await next(ctx, params)
// 响应后处理处理GenericChunk流式响应
if (result.stream) {
const resultFromUpstream = result.stream
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
const reader = resultFromUpstream.getReader()
try {
while (true) {
const { done, value: chunk } = await reader.read()
if (done) {
Logger.debug(`[${MIDDLEWARE_NAME}] Input stream finished.`)
break
}
if (chunk) {
const genericChunk = chunk as GenericChunk
// 提取并累加usage/metrics数据
extractAndAccumulateUsageMetrics(ctx, genericChunk)
const shouldSkipChunk =
isRecursiveCall &&
(genericChunk.type === ChunkType.BLOCK_COMPLETE ||
genericChunk.type === ChunkType.LLM_RESPONSE_COMPLETE)
if (!shouldSkipChunk) params.onChunk?.(genericChunk)
} else {
Logger.warn(`[${MIDDLEWARE_NAME}] Received undefined chunk before stream was done.`)
}
}
} catch (error) {
Logger.error(`[${MIDDLEWARE_NAME}] Error consuming stream:`, error)
throw error
} finally {
if (params.onChunk && !isRecursiveCall) {
params.onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
usage: ctx._internal.observer?.usage ? { ...ctx._internal.observer.usage } : undefined,
metrics: ctx._internal.observer?.metrics ? { ...ctx._internal.observer.metrics } : undefined
}
} as Chunk)
if (ctx._internal.toolProcessingState) {
ctx._internal.toolProcessingState = {}
}
}
}
// 为流式输出添加getText方法
const modifiedResult = {
...result,
stream: new ReadableStream<GenericChunk>({
start(controller) {
controller.close()
}
}),
getText: () => {
return ctx._internal.customState?.accumulatedText || ''
}
}
return modifiedResult
} else {
Logger.debug(`[${MIDDLEWARE_NAME}] No GenericChunk stream to process.`)
}
}
return result
}
/**
* 从GenericChunk或原始SDK chunks中提取usage/metrics数据并累加
*/
function extractAndAccumulateUsageMetrics(ctx: CompletionsContext, chunk: GenericChunk): void {
if (!ctx._internal.observer?.usage || !ctx._internal.observer?.metrics) {
return
}
try {
if (ctx._internal.customState && !ctx._internal.customState?.firstTokenTimestamp) {
ctx._internal.customState.firstTokenTimestamp = Date.now()
Logger.debug(`[${MIDDLEWARE_NAME}] First token timestamp: ${ctx._internal.customState.firstTokenTimestamp}`)
}
if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
Logger.debug(`[${MIDDLEWARE_NAME}] LLM_RESPONSE_COMPLETE chunk received:`, ctx._internal)
// 从LLM_RESPONSE_COMPLETE chunk中提取usage数据
if (chunk.response?.usage) {
accumulateUsage(ctx._internal.observer.usage, chunk.response.usage)
}
if (ctx._internal.customState && ctx._internal.customState?.firstTokenTimestamp) {
ctx._internal.observer.metrics.time_first_token_millsec =
ctx._internal.customState.firstTokenTimestamp - ctx._internal.customState.startTimestamp
ctx._internal.observer.metrics.time_completion_millsec +=
Date.now() - ctx._internal.customState.firstTokenTimestamp
}
}
// 也可以从其他chunk类型中提取metrics数据
if (chunk.type === ChunkType.THINKING_COMPLETE && chunk.thinking_millsec && ctx._internal.observer?.metrics) {
ctx._internal.observer.metrics.time_thinking_millsec = Math.max(
ctx._internal.observer.metrics.time_thinking_millsec || 0,
chunk.thinking_millsec
)
}
} catch (error) {
console.error(`[${MIDDLEWARE_NAME}] Error extracting usage/metrics from chunk:`, error)
}
}
/**
* 累加usage数据
*/
function accumulateUsage(accumulated: Usage, newUsage: Usage): void {
if (newUsage.prompt_tokens !== undefined) {
accumulated.prompt_tokens += newUsage.prompt_tokens
}
if (newUsage.completion_tokens !== undefined) {
accumulated.completion_tokens += newUsage.completion_tokens
}
if (newUsage.total_tokens !== undefined) {
accumulated.total_tokens += newUsage.total_tokens
}
if (newUsage.thoughts_tokens !== undefined) {
accumulated.thoughts_tokens = (accumulated.thoughts_tokens || 0) + newUsage.thoughts_tokens
}
}
export default FinalChunkConsumerMiddleware

View File

@@ -1,64 +0,0 @@
import { BaseContext, MethodMiddleware, MiddlewareAPI } from '../types'
export const MIDDLEWARE_NAME = 'GenericLoggingMiddlewares'
/**
* Helper function to safely stringify arguments for logging, handling circular references and large objects.
* 安全地字符串化日志参数的辅助函数,处理循环引用和大型对象。
* @param args - The arguments array to stringify. 要字符串化的参数数组。
* @returns A string representation of the arguments. 参数的字符串表示形式。
*/
const stringifyArgsForLogging = (args: any[]): string => {
try {
return args
.map((arg) => {
if (typeof arg === 'function') return '[Function]'
if (typeof arg === 'object' && arg !== null && arg.constructor === Object && Object.keys(arg).length > 20) {
return '[Object with >20 keys]'
}
// Truncate long strings to avoid flooding logs 截断长字符串以避免日志泛滥
const stringifiedArg = JSON.stringify(arg, null, 2)
return stringifiedArg && stringifiedArg.length > 200 ? stringifiedArg.substring(0, 200) + '...' : stringifiedArg
})
.join(', ')
} catch (e) {
return '[Error serializing arguments]' // Handle potential errors during stringification 处理字符串化期间的潜在错误
}
}
/**
* Generic logging middleware for provider methods.
* 为提供者方法创建一个通用的日志中间件。
* This middleware logs the initiation, success/failure, and duration of a method call.
* 此中间件记录方法调用的启动、成功/失败以及持续时间。
*/
/**
* Creates a generic logging middleware for provider methods.
* 为提供者方法创建一个通用的日志中间件。
* @returns A `MethodMiddleware` instance. 一个 `MethodMiddleware` 实例。
*/
export const createGenericLoggingMiddleware: () => MethodMiddleware = () => {
const middlewareName = 'GenericLoggingMiddleware'
// eslint-disable-next-line @typescript-eslint/no-unused-vars
return (_: MiddlewareAPI<BaseContext, any[]>) => (next) => async (ctx, args) => {
const methodName = ctx.methodName
const logPrefix = `[${middlewareName} (${methodName})]`
console.log(`${logPrefix} Initiating. Args:`, stringifyArgsForLogging(args))
const startTime = Date.now()
try {
const result = await next(ctx, args)
const duration = Date.now() - startTime
// Log successful completion of the method call with duration. /
// 记录方法调用成功完成及其持续时间。
console.log(`${logPrefix} Successful. Duration: ${duration}ms`)
return result
} catch (error) {
const duration = Date.now() - startTime
// Log failure of the method call with duration and error information. /
// 记录方法调用失败及其持续时间和错误信息。
console.error(`${logPrefix} Failed. Duration: ${duration}ms`, error)
throw error // Re-throw the error to be handled by subsequent layers or the caller / 重新抛出错误,由后续层或调用者处理
}
}
}

View File

@@ -1,285 +0,0 @@
import {
RequestOptions,
SdkInstance,
SdkMessageParam,
SdkParams,
SdkRawChunk,
SdkRawOutput,
SdkTool,
SdkToolCall
} from '@renderer/types/sdk'
import { BaseApiClient } from '../clients'
import { CompletionsParams, CompletionsResult } from './schemas'
import {
BaseContext,
CompletionsContext,
CompletionsMiddleware,
MethodMiddleware,
MIDDLEWARE_CONTEXT_SYMBOL,
MiddlewareAPI
} from './types'
/**
* Creates the initial context for a method call, populating method-specific fields. /
* 为方法调用创建初始上下文,并填充特定于该方法的字段。
* @param methodName - The name of the method being called. / 被调用的方法名。
* @param originalCallArgs - The actual arguments array from the proxy/method call. / 代理/方法调用的实际参数数组。
* @param providerId - The ID of the provider, if available. / 提供者的ID如果可用
* @param providerInstance - The instance of the provider. / 提供者实例。
* @param specificContextFactory - An optional factory function to create a specific context type from the base context and original call arguments. / 一个可选的工厂函数,用于从基础上下文和原始调用参数创建特定的上下文类型。
* @returns The created context object. / 创建的上下文对象。
*/
function createInitialCallContext<TContext extends BaseContext, TCallArgs extends unknown[]>(
methodName: string,
originalCallArgs: TCallArgs, // Renamed from originalArgs to avoid confusion with context.originalArgs
// Factory to create specific context from base and the *original call arguments array*
specificContextFactory?: (base: BaseContext, callArgs: TCallArgs) => TContext
): TContext {
const baseContext: BaseContext = {
[MIDDLEWARE_CONTEXT_SYMBOL]: true,
methodName,
originalArgs: originalCallArgs // Store the full original arguments array in the context
}
if (specificContextFactory) {
return specificContextFactory(baseContext, originalCallArgs)
}
return baseContext as TContext // Fallback to base context if no specific factory
}
/**
* Composes an array of functions from right to left. /
* 从右到左组合一个函数数组。
* `compose(f, g, h)` is `(...args) => f(g(h(...args)))`. /
* `compose(f, g, h)` 等同于 `(...args) => f(g(h(...args)))`。
* Each function in funcs is expected to take the result of the next function
* (or the initial value for the rightmost function) as its argument. /
* `funcs` 中的每个函数都期望接收下一个函数的结果(或最右侧函数的初始值)作为其参数。
* @param funcs - Array of functions to compose. / 要组合的函数数组。
* @returns The composed function. / 组合后的函数。
*/
function compose(...funcs: Array<(...args: any[]) => any>): (...args: any[]) => any {
if (funcs.length === 0) {
// If no functions to compose, return a function that returns its first argument, or undefined if no args. /
// 如果没有要组合的函数则返回一个函数该函数返回其第一个参数如果没有参数则返回undefined。
return (...args: any[]) => (args.length > 0 ? args[0] : undefined)
}
if (funcs.length === 1) {
return funcs[0]
}
return funcs.reduce(
(a, b) =>
(...args: any[]) =>
a(b(...args))
)
}
/**
* Applies an array of Redux-style middlewares to a generic provider method. /
* 将一组Redux风格的中间件应用于一个通用的提供者方法。
* This version keeps arguments as an array throughout the middleware chain. /
* 此版本在整个中间件链中将参数保持为数组形式。
* @param originalProviderInstance - The original provider instance. / 原始提供者实例。
* @param methodName - The name of the method to be enhanced. / 需要增强的方法名。
* @param originalMethod - The original method to be wrapped. / 需要包装的原始方法。
* @param middlewares - An array of `ProviderMethodMiddleware` to apply. / 要应用的 `ProviderMethodMiddleware` 数组。
* @param specificContextFactory - An optional factory to create a specific context for this method. / 可选的工厂函数,用于为此方法创建特定的上下文。
* @returns An enhanced method with the middlewares applied. / 应用了中间件的增强方法。
*/
export function applyMethodMiddlewares<
TArgs extends unknown[] = unknown[], // Original method's arguments array type / 原始方法的参数数组类型
TResult = unknown,
TContext extends BaseContext = BaseContext
>(
methodName: string,
originalMethod: (...args: TArgs) => Promise<TResult>,
middlewares: MethodMiddleware[], // Expects generic middlewares / 期望通用中间件
specificContextFactory?: (base: BaseContext, callArgs: TArgs) => TContext
): (...args: TArgs) => Promise<TResult> {
// Returns a function matching the original method signature. /
// 返回一个与原始方法签名匹配的函数。
return async function enhancedMethod(...methodCallArgs: TArgs): Promise<TResult> {
const ctx = createInitialCallContext<TContext, TArgs>(
methodName,
methodCallArgs, // Pass the actual call arguments array / 传递实际的调用参数数组
specificContextFactory
)
const api: MiddlewareAPI<TContext, TArgs> = {
getContext: () => ctx,
getOriginalArgs: () => methodCallArgs // API provides the original arguments array / API提供原始参数数组
}
// `finalDispatch` is the function that will ultimately call the original provider method. /
// `finalDispatch` 是最终将调用原始提供者方法的函数。
// It receives the current context and arguments, which may have been transformed by middlewares. /
// 它接收当前的上下文和参数,这些参数可能已被中间件转换。
const finalDispatch = async (
_: TContext,
currentArgs: TArgs // Generic final dispatch expects args array / 通用finalDispatch期望参数数组
): Promise<TResult> => {
return originalMethod.apply(currentArgs)
}
const chain = middlewares.map((middleware) => middleware(api)) // Cast API if TContext/TArgs mismatch general ProviderMethodMiddleware / 如果TContext/TArgs与通用的ProviderMethodMiddleware不匹配则转换API
const composedMiddlewareLogic = compose(...chain)
const enhancedDispatch = composedMiddlewareLogic(finalDispatch)
return enhancedDispatch(ctx, methodCallArgs) // Pass context and original args array / 传递上下文和原始参数数组
}
}
/**
* Applies an array of `CompletionsMiddleware` to the `completions` method. /
* 将一组 `CompletionsMiddleware` 应用于 `completions` 方法。
* This version adapts for `CompletionsMiddleware` expecting a single `params` object. /
* 此版本适配了期望单个 `params` 对象的 `CompletionsMiddleware`。
* @param originalProviderInstance - The original provider instance. / 原始提供者实例。
* @param originalCompletionsMethod - The original SDK `createCompletions` method. / 原始的 SDK `createCompletions` 方法。
* @param middlewares - An array of `CompletionsMiddleware` to apply. / 要应用的 `CompletionsMiddleware` 数组。
* @returns An enhanced `completions` method with the middlewares applied. / 应用了中间件的增强版 `completions` 方法。
*/
export function applyCompletionsMiddlewares<
TSdkInstance extends SdkInstance = SdkInstance,
TSdkParams extends SdkParams = SdkParams,
TRawOutput extends SdkRawOutput = SdkRawOutput,
TRawChunk extends SdkRawChunk = SdkRawChunk,
TMessageParam extends SdkMessageParam = SdkMessageParam,
TToolCall extends SdkToolCall = SdkToolCall,
TSdkSpecificTool extends SdkTool = SdkTool
>(
originalApiClientInstance: BaseApiClient<
TSdkInstance,
TSdkParams,
TRawOutput,
TRawChunk,
TMessageParam,
TToolCall,
TSdkSpecificTool
>,
originalCompletionsMethod: (payload: TSdkParams, options?: RequestOptions) => Promise<TRawOutput>,
middlewares: CompletionsMiddleware<
TSdkParams,
TMessageParam,
TToolCall,
TSdkInstance,
TRawOutput,
TRawChunk,
TSdkSpecificTool
>[]
): (params: CompletionsParams, options?: RequestOptions) => Promise<CompletionsResult> {
// Returns a function matching the original method signature. /
// 返回一个与原始方法签名匹配的函数。
const methodName = 'completions'
// Factory to create AiProviderMiddlewareCompletionsContext. /
// 用于创建 AiProviderMiddlewareCompletionsContext 的工厂函数。
const completionsContextFactory = (
base: BaseContext,
callArgs: [CompletionsParams]
): CompletionsContext<
TSdkParams,
TMessageParam,
TToolCall,
TSdkInstance,
TRawOutput,
TRawChunk,
TSdkSpecificTool
> => {
return {
...base,
methodName,
apiClientInstance: originalApiClientInstance,
originalArgs: callArgs,
_internal: {
toolProcessingState: {
recursionDepth: 0,
isRecursiveCall: false
},
observer: {}
}
}
}
return async function enhancedCompletionsMethod(
params: CompletionsParams,
options?: RequestOptions
): Promise<CompletionsResult> {
// `originalCallArgs` for context creation is `[params]`. /
// 用于上下文创建的 `originalCallArgs` 是 `[params]`。
const originalCallArgs: [CompletionsParams] = [params]
const baseContext: BaseContext = {
[MIDDLEWARE_CONTEXT_SYMBOL]: true,
methodName,
originalArgs: originalCallArgs
}
const ctx = completionsContextFactory(baseContext, originalCallArgs)
const api: MiddlewareAPI<
CompletionsContext<TSdkParams, TMessageParam, TToolCall, TSdkInstance, TRawOutput, TRawChunk, TSdkSpecificTool>,
[CompletionsParams]
> = {
getContext: () => ctx,
getOriginalArgs: () => originalCallArgs // API provides [CompletionsParams] / API提供 `[CompletionsParams]`
}
// `finalDispatch` for CompletionsMiddleware: expects (context, params) not (context, args_array). /
// `CompletionsMiddleware` 的 `finalDispatch`:期望 (context, params) 而不是 (context, args_array)。
const finalDispatch = async (
context: CompletionsContext<
TSdkParams,
TMessageParam,
TToolCall,
TSdkInstance,
TRawOutput,
TRawChunk,
TSdkSpecificTool
> // Context passed through / 上下文透传
// _currentParams: CompletionsParams // Directly takes params / 直接接收参数 (unused but required for middleware signature)
): Promise<CompletionsResult> => {
// At this point, middleware should have transformed CompletionsParams to SDK params
// and stored them in context. If no transformation happened, we need to handle it.
// 此时,中间件应该已经将 CompletionsParams 转换为 SDK 参数并存储在上下文中。
// 如果没有进行转换,我们需要处理它。
const sdkPayload = context._internal?.sdkPayload
if (!sdkPayload) {
throw new Error('SDK payload not found in context. Middleware chain should have transformed parameters.')
}
const abortSignal = context._internal.flowControl?.abortSignal
const timeout = context._internal.customState?.sdkMetadata?.timeout
// Call the original SDK method with transformed parameters
// 使用转换后的参数调用原始 SDK 方法
const rawOutput = await originalCompletionsMethod.call(originalApiClientInstance, sdkPayload, {
...options,
signal: abortSignal,
timeout
})
// Return result wrapped in CompletionsResult format
// 以 CompletionsResult 格式返回包装的结果
return {
rawOutput
} as CompletionsResult
}
const chain = middlewares.map((middleware) => middleware(api))
const composedMiddlewareLogic = compose(...chain)
// `enhancedDispatch` has the signature `(context, params) => Promise<CompletionsResult>`. /
// `enhancedDispatch` 的签名为 `(context, params) => Promise<CompletionsResult>`。
const enhancedDispatch = composedMiddlewareLogic(finalDispatch)
// 将 enhancedDispatch 保存到 context 中,供中间件进行递归调用
// 这样可以避免重复执行整个中间件链
ctx._internal.enhancedDispatch = enhancedDispatch
// Execute with context and the single params object. /
// 使用上下文和单个参数对象执行。
return enhancedDispatch(ctx, params)
}
}

View File

@@ -1,306 +0,0 @@
import Logger from '@renderer/config/logger'
import { MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types'
import { ChunkType, MCPToolCreatedChunk } from '@renderer/types/chunk'
import { SdkMessageParam, SdkRawOutput, SdkToolCall } from '@renderer/types/sdk'
import { parseAndCallTools } from '@renderer/utils/mcp-tools'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'McpToolChunkMiddleware'
const MAX_TOOL_RECURSION_DEPTH = 20 // 防止无限递归
/**
* MCP工具处理中间件
*
* 职责:
* 1. 检测并拦截MCP工具进展chunkFunction Call方式和Tool Use方式
* 2. 执行工具调用
* 3. 递归处理工具结果
* 4. 管理工具调用状态和递归深度
*/
export const McpToolChunkMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const mcpTools = params.mcpTools || []
// 如果没有工具,直接调用下一个中间件
if (!mcpTools || mcpTools.length === 0) {
return next(ctx, params)
}
const executeWithToolHandling = async (currentParams: CompletionsParams, depth = 0): Promise<CompletionsResult> => {
if (depth >= MAX_TOOL_RECURSION_DEPTH) {
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Maximum recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`)
throw new Error(`Maximum tool recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`)
}
let result: CompletionsResult
if (depth === 0) {
result = await next(ctx, currentParams)
} else {
const enhancedCompletions = ctx._internal.enhancedDispatch
if (!enhancedCompletions) {
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Enhanced completions method not found, cannot perform recursive call`)
throw new Error('Enhanced completions method not found')
}
ctx._internal.toolProcessingState!.isRecursiveCall = true
ctx._internal.toolProcessingState!.recursionDepth = depth
result = await enhancedCompletions(ctx, currentParams)
}
if (!result.stream) {
Logger.error(`🔧 [${MIDDLEWARE_NAME}] No stream returned from enhanced completions`)
throw new Error('No stream returned from enhanced completions')
}
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
const toolHandlingStream = resultFromUpstream.pipeThrough(
createToolHandlingTransform(ctx, currentParams, mcpTools, depth, executeWithToolHandling)
)
return {
...result,
stream: toolHandlingStream
}
}
return executeWithToolHandling(params, 0)
}
/**
* 创建工具处理的 TransformStream
*/
function createToolHandlingTransform(
ctx: CompletionsContext,
currentParams: CompletionsParams,
mcpTools: MCPTool[],
depth: number,
executeWithToolHandling: (params: CompletionsParams, depth: number) => Promise<CompletionsResult>
): TransformStream<GenericChunk, GenericChunk> {
const toolCalls: SdkToolCall[] = []
const toolUseResponses: MCPToolResponse[] = []
const allToolResponses: MCPToolResponse[] = [] // 统一的工具响应状态管理数组
let hasToolCalls = false
let hasToolUseResponses = false
let streamEnded = false
return new TransformStream({
async transform(chunk: GenericChunk, controller) {
try {
// 处理MCP工具进展chunk
if (chunk.type === ChunkType.MCP_TOOL_CREATED) {
const createdChunk = chunk as MCPToolCreatedChunk
// 1. 处理Function Call方式的工具调用
if (createdChunk.tool_calls && createdChunk.tool_calls.length > 0) {
toolCalls.push(...createdChunk.tool_calls)
hasToolCalls = true
}
// 2. 处理Tool Use方式的工具调用
if (createdChunk.tool_use_responses && createdChunk.tool_use_responses.length > 0) {
toolUseResponses.push(...createdChunk.tool_use_responses)
hasToolUseResponses = true
}
// 不转发MCP工具进展chunks避免重复处理
return
}
// 转发其他所有chunk
controller.enqueue(chunk)
} catch (error) {
console.error(`🔧 [${MIDDLEWARE_NAME}] Error processing chunk:`, error)
controller.error(error)
}
},
async flush(controller) {
const shouldExecuteToolCalls = hasToolCalls && toolCalls.length > 0
const shouldExecuteToolUseResponses = hasToolUseResponses && toolUseResponses.length > 0
if (!streamEnded && (shouldExecuteToolCalls || shouldExecuteToolUseResponses)) {
streamEnded = true
try {
let toolResult: SdkMessageParam[] = []
if (shouldExecuteToolCalls) {
toolResult = await executeToolCalls(
ctx,
toolCalls,
mcpTools,
allToolResponses,
currentParams.onChunk,
currentParams.assistant.model!
)
} else if (shouldExecuteToolUseResponses) {
toolResult = await executeToolUseResponses(
ctx,
toolUseResponses,
mcpTools,
allToolResponses,
currentParams.onChunk,
currentParams.assistant.model!
)
}
if (toolResult.length > 0) {
const output = ctx._internal.toolProcessingState?.output
const newParams = buildParamsWithToolResults(ctx, currentParams, output, toolResult, toolCalls)
await executeWithToolHandling(newParams, depth + 1)
}
} catch (error) {
console.error(`🔧 [${MIDDLEWARE_NAME}] Error in tool processing:`, error)
controller.error(error)
} finally {
hasToolCalls = false
hasToolUseResponses = false
}
}
}
})
}
/**
* 执行工具调用Function Call 方式)
*/
async function executeToolCalls(
ctx: CompletionsContext,
toolCalls: SdkToolCall[],
mcpTools: MCPTool[],
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
model: Model
): Promise<SdkMessageParam[]> {
// 转换为MCPToolResponse格式
const mcpToolResponses: ToolCallResponse[] = toolCalls
.map((toolCall) => {
const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools)
if (!mcpTool) {
return undefined
}
return ctx.apiClientInstance.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
})
.filter((t): t is ToolCallResponse => typeof t !== 'undefined')
if (mcpToolResponses.length === 0) {
console.warn(`🔧 [${MIDDLEWARE_NAME}] No valid MCP tool responses to execute`)
return []
}
// 使用现有的parseAndCallTools函数执行工具
const toolResults = await parseAndCallTools(
mcpToolResponses,
allToolResponses,
onChunk,
(mcpToolResponse, resp, model) => {
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
},
model,
mcpTools
)
return toolResults
}
/**
* 执行工具使用响应Tool Use Response 方式)
* 处理已经解析好的 ToolUseResponse[],不需要重新解析字符串
*/
async function executeToolUseResponses(
ctx: CompletionsContext,
toolUseResponses: MCPToolResponse[],
mcpTools: MCPTool[],
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
model: Model
): Promise<SdkMessageParam[]> {
// 直接使用parseAndCallTools函数处理已经解析好的ToolUseResponse
const toolResults = await parseAndCallTools(
toolUseResponses,
allToolResponses,
onChunk,
(mcpToolResponse, resp, model) => {
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
},
model,
mcpTools
)
return toolResults
}
/**
* 构建包含工具结果的新参数
*/
function buildParamsWithToolResults(
ctx: CompletionsContext,
currentParams: CompletionsParams,
output: SdkRawOutput | string | undefined,
toolResults: SdkMessageParam[],
toolCalls: SdkToolCall[]
): CompletionsParams {
// 获取当前已经转换好的reqMessages如果没有则使用原始messages
const currentReqMessages = getCurrentReqMessages(ctx)
const apiClient = ctx.apiClientInstance
// 从回复中构建助手消息
const newReqMessages = apiClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
// 估算新增消息的 token 消耗并累加到 usage 中
if (ctx._internal.observer?.usage && newReqMessages.length > currentReqMessages.length) {
try {
const newMessages = newReqMessages.slice(currentReqMessages.length)
const additionalTokens = newMessages.reduce((acc, message) => {
return acc + ctx.apiClientInstance.estimateMessageTokens(message)
}, 0)
if (additionalTokens > 0) {
ctx._internal.observer.usage.prompt_tokens += additionalTokens
ctx._internal.observer.usage.total_tokens += additionalTokens
}
} catch (error) {
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Error estimating token usage for new messages:`, error)
}
}
// 更新递归状态
if (!ctx._internal.toolProcessingState) {
ctx._internal.toolProcessingState = {}
}
ctx._internal.toolProcessingState.isRecursiveCall = true
ctx._internal.toolProcessingState.recursionDepth = (ctx._internal.toolProcessingState?.recursionDepth || 0) + 1
return {
...currentParams,
_internal: {
...ctx._internal,
sdkPayload: ctx._internal.sdkPayload,
newReqMessages: newReqMessages
}
}
}
/**
* 类型安全地获取当前请求消息
* 使用API客户端提供的抽象方法保持中间件的provider无关性
*/
function getCurrentReqMessages(ctx: CompletionsContext): SdkMessageParam[] {
const sdkPayload = ctx._internal.sdkPayload
if (!sdkPayload) {
return []
}
// 使用API客户端的抽象方法来提取消息保持provider无关性
return ctx.apiClientInstance.extractMessagesFromSdkPayload(sdkPayload)
}
export default McpToolChunkMiddleware

View File

@@ -1,46 +0,0 @@
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
import { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk'
import { AnthropicStreamListener } from '../../clients/types'
import { CompletionsParams, CompletionsResult } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'RawStreamListenerMiddleware'
export const RawStreamListenerMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const result = await next(ctx, params)
// 在这里可以监听到从SDK返回的最原始流
if (result.rawOutput) {
const providerType = ctx.apiClientInstance.provider.type
// TODO: 后面下放到AnthropicAPIClient
if (providerType === 'anthropic') {
const anthropicListener: AnthropicStreamListener<AnthropicSdkRawChunk> = {
onMessage: (message) => {
if (ctx._internal?.toolProcessingState) {
ctx._internal.toolProcessingState.output = message
}
}
// onContentBlock: (contentBlock) => {
// console.log(`[${MIDDLEWARE_NAME}] 📝 Anthropic content block:`, contentBlock.type)
// }
}
const specificApiClient = ctx.apiClientInstance as AnthropicAPIClient
const monitoredOutput = specificApiClient.attachRawStreamListener(
result.rawOutput as AnthropicSdkRawOutput,
anthropicListener
)
return {
...result,
rawOutput: monitoredOutput
}
}
}
return result
}

View File

@@ -1,85 +0,0 @@
import Logger from '@renderer/config/logger'
import { SdkRawChunk } from '@renderer/types/sdk'
import { ResponseChunkTransformerContext } from '../../clients/types'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'ResponseTransformMiddleware'
/**
* 响应转换中间件
*
* 职责:
* 1. 检测ReadableStream类型的响应流
* 2. 使用ApiClient的getResponseChunkTransformer()将原始SDK响应块转换为通用格式
* 3. 将转换后的ReadableStream保存到ctx._internal.apiCall.genericChunkStream供下游中间件使用
*
* 注意此中间件应该在StreamAdapterMiddleware之后执行
*/
export const ResponseTransformMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
// 调用下游中间件
const result = await next(ctx, params)
// 响应后处理转换原始SDK响应块
if (result.stream) {
const adaptedStream = result.stream
// 处理ReadableStream类型的流
if (adaptedStream instanceof ReadableStream) {
const apiClient = ctx.apiClientInstance
if (!apiClient) {
console.error(`[${MIDDLEWARE_NAME}] ApiClient instance not found in context`)
throw new Error('ApiClient instance not found in context')
}
// 获取响应转换器
const responseChunkTransformer = apiClient.getResponseChunkTransformer(ctx)
if (!responseChunkTransformer) {
Logger.warn(`[${MIDDLEWARE_NAME}] No ResponseChunkTransformer available, skipping transformation`)
return result
}
const assistant = params.assistant
const model = assistant?.model
if (!assistant || !model) {
console.error(`[${MIDDLEWARE_NAME}] Assistant or Model not found for transformation`)
throw new Error('Assistant or Model not found for transformation')
}
const transformerContext: ResponseChunkTransformerContext = {
isStreaming: params.streamOutput || false,
isEnabledToolCalling: (params.mcpTools && params.mcpTools.length > 0) || false,
isEnabledWebSearch: params.enableWebSearch || false,
isEnabledReasoning: params.enableReasoning || false,
mcpTools: params.mcpTools || [],
provider: ctx.apiClientInstance?.provider
}
console.log(`[${MIDDLEWARE_NAME}] Transforming raw SDK chunks with context:`, transformerContext)
try {
// 创建转换后的流
const genericChunkTransformStream = (adaptedStream as ReadableStream<SdkRawChunk>).pipeThrough<GenericChunk>(
new TransformStream<SdkRawChunk, GenericChunk>(responseChunkTransformer(transformerContext))
)
// 将转换后的ReadableStream保存到result供下游中间件使用
return {
...result,
stream: genericChunkTransformStream
}
} catch (error) {
Logger.error(`[${MIDDLEWARE_NAME}] Error during chunk transformation:`, error)
throw error
}
}
}
// 如果没有流或不是ReadableStream返回原始结果
return result
}

View File

@@ -1,56 +0,0 @@
import { SdkRawChunk } from '@renderer/types/sdk'
import { asyncGeneratorToReadableStream, createSingleChunkReadableStream } from '@renderer/utils/stream'
import { CompletionsParams, CompletionsResult } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
import { isAsyncIterable } from '../utils'
export const MIDDLEWARE_NAME = 'StreamAdapterMiddleware'
/**
* 流适配器中间件
*
* 职责:
* 1. 检测ctx._internal.apiCall.rawSdkOutput优先或原始AsyncIterable流
* 2. 将AsyncIterable转换为WHATWG ReadableStream
* 3. 更新响应结果中的stream
*
* 注意如果ResponseTransformMiddleware已处理过会优先使用transformedStream
*/
export const StreamAdapterMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
// TODO:调用开始因为这个是最靠近接口请求的地方next执行代表着开始接口请求了
// 但是这个中间件的职责是流适配,是否在这调用优待商榷
// 调用下游中间件
const result = await next(ctx, params)
if (
result.rawOutput &&
!(result.rawOutput instanceof ReadableStream) &&
isAsyncIterable<SdkRawChunk>(result.rawOutput)
) {
const whatwgReadableStream: ReadableStream<SdkRawChunk> = asyncGeneratorToReadableStream<SdkRawChunk>(
result.rawOutput
)
return {
...result,
stream: whatwgReadableStream
}
} else if (result.rawOutput && result.rawOutput instanceof ReadableStream) {
return {
...result,
stream: result.rawOutput
}
} else if (result.rawOutput) {
// 非流式输出,强行变为可读流
const whatwgReadableStream: ReadableStream<SdkRawChunk> = createSingleChunkReadableStream<SdkRawChunk>(
result.rawOutput
)
return {
...result,
stream: whatwgReadableStream
}
}
return result
}

View File

@@ -1,99 +0,0 @@
import Logger from '@renderer/config/logger'
import { ChunkType, TextDeltaChunk } from '@renderer/types/chunk'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'TextChunkMiddleware'
/**
* 文本块处理中间件
*
* 职责:
* 1. 累积文本内容TEXT_DELTA
* 2. 对文本内容进行智能链接转换
* 3. 生成TEXT_COMPLETE事件
* 4. 暂存Web搜索结果用于最终链接完善
* 5. 处理 onResponse 回调,实时发送文本更新和最终完整文本
*/
export const TextChunkMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
// 调用下游中间件
const result = await next(ctx, params)
// 响应后处理:转换流式响应中的文本内容
if (result.stream) {
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
const assistant = params.assistant
const model = params.assistant?.model
if (!assistant || !model) {
Logger.warn(`[${MIDDLEWARE_NAME}] Missing assistant or model information, skipping text processing`)
return result
}
// 用于跨chunk的状态管理
let accumulatedTextContent = ''
let hasEnqueue = false
const enhancedTextStream = resultFromUpstream.pipeThrough(
new TransformStream<GenericChunk, GenericChunk>({
transform(chunk: GenericChunk, controller) {
if (chunk.type === ChunkType.TEXT_DELTA) {
const textChunk = chunk as TextDeltaChunk
accumulatedTextContent += textChunk.text
// 处理 onResponse 回调 - 发送增量文本更新
if (params.onResponse) {
params.onResponse(accumulatedTextContent, false)
}
// 创建新的chunk包含处理后的文本
controller.enqueue(chunk)
} else if (accumulatedTextContent) {
if (chunk.type !== ChunkType.LLM_RESPONSE_COMPLETE) {
controller.enqueue(chunk)
hasEnqueue = true
}
const finalText = accumulatedTextContent
ctx._internal.customState!.accumulatedText = finalText
if (ctx._internal.toolProcessingState && !ctx._internal.toolProcessingState?.output) {
ctx._internal.toolProcessingState.output = finalText
}
// 处理 onResponse 回调 - 发送最终完整文本
if (params.onResponse) {
params.onResponse(finalText, true)
}
controller.enqueue({
type: ChunkType.TEXT_COMPLETE,
text: finalText
})
accumulatedTextContent = ''
if (!hasEnqueue) {
controller.enqueue(chunk)
}
} else {
// 其他类型的chunk直接传递
controller.enqueue(chunk)
}
}
})
)
// 更新响应结果
return {
...result,
stream: enhancedTextStream
}
} else {
Logger.warn(`[${MIDDLEWARE_NAME}] No stream to process or not a ReadableStream. Returning original result.`)
}
}
return result
}

View File

@@ -1,101 +0,0 @@
import Logger from '@renderer/config/logger'
import { ChunkType, ThinkingCompleteChunk, ThinkingDeltaChunk } from '@renderer/types/chunk'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'ThinkChunkMiddleware'
/**
* 处理思考内容的中间件
*
* 注意:从 v2 版本开始,流结束语义的判断已移至 ApiClient 层处理
* 此中间件现在主要负责:
* 1. 处理原始SDK chunk中的reasoning字段
* 2. 计算准确的思考时间
* 3. 在思考内容结束时生成THINKING_COMPLETE事件
*
* 职责:
* 1. 累积思考内容THINKING_DELTA
* 2. 监听流结束信号生成THINKING_COMPLETE事件
* 3. 计算准确的思考时间
*
*/
export const ThinkChunkMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
// 调用下游中间件
const result = await next(ctx, params)
// 响应后处理:处理思考内容
if (result.stream) {
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
// 检查是否启用reasoning
const enableReasoning = params.enableReasoning || false
if (!enableReasoning) {
return result
}
// 检查是否有流需要处理
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
// thinking 处理状态
let accumulatedThinkingContent = ''
let hasThinkingContent = false
let thinkingStartTime = 0
const processedStream = resultFromUpstream.pipeThrough(
new TransformStream<GenericChunk, GenericChunk>({
transform(chunk: GenericChunk, controller) {
if (chunk.type === ChunkType.THINKING_DELTA) {
const thinkingChunk = chunk as ThinkingDeltaChunk
// 第一次接收到思考内容时记录开始时间
if (!hasThinkingContent) {
hasThinkingContent = true
thinkingStartTime = Date.now()
}
accumulatedThinkingContent += thinkingChunk.text
// 更新思考时间并传递
const enhancedChunk: ThinkingDeltaChunk = {
...thinkingChunk,
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
}
controller.enqueue(enhancedChunk)
} else if (hasThinkingContent && thinkingStartTime > 0) {
// 收到任何非THINKING_DELTA的chunk时如果有累积的思考内容生成THINKING_COMPLETE
const thinkingCompleteChunk: ThinkingCompleteChunk = {
type: ChunkType.THINKING_COMPLETE,
text: accumulatedThinkingContent,
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
}
controller.enqueue(thinkingCompleteChunk)
hasThinkingContent = false
accumulatedThinkingContent = ''
thinkingStartTime = 0
// 继续传递当前chunk
controller.enqueue(chunk)
} else {
// 其他情况直接传递
controller.enqueue(chunk)
}
}
})
)
// 更新响应结果
return {
...result,
stream: processedStream
}
} else {
Logger.warn(`[${MIDDLEWARE_NAME}] No generic chunk stream to process or not a ReadableStream.`)
}
}
return result
}

View File

@@ -1,81 +0,0 @@
import Logger from '@renderer/config/logger'
import { ChunkType } from '@renderer/types/chunk'
import { CompletionsParams, CompletionsResult } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'TransformCoreToSdkParamsMiddleware'
/**
* 中间件将CoreCompletionsRequest转换为SDK特定的参数
* 使用上下文中ApiClient实例的requestTransformer进行转换
*/
export const TransformCoreToSdkParamsMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const internal = ctx._internal
// 🔧 检测递归调用:检查 params 中是否携带了预处理的 SDK 消息
const isRecursiveCall = internal?.toolProcessingState?.isRecursiveCall || false
const newSdkMessages = params._internal?.newReqMessages
const apiClient = ctx.apiClientInstance
if (!apiClient) {
Logger.error(`🔄 [${MIDDLEWARE_NAME}] ApiClient instance not found in context.`)
throw new Error('ApiClient instance not found in context')
}
// 检查是否有requestTransformer方法
const requestTransformer = apiClient.getRequestTransformer()
if (!requestTransformer) {
Logger.warn(
`🔄 [${MIDDLEWARE_NAME}] ApiClient does not have getRequestTransformer method, skipping transformation`
)
const result = await next(ctx, params)
return result
}
// 确保assistant和model可用它们是transformer所需的
const assistant = params.assistant
const model = params.assistant.model
if (!assistant || !model) {
console.error(`🔄 [${MIDDLEWARE_NAME}] Assistant or Model not found for transformation.`)
throw new Error('Assistant or Model not found for transformation')
}
try {
const transformResult = await requestTransformer.transform(
params,
assistant,
model,
isRecursiveCall,
newSdkMessages
)
const { payload: sdkPayload, metadata } = transformResult
// 将SDK特定的payload和metadata存储在状态中供下游中间件使用
ctx._internal.sdkPayload = sdkPayload
if (metadata) {
ctx._internal.customState = {
...ctx._internal.customState,
sdkMetadata: metadata
}
}
if (params.enableGenerateImage) {
params.onChunk?.({
type: ChunkType.IMAGE_CREATED
})
}
return next(ctx, params)
} catch (error) {
Logger.error(`🔄 [${MIDDLEWARE_NAME}] Error during request transformation:`, error)
// 让错误向上传播,或者可以在这里进行特定的错误处理
throw error
}
}

View File

@@ -1,76 +0,0 @@
import { ChunkType } from '@renderer/types/chunk'
import { smartLinkConverter } from '@renderer/utils/linkConverter'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'WebSearchMiddleware'
/**
* Web搜索处理中间件 - 基于GenericChunk流处理
*
* 职责:
* 1. 监听和记录Web搜索事件
* 2. 可以在此处添加Web搜索结果的后处理逻辑
* 3. 维护Web搜索相关的状态
*
* 注意Web搜索结果的识别和生成已在ApiClient的响应转换器中处理
*/
export const WebSearchMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
ctx._internal.webSearchState = {
results: undefined
}
// 调用下游中间件
const result = await next(ctx, params)
const model = params.assistant?.model!
let isFirstChunk = true
// 响应后处理记录Web搜索事件
if (result.stream) {
const resultFromUpstream = result.stream
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
// Web搜索状态跟踪
const enhancedStream = (resultFromUpstream as ReadableStream<GenericChunk>).pipeThrough(
new TransformStream<GenericChunk, GenericChunk>({
transform(chunk: GenericChunk, controller) {
if (chunk.type === ChunkType.TEXT_DELTA) {
const providerType = model.provider || 'openai'
// 使用当前可用的Web搜索结果进行链接转换
const text = chunk.text
const processedText = smartLinkConverter(text, providerType, isFirstChunk)
if (isFirstChunk) {
isFirstChunk = false
}
controller.enqueue({
...chunk,
text: processedText
})
} else if (chunk.type === ChunkType.LLM_WEB_SEARCH_COMPLETE) {
// 暂存Web搜索结果用于链接完善
ctx._internal.webSearchState!.results = chunk.llm_web_search
// 将Web搜索完成事件继续传递下去
controller.enqueue(chunk)
} else {
controller.enqueue(chunk)
}
}
})
)
return {
...result,
stream: enhancedStream
}
} else {
console.log(`[${MIDDLEWARE_NAME}] No stream to process or not a ReadableStream.`)
}
}
return result
}

View File

@@ -1,141 +0,0 @@
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import { ChunkType } from '@renderer/types/chunk'
import { findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import OpenAI from 'openai'
import { toFile } from 'openai/uploads'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'ImageGenerationMiddleware'
export const ImageGenerationMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (context: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const { assistant, messages } = params
const client = context.apiClientInstance as BaseApiClient<OpenAI>
const signal = context._internal?.flowControl?.abortSignal
if (!assistant.model || !isDedicatedImageGenerationModel(assistant.model) || typeof messages === 'string') {
return next(context, params)
}
const stream = new ReadableStream<GenericChunk>({
async start(controller) {
const enqueue = (chunk: GenericChunk) => controller.enqueue(chunk)
try {
if (!assistant.model) {
throw new Error('Assistant model is not defined.')
}
const sdk = await client.getSdkInstance()
const lastUserMessage = messages.findLast((m) => m.role === 'user')
const lastAssistantMessage = messages.findLast((m) => m.role === 'assistant')
if (!lastUserMessage) {
throw new Error('No user message found for image generation.')
}
const prompt = getMainTextContent(lastUserMessage)
let imageFiles: Blob[] = []
// Collect images from user message
const userImageBlocks = findImageBlocks(lastUserMessage)
const userImages = await Promise.all(
userImageBlocks.map(async (block) => {
if (!block.file) return null
const binaryData: Uint8Array = await window.api.file.binaryImage(block.file.id)
const mimeType = `${block.file.type}/${block.file.ext.slice(1)}`
return await toFile(new Blob([binaryData]), block.file.origin_name || 'image.png', { type: mimeType })
})
)
imageFiles = imageFiles.concat(userImages.filter(Boolean) as Blob[])
// Collect images from last assistant message
if (lastAssistantMessage) {
const assistantImageBlocks = findImageBlocks(lastAssistantMessage)
const assistantImages = await Promise.all(
assistantImageBlocks.map(async (block) => {
const b64 = block.url?.replace(/^data:image\/\w+;base64,/, '')
if (!b64) return null
const binary = atob(b64)
const bytes = new Uint8Array(binary.length)
for (let i = 0; i < binary.length; i++) bytes[i] = binary.charCodeAt(i)
return await toFile(new Blob([bytes]), 'assistant_image.png', { type: 'image/png' })
})
)
imageFiles = imageFiles.concat(assistantImages.filter(Boolean) as Blob[])
}
enqueue({ type: ChunkType.IMAGE_CREATED })
const startTime = Date.now()
let response: OpenAI.Images.ImagesResponse
const options = { signal, timeout: 300_000 }
if (imageFiles.length > 0) {
response = await sdk.images.edit(
{
model: assistant.model.id,
image: imageFiles,
prompt: prompt || ''
},
options
)
} else {
response = await sdk.images.generate(
{
model: assistant.model.id,
prompt: prompt || '',
response_format: assistant.model.id.includes('gpt-image-1') ? undefined : 'b64_json'
},
options
)
}
let imageType: 'url' | 'base64' = 'base64'
const imageList =
response.data?.reduce((acc: string[], image) => {
if (image.url) {
acc.push(image.url)
imageType = 'url'
} else if (image.b64_json) {
acc.push(`data:image/png;base64,${image.b64_json}`)
}
return acc
}, []) || []
enqueue({
type: ChunkType.IMAGE_COMPLETE,
image: { type: imageType, images: imageList }
})
const usage = (response as any).usage || { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }
enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage,
metrics: {
completion_tokens: usage.completion_tokens,
time_first_token_millsec: 0,
time_completion_millsec: Date.now() - startTime
}
}
})
} catch (error: any) {
enqueue({ type: ChunkType.ERROR, error })
} finally {
controller.close()
}
}
})
return {
stream,
getText: () => ''
}
}

View File

@@ -1,136 +0,0 @@
import { Model } from '@renderer/types'
import { ChunkType, TextDeltaChunk, ThinkingCompleteChunk, ThinkingDeltaChunk } from '@renderer/types/chunk'
import { TagConfig, TagExtractor } from '@renderer/utils/tagExtraction'
import Logger from 'electron-log/renderer'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'ThinkingTagExtractionMiddleware'
// 不同模型的思考标签配置
const reasoningTags: TagConfig[] = [
{ openingTag: '<think>', closingTag: '</think>', separator: '\n' },
{ openingTag: '###Thinking', closingTag: '###Response', separator: '\n' }
]
const getAppropriateTag = (model?: Model): TagConfig => {
if (model?.id?.includes('qwen3')) return reasoningTags[0]
// 可以在这里添加更多模型特定的标签配置
return reasoningTags[0] // 默认使用 <think> 标签
}
/**
* 处理文本流中思考标签提取的中间件
*
* 该中间件专门处理文本流中的思考标签内容(如 <think>...</think>
* 主要用于 OpenAI 等支持思考标签的 provider
*
* 职责:
* 1. 从文本流中提取思考标签内容
* 2. 将标签内的内容转换为 THINKING_DELTA chunk
* 3. 将标签外的内容作为正常文本输出
* 4. 处理不同模型的思考标签格式
* 5. 在思考内容结束时生成 THINKING_COMPLETE 事件
*/
export const ThinkingTagExtractionMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (context: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
// 调用下游中间件
const result = await next(context, params)
// 响应后处理:处理思考标签提取
if (result.stream) {
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
// 检查是否有流需要处理
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
// 获取当前模型的思考标签配置
const model = params.assistant?.model
const reasoningTag = getAppropriateTag(model)
// 创建标签提取器
const tagExtractor = new TagExtractor(reasoningTag)
// thinking 处理状态
let hasThinkingContent = false
let thinkingStartTime = 0
const processedStream = resultFromUpstream.pipeThrough(
new TransformStream<GenericChunk, GenericChunk>({
transform(chunk: GenericChunk, controller) {
if (chunk.type === ChunkType.TEXT_DELTA) {
const textChunk = chunk as TextDeltaChunk
// 使用 TagExtractor 处理文本
const extractionResults = tagExtractor.processText(textChunk.text)
for (const extractionResult of extractionResults) {
if (extractionResult.complete && extractionResult.tagContentExtracted) {
// 生成 THINKING_COMPLETE 事件
const thinkingCompleteChunk: ThinkingCompleteChunk = {
type: ChunkType.THINKING_COMPLETE,
text: extractionResult.tagContentExtracted,
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
}
controller.enqueue(thinkingCompleteChunk)
// 重置思考状态
hasThinkingContent = false
thinkingStartTime = 0
} else if (extractionResult.content.length > 0) {
if (extractionResult.isTagContent) {
// 第一次接收到思考内容时记录开始时间
if (!hasThinkingContent) {
hasThinkingContent = true
thinkingStartTime = Date.now()
}
const thinkingDeltaChunk: ThinkingDeltaChunk = {
type: ChunkType.THINKING_DELTA,
text: extractionResult.content,
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
}
controller.enqueue(thinkingDeltaChunk)
} else {
// 发送清理后的文本内容
const cleanTextChunk: TextDeltaChunk = {
...textChunk,
text: extractionResult.content
}
controller.enqueue(cleanTextChunk)
}
}
}
} else {
// 其他类型的chunk直接传递包括 THINKING_DELTA, THINKING_COMPLETE 等)
controller.enqueue(chunk)
}
},
flush(controller) {
// 处理可能剩余的思考内容
const finalResult = tagExtractor.finalize()
if (finalResult?.tagContentExtracted) {
const thinkingCompleteChunk: ThinkingCompleteChunk = {
type: ChunkType.THINKING_COMPLETE,
text: finalResult.tagContentExtracted,
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
}
controller.enqueue(thinkingCompleteChunk)
}
}
})
)
// 更新响应结果
return {
...result,
stream: processedStream
}
} else {
Logger.warn(`[${MIDDLEWARE_NAME}] No generic chunk stream to process or not a ReadableStream.`)
}
}
return result
}

View File

@@ -1,124 +0,0 @@
import { MCPTool } from '@renderer/types'
import { ChunkType, MCPToolCreatedChunk, TextDeltaChunk } from '@renderer/types/chunk'
import { parseToolUse } from '@renderer/utils/mcp-tools'
import { TagConfig, TagExtractor } from '@renderer/utils/tagExtraction'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'ToolUseExtractionMiddleware'
// 工具使用标签配置
const TOOL_USE_TAG_CONFIG: TagConfig = {
openingTag: '<tool_use>',
closingTag: '</tool_use>',
separator: '\n'
}
/**
* 工具使用提取中间件
*
* 职责:
* 1. 从文本流中检测并提取 <tool_use></tool_use> 标签
* 2. 解析工具调用信息并转换为 ToolUseResponse 格式
* 3. 生成 MCP_TOOL_CREATED chunk 供 McpToolChunkMiddleware 处理
* 4. 清理文本流,移除工具使用标签但保留正常文本
*
* 注意:此中间件只负责提取和转换,实际工具调用由 McpToolChunkMiddleware 处理
*/
export const ToolUseExtractionMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const mcpTools = params.mcpTools || []
// 如果没有工具,直接调用下一个中间件
if (!mcpTools || mcpTools.length === 0) return next(ctx, params)
// 调用下游中间件
const result = await next(ctx, params)
// 响应后处理:处理工具使用标签提取
if (result.stream) {
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
const processedStream = resultFromUpstream.pipeThrough(createToolUseExtractionTransform(ctx, mcpTools))
return {
...result,
stream: processedStream
}
}
return result
}
/**
* 创建工具使用提取的 TransformStream
*/
function createToolUseExtractionTransform(
_ctx: CompletionsContext,
mcpTools: MCPTool[]
): TransformStream<GenericChunk, GenericChunk> {
const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
return new TransformStream({
async transform(chunk: GenericChunk, controller) {
try {
// 处理文本内容,检测工具使用标签
if (chunk.type === ChunkType.TEXT_DELTA) {
const textChunk = chunk as TextDeltaChunk
const extractionResults = tagExtractor.processText(textChunk.text)
for (const result of extractionResults) {
if (result.complete && result.tagContentExtracted) {
// 提取到完整的工具使用内容,解析并转换为 SDK ToolCall 格式
const toolUseResponses = parseToolUse(result.tagContentExtracted, mcpTools)
if (toolUseResponses.length > 0) {
// 生成 MCP_TOOL_CREATED chunk复用现有的处理流程
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
type: ChunkType.MCP_TOOL_CREATED,
tool_use_responses: toolUseResponses
}
controller.enqueue(mcpToolCreatedChunk)
}
} else if (!result.isTagContent && result.content) {
// 发送标签外的正常文本内容
const cleanTextChunk: TextDeltaChunk = {
...textChunk,
text: result.content
}
controller.enqueue(cleanTextChunk)
}
// 注意标签内的内容不会作为TEXT_DELTA转发避免重复显示
}
return
}
// 转发其他所有chunk
controller.enqueue(chunk)
} catch (error) {
console.error(`🔧 [${MIDDLEWARE_NAME}] Error processing chunk:`, error)
controller.error(error)
}
},
async flush(controller) {
// 检查是否有未完成的标签内容
const finalResult = tagExtractor.finalize()
if (finalResult && finalResult.tagContentExtracted) {
const toolUseResponses = parseToolUse(finalResult.tagContentExtracted, mcpTools)
if (toolUseResponses.length > 0) {
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
type: ChunkType.MCP_TOOL_CREATED,
tool_use_responses: toolUseResponses
}
controller.enqueue(mcpToolCreatedChunk)
}
}
}
})
}
export default ToolUseExtractionMiddleware

View File

@@ -1,88 +0,0 @@
import { CompletionsMiddleware, MethodMiddleware } from './types'
// /**
// * Wraps a provider instance with middlewares.
// */
// export function wrapProviderWithMiddleware(
// apiClientInstance: BaseApiClient,
// middlewareConfig: MiddlewareConfig
// ): BaseApiClient {
// console.log(`[wrapProviderWithMiddleware] Wrapping provider: ${apiClientInstance.provider?.id}`)
// console.log(`[wrapProviderWithMiddleware] Middleware config:`, {
// completions: middlewareConfig.completions?.length || 0,
// methods: Object.keys(middlewareConfig.methods || {}).length
// })
// // Cache for already wrapped methods to avoid re-wrapping on every access.
// const wrappedMethodsCache = new Map<string, (...args: any[]) => Promise<any>>()
// const proxy = new Proxy(apiClientInstance, {
// get(target, propKey, receiver) {
// const methodName = typeof propKey === 'string' ? propKey : undefined
// if (!methodName) {
// return Reflect.get(target, propKey, receiver)
// }
// if (wrappedMethodsCache.has(methodName)) {
// console.log(`[wrapProviderWithMiddleware] Using cached wrapped method: ${methodName}`)
// return wrappedMethodsCache.get(methodName)
// }
// const originalMethod = Reflect.get(target, propKey, receiver)
// // If the property is not a function, return it directly.
// if (typeof originalMethod !== 'function') {
// return originalMethod
// }
// let wrappedMethod: ((...args: any[]) => Promise<any>) | undefined
// // Handle completions method
// if (methodName === 'completions' && middlewareConfig.completions?.length) {
// console.log(
// `[wrapProviderWithMiddleware] Wrapping completions method with ${middlewareConfig.completions.length} middlewares`
// )
// const completionsOriginalMethod = originalMethod as (params: CompletionsParams) => Promise<any>
// wrappedMethod = applyCompletionsMiddlewares(target, completionsOriginalMethod, middlewareConfig.completions)
// }
// // Handle other methods
// else {
// const methodMiddlewares = middlewareConfig.methods?.[methodName]
// if (methodMiddlewares?.length) {
// console.log(
// `[wrapProviderWithMiddleware] Wrapping method ${methodName} with ${methodMiddlewares.length} middlewares`
// )
// const genericOriginalMethod = originalMethod as (...args: any[]) => Promise<any>
// wrappedMethod = applyMethodMiddlewares(target, methodName, genericOriginalMethod, methodMiddlewares)
// }
// }
// if (wrappedMethod) {
// console.log(`[wrapProviderWithMiddleware] Successfully wrapped method: ${methodName}`)
// wrappedMethodsCache.set(methodName, wrappedMethod)
// return wrappedMethod
// }
// // If no middlewares are configured for this method, return the original method bound to the target. /
// // 如果没有为此方法配置中间件,则返回绑定到目标的原始方法。
// console.log(`[wrapProviderWithMiddleware] No middlewares for method ${methodName}, returning original`)
// return originalMethod.bind(target)
// }
// })
// return proxy as BaseApiClient
// }
// Export types for external use
export type { CompletionsMiddleware, MethodMiddleware }
// Export MiddlewareBuilder related types and classes
export {
CompletionsMiddlewareBuilder,
createCompletionsBuilder,
createMethodBuilder,
MethodMiddlewareBuilder,
MiddlewareBuilder,
type MiddlewareExecutor,
type NamedMiddleware
} from './builder'

View File

@@ -1,149 +0,0 @@
import * as AbortHandlerModule from './common/AbortHandlerMiddleware'
import * as ErrorHandlerModule from './common/ErrorHandlerMiddleware'
import * as FinalChunkConsumerModule from './common/FinalChunkConsumerMiddleware'
import * as LoggingModule from './common/LoggingMiddleware'
import * as McpToolChunkModule from './core/McpToolChunkMiddleware'
import * as RawStreamListenerModule from './core/RawStreamListenerMiddleware'
import * as ResponseTransformModule from './core/ResponseTransformMiddleware'
// import * as SdkCallModule from './core/SdkCallMiddleware'
import * as StreamAdapterModule from './core/StreamAdapterMiddleware'
import * as TextChunkModule from './core/TextChunkMiddleware'
import * as ThinkChunkModule from './core/ThinkChunkMiddleware'
import * as TransformCoreToSdkParamsModule from './core/TransformCoreToSdkParamsMiddleware'
import * as WebSearchModule from './core/WebSearchMiddleware'
import * as ImageGenerationModule from './feat/ImageGenerationMiddleware'
import * as ThinkingTagExtractionModule from './feat/ThinkingTagExtractionMiddleware'
import * as ToolUseExtractionMiddleware from './feat/ToolUseExtractionMiddleware'
/**
* 中间件注册表 - 提供所有可用中间件的集中访问
* 注意:目前中间件文件还未导出 MIDDLEWARE_NAME会有 linter 错误,这是正常的
*/
export const MiddlewareRegistry = {
[ErrorHandlerModule.MIDDLEWARE_NAME]: {
name: ErrorHandlerModule.MIDDLEWARE_NAME,
middleware: ErrorHandlerModule.ErrorHandlerMiddleware
},
// 通用中间件
[AbortHandlerModule.MIDDLEWARE_NAME]: {
name: AbortHandlerModule.MIDDLEWARE_NAME,
middleware: AbortHandlerModule.AbortHandlerMiddleware
},
[FinalChunkConsumerModule.MIDDLEWARE_NAME]: {
name: FinalChunkConsumerModule.MIDDLEWARE_NAME,
middleware: FinalChunkConsumerModule.default
},
// 核心流程中间件
[TransformCoreToSdkParamsModule.MIDDLEWARE_NAME]: {
name: TransformCoreToSdkParamsModule.MIDDLEWARE_NAME,
middleware: TransformCoreToSdkParamsModule.TransformCoreToSdkParamsMiddleware
},
// [SdkCallModule.MIDDLEWARE_NAME]: {
// name: SdkCallModule.MIDDLEWARE_NAME,
// middleware: SdkCallModule.SdkCallMiddleware
// },
[StreamAdapterModule.MIDDLEWARE_NAME]: {
name: StreamAdapterModule.MIDDLEWARE_NAME,
middleware: StreamAdapterModule.StreamAdapterMiddleware
},
[RawStreamListenerModule.MIDDLEWARE_NAME]: {
name: RawStreamListenerModule.MIDDLEWARE_NAME,
middleware: RawStreamListenerModule.RawStreamListenerMiddleware
},
[ResponseTransformModule.MIDDLEWARE_NAME]: {
name: ResponseTransformModule.MIDDLEWARE_NAME,
middleware: ResponseTransformModule.ResponseTransformMiddleware
},
// 特性处理中间件
[ThinkingTagExtractionModule.MIDDLEWARE_NAME]: {
name: ThinkingTagExtractionModule.MIDDLEWARE_NAME,
middleware: ThinkingTagExtractionModule.ThinkingTagExtractionMiddleware
},
[ToolUseExtractionMiddleware.MIDDLEWARE_NAME]: {
name: ToolUseExtractionMiddleware.MIDDLEWARE_NAME,
middleware: ToolUseExtractionMiddleware.ToolUseExtractionMiddleware
},
[ThinkChunkModule.MIDDLEWARE_NAME]: {
name: ThinkChunkModule.MIDDLEWARE_NAME,
middleware: ThinkChunkModule.ThinkChunkMiddleware
},
[McpToolChunkModule.MIDDLEWARE_NAME]: {
name: McpToolChunkModule.MIDDLEWARE_NAME,
middleware: McpToolChunkModule.McpToolChunkMiddleware
},
[WebSearchModule.MIDDLEWARE_NAME]: {
name: WebSearchModule.MIDDLEWARE_NAME,
middleware: WebSearchModule.WebSearchMiddleware
},
[TextChunkModule.MIDDLEWARE_NAME]: {
name: TextChunkModule.MIDDLEWARE_NAME,
middleware: TextChunkModule.TextChunkMiddleware
},
[ImageGenerationModule.MIDDLEWARE_NAME]: {
name: ImageGenerationModule.MIDDLEWARE_NAME,
middleware: ImageGenerationModule.ImageGenerationMiddleware
}
} as const
/**
* 根据名称获取中间件
* @param name - 中间件名称
* @returns 对应的中间件信息
*/
export function getMiddleware(name: string) {
return MiddlewareRegistry[name]
}
/**
* 获取所有注册的中间件名称
* @returns 中间件名称列表
*/
export function getRegisteredMiddlewareNames(): string[] {
return Object.keys(MiddlewareRegistry)
}
/**
* 默认的 Completions 中间件配置 - NamedMiddleware 格式,用于 MiddlewareBuilder
*/
export const DefaultCompletionsNamedMiddlewares = [
MiddlewareRegistry[FinalChunkConsumerModule.MIDDLEWARE_NAME], // 最终消费者
MiddlewareRegistry[ErrorHandlerModule.MIDDLEWARE_NAME], // 错误处理
MiddlewareRegistry[TransformCoreToSdkParamsModule.MIDDLEWARE_NAME], // 参数转换
MiddlewareRegistry[AbortHandlerModule.MIDDLEWARE_NAME], // 中止处理
MiddlewareRegistry[McpToolChunkModule.MIDDLEWARE_NAME], // 工具处理
MiddlewareRegistry[TextChunkModule.MIDDLEWARE_NAME], // 文本处理
MiddlewareRegistry[WebSearchModule.MIDDLEWARE_NAME], // Web搜索处理
MiddlewareRegistry[ToolUseExtractionMiddleware.MIDDLEWARE_NAME], // 工具使用提取处理
MiddlewareRegistry[ThinkingTagExtractionModule.MIDDLEWARE_NAME], // 思考标签提取处理特定provider
MiddlewareRegistry[ThinkChunkModule.MIDDLEWARE_NAME], // 思考处理通用SDK
MiddlewareRegistry[ResponseTransformModule.MIDDLEWARE_NAME], // 响应转换
MiddlewareRegistry[StreamAdapterModule.MIDDLEWARE_NAME], // 流适配器
MiddlewareRegistry[RawStreamListenerModule.MIDDLEWARE_NAME] // 原始流监听器
]
/**
* 默认的通用方法中间件 - 例如翻译、摘要等
*/
export const DefaultMethodMiddlewares = {
translate: [LoggingModule.createGenericLoggingMiddleware()],
summaries: [LoggingModule.createGenericLoggingMiddleware()]
}
/**
* 导出所有中间件模块,方便外部使用
*/
export {
AbortHandlerModule,
FinalChunkConsumerModule,
LoggingModule,
McpToolChunkModule,
ResponseTransformModule,
StreamAdapterModule,
TextChunkModule,
ThinkChunkModule,
ThinkingTagExtractionModule,
TransformCoreToSdkParamsModule,
WebSearchModule
}

View File

@@ -1,77 +0,0 @@
import { Assistant, MCPTool } from '@renderer/types'
import { Chunk } from '@renderer/types/chunk'
import { Message } from '@renderer/types/newMessage'
import { SdkRawChunk, SdkRawOutput } from '@renderer/types/sdk'
import { ProcessingState } from './types'
// ============================================================================
// Core Request Types - 核心请求结构
// ============================================================================
/**
* 标准化的内部核心请求结构用于所有AI Provider的统一处理
* 这是应用层参数转换后的标准格式,不包含回调函数和控制逻辑
*/
export interface CompletionsParams {
/**
* 调用的业务场景类型,用于中间件判断是否执行
* 'chat': 主要对话流程
* 'translate': 翻译
* 'summary': 摘要
* 'search': 搜索摘要
* 'generate': 生成
* 'check': API检查
*/
callType?: 'chat' | 'translate' | 'summary' | 'search' | 'generate' | 'check'
// 基础对话数据
messages: Message[] | string // 联合类型方便判断是否为空
assistant: Assistant // 助手为基本单位
// model: Model
onChunk?: (chunk: Chunk) => void
onResponse?: (text: string, isComplete: boolean) => void
// 错误相关
onError?: (error: Error) => void
shouldThrow?: boolean
// 工具相关
mcpTools?: MCPTool[]
// 生成参数
temperature?: number
topP?: number
maxTokens?: number
// 功能开关
streamOutput: boolean
enableWebSearch?: boolean
enableReasoning?: boolean
enableGenerateImage?: boolean
// 上下文控制
contextCount?: number
_internal?: ProcessingState
}
export interface CompletionsResult {
rawOutput?: SdkRawOutput
stream?: ReadableStream<SdkRawChunk> | ReadableStream<Chunk> | AsyncIterable<Chunk>
controller?: AbortController
getText: () => string
}
// ============================================================================
// Generic Chunk Types - 通用数据块结构
// ============================================================================
/**
* 通用数据块类型
* 复用现有的 Chunk 类型这是所有AI Provider都应该输出的标准化数据块格式
*/
export type GenericChunk = Chunk

View File

@@ -1,166 +0,0 @@
import { MCPToolResponse, Metrics, Usage, WebSearchResponse } from '@renderer/types'
import { Chunk, ErrorChunk } from '@renderer/types/chunk'
import {
SdkInstance,
SdkMessageParam,
SdkParams,
SdkRawChunk,
SdkRawOutput,
SdkTool,
SdkToolCall
} from '@renderer/types/sdk'
import { BaseApiClient } from '../clients'
import { CompletionsParams, CompletionsResult } from './schemas'
/**
* Symbol to uniquely identify middleware context objects.
*/
export const MIDDLEWARE_CONTEXT_SYMBOL = Symbol.for('AiProviderMiddlewareContext')
/**
* Defines the structure for the onChunk callback function.
*/
export type OnChunkFunction = (chunk: Chunk | ErrorChunk) => void
/**
* Base context that carries information about the current method call.
*/
export interface BaseContext {
[MIDDLEWARE_CONTEXT_SYMBOL]: true
methodName: string
originalArgs: Readonly<any[]>
}
/**
* Processing state shared between middlewares.
*/
export interface ProcessingState<
TParams extends SdkParams = SdkParams,
TMessageParam extends SdkMessageParam = SdkMessageParam,
TToolCall extends SdkToolCall = SdkToolCall
> {
sdkPayload?: TParams
newReqMessages?: TMessageParam[]
observer?: {
usage?: Usage
metrics?: Metrics
}
toolProcessingState?: {
pendingToolCalls?: Array<TToolCall>
executingToolCalls?: Array<{
sdkToolCall: TToolCall
mcpToolResponse: MCPToolResponse
}>
output?: SdkRawOutput | string
isRecursiveCall?: boolean
recursionDepth?: number
}
webSearchState?: {
results?: WebSearchResponse
}
flowControl?: {
abortController?: AbortController
abortSignal?: AbortSignal
cleanup?: () => void
}
enhancedDispatch?: (context: CompletionsContext, params: CompletionsParams) => Promise<CompletionsResult>
customState?: Record<string, any>
}
/**
* Extended context for completions method.
*/
export interface CompletionsContext<
TSdkParams extends SdkParams = SdkParams,
TSdkMessageParam extends SdkMessageParam = SdkMessageParam,
TSdkToolCall extends SdkToolCall = SdkToolCall,
TSdkInstance extends SdkInstance = SdkInstance,
TRawOutput extends SdkRawOutput = SdkRawOutput,
TRawChunk extends SdkRawChunk = SdkRawChunk,
TSdkSpecificTool extends SdkTool = SdkTool
> extends BaseContext {
readonly methodName: 'completions' // 强制方法名为 'completions'
apiClientInstance: BaseApiClient<
TSdkInstance,
TSdkParams,
TRawOutput,
TRawChunk,
TSdkMessageParam,
TSdkToolCall,
TSdkSpecificTool
>
// --- Mutable internal state for the duration of the middleware chain ---
_internal: ProcessingState<TSdkParams, TSdkMessageParam, TSdkToolCall> // 包含所有可变的处理状态
}
export interface MiddlewareAPI<Ctx extends BaseContext = BaseContext, Args extends any[] = any[]> {
getContext: () => Ctx // Function to get the current context / 获取当前上下文的函数
getOriginalArgs: () => Args // Function to get the original arguments of the method call / 获取方法调用原始参数的函数
}
/**
* Base middleware type.
*/
export type Middleware<TContext extends BaseContext> = (
api: MiddlewareAPI<TContext>
) => (
next: (context: TContext, args: any[]) => Promise<unknown>
) => (context: TContext, args: any[]) => Promise<unknown>
export type MethodMiddleware = Middleware<BaseContext>
/**
* Completions middleware type.
*/
export type CompletionsMiddleware<
TSdkParams extends SdkParams = SdkParams,
TSdkMessageParam extends SdkMessageParam = SdkMessageParam,
TSdkToolCall extends SdkToolCall = SdkToolCall,
TSdkInstance extends SdkInstance = SdkInstance,
TRawOutput extends SdkRawOutput = SdkRawOutput,
TRawChunk extends SdkRawChunk = SdkRawChunk,
TSdkSpecificTool extends SdkTool = SdkTool
> = (
api: MiddlewareAPI<
CompletionsContext<
TSdkParams,
TSdkMessageParam,
TSdkToolCall,
TSdkInstance,
TRawOutput,
TRawChunk,
TSdkSpecificTool
>,
[CompletionsParams]
>
) => (
next: (
context: CompletionsContext<
TSdkParams,
TSdkMessageParam,
TSdkToolCall,
TSdkInstance,
TRawOutput,
TRawChunk,
TSdkSpecificTool
>,
params: CompletionsParams
) => Promise<CompletionsResult>
) => (
context: CompletionsContext<
TSdkParams,
TSdkMessageParam,
TSdkToolCall,
TSdkInstance,
TRawOutput,
TRawChunk,
TSdkSpecificTool
>,
params: CompletionsParams
) => Promise<CompletionsResult>
// Re-export for convenience
export type { Chunk as OnChunkArg } from '@renderer/types/chunk'

View File

@@ -1,57 +0,0 @@
import { ChunkType, ErrorChunk } from '@renderer/types/chunk'
/**
* Creates an ErrorChunk object with a standardized structure.
* @param error The error object or message.
* @param chunkType The type of chunk, defaults to ChunkType.ERROR.
* @returns An ErrorChunk object.
*/
export function createErrorChunk(error: any, chunkType: ChunkType = ChunkType.ERROR): ErrorChunk {
let errorDetails: Record<string, any> = {}
if (error instanceof Error) {
errorDetails = {
message: error.message,
name: error.name,
stack: error.stack
}
} else if (typeof error === 'string') {
errorDetails = { message: error }
} else if (typeof error === 'object' && error !== null) {
errorDetails = Object.getOwnPropertyNames(error).reduce(
(acc, key) => {
acc[key] = error[key]
return acc
},
{} as Record<string, any>
)
if (!errorDetails.message && error.toString && typeof error.toString === 'function') {
const errMsg = error.toString()
if (errMsg !== '[object Object]') {
errorDetails.message = errMsg
}
}
}
return {
type: chunkType,
error: errorDetails
} as ErrorChunk
}
// Helper to capitalize method names for hook construction
export function capitalize(str: string): string {
if (!str) return ''
return str.charAt(0).toUpperCase() + str.slice(1)
}
/**
* 检查对象是否实现了AsyncIterable接口
*/
export function isAsyncIterable<T = unknown>(obj: unknown): obj is AsyncIterable<T> {
return (
obj !== null &&
typeof obj === 'object' &&
typeof (obj as Record<symbol, unknown>)[Symbol.asyncIterator] === 'function'
)
}

View File

@@ -1,110 +0,0 @@
/**
* Cherry Studio 参数转换插件
* 专门处理 Cherry Studio 特有的消息格式、文件处理、Assistant 设置等
*/
import { definePlugin } from '@cherrystudio/ai-core'
import type { Assistant, MCPTool, Message, Model } from '@renderer/types'
import {
buildStreamTextParams,
convertMessagesToSdkMessages,
getCustomParameters,
getTemperature,
getTopP
} from '../transformParameters'
/**
* Cherry Studio 核心转换插件
* 负责将 Cherry Studio 的数据结构转换为 AI SDK 兼容格式
*/
export const cherryStudioTransformPlugin = definePlugin({
name: 'cherry-studio-transform',
/**
* 转换请求参数
* 将 Cherry Studio 的 Assistant + Messages 转换为 AI SDK 格式
*/
transformParams: async (params: any, context) => {
// 检查是否有 Cherry Studio 特有的数据结构
const cherryData = context.metadata?.cherryStudio
if (!cherryData) {
return params // 不是 Cherry Studio 调用,直接返回
}
const { assistant, messages, mcpTools, enableTools } = cherryData
try {
// 1. 转换 Cherry Studio 消息为 AI SDK 消息
const sdkMessages = await convertMessagesToSdkMessages(messages as Message[], assistant.model as Model)
// 2. 构建完整的 AI SDK 参数
const { params: transformedParams } = await buildStreamTextParams(sdkMessages, assistant as Assistant, {
mcpTools: mcpTools as MCPTool[],
enableTools,
requestOptions: {
signal: params.abortSignal,
headers: params.headers
}
})
// 3. 合并原始参数和转换后的参数
return {
...params,
...transformedParams,
// 保留原始的一些关键参数
abortSignal: params.abortSignal,
headers: params.headers
}
} catch (error) {
console.error('Cherry Studio 参数转换失败:', error)
return params // 转换失败时返回原始参数
}
}
})
/**
* Cherry Studio Assistant 设置插件
* 专门处理 Assistant 的温度、TopP、自定义参数等设置
*/
export const cherryStudioSettingsPlugin = definePlugin({
name: 'cherry-studio-settings',
transformParams: async (params: any, context) => {
const cherryData = context.metadata?.cherryStudio
if (!cherryData?.assistant) {
return params
}
const { assistant } = cherryData
const model = assistant.model as Model
return {
...params,
temperature: getTemperature(assistant as Assistant, model),
topP: getTopP(assistant as Assistant, model),
...getCustomParameters(assistant as Assistant)
}
}
})
/**
* 便捷函数:为 Cherry Studio 调用准备上下文元数据
*/
export function createCherryStudioContext(
assistant: Assistant,
messages: Message[],
options: {
mcpTools?: MCPTool[]
enableTools?: boolean
} = {}
) {
return {
cherryStudio: {
assistant,
messages,
mcpTools: options.mcpTools,
enableTools: options.enableTools
}
}
}

View File

@@ -1,311 +0,0 @@
/**
* AI SDK 参数转换模块
* 统一管理从各个 apiClient 提取的参数处理和转换功能
*/
import type { CoreMessage, StreamTextParams } from '@cherrystudio/ai-core'
import {
isGenerateImageModel,
isNotSupportTemperatureAndTopP,
isOpenRouterBuiltInWebSearchModel,
isReasoningModel,
isSupportedDisableGenerationModel,
isSupportedFlexServiceTier,
isSupportedReasoningEffortModel,
isSupportedThinkingTokenModel,
isWebSearchModel
} from '@renderer/config/models'
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
import type { Assistant, MCPTool, Message, Model } from '@renderer/types'
import { FileTypes } from '@renderer/types'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { defaultTimeout } from '@shared/config/constant'
/**
* 获取温度参数
*/
export function getTemperature(assistant: Assistant, model: Model): number | undefined {
return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.temperature
}
/**
* 获取 TopP 参数
*/
export function getTopP(assistant: Assistant, model: Model): number | undefined {
return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.topP
}
/**
* 获取超时设置
*/
export function getTimeout(model: Model): number {
if (isSupportedFlexServiceTier(model)) {
return 15 * 1000 * 60
}
return defaultTimeout
}
/**
* 构建系统提示词
*/
export async function buildSystemPromptWithTools(
prompt: string,
mcpTools?: MCPTool[],
assistant?: Assistant
): Promise<string> {
return await buildSystemPrompt(prompt, mcpTools, assistant)
}
// /**
// * 转换 MCP 工具为 AI SDK 工具格式
// * 注意:这里返回通用格式,实际使用时需要根据具体 provider 转换
// TODO: 需要使用ai-sdk的mcp
// */
// export function convertMcpToolsToSdkTools(mcpTools: MCPTool[]): Pick<StreamTextParams, 'tools'> {
// return mcpTools.map((tool) => ({
// type: 'function',
// function: {
// name: tool.id,
// description: tool.description,
// parameters: tool.inputSchema || {}
// }
// }))
// }
/**
* 提取文件内容
*/
export async function extractFileContent(message: Message): Promise<string> {
const fileBlocks = findFileBlocks(message)
if (fileBlocks.length > 0) {
const textFileBlocks = fileBlocks.filter(
(fb) => fb.file && [FileTypes.TEXT, FileTypes.DOCUMENT].includes(fb.file.type)
)
if (textFileBlocks.length > 0) {
let text = ''
const divider = '\n\n---\n\n'
for (const fileBlock of textFileBlocks) {
const file = fileBlock.file
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
const fileNameRow = 'file: ' + file.origin_name + '\n\n'
text = text + fileNameRow + fileContent + divider
}
return text
}
}
return ''
}
/**
* 转换消息为 AI SDK 参数格式
* 基于 OpenAI 格式的通用转换,支持文本、图片和文件
*/
export async function convertMessageToSdkParam(message: Message, isVisionModel = false): Promise<any> {
const content = getMainTextContent(message)
const fileBlocks = findFileBlocks(message)
const imageBlocks = findImageBlocks(message)
// 简单消息(无文件无图片)
if (fileBlocks.length === 0 && imageBlocks.length === 0) {
return {
role: message.role === 'system' ? 'user' : message.role,
content
}
}
// 复杂消息(包含文件或图片)
const parts: any[] = []
if (content) {
parts.push({ type: 'text', text: content })
}
// 处理图片(仅在支持视觉的模型中)
if (isVisionModel) {
for (const imageBlock of imageBlocks) {
if (imageBlock.file) {
try {
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
parts.push({
type: 'image_url',
image_url: { url: image.data }
})
} catch (error) {
console.warn('Failed to load image:', error)
}
} else if (imageBlock.url && imageBlock.url.startsWith('data:')) {
parts.push({
type: 'image_url',
image_url: { url: imageBlock.url }
})
}
}
}
// 处理文件
for (const fileBlock of fileBlocks) {
const file = fileBlock.file
if (!file) continue
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
try {
const fileContent = await window.api.file.read(file.id + file.ext)
parts.push({
type: 'text',
text: `${file.origin_name}\n${fileContent.trim()}`
})
} catch (error) {
console.warn('Failed to read file:', error)
}
}
}
return {
role: message.role === 'system' ? 'user' : message.role,
content: parts.length === 1 && parts[0].type === 'text' ? parts[0].text : parts
}
}
/**
* 转换 Cherry Studio 消息数组为 AI SDK 消息数组
*/
export async function convertMessagesToSdkMessages(
messages: Message[],
model: Model
): Promise<StreamTextParams['messages']> {
const sdkMessages: StreamTextParams['messages'] = []
const isVision = model.id.includes('vision') || model.id.includes('gpt-4') // 简单的视觉模型检测
for (const message of messages) {
const sdkMessage = await convertMessageToSdkParam(message, isVision)
sdkMessages.push(sdkMessage)
}
return sdkMessages
}
/**
* 构建 AI SDK 流式参数
* 这是主要的参数构建函数,整合所有转换逻辑
*/
export async function buildStreamTextParams(
sdkMessages: StreamTextParams['messages'],
assistant: Assistant,
options: {
mcpTools?: MCPTool[]
enableTools?: boolean
requestOptions?: {
signal?: AbortSignal
timeout?: number
headers?: Record<string, string>
}
} = {}
): Promise<{ params: StreamTextParams; modelId: string }> {
const { mcpTools, enableTools = false } = options
const model = assistant.model || getDefaultModel()
const { maxTokens, reasoning_effort } = getAssistantSettings(assistant)
const enableReasoning =
((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) &&
reasoning_effort !== undefined) ||
(isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model)))
const enableWebSearch =
(assistant.enableWebSearch && isWebSearchModel(model)) ||
isOpenRouterBuiltInWebSearchModel(model) ||
model.id.includes('sonar') ||
false
const enableGenerateImage =
isGenerateImageModel(model) &&
(isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage || false : true)
// 构建系统提示
let systemPrompt = assistant.prompt || ''
if (mcpTools && mcpTools.length > 0) {
systemPrompt = await buildSystemPromptWithTools(systemPrompt, mcpTools, assistant)
}
// 构建基础参数
const params: StreamTextParams = {
messages: sdkMessages,
maxTokens: maxTokens || 1000,
temperature: getTemperature(assistant, model),
topP: getTopP(assistant, model),
system: systemPrompt || undefined,
abortSignal: options.requestOptions?.signal,
headers: options.requestOptions?.headers,
// 随便填着,后面再改
providerOptions: {
reasoning: {
enabled: enableReasoning
},
webSearch: {
enabled: enableWebSearch
},
generateImage: {
enabled: enableGenerateImage
}
},
...getCustomParameters(assistant)
}
// 添加工具(如果启用且有工具)
if (enableTools && mcpTools && mcpTools.length > 0) {
// TODO: 暂时注释掉工具支持,等类型问题解决后再启用
// params.tools = convertMcpToolsToSdkTools(mcpTools)
}
return { params, modelId: model.id }
}
/**
* 构建非流式的 generateText 参数
*/
export async function buildGenerateTextParams(
messages: CoreMessage[],
assistant: Assistant,
options: {
mcpTools?: MCPTool[]
enableTools?: boolean
} = {}
): Promise<any> {
// 复用流式参数的构建逻辑
return await buildStreamTextParams(messages, assistant, options)
}
/**
* 获取自定义参数
* 从 assistant 设置中提取自定义参数
*/
export function getCustomParameters(assistant: Assistant): Record<string, any> {
return (
assistant?.settings?.customParameters?.reduce((acc, param) => {
if (!param.name?.trim()) {
return acc
}
if (param.type === 'json') {
const value = param.value as string
if (value === 'undefined') {
return { ...acc, [param.name]: undefined }
}
try {
return { ...acc, [param.name]: JSON.parse(value) }
} catch {
return { ...acc, [param.name]: value }
}
}
return {
...acc,
[param.name]: param.value
}
}, {}) || {}
)
}

View File

@@ -1,8 +1 @@
<svg width="22" height="22" viewBox="13 -2 25 22" xmlns="http://www.w3.org/2000/svg">
<g id="White=False">
<g id="if">
<path d="M21.2002 3.73454C22.5633 3.73454 23.0666 2.89917 23.0666 1.86812C23.0666 0.837081 22.5623 0.00170898 21.2002 0.00170898C19.838 0.00170898 19.3337 0.837081 19.3337 1.86812C19.3337 2.89917 19.838 3.73454 21.2002 3.73454Z" fill="#0033FF"/>
<path d="M27.7336 4.13435V5.33473H24.6668V8.00171H27.7336V14.6687H22.6668V5.33567H15.9998V8.00265H19.7336V14.6696H15.3337V17.3366H35.3337V14.6696H30.6668V8.00265H35.3337V5.33567H30.6668V2.66869H35.3337V0.00170898H31.8671C29.5877 0.00170898 27.7336 1.8559 27.7336 4.13529V4.13435Z" fill="#0033FF"/>
</g>
</g>
</svg>
<svg fill="currentColor" fill-rule="evenodd" height="1em" style="flex:none;line-height:1" viewBox="0 0 24 24" width="1em" xmlns="http://www.w3.org/2000/svg"><title>Dify</title><clipPath id="lobe-icons-dify-fill"><path d="M1 0h10.286c6.627 0 12 5.373 12 12s-5.373 12-12 12H1V0z"></path></clipPath><foreignObject clip-path="url(#lobe-icons-dify-fill)" height="24" style="background:conic-gradient(from 180deg at 50% 50%, #0222C3, #8FB1F4, #FFFFFF)" width="24"></foreignObject></svg>

Before

Width:  |  Height:  |  Size: 680 B

After

Width:  |  Height:  |  Size: 480 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

View File

@@ -1 +0,0 @@
<svg height="1em" style="flex:none;line-height:1" viewBox="0 0 24 24" width="1em" xmlns="http://www.w3.org/2000/svg"><title>VertexAI</title><path d="M11.995 20.216a1.892 1.892 0 100 3.785 1.892 1.892 0 000-3.785zm0 2.806a.927.927 0 11.927-.914.914.914 0 01-.927.914z" fill="#4285F4"></path><path clip-rule="evenodd" d="M21.687 14.144c.237.038.452.16.605.344a.978.978 0 01-.18 1.3l-8.24 6.082a1.892 1.892 0 00-1.147-1.508l8.28-6.08a.991.991 0 01.682-.138z" fill="#669DF6" fill-rule="evenodd"></path><path clip-rule="evenodd" d="M10.122 21.842l-8.217-6.066a.952.952 0 01-.206-1.287.978.978 0 011.287-.206l8.28 6.08a1.893 1.893 0 00-1.144 1.479z" fill="#AECBFA" fill-rule="evenodd"></path><path d="M4.273 4.475a.978.978 0 01-.965-.965V1.09a.978.978 0 111.943 0v2.42a.978.978 0 01-.978.965zM4.247 13.034a.978.978 0 100-1.956.978.978 0 000 1.956zM4.247 10.19a.978.978 0 100-1.956.978.978 0 000 1.956zM4.247 7.332a.978.978 0 100-1.956.978.978 0 000 1.956z" fill="#AECBFA"></path><path d="M19.718 7.307a.978.978 0 01-.965-.979v-2.42a.965.965 0 011.93 0v2.42a.964.964 0 01-.965.979zM19.743 13.047a.978.978 0 100-1.956.978.978 0 000 1.956zM19.743 10.151a.978.978 0 100-1.956.978.978 0 000 1.956zM19.743 2.068a.978.978 0 100-1.956.978.978 0 000 1.956z" fill="#4285F4"></path><path d="M11.995 15.917a.978.978 0 01-.965-.965v-2.459a.978.978 0 011.943 0v2.433a.976.976 0 01-.978.991zM11.995 18.762a.978.978 0 100-1.956.978.978 0 000 1.956zM11.995 10.64a.978.978 0 100-1.956.978.978 0 000 1.956zM11.995 7.783a.978.978 0 100-1.956.978.978 0 000 1.956z" fill="#669DF6"></path><path d="M15.856 10.177a.978.978 0 01-.965-.965v-2.42a.977.977 0 011.702-.763.979.979 0 01.241.763v2.42a.978.978 0 01-.978.965zM15.869 4.913a.978.978 0 100-1.956.978.978 0 000 1.956zM15.869 15.853a.978.978 0 100-1.956.978.978 0 000 1.956zM15.869 12.996a.978.978 0 100-1.956.978.978 0 000 1.956z" fill="#4285F4"></path><path d="M8.121 15.853a.978.978 0 100-1.956.978.978 0 000 1.956zM8.121 7.783a.978.978 0 100-1.956.978.978 0 000 1.956zM8.121 4.913a.978.978 0 100-1.957.978.978 0 000 1.957zM8.134 12.996a.978.978 0 01-.978-.94V9.611a.965.965 0 011.93 0v2.445a.966.966 0 01-.952.94z" fill="#AECBFA"></path></svg>

Before

Width:  |  Height:  |  Size: 2.1 KiB

View File

@@ -197,26 +197,11 @@
}
}
.ant-dropdown {
.ant-dropdown-menu {
max-height: 50vh;
overflow-y: auto;
border: 0.5px solid var(--color-border);
.ant-dropdown-menu-sub {
max-height: 50vh;
width: max-content;
overflow-y: auto;
overflow-x: hidden;
border: 0.5px solid var(--color-border);
}
}
.ant-dropdown-arrow + .ant-dropdown-menu {
border: none;
}
}
.ant-select-dropdown {
border: 0.5px solid var(--color-border);
.ant-dropdown-menu .ant-dropdown-menu-sub {
max-height: 350px;
width: max-content;
overflow-y: auto;
overflow-x: hidden;
}
.ant-collapse {

View File

@@ -136,10 +136,6 @@ ul {
display: flow-root;
}
.block-wrapper:last-child > *:last-child {
margin-bottom: 0;
}
.message-content-container > *:last-child {
margin-bottom: 0;
}

View File

@@ -321,7 +321,6 @@ mjx-container {
.cm-gutters {
line-height: 1.6;
border-right: none;
}
.cm-content {

View File

@@ -22,7 +22,6 @@ const MermaidPreview: React.FC<Props> = ({ children, setTools }) => {
const diagramId = useRef<string>(`mermaid-${nanoid(6)}`).current
const [error, setError] = useState<string | null>(null)
const [isRendering, setIsRendering] = useState(false)
const [isVisible, setIsVisible] = useState(true)
// 使用通用图像工具
const { handleZoom, handleCopyImage, handleDownload } = usePreviewToolHandlers(mermaidRef, {
@@ -76,55 +75,10 @@ const MermaidPreview: React.FC<Props> = ({ children, setTools }) => {
[renderMermaid]
)
/**
* 监听可见性变化,用于触发重新渲染。
* 这是为了解决 `MessageGroup` 组件的 `fold` 布局中被 `display: none` 隐藏的图标无法正确渲染的问题。
* 监听时向上遍历到第一个有 `fold` className 的父节点为止(也就是目前的 `MessageWrapper`)。
* FIXME: 将来 mermaid-js 修复此问题后可以移除这里的相关逻辑。
*/
useEffect(() => {
if (!mermaidRef.current) return
const checkVisibility = () => {
const element = mermaidRef.current
if (!element) return
const currentlyVisible = element.offsetParent !== null
setIsVisible(currentlyVisible)
}
// 初始检查
checkVisibility()
const observer = new MutationObserver(() => {
checkVisibility()
})
let targetElement = mermaidRef.current.parentElement
while (targetElement) {
observer.observe(targetElement, {
attributes: true,
attributeFilter: ['class', 'style']
})
if (targetElement.className?.includes('fold')) {
break
}
targetElement = targetElement.parentElement
}
return () => {
observer.disconnect()
}
}, [])
// 触发渲染
useEffect(() => {
if (isLoadingMermaid) return
if (mermaidRef.current?.offsetParent === null) return
if (children) {
setIsRendering(true)
debouncedRender(children)
@@ -136,7 +90,7 @@ const MermaidPreview: React.FC<Props> = ({ children, setTools }) => {
return () => {
debouncedRender.cancel()
}
}, [children, isLoadingMermaid, debouncedRender, isVisible])
}, [children, isLoadingMermaid, debouncedRender])
const isLoading = isLoadingMermaid || isRendering

Some files were not shown because too many files have changed in this diff Show More