Merge branch 'main' into feat-knowlege-ocr

This commit is contained in:
eeee0717
2025-06-13 21:10:12 +08:00
107 changed files with 9655 additions and 5716 deletions
-1
View File
@@ -7,7 +7,6 @@
"request": "launch",
"cwd": "${workspaceRoot}",
"runtimeExecutable": "${workspaceRoot}/node_modules/.bin/electron-vite",
"runtimeVersion": "20",
"windows": {
"runtimeExecutable": "${workspaceRoot}/node_modules/.bin/electron-vite.cmd"
},
+214
View File
@@ -0,0 +1,214 @@
# 如何为 AI Provider 编写中间件
本文档旨在指导开发者如何为我们的 AI Provider 框架创建和集成自定义中间件。中间件提供了一种强大而灵活的方式来增强、修改或观察 Provider 方法的调用过程,例如日志记录、缓存、请求/响应转换、错误处理等。
## 架构概览
我们的中间件架构借鉴了 Redux 的三段式设计,并结合了 JavaScript Proxy 来动态地将中间件应用于 Provider 的方法。
- **Proxy**: 拦截对 Provider 方法的调用,并将调用引导至中间件链。
- **中间件链**: 一系列按顺序执行的中间件函数。每个中间件都可以处理请求/响应,然后将控制权传递给链中的下一个中间件,或者在某些情况下提前终止链。
- **上下文 (Context)**: 一个在中间件之间传递的对象,携带了关于当前调用的信息(如方法名、原始参数、Provider 实例、以及中间件自定义的数据)。
## 中间件的类型
目前主要支持两种类型的中间件,它们共享相似的结构但针对不同的场景:
1. **`CompletionsMiddleware`**: 专门为 `completions` 方法设计。这是最常用的中间件类型,因为它允许对 AI 模型的核心聊天/文本生成功能进行精细控制。
2. **`ProviderMethodMiddleware`**: 通用中间件,可以应用于 Provider 上的任何其他方法(例如,`translate`, `summarize` 等,如果这些方法也通过中间件系统包装)。
## 编写一个 `CompletionsMiddleware`
`CompletionsMiddleware` 的基本签名(TypeScript 类型)如下:
```typescript
import { AiProviderMiddlewareCompletionsContext, CompletionsParams, MiddlewareAPI } from './AiProviderMiddlewareTypes' // 假设类型定义文件路径
export type CompletionsMiddleware = (
api: MiddlewareAPI<AiProviderMiddlewareCompletionsContext, [CompletionsParams]>
) => (
next: (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise<any> // next 返回 Promise<any> 代表原始SDK响应或下游中间件的结果
) => (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise<void> // 最内层函数通常返回 Promise<void>,因为结果通过 onChunk 或 context 副作用传递
```
让我们分解这个三段式结构:
1. **第一层函数 `(api) => { ... }`**:
- 接收一个 `api` 对象。
- `api` 对象提供了以下方法:
- `api.getContext()`: 获取当前调用的上下文对象 (`AiProviderMiddlewareCompletionsContext`)。
- `api.getOriginalArgs()`: 获取传递给 `completions` 方法的原始参数数组 (即 `[CompletionsParams]`)。
- `api.getProviderId()`: 获取当前 Provider 的 ID。
- `api.getProviderInstance()`: 获取原始的 Provider 实例。
- 此函数通常用于进行一次性的设置或获取所需的服务/配置。它返回第二层函数。
2. **第二层函数 `(next) => { ... }`**:
- 接收一个 `next` 函数。
- `next` 函数代表了中间件链中的下一个环节。调用 `next(context, params)` 会将控制权传递给下一个中间件,或者如果当前中间件是链中的最后一个,则会调用核心的 Provider 方法逻辑 (例如,实际的 SDK 调用)。
- `next` 函数接收当前的 `context``params` (这些可能已被上游中间件修改)。
- **重要的是**`next` 的返回类型通常是 `Promise<any>`。对于 `completions` 方法,如果 `next` 调用了实际的 SDK,它将返回原始的 SDK 响应(例如,OpenAI 的流对象或 JSON 对象)。你需要处理这个响应。
- 此函数返回第三层(也是最核心的)函数。
3. **第三层函数 `(context, params) => { ... }`**:
- 这是执行中间件主要逻辑的地方。
- 它接收当前的 `context` (`AiProviderMiddlewareCompletionsContext`) 和 `params` (`CompletionsParams`)。
- 在此函数中,你可以:
- **在调用 `next` 之前**:
- 读取或修改 `params`。例如,添加默认参数、转换消息格式。
- 读取或修改 `context`。例如,设置一个时间戳用于后续计算延迟。
- 执行某些检查,如果不满足条件,可以不调用 `next` 而直接返回或抛出错误(例如,参数校验失败)。
- **调用 `await next(context, params)`**:
- 这是将控制权传递给下游的关键步骤。
- `next` 的返回值是原始的 SDK 响应或下游中间件的结果,你需要根据情况处理它(例如,如果是流,则开始消费流)。
- **在调用 `next` 之后**:
- 处理 `next` 的返回结果。例如,如果 `next` 返回了一个流,你可以在这里开始迭代处理这个流,并通过 `context.onChunk` 发送数据块。
- 基于 `context` 的变化或 `next` 的结果执行进一步操作。例如,计算总耗时、记录日志。
- 修改最终结果(尽管对于 `completions`,结果通常通过 `onChunk` 副作用发出)。
### 示例:一个简单的日志中间件
```typescript
import {
AiProviderMiddlewareCompletionsContext,
CompletionsParams,
MiddlewareAPI,
OnChunkFunction // 假设 OnChunkFunction 类型被导出
} from './AiProviderMiddlewareTypes' // 调整路径
import { ChunkType } from '@renderer/types' // 调整路径
export const createSimpleLoggingMiddleware = (): CompletionsMiddleware => {
return (api: MiddlewareAPI<AiProviderMiddlewareCompletionsContext, [CompletionsParams]>) => {
// console.log(`[LoggingMiddleware] Initialized for provider: ${api.getProviderId()}`);
return (next: (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise<any>) => {
return async (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams): Promise<void> => {
const startTime = Date.now()
// 从 context 中获取 onChunk (它最初来自 params.onChunk)
const onChunk = context.onChunk
console.log(
`[LoggingMiddleware] Request for ${context.methodName} with params:`,
params.messages?.[params.messages.length - 1]?.content
)
try {
// 调用下一个中间件或核心逻辑
// `rawSdkResponse` 是来自下游的原始响应 (例如 OpenAIStream 或 ChatCompletion 对象)
const rawSdkResponse = await next(context, params)
// 此处简单示例不处理 rawSdkResponse,假设下游中间件 (如 StreamingResponseHandler)
// 会处理它并通过 onChunk 发送数据。
// 如果这个日志中间件在 StreamingResponseHandler 之后,那么流已经被处理。
// 如果在之前,那么它需要自己处理 rawSdkResponse 或确保下游会处理。
const duration = Date.now() - startTime
console.log(`[LoggingMiddleware] Request for ${context.methodName} completed in ${duration}ms.`)
// 假设下游已经通过 onChunk 发送了所有数据。
// 如果这个中间件是链的末端,并且需要确保 BLOCK_COMPLETE 被发送,
// 它可能需要更复杂的逻辑来跟踪何时所有数据都已发送。
} catch (error) {
const duration = Date.now() - startTime
console.error(`[LoggingMiddleware] Request for ${context.methodName} failed after ${duration}ms:`, error)
// 如果 onChunk 可用,可以尝试发送一个错误块
if (onChunk) {
onChunk({
type: ChunkType.ERROR,
error: { message: (error as Error).message, name: (error as Error).name, stack: (error as Error).stack }
})
// 考虑是否还需要发送 BLOCK_COMPLETE 来结束流
onChunk({ type: ChunkType.BLOCK_COMPLETE, response: {} })
}
throw error // 重新抛出错误,以便上层或全局错误处理器可以捕获
}
}
}
}
}
```
### `AiProviderMiddlewareCompletionsContext` 的重要性
`AiProviderMiddlewareCompletionsContext` 是在中间件之间传递状态和数据的核心。它通常包含:
- `methodName`: 当前调用的方法名 (总是 `'completions'`)。
- `originalArgs`: 传递给 `completions` 的原始参数数组。
- `providerId`: Provider 的 ID。
- `_providerInstance`: Provider 实例。
- `onChunk`: 从原始 `CompletionsParams` 传入的回调函数,用于流式发送数据块。**所有中间件都应该通过 `context.onChunk` 来发送数据。**
- `messages`, `model`, `assistant`, `mcpTools`: 从原始 `CompletionsParams` 中提取的常用字段,方便访问。
- **自定义字段**: 中间件可以向上下文中添加自定义字段,以供后续中间件使用。例如,一个缓存中间件可能会添加 `context.cacheHit = true`
**关键**: 当你在中间件中修改 `params``context` 时,这些修改会向下游中间件传播(如果它们在 `next` 调用之前修改)。
### 中间件的顺序
中间件的执行顺序非常重要。它们在 `AiProviderMiddlewareConfig` 的数组中定义的顺序就是它们的执行顺序。
- 请求首先通过第一个中间件,然后是第二个,依此类推。
- 响应(或 `next` 的调用结果)则以相反的顺序"冒泡"回来。
例如,如果链是 `[AuthMiddleware, CacheMiddleware, LoggingMiddleware]`
1. `AuthMiddleware` 先执行其 "调用 `next` 之前" 的逻辑。
2. 然后 `CacheMiddleware` 执行其 "调用 `next` 之前" 的逻辑。
3. 然后 `LoggingMiddleware` 执行其 "调用 `next` 之前" 的逻辑。
4. 核心SDK调用(或链的末端)。
5. `LoggingMiddleware` 先接收到结果,执行其 "调用 `next` 之后" 的逻辑。
6. 然后 `CacheMiddleware` 接收到结果(可能已被 LoggingMiddleware 修改的上下文),执行其 "调用 `next` 之后" 的逻辑(例如,存储结果)。
7. 最后 `AuthMiddleware` 接收到结果,执行其 "调用 `next` 之后" 的逻辑。
### 注册中间件
中间件在 `src/renderer/src/providers/middleware/register.ts` (或其他类似的配置文件) 中进行注册。
```typescript
// register.ts
import { AiProviderMiddlewareConfig } from './AiProviderMiddlewareTypes'
import { createSimpleLoggingMiddleware } from './common/SimpleLoggingMiddleware' // 假设你创建了这个文件
import { createCompletionsLoggingMiddleware } from './common/CompletionsLoggingMiddleware' // 已有的
const middlewareConfig: AiProviderMiddlewareConfig = {
completions: [
createSimpleLoggingMiddleware(), // 你新加的中间件
createCompletionsLoggingMiddleware() // 已有的日志中间件
// ... 其他 completions 中间件
],
methods: {
// translate: [createGenericLoggingMiddleware()],
// ... 其他方法的中间件
}
}
export default middlewareConfig
```
### 最佳实践
1. **单一职责**: 每个中间件应专注于一个特定的功能(例如,日志、缓存、转换特定数据)。
2. **无副作用 (尽可能)**: 除了通过 `context``onChunk` 明确的副作用外,尽量避免修改全局状态或产生其他隐蔽的副作用。
3. **错误处理**:
- 在中间件内部使用 `try...catch` 来处理可能发生的错误。
- 决定是自行处理错误(例如,通过 `onChunk` 发送错误块)还是将错误重新抛出给上游。
- 如果重新抛出,确保错误对象包含足够的信息。
4. **性能考虑**: 中间件会增加请求处理的开销。避免在中间件中执行非常耗时的同步操作。对于IO密集型操作,确保它们是异步的。
5. **可配置性**: 使中间件的行为可通过参数或配置进行调整。例如,日志中间件可以接受一个日志级别参数。
6. **上下文管理**:
- 谨慎地向 `context` 添加数据。避免污染 `context` 或添加过大的对象。
- 明确你添加到 `context` 的字段的用途和生命周期。
7. **`next` 的调用**:
- 除非你有充分的理由提前终止请求(例如,缓存命中、授权失败),否则**总是确保调用 `await next(context, params)`**。否则,下游的中间件和核心逻辑将不会执行。
- 理解 `next` 的返回值并正确处理它,特别是当它是一个流时。你需要负责消费这个流或将其传递给另一个能够消费它的组件/中间件。
8. **命名清晰**: 给你的中间件和它们创建的函数起描述性的名字。
9. **文档和注释**: 对复杂的中间件逻辑添加注释,解释其工作原理和目的。
### 调试技巧
- 在中间件的关键点使用 `console.log` 或调试器来检查 `params``context` 的状态以及 `next` 的返回值。
- 暂时简化中间件链,只保留你正在调试的中间件和最简单的核心逻辑,以隔离问题。
- 编写单元测试来独立验证每个中间件的行为。
通过遵循这些指南,你应该能够有效地为我们的系统创建强大且可维护的中间件。如果你有任何疑问或需要进一步的帮助,请咨询团队。
+8 -5
View File
@@ -113,8 +113,11 @@ afterSign: scripts/notarize.js
artifactBuildCompleted: scripts/artifact-build-completed.js
releaseInfo:
releaseNotes: |
新增划词助手
助手支持分组
支持主题颜色切换
划词助手支持应用过滤
翻译模块功能改进
划词助手:支持文本选择快捷键、开关快捷键、思考块支持和引用功能
复制功能:新增纯文本复制(去除Markdown格式符号)
知识库:支持设置向量维度,修复Ollama分数错误和维度编辑问题
多语言:增加模型名称多语言提示和翻译源语言手动选择
文件管理:修复主题/消息删除时文件未清理问题,优化文件选择流程
模型:修复Gemini模型推理预算、Voyage AI嵌入问题和DeepSeek翻译模型更新
图像功能:统一图片查看器,支持Base64图片渲染,修复图片预览相关问题
UI:实现标签折叠/拖拽排序,修复气泡溢出,增加引文索引显示
+46 -45
View File
@@ -58,6 +58,23 @@
"prepare": "husky"
},
"dependencies": {
"@libsql/client": "0.14.0",
"@libsql/win32-x64-msvc": "^0.4.7",
"@strongtz/win32-arm64-msvc": "^0.4.7",
"canvas": "3.1.0",
"jsdom": "26.1.0",
"os-proxy-config": "^1.1.2",
"pdf-to-img": "^4.4.0",
"pdfjs-dist": "4.2.67",
"selection-hook": "^0.9.23",
"turndown": "7.2.0"
},
"devDependencies": {
"@agentic/exa": "^7.3.3",
"@agentic/searxng": "^7.3.3",
"@agentic/tavily": "^7.3.3",
"@ant-design/v5-patch-for-react-19": "^1.0.3",
"@anthropic-ai/sdk": "^0.41.0",
"@cherrystudio/embedjs": "^0.1.31",
"@cherrystudio/embedjs-libsql": "^0.1.31",
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
@@ -70,54 +87,11 @@
"@cherrystudio/embedjs-loader-xml": "^0.1.31",
"@cherrystudio/embedjs-ollama": "^0.1.31",
"@cherrystudio/embedjs-openai": "^0.1.31",
"@electron-toolkit/utils": "^3.0.0",
"@langchain/community": "^0.3.36",
"@langchain/ollama": "^0.2.1",
"@mistralai/mistralai": "^1.6.0",
"@strongtz/win32-arm64-msvc": "^0.4.7",
"@tanstack/react-query": "^5.27.0",
"@types/react-infinite-scroll-component": "^5.0.0",
"archiver": "^7.0.1",
"async-mutex": "^0.5.0",
"canvas": "3.1.0",
"diff": "^7.0.0",
"docx": "^9.0.2",
"electron-log": "^5.1.5",
"electron-store": "^8.2.0",
"electron-updater": "6.6.4",
"electron-window-state": "^5.0.3",
"epub": "patch:epub@npm%3A1.3.0#~/.yarn/patches/epub-npm-1.3.0-8325494ffe.patch",
"fast-xml-parser": "^5.2.0",
"franc-min": "^6.2.0",
"fs-extra": "^11.2.0",
"jsdom": "^26.0.0",
"markdown-it": "^14.1.0",
"node-stream-zip": "^1.15.0",
"officeparser": "^4.1.1",
"os-proxy-config": "^1.1.2",
"pdf-to-img": "^4.4.0",
"pdfjs-dist": "4.2.67",
"proxy-agent": "^6.5.0",
"remove-markdown": "^0.6.2",
"selection-hook": "^0.9.23",
"tar": "^7.4.3",
"turndown": "^7.2.0",
"webdav": "^5.8.0",
"zipread": "^1.3.3"
},
"optionalDependencies": {
"@cherrystudio/mac-system-ocr": "^0.2.2"
},
"devDependencies": {
"@agentic/exa": "^7.3.3",
"@agentic/searxng": "^7.3.3",
"@agentic/tavily": "^7.3.3",
"@ant-design/v5-patch-for-react-19": "^1.0.3",
"@anthropic-ai/sdk": "^0.41.0",
"@electron-toolkit/eslint-config-prettier": "^3.0.0",
"@electron-toolkit/eslint-config-ts": "^3.0.0",
"@electron-toolkit/preload": "^3.0.0",
"@electron-toolkit/tsconfig": "^1.0.1",
"@electron-toolkit/utils": "^3.0.0",
"@electron/notarize": "^2.5.0",
"@emotion/is-prop-valid": "^1.3.1",
"@eslint-react/eslint-plugin": "^1.36.1",
@@ -125,6 +99,9 @@
"@google/genai": "^1.0.1",
"@hello-pangea/dnd": "^16.6.0",
"@kangfenmao/keyv-storage": "^0.1.0",
"@langchain/community": "^0.3.36",
"@langchain/ollama": "^0.2.1",
"@mistralai/mistralai": "^1.6.0",
"@modelcontextprotocol/sdk": "^1.11.4",
"@mozilla/readability": "^0.6.0",
"@notionhq/client": "^2.2.15",
@@ -132,6 +109,7 @@
"@reduxjs/toolkit": "^2.2.5",
"@shikijs/markdown-it": "^3.4.2",
"@swc/plugin-styled-components": "^7.1.5",
"@tanstack/react-query": "^5.27.0",
"@testing-library/dom": "^10.4.0",
"@testing-library/jest-dom": "^6.6.3",
"@testing-library/react": "^16.3.0",
@@ -158,24 +136,36 @@
"@vitest/web-worker": "^3.1.4",
"@xyflow/react": "^12.4.4",
"antd": "^5.22.5",
"archiver": "^7.0.1",
"async-mutex": "^0.5.0",
"axios": "^1.7.3",
"browser-image-compression": "^2.0.2",
"color": "^5.0.0",
"dayjs": "^1.11.11",
"dexie": "^4.0.8",
"dexie-react-hooks": "^1.1.7",
"diff": "^7.0.0",
"docx": "^9.0.2",
"dotenv-cli": "^7.4.2",
"electron": "35.4.0",
"electron-builder": "26.0.15",
"electron-devtools-installer": "^3.2.0",
"electron-log": "^5.1.5",
"electron-store": "^8.2.0",
"electron-updater": "6.6.4",
"electron-vite": "^3.1.0",
"electron-window-state": "^5.0.3",
"emittery": "^1.0.3",
"emoji-picker-element": "^1.22.1",
"epub": "patch:epub@npm%3A1.3.0#~/.yarn/patches/epub-npm-1.3.0-8325494ffe.patch",
"eslint": "^9.22.0",
"eslint-plugin-react-hooks": "^5.2.0",
"eslint-plugin-simple-import-sort": "^12.1.1",
"eslint-plugin-unused-imports": "^4.1.4",
"fast-diff": "^1.3.0",
"fast-xml-parser": "^5.2.0",
"franc-min": "^6.2.0",
"fs-extra": "^11.2.0",
"html-to-image": "^1.11.13",
"husky": "^9.1.7",
"i18next": "^23.11.5",
@@ -184,14 +174,18 @@
"lodash": "^4.17.21",
"lru-cache": "^11.1.0",
"lucide-react": "^0.487.0",
"markdown-it": "^14.1.0",
"mermaid": "^11.6.0",
"mime": "^4.0.4",
"motion": "^12.10.5",
"node-stream-zip": "^1.15.0",
"npx-scope-finder": "^1.2.0",
"officeparser": "^4.1.1",
"openai": "patch:openai@npm%3A5.1.0#~/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch",
"p-queue": "^8.1.0",
"playwright": "^1.52.0",
"prettier": "^3.5.3",
"proxy-agent": "^6.5.0",
"rc-virtual-list": "^3.18.6",
"react": "^19.0.0",
"react-dom": "^19.0.0",
@@ -212,17 +206,24 @@
"remark-cjk-friendly": "^1.1.0",
"remark-gfm": "^4.0.0",
"remark-math": "^6.0.0",
"remove-markdown": "^0.6.2",
"rollup-plugin-visualizer": "^5.12.0",
"sass": "^1.88.0",
"shiki": "^3.4.2",
"string-width": "^7.2.0",
"styled-components": "^6.1.11",
"tar": "^7.4.3",
"tiny-pinyin": "^1.3.2",
"tokenx": "^0.4.1",
"typescript": "^5.6.2",
"uuid": "^10.0.0",
"vite": "6.2.6",
"vitest": "^3.1.4"
"vitest": "^3.1.4",
"webdav": "^5.8.0",
"zipread": "^1.3.3"
},
"optionalDependencies": {
"@cherrystudio/mac-system-ocr": "^0.2.2"
},
"resolutions": {
"pdf-parse@npm:1.1.1": "patch:pdf-parse@npm%3A1.1.1#~/.yarn/patches/pdf-parse-npm-1.1.1-04a6109b2a.patch",
+1
View File
@@ -408,3 +408,4 @@ export enum FeedUrl {
PRODUCTION = 'https://releases.cherry-ai.com',
EARLY_ACCESS = 'https://github.com/CherryHQ/cherry-studio/releases/latest/download'
}
export const defaultTimeout = 5 * 1000 * 60
+6 -1
View File
@@ -5,7 +5,8 @@ import { FeedUrl } from '@shared/config/constant'
import { UpdateInfo } from 'builder-util-runtime'
import { app, BrowserWindow, dialog } from 'electron'
import logger from 'electron-log'
import { AppUpdater as _AppUpdater, autoUpdater } from 'electron-updater'
import { AppUpdater as _AppUpdater, autoUpdater, NsisUpdater } from 'electron-updater'
import path from 'path'
import icon from '../../../build/icon.png?asset'
import { configManager } from './ConfigManager'
@@ -56,6 +57,10 @@ export default class AppUpdater {
logger.info('下载完成', releaseInfo)
})
if (isWin) {
;(autoUpdater as NsisUpdater).installDirectory = path.dirname(app.getPath('exe'))
}
this.autoUpdater = autoUpdater
}
+223
View File
@@ -0,0 +1,223 @@
# Cherry Studio AI Provider 技术架构文档 (新方案)
## 1. 核心设计理念与目标
本架构旨在重构 Cherry Studio 的 AI Provider(现称为 `aiCore`)层,以实现以下目标:
- **职责清晰**:明确划分各组件的职责,降低耦合度。
- **高度复用**:最大化业务逻辑和通用处理逻辑的复用,减少重复代码。
- **易于扩展**:方便快捷地接入新的 AI Provider (LLM供应商) 和添加新的 AI 功能 (如翻译、摘要、图像生成等)。
- **易于维护**:简化单个组件的复杂性,提高代码的可读性和可维护性。
- **标准化**:统一内部数据流和接口,简化不同 Provider 之间的差异处理。
核心思路是将纯粹的 **SDK 适配层 (`XxxApiClient`)**、**通用逻辑处理与智能解析层 (中间件)** 以及 **统一业务功能入口层 (`AiCoreService`)** 清晰地分离开来。
## 2. 核心组件详解
### 2.1. `aiCore` (原 `AiProvider` 文件夹)
这是整个 AI 功能的核心模块。
#### 2.1.1. `XxxApiClient` (例如 `aiCore/clients/openai/OpenAIApiClient.ts`)
- **职责**:作为特定 AI Provider SDK 的纯粹适配层。
- **参数适配**:将应用内部统一的 `CoreRequest` 对象 (见下文) 转换为特定 SDK 所需的请求参数格式。
- **基础响应转换**:将 SDK 返回的原始数据块 (`RawSdkChunk`,例如 `OpenAI.Chat.Completions.ChatCompletionChunk`) 转换为一组最基础、最直接的应用层 `Chunk` 对象 (定义于 `src/renderer/src/types/chunk.ts`)。
- 例如:SDK 的 `delta.content` -> `TextDeltaChunk`SDK 的 `delta.reasoning_content` -> `ThinkingDeltaChunk`SDK 的 `delta.tool_calls` -> `RawToolCallChunk` (包含原始工具调用数据)。
- **关键**`XxxApiClient` **不处理**耦合在文本内容中的复杂结构,如 `<think>``<tool_use>` 标签。
- **特点**:极度轻量化,代码量少,易于实现和维护新的 Provider 适配。
#### 2.1.2. `ApiClient.ts` (或 `BaseApiClient.ts` 的核心接口)
- 定义了所有 `XxxApiClient` 必须实现的接口,如:
- `getSdkInstance(): Promise<TSdkInstance> | TSdkInstance`
- `getRequestTransformer(): RequestTransformer<TSdkParams>`
- `getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk, TResponseContext>`
- 其他可选的、与特定 Provider 相关的辅助方法 (如工具调用转换)。
#### 2.1.3. `ApiClientFactory.ts`
- 根据 Provider 配置动态创建和返回相应的 `XxxApiClient` 实例。
#### 2.1.4. `AiCoreService.ts` (`aiCore/index.ts`)
- **职责**:作为所有 AI 相关业务功能的统一入口。
- 提供面向应用的高层接口,例如:
- `executeCompletions(params: CompletionsParams): Promise<AggregatedCompletionsResult>`
- `translateText(params: TranslateParams): Promise<AggregatedTranslateResult>`
- `summarizeText(params: SummarizeParams): Promise<AggregatedSummarizeResult>`
- 未来可能的 `generateImage(prompt: string): Promise<ImageResult>` 等。
- **返回 `Promise`**:每个服务方法返回一个 `Promise`,该 `Promise` 会在整个(可能是流式的)操作完成后,以包含所有聚合结果(如完整文本、工具调用详情、最终的`usage`/`metrics`等)的对象来 `resolve`
- **支持流式回调**:服务方法的参数 (如 `CompletionsParams`) 依然包含 `onChunk` 回调,用于向调用方实时推送处理过程中的 `Chunk` 数据,实现流式UI更新。
- **封装特定任务的提示工程 (Prompt Engineering)**
- 例如,`translateText` 方法内部会构建一个包含特定翻译指令的 `CoreRequest`
- **编排和调用中间件链**:通过内部的 `MiddlewareBuilder` (参见 `middleware/BUILDER_USAGE.md`) 实例,根据调用的业务方法和参数,动态构建和组织合适的中间件序列,然后通过 `applyCompletionsMiddlewares` 等组合函数执行。
- 获取 `ApiClient` 实例并将其注入到中间件上游的 `Context` 中。
- **将 `Promise``resolve``reject` 函数传递给中间件链** (通过 `Context`),以便 `FinalChunkConsumerAndNotifierMiddleware` 可以在操作完成或发生错误时结束该 `Promise`
- **优势**
- 业务逻辑(如翻译、摘要的提示构建和流程控制)只需实现一次,即可支持所有通过 `ApiClient` 接入的底层 Provider。
- **支持外部编排**:调用方可以 `await` 服务方法以获取最终聚合结果,然后将此结果作为后续操作的输入,轻松实现多步骤工作流。
- **支持内部组合**:服务自身也可以通过 `await` 调用其他原子服务方法来构建更复杂的组合功能。
#### 2.1.5. `coreRequestTypes.ts` (或 `types.ts`)
- 定义核心的、Provider 无关的内部请求结构,例如:
- `CoreCompletionsRequest`: 包含标准化后的消息列表、模型配置、工具列表、最大Token数、是否流式输出等。
- `CoreTranslateRequest`, `CoreSummarizeRequest` 等 (如果与 `CoreCompletionsRequest` 结构差异较大,否则可复用并添加任务类型标记)。
### 2.2. `middleware`
中间件层负责处理请求和响应流中的通用逻辑和特定特性。其设计和使用遵循 `middleware/BUILDER_USAGE.md` 中定义的规范。
**核心组件包括:**
- **`MiddlewareBuilder`**: 一个通用的、提供流式API的类,用于动态构建中间件链。它支持从基础链开始,根据条件添加、插入、替换或移除中间件。
- **`applyCompletionsMiddlewares`**: 负责接收 `MiddlewareBuilder` 构建的链并按顺序执行,专门用于 Completions 流程。
- **`MiddlewareRegistry`**: 集中管理所有可用中间件的注册表,提供统一的中间件访问接口。
- **各种独立的中间件模块** (存放于 `common/`, `core/`, `feat/` 子目录)。
#### 2.2.1. `middlewareTypes.ts`
- 定义中间件的核心类型,如 `AiProviderMiddlewareContext` (扩展后包含 `_apiClientInstance``_coreRequest`)、`MiddlewareAPI``CompletionsMiddleware` 等。
#### 2.2.2. 核心中间件 (`middleware/core/`)
- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()``CoreRequest` 转换为特定 SDK 的参数,并存入上下文。
- **`RequestExecutionMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。
- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流 (如异步迭代器) 统一适配为 `ReadableStream<RawSdkChunk>`
- **`RawSdkChunk`**:指特定AI提供商SDK在流式响应中返回的、未经应用层统一处理的原始数据块格式 (例如 OpenAI 的 `ChatCompletionChunk`Gemini 的 `GenerateContentResponse` 中的部分等)。
- **`RawSdkChunkToAppChunkMiddleware.ts`**: (新增) 消费 `ReadableStream<RawSdkChunk>`,在其内部对每个 `RawSdkChunk` 调用 `ApiClient.getResponseChunkTransformer()`,将其转换为一个或多个基础的应用层 `Chunk` 对象,并输出 `ReadableStream<Chunk>`
#### 2.2.3. 特性中间件 (`middleware/feat/`)
这些中间件消费由 `ResponseTransformMiddleware` 输出的、相对标准化的 `Chunk` 流,并处理更复杂的逻辑。
- **`ThinkingTagExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `<think>...</think>` 文本内嵌标签,生成 `ThinkingDeltaChunk``ThinkingCompleteChunk`
- **`ToolUseExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `<tool_use>...</tool_use>` 文本内嵌标签,生成工具调用相关的 Chunk。如果 `ApiClient` 输出了原生工具调用数据,此中间件也负责将其转换为标准格式。
#### 2.2.4. 核心处理中间件 (`middleware/core/`)
- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()``CoreRequest` 转换为特定 SDK 的参数,并存入上下文。
- **`SdkCallMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。
- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流统一适配为标准流格式。
- **`ResponseTransformMiddleware.ts`**: 将原始 SDK 响应转换为应用层标准 `Chunk` 对象。
- **`TextChunkMiddleware.ts`**: 处理文本相关的 Chunk 流。
- **`ThinkChunkMiddleware.ts`**: 处理思考相关的 Chunk 流。
- **`McpToolChunkMiddleware.ts`**: 处理工具调用相关的 Chunk 流。
- **`WebSearchMiddleware.ts`**: 处理 Web 搜索相关逻辑。
#### 2.2.5. 通用中间件 (`middleware/common/`)
- **`LoggingMiddleware.ts`**: 请求和响应日志。
- **`AbortHandlerMiddleware.ts`**: 处理请求中止。
- **`FinalChunkConsumerMiddleware.ts`**: 消费最终的 `Chunk` 流,通过 `context.onChunk` 回调通知应用层实时数据。
- **累积数据**:在流式处理过程中,累积关键数据,如文本片段、工具调用信息、`usage`/`metrics` 等。
- **结束 `Promise`**:当输入流结束时,使用累积的聚合结果来完成整个处理流程。
- 在流结束时,发送包含最终累加信息的完成信号。
### 2.3. `types/chunk.ts`
- 定义应用全局统一的 `Chunk` 类型及其所有变体。这包括基础类型 (如 `TextDeltaChunk`, `ThinkingDeltaChunk`)、SDK原生数据传递类型 (如 `RawToolCallChunk`, `RawFinishChunk` - 作为 `ApiClient` 转换的中间产物),以及功能性类型 (如 `McpToolCallRequestChunk`, `WebSearchCompleteChunk`)。
## 3. 核心执行流程 (以 `AiCoreService.executeCompletions` 为例)
```markdown
**应用层 (例如 UI 组件)**
||
\\/
**`AiProvider.completions` (`aiCore/index.ts`)**
(1. prepare ApiClient instance. 2. use `CompletionsMiddlewareBuilder.withDefaults()` to build middleware chain. 3. call `applyCompletionsMiddlewares`)
||
\\/
**`applyCompletionsMiddlewares` (`middleware/composer.ts`)**
(接收构建好的链、ApiClient实例、原始SDK方法,开始按序执行中间件)
||
\\/
**[ 预处理阶段中间件 ]**
(例如: `FinalChunkConsumerMiddleware`, `TransformCoreToSdkParamsMiddleware`, `AbortHandlerMiddleware`)
|| (Context 中准备好 SDK 请求参数)
\\/
**[ 处理阶段中间件 ]**
(例如: `McpToolChunkMiddleware`, `WebSearchMiddleware`, `TextChunkMiddleware`, `ThinkingTagExtractionMiddleware`)
|| (处理各种特性和Chunk类型)
\\/
**[ SDK调用阶段中间件 ]**
(例如: `ResponseTransformMiddleware`, `StreamAdapterMiddleware`, `SdkCallMiddleware`)
|| (输出: 标准化的应用层Chunk流)
\\/
**`FinalChunkConsumerMiddleware` (核心)**
(消费最终的 `Chunk` 流, 通过 `context.onChunk` 回调通知应用层, 并在流结束时完成处理)
||
\\/
**`AiProvider.completions` 返回 `Promise<CompletionsResult>`**
```
## 4. 建议的文件/目录结构
```
src/renderer/src/
└── aiCore/
├── clients/
│ ├── openai/
│ ├── gemini/
│ ├── anthropic/
│ ├── BaseApiClient.ts
│ ├── ApiClientFactory.ts
│ ├── AihubmixAPIClient.ts
│ ├── index.ts
│ └── types.ts
├── middleware/
│ ├── common/
│ ├── core/
│ ├── feat/
│ ├── builder.ts
│ ├── composer.ts
│ ├── index.ts
│ ├── register.ts
│ ├── schemas.ts
│ ├── types.ts
│ └── utils.ts
├── types/
│ ├── chunk.ts
│ └── ...
└── index.ts
```
## 5. 迁移和实施建议
- **小步快跑,逐步迭代**:优先完成核心流程的重构(例如 `completions`),再逐步迁移其他功能(`translate` 等)和其他 Provider。
- **优先定义核心类型**`CoreRequest`, `Chunk`, `ApiClient` 接口是整个架构的基石。
- **为 `ApiClient` 瘦身**:将现有 `XxxProvider` 中的复杂逻辑剥离到新的中间件或 `AiCoreService` 中。
- **强化中间件**:让中间件承担起更多解析和特性处理的责任。
- **编写单元测试和集成测试**:确保每个组件和整体流程的正确性。
此架构旨在提供一个更健壮、更灵活、更易于维护的 AI 功能核心,支撑 Cherry Studio 未来的发展。
## 6. 迁移策略与实施建议
本节内容提炼自早期的 `migrate.md` 文档,并根据最新的架构讨论进行了调整。
**目标架构核心组件回顾:**
与第 2 节描述的核心组件一致,主要包括 `XxxApiClient`, `AiCoreService`, 中间件链, `CoreRequest` 类型, 和标准化的 `Chunk` 类型。
**迁移步骤:**
**Phase 0: 准备工作和类型定义**
1. **定义核心数据结构 (TypeScript 类型)**
- `CoreCompletionsRequest` (Type):定义应用内部统一的对话请求结构。
- `Chunk` (Type - 检查并按需扩展现有 `src/renderer/src/types/chunk.ts`):定义所有可能的通用Chunk类型。
- 为其他API(翻译、总结)定义类似的 `CoreXxxRequest` (Type)。
2. **定义 `ApiClient` 接口:** 明确 `getRequestTransformer`, `getResponseChunkTransformer`, `getSdkInstance` 等核心方法。
3. **调整 `AiProviderMiddlewareContext`**
- 确保包含 `_apiClientInstance: ApiClient<any,any,any>`
- 确保包含 `_coreRequest: CoreRequestType`
- 考虑添加 `resolvePromise: (value: AggregatedResultType) => void``rejectPromise: (reason?: any) => void` 用于 `AiCoreService` 的 Promise 返回。
**Phase 1: 实现第一个 `ApiClient` (以 `OpenAIApiClient` 为例)**
1. **创建 `OpenAIApiClient` 类:** 实现 `ApiClient` 接口。
2. **迁移SDK实例和配置。**
3. **实现 `getRequestTransformer()`**`CoreCompletionsRequest` 转换为 OpenAI SDK 参数。
4. **实现 `getResponseChunkTransformer()`**`OpenAI.Chat.Completions.ChatCompletionChunk` 转换为基础的 `
@@ -0,0 +1,207 @@
import { isOpenAILLMModel } from '@renderer/config/models'
import {
GenerateImageParams,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
ToolCallResponse
} from '@renderer/types'
import {
RequestOptions,
SdkInstance,
SdkMessageParam,
SdkModel,
SdkParams,
SdkRawChunk,
SdkRawOutput,
SdkTool,
SdkToolCall
} from '@renderer/types/sdk'
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
import { BaseApiClient } from './BaseApiClient'
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
import { RequestTransformer, ResponseChunkTransformer } from './types'
/**
* AihubmixAPIClient - 根据模型类型自动选择合适的ApiClient
* 使用装饰器模式实现,在ApiClient层面进行模型路由
*/
export class AihubmixAPIClient extends BaseApiClient {
// 使用联合类型而不是any,保持类型安全
private clients: Map<string, AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient> =
new Map()
private defaultClient: OpenAIAPIClient
private currentClient: BaseApiClient
constructor(provider: Provider) {
super(provider)
// 初始化各个client - 现在有类型安全
const claudeClient = new AnthropicAPIClient(provider)
const geminiClient = new GeminiAPIClient({ ...provider, apiHost: 'https://aihubmix.com/gemini' })
const openaiClient = new OpenAIResponseAPIClient(provider)
const defaultClient = new OpenAIAPIClient(provider)
this.clients.set('claude', claudeClient)
this.clients.set('gemini', geminiClient)
this.clients.set('openai', openaiClient)
this.clients.set('default', defaultClient)
// 设置默认client
this.defaultClient = defaultClient
this.currentClient = this.defaultClient as BaseApiClient
}
/**
* 类型守卫:确保client是BaseApiClient的实例
*/
private isValidClient(client: unknown): client is BaseApiClient {
return (
client !== null &&
client !== undefined &&
typeof client === 'object' &&
'createCompletions' in client &&
'getRequestTransformer' in client &&
'getResponseChunkTransformer' in client
)
}
/**
* 根据模型获取合适的client
*/
private getClient(model: Model): BaseApiClient {
const id = model.id.toLowerCase()
// claude开头
if (id.startsWith('claude')) {
const client = this.clients.get('claude')
if (!client || !this.isValidClient(client)) {
throw new Error('Claude client not properly initialized')
}
return client
}
// gemini开头 且不以-nothink、-search结尾
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
const client = this.clients.get('gemini')
if (!client || !this.isValidClient(client)) {
throw new Error('Gemini client not properly initialized')
}
return client
}
// OpenAI系列模型
if (isOpenAILLMModel(model)) {
const client = this.clients.get('openai')
if (!client || !this.isValidClient(client)) {
throw new Error('OpenAI client not properly initialized')
}
return client
}
return this.defaultClient as BaseApiClient
}
/**
* 根据模型选择合适的client并委托调用
*/
public getClientForModel(model: Model): BaseApiClient {
this.currentClient = this.getClient(model)
return this.currentClient
}
// ============ BaseApiClient 抽象方法实现 ============
async createCompletions(payload: SdkParams, options?: RequestOptions): Promise<SdkRawOutput> {
// 尝试从payload中提取模型信息来选择client
const modelId = this.extractModelFromPayload(payload)
if (modelId) {
const modelObj = { id: modelId } as Model
const targetClient = this.getClient(modelObj)
return targetClient.createCompletions(payload, options)
}
// 如果无法从payload中提取模型,使用当前设置的client
return this.currentClient.createCompletions(payload, options)
}
/**
* 从SDK payload中提取模型ID
*/
private extractModelFromPayload(payload: SdkParams): string | null {
// 不同的SDK可能有不同的字段名
if ('model' in payload && typeof payload.model === 'string') {
return payload.model
}
return null
}
async generateImage(params: GenerateImageParams): Promise<string[]> {
return this.currentClient.generateImage(params)
}
async getEmbeddingDimensions(model?: Model): Promise<number> {
const client = model ? this.getClient(model) : this.currentClient
return client.getEmbeddingDimensions(model)
}
async listModels(): Promise<SdkModel[]> {
// 可以聚合所有client的模型,或者使用默认client
return this.defaultClient.listModels()
}
async getSdkInstance(): Promise<SdkInstance> {
return this.currentClient.getSdkInstance()
}
getRequestTransformer(): RequestTransformer<SdkParams, SdkMessageParam> {
return this.currentClient.getRequestTransformer()
}
getResponseChunkTransformer(): ResponseChunkTransformer<SdkRawChunk> {
return this.currentClient.getResponseChunkTransformer()
}
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] {
return this.currentClient.convertMcpToolsToSdkTools(mcpTools)
}
convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools)
}
convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse {
return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
}
buildSdkMessages(
currentReqMessages: SdkMessageParam[],
output: SdkRawOutput | string,
toolResults: SdkMessageParam[],
toolCalls?: SdkToolCall[]
): SdkMessageParam[] {
return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
}
convertMcpToolResponseToSdkMessageParam(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): SdkMessageParam | undefined {
const client = this.getClient(model)
return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
}
extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] {
return this.currentClient.extractMessagesFromSdkPayload(sdkPayload)
}
estimateMessageTokens(message: SdkMessageParam): number {
return this.currentClient.estimateMessageTokens(message)
}
}
@@ -0,0 +1,62 @@
import { Provider } from '@renderer/types'
import { AihubmixAPIClient } from './AihubmixAPIClient'
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
import { BaseApiClient } from './BaseApiClient'
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
/**
* Factory for creating ApiClient instances based on provider configuration
* 根据提供者配置创建ApiClient实例的工厂
*/
export class ApiClientFactory {
/**
* Create an ApiClient instance for the given provider
* 为给定的提供者创建ApiClient实例
*/
static create(provider: Provider): BaseApiClient {
console.log(`[ApiClientFactory] Creating ApiClient for provider:`, {
id: provider.id,
type: provider.type
})
let instance: BaseApiClient
// 首先检查特殊的provider id
if (provider.id === 'aihubmix') {
console.log(`[ApiClientFactory] Creating AihubmixAPIClient for provider: ${provider.id}`)
instance = new AihubmixAPIClient(provider) as BaseApiClient
return instance
}
// 然后检查标准的provider type
switch (provider.type) {
case 'openai':
case 'azure-openai':
console.log(`[ApiClientFactory] Creating OpenAIApiClient for provider: ${provider.id}`)
instance = new OpenAIAPIClient(provider) as BaseApiClient
break
case 'openai-response':
instance = new OpenAIResponseAPIClient(provider) as BaseApiClient
break
case 'gemini':
instance = new GeminiAPIClient(provider) as BaseApiClient
break
case 'anthropic':
instance = new AnthropicAPIClient(provider) as BaseApiClient
break
default:
console.log(`[ApiClientFactory] Using default OpenAIApiClient for provider: ${provider.id}`)
instance = new OpenAIAPIClient(provider) as BaseApiClient
break
}
return instance
}
}
export function isOpenAIProvider(provider: Provider) {
return !['anthropic', 'gemini'].includes(provider.type)
}
@@ -1,40 +1,69 @@
import Logger from '@renderer/config/logger'
import { isFunctionCallingModel, isNotSupportTemperatureAndTopP } from '@renderer/config/models'
import {
isFunctionCallingModel,
isNotSupportTemperatureAndTopP,
isOpenAIModel,
isSupportedFlexServiceTier
} from '@renderer/config/models'
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
import type {
import { getStoreSetting } from '@renderer/hooks/useSettings'
import { SettingsState } from '@renderer/store/settings'
import {
Assistant,
FileTypes,
GenerateImageParams,
KnowledgeReference,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
OpenAIServiceTier,
Provider,
Suggestion,
ToolCallResponse,
WebSearchProviderResponse,
WebSearchResponse
} from '@renderer/types'
import { ChunkType } from '@renderer/types/chunk'
import type { Message } from '@renderer/types/newMessage'
import { delay, isJSON, parseJSON } from '@renderer/utils'
import { Message } from '@renderer/types/newMessage'
import {
RequestOptions,
SdkInstance,
SdkMessageParam,
SdkModel,
SdkParams,
SdkRawChunk,
SdkRawOutput,
SdkTool,
SdkToolCall
} from '@renderer/types/sdk'
import { isJSON, parseJSON } from '@renderer/utils'
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
import { formatApiHost } from '@renderer/utils/api'
import { getMainTextContent } from '@renderer/utils/messageUtils/find'
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { defaultTimeout } from '@shared/config/constant'
import Logger from 'electron-log/renderer'
import { isEmpty } from 'lodash'
import type OpenAI from 'openai'
import type { CompletionsParams } from '.'
import { ApiClient, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from './types'
export default abstract class BaseProvider {
// Threshold for determining whether to use system prompt for tools
/**
* Abstract base class for API clients.
* Provides common functionality and structure for specific client implementations.
*/
export abstract class BaseApiClient<
TSdkInstance extends SdkInstance = SdkInstance,
TSdkParams extends SdkParams = SdkParams,
TRawOutput extends SdkRawOutput = SdkRawOutput,
TRawChunk extends SdkRawChunk = SdkRawChunk,
TMessageParam extends SdkMessageParam = SdkMessageParam,
TToolCall extends SdkToolCall = SdkToolCall,
TSdkSpecificTool extends SdkTool = SdkTool
> implements ApiClient<TSdkInstance, TSdkParams, TRawOutput, TRawChunk, TMessageParam, TToolCall, TSdkSpecificTool>
{
private static readonly SYSTEM_PROMPT_THRESHOLD: number = 128
protected provider: Provider
public provider: Provider
protected host: string
protected apiKey: string
protected useSystemPromptForTools: boolean = true
protected sdkInstance?: TSdkInstance
public useSystemPromptForTools: boolean = true
constructor(provider: Provider) {
this.provider = provider
@@ -42,32 +71,81 @@ export default abstract class BaseProvider {
this.apiKey = this.getApiKey()
}
abstract completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void>
abstract translate(
content: string,
assistant: Assistant,
onResponse?: (text: string, isComplete: boolean) => void
): Promise<string>
abstract summaries(messages: Message[], assistant: Assistant): Promise<string>
abstract summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null>
abstract suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]>
abstract generateText({ prompt, content }: { prompt: string; content: string }): Promise<string>
abstract check(model: Model, stream: boolean): Promise<{ valid: boolean; error: Error | null }>
abstract models(): Promise<OpenAI.Models.Model[]>
abstract generateImage(params: GenerateImageParams): Promise<string[]>
abstract generateImageByChat({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void>
// 由于现在出现了一些能够选择嵌入维度的嵌入模型,这个不考虑dimensions参数的方法将只能应用于那些不支持dimensions的模型
abstract getEmbeddingDimensions(model: Model): Promise<number>
public abstract convertMcpTools<T>(mcpTools: MCPTool[]): T[]
public abstract mcpToolCallResponseToMessage(
// // 核心的completions方法 - 在中间件架构中,这通常只是一个占位符
// abstract completions(params: CompletionsParams, internal?: ProcessingState): Promise<CompletionsResult>
/**
* API Endpoint
**/
abstract createCompletions(payload: TSdkParams, options?: RequestOptions): Promise<TRawOutput>
abstract generateImage(generateImageParams: GenerateImageParams): Promise<string[]>
abstract getEmbeddingDimensions(model?: Model): Promise<number>
abstract listModels(): Promise<SdkModel[]>
abstract getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
/**
*
**/
// 在 CoreRequestToSdkParamsMiddleware中使用
abstract getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
// 在RawSdkChunkToGenericChunkMiddleware中使用
abstract getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk>
/**
*
**/
// Optional tool conversion methods - implement if needed by the specific provider
abstract convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[]
abstract convertSdkToolCallToMcp(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined
abstract convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse
abstract buildSdkMessages(
currentReqMessages: TMessageParam[],
output: TRawOutput | string,
toolResults: TMessageParam[],
toolCalls?: TToolCall[]
): TMessageParam[]
abstract estimateMessageTokens(message: TMessageParam): number
abstract convertMcpToolResponseToSdkMessageParam(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): any
): TMessageParam | undefined
/**
* SDK载荷中提取消息数组访
* 使messageshistory等
*/
abstract extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
/**
*
*/
public attachRawStreamListener<TListener extends RawStreamListener<TRawChunk>>(
rawOutput: TRawOutput,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
_listener: TListener
): TRawOutput {
return rawOutput
}
/**
*
**/
public getBaseURL(): string {
const host = this.provider.apiHost
return formatApiHost(host)
return this.provider.apiHost
}
public getApiKey() {
@@ -112,14 +190,32 @@ export default abstract class BaseProvider {
return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.topP
}
public async fakeCompletions({ onChunk }: CompletionsParams) {
for (let i = 0; i < 100; i++) {
await delay(0.01)
onChunk({
response: { text: i + '\n', usage: { completion_tokens: 0, prompt_tokens: 0, total_tokens: 0 } },
type: ChunkType.BLOCK_COMPLETE
})
protected getServiceTier(model: Model) {
if (!isOpenAIModel(model) || model.provider === 'github' || model.provider === 'copilot') {
return undefined
}
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
let serviceTier = 'auto' as OpenAIServiceTier
if (openAI && openAI?.serviceTier === 'flex') {
if (isSupportedFlexServiceTier(model)) {
serviceTier = 'flex'
} else {
serviceTier = 'auto'
}
} else {
serviceTier = openAI.serviceTier
}
return serviceTier
}
protected getTimeout(model: Model) {
if (isSupportedFlexServiceTier(model)) {
return 15 * 1000 * 60
}
return defaultTimeout
}
public async getMessageContent(message: Message): Promise<string> {
@@ -149,6 +245,36 @@ export default abstract class BaseProvider {
return content
}
/**
* Extract the file content from the message
* @param message - The message
* @returns The file content
*/
protected async extractFileContent(message: Message) {
const fileBlocks = findFileBlocks(message)
if (fileBlocks.length > 0) {
const textFileBlocks = fileBlocks.filter(
(fb) => fb.file && [FileTypes.TEXT, FileTypes.DOCUMENT].includes(fb.file.type)
)
if (textFileBlocks.length > 0) {
let text = ''
const divider = '\n\n---\n\n'
for (const fileBlock of textFileBlocks) {
const file = fileBlock.file
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
const fileNameRow = 'file: ' + file.origin_name + '\n\n'
text = text + fileNameRow + fileContent + divider
}
return text
}
}
return ''
}
private async getWebSearchReferencesFromCache(message: Message) {
const content = getMainTextContent(message)
if (isEmpty(content)) {
@@ -210,7 +336,7 @@ export default abstract class BaseProvider {
)
}
protected createAbortController(messageId?: string, isAddEventListener?: boolean) {
public createAbortController(messageId?: string, isAddEventListener?: boolean) {
const abortController = new AbortController()
const abortFn = () => abortController.abort()
@@ -256,11 +382,11 @@ export default abstract class BaseProvider {
}
// Setup tools configuration based on provided parameters
protected setupToolsConfig<T>(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
tools: T[]
public setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
tools: TSdkSpecificTool[]
} {
const { mcpTools, model, enableToolUse } = params
let tools: T[] = []
let tools: TSdkSpecificTool[] = []
// If there are no tools, return an empty array
if (!mcpTools?.length) {
@@ -268,14 +394,14 @@ export default abstract class BaseProvider {
}
// If the number of tools exceeds the threshold, use the system prompt
if (mcpTools.length > BaseProvider.SYSTEM_PROMPT_THRESHOLD) {
if (mcpTools.length > BaseApiClient.SYSTEM_PROMPT_THRESHOLD) {
this.useSystemPromptForTools = true
return { tools }
}
// If the model supports function calling and tool usage is enabled
if (isFunctionCallingModel(model) && enableToolUse) {
tools = this.convertMcpTools<T>(mcpTools)
tools = this.convertMcpToolsToSdkTools(mcpTools)
this.useSystemPromptForTools = false
}
@@ -0,0 +1,714 @@
import Anthropic from '@anthropic-ai/sdk'
import {
Base64ImageSource,
ImageBlockParam,
MessageParam,
TextBlockParam,
ToolResultBlockParam,
ToolUseBlock,
WebSearchTool20250305
} from '@anthropic-ai/sdk/resources'
import {
ContentBlock,
ContentBlockParam,
MessageCreateParams,
MessageCreateParamsBase,
RedactedThinkingBlockParam,
ServerToolUseBlockParam,
ThinkingBlockParam,
ThinkingConfigParam,
ToolUnion,
ToolUseBlockParam,
WebSearchResultBlock,
WebSearchToolResultBlockParam,
WebSearchToolResultError
} from '@anthropic-ai/sdk/resources/messages'
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import Logger from '@renderer/config/logger'
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import FileManager from '@renderer/services/FileManager'
import { estimateTextTokens } from '@renderer/services/TokenService'
import {
Assistant,
EFFORT_RATIO,
FileTypes,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
ToolCallResponse,
WebSearchSource
} from '@renderer/types'
import {
ChunkType,
ErrorChunk,
LLMWebSearchCompleteChunk,
LLMWebSearchInProgressChunk,
MCPToolCreatedChunk,
TextDeltaChunk,
ThinkingDeltaChunk
} from '@renderer/types/chunk'
import type { Message } from '@renderer/types/newMessage'
import {
AnthropicSdkMessageParam,
AnthropicSdkParams,
AnthropicSdkRawChunk,
AnthropicSdkRawOutput
} from '@renderer/types/sdk'
import { addImageFileToContents } from '@renderer/utils/formats'
import {
anthropicToolUseToMcpTool,
isEnabledToolUse,
mcpToolCallResponseToAnthropicMessage,
mcpToolsToAnthropicTools
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { BaseApiClient } from '../BaseApiClient'
import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
export class AnthropicAPIClient extends BaseApiClient<
Anthropic,
AnthropicSdkParams,
AnthropicSdkRawOutput,
AnthropicSdkRawChunk,
AnthropicSdkMessageParam,
ToolUseBlock,
ToolUnion
> {
constructor(provider: Provider) {
super(provider)
}
async getSdkInstance(): Promise<Anthropic> {
if (this.sdkInstance) {
return this.sdkInstance
}
this.sdkInstance = new Anthropic({
apiKey: this.getApiKey(),
baseURL: this.getBaseURL(),
dangerouslyAllowBrowser: true,
defaultHeaders: {
'anthropic-beta': 'output-128k-2025-02-19'
}
})
return this.sdkInstance
}
override async createCompletions(
payload: AnthropicSdkParams,
options?: Anthropic.RequestOptions
): Promise<AnthropicSdkRawOutput> {
const sdk = await this.getSdkInstance()
if (payload.stream) {
return sdk.messages.stream(payload, options)
}
return await sdk.messages.create(payload, options)
}
// @ts-ignore sdk未提供
// eslint-disable-next-line @typescript-eslint/no-unused-vars
override async generateImage(generateImageParams: GenerateImageParams): Promise<string[]> {
return []
}
override async listModels(): Promise<Anthropic.ModelInfo[]> {
const sdk = await this.getSdkInstance()
const response = await sdk.models.list()
return response.data
}
// @ts-ignore sdk未提供
override async getEmbeddingDimensions(): Promise<number> {
return 0
}
override getTemperature(assistant: Assistant, model: Model): number | undefined {
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
return undefined
}
return assistant.settings?.temperature
}
override getTopP(assistant: Assistant, model: Model): number | undefined {
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
return undefined
}
return assistant.settings?.topP
}
/**
* Get the reasoning effort
* @param assistant - The assistant
* @param model - The model
* @returns The reasoning effort
*/
private getBudgetToken(assistant: Assistant, model: Model): ThinkingConfigParam | undefined {
if (!isReasoningModel(model)) {
return undefined
}
const { maxTokens } = getAssistantSettings(assistant)
const reasoningEffort = assistant?.settings?.reasoning_effort
if (reasoningEffort === undefined) {
return {
type: 'disabled'
}
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
const budgetTokens = Math.max(
1024,
Math.floor(
Math.min(
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
findTokenLimit(model.id)?.min!,
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
)
)
)
return {
type: 'enabled',
budget_tokens: budgetTokens
}
}
/**
* Get the message parameter
* @param message - The message
* @param model - The model
* @returns The message parameter
*/
public async convertMessageToSdkParam(message: Message): Promise<AnthropicSdkMessageParam> {
const parts: MessageParam['content'] = [
{
type: 'text',
text: getMainTextContent(message)
}
]
// Get and process image blocks
const imageBlocks = findImageBlocks(message)
for (const imageBlock of imageBlocks) {
if (imageBlock.file) {
// Handle uploaded file
const file = imageBlock.file
const base64Data = await window.api.file.base64Image(file.id + file.ext)
parts.push({
type: 'image',
source: {
data: base64Data.base64,
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
type: 'base64'
}
})
}
}
// Get and process file blocks
const fileBlocks = findFileBlocks(message)
for (const fileBlock of fileBlocks) {
const { file } = fileBlock
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) {
const base64Data = await FileManager.readBase64File(file)
parts.push({
type: 'document',
source: {
type: 'base64',
media_type: 'application/pdf',
data: base64Data
}
})
} else {
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
parts.push({
type: 'text',
text: file.origin_name + '\n' + fileContent
})
}
}
}
return {
role: message.role === 'system' ? 'user' : message.role,
content: parts
}
}
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ToolUnion[] {
return mcpToolsToAnthropicTools(mcpTools)
}
public convertMcpToolResponseToSdkMessageParam(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): AnthropicSdkMessageParam | undefined {
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model)
} else if ('toolCallId' in mcpToolResponse) {
return {
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: mcpToolResponse.toolCallId!,
content: resp.content
.map((item) => {
if (item.type === 'text') {
return {
type: 'text',
text: item.text || ''
} satisfies TextBlockParam
}
if (item.type === 'image') {
return {
type: 'image',
source: {
data: item.data || '',
media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'],
type: 'base64'
}
} satisfies ImageBlockParam
}
return
})
.filter((n) => typeof n !== 'undefined'),
is_error: resp.isError
} satisfies ToolResultBlockParam
]
}
}
return
}
// Implementing abstract methods from BaseApiClient
convertSdkToolCallToMcp(toolCall: ToolUseBlock, mcpTools: MCPTool[]): MCPTool | undefined {
// Based on anthropicToolUseToMcpTool logic in AnthropicProvider
// This might need adjustment based on how tool calls are specifically handled in the new structure
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
return mcpTool
}
convertSdkToolCallToMcpToolResponse(toolCall: ToolUseBlock, mcpTool: MCPTool): ToolCallResponse {
return {
id: toolCall.id,
toolCallId: toolCall.id,
tool: mcpTool,
arguments: toolCall.input as Record<string, unknown>,
status: 'pending'
} as ToolCallResponse
}
override buildSdkMessages(
currentReqMessages: AnthropicSdkMessageParam[],
output: Anthropic.Message,
toolResults: AnthropicSdkMessageParam[]
): AnthropicSdkMessageParam[] {
const assistantMessage: AnthropicSdkMessageParam = {
role: output.role,
content: convertContentBlocksToParams(output.content)
}
const newMessages: AnthropicSdkMessageParam[] = [...currentReqMessages, assistantMessage]
if (toolResults && toolResults.length > 0) {
newMessages.push(...toolResults)
}
return newMessages
}
override estimateMessageTokens(message: AnthropicSdkMessageParam): number {
if (typeof message.content === 'string') {
return estimateTextTokens(message.content)
}
return message.content
.map((content) => {
switch (content.type) {
case 'text':
return estimateTextTokens(content.text)
case 'image':
if (content.source.type === 'base64') {
return estimateTextTokens(content.source.data)
} else {
return estimateTextTokens(content.source.url)
}
case 'tool_use':
return estimateTextTokens(JSON.stringify(content.input))
case 'tool_result':
return estimateTextTokens(JSON.stringify(content.content))
default:
return 0
}
})
.reduce((acc, curr) => acc + curr, 0)
}
public buildAssistantMessage(message: Anthropic.Message): AnthropicSdkMessageParam {
const messageParam: AnthropicSdkMessageParam = {
role: message.role,
content: convertContentBlocksToParams(message.content)
}
return messageParam
}
public extractMessagesFromSdkPayload(sdkPayload: AnthropicSdkParams): AnthropicSdkMessageParam[] {
return sdkPayload.messages || []
}
/**
* Anthropic专用的原始流监听器
* 处理MessageStream对象的特定事件
*/
override attachRawStreamListener(
rawOutput: AnthropicSdkRawOutput,
listener: RawStreamListener<AnthropicSdkRawChunk>
): AnthropicSdkRawOutput {
console.log(`[AnthropicApiClient] 附加流监听器到原始输出`)
// 检查是否为MessageStream
if (rawOutput instanceof MessageStream) {
console.log(`[AnthropicApiClient] 检测到 Anthropic MessageStream,附加专用监听器`)
if (listener.onStart) {
listener.onStart()
}
if (listener.onChunk) {
rawOutput.on('streamEvent', (event: AnthropicSdkRawChunk) => {
listener.onChunk!(event)
})
}
// 专用的Anthropic事件处理
const anthropicListener = listener as AnthropicStreamListener
if (anthropicListener.onContentBlock) {
rawOutput.on('contentBlock', anthropicListener.onContentBlock)
}
if (anthropicListener.onMessage) {
rawOutput.on('finalMessage', anthropicListener.onMessage)
}
if (listener.onEnd) {
rawOutput.on('end', () => {
listener.onEnd!()
})
}
if (listener.onError) {
rawOutput.on('error', (error: Error) => {
listener.onError!(error)
})
}
return rawOutput
}
// 对于非MessageStream响应
return rawOutput
}
private async getWebSearchParams(model: Model): Promise<WebSearchTool20250305 | undefined> {
if (!isWebSearchModel(model)) {
return undefined
}
return {
type: 'web_search_20250305',
name: 'web_search',
max_uses: 5
} as WebSearchTool20250305
}
getRequestTransformer(): RequestTransformer<AnthropicSdkParams, AnthropicSdkMessageParam> {
return {
transform: async (
coreRequest,
assistant,
model,
isRecursiveCall,
recursiveSdkMessages
): Promise<{
payload: AnthropicSdkParams
messages: AnthropicSdkMessageParam[]
metadata: Record<string, any>
}> => {
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
// 1. 处理系统消息
let systemPrompt = assistant.prompt
// 2. 设置工具
const { tools } = this.setupToolsConfig({
mcpTools: mcpTools,
model,
enableToolUse: isEnabledToolUse(assistant)
})
if (this.useSystemPromptForTools) {
systemPrompt = await buildSystemPrompt(systemPrompt, mcpTools)
}
const systemMessage: TextBlockParam | undefined = systemPrompt
? { type: 'text', text: systemPrompt }
: undefined
// 3. 处理用户消息
const sdkMessages: AnthropicSdkMessageParam[] = []
if (typeof messages === 'string') {
sdkMessages.push({ role: 'user', content: messages })
} else {
const processedMessages = addImageFileToContents(messages)
for (const message of processedMessages) {
sdkMessages.push(await this.convertMessageToSdkParam(message))
}
}
if (enableWebSearch) {
const webSearchTool = await this.getWebSearchParams(model)
if (webSearchTool) {
tools.push(webSearchTool)
}
}
const commonParams: MessageCreateParamsBase = {
model: model.id,
messages:
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
? recursiveSdkMessages
: sdkMessages,
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
temperature: this.getTemperature(assistant, model),
top_p: this.getTopP(assistant, model),
system: systemMessage ? [systemMessage] : undefined,
thinking: this.getBudgetToken(assistant, model),
tools: tools.length > 0 ? tools : undefined,
...this.getCustomParameters(assistant)
}
const finalParams: MessageCreateParams = streamOutput
? {
...commonParams,
stream: true
}
: {
...commonParams,
stream: false
}
const timeout = this.getTimeout(model)
return { payload: finalParams, messages: sdkMessages, metadata: { timeout } }
}
}
}
getResponseChunkTransformer(): ResponseChunkTransformer<AnthropicSdkRawChunk> {
return () => {
let accumulatedJson = ''
const toolCalls: Record<number, ToolUseBlock> = {}
return {
async transform(rawChunk: AnthropicSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
switch (rawChunk.type) {
case 'message': {
for (const content of rawChunk.content) {
switch (content.type) {
case 'text': {
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: content.text
} as TextDeltaChunk)
break
}
case 'tool_use': {
toolCalls[0] = content
break
}
case 'thinking': {
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: content.thinking
} as ThinkingDeltaChunk)
break
}
case 'web_search_tool_result': {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: content.content,
source: WebSearchSource.ANTHROPIC
}
} as LLMWebSearchCompleteChunk)
break
}
}
}
break
}
case 'content_block_start': {
const contentBlock = rawChunk.content_block
switch (contentBlock.type) {
case 'server_tool_use': {
if (contentBlock.name === 'web_search') {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS
} as LLMWebSearchInProgressChunk)
}
break
}
case 'web_search_tool_result': {
if (
contentBlock.content &&
(contentBlock.content as WebSearchToolResultError).type === 'web_search_tool_result_error'
) {
controller.enqueue({
type: ChunkType.ERROR,
error: {
code: (contentBlock.content as WebSearchToolResultError).error_code,
message: (contentBlock.content as WebSearchToolResultError).error_code
}
} as ErrorChunk)
} else {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: contentBlock.content as Array<WebSearchResultBlock>,
source: WebSearchSource.ANTHROPIC
}
} as LLMWebSearchCompleteChunk)
}
break
}
case 'tool_use': {
toolCalls[rawChunk.index] = contentBlock
break
}
}
break
}
case 'content_block_delta': {
const messageDelta = rawChunk.delta
switch (messageDelta.type) {
case 'text_delta': {
if (messageDelta.text) {
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: messageDelta.text
} as TextDeltaChunk)
}
break
}
case 'thinking_delta': {
if (messageDelta.thinking) {
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: messageDelta.thinking
} as ThinkingDeltaChunk)
}
break
}
case 'input_json_delta': {
if (messageDelta.partial_json) {
accumulatedJson += messageDelta.partial_json
}
break
}
}
break
}
case 'content_block_stop': {
const toolCall = toolCalls[rawChunk.index]
if (toolCall) {
try {
toolCall.input = JSON.parse(accumulatedJson)
Logger.debug(`Tool call id: ${toolCall.id}, accumulated json: ${accumulatedJson}`)
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: [toolCall]
} as MCPToolCreatedChunk)
} catch (error) {
Logger.error(`Error parsing tool call input: ${error}`)
}
}
break
}
case 'message_delta': {
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: rawChunk.usage.input_tokens || 0,
completion_tokens: rawChunk.usage.output_tokens || 0,
total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0)
}
}
})
}
}
}
}
}
}
}
/**
* 将 ContentBlock 数组转换为 ContentBlockParam 数组
* 去除服务器生成的额外字段,只保留发送给API所需的字段
*/
function convertContentBlocksToParams(contentBlocks: ContentBlock[]): ContentBlockParam[] {
return contentBlocks.map((block): ContentBlockParam => {
switch (block.type) {
case 'text':
// TextBlock -> TextBlockParam,去除 citations 等服务器字段
return {
type: 'text',
text: block.text
} satisfies TextBlockParam
case 'tool_use':
// ToolUseBlock -> ToolUseBlockParam
return {
type: 'tool_use',
id: block.id,
name: block.name,
input: block.input
} satisfies ToolUseBlockParam
case 'thinking':
// ThinkingBlock -> ThinkingBlockParam
return {
type: 'thinking',
thinking: block.thinking,
signature: block.signature
} satisfies ThinkingBlockParam
case 'redacted_thinking':
// RedactedThinkingBlock -> RedactedThinkingBlockParam
return {
type: 'redacted_thinking',
data: block.data
} satisfies RedactedThinkingBlockParam
case 'server_tool_use':
// ServerToolUseBlock -> ServerToolUseBlockParam
return {
type: 'server_tool_use',
id: block.id,
name: block.name,
input: block.input
} satisfies ServerToolUseBlockParam
case 'web_search_tool_result':
// WebSearchToolResultBlock -> WebSearchToolResultBlockParam
return {
type: 'web_search_tool_result',
tool_use_id: block.tool_use_id,
content: block.content
} satisfies WebSearchToolResultBlockParam
default:
return block as ContentBlockParam
}
})
}
@@ -0,0 +1,786 @@
import {
Content,
File,
FileState,
FunctionCall,
GenerateContentConfig,
GenerateImagesConfig,
GoogleGenAI,
HarmBlockThreshold,
HarmCategory,
Modality,
Model as GeminiModel,
Pager,
Part,
SafetySetting,
SendMessageParameters,
ThinkingConfig,
Tool
} from '@google/genai'
import { nanoid } from '@reduxjs/toolkit'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import {
findTokenLimit,
GEMINI_FLASH_MODEL_REGEX,
isGeminiReasoningModel,
isGemmaModel,
isVisionModel
} from '@renderer/config/models'
import { CacheService } from '@renderer/services/CacheService'
import { estimateTextTokens } from '@renderer/services/TokenService'
import {
Assistant,
EFFORT_RATIO,
FileType,
FileTypes,
GenerateImageParams,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
ToolCallResponse,
WebSearchSource
} from '@renderer/types'
import { ChunkType, LLMWebSearchCompleteChunk } from '@renderer/types/chunk'
import { Message } from '@renderer/types/newMessage'
import {
GeminiOptions,
GeminiSdkMessageParam,
GeminiSdkParams,
GeminiSdkRawChunk,
GeminiSdkRawOutput,
GeminiSdkToolCall
} from '@renderer/types/sdk'
import {
geminiFunctionCallToMcpTool,
isEnabledToolUse,
mcpToolCallResponseToGeminiMessage,
mcpToolsToGeminiTools
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { MB } from '@shared/config/constant'
import { BaseApiClient } from '../BaseApiClient'
import { RequestTransformer, ResponseChunkTransformer } from '../types'
export class GeminiAPIClient extends BaseApiClient<
GoogleGenAI,
GeminiSdkParams,
GeminiSdkRawOutput,
GeminiSdkRawChunk,
GeminiSdkMessageParam,
GeminiSdkToolCall,
Tool
> {
constructor(provider: Provider) {
super(provider)
}
override async createCompletions(payload: GeminiSdkParams, options?: GeminiOptions): Promise<GeminiSdkRawOutput> {
const sdk = await this.getSdkInstance()
const { model, history, ...rest } = payload
const realPayload: Omit<GeminiSdkParams, 'model'> = {
...rest,
config: {
...rest.config,
abortSignal: options?.abortSignal,
httpOptions: {
...rest.config?.httpOptions,
timeout: options?.timeout
}
}
} satisfies SendMessageParameters
const streamOutput = options?.streamOutput
const chat = sdk.chats.create({
model: model,
history: history
})
if (streamOutput) {
const stream = chat.sendMessageStream(realPayload)
return stream
} else {
const response = await chat.sendMessage(realPayload)
return response
}
}
override async generateImage(generateImageParams: GenerateImageParams): Promise<string[]> {
const sdk = await this.getSdkInstance()
try {
const { model, prompt, imageSize, batchSize, signal } = generateImageParams
const config: GenerateImagesConfig = {
numberOfImages: batchSize,
aspectRatio: imageSize,
abortSignal: signal,
httpOptions: {
timeout: 5 * 60 * 1000
}
}
const response = await sdk.models.generateImages({
model: model,
prompt,
config
})
if (!response.generatedImages || response.generatedImages.length === 0) {
return []
}
const images = response.generatedImages
.filter((image) => image.image?.imageBytes)
.map((image) => {
const dataPrefix = `data:${image.image?.mimeType || 'image/png'};base64,`
return dataPrefix + image.image?.imageBytes
})
// console.log(response?.generatedImages?.[0]?.image?.imageBytes);
return images
} catch (error) {
console.error('[generateImage] error:', error)
throw error
}
}
override async getEmbeddingDimensions(model: Model): Promise<number> {
const sdk = await this.getSdkInstance()
try {
const data = await sdk.models.embedContent({
model: model.id,
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
})
return data.embeddings?.[0]?.values?.length || 0
} catch (e) {
return 0
}
}
override async listModels(): Promise<GeminiModel[]> {
const sdk = await this.getSdkInstance()
const response = await sdk.models.list()
const models: GeminiModel[] = []
for await (const model of response) {
models.push(model)
}
return models
}
override async getSdkInstance() {
if (this.sdkInstance) {
return this.sdkInstance
}
this.sdkInstance = new GoogleGenAI({
vertexai: false,
apiKey: this.apiKey,
httpOptions: { baseUrl: this.getBaseURL() }
})
return this.sdkInstance
}
/**
* Handle a PDF file
* @param file - The file
* @returns The part
*/
private async handlePdfFile(file: FileType): Promise<Part> {
const smallFileSize = 20 * MB
const isSmallFile = file.size < smallFileSize
if (isSmallFile) {
const { data, mimeType } = await this.base64File(file)
return {
inlineData: {
data,
mimeType
} as Part['inlineData']
}
}
// Retrieve file from Gemini uploaded files
const fileMetadata: File | undefined = await this.retrieveFile(file)
if (fileMetadata) {
return {
fileData: {
fileUri: fileMetadata.uri,
mimeType: fileMetadata.mimeType
} as Part['fileData']
}
}
// If file is not found, upload it to Gemini
const result = await this.uploadFile(file)
return {
fileData: {
fileUri: result.uri,
mimeType: result.mimeType
} as Part['fileData']
}
}
/**
* Get the message contents
* @param message - The message
* @returns The message contents
*/
private async convertMessageToSdkParam(message: Message): Promise<Content> {
const role = message.role === 'user' ? 'user' : 'model'
const parts: Part[] = [{ text: await this.getMessageContent(message) }]
// Add any generated images from previous responses
const imageBlocks = findImageBlocks(message)
for (const imageBlock of imageBlocks) {
if (
imageBlock.metadata?.generateImageResponse?.images &&
imageBlock.metadata.generateImageResponse.images.length > 0
) {
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
if (imageUrl && imageUrl.startsWith('data:')) {
// Extract base64 data and mime type from the data URL
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
if (matches && matches.length === 3) {
const mimeType = matches[1]
const base64Data = matches[2]
parts.push({
inlineData: {
data: base64Data,
mimeType: mimeType
} as Part['inlineData']
})
}
}
}
}
const file = imageBlock.file
if (file) {
const base64Data = await window.api.file.base64Image(file.id + file.ext)
parts.push({
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
} as Part['inlineData']
})
}
}
const fileBlocks = findFileBlocks(message)
for (const fileBlock of fileBlocks) {
const file = fileBlock.file
if (file.type === FileTypes.IMAGE) {
const base64Data = await window.api.file.base64Image(file.id + file.ext)
parts.push({
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
} as Part['inlineData']
})
}
if (file.ext === '.pdf') {
parts.push(await this.handlePdfFile(file))
continue
}
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
parts.push({
text: file.origin_name + '\n' + fileContent
})
}
}
return {
role,
parts: parts
}
}
// @ts-ignore unused
private async getImageFileContents(message: Message): Promise<Content> {
const role = message.role === 'user' ? 'user' : 'model'
const content = getMainTextContent(message)
const parts: Part[] = [{ text: content }]
const imageBlocks = findImageBlocks(message)
for (const imageBlock of imageBlocks) {
if (
imageBlock.metadata?.generateImageResponse?.images &&
imageBlock.metadata.generateImageResponse.images.length > 0
) {
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
if (imageUrl && imageUrl.startsWith('data:')) {
// Extract base64 data and mime type from the data URL
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
if (matches && matches.length === 3) {
const mimeType = matches[1]
const base64Data = matches[2]
parts.push({
inlineData: {
data: base64Data,
mimeType: mimeType
} as Part['inlineData']
})
}
}
}
}
const file = imageBlock.file
if (file) {
const base64Data = await window.api.file.base64Image(file.id + file.ext)
parts.push({
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
} as Part['inlineData']
})
}
}
return {
role,
parts: parts
}
}
/**
* Get the safety settings
* @returns The safety settings
*/
private getSafetySettings(): SafetySetting[] {
const safetyThreshold = 'OFF' as HarmBlockThreshold
return [
{
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold: safetyThreshold
},
{
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold: safetyThreshold
},
{
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold: safetyThreshold
},
{
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold: safetyThreshold
},
{
category: HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY,
threshold: HarmBlockThreshold.BLOCK_NONE
}
]
}
/**
* Get the reasoning effort for the assistant
* @param assistant - The assistant
* @param model - The model
* @returns The reasoning effort
*/
private getBudgetToken(assistant: Assistant, model: Model) {
if (isGeminiReasoningModel(model)) {
const reasoningEffort = assistant?.settings?.reasoning_effort
// 如果thinking_budget是undefined,不思考
if (reasoningEffort === undefined) {
return {
thinkingConfig: {
includeThoughts: false,
...(GEMINI_FLASH_MODEL_REGEX.test(model.id) ? { thinkingBudget: 0 } : {})
} as ThinkingConfig
}
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
if (effortRatio > 1) {
return {
thinkingConfig: {
includeThoughts: true
}
}
}
const { max } = findTokenLimit(model.id) || { max: 0 }
const budget = Math.floor(max * effortRatio)
return {
thinkingConfig: {
...(budget > 0 ? { thinkingBudget: budget } : {}),
includeThoughts: true
} as ThinkingConfig
}
}
return {}
}
private getGenerateImageParameter(): Partial<GenerateContentConfig> {
return {
systemInstruction: undefined,
responseModalities: [Modality.TEXT, Modality.IMAGE],
responseMimeType: 'text/plain'
}
}
getRequestTransformer(): RequestTransformer<GeminiSdkParams, GeminiSdkMessageParam> {
return {
transform: async (
coreRequest,
assistant,
model,
isRecursiveCall,
recursiveSdkMessages
): Promise<{
payload: GeminiSdkParams
messages: GeminiSdkMessageParam[]
metadata: Record<string, any>
}> => {
const { messages, mcpTools, maxTokens, enableWebSearch, enableGenerateImage } = coreRequest
// 1. 处理系统消息
let systemInstruction = assistant.prompt
// 2. 设置工具
const { tools } = this.setupToolsConfig({
mcpTools,
model,
enableToolUse: isEnabledToolUse(assistant)
})
if (this.useSystemPromptForTools) {
systemInstruction = await buildSystemPrompt(assistant.prompt || '', mcpTools)
}
let messageContents: Content
const history: Content[] = []
// 3. 处理用户消息
if (typeof messages === 'string') {
messageContents = {
role: 'user',
parts: [{ text: messages }]
}
} else {
const userLastMessage = messages.pop()!
messageContents = await this.convertMessageToSdkParam(userLastMessage)
for (const message of messages) {
history.push(await this.convertMessageToSdkParam(message))
}
}
if (enableWebSearch) {
tools.push({
googleSearch: {}
})
}
if (isGemmaModel(model) && assistant.prompt) {
const isFirstMessage = history.length === 0
if (isFirstMessage && messageContents) {
const systemMessage = [
{
text:
'<start_of_turn>user\n' +
systemInstruction +
'<end_of_turn>\n' +
'<start_of_turn>user\n' +
(messageContents?.parts?.[0] as Part).text +
'<end_of_turn>'
}
] as Part[]
if (messageContents && messageContents.parts) {
messageContents.parts[0] = systemMessage[0]
}
}
}
const newHistory =
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
? recursiveSdkMessages.slice(0, recursiveSdkMessages.length - 1)
: history
const newMessageContents =
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
? {
...messageContents,
parts: [
...(messageContents.parts || []),
...(recursiveSdkMessages[recursiveSdkMessages.length - 1].parts || [])
]
}
: messageContents
const generateContentConfig: GenerateContentConfig = {
safetySettings: this.getSafetySettings(),
systemInstruction: isGemmaModel(model) ? undefined : systemInstruction,
temperature: this.getTemperature(assistant, model),
topP: this.getTopP(assistant, model),
maxOutputTokens: maxTokens,
tools: tools,
...(enableGenerateImage ? this.getGenerateImageParameter() : {}),
...this.getBudgetToken(assistant, model),
...this.getCustomParameters(assistant)
}
const param: GeminiSdkParams = {
model: model.id,
config: generateContentConfig,
history: newHistory,
message: newMessageContents.parts!
}
return {
payload: param,
messages: [messageContents],
metadata: {}
}
}
}
}
getResponseChunkTransformer(): ResponseChunkTransformer<GeminiSdkRawChunk> {
return () => ({
async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
let toolCalls: FunctionCall[] = []
if (chunk.candidates && chunk.candidates.length > 0) {
for (const candidate of chunk.candidates) {
if (candidate.content) {
candidate.content.parts?.forEach((part) => {
const text = part.text || ''
if (part.thought) {
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: text
})
} else if (part.text) {
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: text
})
} else if (part.inlineData) {
controller.enqueue({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: [
part.inlineData?.data?.startsWith('data:')
? part.inlineData?.data
: `data:${part.inlineData?.mimeType || 'image/png'};base64,${part.inlineData?.data}`
]
}
})
}
})
}
if (candidate.finishReason) {
if (candidate.groundingMetadata) {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: candidate.groundingMetadata,
source: WebSearchSource.GEMINI
}
} as LLMWebSearchCompleteChunk)
}
if (chunk.functionCalls) {
toolCalls = toolCalls.concat(chunk.functionCalls)
}
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
completion_tokens:
(chunk.usageMetadata?.totalTokenCount || 0) - (chunk.usageMetadata?.promptTokenCount || 0),
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
}
}
})
}
}
}
if (toolCalls.length > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
}
}
})
}
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): Tool[] {
return mcpToolsToGeminiTools(mcpTools)
}
public convertSdkToolCallToMcp(toolCall: GeminiSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
return geminiFunctionCallToMcpTool(mcpTools, toolCall)
}
public convertSdkToolCallToMcpToolResponse(toolCall: GeminiSdkToolCall, mcpTool: MCPTool): ToolCallResponse {
const parsedArgs = (() => {
try {
return typeof toolCall.args === 'string' ? JSON.parse(toolCall.args) : toolCall.args
} catch {
return toolCall.args
}
})()
return {
id: toolCall.id || nanoid(),
toolCallId: toolCall.id,
tool: mcpTool,
arguments: parsedArgs,
status: 'pending'
} as ToolCallResponse
}
public convertMcpToolResponseToSdkMessageParam(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): GeminiSdkMessageParam | undefined {
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
return mcpToolCallResponseToGeminiMessage(mcpToolResponse, resp, isVisionModel(model))
} else if ('toolCallId' in mcpToolResponse) {
return {
role: 'user',
parts: [
{
functionResponse: {
id: mcpToolResponse.toolCallId,
name: mcpToolResponse.tool.id,
response: {
output: !resp.isError ? resp.content : undefined,
error: resp.isError ? resp.content : undefined
}
}
}
]
} satisfies Content
}
return
}
public buildSdkMessages(
currentReqMessages: Content[],
output: string,
toolResults: Content[],
toolCalls: FunctionCall[]
): Content[] {
const parts: Part[] = []
if (output) {
parts.push({
text: output
})
}
toolCalls.forEach((toolCall) => {
parts.push({
functionCall: toolCall
})
})
parts.push(
...toolResults
.map((ts) => ts.parts)
.flat()
.filter((p) => p !== undefined)
)
const userMessage: Content = {
role: 'user',
parts: parts
}
return [...currentReqMessages, userMessage]
}
override estimateMessageTokens(message: GeminiSdkMessageParam): number {
return (
message.parts?.reduce((acc, part) => {
if (part.text) {
return acc + estimateTextTokens(part.text)
}
if (part.functionCall) {
return acc + estimateTextTokens(JSON.stringify(part.functionCall))
}
if (part.functionResponse) {
return acc + estimateTextTokens(JSON.stringify(part.functionResponse.response))
}
if (part.inlineData) {
return acc + estimateTextTokens(part.inlineData.data || '')
}
if (part.fileData) {
return acc + estimateTextTokens(part.fileData.fileUri || '')
}
return acc
}, 0) || 0
)
}
public extractMessagesFromSdkPayload(sdkPayload: GeminiSdkParams): GeminiSdkMessageParam[] {
return sdkPayload.history || []
}
private async uploadFile(file: FileType): Promise<File> {
return await this.sdkInstance!.files.upload({
file: file.path,
config: {
mimeType: 'application/pdf',
name: file.id,
displayName: file.origin_name
}
})
}
private async base64File(file: FileType) {
const { data } = await window.api.file.base64File(file.id + file.ext)
return {
data,
mimeType: 'application/pdf'
}
}
private async retrieveFile(file: FileType): Promise<File | undefined> {
const cachedResponse = CacheService.get<any>('gemini_file_list')
if (cachedResponse) {
return this.processResponse(cachedResponse, file)
}
const response = await this.sdkInstance!.files.list()
CacheService.set('gemini_file_list', response, 3000)
return this.processResponse(response, file)
}
private async processResponse(response: Pager<File>, file: FileType) {
for await (const f of response) {
if (f.state === FileState.ACTIVE) {
if (f.displayName === file.origin_name && Number(f.sizeBytes) === file.size) {
return f
}
}
}
return undefined
}
// @ts-ignore unused
private async listFiles(): Promise<File[]> {
const files: File[] = []
for await (const f of await this.sdkInstance!.files.list()) {
files.push(f)
}
return files
}
// @ts-ignore unused
private async deleteFile(fileId: string) {
await this.sdkInstance!.files.delete({ name: fileId })
}
}
+6
View File
@@ -0,0 +1,6 @@
export * from './ApiClientFactory'
export * from './BaseApiClient'
export * from './types'
// Export specific clients from subdirectories
export * from './openai/OpenAIApiClient'
@@ -0,0 +1,682 @@
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import Logger from '@renderer/config/logger'
import {
findTokenLimit,
GEMINI_FLASH_MODEL_REGEX,
getOpenAIWebSearchParams,
isDoubaoThinkingAutoModel,
isReasoningModel,
isSupportedReasoningEffortGrokModel,
isSupportedReasoningEffortModel,
isSupportedReasoningEffortOpenAIModel,
isSupportedThinkingTokenClaudeModel,
isSupportedThinkingTokenDoubaoModel,
isSupportedThinkingTokenGeminiModel,
isSupportedThinkingTokenModel,
isSupportedThinkingTokenQwenModel,
isVisionModel
} from '@renderer/config/models'
import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService'
import { estimateTextTokens } from '@renderer/services/TokenService'
// For Copilot token
import {
Assistant,
EFFORT_RATIO,
FileTypes,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
ToolCallResponse,
WebSearchSource
} from '@renderer/types'
import { ChunkType } from '@renderer/types/chunk'
import { Message } from '@renderer/types/newMessage'
import {
OpenAISdkMessageParam,
OpenAISdkParams,
OpenAISdkRawChunk,
OpenAISdkRawContentSource,
OpenAISdkRawOutput,
ReasoningEffortOptionalParams
} from '@renderer/types/sdk'
import { addImageFileToContents } from '@renderer/utils/formats'
import {
isEnabledToolUse,
mcpToolCallResponseToOpenAICompatibleMessage,
mcpToolsToOpenAIChatTools,
openAIToolsToMcpTool
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import OpenAI, { AzureOpenAI } from 'openai'
import { ChatCompletionContentPart, ChatCompletionContentPartRefusal, ChatCompletionTool } from 'openai/resources'
import { GenericChunk } from '../../middleware/schemas'
import { RequestTransformer, ResponseChunkTransformer, ResponseChunkTransformerContext } from '../types'
import { OpenAIBaseClient } from './OpenAIBaseClient'
export class OpenAIAPIClient extends OpenAIBaseClient<
OpenAI | AzureOpenAI,
OpenAISdkParams,
OpenAISdkRawOutput,
OpenAISdkRawChunk,
OpenAISdkMessageParam,
OpenAI.Chat.Completions.ChatCompletionMessageToolCall,
ChatCompletionTool
> {
constructor(provider: Provider) {
super(provider)
}
override async createCompletions(
payload: OpenAISdkParams,
options?: OpenAI.RequestOptions
): Promise<OpenAISdkRawOutput> {
const sdk = await this.getSdkInstance()
// @ts-ignore - SDK参数可能有额外的字段
return await sdk.chat.completions.create(payload, options)
}
/**
* Get the reasoning effort for the assistant
* @param assistant - The assistant
* @param model - The model
* @returns The reasoning effort
*/
// Method for reasoning effort, moved from OpenAIProvider
override getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams {
if (this.provider.id === 'groq') {
return {}
}
if (!isReasoningModel(model)) {
return {}
}
const reasoningEffort = assistant?.settings?.reasoning_effort
// Doubao 思考模式支持
if (isSupportedThinkingTokenDoubaoModel(model)) {
// reasoningEffort 为空,默认开启 enabled
if (!reasoningEffort) {
return { thinking: { type: 'disabled' } }
}
if (reasoningEffort === 'high') {
return { thinking: { type: 'enabled' } }
}
if (reasoningEffort === 'auto' && isDoubaoThinkingAutoModel(model)) {
return { thinking: { type: 'auto' } }
}
// 其他情况不带 thinking 字段
return {}
}
if (!reasoningEffort) {
if (isSupportedThinkingTokenQwenModel(model)) {
return { enable_thinking: false }
}
if (isSupportedThinkingTokenClaudeModel(model)) {
return {}
}
if (isSupportedThinkingTokenGeminiModel(model)) {
// openrouter没有提供一个不推理的选项,先隐藏
if (this.provider.id === 'openrouter') {
return { reasoning: { max_tokens: 0, exclude: true } }
}
if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
return { reasoning_effort: 'none' }
}
return {}
}
if (isSupportedThinkingTokenDoubaoModel(model)) {
return { thinking: { type: 'disabled' } }
}
return {}
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
const budgetTokens = Math.floor(
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + findTokenLimit(model.id)?.min!
)
// OpenRouter models
if (model.provider === 'openrouter') {
if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) {
return {
reasoning: {
effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort
}
}
}
}
// Qwen models
if (isSupportedThinkingTokenQwenModel(model)) {
return {
enable_thinking: true,
thinking_budget: budgetTokens
}
}
// Grok models
if (isSupportedReasoningEffortGrokModel(model)) {
return {
reasoning_effort: reasoningEffort
}
}
// OpenAI models
if (isSupportedReasoningEffortOpenAIModel(model) || isSupportedThinkingTokenGeminiModel(model)) {
return {
reasoning_effort: reasoningEffort
}
}
// Claude models
if (isSupportedThinkingTokenClaudeModel(model)) {
const maxTokens = assistant.settings?.maxTokens
return {
thinking: {
type: 'enabled',
budget_tokens: Math.floor(
Math.max(1024, Math.min(budgetTokens, (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio))
)
}
}
}
// Doubao models
if (isSupportedThinkingTokenDoubaoModel(model)) {
if (assistant.settings?.reasoning_effort === 'high') {
return {
thinking: {
type: 'enabled'
}
}
}
}
// Default case: no special thinking settings
return {}
}
/**
* Check if the provider does not support files
* @returns True if the provider does not support files, false otherwise
*/
private get isNotSupportFiles() {
if (this.provider?.isNotSupportArrayContent) {
return true
}
const providers = ['deepseek', 'baichuan', 'minimax', 'xirang']
return providers.includes(this.provider.id)
}
/**
* Get the message parameter
* @param message - The message
* @param model - The model
* @returns The message parameter
*/
public async convertMessageToSdkParam(message: Message, model: Model): Promise<OpenAISdkMessageParam> {
const isVision = isVisionModel(model)
const content = await this.getMessageContent(message)
const fileBlocks = findFileBlocks(message)
const imageBlocks = findImageBlocks(message)
if (fileBlocks.length === 0 && imageBlocks.length === 0) {
return {
role: message.role === 'system' ? 'user' : message.role,
content
} as OpenAISdkMessageParam
}
// If the model does not support files, extract the file content
if (this.isNotSupportFiles) {
const fileContent = await this.extractFileContent(message)
return {
role: message.role === 'system' ? 'user' : message.role,
content: content + '\n\n---\n\n' + fileContent
} as OpenAISdkMessageParam
}
// If the model supports files, add the file content to the message
const parts: ChatCompletionContentPart[] = []
if (content) {
parts.push({ type: 'text', text: content })
}
for (const imageBlock of imageBlocks) {
if (isVision) {
if (imageBlock.file) {
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
parts.push({ type: 'image_url', image_url: { url: image.data } })
} else if (imageBlock.url && imageBlock.url.startsWith('data:')) {
parts.push({ type: 'image_url', image_url: { url: imageBlock.url } })
}
}
}
for (const fileBlock of fileBlocks) {
const file = fileBlock.file
if (!file) {
continue
}
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
parts.push({
type: 'text',
text: file.origin_name + '\n' + fileContent
})
}
}
return {
role: message.role === 'system' ? 'user' : message.role,
content: parts
} as OpenAISdkMessageParam
}
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ChatCompletionTool[] {
return mcpToolsToOpenAIChatTools(mcpTools)
}
public convertSdkToolCallToMcp(
toolCall: OpenAI.Chat.Completions.ChatCompletionMessageToolCall,
mcpTools: MCPTool[]
): MCPTool | undefined {
return openAIToolsToMcpTool(mcpTools, toolCall)
}
public convertSdkToolCallToMcpToolResponse(
toolCall: OpenAI.Chat.Completions.ChatCompletionMessageToolCall,
mcpTool: MCPTool
): ToolCallResponse {
let parsedArgs: any
try {
parsedArgs = JSON.parse(toolCall.function.arguments)
} catch {
parsedArgs = toolCall.function.arguments
}
return {
id: toolCall.id,
toolCallId: toolCall.id,
tool: mcpTool,
arguments: parsedArgs,
status: 'pending'
} as ToolCallResponse
}
public convertMcpToolResponseToSdkMessageParam(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): OpenAISdkMessageParam | undefined {
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
// This case is for Anthropic/Claude like tool usage, OpenAI uses tool_call_id
// For OpenAI, we primarily expect toolCallId. This might need adjustment if mixing provider concepts.
return mcpToolCallResponseToOpenAICompatibleMessage(mcpToolResponse, resp, isVisionModel(model))
} else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) {
return {
role: 'tool',
tool_call_id: mcpToolResponse.toolCallId,
content: JSON.stringify(resp.content)
} as OpenAI.Chat.Completions.ChatCompletionToolMessageParam
}
return undefined
}
public buildSdkMessages(
currentReqMessages: OpenAISdkMessageParam[],
output: string,
toolResults: OpenAISdkMessageParam[],
toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[]
): OpenAISdkMessageParam[] {
const assistantMessage: OpenAISdkMessageParam = {
role: 'assistant',
content: output,
tool_calls: toolCalls.length > 0 ? toolCalls : undefined
}
const newReqMessages = [...currentReqMessages, assistantMessage, ...toolResults]
return newReqMessages
}
override estimateMessageTokens(message: OpenAISdkMessageParam): number {
let sum = 0
if (typeof message.content === 'string') {
sum += estimateTextTokens(message.content)
} else if (Array.isArray(message.content)) {
sum += (message.content || [])
.map((part: ChatCompletionContentPart | ChatCompletionContentPartRefusal) => {
switch (part.type) {
case 'text':
return estimateTextTokens(part.text)
case 'image_url':
return estimateTextTokens(part.image_url.url)
case 'input_audio':
return estimateTextTokens(part.input_audio.data)
case 'file':
return estimateTextTokens(part.file.file_data || '')
default:
return 0
}
})
.reduce((acc, curr) => acc + curr, 0)
}
if ('tool_calls' in message && message.tool_calls) {
sum += message.tool_calls.reduce((acc, toolCall) => {
return acc + estimateTextTokens(JSON.stringify(toolCall.function.arguments))
}, 0)
}
return sum
}
public extractMessagesFromSdkPayload(sdkPayload: OpenAISdkParams): OpenAISdkMessageParam[] {
return sdkPayload.messages || []
}
getRequestTransformer(): RequestTransformer<OpenAISdkParams, OpenAISdkMessageParam> {
return {
transform: async (
coreRequest,
assistant,
model,
isRecursiveCall,
recursiveSdkMessages
): Promise<{
payload: OpenAISdkParams
messages: OpenAISdkMessageParam[]
metadata: Record<string, any>
}> => {
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
// 1. 处理系统消息
let systemMessage = { role: 'system', content: assistant.prompt || '' }
if (isSupportedReasoningEffortOpenAIModel(model)) {
systemMessage = {
role: 'developer',
content: `Formatting re-enabled${systemMessage ? '\n' + systemMessage.content : ''}`
}
}
if (model.id.includes('o1-mini') || model.id.includes('o1-preview')) {
systemMessage.role = 'assistant'
}
// 2. 设置工具(必须在this.usesystemPromptForTools前面)
const { tools } = this.setupToolsConfig({
mcpTools: mcpTools,
model,
enableToolUse: isEnabledToolUse(assistant)
})
if (this.useSystemPromptForTools) {
systemMessage.content = await buildSystemPrompt(systemMessage.content || '', mcpTools)
}
// 3. 处理用户消息
const userMessages: OpenAISdkMessageParam[] = []
if (typeof messages === 'string') {
userMessages.push({ role: 'user', content: messages })
} else {
const processedMessages = addImageFileToContents(messages)
for (const message of processedMessages) {
userMessages.push(await this.convertMessageToSdkParam(message, model))
}
}
const lastUserMsg = userMessages.findLast((m) => m.role === 'user')
if (lastUserMsg && isSupportedThinkingTokenQwenModel(model)) {
const postsuffix = '/no_think'
const qwenThinkModeEnabled = assistant.settings?.qwenThinkMode === true
const currentContent = lastUserMsg.content
lastUserMsg.content = processPostsuffixQwen3Model(currentContent, postsuffix, qwenThinkModeEnabled) as any
}
// 4. 最终请求消息
let reqMessages: OpenAISdkMessageParam[]
if (!systemMessage.content) {
reqMessages = [...userMessages]
} else {
reqMessages = [systemMessage, ...userMessages].filter(Boolean) as OpenAISdkMessageParam[]
}
reqMessages = processReqMessages(model, reqMessages)
// 5. 创建通用参数
const commonParams = {
model: model.id,
messages:
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
? recursiveSdkMessages
: reqMessages,
temperature: this.getTemperature(assistant, model),
top_p: this.getTopP(assistant, model),
max_tokens: maxTokens,
tools: tools.length > 0 ? tools : undefined,
service_tier: this.getServiceTier(model),
...this.getProviderSpecificParameters(assistant, model),
...this.getReasoningEffort(assistant, model),
...getOpenAIWebSearchParams(model, enableWebSearch),
...this.getCustomParameters(assistant)
}
// Create the appropriate parameters object based on whether streaming is enabled
const sdkParams: OpenAISdkParams = streamOutput
? {
...commonParams,
stream: true
}
: {
...commonParams,
stream: false
}
const timeout = this.getTimeout(model)
return { payload: sdkParams, messages: reqMessages, metadata: { timeout } }
}
}
}
// 在RawSdkChunkToGenericChunkMiddleware中使用
getResponseChunkTransformer = (): ResponseChunkTransformer<OpenAISdkRawChunk> => {
let hasBeenCollectedWebSearch = false
const collectWebSearchData = (
chunk: OpenAISdkRawChunk,
contentSource: OpenAISdkRawContentSource,
context: ResponseChunkTransformerContext
) => {
if (hasBeenCollectedWebSearch) {
return
}
// OpenAI annotations
// @ts-ignore - annotations may not be in standard type definitions
const annotations = contentSource.annotations || chunk.annotations
if (annotations && annotations.length > 0 && annotations[0].type === 'url_citation') {
hasBeenCollectedWebSearch = true
return {
results: annotations,
source: WebSearchSource.OPENAI
}
}
// Grok citations
// @ts-ignore - citations may not be in standard type definitions
if (context.provider?.id === 'grok' && chunk.citations) {
hasBeenCollectedWebSearch = true
return {
// @ts-ignore - citations may not be in standard type definitions
results: chunk.citations,
source: WebSearchSource.GROK
}
}
// Perplexity citations
// @ts-ignore - citations may not be in standard type definitions
if (context.provider?.id === 'perplexity' && chunk.citations && chunk.citations.length > 0) {
hasBeenCollectedWebSearch = true
return {
// @ts-ignore - citations may not be in standard type definitions
results: chunk.citations,
source: WebSearchSource.PERPLEXITY
}
}
// OpenRouter citations
// @ts-ignore - citations may not be in standard type definitions
if (context.provider?.id === 'openrouter' && chunk.citations && chunk.citations.length > 0) {
hasBeenCollectedWebSearch = true
return {
// @ts-ignore - citations may not be in standard type definitions
results: chunk.citations,
source: WebSearchSource.OPENROUTER
}
}
// Zhipu web search
// @ts-ignore - web_search may not be in standard type definitions
if (context.provider?.id === 'zhipu' && chunk.web_search) {
hasBeenCollectedWebSearch = true
return {
// @ts-ignore - web_search may not be in standard type definitions
results: chunk.web_search,
source: WebSearchSource.ZHIPU
}
}
// Hunyuan web search
// @ts-ignore - search_info may not be in standard type definitions
if (context.provider?.id === 'hunyuan' && chunk.search_info?.search_results) {
hasBeenCollectedWebSearch = true
return {
// @ts-ignore - search_info may not be in standard type definitions
results: chunk.search_info.search_results,
source: WebSearchSource.HUNYUAN
}
}
// TODO: 放到AnthropicApiClient中
// // Other providers...
// // @ts-ignore - web_search may not be in standard type definitions
// if (chunk.web_search) {
// const sourceMap: Record<string, string> = {
// openai: 'openai',
// anthropic: 'anthropic',
// qwenlm: 'qwen'
// }
// const source = sourceMap[context.provider?.id] || 'openai_response'
// return {
// results: chunk.web_search,
// source: source as const
// }
// }
return null
}
const toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] = []
return (context: ResponseChunkTransformerContext) => ({
async transform(chunk: OpenAISdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
// 处理chunk
if ('choices' in chunk && chunk.choices && chunk.choices.length > 0) {
const choice = chunk.choices[0]
if (!choice) return
// 对于流式响应,使用delta;对于非流式响应,使用message
const contentSource: OpenAISdkRawContentSource | null =
'delta' in choice ? choice.delta : 'message' in choice ? choice.message : null
if (!contentSource) return
const webSearchData = collectWebSearchData(chunk, contentSource, context)
if (webSearchData) {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: webSearchData
})
}
// 处理推理内容 (e.g. from OpenRouter DeepSeek-R1)
// @ts-ignore - reasoning_content is not in standard OpenAI types but some providers use it
const reasoningText = contentSource.reasoning_content || contentSource.reasoning
if (reasoningText) {
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: reasoningText
})
}
// 处理文本内容
if (contentSource.content) {
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: contentSource.content
})
}
// 处理工具调用
if (contentSource.tool_calls) {
for (const toolCall of contentSource.tool_calls) {
if ('index' in toolCall) {
const { id, index, function: fun } = toolCall
if (fun?.name) {
toolCalls[index] = {
id: id || '',
function: {
name: fun.name,
arguments: fun.arguments || ''
},
type: 'function'
}
} else if (fun?.arguments) {
toolCalls[index].function.arguments += fun.arguments
}
} else {
toolCalls.push(toolCall)
}
}
}
// 处理finish_reason,发送流结束信号
if ('finish_reason' in choice && choice.finish_reason) {
Logger.debug(`[OpenAIApiClient] Stream finished with reason: ${choice.finish_reason}`)
if (toolCalls.length > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
}
const webSearchData = collectWebSearchData(chunk, contentSource, context)
if (webSearchData) {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: webSearchData
})
}
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: chunk.usage?.prompt_tokens || 0,
completion_tokens: chunk.usage?.completion_tokens || 0,
total_tokens: (chunk.usage?.prompt_tokens || 0) + (chunk.usage?.completion_tokens || 0)
}
}
})
}
}
}
})
}
}
@@ -0,0 +1,258 @@
import {
isClaudeReasoningModel,
isNotSupportTemperatureAndTopP,
isOpenAIReasoningModel,
isSupportedModel,
isSupportedReasoningEffortOpenAIModel
} from '@renderer/config/models'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import store from '@renderer/store'
import { SettingsState } from '@renderer/store/settings'
import { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
import {
OpenAIResponseSdkMessageParam,
OpenAIResponseSdkParams,
OpenAIResponseSdkRawChunk,
OpenAIResponseSdkRawOutput,
OpenAIResponseSdkTool,
OpenAIResponseSdkToolCall,
OpenAISdkMessageParam,
OpenAISdkParams,
OpenAISdkRawChunk,
OpenAISdkRawOutput,
ReasoningEffortOptionalParams
} from '@renderer/types/sdk'
import { formatApiHost } from '@renderer/utils/api'
import OpenAI, { AzureOpenAI } from 'openai'
import { BaseApiClient } from '../BaseApiClient'
/**
* 抽象的OpenAI基础客户端类,包含两个OpenAI客户端之间的共享功能
*/
export abstract class OpenAIBaseClient<
TSdkInstance extends OpenAI | AzureOpenAI,
TSdkParams extends OpenAISdkParams | OpenAIResponseSdkParams,
TRawOutput extends OpenAISdkRawOutput | OpenAIResponseSdkRawOutput,
TRawChunk extends OpenAISdkRawChunk | OpenAIResponseSdkRawChunk,
TMessageParam extends OpenAISdkMessageParam | OpenAIResponseSdkMessageParam,
TToolCall extends OpenAI.Chat.Completions.ChatCompletionMessageToolCall | OpenAIResponseSdkToolCall,
TSdkSpecificTool extends OpenAI.Chat.Completions.ChatCompletionTool | OpenAIResponseSdkTool
> extends BaseApiClient<TSdkInstance, TSdkParams, TRawOutput, TRawChunk, TMessageParam, TToolCall, TSdkSpecificTool> {
constructor(provider: Provider) {
super(provider)
}
// 仅适用于openai
override getBaseURL(): string {
const host = this.provider.apiHost
return formatApiHost(host)
}
override async generateImage({
model,
prompt,
negativePrompt,
imageSize,
batchSize,
seed,
numInferenceSteps,
guidanceScale,
signal,
promptEnhancement
}: GenerateImageParams): Promise<string[]> {
const sdk = await this.getSdkInstance()
const response = (await sdk.request({
method: 'post',
path: '/images/generations',
signal,
body: {
model,
prompt,
negative_prompt: negativePrompt,
image_size: imageSize,
batch_size: batchSize,
seed: seed ? parseInt(seed) : undefined,
num_inference_steps: numInferenceSteps,
guidance_scale: guidanceScale,
prompt_enhancement: promptEnhancement
}
})) as { data: Array<{ url: string }> }
return response.data.map((item) => item.url)
}
override async getEmbeddingDimensions(model: Model): Promise<number> {
const sdk = await this.getSdkInstance()
try {
const data = await sdk.embeddings.create({
model: model.id,
input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi',
encoding_format: 'float'
})
return data.data[0].embedding.length
} catch (e) {
return 0
}
}
override async listModels(): Promise<OpenAI.Models.Model[]> {
try {
const sdk = await this.getSdkInstance()
const response = await sdk.models.list()
if (this.provider.id === 'github') {
// @ts-ignore key is not typed
return response?.body
.map((model) => ({
id: model.name,
description: model.summary,
object: 'model',
owned_by: model.publisher
}))
.filter(isSupportedModel)
}
if (this.provider.id === 'together') {
// @ts-ignore key is not typed
return response?.body.map((model) => ({
id: model.id,
description: model.display_name,
object: 'model',
owned_by: model.organization
}))
}
const models = response.data || []
models.forEach((model) => {
model.id = model.id.trim()
})
return models.filter(isSupportedModel)
} catch (error) {
console.error('Error listing models:', error)
return []
}
}
override async getSdkInstance() {
if (this.sdkInstance) {
return this.sdkInstance
}
let apiKeyForSdkInstance = this.provider.apiKey
if (this.provider.id === 'copilot') {
const defaultHeaders = store.getState().copilot.defaultHeaders
const { token } = await window.api.copilot.getToken(defaultHeaders)
// this.provider.apiKey不允许修改
// this.provider.apiKey = token
apiKeyForSdkInstance = token
}
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
this.sdkInstance = new AzureOpenAI({
dangerouslyAllowBrowser: true,
apiKey: apiKeyForSdkInstance,
apiVersion: this.provider.apiVersion,
endpoint: this.provider.apiHost
}) as TSdkInstance
} else {
this.sdkInstance = new OpenAI({
dangerouslyAllowBrowser: true,
apiKey: apiKeyForSdkInstance,
baseURL: this.getBaseURL(),
defaultHeaders: {
...this.defaultHeaders(),
...(this.provider.id === 'copilot' ? { 'editor-version': 'vscode/1.97.2' } : {}),
...(this.provider.id === 'copilot' ? { 'copilot-vision-request': 'true' } : {})
}
}) as TSdkInstance
}
return this.sdkInstance
}
override getTemperature(assistant: Assistant, model: Model): number | undefined {
if (
isNotSupportTemperatureAndTopP(model) ||
(assistant.settings?.reasoning_effort && isClaudeReasoningModel(model))
) {
return undefined
}
return assistant.settings?.temperature
}
override getTopP(assistant: Assistant, model: Model): number | undefined {
if (
isNotSupportTemperatureAndTopP(model) ||
(assistant.settings?.reasoning_effort && isClaudeReasoningModel(model))
) {
return undefined
}
return assistant.settings?.topP
}
/**
* Get the provider specific parameters for the assistant
* @param assistant - The assistant
* @param model - The model
* @returns The provider specific parameters
*/
protected getProviderSpecificParameters(assistant: Assistant, model: Model) {
const { maxTokens } = getAssistantSettings(assistant)
if (this.provider.id === 'openrouter') {
if (model.id.includes('deepseek-r1')) {
return {
include_reasoning: true
}
}
}
if (isOpenAIReasoningModel(model)) {
return {
max_tokens: undefined,
max_completion_tokens: maxTokens
}
}
return {}
}
/**
* Get the reasoning effort for the assistant
* @param assistant - The assistant
* @param model - The model
* @returns The reasoning effort
*/
protected getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams {
if (!isSupportedReasoningEffortOpenAIModel(model)) {
return {}
}
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
const summaryText = openAI?.summaryText || 'off'
let summary: string | undefined = undefined
if (summaryText === 'off' || model.id.includes('o1-pro')) {
summary = undefined
} else {
summary = summaryText
}
const reasoningEffort = assistant?.settings?.reasoning_effort
if (!reasoningEffort) {
return {}
}
if (isSupportedReasoningEffortOpenAIModel(model)) {
return {
reasoning: {
effort: reasoningEffort as OpenAI.ReasoningEffort,
summary: summary
} as OpenAI.Reasoning
}
}
return {}
}
}
@@ -0,0 +1,532 @@
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import {
isOpenAIChatCompletionOnlyModel,
isSupportedReasoningEffortOpenAIModel,
isVisionModel
} from '@renderer/config/models'
import { estimateTextTokens } from '@renderer/services/TokenService'
import {
FileTypes,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Model,
Provider,
ToolCallResponse,
WebSearchSource
} from '@renderer/types'
import { ChunkType } from '@renderer/types/chunk'
import { Message } from '@renderer/types/newMessage'
import {
OpenAIResponseSdkMessageParam,
OpenAIResponseSdkParams,
OpenAIResponseSdkRawChunk,
OpenAIResponseSdkRawOutput,
OpenAIResponseSdkTool,
OpenAIResponseSdkToolCall
} from '@renderer/types/sdk'
import { addImageFileToContents } from '@renderer/utils/formats'
import {
isEnabledToolUse,
mcpToolCallResponseToOpenAIMessage,
mcpToolsToOpenAIResponseTools,
openAIToolsToMcpTool
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { isEmpty } from 'lodash'
import OpenAI from 'openai'
import { RequestTransformer, ResponseChunkTransformer } from '../types'
import { OpenAIAPIClient } from './OpenAIApiClient'
import { OpenAIBaseClient } from './OpenAIBaseClient'
export class OpenAIResponseAPIClient extends OpenAIBaseClient<
OpenAI,
OpenAIResponseSdkParams,
OpenAIResponseSdkRawOutput,
OpenAIResponseSdkRawChunk,
OpenAIResponseSdkMessageParam,
OpenAIResponseSdkToolCall,
OpenAIResponseSdkTool
> {
private client: OpenAIAPIClient
constructor(provider: Provider) {
super(provider)
this.client = new OpenAIAPIClient(provider)
}
/**
* 根据模型特征选择合适的客户端
*/
public getClient(model: Model) {
if (isOpenAIChatCompletionOnlyModel(model)) {
return this.client
} else {
return this
}
}
override async getSdkInstance() {
if (this.sdkInstance) {
return this.sdkInstance
}
return new OpenAI({
dangerouslyAllowBrowser: true,
apiKey: this.provider.apiKey,
baseURL: this.getBaseURL(),
defaultHeaders: {
...this.defaultHeaders()
}
})
}
override async createCompletions(
payload: OpenAIResponseSdkParams,
options?: OpenAI.RequestOptions
): Promise<OpenAIResponseSdkRawOutput> {
const sdk = await this.getSdkInstance()
return await sdk.responses.create(payload, options)
}
public async convertMessageToSdkParam(message: Message, model: Model): Promise<OpenAIResponseSdkMessageParam> {
const isVision = isVisionModel(model)
const content = await this.getMessageContent(message)
const fileBlocks = findFileBlocks(message)
const imageBlocks = findImageBlocks(message)
if (fileBlocks.length === 0 && imageBlocks.length === 0) {
if (message.role === 'assistant') {
return {
role: 'assistant',
content: content
}
} else {
return {
role: message.role === 'system' ? 'user' : message.role,
content: content ? [{ type: 'input_text', text: content }] : []
} as OpenAI.Responses.EasyInputMessage
}
}
const parts: OpenAI.Responses.ResponseInputContent[] = []
if (content) {
parts.push({
type: 'input_text',
text: content
})
}
for (const imageBlock of imageBlocks) {
if (isVision) {
if (imageBlock.file) {
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
parts.push({
detail: 'auto',
type: 'input_image',
image_url: image.data as string
})
} else if (imageBlock.url && imageBlock.url.startsWith('data:')) {
parts.push({
detail: 'auto',
type: 'input_image',
image_url: imageBlock.url
})
}
}
}
for (const fileBlock of fileBlocks) {
const file = fileBlock.file
if (!file) continue
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
parts.push({
type: 'input_text',
text: file.origin_name + '\n' + fileContent
})
}
}
return {
role: message.role === 'system' ? 'user' : message.role,
content: parts
}
}
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): OpenAI.Responses.Tool[] {
return mcpToolsToOpenAIResponseTools(mcpTools)
}
public convertSdkToolCallToMcp(toolCall: OpenAIResponseSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
return openAIToolsToMcpTool(mcpTools, toolCall)
}
public convertSdkToolCallToMcpToolResponse(toolCall: OpenAIResponseSdkToolCall, mcpTool: MCPTool): ToolCallResponse {
const parsedArgs = (() => {
try {
return JSON.parse(toolCall.arguments)
} catch {
return toolCall.arguments
}
})()
return {
id: toolCall.call_id,
toolCallId: toolCall.call_id,
tool: mcpTool,
arguments: parsedArgs,
status: 'pending'
}
}
public convertMcpToolResponseToSdkMessageParam(
mcpToolResponse: MCPToolResponse,
resp: MCPCallToolResponse,
model: Model
): OpenAIResponseSdkMessageParam | undefined {
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
return mcpToolCallResponseToOpenAIMessage(mcpToolResponse, resp, isVisionModel(model))
} else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) {
return {
type: 'function_call_output',
call_id: mcpToolResponse.toolCallId,
output: JSON.stringify(resp.content)
}
}
return
}
public buildSdkMessages(
currentReqMessages: OpenAIResponseSdkMessageParam[],
output: string,
toolResults: OpenAIResponseSdkMessageParam[],
toolCalls: OpenAIResponseSdkToolCall[]
): OpenAIResponseSdkMessageParam[] {
const assistantMessage: OpenAIResponseSdkMessageParam = {
role: 'assistant',
content: [{ type: 'input_text', text: output }]
}
const newReqMessages = [...currentReqMessages, assistantMessage, ...(toolCalls || []), ...(toolResults || [])]
return newReqMessages
}
override estimateMessageTokens(message: OpenAIResponseSdkMessageParam): number {
let sum = 0
if ('content' in message) {
if (typeof message.content === 'string') {
sum += estimateTextTokens(message.content)
} else if (Array.isArray(message.content)) {
for (const part of message.content) {
switch (part.type) {
case 'input_text':
sum += estimateTextTokens(part.text)
break
case 'input_image':
sum += estimateTextTokens(part.image_url || '')
break
default:
break
}
}
}
}
switch (message.type) {
case 'function_call_output':
sum += estimateTextTokens(message.output)
break
case 'function_call':
sum += estimateTextTokens(message.arguments)
break
default:
break
}
return sum
}
public extractMessagesFromSdkPayload(sdkPayload: OpenAIResponseSdkParams): OpenAIResponseSdkMessageParam[] {
if (typeof sdkPayload.input === 'string') {
return [{ role: 'user', content: sdkPayload.input }]
}
return sdkPayload.input
}
getRequestTransformer(): RequestTransformer<OpenAIResponseSdkParams, OpenAIResponseSdkMessageParam> {
return {
transform: async (
coreRequest,
assistant,
model,
isRecursiveCall,
recursiveSdkMessages
): Promise<{
payload: OpenAIResponseSdkParams
messages: OpenAIResponseSdkMessageParam[]
metadata: Record<string, any>
}> => {
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch, enableGenerateImage } = coreRequest
// 1. 处理系统消息
const systemMessage: OpenAI.Responses.EasyInputMessage = {
role: 'system',
content: []
}
const systemMessageContent: OpenAI.Responses.ResponseInputMessageContentList = []
const systemMessageInput: OpenAI.Responses.ResponseInputText = {
text: assistant.prompt || '',
type: 'input_text'
}
if (isSupportedReasoningEffortOpenAIModel(model)) {
systemMessage.role = 'developer'
}
// 2. 设置工具
let tools: OpenAI.Responses.Tool[] = []
const { tools: extraTools } = this.setupToolsConfig({
mcpTools: mcpTools,
model,
enableToolUse: isEnabledToolUse(assistant)
})
if (this.useSystemPromptForTools) {
systemMessageInput.text = await buildSystemPrompt(systemMessageInput.text || '', mcpTools)
}
systemMessageContent.push(systemMessageInput)
systemMessage.content = systemMessageContent
// 3. 处理用户消息
let userMessage: OpenAI.Responses.ResponseInputItem[] = []
if (typeof messages === 'string') {
userMessage.push({ role: 'user', content: messages })
} else {
const processedMessages = addImageFileToContents(messages)
for (const message of processedMessages) {
userMessage.push(await this.convertMessageToSdkParam(message, model))
}
}
// FIXME: 最好还是直接使用previous_response_id来处理(或者在数据库中存储image_generation_call的id
if (enableGenerateImage) {
const finalAssistantMessage = userMessage.findLast(
(m) => (m as OpenAI.Responses.EasyInputMessage).role === 'assistant'
) as OpenAI.Responses.EasyInputMessage
const finalUserMessage = userMessage.pop() as OpenAI.Responses.EasyInputMessage
if (
finalAssistantMessage &&
Array.isArray(finalAssistantMessage.content) &&
finalUserMessage &&
Array.isArray(finalUserMessage.content)
) {
finalAssistantMessage.content = [...finalAssistantMessage.content, ...finalUserMessage.content]
}
// 这里是故意将上条助手消息的内容(包含图片和文件)作为用户消息发送
userMessage = [{ ...finalAssistantMessage, role: 'user' } as OpenAI.Responses.EasyInputMessage]
}
// 4. 最终请求消息
let reqMessages: OpenAI.Responses.ResponseInput
if (!systemMessage.content) {
reqMessages = [...userMessage]
} else {
reqMessages = [systemMessage, ...userMessage].filter(Boolean) as OpenAI.Responses.EasyInputMessage[]
}
if (enableWebSearch) {
tools.push({
type: 'web_search_preview'
})
}
if (enableGenerateImage) {
tools.push({
type: 'image_generation',
partial_images: streamOutput ? 2 : undefined
})
}
const toolChoices: OpenAI.Responses.ToolChoiceTypes = {
type: 'web_search_preview'
}
tools = tools.concat(extraTools)
const commonParams = {
model: model.id,
input:
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
? recursiveSdkMessages
: reqMessages,
temperature: this.getTemperature(assistant, model),
top_p: this.getTopP(assistant, model),
max_output_tokens: maxTokens,
stream: streamOutput,
tools: !isEmpty(tools) ? tools : undefined,
tool_choice: enableWebSearch ? toolChoices : undefined,
service_tier: this.getServiceTier(model),
...(this.getReasoningEffort(assistant, model) as OpenAI.Reasoning),
...this.getCustomParameters(assistant)
}
const sdkParams: OpenAIResponseSdkParams = streamOutput
? {
...commonParams,
stream: true
}
: {
...commonParams,
stream: false
}
const timeout = this.getTimeout(model)
return { payload: sdkParams, messages: reqMessages, metadata: { timeout } }
}
}
}
getResponseChunkTransformer(): ResponseChunkTransformer<OpenAIResponseSdkRawChunk> {
const toolCalls: OpenAIResponseSdkToolCall[] = []
const outputItems: OpenAI.Responses.ResponseOutputItem[] = []
return () => ({
async transform(chunk: OpenAIResponseSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
// 处理chunk
if ('output' in chunk) {
for (const output of chunk.output) {
switch (output.type) {
case 'message':
if (output.content[0].type === 'output_text') {
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: output.content[0].text
})
if (output.content[0].annotations && output.content[0].annotations.length > 0) {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
source: WebSearchSource.OPENAI_RESPONSE,
results: output.content[0].annotations
}
})
}
}
break
case 'reasoning':
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: output.summary.map((s) => s.text).join('\n')
})
break
case 'function_call':
toolCalls.push(output)
break
case 'image_generation_call':
controller.enqueue({
type: ChunkType.IMAGE_CREATED
})
controller.enqueue({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: [`data:image/png;base64,${output.result}`]
}
})
}
}
} else {
switch (chunk.type) {
case 'response.output_item.added':
if (chunk.item.type === 'function_call') {
outputItems.push(chunk.item)
}
break
case 'response.reasoning_summary_text.delta':
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: chunk.delta
})
break
case 'response.image_generation_call.generating':
controller.enqueue({
type: ChunkType.IMAGE_CREATED
})
break
case 'response.image_generation_call.partial_image':
controller.enqueue({
type: ChunkType.IMAGE_DELTA,
image: {
type: 'base64',
images: [`data:image/png;base64,${chunk.partial_image_b64}`]
}
})
break
case 'response.image_generation_call.completed':
controller.enqueue({
type: ChunkType.IMAGE_COMPLETE
})
break
case 'response.output_text.delta': {
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: chunk.delta
})
break
}
case 'response.function_call_arguments.done': {
const outputItem: OpenAI.Responses.ResponseOutputItem | undefined = outputItems.find(
(item) => item.id === chunk.item_id
)
if (outputItem) {
if (outputItem.type === 'function_call') {
toolCalls.push({
...outputItem,
arguments: chunk.arguments
})
}
}
break
}
case 'response.content_part.done': {
if (chunk.part.type === 'output_text' && chunk.part.annotations && chunk.part.annotations.length > 0) {
controller.enqueue({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
source: WebSearchSource.OPENAI_RESPONSE,
results: chunk.part.annotations
}
})
}
if (toolCalls.length > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
}
break
}
case 'response.completed': {
const completion_tokens = chunk.response.usage?.output_tokens || 0
const total_tokens = chunk.response.usage?.total_tokens || 0
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: chunk.response.usage?.input_tokens || 0,
completion_tokens: completion_tokens,
total_tokens: total_tokens
}
}
})
break
}
case 'error': {
controller.enqueue({
type: ChunkType.ERROR,
error: {
message: chunk.message,
code: chunk.code
}
})
break
}
}
}
}
})
}
}
+129
View File
@@ -0,0 +1,129 @@
import Anthropic from '@anthropic-ai/sdk'
import { Assistant, MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types'
import { Provider } from '@renderer/types'
import {
AnthropicSdkRawChunk,
OpenAISdkRawChunk,
SdkMessageParam,
SdkParams,
SdkRawChunk,
SdkRawOutput,
SdkTool,
SdkToolCall
} from '@renderer/types/sdk'
import OpenAI from 'openai'
import { CompletionsParams, GenericChunk } from '../middleware/schemas'
/**
* 原始流监听器接口
*/
export interface RawStreamListener<TRawChunk = SdkRawChunk> {
onChunk?: (chunk: TRawChunk) => void
onStart?: () => void
onEnd?: () => void
onError?: (error: Error) => void
}
/**
* OpenAI 专用的流监听器
*/
export interface OpenAIStreamListener extends RawStreamListener<OpenAISdkRawChunk> {
onChoice?: (choice: OpenAI.Chat.Completions.ChatCompletionChunk.Choice) => void
onFinishReason?: (reason: string) => void
}
/**
* Anthropic 专用的流监听器
*/
export interface AnthropicStreamListener<TChunk extends AnthropicSdkRawChunk = AnthropicSdkRawChunk>
extends RawStreamListener<TChunk> {
onContentBlock?: (contentBlock: Anthropic.Messages.ContentBlock) => void
onMessage?: (message: Anthropic.Messages.Message) => void
}
/**
* 请求转换器接口
*/
export interface RequestTransformer<
TSdkParams extends SdkParams = SdkParams,
TMessageParam extends SdkMessageParam = SdkMessageParam
> {
transform(
completionsParams: CompletionsParams,
assistant: Assistant,
model: Model,
isRecursiveCall?: boolean,
recursiveSdkMessages?: TMessageParam[]
): Promise<{
payload: TSdkParams
messages: TMessageParam[]
metadata?: Record<string, any>
}>
}
/**
* 响应块转换器接口
*/
export type ResponseChunkTransformer<TRawChunk extends SdkRawChunk = SdkRawChunk, TContext = any> = (
context?: TContext
) => Transformer<TRawChunk, GenericChunk>
export interface ResponseChunkTransformerContext {
isStreaming: boolean
isEnabledToolCalling: boolean
isEnabledWebSearch: boolean
isEnabledReasoning: boolean
mcpTools: MCPTool[]
provider: Provider
}
/**
* API客户端接口
*/
export interface ApiClient<
TSdkInstance = any,
TSdkParams extends SdkParams = SdkParams,
TRawOutput extends SdkRawOutput = SdkRawOutput,
TRawChunk extends SdkRawChunk = SdkRawChunk,
TMessageParam extends SdkMessageParam = SdkMessageParam,
TToolCall extends SdkToolCall = SdkToolCall,
TSdkSpecificTool extends SdkTool = SdkTool
> {
provider: Provider
// 核心方法 - 在中间件架构中,这个方法可能只是一个占位符
// 实际的SDK调用由SdkCallMiddleware处理
// completions(params: CompletionsParams): Promise<CompletionsResult>
createCompletions(payload: TSdkParams): Promise<TRawOutput>
// SDK相关方法
getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk>
// 原始流监听方法
attachRawStreamListener?(rawOutput: TRawOutput, listener: RawStreamListener<TRawChunk>): TRawOutput
// 工具转换相关方法 (保持可选,因为不是所有Provider都支持工具)
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[]
convertMcpToolResponseToSdkMessageParam?(
mcpToolResponse: MCPToolResponse,
resp: any,
model: Model
): TMessageParam | undefined
convertSdkToolCallToMcp?(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined
convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse
// 构建SDK特定的消息列表,用于工具调用后的递归调用
buildSdkMessages(
currentReqMessages: TMessageParam[],
output: TRawOutput | string,
toolResults: TMessageParam[],
toolCalls?: TToolCall[]
): TMessageParam[]
// 从SDK载荷中提取消息数组(用于中间件中的类型安全访问)
extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
}
+130
View File
@@ -0,0 +1,130 @@
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
import { RequestOptions, SdkModel } from '@renderer/types/sdk'
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
import { OpenAIAPIClient } from './clients'
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
import { AnthropicAPIClient } from './clients/anthropic/AnthropicAPIClient'
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
import { CompletionsMiddlewareBuilder } from './middleware/builder'
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
import { applyCompletionsMiddlewares } from './middleware/composer'
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
import { MIDDLEWARE_NAME as ThinkChunkMiddlewareName } from './middleware/core/ThinkChunkMiddleware'
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
import { MiddlewareRegistry } from './middleware/register'
import { CompletionsParams, CompletionsResult } from './middleware/schemas'
export default class AiProvider {
private apiClient: BaseApiClient
constructor(provider: Provider) {
// Use the new ApiClientFactory to get a BaseApiClient instance
this.apiClient = ApiClientFactory.create(provider)
}
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
// 1. 根据模型识别正确的客户端
const model = params.assistant.model
if (!model) {
return Promise.reject(new Error('Model is required'))
}
// 根据client类型选择合适的处理方式
let client: BaseApiClient
if (this.apiClient instanceof AihubmixAPIClient) {
// AihubmixAPIClient: 根据模型选择合适的子client
client = this.apiClient.getClientForModel(model)
if (client instanceof OpenAIResponseAPIClient) {
client = client.getClient(model) as BaseApiClient
}
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
// OpenAIResponseAPIClient: 根据模型特征选择API类型
client = this.apiClient.getClient(model) as BaseApiClient
} else {
// 其他client直接使用
client = this.apiClient
}
// 2. 构建中间件链
const builder = CompletionsMiddlewareBuilder.withDefaults()
// images api
if (isDedicatedImageGenerationModel(model)) {
builder.clear()
builder
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
} else {
// Existing logic for other models
if (!params.enableReasoning) {
builder.remove(ThinkingTagExtractionMiddlewareName)
builder.remove(ThinkChunkMiddlewareName)
}
// 注意:用client判断会导致typescript类型收窄
if (!(this.apiClient instanceof OpenAIAPIClient)) {
builder.remove(ThinkingTagExtractionMiddlewareName)
}
if (!(this.apiClient instanceof AnthropicAPIClient)) {
builder.remove(RawStreamListenerMiddlewareName)
}
if (!params.enableWebSearch) {
builder.remove(WebSearchMiddlewareName)
}
if (!params.mcpTools?.length) {
builder.remove(ToolUseExtractionMiddlewareName)
builder.remove(McpToolChunkMiddlewareName)
}
if (isEnabledToolUse(params.assistant) && isFunctionCallingModel(model)) {
builder.remove(ToolUseExtractionMiddlewareName)
}
if (params.callType !== 'chat') {
builder.remove(AbortHandlerMiddlewareName)
}
}
const middlewares = builder.build()
// 3. Create the wrapped SDK method with middlewares
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
// 4. Execute the wrapped method with the original params
return wrappedCompletionMethod(params, options)
}
public async models(): Promise<SdkModel[]> {
return this.apiClient.listModels()
}
public async getEmbeddingDimensions(model: Model): Promise<number> {
try {
// Use the SDK instance to test embedding capabilities
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
return dimensions
} catch (error) {
console.error('Error getting embedding dimensions:', error)
return 0
}
}
public async generateImage(params: GenerateImageParams): Promise<string[]> {
return this.apiClient.generateImage(params)
}
public getBaseURL(): string {
return this.apiClient.getBaseURL()
}
public getApiKey(): string {
return this.apiClient.getApiKey()
}
}
@@ -0,0 +1,182 @@
# MiddlewareBuilder 使用指南
`MiddlewareBuilder` 是一个用于动态构建和管理中间件链的工具,提供灵活的中间件组织和配置能力。
## 主要特性
### 1. 统一的中间件命名
所有中间件都通过导出的 `MIDDLEWARE_NAME` 常量标识:
```typescript
// 中间件文件示例
export const MIDDLEWARE_NAME = 'SdkCallMiddleware'
export const SdkCallMiddleware: CompletionsMiddleware = ...
```
### 2. NamedMiddleware 接口
中间件使用统一的 `NamedMiddleware` 接口格式:
```typescript
interface NamedMiddleware<TMiddleware = any> {
name: string
middleware: TMiddleware
}
```
### 3. 中间件注册表
通过 `MiddlewareRegistry` 集中管理所有可用中间件:
```typescript
import { MiddlewareRegistry } from './register'
// 通过名称获取中间件
const sdkCallMiddleware = MiddlewareRegistry['SdkCallMiddleware']
```
## 基本用法
### 1. 使用默认中间件链
```typescript
import { CompletionsMiddlewareBuilder } from './builder'
const builder = CompletionsMiddlewareBuilder.withDefaults()
const middlewares = builder.build()
```
### 2. 自定义中间件链
```typescript
import { createCompletionsBuilder, MiddlewareRegistry } from './builder'
const builder = createCompletionsBuilder([
MiddlewareRegistry['AbortHandlerMiddleware'],
MiddlewareRegistry['TextChunkMiddleware']
])
const middlewares = builder.build()
```
### 3. 动态调整中间件链
```typescript
const builder = CompletionsMiddlewareBuilder.withDefaults()
// 根据条件添加、移除、替换中间件
if (needsLogging) {
builder.prepend(MiddlewareRegistry['GenericLoggingMiddleware'])
}
if (disableTools) {
builder.remove('McpToolChunkMiddleware')
}
if (customThinking) {
builder.replace('ThinkingTagExtractionMiddleware', customThinkingMiddleware)
}
const middlewares = builder.build()
```
### 4. 链式操作
```typescript
const middlewares = CompletionsMiddlewareBuilder.withDefaults()
.add(MiddlewareRegistry['CustomMiddleware'])
.insertBefore('SdkCallMiddleware', MiddlewareRegistry['SecurityCheckMiddleware'])
.remove('WebSearchMiddleware')
.build()
```
## API 参考
### CompletionsMiddlewareBuilder
**静态方法:**
- `static withDefaults()`: 创建带有默认中间件链的构建器
**实例方法:**
- `add(middleware: NamedMiddleware)`: 在链末尾添加中间件
- `prepend(middleware: NamedMiddleware)`: 在链开头添加中间件
- `insertAfter(targetName: string, middleware: NamedMiddleware)`: 在指定中间件后插入
- `insertBefore(targetName: string, middleware: NamedMiddleware)`: 在指定中间件前插入
- `replace(targetName: string, middleware: NamedMiddleware)`: 替换指定中间件
- `remove(targetName: string)`: 移除指定中间件
- `has(name: string)`: 检查是否包含指定中间件
- `build()`: 构建最终的中间件数组
- `getChain()`: 获取当前链(包含名称信息)
- `clear()`: 清空中间件链
- `execute(context, params, middlewareExecutor)`: 直接执行构建好的中间件链
### 工厂函数
- `createCompletionsBuilder(baseChain?)`: 创建 Completions 中间件构建器
- `createMethodBuilder(baseChain?)`: 创建通用方法中间件构建器
- `addMiddlewareName(middleware, name)`: 为中间件添加名称属性的辅助函数
### 中间件注册表
- `MiddlewareRegistry`: 所有注册中间件的集中访问点
- `getMiddleware(name)`: 根据名称获取中间件
- `getRegisteredMiddlewareNames()`: 获取所有注册的中间件名称
- `DefaultCompletionsNamedMiddlewares`: 默认的 Completions 中间件链(NamedMiddleware 格式)
## 类型安全
构建器提供完整的 TypeScript 类型支持:
- `CompletionsMiddlewareBuilder` 专门用于 `CompletionsMiddleware` 类型
- `MethodMiddlewareBuilder` 用于通用的 `MethodMiddleware` 类型
- 所有中间件操作都基于 `NamedMiddleware<TMiddleware>` 接口
## 默认中间件链
默认的 Completions 中间件执行顺序:
1. `FinalChunkConsumerMiddleware` - 最终消费者
2. `TransformCoreToSdkParamsMiddleware` - 参数转换
3. `AbortHandlerMiddleware` - 中止处理
4. `McpToolChunkMiddleware` - 工具处理
5. `WebSearchMiddleware` - Web搜索处理
6. `TextChunkMiddleware` - 文本处理
7. `ThinkingTagExtractionMiddleware` - 思考标签提取处理
8. `ThinkChunkMiddleware` - 思考处理
9. `ResponseTransformMiddleware` - 响应转换
10. `StreamAdapterMiddleware` - 流适配器
11. `SdkCallMiddleware` - SDK调用
## 在 AiProvider 中的使用
```typescript
export default class AiProvider {
public async completions(params: CompletionsParams): Promise<CompletionsResult> {
// 1. 构建中间件链
const builder = CompletionsMiddlewareBuilder.withDefaults()
// 2. 根据参数动态调整
if (params.enableCustomFeature) {
builder.insertAfter('StreamAdapterMiddleware', customFeatureMiddleware)
}
// 3. 应用中间件
const middlewares = builder.build()
const wrappedMethod = applyCompletionsMiddlewares(this.apiClient, this.apiClient.createCompletions, middlewares)
return wrappedMethod(params)
}
}
```
## 注意事项
1. **类型兼容性**`MethodMiddleware``CompletionsMiddleware` 不兼容,需要使用对应的构建器
2. **中间件名称**:所有中间件必须导出 `MIDDLEWARE_NAME` 常量用于标识
3. **注册表管理**:新增中间件需要在 `register.ts` 中注册
4. **默认链**:默认链通过 `DefaultCompletionsNamedMiddlewares` 提供,支持延迟加载避免循环依赖
这种设计使得中间件链的构建既灵活又类型安全,同时保持了简洁的 API 接口。
@@ -0,0 +1,175 @@
# Cherry Studio 中间件规范
本文档定义了 Cherry Studio `aiCore` 模块中中间件的设计、实现和使用规范。目标是建立一个灵活、可维护且易于扩展的中间件系统。
## 1. 核心概念
### 1.1. 中间件 (Middleware)
中间件是一个函数或对象,它在 AI 请求的处理流程中的特定阶段执行,可以访问和修改请求上下文 (`AiProviderMiddlewareContext`)、请求参数 (`Params`),并控制是否将请求传递给下一个中间件或终止流程。
每个中间件应该专注于一个单一的横切关注点,例如日志记录、错误处理、流适配、特性解析等。
### 1.2. `AiProviderMiddlewareContext` (上下文对象)
这是一个在整个中间件链执行过程中传递的对象,包含以下核心信息:
- `_apiClientInstance: ApiClient<any,any,any>`: 当前选定的、已实例化的 AI Provider 客户端。
- `_coreRequest: CoreRequestType`: 标准化的内部核心请求对象。
- `resolvePromise: (value: AggregatedResultType) => void`: 用于在整个操作成功完成时解析 `AiCoreService` 返回的 Promise。
- `rejectPromise: (reason?: any) => void`: 用于在发生错误时拒绝 `AiCoreService` 返回的 Promise。
- `onChunk?: (chunk: Chunk) => void`: 应用层提供的流式数据块回调。
- `abortController?: AbortController`: 用于中止请求的控制器。
- 其他中间件可能读写的、与当前请求相关的动态数据。
### 1.3. `MiddlewareName` (中间件名称)
为了方便动态操作(如插入、替换、移除)中间件,每个重要的、可能被其他逻辑引用的中间件都应该有一个唯一的、可识别的名称。推荐使用 TypeScript 的 `enum` 来定义:
```typescript
// example
export enum MiddlewareName {
LOGGING_START = 'LoggingStartMiddleware',
LOGGING_END = 'LoggingEndMiddleware',
ERROR_HANDLING = 'ErrorHandlingMiddleware',
ABORT_HANDLER = 'AbortHandlerMiddleware',
// Core Flow
TRANSFORM_CORE_TO_SDK_PARAMS = 'TransformCoreToSdkParamsMiddleware',
REQUEST_EXECUTION = 'RequestExecutionMiddleware',
STREAM_ADAPTER = 'StreamAdapterMiddleware',
RAW_SDK_CHUNK_TO_APP_CHUNK = 'RawSdkChunkToAppChunkMiddleware',
// Features
THINKING_TAG_EXTRACTION = 'ThinkingTagExtractionMiddleware',
TOOL_USE_TAG_EXTRACTION = 'ToolUseTagExtractionMiddleware',
MCP_TOOL_HANDLER = 'McpToolHandlerMiddleware',
// Finalization
FINAL_CHUNK_CONSUMER = 'FinalChunkConsumerAndNotifierMiddleware'
// Add more as needed
}
```
中间件实例需要某种方式暴露其 `MiddlewareName`,例如通过一个 `name` 属性。
### 1.4. 中间件执行结构
我们采用一种灵活的中间件执行结构。一个中间件通常是一个函数,它接收 `Context``Params`,以及一个 `next` 函数(用于调用链中的下一个中间件)。
```typescript
// 简化形式的中间件函数签名
type MiddlewareFunction = (
context: AiProviderMiddlewareContext,
params: any, // e.g., CompletionsParams
next: () => Promise<void> // next 通常返回 Promise 以支持异步操作
) => Promise<void> // 中间件自身也可能返回 Promise
// 或者更经典的 Koa/Express 风格 (三段式)
// type MiddlewareFactory = (api?: MiddlewareApi) =>
// (nextMiddleware: (ctx: AiProviderMiddlewareContext, params: any) => Promise<void>) =>
// (context: AiProviderMiddlewareContext, params: any) => Promise<void>;
// 当前设计更倾向于上述简化的 MiddlewareFunction,由 MiddlewareExecutor 负责 next 的编排。
```
`MiddlewareExecutor` (或 `applyMiddlewares`) 会负责管理 `next` 的调用。
## 2. `MiddlewareBuilder` (通用中间件构建器)
为了动态构建和管理中间件链,我们引入一个通用的 `MiddlewareBuilder` 类。
### 2.1. 设计理念
`MiddlewareBuilder` 提供了一个流式 API,用于以声明式的方式构建中间件链。它允许从一个基础链开始,然后根据特定条件添加、插入、替换或移除中间件。
### 2.2. API 概览
```typescript
class MiddlewareBuilder {
constructor(baseChain?: Middleware[])
add(middleware: Middleware): this
prepend(middleware: Middleware): this
insertAfter(targetName: MiddlewareName, middlewareToInsert: Middleware): this
insertBefore(targetName: MiddlewareName, middlewareToInsert: Middleware): this
replace(targetName: MiddlewareName, newMiddleware: Middleware): this
remove(targetName: MiddlewareName): this
build(): Middleware[] // 返回构建好的中间件数组
// 可选:直接执行链
execute(
context: AiProviderMiddlewareContext,
params: any,
middlewareExecutor: (chain: Middleware[], context: AiProviderMiddlewareContext, params: any) => void
): void
}
```
### 2.3. 使用示例
```typescript
// 1. 定义一些中间件实例 (假设它们有 .name 属性)
const loggingStart = { name: MiddlewareName.LOGGING_START, fn: loggingStartFn }
const requestExec = { name: MiddlewareName.REQUEST_EXECUTION, fn: requestExecFn }
const streamAdapter = { name: MiddlewareName.STREAM_ADAPTER, fn: streamAdapterFn }
const customFeature = { name: MiddlewareName.CUSTOM_FEATURE, fn: customFeatureFn } // 假设自定义
// 2. 定义一个基础链 (可选)
const BASE_CHAIN: Middleware[] = [loggingStart, requestExec, streamAdapter]
// 3. 使用 MiddlewareBuilder
const builder = new MiddlewareBuilder(BASE_CHAIN)
if (params.needsCustomFeature) {
builder.insertAfter(MiddlewareName.STREAM_ADAPTER, customFeature)
}
if (params.isHighSecurityContext) {
builder.insertBefore(MiddlewareName.REQUEST_EXECUTION, высокоSecurityCheckMiddleware)
}
if (params.overrideLogging) {
builder.replace(MiddlewareName.LOGGING_START, newSpecialLoggingMiddleware)
}
// 4. 获取最终链
const finalChain = builder.build()
// 5. 执行 (通过外部执行器)
// middlewareExecutor(finalChain, context, params);
// 或者 builder.execute(context, params, middlewareExecutor);
```
## 3. `MiddlewareExecutor` / `applyMiddlewares` (中间件执行器)
这是负责接收 `MiddlewareBuilder` 构建的中间件链并实际执行它们的组件。
### 3.1. 职责
- 接收 `Middleware[]`, `AiProviderMiddlewareContext`, `Params`
- 按顺序迭代中间件。
- 为每个中间件提供正确的 `next` 函数,该函数在被调用时会执行链中的下一个中间件。
- 处理中间件执行过程中的Promise(如果中间件是异步的)。
- 基础的错误捕获(具体错误处理应由链内的 `ErrorHandlingMiddleware` 负责)。
## 4. 在 `AiCoreService` 中使用
`AiCoreService` 中的每个核心业务方法 (如 `executeCompletions`) 将负责:
1. 准备基础数据:实例化 `ApiClient`,转换 `Params``CoreRequest`
2. 实例化 `MiddlewareBuilder`,可能会传入一个特定于该业务方法的基础中间件链。
3. 根据 `Params``CoreRequest` 中的条件,调用 `MiddlewareBuilder` 的方法来动态调整中间件链。
4. 调用 `MiddlewareBuilder.build()` 获取最终的中间件链。
5. 创建完整的 `AiProviderMiddlewareContext` (包含 `resolvePromise`, `rejectPromise` 等)。
6. 调用 `MiddlewareExecutor` (或 `applyMiddlewares`) 来执行构建好的链。
## 5. 组合功能
对于组合功能(例如 "Completions then Translate"):
- 不推荐创建一个单一、庞大的 `MiddlewareBuilder` 来处理整个组合流程。
- 推荐在 `AiCoreService` 中创建一个新的方法,该方法按顺序 `await` 调用底层的原子 `AiCoreService` 方法(例如,先 `await this.executeCompletions(...)`,然后用其结果 `await this.translateText(...)`)。
- 每个被调用的原子方法内部会使用其自身的 `MiddlewareBuilder` 实例来构建和执行其特定阶段的中间件链。
- 这种方式最大化了复用,并保持了各部分职责的清晰。
## 6. 中间件命名和发现
为中间件赋予唯一的 `MiddlewareName` 对于 `MiddlewareBuilder``insertAfter`, `insertBefore`, `replace`, `remove` 等操作至关重要。确保中间件实例能够以某种方式暴露其名称(例如,一个 `name` 属性)。
@@ -0,0 +1,241 @@
import { DefaultCompletionsNamedMiddlewares } from './register'
import { BaseContext, CompletionsMiddleware, MethodMiddleware } from './types'
/**
* 带有名称标识的中间件接口
*/
export interface NamedMiddleware<TMiddleware = any> {
name: string
middleware: TMiddleware
}
/**
* 中间件执行器函数类型
*/
export type MiddlewareExecutor<TContext extends BaseContext = BaseContext> = (
chain: any[],
context: TContext,
params: any
) => Promise<any>
/**
* 通用中间件构建器类
* 提供流式 API 用于动态构建和管理中间件链
*
* 注意:所有中间件都通过 MiddlewareRegistry 管理,使用 NamedMiddleware 格式
*/
export class MiddlewareBuilder<TMiddleware = any> {
private middlewares: NamedMiddleware<TMiddleware>[]
/**
* 构造函数
* @param baseChain - 可选的基础中间件链(NamedMiddleware 格式)
*/
constructor(baseChain?: NamedMiddleware<TMiddleware>[]) {
this.middlewares = baseChain ? [...baseChain] : []
}
/**
* 在链的末尾添加中间件
* @param middleware - 要添加的具名中间件
* @returns this,支持链式调用
*/
add(middleware: NamedMiddleware<TMiddleware>): this {
this.middlewares.push(middleware)
return this
}
/**
* 在链的开头添加中间件
* @param middleware - 要添加的具名中间件
* @returns this,支持链式调用
*/
prepend(middleware: NamedMiddleware<TMiddleware>): this {
this.middlewares.unshift(middleware)
return this
}
/**
* 在指定中间件之后插入新中间件
* @param targetName - 目标中间件名称
* @param middlewareToInsert - 要插入的具名中间件
* @returns this,支持链式调用
*/
insertAfter(targetName: string, middlewareToInsert: NamedMiddleware<TMiddleware>): this {
const index = this.findMiddlewareIndex(targetName)
if (index !== -1) {
this.middlewares.splice(index + 1, 0, middlewareToInsert)
} else {
console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`)
}
return this
}
/**
* 在指定中间件之前插入新中间件
* @param targetName - 目标中间件名称
* @param middlewareToInsert - 要插入的具名中间件
* @returns this,支持链式调用
*/
insertBefore(targetName: string, middlewareToInsert: NamedMiddleware<TMiddleware>): this {
const index = this.findMiddlewareIndex(targetName)
if (index !== -1) {
this.middlewares.splice(index, 0, middlewareToInsert)
} else {
console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`)
}
return this
}
/**
* 替换指定的中间件
* @param targetName - 要替换的中间件名称
* @param newMiddleware - 新的具名中间件
* @returns this,支持链式调用
*/
replace(targetName: string, newMiddleware: NamedMiddleware<TMiddleware>): this {
const index = this.findMiddlewareIndex(targetName)
if (index !== -1) {
this.middlewares[index] = newMiddleware
} else {
console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法替换`)
}
return this
}
/**
* 移除指定的中间件
* @param targetName - 要移除的中间件名称
* @returns this,支持链式调用
*/
remove(targetName: string): this {
const index = this.findMiddlewareIndex(targetName)
if (index !== -1) {
this.middlewares.splice(index, 1)
}
return this
}
/**
* 构建最终的中间件数组
* @returns 构建好的中间件数组
*/
build(): TMiddleware[] {
return this.middlewares.map((item) => item.middleware)
}
/**
* 获取当前中间件链的副本(包含名称信息)
* @returns 当前中间件链的副本
*/
getChain(): NamedMiddleware<TMiddleware>[] {
return [...this.middlewares]
}
/**
* 检查是否包含指定名称的中间件
* @param name - 中间件名称
* @returns 是否包含该中间件
*/
has(name: string): boolean {
return this.findMiddlewareIndex(name) !== -1
}
/**
* 获取中间件链的长度
* @returns 中间件数量
*/
get length(): number {
return this.middlewares.length
}
/**
* 清空中间件链
* @returns this,支持链式调用
*/
clear(): this {
this.middlewares = []
return this
}
/**
* 直接执行构建好的中间件链
* @param context - 中间件上下文
* @param params - 参数
* @param middlewareExecutor - 中间件执行器
* @returns 执行结果
*/
execute<TContext extends BaseContext>(
context: TContext,
params: any,
middlewareExecutor: MiddlewareExecutor<TContext>
): Promise<any> {
const chain = this.build()
return middlewareExecutor(chain, context, params)
}
/**
* 查找中间件在链中的索引
* @param name - 中间件名称
* @returns 索引,如果未找到返回 -1
*/
private findMiddlewareIndex(name: string): number {
return this.middlewares.findIndex((item) => item.name === name)
}
}
/**
* Completions 中间件构建器
*/
export class CompletionsMiddlewareBuilder extends MiddlewareBuilder<CompletionsMiddleware> {
constructor(baseChain?: NamedMiddleware<CompletionsMiddleware>[]) {
super(baseChain)
}
/**
* 使用默认的 Completions 中间件链
* @returns CompletionsMiddlewareBuilder 实例
*/
static withDefaults(): CompletionsMiddlewareBuilder {
return new CompletionsMiddlewareBuilder(DefaultCompletionsNamedMiddlewares)
}
}
/**
* 通用方法中间件构建器
*/
export class MethodMiddlewareBuilder extends MiddlewareBuilder<MethodMiddleware> {
constructor(baseChain?: NamedMiddleware<MethodMiddleware>[]) {
super(baseChain)
}
}
// 便捷的工厂函数
/**
* 创建 Completions 中间件构建器
* @param baseChain - 可选的基础链
* @returns Completions 中间件构建器实例
*/
export function createCompletionsBuilder(
baseChain?: NamedMiddleware<CompletionsMiddleware>[]
): CompletionsMiddlewareBuilder {
return new CompletionsMiddlewareBuilder(baseChain)
}
/**
* 创建通用方法中间件构建器
* @param baseChain - 可选的基础链
* @returns 通用方法中间件构建器实例
*/
export function createMethodBuilder(baseChain?: NamedMiddleware<MethodMiddleware>[]): MethodMiddlewareBuilder {
return new MethodMiddlewareBuilder(baseChain)
}
/**
* 为中间件添加名称属性的辅助函数
* 可以用于给现有的中间件添加名称属性
*/
export function addMiddlewareName<T extends object>(middleware: T, name: string): T & { MIDDLEWARE_NAME: string } {
return Object.assign(middleware, { MIDDLEWARE_NAME: name })
}
@@ -0,0 +1,106 @@
import { Chunk, ChunkType, ErrorChunk } from '@renderer/types/chunk'
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
import { CompletionsParams, CompletionsResult } from '../schemas'
import type { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'AbortHandlerMiddleware'
export const AbortHandlerMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const isRecursiveCall = ctx._internal?.toolProcessingState?.isRecursiveCall || false
// 在递归调用中,跳过 AbortController 的创建,直接使用已有的
if (isRecursiveCall) {
const result = await next(ctx, params)
return result
}
// 获取当前消息的ID用于abort管理
// 优先使用处理过的消息,如果没有则使用原始消息
let messageId: string | undefined
if (typeof params.messages === 'string') {
messageId = `message-${Date.now()}-${Math.random().toString(36).substring(2, 9)}`
} else {
const processedMessages = params.messages
const lastUserMessage = processedMessages.findLast((m) => m.role === 'user')
messageId = lastUserMessage?.id
}
if (!messageId) {
console.warn(`[${MIDDLEWARE_NAME}] No messageId found, abort functionality will not be available.`)
return next(ctx, params)
}
const abortController = new AbortController()
const abortFn = (): void => abortController.abort()
addAbortController(messageId, abortFn)
let abortSignal: AbortSignal | null = abortController.signal
const cleanup = (): void => {
removeAbortController(messageId as string, abortFn)
if (ctx._internal?.flowControl) {
ctx._internal.flowControl.abortController = undefined
ctx._internal.flowControl.abortSignal = undefined
ctx._internal.flowControl.cleanup = undefined
}
abortSignal = null
}
// 将controller添加到_internal中的flowControl状态
if (!ctx._internal.flowControl) {
ctx._internal.flowControl = {}
}
ctx._internal.flowControl.abortController = abortController
ctx._internal.flowControl.abortSignal = abortSignal
ctx._internal.flowControl.cleanup = cleanup
const result = await next(ctx, params)
const error = new DOMException('Request was aborted', 'AbortError')
const streamWithAbortHandler = (result.stream as ReadableStream<Chunk>).pipeThrough(
new TransformStream<Chunk, Chunk | ErrorChunk>({
transform(chunk, controller) {
// 检查 abort 状态
if (abortSignal?.aborted) {
// 转换为 ErrorChunk
const errorChunk: ErrorChunk = {
type: ChunkType.ERROR,
error
}
controller.enqueue(errorChunk)
cleanup()
return
}
// 正常传递 chunk
controller.enqueue(chunk)
},
flush(controller) {
// 在流结束时再次检查 abort 状态
if (abortSignal?.aborted) {
const errorChunk: ErrorChunk = {
type: ChunkType.ERROR,
error
}
controller.enqueue(errorChunk)
}
// 在流完全处理完成后清理 AbortController
cleanup()
}
})
)
return {
...result,
stream: streamWithAbortHandler
}
}
@@ -0,0 +1,60 @@
import { Chunk } from '@renderer/types/chunk'
import { isAbortError } from '@renderer/utils/error'
import { CompletionsResult } from '../schemas'
import { CompletionsContext } from '../types'
import { createErrorChunk } from '../utils'
export const MIDDLEWARE_NAME = 'ErrorHandlerMiddleware'
/**
* 创建一个错误处理中间件。
*
* 这是一个高阶函数,它接收配置并返回一个标准的中间件。
* 它的主要职责是捕获下游中间件或API调用中发生的任何错误。
*
* @param config - 中间件的配置。
* @returns 一个配置好的CompletionsMiddleware。
*/
export const ErrorHandlerMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params): Promise<CompletionsResult> => {
const { shouldThrow } = params
try {
// 尝试执行下一个中间件
return await next(ctx, params)
} catch (error: any) {
let errorStream: ReadableStream<Chunk> | undefined
// 有些sdk的abort error 是直接抛出的
if (!isAbortError(error)) {
// 1. 使用通用的工具函数将错误解析为标准格式
const errorChunk = createErrorChunk(error)
// 2. 调用从外部传入的 onError 回调
if (params.onError) {
params.onError(error)
}
// 3. 根据配置决定是重新抛出错误,还是将其作为流的一部分向下传递
if (shouldThrow) {
throw error
}
// 如果不抛出,则创建一个只包含该错误块的流并向下传递
errorStream = new ReadableStream<Chunk>({
start(controller) {
controller.enqueue(errorChunk)
controller.close()
}
})
}
return {
rawOutput: undefined,
stream: errorStream, // 将包含错误的流传递下去
controller: undefined,
getText: () => '' // 错误情况下没有文本结果
}
}
}
@@ -0,0 +1,183 @@
import Logger from '@renderer/config/logger'
import { Usage } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import { ChunkType } from '@renderer/types/chunk'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'FinalChunkConsumerAndNotifierMiddleware'
/**
* 最终Chunk消费和通知中间件
*
* 职责:
* 1. 消费所有GenericChunk流中的chunks并转发给onChunk回调
* 2. 累加usage/metrics数据(从原始SDK chunks或GenericChunk中提取)
* 3. 在检测到LLM_RESPONSE_COMPLETE时发送包含累计数据的BLOCK_COMPLETE
* 4. 处理MCP工具调用的多轮请求中的数据累加
*/
const FinalChunkConsumerMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const isRecursiveCall =
params._internal?.toolProcessingState?.isRecursiveCall ||
ctx._internal?.toolProcessingState?.isRecursiveCall ||
false
// 初始化累计数据(只在顶层调用时初始化)
if (!isRecursiveCall) {
if (!ctx._internal.customState) {
ctx._internal.customState = {}
}
ctx._internal.observer = {
usage: {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0
},
metrics: {
completion_tokens: 0,
time_completion_millsec: 0,
time_first_token_millsec: 0,
time_thinking_millsec: 0
}
}
// 初始化文本累积器
ctx._internal.customState.accumulatedText = ''
ctx._internal.customState.startTimestamp = Date.now()
}
// 调用下游中间件
const result = await next(ctx, params)
// 响应后处理:处理GenericChunk流式响应
if (result.stream) {
const resultFromUpstream = result.stream
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
const reader = resultFromUpstream.getReader()
try {
while (true) {
const { done, value: chunk } = await reader.read()
if (done) {
Logger.debug(`[${MIDDLEWARE_NAME}] Input stream finished.`)
break
}
if (chunk) {
const genericChunk = chunk as GenericChunk
// 提取并累加usage/metrics数据
extractAndAccumulateUsageMetrics(ctx, genericChunk)
const shouldSkipChunk =
isRecursiveCall &&
(genericChunk.type === ChunkType.BLOCK_COMPLETE ||
genericChunk.type === ChunkType.LLM_RESPONSE_COMPLETE)
if (!shouldSkipChunk) params.onChunk?.(genericChunk)
} else {
Logger.warn(`[${MIDDLEWARE_NAME}] Received undefined chunk before stream was done.`)
}
}
} catch (error) {
Logger.error(`[${MIDDLEWARE_NAME}] Error consuming stream:`, error)
throw error
} finally {
if (params.onChunk && !isRecursiveCall) {
params.onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
usage: ctx._internal.observer?.usage ? { ...ctx._internal.observer.usage } : undefined,
metrics: ctx._internal.observer?.metrics ? { ...ctx._internal.observer.metrics } : undefined
}
} as Chunk)
if (ctx._internal.toolProcessingState) {
ctx._internal.toolProcessingState = {}
}
}
}
// 为流式输出添加getText方法
const modifiedResult = {
...result,
stream: new ReadableStream<GenericChunk>({
start(controller) {
controller.close()
}
}),
getText: () => {
return ctx._internal.customState?.accumulatedText || ''
}
}
return modifiedResult
} else {
Logger.debug(`[${MIDDLEWARE_NAME}] No GenericChunk stream to process.`)
}
}
return result
}
/**
* 从GenericChunk或原始SDK chunks中提取usage/metrics数据并累加
*/
function extractAndAccumulateUsageMetrics(ctx: CompletionsContext, chunk: GenericChunk): void {
if (!ctx._internal.observer?.usage || !ctx._internal.observer?.metrics) {
return
}
try {
if (ctx._internal.customState && !ctx._internal.customState?.firstTokenTimestamp) {
ctx._internal.customState.firstTokenTimestamp = Date.now()
Logger.debug(`[${MIDDLEWARE_NAME}] First token timestamp: ${ctx._internal.customState.firstTokenTimestamp}`)
}
if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
Logger.debug(`[${MIDDLEWARE_NAME}] LLM_RESPONSE_COMPLETE chunk received:`, ctx._internal)
// 从LLM_RESPONSE_COMPLETE chunk中提取usage数据
if (chunk.response?.usage) {
accumulateUsage(ctx._internal.observer.usage, chunk.response.usage)
}
if (ctx._internal.customState && ctx._internal.customState?.firstTokenTimestamp) {
ctx._internal.observer.metrics.time_first_token_millsec =
ctx._internal.customState.firstTokenTimestamp - ctx._internal.customState.startTimestamp
ctx._internal.observer.metrics.time_completion_millsec +=
Date.now() - ctx._internal.customState.firstTokenTimestamp
}
}
// 也可以从其他chunk类型中提取metrics数据
if (chunk.type === ChunkType.THINKING_COMPLETE && chunk.thinking_millsec && ctx._internal.observer?.metrics) {
ctx._internal.observer.metrics.time_thinking_millsec = Math.max(
ctx._internal.observer.metrics.time_thinking_millsec || 0,
chunk.thinking_millsec
)
}
} catch (error) {
console.error(`[${MIDDLEWARE_NAME}] Error extracting usage/metrics from chunk:`, error)
}
}
/**
* 累加usage数据
*/
function accumulateUsage(accumulated: Usage, newUsage: Usage): void {
if (newUsage.prompt_tokens !== undefined) {
accumulated.prompt_tokens += newUsage.prompt_tokens
}
if (newUsage.completion_tokens !== undefined) {
accumulated.completion_tokens += newUsage.completion_tokens
}
if (newUsage.total_tokens !== undefined) {
accumulated.total_tokens += newUsage.total_tokens
}
if (newUsage.thoughts_tokens !== undefined) {
accumulated.thoughts_tokens = (accumulated.thoughts_tokens || 0) + newUsage.thoughts_tokens
}
}
export default FinalChunkConsumerMiddleware
@@ -0,0 +1,64 @@
import { BaseContext, MethodMiddleware, MiddlewareAPI } from '../types'
export const MIDDLEWARE_NAME = 'GenericLoggingMiddlewares'
/**
* Helper function to safely stringify arguments for logging, handling circular references and large objects.
* 安全地字符串化日志参数的辅助函数,处理循环引用和大型对象。
* @param args - The arguments array to stringify. 要字符串化的参数数组。
* @returns A string representation of the arguments. 参数的字符串表示形式。
*/
const stringifyArgsForLogging = (args: any[]): string => {
try {
return args
.map((arg) => {
if (typeof arg === 'function') return '[Function]'
if (typeof arg === 'object' && arg !== null && arg.constructor === Object && Object.keys(arg).length > 20) {
return '[Object with >20 keys]'
}
// Truncate long strings to avoid flooding logs 截断长字符串以避免日志泛滥
const stringifiedArg = JSON.stringify(arg, null, 2)
return stringifiedArg && stringifiedArg.length > 200 ? stringifiedArg.substring(0, 200) + '...' : stringifiedArg
})
.join(', ')
} catch (e) {
return '[Error serializing arguments]' // Handle potential errors during stringification 处理字符串化期间的潜在错误
}
}
/**
* Generic logging middleware for provider methods.
* 为提供者方法创建一个通用的日志中间件。
* This middleware logs the initiation, success/failure, and duration of a method call.
* 此中间件记录方法调用的启动、成功/失败以及持续时间。
*/
/**
* Creates a generic logging middleware for provider methods.
* 为提供者方法创建一个通用的日志中间件。
* @returns A `MethodMiddleware` instance. 一个 `MethodMiddleware` 实例。
*/
export const createGenericLoggingMiddleware: () => MethodMiddleware = () => {
const middlewareName = 'GenericLoggingMiddleware'
// eslint-disable-next-line @typescript-eslint/no-unused-vars
return (_: MiddlewareAPI<BaseContext, any[]>) => (next) => async (ctx, args) => {
const methodName = ctx.methodName
const logPrefix = `[${middlewareName} (${methodName})]`
console.log(`${logPrefix} Initiating. Args:`, stringifyArgsForLogging(args))
const startTime = Date.now()
try {
const result = await next(ctx, args)
const duration = Date.now() - startTime
// Log successful completion of the method call with duration. /
// 记录方法调用成功完成及其持续时间。
console.log(`${logPrefix} Successful. Duration: ${duration}ms`)
return result
} catch (error) {
const duration = Date.now() - startTime
// Log failure of the method call with duration and error information. /
// 记录方法调用失败及其持续时间和错误信息。
console.error(`${logPrefix} Failed. Duration: ${duration}ms`, error)
throw error // Re-throw the error to be handled by subsequent layers or the caller / 重新抛出错误,由后续层或调用者处理
}
}
}
@@ -0,0 +1,285 @@
import {
RequestOptions,
SdkInstance,
SdkMessageParam,
SdkParams,
SdkRawChunk,
SdkRawOutput,
SdkTool,
SdkToolCall
} from '@renderer/types/sdk'
import { BaseApiClient } from '../clients'
import { CompletionsParams, CompletionsResult } from './schemas'
import {
BaseContext,
CompletionsContext,
CompletionsMiddleware,
MethodMiddleware,
MIDDLEWARE_CONTEXT_SYMBOL,
MiddlewareAPI
} from './types'
/**
* Creates the initial context for a method call, populating method-specific fields. /
* 为方法调用创建初始上下文,并填充特定于该方法的字段。
* @param methodName - The name of the method being called. / 被调用的方法名。
* @param originalCallArgs - The actual arguments array from the proxy/method call. / 代理/方法调用的实际参数数组。
* @param providerId - The ID of the provider, if available. / 提供者的ID(如果可用)。
* @param providerInstance - The instance of the provider. / 提供者实例。
* @param specificContextFactory - An optional factory function to create a specific context type from the base context and original call arguments. / 一个可选的工厂函数,用于从基础上下文和原始调用参数创建特定的上下文类型。
* @returns The created context object. / 创建的上下文对象。
*/
function createInitialCallContext<TContext extends BaseContext, TCallArgs extends unknown[]>(
methodName: string,
originalCallArgs: TCallArgs, // Renamed from originalArgs to avoid confusion with context.originalArgs
// Factory to create specific context from base and the *original call arguments array*
specificContextFactory?: (base: BaseContext, callArgs: TCallArgs) => TContext
): TContext {
const baseContext: BaseContext = {
[MIDDLEWARE_CONTEXT_SYMBOL]: true,
methodName,
originalArgs: originalCallArgs // Store the full original arguments array in the context
}
if (specificContextFactory) {
return specificContextFactory(baseContext, originalCallArgs)
}
return baseContext as TContext // Fallback to base context if no specific factory
}
/**
* Composes an array of functions from right to left. /
* 从右到左组合一个函数数组。
* `compose(f, g, h)` is `(...args) => f(g(h(...args)))`. /
* `compose(f, g, h)` 等同于 `(...args) => f(g(h(...args)))`。
* Each function in funcs is expected to take the result of the next function
* (or the initial value for the rightmost function) as its argument. /
* `funcs` 中的每个函数都期望接收下一个函数的结果(或最右侧函数的初始值)作为其参数。
* @param funcs - Array of functions to compose. / 要组合的函数数组。
* @returns The composed function. / 组合后的函数。
*/
function compose(...funcs: Array<(...args: any[]) => any>): (...args: any[]) => any {
if (funcs.length === 0) {
// If no functions to compose, return a function that returns its first argument, or undefined if no args. /
// 如果没有要组合的函数,则返回一个函数,该函数返回其第一个参数,如果没有参数则返回undefined。
return (...args: any[]) => (args.length > 0 ? args[0] : undefined)
}
if (funcs.length === 1) {
return funcs[0]
}
return funcs.reduce(
(a, b) =>
(...args: any[]) =>
a(b(...args))
)
}
/**
* Applies an array of Redux-style middlewares to a generic provider method. /
* 将一组Redux风格的中间件应用于一个通用的提供者方法。
* This version keeps arguments as an array throughout the middleware chain. /
* 此版本在整个中间件链中将参数保持为数组形式。
* @param originalProviderInstance - The original provider instance. / 原始提供者实例。
* @param methodName - The name of the method to be enhanced. / 需要增强的方法名。
* @param originalMethod - The original method to be wrapped. / 需要包装的原始方法。
* @param middlewares - An array of `ProviderMethodMiddleware` to apply. / 要应用的 `ProviderMethodMiddleware` 数组。
* @param specificContextFactory - An optional factory to create a specific context for this method. / 可选的工厂函数,用于为此方法创建特定的上下文。
* @returns An enhanced method with the middlewares applied. / 应用了中间件的增强方法。
*/
export function applyMethodMiddlewares<
TArgs extends unknown[] = unknown[], // Original method's arguments array type / 原始方法的参数数组类型
TResult = unknown,
TContext extends BaseContext = BaseContext
>(
methodName: string,
originalMethod: (...args: TArgs) => Promise<TResult>,
middlewares: MethodMiddleware[], // Expects generic middlewares / 期望通用中间件
specificContextFactory?: (base: BaseContext, callArgs: TArgs) => TContext
): (...args: TArgs) => Promise<TResult> {
// Returns a function matching the original method signature. /
// 返回一个与原始方法签名匹配的函数。
return async function enhancedMethod(...methodCallArgs: TArgs): Promise<TResult> {
const ctx = createInitialCallContext<TContext, TArgs>(
methodName,
methodCallArgs, // Pass the actual call arguments array / 传递实际的调用参数数组
specificContextFactory
)
const api: MiddlewareAPI<TContext, TArgs> = {
getContext: () => ctx,
getOriginalArgs: () => methodCallArgs // API provides the original arguments array / API提供原始参数数组
}
// `finalDispatch` is the function that will ultimately call the original provider method. /
// `finalDispatch` 是最终将调用原始提供者方法的函数。
// It receives the current context and arguments, which may have been transformed by middlewares. /
// 它接收当前的上下文和参数,这些参数可能已被中间件转换。
const finalDispatch = async (
_: TContext,
currentArgs: TArgs // Generic final dispatch expects args array / 通用finalDispatch期望参数数组
): Promise<TResult> => {
return originalMethod.apply(currentArgs)
}
const chain = middlewares.map((middleware) => middleware(api)) // Cast API if TContext/TArgs mismatch general ProviderMethodMiddleware / 如果TContext/TArgs与通用的ProviderMethodMiddleware不匹配,则转换API
const composedMiddlewareLogic = compose(...chain)
const enhancedDispatch = composedMiddlewareLogic(finalDispatch)
return enhancedDispatch(ctx, methodCallArgs) // Pass context and original args array / 传递上下文和原始参数数组
}
}
/**
* Applies an array of `CompletionsMiddleware` to the `completions` method. /
* 将一组 `CompletionsMiddleware` 应用于 `completions` 方法。
* This version adapts for `CompletionsMiddleware` expecting a single `params` object. /
* 此版本适配了期望单个 `params` 对象的 `CompletionsMiddleware`。
* @param originalProviderInstance - The original provider instance. / 原始提供者实例。
* @param originalCompletionsMethod - The original SDK `createCompletions` method. / 原始的 SDK `createCompletions` 方法。
* @param middlewares - An array of `CompletionsMiddleware` to apply. / 要应用的 `CompletionsMiddleware` 数组。
* @returns An enhanced `completions` method with the middlewares applied. / 应用了中间件的增强版 `completions` 方法。
*/
export function applyCompletionsMiddlewares<
TSdkInstance extends SdkInstance = SdkInstance,
TSdkParams extends SdkParams = SdkParams,
TRawOutput extends SdkRawOutput = SdkRawOutput,
TRawChunk extends SdkRawChunk = SdkRawChunk,
TMessageParam extends SdkMessageParam = SdkMessageParam,
TToolCall extends SdkToolCall = SdkToolCall,
TSdkSpecificTool extends SdkTool = SdkTool
>(
originalApiClientInstance: BaseApiClient<
TSdkInstance,
TSdkParams,
TRawOutput,
TRawChunk,
TMessageParam,
TToolCall,
TSdkSpecificTool
>,
originalCompletionsMethod: (payload: TSdkParams, options?: RequestOptions) => Promise<TRawOutput>,
middlewares: CompletionsMiddleware<
TSdkParams,
TMessageParam,
TToolCall,
TSdkInstance,
TRawOutput,
TRawChunk,
TSdkSpecificTool
>[]
): (params: CompletionsParams, options?: RequestOptions) => Promise<CompletionsResult> {
// Returns a function matching the original method signature. /
// 返回一个与原始方法签名匹配的函数。
const methodName = 'completions'
// Factory to create AiProviderMiddlewareCompletionsContext. /
// 用于创建 AiProviderMiddlewareCompletionsContext 的工厂函数。
const completionsContextFactory = (
base: BaseContext,
callArgs: [CompletionsParams]
): CompletionsContext<
TSdkParams,
TMessageParam,
TToolCall,
TSdkInstance,
TRawOutput,
TRawChunk,
TSdkSpecificTool
> => {
return {
...base,
methodName,
apiClientInstance: originalApiClientInstance,
originalArgs: callArgs,
_internal: {
toolProcessingState: {
recursionDepth: 0,
isRecursiveCall: false
},
observer: {}
}
}
}
return async function enhancedCompletionsMethod(
params: CompletionsParams,
options?: RequestOptions
): Promise<CompletionsResult> {
// `originalCallArgs` for context creation is `[params]`. /
// 用于上下文创建的 `originalCallArgs` 是 `[params]`。
const originalCallArgs: [CompletionsParams] = [params]
const baseContext: BaseContext = {
[MIDDLEWARE_CONTEXT_SYMBOL]: true,
methodName,
originalArgs: originalCallArgs
}
const ctx = completionsContextFactory(baseContext, originalCallArgs)
const api: MiddlewareAPI<
CompletionsContext<TSdkParams, TMessageParam, TToolCall, TSdkInstance, TRawOutput, TRawChunk, TSdkSpecificTool>,
[CompletionsParams]
> = {
getContext: () => ctx,
getOriginalArgs: () => originalCallArgs // API provides [CompletionsParams] / API提供 `[CompletionsParams]`
}
// `finalDispatch` for CompletionsMiddleware: expects (context, params) not (context, args_array). /
// `CompletionsMiddleware` 的 `finalDispatch`:期望 (context, params) 而不是 (context, args_array)。
const finalDispatch = async (
context: CompletionsContext<
TSdkParams,
TMessageParam,
TToolCall,
TSdkInstance,
TRawOutput,
TRawChunk,
TSdkSpecificTool
> // Context passed through / 上下文透传
// _currentParams: CompletionsParams // Directly takes params / 直接接收参数 (unused but required for middleware signature)
): Promise<CompletionsResult> => {
// At this point, middleware should have transformed CompletionsParams to SDK params
// and stored them in context. If no transformation happened, we need to handle it.
// 此时,中间件应该已经将 CompletionsParams 转换为 SDK 参数并存储在上下文中。
// 如果没有进行转换,我们需要处理它。
const sdkPayload = context._internal?.sdkPayload
if (!sdkPayload) {
throw new Error('SDK payload not found in context. Middleware chain should have transformed parameters.')
}
const abortSignal = context._internal.flowControl?.abortSignal
const timeout = context._internal.customState?.sdkMetadata?.timeout
// Call the original SDK method with transformed parameters
// 使用转换后的参数调用原始 SDK 方法
const rawOutput = await originalCompletionsMethod.call(originalApiClientInstance, sdkPayload, {
...options,
signal: abortSignal,
timeout
})
// Return result wrapped in CompletionsResult format
// 以 CompletionsResult 格式返回包装的结果
return {
rawOutput
} as CompletionsResult
}
const chain = middlewares.map((middleware) => middleware(api))
const composedMiddlewareLogic = compose(...chain)
// `enhancedDispatch` has the signature `(context, params) => Promise<CompletionsResult>`. /
// `enhancedDispatch` 的签名为 `(context, params) => Promise<CompletionsResult>`。
const enhancedDispatch = composedMiddlewareLogic(finalDispatch)
// 将 enhancedDispatch 保存到 context 中,供中间件进行递归调用
// 这样可以避免重复执行整个中间件链
ctx._internal.enhancedDispatch = enhancedDispatch
// Execute with context and the single params object. /
// 使用上下文和单个参数对象执行。
return enhancedDispatch(ctx, params)
}
}
@@ -0,0 +1,306 @@
import Logger from '@renderer/config/logger'
import { MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types'
import { ChunkType, MCPToolCreatedChunk } from '@renderer/types/chunk'
import { SdkMessageParam, SdkRawOutput, SdkToolCall } from '@renderer/types/sdk'
import { parseAndCallTools } from '@renderer/utils/mcp-tools'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'McpToolChunkMiddleware'
const MAX_TOOL_RECURSION_DEPTH = 20 // 防止无限递归
/**
* MCP工具处理中间件
*
* 职责:
* 1. 检测并拦截MCP工具进展chunkFunction Call方式和Tool Use方式)
* 2. 执行工具调用
* 3. 递归处理工具结果
* 4. 管理工具调用状态和递归深度
*/
export const McpToolChunkMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const mcpTools = params.mcpTools || []
// 如果没有工具,直接调用下一个中间件
if (!mcpTools || mcpTools.length === 0) {
return next(ctx, params)
}
const executeWithToolHandling = async (currentParams: CompletionsParams, depth = 0): Promise<CompletionsResult> => {
if (depth >= MAX_TOOL_RECURSION_DEPTH) {
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Maximum recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`)
throw new Error(`Maximum tool recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`)
}
let result: CompletionsResult
if (depth === 0) {
result = await next(ctx, currentParams)
} else {
const enhancedCompletions = ctx._internal.enhancedDispatch
if (!enhancedCompletions) {
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Enhanced completions method not found, cannot perform recursive call`)
throw new Error('Enhanced completions method not found')
}
ctx._internal.toolProcessingState!.isRecursiveCall = true
ctx._internal.toolProcessingState!.recursionDepth = depth
result = await enhancedCompletions(ctx, currentParams)
}
if (!result.stream) {
Logger.error(`🔧 [${MIDDLEWARE_NAME}] No stream returned from enhanced completions`)
throw new Error('No stream returned from enhanced completions')
}
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
const toolHandlingStream = resultFromUpstream.pipeThrough(
createToolHandlingTransform(ctx, currentParams, mcpTools, depth, executeWithToolHandling)
)
return {
...result,
stream: toolHandlingStream
}
}
return executeWithToolHandling(params, 0)
}
/**
* 创建工具处理的 TransformStream
*/
function createToolHandlingTransform(
ctx: CompletionsContext,
currentParams: CompletionsParams,
mcpTools: MCPTool[],
depth: number,
executeWithToolHandling: (params: CompletionsParams, depth: number) => Promise<CompletionsResult>
): TransformStream<GenericChunk, GenericChunk> {
const toolCalls: SdkToolCall[] = []
const toolUseResponses: MCPToolResponse[] = []
const allToolResponses: MCPToolResponse[] = [] // 统一的工具响应状态管理数组
let hasToolCalls = false
let hasToolUseResponses = false
let streamEnded = false
return new TransformStream({
async transform(chunk: GenericChunk, controller) {
try {
// 处理MCP工具进展chunk
if (chunk.type === ChunkType.MCP_TOOL_CREATED) {
const createdChunk = chunk as MCPToolCreatedChunk
// 1. 处理Function Call方式的工具调用
if (createdChunk.tool_calls && createdChunk.tool_calls.length > 0) {
toolCalls.push(...createdChunk.tool_calls)
hasToolCalls = true
}
// 2. 处理Tool Use方式的工具调用
if (createdChunk.tool_use_responses && createdChunk.tool_use_responses.length > 0) {
toolUseResponses.push(...createdChunk.tool_use_responses)
hasToolUseResponses = true
}
// 不转发MCP工具进展chunks,避免重复处理
return
}
// 转发其他所有chunk
controller.enqueue(chunk)
} catch (error) {
console.error(`🔧 [${MIDDLEWARE_NAME}] Error processing chunk:`, error)
controller.error(error)
}
},
async flush(controller) {
const shouldExecuteToolCalls = hasToolCalls && toolCalls.length > 0
const shouldExecuteToolUseResponses = hasToolUseResponses && toolUseResponses.length > 0
if (!streamEnded && (shouldExecuteToolCalls || shouldExecuteToolUseResponses)) {
streamEnded = true
try {
let toolResult: SdkMessageParam[] = []
if (shouldExecuteToolCalls) {
toolResult = await executeToolCalls(
ctx,
toolCalls,
mcpTools,
allToolResponses,
currentParams.onChunk,
currentParams.assistant.model!
)
} else if (shouldExecuteToolUseResponses) {
toolResult = await executeToolUseResponses(
ctx,
toolUseResponses,
mcpTools,
allToolResponses,
currentParams.onChunk,
currentParams.assistant.model!
)
}
if (toolResult.length > 0) {
const output = ctx._internal.toolProcessingState?.output
const newParams = buildParamsWithToolResults(ctx, currentParams, output!, toolResult, toolCalls)
await executeWithToolHandling(newParams, depth + 1)
}
} catch (error) {
console.error(`🔧 [${MIDDLEWARE_NAME}] Error in tool processing:`, error)
controller.error(error)
} finally {
hasToolCalls = false
hasToolUseResponses = false
}
}
}
})
}
/**
* 执行工具调用(Function Call 方式)
*/
async function executeToolCalls(
ctx: CompletionsContext,
toolCalls: SdkToolCall[],
mcpTools: MCPTool[],
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
model: Model
): Promise<SdkMessageParam[]> {
// 转换为MCPToolResponse格式
const mcpToolResponses: ToolCallResponse[] = toolCalls
.map((toolCall) => {
const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools)
if (!mcpTool) {
return undefined
}
return ctx.apiClientInstance.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
})
.filter((t): t is ToolCallResponse => typeof t !== 'undefined')
if (mcpToolResponses.length === 0) {
console.warn(`🔧 [${MIDDLEWARE_NAME}] No valid MCP tool responses to execute`)
return []
}
// 使用现有的parseAndCallTools函数执行工具
const toolResults = await parseAndCallTools(
mcpToolResponses,
allToolResponses,
onChunk,
(mcpToolResponse, resp, model) => {
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
},
model,
mcpTools
)
return toolResults
}
/**
* 执行工具使用响应(Tool Use Response 方式)
* 处理已经解析好的 ToolUseResponse[],不需要重新解析字符串
*/
async function executeToolUseResponses(
ctx: CompletionsContext,
toolUseResponses: MCPToolResponse[],
mcpTools: MCPTool[],
allToolResponses: MCPToolResponse[],
onChunk: CompletionsParams['onChunk'],
model: Model
): Promise<SdkMessageParam[]> {
// 直接使用parseAndCallTools函数处理已经解析好的ToolUseResponse
const toolResults = await parseAndCallTools(
toolUseResponses,
allToolResponses,
onChunk,
(mcpToolResponse, resp, model) => {
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
},
model,
mcpTools
)
return toolResults
}
/**
* 构建包含工具结果的新参数
*/
function buildParamsWithToolResults(
ctx: CompletionsContext,
currentParams: CompletionsParams,
output: SdkRawOutput | string,
toolResults: SdkMessageParam[],
toolCalls: SdkToolCall[]
): CompletionsParams {
// 获取当前已经转换好的reqMessages,如果没有则使用原始messages
const currentReqMessages = getCurrentReqMessages(ctx)
const apiClient = ctx.apiClientInstance
// 从回复中构建助手消息
const newReqMessages = apiClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
// 估算新增消息的 token 消耗并累加到 usage 中
if (ctx._internal.observer?.usage && newReqMessages.length > currentReqMessages.length) {
try {
const newMessages = newReqMessages.slice(currentReqMessages.length)
const additionalTokens = newMessages.reduce((acc, message) => {
return acc + ctx.apiClientInstance.estimateMessageTokens(message)
}, 0)
if (additionalTokens > 0) {
ctx._internal.observer.usage.prompt_tokens += additionalTokens
ctx._internal.observer.usage.total_tokens += additionalTokens
}
} catch (error) {
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Error estimating token usage for new messages:`, error)
}
}
// 更新递归状态
if (!ctx._internal.toolProcessingState) {
ctx._internal.toolProcessingState = {}
}
ctx._internal.toolProcessingState.isRecursiveCall = true
ctx._internal.toolProcessingState.recursionDepth = (ctx._internal.toolProcessingState?.recursionDepth || 0) + 1
return {
...currentParams,
_internal: {
...ctx._internal,
sdkPayload: ctx._internal.sdkPayload,
newReqMessages: newReqMessages
}
}
}
/**
* 类型安全地获取当前请求消息
* 使用API客户端提供的抽象方法,保持中间件的provider无关性
*/
function getCurrentReqMessages(ctx: CompletionsContext): SdkMessageParam[] {
const sdkPayload = ctx._internal.sdkPayload
if (!sdkPayload) {
return []
}
// 使用API客户端的抽象方法来提取消息,保持provider无关性
return ctx.apiClientInstance.extractMessagesFromSdkPayload(sdkPayload)
}
export default McpToolChunkMiddleware
@@ -0,0 +1,48 @@
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
import { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk'
import { AnthropicStreamListener } from '../../clients/types'
import { CompletionsParams, CompletionsResult } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'RawStreamListenerMiddleware'
export const RawStreamListenerMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const result = await next(ctx, params)
// 在这里可以监听到从SDK返回的最原始流
if (result.rawOutput) {
console.log(`[${MIDDLEWARE_NAME}] 检测到原始SDK输出,准备附加监听器`)
const providerType = ctx.apiClientInstance.provider.type
// TODO: 后面下放到AnthropicAPIClient
if (providerType === 'anthropic') {
const anthropicListener: AnthropicStreamListener<AnthropicSdkRawChunk> = {
onMessage: (message) => {
if (ctx._internal?.toolProcessingState) {
ctx._internal.toolProcessingState.output = message
}
}
// onContentBlock: (contentBlock) => {
// console.log(`[${MIDDLEWARE_NAME}] 📝 Anthropic content block:`, contentBlock.type)
// }
}
const specificApiClient = ctx.apiClientInstance as AnthropicAPIClient
const monitoredOutput = specificApiClient.attachRawStreamListener(
result.rawOutput as AnthropicSdkRawOutput,
anthropicListener
)
return {
...result,
rawOutput: monitoredOutput
}
}
}
return result
}
@@ -0,0 +1,85 @@
import Logger from '@renderer/config/logger'
import { SdkRawChunk } from '@renderer/types/sdk'
import { ResponseChunkTransformerContext } from '../../clients/types'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'ResponseTransformMiddleware'
/**
* 响应转换中间件
*
* 职责:
* 1. 检测ReadableStream类型的响应流
* 2. 使用ApiClient的getResponseChunkTransformer()将原始SDK响应块转换为通用格式
* 3. 将转换后的ReadableStream保存到ctx._internal.apiCall.genericChunkStream,供下游中间件使用
*
* 注意:此中间件应该在StreamAdapterMiddleware之后执行
*/
export const ResponseTransformMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
// 调用下游中间件
const result = await next(ctx, params)
// 响应后处理:转换原始SDK响应块
if (result.stream) {
const adaptedStream = result.stream
// 处理ReadableStream类型的流
if (adaptedStream instanceof ReadableStream) {
const apiClient = ctx.apiClientInstance
if (!apiClient) {
console.error(`[${MIDDLEWARE_NAME}] ApiClient instance not found in context`)
throw new Error('ApiClient instance not found in context')
}
// 获取响应转换器
const responseChunkTransformer = apiClient.getResponseChunkTransformer?.()
if (!responseChunkTransformer) {
Logger.warn(`[${MIDDLEWARE_NAME}] No ResponseChunkTransformer available, skipping transformation`)
return result
}
const assistant = params.assistant
const model = assistant?.model
if (!assistant || !model) {
console.error(`[${MIDDLEWARE_NAME}] Assistant or Model not found for transformation`)
throw new Error('Assistant or Model not found for transformation')
}
const transformerContext: ResponseChunkTransformerContext = {
isStreaming: params.streamOutput || false,
isEnabledToolCalling: (params.mcpTools && params.mcpTools.length > 0) || false,
isEnabledWebSearch: params.enableWebSearch || false,
isEnabledReasoning: params.enableReasoning || false,
mcpTools: params.mcpTools || [],
provider: ctx.apiClientInstance?.provider
}
console.log(`[${MIDDLEWARE_NAME}] Transforming raw SDK chunks with context:`, transformerContext)
try {
// 创建转换后的流
const genericChunkTransformStream = (adaptedStream as ReadableStream<SdkRawChunk>).pipeThrough<GenericChunk>(
new TransformStream<SdkRawChunk, GenericChunk>(responseChunkTransformer(transformerContext))
)
// 将转换后的ReadableStream保存到result,供下游中间件使用
return {
...result,
stream: genericChunkTransformStream
}
} catch (error) {
Logger.error(`[${MIDDLEWARE_NAME}] Error during chunk transformation:`, error)
throw error
}
}
}
// 如果没有流或不是ReadableStream,返回原始结果
return result
}
@@ -0,0 +1,57 @@
import { SdkRawChunk } from '@renderer/types/sdk'
import { asyncGeneratorToReadableStream, createSingleChunkReadableStream } from '@renderer/utils/stream'
import { CompletionsParams, CompletionsResult } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
import { isAsyncIterable } from '../utils'
export const MIDDLEWARE_NAME = 'StreamAdapterMiddleware'
/**
* 流适配器中间件
*
* 职责:
* 1. 检测ctx._internal.apiCall.rawSdkOutput(优先)或原始AsyncIterable流
* 2. 将AsyncIterable转换为WHATWG ReadableStream
* 3. 更新响应结果中的stream
*
* 注意:如果ResponseTransformMiddleware已处理过,会优先使用transformedStream
*/
export const StreamAdapterMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
// TODO:调用开始,因为这个是最靠近接口请求的地方,next执行代表着开始接口请求了
// 但是这个中间件的职责是流适配,是否在这调用优待商榷
// 调用下游中间件
const result = await next(ctx, params)
if (
result.rawOutput &&
!(result.rawOutput instanceof ReadableStream) &&
isAsyncIterable<SdkRawChunk>(result.rawOutput)
) {
const whatwgReadableStream: ReadableStream<SdkRawChunk> = asyncGeneratorToReadableStream<SdkRawChunk>(
result.rawOutput
)
return {
...result,
stream: whatwgReadableStream
}
} else if (result.rawOutput && result.rawOutput instanceof ReadableStream) {
return {
...result,
stream: result.rawOutput
}
} else if (result.rawOutput) {
// 非流式输出,强行变为可读流
const whatwgReadableStream: ReadableStream<SdkRawChunk> = createSingleChunkReadableStream<SdkRawChunk>(
result.rawOutput
)
return {
...result,
stream: whatwgReadableStream
}
}
return result
}
@@ -0,0 +1,99 @@
import Logger from '@renderer/config/logger'
import { ChunkType, TextDeltaChunk } from '@renderer/types/chunk'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'TextChunkMiddleware'
/**
* 文本块处理中间件
*
* 职责:
* 1. 累积文本内容(TEXT_DELTA
* 2. 对文本内容进行智能链接转换
* 3. 生成TEXT_COMPLETE事件
* 4. 暂存Web搜索结果,用于最终链接完善
* 5. 处理 onResponse 回调,实时发送文本更新和最终完整文本
*/
export const TextChunkMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
// 调用下游中间件
const result = await next(ctx, params)
// 响应后处理:转换流式响应中的文本内容
if (result.stream) {
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
const assistant = params.assistant
const model = params.assistant?.model
if (!assistant || !model) {
Logger.warn(`[${MIDDLEWARE_NAME}] Missing assistant or model information, skipping text processing`)
return result
}
// 用于跨chunk的状态管理
let accumulatedTextContent = ''
let hasEnqueue = false
const enhancedTextStream = resultFromUpstream.pipeThrough(
new TransformStream<GenericChunk, GenericChunk>({
transform(chunk: GenericChunk, controller) {
if (chunk.type === ChunkType.TEXT_DELTA) {
const textChunk = chunk as TextDeltaChunk
accumulatedTextContent += textChunk.text
// 处理 onResponse 回调 - 发送增量文本更新
if (params.onResponse) {
params.onResponse(accumulatedTextContent, false)
}
// 创建新的chunk,包含处理后的文本
controller.enqueue(chunk)
} else if (accumulatedTextContent) {
if (chunk.type !== ChunkType.LLM_RESPONSE_COMPLETE) {
controller.enqueue(chunk)
hasEnqueue = true
}
const finalText = accumulatedTextContent
ctx._internal.customState!.accumulatedText = finalText
if (ctx._internal.toolProcessingState && !ctx._internal.toolProcessingState?.output) {
ctx._internal.toolProcessingState.output = finalText
}
// 处理 onResponse 回调 - 发送最终完整文本
if (params.onResponse) {
params.onResponse(finalText, true)
}
controller.enqueue({
type: ChunkType.TEXT_COMPLETE,
text: finalText
})
accumulatedTextContent = ''
if (!hasEnqueue) {
controller.enqueue(chunk)
}
} else {
// 其他类型的chunk直接传递
controller.enqueue(chunk)
}
}
})
)
// 更新响应结果
return {
...result,
stream: enhancedTextStream
}
} else {
Logger.warn(`[${MIDDLEWARE_NAME}] No stream to process or not a ReadableStream. Returning original result.`)
}
}
return result
}
@@ -0,0 +1,101 @@
import Logger from '@renderer/config/logger'
import { ChunkType, ThinkingCompleteChunk, ThinkingDeltaChunk } from '@renderer/types/chunk'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'ThinkChunkMiddleware'
/**
* 处理思考内容的中间件
*
* 注意:从 v2 版本开始,流结束语义的判断已移至 ApiClient 层处理
* 此中间件现在主要负责:
* 1. 处理原始SDK chunk中的reasoning字段
* 2. 计算准确的思考时间
* 3. 在思考内容结束时生成THINKING_COMPLETE事件
*
* 职责:
* 1. 累积思考内容(THINKING_DELTA
* 2. 监听流结束信号,生成THINKING_COMPLETE事件
* 3. 计算准确的思考时间
*
*/
export const ThinkChunkMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
// 调用下游中间件
const result = await next(ctx, params)
// 响应后处理:处理思考内容
if (result.stream) {
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
// 检查是否启用reasoning
const enableReasoning = params.enableReasoning || false
if (!enableReasoning) {
return result
}
// 检查是否有流需要处理
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
// thinking 处理状态
let accumulatedThinkingContent = ''
let hasThinkingContent = false
let thinkingStartTime = 0
const processedStream = resultFromUpstream.pipeThrough(
new TransformStream<GenericChunk, GenericChunk>({
transform(chunk: GenericChunk, controller) {
if (chunk.type === ChunkType.THINKING_DELTA) {
const thinkingChunk = chunk as ThinkingDeltaChunk
// 第一次接收到思考内容时记录开始时间
if (!hasThinkingContent) {
hasThinkingContent = true
thinkingStartTime = Date.now()
}
accumulatedThinkingContent += thinkingChunk.text
// 更新思考时间并传递
const enhancedChunk: ThinkingDeltaChunk = {
...thinkingChunk,
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
}
controller.enqueue(enhancedChunk)
} else if (hasThinkingContent && thinkingStartTime > 0) {
// 收到任何非THINKING_DELTA的chunk时,如果有累积的思考内容,生成THINKING_COMPLETE
const thinkingCompleteChunk: ThinkingCompleteChunk = {
type: ChunkType.THINKING_COMPLETE,
text: accumulatedThinkingContent,
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
}
controller.enqueue(thinkingCompleteChunk)
hasThinkingContent = false
accumulatedThinkingContent = ''
thinkingStartTime = 0
// 继续传递当前chunk
controller.enqueue(chunk)
} else {
// 其他情况直接传递
controller.enqueue(chunk)
}
}
})
)
// 更新响应结果
return {
...result,
stream: processedStream
}
} else {
Logger.warn(`[${MIDDLEWARE_NAME}] No generic chunk stream to process or not a ReadableStream.`)
}
}
return result
}
@@ -0,0 +1,83 @@
import Logger from '@renderer/config/logger'
import { ChunkType } from '@renderer/types/chunk'
import { CompletionsParams, CompletionsResult } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'TransformCoreToSdkParamsMiddleware'
/**
* 中间件:将CoreCompletionsRequest转换为SDK特定的参数
* 使用上下文中ApiClient实例的requestTransformer进行转换
*/
export const TransformCoreToSdkParamsMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
Logger.debug(`🔄 [${MIDDLEWARE_NAME}] Starting core to SDK params transformation:`, ctx)
const internal = ctx._internal
// 🔧 检测递归调用:检查 params 中是否携带了预处理的 SDK 消息
const isRecursiveCall = internal?.toolProcessingState?.isRecursiveCall || false
const newSdkMessages = params._internal?.newReqMessages
const apiClient = ctx.apiClientInstance
if (!apiClient) {
Logger.error(`🔄 [${MIDDLEWARE_NAME}] ApiClient instance not found in context.`)
throw new Error('ApiClient instance not found in context')
}
// 检查是否有requestTransformer方法
const requestTransformer = apiClient.getRequestTransformer()
if (!requestTransformer) {
Logger.warn(
`🔄 [${MIDDLEWARE_NAME}] ApiClient does not have getRequestTransformer method, skipping transformation`
)
const result = await next(ctx, params)
return result
}
// 确保assistant和model可用,它们是transformer所需的
const assistant = params.assistant
const model = params.assistant.model
if (!assistant || !model) {
console.error(`🔄 [${MIDDLEWARE_NAME}] Assistant or Model not found for transformation.`)
throw new Error('Assistant or Model not found for transformation')
}
try {
const transformResult = await requestTransformer.transform(
params,
assistant,
model,
isRecursiveCall,
newSdkMessages
)
const { payload: sdkPayload, metadata } = transformResult
// 将SDK特定的payload和metadata存储在状态中,供下游中间件使用
ctx._internal.sdkPayload = sdkPayload
if (metadata) {
ctx._internal.customState = {
...ctx._internal.customState,
sdkMetadata: metadata
}
}
if (params.enableGenerateImage) {
params.onChunk?.({
type: ChunkType.IMAGE_CREATED
})
}
return next(ctx, params)
} catch (error) {
Logger.error(`🔄 [${MIDDLEWARE_NAME}] Error during request transformation:`, error)
// 让错误向上传播,或者可以在这里进行特定的错误处理
throw error
}
}
@@ -0,0 +1,76 @@
import { ChunkType } from '@renderer/types/chunk'
import { smartLinkConverter } from '@renderer/utils/linkConverter'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'WebSearchMiddleware'
/**
* Web搜索处理中间件 - 基于GenericChunk流处理
*
* 职责:
* 1. 监听和记录Web搜索事件
* 2. 可以在此处添加Web搜索结果的后处理逻辑
* 3. 维护Web搜索相关的状态
*
* 注意:Web搜索结果的识别和生成已在ApiClient的响应转换器中处理
*/
export const WebSearchMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
ctx._internal.webSearchState = {
results: undefined
}
// 调用下游中间件
const result = await next(ctx, params)
const model = params.assistant?.model!
let isFirstChunk = true
// 响应后处理:记录Web搜索事件
if (result.stream) {
const resultFromUpstream = result.stream
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
// Web搜索状态跟踪
const enhancedStream = (resultFromUpstream as ReadableStream<GenericChunk>).pipeThrough(
new TransformStream<GenericChunk, GenericChunk>({
transform(chunk: GenericChunk, controller) {
if (chunk.type === ChunkType.TEXT_DELTA) {
const providerType = model.provider || 'openai'
// 使用当前可用的Web搜索结果进行链接转换
const text = chunk.text
const processedText = smartLinkConverter(text, providerType, isFirstChunk)
if (isFirstChunk) {
isFirstChunk = false
}
controller.enqueue({
...chunk,
text: processedText
})
} else if (chunk.type === ChunkType.LLM_WEB_SEARCH_COMPLETE) {
// 暂存Web搜索结果用于链接完善
ctx._internal.webSearchState!.results = chunk.llm_web_search
// 将Web搜索完成事件继续传递下去
controller.enqueue(chunk)
} else {
controller.enqueue(chunk)
}
}
})
)
return {
...result,
stream: enhancedStream
}
} else {
console.log(`[${MIDDLEWARE_NAME}] No stream to process or not a ReadableStream.`)
}
}
return result
}
@@ -0,0 +1,132 @@
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import { ChunkType } from '@renderer/types/chunk'
import { findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import OpenAI from 'openai'
import { toFile } from 'openai/uploads'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'ImageGenerationMiddleware'
export const ImageGenerationMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (context: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const { assistant, messages } = params
const client = context.apiClientInstance as BaseApiClient<OpenAI>
const signal = context._internal?.flowControl?.abortSignal
if (!assistant.model || !isDedicatedImageGenerationModel(assistant.model) || typeof messages === 'string') {
return next(context, params)
}
const stream = new ReadableStream<GenericChunk>({
async start(controller) {
const enqueue = (chunk: GenericChunk) => controller.enqueue(chunk)
try {
if (!assistant.model) {
throw new Error('Assistant model is not defined.')
}
const sdk = await client.getSdkInstance()
const lastUserMessage = messages.findLast((m) => m.role === 'user')
const lastAssistantMessage = messages.findLast((m) => m.role === 'assistant')
if (!lastUserMessage) {
throw new Error('No user message found for image generation.')
}
const prompt = getMainTextContent(lastUserMessage)
let imageFiles: Blob[] = []
// Collect images from user message
const userImageBlocks = findImageBlocks(lastUserMessage)
const userImages = await Promise.all(
userImageBlocks.map(async (block) => {
if (!block.file) return null
const binaryData: Uint8Array = await window.api.file.binaryImage(block.file.id)
const mimeType = `${block.file.type}/${block.file.ext.slice(1)}`
return await toFile(new Blob([binaryData]), block.file.origin_name || 'image.png', { type: mimeType })
})
)
imageFiles = imageFiles.concat(userImages.filter(Boolean) as Blob[])
// Collect images from last assistant message
if (lastAssistantMessage) {
const assistantImageBlocks = findImageBlocks(lastAssistantMessage)
const assistantImages = await Promise.all(
assistantImageBlocks.map(async (block) => {
const b64 = block.url?.replace(/^data:image\/\w+;base64,/, '')
if (!b64) return null
const binary = atob(b64)
const bytes = new Uint8Array(binary.length)
for (let i = 0; i < binary.length; i++) bytes[i] = binary.charCodeAt(i)
return await toFile(new Blob([bytes]), 'assistant_image.png', { type: 'image/png' })
})
)
imageFiles = imageFiles.concat(assistantImages.filter(Boolean) as Blob[])
}
enqueue({ type: ChunkType.IMAGE_CREATED })
const startTime = Date.now()
let response: OpenAI.Images.ImagesResponse
const options = { signal, timeout: 300_000 }
if (imageFiles.length > 0) {
response = await sdk.images.edit(
{
model: assistant.model.id,
image: imageFiles,
prompt: prompt || ''
},
options
)
} else {
response = await sdk.images.generate(
{
model: assistant.model.id,
prompt: prompt || '',
response_format: assistant.model.id.includes('gpt-image-1') ? undefined : 'b64_json'
},
options
)
}
const b64_json_array = response.data?.map((item) => `data:image/png;base64,${item.b64_json}`) || []
enqueue({
type: ChunkType.IMAGE_COMPLETE,
image: { type: 'base64', images: b64_json_array }
})
const usage = (response as any).usage || { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }
enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage,
metrics: {
completion_tokens: usage.completion_tokens,
time_first_token_millsec: 0,
time_completion_millsec: Date.now() - startTime
}
}
})
} catch (error: any) {
enqueue({ type: ChunkType.ERROR, error })
} finally {
controller.close()
}
}
})
return {
stream,
getText: () => ''
}
}
@@ -0,0 +1,136 @@
import { Model } from '@renderer/types'
import { ChunkType, TextDeltaChunk, ThinkingCompleteChunk, ThinkingDeltaChunk } from '@renderer/types/chunk'
import { TagConfig, TagExtractor } from '@renderer/utils/tagExtraction'
import Logger from 'electron-log/renderer'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'ThinkingTagExtractionMiddleware'
// 不同模型的思考标签配置
const reasoningTags: TagConfig[] = [
{ openingTag: '<think>', closingTag: '</think>', separator: '\n' },
{ openingTag: '###Thinking', closingTag: '###Response', separator: '\n' }
]
const getAppropriateTag = (model?: Model): TagConfig => {
if (model?.id?.includes('qwen3')) return reasoningTags[0]
// 可以在这里添加更多模型特定的标签配置
return reasoningTags[0] // 默认使用 <think> 标签
}
/**
* 处理文本流中思考标签提取的中间件
*
* 该中间件专门处理文本流中的思考标签内容(如 <think>...</think>
* 主要用于 OpenAI 等支持思考标签的 provider
*
* 职责:
* 1. 从文本流中提取思考标签内容
* 2. 将标签内的内容转换为 THINKING_DELTA chunk
* 3. 将标签外的内容作为正常文本输出
* 4. 处理不同模型的思考标签格式
* 5. 在思考内容结束时生成 THINKING_COMPLETE 事件
*/
export const ThinkingTagExtractionMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (context: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
// 调用下游中间件
const result = await next(context, params)
// 响应后处理:处理思考标签提取
if (result.stream) {
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
// 检查是否有流需要处理
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
// 获取当前模型的思考标签配置
const model = params.assistant?.model
const reasoningTag = getAppropriateTag(model)
// 创建标签提取器
const tagExtractor = new TagExtractor(reasoningTag)
// thinking 处理状态
let hasThinkingContent = false
let thinkingStartTime = 0
const processedStream = resultFromUpstream.pipeThrough(
new TransformStream<GenericChunk, GenericChunk>({
transform(chunk: GenericChunk, controller) {
if (chunk.type === ChunkType.TEXT_DELTA) {
const textChunk = chunk as TextDeltaChunk
// 使用 TagExtractor 处理文本
const extractionResults = tagExtractor.processText(textChunk.text)
for (const extractionResult of extractionResults) {
if (extractionResult.complete && extractionResult.tagContentExtracted) {
// 生成 THINKING_COMPLETE 事件
const thinkingCompleteChunk: ThinkingCompleteChunk = {
type: ChunkType.THINKING_COMPLETE,
text: extractionResult.tagContentExtracted,
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
}
controller.enqueue(thinkingCompleteChunk)
// 重置思考状态
hasThinkingContent = false
thinkingStartTime = 0
} else if (extractionResult.content.length > 0) {
if (extractionResult.isTagContent) {
// 第一次接收到思考内容时记录开始时间
if (!hasThinkingContent) {
hasThinkingContent = true
thinkingStartTime = Date.now()
}
const thinkingDeltaChunk: ThinkingDeltaChunk = {
type: ChunkType.THINKING_DELTA,
text: extractionResult.content,
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
}
controller.enqueue(thinkingDeltaChunk)
} else {
// 发送清理后的文本内容
const cleanTextChunk: TextDeltaChunk = {
...textChunk,
text: extractionResult.content
}
controller.enqueue(cleanTextChunk)
}
}
}
} else {
// 其他类型的chunk直接传递(包括 THINKING_DELTA, THINKING_COMPLETE 等)
controller.enqueue(chunk)
}
},
flush(controller) {
// 处理可能剩余的思考内容
const finalResult = tagExtractor.finalize()
if (finalResult?.tagContentExtracted) {
const thinkingCompleteChunk: ThinkingCompleteChunk = {
type: ChunkType.THINKING_COMPLETE,
text: finalResult.tagContentExtracted,
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
}
controller.enqueue(thinkingCompleteChunk)
}
}
})
)
// 更新响应结果
return {
...result,
stream: processedStream
}
} else {
Logger.warn(`[${MIDDLEWARE_NAME}] No generic chunk stream to process or not a ReadableStream.`)
}
}
return result
}
@@ -0,0 +1,124 @@
import { MCPTool } from '@renderer/types'
import { ChunkType, MCPToolCreatedChunk, TextDeltaChunk } from '@renderer/types/chunk'
import { parseToolUse } from '@renderer/utils/mcp-tools'
import { TagConfig, TagExtractor } from '@renderer/utils/tagExtraction'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
export const MIDDLEWARE_NAME = 'ToolUseExtractionMiddleware'
// 工具使用标签配置
const TOOL_USE_TAG_CONFIG: TagConfig = {
openingTag: '<tool_use>',
closingTag: '</tool_use>',
separator: '\n'
}
/**
* 工具使用提取中间件
*
* 职责:
* 1. 从文本流中检测并提取 <tool_use></tool_use> 标签
* 2. 解析工具调用信息并转换为 ToolUseResponse 格式
* 3. 生成 MCP_TOOL_CREATED chunk 供 McpToolChunkMiddleware 处理
* 4. 清理文本流,移除工具使用标签但保留正常文本
*
* 注意:此中间件只负责提取和转换,实际工具调用由 McpToolChunkMiddleware 处理
*/
export const ToolUseExtractionMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
const mcpTools = params.mcpTools || []
// 如果没有工具,直接调用下一个中间件
if (!mcpTools || mcpTools.length === 0) return next(ctx, params)
// 调用下游中间件
const result = await next(ctx, params)
// 响应后处理:处理工具使用标签提取
if (result.stream) {
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
const processedStream = resultFromUpstream.pipeThrough(createToolUseExtractionTransform(ctx, mcpTools))
return {
...result,
stream: processedStream
}
}
return result
}
/**
* 创建工具使用提取的 TransformStream
*/
function createToolUseExtractionTransform(
_ctx: CompletionsContext,
mcpTools: MCPTool[]
): TransformStream<GenericChunk, GenericChunk> {
const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
return new TransformStream({
async transform(chunk: GenericChunk, controller) {
try {
// 处理文本内容,检测工具使用标签
if (chunk.type === ChunkType.TEXT_DELTA) {
const textChunk = chunk as TextDeltaChunk
const extractionResults = tagExtractor.processText(textChunk.text)
for (const result of extractionResults) {
if (result.complete && result.tagContentExtracted) {
// 提取到完整的工具使用内容,解析并转换为 SDK ToolCall 格式
const toolUseResponses = parseToolUse(result.tagContentExtracted, mcpTools)
if (toolUseResponses.length > 0) {
// 生成 MCP_TOOL_CREATED chunk,复用现有的处理流程
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
type: ChunkType.MCP_TOOL_CREATED,
tool_use_responses: toolUseResponses
}
controller.enqueue(mcpToolCreatedChunk)
}
} else if (!result.isTagContent && result.content) {
// 发送标签外的正常文本内容
const cleanTextChunk: TextDeltaChunk = {
...textChunk,
text: result.content
}
controller.enqueue(cleanTextChunk)
}
// 注意:标签内的内容不会作为TEXT_DELTA转发,避免重复显示
}
return
}
// 转发其他所有chunk
controller.enqueue(chunk)
} catch (error) {
console.error(`🔧 [${MIDDLEWARE_NAME}] Error processing chunk:`, error)
controller.error(error)
}
},
async flush(controller) {
// 检查是否有未完成的标签内容
const finalResult = tagExtractor.finalize()
if (finalResult && finalResult.tagContentExtracted) {
const toolUseResponses = parseToolUse(finalResult.tagContentExtracted, mcpTools)
if (toolUseResponses.length > 0) {
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
type: ChunkType.MCP_TOOL_CREATED,
tool_use_responses: toolUseResponses
}
controller.enqueue(mcpToolCreatedChunk)
}
}
}
})
}
export default ToolUseExtractionMiddleware
@@ -0,0 +1,88 @@
import { CompletionsMiddleware, MethodMiddleware } from './types'
// /**
// * Wraps a provider instance with middlewares.
// */
// export function wrapProviderWithMiddleware(
// apiClientInstance: BaseApiClient,
// middlewareConfig: MiddlewareConfig
// ): BaseApiClient {
// console.log(`[wrapProviderWithMiddleware] Wrapping provider: ${apiClientInstance.provider?.id}`)
// console.log(`[wrapProviderWithMiddleware] Middleware config:`, {
// completions: middlewareConfig.completions?.length || 0,
// methods: Object.keys(middlewareConfig.methods || {}).length
// })
// // Cache for already wrapped methods to avoid re-wrapping on every access.
// const wrappedMethodsCache = new Map<string, (...args: any[]) => Promise<any>>()
// const proxy = new Proxy(apiClientInstance, {
// get(target, propKey, receiver) {
// const methodName = typeof propKey === 'string' ? propKey : undefined
// if (!methodName) {
// return Reflect.get(target, propKey, receiver)
// }
// if (wrappedMethodsCache.has(methodName)) {
// console.log(`[wrapProviderWithMiddleware] Using cached wrapped method: ${methodName}`)
// return wrappedMethodsCache.get(methodName)
// }
// const originalMethod = Reflect.get(target, propKey, receiver)
// // If the property is not a function, return it directly.
// if (typeof originalMethod !== 'function') {
// return originalMethod
// }
// let wrappedMethod: ((...args: any[]) => Promise<any>) | undefined
// // Handle completions method
// if (methodName === 'completions' && middlewareConfig.completions?.length) {
// console.log(
// `[wrapProviderWithMiddleware] Wrapping completions method with ${middlewareConfig.completions.length} middlewares`
// )
// const completionsOriginalMethod = originalMethod as (params: CompletionsParams) => Promise<any>
// wrappedMethod = applyCompletionsMiddlewares(target, completionsOriginalMethod, middlewareConfig.completions)
// }
// // Handle other methods
// else {
// const methodMiddlewares = middlewareConfig.methods?.[methodName]
// if (methodMiddlewares?.length) {
// console.log(
// `[wrapProviderWithMiddleware] Wrapping method ${methodName} with ${methodMiddlewares.length} middlewares`
// )
// const genericOriginalMethod = originalMethod as (...args: any[]) => Promise<any>
// wrappedMethod = applyMethodMiddlewares(target, methodName, genericOriginalMethod, methodMiddlewares)
// }
// }
// if (wrappedMethod) {
// console.log(`[wrapProviderWithMiddleware] Successfully wrapped method: ${methodName}`)
// wrappedMethodsCache.set(methodName, wrappedMethod)
// return wrappedMethod
// }
// // If no middlewares are configured for this method, return the original method bound to the target. /
// // 如果没有为此方法配置中间件,则返回绑定到目标的原始方法。
// console.log(`[wrapProviderWithMiddleware] No middlewares for method ${methodName}, returning original`)
// return originalMethod.bind(target)
// }
// })
// return proxy as BaseApiClient
// }
// Export types for external use
export type { CompletionsMiddleware, MethodMiddleware }
// Export MiddlewareBuilder related types and classes
export {
CompletionsMiddlewareBuilder,
createCompletionsBuilder,
createMethodBuilder,
MethodMiddlewareBuilder,
MiddlewareBuilder,
type MiddlewareExecutor,
type NamedMiddleware
} from './builder'
@@ -0,0 +1,149 @@
import * as AbortHandlerModule from './common/AbortHandlerMiddleware'
import * as ErrorHandlerModule from './common/ErrorHandlerMiddleware'
import * as FinalChunkConsumerModule from './common/FinalChunkConsumerMiddleware'
import * as LoggingModule from './common/LoggingMiddleware'
import * as McpToolChunkModule from './core/McpToolChunkMiddleware'
import * as RawStreamListenerModule from './core/RawStreamListenerMiddleware'
import * as ResponseTransformModule from './core/ResponseTransformMiddleware'
// import * as SdkCallModule from './core/SdkCallMiddleware'
import * as StreamAdapterModule from './core/StreamAdapterMiddleware'
import * as TextChunkModule from './core/TextChunkMiddleware'
import * as ThinkChunkModule from './core/ThinkChunkMiddleware'
import * as TransformCoreToSdkParamsModule from './core/TransformCoreToSdkParamsMiddleware'
import * as WebSearchModule from './core/WebSearchMiddleware'
import * as ImageGenerationModule from './feat/ImageGenerationMiddleware'
import * as ThinkingTagExtractionModule from './feat/ThinkingTagExtractionMiddleware'
import * as ToolUseExtractionMiddleware from './feat/ToolUseExtractionMiddleware'
/**
* 中间件注册表 - 提供所有可用中间件的集中访问
* 注意:目前中间件文件还未导出 MIDDLEWARE_NAME,会有 linter 错误,这是正常的
*/
export const MiddlewareRegistry = {
[ErrorHandlerModule.MIDDLEWARE_NAME]: {
name: ErrorHandlerModule.MIDDLEWARE_NAME,
middleware: ErrorHandlerModule.ErrorHandlerMiddleware
},
// 通用中间件
[AbortHandlerModule.MIDDLEWARE_NAME]: {
name: AbortHandlerModule.MIDDLEWARE_NAME,
middleware: AbortHandlerModule.AbortHandlerMiddleware
},
[FinalChunkConsumerModule.MIDDLEWARE_NAME]: {
name: FinalChunkConsumerModule.MIDDLEWARE_NAME,
middleware: FinalChunkConsumerModule.default
},
// 核心流程中间件
[TransformCoreToSdkParamsModule.MIDDLEWARE_NAME]: {
name: TransformCoreToSdkParamsModule.MIDDLEWARE_NAME,
middleware: TransformCoreToSdkParamsModule.TransformCoreToSdkParamsMiddleware
},
// [SdkCallModule.MIDDLEWARE_NAME]: {
// name: SdkCallModule.MIDDLEWARE_NAME,
// middleware: SdkCallModule.SdkCallMiddleware
// },
[StreamAdapterModule.MIDDLEWARE_NAME]: {
name: StreamAdapterModule.MIDDLEWARE_NAME,
middleware: StreamAdapterModule.StreamAdapterMiddleware
},
[RawStreamListenerModule.MIDDLEWARE_NAME]: {
name: RawStreamListenerModule.MIDDLEWARE_NAME,
middleware: RawStreamListenerModule.RawStreamListenerMiddleware
},
[ResponseTransformModule.MIDDLEWARE_NAME]: {
name: ResponseTransformModule.MIDDLEWARE_NAME,
middleware: ResponseTransformModule.ResponseTransformMiddleware
},
// 特性处理中间件
[ThinkingTagExtractionModule.MIDDLEWARE_NAME]: {
name: ThinkingTagExtractionModule.MIDDLEWARE_NAME,
middleware: ThinkingTagExtractionModule.ThinkingTagExtractionMiddleware
},
[ToolUseExtractionMiddleware.MIDDLEWARE_NAME]: {
name: ToolUseExtractionMiddleware.MIDDLEWARE_NAME,
middleware: ToolUseExtractionMiddleware.ToolUseExtractionMiddleware
},
[ThinkChunkModule.MIDDLEWARE_NAME]: {
name: ThinkChunkModule.MIDDLEWARE_NAME,
middleware: ThinkChunkModule.ThinkChunkMiddleware
},
[McpToolChunkModule.MIDDLEWARE_NAME]: {
name: McpToolChunkModule.MIDDLEWARE_NAME,
middleware: McpToolChunkModule.McpToolChunkMiddleware
},
[WebSearchModule.MIDDLEWARE_NAME]: {
name: WebSearchModule.MIDDLEWARE_NAME,
middleware: WebSearchModule.WebSearchMiddleware
},
[TextChunkModule.MIDDLEWARE_NAME]: {
name: TextChunkModule.MIDDLEWARE_NAME,
middleware: TextChunkModule.TextChunkMiddleware
},
[ImageGenerationModule.MIDDLEWARE_NAME]: {
name: ImageGenerationModule.MIDDLEWARE_NAME,
middleware: ImageGenerationModule.ImageGenerationMiddleware
}
} as const
/**
* 根据名称获取中间件
* @param name - 中间件名称
* @returns 对应的中间件信息
*/
export function getMiddleware(name: string) {
return MiddlewareRegistry[name]
}
/**
* 获取所有注册的中间件名称
* @returns 中间件名称列表
*/
export function getRegisteredMiddlewareNames(): string[] {
return Object.keys(MiddlewareRegistry)
}
/**
* 默认的 Completions 中间件配置 - NamedMiddleware 格式,用于 MiddlewareBuilder
*/
export const DefaultCompletionsNamedMiddlewares = [
MiddlewareRegistry[FinalChunkConsumerModule.MIDDLEWARE_NAME], // 最终消费者
MiddlewareRegistry[ErrorHandlerModule.MIDDLEWARE_NAME], // 错误处理
MiddlewareRegistry[TransformCoreToSdkParamsModule.MIDDLEWARE_NAME], // 参数转换
MiddlewareRegistry[AbortHandlerModule.MIDDLEWARE_NAME], // 中止处理
MiddlewareRegistry[McpToolChunkModule.MIDDLEWARE_NAME], // 工具处理
MiddlewareRegistry[TextChunkModule.MIDDLEWARE_NAME], // 文本处理
MiddlewareRegistry[WebSearchModule.MIDDLEWARE_NAME], // Web搜索处理
MiddlewareRegistry[ToolUseExtractionMiddleware.MIDDLEWARE_NAME], // 工具使用提取处理
MiddlewareRegistry[ThinkingTagExtractionModule.MIDDLEWARE_NAME], // 思考标签提取处理(特定provider)
MiddlewareRegistry[ThinkChunkModule.MIDDLEWARE_NAME], // 思考处理(通用SDK
MiddlewareRegistry[ResponseTransformModule.MIDDLEWARE_NAME], // 响应转换
MiddlewareRegistry[StreamAdapterModule.MIDDLEWARE_NAME], // 流适配器
MiddlewareRegistry[RawStreamListenerModule.MIDDLEWARE_NAME] // 原始流监听器
]
/**
* 默认的通用方法中间件 - 例如翻译、摘要等
*/
export const DefaultMethodMiddlewares = {
translate: [LoggingModule.createGenericLoggingMiddleware()],
summaries: [LoggingModule.createGenericLoggingMiddleware()]
}
/**
* 导出所有中间件模块,方便外部使用
*/
export {
AbortHandlerModule,
FinalChunkConsumerModule,
LoggingModule,
McpToolChunkModule,
ResponseTransformModule,
StreamAdapterModule,
TextChunkModule,
ThinkChunkModule,
ThinkingTagExtractionModule,
TransformCoreToSdkParamsModule,
WebSearchModule
}
@@ -0,0 +1,77 @@
import { Assistant, MCPTool } from '@renderer/types'
import { Chunk } from '@renderer/types/chunk'
import { Message } from '@renderer/types/newMessage'
import { SdkRawChunk, SdkRawOutput } from '@renderer/types/sdk'
import { ProcessingState } from './types'
// ============================================================================
// Core Request Types - 核心请求结构
// ============================================================================
/**
* 标准化的内部核心请求结构,用于所有AI Provider的统一处理
* 这是应用层参数转换后的标准格式,不包含回调函数和控制逻辑
*/
export interface CompletionsParams {
/**
* 调用的业务场景类型,用于中间件判断是否执行
* 'chat': 主要对话流程
* 'translate': 翻译
* 'summary': 摘要
* 'search': 搜索摘要
* 'generate': 生成
* 'check': API检查
*/
callType?: 'chat' | 'translate' | 'summary' | 'search' | 'generate' | 'check'
// 基础对话数据
messages: Message[] | string // 联合类型方便判断是否为空
assistant: Assistant // 助手为基本单位
// model: Model
onChunk?: (chunk: Chunk) => void
onResponse?: (text: string, isComplete: boolean) => void
// 错误相关
onError?: (error: Error) => void
shouldThrow?: boolean
// 工具相关
mcpTools?: MCPTool[]
// 生成参数
temperature?: number
topP?: number
maxTokens?: number
// 功能开关
streamOutput: boolean
enableWebSearch?: boolean
enableReasoning?: boolean
enableGenerateImage?: boolean
// 上下文控制
contextCount?: number
_internal?: ProcessingState
}
export interface CompletionsResult {
rawOutput?: SdkRawOutput
stream?: ReadableStream<SdkRawChunk> | ReadableStream<Chunk> | AsyncIterable<Chunk>
controller?: AbortController
getText: () => string
}
// ============================================================================
// Generic Chunk Types - 通用数据块结构
// ============================================================================
/**
* 通用数据块类型
* 复用现有的 Chunk 类型,这是所有AI Provider都应该输出的标准化数据块格式
*/
export type GenericChunk = Chunk
+166
View File
@@ -0,0 +1,166 @@
import { MCPToolResponse, Metrics, Usage, WebSearchResponse } from '@renderer/types'
import { Chunk, ErrorChunk } from '@renderer/types/chunk'
import {
SdkInstance,
SdkMessageParam,
SdkParams,
SdkRawChunk,
SdkRawOutput,
SdkTool,
SdkToolCall
} from '@renderer/types/sdk'
import { BaseApiClient } from '../clients'
import { CompletionsParams, CompletionsResult } from './schemas'
/**
* Symbol to uniquely identify middleware context objects.
*/
export const MIDDLEWARE_CONTEXT_SYMBOL = Symbol.for('AiProviderMiddlewareContext')
/**
* Defines the structure for the onChunk callback function.
*/
export type OnChunkFunction = (chunk: Chunk | ErrorChunk) => void
/**
* Base context that carries information about the current method call.
*/
export interface BaseContext {
[MIDDLEWARE_CONTEXT_SYMBOL]: true
methodName: string
originalArgs: Readonly<any[]>
}
/**
* Processing state shared between middlewares.
*/
export interface ProcessingState<
TParams extends SdkParams = SdkParams,
TMessageParam extends SdkMessageParam = SdkMessageParam,
TToolCall extends SdkToolCall = SdkToolCall
> {
sdkPayload?: TParams
newReqMessages?: TMessageParam[]
observer?: {
usage?: Usage
metrics?: Metrics
}
toolProcessingState?: {
pendingToolCalls?: Array<TToolCall>
executingToolCalls?: Array<{
sdkToolCall: TToolCall
mcpToolResponse: MCPToolResponse
}>
output?: SdkRawOutput | string
isRecursiveCall?: boolean
recursionDepth?: number
}
webSearchState?: {
results?: WebSearchResponse
}
flowControl?: {
abortController?: AbortController
abortSignal?: AbortSignal
cleanup?: () => void
}
enhancedDispatch?: (context: CompletionsContext, params: CompletionsParams) => Promise<CompletionsResult>
customState?: Record<string, any>
}
/**
* Extended context for completions method.
*/
export interface CompletionsContext<
TSdkParams extends SdkParams = SdkParams,
TSdkMessageParam extends SdkMessageParam = SdkMessageParam,
TSdkToolCall extends SdkToolCall = SdkToolCall,
TSdkInstance extends SdkInstance = SdkInstance,
TRawOutput extends SdkRawOutput = SdkRawOutput,
TRawChunk extends SdkRawChunk = SdkRawChunk,
TSdkSpecificTool extends SdkTool = SdkTool
> extends BaseContext {
readonly methodName: 'completions' // 强制方法名为 'completions'
apiClientInstance: BaseApiClient<
TSdkInstance,
TSdkParams,
TRawOutput,
TRawChunk,
TSdkMessageParam,
TSdkToolCall,
TSdkSpecificTool
>
// --- Mutable internal state for the duration of the middleware chain ---
_internal: ProcessingState<TSdkParams, TSdkMessageParam, TSdkToolCall> // 包含所有可变的处理状态
}
export interface MiddlewareAPI<Ctx extends BaseContext = BaseContext, Args extends any[] = any[]> {
getContext: () => Ctx // Function to get the current context / 获取当前上下文的函数
getOriginalArgs: () => Args // Function to get the original arguments of the method call / 获取方法调用原始参数的函数
}
/**
* Base middleware type.
*/
export type Middleware<TContext extends BaseContext> = (
api: MiddlewareAPI<TContext>
) => (
next: (context: TContext, args: any[]) => Promise<unknown>
) => (context: TContext, args: any[]) => Promise<unknown>
export type MethodMiddleware = Middleware<BaseContext>
/**
* Completions middleware type.
*/
export type CompletionsMiddleware<
TSdkParams extends SdkParams = SdkParams,
TSdkMessageParam extends SdkMessageParam = SdkMessageParam,
TSdkToolCall extends SdkToolCall = SdkToolCall,
TSdkInstance extends SdkInstance = SdkInstance,
TRawOutput extends SdkRawOutput = SdkRawOutput,
TRawChunk extends SdkRawChunk = SdkRawChunk,
TSdkSpecificTool extends SdkTool = SdkTool
> = (
api: MiddlewareAPI<
CompletionsContext<
TSdkParams,
TSdkMessageParam,
TSdkToolCall,
TSdkInstance,
TRawOutput,
TRawChunk,
TSdkSpecificTool
>,
[CompletionsParams]
>
) => (
next: (
context: CompletionsContext<
TSdkParams,
TSdkMessageParam,
TSdkToolCall,
TSdkInstance,
TRawOutput,
TRawChunk,
TSdkSpecificTool
>,
params: CompletionsParams
) => Promise<CompletionsResult>
) => (
context: CompletionsContext<
TSdkParams,
TSdkMessageParam,
TSdkToolCall,
TSdkInstance,
TRawOutput,
TRawChunk,
TSdkSpecificTool
>,
params: CompletionsParams
) => Promise<CompletionsResult>
// Re-export for convenience
export type { Chunk as OnChunkArg } from '@renderer/types/chunk'
@@ -0,0 +1,57 @@
import { ChunkType, ErrorChunk } from '@renderer/types/chunk'
/**
* Creates an ErrorChunk object with a standardized structure.
* @param error The error object or message.
* @param chunkType The type of chunk, defaults to ChunkType.ERROR.
* @returns An ErrorChunk object.
*/
export function createErrorChunk(error: any, chunkType: ChunkType = ChunkType.ERROR): ErrorChunk {
let errorDetails: Record<string, any> = {}
if (error instanceof Error) {
errorDetails = {
message: error.message,
name: error.name,
stack: error.stack
}
} else if (typeof error === 'string') {
errorDetails = { message: error }
} else if (typeof error === 'object' && error !== null) {
errorDetails = Object.getOwnPropertyNames(error).reduce(
(acc, key) => {
acc[key] = error[key]
return acc
},
{} as Record<string, any>
)
if (!errorDetails.message && error.toString && typeof error.toString === 'function') {
const errMsg = error.toString()
if (errMsg !== '[object Object]') {
errorDetails.message = errMsg
}
}
}
return {
type: chunkType,
error: errorDetails
} as ErrorChunk
}
// Helper to capitalize method names for hook construction
export function capitalize(str: string): string {
if (!str) return ''
return str.charAt(0).toUpperCase() + str.slice(1)
}
/**
* 检查对象是否实现了AsyncIterable接口
*/
export function isAsyncIterable<T = unknown>(obj: unknown): obj is AsyncIterable<T> {
return (
obj !== null &&
typeof obj === 'object' &&
typeof (obj as Record<symbol, unknown>)[Symbol.asyncIterator] === 'function'
)
}
Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 6.9 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 5.9 KiB

+9 -5
View File
@@ -295,13 +295,16 @@ emoji-picker {
--border-size: 0;
}
.katex-display {
.katex,
mjx-container {
display: inline-block;
overflow-x: auto;
overflow-y: hidden;
}
mjx-container {
overflow-x: auto;
overflow-wrap: break-word;
vertical-align: middle;
max-width: 100%;
padding: 1px 2px;
margin-top: -2px;
}
/* CodeMirror 相关样式 */
@@ -318,6 +321,7 @@ mjx-container {
.cm-gutters {
line-height: 1.6;
border-right: none;
}
.cm-content {
@@ -22,6 +22,7 @@ const MermaidPreview: React.FC<Props> = ({ children, setTools }) => {
const diagramId = useRef<string>(`mermaid-${nanoid(6)}`).current
const [error, setError] = useState<string | null>(null)
const [isRendering, setIsRendering] = useState(false)
const [isVisible, setIsVisible] = useState(true)
// 使用通用图像工具
const { handleZoom, handleCopyImage, handleDownload } = usePreviewToolHandlers(mermaidRef, {
@@ -75,10 +76,55 @@ const MermaidPreview: React.FC<Props> = ({ children, setTools }) => {
[renderMermaid]
)
/**
* 监听可见性变化,用于触发重新渲染。
* 这是为了解决 `MessageGroup` 组件的 `fold` 布局中被 `display: none` 隐藏的图标无法正确渲染的问题。
* 监听时向上遍历到第一个有 `fold` className 的父节点为止(也就是目前的 `MessageWrapper`)。
* FIXME: 将来 mermaid-js 修复此问题后可以移除这里的相关逻辑。
*/
useEffect(() => {
if (!mermaidRef.current) return
const checkVisibility = () => {
const element = mermaidRef.current
if (!element) return
const currentlyVisible = element.offsetParent !== null
setIsVisible(currentlyVisible)
}
// 初始检查
checkVisibility()
const observer = new MutationObserver(() => {
checkVisibility()
})
let targetElement = mermaidRef.current.parentElement
while (targetElement) {
observer.observe(targetElement, {
attributes: true,
attributeFilter: ['class', 'style']
})
if (targetElement.className?.includes('fold')) {
break
}
targetElement = targetElement.parentElement
}
return () => {
observer.disconnect()
}
}, [])
// 触发渲染
useEffect(() => {
if (isLoadingMermaid) return
if (mermaidRef.current?.offsetParent === null) return
if (children) {
setIsRendering(true)
debouncedRender(children)
@@ -90,7 +136,7 @@ const MermaidPreview: React.FC<Props> = ({ children, setTools }) => {
return () => {
debouncedRender.cancel()
}
}, [children, isLoadingMermaid, debouncedRender])
}, [children, isLoadingMermaid, debouncedRender, isVisible])
const isLoading = isLoadingMermaid || isRendering
@@ -0,0 +1,221 @@
import { render, screen, waitFor } from '@testing-library/react'
import { act } from 'react'
import { afterEach, beforeEach, describe, expect, it, type Mock, vi } from 'vitest'
import MermaidPreview from '../CodeBlockView/MermaidPreview'
const mocks = vi.hoisted(() => ({
useMermaid: vi.fn(),
usePreviewToolHandlers: vi.fn(),
usePreviewTools: vi.fn()
}))
// Mock hooks
vi.mock('@renderer/hooks/useMermaid', () => ({
useMermaid: () => mocks.useMermaid()
}))
vi.mock('@renderer/components/CodeToolbar', () => ({
usePreviewToolHandlers: () => mocks.usePreviewToolHandlers(),
usePreviewTools: () => mocks.usePreviewTools()
}))
// Mock nanoid
vi.mock('@reduxjs/toolkit', () => ({
nanoid: () => 'test-id-123456'
}))
// Mock lodash debounce
vi.mock('lodash', async () => {
const actual = await import('lodash')
return {
...actual,
debounce: vi.fn((fn) => {
const debounced = (...args: any[]) => fn(...args)
debounced.cancel = vi.fn()
return debounced
})
}
})
// Mock antd components
vi.mock('antd', () => ({
Flex: ({ children, vertical, ...props }: any) => (
<div data-testid="flex" data-vertical={vertical} {...props}>
{children}
</div>
),
Spin: ({ children, spinning, indicator }: any) => (
<div data-testid="spin" data-spinning={spinning}>
{spinning && indicator}
{children}
</div>
)
}))
describe('MermaidPreview', () => {
const mockMermaid = {
parse: vi.fn(),
render: vi.fn()
}
beforeEach(() => {
vi.clearAllMocks()
mocks.useMermaid.mockReturnValue({
mermaid: mockMermaid,
isLoading: false,
error: null
})
mocks.usePreviewToolHandlers.mockReturnValue({
handleZoom: vi.fn(),
handleCopyImage: vi.fn(),
handleDownload: vi.fn()
})
mocks.usePreviewTools.mockReturnValue({})
mockMermaid.parse.mockResolvedValue(true)
mockMermaid.render.mockResolvedValue({
svg: '<svg class="flowchart" viewBox="0 0 100 100"><g>test diagram</g></svg>'
})
// Mock MutationObserver
global.MutationObserver = vi.fn().mockImplementation(() => ({
observe: vi.fn(),
disconnect: vi.fn(),
takeRecords: vi.fn()
}))
})
afterEach(() => {
vi.restoreAllMocks()
})
describe('visibility detection', () => {
it('should not render mermaid when element has display: none', async () => {
const mermaidCode = 'graph TD\nA-->B'
const { container } = render(<MermaidPreview>{mermaidCode}</MermaidPreview>)
// Mock offsetParent to be null (simulating display: none)
const mermaidElement = container.querySelector('.mermaid')
if (mermaidElement) {
Object.defineProperty(mermaidElement, 'offsetParent', {
get: () => null,
configurable: true
})
}
// Re-render to trigger the effect
render(<MermaidPreview>{mermaidCode}</MermaidPreview>)
// Should not call mermaid render when offsetParent is null
expect(mockMermaid.render).not.toHaveBeenCalled()
const svgElement = mermaidElement?.querySelector('svg.flowchart')
expect(svgElement).not.toBeInTheDocument()
})
it('should setup MutationObserver to monitor parent elements', () => {
const mermaidCode = 'graph TD\nA-->B'
render(<MermaidPreview>{mermaidCode}</MermaidPreview>)
expect(global.MutationObserver).toHaveBeenCalledWith(expect.any(Function))
})
it('should observe parent elements up to fold className', () => {
const mermaidCode = 'graph TD\nA-->B'
// Create a DOM structure that simulates MessageGroup fold layout
const foldContainer = document.createElement('div')
foldContainer.className = 'fold selected'
const messageWrapper = document.createElement('div')
messageWrapper.className = 'message-wrapper'
const codeBlock = document.createElement('div')
codeBlock.className = 'code-block'
foldContainer.appendChild(messageWrapper)
messageWrapper.appendChild(codeBlock)
document.body.appendChild(foldContainer)
render(<MermaidPreview>{mermaidCode}</MermaidPreview>, {
container: codeBlock
})
const observerInstance = (global.MutationObserver as Mock).mock.results[0]?.value
expect(observerInstance.observe).toHaveBeenCalled()
// Cleanup
document.body.removeChild(foldContainer)
})
it('should trigger re-render when visibility changes from hidden to visible', async () => {
const mermaidCode = 'graph TD\nA-->B'
const { container, rerender } = render(<MermaidPreview>{mermaidCode}</MermaidPreview>)
const mermaidElement = container.querySelector('.mermaid')
// Initially hidden (offsetParent is null)
Object.defineProperty(mermaidElement, 'offsetParent', {
get: () => null,
configurable: true
})
// Clear previous calls
mockMermaid.render.mockClear()
// Re-render with hidden state
rerender(<MermaidPreview>{mermaidCode}</MermaidPreview>)
// Should not render when hidden
expect(mockMermaid.render).not.toHaveBeenCalled()
// Now make it visible
Object.defineProperty(mermaidElement, 'offsetParent', {
get: () => document.body,
configurable: true
})
// Simulate MutationObserver callback
const observerCallback = (global.MutationObserver as Mock).mock.calls[0][0]
act(() => {
observerCallback([])
})
// Re-render to trigger visibility change effect
rerender(<MermaidPreview>{mermaidCode}</MermaidPreview>)
await waitFor(() => {
expect(mockMermaid.render).toHaveBeenCalledWith('mermaid-test-id-123456', mermaidCode, expect.any(Object))
const svgElement = mermaidElement?.querySelector('svg.flowchart')
expect(svgElement).toBeInTheDocument()
expect(svgElement).toHaveClass('flowchart')
})
})
it('should handle mermaid loading state', () => {
mocks.useMermaid.mockReturnValue({
mermaid: mockMermaid,
isLoading: true,
error: null
})
const mermaidCode = 'graph TD\nA-->B'
render(<MermaidPreview>{mermaidCode}</MermaidPreview>)
// Should not render when mermaid is loading
expect(mockMermaid.render).not.toHaveBeenCalled()
// Should show loading state
expect(screen.getByTestId('spin')).toHaveAttribute('data-spinning', 'true')
})
})
})
+230 -65
View File
@@ -55,6 +55,7 @@ import {
default as ChatGptModelLogoDakr,
default as ChatGPTo1ModelLogoDark
} from '@renderer/assets/images/models/gpt_dark.png'
import ChatGPTImageModelLogo from '@renderer/assets/images/models/gpt_image_1.png'
import ChatGPTo1ModelLogo from '@renderer/assets/images/models/gpt_o1.png'
import GrokModelLogo from '@renderer/assets/images/models/grok.png'
import GrokModelLogoDark from '@renderer/assets/images/models/grok_dark.png'
@@ -143,7 +144,7 @@ import YiModelLogoDark from '@renderer/assets/images/models/yi_dark.png'
import YoudaoLogo from '@renderer/assets/images/providers/netease-youdao.svg'
import NomicLogo from '@renderer/assets/images/providers/nomic.png'
import { getProviderByModel } from '@renderer/services/AssistantService'
import { Assistant, Model } from '@renderer/types'
import { Model } from '@renderer/types'
import OpenAI from 'openai'
import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from './prompts'
@@ -181,7 +182,8 @@ const visionAllowedModels = [
'o4(?:-[\\w-]+)?',
'deepseek-vl(?:[\\w-]+)?',
'kimi-latest',
'gemma-3(?:-[\\w-]+)'
'gemma-3(?:-[\\w-]+)',
'doubao-1.6-seed(?:-[\\w-]+)'
]
const visionExcludedModels = [
@@ -199,6 +201,11 @@ export const VISION_REGEX = new RegExp(
'i'
)
// For middleware to identify models that must use the dedicated Image API
export const DEDICATED_IMAGE_MODELS = ['grok-2-image', 'dall-e-3', 'dall-e-2', 'gpt-image-1']
export const isDedicatedImageGenerationModel = (model: Model): boolean =>
DEDICATED_IMAGE_MODELS.filter((m) => model.id.includes(m)).length > 0
// Text to image models
export const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-|dall|cogview|janus/i
@@ -286,6 +293,7 @@ export function getModelLogo(modelId: string) {
o1: isLight ? ChatGPTo1ModelLogo : ChatGPTo1ModelLogoDark,
o3: isLight ? ChatGPTo1ModelLogo : ChatGPTo1ModelLogoDark,
o4: isLight ? ChatGPTo1ModelLogo : ChatGPTo1ModelLogoDark,
'gpt-image': ChatGPTImageModelLogo,
'gpt-3': isLight ? ChatGPT35ModelLogo : ChatGPT35ModelLogoDark,
'gpt-4': isLight ? ChatGPT4ModelLogo : ChatGPT4ModelLogoDark,
gpts: isLight ? ChatGPT4ModelLogo : ChatGPT4ModelLogoDark,
@@ -307,6 +315,7 @@ export function getModelLogo(modelId: string) {
mistral: isLight ? MistralModelLogo : MistralModelLogoDark,
codestral: CodestralModelLogo,
ministral: isLight ? MistralModelLogo : MistralModelLogoDark,
magistral: isLight ? MistralModelLogo : MistralModelLogoDark,
moonshot: isLight ? MoonshotModelLogo : MoonshotModelLogoDark,
kimi: isLight ? MoonshotModelLogo : MoonshotModelLogoDark,
phi: isLight ? MicrosoftModelLogo : MicrosoftModelLogoDark,
@@ -429,7 +438,86 @@ export const SYSTEM_MODELS: Record<string, Model[]> = {
group: 'deepseek-ai'
}
],
'302ai': [
{
id: 'deepseek-chat',
name: 'deepseek-chat',
provider: '302ai',
group: 'DeepSeek'
},
{
id: 'deepseek-reasoner',
name: 'deepseek-reasoner',
provider: '302ai',
group: 'DeepSeek'
},
{
id: 'chatgpt-4o-latest',
name: 'chatgpt-4o-latest',
provider: '302ai',
group: 'OpenAI'
},
{
id: 'gpt-4.1',
name: 'gpt-4.1',
provider: '302ai',
group: 'OpenAI'
},
{
id: 'o3',
name: 'o3',
provider: '302ai',
group: 'OpenAI'
},
{
id: 'o4-mini',
name: 'o4-mini',
provider: '302ai',
group: 'OpenAI'
},
{
id: 'qwen3-235b-a22b',
name: 'qwen3-235b-a22b',
provider: '302ai',
group: 'Qwen'
},
{
id: 'gemini-2.5-flash-preview-05-20',
name: 'gemini-2.5-flash-preview-05-20',
provider: '302ai',
group: 'Gemini'
},
{
id: 'gemini-2.5-pro-preview-06-05',
name: 'gemini-2.5-pro-preview-06-05',
provider: '302ai',
group: 'Gemini'
},
{
id: 'claude-sonnet-4-20250514',
provider: '302ai',
name: 'claude-sonnet-4-20250514',
group: 'Anthropic'
},
{
id: 'claude-opus-4-20250514',
provider: '302ai',
name: 'claude-opus-4-20250514',
group: 'Anthropic'
},
{
id: 'jina-clip-v2',
name: 'jina-clip-v2',
provider: '302ai',
group: 'Jina AI'
},
{
id: 'jina-reranker-m0',
name: 'jina-reranker-m0',
provider: '302ai',
group: 'Jina AI'
}
],
aihubmix: [
{
id: 'gpt-4o',
@@ -2082,6 +2170,14 @@ export const SYSTEM_MODELS: Record<string, Model[]> = {
name: 'Qwen Plus',
group: 'Qwen'
}
],
cephalon: [
{
id: 'DeepSeek-R1',
provider: 'cephalon',
name: 'DeepSeek-R1满血版',
group: 'DeepSeek'
}
]
}
@@ -2159,14 +2255,24 @@ export const TEXT_TO_IMAGES_MODELS_SUPPORT_IMAGE_ENHANCEMENT = [
'stabilityai/stable-diffusion-xl-base-1.0'
]
export const SUPPORTED_DISABLE_GENERATION_MODELS = [
'gemini-2.0-flash-exp',
'gpt-4o',
'gpt-4o-mini',
'gpt-4.1',
'gpt-4.1-mini',
'gpt-4.1-nano',
'o3'
]
export const GENERATE_IMAGE_MODELS = [
'gemini-2.0-flash-exp-image-generation',
'gemini-2.0-flash-preview-image-generation',
'gemini-2.0-flash-exp',
'grok-2-image-1212',
'grok-2-image',
'grok-2-image-latest',
'gpt-image-1'
'gpt-image-1',
...SUPPORTED_DISABLE_GENERATION_MODELS
]
export const GEMINI_SEARCH_MODELS = [
@@ -2275,10 +2381,32 @@ export function isSupportedReasoningEffortOpenAIModel(model: Model): boolean {
)
}
export function isOpenAIWebSearch(model: Model): boolean {
export function isOpenAIChatCompletionOnlyModel(model: Model): boolean {
if (!model) {
return false
}
return (
model.id.includes('gpt-4o-search-preview') ||
model.id.includes('gpt-4o-mini-search-preview') ||
model.id.includes('o1-mini') ||
model.id.includes('o1-preview')
)
}
export function isOpenAIWebSearchChatCompletionOnlyModel(model: Model): boolean {
return model.id.includes('gpt-4o-search-preview') || model.id.includes('gpt-4o-mini-search-preview')
}
export function isOpenAIWebSearchModel(model: Model): boolean {
return (
model.id.includes('gpt-4o-search-preview') ||
model.id.includes('gpt-4o-mini-search-preview') ||
(model.id.includes('gpt-4.1') && !model.id.includes('gpt-4.1-nano')) ||
(model.id.includes('gpt-4o') && !model.id.includes('gpt-4o-image'))
)
}
export function isSupportedThinkingTokenModel(model?: Model): boolean {
if (!model) {
return false
@@ -2287,7 +2415,8 @@ export function isSupportedThinkingTokenModel(model?: Model): boolean {
return (
isSupportedThinkingTokenGeminiModel(model) ||
isSupportedThinkingTokenQwenModel(model) ||
isSupportedThinkingTokenClaudeModel(model)
isSupportedThinkingTokenClaudeModel(model) ||
isSupportedThinkingTokenDoubaoModel(model)
)
}
@@ -2369,6 +2498,14 @@ export function isSupportedThinkingTokenQwenModel(model?: Model): boolean {
)
}
export function isSupportedThinkingTokenDoubaoModel(model?: Model): boolean {
if (!model) {
return false
}
return DOUBAO_THINKING_MODEL_REGEX.test(model.id)
}
export function isClaudeReasoningModel(model?: Model): boolean {
if (!model) {
return false
@@ -2389,7 +2526,12 @@ export function isReasoningModel(model?: Model): boolean {
}
if (model.provider === 'doubao') {
return REASONING_REGEX.test(model.name) || model.type?.includes('reasoning') || false
return (
REASONING_REGEX.test(model.name) ||
model.type?.includes('reasoning') ||
isSupportedThinkingTokenDoubaoModel(model) ||
false
)
}
if (
@@ -2398,7 +2540,8 @@ export function isReasoningModel(model?: Model): boolean {
isGeminiReasoningModel(model) ||
isQwenReasoningModel(model) ||
isGrokReasoningModel(model) ||
model.id.includes('glm-z1')
model.id.includes('glm-z1') ||
model.id.includes('magistral')
) {
return true
}
@@ -2419,7 +2562,7 @@ export function isNotSupportTemperatureAndTopP(model: Model): boolean {
return true
}
if (isOpenAIReasoningModel(model) || isOpenAIWebSearch(model)) {
if (isOpenAIReasoningModel(model) || isOpenAIChatCompletionOnlyModel(model)) {
return true
}
@@ -2449,17 +2592,13 @@ export function isWebSearchModel(model: Model): boolean {
return false
}
// 不管哪个供应商都判断了
if (model.id.includes('claude')) {
return CLAUDE_SUPPORTED_WEBSEARCH_REGEX.test(model.id)
}
if (provider.type === 'openai-response') {
if (
isOpenAILLMModel(model) &&
!isTextToImageModel(model) &&
!isOpenAIReasoningModel(model) &&
!GENERATE_IMAGE_MODELS.includes(model.id)
) {
if (isOpenAIWebSearchModel(model)) {
return true
}
@@ -2471,12 +2610,7 @@ export function isWebSearchModel(model: Model): boolean {
}
if (provider.id === 'aihubmix') {
if (
isOpenAILLMModel(model) &&
!isTextToImageModel(model) &&
!isOpenAIReasoningModel(model) &&
!GENERATE_IMAGE_MODELS.includes(model.id)
) {
if (isOpenAIWebSearchModel(model)) {
return true
}
@@ -2485,7 +2619,7 @@ export function isWebSearchModel(model: Model): boolean {
}
if (provider?.type === 'openai') {
if (GEMINI_SEARCH_MODELS.includes(model?.id) || isOpenAIWebSearch(model)) {
if (GEMINI_SEARCH_MODELS.includes(model?.id) || isOpenAIWebSearchModel(model)) {
return true
}
}
@@ -2519,6 +2653,20 @@ export function isWebSearchModel(model: Model): boolean {
return false
}
export function isOpenRouterBuiltInWebSearchModel(model: Model): boolean {
if (!model) {
return false
}
const provider = getProviderByModel(model)
if (provider.id !== 'openrouter') {
return false
}
return isOpenAIWebSearchModel(model) || model.id.includes('sonar')
}
export function isGenerateImageModel(model: Model): boolean {
if (!model) {
return false
@@ -2541,56 +2689,60 @@ export function isGenerateImageModel(model: Model): boolean {
return false
}
export function getOpenAIWebSearchParams(assistant: Assistant, model: Model): Record<string, any> {
if (isWebSearchModel(model)) {
if (assistant.enableWebSearch) {
const webSearchTools = getWebSearchTools(model)
export function isSupportedDisableGenerationModel(model: Model): boolean {
if (!model) {
return false
}
if (model.provider === 'grok') {
return {
search_parameters: {
mode: 'auto',
return_citations: true,
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
}
}
}
return SUPPORTED_DISABLE_GENERATION_MODELS.includes(model.id)
}
if (model.provider === 'hunyuan') {
return { enable_enhancement: true, citation: true, search_info: true }
}
export function getOpenAIWebSearchParams(model: Model, isEnableWebSearch?: boolean): Record<string, any> {
if (!isEnableWebSearch) {
return {}
}
if (model.provider === 'dashscope') {
return {
enable_search: true,
search_options: {
forced_search: true
}
}
}
const webSearchTools = getWebSearchTools(model)
if (model.provider === 'openrouter') {
return {
plugins: [{ id: 'web', search_prompts: WEB_SEARCH_PROMPT_FOR_OPENROUTER }]
}
}
if (isOpenAIWebSearch(model)) {
return {
web_search_options: {}
}
}
return {
tools: webSearchTools
}
} else {
if (model.provider === 'hunyuan') {
return { enable_enhancement: false }
if (model.provider === 'grok') {
return {
search_parameters: {
mode: 'auto',
return_citations: true,
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
}
}
}
if (model.provider === 'hunyuan') {
return { enable_enhancement: true, citation: true, search_info: true }
}
if (model.provider === 'dashscope') {
return {
enable_search: true,
search_options: {
forced_search: true
}
}
}
if (isOpenAIWebSearchChatCompletionOnlyModel(model)) {
return {
web_search_options: {}
}
}
if (model.provider === 'openrouter') {
return {
plugins: [{ id: 'web', search_prompts: WEB_SEARCH_PROMPT_FOR_OPENROUTER }]
}
}
return {
tools: webSearchTools
}
return {}
}
@@ -2671,3 +2823,16 @@ export const findTokenLimit = (modelId: string): { min: number; max: number } |
}
return undefined
}
// Doubao 支持思考模式的模型正则
export const DOUBAO_THINKING_MODEL_REGEX =
/doubao-(?:1(\.|-5)-thinking-vision-pro|1(\.|-)5-thinking-pro-m|seed-1\.6|seed-1\.6-flash)(?:-[\\w-]+)?/i
// 支持 auto 的 Doubao 模型
export const DOUBAO_THINKING_AUTO_MODEL_REGEX = /doubao-(?:1-5-thinking-pro-m|seed-1.6)(?:-[\\w-]+)?/i
export function isDoubaoThinkingAutoModel(model: Model): boolean {
return DOUBAO_THINKING_AUTO_MODEL_REGEX.test(model.id)
}
export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini-.*-flash.*$')
+27 -1
View File
@@ -1,6 +1,7 @@
import ZhinaoProviderLogo from '@renderer/assets/images/models/360.png'
import HunyuanProviderLogo from '@renderer/assets/images/models/hunyuan.png'
import AzureProviderLogo from '@renderer/assets/images/models/microsoft.png'
import Ai302ProviderLogo from '@renderer/assets/images/providers/302ai.webp'
import AiHubMixProviderLogo from '@renderer/assets/images/providers/aihubmix.webp'
import AlayaNewProviderLogo from '@renderer/assets/images/providers/alayanew.webp'
import AnthropicProviderLogo from '@renderer/assets/images/providers/anthropic.png'
@@ -8,6 +9,7 @@ import BaichuanProviderLogo from '@renderer/assets/images/providers/baichuan.png
import BaiduCloudProviderLogo from '@renderer/assets/images/providers/baidu-cloud.svg'
import BailianProviderLogo from '@renderer/assets/images/providers/bailian.png'
import BurnCloudProviderLogo from '@renderer/assets/images/providers/burncloud.png'
import CephalonProviderLogo from '@renderer/assets/images/providers/cephalon.jpeg'
import DeepSeekProviderLogo from '@renderer/assets/images/providers/deepseek.png'
import DmxapiProviderLogo from '@renderer/assets/images/providers/DMXAPI.png'
import FireworksProviderLogo from '@renderer/assets/images/providers/fireworks.png'
@@ -48,6 +50,7 @@ import ZhipuProviderLogo from '@renderer/assets/images/providers/zhipu.png'
import { TOKENFLUX_HOST } from './constant'
const PROVIDER_LOGO_MAP = {
'302ai': Ai302ProviderLogo,
openai: OpenAiProviderLogo,
silicon: SiliconFlowProviderLogo,
deepseek: DeepSeekProviderLogo,
@@ -94,7 +97,8 @@ const PROVIDER_LOGO_MAP = {
alayanew: AlayaNewProviderLogo,
voyageai: VoyageAIProviderLogo,
qiniu: QiniuProviderLogo,
tokenflux: TokenFluxProviderLogo
tokenflux: TokenFluxProviderLogo,
cephalon: CephalonProviderLogo
} as const
export function getProviderLogo(providerId: string) {
@@ -106,6 +110,17 @@ export const NOT_SUPPORTED_REANK_PROVIDERS = ['ollama']
export const ONLY_SUPPORTED_DIMENSION_PROVIDERS = ['ollama', 'infini']
export const PROVIDER_CONFIG = {
'302ai': {
api: {
url: 'https://api.302.ai'
},
websites: {
official: 'https://302.ai',
apiKey: 'https://dash.302.ai/apis/list',
docs: 'https://302ai.apifox.cn/api-147522039',
models: 'https://302.ai/pricing/'
}
},
openai: {
api: {
url: 'https://api.openai.com'
@@ -612,5 +627,16 @@ export const PROVIDER_CONFIG = {
docs: `${TOKENFLUX_HOST}/docs`,
models: `${TOKENFLUX_HOST}/models`
}
},
cephalon: {
api: {
url: 'https://cephalon.cloud/user-center/v1/model'
},
websites: {
official: 'https://cephalon.cloud/share/register-landing?invite_id=jSdOYA',
apiKey: 'https://cephalon.cloud/api',
docs: 'https://cephalon.cloud/apitoken/1864244127731589124',
models: 'https://cephalon.cloud/model'
}
}
}
+68 -15
View File
@@ -4,6 +4,7 @@ import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
import { deleteMessageFiles } from '@renderer/services/MessagesService'
import store from '@renderer/store'
import { updateTopic } from '@renderer/store/assistants'
import { setNewlyRenamedTopics, setRenamingTopics } from '@renderer/store/runtime'
import { loadTopicMessagesThunk } from '@renderer/store/thunk/messageThunk'
import { Assistant, Topic } from '@renderer/types'
import { findMainTextBlocks } from '@renderer/utils/messageUtils/find'
@@ -13,8 +14,6 @@ import { useEffect, useState } from 'react'
import { useAssistant } from './useAssistant'
import { getStoreSetting } from './useSettings'
const renamingTopics = new Set<string>()
let _activeTopic: Topic
let _setActiveTopic: (topic: Topic) => void
@@ -58,13 +57,46 @@ export async function getTopicById(topicId: string) {
return { ...topic, messages } as Topic
}
/**
*
*/
export const startTopicRenaming = (topicId: string) => {
const currentIds = store.getState().runtime.chat.renamingTopics
if (!currentIds.includes(topicId)) {
store.dispatch(setRenamingTopics([...currentIds, topicId]))
}
}
/**
*
*/
export const finishTopicRenaming = (topicId: string) => {
const state = store.getState()
// 1. 立即从 renamingTopics 移除
const currentRenaming = state.runtime.chat.renamingTopics
store.dispatch(setRenamingTopics(currentRenaming.filter((id) => id !== topicId)))
// 2. 立即添加到 newlyRenamedTopics
const currentNewlyRenamed = state.runtime.chat.newlyRenamedTopics
store.dispatch(setNewlyRenamedTopics([...currentNewlyRenamed, topicId]))
// 3. 延迟从 newlyRenamedTopics 移除
setTimeout(() => {
const current = store.getState().runtime.chat.newlyRenamedTopics
store.dispatch(setNewlyRenamedTopics(current.filter((id) => id !== topicId)))
}, 700)
}
const topicRenamingLocks = new Set<string>()
export const autoRenameTopic = async (assistant: Assistant, topicId: string) => {
if (renamingTopics.has(topicId)) {
if (topicRenamingLocks.has(topicId)) {
return
}
try {
renamingTopics.add(topicId)
topicRenamingLocks.add(topicId)
const topic = await getTopicById(topicId)
const enableTopicNaming = getStoreSetting('enableTopicNaming')
@@ -85,24 +117,36 @@ export const autoRenameTopic = async (assistant: Assistant, topicId: string) =>
.join('\n\n')
.substring(0, 50)
if (topicName) {
const data = { ...topic, name: topicName } as Topic
_setActiveTopic(data)
store.dispatch(updateTopic({ assistantId: assistant.id, topic: data }))
try {
startTopicRenaming(topicId)
const data = { ...topic, name: topicName } as Topic
topic.id === _activeTopic.id && _setActiveTopic(data)
store.dispatch(updateTopic({ assistantId: assistant.id, topic: data }))
} finally {
finishTopicRenaming(topicId)
}
}
return
}
if (topic && topic.name === i18n.t('chat.default.topic.name') && topic.messages.length >= 2) {
const { fetchMessagesSummary } = await import('@renderer/services/ApiService')
const summaryText = await fetchMessagesSummary({ messages: topic.messages, assistant })
if (summaryText) {
const data = { ...topic, name: summaryText }
_setActiveTopic(data)
store.dispatch(updateTopic({ assistantId: assistant.id, topic: data }))
try {
startTopicRenaming(topicId)
const { fetchMessagesSummary } = await import('@renderer/services/ApiService')
const summaryText = await fetchMessagesSummary({ messages: topic.messages, assistant })
if (summaryText) {
const data = { ...topic, name: summaryText }
topic.id === _activeTopic.id && _setActiveTopic(data)
store.dispatch(updateTopic({ assistantId: assistant.id, topic: data }))
}
} finally {
finishTopicRenaming(topicId)
}
}
} finally {
renamingTopics.delete(topicId)
topicRenamingLocks.delete(topicId)
}
}
@@ -117,9 +161,18 @@ export const TopicManager = {
return await db.topics.toArray()
},
/**
*
*/
async getTopicMessages(id: string) {
const topic = await TopicManager.getTopic(id)
return topic ? topic.messages : []
if (!topic) return []
await store.dispatch(loadTopicMessagesThunk(id))
// 获取更新后的话题
const updatedTopic = await TopicManager.getTopic(id)
return updatedTopic?.messages || []
},
async removeTopic(id: string) {
+3 -1
View File
@@ -987,6 +987,7 @@
"azure-openai": "Azure OpenAI",
"baichuan": "Baichuan",
"baidu-cloud": "Baidu Cloud",
"cephalon": "Cephalon",
"copilot": "GitHub Copilot",
"dashscope": "Alibaba Cloud",
"deepseek": "DeepSeek",
@@ -1027,7 +1028,8 @@
"zhipu": "ZHIPU AI",
"voyageai": "Voyage AI",
"qiniu": "Qiniu AI",
"tokenflux": "TokenFlux"
"tokenflux": "TokenFlux",
"302ai": "302.AI"
},
"restore": {
"confirm": "Are you sure you want to restore data?",
+4 -2
View File
@@ -713,7 +713,7 @@
"error.yuque.no_config": "語雀のAPIアドレスまたはトークンが設定されていません",
"download.success": "ダウンロードに成功しました",
"download.failed": "ダウンロードに失敗しました",
"error.fetchTopicName": "トピックの命名に失敗しました"
"error.fetchTopicName": "トピック名の取得に失敗しました"
},
"minapp": {
"popup": {
@@ -1027,7 +1027,9 @@
"zhipu": "智譜AI",
"voyageai": "Voyage AI",
"qiniu": "七牛云 AI 推理",
"tokenflux": "TokenFlux"
"tokenflux": "TokenFlux",
"302ai": "302.AI",
"cephalon": "Cephalon"
},
"restore": {
"confirm": "データを復元しますか?",
+4 -2
View File
@@ -713,7 +713,7 @@
"warn.siyuan.exporting": "Экспортируется в Siyuan, пожалуйста, не отправляйте повторные запросы!",
"download.success": "Скачано успешно",
"download.failed": "Скачивание не удалось",
"error.fetchTopicName": "Не удалось назвать тему"
"error.fetchTopicName": "Не удалось назвать топик"
},
"minapp": {
"popup": {
@@ -987,6 +987,7 @@
"azure-openai": "Azure OpenAI",
"baichuan": "Baichuan",
"baidu-cloud": "Baidu Cloud",
"cephalon": "Cephalon",
"copilot": "GitHub Copilot",
"dashscope": "Alibaba Cloud",
"deepseek": "DeepSeek",
@@ -1027,7 +1028,8 @@
"zhipu": "ZHIPU AI",
"voyageai": "Voyage AI",
"qiniu": "Qiniu AI",
"tokenflux": "TokenFlux"
"tokenflux": "TokenFlux",
"302ai": "302.AI"
},
"restore": {
"confirm": "Вы уверены, что хотите восстановить данные?",
+3 -1
View File
@@ -987,6 +987,7 @@
"azure-openai": "Azure OpenAI",
"baichuan": "百川",
"baidu-cloud": "百度云千帆",
"cephalon": "Cephalon",
"copilot": "GitHub Copilot",
"dashscope": "阿里云百炼",
"deepseek": "深度求索",
@@ -1027,7 +1028,8 @@
"zhipu": "智谱AI",
"voyageai": "Voyage AI",
"qiniu": "七牛云 AI 推理",
"tokenflux": "TokenFlux"
"tokenflux": "TokenFlux",
"302ai": "302.AI"
},
"restore": {
"confirm": "确定要恢复数据吗?",
+3 -1
View File
@@ -987,6 +987,7 @@
"azure-openai": "Azure OpenAI",
"baichuan": "百川",
"baidu-cloud": "百度雲千帆",
"cephalon": "Cephalon",
"copilot": "GitHub Copilot",
"dashscope": "阿里雲百鍊",
"deepseek": "深度求索",
@@ -1027,7 +1028,8 @@
"zhipu": "智譜 AI",
"voyageai": "Voyage AI",
"qiniu": "七牛雲 AI 推理",
"tokenflux": "TokenFlux"
"tokenflux": "TokenFlux",
"302ai": "302.AI"
},
"restore": {
"confirm": "確定要復原資料嗎?",
@@ -840,6 +840,7 @@
"azure-openai": "Azure OpenAI",
"baichuan": "Παράκειμαι",
"baidu-cloud": "Baidu Cloud Qianfan",
"cephalon": "Cephalon",
"copilot": "GitHub Copilot",
"dashscope": "AliCloud Bailian",
"deepseek": "Βαθιά Αναζήτηση",
@@ -841,6 +841,7 @@
"azure-openai": "Azure OpenAI",
"baichuan": "BaiChuan",
"baidu-cloud": "Baidu Nube Qiánfān",
"cephalon": "Cephalon",
"copilot": "GitHub Copiloto",
"dashscope": "Álibaba Nube BaiLiàn",
"deepseek": "Profundo Buscar",
@@ -840,6 +840,7 @@
"azure-openai": "Azure OpenAI",
"baichuan": "BaiChuan",
"baidu-cloud": "Baidu Cloud Qianfan",
"cephalon": "Cephalon",
"copilot": "GitHub Copilote",
"dashscope": "AliCloud BaiLian",
"deepseek": "DeepSeek",
@@ -1,118 +0,0 @@
// Modified from https://github.com/vercel/ai/blob/845080d80b8538bb9c7e527d2213acb5f33ac9c2/packages/ai/core/middleware/extract-reasoning-middleware.ts
import { getPotentialStartIndex } from '../utils/getPotentialIndex'
export interface ExtractReasoningMiddlewareOptions {
openingTag: string
closingTag: string
separator?: string
enableReasoning?: boolean
}
function escapeRegExp(str: string) {
return str.replace(/[.*+?^${}()|[\\]\\]/g, '\\$&')
}
// 支持泛型 T,默认 T = { type: string; textDelta: string }
export function extractReasoningMiddleware<
T extends { type: string } & (
| { type: 'text-delta' | 'reasoning'; textDelta: string }
| { type: string } // 其他类型
) = { type: string; textDelta: string }
>({ openingTag, closingTag, separator = '\n', enableReasoning }: ExtractReasoningMiddlewareOptions) {
const openingTagEscaped = escapeRegExp(openingTag)
const closingTagEscaped = escapeRegExp(closingTag)
return {
wrapGenerate: async ({ doGenerate }: { doGenerate: () => Promise<{ text: string } & Record<string, any>> }) => {
const { text: rawText, ...rest } = await doGenerate()
if (rawText == null) {
return { text: rawText, ...rest }
}
const text = rawText
const regexp = new RegExp(`${openingTagEscaped}(.*?)${closingTagEscaped}`, 'gs')
const matches = Array.from(text.matchAll(regexp))
if (!matches.length) {
return { text, ...rest }
}
const reasoning = matches.map((match: RegExpMatchArray) => match[1]).join(separator)
let textWithoutReasoning = text
for (let i = matches.length - 1; i >= 0; i--) {
const match = matches[i] as RegExpMatchArray
const beforeMatch = textWithoutReasoning.slice(0, match.index as number)
const afterMatch = textWithoutReasoning.slice((match.index as number) + match[0].length)
textWithoutReasoning =
beforeMatch + (beforeMatch.length > 0 && afterMatch.length > 0 ? separator : '') + afterMatch
}
return { ...rest, text: textWithoutReasoning, reasoning }
},
wrapStream: async ({
doStream
}: {
doStream: () => Promise<{ stream: ReadableStream<T> } & Record<string, any>>
}) => {
const { stream, ...rest } = await doStream()
if (!enableReasoning) {
return {
stream,
...rest
}
}
let isFirstReasoning = true
let isFirstText = true
let afterSwitch = false
let isReasoning = false
let buffer = ''
return {
stream: stream.pipeThrough(
new TransformStream<T, T>({
transform: (chunk, controller) => {
if (chunk.type !== 'text-delta') {
controller.enqueue(chunk)
return
}
// textDelta 只在 text-delta/reasoning chunk 上
buffer += (chunk as { textDelta: string }).textDelta
function publish(text: string) {
if (text.length > 0) {
const prefix = afterSwitch && (isReasoning ? !isFirstReasoning : !isFirstText) ? separator : ''
controller.enqueue({
...chunk,
type: isReasoning ? 'reasoning' : 'text-delta',
textDelta: prefix + text
} as T)
afterSwitch = false
if (isReasoning) {
isFirstReasoning = false
} else {
isFirstText = false
}
}
}
while (true) {
const nextTag = isReasoning ? closingTag : openingTag
const startIndex = getPotentialStartIndex(buffer, nextTag)
if (startIndex == null) {
publish(buffer)
buffer = ''
break
}
publish(buffer.slice(0, startIndex))
const foundFullMatch = startIndex + nextTag.length <= buffer.length
if (foundFullMatch) {
buffer = buffer.slice(startIndex + nextTag.length)
isReasoning = !isReasoning
afterSwitch = true
} else {
buffer = buffer.slice(startIndex)
break
}
}
}
})
),
...rest
}
}
}
}
@@ -4,6 +4,7 @@ import TranslateButton from '@renderer/components/TranslateButton'
import Logger from '@renderer/config/logger'
import {
isGenerateImageModel,
isSupportedDisableGenerationModel,
isSupportedReasoningEffortModel,
isSupportedThinkingTokenModel,
isVisionModel,
@@ -727,7 +728,7 @@ const Inputbar: FC<Props> = ({ assistant: _assistant, setActiveTopic, topic }) =
if (!isGenerateImageModel(model) && assistant.enableGenerateImage) {
updateAssistant({ ...assistant, enableGenerateImage: false })
}
if (isGenerateImageModel(model) && !assistant.enableGenerateImage && model.id !== 'gemini-2.0-flash-exp') {
if (isGenerateImageModel(model) && !assistant.enableGenerateImage && !isSupportedDisableGenerationModel(model)) {
updateAssistant({ ...assistant, enableGenerateImage: true })
}
}, [assistant, model, updateAssistant])
@@ -7,7 +7,9 @@ import {
} from '@renderer/components/Icons/SVGIcon'
import { useQuickPanel } from '@renderer/components/QuickPanel'
import {
isDoubaoThinkingAutoModel,
isSupportedReasoningEffortGrokModel,
isSupportedThinkingTokenDoubaoModel,
isSupportedThinkingTokenGeminiModel,
isSupportedThinkingTokenQwenModel
} from '@renderer/config/models'
@@ -35,13 +37,14 @@ const MODEL_SUPPORTED_OPTIONS: Record<string, ThinkingOption[]> = {
default: ['off', 'low', 'medium', 'high'],
grok: ['off', 'low', 'high'],
gemini: ['off', 'low', 'medium', 'high', 'auto'],
qwen: ['off', 'low', 'medium', 'high']
qwen: ['off', 'low', 'medium', 'high'],
doubao: ['off', 'auto', 'high']
}
// 选项转换映射表:当选项不支持时使用的替代选项
const OPTION_FALLBACK: Record<ThinkingOption, ThinkingOption> = {
off: 'off',
low: 'low',
low: 'high',
medium: 'high', // medium -> high (for Grok models)
high: 'high',
auto: 'high' // auto -> high (for non-Gemini models)
@@ -55,6 +58,7 @@ const ThinkingButton: FC<Props> = ({ ref, model, assistant, ToolbarButton }): Re
const isGrokModel = isSupportedReasoningEffortGrokModel(model)
const isGeminiModel = isSupportedThinkingTokenGeminiModel(model)
const isQwenModel = isSupportedThinkingTokenQwenModel(model)
const isDoubaoModel = isSupportedThinkingTokenDoubaoModel(model)
const currentReasoningEffort = useMemo(() => {
return assistant.settings?.reasoning_effort || 'off'
@@ -65,13 +69,20 @@ const ThinkingButton: FC<Props> = ({ ref, model, assistant, ToolbarButton }): Re
if (isGeminiModel) return 'gemini'
if (isGrokModel) return 'grok'
if (isQwenModel) return 'qwen'
if (isDoubaoModel) return 'doubao'
return 'default'
}, [isGeminiModel, isGrokModel, isQwenModel])
}, [isGeminiModel, isGrokModel, isQwenModel, isDoubaoModel])
// 获取当前模型支持的选项
const supportedOptions = useMemo(() => {
if (modelType === 'doubao') {
if (isDoubaoThinkingAutoModel(model)) {
return ['off', 'auto', 'high'] as ThinkingOption[]
}
return ['off', 'high'] as ThinkingOption[]
}
return MODEL_SUPPORTED_OPTIONS[modelType]
}, [modelType])
}, [model, modelType])
// 检查当前设置是否与当前模型兼容
useEffect(() => {
@@ -24,6 +24,7 @@ import remarkMath from 'remark-math'
import CodeBlock from './CodeBlock'
import Link from './Link'
import Table from './Table'
const ALLOWED_ELEMENTS =
/<(style|p|div|span|b|i|strong|em|ul|ol|li|table|tr|td|th|thead|tbody|h[1-6]|blockquote|pre|code|br|hr|svg|path|circle|rect|line|polyline|polygon|text|g|defs|title|desc|tspan|sub|sup)/i
@@ -83,6 +84,7 @@ const Markdown: FC<Props> = ({ block }) => {
code: (props: any) => (
<CodeBlock {...props} id={getCodeBlockId(props?.node?.position?.start)} onSave={onSaveCodeBlock} />
),
table: (props: any) => <Table {...props} blockId={block.id} />,
img: (props: any) => <ImageViewer style={{ maxWidth: 500, maxHeight: 500 }} {...props} />,
pre: (props: any) => <pre style={{ overflow: 'visible' }} {...props} />,
p: (props) => {
@@ -91,7 +93,7 @@ const Markdown: FC<Props> = ({ block }) => {
return <p {...props} />
}
} as Partial<Components>
}, [onSaveCodeBlock])
}, [onSaveCodeBlock, block.id])
if (messageContent.includes('<style>')) {
components.style = MarkdownShadowDOMRenderer as any
@@ -0,0 +1,120 @@
import store from '@renderer/store'
import { messageBlocksSelectors } from '@renderer/store/messageBlock'
import { Tooltip } from 'antd'
import { Check, Copy } from 'lucide-react'
import React, { memo, useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
interface Props {
children: React.ReactNode
node?: any
blockId?: string
}
/**
* Markdown copy
*/
const Table: React.FC<Props> = ({ children, node, blockId }) => {
const { t } = useTranslation()
const [copied, setCopied] = useState(false)
const handleCopyTable = useCallback(() => {
const tableMarkdown = extractTableMarkdown(blockId ?? '', node?.position)
if (!tableMarkdown) return
navigator.clipboard
.writeText(tableMarkdown)
.then(() => {
setCopied(true)
setTimeout(() => setCopied(false), 2000)
})
.catch((error) => {
window.message?.error({ content: `${t('message.copy.failed')}: ${error}`, key: 'copy-table-error' })
})
}, [node, blockId, t])
return (
<TableWrapper className="table-wrapper">
<table>{children}</table>
<ToolbarWrapper className="table-toolbar">
<Tooltip title={t('common.copy')} mouseEnterDelay={0.8}>
<ToolButton role="button" aria-label={t('common.copy')} onClick={handleCopyTable}>
{copied ? (
<Check size={14} style={{ color: 'var(--color-primary)' }} data-testid="check-icon" />
) : (
<Copy size={14} data-testid="copy-icon" />
)}
</ToolButton>
</Tooltip>
</ToolbarWrapper>
</TableWrapper>
)
}
/**
* Markdown
* @param blockId ID
* @param position
* @returns
*/
export function extractTableMarkdown(blockId: string, position: any): string {
if (!position || !blockId) return ''
const block = messageBlocksSelectors.selectById(store.getState(), blockId)
if (!block || !('content' in block) || typeof block.content !== 'string') return ''
const { start, end } = position
const lines = block.content.split('\n')
// 提取表格对应的行(行号从1开始,数组索引从0开始)
const tableLines = lines.slice(start.line - 1, end.line)
return tableLines.join('\n').trim()
}
const TableWrapper = styled.div`
position: relative;
.table-toolbar {
border-radius: 4px;
opacity: 0;
transition: opacity 0.2s ease;
transform: translateZ(0);
will-change: opacity;
}
&:hover {
.table-toolbar {
opacity: 1;
}
}
`
const ToolbarWrapper = styled.div`
position: absolute;
top: 8px;
right: 8px;
z-index: 10;
`
const ToolButton = styled.div`
display: flex;
align-items: center;
justify-content: center;
width: 24px;
height: 24px;
border-radius: 4px;
cursor: pointer;
user-select: none;
transition: all 0.2s ease;
opacity: 1;
color: var(--color-text-3);
background-color: var(--color-background-mute);
will-change: background-color, opacity;
&:hover {
background-color: var(--color-background-soft);
}
`
export default memo(Table)
@@ -78,6 +78,18 @@ vi.mock('../Link', () => ({
)
}))
vi.mock('../Table', () => ({
__esModule: true,
default: ({ children, blockId }: any) => (
<div data-testid="table-component" data-block-id={blockId}>
<table>{children}</table>
<button type="button" data-testid="copy-table-button">
Copy Table
</button>
</div>
)
}))
vi.mock('@renderer/components/MarkdownShadowDOMRenderer', () => ({
__esModule: true,
default: ({ children }: any) => <div data-testid="shadow-dom">{children}</div>
@@ -104,6 +116,11 @@ vi.mock('react-markdown', () => ({
{components.code({ children: 'test code', node: { position: { start: { line: 1 } } } })}
</div>
)}
{components?.table && (
<div data-testid="has-table-component">
{components.table({ children: 'test table', node: { position: { start: { line: 1 } } } })}
</div>
)}
{components?.img && <span data-testid="has-img-component">img</span>}
{components?.style && <span data-testid="has-style-component">style</span>}
</div>
@@ -300,6 +317,16 @@ describe('Markdown', () => {
})
})
it('should integrate Table component with copy functionality', () => {
const block = createMainTextBlock({ id: 'test-block-456' })
render(<Markdown block={block} />)
expect(screen.getByTestId('has-table-component')).toBeInTheDocument()
const tableComponent = screen.getByTestId('table-component')
expect(tableComponent).toHaveAttribute('data-block-id', 'test-block-456')
})
it('should integrate ImagePreview component', () => {
render(<Markdown block={createMainTextBlock()} />)
@@ -0,0 +1,316 @@
import { act, render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { afterAll, afterEach, beforeAll, beforeEach, describe, expect, it, vi } from 'vitest'
import Table, { extractTableMarkdown } from '../Table'
const mocks = vi.hoisted(() => {
return {
store: {
getState: vi.fn()
},
messageBlocksSelectors: {
selectById: vi.fn()
},
windowMessage: {
error: vi.fn()
}
}
})
// Mock dependencies
vi.mock('@renderer/store', () => ({
__esModule: true,
default: mocks.store
}))
vi.mock('@renderer/store/messageBlock', () => ({
messageBlocksSelectors: mocks.messageBlocksSelectors
}))
vi.mock('react-i18next', () => ({
useTranslation: () => ({
t: (key: string) => key
})
}))
vi.mock('antd', () => ({
Tooltip: ({ children, title }: any) => (
<div data-testid="tooltip" title={title}>
{children}
</div>
)
}))
Object.assign(window, {
message: mocks.windowMessage
})
describe('Table', () => {
beforeAll(() => {
vi.stubGlobal('jest', {
advanceTimersByTime: vi.advanceTimersByTime.bind(vi)
})
})
beforeEach(() => {
vi.clearAllMocks()
vi.useFakeTimers()
})
afterEach(() => {
vi.restoreAllMocks()
vi.runOnlyPendingTimers()
vi.useRealTimers()
})
afterAll(() => {
vi.unstubAllGlobals()
})
// https://testing-library.com/docs/user-event/clipboard/
const user = userEvent.setup({
advanceTimers: vi.advanceTimersByTime.bind(vi),
writeToClipboard: true
})
// Test data factories
const createMockBlock = (content: string = defaultTableContent) => ({
id: 'test-block-1',
content
})
const createTablePosition = (startLine = 1, endLine = 3) => ({
start: { line: startLine },
end: { line: endLine }
})
const defaultTableContent = `| Header 1 | Header 2 |
|----------|----------|
| Cell 1 | Cell 2 |`
const defaultProps = {
children: (
<tbody>
<tr>
<td>Cell 1</td>
<td>Cell 2</td>
</tr>
</tbody>
),
blockId: 'test-block-1',
node: { position: createTablePosition() }
}
const getCopyButton = () => screen.getByRole('button', { name: /common\.copy/i })
const getCopyIcon = () => screen.getByTestId('copy-icon')
const getCheckIcon = () => screen.getByTestId('check-icon')
const queryCheckIcon = () => screen.queryByTestId('check-icon')
const queryCopyIcon = () => screen.queryByTestId('copy-icon')
describe('rendering', () => {
it('should render table with children and toolbar', () => {
render(<Table {...defaultProps} />)
expect(screen.getByRole('table')).toBeInTheDocument()
expect(screen.getByText('Cell 1')).toBeInTheDocument()
expect(screen.getByText('Cell 2')).toBeInTheDocument()
expect(screen.getByTestId('tooltip')).toBeInTheDocument()
})
it('should render with table-wrapper and table-toolbar classes', () => {
const { container } = render(<Table {...defaultProps} />)
expect(container.querySelector('.table-wrapper')).toBeInTheDocument()
expect(container.querySelector('.table-toolbar')).toBeInTheDocument()
})
it('should render copy button with correct tooltip', () => {
render(<Table {...defaultProps} />)
const tooltip = screen.getByTestId('tooltip')
expect(tooltip).toHaveAttribute('title', 'common.copy')
})
it('should match snapshot', () => {
const { container } = render(<Table {...defaultProps} />)
expect(container.firstChild).toMatchSnapshot()
})
})
describe('extractTableMarkdown', () => {
beforeEach(() => {
mocks.store.getState.mockReturnValue({})
})
it('should extract table content from specified line range', () => {
const block = createMockBlock()
const position = createTablePosition(1, 3)
mocks.messageBlocksSelectors.selectById.mockReturnValue(block)
const result = extractTableMarkdown('test-block-1', position)
expect(result).toBe(defaultTableContent)
expect(mocks.messageBlocksSelectors.selectById).toHaveBeenCalledWith({}, 'test-block-1')
})
it('should handle line range extraction correctly', () => {
const multiLineContent = `Line 0
| Header 1 | Header 2 |
|----------|----------|
| Cell 1 | Cell 2 |
Line 4`
const block = createMockBlock(multiLineContent)
const position = createTablePosition(2, 4) // Extract lines 2-4 (table part)
mocks.messageBlocksSelectors.selectById.mockReturnValue(block)
const result = extractTableMarkdown('test-block-1', position)
expect(result).toBe(`| Header 1 | Header 2 |
|----------|----------|
| Cell 1 | Cell 2 |`)
})
it('should return empty string when blockId is empty', () => {
const result = extractTableMarkdown('', createTablePosition())
expect(result).toBe('')
expect(mocks.messageBlocksSelectors.selectById).not.toHaveBeenCalled()
})
it('should return empty string when position is null', () => {
const result = extractTableMarkdown('test-block-1', null)
expect(result).toBe('')
expect(mocks.messageBlocksSelectors.selectById).not.toHaveBeenCalled()
})
it('should return empty string when position is undefined', () => {
const result = extractTableMarkdown('test-block-1', undefined)
expect(result).toBe('')
expect(mocks.messageBlocksSelectors.selectById).not.toHaveBeenCalled()
})
it('should return empty string when block does not exist', () => {
mocks.messageBlocksSelectors.selectById.mockReturnValue(null)
const result = extractTableMarkdown('non-existent-block', createTablePosition())
expect(result).toBe('')
})
it('should return empty string when block has no content property', () => {
const blockWithoutContent = { id: 'test-block-1' }
mocks.messageBlocksSelectors.selectById.mockReturnValue(blockWithoutContent)
const result = extractTableMarkdown('test-block-1', createTablePosition())
expect(result).toBe('')
})
it('should return empty string when block content is not a string', () => {
const blockWithInvalidContent = { id: 'test-block-1', content: 123 }
mocks.messageBlocksSelectors.selectById.mockReturnValue(blockWithInvalidContent)
const result = extractTableMarkdown('test-block-1', createTablePosition())
expect(result).toBe('')
})
it('should handle boundary line numbers correctly', () => {
const block = createMockBlock('Line 1\nLine 2\nLine 3')
const position = createTablePosition(1, 3)
mocks.messageBlocksSelectors.selectById.mockReturnValue(block)
const result = extractTableMarkdown('test-block-1', position)
expect(result).toBe('Line 1\nLine 2\nLine 3')
})
})
describe('copy functionality', () => {
beforeEach(() => {
mocks.messageBlocksSelectors.selectById.mockReturnValue(createMockBlock())
})
it('should copy table content to clipboard on button click', async () => {
render(<Table {...defaultProps} />)
const copyButton = getCopyButton()
await user.click(copyButton)
await waitFor(() => {
expect(getCheckIcon()).toBeInTheDocument()
expect(queryCopyIcon()).not.toBeInTheDocument()
})
})
it('should show check icon after successful copy', async () => {
render(<Table {...defaultProps} />)
// Initially shows copy icon
expect(getCopyIcon()).toBeInTheDocument()
const copyButton = getCopyButton()
await user.click(copyButton)
await waitFor(() => {
expect(getCheckIcon()).toBeInTheDocument()
expect(queryCopyIcon()).not.toBeInTheDocument()
})
})
it('should reset to copy icon after 2 seconds', async () => {
render(<Table {...defaultProps} />)
const copyButton = getCopyButton()
await user.click(copyButton)
await waitFor(() => {
expect(getCheckIcon()).toBeInTheDocument()
})
// Fast forward 2 seconds
act(() => {
vi.advanceTimersByTime(2000)
})
await waitFor(() => {
expect(getCopyIcon()).toBeInTheDocument()
expect(queryCheckIcon()).not.toBeInTheDocument()
})
})
it('should not copy when extractTableMarkdown returns empty string', async () => {
mocks.messageBlocksSelectors.selectById.mockReturnValue(null)
render(<Table {...defaultProps} />)
const copyButton = getCopyButton()
await user.click(copyButton)
await waitFor(() => {
expect(getCopyIcon()).toBeInTheDocument()
expect(queryCheckIcon()).not.toBeInTheDocument()
})
})
})
describe('edge cases', () => {
it('should work without blockId', () => {
const propsWithoutBlockId = { ...defaultProps, blockId: undefined }
expect(() => render(<Table {...propsWithoutBlockId} />)).not.toThrow()
const copyButton = getCopyButton()
expect(copyButton).toBeInTheDocument()
})
it('should work without node position', () => {
const propsWithoutPosition = { ...defaultProps, node: undefined }
expect(() => render(<Table {...propsWithoutPosition} />)).not.toThrow()
const copyButton = getCopyButton()
expect(copyButton).toBeInTheDocument()
})
})
})
@@ -30,6 +30,24 @@ This is **bold** text.
</button>
</div>
</div>
<div
data-testid="has-table-component"
>
<div
data-block-id="test-block-1"
data-testid="table-component"
>
<table>
test table
</table>
<button
data-testid="copy-table-button"
type="button"
>
Copy Table
</button>
</div>
</div>
<span
data-testid="has-img-component"
>
@@ -0,0 +1,103 @@
// Vitest Snapshot v1, https://vitest.dev/guide/snapshot.html
exports[`Table > rendering > should match snapshot 1`] = `
.c0 {
position: relative;
}
.c0 .table-toolbar {
border-radius: 4px;
opacity: 0;
transition: opacity 0.2s ease;
transform: translateZ(0);
will-change: opacity;
}
.c0:hover .table-toolbar {
opacity: 1;
}
.c1 {
position: absolute;
top: 8px;
right: 8px;
z-index: 10;
}
.c2 {
display: flex;
align-items: center;
justify-content: center;
width: 24px;
height: 24px;
border-radius: 4px;
cursor: pointer;
user-select: none;
transition: all 0.2s ease;
opacity: 1;
color: var(--color-text-3);
background-color: var(--color-background-mute);
will-change: background-color,opacity;
}
.c2:hover {
background-color: var(--color-background-soft);
}
<div
class="c0 table-wrapper"
>
<table>
<tbody>
<tr>
<td>
Cell 1
</td>
<td>
Cell 2
</td>
</tr>
</tbody>
</table>
<div
class="c1 table-toolbar"
>
<div
data-testid="tooltip"
title="common.copy"
>
<div
aria-label="common.copy"
class="c2"
role="button"
>
<svg
class="lucide lucide-copy"
data-testid="copy-icon"
fill="none"
height="14"
stroke="currentColor"
stroke-linecap="round"
stroke-linejoin="round"
stroke-width="2"
viewBox="0 0 24 24"
width="14"
xmlns="http://www.w3.org/2000/svg"
>
<rect
height="14"
rx="2"
ry="2"
width="14"
x="8"
y="8"
/>
<path
d="M4 16c-1.1 0-2-.9-2-2V4c0-1.1.9-2 2-2h10c1.1 0 2 .9 2 2"
/>
</svg>
</div>
</div>
</div>
</div>
`;
@@ -40,7 +40,18 @@ function CitationBlock({ block }: { block: CitationMessageBlock }) {
__html:
(block.response?.results as GroundingMetadata)?.searchEntryPoint?.renderedContent
?.replace(/@media \(prefers-color-scheme: light\)/g, 'body[theme-mode="light"]')
.replace(/@media \(prefers-color-scheme: dark\)/g, 'body[theme-mode="dark"]') || ''
.replace(/@media \(prefers-color-scheme: dark\)/g, 'body[theme-mode="dark"]')
.replace(
/background-color\s*:\s*#[0-9a-fA-F]{3,6}\b|\bbackground-color\s*:\s*[a-zA-Z-]+\b/g,
'background-color: var(--color-background-soft)'
)
.replace(/\.gradient\s*{[^}]*background\s*:\s*[^};]+[;}]/g, (match) => {
// Remove the background property while preserving the rest
return match.replace(/background\s*:\s*[^};]+;?\s*/g, '')
})
.replace(/\.chip {\n/g, '.chip {\n background-color: var(--color-background)!important;\n')
.replace(/border-color\s*:\s*[^};]+;?\s*/g, '')
.replace(/border\s*:\s*[^};]+;?\s*/g, '') || ''
}}
/>
</>
@@ -1,6 +1,6 @@
import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring'
import ImageViewer from '@renderer/components/ImageViewer'
import type { ImageMessageBlock } from '@renderer/types/newMessage'
import { type ImageMessageBlock, MessageBlockStatus } from '@renderer/types/newMessage'
import { Skeleton } from 'antd'
import React from 'react'
import styled from 'styled-components'
@@ -9,23 +9,26 @@ interface Props {
}
const ImageBlock: React.FC<Props> = ({ block }) => {
if (block.status !== 'success') return <SvgSpinners180Ring />
const images = block.metadata?.generateImageResponse?.images?.length
? block.metadata?.generateImageResponse?.images
: block?.file?.path
? [`file://${block?.file?.path}`]
: []
return (
<Container style={{ marginBottom: 8 }}>
{images.map((src, index) => (
<ImageViewer
src={src}
key={`image-${index}`}
style={{ maxWidth: 500, maxHeight: 500, padding: 5, borderRadius: 8 }}
/>
))}
</Container>
)
if (block.status === MessageBlockStatus.STREAMING || block.status === MessageBlockStatus.PROCESSING)
return <Skeleton.Image active style={{ width: 200, height: 200 }} />
if (block.status === MessageBlockStatus.SUCCESS) {
const images = block.metadata?.generateImageResponse?.images?.length
? block.metadata?.generateImageResponse?.images
: block?.file?.path
? [`file://${block?.file?.path}`]
: []
return (
<Container style={{ marginBottom: 8 }}>
{images.map((src, index) => (
<ImageViewer
src={src}
key={`image-${index}`}
style={{ maxWidth: 500, maxHeight: 500, padding: 5, borderRadius: 8 }}
/>
))}
</Container>
)
} else return null
}
const Container = styled.div`
display: flex;
@@ -33,5 +36,4 @@ const Container = styled.div`
gap: 10px;
margin-top: 8px;
`
export default React.memo(ImageBlock)
@@ -164,15 +164,14 @@ export default React.memo(MessageBlockRenderer)
const ImageBlockGroup = styled.div`
display: grid;
grid-template-columns: repeat(3, minmax(200px, 1fr));
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 8px;
width: 100%;
max-width: 960px;
> * {
/* > * {
min-width: 200px;
}
} */
@media (min-width: 1536px) {
grid-template-columns: repeat(4, minmax(250px, 1fr));
grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
max-width: 1280px;
> * {
min-width: 250px;
@@ -24,7 +24,8 @@ const EXCLUDED_SELECTORS = [
'.ant-collapse-header',
'.group-menu-bar',
'.code-block',
'.message-editor'
'.message-editor',
'.table-wrapper'
]
// Gap between the navigation bar and the right element
@@ -53,15 +53,17 @@ const MessgeTokens: React.FC<MessageTokensProps> = ({ message }) => {
)
return (
<MessageMetadata className="message-tokens" onClick={locateMessage}>
{hasMetrics ? (
<Popover content={metrixs} placement="top" trigger="hover" styles={{ root: { fontSize: 11 } }}>
{showTokens && tokensInfo}
</Popover>
) : (
tokensInfo
)}
</MessageMetadata>
showTokens && (
<MessageMetadata className="message-tokens" onClick={locateMessage}>
{hasMetrics ? (
<Popover content={metrixs} placement="top" trigger="hover" styles={{ root: { fontSize: 11 } }}>
{tokensInfo}
</Popover>
) : (
tokensInfo
)}
</MessageMetadata>
)
)
}
+80 -10
View File
@@ -18,7 +18,7 @@ import { isMac } from '@renderer/config/constant'
import { useAssistant, useAssistants } from '@renderer/hooks/useAssistant'
import { modelGenerating } from '@renderer/hooks/useRuntime'
import { useSettings } from '@renderer/hooks/useSettings'
import { TopicManager } from '@renderer/hooks/useTopic'
import { finishTopicRenaming, startTopicRenaming, TopicManager } from '@renderer/hooks/useTopic'
import { fetchMessagesSummary } from '@renderer/services/ApiService'
import { EVENT_NAMES, EventEmitter } from '@renderer/services/EventService'
import store from '@renderer/store'
@@ -57,6 +57,9 @@ const Topics: FC<Props> = ({ assistant: _assistant, activeTopic, setActiveTopic
const { t } = useTranslation()
const { showTopicTime, pinTopicsToTop, setTopicPosition } = useSettings()
const renamingTopics = useSelector((state: RootState) => state.runtime.chat.renamingTopics)
const newlyRenamedTopics = useSelector((state: RootState) => state.runtime.chat.newlyRenamedTopics)
const borderRadius = showTopicTime ? 12 : 'var(--list-item-border-radius)'
const [deletingTopicId, setDeletingTopicId] = useState<string | null>(null)
@@ -84,6 +87,20 @@ const Topics: FC<Props> = ({ assistant: _assistant, activeTopic, setActiveTopic
[activeTopic.id, pendingTopics]
)
const isRenaming = useCallback(
(topicId: string) => {
return renamingTopics.includes(topicId)
},
[renamingTopics]
)
const isNewlyRenamed = useCallback(
(topicId: string) => {
return newlyRenamedTopics.includes(topicId)
},
[newlyRenamedTopics]
)
const handleDeleteClick = useCallback((topicId: string, e: React.MouseEvent) => {
e.stopPropagation()
@@ -170,16 +187,21 @@ const Topics: FC<Props> = ({ assistant: _assistant, activeTopic, setActiveTopic
label: t('chat.topics.auto_rename'),
key: 'auto-rename',
icon: <i className="iconfont icon-business-smart-assistant" style={{ fontSize: '14px' }} />,
disabled: isRenaming(topic.id),
async onClick() {
const messages = await TopicManager.getTopicMessages(topic.id)
if (messages.length >= 2) {
const summaryText = await fetchMessagesSummary({ messages, assistant })
if (summaryText) {
const updatedTopic = { ...topic, name: summaryText, isNameManuallyEdited: false }
updateTopic(updatedTopic)
topic.id === activeTopic.id && setActiveTopic(updatedTopic)
} else {
window.message?.error(t('message.error.fetchTopicName'))
startTopicRenaming(topic.id)
try {
const summaryText = await fetchMessagesSummary({ messages, assistant })
if (summaryText) {
const updatedTopic = { ...topic, name: summaryText, isNameManuallyEdited: false }
updateTopic(updatedTopic)
} else {
window.message?.error(t('message.error.fetchTopicName'))
}
} finally {
finishTopicRenaming(topic.id)
}
}
}
@@ -188,6 +210,7 @@ const Topics: FC<Props> = ({ assistant: _assistant, activeTopic, setActiveTopic
label: t('chat.topics.edit.title'),
key: 'rename',
icon: <EditOutlined />,
disabled: isRenaming(topic.id),
async onClick() {
const name = await PromptPopup.show({
title: t('chat.topics.edit.title'),
@@ -197,7 +220,6 @@ const Topics: FC<Props> = ({ assistant: _assistant, activeTopic, setActiveTopic
if (name && topic?.name !== name) {
const updatedTopic = { ...topic, name, isNameManuallyEdited: true }
updateTopic(updatedTopic)
topic.id === activeTopic.id && setActiveTopic(updatedTopic)
}
}
},
@@ -388,6 +410,7 @@ const Topics: FC<Props> = ({ assistant: _assistant, activeTopic, setActiveTopic
}, [
targetTopic,
t,
isRenaming,
exportMenuOptions.image,
exportMenuOptions.markdown,
exportMenuOptions.markdown_reason,
@@ -430,6 +453,13 @@ const Topics: FC<Props> = ({ assistant: _assistant, activeTopic, setActiveTopic
const topicName = topic.name.replace('`', '')
const topicPrompt = topic.prompt
const fullTopicPrompt = t('common.prompt') + ': ' + topicPrompt
const getTopicNameClassName = () => {
if (isRenaming(topic.id)) return 'shimmer'
if (isNewlyRenamed(topic.id)) return 'typing'
return ''
}
return (
<TopicListItem
onContextMenu={() => setTargetTopic(topic)}
@@ -438,7 +468,7 @@ const Topics: FC<Props> = ({ assistant: _assistant, activeTopic, setActiveTopic
style={{ borderRadius }}>
{isPending(topic.id) && !isActive && <PendingIndicator />}
<TopicNameContainer>
<TopicName className="name" title={topicName}>
<TopicName className={getTopicNameClassName()} title={topicName}>
{topicName}
</TopicName>
{isActive && !topic.pinned && (
@@ -544,6 +574,46 @@ const TopicName = styled.div`
-webkit-box-orient: vertical;
overflow: hidden;
font-size: 13px;
position: relative;
will-change: background-position, width;
--color-shimmer-mid: var(--color-text-1);
--color-shimmer-end: color-mix(in srgb, var(--color-text-1) 25%, transparent);
&.shimmer {
background: linear-gradient(to left, var(--color-shimmer-end), var(--color-shimmer-mid), var(--color-shimmer-end));
background-size: 200% 100%;
background-clip: text;
color: transparent;
animation: shimmer 3s linear infinite;
}
&.typing {
display: block;
-webkit-line-clamp: unset;
-webkit-box-orient: unset;
white-space: nowrap;
overflow: hidden;
animation: typewriter 0.5s steps(40, end);
}
@keyframes shimmer {
0% {
background-position: 200% 0;
}
100% {
background-position: -200% 0;
}
}
@keyframes typewriter {
from {
width: 0;
}
to {
width: 100%;
}
}
`
const PendingIndicator = styled.div.attrs({
@@ -1,124 +0,0 @@
import SvgSpinners180Ring from '@renderer/components/Icons/SvgSpinners180Ring'
import { fetchSuggestions } from '@renderer/services/ApiService'
import { getUserMessage } from '@renderer/services/MessagesService'
import { useAppDispatch } from '@renderer/store'
import { sendMessage } from '@renderer/store/thunk/messageThunk'
import { Assistant, Suggestion } from '@renderer/types'
import type { Message } from '@renderer/types/newMessage'
import { last } from 'lodash'
import { FC, memo, useEffect, useState } from 'react'
import styled from 'styled-components'
interface Props {
assistant: Assistant
messages: Message[]
}
const suggestionsMap = new Map<string, Suggestion[]>()
const Suggestions: FC<Props> = ({ assistant, messages }) => {
const dispatch = useAppDispatch()
const [suggestions, setSuggestions] = useState<Suggestion[]>(
suggestionsMap.get(messages[messages.length - 1]?.id) || []
)
const [loadingSuggestions, setLoadingSuggestions] = useState(false)
const handleSuggestionClick = async (content: string) => {
const { message: userMessage, blocks } = getUserMessage({
assistant,
topic: assistant.topics[0],
content
})
await dispatch(sendMessage(userMessage, blocks, assistant, assistant.topics[0].id))
}
const suggestionsHandle = async () => {
if (loadingSuggestions) return
try {
setLoadingSuggestions(true)
const _suggestions = await fetchSuggestions({
assistant,
messages
})
if (_suggestions.length) {
setSuggestions(_suggestions)
suggestionsMap.set(messages[messages.length - 1].id, _suggestions)
}
} finally {
setLoadingSuggestions(false)
}
}
useEffect(() => {
suggestionsHandle()
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [])
useEffect(() => {
setSuggestions(suggestionsMap.get(messages[messages.length - 1]?.id) || [])
}, [messages])
if (last(messages)?.status !== 'success') {
return null
}
if (loadingSuggestions) {
return (
<Container>
<SvgSpinners180Ring color="var(--color-text-2)" />
</Container>
)
}
if (suggestions.length === 0) {
return null
}
return (
<Container>
<SuggestionsContainer>
{suggestions.map((s, i) => (
<SuggestionItem key={i} onClick={() => handleSuggestionClick(s.content)}>
{s.content}
</SuggestionItem>
))}
</SuggestionsContainer>
</Container>
)
}
const Container = styled.div`
display: flex;
flex-direction: column;
padding: 10px 10px 20px 65px;
display: flex;
width: 100%;
flex-direction: row;
flex-wrap: wrap;
gap: 15px;
`
const SuggestionsContainer = styled.div`
display: flex;
flex-direction: row;
flex-wrap: wrap;
gap: 10px;
`
const SuggestionItem = styled.div`
display: flex;
align-items: center;
width: fit-content;
padding: 5px 10px;
border-radius: 12px;
font-size: 12px;
color: var(--color-text);
background: var(--color-background-mute);
cursor: pointer;
&:hover {
opacity: 0.9;
}
`
export default memo(Suggestions)
@@ -1,4 +1,5 @@
import { InfoCircleOutlined, SettingOutlined, WarningOutlined } from '@ant-design/icons'
import AiProvider from '@renderer/aiCore'
import { TopView } from '@renderer/components/TopView'
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant'
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
@@ -9,7 +10,6 @@ import { useKnowledgeBases } from '@renderer/hooks/useKnowledge'
import { useOcrProviders } from '@renderer/hooks/useOcr'
import { usePreprocessProviders } from '@renderer/hooks/usePreprocess'
import { useProviders } from '@renderer/hooks/useProvider'
import AiProvider from '@renderer/providers/AiProvider'
import { getKnowledgeBaseParams } from '@renderer/services/KnowledgeService'
import { getModelUniqId } from '@renderer/services/ModelService'
import { KnowledgeBase, Model, OcrProvider, PreprocessProvider } from '@renderer/types'
@@ -426,48 +426,6 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
destroyOnClose
centered
okButtonProps={{ loading }}>
{/* <Form form={form} layout="vertical">
<Form.Item
name="name"
label={t('common.name')}
rules={[{ required: true, message: t('message.error.enter.name') }]}>
<Input placeholder={t('common.name')} ref={nameInputRef} />
</Form.Item>
<Form.Item
name="model"
label={t('models.embedding_model')}
tooltip={{ title: t('models.embedding_model_tooltip'), placement: 'right' }}
rules={[{ required: true, message: t('message.error.enter.model') }]}>
<Select style={{ width: '100%' }} options={embeddingSelectOptions} placeholder={t('settings.models.empty')} />
</Form.Item>
<Form.Item
name="rerankModel"
label={t('models.rerank_model')}
tooltip={{ title: t('models.rerank_model_tooltip'), placement: 'right' }}
rules={[{ required: false, message: t('message.error.enter.model') }]}>
<Select style={{ width: '100%' }} options={rerankSelectOptions} placeholder={t('settings.models.empty')} />
</Form.Item>
<SettingHelpText style={{ marginTop: -15, marginBottom: 20 }}>
{t('models.rerank_model_not_support_provider', {
provider: NOT_SUPPORTED_REANK_PROVIDERS.map((id) => t(`provider.${id}`))
})}
</SettingHelpText>
<Form.Item
name="documentCount"
label={t('knowledge.document_count')}
initialValue={DEFAULT_KNOWLEDGE_DOCUMENT_COUNT} // 设置初始值
tooltip={{ title: t('knowledge.document_count_help') }}>
<Slider
style={{ width: '100%' }}
min={1}
max={30}
step={1}
marks={{ 1: '1', 6: t('knowledge.document_count_default'), 30: '30' }}
/>
</Form.Item>
</Form> */}
<div>
<Tabs style={{ minHeight: '50vh' }} defaultActiveKey="1" tabPosition={'left'} items={settingItems} />
</div>
@@ -11,7 +11,7 @@ import { usePaintings } from '@renderer/hooks/usePaintings'
import { useAllProviders } from '@renderer/hooks/useProvider'
import { useRuntime } from '@renderer/hooks/useRuntime'
import { useSettings } from '@renderer/hooks/useSettings'
import AiProvider from '@renderer/providers/AiProvider'
import AiProvider from '@renderer/aiCore'
import FileManager from '@renderer/services/FileManager'
import { translateText } from '@renderer/services/TranslateService'
import { useAppDispatch } from '@renderer/store'
@@ -182,11 +182,9 @@ const AihubmixPage: FC<{ Options: string[] }> = ({ Options }) => {
const base64s = await AI.generateImage({
prompt,
model: painting.model,
config: {
aspectRatio: painting.aspectRatio?.replace('ASPECT_', '').replace('_', ':'),
numberOfImages: painting.model.startsWith('imagen-4.0-ultra-generate-exp') ? 1 : painting.numberOfImages,
personGeneration: painting.personGeneration
}
imageSize: painting.aspectRatio?.replace('ASPECT_', '').replace('_', ':') || '1:1',
batchSize: painting.model.startsWith('imagen-4.0-ultra-generate-exp') ? 1 : painting.numberOfImages || 1,
personGeneration: painting.personGeneration
})
if (base64s?.length > 0) {
const validFiles = await Promise.all(
@@ -16,7 +16,7 @@ import { usePaintings } from '@renderer/hooks/usePaintings'
import { useAllProviders } from '@renderer/hooks/useProvider'
import { useRuntime } from '@renderer/hooks/useRuntime'
import { useSettings } from '@renderer/hooks/useSettings'
import AiProvider from '@renderer/providers/AiProvider'
import AiProvider from '@renderer/aiCore'
import { getProviderByModel } from '@renderer/services/AssistantService'
import FileManager from '@renderer/services/FileManager'
import { translateText } from '@renderer/services/TranslateService'
@@ -51,8 +51,8 @@ const PopupContainer: React.FC<Props> = ({ title, provider, model, apiKeys, type
try {
let valid = false
if (type === 'provider' && model) {
const result = await checkApi({ ...(provider as Provider), apiKey: status.key }, model)
valid = result.valid
await checkApi({ ...(provider as Provider), apiKey: status.key }, model)
valid = true
} else {
const result = await WebSearchService.checkSearch({
...(provider as WebSearchProvider),
@@ -65,7 +65,7 @@ const PopupContainer: React.FC<Props> = ({ title, provider, model, apiKeys, type
setKeyStatuses((prev) => prev.map((s, idx) => (idx === i ? { ...s, checking: false, isValid: valid } : s)))
return { index: i, valid }
} catch (error) {
} catch (error: unknown) {
// 处理错误情况
setKeyStatuses((prev) => prev.map((s, idx) => (idx === i ? { ...s, checking: false, isValid: false } : s)))
return { index: i, valid: false }
@@ -90,8 +90,8 @@ const PopupContainer: React.FC<Props> = ({ title, provider, model, apiKeys, type
try {
let valid = false
if (type === 'provider' && model) {
const result = await checkApi({ ...(provider as Provider), apiKey: keyStatuses[keyIndex].key }, model)
valid = result.valid
await checkApi({ ...(provider as Provider), apiKey: keyStatuses[keyIndex].key }, model)
valid = true
} else {
const result = await WebSearchService.checkSearch({
...(provider as WebSearchProvider),
@@ -103,7 +103,7 @@ const PopupContainer: React.FC<Props> = ({ title, provider, model, apiKeys, type
setKeyStatuses((prev) =>
prev.map((status, idx) => (idx === keyIndex ? { ...status, checking: false, isValid: valid } : status))
)
} catch (error) {
} catch (error: unknown) {
setKeyStatuses((prev) =>
prev.map((status, idx) => (idx === keyIndex ? { ...status, checking: false, isValid: false } : status))
)
@@ -145,14 +145,17 @@ const PopupContainer: React.FC<Props> = ({ provider: _provider, resolve }) => {
setListModels(
models
.map((model) => ({
id: model.id,
// @ts-ignore modelId
id: model?.id || model?.name,
// @ts-ignore name
name: model.name || model.id,
name: model?.display_name || model?.displayName || model?.name || model?.id,
provider: _provider.id,
group: getDefaultGroupName(model.id, _provider.id),
// @ts-ignore name
description: model?.description,
owned_by: model?.owned_by
// @ts-ignore group
group: getDefaultGroupName(model?.id || model?.name, _provider.id),
// @ts-ignore description
description: model?.description || '',
// @ts-ignore owned_by
owned_by: model?.owned_by || ''
}))
.filter((model) => !isEmpty(model.name))
)
@@ -7,7 +7,7 @@ import { PROVIDER_CONFIG } from '@renderer/config/providers'
import { useTheme } from '@renderer/context/ThemeProvider'
import { useAllProviders, useProvider, useProviders } from '@renderer/hooks/useProvider'
import i18n from '@renderer/i18n'
import { isOpenAIProvider } from '@renderer/providers/AiProvider/ProviderFactory'
import { isOpenAIProvider } from '@renderer/aiCore/clients/ApiClientFactory'
import { checkApi, formatApiKeys } from '@renderer/services/ApiService'
import { checkModelsHealth, getModelCheckSummary } from '@renderer/services/HealthCheckService'
import { isProviderSupportAuth } from '@renderer/services/ProviderService'
@@ -231,22 +231,32 @@ const ProviderSetting: FC<Props> = ({ provider: _provider }) => {
} else {
setApiChecking(true)
const { valid, error } = await checkApi({ ...provider, apiKey, apiHost }, model)
try {
await checkApi({ ...provider, apiKey, apiHost }, model)
const errorMessage = error && error?.message ? ' ' + error?.message : ''
window.message.success({
key: 'api-check',
style: { marginTop: '3vh' },
duration: 2,
content: i18n.t('message.api.connection.success')
})
window.message[valid ? 'success' : 'error']({
key: 'api-check',
style: { marginTop: '3vh' },
duration: valid ? 2 : 8,
content: valid
? i18n.t('message.api.connection.success')
: i18n.t('message.api.connection.failed') + errorMessage
})
setApiValid(true)
setTimeout(() => setApiValid(false), 3000)
} catch (error: any) {
const errorMessage = error?.message ? ' ' + error.message : ''
setApiValid(valid)
setApiChecking(false)
setTimeout(() => setApiValid(false), 3000)
window.message.error({
key: 'api-check',
style: { marginTop: '3vh' },
duration: 8,
content: i18n.t('message.api.connection.failed') + errorMessage
})
setApiValid(false)
} finally {
setApiChecking(false)
}
}
}
@@ -1,117 +0,0 @@
import { isOpenAILLMModel } from '@renderer/config/models'
import { getDefaultModel } from '@renderer/services/AssistantService'
import { Assistant, MCPCallToolResponse, MCPTool, MCPToolResponse, Model, Provider, Suggestion } from '@renderer/types'
import { Message } from '@renderer/types/newMessage'
import OpenAI from 'openai'
import { CompletionsParams } from '.'
import AnthropicProvider from './AnthropicProvider'
import BaseProvider from './BaseProvider'
import GeminiProvider from './GeminiProvider'
import OpenAIProvider from './OpenAIProvider'
import OpenAIResponseProvider from './OpenAIResponseProvider'
/**
* AihubmixProvider -
* 使
*/
export default class AihubmixProvider extends BaseProvider {
private providers: Map<string, BaseProvider> = new Map()
private defaultProvider: BaseProvider
private currentProvider: BaseProvider
constructor(provider: Provider) {
super(provider)
// 初始化各个提供商
this.providers.set('claude', new AnthropicProvider(provider))
this.providers.set('gemini', new GeminiProvider({ ...provider, apiHost: 'https://aihubmix.com/gemini' }))
this.providers.set('openai', new OpenAIResponseProvider(provider))
this.providers.set('default', new OpenAIProvider(provider))
// 设置默认提供商
this.defaultProvider = this.providers.get('default')!
this.currentProvider = this.defaultProvider
}
/**
*
*/
private getProvider(model: Model): BaseProvider {
const id = model.id.toLowerCase()
// claude开头
if (id.startsWith('claude')) {
return this.providers.get('claude')!
}
// gemini开头 或 imagen开头 且不以-nothink、-search结尾
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
return this.providers.get('gemini')!
}
if (isOpenAILLMModel(model)) {
return this.providers.get('openai')!
}
return this.defaultProvider
}
// 直接使用默认提供商的方法
public async models(): Promise<OpenAI.Models.Model[]> {
return this.defaultProvider.models()
}
public async generateText(params: { prompt: string; content: string }): Promise<string> {
return this.defaultProvider.generateText(params)
}
public async generateImage(params: any): Promise<string[]> {
return this.getProvider({
id: params.model
} as unknown as Model).generateImage(params)
}
public async generateImageByChat(params: any): Promise<void> {
return this.defaultProvider.generateImageByChat(params)
}
public async completions(params: CompletionsParams): Promise<void> {
const model = params.assistant.model
this.currentProvider = this.getProvider(model!)
return this.currentProvider.completions(params)
}
public async translate(
content: string,
assistant: Assistant,
onResponse?: (text: string, isComplete: boolean) => void
): Promise<string> {
return this.getProvider(assistant.model || getDefaultModel()).translate(content, assistant, onResponse)
}
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
return this.getProvider(assistant.model || getDefaultModel()).summaries(messages, assistant)
}
public async summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null> {
return this.getProvider(assistant.model || getDefaultModel()).summaryForSearch(messages, assistant)
}
public async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> {
return this.getProvider(assistant.model || getDefaultModel()).suggestions(messages, assistant)
}
public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> {
return this.getProvider(model).check(model, stream)
}
public async getEmbeddingDimensions(model: Model): Promise<number> {
return this.getProvider(model).getEmbeddingDimensions(model)
}
public convertMcpTools<T>(mcpTools: MCPTool[]) {
return this.currentProvider.convertMcpTools(mcpTools) as T[]
}
public mcpToolCallResponseToMessage(mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) {
return this.currentProvider.mcpToolCallResponseToMessage(mcpToolResponse, resp, model)
}
}
@@ -1,802 +0,0 @@
import Anthropic from '@anthropic-ai/sdk'
import {
Base64ImageSource,
ImageBlockParam,
MessageCreateParamsNonStreaming,
MessageParam,
TextBlockParam,
ToolResultBlockParam,
ToolUnion,
ToolUseBlock,
WebSearchResultBlock,
WebSearchTool20250305,
WebSearchToolResultError
} from '@anthropic-ai/sdk/resources'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/AssistantService'
import FileManager from '@renderer/services/FileManager'
import {
filterContextMessages,
filterEmptyMessages,
filterUserRoleStartMessages
} from '@renderer/services/MessagesService'
import {
Assistant,
EFFORT_RATIO,
FileTypes,
MCPCallToolResponse,
MCPTool,
MCPToolResponse,
Metrics,
Model,
Provider,
Suggestion,
ToolCallResponse,
Usage,
WebSearchSource
} from '@renderer/types'
import { ChunkType } from '@renderer/types/chunk'
import type { Message } from '@renderer/types/newMessage'
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
import {
anthropicToolUseToMcpTool,
isEnabledToolUse,
mcpToolCallResponseToAnthropicMessage,
mcpToolsToAnthropicTools,
parseAndCallTools
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { first, flatten, takeRight } from 'lodash'
import OpenAI from 'openai'
import { CompletionsParams } from '.'
import BaseProvider from './BaseProvider'
interface ReasoningConfig {
type: 'enabled' | 'disabled'
budget_tokens?: number
}
export default class AnthropicProvider extends BaseProvider {
private sdk: Anthropic
constructor(provider: Provider) {
super(provider)
this.sdk = new Anthropic({
apiKey: this.apiKey,
baseURL: this.getBaseURL(),
dangerouslyAllowBrowser: true,
defaultHeaders: {
'anthropic-beta': 'output-128k-2025-02-19'
}
})
}
public getBaseURL(): string {
return this.provider.apiHost
}
/**
* Get the message parameter
* @param message - The message
* @returns The message parameter
*/
private async getMessageParam(message: Message): Promise<MessageParam> {
const parts: MessageParam['content'] = [
{
type: 'text',
text: getMainTextContent(message)
}
]
// Get and process image blocks
const imageBlocks = findImageBlocks(message)
for (const imageBlock of imageBlocks) {
if (imageBlock.file) {
// Handle uploaded file
const file = imageBlock.file
const base64Data = await window.api.file.base64Image(file.id + file.ext)
parts.push({
type: 'image',
source: {
data: base64Data.base64,
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
type: 'base64'
}
})
}
}
// Get and process file blocks
const fileBlocks = findFileBlocks(message)
for (const fileBlock of fileBlocks) {
const { file } = fileBlock
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) {
const base64Data = await FileManager.readBase64File(file)
parts.push({
type: 'document',
source: {
type: 'base64',
media_type: 'application/pdf',
data: base64Data
}
})
} else {
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
parts.push({
type: 'text',
text: file.origin_name + '\n' + fileContent
})
}
}
}
return {
role: message.role === 'system' ? 'user' : message.role,
content: parts
}
}
private async getWebSearchParams(model: Model): Promise<WebSearchTool20250305 | undefined> {
if (!isWebSearchModel(model)) {
return undefined
}
return {
type: 'web_search_20250305',
name: 'web_search',
max_uses: 5
} as WebSearchTool20250305
}
override getTemperature(assistant: Assistant, model: Model): number | undefined {
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
return undefined
}
return assistant.settings?.temperature
}
override getTopP(assistant: Assistant, model: Model): number | undefined {
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
return undefined
}
return assistant.settings?.topP
}
/**
* Get the reasoning effort
* @param assistant - The assistant
* @param model - The model
* @returns The reasoning effort
*/
private getBudgetToken(assistant: Assistant, model: Model): ReasoningConfig | undefined {
if (!isReasoningModel(model)) {
return undefined
}
const { maxTokens } = getAssistantSettings(assistant)
const reasoningEffort = assistant?.settings?.reasoning_effort
if (reasoningEffort === undefined) {
return {
type: 'disabled'
}
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
const budgetTokens = Math.max(
1024,
Math.floor(
Math.min(
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
findTokenLimit(model.id)?.min!,
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
)
)
)
return {
type: 'enabled',
budget_tokens: budgetTokens
}
}
/**
* Generate completions
* @param messages - The messages
* @param assistant - The assistant
* @param mcpTools - The MCP tools
* @param onChunk - The onChunk callback
* @param onFilterMessages - The onFilterMessages callback
*/
public async completions({ messages, assistant, mcpTools, onChunk, onFilterMessages }: CompletionsParams) {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens, streamOutput } = getAssistantSettings(assistant)
const userMessagesParams: MessageParam[] = []
const _messages = filterUserRoleStartMessages(
filterContextMessages(filterEmptyMessages(takeRight(messages, contextCount + 2)))
)
onFilterMessages(_messages)
for (const message of _messages) {
userMessagesParams.push(await this.getMessageParam(message))
}
const userMessages = flatten(userMessagesParams)
const lastUserMessage = _messages.findLast((m) => m.role === 'user')
let systemPrompt = assistant.prompt
const { tools } = this.setupToolsConfig<ToolUnion>({
model,
mcpTools,
enableToolUse: isEnabledToolUse(assistant)
})
if (this.useSystemPromptForTools && mcpTools && mcpTools.length) {
systemPrompt = await buildSystemPrompt(systemPrompt, mcpTools)
}
let systemMessage: TextBlockParam | undefined = undefined
if (systemPrompt) {
systemMessage = {
type: 'text',
text: systemPrompt
}
}
const isEnabledBuiltinWebSearch = assistant.enableWebSearch && isWebSearchModel(model)
if (isEnabledBuiltinWebSearch) {
const webSearchTool = await this.getWebSearchParams(model)
if (webSearchTool) {
tools.push(webSearchTool)
}
}
const body: MessageCreateParamsNonStreaming = {
model: model.id,
messages: userMessages,
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
temperature: this.getTemperature(assistant, model),
top_p: this.getTopP(assistant, model),
system: systemMessage ? [systemMessage] : undefined,
// @ts-ignore thinking
thinking: this.getBudgetToken(assistant, model),
tools: tools,
...this.getCustomParameters(assistant)
}
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
const { signal } = abortController
const finalUsage: Usage = {
completion_tokens: 0,
prompt_tokens: 0,
total_tokens: 0
}
const finalMetrics: Metrics = {
completion_tokens: 0,
time_completion_millsec: 0,
time_first_token_millsec: 0
}
const toolResponses: MCPToolResponse[] = []
const processStream = async (body: MessageCreateParamsNonStreaming, idx: number) => {
let time_first_token_millsec = 0
if (!streamOutput) {
const message = await this.sdk.messages.create({ ...body, stream: false })
const time_completion_millsec = new Date().getTime() - start_time_millsec
let text = ''
let reasoning_content = ''
if (message.content && message.content.length > 0) {
const thinkingBlock = message.content.find((block) => block.type === 'thinking')
const textBlock = message.content.find((block) => block.type === 'text')
if (thinkingBlock && 'thinking' in thinkingBlock) {
reasoning_content = thinkingBlock.thinking
}
if (textBlock && 'text' in textBlock) {
text = textBlock.text
}
}
return onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
text,
reasoning_content,
usage: message.usage as any,
metrics: {
completion_tokens: message.usage?.output_tokens || 0,
time_completion_millsec,
time_first_token_millsec: 0
}
}
})
}
let thinking_content = ''
let isFirstChunk = true
return new Promise<void>((resolve, reject) => {
// 等待接口返回流
const toolCalls: ToolUseBlock[] = []
this.sdk.messages
.stream({ ...body, stream: true }, { signal, timeout: 5 * 60 * 1000 })
.on('text', (text) => {
if (isFirstChunk) {
isFirstChunk = false
if (time_first_token_millsec == 0) {
time_first_token_millsec = new Date().getTime()
} else {
onChunk({
type: ChunkType.THINKING_COMPLETE,
text: thinking_content,
thinking_millsec: new Date().getTime() - time_first_token_millsec
})
}
}
onChunk({ type: ChunkType.TEXT_DELTA, text })
})
.on('contentBlock', (block) => {
if (block.type === 'server_tool_use' && block.name === 'web_search') {
onChunk({
type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS
})
} else if (block.type === 'web_search_tool_result') {
if (
block.content &&
(block.content as WebSearchToolResultError).type === 'web_search_tool_result_error'
) {
onChunk({
type: ChunkType.ERROR,
error: {
code: (block.content as WebSearchToolResultError).error_code,
message: (block.content as WebSearchToolResultError).error_code
}
})
} else {
onChunk({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: block.content as Array<WebSearchResultBlock>,
source: WebSearchSource.ANTHROPIC
}
})
}
}
if (block.type === 'tool_use') {
toolCalls.push(block)
}
})
.on('thinking', (thinking) => {
if (time_first_token_millsec == 0) {
time_first_token_millsec = new Date().getTime()
}
onChunk({
type: ChunkType.THINKING_DELTA,
text: thinking,
thinking_millsec: new Date().getTime() - time_first_token_millsec
})
thinking_content += thinking
})
.on('finalMessage', async (message) => {
const toolResults: Awaited<ReturnType<typeof parseAndCallTools>> = []
// tool call
if (toolCalls.length > 0) {
const mcpToolResponses = toolCalls
.map((toolCall) => {
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
if (!mcpTool) {
return undefined
}
return {
id: toolCall.id,
toolCallId: toolCall.id,
tool: mcpTool,
arguments: toolCall.input as Record<string, unknown>,
status: 'pending'
} as ToolCallResponse
})
.filter((t) => typeof t !== 'undefined')
toolResults.push(
...(await parseAndCallTools(
mcpToolResponses,
toolResponses,
onChunk,
this.mcpToolCallResponseToMessage,
model,
mcpTools
))
)
}
// tool use
const content = message.content[0]
if (content && content.type === 'text') {
onChunk({ type: ChunkType.TEXT_COMPLETE, text: content.text })
toolResults.push(
...(await parseAndCallTools(
content.text,
toolResponses,
onChunk,
this.mcpToolCallResponseToMessage,
model,
mcpTools
))
)
}
if (thinking_content) {
onChunk({
type: ChunkType.THINKING_COMPLETE,
text: thinking_content,
thinking_millsec: new Date().getTime() - time_first_token_millsec
})
}
userMessages.push({
role: message.role,
content: message.content
})
if (toolResults.length > 0) {
toolResults.forEach((ts) => userMessages.push(ts as MessageParam))
const newBody = body
newBody.messages = userMessages
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
try {
await processStream(newBody, idx + 1)
} catch (error) {
console.error('Error processing stream:', error)
reject(error)
}
}
// 直接修改finalUsage对象会报错,TypeError: Cannot assign to read only property 'prompt_tokens' of object '#<Object>'
// 暂未找到原因
const updatedUsage: Usage = {
...finalUsage,
prompt_tokens: finalUsage.prompt_tokens + (message.usage?.input_tokens || 0),
completion_tokens: finalUsage.completion_tokens + (message.usage?.output_tokens || 0)
}
updatedUsage.total_tokens = updatedUsage.prompt_tokens + updatedUsage.completion_tokens
const updatedMetrics: Metrics = {
...finalMetrics,
completion_tokens: updatedUsage.completion_tokens,
time_completion_millsec:
finalMetrics.time_completion_millsec + (new Date().getTime() - start_time_millsec),
time_first_token_millsec: time_first_token_millsec - start_time_millsec
}
Object.assign(finalUsage, updatedUsage)
Object.assign(finalMetrics, updatedMetrics)
onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
usage: updatedUsage,
metrics: updatedMetrics
}
})
resolve()
})
.on('error', (error) => reject(error))
.on('abort', () => {
reject(new Error('Request was aborted.'))
})
})
}
onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
const start_time_millsec = new Date().getTime()
await processStream(body, 0).finally(() => {
cleanup()
})
}
/**
* Translate a message
* @param content
* @param assistant - The assistant
* @param onResponse - The onResponse callback
* @returns The translated message
*/
public async translate(
content: string,
assistant: Assistant,
onResponse?: (text: string, isComplete: boolean) => void
) {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const messagesForApi = [{ role: 'user' as const, content: content }]
const stream = !!onResponse
const body: MessageCreateParamsNonStreaming = {
model: model.id,
messages: messagesForApi,
max_tokens: 4096,
temperature: assistant?.settings?.temperature,
system: assistant.prompt
}
if (!stream) {
const response = await this.sdk.messages.create({ ...body, stream: false })
return response.content[0].type === 'text' ? response.content[0].text : ''
}
let text = ''
return new Promise<string>((resolve, reject) => {
this.sdk.messages
.stream({ ...body, stream: true })
.on('text', (_text) => {
text += _text
onResponse?.(text, false)
})
.on('finalMessage', () => {
onResponse?.(text, true)
resolve(text)
})
.on('error', (error) => reject(error))
})
}
/**
* Summarize a message
* @param messages - The messages
* @param assistant - The assistant
* @returns The summary
*/
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
const model = getTopNamingModel() || assistant.model || getDefaultModel()
const userMessages = takeRight(messages, 5).map((message) => ({
role: message.role,
content: getMainTextContent(message)
}))
if (first(userMessages)?.role === 'assistant') {
userMessages.shift()
}
const userMessageContent = userMessages.reduce((prev, curr) => {
const currentContent = curr.role === 'user' ? `User: ${curr.content}` : `Assistant: ${curr.content}`
return prev + (prev ? '\n' : '') + currentContent
}, '')
const systemMessage = {
role: 'system',
content: (getStoreSetting('topicNamingPrompt') as string) || i18n.t('prompts.title')
}
const userMessage = {
role: 'user',
content: userMessageContent
}
const message = await this.sdk.messages.create({
messages: [userMessage] as Anthropic.Messages.MessageParam[],
model: model.id,
system: systemMessage.content,
stream: false,
max_tokens: 4096
})
const responseContent = message.content[0].type === 'text' ? message.content[0].text : ''
return removeSpecialCharactersForTopicName(responseContent)
}
/**
* Summarize a message for search
* @param messages - The messages
* @param assistant - The assistant
* @returns The summary
*/
public async summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null> {
const model = assistant.model || getDefaultModel()
const systemMessage = { content: assistant.prompt }
const userMessageContent = messages.map((m) => getMainTextContent(m)).join('\n')
const userMessage = {
role: 'user' as const,
content: userMessageContent
}
const lastUserMessage = messages[messages.length - 1]
const { abortController, cleanup } = this.createAbortController(lastUserMessage?.id)
const { signal } = abortController
const response = await this.sdk.messages
.create(
{
messages: [userMessage],
model: model.id,
system: systemMessage.content,
stream: false,
max_tokens: 4096
},
{ timeout: 20 * 1000, signal }
)
.finally(cleanup)
return response.content[0].type === 'text' ? response.content[0].text : ''
}
/**
* Generate text
* @param prompt - The prompt
* @param content - The content
* @returns The generated text
*/
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
const model = getDefaultModel()
const message = await this.sdk.messages.create({
model: model.id,
system: prompt,
stream: false,
max_tokens: 4096,
messages: [
{
role: 'user',
content
}
]
})
return message.content[0].type === 'text' ? message.content[0].text : ''
}
/**
* Generate an image
* @returns The generated image
*/
public async generateImage(): Promise<string[]> {
return []
}
public async generateImageByChat(): Promise<void> {
throw new Error('Method not implemented.')
}
/**
* Generate suggestions
* @returns The suggestions
*/
public async suggestions(): Promise<Suggestion[]> {
return []
}
/**
* Check if the model is valid
* @param model - The model
* @param stream - Whether to use streaming interface
* @returns The validity of the model
*/
public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> {
if (!model) {
return { valid: false, error: new Error('No model found') }
}
const body = {
model: model.id,
messages: [{ role: 'user' as const, content: 'hi' }],
max_tokens: 2, // api文档写的 x>1
stream
}
try {
if (!stream) {
const message = await this.sdk.messages.create(body as MessageCreateParamsNonStreaming)
return {
valid: message.content.length > 0,
error: null
}
} else {
return await new Promise((resolve, reject) => {
let hasContent = false
this.sdk.messages
.stream(body)
.on('text', (text) => {
if (!hasContent && text) {
hasContent = true
resolve({ valid: true, error: null })
}
})
.on('finalMessage', (message) => {
if (!hasContent && message.content && message.content.length > 0) {
hasContent = true
resolve({ valid: true, error: null })
}
if (!hasContent) {
reject(new Error('Empty streaming response'))
}
})
.on('error', (error) => reject(error))
})
}
} catch (error: any) {
return {
valid: false,
error
}
}
}
/**
* Get the models
* @returns The models
*/
public async models(): Promise<OpenAI.Models.Model[]> {
return []
}
public async getEmbeddingDimensions(): Promise<number> {
return 0
}
public convertMcpTools<T>(mcpTools: MCPTool[]): T[] {
return mcpToolsToAnthropicTools(mcpTools) as T[]
}
public mcpToolCallResponseToMessage = (mcpToolResponse: MCPToolResponse, resp: MCPCallToolResponse, model: Model) => {
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model)
} else if ('toolCallId' in mcpToolResponse) {
return {
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: mcpToolResponse.toolCallId!,
content: resp.content
.map((item) => {
if (item.type === 'text') {
return {
type: 'text',
text: item.text || ''
} satisfies TextBlockParam
}
if (item.type === 'image') {
return {
type: 'image',
source: {
data: item.data || '',
media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'],
type: 'base64'
}
} satisfies ImageBlockParam
}
return
})
.filter((n) => typeof n !== 'undefined'),
is_error: resp.isError
} satisfies ToolResultBlockParam
]
}
}
return
}
}
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
File diff suppressed because it is too large Load Diff
@@ -1,33 +0,0 @@
import { Provider } from '@renderer/types'
import AihubmixProvider from './AihubmixProvider'
import AnthropicProvider from './AnthropicProvider'
import BaseProvider from './BaseProvider'
import GeminiProvider from './GeminiProvider'
import OpenAIProvider from './OpenAIProvider'
import OpenAIResponseProvider from './OpenAIResponseProvider'
export default class ProviderFactory {
static create(provider: Provider): BaseProvider {
if (provider.id === 'aihubmix') {
return new AihubmixProvider(provider)
}
switch (provider.type) {
case 'openai':
return new OpenAIProvider(provider)
case 'openai-response':
return new OpenAIResponseProvider(provider)
case 'anthropic':
return new AnthropicProvider(provider)
case 'gemini':
return new GeminiProvider(provider)
default:
return new OpenAIProvider(provider)
}
}
}
export function isOpenAIProvider(provider: Provider) {
return !['anthropic', 'gemini'].includes(provider.type)
}
@@ -1,94 +0,0 @@
import { GenerateImagesParameters } from '@google/genai'
import BaseProvider from '@renderer/providers/AiProvider/BaseProvider'
import ProviderFactory from '@renderer/providers/AiProvider/ProviderFactory'
import type { Assistant, GenerateImageParams, MCPTool, Model, Provider, Suggestion } from '@renderer/types'
import { Chunk } from '@renderer/types/chunk'
import type { Message } from '@renderer/types/newMessage'
import OpenAI from 'openai'
export interface CompletionsParams {
messages: Message[]
assistant: Assistant
onChunk: (chunk: Chunk) => void
onFilterMessages: (messages: Message[]) => void
mcpTools?: MCPTool[]
}
export default class AiProvider {
private sdk: BaseProvider
constructor(provider: Provider) {
this.sdk = ProviderFactory.create(provider)
}
public async fakeCompletions(params: CompletionsParams): Promise<void> {
return this.sdk.fakeCompletions(params)
}
public async completions({
messages,
assistant,
mcpTools,
onChunk,
onFilterMessages
}: CompletionsParams): Promise<void> {
return this.sdk.completions({ messages, assistant, mcpTools, onChunk, onFilterMessages })
}
public async translate(
content: string,
assistant: Assistant,
onResponse?: (text: string, isComplete: boolean) => void
): Promise<string> {
return this.sdk.translate(content, assistant, onResponse)
}
public async summaries(messages: Message[], assistant: Assistant): Promise<string> {
return this.sdk.summaries(messages, assistant)
}
public async summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null> {
return this.sdk.summaryForSearch(messages, assistant)
}
public async suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]> {
return this.sdk.suggestions(messages, assistant)
}
public async generateText({ prompt, content }: { prompt: string; content: string }): Promise<string> {
return this.sdk.generateText({ prompt, content })
}
public async check(model: Model, stream: boolean = false): Promise<{ valid: boolean; error: Error | null }> {
return this.sdk.check(model, stream)
}
public async models(): Promise<OpenAI.Models.Model[]> {
return this.sdk.models()
}
public getApiKey(): string {
return this.sdk.getApiKey()
}
public async generateImage(params: GenerateImageParams | GenerateImagesParameters): Promise<string[]> {
return this.sdk.generateImage(params as GenerateImageParams)
}
public async generateImageByChat({
messages,
assistant,
onChunk,
onFilterMessages
}: CompletionsParams): Promise<void> {
return this.sdk.generateImageByChat({ messages, assistant, onChunk, onFilterMessages })
}
public async getEmbeddingDimensions(model: Model): Promise<number> {
return this.sdk.getEmbeddingDimensions(model)
}
public getBaseURL(): string {
return this.sdk.getBaseURL()
}
}
+176 -92
View File
@@ -1,10 +1,21 @@
import { CompletionsParams } from '@renderer/aiCore/middleware/schemas'
import Logger from '@renderer/config/logger'
import { getOpenAIWebSearchParams, isOpenAIWebSearch } from '@renderer/config/models'
import {
isEmbeddingModel,
isGenerateImageModel,
isOpenRouterBuiltInWebSearchModel,
isReasoningModel,
isSupportedDisableGenerationModel,
isSupportedReasoningEffortModel,
isSupportedThinkingTokenModel,
isWebSearchModel
} from '@renderer/config/models'
import {
SEARCH_SUMMARY_PROMPT,
SEARCH_SUMMARY_PROMPT_KNOWLEDGE_ONLY,
SEARCH_SUMMARY_PROMPT_WEB_ONLY
} from '@renderer/config/prompts'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import {
Assistant,
@@ -13,20 +24,22 @@ import {
MCPTool,
Model,
Provider,
Suggestion,
WebSearchResponse,
WebSearchSource
} from '@renderer/types'
import { type Chunk, ChunkType } from '@renderer/types/chunk'
import { Message } from '@renderer/types/newMessage'
import { SdkModel } from '@renderer/types/sdk'
import { removeSpecialCharactersForTopicName } from '@renderer/utils'
import { isAbortError } from '@renderer/utils/error'
import { extractInfoFromXML, ExtractResults } from '@renderer/utils/extract'
import { getKnowledgeBaseIds, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { findLast, isEmpty } from 'lodash'
import { findLast, isEmpty, takeRight } from 'lodash'
import AiProvider from '../providers/AiProvider'
import AiProvider from '../aiCore'
import {
getAssistantProvider,
getAssistantSettings,
getDefaultModel,
getProviderByModel,
getTopNamingModel,
@@ -34,7 +47,13 @@ import {
} from './AssistantService'
import { getDefaultAssistant } from './AssistantService'
import { processKnowledgeSearch } from './KnowledgeService'
import { filterContextMessages, filterMessages, filterUsefulMessages } from './MessagesService'
import {
filterContextMessages,
filterEmptyMessages,
filterMessages,
filterUsefulMessages,
filterUserRoleStartMessages
} from './MessagesService'
import WebSearchService from './WebSearchService'
// TODO:考虑拆开
@@ -50,6 +69,7 @@ async function fetchExternalTool(
const knowledgeRecognition = assistant.knowledgeRecognition || 'on'
const webSearchProvider = WebSearchService.getWebSearchProvider(assistant.webSearchProviderId)
// 使用外部搜索工具
const shouldWebSearch = !!assistant.webSearchProviderId && webSearchProvider !== null
const shouldKnowledgeSearch = hasKnowledgeBase
@@ -83,14 +103,14 @@ async function fetchExternalTool(
summaryAssistant.prompt = prompt
try {
const keywords = await fetchSearchSummary({
const result = await fetchSearchSummary({
messages: lastAnswer ? [lastAnswer, lastUserMessage] : [lastUserMessage],
assistant: summaryAssistant
})
if (!keywords) return getFallbackResult()
if (!result) return getFallbackResult()
const extracted = extractInfoFromXML(keywords)
const extracted = extractInfoFromXML(result.getText())
// 根据需求过滤结果
return {
websearch: needWebExtract ? extracted?.websearch : undefined,
@@ -134,12 +154,6 @@ async function fetchExternalTool(
return undefined
}
// Pass the guaranteed model to the check function
const webSearchParams = getOpenAIWebSearchParams(assistant, assistant.model)
if (!isEmpty(webSearchParams) || isOpenAIWebSearch(assistant.model)) {
return
}
try {
// Use the consolidated processWebsearch function
WebSearchService.createAbortSignal(lastUserMessage.id)
@@ -238,7 +252,7 @@ async function fetchExternalTool(
// Get MCP tools (Fix duplicate declaration)
let mcpTools: MCPTool[] = [] // Initialize as empty array
const enabledMCPs = lastUserMessage?.enabledMCPs
const enabledMCPs = assistant.mcpServers
if (enabledMCPs && enabledMCPs.length > 0) {
try {
const toolPromises = enabledMCPs.map(async (mcpServer) => {
@@ -301,17 +315,52 @@ export async function fetchChatCompletion({
// NOTE: The search results are NOT added to the messages sent to the AI here.
// They will be retrieved and used by the messageThunk later to create CitationBlocks.
const { mcpTools } = await fetchExternalTool(lastUserMessage, assistant, onChunkReceived, lastAnswer)
const model = assistant.model || getDefaultModel()
const { maxTokens, contextCount } = getAssistantSettings(assistant)
const filteredMessages = filterUsefulMessages(messages)
const _messages = filterUserRoleStartMessages(
filterEmptyMessages(filterContextMessages(takeRight(filteredMessages, contextCount + 2))) // 取原来几个provider的最大值
)
const enableReasoning =
((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) &&
assistant.settings?.reasoning_effort !== undefined) ||
(isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model)))
const enableWebSearch =
(assistant.enableWebSearch && isWebSearchModel(model)) ||
isOpenRouterBuiltInWebSearchModel(model) ||
model.id.includes('sonar') ||
false
const enableGenerateImage =
isGenerateImageModel(model) && (isSupportedDisableGenerationModel(model) ? assistant.enableGenerateImage : true)
// --- Call AI Completions ---
await AI.completions({
messages: filteredMessages,
assistant,
onFilterMessages: () => {},
onChunk: onChunkReceived,
mcpTools: mcpTools
})
onChunkReceived({ type: ChunkType.LLM_RESPONSE_CREATED })
if (enableWebSearch) {
onChunkReceived({ type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS })
}
await AI.completions(
{
callType: 'chat',
messages: _messages,
assistant,
onChunk: onChunkReceived,
mcpTools: mcpTools,
maxTokens,
streamOutput: assistant.settings?.streamOutput || false,
enableReasoning,
enableWebSearch,
enableGenerateImage
},
{
streamOutput: assistant.settings?.streamOutput || false
}
)
}
interface FetchTranslateProps {
@@ -321,7 +370,7 @@ interface FetchTranslateProps {
}
export async function fetchTranslate({ content, assistant, onResponse }: FetchTranslateProps) {
const model = getTranslateModel()
const model = getTranslateModel() || assistant.model || getDefaultModel()
if (!model) {
throw new Error(i18n.t('error.provider_disabled'))
@@ -333,17 +382,45 @@ export async function fetchTranslate({ content, assistant, onResponse }: FetchTr
throw new Error(i18n.t('error.no_api_key'))
}
const isSupportedStreamOutput = () => {
if (!onResponse) {
return false
}
return true
}
const stream = isSupportedStreamOutput()
const enableReasoning =
((isSupportedThinkingTokenModel(model) || isSupportedReasoningEffortModel(model)) &&
assistant.settings?.reasoning_effort !== undefined) ||
(isReasoningModel(model) && (!isSupportedThinkingTokenModel(model) || !isSupportedReasoningEffortModel(model)))
const params: CompletionsParams = {
callType: 'translate',
messages: content,
assistant: { ...assistant, model },
streamOutput: stream,
enableReasoning,
onResponse
}
const AI = new AiProvider(provider)
try {
return await AI.translate(content, assistant, onResponse)
return (await AI.completions(params)).getText() || ''
} catch (error: any) {
return ''
}
}
export async function fetchMessagesSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
const prompt = (getStoreSetting('topicNamingPrompt') as string) || i18n.t('prompts.title')
const model = getTopNamingModel() || assistant.model || getDefaultModel()
const userMessages = takeRight(messages, 5).map((message) => ({
...message,
content: getMainTextContent(message)
}))
const provider = getProviderByModel(model)
if (!hasApiKey(provider)) {
@@ -352,9 +429,18 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages:
const AI = new AiProvider(provider)
const params: CompletionsParams = {
callType: 'summary',
messages: filterMessages(userMessages),
assistant: { ...assistant, prompt, model },
maxTokens: 1000,
streamOutput: false
}
try {
const text = await AI.summaries(filterMessages(messages), assistant)
return text?.replace(/["']/g, '') || null
const { getText } = await AI.completions(params)
const text = getText()
return removeSpecialCharactersForTopicName(text) || null
} catch (error: any) {
return null
}
@@ -370,7 +456,14 @@ export async function fetchSearchSummary({ messages, assistant }: { messages: Me
const AI = new AiProvider(provider)
return await AI.summaryForSearch(messages, assistant)
const params: CompletionsParams = {
callType: 'search',
messages: messages,
assistant,
streamOutput: false
}
return await AI.completions(params)
}
export async function fetchGenerate({ prompt, content }: { prompt: string; content: string }): Promise<string> {
@@ -383,42 +476,32 @@ export async function fetchGenerate({ prompt, content }: { prompt: string; conte
const AI = new AiProvider(provider)
const assistant = getDefaultAssistant()
assistant.model = model
assistant.prompt = prompt
const params: CompletionsParams = {
callType: 'generate',
messages: content,
assistant,
streamOutput: false
}
try {
return await AI.generateText({ prompt, content })
const result = await AI.completions(params)
return result.getText() || ''
} catch (error: any) {
return ''
}
}
export async function fetchSuggestions({
messages,
assistant
}: {
messages: Message[]
assistant: Assistant
}): Promise<Suggestion[]> {
const model = assistant.model
if (!model || model.id.endsWith('global')) {
return []
}
const provider = getAssistantProvider(assistant)
const AI = new AiProvider(provider)
try {
return await AI.suggestions(filterMessages(messages), assistant)
} catch (error: any) {
return []
}
}
function hasApiKey(provider: Provider) {
if (!provider) return false
if (provider.id === 'ollama' || provider.id === 'lmstudio') return true
return !isEmpty(provider.apiKey)
}
export async function fetchModels(provider: Provider) {
export async function fetchModels(provider: Provider): Promise<SdkModel[]> {
const AI = new AiProvider(provider)
try {
@@ -432,68 +515,69 @@ export const formatApiKeys = (value: string) => {
return value.replaceAll('', ',').replaceAll(' ', ',').replaceAll(' ', '').replaceAll('\n', ',')
}
export function checkApiProvider(provider: Provider): {
valid: boolean
error: Error | null
} {
export function checkApiProvider(provider: Provider): void {
const key = 'api-check'
const style = { marginTop: '3vh' }
if (provider.id !== 'ollama' && provider.id !== 'lmstudio') {
if (!provider.apiKey) {
window.message.error({ content: i18n.t('message.error.enter.api.key'), key, style })
return {
valid: false,
error: new Error(i18n.t('message.error.enter.api.key'))
}
throw new Error(i18n.t('message.error.enter.api.key'))
}
}
if (!provider.apiHost) {
window.message.error({ content: i18n.t('message.error.enter.api.host'), key, style })
return {
valid: false,
error: new Error(i18n.t('message.error.enter.api.host'))
}
throw new Error(i18n.t('message.error.enter.api.host'))
}
if (isEmpty(provider.models)) {
window.message.error({ content: i18n.t('message.error.enter.model'), key, style })
return {
valid: false,
error: new Error(i18n.t('message.error.enter.model'))
}
}
return {
valid: true,
error: null
throw new Error(i18n.t('message.error.enter.model'))
}
}
export async function checkApi(provider: Provider, model: Model): Promise<{ valid: boolean; error: Error | null }> {
const validation = checkApiProvider(provider)
if (!validation.valid) {
return {
valid: validation.valid,
error: validation.error
}
}
export async function checkApi(provider: Provider, model: Model): Promise<void> {
checkApiProvider(provider)
const ai = new AiProvider(provider)
// Try streaming check first
const result = await ai.check(model, true)
const assistant = getDefaultAssistant()
assistant.model = model
try {
if (isEmbeddingModel(model)) {
const result = await ai.getEmbeddingDimensions(model)
if (result === 0) {
throw new Error(i18n.t('message.error.enter.model'))
}
} else {
const params: CompletionsParams = {
callType: 'check',
messages: 'hi',
assistant,
streamOutput: true
}
if (result.valid && !result.error) {
return result
}
// 不应该假设错误由流式引发。多次发起检测请求可能触发429,掩盖了真正的问题。
// 但这里错误类型做的很粗糙,暂时先这样
if (result.error && result.error.message.includes('stream')) {
return ai.check(model, false)
} else {
return result
// Try streaming check first
const result = await ai.completions(params)
if (!result.getText()) {
throw new Error('No response received')
}
}
} catch (error: any) {
if (error.message.includes('stream')) {
const params: CompletionsParams = {
callType: 'check',
messages: 'hi',
assistant,
streamOutput: false
}
const result = await ai.completions(params)
if (!result.getText()) {
throw new Error('No response received')
}
} else {
throw error
}
}
}
+28 -16
View File
@@ -98,14 +98,20 @@ export async function checkModelWithMultipleKeys(
if (isParallel) {
// Check all API keys in parallel
const keyPromises = apiKeys.map(async (key) => {
const result = await checkModel({ ...provider, apiKey: key }, model)
return {
key,
isValid: result.valid,
error: result.error?.message,
latency: result.latency
} as ApiKeyCheckStatus
try {
const result = await checkModel({ ...provider, apiKey: key }, model)
return {
key,
isValid: true,
latency: result.latency
} as ApiKeyCheckStatus
} catch (error: unknown) {
return {
key,
isValid: false,
error: error instanceof Error ? error.message.slice(0, 20) + '...' : String(error).slice(0, 20) + '...'
} as ApiKeyCheckStatus
}
})
const results = await Promise.allSettled(keyPromises)
@@ -125,14 +131,20 @@ export async function checkModelWithMultipleKeys(
} else {
// Check all API keys serially
for (const key of apiKeys) {
const result = await checkModel({ ...provider, apiKey: key }, model)
keyResults.push({
key,
isValid: result.valid,
error: result.error?.message,
latency: result.latency
})
try {
const result = await checkModel({ ...provider, apiKey: key }, model)
keyResults.push({
key,
isValid: true,
latency: result.latency
})
} catch (error: unknown) {
keyResults.push({
key,
isValid: false,
error: error instanceof Error ? error.message.slice(0, 20) + '...' : String(error).slice(0, 20) + '...'
})
}
}
}
@@ -1,8 +1,8 @@
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import AiProvider from '@renderer/aiCore'
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT, DEFAULT_KNOWLEDGE_THRESHOLD } from '@renderer/config/constant'
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
import Logger from '@renderer/config/logger'
import AiProvider from '@renderer/providers/AiProvider'
import store from '@renderer/store'
import { FileMetadata, KnowledgeBase, KnowledgeBaseParams, KnowledgeReference } from '@renderer/types'
import { ExtractResults } from '@renderer/utils/extract'
+12 -55
View File
@@ -1,11 +1,9 @@
import { isEmbeddingModel } from '@renderer/config/models'
import AiProvider from '@renderer/providers/AiProvider'
import store from '@renderer/store'
import { Model, Provider } from '@renderer/types'
import { t } from 'i18next'
import { pick } from 'lodash'
import { checkApiProvider } from './ApiService'
import { checkApi } from './ApiService'
export const getModelUniqId = (m?: Model) => {
return m?.id ? JSON.stringify(pick(m, ['id', 'provider'])) : ''
@@ -33,64 +31,23 @@ export function getModelName(model?: Model) {
return modelName
}
// Generic function to perform model checks
// Abstracts provider validation and error handling, allowing different types of check logic
// Generic function to perform model checks with exception handling
async function performModelCheck<T>(
provider: Provider,
model: Model,
checkFn: (ai: AiProvider, model: Model) => Promise<T>,
processResult: (result: T) => { valid: boolean; error: Error | null }
): Promise<{ valid: boolean; error: Error | null; latency?: number }> {
const validation = checkApiProvider(provider)
if (!validation.valid) {
return {
valid: validation.valid,
error: validation.error
}
}
checkFn: (provider: Provider, model: Model) => Promise<T>
): Promise<{ latency: number }> {
const startTime = performance.now()
await checkFn(provider, model)
const latency = performance.now() - startTime
const AI = new AiProvider(provider)
try {
const startTime = performance.now()
const result = await checkFn(AI, model)
const latency = performance.now() - startTime
return {
...processResult(result),
latency
}
} catch (error: any) {
return {
valid: false,
error
}
}
return { latency }
}
// Unified model check function
// Automatically selects appropriate check method based on model type
export async function checkModel(provider: Provider, model: Model) {
if (isEmbeddingModel(model)) {
return performModelCheck(
provider,
model,
(ai, model) => ai.getEmbeddingDimensions(model),
(dimensions) => ({ valid: dimensions > 0, error: null })
)
} else {
return performModelCheck(
provider,
model,
async (ai, model) => {
// Try streaming check first
const result = await ai.check(model, true)
if (result.valid && !result.error) {
return result
}
return ai.check(model, false)
},
({ valid, error }) => ({ valid, error: error || null })
)
}
export async function checkModel(provider: Provider, model: Model): Promise<{ latency: number }> {
return performModelCheck(provider, model, async (provider, model) => {
await checkApi(provider, model)
})
}
@@ -28,7 +28,9 @@ export interface StreamProcessorCallbacks {
onLLMWebSearchComplete?: (llmWebSearchResult: WebSearchResponse) => void
// Image generation chunk received
onImageCreated?: () => void
onImageGenerated?: (imageData: GenerateImageResponse) => void
onImageDelta?: (imageData: GenerateImageResponse) => void
onImageGenerated?: (imageData?: GenerateImageResponse) => void
onLLMResponseComplete?: (response?: Response) => void
// Called when an error occurs during chunk processing
onError?: (error: any) => void
// Called when the entire stream processing is signaled as complete (success or failure)
@@ -40,59 +42,84 @@ export function createStreamProcessor(callbacks: StreamProcessorCallbacks = {})
// The returned function processes a single chunk or a final signal
return (chunk: Chunk) => {
try {
// Logger.log(`[${new Date().toLocaleString()}] createStreamProcessor ${chunk.type}`, chunk)
// 1. Handle the manual final signal first
if (chunk?.type === ChunkType.BLOCK_COMPLETE) {
callbacks.onComplete?.(AssistantMessageStatus.SUCCESS, chunk?.response)
return
const data = chunk
switch (data.type) {
case ChunkType.BLOCK_COMPLETE: {
if (callbacks.onComplete) callbacks.onComplete(AssistantMessageStatus.SUCCESS, data?.response)
break
}
case ChunkType.LLM_RESPONSE_CREATED: {
if (callbacks.onLLMResponseCreated) callbacks.onLLMResponseCreated()
break
}
case ChunkType.TEXT_DELTA: {
if (callbacks.onTextChunk) callbacks.onTextChunk(data.text)
break
}
case ChunkType.TEXT_COMPLETE: {
if (callbacks.onTextComplete) callbacks.onTextComplete(data.text)
break
}
case ChunkType.THINKING_DELTA: {
if (callbacks.onThinkingChunk) callbacks.onThinkingChunk(data.text, data.thinking_millsec)
break
}
case ChunkType.THINKING_COMPLETE: {
if (callbacks.onThinkingComplete) callbacks.onThinkingComplete(data.text, data.thinking_millsec)
break
}
case ChunkType.MCP_TOOL_IN_PROGRESS: {
if (callbacks.onToolCallInProgress)
data.responses.forEach((toolResp) => callbacks.onToolCallInProgress!(toolResp))
break
}
case ChunkType.MCP_TOOL_COMPLETE: {
if (callbacks.onToolCallComplete && data.responses.length > 0) {
data.responses.forEach((toolResp) => callbacks.onToolCallComplete!(toolResp))
}
break
}
case ChunkType.EXTERNEL_TOOL_IN_PROGRESS: {
if (callbacks.onExternalToolInProgress) callbacks.onExternalToolInProgress()
break
}
case ChunkType.EXTERNEL_TOOL_COMPLETE: {
if (callbacks.onExternalToolComplete) callbacks.onExternalToolComplete(data.external_tool)
break
}
case ChunkType.LLM_WEB_SEARCH_IN_PROGRESS: {
if (callbacks.onLLMWebSearchInProgress) callbacks.onLLMWebSearchInProgress()
break
}
case ChunkType.LLM_WEB_SEARCH_COMPLETE: {
if (callbacks.onLLMWebSearchComplete) callbacks.onLLMWebSearchComplete(data.llm_web_search)
break
}
case ChunkType.IMAGE_CREATED: {
if (callbacks.onImageCreated) callbacks.onImageCreated()
break
}
case ChunkType.IMAGE_DELTA: {
if (callbacks.onImageDelta) callbacks.onImageDelta(data.image)
break
}
case ChunkType.IMAGE_COMPLETE: {
if (callbacks.onImageGenerated) callbacks.onImageGenerated(data.image)
break
}
case ChunkType.LLM_RESPONSE_COMPLETE: {
if (callbacks.onLLMResponseComplete) callbacks.onLLMResponseComplete(data.response)
break
}
case ChunkType.ERROR: {
if (callbacks.onError) callbacks.onError(data.error)
break
}
default: {
// Handle unknown chunk types or log an error
console.warn(`Unknown chunk type: ${data.type}`)
}
}
// 2. Process the actual ChunkCallbackData
const data = chunk // Cast after checking for 'final'
// Invoke callbacks based on the fields present in the chunk data
if (data.type === ChunkType.LLM_RESPONSE_CREATED && callbacks.onLLMResponseCreated) {
callbacks.onLLMResponseCreated()
}
if (data.type === ChunkType.TEXT_DELTA && callbacks.onTextChunk) {
callbacks.onTextChunk(data.text)
}
if (data.type === ChunkType.TEXT_COMPLETE && callbacks.onTextComplete) {
callbacks.onTextComplete(data.text)
}
if (data.type === ChunkType.THINKING_DELTA && callbacks.onThinkingChunk) {
callbacks.onThinkingChunk(data.text, data.thinking_millsec)
}
if (data.type === ChunkType.THINKING_COMPLETE && callbacks.onThinkingComplete) {
callbacks.onThinkingComplete(data.text, data.thinking_millsec)
}
if (data.type === ChunkType.MCP_TOOL_IN_PROGRESS && callbacks.onToolCallInProgress) {
data.responses.forEach((toolResp) => callbacks.onToolCallInProgress!(toolResp))
}
if (data.type === ChunkType.MCP_TOOL_COMPLETE && data.responses.length > 0 && callbacks.onToolCallComplete) {
data.responses.forEach((toolResp) => callbacks.onToolCallComplete!(toolResp))
}
if (data.type === ChunkType.EXTERNEL_TOOL_IN_PROGRESS && callbacks.onExternalToolInProgress) {
callbacks.onExternalToolInProgress()
}
if (data.type === ChunkType.EXTERNEL_TOOL_COMPLETE && callbacks.onExternalToolComplete) {
callbacks.onExternalToolComplete(data.external_tool)
}
if (data.type === ChunkType.LLM_WEB_SEARCH_IN_PROGRESS && callbacks.onLLMWebSearchInProgress) {
callbacks.onLLMWebSearchInProgress()
}
if (data.type === ChunkType.LLM_WEB_SEARCH_COMPLETE && callbacks.onLLMWebSearchComplete) {
callbacks.onLLMWebSearchComplete(data.llm_web_search)
}
if (data.type === ChunkType.IMAGE_CREATED && callbacks.onImageCreated) {
callbacks.onImageCreated()
}
if (data.type === ChunkType.IMAGE_COMPLETE && callbacks.onImageGenerated) {
callbacks.onImageGenerated(data.image)
}
if (data.type === ChunkType.ERROR && callbacks.onError) {
callbacks.onError(data.error)
}
// Note: Usage and Metrics are usually handled at the end or accumulated differently,
// so direct callbacks might not be the best fit here. They are often part of the final message state.
} catch (error) {
console.error('Error processing stream chunk:', error)
callbacks.onError?.(error)
+1 -1
View File
@@ -54,7 +54,7 @@ const persistedReducer = persistReducer(
{
key: 'cherry-studio',
storage,
version: 112,
version: 113,
blacklist: ['runtime', 'messages', 'messageBlocks'],
migrate
},
+14 -4
View File
@@ -127,12 +127,22 @@ export const INITIAL_PROVIDERS: Provider[] = [
enabled: false
},
{
id: 'o3',
name: 'O3',
id: '302ai',
name: '302.AI',
type: 'openai',
apiKey: '',
apiHost: 'https://api.o3.fan',
models: SYSTEM_MODELS.o3,
apiHost: 'https://api.302.ai',
models: SYSTEM_MODELS['302ai'],
isSystem: true,
enabled: false
},
{
id: 'cephalon',
name: 'Cephalon',
type: 'openai',
apiKey: '',
apiHost: 'https://cephalon.cloud/user-center/v1/model',
models: SYSTEM_MODELS.cephalon,
isSystem: true,
enabled: false
},
+11
View File
@@ -1564,6 +1564,17 @@ const migrateConfig = {
}
},
'112': (state: RootState) => {
try {
addProvider(state, 'cephalon')
addProvider(state, '302ai')
state.llm.providers = moveProvider(state.llm.providers, 'cephalon', 13)
state.llm.providers = moveProvider(state.llm.providers, '302ai', 14)
return state
} catch (error) {
return state
}
},
'113': (state: RootState) => {
try {
if (!state.settings.userId) {
state.settings.userId = uuid()
+16 -2
View File
@@ -7,6 +7,10 @@ export interface ChatState {
isMultiSelectMode: boolean
selectedMessageIds: string[]
activeTopic: Topic | null
/** topic ids that are currently being renamed */
renamingTopics: string[]
/** topic ids that are newly renamed */
newlyRenamedTopics: string[]
}
export interface UpdateState {
@@ -65,7 +69,9 @@ const initialState: RuntimeState = {
chat: {
isMultiSelectMode: false,
selectedMessageIds: [],
activeTopic: null
activeTopic: null,
renamingTopics: [],
newlyRenamedTopics: []
}
}
@@ -118,6 +124,12 @@ const runtimeSlice = createSlice({
},
setActiveTopic: (state, action: PayloadAction<Topic>) => {
state.chat.activeTopic = action.payload
},
setRenamingTopics: (state, action: PayloadAction<string[]>) => {
state.chat.renamingTopics = action.payload
},
setNewlyRenamedTopics: (state, action: PayloadAction<string[]>) => {
state.chat.newlyRenamedTopics = action.payload
}
}
})
@@ -137,7 +149,9 @@ export const {
// Chat related actions
toggleMultiSelectMode,
setSelectedMessageIds,
setActiveTopic
setActiveTopic,
setRenamingTopics,
setNewlyRenamedTopics
} = runtimeSlice.actions
export default runtimeSlice.reducer
+215 -164
View File
@@ -8,7 +8,6 @@ import { createStreamProcessor, type StreamProcessorCallbacks } from '@renderer/
import { estimateMessagesUsage } from '@renderer/services/TokenService'
import store from '@renderer/store'
import type { Assistant, ExternalToolResult, FileMetadata, MCPToolResponse, Model, Topic } from '@renderer/types'
import { WebSearchSource } from '@renderer/types'
import type {
CitationMessageBlock,
FileMessageBlock,
@@ -22,7 +21,6 @@ import { AssistantMessageStatus, MessageBlockStatus, MessageBlockType } from '@r
import { Response } from '@renderer/types/newMessage'
import { uuid } from '@renderer/utils'
import { formatErrorMessage, isAbortError } from '@renderer/utils/error'
import { extractUrlsFromMarkdown } from '@renderer/utils/linkConverter'
import {
createAssistantMessage,
createBaseMessageBlock,
@@ -35,7 +33,8 @@ import {
createTranslationBlock,
resetAssistantMessage
} from '@renderer/utils/messageUtils/create'
import { getTopicQueue, waitForTopicQueue } from '@renderer/utils/queue'
import { getMainTextContent } from '@renderer/utils/messageUtils/find'
import { getTopicQueue } from '@renderer/utils/queue'
import { isOnHomePage } from '@renderer/utils/window'
import { t } from 'i18next'
import { isEmpty, throttle } from 'lodash'
@@ -45,10 +44,10 @@ import type { AppDispatch, RootState } from '../index'
import { removeManyBlocks, updateOneBlock, upsertManyBlocks, upsertOneBlock } from '../messageBlock'
import { newMessagesActions, selectMessagesForTopic } from '../newMessage'
const handleChangeLoadingOfTopic = async (topicId: string) => {
await waitForTopicQueue(topicId)
store.dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
}
// const handleChangeLoadingOfTopic = async (topicId: string) => {
// await waitForTopicQueue(topicId)
// store.dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
// }
// TODO: 后续可以将db操作移到Listener Middleware中
export const saveMessageAndBlocksToDB = async (message: Message, blocks: MessageBlock[], messageIndex: number = -1) => {
try {
@@ -337,10 +336,17 @@ const fetchAndProcessAssistantResponseImpl = async (
let accumulatedContent = ''
let accumulatedThinking = ''
// 专注于管理UI焦点和块切换
let lastBlockId: string | null = null
let lastBlockType: MessageBlockType | null = null
// 专注于块内部的生命周期处理
let initialPlaceholderBlockId: string | null = null
let citationBlockId: string | null = null
let mainTextBlockId: string | null = null
let thinkingBlockId: string | null = null
let imageBlockId: string | null = null
let toolBlockId: string | null = null
let hasWebSearch = false
const toolCallIdToBlockIdMap = new Map<string, string>()
const notificationService = NotificationService.getInstance()
@@ -400,129 +406,129 @@ const fetchAndProcessAssistantResponseImpl = async (
}
callbacks = {
onLLMResponseCreated: () => {
onLLMResponseCreated: async () => {
const baseBlock = createBaseMessageBlock(assistantMsgId, MessageBlockType.UNKNOWN, {
status: MessageBlockStatus.PROCESSING
})
handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
initialPlaceholderBlockId = baseBlock.id
await handleBlockTransition(baseBlock as PlaceholderMessageBlock, MessageBlockType.UNKNOWN)
},
onTextChunk: (text) => {
onTextChunk: async (text) => {
accumulatedContent += text
if (lastBlockId) {
if (lastBlockType === MessageBlockType.UNKNOWN) {
const initialChanges: Partial<MessageBlock> = {
type: MessageBlockType.MAIN_TEXT,
content: accumulatedContent,
status: MessageBlockStatus.STREAMING,
citationReferences: citationBlockId ? [{ citationBlockId }] : []
}
mainTextBlockId = lastBlockId
lastBlockType = MessageBlockType.MAIN_TEXT
dispatch(updateOneBlock({ id: lastBlockId, changes: initialChanges }))
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
} else if (lastBlockType === MessageBlockType.MAIN_TEXT) {
const blockChanges: Partial<MessageBlock> = {
content: accumulatedContent,
status: MessageBlockStatus.STREAMING
}
throttledBlockUpdate(lastBlockId, blockChanges)
// throttledBlockDbUpdate(lastBlockId, blockChanges)
} else {
const newBlock = createMainTextBlock(assistantMsgId, accumulatedContent, {
status: MessageBlockStatus.STREAMING,
citationReferences: citationBlockId ? [{ citationBlockId }] : []
})
handleBlockTransition(newBlock, MessageBlockType.MAIN_TEXT)
mainTextBlockId = newBlock.id
if (mainTextBlockId) {
const blockChanges: Partial<MessageBlock> = {
content: accumulatedContent,
status: MessageBlockStatus.STREAMING
}
throttledBlockUpdate(mainTextBlockId, blockChanges)
} else if (initialPlaceholderBlockId) {
// 将占位块转换为主文本块
const initialChanges: Partial<MessageBlock> = {
type: MessageBlockType.MAIN_TEXT,
content: accumulatedContent,
status: MessageBlockStatus.STREAMING,
citationReferences: citationBlockId ? [{ citationBlockId }] : []
}
mainTextBlockId = initialPlaceholderBlockId
// 清理占位块
initialPlaceholderBlockId = null
lastBlockType = MessageBlockType.MAIN_TEXT
dispatch(updateOneBlock({ id: mainTextBlockId, changes: initialChanges }))
saveUpdatedBlockToDB(mainTextBlockId, assistantMsgId, topicId, getState)
} else {
const newBlock = createMainTextBlock(assistantMsgId, accumulatedContent, {
status: MessageBlockStatus.STREAMING,
citationReferences: citationBlockId ? [{ citationBlockId }] : []
})
mainTextBlockId = newBlock.id // 立即设置ID,防止竞态条件
await handleBlockTransition(newBlock, MessageBlockType.MAIN_TEXT)
}
},
onTextComplete: async (finalText) => {
if (lastBlockType === MessageBlockType.MAIN_TEXT && lastBlockId) {
if (mainTextBlockId) {
const changes = {
content: finalText,
status: MessageBlockStatus.SUCCESS
}
cancelThrottledBlockUpdate(lastBlockId)
dispatch(updateOneBlock({ id: lastBlockId, changes }))
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
if (assistant.enableWebSearch && assistant.model?.provider === 'openrouter') {
const extractedUrls = extractUrlsFromMarkdown(finalText)
if (extractedUrls.length > 0) {
const citationBlock = createCitationBlock(
assistantMsgId,
{ response: { source: WebSearchSource.OPENROUTER, results: extractedUrls } },
{ status: MessageBlockStatus.SUCCESS }
)
await handleBlockTransition(citationBlock, MessageBlockType.CITATION)
// saveUpdatedBlockToDB(citationBlock.id, assistantMsgId, topicId, getState)
}
}
cancelThrottledBlockUpdate(mainTextBlockId)
dispatch(updateOneBlock({ id: mainTextBlockId, changes }))
saveUpdatedBlockToDB(mainTextBlockId, assistantMsgId, topicId, getState)
mainTextBlockId = null
} else {
console.warn(
`[onTextComplete] Received text.complete but last block was not MAIN_TEXT (was ${lastBlockType}) or lastBlockId is null.`
`[onTextComplete] Received text.complete but last block was not MAIN_TEXT (was ${lastBlockType}) or lastBlockId is null.`
)
}
},
onThinkingChunk: (text, thinking_millsec) => {
accumulatedThinking += text
if (lastBlockId) {
if (lastBlockType === MessageBlockType.UNKNOWN) {
// First chunk for this block: Update type and status immediately
lastBlockType = MessageBlockType.THINKING
const initialChanges: Partial<MessageBlock> = {
type: MessageBlockType.THINKING,
content: accumulatedThinking,
status: MessageBlockStatus.STREAMING
}
dispatch(updateOneBlock({ id: lastBlockId, changes: initialChanges }))
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
} else if (lastBlockType === MessageBlockType.THINKING) {
const blockChanges: Partial<MessageBlock> = {
content: accumulatedThinking,
status: MessageBlockStatus.STREAMING,
thinking_millsec: thinking_millsec
}
throttledBlockUpdate(lastBlockId, blockChanges)
// throttledBlockDbUpdate(lastBlockId, blockChanges)
} else {
const newBlock = createThinkingBlock(assistantMsgId, accumulatedThinking, {
status: MessageBlockStatus.STREAMING,
thinking_millsec: 0
})
handleBlockTransition(newBlock, MessageBlockType.THINKING)
if (citationBlockId && !hasWebSearch) {
const changes: Partial<CitationMessageBlock> = {
status: MessageBlockStatus.SUCCESS
}
dispatch(updateOneBlock({ id: citationBlockId, changes }))
saveUpdatedBlockToDB(citationBlockId, assistantMsgId, topicId, getState)
citationBlockId = null
}
},
onThinkingChunk: async (text, thinking_millsec) => {
accumulatedThinking += text
if (thinkingBlockId) {
const blockChanges: Partial<MessageBlock> = {
content: accumulatedThinking,
status: MessageBlockStatus.STREAMING,
thinking_millsec: thinking_millsec
}
throttledBlockUpdate(thinkingBlockId, blockChanges)
} else if (initialPlaceholderBlockId) {
// First chunk for this block: Update type and status immediately
lastBlockType = MessageBlockType.THINKING
const initialChanges: Partial<MessageBlock> = {
type: MessageBlockType.THINKING,
content: accumulatedThinking,
status: MessageBlockStatus.STREAMING
}
thinkingBlockId = initialPlaceholderBlockId
initialPlaceholderBlockId = null
dispatch(updateOneBlock({ id: thinkingBlockId, changes: initialChanges }))
saveUpdatedBlockToDB(thinkingBlockId, assistantMsgId, topicId, getState)
} else {
const newBlock = createThinkingBlock(assistantMsgId, accumulatedThinking, {
status: MessageBlockStatus.STREAMING,
thinking_millsec: 0
})
thinkingBlockId = newBlock.id // 立即设置ID,防止竞态条件
await handleBlockTransition(newBlock, MessageBlockType.THINKING)
}
},
onThinkingComplete: (finalText, final_thinking_millsec) => {
if (lastBlockType === MessageBlockType.THINKING && lastBlockId) {
if (thinkingBlockId) {
const changes = {
type: MessageBlockType.THINKING,
content: finalText,
status: MessageBlockStatus.SUCCESS,
thinking_millsec: final_thinking_millsec
}
cancelThrottledBlockUpdate(lastBlockId)
dispatch(updateOneBlock({ id: lastBlockId, changes }))
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
cancelThrottledBlockUpdate(thinkingBlockId)
dispatch(updateOneBlock({ id: thinkingBlockId, changes }))
saveUpdatedBlockToDB(thinkingBlockId, assistantMsgId, topicId, getState)
} else {
console.warn(
`[onThinkingComplete] Received thinking.complete but last block was not THINKING (was ${lastBlockType}) or lastBlockId is null.`
`[onThinkingComplete] Received thinking.complete but last block was not THINKING (was ${lastBlockType}) or lastBlockId is null.`
)
}
thinkingBlockId = null
},
onToolCallInProgress: (toolResponse: MCPToolResponse) => {
if (lastBlockType === MessageBlockType.UNKNOWN && lastBlockId) {
if (initialPlaceholderBlockId) {
lastBlockType = MessageBlockType.TOOL
const changes = {
type: MessageBlockType.TOOL,
status: MessageBlockStatus.PROCESSING,
metadata: { rawMcpToolResponse: toolResponse }
}
dispatch(updateOneBlock({ id: lastBlockId, changes }))
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
toolCallIdToBlockIdMap.set(toolResponse.id, lastBlockId)
toolBlockId = initialPlaceholderBlockId
initialPlaceholderBlockId = null
dispatch(updateOneBlock({ id: toolBlockId, changes }))
saveUpdatedBlockToDB(toolBlockId, assistantMsgId, topicId, getState)
toolCallIdToBlockIdMap.set(toolResponse.id, toolBlockId)
} else if (toolResponse.status === 'invoking') {
const toolBlock = createToolBlock(assistantMsgId, toolResponse.id, {
toolName: toolResponse.tool.name,
@@ -539,6 +545,7 @@ const fetchAndProcessAssistantResponseImpl = async (
},
onToolCallComplete: (toolResponse: MCPToolResponse) => {
const existingBlockId = toolCallIdToBlockIdMap.get(toolResponse.id)
toolCallIdToBlockIdMap.delete(toolResponse.id)
if (toolResponse.status === 'done' || toolResponse.status === 'error') {
if (!existingBlockId) {
console.error(
@@ -564,10 +571,10 @@ const fetchAndProcessAssistantResponseImpl = async (
)
}
},
onExternalToolInProgress: () => {
onExternalToolInProgress: async () => {
const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING })
citationBlockId = citationBlock.id
handleBlockTransition(citationBlock, MessageBlockType.CITATION)
await handleBlockTransition(citationBlock, MessageBlockType.CITATION)
// saveUpdatedBlockToDB(citationBlock.id, assistantMsgId, topicId, getState)
},
onExternalToolComplete: (externalToolResult: ExternalToolResult) => {
@@ -583,35 +590,39 @@ const fetchAndProcessAssistantResponseImpl = async (
console.error('[onExternalToolComplete] citationBlockId is null. Cannot update.')
}
},
onLLMWebSearchInProgress: () => {
const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING })
citationBlockId = citationBlock.id
handleBlockTransition(citationBlock, MessageBlockType.CITATION)
// saveUpdatedBlockToDB(citationBlock.id, assistantMsgId, topicId, getState)
onLLMWebSearchInProgress: async () => {
if (initialPlaceholderBlockId) {
lastBlockType = MessageBlockType.CITATION
citationBlockId = initialPlaceholderBlockId
const changes = {
type: MessageBlockType.CITATION,
status: MessageBlockStatus.PROCESSING
}
lastBlockType = MessageBlockType.CITATION
dispatch(updateOneBlock({ id: initialPlaceholderBlockId, changes }))
saveUpdatedBlockToDB(initialPlaceholderBlockId, assistantMsgId, topicId, getState)
initialPlaceholderBlockId = null
} else {
const citationBlock = createCitationBlock(assistantMsgId, {}, { status: MessageBlockStatus.PROCESSING })
citationBlockId = citationBlock.id
await handleBlockTransition(citationBlock, MessageBlockType.CITATION)
}
},
onLLMWebSearchComplete: async (llmWebSearchResult) => {
if (citationBlockId) {
hasWebSearch = true
const changes: Partial<CitationMessageBlock> = {
response: llmWebSearchResult,
status: MessageBlockStatus.SUCCESS
}
dispatch(updateOneBlock({ id: citationBlockId, changes }))
saveUpdatedBlockToDB(citationBlockId, assistantMsgId, topicId, getState)
} else {
const citationBlock = createCitationBlock(
assistantMsgId,
{ response: llmWebSearchResult },
{ status: MessageBlockStatus.SUCCESS }
)
citationBlockId = citationBlock.id
handleBlockTransition(citationBlock, MessageBlockType.CITATION)
}
if (mainTextBlockId) {
const state = getState()
const existingMainTextBlock = state.messageBlocks.entities[mainTextBlockId]
if (existingMainTextBlock && existingMainTextBlock.type === MessageBlockType.MAIN_TEXT) {
const currentRefs = existingMainTextBlock.citationReferences || []
if (!currentRefs.some((ref) => ref.citationBlockId === citationBlockId)) {
if (mainTextBlockId) {
const state = getState()
const existingMainTextBlock = state.messageBlocks.entities[mainTextBlockId]
if (existingMainTextBlock && existingMainTextBlock.type === MessageBlockType.MAIN_TEXT) {
const currentRefs = existingMainTextBlock.citationReferences || []
const mainTextChanges = {
citationReferences: [
...currentRefs,
@@ -621,40 +632,64 @@ const fetchAndProcessAssistantResponseImpl = async (
dispatch(updateOneBlock({ id: mainTextBlockId, changes: mainTextChanges }))
saveUpdatedBlockToDB(mainTextBlockId, assistantMsgId, topicId, getState)
}
mainTextBlockId = null
}
}
},
onImageCreated: () => {
if (lastBlockId) {
if (lastBlockType === MessageBlockType.UNKNOWN) {
const initialChanges: Partial<MessageBlock> = {
type: MessageBlockType.IMAGE,
status: MessageBlockStatus.STREAMING
}
lastBlockType = MessageBlockType.IMAGE
dispatch(updateOneBlock({ id: lastBlockId, changes: initialChanges }))
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
} else {
const imageBlock = createImageBlock(assistantMsgId, {
status: MessageBlockStatus.PROCESSING
})
handleBlockTransition(imageBlock, MessageBlockType.IMAGE)
onImageCreated: async () => {
if (initialPlaceholderBlockId) {
lastBlockType = MessageBlockType.IMAGE
const initialChanges: Partial<MessageBlock> = {
type: MessageBlockType.IMAGE,
status: MessageBlockStatus.STREAMING
}
lastBlockType = MessageBlockType.IMAGE
imageBlockId = initialPlaceholderBlockId
initialPlaceholderBlockId = null
dispatch(updateOneBlock({ id: imageBlockId, changes: initialChanges }))
saveUpdatedBlockToDB(imageBlockId, assistantMsgId, topicId, getState)
} else if (!imageBlockId) {
const imageBlock = createImageBlock(assistantMsgId, {
status: MessageBlockStatus.STREAMING
})
imageBlockId = imageBlock.id
await handleBlockTransition(imageBlock, MessageBlockType.IMAGE)
}
},
onImageGenerated: (imageData) => {
onImageDelta: (imageData) => {
const imageUrl = imageData.images?.[0] || 'placeholder_image_url'
if (lastBlockId && lastBlockType === MessageBlockType.IMAGE) {
if (imageBlockId) {
const changes: Partial<ImageMessageBlock> = {
url: imageUrl,
metadata: { generateImageResponse: imageData },
status: MessageBlockStatus.SUCCESS
status: MessageBlockStatus.STREAMING
}
dispatch(updateOneBlock({ id: imageBlockId, changes }))
saveUpdatedBlockToDB(imageBlockId, assistantMsgId, topicId, getState)
}
},
onImageGenerated: (imageData) => {
if (imageBlockId) {
if (!imageData) {
const changes: Partial<ImageMessageBlock> = {
status: MessageBlockStatus.SUCCESS
}
dispatch(updateOneBlock({ id: imageBlockId, changes }))
saveUpdatedBlockToDB(imageBlockId, assistantMsgId, topicId, getState)
} else {
const imageUrl = imageData.images?.[0] || 'placeholder_image_url'
const changes: Partial<ImageMessageBlock> = {
url: imageUrl,
metadata: { generateImageResponse: imageData },
status: MessageBlockStatus.SUCCESS
}
dispatch(updateOneBlock({ id: imageBlockId, changes }))
saveUpdatedBlockToDB(imageBlockId, assistantMsgId, topicId, getState)
}
dispatch(updateOneBlock({ id: lastBlockId, changes }))
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
} else {
console.error('[onImageGenerated] Last block was not an Image block or ID is missing.')
}
imageBlockId = null
},
onError: async (error) => {
console.dir(error, { depth: null })
@@ -683,15 +718,16 @@ const fetchAndProcessAssistantResponseImpl = async (
source: 'assistant'
})
}
if (lastBlockId) {
const possibleBlockId =
mainTextBlockId || thinkingBlockId || toolBlockId || imageBlockId || citationBlockId || lastBlockId
if (possibleBlockId) {
// 更改上一个block的状态为ERROR
const changes: Partial<MessageBlock> = {
status: isErrorTypeAbort ? MessageBlockStatus.PAUSED : MessageBlockStatus.ERROR
}
cancelThrottledBlockUpdate(lastBlockId)
dispatch(updateOneBlock({ id: lastBlockId, changes }))
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
cancelThrottledBlockUpdate(possibleBlockId)
dispatch(updateOneBlock({ id: possibleBlockId, changes }))
saveUpdatedBlockToDB(possibleBlockId, assistantMsgId, topicId, getState)
}
const errorBlock = createErrorBlock(assistantMsgId, serializableError, { status: MessageBlockStatus.SUCCESS })
@@ -721,35 +757,45 @@ const fetchAndProcessAssistantResponseImpl = async (
const contextForUsage = userMsgIndex !== -1 ? orderedMsgs.slice(0, userMsgIndex + 1) : []
const finalContextWithAssistant = [...contextForUsage, finalAssistantMsg]
if (lastBlockId) {
const possibleBlockId =
mainTextBlockId || thinkingBlockId || toolBlockId || imageBlockId || citationBlockId || lastBlockId
if (possibleBlockId) {
const changes: Partial<MessageBlock> = {
status: MessageBlockStatus.SUCCESS
}
cancelThrottledBlockUpdate(lastBlockId)
dispatch(updateOneBlock({ id: lastBlockId, changes }))
saveUpdatedBlockToDB(lastBlockId, assistantMsgId, topicId, getState)
cancelThrottledBlockUpdate(possibleBlockId)
dispatch(updateOneBlock({ id: possibleBlockId, changes }))
saveUpdatedBlockToDB(possibleBlockId, assistantMsgId, topicId, getState)
}
// const content = getMainTextContent(finalAssistantMsg)
// if (!isOnHomePage()) {
// await notificationService.send({
// id: uuid(),
// type: 'success',
// title: t('notification.assistant'),
// message: content.length > 50 ? content.slice(0, 47) + '...' : content,
// silent: false,
// timestamp: Date.now(),
// source: 'assistant'
// })
// }
const endTime = Date.now()
const duration = endTime - startTime
const content = getMainTextContent(finalAssistantMsg)
if (!isOnHomePage() && duration > 60 * 1000) {
await notificationService.send({
id: uuid(),
type: 'success',
title: t('notification.assistant'),
message: content.length > 50 ? content.slice(0, 47) + '...' : content,
silent: false,
timestamp: Date.now(),
source: 'assistant'
})
}
// 更新topic的name
autoRenameTopic(assistant, topicId)
if (response && response.usage?.total_tokens === 0) {
if (
response &&
(response.usage?.total_tokens === 0 ||
response?.usage?.prompt_tokens === 0 ||
response?.usage?.completion_tokens === 0)
) {
const usage = await estimateMessagesUsage({ assistant, messages: finalContextWithAssistant })
response.usage = usage
}
dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
}
if (response && response.metrics) {
if (response.metrics.completion_tokens === 0 && response.usage?.completion_tokens) {
@@ -779,6 +825,7 @@ const fetchAndProcessAssistantResponseImpl = async (
const streamProcessorCallbacks = createStreamProcessor(callbacks)
const startTime = Date.now()
await fetchChatCompletion({
messages: messagesForContext,
assistant: assistant,
@@ -833,9 +880,10 @@ export const sendMessage =
}
} catch (error) {
console.error('Error in sendMessage thunk:', error)
} finally {
handleChangeLoadingOfTopic(topicId)
}
// finally {
// handleChangeLoadingOfTopic(topicId)
// }
}
/**
@@ -1069,9 +1117,10 @@ export const resendMessageThunk =
}
} catch (error) {
console.error(`[resendMessageThunk] Error resending user message ${userMessageToResend.id}:`, error)
} finally {
handleChangeLoadingOfTopic(topicId)
}
// finally {
// handleChangeLoadingOfTopic(topicId)
// }
}
/**
@@ -1179,10 +1228,11 @@ export const regenerateAssistantResponseThunk =
`[regenerateAssistantResponseThunk] Error regenerating response for assistant message ${assistantMessageToRegenerate.id}:`,
error
)
dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
} finally {
handleChangeLoadingOfTopic(topicId)
// dispatch(newMessagesActions.setTopicLoading({ topicId, loading: false }))
}
// finally {
// handleChangeLoadingOfTopic(topicId)
// }
}
// --- Thunk to initiate translation and create the initial block ---
@@ -1348,9 +1398,10 @@ export const appendAssistantResponseThunk =
console.error(`[appendAssistantResponseThunk] Error appending assistant response:`, error)
// Optionally dispatch an error action or notification
// Resetting loading state should be handled by the underlying fetchAndProcessAssistantResponseImpl
} finally {
handleChangeLoadingOfTopic(topicId)
}
// finally {
// handleChangeLoadingOfTopic(topicId)
// }
}
/**
+12 -3
View File
@@ -1,5 +1,6 @@
import { ExternalToolResult, KnowledgeReference, MCPToolResponse, WebSearchResponse } from '.'
import { ExternalToolResult, KnowledgeReference, MCPToolResponse, ToolUseResponse, WebSearchResponse } from '.'
import { Response, ResponseError } from './newMessage'
import { SdkToolCall } from './sdk'
// Define Enum for Chunk Types
// 目前用到的,并没有列出完整的生命周期
@@ -11,6 +12,7 @@ export enum ChunkType {
WEB_SEARCH_COMPLETE = 'web_search_complete',
KNOWLEDGE_SEARCH_IN_PROGRESS = 'knowledge_search_in_progress',
KNOWLEDGE_SEARCH_COMPLETE = 'knowledge_search_complete',
MCP_TOOL_CREATED = 'mcp_tool_created',
MCP_TOOL_IN_PROGRESS = 'mcp_tool_in_progress',
MCP_TOOL_COMPLETE = 'mcp_tool_complete',
EXTERNEL_TOOL_COMPLETE = 'externel_tool_complete',
@@ -118,7 +120,7 @@ export interface ImageDeltaChunk {
/**
* A chunk of Base64 encoded image data
*/
image: string
image: { type: 'base64'; images: string[] }
/**
* The type of the chunk
@@ -135,7 +137,7 @@ export interface ImageCompleteChunk {
/**
* The image content of the chunk
*/
image: { type: 'base64'; images: string[] }
image?: { type: 'base64'; images: string[] }
}
export interface ThinkingDeltaChunk {
@@ -253,6 +255,12 @@ export interface ExternalToolCompleteChunk {
type: ChunkType.EXTERNEL_TOOL_COMPLETE
}
export interface MCPToolCreatedChunk {
type: ChunkType.MCP_TOOL_CREATED
tool_calls?: SdkToolCall[] // 工具调用
tool_use_responses?: ToolUseResponse[] // 工具使用响应
}
export interface MCPToolInProgressChunk {
/**
* The type of the chunk
@@ -345,6 +353,7 @@ export type Chunk =
| WebSearchCompleteChunk // 互联网搜索完成
| KnowledgeSearchInProgressChunk // 知识库搜索进行中
| KnowledgeSearchCompleteChunk // 知识库搜索完成
| MCPToolCreatedChunk // MCP工具被大模型创建
| MCPToolInProgressChunk // MCP工具调用中
| MCPToolCompleteChunk // MCP工具调用完成
| ExternalToolCompleteChunk // 外部工具调用完成,外部工具包含搜索互联网,知识库,MCP服务器
+5 -4
View File
@@ -1,5 +1,5 @@
import type { WebSearchResultBlock } from '@anthropic-ai/sdk/resources'
import type { GenerateImagesConfig, GroundingMetadata } from '@google/genai'
import type { GenerateImagesConfig, GroundingMetadata, PersonGeneration } from '@google/genai'
import type OpenAI from 'openai'
import type { CSSProperties } from 'react'
@@ -455,10 +455,11 @@ export type GenerateImageParams = {
imageSize: string
batchSize: number
seed?: string
numInferenceSteps: number
guidanceScale: number
numInferenceSteps?: number
guidanceScale?: number
signal?: AbortSignal
promptEnhancement?: boolean
personGeneration?: PersonGeneration
}
export type GenerateImageResponse = {
@@ -531,7 +532,7 @@ export enum WebSearchSource {
}
export type WebSearchResponse = {
results: WebSearchResults
results?: WebSearchResults
source: WebSearchSource
}

Some files were not shown because too many files have changed in this diff Show More