Compare commits
34 Commits
v1.5.9
...
feat/claud
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bd6d6bd56e | ||
|
|
7a23386de4 | ||
|
|
a227f6dcb9 | ||
|
|
9ff4acf092 | ||
|
|
128b1fe9bc | ||
|
|
9a92372c3e | ||
|
|
0a36869b3c | ||
|
|
a9a38f88bb | ||
|
|
aca1fcad18 | ||
|
|
24bc878c27 | ||
|
|
b1a9fbc6fd | ||
|
|
8a4c635c97 | ||
|
|
16d5f5c299 | ||
|
|
69a5a0434a | ||
|
|
6d1f3a5729 | ||
|
|
b725400428 | ||
|
|
9f7d2be463 | ||
|
|
fdee510c8c | ||
|
|
76ac1bd8f7 | ||
|
|
362658339a | ||
|
|
925d7e2a25 | ||
|
|
089477eb1e | ||
|
|
f153f77a7e | ||
|
|
a34141c912 | ||
|
|
94374e7de2 | ||
|
|
bdf6748956 | ||
|
|
d6dcb471f9 | ||
|
|
2c0391da81 | ||
|
|
77c2255da4 | ||
|
|
5ce7261678 | ||
|
|
001253b32d | ||
|
|
2480822690 | ||
|
|
16b9f49cc8 | ||
|
|
1295d37ff6 |
3
.gitignore
vendored
3
.gitignore
vendored
@@ -60,6 +60,9 @@ coverage
|
||||
.vitest-cache
|
||||
vitest.config.*.timestamp-*
|
||||
|
||||
# TypeScript incremental build
|
||||
.tsbuildinfo
|
||||
|
||||
# playwright
|
||||
playwright-report
|
||||
test-results
|
||||
|
||||
@@ -7,4 +7,5 @@ tsconfig.*.json
|
||||
CHANGELOG*.md
|
||||
agents.json
|
||||
src/renderer/src/integration/nutstore/sso/lib
|
||||
AGENT.md
|
||||
src/main/integration/cherryin/index.js
|
||||
|
||||
30
.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch
vendored
Normal file
30
.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
diff --git a/index.js b/index.js
|
||||
index dc071739e79876dff88e1be06a9168e294222d13..b9df7525c62bdf777e89e732e1b0c81f84d872f2 100644
|
||||
--- a/index.js
|
||||
+++ b/index.js
|
||||
@@ -380,7 +380,7 @@ if (!nativeBinding || process.env.NAPI_RS_FORCE_WASI) {
|
||||
}
|
||||
}
|
||||
|
||||
-if (!nativeBinding) {
|
||||
+if (!nativeBinding && process.platform !== 'linux') {
|
||||
if (loadErrors.length > 0) {
|
||||
throw new Error(
|
||||
`Cannot find native binding. ` +
|
||||
@@ -392,6 +392,13 @@ if (!nativeBinding) {
|
||||
throw new Error(`Failed to load native binding`)
|
||||
}
|
||||
|
||||
-module.exports = nativeBinding
|
||||
-module.exports.OcrAccuracy = nativeBinding.OcrAccuracy
|
||||
-module.exports.recognize = nativeBinding.recognize
|
||||
+if (process.platform === 'linux') {
|
||||
+ module.exports = {OcrAccuracy: {
|
||||
+ Fast: 0,
|
||||
+ Accurate: 1
|
||||
+ }, recognize: () => Promise.resolve({text: '', confidence: 1.0})}
|
||||
+}else{
|
||||
+ module.exports = nativeBinding
|
||||
+ module.exports.OcrAccuracy = nativeBinding.OcrAccuracy
|
||||
+ module.exports.recognize = nativeBinding.recognize
|
||||
+}
|
||||
@@ -121,24 +121,12 @@ afterSign: scripts/notarize.js
|
||||
artifactBuildCompleted: scripts/artifact-build-completed.js
|
||||
releaseInfo:
|
||||
releaseNotes: |
|
||||
✨ 重要更新:
|
||||
- 新增笔记模块,支持富文本编辑和管理
|
||||
- 内置 GLM-4.5-Flash 免费模型(由智谱开放平台提供)
|
||||
- 内置 Qwen3-8B 免费模型(由硅基流动提供)
|
||||
- 新增 Nano Banana(Gemini 2.5 Flash Image)模型支持
|
||||
- 新增系统 OCR 功能 (macOS & Windows)
|
||||
- 新增图片 OCR 识别和翻译功能
|
||||
- 模型切换支持通过标签筛选
|
||||
- 翻译功能增强:历史搜索和收藏
|
||||
|
||||
🔧 性能优化:
|
||||
- 优化历史页面搜索性能
|
||||
- 优化拖拽列表组件交互
|
||||
- 升级 Electron 到 37.4.0
|
||||
- 优化AI服务连接方式,提升响应速度和稳定性
|
||||
- 改进模型列表获取功能,减少不必要的网络请求
|
||||
- 增强各AI服务商的兼容性和连接可靠性
|
||||
|
||||
🐛 修复问题:
|
||||
- 修复知识库加密 PDF 文档处理
|
||||
- 修复导航栏在左侧时笔记侧边栏按钮缺失
|
||||
- 修复多个模型兼容性问题
|
||||
- 修复 MCP 相关问题
|
||||
- 其他稳定性改进
|
||||
🐛 问题修复:
|
||||
- 修复部分AI服务商连接失败的问题
|
||||
- 修复模型配置加载时的潜在错误
|
||||
- 提升应用整体稳定性和容错能力
|
||||
|
||||
@@ -4,6 +4,8 @@ import { defineConfig, externalizeDepsPlugin } from 'electron-vite'
|
||||
import { resolve } from 'path'
|
||||
import { visualizer } from 'rollup-plugin-visualizer'
|
||||
|
||||
import pkg from './package.json' assert { type: 'json' }
|
||||
|
||||
const visualizerPlugin = (type: 'renderer' | 'main') => {
|
||||
return process.env[`VISUALIZER_${type.toUpperCase()}`] ? [visualizer({ open: true })] : []
|
||||
}
|
||||
@@ -21,12 +23,15 @@ export default defineConfig({
|
||||
'@shared': resolve('packages/shared'),
|
||||
'@logger': resolve('src/main/services/LoggerService'),
|
||||
'@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'),
|
||||
'@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node')
|
||||
'@mcp-trace/trace-node': resolve('packages/mcp-trace/trace-node'),
|
||||
'@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'),
|
||||
'@cherrystudio/ai-core/built-in/plugins': resolve('packages/aiCore/src/core/plugins/built-in'),
|
||||
'@cherrystudio/ai-core': resolve('packages/aiCore/src')
|
||||
}
|
||||
},
|
||||
build: {
|
||||
rollupOptions: {
|
||||
external: ['@libsql/client', 'bufferutil', 'utf-8-validate'],
|
||||
external: ['bufferutil', 'utf-8-validate', 'electron', ...Object.keys(pkg.dependencies)],
|
||||
output: {
|
||||
manualChunks: undefined, // 彻底禁用代码分割 - 返回 null 强制单文件打包
|
||||
inlineDynamicImports: true // 内联所有动态导入,这是关键配置
|
||||
@@ -82,6 +87,9 @@ export default defineConfig({
|
||||
'@logger': resolve('src/renderer/src/services/LoggerService'),
|
||||
'@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'),
|
||||
'@mcp-trace/trace-web': resolve('packages/mcp-trace/trace-web'),
|
||||
'@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'),
|
||||
'@cherrystudio/ai-core/built-in/plugins': resolve('packages/aiCore/src/core/plugins/built-in'),
|
||||
'@cherrystudio/ai-core': resolve('packages/aiCore/src'),
|
||||
'@cherrystudio/extension-table-plus': resolve('packages/extension-table-plus/src')
|
||||
}
|
||||
},
|
||||
|
||||
28
package.json
28
package.json
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "CherryStudio",
|
||||
"version": "1.5.9",
|
||||
"version": "1.6.0-beta.6",
|
||||
"private": true,
|
||||
"description": "A powerful AI assistant for producer.",
|
||||
"main": "./out/main/index.js",
|
||||
@@ -47,7 +47,7 @@
|
||||
"generate:icons": "electron-icon-builder --input=./build/logo.png --output=build",
|
||||
"analyze:renderer": "VISUALIZER_RENDERER=true yarn build",
|
||||
"analyze:main": "VISUALIZER_MAIN=true yarn build",
|
||||
"typecheck": "npm run typecheck:node && npm run typecheck:web",
|
||||
"typecheck": "concurrently -n \"node,web\" -c \"cyan,magenta\" \"npm run typecheck:node\" \"npm run typecheck:web\"",
|
||||
"typecheck:node": "tsc --noEmit -p tsconfig.node.json --composite false",
|
||||
"typecheck:web": "tsc --noEmit -p tsconfig.web.json --composite false",
|
||||
"check:i18n": "tsx scripts/check-i18n.ts",
|
||||
@@ -72,8 +72,10 @@
|
||||
"dependencies": {
|
||||
"@libsql/client": "0.14.0",
|
||||
"@libsql/win32-x64-msvc": "^0.4.7",
|
||||
"@napi-rs/system-ocr": "^1.0.2",
|
||||
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch",
|
||||
"@strongtz/win32-arm64-msvc": "^0.4.7",
|
||||
"ai-sdk-provider-claude-code": "^1.1.3",
|
||||
"express": "^5.1.0",
|
||||
"graceful-fs": "^4.2.11",
|
||||
"jsdom": "26.1.0",
|
||||
"node-stream-zip": "^1.15.0",
|
||||
@@ -88,12 +90,16 @@
|
||||
"@agentic/exa": "^7.3.3",
|
||||
"@agentic/searxng": "^7.3.3",
|
||||
"@agentic/tavily": "^7.3.3",
|
||||
"@ai-sdk/amazon-bedrock": "^3.0.0",
|
||||
"@ai-sdk/google-vertex": "^3.0.0",
|
||||
"@ai-sdk/mistral": "^2.0.0",
|
||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||
"@anthropic-ai/sdk": "^0.41.0",
|
||||
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
|
||||
"@aws-sdk/client-bedrock": "^3.840.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.840.0",
|
||||
"@aws-sdk/client-s3": "^3.840.0",
|
||||
"@cherrystudio/ai-core": "workspace:*",
|
||||
"@cherrystudio/embedjs": "^0.1.31",
|
||||
"@cherrystudio/embedjs-libsql": "^0.1.31",
|
||||
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
|
||||
@@ -129,6 +135,7 @@
|
||||
"@modelcontextprotocol/sdk": "^1.17.0",
|
||||
"@mozilla/readability": "^0.6.0",
|
||||
"@notionhq/client": "^2.2.15",
|
||||
"@openrouter/ai-sdk-provider": "^1.1.2",
|
||||
"@opentelemetry/api": "^1.9.0",
|
||||
"@opentelemetry/core": "2.0.0",
|
||||
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
|
||||
@@ -138,7 +145,7 @@
|
||||
"@playwright/test": "^1.52.0",
|
||||
"@reduxjs/toolkit": "^2.2.5",
|
||||
"@shikijs/markdown-it": "^3.12.0",
|
||||
"@swc/plugin-styled-components": "^7.1.5",
|
||||
"@swc/plugin-styled-components": "^8.0.4",
|
||||
"@tanstack/react-query": "^5.85.5",
|
||||
"@tanstack/react-virtual": "^3.13.12",
|
||||
"@testing-library/dom": "^10.4.0",
|
||||
@@ -164,6 +171,7 @@
|
||||
"@truto/turndown-plugin-gfm": "^1.0.2",
|
||||
"@tryfabric/martian": "^1.2.4",
|
||||
"@types/cli-progress": "^3",
|
||||
"@types/express": "^5.0.3",
|
||||
"@types/fs-extra": "^11",
|
||||
"@types/he": "^1",
|
||||
"@types/lodash": "^4.17.5",
|
||||
@@ -189,6 +197,7 @@
|
||||
"@viz-js/lang-dot": "^1.0.5",
|
||||
"@viz-js/viz": "^3.14.0",
|
||||
"@xyflow/react": "^12.4.4",
|
||||
"ai": "^5.0.29",
|
||||
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
|
||||
"archiver": "^7.0.1",
|
||||
"async-mutex": "^0.5.0",
|
||||
@@ -199,6 +208,7 @@
|
||||
"cli-progress": "^3.12.0",
|
||||
"code-inspector-plugin": "^0.20.14",
|
||||
"color": "^5.0.0",
|
||||
"concurrently": "^9.2.1",
|
||||
"country-flag-emoji-polyfill": "0.1.8",
|
||||
"dayjs": "^1.11.11",
|
||||
"dexie": "^4.0.8",
|
||||
@@ -328,7 +338,13 @@
|
||||
"pkce-challenge@npm:^4.1.0": "patch:pkce-challenge@npm%3A4.1.0#~/.yarn/patches/pkce-challenge-npm-4.1.0-fbc51695a3.patch",
|
||||
"undici": "6.21.2",
|
||||
"vite": "npm:rolldown-vite@latest",
|
||||
"tesseract.js@npm:*": "patch:tesseract.js@npm%3A6.0.1#~/.yarn/patches/tesseract.js-npm-6.0.1-2562a7e46d.patch"
|
||||
"tesseract.js@npm:*": "patch:tesseract.js@npm%3A6.0.1#~/.yarn/patches/tesseract.js-npm-6.0.1-2562a7e46d.patch",
|
||||
"@img/sharp-darwin-arm64": "0.34.3",
|
||||
"@img/sharp-darwin-x64": "0.34.3",
|
||||
"@img/sharp-linux-arm": "0.34.3",
|
||||
"@img/sharp-linux-arm64": "0.34.3",
|
||||
"@img/sharp-linux-x64": "0.34.3",
|
||||
"@img/sharp-win32-x64": "0.34.3"
|
||||
},
|
||||
"packageManager": "yarn@4.9.1",
|
||||
"lint-staged": {
|
||||
@@ -336,7 +352,7 @@
|
||||
"prettier --write",
|
||||
"eslint --fix"
|
||||
],
|
||||
"*.{json,md,yml,yaml,css,scss,html}": [
|
||||
"*.{json,yml,yaml,css,scss,html}": [
|
||||
"prettier --write"
|
||||
]
|
||||
}
|
||||
|
||||
514
packages/aiCore/AI_SDK_ARCHITECTURE.md
Normal file
514
packages/aiCore/AI_SDK_ARCHITECTURE.md
Normal file
@@ -0,0 +1,514 @@
|
||||
# AI Core 基于 Vercel AI SDK 的技术架构
|
||||
|
||||
## 1. 架构设计理念
|
||||
|
||||
### 1.1 设计目标
|
||||
|
||||
- **简化分层**:`models`(模型层)→ `runtime`(运行时层),清晰的职责分离
|
||||
- **统一接口**:使用 Vercel AI SDK 统一不同 AI Provider 的接口差异
|
||||
- **动态导入**:通过动态导入实现按需加载,减少打包体积
|
||||
- **最小包装**:直接使用 AI SDK 的类型和接口,避免重复定义
|
||||
- **插件系统**:基于钩子的通用插件架构,支持请求全生命周期扩展
|
||||
- **类型安全**:利用 TypeScript 和 AI SDK 的类型系统确保类型安全
|
||||
- **轻量级**:专注核心功能,保持包的轻量和高效
|
||||
- **包级独立**:作为独立包管理,便于复用和维护
|
||||
- **Agent就绪**:为将来集成 OpenAI Agents SDK 预留扩展空间
|
||||
|
||||
### 1.2 核心优势
|
||||
|
||||
- **标准化**:AI SDK 提供统一的模型接口,减少适配工作
|
||||
- **简化设计**:函数式API,避免过度抽象
|
||||
- **更好的开发体验**:完整的 TypeScript 支持和丰富的生态系统
|
||||
- **性能优化**:AI SDK 内置优化和最佳实践
|
||||
- **模块化设计**:独立包结构,支持跨项目复用
|
||||
- **可扩展插件**:通用的流转换和参数处理插件系统
|
||||
- **面向未来**:为 OpenAI Agents SDK 集成做好准备
|
||||
|
||||
## 2. 整体架构图
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph "用户应用 (如 Cherry Studio)"
|
||||
UI["用户界面"]
|
||||
Components["应用组件"]
|
||||
end
|
||||
|
||||
subgraph "packages/aiCore (AI Core 包)"
|
||||
subgraph "Runtime Layer (运行时层)"
|
||||
RuntimeExecutor["RuntimeExecutor (运行时执行器)"]
|
||||
PluginEngine["PluginEngine (插件引擎)"]
|
||||
RuntimeAPI["Runtime API (便捷函数)"]
|
||||
end
|
||||
|
||||
subgraph "Models Layer (模型层)"
|
||||
ModelFactory["createModel() (模型工厂)"]
|
||||
ProviderCreator["ProviderCreator (提供商创建器)"]
|
||||
end
|
||||
|
||||
subgraph "Core Systems (核心系统)"
|
||||
subgraph "Plugins (插件)"
|
||||
PluginManager["PluginManager (插件管理)"]
|
||||
BuiltInPlugins["Built-in Plugins (内置插件)"]
|
||||
StreamTransforms["Stream Transforms (流转换)"]
|
||||
end
|
||||
|
||||
subgraph "Middleware (中间件)"
|
||||
MiddlewareWrapper["wrapModelWithMiddlewares() (中间件包装)"]
|
||||
end
|
||||
|
||||
subgraph "Providers (提供商)"
|
||||
Registry["Provider Registry (注册表)"]
|
||||
Factory["Provider Factory (工厂)"]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
subgraph "Vercel AI SDK"
|
||||
AICore["ai (核心库)"]
|
||||
OpenAI["@ai-sdk/openai"]
|
||||
Anthropic["@ai-sdk/anthropic"]
|
||||
Google["@ai-sdk/google"]
|
||||
XAI["@ai-sdk/xai"]
|
||||
Others["其他 19+ Providers"]
|
||||
end
|
||||
|
||||
subgraph "Future: OpenAI Agents SDK"
|
||||
AgentSDK["@openai/agents (未来集成)"]
|
||||
AgentExtensions["Agent Extensions (预留)"]
|
||||
end
|
||||
|
||||
UI --> RuntimeAPI
|
||||
Components --> RuntimeExecutor
|
||||
RuntimeAPI --> RuntimeExecutor
|
||||
RuntimeExecutor --> PluginEngine
|
||||
RuntimeExecutor --> ModelFactory
|
||||
PluginEngine --> PluginManager
|
||||
ModelFactory --> ProviderCreator
|
||||
ModelFactory --> MiddlewareWrapper
|
||||
ProviderCreator --> Registry
|
||||
Registry --> Factory
|
||||
Factory --> OpenAI
|
||||
Factory --> Anthropic
|
||||
Factory --> Google
|
||||
Factory --> XAI
|
||||
Factory --> Others
|
||||
|
||||
RuntimeExecutor --> AICore
|
||||
AICore --> streamText
|
||||
AICore --> generateText
|
||||
AICore --> streamObject
|
||||
AICore --> generateObject
|
||||
|
||||
PluginManager --> StreamTransforms
|
||||
PluginManager --> BuiltInPlugins
|
||||
|
||||
%% 未来集成路径
|
||||
RuntimeExecutor -.-> AgentSDK
|
||||
AgentSDK -.-> AgentExtensions
|
||||
```
|
||||
|
||||
## 3. 包结构设计
|
||||
|
||||
### 3.1 新架构文件结构
|
||||
|
||||
```
|
||||
packages/aiCore/
|
||||
├── src/
|
||||
│ ├── core/ # 核心层 - 内部实现
|
||||
│ │ ├── models/ # 模型层 - 模型创建和配置
|
||||
│ │ │ ├── factory.ts # 模型工厂函数 ✅
|
||||
│ │ │ ├── ModelCreator.ts # 模型创建器 ✅
|
||||
│ │ │ ├── ConfigManager.ts # 配置管理器 ✅
|
||||
│ │ │ ├── types.ts # 模型类型定义 ✅
|
||||
│ │ │ └── index.ts # 模型层导出 ✅
|
||||
│ │ ├── runtime/ # 运行时层 - 执行和用户API
|
||||
│ │ │ ├── executor.ts # 运行时执行器 ✅
|
||||
│ │ │ ├── pluginEngine.ts # 插件引擎 ✅
|
||||
│ │ │ ├── types.ts # 运行时类型定义 ✅
|
||||
│ │ │ └── index.ts # 运行时导出 ✅
|
||||
│ │ ├── middleware/ # 中间件系统
|
||||
│ │ │ ├── wrapper.ts # 模型包装器 ✅
|
||||
│ │ │ ├── manager.ts # 中间件管理器 ✅
|
||||
│ │ │ ├── types.ts # 中间件类型 ✅
|
||||
│ │ │ └── index.ts # 中间件导出 ✅
|
||||
│ │ ├── plugins/ # 插件系统
|
||||
│ │ │ ├── types.ts # 插件类型定义 ✅
|
||||
│ │ │ ├── manager.ts # 插件管理器 ✅
|
||||
│ │ │ ├── built-in/ # 内置插件 ✅
|
||||
│ │ │ │ ├── logging.ts # 日志插件 ✅
|
||||
│ │ │ │ ├── webSearchPlugin/ # 网络搜索插件 ✅
|
||||
│ │ │ │ ├── toolUsePlugin/ # 工具使用插件 ✅
|
||||
│ │ │ │ └── index.ts # 内置插件导出 ✅
|
||||
│ │ │ ├── README.md # 插件文档 ✅
|
||||
│ │ │ └── index.ts # 插件导出 ✅
|
||||
│ │ ├── providers/ # 提供商管理
|
||||
│ │ │ ├── registry.ts # 提供商注册表 ✅
|
||||
│ │ │ ├── factory.ts # 提供商工厂 ✅
|
||||
│ │ │ ├── creator.ts # 提供商创建器 ✅
|
||||
│ │ │ ├── types.ts # 提供商类型 ✅
|
||||
│ │ │ ├── utils.ts # 工具函数 ✅
|
||||
│ │ │ └── index.ts # 提供商导出 ✅
|
||||
│ │ ├── options/ # 配置选项
|
||||
│ │ │ ├── factory.ts # 选项工厂 ✅
|
||||
│ │ │ ├── types.ts # 选项类型 ✅
|
||||
│ │ │ ├── xai.ts # xAI 选项 ✅
|
||||
│ │ │ ├── openrouter.ts # OpenRouter 选项 ✅
|
||||
│ │ │ ├── examples.ts # 示例配置 ✅
|
||||
│ │ │ └── index.ts # 选项导出 ✅
|
||||
│ │ └── index.ts # 核心层导出 ✅
|
||||
│ ├── types.ts # 全局类型定义 ✅
|
||||
│ └── index.ts # 包主入口文件 ✅
|
||||
├── package.json # 包配置文件 ✅
|
||||
├── tsconfig.json # TypeScript 配置 ✅
|
||||
├── README.md # 包说明文档 ✅
|
||||
└── AI_SDK_ARCHITECTURE.md # 本文档 ✅
|
||||
```
|
||||
|
||||
## 4. 架构分层详解
|
||||
|
||||
### 4.1 Models Layer (模型层)
|
||||
|
||||
**职责**:统一的模型创建和配置管理
|
||||
|
||||
**核心文件**:
|
||||
|
||||
- `factory.ts`: 模型工厂函数 (`createModel`, `createModels`)
|
||||
- `ProviderCreator.ts`: 底层提供商创建和模型实例化
|
||||
- `types.ts`: 模型配置类型定义
|
||||
|
||||
**设计特点**:
|
||||
|
||||
- 函数式设计,避免不必要的类抽象
|
||||
- 统一的模型配置接口
|
||||
- 自动处理中间件应用
|
||||
- 支持批量模型创建
|
||||
|
||||
**核心API**:
|
||||
|
||||
```typescript
|
||||
// 模型配置接口
|
||||
export interface ModelConfig {
|
||||
providerId: ProviderId
|
||||
modelId: string
|
||||
options: ProviderSettingsMap[ProviderId]
|
||||
middlewares?: LanguageModelV1Middleware[]
|
||||
}
|
||||
|
||||
// 核心模型创建函数
|
||||
export async function createModel(config: ModelConfig): Promise<LanguageModel>
|
||||
export async function createModels(configs: ModelConfig[]): Promise<LanguageModel[]>
|
||||
```
|
||||
|
||||
### 4.2 Runtime Layer (运行时层)
|
||||
|
||||
**职责**:运行时执行器和用户面向的API接口
|
||||
|
||||
**核心组件**:
|
||||
|
||||
- `executor.ts`: 运行时执行器类
|
||||
- `plugin-engine.ts`: 插件引擎(原PluginEnabledAiClient)
|
||||
- `index.ts`: 便捷函数和工厂方法
|
||||
|
||||
**设计特点**:
|
||||
|
||||
- 提供三种使用方式:类实例、静态工厂、函数式调用
|
||||
- 自动集成模型创建和插件处理
|
||||
- 完整的类型安全支持
|
||||
- 为 OpenAI Agents SDK 预留扩展接口
|
||||
|
||||
**核心API**:
|
||||
|
||||
```typescript
|
||||
// 运行时执行器
|
||||
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
static create<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T],
|
||||
plugins?: AiPlugin[]
|
||||
): RuntimeExecutor<T>
|
||||
|
||||
async streamText(modelId: string, params: StreamTextParams): Promise<StreamTextResult>
|
||||
async generateText(modelId: string, params: GenerateTextParams): Promise<GenerateTextResult>
|
||||
async streamObject(modelId: string, params: StreamObjectParams): Promise<StreamObjectResult>
|
||||
async generateObject(modelId: string, params: GenerateObjectParams): Promise<GenerateObjectResult>
|
||||
}
|
||||
|
||||
// 便捷函数式API
|
||||
export async function streamText<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T],
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
plugins?: AiPlugin[]
|
||||
): Promise<StreamTextResult>
|
||||
```
|
||||
|
||||
### 4.3 Plugin System (插件系统)
|
||||
|
||||
**职责**:可扩展的插件架构
|
||||
|
||||
**核心组件**:
|
||||
|
||||
- `PluginManager`: 插件生命周期管理
|
||||
- `built-in/`: 内置插件集合
|
||||
- 流转换收集和应用
|
||||
|
||||
**设计特点**:
|
||||
|
||||
- 借鉴 Rollup 的钩子分类设计
|
||||
- 支持流转换 (`experimental_transform`)
|
||||
- 内置常用插件(日志、计数等)
|
||||
- 完整的生命周期钩子
|
||||
|
||||
**插件接口**:
|
||||
|
||||
```typescript
|
||||
export interface AiPlugin {
|
||||
name: string
|
||||
enforce?: 'pre' | 'post'
|
||||
|
||||
// 【First】首个钩子 - 只执行第一个返回值的插件
|
||||
resolveModel?: (modelId: string, context: AiRequestContext) => string | null | Promise<string | null>
|
||||
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
|
||||
|
||||
// 【Sequential】串行钩子 - 链式执行,支持数据转换
|
||||
transformParams?: (params: any, context: AiRequestContext) => any | Promise<any>
|
||||
transformResult?: (result: any, context: AiRequestContext) => any | Promise<any>
|
||||
|
||||
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
|
||||
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
|
||||
onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise<void>
|
||||
onError?: (error: Error, context: AiRequestContext) => void | Promise<void>
|
||||
|
||||
// 【Stream】流处理
|
||||
transformStream?: () => TransformStream
|
||||
}
|
||||
```
|
||||
|
||||
### 4.4 Middleware System (中间件系统)
|
||||
|
||||
**职责**:AI SDK原生中间件支持
|
||||
|
||||
**核心组件**:
|
||||
|
||||
- `ModelWrapper.ts`: 模型包装函数
|
||||
|
||||
**设计哲学**:
|
||||
|
||||
- 直接使用AI SDK的 `wrapLanguageModel`
|
||||
- 与插件系统分离,职责明确
|
||||
- 函数式设计,简化使用
|
||||
|
||||
```typescript
|
||||
export function wrapModelWithMiddlewares(model: LanguageModel, middlewares: LanguageModelV1Middleware[]): LanguageModel
|
||||
```
|
||||
|
||||
### 4.5 Provider System (提供商系统)
|
||||
|
||||
**职责**:AI Provider注册表和动态导入
|
||||
|
||||
**核心组件**:
|
||||
|
||||
- `registry.ts`: 19+ Provider配置和类型
|
||||
- `factory.ts`: Provider配置工厂
|
||||
|
||||
**支持的Providers**:
|
||||
|
||||
- OpenAI, Anthropic, Google, XAI
|
||||
- Azure OpenAI, Amazon Bedrock, Google Vertex
|
||||
- Groq, Together.ai, Fireworks, DeepSeek
|
||||
- 等19+ AI SDK官方支持的providers
|
||||
|
||||
## 5. 使用方式
|
||||
|
||||
### 5.1 函数式调用 (推荐 - 简单场景)
|
||||
|
||||
```typescript
|
||||
import { streamText, generateText } from '@cherrystudio/ai-core/runtime'
|
||||
|
||||
// 直接函数调用
|
||||
const stream = await streamText(
|
||||
'anthropic',
|
||||
{ apiKey: 'your-api-key' },
|
||||
'claude-3',
|
||||
{ messages: [{ role: 'user', content: 'Hello!' }] },
|
||||
[loggingPlugin]
|
||||
)
|
||||
```
|
||||
|
||||
### 5.2 执行器实例 (推荐 - 复杂场景)
|
||||
|
||||
```typescript
|
||||
import { createExecutor } from '@cherrystudio/ai-core/runtime'
|
||||
|
||||
// 创建可复用的执行器
|
||||
const executor = createExecutor('openai', { apiKey: 'your-api-key' }, [plugin1, plugin2])
|
||||
|
||||
// 多次使用
|
||||
const stream = await executor.streamText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'Hello!' }]
|
||||
})
|
||||
|
||||
const result = await executor.generateText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'How are you?' }]
|
||||
})
|
||||
```
|
||||
|
||||
### 5.3 静态工厂方法
|
||||
|
||||
```typescript
|
||||
import { RuntimeExecutor } from '@cherrystudio/ai-core/runtime'
|
||||
|
||||
// 静态创建
|
||||
const executor = RuntimeExecutor.create('anthropic', { apiKey: 'your-api-key' })
|
||||
await executor.streamText('claude-3', { messages: [...] })
|
||||
```
|
||||
|
||||
### 5.4 直接模型创建 (高级用法)
|
||||
|
||||
```typescript
|
||||
import { createModel } from '@cherrystudio/ai-core/models'
|
||||
import { streamText } from 'ai'
|
||||
|
||||
// 直接创建模型使用
|
||||
const model = await createModel({
|
||||
providerId: 'openai',
|
||||
modelId: 'gpt-4',
|
||||
options: { apiKey: 'your-api-key' },
|
||||
middlewares: [middleware1, middleware2]
|
||||
})
|
||||
|
||||
// 直接使用 AI SDK
|
||||
const result = await streamText({ model, messages: [...] })
|
||||
```
|
||||
|
||||
## 6. 为 OpenAI Agents SDK 预留的设计
|
||||
|
||||
### 6.1 架构兼容性
|
||||
|
||||
当前架构完全兼容 OpenAI Agents SDK 的集成需求:
|
||||
|
||||
```typescript
|
||||
// 当前的模型创建
|
||||
const model = await createModel({
|
||||
providerId: 'anthropic',
|
||||
modelId: 'claude-3',
|
||||
options: { apiKey: 'xxx' }
|
||||
})
|
||||
|
||||
// 将来可以直接用于 OpenAI Agents SDK
|
||||
import { Agent, run } from '@openai/agents'
|
||||
|
||||
const agent = new Agent({
|
||||
model, // ✅ 直接兼容 LanguageModel 接口
|
||||
name: 'Assistant',
|
||||
instructions: '...',
|
||||
tools: [tool1, tool2]
|
||||
})
|
||||
|
||||
const result = await run(agent, 'user input')
|
||||
```
|
||||
|
||||
### 6.2 预留的扩展点
|
||||
|
||||
1. **runtime/agents/** 目录预留
|
||||
2. **AgentExecutor** 类预留
|
||||
3. **Agent工具转换插件** 预留
|
||||
4. **多Agent编排** 预留
|
||||
|
||||
### 6.3 未来架构扩展
|
||||
|
||||
```
|
||||
packages/aiCore/src/core/
|
||||
├── runtime/
|
||||
│ ├── agents/ # 🚀 未来添加
|
||||
│ │ ├── AgentExecutor.ts
|
||||
│ │ ├── WorkflowManager.ts
|
||||
│ │ └── ConversationManager.ts
|
||||
│ ├── executor.ts
|
||||
│ └── index.ts
|
||||
```
|
||||
|
||||
## 7. 架构优势
|
||||
|
||||
### 7.1 简化设计
|
||||
|
||||
- **移除过度抽象**:删除了orchestration层和creation层的复杂包装
|
||||
- **函数式优先**:models层使用函数而非类
|
||||
- **直接明了**:runtime层直接提供用户API
|
||||
|
||||
### 7.2 职责清晰
|
||||
|
||||
- **Models**: 专注模型创建和配置
|
||||
- **Runtime**: 专注执行和用户API
|
||||
- **Plugins**: 专注扩展功能
|
||||
- **Providers**: 专注AI Provider管理
|
||||
|
||||
### 7.3 类型安全
|
||||
|
||||
- 完整的 TypeScript 支持
|
||||
- AI SDK 类型的直接复用
|
||||
- 避免类型重复定义
|
||||
|
||||
### 7.4 灵活使用
|
||||
|
||||
- 三种使用模式满足不同需求
|
||||
- 从简单函数调用到复杂执行器
|
||||
- 支持直接AI SDK使用
|
||||
|
||||
### 7.5 面向未来
|
||||
|
||||
- 为 OpenAI Agents SDK 集成做好准备
|
||||
- 清晰的扩展点和架构边界
|
||||
- 模块化设计便于功能添加
|
||||
|
||||
## 8. 技术决策记录
|
||||
|
||||
### 8.1 为什么选择简化的两层架构?
|
||||
|
||||
- **职责分离**:models专注创建,runtime专注执行
|
||||
- **模块化**:每层都有清晰的边界和职责
|
||||
- **扩展性**:为Agent功能预留了清晰的扩展空间
|
||||
|
||||
### 8.2 为什么选择函数式设计?
|
||||
|
||||
- **简洁性**:避免不必要的类设计
|
||||
- **性能**:减少对象创建开销
|
||||
- **易用性**:函数调用更直观
|
||||
|
||||
### 8.3 为什么分离插件和中间件?
|
||||
|
||||
- **职责明确**: 插件处理应用特定需求
|
||||
- **原生支持**: 中间件使用AI SDK原生功能
|
||||
- **灵活性**: 两套系统可以独立演进
|
||||
|
||||
## 9. 总结
|
||||
|
||||
AI Core架构实现了:
|
||||
|
||||
### 9.1 核心特点
|
||||
|
||||
- ✅ **简化架构**: 2层核心架构,职责清晰
|
||||
- ✅ **函数式设计**: models层完全函数化
|
||||
- ✅ **类型安全**: 统一的类型定义和AI SDK类型复用
|
||||
- ✅ **插件扩展**: 强大的插件系统
|
||||
- ✅ **多种使用方式**: 满足不同复杂度需求
|
||||
- ✅ **Agent就绪**: 为OpenAI Agents SDK集成做好准备
|
||||
|
||||
### 9.2 核心价值
|
||||
|
||||
- **统一接口**: 一套API支持19+ AI providers
|
||||
- **灵活使用**: 函数式、实例式、静态工厂式
|
||||
- **强类型**: 完整的TypeScript支持
|
||||
- **可扩展**: 插件和中间件双重扩展能力
|
||||
- **高性能**: 最小化包装,直接使用AI SDK
|
||||
- **面向未来**: Agent SDK集成架构就绪
|
||||
|
||||
### 9.3 未来发展
|
||||
|
||||
这个架构提供了:
|
||||
|
||||
- **优秀的开发体验**: 简洁的API和清晰的使用模式
|
||||
- **强大的扩展能力**: 为Agent功能预留了完整的架构空间
|
||||
- **良好的维护性**: 职责分离明确,代码易于维护
|
||||
- **广泛的适用性**: 既适合简单调用也适合复杂应用
|
||||
433
packages/aiCore/README.md
Normal file
433
packages/aiCore/README.md
Normal file
@@ -0,0 +1,433 @@
|
||||
# @cherrystudio/ai-core
|
||||
|
||||
Cherry Studio AI Core 是一个基于 Vercel AI SDK 的统一 AI Provider 接口包,为 AI 应用提供强大的抽象层和插件化架构。
|
||||
|
||||
## ✨ 核心亮点
|
||||
|
||||
### 🏗️ 优雅的架构设计
|
||||
|
||||
- **简化分层**:`models`(模型层)→ `runtime`(运行时层),清晰的职责分离
|
||||
- **函数式优先**:避免过度抽象,提供简洁直观的 API
|
||||
- **类型安全**:完整的 TypeScript 支持,直接复用 AI SDK 类型系统
|
||||
- **最小包装**:直接使用 AI SDK 的接口,避免重复定义和性能损耗
|
||||
|
||||
### 🔌 强大的插件系统
|
||||
|
||||
- **生命周期钩子**:支持请求全生命周期的扩展点
|
||||
- **流转换支持**:基于 AI SDK 的 `experimental_transform` 实现流处理
|
||||
- **插件分类**:First、Sequential、Parallel 三种钩子类型,满足不同场景
|
||||
- **内置插件**:webSearch、logging、toolUse 等开箱即用的功能
|
||||
|
||||
### 🌐 统一多 Provider 接口
|
||||
|
||||
- **扩展注册**:支持自定义 Provider 注册,无限扩展能力
|
||||
- **配置统一**:统一的配置接口,简化多 Provider 管理
|
||||
|
||||
### 🚀 多种使用方式
|
||||
|
||||
- **函数式调用**:适合简单场景的直接函数调用
|
||||
- **执行器实例**:适合复杂场景的可复用执行器
|
||||
- **静态工厂**:便捷的静态创建方法
|
||||
- **原生兼容**:完全兼容 AI SDK 原生 Provider Registry
|
||||
|
||||
### 🔮 面向未来
|
||||
|
||||
- **Agent 就绪**:为 OpenAI Agents SDK 集成预留架构空间
|
||||
- **模块化设计**:独立包结构,支持跨项目复用
|
||||
- **渐进式迁移**:可以逐步从现有 AI SDK 代码迁移
|
||||
|
||||
## 特性
|
||||
|
||||
- 🚀 统一的 AI Provider 接口
|
||||
- 🔄 动态导入支持
|
||||
- 🛠️ TypeScript 支持
|
||||
- 📦 强大的插件系统
|
||||
- 🌍 内置webSearch(Openai,Google,Anthropic,xAI)
|
||||
- 🎯 多种使用模式(函数式/实例式/静态工厂)
|
||||
- 🔌 可扩展的 Provider 注册系统
|
||||
- 🧩 完整的中间件支持
|
||||
- 📊 插件统计和调试功能
|
||||
|
||||
## 支持的 Providers
|
||||
|
||||
基于 [AI SDK 官方支持的 providers](https://ai-sdk.dev/providers/ai-sdk-providers):
|
||||
|
||||
**核心 Providers(内置支持):**
|
||||
|
||||
- OpenAI
|
||||
- Anthropic
|
||||
- Google Generative AI
|
||||
- OpenAI-Compatible
|
||||
- xAI (Grok)
|
||||
- Azure OpenAI
|
||||
- DeepSeek
|
||||
|
||||
**扩展 Providers(通过注册API支持):**
|
||||
|
||||
- Google Vertex AI
|
||||
- ...
|
||||
- 自定义 Provider
|
||||
|
||||
## 安装
|
||||
|
||||
```bash
|
||||
npm install @cherrystudio/ai-core ai
|
||||
```
|
||||
|
||||
### React Native
|
||||
|
||||
如果你在 React Native 项目中使用此包,需要在 `metro.config.js` 中添加以下配置:
|
||||
|
||||
```javascript
|
||||
// metro.config.js
|
||||
const { getDefaultConfig } = require('expo/metro-config')
|
||||
|
||||
const config = getDefaultConfig(__dirname)
|
||||
|
||||
// 添加对 @cherrystudio/ai-core 的支持
|
||||
config.resolver.resolverMainFields = ['react-native', 'browser', 'main']
|
||||
config.resolver.platforms = ['ios', 'android', 'native', 'web']
|
||||
|
||||
module.exports = config
|
||||
```
|
||||
|
||||
还需要安装你要使用的 AI SDK provider:
|
||||
|
||||
```bash
|
||||
npm install @ai-sdk/openai @ai-sdk/anthropic @ai-sdk/google
|
||||
```
|
||||
|
||||
## 使用示例
|
||||
|
||||
### 基础用法
|
||||
|
||||
```typescript
|
||||
import { AiCore } from '@cherrystudio/ai-core'
|
||||
|
||||
// 创建 OpenAI executor
|
||||
const executor = AiCore.create('openai', {
|
||||
apiKey: 'your-api-key'
|
||||
})
|
||||
|
||||
// 流式生成
|
||||
const result = await executor.streamText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'Hello!' }]
|
||||
})
|
||||
|
||||
// 非流式生成
|
||||
const response = await executor.generateText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'Hello!' }]
|
||||
})
|
||||
```
|
||||
|
||||
### 便捷函数
|
||||
|
||||
```typescript
|
||||
import { createOpenAIExecutor } from '@cherrystudio/ai-core'
|
||||
|
||||
// 快速创建 OpenAI executor
|
||||
const executor = createOpenAIExecutor({
|
||||
apiKey: 'your-api-key'
|
||||
})
|
||||
|
||||
// 使用 executor
|
||||
const result = await executor.streamText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'Hello!' }]
|
||||
})
|
||||
```
|
||||
|
||||
### 多 Provider 支持
|
||||
|
||||
```typescript
|
||||
import { AiCore } from '@cherrystudio/ai-core'
|
||||
|
||||
// 支持多种 AI providers
|
||||
const openaiExecutor = AiCore.create('openai', { apiKey: 'openai-key' })
|
||||
const anthropicExecutor = AiCore.create('anthropic', { apiKey: 'anthropic-key' })
|
||||
const googleExecutor = AiCore.create('google', { apiKey: 'google-key' })
|
||||
const xaiExecutor = AiCore.create('xai', { apiKey: 'xai-key' })
|
||||
```
|
||||
|
||||
### 扩展 Provider 注册
|
||||
|
||||
对于非内置的 providers,可以通过注册 API 扩展支持:
|
||||
|
||||
```typescript
|
||||
import { registerProvider, AiCore } from '@cherrystudio/ai-core'
|
||||
|
||||
// 方式一:导入并注册第三方 provider
|
||||
import { createGroq } from '@ai-sdk/groq'
|
||||
|
||||
registerProvider({
|
||||
id: 'groq',
|
||||
name: 'Groq',
|
||||
creator: createGroq,
|
||||
supportsImageGeneration: false
|
||||
})
|
||||
|
||||
// 现在可以使用 Groq
|
||||
const groqExecutor = AiCore.create('groq', { apiKey: 'groq-key' })
|
||||
|
||||
// 方式二:动态导入方式注册
|
||||
registerProvider({
|
||||
id: 'mistral',
|
||||
name: 'Mistral AI',
|
||||
import: () => import('@ai-sdk/mistral'),
|
||||
creatorFunctionName: 'createMistral'
|
||||
})
|
||||
|
||||
const mistralExecutor = AiCore.create('mistral', { apiKey: 'mistral-key' })
|
||||
```
|
||||
|
||||
## 🔌 插件系统
|
||||
|
||||
AI Core 提供了强大的插件系统,支持请求全生命周期的扩展。
|
||||
|
||||
### 内置插件
|
||||
|
||||
#### webSearchPlugin - 网络搜索插件
|
||||
|
||||
为不同 AI Provider 提供统一的网络搜索能力:
|
||||
|
||||
```typescript
|
||||
import { webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
|
||||
const executor = AiCore.create('openai', { apiKey: 'your-key' }, [
|
||||
webSearchPlugin({
|
||||
openai: {
|
||||
/* OpenAI 搜索配置 */
|
||||
},
|
||||
anthropic: { maxUses: 5 },
|
||||
google: {
|
||||
/* Google 搜索配置 */
|
||||
},
|
||||
xai: {
|
||||
mode: 'on',
|
||||
returnCitations: true,
|
||||
maxSearchResults: 5,
|
||||
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
|
||||
}
|
||||
})
|
||||
])
|
||||
```
|
||||
|
||||
#### loggingPlugin - 日志插件
|
||||
|
||||
提供详细的请求日志记录:
|
||||
|
||||
```typescript
|
||||
import { createLoggingPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
|
||||
const executor = AiCore.create('openai', { apiKey: 'your-key' }, [
|
||||
createLoggingPlugin({
|
||||
logLevel: 'info',
|
||||
includeParams: true,
|
||||
includeResult: false
|
||||
})
|
||||
])
|
||||
```
|
||||
|
||||
#### promptToolUsePlugin - 提示工具使用插件
|
||||
|
||||
为不支持原生 Function Call 的模型提供 prompt 方式的工具调用:
|
||||
|
||||
```typescript
|
||||
import { createPromptToolUsePlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
|
||||
// 对于不支持 function call 的模型
|
||||
const executor = AiCore.create(
|
||||
'providerId',
|
||||
{
|
||||
apiKey: 'your-key',
|
||||
baseURL: 'https://your-model-endpoint'
|
||||
},
|
||||
[
|
||||
createPromptToolUsePlugin({
|
||||
enabled: true,
|
||||
// 可选:自定义系统提示符构建
|
||||
buildSystemPrompt: (userPrompt, tools) => {
|
||||
return `${userPrompt}\n\nAvailable tools: ${Object.keys(tools).join(', ')}`
|
||||
}
|
||||
})
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### 自定义插件
|
||||
|
||||
创建自定义插件非常简单:
|
||||
|
||||
```typescript
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
|
||||
const customPlugin = definePlugin({
|
||||
name: 'custom-plugin',
|
||||
enforce: 'pre', // 'pre' | 'post' | undefined
|
||||
|
||||
// 在请求开始时记录日志
|
||||
onRequestStart: async (context) => {
|
||||
console.log(`Starting request for model: ${context.modelId}`)
|
||||
},
|
||||
|
||||
// 转换请求参数
|
||||
transformParams: async (params, context) => {
|
||||
// 添加自定义系统消息
|
||||
if (params.messages) {
|
||||
params.messages.unshift({
|
||||
role: 'system',
|
||||
content: 'You are a helpful assistant.'
|
||||
})
|
||||
}
|
||||
return params
|
||||
},
|
||||
|
||||
// 处理响应结果
|
||||
transformResult: async (result, context) => {
|
||||
// 添加元数据
|
||||
if (result.text) {
|
||||
result.metadata = {
|
||||
processedAt: new Date().toISOString(),
|
||||
modelId: context.modelId
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
})
|
||||
|
||||
// 使用自定义插件
|
||||
const executor = AiCore.create('openai', { apiKey: 'your-key' }, [customPlugin])
|
||||
```
|
||||
|
||||
### 使用 AI SDK 原生 Provider 注册表
|
||||
|
||||
> https://ai-sdk.dev/docs/reference/ai-sdk-core/provider-registry
|
||||
|
||||
除了使用内建的 provider 管理,你还可以使用 AI SDK 原生的 `createProviderRegistry` 来构建自己的 provider 注册表。
|
||||
|
||||
#### 基本用法示例
|
||||
|
||||
```typescript
|
||||
import { createClient } from '@cherrystudio/ai-core'
|
||||
import { createProviderRegistry } from 'ai'
|
||||
import { createOpenAI } from '@ai-sdk/openai'
|
||||
import { anthropic } from '@ai-sdk/anthropic'
|
||||
|
||||
// 1. 创建 AI SDK 原生注册表
|
||||
export const registry = createProviderRegistry({
|
||||
// register provider with prefix and default setup:
|
||||
anthropic,
|
||||
|
||||
// register provider with prefix and custom setup:
|
||||
openai: createOpenAI({
|
||||
apiKey: process.env.OPENAI_API_KEY
|
||||
})
|
||||
})
|
||||
|
||||
// 2. 创建client,'openai'可以传空或者传providerId(内建的provider)
|
||||
const client = PluginEnabledAiClient.create('openai', {
|
||||
apiKey: process.env.OPENAI_API_KEY
|
||||
})
|
||||
|
||||
// 3. 方式1:使用内建逻辑(传统方式)
|
||||
const result1 = await client.streamText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'Hello with built-in logic!' }]
|
||||
})
|
||||
|
||||
// 4. 方式2:使用自定义注册表(灵活方式)
|
||||
const result2 = await client.streamText({
|
||||
model: registry.languageModel('openai:gpt-4'),
|
||||
messages: [{ role: 'user', content: 'Hello with custom registry!' }]
|
||||
})
|
||||
|
||||
// 5. 支持的重载方法
|
||||
await client.generateObject({
|
||||
model: registry.languageModel('openai:gpt-4'),
|
||||
schema: z.object({ name: z.string() }),
|
||||
messages: [{ role: 'user', content: 'Generate a user' }]
|
||||
})
|
||||
|
||||
await client.streamObject({
|
||||
model: registry.languageModel('anthropic:claude-3-opus-20240229'),
|
||||
schema: z.object({ items: z.array(z.string()) }),
|
||||
messages: [{ role: 'user', content: 'Generate a list' }]
|
||||
})
|
||||
```
|
||||
|
||||
#### 与插件系统配合使用
|
||||
|
||||
更强大的是,你还可以将自定义注册表与 Cherry Studio 的插件系统结合使用:
|
||||
|
||||
```typescript
|
||||
import { PluginEnabledAiClient } from '@cherrystudio/ai-core'
|
||||
import { createProviderRegistry } from 'ai'
|
||||
import { createOpenAI } from '@ai-sdk/openai'
|
||||
import { anthropic } from '@ai-sdk/anthropic'
|
||||
|
||||
// 1. 创建带插件的客户端
|
||||
const client = PluginEnabledAiClient.create(
|
||||
'openai',
|
||||
{
|
||||
apiKey: process.env.OPENAI_API_KEY
|
||||
},
|
||||
[LoggingPlugin, RetryPlugin]
|
||||
)
|
||||
|
||||
// 2. 创建自定义注册表
|
||||
const registry = createProviderRegistry({
|
||||
openai: createOpenAI({ apiKey: process.env.OPENAI_API_KEY }),
|
||||
anthropic: anthropic({ apiKey: process.env.ANTHROPIC_API_KEY })
|
||||
})
|
||||
|
||||
// 3. 方式1:使用内建逻辑 + 完整插件系统
|
||||
await client.streamText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'Hello with plugins!' }]
|
||||
})
|
||||
|
||||
// 4. 方式2:使用自定义注册表 + 有限插件支持
|
||||
await client.streamText({
|
||||
model: registry.languageModel('anthropic:claude-3-opus-20240229'),
|
||||
messages: [{ role: 'user', content: 'Hello from Claude!' }]
|
||||
})
|
||||
|
||||
// 5. 支持的方法
|
||||
await client.generateObject({
|
||||
model: registry.languageModel('openai:gpt-4'),
|
||||
schema: z.object({ name: z.string() }),
|
||||
messages: [{ role: 'user', content: 'Generate a user' }]
|
||||
})
|
||||
|
||||
await client.streamObject({
|
||||
model: registry.languageModel('openai:gpt-4'),
|
||||
schema: z.object({ items: z.array(z.string()) }),
|
||||
messages: [{ role: 'user', content: 'Generate a list' }]
|
||||
})
|
||||
```
|
||||
|
||||
#### 混合使用的优势
|
||||
|
||||
- **灵活性**:可以根据需要选择使用内建逻辑或自定义注册表
|
||||
- **兼容性**:完全兼容 AI SDK 的 `createProviderRegistry` API
|
||||
- **渐进式**:可以逐步迁移现有代码,无需一次性重构
|
||||
- **插件支持**:自定义注册表仍可享受插件系统的部分功能
|
||||
- **最佳实践**:结合两种方式的优点,既有动态加载的性能优势,又有统一注册表的便利性
|
||||
|
||||
## 📚 相关资源
|
||||
|
||||
- [Vercel AI SDK 文档](https://ai-sdk.dev/)
|
||||
- [Cherry Studio 项目](https://github.com/CherryHQ/cherry-studio)
|
||||
- [AI SDK Providers](https://ai-sdk.dev/providers/ai-sdk-providers)
|
||||
|
||||
## 未来版本
|
||||
|
||||
- 🔮 多 Agent 编排
|
||||
- 🔮 可视化插件配置
|
||||
- 🔮 实时监控和分析
|
||||
- 🔮 云端插件同步
|
||||
|
||||
## 📄 License
|
||||
|
||||
MIT License - 详见 [LICENSE](https://github.com/CherryHQ/cherry-studio/blob/main/LICENSE) 文件
|
||||
|
||||
---
|
||||
|
||||
**Cherry Studio AI Core** - 让 AI 开发更简单、更强大、更灵活 🚀
|
||||
103
packages/aiCore/examples/hub-provider-usage.ts
Normal file
103
packages/aiCore/examples/hub-provider-usage.ts
Normal file
@@ -0,0 +1,103 @@
|
||||
/**
|
||||
* Hub Provider 使用示例
|
||||
*
|
||||
* 演示如何使用简化后的Hub Provider功能来路由到多个底层provider
|
||||
*/
|
||||
|
||||
import { createHubProvider, initializeProvider, providerRegistry } from '../src/index'
|
||||
|
||||
async function demonstrateHubProvider() {
|
||||
try {
|
||||
// 1. 初始化底层providers
|
||||
console.log('📦 初始化底层providers...')
|
||||
|
||||
initializeProvider('openai', {
|
||||
apiKey: process.env.OPENAI_API_KEY || 'sk-test-key'
|
||||
})
|
||||
|
||||
initializeProvider('anthropic', {
|
||||
apiKey: process.env.ANTHROPIC_API_KEY || 'sk-ant-test-key'
|
||||
})
|
||||
|
||||
// 2. 创建Hub Provider(自动包含所有已初始化的providers)
|
||||
console.log('🌐 创建Hub Provider...')
|
||||
|
||||
const aihubmixProvider = createHubProvider({
|
||||
hubId: 'aihubmix',
|
||||
debug: true
|
||||
})
|
||||
|
||||
// 3. 注册Hub Provider
|
||||
providerRegistry.registerProvider('aihubmix', aihubmixProvider)
|
||||
|
||||
console.log('✅ Hub Provider "aihubmix" 注册成功')
|
||||
|
||||
// 4. 使用Hub Provider访问不同的模型
|
||||
console.log('\n🚀 使用Hub模型...')
|
||||
|
||||
// 通过Hub路由到OpenAI
|
||||
const openaiModel = providerRegistry.languageModel('aihubmix:openai:gpt-4')
|
||||
console.log('✓ OpenAI模型已获取:', openaiModel.modelId)
|
||||
|
||||
// 通过Hub路由到Anthropic
|
||||
const anthropicModel = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
|
||||
console.log('✓ Anthropic模型已获取:', anthropicModel.modelId)
|
||||
|
||||
// 5. 演示错误处理
|
||||
console.log('\n❌ 演示错误处理...')
|
||||
|
||||
try {
|
||||
// 尝试访问未初始化的provider
|
||||
providerRegistry.languageModel('aihubmix:google:gemini-pro')
|
||||
} catch (error) {
|
||||
console.log('预期错误:', error.message)
|
||||
}
|
||||
|
||||
try {
|
||||
// 尝试使用错误的模型ID格式
|
||||
providerRegistry.languageModel('aihubmix:invalid-format')
|
||||
} catch (error) {
|
||||
console.log('预期错误:', error.message)
|
||||
}
|
||||
|
||||
// 6. 多个Hub Provider示例
|
||||
console.log('\n🔄 创建多个Hub Provider...')
|
||||
|
||||
const localHubProvider = createHubProvider({
|
||||
hubId: 'local-ai'
|
||||
})
|
||||
|
||||
providerRegistry.registerProvider('local-ai', localHubProvider)
|
||||
console.log('✅ Hub Provider "local-ai" 注册成功')
|
||||
|
||||
console.log('\n🎉 Hub Provider演示完成!')
|
||||
} catch (error) {
|
||||
console.error('💥 演示过程中发生错误:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// 演示简化的使用方式
|
||||
function simplifiedUsageExample() {
|
||||
console.log('\n📝 简化使用示例:')
|
||||
console.log(`
|
||||
// 1. 初始化providers
|
||||
initializeProvider('openai', { apiKey: 'sk-xxx' })
|
||||
initializeProvider('anthropic', { apiKey: 'sk-ant-xxx' })
|
||||
|
||||
// 2. 创建并注册Hub Provider
|
||||
const hubProvider = createHubProvider({ hubId: 'aihubmix' })
|
||||
providerRegistry.registerProvider('aihubmix', hubProvider)
|
||||
|
||||
// 3. 直接使用
|
||||
const model1 = providerRegistry.languageModel('aihubmix:openai:gpt-4')
|
||||
const model2 = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
|
||||
`)
|
||||
}
|
||||
|
||||
// 运行演示
|
||||
if (require.main === module) {
|
||||
demonstrateHubProvider()
|
||||
simplifiedUsageExample()
|
||||
}
|
||||
|
||||
export { demonstrateHubProvider, simplifiedUsageExample }
|
||||
167
packages/aiCore/examples/image-generation.ts
Normal file
167
packages/aiCore/examples/image-generation.ts
Normal file
@@ -0,0 +1,167 @@
|
||||
/**
|
||||
* Image Generation Example
|
||||
* 演示如何使用 aiCore 的文生图功能
|
||||
*/
|
||||
|
||||
import { createExecutor, generateImage } from '../src/index'
|
||||
|
||||
async function main() {
|
||||
// 方式1: 使用执行器实例
|
||||
console.log('📸 创建 OpenAI 图像生成执行器...')
|
||||
const executor = createExecutor('openai', {
|
||||
apiKey: process.env.OPENAI_API_KEY!
|
||||
})
|
||||
|
||||
try {
|
||||
console.log('🎨 使用执行器生成图像...')
|
||||
const result1 = await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A futuristic cityscape at sunset with flying cars',
|
||||
size: '1024x1024',
|
||||
n: 1
|
||||
})
|
||||
|
||||
console.log('✅ 图像生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result1.images.length,
|
||||
mediaType: result1.image.mediaType,
|
||||
hasBase64: !!result1.image.base64,
|
||||
providerMetadata: result1.providerMetadata
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ 执行器生成失败:', error)
|
||||
}
|
||||
|
||||
// 方式2: 使用直接调用 API
|
||||
try {
|
||||
console.log('🎨 使用直接 API 生成图像...')
|
||||
const result2 = await generateImage('openai', { apiKey: process.env.OPENAI_API_KEY! }, 'dall-e-3', {
|
||||
prompt: 'A magical forest with glowing mushrooms and fairy lights',
|
||||
aspectRatio: '16:9',
|
||||
providerOptions: {
|
||||
openai: {
|
||||
quality: 'hd',
|
||||
style: 'vivid'
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
console.log('✅ 直接 API 生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result2.images.length,
|
||||
mediaType: result2.image.mediaType,
|
||||
hasBase64: !!result2.image.base64
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ 直接 API 生成失败:', error)
|
||||
}
|
||||
|
||||
// 方式3: 支持其他提供商 (Google Imagen)
|
||||
if (process.env.GOOGLE_API_KEY) {
|
||||
try {
|
||||
console.log('🎨 使用 Google Imagen 生成图像...')
|
||||
const googleExecutor = createExecutor('google', {
|
||||
apiKey: process.env.GOOGLE_API_KEY!
|
||||
})
|
||||
|
||||
const result3 = await googleExecutor.generateImage('imagen-3.0-generate-002', {
|
||||
prompt: 'A serene mountain lake at dawn with mist rising from the water',
|
||||
aspectRatio: '1:1'
|
||||
})
|
||||
|
||||
console.log('✅ Google Imagen 生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result3.images.length,
|
||||
mediaType: result3.image.mediaType,
|
||||
hasBase64: !!result3.image.base64
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ Google Imagen 生成失败:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// 方式4: 支持插件系统
|
||||
const pluginExample = async () => {
|
||||
console.log('🔌 演示插件系统...')
|
||||
|
||||
// 创建一个示例插件,用于修改提示词
|
||||
const promptEnhancerPlugin = {
|
||||
name: 'prompt-enhancer',
|
||||
transformParams: async (params: any) => {
|
||||
console.log('🔧 插件: 增强提示词...')
|
||||
return {
|
||||
...params,
|
||||
prompt: `${params.prompt}, highly detailed, cinematic lighting, 4K resolution`
|
||||
}
|
||||
},
|
||||
transformResult: async (result: any) => {
|
||||
console.log('🔧 插件: 处理结果...')
|
||||
return {
|
||||
...result,
|
||||
enhanced: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const executorWithPlugin = createExecutor(
|
||||
'openai',
|
||||
{
|
||||
apiKey: process.env.OPENAI_API_KEY!
|
||||
},
|
||||
[promptEnhancerPlugin]
|
||||
)
|
||||
|
||||
try {
|
||||
const result4 = await executorWithPlugin.generateImage('dall-e-3', {
|
||||
prompt: 'A cute robot playing in a garden'
|
||||
})
|
||||
|
||||
console.log('✅ 插件系统生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result4.images.length,
|
||||
enhanced: (result4 as any).enhanced,
|
||||
mediaType: result4.image.mediaType
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ 插件系统生成失败:', error)
|
||||
}
|
||||
}
|
||||
|
||||
await pluginExample()
|
||||
}
|
||||
|
||||
// 错误处理演示
|
||||
async function errorHandlingExample() {
|
||||
console.log('⚠️ 演示错误处理...')
|
||||
|
||||
try {
|
||||
const executor = createExecutor('openai', {
|
||||
apiKey: 'invalid-key'
|
||||
})
|
||||
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'Test image'
|
||||
})
|
||||
} catch (error: any) {
|
||||
console.log('✅ 成功捕获错误:', error.constructor.name)
|
||||
console.log('📋 错误信息:', error.message)
|
||||
console.log('🏷️ 提供商ID:', error.providerId)
|
||||
console.log('🏷️ 模型ID:', error.modelId)
|
||||
}
|
||||
}
|
||||
|
||||
// 运行示例
|
||||
if (require.main === module) {
|
||||
main()
|
||||
.then(() => {
|
||||
console.log('🎉 所有示例完成!')
|
||||
return errorHandlingExample()
|
||||
})
|
||||
.then(() => {
|
||||
console.log('🎯 示例程序结束')
|
||||
process.exit(0)
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('💥 程序执行出错:', error)
|
||||
process.exit(1)
|
||||
})
|
||||
}
|
||||
85
packages/aiCore/package.json
Normal file
85
packages/aiCore/package.json
Normal file
@@ -0,0 +1,85 @@
|
||||
{
|
||||
"name": "@cherrystudio/ai-core",
|
||||
"version": "1.0.0-alpha.11",
|
||||
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
|
||||
"main": "dist/index.js",
|
||||
"module": "dist/index.mjs",
|
||||
"types": "dist/index.d.ts",
|
||||
"react-native": "dist/index.js",
|
||||
"scripts": {
|
||||
"build": "tsdown",
|
||||
"dev": "tsc -w",
|
||||
"clean": "rm -rf dist",
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest"
|
||||
},
|
||||
"keywords": [
|
||||
"ai",
|
||||
"sdk",
|
||||
"openai",
|
||||
"anthropic",
|
||||
"google",
|
||||
"cherry-studio",
|
||||
"vercel-ai-sdk"
|
||||
],
|
||||
"author": "Cherry Studio",
|
||||
"license": "MIT",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/CherryHQ/cherry-studio.git"
|
||||
},
|
||||
"bugs": {
|
||||
"url": "https://github.com/CherryHQ/cherry-studio/issues"
|
||||
},
|
||||
"homepage": "https://github.com/CherryHQ/cherry-studio#readme",
|
||||
"peerDependencies": {
|
||||
"ai": "^5.0.26"
|
||||
},
|
||||
"dependencies": {
|
||||
"@ai-sdk/anthropic": "^2.0.5",
|
||||
"@ai-sdk/azure": "^2.0.16",
|
||||
"@ai-sdk/deepseek": "^1.0.9",
|
||||
"@ai-sdk/google": "^2.0.7",
|
||||
"@ai-sdk/openai": "^2.0.19",
|
||||
"@ai-sdk/openai-compatible": "^1.0.9",
|
||||
"@ai-sdk/provider": "^2.0.0",
|
||||
"@ai-sdk/provider-utils": "^3.0.4",
|
||||
"@ai-sdk/xai": "^2.0.9",
|
||||
"zod": "^3.25.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"tsdown": "^0.12.9",
|
||||
"typescript": "^5.0.0",
|
||||
"vitest": "^3.2.4"
|
||||
},
|
||||
"sideEffects": false,
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
},
|
||||
"files": [
|
||||
"dist"
|
||||
],
|
||||
"exports": {
|
||||
".": {
|
||||
"types": "./dist/index.d.ts",
|
||||
"react-native": "./dist/index.js",
|
||||
"import": "./dist/index.mjs",
|
||||
"require": "./dist/index.js",
|
||||
"default": "./dist/index.js"
|
||||
},
|
||||
"./built-in/plugins": {
|
||||
"types": "./dist/built-in/plugins/index.d.ts",
|
||||
"react-native": "./dist/built-in/plugins/index.js",
|
||||
"import": "./dist/built-in/plugins/index.mjs",
|
||||
"require": "./dist/built-in/plugins/index.js",
|
||||
"default": "./dist/built-in/plugins/index.js"
|
||||
},
|
||||
"./provider": {
|
||||
"types": "./dist/provider/index.d.ts",
|
||||
"react-native": "./dist/provider/index.js",
|
||||
"import": "./dist/provider/index.mjs",
|
||||
"require": "./dist/provider/index.js",
|
||||
"default": "./dist/provider/index.js"
|
||||
}
|
||||
}
|
||||
}
|
||||
2
packages/aiCore/setupVitest.ts
Normal file
2
packages/aiCore/setupVitest.ts
Normal file
@@ -0,0 +1,2 @@
|
||||
// 模拟 Vite SSR helper,避免 Node 环境找不到时报错
|
||||
;(globalThis as any).__vite_ssr_exportName__ = (name: string, value: any) => value
|
||||
3
packages/aiCore/src/core/README.MD
Normal file
3
packages/aiCore/src/core/README.MD
Normal file
@@ -0,0 +1,3 @@
|
||||
# @cherryStudio-aiCore
|
||||
|
||||
Core
|
||||
17
packages/aiCore/src/core/index.ts
Normal file
17
packages/aiCore/src/core/index.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
/**
|
||||
* Core 模块导出
|
||||
* 内部核心功能,供其他模块使用,不直接面向最终调用者
|
||||
*/
|
||||
|
||||
// 中间件系统
|
||||
export type { NamedMiddleware } from './middleware'
|
||||
export { createMiddlewares, wrapModelWithMiddlewares } from './middleware'
|
||||
|
||||
// 创建管理
|
||||
export { globalModelResolver, ModelResolver } from './models'
|
||||
export type { ModelConfig as ModelConfigType } from './models/types'
|
||||
|
||||
// 执行管理
|
||||
export type { ToolUseRequestContext } from './plugins/built-in/toolUsePlugin/type'
|
||||
export { createExecutor, createOpenAICompatibleExecutor } from './runtime'
|
||||
export type { RuntimeConfig } from './runtime/types'
|
||||
8
packages/aiCore/src/core/middleware/index.ts
Normal file
8
packages/aiCore/src/core/middleware/index.ts
Normal file
@@ -0,0 +1,8 @@
|
||||
/**
|
||||
* Middleware 模块导出
|
||||
* 提供通用的中间件管理能力
|
||||
*/
|
||||
|
||||
export { createMiddlewares } from './manager'
|
||||
export type { NamedMiddleware } from './types'
|
||||
export { wrapModelWithMiddlewares } from './wrapper'
|
||||
16
packages/aiCore/src/core/middleware/manager.ts
Normal file
16
packages/aiCore/src/core/middleware/manager.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
/**
|
||||
* 中间件管理器
|
||||
* 专注于 AI SDK 中间件的管理,与插件系统分离
|
||||
*/
|
||||
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
/**
|
||||
* 创建中间件列表
|
||||
* 合并用户提供的中间件
|
||||
*/
|
||||
export function createMiddlewares(userMiddlewares: LanguageModelV2Middleware[] = []): LanguageModelV2Middleware[] {
|
||||
// 未来可以在这里添加默认的中间件
|
||||
const defaultMiddlewares: LanguageModelV2Middleware[] = []
|
||||
|
||||
return [...defaultMiddlewares, ...userMiddlewares]
|
||||
}
|
||||
12
packages/aiCore/src/core/middleware/types.ts
Normal file
12
packages/aiCore/src/core/middleware/types.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
/**
|
||||
* 中间件系统类型定义
|
||||
*/
|
||||
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
/**
|
||||
* 具名中间件接口
|
||||
*/
|
||||
export interface NamedMiddleware {
|
||||
name: string
|
||||
middleware: LanguageModelV2Middleware
|
||||
}
|
||||
23
packages/aiCore/src/core/middleware/wrapper.ts
Normal file
23
packages/aiCore/src/core/middleware/wrapper.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
/**
|
||||
* 模型包装工具函数
|
||||
* 用于将中间件应用到LanguageModel上
|
||||
*/
|
||||
import { LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
import { wrapLanguageModel } from 'ai'
|
||||
|
||||
/**
|
||||
* 使用中间件包装模型
|
||||
*/
|
||||
export function wrapModelWithMiddlewares(
|
||||
model: LanguageModelV2,
|
||||
middlewares: LanguageModelV2Middleware[]
|
||||
): LanguageModelV2 {
|
||||
if (middlewares.length === 0) {
|
||||
return model
|
||||
}
|
||||
|
||||
return wrapLanguageModel({
|
||||
model,
|
||||
middleware: middlewares
|
||||
})
|
||||
}
|
||||
125
packages/aiCore/src/core/models/ModelResolver.ts
Normal file
125
packages/aiCore/src/core/models/ModelResolver.ts
Normal file
@@ -0,0 +1,125 @@
|
||||
/**
|
||||
* 模型解析器 - models模块的核心
|
||||
* 负责将modelId解析为AI SDK的LanguageModel实例
|
||||
* 支持传统格式和命名空间格式
|
||||
* 集成了来自 ModelCreator 的特殊处理逻辑
|
||||
*/
|
||||
|
||||
import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
import { wrapModelWithMiddlewares } from '../middleware/wrapper'
|
||||
import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../providers/RegistryManagement'
|
||||
|
||||
export class ModelResolver {
|
||||
/**
|
||||
* 核心方法:解析任意格式的modelId为语言模型
|
||||
*
|
||||
* @param modelId 模型ID,支持 'gpt-4' 和 'anthropic>claude-3' 两种格式
|
||||
* @param fallbackProviderId 当modelId为传统格式时使用的providerId
|
||||
* @param providerOptions provider配置选项(用于OpenAI模式选择等)
|
||||
* @param middlewares 中间件数组,会应用到最终模型上
|
||||
*/
|
||||
async resolveLanguageModel(
|
||||
modelId: string,
|
||||
fallbackProviderId: string,
|
||||
providerOptions?: any,
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<LanguageModelV2> {
|
||||
let finalProviderId = fallbackProviderId
|
||||
let model: LanguageModelV2
|
||||
// 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移)
|
||||
if ((fallbackProviderId === 'openai' || fallbackProviderId === 'azure') && providerOptions?.mode === 'chat') {
|
||||
finalProviderId = `${fallbackProviderId}-chat`
|
||||
}
|
||||
|
||||
// 检查是否是命名空间格式
|
||||
if (modelId.includes(DEFAULT_SEPARATOR)) {
|
||||
model = this.resolveNamespacedModel(modelId)
|
||||
} else {
|
||||
// 传统格式:使用处理后的 providerId + modelId
|
||||
model = this.resolveTraditionalModel(finalProviderId, modelId)
|
||||
}
|
||||
|
||||
// 🎯 应用中间件(如果有)
|
||||
if (middlewares && middlewares.length > 0) {
|
||||
model = wrapModelWithMiddlewares(model, middlewares)
|
||||
}
|
||||
|
||||
return model
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析文本嵌入模型
|
||||
*/
|
||||
async resolveTextEmbeddingModel(modelId: string, fallbackProviderId: string): Promise<EmbeddingModelV2<string>> {
|
||||
if (modelId.includes(DEFAULT_SEPARATOR)) {
|
||||
return this.resolveNamespacedEmbeddingModel(modelId)
|
||||
}
|
||||
|
||||
return this.resolveTraditionalEmbeddingModel(fallbackProviderId, modelId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析图像模型
|
||||
*/
|
||||
async resolveImageModel(modelId: string, fallbackProviderId: string): Promise<ImageModelV2> {
|
||||
if (modelId.includes(DEFAULT_SEPARATOR)) {
|
||||
return this.resolveNamespacedImageModel(modelId)
|
||||
}
|
||||
|
||||
return this.resolveTraditionalImageModel(fallbackProviderId, modelId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析命名空间格式的语言模型
|
||||
* aihubmix:anthropic:claude-3 -> globalRegistryManagement.languageModel('aihubmix:anthropic:claude-3')
|
||||
*/
|
||||
private resolveNamespacedModel(modelId: string): LanguageModelV2 {
|
||||
return globalRegistryManagement.languageModel(modelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析传统格式的语言模型
|
||||
* providerId: 'openai', modelId: 'gpt-4' -> globalRegistryManagement.languageModel('openai:gpt-4')
|
||||
*/
|
||||
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
|
||||
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||
console.log('fullModelId', fullModelId)
|
||||
return globalRegistryManagement.languageModel(fullModelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析命名空间格式的嵌入模型
|
||||
*/
|
||||
private resolveNamespacedEmbeddingModel(modelId: string): EmbeddingModelV2<string> {
|
||||
return globalRegistryManagement.textEmbeddingModel(modelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析传统格式的嵌入模型
|
||||
*/
|
||||
private resolveTraditionalEmbeddingModel(providerId: string, modelId: string): EmbeddingModelV2<string> {
|
||||
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||
return globalRegistryManagement.textEmbeddingModel(fullModelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析命名空间格式的图像模型
|
||||
*/
|
||||
private resolveNamespacedImageModel(modelId: string): ImageModelV2 {
|
||||
return globalRegistryManagement.imageModel(modelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析传统格式的图像模型
|
||||
*/
|
||||
private resolveTraditionalImageModel(providerId: string, modelId: string): ImageModelV2 {
|
||||
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||
return globalRegistryManagement.imageModel(fullModelId as any)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 全局模型解析器实例
|
||||
*/
|
||||
export const globalModelResolver = new ModelResolver()
|
||||
9
packages/aiCore/src/core/models/index.ts
Normal file
9
packages/aiCore/src/core/models/index.ts
Normal file
@@ -0,0 +1,9 @@
|
||||
/**
|
||||
* Models 模块统一导出 - 简化版
|
||||
*/
|
||||
|
||||
// 核心模型解析器
|
||||
export { globalModelResolver, ModelResolver } from './ModelResolver'
|
||||
|
||||
// 保留的类型定义(可能被其他地方使用)
|
||||
export type { ModelConfig as ModelConfigType } from './types'
|
||||
15
packages/aiCore/src/core/models/types.ts
Normal file
15
packages/aiCore/src/core/models/types.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
/**
|
||||
* Creation 模块类型定义
|
||||
*/
|
||||
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
import type { ProviderId, ProviderSettingsMap } from '../providers/types'
|
||||
|
||||
export interface ModelConfig<T extends ProviderId = ProviderId> {
|
||||
providerId: T
|
||||
modelId: string
|
||||
providerSettings: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
// 额外模型参数
|
||||
extraModelConfig?: Record<string, any>
|
||||
}
|
||||
87
packages/aiCore/src/core/options/examples.ts
Normal file
87
packages/aiCore/src/core/options/examples.ts
Normal file
@@ -0,0 +1,87 @@
|
||||
import { streamText } from 'ai'
|
||||
|
||||
import {
|
||||
createAnthropicOptions,
|
||||
createGenericProviderOptions,
|
||||
createGoogleOptions,
|
||||
createOpenAIOptions,
|
||||
mergeProviderOptions
|
||||
} from './factory'
|
||||
|
||||
// 示例1: 使用已知供应商的严格类型约束
|
||||
export function exampleOpenAIWithOptions() {
|
||||
const openaiOptions = createOpenAIOptions({
|
||||
reasoningEffort: 'medium'
|
||||
})
|
||||
|
||||
// 这里会有类型检查,确保选项符合OpenAI的设置
|
||||
return streamText({
|
||||
model: {} as any, // 实际使用时替换为真实模型
|
||||
prompt: 'Hello',
|
||||
providerOptions: openaiOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例2: 使用Anthropic供应商选项
|
||||
export function exampleAnthropicWithOptions() {
|
||||
const anthropicOptions = createAnthropicOptions({
|
||||
thinking: {
|
||||
type: 'enabled',
|
||||
budgetTokens: 1000
|
||||
}
|
||||
})
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: anthropicOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例3: 使用Google供应商选项
|
||||
export function exampleGoogleWithOptions() {
|
||||
const googleOptions = createGoogleOptions({
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: 1000
|
||||
}
|
||||
})
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: googleOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例4: 使用未知供应商(通用类型)
|
||||
export function exampleUnknownProviderWithOptions() {
|
||||
const customProviderOptions = createGenericProviderOptions('custom-provider', {
|
||||
temperature: 0.7,
|
||||
customSetting: 'value',
|
||||
anotherOption: true
|
||||
})
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: customProviderOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例5: 合并多个供应商选项
|
||||
export function exampleMergedOptions() {
|
||||
const openaiOptions = createOpenAIOptions({})
|
||||
|
||||
const customOptions = createGenericProviderOptions('custom', {
|
||||
customParam: 'value'
|
||||
})
|
||||
|
||||
const mergedOptions = mergeProviderOptions(openaiOptions, customOptions)
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: mergedOptions
|
||||
})
|
||||
}
|
||||
71
packages/aiCore/src/core/options/factory.ts
Normal file
71
packages/aiCore/src/core/options/factory.ts
Normal file
@@ -0,0 +1,71 @@
|
||||
import { ExtractProviderOptions, ProviderOptionsMap, TypedProviderOptions } from './types'
|
||||
|
||||
/**
|
||||
* 创建特定供应商的选项
|
||||
* @param provider 供应商名称
|
||||
* @param options 供应商特定的选项
|
||||
* @returns 格式化的provider options
|
||||
*/
|
||||
export function createProviderOptions<T extends keyof ProviderOptionsMap>(
|
||||
provider: T,
|
||||
options: ExtractProviderOptions<T>
|
||||
): Record<T, ExtractProviderOptions<T>> {
|
||||
return { [provider]: options } as Record<T, ExtractProviderOptions<T>>
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建任意供应商的选项(包括未知供应商)
|
||||
* @param provider 供应商名称
|
||||
* @param options 供应商选项
|
||||
* @returns 格式化的provider options
|
||||
*/
|
||||
export function createGenericProviderOptions<T extends string>(
|
||||
provider: T,
|
||||
options: Record<string, any>
|
||||
): Record<T, Record<string, any>> {
|
||||
return { [provider]: options } as Record<T, Record<string, any>>
|
||||
}
|
||||
|
||||
/**
|
||||
* 合并多个供应商的options
|
||||
* @param optionsMap 包含多个供应商选项的对象
|
||||
* @returns 合并后的TypedProviderOptions
|
||||
*/
|
||||
export function mergeProviderOptions(...optionsMap: Partial<TypedProviderOptions>[]): TypedProviderOptions {
|
||||
return Object.assign({}, ...optionsMap)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建OpenAI供应商选项的便捷函数
|
||||
*/
|
||||
export function createOpenAIOptions(options: ExtractProviderOptions<'openai'>) {
|
||||
return createProviderOptions('openai', options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建Anthropic供应商选项的便捷函数
|
||||
*/
|
||||
export function createAnthropicOptions(options: ExtractProviderOptions<'anthropic'>) {
|
||||
return createProviderOptions('anthropic', options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建Google供应商选项的便捷函数
|
||||
*/
|
||||
export function createGoogleOptions(options: ExtractProviderOptions<'google'>) {
|
||||
return createProviderOptions('google', options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建OpenRouter供应商选项的便捷函数
|
||||
*/
|
||||
export function createOpenRouterOptions(options: ExtractProviderOptions<'openrouter'>) {
|
||||
return createProviderOptions('openrouter', options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建XAI供应商选项的便捷函数
|
||||
*/
|
||||
export function createXaiOptions(options: ExtractProviderOptions<'xai'>) {
|
||||
return createProviderOptions('xai', options)
|
||||
}
|
||||
2
packages/aiCore/src/core/options/index.ts
Normal file
2
packages/aiCore/src/core/options/index.ts
Normal file
@@ -0,0 +1,2 @@
|
||||
export * from './factory'
|
||||
export * from './types'
|
||||
38
packages/aiCore/src/core/options/openrouter.ts
Normal file
38
packages/aiCore/src/core/options/openrouter.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
export type OpenRouterProviderOptions = {
|
||||
models?: string[]
|
||||
|
||||
/**
|
||||
* https://openrouter.ai/docs/use-cases/reasoning-tokens
|
||||
* One of `max_tokens` or `effort` is required.
|
||||
* If `exclude` is true, reasoning will be removed from the response. Default is false.
|
||||
*/
|
||||
reasoning?: {
|
||||
exclude?: boolean
|
||||
} & (
|
||||
| {
|
||||
max_tokens: number
|
||||
}
|
||||
| {
|
||||
effort: 'high' | 'medium' | 'low'
|
||||
}
|
||||
)
|
||||
|
||||
/**
|
||||
* A unique identifier representing your end-user, which can
|
||||
* help OpenRouter to monitor and detect abuse.
|
||||
*/
|
||||
user?: string
|
||||
|
||||
extraBody?: Record<string, unknown>
|
||||
|
||||
/**
|
||||
* Enable usage accounting to get detailed token usage information.
|
||||
* https://openrouter.ai/docs/use-cases/usage-accounting
|
||||
*/
|
||||
usage?: {
|
||||
/**
|
||||
* When true, includes token usage information in the response.
|
||||
*/
|
||||
include: boolean
|
||||
}
|
||||
}
|
||||
33
packages/aiCore/src/core/options/types.ts
Normal file
33
packages/aiCore/src/core/options/types.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
import { type AnthropicProviderOptions } from '@ai-sdk/anthropic'
|
||||
import { type GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
|
||||
import { type OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
|
||||
import { type SharedV2ProviderMetadata } from '@ai-sdk/provider'
|
||||
|
||||
import { type OpenRouterProviderOptions } from './openrouter'
|
||||
import { type XaiProviderOptions } from './xai'
|
||||
|
||||
export type ProviderOptions<T extends keyof SharedV2ProviderMetadata> = SharedV2ProviderMetadata[T]
|
||||
|
||||
/**
|
||||
* 供应商选项类型,如果map中没有,说明没有约束
|
||||
*/
|
||||
export type ProviderOptionsMap = {
|
||||
openai: OpenAIResponsesProviderOptions
|
||||
anthropic: AnthropicProviderOptions
|
||||
google: GoogleGenerativeAIProviderOptions
|
||||
openrouter: OpenRouterProviderOptions
|
||||
xai: XaiProviderOptions
|
||||
}
|
||||
|
||||
// 工具类型,用于从ProviderOptionsMap中提取特定供应商的选项类型
|
||||
export type ExtractProviderOptions<T extends keyof ProviderOptionsMap> = ProviderOptionsMap[T]
|
||||
|
||||
/**
|
||||
* 类型安全的ProviderOptions
|
||||
* 对于已知供应商使用严格类型,对于未知供应商允许任意Record<string, JSONValue>
|
||||
*/
|
||||
export type TypedProviderOptions = {
|
||||
[K in keyof ProviderOptionsMap]?: ProviderOptionsMap[K]
|
||||
} & {
|
||||
[K in string]?: Record<string, any>
|
||||
} & SharedV2ProviderMetadata
|
||||
86
packages/aiCore/src/core/options/xai.ts
Normal file
86
packages/aiCore/src/core/options/xai.ts
Normal file
@@ -0,0 +1,86 @@
|
||||
// copy from @ai-sdk/xai/xai-chat-options.ts
|
||||
// 如果@ai-sdk/xai暴露出了xaiProviderOptions就删除这个文件
|
||||
|
||||
import * as z from 'zod/v4'
|
||||
|
||||
const webSourceSchema = z.object({
|
||||
type: z.literal('web'),
|
||||
country: z.string().length(2).optional(),
|
||||
excludedWebsites: z.array(z.string()).max(5).optional(),
|
||||
allowedWebsites: z.array(z.string()).max(5).optional(),
|
||||
safeSearch: z.boolean().optional()
|
||||
})
|
||||
|
||||
const xSourceSchema = z.object({
|
||||
type: z.literal('x'),
|
||||
xHandles: z.array(z.string()).optional()
|
||||
})
|
||||
|
||||
const newsSourceSchema = z.object({
|
||||
type: z.literal('news'),
|
||||
country: z.string().length(2).optional(),
|
||||
excludedWebsites: z.array(z.string()).max(5).optional(),
|
||||
safeSearch: z.boolean().optional()
|
||||
})
|
||||
|
||||
const rssSourceSchema = z.object({
|
||||
type: z.literal('rss'),
|
||||
links: z.array(z.url()).max(1) // currently only supports one RSS link
|
||||
})
|
||||
|
||||
const searchSourceSchema = z.discriminatedUnion('type', [
|
||||
webSourceSchema,
|
||||
xSourceSchema,
|
||||
newsSourceSchema,
|
||||
rssSourceSchema
|
||||
])
|
||||
|
||||
export const xaiProviderOptions = z.object({
|
||||
/**
|
||||
* reasoning effort for reasoning models
|
||||
* only supported by grok-3-mini and grok-3-mini-fast models
|
||||
*/
|
||||
reasoningEffort: z.enum(['low', 'high']).optional(),
|
||||
|
||||
searchParameters: z
|
||||
.object({
|
||||
/**
|
||||
* search mode preference
|
||||
* - "off": disables search completely
|
||||
* - "auto": model decides whether to search (default)
|
||||
* - "on": always enables search
|
||||
*/
|
||||
mode: z.enum(['off', 'auto', 'on']),
|
||||
|
||||
/**
|
||||
* whether to return citations in the response
|
||||
* defaults to true
|
||||
*/
|
||||
returnCitations: z.boolean().optional(),
|
||||
|
||||
/**
|
||||
* start date for search data (ISO8601 format: YYYY-MM-DD)
|
||||
*/
|
||||
fromDate: z.string().optional(),
|
||||
|
||||
/**
|
||||
* end date for search data (ISO8601 format: YYYY-MM-DD)
|
||||
*/
|
||||
toDate: z.string().optional(),
|
||||
|
||||
/**
|
||||
* maximum number of search results to consider
|
||||
* defaults to 20
|
||||
*/
|
||||
maxSearchResults: z.number().min(1).max(50).optional(),
|
||||
|
||||
/**
|
||||
* data sources to search from
|
||||
* defaults to ["web", "x"] if not specified
|
||||
*/
|
||||
sources: z.array(searchSourceSchema).optional()
|
||||
})
|
||||
.optional()
|
||||
})
|
||||
|
||||
export type XaiProviderOptions = z.infer<typeof xaiProviderOptions>
|
||||
257
packages/aiCore/src/core/plugins/README.md
Normal file
257
packages/aiCore/src/core/plugins/README.md
Normal file
@@ -0,0 +1,257 @@
|
||||
# AI Core 插件系统
|
||||
|
||||
支持四种钩子类型:**First**、**Sequential**、**Parallel** 和 **Stream**。
|
||||
|
||||
## 🎯 设计理念
|
||||
|
||||
- **语义清晰**:不同钩子有不同的执行语义
|
||||
- **类型安全**:TypeScript 完整支持
|
||||
- **性能优化**:First 短路、Parallel 并发、Sequential 链式
|
||||
- **易于扩展**:`enforce` 排序 + 功能分类
|
||||
|
||||
## 📋 钩子类型
|
||||
|
||||
### 🥇 First 钩子 - 首个有效结果
|
||||
|
||||
```typescript
|
||||
// 只执行第一个返回值的插件,用于解析和查找
|
||||
resolveModel?: (modelId: string, context: AiRequestContext) => string | null
|
||||
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null
|
||||
```
|
||||
|
||||
### 🔄 Sequential 钩子 - 链式数据转换
|
||||
|
||||
```typescript
|
||||
// 按顺序链式执行,每个插件可以修改数据
|
||||
transformParams?: (params: any, context: AiRequestContext) => any
|
||||
transformResult?: (result: any, context: AiRequestContext) => any
|
||||
```
|
||||
|
||||
### ⚡ Parallel 钩子 - 并行副作用
|
||||
|
||||
```typescript
|
||||
// 并发执行,用于日志、监控等副作用
|
||||
onRequestStart?: (context: AiRequestContext) => void
|
||||
onRequestEnd?: (context: AiRequestContext, result: any) => void
|
||||
onError?: (error: Error, context: AiRequestContext) => void
|
||||
```
|
||||
|
||||
### 🌊 Stream 钩子 - 流处理
|
||||
|
||||
```typescript
|
||||
// 直接使用 AI SDK 的 TransformStream
|
||||
transformStream?: () => (options) => TransformStream<TextStreamPart, TextStreamPart>
|
||||
```
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 基础用法
|
||||
|
||||
```typescript
|
||||
import { PluginManager, createContext, definePlugin } from '@cherrystudio/ai-core/middleware'
|
||||
|
||||
// 创建插件管理器
|
||||
const pluginManager = new PluginManager()
|
||||
|
||||
// 添加插件
|
||||
pluginManager.use({
|
||||
name: 'my-plugin',
|
||||
async transformParams(params, context) {
|
||||
return { ...params, temperature: 0.7 }
|
||||
}
|
||||
})
|
||||
|
||||
// 使用插件
|
||||
const context = createContext('openai', 'gpt-4', { messages: [] })
|
||||
const transformedParams = await pluginManager.executeSequential(
|
||||
'transformParams',
|
||||
{ messages: [{ role: 'user', content: 'Hello' }] },
|
||||
context
|
||||
)
|
||||
```
|
||||
|
||||
### 完整示例
|
||||
|
||||
```typescript
|
||||
import {
|
||||
PluginManager,
|
||||
ModelAliasPlugin,
|
||||
LoggingPlugin,
|
||||
ParamsValidationPlugin,
|
||||
createContext
|
||||
} from '@cherrystudio/ai-core/middleware'
|
||||
|
||||
// 创建插件管理器
|
||||
const manager = new PluginManager([
|
||||
ModelAliasPlugin, // 模型别名解析
|
||||
ParamsValidationPlugin, // 参数验证
|
||||
LoggingPlugin // 日志记录
|
||||
])
|
||||
|
||||
// AI 请求流程
|
||||
async function aiRequest(providerId: string, modelId: string, params: any) {
|
||||
const context = createContext(providerId, modelId, params)
|
||||
|
||||
try {
|
||||
// 1. 【并行】触发请求开始事件
|
||||
await manager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 【首个】解析模型别名
|
||||
const resolvedModel = await manager.executeFirst('resolveModel', modelId, context)
|
||||
context.modelId = resolvedModel || modelId
|
||||
|
||||
// 3. 【串行】转换请求参数
|
||||
const transformedParams = await manager.executeSequential('transformParams', params, context)
|
||||
|
||||
// 4. 【流处理】收集流转换器(AI SDK 原生支持数组)
|
||||
const streamTransforms = manager.collectStreamTransforms()
|
||||
|
||||
// 5. 调用 AI SDK(这里省略具体实现)
|
||||
const result = await callAiSdk(transformedParams, streamTransforms)
|
||||
|
||||
// 6. 【串行】转换响应结果
|
||||
const transformedResult = await manager.executeSequential('transformResult', result, context)
|
||||
|
||||
// 7. 【并行】触发请求完成事件
|
||||
await manager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
|
||||
return transformedResult
|
||||
} catch (error) {
|
||||
// 8. 【并行】触发错误事件
|
||||
await manager.executeParallel('onError', context, undefined, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 🔧 自定义插件
|
||||
|
||||
### 模型别名插件
|
||||
|
||||
```typescript
|
||||
const ModelAliasPlugin = definePlugin({
|
||||
name: 'model-alias',
|
||||
enforce: 'pre', // 最先执行
|
||||
|
||||
async resolveModel(modelId) {
|
||||
const aliases = {
|
||||
gpt4: 'gpt-4-turbo-preview',
|
||||
claude: 'claude-3-sonnet-20240229'
|
||||
}
|
||||
return aliases[modelId] || null
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### 参数验证插件
|
||||
|
||||
```typescript
|
||||
const ValidationPlugin = definePlugin({
|
||||
name: 'validation',
|
||||
|
||||
async transformParams(params) {
|
||||
if (!params.messages) {
|
||||
throw new Error('messages is required')
|
||||
}
|
||||
|
||||
return {
|
||||
...params,
|
||||
temperature: params.temperature ?? 0.7,
|
||||
max_tokens: params.max_tokens ?? 4096
|
||||
}
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### 监控插件
|
||||
|
||||
```typescript
|
||||
const MonitoringPlugin = definePlugin({
|
||||
name: 'monitoring',
|
||||
enforce: 'post', // 最后执行
|
||||
|
||||
async onRequestEnd(context, result) {
|
||||
const duration = Date.now() - context.startTime
|
||||
console.log(`请求耗时: ${duration}ms`)
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### 内容过滤插件
|
||||
|
||||
```typescript
|
||||
const FilterPlugin = definePlugin({
|
||||
name: 'content-filter',
|
||||
|
||||
transformStream() {
|
||||
return () =>
|
||||
new TransformStream({
|
||||
transform(chunk, controller) {
|
||||
if (chunk.type === 'text-delta') {
|
||||
const filtered = chunk.textDelta.replace(/敏感词/g, '***')
|
||||
controller.enqueue({ ...chunk, textDelta: filtered })
|
||||
} else {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
## 📊 执行顺序
|
||||
|
||||
### 插件排序
|
||||
|
||||
```
|
||||
enforce: 'pre' → normal → enforce: 'post'
|
||||
```
|
||||
|
||||
### 钩子执行流程
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[请求开始] --> B[onRequestStart 并行执行]
|
||||
B --> C[resolveModel 首个有效]
|
||||
C --> D[loadTemplate 首个有效]
|
||||
D --> E[transformParams 串行执行]
|
||||
E --> F[collectStreamTransforms]
|
||||
F --> G[AI SDK 调用]
|
||||
G --> H[transformResult 串行执行]
|
||||
H --> I[onRequestEnd 并行执行]
|
||||
|
||||
G --> J[异常处理]
|
||||
J --> K[onError 并行执行]
|
||||
```
|
||||
|
||||
## 💡 最佳实践
|
||||
|
||||
1. **功能单一**:每个插件专注一个功能
|
||||
2. **幂等性**:插件应该是幂等的,重复执行不会产生副作用
|
||||
3. **错误处理**:插件内部处理异常,不要让异常向上传播
|
||||
4. **性能优化**:使用合适的钩子类型(First vs Sequential vs Parallel)
|
||||
5. **命名规范**:使用语义化的插件名称
|
||||
|
||||
## 🔍 调试工具
|
||||
|
||||
```typescript
|
||||
// 查看插件统计信息
|
||||
const stats = manager.getStats()
|
||||
console.log('插件统计:', stats)
|
||||
|
||||
// 查看所有插件
|
||||
const plugins = manager.getPlugins()
|
||||
console.log(
|
||||
'已注册插件:',
|
||||
plugins.map((p) => p.name)
|
||||
)
|
||||
```
|
||||
|
||||
## ⚡ 性能优势
|
||||
|
||||
- **First 钩子**:一旦找到结果立即停止,避免无效计算
|
||||
- **Parallel 钩子**:真正并发执行,不阻塞主流程
|
||||
- **Sequential 钩子**:保证数据转换的顺序性
|
||||
- **Stream 钩子**:直接集成 AI SDK,零开销
|
||||
|
||||
这个设计兼顾了简洁性和强大功能,为 AI Core 提供了灵活而高效的扩展机制。
|
||||
10
packages/aiCore/src/core/plugins/built-in/index.ts
Normal file
10
packages/aiCore/src/core/plugins/built-in/index.ts
Normal file
@@ -0,0 +1,10 @@
|
||||
/**
|
||||
* 内置插件命名空间
|
||||
* 所有内置插件都以 'built-in:' 为前缀
|
||||
*/
|
||||
export const BUILT_IN_PLUGIN_PREFIX = 'built-in:'
|
||||
|
||||
export { createLoggingPlugin } from './logging'
|
||||
export { createPromptToolUsePlugin } from './toolUsePlugin/promptToolUsePlugin'
|
||||
export type { PromptToolUseConfig, ToolUseRequestContext, ToolUseResult } from './toolUsePlugin/type'
|
||||
export { webSearchPlugin } from './webSearchPlugin'
|
||||
86
packages/aiCore/src/core/plugins/built-in/logging.ts
Normal file
86
packages/aiCore/src/core/plugins/built-in/logging.ts
Normal file
@@ -0,0 +1,86 @@
|
||||
/**
|
||||
* 内置插件:日志记录
|
||||
* 记录AI调用的关键信息,支持性能监控和调试
|
||||
*/
|
||||
import { definePlugin } from '../index'
|
||||
import type { AiRequestContext } from '../types'
|
||||
|
||||
export interface LoggingConfig {
|
||||
// 日志级别
|
||||
level?: 'debug' | 'info' | 'warn' | 'error'
|
||||
// 是否记录参数
|
||||
logParams?: boolean
|
||||
// 是否记录结果
|
||||
logResult?: boolean
|
||||
// 是否记录性能数据
|
||||
logPerformance?: boolean
|
||||
// 自定义日志函数
|
||||
logger?: (level: string, message: string, data?: any) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建日志插件
|
||||
*/
|
||||
export function createLoggingPlugin(config: LoggingConfig = {}) {
|
||||
const { level = 'info', logParams = true, logResult = false, logPerformance = true, logger = console.log } = config
|
||||
|
||||
const startTimes = new Map<string, number>()
|
||||
|
||||
return definePlugin({
|
||||
name: 'built-in:logging',
|
||||
|
||||
onRequestStart: (context: AiRequestContext) => {
|
||||
const requestId = context.requestId
|
||||
startTimes.set(requestId, Date.now())
|
||||
|
||||
logger(level, `🚀 AI Request Started`, {
|
||||
requestId,
|
||||
providerId: context.providerId,
|
||||
modelId: context.modelId,
|
||||
originalParams: logParams ? context.originalParams : '[hidden]'
|
||||
})
|
||||
},
|
||||
|
||||
onRequestEnd: (context: AiRequestContext, result: any) => {
|
||||
const requestId = context.requestId
|
||||
const startTime = startTimes.get(requestId)
|
||||
const duration = startTime ? Date.now() - startTime : undefined
|
||||
startTimes.delete(requestId)
|
||||
|
||||
const logData: any = {
|
||||
requestId,
|
||||
providerId: context.providerId,
|
||||
modelId: context.modelId
|
||||
}
|
||||
|
||||
if (logPerformance && duration) {
|
||||
logData.duration = `${duration}ms`
|
||||
}
|
||||
|
||||
if (logResult) {
|
||||
logData.result = result
|
||||
}
|
||||
|
||||
logger(level, `✅ AI Request Completed`, logData)
|
||||
},
|
||||
|
||||
onError: (error: Error, context: AiRequestContext) => {
|
||||
const requestId = context.requestId
|
||||
const startTime = startTimes.get(requestId)
|
||||
const duration = startTime ? Date.now() - startTime : undefined
|
||||
startTimes.delete(requestId)
|
||||
|
||||
logger('error', `❌ AI Request Failed`, {
|
||||
requestId,
|
||||
providerId: context.providerId,
|
||||
modelId: context.modelId,
|
||||
duration: duration ? `${duration}ms` : undefined,
|
||||
error: {
|
||||
name: error.name,
|
||||
message: error.message,
|
||||
stack: error.stack
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
/**
|
||||
* 流事件管理器
|
||||
*
|
||||
* 负责处理 AI SDK 流事件的发送和管理
|
||||
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
|
||||
*/
|
||||
import type { ModelMessage } from 'ai'
|
||||
|
||||
import type { AiRequestContext } from '../../types'
|
||||
import type { StreamController } from './ToolExecutor'
|
||||
|
||||
/**
|
||||
* 流事件管理器类
|
||||
*/
|
||||
export class StreamEventManager {
|
||||
/**
|
||||
* 发送工具调用步骤开始事件
|
||||
*/
|
||||
sendStepStartEvent(controller: StreamController): void {
|
||||
controller.enqueue({
|
||||
type: 'start-step',
|
||||
request: {},
|
||||
warnings: []
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送步骤完成事件
|
||||
*/
|
||||
sendStepFinishEvent(controller: StreamController, chunk: any): void {
|
||||
controller.enqueue({
|
||||
type: 'finish-step',
|
||||
finishReason: 'stop',
|
||||
response: chunk.response,
|
||||
usage: chunk.usage,
|
||||
providerMetadata: chunk.providerMetadata
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理递归调用并将结果流接入当前流
|
||||
*/
|
||||
async handleRecursiveCall(
|
||||
controller: StreamController,
|
||||
recursiveParams: any,
|
||||
context: AiRequestContext,
|
||||
stepId: string
|
||||
): Promise<void> {
|
||||
try {
|
||||
console.log('[MCP Prompt] Starting recursive call after tool execution...')
|
||||
|
||||
const recursiveResult = await context.recursiveCall(recursiveParams)
|
||||
|
||||
if (recursiveResult && recursiveResult.fullStream) {
|
||||
await this.pipeRecursiveStream(controller, recursiveResult.fullStream)
|
||||
} else {
|
||||
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
|
||||
}
|
||||
} catch (error) {
|
||||
this.handleRecursiveCallError(controller, error, stepId)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 将递归流的数据传递到当前流
|
||||
*/
|
||||
private async pipeRecursiveStream(controller: StreamController, recursiveStream: ReadableStream): Promise<void> {
|
||||
const reader = recursiveStream.getReader()
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) {
|
||||
break
|
||||
}
|
||||
if (value.type === 'finish') {
|
||||
// 迭代的流不发finish
|
||||
break
|
||||
}
|
||||
// 将递归流的数据传递到当前流
|
||||
controller.enqueue(value)
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理递归调用错误
|
||||
*/
|
||||
private handleRecursiveCallError(controller: StreamController, error: unknown, stepId: string): void {
|
||||
console.error('[MCP Prompt] Recursive call failed:', error)
|
||||
|
||||
// 使用 AI SDK 标准错误格式,但不中断流
|
||||
controller.enqueue({
|
||||
type: 'error',
|
||||
error: {
|
||||
message: error instanceof Error ? error.message : String(error),
|
||||
name: error instanceof Error ? error.name : 'RecursiveCallError'
|
||||
}
|
||||
})
|
||||
|
||||
// 继续发送文本增量,保持流的连续性
|
||||
controller.enqueue({
|
||||
type: 'text-delta',
|
||||
id: stepId,
|
||||
text: '\n\n[工具执行后递归调用失败,继续对话...]'
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建递归调用的参数
|
||||
*/
|
||||
buildRecursiveParams(context: AiRequestContext, textBuffer: string, toolResultsText: string, tools: any): any {
|
||||
// 构建新的对话消息
|
||||
const newMessages: ModelMessage[] = [
|
||||
...(context.originalParams.messages || []),
|
||||
{
|
||||
role: 'assistant',
|
||||
content: textBuffer
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: toolResultsText
|
||||
}
|
||||
]
|
||||
|
||||
// 递归调用,继续对话,重新传递 tools
|
||||
const recursiveParams = {
|
||||
...context.originalParams,
|
||||
messages: newMessages,
|
||||
tools: tools
|
||||
}
|
||||
|
||||
// 更新上下文中的消息
|
||||
context.originalParams.messages = newMessages
|
||||
|
||||
return recursiveParams
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
/**
|
||||
* 工具执行器
|
||||
*
|
||||
* 负责工具的执行、结果格式化和相关事件发送
|
||||
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
|
||||
*/
|
||||
import type { ToolSet } from 'ai'
|
||||
|
||||
import type { ToolUseResult } from './type'
|
||||
|
||||
/**
|
||||
* 工具执行结果
|
||||
*/
|
||||
export interface ExecutedResult {
|
||||
toolCallId: string
|
||||
toolName: string
|
||||
result: any
|
||||
isError?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* 流控制器类型(从 AI SDK 提取)
|
||||
*/
|
||||
export interface StreamController {
|
||||
enqueue(chunk: any): void
|
||||
}
|
||||
|
||||
/**
|
||||
* 工具执行器类
|
||||
*/
|
||||
export class ToolExecutor {
|
||||
/**
|
||||
* 执行多个工具调用
|
||||
*/
|
||||
async executeTools(
|
||||
toolUses: ToolUseResult[],
|
||||
tools: ToolSet,
|
||||
controller: StreamController
|
||||
): Promise<ExecutedResult[]> {
|
||||
const executedResults: ExecutedResult[] = []
|
||||
|
||||
for (const toolUse of toolUses) {
|
||||
try {
|
||||
const tool = tools[toolUse.toolName]
|
||||
if (!tool || typeof tool.execute !== 'function') {
|
||||
throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
|
||||
}
|
||||
|
||||
// 发送工具调用开始事件
|
||||
this.sendToolStartEvents(controller, toolUse)
|
||||
|
||||
console.log(`[MCP Prompt Stream] Executing tool: ${toolUse.toolName}`, toolUse.arguments)
|
||||
|
||||
// 发送 tool-call 事件
|
||||
controller.enqueue({
|
||||
type: 'tool-call',
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
input: tool.inputSchema
|
||||
})
|
||||
|
||||
const result = await tool.execute(toolUse.arguments, {
|
||||
toolCallId: toolUse.id,
|
||||
messages: [],
|
||||
abortSignal: new AbortController().signal
|
||||
})
|
||||
|
||||
// 发送 tool-result 事件
|
||||
controller.enqueue({
|
||||
type: 'tool-result',
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
input: toolUse.arguments,
|
||||
output: result
|
||||
})
|
||||
|
||||
executedResults.push({
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
result,
|
||||
isError: false
|
||||
})
|
||||
} catch (error) {
|
||||
console.error(`[MCP Prompt Stream] Tool execution failed: ${toolUse.toolName}`, error)
|
||||
|
||||
// 处理错误情况
|
||||
const errorResult = this.handleToolError(toolUse, error, controller)
|
||||
executedResults.push(errorResult)
|
||||
}
|
||||
}
|
||||
|
||||
return executedResults
|
||||
}
|
||||
|
||||
/**
|
||||
* 格式化工具结果为 Cherry Studio 标准格式
|
||||
*/
|
||||
formatToolResults(executedResults: ExecutedResult[]): string {
|
||||
return executedResults
|
||||
.map((tr) => {
|
||||
if (!tr.isError) {
|
||||
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <result>${JSON.stringify(tr.result)}</result>\n</tool_use_result>`
|
||||
} else {
|
||||
const error = tr.result || 'Unknown error'
|
||||
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <error>${error}</error>\n</tool_use_result>`
|
||||
}
|
||||
})
|
||||
.join('\n\n')
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送工具调用开始相关事件
|
||||
*/
|
||||
private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
|
||||
// 发送 tool-input-start 事件
|
||||
controller.enqueue({
|
||||
type: 'tool-input-start',
|
||||
id: toolUse.id,
|
||||
toolName: toolUse.toolName
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理工具执行错误
|
||||
*/
|
||||
private handleToolError(
|
||||
toolUse: ToolUseResult,
|
||||
error: unknown,
|
||||
controller: StreamController
|
||||
// _tools: ToolSet
|
||||
): ExecutedResult {
|
||||
// 使用 AI SDK 标准错误格式
|
||||
// const toolError: TypedToolError<typeof _tools> = {
|
||||
// type: 'tool-error',
|
||||
// toolCallId: toolUse.id,
|
||||
// toolName: toolUse.toolName,
|
||||
// input: toolUse.arguments,
|
||||
// error: error instanceof Error ? error.message : String(error)
|
||||
// }
|
||||
|
||||
// controller.enqueue(toolError)
|
||||
|
||||
// 发送标准错误事件
|
||||
controller.enqueue({
|
||||
type: 'error',
|
||||
error: error instanceof Error ? error.message : String(error)
|
||||
})
|
||||
|
||||
return {
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
result: error instanceof Error ? error.message : String(error),
|
||||
isError: true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,373 @@
|
||||
/**
|
||||
* 内置插件:MCP Prompt 模式
|
||||
* 为不支持原生 Function Call 的模型提供 prompt 方式的工具调用
|
||||
* 内置默认逻辑,支持自定义覆盖
|
||||
*/
|
||||
import type { TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import { definePlugin } from '../../index'
|
||||
import type { AiRequestContext } from '../../types'
|
||||
import { StreamEventManager } from './StreamEventManager'
|
||||
import { ToolExecutor } from './ToolExecutor'
|
||||
import { PromptToolUseConfig, ToolUseResult } from './type'
|
||||
|
||||
/**
|
||||
* 默认系统提示符模板(提取自 Cherry Studio)
|
||||
*/
|
||||
const DEFAULT_SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \\
|
||||
You can use one tool per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
|
||||
|
||||
## Tool Use Formatting
|
||||
|
||||
Tool use is formatted using XML-style tags. The tool name is enclosed in opening and closing tags, and each parameter is similarly enclosed within its own set of tags. Here's the structure:
|
||||
|
||||
<tool_use>
|
||||
<name>{tool_name}</name>
|
||||
<arguments>{json_arguments}</arguments>
|
||||
</tool_use>
|
||||
|
||||
The tool name should be the exact name of the tool you are using, and the arguments should be a JSON object containing the parameters required by that tool. For example:
|
||||
<tool_use>
|
||||
<name>python_interpreter</name>
|
||||
<arguments>{"code": "5 + 3 + 1294.678"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
The user will respond with the result of the tool use, which should be formatted as follows:
|
||||
|
||||
<tool_use_result>
|
||||
<name>{tool_name}</name>
|
||||
<result>{result}</result>
|
||||
</tool_use_result>
|
||||
|
||||
The result should be a string, which can represent a file or any other output type. You can use this result as input for the next action.
|
||||
For example, if the result of the tool use is an image file, you can use it in the next action like this:
|
||||
|
||||
<tool_use>
|
||||
<name>image_transformer</name>
|
||||
<arguments>{"image": "image_1.jpg"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
Always adhere to this format for the tool use to ensure proper parsing and execution.
|
||||
|
||||
## Tool Use Examples
|
||||
{{ TOOL_USE_EXAMPLES }}
|
||||
|
||||
## Tool Use Available Tools
|
||||
Above example were using notional tools that might not exist for you. You only have access to these tools:
|
||||
{{ AVAILABLE_TOOLS }}
|
||||
|
||||
## Tool Use Rules
|
||||
Here are the rules you should always follow to solve your task:
|
||||
1. Always use the right arguments for the tools. Never use variable names as the action arguments, use the value instead.
|
||||
2. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself.
|
||||
3. If no tool call is needed, just answer the question directly.
|
||||
4. Never re-do a tool call that you previously did with the exact same parameters.
|
||||
5. For tool use, MAKE SURE use XML tag format as shown in the examples above. Do not use any other format.
|
||||
|
||||
# User Instructions
|
||||
{{ USER_SYSTEM_PROMPT }}
|
||||
|
||||
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.`
|
||||
|
||||
/**
|
||||
* 默认工具使用示例(提取自 Cherry Studio)
|
||||
*/
|
||||
const DEFAULT_TOOL_USE_EXAMPLES = `
|
||||
Here are a few examples using notional tools:
|
||||
---
|
||||
User: Generate an image of the oldest person in this document.
|
||||
|
||||
A: I can use the document_qa tool to find out who the oldest person is in the document.
|
||||
<tool_use>
|
||||
<name>document_qa</name>
|
||||
<arguments>{"document": "document.pdf", "question": "Who is the oldest person mentioned?"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>document_qa</name>
|
||||
<result>John Doe, a 55 year old lumberjack living in Newfoundland.</result>
|
||||
</tool_use_result>
|
||||
|
||||
A: I can use the image_generator tool to create a portrait of John Doe.
|
||||
<tool_use>
|
||||
<name>image_generator</name>
|
||||
<arguments>{"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>image_generator</name>
|
||||
<result>image.png</result>
|
||||
</tool_use_result>
|
||||
|
||||
A: the image is generated as image.png
|
||||
|
||||
---
|
||||
User: "What is the result of the following operation: 5 + 3 + 1294.678?"
|
||||
|
||||
A: I can use the python_interpreter tool to calculate the result of the operation.
|
||||
<tool_use>
|
||||
<name>python_interpreter</name>
|
||||
<arguments>{"code": "5 + 3 + 1294.678"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>python_interpreter</name>
|
||||
<result>1302.678</result>
|
||||
</tool_use_result>
|
||||
|
||||
A: The result of the operation is 1302.678.
|
||||
|
||||
---
|
||||
User: "Which city has the highest population , Guangzhou or Shanghai?"
|
||||
|
||||
A: I can use the search tool to find the population of Guangzhou.
|
||||
<tool_use>
|
||||
<name>search</name>
|
||||
<arguments>{"query": "Population Guangzhou"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>search</name>
|
||||
<result>Guangzhou has a population of 15 million inhabitants as of 2021.</result>
|
||||
</tool_use_result>
|
||||
|
||||
A: I can use the search tool to find the population of Shanghai.
|
||||
<tool_use>
|
||||
<name>search</name>
|
||||
<arguments>{"query": "Population Shanghai"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>search</name>
|
||||
<result>26 million (2019)</result>
|
||||
</tool_use_result>
|
||||
Assistant: The population of Shanghai is 26 million, while Guangzhou has a population of 15 million. Therefore, Shanghai has the highest population.`
|
||||
|
||||
/**
|
||||
* 构建可用工具部分(提取自 Cherry Studio)
|
||||
*/
|
||||
function buildAvailableTools(tools: ToolSet): string {
|
||||
const availableTools = Object.keys(tools)
|
||||
.map((toolName: string) => {
|
||||
const tool = tools[toolName]
|
||||
return `
|
||||
<tool>
|
||||
<name>${toolName}</name>
|
||||
<description>${tool.description || ''}</description>
|
||||
<arguments>
|
||||
${tool.inputSchema ? JSON.stringify(tool.inputSchema) : ''}
|
||||
</arguments>
|
||||
</tool>
|
||||
`
|
||||
})
|
||||
.join('\n')
|
||||
return `<tools>
|
||||
${availableTools}
|
||||
</tools>`
|
||||
}
|
||||
|
||||
/**
|
||||
* 默认的系统提示符构建函数(提取自 Cherry Studio)
|
||||
*/
|
||||
function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet): string {
|
||||
const availableTools = buildAvailableTools(tools)
|
||||
|
||||
const fullPrompt = DEFAULT_SYSTEM_PROMPT.replace('{{ TOOL_USE_EXAMPLES }}', DEFAULT_TOOL_USE_EXAMPLES)
|
||||
.replace('{{ AVAILABLE_TOOLS }}', availableTools)
|
||||
.replace('{{ USER_SYSTEM_PROMPT }}', userSystemPrompt || '')
|
||||
|
||||
return fullPrompt
|
||||
}
|
||||
|
||||
/**
|
||||
* 默认工具解析函数(提取自 Cherry Studio)
|
||||
* 解析 XML 格式的工具调用
|
||||
*/
|
||||
function defaultParseToolUse(content: string, tools: ToolSet): { results: ToolUseResult[]; content: string } {
|
||||
if (!content || !tools || Object.keys(tools).length === 0) {
|
||||
return { results: [], content: content }
|
||||
}
|
||||
|
||||
// 支持两种格式:
|
||||
// 1. 完整的 <tool_use></tool_use> 标签包围的内容
|
||||
// 2. 只有内部内容(从 TagExtractor 提取出来的)
|
||||
|
||||
let contentToProcess = content
|
||||
// 如果内容不包含 <tool_use> 标签,说明是从 TagExtractor 提取的内部内容,需要包装
|
||||
if (!content.includes('<tool_use>')) {
|
||||
contentToProcess = `<tool_use>\n${content}\n</tool_use>`
|
||||
}
|
||||
|
||||
const toolUsePattern =
|
||||
/<tool_use>([\s\S]*?)<name>([\s\S]*?)<\/name>([\s\S]*?)<arguments>([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g
|
||||
const results: ToolUseResult[] = []
|
||||
let match
|
||||
let idx = 0
|
||||
|
||||
// Find all tool use blocks
|
||||
while ((match = toolUsePattern.exec(contentToProcess)) !== null) {
|
||||
const fullMatch = match[0]
|
||||
const toolName = match[2].trim()
|
||||
const toolArgs = match[4].trim()
|
||||
|
||||
// Try to parse the arguments as JSON
|
||||
let parsedArgs
|
||||
try {
|
||||
parsedArgs = JSON.parse(toolArgs)
|
||||
} catch (error) {
|
||||
// If parsing fails, use the string as is
|
||||
parsedArgs = toolArgs
|
||||
}
|
||||
|
||||
// Find the corresponding tool
|
||||
const tool = tools[toolName]
|
||||
if (!tool) {
|
||||
console.warn(`Tool "${toolName}" not found in available tools`)
|
||||
continue
|
||||
}
|
||||
|
||||
// Add to results array
|
||||
results.push({
|
||||
id: `${toolName}-${idx++}`, // Unique ID for each tool use
|
||||
toolName: toolName,
|
||||
arguments: parsedArgs,
|
||||
status: 'pending'
|
||||
})
|
||||
contentToProcess = contentToProcess.replace(fullMatch, '')
|
||||
}
|
||||
return { results, content: contentToProcess }
|
||||
}
|
||||
|
||||
export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
const { enabled = true, buildSystemPrompt = defaultBuildSystemPrompt, parseToolUse = defaultParseToolUse } = config
|
||||
|
||||
return definePlugin({
|
||||
name: 'built-in:prompt-tool-use',
|
||||
transformParams: (params: any, context: AiRequestContext) => {
|
||||
if (!enabled || !params.tools || typeof params.tools !== 'object') {
|
||||
return params
|
||||
}
|
||||
|
||||
context.mcpTools = params.tools
|
||||
console.log('tools stored in context', params.tools)
|
||||
|
||||
// 构建系统提示符
|
||||
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
|
||||
const systemPrompt = buildSystemPrompt(userSystemPrompt, params.tools)
|
||||
let systemMessage: string | null = systemPrompt
|
||||
console.log('config.context', context)
|
||||
if (config.createSystemMessage) {
|
||||
// 🎯 如果用户提供了自定义处理函数,使用它
|
||||
systemMessage = config.createSystemMessage(systemPrompt, params, context)
|
||||
}
|
||||
|
||||
// 移除 tools,改为 prompt 模式
|
||||
const transformedParams = {
|
||||
...params,
|
||||
...(systemMessage ? { system: systemMessage } : {}),
|
||||
tools: undefined
|
||||
}
|
||||
context.originalParams = transformedParams
|
||||
console.log('transformedParams', transformedParams)
|
||||
return transformedParams
|
||||
},
|
||||
transformStream: (_: any, context: AiRequestContext) => () => {
|
||||
let textBuffer = ''
|
||||
let stepId = ''
|
||||
|
||||
if (!context.mcpTools) {
|
||||
throw new Error('No tools available')
|
||||
}
|
||||
|
||||
// 创建工具执行器和流事件管理器
|
||||
const toolExecutor = new ToolExecutor()
|
||||
const streamEventManager = new StreamEventManager()
|
||||
|
||||
type TOOLS = NonNullable<typeof context.mcpTools>
|
||||
return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
|
||||
async transform(
|
||||
chunk: TextStreamPart<TOOLS>,
|
||||
controller: TransformStreamDefaultController<TextStreamPart<TOOLS>>
|
||||
) {
|
||||
// 收集文本内容
|
||||
if (chunk.type === 'text-delta') {
|
||||
textBuffer += chunk.text || ''
|
||||
stepId = chunk.id || ''
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
|
||||
if (chunk.type === 'text-end' || chunk.type === 'finish-step') {
|
||||
const tools = context.mcpTools
|
||||
if (!tools || Object.keys(tools).length === 0) {
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析工具调用
|
||||
const { results: parsedTools, content: parsedContent } = parseToolUse(textBuffer, tools)
|
||||
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
|
||||
|
||||
// 如果没有有效的工具调用,直接传递原始事件
|
||||
if (validToolUses.length === 0) {
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
|
||||
if (chunk.type === 'text-end') {
|
||||
controller.enqueue({
|
||||
type: 'text-end',
|
||||
id: stepId,
|
||||
providerMetadata: {
|
||||
text: {
|
||||
value: parsedContent
|
||||
}
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
finishReason: 'tool-calls'
|
||||
})
|
||||
|
||||
// 发送步骤开始事件
|
||||
streamEventManager.sendStepStartEvent(controller)
|
||||
|
||||
// 执行工具调用
|
||||
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
|
||||
|
||||
// 发送步骤完成事件
|
||||
streamEventManager.sendStepFinishEvent(controller, chunk)
|
||||
|
||||
// 处理递归调用
|
||||
if (validToolUses.length > 0) {
|
||||
const toolResultsText = toolExecutor.formatToolResults(executedResults)
|
||||
const recursiveParams = streamEventManager.buildRecursiveParams(
|
||||
context,
|
||||
textBuffer,
|
||||
toolResultsText,
|
||||
tools
|
||||
)
|
||||
|
||||
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context, stepId)
|
||||
}
|
||||
|
||||
// 清理状态
|
||||
textBuffer = ''
|
||||
return
|
||||
}
|
||||
|
||||
// 对于其他类型的事件,直接传递
|
||||
controller.enqueue(chunk)
|
||||
},
|
||||
|
||||
flush() {
|
||||
// 流结束时的清理工作
|
||||
console.log('[MCP Prompt] Stream ended, cleaning up...')
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
// Copied from https://github.com/vercel/ai/blob/main/packages/ai/core/util/get-potential-start-index.ts
|
||||
|
||||
/**
|
||||
* Returns the index of the start of the searchedText in the text, or null if it
|
||||
* is not found.
|
||||
*/
|
||||
export function getPotentialStartIndex(text: string, searchedText: string): number | null {
|
||||
// Return null immediately if searchedText is empty.
|
||||
if (searchedText.length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Check if the searchedText exists as a direct substring of text.
|
||||
const directIndex = text.indexOf(searchedText)
|
||||
if (directIndex !== -1) {
|
||||
return directIndex
|
||||
}
|
||||
|
||||
// Otherwise, look for the largest suffix of "text" that matches
|
||||
// a prefix of "searchedText". We go from the end of text inward.
|
||||
for (let i = text.length - 1; i >= 0; i--) {
|
||||
const suffix = text.substring(i)
|
||||
if (searchedText.startsWith(suffix)) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
export interface TagConfig {
|
||||
openingTag: string
|
||||
closingTag: string
|
||||
separator?: string
|
||||
}
|
||||
|
||||
export interface TagExtractionState {
|
||||
textBuffer: string
|
||||
isInsideTag: boolean
|
||||
isFirstTag: boolean
|
||||
isFirstText: boolean
|
||||
afterSwitch: boolean
|
||||
accumulatedTagContent: string
|
||||
hasTagContent: boolean
|
||||
}
|
||||
|
||||
export interface TagExtractionResult {
|
||||
content: string
|
||||
isTagContent: boolean
|
||||
complete: boolean
|
||||
tagContentExtracted?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用标签提取处理器
|
||||
* 可以处理各种形式的标签对,如 <think>...</think>, <tool_use>...</tool_use> 等
|
||||
*/
|
||||
export class TagExtractor {
|
||||
private config: TagConfig
|
||||
private state: TagExtractionState
|
||||
|
||||
constructor(config: TagConfig) {
|
||||
this.config = config
|
||||
this.state = {
|
||||
textBuffer: '',
|
||||
isInsideTag: false,
|
||||
isFirstTag: true,
|
||||
isFirstText: true,
|
||||
afterSwitch: false,
|
||||
accumulatedTagContent: '',
|
||||
hasTagContent: false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理文本块,返回处理结果
|
||||
*/
|
||||
processText(newText: string): TagExtractionResult[] {
|
||||
this.state.textBuffer += newText
|
||||
const results: TagExtractionResult[] = []
|
||||
|
||||
// 处理标签提取逻辑
|
||||
while (true) {
|
||||
const nextTag = this.state.isInsideTag ? this.config.closingTag : this.config.openingTag
|
||||
const startIndex = getPotentialStartIndex(this.state.textBuffer, nextTag)
|
||||
|
||||
if (startIndex == null) {
|
||||
const content = this.state.textBuffer
|
||||
if (content.length > 0) {
|
||||
results.push({
|
||||
content: this.addPrefix(content),
|
||||
isTagContent: this.state.isInsideTag,
|
||||
complete: false
|
||||
})
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.accumulatedTagContent += this.addPrefix(content)
|
||||
this.state.hasTagContent = true
|
||||
}
|
||||
}
|
||||
this.state.textBuffer = ''
|
||||
break
|
||||
}
|
||||
|
||||
// 处理标签前的内容
|
||||
const contentBeforeTag = this.state.textBuffer.slice(0, startIndex)
|
||||
if (contentBeforeTag.length > 0) {
|
||||
results.push({
|
||||
content: this.addPrefix(contentBeforeTag),
|
||||
isTagContent: this.state.isInsideTag,
|
||||
complete: false
|
||||
})
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.accumulatedTagContent += this.addPrefix(contentBeforeTag)
|
||||
this.state.hasTagContent = true
|
||||
}
|
||||
}
|
||||
|
||||
const foundFullMatch = startIndex + nextTag.length <= this.state.textBuffer.length
|
||||
|
||||
if (foundFullMatch) {
|
||||
// 如果找到完整的标签
|
||||
this.state.textBuffer = this.state.textBuffer.slice(startIndex + nextTag.length)
|
||||
|
||||
// 如果刚刚结束一个标签内容,生成完整的标签内容结果
|
||||
if (this.state.isInsideTag && this.state.hasTagContent) {
|
||||
results.push({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: this.state.accumulatedTagContent
|
||||
})
|
||||
this.state.accumulatedTagContent = ''
|
||||
this.state.hasTagContent = false
|
||||
}
|
||||
|
||||
this.state.isInsideTag = !this.state.isInsideTag
|
||||
this.state.afterSwitch = true
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.isFirstTag = false
|
||||
} else {
|
||||
this.state.isFirstText = false
|
||||
}
|
||||
} else {
|
||||
this.state.textBuffer = this.state.textBuffer.slice(startIndex)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
/**
|
||||
* 完成处理,返回任何剩余的标签内容
|
||||
*/
|
||||
finalize(): TagExtractionResult | null {
|
||||
if (this.state.hasTagContent && this.state.accumulatedTagContent) {
|
||||
const result = {
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: this.state.accumulatedTagContent
|
||||
}
|
||||
this.state.accumulatedTagContent = ''
|
||||
this.state.hasTagContent = false
|
||||
return result
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
private addPrefix(text: string): string {
|
||||
const needsPrefix =
|
||||
this.state.afterSwitch && (this.state.isInsideTag ? !this.state.isFirstTag : !this.state.isFirstText)
|
||||
|
||||
const prefix = needsPrefix && this.config.separator ? this.config.separator : ''
|
||||
this.state.afterSwitch = false
|
||||
return prefix + text
|
||||
}
|
||||
|
||||
/**
|
||||
* 重置状态
|
||||
*/
|
||||
reset(): void {
|
||||
this.state = {
|
||||
textBuffer: '',
|
||||
isInsideTag: false,
|
||||
isFirstTag: true,
|
||||
isFirstText: true,
|
||||
afterSwitch: false,
|
||||
accumulatedTagContent: '',
|
||||
hasTagContent: false
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
import { ToolSet } from 'ai'
|
||||
|
||||
import { AiRequestContext } from '../..'
|
||||
|
||||
/**
|
||||
* 解析结果类型
|
||||
* 表示从AI响应中解析出的工具使用意图
|
||||
*/
|
||||
export interface ToolUseResult {
|
||||
id: string
|
||||
toolName: string
|
||||
arguments: any
|
||||
status: 'pending' | 'invoking' | 'done' | 'error'
|
||||
}
|
||||
|
||||
export interface BaseToolUsePluginConfig {
|
||||
enabled?: boolean
|
||||
}
|
||||
|
||||
export interface PromptToolUseConfig extends BaseToolUsePluginConfig {
|
||||
// 自定义系统提示符构建函数(可选,有默认实现)
|
||||
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => string
|
||||
// 自定义工具解析函数(可选,有默认实现)
|
||||
parseToolUse?: (content: string, tools: ToolSet) => { results: ToolUseResult[]; content: string }
|
||||
createSystemMessage?: (systemPrompt: string, originalParams: any, context: AiRequestContext) => string | null
|
||||
}
|
||||
|
||||
/**
|
||||
* 扩展的 AI 请求上下文,支持 MCP 工具存储
|
||||
*/
|
||||
export interface ToolUseRequestContext extends AiRequestContext {
|
||||
mcpTools: ToolSet
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
import { anthropic } from '@ai-sdk/anthropic'
|
||||
import { google } from '@ai-sdk/google'
|
||||
import { openai } from '@ai-sdk/openai'
|
||||
|
||||
import { ProviderOptionsMap } from '../../../options/types'
|
||||
|
||||
/**
|
||||
* 从 AI SDK 的工具函数中提取参数类型,以确保类型安全。
|
||||
*/
|
||||
type OpenAISearchConfig = Parameters<typeof openai.tools.webSearchPreview>[0]
|
||||
type AnthropicSearchConfig = Parameters<typeof anthropic.tools.webSearch_20250305>[0]
|
||||
type GoogleSearchConfig = Parameters<typeof google.tools.googleSearch>[0]
|
||||
|
||||
/**
|
||||
* 插件初始化时接收的完整配置对象
|
||||
*
|
||||
* 其结构与 ProviderOptions 保持一致,方便上游统一管理配置
|
||||
*/
|
||||
export interface WebSearchPluginConfig {
|
||||
openai?: OpenAISearchConfig
|
||||
anthropic?: AnthropicSearchConfig
|
||||
xai?: ProviderOptionsMap['xai']['searchParameters']
|
||||
google?: GoogleSearchConfig
|
||||
'google-vertex'?: GoogleSearchConfig
|
||||
}
|
||||
|
||||
/**
|
||||
* 插件的默认配置
|
||||
*/
|
||||
export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
|
||||
google: {},
|
||||
'google-vertex': {},
|
||||
openai: {},
|
||||
xai: {
|
||||
mode: 'on',
|
||||
returnCitations: true,
|
||||
maxSearchResults: 5,
|
||||
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
|
||||
},
|
||||
anthropic: {
|
||||
maxUses: 5
|
||||
}
|
||||
}
|
||||
|
||||
export type WebSearchToolOutputSchema = {
|
||||
// Anthropic 工具 - 手动定义
|
||||
anthropicWebSearch: Array<{
|
||||
url: string
|
||||
title: string
|
||||
pageAge: string | null
|
||||
encryptedContent: string
|
||||
type: string
|
||||
}>
|
||||
|
||||
// OpenAI 工具 - 基于实际输出
|
||||
openaiWebSearch: {
|
||||
status: 'completed' | 'failed'
|
||||
}
|
||||
|
||||
// Google 工具
|
||||
googleSearch: {
|
||||
webSearchQueries?: string[]
|
||||
groundingChunks?: Array<{
|
||||
web?: { uri: string; title: string }
|
||||
}>
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
/**
|
||||
* Web Search Plugin
|
||||
* 提供统一的网络搜索能力,支持多个 AI Provider
|
||||
*/
|
||||
import { anthropic } from '@ai-sdk/anthropic'
|
||||
import { google } from '@ai-sdk/google'
|
||||
import { openai } from '@ai-sdk/openai'
|
||||
|
||||
import { createXaiOptions, mergeProviderOptions } from '../../../options'
|
||||
import { definePlugin } from '../../'
|
||||
import type { AiRequestContext } from '../../types'
|
||||
import { DEFAULT_WEB_SEARCH_CONFIG, WebSearchPluginConfig } from './helper'
|
||||
|
||||
/**
|
||||
* 网络搜索插件
|
||||
*
|
||||
* @param config - 在插件初始化时传入的静态配置
|
||||
*/
|
||||
export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEARCH_CONFIG) =>
|
||||
definePlugin({
|
||||
name: 'webSearch',
|
||||
enforce: 'pre',
|
||||
|
||||
transformParams: async (params: any, context: AiRequestContext) => {
|
||||
const { providerId } = context
|
||||
switch (providerId) {
|
||||
case 'openai': {
|
||||
if (config.openai) {
|
||||
if (!params.tools) params.tools = {}
|
||||
params.tools.web_search_preview = openai.tools.webSearchPreview(config.openai)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case 'anthropic': {
|
||||
if (config.anthropic) {
|
||||
if (!params.tools) params.tools = {}
|
||||
params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case 'google': {
|
||||
// case 'google-vertex':
|
||||
if (!params.tools) params.tools = {}
|
||||
params.tools.web_search = google.tools.googleSearch(config.google || {})
|
||||
break
|
||||
}
|
||||
|
||||
case 'xai': {
|
||||
if (config.xai) {
|
||||
const searchOptions = createXaiOptions({
|
||||
searchParameters: { ...config.xai, mode: 'on' }
|
||||
})
|
||||
params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
})
|
||||
|
||||
// 导出类型定义供开发者使用
|
||||
export type { WebSearchPluginConfig, WebSearchToolOutputSchema } from './helper'
|
||||
|
||||
// 默认导出
|
||||
export default webSearchPlugin
|
||||
32
packages/aiCore/src/core/plugins/index.ts
Normal file
32
packages/aiCore/src/core/plugins/index.ts
Normal file
@@ -0,0 +1,32 @@
|
||||
// 核心类型和接口
|
||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types'
|
||||
import type { ProviderId } from '../providers'
|
||||
import type { AiPlugin, AiRequestContext } from './types'
|
||||
|
||||
// 插件管理器
|
||||
export { PluginManager } from './manager'
|
||||
|
||||
// 工具函数
|
||||
export function createContext<T extends ProviderId>(
|
||||
providerId: T,
|
||||
modelId: string,
|
||||
originalParams: any
|
||||
): AiRequestContext {
|
||||
return {
|
||||
providerId,
|
||||
modelId,
|
||||
originalParams,
|
||||
metadata: {},
|
||||
startTime: Date.now(),
|
||||
requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
|
||||
// 占位
|
||||
recursiveCall: () => Promise.resolve(null)
|
||||
}
|
||||
}
|
||||
|
||||
// 插件构建器 - 便于创建插件
|
||||
export function definePlugin(plugin: AiPlugin): AiPlugin
|
||||
export function definePlugin<T extends (...args: any[]) => AiPlugin>(pluginFactory: T): T
|
||||
export function definePlugin(plugin: AiPlugin | ((...args: any[]) => AiPlugin)) {
|
||||
return plugin
|
||||
}
|
||||
184
packages/aiCore/src/core/plugins/manager.ts
Normal file
184
packages/aiCore/src/core/plugins/manager.ts
Normal file
@@ -0,0 +1,184 @@
|
||||
import { AiPlugin, AiRequestContext } from './types'
|
||||
|
||||
/**
|
||||
* 插件管理器
|
||||
*/
|
||||
export class PluginManager {
|
||||
private plugins: AiPlugin[] = []
|
||||
|
||||
constructor(plugins: AiPlugin[] = []) {
|
||||
this.plugins = this.sortPlugins(plugins)
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加插件
|
||||
*/
|
||||
use(plugin: AiPlugin): this {
|
||||
this.plugins = this.sortPlugins([...this.plugins, plugin])
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 移除插件
|
||||
*/
|
||||
remove(pluginName: string): this {
|
||||
this.plugins = this.plugins.filter((p) => p.name !== pluginName)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 插件排序:pre -> normal -> post
|
||||
*/
|
||||
private sortPlugins(plugins: AiPlugin[]): AiPlugin[] {
|
||||
const pre: AiPlugin[] = []
|
||||
const normal: AiPlugin[] = []
|
||||
const post: AiPlugin[] = []
|
||||
|
||||
plugins.forEach((plugin) => {
|
||||
if (plugin.enforce === 'pre') {
|
||||
pre.push(plugin)
|
||||
} else if (plugin.enforce === 'post') {
|
||||
post.push(plugin)
|
||||
} else {
|
||||
normal.push(plugin)
|
||||
}
|
||||
})
|
||||
|
||||
return [...pre, ...normal, ...post]
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 First 钩子 - 返回第一个有效结果
|
||||
*/
|
||||
async executeFirst<T>(
|
||||
hookName: 'resolveModel' | 'loadTemplate',
|
||||
arg: any,
|
||||
context: AiRequestContext
|
||||
): Promise<T | null> {
|
||||
for (const plugin of this.plugins) {
|
||||
const hook = plugin[hookName]
|
||||
if (hook) {
|
||||
const result = await hook(arg, context)
|
||||
if (result !== null && result !== undefined) {
|
||||
return result as T
|
||||
}
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 Sequential 钩子 - 链式数据转换
|
||||
*/
|
||||
async executeSequential<T>(
|
||||
hookName: 'transformParams' | 'transformResult',
|
||||
initialValue: T,
|
||||
context: AiRequestContext
|
||||
): Promise<T> {
|
||||
let result = initialValue
|
||||
|
||||
for (const plugin of this.plugins) {
|
||||
const hook = plugin[hookName]
|
||||
if (hook) {
|
||||
result = await hook<T>(result, context)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 ConfigureContext 钩子 - 串行配置上下文
|
||||
*/
|
||||
async executeConfigureContext(context: AiRequestContext): Promise<void> {
|
||||
for (const plugin of this.plugins) {
|
||||
const hook = plugin.configureContext
|
||||
if (hook) {
|
||||
await hook(context)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 Parallel 钩子 - 并行副作用
|
||||
*/
|
||||
async executeParallel(
|
||||
hookName: 'onRequestStart' | 'onRequestEnd' | 'onError',
|
||||
context: AiRequestContext,
|
||||
result?: any,
|
||||
error?: Error
|
||||
): Promise<void> {
|
||||
const promises = this.plugins
|
||||
.map((plugin) => {
|
||||
const hook = plugin[hookName]
|
||||
if (!hook) return null
|
||||
|
||||
if (hookName === 'onError' && error) {
|
||||
return (hook as any)(error, context)
|
||||
} else if (hookName === 'onRequestEnd' && result !== undefined) {
|
||||
return (hook as any)(context, result)
|
||||
} else if (hookName === 'onRequestStart') {
|
||||
return (hook as any)(context)
|
||||
}
|
||||
return null
|
||||
})
|
||||
.filter(Boolean)
|
||||
|
||||
// 使用 Promise.all 而不是 allSettled,让插件错误能够抛出
|
||||
await Promise.all(promises)
|
||||
}
|
||||
|
||||
/**
|
||||
* 收集所有流转换器(返回数组,AI SDK 原生支持)
|
||||
*/
|
||||
collectStreamTransforms(params: any, context: AiRequestContext) {
|
||||
return this.plugins
|
||||
.filter((plugin) => plugin.transformStream)
|
||||
.map((plugin) => plugin.transformStream?.(params, context))
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有插件信息
|
||||
*/
|
||||
getPlugins(): AiPlugin[] {
|
||||
return [...this.plugins]
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取插件统计信息
|
||||
*/
|
||||
getStats() {
|
||||
const stats = {
|
||||
total: this.plugins.length,
|
||||
pre: 0,
|
||||
normal: 0,
|
||||
post: 0,
|
||||
hooks: {
|
||||
resolveModel: 0,
|
||||
loadTemplate: 0,
|
||||
transformParams: 0,
|
||||
transformResult: 0,
|
||||
onRequestStart: 0,
|
||||
onRequestEnd: 0,
|
||||
onError: 0,
|
||||
transformStream: 0
|
||||
}
|
||||
}
|
||||
|
||||
this.plugins.forEach((plugin) => {
|
||||
// 统计 enforce 类型
|
||||
if (plugin.enforce === 'pre') stats.pre++
|
||||
else if (plugin.enforce === 'post') stats.post++
|
||||
else stats.normal++
|
||||
|
||||
// 统计钩子数量
|
||||
Object.keys(stats.hooks).forEach((hookName) => {
|
||||
if (plugin[hookName as keyof AiPlugin]) {
|
||||
stats.hooks[hookName as keyof typeof stats.hooks]++
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
return stats
|
||||
}
|
||||
}
|
||||
79
packages/aiCore/src/core/plugins/types.ts
Normal file
79
packages/aiCore/src/core/plugins/types.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
import type { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import type { LanguageModel, TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import { type ProviderId } from '../providers/types'
|
||||
|
||||
/**
|
||||
* 递归调用函数类型
|
||||
* 使用 any 是因为递归调用时参数和返回类型可能完全不同
|
||||
*/
|
||||
export type RecursiveCallFn = (newParams: any) => Promise<any>
|
||||
|
||||
/**
|
||||
* AI 请求上下文
|
||||
*/
|
||||
export interface AiRequestContext {
|
||||
providerId: ProviderId
|
||||
modelId: string
|
||||
originalParams: any
|
||||
metadata: Record<string, any>
|
||||
startTime: number
|
||||
requestId: string
|
||||
recursiveCall: RecursiveCallFn
|
||||
isRecursiveCall?: boolean
|
||||
mcpTools?: ToolSet
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
/**
|
||||
* 钩子分类
|
||||
*/
|
||||
export interface AiPlugin {
|
||||
name: string
|
||||
enforce?: 'pre' | 'post'
|
||||
|
||||
// 【First】首个钩子 - 只执行第一个返回值的插件
|
||||
resolveModel?: (
|
||||
modelId: string,
|
||||
context: AiRequestContext
|
||||
) => Promise<LanguageModel | ImageModelV2 | null> | LanguageModel | ImageModelV2 | null
|
||||
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
|
||||
|
||||
// 【Sequential】串行钩子 - 链式执行,支持数据转换
|
||||
configureContext?: (context: AiRequestContext) => void | Promise<void>
|
||||
transformParams?: <T>(params: T, context: AiRequestContext) => T | Promise<T>
|
||||
transformResult?: <T>(result: T, context: AiRequestContext) => T | Promise<T>
|
||||
|
||||
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
|
||||
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
|
||||
onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise<void>
|
||||
onError?: (error: Error, context: AiRequestContext) => void | Promise<void>
|
||||
|
||||
// 【Stream】流处理 - 直接使用 AI SDK
|
||||
transformStream?: (
|
||||
params: any,
|
||||
context: AiRequestContext
|
||||
) => <TOOLS extends ToolSet>(options?: {
|
||||
tools: TOOLS
|
||||
stopStream: () => void
|
||||
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
|
||||
|
||||
// AI SDK 原生中间件
|
||||
// aiSdkMiddlewares?: LanguageModelV1Middleware[]
|
||||
}
|
||||
|
||||
/**
|
||||
* 插件管理器配置
|
||||
*/
|
||||
export interface PluginManagerConfig {
|
||||
plugins: AiPlugin[]
|
||||
context: Partial<AiRequestContext>
|
||||
}
|
||||
|
||||
/**
|
||||
* 钩子执行结果
|
||||
*/
|
||||
export interface HookResult<T = any> {
|
||||
value: T
|
||||
stop?: boolean
|
||||
}
|
||||
101
packages/aiCore/src/core/providers/HubProvider.ts
Normal file
101
packages/aiCore/src/core/providers/HubProvider.ts
Normal file
@@ -0,0 +1,101 @@
|
||||
/**
|
||||
* Hub Provider - 支持路由到多个底层provider
|
||||
*
|
||||
* 支持格式: hubId:providerId:modelId
|
||||
* 例如: aihubmix:anthropic:claude-3.5-sonnet
|
||||
*/
|
||||
|
||||
import { ProviderV2 } from '@ai-sdk/provider'
|
||||
import { customProvider } from 'ai'
|
||||
|
||||
import { globalRegistryManagement } from './RegistryManagement'
|
||||
import type { AiSdkMethodName, AiSdkModelReturn, AiSdkModelType } from './types'
|
||||
|
||||
export interface HubProviderConfig {
|
||||
/** Hub的唯一标识符 */
|
||||
hubId: string
|
||||
/** 是否启用调试日志 */
|
||||
debug?: boolean
|
||||
}
|
||||
|
||||
export class HubProviderError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public readonly hubId: string,
|
||||
public readonly providerId?: string,
|
||||
public readonly originalError?: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'HubProviderError'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析Hub模型ID
|
||||
*/
|
||||
function parseHubModelId(modelId: string): { provider: string; actualModelId: string } {
|
||||
const parts = modelId.split(':')
|
||||
if (parts.length !== 2) {
|
||||
throw new HubProviderError(`Invalid hub model ID format. Expected "provider:modelId", got: ${modelId}`, 'unknown')
|
||||
}
|
||||
return {
|
||||
provider: parts[0],
|
||||
actualModelId: parts[1]
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建Hub Provider
|
||||
*/
|
||||
export function createHubProvider(config: HubProviderConfig): ProviderV2 {
|
||||
const { hubId } = config
|
||||
|
||||
function getTargetProvider(providerId: string): ProviderV2 {
|
||||
// 从全局注册表获取provider实例
|
||||
try {
|
||||
const provider = globalRegistryManagement.getProvider(providerId)
|
||||
if (!provider) {
|
||||
throw new HubProviderError(
|
||||
`Provider "${providerId}" is not initialized. Please call initializeProvider("${providerId}", options) first.`,
|
||||
hubId,
|
||||
providerId
|
||||
)
|
||||
}
|
||||
return provider
|
||||
} catch (error) {
|
||||
throw new HubProviderError(
|
||||
`Failed to get provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
hubId,
|
||||
providerId,
|
||||
error instanceof Error ? error : undefined
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
function resolveModel<T extends AiSdkModelType>(
|
||||
modelId: string,
|
||||
modelType: T,
|
||||
methodName: AiSdkMethodName<T>
|
||||
): AiSdkModelReturn<T> {
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
|
||||
const fn = targetProvider[methodName] as (id: string) => AiSdkModelReturn<T>
|
||||
|
||||
if (!fn) {
|
||||
throw new HubProviderError(`Provider "${provider}" does not support ${modelType}`, hubId, provider)
|
||||
}
|
||||
|
||||
return fn(actualModelId)
|
||||
}
|
||||
|
||||
return customProvider({
|
||||
fallbackProvider: {
|
||||
languageModel: (modelId: string) => resolveModel(modelId, 'text', 'languageModel'),
|
||||
textEmbeddingModel: (modelId: string) => resolveModel(modelId, 'embedding', 'textEmbeddingModel'),
|
||||
imageModel: (modelId: string) => resolveModel(modelId, 'image', 'imageModel'),
|
||||
transcriptionModel: (modelId: string) => resolveModel(modelId, 'transcription', 'transcriptionModel'),
|
||||
speechModel: (modelId: string) => resolveModel(modelId, 'speech', 'speechModel')
|
||||
}
|
||||
})
|
||||
}
|
||||
221
packages/aiCore/src/core/providers/RegistryManagement.ts
Normal file
221
packages/aiCore/src/core/providers/RegistryManagement.ts
Normal file
@@ -0,0 +1,221 @@
|
||||
/**
|
||||
* Provider 注册表管理器
|
||||
* 纯粹的管理功能:存储、检索已配置好的 provider 实例
|
||||
* 基于 AI SDK 原生的 createProviderRegistry
|
||||
*/
|
||||
|
||||
import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, ProviderV2 } from '@ai-sdk/provider'
|
||||
import { createProviderRegistry, type ProviderRegistryProvider } from 'ai'
|
||||
|
||||
type PROVIDERS = Record<string, ProviderV2>
|
||||
|
||||
export const DEFAULT_SEPARATOR = '|'
|
||||
|
||||
// export type MODEL_ID = `${string}${typeof DEFAULT_SEPARATOR}${string}`
|
||||
|
||||
export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARATOR> {
|
||||
private providers: PROVIDERS = {}
|
||||
private aliases: Set<string> = new Set() // 记录哪些key是别名
|
||||
private separator: SEPARATOR
|
||||
private registry: ProviderRegistryProvider<PROVIDERS, SEPARATOR> | null = null
|
||||
|
||||
constructor(options: { separator: SEPARATOR } = { separator: DEFAULT_SEPARATOR as SEPARATOR }) {
|
||||
this.separator = options.separator
|
||||
}
|
||||
|
||||
/**
|
||||
* 注册已配置好的 provider 实例
|
||||
*/
|
||||
registerProvider(id: string, provider: ProviderV2, aliases?: string[]): this {
|
||||
// 注册主provider
|
||||
this.providers[id] = provider
|
||||
|
||||
// 注册别名(都指向同一个provider实例)
|
||||
if (aliases) {
|
||||
aliases.forEach((alias) => {
|
||||
this.providers[alias] = provider // 直接存储引用
|
||||
this.aliases.add(alias) // 标记为别名
|
||||
})
|
||||
}
|
||||
|
||||
this.rebuildRegistry()
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取已注册的provider实例
|
||||
*/
|
||||
getProvider(id: string): ProviderV2 | undefined {
|
||||
return this.providers[id]
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量注册 providers
|
||||
*/
|
||||
registerProviders(providers: Record<string, ProviderV2>): this {
|
||||
Object.assign(this.providers, providers)
|
||||
this.rebuildRegistry()
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 移除 provider(同时清理相关别名)
|
||||
*/
|
||||
unregisterProvider(id: string): this {
|
||||
const provider = this.providers[id]
|
||||
if (!provider) return this
|
||||
|
||||
// 如果移除的是真实ID,需要清理所有指向它的别名
|
||||
if (!this.aliases.has(id)) {
|
||||
// 找到所有指向此provider的别名并删除
|
||||
const aliasesToRemove: string[] = []
|
||||
this.aliases.forEach((alias) => {
|
||||
if (this.providers[alias] === provider) {
|
||||
aliasesToRemove.push(alias)
|
||||
}
|
||||
})
|
||||
|
||||
aliasesToRemove.forEach((alias) => {
|
||||
delete this.providers[alias]
|
||||
this.aliases.delete(alias)
|
||||
})
|
||||
} else {
|
||||
// 如果移除的是别名,只删除别名记录
|
||||
this.aliases.delete(id)
|
||||
}
|
||||
|
||||
delete this.providers[id]
|
||||
this.rebuildRegistry()
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 立即重建 registry - 每次变更都重建
|
||||
*/
|
||||
private rebuildRegistry(): void {
|
||||
if (Object.keys(this.providers).length === 0) {
|
||||
this.registry = null
|
||||
return
|
||||
}
|
||||
|
||||
this.registry = createProviderRegistry<PROVIDERS, SEPARATOR>(this.providers, {
|
||||
separator: this.separator
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取语言模型 - AI SDK 原生方法
|
||||
*/
|
||||
languageModel(id: `${string}${SEPARATOR}${string}`): LanguageModelV2 {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.languageModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取文本嵌入模型 - AI SDK 原生方法
|
||||
*/
|
||||
textEmbeddingModel(id: `${string}${SEPARATOR}${string}`): EmbeddingModelV2<string> {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.textEmbeddingModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取图像模型 - AI SDK 原生方法
|
||||
*/
|
||||
imageModel(id: `${string}${SEPARATOR}${string}`): ImageModelV2 {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.imageModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取转录模型 - AI SDK 原生方法
|
||||
*/
|
||||
transcriptionModel(id: `${string}${SEPARATOR}${string}`): any {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.transcriptionModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取语音模型 - AI SDK 原生方法
|
||||
*/
|
||||
speechModel(id: `${string}${SEPARATOR}${string}`): any {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.speechModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取已注册的 provider 列表
|
||||
*/
|
||||
getRegisteredProviders(): string[] {
|
||||
return Object.keys(this.providers)
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否有已注册的 providers
|
||||
*/
|
||||
hasProviders(): boolean {
|
||||
return Object.keys(this.providers).length > 0
|
||||
}
|
||||
|
||||
/**
|
||||
* 清除所有 providers
|
||||
*/
|
||||
clear(): this {
|
||||
this.providers = {}
|
||||
this.aliases.clear()
|
||||
this.registry = null
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析真实的Provider ID(供getAiSdkProviderId使用)
|
||||
* 如果传入的是别名,返回真实的Provider ID
|
||||
* 如果传入的是真实ID,直接返回
|
||||
*/
|
||||
resolveProviderId(id: string): string {
|
||||
if (!this.aliases.has(id)) return id // 不是别名,直接返回
|
||||
|
||||
// 是别名,找到真实ID
|
||||
const targetProvider = this.providers[id]
|
||||
for (const [realId, provider] of Object.entries(this.providers)) {
|
||||
if (provider === targetProvider && !this.aliases.has(realId)) {
|
||||
return realId
|
||||
}
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否为别名
|
||||
*/
|
||||
isAlias(id: string): boolean {
|
||||
return this.aliases.has(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有别名映射关系
|
||||
*/
|
||||
getAllAliases(): Record<string, string> {
|
||||
const result: Record<string, string> = {}
|
||||
this.aliases.forEach((alias) => {
|
||||
result[alias] = this.resolveProviderId(alias)
|
||||
})
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 全局注册表管理器实例
|
||||
* 使用 | 作为分隔符,因为 : 会和 :free 等suffix冲突
|
||||
*/
|
||||
export const globalRegistryManagement = new RegistryManagement()
|
||||
@@ -0,0 +1,632 @@
|
||||
/**
|
||||
* 测试真正的 AiProviderRegistry 功能
|
||||
*/
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
// 模拟 AI SDK
|
||||
vi.mock('@ai-sdk/openai', () => ({
|
||||
createOpenAI: vi.fn(() => ({ name: 'openai-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/anthropic', () => ({
|
||||
createAnthropic: vi.fn(() => ({ name: 'anthropic-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/azure', () => ({
|
||||
createAzure: vi.fn(() => ({ name: 'azure-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/deepseek', () => ({
|
||||
createDeepSeek: vi.fn(() => ({ name: 'deepseek-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/google', () => ({
|
||||
createGoogleGenerativeAI: vi.fn(() => ({ name: 'google-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/openai-compatible', () => ({
|
||||
createOpenAICompatible: vi.fn(() => ({ name: 'openai-compatible-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/xai', () => ({
|
||||
createXai: vi.fn(() => ({ name: 'xai-mock' }))
|
||||
}))
|
||||
|
||||
import {
|
||||
cleanup,
|
||||
clearAllProviders,
|
||||
createAndRegisterProvider,
|
||||
createProvider,
|
||||
getAllProviderConfigAliases,
|
||||
getAllProviderConfigs,
|
||||
getInitializedProviders,
|
||||
getLanguageModel,
|
||||
getProviderConfig,
|
||||
getProviderConfigByAlias,
|
||||
getSupportedProviders,
|
||||
hasInitializedProviders,
|
||||
hasProviderConfig,
|
||||
hasProviderConfigByAlias,
|
||||
isProviderConfigAlias,
|
||||
ProviderInitializationError,
|
||||
providerRegistry,
|
||||
registerMultipleProviderConfigs,
|
||||
registerProvider,
|
||||
registerProviderConfig,
|
||||
resolveProviderConfigId
|
||||
} from '../registry'
|
||||
import type { ProviderConfig } from '../schemas'
|
||||
|
||||
describe('Provider Registry 功能测试', () => {
|
||||
beforeEach(() => {
|
||||
// 清理状态
|
||||
cleanup()
|
||||
})
|
||||
|
||||
describe('基础功能', () => {
|
||||
it('能够获取支持的 providers 列表', () => {
|
||||
const providers = getSupportedProviders()
|
||||
expect(Array.isArray(providers)).toBe(true)
|
||||
expect(providers.length).toBeGreaterThan(0)
|
||||
|
||||
// 检查返回的数据结构
|
||||
providers.forEach((provider) => {
|
||||
expect(provider).toHaveProperty('id')
|
||||
expect(provider).toHaveProperty('name')
|
||||
expect(typeof provider.id).toBe('string')
|
||||
expect(typeof provider.name).toBe('string')
|
||||
})
|
||||
|
||||
// 包含基础 providers
|
||||
const providerIds = providers.map((p) => p.id)
|
||||
expect(providerIds).toContain('openai')
|
||||
expect(providerIds).toContain('anthropic')
|
||||
expect(providerIds).toContain('google')
|
||||
})
|
||||
|
||||
it('能够获取已初始化的 providers', () => {
|
||||
// 初始状态下没有已初始化的 providers
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
expect(hasInitializedProviders()).toBe(false)
|
||||
})
|
||||
|
||||
it('能够访问全局注册管理器', () => {
|
||||
expect(providerRegistry).toBeDefined()
|
||||
expect(typeof providerRegistry.clear).toBe('function')
|
||||
expect(typeof providerRegistry.getRegisteredProviders).toBe('function')
|
||||
expect(typeof providerRegistry.hasProviders).toBe('function')
|
||||
})
|
||||
|
||||
it('能够获取语言模型', () => {
|
||||
// 在没有注册 provider 的情况下,这个函数应该会抛出错误
|
||||
expect(() => getLanguageModel('non-existent')).toThrow('No providers registered')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider 配置注册', () => {
|
||||
it('能够注册自定义 provider 配置', () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
creator: vi.fn(() => ({ name: 'custom' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
|
||||
expect(hasProviderConfig('custom-provider')).toBe(true)
|
||||
expect(getProviderConfig('custom-provider')).toEqual(config)
|
||||
})
|
||||
|
||||
it('能够注册带别名的 provider 配置', () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'custom-provider-with-aliases',
|
||||
name: 'Custom Provider with Aliases',
|
||||
creator: vi.fn(() => ({ name: 'custom-aliased' })),
|
||||
supportsImageGeneration: false,
|
||||
aliases: ['alias-1', 'alias-2']
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
|
||||
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
|
||||
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
|
||||
expect(getProviderConfigByAlias('alias-1')).toEqual(config)
|
||||
expect(resolveProviderConfigId('alias-1')).toBe('custom-provider-with-aliases')
|
||||
})
|
||||
|
||||
it('拒绝无效的配置', () => {
|
||||
// 缺少必要字段
|
||||
const invalidConfig = {
|
||||
id: 'invalid-provider'
|
||||
// 缺少 name, creator 等
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(invalidConfig as any)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('能够批量注册 provider 配置', () => {
|
||||
const configs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'provider-1',
|
||||
name: 'Provider 1',
|
||||
creator: vi.fn(() => ({ name: 'provider-1' })),
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'provider-2',
|
||||
name: 'Provider 2',
|
||||
creator: vi.fn(() => ({ name: 'provider-2' })),
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: '', // 无效配置
|
||||
name: 'Invalid Provider',
|
||||
creator: vi.fn(() => ({ name: 'invalid' })),
|
||||
supportsImageGeneration: false
|
||||
} as any
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(configs)
|
||||
expect(successCount).toBe(2) // 只有前两个成功
|
||||
|
||||
expect(hasProviderConfig('provider-1')).toBe(true)
|
||||
expect(hasProviderConfig('provider-2')).toBe(true)
|
||||
expect(hasProviderConfig('')).toBe(false)
|
||||
})
|
||||
|
||||
it('能够获取所有配置和别名信息', () => {
|
||||
// 注册一些配置
|
||||
registerProviderConfig({
|
||||
id: 'test-provider',
|
||||
name: 'Test Provider',
|
||||
creator: vi.fn(),
|
||||
supportsImageGeneration: false,
|
||||
aliases: ['test-alias']
|
||||
})
|
||||
|
||||
const allConfigs = getAllProviderConfigs()
|
||||
expect(Array.isArray(allConfigs)).toBe(true)
|
||||
expect(allConfigs.some((config) => config.id === 'test-provider')).toBe(true)
|
||||
|
||||
const aliases = getAllProviderConfigAliases()
|
||||
expect(aliases['test-alias']).toBe('test-provider')
|
||||
expect(isProviderConfigAlias('test-alias')).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider 创建和注册', () => {
|
||||
it('能够创建 provider 实例', async () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'test-create-provider',
|
||||
name: 'Test Create Provider',
|
||||
creator: vi.fn(() => ({ name: 'test-created' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 先注册配置
|
||||
registerProviderConfig(config)
|
||||
|
||||
// 创建 provider 实例
|
||||
const provider = await createProvider('test-create-provider', { apiKey: 'test' })
|
||||
expect(provider).toBeDefined()
|
||||
expect(config.creator).toHaveBeenCalledWith({ apiKey: 'test' })
|
||||
})
|
||||
|
||||
it('能够注册 provider 到全局管理器', () => {
|
||||
const mockProvider = { name: 'mock-provider' }
|
||||
const config: ProviderConfig = {
|
||||
id: 'test-register-provider',
|
||||
name: 'Test Register Provider',
|
||||
creator: vi.fn(() => mockProvider),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 先注册配置
|
||||
registerProviderConfig(config)
|
||||
|
||||
// 注册 provider 到全局管理器
|
||||
const success = registerProvider('test-register-provider', mockProvider)
|
||||
expect(success).toBe(true)
|
||||
|
||||
// 验证注册成功
|
||||
const registeredProviders = getInitializedProviders()
|
||||
expect(registeredProviders).toContain('test-register-provider')
|
||||
expect(hasInitializedProviders()).toBe(true)
|
||||
})
|
||||
|
||||
it('能够一步完成创建和注册', async () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'test-create-and-register',
|
||||
name: 'Test Create and Register',
|
||||
creator: vi.fn(() => ({ name: 'test-both' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 先注册配置
|
||||
registerProviderConfig(config)
|
||||
|
||||
// 一步完成创建和注册
|
||||
const success = await createAndRegisterProvider('test-create-and-register', { apiKey: 'test' })
|
||||
expect(success).toBe(true)
|
||||
|
||||
// 验证注册成功
|
||||
const registeredProviders = getInitializedProviders()
|
||||
expect(registeredProviders).toContain('test-create-and-register')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Registry 管理', () => {
|
||||
it('能够清理所有配置和注册的 providers', () => {
|
||||
// 注册一些配置
|
||||
registerProviderConfig({
|
||||
id: 'temp-provider',
|
||||
name: 'Temp Provider',
|
||||
creator: vi.fn(() => ({ name: 'temp' })),
|
||||
supportsImageGeneration: false
|
||||
})
|
||||
|
||||
expect(hasProviderConfig('temp-provider')).toBe(true)
|
||||
|
||||
// 清理
|
||||
cleanup()
|
||||
|
||||
expect(hasProviderConfig('temp-provider')).toBe(false)
|
||||
// 但基础配置应该重新加载
|
||||
expect(hasProviderConfig('openai')).toBe(true) // 基础 providers 会重新初始化
|
||||
})
|
||||
|
||||
it('能够单独清理已注册的 providers', () => {
|
||||
// 清理所有 providers
|
||||
clearAllProviders()
|
||||
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
expect(hasInitializedProviders()).toBe(false)
|
||||
})
|
||||
|
||||
it('ProviderInitializationError 错误类工作正常', () => {
|
||||
const error = new ProviderInitializationError('Test error', 'test-provider')
|
||||
expect(error.message).toBe('Test error')
|
||||
expect(error.providerId).toBe('test-provider')
|
||||
expect(error.name).toBe('ProviderInitializationError')
|
||||
})
|
||||
})
|
||||
|
||||
describe('错误处理', () => {
|
||||
it('优雅处理空配置', () => {
|
||||
const success = registerProviderConfig(null as any)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('优雅处理未定义配置', () => {
|
||||
const success = registerProviderConfig(undefined as any)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('处理空字符串 ID', () => {
|
||||
const config = {
|
||||
id: '',
|
||||
name: 'Empty ID Provider',
|
||||
creator: vi.fn(() => ({ name: 'empty' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('处理创建不存在配置的 provider', async () => {
|
||||
await expect(createProvider('non-existent-provider', {})).rejects.toThrow(
|
||||
'ProviderConfig not found for id: non-existent-provider'
|
||||
)
|
||||
})
|
||||
|
||||
it('处理注册不存在配置的 provider', () => {
|
||||
const mockProvider = { name: 'mock' }
|
||||
const success = registerProvider('non-existent-provider', mockProvider)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('处理获取不存在配置的情况', () => {
|
||||
expect(getProviderConfig('non-existent')).toBeUndefined()
|
||||
expect(getProviderConfigByAlias('non-existent-alias')).toBeUndefined()
|
||||
expect(hasProviderConfig('non-existent')).toBe(false)
|
||||
expect(hasProviderConfigByAlias('non-existent-alias')).toBe(false)
|
||||
})
|
||||
|
||||
it('处理批量注册时的部分失败', () => {
|
||||
const mixedConfigs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'valid-provider-1',
|
||||
name: 'Valid Provider 1',
|
||||
creator: vi.fn(() => ({ name: 'valid-1' })),
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: '', // 无效配置
|
||||
name: 'Invalid Provider',
|
||||
creator: vi.fn(() => ({ name: 'invalid' })),
|
||||
supportsImageGeneration: false
|
||||
} as any,
|
||||
{
|
||||
id: 'valid-provider-2',
|
||||
name: 'Valid Provider 2',
|
||||
creator: vi.fn(() => ({ name: 'valid-2' })),
|
||||
supportsImageGeneration: true
|
||||
}
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(mixedConfigs)
|
||||
expect(successCount).toBe(2) // 只有两个有效配置成功
|
||||
|
||||
expect(hasProviderConfig('valid-provider-1')).toBe(true)
|
||||
expect(hasProviderConfig('valid-provider-2')).toBe(true)
|
||||
expect(hasProviderConfig('')).toBe(false)
|
||||
})
|
||||
|
||||
it('处理动态导入失败的情况', async () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'import-test-provider',
|
||||
name: 'Import Test Provider',
|
||||
import: vi.fn().mockRejectedValue(new Error('Import failed')),
|
||||
creatorFunctionName: 'createTest',
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
registerProviderConfig(config)
|
||||
|
||||
await expect(createProvider('import-test-provider', {})).rejects.toThrow('Import failed')
|
||||
})
|
||||
})
|
||||
|
||||
describe('集成测试', () => {
|
||||
it('正确处理复杂的配置、创建、注册和清理场景', async () => {
|
||||
// 初始状态验证
|
||||
const initialConfigs = getAllProviderConfigs()
|
||||
expect(initialConfigs.length).toBeGreaterThan(0) // 有基础配置
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
|
||||
// 注册多个带别名的 provider 配置
|
||||
const configs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'integration-provider-1',
|
||||
name: 'Integration Provider 1',
|
||||
creator: vi.fn(() => ({ name: 'integration-1' })),
|
||||
supportsImageGeneration: false,
|
||||
aliases: ['alias-1', 'short-name-1']
|
||||
},
|
||||
{
|
||||
id: 'integration-provider-2',
|
||||
name: 'Integration Provider 2',
|
||||
creator: vi.fn(() => ({ name: 'integration-2' })),
|
||||
supportsImageGeneration: true,
|
||||
aliases: ['alias-2', 'short-name-2']
|
||||
}
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(configs)
|
||||
expect(successCount).toBe(2)
|
||||
|
||||
// 验证配置注册成功
|
||||
expect(hasProviderConfig('integration-provider-1')).toBe(true)
|
||||
expect(hasProviderConfig('integration-provider-2')).toBe(true)
|
||||
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
|
||||
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
|
||||
|
||||
// 验证别名映射
|
||||
const aliases = getAllProviderConfigAliases()
|
||||
expect(aliases['alias-1']).toBe('integration-provider-1')
|
||||
expect(aliases['alias-2']).toBe('integration-provider-2')
|
||||
|
||||
// 创建和注册 providers
|
||||
const success1 = await createAndRegisterProvider('integration-provider-1', { apiKey: 'test1' })
|
||||
const success2 = await createAndRegisterProvider('integration-provider-2', { apiKey: 'test2' })
|
||||
expect(success1).toBe(true)
|
||||
expect(success2).toBe(true)
|
||||
|
||||
// 验证注册成功
|
||||
const registeredProviders = getInitializedProviders()
|
||||
expect(registeredProviders).toContain('integration-provider-1')
|
||||
expect(registeredProviders).toContain('integration-provider-2')
|
||||
expect(hasInitializedProviders()).toBe(true)
|
||||
|
||||
// 清理
|
||||
cleanup()
|
||||
|
||||
// 验证清理后的状态
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
expect(hasProviderConfig('integration-provider-1')).toBe(false)
|
||||
expect(hasProviderConfig('integration-provider-2')).toBe(false)
|
||||
expect(getAllProviderConfigAliases()).toEqual({})
|
||||
|
||||
// 基础配置应该重新加载
|
||||
expect(hasProviderConfig('openai')).toBe(true)
|
||||
})
|
||||
|
||||
it('正确处理动态导入配置的 provider', async () => {
|
||||
const mockModule = { createCustomProvider: vi.fn(() => ({ name: 'custom-dynamic' })) }
|
||||
const dynamicImportConfig: ProviderConfig = {
|
||||
id: 'dynamic-import-provider',
|
||||
name: 'Dynamic Import Provider',
|
||||
import: vi.fn().mockResolvedValue(mockModule),
|
||||
creatorFunctionName: 'createCustomProvider',
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 注册配置
|
||||
const configSuccess = registerProviderConfig(dynamicImportConfig)
|
||||
expect(configSuccess).toBe(true)
|
||||
|
||||
// 创建和注册 provider
|
||||
const registerSuccess = await createAndRegisterProvider('dynamic-import-provider', { apiKey: 'test' })
|
||||
expect(registerSuccess).toBe(true)
|
||||
|
||||
// 验证导入函数被调用
|
||||
expect(dynamicImportConfig.import).toHaveBeenCalled()
|
||||
expect(mockModule.createCustomProvider).toHaveBeenCalledWith({ apiKey: 'test' })
|
||||
|
||||
// 验证注册成功
|
||||
expect(getInitializedProviders()).toContain('dynamic-import-provider')
|
||||
})
|
||||
|
||||
it('正确处理大量配置的注册和管理', () => {
|
||||
const largeConfigList: ProviderConfig[] = []
|
||||
|
||||
// 生成50个配置
|
||||
for (let i = 0; i < 50; i++) {
|
||||
largeConfigList.push({
|
||||
id: `bulk-provider-${i}`,
|
||||
name: `Bulk Provider ${i}`,
|
||||
creator: vi.fn(() => ({ name: `bulk-${i}` })),
|
||||
supportsImageGeneration: i % 2 === 0, // 偶数支持图像生成
|
||||
aliases: [`alias-${i}`, `short-${i}`]
|
||||
})
|
||||
}
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(largeConfigList)
|
||||
expect(successCount).toBe(50)
|
||||
|
||||
// 验证所有配置都被正确注册
|
||||
const allConfigs = getAllProviderConfigs()
|
||||
expect(allConfigs.filter((config) => config.id.startsWith('bulk-provider-')).length).toBe(50)
|
||||
|
||||
// 验证别名数量
|
||||
const aliases = getAllProviderConfigAliases()
|
||||
const bulkAliases = Object.keys(aliases).filter(
|
||||
(alias) => alias.startsWith('alias-') || alias.startsWith('short-')
|
||||
)
|
||||
expect(bulkAliases.length).toBe(100) // 每个 provider 有2个别名
|
||||
|
||||
// 随机验证几个配置
|
||||
expect(hasProviderConfig('bulk-provider-0')).toBe(true)
|
||||
expect(hasProviderConfig('bulk-provider-25')).toBe(true)
|
||||
expect(hasProviderConfig('bulk-provider-49')).toBe(true)
|
||||
|
||||
// 验证别名工作正常
|
||||
expect(resolveProviderConfigId('alias-25')).toBe('bulk-provider-25')
|
||||
expect(isProviderConfigAlias('short-30')).toBe(true)
|
||||
|
||||
// 清理能正确处理大量数据
|
||||
cleanup()
|
||||
const cleanupAliases = getAllProviderConfigAliases()
|
||||
expect(
|
||||
Object.keys(cleanupAliases).filter((alias) => alias.startsWith('alias-') || alias.startsWith('short-'))
|
||||
).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('边界测试', () => {
|
||||
it('处理包含特殊字符的 provider IDs', () => {
|
||||
const specialCharsConfigs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'provider-with-dashes',
|
||||
name: 'Provider With Dashes',
|
||||
creator: vi.fn(() => ({ name: 'dashes' })),
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'provider_with_underscores',
|
||||
name: 'Provider With Underscores',
|
||||
creator: vi.fn(() => ({ name: 'underscores' })),
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'provider.with.dots',
|
||||
name: 'Provider With Dots',
|
||||
creator: vi.fn(() => ({ name: 'dots' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(specialCharsConfigs)
|
||||
expect(successCount).toBeGreaterThan(0) // 至少有一些成功
|
||||
|
||||
// 验证支持的特殊字符格式
|
||||
if (hasProviderConfig('provider-with-dashes')) {
|
||||
expect(getProviderConfig('provider-with-dashes')).toBeDefined()
|
||||
}
|
||||
if (hasProviderConfig('provider_with_underscores')) {
|
||||
expect(getProviderConfig('provider_with_underscores')).toBeDefined()
|
||||
}
|
||||
})
|
||||
|
||||
it('处理空的批量注册', () => {
|
||||
const successCount = registerMultipleProviderConfigs([])
|
||||
expect(successCount).toBe(0)
|
||||
|
||||
// 确保没有额外的配置被添加
|
||||
const configsBefore = getAllProviderConfigs().length
|
||||
expect(configsBefore).toBeGreaterThan(0) // 应该有基础配置
|
||||
})
|
||||
|
||||
it('处理重复的配置注册', () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'duplicate-test-provider',
|
||||
name: 'Duplicate Test Provider',
|
||||
creator: vi.fn(() => ({ name: 'duplicate' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 第一次注册成功
|
||||
expect(registerProviderConfig(config)).toBe(true)
|
||||
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
|
||||
|
||||
// 重复注册相同的配置(允许覆盖)
|
||||
const updatedConfig: ProviderConfig = {
|
||||
...config,
|
||||
name: 'Updated Duplicate Test Provider'
|
||||
}
|
||||
expect(registerProviderConfig(updatedConfig)).toBe(true)
|
||||
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
|
||||
|
||||
// 验证配置被更新
|
||||
const retrievedConfig = getProviderConfig('duplicate-test-provider')
|
||||
expect(retrievedConfig?.name).toBe('Updated Duplicate Test Provider')
|
||||
})
|
||||
|
||||
it('处理极长的 ID 和名称', () => {
|
||||
const longId = 'very-long-provider-id-' + 'x'.repeat(100)
|
||||
const longName = 'Very Long Provider Name ' + 'Y'.repeat(100)
|
||||
|
||||
const config: ProviderConfig = {
|
||||
id: longId,
|
||||
name: longName,
|
||||
creator: vi.fn(() => ({ name: 'long-test' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
expect(hasProviderConfig(longId)).toBe(true)
|
||||
|
||||
const retrievedConfig = getProviderConfig(longId)
|
||||
expect(retrievedConfig?.name).toBe(longName)
|
||||
})
|
||||
|
||||
it('处理大量别名的配置', () => {
|
||||
const manyAliases = Array.from({ length: 50 }, (_, i) => `alias-${i}`)
|
||||
|
||||
const config: ProviderConfig = {
|
||||
id: 'provider-with-many-aliases',
|
||||
name: 'Provider With Many Aliases',
|
||||
creator: vi.fn(() => ({ name: 'many-aliases' })),
|
||||
supportsImageGeneration: false,
|
||||
aliases: manyAliases
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
|
||||
// 验证所有别名都能正确解析
|
||||
manyAliases.forEach((alias) => {
|
||||
expect(hasProviderConfigByAlias(alias)).toBe(true)
|
||||
expect(resolveProviderConfigId(alias)).toBe('provider-with-many-aliases')
|
||||
expect(isProviderConfigAlias(alias)).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
264
packages/aiCore/src/core/providers/__tests__/schemas.test.ts
Normal file
264
packages/aiCore/src/core/providers/__tests__/schemas.test.ts
Normal file
@@ -0,0 +1,264 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
type BaseProviderId,
|
||||
baseProviderIds,
|
||||
baseProviderIdSchema,
|
||||
baseProviders,
|
||||
type CustomProviderId,
|
||||
customProviderIdSchema,
|
||||
providerConfigSchema,
|
||||
type ProviderId,
|
||||
providerIdSchema
|
||||
} from '../schemas'
|
||||
|
||||
describe('Provider Schemas', () => {
|
||||
describe('baseProviders', () => {
|
||||
it('包含所有预期的基础 providers', () => {
|
||||
expect(baseProviders).toBeDefined()
|
||||
expect(Array.isArray(baseProviders)).toBe(true)
|
||||
expect(baseProviders.length).toBeGreaterThan(0)
|
||||
|
||||
const expectedIds = [
|
||||
'openai',
|
||||
'openai-responses',
|
||||
'openai-compatible',
|
||||
'anthropic',
|
||||
'google',
|
||||
'xai',
|
||||
'azure',
|
||||
'deepseek'
|
||||
]
|
||||
const actualIds = baseProviders.map((p) => p.id)
|
||||
expectedIds.forEach((id) => {
|
||||
expect(actualIds).toContain(id)
|
||||
})
|
||||
})
|
||||
|
||||
it('每个基础 provider 有必要的属性', () => {
|
||||
baseProviders.forEach((provider) => {
|
||||
expect(provider).toHaveProperty('id')
|
||||
expect(provider).toHaveProperty('name')
|
||||
expect(provider).toHaveProperty('creator')
|
||||
expect(provider).toHaveProperty('supportsImageGeneration')
|
||||
|
||||
expect(typeof provider.id).toBe('string')
|
||||
expect(typeof provider.name).toBe('string')
|
||||
expect(typeof provider.creator).toBe('function')
|
||||
expect(typeof provider.supportsImageGeneration).toBe('boolean')
|
||||
})
|
||||
})
|
||||
|
||||
it('provider ID 是唯一的', () => {
|
||||
const ids = baseProviders.map((p) => p.id)
|
||||
const uniqueIds = [...new Set(ids)]
|
||||
expect(ids).toEqual(uniqueIds)
|
||||
})
|
||||
})
|
||||
|
||||
describe('baseProviderIds', () => {
|
||||
it('正确提取所有基础 provider IDs', () => {
|
||||
expect(baseProviderIds).toBeDefined()
|
||||
expect(Array.isArray(baseProviderIds)).toBe(true)
|
||||
expect(baseProviderIds.length).toBe(baseProviders.length)
|
||||
|
||||
baseProviders.forEach((provider) => {
|
||||
expect(baseProviderIds).toContain(provider.id)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('baseProviderIdSchema', () => {
|
||||
it('验证有效的基础 provider IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(baseProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝无效的基础 provider IDs', () => {
|
||||
const invalidIds = ['invalid', 'not-exists', '']
|
||||
invalidIds.forEach((id) => {
|
||||
expect(baseProviderIdSchema.safeParse(id).success).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('customProviderIdSchema', () => {
|
||||
it('接受有效的自定义 provider IDs', () => {
|
||||
const validIds = ['custom-provider', 'my-ai-service', 'company-llm-v2']
|
||||
validIds.forEach((id) => {
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝与基础 provider IDs 冲突的 IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝空字符串', () => {
|
||||
expect(customProviderIdSchema.safeParse('').success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('providerIdSchema', () => {
|
||||
it('接受基础 provider IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('接受有效的自定义 provider IDs', () => {
|
||||
const validCustomIds = ['custom-provider', 'my-ai-service']
|
||||
validCustomIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝无效的 IDs', () => {
|
||||
const invalidIds = ['', undefined, null, 123]
|
||||
invalidIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('providerConfigSchema', () => {
|
||||
it('验证带有 creator 的有效配置', () => {
|
||||
const validConfig = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
creator: vi.fn(),
|
||||
supportsImageGeneration: true
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
|
||||
})
|
||||
|
||||
it('验证带有 import 配置的有效配置', () => {
|
||||
const validConfig = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
import: vi.fn(),
|
||||
creatorFunctionName: 'createCustom',
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
|
||||
})
|
||||
|
||||
it('拒绝既没有 creator 也没有 import 配置的配置', () => {
|
||||
const invalidConfig = {
|
||||
id: 'invalid',
|
||||
name: 'Invalid Provider',
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
|
||||
})
|
||||
|
||||
it('为 supportsImageGeneration 设置默认值', () => {
|
||||
const config = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
creator: vi.fn()
|
||||
}
|
||||
const result = providerConfigSchema.safeParse(config)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.data.supportsImageGeneration).toBe(false)
|
||||
}
|
||||
})
|
||||
|
||||
it('拒绝使用基础 provider ID 的配置', () => {
|
||||
const invalidConfig = {
|
||||
id: 'openai', // 基础 provider ID
|
||||
name: 'Should Fail',
|
||||
creator: vi.fn()
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
|
||||
})
|
||||
|
||||
it('拒绝缺少必需字段的配置', () => {
|
||||
const invalidConfigs = [
|
||||
{ name: 'Missing ID', creator: vi.fn() },
|
||||
{ id: 'missing-name', creator: vi.fn() },
|
||||
{ id: '', name: 'Empty ID', creator: vi.fn() },
|
||||
{ id: 'valid-custom', name: '', creator: vi.fn() }
|
||||
]
|
||||
|
||||
invalidConfigs.forEach((config) => {
|
||||
expect(providerConfigSchema.safeParse(config).success).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Schema 验证功能', () => {
|
||||
it('baseProviderIdSchema 正确验证基础 provider IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(baseProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
|
||||
expect(baseProviderIdSchema.safeParse('invalid-id').success).toBe(false)
|
||||
})
|
||||
|
||||
it('customProviderIdSchema 正确验证自定义 provider IDs', () => {
|
||||
const customIds = ['custom-provider', 'my-service', 'company-llm']
|
||||
customIds.forEach((id) => {
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
|
||||
// 拒绝基础 provider IDs
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
it('providerIdSchema 接受基础和自定义 provider IDs', () => {
|
||||
// 基础 IDs
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
|
||||
// 自定义 IDs
|
||||
const customIds = ['custom-provider', 'my-service']
|
||||
customIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('providerConfigSchema 验证完整的 provider 配置', () => {
|
||||
const validConfig = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
creator: vi.fn(),
|
||||
supportsImageGeneration: true
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
|
||||
|
||||
const invalidConfig = {
|
||||
id: 'openai', // 不允许基础 provider ID
|
||||
name: 'OpenAI',
|
||||
creator: vi.fn()
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('类型推导', () => {
|
||||
it('BaseProviderId 类型正确', () => {
|
||||
const id: BaseProviderId = 'openai'
|
||||
expect(baseProviderIds).toContain(id)
|
||||
})
|
||||
|
||||
it('CustomProviderId 类型是字符串', () => {
|
||||
const id: CustomProviderId = 'custom-provider'
|
||||
expect(typeof id).toBe('string')
|
||||
})
|
||||
|
||||
it('ProviderId 类型支持基础和自定义 IDs', () => {
|
||||
const baseId: ProviderId = 'openai'
|
||||
const customId: ProviderId = 'custom-provider'
|
||||
expect(typeof baseId).toBe('string')
|
||||
expect(typeof customId).toBe('string')
|
||||
})
|
||||
})
|
||||
})
|
||||
291
packages/aiCore/src/core/providers/factory.ts
Normal file
291
packages/aiCore/src/core/providers/factory.ts
Normal file
@@ -0,0 +1,291 @@
|
||||
/**
|
||||
* AI Provider 配置工厂
|
||||
* 提供类型安全的 Provider 配置构建器
|
||||
*/
|
||||
|
||||
import type { ProviderId, ProviderSettingsMap } from './types'
|
||||
|
||||
/**
|
||||
* 通用配置基础类型,包含所有 Provider 共有的属性
|
||||
*/
|
||||
export interface BaseProviderConfig {
|
||||
apiKey?: string
|
||||
baseURL?: string
|
||||
timeout?: number
|
||||
headers?: Record<string, string>
|
||||
fetch?: typeof globalThis.fetch
|
||||
}
|
||||
|
||||
/**
|
||||
* 完整的配置类型,结合基础配置、AI SDK 配置和特定 Provider 配置
|
||||
*/
|
||||
type CompleteProviderConfig<T extends ProviderId> = BaseProviderConfig & Partial<ProviderSettingsMap[T]>
|
||||
|
||||
type ConfigHandler<T extends ProviderId> = (
|
||||
builder: ProviderConfigBuilder<T>,
|
||||
provider: CompleteProviderConfig<T>
|
||||
) => void
|
||||
|
||||
const configHandlers: {
|
||||
[K in ProviderId]?: ConfigHandler<K>
|
||||
} = {
|
||||
azure: (builder, provider) => {
|
||||
const azureBuilder = builder as ProviderConfigBuilder<'azure'>
|
||||
const azureProvider = provider as CompleteProviderConfig<'azure'>
|
||||
azureBuilder.withAzureConfig({
|
||||
apiVersion: azureProvider.apiVersion,
|
||||
resourceName: azureProvider.resourceName
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
|
||||
private config: CompleteProviderConfig<T> = {} as CompleteProviderConfig<T>
|
||||
|
||||
constructor(private providerId: T) {}
|
||||
|
||||
/**
|
||||
* 设置 API Key
|
||||
*/
|
||||
withApiKey(apiKey: string): this
|
||||
withApiKey(apiKey: string, options: T extends 'openai' ? { organization?: string; project?: string } : never): this
|
||||
withApiKey(apiKey: string, options?: any): this {
|
||||
this.config.apiKey = apiKey
|
||||
|
||||
// 类型安全的 OpenAI 特定配置
|
||||
if (this.providerId === 'openai' && options) {
|
||||
const openaiConfig = this.config as CompleteProviderConfig<'openai'>
|
||||
if (options.organization) {
|
||||
openaiConfig.organization = options.organization
|
||||
}
|
||||
if (options.project) {
|
||||
openaiConfig.project = options.project
|
||||
}
|
||||
}
|
||||
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置基础 URL
|
||||
*/
|
||||
withBaseURL(baseURL: string) {
|
||||
this.config.baseURL = baseURL
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置请求配置
|
||||
*/
|
||||
withRequestConfig(options: { headers?: Record<string, string>; fetch?: typeof fetch }): this {
|
||||
if (options.headers) {
|
||||
this.config.headers = { ...this.config.headers, ...options.headers }
|
||||
}
|
||||
if (options.fetch) {
|
||||
this.config.fetch = options.fetch
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* Azure OpenAI 特定配置
|
||||
*/
|
||||
withAzureConfig(options: { apiVersion?: string; resourceName?: string }): T extends 'azure' ? this : never
|
||||
withAzureConfig(options: any): any {
|
||||
if (this.providerId === 'azure') {
|
||||
const azureConfig = this.config as CompleteProviderConfig<'azure'>
|
||||
if (options.apiVersion) {
|
||||
azureConfig.apiVersion = options.apiVersion
|
||||
}
|
||||
if (options.resourceName) {
|
||||
azureConfig.resourceName = options.resourceName
|
||||
}
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置自定义参数
|
||||
*/
|
||||
withCustomParams(params: Record<string, any>) {
|
||||
Object.assign(this.config, params)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建最终配置
|
||||
*/
|
||||
build(): ProviderSettingsMap[T] {
|
||||
return this.config as ProviderSettingsMap[T]
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider 配置工厂
|
||||
* 提供便捷的配置创建方法
|
||||
*/
|
||||
export class ProviderConfigFactory {
|
||||
/**
|
||||
* 创建配置构建器
|
||||
*/
|
||||
static builder<T extends ProviderId>(providerId: T): ProviderConfigBuilder<T> {
|
||||
return new ProviderConfigBuilder(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 从通用Provider对象创建配置 - 使用更优雅的处理器模式
|
||||
*/
|
||||
static fromProvider<T extends ProviderId>(
|
||||
providerId: T,
|
||||
provider: CompleteProviderConfig<T>,
|
||||
options?: {
|
||||
headers?: Record<string, string>
|
||||
[key: string]: any
|
||||
}
|
||||
): ProviderSettingsMap[T] {
|
||||
const builder = new ProviderConfigBuilder<T>(providerId)
|
||||
|
||||
// 设置基本配置
|
||||
if (provider.apiKey) {
|
||||
builder.withApiKey(provider.apiKey)
|
||||
}
|
||||
|
||||
if (provider.baseURL) {
|
||||
builder.withBaseURL(provider.baseURL)
|
||||
}
|
||||
|
||||
// 设置请求配置
|
||||
if (options?.headers) {
|
||||
builder.withRequestConfig({
|
||||
headers: options.headers
|
||||
})
|
||||
}
|
||||
|
||||
// 使用配置处理器模式 - 更加优雅和可扩展
|
||||
const handler = configHandlers[providerId]
|
||||
if (handler) {
|
||||
handler(builder, provider)
|
||||
}
|
||||
|
||||
// 添加其他自定义参数
|
||||
if (options) {
|
||||
const customOptions = { ...options }
|
||||
delete customOptions.headers // 已经处理过了
|
||||
if (Object.keys(customOptions).length > 0) {
|
||||
builder.withCustomParams(customOptions)
|
||||
}
|
||||
}
|
||||
|
||||
return builder.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 OpenAI 配置
|
||||
*/
|
||||
static createOpenAI(
|
||||
apiKey: string,
|
||||
options?: {
|
||||
baseURL?: string
|
||||
organization?: string
|
||||
project?: string
|
||||
}
|
||||
) {
|
||||
const builder = this.builder('openai')
|
||||
|
||||
// 使用类型安全的重载
|
||||
if (options?.organization || options?.project) {
|
||||
builder.withApiKey(apiKey, {
|
||||
organization: options.organization,
|
||||
project: options.project
|
||||
})
|
||||
} else {
|
||||
builder.withApiKey(apiKey)
|
||||
}
|
||||
|
||||
return builder.withBaseURL(options?.baseURL || 'https://api.openai.com').build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Anthropic 配置
|
||||
*/
|
||||
static createAnthropic(
|
||||
apiKey: string,
|
||||
options?: {
|
||||
baseURL?: string
|
||||
}
|
||||
) {
|
||||
return this.builder('anthropic')
|
||||
.withApiKey(apiKey)
|
||||
.withBaseURL(options?.baseURL || 'https://api.anthropic.com')
|
||||
.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Azure OpenAI 配置
|
||||
*/
|
||||
static createAzureOpenAI(
|
||||
apiKey: string,
|
||||
options: {
|
||||
baseURL: string
|
||||
apiVersion?: string
|
||||
resourceName?: string
|
||||
}
|
||||
) {
|
||||
return this.builder('azure')
|
||||
.withApiKey(apiKey)
|
||||
.withBaseURL(options.baseURL)
|
||||
.withAzureConfig({
|
||||
apiVersion: options.apiVersion,
|
||||
resourceName: options.resourceName
|
||||
})
|
||||
.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Google 配置
|
||||
*/
|
||||
static createGoogle(
|
||||
apiKey: string,
|
||||
options?: {
|
||||
baseURL?: string
|
||||
projectId?: string
|
||||
location?: string
|
||||
}
|
||||
) {
|
||||
return this.builder('google')
|
||||
.withApiKey(apiKey)
|
||||
.withBaseURL(options?.baseURL || 'https://generativelanguage.googleapis.com')
|
||||
.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Vertex AI 配置
|
||||
*/
|
||||
static createVertexAI() {
|
||||
// credentials: {
|
||||
// clientEmail: string
|
||||
// privateKey: string
|
||||
// },
|
||||
// options?: {
|
||||
// project?: string
|
||||
// location?: string
|
||||
// }
|
||||
// return this.builder('google-vertex')
|
||||
// .withGoogleCredentials(credentials)
|
||||
// .withGoogleVertexConfig({
|
||||
// project: options?.project,
|
||||
// location: options?.location
|
||||
// })
|
||||
// .build()
|
||||
}
|
||||
|
||||
static createOpenAICompatible(baseURL: string, apiKey: string) {
|
||||
return this.builder('openai-compatible').withBaseURL(baseURL).withApiKey(apiKey).build()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 便捷的配置创建函数
|
||||
*/
|
||||
export const createProviderConfig = ProviderConfigFactory.fromProvider
|
||||
export const providerConfigBuilder = ProviderConfigFactory.builder
|
||||
83
packages/aiCore/src/core/providers/index.ts
Normal file
83
packages/aiCore/src/core/providers/index.ts
Normal file
@@ -0,0 +1,83 @@
|
||||
/**
|
||||
* Providers 模块统一导出 - 独立Provider包
|
||||
*/
|
||||
|
||||
// ==================== 核心管理器 ====================
|
||||
|
||||
// Provider 注册表管理器
|
||||
export { globalRegistryManagement, RegistryManagement } from './RegistryManagement'
|
||||
|
||||
// Provider 核心功能
|
||||
export {
|
||||
// 状态管理
|
||||
cleanup,
|
||||
clearAllProviders,
|
||||
createAndRegisterProvider,
|
||||
createProvider,
|
||||
getAllProviderConfigAliases,
|
||||
getAllProviderConfigs,
|
||||
getImageModel,
|
||||
// 工具函数
|
||||
getInitializedProviders,
|
||||
getLanguageModel,
|
||||
getProviderConfig,
|
||||
getProviderConfigByAlias,
|
||||
getSupportedProviders,
|
||||
getTextEmbeddingModel,
|
||||
hasInitializedProviders,
|
||||
// 工具函数
|
||||
hasProviderConfig,
|
||||
// 别名支持
|
||||
hasProviderConfigByAlias,
|
||||
isProviderConfigAlias,
|
||||
// 错误类型
|
||||
ProviderInitializationError,
|
||||
// 全局访问
|
||||
providerRegistry,
|
||||
registerMultipleProviderConfigs,
|
||||
registerProvider,
|
||||
// 统一Provider系统
|
||||
registerProviderConfig,
|
||||
resolveProviderConfigId
|
||||
} from './registry'
|
||||
|
||||
// ==================== 基础数据和类型 ====================
|
||||
|
||||
// 基础Provider数据源
|
||||
export { baseProviderIds, baseProviders } from './schemas'
|
||||
|
||||
// 类型定义和Schema
|
||||
export type {
|
||||
BaseProviderId,
|
||||
CustomProviderId,
|
||||
DynamicProviderRegistration,
|
||||
ProviderConfig,
|
||||
ProviderId
|
||||
} from './schemas' // 从 schemas 导出的类型
|
||||
export { baseProviderIdSchema, customProviderIdSchema, providerConfigSchema, providerIdSchema } from './schemas' // Schema 导出
|
||||
export type {
|
||||
DynamicProviderRegistry,
|
||||
ExtensibleProviderSettingsMap,
|
||||
ProviderError,
|
||||
ProviderSettingsMap,
|
||||
ProviderTypeRegistrar
|
||||
} from './types'
|
||||
|
||||
// ==================== 工具函数 ====================
|
||||
|
||||
// Provider配置工厂
|
||||
export {
|
||||
type BaseProviderConfig,
|
||||
createProviderConfig,
|
||||
ProviderConfigBuilder,
|
||||
providerConfigBuilder,
|
||||
ProviderConfigFactory
|
||||
} from './factory'
|
||||
|
||||
// 工具函数
|
||||
export { formatPrivateKey } from './utils'
|
||||
|
||||
// ==================== 扩展功能 ====================
|
||||
|
||||
// Hub Provider 功能
|
||||
export { createHubProvider, type HubProviderConfig, HubProviderError } from './HubProvider'
|
||||
320
packages/aiCore/src/core/providers/registry.ts
Normal file
320
packages/aiCore/src/core/providers/registry.ts
Normal file
@@ -0,0 +1,320 @@
|
||||
/**
|
||||
* Provider 初始化器
|
||||
* 负责根据配置创建 providers 并注册到全局管理器
|
||||
* 集成了来自 ModelCreator 的特殊处理逻辑
|
||||
*/
|
||||
|
||||
import { customProvider } from 'ai'
|
||||
|
||||
import { globalRegistryManagement } from './RegistryManagement'
|
||||
import { baseProviders, type ProviderConfig } from './schemas'
|
||||
|
||||
/**
|
||||
* Provider 初始化错误类型
|
||||
*/
|
||||
class ProviderInitializationError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public providerId?: string,
|
||||
public cause?: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'ProviderInitializationError'
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 全局管理器导出 ====================
|
||||
|
||||
export { globalRegistryManagement as providerRegistry }
|
||||
|
||||
// ==================== 便捷访问方法 ====================
|
||||
|
||||
export const getLanguageModel = (id: string) => globalRegistryManagement.languageModel(id as any)
|
||||
export const getTextEmbeddingModel = (id: string) => globalRegistryManagement.textEmbeddingModel(id as any)
|
||||
export const getImageModel = (id: string) => globalRegistryManagement.imageModel(id as any)
|
||||
|
||||
// ==================== 工具函数 ====================
|
||||
|
||||
/**
|
||||
* 获取支持的 Providers 列表
|
||||
*/
|
||||
export function getSupportedProviders(): Array<{
|
||||
id: string
|
||||
name: string
|
||||
}> {
|
||||
return baseProviders.map((provider) => ({
|
||||
id: provider.id,
|
||||
name: provider.name
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有已初始化的 providers
|
||||
*/
|
||||
export function getInitializedProviders(): string[] {
|
||||
return globalRegistryManagement.getRegisteredProviders()
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否有任何已初始化的 providers
|
||||
*/
|
||||
export function hasInitializedProviders(): boolean {
|
||||
return globalRegistryManagement.hasProviders()
|
||||
}
|
||||
|
||||
// ==================== 统一Provider配置系统 ====================
|
||||
|
||||
// 全局Provider配置存储
|
||||
const providerConfigs = new Map<string, ProviderConfig>()
|
||||
// 全局ProviderConfig别名映射 - 借鉴RegistryManagement模式
|
||||
const providerConfigAliases = new Map<string, string>() // alias -> realId
|
||||
|
||||
/**
|
||||
* 初始化内置配置 - 将baseProviders转换为统一格式
|
||||
*/
|
||||
function initializeBuiltInConfigs(): void {
|
||||
baseProviders.forEach((provider) => {
|
||||
const config: ProviderConfig = {
|
||||
id: provider.id,
|
||||
name: provider.name,
|
||||
creator: provider.creator as any, // 类型转换以兼容多种creator签名
|
||||
supportsImageGeneration: provider.supportsImageGeneration || false
|
||||
}
|
||||
providerConfigs.set(provider.id, config)
|
||||
})
|
||||
}
|
||||
|
||||
// 启动时自动注册内置配置
|
||||
initializeBuiltInConfigs()
|
||||
|
||||
/**
|
||||
* 步骤1: 注册Provider配置 - 仅存储配置,不执行创建
|
||||
*/
|
||||
export function registerProviderConfig(config: ProviderConfig): boolean {
|
||||
try {
|
||||
// 验证配置
|
||||
if (!config || !config.id || !config.name) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否与已有配置冲突(包括内置配置)
|
||||
if (providerConfigs.has(config.id)) {
|
||||
console.warn(`ProviderConfig "${config.id}" already exists, will override`)
|
||||
}
|
||||
|
||||
// 存储配置(内置和用户配置统一处理)
|
||||
providerConfigs.set(config.id, config)
|
||||
|
||||
// 处理别名
|
||||
if (config.aliases && config.aliases.length > 0) {
|
||||
config.aliases.forEach((alias) => {
|
||||
if (providerConfigAliases.has(alias)) {
|
||||
console.warn(`ProviderConfig alias "${alias}" already exists, will override`)
|
||||
}
|
||||
providerConfigAliases.set(alias, config.id)
|
||||
})
|
||||
}
|
||||
|
||||
return true
|
||||
} catch (error) {
|
||||
console.error(`Failed to register ProviderConfig:`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 步骤2: 创建Provider - 根据配置执行实际创建
|
||||
*/
|
||||
export async function createProvider(providerId: string, options: any): Promise<any> {
|
||||
// 支持通过别名查找配置
|
||||
const config = getProviderConfigByAlias(providerId)
|
||||
|
||||
if (!config) {
|
||||
throw new Error(`ProviderConfig not found for id: ${providerId}`)
|
||||
}
|
||||
|
||||
try {
|
||||
let creator: (options: any) => any
|
||||
|
||||
if (config.creator) {
|
||||
// 方式1: 直接执行 creator
|
||||
creator = config.creator
|
||||
} else if (config.import && config.creatorFunctionName) {
|
||||
// 方式2: 动态导入并执行
|
||||
const module = await config.import()
|
||||
creator = (module as any)[config.creatorFunctionName]
|
||||
|
||||
if (!creator || typeof creator !== 'function') {
|
||||
throw new Error(`Creator function "${config.creatorFunctionName}" not found in imported module`)
|
||||
}
|
||||
} else {
|
||||
throw new Error('No valid creator method provided in ProviderConfig')
|
||||
}
|
||||
|
||||
// 使用真实配置创建provider实例
|
||||
return creator(options)
|
||||
} catch (error) {
|
||||
console.error(`Failed to create provider "${providerId}":`, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 步骤3: 注册Provider到全局管理器
|
||||
*/
|
||||
export function registerProvider(providerId: string, provider: any): boolean {
|
||||
try {
|
||||
const config = providerConfigs.get(providerId)
|
||||
if (!config) {
|
||||
console.error(`ProviderConfig not found for id: ${providerId}`)
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取aliases配置
|
||||
const aliases = config.aliases
|
||||
|
||||
// 处理特殊provider逻辑
|
||||
if (providerId === 'openai') {
|
||||
// 注册默认 openai
|
||||
globalRegistryManagement.registerProvider(providerId, provider, aliases)
|
||||
|
||||
// 创建并注册 openai-chat 变体
|
||||
const openaiChatProvider = customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.chat(modelId)
|
||||
}
|
||||
})
|
||||
globalRegistryManagement.registerProvider(`${providerId}-chat`, openaiChatProvider)
|
||||
} else if (providerId === 'azure') {
|
||||
globalRegistryManagement.registerProvider(`${providerId}-chat`, provider, aliases)
|
||||
// 跟上面相反,creator产出的默认会调用chat
|
||||
const azureResponsesProvider = customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.responses(modelId)
|
||||
}
|
||||
})
|
||||
globalRegistryManagement.registerProvider(providerId, azureResponsesProvider)
|
||||
} else {
|
||||
// 其他provider直接注册
|
||||
globalRegistryManagement.registerProvider(providerId, provider, aliases)
|
||||
}
|
||||
|
||||
return true
|
||||
} catch (error) {
|
||||
console.error(`Failed to register provider "${providerId}" to global registry:`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 便捷函数: 一次性完成创建+注册
|
||||
*/
|
||||
export async function createAndRegisterProvider(providerId: string, options: any): Promise<boolean> {
|
||||
try {
|
||||
// 步骤2: 创建provider
|
||||
const provider = await createProvider(providerId, options)
|
||||
|
||||
// 步骤3: 注册到全局管理器
|
||||
return registerProvider(providerId, provider)
|
||||
} catch (error) {
|
||||
console.error(`Failed to create and register provider "${providerId}":`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量注册Provider配置
|
||||
*/
|
||||
export function registerMultipleProviderConfigs(configs: ProviderConfig[]): number {
|
||||
let successCount = 0
|
||||
configs.forEach((config) => {
|
||||
if (registerProviderConfig(config)) {
|
||||
successCount++
|
||||
}
|
||||
})
|
||||
return successCount
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否有对应的Provider配置
|
||||
*/
|
||||
export function hasProviderConfig(providerId: string): boolean {
|
||||
return providerConfigs.has(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过别名或ID检查是否有对应的Provider配置
|
||||
*/
|
||||
export function hasProviderConfigByAlias(aliasOrId: string): boolean {
|
||||
const realId = resolveProviderConfigId(aliasOrId)
|
||||
return providerConfigs.has(realId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有Provider配置
|
||||
*/
|
||||
export function getAllProviderConfigs(): ProviderConfig[] {
|
||||
return Array.from(providerConfigs.values())
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据ID获取Provider配置
|
||||
*/
|
||||
export function getProviderConfig(providerId: string): ProviderConfig | undefined {
|
||||
return providerConfigs.get(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过别名或ID获取Provider配置
|
||||
*/
|
||||
export function getProviderConfigByAlias(aliasOrId: string): ProviderConfig | undefined {
|
||||
// 先检查是否为别名,如果是则解析为真实ID
|
||||
const realId = providerConfigAliases.get(aliasOrId) || aliasOrId
|
||||
return providerConfigs.get(realId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析真实的ProviderConfig ID(去别名化)
|
||||
*/
|
||||
export function resolveProviderConfigId(aliasOrId: string): string {
|
||||
return providerConfigAliases.get(aliasOrId) || aliasOrId
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否为ProviderConfig别名
|
||||
*/
|
||||
export function isProviderConfigAlias(id: string): boolean {
|
||||
return providerConfigAliases.has(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有ProviderConfig别名映射关系
|
||||
*/
|
||||
export function getAllProviderConfigAliases(): Record<string, string> {
|
||||
const result: Record<string, string> = {}
|
||||
providerConfigAliases.forEach((realId, alias) => {
|
||||
result[alias] = realId
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理所有Provider配置和已注册的providers
|
||||
*/
|
||||
export function cleanup(): void {
|
||||
providerConfigs.clear()
|
||||
providerConfigAliases.clear() // 清理别名映射
|
||||
globalRegistryManagement.clear()
|
||||
// 重新初始化内置配置
|
||||
initializeBuiltInConfigs()
|
||||
}
|
||||
|
||||
export function clearAllProviders(): void {
|
||||
globalRegistryManagement.clear()
|
||||
}
|
||||
|
||||
// ==================== 导出错误类型 ====================
|
||||
|
||||
export { ProviderInitializationError }
|
||||
178
packages/aiCore/src/core/providers/schemas.ts
Normal file
178
packages/aiCore/src/core/providers/schemas.ts
Normal file
@@ -0,0 +1,178 @@
|
||||
/**
|
||||
* Provider Config 定义
|
||||
*/
|
||||
|
||||
import { createAnthropic } from '@ai-sdk/anthropic'
|
||||
import { createAzure } from '@ai-sdk/azure'
|
||||
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
|
||||
import { createDeepSeek } from '@ai-sdk/deepseek'
|
||||
import { createGoogleGenerativeAI } from '@ai-sdk/google'
|
||||
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
|
||||
import { createXai } from '@ai-sdk/xai'
|
||||
import { customProvider, type Provider } from 'ai'
|
||||
import * as z from 'zod'
|
||||
|
||||
/**
|
||||
* 基础 Provider IDs
|
||||
*/
|
||||
export const baseProviderIds = [
|
||||
'openai',
|
||||
'openai-chat',
|
||||
'openai-compatible',
|
||||
'anthropic',
|
||||
'google',
|
||||
'xai',
|
||||
'azure',
|
||||
'azure-responses',
|
||||
'deepseek'
|
||||
] as const
|
||||
|
||||
/**
|
||||
* 基础 Provider ID Schema
|
||||
*/
|
||||
export const baseProviderIdSchema = z.enum(baseProviderIds)
|
||||
|
||||
/**
|
||||
* 基础 Provider ID
|
||||
*/
|
||||
export type BaseProviderId = z.infer<typeof baseProviderIdSchema>
|
||||
|
||||
export const baseProviderSchema = z.object({
|
||||
id: baseProviderIdSchema,
|
||||
name: z.string(),
|
||||
creator: z.function().args(z.any()).returns(z.any()) as z.ZodType<(options: any) => Provider>,
|
||||
supportsImageGeneration: z.boolean()
|
||||
})
|
||||
|
||||
export type BaseProvider = z.infer<typeof baseProviderSchema>
|
||||
|
||||
/**
|
||||
* 基础 Providers 定义
|
||||
* 作为唯一数据源,避免重复维护
|
||||
*/
|
||||
export const baseProviders = [
|
||||
{
|
||||
id: 'openai',
|
||||
name: 'OpenAI',
|
||||
creator: createOpenAI,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'openai-chat',
|
||||
name: 'OpenAI Chat',
|
||||
creator: (options: OpenAIProviderSettings) => {
|
||||
const provider = createOpenAI(options)
|
||||
return customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.chat(modelId)
|
||||
}
|
||||
})
|
||||
},
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'openai-compatible',
|
||||
name: 'OpenAI Compatible',
|
||||
creator: createOpenAICompatible,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'anthropic',
|
||||
name: 'Anthropic',
|
||||
creator: createAnthropic,
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'google',
|
||||
name: 'Google Generative AI',
|
||||
creator: createGoogleGenerativeAI,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'xai',
|
||||
name: 'xAI (Grok)',
|
||||
creator: createXai,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'azure',
|
||||
name: 'Azure OpenAI',
|
||||
creator: createAzure,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'azure-responses',
|
||||
name: 'Azure OpenAI Responses',
|
||||
creator: (options: AzureOpenAIProviderSettings) => {
|
||||
const provider = createAzure(options)
|
||||
return customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.responses(modelId)
|
||||
}
|
||||
})
|
||||
},
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'deepseek',
|
||||
name: 'DeepSeek',
|
||||
creator: createDeepSeek,
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
] as const satisfies BaseProvider[]
|
||||
|
||||
/**
|
||||
* 用户自定义 Provider ID Schema
|
||||
* 允许任意字符串,但排除基础 provider IDs 以避免冲突
|
||||
*/
|
||||
export const customProviderIdSchema = z
|
||||
.string()
|
||||
.min(1)
|
||||
.refine((id) => !baseProviderIds.includes(id as any), {
|
||||
message: 'Custom provider ID cannot conflict with base provider IDs'
|
||||
})
|
||||
|
||||
/**
|
||||
* Provider ID Schema - 支持基础和自定义
|
||||
*/
|
||||
export const providerIdSchema = z.union([baseProviderIdSchema, customProviderIdSchema])
|
||||
|
||||
/**
|
||||
* Provider 配置 Schema
|
||||
* 用于Provider的配置验证
|
||||
*/
|
||||
export const providerConfigSchema = z
|
||||
.object({
|
||||
id: customProviderIdSchema, // 只允许自定义ID
|
||||
name: z.string().min(1),
|
||||
creator: z.function().optional(),
|
||||
import: z.function().optional(),
|
||||
creatorFunctionName: z.string().optional(),
|
||||
supportsImageGeneration: z.boolean().default(false),
|
||||
imageCreator: z.function().optional(),
|
||||
validateOptions: z.function().optional(),
|
||||
aliases: z.array(z.string()).optional()
|
||||
})
|
||||
.refine((data) => data.creator || (data.import && data.creatorFunctionName), {
|
||||
message: 'Must provide either creator function or import configuration'
|
||||
})
|
||||
|
||||
/**
|
||||
* Provider ID 类型 - 基于 zod schema 推导
|
||||
*/
|
||||
export type ProviderId = z.infer<typeof providerIdSchema>
|
||||
export type CustomProviderId = z.infer<typeof customProviderIdSchema>
|
||||
|
||||
/**
|
||||
* Provider 配置类型
|
||||
*/
|
||||
export type ProviderConfig = z.infer<typeof providerConfigSchema>
|
||||
|
||||
/**
|
||||
* 兼容性类型别名
|
||||
* @deprecated 使用 ProviderConfig 替代
|
||||
*/
|
||||
export type DynamicProviderRegistration = ProviderConfig
|
||||
96
packages/aiCore/src/core/providers/types.ts
Normal file
96
packages/aiCore/src/core/providers/types.ts
Normal file
@@ -0,0 +1,96 @@
|
||||
import { type AnthropicProviderSettings } from '@ai-sdk/anthropic'
|
||||
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
|
||||
import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
|
||||
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
|
||||
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
|
||||
import {
|
||||
EmbeddingModelV2 as EmbeddingModel,
|
||||
ImageModelV2 as ImageModel,
|
||||
LanguageModelV2 as LanguageModel,
|
||||
ProviderV2,
|
||||
SpeechModelV2 as SpeechModel,
|
||||
TranscriptionModelV2 as TranscriptionModel
|
||||
} from '@ai-sdk/provider'
|
||||
import { type XaiProviderSettings } from '@ai-sdk/xai'
|
||||
|
||||
// 导入基于 Zod 的 ProviderId 类型
|
||||
import { type ProviderId as ZodProviderId } from './schemas'
|
||||
|
||||
export interface ExtensibleProviderSettingsMap {
|
||||
// 基础的静态providers
|
||||
openai: OpenAIProviderSettings
|
||||
'openai-responses': OpenAIProviderSettings
|
||||
'openai-compatible': OpenAICompatibleProviderSettings
|
||||
anthropic: AnthropicProviderSettings
|
||||
google: GoogleGenerativeAIProviderSettings
|
||||
xai: XaiProviderSettings
|
||||
azure: AzureOpenAIProviderSettings
|
||||
deepseek: DeepSeekProviderSettings
|
||||
}
|
||||
|
||||
// 动态扩展的provider类型注册表
|
||||
export interface DynamicProviderRegistry {
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
// 合并基础和动态provider类型
|
||||
export type ProviderSettingsMap = ExtensibleProviderSettingsMap & DynamicProviderRegistry
|
||||
|
||||
// 错误类型
|
||||
export class ProviderError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public providerId: string,
|
||||
public code?: string,
|
||||
public cause?: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'ProviderError'
|
||||
}
|
||||
}
|
||||
|
||||
// 动态ProviderId类型 - 基于 Zod Schema,支持运行时扩展和验证
|
||||
export type ProviderId = ZodProviderId
|
||||
|
||||
export interface ProviderTypeRegistrar {
|
||||
registerProviderType<T extends string, S>(providerId: T, settingsType: S): void
|
||||
getProviderSettings<T extends string>(providerId: T): any
|
||||
}
|
||||
|
||||
// 重新导出所有类型供外部使用
|
||||
export type {
|
||||
AnthropicProviderSettings,
|
||||
AzureOpenAIProviderSettings,
|
||||
DeepSeekProviderSettings,
|
||||
GoogleGenerativeAIProviderSettings,
|
||||
OpenAICompatibleProviderSettings,
|
||||
OpenAIProviderSettings,
|
||||
XaiProviderSettings
|
||||
}
|
||||
|
||||
export type AiSdkModel = LanguageModel | ImageModel | EmbeddingModel<string> | TranscriptionModel | SpeechModel
|
||||
|
||||
export type AiSdkModelType = 'text' | 'image' | 'embedding' | 'transcription' | 'speech'
|
||||
|
||||
export const METHOD_MAP = {
|
||||
text: 'languageModel',
|
||||
image: 'imageModel',
|
||||
embedding: 'textEmbeddingModel',
|
||||
transcription: 'transcriptionModel',
|
||||
speech: 'speechModel'
|
||||
} as const satisfies Record<AiSdkModelType, keyof ProviderV2>
|
||||
|
||||
export type AiSdkModelMethodMap = Record<AiSdkModelType, keyof ProviderV2>
|
||||
|
||||
export type AiSdkModelReturnMap = {
|
||||
text: LanguageModel
|
||||
image: ImageModel
|
||||
embedding: EmbeddingModel<string>
|
||||
transcription: TranscriptionModel
|
||||
speech: SpeechModel
|
||||
}
|
||||
|
||||
export type AiSdkMethodName<T extends AiSdkModelType> = (typeof METHOD_MAP)[T]
|
||||
|
||||
export type AiSdkModelReturn<T extends AiSdkModelType> = AiSdkModelReturnMap[T]
|
||||
86
packages/aiCore/src/core/providers/utils.ts
Normal file
86
packages/aiCore/src/core/providers/utils.ts
Normal file
@@ -0,0 +1,86 @@
|
||||
/**
|
||||
* 格式化私钥,确保它包含正确的PEM头部和尾部
|
||||
*/
|
||||
export function formatPrivateKey(privateKey: string): string {
|
||||
if (!privateKey || typeof privateKey !== 'string') {
|
||||
throw new Error('Private key must be a non-empty string')
|
||||
}
|
||||
|
||||
// 先处理 JSON 字符串中的转义换行符
|
||||
const key = privateKey.replace(/\\n/g, '\n')
|
||||
|
||||
// 检查是否已经是正确格式的 PEM 私钥
|
||||
const hasBeginMarker = key.includes('-----BEGIN PRIVATE KEY-----')
|
||||
const hasEndMarker = key.includes('-----END PRIVATE KEY-----')
|
||||
|
||||
if (hasBeginMarker && hasEndMarker) {
|
||||
// 已经是 PEM 格式,但可能格式不规范,重新格式化
|
||||
return normalizePemFormat(key)
|
||||
}
|
||||
|
||||
// 如果没有完整的 PEM 头尾,尝试重新构建
|
||||
return reconstructPemKey(key)
|
||||
}
|
||||
|
||||
/**
|
||||
* 标准化 PEM 格式
|
||||
*/
|
||||
function normalizePemFormat(pemKey: string): string {
|
||||
// 分离头部、内容和尾部
|
||||
const lines = pemKey
|
||||
.split('\n')
|
||||
.map((line) => line.trim())
|
||||
.filter((line) => line.length > 0)
|
||||
|
||||
let keyContent = ''
|
||||
let foundBegin = false
|
||||
let foundEnd = false
|
||||
|
||||
for (const line of lines) {
|
||||
if (line === '-----BEGIN PRIVATE KEY-----') {
|
||||
foundBegin = true
|
||||
continue
|
||||
}
|
||||
if (line === '-----END PRIVATE KEY-----') {
|
||||
foundEnd = true
|
||||
break
|
||||
}
|
||||
if (foundBegin && !foundEnd) {
|
||||
keyContent += line
|
||||
}
|
||||
}
|
||||
|
||||
if (!foundBegin || !foundEnd || !keyContent) {
|
||||
throw new Error('Invalid PEM format: missing BEGIN/END markers or key content')
|
||||
}
|
||||
|
||||
// 重新格式化为 64 字符一行
|
||||
const formattedContent = keyContent.match(/.{1,64}/g)?.join('\n') || keyContent
|
||||
|
||||
return `-----BEGIN PRIVATE KEY-----\n${formattedContent}\n-----END PRIVATE KEY-----`
|
||||
}
|
||||
|
||||
/**
|
||||
* 重新构建 PEM 私钥
|
||||
*/
|
||||
function reconstructPemKey(key: string): string {
|
||||
// 移除所有空白字符和可能存在的不完整头尾
|
||||
let cleanKey = key.replace(/\s+/g, '')
|
||||
cleanKey = cleanKey.replace(/-----BEGIN[^-]*-----/g, '')
|
||||
cleanKey = cleanKey.replace(/-----END[^-]*-----/g, '')
|
||||
|
||||
// 确保私钥内容不为空
|
||||
if (!cleanKey) {
|
||||
throw new Error('Private key content is empty after cleaning')
|
||||
}
|
||||
|
||||
// 验证是否是有效的 Base64 字符
|
||||
if (!/^[A-Za-z0-9+/=]+$/.test(cleanKey)) {
|
||||
throw new Error('Private key contains invalid characters (not valid Base64)')
|
||||
}
|
||||
|
||||
// 格式化为 64 字符一行
|
||||
const formattedKey = cleanKey.match(/.{1,64}/g)?.join('\n') || cleanKey
|
||||
|
||||
return `-----BEGIN PRIVATE KEY-----\n${formattedKey}\n-----END PRIVATE KEY-----`
|
||||
}
|
||||
523
packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts
Normal file
523
packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts
Normal file
@@ -0,0 +1,523 @@
|
||||
import { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import { experimental_generateImage as aiGenerateImage, NoImageGeneratedError } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { type AiPlugin } from '../../plugins'
|
||||
import { globalRegistryManagement } from '../../providers/RegistryManagement'
|
||||
import { ImageGenerationError, ImageModelResolutionError } from '../errors'
|
||||
import { RuntimeExecutor } from '../executor'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('ai', () => ({
|
||||
experimental_generateImage: vi.fn(),
|
||||
NoImageGeneratedError: class NoImageGeneratedError extends Error {
|
||||
static isInstance = vi.fn()
|
||||
constructor() {
|
||||
super('No image generated')
|
||||
this.name = 'NoImageGeneratedError'
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../../providers/RegistryManagement', () => ({
|
||||
globalRegistryManagement: {
|
||||
imageModel: vi.fn()
|
||||
},
|
||||
DEFAULT_SEPARATOR: '|'
|
||||
}))
|
||||
|
||||
describe('RuntimeExecutor.generateImage', () => {
|
||||
let executor: RuntimeExecutor<'openai'>
|
||||
let mockImageModel: ImageModelV2
|
||||
let mockGenerateImageResult: any
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset all mocks
|
||||
vi.clearAllMocks()
|
||||
|
||||
// Create executor instance
|
||||
executor = RuntimeExecutor.create('openai', {
|
||||
apiKey: 'test-key'
|
||||
})
|
||||
|
||||
// Mock image model
|
||||
mockImageModel = {
|
||||
modelId: 'dall-e-3',
|
||||
provider: 'openai'
|
||||
} as ImageModelV2
|
||||
|
||||
// Mock generateImage result
|
||||
mockGenerateImageResult = {
|
||||
image: {
|
||||
base64: 'base64-encoded-image-data',
|
||||
uint8Array: new Uint8Array([1, 2, 3]),
|
||||
mediaType: 'image/png'
|
||||
},
|
||||
images: [
|
||||
{
|
||||
base64: 'base64-encoded-image-data',
|
||||
uint8Array: new Uint8Array([1, 2, 3]),
|
||||
mediaType: 'image/png'
|
||||
}
|
||||
],
|
||||
warnings: [],
|
||||
providerMetadata: {
|
||||
openai: {
|
||||
images: [{ revisedPrompt: 'A detailed prompt' }]
|
||||
}
|
||||
},
|
||||
responses: []
|
||||
}
|
||||
|
||||
// Setup mocks to avoid "No providers registered" error
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockReturnValue(mockImageModel)
|
||||
vi.mocked(aiGenerateImage).mockResolvedValue(mockGenerateImageResult)
|
||||
})
|
||||
|
||||
describe('Basic functionality', () => {
|
||||
it('should generate a single image with minimal parameters', async () => {
|
||||
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape at sunset' })
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai|dall-e-3')
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A futuristic cityscape at sunset'
|
||||
})
|
||||
|
||||
expect(result).toEqual(mockGenerateImageResult)
|
||||
})
|
||||
|
||||
it('should generate image with pre-created model', async () => {
|
||||
const result = await executor.generateImage({
|
||||
model: mockImageModel,
|
||||
prompt: 'A beautiful landscape'
|
||||
})
|
||||
|
||||
// Note: globalRegistryManagement.imageModel may still be called due to resolveImageModel logic
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A beautiful landscape'
|
||||
})
|
||||
|
||||
expect(result).toEqual(mockGenerateImageResult)
|
||||
})
|
||||
|
||||
it('should support multiple images generation', async () => {
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape', n: 3 })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A futuristic cityscape',
|
||||
n: 3
|
||||
})
|
||||
})
|
||||
|
||||
it('should support size specification', async () => {
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A beautiful sunset', size: '1024x1024' })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A beautiful sunset',
|
||||
size: '1024x1024'
|
||||
})
|
||||
})
|
||||
|
||||
it('should support aspect ratio specification', async () => {
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A mountain landscape', aspectRatio: '16:9' })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A mountain landscape',
|
||||
aspectRatio: '16:9'
|
||||
})
|
||||
})
|
||||
|
||||
it('should support seed for consistent output', async () => {
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A cat in space', seed: 1234567890 })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A cat in space',
|
||||
seed: 1234567890
|
||||
})
|
||||
})
|
||||
|
||||
it('should support abort signal', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A cityscape', abortSignal: abortController.signal })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A cityscape',
|
||||
abortSignal: abortController.signal
|
||||
})
|
||||
})
|
||||
|
||||
it('should support provider-specific options', async () => {
|
||||
await executor.generateImage({
|
||||
model: 'dall-e-3',
|
||||
prompt: 'A space station',
|
||||
providerOptions: {
|
||||
openai: {
|
||||
quality: 'hd',
|
||||
style: 'vivid'
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A space station',
|
||||
providerOptions: {
|
||||
openai: {
|
||||
quality: 'hd',
|
||||
style: 'vivid'
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should support custom headers', async () => {
|
||||
await executor.generateImage({
|
||||
model: 'dall-e-3',
|
||||
prompt: 'A robot',
|
||||
headers: {
|
||||
'X-Custom-Header': 'test-value'
|
||||
}
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A robot',
|
||||
headers: {
|
||||
'X-Custom-Header': 'test-value'
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Plugin integration', () => {
|
||||
it('should execute plugins in correct order', async () => {
|
||||
const pluginCallOrder: string[] = []
|
||||
|
||||
const testPlugin: AiPlugin = {
|
||||
name: 'test-plugin',
|
||||
onRequestStart: vi.fn(async () => {
|
||||
pluginCallOrder.push('onRequestStart')
|
||||
}),
|
||||
transformParams: vi.fn(async (params) => {
|
||||
pluginCallOrder.push('transformParams')
|
||||
return { ...params, size: '512x512' }
|
||||
}),
|
||||
transformResult: vi.fn(async (result) => {
|
||||
pluginCallOrder.push('transformResult')
|
||||
return { ...result, processed: true }
|
||||
}),
|
||||
onRequestEnd: vi.fn(async () => {
|
||||
pluginCallOrder.push('onRequestEnd')
|
||||
})
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
[testPlugin]
|
||||
)
|
||||
|
||||
const result = await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
|
||||
|
||||
expect(pluginCallOrder).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
|
||||
|
||||
expect(testPlugin.transformParams).toHaveBeenCalledWith(
|
||||
{ prompt: 'A test image' },
|
||||
expect.objectContaining({
|
||||
providerId: 'openai',
|
||||
modelId: 'dall-e-3'
|
||||
})
|
||||
)
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A test image',
|
||||
size: '512x512' // Should be transformed by plugin
|
||||
})
|
||||
|
||||
expect(result).toEqual({
|
||||
...mockGenerateImageResult,
|
||||
processed: true // Should be transformed by plugin
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle model resolution through plugins', async () => {
|
||||
const customImageModel = {
|
||||
modelId: 'custom-model',
|
||||
provider: 'openai'
|
||||
} as ImageModelV2
|
||||
|
||||
const modelResolutionPlugin: AiPlugin = {
|
||||
name: 'model-resolver',
|
||||
resolveModel: vi.fn(async () => customImageModel)
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
[modelResolutionPlugin]
|
||||
)
|
||||
|
||||
await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
|
||||
|
||||
expect(modelResolutionPlugin.resolveModel).toHaveBeenCalledWith(
|
||||
'dall-e-3',
|
||||
expect.objectContaining({
|
||||
providerId: 'openai',
|
||||
modelId: 'dall-e-3'
|
||||
})
|
||||
)
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: customImageModel,
|
||||
prompt: 'A test image'
|
||||
})
|
||||
})
|
||||
|
||||
it('should support recursive calls from plugins', async () => {
|
||||
const recursivePlugin: AiPlugin = {
|
||||
name: 'recursive-plugin',
|
||||
transformParams: vi.fn(async (params, context) => {
|
||||
if (!context.isRecursiveCall && params.prompt === 'original') {
|
||||
// Make a recursive call with modified prompt
|
||||
await context.recursiveCall({
|
||||
model: 'dall-e-3',
|
||||
prompt: 'modified'
|
||||
})
|
||||
}
|
||||
return params
|
||||
})
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
[recursivePlugin]
|
||||
)
|
||||
|
||||
await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'original' })
|
||||
|
||||
expect(recursivePlugin.transformParams).toHaveBeenCalledTimes(2)
|
||||
expect(aiGenerateImage).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error handling', () => {
|
||||
it('should handle model creation errors', async () => {
|
||||
const modelError = new Error('Failed to get image model')
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
|
||||
throw modelError
|
||||
})
|
||||
|
||||
await expect(executor.generateImage({ model: 'invalid-model', prompt: 'A test image' })).rejects.toThrow(
|
||||
ImageGenerationError
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle ImageModelResolutionError correctly', async () => {
|
||||
const resolutionError = new ImageModelResolutionError('invalid-model', 'openai', new Error('Model not found'))
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
|
||||
throw resolutionError
|
||||
})
|
||||
|
||||
const thrownError = await executor
|
||||
.generateImage({ model: 'invalid-model', prompt: 'A test image' })
|
||||
.catch((error) => error)
|
||||
|
||||
expect(thrownError).toBeInstanceOf(ImageGenerationError)
|
||||
expect(thrownError.message).toContain('Failed to generate image:')
|
||||
expect(thrownError.providerId).toBe('openai')
|
||||
expect(thrownError.modelId).toBe('invalid-model')
|
||||
expect(thrownError.cause).toBeInstanceOf(ImageModelResolutionError)
|
||||
expect(thrownError.cause.message).toContain('Failed to resolve image model: invalid-model')
|
||||
})
|
||||
|
||||
it('should handle ImageModelResolutionError without provider', async () => {
|
||||
const resolutionError = new ImageModelResolutionError('unknown-model')
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
|
||||
throw resolutionError
|
||||
})
|
||||
|
||||
await expect(executor.generateImage({ model: 'unknown-model', prompt: 'A test image' })).rejects.toThrow(
|
||||
ImageGenerationError
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle image generation API errors', async () => {
|
||||
const apiError = new Error('API request failed')
|
||||
vi.mocked(aiGenerateImage).mockRejectedValue(apiError)
|
||||
|
||||
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
|
||||
'Failed to generate image:'
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle NoImageGeneratedError', async () => {
|
||||
const noImageError = new NoImageGeneratedError({
|
||||
cause: new Error('No image generated'),
|
||||
responses: []
|
||||
})
|
||||
|
||||
vi.mocked(aiGenerateImage).mockRejectedValue(noImageError)
|
||||
vi.mocked(NoImageGeneratedError.isInstance).mockReturnValue(true)
|
||||
|
||||
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
|
||||
'Failed to generate image:'
|
||||
)
|
||||
})
|
||||
|
||||
it('should execute onError plugin hook on failure', async () => {
|
||||
const error = new Error('Generation failed')
|
||||
vi.mocked(aiGenerateImage).mockRejectedValue(error)
|
||||
|
||||
const errorPlugin: AiPlugin = {
|
||||
name: 'error-handler',
|
||||
onError: vi.fn()
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
[errorPlugin]
|
||||
)
|
||||
|
||||
await expect(executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
|
||||
'Failed to generate image:'
|
||||
)
|
||||
|
||||
expect(errorPlugin.onError).toHaveBeenCalledWith(
|
||||
error,
|
||||
expect.objectContaining({
|
||||
providerId: 'openai',
|
||||
modelId: 'dall-e-3'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle abort signal timeout', async () => {
|
||||
const abortError = new Error('Operation was aborted')
|
||||
abortError.name = 'AbortError'
|
||||
vi.mocked(aiGenerateImage).mockRejectedValue(abortError)
|
||||
|
||||
const abortController = new AbortController()
|
||||
setTimeout(() => abortController.abort(), 10)
|
||||
|
||||
await expect(
|
||||
executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', abortSignal: abortController.signal })
|
||||
).rejects.toThrow('Failed to generate image:')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Multiple providers support', () => {
|
||||
it('should work with different providers', async () => {
|
||||
const googleExecutor = RuntimeExecutor.create('google', {
|
||||
apiKey: 'google-key'
|
||||
})
|
||||
|
||||
await googleExecutor.generateImage({ model: 'imagen-3.0-generate-002', prompt: 'A landscape' })
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('google|imagen-3.0-generate-002')
|
||||
})
|
||||
|
||||
it('should support xAI Grok image models', async () => {
|
||||
const xaiExecutor = RuntimeExecutor.create('xai', {
|
||||
apiKey: 'xai-key'
|
||||
})
|
||||
|
||||
await xaiExecutor.generateImage({ model: 'grok-2-image', prompt: 'A futuristic robot' })
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('xai|grok-2-image')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Advanced features', () => {
|
||||
it('should support batch image generation with maxImagesPerCall', async () => {
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', n: 10, maxImagesPerCall: 5 })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A test image',
|
||||
n: 10,
|
||||
maxImagesPerCall: 5
|
||||
})
|
||||
})
|
||||
|
||||
it('should support retries with maxRetries', async () => {
|
||||
await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', maxRetries: 3 })
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A test image',
|
||||
maxRetries: 3
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle warnings from the model', async () => {
|
||||
const resultWithWarnings = {
|
||||
...mockGenerateImageResult,
|
||||
warnings: [
|
||||
{
|
||||
type: 'unsupported-setting',
|
||||
message: 'Size parameter not supported for this model'
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithWarnings)
|
||||
|
||||
const result = await executor.generateImage({
|
||||
model: 'dall-e-3',
|
||||
prompt: 'A test image',
|
||||
size: '2048x2048' // Unsupported size
|
||||
})
|
||||
|
||||
expect(result.warnings).toHaveLength(1)
|
||||
expect(result.warnings[0].type).toBe('unsupported-setting')
|
||||
})
|
||||
|
||||
it('should provide access to provider metadata', async () => {
|
||||
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
|
||||
|
||||
expect(result.providerMetadata).toBeDefined()
|
||||
expect(result.providerMetadata.openai).toBeDefined()
|
||||
})
|
||||
|
||||
it('should provide response metadata', async () => {
|
||||
const resultWithMetadata = {
|
||||
...mockGenerateImageResult,
|
||||
responses: [
|
||||
{
|
||||
timestamp: new Date(),
|
||||
modelId: 'dall-e-3',
|
||||
headers: { 'x-request-id': 'test-123' }
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithMetadata)
|
||||
|
||||
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
|
||||
|
||||
expect(result.responses).toHaveLength(1)
|
||||
expect(result.responses[0].modelId).toBe('dall-e-3')
|
||||
expect(result.responses[0].headers).toEqual({ 'x-request-id': 'test-123' })
|
||||
})
|
||||
})
|
||||
})
|
||||
38
packages/aiCore/src/core/runtime/errors.ts
Normal file
38
packages/aiCore/src/core/runtime/errors.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
/**
|
||||
* Error classes for runtime operations
|
||||
*/
|
||||
|
||||
/**
|
||||
* Error thrown when image generation fails
|
||||
*/
|
||||
export class ImageGenerationError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public providerId?: string,
|
||||
public modelId?: string,
|
||||
public cause?: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'ImageGenerationError'
|
||||
|
||||
// Maintain proper stack trace (for V8 engines)
|
||||
if (Error.captureStackTrace) {
|
||||
Error.captureStackTrace(this, ImageGenerationError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Error thrown when model resolution fails during image generation
|
||||
*/
|
||||
export class ImageModelResolutionError extends ImageGenerationError {
|
||||
constructor(modelId: string, providerId?: string, cause?: Error) {
|
||||
super(
|
||||
`Failed to resolve image model: ${modelId}${providerId ? ` for provider: ${providerId}` : ''}`,
|
||||
providerId,
|
||||
modelId,
|
||||
cause
|
||||
)
|
||||
this.name = 'ImageModelResolutionError'
|
||||
}
|
||||
}
|
||||
321
packages/aiCore/src/core/runtime/executor.ts
Normal file
321
packages/aiCore/src/core/runtime/executor.ts
Normal file
@@ -0,0 +1,321 @@
|
||||
/**
|
||||
* 运行时执行器
|
||||
* 专注于插件化的AI调用处理
|
||||
*/
|
||||
import { ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
import {
|
||||
experimental_generateImage as generateImage,
|
||||
generateObject,
|
||||
generateText,
|
||||
LanguageModel,
|
||||
streamObject,
|
||||
streamText
|
||||
} from 'ai'
|
||||
|
||||
import { globalModelResolver } from '../models'
|
||||
import { type ModelConfig } from '../models/types'
|
||||
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
|
||||
import { type ProviderId } from '../providers'
|
||||
import { ImageGenerationError, ImageModelResolutionError } from './errors'
|
||||
import { PluginEngine } from './pluginEngine'
|
||||
import { type RuntimeConfig } from './types'
|
||||
|
||||
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
public pluginEngine: PluginEngine<T>
|
||||
// private options: ProviderSettingsMap[T]
|
||||
private config: RuntimeConfig<T>
|
||||
|
||||
constructor(config: RuntimeConfig<T>) {
|
||||
// if (!isProviderSupported(config.providerId)) {
|
||||
// throw new Error(`Unsupported provider: ${config.providerId}`)
|
||||
// }
|
||||
|
||||
// 存储options供后续使用
|
||||
// this.options = config.options
|
||||
this.config = config
|
||||
// 创建插件客户端
|
||||
this.pluginEngine = new PluginEngine(config.providerId, config.plugins || [])
|
||||
}
|
||||
|
||||
private createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) {
|
||||
return definePlugin({
|
||||
name: '_internal_resolveModel',
|
||||
enforce: 'post',
|
||||
|
||||
resolveModel: async (modelId: string) => {
|
||||
// 注意:extraModelConfig 暂时不支持,已在新架构中移除
|
||||
return await this.resolveModel(modelId, middlewares)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private createResolveImageModelPlugin() {
|
||||
return definePlugin({
|
||||
name: '_internal_resolveImageModel',
|
||||
enforce: 'post',
|
||||
|
||||
resolveModel: async (modelId: string) => {
|
||||
return await this.resolveImageModel(modelId)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private createConfigureContextPlugin() {
|
||||
return definePlugin({
|
||||
name: '_internal_configureContext',
|
||||
configureContext: async (context: AiRequestContext) => {
|
||||
context.executor = this
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// === 高阶重载:直接使用模型 ===
|
||||
|
||||
/**
|
||||
* 流式文本生成
|
||||
*/
|
||||
async streamText(
|
||||
params: Parameters<typeof streamText>[0],
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamText>> {
|
||||
const { model, ...restParams } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
} else {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return this.pluginEngine.executeStreamWithPlugins(
|
||||
'streamText',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams, streamTransforms) => {
|
||||
const experimental_transform =
|
||||
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
|
||||
|
||||
const finalParams = {
|
||||
model: resolvedModel,
|
||||
...transformedParams,
|
||||
experimental_transform
|
||||
} as Parameters<typeof streamText>[0]
|
||||
|
||||
return await streamText(finalParams)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
// === 其他方法的重载 ===
|
||||
|
||||
/**
|
||||
* 生成文本
|
||||
*/
|
||||
async generateText(
|
||||
params: Parameters<typeof generateText>[0],
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateText>> {
|
||||
const { model, ...restParams } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
} else {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'generateText',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) =>
|
||||
generateText({ model: resolvedModel, ...transformedParams } as Parameters<typeof generateText>[0])
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成结构化对象
|
||||
*/
|
||||
async generateObject(
|
||||
params: Parameters<typeof generateObject>[0],
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateObject>> {
|
||||
const { model, ...restParams } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
} else {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'generateObject',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) =>
|
||||
generateObject({ model: resolvedModel, ...transformedParams } as Parameters<typeof generateObject>[0])
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 流式生成结构化对象
|
||||
*/
|
||||
async streamObject(
|
||||
params: Parameters<typeof streamObject>[0],
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamObject>> {
|
||||
const { model, ...restParams } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
} else {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'streamObject',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) =>
|
||||
streamObject({ model: resolvedModel, ...transformedParams } as Parameters<typeof streamObject>[0])
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成图像
|
||||
*/
|
||||
async generateImage(
|
||||
params: Omit<Parameters<typeof generateImage>[0], 'model'> & { model: string | ImageModelV2 }
|
||||
): Promise<ReturnType<typeof generateImage>> {
|
||||
try {
|
||||
const { model, ...restParams } = params
|
||||
|
||||
// 根据 model 类型决定插件配置
|
||||
if (typeof model === 'string') {
|
||||
this.pluginEngine.usePlugins([this.createResolveImageModelPlugin(), this.createConfigureContextPlugin()])
|
||||
} else {
|
||||
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||
}
|
||||
|
||||
return await this.pluginEngine.executeImageWithPlugins(
|
||||
'generateImage',
|
||||
model,
|
||||
restParams,
|
||||
async (resolvedModel, transformedParams) => {
|
||||
return await generateImage({ model: resolvedModel, ...transformedParams })
|
||||
}
|
||||
)
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
const modelId = typeof params.model === 'string' ? params.model : params.model.modelId
|
||||
throw new ImageGenerationError(
|
||||
`Failed to generate image: ${error.message}`,
|
||||
this.config.providerId,
|
||||
modelId,
|
||||
error
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// === 辅助方法 ===
|
||||
|
||||
/**
|
||||
* 解析模型:如果是字符串则创建模型,如果是模型则直接返回
|
||||
*/
|
||||
private async resolveModel(
|
||||
modelOrId: LanguageModel,
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<LanguageModelV2> {
|
||||
if (typeof modelOrId === 'string') {
|
||||
// 🎯 字符串modelId,使用新的ModelResolver解析,传递完整参数
|
||||
return await globalModelResolver.resolveLanguageModel(
|
||||
modelOrId, // 支持 'gpt-4' 和 'aihubmix:anthropic:claude-3.5-sonnet'
|
||||
this.config.providerId, // fallback provider
|
||||
this.config.providerSettings, // provider options
|
||||
middlewares // 中间件数组
|
||||
)
|
||||
} else {
|
||||
// 已经是模型,直接返回
|
||||
return modelOrId
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析图像模型:如果是字符串则创建图像模型,如果是模型则直接返回
|
||||
*/
|
||||
private async resolveImageModel(modelOrId: ImageModelV2 | string): Promise<ImageModelV2> {
|
||||
try {
|
||||
if (typeof modelOrId === 'string') {
|
||||
// 字符串modelId,使用新的ModelResolver解析
|
||||
return await globalModelResolver.resolveImageModel(
|
||||
modelOrId, // 支持 'dall-e-3' 和 'aihubmix:openai:dall-e-3'
|
||||
this.config.providerId // fallback provider
|
||||
)
|
||||
} else {
|
||||
// 已经是模型,直接返回
|
||||
return modelOrId
|
||||
}
|
||||
} catch (error) {
|
||||
throw new ImageModelResolutionError(
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
this.config.providerId,
|
||||
error instanceof Error ? error : undefined
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// === 静态工厂方法 ===
|
||||
|
||||
/**
|
||||
* 创建执行器 - 支持已知provider的类型安全
|
||||
*/
|
||||
static create<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ModelConfig<T>['providerSettings'],
|
||||
plugins?: AiPlugin[]
|
||||
): RuntimeExecutor<T> {
|
||||
return new RuntimeExecutor({
|
||||
providerId,
|
||||
providerSettings: options,
|
||||
plugins
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建OpenAI Compatible执行器
|
||||
*/
|
||||
static createOpenAICompatible(
|
||||
options: ModelConfig<'openai-compatible'>['providerSettings'],
|
||||
plugins: AiPlugin[] = []
|
||||
): RuntimeExecutor<'openai-compatible'> {
|
||||
return new RuntimeExecutor({
|
||||
providerId: 'openai-compatible',
|
||||
providerSettings: options,
|
||||
plugins
|
||||
})
|
||||
}
|
||||
}
|
||||
117
packages/aiCore/src/core/runtime/index.ts
Normal file
117
packages/aiCore/src/core/runtime/index.ts
Normal file
@@ -0,0 +1,117 @@
|
||||
/**
|
||||
* Runtime 模块导出
|
||||
* 专注于运行时插件化AI调用处理
|
||||
*/
|
||||
|
||||
// 主要的运行时执行器
|
||||
export { RuntimeExecutor } from './executor'
|
||||
|
||||
// 导出类型
|
||||
export type { RuntimeConfig } from './types'
|
||||
|
||||
// === 便捷工厂函数 ===
|
||||
|
||||
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
import { type AiPlugin } from '../plugins'
|
||||
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
|
||||
import { RuntimeExecutor } from './executor'
|
||||
|
||||
/**
|
||||
* 创建运行时执行器 - 支持类型安全的已知provider
|
||||
*/
|
||||
export function createExecutor<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
plugins?: AiPlugin[]
|
||||
): RuntimeExecutor<T> {
|
||||
return RuntimeExecutor.create(providerId, options, plugins)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建OpenAI Compatible执行器
|
||||
*/
|
||||
export function createOpenAICompatibleExecutor(
|
||||
options: ProviderSettingsMap['openai-compatible'] & { mode?: 'chat' | 'responses' },
|
||||
plugins: AiPlugin[] = []
|
||||
): RuntimeExecutor<'openai-compatible'> {
|
||||
return RuntimeExecutor.createOpenAICompatible(options, plugins)
|
||||
}
|
||||
|
||||
// === 直接调用API(无需创建executor实例)===
|
||||
|
||||
/**
|
||||
* 直接流式文本生成 - 支持middlewares
|
||||
*/
|
||||
export async function streamText<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
params: Parameters<RuntimeExecutor<T>['streamText']>[0],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['streamText']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.streamText(params, { middlewares })
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接生成文本 - 支持middlewares
|
||||
*/
|
||||
export async function generateText<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
params: Parameters<RuntimeExecutor<T>['generateText']>[0],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['generateText']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.generateText(params, { middlewares })
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接生成结构化对象 - 支持middlewares
|
||||
*/
|
||||
export async function generateObject<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
params: Parameters<RuntimeExecutor<T>['generateObject']>[0],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['generateObject']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.generateObject(params, { middlewares })
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接流式生成结构化对象 - 支持middlewares
|
||||
*/
|
||||
export async function streamObject<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
params: Parameters<RuntimeExecutor<T>['streamObject']>[0],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['streamObject']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.streamObject(params, { middlewares })
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接生成图像 - 支持middlewares
|
||||
*/
|
||||
export async function generateImage<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
params: Parameters<RuntimeExecutor<T>['generateImage']>[0],
|
||||
plugins?: AiPlugin[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['generateImage']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.generateImage(params)
|
||||
}
|
||||
|
||||
// === Agent 功能预留 ===
|
||||
// 未来将在 ../agents/ 文件夹中添加:
|
||||
// - AgentExecutor.ts
|
||||
// - WorkflowManager.ts
|
||||
// - ConversationManager.ts
|
||||
// 并在此处导出相关API
|
||||
290
packages/aiCore/src/core/runtime/pluginEngine.ts
Normal file
290
packages/aiCore/src/core/runtime/pluginEngine.ts
Normal file
@@ -0,0 +1,290 @@
|
||||
/* eslint-disable @eslint-react/naming-convention/context-name */
|
||||
import { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import { LanguageModel } from 'ai'
|
||||
|
||||
import { type AiPlugin, createContext, PluginManager } from '../plugins'
|
||||
import { type ProviderId } from '../providers/types'
|
||||
|
||||
/**
|
||||
* 插件增强的 AI 客户端
|
||||
* 专注于插件处理,不暴露用户API
|
||||
*/
|
||||
export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
private pluginManager: PluginManager
|
||||
|
||||
constructor(
|
||||
private readonly providerId: T,
|
||||
// private readonly options: ProviderSettingsMap[T],
|
||||
plugins: AiPlugin[] = []
|
||||
) {
|
||||
this.pluginManager = new PluginManager(plugins)
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加插件
|
||||
*/
|
||||
use(plugin: AiPlugin): this {
|
||||
this.pluginManager.use(plugin)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量添加插件
|
||||
*/
|
||||
usePlugins(plugins: AiPlugin[]): this {
|
||||
plugins.forEach((plugin) => this.use(plugin))
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 移除插件
|
||||
*/
|
||||
removePlugin(pluginName: string): this {
|
||||
this.pluginManager.remove(pluginName)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取插件统计
|
||||
*/
|
||||
getPluginStats() {
|
||||
return this.pluginManager.getStats()
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有插件
|
||||
*/
|
||||
getPlugins() {
|
||||
return this.pluginManager.getPlugins()
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行带插件的操作(非流式)
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeWithPlugins<TParams, TResult>(
|
||||
methodName: string,
|
||||
model: LanguageModel,
|
||||
params: TParams,
|
||||
executor: (model: LanguageModel, transformedParams: TParams) => Promise<TResult>,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 统一处理模型解析
|
||||
let resolvedModel: LanguageModel | undefined
|
||||
let modelId: string
|
||||
|
||||
if (typeof model === 'string') {
|
||||
// 字符串:需要通过插件解析
|
||||
modelId = model
|
||||
} else {
|
||||
// 模型对象:直接使用
|
||||
resolvedModel = model
|
||||
modelId = model.modelId
|
||||
}
|
||||
|
||||
// 使用正确的createContext创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeWithPlugins(methodName, model, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
|
||||
try {
|
||||
// 0. 配置上下文
|
||||
await this.pluginManager.executeConfigureContext(context)
|
||||
|
||||
// 1. 触发请求开始事件
|
||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 解析模型(如果是字符串)
|
||||
if (typeof model === 'string') {
|
||||
const resolved = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||
if (!resolved) {
|
||||
throw new Error(`Failed to resolve model: ${modelId}`)
|
||||
}
|
||||
resolvedModel = resolved
|
||||
}
|
||||
|
||||
if (!resolvedModel) {
|
||||
throw new Error(`Model resolution failed: no model available`)
|
||||
}
|
||||
|
||||
// 3. 转换请求参数
|
||||
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||
|
||||
// 4. 执行具体的 API 调用
|
||||
const result = await executor(resolvedModel, transformedParams)
|
||||
|
||||
// 5. 转换结果(对于非流式调用)
|
||||
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||
|
||||
// 6. 触发完成事件
|
||||
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
|
||||
return transformedResult
|
||||
} catch (error) {
|
||||
// 7. 触发错误事件
|
||||
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行带插件的图像生成操作
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeImageWithPlugins<TParams, TResult>(
|
||||
methodName: string,
|
||||
model: ImageModelV2 | string,
|
||||
params: TParams,
|
||||
executor: (model: ImageModelV2, transformedParams: TParams) => Promise<TResult>,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 统一处理模型解析
|
||||
let resolvedModel: ImageModelV2 | undefined
|
||||
let modelId: string
|
||||
|
||||
if (typeof model === 'string') {
|
||||
// 字符串:需要通过插件解析
|
||||
modelId = model
|
||||
} else {
|
||||
// 模型对象:直接使用
|
||||
resolvedModel = model
|
||||
modelId = model.modelId
|
||||
}
|
||||
|
||||
// 使用正确的createContext创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeImageWithPlugins(methodName, model, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
|
||||
try {
|
||||
// 0. 配置上下文
|
||||
await this.pluginManager.executeConfigureContext(context)
|
||||
|
||||
// 1. 触发请求开始事件
|
||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 解析模型(如果是字符串)
|
||||
if (typeof model === 'string') {
|
||||
const resolved = await this.pluginManager.executeFirst<ImageModelV2>('resolveModel', modelId, context)
|
||||
if (!resolved) {
|
||||
throw new Error(`Failed to resolve image model: ${modelId}`)
|
||||
}
|
||||
resolvedModel = resolved
|
||||
}
|
||||
|
||||
if (!resolvedModel) {
|
||||
throw new Error(`Image model resolution failed: no model available`)
|
||||
}
|
||||
|
||||
// 3. 转换请求参数
|
||||
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||
|
||||
// 4. 执行具体的 API 调用
|
||||
const result = await executor(resolvedModel, transformedParams)
|
||||
|
||||
// 5. 转换结果
|
||||
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||
|
||||
// 6. 触发完成事件
|
||||
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
|
||||
return transformedResult
|
||||
} catch (error) {
|
||||
// 7. 触发错误事件
|
||||
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行流式调用的通用逻辑(支持流转换器)
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeStreamWithPlugins<TParams, TResult>(
|
||||
methodName: string,
|
||||
model: LanguageModel,
|
||||
params: TParams,
|
||||
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 统一处理模型解析
|
||||
let resolvedModel: LanguageModel | undefined
|
||||
let modelId: string
|
||||
|
||||
if (typeof model === 'string') {
|
||||
// 字符串:需要通过插件解析
|
||||
modelId = model
|
||||
} else {
|
||||
// 模型对象:直接使用
|
||||
resolvedModel = model
|
||||
modelId = model.modelId
|
||||
}
|
||||
|
||||
// 创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeStreamWithPlugins(methodName, model, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
|
||||
try {
|
||||
// 0. 配置上下文
|
||||
await this.pluginManager.executeConfigureContext(context)
|
||||
|
||||
// 1. 触发请求开始事件
|
||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 解析模型(如果是字符串)
|
||||
if (typeof model === 'string') {
|
||||
const resolved = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||
if (!resolved) {
|
||||
throw new Error(`Failed to resolve model: ${modelId}`)
|
||||
}
|
||||
resolvedModel = resolved
|
||||
}
|
||||
|
||||
if (!resolvedModel) {
|
||||
throw new Error(`Model resolution failed: no model available`)
|
||||
}
|
||||
|
||||
// 3. 转换请求参数
|
||||
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||
|
||||
// 4. 收集流转换器
|
||||
const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context)
|
||||
|
||||
// 5. 执行流式 API 调用
|
||||
const result = await executor(resolvedModel, transformedParams, streamTransforms)
|
||||
|
||||
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||
|
||||
// 6. 触发完成事件(注意:对于流式调用,这里触发的是开始流式响应的事件)
|
||||
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
|
||||
return transformedResult
|
||||
} catch (error) {
|
||||
// 7. 触发错误事件
|
||||
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
15
packages/aiCore/src/core/runtime/types.ts
Normal file
15
packages/aiCore/src/core/runtime/types.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
/**
|
||||
* Runtime 层类型定义
|
||||
*/
|
||||
import { type ModelConfig } from '../models/types'
|
||||
import { type AiPlugin } from '../plugins'
|
||||
import { type ProviderId } from '../providers/types'
|
||||
|
||||
/**
|
||||
* 运行时执行器配置
|
||||
*/
|
||||
export interface RuntimeConfig<T extends ProviderId = ProviderId> {
|
||||
providerId: T
|
||||
providerSettings: ModelConfig<T>['providerSettings'] & { mode?: 'chat' | 'responses' }
|
||||
plugins?: AiPlugin[]
|
||||
}
|
||||
46
packages/aiCore/src/index.ts
Normal file
46
packages/aiCore/src/index.ts
Normal file
@@ -0,0 +1,46 @@
|
||||
/**
|
||||
* Cherry Studio AI Core Package
|
||||
* 基于 Vercel AI SDK 的统一 AI Provider 接口
|
||||
*/
|
||||
|
||||
// 导入内部使用的类和函数
|
||||
|
||||
// ==================== 主要用户接口 ====================
|
||||
export {
|
||||
createExecutor,
|
||||
createOpenAICompatibleExecutor,
|
||||
generateImage,
|
||||
generateObject,
|
||||
generateText,
|
||||
streamText
|
||||
} from './core/runtime'
|
||||
|
||||
// ==================== 高级API ====================
|
||||
export { globalModelResolver as modelResolver } from './core/models'
|
||||
|
||||
// ==================== 插件系统 ====================
|
||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins'
|
||||
export { createContext, definePlugin, PluginManager } from './core/plugins'
|
||||
// export { createPromptToolUsePlugin, webSearchPlugin } from './core/plugins/built-in'
|
||||
export { PluginEngine } from './core/runtime/pluginEngine'
|
||||
|
||||
// ==================== AI SDK 常用类型导出 ====================
|
||||
// 直接导出 AI SDK 的常用类型,方便使用
|
||||
export type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider'
|
||||
export type { ToolCall } from '@ai-sdk/provider-utils'
|
||||
export type { ReasoningPart } from '@ai-sdk/provider-utils'
|
||||
|
||||
// ==================== 选项 ====================
|
||||
export {
|
||||
createAnthropicOptions,
|
||||
createGoogleOptions,
|
||||
createOpenAIOptions,
|
||||
type ExtractProviderOptions,
|
||||
mergeProviderOptions,
|
||||
type ProviderOptionsMap,
|
||||
type TypedProviderOptions
|
||||
} from './core/options'
|
||||
|
||||
// ==================== 包信息 ====================
|
||||
export const AI_CORE_VERSION = '1.0.0'
|
||||
export const AI_CORE_NAME = '@cherrystudio/ai-core'
|
||||
2
packages/aiCore/src/types.ts
Normal file
2
packages/aiCore/src/types.ts
Normal file
@@ -0,0 +1,2 @@
|
||||
// 重新导出插件类型
|
||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types'
|
||||
26
packages/aiCore/tsconfig.json
Normal file
26
packages/aiCore/tsconfig.json
Normal file
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2020",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "bundler",
|
||||
"declaration": true,
|
||||
"outDir": "./dist",
|
||||
"rootDir": "./src",
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"skipLibCheck": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"resolveJsonModule": true,
|
||||
"allowSyntheticDefaultImports": true,
|
||||
"noEmitOnError": false,
|
||||
"experimentalDecorators": true,
|
||||
"emitDecoratorMetadata": true
|
||||
},
|
||||
"include": [
|
||||
"src/**/*"
|
||||
],
|
||||
"exclude": [
|
||||
"node_modules",
|
||||
"dist"
|
||||
]
|
||||
}
|
||||
14
packages/aiCore/tsdown.config.ts
Normal file
14
packages/aiCore/tsdown.config.ts
Normal file
@@ -0,0 +1,14 @@
|
||||
import { defineConfig } from 'tsdown'
|
||||
|
||||
export default defineConfig({
|
||||
entry: {
|
||||
index: 'src/index.ts',
|
||||
'built-in/plugins/index': 'src/core/plugins/built-in/index.ts',
|
||||
'provider/index': 'src/core/providers/index.ts'
|
||||
},
|
||||
outDir: 'dist',
|
||||
format: ['esm', 'cjs'],
|
||||
clean: true,
|
||||
dts: true,
|
||||
tsconfig: 'tsconfig.json'
|
||||
})
|
||||
15
packages/aiCore/vitest.config.ts
Normal file
15
packages/aiCore/vitest.config.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
import { defineConfig } from 'vitest/config'
|
||||
|
||||
export default defineConfig({
|
||||
test: {
|
||||
globals: true
|
||||
},
|
||||
resolve: {
|
||||
alias: {
|
||||
'@': './src'
|
||||
}
|
||||
},
|
||||
esbuild: {
|
||||
target: 'node18'
|
||||
}
|
||||
})
|
||||
@@ -250,6 +250,7 @@ export enum IpcChannel {
|
||||
|
||||
// Provider
|
||||
Provider_AddKey = 'provider:add-key',
|
||||
Provider_GetClaudeCodePort = 'provider:get-claude-code-port',
|
||||
|
||||
//Selection Assistant
|
||||
Selection_TextSelected = 'selection:text-selected',
|
||||
|
||||
@@ -2089,7 +2089,7 @@
|
||||
"Design",
|
||||
"Education"
|
||||
],
|
||||
"prompt": "I want you to act as a Graphviz DOT generator, an expert to create meaningful diagrams. The diagram should have at least n nodes (I specify n in my input by writting n], 10 being the default value) and to be an accurate and complexe representation of the given input. Each node is indexed by a number to reduce the size of the output, should not include any styling, and with layout=neato, overlap=false, node shape=rectangle] as parameters. The code should be valid, bugless and returned on a single line, without any explanation. Provide a clear and organized diagram, the relationships between the nodes have to make sense for an expert of that input. My first diagram is: \"The water cycle 8]\".\n\n",
|
||||
"prompt": "I want you to act as a Graphviz DOT generator, an expert to create meaningful diagrams. The diagram should have at least n nodes (I specify n in my input by writing n], 10 being the default value) and to be an accurate and complex representation of the given input. Each node is indexed by a number to reduce the size of the output, should not include any styling, and with layout=neato, overlap=false, node shape=rectangle] as parameters. The code should be valid, bugless and returned on a single line, without any explanation. Provide a clear and organized diagram, the relationships between the nodes have to make sense for an expert of that input. My first diagram is: \"The water cycle 8]\".\n\n",
|
||||
"description": "Generate meaningful charts."
|
||||
},
|
||||
{
|
||||
@@ -2148,7 +2148,7 @@
|
||||
"Career",
|
||||
"Business"
|
||||
],
|
||||
"prompt": "Please acknowledge my following request. Please respond to me as a product manager. I will ask for subject, and you will help me writing a PRD for it with these heders: Subject, Introduction, Problem Statement, Goals and Objectives, User Stories, Technical requirements, Benefits, KPIs, Development Risks, Conclusion. Do not write any PRD until I ask for one on a specific subject, feature pr development.\n\n",
|
||||
"prompt": "Please acknowledge my following request. Please respond to me as a product manager. I will ask for subject, and you will help me writing a PRD for it with these headers: Subject, Introduction, Problem Statement, Goals and Objectives, User Stories, Technical requirements, Benefits, KPIs, Development Risks, Conclusion. Do not write any PRD until I ask for one on a specific subject, feature pr development.\n\n",
|
||||
"description": "Help draft the Product Requirements Document."
|
||||
},
|
||||
{
|
||||
@@ -2159,7 +2159,7 @@
|
||||
"Entertainment",
|
||||
"General"
|
||||
],
|
||||
"prompt": "I want you to act as a drunk person. You will only answer like a very drunk person texting and nothing else. Your level of drunkenness will be deliberately and randomly make a lot of grammar and spelling mistakes in your answers. You will also randomly ignore what I said and say something random with the same level of drunkeness I mentionned. Do not write explanations on replies. My first sentence is \"how are you?",
|
||||
"prompt": "I want you to act as a drunk person. You will only answer like a very drunk person texting and nothing else. Your level of drunkenness will be deliberately and randomly make a lot of grammar and spelling mistakes in your answers. You will also randomly ignore what I said and say something random with the same level of drunkenness I mentioned. Do not write explanations on replies. My first sentence is \"how are you?",
|
||||
"description": "Mimic the speech pattern of a drunk person."
|
||||
},
|
||||
{
|
||||
@@ -3517,7 +3517,7 @@
|
||||
"Tools",
|
||||
"Copywriting"
|
||||
],
|
||||
"prompt": "I want you to act as a scientific manuscript matcher. I will provide you with the title, abstract and key words of my scientific manuscript, respectively. Your task is analyzing my title, abstract and key words synthetically to find the most related, reputable journals for potential publication of my research based on an analysis of tens of millions of citation connections in database, such as Web of Science, Pubmed, Scopus, ScienceDirect and so on. You only need to provide me with the 15 most suitable journals. Your reply should include the name of journal, the cooresponding match score (The full score is ten). I want you to reply in text-based excel sheet and sort by matching scores in reverse order.\nMy title is \"XXX\" My abstract is \"XXX\" My key words are \"XXX\"\n\n",
|
||||
"prompt": "I want you to act as a scientific manuscript matcher. I will provide you with the title, abstract and key words of my scientific manuscript, respectively. Your task is analyzing my title, abstract and key words synthetically to find the most related, reputable journals for potential publication of my research based on an analysis of tens of millions of citation connections in database, such as Web of Science, Pubmed, Scopus, ScienceDirect and so on. You only need to provide me with the 15 most suitable journals. Your reply should include the name of journal, the corresponding match score (The full score is ten). I want you to reply in text-based excel sheet and sort by matching scores in reverse order.\nMy title is \"XXX\" My abstract is \"XXX\" My key words are \"XXX\"\n\n",
|
||||
"description": ""
|
||||
},
|
||||
{
|
||||
|
||||
@@ -7,15 +7,12 @@ const allArm64 = {
|
||||
'@img/sharp-darwin-arm64': '0.34.3',
|
||||
'@img/sharp-win32-arm64': '0.34.3',
|
||||
'@img/sharp-linux-arm64': '0.34.3',
|
||||
'@img/sharp-linuxmusl-arm64': '0.34.3',
|
||||
|
||||
'@img/sharp-libvips-darwin-arm64': '1.2.0',
|
||||
'@img/sharp-libvips-linux-arm64': '1.2.0',
|
||||
'@img/sharp-libvips-linuxmusl-arm64': '1.2.0',
|
||||
|
||||
'@libsql/darwin-arm64': '0.4.7',
|
||||
'@libsql/linux-arm64-gnu': '0.4.7',
|
||||
'@libsql/linux-arm64-musl': '0.4.7',
|
||||
'@strongtz/win32-arm64-msvc': '0.4.7',
|
||||
|
||||
'@napi-rs/system-ocr-darwin-arm64': '1.0.2',
|
||||
@@ -25,16 +22,13 @@ const allArm64 = {
|
||||
const allX64 = {
|
||||
'@img/sharp-darwin-x64': '0.34.3',
|
||||
'@img/sharp-linux-x64': '0.34.3',
|
||||
'@img/sharp-linuxmusl-x64': '0.34.3',
|
||||
'@img/sharp-win32-x64': '0.34.3',
|
||||
|
||||
'@img/sharp-libvips-darwin-x64': '1.2.0',
|
||||
'@img/sharp-libvips-linux-x64': '1.2.0',
|
||||
'@img/sharp-libvips-linuxmusl-x64': '1.2.0',
|
||||
|
||||
'@libsql/darwin-x64': '0.4.7',
|
||||
'@libsql/linux-x64-gnu': '0.4.7',
|
||||
'@libsql/linux-x64-musl': '0.4.7',
|
||||
'@libsql/win32-x64-msvc': '0.4.7',
|
||||
|
||||
'@napi-rs/system-ocr-darwin-x64': '1.0.2',
|
||||
|
||||
@@ -13,6 +13,7 @@ import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electro
|
||||
|
||||
import { isDev, isLinux, isWin } from './constant'
|
||||
import { registerIpc } from './ipc'
|
||||
import { claudeCodeService } from './services/ClaudeCodeService'
|
||||
import { configManager } from './services/ConfigManager'
|
||||
import mcpService from './services/MCPService'
|
||||
import { nodeTraceService } from './services/NodeTraceService'
|
||||
@@ -119,6 +120,14 @@ if (!app.requestSingleInstanceLock()) {
|
||||
|
||||
nodeTraceService.init()
|
||||
|
||||
// Start Claude-code HTTP service
|
||||
try {
|
||||
await claudeCodeService.start()
|
||||
logger.info('Claude-code HTTP service started successfully')
|
||||
} catch (error) {
|
||||
logger.error('Failed to start Claude-code HTTP service:', error as Error)
|
||||
}
|
||||
|
||||
app.on('activate', function () {
|
||||
const mainWindow = windowService.getMainWindow()
|
||||
if (!mainWindow || mainWindow.isDestroyed()) {
|
||||
@@ -193,6 +202,15 @@ if (!app.requestSingleInstanceLock()) {
|
||||
} catch (error) {
|
||||
logger.warn('Error cleaning up MCP service:', error as Error)
|
||||
}
|
||||
|
||||
// Stop Claude-code HTTP service
|
||||
try {
|
||||
await claudeCodeService.stop()
|
||||
logger.info('Claude-code HTTP service stopped')
|
||||
} catch (error) {
|
||||
logger.warn('Error stopping Claude-code HTTP service:', error as Error)
|
||||
}
|
||||
|
||||
// finish the logger
|
||||
logger.finish()
|
||||
})
|
||||
|
||||
@@ -11,6 +11,7 @@ import { SpanEntity, TokenUsage } from '@mcp-trace/trace-core'
|
||||
import { MIN_WINDOW_HEIGHT, MIN_WINDOW_WIDTH, UpgradeChannel } from '@shared/config/constant'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
import { FileMetadata, Provider, Shortcut, ThemeMode } from '@types'
|
||||
import { claudeCodeService } from './services/ClaudeCodeService'
|
||||
import { BrowserWindow, dialog, ipcMain, ProxyConfig, session, shell, systemPreferences, webContents } from 'electron'
|
||||
import { Notification } from 'src/renderer/src/types/notification'
|
||||
|
||||
@@ -755,4 +756,9 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
|
||||
// CherryIN
|
||||
ipcMain.handle(IpcChannel.Cherryin_GetSignature, (_, params) => generateSignature(params))
|
||||
|
||||
// Provider
|
||||
ipcMain.handle(IpcChannel.Provider_GetClaudeCodePort, () => {
|
||||
return claudeCodeService.getPort()
|
||||
})
|
||||
}
|
||||
|
||||
158
src/main/services/ClaudeCodeService.ts
Normal file
158
src/main/services/ClaudeCodeService.ts
Normal file
@@ -0,0 +1,158 @@
|
||||
import { createExecutor } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import { createClaudeCode } from 'ai-sdk-provider-claude-code'
|
||||
import express, { Request, Response } from 'express'
|
||||
import { Server } from 'http'
|
||||
|
||||
const logger = loggerService.withContext('ClaudeCodeService')
|
||||
|
||||
export class ClaudeCodeService {
|
||||
private app: express.Application
|
||||
private server: Server | null = null
|
||||
private port: number = 0
|
||||
private claudeCodeProvider: any = null
|
||||
|
||||
constructor() {
|
||||
this.app = express()
|
||||
this.setupMiddleware()
|
||||
this.setupRoutes()
|
||||
}
|
||||
|
||||
private setupMiddleware() {
|
||||
this.app.use(express.json())
|
||||
this.app.use(express.text())
|
||||
}
|
||||
|
||||
private setupRoutes() {
|
||||
// Health check endpoint
|
||||
this.app.get('/health', (_req: Request, res: Response) => {
|
||||
res.json({ status: 'ok', timestamp: new Date().toISOString() })
|
||||
})
|
||||
|
||||
// Initialize claude-code provider
|
||||
this.app.post('/init', async (req: Request, res: Response) => {
|
||||
try {
|
||||
const config = req.body
|
||||
logger.info('Initializing claude-code provider with config', config)
|
||||
|
||||
this.claudeCodeProvider = createClaudeCode()
|
||||
|
||||
res.json({
|
||||
success: true,
|
||||
message: 'Claude-code provider initialized successfully'
|
||||
})
|
||||
} catch (error) {
|
||||
logger.error('Failed to initialize claude-code provider', error as Error)
|
||||
res.status(500).json({
|
||||
success: false,
|
||||
error: (error as Error).message
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
// Stream text completion endpoint
|
||||
this.app.post('/completions', async (req: Request, res: Response): Promise<void> => {
|
||||
try {
|
||||
if (!this.claudeCodeProvider) {
|
||||
res.status(400).json({
|
||||
success: false,
|
||||
error: 'Claude-code provider not initialized. Call /init first.'
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
const { modelId, params, options } = req.body
|
||||
logger.info('Processing completions request', { modelId, hasParams: !!params })
|
||||
|
||||
// 创建执行器
|
||||
const executor = createExecutor('claude-code', options || {}, [])
|
||||
const model = this.claudeCodeProvider.languageModel('opus')
|
||||
|
||||
// 执行流式文本生成
|
||||
const result = await executor.streamText({
|
||||
...params,
|
||||
model,
|
||||
abortSignal: new AbortController().signal
|
||||
})
|
||||
console.log('result', result)
|
||||
// 使用 AI SDK 提供的便捷函数处理流式响应
|
||||
result.pipeUIMessageStreamToResponse(res)
|
||||
|
||||
logger.info('Completions request completed successfully')
|
||||
} catch (error) {
|
||||
logger.error('Error in completions endpoint', error as Error)
|
||||
if (!res.headersSent) {
|
||||
res.status(500).json({
|
||||
success: false,
|
||||
error: (error as Error).message
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
public async start(): Promise<number> {
|
||||
return new Promise((resolve, reject) => {
|
||||
// 尝试使用固定端口,如果失败则使用系统分配端口
|
||||
const preferredPort = 23456
|
||||
|
||||
this.server = this.app.listen(preferredPort, 'localhost', () => {
|
||||
if (this.server?.address()) {
|
||||
this.port = (this.server.address() as any)?.port || 0
|
||||
logger.info(`Claude-code HTTP service started on port ${this.port}`)
|
||||
resolve(this.port)
|
||||
} else {
|
||||
reject(new Error('Failed to start server'))
|
||||
}
|
||||
})
|
||||
|
||||
this.server.on('error', (error: any) => {
|
||||
if (error.code === 'EADDRINUSE') {
|
||||
logger.warn(`Port ${preferredPort} is in use, trying with dynamic port`)
|
||||
// 如果固定端口被占用,使用动态端口
|
||||
this.server = this.app.listen(0, 'localhost', () => {
|
||||
if (this.server?.address()) {
|
||||
this.port = (this.server.address() as any)?.port || 0
|
||||
logger.info(`Claude-code HTTP service started on dynamic port ${this.port}`)
|
||||
resolve(this.port)
|
||||
} else {
|
||||
reject(new Error('Failed to start server'))
|
||||
}
|
||||
})
|
||||
|
||||
this.server.on('error', (dynamicError) => {
|
||||
logger.error('Server error on dynamic port', dynamicError)
|
||||
reject(dynamicError)
|
||||
})
|
||||
} else {
|
||||
logger.error('Server error', error)
|
||||
reject(error)
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
public async stop(): Promise<void> {
|
||||
return new Promise((resolve) => {
|
||||
if (this.server) {
|
||||
this.server.close(() => {
|
||||
logger.info('Claude-code HTTP service stopped')
|
||||
resolve()
|
||||
})
|
||||
} else {
|
||||
resolve()
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
public getPort(): number {
|
||||
return this.port
|
||||
}
|
||||
|
||||
public isRunning(): boolean {
|
||||
return this.server !== null && this.server.listening
|
||||
}
|
||||
}
|
||||
|
||||
// 单例实例
|
||||
export const claudeCodeService = new ClaudeCodeService()
|
||||
@@ -323,7 +323,7 @@ class CodeToolsService {
|
||||
? `set "BUN_INSTALL=${bunInstallPath}" && set "NPM_CONFIG_REGISTRY=${registryUrl}" &&`
|
||||
: `export BUN_INSTALL="${bunInstallPath}" && export NPM_CONFIG_REGISTRY="${registryUrl}" &&`
|
||||
|
||||
const installCommand = `${installEnvPrefix} ${bunPath} install -g ${packageName}`
|
||||
const installCommand = `${installEnvPrefix} "${bunPath}" install -g ${packageName}`
|
||||
baseCommand = `echo "Installing ${packageName}..." && ${installCommand} && echo "Installation complete, starting ${cliTool}..." && ${baseCommand}`
|
||||
}
|
||||
|
||||
|
||||
@@ -570,7 +570,8 @@ class McpService {
|
||||
...tool,
|
||||
id: buildFunctionCallToolName(server.name, tool.name),
|
||||
serverId: server.id,
|
||||
serverName: server.name
|
||||
serverName: server.name,
|
||||
type: 'mcp'
|
||||
}
|
||||
serverTools.push(serverTool)
|
||||
})
|
||||
|
||||
@@ -32,7 +32,8 @@ class ObsidianVaultService {
|
||||
)
|
||||
} else {
|
||||
// Linux
|
||||
this.obsidianConfigPath = path.join(app.getPath('home'), '.config', 'obsidian', 'obsidian.json')
|
||||
this.obsidianConfigPath = this.resolveLinuxObsidianConfigPath()
|
||||
logger.debug(`Resolved Obsidian config path (linux): ${this.obsidianConfigPath}`)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -164,6 +165,57 @@ class ObsidianVaultService {
|
||||
return []
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 在 Linux 下解析 Obsidian 配置文件路径,兼容多种安装方式。
|
||||
* 优先返回第一个存在的路径;若均不存在,则返回 XDG 默认路径。
|
||||
*/
|
||||
private resolveLinuxObsidianConfigPath(): string {
|
||||
const home = app.getPath('home')
|
||||
const xdgConfigHome = process.env.XDG_CONFIG_HOME || path.join(home, '.config')
|
||||
|
||||
// 常见目录名与文件名大小写差异做兼容
|
||||
const configDirs = ['obsidian', 'Obsidian']
|
||||
const fileNames = ['obsidian.json', 'Obsidian.json']
|
||||
|
||||
const candidates: string[] = []
|
||||
|
||||
// 1) AppImage/DEB(XDG 标准路径)
|
||||
for (const dir of configDirs) {
|
||||
for (const file of fileNames) {
|
||||
candidates.push(path.join(xdgConfigHome, dir, file))
|
||||
}
|
||||
}
|
||||
|
||||
// 2) Snap 安装:
|
||||
// - 常见:~/snap/obsidian/current/.config/obsidian/obsidian.json
|
||||
// - 兼容:~/snap/obsidian/common/.config/obsidian/obsidian.json
|
||||
for (const dir of configDirs) {
|
||||
for (const file of fileNames) {
|
||||
candidates.push(path.join(home, 'snap', 'obsidian', 'current', '.config', dir, file))
|
||||
candidates.push(path.join(home, 'snap', 'obsidian', 'common', '.config', dir, file))
|
||||
}
|
||||
}
|
||||
|
||||
// 3) Flatpak 安装:~/.var/app/md.obsidian.Obsidian/config/obsidian/obsidian.json
|
||||
for (const dir of configDirs) {
|
||||
for (const file of fileNames) {
|
||||
candidates.push(path.join(home, '.var', 'app', 'md.obsidian.Obsidian', 'config', dir, file))
|
||||
}
|
||||
}
|
||||
|
||||
const existing = candidates.find((p) => {
|
||||
try {
|
||||
return fs.existsSync(p)
|
||||
} catch {
|
||||
return false
|
||||
}
|
||||
})
|
||||
|
||||
if (existing) return existing
|
||||
|
||||
return path.join(xdgConfigHome, 'obsidian', 'obsidian.json')
|
||||
}
|
||||
}
|
||||
|
||||
export default ObsidianVaultService
|
||||
|
||||
@@ -2,6 +2,7 @@ import { loggerService } from '@logger'
|
||||
import { isLinux } from '@main/constant'
|
||||
import { BuiltinOcrProviderIds, OcrHandler, OcrProvider, OcrResult, SupportedOcrFile } from '@types'
|
||||
|
||||
import { systemOcrService } from './builtin/SystemOcrService'
|
||||
import { tesseractService } from './builtin/TesseractService'
|
||||
|
||||
const logger = loggerService.withContext('OcrService')
|
||||
@@ -34,7 +35,4 @@ export const ocrService = new OcrService()
|
||||
// Register built-in providers
|
||||
ocrService.register(BuiltinOcrProviderIds.tesseract, tesseractService.ocr.bind(tesseractService))
|
||||
|
||||
if (!isLinux) {
|
||||
const { systemOcrService } = require('./builtin/SystemOcrService')
|
||||
ocrService.register(BuiltinOcrProviderIds.system, systemOcrService.ocr.bind(systemOcrService))
|
||||
}
|
||||
!isLinux && ocrService.register(BuiltinOcrProviderIds.system, systemOcrService.ocr.bind(systemOcrService))
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { isLinux, isWin } from '@main/constant'
|
||||
import { loadOcrImage } from '@main/utils/ocr'
|
||||
import { OcrAccuracy, recognize } from '@napi-rs/system-ocr'
|
||||
import {
|
||||
ImageFileMetadata,
|
||||
isImageFileMetadata as isImageFileMetadata,
|
||||
@@ -20,8 +21,6 @@ export class SystemOcrService extends OcrBaseService {
|
||||
if (isLinux) {
|
||||
return { text: '' }
|
||||
}
|
||||
|
||||
const { OcrAccuracy, recognize } = require('@napi-rs/system-ocr')
|
||||
const buffer = await loadOcrImage(file)
|
||||
const langs = isWin ? options?.langs : undefined
|
||||
const result = await recognize(buffer, OcrAccuracy.Accurate, langs)
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import { ImageFileMetadata } from '@types'
|
||||
import { readFile } from 'fs/promises'
|
||||
import sharp from 'sharp'
|
||||
|
||||
const preprocessImage = async (buffer: Buffer): Promise<Buffer> => {
|
||||
const sharp = require('sharp')
|
||||
return sharp(buffer)
|
||||
.grayscale() // 转为灰度
|
||||
.normalize()
|
||||
|
||||
@@ -437,6 +437,9 @@ const api = {
|
||||
cherryin: {
|
||||
generateSignature: (params: { method: string; path: string; query: string; body: Record<string, any> }) =>
|
||||
ipcRenderer.invoke(IpcChannel.Cherryin_GetSignature, params)
|
||||
},
|
||||
provider: {
|
||||
getClaudeCodePort: () => ipcRenderer.invoke(IpcChannel.Provider_GetClaudeCodePort)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
363
src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts
Normal file
363
src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts
Normal file
@@ -0,0 +1,363 @@
|
||||
/**
|
||||
* AI SDK 到 Cherry Studio Chunk 适配器
|
||||
* 用于将 AI SDK 的 fullStream 转换为 Cherry Studio 的 chunk 格式
|
||||
*/
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import { MCPTool, WebSearchResults, WebSearchSource } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import type { TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import { ToolCallChunkHandler } from './handleToolCallChunk'
|
||||
|
||||
const logger = loggerService.withContext('AiSdkToChunkAdapter')
|
||||
|
||||
export interface CherryStudioChunk {
|
||||
type: 'text-delta' | 'text-complete' | 'tool-call' | 'tool-result' | 'finish' | 'error'
|
||||
text?: string
|
||||
toolCall?: any
|
||||
toolResult?: any
|
||||
finishReason?: string
|
||||
usage?: any
|
||||
error?: any
|
||||
}
|
||||
|
||||
/**
|
||||
* AI SDK 到 Cherry Studio Chunk 适配器类
|
||||
* 处理 fullStream 到 Cherry Studio chunk 的转换
|
||||
*/
|
||||
export class AiSdkToChunkAdapter {
|
||||
toolCallHandler: ToolCallChunkHandler
|
||||
private accumulate: boolean | undefined
|
||||
constructor(
|
||||
private onChunk: (chunk: Chunk) => void,
|
||||
mcpTools: MCPTool[] = [],
|
||||
accumulate?: boolean
|
||||
) {
|
||||
this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools)
|
||||
this.accumulate = accumulate
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理 AI SDK 流结果
|
||||
* @param aiSdkResult AI SDK 的流结果对象
|
||||
* @returns 最终的文本内容
|
||||
*/
|
||||
async processStream(aiSdkResult: any): Promise<string> {
|
||||
// 如果是流式且有 fullStream
|
||||
if (aiSdkResult.fullStream) {
|
||||
await this.readFullStream(aiSdkResult.fullStream)
|
||||
}
|
||||
|
||||
// 使用 streamResult.text 获取最终结果
|
||||
return await aiSdkResult.text
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接处理单个 chunk 数据
|
||||
* @param chunk AI SDK 的 chunk 数据
|
||||
*/
|
||||
async processChunk(response: ReadableStream<TextStreamPart<any>>): Promise<void> {
|
||||
const reader = response.getReader()
|
||||
const final = {
|
||||
text: '',
|
||||
reasoningContent: '',
|
||||
webSearchResults: [],
|
||||
reasoningId: ''
|
||||
}
|
||||
try {
|
||||
let buffer = ''
|
||||
const decoder = new TextDecoder()
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) break
|
||||
|
||||
const chunk = decoder.decode(value, { stream: true })
|
||||
buffer += chunk
|
||||
|
||||
// 按行处理 SSE 数据
|
||||
const lines = buffer.split('\n')
|
||||
buffer = lines.pop() || '' // 保留最后一行(可能不完整)
|
||||
|
||||
for (const line of lines) {
|
||||
if (line.startsWith('data: ')) {
|
||||
const dataStr = line.slice(6) // 移除 "data: " 前缀
|
||||
|
||||
if (dataStr === '[DONE]') {
|
||||
break
|
||||
}
|
||||
try {
|
||||
const data = JSON.parse(dataStr)
|
||||
this.convertAndEmitChunk(data, final)
|
||||
} catch (parseError) {
|
||||
// 忽略无法解析的数据
|
||||
// logger.debug('Failed to parse streamed data:', parseError as Error, line)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 读取 fullStream 并转换为 Cherry Studio chunks
|
||||
* @param fullStream AI SDK 的 fullStream (ReadableStream)
|
||||
*/
|
||||
private async readFullStream(fullStream: ReadableStream<TextStreamPart<ToolSet>>) {
|
||||
const reader = fullStream.getReader()
|
||||
const final = {
|
||||
text: '',
|
||||
reasoningContent: '',
|
||||
webSearchResults: [],
|
||||
reasoningId: ''
|
||||
}
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
|
||||
if (done) {
|
||||
break
|
||||
}
|
||||
|
||||
// 转换并发送 chunk
|
||||
this.convertAndEmitChunk(value, final)
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换 AI SDK chunk 为 Cherry Studio chunk 并调用回调
|
||||
* @param chunk AI SDK 的 chunk 数据
|
||||
*/
|
||||
private convertAndEmitChunk(
|
||||
chunk: TextStreamPart<any>,
|
||||
final: { text: string; reasoningContent: string; webSearchResults: any[]; reasoningId: string }
|
||||
) {
|
||||
logger.info(`AI SDK chunk type: ${chunk.type}`, chunk)
|
||||
console.log('final', final)
|
||||
switch (chunk.type) {
|
||||
// === 文本相关事件 ===
|
||||
case 'text-start':
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_START
|
||||
})
|
||||
break
|
||||
case 'text-delta':
|
||||
if (this.accumulate) {
|
||||
final.text += chunk.delta || ''
|
||||
} else {
|
||||
final.text = chunk.text || ''
|
||||
}
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: final.text || ''
|
||||
})
|
||||
break
|
||||
case 'text-end':
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
text: (chunk.providerMetadata?.text?.value as string) ?? final.text ?? ''
|
||||
})
|
||||
final.text = ''
|
||||
break
|
||||
case 'reasoning-start':
|
||||
// if (final.reasoningId !== chunk.id) {
|
||||
final.reasoningId = chunk.id
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_START
|
||||
})
|
||||
// }
|
||||
break
|
||||
case 'reasoning-delta':
|
||||
final.reasoningContent += chunk.text || ''
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: final.reasoningContent || '',
|
||||
thinking_millsec: (chunk.providerMetadata?.metadata?.thinking_millsec as number) || 0
|
||||
})
|
||||
break
|
||||
case 'reasoning-end':
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: (chunk.providerMetadata?.metadata?.thinking_content as string) || '',
|
||||
thinking_millsec: (chunk.providerMetadata?.metadata?.thinking_millsec as number) || 0
|
||||
})
|
||||
final.reasoningContent = ''
|
||||
break
|
||||
|
||||
// === 工具调用相关事件(原始 AI SDK 事件,如果没有被中间件处理) ===
|
||||
|
||||
// case 'tool-input-start':
|
||||
// case 'tool-input-delta':
|
||||
// case 'tool-input-end':
|
||||
// this.toolCallHandler.handleToolCallCreated(chunk)
|
||||
// break
|
||||
|
||||
// case 'tool-input-delta':
|
||||
// this.toolCallHandler.handleToolCallCreated(chunk)
|
||||
// break
|
||||
case 'tool-call':
|
||||
// 原始的工具调用(未被中间件处理)
|
||||
this.toolCallHandler.handleToolCall(chunk)
|
||||
break
|
||||
|
||||
case 'tool-result':
|
||||
// 原始的工具调用结果(未被中间件处理)
|
||||
this.toolCallHandler.handleToolResult(chunk)
|
||||
break
|
||||
|
||||
// === 步骤相关事件 ===
|
||||
// case 'start':
|
||||
// this.onChunk({
|
||||
// type: ChunkType.LLM_RESPONSE_CREATED
|
||||
// })
|
||||
// break
|
||||
// TODO: 需要区分接口开始和步骤开始
|
||||
// case 'start-step':
|
||||
// this.onChunk({
|
||||
// type: ChunkType.BLOCK_CREATED
|
||||
// })
|
||||
// break
|
||||
// case 'step-finish':
|
||||
// this.onChunk({
|
||||
// type: ChunkType.TEXT_COMPLETE,
|
||||
// text: final.text || '' // TEXT_COMPLETE 需要 text 字段
|
||||
// })
|
||||
// final.text = ''
|
||||
// break
|
||||
|
||||
case 'finish-step': {
|
||||
const { providerMetadata, finishReason } = chunk
|
||||
// googel web search
|
||||
if (providerMetadata?.google?.groundingMetadata) {
|
||||
this.onChunk({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: providerMetadata.google?.groundingMetadata as WebSearchResults,
|
||||
source: WebSearchSource.GEMINI
|
||||
}
|
||||
})
|
||||
} else if (final.webSearchResults.length) {
|
||||
const providerName = Object.keys(providerMetadata || {})[0]
|
||||
const sourceMap: Record<string, WebSearchSource> = {
|
||||
[WebSearchSource.OPENAI]: WebSearchSource.OPENAI_RESPONSE,
|
||||
[WebSearchSource.ANTHROPIC]: WebSearchSource.ANTHROPIC,
|
||||
[WebSearchSource.OPENROUTER]: WebSearchSource.OPENROUTER,
|
||||
[WebSearchSource.GEMINI]: WebSearchSource.GEMINI,
|
||||
[WebSearchSource.PERPLEXITY]: WebSearchSource.PERPLEXITY,
|
||||
[WebSearchSource.QWEN]: WebSearchSource.QWEN,
|
||||
[WebSearchSource.HUNYUAN]: WebSearchSource.HUNYUAN,
|
||||
[WebSearchSource.ZHIPU]: WebSearchSource.ZHIPU,
|
||||
[WebSearchSource.GROK]: WebSearchSource.GROK,
|
||||
[WebSearchSource.WEBSEARCH]: WebSearchSource.WEBSEARCH
|
||||
}
|
||||
const source = sourceMap[providerName] || WebSearchSource.AISDK
|
||||
|
||||
this.onChunk({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: final.webSearchResults,
|
||||
source
|
||||
}
|
||||
})
|
||||
}
|
||||
if (finishReason === 'tool-calls') {
|
||||
this.onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
}
|
||||
|
||||
final.webSearchResults = []
|
||||
// final.reasoningId = ''
|
||||
break
|
||||
}
|
||||
|
||||
case 'finish':
|
||||
this.onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
text: final.text || '',
|
||||
reasoning_content: final.reasoningContent || '',
|
||||
usage: {
|
||||
completion_tokens: chunk?.totalUsage?.outputTokens || 0,
|
||||
prompt_tokens: chunk?.totalUsage?.inputTokens || 0,
|
||||
total_tokens: chunk?.totalUsage?.totalTokens || 0
|
||||
},
|
||||
metrics: chunk?.totalUsage
|
||||
? {
|
||||
completion_tokens: chunk?.totalUsage?.outputTokens || 0,
|
||||
time_completion_millsec: 0
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
})
|
||||
this.onChunk({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
text: final.text || '',
|
||||
reasoning_content: final.reasoningContent || '',
|
||||
usage: {
|
||||
completion_tokens: chunk?.totalUsage?.outputTokens || 0,
|
||||
prompt_tokens: chunk?.totalUsage?.inputTokens || 0,
|
||||
total_tokens: chunk?.totalUsage?.totalTokens || 0
|
||||
},
|
||||
metrics: chunk?.totalUsage
|
||||
? {
|
||||
completion_tokens: chunk?.totalUsage?.outputTokens || 0,
|
||||
time_completion_millsec: 0
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
})
|
||||
break
|
||||
|
||||
// === 源和文件相关事件 ===
|
||||
case 'source':
|
||||
if (chunk.sourceType === 'url') {
|
||||
// if (final.webSearchResults.length === 0) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
const { sourceType: _, ...rest } = chunk
|
||||
final.webSearchResults.push(rest)
|
||||
// }
|
||||
// this.onChunk({
|
||||
// type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
// llm_web_search: {
|
||||
// source: WebSearchSource.AISDK,
|
||||
// results: final.webSearchResults
|
||||
// }
|
||||
// })
|
||||
}
|
||||
break
|
||||
case 'file':
|
||||
// 文件相关事件,可能是图片生成
|
||||
this.onChunk({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: [`data:${chunk.file.mediaType};base64,${chunk.file.base64}`]
|
||||
}
|
||||
})
|
||||
break
|
||||
case 'abort':
|
||||
this.onChunk({
|
||||
type: ChunkType.ERROR,
|
||||
error: new DOMException('Request was aborted', 'AbortError')
|
||||
})
|
||||
break
|
||||
case 'error':
|
||||
this.onChunk({
|
||||
type: ChunkType.ERROR,
|
||||
error: chunk.error as Record<string, any>
|
||||
})
|
||||
break
|
||||
|
||||
default:
|
||||
// 其他类型的 chunk 可以忽略或记录日志
|
||||
// console.log('Unhandled AI SDK chunk type:', chunk.type, chunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default AiSdkToChunkAdapter
|
||||
266
src/renderer/src/aiCore/chunk/handleToolCallChunk.ts
Normal file
266
src/renderer/src/aiCore/chunk/handleToolCallChunk.ts
Normal file
@@ -0,0 +1,266 @@
|
||||
/**
|
||||
* 工具调用 Chunk 处理模块
|
||||
* TODO: Tool包含了providerTool和普通的Tool还有MCPTool,后面需要重构
|
||||
* 提供工具调用相关的处理API,每个交互使用一个新的实例
|
||||
*/
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import { BaseTool, MCPTool, MCPToolResponse, NormalToolResponse } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import type { ProviderMetadata, ToolSet, TypedToolCall, TypedToolResult } from 'ai'
|
||||
// import type {
|
||||
// AnthropicSearchOutput,
|
||||
// WebSearchPluginConfig
|
||||
// } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin'
|
||||
|
||||
const logger = loggerService.withContext('ToolCallChunkHandler')
|
||||
|
||||
/**
|
||||
* 工具调用处理器类
|
||||
*/
|
||||
export class ToolCallChunkHandler {
|
||||
// private onChunk: (chunk: Chunk) => void
|
||||
private activeToolCalls = new Map<
|
||||
string,
|
||||
{
|
||||
toolCallId: string
|
||||
toolName: string
|
||||
args: any
|
||||
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
|
||||
tool: BaseTool
|
||||
}
|
||||
>()
|
||||
constructor(
|
||||
private onChunk: (chunk: Chunk) => void,
|
||||
private mcpTools: MCPTool[]
|
||||
) {}
|
||||
|
||||
// /**
|
||||
// * 设置 onChunk 回调
|
||||
// */
|
||||
// public setOnChunk(callback: (chunk: Chunk) => void): void {
|
||||
// this.onChunk = callback
|
||||
// }
|
||||
|
||||
handleToolCallCreated(
|
||||
chunk:
|
||||
| {
|
||||
type: 'tool-input-start'
|
||||
id: string
|
||||
toolName: string
|
||||
providerMetadata?: ProviderMetadata
|
||||
providerExecuted?: boolean
|
||||
}
|
||||
| {
|
||||
type: 'tool-input-end'
|
||||
id: string
|
||||
providerMetadata?: ProviderMetadata
|
||||
}
|
||||
| {
|
||||
type: 'tool-input-delta'
|
||||
id: string
|
||||
delta: string
|
||||
providerMetadata?: ProviderMetadata
|
||||
}
|
||||
): void {
|
||||
switch (chunk.type) {
|
||||
case 'tool-input-start': {
|
||||
// 能拿到说明是mcpTool
|
||||
// if (this.activeToolCalls.get(chunk.id)) return
|
||||
|
||||
const tool: BaseTool | MCPTool = {
|
||||
id: chunk.id,
|
||||
name: chunk.toolName,
|
||||
description: chunk.toolName,
|
||||
type: chunk.toolName.startsWith('builtin_') ? 'builtin' : 'provider'
|
||||
}
|
||||
this.activeToolCalls.set(chunk.id, {
|
||||
toolCallId: chunk.id,
|
||||
toolName: chunk.toolName,
|
||||
args: '',
|
||||
tool
|
||||
})
|
||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
id: chunk.id,
|
||||
tool: tool,
|
||||
arguments: {},
|
||||
status: 'pending',
|
||||
toolCallId: chunk.id
|
||||
}
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_PENDING,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'tool-input-delta': {
|
||||
const toolCall = this.activeToolCalls.get(chunk.id)
|
||||
if (!toolCall) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
return
|
||||
}
|
||||
toolCall.args += chunk.delta
|
||||
break
|
||||
}
|
||||
case 'tool-input-end': {
|
||||
const toolCall = this.activeToolCalls.get(chunk.id)
|
||||
this.activeToolCalls.delete(chunk.id)
|
||||
if (!toolCall) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
return
|
||||
}
|
||||
// const toolResponse: ToolCallResponse = {
|
||||
// id: toolCall.toolCallId,
|
||||
// tool: toolCall.tool,
|
||||
// arguments: toolCall.args,
|
||||
// status: 'pending',
|
||||
// toolCallId: toolCall.toolCallId
|
||||
// }
|
||||
// logger.debug('toolResponse', toolResponse)
|
||||
// this.onChunk({
|
||||
// type: ChunkType.MCP_TOOL_PENDING,
|
||||
// responses: [toolResponse]
|
||||
// })
|
||||
break
|
||||
}
|
||||
}
|
||||
// if (!toolCall) {
|
||||
// Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
// return
|
||||
// }
|
||||
// this.onChunk({
|
||||
// type: ChunkType.MCP_TOOL_CREATED,
|
||||
// tool_calls: [
|
||||
// {
|
||||
// id: chunk.id,
|
||||
// name: chunk.toolName,
|
||||
// status: 'pending'
|
||||
// }
|
||||
// ]
|
||||
// })
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理工具调用事件
|
||||
*/
|
||||
public handleToolCall(
|
||||
chunk: {
|
||||
type: 'tool-call'
|
||||
} & TypedToolCall<ToolSet>
|
||||
): void {
|
||||
const { toolCallId, toolName, input: args, providerExecuted } = chunk
|
||||
|
||||
if (!toolCallId || !toolName) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool call chunk: missing toolCallId or toolName`)
|
||||
return
|
||||
}
|
||||
|
||||
let tool: BaseTool
|
||||
let mcpTool: MCPTool | undefined
|
||||
|
||||
// 根据 providerExecuted 标志区分处理逻辑
|
||||
if (providerExecuted) {
|
||||
// 如果是 Provider 执行的工具(如 web_search)
|
||||
logger.info(`[ToolCallChunkHandler] Handling provider-executed tool: ${toolName}`)
|
||||
tool = {
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
description: toolName,
|
||||
type: 'provider'
|
||||
} as BaseTool
|
||||
} else if (toolName.startsWith('builtin_')) {
|
||||
// 如果是内置工具,沿用现有逻辑
|
||||
logger.info(`[ToolCallChunkHandler] Handling builtin tool: ${toolName}`)
|
||||
tool = {
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
description: toolName,
|
||||
type: 'builtin'
|
||||
} as BaseTool
|
||||
} else if ((mcpTool = this.mcpTools.find((t) => t.name === toolName) as MCPTool)) {
|
||||
// 如果是客户端执行的 MCP 工具,沿用现有逻辑
|
||||
logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`)
|
||||
// mcpTool = this.mcpTools.find((t) => t.name === toolName) as MCPTool
|
||||
// if (!mcpTool) {
|
||||
// logger.warn(`[ToolCallChunkHandler] MCP tool not found: ${toolName}`)
|
||||
// return
|
||||
// }
|
||||
tool = mcpTool
|
||||
} else {
|
||||
tool = {
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
description: toolName,
|
||||
type: 'provider'
|
||||
}
|
||||
}
|
||||
|
||||
// 记录活跃的工具调用
|
||||
this.activeToolCalls.set(toolCallId, {
|
||||
toolCallId,
|
||||
toolName,
|
||||
args,
|
||||
tool
|
||||
})
|
||||
|
||||
// 创建 MCPToolResponse 格式
|
||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
id: toolCallId,
|
||||
tool: tool,
|
||||
arguments: args,
|
||||
status: 'pending',
|
||||
toolCallId: toolCallId
|
||||
}
|
||||
|
||||
// 调用 onChunk
|
||||
if (this.onChunk) {
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_PENDING,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理工具调用结果事件
|
||||
*/
|
||||
public handleToolResult(
|
||||
chunk: {
|
||||
type: 'tool-result'
|
||||
} & TypedToolResult<ToolSet>
|
||||
): void {
|
||||
const { toolCallId, output, input } = chunk
|
||||
|
||||
if (!toolCallId) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool result chunk: missing toolCallId`)
|
||||
return
|
||||
}
|
||||
|
||||
// 查找对应的工具调用信息
|
||||
const toolCallInfo = this.activeToolCalls.get(toolCallId)
|
||||
if (!toolCallInfo) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建工具调用结果的 MCPToolResponse 格式
|
||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
id: toolCallInfo.toolCallId,
|
||||
tool: toolCallInfo.tool,
|
||||
arguments: input,
|
||||
status: 'done',
|
||||
response: output,
|
||||
toolCallId: toolCallId
|
||||
}
|
||||
// 从活跃调用中移除(交互结束后整个实例会被丢弃)
|
||||
this.activeToolCalls.delete(toolCallId)
|
||||
|
||||
// 调用 onChunk
|
||||
if (this.onChunk) {
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_COMPLETE,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,189 +1,16 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
|
||||
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
|
||||
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { withSpanResult } from '@renderer/services/SpanManagerService'
|
||||
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import type { RequestOptions, SdkModel } from '@renderer/types/sdk'
|
||||
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
|
||||
/**
|
||||
* Cherry Studio AI Core - 统一入口点
|
||||
*
|
||||
* 这是新的统一入口,保持向后兼容性
|
||||
* 默认导出legacy AiProvider以保持现有代码的兼容性
|
||||
*/
|
||||
|
||||
import { AihubmixAPIClient } from './clients/aihubmix/AihubmixAPIClient'
|
||||
import { VertexAPIClient } from './clients/gemini/VertexAPIClient'
|
||||
import { NewAPIClient } from './clients/newapi/NewAPIClient'
|
||||
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
|
||||
import { CompletionsMiddlewareBuilder } from './middleware/builder'
|
||||
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
|
||||
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
|
||||
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
|
||||
import { applyCompletionsMiddlewares } from './middleware/composer'
|
||||
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
|
||||
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
|
||||
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
|
||||
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
|
||||
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
|
||||
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
|
||||
import { MiddlewareRegistry } from './middleware/register'
|
||||
import type { CompletionsParams, CompletionsResult } from './middleware/schemas'
|
||||
// 导出Legacy AiProvider作为默认导出(保持向后兼容)
|
||||
export { default } from './legacy/index'
|
||||
|
||||
const logger = loggerService.withContext('AiProvider')
|
||||
// 同时导出Modern AiProvider供新代码使用
|
||||
export { default as ModernAiProvider } from './index_new'
|
||||
|
||||
export default class AiProvider {
|
||||
private apiClient: BaseApiClient
|
||||
|
||||
constructor(provider: Provider) {
|
||||
// Use the new ApiClientFactory to get a BaseApiClient instance
|
||||
this.apiClient = ApiClientFactory.create(provider)
|
||||
}
|
||||
|
||||
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
// 1. 根据模型识别正确的客户端
|
||||
const model = params.assistant.model
|
||||
if (!model) {
|
||||
return Promise.reject(new Error('Model is required'))
|
||||
}
|
||||
|
||||
// 根据client类型选择合适的处理方式
|
||||
let client: BaseApiClient
|
||||
|
||||
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||
// AihubmixAPIClient: 根据模型选择合适的子client
|
||||
client = this.apiClient.getClientForModel(model)
|
||||
if (client instanceof OpenAIResponseAPIClient) {
|
||||
client = client.getClient(model) as BaseApiClient
|
||||
}
|
||||
} else if (this.apiClient instanceof NewAPIClient) {
|
||||
client = this.apiClient.getClientForModel(model)
|
||||
if (client instanceof OpenAIResponseAPIClient) {
|
||||
client = client.getClient(model) as BaseApiClient
|
||||
}
|
||||
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
|
||||
// OpenAIResponseAPIClient: 根据模型特征选择API类型
|
||||
client = this.apiClient.getClient(model) as BaseApiClient
|
||||
} else if (this.apiClient instanceof VertexAPIClient) {
|
||||
client = this.apiClient.getClient(model) as BaseApiClient
|
||||
} else {
|
||||
// 其他client直接使用
|
||||
client = this.apiClient
|
||||
}
|
||||
|
||||
// 2. 构建中间件链
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
// images api
|
||||
if (isDedicatedImageGenerationModel(model)) {
|
||||
builder.clear()
|
||||
builder
|
||||
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
|
||||
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
|
||||
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
|
||||
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
|
||||
} else {
|
||||
// Existing logic for other models
|
||||
logger.silly('Builder Params', params)
|
||||
// 使用兼容性类型检查,避免typescript类型收窄和装饰器模式的问题
|
||||
const clientTypes = client.getClientCompatibilityType(model)
|
||||
const isOpenAICompatible =
|
||||
clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
|
||||
if (!isOpenAICompatible) {
|
||||
logger.silly('ThinkingTagExtractionMiddleware is removed')
|
||||
builder.remove(ThinkingTagExtractionMiddlewareName)
|
||||
}
|
||||
|
||||
const isAnthropicOrOpenAIResponseCompatible =
|
||||
clientTypes.includes('AnthropicAPIClient') ||
|
||||
clientTypes.includes('OpenAIResponseAPIClient') ||
|
||||
clientTypes.includes('AnthropicVertexAPIClient')
|
||||
if (!isAnthropicOrOpenAIResponseCompatible) {
|
||||
logger.silly('RawStreamListenerMiddleware is removed')
|
||||
builder.remove(RawStreamListenerMiddlewareName)
|
||||
}
|
||||
if (!params.enableWebSearch) {
|
||||
logger.silly('WebSearchMiddleware is removed')
|
||||
builder.remove(WebSearchMiddlewareName)
|
||||
}
|
||||
if (!params.mcpTools?.length) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
logger.silly('ToolUseExtractionMiddleware is removed')
|
||||
builder.remove(McpToolChunkMiddlewareName)
|
||||
logger.silly('McpToolChunkMiddleware is removed')
|
||||
}
|
||||
if (isEnabledToolUse(params.assistant) && isFunctionCallingModel(model)) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
logger.silly('ToolUseExtractionMiddleware is removed')
|
||||
}
|
||||
if (params.callType !== 'chat' && params.callType !== 'check' && params.callType !== 'translate') {
|
||||
logger.silly('AbortHandlerMiddleware is removed')
|
||||
builder.remove(AbortHandlerMiddlewareName)
|
||||
}
|
||||
if (params.callType === 'test') {
|
||||
builder.remove(ErrorHandlerMiddlewareName)
|
||||
logger.silly('ErrorHandlerMiddleware is removed')
|
||||
builder.remove(FinalChunkConsumerMiddlewareName)
|
||||
logger.silly('FinalChunkConsumerMiddleware is removed')
|
||||
}
|
||||
}
|
||||
|
||||
const middlewares = builder.build()
|
||||
logger.silly(
|
||||
'middlewares',
|
||||
middlewares.map((m) => m.name)
|
||||
)
|
||||
|
||||
// 3. Create the wrapped SDK method with middlewares
|
||||
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
|
||||
|
||||
// 4. Execute the wrapped method with the original params
|
||||
const result = wrappedCompletionMethod(params, options)
|
||||
return result
|
||||
}
|
||||
|
||||
public async completionsForTrace(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
const traceName = params.assistant.model?.name
|
||||
? `${params.assistant.model?.name}.${params.callType}`
|
||||
: `LLM.${params.callType}`
|
||||
|
||||
const traceParams: StartSpanParams = {
|
||||
name: traceName,
|
||||
tag: 'LLM',
|
||||
topicId: params.topicId || '',
|
||||
modelName: params.assistant.model?.name
|
||||
}
|
||||
|
||||
return await withSpanResult(this.completions.bind(this), traceParams, params, options)
|
||||
}
|
||||
|
||||
public async models(): Promise<SdkModel[]> {
|
||||
return this.apiClient.listModels()
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
try {
|
||||
// Use the SDK instance to test embedding capabilities
|
||||
if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') {
|
||||
this.apiClient = this.apiClient.getClient(model) as BaseApiClient
|
||||
}
|
||||
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
|
||||
return dimensions
|
||||
} catch (error) {
|
||||
logger.error('Error getting embedding dimensions:', error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||
const client = this.apiClient.getClientForModel({ id: params.model } as Model)
|
||||
return client.generateImage(params)
|
||||
}
|
||||
return this.apiClient.generateImage(params)
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.apiClient.getBaseURL()
|
||||
}
|
||||
|
||||
public getApiKey(): string {
|
||||
return this.apiClient.getApiKey()
|
||||
}
|
||||
}
|
||||
// 导出一些常用的类型和工具
|
||||
export * from './legacy/clients/types'
|
||||
export * from './legacy/middleware/schemas'
|
||||
|
||||
584
src/renderer/src/aiCore/index_new.ts
Normal file
584
src/renderer/src/aiCore/index_new.ts
Normal file
@@ -0,0 +1,584 @@
|
||||
/**
|
||||
* Cherry Studio AI Core - 新版本入口
|
||||
* 集成 @cherrystudio/ai-core 库的渐进式重构方案
|
||||
*
|
||||
* 融合方案:简化实现,专注于核心功能
|
||||
* 1. 优先使用新AI SDK
|
||||
* 2. 暂时保持接口兼容性
|
||||
*/
|
||||
|
||||
import { createExecutor } from '@cherrystudio/ai-core'
|
||||
import { loggerService } from '@logger'
|
||||
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
|
||||
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
|
||||
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||
import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import type { AiSdkModel, StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||
import { type ImageModel, type LanguageModel, type Provider as AiSdkProvider, wrapLanguageModel } from 'ai'
|
||||
|
||||
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
|
||||
import LegacyAiProvider from './legacy/index'
|
||||
import { CompletionsParams, CompletionsResult } from './legacy/middleware/schemas'
|
||||
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
|
||||
import { buildPlugins } from './plugins/PluginBuilder'
|
||||
import { createAiSdkProvider } from './provider/factory'
|
||||
import {
|
||||
getActualProvider,
|
||||
isModernSdkSupported,
|
||||
prepareSpecialProviderConfig,
|
||||
providerToAiSdkConfig
|
||||
} from './provider/providerConfig'
|
||||
|
||||
const logger = loggerService.withContext('ModernAiProvider')
|
||||
|
||||
export type ModernAiProviderConfig = AiSdkMiddlewareConfig & {
|
||||
assistant: Assistant
|
||||
// topicId for tracing
|
||||
topicId?: string
|
||||
callType: string
|
||||
}
|
||||
|
||||
export default class ModernAiProvider {
|
||||
private legacyProvider: LegacyAiProvider
|
||||
private config?: ReturnType<typeof providerToAiSdkConfig>
|
||||
private actualProvider: Provider
|
||||
private model?: Model
|
||||
private localProvider: Awaited<AiSdkProvider> | null = null
|
||||
|
||||
// 构造函数重载签名
|
||||
constructor(model: Model, provider?: Provider)
|
||||
constructor(provider: Provider)
|
||||
constructor(modelOrProvider: Model | Provider, provider?: Provider)
|
||||
constructor(modelOrProvider: Model | Provider, provider?: Provider) {
|
||||
if (this.isModel(modelOrProvider)) {
|
||||
// 传入的是 Model
|
||||
this.model = modelOrProvider
|
||||
this.actualProvider = provider || getActualProvider(modelOrProvider)
|
||||
// 只保存配置,不预先创建executor
|
||||
this.config = providerToAiSdkConfig(this.actualProvider, modelOrProvider)
|
||||
} else {
|
||||
// 传入的是 Provider
|
||||
this.actualProvider = modelOrProvider
|
||||
// model为可选,某些操作(如fetchModels)不需要model
|
||||
}
|
||||
|
||||
this.legacyProvider = new LegacyAiProvider(this.actualProvider)
|
||||
}
|
||||
|
||||
/**
|
||||
* 类型守卫函数:通过 provider 属性区分 Model 和 Provider
|
||||
*/
|
||||
private isModel(obj: Model | Provider): obj is Model {
|
||||
return 'provider' in obj && typeof obj.provider === 'string'
|
||||
}
|
||||
|
||||
public getActualProvider() {
|
||||
return this.actualProvider
|
||||
}
|
||||
|
||||
public async completions(modelId: string, params: StreamTextParams, config: ModernAiProviderConfig) {
|
||||
// 检查model是否存在
|
||||
if (!this.model) {
|
||||
throw new Error('Model is required for completions. Please use constructor with model parameter.')
|
||||
}
|
||||
|
||||
// 确保配置存在
|
||||
if (!this.config) {
|
||||
this.config = providerToAiSdkConfig(this.actualProvider, this.model)
|
||||
}
|
||||
|
||||
// 准备特殊配置
|
||||
await prepareSpecialProviderConfig(this.actualProvider, this.config)
|
||||
|
||||
// 特殊处理 claude-code provider,通过本地 HTTP 服务器
|
||||
// if (this.config.providerId === 'claude-code') {
|
||||
return await this._completionsViaHttpService(modelId, params, config)
|
||||
// }
|
||||
|
||||
// 提前创建本地 provider 实例
|
||||
if (!this.localProvider) {
|
||||
this.localProvider = await createAiSdkProvider(this.config)
|
||||
}
|
||||
|
||||
// 提前构建中间件
|
||||
const middlewares = buildAiSdkMiddlewares({
|
||||
...config,
|
||||
provider: this.actualProvider
|
||||
})
|
||||
logger.debug('Built middlewares in completions', {
|
||||
middlewareCount: middlewares.length,
|
||||
isImageGeneration: config.isImageGenerationEndpoint
|
||||
})
|
||||
if (!this.localProvider) {
|
||||
throw new Error('Local provider not created')
|
||||
}
|
||||
|
||||
// 根据endpoint类型创建对应的模型
|
||||
let model: AiSdkModel | undefined
|
||||
if (config.isImageGenerationEndpoint) {
|
||||
model = this.localProvider.imageModel(modelId)
|
||||
} else {
|
||||
model = this.localProvider.languageModel(modelId)
|
||||
// 如果有中间件,应用到语言模型上
|
||||
if (middlewares.length > 0 && typeof model === 'object') {
|
||||
model = wrapLanguageModel({ model, middleware: middlewares })
|
||||
}
|
||||
}
|
||||
|
||||
if (config.topicId && getEnableDeveloperMode()) {
|
||||
// TypeScript类型窄化:确保topicId是string类型
|
||||
const traceConfig = {
|
||||
...config,
|
||||
topicId: config.topicId
|
||||
}
|
||||
return await this._completionsForTrace(model, params, traceConfig)
|
||||
} else {
|
||||
return await this._completionsOrImageGeneration(model, params, config)
|
||||
}
|
||||
}
|
||||
|
||||
private async _completionsOrImageGeneration(
|
||||
model: AiSdkModel,
|
||||
params: StreamTextParams,
|
||||
config: ModernAiProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
if (config.isImageGenerationEndpoint) {
|
||||
// 使用 legacy 实现处理图像生成(支持图片编辑等高级功能)
|
||||
if (!config.uiMessages) {
|
||||
throw new Error('uiMessages is required for image generation endpoint')
|
||||
}
|
||||
|
||||
const legacyParams: CompletionsParams = {
|
||||
callType: 'chat',
|
||||
messages: config.uiMessages, // 使用原始的 UI 消息格式
|
||||
assistant: config.assistant,
|
||||
streamOutput: config.streamOutput ?? true,
|
||||
onChunk: config.onChunk,
|
||||
topicId: config.topicId,
|
||||
mcpTools: config.mcpTools,
|
||||
enableWebSearch: config.enableWebSearch
|
||||
}
|
||||
|
||||
// 调用 legacy 的 completions,会自动使用 ImageGenerationMiddleware
|
||||
return await this.legacyProvider.completions(legacyParams)
|
||||
}
|
||||
|
||||
return await this.modernCompletions(model as LanguageModel, params, config)
|
||||
}
|
||||
|
||||
/**
|
||||
* 带trace支持的completions方法
|
||||
* 类似于legacy的completionsForTrace,确保AI SDK spans在正确的trace上下文中
|
||||
*/
|
||||
private async _completionsForTrace(
|
||||
model: AiSdkModel,
|
||||
params: StreamTextParams,
|
||||
config: ModernAiProviderConfig & { topicId: string }
|
||||
): Promise<CompletionsResult> {
|
||||
const modelId = this.model!.id
|
||||
const traceName = `${this.actualProvider.name}.${modelId}.${config.callType}`
|
||||
const traceParams: StartSpanParams = {
|
||||
name: traceName,
|
||||
tag: 'LLM',
|
||||
topicId: config.topicId,
|
||||
modelName: config.assistant.model?.name, // 使用modelId而不是provider名称
|
||||
inputs: params
|
||||
}
|
||||
|
||||
logger.info('Starting AI SDK trace span', {
|
||||
traceName,
|
||||
topicId: config.topicId,
|
||||
modelId,
|
||||
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
|
||||
toolNames: params.tools ? Object.keys(params.tools) : [],
|
||||
isImageGeneration: config.isImageGenerationEndpoint
|
||||
})
|
||||
|
||||
const span = addSpan(traceParams)
|
||||
if (!span) {
|
||||
logger.warn('Failed to create span, falling back to regular completions', {
|
||||
topicId: config.topicId,
|
||||
modelId,
|
||||
traceName
|
||||
})
|
||||
return await this._completionsOrImageGeneration(model, params, config)
|
||||
}
|
||||
|
||||
try {
|
||||
logger.info('Created parent span, now calling completions', {
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: config.topicId,
|
||||
modelId,
|
||||
parentSpanCreated: true
|
||||
})
|
||||
|
||||
const result = await this._completionsOrImageGeneration(model, params, config)
|
||||
|
||||
logger.info('Completions finished, ending parent span', {
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: config.topicId,
|
||||
modelId,
|
||||
resultLength: result.getText().length
|
||||
})
|
||||
|
||||
// 标记span完成
|
||||
endSpan({
|
||||
topicId: config.topicId,
|
||||
outputs: result,
|
||||
span,
|
||||
modelName: modelId // 使用modelId保持一致性
|
||||
})
|
||||
|
||||
return result
|
||||
} catch (error) {
|
||||
logger.error('Error in completionsForTrace, ending parent span with error', error as Error, {
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: config.topicId,
|
||||
modelId
|
||||
})
|
||||
|
||||
// 标记span出错
|
||||
endSpan({
|
||||
topicId: config.topicId,
|
||||
error: error as Error,
|
||||
span,
|
||||
modelName: modelId // 使用modelId保持一致性
|
||||
})
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过本地 HTTP 服务器处理 claude-code completions
|
||||
*/
|
||||
private async _completionsViaHttpService(
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
config: ModernAiProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
logger.info('Starting claude-code completions via HTTP service', {
|
||||
modelId,
|
||||
providerId: this.config!.providerId,
|
||||
topicId: config.topicId,
|
||||
hasOnChunk: !!config.onChunk
|
||||
})
|
||||
|
||||
try {
|
||||
// 初始化 claude-code provider
|
||||
const initResponse = await fetch('http://localhost:' + (await this.getClaudeCodePort()) + '/init', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify(this.config!.options)
|
||||
})
|
||||
|
||||
if (!initResponse.ok) {
|
||||
throw new Error(`Failed to initialize claude-code provider: ${initResponse.statusText}`)
|
||||
}
|
||||
|
||||
// 发送 completions 请求
|
||||
const completionsResponse = await fetch('http://localhost:' + (await this.getClaudeCodePort()) + '/completions', {
|
||||
method: 'POST',
|
||||
headers: {
|
||||
'Content-Type': 'application/json'
|
||||
},
|
||||
body: JSON.stringify({
|
||||
modelId,
|
||||
params,
|
||||
options: this.config!.options
|
||||
})
|
||||
})
|
||||
|
||||
if (!completionsResponse.ok) {
|
||||
throw new Error(`Failed to get completions: ${completionsResponse.statusText}`)
|
||||
}
|
||||
|
||||
let finalText = ''
|
||||
|
||||
if (config.onChunk && completionsResponse.body) {
|
||||
// 创建 adapter 来处理 chunk 数据
|
||||
const accumulate = this.model!.supported_text_delta !== false
|
||||
const adapter = new AiSdkToChunkAdapter(config.onChunk, config.mcpTools, accumulate)
|
||||
await adapter.processChunk(completionsResponse.body)
|
||||
} else {
|
||||
finalText = await completionsResponse.text()
|
||||
}
|
||||
|
||||
return {
|
||||
getText: () => finalText
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Error in claude-code HTTP service completions', error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 Claude-code HTTP 服务端口
|
||||
*/
|
||||
private async getClaudeCodePort(): Promise<number> {
|
||||
return await window.api.provider.getClaudeCodePort()
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用现代化AI SDK的completions实现
|
||||
*/
|
||||
private async modernCompletions(
|
||||
model: LanguageModel,
|
||||
params: StreamTextParams,
|
||||
config: ModernAiProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
const modelId = this.model!.id
|
||||
logger.info('Starting modernCompletions', {
|
||||
modelId,
|
||||
providerId: this.config!.providerId,
|
||||
topicId: config.topicId,
|
||||
hasOnChunk: !!config.onChunk,
|
||||
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
|
||||
toolCount: params.tools ? Object.keys(params.tools).length : 0
|
||||
})
|
||||
|
||||
// 根据条件构建插件数组
|
||||
const plugins = buildPlugins(config)
|
||||
|
||||
// 用构建好的插件数组创建executor
|
||||
const executor = createExecutor(this.config!.providerId, this.config!.options, plugins)
|
||||
|
||||
// 创建带有中间件的执行器
|
||||
if (config.onChunk) {
|
||||
const accumulate = this.model!.supported_text_delta !== false // true and undefined
|
||||
const adapter = new AiSdkToChunkAdapter(config.onChunk, config.mcpTools, accumulate)
|
||||
|
||||
const streamResult = await executor.streamText({
|
||||
...params,
|
||||
model,
|
||||
experimental_context: { onChunk: config.onChunk }
|
||||
})
|
||||
|
||||
const finalText = await adapter.processStream(streamResult)
|
||||
|
||||
return {
|
||||
getText: () => finalText
|
||||
}
|
||||
} else {
|
||||
const streamResult = await executor.streamText({
|
||||
...params,
|
||||
model
|
||||
})
|
||||
|
||||
// 强制消费流,不然await streamResult.text会阻塞
|
||||
await streamResult?.consumeStream()
|
||||
|
||||
const finalText = await streamResult.text
|
||||
|
||||
return {
|
||||
getText: () => finalText
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用现代化 AI SDK 的图像生成实现,支持流式输出
|
||||
* @deprecated 已改为使用 legacy 实现以支持图片编辑等高级功能
|
||||
*/
|
||||
/*
|
||||
private async modernImageGeneration(
|
||||
model: ImageModel,
|
||||
params: StreamTextParams,
|
||||
config: ModernAiProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
const { onChunk } = config
|
||||
|
||||
try {
|
||||
// 检查 messages 是否存在
|
||||
if (!params.messages || params.messages.length === 0) {
|
||||
throw new Error('No messages provided for image generation.')
|
||||
}
|
||||
|
||||
// 从最后一条用户消息中提取 prompt
|
||||
const lastUserMessage = params.messages.findLast((m) => m.role === 'user')
|
||||
if (!lastUserMessage) {
|
||||
throw new Error('No user message found for image generation.')
|
||||
}
|
||||
|
||||
// 直接使用消息内容,避免类型转换问题
|
||||
const prompt =
|
||||
typeof lastUserMessage.content === 'string'
|
||||
? lastUserMessage.content
|
||||
: lastUserMessage.content?.map((part) => ('text' in part ? part.text : '')).join('') || ''
|
||||
|
||||
if (!prompt) {
|
||||
throw new Error('No prompt found in user message.')
|
||||
}
|
||||
|
||||
const startTime = Date.now()
|
||||
|
||||
// 发送图像生成开始事件
|
||||
if (onChunk) {
|
||||
onChunk({ type: ChunkType.IMAGE_CREATED })
|
||||
}
|
||||
|
||||
// 构建图像生成参数
|
||||
const imageParams = {
|
||||
prompt,
|
||||
size: isNotSupportedImageSizeModel(config.model) ? undefined : ('1024x1024' as `${number}x${number}`), // 默认尺寸,使用正确的类型
|
||||
n: 1,
|
||||
...(params.abortSignal && { abortSignal: params.abortSignal })
|
||||
}
|
||||
|
||||
// 调用新 AI SDK 的图像生成功能
|
||||
const executor = createExecutor(this.config!.providerId, this.config!.options, [])
|
||||
const result = await executor.generateImage({
|
||||
model,
|
||||
...imageParams
|
||||
})
|
||||
|
||||
// 转换结果格式
|
||||
const images: string[] = []
|
||||
const imageType: 'url' | 'base64' = 'base64'
|
||||
|
||||
if (result.images) {
|
||||
for (const image of result.images) {
|
||||
if ('base64' in image && image.base64) {
|
||||
images.push(`data:${image.mediaType};base64,${image.base64}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 发送图像生成完成事件
|
||||
if (onChunk && images.length > 0) {
|
||||
onChunk({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: { type: imageType, images }
|
||||
})
|
||||
}
|
||||
|
||||
// 发送块完成事件(类似于 modernCompletions 的处理)
|
||||
if (onChunk) {
|
||||
const usage = {
|
||||
prompt_tokens: prompt.length, // 估算的 token 数量
|
||||
completion_tokens: 0, // 图像生成没有 completion tokens
|
||||
total_tokens: prompt.length
|
||||
}
|
||||
|
||||
onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
usage,
|
||||
metrics: {
|
||||
completion_tokens: usage.completion_tokens,
|
||||
time_first_token_millsec: 0,
|
||||
time_completion_millsec: Date.now() - startTime
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// 发送 LLM 响应完成事件
|
||||
onChunk({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage,
|
||||
metrics: {
|
||||
completion_tokens: usage.completion_tokens,
|
||||
time_first_token_millsec: 0,
|
||||
time_completion_millsec: Date.now() - startTime
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
getText: () => '' // 图像生成不返回文本
|
||||
}
|
||||
} catch (error) {
|
||||
// 发送错误事件
|
||||
if (onChunk) {
|
||||
onChunk({ type: ChunkType.ERROR, error: error as any })
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
// 代理其他方法到原有实现
|
||||
public async models() {
|
||||
return this.legacyProvider.models()
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
return this.legacyProvider.getEmbeddingDimensions(model)
|
||||
}
|
||||
|
||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
// 如果支持新的 AI SDK,使用现代化实现
|
||||
if (isModernSdkSupported(this.actualProvider)) {
|
||||
try {
|
||||
// 确保本地provider已创建
|
||||
if (!this.localProvider) {
|
||||
this.localProvider = await createAiSdkProvider(this.config)
|
||||
if (!this.localProvider) {
|
||||
throw new Error('Local provider not created')
|
||||
}
|
||||
}
|
||||
|
||||
const result = await this.modernGenerateImage(params)
|
||||
return result
|
||||
} catch (error) {
|
||||
logger.warn('Modern AI SDK generateImage failed, falling back to legacy:', error as Error)
|
||||
// fallback 到传统实现
|
||||
return this.legacyProvider.generateImage(params)
|
||||
}
|
||||
}
|
||||
|
||||
// 直接使用传统实现
|
||||
return this.legacyProvider.generateImage(params)
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用现代化 AI SDK 的图像生成实现
|
||||
*/
|
||||
private async modernGenerateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
const { model, prompt, imageSize, batchSize, signal } = params
|
||||
|
||||
// 转换参数格式
|
||||
const aiSdkParams = {
|
||||
prompt,
|
||||
size: (imageSize || '1024x1024') as `${number}x${number}`,
|
||||
n: batchSize || 1,
|
||||
...(signal && { abortSignal: signal })
|
||||
}
|
||||
|
||||
const executor = createExecutor(this.config!.providerId, this.config!.options, [])
|
||||
const result = await executor.generateImage({
|
||||
model: this.localProvider?.imageModel(model) as ImageModel,
|
||||
...aiSdkParams
|
||||
})
|
||||
|
||||
// 转换结果格式
|
||||
const images: string[] = []
|
||||
if (result.images) {
|
||||
for (const image of result.images) {
|
||||
if ('base64' in image && image.base64) {
|
||||
images.push(`data:image/png;base64,${image.base64}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return images
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.legacyProvider.getBaseURL()
|
||||
}
|
||||
|
||||
public getApiKey(): string {
|
||||
return this.legacyProvider.getApiKey()
|
||||
}
|
||||
}
|
||||
|
||||
// 为了方便调试,导出一些工具函数
|
||||
export { isModernSdkSupported, providerToAiSdkConfig }
|
||||
@@ -75,6 +75,7 @@ export class ApiClientFactory {
|
||||
instance = new GeminiAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'vertexai':
|
||||
logger.debug(`Creating VertexAPIClient for provider: ${provider.id}`)
|
||||
instance = new VertexAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'anthropic':
|
||||
@@ -66,7 +66,8 @@ vi.mock('@renderer/config/models', () => ({
|
||||
SYSTEM_MODELS: {
|
||||
silicon: [],
|
||||
defaultModel: []
|
||||
}
|
||||
},
|
||||
isOpenAIModel: vi.fn(() => false)
|
||||
}))
|
||||
|
||||
describe('ApiClientFactory', () => {
|
||||
@@ -1,11 +1,11 @@
|
||||
import { AihubmixAPIClient } from '@renderer/aiCore/clients/aihubmix/AihubmixAPIClient'
|
||||
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
|
||||
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
|
||||
import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient'
|
||||
import { VertexAPIClient } from '@renderer/aiCore/clients/gemini/VertexAPIClient'
|
||||
import { NewAPIClient } from '@renderer/aiCore/clients/newapi/NewAPIClient'
|
||||
import { OpenAIAPIClient } from '@renderer/aiCore/clients/openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from '@renderer/aiCore/clients/openai/OpenAIResponseAPIClient'
|
||||
import { AihubmixAPIClient } from '@renderer/aiCore/legacy/clients/aihubmix/AihubmixAPIClient'
|
||||
import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient'
|
||||
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
|
||||
import { GeminiAPIClient } from '@renderer/aiCore/legacy/clients/gemini/GeminiAPIClient'
|
||||
import { VertexAPIClient } from '@renderer/aiCore/legacy/clients/gemini/VertexAPIClient'
|
||||
import { NewAPIClient } from '@renderer/aiCore/legacy/clients/newapi/NewAPIClient'
|
||||
import { OpenAIAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIResponseAPIClient'
|
||||
import { EndpointType, Model, Provider } from '@renderer/types'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
@@ -22,6 +22,7 @@ vi.mock('@renderer/config/models', () => ({
|
||||
anthropic: [],
|
||||
gemini: []
|
||||
},
|
||||
isOpenAIModel: vi.fn().mockReturnValue(true),
|
||||
isOpenAILLMModel: vi.fn().mockReturnValue(true),
|
||||
isOpenAIChatCompletionOnlyModel: vi.fn().mockReturnValue(false),
|
||||
isAnthropicLLMModel: vi.fn().mockReturnValue(false),
|
||||
@@ -80,6 +81,7 @@ vi.mock('@logger', () => ({
|
||||
}
|
||||
}))
|
||||
|
||||
// 到底是谁想出来的在服务层调用 React Hook ?????????
|
||||
// Mock additional services and hooks that might be imported
|
||||
vi.mock('@renderer/hooks/useVertexAI', () => ({
|
||||
getVertexAILocation: vi.fn().mockReturnValue('us-central1'),
|
||||
@@ -87,7 +89,9 @@ vi.mock('@renderer/hooks/useVertexAI', () => ({
|
||||
getVertexAIServiceAccount: vi.fn().mockReturnValue({
|
||||
privateKey: 'test-key',
|
||||
clientEmail: 'test@example.com'
|
||||
})
|
||||
}),
|
||||
isVertexAIConfigured: vi.fn().mockReturnValue(true),
|
||||
isVertexProvider: vi.fn().mockReturnValue(true)
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/hooks/useSettings', () => ({
|
||||
@@ -131,7 +135,7 @@ vi.mock('@google-cloud/vertexai', () => ({
|
||||
}))
|
||||
|
||||
// Mock the circular dependency between VertexAPIClient and AnthropicVertexClient
|
||||
vi.mock('@renderer/aiCore/clients/anthropic/AnthropicVertexClient', () => {
|
||||
vi.mock('@renderer/aiCore/legacy/clients/anthropic/AnthropicVertexClient', () => {
|
||||
const MockAnthropicVertexClient = vi.fn()
|
||||
MockAnthropicVertexClient.prototype.getClientCompatibilityType = vi.fn().mockReturnValue(['AnthropicVertexAPIClient'])
|
||||
return {
|
||||
@@ -25,7 +25,6 @@ import {
|
||||
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
||||
import { loggerService } from '@logger'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
|
||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||
@@ -64,13 +63,14 @@ import {
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
anthropicToolUseToMcpTool,
|
||||
isEnabledToolUse,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToAnthropicMessage,
|
||||
mcpToolsToAnthropicTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||
import { t } from 'i18next'
|
||||
|
||||
import { GenericChunk } from '../../middleware/schemas'
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
|
||||
@@ -457,7 +457,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
const systemMessage: TextBlockParam | undefined = systemPrompt
|
||||
@@ -6,7 +6,7 @@ import {
|
||||
InvokeModelWithResponseStreamCommand
|
||||
} from '@aws-sdk/client-bedrock-runtime'
|
||||
import { loggerService } from '@logger'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import { findTokenLimit, isReasoningModel } from '@renderer/config/models'
|
||||
import {
|
||||
@@ -50,7 +50,7 @@ import {
|
||||
import { convertBase64ImageToAwsBedrockFormat } from '@renderer/utils/aws-bedrock-utils'
|
||||
import {
|
||||
awsBedrockToolUseToMcpTool,
|
||||
isEnabledToolUse,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToAwsBedrockMessage,
|
||||
mcpToolsToAwsBedrockTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
@@ -739,7 +739,7 @@ export class AwsBedrockAPIClient extends BaseApiClient<
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
// 3. 处理消息
|
||||
@@ -18,7 +18,6 @@ import {
|
||||
} from '@google/genai'
|
||||
import { loggerService } from '@logger'
|
||||
import { nanoid } from '@reduxjs/toolkit'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import {
|
||||
findTokenLimit,
|
||||
GEMINI_FLASH_MODEL_REGEX,
|
||||
@@ -55,7 +54,7 @@ import {
|
||||
import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
||||
import {
|
||||
geminiFunctionCallToMcpTool,
|
||||
isEnabledToolUse,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToGeminiMessage,
|
||||
mcpToolsToGeminiTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
@@ -63,6 +62,7 @@ import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/u
|
||||
import { defaultTimeout, MB } from '@shared/config/constant'
|
||||
import { t } from 'i18next'
|
||||
|
||||
import { GenericChunk } from '../../middleware/schemas'
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import { RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
|
||||
@@ -454,7 +454,7 @@ export class GeminiAPIClient extends BaseApiClient<
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
let messageContents: Content = { role: 'user', parts: [] } // Initialize messageContents
|
||||
@@ -1,7 +1,7 @@
|
||||
import { GoogleGenAI } from '@google/genai'
|
||||
import { loggerService } from '@logger'
|
||||
import { getVertexAILocation, getVertexAIProjectId, getVertexAIServiceAccount } from '@renderer/hooks/useVertexAI'
|
||||
import { Model, Provider } from '@renderer/types'
|
||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import { Model, Provider, VertexProvider } from '@renderer/types'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient'
|
||||
@@ -12,10 +12,21 @@ export class VertexAPIClient extends GeminiAPIClient {
|
||||
private authHeaders?: Record<string, string>
|
||||
private authHeadersExpiry?: number
|
||||
private anthropicVertexClient: AnthropicVertexClient
|
||||
private vertexProvider: VertexProvider
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
// 检查 VertexAI 配置
|
||||
if (!isVertexAIConfigured()) {
|
||||
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
|
||||
}
|
||||
this.anthropicVertexClient = new AnthropicVertexClient(provider)
|
||||
// 如果传入的是普通 Provider,转换为 VertexProvider
|
||||
if (isVertexProvider(provider)) {
|
||||
this.vertexProvider = provider
|
||||
} else {
|
||||
this.vertexProvider = createVertexProvider(provider)
|
||||
}
|
||||
}
|
||||
|
||||
override getClientCompatibilityType(model?: Model): string[] {
|
||||
@@ -56,11 +67,9 @@ export class VertexAPIClient extends GeminiAPIClient {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
const serviceAccount = getVertexAIServiceAccount()
|
||||
const projectId = getVertexAIProjectId()
|
||||
const location = getVertexAILocation()
|
||||
const { googleCredentials, project, location } = this.vertexProvider
|
||||
|
||||
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId || !location) {
|
||||
if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project || !location) {
|
||||
throw new Error('Vertex AI settings are not configured')
|
||||
}
|
||||
|
||||
@@ -68,7 +77,7 @@ export class VertexAPIClient extends GeminiAPIClient {
|
||||
|
||||
this.sdkInstance = new GoogleGenAI({
|
||||
vertexai: true,
|
||||
project: projectId,
|
||||
project: project,
|
||||
location: location,
|
||||
httpOptions: {
|
||||
apiVersion: this.getApiVersion(),
|
||||
@@ -84,11 +93,10 @@ export class VertexAPIClient extends GeminiAPIClient {
|
||||
* 获取认证头,如果配置了 service account 则从主进程获取
|
||||
*/
|
||||
private async getServiceAccountAuthHeaders(): Promise<Record<string, string> | undefined> {
|
||||
const serviceAccount = getVertexAIServiceAccount()
|
||||
const projectId = getVertexAIProjectId()
|
||||
const { googleCredentials, project } = this.vertexProvider
|
||||
|
||||
// 检查是否配置了 service account
|
||||
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId) {
|
||||
if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
@@ -101,10 +109,10 @@ export class VertexAPIClient extends GeminiAPIClient {
|
||||
try {
|
||||
// 从主进程获取认证头
|
||||
this.authHeaders = await window.api.vertexAI.getAuthHeaders({
|
||||
projectId,
|
||||
projectId: project,
|
||||
serviceAccount: {
|
||||
privateKey: serviceAccount.privateKey,
|
||||
clientEmail: serviceAccount.clientEmail
|
||||
privateKey: googleCredentials.privateKey,
|
||||
clientEmail: googleCredentials.clientEmail
|
||||
}
|
||||
})
|
||||
|
||||
@@ -125,11 +133,10 @@ export class VertexAPIClient extends GeminiAPIClient {
|
||||
this.authHeaders = undefined
|
||||
this.authHeadersExpiry = undefined
|
||||
|
||||
const serviceAccount = getVertexAIServiceAccount()
|
||||
const projectId = getVertexAIProjectId()
|
||||
const { googleCredentials, project } = this.vertexProvider
|
||||
|
||||
if (projectId && serviceAccount.clientEmail) {
|
||||
window.api.vertexAI.clearAuthCache(projectId, serviceAccount.clientEmail)
|
||||
if (project && googleCredentials.clientEmail) {
|
||||
window.api.vertexAI.clearAuthCache(project, googleCredentials.clientEmail)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -71,7 +71,7 @@ import {
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
isEnabledToolUse,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToOpenAICompatibleMessage,
|
||||
mcpToolsToOpenAIChatTools,
|
||||
openAIToolsToMcpTool
|
||||
@@ -611,7 +611,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
// 3. 处理用户消息
|
||||
@@ -1,6 +1,6 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import { CompletionsContext } from '@renderer/aiCore/middleware/types'
|
||||
import { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas'
|
||||
import { CompletionsContext } from '@renderer/aiCore/legacy/middleware/types'
|
||||
import {
|
||||
isGPT5SeriesModel,
|
||||
isOpenAIChatCompletionOnlyModel,
|
||||
@@ -36,7 +36,7 @@ import {
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
isEnabledToolUse,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToOpenAIMessage,
|
||||
mcpToolsToOpenAIResponseTools,
|
||||
openAIToolsToMcpTool
|
||||
@@ -388,7 +388,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
||||
const { tools: extraTools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
systemMessageContent.push(systemMessageInput)
|
||||
189
src/renderer/src/aiCore/legacy/index.ts
Normal file
189
src/renderer/src/aiCore/legacy/index.ts
Normal file
@@ -0,0 +1,189 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
|
||||
import { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient'
|
||||
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { withSpanResult } from '@renderer/services/SpanManagerService'
|
||||
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import type { RequestOptions, SdkModel } from '@renderer/types/sdk'
|
||||
import { isSupportedToolUse } from '@renderer/utils/mcp-tools'
|
||||
|
||||
import { AihubmixAPIClient } from './clients/aihubmix/AihubmixAPIClient'
|
||||
import { VertexAPIClient } from './clients/gemini/VertexAPIClient'
|
||||
import { NewAPIClient } from './clients/newapi/NewAPIClient'
|
||||
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
|
||||
import { CompletionsMiddlewareBuilder } from './middleware/builder'
|
||||
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
|
||||
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
|
||||
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
|
||||
import { applyCompletionsMiddlewares } from './middleware/composer'
|
||||
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
|
||||
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
|
||||
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
|
||||
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
|
||||
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
|
||||
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
|
||||
import { MiddlewareRegistry } from './middleware/register'
|
||||
import type { CompletionsParams, CompletionsResult } from './middleware/schemas'
|
||||
|
||||
const logger = loggerService.withContext('AiProvider')
|
||||
|
||||
export default class AiProvider {
|
||||
private apiClient: BaseApiClient
|
||||
|
||||
constructor(provider: Provider) {
|
||||
// Use the new ApiClientFactory to get a BaseApiClient instance
|
||||
this.apiClient = ApiClientFactory.create(provider)
|
||||
}
|
||||
|
||||
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
// 1. 根据模型识别正确的客户端
|
||||
const model = params.assistant.model
|
||||
if (!model) {
|
||||
return Promise.reject(new Error('Model is required'))
|
||||
}
|
||||
|
||||
// 根据client类型选择合适的处理方式
|
||||
let client: BaseApiClient
|
||||
|
||||
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||
// AihubmixAPIClient: 根据模型选择合适的子client
|
||||
client = this.apiClient.getClientForModel(model)
|
||||
if (client instanceof OpenAIResponseAPIClient) {
|
||||
client = client.getClient(model) as BaseApiClient
|
||||
}
|
||||
} else if (this.apiClient instanceof NewAPIClient) {
|
||||
client = this.apiClient.getClientForModel(model)
|
||||
if (client instanceof OpenAIResponseAPIClient) {
|
||||
client = client.getClient(model) as BaseApiClient
|
||||
}
|
||||
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
|
||||
// OpenAIResponseAPIClient: 根据模型特征选择API类型
|
||||
client = this.apiClient.getClient(model) as BaseApiClient
|
||||
} else if (this.apiClient instanceof VertexAPIClient) {
|
||||
client = this.apiClient.getClient(model) as BaseApiClient
|
||||
} else {
|
||||
// 其他client直接使用
|
||||
client = this.apiClient
|
||||
}
|
||||
|
||||
// 2. 构建中间件链
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
// images api
|
||||
if (isDedicatedImageGenerationModel(model)) {
|
||||
builder.clear()
|
||||
builder
|
||||
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
|
||||
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
|
||||
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
|
||||
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
|
||||
} else {
|
||||
// Existing logic for other models
|
||||
logger.silly('Builder Params', params)
|
||||
// 使用兼容性类型检查,避免typescript类型收窄和装饰器模式的问题
|
||||
const clientTypes = client.getClientCompatibilityType(model)
|
||||
const isOpenAICompatible =
|
||||
clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
|
||||
if (!isOpenAICompatible) {
|
||||
logger.silly('ThinkingTagExtractionMiddleware is removed')
|
||||
builder.remove(ThinkingTagExtractionMiddlewareName)
|
||||
}
|
||||
|
||||
const isAnthropicOrOpenAIResponseCompatible =
|
||||
clientTypes.includes('AnthropicAPIClient') ||
|
||||
clientTypes.includes('OpenAIResponseAPIClient') ||
|
||||
clientTypes.includes('AnthropicVertexAPIClient')
|
||||
if (!isAnthropicOrOpenAIResponseCompatible) {
|
||||
logger.silly('RawStreamListenerMiddleware is removed')
|
||||
builder.remove(RawStreamListenerMiddlewareName)
|
||||
}
|
||||
if (!params.enableWebSearch) {
|
||||
logger.silly('WebSearchMiddleware is removed')
|
||||
builder.remove(WebSearchMiddlewareName)
|
||||
}
|
||||
if (!params.mcpTools?.length) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
logger.silly('ToolUseExtractionMiddleware is removed')
|
||||
builder.remove(McpToolChunkMiddlewareName)
|
||||
logger.silly('McpToolChunkMiddleware is removed')
|
||||
}
|
||||
if (isSupportedToolUse(params.assistant) && isFunctionCallingModel(model)) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
logger.silly('ToolUseExtractionMiddleware is removed')
|
||||
}
|
||||
if (params.callType !== 'chat' && params.callType !== 'check' && params.callType !== 'translate') {
|
||||
logger.silly('AbortHandlerMiddleware is removed')
|
||||
builder.remove(AbortHandlerMiddlewareName)
|
||||
}
|
||||
if (params.callType === 'test') {
|
||||
builder.remove(ErrorHandlerMiddlewareName)
|
||||
logger.silly('ErrorHandlerMiddleware is removed')
|
||||
builder.remove(FinalChunkConsumerMiddlewareName)
|
||||
logger.silly('FinalChunkConsumerMiddleware is removed')
|
||||
}
|
||||
}
|
||||
|
||||
const middlewares = builder.build()
|
||||
logger.silly(
|
||||
'middlewares',
|
||||
middlewares.map((m) => m.name)
|
||||
)
|
||||
|
||||
// 3. Create the wrapped SDK method with middlewares
|
||||
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
|
||||
|
||||
// 4. Execute the wrapped method with the original params
|
||||
const result = wrappedCompletionMethod(params, options)
|
||||
return result
|
||||
}
|
||||
|
||||
public async completionsForTrace(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
const traceName = params.assistant.model?.name
|
||||
? `${params.assistant.model?.name}.${params.callType}`
|
||||
: `LLM.${params.callType}`
|
||||
|
||||
const traceParams: StartSpanParams = {
|
||||
name: traceName,
|
||||
tag: 'LLM',
|
||||
topicId: params.topicId || '',
|
||||
modelName: params.assistant.model?.name
|
||||
}
|
||||
|
||||
return await withSpanResult(this.completions.bind(this), traceParams, params, options)
|
||||
}
|
||||
|
||||
public async models(): Promise<SdkModel[]> {
|
||||
return this.apiClient.listModels()
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
try {
|
||||
// Use the SDK instance to test embedding capabilities
|
||||
if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') {
|
||||
this.apiClient = this.apiClient.getClient(model) as BaseApiClient
|
||||
}
|
||||
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
|
||||
return dimensions
|
||||
} catch (error) {
|
||||
logger.error('Error getting embedding dimensions:', error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||
const client = this.apiClient.getClientForModel({ id: params.model } as Model)
|
||||
return client.generateImage(params)
|
||||
}
|
||||
return this.apiClient.generateImage(params)
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.apiClient.getBaseURL()
|
||||
}
|
||||
|
||||
public getApiKey(): string {
|
||||
return this.apiClient.getApiKey()
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user