Compare commits
230 Commits
main
...
v1.6.0-bet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4b7b0a3823 | ||
|
|
7f87fb9c26 | ||
|
|
01c4777691 | ||
|
|
976d246cac | ||
|
|
4e91db43c9 | ||
|
|
8032f79aca | ||
|
|
00602ddc40 | ||
|
|
782953cca1 | ||
|
|
4cceddc179 | ||
|
|
e84d00fb0d | ||
|
|
cefd32ac7a | ||
|
|
5ad170ec51 | ||
|
|
02b85afefb | ||
|
|
e74c5a8ba3 | ||
|
|
e39e349dd3 | ||
|
|
1d331c8ffc | ||
|
|
837e929731 | ||
|
|
6b92e676dc | ||
|
|
c982976fe0 | ||
|
|
461b54c53b | ||
|
|
2f3b1d767d | ||
|
|
895c30057f | ||
|
|
d62e336dad | ||
|
|
7f643d058c | ||
|
|
4cf239f165 | ||
|
|
aaa56da614 | ||
|
|
d468bb6ad3 | ||
|
|
1ef63aed3a | ||
|
|
415fc23fee | ||
|
|
e8bf1552d0 | ||
|
|
c1c91db9d2 | ||
|
|
27fa16daa3 | ||
|
|
e0f86688a2 | ||
|
|
7a5050d2a2 | ||
|
|
e653a52265 | ||
|
|
f76952b0dd | ||
|
|
01d7f784f7 | ||
|
|
4960eb712b | ||
|
|
694ecc5243 | ||
|
|
84f920e54e | ||
|
|
d4a22d3b0c | ||
|
|
ddb203170e | ||
|
|
1460a0b5b9 | ||
|
|
aaa51c435e | ||
|
|
9bde8b3cae | ||
|
|
e56218f3ac | ||
|
|
58874a954c | ||
|
|
6fa82533d5 | ||
|
|
a7b8b40301 | ||
|
|
bdbb2c2c75 | ||
|
|
005cd730b0 | ||
|
|
3d131dc213 | ||
|
|
b5c1530d97 | ||
|
|
a2326ee825 | ||
|
|
70c6478278 | ||
|
|
0944faf8e5 | ||
|
|
6a1918deef | ||
|
|
c9d0265872 | ||
|
|
2ca5116769 | ||
|
|
efeada281a | ||
|
|
c769e3aa41 | ||
|
|
2212aac6f7 | ||
|
|
49cd9d6723 | ||
|
|
1735a9efb6 | ||
|
|
1b86997f14 | ||
|
|
cf777ba62b | ||
|
|
4918628131 | ||
|
|
fee6ad58d1 | ||
|
|
a30df46c40 | ||
|
|
3004f84be3 | ||
|
|
9551c49452 | ||
|
|
e312c84a0e | ||
|
|
3d0fb97475 | ||
|
|
d10ba04047 | ||
|
|
4b7023f855 | ||
|
|
5f096ecf8c | ||
|
|
a12d627b65 | ||
|
|
7bda658022 | ||
|
|
bfcb215c16 | ||
|
|
1b3fcb2e55 | ||
|
|
9c01e24317 | ||
|
|
2ce9314a10 | ||
|
|
0c7e221b4e | ||
|
|
82d4637c9d | ||
|
|
84eef25ff9 | ||
|
|
53dcda6942 | ||
|
|
ead0e22c60 | ||
|
|
417f90df3b | ||
|
|
65c15c6d87 | ||
|
|
ca4e7e3d2b | ||
|
|
d34b640807 | ||
|
|
aa9ed3b9c8 | ||
|
|
d4da7d817d | ||
|
|
179b7af9bd | ||
|
|
5d0ab0a9a1 | ||
|
|
d93a36e5c9 | ||
|
|
c9c0616c91 | ||
|
|
356443babf | ||
|
|
b2c512082f | ||
|
|
0273c58050 | ||
|
|
239c849890 | ||
|
|
bbc472c169 | ||
|
|
b099c9b0b3 | ||
|
|
628919b562 | ||
|
|
d05ed94702 | ||
|
|
02f8e7a857 | ||
|
|
6c093f72d8 | ||
|
|
0bb1001d40 | ||
|
|
376020b23c | ||
|
|
3630133efd | ||
|
|
cb55f7a69b | ||
|
|
ff7ad52ad5 | ||
|
|
bf02afa841 | ||
|
|
e8b059c4db | ||
|
|
abfec7a228 | ||
|
|
eeafb99059 | ||
|
|
1ea8266280 | ||
|
|
def685921c | ||
|
|
a8dbae1715 | ||
|
|
71959f577d | ||
|
|
ecc08bd3f7 | ||
|
|
7216e9943c | ||
|
|
a05d7cbe2d | ||
|
|
0310648445 | ||
|
|
33db455e32 | ||
|
|
e690da840c | ||
|
|
eca9442907 | ||
|
|
4b62384fc5 | ||
|
|
addd5ffdfa | ||
|
|
fcc8836c95 | ||
|
|
61e3309cd2 | ||
|
|
786bc8dca9 | ||
|
|
c3a6456499 | ||
|
|
ef6be4a6f9 | ||
|
|
69e87ce21a | ||
|
|
608943bdbc | ||
|
|
1248e3c49a | ||
|
|
c3ad18b77e | ||
|
|
0bc5e3d24d | ||
|
|
36e20d545b | ||
|
|
45405213fc | ||
|
|
b83837708b | ||
|
|
4732c8f1bd | ||
|
|
ef8cf65ece | ||
|
|
e3c5c87e1b | ||
|
|
e7d5626055 | ||
|
|
650650a68f | ||
|
|
f38e4a87b8 | ||
|
|
a356492d6f | ||
|
|
8863e10df1 | ||
|
|
42bfa281a7 | ||
|
|
e7b4f1f934 | ||
|
|
0456094512 | ||
|
|
da455997ad | ||
|
|
0c4e8228af | ||
|
|
16e0154200 | ||
|
|
3ab904e789 | ||
|
|
42c7ebd193 | ||
|
|
a0623f2187 | ||
|
|
4bfff85dc8 | ||
|
|
8317ad55e7 | ||
|
|
b67cd9d145 | ||
|
|
234514d736 | ||
|
|
450d6228d4 | ||
|
|
3c955e69f1 | ||
|
|
4573e3f48f | ||
|
|
56c5e5a80f | ||
|
|
bb520910bc | ||
|
|
342c5ab82c | ||
|
|
fce8f2411c | ||
|
|
0a908a334b | ||
|
|
c72156b2da | ||
|
|
9e252d7eb0 | ||
|
|
4b0d8d7e65 | ||
|
|
448b5b5c9e | ||
|
|
f20d964be3 | ||
|
|
c92475b6bf | ||
|
|
89cbf80008 | ||
|
|
3e5969b97c | ||
|
|
cd42410d70 | ||
|
|
547e5785c0 | ||
|
|
13162edcb2 | ||
|
|
ac15930692 | ||
|
|
ff3b1fc38f | ||
|
|
b660e9d524 | ||
|
|
182ab6092c | ||
|
|
cf5ed8e858 | ||
|
|
007de81928 | ||
|
|
6c87b42607 | ||
|
|
592a7ddc3f | ||
|
|
60cb198f44 | ||
|
|
54c36040af | ||
|
|
ef616e1c3b | ||
|
|
dc106a8af7 | ||
|
|
1bcc716eaf | ||
|
|
30a288ce5d | ||
|
|
cbbaa3127c | ||
|
|
f61da8c2d6 | ||
|
|
d9eb9e86fe | ||
|
|
87f803b0d3 | ||
|
|
c934b45c09 | ||
|
|
ba121d04b4 | ||
|
|
9293f26612 | ||
|
|
8b67a45804 | ||
|
|
f23a026a28 | ||
|
|
e4c0ea035f | ||
|
|
7d8ed3a737 | ||
|
|
2a588fdab2 | ||
|
|
f08c444ffb | ||
|
|
f6c3794ac9 | ||
|
|
ebe85ba24a | ||
|
|
09080f0755 | ||
|
|
e421b81fca | ||
|
|
2f58b3360e | ||
|
|
f934b479b2 | ||
|
|
8ca6341609 | ||
|
|
c99a2fedb7 | ||
|
|
456e6c068e | ||
|
|
f206d4ec4c | ||
|
|
1af8be8768 | ||
|
|
e70174817e | ||
|
|
c5cb443de0 | ||
|
|
9318d9ffeb | ||
|
|
3771b24b52 | ||
|
|
1bccfd3170 | ||
|
|
43d55b7e45 | ||
|
|
1c5a30cf49 | ||
|
|
2df1cddb43 | ||
|
|
ed2363e561 | ||
|
|
a27d1bf506 |
@@ -7,4 +7,5 @@ tsconfig.*.json
|
|||||||
CHANGELOG*.md
|
CHANGELOG*.md
|
||||||
agents.json
|
agents.json
|
||||||
src/renderer/src/integration/nutstore/sso/lib
|
src/renderer/src/integration/nutstore/sso/lib
|
||||||
|
AGENT.md
|
||||||
src/main/integration/cherryin/index.js
|
src/main/integration/cherryin/index.js
|
||||||
|
|||||||
@@ -121,24 +121,12 @@ afterSign: scripts/notarize.js
|
|||||||
artifactBuildCompleted: scripts/artifact-build-completed.js
|
artifactBuildCompleted: scripts/artifact-build-completed.js
|
||||||
releaseInfo:
|
releaseInfo:
|
||||||
releaseNotes: |
|
releaseNotes: |
|
||||||
✨ 重要更新:
|
|
||||||
- 新增笔记模块,支持富文本编辑和管理
|
|
||||||
- 内置 GLM-4.5-Flash 免费模型(由智谱开放平台提供)
|
|
||||||
- 内置 Qwen3-8B 免费模型(由硅基流动提供)
|
|
||||||
- 新增 Nano Banana(Gemini 2.5 Flash Image)模型支持
|
|
||||||
- 新增系统 OCR 功能 (macOS & Windows)
|
|
||||||
- 新增图片 OCR 识别和翻译功能
|
|
||||||
- 模型切换支持通过标签筛选
|
|
||||||
- 翻译功能增强:历史搜索和收藏
|
|
||||||
|
|
||||||
🔧 性能优化:
|
🔧 性能优化:
|
||||||
- 优化历史页面搜索性能
|
- 优化AI服务连接方式,提升响应速度和稳定性
|
||||||
- 优化拖拽列表组件交互
|
- 改进模型列表获取功能,减少不必要的网络请求
|
||||||
- 升级 Electron 到 37.4.0
|
- 增强各AI服务商的兼容性和连接可靠性
|
||||||
|
|
||||||
🐛 修复问题:
|
🐛 问题修复:
|
||||||
- 修复知识库加密 PDF 文档处理
|
- 修复部分AI服务商连接失败的问题
|
||||||
- 修复导航栏在左侧时笔记侧边栏按钮缺失
|
- 修复模型配置加载时的潜在错误
|
||||||
- 修复多个模型兼容性问题
|
- 提升应用整体稳定性和容错能力
|
||||||
- 修复 MCP 相关问题
|
|
||||||
- 其他稳定性改进
|
|
||||||
|
|||||||
@@ -95,6 +95,9 @@ export default defineConfig({
|
|||||||
'@logger': resolve('src/renderer/src/services/LoggerService'),
|
'@logger': resolve('src/renderer/src/services/LoggerService'),
|
||||||
'@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'),
|
'@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'),
|
||||||
'@mcp-trace/trace-web': resolve('packages/mcp-trace/trace-web'),
|
'@mcp-trace/trace-web': resolve('packages/mcp-trace/trace-web'),
|
||||||
|
'@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'),
|
||||||
|
'@cherrystudio/ai-core/built-in/plugins': resolve('packages/aiCore/src/core/plugins/built-in'),
|
||||||
|
'@cherrystudio/ai-core': resolve('packages/aiCore/src'),
|
||||||
'@cherrystudio/extension-table-plus': resolve('packages/extension-table-plus/src')
|
'@cherrystudio/extension-table-plus': resolve('packages/extension-table-plus/src')
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
|
|||||||
12
package.json
12
package.json
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "CherryStudio",
|
"name": "CherryStudio",
|
||||||
"version": "1.5.9",
|
"version": "1.6.0-beta.6",
|
||||||
"private": true,
|
"private": true,
|
||||||
"description": "A powerful AI assistant for producer.",
|
"description": "A powerful AI assistant for producer.",
|
||||||
"main": "./out/main/index.js",
|
"main": "./out/main/index.js",
|
||||||
@@ -89,12 +89,16 @@
|
|||||||
"@agentic/exa": "^7.3.3",
|
"@agentic/exa": "^7.3.3",
|
||||||
"@agentic/searxng": "^7.3.3",
|
"@agentic/searxng": "^7.3.3",
|
||||||
"@agentic/tavily": "^7.3.3",
|
"@agentic/tavily": "^7.3.3",
|
||||||
|
"@ai-sdk/amazon-bedrock": "^3.0.0",
|
||||||
|
"@ai-sdk/google-vertex": "^3.0.0",
|
||||||
|
"@ai-sdk/mistral": "^2.0.0",
|
||||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||||
"@anthropic-ai/sdk": "^0.41.0",
|
"@anthropic-ai/sdk": "^0.41.0",
|
||||||
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
|
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
|
||||||
"@aws-sdk/client-bedrock": "^3.840.0",
|
"@aws-sdk/client-bedrock": "^3.840.0",
|
||||||
"@aws-sdk/client-bedrock-runtime": "^3.840.0",
|
"@aws-sdk/client-bedrock-runtime": "^3.840.0",
|
||||||
"@aws-sdk/client-s3": "^3.840.0",
|
"@aws-sdk/client-s3": "^3.840.0",
|
||||||
|
"@cherrystudio/ai-core": "workspace:*",
|
||||||
"@cherrystudio/embedjs": "^0.1.31",
|
"@cherrystudio/embedjs": "^0.1.31",
|
||||||
"@cherrystudio/embedjs-libsql": "^0.1.31",
|
"@cherrystudio/embedjs-libsql": "^0.1.31",
|
||||||
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
|
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
|
||||||
@@ -130,6 +134,7 @@
|
|||||||
"@modelcontextprotocol/sdk": "^1.17.0",
|
"@modelcontextprotocol/sdk": "^1.17.0",
|
||||||
"@mozilla/readability": "^0.6.0",
|
"@mozilla/readability": "^0.6.0",
|
||||||
"@notionhq/client": "^2.2.15",
|
"@notionhq/client": "^2.2.15",
|
||||||
|
"@openrouter/ai-sdk-provider": "^1.1.2",
|
||||||
"@opentelemetry/api": "^1.9.0",
|
"@opentelemetry/api": "^1.9.0",
|
||||||
"@opentelemetry/core": "2.0.0",
|
"@opentelemetry/core": "2.0.0",
|
||||||
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
|
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
|
||||||
@@ -139,7 +144,7 @@
|
|||||||
"@playwright/test": "^1.52.0",
|
"@playwright/test": "^1.52.0",
|
||||||
"@reduxjs/toolkit": "^2.2.5",
|
"@reduxjs/toolkit": "^2.2.5",
|
||||||
"@shikijs/markdown-it": "^3.12.0",
|
"@shikijs/markdown-it": "^3.12.0",
|
||||||
"@swc/plugin-styled-components": "^7.1.5",
|
"@swc/plugin-styled-components": "^8.0.4",
|
||||||
"@tanstack/react-query": "^5.85.5",
|
"@tanstack/react-query": "^5.85.5",
|
||||||
"@tanstack/react-virtual": "^3.13.12",
|
"@tanstack/react-virtual": "^3.13.12",
|
||||||
"@testing-library/dom": "^10.4.0",
|
"@testing-library/dom": "^10.4.0",
|
||||||
@@ -190,6 +195,7 @@
|
|||||||
"@viz-js/lang-dot": "^1.0.5",
|
"@viz-js/lang-dot": "^1.0.5",
|
||||||
"@viz-js/viz": "^3.14.0",
|
"@viz-js/viz": "^3.14.0",
|
||||||
"@xyflow/react": "^12.4.4",
|
"@xyflow/react": "^12.4.4",
|
||||||
|
"ai": "^5.0.29",
|
||||||
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
|
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
|
||||||
"archiver": "^7.0.1",
|
"archiver": "^7.0.1",
|
||||||
"async-mutex": "^0.5.0",
|
"async-mutex": "^0.5.0",
|
||||||
@@ -338,7 +344,7 @@
|
|||||||
"prettier --write",
|
"prettier --write",
|
||||||
"eslint --fix"
|
"eslint --fix"
|
||||||
],
|
],
|
||||||
"*.{json,md,yml,yaml,css,scss,html}": [
|
"*.{json,yml,yaml,css,scss,html}": [
|
||||||
"prettier --write"
|
"prettier --write"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|||||||
514
packages/aiCore/AI_SDK_ARCHITECTURE.md
Normal file
514
packages/aiCore/AI_SDK_ARCHITECTURE.md
Normal file
@@ -0,0 +1,514 @@
|
|||||||
|
# AI Core 基于 Vercel AI SDK 的技术架构
|
||||||
|
|
||||||
|
## 1. 架构设计理念
|
||||||
|
|
||||||
|
### 1.1 设计目标
|
||||||
|
|
||||||
|
- **简化分层**:`models`(模型层)→ `runtime`(运行时层),清晰的职责分离
|
||||||
|
- **统一接口**:使用 Vercel AI SDK 统一不同 AI Provider 的接口差异
|
||||||
|
- **动态导入**:通过动态导入实现按需加载,减少打包体积
|
||||||
|
- **最小包装**:直接使用 AI SDK 的类型和接口,避免重复定义
|
||||||
|
- **插件系统**:基于钩子的通用插件架构,支持请求全生命周期扩展
|
||||||
|
- **类型安全**:利用 TypeScript 和 AI SDK 的类型系统确保类型安全
|
||||||
|
- **轻量级**:专注核心功能,保持包的轻量和高效
|
||||||
|
- **包级独立**:作为独立包管理,便于复用和维护
|
||||||
|
- **Agent就绪**:为将来集成 OpenAI Agents SDK 预留扩展空间
|
||||||
|
|
||||||
|
### 1.2 核心优势
|
||||||
|
|
||||||
|
- **标准化**:AI SDK 提供统一的模型接口,减少适配工作
|
||||||
|
- **简化设计**:函数式API,避免过度抽象
|
||||||
|
- **更好的开发体验**:完整的 TypeScript 支持和丰富的生态系统
|
||||||
|
- **性能优化**:AI SDK 内置优化和最佳实践
|
||||||
|
- **模块化设计**:独立包结构,支持跨项目复用
|
||||||
|
- **可扩展插件**:通用的流转换和参数处理插件系统
|
||||||
|
- **面向未来**:为 OpenAI Agents SDK 集成做好准备
|
||||||
|
|
||||||
|
## 2. 整体架构图
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph TD
|
||||||
|
subgraph "用户应用 (如 Cherry Studio)"
|
||||||
|
UI["用户界面"]
|
||||||
|
Components["应用组件"]
|
||||||
|
end
|
||||||
|
|
||||||
|
subgraph "packages/aiCore (AI Core 包)"
|
||||||
|
subgraph "Runtime Layer (运行时层)"
|
||||||
|
RuntimeExecutor["RuntimeExecutor (运行时执行器)"]
|
||||||
|
PluginEngine["PluginEngine (插件引擎)"]
|
||||||
|
RuntimeAPI["Runtime API (便捷函数)"]
|
||||||
|
end
|
||||||
|
|
||||||
|
subgraph "Models Layer (模型层)"
|
||||||
|
ModelFactory["createModel() (模型工厂)"]
|
||||||
|
ProviderCreator["ProviderCreator (提供商创建器)"]
|
||||||
|
end
|
||||||
|
|
||||||
|
subgraph "Core Systems (核心系统)"
|
||||||
|
subgraph "Plugins (插件)"
|
||||||
|
PluginManager["PluginManager (插件管理)"]
|
||||||
|
BuiltInPlugins["Built-in Plugins (内置插件)"]
|
||||||
|
StreamTransforms["Stream Transforms (流转换)"]
|
||||||
|
end
|
||||||
|
|
||||||
|
subgraph "Middleware (中间件)"
|
||||||
|
MiddlewareWrapper["wrapModelWithMiddlewares() (中间件包装)"]
|
||||||
|
end
|
||||||
|
|
||||||
|
subgraph "Providers (提供商)"
|
||||||
|
Registry["Provider Registry (注册表)"]
|
||||||
|
Factory["Provider Factory (工厂)"]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
subgraph "Vercel AI SDK"
|
||||||
|
AICore["ai (核心库)"]
|
||||||
|
OpenAI["@ai-sdk/openai"]
|
||||||
|
Anthropic["@ai-sdk/anthropic"]
|
||||||
|
Google["@ai-sdk/google"]
|
||||||
|
XAI["@ai-sdk/xai"]
|
||||||
|
Others["其他 19+ Providers"]
|
||||||
|
end
|
||||||
|
|
||||||
|
subgraph "Future: OpenAI Agents SDK"
|
||||||
|
AgentSDK["@openai/agents (未来集成)"]
|
||||||
|
AgentExtensions["Agent Extensions (预留)"]
|
||||||
|
end
|
||||||
|
|
||||||
|
UI --> RuntimeAPI
|
||||||
|
Components --> RuntimeExecutor
|
||||||
|
RuntimeAPI --> RuntimeExecutor
|
||||||
|
RuntimeExecutor --> PluginEngine
|
||||||
|
RuntimeExecutor --> ModelFactory
|
||||||
|
PluginEngine --> PluginManager
|
||||||
|
ModelFactory --> ProviderCreator
|
||||||
|
ModelFactory --> MiddlewareWrapper
|
||||||
|
ProviderCreator --> Registry
|
||||||
|
Registry --> Factory
|
||||||
|
Factory --> OpenAI
|
||||||
|
Factory --> Anthropic
|
||||||
|
Factory --> Google
|
||||||
|
Factory --> XAI
|
||||||
|
Factory --> Others
|
||||||
|
|
||||||
|
RuntimeExecutor --> AICore
|
||||||
|
AICore --> streamText
|
||||||
|
AICore --> generateText
|
||||||
|
AICore --> streamObject
|
||||||
|
AICore --> generateObject
|
||||||
|
|
||||||
|
PluginManager --> StreamTransforms
|
||||||
|
PluginManager --> BuiltInPlugins
|
||||||
|
|
||||||
|
%% 未来集成路径
|
||||||
|
RuntimeExecutor -.-> AgentSDK
|
||||||
|
AgentSDK -.-> AgentExtensions
|
||||||
|
```
|
||||||
|
|
||||||
|
## 3. 包结构设计
|
||||||
|
|
||||||
|
### 3.1 新架构文件结构
|
||||||
|
|
||||||
|
```
|
||||||
|
packages/aiCore/
|
||||||
|
├── src/
|
||||||
|
│ ├── core/ # 核心层 - 内部实现
|
||||||
|
│ │ ├── models/ # 模型层 - 模型创建和配置
|
||||||
|
│ │ │ ├── factory.ts # 模型工厂函数 ✅
|
||||||
|
│ │ │ ├── ModelCreator.ts # 模型创建器 ✅
|
||||||
|
│ │ │ ├── ConfigManager.ts # 配置管理器 ✅
|
||||||
|
│ │ │ ├── types.ts # 模型类型定义 ✅
|
||||||
|
│ │ │ └── index.ts # 模型层导出 ✅
|
||||||
|
│ │ ├── runtime/ # 运行时层 - 执行和用户API
|
||||||
|
│ │ │ ├── executor.ts # 运行时执行器 ✅
|
||||||
|
│ │ │ ├── pluginEngine.ts # 插件引擎 ✅
|
||||||
|
│ │ │ ├── types.ts # 运行时类型定义 ✅
|
||||||
|
│ │ │ └── index.ts # 运行时导出 ✅
|
||||||
|
│ │ ├── middleware/ # 中间件系统
|
||||||
|
│ │ │ ├── wrapper.ts # 模型包装器 ✅
|
||||||
|
│ │ │ ├── manager.ts # 中间件管理器 ✅
|
||||||
|
│ │ │ ├── types.ts # 中间件类型 ✅
|
||||||
|
│ │ │ └── index.ts # 中间件导出 ✅
|
||||||
|
│ │ ├── plugins/ # 插件系统
|
||||||
|
│ │ │ ├── types.ts # 插件类型定义 ✅
|
||||||
|
│ │ │ ├── manager.ts # 插件管理器 ✅
|
||||||
|
│ │ │ ├── built-in/ # 内置插件 ✅
|
||||||
|
│ │ │ │ ├── logging.ts # 日志插件 ✅
|
||||||
|
│ │ │ │ ├── webSearchPlugin/ # 网络搜索插件 ✅
|
||||||
|
│ │ │ │ ├── toolUsePlugin/ # 工具使用插件 ✅
|
||||||
|
│ │ │ │ └── index.ts # 内置插件导出 ✅
|
||||||
|
│ │ │ ├── README.md # 插件文档 ✅
|
||||||
|
│ │ │ └── index.ts # 插件导出 ✅
|
||||||
|
│ │ ├── providers/ # 提供商管理
|
||||||
|
│ │ │ ├── registry.ts # 提供商注册表 ✅
|
||||||
|
│ │ │ ├── factory.ts # 提供商工厂 ✅
|
||||||
|
│ │ │ ├── creator.ts # 提供商创建器 ✅
|
||||||
|
│ │ │ ├── types.ts # 提供商类型 ✅
|
||||||
|
│ │ │ ├── utils.ts # 工具函数 ✅
|
||||||
|
│ │ │ └── index.ts # 提供商导出 ✅
|
||||||
|
│ │ ├── options/ # 配置选项
|
||||||
|
│ │ │ ├── factory.ts # 选项工厂 ✅
|
||||||
|
│ │ │ ├── types.ts # 选项类型 ✅
|
||||||
|
│ │ │ ├── xai.ts # xAI 选项 ✅
|
||||||
|
│ │ │ ├── openrouter.ts # OpenRouter 选项 ✅
|
||||||
|
│ │ │ ├── examples.ts # 示例配置 ✅
|
||||||
|
│ │ │ └── index.ts # 选项导出 ✅
|
||||||
|
│ │ └── index.ts # 核心层导出 ✅
|
||||||
|
│ ├── types.ts # 全局类型定义 ✅
|
||||||
|
│ └── index.ts # 包主入口文件 ✅
|
||||||
|
├── package.json # 包配置文件 ✅
|
||||||
|
├── tsconfig.json # TypeScript 配置 ✅
|
||||||
|
├── README.md # 包说明文档 ✅
|
||||||
|
└── AI_SDK_ARCHITECTURE.md # 本文档 ✅
|
||||||
|
```
|
||||||
|
|
||||||
|
## 4. 架构分层详解
|
||||||
|
|
||||||
|
### 4.1 Models Layer (模型层)
|
||||||
|
|
||||||
|
**职责**:统一的模型创建和配置管理
|
||||||
|
|
||||||
|
**核心文件**:
|
||||||
|
|
||||||
|
- `factory.ts`: 模型工厂函数 (`createModel`, `createModels`)
|
||||||
|
- `ProviderCreator.ts`: 底层提供商创建和模型实例化
|
||||||
|
- `types.ts`: 模型配置类型定义
|
||||||
|
|
||||||
|
**设计特点**:
|
||||||
|
|
||||||
|
- 函数式设计,避免不必要的类抽象
|
||||||
|
- 统一的模型配置接口
|
||||||
|
- 自动处理中间件应用
|
||||||
|
- 支持批量模型创建
|
||||||
|
|
||||||
|
**核心API**:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// 模型配置接口
|
||||||
|
export interface ModelConfig {
|
||||||
|
providerId: ProviderId
|
||||||
|
modelId: string
|
||||||
|
options: ProviderSettingsMap[ProviderId]
|
||||||
|
middlewares?: LanguageModelV1Middleware[]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 核心模型创建函数
|
||||||
|
export async function createModel(config: ModelConfig): Promise<LanguageModel>
|
||||||
|
export async function createModels(configs: ModelConfig[]): Promise<LanguageModel[]>
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.2 Runtime Layer (运行时层)
|
||||||
|
|
||||||
|
**职责**:运行时执行器和用户面向的API接口
|
||||||
|
|
||||||
|
**核心组件**:
|
||||||
|
|
||||||
|
- `executor.ts`: 运行时执行器类
|
||||||
|
- `plugin-engine.ts`: 插件引擎(原PluginEnabledAiClient)
|
||||||
|
- `index.ts`: 便捷函数和工厂方法
|
||||||
|
|
||||||
|
**设计特点**:
|
||||||
|
|
||||||
|
- 提供三种使用方式:类实例、静态工厂、函数式调用
|
||||||
|
- 自动集成模型创建和插件处理
|
||||||
|
- 完整的类型安全支持
|
||||||
|
- 为 OpenAI Agents SDK 预留扩展接口
|
||||||
|
|
||||||
|
**核心API**:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// 运行时执行器
|
||||||
|
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||||
|
static create<T extends ProviderId>(
|
||||||
|
providerId: T,
|
||||||
|
options: ProviderSettingsMap[T],
|
||||||
|
plugins?: AiPlugin[]
|
||||||
|
): RuntimeExecutor<T>
|
||||||
|
|
||||||
|
async streamText(modelId: string, params: StreamTextParams): Promise<StreamTextResult>
|
||||||
|
async generateText(modelId: string, params: GenerateTextParams): Promise<GenerateTextResult>
|
||||||
|
async streamObject(modelId: string, params: StreamObjectParams): Promise<StreamObjectResult>
|
||||||
|
async generateObject(modelId: string, params: GenerateObjectParams): Promise<GenerateObjectResult>
|
||||||
|
}
|
||||||
|
|
||||||
|
// 便捷函数式API
|
||||||
|
export async function streamText<T extends ProviderId>(
|
||||||
|
providerId: T,
|
||||||
|
options: ProviderSettingsMap[T],
|
||||||
|
modelId: string,
|
||||||
|
params: StreamTextParams,
|
||||||
|
plugins?: AiPlugin[]
|
||||||
|
): Promise<StreamTextResult>
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.3 Plugin System (插件系统)
|
||||||
|
|
||||||
|
**职责**:可扩展的插件架构
|
||||||
|
|
||||||
|
**核心组件**:
|
||||||
|
|
||||||
|
- `PluginManager`: 插件生命周期管理
|
||||||
|
- `built-in/`: 内置插件集合
|
||||||
|
- 流转换收集和应用
|
||||||
|
|
||||||
|
**设计特点**:
|
||||||
|
|
||||||
|
- 借鉴 Rollup 的钩子分类设计
|
||||||
|
- 支持流转换 (`experimental_transform`)
|
||||||
|
- 内置常用插件(日志、计数等)
|
||||||
|
- 完整的生命周期钩子
|
||||||
|
|
||||||
|
**插件接口**:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export interface AiPlugin {
|
||||||
|
name: string
|
||||||
|
enforce?: 'pre' | 'post'
|
||||||
|
|
||||||
|
// 【First】首个钩子 - 只执行第一个返回值的插件
|
||||||
|
resolveModel?: (modelId: string, context: AiRequestContext) => string | null | Promise<string | null>
|
||||||
|
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
|
||||||
|
|
||||||
|
// 【Sequential】串行钩子 - 链式执行,支持数据转换
|
||||||
|
transformParams?: (params: any, context: AiRequestContext) => any | Promise<any>
|
||||||
|
transformResult?: (result: any, context: AiRequestContext) => any | Promise<any>
|
||||||
|
|
||||||
|
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
|
||||||
|
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
|
||||||
|
onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise<void>
|
||||||
|
onError?: (error: Error, context: AiRequestContext) => void | Promise<void>
|
||||||
|
|
||||||
|
// 【Stream】流处理
|
||||||
|
transformStream?: () => TransformStream
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.4 Middleware System (中间件系统)
|
||||||
|
|
||||||
|
**职责**:AI SDK原生中间件支持
|
||||||
|
|
||||||
|
**核心组件**:
|
||||||
|
|
||||||
|
- `ModelWrapper.ts`: 模型包装函数
|
||||||
|
|
||||||
|
**设计哲学**:
|
||||||
|
|
||||||
|
- 直接使用AI SDK的 `wrapLanguageModel`
|
||||||
|
- 与插件系统分离,职责明确
|
||||||
|
- 函数式设计,简化使用
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export function wrapModelWithMiddlewares(model: LanguageModel, middlewares: LanguageModelV1Middleware[]): LanguageModel
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4.5 Provider System (提供商系统)
|
||||||
|
|
||||||
|
**职责**:AI Provider注册表和动态导入
|
||||||
|
|
||||||
|
**核心组件**:
|
||||||
|
|
||||||
|
- `registry.ts`: 19+ Provider配置和类型
|
||||||
|
- `factory.ts`: Provider配置工厂
|
||||||
|
|
||||||
|
**支持的Providers**:
|
||||||
|
|
||||||
|
- OpenAI, Anthropic, Google, XAI
|
||||||
|
- Azure OpenAI, Amazon Bedrock, Google Vertex
|
||||||
|
- Groq, Together.ai, Fireworks, DeepSeek
|
||||||
|
- 等19+ AI SDK官方支持的providers
|
||||||
|
|
||||||
|
## 5. 使用方式
|
||||||
|
|
||||||
|
### 5.1 函数式调用 (推荐 - 简单场景)
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { streamText, generateText } from '@cherrystudio/ai-core/runtime'
|
||||||
|
|
||||||
|
// 直接函数调用
|
||||||
|
const stream = await streamText(
|
||||||
|
'anthropic',
|
||||||
|
{ apiKey: 'your-api-key' },
|
||||||
|
'claude-3',
|
||||||
|
{ messages: [{ role: 'user', content: 'Hello!' }] },
|
||||||
|
[loggingPlugin]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.2 执行器实例 (推荐 - 复杂场景)
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { createExecutor } from '@cherrystudio/ai-core/runtime'
|
||||||
|
|
||||||
|
// 创建可复用的执行器
|
||||||
|
const executor = createExecutor('openai', { apiKey: 'your-api-key' }, [plugin1, plugin2])
|
||||||
|
|
||||||
|
// 多次使用
|
||||||
|
const stream = await executor.streamText('gpt-4', {
|
||||||
|
messages: [{ role: 'user', content: 'Hello!' }]
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = await executor.generateText('gpt-4', {
|
||||||
|
messages: [{ role: 'user', content: 'How are you?' }]
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.3 静态工厂方法
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { RuntimeExecutor } from '@cherrystudio/ai-core/runtime'
|
||||||
|
|
||||||
|
// 静态创建
|
||||||
|
const executor = RuntimeExecutor.create('anthropic', { apiKey: 'your-api-key' })
|
||||||
|
await executor.streamText('claude-3', { messages: [...] })
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5.4 直接模型创建 (高级用法)
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { createModel } from '@cherrystudio/ai-core/models'
|
||||||
|
import { streamText } from 'ai'
|
||||||
|
|
||||||
|
// 直接创建模型使用
|
||||||
|
const model = await createModel({
|
||||||
|
providerId: 'openai',
|
||||||
|
modelId: 'gpt-4',
|
||||||
|
options: { apiKey: 'your-api-key' },
|
||||||
|
middlewares: [middleware1, middleware2]
|
||||||
|
})
|
||||||
|
|
||||||
|
// 直接使用 AI SDK
|
||||||
|
const result = await streamText({ model, messages: [...] })
|
||||||
|
```
|
||||||
|
|
||||||
|
## 6. 为 OpenAI Agents SDK 预留的设计
|
||||||
|
|
||||||
|
### 6.1 架构兼容性
|
||||||
|
|
||||||
|
当前架构完全兼容 OpenAI Agents SDK 的集成需求:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// 当前的模型创建
|
||||||
|
const model = await createModel({
|
||||||
|
providerId: 'anthropic',
|
||||||
|
modelId: 'claude-3',
|
||||||
|
options: { apiKey: 'xxx' }
|
||||||
|
})
|
||||||
|
|
||||||
|
// 将来可以直接用于 OpenAI Agents SDK
|
||||||
|
import { Agent, run } from '@openai/agents'
|
||||||
|
|
||||||
|
const agent = new Agent({
|
||||||
|
model, // ✅ 直接兼容 LanguageModel 接口
|
||||||
|
name: 'Assistant',
|
||||||
|
instructions: '...',
|
||||||
|
tools: [tool1, tool2]
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = await run(agent, 'user input')
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.2 预留的扩展点
|
||||||
|
|
||||||
|
1. **runtime/agents/** 目录预留
|
||||||
|
2. **AgentExecutor** 类预留
|
||||||
|
3. **Agent工具转换插件** 预留
|
||||||
|
4. **多Agent编排** 预留
|
||||||
|
|
||||||
|
### 6.3 未来架构扩展
|
||||||
|
|
||||||
|
```
|
||||||
|
packages/aiCore/src/core/
|
||||||
|
├── runtime/
|
||||||
|
│ ├── agents/ # 🚀 未来添加
|
||||||
|
│ │ ├── AgentExecutor.ts
|
||||||
|
│ │ ├── WorkflowManager.ts
|
||||||
|
│ │ └── ConversationManager.ts
|
||||||
|
│ ├── executor.ts
|
||||||
|
│ └── index.ts
|
||||||
|
```
|
||||||
|
|
||||||
|
## 7. 架构优势
|
||||||
|
|
||||||
|
### 7.1 简化设计
|
||||||
|
|
||||||
|
- **移除过度抽象**:删除了orchestration层和creation层的复杂包装
|
||||||
|
- **函数式优先**:models层使用函数而非类
|
||||||
|
- **直接明了**:runtime层直接提供用户API
|
||||||
|
|
||||||
|
### 7.2 职责清晰
|
||||||
|
|
||||||
|
- **Models**: 专注模型创建和配置
|
||||||
|
- **Runtime**: 专注执行和用户API
|
||||||
|
- **Plugins**: 专注扩展功能
|
||||||
|
- **Providers**: 专注AI Provider管理
|
||||||
|
|
||||||
|
### 7.3 类型安全
|
||||||
|
|
||||||
|
- 完整的 TypeScript 支持
|
||||||
|
- AI SDK 类型的直接复用
|
||||||
|
- 避免类型重复定义
|
||||||
|
|
||||||
|
### 7.4 灵活使用
|
||||||
|
|
||||||
|
- 三种使用模式满足不同需求
|
||||||
|
- 从简单函数调用到复杂执行器
|
||||||
|
- 支持直接AI SDK使用
|
||||||
|
|
||||||
|
### 7.5 面向未来
|
||||||
|
|
||||||
|
- 为 OpenAI Agents SDK 集成做好准备
|
||||||
|
- 清晰的扩展点和架构边界
|
||||||
|
- 模块化设计便于功能添加
|
||||||
|
|
||||||
|
## 8. 技术决策记录
|
||||||
|
|
||||||
|
### 8.1 为什么选择简化的两层架构?
|
||||||
|
|
||||||
|
- **职责分离**:models专注创建,runtime专注执行
|
||||||
|
- **模块化**:每层都有清晰的边界和职责
|
||||||
|
- **扩展性**:为Agent功能预留了清晰的扩展空间
|
||||||
|
|
||||||
|
### 8.2 为什么选择函数式设计?
|
||||||
|
|
||||||
|
- **简洁性**:避免不必要的类设计
|
||||||
|
- **性能**:减少对象创建开销
|
||||||
|
- **易用性**:函数调用更直观
|
||||||
|
|
||||||
|
### 8.3 为什么分离插件和中间件?
|
||||||
|
|
||||||
|
- **职责明确**: 插件处理应用特定需求
|
||||||
|
- **原生支持**: 中间件使用AI SDK原生功能
|
||||||
|
- **灵活性**: 两套系统可以独立演进
|
||||||
|
|
||||||
|
## 9. 总结
|
||||||
|
|
||||||
|
AI Core架构实现了:
|
||||||
|
|
||||||
|
### 9.1 核心特点
|
||||||
|
|
||||||
|
- ✅ **简化架构**: 2层核心架构,职责清晰
|
||||||
|
- ✅ **函数式设计**: models层完全函数化
|
||||||
|
- ✅ **类型安全**: 统一的类型定义和AI SDK类型复用
|
||||||
|
- ✅ **插件扩展**: 强大的插件系统
|
||||||
|
- ✅ **多种使用方式**: 满足不同复杂度需求
|
||||||
|
- ✅ **Agent就绪**: 为OpenAI Agents SDK集成做好准备
|
||||||
|
|
||||||
|
### 9.2 核心价值
|
||||||
|
|
||||||
|
- **统一接口**: 一套API支持19+ AI providers
|
||||||
|
- **灵活使用**: 函数式、实例式、静态工厂式
|
||||||
|
- **强类型**: 完整的TypeScript支持
|
||||||
|
- **可扩展**: 插件和中间件双重扩展能力
|
||||||
|
- **高性能**: 最小化包装,直接使用AI SDK
|
||||||
|
- **面向未来**: Agent SDK集成架构就绪
|
||||||
|
|
||||||
|
### 9.3 未来发展
|
||||||
|
|
||||||
|
这个架构提供了:
|
||||||
|
|
||||||
|
- **优秀的开发体验**: 简洁的API和清晰的使用模式
|
||||||
|
- **强大的扩展能力**: 为Agent功能预留了完整的架构空间
|
||||||
|
- **良好的维护性**: 职责分离明确,代码易于维护
|
||||||
|
- **广泛的适用性**: 既适合简单调用也适合复杂应用
|
||||||
433
packages/aiCore/README.md
Normal file
433
packages/aiCore/README.md
Normal file
@@ -0,0 +1,433 @@
|
|||||||
|
# @cherrystudio/ai-core
|
||||||
|
|
||||||
|
Cherry Studio AI Core 是一个基于 Vercel AI SDK 的统一 AI Provider 接口包,为 AI 应用提供强大的抽象层和插件化架构。
|
||||||
|
|
||||||
|
## ✨ 核心亮点
|
||||||
|
|
||||||
|
### 🏗️ 优雅的架构设计
|
||||||
|
|
||||||
|
- **简化分层**:`models`(模型层)→ `runtime`(运行时层),清晰的职责分离
|
||||||
|
- **函数式优先**:避免过度抽象,提供简洁直观的 API
|
||||||
|
- **类型安全**:完整的 TypeScript 支持,直接复用 AI SDK 类型系统
|
||||||
|
- **最小包装**:直接使用 AI SDK 的接口,避免重复定义和性能损耗
|
||||||
|
|
||||||
|
### 🔌 强大的插件系统
|
||||||
|
|
||||||
|
- **生命周期钩子**:支持请求全生命周期的扩展点
|
||||||
|
- **流转换支持**:基于 AI SDK 的 `experimental_transform` 实现流处理
|
||||||
|
- **插件分类**:First、Sequential、Parallel 三种钩子类型,满足不同场景
|
||||||
|
- **内置插件**:webSearch、logging、toolUse 等开箱即用的功能
|
||||||
|
|
||||||
|
### 🌐 统一多 Provider 接口
|
||||||
|
|
||||||
|
- **扩展注册**:支持自定义 Provider 注册,无限扩展能力
|
||||||
|
- **配置统一**:统一的配置接口,简化多 Provider 管理
|
||||||
|
|
||||||
|
### 🚀 多种使用方式
|
||||||
|
|
||||||
|
- **函数式调用**:适合简单场景的直接函数调用
|
||||||
|
- **执行器实例**:适合复杂场景的可复用执行器
|
||||||
|
- **静态工厂**:便捷的静态创建方法
|
||||||
|
- **原生兼容**:完全兼容 AI SDK 原生 Provider Registry
|
||||||
|
|
||||||
|
### 🔮 面向未来
|
||||||
|
|
||||||
|
- **Agent 就绪**:为 OpenAI Agents SDK 集成预留架构空间
|
||||||
|
- **模块化设计**:独立包结构,支持跨项目复用
|
||||||
|
- **渐进式迁移**:可以逐步从现有 AI SDK 代码迁移
|
||||||
|
|
||||||
|
## 特性
|
||||||
|
|
||||||
|
- 🚀 统一的 AI Provider 接口
|
||||||
|
- 🔄 动态导入支持
|
||||||
|
- 🛠️ TypeScript 支持
|
||||||
|
- 📦 强大的插件系统
|
||||||
|
- 🌍 内置webSearch(Openai,Google,Anthropic,xAI)
|
||||||
|
- 🎯 多种使用模式(函数式/实例式/静态工厂)
|
||||||
|
- 🔌 可扩展的 Provider 注册系统
|
||||||
|
- 🧩 完整的中间件支持
|
||||||
|
- 📊 插件统计和调试功能
|
||||||
|
|
||||||
|
## 支持的 Providers
|
||||||
|
|
||||||
|
基于 [AI SDK 官方支持的 providers](https://ai-sdk.dev/providers/ai-sdk-providers):
|
||||||
|
|
||||||
|
**核心 Providers(内置支持):**
|
||||||
|
|
||||||
|
- OpenAI
|
||||||
|
- Anthropic
|
||||||
|
- Google Generative AI
|
||||||
|
- OpenAI-Compatible
|
||||||
|
- xAI (Grok)
|
||||||
|
- Azure OpenAI
|
||||||
|
- DeepSeek
|
||||||
|
|
||||||
|
**扩展 Providers(通过注册API支持):**
|
||||||
|
|
||||||
|
- Google Vertex AI
|
||||||
|
- ...
|
||||||
|
- 自定义 Provider
|
||||||
|
|
||||||
|
## 安装
|
||||||
|
|
||||||
|
```bash
|
||||||
|
npm install @cherrystudio/ai-core ai
|
||||||
|
```
|
||||||
|
|
||||||
|
### React Native
|
||||||
|
|
||||||
|
如果你在 React Native 项目中使用此包,需要在 `metro.config.js` 中添加以下配置:
|
||||||
|
|
||||||
|
```javascript
|
||||||
|
// metro.config.js
|
||||||
|
const { getDefaultConfig } = require('expo/metro-config')
|
||||||
|
|
||||||
|
const config = getDefaultConfig(__dirname)
|
||||||
|
|
||||||
|
// 添加对 @cherrystudio/ai-core 的支持
|
||||||
|
config.resolver.resolverMainFields = ['react-native', 'browser', 'main']
|
||||||
|
config.resolver.platforms = ['ios', 'android', 'native', 'web']
|
||||||
|
|
||||||
|
module.exports = config
|
||||||
|
```
|
||||||
|
|
||||||
|
还需要安装你要使用的 AI SDK provider:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
npm install @ai-sdk/openai @ai-sdk/anthropic @ai-sdk/google
|
||||||
|
```
|
||||||
|
|
||||||
|
## 使用示例
|
||||||
|
|
||||||
|
### 基础用法
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { AiCore } from '@cherrystudio/ai-core'
|
||||||
|
|
||||||
|
// 创建 OpenAI executor
|
||||||
|
const executor = AiCore.create('openai', {
|
||||||
|
apiKey: 'your-api-key'
|
||||||
|
})
|
||||||
|
|
||||||
|
// 流式生成
|
||||||
|
const result = await executor.streamText('gpt-4', {
|
||||||
|
messages: [{ role: 'user', content: 'Hello!' }]
|
||||||
|
})
|
||||||
|
|
||||||
|
// 非流式生成
|
||||||
|
const response = await executor.generateText('gpt-4', {
|
||||||
|
messages: [{ role: 'user', content: 'Hello!' }]
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 便捷函数
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { createOpenAIExecutor } from '@cherrystudio/ai-core'
|
||||||
|
|
||||||
|
// 快速创建 OpenAI executor
|
||||||
|
const executor = createOpenAIExecutor({
|
||||||
|
apiKey: 'your-api-key'
|
||||||
|
})
|
||||||
|
|
||||||
|
// 使用 executor
|
||||||
|
const result = await executor.streamText('gpt-4', {
|
||||||
|
messages: [{ role: 'user', content: 'Hello!' }]
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 多 Provider 支持
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { AiCore } from '@cherrystudio/ai-core'
|
||||||
|
|
||||||
|
// 支持多种 AI providers
|
||||||
|
const openaiExecutor = AiCore.create('openai', { apiKey: 'openai-key' })
|
||||||
|
const anthropicExecutor = AiCore.create('anthropic', { apiKey: 'anthropic-key' })
|
||||||
|
const googleExecutor = AiCore.create('google', { apiKey: 'google-key' })
|
||||||
|
const xaiExecutor = AiCore.create('xai', { apiKey: 'xai-key' })
|
||||||
|
```
|
||||||
|
|
||||||
|
### 扩展 Provider 注册
|
||||||
|
|
||||||
|
对于非内置的 providers,可以通过注册 API 扩展支持:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { registerProvider, AiCore } from '@cherrystudio/ai-core'
|
||||||
|
|
||||||
|
// 方式一:导入并注册第三方 provider
|
||||||
|
import { createGroq } from '@ai-sdk/groq'
|
||||||
|
|
||||||
|
registerProvider({
|
||||||
|
id: 'groq',
|
||||||
|
name: 'Groq',
|
||||||
|
creator: createGroq,
|
||||||
|
supportsImageGeneration: false
|
||||||
|
})
|
||||||
|
|
||||||
|
// 现在可以使用 Groq
|
||||||
|
const groqExecutor = AiCore.create('groq', { apiKey: 'groq-key' })
|
||||||
|
|
||||||
|
// 方式二:动态导入方式注册
|
||||||
|
registerProvider({
|
||||||
|
id: 'mistral',
|
||||||
|
name: 'Mistral AI',
|
||||||
|
import: () => import('@ai-sdk/mistral'),
|
||||||
|
creatorFunctionName: 'createMistral'
|
||||||
|
})
|
||||||
|
|
||||||
|
const mistralExecutor = AiCore.create('mistral', { apiKey: 'mistral-key' })
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔌 插件系统
|
||||||
|
|
||||||
|
AI Core 提供了强大的插件系统,支持请求全生命周期的扩展。
|
||||||
|
|
||||||
|
### 内置插件
|
||||||
|
|
||||||
|
#### webSearchPlugin - 网络搜索插件
|
||||||
|
|
||||||
|
为不同 AI Provider 提供统一的网络搜索能力:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||||
|
|
||||||
|
const executor = AiCore.create('openai', { apiKey: 'your-key' }, [
|
||||||
|
webSearchPlugin({
|
||||||
|
openai: {
|
||||||
|
/* OpenAI 搜索配置 */
|
||||||
|
},
|
||||||
|
anthropic: { maxUses: 5 },
|
||||||
|
google: {
|
||||||
|
/* Google 搜索配置 */
|
||||||
|
},
|
||||||
|
xai: {
|
||||||
|
mode: 'on',
|
||||||
|
returnCitations: true,
|
||||||
|
maxSearchResults: 5,
|
||||||
|
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
])
|
||||||
|
```
|
||||||
|
|
||||||
|
#### loggingPlugin - 日志插件
|
||||||
|
|
||||||
|
提供详细的请求日志记录:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { createLoggingPlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||||
|
|
||||||
|
const executor = AiCore.create('openai', { apiKey: 'your-key' }, [
|
||||||
|
createLoggingPlugin({
|
||||||
|
logLevel: 'info',
|
||||||
|
includeParams: true,
|
||||||
|
includeResult: false
|
||||||
|
})
|
||||||
|
])
|
||||||
|
```
|
||||||
|
|
||||||
|
#### promptToolUsePlugin - 提示工具使用插件
|
||||||
|
|
||||||
|
为不支持原生 Function Call 的模型提供 prompt 方式的工具调用:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { createPromptToolUsePlugin } from '@cherrystudio/ai-core/built-in/plugins'
|
||||||
|
|
||||||
|
// 对于不支持 function call 的模型
|
||||||
|
const executor = AiCore.create(
|
||||||
|
'providerId',
|
||||||
|
{
|
||||||
|
apiKey: 'your-key',
|
||||||
|
baseURL: 'https://your-model-endpoint'
|
||||||
|
},
|
||||||
|
[
|
||||||
|
createPromptToolUsePlugin({
|
||||||
|
enabled: true,
|
||||||
|
// 可选:自定义系统提示符构建
|
||||||
|
buildSystemPrompt: (userPrompt, tools) => {
|
||||||
|
return `${userPrompt}\n\nAvailable tools: ${Object.keys(tools).join(', ')}`
|
||||||
|
}
|
||||||
|
})
|
||||||
|
]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 自定义插件
|
||||||
|
|
||||||
|
创建自定义插件非常简单:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { definePlugin } from '@cherrystudio/ai-core'
|
||||||
|
|
||||||
|
const customPlugin = definePlugin({
|
||||||
|
name: 'custom-plugin',
|
||||||
|
enforce: 'pre', // 'pre' | 'post' | undefined
|
||||||
|
|
||||||
|
// 在请求开始时记录日志
|
||||||
|
onRequestStart: async (context) => {
|
||||||
|
console.log(`Starting request for model: ${context.modelId}`)
|
||||||
|
},
|
||||||
|
|
||||||
|
// 转换请求参数
|
||||||
|
transformParams: async (params, context) => {
|
||||||
|
// 添加自定义系统消息
|
||||||
|
if (params.messages) {
|
||||||
|
params.messages.unshift({
|
||||||
|
role: 'system',
|
||||||
|
content: 'You are a helpful assistant.'
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
},
|
||||||
|
|
||||||
|
// 处理响应结果
|
||||||
|
transformResult: async (result, context) => {
|
||||||
|
// 添加元数据
|
||||||
|
if (result.text) {
|
||||||
|
result.metadata = {
|
||||||
|
processedAt: new Date().toISOString(),
|
||||||
|
modelId: context.modelId
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// 使用自定义插件
|
||||||
|
const executor = AiCore.create('openai', { apiKey: 'your-key' }, [customPlugin])
|
||||||
|
```
|
||||||
|
|
||||||
|
### 使用 AI SDK 原生 Provider 注册表
|
||||||
|
|
||||||
|
> https://ai-sdk.dev/docs/reference/ai-sdk-core/provider-registry
|
||||||
|
|
||||||
|
除了使用内建的 provider 管理,你还可以使用 AI SDK 原生的 `createProviderRegistry` 来构建自己的 provider 注册表。
|
||||||
|
|
||||||
|
#### 基本用法示例
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { createClient } from '@cherrystudio/ai-core'
|
||||||
|
import { createProviderRegistry } from 'ai'
|
||||||
|
import { createOpenAI } from '@ai-sdk/openai'
|
||||||
|
import { anthropic } from '@ai-sdk/anthropic'
|
||||||
|
|
||||||
|
// 1. 创建 AI SDK 原生注册表
|
||||||
|
export const registry = createProviderRegistry({
|
||||||
|
// register provider with prefix and default setup:
|
||||||
|
anthropic,
|
||||||
|
|
||||||
|
// register provider with prefix and custom setup:
|
||||||
|
openai: createOpenAI({
|
||||||
|
apiKey: process.env.OPENAI_API_KEY
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
// 2. 创建client,'openai'可以传空或者传providerId(内建的provider)
|
||||||
|
const client = PluginEnabledAiClient.create('openai', {
|
||||||
|
apiKey: process.env.OPENAI_API_KEY
|
||||||
|
})
|
||||||
|
|
||||||
|
// 3. 方式1:使用内建逻辑(传统方式)
|
||||||
|
const result1 = await client.streamText('gpt-4', {
|
||||||
|
messages: [{ role: 'user', content: 'Hello with built-in logic!' }]
|
||||||
|
})
|
||||||
|
|
||||||
|
// 4. 方式2:使用自定义注册表(灵活方式)
|
||||||
|
const result2 = await client.streamText({
|
||||||
|
model: registry.languageModel('openai:gpt-4'),
|
||||||
|
messages: [{ role: 'user', content: 'Hello with custom registry!' }]
|
||||||
|
})
|
||||||
|
|
||||||
|
// 5. 支持的重载方法
|
||||||
|
await client.generateObject({
|
||||||
|
model: registry.languageModel('openai:gpt-4'),
|
||||||
|
schema: z.object({ name: z.string() }),
|
||||||
|
messages: [{ role: 'user', content: 'Generate a user' }]
|
||||||
|
})
|
||||||
|
|
||||||
|
await client.streamObject({
|
||||||
|
model: registry.languageModel('anthropic:claude-3-opus-20240229'),
|
||||||
|
schema: z.object({ items: z.array(z.string()) }),
|
||||||
|
messages: [{ role: 'user', content: 'Generate a list' }]
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 与插件系统配合使用
|
||||||
|
|
||||||
|
更强大的是,你还可以将自定义注册表与 Cherry Studio 的插件系统结合使用:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { PluginEnabledAiClient } from '@cherrystudio/ai-core'
|
||||||
|
import { createProviderRegistry } from 'ai'
|
||||||
|
import { createOpenAI } from '@ai-sdk/openai'
|
||||||
|
import { anthropic } from '@ai-sdk/anthropic'
|
||||||
|
|
||||||
|
// 1. 创建带插件的客户端
|
||||||
|
const client = PluginEnabledAiClient.create(
|
||||||
|
'openai',
|
||||||
|
{
|
||||||
|
apiKey: process.env.OPENAI_API_KEY
|
||||||
|
},
|
||||||
|
[LoggingPlugin, RetryPlugin]
|
||||||
|
)
|
||||||
|
|
||||||
|
// 2. 创建自定义注册表
|
||||||
|
const registry = createProviderRegistry({
|
||||||
|
openai: createOpenAI({ apiKey: process.env.OPENAI_API_KEY }),
|
||||||
|
anthropic: anthropic({ apiKey: process.env.ANTHROPIC_API_KEY })
|
||||||
|
})
|
||||||
|
|
||||||
|
// 3. 方式1:使用内建逻辑 + 完整插件系统
|
||||||
|
await client.streamText('gpt-4', {
|
||||||
|
messages: [{ role: 'user', content: 'Hello with plugins!' }]
|
||||||
|
})
|
||||||
|
|
||||||
|
// 4. 方式2:使用自定义注册表 + 有限插件支持
|
||||||
|
await client.streamText({
|
||||||
|
model: registry.languageModel('anthropic:claude-3-opus-20240229'),
|
||||||
|
messages: [{ role: 'user', content: 'Hello from Claude!' }]
|
||||||
|
})
|
||||||
|
|
||||||
|
// 5. 支持的方法
|
||||||
|
await client.generateObject({
|
||||||
|
model: registry.languageModel('openai:gpt-4'),
|
||||||
|
schema: z.object({ name: z.string() }),
|
||||||
|
messages: [{ role: 'user', content: 'Generate a user' }]
|
||||||
|
})
|
||||||
|
|
||||||
|
await client.streamObject({
|
||||||
|
model: registry.languageModel('openai:gpt-4'),
|
||||||
|
schema: z.object({ items: z.array(z.string()) }),
|
||||||
|
messages: [{ role: 'user', content: 'Generate a list' }]
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 混合使用的优势
|
||||||
|
|
||||||
|
- **灵活性**:可以根据需要选择使用内建逻辑或自定义注册表
|
||||||
|
- **兼容性**:完全兼容 AI SDK 的 `createProviderRegistry` API
|
||||||
|
- **渐进式**:可以逐步迁移现有代码,无需一次性重构
|
||||||
|
- **插件支持**:自定义注册表仍可享受插件系统的部分功能
|
||||||
|
- **最佳实践**:结合两种方式的优点,既有动态加载的性能优势,又有统一注册表的便利性
|
||||||
|
|
||||||
|
## 📚 相关资源
|
||||||
|
|
||||||
|
- [Vercel AI SDK 文档](https://ai-sdk.dev/)
|
||||||
|
- [Cherry Studio 项目](https://github.com/CherryHQ/cherry-studio)
|
||||||
|
- [AI SDK Providers](https://ai-sdk.dev/providers/ai-sdk-providers)
|
||||||
|
|
||||||
|
## 未来版本
|
||||||
|
|
||||||
|
- 🔮 多 Agent 编排
|
||||||
|
- 🔮 可视化插件配置
|
||||||
|
- 🔮 实时监控和分析
|
||||||
|
- 🔮 云端插件同步
|
||||||
|
|
||||||
|
## 📄 License
|
||||||
|
|
||||||
|
MIT License - 详见 [LICENSE](https://github.com/CherryHQ/cherry-studio/blob/main/LICENSE) 文件
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**Cherry Studio AI Core** - 让 AI 开发更简单、更强大、更灵活 🚀
|
||||||
103
packages/aiCore/examples/hub-provider-usage.ts
Normal file
103
packages/aiCore/examples/hub-provider-usage.ts
Normal file
@@ -0,0 +1,103 @@
|
|||||||
|
/**
|
||||||
|
* Hub Provider 使用示例
|
||||||
|
*
|
||||||
|
* 演示如何使用简化后的Hub Provider功能来路由到多个底层provider
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { createHubProvider, initializeProvider, providerRegistry } from '../src/index'
|
||||||
|
|
||||||
|
async function demonstrateHubProvider() {
|
||||||
|
try {
|
||||||
|
// 1. 初始化底层providers
|
||||||
|
console.log('📦 初始化底层providers...')
|
||||||
|
|
||||||
|
initializeProvider('openai', {
|
||||||
|
apiKey: process.env.OPENAI_API_KEY || 'sk-test-key'
|
||||||
|
})
|
||||||
|
|
||||||
|
initializeProvider('anthropic', {
|
||||||
|
apiKey: process.env.ANTHROPIC_API_KEY || 'sk-ant-test-key'
|
||||||
|
})
|
||||||
|
|
||||||
|
// 2. 创建Hub Provider(自动包含所有已初始化的providers)
|
||||||
|
console.log('🌐 创建Hub Provider...')
|
||||||
|
|
||||||
|
const aihubmixProvider = createHubProvider({
|
||||||
|
hubId: 'aihubmix',
|
||||||
|
debug: true
|
||||||
|
})
|
||||||
|
|
||||||
|
// 3. 注册Hub Provider
|
||||||
|
providerRegistry.registerProvider('aihubmix', aihubmixProvider)
|
||||||
|
|
||||||
|
console.log('✅ Hub Provider "aihubmix" 注册成功')
|
||||||
|
|
||||||
|
// 4. 使用Hub Provider访问不同的模型
|
||||||
|
console.log('\n🚀 使用Hub模型...')
|
||||||
|
|
||||||
|
// 通过Hub路由到OpenAI
|
||||||
|
const openaiModel = providerRegistry.languageModel('aihubmix:openai:gpt-4')
|
||||||
|
console.log('✓ OpenAI模型已获取:', openaiModel.modelId)
|
||||||
|
|
||||||
|
// 通过Hub路由到Anthropic
|
||||||
|
const anthropicModel = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
|
||||||
|
console.log('✓ Anthropic模型已获取:', anthropicModel.modelId)
|
||||||
|
|
||||||
|
// 5. 演示错误处理
|
||||||
|
console.log('\n❌ 演示错误处理...')
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 尝试访问未初始化的provider
|
||||||
|
providerRegistry.languageModel('aihubmix:google:gemini-pro')
|
||||||
|
} catch (error) {
|
||||||
|
console.log('预期错误:', error.message)
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 尝试使用错误的模型ID格式
|
||||||
|
providerRegistry.languageModel('aihubmix:invalid-format')
|
||||||
|
} catch (error) {
|
||||||
|
console.log('预期错误:', error.message)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 6. 多个Hub Provider示例
|
||||||
|
console.log('\n🔄 创建多个Hub Provider...')
|
||||||
|
|
||||||
|
const localHubProvider = createHubProvider({
|
||||||
|
hubId: 'local-ai'
|
||||||
|
})
|
||||||
|
|
||||||
|
providerRegistry.registerProvider('local-ai', localHubProvider)
|
||||||
|
console.log('✅ Hub Provider "local-ai" 注册成功')
|
||||||
|
|
||||||
|
console.log('\n🎉 Hub Provider演示完成!')
|
||||||
|
} catch (error) {
|
||||||
|
console.error('💥 演示过程中发生错误:', error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 演示简化的使用方式
|
||||||
|
function simplifiedUsageExample() {
|
||||||
|
console.log('\n📝 简化使用示例:')
|
||||||
|
console.log(`
|
||||||
|
// 1. 初始化providers
|
||||||
|
initializeProvider('openai', { apiKey: 'sk-xxx' })
|
||||||
|
initializeProvider('anthropic', { apiKey: 'sk-ant-xxx' })
|
||||||
|
|
||||||
|
// 2. 创建并注册Hub Provider
|
||||||
|
const hubProvider = createHubProvider({ hubId: 'aihubmix' })
|
||||||
|
providerRegistry.registerProvider('aihubmix', hubProvider)
|
||||||
|
|
||||||
|
// 3. 直接使用
|
||||||
|
const model1 = providerRegistry.languageModel('aihubmix:openai:gpt-4')
|
||||||
|
const model2 = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
|
||||||
|
`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 运行演示
|
||||||
|
if (require.main === module) {
|
||||||
|
demonstrateHubProvider()
|
||||||
|
simplifiedUsageExample()
|
||||||
|
}
|
||||||
|
|
||||||
|
export { demonstrateHubProvider, simplifiedUsageExample }
|
||||||
167
packages/aiCore/examples/image-generation.ts
Normal file
167
packages/aiCore/examples/image-generation.ts
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
/**
|
||||||
|
* Image Generation Example
|
||||||
|
* 演示如何使用 aiCore 的文生图功能
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { createExecutor, generateImage } from '../src/index'
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
// 方式1: 使用执行器实例
|
||||||
|
console.log('📸 创建 OpenAI 图像生成执行器...')
|
||||||
|
const executor = createExecutor('openai', {
|
||||||
|
apiKey: process.env.OPENAI_API_KEY!
|
||||||
|
})
|
||||||
|
|
||||||
|
try {
|
||||||
|
console.log('🎨 使用执行器生成图像...')
|
||||||
|
const result1 = await executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A futuristic cityscape at sunset with flying cars',
|
||||||
|
size: '1024x1024',
|
||||||
|
n: 1
|
||||||
|
})
|
||||||
|
|
||||||
|
console.log('✅ 图像生成成功!')
|
||||||
|
console.log('📊 结果:', {
|
||||||
|
imagesCount: result1.images.length,
|
||||||
|
mediaType: result1.image.mediaType,
|
||||||
|
hasBase64: !!result1.image.base64,
|
||||||
|
providerMetadata: result1.providerMetadata
|
||||||
|
})
|
||||||
|
} catch (error) {
|
||||||
|
console.error('❌ 执行器生成失败:', error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 方式2: 使用直接调用 API
|
||||||
|
try {
|
||||||
|
console.log('🎨 使用直接 API 生成图像...')
|
||||||
|
const result2 = await generateImage('openai', { apiKey: process.env.OPENAI_API_KEY! }, 'dall-e-3', {
|
||||||
|
prompt: 'A magical forest with glowing mushrooms and fairy lights',
|
||||||
|
aspectRatio: '16:9',
|
||||||
|
providerOptions: {
|
||||||
|
openai: {
|
||||||
|
quality: 'hd',
|
||||||
|
style: 'vivid'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
console.log('✅ 直接 API 生成成功!')
|
||||||
|
console.log('📊 结果:', {
|
||||||
|
imagesCount: result2.images.length,
|
||||||
|
mediaType: result2.image.mediaType,
|
||||||
|
hasBase64: !!result2.image.base64
|
||||||
|
})
|
||||||
|
} catch (error) {
|
||||||
|
console.error('❌ 直接 API 生成失败:', error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 方式3: 支持其他提供商 (Google Imagen)
|
||||||
|
if (process.env.GOOGLE_API_KEY) {
|
||||||
|
try {
|
||||||
|
console.log('🎨 使用 Google Imagen 生成图像...')
|
||||||
|
const googleExecutor = createExecutor('google', {
|
||||||
|
apiKey: process.env.GOOGLE_API_KEY!
|
||||||
|
})
|
||||||
|
|
||||||
|
const result3 = await googleExecutor.generateImage('imagen-3.0-generate-002', {
|
||||||
|
prompt: 'A serene mountain lake at dawn with mist rising from the water',
|
||||||
|
aspectRatio: '1:1'
|
||||||
|
})
|
||||||
|
|
||||||
|
console.log('✅ Google Imagen 生成成功!')
|
||||||
|
console.log('📊 结果:', {
|
||||||
|
imagesCount: result3.images.length,
|
||||||
|
mediaType: result3.image.mediaType,
|
||||||
|
hasBase64: !!result3.image.base64
|
||||||
|
})
|
||||||
|
} catch (error) {
|
||||||
|
console.error('❌ Google Imagen 生成失败:', error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 方式4: 支持插件系统
|
||||||
|
const pluginExample = async () => {
|
||||||
|
console.log('🔌 演示插件系统...')
|
||||||
|
|
||||||
|
// 创建一个示例插件,用于修改提示词
|
||||||
|
const promptEnhancerPlugin = {
|
||||||
|
name: 'prompt-enhancer',
|
||||||
|
transformParams: async (params: any) => {
|
||||||
|
console.log('🔧 插件: 增强提示词...')
|
||||||
|
return {
|
||||||
|
...params,
|
||||||
|
prompt: `${params.prompt}, highly detailed, cinematic lighting, 4K resolution`
|
||||||
|
}
|
||||||
|
},
|
||||||
|
transformResult: async (result: any) => {
|
||||||
|
console.log('🔧 插件: 处理结果...')
|
||||||
|
return {
|
||||||
|
...result,
|
||||||
|
enhanced: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const executorWithPlugin = createExecutor(
|
||||||
|
'openai',
|
||||||
|
{
|
||||||
|
apiKey: process.env.OPENAI_API_KEY!
|
||||||
|
},
|
||||||
|
[promptEnhancerPlugin]
|
||||||
|
)
|
||||||
|
|
||||||
|
try {
|
||||||
|
const result4 = await executorWithPlugin.generateImage('dall-e-3', {
|
||||||
|
prompt: 'A cute robot playing in a garden'
|
||||||
|
})
|
||||||
|
|
||||||
|
console.log('✅ 插件系统生成成功!')
|
||||||
|
console.log('📊 结果:', {
|
||||||
|
imagesCount: result4.images.length,
|
||||||
|
enhanced: (result4 as any).enhanced,
|
||||||
|
mediaType: result4.image.mediaType
|
||||||
|
})
|
||||||
|
} catch (error) {
|
||||||
|
console.error('❌ 插件系统生成失败:', error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
await pluginExample()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 错误处理演示
|
||||||
|
async function errorHandlingExample() {
|
||||||
|
console.log('⚠️ 演示错误处理...')
|
||||||
|
|
||||||
|
try {
|
||||||
|
const executor = createExecutor('openai', {
|
||||||
|
apiKey: 'invalid-key'
|
||||||
|
})
|
||||||
|
|
||||||
|
await executor.generateImage('dall-e-3', {
|
||||||
|
prompt: 'Test image'
|
||||||
|
})
|
||||||
|
} catch (error: any) {
|
||||||
|
console.log('✅ 成功捕获错误:', error.constructor.name)
|
||||||
|
console.log('📋 错误信息:', error.message)
|
||||||
|
console.log('🏷️ 提供商ID:', error.providerId)
|
||||||
|
console.log('🏷️ 模型ID:', error.modelId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 运行示例
|
||||||
|
if (require.main === module) {
|
||||||
|
main()
|
||||||
|
.then(() => {
|
||||||
|
console.log('🎉 所有示例完成!')
|
||||||
|
return errorHandlingExample()
|
||||||
|
})
|
||||||
|
.then(() => {
|
||||||
|
console.log('🎯 示例程序结束')
|
||||||
|
process.exit(0)
|
||||||
|
})
|
||||||
|
.catch((error) => {
|
||||||
|
console.error('💥 程序执行出错:', error)
|
||||||
|
process.exit(1)
|
||||||
|
})
|
||||||
|
}
|
||||||
85
packages/aiCore/package.json
Normal file
85
packages/aiCore/package.json
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
{
|
||||||
|
"name": "@cherrystudio/ai-core",
|
||||||
|
"version": "1.0.0-alpha.11",
|
||||||
|
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
|
||||||
|
"main": "dist/index.js",
|
||||||
|
"module": "dist/index.mjs",
|
||||||
|
"types": "dist/index.d.ts",
|
||||||
|
"react-native": "dist/index.js",
|
||||||
|
"scripts": {
|
||||||
|
"build": "tsdown",
|
||||||
|
"dev": "tsc -w",
|
||||||
|
"clean": "rm -rf dist",
|
||||||
|
"test": "vitest run",
|
||||||
|
"test:watch": "vitest"
|
||||||
|
},
|
||||||
|
"keywords": [
|
||||||
|
"ai",
|
||||||
|
"sdk",
|
||||||
|
"openai",
|
||||||
|
"anthropic",
|
||||||
|
"google",
|
||||||
|
"cherry-studio",
|
||||||
|
"vercel-ai-sdk"
|
||||||
|
],
|
||||||
|
"author": "Cherry Studio",
|
||||||
|
"license": "MIT",
|
||||||
|
"repository": {
|
||||||
|
"type": "git",
|
||||||
|
"url": "git+https://github.com/CherryHQ/cherry-studio.git"
|
||||||
|
},
|
||||||
|
"bugs": {
|
||||||
|
"url": "https://github.com/CherryHQ/cherry-studio/issues"
|
||||||
|
},
|
||||||
|
"homepage": "https://github.com/CherryHQ/cherry-studio#readme",
|
||||||
|
"peerDependencies": {
|
||||||
|
"ai": "^5.0.26"
|
||||||
|
},
|
||||||
|
"dependencies": {
|
||||||
|
"@ai-sdk/anthropic": "^2.0.5",
|
||||||
|
"@ai-sdk/azure": "^2.0.16",
|
||||||
|
"@ai-sdk/deepseek": "^1.0.9",
|
||||||
|
"@ai-sdk/google": "^2.0.7",
|
||||||
|
"@ai-sdk/openai": "^2.0.19",
|
||||||
|
"@ai-sdk/openai-compatible": "^1.0.9",
|
||||||
|
"@ai-sdk/provider": "^2.0.0",
|
||||||
|
"@ai-sdk/provider-utils": "^3.0.4",
|
||||||
|
"@ai-sdk/xai": "^2.0.9",
|
||||||
|
"zod": "^3.25.0"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"tsdown": "^0.12.9",
|
||||||
|
"typescript": "^5.0.0",
|
||||||
|
"vitest": "^3.2.4"
|
||||||
|
},
|
||||||
|
"sideEffects": false,
|
||||||
|
"engines": {
|
||||||
|
"node": ">=18.0.0"
|
||||||
|
},
|
||||||
|
"files": [
|
||||||
|
"dist"
|
||||||
|
],
|
||||||
|
"exports": {
|
||||||
|
".": {
|
||||||
|
"types": "./dist/index.d.ts",
|
||||||
|
"react-native": "./dist/index.js",
|
||||||
|
"import": "./dist/index.mjs",
|
||||||
|
"require": "./dist/index.js",
|
||||||
|
"default": "./dist/index.js"
|
||||||
|
},
|
||||||
|
"./built-in/plugins": {
|
||||||
|
"types": "./dist/built-in/plugins/index.d.ts",
|
||||||
|
"react-native": "./dist/built-in/plugins/index.js",
|
||||||
|
"import": "./dist/built-in/plugins/index.mjs",
|
||||||
|
"require": "./dist/built-in/plugins/index.js",
|
||||||
|
"default": "./dist/built-in/plugins/index.js"
|
||||||
|
},
|
||||||
|
"./provider": {
|
||||||
|
"types": "./dist/provider/index.d.ts",
|
||||||
|
"react-native": "./dist/provider/index.js",
|
||||||
|
"import": "./dist/provider/index.mjs",
|
||||||
|
"require": "./dist/provider/index.js",
|
||||||
|
"default": "./dist/provider/index.js"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
2
packages/aiCore/setupVitest.ts
Normal file
2
packages/aiCore/setupVitest.ts
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
// 模拟 Vite SSR helper,避免 Node 环境找不到时报错
|
||||||
|
;(globalThis as any).__vite_ssr_exportName__ = (name: string, value: any) => value
|
||||||
3
packages/aiCore/src/core/README.MD
Normal file
3
packages/aiCore/src/core/README.MD
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
# @cherryStudio-aiCore
|
||||||
|
|
||||||
|
Core
|
||||||
17
packages/aiCore/src/core/index.ts
Normal file
17
packages/aiCore/src/core/index.ts
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
/**
|
||||||
|
* Core 模块导出
|
||||||
|
* 内部核心功能,供其他模块使用,不直接面向最终调用者
|
||||||
|
*/
|
||||||
|
|
||||||
|
// 中间件系统
|
||||||
|
export type { NamedMiddleware } from './middleware'
|
||||||
|
export { createMiddlewares, wrapModelWithMiddlewares } from './middleware'
|
||||||
|
|
||||||
|
// 创建管理
|
||||||
|
export { globalModelResolver, ModelResolver } from './models'
|
||||||
|
export type { ModelConfig as ModelConfigType } from './models/types'
|
||||||
|
|
||||||
|
// 执行管理
|
||||||
|
export type { ToolUseRequestContext } from './plugins/built-in/toolUsePlugin/type'
|
||||||
|
export { createExecutor, createOpenAICompatibleExecutor } from './runtime'
|
||||||
|
export type { RuntimeConfig } from './runtime/types'
|
||||||
8
packages/aiCore/src/core/middleware/index.ts
Normal file
8
packages/aiCore/src/core/middleware/index.ts
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
/**
|
||||||
|
* Middleware 模块导出
|
||||||
|
* 提供通用的中间件管理能力
|
||||||
|
*/
|
||||||
|
|
||||||
|
export { createMiddlewares } from './manager'
|
||||||
|
export type { NamedMiddleware } from './types'
|
||||||
|
export { wrapModelWithMiddlewares } from './wrapper'
|
||||||
16
packages/aiCore/src/core/middleware/manager.ts
Normal file
16
packages/aiCore/src/core/middleware/manager.ts
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
/**
|
||||||
|
* 中间件管理器
|
||||||
|
* 专注于 AI SDK 中间件的管理,与插件系统分离
|
||||||
|
*/
|
||||||
|
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建中间件列表
|
||||||
|
* 合并用户提供的中间件
|
||||||
|
*/
|
||||||
|
export function createMiddlewares(userMiddlewares: LanguageModelV2Middleware[] = []): LanguageModelV2Middleware[] {
|
||||||
|
// 未来可以在这里添加默认的中间件
|
||||||
|
const defaultMiddlewares: LanguageModelV2Middleware[] = []
|
||||||
|
|
||||||
|
return [...defaultMiddlewares, ...userMiddlewares]
|
||||||
|
}
|
||||||
12
packages/aiCore/src/core/middleware/types.ts
Normal file
12
packages/aiCore/src/core/middleware/types.ts
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
/**
|
||||||
|
* 中间件系统类型定义
|
||||||
|
*/
|
||||||
|
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 具名中间件接口
|
||||||
|
*/
|
||||||
|
export interface NamedMiddleware {
|
||||||
|
name: string
|
||||||
|
middleware: LanguageModelV2Middleware
|
||||||
|
}
|
||||||
23
packages/aiCore/src/core/middleware/wrapper.ts
Normal file
23
packages/aiCore/src/core/middleware/wrapper.ts
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
/**
|
||||||
|
* 模型包装工具函数
|
||||||
|
* 用于将中间件应用到LanguageModel上
|
||||||
|
*/
|
||||||
|
import { LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||||
|
import { wrapLanguageModel } from 'ai'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 使用中间件包装模型
|
||||||
|
*/
|
||||||
|
export function wrapModelWithMiddlewares(
|
||||||
|
model: LanguageModelV2,
|
||||||
|
middlewares: LanguageModelV2Middleware[]
|
||||||
|
): LanguageModelV2 {
|
||||||
|
if (middlewares.length === 0) {
|
||||||
|
return model
|
||||||
|
}
|
||||||
|
|
||||||
|
return wrapLanguageModel({
|
||||||
|
model,
|
||||||
|
middleware: middlewares
|
||||||
|
})
|
||||||
|
}
|
||||||
125
packages/aiCore/src/core/models/ModelResolver.ts
Normal file
125
packages/aiCore/src/core/models/ModelResolver.ts
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
/**
|
||||||
|
* 模型解析器 - models模块的核心
|
||||||
|
* 负责将modelId解析为AI SDK的LanguageModel实例
|
||||||
|
* 支持传统格式和命名空间格式
|
||||||
|
* 集成了来自 ModelCreator 的特殊处理逻辑
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||||
|
|
||||||
|
import { wrapModelWithMiddlewares } from '../middleware/wrapper'
|
||||||
|
import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../providers/RegistryManagement'
|
||||||
|
|
||||||
|
export class ModelResolver {
|
||||||
|
/**
|
||||||
|
* 核心方法:解析任意格式的modelId为语言模型
|
||||||
|
*
|
||||||
|
* @param modelId 模型ID,支持 'gpt-4' 和 'anthropic>claude-3' 两种格式
|
||||||
|
* @param fallbackProviderId 当modelId为传统格式时使用的providerId
|
||||||
|
* @param providerOptions provider配置选项(用于OpenAI模式选择等)
|
||||||
|
* @param middlewares 中间件数组,会应用到最终模型上
|
||||||
|
*/
|
||||||
|
async resolveLanguageModel(
|
||||||
|
modelId: string,
|
||||||
|
fallbackProviderId: string,
|
||||||
|
providerOptions?: any,
|
||||||
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
|
): Promise<LanguageModelV2> {
|
||||||
|
let finalProviderId = fallbackProviderId
|
||||||
|
let model: LanguageModelV2
|
||||||
|
// 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移)
|
||||||
|
if ((fallbackProviderId === 'openai' || fallbackProviderId === 'azure') && providerOptions?.mode === 'chat') {
|
||||||
|
finalProviderId = `${fallbackProviderId}-chat`
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否是命名空间格式
|
||||||
|
if (modelId.includes(DEFAULT_SEPARATOR)) {
|
||||||
|
model = this.resolveNamespacedModel(modelId)
|
||||||
|
} else {
|
||||||
|
// 传统格式:使用处理后的 providerId + modelId
|
||||||
|
model = this.resolveTraditionalModel(finalProviderId, modelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 🎯 应用中间件(如果有)
|
||||||
|
if (middlewares && middlewares.length > 0) {
|
||||||
|
model = wrapModelWithMiddlewares(model, middlewares)
|
||||||
|
}
|
||||||
|
|
||||||
|
return model
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析文本嵌入模型
|
||||||
|
*/
|
||||||
|
async resolveTextEmbeddingModel(modelId: string, fallbackProviderId: string): Promise<EmbeddingModelV2<string>> {
|
||||||
|
if (modelId.includes(DEFAULT_SEPARATOR)) {
|
||||||
|
return this.resolveNamespacedEmbeddingModel(modelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.resolveTraditionalEmbeddingModel(fallbackProviderId, modelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析图像模型
|
||||||
|
*/
|
||||||
|
async resolveImageModel(modelId: string, fallbackProviderId: string): Promise<ImageModelV2> {
|
||||||
|
if (modelId.includes(DEFAULT_SEPARATOR)) {
|
||||||
|
return this.resolveNamespacedImageModel(modelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.resolveTraditionalImageModel(fallbackProviderId, modelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析命名空间格式的语言模型
|
||||||
|
* aihubmix:anthropic:claude-3 -> globalRegistryManagement.languageModel('aihubmix:anthropic:claude-3')
|
||||||
|
*/
|
||||||
|
private resolveNamespacedModel(modelId: string): LanguageModelV2 {
|
||||||
|
return globalRegistryManagement.languageModel(modelId as any)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析传统格式的语言模型
|
||||||
|
* providerId: 'openai', modelId: 'gpt-4' -> globalRegistryManagement.languageModel('openai:gpt-4')
|
||||||
|
*/
|
||||||
|
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
|
||||||
|
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||||
|
console.log('fullModelId', fullModelId)
|
||||||
|
return globalRegistryManagement.languageModel(fullModelId as any)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析命名空间格式的嵌入模型
|
||||||
|
*/
|
||||||
|
private resolveNamespacedEmbeddingModel(modelId: string): EmbeddingModelV2<string> {
|
||||||
|
return globalRegistryManagement.textEmbeddingModel(modelId as any)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析传统格式的嵌入模型
|
||||||
|
*/
|
||||||
|
private resolveTraditionalEmbeddingModel(providerId: string, modelId: string): EmbeddingModelV2<string> {
|
||||||
|
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||||
|
return globalRegistryManagement.textEmbeddingModel(fullModelId as any)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析命名空间格式的图像模型
|
||||||
|
*/
|
||||||
|
private resolveNamespacedImageModel(modelId: string): ImageModelV2 {
|
||||||
|
return globalRegistryManagement.imageModel(modelId as any)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析传统格式的图像模型
|
||||||
|
*/
|
||||||
|
private resolveTraditionalImageModel(providerId: string, modelId: string): ImageModelV2 {
|
||||||
|
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
|
||||||
|
return globalRegistryManagement.imageModel(fullModelId as any)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 全局模型解析器实例
|
||||||
|
*/
|
||||||
|
export const globalModelResolver = new ModelResolver()
|
||||||
9
packages/aiCore/src/core/models/index.ts
Normal file
9
packages/aiCore/src/core/models/index.ts
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
/**
|
||||||
|
* Models 模块统一导出 - 简化版
|
||||||
|
*/
|
||||||
|
|
||||||
|
// 核心模型解析器
|
||||||
|
export { globalModelResolver, ModelResolver } from './ModelResolver'
|
||||||
|
|
||||||
|
// 保留的类型定义(可能被其他地方使用)
|
||||||
|
export type { ModelConfig as ModelConfigType } from './types'
|
||||||
15
packages/aiCore/src/core/models/types.ts
Normal file
15
packages/aiCore/src/core/models/types.ts
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
/**
|
||||||
|
* Creation 模块类型定义
|
||||||
|
*/
|
||||||
|
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||||
|
|
||||||
|
import type { ProviderId, ProviderSettingsMap } from '../providers/types'
|
||||||
|
|
||||||
|
export interface ModelConfig<T extends ProviderId = ProviderId> {
|
||||||
|
providerId: T
|
||||||
|
modelId: string
|
||||||
|
providerSettings: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }
|
||||||
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
|
// 额外模型参数
|
||||||
|
extraModelConfig?: Record<string, any>
|
||||||
|
}
|
||||||
87
packages/aiCore/src/core/options/examples.ts
Normal file
87
packages/aiCore/src/core/options/examples.ts
Normal file
@@ -0,0 +1,87 @@
|
|||||||
|
import { streamText } from 'ai'
|
||||||
|
|
||||||
|
import {
|
||||||
|
createAnthropicOptions,
|
||||||
|
createGenericProviderOptions,
|
||||||
|
createGoogleOptions,
|
||||||
|
createOpenAIOptions,
|
||||||
|
mergeProviderOptions
|
||||||
|
} from './factory'
|
||||||
|
|
||||||
|
// 示例1: 使用已知供应商的严格类型约束
|
||||||
|
export function exampleOpenAIWithOptions() {
|
||||||
|
const openaiOptions = createOpenAIOptions({
|
||||||
|
reasoningEffort: 'medium'
|
||||||
|
})
|
||||||
|
|
||||||
|
// 这里会有类型检查,确保选项符合OpenAI的设置
|
||||||
|
return streamText({
|
||||||
|
model: {} as any, // 实际使用时替换为真实模型
|
||||||
|
prompt: 'Hello',
|
||||||
|
providerOptions: openaiOptions
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 示例2: 使用Anthropic供应商选项
|
||||||
|
export function exampleAnthropicWithOptions() {
|
||||||
|
const anthropicOptions = createAnthropicOptions({
|
||||||
|
thinking: {
|
||||||
|
type: 'enabled',
|
||||||
|
budgetTokens: 1000
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return streamText({
|
||||||
|
model: {} as any,
|
||||||
|
prompt: 'Hello',
|
||||||
|
providerOptions: anthropicOptions
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 示例3: 使用Google供应商选项
|
||||||
|
export function exampleGoogleWithOptions() {
|
||||||
|
const googleOptions = createGoogleOptions({
|
||||||
|
thinkingConfig: {
|
||||||
|
includeThoughts: true,
|
||||||
|
thinkingBudget: 1000
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return streamText({
|
||||||
|
model: {} as any,
|
||||||
|
prompt: 'Hello',
|
||||||
|
providerOptions: googleOptions
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 示例4: 使用未知供应商(通用类型)
|
||||||
|
export function exampleUnknownProviderWithOptions() {
|
||||||
|
const customProviderOptions = createGenericProviderOptions('custom-provider', {
|
||||||
|
temperature: 0.7,
|
||||||
|
customSetting: 'value',
|
||||||
|
anotherOption: true
|
||||||
|
})
|
||||||
|
|
||||||
|
return streamText({
|
||||||
|
model: {} as any,
|
||||||
|
prompt: 'Hello',
|
||||||
|
providerOptions: customProviderOptions
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 示例5: 合并多个供应商选项
|
||||||
|
export function exampleMergedOptions() {
|
||||||
|
const openaiOptions = createOpenAIOptions({})
|
||||||
|
|
||||||
|
const customOptions = createGenericProviderOptions('custom', {
|
||||||
|
customParam: 'value'
|
||||||
|
})
|
||||||
|
|
||||||
|
const mergedOptions = mergeProviderOptions(openaiOptions, customOptions)
|
||||||
|
|
||||||
|
return streamText({
|
||||||
|
model: {} as any,
|
||||||
|
prompt: 'Hello',
|
||||||
|
providerOptions: mergedOptions
|
||||||
|
})
|
||||||
|
}
|
||||||
71
packages/aiCore/src/core/options/factory.ts
Normal file
71
packages/aiCore/src/core/options/factory.ts
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
import { ExtractProviderOptions, ProviderOptionsMap, TypedProviderOptions } from './types'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建特定供应商的选项
|
||||||
|
* @param provider 供应商名称
|
||||||
|
* @param options 供应商特定的选项
|
||||||
|
* @returns 格式化的provider options
|
||||||
|
*/
|
||||||
|
export function createProviderOptions<T extends keyof ProviderOptionsMap>(
|
||||||
|
provider: T,
|
||||||
|
options: ExtractProviderOptions<T>
|
||||||
|
): Record<T, ExtractProviderOptions<T>> {
|
||||||
|
return { [provider]: options } as Record<T, ExtractProviderOptions<T>>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建任意供应商的选项(包括未知供应商)
|
||||||
|
* @param provider 供应商名称
|
||||||
|
* @param options 供应商选项
|
||||||
|
* @returns 格式化的provider options
|
||||||
|
*/
|
||||||
|
export function createGenericProviderOptions<T extends string>(
|
||||||
|
provider: T,
|
||||||
|
options: Record<string, any>
|
||||||
|
): Record<T, Record<string, any>> {
|
||||||
|
return { [provider]: options } as Record<T, Record<string, any>>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 合并多个供应商的options
|
||||||
|
* @param optionsMap 包含多个供应商选项的对象
|
||||||
|
* @returns 合并后的TypedProviderOptions
|
||||||
|
*/
|
||||||
|
export function mergeProviderOptions(...optionsMap: Partial<TypedProviderOptions>[]): TypedProviderOptions {
|
||||||
|
return Object.assign({}, ...optionsMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建OpenAI供应商选项的便捷函数
|
||||||
|
*/
|
||||||
|
export function createOpenAIOptions(options: ExtractProviderOptions<'openai'>) {
|
||||||
|
return createProviderOptions('openai', options)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建Anthropic供应商选项的便捷函数
|
||||||
|
*/
|
||||||
|
export function createAnthropicOptions(options: ExtractProviderOptions<'anthropic'>) {
|
||||||
|
return createProviderOptions('anthropic', options)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建Google供应商选项的便捷函数
|
||||||
|
*/
|
||||||
|
export function createGoogleOptions(options: ExtractProviderOptions<'google'>) {
|
||||||
|
return createProviderOptions('google', options)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建OpenRouter供应商选项的便捷函数
|
||||||
|
*/
|
||||||
|
export function createOpenRouterOptions(options: ExtractProviderOptions<'openrouter'>) {
|
||||||
|
return createProviderOptions('openrouter', options)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建XAI供应商选项的便捷函数
|
||||||
|
*/
|
||||||
|
export function createXaiOptions(options: ExtractProviderOptions<'xai'>) {
|
||||||
|
return createProviderOptions('xai', options)
|
||||||
|
}
|
||||||
2
packages/aiCore/src/core/options/index.ts
Normal file
2
packages/aiCore/src/core/options/index.ts
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
export * from './factory'
|
||||||
|
export * from './types'
|
||||||
38
packages/aiCore/src/core/options/openrouter.ts
Normal file
38
packages/aiCore/src/core/options/openrouter.ts
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
export type OpenRouterProviderOptions = {
|
||||||
|
models?: string[]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* https://openrouter.ai/docs/use-cases/reasoning-tokens
|
||||||
|
* One of `max_tokens` or `effort` is required.
|
||||||
|
* If `exclude` is true, reasoning will be removed from the response. Default is false.
|
||||||
|
*/
|
||||||
|
reasoning?: {
|
||||||
|
exclude?: boolean
|
||||||
|
} & (
|
||||||
|
| {
|
||||||
|
max_tokens: number
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
effort: 'high' | 'medium' | 'low'
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A unique identifier representing your end-user, which can
|
||||||
|
* help OpenRouter to monitor and detect abuse.
|
||||||
|
*/
|
||||||
|
user?: string
|
||||||
|
|
||||||
|
extraBody?: Record<string, unknown>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Enable usage accounting to get detailed token usage information.
|
||||||
|
* https://openrouter.ai/docs/use-cases/usage-accounting
|
||||||
|
*/
|
||||||
|
usage?: {
|
||||||
|
/**
|
||||||
|
* When true, includes token usage information in the response.
|
||||||
|
*/
|
||||||
|
include: boolean
|
||||||
|
}
|
||||||
|
}
|
||||||
33
packages/aiCore/src/core/options/types.ts
Normal file
33
packages/aiCore/src/core/options/types.ts
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
import { type AnthropicProviderOptions } from '@ai-sdk/anthropic'
|
||||||
|
import { type GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
|
||||||
|
import { type OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
|
||||||
|
import { type SharedV2ProviderMetadata } from '@ai-sdk/provider'
|
||||||
|
|
||||||
|
import { type OpenRouterProviderOptions } from './openrouter'
|
||||||
|
import { type XaiProviderOptions } from './xai'
|
||||||
|
|
||||||
|
export type ProviderOptions<T extends keyof SharedV2ProviderMetadata> = SharedV2ProviderMetadata[T]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 供应商选项类型,如果map中没有,说明没有约束
|
||||||
|
*/
|
||||||
|
export type ProviderOptionsMap = {
|
||||||
|
openai: OpenAIResponsesProviderOptions
|
||||||
|
anthropic: AnthropicProviderOptions
|
||||||
|
google: GoogleGenerativeAIProviderOptions
|
||||||
|
openrouter: OpenRouterProviderOptions
|
||||||
|
xai: XaiProviderOptions
|
||||||
|
}
|
||||||
|
|
||||||
|
// 工具类型,用于从ProviderOptionsMap中提取特定供应商的选项类型
|
||||||
|
export type ExtractProviderOptions<T extends keyof ProviderOptionsMap> = ProviderOptionsMap[T]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 类型安全的ProviderOptions
|
||||||
|
* 对于已知供应商使用严格类型,对于未知供应商允许任意Record<string, JSONValue>
|
||||||
|
*/
|
||||||
|
export type TypedProviderOptions = {
|
||||||
|
[K in keyof ProviderOptionsMap]?: ProviderOptionsMap[K]
|
||||||
|
} & {
|
||||||
|
[K in string]?: Record<string, any>
|
||||||
|
} & SharedV2ProviderMetadata
|
||||||
86
packages/aiCore/src/core/options/xai.ts
Normal file
86
packages/aiCore/src/core/options/xai.ts
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
// copy from @ai-sdk/xai/xai-chat-options.ts
|
||||||
|
// 如果@ai-sdk/xai暴露出了xaiProviderOptions就删除这个文件
|
||||||
|
|
||||||
|
import * as z from 'zod/v4'
|
||||||
|
|
||||||
|
const webSourceSchema = z.object({
|
||||||
|
type: z.literal('web'),
|
||||||
|
country: z.string().length(2).optional(),
|
||||||
|
excludedWebsites: z.array(z.string()).max(5).optional(),
|
||||||
|
allowedWebsites: z.array(z.string()).max(5).optional(),
|
||||||
|
safeSearch: z.boolean().optional()
|
||||||
|
})
|
||||||
|
|
||||||
|
const xSourceSchema = z.object({
|
||||||
|
type: z.literal('x'),
|
||||||
|
xHandles: z.array(z.string()).optional()
|
||||||
|
})
|
||||||
|
|
||||||
|
const newsSourceSchema = z.object({
|
||||||
|
type: z.literal('news'),
|
||||||
|
country: z.string().length(2).optional(),
|
||||||
|
excludedWebsites: z.array(z.string()).max(5).optional(),
|
||||||
|
safeSearch: z.boolean().optional()
|
||||||
|
})
|
||||||
|
|
||||||
|
const rssSourceSchema = z.object({
|
||||||
|
type: z.literal('rss'),
|
||||||
|
links: z.array(z.url()).max(1) // currently only supports one RSS link
|
||||||
|
})
|
||||||
|
|
||||||
|
const searchSourceSchema = z.discriminatedUnion('type', [
|
||||||
|
webSourceSchema,
|
||||||
|
xSourceSchema,
|
||||||
|
newsSourceSchema,
|
||||||
|
rssSourceSchema
|
||||||
|
])
|
||||||
|
|
||||||
|
export const xaiProviderOptions = z.object({
|
||||||
|
/**
|
||||||
|
* reasoning effort for reasoning models
|
||||||
|
* only supported by grok-3-mini and grok-3-mini-fast models
|
||||||
|
*/
|
||||||
|
reasoningEffort: z.enum(['low', 'high']).optional(),
|
||||||
|
|
||||||
|
searchParameters: z
|
||||||
|
.object({
|
||||||
|
/**
|
||||||
|
* search mode preference
|
||||||
|
* - "off": disables search completely
|
||||||
|
* - "auto": model decides whether to search (default)
|
||||||
|
* - "on": always enables search
|
||||||
|
*/
|
||||||
|
mode: z.enum(['off', 'auto', 'on']),
|
||||||
|
|
||||||
|
/**
|
||||||
|
* whether to return citations in the response
|
||||||
|
* defaults to true
|
||||||
|
*/
|
||||||
|
returnCitations: z.boolean().optional(),
|
||||||
|
|
||||||
|
/**
|
||||||
|
* start date for search data (ISO8601 format: YYYY-MM-DD)
|
||||||
|
*/
|
||||||
|
fromDate: z.string().optional(),
|
||||||
|
|
||||||
|
/**
|
||||||
|
* end date for search data (ISO8601 format: YYYY-MM-DD)
|
||||||
|
*/
|
||||||
|
toDate: z.string().optional(),
|
||||||
|
|
||||||
|
/**
|
||||||
|
* maximum number of search results to consider
|
||||||
|
* defaults to 20
|
||||||
|
*/
|
||||||
|
maxSearchResults: z.number().min(1).max(50).optional(),
|
||||||
|
|
||||||
|
/**
|
||||||
|
* data sources to search from
|
||||||
|
* defaults to ["web", "x"] if not specified
|
||||||
|
*/
|
||||||
|
sources: z.array(searchSourceSchema).optional()
|
||||||
|
})
|
||||||
|
.optional()
|
||||||
|
})
|
||||||
|
|
||||||
|
export type XaiProviderOptions = z.infer<typeof xaiProviderOptions>
|
||||||
257
packages/aiCore/src/core/plugins/README.md
Normal file
257
packages/aiCore/src/core/plugins/README.md
Normal file
@@ -0,0 +1,257 @@
|
|||||||
|
# AI Core 插件系统
|
||||||
|
|
||||||
|
支持四种钩子类型:**First**、**Sequential**、**Parallel** 和 **Stream**。
|
||||||
|
|
||||||
|
## 🎯 设计理念
|
||||||
|
|
||||||
|
- **语义清晰**:不同钩子有不同的执行语义
|
||||||
|
- **类型安全**:TypeScript 完整支持
|
||||||
|
- **性能优化**:First 短路、Parallel 并发、Sequential 链式
|
||||||
|
- **易于扩展**:`enforce` 排序 + 功能分类
|
||||||
|
|
||||||
|
## 📋 钩子类型
|
||||||
|
|
||||||
|
### 🥇 First 钩子 - 首个有效结果
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// 只执行第一个返回值的插件,用于解析和查找
|
||||||
|
resolveModel?: (modelId: string, context: AiRequestContext) => string | null
|
||||||
|
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null
|
||||||
|
```
|
||||||
|
|
||||||
|
### 🔄 Sequential 钩子 - 链式数据转换
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// 按顺序链式执行,每个插件可以修改数据
|
||||||
|
transformParams?: (params: any, context: AiRequestContext) => any
|
||||||
|
transformResult?: (result: any, context: AiRequestContext) => any
|
||||||
|
```
|
||||||
|
|
||||||
|
### ⚡ Parallel 钩子 - 并行副作用
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// 并发执行,用于日志、监控等副作用
|
||||||
|
onRequestStart?: (context: AiRequestContext) => void
|
||||||
|
onRequestEnd?: (context: AiRequestContext, result: any) => void
|
||||||
|
onError?: (error: Error, context: AiRequestContext) => void
|
||||||
|
```
|
||||||
|
|
||||||
|
### 🌊 Stream 钩子 - 流处理
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// 直接使用 AI SDK 的 TransformStream
|
||||||
|
transformStream?: () => (options) => TransformStream<TextStreamPart, TextStreamPart>
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🚀 快速开始
|
||||||
|
|
||||||
|
### 基础用法
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { PluginManager, createContext, definePlugin } from '@cherrystudio/ai-core/middleware'
|
||||||
|
|
||||||
|
// 创建插件管理器
|
||||||
|
const pluginManager = new PluginManager()
|
||||||
|
|
||||||
|
// 添加插件
|
||||||
|
pluginManager.use({
|
||||||
|
name: 'my-plugin',
|
||||||
|
async transformParams(params, context) {
|
||||||
|
return { ...params, temperature: 0.7 }
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// 使用插件
|
||||||
|
const context = createContext('openai', 'gpt-4', { messages: [] })
|
||||||
|
const transformedParams = await pluginManager.executeSequential(
|
||||||
|
'transformParams',
|
||||||
|
{ messages: [{ role: 'user', content: 'Hello' }] },
|
||||||
|
context
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 完整示例
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import {
|
||||||
|
PluginManager,
|
||||||
|
ModelAliasPlugin,
|
||||||
|
LoggingPlugin,
|
||||||
|
ParamsValidationPlugin,
|
||||||
|
createContext
|
||||||
|
} from '@cherrystudio/ai-core/middleware'
|
||||||
|
|
||||||
|
// 创建插件管理器
|
||||||
|
const manager = new PluginManager([
|
||||||
|
ModelAliasPlugin, // 模型别名解析
|
||||||
|
ParamsValidationPlugin, // 参数验证
|
||||||
|
LoggingPlugin // 日志记录
|
||||||
|
])
|
||||||
|
|
||||||
|
// AI 请求流程
|
||||||
|
async function aiRequest(providerId: string, modelId: string, params: any) {
|
||||||
|
const context = createContext(providerId, modelId, params)
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 1. 【并行】触发请求开始事件
|
||||||
|
await manager.executeParallel('onRequestStart', context)
|
||||||
|
|
||||||
|
// 2. 【首个】解析模型别名
|
||||||
|
const resolvedModel = await manager.executeFirst('resolveModel', modelId, context)
|
||||||
|
context.modelId = resolvedModel || modelId
|
||||||
|
|
||||||
|
// 3. 【串行】转换请求参数
|
||||||
|
const transformedParams = await manager.executeSequential('transformParams', params, context)
|
||||||
|
|
||||||
|
// 4. 【流处理】收集流转换器(AI SDK 原生支持数组)
|
||||||
|
const streamTransforms = manager.collectStreamTransforms()
|
||||||
|
|
||||||
|
// 5. 调用 AI SDK(这里省略具体实现)
|
||||||
|
const result = await callAiSdk(transformedParams, streamTransforms)
|
||||||
|
|
||||||
|
// 6. 【串行】转换响应结果
|
||||||
|
const transformedResult = await manager.executeSequential('transformResult', result, context)
|
||||||
|
|
||||||
|
// 7. 【并行】触发请求完成事件
|
||||||
|
await manager.executeParallel('onRequestEnd', context, transformedResult)
|
||||||
|
|
||||||
|
return transformedResult
|
||||||
|
} catch (error) {
|
||||||
|
// 8. 【并行】触发错误事件
|
||||||
|
await manager.executeParallel('onError', context, undefined, error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔧 自定义插件
|
||||||
|
|
||||||
|
### 模型别名插件
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const ModelAliasPlugin = definePlugin({
|
||||||
|
name: 'model-alias',
|
||||||
|
enforce: 'pre', // 最先执行
|
||||||
|
|
||||||
|
async resolveModel(modelId) {
|
||||||
|
const aliases = {
|
||||||
|
gpt4: 'gpt-4-turbo-preview',
|
||||||
|
claude: 'claude-3-sonnet-20240229'
|
||||||
|
}
|
||||||
|
return aliases[modelId] || null
|
||||||
|
}
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 参数验证插件
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const ValidationPlugin = definePlugin({
|
||||||
|
name: 'validation',
|
||||||
|
|
||||||
|
async transformParams(params) {
|
||||||
|
if (!params.messages) {
|
||||||
|
throw new Error('messages is required')
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
...params,
|
||||||
|
temperature: params.temperature ?? 0.7,
|
||||||
|
max_tokens: params.max_tokens ?? 4096
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 监控插件
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const MonitoringPlugin = definePlugin({
|
||||||
|
name: 'monitoring',
|
||||||
|
enforce: 'post', // 最后执行
|
||||||
|
|
||||||
|
async onRequestEnd(context, result) {
|
||||||
|
const duration = Date.now() - context.startTime
|
||||||
|
console.log(`请求耗时: ${duration}ms`)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
### 内容过滤插件
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const FilterPlugin = definePlugin({
|
||||||
|
name: 'content-filter',
|
||||||
|
|
||||||
|
transformStream() {
|
||||||
|
return () =>
|
||||||
|
new TransformStream({
|
||||||
|
transform(chunk, controller) {
|
||||||
|
if (chunk.type === 'text-delta') {
|
||||||
|
const filtered = chunk.textDelta.replace(/敏感词/g, '***')
|
||||||
|
controller.enqueue({ ...chunk, textDelta: filtered })
|
||||||
|
} else {
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
```
|
||||||
|
|
||||||
|
## 📊 执行顺序
|
||||||
|
|
||||||
|
### 插件排序
|
||||||
|
|
||||||
|
```
|
||||||
|
enforce: 'pre' → normal → enforce: 'post'
|
||||||
|
```
|
||||||
|
|
||||||
|
### 钩子执行流程
|
||||||
|
|
||||||
|
```mermaid
|
||||||
|
graph TD
|
||||||
|
A[请求开始] --> B[onRequestStart 并行执行]
|
||||||
|
B --> C[resolveModel 首个有效]
|
||||||
|
C --> D[loadTemplate 首个有效]
|
||||||
|
D --> E[transformParams 串行执行]
|
||||||
|
E --> F[collectStreamTransforms]
|
||||||
|
F --> G[AI SDK 调用]
|
||||||
|
G --> H[transformResult 串行执行]
|
||||||
|
H --> I[onRequestEnd 并行执行]
|
||||||
|
|
||||||
|
G --> J[异常处理]
|
||||||
|
J --> K[onError 并行执行]
|
||||||
|
```
|
||||||
|
|
||||||
|
## 💡 最佳实践
|
||||||
|
|
||||||
|
1. **功能单一**:每个插件专注一个功能
|
||||||
|
2. **幂等性**:插件应该是幂等的,重复执行不会产生副作用
|
||||||
|
3. **错误处理**:插件内部处理异常,不要让异常向上传播
|
||||||
|
4. **性能优化**:使用合适的钩子类型(First vs Sequential vs Parallel)
|
||||||
|
5. **命名规范**:使用语义化的插件名称
|
||||||
|
|
||||||
|
## 🔍 调试工具
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// 查看插件统计信息
|
||||||
|
const stats = manager.getStats()
|
||||||
|
console.log('插件统计:', stats)
|
||||||
|
|
||||||
|
// 查看所有插件
|
||||||
|
const plugins = manager.getPlugins()
|
||||||
|
console.log(
|
||||||
|
'已注册插件:',
|
||||||
|
plugins.map((p) => p.name)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## ⚡ 性能优势
|
||||||
|
|
||||||
|
- **First 钩子**:一旦找到结果立即停止,避免无效计算
|
||||||
|
- **Parallel 钩子**:真正并发执行,不阻塞主流程
|
||||||
|
- **Sequential 钩子**:保证数据转换的顺序性
|
||||||
|
- **Stream 钩子**:直接集成 AI SDK,零开销
|
||||||
|
|
||||||
|
这个设计兼顾了简洁性和强大功能,为 AI Core 提供了灵活而高效的扩展机制。
|
||||||
10
packages/aiCore/src/core/plugins/built-in/index.ts
Normal file
10
packages/aiCore/src/core/plugins/built-in/index.ts
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
/**
|
||||||
|
* 内置插件命名空间
|
||||||
|
* 所有内置插件都以 'built-in:' 为前缀
|
||||||
|
*/
|
||||||
|
export const BUILT_IN_PLUGIN_PREFIX = 'built-in:'
|
||||||
|
|
||||||
|
export { createLoggingPlugin } from './logging'
|
||||||
|
export { createPromptToolUsePlugin } from './toolUsePlugin/promptToolUsePlugin'
|
||||||
|
export type { PromptToolUseConfig, ToolUseRequestContext, ToolUseResult } from './toolUsePlugin/type'
|
||||||
|
export { webSearchPlugin } from './webSearchPlugin'
|
||||||
86
packages/aiCore/src/core/plugins/built-in/logging.ts
Normal file
86
packages/aiCore/src/core/plugins/built-in/logging.ts
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
/**
|
||||||
|
* 内置插件:日志记录
|
||||||
|
* 记录AI调用的关键信息,支持性能监控和调试
|
||||||
|
*/
|
||||||
|
import { definePlugin } from '../index'
|
||||||
|
import type { AiRequestContext } from '../types'
|
||||||
|
|
||||||
|
export interface LoggingConfig {
|
||||||
|
// 日志级别
|
||||||
|
level?: 'debug' | 'info' | 'warn' | 'error'
|
||||||
|
// 是否记录参数
|
||||||
|
logParams?: boolean
|
||||||
|
// 是否记录结果
|
||||||
|
logResult?: boolean
|
||||||
|
// 是否记录性能数据
|
||||||
|
logPerformance?: boolean
|
||||||
|
// 自定义日志函数
|
||||||
|
logger?: (level: string, message: string, data?: any) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建日志插件
|
||||||
|
*/
|
||||||
|
export function createLoggingPlugin(config: LoggingConfig = {}) {
|
||||||
|
const { level = 'info', logParams = true, logResult = false, logPerformance = true, logger = console.log } = config
|
||||||
|
|
||||||
|
const startTimes = new Map<string, number>()
|
||||||
|
|
||||||
|
return definePlugin({
|
||||||
|
name: 'built-in:logging',
|
||||||
|
|
||||||
|
onRequestStart: (context: AiRequestContext) => {
|
||||||
|
const requestId = context.requestId
|
||||||
|
startTimes.set(requestId, Date.now())
|
||||||
|
|
||||||
|
logger(level, `🚀 AI Request Started`, {
|
||||||
|
requestId,
|
||||||
|
providerId: context.providerId,
|
||||||
|
modelId: context.modelId,
|
||||||
|
originalParams: logParams ? context.originalParams : '[hidden]'
|
||||||
|
})
|
||||||
|
},
|
||||||
|
|
||||||
|
onRequestEnd: (context: AiRequestContext, result: any) => {
|
||||||
|
const requestId = context.requestId
|
||||||
|
const startTime = startTimes.get(requestId)
|
||||||
|
const duration = startTime ? Date.now() - startTime : undefined
|
||||||
|
startTimes.delete(requestId)
|
||||||
|
|
||||||
|
const logData: any = {
|
||||||
|
requestId,
|
||||||
|
providerId: context.providerId,
|
||||||
|
modelId: context.modelId
|
||||||
|
}
|
||||||
|
|
||||||
|
if (logPerformance && duration) {
|
||||||
|
logData.duration = `${duration}ms`
|
||||||
|
}
|
||||||
|
|
||||||
|
if (logResult) {
|
||||||
|
logData.result = result
|
||||||
|
}
|
||||||
|
|
||||||
|
logger(level, `✅ AI Request Completed`, logData)
|
||||||
|
},
|
||||||
|
|
||||||
|
onError: (error: Error, context: AiRequestContext) => {
|
||||||
|
const requestId = context.requestId
|
||||||
|
const startTime = startTimes.get(requestId)
|
||||||
|
const duration = startTime ? Date.now() - startTime : undefined
|
||||||
|
startTimes.delete(requestId)
|
||||||
|
|
||||||
|
logger('error', `❌ AI Request Failed`, {
|
||||||
|
requestId,
|
||||||
|
providerId: context.providerId,
|
||||||
|
modelId: context.modelId,
|
||||||
|
duration: duration ? `${duration}ms` : undefined,
|
||||||
|
error: {
|
||||||
|
name: error.name,
|
||||||
|
message: error.message,
|
||||||
|
stack: error.stack
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,139 @@
|
|||||||
|
/**
|
||||||
|
* 流事件管理器
|
||||||
|
*
|
||||||
|
* 负责处理 AI SDK 流事件的发送和管理
|
||||||
|
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
|
||||||
|
*/
|
||||||
|
import type { ModelMessage } from 'ai'
|
||||||
|
|
||||||
|
import type { AiRequestContext } from '../../types'
|
||||||
|
import type { StreamController } from './ToolExecutor'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 流事件管理器类
|
||||||
|
*/
|
||||||
|
export class StreamEventManager {
|
||||||
|
/**
|
||||||
|
* 发送工具调用步骤开始事件
|
||||||
|
*/
|
||||||
|
sendStepStartEvent(controller: StreamController): void {
|
||||||
|
controller.enqueue({
|
||||||
|
type: 'start-step',
|
||||||
|
request: {},
|
||||||
|
warnings: []
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 发送步骤完成事件
|
||||||
|
*/
|
||||||
|
sendStepFinishEvent(controller: StreamController, chunk: any): void {
|
||||||
|
controller.enqueue({
|
||||||
|
type: 'finish-step',
|
||||||
|
finishReason: 'stop',
|
||||||
|
response: chunk.response,
|
||||||
|
usage: chunk.usage,
|
||||||
|
providerMetadata: chunk.providerMetadata
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 处理递归调用并将结果流接入当前流
|
||||||
|
*/
|
||||||
|
async handleRecursiveCall(
|
||||||
|
controller: StreamController,
|
||||||
|
recursiveParams: any,
|
||||||
|
context: AiRequestContext,
|
||||||
|
stepId: string
|
||||||
|
): Promise<void> {
|
||||||
|
try {
|
||||||
|
console.log('[MCP Prompt] Starting recursive call after tool execution...')
|
||||||
|
|
||||||
|
const recursiveResult = await context.recursiveCall(recursiveParams)
|
||||||
|
|
||||||
|
if (recursiveResult && recursiveResult.fullStream) {
|
||||||
|
await this.pipeRecursiveStream(controller, recursiveResult.fullStream)
|
||||||
|
} else {
|
||||||
|
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
this.handleRecursiveCallError(controller, error, stepId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 将递归流的数据传递到当前流
|
||||||
|
*/
|
||||||
|
private async pipeRecursiveStream(controller: StreamController, recursiveStream: ReadableStream): Promise<void> {
|
||||||
|
const reader = recursiveStream.getReader()
|
||||||
|
try {
|
||||||
|
while (true) {
|
||||||
|
const { done, value } = await reader.read()
|
||||||
|
if (done) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if (value.type === 'finish') {
|
||||||
|
// 迭代的流不发finish
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// 将递归流的数据传递到当前流
|
||||||
|
controller.enqueue(value)
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
reader.releaseLock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 处理递归调用错误
|
||||||
|
*/
|
||||||
|
private handleRecursiveCallError(controller: StreamController, error: unknown, stepId: string): void {
|
||||||
|
console.error('[MCP Prompt] Recursive call failed:', error)
|
||||||
|
|
||||||
|
// 使用 AI SDK 标准错误格式,但不中断流
|
||||||
|
controller.enqueue({
|
||||||
|
type: 'error',
|
||||||
|
error: {
|
||||||
|
message: error instanceof Error ? error.message : String(error),
|
||||||
|
name: error instanceof Error ? error.name : 'RecursiveCallError'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// 继续发送文本增量,保持流的连续性
|
||||||
|
controller.enqueue({
|
||||||
|
type: 'text-delta',
|
||||||
|
id: stepId,
|
||||||
|
text: '\n\n[工具执行后递归调用失败,继续对话...]'
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构建递归调用的参数
|
||||||
|
*/
|
||||||
|
buildRecursiveParams(context: AiRequestContext, textBuffer: string, toolResultsText: string, tools: any): any {
|
||||||
|
// 构建新的对话消息
|
||||||
|
const newMessages: ModelMessage[] = [
|
||||||
|
...(context.originalParams.messages || []),
|
||||||
|
{
|
||||||
|
role: 'assistant',
|
||||||
|
content: textBuffer
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: toolResultsText
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
// 递归调用,继续对话,重新传递 tools
|
||||||
|
const recursiveParams = {
|
||||||
|
...context.originalParams,
|
||||||
|
messages: newMessages,
|
||||||
|
tools: tools
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新上下文中的消息
|
||||||
|
context.originalParams.messages = newMessages
|
||||||
|
|
||||||
|
return recursiveParams
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,156 @@
|
|||||||
|
/**
|
||||||
|
* 工具执行器
|
||||||
|
*
|
||||||
|
* 负责工具的执行、结果格式化和相关事件发送
|
||||||
|
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
|
||||||
|
*/
|
||||||
|
import type { ToolSet } from 'ai'
|
||||||
|
|
||||||
|
import type { ToolUseResult } from './type'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 工具执行结果
|
||||||
|
*/
|
||||||
|
export interface ExecutedResult {
|
||||||
|
toolCallId: string
|
||||||
|
toolName: string
|
||||||
|
result: any
|
||||||
|
isError?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 流控制器类型(从 AI SDK 提取)
|
||||||
|
*/
|
||||||
|
export interface StreamController {
|
||||||
|
enqueue(chunk: any): void
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 工具执行器类
|
||||||
|
*/
|
||||||
|
export class ToolExecutor {
|
||||||
|
/**
|
||||||
|
* 执行多个工具调用
|
||||||
|
*/
|
||||||
|
async executeTools(
|
||||||
|
toolUses: ToolUseResult[],
|
||||||
|
tools: ToolSet,
|
||||||
|
controller: StreamController
|
||||||
|
): Promise<ExecutedResult[]> {
|
||||||
|
const executedResults: ExecutedResult[] = []
|
||||||
|
|
||||||
|
for (const toolUse of toolUses) {
|
||||||
|
try {
|
||||||
|
const tool = tools[toolUse.toolName]
|
||||||
|
if (!tool || typeof tool.execute !== 'function') {
|
||||||
|
throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送工具调用开始事件
|
||||||
|
this.sendToolStartEvents(controller, toolUse)
|
||||||
|
|
||||||
|
console.log(`[MCP Prompt Stream] Executing tool: ${toolUse.toolName}`, toolUse.arguments)
|
||||||
|
|
||||||
|
// 发送 tool-call 事件
|
||||||
|
controller.enqueue({
|
||||||
|
type: 'tool-call',
|
||||||
|
toolCallId: toolUse.id,
|
||||||
|
toolName: toolUse.toolName,
|
||||||
|
input: tool.inputSchema
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = await tool.execute(toolUse.arguments, {
|
||||||
|
toolCallId: toolUse.id,
|
||||||
|
messages: [],
|
||||||
|
abortSignal: new AbortController().signal
|
||||||
|
})
|
||||||
|
|
||||||
|
// 发送 tool-result 事件
|
||||||
|
controller.enqueue({
|
||||||
|
type: 'tool-result',
|
||||||
|
toolCallId: toolUse.id,
|
||||||
|
toolName: toolUse.toolName,
|
||||||
|
input: toolUse.arguments,
|
||||||
|
output: result
|
||||||
|
})
|
||||||
|
|
||||||
|
executedResults.push({
|
||||||
|
toolCallId: toolUse.id,
|
||||||
|
toolName: toolUse.toolName,
|
||||||
|
result,
|
||||||
|
isError: false
|
||||||
|
})
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`[MCP Prompt Stream] Tool execution failed: ${toolUse.toolName}`, error)
|
||||||
|
|
||||||
|
// 处理错误情况
|
||||||
|
const errorResult = this.handleToolError(toolUse, error, controller)
|
||||||
|
executedResults.push(errorResult)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return executedResults
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 格式化工具结果为 Cherry Studio 标准格式
|
||||||
|
*/
|
||||||
|
formatToolResults(executedResults: ExecutedResult[]): string {
|
||||||
|
return executedResults
|
||||||
|
.map((tr) => {
|
||||||
|
if (!tr.isError) {
|
||||||
|
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <result>${JSON.stringify(tr.result)}</result>\n</tool_use_result>`
|
||||||
|
} else {
|
||||||
|
const error = tr.result || 'Unknown error'
|
||||||
|
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <error>${error}</error>\n</tool_use_result>`
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.join('\n\n')
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 发送工具调用开始相关事件
|
||||||
|
*/
|
||||||
|
private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
|
||||||
|
// 发送 tool-input-start 事件
|
||||||
|
controller.enqueue({
|
||||||
|
type: 'tool-input-start',
|
||||||
|
id: toolUse.id,
|
||||||
|
toolName: toolUse.toolName
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 处理工具执行错误
|
||||||
|
*/
|
||||||
|
private handleToolError(
|
||||||
|
toolUse: ToolUseResult,
|
||||||
|
error: unknown,
|
||||||
|
controller: StreamController
|
||||||
|
// _tools: ToolSet
|
||||||
|
): ExecutedResult {
|
||||||
|
// 使用 AI SDK 标准错误格式
|
||||||
|
// const toolError: TypedToolError<typeof _tools> = {
|
||||||
|
// type: 'tool-error',
|
||||||
|
// toolCallId: toolUse.id,
|
||||||
|
// toolName: toolUse.toolName,
|
||||||
|
// input: toolUse.arguments,
|
||||||
|
// error: error instanceof Error ? error.message : String(error)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// controller.enqueue(toolError)
|
||||||
|
|
||||||
|
// 发送标准错误事件
|
||||||
|
controller.enqueue({
|
||||||
|
type: 'error',
|
||||||
|
error: error instanceof Error ? error.message : String(error)
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
toolCallId: toolUse.id,
|
||||||
|
toolName: toolUse.toolName,
|
||||||
|
result: error instanceof Error ? error.message : String(error),
|
||||||
|
isError: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,373 @@
|
|||||||
|
/**
|
||||||
|
* 内置插件:MCP Prompt 模式
|
||||||
|
* 为不支持原生 Function Call 的模型提供 prompt 方式的工具调用
|
||||||
|
* 内置默认逻辑,支持自定义覆盖
|
||||||
|
*/
|
||||||
|
import type { TextStreamPart, ToolSet } from 'ai'
|
||||||
|
|
||||||
|
import { definePlugin } from '../../index'
|
||||||
|
import type { AiRequestContext } from '../../types'
|
||||||
|
import { StreamEventManager } from './StreamEventManager'
|
||||||
|
import { ToolExecutor } from './ToolExecutor'
|
||||||
|
import { PromptToolUseConfig, ToolUseResult } from './type'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 默认系统提示符模板(提取自 Cherry Studio)
|
||||||
|
*/
|
||||||
|
const DEFAULT_SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \\
|
||||||
|
You can use one tool per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
|
||||||
|
|
||||||
|
## Tool Use Formatting
|
||||||
|
|
||||||
|
Tool use is formatted using XML-style tags. The tool name is enclosed in opening and closing tags, and each parameter is similarly enclosed within its own set of tags. Here's the structure:
|
||||||
|
|
||||||
|
<tool_use>
|
||||||
|
<name>{tool_name}</name>
|
||||||
|
<arguments>{json_arguments}</arguments>
|
||||||
|
</tool_use>
|
||||||
|
|
||||||
|
The tool name should be the exact name of the tool you are using, and the arguments should be a JSON object containing the parameters required by that tool. For example:
|
||||||
|
<tool_use>
|
||||||
|
<name>python_interpreter</name>
|
||||||
|
<arguments>{"code": "5 + 3 + 1294.678"}</arguments>
|
||||||
|
</tool_use>
|
||||||
|
|
||||||
|
The user will respond with the result of the tool use, which should be formatted as follows:
|
||||||
|
|
||||||
|
<tool_use_result>
|
||||||
|
<name>{tool_name}</name>
|
||||||
|
<result>{result}</result>
|
||||||
|
</tool_use_result>
|
||||||
|
|
||||||
|
The result should be a string, which can represent a file or any other output type. You can use this result as input for the next action.
|
||||||
|
For example, if the result of the tool use is an image file, you can use it in the next action like this:
|
||||||
|
|
||||||
|
<tool_use>
|
||||||
|
<name>image_transformer</name>
|
||||||
|
<arguments>{"image": "image_1.jpg"}</arguments>
|
||||||
|
</tool_use>
|
||||||
|
|
||||||
|
Always adhere to this format for the tool use to ensure proper parsing and execution.
|
||||||
|
|
||||||
|
## Tool Use Examples
|
||||||
|
{{ TOOL_USE_EXAMPLES }}
|
||||||
|
|
||||||
|
## Tool Use Available Tools
|
||||||
|
Above example were using notional tools that might not exist for you. You only have access to these tools:
|
||||||
|
{{ AVAILABLE_TOOLS }}
|
||||||
|
|
||||||
|
## Tool Use Rules
|
||||||
|
Here are the rules you should always follow to solve your task:
|
||||||
|
1. Always use the right arguments for the tools. Never use variable names as the action arguments, use the value instead.
|
||||||
|
2. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself.
|
||||||
|
3. If no tool call is needed, just answer the question directly.
|
||||||
|
4. Never re-do a tool call that you previously did with the exact same parameters.
|
||||||
|
5. For tool use, MAKE SURE use XML tag format as shown in the examples above. Do not use any other format.
|
||||||
|
|
||||||
|
# User Instructions
|
||||||
|
{{ USER_SYSTEM_PROMPT }}
|
||||||
|
|
||||||
|
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.`
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 默认工具使用示例(提取自 Cherry Studio)
|
||||||
|
*/
|
||||||
|
const DEFAULT_TOOL_USE_EXAMPLES = `
|
||||||
|
Here are a few examples using notional tools:
|
||||||
|
---
|
||||||
|
User: Generate an image of the oldest person in this document.
|
||||||
|
|
||||||
|
A: I can use the document_qa tool to find out who the oldest person is in the document.
|
||||||
|
<tool_use>
|
||||||
|
<name>document_qa</name>
|
||||||
|
<arguments>{"document": "document.pdf", "question": "Who is the oldest person mentioned?"}</arguments>
|
||||||
|
</tool_use>
|
||||||
|
|
||||||
|
User: <tool_use_result>
|
||||||
|
<name>document_qa</name>
|
||||||
|
<result>John Doe, a 55 year old lumberjack living in Newfoundland.</result>
|
||||||
|
</tool_use_result>
|
||||||
|
|
||||||
|
A: I can use the image_generator tool to create a portrait of John Doe.
|
||||||
|
<tool_use>
|
||||||
|
<name>image_generator</name>
|
||||||
|
<arguments>{"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}</arguments>
|
||||||
|
</tool_use>
|
||||||
|
|
||||||
|
User: <tool_use_result>
|
||||||
|
<name>image_generator</name>
|
||||||
|
<result>image.png</result>
|
||||||
|
</tool_use_result>
|
||||||
|
|
||||||
|
A: the image is generated as image.png
|
||||||
|
|
||||||
|
---
|
||||||
|
User: "What is the result of the following operation: 5 + 3 + 1294.678?"
|
||||||
|
|
||||||
|
A: I can use the python_interpreter tool to calculate the result of the operation.
|
||||||
|
<tool_use>
|
||||||
|
<name>python_interpreter</name>
|
||||||
|
<arguments>{"code": "5 + 3 + 1294.678"}</arguments>
|
||||||
|
</tool_use>
|
||||||
|
|
||||||
|
User: <tool_use_result>
|
||||||
|
<name>python_interpreter</name>
|
||||||
|
<result>1302.678</result>
|
||||||
|
</tool_use_result>
|
||||||
|
|
||||||
|
A: The result of the operation is 1302.678.
|
||||||
|
|
||||||
|
---
|
||||||
|
User: "Which city has the highest population , Guangzhou or Shanghai?"
|
||||||
|
|
||||||
|
A: I can use the search tool to find the population of Guangzhou.
|
||||||
|
<tool_use>
|
||||||
|
<name>search</name>
|
||||||
|
<arguments>{"query": "Population Guangzhou"}</arguments>
|
||||||
|
</tool_use>
|
||||||
|
|
||||||
|
User: <tool_use_result>
|
||||||
|
<name>search</name>
|
||||||
|
<result>Guangzhou has a population of 15 million inhabitants as of 2021.</result>
|
||||||
|
</tool_use_result>
|
||||||
|
|
||||||
|
A: I can use the search tool to find the population of Shanghai.
|
||||||
|
<tool_use>
|
||||||
|
<name>search</name>
|
||||||
|
<arguments>{"query": "Population Shanghai"}</arguments>
|
||||||
|
</tool_use>
|
||||||
|
|
||||||
|
User: <tool_use_result>
|
||||||
|
<name>search</name>
|
||||||
|
<result>26 million (2019)</result>
|
||||||
|
</tool_use_result>
|
||||||
|
Assistant: The population of Shanghai is 26 million, while Guangzhou has a population of 15 million. Therefore, Shanghai has the highest population.`
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构建可用工具部分(提取自 Cherry Studio)
|
||||||
|
*/
|
||||||
|
function buildAvailableTools(tools: ToolSet): string {
|
||||||
|
const availableTools = Object.keys(tools)
|
||||||
|
.map((toolName: string) => {
|
||||||
|
const tool = tools[toolName]
|
||||||
|
return `
|
||||||
|
<tool>
|
||||||
|
<name>${toolName}</name>
|
||||||
|
<description>${tool.description || ''}</description>
|
||||||
|
<arguments>
|
||||||
|
${tool.inputSchema ? JSON.stringify(tool.inputSchema) : ''}
|
||||||
|
</arguments>
|
||||||
|
</tool>
|
||||||
|
`
|
||||||
|
})
|
||||||
|
.join('\n')
|
||||||
|
return `<tools>
|
||||||
|
${availableTools}
|
||||||
|
</tools>`
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 默认的系统提示符构建函数(提取自 Cherry Studio)
|
||||||
|
*/
|
||||||
|
function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet): string {
|
||||||
|
const availableTools = buildAvailableTools(tools)
|
||||||
|
|
||||||
|
const fullPrompt = DEFAULT_SYSTEM_PROMPT.replace('{{ TOOL_USE_EXAMPLES }}', DEFAULT_TOOL_USE_EXAMPLES)
|
||||||
|
.replace('{{ AVAILABLE_TOOLS }}', availableTools)
|
||||||
|
.replace('{{ USER_SYSTEM_PROMPT }}', userSystemPrompt || '')
|
||||||
|
|
||||||
|
return fullPrompt
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 默认工具解析函数(提取自 Cherry Studio)
|
||||||
|
* 解析 XML 格式的工具调用
|
||||||
|
*/
|
||||||
|
function defaultParseToolUse(content: string, tools: ToolSet): { results: ToolUseResult[]; content: string } {
|
||||||
|
if (!content || !tools || Object.keys(tools).length === 0) {
|
||||||
|
return { results: [], content: content }
|
||||||
|
}
|
||||||
|
|
||||||
|
// 支持两种格式:
|
||||||
|
// 1. 完整的 <tool_use></tool_use> 标签包围的内容
|
||||||
|
// 2. 只有内部内容(从 TagExtractor 提取出来的)
|
||||||
|
|
||||||
|
let contentToProcess = content
|
||||||
|
// 如果内容不包含 <tool_use> 标签,说明是从 TagExtractor 提取的内部内容,需要包装
|
||||||
|
if (!content.includes('<tool_use>')) {
|
||||||
|
contentToProcess = `<tool_use>\n${content}\n</tool_use>`
|
||||||
|
}
|
||||||
|
|
||||||
|
const toolUsePattern =
|
||||||
|
/<tool_use>([\s\S]*?)<name>([\s\S]*?)<\/name>([\s\S]*?)<arguments>([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g
|
||||||
|
const results: ToolUseResult[] = []
|
||||||
|
let match
|
||||||
|
let idx = 0
|
||||||
|
|
||||||
|
// Find all tool use blocks
|
||||||
|
while ((match = toolUsePattern.exec(contentToProcess)) !== null) {
|
||||||
|
const fullMatch = match[0]
|
||||||
|
const toolName = match[2].trim()
|
||||||
|
const toolArgs = match[4].trim()
|
||||||
|
|
||||||
|
// Try to parse the arguments as JSON
|
||||||
|
let parsedArgs
|
||||||
|
try {
|
||||||
|
parsedArgs = JSON.parse(toolArgs)
|
||||||
|
} catch (error) {
|
||||||
|
// If parsing fails, use the string as is
|
||||||
|
parsedArgs = toolArgs
|
||||||
|
}
|
||||||
|
|
||||||
|
// Find the corresponding tool
|
||||||
|
const tool = tools[toolName]
|
||||||
|
if (!tool) {
|
||||||
|
console.warn(`Tool "${toolName}" not found in available tools`)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add to results array
|
||||||
|
results.push({
|
||||||
|
id: `${toolName}-${idx++}`, // Unique ID for each tool use
|
||||||
|
toolName: toolName,
|
||||||
|
arguments: parsedArgs,
|
||||||
|
status: 'pending'
|
||||||
|
})
|
||||||
|
contentToProcess = contentToProcess.replace(fullMatch, '')
|
||||||
|
}
|
||||||
|
return { results, content: contentToProcess }
|
||||||
|
}
|
||||||
|
|
||||||
|
export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
|
||||||
|
const { enabled = true, buildSystemPrompt = defaultBuildSystemPrompt, parseToolUse = defaultParseToolUse } = config
|
||||||
|
|
||||||
|
return definePlugin({
|
||||||
|
name: 'built-in:prompt-tool-use',
|
||||||
|
transformParams: (params: any, context: AiRequestContext) => {
|
||||||
|
if (!enabled || !params.tools || typeof params.tools !== 'object') {
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
|
context.mcpTools = params.tools
|
||||||
|
console.log('tools stored in context', params.tools)
|
||||||
|
|
||||||
|
// 构建系统提示符
|
||||||
|
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
|
||||||
|
const systemPrompt = buildSystemPrompt(userSystemPrompt, params.tools)
|
||||||
|
let systemMessage: string | null = systemPrompt
|
||||||
|
console.log('config.context', context)
|
||||||
|
if (config.createSystemMessage) {
|
||||||
|
// 🎯 如果用户提供了自定义处理函数,使用它
|
||||||
|
systemMessage = config.createSystemMessage(systemPrompt, params, context)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 移除 tools,改为 prompt 模式
|
||||||
|
const transformedParams = {
|
||||||
|
...params,
|
||||||
|
...(systemMessage ? { system: systemMessage } : {}),
|
||||||
|
tools: undefined
|
||||||
|
}
|
||||||
|
context.originalParams = transformedParams
|
||||||
|
console.log('transformedParams', transformedParams)
|
||||||
|
return transformedParams
|
||||||
|
},
|
||||||
|
transformStream: (_: any, context: AiRequestContext) => () => {
|
||||||
|
let textBuffer = ''
|
||||||
|
let stepId = ''
|
||||||
|
|
||||||
|
if (!context.mcpTools) {
|
||||||
|
throw new Error('No tools available')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建工具执行器和流事件管理器
|
||||||
|
const toolExecutor = new ToolExecutor()
|
||||||
|
const streamEventManager = new StreamEventManager()
|
||||||
|
|
||||||
|
type TOOLS = NonNullable<typeof context.mcpTools>
|
||||||
|
return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
|
||||||
|
async transform(
|
||||||
|
chunk: TextStreamPart<TOOLS>,
|
||||||
|
controller: TransformStreamDefaultController<TextStreamPart<TOOLS>>
|
||||||
|
) {
|
||||||
|
// 收集文本内容
|
||||||
|
if (chunk.type === 'text-delta') {
|
||||||
|
textBuffer += chunk.text || ''
|
||||||
|
stepId = chunk.id || ''
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (chunk.type === 'text-end' || chunk.type === 'finish-step') {
|
||||||
|
const tools = context.mcpTools
|
||||||
|
if (!tools || Object.keys(tools).length === 0) {
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 解析工具调用
|
||||||
|
const { results: parsedTools, content: parsedContent } = parseToolUse(textBuffer, tools)
|
||||||
|
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
|
||||||
|
|
||||||
|
// 如果没有有效的工具调用,直接传递原始事件
|
||||||
|
if (validToolUses.length === 0) {
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (chunk.type === 'text-end') {
|
||||||
|
controller.enqueue({
|
||||||
|
type: 'text-end',
|
||||||
|
id: stepId,
|
||||||
|
providerMetadata: {
|
||||||
|
text: {
|
||||||
|
value: parsedContent
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
controller.enqueue({
|
||||||
|
...chunk,
|
||||||
|
finishReason: 'tool-calls'
|
||||||
|
})
|
||||||
|
|
||||||
|
// 发送步骤开始事件
|
||||||
|
streamEventManager.sendStepStartEvent(controller)
|
||||||
|
|
||||||
|
// 执行工具调用
|
||||||
|
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
|
||||||
|
|
||||||
|
// 发送步骤完成事件
|
||||||
|
streamEventManager.sendStepFinishEvent(controller, chunk)
|
||||||
|
|
||||||
|
// 处理递归调用
|
||||||
|
if (validToolUses.length > 0) {
|
||||||
|
const toolResultsText = toolExecutor.formatToolResults(executedResults)
|
||||||
|
const recursiveParams = streamEventManager.buildRecursiveParams(
|
||||||
|
context,
|
||||||
|
textBuffer,
|
||||||
|
toolResultsText,
|
||||||
|
tools
|
||||||
|
)
|
||||||
|
|
||||||
|
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context, stepId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 清理状态
|
||||||
|
textBuffer = ''
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对于其他类型的事件,直接传递
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
},
|
||||||
|
|
||||||
|
flush() {
|
||||||
|
// 流结束时的清理工作
|
||||||
|
console.log('[MCP Prompt] Stream ended, cleaning up...')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -0,0 +1,196 @@
|
|||||||
|
// Copied from https://github.com/vercel/ai/blob/main/packages/ai/core/util/get-potential-start-index.ts
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Returns the index of the start of the searchedText in the text, or null if it
|
||||||
|
* is not found.
|
||||||
|
*/
|
||||||
|
export function getPotentialStartIndex(text: string, searchedText: string): number | null {
|
||||||
|
// Return null immediately if searchedText is empty.
|
||||||
|
if (searchedText.length === 0) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the searchedText exists as a direct substring of text.
|
||||||
|
const directIndex = text.indexOf(searchedText)
|
||||||
|
if (directIndex !== -1) {
|
||||||
|
return directIndex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Otherwise, look for the largest suffix of "text" that matches
|
||||||
|
// a prefix of "searchedText". We go from the end of text inward.
|
||||||
|
for (let i = text.length - 1; i >= 0; i--) {
|
||||||
|
const suffix = text.substring(i)
|
||||||
|
if (searchedText.startsWith(suffix)) {
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface TagConfig {
|
||||||
|
openingTag: string
|
||||||
|
closingTag: string
|
||||||
|
separator?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface TagExtractionState {
|
||||||
|
textBuffer: string
|
||||||
|
isInsideTag: boolean
|
||||||
|
isFirstTag: boolean
|
||||||
|
isFirstText: boolean
|
||||||
|
afterSwitch: boolean
|
||||||
|
accumulatedTagContent: string
|
||||||
|
hasTagContent: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface TagExtractionResult {
|
||||||
|
content: string
|
||||||
|
isTagContent: boolean
|
||||||
|
complete: boolean
|
||||||
|
tagContentExtracted?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 通用标签提取处理器
|
||||||
|
* 可以处理各种形式的标签对,如 <think>...</think>, <tool_use>...</tool_use> 等
|
||||||
|
*/
|
||||||
|
export class TagExtractor {
|
||||||
|
private config: TagConfig
|
||||||
|
private state: TagExtractionState
|
||||||
|
|
||||||
|
constructor(config: TagConfig) {
|
||||||
|
this.config = config
|
||||||
|
this.state = {
|
||||||
|
textBuffer: '',
|
||||||
|
isInsideTag: false,
|
||||||
|
isFirstTag: true,
|
||||||
|
isFirstText: true,
|
||||||
|
afterSwitch: false,
|
||||||
|
accumulatedTagContent: '',
|
||||||
|
hasTagContent: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 处理文本块,返回处理结果
|
||||||
|
*/
|
||||||
|
processText(newText: string): TagExtractionResult[] {
|
||||||
|
this.state.textBuffer += newText
|
||||||
|
const results: TagExtractionResult[] = []
|
||||||
|
|
||||||
|
// 处理标签提取逻辑
|
||||||
|
while (true) {
|
||||||
|
const nextTag = this.state.isInsideTag ? this.config.closingTag : this.config.openingTag
|
||||||
|
const startIndex = getPotentialStartIndex(this.state.textBuffer, nextTag)
|
||||||
|
|
||||||
|
if (startIndex == null) {
|
||||||
|
const content = this.state.textBuffer
|
||||||
|
if (content.length > 0) {
|
||||||
|
results.push({
|
||||||
|
content: this.addPrefix(content),
|
||||||
|
isTagContent: this.state.isInsideTag,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
|
||||||
|
if (this.state.isInsideTag) {
|
||||||
|
this.state.accumulatedTagContent += this.addPrefix(content)
|
||||||
|
this.state.hasTagContent = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.state.textBuffer = ''
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理标签前的内容
|
||||||
|
const contentBeforeTag = this.state.textBuffer.slice(0, startIndex)
|
||||||
|
if (contentBeforeTag.length > 0) {
|
||||||
|
results.push({
|
||||||
|
content: this.addPrefix(contentBeforeTag),
|
||||||
|
isTagContent: this.state.isInsideTag,
|
||||||
|
complete: false
|
||||||
|
})
|
||||||
|
|
||||||
|
if (this.state.isInsideTag) {
|
||||||
|
this.state.accumulatedTagContent += this.addPrefix(contentBeforeTag)
|
||||||
|
this.state.hasTagContent = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const foundFullMatch = startIndex + nextTag.length <= this.state.textBuffer.length
|
||||||
|
|
||||||
|
if (foundFullMatch) {
|
||||||
|
// 如果找到完整的标签
|
||||||
|
this.state.textBuffer = this.state.textBuffer.slice(startIndex + nextTag.length)
|
||||||
|
|
||||||
|
// 如果刚刚结束一个标签内容,生成完整的标签内容结果
|
||||||
|
if (this.state.isInsideTag && this.state.hasTagContent) {
|
||||||
|
results.push({
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: this.state.accumulatedTagContent
|
||||||
|
})
|
||||||
|
this.state.accumulatedTagContent = ''
|
||||||
|
this.state.hasTagContent = false
|
||||||
|
}
|
||||||
|
|
||||||
|
this.state.isInsideTag = !this.state.isInsideTag
|
||||||
|
this.state.afterSwitch = true
|
||||||
|
|
||||||
|
if (this.state.isInsideTag) {
|
||||||
|
this.state.isFirstTag = false
|
||||||
|
} else {
|
||||||
|
this.state.isFirstText = false
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
this.state.textBuffer = this.state.textBuffer.slice(startIndex)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return results
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 完成处理,返回任何剩余的标签内容
|
||||||
|
*/
|
||||||
|
finalize(): TagExtractionResult | null {
|
||||||
|
if (this.state.hasTagContent && this.state.accumulatedTagContent) {
|
||||||
|
const result = {
|
||||||
|
content: '',
|
||||||
|
isTagContent: false,
|
||||||
|
complete: true,
|
||||||
|
tagContentExtracted: this.state.accumulatedTagContent
|
||||||
|
}
|
||||||
|
this.state.accumulatedTagContent = ''
|
||||||
|
this.state.hasTagContent = false
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
private addPrefix(text: string): string {
|
||||||
|
const needsPrefix =
|
||||||
|
this.state.afterSwitch && (this.state.isInsideTag ? !this.state.isFirstTag : !this.state.isFirstText)
|
||||||
|
|
||||||
|
const prefix = needsPrefix && this.config.separator ? this.config.separator : ''
|
||||||
|
this.state.afterSwitch = false
|
||||||
|
return prefix + text
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 重置状态
|
||||||
|
*/
|
||||||
|
reset(): void {
|
||||||
|
this.state = {
|
||||||
|
textBuffer: '',
|
||||||
|
isInsideTag: false,
|
||||||
|
isFirstTag: true,
|
||||||
|
isFirstText: true,
|
||||||
|
afterSwitch: false,
|
||||||
|
accumulatedTagContent: '',
|
||||||
|
hasTagContent: false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,33 @@
|
|||||||
|
import { ToolSet } from 'ai'
|
||||||
|
|
||||||
|
import { AiRequestContext } from '../..'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析结果类型
|
||||||
|
* 表示从AI响应中解析出的工具使用意图
|
||||||
|
*/
|
||||||
|
export interface ToolUseResult {
|
||||||
|
id: string
|
||||||
|
toolName: string
|
||||||
|
arguments: any
|
||||||
|
status: 'pending' | 'invoking' | 'done' | 'error'
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface BaseToolUsePluginConfig {
|
||||||
|
enabled?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface PromptToolUseConfig extends BaseToolUsePluginConfig {
|
||||||
|
// 自定义系统提示符构建函数(可选,有默认实现)
|
||||||
|
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => string
|
||||||
|
// 自定义工具解析函数(可选,有默认实现)
|
||||||
|
parseToolUse?: (content: string, tools: ToolSet) => { results: ToolUseResult[]; content: string }
|
||||||
|
createSystemMessage?: (systemPrompt: string, originalParams: any, context: AiRequestContext) => string | null
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 扩展的 AI 请求上下文,支持 MCP 工具存储
|
||||||
|
*/
|
||||||
|
export interface ToolUseRequestContext extends AiRequestContext {
|
||||||
|
mcpTools: ToolSet
|
||||||
|
}
|
||||||
@@ -0,0 +1,67 @@
|
|||||||
|
import { anthropic } from '@ai-sdk/anthropic'
|
||||||
|
import { google } from '@ai-sdk/google'
|
||||||
|
import { openai } from '@ai-sdk/openai'
|
||||||
|
|
||||||
|
import { ProviderOptionsMap } from '../../../options/types'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从 AI SDK 的工具函数中提取参数类型,以确保类型安全。
|
||||||
|
*/
|
||||||
|
type OpenAISearchConfig = Parameters<typeof openai.tools.webSearchPreview>[0]
|
||||||
|
type AnthropicSearchConfig = Parameters<typeof anthropic.tools.webSearch_20250305>[0]
|
||||||
|
type GoogleSearchConfig = Parameters<typeof google.tools.googleSearch>[0]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 插件初始化时接收的完整配置对象
|
||||||
|
*
|
||||||
|
* 其结构与 ProviderOptions 保持一致,方便上游统一管理配置
|
||||||
|
*/
|
||||||
|
export interface WebSearchPluginConfig {
|
||||||
|
openai?: OpenAISearchConfig
|
||||||
|
anthropic?: AnthropicSearchConfig
|
||||||
|
xai?: ProviderOptionsMap['xai']['searchParameters']
|
||||||
|
google?: GoogleSearchConfig
|
||||||
|
'google-vertex'?: GoogleSearchConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 插件的默认配置
|
||||||
|
*/
|
||||||
|
export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
|
||||||
|
google: {},
|
||||||
|
'google-vertex': {},
|
||||||
|
openai: {},
|
||||||
|
xai: {
|
||||||
|
mode: 'on',
|
||||||
|
returnCitations: true,
|
||||||
|
maxSearchResults: 5,
|
||||||
|
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
|
||||||
|
},
|
||||||
|
anthropic: {
|
||||||
|
maxUses: 5
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export type WebSearchToolOutputSchema = {
|
||||||
|
// Anthropic 工具 - 手动定义
|
||||||
|
anthropicWebSearch: Array<{
|
||||||
|
url: string
|
||||||
|
title: string
|
||||||
|
pageAge: string | null
|
||||||
|
encryptedContent: string
|
||||||
|
type: string
|
||||||
|
}>
|
||||||
|
|
||||||
|
// OpenAI 工具 - 基于实际输出
|
||||||
|
openaiWebSearch: {
|
||||||
|
status: 'completed' | 'failed'
|
||||||
|
}
|
||||||
|
|
||||||
|
// Google 工具
|
||||||
|
googleSearch: {
|
||||||
|
webSearchQueries?: string[]
|
||||||
|
groundingChunks?: Array<{
|
||||||
|
web?: { uri: string; title: string }
|
||||||
|
}>
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,69 @@
|
|||||||
|
/**
|
||||||
|
* Web Search Plugin
|
||||||
|
* 提供统一的网络搜索能力,支持多个 AI Provider
|
||||||
|
*/
|
||||||
|
import { anthropic } from '@ai-sdk/anthropic'
|
||||||
|
import { google } from '@ai-sdk/google'
|
||||||
|
import { openai } from '@ai-sdk/openai'
|
||||||
|
|
||||||
|
import { createXaiOptions, mergeProviderOptions } from '../../../options'
|
||||||
|
import { definePlugin } from '../../'
|
||||||
|
import type { AiRequestContext } from '../../types'
|
||||||
|
import { DEFAULT_WEB_SEARCH_CONFIG, WebSearchPluginConfig } from './helper'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 网络搜索插件
|
||||||
|
*
|
||||||
|
* @param config - 在插件初始化时传入的静态配置
|
||||||
|
*/
|
||||||
|
export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEARCH_CONFIG) =>
|
||||||
|
definePlugin({
|
||||||
|
name: 'webSearch',
|
||||||
|
enforce: 'pre',
|
||||||
|
|
||||||
|
transformParams: async (params: any, context: AiRequestContext) => {
|
||||||
|
const { providerId } = context
|
||||||
|
switch (providerId) {
|
||||||
|
case 'openai': {
|
||||||
|
if (config.openai) {
|
||||||
|
if (!params.tools) params.tools = {}
|
||||||
|
params.tools.web_search_preview = openai.tools.webSearchPreview(config.openai)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'anthropic': {
|
||||||
|
if (config.anthropic) {
|
||||||
|
if (!params.tools) params.tools = {}
|
||||||
|
params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'google': {
|
||||||
|
// case 'google-vertex':
|
||||||
|
if (!params.tools) params.tools = {}
|
||||||
|
params.tools.web_search = google.tools.googleSearch(config.google || {})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'xai': {
|
||||||
|
if (config.xai) {
|
||||||
|
const searchOptions = createXaiOptions({
|
||||||
|
searchParameters: { ...config.xai, mode: 'on' }
|
||||||
|
})
|
||||||
|
params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// 导出类型定义供开发者使用
|
||||||
|
export type { WebSearchPluginConfig, WebSearchToolOutputSchema } from './helper'
|
||||||
|
|
||||||
|
// 默认导出
|
||||||
|
export default webSearchPlugin
|
||||||
32
packages/aiCore/src/core/plugins/index.ts
Normal file
32
packages/aiCore/src/core/plugins/index.ts
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
// 核心类型和接口
|
||||||
|
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types'
|
||||||
|
import type { ProviderId } from '../providers'
|
||||||
|
import type { AiPlugin, AiRequestContext } from './types'
|
||||||
|
|
||||||
|
// 插件管理器
|
||||||
|
export { PluginManager } from './manager'
|
||||||
|
|
||||||
|
// 工具函数
|
||||||
|
export function createContext<T extends ProviderId>(
|
||||||
|
providerId: T,
|
||||||
|
modelId: string,
|
||||||
|
originalParams: any
|
||||||
|
): AiRequestContext {
|
||||||
|
return {
|
||||||
|
providerId,
|
||||||
|
modelId,
|
||||||
|
originalParams,
|
||||||
|
metadata: {},
|
||||||
|
startTime: Date.now(),
|
||||||
|
requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
|
||||||
|
// 占位
|
||||||
|
recursiveCall: () => Promise.resolve(null)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 插件构建器 - 便于创建插件
|
||||||
|
export function definePlugin(plugin: AiPlugin): AiPlugin
|
||||||
|
export function definePlugin<T extends (...args: any[]) => AiPlugin>(pluginFactory: T): T
|
||||||
|
export function definePlugin(plugin: AiPlugin | ((...args: any[]) => AiPlugin)) {
|
||||||
|
return plugin
|
||||||
|
}
|
||||||
184
packages/aiCore/src/core/plugins/manager.ts
Normal file
184
packages/aiCore/src/core/plugins/manager.ts
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
import { AiPlugin, AiRequestContext } from './types'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 插件管理器
|
||||||
|
*/
|
||||||
|
export class PluginManager {
|
||||||
|
private plugins: AiPlugin[] = []
|
||||||
|
|
||||||
|
constructor(plugins: AiPlugin[] = []) {
|
||||||
|
this.plugins = this.sortPlugins(plugins)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 添加插件
|
||||||
|
*/
|
||||||
|
use(plugin: AiPlugin): this {
|
||||||
|
this.plugins = this.sortPlugins([...this.plugins, plugin])
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 移除插件
|
||||||
|
*/
|
||||||
|
remove(pluginName: string): this {
|
||||||
|
this.plugins = this.plugins.filter((p) => p.name !== pluginName)
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 插件排序:pre -> normal -> post
|
||||||
|
*/
|
||||||
|
private sortPlugins(plugins: AiPlugin[]): AiPlugin[] {
|
||||||
|
const pre: AiPlugin[] = []
|
||||||
|
const normal: AiPlugin[] = []
|
||||||
|
const post: AiPlugin[] = []
|
||||||
|
|
||||||
|
plugins.forEach((plugin) => {
|
||||||
|
if (plugin.enforce === 'pre') {
|
||||||
|
pre.push(plugin)
|
||||||
|
} else if (plugin.enforce === 'post') {
|
||||||
|
post.push(plugin)
|
||||||
|
} else {
|
||||||
|
normal.push(plugin)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return [...pre, ...normal, ...post]
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行 First 钩子 - 返回第一个有效结果
|
||||||
|
*/
|
||||||
|
async executeFirst<T>(
|
||||||
|
hookName: 'resolveModel' | 'loadTemplate',
|
||||||
|
arg: any,
|
||||||
|
context: AiRequestContext
|
||||||
|
): Promise<T | null> {
|
||||||
|
for (const plugin of this.plugins) {
|
||||||
|
const hook = plugin[hookName]
|
||||||
|
if (hook) {
|
||||||
|
const result = await hook(arg, context)
|
||||||
|
if (result !== null && result !== undefined) {
|
||||||
|
return result as T
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行 Sequential 钩子 - 链式数据转换
|
||||||
|
*/
|
||||||
|
async executeSequential<T>(
|
||||||
|
hookName: 'transformParams' | 'transformResult',
|
||||||
|
initialValue: T,
|
||||||
|
context: AiRequestContext
|
||||||
|
): Promise<T> {
|
||||||
|
let result = initialValue
|
||||||
|
|
||||||
|
for (const plugin of this.plugins) {
|
||||||
|
const hook = plugin[hookName]
|
||||||
|
if (hook) {
|
||||||
|
result = await hook<T>(result, context)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行 ConfigureContext 钩子 - 串行配置上下文
|
||||||
|
*/
|
||||||
|
async executeConfigureContext(context: AiRequestContext): Promise<void> {
|
||||||
|
for (const plugin of this.plugins) {
|
||||||
|
const hook = plugin.configureContext
|
||||||
|
if (hook) {
|
||||||
|
await hook(context)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行 Parallel 钩子 - 并行副作用
|
||||||
|
*/
|
||||||
|
async executeParallel(
|
||||||
|
hookName: 'onRequestStart' | 'onRequestEnd' | 'onError',
|
||||||
|
context: AiRequestContext,
|
||||||
|
result?: any,
|
||||||
|
error?: Error
|
||||||
|
): Promise<void> {
|
||||||
|
const promises = this.plugins
|
||||||
|
.map((plugin) => {
|
||||||
|
const hook = plugin[hookName]
|
||||||
|
if (!hook) return null
|
||||||
|
|
||||||
|
if (hookName === 'onError' && error) {
|
||||||
|
return (hook as any)(error, context)
|
||||||
|
} else if (hookName === 'onRequestEnd' && result !== undefined) {
|
||||||
|
return (hook as any)(context, result)
|
||||||
|
} else if (hookName === 'onRequestStart') {
|
||||||
|
return (hook as any)(context)
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
})
|
||||||
|
.filter(Boolean)
|
||||||
|
|
||||||
|
// 使用 Promise.all 而不是 allSettled,让插件错误能够抛出
|
||||||
|
await Promise.all(promises)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 收集所有流转换器(返回数组,AI SDK 原生支持)
|
||||||
|
*/
|
||||||
|
collectStreamTransforms(params: any, context: AiRequestContext) {
|
||||||
|
return this.plugins
|
||||||
|
.filter((plugin) => plugin.transformStream)
|
||||||
|
.map((plugin) => plugin.transformStream?.(params, context))
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取所有插件信息
|
||||||
|
*/
|
||||||
|
getPlugins(): AiPlugin[] {
|
||||||
|
return [...this.plugins]
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取插件统计信息
|
||||||
|
*/
|
||||||
|
getStats() {
|
||||||
|
const stats = {
|
||||||
|
total: this.plugins.length,
|
||||||
|
pre: 0,
|
||||||
|
normal: 0,
|
||||||
|
post: 0,
|
||||||
|
hooks: {
|
||||||
|
resolveModel: 0,
|
||||||
|
loadTemplate: 0,
|
||||||
|
transformParams: 0,
|
||||||
|
transformResult: 0,
|
||||||
|
onRequestStart: 0,
|
||||||
|
onRequestEnd: 0,
|
||||||
|
onError: 0,
|
||||||
|
transformStream: 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
this.plugins.forEach((plugin) => {
|
||||||
|
// 统计 enforce 类型
|
||||||
|
if (plugin.enforce === 'pre') stats.pre++
|
||||||
|
else if (plugin.enforce === 'post') stats.post++
|
||||||
|
else stats.normal++
|
||||||
|
|
||||||
|
// 统计钩子数量
|
||||||
|
Object.keys(stats.hooks).forEach((hookName) => {
|
||||||
|
if (plugin[hookName as keyof AiPlugin]) {
|
||||||
|
stats.hooks[hookName as keyof typeof stats.hooks]++
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
return stats
|
||||||
|
}
|
||||||
|
}
|
||||||
79
packages/aiCore/src/core/plugins/types.ts
Normal file
79
packages/aiCore/src/core/plugins/types.ts
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
import type { ImageModelV2 } from '@ai-sdk/provider'
|
||||||
|
import type { LanguageModel, TextStreamPart, ToolSet } from 'ai'
|
||||||
|
|
||||||
|
import { type ProviderId } from '../providers/types'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 递归调用函数类型
|
||||||
|
* 使用 any 是因为递归调用时参数和返回类型可能完全不同
|
||||||
|
*/
|
||||||
|
export type RecursiveCallFn = (newParams: any) => Promise<any>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AI 请求上下文
|
||||||
|
*/
|
||||||
|
export interface AiRequestContext {
|
||||||
|
providerId: ProviderId
|
||||||
|
modelId: string
|
||||||
|
originalParams: any
|
||||||
|
metadata: Record<string, any>
|
||||||
|
startTime: number
|
||||||
|
requestId: string
|
||||||
|
recursiveCall: RecursiveCallFn
|
||||||
|
isRecursiveCall?: boolean
|
||||||
|
mcpTools?: ToolSet
|
||||||
|
[key: string]: any
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 钩子分类
|
||||||
|
*/
|
||||||
|
export interface AiPlugin {
|
||||||
|
name: string
|
||||||
|
enforce?: 'pre' | 'post'
|
||||||
|
|
||||||
|
// 【First】首个钩子 - 只执行第一个返回值的插件
|
||||||
|
resolveModel?: (
|
||||||
|
modelId: string,
|
||||||
|
context: AiRequestContext
|
||||||
|
) => Promise<LanguageModel | ImageModelV2 | null> | LanguageModel | ImageModelV2 | null
|
||||||
|
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
|
||||||
|
|
||||||
|
// 【Sequential】串行钩子 - 链式执行,支持数据转换
|
||||||
|
configureContext?: (context: AiRequestContext) => void | Promise<void>
|
||||||
|
transformParams?: <T>(params: T, context: AiRequestContext) => T | Promise<T>
|
||||||
|
transformResult?: <T>(result: T, context: AiRequestContext) => T | Promise<T>
|
||||||
|
|
||||||
|
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
|
||||||
|
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
|
||||||
|
onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise<void>
|
||||||
|
onError?: (error: Error, context: AiRequestContext) => void | Promise<void>
|
||||||
|
|
||||||
|
// 【Stream】流处理 - 直接使用 AI SDK
|
||||||
|
transformStream?: (
|
||||||
|
params: any,
|
||||||
|
context: AiRequestContext
|
||||||
|
) => <TOOLS extends ToolSet>(options?: {
|
||||||
|
tools: TOOLS
|
||||||
|
stopStream: () => void
|
||||||
|
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
|
||||||
|
|
||||||
|
// AI SDK 原生中间件
|
||||||
|
// aiSdkMiddlewares?: LanguageModelV1Middleware[]
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 插件管理器配置
|
||||||
|
*/
|
||||||
|
export interface PluginManagerConfig {
|
||||||
|
plugins: AiPlugin[]
|
||||||
|
context: Partial<AiRequestContext>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 钩子执行结果
|
||||||
|
*/
|
||||||
|
export interface HookResult<T = any> {
|
||||||
|
value: T
|
||||||
|
stop?: boolean
|
||||||
|
}
|
||||||
101
packages/aiCore/src/core/providers/HubProvider.ts
Normal file
101
packages/aiCore/src/core/providers/HubProvider.ts
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
/**
|
||||||
|
* Hub Provider - 支持路由到多个底层provider
|
||||||
|
*
|
||||||
|
* 支持格式: hubId:providerId:modelId
|
||||||
|
* 例如: aihubmix:anthropic:claude-3.5-sonnet
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { ProviderV2 } from '@ai-sdk/provider'
|
||||||
|
import { customProvider } from 'ai'
|
||||||
|
|
||||||
|
import { globalRegistryManagement } from './RegistryManagement'
|
||||||
|
import type { AiSdkMethodName, AiSdkModelReturn, AiSdkModelType } from './types'
|
||||||
|
|
||||||
|
export interface HubProviderConfig {
|
||||||
|
/** Hub的唯一标识符 */
|
||||||
|
hubId: string
|
||||||
|
/** 是否启用调试日志 */
|
||||||
|
debug?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
export class HubProviderError extends Error {
|
||||||
|
constructor(
|
||||||
|
message: string,
|
||||||
|
public readonly hubId: string,
|
||||||
|
public readonly providerId?: string,
|
||||||
|
public readonly originalError?: Error
|
||||||
|
) {
|
||||||
|
super(message)
|
||||||
|
this.name = 'HubProviderError'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析Hub模型ID
|
||||||
|
*/
|
||||||
|
function parseHubModelId(modelId: string): { provider: string; actualModelId: string } {
|
||||||
|
const parts = modelId.split(':')
|
||||||
|
if (parts.length !== 2) {
|
||||||
|
throw new HubProviderError(`Invalid hub model ID format. Expected "provider:modelId", got: ${modelId}`, 'unknown')
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
provider: parts[0],
|
||||||
|
actualModelId: parts[1]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建Hub Provider
|
||||||
|
*/
|
||||||
|
export function createHubProvider(config: HubProviderConfig): ProviderV2 {
|
||||||
|
const { hubId } = config
|
||||||
|
|
||||||
|
function getTargetProvider(providerId: string): ProviderV2 {
|
||||||
|
// 从全局注册表获取provider实例
|
||||||
|
try {
|
||||||
|
const provider = globalRegistryManagement.getProvider(providerId)
|
||||||
|
if (!provider) {
|
||||||
|
throw new HubProviderError(
|
||||||
|
`Provider "${providerId}" is not initialized. Please call initializeProvider("${providerId}", options) first.`,
|
||||||
|
hubId,
|
||||||
|
providerId
|
||||||
|
)
|
||||||
|
}
|
||||||
|
return provider
|
||||||
|
} catch (error) {
|
||||||
|
throw new HubProviderError(
|
||||||
|
`Failed to get provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
|
||||||
|
hubId,
|
||||||
|
providerId,
|
||||||
|
error instanceof Error ? error : undefined
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function resolveModel<T extends AiSdkModelType>(
|
||||||
|
modelId: string,
|
||||||
|
modelType: T,
|
||||||
|
methodName: AiSdkMethodName<T>
|
||||||
|
): AiSdkModelReturn<T> {
|
||||||
|
const { provider, actualModelId } = parseHubModelId(modelId)
|
||||||
|
const targetProvider = getTargetProvider(provider)
|
||||||
|
|
||||||
|
const fn = targetProvider[methodName] as (id: string) => AiSdkModelReturn<T>
|
||||||
|
|
||||||
|
if (!fn) {
|
||||||
|
throw new HubProviderError(`Provider "${provider}" does not support ${modelType}`, hubId, provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
return fn(actualModelId)
|
||||||
|
}
|
||||||
|
|
||||||
|
return customProvider({
|
||||||
|
fallbackProvider: {
|
||||||
|
languageModel: (modelId: string) => resolveModel(modelId, 'text', 'languageModel'),
|
||||||
|
textEmbeddingModel: (modelId: string) => resolveModel(modelId, 'embedding', 'textEmbeddingModel'),
|
||||||
|
imageModel: (modelId: string) => resolveModel(modelId, 'image', 'imageModel'),
|
||||||
|
transcriptionModel: (modelId: string) => resolveModel(modelId, 'transcription', 'transcriptionModel'),
|
||||||
|
speechModel: (modelId: string) => resolveModel(modelId, 'speech', 'speechModel')
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
221
packages/aiCore/src/core/providers/RegistryManagement.ts
Normal file
221
packages/aiCore/src/core/providers/RegistryManagement.ts
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
/**
|
||||||
|
* Provider 注册表管理器
|
||||||
|
* 纯粹的管理功能:存储、检索已配置好的 provider 实例
|
||||||
|
* 基于 AI SDK 原生的 createProviderRegistry
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, ProviderV2 } from '@ai-sdk/provider'
|
||||||
|
import { createProviderRegistry, type ProviderRegistryProvider } from 'ai'
|
||||||
|
|
||||||
|
type PROVIDERS = Record<string, ProviderV2>
|
||||||
|
|
||||||
|
export const DEFAULT_SEPARATOR = '|'
|
||||||
|
|
||||||
|
// export type MODEL_ID = `${string}${typeof DEFAULT_SEPARATOR}${string}`
|
||||||
|
|
||||||
|
export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARATOR> {
|
||||||
|
private providers: PROVIDERS = {}
|
||||||
|
private aliases: Set<string> = new Set() // 记录哪些key是别名
|
||||||
|
private separator: SEPARATOR
|
||||||
|
private registry: ProviderRegistryProvider<PROVIDERS, SEPARATOR> | null = null
|
||||||
|
|
||||||
|
constructor(options: { separator: SEPARATOR } = { separator: DEFAULT_SEPARATOR as SEPARATOR }) {
|
||||||
|
this.separator = options.separator
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 注册已配置好的 provider 实例
|
||||||
|
*/
|
||||||
|
registerProvider(id: string, provider: ProviderV2, aliases?: string[]): this {
|
||||||
|
// 注册主provider
|
||||||
|
this.providers[id] = provider
|
||||||
|
|
||||||
|
// 注册别名(都指向同一个provider实例)
|
||||||
|
if (aliases) {
|
||||||
|
aliases.forEach((alias) => {
|
||||||
|
this.providers[alias] = provider // 直接存储引用
|
||||||
|
this.aliases.add(alias) // 标记为别名
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
this.rebuildRegistry()
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取已注册的provider实例
|
||||||
|
*/
|
||||||
|
getProvider(id: string): ProviderV2 | undefined {
|
||||||
|
return this.providers[id]
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 批量注册 providers
|
||||||
|
*/
|
||||||
|
registerProviders(providers: Record<string, ProviderV2>): this {
|
||||||
|
Object.assign(this.providers, providers)
|
||||||
|
this.rebuildRegistry()
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 移除 provider(同时清理相关别名)
|
||||||
|
*/
|
||||||
|
unregisterProvider(id: string): this {
|
||||||
|
const provider = this.providers[id]
|
||||||
|
if (!provider) return this
|
||||||
|
|
||||||
|
// 如果移除的是真实ID,需要清理所有指向它的别名
|
||||||
|
if (!this.aliases.has(id)) {
|
||||||
|
// 找到所有指向此provider的别名并删除
|
||||||
|
const aliasesToRemove: string[] = []
|
||||||
|
this.aliases.forEach((alias) => {
|
||||||
|
if (this.providers[alias] === provider) {
|
||||||
|
aliasesToRemove.push(alias)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
aliasesToRemove.forEach((alias) => {
|
||||||
|
delete this.providers[alias]
|
||||||
|
this.aliases.delete(alias)
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
// 如果移除的是别名,只删除别名记录
|
||||||
|
this.aliases.delete(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete this.providers[id]
|
||||||
|
this.rebuildRegistry()
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 立即重建 registry - 每次变更都重建
|
||||||
|
*/
|
||||||
|
private rebuildRegistry(): void {
|
||||||
|
if (Object.keys(this.providers).length === 0) {
|
||||||
|
this.registry = null
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
this.registry = createProviderRegistry<PROVIDERS, SEPARATOR>(this.providers, {
|
||||||
|
separator: this.separator
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取语言模型 - AI SDK 原生方法
|
||||||
|
*/
|
||||||
|
languageModel(id: `${string}${SEPARATOR}${string}`): LanguageModelV2 {
|
||||||
|
if (!this.registry) {
|
||||||
|
throw new Error('No providers registered')
|
||||||
|
}
|
||||||
|
return this.registry.languageModel(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取文本嵌入模型 - AI SDK 原生方法
|
||||||
|
*/
|
||||||
|
textEmbeddingModel(id: `${string}${SEPARATOR}${string}`): EmbeddingModelV2<string> {
|
||||||
|
if (!this.registry) {
|
||||||
|
throw new Error('No providers registered')
|
||||||
|
}
|
||||||
|
return this.registry.textEmbeddingModel(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取图像模型 - AI SDK 原生方法
|
||||||
|
*/
|
||||||
|
imageModel(id: `${string}${SEPARATOR}${string}`): ImageModelV2 {
|
||||||
|
if (!this.registry) {
|
||||||
|
throw new Error('No providers registered')
|
||||||
|
}
|
||||||
|
return this.registry.imageModel(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取转录模型 - AI SDK 原生方法
|
||||||
|
*/
|
||||||
|
transcriptionModel(id: `${string}${SEPARATOR}${string}`): any {
|
||||||
|
if (!this.registry) {
|
||||||
|
throw new Error('No providers registered')
|
||||||
|
}
|
||||||
|
return this.registry.transcriptionModel(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取语音模型 - AI SDK 原生方法
|
||||||
|
*/
|
||||||
|
speechModel(id: `${string}${SEPARATOR}${string}`): any {
|
||||||
|
if (!this.registry) {
|
||||||
|
throw new Error('No providers registered')
|
||||||
|
}
|
||||||
|
return this.registry.speechModel(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取已注册的 provider 列表
|
||||||
|
*/
|
||||||
|
getRegisteredProviders(): string[] {
|
||||||
|
return Object.keys(this.providers)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检查是否有已注册的 providers
|
||||||
|
*/
|
||||||
|
hasProviders(): boolean {
|
||||||
|
return Object.keys(this.providers).length > 0
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 清除所有 providers
|
||||||
|
*/
|
||||||
|
clear(): this {
|
||||||
|
this.providers = {}
|
||||||
|
this.aliases.clear()
|
||||||
|
this.registry = null
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析真实的Provider ID(供getAiSdkProviderId使用)
|
||||||
|
* 如果传入的是别名,返回真实的Provider ID
|
||||||
|
* 如果传入的是真实ID,直接返回
|
||||||
|
*/
|
||||||
|
resolveProviderId(id: string): string {
|
||||||
|
if (!this.aliases.has(id)) return id // 不是别名,直接返回
|
||||||
|
|
||||||
|
// 是别名,找到真实ID
|
||||||
|
const targetProvider = this.providers[id]
|
||||||
|
for (const [realId, provider] of Object.entries(this.providers)) {
|
||||||
|
if (provider === targetProvider && !this.aliases.has(realId)) {
|
||||||
|
return realId
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检查是否为别名
|
||||||
|
*/
|
||||||
|
isAlias(id: string): boolean {
|
||||||
|
return this.aliases.has(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取所有别名映射关系
|
||||||
|
*/
|
||||||
|
getAllAliases(): Record<string, string> {
|
||||||
|
const result: Record<string, string> = {}
|
||||||
|
this.aliases.forEach((alias) => {
|
||||||
|
result[alias] = this.resolveProviderId(alias)
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 全局注册表管理器实例
|
||||||
|
* 使用 | 作为分隔符,因为 : 会和 :free 等suffix冲突
|
||||||
|
*/
|
||||||
|
export const globalRegistryManagement = new RegistryManagement()
|
||||||
@@ -0,0 +1,632 @@
|
|||||||
|
/**
|
||||||
|
* 测试真正的 AiProviderRegistry 功能
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
// 模拟 AI SDK
|
||||||
|
vi.mock('@ai-sdk/openai', () => ({
|
||||||
|
createOpenAI: vi.fn(() => ({ name: 'openai-mock' }))
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@ai-sdk/anthropic', () => ({
|
||||||
|
createAnthropic: vi.fn(() => ({ name: 'anthropic-mock' }))
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@ai-sdk/azure', () => ({
|
||||||
|
createAzure: vi.fn(() => ({ name: 'azure-mock' }))
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@ai-sdk/deepseek', () => ({
|
||||||
|
createDeepSeek: vi.fn(() => ({ name: 'deepseek-mock' }))
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@ai-sdk/google', () => ({
|
||||||
|
createGoogleGenerativeAI: vi.fn(() => ({ name: 'google-mock' }))
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@ai-sdk/openai-compatible', () => ({
|
||||||
|
createOpenAICompatible: vi.fn(() => ({ name: 'openai-compatible-mock' }))
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@ai-sdk/xai', () => ({
|
||||||
|
createXai: vi.fn(() => ({ name: 'xai-mock' }))
|
||||||
|
}))
|
||||||
|
|
||||||
|
import {
|
||||||
|
cleanup,
|
||||||
|
clearAllProviders,
|
||||||
|
createAndRegisterProvider,
|
||||||
|
createProvider,
|
||||||
|
getAllProviderConfigAliases,
|
||||||
|
getAllProviderConfigs,
|
||||||
|
getInitializedProviders,
|
||||||
|
getLanguageModel,
|
||||||
|
getProviderConfig,
|
||||||
|
getProviderConfigByAlias,
|
||||||
|
getSupportedProviders,
|
||||||
|
hasInitializedProviders,
|
||||||
|
hasProviderConfig,
|
||||||
|
hasProviderConfigByAlias,
|
||||||
|
isProviderConfigAlias,
|
||||||
|
ProviderInitializationError,
|
||||||
|
providerRegistry,
|
||||||
|
registerMultipleProviderConfigs,
|
||||||
|
registerProvider,
|
||||||
|
registerProviderConfig,
|
||||||
|
resolveProviderConfigId
|
||||||
|
} from '../registry'
|
||||||
|
import type { ProviderConfig } from '../schemas'
|
||||||
|
|
||||||
|
describe('Provider Registry 功能测试', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
// 清理状态
|
||||||
|
cleanup()
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('基础功能', () => {
|
||||||
|
it('能够获取支持的 providers 列表', () => {
|
||||||
|
const providers = getSupportedProviders()
|
||||||
|
expect(Array.isArray(providers)).toBe(true)
|
||||||
|
expect(providers.length).toBeGreaterThan(0)
|
||||||
|
|
||||||
|
// 检查返回的数据结构
|
||||||
|
providers.forEach((provider) => {
|
||||||
|
expect(provider).toHaveProperty('id')
|
||||||
|
expect(provider).toHaveProperty('name')
|
||||||
|
expect(typeof provider.id).toBe('string')
|
||||||
|
expect(typeof provider.name).toBe('string')
|
||||||
|
})
|
||||||
|
|
||||||
|
// 包含基础 providers
|
||||||
|
const providerIds = providers.map((p) => p.id)
|
||||||
|
expect(providerIds).toContain('openai')
|
||||||
|
expect(providerIds).toContain('anthropic')
|
||||||
|
expect(providerIds).toContain('google')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('能够获取已初始化的 providers', () => {
|
||||||
|
// 初始状态下没有已初始化的 providers
|
||||||
|
expect(getInitializedProviders()).toEqual([])
|
||||||
|
expect(hasInitializedProviders()).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('能够访问全局注册管理器', () => {
|
||||||
|
expect(providerRegistry).toBeDefined()
|
||||||
|
expect(typeof providerRegistry.clear).toBe('function')
|
||||||
|
expect(typeof providerRegistry.getRegisteredProviders).toBe('function')
|
||||||
|
expect(typeof providerRegistry.hasProviders).toBe('function')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('能够获取语言模型', () => {
|
||||||
|
// 在没有注册 provider 的情况下,这个函数应该会抛出错误
|
||||||
|
expect(() => getLanguageModel('non-existent')).toThrow('No providers registered')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Provider 配置注册', () => {
|
||||||
|
it('能够注册自定义 provider 配置', () => {
|
||||||
|
const config: ProviderConfig = {
|
||||||
|
id: 'custom-provider',
|
||||||
|
name: 'Custom Provider',
|
||||||
|
creator: vi.fn(() => ({ name: 'custom' })),
|
||||||
|
supportsImageGeneration: false
|
||||||
|
}
|
||||||
|
|
||||||
|
const success = registerProviderConfig(config)
|
||||||
|
expect(success).toBe(true)
|
||||||
|
|
||||||
|
expect(hasProviderConfig('custom-provider')).toBe(true)
|
||||||
|
expect(getProviderConfig('custom-provider')).toEqual(config)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('能够注册带别名的 provider 配置', () => {
|
||||||
|
const config: ProviderConfig = {
|
||||||
|
id: 'custom-provider-with-aliases',
|
||||||
|
name: 'Custom Provider with Aliases',
|
||||||
|
creator: vi.fn(() => ({ name: 'custom-aliased' })),
|
||||||
|
supportsImageGeneration: false,
|
||||||
|
aliases: ['alias-1', 'alias-2']
|
||||||
|
}
|
||||||
|
|
||||||
|
const success = registerProviderConfig(config)
|
||||||
|
expect(success).toBe(true)
|
||||||
|
|
||||||
|
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
|
||||||
|
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
|
||||||
|
expect(getProviderConfigByAlias('alias-1')).toEqual(config)
|
||||||
|
expect(resolveProviderConfigId('alias-1')).toBe('custom-provider-with-aliases')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('拒绝无效的配置', () => {
|
||||||
|
// 缺少必要字段
|
||||||
|
const invalidConfig = {
|
||||||
|
id: 'invalid-provider'
|
||||||
|
// 缺少 name, creator 等
|
||||||
|
}
|
||||||
|
|
||||||
|
const success = registerProviderConfig(invalidConfig as any)
|
||||||
|
expect(success).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('能够批量注册 provider 配置', () => {
|
||||||
|
const configs: ProviderConfig[] = [
|
||||||
|
{
|
||||||
|
id: 'provider-1',
|
||||||
|
name: 'Provider 1',
|
||||||
|
creator: vi.fn(() => ({ name: 'provider-1' })),
|
||||||
|
supportsImageGeneration: false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'provider-2',
|
||||||
|
name: 'Provider 2',
|
||||||
|
creator: vi.fn(() => ({ name: 'provider-2' })),
|
||||||
|
supportsImageGeneration: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: '', // 无效配置
|
||||||
|
name: 'Invalid Provider',
|
||||||
|
creator: vi.fn(() => ({ name: 'invalid' })),
|
||||||
|
supportsImageGeneration: false
|
||||||
|
} as any
|
||||||
|
]
|
||||||
|
|
||||||
|
const successCount = registerMultipleProviderConfigs(configs)
|
||||||
|
expect(successCount).toBe(2) // 只有前两个成功
|
||||||
|
|
||||||
|
expect(hasProviderConfig('provider-1')).toBe(true)
|
||||||
|
expect(hasProviderConfig('provider-2')).toBe(true)
|
||||||
|
expect(hasProviderConfig('')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('能够获取所有配置和别名信息', () => {
|
||||||
|
// 注册一些配置
|
||||||
|
registerProviderConfig({
|
||||||
|
id: 'test-provider',
|
||||||
|
name: 'Test Provider',
|
||||||
|
creator: vi.fn(),
|
||||||
|
supportsImageGeneration: false,
|
||||||
|
aliases: ['test-alias']
|
||||||
|
})
|
||||||
|
|
||||||
|
const allConfigs = getAllProviderConfigs()
|
||||||
|
expect(Array.isArray(allConfigs)).toBe(true)
|
||||||
|
expect(allConfigs.some((config) => config.id === 'test-provider')).toBe(true)
|
||||||
|
|
||||||
|
const aliases = getAllProviderConfigAliases()
|
||||||
|
expect(aliases['test-alias']).toBe('test-provider')
|
||||||
|
expect(isProviderConfigAlias('test-alias')).toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Provider 创建和注册', () => {
|
||||||
|
it('能够创建 provider 实例', async () => {
|
||||||
|
const config: ProviderConfig = {
|
||||||
|
id: 'test-create-provider',
|
||||||
|
name: 'Test Create Provider',
|
||||||
|
creator: vi.fn(() => ({ name: 'test-created' })),
|
||||||
|
supportsImageGeneration: false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 先注册配置
|
||||||
|
registerProviderConfig(config)
|
||||||
|
|
||||||
|
// 创建 provider 实例
|
||||||
|
const provider = await createProvider('test-create-provider', { apiKey: 'test' })
|
||||||
|
expect(provider).toBeDefined()
|
||||||
|
expect(config.creator).toHaveBeenCalledWith({ apiKey: 'test' })
|
||||||
|
})
|
||||||
|
|
||||||
|
it('能够注册 provider 到全局管理器', () => {
|
||||||
|
const mockProvider = { name: 'mock-provider' }
|
||||||
|
const config: ProviderConfig = {
|
||||||
|
id: 'test-register-provider',
|
||||||
|
name: 'Test Register Provider',
|
||||||
|
creator: vi.fn(() => mockProvider),
|
||||||
|
supportsImageGeneration: false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 先注册配置
|
||||||
|
registerProviderConfig(config)
|
||||||
|
|
||||||
|
// 注册 provider 到全局管理器
|
||||||
|
const success = registerProvider('test-register-provider', mockProvider)
|
||||||
|
expect(success).toBe(true)
|
||||||
|
|
||||||
|
// 验证注册成功
|
||||||
|
const registeredProviders = getInitializedProviders()
|
||||||
|
expect(registeredProviders).toContain('test-register-provider')
|
||||||
|
expect(hasInitializedProviders()).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('能够一步完成创建和注册', async () => {
|
||||||
|
const config: ProviderConfig = {
|
||||||
|
id: 'test-create-and-register',
|
||||||
|
name: 'Test Create and Register',
|
||||||
|
creator: vi.fn(() => ({ name: 'test-both' })),
|
||||||
|
supportsImageGeneration: false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 先注册配置
|
||||||
|
registerProviderConfig(config)
|
||||||
|
|
||||||
|
// 一步完成创建和注册
|
||||||
|
const success = await createAndRegisterProvider('test-create-and-register', { apiKey: 'test' })
|
||||||
|
expect(success).toBe(true)
|
||||||
|
|
||||||
|
// 验证注册成功
|
||||||
|
const registeredProviders = getInitializedProviders()
|
||||||
|
expect(registeredProviders).toContain('test-create-and-register')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Registry 管理', () => {
|
||||||
|
it('能够清理所有配置和注册的 providers', () => {
|
||||||
|
// 注册一些配置
|
||||||
|
registerProviderConfig({
|
||||||
|
id: 'temp-provider',
|
||||||
|
name: 'Temp Provider',
|
||||||
|
creator: vi.fn(() => ({ name: 'temp' })),
|
||||||
|
supportsImageGeneration: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(hasProviderConfig('temp-provider')).toBe(true)
|
||||||
|
|
||||||
|
// 清理
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
expect(hasProviderConfig('temp-provider')).toBe(false)
|
||||||
|
// 但基础配置应该重新加载
|
||||||
|
expect(hasProviderConfig('openai')).toBe(true) // 基础 providers 会重新初始化
|
||||||
|
})
|
||||||
|
|
||||||
|
it('能够单独清理已注册的 providers', () => {
|
||||||
|
// 清理所有 providers
|
||||||
|
clearAllProviders()
|
||||||
|
|
||||||
|
expect(getInitializedProviders()).toEqual([])
|
||||||
|
expect(hasInitializedProviders()).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('ProviderInitializationError 错误类工作正常', () => {
|
||||||
|
const error = new ProviderInitializationError('Test error', 'test-provider')
|
||||||
|
expect(error.message).toBe('Test error')
|
||||||
|
expect(error.providerId).toBe('test-provider')
|
||||||
|
expect(error.name).toBe('ProviderInitializationError')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('错误处理', () => {
|
||||||
|
it('优雅处理空配置', () => {
|
||||||
|
const success = registerProviderConfig(null as any)
|
||||||
|
expect(success).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('优雅处理未定义配置', () => {
|
||||||
|
const success = registerProviderConfig(undefined as any)
|
||||||
|
expect(success).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('处理空字符串 ID', () => {
|
||||||
|
const config = {
|
||||||
|
id: '',
|
||||||
|
name: 'Empty ID Provider',
|
||||||
|
creator: vi.fn(() => ({ name: 'empty' })),
|
||||||
|
supportsImageGeneration: false
|
||||||
|
}
|
||||||
|
|
||||||
|
const success = registerProviderConfig(config)
|
||||||
|
expect(success).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('处理创建不存在配置的 provider', async () => {
|
||||||
|
await expect(createProvider('non-existent-provider', {})).rejects.toThrow(
|
||||||
|
'ProviderConfig not found for id: non-existent-provider'
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('处理注册不存在配置的 provider', () => {
|
||||||
|
const mockProvider = { name: 'mock' }
|
||||||
|
const success = registerProvider('non-existent-provider', mockProvider)
|
||||||
|
expect(success).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('处理获取不存在配置的情况', () => {
|
||||||
|
expect(getProviderConfig('non-existent')).toBeUndefined()
|
||||||
|
expect(getProviderConfigByAlias('non-existent-alias')).toBeUndefined()
|
||||||
|
expect(hasProviderConfig('non-existent')).toBe(false)
|
||||||
|
expect(hasProviderConfigByAlias('non-existent-alias')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('处理批量注册时的部分失败', () => {
|
||||||
|
const mixedConfigs: ProviderConfig[] = [
|
||||||
|
{
|
||||||
|
id: 'valid-provider-1',
|
||||||
|
name: 'Valid Provider 1',
|
||||||
|
creator: vi.fn(() => ({ name: 'valid-1' })),
|
||||||
|
supportsImageGeneration: false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: '', // 无效配置
|
||||||
|
name: 'Invalid Provider',
|
||||||
|
creator: vi.fn(() => ({ name: 'invalid' })),
|
||||||
|
supportsImageGeneration: false
|
||||||
|
} as any,
|
||||||
|
{
|
||||||
|
id: 'valid-provider-2',
|
||||||
|
name: 'Valid Provider 2',
|
||||||
|
creator: vi.fn(() => ({ name: 'valid-2' })),
|
||||||
|
supportsImageGeneration: true
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
const successCount = registerMultipleProviderConfigs(mixedConfigs)
|
||||||
|
expect(successCount).toBe(2) // 只有两个有效配置成功
|
||||||
|
|
||||||
|
expect(hasProviderConfig('valid-provider-1')).toBe(true)
|
||||||
|
expect(hasProviderConfig('valid-provider-2')).toBe(true)
|
||||||
|
expect(hasProviderConfig('')).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('处理动态导入失败的情况', async () => {
|
||||||
|
const config: ProviderConfig = {
|
||||||
|
id: 'import-test-provider',
|
||||||
|
name: 'Import Test Provider',
|
||||||
|
import: vi.fn().mockRejectedValue(new Error('Import failed')),
|
||||||
|
creatorFunctionName: 'createTest',
|
||||||
|
supportsImageGeneration: false
|
||||||
|
}
|
||||||
|
|
||||||
|
registerProviderConfig(config)
|
||||||
|
|
||||||
|
await expect(createProvider('import-test-provider', {})).rejects.toThrow('Import failed')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('集成测试', () => {
|
||||||
|
it('正确处理复杂的配置、创建、注册和清理场景', async () => {
|
||||||
|
// 初始状态验证
|
||||||
|
const initialConfigs = getAllProviderConfigs()
|
||||||
|
expect(initialConfigs.length).toBeGreaterThan(0) // 有基础配置
|
||||||
|
expect(getInitializedProviders()).toEqual([])
|
||||||
|
|
||||||
|
// 注册多个带别名的 provider 配置
|
||||||
|
const configs: ProviderConfig[] = [
|
||||||
|
{
|
||||||
|
id: 'integration-provider-1',
|
||||||
|
name: 'Integration Provider 1',
|
||||||
|
creator: vi.fn(() => ({ name: 'integration-1' })),
|
||||||
|
supportsImageGeneration: false,
|
||||||
|
aliases: ['alias-1', 'short-name-1']
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'integration-provider-2',
|
||||||
|
name: 'Integration Provider 2',
|
||||||
|
creator: vi.fn(() => ({ name: 'integration-2' })),
|
||||||
|
supportsImageGeneration: true,
|
||||||
|
aliases: ['alias-2', 'short-name-2']
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
const successCount = registerMultipleProviderConfigs(configs)
|
||||||
|
expect(successCount).toBe(2)
|
||||||
|
|
||||||
|
// 验证配置注册成功
|
||||||
|
expect(hasProviderConfig('integration-provider-1')).toBe(true)
|
||||||
|
expect(hasProviderConfig('integration-provider-2')).toBe(true)
|
||||||
|
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
|
||||||
|
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
|
||||||
|
|
||||||
|
// 验证别名映射
|
||||||
|
const aliases = getAllProviderConfigAliases()
|
||||||
|
expect(aliases['alias-1']).toBe('integration-provider-1')
|
||||||
|
expect(aliases['alias-2']).toBe('integration-provider-2')
|
||||||
|
|
||||||
|
// 创建和注册 providers
|
||||||
|
const success1 = await createAndRegisterProvider('integration-provider-1', { apiKey: 'test1' })
|
||||||
|
const success2 = await createAndRegisterProvider('integration-provider-2', { apiKey: 'test2' })
|
||||||
|
expect(success1).toBe(true)
|
||||||
|
expect(success2).toBe(true)
|
||||||
|
|
||||||
|
// 验证注册成功
|
||||||
|
const registeredProviders = getInitializedProviders()
|
||||||
|
expect(registeredProviders).toContain('integration-provider-1')
|
||||||
|
expect(registeredProviders).toContain('integration-provider-2')
|
||||||
|
expect(hasInitializedProviders()).toBe(true)
|
||||||
|
|
||||||
|
// 清理
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
// 验证清理后的状态
|
||||||
|
expect(getInitializedProviders()).toEqual([])
|
||||||
|
expect(hasProviderConfig('integration-provider-1')).toBe(false)
|
||||||
|
expect(hasProviderConfig('integration-provider-2')).toBe(false)
|
||||||
|
expect(getAllProviderConfigAliases()).toEqual({})
|
||||||
|
|
||||||
|
// 基础配置应该重新加载
|
||||||
|
expect(hasProviderConfig('openai')).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('正确处理动态导入配置的 provider', async () => {
|
||||||
|
const mockModule = { createCustomProvider: vi.fn(() => ({ name: 'custom-dynamic' })) }
|
||||||
|
const dynamicImportConfig: ProviderConfig = {
|
||||||
|
id: 'dynamic-import-provider',
|
||||||
|
name: 'Dynamic Import Provider',
|
||||||
|
import: vi.fn().mockResolvedValue(mockModule),
|
||||||
|
creatorFunctionName: 'createCustomProvider',
|
||||||
|
supportsImageGeneration: false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 注册配置
|
||||||
|
const configSuccess = registerProviderConfig(dynamicImportConfig)
|
||||||
|
expect(configSuccess).toBe(true)
|
||||||
|
|
||||||
|
// 创建和注册 provider
|
||||||
|
const registerSuccess = await createAndRegisterProvider('dynamic-import-provider', { apiKey: 'test' })
|
||||||
|
expect(registerSuccess).toBe(true)
|
||||||
|
|
||||||
|
// 验证导入函数被调用
|
||||||
|
expect(dynamicImportConfig.import).toHaveBeenCalled()
|
||||||
|
expect(mockModule.createCustomProvider).toHaveBeenCalledWith({ apiKey: 'test' })
|
||||||
|
|
||||||
|
// 验证注册成功
|
||||||
|
expect(getInitializedProviders()).toContain('dynamic-import-provider')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('正确处理大量配置的注册和管理', () => {
|
||||||
|
const largeConfigList: ProviderConfig[] = []
|
||||||
|
|
||||||
|
// 生成50个配置
|
||||||
|
for (let i = 0; i < 50; i++) {
|
||||||
|
largeConfigList.push({
|
||||||
|
id: `bulk-provider-${i}`,
|
||||||
|
name: `Bulk Provider ${i}`,
|
||||||
|
creator: vi.fn(() => ({ name: `bulk-${i}` })),
|
||||||
|
supportsImageGeneration: i % 2 === 0, // 偶数支持图像生成
|
||||||
|
aliases: [`alias-${i}`, `short-${i}`]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const successCount = registerMultipleProviderConfigs(largeConfigList)
|
||||||
|
expect(successCount).toBe(50)
|
||||||
|
|
||||||
|
// 验证所有配置都被正确注册
|
||||||
|
const allConfigs = getAllProviderConfigs()
|
||||||
|
expect(allConfigs.filter((config) => config.id.startsWith('bulk-provider-')).length).toBe(50)
|
||||||
|
|
||||||
|
// 验证别名数量
|
||||||
|
const aliases = getAllProviderConfigAliases()
|
||||||
|
const bulkAliases = Object.keys(aliases).filter(
|
||||||
|
(alias) => alias.startsWith('alias-') || alias.startsWith('short-')
|
||||||
|
)
|
||||||
|
expect(bulkAliases.length).toBe(100) // 每个 provider 有2个别名
|
||||||
|
|
||||||
|
// 随机验证几个配置
|
||||||
|
expect(hasProviderConfig('bulk-provider-0')).toBe(true)
|
||||||
|
expect(hasProviderConfig('bulk-provider-25')).toBe(true)
|
||||||
|
expect(hasProviderConfig('bulk-provider-49')).toBe(true)
|
||||||
|
|
||||||
|
// 验证别名工作正常
|
||||||
|
expect(resolveProviderConfigId('alias-25')).toBe('bulk-provider-25')
|
||||||
|
expect(isProviderConfigAlias('short-30')).toBe(true)
|
||||||
|
|
||||||
|
// 清理能正确处理大量数据
|
||||||
|
cleanup()
|
||||||
|
const cleanupAliases = getAllProviderConfigAliases()
|
||||||
|
expect(
|
||||||
|
Object.keys(cleanupAliases).filter((alias) => alias.startsWith('alias-') || alias.startsWith('short-'))
|
||||||
|
).toEqual([])
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('边界测试', () => {
|
||||||
|
it('处理包含特殊字符的 provider IDs', () => {
|
||||||
|
const specialCharsConfigs: ProviderConfig[] = [
|
||||||
|
{
|
||||||
|
id: 'provider-with-dashes',
|
||||||
|
name: 'Provider With Dashes',
|
||||||
|
creator: vi.fn(() => ({ name: 'dashes' })),
|
||||||
|
supportsImageGeneration: false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'provider_with_underscores',
|
||||||
|
name: 'Provider With Underscores',
|
||||||
|
creator: vi.fn(() => ({ name: 'underscores' })),
|
||||||
|
supportsImageGeneration: false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'provider.with.dots',
|
||||||
|
name: 'Provider With Dots',
|
||||||
|
creator: vi.fn(() => ({ name: 'dots' })),
|
||||||
|
supportsImageGeneration: false
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
const successCount = registerMultipleProviderConfigs(specialCharsConfigs)
|
||||||
|
expect(successCount).toBeGreaterThan(0) // 至少有一些成功
|
||||||
|
|
||||||
|
// 验证支持的特殊字符格式
|
||||||
|
if (hasProviderConfig('provider-with-dashes')) {
|
||||||
|
expect(getProviderConfig('provider-with-dashes')).toBeDefined()
|
||||||
|
}
|
||||||
|
if (hasProviderConfig('provider_with_underscores')) {
|
||||||
|
expect(getProviderConfig('provider_with_underscores')).toBeDefined()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it('处理空的批量注册', () => {
|
||||||
|
const successCount = registerMultipleProviderConfigs([])
|
||||||
|
expect(successCount).toBe(0)
|
||||||
|
|
||||||
|
// 确保没有额外的配置被添加
|
||||||
|
const configsBefore = getAllProviderConfigs().length
|
||||||
|
expect(configsBefore).toBeGreaterThan(0) // 应该有基础配置
|
||||||
|
})
|
||||||
|
|
||||||
|
it('处理重复的配置注册', () => {
|
||||||
|
const config: ProviderConfig = {
|
||||||
|
id: 'duplicate-test-provider',
|
||||||
|
name: 'Duplicate Test Provider',
|
||||||
|
creator: vi.fn(() => ({ name: 'duplicate' })),
|
||||||
|
supportsImageGeneration: false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 第一次注册成功
|
||||||
|
expect(registerProviderConfig(config)).toBe(true)
|
||||||
|
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
|
||||||
|
|
||||||
|
// 重复注册相同的配置(允许覆盖)
|
||||||
|
const updatedConfig: ProviderConfig = {
|
||||||
|
...config,
|
||||||
|
name: 'Updated Duplicate Test Provider'
|
||||||
|
}
|
||||||
|
expect(registerProviderConfig(updatedConfig)).toBe(true)
|
||||||
|
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
|
||||||
|
|
||||||
|
// 验证配置被更新
|
||||||
|
const retrievedConfig = getProviderConfig('duplicate-test-provider')
|
||||||
|
expect(retrievedConfig?.name).toBe('Updated Duplicate Test Provider')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('处理极长的 ID 和名称', () => {
|
||||||
|
const longId = 'very-long-provider-id-' + 'x'.repeat(100)
|
||||||
|
const longName = 'Very Long Provider Name ' + 'Y'.repeat(100)
|
||||||
|
|
||||||
|
const config: ProviderConfig = {
|
||||||
|
id: longId,
|
||||||
|
name: longName,
|
||||||
|
creator: vi.fn(() => ({ name: 'long-test' })),
|
||||||
|
supportsImageGeneration: false
|
||||||
|
}
|
||||||
|
|
||||||
|
const success = registerProviderConfig(config)
|
||||||
|
expect(success).toBe(true)
|
||||||
|
expect(hasProviderConfig(longId)).toBe(true)
|
||||||
|
|
||||||
|
const retrievedConfig = getProviderConfig(longId)
|
||||||
|
expect(retrievedConfig?.name).toBe(longName)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('处理大量别名的配置', () => {
|
||||||
|
const manyAliases = Array.from({ length: 50 }, (_, i) => `alias-${i}`)
|
||||||
|
|
||||||
|
const config: ProviderConfig = {
|
||||||
|
id: 'provider-with-many-aliases',
|
||||||
|
name: 'Provider With Many Aliases',
|
||||||
|
creator: vi.fn(() => ({ name: 'many-aliases' })),
|
||||||
|
supportsImageGeneration: false,
|
||||||
|
aliases: manyAliases
|
||||||
|
}
|
||||||
|
|
||||||
|
const success = registerProviderConfig(config)
|
||||||
|
expect(success).toBe(true)
|
||||||
|
|
||||||
|
// 验证所有别名都能正确解析
|
||||||
|
manyAliases.forEach((alias) => {
|
||||||
|
expect(hasProviderConfigByAlias(alias)).toBe(true)
|
||||||
|
expect(resolveProviderConfigId(alias)).toBe('provider-with-many-aliases')
|
||||||
|
expect(isProviderConfigAlias(alias)).toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
264
packages/aiCore/src/core/providers/__tests__/schemas.test.ts
Normal file
264
packages/aiCore/src/core/providers/__tests__/schemas.test.ts
Normal file
@@ -0,0 +1,264 @@
|
|||||||
|
import { describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
import {
|
||||||
|
type BaseProviderId,
|
||||||
|
baseProviderIds,
|
||||||
|
baseProviderIdSchema,
|
||||||
|
baseProviders,
|
||||||
|
type CustomProviderId,
|
||||||
|
customProviderIdSchema,
|
||||||
|
providerConfigSchema,
|
||||||
|
type ProviderId,
|
||||||
|
providerIdSchema
|
||||||
|
} from '../schemas'
|
||||||
|
|
||||||
|
describe('Provider Schemas', () => {
|
||||||
|
describe('baseProviders', () => {
|
||||||
|
it('包含所有预期的基础 providers', () => {
|
||||||
|
expect(baseProviders).toBeDefined()
|
||||||
|
expect(Array.isArray(baseProviders)).toBe(true)
|
||||||
|
expect(baseProviders.length).toBeGreaterThan(0)
|
||||||
|
|
||||||
|
const expectedIds = [
|
||||||
|
'openai',
|
||||||
|
'openai-responses',
|
||||||
|
'openai-compatible',
|
||||||
|
'anthropic',
|
||||||
|
'google',
|
||||||
|
'xai',
|
||||||
|
'azure',
|
||||||
|
'deepseek'
|
||||||
|
]
|
||||||
|
const actualIds = baseProviders.map((p) => p.id)
|
||||||
|
expectedIds.forEach((id) => {
|
||||||
|
expect(actualIds).toContain(id)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('每个基础 provider 有必要的属性', () => {
|
||||||
|
baseProviders.forEach((provider) => {
|
||||||
|
expect(provider).toHaveProperty('id')
|
||||||
|
expect(provider).toHaveProperty('name')
|
||||||
|
expect(provider).toHaveProperty('creator')
|
||||||
|
expect(provider).toHaveProperty('supportsImageGeneration')
|
||||||
|
|
||||||
|
expect(typeof provider.id).toBe('string')
|
||||||
|
expect(typeof provider.name).toBe('string')
|
||||||
|
expect(typeof provider.creator).toBe('function')
|
||||||
|
expect(typeof provider.supportsImageGeneration).toBe('boolean')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('provider ID 是唯一的', () => {
|
||||||
|
const ids = baseProviders.map((p) => p.id)
|
||||||
|
const uniqueIds = [...new Set(ids)]
|
||||||
|
expect(ids).toEqual(uniqueIds)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('baseProviderIds', () => {
|
||||||
|
it('正确提取所有基础 provider IDs', () => {
|
||||||
|
expect(baseProviderIds).toBeDefined()
|
||||||
|
expect(Array.isArray(baseProviderIds)).toBe(true)
|
||||||
|
expect(baseProviderIds.length).toBe(baseProviders.length)
|
||||||
|
|
||||||
|
baseProviders.forEach((provider) => {
|
||||||
|
expect(baseProviderIds).toContain(provider.id)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('baseProviderIdSchema', () => {
|
||||||
|
it('验证有效的基础 provider IDs', () => {
|
||||||
|
baseProviderIds.forEach((id) => {
|
||||||
|
expect(baseProviderIdSchema.safeParse(id).success).toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('拒绝无效的基础 provider IDs', () => {
|
||||||
|
const invalidIds = ['invalid', 'not-exists', '']
|
||||||
|
invalidIds.forEach((id) => {
|
||||||
|
expect(baseProviderIdSchema.safeParse(id).success).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('customProviderIdSchema', () => {
|
||||||
|
it('接受有效的自定义 provider IDs', () => {
|
||||||
|
const validIds = ['custom-provider', 'my-ai-service', 'company-llm-v2']
|
||||||
|
validIds.forEach((id) => {
|
||||||
|
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('拒绝与基础 provider IDs 冲突的 IDs', () => {
|
||||||
|
baseProviderIds.forEach((id) => {
|
||||||
|
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('拒绝空字符串', () => {
|
||||||
|
expect(customProviderIdSchema.safeParse('').success).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('providerIdSchema', () => {
|
||||||
|
it('接受基础 provider IDs', () => {
|
||||||
|
baseProviderIds.forEach((id) => {
|
||||||
|
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('接受有效的自定义 provider IDs', () => {
|
||||||
|
const validCustomIds = ['custom-provider', 'my-ai-service']
|
||||||
|
validCustomIds.forEach((id) => {
|
||||||
|
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('拒绝无效的 IDs', () => {
|
||||||
|
const invalidIds = ['', undefined, null, 123]
|
||||||
|
invalidIds.forEach((id) => {
|
||||||
|
expect(providerIdSchema.safeParse(id).success).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('providerConfigSchema', () => {
|
||||||
|
it('验证带有 creator 的有效配置', () => {
|
||||||
|
const validConfig = {
|
||||||
|
id: 'custom-provider',
|
||||||
|
name: 'Custom Provider',
|
||||||
|
creator: vi.fn(),
|
||||||
|
supportsImageGeneration: true
|
||||||
|
}
|
||||||
|
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('验证带有 import 配置的有效配置', () => {
|
||||||
|
const validConfig = {
|
||||||
|
id: 'custom-provider',
|
||||||
|
name: 'Custom Provider',
|
||||||
|
import: vi.fn(),
|
||||||
|
creatorFunctionName: 'createCustom',
|
||||||
|
supportsImageGeneration: false
|
||||||
|
}
|
||||||
|
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('拒绝既没有 creator 也没有 import 配置的配置', () => {
|
||||||
|
const invalidConfig = {
|
||||||
|
id: 'invalid',
|
||||||
|
name: 'Invalid Provider',
|
||||||
|
supportsImageGeneration: false
|
||||||
|
}
|
||||||
|
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('为 supportsImageGeneration 设置默认值', () => {
|
||||||
|
const config = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
creator: vi.fn()
|
||||||
|
}
|
||||||
|
const result = providerConfigSchema.safeParse(config)
|
||||||
|
expect(result.success).toBe(true)
|
||||||
|
if (result.success) {
|
||||||
|
expect(result.data.supportsImageGeneration).toBe(false)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it('拒绝使用基础 provider ID 的配置', () => {
|
||||||
|
const invalidConfig = {
|
||||||
|
id: 'openai', // 基础 provider ID
|
||||||
|
name: 'Should Fail',
|
||||||
|
creator: vi.fn()
|
||||||
|
}
|
||||||
|
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('拒绝缺少必需字段的配置', () => {
|
||||||
|
const invalidConfigs = [
|
||||||
|
{ name: 'Missing ID', creator: vi.fn() },
|
||||||
|
{ id: 'missing-name', creator: vi.fn() },
|
||||||
|
{ id: '', name: 'Empty ID', creator: vi.fn() },
|
||||||
|
{ id: 'valid-custom', name: '', creator: vi.fn() }
|
||||||
|
]
|
||||||
|
|
||||||
|
invalidConfigs.forEach((config) => {
|
||||||
|
expect(providerConfigSchema.safeParse(config).success).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Schema 验证功能', () => {
|
||||||
|
it('baseProviderIdSchema 正确验证基础 provider IDs', () => {
|
||||||
|
baseProviderIds.forEach((id) => {
|
||||||
|
expect(baseProviderIdSchema.safeParse(id).success).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(baseProviderIdSchema.safeParse('invalid-id').success).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('customProviderIdSchema 正确验证自定义 provider IDs', () => {
|
||||||
|
const customIds = ['custom-provider', 'my-service', 'company-llm']
|
||||||
|
customIds.forEach((id) => {
|
||||||
|
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
// 拒绝基础 provider IDs
|
||||||
|
baseProviderIds.forEach((id) => {
|
||||||
|
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('providerIdSchema 接受基础和自定义 provider IDs', () => {
|
||||||
|
// 基础 IDs
|
||||||
|
baseProviderIds.forEach((id) => {
|
||||||
|
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
// 自定义 IDs
|
||||||
|
const customIds = ['custom-provider', 'my-service']
|
||||||
|
customIds.forEach((id) => {
|
||||||
|
expect(providerIdSchema.safeParse(id).success).toBe(true)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('providerConfigSchema 验证完整的 provider 配置', () => {
|
||||||
|
const validConfig = {
|
||||||
|
id: 'custom-provider',
|
||||||
|
name: 'Custom Provider',
|
||||||
|
creator: vi.fn(),
|
||||||
|
supportsImageGeneration: true
|
||||||
|
}
|
||||||
|
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
|
||||||
|
|
||||||
|
const invalidConfig = {
|
||||||
|
id: 'openai', // 不允许基础 provider ID
|
||||||
|
name: 'OpenAI',
|
||||||
|
creator: vi.fn()
|
||||||
|
}
|
||||||
|
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('类型推导', () => {
|
||||||
|
it('BaseProviderId 类型正确', () => {
|
||||||
|
const id: BaseProviderId = 'openai'
|
||||||
|
expect(baseProviderIds).toContain(id)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('CustomProviderId 类型是字符串', () => {
|
||||||
|
const id: CustomProviderId = 'custom-provider'
|
||||||
|
expect(typeof id).toBe('string')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('ProviderId 类型支持基础和自定义 IDs', () => {
|
||||||
|
const baseId: ProviderId = 'openai'
|
||||||
|
const customId: ProviderId = 'custom-provider'
|
||||||
|
expect(typeof baseId).toBe('string')
|
||||||
|
expect(typeof customId).toBe('string')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
291
packages/aiCore/src/core/providers/factory.ts
Normal file
291
packages/aiCore/src/core/providers/factory.ts
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
/**
|
||||||
|
* AI Provider 配置工厂
|
||||||
|
* 提供类型安全的 Provider 配置构建器
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type { ProviderId, ProviderSettingsMap } from './types'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 通用配置基础类型,包含所有 Provider 共有的属性
|
||||||
|
*/
|
||||||
|
export interface BaseProviderConfig {
|
||||||
|
apiKey?: string
|
||||||
|
baseURL?: string
|
||||||
|
timeout?: number
|
||||||
|
headers?: Record<string, string>
|
||||||
|
fetch?: typeof globalThis.fetch
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 完整的配置类型,结合基础配置、AI SDK 配置和特定 Provider 配置
|
||||||
|
*/
|
||||||
|
type CompleteProviderConfig<T extends ProviderId> = BaseProviderConfig & Partial<ProviderSettingsMap[T]>
|
||||||
|
|
||||||
|
type ConfigHandler<T extends ProviderId> = (
|
||||||
|
builder: ProviderConfigBuilder<T>,
|
||||||
|
provider: CompleteProviderConfig<T>
|
||||||
|
) => void
|
||||||
|
|
||||||
|
const configHandlers: {
|
||||||
|
[K in ProviderId]?: ConfigHandler<K>
|
||||||
|
} = {
|
||||||
|
azure: (builder, provider) => {
|
||||||
|
const azureBuilder = builder as ProviderConfigBuilder<'azure'>
|
||||||
|
const azureProvider = provider as CompleteProviderConfig<'azure'>
|
||||||
|
azureBuilder.withAzureConfig({
|
||||||
|
apiVersion: azureProvider.apiVersion,
|
||||||
|
resourceName: azureProvider.resourceName
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
|
||||||
|
private config: CompleteProviderConfig<T> = {} as CompleteProviderConfig<T>
|
||||||
|
|
||||||
|
constructor(private providerId: T) {}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置 API Key
|
||||||
|
*/
|
||||||
|
withApiKey(apiKey: string): this
|
||||||
|
withApiKey(apiKey: string, options: T extends 'openai' ? { organization?: string; project?: string } : never): this
|
||||||
|
withApiKey(apiKey: string, options?: any): this {
|
||||||
|
this.config.apiKey = apiKey
|
||||||
|
|
||||||
|
// 类型安全的 OpenAI 特定配置
|
||||||
|
if (this.providerId === 'openai' && options) {
|
||||||
|
const openaiConfig = this.config as CompleteProviderConfig<'openai'>
|
||||||
|
if (options.organization) {
|
||||||
|
openaiConfig.organization = options.organization
|
||||||
|
}
|
||||||
|
if (options.project) {
|
||||||
|
openaiConfig.project = options.project
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置基础 URL
|
||||||
|
*/
|
||||||
|
withBaseURL(baseURL: string) {
|
||||||
|
this.config.baseURL = baseURL
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置请求配置
|
||||||
|
*/
|
||||||
|
withRequestConfig(options: { headers?: Record<string, string>; fetch?: typeof fetch }): this {
|
||||||
|
if (options.headers) {
|
||||||
|
this.config.headers = { ...this.config.headers, ...options.headers }
|
||||||
|
}
|
||||||
|
if (options.fetch) {
|
||||||
|
this.config.fetch = options.fetch
|
||||||
|
}
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Azure OpenAI 特定配置
|
||||||
|
*/
|
||||||
|
withAzureConfig(options: { apiVersion?: string; resourceName?: string }): T extends 'azure' ? this : never
|
||||||
|
withAzureConfig(options: any): any {
|
||||||
|
if (this.providerId === 'azure') {
|
||||||
|
const azureConfig = this.config as CompleteProviderConfig<'azure'>
|
||||||
|
if (options.apiVersion) {
|
||||||
|
azureConfig.apiVersion = options.apiVersion
|
||||||
|
}
|
||||||
|
if (options.resourceName) {
|
||||||
|
azureConfig.resourceName = options.resourceName
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 设置自定义参数
|
||||||
|
*/
|
||||||
|
withCustomParams(params: Record<string, any>) {
|
||||||
|
Object.assign(this.config, params)
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构建最终配置
|
||||||
|
*/
|
||||||
|
build(): ProviderSettingsMap[T] {
|
||||||
|
return this.config as ProviderSettingsMap[T]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provider 配置工厂
|
||||||
|
* 提供便捷的配置创建方法
|
||||||
|
*/
|
||||||
|
export class ProviderConfigFactory {
|
||||||
|
/**
|
||||||
|
* 创建配置构建器
|
||||||
|
*/
|
||||||
|
static builder<T extends ProviderId>(providerId: T): ProviderConfigBuilder<T> {
|
||||||
|
return new ProviderConfigBuilder(providerId)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从通用Provider对象创建配置 - 使用更优雅的处理器模式
|
||||||
|
*/
|
||||||
|
static fromProvider<T extends ProviderId>(
|
||||||
|
providerId: T,
|
||||||
|
provider: CompleteProviderConfig<T>,
|
||||||
|
options?: {
|
||||||
|
headers?: Record<string, string>
|
||||||
|
[key: string]: any
|
||||||
|
}
|
||||||
|
): ProviderSettingsMap[T] {
|
||||||
|
const builder = new ProviderConfigBuilder<T>(providerId)
|
||||||
|
|
||||||
|
// 设置基本配置
|
||||||
|
if (provider.apiKey) {
|
||||||
|
builder.withApiKey(provider.apiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (provider.baseURL) {
|
||||||
|
builder.withBaseURL(provider.baseURL)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 设置请求配置
|
||||||
|
if (options?.headers) {
|
||||||
|
builder.withRequestConfig({
|
||||||
|
headers: options.headers
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用配置处理器模式 - 更加优雅和可扩展
|
||||||
|
const handler = configHandlers[providerId]
|
||||||
|
if (handler) {
|
||||||
|
handler(builder, provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加其他自定义参数
|
||||||
|
if (options) {
|
||||||
|
const customOptions = { ...options }
|
||||||
|
delete customOptions.headers // 已经处理过了
|
||||||
|
if (Object.keys(customOptions).length > 0) {
|
||||||
|
builder.withCustomParams(customOptions)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder.build()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 快速创建 OpenAI 配置
|
||||||
|
*/
|
||||||
|
static createOpenAI(
|
||||||
|
apiKey: string,
|
||||||
|
options?: {
|
||||||
|
baseURL?: string
|
||||||
|
organization?: string
|
||||||
|
project?: string
|
||||||
|
}
|
||||||
|
) {
|
||||||
|
const builder = this.builder('openai')
|
||||||
|
|
||||||
|
// 使用类型安全的重载
|
||||||
|
if (options?.organization || options?.project) {
|
||||||
|
builder.withApiKey(apiKey, {
|
||||||
|
organization: options.organization,
|
||||||
|
project: options.project
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
builder.withApiKey(apiKey)
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder.withBaseURL(options?.baseURL || 'https://api.openai.com').build()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 快速创建 Anthropic 配置
|
||||||
|
*/
|
||||||
|
static createAnthropic(
|
||||||
|
apiKey: string,
|
||||||
|
options?: {
|
||||||
|
baseURL?: string
|
||||||
|
}
|
||||||
|
) {
|
||||||
|
return this.builder('anthropic')
|
||||||
|
.withApiKey(apiKey)
|
||||||
|
.withBaseURL(options?.baseURL || 'https://api.anthropic.com')
|
||||||
|
.build()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 快速创建 Azure OpenAI 配置
|
||||||
|
*/
|
||||||
|
static createAzureOpenAI(
|
||||||
|
apiKey: string,
|
||||||
|
options: {
|
||||||
|
baseURL: string
|
||||||
|
apiVersion?: string
|
||||||
|
resourceName?: string
|
||||||
|
}
|
||||||
|
) {
|
||||||
|
return this.builder('azure')
|
||||||
|
.withApiKey(apiKey)
|
||||||
|
.withBaseURL(options.baseURL)
|
||||||
|
.withAzureConfig({
|
||||||
|
apiVersion: options.apiVersion,
|
||||||
|
resourceName: options.resourceName
|
||||||
|
})
|
||||||
|
.build()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 快速创建 Google 配置
|
||||||
|
*/
|
||||||
|
static createGoogle(
|
||||||
|
apiKey: string,
|
||||||
|
options?: {
|
||||||
|
baseURL?: string
|
||||||
|
projectId?: string
|
||||||
|
location?: string
|
||||||
|
}
|
||||||
|
) {
|
||||||
|
return this.builder('google')
|
||||||
|
.withApiKey(apiKey)
|
||||||
|
.withBaseURL(options?.baseURL || 'https://generativelanguage.googleapis.com')
|
||||||
|
.build()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 快速创建 Vertex AI 配置
|
||||||
|
*/
|
||||||
|
static createVertexAI() {
|
||||||
|
// credentials: {
|
||||||
|
// clientEmail: string
|
||||||
|
// privateKey: string
|
||||||
|
// },
|
||||||
|
// options?: {
|
||||||
|
// project?: string
|
||||||
|
// location?: string
|
||||||
|
// }
|
||||||
|
// return this.builder('google-vertex')
|
||||||
|
// .withGoogleCredentials(credentials)
|
||||||
|
// .withGoogleVertexConfig({
|
||||||
|
// project: options?.project,
|
||||||
|
// location: options?.location
|
||||||
|
// })
|
||||||
|
// .build()
|
||||||
|
}
|
||||||
|
|
||||||
|
static createOpenAICompatible(baseURL: string, apiKey: string) {
|
||||||
|
return this.builder('openai-compatible').withBaseURL(baseURL).withApiKey(apiKey).build()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 便捷的配置创建函数
|
||||||
|
*/
|
||||||
|
export const createProviderConfig = ProviderConfigFactory.fromProvider
|
||||||
|
export const providerConfigBuilder = ProviderConfigFactory.builder
|
||||||
83
packages/aiCore/src/core/providers/index.ts
Normal file
83
packages/aiCore/src/core/providers/index.ts
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
/**
|
||||||
|
* Providers 模块统一导出 - 独立Provider包
|
||||||
|
*/
|
||||||
|
|
||||||
|
// ==================== 核心管理器 ====================
|
||||||
|
|
||||||
|
// Provider 注册表管理器
|
||||||
|
export { globalRegistryManagement, RegistryManagement } from './RegistryManagement'
|
||||||
|
|
||||||
|
// Provider 核心功能
|
||||||
|
export {
|
||||||
|
// 状态管理
|
||||||
|
cleanup,
|
||||||
|
clearAllProviders,
|
||||||
|
createAndRegisterProvider,
|
||||||
|
createProvider,
|
||||||
|
getAllProviderConfigAliases,
|
||||||
|
getAllProviderConfigs,
|
||||||
|
getImageModel,
|
||||||
|
// 工具函数
|
||||||
|
getInitializedProviders,
|
||||||
|
getLanguageModel,
|
||||||
|
getProviderConfig,
|
||||||
|
getProviderConfigByAlias,
|
||||||
|
getSupportedProviders,
|
||||||
|
getTextEmbeddingModel,
|
||||||
|
hasInitializedProviders,
|
||||||
|
// 工具函数
|
||||||
|
hasProviderConfig,
|
||||||
|
// 别名支持
|
||||||
|
hasProviderConfigByAlias,
|
||||||
|
isProviderConfigAlias,
|
||||||
|
// 错误类型
|
||||||
|
ProviderInitializationError,
|
||||||
|
// 全局访问
|
||||||
|
providerRegistry,
|
||||||
|
registerMultipleProviderConfigs,
|
||||||
|
registerProvider,
|
||||||
|
// 统一Provider系统
|
||||||
|
registerProviderConfig,
|
||||||
|
resolveProviderConfigId
|
||||||
|
} from './registry'
|
||||||
|
|
||||||
|
// ==================== 基础数据和类型 ====================
|
||||||
|
|
||||||
|
// 基础Provider数据源
|
||||||
|
export { baseProviderIds, baseProviders } from './schemas'
|
||||||
|
|
||||||
|
// 类型定义和Schema
|
||||||
|
export type {
|
||||||
|
BaseProviderId,
|
||||||
|
CustomProviderId,
|
||||||
|
DynamicProviderRegistration,
|
||||||
|
ProviderConfig,
|
||||||
|
ProviderId
|
||||||
|
} from './schemas' // 从 schemas 导出的类型
|
||||||
|
export { baseProviderIdSchema, customProviderIdSchema, providerConfigSchema, providerIdSchema } from './schemas' // Schema 导出
|
||||||
|
export type {
|
||||||
|
DynamicProviderRegistry,
|
||||||
|
ExtensibleProviderSettingsMap,
|
||||||
|
ProviderError,
|
||||||
|
ProviderSettingsMap,
|
||||||
|
ProviderTypeRegistrar
|
||||||
|
} from './types'
|
||||||
|
|
||||||
|
// ==================== 工具函数 ====================
|
||||||
|
|
||||||
|
// Provider配置工厂
|
||||||
|
export {
|
||||||
|
type BaseProviderConfig,
|
||||||
|
createProviderConfig,
|
||||||
|
ProviderConfigBuilder,
|
||||||
|
providerConfigBuilder,
|
||||||
|
ProviderConfigFactory
|
||||||
|
} from './factory'
|
||||||
|
|
||||||
|
// 工具函数
|
||||||
|
export { formatPrivateKey } from './utils'
|
||||||
|
|
||||||
|
// ==================== 扩展功能 ====================
|
||||||
|
|
||||||
|
// Hub Provider 功能
|
||||||
|
export { createHubProvider, type HubProviderConfig, HubProviderError } from './HubProvider'
|
||||||
320
packages/aiCore/src/core/providers/registry.ts
Normal file
320
packages/aiCore/src/core/providers/registry.ts
Normal file
@@ -0,0 +1,320 @@
|
|||||||
|
/**
|
||||||
|
* Provider 初始化器
|
||||||
|
* 负责根据配置创建 providers 并注册到全局管理器
|
||||||
|
* 集成了来自 ModelCreator 的特殊处理逻辑
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { customProvider } from 'ai'
|
||||||
|
|
||||||
|
import { globalRegistryManagement } from './RegistryManagement'
|
||||||
|
import { baseProviders, type ProviderConfig } from './schemas'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provider 初始化错误类型
|
||||||
|
*/
|
||||||
|
class ProviderInitializationError extends Error {
|
||||||
|
constructor(
|
||||||
|
message: string,
|
||||||
|
public providerId?: string,
|
||||||
|
public cause?: Error
|
||||||
|
) {
|
||||||
|
super(message)
|
||||||
|
this.name = 'ProviderInitializationError'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 全局管理器导出 ====================
|
||||||
|
|
||||||
|
export { globalRegistryManagement as providerRegistry }
|
||||||
|
|
||||||
|
// ==================== 便捷访问方法 ====================
|
||||||
|
|
||||||
|
export const getLanguageModel = (id: string) => globalRegistryManagement.languageModel(id as any)
|
||||||
|
export const getTextEmbeddingModel = (id: string) => globalRegistryManagement.textEmbeddingModel(id as any)
|
||||||
|
export const getImageModel = (id: string) => globalRegistryManagement.imageModel(id as any)
|
||||||
|
|
||||||
|
// ==================== 工具函数 ====================
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取支持的 Providers 列表
|
||||||
|
*/
|
||||||
|
export function getSupportedProviders(): Array<{
|
||||||
|
id: string
|
||||||
|
name: string
|
||||||
|
}> {
|
||||||
|
return baseProviders.map((provider) => ({
|
||||||
|
id: provider.id,
|
||||||
|
name: provider.name
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取所有已初始化的 providers
|
||||||
|
*/
|
||||||
|
export function getInitializedProviders(): string[] {
|
||||||
|
return globalRegistryManagement.getRegisteredProviders()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检查是否有任何已初始化的 providers
|
||||||
|
*/
|
||||||
|
export function hasInitializedProviders(): boolean {
|
||||||
|
return globalRegistryManagement.hasProviders()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 统一Provider配置系统 ====================
|
||||||
|
|
||||||
|
// 全局Provider配置存储
|
||||||
|
const providerConfigs = new Map<string, ProviderConfig>()
|
||||||
|
// 全局ProviderConfig别名映射 - 借鉴RegistryManagement模式
|
||||||
|
const providerConfigAliases = new Map<string, string>() // alias -> realId
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 初始化内置配置 - 将baseProviders转换为统一格式
|
||||||
|
*/
|
||||||
|
function initializeBuiltInConfigs(): void {
|
||||||
|
baseProviders.forEach((provider) => {
|
||||||
|
const config: ProviderConfig = {
|
||||||
|
id: provider.id,
|
||||||
|
name: provider.name,
|
||||||
|
creator: provider.creator as any, // 类型转换以兼容多种creator签名
|
||||||
|
supportsImageGeneration: provider.supportsImageGeneration || false
|
||||||
|
}
|
||||||
|
providerConfigs.set(provider.id, config)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 启动时自动注册内置配置
|
||||||
|
initializeBuiltInConfigs()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 步骤1: 注册Provider配置 - 仅存储配置,不执行创建
|
||||||
|
*/
|
||||||
|
export function registerProviderConfig(config: ProviderConfig): boolean {
|
||||||
|
try {
|
||||||
|
// 验证配置
|
||||||
|
if (!config || !config.id || !config.name) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否与已有配置冲突(包括内置配置)
|
||||||
|
if (providerConfigs.has(config.id)) {
|
||||||
|
console.warn(`ProviderConfig "${config.id}" already exists, will override`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 存储配置(内置和用户配置统一处理)
|
||||||
|
providerConfigs.set(config.id, config)
|
||||||
|
|
||||||
|
// 处理别名
|
||||||
|
if (config.aliases && config.aliases.length > 0) {
|
||||||
|
config.aliases.forEach((alias) => {
|
||||||
|
if (providerConfigAliases.has(alias)) {
|
||||||
|
console.warn(`ProviderConfig alias "${alias}" already exists, will override`)
|
||||||
|
}
|
||||||
|
providerConfigAliases.set(alias, config.id)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Failed to register ProviderConfig:`, error)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 步骤2: 创建Provider - 根据配置执行实际创建
|
||||||
|
*/
|
||||||
|
export async function createProvider(providerId: string, options: any): Promise<any> {
|
||||||
|
// 支持通过别名查找配置
|
||||||
|
const config = getProviderConfigByAlias(providerId)
|
||||||
|
|
||||||
|
if (!config) {
|
||||||
|
throw new Error(`ProviderConfig not found for id: ${providerId}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
let creator: (options: any) => any
|
||||||
|
|
||||||
|
if (config.creator) {
|
||||||
|
// 方式1: 直接执行 creator
|
||||||
|
creator = config.creator
|
||||||
|
} else if (config.import && config.creatorFunctionName) {
|
||||||
|
// 方式2: 动态导入并执行
|
||||||
|
const module = await config.import()
|
||||||
|
creator = (module as any)[config.creatorFunctionName]
|
||||||
|
|
||||||
|
if (!creator || typeof creator !== 'function') {
|
||||||
|
throw new Error(`Creator function "${config.creatorFunctionName}" not found in imported module`)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
throw new Error('No valid creator method provided in ProviderConfig')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用真实配置创建provider实例
|
||||||
|
return creator(options)
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Failed to create provider "${providerId}":`, error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 步骤3: 注册Provider到全局管理器
|
||||||
|
*/
|
||||||
|
export function registerProvider(providerId: string, provider: any): boolean {
|
||||||
|
try {
|
||||||
|
const config = providerConfigs.get(providerId)
|
||||||
|
if (!config) {
|
||||||
|
console.error(`ProviderConfig not found for id: ${providerId}`)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取aliases配置
|
||||||
|
const aliases = config.aliases
|
||||||
|
|
||||||
|
// 处理特殊provider逻辑
|
||||||
|
if (providerId === 'openai') {
|
||||||
|
// 注册默认 openai
|
||||||
|
globalRegistryManagement.registerProvider(providerId, provider, aliases)
|
||||||
|
|
||||||
|
// 创建并注册 openai-chat 变体
|
||||||
|
const openaiChatProvider = customProvider({
|
||||||
|
fallbackProvider: {
|
||||||
|
...provider,
|
||||||
|
languageModel: (modelId: string) => provider.chat(modelId)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
globalRegistryManagement.registerProvider(`${providerId}-chat`, openaiChatProvider)
|
||||||
|
} else if (providerId === 'azure') {
|
||||||
|
globalRegistryManagement.registerProvider(`${providerId}-chat`, provider, aliases)
|
||||||
|
// 跟上面相反,creator产出的默认会调用chat
|
||||||
|
const azureResponsesProvider = customProvider({
|
||||||
|
fallbackProvider: {
|
||||||
|
...provider,
|
||||||
|
languageModel: (modelId: string) => provider.responses(modelId)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
globalRegistryManagement.registerProvider(providerId, azureResponsesProvider)
|
||||||
|
} else {
|
||||||
|
// 其他provider直接注册
|
||||||
|
globalRegistryManagement.registerProvider(providerId, provider, aliases)
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Failed to register provider "${providerId}" to global registry:`, error)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 便捷函数: 一次性完成创建+注册
|
||||||
|
*/
|
||||||
|
export async function createAndRegisterProvider(providerId: string, options: any): Promise<boolean> {
|
||||||
|
try {
|
||||||
|
// 步骤2: 创建provider
|
||||||
|
const provider = await createProvider(providerId, options)
|
||||||
|
|
||||||
|
// 步骤3: 注册到全局管理器
|
||||||
|
return registerProvider(providerId, provider)
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`Failed to create and register provider "${providerId}":`, error)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 批量注册Provider配置
|
||||||
|
*/
|
||||||
|
export function registerMultipleProviderConfigs(configs: ProviderConfig[]): number {
|
||||||
|
let successCount = 0
|
||||||
|
configs.forEach((config) => {
|
||||||
|
if (registerProviderConfig(config)) {
|
||||||
|
successCount++
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return successCount
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检查是否有对应的Provider配置
|
||||||
|
*/
|
||||||
|
export function hasProviderConfig(providerId: string): boolean {
|
||||||
|
return providerConfigs.has(providerId)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 通过别名或ID检查是否有对应的Provider配置
|
||||||
|
*/
|
||||||
|
export function hasProviderConfigByAlias(aliasOrId: string): boolean {
|
||||||
|
const realId = resolveProviderConfigId(aliasOrId)
|
||||||
|
return providerConfigs.has(realId)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取所有Provider配置
|
||||||
|
*/
|
||||||
|
export function getAllProviderConfigs(): ProviderConfig[] {
|
||||||
|
return Array.from(providerConfigs.values())
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 根据ID获取Provider配置
|
||||||
|
*/
|
||||||
|
export function getProviderConfig(providerId: string): ProviderConfig | undefined {
|
||||||
|
return providerConfigs.get(providerId)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 通过别名或ID获取Provider配置
|
||||||
|
*/
|
||||||
|
export function getProviderConfigByAlias(aliasOrId: string): ProviderConfig | undefined {
|
||||||
|
// 先检查是否为别名,如果是则解析为真实ID
|
||||||
|
const realId = providerConfigAliases.get(aliasOrId) || aliasOrId
|
||||||
|
return providerConfigs.get(realId)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析真实的ProviderConfig ID(去别名化)
|
||||||
|
*/
|
||||||
|
export function resolveProviderConfigId(aliasOrId: string): string {
|
||||||
|
return providerConfigAliases.get(aliasOrId) || aliasOrId
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检查是否为ProviderConfig别名
|
||||||
|
*/
|
||||||
|
export function isProviderConfigAlias(id: string): boolean {
|
||||||
|
return providerConfigAliases.has(id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取所有ProviderConfig别名映射关系
|
||||||
|
*/
|
||||||
|
export function getAllProviderConfigAliases(): Record<string, string> {
|
||||||
|
const result: Record<string, string> = {}
|
||||||
|
providerConfigAliases.forEach((realId, alias) => {
|
||||||
|
result[alias] = realId
|
||||||
|
})
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 清理所有Provider配置和已注册的providers
|
||||||
|
*/
|
||||||
|
export function cleanup(): void {
|
||||||
|
providerConfigs.clear()
|
||||||
|
providerConfigAliases.clear() // 清理别名映射
|
||||||
|
globalRegistryManagement.clear()
|
||||||
|
// 重新初始化内置配置
|
||||||
|
initializeBuiltInConfigs()
|
||||||
|
}
|
||||||
|
|
||||||
|
export function clearAllProviders(): void {
|
||||||
|
globalRegistryManagement.clear()
|
||||||
|
}
|
||||||
|
|
||||||
|
// ==================== 导出错误类型 ====================
|
||||||
|
|
||||||
|
export { ProviderInitializationError }
|
||||||
178
packages/aiCore/src/core/providers/schemas.ts
Normal file
178
packages/aiCore/src/core/providers/schemas.ts
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
/**
|
||||||
|
* Provider Config 定义
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { createAnthropic } from '@ai-sdk/anthropic'
|
||||||
|
import { createAzure } from '@ai-sdk/azure'
|
||||||
|
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
|
||||||
|
import { createDeepSeek } from '@ai-sdk/deepseek'
|
||||||
|
import { createGoogleGenerativeAI } from '@ai-sdk/google'
|
||||||
|
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||||
|
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
|
||||||
|
import { createXai } from '@ai-sdk/xai'
|
||||||
|
import { customProvider, type Provider } from 'ai'
|
||||||
|
import * as z from 'zod'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 基础 Provider IDs
|
||||||
|
*/
|
||||||
|
export const baseProviderIds = [
|
||||||
|
'openai',
|
||||||
|
'openai-chat',
|
||||||
|
'openai-compatible',
|
||||||
|
'anthropic',
|
||||||
|
'google',
|
||||||
|
'xai',
|
||||||
|
'azure',
|
||||||
|
'azure-responses',
|
||||||
|
'deepseek'
|
||||||
|
] as const
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 基础 Provider ID Schema
|
||||||
|
*/
|
||||||
|
export const baseProviderIdSchema = z.enum(baseProviderIds)
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 基础 Provider ID
|
||||||
|
*/
|
||||||
|
export type BaseProviderId = z.infer<typeof baseProviderIdSchema>
|
||||||
|
|
||||||
|
export const baseProviderSchema = z.object({
|
||||||
|
id: baseProviderIdSchema,
|
||||||
|
name: z.string(),
|
||||||
|
creator: z.function().args(z.any()).returns(z.any()) as z.ZodType<(options: any) => Provider>,
|
||||||
|
supportsImageGeneration: z.boolean()
|
||||||
|
})
|
||||||
|
|
||||||
|
export type BaseProvider = z.infer<typeof baseProviderSchema>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 基础 Providers 定义
|
||||||
|
* 作为唯一数据源,避免重复维护
|
||||||
|
*/
|
||||||
|
export const baseProviders = [
|
||||||
|
{
|
||||||
|
id: 'openai',
|
||||||
|
name: 'OpenAI',
|
||||||
|
creator: createOpenAI,
|
||||||
|
supportsImageGeneration: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'openai-chat',
|
||||||
|
name: 'OpenAI Chat',
|
||||||
|
creator: (options: OpenAIProviderSettings) => {
|
||||||
|
const provider = createOpenAI(options)
|
||||||
|
return customProvider({
|
||||||
|
fallbackProvider: {
|
||||||
|
...provider,
|
||||||
|
languageModel: (modelId: string) => provider.chat(modelId)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
},
|
||||||
|
supportsImageGeneration: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'openai-compatible',
|
||||||
|
name: 'OpenAI Compatible',
|
||||||
|
creator: createOpenAICompatible,
|
||||||
|
supportsImageGeneration: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'anthropic',
|
||||||
|
name: 'Anthropic',
|
||||||
|
creator: createAnthropic,
|
||||||
|
supportsImageGeneration: false
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'google',
|
||||||
|
name: 'Google Generative AI',
|
||||||
|
creator: createGoogleGenerativeAI,
|
||||||
|
supportsImageGeneration: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'xai',
|
||||||
|
name: 'xAI (Grok)',
|
||||||
|
creator: createXai,
|
||||||
|
supportsImageGeneration: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'azure',
|
||||||
|
name: 'Azure OpenAI',
|
||||||
|
creator: createAzure,
|
||||||
|
supportsImageGeneration: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'azure-responses',
|
||||||
|
name: 'Azure OpenAI Responses',
|
||||||
|
creator: (options: AzureOpenAIProviderSettings) => {
|
||||||
|
const provider = createAzure(options)
|
||||||
|
return customProvider({
|
||||||
|
fallbackProvider: {
|
||||||
|
...provider,
|
||||||
|
languageModel: (modelId: string) => provider.responses(modelId)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
},
|
||||||
|
supportsImageGeneration: true
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'deepseek',
|
||||||
|
name: 'DeepSeek',
|
||||||
|
creator: createDeepSeek,
|
||||||
|
supportsImageGeneration: false
|
||||||
|
}
|
||||||
|
] as const satisfies BaseProvider[]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 用户自定义 Provider ID Schema
|
||||||
|
* 允许任意字符串,但排除基础 provider IDs 以避免冲突
|
||||||
|
*/
|
||||||
|
export const customProviderIdSchema = z
|
||||||
|
.string()
|
||||||
|
.min(1)
|
||||||
|
.refine((id) => !baseProviderIds.includes(id as any), {
|
||||||
|
message: 'Custom provider ID cannot conflict with base provider IDs'
|
||||||
|
})
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provider ID Schema - 支持基础和自定义
|
||||||
|
*/
|
||||||
|
export const providerIdSchema = z.union([baseProviderIdSchema, customProviderIdSchema])
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provider 配置 Schema
|
||||||
|
* 用于Provider的配置验证
|
||||||
|
*/
|
||||||
|
export const providerConfigSchema = z
|
||||||
|
.object({
|
||||||
|
id: customProviderIdSchema, // 只允许自定义ID
|
||||||
|
name: z.string().min(1),
|
||||||
|
creator: z.function().optional(),
|
||||||
|
import: z.function().optional(),
|
||||||
|
creatorFunctionName: z.string().optional(),
|
||||||
|
supportsImageGeneration: z.boolean().default(false),
|
||||||
|
imageCreator: z.function().optional(),
|
||||||
|
validateOptions: z.function().optional(),
|
||||||
|
aliases: z.array(z.string()).optional()
|
||||||
|
})
|
||||||
|
.refine((data) => data.creator || (data.import && data.creatorFunctionName), {
|
||||||
|
message: 'Must provide either creator function or import configuration'
|
||||||
|
})
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provider ID 类型 - 基于 zod schema 推导
|
||||||
|
*/
|
||||||
|
export type ProviderId = z.infer<typeof providerIdSchema>
|
||||||
|
export type CustomProviderId = z.infer<typeof customProviderIdSchema>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provider 配置类型
|
||||||
|
*/
|
||||||
|
export type ProviderConfig = z.infer<typeof providerConfigSchema>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 兼容性类型别名
|
||||||
|
* @deprecated 使用 ProviderConfig 替代
|
||||||
|
*/
|
||||||
|
export type DynamicProviderRegistration = ProviderConfig
|
||||||
96
packages/aiCore/src/core/providers/types.ts
Normal file
96
packages/aiCore/src/core/providers/types.ts
Normal file
@@ -0,0 +1,96 @@
|
|||||||
|
import { type AnthropicProviderSettings } from '@ai-sdk/anthropic'
|
||||||
|
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
|
||||||
|
import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
|
||||||
|
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
|
||||||
|
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||||
|
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
|
||||||
|
import {
|
||||||
|
EmbeddingModelV2 as EmbeddingModel,
|
||||||
|
ImageModelV2 as ImageModel,
|
||||||
|
LanguageModelV2 as LanguageModel,
|
||||||
|
ProviderV2,
|
||||||
|
SpeechModelV2 as SpeechModel,
|
||||||
|
TranscriptionModelV2 as TranscriptionModel
|
||||||
|
} from '@ai-sdk/provider'
|
||||||
|
import { type XaiProviderSettings } from '@ai-sdk/xai'
|
||||||
|
|
||||||
|
// 导入基于 Zod 的 ProviderId 类型
|
||||||
|
import { type ProviderId as ZodProviderId } from './schemas'
|
||||||
|
|
||||||
|
export interface ExtensibleProviderSettingsMap {
|
||||||
|
// 基础的静态providers
|
||||||
|
openai: OpenAIProviderSettings
|
||||||
|
'openai-responses': OpenAIProviderSettings
|
||||||
|
'openai-compatible': OpenAICompatibleProviderSettings
|
||||||
|
anthropic: AnthropicProviderSettings
|
||||||
|
google: GoogleGenerativeAIProviderSettings
|
||||||
|
xai: XaiProviderSettings
|
||||||
|
azure: AzureOpenAIProviderSettings
|
||||||
|
deepseek: DeepSeekProviderSettings
|
||||||
|
}
|
||||||
|
|
||||||
|
// 动态扩展的provider类型注册表
|
||||||
|
export interface DynamicProviderRegistry {
|
||||||
|
[key: string]: any
|
||||||
|
}
|
||||||
|
|
||||||
|
// 合并基础和动态provider类型
|
||||||
|
export type ProviderSettingsMap = ExtensibleProviderSettingsMap & DynamicProviderRegistry
|
||||||
|
|
||||||
|
// 错误类型
|
||||||
|
export class ProviderError extends Error {
|
||||||
|
constructor(
|
||||||
|
message: string,
|
||||||
|
public providerId: string,
|
||||||
|
public code?: string,
|
||||||
|
public cause?: Error
|
||||||
|
) {
|
||||||
|
super(message)
|
||||||
|
this.name = 'ProviderError'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 动态ProviderId类型 - 基于 Zod Schema,支持运行时扩展和验证
|
||||||
|
export type ProviderId = ZodProviderId
|
||||||
|
|
||||||
|
export interface ProviderTypeRegistrar {
|
||||||
|
registerProviderType<T extends string, S>(providerId: T, settingsType: S): void
|
||||||
|
getProviderSettings<T extends string>(providerId: T): any
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重新导出所有类型供外部使用
|
||||||
|
export type {
|
||||||
|
AnthropicProviderSettings,
|
||||||
|
AzureOpenAIProviderSettings,
|
||||||
|
DeepSeekProviderSettings,
|
||||||
|
GoogleGenerativeAIProviderSettings,
|
||||||
|
OpenAICompatibleProviderSettings,
|
||||||
|
OpenAIProviderSettings,
|
||||||
|
XaiProviderSettings
|
||||||
|
}
|
||||||
|
|
||||||
|
export type AiSdkModel = LanguageModel | ImageModel | EmbeddingModel<string> | TranscriptionModel | SpeechModel
|
||||||
|
|
||||||
|
export type AiSdkModelType = 'text' | 'image' | 'embedding' | 'transcription' | 'speech'
|
||||||
|
|
||||||
|
export const METHOD_MAP = {
|
||||||
|
text: 'languageModel',
|
||||||
|
image: 'imageModel',
|
||||||
|
embedding: 'textEmbeddingModel',
|
||||||
|
transcription: 'transcriptionModel',
|
||||||
|
speech: 'speechModel'
|
||||||
|
} as const satisfies Record<AiSdkModelType, keyof ProviderV2>
|
||||||
|
|
||||||
|
export type AiSdkModelMethodMap = Record<AiSdkModelType, keyof ProviderV2>
|
||||||
|
|
||||||
|
export type AiSdkModelReturnMap = {
|
||||||
|
text: LanguageModel
|
||||||
|
image: ImageModel
|
||||||
|
embedding: EmbeddingModel<string>
|
||||||
|
transcription: TranscriptionModel
|
||||||
|
speech: SpeechModel
|
||||||
|
}
|
||||||
|
|
||||||
|
export type AiSdkMethodName<T extends AiSdkModelType> = (typeof METHOD_MAP)[T]
|
||||||
|
|
||||||
|
export type AiSdkModelReturn<T extends AiSdkModelType> = AiSdkModelReturnMap[T]
|
||||||
86
packages/aiCore/src/core/providers/utils.ts
Normal file
86
packages/aiCore/src/core/providers/utils.ts
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
/**
|
||||||
|
* 格式化私钥,确保它包含正确的PEM头部和尾部
|
||||||
|
*/
|
||||||
|
export function formatPrivateKey(privateKey: string): string {
|
||||||
|
if (!privateKey || typeof privateKey !== 'string') {
|
||||||
|
throw new Error('Private key must be a non-empty string')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 先处理 JSON 字符串中的转义换行符
|
||||||
|
const key = privateKey.replace(/\\n/g, '\n')
|
||||||
|
|
||||||
|
// 检查是否已经是正确格式的 PEM 私钥
|
||||||
|
const hasBeginMarker = key.includes('-----BEGIN PRIVATE KEY-----')
|
||||||
|
const hasEndMarker = key.includes('-----END PRIVATE KEY-----')
|
||||||
|
|
||||||
|
if (hasBeginMarker && hasEndMarker) {
|
||||||
|
// 已经是 PEM 格式,但可能格式不规范,重新格式化
|
||||||
|
return normalizePemFormat(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有完整的 PEM 头尾,尝试重新构建
|
||||||
|
return reconstructPemKey(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 标准化 PEM 格式
|
||||||
|
*/
|
||||||
|
function normalizePemFormat(pemKey: string): string {
|
||||||
|
// 分离头部、内容和尾部
|
||||||
|
const lines = pemKey
|
||||||
|
.split('\n')
|
||||||
|
.map((line) => line.trim())
|
||||||
|
.filter((line) => line.length > 0)
|
||||||
|
|
||||||
|
let keyContent = ''
|
||||||
|
let foundBegin = false
|
||||||
|
let foundEnd = false
|
||||||
|
|
||||||
|
for (const line of lines) {
|
||||||
|
if (line === '-----BEGIN PRIVATE KEY-----') {
|
||||||
|
foundBegin = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if (line === '-----END PRIVATE KEY-----') {
|
||||||
|
foundEnd = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if (foundBegin && !foundEnd) {
|
||||||
|
keyContent += line
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!foundBegin || !foundEnd || !keyContent) {
|
||||||
|
throw new Error('Invalid PEM format: missing BEGIN/END markers or key content')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 重新格式化为 64 字符一行
|
||||||
|
const formattedContent = keyContent.match(/.{1,64}/g)?.join('\n') || keyContent
|
||||||
|
|
||||||
|
return `-----BEGIN PRIVATE KEY-----\n${formattedContent}\n-----END PRIVATE KEY-----`
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 重新构建 PEM 私钥
|
||||||
|
*/
|
||||||
|
function reconstructPemKey(key: string): string {
|
||||||
|
// 移除所有空白字符和可能存在的不完整头尾
|
||||||
|
let cleanKey = key.replace(/\s+/g, '')
|
||||||
|
cleanKey = cleanKey.replace(/-----BEGIN[^-]*-----/g, '')
|
||||||
|
cleanKey = cleanKey.replace(/-----END[^-]*-----/g, '')
|
||||||
|
|
||||||
|
// 确保私钥内容不为空
|
||||||
|
if (!cleanKey) {
|
||||||
|
throw new Error('Private key content is empty after cleaning')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 验证是否是有效的 Base64 字符
|
||||||
|
if (!/^[A-Za-z0-9+/=]+$/.test(cleanKey)) {
|
||||||
|
throw new Error('Private key contains invalid characters (not valid Base64)')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 格式化为 64 字符一行
|
||||||
|
const formattedKey = cleanKey.match(/.{1,64}/g)?.join('\n') || cleanKey
|
||||||
|
|
||||||
|
return `-----BEGIN PRIVATE KEY-----\n${formattedKey}\n-----END PRIVATE KEY-----`
|
||||||
|
}
|
||||||
523
packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts
Normal file
523
packages/aiCore/src/core/runtime/__tests__/generateImage.test.ts
Normal file
@@ -0,0 +1,523 @@
|
|||||||
|
import { ImageModelV2 } from '@ai-sdk/provider'
|
||||||
|
import { experimental_generateImage as aiGenerateImage, NoImageGeneratedError } from 'ai'
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
import { type AiPlugin } from '../../plugins'
|
||||||
|
import { globalRegistryManagement } from '../../providers/RegistryManagement'
|
||||||
|
import { ImageGenerationError, ImageModelResolutionError } from '../errors'
|
||||||
|
import { RuntimeExecutor } from '../executor'
|
||||||
|
|
||||||
|
// Mock dependencies
|
||||||
|
vi.mock('ai', () => ({
|
||||||
|
experimental_generateImage: vi.fn(),
|
||||||
|
NoImageGeneratedError: class NoImageGeneratedError extends Error {
|
||||||
|
static isInstance = vi.fn()
|
||||||
|
constructor() {
|
||||||
|
super('No image generated')
|
||||||
|
this.name = 'NoImageGeneratedError'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../../providers/RegistryManagement', () => ({
|
||||||
|
globalRegistryManagement: {
|
||||||
|
imageModel: vi.fn()
|
||||||
|
},
|
||||||
|
DEFAULT_SEPARATOR: '|'
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('RuntimeExecutor.generateImage', () => {
|
||||||
|
let executor: RuntimeExecutor<'openai'>
|
||||||
|
let mockImageModel: ImageModelV2
|
||||||
|
let mockGenerateImageResult: any
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
// Reset all mocks
|
||||||
|
vi.clearAllMocks()
|
||||||
|
|
||||||
|
// Create executor instance
|
||||||
|
executor = RuntimeExecutor.create('openai', {
|
||||||
|
apiKey: 'test-key'
|
||||||
|
})
|
||||||
|
|
||||||
|
// Mock image model
|
||||||
|
mockImageModel = {
|
||||||
|
modelId: 'dall-e-3',
|
||||||
|
provider: 'openai'
|
||||||
|
} as ImageModelV2
|
||||||
|
|
||||||
|
// Mock generateImage result
|
||||||
|
mockGenerateImageResult = {
|
||||||
|
image: {
|
||||||
|
base64: 'base64-encoded-image-data',
|
||||||
|
uint8Array: new Uint8Array([1, 2, 3]),
|
||||||
|
mediaType: 'image/png'
|
||||||
|
},
|
||||||
|
images: [
|
||||||
|
{
|
||||||
|
base64: 'base64-encoded-image-data',
|
||||||
|
uint8Array: new Uint8Array([1, 2, 3]),
|
||||||
|
mediaType: 'image/png'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
warnings: [],
|
||||||
|
providerMetadata: {
|
||||||
|
openai: {
|
||||||
|
images: [{ revisedPrompt: 'A detailed prompt' }]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
responses: []
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup mocks to avoid "No providers registered" error
|
||||||
|
vi.mocked(globalRegistryManagement.imageModel).mockReturnValue(mockImageModel)
|
||||||
|
vi.mocked(aiGenerateImage).mockResolvedValue(mockGenerateImageResult)
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Basic functionality', () => {
|
||||||
|
it('should generate a single image with minimal parameters', async () => {
|
||||||
|
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape at sunset' })
|
||||||
|
|
||||||
|
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai|dall-e-3')
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A futuristic cityscape at sunset'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result).toEqual(mockGenerateImageResult)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should generate image with pre-created model', async () => {
|
||||||
|
const result = await executor.generateImage({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A beautiful landscape'
|
||||||
|
})
|
||||||
|
|
||||||
|
// Note: globalRegistryManagement.imageModel may still be called due to resolveImageModel logic
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A beautiful landscape'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result).toEqual(mockGenerateImageResult)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support multiple images generation', async () => {
|
||||||
|
await executor.generateImage({ model: 'dall-e-3', prompt: 'A futuristic cityscape', n: 3 })
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A futuristic cityscape',
|
||||||
|
n: 3
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support size specification', async () => {
|
||||||
|
await executor.generateImage({ model: 'dall-e-3', prompt: 'A beautiful sunset', size: '1024x1024' })
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A beautiful sunset',
|
||||||
|
size: '1024x1024'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support aspect ratio specification', async () => {
|
||||||
|
await executor.generateImage({ model: 'dall-e-3', prompt: 'A mountain landscape', aspectRatio: '16:9' })
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A mountain landscape',
|
||||||
|
aspectRatio: '16:9'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support seed for consistent output', async () => {
|
||||||
|
await executor.generateImage({ model: 'dall-e-3', prompt: 'A cat in space', seed: 1234567890 })
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A cat in space',
|
||||||
|
seed: 1234567890
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support abort signal', async () => {
|
||||||
|
const abortController = new AbortController()
|
||||||
|
|
||||||
|
await executor.generateImage({ model: 'dall-e-3', prompt: 'A cityscape', abortSignal: abortController.signal })
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A cityscape',
|
||||||
|
abortSignal: abortController.signal
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support provider-specific options', async () => {
|
||||||
|
await executor.generateImage({
|
||||||
|
model: 'dall-e-3',
|
||||||
|
prompt: 'A space station',
|
||||||
|
providerOptions: {
|
||||||
|
openai: {
|
||||||
|
quality: 'hd',
|
||||||
|
style: 'vivid'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A space station',
|
||||||
|
providerOptions: {
|
||||||
|
openai: {
|
||||||
|
quality: 'hd',
|
||||||
|
style: 'vivid'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support custom headers', async () => {
|
||||||
|
await executor.generateImage({
|
||||||
|
model: 'dall-e-3',
|
||||||
|
prompt: 'A robot',
|
||||||
|
headers: {
|
||||||
|
'X-Custom-Header': 'test-value'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A robot',
|
||||||
|
headers: {
|
||||||
|
'X-Custom-Header': 'test-value'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Plugin integration', () => {
|
||||||
|
it('should execute plugins in correct order', async () => {
|
||||||
|
const pluginCallOrder: string[] = []
|
||||||
|
|
||||||
|
const testPlugin: AiPlugin = {
|
||||||
|
name: 'test-plugin',
|
||||||
|
onRequestStart: vi.fn(async () => {
|
||||||
|
pluginCallOrder.push('onRequestStart')
|
||||||
|
}),
|
||||||
|
transformParams: vi.fn(async (params) => {
|
||||||
|
pluginCallOrder.push('transformParams')
|
||||||
|
return { ...params, size: '512x512' }
|
||||||
|
}),
|
||||||
|
transformResult: vi.fn(async (result) => {
|
||||||
|
pluginCallOrder.push('transformResult')
|
||||||
|
return { ...result, processed: true }
|
||||||
|
}),
|
||||||
|
onRequestEnd: vi.fn(async () => {
|
||||||
|
pluginCallOrder.push('onRequestEnd')
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const executorWithPlugin = RuntimeExecutor.create(
|
||||||
|
'openai',
|
||||||
|
{
|
||||||
|
apiKey: 'test-key'
|
||||||
|
},
|
||||||
|
[testPlugin]
|
||||||
|
)
|
||||||
|
|
||||||
|
const result = await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
|
||||||
|
|
||||||
|
expect(pluginCallOrder).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
|
||||||
|
|
||||||
|
expect(testPlugin.transformParams).toHaveBeenCalledWith(
|
||||||
|
{ prompt: 'A test image' },
|
||||||
|
expect.objectContaining({
|
||||||
|
providerId: 'openai',
|
||||||
|
modelId: 'dall-e-3'
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A test image',
|
||||||
|
size: '512x512' // Should be transformed by plugin
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
...mockGenerateImageResult,
|
||||||
|
processed: true // Should be transformed by plugin
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle model resolution through plugins', async () => {
|
||||||
|
const customImageModel = {
|
||||||
|
modelId: 'custom-model',
|
||||||
|
provider: 'openai'
|
||||||
|
} as ImageModelV2
|
||||||
|
|
||||||
|
const modelResolutionPlugin: AiPlugin = {
|
||||||
|
name: 'model-resolver',
|
||||||
|
resolveModel: vi.fn(async () => customImageModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
const executorWithPlugin = RuntimeExecutor.create(
|
||||||
|
'openai',
|
||||||
|
{
|
||||||
|
apiKey: 'test-key'
|
||||||
|
},
|
||||||
|
[modelResolutionPlugin]
|
||||||
|
)
|
||||||
|
|
||||||
|
await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
|
||||||
|
|
||||||
|
expect(modelResolutionPlugin.resolveModel).toHaveBeenCalledWith(
|
||||||
|
'dall-e-3',
|
||||||
|
expect.objectContaining({
|
||||||
|
providerId: 'openai',
|
||||||
|
modelId: 'dall-e-3'
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: customImageModel,
|
||||||
|
prompt: 'A test image'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support recursive calls from plugins', async () => {
|
||||||
|
const recursivePlugin: AiPlugin = {
|
||||||
|
name: 'recursive-plugin',
|
||||||
|
transformParams: vi.fn(async (params, context) => {
|
||||||
|
if (!context.isRecursiveCall && params.prompt === 'original') {
|
||||||
|
// Make a recursive call with modified prompt
|
||||||
|
await context.recursiveCall({
|
||||||
|
model: 'dall-e-3',
|
||||||
|
prompt: 'modified'
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const executorWithPlugin = RuntimeExecutor.create(
|
||||||
|
'openai',
|
||||||
|
{
|
||||||
|
apiKey: 'test-key'
|
||||||
|
},
|
||||||
|
[recursivePlugin]
|
||||||
|
)
|
||||||
|
|
||||||
|
await executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'original' })
|
||||||
|
|
||||||
|
expect(recursivePlugin.transformParams).toHaveBeenCalledTimes(2)
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledTimes(2)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Error handling', () => {
|
||||||
|
it('should handle model creation errors', async () => {
|
||||||
|
const modelError = new Error('Failed to get image model')
|
||||||
|
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
|
||||||
|
throw modelError
|
||||||
|
})
|
||||||
|
|
||||||
|
await expect(executor.generateImage({ model: 'invalid-model', prompt: 'A test image' })).rejects.toThrow(
|
||||||
|
ImageGenerationError
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle ImageModelResolutionError correctly', async () => {
|
||||||
|
const resolutionError = new ImageModelResolutionError('invalid-model', 'openai', new Error('Model not found'))
|
||||||
|
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
|
||||||
|
throw resolutionError
|
||||||
|
})
|
||||||
|
|
||||||
|
const thrownError = await executor
|
||||||
|
.generateImage({ model: 'invalid-model', prompt: 'A test image' })
|
||||||
|
.catch((error) => error)
|
||||||
|
|
||||||
|
expect(thrownError).toBeInstanceOf(ImageGenerationError)
|
||||||
|
expect(thrownError.message).toContain('Failed to generate image:')
|
||||||
|
expect(thrownError.providerId).toBe('openai')
|
||||||
|
expect(thrownError.modelId).toBe('invalid-model')
|
||||||
|
expect(thrownError.cause).toBeInstanceOf(ImageModelResolutionError)
|
||||||
|
expect(thrownError.cause.message).toContain('Failed to resolve image model: invalid-model')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle ImageModelResolutionError without provider', async () => {
|
||||||
|
const resolutionError = new ImageModelResolutionError('unknown-model')
|
||||||
|
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
|
||||||
|
throw resolutionError
|
||||||
|
})
|
||||||
|
|
||||||
|
await expect(executor.generateImage({ model: 'unknown-model', prompt: 'A test image' })).rejects.toThrow(
|
||||||
|
ImageGenerationError
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle image generation API errors', async () => {
|
||||||
|
const apiError = new Error('API request failed')
|
||||||
|
vi.mocked(aiGenerateImage).mockRejectedValue(apiError)
|
||||||
|
|
||||||
|
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
|
||||||
|
'Failed to generate image:'
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle NoImageGeneratedError', async () => {
|
||||||
|
const noImageError = new NoImageGeneratedError({
|
||||||
|
cause: new Error('No image generated'),
|
||||||
|
responses: []
|
||||||
|
})
|
||||||
|
|
||||||
|
vi.mocked(aiGenerateImage).mockRejectedValue(noImageError)
|
||||||
|
vi.mocked(NoImageGeneratedError.isInstance).mockReturnValue(true)
|
||||||
|
|
||||||
|
await expect(executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
|
||||||
|
'Failed to generate image:'
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should execute onError plugin hook on failure', async () => {
|
||||||
|
const error = new Error('Generation failed')
|
||||||
|
vi.mocked(aiGenerateImage).mockRejectedValue(error)
|
||||||
|
|
||||||
|
const errorPlugin: AiPlugin = {
|
||||||
|
name: 'error-handler',
|
||||||
|
onError: vi.fn()
|
||||||
|
}
|
||||||
|
|
||||||
|
const executorWithPlugin = RuntimeExecutor.create(
|
||||||
|
'openai',
|
||||||
|
{
|
||||||
|
apiKey: 'test-key'
|
||||||
|
},
|
||||||
|
[errorPlugin]
|
||||||
|
)
|
||||||
|
|
||||||
|
await expect(executorWithPlugin.generateImage({ model: 'dall-e-3', prompt: 'A test image' })).rejects.toThrow(
|
||||||
|
'Failed to generate image:'
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(errorPlugin.onError).toHaveBeenCalledWith(
|
||||||
|
error,
|
||||||
|
expect.objectContaining({
|
||||||
|
providerId: 'openai',
|
||||||
|
modelId: 'dall-e-3'
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle abort signal timeout', async () => {
|
||||||
|
const abortError = new Error('Operation was aborted')
|
||||||
|
abortError.name = 'AbortError'
|
||||||
|
vi.mocked(aiGenerateImage).mockRejectedValue(abortError)
|
||||||
|
|
||||||
|
const abortController = new AbortController()
|
||||||
|
setTimeout(() => abortController.abort(), 10)
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', abortSignal: abortController.signal })
|
||||||
|
).rejects.toThrow('Failed to generate image:')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Multiple providers support', () => {
|
||||||
|
it('should work with different providers', async () => {
|
||||||
|
const googleExecutor = RuntimeExecutor.create('google', {
|
||||||
|
apiKey: 'google-key'
|
||||||
|
})
|
||||||
|
|
||||||
|
await googleExecutor.generateImage({ model: 'imagen-3.0-generate-002', prompt: 'A landscape' })
|
||||||
|
|
||||||
|
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('google|imagen-3.0-generate-002')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support xAI Grok image models', async () => {
|
||||||
|
const xaiExecutor = RuntimeExecutor.create('xai', {
|
||||||
|
apiKey: 'xai-key'
|
||||||
|
})
|
||||||
|
|
||||||
|
await xaiExecutor.generateImage({ model: 'grok-2-image', prompt: 'A futuristic robot' })
|
||||||
|
|
||||||
|
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('xai|grok-2-image')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Advanced features', () => {
|
||||||
|
it('should support batch image generation with maxImagesPerCall', async () => {
|
||||||
|
await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', n: 10, maxImagesPerCall: 5 })
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A test image',
|
||||||
|
n: 10,
|
||||||
|
maxImagesPerCall: 5
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support retries with maxRetries', async () => {
|
||||||
|
await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image', maxRetries: 3 })
|
||||||
|
|
||||||
|
expect(aiGenerateImage).toHaveBeenCalledWith({
|
||||||
|
model: mockImageModel,
|
||||||
|
prompt: 'A test image',
|
||||||
|
maxRetries: 3
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle warnings from the model', async () => {
|
||||||
|
const resultWithWarnings = {
|
||||||
|
...mockGenerateImageResult,
|
||||||
|
warnings: [
|
||||||
|
{
|
||||||
|
type: 'unsupported-setting',
|
||||||
|
message: 'Size parameter not supported for this model'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithWarnings)
|
||||||
|
|
||||||
|
const result = await executor.generateImage({
|
||||||
|
model: 'dall-e-3',
|
||||||
|
prompt: 'A test image',
|
||||||
|
size: '2048x2048' // Unsupported size
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.warnings).toHaveLength(1)
|
||||||
|
expect(result.warnings[0].type).toBe('unsupported-setting')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should provide access to provider metadata', async () => {
|
||||||
|
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
|
||||||
|
|
||||||
|
expect(result.providerMetadata).toBeDefined()
|
||||||
|
expect(result.providerMetadata.openai).toBeDefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should provide response metadata', async () => {
|
||||||
|
const resultWithMetadata = {
|
||||||
|
...mockGenerateImageResult,
|
||||||
|
responses: [
|
||||||
|
{
|
||||||
|
timestamp: new Date(),
|
||||||
|
modelId: 'dall-e-3',
|
||||||
|
headers: { 'x-request-id': 'test-123' }
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithMetadata)
|
||||||
|
|
||||||
|
const result = await executor.generateImage({ model: 'dall-e-3', prompt: 'A test image' })
|
||||||
|
|
||||||
|
expect(result.responses).toHaveLength(1)
|
||||||
|
expect(result.responses[0].modelId).toBe('dall-e-3')
|
||||||
|
expect(result.responses[0].headers).toEqual({ 'x-request-id': 'test-123' })
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
38
packages/aiCore/src/core/runtime/errors.ts
Normal file
38
packages/aiCore/src/core/runtime/errors.ts
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
/**
|
||||||
|
* Error classes for runtime operations
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Error thrown when image generation fails
|
||||||
|
*/
|
||||||
|
export class ImageGenerationError extends Error {
|
||||||
|
constructor(
|
||||||
|
message: string,
|
||||||
|
public providerId?: string,
|
||||||
|
public modelId?: string,
|
||||||
|
public cause?: Error
|
||||||
|
) {
|
||||||
|
super(message)
|
||||||
|
this.name = 'ImageGenerationError'
|
||||||
|
|
||||||
|
// Maintain proper stack trace (for V8 engines)
|
||||||
|
if (Error.captureStackTrace) {
|
||||||
|
Error.captureStackTrace(this, ImageGenerationError)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Error thrown when model resolution fails during image generation
|
||||||
|
*/
|
||||||
|
export class ImageModelResolutionError extends ImageGenerationError {
|
||||||
|
constructor(modelId: string, providerId?: string, cause?: Error) {
|
||||||
|
super(
|
||||||
|
`Failed to resolve image model: ${modelId}${providerId ? ` for provider: ${providerId}` : ''}`,
|
||||||
|
providerId,
|
||||||
|
modelId,
|
||||||
|
cause
|
||||||
|
)
|
||||||
|
this.name = 'ImageModelResolutionError'
|
||||||
|
}
|
||||||
|
}
|
||||||
321
packages/aiCore/src/core/runtime/executor.ts
Normal file
321
packages/aiCore/src/core/runtime/executor.ts
Normal file
@@ -0,0 +1,321 @@
|
|||||||
|
/**
|
||||||
|
* 运行时执行器
|
||||||
|
* 专注于插件化的AI调用处理
|
||||||
|
*/
|
||||||
|
import { ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||||
|
import {
|
||||||
|
experimental_generateImage as generateImage,
|
||||||
|
generateObject,
|
||||||
|
generateText,
|
||||||
|
LanguageModel,
|
||||||
|
streamObject,
|
||||||
|
streamText
|
||||||
|
} from 'ai'
|
||||||
|
|
||||||
|
import { globalModelResolver } from '../models'
|
||||||
|
import { type ModelConfig } from '../models/types'
|
||||||
|
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
|
||||||
|
import { type ProviderId } from '../providers'
|
||||||
|
import { ImageGenerationError, ImageModelResolutionError } from './errors'
|
||||||
|
import { PluginEngine } from './pluginEngine'
|
||||||
|
import { type RuntimeConfig } from './types'
|
||||||
|
|
||||||
|
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
|
||||||
|
public pluginEngine: PluginEngine<T>
|
||||||
|
// private options: ProviderSettingsMap[T]
|
||||||
|
private config: RuntimeConfig<T>
|
||||||
|
|
||||||
|
constructor(config: RuntimeConfig<T>) {
|
||||||
|
// if (!isProviderSupported(config.providerId)) {
|
||||||
|
// throw new Error(`Unsupported provider: ${config.providerId}`)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// 存储options供后续使用
|
||||||
|
// this.options = config.options
|
||||||
|
this.config = config
|
||||||
|
// 创建插件客户端
|
||||||
|
this.pluginEngine = new PluginEngine(config.providerId, config.plugins || [])
|
||||||
|
}
|
||||||
|
|
||||||
|
private createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) {
|
||||||
|
return definePlugin({
|
||||||
|
name: '_internal_resolveModel',
|
||||||
|
enforce: 'post',
|
||||||
|
|
||||||
|
resolveModel: async (modelId: string) => {
|
||||||
|
// 注意:extraModelConfig 暂时不支持,已在新架构中移除
|
||||||
|
return await this.resolveModel(modelId, middlewares)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
private createResolveImageModelPlugin() {
|
||||||
|
return definePlugin({
|
||||||
|
name: '_internal_resolveImageModel',
|
||||||
|
enforce: 'post',
|
||||||
|
|
||||||
|
resolveModel: async (modelId: string) => {
|
||||||
|
return await this.resolveImageModel(modelId)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
private createConfigureContextPlugin() {
|
||||||
|
return definePlugin({
|
||||||
|
name: '_internal_configureContext',
|
||||||
|
configureContext: async (context: AiRequestContext) => {
|
||||||
|
context.executor = this
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// === 高阶重载:直接使用模型 ===
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 流式文本生成
|
||||||
|
*/
|
||||||
|
async streamText(
|
||||||
|
params: Parameters<typeof streamText>[0],
|
||||||
|
options?: {
|
||||||
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
|
}
|
||||||
|
): Promise<ReturnType<typeof streamText>> {
|
||||||
|
const { model, ...restParams } = params
|
||||||
|
|
||||||
|
// 根据 model 类型决定插件配置
|
||||||
|
if (typeof model === 'string') {
|
||||||
|
this.pluginEngine.usePlugins([
|
||||||
|
this.createResolveModelPlugin(options?.middlewares),
|
||||||
|
this.createConfigureContextPlugin()
|
||||||
|
])
|
||||||
|
} else {
|
||||||
|
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.pluginEngine.executeStreamWithPlugins(
|
||||||
|
'streamText',
|
||||||
|
model,
|
||||||
|
restParams,
|
||||||
|
async (resolvedModel, transformedParams, streamTransforms) => {
|
||||||
|
const experimental_transform =
|
||||||
|
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
|
||||||
|
|
||||||
|
const finalParams = {
|
||||||
|
model: resolvedModel,
|
||||||
|
...transformedParams,
|
||||||
|
experimental_transform
|
||||||
|
} as Parameters<typeof streamText>[0]
|
||||||
|
|
||||||
|
return await streamText(finalParams)
|
||||||
|
}
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// === 其他方法的重载 ===
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 生成文本
|
||||||
|
*/
|
||||||
|
async generateText(
|
||||||
|
params: Parameters<typeof generateText>[0],
|
||||||
|
options?: {
|
||||||
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
|
}
|
||||||
|
): Promise<ReturnType<typeof generateText>> {
|
||||||
|
const { model, ...restParams } = params
|
||||||
|
|
||||||
|
// 根据 model 类型决定插件配置
|
||||||
|
if (typeof model === 'string') {
|
||||||
|
this.pluginEngine.usePlugins([
|
||||||
|
this.createResolveModelPlugin(options?.middlewares),
|
||||||
|
this.createConfigureContextPlugin()
|
||||||
|
])
|
||||||
|
} else {
|
||||||
|
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.pluginEngine.executeWithPlugins(
|
||||||
|
'generateText',
|
||||||
|
model,
|
||||||
|
restParams,
|
||||||
|
async (resolvedModel, transformedParams) =>
|
||||||
|
generateText({ model: resolvedModel, ...transformedParams } as Parameters<typeof generateText>[0])
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 生成结构化对象
|
||||||
|
*/
|
||||||
|
async generateObject(
|
||||||
|
params: Parameters<typeof generateObject>[0],
|
||||||
|
options?: {
|
||||||
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
|
}
|
||||||
|
): Promise<ReturnType<typeof generateObject>> {
|
||||||
|
const { model, ...restParams } = params
|
||||||
|
|
||||||
|
// 根据 model 类型决定插件配置
|
||||||
|
if (typeof model === 'string') {
|
||||||
|
this.pluginEngine.usePlugins([
|
||||||
|
this.createResolveModelPlugin(options?.middlewares),
|
||||||
|
this.createConfigureContextPlugin()
|
||||||
|
])
|
||||||
|
} else {
|
||||||
|
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.pluginEngine.executeWithPlugins(
|
||||||
|
'generateObject',
|
||||||
|
model,
|
||||||
|
restParams,
|
||||||
|
async (resolvedModel, transformedParams) =>
|
||||||
|
generateObject({ model: resolvedModel, ...transformedParams } as Parameters<typeof generateObject>[0])
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 流式生成结构化对象
|
||||||
|
*/
|
||||||
|
async streamObject(
|
||||||
|
params: Parameters<typeof streamObject>[0],
|
||||||
|
options?: {
|
||||||
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
|
}
|
||||||
|
): Promise<ReturnType<typeof streamObject>> {
|
||||||
|
const { model, ...restParams } = params
|
||||||
|
|
||||||
|
// 根据 model 类型决定插件配置
|
||||||
|
if (typeof model === 'string') {
|
||||||
|
this.pluginEngine.usePlugins([
|
||||||
|
this.createResolveModelPlugin(options?.middlewares),
|
||||||
|
this.createConfigureContextPlugin()
|
||||||
|
])
|
||||||
|
} else {
|
||||||
|
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.pluginEngine.executeWithPlugins(
|
||||||
|
'streamObject',
|
||||||
|
model,
|
||||||
|
restParams,
|
||||||
|
async (resolvedModel, transformedParams) =>
|
||||||
|
streamObject({ model: resolvedModel, ...transformedParams } as Parameters<typeof streamObject>[0])
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 生成图像
|
||||||
|
*/
|
||||||
|
async generateImage(
|
||||||
|
params: Omit<Parameters<typeof generateImage>[0], 'model'> & { model: string | ImageModelV2 }
|
||||||
|
): Promise<ReturnType<typeof generateImage>> {
|
||||||
|
try {
|
||||||
|
const { model, ...restParams } = params
|
||||||
|
|
||||||
|
// 根据 model 类型决定插件配置
|
||||||
|
if (typeof model === 'string') {
|
||||||
|
this.pluginEngine.usePlugins([this.createResolveImageModelPlugin(), this.createConfigureContextPlugin()])
|
||||||
|
} else {
|
||||||
|
this.pluginEngine.usePlugins([this.createConfigureContextPlugin()])
|
||||||
|
}
|
||||||
|
|
||||||
|
return await this.pluginEngine.executeImageWithPlugins(
|
||||||
|
'generateImage',
|
||||||
|
model,
|
||||||
|
restParams,
|
||||||
|
async (resolvedModel, transformedParams) => {
|
||||||
|
return await generateImage({ model: resolvedModel, ...transformedParams })
|
||||||
|
}
|
||||||
|
)
|
||||||
|
} catch (error) {
|
||||||
|
if (error instanceof Error) {
|
||||||
|
const modelId = typeof params.model === 'string' ? params.model : params.model.modelId
|
||||||
|
throw new ImageGenerationError(
|
||||||
|
`Failed to generate image: ${error.message}`,
|
||||||
|
this.config.providerId,
|
||||||
|
modelId,
|
||||||
|
error
|
||||||
|
)
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// === 辅助方法 ===
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析模型:如果是字符串则创建模型,如果是模型则直接返回
|
||||||
|
*/
|
||||||
|
private async resolveModel(
|
||||||
|
modelOrId: LanguageModel,
|
||||||
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
|
): Promise<LanguageModelV2> {
|
||||||
|
if (typeof modelOrId === 'string') {
|
||||||
|
// 🎯 字符串modelId,使用新的ModelResolver解析,传递完整参数
|
||||||
|
return await globalModelResolver.resolveLanguageModel(
|
||||||
|
modelOrId, // 支持 'gpt-4' 和 'aihubmix:anthropic:claude-3.5-sonnet'
|
||||||
|
this.config.providerId, // fallback provider
|
||||||
|
this.config.providerSettings, // provider options
|
||||||
|
middlewares // 中间件数组
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// 已经是模型,直接返回
|
||||||
|
return modelOrId
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 解析图像模型:如果是字符串则创建图像模型,如果是模型则直接返回
|
||||||
|
*/
|
||||||
|
private async resolveImageModel(modelOrId: ImageModelV2 | string): Promise<ImageModelV2> {
|
||||||
|
try {
|
||||||
|
if (typeof modelOrId === 'string') {
|
||||||
|
// 字符串modelId,使用新的ModelResolver解析
|
||||||
|
return await globalModelResolver.resolveImageModel(
|
||||||
|
modelOrId, // 支持 'dall-e-3' 和 'aihubmix:openai:dall-e-3'
|
||||||
|
this.config.providerId // fallback provider
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
// 已经是模型,直接返回
|
||||||
|
return modelOrId
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
throw new ImageModelResolutionError(
|
||||||
|
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
|
||||||
|
this.config.providerId,
|
||||||
|
error instanceof Error ? error : undefined
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// === 静态工厂方法 ===
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建执行器 - 支持已知provider的类型安全
|
||||||
|
*/
|
||||||
|
static create<T extends ProviderId>(
|
||||||
|
providerId: T,
|
||||||
|
options: ModelConfig<T>['providerSettings'],
|
||||||
|
plugins?: AiPlugin[]
|
||||||
|
): RuntimeExecutor<T> {
|
||||||
|
return new RuntimeExecutor({
|
||||||
|
providerId,
|
||||||
|
providerSettings: options,
|
||||||
|
plugins
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建OpenAI Compatible执行器
|
||||||
|
*/
|
||||||
|
static createOpenAICompatible(
|
||||||
|
options: ModelConfig<'openai-compatible'>['providerSettings'],
|
||||||
|
plugins: AiPlugin[] = []
|
||||||
|
): RuntimeExecutor<'openai-compatible'> {
|
||||||
|
return new RuntimeExecutor({
|
||||||
|
providerId: 'openai-compatible',
|
||||||
|
providerSettings: options,
|
||||||
|
plugins
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
117
packages/aiCore/src/core/runtime/index.ts
Normal file
117
packages/aiCore/src/core/runtime/index.ts
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
/**
|
||||||
|
* Runtime 模块导出
|
||||||
|
* 专注于运行时插件化AI调用处理
|
||||||
|
*/
|
||||||
|
|
||||||
|
// 主要的运行时执行器
|
||||||
|
export { RuntimeExecutor } from './executor'
|
||||||
|
|
||||||
|
// 导出类型
|
||||||
|
export type { RuntimeConfig } from './types'
|
||||||
|
|
||||||
|
// === 便捷工厂函数 ===
|
||||||
|
|
||||||
|
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
|
||||||
|
|
||||||
|
import { type AiPlugin } from '../plugins'
|
||||||
|
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
|
||||||
|
import { RuntimeExecutor } from './executor'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建运行时执行器 - 支持类型安全的已知provider
|
||||||
|
*/
|
||||||
|
export function createExecutor<T extends ProviderId>(
|
||||||
|
providerId: T,
|
||||||
|
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||||
|
plugins?: AiPlugin[]
|
||||||
|
): RuntimeExecutor<T> {
|
||||||
|
return RuntimeExecutor.create(providerId, options, plugins)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建OpenAI Compatible执行器
|
||||||
|
*/
|
||||||
|
export function createOpenAICompatibleExecutor(
|
||||||
|
options: ProviderSettingsMap['openai-compatible'] & { mode?: 'chat' | 'responses' },
|
||||||
|
plugins: AiPlugin[] = []
|
||||||
|
): RuntimeExecutor<'openai-compatible'> {
|
||||||
|
return RuntimeExecutor.createOpenAICompatible(options, plugins)
|
||||||
|
}
|
||||||
|
|
||||||
|
// === 直接调用API(无需创建executor实例)===
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 直接流式文本生成 - 支持middlewares
|
||||||
|
*/
|
||||||
|
export async function streamText<T extends ProviderId>(
|
||||||
|
providerId: T,
|
||||||
|
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||||
|
params: Parameters<RuntimeExecutor<T>['streamText']>[0],
|
||||||
|
plugins?: AiPlugin[],
|
||||||
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
|
): Promise<ReturnType<RuntimeExecutor<T>['streamText']>> {
|
||||||
|
const executor = createExecutor(providerId, options, plugins)
|
||||||
|
return executor.streamText(params, { middlewares })
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 直接生成文本 - 支持middlewares
|
||||||
|
*/
|
||||||
|
export async function generateText<T extends ProviderId>(
|
||||||
|
providerId: T,
|
||||||
|
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||||
|
params: Parameters<RuntimeExecutor<T>['generateText']>[0],
|
||||||
|
plugins?: AiPlugin[],
|
||||||
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
|
): Promise<ReturnType<RuntimeExecutor<T>['generateText']>> {
|
||||||
|
const executor = createExecutor(providerId, options, plugins)
|
||||||
|
return executor.generateText(params, { middlewares })
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 直接生成结构化对象 - 支持middlewares
|
||||||
|
*/
|
||||||
|
export async function generateObject<T extends ProviderId>(
|
||||||
|
providerId: T,
|
||||||
|
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||||
|
params: Parameters<RuntimeExecutor<T>['generateObject']>[0],
|
||||||
|
plugins?: AiPlugin[],
|
||||||
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
|
): Promise<ReturnType<RuntimeExecutor<T>['generateObject']>> {
|
||||||
|
const executor = createExecutor(providerId, options, plugins)
|
||||||
|
return executor.generateObject(params, { middlewares })
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 直接流式生成结构化对象 - 支持middlewares
|
||||||
|
*/
|
||||||
|
export async function streamObject<T extends ProviderId>(
|
||||||
|
providerId: T,
|
||||||
|
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||||
|
params: Parameters<RuntimeExecutor<T>['streamObject']>[0],
|
||||||
|
plugins?: AiPlugin[],
|
||||||
|
middlewares?: LanguageModelV2Middleware[]
|
||||||
|
): Promise<ReturnType<RuntimeExecutor<T>['streamObject']>> {
|
||||||
|
const executor = createExecutor(providerId, options, plugins)
|
||||||
|
return executor.streamObject(params, { middlewares })
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 直接生成图像 - 支持middlewares
|
||||||
|
*/
|
||||||
|
export async function generateImage<T extends ProviderId>(
|
||||||
|
providerId: T,
|
||||||
|
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
|
||||||
|
params: Parameters<RuntimeExecutor<T>['generateImage']>[0],
|
||||||
|
plugins?: AiPlugin[]
|
||||||
|
): Promise<ReturnType<RuntimeExecutor<T>['generateImage']>> {
|
||||||
|
const executor = createExecutor(providerId, options, plugins)
|
||||||
|
return executor.generateImage(params)
|
||||||
|
}
|
||||||
|
|
||||||
|
// === Agent 功能预留 ===
|
||||||
|
// 未来将在 ../agents/ 文件夹中添加:
|
||||||
|
// - AgentExecutor.ts
|
||||||
|
// - WorkflowManager.ts
|
||||||
|
// - ConversationManager.ts
|
||||||
|
// 并在此处导出相关API
|
||||||
290
packages/aiCore/src/core/runtime/pluginEngine.ts
Normal file
290
packages/aiCore/src/core/runtime/pluginEngine.ts
Normal file
@@ -0,0 +1,290 @@
|
|||||||
|
/* eslint-disable @eslint-react/naming-convention/context-name */
|
||||||
|
import { ImageModelV2 } from '@ai-sdk/provider'
|
||||||
|
import { LanguageModel } from 'ai'
|
||||||
|
|
||||||
|
import { type AiPlugin, createContext, PluginManager } from '../plugins'
|
||||||
|
import { type ProviderId } from '../providers/types'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 插件增强的 AI 客户端
|
||||||
|
* 专注于插件处理,不暴露用户API
|
||||||
|
*/
|
||||||
|
export class PluginEngine<T extends ProviderId = ProviderId> {
|
||||||
|
private pluginManager: PluginManager
|
||||||
|
|
||||||
|
constructor(
|
||||||
|
private readonly providerId: T,
|
||||||
|
// private readonly options: ProviderSettingsMap[T],
|
||||||
|
plugins: AiPlugin[] = []
|
||||||
|
) {
|
||||||
|
this.pluginManager = new PluginManager(plugins)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 添加插件
|
||||||
|
*/
|
||||||
|
use(plugin: AiPlugin): this {
|
||||||
|
this.pluginManager.use(plugin)
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 批量添加插件
|
||||||
|
*/
|
||||||
|
usePlugins(plugins: AiPlugin[]): this {
|
||||||
|
plugins.forEach((plugin) => this.use(plugin))
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 移除插件
|
||||||
|
*/
|
||||||
|
removePlugin(pluginName: string): this {
|
||||||
|
this.pluginManager.remove(pluginName)
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取插件统计
|
||||||
|
*/
|
||||||
|
getPluginStats() {
|
||||||
|
return this.pluginManager.getStats()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取所有插件
|
||||||
|
*/
|
||||||
|
getPlugins() {
|
||||||
|
return this.pluginManager.getPlugins()
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行带插件的操作(非流式)
|
||||||
|
* 提供给AiExecutor使用
|
||||||
|
*/
|
||||||
|
async executeWithPlugins<TParams, TResult>(
|
||||||
|
methodName: string,
|
||||||
|
model: LanguageModel,
|
||||||
|
params: TParams,
|
||||||
|
executor: (model: LanguageModel, transformedParams: TParams) => Promise<TResult>,
|
||||||
|
_context?: ReturnType<typeof createContext>
|
||||||
|
): Promise<TResult> {
|
||||||
|
// 统一处理模型解析
|
||||||
|
let resolvedModel: LanguageModel | undefined
|
||||||
|
let modelId: string
|
||||||
|
|
||||||
|
if (typeof model === 'string') {
|
||||||
|
// 字符串:需要通过插件解析
|
||||||
|
modelId = model
|
||||||
|
} else {
|
||||||
|
// 模型对象:直接使用
|
||||||
|
resolvedModel = model
|
||||||
|
modelId = model.modelId
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用正确的createContext创建请求上下文
|
||||||
|
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||||
|
|
||||||
|
// 🔥 为上下文添加递归调用能力
|
||||||
|
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||||
|
// 递归调用自身,重新走完整的插件流程
|
||||||
|
context.isRecursiveCall = true
|
||||||
|
const result = await this.executeWithPlugins(methodName, model, newParams, executor, context)
|
||||||
|
context.isRecursiveCall = false
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 0. 配置上下文
|
||||||
|
await this.pluginManager.executeConfigureContext(context)
|
||||||
|
|
||||||
|
// 1. 触发请求开始事件
|
||||||
|
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||||
|
|
||||||
|
// 2. 解析模型(如果是字符串)
|
||||||
|
if (typeof model === 'string') {
|
||||||
|
const resolved = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||||
|
if (!resolved) {
|
||||||
|
throw new Error(`Failed to resolve model: ${modelId}`)
|
||||||
|
}
|
||||||
|
resolvedModel = resolved
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!resolvedModel) {
|
||||||
|
throw new Error(`Model resolution failed: no model available`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 转换请求参数
|
||||||
|
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||||
|
|
||||||
|
// 4. 执行具体的 API 调用
|
||||||
|
const result = await executor(resolvedModel, transformedParams)
|
||||||
|
|
||||||
|
// 5. 转换结果(对于非流式调用)
|
||||||
|
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||||
|
|
||||||
|
// 6. 触发完成事件
|
||||||
|
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
|
||||||
|
|
||||||
|
return transformedResult
|
||||||
|
} catch (error) {
|
||||||
|
// 7. 触发错误事件
|
||||||
|
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行带插件的图像生成操作
|
||||||
|
* 提供给AiExecutor使用
|
||||||
|
*/
|
||||||
|
async executeImageWithPlugins<TParams, TResult>(
|
||||||
|
methodName: string,
|
||||||
|
model: ImageModelV2 | string,
|
||||||
|
params: TParams,
|
||||||
|
executor: (model: ImageModelV2, transformedParams: TParams) => Promise<TResult>,
|
||||||
|
_context?: ReturnType<typeof createContext>
|
||||||
|
): Promise<TResult> {
|
||||||
|
// 统一处理模型解析
|
||||||
|
let resolvedModel: ImageModelV2 | undefined
|
||||||
|
let modelId: string
|
||||||
|
|
||||||
|
if (typeof model === 'string') {
|
||||||
|
// 字符串:需要通过插件解析
|
||||||
|
modelId = model
|
||||||
|
} else {
|
||||||
|
// 模型对象:直接使用
|
||||||
|
resolvedModel = model
|
||||||
|
modelId = model.modelId
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用正确的createContext创建请求上下文
|
||||||
|
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||||
|
|
||||||
|
// 🔥 为上下文添加递归调用能力
|
||||||
|
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||||
|
// 递归调用自身,重新走完整的插件流程
|
||||||
|
context.isRecursiveCall = true
|
||||||
|
const result = await this.executeImageWithPlugins(methodName, model, newParams, executor, context)
|
||||||
|
context.isRecursiveCall = false
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 0. 配置上下文
|
||||||
|
await this.pluginManager.executeConfigureContext(context)
|
||||||
|
|
||||||
|
// 1. 触发请求开始事件
|
||||||
|
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||||
|
|
||||||
|
// 2. 解析模型(如果是字符串)
|
||||||
|
if (typeof model === 'string') {
|
||||||
|
const resolved = await this.pluginManager.executeFirst<ImageModelV2>('resolveModel', modelId, context)
|
||||||
|
if (!resolved) {
|
||||||
|
throw new Error(`Failed to resolve image model: ${modelId}`)
|
||||||
|
}
|
||||||
|
resolvedModel = resolved
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!resolvedModel) {
|
||||||
|
throw new Error(`Image model resolution failed: no model available`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 转换请求参数
|
||||||
|
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||||
|
|
||||||
|
// 4. 执行具体的 API 调用
|
||||||
|
const result = await executor(resolvedModel, transformedParams)
|
||||||
|
|
||||||
|
// 5. 转换结果
|
||||||
|
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||||
|
|
||||||
|
// 6. 触发完成事件
|
||||||
|
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
|
||||||
|
|
||||||
|
return transformedResult
|
||||||
|
} catch (error) {
|
||||||
|
// 7. 触发错误事件
|
||||||
|
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行流式调用的通用逻辑(支持流转换器)
|
||||||
|
* 提供给AiExecutor使用
|
||||||
|
*/
|
||||||
|
async executeStreamWithPlugins<TParams, TResult>(
|
||||||
|
methodName: string,
|
||||||
|
model: LanguageModel,
|
||||||
|
params: TParams,
|
||||||
|
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
|
||||||
|
_context?: ReturnType<typeof createContext>
|
||||||
|
): Promise<TResult> {
|
||||||
|
// 统一处理模型解析
|
||||||
|
let resolvedModel: LanguageModel | undefined
|
||||||
|
let modelId: string
|
||||||
|
|
||||||
|
if (typeof model === 'string') {
|
||||||
|
// 字符串:需要通过插件解析
|
||||||
|
modelId = model
|
||||||
|
} else {
|
||||||
|
// 模型对象:直接使用
|
||||||
|
resolvedModel = model
|
||||||
|
modelId = model.modelId
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建请求上下文
|
||||||
|
const context = _context ? _context : createContext(this.providerId, modelId, params)
|
||||||
|
|
||||||
|
// 🔥 为上下文添加递归调用能力
|
||||||
|
context.recursiveCall = async (newParams: any): Promise<TResult> => {
|
||||||
|
// 递归调用自身,重新走完整的插件流程
|
||||||
|
context.isRecursiveCall = true
|
||||||
|
const result = await this.executeStreamWithPlugins(methodName, model, newParams, executor, context)
|
||||||
|
context.isRecursiveCall = false
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 0. 配置上下文
|
||||||
|
await this.pluginManager.executeConfigureContext(context)
|
||||||
|
|
||||||
|
// 1. 触发请求开始事件
|
||||||
|
await this.pluginManager.executeParallel('onRequestStart', context)
|
||||||
|
|
||||||
|
// 2. 解析模型(如果是字符串)
|
||||||
|
if (typeof model === 'string') {
|
||||||
|
const resolved = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
|
||||||
|
if (!resolved) {
|
||||||
|
throw new Error(`Failed to resolve model: ${modelId}`)
|
||||||
|
}
|
||||||
|
resolvedModel = resolved
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!resolvedModel) {
|
||||||
|
throw new Error(`Model resolution failed: no model available`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 转换请求参数
|
||||||
|
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
|
||||||
|
|
||||||
|
// 4. 收集流转换器
|
||||||
|
const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context)
|
||||||
|
|
||||||
|
// 5. 执行流式 API 调用
|
||||||
|
const result = await executor(resolvedModel, transformedParams, streamTransforms)
|
||||||
|
|
||||||
|
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
|
||||||
|
|
||||||
|
// 6. 触发完成事件(注意:对于流式调用,这里触发的是开始流式响应的事件)
|
||||||
|
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
|
||||||
|
|
||||||
|
return transformedResult
|
||||||
|
} catch (error) {
|
||||||
|
// 7. 触发错误事件
|
||||||
|
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
15
packages/aiCore/src/core/runtime/types.ts
Normal file
15
packages/aiCore/src/core/runtime/types.ts
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
/**
|
||||||
|
* Runtime 层类型定义
|
||||||
|
*/
|
||||||
|
import { type ModelConfig } from '../models/types'
|
||||||
|
import { type AiPlugin } from '../plugins'
|
||||||
|
import { type ProviderId } from '../providers/types'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 运行时执行器配置
|
||||||
|
*/
|
||||||
|
export interface RuntimeConfig<T extends ProviderId = ProviderId> {
|
||||||
|
providerId: T
|
||||||
|
providerSettings: ModelConfig<T>['providerSettings'] & { mode?: 'chat' | 'responses' }
|
||||||
|
plugins?: AiPlugin[]
|
||||||
|
}
|
||||||
46
packages/aiCore/src/index.ts
Normal file
46
packages/aiCore/src/index.ts
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
/**
|
||||||
|
* Cherry Studio AI Core Package
|
||||||
|
* 基于 Vercel AI SDK 的统一 AI Provider 接口
|
||||||
|
*/
|
||||||
|
|
||||||
|
// 导入内部使用的类和函数
|
||||||
|
|
||||||
|
// ==================== 主要用户接口 ====================
|
||||||
|
export {
|
||||||
|
createExecutor,
|
||||||
|
createOpenAICompatibleExecutor,
|
||||||
|
generateImage,
|
||||||
|
generateObject,
|
||||||
|
generateText,
|
||||||
|
streamText
|
||||||
|
} from './core/runtime'
|
||||||
|
|
||||||
|
// ==================== 高级API ====================
|
||||||
|
export { globalModelResolver as modelResolver } from './core/models'
|
||||||
|
|
||||||
|
// ==================== 插件系统 ====================
|
||||||
|
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins'
|
||||||
|
export { createContext, definePlugin, PluginManager } from './core/plugins'
|
||||||
|
// export { createPromptToolUsePlugin, webSearchPlugin } from './core/plugins/built-in'
|
||||||
|
export { PluginEngine } from './core/runtime/pluginEngine'
|
||||||
|
|
||||||
|
// ==================== AI SDK 常用类型导出 ====================
|
||||||
|
// 直接导出 AI SDK 的常用类型,方便使用
|
||||||
|
export type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider'
|
||||||
|
export type { ToolCall } from '@ai-sdk/provider-utils'
|
||||||
|
export type { ReasoningPart } from '@ai-sdk/provider-utils'
|
||||||
|
|
||||||
|
// ==================== 选项 ====================
|
||||||
|
export {
|
||||||
|
createAnthropicOptions,
|
||||||
|
createGoogleOptions,
|
||||||
|
createOpenAIOptions,
|
||||||
|
type ExtractProviderOptions,
|
||||||
|
mergeProviderOptions,
|
||||||
|
type ProviderOptionsMap,
|
||||||
|
type TypedProviderOptions
|
||||||
|
} from './core/options'
|
||||||
|
|
||||||
|
// ==================== 包信息 ====================
|
||||||
|
export const AI_CORE_VERSION = '1.0.0'
|
||||||
|
export const AI_CORE_NAME = '@cherrystudio/ai-core'
|
||||||
2
packages/aiCore/src/types.ts
Normal file
2
packages/aiCore/src/types.ts
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
// 重新导出插件类型
|
||||||
|
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types'
|
||||||
26
packages/aiCore/tsconfig.json
Normal file
26
packages/aiCore/tsconfig.json
Normal file
@@ -0,0 +1,26 @@
|
|||||||
|
{
|
||||||
|
"compilerOptions": {
|
||||||
|
"target": "ES2020",
|
||||||
|
"module": "ESNext",
|
||||||
|
"moduleResolution": "bundler",
|
||||||
|
"declaration": true,
|
||||||
|
"outDir": "./dist",
|
||||||
|
"rootDir": "./src",
|
||||||
|
"strict": true,
|
||||||
|
"esModuleInterop": true,
|
||||||
|
"skipLibCheck": true,
|
||||||
|
"forceConsistentCasingInFileNames": true,
|
||||||
|
"resolveJsonModule": true,
|
||||||
|
"allowSyntheticDefaultImports": true,
|
||||||
|
"noEmitOnError": false,
|
||||||
|
"experimentalDecorators": true,
|
||||||
|
"emitDecoratorMetadata": true
|
||||||
|
},
|
||||||
|
"include": [
|
||||||
|
"src/**/*"
|
||||||
|
],
|
||||||
|
"exclude": [
|
||||||
|
"node_modules",
|
||||||
|
"dist"
|
||||||
|
]
|
||||||
|
}
|
||||||
14
packages/aiCore/tsdown.config.ts
Normal file
14
packages/aiCore/tsdown.config.ts
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
import { defineConfig } from 'tsdown'
|
||||||
|
|
||||||
|
export default defineConfig({
|
||||||
|
entry: {
|
||||||
|
index: 'src/index.ts',
|
||||||
|
'built-in/plugins/index': 'src/core/plugins/built-in/index.ts',
|
||||||
|
'provider/index': 'src/core/providers/index.ts'
|
||||||
|
},
|
||||||
|
outDir: 'dist',
|
||||||
|
format: ['esm', 'cjs'],
|
||||||
|
clean: true,
|
||||||
|
dts: true,
|
||||||
|
tsconfig: 'tsconfig.json'
|
||||||
|
})
|
||||||
15
packages/aiCore/vitest.config.ts
Normal file
15
packages/aiCore/vitest.config.ts
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
import { defineConfig } from 'vitest/config'
|
||||||
|
|
||||||
|
export default defineConfig({
|
||||||
|
test: {
|
||||||
|
globals: true
|
||||||
|
},
|
||||||
|
resolve: {
|
||||||
|
alias: {
|
||||||
|
'@': './src'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
esbuild: {
|
||||||
|
target: 'node18'
|
||||||
|
}
|
||||||
|
})
|
||||||
@@ -570,7 +570,8 @@ class McpService {
|
|||||||
...tool,
|
...tool,
|
||||||
id: buildFunctionCallToolName(server.name, tool.name),
|
id: buildFunctionCallToolName(server.name, tool.name),
|
||||||
serverId: server.id,
|
serverId: server.id,
|
||||||
serverName: server.name
|
serverName: server.name,
|
||||||
|
type: 'mcp'
|
||||||
}
|
}
|
||||||
serverTools.push(serverTool)
|
serverTools.push(serverTool)
|
||||||
})
|
})
|
||||||
|
|||||||
314
src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts
Normal file
314
src/renderer/src/aiCore/chunk/AiSdkToChunkAdapter.ts
Normal file
@@ -0,0 +1,314 @@
|
|||||||
|
/**
|
||||||
|
* AI SDK 到 Cherry Studio Chunk 适配器
|
||||||
|
* 用于将 AI SDK 的 fullStream 转换为 Cherry Studio 的 chunk 格式
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { loggerService } from '@logger'
|
||||||
|
import { MCPTool, WebSearchResults, WebSearchSource } from '@renderer/types'
|
||||||
|
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||||
|
import type { TextStreamPart, ToolSet } from 'ai'
|
||||||
|
|
||||||
|
import { ToolCallChunkHandler } from './handleToolCallChunk'
|
||||||
|
|
||||||
|
const logger = loggerService.withContext('AiSdkToChunkAdapter')
|
||||||
|
|
||||||
|
export interface CherryStudioChunk {
|
||||||
|
type: 'text-delta' | 'text-complete' | 'tool-call' | 'tool-result' | 'finish' | 'error'
|
||||||
|
text?: string
|
||||||
|
toolCall?: any
|
||||||
|
toolResult?: any
|
||||||
|
finishReason?: string
|
||||||
|
usage?: any
|
||||||
|
error?: any
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AI SDK 到 Cherry Studio Chunk 适配器类
|
||||||
|
* 处理 fullStream 到 Cherry Studio chunk 的转换
|
||||||
|
*/
|
||||||
|
export class AiSdkToChunkAdapter {
|
||||||
|
toolCallHandler: ToolCallChunkHandler
|
||||||
|
private accumulate: boolean | undefined
|
||||||
|
constructor(
|
||||||
|
private onChunk: (chunk: Chunk) => void,
|
||||||
|
mcpTools: MCPTool[] = [],
|
||||||
|
accumulate?: boolean
|
||||||
|
) {
|
||||||
|
this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools)
|
||||||
|
this.accumulate = accumulate
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 处理 AI SDK 流结果
|
||||||
|
* @param aiSdkResult AI SDK 的流结果对象
|
||||||
|
* @returns 最终的文本内容
|
||||||
|
*/
|
||||||
|
async processStream(aiSdkResult: any): Promise<string> {
|
||||||
|
// 如果是流式且有 fullStream
|
||||||
|
if (aiSdkResult.fullStream) {
|
||||||
|
await this.readFullStream(aiSdkResult.fullStream)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用 streamResult.text 获取最终结果
|
||||||
|
return await aiSdkResult.text
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 读取 fullStream 并转换为 Cherry Studio chunks
|
||||||
|
* @param fullStream AI SDK 的 fullStream (ReadableStream)
|
||||||
|
*/
|
||||||
|
private async readFullStream(fullStream: ReadableStream<TextStreamPart<ToolSet>>) {
|
||||||
|
const reader = fullStream.getReader()
|
||||||
|
const final = {
|
||||||
|
text: '',
|
||||||
|
reasoningContent: '',
|
||||||
|
webSearchResults: [],
|
||||||
|
reasoningId: ''
|
||||||
|
}
|
||||||
|
try {
|
||||||
|
while (true) {
|
||||||
|
const { done, value } = await reader.read()
|
||||||
|
|
||||||
|
if (done) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转换并发送 chunk
|
||||||
|
this.convertAndEmitChunk(value, final)
|
||||||
|
}
|
||||||
|
} finally {
|
||||||
|
reader.releaseLock()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 转换 AI SDK chunk 为 Cherry Studio chunk 并调用回调
|
||||||
|
* @param chunk AI SDK 的 chunk 数据
|
||||||
|
*/
|
||||||
|
private convertAndEmitChunk(
|
||||||
|
chunk: TextStreamPart<any>,
|
||||||
|
final: { text: string; reasoningContent: string; webSearchResults: any[]; reasoningId: string }
|
||||||
|
) {
|
||||||
|
logger.info(`AI SDK chunk type: ${chunk.type}`, chunk)
|
||||||
|
switch (chunk.type) {
|
||||||
|
// === 文本相关事件 ===
|
||||||
|
case 'text-start':
|
||||||
|
this.onChunk({
|
||||||
|
type: ChunkType.TEXT_START
|
||||||
|
})
|
||||||
|
break
|
||||||
|
case 'text-delta':
|
||||||
|
if (this.accumulate) {
|
||||||
|
final.text += chunk.text || ''
|
||||||
|
} else {
|
||||||
|
final.text = chunk.text || ''
|
||||||
|
}
|
||||||
|
this.onChunk({
|
||||||
|
type: ChunkType.TEXT_DELTA,
|
||||||
|
text: final.text || ''
|
||||||
|
})
|
||||||
|
break
|
||||||
|
case 'text-end':
|
||||||
|
this.onChunk({
|
||||||
|
type: ChunkType.TEXT_COMPLETE,
|
||||||
|
text: (chunk.providerMetadata?.text?.value as string) ?? final.text ?? ''
|
||||||
|
})
|
||||||
|
final.text = ''
|
||||||
|
break
|
||||||
|
case 'reasoning-start':
|
||||||
|
// if (final.reasoningId !== chunk.id) {
|
||||||
|
final.reasoningId = chunk.id
|
||||||
|
this.onChunk({
|
||||||
|
type: ChunkType.THINKING_START
|
||||||
|
})
|
||||||
|
// }
|
||||||
|
break
|
||||||
|
case 'reasoning-delta':
|
||||||
|
final.reasoningContent += chunk.text || ''
|
||||||
|
this.onChunk({
|
||||||
|
type: ChunkType.THINKING_DELTA,
|
||||||
|
text: final.reasoningContent || '',
|
||||||
|
thinking_millsec: (chunk.providerMetadata?.metadata?.thinking_millsec as number) || 0
|
||||||
|
})
|
||||||
|
break
|
||||||
|
case 'reasoning-end':
|
||||||
|
this.onChunk({
|
||||||
|
type: ChunkType.THINKING_COMPLETE,
|
||||||
|
text: (chunk.providerMetadata?.metadata?.thinking_content as string) || '',
|
||||||
|
thinking_millsec: (chunk.providerMetadata?.metadata?.thinking_millsec as number) || 0
|
||||||
|
})
|
||||||
|
final.reasoningContent = ''
|
||||||
|
break
|
||||||
|
|
||||||
|
// === 工具调用相关事件(原始 AI SDK 事件,如果没有被中间件处理) ===
|
||||||
|
|
||||||
|
// case 'tool-input-start':
|
||||||
|
// case 'tool-input-delta':
|
||||||
|
// case 'tool-input-end':
|
||||||
|
// this.toolCallHandler.handleToolCallCreated(chunk)
|
||||||
|
// break
|
||||||
|
|
||||||
|
// case 'tool-input-delta':
|
||||||
|
// this.toolCallHandler.handleToolCallCreated(chunk)
|
||||||
|
// break
|
||||||
|
case 'tool-call':
|
||||||
|
// 原始的工具调用(未被中间件处理)
|
||||||
|
this.toolCallHandler.handleToolCall(chunk)
|
||||||
|
break
|
||||||
|
|
||||||
|
case 'tool-result':
|
||||||
|
// 原始的工具调用结果(未被中间件处理)
|
||||||
|
this.toolCallHandler.handleToolResult(chunk)
|
||||||
|
break
|
||||||
|
|
||||||
|
// === 步骤相关事件 ===
|
||||||
|
// case 'start':
|
||||||
|
// this.onChunk({
|
||||||
|
// type: ChunkType.LLM_RESPONSE_CREATED
|
||||||
|
// })
|
||||||
|
// break
|
||||||
|
// TODO: 需要区分接口开始和步骤开始
|
||||||
|
// case 'start-step':
|
||||||
|
// this.onChunk({
|
||||||
|
// type: ChunkType.BLOCK_CREATED
|
||||||
|
// })
|
||||||
|
// break
|
||||||
|
// case 'step-finish':
|
||||||
|
// this.onChunk({
|
||||||
|
// type: ChunkType.TEXT_COMPLETE,
|
||||||
|
// text: final.text || '' // TEXT_COMPLETE 需要 text 字段
|
||||||
|
// })
|
||||||
|
// final.text = ''
|
||||||
|
// break
|
||||||
|
|
||||||
|
case 'finish-step': {
|
||||||
|
const { providerMetadata, finishReason } = chunk
|
||||||
|
// googel web search
|
||||||
|
if (providerMetadata?.google?.groundingMetadata) {
|
||||||
|
this.onChunk({
|
||||||
|
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||||
|
llm_web_search: {
|
||||||
|
results: providerMetadata.google?.groundingMetadata as WebSearchResults,
|
||||||
|
source: WebSearchSource.GEMINI
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else if (final.webSearchResults.length) {
|
||||||
|
const providerName = Object.keys(providerMetadata || {})[0]
|
||||||
|
const sourceMap: Record<string, WebSearchSource> = {
|
||||||
|
[WebSearchSource.OPENAI]: WebSearchSource.OPENAI_RESPONSE,
|
||||||
|
[WebSearchSource.ANTHROPIC]: WebSearchSource.ANTHROPIC,
|
||||||
|
[WebSearchSource.OPENROUTER]: WebSearchSource.OPENROUTER,
|
||||||
|
[WebSearchSource.GEMINI]: WebSearchSource.GEMINI,
|
||||||
|
[WebSearchSource.PERPLEXITY]: WebSearchSource.PERPLEXITY,
|
||||||
|
[WebSearchSource.QWEN]: WebSearchSource.QWEN,
|
||||||
|
[WebSearchSource.HUNYUAN]: WebSearchSource.HUNYUAN,
|
||||||
|
[WebSearchSource.ZHIPU]: WebSearchSource.ZHIPU,
|
||||||
|
[WebSearchSource.GROK]: WebSearchSource.GROK,
|
||||||
|
[WebSearchSource.WEBSEARCH]: WebSearchSource.WEBSEARCH
|
||||||
|
}
|
||||||
|
const source = sourceMap[providerName] || WebSearchSource.AISDK
|
||||||
|
|
||||||
|
this.onChunk({
|
||||||
|
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||||
|
llm_web_search: {
|
||||||
|
results: final.webSearchResults,
|
||||||
|
source
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if (finishReason === 'tool-calls') {
|
||||||
|
this.onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
|
||||||
|
}
|
||||||
|
|
||||||
|
final.webSearchResults = []
|
||||||
|
// final.reasoningId = ''
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'finish':
|
||||||
|
this.onChunk({
|
||||||
|
type: ChunkType.BLOCK_COMPLETE,
|
||||||
|
response: {
|
||||||
|
text: final.text || '',
|
||||||
|
reasoning_content: final.reasoningContent || '',
|
||||||
|
usage: {
|
||||||
|
completion_tokens: chunk.totalUsage.outputTokens || 0,
|
||||||
|
prompt_tokens: chunk.totalUsage.inputTokens || 0,
|
||||||
|
total_tokens: chunk.totalUsage.totalTokens || 0
|
||||||
|
},
|
||||||
|
metrics: chunk.totalUsage
|
||||||
|
? {
|
||||||
|
completion_tokens: chunk.totalUsage.outputTokens || 0,
|
||||||
|
time_completion_millsec: 0
|
||||||
|
}
|
||||||
|
: undefined
|
||||||
|
}
|
||||||
|
})
|
||||||
|
this.onChunk({
|
||||||
|
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||||
|
response: {
|
||||||
|
text: final.text || '',
|
||||||
|
reasoning_content: final.reasoningContent || '',
|
||||||
|
usage: {
|
||||||
|
completion_tokens: chunk.totalUsage.outputTokens || 0,
|
||||||
|
prompt_tokens: chunk.totalUsage.inputTokens || 0,
|
||||||
|
total_tokens: chunk.totalUsage.totalTokens || 0
|
||||||
|
},
|
||||||
|
metrics: chunk.totalUsage
|
||||||
|
? {
|
||||||
|
completion_tokens: chunk.totalUsage.outputTokens || 0,
|
||||||
|
time_completion_millsec: 0
|
||||||
|
}
|
||||||
|
: undefined
|
||||||
|
}
|
||||||
|
})
|
||||||
|
break
|
||||||
|
|
||||||
|
// === 源和文件相关事件 ===
|
||||||
|
case 'source':
|
||||||
|
if (chunk.sourceType === 'url') {
|
||||||
|
// if (final.webSearchResults.length === 0) {
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||||
|
const { sourceType: _, ...rest } = chunk
|
||||||
|
final.webSearchResults.push(rest)
|
||||||
|
// }
|
||||||
|
// this.onChunk({
|
||||||
|
// type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||||
|
// llm_web_search: {
|
||||||
|
// source: WebSearchSource.AISDK,
|
||||||
|
// results: final.webSearchResults
|
||||||
|
// }
|
||||||
|
// })
|
||||||
|
}
|
||||||
|
break
|
||||||
|
case 'file':
|
||||||
|
// 文件相关事件,可能是图片生成
|
||||||
|
this.onChunk({
|
||||||
|
type: ChunkType.IMAGE_COMPLETE,
|
||||||
|
image: {
|
||||||
|
type: 'base64',
|
||||||
|
images: [`data:${chunk.file.mediaType};base64,${chunk.file.base64}`]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
break
|
||||||
|
case 'abort':
|
||||||
|
this.onChunk({
|
||||||
|
type: ChunkType.ERROR,
|
||||||
|
error: new DOMException('Request was aborted', 'AbortError')
|
||||||
|
})
|
||||||
|
break
|
||||||
|
case 'error':
|
||||||
|
this.onChunk({
|
||||||
|
type: ChunkType.ERROR,
|
||||||
|
error: chunk.error as Record<string, any>
|
||||||
|
})
|
||||||
|
break
|
||||||
|
|
||||||
|
default:
|
||||||
|
// 其他类型的 chunk 可以忽略或记录日志
|
||||||
|
// console.log('Unhandled AI SDK chunk type:', chunk.type, chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export default AiSdkToChunkAdapter
|
||||||
266
src/renderer/src/aiCore/chunk/handleToolCallChunk.ts
Normal file
266
src/renderer/src/aiCore/chunk/handleToolCallChunk.ts
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
/**
|
||||||
|
* 工具调用 Chunk 处理模块
|
||||||
|
* TODO: Tool包含了providerTool和普通的Tool还有MCPTool,后面需要重构
|
||||||
|
* 提供工具调用相关的处理API,每个交互使用一个新的实例
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { loggerService } from '@logger'
|
||||||
|
import { BaseTool, MCPTool, MCPToolResponse, NormalToolResponse } from '@renderer/types'
|
||||||
|
import { Chunk, ChunkType } from '@renderer/types/chunk'
|
||||||
|
import type { ProviderMetadata, ToolSet, TypedToolCall, TypedToolResult } from 'ai'
|
||||||
|
// import type {
|
||||||
|
// AnthropicSearchOutput,
|
||||||
|
// WebSearchPluginConfig
|
||||||
|
// } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin'
|
||||||
|
|
||||||
|
const logger = loggerService.withContext('ToolCallChunkHandler')
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 工具调用处理器类
|
||||||
|
*/
|
||||||
|
export class ToolCallChunkHandler {
|
||||||
|
// private onChunk: (chunk: Chunk) => void
|
||||||
|
private activeToolCalls = new Map<
|
||||||
|
string,
|
||||||
|
{
|
||||||
|
toolCallId: string
|
||||||
|
toolName: string
|
||||||
|
args: any
|
||||||
|
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
|
||||||
|
tool: BaseTool
|
||||||
|
}
|
||||||
|
>()
|
||||||
|
constructor(
|
||||||
|
private onChunk: (chunk: Chunk) => void,
|
||||||
|
private mcpTools: MCPTool[]
|
||||||
|
) {}
|
||||||
|
|
||||||
|
// /**
|
||||||
|
// * 设置 onChunk 回调
|
||||||
|
// */
|
||||||
|
// public setOnChunk(callback: (chunk: Chunk) => void): void {
|
||||||
|
// this.onChunk = callback
|
||||||
|
// }
|
||||||
|
|
||||||
|
handleToolCallCreated(
|
||||||
|
chunk:
|
||||||
|
| {
|
||||||
|
type: 'tool-input-start'
|
||||||
|
id: string
|
||||||
|
toolName: string
|
||||||
|
providerMetadata?: ProviderMetadata
|
||||||
|
providerExecuted?: boolean
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
type: 'tool-input-end'
|
||||||
|
id: string
|
||||||
|
providerMetadata?: ProviderMetadata
|
||||||
|
}
|
||||||
|
| {
|
||||||
|
type: 'tool-input-delta'
|
||||||
|
id: string
|
||||||
|
delta: string
|
||||||
|
providerMetadata?: ProviderMetadata
|
||||||
|
}
|
||||||
|
): void {
|
||||||
|
switch (chunk.type) {
|
||||||
|
case 'tool-input-start': {
|
||||||
|
// 能拿到说明是mcpTool
|
||||||
|
// if (this.activeToolCalls.get(chunk.id)) return
|
||||||
|
|
||||||
|
const tool: BaseTool | MCPTool = {
|
||||||
|
id: chunk.id,
|
||||||
|
name: chunk.toolName,
|
||||||
|
description: chunk.toolName,
|
||||||
|
type: chunk.toolName.startsWith('builtin_') ? 'builtin' : 'provider'
|
||||||
|
}
|
||||||
|
this.activeToolCalls.set(chunk.id, {
|
||||||
|
toolCallId: chunk.id,
|
||||||
|
toolName: chunk.toolName,
|
||||||
|
args: '',
|
||||||
|
tool
|
||||||
|
})
|
||||||
|
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||||
|
id: chunk.id,
|
||||||
|
tool: tool,
|
||||||
|
arguments: {},
|
||||||
|
status: 'pending',
|
||||||
|
toolCallId: chunk.id
|
||||||
|
}
|
||||||
|
this.onChunk({
|
||||||
|
type: ChunkType.MCP_TOOL_PENDING,
|
||||||
|
responses: [toolResponse]
|
||||||
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'tool-input-delta': {
|
||||||
|
const toolCall = this.activeToolCalls.get(chunk.id)
|
||||||
|
if (!toolCall) {
|
||||||
|
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
toolCall.args += chunk.delta
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'tool-input-end': {
|
||||||
|
const toolCall = this.activeToolCalls.get(chunk.id)
|
||||||
|
this.activeToolCalls.delete(chunk.id)
|
||||||
|
if (!toolCall) {
|
||||||
|
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// const toolResponse: ToolCallResponse = {
|
||||||
|
// id: toolCall.toolCallId,
|
||||||
|
// tool: toolCall.tool,
|
||||||
|
// arguments: toolCall.args,
|
||||||
|
// status: 'pending',
|
||||||
|
// toolCallId: toolCall.toolCallId
|
||||||
|
// }
|
||||||
|
// logger.debug('toolResponse', toolResponse)
|
||||||
|
// this.onChunk({
|
||||||
|
// type: ChunkType.MCP_TOOL_PENDING,
|
||||||
|
// responses: [toolResponse]
|
||||||
|
// })
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// if (!toolCall) {
|
||||||
|
// Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
// this.onChunk({
|
||||||
|
// type: ChunkType.MCP_TOOL_CREATED,
|
||||||
|
// tool_calls: [
|
||||||
|
// {
|
||||||
|
// id: chunk.id,
|
||||||
|
// name: chunk.toolName,
|
||||||
|
// status: 'pending'
|
||||||
|
// }
|
||||||
|
// ]
|
||||||
|
// })
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 处理工具调用事件
|
||||||
|
*/
|
||||||
|
public handleToolCall(
|
||||||
|
chunk: {
|
||||||
|
type: 'tool-call'
|
||||||
|
} & TypedToolCall<ToolSet>
|
||||||
|
): void {
|
||||||
|
const { toolCallId, toolName, input: args, providerExecuted } = chunk
|
||||||
|
|
||||||
|
if (!toolCallId || !toolName) {
|
||||||
|
logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool call chunk: missing toolCallId or toolName`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
let tool: BaseTool
|
||||||
|
let mcpTool: MCPTool | undefined
|
||||||
|
|
||||||
|
// 根据 providerExecuted 标志区分处理逻辑
|
||||||
|
if (providerExecuted) {
|
||||||
|
// 如果是 Provider 执行的工具(如 web_search)
|
||||||
|
logger.info(`[ToolCallChunkHandler] Handling provider-executed tool: ${toolName}`)
|
||||||
|
tool = {
|
||||||
|
id: toolCallId,
|
||||||
|
name: toolName,
|
||||||
|
description: toolName,
|
||||||
|
type: 'provider'
|
||||||
|
} as BaseTool
|
||||||
|
} else if (toolName.startsWith('builtin_')) {
|
||||||
|
// 如果是内置工具,沿用现有逻辑
|
||||||
|
logger.info(`[ToolCallChunkHandler] Handling builtin tool: ${toolName}`)
|
||||||
|
tool = {
|
||||||
|
id: toolCallId,
|
||||||
|
name: toolName,
|
||||||
|
description: toolName,
|
||||||
|
type: 'builtin'
|
||||||
|
} as BaseTool
|
||||||
|
} else if ((mcpTool = this.mcpTools.find((t) => t.name === toolName) as MCPTool)) {
|
||||||
|
// 如果是客户端执行的 MCP 工具,沿用现有逻辑
|
||||||
|
logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`)
|
||||||
|
// mcpTool = this.mcpTools.find((t) => t.name === toolName) as MCPTool
|
||||||
|
// if (!mcpTool) {
|
||||||
|
// logger.warn(`[ToolCallChunkHandler] MCP tool not found: ${toolName}`)
|
||||||
|
// return
|
||||||
|
// }
|
||||||
|
tool = mcpTool
|
||||||
|
} else {
|
||||||
|
tool = {
|
||||||
|
id: toolCallId,
|
||||||
|
name: toolName,
|
||||||
|
description: toolName,
|
||||||
|
type: 'provider'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 记录活跃的工具调用
|
||||||
|
this.activeToolCalls.set(toolCallId, {
|
||||||
|
toolCallId,
|
||||||
|
toolName,
|
||||||
|
args,
|
||||||
|
tool
|
||||||
|
})
|
||||||
|
|
||||||
|
// 创建 MCPToolResponse 格式
|
||||||
|
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||||
|
id: toolCallId,
|
||||||
|
tool: tool,
|
||||||
|
arguments: args,
|
||||||
|
status: 'pending',
|
||||||
|
toolCallId: toolCallId
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用 onChunk
|
||||||
|
if (this.onChunk) {
|
||||||
|
this.onChunk({
|
||||||
|
type: ChunkType.MCP_TOOL_PENDING,
|
||||||
|
responses: [toolResponse]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 处理工具调用结果事件
|
||||||
|
*/
|
||||||
|
public handleToolResult(
|
||||||
|
chunk: {
|
||||||
|
type: 'tool-result'
|
||||||
|
} & TypedToolResult<ToolSet>
|
||||||
|
): void {
|
||||||
|
const { toolCallId, output, input } = chunk
|
||||||
|
|
||||||
|
if (!toolCallId) {
|
||||||
|
logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool result chunk: missing toolCallId`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 查找对应的工具调用信息
|
||||||
|
const toolCallInfo = this.activeToolCalls.get(toolCallId)
|
||||||
|
if (!toolCallInfo) {
|
||||||
|
logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建工具调用结果的 MCPToolResponse 格式
|
||||||
|
const toolResponse: MCPToolResponse | NormalToolResponse = {
|
||||||
|
id: toolCallInfo.toolCallId,
|
||||||
|
tool: toolCallInfo.tool,
|
||||||
|
arguments: input,
|
||||||
|
status: 'done',
|
||||||
|
response: output,
|
||||||
|
toolCallId: toolCallId
|
||||||
|
}
|
||||||
|
// 从活跃调用中移除(交互结束后整个实例会被丢弃)
|
||||||
|
this.activeToolCalls.delete(toolCallId)
|
||||||
|
|
||||||
|
// 调用 onChunk
|
||||||
|
if (this.onChunk) {
|
||||||
|
this.onChunk({
|
||||||
|
type: ChunkType.MCP_TOOL_COMPLETE,
|
||||||
|
responses: [toolResponse]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,189 +1,16 @@
|
|||||||
import { loggerService } from '@logger'
|
/**
|
||||||
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
|
* Cherry Studio AI Core - 统一入口点
|
||||||
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
|
*
|
||||||
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
|
* 这是新的统一入口,保持向后兼容性
|
||||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
* 默认导出legacy AiProvider以保持现有代码的兼容性
|
||||||
import { withSpanResult } from '@renderer/services/SpanManagerService'
|
*/
|
||||||
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
|
||||||
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
|
||||||
import type { RequestOptions, SdkModel } from '@renderer/types/sdk'
|
|
||||||
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
|
|
||||||
|
|
||||||
import { AihubmixAPIClient } from './clients/aihubmix/AihubmixAPIClient'
|
// 导出Legacy AiProvider作为默认导出(保持向后兼容)
|
||||||
import { VertexAPIClient } from './clients/gemini/VertexAPIClient'
|
export { default } from './legacy/index'
|
||||||
import { NewAPIClient } from './clients/newapi/NewAPIClient'
|
|
||||||
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
|
|
||||||
import { CompletionsMiddlewareBuilder } from './middleware/builder'
|
|
||||||
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
|
|
||||||
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
|
|
||||||
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
|
|
||||||
import { applyCompletionsMiddlewares } from './middleware/composer'
|
|
||||||
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
|
|
||||||
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
|
|
||||||
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
|
|
||||||
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
|
|
||||||
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
|
|
||||||
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
|
|
||||||
import { MiddlewareRegistry } from './middleware/register'
|
|
||||||
import type { CompletionsParams, CompletionsResult } from './middleware/schemas'
|
|
||||||
|
|
||||||
const logger = loggerService.withContext('AiProvider')
|
// 同时导出Modern AiProvider供新代码使用
|
||||||
|
export { default as ModernAiProvider } from './index_new'
|
||||||
|
|
||||||
export default class AiProvider {
|
// 导出一些常用的类型和工具
|
||||||
private apiClient: BaseApiClient
|
export * from './legacy/clients/types'
|
||||||
|
export * from './legacy/middleware/schemas'
|
||||||
constructor(provider: Provider) {
|
|
||||||
// Use the new ApiClientFactory to get a BaseApiClient instance
|
|
||||||
this.apiClient = ApiClientFactory.create(provider)
|
|
||||||
}
|
|
||||||
|
|
||||||
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
|
||||||
// 1. 根据模型识别正确的客户端
|
|
||||||
const model = params.assistant.model
|
|
||||||
if (!model) {
|
|
||||||
return Promise.reject(new Error('Model is required'))
|
|
||||||
}
|
|
||||||
|
|
||||||
// 根据client类型选择合适的处理方式
|
|
||||||
let client: BaseApiClient
|
|
||||||
|
|
||||||
if (this.apiClient instanceof AihubmixAPIClient) {
|
|
||||||
// AihubmixAPIClient: 根据模型选择合适的子client
|
|
||||||
client = this.apiClient.getClientForModel(model)
|
|
||||||
if (client instanceof OpenAIResponseAPIClient) {
|
|
||||||
client = client.getClient(model) as BaseApiClient
|
|
||||||
}
|
|
||||||
} else if (this.apiClient instanceof NewAPIClient) {
|
|
||||||
client = this.apiClient.getClientForModel(model)
|
|
||||||
if (client instanceof OpenAIResponseAPIClient) {
|
|
||||||
client = client.getClient(model) as BaseApiClient
|
|
||||||
}
|
|
||||||
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
|
|
||||||
// OpenAIResponseAPIClient: 根据模型特征选择API类型
|
|
||||||
client = this.apiClient.getClient(model) as BaseApiClient
|
|
||||||
} else if (this.apiClient instanceof VertexAPIClient) {
|
|
||||||
client = this.apiClient.getClient(model) as BaseApiClient
|
|
||||||
} else {
|
|
||||||
// 其他client直接使用
|
|
||||||
client = this.apiClient
|
|
||||||
}
|
|
||||||
|
|
||||||
// 2. 构建中间件链
|
|
||||||
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
|
||||||
// images api
|
|
||||||
if (isDedicatedImageGenerationModel(model)) {
|
|
||||||
builder.clear()
|
|
||||||
builder
|
|
||||||
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
|
|
||||||
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
|
|
||||||
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
|
|
||||||
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
|
|
||||||
} else {
|
|
||||||
// Existing logic for other models
|
|
||||||
logger.silly('Builder Params', params)
|
|
||||||
// 使用兼容性类型检查,避免typescript类型收窄和装饰器模式的问题
|
|
||||||
const clientTypes = client.getClientCompatibilityType(model)
|
|
||||||
const isOpenAICompatible =
|
|
||||||
clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
|
|
||||||
if (!isOpenAICompatible) {
|
|
||||||
logger.silly('ThinkingTagExtractionMiddleware is removed')
|
|
||||||
builder.remove(ThinkingTagExtractionMiddlewareName)
|
|
||||||
}
|
|
||||||
|
|
||||||
const isAnthropicOrOpenAIResponseCompatible =
|
|
||||||
clientTypes.includes('AnthropicAPIClient') ||
|
|
||||||
clientTypes.includes('OpenAIResponseAPIClient') ||
|
|
||||||
clientTypes.includes('AnthropicVertexAPIClient')
|
|
||||||
if (!isAnthropicOrOpenAIResponseCompatible) {
|
|
||||||
logger.silly('RawStreamListenerMiddleware is removed')
|
|
||||||
builder.remove(RawStreamListenerMiddlewareName)
|
|
||||||
}
|
|
||||||
if (!params.enableWebSearch) {
|
|
||||||
logger.silly('WebSearchMiddleware is removed')
|
|
||||||
builder.remove(WebSearchMiddlewareName)
|
|
||||||
}
|
|
||||||
if (!params.mcpTools?.length) {
|
|
||||||
builder.remove(ToolUseExtractionMiddlewareName)
|
|
||||||
logger.silly('ToolUseExtractionMiddleware is removed')
|
|
||||||
builder.remove(McpToolChunkMiddlewareName)
|
|
||||||
logger.silly('McpToolChunkMiddleware is removed')
|
|
||||||
}
|
|
||||||
if (isEnabledToolUse(params.assistant) && isFunctionCallingModel(model)) {
|
|
||||||
builder.remove(ToolUseExtractionMiddlewareName)
|
|
||||||
logger.silly('ToolUseExtractionMiddleware is removed')
|
|
||||||
}
|
|
||||||
if (params.callType !== 'chat' && params.callType !== 'check' && params.callType !== 'translate') {
|
|
||||||
logger.silly('AbortHandlerMiddleware is removed')
|
|
||||||
builder.remove(AbortHandlerMiddlewareName)
|
|
||||||
}
|
|
||||||
if (params.callType === 'test') {
|
|
||||||
builder.remove(ErrorHandlerMiddlewareName)
|
|
||||||
logger.silly('ErrorHandlerMiddleware is removed')
|
|
||||||
builder.remove(FinalChunkConsumerMiddlewareName)
|
|
||||||
logger.silly('FinalChunkConsumerMiddleware is removed')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const middlewares = builder.build()
|
|
||||||
logger.silly(
|
|
||||||
'middlewares',
|
|
||||||
middlewares.map((m) => m.name)
|
|
||||||
)
|
|
||||||
|
|
||||||
// 3. Create the wrapped SDK method with middlewares
|
|
||||||
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
|
|
||||||
|
|
||||||
// 4. Execute the wrapped method with the original params
|
|
||||||
const result = wrappedCompletionMethod(params, options)
|
|
||||||
return result
|
|
||||||
}
|
|
||||||
|
|
||||||
public async completionsForTrace(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
|
||||||
const traceName = params.assistant.model?.name
|
|
||||||
? `${params.assistant.model?.name}.${params.callType}`
|
|
||||||
: `LLM.${params.callType}`
|
|
||||||
|
|
||||||
const traceParams: StartSpanParams = {
|
|
||||||
name: traceName,
|
|
||||||
tag: 'LLM',
|
|
||||||
topicId: params.topicId || '',
|
|
||||||
modelName: params.assistant.model?.name
|
|
||||||
}
|
|
||||||
|
|
||||||
return await withSpanResult(this.completions.bind(this), traceParams, params, options)
|
|
||||||
}
|
|
||||||
|
|
||||||
public async models(): Promise<SdkModel[]> {
|
|
||||||
return this.apiClient.listModels()
|
|
||||||
}
|
|
||||||
|
|
||||||
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
|
||||||
try {
|
|
||||||
// Use the SDK instance to test embedding capabilities
|
|
||||||
if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') {
|
|
||||||
this.apiClient = this.apiClient.getClient(model) as BaseApiClient
|
|
||||||
}
|
|
||||||
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
|
|
||||||
return dimensions
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('Error getting embedding dimensions:', error as Error)
|
|
||||||
throw error
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
|
||||||
if (this.apiClient instanceof AihubmixAPIClient) {
|
|
||||||
const client = this.apiClient.getClientForModel({ id: params.model } as Model)
|
|
||||||
return client.generateImage(params)
|
|
||||||
}
|
|
||||||
return this.apiClient.generateImage(params)
|
|
||||||
}
|
|
||||||
|
|
||||||
public getBaseURL(): string {
|
|
||||||
return this.apiClient.getBaseURL()
|
|
||||||
}
|
|
||||||
|
|
||||||
public getApiKey(): string {
|
|
||||||
return this.apiClient.getApiKey()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
488
src/renderer/src/aiCore/index_new.ts
Normal file
488
src/renderer/src/aiCore/index_new.ts
Normal file
@@ -0,0 +1,488 @@
|
|||||||
|
/**
|
||||||
|
* Cherry Studio AI Core - 新版本入口
|
||||||
|
* 集成 @cherrystudio/ai-core 库的渐进式重构方案
|
||||||
|
*
|
||||||
|
* 融合方案:简化实现,专注于核心功能
|
||||||
|
* 1. 优先使用新AI SDK
|
||||||
|
* 2. 暂时保持接口兼容性
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { createExecutor } from '@cherrystudio/ai-core'
|
||||||
|
import { loggerService } from '@logger'
|
||||||
|
import { isNotSupportedImageSizeModel } from '@renderer/config/models'
|
||||||
|
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
|
||||||
|
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
|
||||||
|
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||||
|
import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||||
|
import type { AiSdkModel, StreamTextParams } from '@renderer/types/aiCoreTypes'
|
||||||
|
import { ChunkType } from '@renderer/types/chunk'
|
||||||
|
import { type ImageModel, type LanguageModel, type Provider as AiSdkProvider, wrapLanguageModel } from 'ai'
|
||||||
|
|
||||||
|
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
|
||||||
|
import LegacyAiProvider from './legacy/index'
|
||||||
|
import { CompletionsResult } from './legacy/middleware/schemas'
|
||||||
|
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
|
||||||
|
import { buildPlugins } from './plugins/PluginBuilder'
|
||||||
|
import { createAiSdkProvider } from './provider/factory'
|
||||||
|
import {
|
||||||
|
getActualProvider,
|
||||||
|
isModernSdkSupported,
|
||||||
|
prepareSpecialProviderConfig,
|
||||||
|
providerToAiSdkConfig
|
||||||
|
} from './provider/providerConfig'
|
||||||
|
|
||||||
|
const logger = loggerService.withContext('ModernAiProvider')
|
||||||
|
|
||||||
|
export type ModernAiProviderConfig = AiSdkMiddlewareConfig & {
|
||||||
|
assistant: Assistant
|
||||||
|
// topicId for tracing
|
||||||
|
topicId?: string
|
||||||
|
callType: string
|
||||||
|
}
|
||||||
|
|
||||||
|
export default class ModernAiProvider {
|
||||||
|
private legacyProvider: LegacyAiProvider
|
||||||
|
private config?: ReturnType<typeof providerToAiSdkConfig>
|
||||||
|
private actualProvider: Provider
|
||||||
|
private model?: Model
|
||||||
|
private localProvider: Awaited<AiSdkProvider> | null = null
|
||||||
|
|
||||||
|
// 构造函数重载签名
|
||||||
|
constructor(model: Model, provider?: Provider)
|
||||||
|
constructor(provider: Provider)
|
||||||
|
constructor(modelOrProvider: Model | Provider, provider?: Provider)
|
||||||
|
constructor(modelOrProvider: Model | Provider, provider?: Provider) {
|
||||||
|
if (this.isModel(modelOrProvider)) {
|
||||||
|
// 传入的是 Model
|
||||||
|
this.model = modelOrProvider
|
||||||
|
this.actualProvider = provider || getActualProvider(modelOrProvider)
|
||||||
|
// 只保存配置,不预先创建executor
|
||||||
|
this.config = providerToAiSdkConfig(this.actualProvider, modelOrProvider)
|
||||||
|
} else {
|
||||||
|
// 传入的是 Provider
|
||||||
|
this.actualProvider = modelOrProvider
|
||||||
|
// model为可选,某些操作(如fetchModels)不需要model
|
||||||
|
}
|
||||||
|
|
||||||
|
this.legacyProvider = new LegacyAiProvider(this.actualProvider)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 类型守卫函数:通过 provider 属性区分 Model 和 Provider
|
||||||
|
*/
|
||||||
|
private isModel(obj: Model | Provider): obj is Model {
|
||||||
|
return 'provider' in obj && typeof obj.provider === 'string'
|
||||||
|
}
|
||||||
|
|
||||||
|
public getActualProvider() {
|
||||||
|
return this.actualProvider
|
||||||
|
}
|
||||||
|
|
||||||
|
public async completions(modelId: string, params: StreamTextParams, config: ModernAiProviderConfig) {
|
||||||
|
// 检查model是否存在
|
||||||
|
if (!this.model) {
|
||||||
|
throw new Error('Model is required for completions. Please use constructor with model parameter.')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保配置存在
|
||||||
|
if (!this.config) {
|
||||||
|
this.config = providerToAiSdkConfig(this.actualProvider, this.model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 准备特殊配置
|
||||||
|
await prepareSpecialProviderConfig(this.actualProvider, this.config)
|
||||||
|
|
||||||
|
// 提前创建本地 provider 实例
|
||||||
|
if (!this.localProvider) {
|
||||||
|
this.localProvider = await createAiSdkProvider(this.config)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 提前构建中间件
|
||||||
|
const middlewares = buildAiSdkMiddlewares({
|
||||||
|
...config,
|
||||||
|
provider: this.actualProvider
|
||||||
|
})
|
||||||
|
logger.debug('Built middlewares in completions', {
|
||||||
|
middlewareCount: middlewares.length,
|
||||||
|
isImageGeneration: config.isImageGenerationEndpoint
|
||||||
|
})
|
||||||
|
if (!this.localProvider) {
|
||||||
|
throw new Error('Local provider not created')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据endpoint类型创建对应的模型
|
||||||
|
let model: AiSdkModel | undefined
|
||||||
|
if (config.isImageGenerationEndpoint) {
|
||||||
|
model = this.localProvider.imageModel(modelId)
|
||||||
|
} else {
|
||||||
|
model = this.localProvider.languageModel(modelId)
|
||||||
|
// 如果有中间件,应用到语言模型上
|
||||||
|
if (middlewares.length > 0 && typeof model === 'object') {
|
||||||
|
model = wrapLanguageModel({ model, middleware: middlewares })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (config.topicId && getEnableDeveloperMode()) {
|
||||||
|
// TypeScript类型窄化:确保topicId是string类型
|
||||||
|
const traceConfig = {
|
||||||
|
...config,
|
||||||
|
topicId: config.topicId
|
||||||
|
}
|
||||||
|
return await this._completionsForTrace(model, params, traceConfig)
|
||||||
|
} else {
|
||||||
|
return await this._completionsOrImageGeneration(model, params, config)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async _completionsOrImageGeneration(
|
||||||
|
model: AiSdkModel,
|
||||||
|
params: StreamTextParams,
|
||||||
|
config: ModernAiProviderConfig
|
||||||
|
): Promise<CompletionsResult> {
|
||||||
|
if (config.isImageGenerationEndpoint) {
|
||||||
|
return await this.modernImageGeneration(model as ImageModel, params, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
return await this.modernCompletions(model as LanguageModel, params, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 带trace支持的completions方法
|
||||||
|
* 类似于legacy的completionsForTrace,确保AI SDK spans在正确的trace上下文中
|
||||||
|
*/
|
||||||
|
private async _completionsForTrace(
|
||||||
|
model: AiSdkModel,
|
||||||
|
params: StreamTextParams,
|
||||||
|
config: ModernAiProviderConfig & { topicId: string }
|
||||||
|
): Promise<CompletionsResult> {
|
||||||
|
const modelId = this.model!.id
|
||||||
|
const traceName = `${this.actualProvider.name}.${modelId}.${config.callType}`
|
||||||
|
const traceParams: StartSpanParams = {
|
||||||
|
name: traceName,
|
||||||
|
tag: 'LLM',
|
||||||
|
topicId: config.topicId,
|
||||||
|
modelName: config.assistant.model?.name, // 使用modelId而不是provider名称
|
||||||
|
inputs: params
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.info('Starting AI SDK trace span', {
|
||||||
|
traceName,
|
||||||
|
topicId: config.topicId,
|
||||||
|
modelId,
|
||||||
|
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
|
||||||
|
toolNames: params.tools ? Object.keys(params.tools) : [],
|
||||||
|
isImageGeneration: config.isImageGenerationEndpoint
|
||||||
|
})
|
||||||
|
|
||||||
|
const span = addSpan(traceParams)
|
||||||
|
if (!span) {
|
||||||
|
logger.warn('Failed to create span, falling back to regular completions', {
|
||||||
|
topicId: config.topicId,
|
||||||
|
modelId,
|
||||||
|
traceName
|
||||||
|
})
|
||||||
|
return await this._completionsOrImageGeneration(model, params, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
logger.info('Created parent span, now calling completions', {
|
||||||
|
spanId: span.spanContext().spanId,
|
||||||
|
traceId: span.spanContext().traceId,
|
||||||
|
topicId: config.topicId,
|
||||||
|
modelId,
|
||||||
|
parentSpanCreated: true
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = await this._completionsOrImageGeneration(model, params, config)
|
||||||
|
|
||||||
|
logger.info('Completions finished, ending parent span', {
|
||||||
|
spanId: span.spanContext().spanId,
|
||||||
|
traceId: span.spanContext().traceId,
|
||||||
|
topicId: config.topicId,
|
||||||
|
modelId,
|
||||||
|
resultLength: result.getText().length
|
||||||
|
})
|
||||||
|
|
||||||
|
// 标记span完成
|
||||||
|
endSpan({
|
||||||
|
topicId: config.topicId,
|
||||||
|
outputs: result,
|
||||||
|
span,
|
||||||
|
modelName: modelId // 使用modelId保持一致性
|
||||||
|
})
|
||||||
|
|
||||||
|
return result
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error in completionsForTrace, ending parent span with error', error as Error, {
|
||||||
|
spanId: span.spanContext().spanId,
|
||||||
|
traceId: span.spanContext().traceId,
|
||||||
|
topicId: config.topicId,
|
||||||
|
modelId
|
||||||
|
})
|
||||||
|
|
||||||
|
// 标记span出错
|
||||||
|
endSpan({
|
||||||
|
topicId: config.topicId,
|
||||||
|
error: error as Error,
|
||||||
|
span,
|
||||||
|
modelName: modelId // 使用modelId保持一致性
|
||||||
|
})
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 使用现代化AI SDK的completions实现
|
||||||
|
*/
|
||||||
|
private async modernCompletions(
|
||||||
|
model: LanguageModel,
|
||||||
|
params: StreamTextParams,
|
||||||
|
config: ModernAiProviderConfig
|
||||||
|
): Promise<CompletionsResult> {
|
||||||
|
const modelId = this.model!.id
|
||||||
|
logger.info('Starting modernCompletions', {
|
||||||
|
modelId,
|
||||||
|
providerId: this.config!.providerId,
|
||||||
|
topicId: config.topicId,
|
||||||
|
hasOnChunk: !!config.onChunk,
|
||||||
|
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
|
||||||
|
toolCount: params.tools ? Object.keys(params.tools).length : 0
|
||||||
|
})
|
||||||
|
|
||||||
|
// 根据条件构建插件数组
|
||||||
|
const plugins = buildPlugins(config)
|
||||||
|
|
||||||
|
// 用构建好的插件数组创建executor
|
||||||
|
const executor = createExecutor(this.config!.providerId, this.config!.options, plugins)
|
||||||
|
|
||||||
|
// 创建带有中间件的执行器
|
||||||
|
if (config.onChunk) {
|
||||||
|
const accumulate = this.model!.supported_text_delta !== false // true and undefined
|
||||||
|
const adapter = new AiSdkToChunkAdapter(config.onChunk, config.mcpTools, accumulate)
|
||||||
|
|
||||||
|
const streamResult = await executor.streamText({
|
||||||
|
...params,
|
||||||
|
model,
|
||||||
|
experimental_context: { onChunk: config.onChunk }
|
||||||
|
})
|
||||||
|
|
||||||
|
const finalText = await adapter.processStream(streamResult)
|
||||||
|
|
||||||
|
return {
|
||||||
|
getText: () => finalText
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const streamResult = await executor.streamText({
|
||||||
|
...params,
|
||||||
|
model
|
||||||
|
})
|
||||||
|
|
||||||
|
// 强制消费流,不然await streamResult.text会阻塞
|
||||||
|
await streamResult?.consumeStream()
|
||||||
|
|
||||||
|
const finalText = await streamResult.text
|
||||||
|
|
||||||
|
return {
|
||||||
|
getText: () => finalText
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 使用现代化 AI SDK 的图像生成实现,支持流式输出
|
||||||
|
*/
|
||||||
|
private async modernImageGeneration(
|
||||||
|
model: ImageModel,
|
||||||
|
params: StreamTextParams,
|
||||||
|
config: ModernAiProviderConfig
|
||||||
|
): Promise<CompletionsResult> {
|
||||||
|
const { onChunk } = config
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 检查 messages 是否存在
|
||||||
|
if (!params.messages || params.messages.length === 0) {
|
||||||
|
throw new Error('No messages provided for image generation.')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 从最后一条用户消息中提取 prompt
|
||||||
|
const lastUserMessage = params.messages.findLast((m) => m.role === 'user')
|
||||||
|
if (!lastUserMessage) {
|
||||||
|
throw new Error('No user message found for image generation.')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 直接使用消息内容,避免类型转换问题
|
||||||
|
const prompt =
|
||||||
|
typeof lastUserMessage.content === 'string'
|
||||||
|
? lastUserMessage.content
|
||||||
|
: lastUserMessage.content?.map((part) => ('text' in part ? part.text : '')).join('') || ''
|
||||||
|
|
||||||
|
if (!prompt) {
|
||||||
|
throw new Error('No prompt found in user message.')
|
||||||
|
}
|
||||||
|
|
||||||
|
const startTime = Date.now()
|
||||||
|
|
||||||
|
// 发送图像生成开始事件
|
||||||
|
if (onChunk) {
|
||||||
|
onChunk({ type: ChunkType.IMAGE_CREATED })
|
||||||
|
}
|
||||||
|
|
||||||
|
// 构建图像生成参数
|
||||||
|
const imageParams = {
|
||||||
|
prompt,
|
||||||
|
size: isNotSupportedImageSizeModel(config.model) ? undefined : ('1024x1024' as `${number}x${number}`), // 默认尺寸,使用正确的类型
|
||||||
|
n: 1,
|
||||||
|
...(params.abortSignal && { abortSignal: params.abortSignal })
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用新 AI SDK 的图像生成功能
|
||||||
|
const executor = createExecutor(this.config!.providerId, this.config!.options, [])
|
||||||
|
const result = await executor.generateImage({
|
||||||
|
model,
|
||||||
|
...imageParams
|
||||||
|
})
|
||||||
|
|
||||||
|
// 转换结果格式
|
||||||
|
const images: string[] = []
|
||||||
|
const imageType: 'url' | 'base64' = 'base64'
|
||||||
|
|
||||||
|
if (result.images) {
|
||||||
|
for (const image of result.images) {
|
||||||
|
if ('base64' in image && image.base64) {
|
||||||
|
images.push(`data:${image.mediaType};base64,${image.base64}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送图像生成完成事件
|
||||||
|
if (onChunk && images.length > 0) {
|
||||||
|
onChunk({
|
||||||
|
type: ChunkType.IMAGE_COMPLETE,
|
||||||
|
image: { type: imageType, images }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 发送块完成事件(类似于 modernCompletions 的处理)
|
||||||
|
if (onChunk) {
|
||||||
|
const usage = {
|
||||||
|
prompt_tokens: prompt.length, // 估算的 token 数量
|
||||||
|
completion_tokens: 0, // 图像生成没有 completion tokens
|
||||||
|
total_tokens: prompt.length
|
||||||
|
}
|
||||||
|
|
||||||
|
onChunk({
|
||||||
|
type: ChunkType.BLOCK_COMPLETE,
|
||||||
|
response: {
|
||||||
|
usage,
|
||||||
|
metrics: {
|
||||||
|
completion_tokens: usage.completion_tokens,
|
||||||
|
time_first_token_millsec: 0,
|
||||||
|
time_completion_millsec: Date.now() - startTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// 发送 LLM 响应完成事件
|
||||||
|
onChunk({
|
||||||
|
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||||
|
response: {
|
||||||
|
usage,
|
||||||
|
metrics: {
|
||||||
|
completion_tokens: usage.completion_tokens,
|
||||||
|
time_first_token_millsec: 0,
|
||||||
|
time_completion_millsec: Date.now() - startTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
getText: () => '' // 图像生成不返回文本
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
// 发送错误事件
|
||||||
|
if (onChunk) {
|
||||||
|
onChunk({ type: ChunkType.ERROR, error: error as any })
|
||||||
|
}
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 代理其他方法到原有实现
|
||||||
|
public async models() {
|
||||||
|
return this.legacyProvider.models()
|
||||||
|
}
|
||||||
|
|
||||||
|
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||||
|
return this.legacyProvider.getEmbeddingDimensions(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||||
|
// 如果支持新的 AI SDK,使用现代化实现
|
||||||
|
if (isModernSdkSupported(this.actualProvider)) {
|
||||||
|
try {
|
||||||
|
// 确保本地provider已创建
|
||||||
|
if (!this.localProvider) {
|
||||||
|
this.localProvider = await createAiSdkProvider(this.config)
|
||||||
|
if (!this.localProvider) {
|
||||||
|
throw new Error('Local provider not created')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = await this.modernGenerateImage(params)
|
||||||
|
return result
|
||||||
|
} catch (error) {
|
||||||
|
logger.warn('Modern AI SDK generateImage failed, falling back to legacy:', error as Error)
|
||||||
|
// fallback 到传统实现
|
||||||
|
return this.legacyProvider.generateImage(params)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 直接使用传统实现
|
||||||
|
return this.legacyProvider.generateImage(params)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 使用现代化 AI SDK 的图像生成实现
|
||||||
|
*/
|
||||||
|
private async modernGenerateImage(params: GenerateImageParams): Promise<string[]> {
|
||||||
|
const { model, prompt, imageSize, batchSize, signal } = params
|
||||||
|
|
||||||
|
// 转换参数格式
|
||||||
|
const aiSdkParams = {
|
||||||
|
prompt,
|
||||||
|
size: (imageSize || '1024x1024') as `${number}x${number}`,
|
||||||
|
n: batchSize || 1,
|
||||||
|
...(signal && { abortSignal: signal })
|
||||||
|
}
|
||||||
|
|
||||||
|
const executor = createExecutor(this.config!.providerId, this.config!.options, [])
|
||||||
|
const result = await executor.generateImage({
|
||||||
|
model: this.localProvider?.imageModel(model) as ImageModel,
|
||||||
|
...aiSdkParams
|
||||||
|
})
|
||||||
|
|
||||||
|
// 转换结果格式
|
||||||
|
const images: string[] = []
|
||||||
|
if (result.images) {
|
||||||
|
for (const image of result.images) {
|
||||||
|
if ('base64' in image && image.base64) {
|
||||||
|
images.push(`data:image/png;base64,${image.base64}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return images
|
||||||
|
}
|
||||||
|
|
||||||
|
public getBaseURL(): string {
|
||||||
|
return this.legacyProvider.getBaseURL()
|
||||||
|
}
|
||||||
|
|
||||||
|
public getApiKey(): string {
|
||||||
|
return this.legacyProvider.getApiKey()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 为了方便调试,导出一些工具函数
|
||||||
|
export { isModernSdkSupported, providerToAiSdkConfig }
|
||||||
@@ -75,6 +75,7 @@ export class ApiClientFactory {
|
|||||||
instance = new GeminiAPIClient(provider) as BaseApiClient
|
instance = new GeminiAPIClient(provider) as BaseApiClient
|
||||||
break
|
break
|
||||||
case 'vertexai':
|
case 'vertexai':
|
||||||
|
logger.debug(`Creating VertexAPIClient for provider: ${provider.id}`)
|
||||||
instance = new VertexAPIClient(provider) as BaseApiClient
|
instance = new VertexAPIClient(provider) as BaseApiClient
|
||||||
break
|
break
|
||||||
case 'anthropic':
|
case 'anthropic':
|
||||||
@@ -66,7 +66,8 @@ vi.mock('@renderer/config/models', () => ({
|
|||||||
SYSTEM_MODELS: {
|
SYSTEM_MODELS: {
|
||||||
silicon: [],
|
silicon: [],
|
||||||
defaultModel: []
|
defaultModel: []
|
||||||
}
|
},
|
||||||
|
isOpenAIModel: vi.fn(() => false)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
describe('ApiClientFactory', () => {
|
describe('ApiClientFactory', () => {
|
||||||
@@ -1,11 +1,11 @@
|
|||||||
import { AihubmixAPIClient } from '@renderer/aiCore/clients/aihubmix/AihubmixAPIClient'
|
import { AihubmixAPIClient } from '@renderer/aiCore/legacy/clients/aihubmix/AihubmixAPIClient'
|
||||||
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
|
import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient'
|
||||||
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
|
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
|
||||||
import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient'
|
import { GeminiAPIClient } from '@renderer/aiCore/legacy/clients/gemini/GeminiAPIClient'
|
||||||
import { VertexAPIClient } from '@renderer/aiCore/clients/gemini/VertexAPIClient'
|
import { VertexAPIClient } from '@renderer/aiCore/legacy/clients/gemini/VertexAPIClient'
|
||||||
import { NewAPIClient } from '@renderer/aiCore/clients/newapi/NewAPIClient'
|
import { NewAPIClient } from '@renderer/aiCore/legacy/clients/newapi/NewAPIClient'
|
||||||
import { OpenAIAPIClient } from '@renderer/aiCore/clients/openai/OpenAIApiClient'
|
import { OpenAIAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIApiClient'
|
||||||
import { OpenAIResponseAPIClient } from '@renderer/aiCore/clients/openai/OpenAIResponseAPIClient'
|
import { OpenAIResponseAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIResponseAPIClient'
|
||||||
import { EndpointType, Model, Provider } from '@renderer/types'
|
import { EndpointType, Model, Provider } from '@renderer/types'
|
||||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
@@ -22,6 +22,7 @@ vi.mock('@renderer/config/models', () => ({
|
|||||||
anthropic: [],
|
anthropic: [],
|
||||||
gemini: []
|
gemini: []
|
||||||
},
|
},
|
||||||
|
isOpenAIModel: vi.fn().mockReturnValue(true),
|
||||||
isOpenAILLMModel: vi.fn().mockReturnValue(true),
|
isOpenAILLMModel: vi.fn().mockReturnValue(true),
|
||||||
isOpenAIChatCompletionOnlyModel: vi.fn().mockReturnValue(false),
|
isOpenAIChatCompletionOnlyModel: vi.fn().mockReturnValue(false),
|
||||||
isAnthropicLLMModel: vi.fn().mockReturnValue(false),
|
isAnthropicLLMModel: vi.fn().mockReturnValue(false),
|
||||||
@@ -80,6 +81,7 @@ vi.mock('@logger', () => ({
|
|||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
// 到底是谁想出来的在服务层调用 React Hook ?????????
|
||||||
// Mock additional services and hooks that might be imported
|
// Mock additional services and hooks that might be imported
|
||||||
vi.mock('@renderer/hooks/useVertexAI', () => ({
|
vi.mock('@renderer/hooks/useVertexAI', () => ({
|
||||||
getVertexAILocation: vi.fn().mockReturnValue('us-central1'),
|
getVertexAILocation: vi.fn().mockReturnValue('us-central1'),
|
||||||
@@ -87,7 +89,9 @@ vi.mock('@renderer/hooks/useVertexAI', () => ({
|
|||||||
getVertexAIServiceAccount: vi.fn().mockReturnValue({
|
getVertexAIServiceAccount: vi.fn().mockReturnValue({
|
||||||
privateKey: 'test-key',
|
privateKey: 'test-key',
|
||||||
clientEmail: 'test@example.com'
|
clientEmail: 'test@example.com'
|
||||||
})
|
}),
|
||||||
|
isVertexAIConfigured: vi.fn().mockReturnValue(true),
|
||||||
|
isVertexProvider: vi.fn().mockReturnValue(true)
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@renderer/hooks/useSettings', () => ({
|
vi.mock('@renderer/hooks/useSettings', () => ({
|
||||||
@@ -131,7 +135,7 @@ vi.mock('@google-cloud/vertexai', () => ({
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
// Mock the circular dependency between VertexAPIClient and AnthropicVertexClient
|
// Mock the circular dependency between VertexAPIClient and AnthropicVertexClient
|
||||||
vi.mock('@renderer/aiCore/clients/anthropic/AnthropicVertexClient', () => {
|
vi.mock('@renderer/aiCore/legacy/clients/anthropic/AnthropicVertexClient', () => {
|
||||||
const MockAnthropicVertexClient = vi.fn()
|
const MockAnthropicVertexClient = vi.fn()
|
||||||
MockAnthropicVertexClient.prototype.getClientCompatibilityType = vi.fn().mockReturnValue(['AnthropicVertexAPIClient'])
|
MockAnthropicVertexClient.prototype.getClientCompatibilityType = vi.fn().mockReturnValue(['AnthropicVertexAPIClient'])
|
||||||
return {
|
return {
|
||||||
@@ -25,7 +25,6 @@ import {
|
|||||||
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||||
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
|
||||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||||
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
|
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
|
||||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||||
@@ -64,13 +63,14 @@ import {
|
|||||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||||
import {
|
import {
|
||||||
anthropicToolUseToMcpTool,
|
anthropicToolUseToMcpTool,
|
||||||
isEnabledToolUse,
|
isSupportedToolUse,
|
||||||
mcpToolCallResponseToAnthropicMessage,
|
mcpToolCallResponseToAnthropicMessage,
|
||||||
mcpToolsToAnthropicTools
|
mcpToolsToAnthropicTools
|
||||||
} from '@renderer/utils/mcp-tools'
|
} from '@renderer/utils/mcp-tools'
|
||||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||||
import { t } from 'i18next'
|
import { t } from 'i18next'
|
||||||
|
|
||||||
|
import { GenericChunk } from '../../middleware/schemas'
|
||||||
import { BaseApiClient } from '../BaseApiClient'
|
import { BaseApiClient } from '../BaseApiClient'
|
||||||
import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
|
import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||||
|
|
||||||
@@ -457,7 +457,7 @@ export class AnthropicAPIClient extends BaseApiClient<
|
|||||||
const { tools } = this.setupToolsConfig({
|
const { tools } = this.setupToolsConfig({
|
||||||
mcpTools: mcpTools,
|
mcpTools: mcpTools,
|
||||||
model,
|
model,
|
||||||
enableToolUse: isEnabledToolUse(assistant)
|
enableToolUse: isSupportedToolUse(assistant)
|
||||||
})
|
})
|
||||||
|
|
||||||
const systemMessage: TextBlockParam | undefined = systemPrompt
|
const systemMessage: TextBlockParam | undefined = systemPrompt
|
||||||
@@ -6,7 +6,7 @@ import {
|
|||||||
InvokeModelWithResponseStreamCommand
|
InvokeModelWithResponseStreamCommand
|
||||||
} from '@aws-sdk/client-bedrock-runtime'
|
} from '@aws-sdk/client-bedrock-runtime'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
import { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas'
|
||||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||||
import { findTokenLimit, isReasoningModel } from '@renderer/config/models'
|
import { findTokenLimit, isReasoningModel } from '@renderer/config/models'
|
||||||
import {
|
import {
|
||||||
@@ -50,7 +50,7 @@ import {
|
|||||||
import { convertBase64ImageToAwsBedrockFormat } from '@renderer/utils/aws-bedrock-utils'
|
import { convertBase64ImageToAwsBedrockFormat } from '@renderer/utils/aws-bedrock-utils'
|
||||||
import {
|
import {
|
||||||
awsBedrockToolUseToMcpTool,
|
awsBedrockToolUseToMcpTool,
|
||||||
isEnabledToolUse,
|
isSupportedToolUse,
|
||||||
mcpToolCallResponseToAwsBedrockMessage,
|
mcpToolCallResponseToAwsBedrockMessage,
|
||||||
mcpToolsToAwsBedrockTools
|
mcpToolsToAwsBedrockTools
|
||||||
} from '@renderer/utils/mcp-tools'
|
} from '@renderer/utils/mcp-tools'
|
||||||
@@ -739,7 +739,7 @@ export class AwsBedrockAPIClient extends BaseApiClient<
|
|||||||
const { tools } = this.setupToolsConfig({
|
const { tools } = this.setupToolsConfig({
|
||||||
mcpTools: mcpTools,
|
mcpTools: mcpTools,
|
||||||
model,
|
model,
|
||||||
enableToolUse: isEnabledToolUse(assistant)
|
enableToolUse: isSupportedToolUse(assistant)
|
||||||
})
|
})
|
||||||
|
|
||||||
// 3. 处理消息
|
// 3. 处理消息
|
||||||
@@ -18,7 +18,6 @@ import {
|
|||||||
} from '@google/genai'
|
} from '@google/genai'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { nanoid } from '@reduxjs/toolkit'
|
import { nanoid } from '@reduxjs/toolkit'
|
||||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
|
||||||
import {
|
import {
|
||||||
findTokenLimit,
|
findTokenLimit,
|
||||||
GEMINI_FLASH_MODEL_REGEX,
|
GEMINI_FLASH_MODEL_REGEX,
|
||||||
@@ -55,7 +54,7 @@ import {
|
|||||||
import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
||||||
import {
|
import {
|
||||||
geminiFunctionCallToMcpTool,
|
geminiFunctionCallToMcpTool,
|
||||||
isEnabledToolUse,
|
isSupportedToolUse,
|
||||||
mcpToolCallResponseToGeminiMessage,
|
mcpToolCallResponseToGeminiMessage,
|
||||||
mcpToolsToGeminiTools
|
mcpToolsToGeminiTools
|
||||||
} from '@renderer/utils/mcp-tools'
|
} from '@renderer/utils/mcp-tools'
|
||||||
@@ -63,6 +62,7 @@ import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/u
|
|||||||
import { defaultTimeout, MB } from '@shared/config/constant'
|
import { defaultTimeout, MB } from '@shared/config/constant'
|
||||||
import { t } from 'i18next'
|
import { t } from 'i18next'
|
||||||
|
|
||||||
|
import { GenericChunk } from '../../middleware/schemas'
|
||||||
import { BaseApiClient } from '../BaseApiClient'
|
import { BaseApiClient } from '../BaseApiClient'
|
||||||
import { RequestTransformer, ResponseChunkTransformer } from '../types'
|
import { RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||||
|
|
||||||
@@ -454,7 +454,7 @@ export class GeminiAPIClient extends BaseApiClient<
|
|||||||
const { tools } = this.setupToolsConfig({
|
const { tools } = this.setupToolsConfig({
|
||||||
mcpTools,
|
mcpTools,
|
||||||
model,
|
model,
|
||||||
enableToolUse: isEnabledToolUse(assistant)
|
enableToolUse: isSupportedToolUse(assistant)
|
||||||
})
|
})
|
||||||
|
|
||||||
let messageContents: Content = { role: 'user', parts: [] } // Initialize messageContents
|
let messageContents: Content = { role: 'user', parts: [] } // Initialize messageContents
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
import { GoogleGenAI } from '@google/genai'
|
import { GoogleGenAI } from '@google/genai'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { getVertexAILocation, getVertexAIProjectId, getVertexAIServiceAccount } from '@renderer/hooks/useVertexAI'
|
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||||
import { Model, Provider } from '@renderer/types'
|
import { Model, Provider, VertexProvider } from '@renderer/types'
|
||||||
import { isEmpty } from 'lodash'
|
import { isEmpty } from 'lodash'
|
||||||
|
|
||||||
import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient'
|
import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient'
|
||||||
@@ -12,10 +12,21 @@ export class VertexAPIClient extends GeminiAPIClient {
|
|||||||
private authHeaders?: Record<string, string>
|
private authHeaders?: Record<string, string>
|
||||||
private authHeadersExpiry?: number
|
private authHeadersExpiry?: number
|
||||||
private anthropicVertexClient: AnthropicVertexClient
|
private anthropicVertexClient: AnthropicVertexClient
|
||||||
|
private vertexProvider: VertexProvider
|
||||||
|
|
||||||
constructor(provider: Provider) {
|
constructor(provider: Provider) {
|
||||||
super(provider)
|
super(provider)
|
||||||
|
// 检查 VertexAI 配置
|
||||||
|
if (!isVertexAIConfigured()) {
|
||||||
|
throw new Error('VertexAI is not configured. Please configure project, location and service account credentials.')
|
||||||
|
}
|
||||||
this.anthropicVertexClient = new AnthropicVertexClient(provider)
|
this.anthropicVertexClient = new AnthropicVertexClient(provider)
|
||||||
|
// 如果传入的是普通 Provider,转换为 VertexProvider
|
||||||
|
if (isVertexProvider(provider)) {
|
||||||
|
this.vertexProvider = provider
|
||||||
|
} else {
|
||||||
|
this.vertexProvider = createVertexProvider(provider)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
override getClientCompatibilityType(model?: Model): string[] {
|
override getClientCompatibilityType(model?: Model): string[] {
|
||||||
@@ -56,11 +67,9 @@ export class VertexAPIClient extends GeminiAPIClient {
|
|||||||
return this.sdkInstance
|
return this.sdkInstance
|
||||||
}
|
}
|
||||||
|
|
||||||
const serviceAccount = getVertexAIServiceAccount()
|
const { googleCredentials, project, location } = this.vertexProvider
|
||||||
const projectId = getVertexAIProjectId()
|
|
||||||
const location = getVertexAILocation()
|
|
||||||
|
|
||||||
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId || !location) {
|
if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project || !location) {
|
||||||
throw new Error('Vertex AI settings are not configured')
|
throw new Error('Vertex AI settings are not configured')
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,7 +77,7 @@ export class VertexAPIClient extends GeminiAPIClient {
|
|||||||
|
|
||||||
this.sdkInstance = new GoogleGenAI({
|
this.sdkInstance = new GoogleGenAI({
|
||||||
vertexai: true,
|
vertexai: true,
|
||||||
project: projectId,
|
project: project,
|
||||||
location: location,
|
location: location,
|
||||||
httpOptions: {
|
httpOptions: {
|
||||||
apiVersion: this.getApiVersion(),
|
apiVersion: this.getApiVersion(),
|
||||||
@@ -84,11 +93,10 @@ export class VertexAPIClient extends GeminiAPIClient {
|
|||||||
* 获取认证头,如果配置了 service account 则从主进程获取
|
* 获取认证头,如果配置了 service account 则从主进程获取
|
||||||
*/
|
*/
|
||||||
private async getServiceAccountAuthHeaders(): Promise<Record<string, string> | undefined> {
|
private async getServiceAccountAuthHeaders(): Promise<Record<string, string> | undefined> {
|
||||||
const serviceAccount = getVertexAIServiceAccount()
|
const { googleCredentials, project } = this.vertexProvider
|
||||||
const projectId = getVertexAIProjectId()
|
|
||||||
|
|
||||||
// 检查是否配置了 service account
|
// 检查是否配置了 service account
|
||||||
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId) {
|
if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project) {
|
||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,10 +109,10 @@ export class VertexAPIClient extends GeminiAPIClient {
|
|||||||
try {
|
try {
|
||||||
// 从主进程获取认证头
|
// 从主进程获取认证头
|
||||||
this.authHeaders = await window.api.vertexAI.getAuthHeaders({
|
this.authHeaders = await window.api.vertexAI.getAuthHeaders({
|
||||||
projectId,
|
projectId: project,
|
||||||
serviceAccount: {
|
serviceAccount: {
|
||||||
privateKey: serviceAccount.privateKey,
|
privateKey: googleCredentials.privateKey,
|
||||||
clientEmail: serviceAccount.clientEmail
|
clientEmail: googleCredentials.clientEmail
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -125,11 +133,10 @@ export class VertexAPIClient extends GeminiAPIClient {
|
|||||||
this.authHeaders = undefined
|
this.authHeaders = undefined
|
||||||
this.authHeadersExpiry = undefined
|
this.authHeadersExpiry = undefined
|
||||||
|
|
||||||
const serviceAccount = getVertexAIServiceAccount()
|
const { googleCredentials, project } = this.vertexProvider
|
||||||
const projectId = getVertexAIProjectId()
|
|
||||||
|
|
||||||
if (projectId && serviceAccount.clientEmail) {
|
if (project && googleCredentials.clientEmail) {
|
||||||
window.api.vertexAI.clearAuthCache(projectId, serviceAccount.clientEmail)
|
window.api.vertexAI.clearAuthCache(project, googleCredentials.clientEmail)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -71,7 +71,7 @@ import {
|
|||||||
} from '@renderer/types/sdk'
|
} from '@renderer/types/sdk'
|
||||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||||
import {
|
import {
|
||||||
isEnabledToolUse,
|
isSupportedToolUse,
|
||||||
mcpToolCallResponseToOpenAICompatibleMessage,
|
mcpToolCallResponseToOpenAICompatibleMessage,
|
||||||
mcpToolsToOpenAIChatTools,
|
mcpToolsToOpenAIChatTools,
|
||||||
openAIToolsToMcpTool
|
openAIToolsToMcpTool
|
||||||
@@ -611,7 +611,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
|
|||||||
const { tools } = this.setupToolsConfig({
|
const { tools } = this.setupToolsConfig({
|
||||||
mcpTools: mcpTools,
|
mcpTools: mcpTools,
|
||||||
model,
|
model,
|
||||||
enableToolUse: isEnabledToolUse(assistant)
|
enableToolUse: isSupportedToolUse(assistant)
|
||||||
})
|
})
|
||||||
|
|
||||||
// 3. 处理用户消息
|
// 3. 处理用户消息
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
import { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas'
|
||||||
import { CompletionsContext } from '@renderer/aiCore/middleware/types'
|
import { CompletionsContext } from '@renderer/aiCore/legacy/middleware/types'
|
||||||
import {
|
import {
|
||||||
isGPT5SeriesModel,
|
isGPT5SeriesModel,
|
||||||
isOpenAIChatCompletionOnlyModel,
|
isOpenAIChatCompletionOnlyModel,
|
||||||
@@ -36,7 +36,7 @@ import {
|
|||||||
} from '@renderer/types/sdk'
|
} from '@renderer/types/sdk'
|
||||||
import { addImageFileToContents } from '@renderer/utils/formats'
|
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||||
import {
|
import {
|
||||||
isEnabledToolUse,
|
isSupportedToolUse,
|
||||||
mcpToolCallResponseToOpenAIMessage,
|
mcpToolCallResponseToOpenAIMessage,
|
||||||
mcpToolsToOpenAIResponseTools,
|
mcpToolsToOpenAIResponseTools,
|
||||||
openAIToolsToMcpTool
|
openAIToolsToMcpTool
|
||||||
@@ -388,7 +388,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
|||||||
const { tools: extraTools } = this.setupToolsConfig({
|
const { tools: extraTools } = this.setupToolsConfig({
|
||||||
mcpTools: mcpTools,
|
mcpTools: mcpTools,
|
||||||
model,
|
model,
|
||||||
enableToolUse: isEnabledToolUse(assistant)
|
enableToolUse: isSupportedToolUse(assistant)
|
||||||
})
|
})
|
||||||
|
|
||||||
systemMessageContent.push(systemMessageInput)
|
systemMessageContent.push(systemMessageInput)
|
||||||
189
src/renderer/src/aiCore/legacy/index.ts
Normal file
189
src/renderer/src/aiCore/legacy/index.ts
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
import { loggerService } from '@logger'
|
||||||
|
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
|
||||||
|
import { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient'
|
||||||
|
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
|
||||||
|
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
|
import { withSpanResult } from '@renderer/services/SpanManagerService'
|
||||||
|
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
|
||||||
|
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||||
|
import type { RequestOptions, SdkModel } from '@renderer/types/sdk'
|
||||||
|
import { isSupportedToolUse } from '@renderer/utils/mcp-tools'
|
||||||
|
|
||||||
|
import { AihubmixAPIClient } from './clients/aihubmix/AihubmixAPIClient'
|
||||||
|
import { VertexAPIClient } from './clients/gemini/VertexAPIClient'
|
||||||
|
import { NewAPIClient } from './clients/newapi/NewAPIClient'
|
||||||
|
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
|
||||||
|
import { CompletionsMiddlewareBuilder } from './middleware/builder'
|
||||||
|
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
|
||||||
|
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
|
||||||
|
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
|
||||||
|
import { applyCompletionsMiddlewares } from './middleware/composer'
|
||||||
|
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
|
||||||
|
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
|
||||||
|
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
|
||||||
|
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
|
||||||
|
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
|
||||||
|
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
|
||||||
|
import { MiddlewareRegistry } from './middleware/register'
|
||||||
|
import type { CompletionsParams, CompletionsResult } from './middleware/schemas'
|
||||||
|
|
||||||
|
const logger = loggerService.withContext('AiProvider')
|
||||||
|
|
||||||
|
export default class AiProvider {
|
||||||
|
private apiClient: BaseApiClient
|
||||||
|
|
||||||
|
constructor(provider: Provider) {
|
||||||
|
// Use the new ApiClientFactory to get a BaseApiClient instance
|
||||||
|
this.apiClient = ApiClientFactory.create(provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||||
|
// 1. 根据模型识别正确的客户端
|
||||||
|
const model = params.assistant.model
|
||||||
|
if (!model) {
|
||||||
|
return Promise.reject(new Error('Model is required'))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据client类型选择合适的处理方式
|
||||||
|
let client: BaseApiClient
|
||||||
|
|
||||||
|
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||||
|
// AihubmixAPIClient: 根据模型选择合适的子client
|
||||||
|
client = this.apiClient.getClientForModel(model)
|
||||||
|
if (client instanceof OpenAIResponseAPIClient) {
|
||||||
|
client = client.getClient(model) as BaseApiClient
|
||||||
|
}
|
||||||
|
} else if (this.apiClient instanceof NewAPIClient) {
|
||||||
|
client = this.apiClient.getClientForModel(model)
|
||||||
|
if (client instanceof OpenAIResponseAPIClient) {
|
||||||
|
client = client.getClient(model) as BaseApiClient
|
||||||
|
}
|
||||||
|
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
|
||||||
|
// OpenAIResponseAPIClient: 根据模型特征选择API类型
|
||||||
|
client = this.apiClient.getClient(model) as BaseApiClient
|
||||||
|
} else if (this.apiClient instanceof VertexAPIClient) {
|
||||||
|
client = this.apiClient.getClient(model) as BaseApiClient
|
||||||
|
} else {
|
||||||
|
// 其他client直接使用
|
||||||
|
client = this.apiClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 构建中间件链
|
||||||
|
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||||
|
// images api
|
||||||
|
if (isDedicatedImageGenerationModel(model)) {
|
||||||
|
builder.clear()
|
||||||
|
builder
|
||||||
|
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
|
||||||
|
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
|
||||||
|
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
|
||||||
|
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
|
||||||
|
} else {
|
||||||
|
// Existing logic for other models
|
||||||
|
logger.silly('Builder Params', params)
|
||||||
|
// 使用兼容性类型检查,避免typescript类型收窄和装饰器模式的问题
|
||||||
|
const clientTypes = client.getClientCompatibilityType(model)
|
||||||
|
const isOpenAICompatible =
|
||||||
|
clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
|
||||||
|
if (!isOpenAICompatible) {
|
||||||
|
logger.silly('ThinkingTagExtractionMiddleware is removed')
|
||||||
|
builder.remove(ThinkingTagExtractionMiddlewareName)
|
||||||
|
}
|
||||||
|
|
||||||
|
const isAnthropicOrOpenAIResponseCompatible =
|
||||||
|
clientTypes.includes('AnthropicAPIClient') ||
|
||||||
|
clientTypes.includes('OpenAIResponseAPIClient') ||
|
||||||
|
clientTypes.includes('AnthropicVertexAPIClient')
|
||||||
|
if (!isAnthropicOrOpenAIResponseCompatible) {
|
||||||
|
logger.silly('RawStreamListenerMiddleware is removed')
|
||||||
|
builder.remove(RawStreamListenerMiddlewareName)
|
||||||
|
}
|
||||||
|
if (!params.enableWebSearch) {
|
||||||
|
logger.silly('WebSearchMiddleware is removed')
|
||||||
|
builder.remove(WebSearchMiddlewareName)
|
||||||
|
}
|
||||||
|
if (!params.mcpTools?.length) {
|
||||||
|
builder.remove(ToolUseExtractionMiddlewareName)
|
||||||
|
logger.silly('ToolUseExtractionMiddleware is removed')
|
||||||
|
builder.remove(McpToolChunkMiddlewareName)
|
||||||
|
logger.silly('McpToolChunkMiddleware is removed')
|
||||||
|
}
|
||||||
|
if (isSupportedToolUse(params.assistant) && isFunctionCallingModel(model)) {
|
||||||
|
builder.remove(ToolUseExtractionMiddlewareName)
|
||||||
|
logger.silly('ToolUseExtractionMiddleware is removed')
|
||||||
|
}
|
||||||
|
if (params.callType !== 'chat' && params.callType !== 'check' && params.callType !== 'translate') {
|
||||||
|
logger.silly('AbortHandlerMiddleware is removed')
|
||||||
|
builder.remove(AbortHandlerMiddlewareName)
|
||||||
|
}
|
||||||
|
if (params.callType === 'test') {
|
||||||
|
builder.remove(ErrorHandlerMiddlewareName)
|
||||||
|
logger.silly('ErrorHandlerMiddleware is removed')
|
||||||
|
builder.remove(FinalChunkConsumerMiddlewareName)
|
||||||
|
logger.silly('FinalChunkConsumerMiddleware is removed')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const middlewares = builder.build()
|
||||||
|
logger.silly(
|
||||||
|
'middlewares',
|
||||||
|
middlewares.map((m) => m.name)
|
||||||
|
)
|
||||||
|
|
||||||
|
// 3. Create the wrapped SDK method with middlewares
|
||||||
|
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
|
||||||
|
|
||||||
|
// 4. Execute the wrapped method with the original params
|
||||||
|
const result = wrappedCompletionMethod(params, options)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
public async completionsForTrace(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||||
|
const traceName = params.assistant.model?.name
|
||||||
|
? `${params.assistant.model?.name}.${params.callType}`
|
||||||
|
: `LLM.${params.callType}`
|
||||||
|
|
||||||
|
const traceParams: StartSpanParams = {
|
||||||
|
name: traceName,
|
||||||
|
tag: 'LLM',
|
||||||
|
topicId: params.topicId || '',
|
||||||
|
modelName: params.assistant.model?.name
|
||||||
|
}
|
||||||
|
|
||||||
|
return await withSpanResult(this.completions.bind(this), traceParams, params, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
public async models(): Promise<SdkModel[]> {
|
||||||
|
return this.apiClient.listModels()
|
||||||
|
}
|
||||||
|
|
||||||
|
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||||
|
try {
|
||||||
|
// Use the SDK instance to test embedding capabilities
|
||||||
|
if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') {
|
||||||
|
this.apiClient = this.apiClient.getClient(model) as BaseApiClient
|
||||||
|
}
|
||||||
|
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
|
||||||
|
return dimensions
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error getting embedding dimensions:', error as Error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||||
|
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||||
|
const client = this.apiClient.getClientForModel({ id: params.model } as Model)
|
||||||
|
return client.generateImage(params)
|
||||||
|
}
|
||||||
|
return this.apiClient.generateImage(params)
|
||||||
|
}
|
||||||
|
|
||||||
|
public getBaseURL(): string {
|
||||||
|
return this.apiClient.getBaseURL()
|
||||||
|
}
|
||||||
|
|
||||||
|
public getApiKey(): string {
|
||||||
|
return this.apiClient.getApiKey()
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { MCPCallToolResponse, MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types'
|
import { MCPCallToolResponse, MCPTool, MCPToolResponse, Model } from '@renderer/types'
|
||||||
import { ChunkType, MCPToolCreatedChunk } from '@renderer/types/chunk'
|
import { ChunkType, MCPToolCreatedChunk } from '@renderer/types/chunk'
|
||||||
import { SdkMessageParam, SdkRawOutput, SdkToolCall } from '@renderer/types/sdk'
|
import { SdkMessageParam, SdkRawOutput, SdkToolCall } from '@renderer/types/sdk'
|
||||||
import {
|
import {
|
||||||
@@ -230,7 +230,7 @@ async function executeToolCalls(
|
|||||||
model: Model,
|
model: Model,
|
||||||
topicId?: string
|
topicId?: string
|
||||||
): Promise<{ toolResults: SdkMessageParam[]; confirmedToolCalls: SdkToolCall[] }> {
|
): Promise<{ toolResults: SdkMessageParam[]; confirmedToolCalls: SdkToolCall[] }> {
|
||||||
const mcpToolResponses: ToolCallResponse[] = toolCalls
|
const mcpToolResponses: MCPToolResponse[] = toolCalls
|
||||||
.map((toolCall) => {
|
.map((toolCall) => {
|
||||||
const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools)
|
const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools)
|
||||||
if (!mcpTool) {
|
if (!mcpTool) {
|
||||||
@@ -238,7 +238,7 @@ async function executeToolCalls(
|
|||||||
}
|
}
|
||||||
return ctx.apiClientInstance.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
|
return ctx.apiClientInstance.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
|
||||||
})
|
})
|
||||||
.filter((t): t is ToolCallResponse => typeof t !== 'undefined')
|
.filter((t): t is MCPToolResponse => typeof t !== 'undefined')
|
||||||
|
|
||||||
if (mcpToolResponses.length === 0) {
|
if (mcpToolResponses.length === 0) {
|
||||||
logger.warn(`No valid MCP tool responses to execute`)
|
logger.warn(`No valid MCP tool responses to execute`)
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
|
import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient'
|
||||||
import { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk'
|
import { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk'
|
||||||
|
|
||||||
import { AnthropicStreamListener } from '../../clients/types'
|
import { AnthropicStreamListener } from '../../clients/types'
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user