Compare commits
160 Commits
main
...
v1.6.0-bet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9551c49452 | ||
|
|
e312c84a0e | ||
|
|
3d0fb97475 | ||
|
|
d10ba04047 | ||
|
|
4b7023f855 | ||
|
|
5f096ecf8c | ||
|
|
a12d627b65 | ||
|
|
7bda658022 | ||
|
|
bfcb215c16 | ||
|
|
1b3fcb2e55 | ||
|
|
9c01e24317 | ||
|
|
2ce9314a10 | ||
|
|
0c7e221b4e | ||
|
|
82d4637c9d | ||
|
|
84eef25ff9 | ||
|
|
53dcda6942 | ||
|
|
ead0e22c60 | ||
|
|
417f90df3b | ||
|
|
65c15c6d87 | ||
|
|
ca4e7e3d2b | ||
|
|
d34b640807 | ||
|
|
aa9ed3b9c8 | ||
|
|
d4da7d817d | ||
|
|
179b7af9bd | ||
|
|
5d0ab0a9a1 | ||
|
|
d93a36e5c9 | ||
|
|
c9c0616c91 | ||
|
|
356443babf | ||
|
|
b2c512082f | ||
|
|
0273c58050 | ||
|
|
239c849890 | ||
|
|
bbc472c169 | ||
|
|
b099c9b0b3 | ||
|
|
628919b562 | ||
|
|
d05ed94702 | ||
|
|
02f8e7a857 | ||
|
|
6c093f72d8 | ||
|
|
0bb1001d40 | ||
|
|
376020b23c | ||
|
|
3630133efd | ||
|
|
cb55f7a69b | ||
|
|
ff7ad52ad5 | ||
|
|
bf02afa841 | ||
|
|
e8b059c4db | ||
|
|
abfec7a228 | ||
|
|
eeafb99059 | ||
|
|
1ea8266280 | ||
|
|
def685921c | ||
|
|
a8dbae1715 | ||
|
|
71959f577d | ||
|
|
ecc08bd3f7 | ||
|
|
7216e9943c | ||
|
|
a05d7cbe2d | ||
|
|
0310648445 | ||
|
|
33db455e32 | ||
|
|
e690da840c | ||
|
|
eca9442907 | ||
|
|
4b62384fc5 | ||
|
|
addd5ffdfa | ||
|
|
fcc8836c95 | ||
|
|
61e3309cd2 | ||
|
|
786bc8dca9 | ||
|
|
c3a6456499 | ||
|
|
ef6be4a6f9 | ||
|
|
69e87ce21a | ||
|
|
608943bdbc | ||
|
|
1248e3c49a | ||
|
|
c3ad18b77e | ||
|
|
0bc5e3d24d | ||
|
|
36e20d545b | ||
|
|
45405213fc | ||
|
|
b83837708b | ||
|
|
4732c8f1bd | ||
|
|
ef8cf65ece | ||
|
|
e3c5c87e1b | ||
|
|
e7d5626055 | ||
|
|
650650a68f | ||
|
|
f38e4a87b8 | ||
|
|
a356492d6f | ||
|
|
8863e10df1 | ||
|
|
42bfa281a7 | ||
|
|
e7b4f1f934 | ||
|
|
0456094512 | ||
|
|
da455997ad | ||
|
|
0c4e8228af | ||
|
|
16e0154200 | ||
|
|
3ab904e789 | ||
|
|
42c7ebd193 | ||
|
|
a0623f2187 | ||
|
|
4bfff85dc8 | ||
|
|
8317ad55e7 | ||
|
|
b67cd9d145 | ||
|
|
234514d736 | ||
|
|
450d6228d4 | ||
|
|
3c955e69f1 | ||
|
|
4573e3f48f | ||
|
|
56c5e5a80f | ||
|
|
bb520910bc | ||
|
|
342c5ab82c | ||
|
|
fce8f2411c | ||
|
|
0a908a334b | ||
|
|
c72156b2da | ||
|
|
9e252d7eb0 | ||
|
|
4b0d8d7e65 | ||
|
|
448b5b5c9e | ||
|
|
f20d964be3 | ||
|
|
c92475b6bf | ||
|
|
89cbf80008 | ||
|
|
3e5969b97c | ||
|
|
cd42410d70 | ||
|
|
547e5785c0 | ||
|
|
13162edcb2 | ||
|
|
ac15930692 | ||
|
|
ff3b1fc38f | ||
|
|
b660e9d524 | ||
|
|
182ab6092c | ||
|
|
cf5ed8e858 | ||
|
|
007de81928 | ||
|
|
6c87b42607 | ||
|
|
592a7ddc3f | ||
|
|
60cb198f44 | ||
|
|
54c36040af | ||
|
|
ef616e1c3b | ||
|
|
dc106a8af7 | ||
|
|
1bcc716eaf | ||
|
|
30a288ce5d | ||
|
|
cbbaa3127c | ||
|
|
f61da8c2d6 | ||
|
|
d9eb9e86fe | ||
|
|
87f803b0d3 | ||
|
|
c934b45c09 | ||
|
|
ba121d04b4 | ||
|
|
9293f26612 | ||
|
|
8b67a45804 | ||
|
|
f23a026a28 | ||
|
|
e4c0ea035f | ||
|
|
7d8ed3a737 | ||
|
|
2a588fdab2 | ||
|
|
f08c444ffb | ||
|
|
f6c3794ac9 | ||
|
|
ebe85ba24a | ||
|
|
09080f0755 | ||
|
|
e421b81fca | ||
|
|
2f58b3360e | ||
|
|
f934b479b2 | ||
|
|
8ca6341609 | ||
|
|
c99a2fedb7 | ||
|
|
456e6c068e | ||
|
|
f206d4ec4c | ||
|
|
1af8be8768 | ||
|
|
e70174817e | ||
|
|
c5cb443de0 | ||
|
|
9318d9ffeb | ||
|
|
3771b24b52 | ||
|
|
1bccfd3170 | ||
|
|
43d55b7e45 | ||
|
|
1c5a30cf49 | ||
|
|
2df1cddb43 | ||
|
|
ed2363e561 | ||
|
|
a27d1bf506 |
@@ -7,3 +7,4 @@ tsconfig.*.json
|
||||
CHANGELOG*.md
|
||||
agents.json
|
||||
src/renderer/src/integration/nutstore/sso/lib
|
||||
AGENT.md
|
||||
|
||||
@@ -116,11 +116,37 @@ afterSign: scripts/notarize.js
|
||||
artifactBuildCompleted: scripts/artifact-build-completed.js
|
||||
releaseInfo:
|
||||
releaseNotes: |
|
||||
输入框快捷菜单增加清除按钮
|
||||
侧边栏增加代码工具入口,代码工具增加环境变量设置
|
||||
小程序增加多语言显示
|
||||
优化 MCP 服务器列表
|
||||
新增 Web 搜索图标
|
||||
优化 SVG 预览,优化 HTML 内容样式
|
||||
修复知识库文档预处理失败问题
|
||||
稳定性改进和错误修复
|
||||
🎉 新增功能:
|
||||
- 集成全新 AI SDK 架构,提升 AI 提供商管理和工具调用性能
|
||||
- 新增 OCR 图像文字识别功能,支持图片转文字和翻译
|
||||
- 新增代码工具页面,支持环境变量设置和命令行工具管理
|
||||
- 新增内置 Web 搜索工具和记忆搜索工具
|
||||
- 新增翻译历史搜索和收藏功能
|
||||
- 新增模型筛选和推理缓存功能
|
||||
- 新增快速模型检测和语言检测功能
|
||||
- 支持多种新模型:Qwen Flash、DeepSeek v3.1、Vertex AI 等
|
||||
|
||||
🔧 优化改进:
|
||||
- 优化 MCP 服务器列表,新增搜索功能
|
||||
- 优化 SVG 预览和 HTML 内容样式渲染
|
||||
- 优化代码块识别和语言别名支持
|
||||
- 优化拖拽列表组件,移除 antd 依赖
|
||||
- 优化选择工具栏样式和验证逻辑
|
||||
- 优化历史页面搜索性能和加载状态
|
||||
- 优化代理管理器的规则匹配和日志记录
|
||||
- 优化翻译服务,减少 token 消耗
|
||||
|
||||
🐛 问题修复:
|
||||
- 修复知识库文档预处理失败问题
|
||||
- 修复 Windows 平台代码工具路径空格处理
|
||||
- 修复选择文本范围验证和空字符串处理
|
||||
- 修复 Web 搜索引用丢失问题
|
||||
- 修复全屏模式意外退出问题
|
||||
- 修复各种模型 API 兼容性问题
|
||||
- 修复文件上传和附件预览相关问题
|
||||
- 修复多语言翻译格式和显示问题
|
||||
|
||||
⚡ 性能提升:
|
||||
- 优化消息处理和工具调用性能
|
||||
- 改进错误边界和异常处理
|
||||
- 优化内存使用和避免内存泄漏
|
||||
|
||||
@@ -81,7 +81,10 @@ export default defineConfig({
|
||||
'@shared': resolve('packages/shared'),
|
||||
'@logger': resolve('src/renderer/src/services/LoggerService'),
|
||||
'@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'),
|
||||
'@mcp-trace/trace-web': resolve('packages/mcp-trace/trace-web')
|
||||
'@mcp-trace/trace-web': resolve('packages/mcp-trace/trace-web'),
|
||||
'@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'),
|
||||
'@cherrystudio/ai-core/built-in/plugins': resolve('packages/aiCore/src/core/plugins/built-in'),
|
||||
'@cherrystudio/ai-core': resolve('packages/aiCore/src')
|
||||
}
|
||||
},
|
||||
optimizeDeps: {
|
||||
|
||||
12
package.json
12
package.json
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "CherryStudio",
|
||||
"version": "1.5.7-rc.2",
|
||||
"version": "1.6.0-beta.1",
|
||||
"private": true,
|
||||
"description": "A powerful AI assistant for producer.",
|
||||
"main": "./out/main/index.js",
|
||||
@@ -87,12 +87,16 @@
|
||||
"@agentic/exa": "^7.3.3",
|
||||
"@agentic/searxng": "^7.3.3",
|
||||
"@agentic/tavily": "^7.3.3",
|
||||
"@ai-sdk/amazon-bedrock": "^3.0.0",
|
||||
"@ai-sdk/google-vertex": "^3.0.0",
|
||||
"@ai-sdk/mistral": "^2.0.0",
|
||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||
"@anthropic-ai/sdk": "^0.41.0",
|
||||
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
|
||||
"@aws-sdk/client-bedrock": "^3.840.0",
|
||||
"@aws-sdk/client-bedrock-runtime": "^3.840.0",
|
||||
"@aws-sdk/client-s3": "^3.840.0",
|
||||
"@cherrystudio/ai-core": "workspace:*",
|
||||
"@cherrystudio/embedjs": "^0.1.31",
|
||||
"@cherrystudio/embedjs-libsql": "^0.1.31",
|
||||
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
|
||||
@@ -127,6 +131,7 @@
|
||||
"@modelcontextprotocol/sdk": "^1.17.0",
|
||||
"@mozilla/readability": "^0.6.0",
|
||||
"@notionhq/client": "^2.2.15",
|
||||
"@openrouter/ai-sdk-provider": "^1.1.2",
|
||||
"@opentelemetry/api": "^1.9.0",
|
||||
"@opentelemetry/core": "2.0.0",
|
||||
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
|
||||
@@ -136,7 +141,7 @@
|
||||
"@playwright/test": "^1.52.0",
|
||||
"@reduxjs/toolkit": "^2.2.5",
|
||||
"@shikijs/markdown-it": "^3.9.1",
|
||||
"@swc/plugin-styled-components": "^7.1.5",
|
||||
"@swc/plugin-styled-components": "^8.0.4",
|
||||
"@tanstack/react-query": "^5.27.0",
|
||||
"@tanstack/react-virtual": "^3.13.12",
|
||||
"@testing-library/dom": "^10.4.0",
|
||||
@@ -169,6 +174,7 @@
|
||||
"@viz-js/lang-dot": "^1.0.5",
|
||||
"@viz-js/viz": "^3.14.0",
|
||||
"@xyflow/react": "^12.4.4",
|
||||
"ai": "^5.0.26",
|
||||
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
|
||||
"archiver": "^7.0.1",
|
||||
"async-mutex": "^0.5.0",
|
||||
@@ -307,7 +313,7 @@
|
||||
"prettier --write",
|
||||
"eslint --fix"
|
||||
],
|
||||
"*.{json,md,yml,yaml,css,scss,html}": [
|
||||
"*.{json,yml,yaml,css,scss,html}": [
|
||||
"prettier --write"
|
||||
]
|
||||
}
|
||||
|
||||
514
packages/aiCore/AI_SDK_ARCHITECTURE.md
Normal file
514
packages/aiCore/AI_SDK_ARCHITECTURE.md
Normal file
@@ -0,0 +1,514 @@
|
||||
# AI Core 基于 Vercel AI SDK 的技术架构
|
||||
|
||||
## 1. 架构设计理念
|
||||
|
||||
### 1.1 设计目标
|
||||
|
||||
- **简化分层**:`models`(模型层)→ `runtime`(运行时层),清晰的职责分离
|
||||
- **统一接口**:使用 Vercel AI SDK 统一不同 AI Provider 的接口差异
|
||||
- **动态导入**:通过动态导入实现按需加载,减少打包体积
|
||||
- **最小包装**:直接使用 AI SDK 的类型和接口,避免重复定义
|
||||
- **插件系统**:基于钩子的通用插件架构,支持请求全生命周期扩展
|
||||
- **类型安全**:利用 TypeScript 和 AI SDK 的类型系统确保类型安全
|
||||
- **轻量级**:专注核心功能,保持包的轻量和高效
|
||||
- **包级独立**:作为独立包管理,便于复用和维护
|
||||
- **Agent就绪**:为将来集成 OpenAI Agents SDK 预留扩展空间
|
||||
|
||||
### 1.2 核心优势
|
||||
|
||||
- **标准化**:AI SDK 提供统一的模型接口,减少适配工作
|
||||
- **简化设计**:函数式API,避免过度抽象
|
||||
- **更好的开发体验**:完整的 TypeScript 支持和丰富的生态系统
|
||||
- **性能优化**:AI SDK 内置优化和最佳实践
|
||||
- **模块化设计**:独立包结构,支持跨项目复用
|
||||
- **可扩展插件**:通用的流转换和参数处理插件系统
|
||||
- **面向未来**:为 OpenAI Agents SDK 集成做好准备
|
||||
|
||||
## 2. 整体架构图
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
subgraph "用户应用 (如 Cherry Studio)"
|
||||
UI["用户界面"]
|
||||
Components["应用组件"]
|
||||
end
|
||||
|
||||
subgraph "packages/aiCore (AI Core 包)"
|
||||
subgraph "Runtime Layer (运行时层)"
|
||||
RuntimeExecutor["RuntimeExecutor (运行时执行器)"]
|
||||
PluginEngine["PluginEngine (插件引擎)"]
|
||||
RuntimeAPI["Runtime API (便捷函数)"]
|
||||
end
|
||||
|
||||
subgraph "Models Layer (模型层)"
|
||||
ModelFactory["createModel() (模型工厂)"]
|
||||
ProviderCreator["ProviderCreator (提供商创建器)"]
|
||||
end
|
||||
|
||||
subgraph "Core Systems (核心系统)"
|
||||
subgraph "Plugins (插件)"
|
||||
PluginManager["PluginManager (插件管理)"]
|
||||
BuiltInPlugins["Built-in Plugins (内置插件)"]
|
||||
StreamTransforms["Stream Transforms (流转换)"]
|
||||
end
|
||||
|
||||
subgraph "Middleware (中间件)"
|
||||
MiddlewareWrapper["wrapModelWithMiddlewares() (中间件包装)"]
|
||||
end
|
||||
|
||||
subgraph "Providers (提供商)"
|
||||
Registry["Provider Registry (注册表)"]
|
||||
Factory["Provider Factory (工厂)"]
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
subgraph "Vercel AI SDK"
|
||||
AICore["ai (核心库)"]
|
||||
OpenAI["@ai-sdk/openai"]
|
||||
Anthropic["@ai-sdk/anthropic"]
|
||||
Google["@ai-sdk/google"]
|
||||
XAI["@ai-sdk/xai"]
|
||||
Others["其他 19+ Providers"]
|
||||
end
|
||||
|
||||
subgraph "Future: OpenAI Agents SDK"
|
||||
AgentSDK["@openai/agents (未来集成)"]
|
||||
AgentExtensions["Agent Extensions (预留)"]
|
||||
end
|
||||
|
||||
UI --> RuntimeAPI
|
||||
Components --> RuntimeExecutor
|
||||
RuntimeAPI --> RuntimeExecutor
|
||||
RuntimeExecutor --> PluginEngine
|
||||
RuntimeExecutor --> ModelFactory
|
||||
PluginEngine --> PluginManager
|
||||
ModelFactory --> ProviderCreator
|
||||
ModelFactory --> MiddlewareWrapper
|
||||
ProviderCreator --> Registry
|
||||
Registry --> Factory
|
||||
Factory --> OpenAI
|
||||
Factory --> Anthropic
|
||||
Factory --> Google
|
||||
Factory --> XAI
|
||||
Factory --> Others
|
||||
|
||||
RuntimeExecutor --> AICore
|
||||
AICore --> streamText
|
||||
AICore --> generateText
|
||||
AICore --> streamObject
|
||||
AICore --> generateObject
|
||||
|
||||
PluginManager --> StreamTransforms
|
||||
PluginManager --> BuiltInPlugins
|
||||
|
||||
%% 未来集成路径
|
||||
RuntimeExecutor -.-> AgentSDK
|
||||
AgentSDK -.-> AgentExtensions
|
||||
```
|
||||
|
||||
## 3. 包结构设计
|
||||
|
||||
### 3.1 新架构文件结构
|
||||
|
||||
```
|
||||
packages/aiCore/
|
||||
├── src/
|
||||
│ ├── core/ # 核心层 - 内部实现
|
||||
│ │ ├── models/ # 模型层 - 模型创建和配置
|
||||
│ │ │ ├── factory.ts # 模型工厂函数 ✅
|
||||
│ │ │ ├── ModelCreator.ts # 模型创建器 ✅
|
||||
│ │ │ ├── ConfigManager.ts # 配置管理器 ✅
|
||||
│ │ │ ├── types.ts # 模型类型定义 ✅
|
||||
│ │ │ └── index.ts # 模型层导出 ✅
|
||||
│ │ ├── runtime/ # 运行时层 - 执行和用户API
|
||||
│ │ │ ├── executor.ts # 运行时执行器 ✅
|
||||
│ │ │ ├── pluginEngine.ts # 插件引擎 ✅
|
||||
│ │ │ ├── types.ts # 运行时类型定义 ✅
|
||||
│ │ │ └── index.ts # 运行时导出 ✅
|
||||
│ │ ├── middleware/ # 中间件系统
|
||||
│ │ │ ├── wrapper.ts # 模型包装器 ✅
|
||||
│ │ │ ├── manager.ts # 中间件管理器 ✅
|
||||
│ │ │ ├── types.ts # 中间件类型 ✅
|
||||
│ │ │ └── index.ts # 中间件导出 ✅
|
||||
│ │ ├── plugins/ # 插件系统
|
||||
│ │ │ ├── types.ts # 插件类型定义 ✅
|
||||
│ │ │ ├── manager.ts # 插件管理器 ✅
|
||||
│ │ │ ├── built-in/ # 内置插件 ✅
|
||||
│ │ │ │ ├── logging.ts # 日志插件 ✅
|
||||
│ │ │ │ ├── webSearchPlugin/ # 网络搜索插件 ✅
|
||||
│ │ │ │ ├── toolUsePlugin/ # 工具使用插件 ✅
|
||||
│ │ │ │ └── index.ts # 内置插件导出 ✅
|
||||
│ │ │ ├── README.md # 插件文档 ✅
|
||||
│ │ │ └── index.ts # 插件导出 ✅
|
||||
│ │ ├── providers/ # 提供商管理
|
||||
│ │ │ ├── registry.ts # 提供商注册表 ✅
|
||||
│ │ │ ├── factory.ts # 提供商工厂 ✅
|
||||
│ │ │ ├── creator.ts # 提供商创建器 ✅
|
||||
│ │ │ ├── types.ts # 提供商类型 ✅
|
||||
│ │ │ ├── utils.ts # 工具函数 ✅
|
||||
│ │ │ └── index.ts # 提供商导出 ✅
|
||||
│ │ ├── options/ # 配置选项
|
||||
│ │ │ ├── factory.ts # 选项工厂 ✅
|
||||
│ │ │ ├── types.ts # 选项类型 ✅
|
||||
│ │ │ ├── xai.ts # xAI 选项 ✅
|
||||
│ │ │ ├── openrouter.ts # OpenRouter 选项 ✅
|
||||
│ │ │ ├── examples.ts # 示例配置 ✅
|
||||
│ │ │ └── index.ts # 选项导出 ✅
|
||||
│ │ └── index.ts # 核心层导出 ✅
|
||||
│ ├── types.ts # 全局类型定义 ✅
|
||||
│ └── index.ts # 包主入口文件 ✅
|
||||
├── package.json # 包配置文件 ✅
|
||||
├── tsconfig.json # TypeScript 配置 ✅
|
||||
├── README.md # 包说明文档 ✅
|
||||
└── AI_SDK_ARCHITECTURE.md # 本文档 ✅
|
||||
```
|
||||
|
||||
## 4. 架构分层详解
|
||||
|
||||
### 4.1 Models Layer (模型层)
|
||||
|
||||
**职责**:统一的模型创建和配置管理
|
||||
|
||||
**核心文件**:
|
||||
|
||||
- `factory.ts`: 模型工厂函数 (`createModel`, `createModels`)
|
||||
- `ProviderCreator.ts`: 底层提供商创建和模型实例化
|
||||
- `types.ts`: 模型配置类型定义
|
||||
|
||||
**设计特点**:
|
||||
|
||||
- 函数式设计,避免不必要的类抽象
|
||||
- 统一的模型配置接口
|
||||
- 自动处理中间件应用
|
||||
- 支持批量模型创建
|
||||
|
||||
**核心API**:
|
||||
|
||||
```typescript
|
||||
// 模型配置接口
|
||||
export interface ModelConfig {
|
||||
providerId: ProviderId
|
||||
modelId: string
|
||||
options: ProviderSettingsMap[ProviderId]
|
||||
middlewares?: LanguageModelV1Middleware[]
|
||||
}
|
||||
|
||||
// 核心模型创建函数
|
||||
export async function createModel(config: ModelConfig): Promise<LanguageModel>
|
||||
export async function createModels(configs: ModelConfig[]): Promise<LanguageModel[]>
|
||||
```
|
||||
|
||||
### 4.2 Runtime Layer (运行时层)
|
||||
|
||||
**职责**:运行时执行器和用户面向的API接口
|
||||
|
||||
**核心组件**:
|
||||
|
||||
- `executor.ts`: 运行时执行器类
|
||||
- `plugin-engine.ts`: 插件引擎(原PluginEnabledAiClient)
|
||||
- `index.ts`: 便捷函数和工厂方法
|
||||
|
||||
**设计特点**:
|
||||
|
||||
- 提供三种使用方式:类实例、静态工厂、函数式调用
|
||||
- 自动集成模型创建和插件处理
|
||||
- 完整的类型安全支持
|
||||
- 为 OpenAI Agents SDK 预留扩展接口
|
||||
|
||||
**核心API**:
|
||||
|
||||
```typescript
|
||||
// 运行时执行器
|
||||
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
static create<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T],
|
||||
plugins?: AiPlugin[]
|
||||
): RuntimeExecutor<T>
|
||||
|
||||
async streamText(modelId: string, params: StreamTextParams): Promise<StreamTextResult>
|
||||
async generateText(modelId: string, params: GenerateTextParams): Promise<GenerateTextResult>
|
||||
async streamObject(modelId: string, params: StreamObjectParams): Promise<StreamObjectResult>
|
||||
async generateObject(modelId: string, params: GenerateObjectParams): Promise<GenerateObjectResult>
|
||||
}
|
||||
|
||||
// 便捷函数式API
|
||||
export async function streamText<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T],
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
plugins?: AiPlugin[]
|
||||
): Promise<StreamTextResult>
|
||||
```
|
||||
|
||||
### 4.3 Plugin System (插件系统)
|
||||
|
||||
**职责**:可扩展的插件架构
|
||||
|
||||
**核心组件**:
|
||||
|
||||
- `PluginManager`: 插件生命周期管理
|
||||
- `built-in/`: 内置插件集合
|
||||
- 流转换收集和应用
|
||||
|
||||
**设计特点**:
|
||||
|
||||
- 借鉴 Rollup 的钩子分类设计
|
||||
- 支持流转换 (`experimental_transform`)
|
||||
- 内置常用插件(日志、计数等)
|
||||
- 完整的生命周期钩子
|
||||
|
||||
**插件接口**:
|
||||
|
||||
```typescript
|
||||
export interface AiPlugin {
|
||||
name: string
|
||||
enforce?: 'pre' | 'post'
|
||||
|
||||
// 【First】首个钩子 - 只执行第一个返回值的插件
|
||||
resolveModel?: (modelId: string, context: AiRequestContext) => string | null | Promise<string | null>
|
||||
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
|
||||
|
||||
// 【Sequential】串行钩子 - 链式执行,支持数据转换
|
||||
transformParams?: (params: any, context: AiRequestContext) => any | Promise<any>
|
||||
transformResult?: (result: any, context: AiRequestContext) => any | Promise<any>
|
||||
|
||||
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
|
||||
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
|
||||
onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise<void>
|
||||
onError?: (error: Error, context: AiRequestContext) => void | Promise<void>
|
||||
|
||||
// 【Stream】流处理
|
||||
transformStream?: () => TransformStream
|
||||
}
|
||||
```
|
||||
|
||||
### 4.4 Middleware System (中间件系统)
|
||||
|
||||
**职责**:AI SDK原生中间件支持
|
||||
|
||||
**核心组件**:
|
||||
|
||||
- `ModelWrapper.ts`: 模型包装函数
|
||||
|
||||
**设计哲学**:
|
||||
|
||||
- 直接使用AI SDK的 `wrapLanguageModel`
|
||||
- 与插件系统分离,职责明确
|
||||
- 函数式设计,简化使用
|
||||
|
||||
```typescript
|
||||
export function wrapModelWithMiddlewares(model: LanguageModel, middlewares: LanguageModelV1Middleware[]): LanguageModel
|
||||
```
|
||||
|
||||
### 4.5 Provider System (提供商系统)
|
||||
|
||||
**职责**:AI Provider注册表和动态导入
|
||||
|
||||
**核心组件**:
|
||||
|
||||
- `registry.ts`: 19+ Provider配置和类型
|
||||
- `factory.ts`: Provider配置工厂
|
||||
|
||||
**支持的Providers**:
|
||||
|
||||
- OpenAI, Anthropic, Google, XAI
|
||||
- Azure OpenAI, Amazon Bedrock, Google Vertex
|
||||
- Groq, Together.ai, Fireworks, DeepSeek
|
||||
- 等19+ AI SDK官方支持的providers
|
||||
|
||||
## 5. 使用方式
|
||||
|
||||
### 5.1 函数式调用 (推荐 - 简单场景)
|
||||
|
||||
```typescript
|
||||
import { streamText, generateText } from '@cherrystudio/ai-core/runtime'
|
||||
|
||||
// 直接函数调用
|
||||
const stream = await streamText(
|
||||
'anthropic',
|
||||
{ apiKey: 'your-api-key' },
|
||||
'claude-3',
|
||||
{ messages: [{ role: 'user', content: 'Hello!' }] },
|
||||
[loggingPlugin]
|
||||
)
|
||||
```
|
||||
|
||||
### 5.2 执行器实例 (推荐 - 复杂场景)
|
||||
|
||||
```typescript
|
||||
import { createExecutor } from '@cherrystudio/ai-core/runtime'
|
||||
|
||||
// 创建可复用的执行器
|
||||
const executor = createExecutor('openai', { apiKey: 'your-api-key' }, [plugin1, plugin2])
|
||||
|
||||
// 多次使用
|
||||
const stream = await executor.streamText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'Hello!' }]
|
||||
})
|
||||
|
||||
const result = await executor.generateText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'How are you?' }]
|
||||
})
|
||||
```
|
||||
|
||||
### 5.3 静态工厂方法
|
||||
|
||||
```typescript
|
||||
import { RuntimeExecutor } from '@cherrystudio/ai-core/runtime'
|
||||
|
||||
// 静态创建
|
||||
const executor = RuntimeExecutor.create('anthropic', { apiKey: 'your-api-key' })
|
||||
await executor.streamText('claude-3', { messages: [...] })
|
||||
```
|
||||
|
||||
### 5.4 直接模型创建 (高级用法)
|
||||
|
||||
```typescript
|
||||
import { createModel } from '@cherrystudio/ai-core/models'
|
||||
import { streamText } from 'ai'
|
||||
|
||||
// 直接创建模型使用
|
||||
const model = await createModel({
|
||||
providerId: 'openai',
|
||||
modelId: 'gpt-4',
|
||||
options: { apiKey: 'your-api-key' },
|
||||
middlewares: [middleware1, middleware2]
|
||||
})
|
||||
|
||||
// 直接使用 AI SDK
|
||||
const result = await streamText({ model, messages: [...] })
|
||||
```
|
||||
|
||||
## 6. 为 OpenAI Agents SDK 预留的设计
|
||||
|
||||
### 6.1 架构兼容性
|
||||
|
||||
当前架构完全兼容 OpenAI Agents SDK 的集成需求:
|
||||
|
||||
```typescript
|
||||
// 当前的模型创建
|
||||
const model = await createModel({
|
||||
providerId: 'anthropic',
|
||||
modelId: 'claude-3',
|
||||
options: { apiKey: 'xxx' }
|
||||
})
|
||||
|
||||
// 将来可以直接用于 OpenAI Agents SDK
|
||||
import { Agent, run } from '@openai/agents'
|
||||
|
||||
const agent = new Agent({
|
||||
model, // ✅ 直接兼容 LanguageModel 接口
|
||||
name: 'Assistant',
|
||||
instructions: '...',
|
||||
tools: [tool1, tool2]
|
||||
})
|
||||
|
||||
const result = await run(agent, 'user input')
|
||||
```
|
||||
|
||||
### 6.2 预留的扩展点
|
||||
|
||||
1. **runtime/agents/** 目录预留
|
||||
2. **AgentExecutor** 类预留
|
||||
3. **Agent工具转换插件** 预留
|
||||
4. **多Agent编排** 预留
|
||||
|
||||
### 6.3 未来架构扩展
|
||||
|
||||
```
|
||||
packages/aiCore/src/core/
|
||||
├── runtime/
|
||||
│ ├── agents/ # 🚀 未来添加
|
||||
│ │ ├── AgentExecutor.ts
|
||||
│ │ ├── WorkflowManager.ts
|
||||
│ │ └── ConversationManager.ts
|
||||
│ ├── executor.ts
|
||||
│ └── index.ts
|
||||
```
|
||||
|
||||
## 7. 架构优势
|
||||
|
||||
### 7.1 简化设计
|
||||
|
||||
- **移除过度抽象**:删除了orchestration层和creation层的复杂包装
|
||||
- **函数式优先**:models层使用函数而非类
|
||||
- **直接明了**:runtime层直接提供用户API
|
||||
|
||||
### 7.2 职责清晰
|
||||
|
||||
- **Models**: 专注模型创建和配置
|
||||
- **Runtime**: 专注执行和用户API
|
||||
- **Plugins**: 专注扩展功能
|
||||
- **Providers**: 专注AI Provider管理
|
||||
|
||||
### 7.3 类型安全
|
||||
|
||||
- 完整的 TypeScript 支持
|
||||
- AI SDK 类型的直接复用
|
||||
- 避免类型重复定义
|
||||
|
||||
### 7.4 灵活使用
|
||||
|
||||
- 三种使用模式满足不同需求
|
||||
- 从简单函数调用到复杂执行器
|
||||
- 支持直接AI SDK使用
|
||||
|
||||
### 7.5 面向未来
|
||||
|
||||
- 为 OpenAI Agents SDK 集成做好准备
|
||||
- 清晰的扩展点和架构边界
|
||||
- 模块化设计便于功能添加
|
||||
|
||||
## 8. 技术决策记录
|
||||
|
||||
### 8.1 为什么选择简化的两层架构?
|
||||
|
||||
- **职责分离**:models专注创建,runtime专注执行
|
||||
- **模块化**:每层都有清晰的边界和职责
|
||||
- **扩展性**:为Agent功能预留了清晰的扩展空间
|
||||
|
||||
### 8.2 为什么选择函数式设计?
|
||||
|
||||
- **简洁性**:避免不必要的类设计
|
||||
- **性能**:减少对象创建开销
|
||||
- **易用性**:函数调用更直观
|
||||
|
||||
### 8.3 为什么分离插件和中间件?
|
||||
|
||||
- **职责明确**: 插件处理应用特定需求
|
||||
- **原生支持**: 中间件使用AI SDK原生功能
|
||||
- **灵活性**: 两套系统可以独立演进
|
||||
|
||||
## 9. 总结
|
||||
|
||||
AI Core架构实现了:
|
||||
|
||||
### 9.1 核心特点
|
||||
|
||||
- ✅ **简化架构**: 2层核心架构,职责清晰
|
||||
- ✅ **函数式设计**: models层完全函数化
|
||||
- ✅ **类型安全**: 统一的类型定义和AI SDK类型复用
|
||||
- ✅ **插件扩展**: 强大的插件系统
|
||||
- ✅ **多种使用方式**: 满足不同复杂度需求
|
||||
- ✅ **Agent就绪**: 为OpenAI Agents SDK集成做好准备
|
||||
|
||||
### 9.2 核心价值
|
||||
|
||||
- **统一接口**: 一套API支持19+ AI providers
|
||||
- **灵活使用**: 函数式、实例式、静态工厂式
|
||||
- **强类型**: 完整的TypeScript支持
|
||||
- **可扩展**: 插件和中间件双重扩展能力
|
||||
- **高性能**: 最小化包装,直接使用AI SDK
|
||||
- **面向未来**: Agent SDK集成架构就绪
|
||||
|
||||
### 9.3 未来发展
|
||||
|
||||
这个架构提供了:
|
||||
|
||||
- **优秀的开发体验**: 简洁的API和清晰的使用模式
|
||||
- **强大的扩展能力**: 为Agent功能预留了完整的架构空间
|
||||
- **良好的维护性**: 职责分离明确,代码易于维护
|
||||
- **广泛的适用性**: 既适合简单调用也适合复杂应用
|
||||
433
packages/aiCore/README.md
Normal file
433
packages/aiCore/README.md
Normal file
@@ -0,0 +1,433 @@
|
||||
# @cherrystudio/ai-core
|
||||
|
||||
Cherry Studio AI Core 是一个基于 Vercel AI SDK 的统一 AI Provider 接口包,为 AI 应用提供强大的抽象层和插件化架构。
|
||||
|
||||
## ✨ 核心亮点
|
||||
|
||||
### 🏗️ 优雅的架构设计
|
||||
|
||||
- **简化分层**:`models`(模型层)→ `runtime`(运行时层),清晰的职责分离
|
||||
- **函数式优先**:避免过度抽象,提供简洁直观的 API
|
||||
- **类型安全**:完整的 TypeScript 支持,直接复用 AI SDK 类型系统
|
||||
- **最小包装**:直接使用 AI SDK 的接口,避免重复定义和性能损耗
|
||||
|
||||
### 🔌 强大的插件系统
|
||||
|
||||
- **生命周期钩子**:支持请求全生命周期的扩展点
|
||||
- **流转换支持**:基于 AI SDK 的 `experimental_transform` 实现流处理
|
||||
- **插件分类**:First、Sequential、Parallel 三种钩子类型,满足不同场景
|
||||
- **内置插件**:webSearch、logging、toolUse 等开箱即用的功能
|
||||
|
||||
### 🌐 统一多 Provider 接口
|
||||
|
||||
- **扩展注册**:支持自定义 Provider 注册,无限扩展能力
|
||||
- **配置统一**:统一的配置接口,简化多 Provider 管理
|
||||
|
||||
### 🚀 多种使用方式
|
||||
|
||||
- **函数式调用**:适合简单场景的直接函数调用
|
||||
- **执行器实例**:适合复杂场景的可复用执行器
|
||||
- **静态工厂**:便捷的静态创建方法
|
||||
- **原生兼容**:完全兼容 AI SDK 原生 Provider Registry
|
||||
|
||||
### 🔮 面向未来
|
||||
|
||||
- **Agent 就绪**:为 OpenAI Agents SDK 集成预留架构空间
|
||||
- **模块化设计**:独立包结构,支持跨项目复用
|
||||
- **渐进式迁移**:可以逐步从现有 AI SDK 代码迁移
|
||||
|
||||
## 特性
|
||||
|
||||
- 🚀 统一的 AI Provider 接口
|
||||
- 🔄 动态导入支持
|
||||
- 🛠️ TypeScript 支持
|
||||
- 📦 强大的插件系统
|
||||
- 🌍 内置webSearch(Openai,Google,Anthropic,xAI)
|
||||
- 🎯 多种使用模式(函数式/实例式/静态工厂)
|
||||
- 🔌 可扩展的 Provider 注册系统
|
||||
- 🧩 完整的中间件支持
|
||||
- 📊 插件统计和调试功能
|
||||
|
||||
## 支持的 Providers
|
||||
|
||||
基于 [AI SDK 官方支持的 providers](https://ai-sdk.dev/providers/ai-sdk-providers):
|
||||
|
||||
**核心 Providers(内置支持):**
|
||||
|
||||
- OpenAI
|
||||
- Anthropic
|
||||
- Google Generative AI
|
||||
- OpenAI-Compatible
|
||||
- xAI (Grok)
|
||||
- Azure OpenAI
|
||||
- DeepSeek
|
||||
|
||||
**扩展 Providers(通过注册API支持):**
|
||||
|
||||
- Google Vertex AI
|
||||
- ...
|
||||
- 自定义 Provider
|
||||
|
||||
## 安装
|
||||
|
||||
```bash
|
||||
npm install @cherrystudio/ai-core ai
|
||||
```
|
||||
|
||||
### React Native
|
||||
|
||||
如果你在 React Native 项目中使用此包,需要在 `metro.config.js` 中添加以下配置:
|
||||
|
||||
```javascript
|
||||
// metro.config.js
|
||||
const { getDefaultConfig } = require('expo/metro-config')
|
||||
|
||||
const config = getDefaultConfig(__dirname)
|
||||
|
||||
// 添加对 @cherrystudio/ai-core 的支持
|
||||
config.resolver.resolverMainFields = ['react-native', 'browser', 'main']
|
||||
config.resolver.platforms = ['ios', 'android', 'native', 'web']
|
||||
|
||||
module.exports = config
|
||||
```
|
||||
|
||||
还需要安装你要使用的 AI SDK provider:
|
||||
|
||||
```bash
|
||||
npm install @ai-sdk/openai @ai-sdk/anthropic @ai-sdk/google
|
||||
```
|
||||
|
||||
## 使用示例
|
||||
|
||||
### 基础用法
|
||||
|
||||
```typescript
|
||||
import { AiCore } from '@cherrystudio/ai-core'
|
||||
|
||||
// 创建 OpenAI executor
|
||||
const executor = AiCore.create('openai', {
|
||||
apiKey: 'your-api-key'
|
||||
})
|
||||
|
||||
// 流式生成
|
||||
const result = await executor.streamText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'Hello!' }]
|
||||
})
|
||||
|
||||
// 非流式生成
|
||||
const response = await executor.generateText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'Hello!' }]
|
||||
})
|
||||
```
|
||||
|
||||
### 便捷函数
|
||||
|
||||
```typescript
|
||||
import { createOpenAIExecutor } from '@cherrystudio/ai-core'
|
||||
|
||||
// 快速创建 OpenAI executor
|
||||
const executor = createOpenAIExecutor({
|
||||
apiKey: 'your-api-key'
|
||||
})
|
||||
|
||||
// 使用 executor
|
||||
const result = await executor.streamText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'Hello!' }]
|
||||
})
|
||||
```
|
||||
|
||||
### 多 Provider 支持
|
||||
|
||||
```typescript
|
||||
import { AiCore } from '@cherrystudio/ai-core'
|
||||
|
||||
// 支持多种 AI providers
|
||||
const openaiExecutor = AiCore.create('openai', { apiKey: 'openai-key' })
|
||||
const anthropicExecutor = AiCore.create('anthropic', { apiKey: 'anthropic-key' })
|
||||
const googleExecutor = AiCore.create('google', { apiKey: 'google-key' })
|
||||
const xaiExecutor = AiCore.create('xai', { apiKey: 'xai-key' })
|
||||
```
|
||||
|
||||
### 扩展 Provider 注册
|
||||
|
||||
对于非内置的 providers,可以通过注册 API 扩展支持:
|
||||
|
||||
```typescript
|
||||
import { registerProvider, AiCore } from '@cherrystudio/ai-core'
|
||||
|
||||
// 方式一:导入并注册第三方 provider
|
||||
import { createGroq } from '@ai-sdk/groq'
|
||||
|
||||
registerProvider({
|
||||
id: 'groq',
|
||||
name: 'Groq',
|
||||
creator: createGroq,
|
||||
supportsImageGeneration: false
|
||||
})
|
||||
|
||||
// 现在可以使用 Groq
|
||||
const groqExecutor = AiCore.create('groq', { apiKey: 'groq-key' })
|
||||
|
||||
// 方式二:动态导入方式注册
|
||||
registerProvider({
|
||||
id: 'mistral',
|
||||
name: 'Mistral AI',
|
||||
import: () => import('@ai-sdk/mistral'),
|
||||
creatorFunctionName: 'createMistral'
|
||||
})
|
||||
|
||||
const mistralExecutor = AiCore.create('mistral', { apiKey: 'mistral-key' })
|
||||
```
|
||||
|
||||
## 🔌 插件系统
|
||||
|
||||
AI Core 提供了强大的插件系统,支持请求全生命周期的扩展。
|
||||
|
||||
### 内置插件
|
||||
|
||||
#### webSearchPlugin - 网络搜索插件
|
||||
|
||||
为不同 AI Provider 提供统一的网络搜索能力:
|
||||
|
||||
```typescript
|
||||
import { webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
|
||||
const executor = AiCore.create('openai', { apiKey: 'your-key' }, [
|
||||
webSearchPlugin({
|
||||
openai: {
|
||||
/* OpenAI 搜索配置 */
|
||||
},
|
||||
anthropic: { maxUses: 5 },
|
||||
google: {
|
||||
/* Google 搜索配置 */
|
||||
},
|
||||
xai: {
|
||||
mode: 'on',
|
||||
returnCitations: true,
|
||||
maxSearchResults: 5,
|
||||
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
|
||||
}
|
||||
})
|
||||
])
|
||||
```
|
||||
|
||||
#### loggingPlugin - 日志插件
|
||||
|
||||
提供详细的请求日志记录:
|
||||
|
||||
```typescript
|
||||
import { createLoggingPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
|
||||
const executor = AiCore.create('openai', { apiKey: 'your-key' }, [
|
||||
createLoggingPlugin({
|
||||
logLevel: 'info',
|
||||
includeParams: true,
|
||||
includeResult: false
|
||||
})
|
||||
])
|
||||
```
|
||||
|
||||
#### promptToolUsePlugin - 提示工具使用插件
|
||||
|
||||
为不支持原生 Function Call 的模型提供 prompt 方式的工具调用:
|
||||
|
||||
```typescript
|
||||
import { createPromptToolUsePlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
|
||||
// 对于不支持 function call 的模型
|
||||
const executor = AiCore.create(
|
||||
'providerId',
|
||||
{
|
||||
apiKey: 'your-key',
|
||||
baseURL: 'https://your-model-endpoint'
|
||||
},
|
||||
[
|
||||
createPromptToolUsePlugin({
|
||||
enabled: true,
|
||||
// 可选:自定义系统提示符构建
|
||||
buildSystemPrompt: (userPrompt, tools) => {
|
||||
return `${userPrompt}\n\nAvailable tools: ${Object.keys(tools).join(', ')}`
|
||||
}
|
||||
})
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### 自定义插件
|
||||
|
||||
创建自定义插件非常简单:
|
||||
|
||||
```typescript
|
||||
import { definePlugin } from '@cherrystudio/ai-core'
|
||||
|
||||
const customPlugin = definePlugin({
|
||||
name: 'custom-plugin',
|
||||
enforce: 'pre', // 'pre' | 'post' | undefined
|
||||
|
||||
// 在请求开始时记录日志
|
||||
onRequestStart: async (context) => {
|
||||
console.log(`Starting request for model: ${context.modelId}`)
|
||||
},
|
||||
|
||||
// 转换请求参数
|
||||
transformParams: async (params, context) => {
|
||||
// 添加自定义系统消息
|
||||
if (params.messages) {
|
||||
params.messages.unshift({
|
||||
role: 'system',
|
||||
content: 'You are a helpful assistant.'
|
||||
})
|
||||
}
|
||||
return params
|
||||
},
|
||||
|
||||
// 处理响应结果
|
||||
transformResult: async (result, context) => {
|
||||
// 添加元数据
|
||||
if (result.text) {
|
||||
result.metadata = {
|
||||
processedAt: new Date().toISOString(),
|
||||
modelId: context.modelId
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
})
|
||||
|
||||
// 使用自定义插件
|
||||
const executor = AiCore.create('openai', { apiKey: 'your-key' }, [customPlugin])
|
||||
```
|
||||
|
||||
### 使用 AI SDK 原生 Provider 注册表
|
||||
|
||||
> https://ai-sdk.dev/docs/reference/ai-sdk-core/provider-registry
|
||||
|
||||
除了使用内建的 provider 管理,你还可以使用 AI SDK 原生的 `createProviderRegistry` 来构建自己的 provider 注册表。
|
||||
|
||||
#### 基本用法示例
|
||||
|
||||
```typescript
|
||||
import { createClient } from '@cherrystudio/ai-core'
|
||||
import { createProviderRegistry } from 'ai'
|
||||
import { createOpenAI } from '@ai-sdk/openai'
|
||||
import { anthropic } from '@ai-sdk/anthropic'
|
||||
|
||||
// 1. 创建 AI SDK 原生注册表
|
||||
export const registry = createProviderRegistry({
|
||||
// register provider with prefix and default setup:
|
||||
anthropic,
|
||||
|
||||
// register provider with prefix and custom setup:
|
||||
openai: createOpenAI({
|
||||
apiKey: process.env.OPENAI_API_KEY
|
||||
})
|
||||
})
|
||||
|
||||
// 2. 创建client,'openai'可以传空或者传providerId(内建的provider)
|
||||
const client = PluginEnabledAiClient.create('openai', {
|
||||
apiKey: process.env.OPENAI_API_KEY
|
||||
})
|
||||
|
||||
// 3. 方式1:使用内建逻辑(传统方式)
|
||||
const result1 = await client.streamText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'Hello with built-in logic!' }]
|
||||
})
|
||||
|
||||
// 4. 方式2:使用自定义注册表(灵活方式)
|
||||
const result2 = await client.streamText({
|
||||
model: registry.languageModel('openai:gpt-4'),
|
||||
messages: [{ role: 'user', content: 'Hello with custom registry!' }]
|
||||
})
|
||||
|
||||
// 5. 支持的重载方法
|
||||
await client.generateObject({
|
||||
model: registry.languageModel('openai:gpt-4'),
|
||||
schema: z.object({ name: z.string() }),
|
||||
messages: [{ role: 'user', content: 'Generate a user' }]
|
||||
})
|
||||
|
||||
await client.streamObject({
|
||||
model: registry.languageModel('anthropic:claude-3-opus-20240229'),
|
||||
schema: z.object({ items: z.array(z.string()) }),
|
||||
messages: [{ role: 'user', content: 'Generate a list' }]
|
||||
})
|
||||
```
|
||||
|
||||
#### 与插件系统配合使用
|
||||
|
||||
更强大的是,你还可以将自定义注册表与 Cherry Studio 的插件系统结合使用:
|
||||
|
||||
```typescript
|
||||
import { PluginEnabledAiClient } from '@cherrystudio/ai-core'
|
||||
import { createProviderRegistry } from 'ai'
|
||||
import { createOpenAI } from '@ai-sdk/openai'
|
||||
import { anthropic } from '@ai-sdk/anthropic'
|
||||
|
||||
// 1. 创建带插件的客户端
|
||||
const client = PluginEnabledAiClient.create(
|
||||
'openai',
|
||||
{
|
||||
apiKey: process.env.OPENAI_API_KEY
|
||||
},
|
||||
[LoggingPlugin, RetryPlugin]
|
||||
)
|
||||
|
||||
// 2. 创建自定义注册表
|
||||
const registry = createProviderRegistry({
|
||||
openai: createOpenAI({ apiKey: process.env.OPENAI_API_KEY }),
|
||||
anthropic: anthropic({ apiKey: process.env.ANTHROPIC_API_KEY })
|
||||
})
|
||||
|
||||
// 3. 方式1:使用内建逻辑 + 完整插件系统
|
||||
await client.streamText('gpt-4', {
|
||||
messages: [{ role: 'user', content: 'Hello with plugins!' }]
|
||||
})
|
||||
|
||||
// 4. 方式2:使用自定义注册表 + 有限插件支持
|
||||
await client.streamText({
|
||||
model: registry.languageModel('anthropic:claude-3-opus-20240229'),
|
||||
messages: [{ role: 'user', content: 'Hello from Claude!' }]
|
||||
})
|
||||
|
||||
// 5. 支持的方法
|
||||
await client.generateObject({
|
||||
model: registry.languageModel('openai:gpt-4'),
|
||||
schema: z.object({ name: z.string() }),
|
||||
messages: [{ role: 'user', content: 'Generate a user' }]
|
||||
})
|
||||
|
||||
await client.streamObject({
|
||||
model: registry.languageModel('openai:gpt-4'),
|
||||
schema: z.object({ items: z.array(z.string()) }),
|
||||
messages: [{ role: 'user', content: 'Generate a list' }]
|
||||
})
|
||||
```
|
||||
|
||||
#### 混合使用的优势
|
||||
|
||||
- **灵活性**:可以根据需要选择使用内建逻辑或自定义注册表
|
||||
- **兼容性**:完全兼容 AI SDK 的 `createProviderRegistry` API
|
||||
- **渐进式**:可以逐步迁移现有代码,无需一次性重构
|
||||
- **插件支持**:自定义注册表仍可享受插件系统的部分功能
|
||||
- **最佳实践**:结合两种方式的优点,既有动态加载的性能优势,又有统一注册表的便利性
|
||||
|
||||
## 📚 相关资源
|
||||
|
||||
- [Vercel AI SDK 文档](https://ai-sdk.dev/)
|
||||
- [Cherry Studio 项目](https://github.com/CherryHQ/cherry-studio)
|
||||
- [AI SDK Providers](https://ai-sdk.dev/providers/ai-sdk-providers)
|
||||
|
||||
## 未来版本
|
||||
|
||||
- 🔮 多 Agent 编排
|
||||
- 🔮 可视化插件配置
|
||||
- 🔮 实时监控和分析
|
||||
- 🔮 云端插件同步
|
||||
|
||||
## 📄 License
|
||||
|
||||
MIT License - 详见 [LICENSE](https://github.com/CherryHQ/cherry-studio/blob/main/LICENSE) 文件
|
||||
|
||||
---
|
||||
|
||||
**Cherry Studio AI Core** - 让 AI 开发更简单、更强大、更灵活 🚀
|
||||
103
packages/aiCore/examples/hub-provider-usage.ts
Normal file
103
packages/aiCore/examples/hub-provider-usage.ts
Normal file
@@ -0,0 +1,103 @@
|
||||
/**
|
||||
* Hub Provider 使用示例
|
||||
*
|
||||
* 演示如何使用简化后的Hub Provider功能来路由到多个底层provider
|
||||
*/
|
||||
|
||||
import { createHubProvider, initializeProvider, providerRegistry } from '../src/index'
|
||||
|
||||
async function demonstrateHubProvider() {
|
||||
try {
|
||||
// 1. 初始化底层providers
|
||||
console.log('📦 初始化底层providers...')
|
||||
|
||||
initializeProvider('openai', {
|
||||
apiKey: process.env.OPENAI_API_KEY || 'sk-test-key'
|
||||
})
|
||||
|
||||
initializeProvider('anthropic', {
|
||||
apiKey: process.env.ANTHROPIC_API_KEY || 'sk-ant-test-key'
|
||||
})
|
||||
|
||||
// 2. 创建Hub Provider(自动包含所有已初始化的providers)
|
||||
console.log('🌐 创建Hub Provider...')
|
||||
|
||||
const aihubmixProvider = createHubProvider({
|
||||
hubId: 'aihubmix',
|
||||
debug: true
|
||||
})
|
||||
|
||||
// 3. 注册Hub Provider
|
||||
providerRegistry.registerProvider('aihubmix', aihubmixProvider)
|
||||
|
||||
console.log('✅ Hub Provider "aihubmix" 注册成功')
|
||||
|
||||
// 4. 使用Hub Provider访问不同的模型
|
||||
console.log('\n🚀 使用Hub模型...')
|
||||
|
||||
// 通过Hub路由到OpenAI
|
||||
const openaiModel = providerRegistry.languageModel('aihubmix:openai:gpt-4')
|
||||
console.log('✓ OpenAI模型已获取:', openaiModel.modelId)
|
||||
|
||||
// 通过Hub路由到Anthropic
|
||||
const anthropicModel = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
|
||||
console.log('✓ Anthropic模型已获取:', anthropicModel.modelId)
|
||||
|
||||
// 5. 演示错误处理
|
||||
console.log('\n❌ 演示错误处理...')
|
||||
|
||||
try {
|
||||
// 尝试访问未初始化的provider
|
||||
providerRegistry.languageModel('aihubmix:google:gemini-pro')
|
||||
} catch (error) {
|
||||
console.log('预期错误:', error.message)
|
||||
}
|
||||
|
||||
try {
|
||||
// 尝试使用错误的模型ID格式
|
||||
providerRegistry.languageModel('aihubmix:invalid-format')
|
||||
} catch (error) {
|
||||
console.log('预期错误:', error.message)
|
||||
}
|
||||
|
||||
// 6. 多个Hub Provider示例
|
||||
console.log('\n🔄 创建多个Hub Provider...')
|
||||
|
||||
const localHubProvider = createHubProvider({
|
||||
hubId: 'local-ai'
|
||||
})
|
||||
|
||||
providerRegistry.registerProvider('local-ai', localHubProvider)
|
||||
console.log('✅ Hub Provider "local-ai" 注册成功')
|
||||
|
||||
console.log('\n🎉 Hub Provider演示完成!')
|
||||
} catch (error) {
|
||||
console.error('💥 演示过程中发生错误:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// 演示简化的使用方式
|
||||
function simplifiedUsageExample() {
|
||||
console.log('\n📝 简化使用示例:')
|
||||
console.log(`
|
||||
// 1. 初始化providers
|
||||
initializeProvider('openai', { apiKey: 'sk-xxx' })
|
||||
initializeProvider('anthropic', { apiKey: 'sk-ant-xxx' })
|
||||
|
||||
// 2. 创建并注册Hub Provider
|
||||
const hubProvider = createHubProvider({ hubId: 'aihubmix' })
|
||||
providerRegistry.registerProvider('aihubmix', hubProvider)
|
||||
|
||||
// 3. 直接使用
|
||||
const model1 = providerRegistry.languageModel('aihubmix:openai:gpt-4')
|
||||
const model2 = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
|
||||
`)
|
||||
}
|
||||
|
||||
// 运行演示
|
||||
if (require.main === module) {
|
||||
demonstrateHubProvider()
|
||||
simplifiedUsageExample()
|
||||
}
|
||||
|
||||
export { demonstrateHubProvider, simplifiedUsageExample }
|
||||
167
packages/aiCore/examples/image-generation.ts
Normal file
167
packages/aiCore/examples/image-generation.ts
Normal file
@@ -0,0 +1,167 @@
|
||||
/**
|
||||
* Image Generation Example
|
||||
* 演示如何使用 aiCore 的文生图功能
|
||||
*/
|
||||
|
||||
import { createExecutor, generateImage } from '../src/index'
|
||||
|
||||
async function main() {
|
||||
// 方式1: 使用执行器实例
|
||||
console.log('📸 创建 OpenAI 图像生成执行器...')
|
||||
const executor = createExecutor('openai', {
|
||||
apiKey: process.env.OPENAI_API_KEY!
|
||||
})
|
||||
|
||||
try {
|
||||
console.log('🎨 使用执行器生成图像...')
|
||||
const result1 = await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A futuristic cityscape at sunset with flying cars',
|
||||
size: '1024x1024',
|
||||
n: 1
|
||||
})
|
||||
|
||||
console.log('✅ 图像生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result1.images.length,
|
||||
mediaType: result1.image.mediaType,
|
||||
hasBase64: !!result1.image.base64,
|
||||
providerMetadata: result1.providerMetadata
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ 执行器生成失败:', error)
|
||||
}
|
||||
|
||||
// 方式2: 使用直接调用 API
|
||||
try {
|
||||
console.log('🎨 使用直接 API 生成图像...')
|
||||
const result2 = await generateImage('openai', { apiKey: process.env.OPENAI_API_KEY! }, 'dall-e-3', {
|
||||
prompt: 'A magical forest with glowing mushrooms and fairy lights',
|
||||
aspectRatio: '16:9',
|
||||
providerOptions: {
|
||||
openai: {
|
||||
quality: 'hd',
|
||||
style: 'vivid'
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
console.log('✅ 直接 API 生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result2.images.length,
|
||||
mediaType: result2.image.mediaType,
|
||||
hasBase64: !!result2.image.base64
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ 直接 API 生成失败:', error)
|
||||
}
|
||||
|
||||
// 方式3: 支持其他提供商 (Google Imagen)
|
||||
if (process.env.GOOGLE_API_KEY) {
|
||||
try {
|
||||
console.log('🎨 使用 Google Imagen 生成图像...')
|
||||
const googleExecutor = createExecutor('google', {
|
||||
apiKey: process.env.GOOGLE_API_KEY!
|
||||
})
|
||||
|
||||
const result3 = await googleExecutor.generateImage('imagen-3.0-generate-002', {
|
||||
prompt: 'A serene mountain lake at dawn with mist rising from the water',
|
||||
aspectRatio: '1:1'
|
||||
})
|
||||
|
||||
console.log('✅ Google Imagen 生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result3.images.length,
|
||||
mediaType: result3.image.mediaType,
|
||||
hasBase64: !!result3.image.base64
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ Google Imagen 生成失败:', error)
|
||||
}
|
||||
}
|
||||
|
||||
// 方式4: 支持插件系统
|
||||
const pluginExample = async () => {
|
||||
console.log('🔌 演示插件系统...')
|
||||
|
||||
// 创建一个示例插件,用于修改提示词
|
||||
const promptEnhancerPlugin = {
|
||||
name: 'prompt-enhancer',
|
||||
transformParams: async (params: any) => {
|
||||
console.log('🔧 插件: 增强提示词...')
|
||||
return {
|
||||
...params,
|
||||
prompt: `${params.prompt}, highly detailed, cinematic lighting, 4K resolution`
|
||||
}
|
||||
},
|
||||
transformResult: async (result: any) => {
|
||||
console.log('🔧 插件: 处理结果...')
|
||||
return {
|
||||
...result,
|
||||
enhanced: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const executorWithPlugin = createExecutor(
|
||||
'openai',
|
||||
{
|
||||
apiKey: process.env.OPENAI_API_KEY!
|
||||
},
|
||||
[promptEnhancerPlugin]
|
||||
)
|
||||
|
||||
try {
|
||||
const result4 = await executorWithPlugin.generateImage('dall-e-3', {
|
||||
prompt: 'A cute robot playing in a garden'
|
||||
})
|
||||
|
||||
console.log('✅ 插件系统生成成功!')
|
||||
console.log('📊 结果:', {
|
||||
imagesCount: result4.images.length,
|
||||
enhanced: (result4 as any).enhanced,
|
||||
mediaType: result4.image.mediaType
|
||||
})
|
||||
} catch (error) {
|
||||
console.error('❌ 插件系统生成失败:', error)
|
||||
}
|
||||
}
|
||||
|
||||
await pluginExample()
|
||||
}
|
||||
|
||||
// 错误处理演示
|
||||
async function errorHandlingExample() {
|
||||
console.log('⚠️ 演示错误处理...')
|
||||
|
||||
try {
|
||||
const executor = createExecutor('openai', {
|
||||
apiKey: 'invalid-key'
|
||||
})
|
||||
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'Test image'
|
||||
})
|
||||
} catch (error: any) {
|
||||
console.log('✅ 成功捕获错误:', error.constructor.name)
|
||||
console.log('📋 错误信息:', error.message)
|
||||
console.log('🏷️ 提供商ID:', error.providerId)
|
||||
console.log('🏷️ 模型ID:', error.modelId)
|
||||
}
|
||||
}
|
||||
|
||||
// 运行示例
|
||||
if (require.main === module) {
|
||||
main()
|
||||
.then(() => {
|
||||
console.log('🎉 所有示例完成!')
|
||||
return errorHandlingExample()
|
||||
})
|
||||
.then(() => {
|
||||
console.log('🎯 示例程序结束')
|
||||
process.exit(0)
|
||||
})
|
||||
.catch((error) => {
|
||||
console.error('💥 程序执行出错:', error)
|
||||
process.exit(1)
|
||||
})
|
||||
}
|
||||
85
packages/aiCore/package.json
Normal file
85
packages/aiCore/package.json
Normal file
@@ -0,0 +1,85 @@
|
||||
{
|
||||
"name": "@cherrystudio/ai-core",
|
||||
"version": "1.0.0-alpha.10",
|
||||
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
|
||||
"main": "dist/index.js",
|
||||
"module": "dist/index.mjs",
|
||||
"types": "dist/index.d.ts",
|
||||
"react-native": "dist/index.js",
|
||||
"scripts": {
|
||||
"build": "tsdown",
|
||||
"dev": "tsc -w",
|
||||
"clean": "rm -rf dist",
|
||||
"test": "vitest run",
|
||||
"test:watch": "vitest"
|
||||
},
|
||||
"keywords": [
|
||||
"ai",
|
||||
"sdk",
|
||||
"openai",
|
||||
"anthropic",
|
||||
"google",
|
||||
"cherry-studio",
|
||||
"vercel-ai-sdk"
|
||||
],
|
||||
"author": "Cherry Studio",
|
||||
"license": "MIT",
|
||||
"repository": {
|
||||
"type": "git",
|
||||
"url": "git+https://github.com/CherryHQ/cherry-studio.git"
|
||||
},
|
||||
"bugs": {
|
||||
"url": "https://github.com/CherryHQ/cherry-studio/issues"
|
||||
},
|
||||
"homepage": "https://github.com/CherryHQ/cherry-studio#readme",
|
||||
"peerDependencies": {
|
||||
"ai": "^5.0.26"
|
||||
},
|
||||
"dependencies": {
|
||||
"@ai-sdk/anthropic": "^2.0.5",
|
||||
"@ai-sdk/azure": "^2.0.16",
|
||||
"@ai-sdk/deepseek": "^1.0.9",
|
||||
"@ai-sdk/google": "^2.0.7",
|
||||
"@ai-sdk/openai": "^2.0.19",
|
||||
"@ai-sdk/openai-compatible": "^1.0.9",
|
||||
"@ai-sdk/provider": "^2.0.0",
|
||||
"@ai-sdk/provider-utils": "^3.0.4",
|
||||
"@ai-sdk/xai": "^2.0.9",
|
||||
"zod": "^3.25.0"
|
||||
},
|
||||
"devDependencies": {
|
||||
"tsdown": "^0.12.9",
|
||||
"typescript": "^5.0.0",
|
||||
"vitest": "^3.2.4"
|
||||
},
|
||||
"sideEffects": false,
|
||||
"engines": {
|
||||
"node": ">=18.0.0"
|
||||
},
|
||||
"files": [
|
||||
"dist"
|
||||
],
|
||||
"exports": {
|
||||
".": {
|
||||
"types": "./dist/index.d.ts",
|
||||
"react-native": "./dist/index.js",
|
||||
"import": "./dist/index.mjs",
|
||||
"require": "./dist/index.js",
|
||||
"default": "./dist/index.js"
|
||||
},
|
||||
"./built-in/plugins": {
|
||||
"types": "./dist/built-in/plugins/index.d.ts",
|
||||
"react-native": "./dist/built-in/plugins/index.js",
|
||||
"import": "./dist/built-in/plugins/index.mjs",
|
||||
"require": "./dist/built-in/plugins/index.js",
|
||||
"default": "./dist/built-in/plugins/index.js"
|
||||
},
|
||||
"./provider": {
|
||||
"types": "./dist/provider/index.d.ts",
|
||||
"react-native": "./dist/provider/index.js",
|
||||
"import": "./dist/provider/index.mjs",
|
||||
"require": "./dist/provider/index.js",
|
||||
"default": "./dist/provider/index.js"
|
||||
}
|
||||
}
|
||||
}
|
||||
2
packages/aiCore/setupVitest.ts
Normal file
2
packages/aiCore/setupVitest.ts
Normal file
@@ -0,0 +1,2 @@
|
||||
// 模拟 Vite SSR helper,避免 Node 环境找不到时报错
|
||||
;(globalThis as any).__vite_ssr_exportName__ = (name: string, value: any) => value
|
||||
3
packages/aiCore/src/core/README.MD
Normal file
3
packages/aiCore/src/core/README.MD
Normal file
@@ -0,0 +1,3 @@
|
||||
# @cherryStudio-aiCore
|
||||
|
||||
Core
|
||||
17
packages/aiCore/src/core/index.ts
Normal file
17
packages/aiCore/src/core/index.ts
Normal file
@@ -0,0 +1,17 @@
|
||||
/**
|
||||
* Core 模块导出
|
||||
* 内部核心功能,供其他模块使用,不直接面向最终调用者
|
||||
*/
|
||||
|
||||
// 中间件系统
|
||||
export type { NamedMiddleware } from './middleware'
|
||||
export { createMiddlewares, wrapModelWithMiddlewares } from './middleware'
|
||||
|
||||
// 创建管理
|
||||
export { globalModelResolver, ModelResolver } from './models'
|
||||
export type { ModelConfig as ModelConfigType } from './models/types'
|
||||
|
||||
// 执行管理
|
||||
export type { ToolUseRequestContext } from './plugins/built-in/toolUsePlugin/type'
|
||||
export { createExecutor, createOpenAICompatibleExecutor } from './runtime'
|
||||
export type { RuntimeConfig } from './runtime/types'
|
||||
8
packages/aiCore/src/core/middleware/index.ts
Normal file
8
packages/aiCore/src/core/middleware/index.ts
Normal file
@@ -0,0 +1,8 @@
|
||||
/**
|
||||
* Middleware 模块导出
|
||||
* 提供通用的中间件管理能力
|
||||
*/
|
||||
|
||||
export { createMiddlewares } from './manager'
|
||||
export type { NamedMiddleware } from './types'
|
||||
export { wrapModelWithMiddlewares } from './wrapper'
|
||||
16
packages/aiCore/src/core/middleware/manager.ts
Normal file
16
packages/aiCore/src/core/middleware/manager.ts
Normal file
@@ -0,0 +1,16 @@
|
||||
/**
|
||||
* 中间件管理器
|
||||
* 专注于 AI SDK 中间件的管理,与插件系统分离
|
||||
*/
|
||||
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
/**
|
||||
* 创建中间件列表
|
||||
* 合并用户提供的中间件
|
||||
*/
|
||||
export function createMiddlewares(userMiddlewares: LanguageModelV2Middleware[] = []): LanguageModelV2Middleware[] {
|
||||
// 未来可以在这里添加默认的中间件
|
||||
const defaultMiddlewares: LanguageModelV2Middleware[] = []
|
||||
|
||||
return [...defaultMiddlewares, ...userMiddlewares]
|
||||
}
|
||||
12
packages/aiCore/src/core/middleware/types.ts
Normal file
12
packages/aiCore/src/core/middleware/types.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
/**
|
||||
* 中间件系统类型定义
|
||||
*/
|
||||
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
/**
|
||||
* 具名中间件接口
|
||||
*/
|
||||
export interface NamedMiddleware {
|
||||
name: string
|
||||
middleware: LanguageModelV2Middleware
|
||||
}
|
||||
23
packages/aiCore/src/core/middleware/wrapper.ts
Normal file
23
packages/aiCore/src/core/middleware/wrapper.ts
Normal file
@@ -0,0 +1,23 @@
|
||||
/**
|
||||
* 模型包装工具函数
|
||||
* 用于将中间件应用到LanguageModel上
|
||||
*/
|
||||
import { LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
import { wrapLanguageModel } from 'ai'
|
||||
|
||||
/**
|
||||
* 使用中间件包装模型
|
||||
*/
|
||||
export function wrapModelWithMiddlewares(
|
||||
model: LanguageModelV2,
|
||||
middlewares: LanguageModelV2Middleware[]
|
||||
): LanguageModelV2 {
|
||||
if (middlewares.length === 0) {
|
||||
return model
|
||||
}
|
||||
|
||||
return wrapLanguageModel({
|
||||
model,
|
||||
middleware: middlewares
|
||||
})
|
||||
}
|
||||
130
packages/aiCore/src/core/models/ModelResolver.ts
Normal file
130
packages/aiCore/src/core/models/ModelResolver.ts
Normal file
@@ -0,0 +1,130 @@
|
||||
/**
|
||||
* 模型解析器 - models模块的核心
|
||||
* 负责将modelId解析为AI SDK的LanguageModel实例
|
||||
* 支持传统格式和命名空间格式
|
||||
* 集成了来自 ModelCreator 的特殊处理逻辑
|
||||
*/
|
||||
|
||||
import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
import { isOpenAIChatCompletionOnlyModel } from '../../utils/model'
|
||||
import { wrapModelWithMiddlewares } from '../middleware/wrapper'
|
||||
import { globalRegistryManagement } from '../providers/RegistryManagement'
|
||||
|
||||
export class ModelResolver {
|
||||
/**
|
||||
* 核心方法:解析任意格式的modelId为语言模型
|
||||
*
|
||||
* @param modelId 模型ID,支持 'gpt-4' 和 'anthropic>claude-3' 两种格式
|
||||
* @param fallbackProviderId 当modelId为传统格式时使用的providerId
|
||||
* @param providerOptions provider配置选项(用于OpenAI模式选择等)
|
||||
* @param middlewares 中间件数组,会应用到最终模型上
|
||||
*/
|
||||
async resolveLanguageModel(
|
||||
modelId: string,
|
||||
fallbackProviderId: string,
|
||||
providerOptions?: any,
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<LanguageModelV2> {
|
||||
let finalProviderId = fallbackProviderId
|
||||
let model: LanguageModelV2
|
||||
|
||||
// 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移)
|
||||
if (fallbackProviderId === 'openai' && providerOptions?.mode === 'chat') {
|
||||
// 检查是否支持 chat 模式且不是只支持 chat 的模型
|
||||
if (!isOpenAIChatCompletionOnlyModel(modelId)) {
|
||||
finalProviderId = 'openai-chat'
|
||||
}
|
||||
// 否则使用默认的 openai (responses 模式)
|
||||
}
|
||||
|
||||
// 检查是否是命名空间格式
|
||||
if (modelId.includes(':')) {
|
||||
model = this.resolveNamespacedModel(modelId)
|
||||
} else {
|
||||
// 传统格式:使用处理后的 providerId + modelId
|
||||
model = this.resolveTraditionalModel(finalProviderId, modelId)
|
||||
}
|
||||
|
||||
// 🎯 应用中间件(如果有)
|
||||
if (middlewares && middlewares.length > 0) {
|
||||
model = wrapModelWithMiddlewares(model, middlewares)
|
||||
}
|
||||
|
||||
return model
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析文本嵌入模型
|
||||
*/
|
||||
async resolveTextEmbeddingModel(modelId: string, fallbackProviderId: string): Promise<EmbeddingModelV2<string>> {
|
||||
if (modelId.includes(':')) {
|
||||
return this.resolveNamespacedEmbeddingModel(modelId)
|
||||
}
|
||||
|
||||
return this.resolveTraditionalEmbeddingModel(fallbackProviderId, modelId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析图像模型
|
||||
*/
|
||||
async resolveImageModel(modelId: string, fallbackProviderId: string): Promise<ImageModelV2> {
|
||||
if (modelId.includes(':')) {
|
||||
return this.resolveNamespacedImageModel(modelId)
|
||||
}
|
||||
|
||||
return this.resolveTraditionalImageModel(fallbackProviderId, modelId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析命名空间格式的语言模型
|
||||
* aihubmix:anthropic:claude-3 -> globalRegistryManagement.languageModel('aihubmix:anthropic:claude-3')
|
||||
*/
|
||||
private resolveNamespacedModel(modelId: string): LanguageModelV2 {
|
||||
return globalRegistryManagement.languageModel(modelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析传统格式的语言模型
|
||||
* providerId: 'openai', modelId: 'gpt-4' -> globalRegistryManagement.languageModel('openai:gpt-4')
|
||||
*/
|
||||
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
|
||||
const fullModelId = `${providerId}:${modelId}`
|
||||
return globalRegistryManagement.languageModel(fullModelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析命名空间格式的嵌入模型
|
||||
*/
|
||||
private resolveNamespacedEmbeddingModel(modelId: string): EmbeddingModelV2<string> {
|
||||
return globalRegistryManagement.textEmbeddingModel(modelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析传统格式的嵌入模型
|
||||
*/
|
||||
private resolveTraditionalEmbeddingModel(providerId: string, modelId: string): EmbeddingModelV2<string> {
|
||||
const fullModelId = `${providerId}:${modelId}`
|
||||
return globalRegistryManagement.textEmbeddingModel(fullModelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析命名空间格式的图像模型
|
||||
*/
|
||||
private resolveNamespacedImageModel(modelId: string): ImageModelV2 {
|
||||
return globalRegistryManagement.imageModel(modelId as any)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析传统格式的图像模型
|
||||
*/
|
||||
private resolveTraditionalImageModel(providerId: string, modelId: string): ImageModelV2 {
|
||||
const fullModelId = `${providerId}:${modelId}`
|
||||
return globalRegistryManagement.imageModel(fullModelId as any)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 全局模型解析器实例
|
||||
*/
|
||||
export const globalModelResolver = new ModelResolver()
|
||||
9
packages/aiCore/src/core/models/index.ts
Normal file
9
packages/aiCore/src/core/models/index.ts
Normal file
@@ -0,0 +1,9 @@
|
||||
/**
|
||||
* Models 模块统一导出 - 简化版
|
||||
*/
|
||||
|
||||
// 核心模型解析器
|
||||
export { globalModelResolver, ModelResolver } from './ModelResolver'
|
||||
|
||||
// 保留的类型定义(可能被其他地方使用)
|
||||
export type { ModelConfig as ModelConfigType } from './types'
|
||||
15
packages/aiCore/src/core/models/types.ts
Normal file
15
packages/aiCore/src/core/models/types.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
/**
|
||||
* Creation 模块类型定义
|
||||
*/
|
||||
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
import type { ProviderId, ProviderSettingsMap } from '../providers/types'
|
||||
|
||||
export interface ModelConfig<T extends ProviderId = ProviderId> {
|
||||
providerId: T
|
||||
modelId: string
|
||||
providerSettings: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
// 额外模型参数
|
||||
extraModelConfig?: Record<string, any>
|
||||
}
|
||||
87
packages/aiCore/src/core/options/examples.ts
Normal file
87
packages/aiCore/src/core/options/examples.ts
Normal file
@@ -0,0 +1,87 @@
|
||||
import { streamText } from 'ai'
|
||||
|
||||
import {
|
||||
createAnthropicOptions,
|
||||
createGenericProviderOptions,
|
||||
createGoogleOptions,
|
||||
createOpenAIOptions,
|
||||
mergeProviderOptions
|
||||
} from './factory'
|
||||
|
||||
// 示例1: 使用已知供应商的严格类型约束
|
||||
export function exampleOpenAIWithOptions() {
|
||||
const openaiOptions = createOpenAIOptions({
|
||||
reasoningEffort: 'medium'
|
||||
})
|
||||
|
||||
// 这里会有类型检查,确保选项符合OpenAI的设置
|
||||
return streamText({
|
||||
model: {} as any, // 实际使用时替换为真实模型
|
||||
prompt: 'Hello',
|
||||
providerOptions: openaiOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例2: 使用Anthropic供应商选项
|
||||
export function exampleAnthropicWithOptions() {
|
||||
const anthropicOptions = createAnthropicOptions({
|
||||
thinking: {
|
||||
type: 'enabled',
|
||||
budgetTokens: 1000
|
||||
}
|
||||
})
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: anthropicOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例3: 使用Google供应商选项
|
||||
export function exampleGoogleWithOptions() {
|
||||
const googleOptions = createGoogleOptions({
|
||||
thinkingConfig: {
|
||||
includeThoughts: true,
|
||||
thinkingBudget: 1000
|
||||
}
|
||||
})
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: googleOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例4: 使用未知供应商(通用类型)
|
||||
export function exampleUnknownProviderWithOptions() {
|
||||
const customProviderOptions = createGenericProviderOptions('custom-provider', {
|
||||
temperature: 0.7,
|
||||
customSetting: 'value',
|
||||
anotherOption: true
|
||||
})
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: customProviderOptions
|
||||
})
|
||||
}
|
||||
|
||||
// 示例5: 合并多个供应商选项
|
||||
export function exampleMergedOptions() {
|
||||
const openaiOptions = createOpenAIOptions({})
|
||||
|
||||
const customOptions = createGenericProviderOptions('custom', {
|
||||
customParam: 'value'
|
||||
})
|
||||
|
||||
const mergedOptions = mergeProviderOptions(openaiOptions, customOptions)
|
||||
|
||||
return streamText({
|
||||
model: {} as any,
|
||||
prompt: 'Hello',
|
||||
providerOptions: mergedOptions
|
||||
})
|
||||
}
|
||||
71
packages/aiCore/src/core/options/factory.ts
Normal file
71
packages/aiCore/src/core/options/factory.ts
Normal file
@@ -0,0 +1,71 @@
|
||||
import { ExtractProviderOptions, ProviderOptionsMap, TypedProviderOptions } from './types'
|
||||
|
||||
/**
|
||||
* 创建特定供应商的选项
|
||||
* @param provider 供应商名称
|
||||
* @param options 供应商特定的选项
|
||||
* @returns 格式化的provider options
|
||||
*/
|
||||
export function createProviderOptions<T extends keyof ProviderOptionsMap>(
|
||||
provider: T,
|
||||
options: ExtractProviderOptions<T>
|
||||
): Record<T, ExtractProviderOptions<T>> {
|
||||
return { [provider]: options } as Record<T, ExtractProviderOptions<T>>
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建任意供应商的选项(包括未知供应商)
|
||||
* @param provider 供应商名称
|
||||
* @param options 供应商选项
|
||||
* @returns 格式化的provider options
|
||||
*/
|
||||
export function createGenericProviderOptions<T extends string>(
|
||||
provider: T,
|
||||
options: Record<string, any>
|
||||
): Record<T, Record<string, any>> {
|
||||
return { [provider]: options } as Record<T, Record<string, any>>
|
||||
}
|
||||
|
||||
/**
|
||||
* 合并多个供应商的options
|
||||
* @param optionsMap 包含多个供应商选项的对象
|
||||
* @returns 合并后的TypedProviderOptions
|
||||
*/
|
||||
export function mergeProviderOptions(...optionsMap: Partial<TypedProviderOptions>[]): TypedProviderOptions {
|
||||
return Object.assign({}, ...optionsMap)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建OpenAI供应商选项的便捷函数
|
||||
*/
|
||||
export function createOpenAIOptions(options: ExtractProviderOptions<'openai'>) {
|
||||
return createProviderOptions('openai', options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建Anthropic供应商选项的便捷函数
|
||||
*/
|
||||
export function createAnthropicOptions(options: ExtractProviderOptions<'anthropic'>) {
|
||||
return createProviderOptions('anthropic', options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建Google供应商选项的便捷函数
|
||||
*/
|
||||
export function createGoogleOptions(options: ExtractProviderOptions<'google'>) {
|
||||
return createProviderOptions('google', options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建OpenRouter供应商选项的便捷函数
|
||||
*/
|
||||
export function createOpenRouterOptions(options: ExtractProviderOptions<'openrouter'>) {
|
||||
return createProviderOptions('openrouter', options)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建XAI供应商选项的便捷函数
|
||||
*/
|
||||
export function createXaiOptions(options: ExtractProviderOptions<'xai'>) {
|
||||
return createProviderOptions('xai', options)
|
||||
}
|
||||
2
packages/aiCore/src/core/options/index.ts
Normal file
2
packages/aiCore/src/core/options/index.ts
Normal file
@@ -0,0 +1,2 @@
|
||||
export * from './factory'
|
||||
export * from './types'
|
||||
38
packages/aiCore/src/core/options/openrouter.ts
Normal file
38
packages/aiCore/src/core/options/openrouter.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
export type OpenRouterProviderOptions = {
|
||||
models?: string[]
|
||||
|
||||
/**
|
||||
* https://openrouter.ai/docs/use-cases/reasoning-tokens
|
||||
* One of `max_tokens` or `effort` is required.
|
||||
* If `exclude` is true, reasoning will be removed from the response. Default is false.
|
||||
*/
|
||||
reasoning?: {
|
||||
exclude?: boolean
|
||||
} & (
|
||||
| {
|
||||
max_tokens: number
|
||||
}
|
||||
| {
|
||||
effort: 'high' | 'medium' | 'low'
|
||||
}
|
||||
)
|
||||
|
||||
/**
|
||||
* A unique identifier representing your end-user, which can
|
||||
* help OpenRouter to monitor and detect abuse.
|
||||
*/
|
||||
user?: string
|
||||
|
||||
extraBody?: Record<string, unknown>
|
||||
|
||||
/**
|
||||
* Enable usage accounting to get detailed token usage information.
|
||||
* https://openrouter.ai/docs/use-cases/usage-accounting
|
||||
*/
|
||||
usage?: {
|
||||
/**
|
||||
* When true, includes token usage information in the response.
|
||||
*/
|
||||
include: boolean
|
||||
}
|
||||
}
|
||||
33
packages/aiCore/src/core/options/types.ts
Normal file
33
packages/aiCore/src/core/options/types.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
import { type AnthropicProviderOptions } from '@ai-sdk/anthropic'
|
||||
import { type GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
|
||||
import { type OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
|
||||
import { type SharedV2ProviderMetadata } from '@ai-sdk/provider'
|
||||
|
||||
import { type OpenRouterProviderOptions } from './openrouter'
|
||||
import { type XaiProviderOptions } from './xai'
|
||||
|
||||
export type ProviderOptions<T extends keyof SharedV2ProviderMetadata> = SharedV2ProviderMetadata[T]
|
||||
|
||||
/**
|
||||
* 供应商选项类型,如果map中没有,说明没有约束
|
||||
*/
|
||||
export type ProviderOptionsMap = {
|
||||
openai: OpenAIResponsesProviderOptions
|
||||
anthropic: AnthropicProviderOptions
|
||||
google: GoogleGenerativeAIProviderOptions
|
||||
openrouter: OpenRouterProviderOptions
|
||||
xai: XaiProviderOptions
|
||||
}
|
||||
|
||||
// 工具类型,用于从ProviderOptionsMap中提取特定供应商的选项类型
|
||||
export type ExtractProviderOptions<T extends keyof ProviderOptionsMap> = ProviderOptionsMap[T]
|
||||
|
||||
/**
|
||||
* 类型安全的ProviderOptions
|
||||
* 对于已知供应商使用严格类型,对于未知供应商允许任意Record<string, JSONValue>
|
||||
*/
|
||||
export type TypedProviderOptions = {
|
||||
[K in keyof ProviderOptionsMap]?: ProviderOptionsMap[K]
|
||||
} & {
|
||||
[K in string]?: Record<string, any>
|
||||
} & SharedV2ProviderMetadata
|
||||
86
packages/aiCore/src/core/options/xai.ts
Normal file
86
packages/aiCore/src/core/options/xai.ts
Normal file
@@ -0,0 +1,86 @@
|
||||
// copy from @ai-sdk/xai/xai-chat-options.ts
|
||||
// 如果@ai-sdk/xai暴露出了xaiProviderOptions就删除这个文件
|
||||
|
||||
import * as z from 'zod/v4'
|
||||
|
||||
const webSourceSchema = z.object({
|
||||
type: z.literal('web'),
|
||||
country: z.string().length(2).optional(),
|
||||
excludedWebsites: z.array(z.string()).max(5).optional(),
|
||||
allowedWebsites: z.array(z.string()).max(5).optional(),
|
||||
safeSearch: z.boolean().optional()
|
||||
})
|
||||
|
||||
const xSourceSchema = z.object({
|
||||
type: z.literal('x'),
|
||||
xHandles: z.array(z.string()).optional()
|
||||
})
|
||||
|
||||
const newsSourceSchema = z.object({
|
||||
type: z.literal('news'),
|
||||
country: z.string().length(2).optional(),
|
||||
excludedWebsites: z.array(z.string()).max(5).optional(),
|
||||
safeSearch: z.boolean().optional()
|
||||
})
|
||||
|
||||
const rssSourceSchema = z.object({
|
||||
type: z.literal('rss'),
|
||||
links: z.array(z.url()).max(1) // currently only supports one RSS link
|
||||
})
|
||||
|
||||
const searchSourceSchema = z.discriminatedUnion('type', [
|
||||
webSourceSchema,
|
||||
xSourceSchema,
|
||||
newsSourceSchema,
|
||||
rssSourceSchema
|
||||
])
|
||||
|
||||
export const xaiProviderOptions = z.object({
|
||||
/**
|
||||
* reasoning effort for reasoning models
|
||||
* only supported by grok-3-mini and grok-3-mini-fast models
|
||||
*/
|
||||
reasoningEffort: z.enum(['low', 'high']).optional(),
|
||||
|
||||
searchParameters: z
|
||||
.object({
|
||||
/**
|
||||
* search mode preference
|
||||
* - "off": disables search completely
|
||||
* - "auto": model decides whether to search (default)
|
||||
* - "on": always enables search
|
||||
*/
|
||||
mode: z.enum(['off', 'auto', 'on']),
|
||||
|
||||
/**
|
||||
* whether to return citations in the response
|
||||
* defaults to true
|
||||
*/
|
||||
returnCitations: z.boolean().optional(),
|
||||
|
||||
/**
|
||||
* start date for search data (ISO8601 format: YYYY-MM-DD)
|
||||
*/
|
||||
fromDate: z.string().optional(),
|
||||
|
||||
/**
|
||||
* end date for search data (ISO8601 format: YYYY-MM-DD)
|
||||
*/
|
||||
toDate: z.string().optional(),
|
||||
|
||||
/**
|
||||
* maximum number of search results to consider
|
||||
* defaults to 20
|
||||
*/
|
||||
maxSearchResults: z.number().min(1).max(50).optional(),
|
||||
|
||||
/**
|
||||
* data sources to search from
|
||||
* defaults to ["web", "x"] if not specified
|
||||
*/
|
||||
sources: z.array(searchSourceSchema).optional()
|
||||
})
|
||||
.optional()
|
||||
})
|
||||
|
||||
export type XaiProviderOptions = z.infer<typeof xaiProviderOptions>
|
||||
257
packages/aiCore/src/core/plugins/README.md
Normal file
257
packages/aiCore/src/core/plugins/README.md
Normal file
@@ -0,0 +1,257 @@
|
||||
# AI Core 插件系统
|
||||
|
||||
支持四种钩子类型:**First**、**Sequential**、**Parallel** 和 **Stream**。
|
||||
|
||||
## 🎯 设计理念
|
||||
|
||||
- **语义清晰**:不同钩子有不同的执行语义
|
||||
- **类型安全**:TypeScript 完整支持
|
||||
- **性能优化**:First 短路、Parallel 并发、Sequential 链式
|
||||
- **易于扩展**:`enforce` 排序 + 功能分类
|
||||
|
||||
## 📋 钩子类型
|
||||
|
||||
### 🥇 First 钩子 - 首个有效结果
|
||||
|
||||
```typescript
|
||||
// 只执行第一个返回值的插件,用于解析和查找
|
||||
resolveModel?: (modelId: string, context: AiRequestContext) => string | null
|
||||
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null
|
||||
```
|
||||
|
||||
### 🔄 Sequential 钩子 - 链式数据转换
|
||||
|
||||
```typescript
|
||||
// 按顺序链式执行,每个插件可以修改数据
|
||||
transformParams?: (params: any, context: AiRequestContext) => any
|
||||
transformResult?: (result: any, context: AiRequestContext) => any
|
||||
```
|
||||
|
||||
### ⚡ Parallel 钩子 - 并行副作用
|
||||
|
||||
```typescript
|
||||
// 并发执行,用于日志、监控等副作用
|
||||
onRequestStart?: (context: AiRequestContext) => void
|
||||
onRequestEnd?: (context: AiRequestContext, result: any) => void
|
||||
onError?: (error: Error, context: AiRequestContext) => void
|
||||
```
|
||||
|
||||
### 🌊 Stream 钩子 - 流处理
|
||||
|
||||
```typescript
|
||||
// 直接使用 AI SDK 的 TransformStream
|
||||
transformStream?: () => (options) => TransformStream<TextStreamPart, TextStreamPart>
|
||||
```
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 基础用法
|
||||
|
||||
```typescript
|
||||
import { PluginManager, createContext, definePlugin } from '@cherrystudio/ai-core/middleware'
|
||||
|
||||
// 创建插件管理器
|
||||
const pluginManager = new PluginManager()
|
||||
|
||||
// 添加插件
|
||||
pluginManager.use({
|
||||
name: 'my-plugin',
|
||||
async transformParams(params, context) {
|
||||
return { ...params, temperature: 0.7 }
|
||||
}
|
||||
})
|
||||
|
||||
// 使用插件
|
||||
const context = createContext('openai', 'gpt-4', { messages: [] })
|
||||
const transformedParams = await pluginManager.executeSequential(
|
||||
'transformParams',
|
||||
{ messages: [{ role: 'user', content: 'Hello' }] },
|
||||
context
|
||||
)
|
||||
```
|
||||
|
||||
### 完整示例
|
||||
|
||||
```typescript
|
||||
import {
|
||||
PluginManager,
|
||||
ModelAliasPlugin,
|
||||
LoggingPlugin,
|
||||
ParamsValidationPlugin,
|
||||
createContext
|
||||
} from '@cherrystudio/ai-core/middleware'
|
||||
|
||||
// 创建插件管理器
|
||||
const manager = new PluginManager([
|
||||
ModelAliasPlugin, // 模型别名解析
|
||||
ParamsValidationPlugin, // 参数验证
|
||||
LoggingPlugin // 日志记录
|
||||
])
|
||||
|
||||
// AI 请求流程
|
||||
async function aiRequest(providerId: string, modelId: string, params: any) {
|
||||
const context = createContext(providerId, modelId, params)
|
||||
|
||||
try {
|
||||
// 1. 【并行】触发请求开始事件
|
||||
await manager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 【首个】解析模型别名
|
||||
const resolvedModel = await manager.executeFirst('resolveModel', modelId, context)
|
||||
context.modelId = resolvedModel || modelId
|
||||
|
||||
// 3. 【串行】转换请求参数
|
||||
const transformedParams = await manager.executeSequential('transformParams', params, context)
|
||||
|
||||
// 4. 【流处理】收集流转换器(AI SDK 原生支持数组)
|
||||
const streamTransforms = manager.collectStreamTransforms()
|
||||
|
||||
// 5. 调用 AI SDK(这里省略具体实现)
|
||||
const result = await callAiSdk(transformedParams, streamTransforms)
|
||||
|
||||
// 6. 【串行】转换响应结果
|
||||
const transformedResult = await manager.executeSequential('transformResult', result, context)
|
||||
|
||||
// 7. 【并行】触发请求完成事件
|
||||
await manager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
|
||||
return transformedResult
|
||||
} catch (error) {
|
||||
// 8. 【并行】触发错误事件
|
||||
await manager.executeParallel('onError', context, undefined, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## 🔧 自定义插件
|
||||
|
||||
### 模型别名插件
|
||||
|
||||
```typescript
|
||||
const ModelAliasPlugin = definePlugin({
|
||||
name: 'model-alias',
|
||||
enforce: 'pre', // 最先执行
|
||||
|
||||
async resolveModel(modelId) {
|
||||
const aliases = {
|
||||
gpt4: 'gpt-4-turbo-preview',
|
||||
claude: 'claude-3-sonnet-20240229'
|
||||
}
|
||||
return aliases[modelId] || null
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### 参数验证插件
|
||||
|
||||
```typescript
|
||||
const ValidationPlugin = definePlugin({
|
||||
name: 'validation',
|
||||
|
||||
async transformParams(params) {
|
||||
if (!params.messages) {
|
||||
throw new Error('messages is required')
|
||||
}
|
||||
|
||||
return {
|
||||
...params,
|
||||
temperature: params.temperature ?? 0.7,
|
||||
max_tokens: params.max_tokens ?? 4096
|
||||
}
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### 监控插件
|
||||
|
||||
```typescript
|
||||
const MonitoringPlugin = definePlugin({
|
||||
name: 'monitoring',
|
||||
enforce: 'post', // 最后执行
|
||||
|
||||
async onRequestEnd(context, result) {
|
||||
const duration = Date.now() - context.startTime
|
||||
console.log(`请求耗时: ${duration}ms`)
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
### 内容过滤插件
|
||||
|
||||
```typescript
|
||||
const FilterPlugin = definePlugin({
|
||||
name: 'content-filter',
|
||||
|
||||
transformStream() {
|
||||
return () =>
|
||||
new TransformStream({
|
||||
transform(chunk, controller) {
|
||||
if (chunk.type === 'text-delta') {
|
||||
const filtered = chunk.textDelta.replace(/敏感词/g, '***')
|
||||
controller.enqueue({ ...chunk, textDelta: filtered })
|
||||
} else {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
```
|
||||
|
||||
## 📊 执行顺序
|
||||
|
||||
### 插件排序
|
||||
|
||||
```
|
||||
enforce: 'pre' → normal → enforce: 'post'
|
||||
```
|
||||
|
||||
### 钩子执行流程
|
||||
|
||||
```mermaid
|
||||
graph TD
|
||||
A[请求开始] --> B[onRequestStart 并行执行]
|
||||
B --> C[resolveModel 首个有效]
|
||||
C --> D[loadTemplate 首个有效]
|
||||
D --> E[transformParams 串行执行]
|
||||
E --> F[collectStreamTransforms]
|
||||
F --> G[AI SDK 调用]
|
||||
G --> H[transformResult 串行执行]
|
||||
H --> I[onRequestEnd 并行执行]
|
||||
|
||||
G --> J[异常处理]
|
||||
J --> K[onError 并行执行]
|
||||
```
|
||||
|
||||
## 💡 最佳实践
|
||||
|
||||
1. **功能单一**:每个插件专注一个功能
|
||||
2. **幂等性**:插件应该是幂等的,重复执行不会产生副作用
|
||||
3. **错误处理**:插件内部处理异常,不要让异常向上传播
|
||||
4. **性能优化**:使用合适的钩子类型(First vs Sequential vs Parallel)
|
||||
5. **命名规范**:使用语义化的插件名称
|
||||
|
||||
## 🔍 调试工具
|
||||
|
||||
```typescript
|
||||
// 查看插件统计信息
|
||||
const stats = manager.getStats()
|
||||
console.log('插件统计:', stats)
|
||||
|
||||
// 查看所有插件
|
||||
const plugins = manager.getPlugins()
|
||||
console.log(
|
||||
'已注册插件:',
|
||||
plugins.map((p) => p.name)
|
||||
)
|
||||
```
|
||||
|
||||
## ⚡ 性能优势
|
||||
|
||||
- **First 钩子**:一旦找到结果立即停止,避免无效计算
|
||||
- **Parallel 钩子**:真正并发执行,不阻塞主流程
|
||||
- **Sequential 钩子**:保证数据转换的顺序性
|
||||
- **Stream 钩子**:直接集成 AI SDK,零开销
|
||||
|
||||
这个设计兼顾了简洁性和强大功能,为 AI Core 提供了灵活而高效的扩展机制。
|
||||
10
packages/aiCore/src/core/plugins/built-in/index.ts
Normal file
10
packages/aiCore/src/core/plugins/built-in/index.ts
Normal file
@@ -0,0 +1,10 @@
|
||||
/**
|
||||
* 内置插件命名空间
|
||||
* 所有内置插件都以 'built-in:' 为前缀
|
||||
*/
|
||||
export const BUILT_IN_PLUGIN_PREFIX = 'built-in:'
|
||||
|
||||
export { createLoggingPlugin } from './logging'
|
||||
export { createPromptToolUsePlugin } from './toolUsePlugin/promptToolUsePlugin'
|
||||
export type { PromptToolUseConfig, ToolUseRequestContext, ToolUseResult } from './toolUsePlugin/type'
|
||||
export { webSearchPlugin } from './webSearchPlugin'
|
||||
86
packages/aiCore/src/core/plugins/built-in/logging.ts
Normal file
86
packages/aiCore/src/core/plugins/built-in/logging.ts
Normal file
@@ -0,0 +1,86 @@
|
||||
/**
|
||||
* 内置插件:日志记录
|
||||
* 记录AI调用的关键信息,支持性能监控和调试
|
||||
*/
|
||||
import { definePlugin } from '../index'
|
||||
import type { AiRequestContext } from '../types'
|
||||
|
||||
export interface LoggingConfig {
|
||||
// 日志级别
|
||||
level?: 'debug' | 'info' | 'warn' | 'error'
|
||||
// 是否记录参数
|
||||
logParams?: boolean
|
||||
// 是否记录结果
|
||||
logResult?: boolean
|
||||
// 是否记录性能数据
|
||||
logPerformance?: boolean
|
||||
// 自定义日志函数
|
||||
logger?: (level: string, message: string, data?: any) => void
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建日志插件
|
||||
*/
|
||||
export function createLoggingPlugin(config: LoggingConfig = {}) {
|
||||
const { level = 'info', logParams = true, logResult = false, logPerformance = true, logger = console.log } = config
|
||||
|
||||
const startTimes = new Map<string, number>()
|
||||
|
||||
return definePlugin({
|
||||
name: 'built-in:logging',
|
||||
|
||||
onRequestStart: (context: AiRequestContext) => {
|
||||
const requestId = context.requestId
|
||||
startTimes.set(requestId, Date.now())
|
||||
|
||||
logger(level, `🚀 AI Request Started`, {
|
||||
requestId,
|
||||
providerId: context.providerId,
|
||||
modelId: context.modelId,
|
||||
originalParams: logParams ? context.originalParams : '[hidden]'
|
||||
})
|
||||
},
|
||||
|
||||
onRequestEnd: (context: AiRequestContext, result: any) => {
|
||||
const requestId = context.requestId
|
||||
const startTime = startTimes.get(requestId)
|
||||
const duration = startTime ? Date.now() - startTime : undefined
|
||||
startTimes.delete(requestId)
|
||||
|
||||
const logData: any = {
|
||||
requestId,
|
||||
providerId: context.providerId,
|
||||
modelId: context.modelId
|
||||
}
|
||||
|
||||
if (logPerformance && duration) {
|
||||
logData.duration = `${duration}ms`
|
||||
}
|
||||
|
||||
if (logResult) {
|
||||
logData.result = result
|
||||
}
|
||||
|
||||
logger(level, `✅ AI Request Completed`, logData)
|
||||
},
|
||||
|
||||
onError: (error: Error, context: AiRequestContext) => {
|
||||
const requestId = context.requestId
|
||||
const startTime = startTimes.get(requestId)
|
||||
const duration = startTime ? Date.now() - startTime : undefined
|
||||
startTimes.delete(requestId)
|
||||
|
||||
logger('error', `❌ AI Request Failed`, {
|
||||
requestId,
|
||||
providerId: context.providerId,
|
||||
modelId: context.modelId,
|
||||
duration: duration ? `${duration}ms` : undefined,
|
||||
error: {
|
||||
name: error.name,
|
||||
message: error.message,
|
||||
stack: error.stack
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,139 @@
|
||||
/**
|
||||
* 流事件管理器
|
||||
*
|
||||
* 负责处理 AI SDK 流事件的发送和管理
|
||||
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
|
||||
*/
|
||||
import type { ModelMessage } from 'ai'
|
||||
|
||||
import type { AiRequestContext } from '../../types'
|
||||
import type { StreamController } from './ToolExecutor'
|
||||
|
||||
/**
|
||||
* 流事件管理器类
|
||||
*/
|
||||
export class StreamEventManager {
|
||||
/**
|
||||
* 发送工具调用步骤开始事件
|
||||
*/
|
||||
sendStepStartEvent(controller: StreamController): void {
|
||||
controller.enqueue({
|
||||
type: 'start-step',
|
||||
request: {},
|
||||
warnings: []
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送步骤完成事件
|
||||
*/
|
||||
sendStepFinishEvent(controller: StreamController, chunk: any): void {
|
||||
controller.enqueue({
|
||||
type: 'finish-step',
|
||||
finishReason: 'stop',
|
||||
response: chunk.response,
|
||||
usage: chunk.usage,
|
||||
providerMetadata: chunk.providerMetadata
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理递归调用并将结果流接入当前流
|
||||
*/
|
||||
async handleRecursiveCall(
|
||||
controller: StreamController,
|
||||
recursiveParams: any,
|
||||
context: AiRequestContext,
|
||||
stepId: string
|
||||
): Promise<void> {
|
||||
try {
|
||||
console.log('[MCP Prompt] Starting recursive call after tool execution...')
|
||||
|
||||
const recursiveResult = await context.recursiveCall(recursiveParams)
|
||||
|
||||
if (recursiveResult && recursiveResult.fullStream) {
|
||||
await this.pipeRecursiveStream(controller, recursiveResult.fullStream)
|
||||
} else {
|
||||
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
|
||||
}
|
||||
} catch (error) {
|
||||
this.handleRecursiveCallError(controller, error, stepId)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 将递归流的数据传递到当前流
|
||||
*/
|
||||
private async pipeRecursiveStream(controller: StreamController, recursiveStream: ReadableStream): Promise<void> {
|
||||
const reader = recursiveStream.getReader()
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
if (done) {
|
||||
break
|
||||
}
|
||||
if (value.type === 'finish') {
|
||||
// 迭代的流不发finish
|
||||
break
|
||||
}
|
||||
// 将递归流的数据传递到当前流
|
||||
controller.enqueue(value)
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理递归调用错误
|
||||
*/
|
||||
private handleRecursiveCallError(controller: StreamController, error: unknown, stepId: string): void {
|
||||
console.error('[MCP Prompt] Recursive call failed:', error)
|
||||
|
||||
// 使用 AI SDK 标准错误格式,但不中断流
|
||||
controller.enqueue({
|
||||
type: 'error',
|
||||
error: {
|
||||
message: error instanceof Error ? error.message : String(error),
|
||||
name: error instanceof Error ? error.name : 'RecursiveCallError'
|
||||
}
|
||||
})
|
||||
|
||||
// 继续发送文本增量,保持流的连续性
|
||||
controller.enqueue({
|
||||
type: 'text-delta',
|
||||
id: stepId,
|
||||
text: '\n\n[工具执行后递归调用失败,继续对话...]'
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建递归调用的参数
|
||||
*/
|
||||
buildRecursiveParams(context: AiRequestContext, textBuffer: string, toolResultsText: string, tools: any): any {
|
||||
// 构建新的对话消息
|
||||
const newMessages: ModelMessage[] = [
|
||||
...(context.originalParams.messages || []),
|
||||
{
|
||||
role: 'assistant',
|
||||
content: textBuffer
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: toolResultsText
|
||||
}
|
||||
]
|
||||
|
||||
// 递归调用,继续对话,重新传递 tools
|
||||
const recursiveParams = {
|
||||
...context.originalParams,
|
||||
messages: newMessages,
|
||||
tools: tools
|
||||
}
|
||||
|
||||
// 更新上下文中的消息
|
||||
context.originalParams.messages = newMessages
|
||||
|
||||
return recursiveParams
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,156 @@
|
||||
/**
|
||||
* 工具执行器
|
||||
*
|
||||
* 负责工具的执行、结果格式化和相关事件发送
|
||||
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
|
||||
*/
|
||||
import type { ToolSet } from 'ai'
|
||||
|
||||
import type { ToolUseResult } from './type'
|
||||
|
||||
/**
|
||||
* 工具执行结果
|
||||
*/
|
||||
export interface ExecutedResult {
|
||||
toolCallId: string
|
||||
toolName: string
|
||||
result: any
|
||||
isError?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* 流控制器类型(从 AI SDK 提取)
|
||||
*/
|
||||
export interface StreamController {
|
||||
enqueue(chunk: any): void
|
||||
}
|
||||
|
||||
/**
|
||||
* 工具执行器类
|
||||
*/
|
||||
export class ToolExecutor {
|
||||
/**
|
||||
* 执行多个工具调用
|
||||
*/
|
||||
async executeTools(
|
||||
toolUses: ToolUseResult[],
|
||||
tools: ToolSet,
|
||||
controller: StreamController
|
||||
): Promise<ExecutedResult[]> {
|
||||
const executedResults: ExecutedResult[] = []
|
||||
|
||||
for (const toolUse of toolUses) {
|
||||
try {
|
||||
const tool = tools[toolUse.toolName]
|
||||
if (!tool || typeof tool.execute !== 'function') {
|
||||
throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
|
||||
}
|
||||
|
||||
// 发送工具调用开始事件
|
||||
this.sendToolStartEvents(controller, toolUse)
|
||||
|
||||
console.log(`[MCP Prompt Stream] Executing tool: ${toolUse.toolName}`, toolUse.arguments)
|
||||
|
||||
// 发送 tool-call 事件
|
||||
controller.enqueue({
|
||||
type: 'tool-call',
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
input: tool.inputSchema
|
||||
})
|
||||
|
||||
const result = await tool.execute(toolUse.arguments, {
|
||||
toolCallId: toolUse.id,
|
||||
messages: [],
|
||||
abortSignal: new AbortController().signal
|
||||
})
|
||||
|
||||
// 发送 tool-result 事件
|
||||
controller.enqueue({
|
||||
type: 'tool-result',
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
input: toolUse.arguments,
|
||||
output: result
|
||||
})
|
||||
|
||||
executedResults.push({
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
result,
|
||||
isError: false
|
||||
})
|
||||
} catch (error) {
|
||||
console.error(`[MCP Prompt Stream] Tool execution failed: ${toolUse.toolName}`, error)
|
||||
|
||||
// 处理错误情况
|
||||
const errorResult = this.handleToolError(toolUse, error, controller)
|
||||
executedResults.push(errorResult)
|
||||
}
|
||||
}
|
||||
|
||||
return executedResults
|
||||
}
|
||||
|
||||
/**
|
||||
* 格式化工具结果为 Cherry Studio 标准格式
|
||||
*/
|
||||
formatToolResults(executedResults: ExecutedResult[]): string {
|
||||
return executedResults
|
||||
.map((tr) => {
|
||||
if (!tr.isError) {
|
||||
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <result>${JSON.stringify(tr.result)}</result>\n</tool_use_result>`
|
||||
} else {
|
||||
const error = tr.result || 'Unknown error'
|
||||
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <error>${error}</error>\n</tool_use_result>`
|
||||
}
|
||||
})
|
||||
.join('\n\n')
|
||||
}
|
||||
|
||||
/**
|
||||
* 发送工具调用开始相关事件
|
||||
*/
|
||||
private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
|
||||
// 发送 tool-input-start 事件
|
||||
controller.enqueue({
|
||||
type: 'tool-input-start',
|
||||
id: toolUse.id,
|
||||
toolName: toolUse.toolName
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理工具执行错误
|
||||
*/
|
||||
private handleToolError(
|
||||
toolUse: ToolUseResult,
|
||||
error: unknown,
|
||||
controller: StreamController
|
||||
// _tools: ToolSet
|
||||
): ExecutedResult {
|
||||
// 使用 AI SDK 标准错误格式
|
||||
// const toolError: TypedToolError<typeof _tools> = {
|
||||
// type: 'tool-error',
|
||||
// toolCallId: toolUse.id,
|
||||
// toolName: toolUse.toolName,
|
||||
// input: toolUse.arguments,
|
||||
// error: error instanceof Error ? error.message : String(error)
|
||||
// }
|
||||
|
||||
// controller.enqueue(toolError)
|
||||
|
||||
// 发送标准错误事件
|
||||
controller.enqueue({
|
||||
type: 'error',
|
||||
error: error instanceof Error ? error.message : String(error)
|
||||
})
|
||||
|
||||
return {
|
||||
toolCallId: toolUse.id,
|
||||
toolName: toolUse.toolName,
|
||||
result: error instanceof Error ? error.message : String(error),
|
||||
isError: true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,373 @@
|
||||
/**
|
||||
* 内置插件:MCP Prompt 模式
|
||||
* 为不支持原生 Function Call 的模型提供 prompt 方式的工具调用
|
||||
* 内置默认逻辑,支持自定义覆盖
|
||||
*/
|
||||
import type { TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import { definePlugin } from '../../index'
|
||||
import type { AiRequestContext } from '../../types'
|
||||
import { StreamEventManager } from './StreamEventManager'
|
||||
import { ToolExecutor } from './ToolExecutor'
|
||||
import { PromptToolUseConfig, ToolUseResult } from './type'
|
||||
|
||||
/**
|
||||
* 默认系统提示符模板(提取自 Cherry Studio)
|
||||
*/
|
||||
const DEFAULT_SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \\
|
||||
You can use one tool per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
|
||||
|
||||
## Tool Use Formatting
|
||||
|
||||
Tool use is formatted using XML-style tags. The tool name is enclosed in opening and closing tags, and each parameter is similarly enclosed within its own set of tags. Here's the structure:
|
||||
|
||||
<tool_use>
|
||||
<name>{tool_name}</name>
|
||||
<arguments>{json_arguments}</arguments>
|
||||
</tool_use>
|
||||
|
||||
The tool name should be the exact name of the tool you are using, and the arguments should be a JSON object containing the parameters required by that tool. For example:
|
||||
<tool_use>
|
||||
<name>python_interpreter</name>
|
||||
<arguments>{"code": "5 + 3 + 1294.678"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
The user will respond with the result of the tool use, which should be formatted as follows:
|
||||
|
||||
<tool_use_result>
|
||||
<name>{tool_name}</name>
|
||||
<result>{result}</result>
|
||||
</tool_use_result>
|
||||
|
||||
The result should be a string, which can represent a file or any other output type. You can use this result as input for the next action.
|
||||
For example, if the result of the tool use is an image file, you can use it in the next action like this:
|
||||
|
||||
<tool_use>
|
||||
<name>image_transformer</name>
|
||||
<arguments>{"image": "image_1.jpg"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
Always adhere to this format for the tool use to ensure proper parsing and execution.
|
||||
|
||||
## Tool Use Examples
|
||||
{{ TOOL_USE_EXAMPLES }}
|
||||
|
||||
## Tool Use Available Tools
|
||||
Above example were using notional tools that might not exist for you. You only have access to these tools:
|
||||
{{ AVAILABLE_TOOLS }}
|
||||
|
||||
## Tool Use Rules
|
||||
Here are the rules you should always follow to solve your task:
|
||||
1. Always use the right arguments for the tools. Never use variable names as the action arguments, use the value instead.
|
||||
2. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself.
|
||||
3. If no tool call is needed, just answer the question directly.
|
||||
4. Never re-do a tool call that you previously did with the exact same parameters.
|
||||
5. For tool use, MAKE SURE use XML tag format as shown in the examples above. Do not use any other format.
|
||||
|
||||
# User Instructions
|
||||
{{ USER_SYSTEM_PROMPT }}
|
||||
|
||||
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.`
|
||||
|
||||
/**
|
||||
* 默认工具使用示例(提取自 Cherry Studio)
|
||||
*/
|
||||
const DEFAULT_TOOL_USE_EXAMPLES = `
|
||||
Here are a few examples using notional tools:
|
||||
---
|
||||
User: Generate an image of the oldest person in this document.
|
||||
|
||||
A: I can use the document_qa tool to find out who the oldest person is in the document.
|
||||
<tool_use>
|
||||
<name>document_qa</name>
|
||||
<arguments>{"document": "document.pdf", "question": "Who is the oldest person mentioned?"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>document_qa</name>
|
||||
<result>John Doe, a 55 year old lumberjack living in Newfoundland.</result>
|
||||
</tool_use_result>
|
||||
|
||||
A: I can use the image_generator tool to create a portrait of John Doe.
|
||||
<tool_use>
|
||||
<name>image_generator</name>
|
||||
<arguments>{"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>image_generator</name>
|
||||
<result>image.png</result>
|
||||
</tool_use_result>
|
||||
|
||||
A: the image is generated as image.png
|
||||
|
||||
---
|
||||
User: "What is the result of the following operation: 5 + 3 + 1294.678?"
|
||||
|
||||
A: I can use the python_interpreter tool to calculate the result of the operation.
|
||||
<tool_use>
|
||||
<name>python_interpreter</name>
|
||||
<arguments>{"code": "5 + 3 + 1294.678"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>python_interpreter</name>
|
||||
<result>1302.678</result>
|
||||
</tool_use_result>
|
||||
|
||||
A: The result of the operation is 1302.678.
|
||||
|
||||
---
|
||||
User: "Which city has the highest population , Guangzhou or Shanghai?"
|
||||
|
||||
A: I can use the search tool to find the population of Guangzhou.
|
||||
<tool_use>
|
||||
<name>search</name>
|
||||
<arguments>{"query": "Population Guangzhou"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>search</name>
|
||||
<result>Guangzhou has a population of 15 million inhabitants as of 2021.</result>
|
||||
</tool_use_result>
|
||||
|
||||
A: I can use the search tool to find the population of Shanghai.
|
||||
<tool_use>
|
||||
<name>search</name>
|
||||
<arguments>{"query": "Population Shanghai"}</arguments>
|
||||
</tool_use>
|
||||
|
||||
User: <tool_use_result>
|
||||
<name>search</name>
|
||||
<result>26 million (2019)</result>
|
||||
</tool_use_result>
|
||||
Assistant: The population of Shanghai is 26 million, while Guangzhou has a population of 15 million. Therefore, Shanghai has the highest population.`
|
||||
|
||||
/**
|
||||
* 构建可用工具部分(提取自 Cherry Studio)
|
||||
*/
|
||||
function buildAvailableTools(tools: ToolSet): string {
|
||||
const availableTools = Object.keys(tools)
|
||||
.map((toolName: string) => {
|
||||
const tool = tools[toolName]
|
||||
return `
|
||||
<tool>
|
||||
<name>${toolName}</name>
|
||||
<description>${tool.description || ''}</description>
|
||||
<arguments>
|
||||
${tool.inputSchema ? JSON.stringify(tool.inputSchema) : ''}
|
||||
</arguments>
|
||||
</tool>
|
||||
`
|
||||
})
|
||||
.join('\n')
|
||||
return `<tools>
|
||||
${availableTools}
|
||||
</tools>`
|
||||
}
|
||||
|
||||
/**
|
||||
* 默认的系统提示符构建函数(提取自 Cherry Studio)
|
||||
*/
|
||||
function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet): string {
|
||||
const availableTools = buildAvailableTools(tools)
|
||||
|
||||
const fullPrompt = DEFAULT_SYSTEM_PROMPT.replace('{{ TOOL_USE_EXAMPLES }}', DEFAULT_TOOL_USE_EXAMPLES)
|
||||
.replace('{{ AVAILABLE_TOOLS }}', availableTools)
|
||||
.replace('{{ USER_SYSTEM_PROMPT }}', userSystemPrompt || '')
|
||||
|
||||
return fullPrompt
|
||||
}
|
||||
|
||||
/**
|
||||
* 默认工具解析函数(提取自 Cherry Studio)
|
||||
* 解析 XML 格式的工具调用
|
||||
*/
|
||||
function defaultParseToolUse(content: string, tools: ToolSet): { results: ToolUseResult[]; content: string } {
|
||||
if (!content || !tools || Object.keys(tools).length === 0) {
|
||||
return { results: [], content: content }
|
||||
}
|
||||
|
||||
// 支持两种格式:
|
||||
// 1. 完整的 <tool_use></tool_use> 标签包围的内容
|
||||
// 2. 只有内部内容(从 TagExtractor 提取出来的)
|
||||
|
||||
let contentToProcess = content
|
||||
// 如果内容不包含 <tool_use> 标签,说明是从 TagExtractor 提取的内部内容,需要包装
|
||||
if (!content.includes('<tool_use>')) {
|
||||
contentToProcess = `<tool_use>\n${content}\n</tool_use>`
|
||||
}
|
||||
|
||||
const toolUsePattern =
|
||||
/<tool_use>([\s\S]*?)<name>([\s\S]*?)<\/name>([\s\S]*?)<arguments>([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g
|
||||
const results: ToolUseResult[] = []
|
||||
let match
|
||||
let idx = 0
|
||||
|
||||
// Find all tool use blocks
|
||||
while ((match = toolUsePattern.exec(contentToProcess)) !== null) {
|
||||
const fullMatch = match[0]
|
||||
const toolName = match[2].trim()
|
||||
const toolArgs = match[4].trim()
|
||||
|
||||
// Try to parse the arguments as JSON
|
||||
let parsedArgs
|
||||
try {
|
||||
parsedArgs = JSON.parse(toolArgs)
|
||||
} catch (error) {
|
||||
// If parsing fails, use the string as is
|
||||
parsedArgs = toolArgs
|
||||
}
|
||||
|
||||
// Find the corresponding tool
|
||||
const tool = tools[toolName]
|
||||
if (!tool) {
|
||||
console.warn(`Tool "${toolName}" not found in available tools`)
|
||||
continue
|
||||
}
|
||||
|
||||
// Add to results array
|
||||
results.push({
|
||||
id: `${toolName}-${idx++}`, // Unique ID for each tool use
|
||||
toolName: toolName,
|
||||
arguments: parsedArgs,
|
||||
status: 'pending'
|
||||
})
|
||||
contentToProcess = contentToProcess.replace(fullMatch, '')
|
||||
}
|
||||
return { results, content: contentToProcess }
|
||||
}
|
||||
|
||||
export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||
const { enabled = true, buildSystemPrompt = defaultBuildSystemPrompt, parseToolUse = defaultParseToolUse } = config
|
||||
|
||||
return definePlugin({
|
||||
name: 'built-in:prompt-tool-use',
|
||||
transformParams: (params: any, context: AiRequestContext) => {
|
||||
if (!enabled || !params.tools || typeof params.tools !== 'object') {
|
||||
return params
|
||||
}
|
||||
|
||||
context.mcpTools = params.tools
|
||||
console.log('tools stored in context', params.tools)
|
||||
|
||||
// 构建系统提示符
|
||||
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
|
||||
const systemPrompt = buildSystemPrompt(userSystemPrompt, params.tools)
|
||||
let systemMessage: string | null = systemPrompt
|
||||
console.log('config.context', context)
|
||||
if (config.createSystemMessage) {
|
||||
// 🎯 如果用户提供了自定义处理函数,使用它
|
||||
systemMessage = config.createSystemMessage(systemPrompt, params, context)
|
||||
}
|
||||
|
||||
// 移除 tools,改为 prompt 模式
|
||||
const transformedParams = {
|
||||
...params,
|
||||
...(systemMessage ? { system: systemMessage } : {}),
|
||||
tools: undefined
|
||||
}
|
||||
context.originalParams = transformedParams
|
||||
console.log('transformedParams', transformedParams)
|
||||
return transformedParams
|
||||
},
|
||||
transformStream: (_: any, context: AiRequestContext) => () => {
|
||||
let textBuffer = ''
|
||||
let stepId = ''
|
||||
|
||||
if (!context.mcpTools) {
|
||||
throw new Error('No tools available')
|
||||
}
|
||||
|
||||
// 创建工具执行器和流事件管理器
|
||||
const toolExecutor = new ToolExecutor()
|
||||
const streamEventManager = new StreamEventManager()
|
||||
|
||||
type TOOLS = NonNullable<typeof context.mcpTools>
|
||||
return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
|
||||
async transform(
|
||||
chunk: TextStreamPart<TOOLS>,
|
||||
controller: TransformStreamDefaultController<TextStreamPart<TOOLS>>
|
||||
) {
|
||||
// 收集文本内容
|
||||
if (chunk.type === 'text-delta') {
|
||||
textBuffer += chunk.text || ''
|
||||
stepId = chunk.id || ''
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
|
||||
if (chunk.type === 'text-end' || chunk.type === 'finish-step') {
|
||||
const tools = context.mcpTools
|
||||
if (!tools || Object.keys(tools).length === 0) {
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
|
||||
// 解析工具调用
|
||||
const { results: parsedTools, content: parsedContent } = parseToolUse(textBuffer, tools)
|
||||
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
|
||||
|
||||
// 如果没有有效的工具调用,直接传递原始事件
|
||||
if (validToolUses.length === 0) {
|
||||
controller.enqueue(chunk)
|
||||
return
|
||||
}
|
||||
|
||||
if (chunk.type === 'text-end') {
|
||||
controller.enqueue({
|
||||
type: 'text-end',
|
||||
id: stepId,
|
||||
providerMetadata: {
|
||||
text: {
|
||||
value: parsedContent
|
||||
}
|
||||
}
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
finishReason: 'tool-calls'
|
||||
})
|
||||
|
||||
// 发送步骤开始事件
|
||||
streamEventManager.sendStepStartEvent(controller)
|
||||
|
||||
// 执行工具调用
|
||||
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
|
||||
|
||||
// 发送步骤完成事件
|
||||
streamEventManager.sendStepFinishEvent(controller, chunk)
|
||||
|
||||
// 处理递归调用
|
||||
if (validToolUses.length > 0) {
|
||||
const toolResultsText = toolExecutor.formatToolResults(executedResults)
|
||||
const recursiveParams = streamEventManager.buildRecursiveParams(
|
||||
context,
|
||||
textBuffer,
|
||||
toolResultsText,
|
||||
tools
|
||||
)
|
||||
|
||||
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context, stepId)
|
||||
}
|
||||
|
||||
// 清理状态
|
||||
textBuffer = ''
|
||||
return
|
||||
}
|
||||
|
||||
// 对于其他类型的事件,直接传递
|
||||
controller.enqueue(chunk)
|
||||
},
|
||||
|
||||
flush() {
|
||||
// 流结束时的清理工作
|
||||
console.log('[MCP Prompt] Stream ended, cleaning up...')
|
||||
}
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,196 @@
|
||||
// Copied from https://github.com/vercel/ai/blob/main/packages/ai/core/util/get-potential-start-index.ts
|
||||
|
||||
/**
|
||||
* Returns the index of the start of the searchedText in the text, or null if it
|
||||
* is not found.
|
||||
*/
|
||||
export function getPotentialStartIndex(text: string, searchedText: string): number | null {
|
||||
// Return null immediately if searchedText is empty.
|
||||
if (searchedText.length === 0) {
|
||||
return null
|
||||
}
|
||||
|
||||
// Check if the searchedText exists as a direct substring of text.
|
||||
const directIndex = text.indexOf(searchedText)
|
||||
if (directIndex !== -1) {
|
||||
return directIndex
|
||||
}
|
||||
|
||||
// Otherwise, look for the largest suffix of "text" that matches
|
||||
// a prefix of "searchedText". We go from the end of text inward.
|
||||
for (let i = text.length - 1; i >= 0; i--) {
|
||||
const suffix = text.substring(i)
|
||||
if (searchedText.startsWith(suffix)) {
|
||||
return i
|
||||
}
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
export interface TagConfig {
|
||||
openingTag: string
|
||||
closingTag: string
|
||||
separator?: string
|
||||
}
|
||||
|
||||
export interface TagExtractionState {
|
||||
textBuffer: string
|
||||
isInsideTag: boolean
|
||||
isFirstTag: boolean
|
||||
isFirstText: boolean
|
||||
afterSwitch: boolean
|
||||
accumulatedTagContent: string
|
||||
hasTagContent: boolean
|
||||
}
|
||||
|
||||
export interface TagExtractionResult {
|
||||
content: string
|
||||
isTagContent: boolean
|
||||
complete: boolean
|
||||
tagContentExtracted?: string
|
||||
}
|
||||
|
||||
/**
|
||||
* 通用标签提取处理器
|
||||
* 可以处理各种形式的标签对,如 <think>...</think>, <tool_use>...</tool_use> 等
|
||||
*/
|
||||
export class TagExtractor {
|
||||
private config: TagConfig
|
||||
private state: TagExtractionState
|
||||
|
||||
constructor(config: TagConfig) {
|
||||
this.config = config
|
||||
this.state = {
|
||||
textBuffer: '',
|
||||
isInsideTag: false,
|
||||
isFirstTag: true,
|
||||
isFirstText: true,
|
||||
afterSwitch: false,
|
||||
accumulatedTagContent: '',
|
||||
hasTagContent: false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理文本块,返回处理结果
|
||||
*/
|
||||
processText(newText: string): TagExtractionResult[] {
|
||||
this.state.textBuffer += newText
|
||||
const results: TagExtractionResult[] = []
|
||||
|
||||
// 处理标签提取逻辑
|
||||
while (true) {
|
||||
const nextTag = this.state.isInsideTag ? this.config.closingTag : this.config.openingTag
|
||||
const startIndex = getPotentialStartIndex(this.state.textBuffer, nextTag)
|
||||
|
||||
if (startIndex == null) {
|
||||
const content = this.state.textBuffer
|
||||
if (content.length > 0) {
|
||||
results.push({
|
||||
content: this.addPrefix(content),
|
||||
isTagContent: this.state.isInsideTag,
|
||||
complete: false
|
||||
})
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.accumulatedTagContent += this.addPrefix(content)
|
||||
this.state.hasTagContent = true
|
||||
}
|
||||
}
|
||||
this.state.textBuffer = ''
|
||||
break
|
||||
}
|
||||
|
||||
// 处理标签前的内容
|
||||
const contentBeforeTag = this.state.textBuffer.slice(0, startIndex)
|
||||
if (contentBeforeTag.length > 0) {
|
||||
results.push({
|
||||
content: this.addPrefix(contentBeforeTag),
|
||||
isTagContent: this.state.isInsideTag,
|
||||
complete: false
|
||||
})
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.accumulatedTagContent += this.addPrefix(contentBeforeTag)
|
||||
this.state.hasTagContent = true
|
||||
}
|
||||
}
|
||||
|
||||
const foundFullMatch = startIndex + nextTag.length <= this.state.textBuffer.length
|
||||
|
||||
if (foundFullMatch) {
|
||||
// 如果找到完整的标签
|
||||
this.state.textBuffer = this.state.textBuffer.slice(startIndex + nextTag.length)
|
||||
|
||||
// 如果刚刚结束一个标签内容,生成完整的标签内容结果
|
||||
if (this.state.isInsideTag && this.state.hasTagContent) {
|
||||
results.push({
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: this.state.accumulatedTagContent
|
||||
})
|
||||
this.state.accumulatedTagContent = ''
|
||||
this.state.hasTagContent = false
|
||||
}
|
||||
|
||||
this.state.isInsideTag = !this.state.isInsideTag
|
||||
this.state.afterSwitch = true
|
||||
|
||||
if (this.state.isInsideTag) {
|
||||
this.state.isFirstTag = false
|
||||
} else {
|
||||
this.state.isFirstText = false
|
||||
}
|
||||
} else {
|
||||
this.state.textBuffer = this.state.textBuffer.slice(startIndex)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return results
|
||||
}
|
||||
|
||||
/**
|
||||
* 完成处理,返回任何剩余的标签内容
|
||||
*/
|
||||
finalize(): TagExtractionResult | null {
|
||||
if (this.state.hasTagContent && this.state.accumulatedTagContent) {
|
||||
const result = {
|
||||
content: '',
|
||||
isTagContent: false,
|
||||
complete: true,
|
||||
tagContentExtracted: this.state.accumulatedTagContent
|
||||
}
|
||||
this.state.accumulatedTagContent = ''
|
||||
this.state.hasTagContent = false
|
||||
return result
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
private addPrefix(text: string): string {
|
||||
const needsPrefix =
|
||||
this.state.afterSwitch && (this.state.isInsideTag ? !this.state.isFirstTag : !this.state.isFirstText)
|
||||
|
||||
const prefix = needsPrefix && this.config.separator ? this.config.separator : ''
|
||||
this.state.afterSwitch = false
|
||||
return prefix + text
|
||||
}
|
||||
|
||||
/**
|
||||
* 重置状态
|
||||
*/
|
||||
reset(): void {
|
||||
this.state = {
|
||||
textBuffer: '',
|
||||
isInsideTag: false,
|
||||
isFirstTag: true,
|
||||
isFirstText: true,
|
||||
afterSwitch: false,
|
||||
accumulatedTagContent: '',
|
||||
hasTagContent: false
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,33 @@
|
||||
import { ToolSet } from 'ai'
|
||||
|
||||
import { AiRequestContext } from '../..'
|
||||
|
||||
/**
|
||||
* 解析结果类型
|
||||
* 表示从AI响应中解析出的工具使用意图
|
||||
*/
|
||||
export interface ToolUseResult {
|
||||
id: string
|
||||
toolName: string
|
||||
arguments: any
|
||||
status: 'pending' | 'invoking' | 'done' | 'error'
|
||||
}
|
||||
|
||||
export interface BaseToolUsePluginConfig {
|
||||
enabled?: boolean
|
||||
}
|
||||
|
||||
export interface PromptToolUseConfig extends BaseToolUsePluginConfig {
|
||||
// 自定义系统提示符构建函数(可选,有默认实现)
|
||||
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => string
|
||||
// 自定义工具解析函数(可选,有默认实现)
|
||||
parseToolUse?: (content: string, tools: ToolSet) => { results: ToolUseResult[]; content: string }
|
||||
createSystemMessage?: (systemPrompt: string, originalParams: any, context: AiRequestContext) => string | null
|
||||
}
|
||||
|
||||
/**
|
||||
* 扩展的 AI 请求上下文,支持 MCP 工具存储
|
||||
*/
|
||||
export interface ToolUseRequestContext extends AiRequestContext {
|
||||
mcpTools: ToolSet
|
||||
}
|
||||
@@ -0,0 +1,67 @@
|
||||
import { anthropic } from '@ai-sdk/anthropic'
|
||||
import { google } from '@ai-sdk/google'
|
||||
import { openai } from '@ai-sdk/openai'
|
||||
|
||||
import { ProviderOptionsMap } from '../../../options/types'
|
||||
|
||||
/**
|
||||
* 从 AI SDK 的工具函数中提取参数类型,以确保类型安全。
|
||||
*/
|
||||
type OpenAISearchConfig = Parameters<typeof openai.tools.webSearchPreview>[0]
|
||||
type AnthropicSearchConfig = Parameters<typeof anthropic.tools.webSearch_20250305>[0]
|
||||
type GoogleSearchConfig = Parameters<typeof google.tools.googleSearch>[0]
|
||||
|
||||
/**
|
||||
* 插件初始化时接收的完整配置对象
|
||||
*
|
||||
* 其结构与 ProviderOptions 保持一致,方便上游统一管理配置
|
||||
*/
|
||||
export interface WebSearchPluginConfig {
|
||||
openai?: OpenAISearchConfig
|
||||
anthropic?: AnthropicSearchConfig
|
||||
xai?: ProviderOptionsMap['xai']['searchParameters']
|
||||
google?: GoogleSearchConfig
|
||||
'google-vertex'?: GoogleSearchConfig
|
||||
}
|
||||
|
||||
/**
|
||||
* 插件的默认配置
|
||||
*/
|
||||
export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
|
||||
google: {},
|
||||
'google-vertex': {},
|
||||
openai: {},
|
||||
xai: {
|
||||
mode: 'on',
|
||||
returnCitations: true,
|
||||
maxSearchResults: 5,
|
||||
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
|
||||
},
|
||||
anthropic: {
|
||||
maxUses: 5
|
||||
}
|
||||
}
|
||||
|
||||
export type WebSearchToolOutputSchema = {
|
||||
// Anthropic 工具 - 手动定义
|
||||
anthropicWebSearch: Array<{
|
||||
url: string
|
||||
title: string
|
||||
pageAge: string | null
|
||||
encryptedContent: string
|
||||
type: string
|
||||
}>
|
||||
|
||||
// OpenAI 工具 - 基于实际输出
|
||||
openaiWebSearch: {
|
||||
status: 'completed' | 'failed'
|
||||
}
|
||||
|
||||
// Google 工具
|
||||
googleSearch: {
|
||||
webSearchQueries?: string[]
|
||||
groundingChunks?: Array<{
|
||||
web?: { uri: string; title: string }
|
||||
}>
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,69 @@
|
||||
/**
|
||||
* Web Search Plugin
|
||||
* 提供统一的网络搜索能力,支持多个 AI Provider
|
||||
*/
|
||||
import { anthropic } from '@ai-sdk/anthropic'
|
||||
import { google } from '@ai-sdk/google'
|
||||
import { openai } from '@ai-sdk/openai'
|
||||
|
||||
import { createXaiOptions, mergeProviderOptions } from '../../../options'
|
||||
import { definePlugin } from '../../'
|
||||
import type { AiRequestContext } from '../../types'
|
||||
import { DEFAULT_WEB_SEARCH_CONFIG, WebSearchPluginConfig } from './helper'
|
||||
|
||||
/**
|
||||
* 网络搜索插件
|
||||
*
|
||||
* @param config - 在插件初始化时传入的静态配置
|
||||
*/
|
||||
export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEARCH_CONFIG) =>
|
||||
definePlugin({
|
||||
name: 'webSearch',
|
||||
enforce: 'pre',
|
||||
|
||||
transformParams: async (params: any, context: AiRequestContext) => {
|
||||
const { providerId } = context
|
||||
switch (providerId) {
|
||||
case 'openai': {
|
||||
if (config.openai) {
|
||||
if (!params.tools) params.tools = {}
|
||||
params.tools.web_search_preview = openai.tools.webSearchPreview(config.openai)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case 'anthropic': {
|
||||
if (config.anthropic) {
|
||||
if (!params.tools) params.tools = {}
|
||||
params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic)
|
||||
}
|
||||
break
|
||||
}
|
||||
|
||||
case 'google': {
|
||||
// case 'google-vertex':
|
||||
if (!params.tools) params.tools = {}
|
||||
params.tools.web_search = google.tools.googleSearch(config.google || {})
|
||||
break
|
||||
}
|
||||
|
||||
case 'xai': {
|
||||
if (config.xai) {
|
||||
const searchOptions = createXaiOptions({
|
||||
searchParameters: { ...config.xai, mode: 'on' }
|
||||
})
|
||||
params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return params
|
||||
}
|
||||
})
|
||||
|
||||
// 导出类型定义供开发者使用
|
||||
export type { WebSearchPluginConfig, WebSearchToolOutputSchema } from './helper'
|
||||
|
||||
// 默认导出
|
||||
export default webSearchPlugin
|
||||
32
packages/aiCore/src/core/plugins/index.ts
Normal file
32
packages/aiCore/src/core/plugins/index.ts
Normal file
@@ -0,0 +1,32 @@
|
||||
// 核心类型和接口
|
||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types'
|
||||
import type { ProviderId } from '../providers'
|
||||
import type { AiPlugin, AiRequestContext } from './types'
|
||||
|
||||
// 插件管理器
|
||||
export { PluginManager } from './manager'
|
||||
|
||||
// 工具函数
|
||||
export function createContext<T extends ProviderId>(
|
||||
providerId: T,
|
||||
modelId: string,
|
||||
originalParams: any
|
||||
): AiRequestContext {
|
||||
return {
|
||||
providerId,
|
||||
modelId,
|
||||
originalParams,
|
||||
metadata: {},
|
||||
startTime: Date.now(),
|
||||
requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
|
||||
// 占位
|
||||
recursiveCall: () => Promise.resolve(null)
|
||||
}
|
||||
}
|
||||
|
||||
// 插件构建器 - 便于创建插件
|
||||
export function definePlugin(plugin: AiPlugin): AiPlugin
|
||||
export function definePlugin<T extends (...args: any[]) => AiPlugin>(pluginFactory: T): T
|
||||
export function definePlugin(plugin: AiPlugin | ((...args: any[]) => AiPlugin)) {
|
||||
return plugin
|
||||
}
|
||||
184
packages/aiCore/src/core/plugins/manager.ts
Normal file
184
packages/aiCore/src/core/plugins/manager.ts
Normal file
@@ -0,0 +1,184 @@
|
||||
import { AiPlugin, AiRequestContext } from './types'
|
||||
|
||||
/**
|
||||
* 插件管理器
|
||||
*/
|
||||
export class PluginManager {
|
||||
private plugins: AiPlugin[] = []
|
||||
|
||||
constructor(plugins: AiPlugin[] = []) {
|
||||
this.plugins = this.sortPlugins(plugins)
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加插件
|
||||
*/
|
||||
use(plugin: AiPlugin): this {
|
||||
this.plugins = this.sortPlugins([...this.plugins, plugin])
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 移除插件
|
||||
*/
|
||||
remove(pluginName: string): this {
|
||||
this.plugins = this.plugins.filter((p) => p.name !== pluginName)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 插件排序:pre -> normal -> post
|
||||
*/
|
||||
private sortPlugins(plugins: AiPlugin[]): AiPlugin[] {
|
||||
const pre: AiPlugin[] = []
|
||||
const normal: AiPlugin[] = []
|
||||
const post: AiPlugin[] = []
|
||||
|
||||
plugins.forEach((plugin) => {
|
||||
if (plugin.enforce === 'pre') {
|
||||
pre.push(plugin)
|
||||
} else if (plugin.enforce === 'post') {
|
||||
post.push(plugin)
|
||||
} else {
|
||||
normal.push(plugin)
|
||||
}
|
||||
})
|
||||
|
||||
return [...pre, ...normal, ...post]
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 First 钩子 - 返回第一个有效结果
|
||||
*/
|
||||
async executeFirst<T>(
|
||||
hookName: 'resolveModel' | 'loadTemplate',
|
||||
arg: any,
|
||||
context: AiRequestContext
|
||||
): Promise<T | null> {
|
||||
for (const plugin of this.plugins) {
|
||||
const hook = plugin[hookName]
|
||||
if (hook) {
|
||||
const result = await hook(arg, context)
|
||||
if (result !== null && result !== undefined) {
|
||||
return result as T
|
||||
}
|
||||
}
|
||||
}
|
||||
return null
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 Sequential 钩子 - 链式数据转换
|
||||
*/
|
||||
async executeSequential<T>(
|
||||
hookName: 'transformParams' | 'transformResult',
|
||||
initialValue: T,
|
||||
context: AiRequestContext
|
||||
): Promise<T> {
|
||||
let result = initialValue
|
||||
|
||||
for (const plugin of this.plugins) {
|
||||
const hook = plugin[hookName]
|
||||
if (hook) {
|
||||
result = await hook<T>(result, context)
|
||||
}
|
||||
}
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 ConfigureContext 钩子 - 串行配置上下文
|
||||
*/
|
||||
async executeConfigureContext(context: AiRequestContext): Promise<void> {
|
||||
for (const plugin of this.plugins) {
|
||||
const hook = plugin.configureContext
|
||||
if (hook) {
|
||||
await hook(context)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行 Parallel 钩子 - 并行副作用
|
||||
*/
|
||||
async executeParallel(
|
||||
hookName: 'onRequestStart' | 'onRequestEnd' | 'onError',
|
||||
context: AiRequestContext,
|
||||
result?: any,
|
||||
error?: Error
|
||||
): Promise<void> {
|
||||
const promises = this.plugins
|
||||
.map((plugin) => {
|
||||
const hook = plugin[hookName]
|
||||
if (!hook) return null
|
||||
|
||||
if (hookName === 'onError' && error) {
|
||||
return (hook as any)(error, context)
|
||||
} else if (hookName === 'onRequestEnd' && result !== undefined) {
|
||||
return (hook as any)(context, result)
|
||||
} else if (hookName === 'onRequestStart') {
|
||||
return (hook as any)(context)
|
||||
}
|
||||
return null
|
||||
})
|
||||
.filter(Boolean)
|
||||
|
||||
// 使用 Promise.all 而不是 allSettled,让插件错误能够抛出
|
||||
await Promise.all(promises)
|
||||
}
|
||||
|
||||
/**
|
||||
* 收集所有流转换器(返回数组,AI SDK 原生支持)
|
||||
*/
|
||||
collectStreamTransforms(params: any, context: AiRequestContext) {
|
||||
return this.plugins
|
||||
.filter((plugin) => plugin.transformStream)
|
||||
.map((plugin) => plugin.transformStream?.(params, context))
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有插件信息
|
||||
*/
|
||||
getPlugins(): AiPlugin[] {
|
||||
return [...this.plugins]
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取插件统计信息
|
||||
*/
|
||||
getStats() {
|
||||
const stats = {
|
||||
total: this.plugins.length,
|
||||
pre: 0,
|
||||
normal: 0,
|
||||
post: 0,
|
||||
hooks: {
|
||||
resolveModel: 0,
|
||||
loadTemplate: 0,
|
||||
transformParams: 0,
|
||||
transformResult: 0,
|
||||
onRequestStart: 0,
|
||||
onRequestEnd: 0,
|
||||
onError: 0,
|
||||
transformStream: 0
|
||||
}
|
||||
}
|
||||
|
||||
this.plugins.forEach((plugin) => {
|
||||
// 统计 enforce 类型
|
||||
if (plugin.enforce === 'pre') stats.pre++
|
||||
else if (plugin.enforce === 'post') stats.post++
|
||||
else stats.normal++
|
||||
|
||||
// 统计钩子数量
|
||||
Object.keys(stats.hooks).forEach((hookName) => {
|
||||
if (plugin[hookName as keyof AiPlugin]) {
|
||||
stats.hooks[hookName as keyof typeof stats.hooks]++
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
return stats
|
||||
}
|
||||
}
|
||||
79
packages/aiCore/src/core/plugins/types.ts
Normal file
79
packages/aiCore/src/core/plugins/types.ts
Normal file
@@ -0,0 +1,79 @@
|
||||
import type { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import type { LanguageModel, TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import { type ProviderId } from '../providers/types'
|
||||
|
||||
/**
|
||||
* 递归调用函数类型
|
||||
* 使用 any 是因为递归调用时参数和返回类型可能完全不同
|
||||
*/
|
||||
export type RecursiveCallFn = (newParams: any) => Promise<any>
|
||||
|
||||
/**
|
||||
* AI 请求上下文
|
||||
*/
|
||||
export interface AiRequestContext {
|
||||
providerId: ProviderId
|
||||
modelId: string
|
||||
originalParams: any
|
||||
metadata: Record<string, any>
|
||||
startTime: number
|
||||
requestId: string
|
||||
recursiveCall: RecursiveCallFn
|
||||
isRecursiveCall?: boolean
|
||||
mcpTools?: ToolSet
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
/**
|
||||
* 钩子分类
|
||||
*/
|
||||
export interface AiPlugin {
|
||||
name: string
|
||||
enforce?: 'pre' | 'post'
|
||||
|
||||
// 【First】首个钩子 - 只执行第一个返回值的插件
|
||||
resolveModel?: (
|
||||
modelId: string,
|
||||
context: AiRequestContext
|
||||
) => Promise<LanguageModel | ImageModelV2 | null> | LanguageModel | ImageModelV2 | null
|
||||
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
|
||||
|
||||
// 【Sequential】串行钩子 - 链式执行,支持数据转换
|
||||
configureContext?: (context: AiRequestContext) => void | Promise<void>
|
||||
transformParams?: <T>(params: T, context: AiRequestContext) => T | Promise<T>
|
||||
transformResult?: <T>(result: T, context: AiRequestContext) => T | Promise<T>
|
||||
|
||||
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
|
||||
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
|
||||
onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise<void>
|
||||
onError?: (error: Error, context: AiRequestContext) => void | Promise<void>
|
||||
|
||||
// 【Stream】流处理 - 直接使用 AI SDK
|
||||
transformStream?: (
|
||||
params: any,
|
||||
context: AiRequestContext
|
||||
) => <TOOLS extends ToolSet>(options?: {
|
||||
tools: TOOLS
|
||||
stopStream: () => void
|
||||
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
|
||||
|
||||
// AI SDK 原生中间件
|
||||
// aiSdkMiddlewares?: LanguageModelV1Middleware[]
|
||||
}
|
||||
|
||||
/**
|
||||
* 插件管理器配置
|
||||
*/
|
||||
export interface PluginManagerConfig {
|
||||
plugins: AiPlugin[]
|
||||
context: Partial<AiRequestContext>
|
||||
}
|
||||
|
||||
/**
|
||||
* 钩子执行结果
|
||||
*/
|
||||
export interface HookResult<T = any> {
|
||||
value: T
|
||||
stop?: boolean
|
||||
}
|
||||
96
packages/aiCore/src/core/providers/HubProvider.ts
Normal file
96
packages/aiCore/src/core/providers/HubProvider.ts
Normal file
@@ -0,0 +1,96 @@
|
||||
/**
|
||||
* Hub Provider - 支持路由到多个底层provider
|
||||
*
|
||||
* 支持格式: hubId:providerId:modelId
|
||||
* 例如: aihubmix:anthropic:claude-3.5-sonnet
|
||||
*/
|
||||
|
||||
import { EmbeddingModelV2, ImageModelV2, ProviderV2, SpeechModelV2, TranscriptionModelV2 } from '@ai-sdk/provider'
|
||||
import { customProvider } from 'ai'
|
||||
|
||||
import { globalRegistryManagement } from './RegistryManagement'
|
||||
|
||||
export interface HubProviderConfig {
|
||||
/** Hub的唯一标识符 */
|
||||
hubId: string
|
||||
/** 是否启用调试日志 */
|
||||
debug?: boolean
|
||||
}
|
||||
|
||||
export class HubProviderError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public readonly hubId: string,
|
||||
public readonly providerId?: string,
|
||||
public readonly originalError?: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'HubProviderError'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析Hub模型ID
|
||||
*/
|
||||
function parseHubModelId(modelId: string): { provider: string; actualModelId: string } {
|
||||
const parts = modelId.split(':')
|
||||
if (parts.length !== 2) {
|
||||
throw new HubProviderError(`Invalid hub model ID format. Expected "provider:modelId", got: ${modelId}`, 'unknown')
|
||||
}
|
||||
return {
|
||||
provider: parts[0],
|
||||
actualModelId: parts[1]
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建Hub Provider
|
||||
*/
|
||||
export function createHubProvider(config: HubProviderConfig): ProviderV2 {
|
||||
const { hubId } = config
|
||||
|
||||
function getTargetProvider(providerId: string): ProviderV2 {
|
||||
// 从全局注册表获取provider实例
|
||||
try {
|
||||
const provider = (globalRegistryManagement as any).getProvider(providerId)
|
||||
if (!provider) {
|
||||
throw new HubProviderError(
|
||||
`Provider "${providerId}" is not initialized. Please call initializeProvider("${providerId}", options) first.`,
|
||||
hubId,
|
||||
providerId
|
||||
)
|
||||
}
|
||||
return provider
|
||||
} catch (error) {
|
||||
throw new HubProviderError(
|
||||
`Failed to get provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||
hubId,
|
||||
providerId,
|
||||
error instanceof Error ? error : undefined
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
function resolveModel<T>(modelId: string, modelType: string, methodName: keyof ProviderV2): T {
|
||||
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||
const targetProvider = getTargetProvider(provider)
|
||||
|
||||
if (!targetProvider[methodName]) {
|
||||
throw new HubProviderError(`Provider "${provider}" does not support ${modelType}`, hubId, provider)
|
||||
}
|
||||
|
||||
return (targetProvider[methodName] as any)(actualModelId)
|
||||
}
|
||||
|
||||
return customProvider({
|
||||
fallbackProvider: {
|
||||
languageModel: (modelId: string) => resolveModel(modelId, 'language models', 'languageModel'),
|
||||
textEmbeddingModel: (modelId: string) =>
|
||||
resolveModel<EmbeddingModelV2<string>>(modelId, 'text embedding models', 'textEmbeddingModel'),
|
||||
imageModel: (modelId: string) => resolveModel<ImageModelV2>(modelId, 'image models', 'imageModel'),
|
||||
transcriptionModel: (modelId: string) =>
|
||||
resolveModel<TranscriptionModelV2>(modelId, 'transcription models', 'transcriptionModel'),
|
||||
speechModel: (modelId: string) => resolveModel<SpeechModelV2>(modelId, 'speech models', 'speechModel')
|
||||
}
|
||||
})
|
||||
}
|
||||
220
packages/aiCore/src/core/providers/RegistryManagement.ts
Normal file
220
packages/aiCore/src/core/providers/RegistryManagement.ts
Normal file
@@ -0,0 +1,220 @@
|
||||
/**
|
||||
* Provider 注册表管理器
|
||||
* 纯粹的管理功能:存储、检索已配置好的 provider 实例
|
||||
* 基于 AI SDK 原生的 createProviderRegistry
|
||||
*/
|
||||
|
||||
import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, ProviderV2 } from '@ai-sdk/provider'
|
||||
import { createProviderRegistry, type ProviderRegistryProvider } from 'ai'
|
||||
|
||||
type PROVIDERS = Record<string, ProviderV2>
|
||||
|
||||
export const DEFAULT_SEPARATOR = ':'
|
||||
|
||||
// export type MODEL_ID = `${string}${typeof DEFAULT_SEPARATOR}${string}`
|
||||
|
||||
export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARATOR> {
|
||||
private providers: PROVIDERS = {}
|
||||
private aliases: Set<string> = new Set() // 记录哪些key是别名
|
||||
private separator: SEPARATOR
|
||||
private registry: ProviderRegistryProvider<PROVIDERS, SEPARATOR> | null = null
|
||||
|
||||
constructor(options: { separator: SEPARATOR } = { separator: DEFAULT_SEPARATOR as SEPARATOR }) {
|
||||
this.separator = options.separator
|
||||
}
|
||||
|
||||
/**
|
||||
* 注册已配置好的 provider 实例
|
||||
*/
|
||||
registerProvider(id: string, provider: ProviderV2, aliases?: string[]): this {
|
||||
// 注册主provider
|
||||
this.providers[id] = provider
|
||||
|
||||
// 注册别名(都指向同一个provider实例)
|
||||
if (aliases) {
|
||||
aliases.forEach((alias) => {
|
||||
this.providers[alias] = provider // 直接存储引用
|
||||
this.aliases.add(alias) // 标记为别名
|
||||
})
|
||||
}
|
||||
|
||||
this.rebuildRegistry()
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取已注册的provider实例
|
||||
*/
|
||||
getProvider(id: string): ProviderV2 | undefined {
|
||||
return this.providers[id]
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量注册 providers
|
||||
*/
|
||||
registerProviders(providers: Record<string, ProviderV2>): this {
|
||||
Object.assign(this.providers, providers)
|
||||
this.rebuildRegistry()
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 移除 provider(同时清理相关别名)
|
||||
*/
|
||||
unregisterProvider(id: string): this {
|
||||
const provider = this.providers[id]
|
||||
if (!provider) return this
|
||||
|
||||
// 如果移除的是真实ID,需要清理所有指向它的别名
|
||||
if (!this.aliases.has(id)) {
|
||||
// 找到所有指向此provider的别名并删除
|
||||
const aliasesToRemove: string[] = []
|
||||
this.aliases.forEach((alias) => {
|
||||
if (this.providers[alias] === provider) {
|
||||
aliasesToRemove.push(alias)
|
||||
}
|
||||
})
|
||||
|
||||
aliasesToRemove.forEach((alias) => {
|
||||
delete this.providers[alias]
|
||||
this.aliases.delete(alias)
|
||||
})
|
||||
} else {
|
||||
// 如果移除的是别名,只删除别名记录
|
||||
this.aliases.delete(id)
|
||||
}
|
||||
|
||||
delete this.providers[id]
|
||||
this.rebuildRegistry()
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 立即重建 registry - 每次变更都重建
|
||||
*/
|
||||
private rebuildRegistry(): void {
|
||||
if (Object.keys(this.providers).length === 0) {
|
||||
this.registry = null
|
||||
return
|
||||
}
|
||||
|
||||
this.registry = createProviderRegistry<PROVIDERS, SEPARATOR>(this.providers, {
|
||||
separator: this.separator
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取语言模型 - AI SDK 原生方法
|
||||
*/
|
||||
languageModel(id: `${string}${SEPARATOR}${string}`): LanguageModelV2 {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.languageModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取文本嵌入模型 - AI SDK 原生方法
|
||||
*/
|
||||
textEmbeddingModel(id: `${string}${SEPARATOR}${string}`): EmbeddingModelV2<string> {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.textEmbeddingModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取图像模型 - AI SDK 原生方法
|
||||
*/
|
||||
imageModel(id: `${string}${SEPARATOR}${string}`): ImageModelV2 {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.imageModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取转录模型 - AI SDK 原生方法
|
||||
*/
|
||||
transcriptionModel(id: `${string}${SEPARATOR}${string}`): any {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.transcriptionModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取语音模型 - AI SDK 原生方法
|
||||
*/
|
||||
speechModel(id: `${string}${SEPARATOR}${string}`): any {
|
||||
if (!this.registry) {
|
||||
throw new Error('No providers registered')
|
||||
}
|
||||
return this.registry.speechModel(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取已注册的 provider 列表
|
||||
*/
|
||||
getRegisteredProviders(): string[] {
|
||||
return Object.keys(this.providers)
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否有已注册的 providers
|
||||
*/
|
||||
hasProviders(): boolean {
|
||||
return Object.keys(this.providers).length > 0
|
||||
}
|
||||
|
||||
/**
|
||||
* 清除所有 providers
|
||||
*/
|
||||
clear(): this {
|
||||
this.providers = {}
|
||||
this.aliases.clear()
|
||||
this.registry = null
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析真实的Provider ID(供getAiSdkProviderId使用)
|
||||
* 如果传入的是别名,返回真实的Provider ID
|
||||
* 如果传入的是真实ID,直接返回
|
||||
*/
|
||||
resolveProviderId(id: string): string {
|
||||
if (!this.aliases.has(id)) return id // 不是别名,直接返回
|
||||
|
||||
// 是别名,找到真实ID
|
||||
const targetProvider = this.providers[id]
|
||||
for (const [realId, provider] of Object.entries(this.providers)) {
|
||||
if (provider === targetProvider && !this.aliases.has(realId)) {
|
||||
return realId
|
||||
}
|
||||
}
|
||||
return id
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否为别名
|
||||
*/
|
||||
isAlias(id: string): boolean {
|
||||
return this.aliases.has(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有别名映射关系
|
||||
*/
|
||||
getAllAliases(): Record<string, string> {
|
||||
const result: Record<string, string> = {}
|
||||
this.aliases.forEach((alias) => {
|
||||
result[alias] = this.resolveProviderId(alias)
|
||||
})
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 全局注册表管理器实例
|
||||
*/
|
||||
export const globalRegistryManagement = new RegistryManagement<':'>()
|
||||
@@ -0,0 +1,632 @@
|
||||
/**
|
||||
* 测试真正的 AiProviderRegistry 功能
|
||||
*/
|
||||
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
// 模拟 AI SDK
|
||||
vi.mock('@ai-sdk/openai', () => ({
|
||||
createOpenAI: vi.fn(() => ({ name: 'openai-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/anthropic', () => ({
|
||||
createAnthropic: vi.fn(() => ({ name: 'anthropic-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/azure', () => ({
|
||||
createAzure: vi.fn(() => ({ name: 'azure-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/deepseek', () => ({
|
||||
createDeepSeek: vi.fn(() => ({ name: 'deepseek-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/google', () => ({
|
||||
createGoogleGenerativeAI: vi.fn(() => ({ name: 'google-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/openai-compatible', () => ({
|
||||
createOpenAICompatible: vi.fn(() => ({ name: 'openai-compatible-mock' }))
|
||||
}))
|
||||
|
||||
vi.mock('@ai-sdk/xai', () => ({
|
||||
createXai: vi.fn(() => ({ name: 'xai-mock' }))
|
||||
}))
|
||||
|
||||
import {
|
||||
cleanup,
|
||||
clearAllProviders,
|
||||
createAndRegisterProvider,
|
||||
createProvider,
|
||||
getAllProviderConfigAliases,
|
||||
getAllProviderConfigs,
|
||||
getInitializedProviders,
|
||||
getLanguageModel,
|
||||
getProviderConfig,
|
||||
getProviderConfigByAlias,
|
||||
getSupportedProviders,
|
||||
hasInitializedProviders,
|
||||
hasProviderConfig,
|
||||
hasProviderConfigByAlias,
|
||||
isProviderConfigAlias,
|
||||
ProviderInitializationError,
|
||||
providerRegistry,
|
||||
registerMultipleProviderConfigs,
|
||||
registerProvider,
|
||||
registerProviderConfig,
|
||||
resolveProviderConfigId
|
||||
} from '../registry'
|
||||
import type { ProviderConfig } from '../schemas'
|
||||
|
||||
describe('Provider Registry 功能测试', () => {
|
||||
beforeEach(() => {
|
||||
// 清理状态
|
||||
cleanup()
|
||||
})
|
||||
|
||||
describe('基础功能', () => {
|
||||
it('能够获取支持的 providers 列表', () => {
|
||||
const providers = getSupportedProviders()
|
||||
expect(Array.isArray(providers)).toBe(true)
|
||||
expect(providers.length).toBeGreaterThan(0)
|
||||
|
||||
// 检查返回的数据结构
|
||||
providers.forEach((provider) => {
|
||||
expect(provider).toHaveProperty('id')
|
||||
expect(provider).toHaveProperty('name')
|
||||
expect(typeof provider.id).toBe('string')
|
||||
expect(typeof provider.name).toBe('string')
|
||||
})
|
||||
|
||||
// 包含基础 providers
|
||||
const providerIds = providers.map((p) => p.id)
|
||||
expect(providerIds).toContain('openai')
|
||||
expect(providerIds).toContain('anthropic')
|
||||
expect(providerIds).toContain('google')
|
||||
})
|
||||
|
||||
it('能够获取已初始化的 providers', () => {
|
||||
// 初始状态下没有已初始化的 providers
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
expect(hasInitializedProviders()).toBe(false)
|
||||
})
|
||||
|
||||
it('能够访问全局注册管理器', () => {
|
||||
expect(providerRegistry).toBeDefined()
|
||||
expect(typeof providerRegistry.clear).toBe('function')
|
||||
expect(typeof providerRegistry.getRegisteredProviders).toBe('function')
|
||||
expect(typeof providerRegistry.hasProviders).toBe('function')
|
||||
})
|
||||
|
||||
it('能够获取语言模型', () => {
|
||||
// 在没有注册 provider 的情况下,这个函数可能会抛出错误或返回 undefined
|
||||
expect(() => getLanguageModel('non-existent')).not.toThrow()
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider 配置注册', () => {
|
||||
it('能够注册自定义 provider 配置', () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
creator: vi.fn(() => ({ name: 'custom' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
|
||||
expect(hasProviderConfig('custom-provider')).toBe(true)
|
||||
expect(getProviderConfig('custom-provider')).toEqual(config)
|
||||
})
|
||||
|
||||
it('能够注册带别名的 provider 配置', () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'custom-provider-with-aliases',
|
||||
name: 'Custom Provider with Aliases',
|
||||
creator: vi.fn(() => ({ name: 'custom-aliased' })),
|
||||
supportsImageGeneration: false,
|
||||
aliases: ['alias-1', 'alias-2']
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
|
||||
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
|
||||
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
|
||||
expect(getProviderConfigByAlias('alias-1')).toEqual(config)
|
||||
expect(resolveProviderConfigId('alias-1')).toBe('custom-provider-with-aliases')
|
||||
})
|
||||
|
||||
it('拒绝无效的配置', () => {
|
||||
// 缺少必要字段
|
||||
const invalidConfig = {
|
||||
id: 'invalid-provider'
|
||||
// 缺少 name, creator 等
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(invalidConfig as any)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('能够批量注册 provider 配置', () => {
|
||||
const configs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'provider-1',
|
||||
name: 'Provider 1',
|
||||
creator: vi.fn(() => ({ name: 'provider-1' })),
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'provider-2',
|
||||
name: 'Provider 2',
|
||||
creator: vi.fn(() => ({ name: 'provider-2' })),
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: '', // 无效配置
|
||||
name: 'Invalid Provider',
|
||||
creator: vi.fn(() => ({ name: 'invalid' })),
|
||||
supportsImageGeneration: false
|
||||
} as any
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(configs)
|
||||
expect(successCount).toBe(2) // 只有前两个成功
|
||||
|
||||
expect(hasProviderConfig('provider-1')).toBe(true)
|
||||
expect(hasProviderConfig('provider-2')).toBe(true)
|
||||
expect(hasProviderConfig('')).toBe(false)
|
||||
})
|
||||
|
||||
it('能够获取所有配置和别名信息', () => {
|
||||
// 注册一些配置
|
||||
registerProviderConfig({
|
||||
id: 'test-provider',
|
||||
name: 'Test Provider',
|
||||
creator: vi.fn(),
|
||||
supportsImageGeneration: false,
|
||||
aliases: ['test-alias']
|
||||
})
|
||||
|
||||
const allConfigs = getAllProviderConfigs()
|
||||
expect(Array.isArray(allConfigs)).toBe(true)
|
||||
expect(allConfigs.some((config) => config.id === 'test-provider')).toBe(true)
|
||||
|
||||
const aliases = getAllProviderConfigAliases()
|
||||
expect(aliases['test-alias']).toBe('test-provider')
|
||||
expect(isProviderConfigAlias('test-alias')).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Provider 创建和注册', () => {
|
||||
it('能够创建 provider 实例', async () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'test-create-provider',
|
||||
name: 'Test Create Provider',
|
||||
creator: vi.fn(() => ({ name: 'test-created' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 先注册配置
|
||||
registerProviderConfig(config)
|
||||
|
||||
// 创建 provider 实例
|
||||
const provider = await createProvider('test-create-provider', { apiKey: 'test' })
|
||||
expect(provider).toBeDefined()
|
||||
expect(config.creator).toHaveBeenCalledWith({ apiKey: 'test' })
|
||||
})
|
||||
|
||||
it('能够注册 provider 到全局管理器', () => {
|
||||
const mockProvider = { name: 'mock-provider' }
|
||||
const config: ProviderConfig = {
|
||||
id: 'test-register-provider',
|
||||
name: 'Test Register Provider',
|
||||
creator: vi.fn(() => mockProvider),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 先注册配置
|
||||
registerProviderConfig(config)
|
||||
|
||||
// 注册 provider 到全局管理器
|
||||
const success = registerProvider('test-register-provider', mockProvider)
|
||||
expect(success).toBe(true)
|
||||
|
||||
// 验证注册成功
|
||||
const registeredProviders = getInitializedProviders()
|
||||
expect(registeredProviders).toContain('test-register-provider')
|
||||
expect(hasInitializedProviders()).toBe(true)
|
||||
})
|
||||
|
||||
it('能够一步完成创建和注册', async () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'test-create-and-register',
|
||||
name: 'Test Create and Register',
|
||||
creator: vi.fn(() => ({ name: 'test-both' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 先注册配置
|
||||
registerProviderConfig(config)
|
||||
|
||||
// 一步完成创建和注册
|
||||
const success = await createAndRegisterProvider('test-create-and-register', { apiKey: 'test' })
|
||||
expect(success).toBe(true)
|
||||
|
||||
// 验证注册成功
|
||||
const registeredProviders = getInitializedProviders()
|
||||
expect(registeredProviders).toContain('test-create-and-register')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Registry 管理', () => {
|
||||
it('能够清理所有配置和注册的 providers', () => {
|
||||
// 注册一些配置
|
||||
registerProviderConfig({
|
||||
id: 'temp-provider',
|
||||
name: 'Temp Provider',
|
||||
creator: vi.fn(() => ({ name: 'temp' })),
|
||||
supportsImageGeneration: false
|
||||
})
|
||||
|
||||
expect(hasProviderConfig('temp-provider')).toBe(true)
|
||||
|
||||
// 清理
|
||||
cleanup()
|
||||
|
||||
expect(hasProviderConfig('temp-provider')).toBe(false)
|
||||
// 但基础配置应该重新加载
|
||||
expect(hasProviderConfig('openai')).toBe(true) // 基础 providers 会重新初始化
|
||||
})
|
||||
|
||||
it('能够单独清理已注册的 providers', () => {
|
||||
// 清理所有 providers
|
||||
clearAllProviders()
|
||||
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
expect(hasInitializedProviders()).toBe(false)
|
||||
})
|
||||
|
||||
it('ProviderInitializationError 错误类工作正常', () => {
|
||||
const error = new ProviderInitializationError('Test error', 'test-provider')
|
||||
expect(error.message).toBe('Test error')
|
||||
expect(error.providerId).toBe('test-provider')
|
||||
expect(error.name).toBe('ProviderInitializationError')
|
||||
})
|
||||
})
|
||||
|
||||
describe('错误处理', () => {
|
||||
it('优雅处理空配置', () => {
|
||||
const success = registerProviderConfig(null as any)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('优雅处理未定义配置', () => {
|
||||
const success = registerProviderConfig(undefined as any)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('处理空字符串 ID', () => {
|
||||
const config = {
|
||||
id: '',
|
||||
name: 'Empty ID Provider',
|
||||
creator: vi.fn(() => ({ name: 'empty' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('处理创建不存在配置的 provider', async () => {
|
||||
await expect(createProvider('non-existent-provider', {})).rejects.toThrow(
|
||||
'ProviderConfig not found for id: non-existent-provider'
|
||||
)
|
||||
})
|
||||
|
||||
it('处理注册不存在配置的 provider', () => {
|
||||
const mockProvider = { name: 'mock' }
|
||||
const success = registerProvider('non-existent-provider', mockProvider)
|
||||
expect(success).toBe(false)
|
||||
})
|
||||
|
||||
it('处理获取不存在配置的情况', () => {
|
||||
expect(getProviderConfig('non-existent')).toBeUndefined()
|
||||
expect(getProviderConfigByAlias('non-existent-alias')).toBeUndefined()
|
||||
expect(hasProviderConfig('non-existent')).toBe(false)
|
||||
expect(hasProviderConfigByAlias('non-existent-alias')).toBe(false)
|
||||
})
|
||||
|
||||
it('处理批量注册时的部分失败', () => {
|
||||
const mixedConfigs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'valid-provider-1',
|
||||
name: 'Valid Provider 1',
|
||||
creator: vi.fn(() => ({ name: 'valid-1' })),
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: '', // 无效配置
|
||||
name: 'Invalid Provider',
|
||||
creator: vi.fn(() => ({ name: 'invalid' })),
|
||||
supportsImageGeneration: false
|
||||
} as any,
|
||||
{
|
||||
id: 'valid-provider-2',
|
||||
name: 'Valid Provider 2',
|
||||
creator: vi.fn(() => ({ name: 'valid-2' })),
|
||||
supportsImageGeneration: true
|
||||
}
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(mixedConfigs)
|
||||
expect(successCount).toBe(2) // 只有两个有效配置成功
|
||||
|
||||
expect(hasProviderConfig('valid-provider-1')).toBe(true)
|
||||
expect(hasProviderConfig('valid-provider-2')).toBe(true)
|
||||
expect(hasProviderConfig('')).toBe(false)
|
||||
})
|
||||
|
||||
it('处理动态导入失败的情况', async () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'import-test-provider',
|
||||
name: 'Import Test Provider',
|
||||
import: vi.fn().mockRejectedValue(new Error('Import failed')),
|
||||
creatorFunctionName: 'createTest',
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
registerProviderConfig(config)
|
||||
|
||||
await expect(createProvider('import-test-provider', {})).rejects.toThrow('Import failed')
|
||||
})
|
||||
})
|
||||
|
||||
describe('集成测试', () => {
|
||||
it('正确处理复杂的配置、创建、注册和清理场景', async () => {
|
||||
// 初始状态验证
|
||||
const initialConfigs = getAllProviderConfigs()
|
||||
expect(initialConfigs.length).toBeGreaterThan(0) // 有基础配置
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
|
||||
// 注册多个带别名的 provider 配置
|
||||
const configs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'integration-provider-1',
|
||||
name: 'Integration Provider 1',
|
||||
creator: vi.fn(() => ({ name: 'integration-1' })),
|
||||
supportsImageGeneration: false,
|
||||
aliases: ['alias-1', 'short-name-1']
|
||||
},
|
||||
{
|
||||
id: 'integration-provider-2',
|
||||
name: 'Integration Provider 2',
|
||||
creator: vi.fn(() => ({ name: 'integration-2' })),
|
||||
supportsImageGeneration: true,
|
||||
aliases: ['alias-2', 'short-name-2']
|
||||
}
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(configs)
|
||||
expect(successCount).toBe(2)
|
||||
|
||||
// 验证配置注册成功
|
||||
expect(hasProviderConfig('integration-provider-1')).toBe(true)
|
||||
expect(hasProviderConfig('integration-provider-2')).toBe(true)
|
||||
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
|
||||
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
|
||||
|
||||
// 验证别名映射
|
||||
const aliases = getAllProviderConfigAliases()
|
||||
expect(aliases['alias-1']).toBe('integration-provider-1')
|
||||
expect(aliases['alias-2']).toBe('integration-provider-2')
|
||||
|
||||
// 创建和注册 providers
|
||||
const success1 = await createAndRegisterProvider('integration-provider-1', { apiKey: 'test1' })
|
||||
const success2 = await createAndRegisterProvider('integration-provider-2', { apiKey: 'test2' })
|
||||
expect(success1).toBe(true)
|
||||
expect(success2).toBe(true)
|
||||
|
||||
// 验证注册成功
|
||||
const registeredProviders = getInitializedProviders()
|
||||
expect(registeredProviders).toContain('integration-provider-1')
|
||||
expect(registeredProviders).toContain('integration-provider-2')
|
||||
expect(hasInitializedProviders()).toBe(true)
|
||||
|
||||
// 清理
|
||||
cleanup()
|
||||
|
||||
// 验证清理后的状态
|
||||
expect(getInitializedProviders()).toEqual([])
|
||||
expect(hasProviderConfig('integration-provider-1')).toBe(false)
|
||||
expect(hasProviderConfig('integration-provider-2')).toBe(false)
|
||||
expect(getAllProviderConfigAliases()).toEqual({})
|
||||
|
||||
// 基础配置应该重新加载
|
||||
expect(hasProviderConfig('openai')).toBe(true)
|
||||
})
|
||||
|
||||
it('正确处理动态导入配置的 provider', async () => {
|
||||
const mockModule = { createCustomProvider: vi.fn(() => ({ name: 'custom-dynamic' })) }
|
||||
const dynamicImportConfig: ProviderConfig = {
|
||||
id: 'dynamic-import-provider',
|
||||
name: 'Dynamic Import Provider',
|
||||
import: vi.fn().mockResolvedValue(mockModule),
|
||||
creatorFunctionName: 'createCustomProvider',
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 注册配置
|
||||
const configSuccess = registerProviderConfig(dynamicImportConfig)
|
||||
expect(configSuccess).toBe(true)
|
||||
|
||||
// 创建和注册 provider
|
||||
const registerSuccess = await createAndRegisterProvider('dynamic-import-provider', { apiKey: 'test' })
|
||||
expect(registerSuccess).toBe(true)
|
||||
|
||||
// 验证导入函数被调用
|
||||
expect(dynamicImportConfig.import).toHaveBeenCalled()
|
||||
expect(mockModule.createCustomProvider).toHaveBeenCalledWith({ apiKey: 'test' })
|
||||
|
||||
// 验证注册成功
|
||||
expect(getInitializedProviders()).toContain('dynamic-import-provider')
|
||||
})
|
||||
|
||||
it('正确处理大量配置的注册和管理', () => {
|
||||
const largeConfigList: ProviderConfig[] = []
|
||||
|
||||
// 生成50个配置
|
||||
for (let i = 0; i < 50; i++) {
|
||||
largeConfigList.push({
|
||||
id: `bulk-provider-${i}`,
|
||||
name: `Bulk Provider ${i}`,
|
||||
creator: vi.fn(() => ({ name: `bulk-${i}` })),
|
||||
supportsImageGeneration: i % 2 === 0, // 偶数支持图像生成
|
||||
aliases: [`alias-${i}`, `short-${i}`]
|
||||
})
|
||||
}
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(largeConfigList)
|
||||
expect(successCount).toBe(50)
|
||||
|
||||
// 验证所有配置都被正确注册
|
||||
const allConfigs = getAllProviderConfigs()
|
||||
expect(allConfigs.filter((config) => config.id.startsWith('bulk-provider-')).length).toBe(50)
|
||||
|
||||
// 验证别名数量
|
||||
const aliases = getAllProviderConfigAliases()
|
||||
const bulkAliases = Object.keys(aliases).filter(
|
||||
(alias) => alias.startsWith('alias-') || alias.startsWith('short-')
|
||||
)
|
||||
expect(bulkAliases.length).toBe(100) // 每个 provider 有2个别名
|
||||
|
||||
// 随机验证几个配置
|
||||
expect(hasProviderConfig('bulk-provider-0')).toBe(true)
|
||||
expect(hasProviderConfig('bulk-provider-25')).toBe(true)
|
||||
expect(hasProviderConfig('bulk-provider-49')).toBe(true)
|
||||
|
||||
// 验证别名工作正常
|
||||
expect(resolveProviderConfigId('alias-25')).toBe('bulk-provider-25')
|
||||
expect(isProviderConfigAlias('short-30')).toBe(true)
|
||||
|
||||
// 清理能正确处理大量数据
|
||||
cleanup()
|
||||
const cleanupAliases = getAllProviderConfigAliases()
|
||||
expect(
|
||||
Object.keys(cleanupAliases).filter((alias) => alias.startsWith('alias-') || alias.startsWith('short-'))
|
||||
).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
describe('边界测试', () => {
|
||||
it('处理包含特殊字符的 provider IDs', () => {
|
||||
const specialCharsConfigs: ProviderConfig[] = [
|
||||
{
|
||||
id: 'provider-with-dashes',
|
||||
name: 'Provider With Dashes',
|
||||
creator: vi.fn(() => ({ name: 'dashes' })),
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'provider_with_underscores',
|
||||
name: 'Provider With Underscores',
|
||||
creator: vi.fn(() => ({ name: 'underscores' })),
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'provider.with.dots',
|
||||
name: 'Provider With Dots',
|
||||
creator: vi.fn(() => ({ name: 'dots' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
]
|
||||
|
||||
const successCount = registerMultipleProviderConfigs(specialCharsConfigs)
|
||||
expect(successCount).toBeGreaterThan(0) // 至少有一些成功
|
||||
|
||||
// 验证支持的特殊字符格式
|
||||
if (hasProviderConfig('provider-with-dashes')) {
|
||||
expect(getProviderConfig('provider-with-dashes')).toBeDefined()
|
||||
}
|
||||
if (hasProviderConfig('provider_with_underscores')) {
|
||||
expect(getProviderConfig('provider_with_underscores')).toBeDefined()
|
||||
}
|
||||
})
|
||||
|
||||
it('处理空的批量注册', () => {
|
||||
const successCount = registerMultipleProviderConfigs([])
|
||||
expect(successCount).toBe(0)
|
||||
|
||||
// 确保没有额外的配置被添加
|
||||
const configsBefore = getAllProviderConfigs().length
|
||||
expect(configsBefore).toBeGreaterThan(0) // 应该有基础配置
|
||||
})
|
||||
|
||||
it('处理重复的配置注册', () => {
|
||||
const config: ProviderConfig = {
|
||||
id: 'duplicate-test-provider',
|
||||
name: 'Duplicate Test Provider',
|
||||
creator: vi.fn(() => ({ name: 'duplicate' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
// 第一次注册成功
|
||||
expect(registerProviderConfig(config)).toBe(true)
|
||||
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
|
||||
|
||||
// 重复注册相同的配置(允许覆盖)
|
||||
const updatedConfig: ProviderConfig = {
|
||||
...config,
|
||||
name: 'Updated Duplicate Test Provider'
|
||||
}
|
||||
expect(registerProviderConfig(updatedConfig)).toBe(true)
|
||||
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
|
||||
|
||||
// 验证配置被更新
|
||||
const retrievedConfig = getProviderConfig('duplicate-test-provider')
|
||||
expect(retrievedConfig?.name).toBe('Updated Duplicate Test Provider')
|
||||
})
|
||||
|
||||
it('处理极长的 ID 和名称', () => {
|
||||
const longId = 'very-long-provider-id-' + 'x'.repeat(100)
|
||||
const longName = 'Very Long Provider Name ' + 'Y'.repeat(100)
|
||||
|
||||
const config: ProviderConfig = {
|
||||
id: longId,
|
||||
name: longName,
|
||||
creator: vi.fn(() => ({ name: 'long-test' })),
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
expect(hasProviderConfig(longId)).toBe(true)
|
||||
|
||||
const retrievedConfig = getProviderConfig(longId)
|
||||
expect(retrievedConfig?.name).toBe(longName)
|
||||
})
|
||||
|
||||
it('处理大量别名的配置', () => {
|
||||
const manyAliases = Array.from({ length: 50 }, (_, i) => `alias-${i}`)
|
||||
|
||||
const config: ProviderConfig = {
|
||||
id: 'provider-with-many-aliases',
|
||||
name: 'Provider With Many Aliases',
|
||||
creator: vi.fn(() => ({ name: 'many-aliases' })),
|
||||
supportsImageGeneration: false,
|
||||
aliases: manyAliases
|
||||
}
|
||||
|
||||
const success = registerProviderConfig(config)
|
||||
expect(success).toBe(true)
|
||||
|
||||
// 验证所有别名都能正确解析
|
||||
manyAliases.forEach((alias) => {
|
||||
expect(hasProviderConfigByAlias(alias)).toBe(true)
|
||||
expect(resolveProviderConfigId(alias)).toBe('provider-with-many-aliases')
|
||||
expect(isProviderConfigAlias(alias)).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
264
packages/aiCore/src/core/providers/__tests__/schemas.test.ts
Normal file
264
packages/aiCore/src/core/providers/__tests__/schemas.test.ts
Normal file
@@ -0,0 +1,264 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
type BaseProviderId,
|
||||
baseProviderIds,
|
||||
baseProviderIdSchema,
|
||||
baseProviders,
|
||||
type CustomProviderId,
|
||||
customProviderIdSchema,
|
||||
providerConfigSchema,
|
||||
type ProviderId,
|
||||
providerIdSchema
|
||||
} from '../schemas'
|
||||
|
||||
describe('Provider Schemas', () => {
|
||||
describe('baseProviders', () => {
|
||||
it('包含所有预期的基础 providers', () => {
|
||||
expect(baseProviders).toBeDefined()
|
||||
expect(Array.isArray(baseProviders)).toBe(true)
|
||||
expect(baseProviders.length).toBeGreaterThan(0)
|
||||
|
||||
const expectedIds = [
|
||||
'openai',
|
||||
'openai-responses',
|
||||
'openai-compatible',
|
||||
'anthropic',
|
||||
'google',
|
||||
'xai',
|
||||
'azure',
|
||||
'deepseek'
|
||||
]
|
||||
const actualIds = baseProviders.map((p) => p.id)
|
||||
expectedIds.forEach((id) => {
|
||||
expect(actualIds).toContain(id)
|
||||
})
|
||||
})
|
||||
|
||||
it('每个基础 provider 有必要的属性', () => {
|
||||
baseProviders.forEach((provider) => {
|
||||
expect(provider).toHaveProperty('id')
|
||||
expect(provider).toHaveProperty('name')
|
||||
expect(provider).toHaveProperty('creator')
|
||||
expect(provider).toHaveProperty('supportsImageGeneration')
|
||||
|
||||
expect(typeof provider.id).toBe('string')
|
||||
expect(typeof provider.name).toBe('string')
|
||||
expect(typeof provider.creator).toBe('function')
|
||||
expect(typeof provider.supportsImageGeneration).toBe('boolean')
|
||||
})
|
||||
})
|
||||
|
||||
it('provider ID 是唯一的', () => {
|
||||
const ids = baseProviders.map((p) => p.id)
|
||||
const uniqueIds = [...new Set(ids)]
|
||||
expect(ids).toEqual(uniqueIds)
|
||||
})
|
||||
})
|
||||
|
||||
describe('baseProviderIds', () => {
|
||||
it('正确提取所有基础 provider IDs', () => {
|
||||
expect(baseProviderIds).toBeDefined()
|
||||
expect(Array.isArray(baseProviderIds)).toBe(true)
|
||||
expect(baseProviderIds.length).toBe(baseProviders.length)
|
||||
|
||||
baseProviders.forEach((provider) => {
|
||||
expect(baseProviderIds).toContain(provider.id)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('baseProviderIdSchema', () => {
|
||||
it('验证有效的基础 provider IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(baseProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝无效的基础 provider IDs', () => {
|
||||
const invalidIds = ['invalid', 'not-exists', '']
|
||||
invalidIds.forEach((id) => {
|
||||
expect(baseProviderIdSchema.safeParse(id).success).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('customProviderIdSchema', () => {
|
||||
it('接受有效的自定义 provider IDs', () => {
|
||||
const validIds = ['custom-provider', 'my-ai-service', 'company-llm-v2']
|
||||
validIds.forEach((id) => {
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝与基础 provider IDs 冲突的 IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝空字符串', () => {
|
||||
expect(customProviderIdSchema.safeParse('').success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('providerIdSchema', () => {
|
||||
it('接受基础 provider IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('接受有效的自定义 provider IDs', () => {
|
||||
const validCustomIds = ['custom-provider', 'my-ai-service']
|
||||
validCustomIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('拒绝无效的 IDs', () => {
|
||||
const invalidIds = ['', undefined, null, 123]
|
||||
invalidIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('providerConfigSchema', () => {
|
||||
it('验证带有 creator 的有效配置', () => {
|
||||
const validConfig = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
creator: vi.fn(),
|
||||
supportsImageGeneration: true
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
|
||||
})
|
||||
|
||||
it('验证带有 import 配置的有效配置', () => {
|
||||
const validConfig = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
import: vi.fn(),
|
||||
creatorFunctionName: 'createCustom',
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
|
||||
})
|
||||
|
||||
it('拒绝既没有 creator 也没有 import 配置的配置', () => {
|
||||
const invalidConfig = {
|
||||
id: 'invalid',
|
||||
name: 'Invalid Provider',
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
|
||||
})
|
||||
|
||||
it('为 supportsImageGeneration 设置默认值', () => {
|
||||
const config = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
creator: vi.fn()
|
||||
}
|
||||
const result = providerConfigSchema.safeParse(config)
|
||||
expect(result.success).toBe(true)
|
||||
if (result.success) {
|
||||
expect(result.data.supportsImageGeneration).toBe(false)
|
||||
}
|
||||
})
|
||||
|
||||
it('拒绝使用基础 provider ID 的配置', () => {
|
||||
const invalidConfig = {
|
||||
id: 'openai', // 基础 provider ID
|
||||
name: 'Should Fail',
|
||||
creator: vi.fn()
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
|
||||
})
|
||||
|
||||
it('拒绝缺少必需字段的配置', () => {
|
||||
const invalidConfigs = [
|
||||
{ name: 'Missing ID', creator: vi.fn() },
|
||||
{ id: 'missing-name', creator: vi.fn() },
|
||||
{ id: '', name: 'Empty ID', creator: vi.fn() },
|
||||
{ id: 'valid-custom', name: '', creator: vi.fn() }
|
||||
]
|
||||
|
||||
invalidConfigs.forEach((config) => {
|
||||
expect(providerConfigSchema.safeParse(config).success).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Schema 验证功能', () => {
|
||||
it('baseProviderIdSchema 正确验证基础 provider IDs', () => {
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(baseProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
|
||||
expect(baseProviderIdSchema.safeParse('invalid-id').success).toBe(false)
|
||||
})
|
||||
|
||||
it('customProviderIdSchema 正确验证自定义 provider IDs', () => {
|
||||
const customIds = ['custom-provider', 'my-service', 'company-llm']
|
||||
customIds.forEach((id) => {
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
|
||||
// 拒绝基础 provider IDs
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
it('providerIdSchema 接受基础和自定义 provider IDs', () => {
|
||||
// 基础 IDs
|
||||
baseProviderIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
|
||||
// 自定义 IDs
|
||||
const customIds = ['custom-provider', 'my-service']
|
||||
customIds.forEach((id) => {
|
||||
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
it('providerConfigSchema 验证完整的 provider 配置', () => {
|
||||
const validConfig = {
|
||||
id: 'custom-provider',
|
||||
name: 'Custom Provider',
|
||||
creator: vi.fn(),
|
||||
supportsImageGeneration: true
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
|
||||
|
||||
const invalidConfig = {
|
||||
id: 'openai', // 不允许基础 provider ID
|
||||
name: 'OpenAI',
|
||||
creator: vi.fn()
|
||||
}
|
||||
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('类型推导', () => {
|
||||
it('BaseProviderId 类型正确', () => {
|
||||
const id: BaseProviderId = 'openai'
|
||||
expect(baseProviderIds).toContain(id)
|
||||
})
|
||||
|
||||
it('CustomProviderId 类型是字符串', () => {
|
||||
const id: CustomProviderId = 'custom-provider'
|
||||
expect(typeof id).toBe('string')
|
||||
})
|
||||
|
||||
it('ProviderId 类型支持基础和自定义 IDs', () => {
|
||||
const baseId: ProviderId = 'openai'
|
||||
const customId: ProviderId = 'custom-provider'
|
||||
expect(typeof baseId).toBe('string')
|
||||
expect(typeof customId).toBe('string')
|
||||
})
|
||||
})
|
||||
})
|
||||
291
packages/aiCore/src/core/providers/factory.ts
Normal file
291
packages/aiCore/src/core/providers/factory.ts
Normal file
@@ -0,0 +1,291 @@
|
||||
/**
|
||||
* AI Provider 配置工厂
|
||||
* 提供类型安全的 Provider 配置构建器
|
||||
*/
|
||||
|
||||
import type { ProviderId, ProviderSettingsMap } from './types'
|
||||
|
||||
/**
|
||||
* 通用配置基础类型,包含所有 Provider 共有的属性
|
||||
*/
|
||||
export interface BaseProviderConfig {
|
||||
apiKey?: string
|
||||
baseURL?: string
|
||||
timeout?: number
|
||||
headers?: Record<string, string>
|
||||
fetch?: typeof globalThis.fetch
|
||||
}
|
||||
|
||||
/**
|
||||
* 完整的配置类型,结合基础配置、AI SDK 配置和特定 Provider 配置
|
||||
*/
|
||||
type CompleteProviderConfig<T extends ProviderId> = BaseProviderConfig & Partial<ProviderSettingsMap[T]>
|
||||
|
||||
type ConfigHandler<T extends ProviderId> = (
|
||||
builder: ProviderConfigBuilder<T>,
|
||||
provider: CompleteProviderConfig<T>
|
||||
) => void
|
||||
|
||||
const configHandlers: {
|
||||
[K in ProviderId]?: ConfigHandler<K>
|
||||
} = {
|
||||
azure: (builder, provider) => {
|
||||
const azureBuilder = builder as ProviderConfigBuilder<'azure'>
|
||||
const azureProvider = provider as CompleteProviderConfig<'azure'>
|
||||
azureBuilder.withAzureConfig({
|
||||
apiVersion: azureProvider.apiVersion,
|
||||
resourceName: azureProvider.resourceName
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
|
||||
private config: CompleteProviderConfig<T> = {} as CompleteProviderConfig<T>
|
||||
|
||||
constructor(private providerId: T) {}
|
||||
|
||||
/**
|
||||
* 设置 API Key
|
||||
*/
|
||||
withApiKey(apiKey: string): this
|
||||
withApiKey(apiKey: string, options: T extends 'openai' ? { organization?: string; project?: string } : never): this
|
||||
withApiKey(apiKey: string, options?: any): this {
|
||||
this.config.apiKey = apiKey
|
||||
|
||||
// 类型安全的 OpenAI 特定配置
|
||||
if (this.providerId === 'openai' && options) {
|
||||
const openaiConfig = this.config as CompleteProviderConfig<'openai'>
|
||||
if (options.organization) {
|
||||
openaiConfig.organization = options.organization
|
||||
}
|
||||
if (options.project) {
|
||||
openaiConfig.project = options.project
|
||||
}
|
||||
}
|
||||
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置基础 URL
|
||||
*/
|
||||
withBaseURL(baseURL: string) {
|
||||
this.config.baseURL = baseURL
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置请求配置
|
||||
*/
|
||||
withRequestConfig(options: { headers?: Record<string, string>; fetch?: typeof fetch }): this {
|
||||
if (options.headers) {
|
||||
this.config.headers = { ...this.config.headers, ...options.headers }
|
||||
}
|
||||
if (options.fetch) {
|
||||
this.config.fetch = options.fetch
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* Azure OpenAI 特定配置
|
||||
*/
|
||||
withAzureConfig(options: { apiVersion?: string; resourceName?: string }): T extends 'azure' ? this : never
|
||||
withAzureConfig(options: any): any {
|
||||
if (this.providerId === 'azure') {
|
||||
const azureConfig = this.config as CompleteProviderConfig<'azure'>
|
||||
if (options.apiVersion) {
|
||||
azureConfig.apiVersion = options.apiVersion
|
||||
}
|
||||
if (options.resourceName) {
|
||||
azureConfig.resourceName = options.resourceName
|
||||
}
|
||||
}
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 设置自定义参数
|
||||
*/
|
||||
withCustomParams(params: Record<string, any>) {
|
||||
Object.assign(this.config, params)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 构建最终配置
|
||||
*/
|
||||
build(): ProviderSettingsMap[T] {
|
||||
return this.config as ProviderSettingsMap[T]
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Provider 配置工厂
|
||||
* 提供便捷的配置创建方法
|
||||
*/
|
||||
export class ProviderConfigFactory {
|
||||
/**
|
||||
* 创建配置构建器
|
||||
*/
|
||||
static builder<T extends ProviderId>(providerId: T): ProviderConfigBuilder<T> {
|
||||
return new ProviderConfigBuilder(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 从通用Provider对象创建配置 - 使用更优雅的处理器模式
|
||||
*/
|
||||
static fromProvider<T extends ProviderId>(
|
||||
providerId: T,
|
||||
provider: CompleteProviderConfig<T>,
|
||||
options?: {
|
||||
headers?: Record<string, string>
|
||||
[key: string]: any
|
||||
}
|
||||
): ProviderSettingsMap[T] {
|
||||
const builder = new ProviderConfigBuilder<T>(providerId)
|
||||
|
||||
// 设置基本配置
|
||||
if (provider.apiKey) {
|
||||
builder.withApiKey(provider.apiKey)
|
||||
}
|
||||
|
||||
if (provider.baseURL) {
|
||||
builder.withBaseURL(provider.baseURL)
|
||||
}
|
||||
|
||||
// 设置请求配置
|
||||
if (options?.headers) {
|
||||
builder.withRequestConfig({
|
||||
headers: options.headers
|
||||
})
|
||||
}
|
||||
|
||||
// 使用配置处理器模式 - 更加优雅和可扩展
|
||||
const handler = configHandlers[providerId]
|
||||
if (handler) {
|
||||
handler(builder, provider)
|
||||
}
|
||||
|
||||
// 添加其他自定义参数
|
||||
if (options) {
|
||||
const customOptions = { ...options }
|
||||
delete customOptions.headers // 已经处理过了
|
||||
if (Object.keys(customOptions).length > 0) {
|
||||
builder.withCustomParams(customOptions)
|
||||
}
|
||||
}
|
||||
|
||||
return builder.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 OpenAI 配置
|
||||
*/
|
||||
static createOpenAI(
|
||||
apiKey: string,
|
||||
options?: {
|
||||
baseURL?: string
|
||||
organization?: string
|
||||
project?: string
|
||||
}
|
||||
) {
|
||||
const builder = this.builder('openai')
|
||||
|
||||
// 使用类型安全的重载
|
||||
if (options?.organization || options?.project) {
|
||||
builder.withApiKey(apiKey, {
|
||||
organization: options.organization,
|
||||
project: options.project
|
||||
})
|
||||
} else {
|
||||
builder.withApiKey(apiKey)
|
||||
}
|
||||
|
||||
return builder.withBaseURL(options?.baseURL || 'https://api.openai.com').build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Anthropic 配置
|
||||
*/
|
||||
static createAnthropic(
|
||||
apiKey: string,
|
||||
options?: {
|
||||
baseURL?: string
|
||||
}
|
||||
) {
|
||||
return this.builder('anthropic')
|
||||
.withApiKey(apiKey)
|
||||
.withBaseURL(options?.baseURL || 'https://api.anthropic.com')
|
||||
.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Azure OpenAI 配置
|
||||
*/
|
||||
static createAzureOpenAI(
|
||||
apiKey: string,
|
||||
options: {
|
||||
baseURL: string
|
||||
apiVersion?: string
|
||||
resourceName?: string
|
||||
}
|
||||
) {
|
||||
return this.builder('azure')
|
||||
.withApiKey(apiKey)
|
||||
.withBaseURL(options.baseURL)
|
||||
.withAzureConfig({
|
||||
apiVersion: options.apiVersion,
|
||||
resourceName: options.resourceName
|
||||
})
|
||||
.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Google 配置
|
||||
*/
|
||||
static createGoogle(
|
||||
apiKey: string,
|
||||
options?: {
|
||||
baseURL?: string
|
||||
projectId?: string
|
||||
location?: string
|
||||
}
|
||||
) {
|
||||
return this.builder('google')
|
||||
.withApiKey(apiKey)
|
||||
.withBaseURL(options?.baseURL || 'https://generativelanguage.googleapis.com')
|
||||
.build()
|
||||
}
|
||||
|
||||
/**
|
||||
* 快速创建 Vertex AI 配置
|
||||
*/
|
||||
static createVertexAI() {
|
||||
// credentials: {
|
||||
// clientEmail: string
|
||||
// privateKey: string
|
||||
// },
|
||||
// options?: {
|
||||
// project?: string
|
||||
// location?: string
|
||||
// }
|
||||
// return this.builder('google-vertex')
|
||||
// .withGoogleCredentials(credentials)
|
||||
// .withGoogleVertexConfig({
|
||||
// project: options?.project,
|
||||
// location: options?.location
|
||||
// })
|
||||
// .build()
|
||||
}
|
||||
|
||||
static createOpenAICompatible(baseURL: string, apiKey: string) {
|
||||
return this.builder('openai-compatible').withBaseURL(baseURL).withApiKey(apiKey).build()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 便捷的配置创建函数
|
||||
*/
|
||||
export const createProviderConfig = ProviderConfigFactory.fromProvider
|
||||
export const providerConfigBuilder = ProviderConfigFactory.builder
|
||||
83
packages/aiCore/src/core/providers/index.ts
Normal file
83
packages/aiCore/src/core/providers/index.ts
Normal file
@@ -0,0 +1,83 @@
|
||||
/**
|
||||
* Providers 模块统一导出 - 独立Provider包
|
||||
*/
|
||||
|
||||
// ==================== 核心管理器 ====================
|
||||
|
||||
// Provider 注册表管理器
|
||||
export { globalRegistryManagement, RegistryManagement } from './RegistryManagement'
|
||||
|
||||
// Provider 核心功能
|
||||
export {
|
||||
// 状态管理
|
||||
cleanup,
|
||||
clearAllProviders,
|
||||
createAndRegisterProvider,
|
||||
createProvider,
|
||||
getAllProviderConfigAliases,
|
||||
getAllProviderConfigs,
|
||||
getImageModel,
|
||||
// 工具函数
|
||||
getInitializedProviders,
|
||||
getLanguageModel,
|
||||
getProviderConfig,
|
||||
getProviderConfigByAlias,
|
||||
getSupportedProviders,
|
||||
getTextEmbeddingModel,
|
||||
hasInitializedProviders,
|
||||
// 工具函数
|
||||
hasProviderConfig,
|
||||
// 别名支持
|
||||
hasProviderConfigByAlias,
|
||||
isProviderConfigAlias,
|
||||
// 错误类型
|
||||
ProviderInitializationError,
|
||||
// 全局访问
|
||||
providerRegistry,
|
||||
registerMultipleProviderConfigs,
|
||||
registerProvider,
|
||||
// 统一Provider系统
|
||||
registerProviderConfig,
|
||||
resolveProviderConfigId
|
||||
} from './registry'
|
||||
|
||||
// ==================== 基础数据和类型 ====================
|
||||
|
||||
// 基础Provider数据源
|
||||
export { baseProviderIds, baseProviders } from './schemas'
|
||||
|
||||
// 类型定义和Schema
|
||||
export type {
|
||||
BaseProviderId,
|
||||
CustomProviderId,
|
||||
DynamicProviderRegistration,
|
||||
ProviderConfig,
|
||||
ProviderId
|
||||
} from './schemas' // 从 schemas 导出的类型
|
||||
export { baseProviderIdSchema, customProviderIdSchema, providerConfigSchema, providerIdSchema } from './schemas' // Schema 导出
|
||||
export type {
|
||||
DynamicProviderRegistry,
|
||||
ExtensibleProviderSettingsMap,
|
||||
ProviderError,
|
||||
ProviderSettingsMap,
|
||||
ProviderTypeRegistrar
|
||||
} from './types'
|
||||
|
||||
// ==================== 工具函数 ====================
|
||||
|
||||
// Provider配置工厂
|
||||
export {
|
||||
type BaseProviderConfig,
|
||||
createProviderConfig,
|
||||
ProviderConfigBuilder,
|
||||
providerConfigBuilder,
|
||||
ProviderConfigFactory
|
||||
} from './factory'
|
||||
|
||||
// 工具函数
|
||||
export { formatPrivateKey } from './utils'
|
||||
|
||||
// ==================== 扩展功能 ====================
|
||||
|
||||
// Hub Provider 功能
|
||||
export { createHubProvider, type HubProviderConfig, HubProviderError } from './HubProvider'
|
||||
310
packages/aiCore/src/core/providers/registry.ts
Normal file
310
packages/aiCore/src/core/providers/registry.ts
Normal file
@@ -0,0 +1,310 @@
|
||||
/**
|
||||
* Provider 初始化器
|
||||
* 负责根据配置创建 providers 并注册到全局管理器
|
||||
* 集成了来自 ModelCreator 的特殊处理逻辑
|
||||
*/
|
||||
|
||||
import { customProvider } from 'ai'
|
||||
|
||||
import { globalRegistryManagement } from './RegistryManagement'
|
||||
import { baseProviders, type ProviderConfig } from './schemas'
|
||||
|
||||
/**
|
||||
* Provider 初始化错误类型
|
||||
*/
|
||||
class ProviderInitializationError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public providerId?: string,
|
||||
public cause?: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'ProviderInitializationError'
|
||||
}
|
||||
}
|
||||
|
||||
// ==================== 全局管理器导出 ====================
|
||||
|
||||
export { globalRegistryManagement as providerRegistry }
|
||||
|
||||
// ==================== 便捷访问方法 ====================
|
||||
|
||||
export const getLanguageModel = (id: string) => globalRegistryManagement.languageModel(id as any)
|
||||
export const getTextEmbeddingModel = (id: string) => globalRegistryManagement.textEmbeddingModel(id as any)
|
||||
export const getImageModel = (id: string) => globalRegistryManagement.imageModel(id as any)
|
||||
|
||||
// ==================== 工具函数 ====================
|
||||
|
||||
/**
|
||||
* 获取支持的 Providers 列表
|
||||
*/
|
||||
export function getSupportedProviders(): Array<{
|
||||
id: string
|
||||
name: string
|
||||
}> {
|
||||
return baseProviders.map((provider) => ({
|
||||
id: provider.id,
|
||||
name: provider.name
|
||||
}))
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有已初始化的 providers
|
||||
*/
|
||||
export function getInitializedProviders(): string[] {
|
||||
return globalRegistryManagement.getRegisteredProviders()
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否有任何已初始化的 providers
|
||||
*/
|
||||
export function hasInitializedProviders(): boolean {
|
||||
return globalRegistryManagement.hasProviders()
|
||||
}
|
||||
|
||||
// ==================== 统一Provider配置系统 ====================
|
||||
|
||||
// 全局Provider配置存储
|
||||
const providerConfigs = new Map<string, ProviderConfig>()
|
||||
// 全局ProviderConfig别名映射 - 借鉴RegistryManagement模式
|
||||
const providerConfigAliases = new Map<string, string>() // alias -> realId
|
||||
|
||||
/**
|
||||
* 初始化内置配置 - 将baseProviders转换为统一格式
|
||||
*/
|
||||
function initializeBuiltInConfigs(): void {
|
||||
baseProviders.forEach((provider) => {
|
||||
const config: ProviderConfig = {
|
||||
id: provider.id,
|
||||
name: provider.name,
|
||||
creator: provider.creator as any, // 类型转换以兼容多种creator签名
|
||||
supportsImageGeneration: provider.supportsImageGeneration || false
|
||||
}
|
||||
providerConfigs.set(provider.id, config)
|
||||
})
|
||||
}
|
||||
|
||||
// 启动时自动注册内置配置
|
||||
initializeBuiltInConfigs()
|
||||
|
||||
/**
|
||||
* 步骤1: 注册Provider配置 - 仅存储配置,不执行创建
|
||||
*/
|
||||
export function registerProviderConfig(config: ProviderConfig): boolean {
|
||||
try {
|
||||
// 验证配置
|
||||
if (!config.id || !config.name) {
|
||||
return false
|
||||
}
|
||||
|
||||
// 检查是否与已有配置冲突(包括内置配置)
|
||||
if (providerConfigs.has(config.id)) {
|
||||
console.warn(`ProviderConfig "${config.id}" already exists, will override`)
|
||||
}
|
||||
|
||||
// 存储配置(内置和用户配置统一处理)
|
||||
providerConfigs.set(config.id, config)
|
||||
|
||||
// 处理别名
|
||||
if (config.aliases && config.aliases.length > 0) {
|
||||
config.aliases.forEach((alias) => {
|
||||
if (providerConfigAliases.has(alias)) {
|
||||
console.warn(`ProviderConfig alias "${alias}" already exists, will override`)
|
||||
}
|
||||
providerConfigAliases.set(alias, config.id)
|
||||
})
|
||||
}
|
||||
|
||||
return true
|
||||
} catch (error) {
|
||||
console.error(`Failed to register ProviderConfig:`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 步骤2: 创建Provider - 根据配置执行实际创建
|
||||
*/
|
||||
export async function createProvider(providerId: string, options: any): Promise<any> {
|
||||
// 支持通过别名查找配置
|
||||
const config = getProviderConfigByAlias(providerId)
|
||||
|
||||
if (!config) {
|
||||
throw new Error(`ProviderConfig not found for id: ${providerId}`)
|
||||
}
|
||||
|
||||
try {
|
||||
let creator: (options: any) => any
|
||||
|
||||
if (config.creator) {
|
||||
// 方式1: 直接执行 creator
|
||||
creator = config.creator
|
||||
} else if (config.import && config.creatorFunctionName) {
|
||||
// 方式2: 动态导入并执行
|
||||
const module = await config.import()
|
||||
creator = (module as any)[config.creatorFunctionName]
|
||||
|
||||
if (!creator || typeof creator !== 'function') {
|
||||
throw new Error(`Creator function "${config.creatorFunctionName}" not found in imported module`)
|
||||
}
|
||||
} else {
|
||||
throw new Error('No valid creator method provided in ProviderConfig')
|
||||
}
|
||||
|
||||
// 使用真实配置创建provider实例
|
||||
return creator(options)
|
||||
} catch (error) {
|
||||
console.error(`Failed to create provider "${providerId}":`, error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 步骤3: 注册Provider到全局管理器
|
||||
*/
|
||||
export function registerProvider(providerId: string, provider: any): boolean {
|
||||
try {
|
||||
const config = providerConfigs.get(providerId)
|
||||
if (!config) {
|
||||
console.error(`ProviderConfig not found for id: ${providerId}`)
|
||||
return false
|
||||
}
|
||||
|
||||
// 获取aliases配置
|
||||
const aliases = config.aliases
|
||||
|
||||
// 处理特殊provider逻辑
|
||||
if (providerId === 'openai') {
|
||||
// 注册默认 openai
|
||||
globalRegistryManagement.registerProvider('openai', provider, aliases)
|
||||
|
||||
// 创建并注册 openai-chat 变体
|
||||
const openaiChatProvider = customProvider({
|
||||
fallbackProvider: {
|
||||
...provider,
|
||||
languageModel: (modelId: string) => provider.chat(modelId)
|
||||
}
|
||||
})
|
||||
globalRegistryManagement.registerProvider('openai-chat', openaiChatProvider)
|
||||
} else {
|
||||
// 其他provider直接注册
|
||||
globalRegistryManagement.registerProvider(providerId, provider, aliases)
|
||||
}
|
||||
|
||||
return true
|
||||
} catch (error) {
|
||||
console.error(`Failed to register provider "${providerId}" to global registry:`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 便捷函数: 一次性完成创建+注册
|
||||
*/
|
||||
export async function createAndRegisterProvider(providerId: string, options: any): Promise<boolean> {
|
||||
try {
|
||||
// 步骤2: 创建provider
|
||||
const provider = await createProvider(providerId, options)
|
||||
|
||||
// 步骤3: 注册到全局管理器
|
||||
return registerProvider(providerId, provider)
|
||||
} catch (error) {
|
||||
console.error(`Failed to create and register provider "${providerId}":`, error)
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量注册Provider配置
|
||||
*/
|
||||
export function registerMultipleProviderConfigs(configs: ProviderConfig[]): number {
|
||||
let successCount = 0
|
||||
configs.forEach((config) => {
|
||||
if (registerProviderConfig(config)) {
|
||||
successCount++
|
||||
}
|
||||
})
|
||||
return successCount
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否有对应的Provider配置
|
||||
*/
|
||||
export function hasProviderConfig(providerId: string): boolean {
|
||||
return providerConfigs.has(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过别名或ID检查是否有对应的Provider配置
|
||||
*/
|
||||
export function hasProviderConfigByAlias(aliasOrId: string): boolean {
|
||||
const realId = resolveProviderConfigId(aliasOrId)
|
||||
return providerConfigs.has(realId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有Provider配置
|
||||
*/
|
||||
export function getAllProviderConfigs(): ProviderConfig[] {
|
||||
return Array.from(providerConfigs.values())
|
||||
}
|
||||
|
||||
/**
|
||||
* 根据ID获取Provider配置
|
||||
*/
|
||||
export function getProviderConfig(providerId: string): ProviderConfig | undefined {
|
||||
return providerConfigs.get(providerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 通过别名或ID获取Provider配置
|
||||
*/
|
||||
export function getProviderConfigByAlias(aliasOrId: string): ProviderConfig | undefined {
|
||||
// 先检查是否为别名,如果是则解析为真实ID
|
||||
const realId = providerConfigAliases.get(aliasOrId) || aliasOrId
|
||||
return providerConfigs.get(realId)
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析真实的ProviderConfig ID(去别名化)
|
||||
*/
|
||||
export function resolveProviderConfigId(aliasOrId: string): string {
|
||||
return providerConfigAliases.get(aliasOrId) || aliasOrId
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否为ProviderConfig别名
|
||||
*/
|
||||
export function isProviderConfigAlias(id: string): boolean {
|
||||
return providerConfigAliases.has(id)
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有ProviderConfig别名映射关系
|
||||
*/
|
||||
export function getAllProviderConfigAliases(): Record<string, string> {
|
||||
const result: Record<string, string> = {}
|
||||
providerConfigAliases.forEach((realId, alias) => {
|
||||
result[alias] = realId
|
||||
})
|
||||
return result
|
||||
}
|
||||
|
||||
/**
|
||||
* 清理所有Provider配置和已注册的providers
|
||||
*/
|
||||
export function cleanup(): void {
|
||||
providerConfigs.clear()
|
||||
providerConfigAliases.clear() // 清理别名映射
|
||||
globalRegistryManagement.clear()
|
||||
// 重新初始化内置配置
|
||||
initializeBuiltInConfigs()
|
||||
}
|
||||
|
||||
export function clearAllProviders(): void {
|
||||
globalRegistryManagement.clear()
|
||||
}
|
||||
|
||||
// ==================== 导出错误类型 ====================
|
||||
|
||||
export { ProviderInitializationError }
|
||||
132
packages/aiCore/src/core/providers/schemas.ts
Normal file
132
packages/aiCore/src/core/providers/schemas.ts
Normal file
@@ -0,0 +1,132 @@
|
||||
/**
|
||||
* Provider Config 定义
|
||||
*/
|
||||
|
||||
import { createAnthropic } from '@ai-sdk/anthropic'
|
||||
import { createAzure } from '@ai-sdk/azure'
|
||||
import { createDeepSeek } from '@ai-sdk/deepseek'
|
||||
import { createGoogleGenerativeAI } from '@ai-sdk/google'
|
||||
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
|
||||
import { createXai } from '@ai-sdk/xai'
|
||||
import * as z from 'zod'
|
||||
|
||||
/**
|
||||
* 基础 Providers 定义
|
||||
* 作为唯一数据源,避免重复维护
|
||||
*/
|
||||
export const baseProviders = [
|
||||
{
|
||||
id: 'openai',
|
||||
name: 'OpenAI',
|
||||
creator: createOpenAI,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'openai-responses',
|
||||
name: 'OpenAI Responses',
|
||||
creator: (options: OpenAIProviderSettings) => createOpenAI(options).responses,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'openai-compatible',
|
||||
name: 'OpenAI Compatible',
|
||||
creator: createOpenAICompatible,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'anthropic',
|
||||
name: 'Anthropic',
|
||||
creator: createAnthropic,
|
||||
supportsImageGeneration: false
|
||||
},
|
||||
{
|
||||
id: 'google',
|
||||
name: 'Google Generative AI',
|
||||
creator: createGoogleGenerativeAI,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'xai',
|
||||
name: 'xAI (Grok)',
|
||||
creator: createXai,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'azure',
|
||||
name: 'Azure OpenAI',
|
||||
creator: createAzure,
|
||||
supportsImageGeneration: true
|
||||
},
|
||||
{
|
||||
id: 'deepseek',
|
||||
name: 'DeepSeek',
|
||||
creator: createDeepSeek,
|
||||
supportsImageGeneration: false
|
||||
}
|
||||
] as const
|
||||
|
||||
/**
|
||||
* 基础 Provider IDs
|
||||
* 从 baseProviders 动态生成
|
||||
*/
|
||||
export const baseProviderIds = baseProviders.map((p) => p.id) as unknown as readonly [string, ...string[]]
|
||||
|
||||
/**
|
||||
* 基础 Provider ID Schema
|
||||
*/
|
||||
export const baseProviderIdSchema = z.enum(baseProviderIds)
|
||||
|
||||
/**
|
||||
* 用户自定义 Provider ID Schema
|
||||
* 允许任意字符串,但排除基础 provider IDs 以避免冲突
|
||||
*/
|
||||
export const customProviderIdSchema = z
|
||||
.string()
|
||||
.min(1)
|
||||
.refine((id) => !baseProviderIds.includes(id as any), {
|
||||
message: 'Custom provider ID cannot conflict with base provider IDs'
|
||||
})
|
||||
|
||||
/**
|
||||
* Provider ID Schema - 支持基础和自定义
|
||||
*/
|
||||
export const providerIdSchema = z.union([baseProviderIdSchema, customProviderIdSchema])
|
||||
|
||||
/**
|
||||
* Provider 配置 Schema
|
||||
* 用于Provider的配置验证
|
||||
*/
|
||||
export const providerConfigSchema = z
|
||||
.object({
|
||||
id: customProviderIdSchema, // 只允许自定义ID
|
||||
name: z.string().min(1),
|
||||
creator: z.function().optional(),
|
||||
import: z.function().optional(),
|
||||
creatorFunctionName: z.string().optional(),
|
||||
supportsImageGeneration: z.boolean().default(false),
|
||||
imageCreator: z.function().optional(),
|
||||
validateOptions: z.function().optional(),
|
||||
aliases: z.array(z.string()).optional()
|
||||
})
|
||||
.refine((data) => data.creator || (data.import && data.creatorFunctionName), {
|
||||
message: 'Must provide either creator function or import configuration'
|
||||
})
|
||||
|
||||
/**
|
||||
* Provider ID 类型 - 基于 zod schema 推导
|
||||
*/
|
||||
export type ProviderId = z.infer<typeof providerIdSchema>
|
||||
export type BaseProviderId = z.infer<typeof baseProviderIdSchema>
|
||||
export type CustomProviderId = z.infer<typeof customProviderIdSchema>
|
||||
|
||||
/**
|
||||
* Provider 配置类型
|
||||
*/
|
||||
export type ProviderConfig = z.infer<typeof providerConfigSchema>
|
||||
|
||||
/**
|
||||
* 兼容性类型别名
|
||||
* @deprecated 使用 ProviderConfig 替代
|
||||
*/
|
||||
export type DynamicProviderRegistration = ProviderConfig
|
||||
88
packages/aiCore/src/core/providers/types.ts
Normal file
88
packages/aiCore/src/core/providers/types.ts
Normal file
@@ -0,0 +1,88 @@
|
||||
import { type AnthropicProviderSettings } from '@ai-sdk/anthropic'
|
||||
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
|
||||
import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
|
||||
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
|
||||
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
|
||||
import { type XaiProviderSettings } from '@ai-sdk/xai'
|
||||
|
||||
// 导入基于 Zod 的 ProviderId 类型
|
||||
import { type ProviderId as ZodProviderId } from './schemas'
|
||||
|
||||
export interface ExtensibleProviderSettingsMap {
|
||||
// 基础的静态providers
|
||||
openai: OpenAIProviderSettings
|
||||
'openai-responses': OpenAIProviderSettings
|
||||
'openai-compatible': OpenAICompatibleProviderSettings
|
||||
anthropic: AnthropicProviderSettings
|
||||
google: GoogleGenerativeAIProviderSettings
|
||||
xai: XaiProviderSettings
|
||||
azure: AzureOpenAIProviderSettings
|
||||
deepseek: DeepSeekProviderSettings
|
||||
}
|
||||
|
||||
// 动态扩展的provider类型注册表
|
||||
export interface DynamicProviderRegistry {
|
||||
[key: string]: any
|
||||
}
|
||||
|
||||
// 合并基础和动态provider类型
|
||||
export type ProviderSettingsMap = ExtensibleProviderSettingsMap & DynamicProviderRegistry
|
||||
|
||||
/**
|
||||
* Provider 相关核心类型定义
|
||||
* 只定义必要的接口,其他类型直接使用 AI SDK
|
||||
*/
|
||||
|
||||
// Provider 配置接口 - 支持灵活的创建方式
|
||||
// export interface ProviderConfig {
|
||||
// id: string
|
||||
// name: string
|
||||
|
||||
// // 方式一:直接提供 creator 函数(推荐用于自定义)
|
||||
// creator?: (options: any) => any
|
||||
|
||||
// // 方式二:动态导入 + 函数名(用于包导入)
|
||||
// import?: () => Promise<any>
|
||||
// creatorFunctionName?: string
|
||||
|
||||
// // 图片生成支持
|
||||
// supportsImageGeneration?: boolean
|
||||
// imageCreator?: (options: any) => any
|
||||
|
||||
// // 可选的验证函数
|
||||
// validateOptions?: (options: any) => boolean
|
||||
// }
|
||||
|
||||
// 错误类型
|
||||
export class ProviderError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public providerId: string,
|
||||
public code?: string,
|
||||
public cause?: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'ProviderError'
|
||||
}
|
||||
}
|
||||
|
||||
// 动态ProviderId类型 - 基于 Zod Schema,支持运行时扩展和验证
|
||||
export type ProviderId = ZodProviderId
|
||||
|
||||
export interface ProviderTypeRegistrar {
|
||||
registerProviderType<T extends string, S>(providerId: T, settingsType: S): void
|
||||
getProviderSettings<T extends string>(providerId: T): any
|
||||
}
|
||||
|
||||
// 重新导出所有类型供外部使用
|
||||
export type {
|
||||
AnthropicProviderSettings,
|
||||
AzureOpenAIProviderSettings,
|
||||
DeepSeekProviderSettings,
|
||||
GoogleGenerativeAIProviderSettings,
|
||||
OpenAICompatibleProviderSettings,
|
||||
OpenAIProviderSettings,
|
||||
XaiProviderSettings
|
||||
}
|
||||
// 新的provider类型已经在上面直接export,不需要重复导出
|
||||
86
packages/aiCore/src/core/providers/utils.ts
Normal file
86
packages/aiCore/src/core/providers/utils.ts
Normal file
@@ -0,0 +1,86 @@
|
||||
/**
|
||||
* 格式化私钥,确保它包含正确的PEM头部和尾部
|
||||
*/
|
||||
export function formatPrivateKey(privateKey: string): string {
|
||||
if (!privateKey || typeof privateKey !== 'string') {
|
||||
throw new Error('Private key must be a non-empty string')
|
||||
}
|
||||
|
||||
// 先处理 JSON 字符串中的转义换行符
|
||||
const key = privateKey.replace(/\\n/g, '\n')
|
||||
|
||||
// 检查是否已经是正确格式的 PEM 私钥
|
||||
const hasBeginMarker = key.includes('-----BEGIN PRIVATE KEY-----')
|
||||
const hasEndMarker = key.includes('-----END PRIVATE KEY-----')
|
||||
|
||||
if (hasBeginMarker && hasEndMarker) {
|
||||
// 已经是 PEM 格式,但可能格式不规范,重新格式化
|
||||
return normalizePemFormat(key)
|
||||
}
|
||||
|
||||
// 如果没有完整的 PEM 头尾,尝试重新构建
|
||||
return reconstructPemKey(key)
|
||||
}
|
||||
|
||||
/**
|
||||
* 标准化 PEM 格式
|
||||
*/
|
||||
function normalizePemFormat(pemKey: string): string {
|
||||
// 分离头部、内容和尾部
|
||||
const lines = pemKey
|
||||
.split('\n')
|
||||
.map((line) => line.trim())
|
||||
.filter((line) => line.length > 0)
|
||||
|
||||
let keyContent = ''
|
||||
let foundBegin = false
|
||||
let foundEnd = false
|
||||
|
||||
for (const line of lines) {
|
||||
if (line === '-----BEGIN PRIVATE KEY-----') {
|
||||
foundBegin = true
|
||||
continue
|
||||
}
|
||||
if (line === '-----END PRIVATE KEY-----') {
|
||||
foundEnd = true
|
||||
break
|
||||
}
|
||||
if (foundBegin && !foundEnd) {
|
||||
keyContent += line
|
||||
}
|
||||
}
|
||||
|
||||
if (!foundBegin || !foundEnd || !keyContent) {
|
||||
throw new Error('Invalid PEM format: missing BEGIN/END markers or key content')
|
||||
}
|
||||
|
||||
// 重新格式化为 64 字符一行
|
||||
const formattedContent = keyContent.match(/.{1,64}/g)?.join('\n') || keyContent
|
||||
|
||||
return `-----BEGIN PRIVATE KEY-----\n${formattedContent}\n-----END PRIVATE KEY-----`
|
||||
}
|
||||
|
||||
/**
|
||||
* 重新构建 PEM 私钥
|
||||
*/
|
||||
function reconstructPemKey(key: string): string {
|
||||
// 移除所有空白字符和可能存在的不完整头尾
|
||||
let cleanKey = key.replace(/\s+/g, '')
|
||||
cleanKey = cleanKey.replace(/-----BEGIN[^-]*-----/g, '')
|
||||
cleanKey = cleanKey.replace(/-----END[^-]*-----/g, '')
|
||||
|
||||
// 确保私钥内容不为空
|
||||
if (!cleanKey) {
|
||||
throw new Error('Private key content is empty after cleaning')
|
||||
}
|
||||
|
||||
// 验证是否是有效的 Base64 字符
|
||||
if (!/^[A-Za-z0-9+/=]+$/.test(cleanKey)) {
|
||||
throw new Error('Private key contains invalid characters (not valid Base64)')
|
||||
}
|
||||
|
||||
// 格式化为 64 字符一行
|
||||
const formattedKey = cleanKey.match(/.{1,64}/g)?.join('\n') || cleanKey
|
||||
|
||||
return `-----BEGIN PRIVATE KEY-----\n${formattedKey}\n-----END PRIVATE KEY-----`
|
||||
}
|
||||
540
packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts
Normal file
540
packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts
Normal file
@@ -0,0 +1,540 @@
|
||||
import { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import { experimental_generateImage as aiGenerateImage, NoImageGeneratedError } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { type AiPlugin } from '../../plugins'
|
||||
import { globalRegistryManagement } from '../../providers/RegistryManagement'
|
||||
import { ImageGenerationError } from '../errors'
|
||||
import { RuntimeExecutor } from '../executor'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('ai', () => ({
|
||||
experimental_generateImage: vi.fn(),
|
||||
NoImageGeneratedError: class NoImageGeneratedError extends Error {
|
||||
static isInstance = vi.fn()
|
||||
constructor() {
|
||||
super('No image generated')
|
||||
this.name = 'NoImageGeneratedError'
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('../../providers/RegistryManagement', () => ({
|
||||
globalRegistryManagement: {
|
||||
imageModel: vi.fn()
|
||||
}
|
||||
}))
|
||||
|
||||
describe('RuntimeExecutor.generateImage', () => {
|
||||
let executor: RuntimeExecutor<'openai'>
|
||||
let mockImageModel: ImageModelV2
|
||||
let mockGenerateImageResult: any
|
||||
|
||||
beforeEach(() => {
|
||||
// Reset all mocks
|
||||
vi.clearAllMocks()
|
||||
|
||||
// Create executor instance
|
||||
executor = RuntimeExecutor.create('openai', {
|
||||
apiKey: 'test-key'
|
||||
})
|
||||
|
||||
// Mock image model
|
||||
mockImageModel = {
|
||||
modelId: 'dall-e-3',
|
||||
provider: 'openai'
|
||||
} as ImageModelV2
|
||||
|
||||
// Mock generateImage result
|
||||
mockGenerateImageResult = {
|
||||
image: {
|
||||
base64: 'base64-encoded-image-data',
|
||||
uint8Array: new Uint8Array([1, 2, 3]),
|
||||
mediaType: 'image/png'
|
||||
},
|
||||
images: [
|
||||
{
|
||||
base64: 'base64-encoded-image-data',
|
||||
uint8Array: new Uint8Array([1, 2, 3]),
|
||||
mediaType: 'image/png'
|
||||
}
|
||||
],
|
||||
warnings: [],
|
||||
providerMetadata: {
|
||||
openai: {
|
||||
images: [{ revisedPrompt: 'A detailed prompt' }]
|
||||
}
|
||||
},
|
||||
responses: []
|
||||
}
|
||||
|
||||
// Setup mocks
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockReturnValue(mockImageModel)
|
||||
vi.mocked(aiGenerateImage).mockResolvedValue(mockGenerateImageResult)
|
||||
|
||||
// Reset mock implementation in case it was changed by previous tests
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => mockImageModel)
|
||||
})
|
||||
|
||||
describe('Basic functionality', () => {
|
||||
it('should generate a single image with minimal parameters', async () => {
|
||||
const result = await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A futuristic cityscape at sunset'
|
||||
})
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai:dall-e-3')
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A futuristic cityscape at sunset'
|
||||
})
|
||||
|
||||
expect(result).toEqual(mockGenerateImageResult)
|
||||
})
|
||||
|
||||
it('should generate image with pre-created model', async () => {
|
||||
const result = await executor.generateImage(mockImageModel, {
|
||||
prompt: 'A beautiful landscape'
|
||||
})
|
||||
|
||||
// Note: globalRegistryManagement.imageModel may still be called due to resolveImageModel logic
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A beautiful landscape'
|
||||
})
|
||||
|
||||
expect(result).toEqual(mockGenerateImageResult)
|
||||
})
|
||||
|
||||
it('should support multiple images generation', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A futuristic cityscape',
|
||||
n: 3
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A futuristic cityscape',
|
||||
n: 3
|
||||
})
|
||||
})
|
||||
|
||||
it('should support size specification', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A beautiful sunset',
|
||||
size: '1024x1024'
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A beautiful sunset',
|
||||
size: '1024x1024'
|
||||
})
|
||||
})
|
||||
|
||||
it('should support aspect ratio specification', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A mountain landscape',
|
||||
aspectRatio: '16:9'
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A mountain landscape',
|
||||
aspectRatio: '16:9'
|
||||
})
|
||||
})
|
||||
|
||||
it('should support seed for consistent output', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A cat in space',
|
||||
seed: 1234567890
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A cat in space',
|
||||
seed: 1234567890
|
||||
})
|
||||
})
|
||||
|
||||
it('should support abort signal', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A cityscape',
|
||||
abortSignal: abortController.signal
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A cityscape',
|
||||
abortSignal: abortController.signal
|
||||
})
|
||||
})
|
||||
|
||||
it('should support provider-specific options', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A space station',
|
||||
providerOptions: {
|
||||
openai: {
|
||||
quality: 'hd',
|
||||
style: 'vivid'
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A space station',
|
||||
providerOptions: {
|
||||
openai: {
|
||||
quality: 'hd',
|
||||
style: 'vivid'
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should support custom headers', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A robot',
|
||||
headers: {
|
||||
'X-Custom-Header': 'test-value'
|
||||
}
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A robot',
|
||||
headers: {
|
||||
'X-Custom-Header': 'test-value'
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Plugin integration', () => {
|
||||
it('should execute plugins in correct order', async () => {
|
||||
const pluginCallOrder: string[] = []
|
||||
|
||||
const testPlugin: AiPlugin = {
|
||||
name: 'test-plugin',
|
||||
onRequestStart: vi.fn(async () => {
|
||||
pluginCallOrder.push('onRequestStart')
|
||||
}),
|
||||
transformParams: vi.fn(async (params) => {
|
||||
pluginCallOrder.push('transformParams')
|
||||
return { ...params, size: '512x512' }
|
||||
}),
|
||||
transformResult: vi.fn(async (result) => {
|
||||
pluginCallOrder.push('transformResult')
|
||||
return { ...result, processed: true }
|
||||
}),
|
||||
onRequestEnd: vi.fn(async () => {
|
||||
pluginCallOrder.push('onRequestEnd')
|
||||
})
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
[testPlugin]
|
||||
)
|
||||
|
||||
const result = await executorWithPlugin.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
|
||||
expect(pluginCallOrder).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
|
||||
|
||||
expect(testPlugin.transformParams).toHaveBeenCalledWith(
|
||||
{ prompt: 'A test image' },
|
||||
expect.objectContaining({
|
||||
providerId: 'openai',
|
||||
modelId: 'dall-e-3'
|
||||
})
|
||||
)
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A test image',
|
||||
size: '512x512' // Should be transformed by plugin
|
||||
})
|
||||
|
||||
expect(result).toEqual({
|
||||
...mockGenerateImageResult,
|
||||
processed: true // Should be transformed by plugin
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle model resolution through plugins', async () => {
|
||||
const customImageModel = {
|
||||
modelId: 'custom-model',
|
||||
provider: 'openai'
|
||||
} as ImageModelV2
|
||||
|
||||
const modelResolutionPlugin: AiPlugin = {
|
||||
name: 'model-resolver',
|
||||
resolveModel: vi.fn(async () => customImageModel)
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
[modelResolutionPlugin]
|
||||
)
|
||||
|
||||
await executorWithPlugin.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
|
||||
expect(modelResolutionPlugin.resolveModel).toHaveBeenCalledWith(
|
||||
'dall-e-3',
|
||||
expect.objectContaining({
|
||||
providerId: 'openai',
|
||||
modelId: 'dall-e-3'
|
||||
})
|
||||
)
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: customImageModel,
|
||||
prompt: 'A test image'
|
||||
})
|
||||
})
|
||||
|
||||
it('should support recursive calls from plugins', async () => {
|
||||
const recursivePlugin: AiPlugin = {
|
||||
name: 'recursive-plugin',
|
||||
transformParams: vi.fn(async (params, context) => {
|
||||
if (!context.isRecursiveCall && params.prompt === 'original') {
|
||||
// Make a recursive call with modified prompt
|
||||
await context.recursiveCall({
|
||||
prompt: 'modified'
|
||||
})
|
||||
}
|
||||
return params
|
||||
})
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
[recursivePlugin]
|
||||
)
|
||||
|
||||
await executorWithPlugin.generateImage('dall-e-3', {
|
||||
prompt: 'original'
|
||||
})
|
||||
|
||||
expect(recursivePlugin.transformParams).toHaveBeenCalledTimes(2)
|
||||
expect(aiGenerateImage).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error handling', () => {
|
||||
it('should handle model creation errors', async () => {
|
||||
const modelError = new Error('Failed to get image model')
|
||||
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
|
||||
throw modelError
|
||||
})
|
||||
|
||||
await expect(
|
||||
executor.generateImage('invalid-model', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
).rejects.toThrow(ImageGenerationError)
|
||||
})
|
||||
|
||||
it('should handle image generation API errors', async () => {
|
||||
const apiError = new Error('API request failed')
|
||||
vi.mocked(aiGenerateImage).mockRejectedValue(apiError)
|
||||
|
||||
await expect(
|
||||
executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
).rejects.toThrow('Failed to generate image: API request failed')
|
||||
})
|
||||
|
||||
it('should handle NoImageGeneratedError', async () => {
|
||||
const noImageError = new NoImageGeneratedError({
|
||||
cause: new Error('No image generated'),
|
||||
responses: []
|
||||
})
|
||||
|
||||
vi.mocked(aiGenerateImage).mockRejectedValue(noImageError)
|
||||
vi.mocked(NoImageGeneratedError.isInstance).mockReturnValue(true)
|
||||
|
||||
await expect(
|
||||
executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
).rejects.toThrow('Failed to generate image: No image generated')
|
||||
})
|
||||
|
||||
it('should execute onError plugin hook on failure', async () => {
|
||||
const error = new Error('Generation failed')
|
||||
vi.mocked(aiGenerateImage).mockRejectedValue(error)
|
||||
|
||||
const errorPlugin: AiPlugin = {
|
||||
name: 'error-handler',
|
||||
onError: vi.fn()
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create(
|
||||
'openai',
|
||||
{
|
||||
apiKey: 'test-key'
|
||||
},
|
||||
[errorPlugin]
|
||||
)
|
||||
|
||||
await expect(
|
||||
executorWithPlugin.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
).rejects.toThrow('Failed to generate image: Generation failed')
|
||||
|
||||
expect(errorPlugin.onError).toHaveBeenCalledWith(
|
||||
error,
|
||||
expect.objectContaining({
|
||||
providerId: 'openai',
|
||||
modelId: 'dall-e-3'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle abort signal timeout', async () => {
|
||||
const abortError = new Error('Operation was aborted')
|
||||
abortError.name = 'AbortError'
|
||||
vi.mocked(aiGenerateImage).mockRejectedValue(abortError)
|
||||
|
||||
const abortController = new AbortController()
|
||||
setTimeout(() => abortController.abort(), 10)
|
||||
|
||||
await expect(
|
||||
executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image',
|
||||
abortSignal: abortController.signal
|
||||
})
|
||||
).rejects.toThrow('Operation was aborted')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Multiple providers support', () => {
|
||||
it('should work with different providers', async () => {
|
||||
const googleExecutor = RuntimeExecutor.create('google', {
|
||||
apiKey: 'google-key'
|
||||
})
|
||||
|
||||
await googleExecutor.generateImage('imagen-3.0-generate-002', {
|
||||
prompt: 'A landscape'
|
||||
})
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('google:imagen-3.0-generate-002')
|
||||
})
|
||||
|
||||
it('should support xAI Grok image models', async () => {
|
||||
const xaiExecutor = RuntimeExecutor.create('xai', {
|
||||
apiKey: 'xai-key'
|
||||
})
|
||||
|
||||
await xaiExecutor.generateImage('grok-2-image', {
|
||||
prompt: 'A futuristic robot'
|
||||
})
|
||||
|
||||
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('xai:grok-2-image')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Advanced features', () => {
|
||||
it('should support batch image generation with maxImagesPerCall', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image',
|
||||
n: 10,
|
||||
maxImagesPerCall: 5
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A test image',
|
||||
n: 10,
|
||||
maxImagesPerCall: 5
|
||||
})
|
||||
})
|
||||
|
||||
it('should support retries with maxRetries', async () => {
|
||||
await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image',
|
||||
maxRetries: 3
|
||||
})
|
||||
|
||||
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||
model: mockImageModel,
|
||||
prompt: 'A test image',
|
||||
maxRetries: 3
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle warnings from the model', async () => {
|
||||
const resultWithWarnings = {
|
||||
...mockGenerateImageResult,
|
||||
warnings: [
|
||||
{
|
||||
type: 'unsupported-setting',
|
||||
message: 'Size parameter not supported for this model'
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithWarnings)
|
||||
|
||||
const result = await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image',
|
||||
size: '2048x2048' // Unsupported size
|
||||
})
|
||||
|
||||
expect(result.warnings).toHaveLength(1)
|
||||
expect(result.warnings[0].type).toBe('unsupported-setting')
|
||||
})
|
||||
|
||||
it('should provide access to provider metadata', async () => {
|
||||
const result = await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
|
||||
expect(result.providerMetadata).toBeDefined()
|
||||
expect(result.providerMetadata.openai).toBeDefined()
|
||||
})
|
||||
|
||||
it('should provide response metadata', async () => {
|
||||
const resultWithMetadata = {
|
||||
...mockGenerateImageResult,
|
||||
responses: [
|
||||
{
|
||||
timestamp: new Date(),
|
||||
modelId: 'dall-e-3',
|
||||
headers: { 'x-request-id': 'test-123' }
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithMetadata)
|
||||
|
||||
const result = await executor.generateImage('dall-e-3', {
|
||||
prompt: 'A test image'
|
||||
})
|
||||
|
||||
expect(result.responses).toHaveLength(1)
|
||||
expect(result.responses[0].modelId).toBe('dall-e-3')
|
||||
expect(result.responses[0].headers).toEqual({ 'x-request-id': 'test-123' })
|
||||
})
|
||||
})
|
||||
})
|
||||
38
packages/aiCore/src/core/runtime/errors.ts
Normal file
38
packages/aiCore/src/core/runtime/errors.ts
Normal file
@@ -0,0 +1,38 @@
|
||||
/**
|
||||
* Error classes for runtime operations
|
||||
*/
|
||||
|
||||
/**
|
||||
* Error thrown when image generation fails
|
||||
*/
|
||||
export class ImageGenerationError extends Error {
|
||||
constructor(
|
||||
message: string,
|
||||
public providerId?: string,
|
||||
public modelId?: string,
|
||||
public cause?: Error
|
||||
) {
|
||||
super(message)
|
||||
this.name = 'ImageGenerationError'
|
||||
|
||||
// Maintain proper stack trace (for V8 engines)
|
||||
if (Error.captureStackTrace) {
|
||||
Error.captureStackTrace(this, ImageGenerationError)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Error thrown when model resolution fails during image generation
|
||||
*/
|
||||
export class ImageModelResolutionError extends ImageGenerationError {
|
||||
constructor(modelId: string, providerId?: string, cause?: Error) {
|
||||
super(
|
||||
`Failed to resolve image model: ${modelId}${providerId ? ` for provider: ${providerId}` : ''}`,
|
||||
providerId,
|
||||
modelId,
|
||||
cause
|
||||
)
|
||||
this.name = 'ImageModelResolutionError'
|
||||
}
|
||||
}
|
||||
346
packages/aiCore/src/core/runtime/executor.ts
Normal file
346
packages/aiCore/src/core/runtime/executor.ts
Normal file
@@ -0,0 +1,346 @@
|
||||
/**
|
||||
* 运行时执行器
|
||||
* 专注于插件化的AI调用处理
|
||||
*/
|
||||
import { ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
import {
|
||||
experimental_generateImage as generateImage,
|
||||
generateObject,
|
||||
generateText,
|
||||
LanguageModel,
|
||||
streamObject,
|
||||
streamText
|
||||
} from 'ai'
|
||||
|
||||
import { globalModelResolver } from '../models'
|
||||
import { type ModelConfig } from '../models/types'
|
||||
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
|
||||
import { type ProviderId } from '../providers'
|
||||
import { ImageGenerationError, ImageModelResolutionError } from './errors'
|
||||
import { PluginEngine } from './pluginEngine'
|
||||
import { type RuntimeConfig } from './types'
|
||||
|
||||
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||
public pluginEngine: PluginEngine<T>
|
||||
// private options: ProviderSettingsMap[T]
|
||||
private config: RuntimeConfig<T>
|
||||
|
||||
constructor(config: RuntimeConfig<T>) {
|
||||
// if (!isProviderSupported(config.providerId)) {
|
||||
// throw new Error(`Unsupported provider: ${config.providerId}`)
|
||||
// }
|
||||
|
||||
// 存储options供后续使用
|
||||
// this.options = config.options
|
||||
this.config = config
|
||||
// 创建插件客户端
|
||||
this.pluginEngine = new PluginEngine(config.providerId, config.plugins || [])
|
||||
}
|
||||
|
||||
private createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) {
|
||||
return definePlugin({
|
||||
name: '_internal_resolveModel',
|
||||
enforce: 'post',
|
||||
|
||||
resolveModel: async (modelId: string) => {
|
||||
// 注意:extraModelConfig 暂时不支持,已在新架构中移除
|
||||
return await this.resolveModel(modelId, middlewares)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private createResolveImageModelPlugin() {
|
||||
return definePlugin({
|
||||
name: '_internal_resolveImageModel',
|
||||
enforce: 'post',
|
||||
|
||||
resolveModel: async (modelId: string) => {
|
||||
return await this.resolveImageModel(modelId)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
private createConfigureContextPlugin() {
|
||||
return definePlugin({
|
||||
name: '_internal_configureContext',
|
||||
configureContext: async (context: AiRequestContext) => {
|
||||
context.executor = this
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// === 高阶重载:直接使用模型 ===
|
||||
|
||||
/**
|
||||
* 流式文本生成 - 使用已创建的模型(高级用法)
|
||||
*/
|
||||
async streamText(
|
||||
model: LanguageModel,
|
||||
params: Omit<Parameters<typeof streamText>[0], 'model'>
|
||||
): Promise<ReturnType<typeof streamText>>
|
||||
async streamText(
|
||||
modelId: string,
|
||||
params: Omit<Parameters<typeof streamText>[0], 'model'>,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamText>>
|
||||
async streamText(
|
||||
modelOrId: LanguageModel,
|
||||
params: Omit<Parameters<typeof streamText>[0], 'model'>,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamText>> {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
|
||||
// 2. 执行插件处理
|
||||
return this.pluginEngine.executeStreamWithPlugins(
|
||||
'streamText',
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
params,
|
||||
async (model, transformedParams, streamTransforms) => {
|
||||
const experimental_transform =
|
||||
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
|
||||
|
||||
const finalParams = {
|
||||
model,
|
||||
...transformedParams,
|
||||
experimental_transform
|
||||
} as Parameters<typeof streamText>[0]
|
||||
|
||||
return await streamText(finalParams)
|
||||
}
|
||||
)
|
||||
}
|
||||
|
||||
// === 其他方法的重载 ===
|
||||
|
||||
/**
|
||||
* 生成文本 - 使用已创建的模型
|
||||
*/
|
||||
async generateText(
|
||||
model: LanguageModel,
|
||||
params: Omit<Parameters<typeof generateText>[0], 'model'>
|
||||
): Promise<ReturnType<typeof generateText>>
|
||||
async generateText(
|
||||
modelId: string,
|
||||
params: Omit<Parameters<typeof generateText>[0], 'model'>,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateText>>
|
||||
async generateText(
|
||||
modelOrId: LanguageModel | string,
|
||||
params: Omit<Parameters<typeof generateText>[0], 'model'>,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateText>> {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'generateText',
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
params,
|
||||
async (model, transformedParams) =>
|
||||
generateText({ model, ...transformedParams } as Parameters<typeof generateText>[0])
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成结构化对象 - 使用已创建的模型
|
||||
*/
|
||||
async generateObject(
|
||||
model: LanguageModel,
|
||||
params: Omit<Parameters<typeof generateObject>[0], 'model'>
|
||||
): Promise<ReturnType<typeof generateObject>>
|
||||
async generateObject(
|
||||
modelOrId: string,
|
||||
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateObject>>
|
||||
async generateObject(
|
||||
modelOrId: LanguageModel | string,
|
||||
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateObject>> {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'generateObject',
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
params,
|
||||
async (model, transformedParams) =>
|
||||
generateObject({ model, ...transformedParams } as Parameters<typeof generateObject>[0])
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 流式生成结构化对象 - 使用已创建的模型
|
||||
*/
|
||||
async streamObject(
|
||||
model: LanguageModel,
|
||||
params: Omit<Parameters<typeof streamObject>[0], 'model'>
|
||||
): Promise<ReturnType<typeof streamObject>>
|
||||
async streamObject(
|
||||
modelId: string,
|
||||
params: Omit<Parameters<typeof streamObject>[0], 'model'>,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamObject>>
|
||||
async streamObject(
|
||||
modelOrId: LanguageModel | string,
|
||||
params: Omit<Parameters<typeof streamObject>[0], 'model'>,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof streamObject>> {
|
||||
this.pluginEngine.usePlugins([
|
||||
this.createResolveModelPlugin(options?.middlewares),
|
||||
this.createConfigureContextPlugin()
|
||||
])
|
||||
|
||||
return this.pluginEngine.executeWithPlugins(
|
||||
'streamObject',
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
params,
|
||||
async (model, transformedParams) =>
|
||||
streamObject({ model, ...transformedParams } as Parameters<typeof streamObject>[0])
|
||||
)
|
||||
}
|
||||
|
||||
/**
|
||||
* 生成图像 - 使用已创建的图像模型
|
||||
*/
|
||||
async generateImage(
|
||||
model: ImageModelV2,
|
||||
params: Omit<Parameters<typeof generateImage>[0], 'model'>
|
||||
): Promise<ReturnType<typeof generateImage>>
|
||||
async generateImage(
|
||||
modelId: string,
|
||||
params: Omit<Parameters<typeof generateImage>[0], 'model'>,
|
||||
options?: {
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
}
|
||||
): Promise<ReturnType<typeof generateImage>>
|
||||
async generateImage(
|
||||
modelOrId: ImageModelV2 | string,
|
||||
params: Omit<Parameters<typeof generateImage>[0], 'model'>
|
||||
): Promise<ReturnType<typeof generateImage>> {
|
||||
try {
|
||||
this.pluginEngine.usePlugins([this.createResolveImageModelPlugin(), this.createConfigureContextPlugin()])
|
||||
|
||||
return await this.pluginEngine.executeImageWithPlugins(
|
||||
'generateImage',
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
params,
|
||||
async (model, transformedParams) => {
|
||||
return await generateImage({ model, ...transformedParams })
|
||||
}
|
||||
)
|
||||
} catch (error) {
|
||||
if (error instanceof Error) {
|
||||
throw new ImageGenerationError(
|
||||
`Failed to generate image: ${error.message}`,
|
||||
this.config.providerId,
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
error
|
||||
)
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// === 辅助方法 ===
|
||||
|
||||
/**
|
||||
* 解析模型:如果是字符串则创建模型,如果是模型则直接返回
|
||||
*/
|
||||
private async resolveModel(
|
||||
modelOrId: LanguageModel,
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<LanguageModelV2> {
|
||||
if (typeof modelOrId === 'string') {
|
||||
// 🎯 字符串modelId,使用新的ModelResolver解析,传递完整参数
|
||||
return await globalModelResolver.resolveLanguageModel(
|
||||
modelOrId, // 支持 'gpt-4' 和 'aihubmix:anthropic:claude-3.5-sonnet'
|
||||
this.config.providerId, // fallback provider
|
||||
this.config.providerSettings, // provider options
|
||||
middlewares // 中间件数组
|
||||
)
|
||||
} else {
|
||||
// 已经是模型,直接返回
|
||||
return modelOrId
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 解析图像模型:如果是字符串则创建图像模型,如果是模型则直接返回
|
||||
*/
|
||||
private async resolveImageModel(modelOrId: ImageModelV2 | string): Promise<ImageModelV2> {
|
||||
try {
|
||||
if (typeof modelOrId === 'string') {
|
||||
// 字符串modelId,使用新的ModelResolver解析
|
||||
return await globalModelResolver.resolveImageModel(
|
||||
modelOrId, // 支持 'dall-e-3' 和 'aihubmix:openai:dall-e-3'
|
||||
this.config.providerId // fallback provider
|
||||
)
|
||||
} else {
|
||||
// 已经是模型,直接返回
|
||||
return modelOrId
|
||||
}
|
||||
} catch (error) {
|
||||
throw new ImageModelResolutionError(
|
||||
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||
this.config.providerId,
|
||||
error instanceof Error ? error : undefined
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// === 静态工厂方法 ===
|
||||
|
||||
/**
|
||||
* 创建执行器 - 支持已知provider的类型安全
|
||||
*/
|
||||
static create<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ModelConfig<T>['providerSettings'],
|
||||
plugins?: AiPlugin[]
|
||||
): RuntimeExecutor<T> {
|
||||
return new RuntimeExecutor({
|
||||
providerId,
|
||||
providerSettings: options,
|
||||
plugins
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建OpenAI Compatible执行器
|
||||
*/
|
||||
static createOpenAICompatible(
|
||||
options: ModelConfig<'openai-compatible'>['providerSettings'],
|
||||
plugins: AiPlugin[] = []
|
||||
): RuntimeExecutor<'openai-compatible'> {
|
||||
return new RuntimeExecutor({
|
||||
providerId: 'openai-compatible',
|
||||
providerSettings: options,
|
||||
plugins
|
||||
})
|
||||
}
|
||||
}
|
||||
123
packages/aiCore/src/core/runtime/index.ts
Normal file
123
packages/aiCore/src/core/runtime/index.ts
Normal file
@@ -0,0 +1,123 @@
|
||||
/**
|
||||
* Runtime 模块导出
|
||||
* 专注于运行时插件化AI调用处理
|
||||
*/
|
||||
|
||||
// 主要的运行时执行器
|
||||
export { RuntimeExecutor } from './executor'
|
||||
|
||||
// 导出类型
|
||||
export type { RuntimeConfig } from './types'
|
||||
|
||||
// === 便捷工厂函数 ===
|
||||
|
||||
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||
|
||||
import { type AiPlugin } from '../plugins'
|
||||
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
|
||||
import { RuntimeExecutor } from './executor'
|
||||
|
||||
/**
|
||||
* 创建运行时执行器 - 支持类型安全的已知provider
|
||||
*/
|
||||
export function createExecutor<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
plugins?: AiPlugin[]
|
||||
): RuntimeExecutor<T> {
|
||||
return RuntimeExecutor.create(providerId, options, plugins)
|
||||
}
|
||||
|
||||
/**
|
||||
* 创建OpenAI Compatible执行器
|
||||
*/
|
||||
export function createOpenAICompatibleExecutor(
|
||||
options: ProviderSettingsMap['openai-compatible'] & { mode?: 'chat' | 'responses' },
|
||||
plugins: AiPlugin[] = []
|
||||
): RuntimeExecutor<'openai-compatible'> {
|
||||
return RuntimeExecutor.createOpenAICompatible(options, plugins)
|
||||
}
|
||||
|
||||
// === 直接调用API(无需创建executor实例)===
|
||||
|
||||
/**
|
||||
* 直接流式文本生成 - 支持middlewares
|
||||
*/
|
||||
export async function streamText<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
modelId: string,
|
||||
params: Parameters<RuntimeExecutor<T>['streamText']>[1],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['streamText']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.streamText(modelId, params, { middlewares })
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接生成文本 - 支持middlewares
|
||||
*/
|
||||
export async function generateText<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
modelId: string,
|
||||
params: Parameters<RuntimeExecutor<T>['generateText']>[1],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['generateText']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.generateText(modelId, params, { middlewares })
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接生成结构化对象 - 支持middlewares
|
||||
*/
|
||||
export async function generateObject<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
modelId: string,
|
||||
params: Parameters<RuntimeExecutor<T>['generateObject']>[1],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['generateObject']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.generateObject(modelId, params, { middlewares })
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接流式生成结构化对象 - 支持middlewares
|
||||
*/
|
||||
export async function streamObject<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
modelId: string,
|
||||
params: Parameters<RuntimeExecutor<T>['streamObject']>[1],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['streamObject']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.streamObject(modelId, params, { middlewares })
|
||||
}
|
||||
|
||||
/**
|
||||
* 直接生成图像 - 支持middlewares
|
||||
*/
|
||||
export async function generateImage<T extends ProviderId>(
|
||||
providerId: T,
|
||||
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||
modelId: string,
|
||||
params: Parameters<RuntimeExecutor<T>['generateImage']>[1],
|
||||
plugins?: AiPlugin[],
|
||||
middlewares?: LanguageModelV2Middleware[]
|
||||
): Promise<ReturnType<RuntimeExecutor<T>['generateImage']>> {
|
||||
const executor = createExecutor(providerId, options, plugins)
|
||||
return executor.generateImage(modelId, params, { middlewares })
|
||||
}
|
||||
|
||||
// === Agent 功能预留 ===
|
||||
// 未来将在 ../agents/ 文件夹中添加:
|
||||
// - AgentExecutor.ts
|
||||
// - WorkflowManager.ts
|
||||
// - ConversationManager.ts
|
||||
// 并在此处导出相关API
|
||||
231
packages/aiCore/src/core/runtime/pluginEngine.ts
Normal file
231
packages/aiCore/src/core/runtime/pluginEngine.ts
Normal file
@@ -0,0 +1,231 @@
|
||||
/* eslint-disable @eslint-react/naming-convention/context-name */
|
||||
import { ImageModelV2 } from '@ai-sdk/provider'
|
||||
import { LanguageModel } from 'ai'
|
||||
|
||||
import { type AiPlugin, createContext, PluginManager } from '../plugins'
|
||||
import { type ProviderId } from '../providers/types'
|
||||
|
||||
/**
|
||||
* 插件增强的 AI 客户端
|
||||
* 专注于插件处理,不暴露用户API
|
||||
*/
|
||||
export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||
private pluginManager: PluginManager
|
||||
|
||||
constructor(
|
||||
private readonly providerId: T,
|
||||
// private readonly options: ProviderSettingsMap[T],
|
||||
plugins: AiPlugin[] = []
|
||||
) {
|
||||
this.pluginManager = new PluginManager(plugins)
|
||||
}
|
||||
|
||||
/**
|
||||
* 添加插件
|
||||
*/
|
||||
use(plugin: AiPlugin): this {
|
||||
this.pluginManager.use(plugin)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 批量添加插件
|
||||
*/
|
||||
usePlugins(plugins: AiPlugin[]): this {
|
||||
plugins.forEach((plugin) => this.use(plugin))
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 移除插件
|
||||
*/
|
||||
removePlugin(pluginName: string): this {
|
||||
this.pluginManager.remove(pluginName)
|
||||
return this
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取插件统计
|
||||
*/
|
||||
getPluginStats() {
|
||||
return this.pluginManager.getStats()
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取所有插件
|
||||
*/
|
||||
getPlugins() {
|
||||
return this.pluginManager.getPlugins()
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行带插件的操作(非流式)
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeWithPlugins<TParams, TResult>(
|
||||
methodName: string,
|
||||
modelId: string,
|
||||
params: TParams,
|
||||
executor: (model: LanguageModel, transformedParams: TParams) => Promise<TResult>,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 使用正确的createContext创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeWithPlugins(methodName, modelId, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
|
||||
try {
|
||||
// 0. 配置上下文
|
||||
await this.pluginManager.executeConfigureContext(context)
|
||||
|
||||
// 1. 触发请求开始事件
|
||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 解析模型
|
||||
const model = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||
if (!model) {
|
||||
throw new Error(`Failed to resolve model: ${modelId}`)
|
||||
}
|
||||
|
||||
// 3. 转换请求参数
|
||||
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||
|
||||
// 4. 执行具体的 API 调用
|
||||
const result = await executor(model, transformedParams)
|
||||
|
||||
// 5. 转换结果(对于非流式调用)
|
||||
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||
|
||||
// 6. 触发完成事件
|
||||
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
|
||||
return transformedResult
|
||||
} catch (error) {
|
||||
// 7. 触发错误事件
|
||||
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行带插件的图像生成操作
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeImageWithPlugins<TParams, TResult>(
|
||||
methodName: string,
|
||||
modelId: string,
|
||||
params: TParams,
|
||||
executor: (model: ImageModelV2, transformedParams: TParams) => Promise<TResult>,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 使用正确的createContext创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeImageWithPlugins(methodName, modelId, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
|
||||
try {
|
||||
// 0. 配置上下文
|
||||
await this.pluginManager.executeConfigureContext(context)
|
||||
|
||||
// 1. 触发请求开始事件
|
||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 解析模型
|
||||
const model = await this.pluginManager.executeFirst<ImageModelV2>('resolveModel', modelId, context)
|
||||
if (!model) {
|
||||
throw new Error(`Failed to resolve image model: ${modelId}`)
|
||||
}
|
||||
|
||||
// 3. 转换请求参数
|
||||
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||
|
||||
// 4. 执行具体的 API 调用
|
||||
const result = await executor(model, transformedParams)
|
||||
|
||||
// 5. 转换结果
|
||||
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||
|
||||
// 6. 触发完成事件
|
||||
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
|
||||
return transformedResult
|
||||
} catch (error) {
|
||||
// 7. 触发错误事件
|
||||
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 执行流式调用的通用逻辑(支持流转换器)
|
||||
* 提供给AiExecutor使用
|
||||
*/
|
||||
async executeStreamWithPlugins<TParams, TResult>(
|
||||
methodName: string,
|
||||
modelId: string,
|
||||
params: TParams,
|
||||
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
|
||||
_context?: ReturnType<typeof createContext>
|
||||
): Promise<TResult> {
|
||||
// 创建请求上下文
|
||||
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||
|
||||
// 🔥 为上下文添加递归调用能力
|
||||
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||
// 递归调用自身,重新走完整的插件流程
|
||||
context.isRecursiveCall = true
|
||||
const result = await this.executeStreamWithPlugins(methodName, modelId, newParams, executor, context)
|
||||
context.isRecursiveCall = false
|
||||
return result
|
||||
}
|
||||
|
||||
try {
|
||||
// 0. 配置上下文
|
||||
await this.pluginManager.executeConfigureContext(context)
|
||||
|
||||
// 1. 触发请求开始事件
|
||||
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||
|
||||
// 2. 解析模型
|
||||
const model = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||
|
||||
if (!model) {
|
||||
throw new Error(`Failed to resolve model: ${modelId}`)
|
||||
}
|
||||
|
||||
// 3. 转换请求参数
|
||||
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||
|
||||
// 4. 收集流转换器
|
||||
const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context)
|
||||
|
||||
// 5. 执行流式 API 调用
|
||||
const result = await executor(model, transformedParams, streamTransforms)
|
||||
|
||||
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||
|
||||
// 6. 触发完成事件(注意:对于流式调用,这里触发的是开始流式响应的事件)
|
||||
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
|
||||
|
||||
return transformedResult
|
||||
} catch (error) {
|
||||
// 7. 触发错误事件
|
||||
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
}
|
||||
15
packages/aiCore/src/core/runtime/types.ts
Normal file
15
packages/aiCore/src/core/runtime/types.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
/**
|
||||
* Runtime 层类型定义
|
||||
*/
|
||||
import { type ModelConfig } from '../models/types'
|
||||
import { type AiPlugin } from '../plugins'
|
||||
import { type ProviderId } from '../providers/types'
|
||||
|
||||
/**
|
||||
* 运行时执行器配置
|
||||
*/
|
||||
export interface RuntimeConfig<T extends ProviderId = ProviderId> {
|
||||
providerId: T
|
||||
providerSettings: ModelConfig<T>['providerSettings'] & { mode?: 'chat' | 'responses' }
|
||||
plugins?: AiPlugin[]
|
||||
}
|
||||
46
packages/aiCore/src/index.ts
Normal file
46
packages/aiCore/src/index.ts
Normal file
@@ -0,0 +1,46 @@
|
||||
/**
|
||||
* Cherry Studio AI Core Package
|
||||
* 基于 Vercel AI SDK 的统一 AI Provider 接口
|
||||
*/
|
||||
|
||||
// 导入内部使用的类和函数
|
||||
|
||||
// ==================== 主要用户接口 ====================
|
||||
export {
|
||||
createExecutor,
|
||||
createOpenAICompatibleExecutor,
|
||||
generateImage,
|
||||
generateObject,
|
||||
generateText,
|
||||
streamText
|
||||
} from './core/runtime'
|
||||
|
||||
// ==================== 高级API ====================
|
||||
export { globalModelResolver as modelResolver } from './core/models'
|
||||
|
||||
// ==================== 插件系统 ====================
|
||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins'
|
||||
export { createContext, definePlugin, PluginManager } from './core/plugins'
|
||||
// export { createPromptToolUsePlugin, webSearchPlugin } from './core/plugins/built-in'
|
||||
export { PluginEngine } from './core/runtime/pluginEngine'
|
||||
|
||||
// ==================== AI SDK 常用类型导出 ====================
|
||||
// 直接导出 AI SDK 的常用类型,方便使用
|
||||
export type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider'
|
||||
export type { ToolCall } from '@ai-sdk/provider-utils'
|
||||
export type { ReasoningPart } from '@ai-sdk/provider-utils'
|
||||
|
||||
// ==================== 选项 ====================
|
||||
export {
|
||||
createAnthropicOptions,
|
||||
createGoogleOptions,
|
||||
createOpenAIOptions,
|
||||
type ExtractProviderOptions,
|
||||
mergeProviderOptions,
|
||||
type ProviderOptionsMap,
|
||||
type TypedProviderOptions
|
||||
} from './core/options'
|
||||
|
||||
// ==================== 包信息 ====================
|
||||
export const AI_CORE_VERSION = '1.0.0'
|
||||
export const AI_CORE_NAME = '@cherrystudio/ai-core'
|
||||
2
packages/aiCore/src/types.ts
Normal file
2
packages/aiCore/src/types.ts
Normal file
@@ -0,0 +1,2 @@
|
||||
// 重新导出插件类型
|
||||
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types'
|
||||
49
packages/aiCore/src/utils/model.ts
Normal file
49
packages/aiCore/src/utils/model.ts
Normal file
@@ -0,0 +1,49 @@
|
||||
import { type ProviderId } from '../core/providers/types'
|
||||
|
||||
export function isOpenAIChatCompletionOnlyModel(modelId: string): boolean {
|
||||
if (!modelId) {
|
||||
return false
|
||||
}
|
||||
|
||||
return (
|
||||
modelId.includes('gpt-4o-search-preview') ||
|
||||
modelId.includes('gpt-4o-mini-search-preview') ||
|
||||
modelId.includes('o1-mini') ||
|
||||
modelId.includes('o1-preview')
|
||||
)
|
||||
}
|
||||
|
||||
export function isOpenAIReasoningModel(modelId: string): boolean {
|
||||
return modelId.includes('o1') || modelId.includes('o3') || modelId.includes('o4')
|
||||
}
|
||||
|
||||
export function isOpenAILLMModel(modelId: string): boolean {
|
||||
if (modelId.includes('gpt-4o-image')) {
|
||||
return false
|
||||
}
|
||||
if (isOpenAIReasoningModel(modelId)) {
|
||||
return true
|
||||
}
|
||||
if (modelId.includes('gpt')) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
export function getModelToProviderId(modelId: string): ProviderId | 'openai-compatible' {
|
||||
const id = modelId.toLowerCase()
|
||||
|
||||
if (id.startsWith('claude')) {
|
||||
return 'anthropic'
|
||||
}
|
||||
|
||||
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
|
||||
return 'google'
|
||||
}
|
||||
|
||||
if (isOpenAILLMModel(modelId)) {
|
||||
return 'openai'
|
||||
}
|
||||
|
||||
return 'openai-compatible'
|
||||
}
|
||||
26
packages/aiCore/tsconfig.json
Normal file
26
packages/aiCore/tsconfig.json
Normal file
@@ -0,0 +1,26 @@
|
||||
{
|
||||
"compilerOptions": {
|
||||
"target": "ES2020",
|
||||
"module": "ESNext",
|
||||
"moduleResolution": "bundler",
|
||||
"declaration": true,
|
||||
"outDir": "./dist",
|
||||
"rootDir": "./src",
|
||||
"strict": true,
|
||||
"esModuleInterop": true,
|
||||
"skipLibCheck": true,
|
||||
"forceConsistentCasingInFileNames": true,
|
||||
"resolveJsonModule": true,
|
||||
"allowSyntheticDefaultImports": true,
|
||||
"noEmitOnError": false,
|
||||
"experimentalDecorators": true,
|
||||
"emitDecoratorMetadata": true
|
||||
},
|
||||
"include": [
|
||||
"src/**/*"
|
||||
],
|
||||
"exclude": [
|
||||
"node_modules",
|
||||
"dist"
|
||||
]
|
||||
}
|
||||
14
packages/aiCore/tsdown.config.ts
Normal file
14
packages/aiCore/tsdown.config.ts
Normal file
@@ -0,0 +1,14 @@
|
||||
import { defineConfig } from 'tsdown'
|
||||
|
||||
export default defineConfig({
|
||||
entry: {
|
||||
index: 'src/index.ts',
|
||||
'built-in/plugins/index': 'src/core/plugins/built-in/index.ts',
|
||||
'provider/index': 'src/core/providers/index.ts'
|
||||
},
|
||||
outDir: 'dist',
|
||||
format: ['esm', 'cjs'],
|
||||
clean: true,
|
||||
dts: true,
|
||||
tsconfig: 'tsconfig.json'
|
||||
})
|
||||
15
packages/aiCore/vitest.config.ts
Normal file
15
packages/aiCore/vitest.config.ts
Normal file
@@ -0,0 +1,15 @@
|
||||
import { defineConfig } from 'vitest/config'
|
||||
|
||||
export default defineConfig({
|
||||
test: {
|
||||
globals: true
|
||||
},
|
||||
resolve: {
|
||||
alias: {
|
||||
'@': './src'
|
||||
}
|
||||
},
|
||||
esbuild: {
|
||||
target: 'node18'
|
||||
}
|
||||
})
|
||||
@@ -570,7 +570,8 @@ class McpService {
|
||||
...tool,
|
||||
id: buildFunctionCallToolName(server.name, tool.name),
|
||||
serverId: server.id,
|
||||
serverName: server.name
|
||||
serverName: server.name,
|
||||
type: 'mcp'
|
||||
}
|
||||
serverTools.push(serverTool)
|
||||
})
|
||||
|
||||
306
src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts
Normal file
306
src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts
Normal file
@@ -0,0 +1,306 @@
|
||||
/**
|
||||
* AI SDK 到 Cherry Studio Chunk 适配器
|
||||
* 用于将 AI SDK 的 fullStream 转换为 Cherry Studio 的 chunk 格式
|
||||
*/
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import { MCPTool, WebSearchResults, WebSearchSource } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import type { TextStreamPart, ToolSet } from 'ai'
|
||||
|
||||
import { ToolCallChunkHandler } from './handleToolCallChunk'
|
||||
|
||||
const logger = loggerService.withContext('AiSdkToChunkAdapter')
|
||||
|
||||
export interface CherryStudioChunk {
|
||||
type: 'text-delta' | 'text-complete' | 'tool-call' | 'tool-result' | 'finish' | 'error'
|
||||
text?: string
|
||||
toolCall?: any
|
||||
toolResult?: any
|
||||
finishReason?: string
|
||||
usage?: any
|
||||
error?: any
|
||||
}
|
||||
|
||||
/**
|
||||
* AI SDK 到 Cherry Studio Chunk 适配器类
|
||||
* 处理 fullStream 到 Cherry Studio chunk 的转换
|
||||
*/
|
||||
export class AiSdkToChunkAdapter {
|
||||
toolCallHandler: ToolCallChunkHandler
|
||||
constructor(
|
||||
private onChunk: (chunk: Chunk) => void,
|
||||
mcpTools: MCPTool[] = []
|
||||
) {
|
||||
this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools)
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理 AI SDK 流结果
|
||||
* @param aiSdkResult AI SDK 的流结果对象
|
||||
* @returns 最终的文本内容
|
||||
*/
|
||||
async processStream(aiSdkResult: any): Promise<string> {
|
||||
// 如果是流式且有 fullStream
|
||||
if (aiSdkResult.fullStream) {
|
||||
await this.readFullStream(aiSdkResult.fullStream)
|
||||
}
|
||||
|
||||
// 使用 streamResult.text 获取最终结果
|
||||
return await aiSdkResult.text
|
||||
}
|
||||
|
||||
/**
|
||||
* 读取 fullStream 并转换为 Cherry Studio chunks
|
||||
* @param fullStream AI SDK 的 fullStream (ReadableStream)
|
||||
*/
|
||||
private async readFullStream(fullStream: ReadableStream<TextStreamPart<ToolSet>>) {
|
||||
const reader = fullStream.getReader()
|
||||
const final = {
|
||||
text: '',
|
||||
reasoningContent: '',
|
||||
webSearchResults: [],
|
||||
reasoningId: ''
|
||||
}
|
||||
try {
|
||||
while (true) {
|
||||
const { done, value } = await reader.read()
|
||||
|
||||
if (done) {
|
||||
break
|
||||
}
|
||||
|
||||
// 转换并发送 chunk
|
||||
this.convertAndEmitChunk(value, final)
|
||||
}
|
||||
} finally {
|
||||
reader.releaseLock()
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 转换 AI SDK chunk 为 Cherry Studio chunk 并调用回调
|
||||
* @param chunk AI SDK 的 chunk 数据
|
||||
*/
|
||||
private convertAndEmitChunk(
|
||||
chunk: TextStreamPart<any>,
|
||||
final: { text: string; reasoningContent: string; webSearchResults: any[]; reasoningId: string }
|
||||
) {
|
||||
logger.info(`AI SDK chunk type: ${chunk.type}`, chunk)
|
||||
switch (chunk.type) {
|
||||
// === 文本相关事件 ===
|
||||
case 'text-start':
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_START
|
||||
})
|
||||
break
|
||||
case 'text-delta':
|
||||
final.text += chunk.text || ''
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_DELTA,
|
||||
text: final.text || ''
|
||||
})
|
||||
break
|
||||
case 'text-end':
|
||||
this.onChunk({
|
||||
type: ChunkType.TEXT_COMPLETE,
|
||||
text: (chunk.providerMetadata?.text?.value as string) ?? final.text ?? ''
|
||||
})
|
||||
final.text = ''
|
||||
break
|
||||
case 'reasoning-start':
|
||||
// if (final.reasoningId !== chunk.id) {
|
||||
final.reasoningId = chunk.id
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_START
|
||||
})
|
||||
// }
|
||||
break
|
||||
case 'reasoning-delta':
|
||||
final.reasoningContent += chunk.text || ''
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_DELTA,
|
||||
text: final.reasoningContent || '',
|
||||
thinking_millsec: (chunk.providerMetadata?.metadata?.thinking_millsec as number) || 0
|
||||
})
|
||||
break
|
||||
case 'reasoning-end':
|
||||
this.onChunk({
|
||||
type: ChunkType.THINKING_COMPLETE,
|
||||
text: (chunk.providerMetadata?.metadata?.thinking_content as string) || '',
|
||||
thinking_millsec: (chunk.providerMetadata?.metadata?.thinking_millsec as number) || 0
|
||||
})
|
||||
final.reasoningContent = ''
|
||||
break
|
||||
|
||||
// === 工具调用相关事件(原始 AI SDK 事件,如果没有被中间件处理) ===
|
||||
|
||||
// case 'tool-input-start':
|
||||
// case 'tool-input-delta':
|
||||
// case 'tool-input-end':
|
||||
// this.toolCallHandler.handleToolCallCreated(chunk)
|
||||
// break
|
||||
|
||||
// case 'tool-input-delta':
|
||||
// this.toolCallHandler.handleToolCallCreated(chunk)
|
||||
// break
|
||||
case 'tool-call':
|
||||
// 原始的工具调用(未被中间件处理)
|
||||
this.toolCallHandler.handleToolCall(chunk)
|
||||
break
|
||||
|
||||
case 'tool-result':
|
||||
// 原始的工具调用结果(未被中间件处理)
|
||||
this.toolCallHandler.handleToolResult(chunk)
|
||||
break
|
||||
|
||||
// === 步骤相关事件 ===
|
||||
// case 'start':
|
||||
// this.onChunk({
|
||||
// type: ChunkType.LLM_RESPONSE_CREATED
|
||||
// })
|
||||
// break
|
||||
// TODO: 需要区分接口开始和步骤开始
|
||||
// case 'start-step':
|
||||
// this.onChunk({
|
||||
// type: ChunkType.BLOCK_CREATED
|
||||
// })
|
||||
// break
|
||||
// case 'step-finish':
|
||||
// this.onChunk({
|
||||
// type: ChunkType.TEXT_COMPLETE,
|
||||
// text: final.text || '' // TEXT_COMPLETE 需要 text 字段
|
||||
// })
|
||||
// final.text = ''
|
||||
// break
|
||||
|
||||
case 'finish-step': {
|
||||
const { providerMetadata, finishReason } = chunk
|
||||
// googel web search
|
||||
if (providerMetadata?.google?.groundingMetadata) {
|
||||
this.onChunk({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: providerMetadata.google?.groundingMetadata as WebSearchResults,
|
||||
source: WebSearchSource.GEMINI
|
||||
}
|
||||
})
|
||||
} else if (final.webSearchResults.length) {
|
||||
const providerName = Object.keys(providerMetadata || {})[0]
|
||||
switch (providerName) {
|
||||
case WebSearchSource.OPENAI:
|
||||
this.onChunk({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: final.webSearchResults,
|
||||
source: WebSearchSource.OPENAI_RESPONSE
|
||||
}
|
||||
})
|
||||
break
|
||||
default:
|
||||
this.onChunk({
|
||||
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
llm_web_search: {
|
||||
results: final.webSearchResults,
|
||||
source: WebSearchSource.AISDK
|
||||
}
|
||||
})
|
||||
break
|
||||
}
|
||||
}
|
||||
if (finishReason === 'tool-calls') {
|
||||
this.onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||
}
|
||||
|
||||
final.webSearchResults = []
|
||||
// final.reasoningId = ''
|
||||
break
|
||||
}
|
||||
|
||||
case 'finish':
|
||||
this.onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
text: final.text || '',
|
||||
reasoning_content: final.reasoningContent || '',
|
||||
usage: {
|
||||
completion_tokens: chunk.totalUsage.outputTokens || 0,
|
||||
prompt_tokens: chunk.totalUsage.inputTokens || 0,
|
||||
total_tokens: chunk.totalUsage.totalTokens || 0
|
||||
},
|
||||
metrics: chunk.totalUsage
|
||||
? {
|
||||
completion_tokens: chunk.totalUsage.outputTokens || 0,
|
||||
time_completion_millsec: 0
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
})
|
||||
this.onChunk({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
text: final.text || '',
|
||||
reasoning_content: final.reasoningContent || '',
|
||||
usage: {
|
||||
completion_tokens: chunk.totalUsage.outputTokens || 0,
|
||||
prompt_tokens: chunk.totalUsage.inputTokens || 0,
|
||||
total_tokens: chunk.totalUsage.totalTokens || 0
|
||||
},
|
||||
metrics: chunk.totalUsage
|
||||
? {
|
||||
completion_tokens: chunk.totalUsage.outputTokens || 0,
|
||||
time_completion_millsec: 0
|
||||
}
|
||||
: undefined
|
||||
}
|
||||
})
|
||||
break
|
||||
|
||||
// === 源和文件相关事件 ===
|
||||
case 'source':
|
||||
if (chunk.sourceType === 'url') {
|
||||
// if (final.webSearchResults.length === 0) {
|
||||
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||
const { sourceType: _, ...rest } = chunk
|
||||
final.webSearchResults.push(rest)
|
||||
// }
|
||||
// this.onChunk({
|
||||
// type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||
// llm_web_search: {
|
||||
// source: WebSearchSource.AISDK,
|
||||
// results: final.webSearchResults
|
||||
// }
|
||||
// })
|
||||
}
|
||||
break
|
||||
case 'file':
|
||||
// 文件相关事件,可能是图片生成
|
||||
this.onChunk({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: {
|
||||
type: 'base64',
|
||||
images: [`data:${chunk.file.mediaType};base64,${chunk.file.base64}`]
|
||||
}
|
||||
})
|
||||
break
|
||||
case 'abort':
|
||||
this.onChunk({
|
||||
type: ChunkType.ERROR,
|
||||
error: new DOMException('Request was aborted', 'AbortError')
|
||||
})
|
||||
break
|
||||
case 'error':
|
||||
this.onChunk({
|
||||
type: ChunkType.ERROR,
|
||||
error: chunk.error as Record<string, any>
|
||||
})
|
||||
break
|
||||
|
||||
default:
|
||||
// 其他类型的 chunk 可以忽略或记录日志
|
||||
// console.log('Unhandled AI SDK chunk type:', chunk.type, chunk)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export default AiSdkToChunkAdapter
|
||||
266
src/renderer/src/aiCore/chunk/handleToolCallChunk.ts
Normal file
266
src/renderer/src/aiCore/chunk/handleToolCallChunk.ts
Normal file
@@ -0,0 +1,266 @@
|
||||
/**
|
||||
* 工具调用 Chunk 处理模块
|
||||
* TODO: Tool包含了providerTool和普通的Tool还有MCPTool,后面需要重构
|
||||
* 提供工具调用相关的处理API,每个交互使用一个新的实例
|
||||
*/
|
||||
|
||||
import { loggerService } from '@logger'
|
||||
import { BaseTool, MCPTool, MCPToolResponse, NormalToolResponse } from '@renderer/types'
|
||||
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||
import type { ProviderMetadata, ToolSet, TypedToolCall, TypedToolResult } from 'ai'
|
||||
// import type {
|
||||
// AnthropicSearchOutput,
|
||||
// WebSearchPluginConfig
|
||||
// } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin'
|
||||
|
||||
const logger = loggerService.withContext('ToolCallChunkHandler')
|
||||
|
||||
/**
|
||||
* 工具调用处理器类
|
||||
*/
|
||||
export class ToolCallChunkHandler {
|
||||
// private onChunk: (chunk: Chunk) => void
|
||||
private activeToolCalls = new Map<
|
||||
string,
|
||||
{
|
||||
toolCallId: string
|
||||
toolName: string
|
||||
args: any
|
||||
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
|
||||
tool: BaseTool
|
||||
}
|
||||
>()
|
||||
constructor(
|
||||
private onChunk: (chunk: Chunk) => void,
|
||||
private mcpTools: MCPTool[]
|
||||
) {}
|
||||
|
||||
// /**
|
||||
// * 设置 onChunk 回调
|
||||
// */
|
||||
// public setOnChunk(callback: (chunk: Chunk) => void): void {
|
||||
// this.onChunk = callback
|
||||
// }
|
||||
|
||||
handleToolCallCreated(
|
||||
chunk:
|
||||
| {
|
||||
type: 'tool-input-start'
|
||||
id: string
|
||||
toolName: string
|
||||
providerMetadata?: ProviderMetadata
|
||||
providerExecuted?: boolean
|
||||
}
|
||||
| {
|
||||
type: 'tool-input-end'
|
||||
id: string
|
||||
providerMetadata?: ProviderMetadata
|
||||
}
|
||||
| {
|
||||
type: 'tool-input-delta'
|
||||
id: string
|
||||
delta: string
|
||||
providerMetadata?: ProviderMetadata
|
||||
}
|
||||
): void {
|
||||
switch (chunk.type) {
|
||||
case 'tool-input-start': {
|
||||
// 能拿到说明是mcpTool
|
||||
// if (this.activeToolCalls.get(chunk.id)) return
|
||||
|
||||
const tool: BaseTool | MCPTool = {
|
||||
id: chunk.id,
|
||||
name: chunk.toolName,
|
||||
description: chunk.toolName,
|
||||
type: chunk.toolName.startsWith('builtin_') ? 'builtin' : 'provider'
|
||||
}
|
||||
this.activeToolCalls.set(chunk.id, {
|
||||
toolCallId: chunk.id,
|
||||
toolName: chunk.toolName,
|
||||
args: '',
|
||||
tool
|
||||
})
|
||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
id: chunk.id,
|
||||
tool: tool,
|
||||
arguments: {},
|
||||
status: 'pending',
|
||||
toolCallId: chunk.id
|
||||
}
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_PENDING,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
break
|
||||
}
|
||||
case 'tool-input-delta': {
|
||||
const toolCall = this.activeToolCalls.get(chunk.id)
|
||||
if (!toolCall) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
return
|
||||
}
|
||||
toolCall.args += chunk.delta
|
||||
break
|
||||
}
|
||||
case 'tool-input-end': {
|
||||
const toolCall = this.activeToolCalls.get(chunk.id)
|
||||
this.activeToolCalls.delete(chunk.id)
|
||||
if (!toolCall) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
return
|
||||
}
|
||||
// const toolResponse: ToolCallResponse = {
|
||||
// id: toolCall.toolCallId,
|
||||
// tool: toolCall.tool,
|
||||
// arguments: toolCall.args,
|
||||
// status: 'pending',
|
||||
// toolCallId: toolCall.toolCallId
|
||||
// }
|
||||
// logger.debug('toolResponse', toolResponse)
|
||||
// this.onChunk({
|
||||
// type: ChunkType.MCP_TOOL_PENDING,
|
||||
// responses: [toolResponse]
|
||||
// })
|
||||
break
|
||||
}
|
||||
}
|
||||
// if (!toolCall) {
|
||||
// Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||
// return
|
||||
// }
|
||||
// this.onChunk({
|
||||
// type: ChunkType.MCP_TOOL_CREATED,
|
||||
// tool_calls: [
|
||||
// {
|
||||
// id: chunk.id,
|
||||
// name: chunk.toolName,
|
||||
// status: 'pending'
|
||||
// }
|
||||
// ]
|
||||
// })
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理工具调用事件
|
||||
*/
|
||||
public handleToolCall(
|
||||
chunk: {
|
||||
type: 'tool-call'
|
||||
} & TypedToolCall<ToolSet>
|
||||
): void {
|
||||
const { toolCallId, toolName, input: args, providerExecuted } = chunk
|
||||
|
||||
if (!toolCallId || !toolName) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool call chunk: missing toolCallId or toolName`)
|
||||
return
|
||||
}
|
||||
|
||||
let tool: BaseTool
|
||||
let mcpTool: MCPTool | undefined
|
||||
|
||||
// 根据 providerExecuted 标志区分处理逻辑
|
||||
if (providerExecuted) {
|
||||
// 如果是 Provider 执行的工具(如 web_search)
|
||||
logger.info(`[ToolCallChunkHandler] Handling provider-executed tool: ${toolName}`)
|
||||
tool = {
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
description: toolName,
|
||||
type: 'provider'
|
||||
} as BaseTool
|
||||
} else if (toolName.startsWith('builtin_')) {
|
||||
// 如果是内置工具,沿用现有逻辑
|
||||
logger.info(`[ToolCallChunkHandler] Handling builtin tool: ${toolName}`)
|
||||
tool = {
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
description: toolName,
|
||||
type: 'builtin'
|
||||
} as BaseTool
|
||||
} else if ((mcpTool = this.mcpTools.find((t) => t.name === toolName) as MCPTool)) {
|
||||
// 如果是客户端执行的 MCP 工具,沿用现有逻辑
|
||||
logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`)
|
||||
// mcpTool = this.mcpTools.find((t) => t.name === toolName) as MCPTool
|
||||
// if (!mcpTool) {
|
||||
// logger.warn(`[ToolCallChunkHandler] MCP tool not found: ${toolName}`)
|
||||
// return
|
||||
// }
|
||||
tool = mcpTool
|
||||
} else {
|
||||
tool = {
|
||||
id: toolCallId,
|
||||
name: toolName,
|
||||
description: toolName,
|
||||
type: 'provider'
|
||||
}
|
||||
}
|
||||
|
||||
// 记录活跃的工具调用
|
||||
this.activeToolCalls.set(toolCallId, {
|
||||
toolCallId,
|
||||
toolName,
|
||||
args,
|
||||
tool
|
||||
})
|
||||
|
||||
// 创建 MCPToolResponse 格式
|
||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
id: toolCallId,
|
||||
tool: tool,
|
||||
arguments: args,
|
||||
status: 'pending',
|
||||
toolCallId: toolCallId
|
||||
}
|
||||
|
||||
// 调用 onChunk
|
||||
if (this.onChunk) {
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_PENDING,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 处理工具调用结果事件
|
||||
*/
|
||||
public handleToolResult(
|
||||
chunk: {
|
||||
type: 'tool-result'
|
||||
} & TypedToolResult<ToolSet>
|
||||
): void {
|
||||
const { toolCallId, output, input } = chunk
|
||||
|
||||
if (!toolCallId) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool result chunk: missing toolCallId`)
|
||||
return
|
||||
}
|
||||
|
||||
// 查找对应的工具调用信息
|
||||
const toolCallInfo = this.activeToolCalls.get(toolCallId)
|
||||
if (!toolCallInfo) {
|
||||
logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
|
||||
return
|
||||
}
|
||||
|
||||
// 创建工具调用结果的 MCPToolResponse 格式
|
||||
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||
id: toolCallInfo.toolCallId,
|
||||
tool: toolCallInfo.tool,
|
||||
arguments: input,
|
||||
status: 'done',
|
||||
response: output,
|
||||
toolCallId: toolCallId
|
||||
}
|
||||
// 从活跃调用中移除(交互结束后整个实例会被丢弃)
|
||||
this.activeToolCalls.delete(toolCallId)
|
||||
|
||||
// 调用 onChunk
|
||||
if (this.onChunk) {
|
||||
this.onChunk({
|
||||
type: ChunkType.MCP_TOOL_COMPLETE,
|
||||
responses: [toolResponse]
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,189 +1,16 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
|
||||
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
|
||||
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { withSpanResult } from '@renderer/services/SpanManagerService'
|
||||
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import type { RequestOptions, SdkModel } from '@renderer/types/sdk'
|
||||
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
|
||||
/**
|
||||
* Cherry Studio AI Core - 统一入口点
|
||||
*
|
||||
* 这是新的统一入口,保持向后兼容性
|
||||
* 默认导出legacy AiProvider以保持现有代码的兼容性
|
||||
*/
|
||||
|
||||
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
|
||||
import { VertexAPIClient } from './clients/gemini/VertexAPIClient'
|
||||
import { NewAPIClient } from './clients/NewAPIClient'
|
||||
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
|
||||
import { CompletionsMiddlewareBuilder } from './middleware/builder'
|
||||
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
|
||||
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
|
||||
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
|
||||
import { applyCompletionsMiddlewares } from './middleware/composer'
|
||||
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
|
||||
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
|
||||
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
|
||||
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
|
||||
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
|
||||
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
|
||||
import { MiddlewareRegistry } from './middleware/register'
|
||||
import type { CompletionsParams, CompletionsResult } from './middleware/schemas'
|
||||
// 导出Legacy AiProvider作为默认导出(保持向后兼容)
|
||||
export { default } from './legacy/index'
|
||||
|
||||
const logger = loggerService.withContext('AiProvider')
|
||||
// 同时导出Modern AiProvider供新代码使用
|
||||
export { default as ModernAiProvider } from './index_new'
|
||||
|
||||
export default class AiProvider {
|
||||
private apiClient: BaseApiClient
|
||||
|
||||
constructor(provider: Provider) {
|
||||
// Use the new ApiClientFactory to get a BaseApiClient instance
|
||||
this.apiClient = ApiClientFactory.create(provider)
|
||||
}
|
||||
|
||||
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
// 1. 根据模型识别正确的客户端
|
||||
const model = params.assistant.model
|
||||
if (!model) {
|
||||
return Promise.reject(new Error('Model is required'))
|
||||
}
|
||||
|
||||
// 根据client类型选择合适的处理方式
|
||||
let client: BaseApiClient
|
||||
|
||||
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||
// AihubmixAPIClient: 根据模型选择合适的子client
|
||||
client = this.apiClient.getClientForModel(model)
|
||||
if (client instanceof OpenAIResponseAPIClient) {
|
||||
client = client.getClient(model) as BaseApiClient
|
||||
}
|
||||
} else if (this.apiClient instanceof NewAPIClient) {
|
||||
client = this.apiClient.getClientForModel(model)
|
||||
if (client instanceof OpenAIResponseAPIClient) {
|
||||
client = client.getClient(model) as BaseApiClient
|
||||
}
|
||||
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
|
||||
// OpenAIResponseAPIClient: 根据模型特征选择API类型
|
||||
client = this.apiClient.getClient(model) as BaseApiClient
|
||||
} else if (this.apiClient instanceof VertexAPIClient) {
|
||||
client = this.apiClient.getClient(model) as BaseApiClient
|
||||
} else {
|
||||
// 其他client直接使用
|
||||
client = this.apiClient
|
||||
}
|
||||
|
||||
// 2. 构建中间件链
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
// images api
|
||||
if (isDedicatedImageGenerationModel(model)) {
|
||||
builder.clear()
|
||||
builder
|
||||
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
|
||||
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
|
||||
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
|
||||
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
|
||||
} else {
|
||||
// Existing logic for other models
|
||||
logger.silly('Builder Params', params)
|
||||
// 使用兼容性类型检查,避免typescript类型收窄和装饰器模式的问题
|
||||
const clientTypes = client.getClientCompatibilityType(model)
|
||||
const isOpenAICompatible =
|
||||
clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
|
||||
if (!isOpenAICompatible) {
|
||||
logger.silly('ThinkingTagExtractionMiddleware is removed')
|
||||
builder.remove(ThinkingTagExtractionMiddlewareName)
|
||||
}
|
||||
|
||||
const isAnthropicOrOpenAIResponseCompatible =
|
||||
clientTypes.includes('AnthropicAPIClient') ||
|
||||
clientTypes.includes('OpenAIResponseAPIClient') ||
|
||||
clientTypes.includes('AnthropicVertexAPIClient')
|
||||
if (!isAnthropicOrOpenAIResponseCompatible) {
|
||||
logger.silly('RawStreamListenerMiddleware is removed')
|
||||
builder.remove(RawStreamListenerMiddlewareName)
|
||||
}
|
||||
if (!params.enableWebSearch) {
|
||||
logger.silly('WebSearchMiddleware is removed')
|
||||
builder.remove(WebSearchMiddlewareName)
|
||||
}
|
||||
if (!params.mcpTools?.length) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
logger.silly('ToolUseExtractionMiddleware is removed')
|
||||
builder.remove(McpToolChunkMiddlewareName)
|
||||
logger.silly('McpToolChunkMiddleware is removed')
|
||||
}
|
||||
if (isEnabledToolUse(params.assistant) && isFunctionCallingModel(model)) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
logger.silly('ToolUseExtractionMiddleware is removed')
|
||||
}
|
||||
if (params.callType !== 'chat' && params.callType !== 'check' && params.callType !== 'translate') {
|
||||
logger.silly('AbortHandlerMiddleware is removed')
|
||||
builder.remove(AbortHandlerMiddlewareName)
|
||||
}
|
||||
if (params.callType === 'test') {
|
||||
builder.remove(ErrorHandlerMiddlewareName)
|
||||
logger.silly('ErrorHandlerMiddleware is removed')
|
||||
builder.remove(FinalChunkConsumerMiddlewareName)
|
||||
logger.silly('FinalChunkConsumerMiddleware is removed')
|
||||
}
|
||||
}
|
||||
|
||||
const middlewares = builder.build()
|
||||
logger.silly(
|
||||
'middlewares',
|
||||
middlewares.map((m) => m.name)
|
||||
)
|
||||
|
||||
// 3. Create the wrapped SDK method with middlewares
|
||||
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
|
||||
|
||||
// 4. Execute the wrapped method with the original params
|
||||
const result = wrappedCompletionMethod(params, options)
|
||||
return result
|
||||
}
|
||||
|
||||
public async completionsForTrace(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
const traceName = params.assistant.model?.name
|
||||
? `${params.assistant.model?.name}.${params.callType}`
|
||||
: `LLM.${params.callType}`
|
||||
|
||||
const traceParams: StartSpanParams = {
|
||||
name: traceName,
|
||||
tag: 'LLM',
|
||||
topicId: params.topicId || '',
|
||||
modelName: params.assistant.model?.name
|
||||
}
|
||||
|
||||
return await withSpanResult(this.completions.bind(this), traceParams, params, options)
|
||||
}
|
||||
|
||||
public async models(): Promise<SdkModel[]> {
|
||||
return this.apiClient.listModels()
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
try {
|
||||
// Use the SDK instance to test embedding capabilities
|
||||
if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') {
|
||||
this.apiClient = this.apiClient.getClient(model) as BaseApiClient
|
||||
}
|
||||
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
|
||||
return dimensions
|
||||
} catch (error) {
|
||||
logger.error('Error getting embedding dimensions:', error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||
const client = this.apiClient.getClientForModel({ id: params.model } as Model)
|
||||
return client.generateImage(params)
|
||||
}
|
||||
return this.apiClient.generateImage(params)
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.apiClient.getBaseURL()
|
||||
}
|
||||
|
||||
public getApiKey(): string {
|
||||
return this.apiClient.getApiKey()
|
||||
}
|
||||
}
|
||||
// 导出一些常用的类型和工具
|
||||
export * from './legacy/clients/types'
|
||||
export * from './legacy/middleware/schemas'
|
||||
|
||||
487
src/renderer/src/aiCore/index_new.ts
Normal file
487
src/renderer/src/aiCore/index_new.ts
Normal file
@@ -0,0 +1,487 @@
|
||||
/**
|
||||
* Cherry Studio AI Core - 新版本入口
|
||||
* 集成 @cherrystudio/ai-core 库的渐进式重构方案
|
||||
*
|
||||
* 融合方案:简化实现,专注于核心功能
|
||||
* 1. 优先使用新AI SDK
|
||||
* 2. 暂时保持接口兼容性
|
||||
*/
|
||||
|
||||
import { createExecutor, generateImage } from '@cherrystudio/ai-core'
|
||||
import { createAndRegisterProvider } from '@cherrystudio/ai-core/provider'
|
||||
import { loggerService } from '@logger'
|
||||
import { isNotSupportedImageSizeModel } from '@renderer/config/models'
|
||||
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
|
||||
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||
import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import { ChunkType } from '@renderer/types/chunk'
|
||||
|
||||
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
|
||||
import LegacyAiProvider from './legacy/index'
|
||||
import { CompletionsResult } from './legacy/middleware/schemas'
|
||||
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
|
||||
import { buildPlugins } from './plugins/PluginBuilder'
|
||||
import { getActualProvider, isModernSdkSupported, providerToAiSdkConfig } from './provider/ProviderConfigProcessor'
|
||||
import type { StreamTextParams } from './types'
|
||||
|
||||
const logger = loggerService.withContext('ModernAiProvider')
|
||||
|
||||
export default class ModernAiProvider {
|
||||
private legacyProvider: LegacyAiProvider
|
||||
private config: ReturnType<typeof providerToAiSdkConfig>
|
||||
private actualProvider: Provider
|
||||
|
||||
constructor(model: Model, provider?: Provider) {
|
||||
this.actualProvider = provider || getActualProvider(model)
|
||||
this.legacyProvider = new LegacyAiProvider(this.actualProvider)
|
||||
|
||||
// 只保存配置,不预先创建executor
|
||||
this.config = providerToAiSdkConfig(this.actualProvider)
|
||||
}
|
||||
|
||||
public getActualProvider() {
|
||||
return this.actualProvider
|
||||
}
|
||||
|
||||
public async completions(
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
config: AiSdkMiddlewareConfig & {
|
||||
assistant: Assistant
|
||||
// topicId for tracing
|
||||
topicId?: string
|
||||
callType: string
|
||||
}
|
||||
): Promise<CompletionsResult> {
|
||||
// 初始化 provider 到全局管理器
|
||||
try {
|
||||
await createAndRegisterProvider(this.config.providerId, this.config.options)
|
||||
logger.debug('Provider initialized successfully', {
|
||||
providerId: this.config.providerId,
|
||||
hasOptions: !!this.config.options
|
||||
})
|
||||
} catch (error) {
|
||||
// 如果 provider 已经初始化过,可能会抛出错误,这里可以忽略
|
||||
logger.debug('Provider initialization skipped (may already be initialized)', {
|
||||
providerId: this.config.providerId,
|
||||
error: error instanceof Error ? error.message : String(error)
|
||||
})
|
||||
}
|
||||
|
||||
if (config.isImageGenerationEndpoint) {
|
||||
return await this.modernImageGeneration(modelId, params, config)
|
||||
}
|
||||
|
||||
return await this.modernCompletions(modelId, params, config)
|
||||
}
|
||||
|
||||
/**
|
||||
* 带trace支持的completions方法
|
||||
* 类似于legacy的completionsForTrace,确保AI SDK spans在正确的trace上下文中
|
||||
*/
|
||||
public async completionsForTrace(
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
config: AiSdkMiddlewareConfig & {
|
||||
assistant: Assistant
|
||||
// topicId for tracing
|
||||
topicId: string
|
||||
callType: string
|
||||
}
|
||||
): Promise<CompletionsResult> {
|
||||
const traceName = `${this.actualProvider.name}.${modelId}.${config.callType}`
|
||||
const traceParams: StartSpanParams = {
|
||||
name: traceName,
|
||||
tag: 'LLM',
|
||||
topicId: config.topicId,
|
||||
modelName: config.assistant.model?.name, // 使用modelId而不是provider名称
|
||||
inputs: params
|
||||
}
|
||||
|
||||
logger.info('Starting AI SDK trace span', {
|
||||
traceName,
|
||||
topicId: config.topicId,
|
||||
modelId,
|
||||
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
|
||||
toolNames: params.tools ? Object.keys(params.tools) : [],
|
||||
isImageGeneration: config.isImageGenerationEndpoint
|
||||
})
|
||||
|
||||
const span = addSpan(traceParams)
|
||||
if (!span) {
|
||||
logger.warn('Failed to create span, falling back to regular completions', {
|
||||
topicId: config.topicId,
|
||||
modelId,
|
||||
traceName
|
||||
})
|
||||
return await this.completions(modelId, params, config)
|
||||
}
|
||||
|
||||
try {
|
||||
logger.info('Created parent span, now calling completions', {
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: config.topicId,
|
||||
modelId,
|
||||
parentSpanCreated: true
|
||||
})
|
||||
|
||||
const result = await this.completions(modelId, params, config)
|
||||
|
||||
logger.info('Completions finished, ending parent span', {
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: config.topicId,
|
||||
modelId,
|
||||
resultLength: result.getText().length
|
||||
})
|
||||
|
||||
// 标记span完成
|
||||
endSpan({
|
||||
topicId: config.topicId,
|
||||
outputs: result,
|
||||
span,
|
||||
modelName: modelId // 使用modelId保持一致性
|
||||
})
|
||||
|
||||
return result
|
||||
} catch (error) {
|
||||
logger.error('Error in completionsForTrace, ending parent span with error', error as Error, {
|
||||
spanId: span.spanContext().spanId,
|
||||
traceId: span.spanContext().traceId,
|
||||
topicId: config.topicId,
|
||||
modelId
|
||||
})
|
||||
|
||||
// 标记span出错
|
||||
endSpan({
|
||||
topicId: config.topicId,
|
||||
error: error as Error,
|
||||
span,
|
||||
modelName: modelId // 使用modelId保持一致性
|
||||
})
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用现代化AI SDK的completions实现
|
||||
*/
|
||||
private async modernCompletions(
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
config: AiSdkMiddlewareConfig & {
|
||||
assistant: Assistant
|
||||
// topicId for tracing
|
||||
topicId?: string
|
||||
callType: string
|
||||
}
|
||||
): Promise<CompletionsResult> {
|
||||
logger.info('Starting modernCompletions', {
|
||||
modelId,
|
||||
providerId: this.config.providerId,
|
||||
topicId: config.topicId,
|
||||
hasOnChunk: !!config.onChunk,
|
||||
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
|
||||
toolCount: params.tools ? Object.keys(params.tools).length : 0
|
||||
})
|
||||
|
||||
// 根据条件构建插件数组
|
||||
const plugins = buildPlugins(config)
|
||||
logger.debug('Built plugins for AI SDK', {
|
||||
pluginCount: plugins.length,
|
||||
pluginNames: plugins.map((p) => p.name),
|
||||
providerId: this.config.providerId,
|
||||
topicId: config.topicId
|
||||
})
|
||||
|
||||
// 用构建好的插件数组创建executor
|
||||
const executor = createExecutor(this.config.providerId, this.config.options, plugins)
|
||||
logger.debug('Created AI SDK executor', {
|
||||
providerId: this.config.providerId,
|
||||
hasOptions: !!this.config.options,
|
||||
pluginCount: plugins.length
|
||||
})
|
||||
|
||||
// 动态构建中间件数组
|
||||
const middlewares = buildAiSdkMiddlewares(config)
|
||||
logger.debug('Built AI SDK middlewares', {
|
||||
middlewareCount: middlewares.length,
|
||||
topicId: config.topicId
|
||||
})
|
||||
|
||||
// 创建带有中间件的执行器
|
||||
if (config.onChunk) {
|
||||
// 流式处理 - 使用适配器
|
||||
logger.info('Starting streaming with chunk adapter', {
|
||||
modelId,
|
||||
hasMiddlewares: middlewares.length > 0,
|
||||
middlewareCount: middlewares.length,
|
||||
hasMcpTools: !!config.mcpTools,
|
||||
mcpToolCount: config.mcpTools?.length || 0,
|
||||
topicId: config.topicId
|
||||
})
|
||||
|
||||
const adapter = new AiSdkToChunkAdapter(config.onChunk, config.mcpTools)
|
||||
|
||||
logger.debug('Final params before streamText', {
|
||||
modelId,
|
||||
hasMessages: !!params.messages,
|
||||
messageCount: params.messages?.length || 0,
|
||||
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
|
||||
toolNames: params.tools ? Object.keys(params.tools) : [],
|
||||
hasSystem: !!params.system,
|
||||
topicId: config.topicId
|
||||
})
|
||||
|
||||
const streamResult = await executor.streamText(
|
||||
modelId,
|
||||
{ ...params, experimental_context: { onChunk: config.onChunk } },
|
||||
middlewares.length > 0 ? { middlewares } : undefined
|
||||
)
|
||||
|
||||
logger.info('StreamText call successful, processing stream', {
|
||||
modelId,
|
||||
topicId: config.topicId,
|
||||
hasFullStream: !!streamResult.fullStream
|
||||
})
|
||||
|
||||
const finalText = await adapter.processStream(streamResult)
|
||||
|
||||
logger.info('Stream processing completed', {
|
||||
modelId,
|
||||
topicId: config.topicId,
|
||||
finalTextLength: finalText.length
|
||||
})
|
||||
|
||||
return {
|
||||
getText: () => finalText
|
||||
}
|
||||
} else {
|
||||
// 流式处理但没有 onChunk 回调
|
||||
logger.info('Starting streaming without chunk callback', {
|
||||
modelId,
|
||||
hasMiddlewares: middlewares.length > 0,
|
||||
middlewareCount: middlewares.length,
|
||||
topicId: config.topicId
|
||||
})
|
||||
|
||||
const streamResult = await executor.streamText(
|
||||
modelId,
|
||||
params,
|
||||
middlewares.length > 0 ? { middlewares } : undefined
|
||||
)
|
||||
|
||||
logger.info('StreamText call successful, waiting for text', {
|
||||
modelId,
|
||||
topicId: config.topicId
|
||||
})
|
||||
// 强制消费流,不然await streamResult.text会阻塞
|
||||
await streamResult?.consumeStream()
|
||||
|
||||
const finalText = await streamResult.text
|
||||
|
||||
logger.info('Text extraction completed', {
|
||||
modelId,
|
||||
topicId: config.topicId,
|
||||
finalTextLength: finalText.length
|
||||
})
|
||||
|
||||
return {
|
||||
getText: () => finalText
|
||||
}
|
||||
}
|
||||
// }
|
||||
// catch (error) {
|
||||
// console.error('Modern AI SDK error:', error)
|
||||
// throw error
|
||||
// }
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用现代化 AI SDK 的图像生成实现,支持流式输出
|
||||
*/
|
||||
private async modernImageGeneration(
|
||||
modelId: string,
|
||||
params: StreamTextParams,
|
||||
config: AiSdkMiddlewareConfig & {
|
||||
assistant: Assistant
|
||||
// topicId for tracing
|
||||
topicId?: string
|
||||
callType: string
|
||||
}
|
||||
): Promise<CompletionsResult> {
|
||||
const { onChunk } = config
|
||||
|
||||
try {
|
||||
// 检查 messages 是否存在
|
||||
if (!params.messages || params.messages.length === 0) {
|
||||
throw new Error('No messages provided for image generation.')
|
||||
}
|
||||
|
||||
// 从最后一条用户消息中提取 prompt
|
||||
const lastUserMessage = params.messages.findLast((m) => m.role === 'user')
|
||||
if (!lastUserMessage) {
|
||||
throw new Error('No user message found for image generation.')
|
||||
}
|
||||
|
||||
// 直接使用消息内容,避免类型转换问题
|
||||
const prompt =
|
||||
typeof lastUserMessage.content === 'string'
|
||||
? lastUserMessage.content
|
||||
: lastUserMessage.content?.map((part) => ('text' in part ? part.text : '')).join('') || ''
|
||||
|
||||
if (!prompt) {
|
||||
throw new Error('No prompt found in user message.')
|
||||
}
|
||||
|
||||
const startTime = Date.now()
|
||||
|
||||
// 发送图像生成开始事件
|
||||
if (onChunk) {
|
||||
onChunk({ type: ChunkType.IMAGE_CREATED })
|
||||
}
|
||||
|
||||
// 构建图像生成参数
|
||||
const imageParams = {
|
||||
prompt,
|
||||
size: isNotSupportedImageSizeModel(config.model) ? undefined : ('1024x1024' as `${number}x${number}`), // 默认尺寸,使用正确的类型
|
||||
n: 1,
|
||||
...(params.abortSignal && { abortSignal: params.abortSignal })
|
||||
}
|
||||
|
||||
// 调用新 AI SDK 的图像生成功能
|
||||
const result = await generateImage(this.config.providerId, this.config.options, modelId, imageParams)
|
||||
|
||||
// 转换结果格式
|
||||
const images: string[] = []
|
||||
const imageType: 'url' | 'base64' = 'base64'
|
||||
|
||||
if (result.images) {
|
||||
for (const image of result.images) {
|
||||
if ('base64' in image && image.base64) {
|
||||
images.push(`data:${image.mediaType};base64,${image.base64}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 发送图像生成完成事件
|
||||
if (onChunk && images.length > 0) {
|
||||
onChunk({
|
||||
type: ChunkType.IMAGE_COMPLETE,
|
||||
image: { type: imageType, images }
|
||||
})
|
||||
}
|
||||
|
||||
// 发送块完成事件(类似于 modernCompletions 的处理)
|
||||
if (onChunk) {
|
||||
const usage = {
|
||||
prompt_tokens: prompt.length, // 估算的 token 数量
|
||||
completion_tokens: 0, // 图像生成没有 completion tokens
|
||||
total_tokens: prompt.length
|
||||
}
|
||||
|
||||
onChunk({
|
||||
type: ChunkType.BLOCK_COMPLETE,
|
||||
response: {
|
||||
usage,
|
||||
metrics: {
|
||||
completion_tokens: usage.completion_tokens,
|
||||
time_first_token_millsec: 0,
|
||||
time_completion_millsec: Date.now() - startTime
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
// 发送 LLM 响应完成事件
|
||||
onChunk({
|
||||
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||
response: {
|
||||
usage,
|
||||
metrics: {
|
||||
completion_tokens: usage.completion_tokens,
|
||||
time_first_token_millsec: 0,
|
||||
time_completion_millsec: Date.now() - startTime
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
return {
|
||||
getText: () => '' // 图像生成不返回文本
|
||||
}
|
||||
} catch (error) {
|
||||
// 发送错误事件
|
||||
if (onChunk) {
|
||||
onChunk({ type: ChunkType.ERROR, error: error as any })
|
||||
}
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
// 代理其他方法到原有实现
|
||||
public async models() {
|
||||
return this.legacyProvider.models()
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
return this.legacyProvider.getEmbeddingDimensions(model)
|
||||
}
|
||||
|
||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
// 如果支持新的 AI SDK,使用现代化实现
|
||||
if (isModernSdkSupported(this.actualProvider)) {
|
||||
try {
|
||||
const result = await this.modernGenerateImage(params)
|
||||
return result
|
||||
} catch (error) {
|
||||
logger.warn('Modern AI SDK generateImage failed, falling back to legacy:', error as Error)
|
||||
// fallback 到传统实现
|
||||
return this.legacyProvider.generateImage(params)
|
||||
}
|
||||
}
|
||||
|
||||
// 直接使用传统实现
|
||||
return this.legacyProvider.generateImage(params)
|
||||
}
|
||||
|
||||
/**
|
||||
* 使用现代化 AI SDK 的图像生成实现
|
||||
*/
|
||||
private async modernGenerateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
const { model, prompt, imageSize, batchSize, signal } = params
|
||||
|
||||
// 转换参数格式
|
||||
const aiSdkParams = {
|
||||
prompt,
|
||||
size: (imageSize || '1024x1024') as `${number}x${number}`,
|
||||
n: batchSize || 1,
|
||||
...(signal && { abortSignal: signal })
|
||||
}
|
||||
|
||||
const result = await generateImage(this.config.providerId, this.config.options, model, aiSdkParams)
|
||||
|
||||
// 转换结果格式
|
||||
const images: string[] = []
|
||||
if (result.images) {
|
||||
for (const image of result.images) {
|
||||
if ('base64' in image && image.base64) {
|
||||
images.push(`data:image/png;base64,${image.base64}`)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return images
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.legacyProvider.getBaseURL()
|
||||
}
|
||||
|
||||
public getApiKey(): string {
|
||||
return this.legacyProvider.getApiKey()
|
||||
}
|
||||
}
|
||||
|
||||
// 为了方便调试,导出一些工具函数
|
||||
export { isModernSdkSupported, providerToAiSdkConfig }
|
||||
@@ -1,4 +1,5 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
|
||||
import { Provider } from '@renderer/types'
|
||||
|
||||
import { AihubmixAPIClient } from './AihubmixAPIClient'
|
||||
@@ -61,6 +62,13 @@ export class ApiClientFactory {
|
||||
instance = new GeminiAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'vertexai':
|
||||
logger.debug(`Creating VertexAPIClient for provider: ${provider.id}`)
|
||||
// 检查 VertexAI 配置
|
||||
if (!isVertexAIConfigured()) {
|
||||
throw new Error(
|
||||
'VertexAI is not configured. Please configure project, location and service account credentials.'
|
||||
)
|
||||
}
|
||||
instance = new VertexAPIClient(provider) as BaseApiClient
|
||||
break
|
||||
case 'anthropic':
|
||||
@@ -1,11 +1,11 @@
|
||||
import { AihubmixAPIClient } from '@renderer/aiCore/clients/AihubmixAPIClient'
|
||||
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
|
||||
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
|
||||
import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient'
|
||||
import { VertexAPIClient } from '@renderer/aiCore/clients/gemini/VertexAPIClient'
|
||||
import { NewAPIClient } from '@renderer/aiCore/clients/NewAPIClient'
|
||||
import { OpenAIAPIClient } from '@renderer/aiCore/clients/openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from '@renderer/aiCore/clients/openai/OpenAIResponseAPIClient'
|
||||
import { AihubmixAPIClient } from '@renderer/aiCore/legacy/clients/AihubmixAPIClient'
|
||||
import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient'
|
||||
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
|
||||
import { GeminiAPIClient } from '@renderer/aiCore/legacy/clients/gemini/GeminiAPIClient'
|
||||
import { VertexAPIClient } from '@renderer/aiCore/legacy/clients/gemini/VertexAPIClient'
|
||||
import { NewAPIClient } from '@renderer/aiCore/legacy/clients/NewAPIClient'
|
||||
import { OpenAIAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIApiClient'
|
||||
import { OpenAIResponseAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIResponseAPIClient'
|
||||
import { EndpointType, Model, Provider } from '@renderer/types'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
@@ -25,7 +25,6 @@ import {
|
||||
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
||||
import { loggerService } from '@logger'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
|
||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||
@@ -64,13 +63,14 @@ import {
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
anthropicToolUseToMcpTool,
|
||||
isEnabledToolUse,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToAnthropicMessage,
|
||||
mcpToolsToAnthropicTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||
import { t } from 'i18next'
|
||||
|
||||
import { GenericChunk } from '../../middleware/schemas'
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
|
||||
@@ -457,7 +457,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
const systemMessage: TextBlockParam | undefined = systemPrompt
|
||||
@@ -6,7 +6,7 @@ import {
|
||||
InvokeModelWithResponseStreamCommand
|
||||
} from '@aws-sdk/client-bedrock-runtime'
|
||||
import { loggerService } from '@logger'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import { findTokenLimit, isReasoningModel } from '@renderer/config/models'
|
||||
import {
|
||||
@@ -50,7 +50,7 @@ import {
|
||||
import { convertBase64ImageToAwsBedrockFormat } from '@renderer/utils/aws-bedrock-utils'
|
||||
import {
|
||||
awsBedrockToolUseToMcpTool,
|
||||
isEnabledToolUse,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToAwsBedrockMessage,
|
||||
mcpToolsToAwsBedrockTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
@@ -739,7 +739,7 @@ export class AwsBedrockAPIClient extends BaseApiClient<
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
// 3. 处理消息
|
||||
@@ -18,7 +18,6 @@ import {
|
||||
} from '@google/genai'
|
||||
import { loggerService } from '@logger'
|
||||
import { nanoid } from '@reduxjs/toolkit'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import {
|
||||
findTokenLimit,
|
||||
GEMINI_FLASH_MODEL_REGEX,
|
||||
@@ -55,7 +54,7 @@ import {
|
||||
import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
||||
import {
|
||||
geminiFunctionCallToMcpTool,
|
||||
isEnabledToolUse,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToGeminiMessage,
|
||||
mcpToolsToGeminiTools
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
@@ -63,6 +62,7 @@ import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/u
|
||||
import { defaultTimeout, MB } from '@shared/config/constant'
|
||||
import { t } from 'i18next'
|
||||
|
||||
import { GenericChunk } from '../../middleware/schemas'
|
||||
import { BaseApiClient } from '../BaseApiClient'
|
||||
import { RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||
|
||||
@@ -454,7 +454,7 @@ export class GeminiAPIClient extends BaseApiClient<
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
let messageContents: Content = { role: 'user', parts: [] } // Initialize messageContents
|
||||
@@ -1,7 +1,7 @@
|
||||
import { GoogleGenAI } from '@google/genai'
|
||||
import { loggerService } from '@logger'
|
||||
import { getVertexAILocation, getVertexAIProjectId, getVertexAIServiceAccount } from '@renderer/hooks/useVertexAI'
|
||||
import { Model, Provider } from '@renderer/types'
|
||||
import { createVertexProvider, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import { Model, Provider, VertexProvider } from '@renderer/types'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient'
|
||||
@@ -12,10 +12,17 @@ export class VertexAPIClient extends GeminiAPIClient {
|
||||
private authHeaders?: Record<string, string>
|
||||
private authHeadersExpiry?: number
|
||||
private anthropicVertexClient: AnthropicVertexClient
|
||||
private vertexProvider: VertexProvider
|
||||
|
||||
constructor(provider: Provider) {
|
||||
super(provider)
|
||||
this.anthropicVertexClient = new AnthropicVertexClient(provider)
|
||||
// 如果传入的是普通 Provider,转换为 VertexProvider
|
||||
if (isVertexProvider(provider)) {
|
||||
this.vertexProvider = provider
|
||||
} else {
|
||||
this.vertexProvider = createVertexProvider(provider)
|
||||
}
|
||||
}
|
||||
|
||||
override getClientCompatibilityType(model?: Model): string[] {
|
||||
@@ -56,11 +63,9 @@ export class VertexAPIClient extends GeminiAPIClient {
|
||||
return this.sdkInstance
|
||||
}
|
||||
|
||||
const serviceAccount = getVertexAIServiceAccount()
|
||||
const projectId = getVertexAIProjectId()
|
||||
const location = getVertexAILocation()
|
||||
const { googleCredentials, project, location } = this.vertexProvider
|
||||
|
||||
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId || !location) {
|
||||
if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project || !location) {
|
||||
throw new Error('Vertex AI settings are not configured')
|
||||
}
|
||||
|
||||
@@ -68,7 +73,7 @@ export class VertexAPIClient extends GeminiAPIClient {
|
||||
|
||||
this.sdkInstance = new GoogleGenAI({
|
||||
vertexai: true,
|
||||
project: projectId,
|
||||
project: project,
|
||||
location: location,
|
||||
httpOptions: {
|
||||
apiVersion: this.getApiVersion(),
|
||||
@@ -84,11 +89,10 @@ export class VertexAPIClient extends GeminiAPIClient {
|
||||
* 获取认证头,如果配置了 service account 则从主进程获取
|
||||
*/
|
||||
private async getServiceAccountAuthHeaders(): Promise<Record<string, string> | undefined> {
|
||||
const serviceAccount = getVertexAIServiceAccount()
|
||||
const projectId = getVertexAIProjectId()
|
||||
const { googleCredentials, project } = this.vertexProvider
|
||||
|
||||
// 检查是否配置了 service account
|
||||
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId) {
|
||||
if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
@@ -101,10 +105,10 @@ export class VertexAPIClient extends GeminiAPIClient {
|
||||
try {
|
||||
// 从主进程获取认证头
|
||||
this.authHeaders = await window.api.vertexAI.getAuthHeaders({
|
||||
projectId,
|
||||
projectId: project,
|
||||
serviceAccount: {
|
||||
privateKey: serviceAccount.privateKey,
|
||||
clientEmail: serviceAccount.clientEmail
|
||||
privateKey: googleCredentials.privateKey,
|
||||
clientEmail: googleCredentials.clientEmail
|
||||
}
|
||||
})
|
||||
|
||||
@@ -125,11 +129,10 @@ export class VertexAPIClient extends GeminiAPIClient {
|
||||
this.authHeaders = undefined
|
||||
this.authHeadersExpiry = undefined
|
||||
|
||||
const serviceAccount = getVertexAIServiceAccount()
|
||||
const projectId = getVertexAIProjectId()
|
||||
const { googleCredentials, project } = this.vertexProvider
|
||||
|
||||
if (projectId && serviceAccount.clientEmail) {
|
||||
window.api.vertexAI.clearAuthCache(projectId, serviceAccount.clientEmail)
|
||||
if (project && googleCredentials.clientEmail) {
|
||||
window.api.vertexAI.clearAuthCache(project, googleCredentials.clientEmail)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -69,7 +69,7 @@ import {
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
isEnabledToolUse,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToOpenAICompatibleMessage,
|
||||
mcpToolsToOpenAIChatTools,
|
||||
openAIToolsToMcpTool
|
||||
@@ -598,7 +598,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
|
||||
const { tools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
// 3. 处理用户消息
|
||||
@@ -1,6 +1,6 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||
import { CompletionsContext } from '@renderer/aiCore/middleware/types'
|
||||
import { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas'
|
||||
import { CompletionsContext } from '@renderer/aiCore/legacy/middleware/types'
|
||||
import {
|
||||
isGPT5SeriesModel,
|
||||
isOpenAIChatCompletionOnlyModel,
|
||||
@@ -36,7 +36,7 @@ import {
|
||||
} from '@renderer/types/sdk'
|
||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||
import {
|
||||
isEnabledToolUse,
|
||||
isSupportedToolUse,
|
||||
mcpToolCallResponseToOpenAIMessage,
|
||||
mcpToolsToOpenAIResponseTools,
|
||||
openAIToolsToMcpTool
|
||||
@@ -388,7 +388,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
||||
const { tools: extraTools } = this.setupToolsConfig({
|
||||
mcpTools: mcpTools,
|
||||
model,
|
||||
enableToolUse: isEnabledToolUse(assistant)
|
||||
enableToolUse: isSupportedToolUse(assistant)
|
||||
})
|
||||
|
||||
systemMessageContent.push(systemMessageInput)
|
||||
189
src/renderer/src/aiCore/legacy/index.ts
Normal file
189
src/renderer/src/aiCore/legacy/index.ts
Normal file
@@ -0,0 +1,189 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
|
||||
import { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient'
|
||||
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { withSpanResult } from '@renderer/services/SpanManagerService'
|
||||
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||
import type { RequestOptions, SdkModel } from '@renderer/types/sdk'
|
||||
import { isSupportedToolUse } from '@renderer/utils/mcp-tools'
|
||||
|
||||
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
|
||||
import { VertexAPIClient } from './clients/gemini/VertexAPIClient'
|
||||
import { NewAPIClient } from './clients/NewAPIClient'
|
||||
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
|
||||
import { CompletionsMiddlewareBuilder } from './middleware/builder'
|
||||
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
|
||||
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
|
||||
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
|
||||
import { applyCompletionsMiddlewares } from './middleware/composer'
|
||||
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
|
||||
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
|
||||
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
|
||||
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
|
||||
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
|
||||
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
|
||||
import { MiddlewareRegistry } from './middleware/register'
|
||||
import type { CompletionsParams, CompletionsResult } from './middleware/schemas'
|
||||
|
||||
const logger = loggerService.withContext('AiProvider')
|
||||
|
||||
export default class AiProvider {
|
||||
private apiClient: BaseApiClient
|
||||
|
||||
constructor(provider: Provider) {
|
||||
// Use the new ApiClientFactory to get a BaseApiClient instance
|
||||
this.apiClient = ApiClientFactory.create(provider)
|
||||
}
|
||||
|
||||
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
// 1. 根据模型识别正确的客户端
|
||||
const model = params.assistant.model
|
||||
if (!model) {
|
||||
return Promise.reject(new Error('Model is required'))
|
||||
}
|
||||
|
||||
// 根据client类型选择合适的处理方式
|
||||
let client: BaseApiClient
|
||||
|
||||
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||
// AihubmixAPIClient: 根据模型选择合适的子client
|
||||
client = this.apiClient.getClientForModel(model)
|
||||
if (client instanceof OpenAIResponseAPIClient) {
|
||||
client = client.getClient(model) as BaseApiClient
|
||||
}
|
||||
} else if (this.apiClient instanceof NewAPIClient) {
|
||||
client = this.apiClient.getClientForModel(model)
|
||||
if (client instanceof OpenAIResponseAPIClient) {
|
||||
client = client.getClient(model) as BaseApiClient
|
||||
}
|
||||
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
|
||||
// OpenAIResponseAPIClient: 根据模型特征选择API类型
|
||||
client = this.apiClient.getClient(model) as BaseApiClient
|
||||
} else if (this.apiClient instanceof VertexAPIClient) {
|
||||
client = this.apiClient.getClient(model) as BaseApiClient
|
||||
} else {
|
||||
// 其他client直接使用
|
||||
client = this.apiClient
|
||||
}
|
||||
|
||||
// 2. 构建中间件链
|
||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||
// images api
|
||||
if (isDedicatedImageGenerationModel(model)) {
|
||||
builder.clear()
|
||||
builder
|
||||
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
|
||||
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
|
||||
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
|
||||
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
|
||||
} else {
|
||||
// Existing logic for other models
|
||||
logger.silly('Builder Params', params)
|
||||
// 使用兼容性类型检查,避免typescript类型收窄和装饰器模式的问题
|
||||
const clientTypes = client.getClientCompatibilityType(model)
|
||||
const isOpenAICompatible =
|
||||
clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
|
||||
if (!isOpenAICompatible) {
|
||||
logger.silly('ThinkingTagExtractionMiddleware is removed')
|
||||
builder.remove(ThinkingTagExtractionMiddlewareName)
|
||||
}
|
||||
|
||||
const isAnthropicOrOpenAIResponseCompatible =
|
||||
clientTypes.includes('AnthropicAPIClient') ||
|
||||
clientTypes.includes('OpenAIResponseAPIClient') ||
|
||||
clientTypes.includes('AnthropicVertexAPIClient')
|
||||
if (!isAnthropicOrOpenAIResponseCompatible) {
|
||||
logger.silly('RawStreamListenerMiddleware is removed')
|
||||
builder.remove(RawStreamListenerMiddlewareName)
|
||||
}
|
||||
if (!params.enableWebSearch) {
|
||||
logger.silly('WebSearchMiddleware is removed')
|
||||
builder.remove(WebSearchMiddlewareName)
|
||||
}
|
||||
if (!params.mcpTools?.length) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
logger.silly('ToolUseExtractionMiddleware is removed')
|
||||
builder.remove(McpToolChunkMiddlewareName)
|
||||
logger.silly('McpToolChunkMiddleware is removed')
|
||||
}
|
||||
if (isSupportedToolUse(params.assistant) && isFunctionCallingModel(model)) {
|
||||
builder.remove(ToolUseExtractionMiddlewareName)
|
||||
logger.silly('ToolUseExtractionMiddleware is removed')
|
||||
}
|
||||
if (params.callType !== 'chat' && params.callType !== 'check' && params.callType !== 'translate') {
|
||||
logger.silly('AbortHandlerMiddleware is removed')
|
||||
builder.remove(AbortHandlerMiddlewareName)
|
||||
}
|
||||
if (params.callType === 'test') {
|
||||
builder.remove(ErrorHandlerMiddlewareName)
|
||||
logger.silly('ErrorHandlerMiddleware is removed')
|
||||
builder.remove(FinalChunkConsumerMiddlewareName)
|
||||
logger.silly('FinalChunkConsumerMiddleware is removed')
|
||||
}
|
||||
}
|
||||
|
||||
const middlewares = builder.build()
|
||||
logger.silly(
|
||||
'middlewares',
|
||||
middlewares.map((m) => m.name)
|
||||
)
|
||||
|
||||
// 3. Create the wrapped SDK method with middlewares
|
||||
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
|
||||
|
||||
// 4. Execute the wrapped method with the original params
|
||||
const result = wrappedCompletionMethod(params, options)
|
||||
return result
|
||||
}
|
||||
|
||||
public async completionsForTrace(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||
const traceName = params.assistant.model?.name
|
||||
? `${params.assistant.model?.name}.${params.callType}`
|
||||
: `LLM.${params.callType}`
|
||||
|
||||
const traceParams: StartSpanParams = {
|
||||
name: traceName,
|
||||
tag: 'LLM',
|
||||
topicId: params.topicId || '',
|
||||
modelName: params.assistant.model?.name
|
||||
}
|
||||
|
||||
return await withSpanResult(this.completions.bind(this), traceParams, params, options)
|
||||
}
|
||||
|
||||
public async models(): Promise<SdkModel[]> {
|
||||
return this.apiClient.listModels()
|
||||
}
|
||||
|
||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||
try {
|
||||
// Use the SDK instance to test embedding capabilities
|
||||
if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') {
|
||||
this.apiClient = this.apiClient.getClient(model) as BaseApiClient
|
||||
}
|
||||
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
|
||||
return dimensions
|
||||
} catch (error) {
|
||||
logger.error('Error getting embedding dimensions:', error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||
const client = this.apiClient.getClientForModel({ id: params.model } as Model)
|
||||
return client.generateImage(params)
|
||||
}
|
||||
return this.apiClient.generateImage(params)
|
||||
}
|
||||
|
||||
public getBaseURL(): string {
|
||||
return this.apiClient.getBaseURL()
|
||||
}
|
||||
|
||||
public getApiKey(): string {
|
||||
return this.apiClient.getApiKey()
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { MCPCallToolResponse, MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types'
|
||||
import { MCPCallToolResponse, MCPTool, MCPToolResponse, Model } from '@renderer/types'
|
||||
import { ChunkType, MCPToolCreatedChunk } from '@renderer/types/chunk'
|
||||
import { SdkMessageParam, SdkRawOutput, SdkToolCall } from '@renderer/types/sdk'
|
||||
import {
|
||||
@@ -230,7 +230,7 @@ async function executeToolCalls(
|
||||
model: Model,
|
||||
topicId?: string
|
||||
): Promise<{ toolResults: SdkMessageParam[]; confirmedToolCalls: SdkToolCall[] }> {
|
||||
const mcpToolResponses: ToolCallResponse[] = toolCalls
|
||||
const mcpToolResponses: MCPToolResponse[] = toolCalls
|
||||
.map((toolCall) => {
|
||||
const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools)
|
||||
if (!mcpTool) {
|
||||
@@ -238,7 +238,7 @@ async function executeToolCalls(
|
||||
}
|
||||
return ctx.apiClientInstance.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
|
||||
})
|
||||
.filter((t): t is ToolCallResponse => typeof t !== 'undefined')
|
||||
.filter((t): t is MCPToolResponse => typeof t !== 'undefined')
|
||||
|
||||
if (mcpToolResponses.length === 0) {
|
||||
logger.warn(`No valid MCP tool responses to execute`)
|
||||
@@ -1,4 +1,4 @@
|
||||
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
|
||||
import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient'
|
||||
import { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk'
|
||||
|
||||
import { AnthropicStreamListener } from '../../clients/types'
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user