Compare commits
268 Commits
feat/custo
...
fix/quick-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6d2f24ac3a | ||
|
|
97b7eebf7d | ||
|
|
4d553beb85 | ||
|
|
26597816e5 | ||
|
|
b8b1083921 | ||
|
|
f19ba44574 | ||
|
|
050bfe1380 | ||
|
|
1b5cba94d2 | ||
|
|
dbd75912aa | ||
|
|
9b321af3da | ||
|
|
d061cdb3ef | ||
|
|
97fb24e060 | ||
|
|
7a035c5734 | ||
|
|
eb89ca5415 | ||
|
|
eb650aa586 | ||
|
|
ce32fd32b6 | ||
|
|
00e395f252 | ||
|
|
b6b1b43094 | ||
|
|
68ae88dc1b | ||
|
|
acf78e8383 | ||
|
|
bd87b8a002 | ||
|
|
7cf7368ae3 | ||
|
|
9001a96fff | ||
|
|
9ea4d1f99f | ||
|
|
fc62a5bdc2 | ||
|
|
06b543039f | ||
|
|
1c354ffa0a | ||
|
|
163e28d9ba | ||
|
|
fd9ff4a432 | ||
|
|
cab975f88b | ||
|
|
c644e4afa8 | ||
|
|
0a498460d6 | ||
|
|
bd4333ab9a | ||
|
|
9138aecdf0 | ||
|
|
e4e4dcbd1e | ||
|
|
2a0484ede2 | ||
|
|
c9f12c2e49 | ||
|
|
27354d82e2 | ||
|
|
f5e1885ffa | ||
|
|
afc4731b9d | ||
|
|
9411866727 | ||
|
|
c7fd1ac373 | ||
|
|
faf14ff10b | ||
|
|
3b3b3c961e | ||
|
|
06d495c7e1 | ||
|
|
922e142079 | ||
|
|
cdc9347011 | ||
|
|
e264b5b052 | ||
|
|
28696c0dad | ||
|
|
8689c07888 | ||
|
|
aa0b7ed1a8 | ||
|
|
5f4d73b00d | ||
|
|
6ad9044cd1 | ||
|
|
9e9a1ec024 | ||
|
|
a214dca6fa | ||
|
|
b142e5647e | ||
|
|
a33a8da5c1 | ||
|
|
e029159067 | ||
|
|
8582ad2529 | ||
|
|
e7f1127aee | ||
|
|
7e54c465b1 | ||
|
|
5c76d398c5 | ||
|
|
f6a935f14f | ||
|
|
26d018b1b7 | ||
|
|
cd8c5115df | ||
|
|
0020e9f3c9 | ||
|
|
8df4cd7e76 | ||
|
|
ee7e6c0f87 | ||
|
|
e65091f83c | ||
|
|
3ee8186f96 | ||
|
|
49f1b62848 | ||
|
|
90a84bb55a | ||
|
|
d2147aed3b | ||
|
|
4f28086a64 | ||
|
|
d9c20c8815 | ||
|
|
b951d89c6a | ||
|
|
ac7d4cb4fa | ||
|
|
d2ea0592ce | ||
|
|
66ddeb94bf | ||
|
|
e13b136484 | ||
|
|
9c5fa57936 | ||
|
|
7e201522d0 | ||
|
|
df35f25502 | ||
|
|
f9e557763e | ||
|
|
eafd814caf | ||
|
|
b84f7bf596 | ||
|
|
c1d753b7fe | ||
|
|
3350f58422 | ||
|
|
8c617872e0 | ||
|
|
a333c635cb | ||
|
|
a244057b3a | ||
|
|
79d7ffcbad | ||
|
|
2d985c1f91 | ||
|
|
5879ccbeb2 | ||
|
|
7887f4867d | ||
|
|
c38a6cdfbf | ||
|
|
ea7766db44 | ||
|
|
a5012ce49e | ||
|
|
d3da4f4623 | ||
|
|
7f12c2f8b8 | ||
|
|
9ba2dea148 | ||
|
|
653bfa1f17 | ||
|
|
fa00b5b173 | ||
|
|
70fb6393b6 | ||
|
|
5b379666f4 | ||
|
|
3cb34d30a9 | ||
|
|
d47c93b4d8 | ||
|
|
bc5cc4bf02 | ||
|
|
8efa7d25f8 | ||
|
|
59195fec1a | ||
|
|
14e6a80049 | ||
|
|
67ab36e0ea | ||
|
|
dfc32967ed | ||
|
|
aa3c376def | ||
|
|
61c58caf78 | ||
|
|
b402cdf7ff | ||
|
|
d80513d011 | ||
|
|
4bcfbf785f | ||
|
|
b722dab56b | ||
|
|
6165e4a47f | ||
|
|
b829abed2d | ||
|
|
36f56ba9aa | ||
|
|
022b11cf6c | ||
|
|
8d6662cb48 | ||
|
|
a59a45f109 | ||
|
|
6337561f65 | ||
|
|
fbbc94028d | ||
|
|
93d955c4b9 | ||
|
|
1c71e6d474 | ||
|
|
b2d10b7a6b | ||
|
|
1215bcb046 | ||
|
|
9195a0324e | ||
|
|
acbec213e8 | ||
|
|
e2a08e31e8 | ||
|
|
e479ee3dbc | ||
|
|
f6462ef998 | ||
|
|
dcdf49a5ce | ||
|
|
74f72fa5b6 | ||
|
|
36f33fed75 | ||
|
|
eb7c05fd4c | ||
|
|
cb746fd722 | ||
|
|
0449bc359a | ||
|
|
d3e51ffb1c | ||
|
|
77eb70626c | ||
|
|
345c4f096e | ||
|
|
a4aab3fd4e | ||
|
|
ecf770e183 | ||
|
|
d58911ac60 | ||
|
|
bb0a35b920 | ||
|
|
403649f2ea | ||
|
|
958f8387d0 | ||
|
|
9c89676030 | ||
|
|
34ec018840 | ||
|
|
1be103a249 | ||
|
|
f83f8bb789 | ||
|
|
cc2810b117 | ||
|
|
be1dae7ef0 | ||
|
|
446d26d8dc | ||
|
|
7724b49ec4 | ||
|
|
ecbd283779 | ||
|
|
389f750d7b | ||
|
|
23eaae80c8 | ||
|
|
8f8c2f852e | ||
|
|
13f7269e36 | ||
|
|
0cd62a07fb | ||
|
|
20b55693cb | ||
|
|
74cccf2c09 | ||
|
|
54d20aa99b | ||
|
|
2c8086f078 | ||
|
|
ea061a3ba6 | ||
|
|
28a6ba1b5d | ||
|
|
8b793a9ca9 | ||
|
|
fe1cf5d605 | ||
|
|
f0335b5aaa | ||
|
|
6c394ec375 | ||
|
|
9f49ce6dc9 | ||
|
|
0df331cf8a | ||
|
|
a5a04e1df7 | ||
|
|
170d1a3a9c | ||
|
|
ce941b6532 | ||
|
|
c5fc7df258 | ||
|
|
30844b8e21 | ||
|
|
99b00cedb4 | ||
|
|
63242384d6 | ||
|
|
e83d31a232 | ||
|
|
65c7b720de | ||
|
|
77ecfbac9f | ||
|
|
1a090a7c51 | ||
|
|
a88bf104df | ||
|
|
c9caa5f46b | ||
|
|
96ae5df1f1 | ||
|
|
6048f42740 | ||
|
|
5b199aa736 | ||
|
|
a6bb58bb45 | ||
|
|
a78db10798 | ||
|
|
479b3ccfb7 | ||
|
|
f916002a71 | ||
|
|
c5208eeaef | ||
|
|
2e8cbdc4aa | ||
|
|
77b0dfc8d3 | ||
|
|
c5c5681cfd | ||
|
|
808afa053f | ||
|
|
cb75d01fd3 | ||
|
|
3ae7bbf304 | ||
|
|
fc3d536433 | ||
|
|
36abf3f099 | ||
|
|
3d7fd5a30c | ||
|
|
f83d9fc03c | ||
|
|
94e6ba759e | ||
|
|
c8c30f327b | ||
|
|
72fae1af25 | ||
|
|
98f8bacdc8 | ||
|
|
06f6da725d | ||
|
|
d24eabb97c | ||
|
|
eca3f1d71e | ||
|
|
87d178773a | ||
|
|
02cb005668 | ||
|
|
cf1d5c098f | ||
|
|
65273b055c | ||
|
|
f171839830 | ||
|
|
8f9a5642f2 | ||
|
|
e906d5db25 | ||
|
|
80c09a07dc | ||
|
|
af6145600a | ||
|
|
42bda59392 | ||
|
|
e73f6505e9 | ||
|
|
332aa45618 | ||
|
|
253075e332 | ||
|
|
737b8f02b1 | ||
|
|
2a996e2c9a | ||
|
|
c77d627077 | ||
|
|
11daf93094 | ||
|
|
44b07ee35d | ||
|
|
b24de23219 | ||
|
|
431e2aaa13 | ||
|
|
9896c75a2e | ||
|
|
94cec70737 | ||
|
|
2ba4e51e93 | ||
|
|
665a62080b | ||
|
|
a05a7e45cc | ||
|
|
f8e9216270 | ||
|
|
11d72f14dc | ||
|
|
f36735f6db | ||
|
|
1b0b08c4c4 | ||
|
|
13d440b0b6 | ||
|
|
2dc81ab8c8 | ||
|
|
b2b0fe9072 | ||
|
|
da30b52334 | ||
|
|
e854ef8757 | ||
|
|
d90ac44945 | ||
|
|
55852cb0a1 | ||
|
|
c28afebdfd | ||
|
|
07407f751f | ||
|
|
4726673508 | ||
|
|
5dc48580a0 | ||
|
|
676c1cbe83 | ||
|
|
6d61bcd605 | ||
|
|
ee78dbd27e | ||
|
|
d88d78e143 | ||
|
|
458f017517 | ||
|
|
f462b7f94e | ||
|
|
94792c9bb1 | ||
|
|
adef817e86 | ||
|
|
2f312d68a0 | ||
|
|
a7520169e6 | ||
|
|
59e3082642 | ||
|
|
795d12c91e | ||
|
|
8eb0be7562 |
2
.github/workflows/nightly-build.yml
vendored
@@ -53,7 +53,7 @@ jobs:
|
|||||||
- name: Check out Git repository
|
- name: Check out Git repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
ref: develop
|
ref: main
|
||||||
|
|
||||||
- name: Install Node.js
|
- name: Install Node.js
|
||||||
uses: actions/setup-node@v4
|
uses: actions/setup-node@v4
|
||||||
|
|||||||
2
.github/workflows/pr-ci.yml
vendored
@@ -44,4 +44,4 @@ jobs:
|
|||||||
run: yarn build:check
|
run: yarn build:check
|
||||||
|
|
||||||
- name: Lint Check
|
- name: Lint Check
|
||||||
run: yarn lint
|
run: yarn test:lint
|
||||||
|
|||||||
39
.github/workflows/release.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
|||||||
- name: Check out Git repository
|
- name: Check out Git repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
with:
|
with:
|
||||||
ref: main
|
fetch-depth: 0
|
||||||
|
|
||||||
- name: Get release tag
|
- name: Get release tag
|
||||||
id: get-tag
|
id: get-tag
|
||||||
@@ -113,5 +113,40 @@ jobs:
|
|||||||
allowUpdates: true
|
allowUpdates: true
|
||||||
makeLatest: false
|
makeLatest: false
|
||||||
tag: ${{ steps.get-tag.outputs.tag }}
|
tag: ${{ steps.get-tag.outputs.tag }}
|
||||||
artifacts: 'dist/*.exe,dist/*.zip,dist/*.dmg,dist/*.AppImage,dist/*.snap,dist/*.deb,dist/*.rpm,dist/*.tar.gz,dist/latest*.yml,dist/*.blockmap'
|
artifacts: 'dist/*.exe,dist/*.zip,dist/*.dmg,dist/*.AppImage,dist/*.snap,dist/*.deb,dist/*.rpm,dist/*.tar.gz,dist/latest*.yml,dist/rc*.yml,dist/*.blockmap'
|
||||||
token: ${{ secrets.GITHUB_TOKEN }}
|
token: ${{ secrets.GITHUB_TOKEN }}
|
||||||
|
|
||||||
|
dispatch-docs-update:
|
||||||
|
needs: release
|
||||||
|
if: success() && github.repository == 'CherryHQ/cherry-studio' # 确保所有构建成功且在主仓库中运行
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Get release tag
|
||||||
|
id: get-tag
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
if [ "${{ github.event_name }}" = "workflow_dispatch" ]; then
|
||||||
|
echo "tag=${{ github.event.inputs.tag }}" >> $GITHUB_OUTPUT
|
||||||
|
else
|
||||||
|
echo "tag=${GITHUB_REF#refs/tags/}" >> $GITHUB_OUTPUT
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Check if tag is pre-release
|
||||||
|
id: check-tag
|
||||||
|
shell: bash
|
||||||
|
run: |
|
||||||
|
TAG="${{ steps.get-tag.outputs.tag }}"
|
||||||
|
if [[ "$TAG" == *"rc"* || "$TAG" == *"pre-release"* ]]; then
|
||||||
|
echo "is_pre_release=true" >> $GITHUB_OUTPUT
|
||||||
|
else
|
||||||
|
echo "is_pre_release=false" >> $GITHUB_OUTPUT
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Dispatch update-download-version workflow to cherry-studio-docs
|
||||||
|
if: steps.check-tag.outputs.is_pre_release == 'false'
|
||||||
|
uses: peter-evans/repository-dispatch@v3
|
||||||
|
with:
|
||||||
|
token: ${{ secrets.REPO_DISPATCH_TOKEN }}
|
||||||
|
repository: CherryHQ/cherry-studio-docs
|
||||||
|
event-type: update-download-version
|
||||||
|
client-payload: '{"version": "${{ steps.get-tag.outputs.tag }}"}'
|
||||||
|
|||||||
9
.gitignore
vendored
@@ -45,10 +45,15 @@ stats.html
|
|||||||
local
|
local
|
||||||
.aider*
|
.aider*
|
||||||
.cursorrules
|
.cursorrules
|
||||||
.cursor/rules
|
.cursor/*
|
||||||
|
|
||||||
# test
|
# vitest
|
||||||
coverage
|
coverage
|
||||||
.vitest-cache
|
.vitest-cache
|
||||||
vitest.config.*.timestamp-*
|
vitest.config.*.timestamp-*
|
||||||
|
|
||||||
|
# playwright
|
||||||
|
playwright-report
|
||||||
|
test-results
|
||||||
|
|
||||||
YOUR_MEMORY_FILE_PATH
|
YOUR_MEMORY_FILE_PATH
|
||||||
|
|||||||
1
.vscode/launch.json
vendored
@@ -7,7 +7,6 @@
|
|||||||
"request": "launch",
|
"request": "launch",
|
||||||
"cwd": "${workspaceRoot}",
|
"cwd": "${workspaceRoot}",
|
||||||
"runtimeExecutable": "${workspaceRoot}/node_modules/.bin/electron-vite",
|
"runtimeExecutable": "${workspaceRoot}/node_modules/.bin/electron-vite",
|
||||||
"runtimeVersion": "20",
|
|
||||||
"windows": {
|
"windows": {
|
||||||
"runtimeExecutable": "${workspaceRoot}/node_modules/.bin/electron-vite.cmd"
|
"runtimeExecutable": "${workspaceRoot}/node_modules/.bin/electron-vite.cmd"
|
||||||
},
|
},
|
||||||
|
|||||||
6471
.yarn/patches/@google-genai-npm-1.0.1-e26f0f9af7.patch
vendored
Normal file
71
.yarn/patches/@langchain-core-npm-0.3.44-41d5c3cb0a.patch
vendored
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
diff --git a/dist/utils/tiktoken.cjs b/dist/utils/tiktoken.cjs
|
||||||
|
index 973b0d0e75aeaf8de579419af31b879b32975413..f23c7caa8b9dc8bd404132725346a4786f6b278b 100644
|
||||||
|
--- a/dist/utils/tiktoken.cjs
|
||||||
|
+++ b/dist/utils/tiktoken.cjs
|
||||||
|
@@ -1,25 +1,14 @@
|
||||||
|
"use strict";
|
||||||
|
Object.defineProperty(exports, "__esModule", { value: true });
|
||||||
|
exports.encodingForModel = exports.getEncoding = void 0;
|
||||||
|
-const lite_1 = require("js-tiktoken/lite");
|
||||||
|
const async_caller_js_1 = require("./async_caller.cjs");
|
||||||
|
const cache = {};
|
||||||
|
const caller = /* #__PURE__ */ new async_caller_js_1.AsyncCaller({});
|
||||||
|
async function getEncoding(encoding) {
|
||||||
|
- if (!(encoding in cache)) {
|
||||||
|
- cache[encoding] = caller
|
||||||
|
- .fetch(`https://tiktoken.pages.dev/js/${encoding}.json`)
|
||||||
|
- .then((res) => res.json())
|
||||||
|
- .then((data) => new lite_1.Tiktoken(data))
|
||||||
|
- .catch((e) => {
|
||||||
|
- delete cache[encoding];
|
||||||
|
- throw e;
|
||||||
|
- });
|
||||||
|
- }
|
||||||
|
- return await cache[encoding];
|
||||||
|
+ throw new Error("TikToken Not implemented");
|
||||||
|
}
|
||||||
|
exports.getEncoding = getEncoding;
|
||||||
|
async function encodingForModel(model) {
|
||||||
|
- return getEncoding((0, lite_1.getEncodingNameForModel)(model));
|
||||||
|
+ throw new Error("TikToken Not implemented");
|
||||||
|
}
|
||||||
|
exports.encodingForModel = encodingForModel;
|
||||||
|
diff --git a/dist/utils/tiktoken.js b/dist/utils/tiktoken.js
|
||||||
|
index 8e41ee6f00f2f9c7fa2c59fa2b2f4297634b97aa..aa5f314a6349ad0d1c5aea8631a56aad099176e0 100644
|
||||||
|
--- a/dist/utils/tiktoken.js
|
||||||
|
+++ b/dist/utils/tiktoken.js
|
||||||
|
@@ -1,20 +1,9 @@
|
||||||
|
-import { Tiktoken, getEncodingNameForModel, } from "js-tiktoken/lite";
|
||||||
|
import { AsyncCaller } from "./async_caller.js";
|
||||||
|
const cache = {};
|
||||||
|
const caller = /* #__PURE__ */ new AsyncCaller({});
|
||||||
|
export async function getEncoding(encoding) {
|
||||||
|
- if (!(encoding in cache)) {
|
||||||
|
- cache[encoding] = caller
|
||||||
|
- .fetch(`https://tiktoken.pages.dev/js/${encoding}.json`)
|
||||||
|
- .then((res) => res.json())
|
||||||
|
- .then((data) => new Tiktoken(data))
|
||||||
|
- .catch((e) => {
|
||||||
|
- delete cache[encoding];
|
||||||
|
- throw e;
|
||||||
|
- });
|
||||||
|
- }
|
||||||
|
- return await cache[encoding];
|
||||||
|
+ throw new Error("TikToken Not implemented");
|
||||||
|
}
|
||||||
|
export async function encodingForModel(model) {
|
||||||
|
- return getEncoding(getEncodingNameForModel(model));
|
||||||
|
+ throw new Error("TikToken Not implemented");
|
||||||
|
}
|
||||||
|
diff --git a/package.json b/package.json
|
||||||
|
index 36072aecf700fca1bc49832a19be832eca726103..90b8922fba1c3d1b26f78477c891b07816d6238a 100644
|
||||||
|
--- a/package.json
|
||||||
|
+++ b/package.json
|
||||||
|
@@ -37,7 +37,6 @@
|
||||||
|
"ansi-styles": "^5.0.0",
|
||||||
|
"camelcase": "6",
|
||||||
|
"decamelize": "1.2.0",
|
||||||
|
- "js-tiktoken": "^1.0.12",
|
||||||
|
"langsmith": ">=0.2.8 <0.4.0",
|
||||||
|
"mustache": "^4.2.0",
|
||||||
|
"p-queue": "^6.6.2",
|
||||||
@@ -65,11 +65,44 @@ index e8bd7bb46c8a54b3f55cf3a853ef924195271e01..f956e9f3fe9eb903c78aef3502553b01
|
|||||||
await packager.info.emitArtifactBuildCompleted({
|
await packager.info.emitArtifactBuildCompleted({
|
||||||
file: installerPath,
|
file: installerPath,
|
||||||
updateInfo,
|
updateInfo,
|
||||||
|
diff --git a/out/util/yarn.js b/out/util/yarn.js
|
||||||
|
index 1ee20f8b252a8f28d0c7b103789cf0a9a427aec1..c2878ec54d57da50bf14225e0c70c9c88664eb8a 100644
|
||||||
|
--- a/out/util/yarn.js
|
||||||
|
+++ b/out/util/yarn.js
|
||||||
|
@@ -140,6 +140,7 @@ async function rebuild(config, { appDir, projectDir }, options) {
|
||||||
|
arch,
|
||||||
|
platform,
|
||||||
|
buildFromSource,
|
||||||
|
+ ignoreModules: config.excludeReBuildModules || undefined,
|
||||||
|
projectRootPath: projectDir,
|
||||||
|
mode: config.nativeRebuilder || "sequential",
|
||||||
|
disablePreGypCopy: true,
|
||||||
diff --git a/scheme.json b/scheme.json
|
diff --git a/scheme.json b/scheme.json
|
||||||
index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..a89c7a9b0b608fef67902c49106a43ebd0fa8b61 100644
|
index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..0167441bf928a92f59b5dbe70b2317a74dda74c9 100644
|
||||||
--- a/scheme.json
|
--- a/scheme.json
|
||||||
+++ b/scheme.json
|
+++ b/scheme.json
|
||||||
@@ -1975,6 +1975,13 @@
|
@@ -1825,6 +1825,20 @@
|
||||||
|
"string"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
+ "excludeReBuildModules": {
|
||||||
|
+ "anyOf": [
|
||||||
|
+ {
|
||||||
|
+ "items": {
|
||||||
|
+ "type": "string"
|
||||||
|
+ },
|
||||||
|
+ "type": "array"
|
||||||
|
+ },
|
||||||
|
+ {
|
||||||
|
+ "type": "null"
|
||||||
|
+ }
|
||||||
|
+ ],
|
||||||
|
+ "description": "The modules to exclude from the rebuild."
|
||||||
|
+ },
|
||||||
|
"executableArgs": {
|
||||||
|
"anyOf": [
|
||||||
|
{
|
||||||
|
@@ -1975,6 +1989,13 @@
|
||||||
],
|
],
|
||||||
"description": "The mime types in addition to specified in the file associations. Use it if you don't want to register a new mime type, but reuse existing."
|
"description": "The mime types in addition to specified in the file associations. Use it if you don't want to register a new mime type, but reuse existing."
|
||||||
},
|
},
|
||||||
@@ -83,7 +116,7 @@ index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..a89c7a9b0b608fef67902c49106a43eb
|
|||||||
"packageCategory": {
|
"packageCategory": {
|
||||||
"description": "backward compatibility + to allow specify fpm-only category for all possible fpm targets in one place",
|
"description": "backward compatibility + to allow specify fpm-only category for all possible fpm targets in one place",
|
||||||
"type": [
|
"type": [
|
||||||
@@ -2327,6 +2334,13 @@
|
@@ -2327,6 +2348,13 @@
|
||||||
"MacConfiguration": {
|
"MacConfiguration": {
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -97,7 +130,28 @@ index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..a89c7a9b0b608fef67902c49106a43eb
|
|||||||
"additionalArguments": {
|
"additionalArguments": {
|
||||||
"anyOf": [
|
"anyOf": [
|
||||||
{
|
{
|
||||||
@@ -2737,7 +2751,7 @@
|
@@ -2527,6 +2555,20 @@
|
||||||
|
"string"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
+ "excludeReBuildModules": {
|
||||||
|
+ "anyOf": [
|
||||||
|
+ {
|
||||||
|
+ "items": {
|
||||||
|
+ "type": "string"
|
||||||
|
+ },
|
||||||
|
+ "type": "array"
|
||||||
|
+ },
|
||||||
|
+ {
|
||||||
|
+ "type": "null"
|
||||||
|
+ }
|
||||||
|
+ ],
|
||||||
|
+ "description": "The modules to exclude from the rebuild."
|
||||||
|
+ },
|
||||||
|
"executableName": {
|
||||||
|
"description": "The executable name. Defaults to `productName`.",
|
||||||
|
"type": [
|
||||||
|
@@ -2737,7 +2779,7 @@
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
"minimumSystemVersion": {
|
"minimumSystemVersion": {
|
||||||
@@ -106,7 +160,7 @@ index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..a89c7a9b0b608fef67902c49106a43eb
|
|||||||
"type": [
|
"type": [
|
||||||
"null",
|
"null",
|
||||||
"string"
|
"string"
|
||||||
@@ -2959,6 +2973,13 @@
|
@@ -2959,6 +3001,13 @@
|
||||||
"MasConfiguration": {
|
"MasConfiguration": {
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"properties": {
|
"properties": {
|
||||||
@@ -120,7 +174,28 @@ index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..a89c7a9b0b608fef67902c49106a43eb
|
|||||||
"additionalArguments": {
|
"additionalArguments": {
|
||||||
"anyOf": [
|
"anyOf": [
|
||||||
{
|
{
|
||||||
@@ -3369,7 +3390,7 @@
|
@@ -3159,6 +3208,20 @@
|
||||||
|
"string"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
+ "excludeReBuildModules": {
|
||||||
|
+ "anyOf": [
|
||||||
|
+ {
|
||||||
|
+ "items": {
|
||||||
|
+ "type": "string"
|
||||||
|
+ },
|
||||||
|
+ "type": "array"
|
||||||
|
+ },
|
||||||
|
+ {
|
||||||
|
+ "type": "null"
|
||||||
|
+ }
|
||||||
|
+ ],
|
||||||
|
+ "description": "The modules to exclude from the rebuild."
|
||||||
|
+ },
|
||||||
|
"executableName": {
|
||||||
|
"description": "The executable name. Defaults to `productName`.",
|
||||||
|
"type": [
|
||||||
|
@@ -3369,7 +3432,7 @@
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
"minimumSystemVersion": {
|
"minimumSystemVersion": {
|
||||||
@@ -129,7 +204,28 @@ index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..a89c7a9b0b608fef67902c49106a43eb
|
|||||||
"type": [
|
"type": [
|
||||||
"null",
|
"null",
|
||||||
"string"
|
"string"
|
||||||
@@ -6507,6 +6528,13 @@
|
@@ -6381,6 +6444,20 @@
|
||||||
|
"string"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
+ "excludeReBuildModules": {
|
||||||
|
+ "anyOf": [
|
||||||
|
+ {
|
||||||
|
+ "items": {
|
||||||
|
+ "type": "string"
|
||||||
|
+ },
|
||||||
|
+ "type": "array"
|
||||||
|
+ },
|
||||||
|
+ {
|
||||||
|
+ "type": "null"
|
||||||
|
+ }
|
||||||
|
+ ],
|
||||||
|
+ "description": "The modules to exclude from the rebuild."
|
||||||
|
+ },
|
||||||
|
"executableName": {
|
||||||
|
"description": "The executable name. Defaults to `productName`.",
|
||||||
|
"type": [
|
||||||
|
@@ -6507,6 +6584,13 @@
|
||||||
"string"
|
"string"
|
||||||
]
|
]
|
||||||
},
|
},
|
||||||
@@ -143,7 +239,28 @@ index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..a89c7a9b0b608fef67902c49106a43eb
|
|||||||
"protocols": {
|
"protocols": {
|
||||||
"anyOf": [
|
"anyOf": [
|
||||||
{
|
{
|
||||||
@@ -7376,6 +7404,13 @@
|
@@ -7153,6 +7237,20 @@
|
||||||
|
"string"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
+ "excludeReBuildModules": {
|
||||||
|
+ "anyOf": [
|
||||||
|
+ {
|
||||||
|
+ "items": {
|
||||||
|
+ "type": "string"
|
||||||
|
+ },
|
||||||
|
+ "type": "array"
|
||||||
|
+ },
|
||||||
|
+ {
|
||||||
|
+ "type": "null"
|
||||||
|
+ }
|
||||||
|
+ ],
|
||||||
|
+ "description": "The modules to exclude from the rebuild."
|
||||||
|
+ },
|
||||||
|
"executableName": {
|
||||||
|
"description": "The executable name. Defaults to `productName`.",
|
||||||
|
"type": [
|
||||||
|
@@ -7376,6 +7474,13 @@
|
||||||
],
|
],
|
||||||
"description": "MAS (Mac Application Store) development options (`mas-dev` target)."
|
"description": "MAS (Mac Application Store) development options (`mas-dev` target)."
|
||||||
},
|
},
|
||||||
|
|||||||
85
.yarn/patches/openai-npm-4.96.0-0665b05cb9.patch
vendored
@@ -1,85 +0,0 @@
|
|||||||
diff --git a/core.js b/core.js
|
|
||||||
index 862d66101f441fb4f47dfc8cff5e2d39e1f5a11e..6464bebbf696c39d35f0368f061ea4236225c162 100644
|
|
||||||
--- a/core.js
|
|
||||||
+++ b/core.js
|
|
||||||
@@ -159,7 +159,7 @@ class APIClient {
|
|
||||||
Accept: 'application/json',
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'User-Agent': this.getUserAgent(),
|
|
||||||
- ...getPlatformHeaders(),
|
|
||||||
+ // ...getPlatformHeaders(),
|
|
||||||
...this.authHeaders(opts),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
diff --git a/core.mjs b/core.mjs
|
|
||||||
index 05dbc6cfde51589a2b100d4e4b5b3c1a33b32b89..789fbb4985eb952a0349b779fa83b1a068af6e7e 100644
|
|
||||||
--- a/core.mjs
|
|
||||||
+++ b/core.mjs
|
|
||||||
@@ -152,7 +152,7 @@ export class APIClient {
|
|
||||||
Accept: 'application/json',
|
|
||||||
'Content-Type': 'application/json',
|
|
||||||
'User-Agent': this.getUserAgent(),
|
|
||||||
- ...getPlatformHeaders(),
|
|
||||||
+ // ...getPlatformHeaders(),
|
|
||||||
...this.authHeaders(opts),
|
|
||||||
};
|
|
||||||
}
|
|
||||||
diff --git a/error.mjs b/error.mjs
|
|
||||||
index 7d19f5578040afa004bc887aab1725e8703d2bac..59ec725b6142299a62798ac4bdedb63ba7d9932c 100644
|
|
||||||
--- a/error.mjs
|
|
||||||
+++ b/error.mjs
|
|
||||||
@@ -36,7 +36,7 @@ export class APIError extends OpenAIError {
|
|
||||||
if (!status || !headers) {
|
|
||||||
return new APIConnectionError({ message, cause: castToError(errorResponse) });
|
|
||||||
}
|
|
||||||
- const error = errorResponse?.['error'];
|
|
||||||
+ const error = errorResponse?.['error'] || errorResponse;
|
|
||||||
if (status === 400) {
|
|
||||||
return new BadRequestError(status, error, message, headers);
|
|
||||||
}
|
|
||||||
diff --git a/resources/embeddings.js b/resources/embeddings.js
|
|
||||||
index aae578404cb2d09a39ac33fc416f1c215c45eecd..25c54b05bdae64d5c3b36fbb30dc7c8221b14034 100644
|
|
||||||
--- a/resources/embeddings.js
|
|
||||||
+++ b/resources/embeddings.js
|
|
||||||
@@ -36,6 +36,9 @@ class Embeddings extends resource_1.APIResource {
|
|
||||||
// No encoding_format specified, defaulting to base64 for performance reasons
|
|
||||||
// See https://github.com/openai/openai-node/pull/1312
|
|
||||||
let encoding_format = hasUserProvidedEncodingFormat ? body.encoding_format : 'base64';
|
|
||||||
+ if (body.model.includes('jina')) {
|
|
||||||
+ encoding_format = undefined;
|
|
||||||
+ }
|
|
||||||
if (hasUserProvidedEncodingFormat) {
|
|
||||||
Core.debug('Request', 'User defined encoding_format:', body.encoding_format);
|
|
||||||
}
|
|
||||||
@@ -47,7 +50,7 @@ class Embeddings extends resource_1.APIResource {
|
|
||||||
...options,
|
|
||||||
});
|
|
||||||
// if the user specified an encoding_format, return the response as-is
|
|
||||||
- if (hasUserProvidedEncodingFormat) {
|
|
||||||
+ if (hasUserProvidedEncodingFormat || body.model.includes('jina')) {
|
|
||||||
return response;
|
|
||||||
}
|
|
||||||
// in this stage, we are sure the user did not specify an encoding_format
|
|
||||||
diff --git a/resources/embeddings.mjs b/resources/embeddings.mjs
|
|
||||||
index 0df3c6cc79a520e54acb4c2b5f77c43b774035ff..aa488b8a11b2c413c0a663d9a6059d286d7b5faf 100644
|
|
||||||
--- a/resources/embeddings.mjs
|
|
||||||
+++ b/resources/embeddings.mjs
|
|
||||||
@@ -10,6 +10,9 @@ export class Embeddings extends APIResource {
|
|
||||||
// No encoding_format specified, defaulting to base64 for performance reasons
|
|
||||||
// See https://github.com/openai/openai-node/pull/1312
|
|
||||||
let encoding_format = hasUserProvidedEncodingFormat ? body.encoding_format : 'base64';
|
|
||||||
+ if (body.model.includes('jina')) {
|
|
||||||
+ encoding_format = undefined;
|
|
||||||
+ }
|
|
||||||
if (hasUserProvidedEncodingFormat) {
|
|
||||||
Core.debug('Request', 'User defined encoding_format:', body.encoding_format);
|
|
||||||
}
|
|
||||||
@@ -21,7 +24,7 @@ export class Embeddings extends APIResource {
|
|
||||||
...options,
|
|
||||||
});
|
|
||||||
// if the user specified an encoding_format, return the response as-is
|
|
||||||
- if (hasUserProvidedEncodingFormat) {
|
|
||||||
+ if (hasUserProvidedEncodingFormat || body.model.includes('jina')) {
|
|
||||||
return response;
|
|
||||||
}
|
|
||||||
// in this stage, we are sure the user did not specify an encoding_format
|
|
||||||
279
.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch
vendored
Normal file
@@ -0,0 +1,279 @@
|
|||||||
|
diff --git a/client.js b/client.js
|
||||||
|
index 33b4ff6309d5f29187dab4e285d07dac20340bab..8f568637ee9e4677585931fb0284c8165a933f69 100644
|
||||||
|
--- a/client.js
|
||||||
|
+++ b/client.js
|
||||||
|
@@ -433,7 +433,7 @@ class OpenAI {
|
||||||
|
'User-Agent': this.getUserAgent(),
|
||||||
|
'X-Stainless-Retry-Count': String(retryCount),
|
||||||
|
...(options.timeout ? { 'X-Stainless-Timeout': String(Math.trunc(options.timeout / 1000)) } : {}),
|
||||||
|
- ...(0, detect_platform_1.getPlatformHeaders)(),
|
||||||
|
+ // ...(0, detect_platform_1.getPlatformHeaders)(),
|
||||||
|
'OpenAI-Organization': this.organization,
|
||||||
|
'OpenAI-Project': this.project,
|
||||||
|
},
|
||||||
|
diff --git a/client.mjs b/client.mjs
|
||||||
|
index c34c18213073540ebb296ea540b1d1ad39527906..1ce1a98256d7e90e26ca963582f235b23e996e73 100644
|
||||||
|
--- a/client.mjs
|
||||||
|
+++ b/client.mjs
|
||||||
|
@@ -430,7 +430,7 @@ export class OpenAI {
|
||||||
|
'User-Agent': this.getUserAgent(),
|
||||||
|
'X-Stainless-Retry-Count': String(retryCount),
|
||||||
|
...(options.timeout ? { 'X-Stainless-Timeout': String(Math.trunc(options.timeout / 1000)) } : {}),
|
||||||
|
- ...getPlatformHeaders(),
|
||||||
|
+ // ...getPlatformHeaders(),
|
||||||
|
'OpenAI-Organization': this.organization,
|
||||||
|
'OpenAI-Project': this.project,
|
||||||
|
},
|
||||||
|
diff --git a/core/error.js b/core/error.js
|
||||||
|
index a12d9d9ccd242050161adeb0f82e1b98d9e78e20..fe3a5462480558bc426deea147f864f12b36f9bd 100644
|
||||||
|
--- a/core/error.js
|
||||||
|
+++ b/core/error.js
|
||||||
|
@@ -40,7 +40,7 @@ class APIError extends OpenAIError {
|
||||||
|
if (!status || !headers) {
|
||||||
|
return new APIConnectionError({ message, cause: (0, errors_1.castToError)(errorResponse) });
|
||||||
|
}
|
||||||
|
- const error = errorResponse?.['error'];
|
||||||
|
+ const error = errorResponse?.['error'] || errorResponse;
|
||||||
|
if (status === 400) {
|
||||||
|
return new BadRequestError(status, error, message, headers);
|
||||||
|
}
|
||||||
|
diff --git a/core/error.mjs b/core/error.mjs
|
||||||
|
index 83cefbaffeb8c657536347322d8de9516af479a2..63334b7972ec04882aa4a0800c1ead5982345045 100644
|
||||||
|
--- a/core/error.mjs
|
||||||
|
+++ b/core/error.mjs
|
||||||
|
@@ -36,7 +36,7 @@ export class APIError extends OpenAIError {
|
||||||
|
if (!status || !headers) {
|
||||||
|
return new APIConnectionError({ message, cause: castToError(errorResponse) });
|
||||||
|
}
|
||||||
|
- const error = errorResponse?.['error'];
|
||||||
|
+ const error = errorResponse?.['error'] || errorResponse;
|
||||||
|
if (status === 400) {
|
||||||
|
return new BadRequestError(status, error, message, headers);
|
||||||
|
}
|
||||||
|
diff --git a/resources/embeddings.js b/resources/embeddings.js
|
||||||
|
index 2404264d4ba0204322548945ebb7eab3bea82173..8f1bc45cc45e0797d50989d96b51147b90ae6790 100644
|
||||||
|
--- a/resources/embeddings.js
|
||||||
|
+++ b/resources/embeddings.js
|
||||||
|
@@ -5,52 +5,64 @@ exports.Embeddings = void 0;
|
||||||
|
const resource_1 = require("../core/resource.js");
|
||||||
|
const utils_1 = require("../internal/utils.js");
|
||||||
|
class Embeddings extends resource_1.APIResource {
|
||||||
|
- /**
|
||||||
|
- * Creates an embedding vector representing the input text.
|
||||||
|
- *
|
||||||
|
- * @example
|
||||||
|
- * ```ts
|
||||||
|
- * const createEmbeddingResponse =
|
||||||
|
- * await client.embeddings.create({
|
||||||
|
- * input: 'The quick brown fox jumped over the lazy dog',
|
||||||
|
- * model: 'text-embedding-3-small',
|
||||||
|
- * });
|
||||||
|
- * ```
|
||||||
|
- */
|
||||||
|
- create(body, options) {
|
||||||
|
- const hasUserProvidedEncodingFormat = !!body.encoding_format;
|
||||||
|
- // No encoding_format specified, defaulting to base64 for performance reasons
|
||||||
|
- // See https://github.com/openai/openai-node/pull/1312
|
||||||
|
- let encoding_format = hasUserProvidedEncodingFormat ? body.encoding_format : 'base64';
|
||||||
|
- if (hasUserProvidedEncodingFormat) {
|
||||||
|
- (0, utils_1.loggerFor)(this._client).debug('embeddings/user defined encoding_format:', body.encoding_format);
|
||||||
|
- }
|
||||||
|
- const response = this._client.post('/embeddings', {
|
||||||
|
- body: {
|
||||||
|
- ...body,
|
||||||
|
- encoding_format: encoding_format,
|
||||||
|
- },
|
||||||
|
- ...options,
|
||||||
|
- });
|
||||||
|
- // if the user specified an encoding_format, return the response as-is
|
||||||
|
- if (hasUserProvidedEncodingFormat) {
|
||||||
|
- return response;
|
||||||
|
- }
|
||||||
|
- // in this stage, we are sure the user did not specify an encoding_format
|
||||||
|
- // and we defaulted to base64 for performance reasons
|
||||||
|
- // we are sure then that the response is base64 encoded, let's decode it
|
||||||
|
- // the returned result will be a float32 array since this is OpenAI API's default encoding
|
||||||
|
- (0, utils_1.loggerFor)(this._client).debug('embeddings/decoding base64 embeddings from base64');
|
||||||
|
- return response._thenUnwrap((response) => {
|
||||||
|
- if (response && response.data) {
|
||||||
|
- response.data.forEach((embeddingBase64Obj) => {
|
||||||
|
- const embeddingBase64Str = embeddingBase64Obj.embedding;
|
||||||
|
- embeddingBase64Obj.embedding = (0, utils_1.toFloat32Array)(embeddingBase64Str);
|
||||||
|
- });
|
||||||
|
- }
|
||||||
|
- return response;
|
||||||
|
- });
|
||||||
|
- }
|
||||||
|
+ /**
|
||||||
|
+ * Creates an embedding vector representing the input text.
|
||||||
|
+ *
|
||||||
|
+ * @example
|
||||||
|
+ * ```ts
|
||||||
|
+ * const createEmbeddingResponse =
|
||||||
|
+ * await client.embeddings.create({
|
||||||
|
+ * input: 'The quick brown fox jumped over the lazy dog',
|
||||||
|
+ * model: 'text-embedding-3-small',
|
||||||
|
+ * });
|
||||||
|
+ * ```
|
||||||
|
+ */
|
||||||
|
+ create(body, options) {
|
||||||
|
+ const hasUserProvidedEncodingFormat = !!body.encoding_format;
|
||||||
|
+ // No encoding_format specified, defaulting to base64 for performance reasons
|
||||||
|
+ // See https://github.com/openai/openai-node/pull/1312
|
||||||
|
+ let encoding_format = hasUserProvidedEncodingFormat
|
||||||
|
+ ? body.encoding_format
|
||||||
|
+ : "base64";
|
||||||
|
+ if (body.model.includes("jina")) {
|
||||||
|
+ encoding_format = undefined;
|
||||||
|
+ }
|
||||||
|
+ if (hasUserProvidedEncodingFormat) {
|
||||||
|
+ (0, utils_1.loggerFor)(this._client).debug(
|
||||||
|
+ "embeddings/user defined encoding_format:",
|
||||||
|
+ body.encoding_format
|
||||||
|
+ );
|
||||||
|
+ }
|
||||||
|
+ const response = this._client.post("/embeddings", {
|
||||||
|
+ body: {
|
||||||
|
+ ...body,
|
||||||
|
+ encoding_format: encoding_format,
|
||||||
|
+ },
|
||||||
|
+ ...options,
|
||||||
|
+ });
|
||||||
|
+ // if the user specified an encoding_format, return the response as-is
|
||||||
|
+ if (hasUserProvidedEncodingFormat || body.model.includes("jina")) {
|
||||||
|
+ return response;
|
||||||
|
+ }
|
||||||
|
+ // in this stage, we are sure the user did not specify an encoding_format
|
||||||
|
+ // and we defaulted to base64 for performance reasons
|
||||||
|
+ // we are sure then that the response is base64 encoded, let's decode it
|
||||||
|
+ // the returned result will be a float32 array since this is OpenAI API's default encoding
|
||||||
|
+ (0, utils_1.loggerFor)(this._client).debug(
|
||||||
|
+ "embeddings/decoding base64 embeddings from base64"
|
||||||
|
+ );
|
||||||
|
+ return response._thenUnwrap((response) => {
|
||||||
|
+ if (response && response.data && typeof response.data[0]?.embedding === 'string') {
|
||||||
|
+ response.data.forEach((embeddingBase64Obj) => {
|
||||||
|
+ const embeddingBase64Str = embeddingBase64Obj.embedding;
|
||||||
|
+ embeddingBase64Obj.embedding = (0, utils_1.toFloat32Array)(
|
||||||
|
+ embeddingBase64Str
|
||||||
|
+ );
|
||||||
|
+ });
|
||||||
|
+ }
|
||||||
|
+ return response;
|
||||||
|
+ });
|
||||||
|
+ }
|
||||||
|
}
|
||||||
|
exports.Embeddings = Embeddings;
|
||||||
|
//# sourceMappingURL=embeddings.js.map
|
||||||
|
diff --git a/resources/embeddings.mjs b/resources/embeddings.mjs
|
||||||
|
index 19dcaef578c194a89759c4360073cfd4f7dd2cbf..0284e9cc615c900eff508eb595f7360a74bd9200 100644
|
||||||
|
--- a/resources/embeddings.mjs
|
||||||
|
+++ b/resources/embeddings.mjs
|
||||||
|
@@ -2,51 +2,61 @@
|
||||||
|
import { APIResource } from "../core/resource.mjs";
|
||||||
|
import { loggerFor, toFloat32Array } from "../internal/utils.mjs";
|
||||||
|
export class Embeddings extends APIResource {
|
||||||
|
- /**
|
||||||
|
- * Creates an embedding vector representing the input text.
|
||||||
|
- *
|
||||||
|
- * @example
|
||||||
|
- * ```ts
|
||||||
|
- * const createEmbeddingResponse =
|
||||||
|
- * await client.embeddings.create({
|
||||||
|
- * input: 'The quick brown fox jumped over the lazy dog',
|
||||||
|
- * model: 'text-embedding-3-small',
|
||||||
|
- * });
|
||||||
|
- * ```
|
||||||
|
- */
|
||||||
|
- create(body, options) {
|
||||||
|
- const hasUserProvidedEncodingFormat = !!body.encoding_format;
|
||||||
|
- // No encoding_format specified, defaulting to base64 for performance reasons
|
||||||
|
- // See https://github.com/openai/openai-node/pull/1312
|
||||||
|
- let encoding_format = hasUserProvidedEncodingFormat ? body.encoding_format : 'base64';
|
||||||
|
- if (hasUserProvidedEncodingFormat) {
|
||||||
|
- loggerFor(this._client).debug('embeddings/user defined encoding_format:', body.encoding_format);
|
||||||
|
- }
|
||||||
|
- const response = this._client.post('/embeddings', {
|
||||||
|
- body: {
|
||||||
|
- ...body,
|
||||||
|
- encoding_format: encoding_format,
|
||||||
|
- },
|
||||||
|
- ...options,
|
||||||
|
- });
|
||||||
|
- // if the user specified an encoding_format, return the response as-is
|
||||||
|
- if (hasUserProvidedEncodingFormat) {
|
||||||
|
- return response;
|
||||||
|
- }
|
||||||
|
- // in this stage, we are sure the user did not specify an encoding_format
|
||||||
|
- // and we defaulted to base64 for performance reasons
|
||||||
|
- // we are sure then that the response is base64 encoded, let's decode it
|
||||||
|
- // the returned result will be a float32 array since this is OpenAI API's default encoding
|
||||||
|
- loggerFor(this._client).debug('embeddings/decoding base64 embeddings from base64');
|
||||||
|
- return response._thenUnwrap((response) => {
|
||||||
|
- if (response && response.data) {
|
||||||
|
- response.data.forEach((embeddingBase64Obj) => {
|
||||||
|
- const embeddingBase64Str = embeddingBase64Obj.embedding;
|
||||||
|
- embeddingBase64Obj.embedding = toFloat32Array(embeddingBase64Str);
|
||||||
|
- });
|
||||||
|
- }
|
||||||
|
- return response;
|
||||||
|
- });
|
||||||
|
- }
|
||||||
|
+ /**
|
||||||
|
+ * Creates an embedding vector representing the input text.
|
||||||
|
+ *
|
||||||
|
+ * @example
|
||||||
|
+ * ```ts
|
||||||
|
+ * const createEmbeddingResponse =
|
||||||
|
+ * await client.embeddings.create({
|
||||||
|
+ * input: 'The quick brown fox jumped over the lazy dog',
|
||||||
|
+ * model: 'text-embedding-3-small',
|
||||||
|
+ * });
|
||||||
|
+ * ```
|
||||||
|
+ */
|
||||||
|
+ create(body, options) {
|
||||||
|
+ const hasUserProvidedEncodingFormat = !!body.encoding_format;
|
||||||
|
+ // No encoding_format specified, defaulting to base64 for performance reasons
|
||||||
|
+ // See https://github.com/openai/openai-node/pull/1312
|
||||||
|
+ let encoding_format = hasUserProvidedEncodingFormat
|
||||||
|
+ ? body.encoding_format
|
||||||
|
+ : "base64";
|
||||||
|
+ if (body.model.includes("jina")) {
|
||||||
|
+ encoding_format = undefined;
|
||||||
|
+ }
|
||||||
|
+ if (hasUserProvidedEncodingFormat) {
|
||||||
|
+ loggerFor(this._client).debug(
|
||||||
|
+ "embeddings/user defined encoding_format:",
|
||||||
|
+ body.encoding_format
|
||||||
|
+ );
|
||||||
|
+ }
|
||||||
|
+ const response = this._client.post("/embeddings", {
|
||||||
|
+ body: {
|
||||||
|
+ ...body,
|
||||||
|
+ encoding_format: encoding_format,
|
||||||
|
+ },
|
||||||
|
+ ...options,
|
||||||
|
+ });
|
||||||
|
+ // if the user specified an encoding_format, return the response as-is
|
||||||
|
+ if (hasUserProvidedEncodingFormat || body.model.includes("jina")) {
|
||||||
|
+ return response;
|
||||||
|
+ }
|
||||||
|
+ // in this stage, we are sure the user did not specify an encoding_format
|
||||||
|
+ // and we defaulted to base64 for performance reasons
|
||||||
|
+ // we are sure then that the response is base64 encoded, let's decode it
|
||||||
|
+ // the returned result will be a float32 array since this is OpenAI API's default encoding
|
||||||
|
+ loggerFor(this._client).debug(
|
||||||
|
+ "embeddings/decoding base64 embeddings from base64"
|
||||||
|
+ );
|
||||||
|
+ return response._thenUnwrap((response) => {
|
||||||
|
+ if (response && response.data && typeof response.data[0]?.embedding === 'string') {
|
||||||
|
+ response.data.forEach((embeddingBase64Obj) => {
|
||||||
|
+ const embeddingBase64Str = embeddingBase64Obj.embedding;
|
||||||
|
+ embeddingBase64Obj.embedding = toFloat32Array(embeddingBase64Str);
|
||||||
|
+ });
|
||||||
|
+ }
|
||||||
|
+ return response;
|
||||||
|
+ });
|
||||||
|
+ }
|
||||||
|
}
|
||||||
|
//# sourceMappingURL=embeddings.mjs.map
|
||||||
144
README.md
@@ -3,10 +3,42 @@
|
|||||||
<img src="https://github.com/CherryHQ/cherry-studio/blob/main/build/icon.png?raw=true" width="150" height="150" alt="banner" /><br>
|
<img src="https://github.com/CherryHQ/cherry-studio/blob/main/build/icon.png?raw=true" width="150" height="150" alt="banner" /><br>
|
||||||
</a>
|
</a>
|
||||||
</h1>
|
</h1>
|
||||||
<p align="center">English | <a href="./docs/README.zh.md">中文</a> | <a href="./docs/README.ja.md">日本語</a><br></p>
|
<p align="center">English | <a href="./docs/README.zh.md">中文</a> | <a href="./docs/README.ja.md">日本語</a> | <a href="https://cherry-ai.com">Official Site</a> | <a href="https://docs.cherry-ai.com/cherry-studio-wen-dang/en-us">Documents</a> | <a href="./docs/dev.md">Development</a> | <a href="https://github.com/CherryHQ/cherry-studio/issues">Feedback</a><br></p>
|
||||||
|
|
||||||
|
<!-- 题头徽章组合 -->
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
|
[![][deepwiki-shield]][deepwiki-link]
|
||||||
|
[![][twitter-shield]][twitter-link]
|
||||||
|
[![][discord-shield]][discord-link]
|
||||||
|
[![][telegram-shield]][telegram-link]
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 项目统计徽章 -->
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|
[![][github-stars-shield]][github-stars-link]
|
||||||
|
[![][github-forks-shield]][github-forks-link]
|
||||||
|
[![][github-release-shield]][github-release-link]
|
||||||
|
[![][github-contributors-shield]][github-contributors-link]
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|
[![][license-shield]][license-link]
|
||||||
|
[![][commercial-shield]][commercial-link]
|
||||||
|
[![][sponsor-shield]][sponsor-link]
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<a href="https://hellogithub.com/repository/1605492e1e2a4df3be07abfa4578dd37" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=1605492e1e2a4df3be07abfa4578dd37" alt="Featured|HelloGitHub" style="width: 200px; height: 43px;" width="200" height="43" /></a>
|
||||||
<a href="https://trendshift.io/repositories/11772" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11772" alt="kangfenmao%2Fcherry-studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
<a href="https://trendshift.io/repositories/11772" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11772" alt="kangfenmao%2Fcherry-studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
<a href="https://www.producthunt.com/posts/cherry-studio?embed=true&utm_source=badge-featured&utm_medium=badge&utm_souce=badge-cherry-studio" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=496640&theme=light" alt="Cherry Studio - AI Chatbots, AI Desktop Client | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
<a href="https://www.producthunt.com/posts/cherry-studio?embed=true&utm_source=badge-featured&utm_medium=badge&utm_souce=badge-cherry-studio" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=496640&theme=light" alt="Cherry Studio - AI Chatbots, AI Desktop Client | Product Hunt" style="width: 200px; height: 43px;" width="200" height="43" /></a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
# 🍒 Cherry Studio
|
# 🍒 Cherry Studio
|
||||||
@@ -17,10 +49,6 @@ Cherry Studio is a desktop client that supports for multiple LLM providers, avai
|
|||||||
|
|
||||||
❤️ Like Cherry Studio? Give it a star 🌟 or [Sponsor](docs/sponsor.md) to support the development!
|
❤️ Like Cherry Studio? Give it a star 🌟 or [Sponsor](docs/sponsor.md) to support the development!
|
||||||
|
|
||||||
# 📖 Guide
|
|
||||||
|
|
||||||
<https://docs.cherry-ai.com>
|
|
||||||
|
|
||||||
# 🌠 Screenshot
|
# 🌠 Screenshot
|
||||||
|
|
||||||

|

|
||||||
@@ -67,20 +95,42 @@ Cherry Studio is a desktop client that supports for multiple LLM providers, avai
|
|||||||
- 📝 Complete Markdown Rendering
|
- 📝 Complete Markdown Rendering
|
||||||
- 🤲 Easy Content Sharing
|
- 🤲 Easy Content Sharing
|
||||||
|
|
||||||
# 📝 TODO
|
# 📝 Roadmap
|
||||||
|
|
||||||
- [x] Quick popup (read clipboard, quick question, explain, translate, summarize)
|
We're actively working on the following features and improvements:
|
||||||
- [x] Comparison of multi-model answers
|
|
||||||
- [x] Support login using SSO provided by service providers
|
1. 🎯 **Core Features**
|
||||||
- [x] All models support networking
|
|
||||||
- [x] Launch of the first official version
|
- Selection Assistant - Smart content selection enhancement
|
||||||
- [x] Bug fixes and improvements (In progress...)
|
- Deep Research - Advanced research capabilities
|
||||||
- [ ] Plugin functionality (JavaScript)
|
- Memory System - Global context awareness
|
||||||
- [ ] Browser extension (highlight text to translate, summarize, add to knowledge base)
|
- Document Preprocessing - Improved document handling
|
||||||
- [ ] iOS & Android client
|
- MCP Marketplace - Model Context Protocol ecosystem
|
||||||
- [ ] AI notes
|
|
||||||
- [ ] Voice input and output (AI call)
|
2. 🗂 **Knowledge Management**
|
||||||
- [ ] Data backup supports custom backup content
|
|
||||||
|
- Notes and Collections
|
||||||
|
- Dynamic Canvas visualization
|
||||||
|
- OCR capabilities
|
||||||
|
- TTS (Text-to-Speech) support
|
||||||
|
|
||||||
|
3. 📱 **Platform Support**
|
||||||
|
|
||||||
|
- HarmonyOS Edition (PC)
|
||||||
|
- Android App (Phase 1)
|
||||||
|
- iOS App (Phase 1)
|
||||||
|
- Multi-Window support
|
||||||
|
- Window Pinning functionality
|
||||||
|
|
||||||
|
4. 🔌 **Advanced Features**
|
||||||
|
|
||||||
|
- Plugin System
|
||||||
|
- ASR (Automatic Speech Recognition)
|
||||||
|
- Assistant and Topic Interaction Refactoring
|
||||||
|
|
||||||
|
Track our progress and contribute on our [project board](https://github.com/orgs/CherryHQ/projects/7).
|
||||||
|
|
||||||
|
Want to influence our roadmap? Join our [GitHub Discussions](https://github.com/CherryHQ/cherry-studio/discussions) to share your ideas and feedback!
|
||||||
|
|
||||||
# 🌈 Theme
|
# 🌈 Theme
|
||||||
|
|
||||||
@@ -92,14 +142,6 @@ Cherry Studio is a desktop client that supports for multiple LLM providers, avai
|
|||||||
|
|
||||||
Welcome PR for more themes
|
Welcome PR for more themes
|
||||||
|
|
||||||
# 🖥️ Develop
|
|
||||||
|
|
||||||
Refer to the [development documentation](docs/dev.md)
|
|
||||||
|
|
||||||
Refer to the [Architecture overview documentation](https://deepwiki.com/CherryHQ/cherry-studio)
|
|
||||||
|
|
||||||
Refer to the [Branching Strategy](docs/branching-strategy-en.md) for contribution guidelines
|
|
||||||
|
|
||||||
# 🤝 Contributing
|
# 🤝 Contributing
|
||||||
|
|
||||||
We welcome contributions to Cherry Studio! Here are some ways you can contribute:
|
We welcome contributions to Cherry Studio! Here are some ways you can contribute:
|
||||||
@@ -112,6 +154,8 @@ We welcome contributions to Cherry Studio! Here are some ways you can contribute
|
|||||||
6. **Community Engagement**: Join discussions and help users.
|
6. **Community Engagement**: Join discussions and help users.
|
||||||
7. **Promote Usage**: Spread the word about Cherry Studio.
|
7. **Promote Usage**: Spread the word about Cherry Studio.
|
||||||
|
|
||||||
|
Refer to the [Branching Strategy](docs/branching-strategy-en.md) for contribution guidelines
|
||||||
|
|
||||||
## Getting Started
|
## Getting Started
|
||||||
|
|
||||||
1. **Fork the Repository**: Fork and clone it to your local machine.
|
1. **Fork the Repository**: Fork and clone it to your local machine.
|
||||||
@@ -136,22 +180,34 @@ Thank you for your support and contributions!
|
|||||||
</a>
|
</a>
|
||||||
<br /><br />
|
<br /><br />
|
||||||
|
|
||||||
# 🌐 Community
|
|
||||||
|
|
||||||
[Telegram](https://t.me/CherryStudioAI) | [Email](mailto:support@cherry-ai.com) | [Twitter](https://x.com/kangfenmao)
|
|
||||||
|
|
||||||
# ☕ Sponsor
|
|
||||||
|
|
||||||
[Buy Me a Coffee](docs/sponsor.md)
|
|
||||||
|
|
||||||
# 📃 License
|
|
||||||
|
|
||||||
[LICENSE](./LICENSE)
|
|
||||||
|
|
||||||
# ✉️ Contact
|
|
||||||
|
|
||||||
<yinsenho@cherry-ai.com>
|
|
||||||
|
|
||||||
# ⭐️ Star History
|
# ⭐️ Star History
|
||||||
|
|
||||||
[](https://star-history.com/#kangfenmao/cherry-studio&Timeline)
|
[](https://star-history.com/#CherryHQ/cherry-studio&Timeline)
|
||||||
|
|
||||||
|
<!-- Links & Images -->
|
||||||
|
[deepwiki-shield]: https://img.shields.io/badge/Deepwiki-CherryHQ-0088CC?style=plastic
|
||||||
|
[deepwiki-link]: https://deepwiki.com/CherryHQ/cherry-studio
|
||||||
|
[twitter-shield]: https://img.shields.io/badge/Twitter-CherryStudioApp-0088CC?style=plastic&logo=x
|
||||||
|
[twitter-link]: https://twitter.com/CherryStudioHQ
|
||||||
|
[discord-shield]: https://img.shields.io/badge/Discord-@CherryStudio-0088CC?style=plastic&logo=discord
|
||||||
|
[discord-link]: https://discord.gg/wez8HtpxqQ
|
||||||
|
[telegram-shield]: https://img.shields.io/badge/Telegram-@CherryStudioAI-0088CC?style=plastic&logo=telegram
|
||||||
|
[telegram-link]: https://t.me/CherryStudioAI
|
||||||
|
|
||||||
|
<!-- Links & Images -->
|
||||||
|
[github-stars-shield]: https://img.shields.io/github/stars/CherryHQ/cherry-studio?style=social
|
||||||
|
[github-stars-link]: https://github.com/CherryHQ/cherry-studio/stargazers
|
||||||
|
[github-forks-shield]: https://img.shields.io/github/forks/CherryHQ/cherry-studio?style=social
|
||||||
|
[github-forks-link]: https://github.com/CherryHQ/cherry-studio/network
|
||||||
|
[github-release-shield]: https://img.shields.io/github/v/release/CherryHQ/cherry-studio
|
||||||
|
[github-release-link]: https://github.com/CherryHQ/cherry-studio/releases
|
||||||
|
[github-contributors-shield]: https://img.shields.io/github/contributors/CherryHQ/cherry-studio
|
||||||
|
[github-contributors-link]: https://github.com/CherryHQ/cherry-studio/graphs/contributors
|
||||||
|
|
||||||
|
<!-- Links & Images -->
|
||||||
|
[license-shield]: https://img.shields.io/badge/License-AGPLv3-important.svg?style=plastic&logo=gnu
|
||||||
|
[license-link]: https://www.gnu.org/licenses/agpl-3.0
|
||||||
|
[commercial-shield]: https://img.shields.io/badge/License-Contact-white.svg?style=plastic&logoColor=white&logo=telegram&color=blue
|
||||||
|
[commercial-link]: mailto:license@cherry-ai.com?subject=Commercial%20License%20Inquiry
|
||||||
|
[sponsor-shield]: https://img.shields.io/badge/Sponsor-FF6699.svg?style=plastic&logo=githubsponsors&logoColor=white
|
||||||
|
[sponsor-link]: https://github.com/CherryHQ/cherry-studio/blob/main/docs/sponsor.md
|
||||||
|
|||||||
@@ -1,15 +1,46 @@
|
|||||||
<h1 align="center">
|
<h1 align="center">
|
||||||
<a href="https://github.com/CherryHQ/cherry-studio/releases">
|
<a href="https://github.com/CherryHQ/cherry-studio/releases">
|
||||||
<img src="https://github.com/CherryHQ/cherry-studio/blob/main/build/icon.png?raw=true" width="150" height="150" alt="banner" />
|
<img src="https://github.com/CherryHQ/cherry-studio/blob/main/build/icon.png?raw=true" width="150" height="150" alt="banner" /><br>
|
||||||
</a>
|
</a>
|
||||||
</h1>
|
</h1>
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://github.com/CherryHQ/cherry-studio">English</a> | <a href="./README.zh.md">中文</a> | 日本語 <br>
|
<a href="https://github.com/CherryHQ/cherry-studio">English</a> | <a href="./README.zh.md">中文</a> | 日本語 | <a href="https://cherry-ai.com">公式サイト</a> | <a href="https://docs.cherry-ai.com/cherry-studio-wen-dang/ja">ドキュメント</a> | <a href="./dev.md">開発</a> | <a href="https://github.com/CherryHQ/cherry-studio/issues">フィードバック</a><br>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
<!-- バッジコレクション -->
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
|
[![][deepwiki-shield]][deepwiki-link]
|
||||||
|
[![][twitter-shield]][twitter-link]
|
||||||
|
[![][discord-shield]][discord-link]
|
||||||
|
[![][telegram-shield]][telegram-link]
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- プロジェクト統計 -->
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|
[![][github-stars-shield]][github-stars-link]
|
||||||
|
[![][github-forks-shield]][github-forks-link]
|
||||||
|
[![][github-release-shield]][github-release-link]
|
||||||
|
[![][github-contributors-shield]][github-contributors-link]
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|
[![][license-shield]][license-link]
|
||||||
|
[![][commercial-shield]][commercial-link]
|
||||||
|
[![][sponsor-shield]][sponsor-link]
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<a href="https://hellogithub.com/repository/1605492e1e2a4df3be07abfa4578dd37" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=1605492e1e2a4df3be07abfa4578dd37" alt="Featured|HelloGitHub" style="width: 200px; height: 43px;" width="200" height="43" /></a>
|
||||||
<a href="https://trendshift.io/repositories/11772" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11772" alt="kangfenmao%2Fcherry-studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
<a href="https://trendshift.io/repositories/11772" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11772" alt="kangfenmao%2Fcherry-studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
<a href="https://www.producthunt.com/posts/cherry-studio?embed=true&utm_source=badge-featured&utm_medium=badge&utm_souce=badge-cherry-studio" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=496640&theme=light" alt="Cherry Studio - AI Chatbots, AI Desktop Client | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
<a href="https://www.producthunt.com/posts/cherry-studio?embed=true&utm_source=badge-featured&utm_medium=badge&utm_souce=badge-cherry-studio" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=496640&theme=light" alt="Cherry Studio - AI Chatbots, AI Desktop Client | Product Hunt" style="width: 200px; height: 43px;" width="200" height="43" /></a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
# 🍒 Cherry Studio
|
# 🍒 Cherry Studio
|
||||||
@@ -20,10 +51,6 @@ Cherry Studio は、複数の LLM プロバイダーをサポートするデス
|
|||||||
|
|
||||||
❤️ Cherry Studio をお気に入りにしましたか?小さな星をつけてください 🌟 または [スポンサー](sponsor.md) をして開発をサポートしてください!
|
❤️ Cherry Studio をお気に入りにしましたか?小さな星をつけてください 🌟 または [スポンサー](sponsor.md) をして開発をサポートしてください!
|
||||||
|
|
||||||
# 📖 ガイド
|
|
||||||
|
|
||||||
https://docs.cherry-ai.com
|
|
||||||
|
|
||||||
# 🌠 スクリーンショット
|
# 🌠 スクリーンショット
|
||||||
|
|
||||||

|

|
||||||
@@ -70,20 +97,42 @@ https://docs.cherry-ai.com
|
|||||||
- 📝 完全な Markdown レンダリング
|
- 📝 完全な Markdown レンダリング
|
||||||
- 🤲 簡単な共有機能
|
- 🤲 簡単な共有機能
|
||||||
|
|
||||||
# 📝 TODO
|
# 📝 開発計画
|
||||||
|
|
||||||
- [x] クイックポップアップ(クリップボードの読み取り、簡単な質問、説明、翻訳、要約)
|
以下の機能と改善に積極的に取り組んでいます:
|
||||||
- [x] 複数モデルの回答の比較
|
|
||||||
- [x] サービスプロバイダーが提供する SSO を使用したログイン対応
|
1. 🎯 **コア機能**
|
||||||
- [x] すべてのモデルのネットワーク対応
|
|
||||||
- [x] 最初の公式バージョンのリリース
|
- 選択アシスタント - スマートな内容選択の強化
|
||||||
- [x] バグ修正と改善(進行中...)
|
- ディープリサーチ - 高度な研究能力
|
||||||
- [ ] プラグイン機能(JavaScript)
|
- メモリーシステム - グローバルコンテキスト認識
|
||||||
- [ ] ブラウザ拡張機能(テキストをハイライトして翻訳、要約、ナレッジベースに追加)
|
- ドキュメント前処理 - 文書処理の改善
|
||||||
- [ ] iOS & Android クライアント
|
- MCP マーケットプレイス - モデルコンテキストプロトコルエコシステム
|
||||||
- [ ] AI ノート
|
|
||||||
- [ ] 音声入出力(AI コール)
|
2. 🗂 **ナレッジ管理**
|
||||||
- [ ] データバックアップのカスタマイズ対応
|
|
||||||
|
- ノートとコレクション
|
||||||
|
- ダイナミックキャンバス可視化
|
||||||
|
- OCR 機能
|
||||||
|
- TTS(テキスト読み上げ)サポート
|
||||||
|
|
||||||
|
3. 📱 **プラットフォーム対応**
|
||||||
|
|
||||||
|
- HarmonyOS エディション
|
||||||
|
- Android アプリ(フェーズ1)
|
||||||
|
- iOS アプリ(フェーズ1)
|
||||||
|
- マルチウィンドウ対応
|
||||||
|
- ウィンドウピン留め機能
|
||||||
|
|
||||||
|
4. 🔌 **高度な機能**
|
||||||
|
|
||||||
|
- プラグインシステム
|
||||||
|
- ASR(音声認識)
|
||||||
|
- アシスタントとトピックの対話機能リファクタリング
|
||||||
|
|
||||||
|
[プロジェクトボード](https://github.com/orgs/CherryHQ/projects/7)で進捗を確認し、貢献することができます。
|
||||||
|
|
||||||
|
開発計画に影響を与えたいですか?[GitHub ディスカッション](https://github.com/CherryHQ/cherry-studio/discussions)に参加して、アイデアやフィードバックを共有してください!
|
||||||
|
|
||||||
# 🌈 テーマ
|
# 🌈 テーマ
|
||||||
|
|
||||||
@@ -95,14 +144,6 @@ https://docs.cherry-ai.com
|
|||||||
|
|
||||||
より多くのテーマの PR を歓迎します
|
より多くのテーマの PR を歓迎します
|
||||||
|
|
||||||
# 🖥️ 開発
|
|
||||||
|
|
||||||
[開発ドキュメント](dev.md)を参照してください
|
|
||||||
|
|
||||||
[アーキテクチャ概要ドキュメント](https://deepwiki.com/CherryHQ/cherry-studio)を参照してください
|
|
||||||
|
|
||||||
[ブランチ戦略](branching-strategy-en.md)を参照して貢献ガイドラインを確認してください
|
|
||||||
|
|
||||||
# 🤝 貢献
|
# 🤝 貢献
|
||||||
|
|
||||||
Cherry Studio への貢献を歓迎します!以下の方法で貢献できます:
|
Cherry Studio への貢献を歓迎します!以下の方法で貢献できます:
|
||||||
@@ -115,6 +156,8 @@ Cherry Studio への貢献を歓迎します!以下の方法で貢献できま
|
|||||||
6. **コミュニティの参加**:ディスカッションに参加し、ユーザーを支援します
|
6. **コミュニティの参加**:ディスカッションに参加し、ユーザーを支援します
|
||||||
7. **使用の促進**:Cherry Studio を広めます
|
7. **使用の促進**:Cherry Studio を広めます
|
||||||
|
|
||||||
|
[ブランチ戦略](branching-strategy-en.md)を参照して貢献ガイドラインを確認してください
|
||||||
|
|
||||||
## 始め方
|
## 始め方
|
||||||
|
|
||||||
1. **リポジトリをフォーク**:フォークしてローカルマシンにクローンします
|
1. **リポジトリをフォーク**:フォークしてローカルマシンにクローンします
|
||||||
@@ -139,22 +182,34 @@ Cherry Studio への貢献を歓迎します!以下の方法で貢献できま
|
|||||||
</a>
|
</a>
|
||||||
<br /><br />
|
<br /><br />
|
||||||
|
|
||||||
# 🌐 コミュニティ
|
|
||||||
|
|
||||||
[Telegram](https://t.me/CherryStudioAI) | [Email](mailto:support@cherry-ai.com) | [Twitter](https://x.com/kangfenmao)
|
|
||||||
|
|
||||||
# ☕ スポンサー
|
|
||||||
|
|
||||||
[開発者を支援する](sponsor.md)
|
|
||||||
|
|
||||||
# 📃 ライセンス
|
|
||||||
|
|
||||||
[LICENSE](../LICENSE)
|
|
||||||
|
|
||||||
# ✉️ お問い合わせ
|
|
||||||
|
|
||||||
yinsenho@cherry-ai.com
|
|
||||||
|
|
||||||
# ⭐️ スター履歴
|
# ⭐️ スター履歴
|
||||||
|
|
||||||
[](https://star-history.com/#kangfenmao/cherry-studio&Timeline)
|
[](https://star-history.com/#CherryHQ/cherry-studio&Timeline)
|
||||||
|
|
||||||
|
<!-- リンクと画像 -->
|
||||||
|
[deepwiki-shield]: https://img.shields.io/badge/Deepwiki-CherryHQ-0088CC?style=plastic
|
||||||
|
[deepwiki-link]: https://deepwiki.com/CherryHQ/cherry-studio
|
||||||
|
[twitter-shield]: https://img.shields.io/badge/Twitter-CherryStudioApp-0088CC?style=plastic&logo=x
|
||||||
|
[twitter-link]: https://twitter.com/CherryStudioHQ
|
||||||
|
[discord-shield]: https://img.shields.io/badge/Discord-@CherryStudio-0088CC?style=plastic&logo=discord
|
||||||
|
[discord-link]: https://discord.gg/wez8HtpxqQ
|
||||||
|
[telegram-shield]: https://img.shields.io/badge/Telegram-@CherryStudioAI-0088CC?style=plastic&logo=telegram
|
||||||
|
[telegram-link]: https://t.me/CherryStudioAI
|
||||||
|
|
||||||
|
<!-- プロジェクト統計 -->
|
||||||
|
[github-stars-shield]: https://img.shields.io/github/stars/CherryHQ/cherry-studio?style=social
|
||||||
|
[github-stars-link]: https://github.com/CherryHQ/cherry-studio/stargazers
|
||||||
|
[github-forks-shield]: https://img.shields.io/github/forks/CherryHQ/cherry-studio?style=social
|
||||||
|
[github-forks-link]: https://github.com/CherryHQ/cherry-studio/network
|
||||||
|
[github-release-shield]: https://img.shields.io/github/v/release/CherryHQ/cherry-studio
|
||||||
|
[github-release-link]: https://github.com/CherryHQ/cherry-studio/releases
|
||||||
|
[github-contributors-shield]: https://img.shields.io/github/contributors/CherryHQ/cherry-studio
|
||||||
|
[github-contributors-link]: https://github.com/CherryHQ/cherry-studio/graphs/contributors
|
||||||
|
|
||||||
|
<!-- ライセンスとスポンサー -->
|
||||||
|
[license-shield]: https://img.shields.io/badge/License-AGPLv3-important.svg?style=plastic&logo=gnu
|
||||||
|
[license-link]: https://www.gnu.org/licenses/agpl-3.0
|
||||||
|
[commercial-shield]: https://img.shields.io/badge/商用ライセンス-お問い合わせ-white.svg?style=plastic&logoColor=white&logo=telegram&color=blue
|
||||||
|
[commercial-link]: mailto:license@cherry-ai.com?subject=商業ライセンスについて
|
||||||
|
[sponsor-shield]: https://img.shields.io/badge/スポンサー-FF6699.svg?style=plastic&logo=githubsponsors&logoColor=white
|
||||||
|
[sponsor-link]: https://github.com/CherryHQ/cherry-studio/blob/main/docs/sponsor.md
|
||||||
|
|||||||
@@ -1,14 +1,46 @@
|
|||||||
<h1 align="center">
|
<h1 align="center">
|
||||||
<a href="https://github.com/CherryHQ/cherry-studio/releases">
|
<a href="https://github.com/CherryHQ/cherry-studio/releases">
|
||||||
<img src="https://github.com/CherryHQ/cherry-studio/blob/main/build/icon.png?raw=true" width="150" height="150" alt="banner" />
|
<img src="https://github.com/CherryHQ/cherry-studio/blob/main/build/icon.png?raw=true" width="150" height="150" alt="banner" /><br>
|
||||||
</a>
|
</a>
|
||||||
</h1>
|
</h1>
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="https://github.com/CherryHQ/cherry-studio">English</a> | 中文 | <a href="./README.ja.md">日本語</a><br>
|
<a href="https://github.com/CherryHQ/cherry-studio">English</a> | 中文 | <a href="./README.ja.md">日本語</a> | <a href="https://cherry-ai.com">官方网站</a> | <a href="https://docs.cherry-ai.com/cherry-studio-wen-dang/zh-cn">文档</a> | <a href="./dev.md">开发</a> | <a href="https://github.com/CherryHQ/cherry-studio/issues">反馈</a><br>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
|
<!-- 题头徽章组合 -->
|
||||||
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
|
||||||
|
[![][deepwiki-shield]][deepwiki-link]
|
||||||
|
[![][twitter-shield]][twitter-link]
|
||||||
|
[![][discord-shield]][discord-link]
|
||||||
|
[![][telegram-shield]][telegram-link]
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<!-- 项目统计徽章 -->
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|
[![][github-stars-shield]][github-stars-link]
|
||||||
|
[![][github-forks-shield]][github-forks-link]
|
||||||
|
[![][github-release-shield]][github-release-link]
|
||||||
|
[![][github-contributors-shield]][github-contributors-link]
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|
[![][license-shield]][license-link]
|
||||||
|
[![][commercial-shield]][commercial-link]
|
||||||
|
[![][sponsor-shield]][sponsor-link]
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
<a href="https://hellogithub.com/repository/1605492e1e2a4df3be07abfa4578dd37" target="_blank"><img src="https://api.hellogithub.com/v1/widgets/recommend.svg?rid=1605492e1e2a4df3be07abfa4578dd37" alt="Featured|HelloGitHub" style="width: 200px; height: 43px;" width="200" height="43" /></a>
|
||||||
<a href="https://trendshift.io/repositories/11772" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11772" alt="kangfenmao%2Fcherry-studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
<a href="https://trendshift.io/repositories/11772" target="_blank"><img src="https://trendshift.io/api/badge/repositories/11772" alt="kangfenmao%2Fcherry-studio | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
||||||
<a href="https://www.producthunt.com/posts/cherry-studio?embed=true&utm_source=badge-featured&utm_medium=badge&utm_souce=badge-cherry-studio" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=496640&theme=light" alt="Cherry Studio - AI Chatbots, AI Desktop Client | Product Hunt" style="width: 250px; height: 54px;" width="250" height="54" /></a>
|
<a href="https://www.producthunt.com/posts/cherry-studio?embed=true&utm_source=badge-featured&utm_medium=badge&utm_souce=badge-cherry-studio" target="_blank"><img src="https://api.producthunt.com/widgets/embed-image/v1/featured.svg?post_id=496640&theme=light" alt="Cherry Studio - AI Chatbots, AI Desktop Client | Product Hunt" style="width: 200px; height: 43px;" width="200" height="43" /></a>
|
||||||
</div>
|
</div>
|
||||||
|
|
||||||
# 🍒 Cherry Studio
|
# 🍒 Cherry Studio
|
||||||
@@ -77,20 +109,42 @@ https://docs.cherry-ai.com
|
|||||||
- 📝 完整的 Markdown 渲染
|
- 📝 完整的 Markdown 渲染
|
||||||
- 🤲 便捷的内容分享功能
|
- 🤲 便捷的内容分享功能
|
||||||
|
|
||||||
# 📝 待办事项
|
# 📝 开发计划
|
||||||
|
|
||||||
- [x] 快捷弹窗(读取剪贴板、快速提问、解释、翻译、总结)
|
我们正在积极开发以下功能和改进:
|
||||||
- [x] 多模型回答对比
|
|
||||||
- [x] 支持使用服务供应商提供的 SSO 进行登录
|
1. 🎯 **核心功能**
|
||||||
- [x] 所有模型支持联网
|
|
||||||
- [x] 推出第一个正式版
|
- 选择助手 - 智能内容选择增强
|
||||||
- [x] 错误修复和改进(开发中...)
|
- 深度研究 - 高级研究能力
|
||||||
- [ ] 插件功能(JavaScript)
|
- 全局记忆 - 全局上下文感知
|
||||||
- [ ] 浏览器插件(划词翻译、总结、新增至知识库)
|
- 文档预处理 - 改进文档处理能力
|
||||||
- [ ] iOS & Android 客户端
|
- MCP 市场 - 模型上下文协议生态系统
|
||||||
- [ ] AI 笔记
|
|
||||||
- [ ] 语音输入输出(AI 通话)
|
2. 🗂 **知识管理**
|
||||||
- [ ] 数据备份支持自定义备份内容
|
|
||||||
|
- 笔记与收藏功能
|
||||||
|
- 动态画布可视化
|
||||||
|
- OCR 光学字符识别
|
||||||
|
- TTS 文本转语音支持
|
||||||
|
|
||||||
|
3. 📱 **平台支持**
|
||||||
|
|
||||||
|
- 鸿蒙版本 (PC)
|
||||||
|
- Android 应用(第一期)
|
||||||
|
- iOS 应用(第一期)
|
||||||
|
- 多窗口支持
|
||||||
|
- 窗口置顶功能
|
||||||
|
|
||||||
|
4. 🔌 **高级特性**
|
||||||
|
|
||||||
|
- 插件系统
|
||||||
|
- ASR 语音识别
|
||||||
|
- 助手与话题交互重构
|
||||||
|
|
||||||
|
在我们的[项目面板](https://github.com/orgs/CherryHQ/projects/7)上跟踪进展并参与贡献。
|
||||||
|
|
||||||
|
想要影响开发计划?欢迎加入我们的 [GitHub 讨论区](https://github.com/CherryHQ/cherry-studio/discussions) 分享您的想法和反馈!
|
||||||
|
|
||||||
# 🌈 主题
|
# 🌈 主题
|
||||||
|
|
||||||
@@ -102,14 +156,6 @@ https://docs.cherry-ai.com
|
|||||||
|
|
||||||
欢迎 PR 更多主题
|
欢迎 PR 更多主题
|
||||||
|
|
||||||
# 🖥️ 开发
|
|
||||||
|
|
||||||
参考[开发文档](dev.md)
|
|
||||||
|
|
||||||
参考[架构概览文档](https://deepwiki.com/CherryHQ/cherry-studio)
|
|
||||||
|
|
||||||
参考[分支策略](branching-strategy-zh.md)了解贡献指南
|
|
||||||
|
|
||||||
# 🤝 贡献
|
# 🤝 贡献
|
||||||
|
|
||||||
我们欢迎对 Cherry Studio 的贡献!您可以通过以下方式贡献:
|
我们欢迎对 Cherry Studio 的贡献!您可以通过以下方式贡献:
|
||||||
@@ -122,6 +168,8 @@ https://docs.cherry-ai.com
|
|||||||
6. **社区参与**:加入讨论并帮助用户
|
6. **社区参与**:加入讨论并帮助用户
|
||||||
7. **推广使用**:宣传 Cherry Studio
|
7. **推广使用**:宣传 Cherry Studio
|
||||||
|
|
||||||
|
参考[分支策略](branching-strategy-zh.md)了解贡献指南
|
||||||
|
|
||||||
## 入门
|
## 入门
|
||||||
|
|
||||||
1. **Fork 仓库**:Fork 并克隆到您的本地机器
|
1. **Fork 仓库**:Fork 并克隆到您的本地机器
|
||||||
@@ -146,22 +194,34 @@ https://docs.cherry-ai.com
|
|||||||
</a>
|
</a>
|
||||||
<br /><br />
|
<br /><br />
|
||||||
|
|
||||||
# 🌐 社区
|
|
||||||
|
|
||||||
[Telegram](https://t.me/CherryStudioAI) | [Email](mailto:support@cherry-ai.com) | [Twitter](https://x.com/kangfenmao)
|
|
||||||
|
|
||||||
# ☕ 赞助
|
|
||||||
|
|
||||||
[赞助开发者](sponsor.md)
|
|
||||||
|
|
||||||
# 📃 许可证
|
|
||||||
|
|
||||||
[LICENSE](../LICENSE)
|
|
||||||
|
|
||||||
# ✉️ 联系我们
|
|
||||||
|
|
||||||
yinsenho@cherry-ai.com
|
|
||||||
|
|
||||||
# ⭐️ Star 记录
|
# ⭐️ Star 记录
|
||||||
|
|
||||||
[](https://star-history.com/#kangfenmao/cherry-studio&Timeline)
|
[](https://star-history.com/#CherryHQ/cherry-studio&Timeline)
|
||||||
|
|
||||||
|
<!-- Links & Images -->
|
||||||
|
[deepwiki-shield]: https://img.shields.io/badge/Deepwiki-CherryHQ-0088CC?style=plastic
|
||||||
|
[deepwiki-link]: https://deepwiki.com/CherryHQ/cherry-studio
|
||||||
|
[twitter-shield]: https://img.shields.io/badge/Twitter-CherryStudioApp-0088CC?style=plastic&logo=x
|
||||||
|
[twitter-link]: https://twitter.com/CherryStudioHQ
|
||||||
|
[discord-shield]: https://img.shields.io/badge/Discord-@CherryStudio-0088CC?style=plastic&logo=discord
|
||||||
|
[discord-link]: https://discord.gg/wez8HtpxqQ
|
||||||
|
[telegram-shield]: https://img.shields.io/badge/Telegram-@CherryStudioAI-0088CC?style=plastic&logo=telegram
|
||||||
|
[telegram-link]: https://t.me/CherryStudioAI
|
||||||
|
|
||||||
|
<!-- 项目统计徽章 -->
|
||||||
|
[github-stars-shield]: https://img.shields.io/github/stars/CherryHQ/cherry-studio?style=social
|
||||||
|
[github-stars-link]: https://github.com/CherryHQ/cherry-studio/stargazers
|
||||||
|
[github-forks-shield]: https://img.shields.io/github/forks/CherryHQ/cherry-studio?style=social
|
||||||
|
[github-forks-link]: https://github.com/CherryHQ/cherry-studio/network
|
||||||
|
[github-release-shield]: https://img.shields.io/github/v/release/CherryHQ/cherry-studio
|
||||||
|
[github-release-link]: https://github.com/CherryHQ/cherry-studio/releases
|
||||||
|
[github-contributors-shield]: https://img.shields.io/github/contributors/CherryHQ/cherry-studio
|
||||||
|
[github-contributors-link]: https://github.com/CherryHQ/cherry-studio/graphs/contributors
|
||||||
|
|
||||||
|
<!-- 许可和赞助徽章 -->
|
||||||
|
[license-shield]: https://img.shields.io/badge/License-AGPLv3-important.svg?style=plastic&logo=gnu
|
||||||
|
[license-link]: https://www.gnu.org/licenses/agpl-3.0
|
||||||
|
[commercial-shield]: https://img.shields.io/badge/商用授权-联系-white.svg?style=plastic&logoColor=white&logo=telegram&color=blue
|
||||||
|
[commercial-link]: mailto:license@cherry-ai.com?subject=商业授权咨询
|
||||||
|
[sponsor-shield]: https://img.shields.io/badge/赞助支持-FF6699.svg?style=plastic&logo=githubsponsors&logoColor=white
|
||||||
|
[sponsor-link]: https://github.com/CherryHQ/cherry-studio/blob/main/docs/sponsor.md
|
||||||
|
|||||||
@@ -37,6 +37,14 @@ yarn install
|
|||||||
yarn dev
|
yarn dev
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Debug
|
||||||
|
|
||||||
|
```bash
|
||||||
|
yarn debug
|
||||||
|
```
|
||||||
|
|
||||||
|
Then input chrome://inspect in browser
|
||||||
|
|
||||||
### Test
|
### Test
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|||||||
214
docs/technical/how-to-write-middlewares.md
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
# 如何为 AI Provider 编写中间件
|
||||||
|
|
||||||
|
本文档旨在指导开发者如何为我们的 AI Provider 框架创建和集成自定义中间件。中间件提供了一种强大而灵活的方式来增强、修改或观察 Provider 方法的调用过程,例如日志记录、缓存、请求/响应转换、错误处理等。
|
||||||
|
|
||||||
|
## 架构概览
|
||||||
|
|
||||||
|
我们的中间件架构借鉴了 Redux 的三段式设计,并结合了 JavaScript Proxy 来动态地将中间件应用于 Provider 的方法。
|
||||||
|
|
||||||
|
- **Proxy**: 拦截对 Provider 方法的调用,并将调用引导至中间件链。
|
||||||
|
- **中间件链**: 一系列按顺序执行的中间件函数。每个中间件都可以处理请求/响应,然后将控制权传递给链中的下一个中间件,或者在某些情况下提前终止链。
|
||||||
|
- **上下文 (Context)**: 一个在中间件之间传递的对象,携带了关于当前调用的信息(如方法名、原始参数、Provider 实例、以及中间件自定义的数据)。
|
||||||
|
|
||||||
|
## 中间件的类型
|
||||||
|
|
||||||
|
目前主要支持两种类型的中间件,它们共享相似的结构但针对不同的场景:
|
||||||
|
|
||||||
|
1. **`CompletionsMiddleware`**: 专门为 `completions` 方法设计。这是最常用的中间件类型,因为它允许对 AI 模型的核心聊天/文本生成功能进行精细控制。
|
||||||
|
2. **`ProviderMethodMiddleware`**: 通用中间件,可以应用于 Provider 上的任何其他方法(例如,`translate`, `summarize` 等,如果这些方法也通过中间件系统包装)。
|
||||||
|
|
||||||
|
## 编写一个 `CompletionsMiddleware`
|
||||||
|
|
||||||
|
`CompletionsMiddleware` 的基本签名(TypeScript 类型)如下:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { AiProviderMiddlewareCompletionsContext, CompletionsParams, MiddlewareAPI } from './AiProviderMiddlewareTypes' // 假设类型定义文件路径
|
||||||
|
|
||||||
|
export type CompletionsMiddleware = (
|
||||||
|
api: MiddlewareAPI<AiProviderMiddlewareCompletionsContext, [CompletionsParams]>
|
||||||
|
) => (
|
||||||
|
next: (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise<any> // next 返回 Promise<any> 代表原始SDK响应或下游中间件的结果
|
||||||
|
) => (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise<void> // 最内层函数通常返回 Promise<void>,因为结果通过 onChunk 或 context 副作用传递
|
||||||
|
```
|
||||||
|
|
||||||
|
让我们分解这个三段式结构:
|
||||||
|
|
||||||
|
1. **第一层函数 `(api) => { ... }`**:
|
||||||
|
|
||||||
|
- 接收一个 `api` 对象。
|
||||||
|
- `api` 对象提供了以下方法:
|
||||||
|
- `api.getContext()`: 获取当前调用的上下文对象 (`AiProviderMiddlewareCompletionsContext`)。
|
||||||
|
- `api.getOriginalArgs()`: 获取传递给 `completions` 方法的原始参数数组 (即 `[CompletionsParams]`)。
|
||||||
|
- `api.getProviderId()`: 获取当前 Provider 的 ID。
|
||||||
|
- `api.getProviderInstance()`: 获取原始的 Provider 实例。
|
||||||
|
- 此函数通常用于进行一次性的设置或获取所需的服务/配置。它返回第二层函数。
|
||||||
|
|
||||||
|
2. **第二层函数 `(next) => { ... }`**:
|
||||||
|
|
||||||
|
- 接收一个 `next` 函数。
|
||||||
|
- `next` 函数代表了中间件链中的下一个环节。调用 `next(context, params)` 会将控制权传递给下一个中间件,或者如果当前中间件是链中的最后一个,则会调用核心的 Provider 方法逻辑 (例如,实际的 SDK 调用)。
|
||||||
|
- `next` 函数接收当前的 `context` 和 `params` (这些可能已被上游中间件修改)。
|
||||||
|
- **重要的是**:`next` 的返回类型通常是 `Promise<any>`。对于 `completions` 方法,如果 `next` 调用了实际的 SDK,它将返回原始的 SDK 响应(例如,OpenAI 的流对象或 JSON 对象)。你需要处理这个响应。
|
||||||
|
- 此函数返回第三层(也是最核心的)函数。
|
||||||
|
|
||||||
|
3. **第三层函数 `(context, params) => { ... }`**:
|
||||||
|
- 这是执行中间件主要逻辑的地方。
|
||||||
|
- 它接收当前的 `context` (`AiProviderMiddlewareCompletionsContext`) 和 `params` (`CompletionsParams`)。
|
||||||
|
- 在此函数中,你可以:
|
||||||
|
- **在调用 `next` 之前**:
|
||||||
|
- 读取或修改 `params`。例如,添加默认参数、转换消息格式。
|
||||||
|
- 读取或修改 `context`。例如,设置一个时间戳用于后续计算延迟。
|
||||||
|
- 执行某些检查,如果不满足条件,可以不调用 `next` 而直接返回或抛出错误(例如,参数校验失败)。
|
||||||
|
- **调用 `await next(context, params)`**:
|
||||||
|
- 这是将控制权传递给下游的关键步骤。
|
||||||
|
- `next` 的返回值是原始的 SDK 响应或下游中间件的结果,你需要根据情况处理它(例如,如果是流,则开始消费流)。
|
||||||
|
- **在调用 `next` 之后**:
|
||||||
|
- 处理 `next` 的返回结果。例如,如果 `next` 返回了一个流,你可以在这里开始迭代处理这个流,并通过 `context.onChunk` 发送数据块。
|
||||||
|
- 基于 `context` 的变化或 `next` 的结果执行进一步操作。例如,计算总耗时、记录日志。
|
||||||
|
- 修改最终结果(尽管对于 `completions`,结果通常通过 `onChunk` 副作用发出)。
|
||||||
|
|
||||||
|
### 示例:一个简单的日志中间件
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import {
|
||||||
|
AiProviderMiddlewareCompletionsContext,
|
||||||
|
CompletionsParams,
|
||||||
|
MiddlewareAPI,
|
||||||
|
OnChunkFunction // 假设 OnChunkFunction 类型被导出
|
||||||
|
} from './AiProviderMiddlewareTypes' // 调整路径
|
||||||
|
import { ChunkType } from '@renderer/types' // 调整路径
|
||||||
|
|
||||||
|
export const createSimpleLoggingMiddleware = (): CompletionsMiddleware => {
|
||||||
|
return (api: MiddlewareAPI<AiProviderMiddlewareCompletionsContext, [CompletionsParams]>) => {
|
||||||
|
// console.log(`[LoggingMiddleware] Initialized for provider: ${api.getProviderId()}`);
|
||||||
|
|
||||||
|
return (next: (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams) => Promise<any>) => {
|
||||||
|
return async (context: AiProviderMiddlewareCompletionsContext, params: CompletionsParams): Promise<void> => {
|
||||||
|
const startTime = Date.now()
|
||||||
|
// 从 context 中获取 onChunk (它最初来自 params.onChunk)
|
||||||
|
const onChunk = context.onChunk
|
||||||
|
|
||||||
|
console.log(
|
||||||
|
`[LoggingMiddleware] Request for ${context.methodName} with params:`,
|
||||||
|
params.messages?.[params.messages.length - 1]?.content
|
||||||
|
)
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 调用下一个中间件或核心逻辑
|
||||||
|
// `rawSdkResponse` 是来自下游的原始响应 (例如 OpenAIStream 或 ChatCompletion 对象)
|
||||||
|
const rawSdkResponse = await next(context, params)
|
||||||
|
|
||||||
|
// 此处简单示例不处理 rawSdkResponse,假设下游中间件 (如 StreamingResponseHandler)
|
||||||
|
// 会处理它并通过 onChunk 发送数据。
|
||||||
|
// 如果这个日志中间件在 StreamingResponseHandler 之后,那么流已经被处理。
|
||||||
|
// 如果在之前,那么它需要自己处理 rawSdkResponse 或确保下游会处理。
|
||||||
|
|
||||||
|
const duration = Date.now() - startTime
|
||||||
|
console.log(`[LoggingMiddleware] Request for ${context.methodName} completed in ${duration}ms.`)
|
||||||
|
|
||||||
|
// 假设下游已经通过 onChunk 发送了所有数据。
|
||||||
|
// 如果这个中间件是链的末端,并且需要确保 BLOCK_COMPLETE 被发送,
|
||||||
|
// 它可能需要更复杂的逻辑来跟踪何时所有数据都已发送。
|
||||||
|
} catch (error) {
|
||||||
|
const duration = Date.now() - startTime
|
||||||
|
console.error(`[LoggingMiddleware] Request for ${context.methodName} failed after ${duration}ms:`, error)
|
||||||
|
|
||||||
|
// 如果 onChunk 可用,可以尝试发送一个错误块
|
||||||
|
if (onChunk) {
|
||||||
|
onChunk({
|
||||||
|
type: ChunkType.ERROR,
|
||||||
|
error: { message: (error as Error).message, name: (error as Error).name, stack: (error as Error).stack }
|
||||||
|
})
|
||||||
|
// 考虑是否还需要发送 BLOCK_COMPLETE 来结束流
|
||||||
|
onChunk({ type: ChunkType.BLOCK_COMPLETE, response: {} })
|
||||||
|
}
|
||||||
|
throw error // 重新抛出错误,以便上层或全局错误处理器可以捕获
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### `AiProviderMiddlewareCompletionsContext` 的重要性
|
||||||
|
|
||||||
|
`AiProviderMiddlewareCompletionsContext` 是在中间件之间传递状态和数据的核心。它通常包含:
|
||||||
|
|
||||||
|
- `methodName`: 当前调用的方法名 (总是 `'completions'`)。
|
||||||
|
- `originalArgs`: 传递给 `completions` 的原始参数数组。
|
||||||
|
- `providerId`: Provider 的 ID。
|
||||||
|
- `_providerInstance`: Provider 实例。
|
||||||
|
- `onChunk`: 从原始 `CompletionsParams` 传入的回调函数,用于流式发送数据块。**所有中间件都应该通过 `context.onChunk` 来发送数据。**
|
||||||
|
- `messages`, `model`, `assistant`, `mcpTools`: 从原始 `CompletionsParams` 中提取的常用字段,方便访问。
|
||||||
|
- **自定义字段**: 中间件可以向上下文中添加自定义字段,以供后续中间件使用。例如,一个缓存中间件可能会添加 `context.cacheHit = true`。
|
||||||
|
|
||||||
|
**关键**: 当你在中间件中修改 `params` 或 `context` 时,这些修改会向下游中间件传播(如果它们在 `next` 调用之前修改)。
|
||||||
|
|
||||||
|
### 中间件的顺序
|
||||||
|
|
||||||
|
中间件的执行顺序非常重要。它们在 `AiProviderMiddlewareConfig` 的数组中定义的顺序就是它们的执行顺序。
|
||||||
|
|
||||||
|
- 请求首先通过第一个中间件,然后是第二个,依此类推。
|
||||||
|
- 响应(或 `next` 的调用结果)则以相反的顺序"冒泡"回来。
|
||||||
|
|
||||||
|
例如,如果链是 `[AuthMiddleware, CacheMiddleware, LoggingMiddleware]`:
|
||||||
|
|
||||||
|
1. `AuthMiddleware` 先执行其 "调用 `next` 之前" 的逻辑。
|
||||||
|
2. 然后 `CacheMiddleware` 执行其 "调用 `next` 之前" 的逻辑。
|
||||||
|
3. 然后 `LoggingMiddleware` 执行其 "调用 `next` 之前" 的逻辑。
|
||||||
|
4. 核心SDK调用(或链的末端)。
|
||||||
|
5. `LoggingMiddleware` 先接收到结果,执行其 "调用 `next` 之后" 的逻辑。
|
||||||
|
6. 然后 `CacheMiddleware` 接收到结果(可能已被 LoggingMiddleware 修改的上下文),执行其 "调用 `next` 之后" 的逻辑(例如,存储结果)。
|
||||||
|
7. 最后 `AuthMiddleware` 接收到结果,执行其 "调用 `next` 之后" 的逻辑。
|
||||||
|
|
||||||
|
### 注册中间件
|
||||||
|
|
||||||
|
中间件在 `src/renderer/src/providers/middleware/register.ts` (或其他类似的配置文件) 中进行注册。
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// register.ts
|
||||||
|
import { AiProviderMiddlewareConfig } from './AiProviderMiddlewareTypes'
|
||||||
|
import { createSimpleLoggingMiddleware } from './common/SimpleLoggingMiddleware' // 假设你创建了这个文件
|
||||||
|
import { createCompletionsLoggingMiddleware } from './common/CompletionsLoggingMiddleware' // 已有的
|
||||||
|
|
||||||
|
const middlewareConfig: AiProviderMiddlewareConfig = {
|
||||||
|
completions: [
|
||||||
|
createSimpleLoggingMiddleware(), // 你新加的中间件
|
||||||
|
createCompletionsLoggingMiddleware() // 已有的日志中间件
|
||||||
|
// ... 其他 completions 中间件
|
||||||
|
],
|
||||||
|
methods: {
|
||||||
|
// translate: [createGenericLoggingMiddleware()],
|
||||||
|
// ... 其他方法的中间件
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export default middlewareConfig
|
||||||
|
```
|
||||||
|
|
||||||
|
### 最佳实践
|
||||||
|
|
||||||
|
1. **单一职责**: 每个中间件应专注于一个特定的功能(例如,日志、缓存、转换特定数据)。
|
||||||
|
2. **无副作用 (尽可能)**: 除了通过 `context` 或 `onChunk` 明确的副作用外,尽量避免修改全局状态或产生其他隐蔽的副作用。
|
||||||
|
3. **错误处理**:
|
||||||
|
- 在中间件内部使用 `try...catch` 来处理可能发生的错误。
|
||||||
|
- 决定是自行处理错误(例如,通过 `onChunk` 发送错误块)还是将错误重新抛出给上游。
|
||||||
|
- 如果重新抛出,确保错误对象包含足够的信息。
|
||||||
|
4. **性能考虑**: 中间件会增加请求处理的开销。避免在中间件中执行非常耗时的同步操作。对于IO密集型操作,确保它们是异步的。
|
||||||
|
5. **可配置性**: 使中间件的行为可通过参数或配置进行调整。例如,日志中间件可以接受一个日志级别参数。
|
||||||
|
6. **上下文管理**:
|
||||||
|
- 谨慎地向 `context` 添加数据。避免污染 `context` 或添加过大的对象。
|
||||||
|
- 明确你添加到 `context` 的字段的用途和生命周期。
|
||||||
|
7. **`next` 的调用**:
|
||||||
|
- 除非你有充分的理由提前终止请求(例如,缓存命中、授权失败),否则**总是确保调用 `await next(context, params)`**。否则,下游的中间件和核心逻辑将不会执行。
|
||||||
|
- 理解 `next` 的返回值并正确处理它,特别是当它是一个流时。你需要负责消费这个流或将其传递给另一个能够消费它的组件/中间件。
|
||||||
|
8. **命名清晰**: 给你的中间件和它们创建的函数起描述性的名字。
|
||||||
|
9. **文档和注释**: 对复杂的中间件逻辑添加注释,解释其工作原理和目的。
|
||||||
|
|
||||||
|
### 调试技巧
|
||||||
|
|
||||||
|
- 在中间件的关键点使用 `console.log` 或调试器来检查 `params`、`context` 的状态以及 `next` 的返回值。
|
||||||
|
- 暂时简化中间件链,只保留你正在调试的中间件和最简单的核心逻辑,以隔离问题。
|
||||||
|
- 编写单元测试来独立验证每个中间件的行为。
|
||||||
|
|
||||||
|
通过遵循这些指南,你应该能够有效地为我们的系统创建强大且可维护的中间件。如果你有任何疑问或需要进一步的帮助,请咨询团队。
|
||||||
@@ -12,30 +12,43 @@ electronLanguages:
|
|||||||
directories:
|
directories:
|
||||||
buildResources: build
|
buildResources: build
|
||||||
files:
|
files:
|
||||||
- '!{.vscode,.yarn,.github}'
|
- '**/*'
|
||||||
- '!electron.vite.config.{js,ts,mjs,cjs}'
|
- '!**/{.vscode,.yarn,.yarn-lock,.github,.cursorrules,.prettierrc}'
|
||||||
- '!{.eslintignore,.eslintrc.cjs,.prettierignore,.prettierrc.yaml,dev-app-update.yml,CHANGELOG.md,README.md}'
|
- '!electron.vite.config.{js,ts,mjs,cjs}}'
|
||||||
- '!{.env,.env.*,.npmrc,pnpm-lock.yaml}'
|
- '!**/{.eslintignore,.eslintrc.js,.eslintrc.json,.eslintcache,root.eslint.config.js,eslint.config.js,.eslintrc.cjs,.prettierignore,.prettierrc.yaml,eslint.config.mjs,dev-app-update.yml,CHANGELOG.md,README.md}'
|
||||||
- '!{tsconfig.json,tsconfig.node.json,tsconfig.web.json}'
|
- '!**/{.env,.env.*,.npmrc,pnpm-lock.yaml}'
|
||||||
|
- '!**/{tsconfig.json,tsconfig.tsbuildinfo,tsconfig.node.json,tsconfig.web.json}'
|
||||||
|
- '!**/{.editorconfig,.jekyll-metadata}'
|
||||||
- '!src'
|
- '!src'
|
||||||
- '!scripts'
|
- '!scripts'
|
||||||
- '!local'
|
- '!local'
|
||||||
- '!docs'
|
- '!docs'
|
||||||
- '!packages'
|
- '!packages'
|
||||||
|
- '!.swc'
|
||||||
|
- '!.bin'
|
||||||
|
- '!._*'
|
||||||
|
- '!*.log'
|
||||||
- '!stats.html'
|
- '!stats.html'
|
||||||
- '!*.md'
|
- '!*.md'
|
||||||
|
- '!**/*.{iml,o,hprof,orig,pyc,pyo,rbc,swp,csproj,sln,xproj}'
|
||||||
- '!**/*.{map,ts,tsx,jsx,less,scss,sass,css.d.ts,d.cts,d.mts,md,markdown,yaml,yml}'
|
- '!**/*.{map,ts,tsx,jsx,less,scss,sass,css.d.ts,d.cts,d.mts,md,markdown,yaml,yml}'
|
||||||
- '!**/{test,tests,__tests__,coverage}/**'
|
- '!**/{test,tests,__tests__,powered-test,coverage}/**'
|
||||||
|
- '!**/{example,examples}/**'
|
||||||
- '!**/*.{spec,test}.{js,jsx,ts,tsx}'
|
- '!**/*.{spec,test}.{js,jsx,ts,tsx}'
|
||||||
- '!**/*.min.*.map'
|
- '!**/*.min.*.map'
|
||||||
- '!**/*.d.ts'
|
- '!**/*.d.ts'
|
||||||
- '!**/{.DS_Store,Thumbs.db}'
|
- '!**/dist/es6/**'
|
||||||
- '!**/{LICENSE,LICENSE.txt,LICENSE-MIT.txt,*.LICENSE.txt,NOTICE.txt,README.md,CHANGELOG.md}'
|
- '!**/dist/demo/**'
|
||||||
|
- '!**/amd/**'
|
||||||
|
- '!**/{.DS_Store,Thumbs.db,thumbs.db,__pycache__}'
|
||||||
|
- '!**/{LICENSE,license,LICENSE.*,*.LICENSE.txt,NOTICE.txt,README.md,readme.md,CHANGELOG.md}'
|
||||||
- '!node_modules/rollup-plugin-visualizer'
|
- '!node_modules/rollup-plugin-visualizer'
|
||||||
- '!node_modules/js-tiktoken'
|
- '!node_modules/js-tiktoken'
|
||||||
- '!node_modules/@tavily/core/node_modules/js-tiktoken'
|
- '!node_modules/@tavily/core/node_modules/js-tiktoken'
|
||||||
- '!node_modules/pdf-parse/lib/pdf.js/{v1.9.426,v1.10.88,v2.0.550}'
|
- '!node_modules/pdf-parse/lib/pdf.js/{v1.9.426,v1.10.88,v2.0.550}'
|
||||||
- '!node_modules/mammoth/{mammoth.browser.js,mammoth.browser.min.js}'
|
- '!node_modules/mammoth/{mammoth.browser.js,mammoth.browser.min.js}'
|
||||||
|
- '!node_modules/selection-hook/prebuilds/**/*' # we rebuild .node, don't use prebuilds
|
||||||
|
- '!**/*.{h,iobj,ipdb,tlog,recipe,vcxproj,vcxproj.filters}' # filter .node build files
|
||||||
asarUnpack:
|
asarUnpack:
|
||||||
- resources/**
|
- resources/**
|
||||||
- '**/*.{metal,exp,lib}'
|
- '**/*.{metal,exp,lib}'
|
||||||
@@ -94,14 +107,11 @@ afterSign: scripts/notarize.js
|
|||||||
artifactBuildCompleted: scripts/artifact-build-completed.js
|
artifactBuildCompleted: scripts/artifact-build-completed.js
|
||||||
releaseInfo:
|
releaseInfo:
|
||||||
releaseNotes: |
|
releaseNotes: |
|
||||||
⚠️ 注意:升级前请备份数据,否则将无法降级
|
划词助手:支持文本选择快捷键、开关快捷键、思考块支持和引用功能
|
||||||
增加 TokenFlux 服务商
|
复制功能:新增纯文本复制(去除Markdown格式符号)
|
||||||
增加 Claude 4 模型支持
|
知识库:支持设置向量维度,修复Ollama分数错误和维度编辑问题
|
||||||
Grok 模型增加联网能力
|
多语言:增加模型名称多语言提示和翻译源语言手动选择
|
||||||
小程序支持前进和后退
|
文件管理:修复主题/消息删除时文件未清理问题,优化文件选择流程
|
||||||
修复 Windows 用户 MCP 无法启动问题
|
模型:修复Gemini模型推理预算、Voyage AI嵌入问题和DeepSeek翻译模型更新
|
||||||
修复无法搜索历史消息问题
|
图像功能:统一图片查看器,支持Base64图片渲染,修复图片预览相关问题
|
||||||
修复 MCP 代理问题
|
UI:实现标签折叠/拖拽排序,修复气泡溢出,增加引文索引显示
|
||||||
修复精简备份恢复覆盖文件问题
|
|
||||||
修复@模型回复插入位置错误问题
|
|
||||||
修复搜索小程序崩溃问题
|
|
||||||
|
|||||||
@@ -9,25 +9,7 @@ const visualizerPlugin = (type: 'renderer' | 'main') => {
|
|||||||
|
|
||||||
export default defineConfig({
|
export default defineConfig({
|
||||||
main: {
|
main: {
|
||||||
plugins: [
|
plugins: [externalizeDepsPlugin(), ...visualizerPlugin('main')],
|
||||||
externalizeDepsPlugin({
|
|
||||||
exclude: [
|
|
||||||
'@cherrystudio/embedjs',
|
|
||||||
'@cherrystudio/embedjs-openai',
|
|
||||||
'@cherrystudio/embedjs-loader-web',
|
|
||||||
'@cherrystudio/embedjs-loader-markdown',
|
|
||||||
'@cherrystudio/embedjs-loader-msoffice',
|
|
||||||
'@cherrystudio/embedjs-loader-xml',
|
|
||||||
'@cherrystudio/embedjs-loader-pdf',
|
|
||||||
'@cherrystudio/embedjs-loader-sitemap',
|
|
||||||
'@cherrystudio/embedjs-libsql',
|
|
||||||
'@cherrystudio/embedjs-loader-image',
|
|
||||||
'p-queue',
|
|
||||||
'webdav'
|
|
||||||
]
|
|
||||||
}),
|
|
||||||
...visualizerPlugin('main')
|
|
||||||
],
|
|
||||||
resolve: {
|
resolve: {
|
||||||
alias: {
|
alias: {
|
||||||
'@main': resolve('src/main'),
|
'@main': resolve('src/main'),
|
||||||
@@ -37,8 +19,18 @@ export default defineConfig({
|
|||||||
},
|
},
|
||||||
build: {
|
build: {
|
||||||
rollupOptions: {
|
rollupOptions: {
|
||||||
external: ['@libsql/client']
|
external: ['@libsql/client', 'bufferutil', 'utf-8-validate'],
|
||||||
}
|
output: {
|
||||||
|
// 彻底禁用代码分割 - 返回 null 强制单文件打包
|
||||||
|
manualChunks: undefined,
|
||||||
|
// 内联所有动态导入,这是关键配置
|
||||||
|
inlineDynamicImports: true
|
||||||
|
}
|
||||||
|
},
|
||||||
|
sourcemap: process.env.NODE_ENV === 'development'
|
||||||
|
},
|
||||||
|
optimizeDeps: {
|
||||||
|
noDiscovery: process.env.NODE_ENV === 'development'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
preload: {
|
preload: {
|
||||||
@@ -47,6 +39,9 @@ export default defineConfig({
|
|||||||
alias: {
|
alias: {
|
||||||
'@shared': resolve('packages/shared')
|
'@shared': resolve('packages/shared')
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
build: {
|
||||||
|
sourcemap: process.env.NODE_ENV === 'development'
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
renderer: {
|
renderer: {
|
||||||
@@ -82,7 +77,9 @@ export default defineConfig({
|
|||||||
rollupOptions: {
|
rollupOptions: {
|
||||||
input: {
|
input: {
|
||||||
index: resolve(__dirname, 'src/renderer/index.html'),
|
index: resolve(__dirname, 'src/renderer/index.html'),
|
||||||
miniWindow: resolve(__dirname, 'src/renderer/miniWindow.html')
|
miniWindow: resolve(__dirname, 'src/renderer/miniWindow.html'),
|
||||||
|
selectionToolbar: resolve(__dirname, 'src/renderer/selectionToolbar.html'),
|
||||||
|
selectionAction: resolve(__dirname, 'src/renderer/selectionAction.html')
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
133
package.json
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "CherryStudio",
|
"name": "CherryStudio",
|
||||||
"version": "1.3.11",
|
"version": "1.4.2",
|
||||||
"private": true,
|
"private": true,
|
||||||
"description": "A powerful AI assistant for producer.",
|
"description": "A powerful AI assistant for producer.",
|
||||||
"main": "./out/main/index.js",
|
"main": "./out/main/index.js",
|
||||||
@@ -20,8 +20,9 @@
|
|||||||
"scripts": {
|
"scripts": {
|
||||||
"start": "electron-vite preview",
|
"start": "electron-vite preview",
|
||||||
"dev": "electron-vite dev",
|
"dev": "electron-vite dev",
|
||||||
|
"debug": "electron-vite -- --inspect --sourcemap --remote-debugging-port=9222",
|
||||||
"build": "npm run typecheck && electron-vite build",
|
"build": "npm run typecheck && electron-vite build",
|
||||||
"build:check": "yarn test && yarn typecheck && yarn check:i18n",
|
"build:check": "yarn typecheck && yarn check:i18n && yarn test",
|
||||||
"build:unpack": "dotenv npm run build && electron-builder --dir",
|
"build:unpack": "dotenv npm run build && electron-builder --dir",
|
||||||
"build:win": "dotenv npm run build && electron-builder --win --x64 --arm64",
|
"build:win": "dotenv npm run build && electron-builder --win --x64 --arm64",
|
||||||
"build:win:x64": "dotenv npm run build && electron-builder --win --x64",
|
"build:win:x64": "dotenv npm run build && electron-builder --win --x64",
|
||||||
@@ -37,26 +38,41 @@
|
|||||||
"publish": "yarn build:check && yarn release patch push",
|
"publish": "yarn build:check && yarn release patch push",
|
||||||
"pulish:artifacts": "cd packages/artifacts && npm publish && cd -",
|
"pulish:artifacts": "cd packages/artifacts && npm publish && cd -",
|
||||||
"generate:agents": "yarn workspace @cherry-studio/database agents",
|
"generate:agents": "yarn workspace @cherry-studio/database agents",
|
||||||
"generate:icons": "electron-icon-builder --input=./build/logo.png --output=build",
|
|
||||||
"analyze:renderer": "VISUALIZER_RENDERER=true yarn build",
|
"analyze:renderer": "VISUALIZER_RENDERER=true yarn build",
|
||||||
"analyze:main": "VISUALIZER_MAIN=true yarn build",
|
"analyze:main": "VISUALIZER_MAIN=true yarn build",
|
||||||
"typecheck": "npm run typecheck:node && npm run typecheck:web",
|
"typecheck": "npm run typecheck:node && npm run typecheck:web",
|
||||||
"typecheck:node": "tsc --noEmit -p tsconfig.node.json --composite false",
|
"typecheck:node": "tsc --noEmit -p tsconfig.node.json --composite false",
|
||||||
"typecheck:web": "tsc --noEmit -p tsconfig.web.json --composite false",
|
"typecheck:web": "tsc --noEmit -p tsconfig.web.json --composite false",
|
||||||
"check:i18n": "node scripts/check-i18n.js",
|
"check:i18n": "node scripts/check-i18n.js",
|
||||||
"test": "yarn test:renderer",
|
"test": "vitest run --silent",
|
||||||
"test:coverage": "yarn test:renderer:coverage",
|
"test:main": "vitest run --project main",
|
||||||
"test:node": "npx -y tsx --test src/**/*.test.ts",
|
"test:renderer": "vitest run --project renderer",
|
||||||
"test:renderer": "vitest run",
|
"test:update": "yarn test:renderer --update",
|
||||||
"test:renderer:ui": "vitest --ui",
|
"test:coverage": "vitest run --coverage --silent",
|
||||||
"test:renderer:coverage": "vitest run --coverage",
|
"test:ui": "vitest --ui",
|
||||||
|
"test:watch": "vitest",
|
||||||
|
"test:e2e": "yarn playwright test",
|
||||||
"test:lint": "eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts",
|
"test:lint": "eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts",
|
||||||
"format": "prettier --write .",
|
"format": "prettier --write .",
|
||||||
"lint": "eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix",
|
"lint": "eslint . --ext .js,.jsx,.cjs,.mjs,.ts,.tsx,.cts,.mts --fix",
|
||||||
"postinstall": "electron-builder install-app-deps",
|
|
||||||
"prepare": "husky"
|
"prepare": "husky"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
|
"@libsql/client": "0.14.0",
|
||||||
|
"@libsql/win32-x64-msvc": "^0.4.7",
|
||||||
|
"@strongtz/win32-arm64-msvc": "^0.4.7",
|
||||||
|
"jsdom": "26.1.0",
|
||||||
|
"notion-helper": "^1.3.22",
|
||||||
|
"os-proxy-config": "^1.1.2",
|
||||||
|
"selection-hook": "^0.9.23",
|
||||||
|
"turndown": "7.2.0"
|
||||||
|
},
|
||||||
|
"devDependencies": {
|
||||||
|
"@agentic/exa": "^7.3.3",
|
||||||
|
"@agentic/searxng": "^7.3.3",
|
||||||
|
"@agentic/tavily": "^7.3.3",
|
||||||
|
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||||
|
"@anthropic-ai/sdk": "^0.41.0",
|
||||||
"@cherrystudio/embedjs": "^0.1.31",
|
"@cherrystudio/embedjs": "^0.1.31",
|
||||||
"@cherrystudio/embedjs-libsql": "^0.1.31",
|
"@cherrystudio/embedjs-libsql": "^0.1.31",
|
||||||
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
|
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
|
||||||
@@ -67,63 +83,33 @@
|
|||||||
"@cherrystudio/embedjs-loader-sitemap": "^0.1.31",
|
"@cherrystudio/embedjs-loader-sitemap": "^0.1.31",
|
||||||
"@cherrystudio/embedjs-loader-web": "^0.1.31",
|
"@cherrystudio/embedjs-loader-web": "^0.1.31",
|
||||||
"@cherrystudio/embedjs-loader-xml": "^0.1.31",
|
"@cherrystudio/embedjs-loader-xml": "^0.1.31",
|
||||||
|
"@cherrystudio/embedjs-ollama": "^0.1.31",
|
||||||
"@cherrystudio/embedjs-openai": "^0.1.31",
|
"@cherrystudio/embedjs-openai": "^0.1.31",
|
||||||
"@electron-toolkit/utils": "^3.0.0",
|
|
||||||
"@electron/notarize": "^2.5.0",
|
|
||||||
"@langchain/community": "^0.3.36",
|
|
||||||
"@strongtz/win32-arm64-msvc": "^0.4.7",
|
|
||||||
"@tanstack/react-query": "^5.27.0",
|
|
||||||
"@types/react-infinite-scroll-component": "^5.0.0",
|
|
||||||
"archiver": "^7.0.1",
|
|
||||||
"async-mutex": "^0.5.0",
|
|
||||||
"color": "^5.0.0",
|
|
||||||
"diff": "^7.0.0",
|
|
||||||
"docx": "^9.0.2",
|
|
||||||
"electron-log": "^5.1.5",
|
|
||||||
"electron-store": "^8.2.0",
|
|
||||||
"electron-updater": "6.6.4",
|
|
||||||
"electron-window-state": "^5.0.3",
|
|
||||||
"epub": "patch:epub@npm%3A1.3.0#~/.yarn/patches/epub-npm-1.3.0-8325494ffe.patch",
|
|
||||||
"fast-diff": "^1.3.0",
|
|
||||||
"fast-xml-parser": "^5.2.0",
|
|
||||||
"fetch-socks": "^1.3.2",
|
|
||||||
"fs-extra": "^11.2.0",
|
|
||||||
"got-scraping": "^4.1.1",
|
|
||||||
"jsdom": "^26.0.0",
|
|
||||||
"markdown-it": "^14.1.0",
|
|
||||||
"node-stream-zip": "^1.15.0",
|
|
||||||
"officeparser": "^4.1.1",
|
|
||||||
"os-proxy-config": "^1.1.2",
|
|
||||||
"proxy-agent": "^6.5.0",
|
|
||||||
"tar": "^7.4.3",
|
|
||||||
"turndown": "^7.2.0",
|
|
||||||
"turndown-plugin-gfm": "^1.0.2",
|
|
||||||
"webdav": "^5.8.0",
|
|
||||||
"ws": "^8.18.1",
|
|
||||||
"zipread": "^1.3.3"
|
|
||||||
},
|
|
||||||
"devDependencies": {
|
|
||||||
"@agentic/exa": "^7.3.3",
|
|
||||||
"@agentic/searxng": "^7.3.3",
|
|
||||||
"@agentic/tavily": "^7.3.3",
|
|
||||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
|
||||||
"@anthropic-ai/sdk": "^0.41.0",
|
|
||||||
"@electron-toolkit/eslint-config-prettier": "^3.0.0",
|
"@electron-toolkit/eslint-config-prettier": "^3.0.0",
|
||||||
"@electron-toolkit/eslint-config-ts": "^3.0.0",
|
"@electron-toolkit/eslint-config-ts": "^3.0.0",
|
||||||
"@electron-toolkit/preload": "^3.0.0",
|
"@electron-toolkit/preload": "^3.0.0",
|
||||||
"@electron-toolkit/tsconfig": "^1.0.1",
|
"@electron-toolkit/tsconfig": "^1.0.1",
|
||||||
|
"@electron-toolkit/utils": "^3.0.0",
|
||||||
|
"@electron/notarize": "^2.5.0",
|
||||||
"@emotion/is-prop-valid": "^1.3.1",
|
"@emotion/is-prop-valid": "^1.3.1",
|
||||||
"@eslint-react/eslint-plugin": "^1.36.1",
|
"@eslint-react/eslint-plugin": "^1.36.1",
|
||||||
"@eslint/js": "^9.22.0",
|
"@eslint/js": "^9.22.0",
|
||||||
"@google/genai": "^0.13.0",
|
"@google/genai": "patch:@google/genai@npm%3A1.0.1#~/.yarn/patches/@google-genai-npm-1.0.1-e26f0f9af7.patch",
|
||||||
"@hello-pangea/dnd": "^16.6.0",
|
"@hello-pangea/dnd": "^16.6.0",
|
||||||
"@kangfenmao/keyv-storage": "^0.1.0",
|
"@kangfenmao/keyv-storage": "^0.1.0",
|
||||||
|
"@langchain/community": "^0.3.36",
|
||||||
|
"@langchain/ollama": "^0.2.1",
|
||||||
"@modelcontextprotocol/sdk": "^1.11.4",
|
"@modelcontextprotocol/sdk": "^1.11.4",
|
||||||
"@mozilla/readability": "^0.6.0",
|
"@mozilla/readability": "^0.6.0",
|
||||||
"@notionhq/client": "^2.2.15",
|
"@notionhq/client": "^2.2.15",
|
||||||
|
"@playwright/test": "^1.52.0",
|
||||||
"@reduxjs/toolkit": "^2.2.5",
|
"@reduxjs/toolkit": "^2.2.5",
|
||||||
"@shikijs/markdown-it": "^3.4.2",
|
"@shikijs/markdown-it": "^3.4.2",
|
||||||
"@swc/plugin-styled-components": "^7.1.5",
|
"@swc/plugin-styled-components": "^7.1.5",
|
||||||
|
"@tanstack/react-query": "^5.27.0",
|
||||||
|
"@testing-library/dom": "^10.4.0",
|
||||||
|
"@testing-library/jest-dom": "^6.6.3",
|
||||||
|
"@testing-library/react": "^16.3.0",
|
||||||
"@tryfabric/martian": "^1.2.4",
|
"@tryfabric/martian": "^1.2.4",
|
||||||
"@types/diff": "^7",
|
"@types/diff": "^7",
|
||||||
"@types/fs-extra": "^11",
|
"@types/fs-extra": "^11",
|
||||||
@@ -137,46 +123,67 @@
|
|||||||
"@types/react-infinite-scroll-component": "^5.0.0",
|
"@types/react-infinite-scroll-component": "^5.0.0",
|
||||||
"@types/react-window": "^1",
|
"@types/react-window": "^1",
|
||||||
"@types/tinycolor2": "^1",
|
"@types/tinycolor2": "^1",
|
||||||
"@types/ws": "^8",
|
|
||||||
"@uiw/codemirror-extensions-langs": "^4.23.12",
|
"@uiw/codemirror-extensions-langs": "^4.23.12",
|
||||||
"@uiw/codemirror-themes-all": "^4.23.12",
|
"@uiw/codemirror-themes-all": "^4.23.12",
|
||||||
"@uiw/react-codemirror": "^4.23.12",
|
"@uiw/react-codemirror": "^4.23.12",
|
||||||
"@vitejs/plugin-react-swc": "^3.9.0",
|
"@vitejs/plugin-react-swc": "^3.9.0",
|
||||||
"@vitest/ui": "^3.1.1",
|
"@vitest/browser": "^3.1.4",
|
||||||
"@vitest/web-worker": "^3.1.3",
|
"@vitest/coverage-v8": "^3.1.4",
|
||||||
|
"@vitest/ui": "^3.1.4",
|
||||||
|
"@vitest/web-worker": "^3.1.4",
|
||||||
"@xyflow/react": "^12.4.4",
|
"@xyflow/react": "^12.4.4",
|
||||||
"antd": "^5.22.5",
|
"antd": "^5.22.5",
|
||||||
|
"archiver": "^7.0.1",
|
||||||
|
"async-mutex": "^0.5.0",
|
||||||
"axios": "^1.7.3",
|
"axios": "^1.7.3",
|
||||||
"browser-image-compression": "^2.0.2",
|
"browser-image-compression": "^2.0.2",
|
||||||
|
"color": "^5.0.0",
|
||||||
"dayjs": "^1.11.11",
|
"dayjs": "^1.11.11",
|
||||||
"dexie": "^4.0.8",
|
"dexie": "^4.0.8",
|
||||||
"dexie-react-hooks": "^1.1.7",
|
"dexie-react-hooks": "^1.1.7",
|
||||||
|
"diff": "^7.0.0",
|
||||||
|
"docx": "^9.0.2",
|
||||||
"dotenv-cli": "^7.4.2",
|
"dotenv-cli": "^7.4.2",
|
||||||
"electron": "35.4.0",
|
"electron": "35.4.0",
|
||||||
"electron-builder": "26.0.15",
|
"electron-builder": "26.0.15",
|
||||||
"electron-devtools-installer": "^3.2.0",
|
"electron-devtools-installer": "^3.2.0",
|
||||||
"electron-icon-builder": "^2.0.1",
|
"electron-log": "^5.1.5",
|
||||||
|
"electron-store": "^8.2.0",
|
||||||
|
"electron-updater": "6.6.4",
|
||||||
"electron-vite": "^3.1.0",
|
"electron-vite": "^3.1.0",
|
||||||
|
"electron-window-state": "^5.0.3",
|
||||||
"emittery": "^1.0.3",
|
"emittery": "^1.0.3",
|
||||||
"emoji-picker-element": "^1.22.1",
|
"emoji-picker-element": "^1.22.1",
|
||||||
|
"epub": "patch:epub@npm%3A1.3.0#~/.yarn/patches/epub-npm-1.3.0-8325494ffe.patch",
|
||||||
"eslint": "^9.22.0",
|
"eslint": "^9.22.0",
|
||||||
"eslint-plugin-react-hooks": "^5.2.0",
|
"eslint-plugin-react-hooks": "^5.2.0",
|
||||||
"eslint-plugin-simple-import-sort": "^12.1.1",
|
"eslint-plugin-simple-import-sort": "^12.1.1",
|
||||||
"eslint-plugin-unused-imports": "^4.1.4",
|
"eslint-plugin-unused-imports": "^4.1.4",
|
||||||
|
"fast-diff": "^1.3.0",
|
||||||
|
"fast-xml-parser": "^5.2.0",
|
||||||
|
"franc-min": "^6.2.0",
|
||||||
|
"fs-extra": "^11.2.0",
|
||||||
|
"google-auth-library": "^9.15.1",
|
||||||
"html-to-image": "^1.11.13",
|
"html-to-image": "^1.11.13",
|
||||||
"husky": "^9.1.7",
|
"husky": "^9.1.7",
|
||||||
"i18next": "^23.11.5",
|
"i18next": "^23.11.5",
|
||||||
|
"jest-styled-components": "^7.2.0",
|
||||||
"lint-staged": "^15.5.0",
|
"lint-staged": "^15.5.0",
|
||||||
"lodash": "^4.17.21",
|
"lodash": "^4.17.21",
|
||||||
"lru-cache": "^11.1.0",
|
"lru-cache": "^11.1.0",
|
||||||
"lucide-react": "^0.487.0",
|
"lucide-react": "^0.487.0",
|
||||||
|
"markdown-it": "^14.1.0",
|
||||||
"mermaid": "^11.6.0",
|
"mermaid": "^11.6.0",
|
||||||
"mime": "^4.0.4",
|
"mime": "^4.0.4",
|
||||||
"motion": "^12.10.5",
|
"motion": "^12.10.5",
|
||||||
|
"node-stream-zip": "^1.15.0",
|
||||||
"npx-scope-finder": "^1.2.0",
|
"npx-scope-finder": "^1.2.0",
|
||||||
"openai": "patch:openai@npm%3A4.96.0#~/.yarn/patches/openai-npm-4.96.0-0665b05cb9.patch",
|
"officeparser": "^4.1.1",
|
||||||
|
"openai": "patch:openai@npm%3A5.1.0#~/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch",
|
||||||
"p-queue": "^8.1.0",
|
"p-queue": "^8.1.0",
|
||||||
|
"playwright": "^1.52.0",
|
||||||
"prettier": "^3.5.3",
|
"prettier": "^3.5.3",
|
||||||
|
"proxy-agent": "^6.5.0",
|
||||||
"rc-virtual-list": "^3.18.6",
|
"rc-virtual-list": "^3.18.6",
|
||||||
"react": "^19.0.0",
|
"react": "^19.0.0",
|
||||||
"react-dom": "^19.0.0",
|
"react-dom": "^19.0.0",
|
||||||
@@ -197,29 +204,33 @@
|
|||||||
"remark-cjk-friendly": "^1.1.0",
|
"remark-cjk-friendly": "^1.1.0",
|
||||||
"remark-gfm": "^4.0.0",
|
"remark-gfm": "^4.0.0",
|
||||||
"remark-math": "^6.0.0",
|
"remark-math": "^6.0.0",
|
||||||
|
"remove-markdown": "^0.6.2",
|
||||||
"rollup-plugin-visualizer": "^5.12.0",
|
"rollup-plugin-visualizer": "^5.12.0",
|
||||||
"sass": "^1.88.0",
|
"sass": "^1.88.0",
|
||||||
"shiki": "^3.4.2",
|
"shiki": "^3.4.2",
|
||||||
"string-width": "^7.2.0",
|
"string-width": "^7.2.0",
|
||||||
"styled-components": "^6.1.11",
|
"styled-components": "^6.1.11",
|
||||||
|
"tar": "^7.4.3",
|
||||||
"tiny-pinyin": "^1.3.2",
|
"tiny-pinyin": "^1.3.2",
|
||||||
"tokenx": "^0.4.1",
|
"tokenx": "^0.4.1",
|
||||||
"typescript": "^5.6.2",
|
"typescript": "^5.6.2",
|
||||||
"uuid": "^10.0.0",
|
"uuid": "^10.0.0",
|
||||||
"vite": "6.2.6",
|
"vite": "6.2.6",
|
||||||
"vitest": "^3.1.1"
|
"vitest": "^3.1.4",
|
||||||
|
"webdav": "^5.8.0",
|
||||||
|
"zipread": "^1.3.3"
|
||||||
},
|
},
|
||||||
"resolutions": {
|
"resolutions": {
|
||||||
"pdf-parse@npm:1.1.1": "patch:pdf-parse@npm%3A1.1.1#~/.yarn/patches/pdf-parse-npm-1.1.1-04a6109b2a.patch",
|
"pdf-parse@npm:1.1.1": "patch:pdf-parse@npm%3A1.1.1#~/.yarn/patches/pdf-parse-npm-1.1.1-04a6109b2a.patch",
|
||||||
"@langchain/openai@npm:^0.3.16": "patch:@langchain/openai@npm%3A0.3.16#~/.yarn/patches/@langchain-openai-npm-0.3.16-e525b59526.patch",
|
"@langchain/openai@npm:^0.3.16": "patch:@langchain/openai@npm%3A0.3.16#~/.yarn/patches/@langchain-openai-npm-0.3.16-e525b59526.patch",
|
||||||
"@langchain/openai@npm:>=0.1.0 <0.4.0": "patch:@langchain/openai@npm%3A0.3.16#~/.yarn/patches/@langchain-openai-npm-0.3.16-e525b59526.patch",
|
"@langchain/openai@npm:>=0.1.0 <0.4.0": "patch:@langchain/openai@npm%3A0.3.16#~/.yarn/patches/@langchain-openai-npm-0.3.16-e525b59526.patch",
|
||||||
"node-gyp": "^9.1.0",
|
|
||||||
"libsql@npm:^0.4.4": "patch:libsql@npm%3A0.4.7#~/.yarn/patches/libsql-npm-0.4.7-444e260fb1.patch",
|
"libsql@npm:^0.4.4": "patch:libsql@npm%3A0.4.7#~/.yarn/patches/libsql-npm-0.4.7-444e260fb1.patch",
|
||||||
"openai@npm:^4.77.0": "patch:openai@npm%3A4.96.0#~/.yarn/patches/openai-npm-4.96.0-0665b05cb9.patch",
|
"openai@npm:^4.77.0": "patch:openai@npm%3A5.1.0#~/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch",
|
||||||
"pkce-challenge@npm:^4.1.0": "patch:pkce-challenge@npm%3A4.1.0#~/.yarn/patches/pkce-challenge-npm-4.1.0-fbc51695a3.patch",
|
"pkce-challenge@npm:^4.1.0": "patch:pkce-challenge@npm%3A4.1.0#~/.yarn/patches/pkce-challenge-npm-4.1.0-fbc51695a3.patch",
|
||||||
"app-builder-lib@npm:26.0.13": "patch:app-builder-lib@npm%3A26.0.13#~/.yarn/patches/app-builder-lib-npm-26.0.13-a064c9e1d0.patch",
|
"app-builder-lib@npm:26.0.13": "patch:app-builder-lib@npm%3A26.0.13#~/.yarn/patches/app-builder-lib-npm-26.0.13-a064c9e1d0.patch",
|
||||||
"openai@npm:^4.87.3": "patch:openai@npm%3A4.96.0#~/.yarn/patches/openai-npm-4.96.0-0665b05cb9.patch",
|
"openai@npm:^4.87.3": "patch:openai@npm%3A5.1.0#~/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch",
|
||||||
"app-builder-lib@npm:26.0.15": "patch:app-builder-lib@npm%3A26.0.15#~/.yarn/patches/app-builder-lib-npm-26.0.15-360e5b0476.patch"
|
"app-builder-lib@npm:26.0.15": "patch:app-builder-lib@npm%3A26.0.15#~/.yarn/patches/app-builder-lib-npm-26.0.15-360e5b0476.patch",
|
||||||
|
"@langchain/core@npm:^0.3.26": "patch:@langchain/core@npm%3A0.3.44#~/.yarn/patches/@langchain-core-npm-0.3.44-41d5c3cb0a.patch"
|
||||||
},
|
},
|
||||||
"packageManager": "yarn@4.9.1",
|
"packageManager": "yarn@4.9.1",
|
||||||
"lint-staged": {
|
"lint-staged": {
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ export enum IpcChannel {
|
|||||||
App_SetLaunchToTray = 'app:set-launch-to-tray',
|
App_SetLaunchToTray = 'app:set-launch-to-tray',
|
||||||
App_SetTray = 'app:set-tray',
|
App_SetTray = 'app:set-tray',
|
||||||
App_SetTrayOnClose = 'app:set-tray-on-close',
|
App_SetTrayOnClose = 'app:set-tray-on-close',
|
||||||
App_RestartTray = 'app:restart-tray',
|
|
||||||
App_SetTheme = 'app:set-theme',
|
App_SetTheme = 'app:set-theme',
|
||||||
App_SetAutoUpdate = 'app:set-auto-update',
|
App_SetAutoUpdate = 'app:set-auto-update',
|
||||||
|
App_SetFeedUrl = 'app:set-feed-url',
|
||||||
App_HandleZoomFactor = 'app:handle-zoom-factor',
|
App_HandleZoomFactor = 'app:handle-zoom-factor',
|
||||||
|
|
||||||
App_IsBinaryExist = 'app:is-binary-exist',
|
App_IsBinaryExist = 'app:is-binary-exist',
|
||||||
@@ -21,6 +21,8 @@ export enum IpcChannel {
|
|||||||
App_InstallUvBinary = 'app:install-uv-binary',
|
App_InstallUvBinary = 'app:install-uv-binary',
|
||||||
App_InstallBunBinary = 'app:install-bun-binary',
|
App_InstallBunBinary = 'app:install-bun-binary',
|
||||||
|
|
||||||
|
App_QuoteToMain = 'app:quote-to-main',
|
||||||
|
|
||||||
Notification_Send = 'notification:send',
|
Notification_Send = 'notification:send',
|
||||||
Notification_OnClick = 'notification:on-click',
|
Notification_OnClick = 'notification:on-click',
|
||||||
|
|
||||||
@@ -84,6 +86,10 @@ export enum IpcChannel {
|
|||||||
Gemini_ListFiles = 'gemini:list-files',
|
Gemini_ListFiles = 'gemini:list-files',
|
||||||
Gemini_DeleteFile = 'gemini:delete-file',
|
Gemini_DeleteFile = 'gemini:delete-file',
|
||||||
|
|
||||||
|
// VertexAI
|
||||||
|
VertexAI_GetAuthHeaders = 'vertexai:get-auth-headers',
|
||||||
|
VertexAI_ClearAuthCache = 'vertexai:clear-auth-cache',
|
||||||
|
|
||||||
Windows_ResetMinimumSize = 'window:reset-minimum-size',
|
Windows_ResetMinimumSize = 'window:reset-minimum-size',
|
||||||
Windows_SetMinimumSize = 'window:set-minimum-size',
|
Windows_SetMinimumSize = 'window:set-minimum-size',
|
||||||
|
|
||||||
@@ -111,10 +117,12 @@ export enum IpcChannel {
|
|||||||
File_WriteWithId = 'file:writeWithId',
|
File_WriteWithId = 'file:writeWithId',
|
||||||
File_SaveImage = 'file:saveImage',
|
File_SaveImage = 'file:saveImage',
|
||||||
File_Base64Image = 'file:base64Image',
|
File_Base64Image = 'file:base64Image',
|
||||||
|
File_SaveBase64Image = 'file:saveBase64Image',
|
||||||
File_Download = 'file:download',
|
File_Download = 'file:download',
|
||||||
File_Copy = 'file:copy',
|
File_Copy = 'file:copy',
|
||||||
File_BinaryImage = 'file:binaryImage',
|
File_BinaryImage = 'file:binaryImage',
|
||||||
File_Base64File = 'file:base64File',
|
File_Base64File = 'file:base64File',
|
||||||
|
File_GetPdfInfo = 'file:getPdfInfo',
|
||||||
Fs_Read = 'fs:read',
|
Fs_Read = 'fs:read',
|
||||||
|
|
||||||
Export_Word = 'export:word',
|
Export_Word = 'export:word',
|
||||||
@@ -144,7 +152,7 @@ export enum IpcChannel {
|
|||||||
|
|
||||||
// events
|
// events
|
||||||
BackupProgress = 'backup-progress',
|
BackupProgress = 'backup-progress',
|
||||||
ThemeChange = 'theme:change',
|
ThemeUpdated = 'theme:updated',
|
||||||
UpdateDownloadedCancelled = 'update-downloaded-cancelled',
|
UpdateDownloadedCancelled = 'update-downloaded-cancelled',
|
||||||
RestoreProgress = 'restore-progress',
|
RestoreProgress = 'restore-progress',
|
||||||
UpdateError = 'update-error',
|
UpdateError = 'update-error',
|
||||||
@@ -176,5 +184,23 @@ export enum IpcChannel {
|
|||||||
StoreSync_BroadcastSync = 'store-sync:broadcast-sync',
|
StoreSync_BroadcastSync = 'store-sync:broadcast-sync',
|
||||||
|
|
||||||
// Provider
|
// Provider
|
||||||
Provider_AddKey = 'provider:add-key'
|
Provider_AddKey = 'provider:add-key',
|
||||||
|
|
||||||
|
//Selection Assistant
|
||||||
|
Selection_TextSelected = 'selection:text-selected',
|
||||||
|
Selection_ToolbarHide = 'selection:toolbar-hide',
|
||||||
|
Selection_ToolbarVisibilityChange = 'selection:toolbar-visibility-change',
|
||||||
|
Selection_ToolbarDetermineSize = 'selection:toolbar-determine-size',
|
||||||
|
Selection_WriteToClipboard = 'selection:write-to-clipboard',
|
||||||
|
Selection_SetEnabled = 'selection:set-enabled',
|
||||||
|
Selection_SetTriggerMode = 'selection:set-trigger-mode',
|
||||||
|
Selection_SetFilterMode = 'selection:set-filter-mode',
|
||||||
|
Selection_SetFilterList = 'selection:set-filter-list',
|
||||||
|
Selection_SetFollowToolbar = 'selection:set-follow-toolbar',
|
||||||
|
Selection_SetRemeberWinSize = 'selection:set-remeber-win-size',
|
||||||
|
Selection_ActionWindowClose = 'selection:action-window-close',
|
||||||
|
Selection_ActionWindowMinimize = 'selection:action-window-minimize',
|
||||||
|
Selection_ActionWindowPin = 'selection:action-window-pin',
|
||||||
|
Selection_ProcessAction = 'selection:process-action',
|
||||||
|
Selection_UpdateActionData = 'selection:update-action-data'
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,135 +4,368 @@ export const audioExts = ['.mp3', '.wav', '.ogg', '.flac', '.aac']
|
|||||||
export const documentExts = ['.pdf', '.docx', '.pptx', '.xlsx', '.odt', '.odp', '.ods']
|
export const documentExts = ['.pdf', '.docx', '.pptx', '.xlsx', '.odt', '.odp', '.ods']
|
||||||
export const thirdPartyApplicationExts = ['.draftsExport']
|
export const thirdPartyApplicationExts = ['.draftsExport']
|
||||||
export const bookExts = ['.epub']
|
export const bookExts = ['.epub']
|
||||||
export const textExts = [
|
const textExtsByCategory = new Map([
|
||||||
'.txt', // 普通文本文件
|
[
|
||||||
'.md', // Markdown 文件
|
'language',
|
||||||
'.mdx', // Markdown 文件
|
[
|
||||||
'.html', // HTML 文件
|
'.js',
|
||||||
'.htm', // HTML 文件的另一种扩展名
|
'.mjs',
|
||||||
'.xml', // XML 文件
|
'.cjs',
|
||||||
'.json', // JSON 文件
|
'.ts',
|
||||||
'.yaml', // YAML 文件
|
'.jsx',
|
||||||
'.yml', // YAML 文件的另一种扩展名
|
'.tsx', // JavaScript/TypeScript
|
||||||
'.csv', // 逗号分隔值文件
|
'.py', // Python
|
||||||
'.tsv', // 制表符分隔值文件
|
'.java', // Java
|
||||||
'.ini', // 配置文件
|
'.cs', // C#
|
||||||
'.log', // 日志文件
|
'.cpp',
|
||||||
'.rtf', // 富文本格式文件
|
'.c',
|
||||||
'.org', // org-mode 文件
|
'.h',
|
||||||
'.wiki', // VimWiki 文件
|
'.hpp',
|
||||||
'.tex', // LaTeX 文件
|
'.cc',
|
||||||
'.bib', // BibTeX 文件
|
'.cxx',
|
||||||
'.srt', // 字幕文件
|
'.cppm',
|
||||||
'.xhtml', // XHTML 文件
|
'.ipp',
|
||||||
'.nfo', // 信息文件(主要用于场景发布)
|
'.ixx', // C/C++
|
||||||
'.conf', // 配置文件
|
'.php', // PHP
|
||||||
'.config', // 配置文件
|
'.rb', // Ruby
|
||||||
'.env', // 环境变量文件
|
'.pl', // Perl
|
||||||
'.rst', // reStructuredText 文件
|
'.go', // Go
|
||||||
'.php', // PHP 脚本文件,包含嵌入的 HTML
|
'.rs', // Rust
|
||||||
'.js', // JavaScript 文件(部分是文本,部分可能包含代码)
|
'.swift', // Swift
|
||||||
'.ts', // TypeScript 文件
|
'.kt',
|
||||||
'.jsp', // JavaServer Pages 文件
|
'.kts', // Kotlin
|
||||||
'.aspx', // ASP.NET 文件
|
'.scala', // Scala
|
||||||
'.bat', // Windows 批处理文件
|
'.lua', // Lua
|
||||||
'.sh', // Unix/Linux Shell 脚本文件
|
'.groovy', // Groovy
|
||||||
'.py', // Python 脚本文件
|
'.dart', // Dart
|
||||||
'.ipynb', // Jupyter 笔记本格式
|
'.hs', // Haskell
|
||||||
'.rb', // Ruby 脚本文件
|
'.clj',
|
||||||
'.pl', // Perl 脚本文件
|
'.cljs', // Clojure
|
||||||
'.sql', // SQL 脚本文件
|
'.elm', // Elm
|
||||||
'.css', // Cascading Style Sheets 文件
|
'.erl', // Erlang
|
||||||
'.less', // Less CSS 预处理器文件
|
'.ex',
|
||||||
'.scss', // Sass CSS 预处理器文件
|
'.exs', // Elixir
|
||||||
'.sass', // Sass 文件
|
'.ml',
|
||||||
'.styl', // Stylus CSS 预处理器文件
|
'.mli', // OCaml
|
||||||
'.coffee', // CoffeeScript 文件
|
'.fs', // F#
|
||||||
'.ino', // Arduino 代码文件
|
'.r',
|
||||||
'.asm', // Assembly 语言文件
|
'.R', // R
|
||||||
'.go', // Go 语言文件
|
'.sol', // Solidity
|
||||||
'.scala', // Scala 语言文件
|
'.awk', // AWK
|
||||||
'.swift', // Swift 语言文件
|
'.cob', // COBOL
|
||||||
'.kt', // Kotlin 语言文件
|
'.asm',
|
||||||
'.rs', // Rust 语言文件
|
'.s', // Assembly
|
||||||
'.lua', // Lua 语言文件
|
'.lisp',
|
||||||
'.groovy', // Groovy 语言文件
|
'.lsp', // Lisp
|
||||||
'.dart', // Dart 语言文件
|
'.coffee', // CoffeeScript
|
||||||
'.hs', // Haskell 语言文件
|
'.ino', // Arduino
|
||||||
'.clj', // Clojure 语言文件
|
'.jl', // Julia
|
||||||
'.cljs', // ClojureScript 语言文件
|
'.nim', // Nim
|
||||||
'.elm', // Elm 语言文件
|
'.zig', // Zig
|
||||||
'.erl', // Erlang 语言文件
|
'.d', // D语言
|
||||||
'.ex', // Elixir 语言文件
|
'.pas', // Pascal
|
||||||
'.exs', // Elixir 脚本文件
|
'.vb', // Visual Basic
|
||||||
'.pug', // Pug (formerly Jade) 模板文件
|
'.rkt', // Racket
|
||||||
'.haml', // Haml 模板文件
|
'.scm', // Scheme
|
||||||
'.slim', // Slim 模板文件
|
'.hx', // Haxe
|
||||||
'.tpl', // 模板文件(通用)
|
'.as', // ActionScript
|
||||||
'.ejs', // Embedded JavaScript 模板文件
|
'.pde', // Processing
|
||||||
'.hbs', // Handlebars 模板文件
|
'.f90',
|
||||||
'.mustache', // Mustache 模板文件
|
'.f',
|
||||||
'.jade', // Jade 模板文件 (已重命名为 Pug)
|
'.f03',
|
||||||
'.twig', // Twig 模板文件
|
'.for',
|
||||||
'.blade', // Blade 模板文件 (Laravel)
|
'.f95', // Fortran
|
||||||
'.vue', // Vue.js 单文件组件
|
'.adb',
|
||||||
'.jsx', // React JSX 文件
|
'.ads', // Ada
|
||||||
'.tsx', // React TSX 文件
|
'.pro', // Prolog
|
||||||
'.graphql', // GraphQL 查询语言文件
|
'.m',
|
||||||
'.gql', // GraphQL 查询语言文件
|
'.mm', // Objective-C/MATLAB
|
||||||
'.proto', // Protocol Buffers 文件
|
'.rpy', // Ren'Py
|
||||||
'.thrift', // Thrift 文件
|
'.ets', // OpenHarmony,
|
||||||
'.toml', // TOML 配置文件
|
'.uniswap', // DeFi
|
||||||
'.edn', // Clojure 数据表示文件
|
'.vy', // Vyper
|
||||||
'.cake', // CakePHP 配置文件
|
'.shader',
|
||||||
'.ctp', // CakePHP 视图文件
|
'.glsl',
|
||||||
'.cfm', // ColdFusion 标记语言文件
|
'.frag',
|
||||||
'.cfc', // ColdFusion 组件文件
|
'.vert',
|
||||||
'.m', // Objective-C 或 MATLAB 源文件
|
'.gd' // Godot
|
||||||
'.mm', // Objective-C++ 源文件
|
]
|
||||||
'.gradle', // Gradle 构建文件
|
],
|
||||||
'.groovy', // Gradle 构建文件
|
[
|
||||||
'.kts', // Kotlin Script 文件
|
'script',
|
||||||
'.java', // Java 代码文件
|
[
|
||||||
'.cs', // C# 代码文件
|
'.sh', // Shell
|
||||||
'.cpp', // C++ 代码文件
|
'.bat',
|
||||||
'.c', // C++ 代码文件
|
'.cmd', // Windows批处理
|
||||||
'.h', // C++ 头文件
|
'.ps1', // PowerShell
|
||||||
'.hpp', // C++ 头文件
|
'.tcl',
|
||||||
'.cc', // C++ 源文件
|
'.do', // Tcl
|
||||||
'.cxx', // C++ 源文件
|
'.ahk', // AutoHotkey
|
||||||
'.cppm', // C++20 模块接口文件
|
'.zsh', // Zsh
|
||||||
'.ipp', // 模板实现文件
|
'.fish', // Fish shell
|
||||||
'.ixx', // C++20 模块实现文件
|
'.csh', // C shell
|
||||||
'.f90', // Fortran 90 源文件
|
'.vbs', // VBScript
|
||||||
'.f', // Fortran 固定格式源代码文件
|
'.applescript', // AppleScript
|
||||||
'.f03', // Fortran 2003+ 源代码文件
|
'.au3', // AutoIt
|
||||||
'.ahk', // AutoHotKey 语言文件
|
'.bash',
|
||||||
'.tcl', // Tcl 脚本
|
'.nu'
|
||||||
'.do', // Questa 或 Modelsim Tcl 脚本
|
]
|
||||||
'.v', // Verilog 源文件
|
],
|
||||||
'.sv', // SystemVerilog 源文件
|
[
|
||||||
'.svh', // SystemVerilog 头文件
|
'style',
|
||||||
'.vhd', // VHDL 源文件
|
[
|
||||||
'.vhdl', // VHDL 源文件
|
'.css', // CSS
|
||||||
'.lef', // Library Exchange Format
|
'.less', // Less
|
||||||
'.def', // Design Exchange Format
|
'.scss',
|
||||||
'.edif', // Electronic Design Interchange Format
|
'.sass', // Sass
|
||||||
'.sdf', // Standard Delay Format
|
'.styl', // Stylus
|
||||||
'.sdc', // Synopsys Design Constraints
|
'.pcss', // PostCSS
|
||||||
'.xdc', // Xilinx Design Constraints
|
'.postcss' // PostCSS
|
||||||
'.rpt', // 报告文件
|
]
|
||||||
'.lisp', // Lisp 脚本
|
],
|
||||||
'.il', // Cadence SKILL 脚本
|
[
|
||||||
'.ils', // Cadence SKILL++ 脚本
|
'template',
|
||||||
'.sp', // SPICE netlist 文件
|
[
|
||||||
'.spi', // SPICE netlist 文件
|
'.vue', // Vue.js
|
||||||
'.cir', // SPICE netlist 文件
|
'.pug',
|
||||||
'.net', // SPICE netlist 文件
|
'.jade', // Pug/Jade
|
||||||
'.scs', // Spectre netlist 文件
|
'.haml', // Haml
|
||||||
'.asc', // LTspice netlist schematic 文件
|
'.slim', // Slim
|
||||||
'.tf' // Technology File
|
'.tpl', // 通用模板
|
||||||
]
|
'.ejs', // EJS
|
||||||
|
'.hbs', // Handlebars
|
||||||
|
'.mustache', // Mustache
|
||||||
|
'.twig', // Twig
|
||||||
|
'.blade', // Blade (Laravel)
|
||||||
|
'.liquid', // Liquid
|
||||||
|
'.jinja',
|
||||||
|
'.jinja2',
|
||||||
|
'.j2', // Jinja
|
||||||
|
'.erb', // ERB
|
||||||
|
'.vm', // Velocity
|
||||||
|
'.ftl', // FreeMarker
|
||||||
|
'.svelte', // Svelte
|
||||||
|
'.astro' // Astro
|
||||||
|
]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
'config',
|
||||||
|
[
|
||||||
|
'.ini', // INI配置
|
||||||
|
'.conf',
|
||||||
|
'.config', // 通用配置
|
||||||
|
'.env', // 环境变量
|
||||||
|
'.toml', // TOML
|
||||||
|
'.cfg', // 通用配置
|
||||||
|
'.properties', // Java属性
|
||||||
|
'.desktop', // Linux桌面文件
|
||||||
|
'.service', // systemd服务
|
||||||
|
'.rc',
|
||||||
|
'.bashrc',
|
||||||
|
'.zshrc', // Shell配置
|
||||||
|
'.fishrc', // Fish shell配置
|
||||||
|
'.vimrc', // Vim配置
|
||||||
|
'.htaccess', // Apache配置
|
||||||
|
'.robots', // robots.txt
|
||||||
|
'.editorconfig', // EditorConfig
|
||||||
|
'.eslintrc', // ESLint
|
||||||
|
'.prettierrc', // Prettier
|
||||||
|
'.babelrc', // Babel
|
||||||
|
'.npmrc', // npm
|
||||||
|
'.dockerignore', // Docker ignore
|
||||||
|
'.npmignore',
|
||||||
|
'.yarnrc',
|
||||||
|
'.prettierignore',
|
||||||
|
'.eslintignore',
|
||||||
|
'.browserslistrc',
|
||||||
|
'.json5',
|
||||||
|
'.tfvars'
|
||||||
|
]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
'document',
|
||||||
|
[
|
||||||
|
'.txt',
|
||||||
|
'.text', // 纯文本
|
||||||
|
'.md',
|
||||||
|
'.mdx', // Markdown
|
||||||
|
'.html',
|
||||||
|
'.htm',
|
||||||
|
'.xhtml', // HTML
|
||||||
|
'.xml', // XML
|
||||||
|
'.org', // Org-mode
|
||||||
|
'.wiki', // Wiki
|
||||||
|
'.tex',
|
||||||
|
'.bib', // LaTeX
|
||||||
|
'.rst', // reStructuredText
|
||||||
|
'.rtf', // 富文本
|
||||||
|
'.nfo', // 信息文件
|
||||||
|
'.adoc',
|
||||||
|
'.asciidoc', // AsciiDoc
|
||||||
|
'.pod', // Perl文档
|
||||||
|
'.1',
|
||||||
|
'.2',
|
||||||
|
'.3',
|
||||||
|
'.4',
|
||||||
|
'.5',
|
||||||
|
'.6',
|
||||||
|
'.7',
|
||||||
|
'.8',
|
||||||
|
'.9', // man页面
|
||||||
|
'.man', // man页面
|
||||||
|
'.texi',
|
||||||
|
'.texinfo', // Texinfo
|
||||||
|
'.readme',
|
||||||
|
'.me', // README
|
||||||
|
'.changelog', // 变更日志
|
||||||
|
'.license', // 许可证
|
||||||
|
'.authors', // 作者文件
|
||||||
|
'.po',
|
||||||
|
'.pot'
|
||||||
|
]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
'data',
|
||||||
|
[
|
||||||
|
'.json', // JSON
|
||||||
|
'.jsonc', // JSON with comments
|
||||||
|
'.yaml',
|
||||||
|
'.yml', // YAML
|
||||||
|
'.csv',
|
||||||
|
'.tsv', // 分隔值文件
|
||||||
|
'.edn', // Clojure数据
|
||||||
|
'.jsonl',
|
||||||
|
'.ndjson', // 换行分隔JSON
|
||||||
|
'.geojson', // GeoJSON
|
||||||
|
'.gpx', // GPS Exchange
|
||||||
|
'.kml', // Keyhole Markup
|
||||||
|
'.rss',
|
||||||
|
'.atom', // Feed格式
|
||||||
|
'.vcf', // vCard
|
||||||
|
'.ics', // iCalendar
|
||||||
|
'.ldif', // LDAP数据交换
|
||||||
|
'.pbtxt',
|
||||||
|
'.map'
|
||||||
|
]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
'build',
|
||||||
|
[
|
||||||
|
'.gradle', // Gradle
|
||||||
|
'.make',
|
||||||
|
'.mk', // Make
|
||||||
|
'.cmake', // CMake
|
||||||
|
'.sbt', // SBT
|
||||||
|
'.rake', // Rake
|
||||||
|
'.spec', // RPM spec
|
||||||
|
'.pom',
|
||||||
|
'.build', // Meson
|
||||||
|
'.bazel' // Bazel
|
||||||
|
]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
'database',
|
||||||
|
[
|
||||||
|
'.sql', // SQL
|
||||||
|
'.ddl',
|
||||||
|
'.dml', // DDL/DML
|
||||||
|
'.plsql', // PL/SQL
|
||||||
|
'.psql', // PostgreSQL
|
||||||
|
'.cypher', // Cypher
|
||||||
|
'.sparql' // SPARQL
|
||||||
|
]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
'web',
|
||||||
|
[
|
||||||
|
'.graphql',
|
||||||
|
'.gql', // GraphQL
|
||||||
|
'.proto', // Protocol Buffers
|
||||||
|
'.thrift', // Thrift
|
||||||
|
'.wsdl', // WSDL
|
||||||
|
'.raml', // RAML
|
||||||
|
'.swagger',
|
||||||
|
'.openapi' // API文档
|
||||||
|
]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
'version',
|
||||||
|
[
|
||||||
|
'.gitignore', // Git ignore
|
||||||
|
'.gitattributes', // Git attributes
|
||||||
|
'.gitconfig', // Git config
|
||||||
|
'.hgignore', // Mercurial ignore
|
||||||
|
'.bzrignore', // Bazaar ignore
|
||||||
|
'.svnignore', // SVN ignore
|
||||||
|
'.githistory' // Git history
|
||||||
|
]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
'subtitle',
|
||||||
|
[
|
||||||
|
'.srt',
|
||||||
|
'.sub',
|
||||||
|
'.ass' // 字幕格式
|
||||||
|
]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
'log',
|
||||||
|
[
|
||||||
|
'.log',
|
||||||
|
'.rpt' // 日志和报告 (移除了.out,因为通常是二进制可执行文件)
|
||||||
|
]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
'eda',
|
||||||
|
[
|
||||||
|
'.v',
|
||||||
|
'.sv',
|
||||||
|
'.svh', // Verilog/SystemVerilog
|
||||||
|
'.vhd',
|
||||||
|
'.vhdl', // VHDL
|
||||||
|
'.lef',
|
||||||
|
'.def', // LEF/DEF
|
||||||
|
'.edif', // EDIF
|
||||||
|
'.sdf', // SDF
|
||||||
|
'.sdc',
|
||||||
|
'.xdc', // 约束文件
|
||||||
|
'.sp',
|
||||||
|
'.spi',
|
||||||
|
'.cir',
|
||||||
|
'.net', // SPICE
|
||||||
|
'.scs', // Spectre
|
||||||
|
'.asc', // LTspice
|
||||||
|
'.tf', // Technology File
|
||||||
|
'.il',
|
||||||
|
'.ils' // SKILL
|
||||||
|
]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
'game',
|
||||||
|
[
|
||||||
|
'.mtl', // Material Template Library
|
||||||
|
'.x3d', // X3D文件
|
||||||
|
'.gltf', // glTF JSON
|
||||||
|
'.prefab', // Unity预制体 (YAML格式)
|
||||||
|
'.meta' // Unity元数据文件 (YAML格式)
|
||||||
|
]
|
||||||
|
],
|
||||||
|
[
|
||||||
|
'other',
|
||||||
|
[
|
||||||
|
'.mcfunction', // Minecraft函数
|
||||||
|
'.jsp', // JSP
|
||||||
|
'.aspx', // ASP.NET
|
||||||
|
'.ipynb', // Jupyter Notebook
|
||||||
|
'.cake',
|
||||||
|
'.ctp', // CakePHP
|
||||||
|
'.cfm',
|
||||||
|
'.cfc' // ColdFusion
|
||||||
|
]
|
||||||
|
]
|
||||||
|
])
|
||||||
|
|
||||||
|
export const textExts = Array.from(textExtsByCategory.values()).flat()
|
||||||
|
|
||||||
export const ZOOM_LEVELS = [0.25, 0.33, 0.5, 0.67, 0.75, 0.8, 0.9, 1, 1.1, 1.25, 1.5, 1.75, 2, 2.5, 3, 4, 5]
|
export const ZOOM_LEVELS = [0.25, 0.33, 0.5, 0.67, 0.75, 0.8, 0.9, 1, 1.1, 1.25, 1.5, 1.75, 2, 2.5, 3, 4, 5]
|
||||||
|
|
||||||
@@ -170,3 +403,9 @@ export const KB = 1024
|
|||||||
export const MB = 1024 * KB
|
export const MB = 1024 * KB
|
||||||
export const GB = 1024 * MB
|
export const GB = 1024 * MB
|
||||||
export const defaultLanguage = 'en-US'
|
export const defaultLanguage = 'en-US'
|
||||||
|
|
||||||
|
export enum FeedUrl {
|
||||||
|
PRODUCTION = 'https://releases.cherry-ai.com',
|
||||||
|
EARLY_ACCESS = 'https://github.com/CherryHQ/cherry-studio/releases/latest/download'
|
||||||
|
}
|
||||||
|
export const defaultTimeout = 5 * 1000 * 60
|
||||||
|
|||||||
42
playwright.config.ts
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
import { defineConfig, devices } from '@playwright/test'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* See https://playwright.dev/docs/test-configuration.
|
||||||
|
*/
|
||||||
|
export default defineConfig({
|
||||||
|
// Look for test files, relative to this configuration file.
|
||||||
|
testDir: './tests/e2e',
|
||||||
|
/* Run tests in files in parallel */
|
||||||
|
fullyParallel: true,
|
||||||
|
/* Fail the build on CI if you accidentally left test.only in the source code. */
|
||||||
|
forbidOnly: !!process.env.CI,
|
||||||
|
/* Retry on CI only */
|
||||||
|
retries: process.env.CI ? 2 : 0,
|
||||||
|
/* Opt out of parallel tests on CI. */
|
||||||
|
workers: process.env.CI ? 1 : undefined,
|
||||||
|
/* Reporter to use. See https://playwright.dev/docs/test-reporters */
|
||||||
|
reporter: 'html',
|
||||||
|
/* Shared settings for all the projects below. See https://playwright.dev/docs/api/class-testoptions. */
|
||||||
|
use: {
|
||||||
|
/* Base URL to use in actions like `await page.goto('/')`. */
|
||||||
|
// baseURL: 'http://localhost:3000',
|
||||||
|
|
||||||
|
/* Collect trace when retrying the failed test. See https://playwright.dev/docs/trace-viewer */
|
||||||
|
trace: 'on-first-retry'
|
||||||
|
},
|
||||||
|
|
||||||
|
/* Configure projects for major browsers */
|
||||||
|
projects: [
|
||||||
|
{
|
||||||
|
name: 'chromium',
|
||||||
|
use: { ...devices['Desktop Chrome'] }
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
/* Run your local dev server before starting the tests */
|
||||||
|
// webServer: {
|
||||||
|
// command: 'npm run start',
|
||||||
|
// url: 'http://localhost:3000',
|
||||||
|
// reuseExistingServer: !process.env.CI,
|
||||||
|
// },
|
||||||
|
})
|
||||||
@@ -36,6 +36,11 @@ exports.default = async function (context) {
|
|||||||
keepPackageNodeFiles(node_modules_path, '@libsql', ['win32-x64-msvc'])
|
keepPackageNodeFiles(node_modules_path, '@libsql', ['win32-x64-msvc'])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (platform === 'windows') {
|
||||||
|
fs.rmSync(path.join(context.appOutDir, 'LICENSE.electron.txt'), { force: true })
|
||||||
|
fs.rmSync(path.join(context.appOutDir, 'LICENSES.chromium.html'), { force: true })
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
58
src/main/configs/SelectionConfig.ts
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
interface IFilterList {
|
||||||
|
WINDOWS: string[]
|
||||||
|
MAC?: string[]
|
||||||
|
}
|
||||||
|
|
||||||
|
interface IFinetunedList {
|
||||||
|
EXCLUDE_CLIPBOARD_CURSOR_DETECT: IFilterList
|
||||||
|
INCLUDE_CLIPBOARD_DELAY_READ: IFilterList
|
||||||
|
}
|
||||||
|
|
||||||
|
/*************************************************************************
|
||||||
|
* 注意:请不要修改此配置,除非你非常清楚其含义、影响和行为的目的
|
||||||
|
* Note: Do not modify this configuration unless you fully understand its meaning, implications, and intended behavior.
|
||||||
|
* -----------------------------------------------------------------------
|
||||||
|
* A predefined application filter list to include commonly used software
|
||||||
|
* that does not require text selection but may conflict with it, and disable them in advance.
|
||||||
|
* Only available in the selected mode.
|
||||||
|
*
|
||||||
|
* Specification: must be all lowercase, need to accurately find the actual running program name
|
||||||
|
*************************************************************************/
|
||||||
|
export const SELECTION_PREDEFINED_BLACKLIST: IFilterList = {
|
||||||
|
WINDOWS: [
|
||||||
|
'explorer.exe',
|
||||||
|
// Screenshot
|
||||||
|
'snipaste.exe',
|
||||||
|
'pixpin.exe',
|
||||||
|
'sharex.exe',
|
||||||
|
// Office
|
||||||
|
'excel.exe',
|
||||||
|
'powerpnt.exe',
|
||||||
|
// Image Editor
|
||||||
|
'photoshop.exe',
|
||||||
|
'illustrator.exe',
|
||||||
|
// Video Editor
|
||||||
|
'adobe premiere pro.exe',
|
||||||
|
'afterfx.exe',
|
||||||
|
// Audio Editor
|
||||||
|
'adobe audition.exe',
|
||||||
|
// 3D Editor
|
||||||
|
'blender.exe',
|
||||||
|
'3dsmax.exe',
|
||||||
|
'maya.exe',
|
||||||
|
// CAD
|
||||||
|
'acad.exe',
|
||||||
|
'sldworks.exe',
|
||||||
|
// Remote Desktop
|
||||||
|
'mstsc.exe'
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
export const SELECTION_FINETUNED_LIST: IFinetunedList = {
|
||||||
|
EXCLUDE_CLIPBOARD_CURSOR_DETECT: {
|
||||||
|
WINDOWS: ['acrobat.exe', 'wps.exe', 'cajviewer.exe']
|
||||||
|
},
|
||||||
|
INCLUDE_CLIPBOARD_DELAY_READ: {
|
||||||
|
WINDOWS: ['acrobat.exe', 'wps.exe', 'cajviewer.exe', 'foxitphantom.exe']
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -5,8 +5,15 @@ import EmbeddingsFactory from './EmbeddingsFactory'
|
|||||||
|
|
||||||
export default class Embeddings {
|
export default class Embeddings {
|
||||||
private sdk: BaseEmbeddings
|
private sdk: BaseEmbeddings
|
||||||
constructor({ model, apiKey, apiVersion, baseURL, dimensions }: KnowledgeBaseParams) {
|
constructor({ model, provider, apiKey, apiVersion, baseURL, dimensions }: KnowledgeBaseParams) {
|
||||||
this.sdk = EmbeddingsFactory.create({ model, apiKey, apiVersion, baseURL, dimensions } as KnowledgeBaseParams)
|
this.sdk = EmbeddingsFactory.create({
|
||||||
|
model,
|
||||||
|
provider,
|
||||||
|
apiKey,
|
||||||
|
apiVersion,
|
||||||
|
baseURL,
|
||||||
|
dimensions
|
||||||
|
} as KnowledgeBaseParams)
|
||||||
}
|
}
|
||||||
public async init(): Promise<void> {
|
public async init(): Promise<void> {
|
||||||
return this.sdk.init()
|
return this.sdk.init()
|
||||||
|
|||||||
@@ -1,20 +1,49 @@
|
|||||||
import type { BaseEmbeddings } from '@cherrystudio/embedjs-interfaces'
|
import type { BaseEmbeddings } from '@cherrystudio/embedjs-interfaces'
|
||||||
|
import { OllamaEmbeddings } from '@cherrystudio/embedjs-ollama'
|
||||||
import { OpenAiEmbeddings } from '@cherrystudio/embedjs-openai'
|
import { OpenAiEmbeddings } from '@cherrystudio/embedjs-openai'
|
||||||
import { AzureOpenAiEmbeddings } from '@cherrystudio/embedjs-openai/src/azure-openai-embeddings'
|
import { AzureOpenAiEmbeddings } from '@cherrystudio/embedjs-openai/src/azure-openai-embeddings'
|
||||||
import { getInstanceName } from '@main/utils'
|
import { getInstanceName } from '@main/utils'
|
||||||
import { KnowledgeBaseParams } from '@types'
|
import { KnowledgeBaseParams } from '@types'
|
||||||
|
|
||||||
import VoyageEmbeddings from './VoyageEmbeddings'
|
import { SUPPORTED_DIM_MODELS as VOYAGE_SUPPORTED_DIM_MODELS, VoyageEmbeddings } from './VoyageEmbeddings'
|
||||||
|
|
||||||
export default class EmbeddingsFactory {
|
export default class EmbeddingsFactory {
|
||||||
static create({ model, apiKey, apiVersion, baseURL, dimensions }: KnowledgeBaseParams): BaseEmbeddings {
|
static create({ model, provider, apiKey, apiVersion, baseURL, dimensions }: KnowledgeBaseParams): BaseEmbeddings {
|
||||||
const batchSize = 10
|
const batchSize = 10
|
||||||
if (model.includes('voyage')) {
|
if (provider === 'voyageai') {
|
||||||
return new VoyageEmbeddings({
|
if (VOYAGE_SUPPORTED_DIM_MODELS.includes(model)) {
|
||||||
modelName: model,
|
return new VoyageEmbeddings({
|
||||||
apiKey,
|
modelName: model,
|
||||||
outputDimension: dimensions,
|
apiKey,
|
||||||
batchSize: 8
|
outputDimension: dimensions,
|
||||||
|
batchSize: 8
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
return new VoyageEmbeddings({
|
||||||
|
modelName: model,
|
||||||
|
apiKey,
|
||||||
|
batchSize: 8
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (provider === 'ollama') {
|
||||||
|
if (baseURL.includes('v1/')) {
|
||||||
|
return new OllamaEmbeddings({
|
||||||
|
model: model,
|
||||||
|
baseUrl: baseURL.replace('v1/', ''),
|
||||||
|
requestOptions: {
|
||||||
|
// @ts-ignore expected
|
||||||
|
'encoding-format': 'float'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return new OllamaEmbeddings({
|
||||||
|
model: model,
|
||||||
|
baseUrl: baseURL,
|
||||||
|
requestOptions: {
|
||||||
|
// @ts-ignore expected
|
||||||
|
'encoding-format': 'float'
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
if (apiVersion !== undefined) {
|
if (apiVersion !== undefined) {
|
||||||
@@ -23,14 +52,14 @@ export default class EmbeddingsFactory {
|
|||||||
azureOpenAIApiVersion: apiVersion,
|
azureOpenAIApiVersion: apiVersion,
|
||||||
azureOpenAIApiDeploymentName: model,
|
azureOpenAIApiDeploymentName: model,
|
||||||
azureOpenAIApiInstanceName: getInstanceName(baseURL),
|
azureOpenAIApiInstanceName: getInstanceName(baseURL),
|
||||||
// dimensions,
|
dimensions,
|
||||||
batchSize
|
batchSize
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
return new OpenAiEmbeddings({
|
return new OpenAiEmbeddings({
|
||||||
model,
|
model,
|
||||||
apiKey,
|
apiKey,
|
||||||
// dimensions,
|
dimensions,
|
||||||
batchSize,
|
batchSize,
|
||||||
configuration: { baseURL }
|
configuration: { baseURL }
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,16 +1,20 @@
|
|||||||
import { BaseEmbeddings } from '@cherrystudio/embedjs-interfaces'
|
import { BaseEmbeddings } from '@cherrystudio/embedjs-interfaces'
|
||||||
import { VoyageEmbeddings as _VoyageEmbeddings } from '@langchain/community/embeddings/voyage'
|
import { VoyageEmbeddings as _VoyageEmbeddings } from '@langchain/community/embeddings/voyage'
|
||||||
|
|
||||||
export default class VoyageEmbeddings extends BaseEmbeddings {
|
/**
|
||||||
|
* 支持设置嵌入维度的模型
|
||||||
|
*/
|
||||||
|
export const SUPPORTED_DIM_MODELS = ['voyage-3-large', 'voyage-3.5', 'voyage-3.5-lite', 'voyage-code-3']
|
||||||
|
export class VoyageEmbeddings extends BaseEmbeddings {
|
||||||
private model: _VoyageEmbeddings
|
private model: _VoyageEmbeddings
|
||||||
constructor(private readonly configuration?: ConstructorParameters<typeof _VoyageEmbeddings>[0]) {
|
constructor(private readonly configuration?: ConstructorParameters<typeof _VoyageEmbeddings>[0]) {
|
||||||
super()
|
super()
|
||||||
if (!this.configuration) this.configuration = {}
|
if (!this.configuration) this.configuration = {}
|
||||||
if (!this.configuration.modelName) this.configuration.modelName = 'voyage-3'
|
if (!this.configuration.modelName) this.configuration.modelName = 'voyage-3'
|
||||||
|
if (!SUPPORTED_DIM_MODELS.includes(this.configuration.modelName) && this.configuration.outputDimension) {
|
||||||
if (!this.configuration.outputDimension) {
|
throw new Error(`VoyageEmbeddings only supports ${SUPPORTED_DIM_MODELS.join(', ')}`)
|
||||||
throw new Error('You need to pass in the optional dimensions parameter for this model')
|
|
||||||
}
|
}
|
||||||
|
|
||||||
this.model = new _VoyageEmbeddings(this.configuration)
|
this.model = new _VoyageEmbeddings(this.configuration)
|
||||||
}
|
}
|
||||||
override async getDimensions(): Promise<number> {
|
override async getDimensions(): Promise<number> {
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import { app } from 'electron'
|
|||||||
import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer'
|
import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer'
|
||||||
import Logger from 'electron-log'
|
import Logger from 'electron-log'
|
||||||
|
|
||||||
import { isDev } from './constant'
|
import { isDev, isWin } from './constant'
|
||||||
import { registerIpc } from './ipc'
|
import { registerIpc } from './ipc'
|
||||||
import { configManager } from './services/ConfigManager'
|
import { configManager } from './services/ConfigManager'
|
||||||
import mcpService from './services/MCPService'
|
import mcpService from './services/MCPService'
|
||||||
@@ -16,6 +16,7 @@ import {
|
|||||||
registerProtocolClient,
|
registerProtocolClient,
|
||||||
setupAppImageDeepLink
|
setupAppImageDeepLink
|
||||||
} from './services/ProtocolClient'
|
} from './services/ProtocolClient'
|
||||||
|
import selectionService, { initSelectionService } from './services/SelectionService'
|
||||||
import { registerShortcuts } from './services/ShortcutService'
|
import { registerShortcuts } from './services/ShortcutService'
|
||||||
import { TrayService } from './services/TrayService'
|
import { TrayService } from './services/TrayService'
|
||||||
import { windowService } from './services/WindowService'
|
import { windowService } from './services/WindowService'
|
||||||
@@ -23,6 +24,36 @@ import { setUserDataDir } from './utils/file'
|
|||||||
|
|
||||||
Logger.initialize()
|
Logger.initialize()
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Disable chromium's window animations
|
||||||
|
* main purpose for this is to avoid the transparent window flashing when it is shown
|
||||||
|
* (especially on Windows for SelectionAssistant Toolbar)
|
||||||
|
* Know Issue: https://github.com/electron/electron/issues/12130#issuecomment-627198990
|
||||||
|
*/
|
||||||
|
if (isWin) {
|
||||||
|
app.commandLine.appendSwitch('wm-window-animations-disabled')
|
||||||
|
}
|
||||||
|
|
||||||
|
// Enable features for unresponsive renderer js call stacks
|
||||||
|
app.commandLine.appendSwitch('enable-features', 'DocumentPolicyIncludeJSCallStacksInCrashReports')
|
||||||
|
app.on('web-contents-created', (_, webContents) => {
|
||||||
|
webContents.session.webRequest.onHeadersReceived((details, callback) => {
|
||||||
|
callback({
|
||||||
|
responseHeaders: {
|
||||||
|
...details.responseHeaders,
|
||||||
|
'Document-Policy': ['include-js-call-stacks-in-crash-reports']
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
webContents.on('unresponsive', async () => {
|
||||||
|
// Interrupt execution and collect call stack from unresponsive renderer
|
||||||
|
Logger.error('Renderer unresponsive start')
|
||||||
|
const callStack = await webContents.mainFrame.collectJavaScriptCallStack()
|
||||||
|
Logger.error('Renderer unresponsive js call stack\n', callStack)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
// in production mode, handle uncaught exception and unhandled rejection globally
|
// in production mode, handle uncaught exception and unhandled rejection globally
|
||||||
if (!isDev) {
|
if (!isDev) {
|
||||||
// handle uncaught exception
|
// handle uncaught exception
|
||||||
@@ -84,6 +115,9 @@ if (!app.requestSingleInstanceLock()) {
|
|||||||
.then((name) => console.log(`Added Extension: ${name}`))
|
.then((name) => console.log(`Added Extension: ${name}`))
|
||||||
.catch((err) => console.log('An error occurred: ', err))
|
.catch((err) => console.log('An error occurred: ', err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//start selection assistant service
|
||||||
|
initSelectionService()
|
||||||
})
|
})
|
||||||
|
|
||||||
registerProtocolClient(app)
|
registerProtocolClient(app)
|
||||||
@@ -110,6 +144,11 @@ if (!app.requestSingleInstanceLock()) {
|
|||||||
|
|
||||||
app.on('before-quit', () => {
|
app.on('before-quit', () => {
|
||||||
app.isQuitting = true
|
app.isQuitting = true
|
||||||
|
|
||||||
|
// quit selection service
|
||||||
|
if (selectionService) {
|
||||||
|
selectionService.quit()
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
app.on('will-quit', async () => {
|
app.on('will-quit', async () => {
|
||||||
|
|||||||
@@ -4,13 +4,13 @@ import { arch } from 'node:os'
|
|||||||
import { isMac, isWin } from '@main/constant'
|
import { isMac, isWin } from '@main/constant'
|
||||||
import { getBinaryPath, isBinaryExists, runInstallScript } from '@main/utils/process'
|
import { getBinaryPath, isBinaryExists, runInstallScript } from '@main/utils/process'
|
||||||
import { handleZoomFactor } from '@main/utils/zoom'
|
import { handleZoomFactor } from '@main/utils/zoom'
|
||||||
|
import { FeedUrl } from '@shared/config/constant'
|
||||||
import { IpcChannel } from '@shared/IpcChannel'
|
import { IpcChannel } from '@shared/IpcChannel'
|
||||||
import { Shortcut, ThemeMode } from '@types'
|
import { Shortcut, ThemeMode } from '@types'
|
||||||
import { BrowserWindow, ipcMain, nativeTheme, session, shell } from 'electron'
|
import { BrowserWindow, ipcMain, session, shell } from 'electron'
|
||||||
import log from 'electron-log'
|
import log from 'electron-log'
|
||||||
import { Notification } from 'src/renderer/src/types/notification'
|
import { Notification } from 'src/renderer/src/types/notification'
|
||||||
|
|
||||||
import { titleBarOverlayDark, titleBarOverlayLight } from './config'
|
|
||||||
import AppUpdater from './services/AppUpdater'
|
import AppUpdater from './services/AppUpdater'
|
||||||
import BackupManager from './services/BackupManager'
|
import BackupManager from './services/BackupManager'
|
||||||
import { configManager } from './services/ConfigManager'
|
import { configManager } from './services/ConfigManager'
|
||||||
@@ -18,7 +18,6 @@ import CopilotService from './services/CopilotService'
|
|||||||
import { ExportService } from './services/ExportService'
|
import { ExportService } from './services/ExportService'
|
||||||
import FileService from './services/FileService'
|
import FileService from './services/FileService'
|
||||||
import FileStorage from './services/FileStorage'
|
import FileStorage from './services/FileStorage'
|
||||||
import { GeminiService } from './services/GeminiService'
|
|
||||||
import KnowledgeService from './services/KnowledgeService'
|
import KnowledgeService from './services/KnowledgeService'
|
||||||
import mcpService from './services/MCPService'
|
import mcpService from './services/MCPService'
|
||||||
import NotificationService from './services/NotificationService'
|
import NotificationService from './services/NotificationService'
|
||||||
@@ -26,9 +25,11 @@ import * as NutstoreService from './services/NutstoreService'
|
|||||||
import ObsidianVaultService from './services/ObsidianVaultService'
|
import ObsidianVaultService from './services/ObsidianVaultService'
|
||||||
import { ProxyConfig, proxyManager } from './services/ProxyManager'
|
import { ProxyConfig, proxyManager } from './services/ProxyManager'
|
||||||
import { searchService } from './services/SearchService'
|
import { searchService } from './services/SearchService'
|
||||||
|
import { SelectionService } from './services/SelectionService'
|
||||||
import { registerShortcuts, unregisterAllShortcuts } from './services/ShortcutService'
|
import { registerShortcuts, unregisterAllShortcuts } from './services/ShortcutService'
|
||||||
import storeSyncService from './services/StoreSyncService'
|
import storeSyncService from './services/StoreSyncService'
|
||||||
import { TrayService } from './services/TrayService'
|
import { themeService } from './services/ThemeService'
|
||||||
|
import VertexAIService from './services/VertexAIService'
|
||||||
import { setOpenLinkExternal } from './services/WebviewService'
|
import { setOpenLinkExternal } from './services/WebviewService'
|
||||||
import { windowService } from './services/WindowService'
|
import { windowService } from './services/WindowService'
|
||||||
import { calculateDirectorySize, getResourcePath } from './utils'
|
import { calculateDirectorySize, getResourcePath } from './utils'
|
||||||
@@ -40,6 +41,7 @@ const fileManager = new FileStorage()
|
|||||||
const backupManager = new BackupManager()
|
const backupManager = new BackupManager()
|
||||||
const exportService = new ExportService(fileManager)
|
const exportService = new ExportService(fileManager)
|
||||||
const obsidianVaultService = new ObsidianVaultService()
|
const obsidianVaultService = new ObsidianVaultService()
|
||||||
|
const vertexAIService = VertexAIService.getInstance()
|
||||||
|
|
||||||
export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||||
const appUpdater = new AppUpdater(mainWindow)
|
const appUpdater = new AppUpdater(mainWindow)
|
||||||
@@ -113,10 +115,12 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
configManager.setAutoUpdate(isActive)
|
configManager.setAutoUpdate(isActive)
|
||||||
})
|
})
|
||||||
|
|
||||||
ipcMain.handle(IpcChannel.App_RestartTray, () => TrayService.getInstance().restartTray())
|
ipcMain.handle(IpcChannel.App_SetFeedUrl, (_, feedUrl: FeedUrl) => {
|
||||||
|
appUpdater.setFeedUrl(feedUrl)
|
||||||
|
})
|
||||||
|
|
||||||
ipcMain.handle(IpcChannel.Config_Set, (_, key: string, value: any) => {
|
ipcMain.handle(IpcChannel.Config_Set, (_, key: string, value: any, isNotify: boolean = false) => {
|
||||||
configManager.set(key, value)
|
configManager.set(key, value, isNotify)
|
||||||
})
|
})
|
||||||
|
|
||||||
ipcMain.handle(IpcChannel.Config_Get, (_, key: string) => {
|
ipcMain.handle(IpcChannel.Config_Get, (_, key: string) => {
|
||||||
@@ -125,34 +129,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
|
|
||||||
// theme
|
// theme
|
||||||
ipcMain.handle(IpcChannel.App_SetTheme, (_, theme: ThemeMode) => {
|
ipcMain.handle(IpcChannel.App_SetTheme, (_, theme: ThemeMode) => {
|
||||||
const updateTitleBarOverlay = () => {
|
themeService.setTheme(theme)
|
||||||
if (!mainWindow?.setTitleBarOverlay) return
|
|
||||||
const isDark = nativeTheme.shouldUseDarkColors
|
|
||||||
mainWindow.setTitleBarOverlay(isDark ? titleBarOverlayDark : titleBarOverlayLight)
|
|
||||||
}
|
|
||||||
|
|
||||||
const broadcastThemeChange = () => {
|
|
||||||
const isDark = nativeTheme.shouldUseDarkColors
|
|
||||||
const effectiveTheme = isDark ? ThemeMode.dark : ThemeMode.light
|
|
||||||
BrowserWindow.getAllWindows().forEach((win) => win.webContents.send(IpcChannel.ThemeChange, effectiveTheme))
|
|
||||||
}
|
|
||||||
|
|
||||||
const notifyThemeChange = () => {
|
|
||||||
updateTitleBarOverlay()
|
|
||||||
broadcastThemeChange()
|
|
||||||
}
|
|
||||||
|
|
||||||
if (theme === ThemeMode.auto) {
|
|
||||||
nativeTheme.themeSource = 'system'
|
|
||||||
nativeTheme.on('updated', notifyThemeChange)
|
|
||||||
} else {
|
|
||||||
nativeTheme.themeSource = theme
|
|
||||||
nativeTheme.off('updated', notifyThemeChange)
|
|
||||||
}
|
|
||||||
|
|
||||||
updateTitleBarOverlay()
|
|
||||||
configManager.setTheme(theme)
|
|
||||||
notifyThemeChange()
|
|
||||||
})
|
})
|
||||||
|
|
||||||
ipcMain.handle(IpcChannel.App_HandleZoomFactor, (_, delta: number, reset: boolean = false) => {
|
ipcMain.handle(IpcChannel.App_HandleZoomFactor, (_, delta: number, reset: boolean = false) => {
|
||||||
@@ -200,7 +177,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
|
|
||||||
// check for update
|
// check for update
|
||||||
ipcMain.handle(IpcChannel.App_CheckForUpdate, async () => {
|
ipcMain.handle(IpcChannel.App_CheckForUpdate, async () => {
|
||||||
await appUpdater.checkForUpdates()
|
return await appUpdater.checkForUpdates()
|
||||||
})
|
})
|
||||||
|
|
||||||
// notification
|
// notification
|
||||||
@@ -249,7 +226,9 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
ipcMain.handle(IpcChannel.File_WriteWithId, fileManager.writeFileWithId)
|
ipcMain.handle(IpcChannel.File_WriteWithId, fileManager.writeFileWithId)
|
||||||
ipcMain.handle(IpcChannel.File_SaveImage, fileManager.saveImage)
|
ipcMain.handle(IpcChannel.File_SaveImage, fileManager.saveImage)
|
||||||
ipcMain.handle(IpcChannel.File_Base64Image, fileManager.base64Image)
|
ipcMain.handle(IpcChannel.File_Base64Image, fileManager.base64Image)
|
||||||
|
ipcMain.handle(IpcChannel.File_SaveBase64Image, fileManager.saveBase64Image)
|
||||||
ipcMain.handle(IpcChannel.File_Base64File, fileManager.base64File)
|
ipcMain.handle(IpcChannel.File_Base64File, fileManager.base64File)
|
||||||
|
ipcMain.handle(IpcChannel.File_GetPdfInfo, fileManager.pdfPageCount)
|
||||||
ipcMain.handle(IpcChannel.File_Download, fileManager.downloadFile)
|
ipcMain.handle(IpcChannel.File_Download, fileManager.downloadFile)
|
||||||
ipcMain.handle(IpcChannel.File_Copy, fileManager.copyFile)
|
ipcMain.handle(IpcChannel.File_Copy, fileManager.copyFile)
|
||||||
ipcMain.handle(IpcChannel.File_BinaryImage, fileManager.binaryImage)
|
ipcMain.handle(IpcChannel.File_BinaryImage, fileManager.binaryImage)
|
||||||
@@ -297,12 +276,14 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
// gemini
|
// VertexAI
|
||||||
ipcMain.handle(IpcChannel.Gemini_UploadFile, GeminiService.uploadFile)
|
ipcMain.handle(IpcChannel.VertexAI_GetAuthHeaders, async (_, params) => {
|
||||||
ipcMain.handle(IpcChannel.Gemini_Base64File, GeminiService.base64File)
|
return vertexAIService.getAuthHeaders(params)
|
||||||
ipcMain.handle(IpcChannel.Gemini_RetrieveFile, GeminiService.retrieveFile)
|
})
|
||||||
ipcMain.handle(IpcChannel.Gemini_ListFiles, GeminiService.listFiles)
|
|
||||||
ipcMain.handle(IpcChannel.Gemini_DeleteFile, GeminiService.deleteFile)
|
ipcMain.handle(IpcChannel.VertexAI_ClearAuthCache, async (_, projectId: string, clientEmail?: string) => {
|
||||||
|
vertexAIService.clearAuthCache(projectId, clientEmail)
|
||||||
|
})
|
||||||
|
|
||||||
// mini window
|
// mini window
|
||||||
ipcMain.handle(IpcChannel.MiniWindow_Show, () => windowService.showMiniWindow())
|
ipcMain.handle(IpcChannel.MiniWindow_Show, () => windowService.showMiniWindow())
|
||||||
@@ -379,4 +360,9 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
|
|
||||||
// store sync
|
// store sync
|
||||||
storeSyncService.registerIpcHandler()
|
storeSyncService.registerIpcHandler()
|
||||||
|
|
||||||
|
// selection assistant
|
||||||
|
SelectionService.registerIpcHandler()
|
||||||
|
|
||||||
|
ipcMain.handle(IpcChannel.App_QuoteToMain, (_, text: string) => windowService.quoteToMainWindow(text))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,10 +21,13 @@ export default abstract class BaseReranker {
|
|||||||
return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
|
return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
|
||||||
}
|
}
|
||||||
|
|
||||||
let baseURL = this.base?.rerankBaseURL?.endsWith('/')
|
let baseURL = this.base.rerankBaseURL
|
||||||
? this.base.rerankBaseURL.slice(0, -1)
|
|
||||||
: this.base.rerankBaseURL
|
if (baseURL && baseURL.endsWith('/')) {
|
||||||
// 必须携带/v1,否则会404
|
// `/` 结尾强制使用rerankBaseURL
|
||||||
|
return `${baseURL}rerank`
|
||||||
|
}
|
||||||
|
|
||||||
if (baseURL && !baseURL.endsWith('/v1')) {
|
if (baseURL && !baseURL.endsWith('/v1')) {
|
||||||
baseURL = `${baseURL}/v1`
|
baseURL = `${baseURL}/v1`
|
||||||
}
|
}
|
||||||
@@ -58,6 +61,12 @@ export default abstract class BaseReranker {
|
|||||||
top_n: topN
|
top_n: topN
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
} else if (provider?.includes('tei')) {
|
||||||
|
return {
|
||||||
|
query,
|
||||||
|
texts: documents,
|
||||||
|
return_text: true
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
return {
|
return {
|
||||||
model: this.base.rerankModel,
|
model: this.base.rerankModel,
|
||||||
@@ -77,6 +86,13 @@ export default abstract class BaseReranker {
|
|||||||
return data.output.results
|
return data.output.results
|
||||||
} else if (provider === 'voyageai') {
|
} else if (provider === 'voyageai') {
|
||||||
return data.data
|
return data.data
|
||||||
|
} else if (provider === 'mis-tei') {
|
||||||
|
return data.map((item: any) => {
|
||||||
|
return {
|
||||||
|
index: item.index,
|
||||||
|
relevance_score: item.score
|
||||||
|
}
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
return data.results
|
return data.results
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,12 @@
|
|||||||
import { isWin } from '@main/constant'
|
import { isWin } from '@main/constant'
|
||||||
|
import { locales } from '@main/utils/locales'
|
||||||
|
import { FeedUrl } from '@shared/config/constant'
|
||||||
import { IpcChannel } from '@shared/IpcChannel'
|
import { IpcChannel } from '@shared/IpcChannel'
|
||||||
import { UpdateInfo } from 'builder-util-runtime'
|
import { UpdateInfo } from 'builder-util-runtime'
|
||||||
import { app, BrowserWindow, dialog } from 'electron'
|
import { app, BrowserWindow, dialog } from 'electron'
|
||||||
import logger from 'electron-log'
|
import logger from 'electron-log'
|
||||||
import { AppUpdater as _AppUpdater, autoUpdater } from 'electron-updater'
|
import { AppUpdater as _AppUpdater, autoUpdater, NsisUpdater } from 'electron-updater'
|
||||||
|
import path from 'path'
|
||||||
|
|
||||||
import icon from '../../../build/icon.png?asset'
|
import icon from '../../../build/icon.png?asset'
|
||||||
import { configManager } from './ConfigManager'
|
import { configManager } from './ConfigManager'
|
||||||
@@ -19,6 +22,7 @@ export default class AppUpdater {
|
|||||||
autoUpdater.forceDevUpdateConfig = !app.isPackaged
|
autoUpdater.forceDevUpdateConfig = !app.isPackaged
|
||||||
autoUpdater.autoDownload = configManager.getAutoUpdate()
|
autoUpdater.autoDownload = configManager.getAutoUpdate()
|
||||||
autoUpdater.autoInstallOnAppQuit = configManager.getAutoUpdate()
|
autoUpdater.autoInstallOnAppQuit = configManager.getAutoUpdate()
|
||||||
|
autoUpdater.setFeedURL(configManager.getFeedUrl())
|
||||||
|
|
||||||
// 检测下载错误
|
// 检测下载错误
|
||||||
autoUpdater.on('error', (error) => {
|
autoUpdater.on('error', (error) => {
|
||||||
@@ -53,14 +57,47 @@ export default class AppUpdater {
|
|||||||
logger.info('下载完成', releaseInfo)
|
logger.info('下载完成', releaseInfo)
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if (isWin) {
|
||||||
|
;(autoUpdater as NsisUpdater).installDirectory = path.dirname(app.getPath('exe'))
|
||||||
|
}
|
||||||
|
|
||||||
this.autoUpdater = autoUpdater
|
this.autoUpdater = autoUpdater
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private async _getIpCountry() {
|
||||||
|
try {
|
||||||
|
// add timeout using AbortController
|
||||||
|
const controller = new AbortController()
|
||||||
|
const timeoutId = setTimeout(() => controller.abort(), 5000)
|
||||||
|
|
||||||
|
const ipinfo = await fetch('https://ipinfo.io/json', {
|
||||||
|
signal: controller.signal,
|
||||||
|
headers: {
|
||||||
|
'User-Agent':
|
||||||
|
'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/124.0.0.0 Safari/537.36',
|
||||||
|
'Accept-Language': 'en-US,en;q=0.9'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
clearTimeout(timeoutId)
|
||||||
|
const data = await ipinfo.json()
|
||||||
|
return data.country || 'CN'
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to get ipinfo:', error)
|
||||||
|
return 'CN'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public setAutoUpdate(isActive: boolean) {
|
public setAutoUpdate(isActive: boolean) {
|
||||||
autoUpdater.autoDownload = isActive
|
autoUpdater.autoDownload = isActive
|
||||||
autoUpdater.autoInstallOnAppQuit = isActive
|
autoUpdater.autoInstallOnAppQuit = isActive
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public setFeedUrl(feedUrl: FeedUrl) {
|
||||||
|
autoUpdater.setFeedURL(feedUrl)
|
||||||
|
configManager.setFeedUrl(feedUrl)
|
||||||
|
}
|
||||||
|
|
||||||
public async checkForUpdates() {
|
public async checkForUpdates() {
|
||||||
if (isWin && 'PORTABLE_EXECUTABLE_DIR' in process.env) {
|
if (isWin && 'PORTABLE_EXECUTABLE_DIR' in process.env) {
|
||||||
return {
|
return {
|
||||||
@@ -69,6 +106,12 @@ export default class AppUpdater {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const ipCountry = await this._getIpCountry()
|
||||||
|
logger.info('ipCountry', ipCountry)
|
||||||
|
if (ipCountry !== 'CN') {
|
||||||
|
this.autoUpdater.setFeedURL(FeedUrl.EARLY_ACCESS)
|
||||||
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const update = await this.autoUpdater.checkForUpdates()
|
const update = await this.autoUpdater.checkForUpdates()
|
||||||
if (update?.isUpdateAvailable && !this.autoUpdater.autoDownload) {
|
if (update?.isUpdateAvailable && !this.autoUpdater.autoDownload) {
|
||||||
@@ -94,15 +137,22 @@ export default class AppUpdater {
|
|||||||
if (!this.releaseInfo) {
|
if (!this.releaseInfo) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
const locale = locales[configManager.getLanguage()]
|
||||||
|
const { update: updateLocale } = locale.translation
|
||||||
|
|
||||||
|
let detail = this.formatReleaseNotes(this.releaseInfo.releaseNotes)
|
||||||
|
if (detail === '') {
|
||||||
|
detail = updateLocale.noReleaseNotes
|
||||||
|
}
|
||||||
|
|
||||||
dialog
|
dialog
|
||||||
.showMessageBox({
|
.showMessageBox({
|
||||||
type: 'info',
|
type: 'info',
|
||||||
title: '安装更新',
|
title: updateLocale.title,
|
||||||
icon,
|
icon,
|
||||||
message: `新版本 ${this.releaseInfo.version} 已准备就绪`,
|
message: updateLocale.message.replace('{{version}}', this.releaseInfo.version),
|
||||||
detail: this.formatReleaseNotes(this.releaseInfo.releaseNotes),
|
detail,
|
||||||
buttons: ['稍后安装', '立即安装'],
|
buttons: [updateLocale.later, updateLocale.install],
|
||||||
defaultId: 1,
|
defaultId: 1,
|
||||||
cancelId: 0
|
cancelId: 0
|
||||||
})
|
})
|
||||||
@@ -118,7 +168,7 @@ export default class AppUpdater {
|
|||||||
|
|
||||||
private formatReleaseNotes(releaseNotes: string | ReleaseNoteInfo[] | null | undefined): string {
|
private formatReleaseNotes(releaseNotes: string | ReleaseNoteInfo[] | null | undefined): string {
|
||||||
if (!releaseNotes) {
|
if (!releaseNotes) {
|
||||||
return '暂无更新说明'
|
return ''
|
||||||
}
|
}
|
||||||
|
|
||||||
if (typeof releaseNotes === 'string') {
|
if (typeof releaseNotes === 'string') {
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import Logger from 'electron-log'
|
|||||||
import * as fs from 'fs-extra'
|
import * as fs from 'fs-extra'
|
||||||
import StreamZip from 'node-stream-zip'
|
import StreamZip from 'node-stream-zip'
|
||||||
import * as path from 'path'
|
import * as path from 'path'
|
||||||
import { createClient, CreateDirectoryOptions, FileStat } from 'webdav'
|
import { CreateDirectoryOptions, FileStat } from 'webdav'
|
||||||
|
|
||||||
import WebDav from './WebDav'
|
import WebDav from './WebDav'
|
||||||
import { windowService } from './WindowService'
|
import { windowService } from './WindowService'
|
||||||
@@ -295,10 +295,12 @@ class BackupManager {
|
|||||||
async backupToWebdav(_: Electron.IpcMainInvokeEvent, data: string, webdavConfig: WebDavConfig) {
|
async backupToWebdav(_: Electron.IpcMainInvokeEvent, data: string, webdavConfig: WebDavConfig) {
|
||||||
const filename = webdavConfig.fileName || 'cherry-studio.backup.zip'
|
const filename = webdavConfig.fileName || 'cherry-studio.backup.zip'
|
||||||
const backupedFilePath = await this.backup(_, filename, data, undefined, webdavConfig.skipBackupFile)
|
const backupedFilePath = await this.backup(_, filename, data, undefined, webdavConfig.skipBackupFile)
|
||||||
|
const contentLength = (await fs.stat(backupedFilePath)).size
|
||||||
const webdavClient = new WebDav(webdavConfig)
|
const webdavClient = new WebDav(webdavConfig)
|
||||||
try {
|
try {
|
||||||
const result = await webdavClient.putFileContents(filename, fs.createReadStream(backupedFilePath), {
|
const result = await webdavClient.putFileContents(filename, fs.createReadStream(backupedFilePath), {
|
||||||
overwrite: true
|
overwrite: true,
|
||||||
|
contentLength
|
||||||
})
|
})
|
||||||
// 上传成功后删除本地备份文件
|
// 上传成功后删除本地备份文件
|
||||||
await fs.remove(backupedFilePath)
|
await fs.remove(backupedFilePath)
|
||||||
@@ -340,12 +342,8 @@ class BackupManager {
|
|||||||
|
|
||||||
listWebdavFiles = async (_: Electron.IpcMainInvokeEvent, config: WebDavConfig) => {
|
listWebdavFiles = async (_: Electron.IpcMainInvokeEvent, config: WebDavConfig) => {
|
||||||
try {
|
try {
|
||||||
const client = createClient(config.webdavHost, {
|
const client = new WebDav(config)
|
||||||
username: config.webdavUser,
|
const response = await client.getDirectoryContents()
|
||||||
password: config.webdavPass
|
|
||||||
})
|
|
||||||
|
|
||||||
const response = await client.getDirectoryContents(config.webdavPath)
|
|
||||||
const files = Array.isArray(response) ? response : response.data
|
const files = Array.isArray(response) ? response : response.data
|
||||||
|
|
||||||
return files
|
return files
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
import { defaultLanguage, ZOOM_SHORTCUTS } from '@shared/config/constant'
|
import { defaultLanguage, FeedUrl, ZOOM_SHORTCUTS } from '@shared/config/constant'
|
||||||
import { LanguageVarious, Shortcut, ThemeMode } from '@types'
|
import { LanguageVarious, Shortcut, ThemeMode } from '@types'
|
||||||
import { app } from 'electron'
|
import { app } from 'electron'
|
||||||
import Store from 'electron-store'
|
import Store from 'electron-store'
|
||||||
|
|
||||||
import { locales } from '../utils/locales'
|
import { locales } from '../utils/locales'
|
||||||
|
|
||||||
enum ConfigKeys {
|
export enum ConfigKeys {
|
||||||
Language = 'language',
|
Language = 'language',
|
||||||
Theme = 'theme',
|
Theme = 'theme',
|
||||||
LaunchToTray = 'launchToTray',
|
LaunchToTray = 'launchToTray',
|
||||||
@@ -16,7 +16,14 @@ enum ConfigKeys {
|
|||||||
ClickTrayToShowQuickAssistant = 'clickTrayToShowQuickAssistant',
|
ClickTrayToShowQuickAssistant = 'clickTrayToShowQuickAssistant',
|
||||||
EnableQuickAssistant = 'enableQuickAssistant',
|
EnableQuickAssistant = 'enableQuickAssistant',
|
||||||
AutoUpdate = 'autoUpdate',
|
AutoUpdate = 'autoUpdate',
|
||||||
EnableDataCollection = 'enableDataCollection'
|
FeedUrl = 'feedUrl',
|
||||||
|
EnableDataCollection = 'enableDataCollection',
|
||||||
|
SelectionAssistantEnabled = 'selectionAssistantEnabled',
|
||||||
|
SelectionAssistantTriggerMode = 'selectionAssistantTriggerMode',
|
||||||
|
SelectionAssistantFollowToolbar = 'selectionAssistantFollowToolbar',
|
||||||
|
SelectionAssistantRemeberWinSize = 'selectionAssistantRemeberWinSize',
|
||||||
|
SelectionAssistantFilterMode = 'selectionAssistantFilterMode',
|
||||||
|
SelectionAssistantFilterList = 'selectionAssistantFilterList'
|
||||||
}
|
}
|
||||||
|
|
||||||
export class ConfigManager {
|
export class ConfigManager {
|
||||||
@@ -32,12 +39,12 @@ export class ConfigManager {
|
|||||||
return this.get(ConfigKeys.Language, locale) as LanguageVarious
|
return this.get(ConfigKeys.Language, locale) as LanguageVarious
|
||||||
}
|
}
|
||||||
|
|
||||||
setLanguage(theme: LanguageVarious) {
|
setLanguage(lang: LanguageVarious) {
|
||||||
this.set(ConfigKeys.Language, theme)
|
this.setAndNotify(ConfigKeys.Language, lang)
|
||||||
}
|
}
|
||||||
|
|
||||||
getTheme(): ThemeMode {
|
getTheme(): ThemeMode {
|
||||||
return this.get(ConfigKeys.Theme, ThemeMode.auto)
|
return this.get(ConfigKeys.Theme, ThemeMode.system)
|
||||||
}
|
}
|
||||||
|
|
||||||
setTheme(theme: ThemeMode) {
|
setTheme(theme: ThemeMode) {
|
||||||
@@ -57,8 +64,7 @@ export class ConfigManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
setTray(value: boolean) {
|
setTray(value: boolean) {
|
||||||
this.set(ConfigKeys.Tray, value)
|
this.setAndNotify(ConfigKeys.Tray, value)
|
||||||
this.notifySubscribers(ConfigKeys.Tray, value)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
getTrayOnClose(): boolean {
|
getTrayOnClose(): boolean {
|
||||||
@@ -74,8 +80,7 @@ export class ConfigManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
setZoomFactor(factor: number) {
|
setZoomFactor(factor: number) {
|
||||||
this.set(ConfigKeys.ZoomFactor, factor)
|
this.setAndNotify(ConfigKeys.ZoomFactor, factor)
|
||||||
this.notifySubscribers(ConfigKeys.ZoomFactor, factor)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
subscribe<T>(key: string, callback: (newValue: T) => void) {
|
subscribe<T>(key: string, callback: (newValue: T) => void) {
|
||||||
@@ -107,11 +112,10 @@ export class ConfigManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
setShortcuts(shortcuts: Shortcut[]) {
|
setShortcuts(shortcuts: Shortcut[]) {
|
||||||
this.set(
|
this.setAndNotify(
|
||||||
ConfigKeys.Shortcuts,
|
ConfigKeys.Shortcuts,
|
||||||
shortcuts.filter((shortcut) => shortcut.system)
|
shortcuts.filter((shortcut) => shortcut.system)
|
||||||
)
|
)
|
||||||
this.notifySubscribers(ConfigKeys.Shortcuts, shortcuts)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
getClickTrayToShowQuickAssistant(): boolean {
|
getClickTrayToShowQuickAssistant(): boolean {
|
||||||
@@ -127,7 +131,7 @@ export class ConfigManager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
setEnableQuickAssistant(value: boolean) {
|
setEnableQuickAssistant(value: boolean) {
|
||||||
this.set(ConfigKeys.EnableQuickAssistant, value)
|
this.setAndNotify(ConfigKeys.EnableQuickAssistant, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
getAutoUpdate(): boolean {
|
getAutoUpdate(): boolean {
|
||||||
@@ -138,6 +142,14 @@ export class ConfigManager {
|
|||||||
this.set(ConfigKeys.AutoUpdate, value)
|
this.set(ConfigKeys.AutoUpdate, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getFeedUrl(): string {
|
||||||
|
return this.get<string>(ConfigKeys.FeedUrl, FeedUrl.PRODUCTION)
|
||||||
|
}
|
||||||
|
|
||||||
|
setFeedUrl(value: FeedUrl) {
|
||||||
|
this.set(ConfigKeys.FeedUrl, value)
|
||||||
|
}
|
||||||
|
|
||||||
getEnableDataCollection(): boolean {
|
getEnableDataCollection(): boolean {
|
||||||
return this.get<boolean>(ConfigKeys.EnableDataCollection, true)
|
return this.get<boolean>(ConfigKeys.EnableDataCollection, true)
|
||||||
}
|
}
|
||||||
@@ -146,8 +158,64 @@ export class ConfigManager {
|
|||||||
this.set(ConfigKeys.EnableDataCollection, value)
|
this.set(ConfigKeys.EnableDataCollection, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
set(key: string, value: unknown) {
|
// Selection Assistant: is enabled the selection assistant
|
||||||
|
getSelectionAssistantEnabled(): boolean {
|
||||||
|
return this.get<boolean>(ConfigKeys.SelectionAssistantEnabled, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
setSelectionAssistantEnabled(value: boolean) {
|
||||||
|
this.setAndNotify(ConfigKeys.SelectionAssistantEnabled, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Selection Assistant: trigger mode (selected, ctrlkey)
|
||||||
|
getSelectionAssistantTriggerMode(): string {
|
||||||
|
return this.get<string>(ConfigKeys.SelectionAssistantTriggerMode, 'selected')
|
||||||
|
}
|
||||||
|
|
||||||
|
setSelectionAssistantTriggerMode(value: string) {
|
||||||
|
this.setAndNotify(ConfigKeys.SelectionAssistantTriggerMode, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Selection Assistant: if action window position follow toolbar
|
||||||
|
getSelectionAssistantFollowToolbar(): boolean {
|
||||||
|
return this.get<boolean>(ConfigKeys.SelectionAssistantFollowToolbar, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
setSelectionAssistantFollowToolbar(value: boolean) {
|
||||||
|
this.setAndNotify(ConfigKeys.SelectionAssistantFollowToolbar, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
getSelectionAssistantRemeberWinSize(): boolean {
|
||||||
|
return this.get<boolean>(ConfigKeys.SelectionAssistantRemeberWinSize, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
setSelectionAssistantRemeberWinSize(value: boolean) {
|
||||||
|
this.setAndNotify(ConfigKeys.SelectionAssistantRemeberWinSize, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
getSelectionAssistantFilterMode(): string {
|
||||||
|
return this.get<string>(ConfigKeys.SelectionAssistantFilterMode, 'default')
|
||||||
|
}
|
||||||
|
|
||||||
|
setSelectionAssistantFilterMode(value: string) {
|
||||||
|
this.setAndNotify(ConfigKeys.SelectionAssistantFilterMode, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
getSelectionAssistantFilterList(): string[] {
|
||||||
|
return this.get<string[]>(ConfigKeys.SelectionAssistantFilterList, [])
|
||||||
|
}
|
||||||
|
|
||||||
|
setSelectionAssistantFilterList(value: string[]) {
|
||||||
|
this.setAndNotify(ConfigKeys.SelectionAssistantFilterList, value)
|
||||||
|
}
|
||||||
|
|
||||||
|
setAndNotify(key: string, value: unknown) {
|
||||||
|
this.set(key, value, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
set(key: string, value: unknown, isNotify: boolean = false) {
|
||||||
this.store.set(key, value)
|
this.store.set(key, value)
|
||||||
|
isNotify && this.notifySubscribers(key, value)
|
||||||
}
|
}
|
||||||
|
|
||||||
get<T>(key: string, defaultValue?: T) {
|
get<T>(key: string, defaultValue?: T) {
|
||||||
|
|||||||
@@ -47,6 +47,8 @@ export class ExportService {
|
|||||||
let linkText = ''
|
let linkText = ''
|
||||||
let linkUrl = ''
|
let linkUrl = ''
|
||||||
let insideLink = false
|
let insideLink = false
|
||||||
|
let boldStack = 0 // 跟踪嵌套的粗体标记
|
||||||
|
let italicStack = 0 // 跟踪嵌套的斜体标记
|
||||||
|
|
||||||
for (let i = 0; i < tokens.length; i++) {
|
for (let i = 0; i < tokens.length; i++) {
|
||||||
const token = tokens[i]
|
const token = tokens[i]
|
||||||
@@ -82,17 +84,37 @@ export class ExportService {
|
|||||||
insideLink = false
|
insideLink = false
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
|
case 'strong_open':
|
||||||
|
boldStack++
|
||||||
|
break
|
||||||
|
case 'strong_close':
|
||||||
|
boldStack--
|
||||||
|
break
|
||||||
|
case 'em_open':
|
||||||
|
italicStack++
|
||||||
|
break
|
||||||
|
case 'em_close':
|
||||||
|
italicStack--
|
||||||
|
break
|
||||||
case 'text':
|
case 'text':
|
||||||
runs.push(new TextRun({ text: token.content, bold: isHeaderRow }))
|
runs.push(
|
||||||
break
|
new TextRun({
|
||||||
case 'strong':
|
text: token.content,
|
||||||
runs.push(new TextRun({ text: token.content, bold: true }))
|
bold: isHeaderRow || boldStack > 0,
|
||||||
break
|
italics: italicStack > 0
|
||||||
case 'em':
|
})
|
||||||
runs.push(new TextRun({ text: token.content, italics: true }))
|
)
|
||||||
break
|
break
|
||||||
case 'code_inline':
|
case 'code_inline':
|
||||||
runs.push(new TextRun({ text: token.content, font: 'Consolas', size: 20 }))
|
runs.push(
|
||||||
|
new TextRun({
|
||||||
|
text: token.content,
|
||||||
|
font: 'Consolas',
|
||||||
|
size: 20,
|
||||||
|
bold: isHeaderRow || boldStack > 0,
|
||||||
|
italics: italicStack > 0
|
||||||
|
})
|
||||||
|
)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,7 +1,9 @@
|
|||||||
import fs from 'node:fs'
|
import fs from 'fs/promises'
|
||||||
|
|
||||||
export default class FileService {
|
export default class FileService {
|
||||||
public static async readFile(_: Electron.IpcMainInvokeEvent, path: string) {
|
public static async readFile(_: Electron.IpcMainInvokeEvent, pathOrUrl: string, encoding?: BufferEncoding) {
|
||||||
return fs.readFileSync(path, 'utf8')
|
const path = pathOrUrl.startsWith('file://') ? new URL(pathOrUrl) : pathOrUrl
|
||||||
|
if (encoding) return fs.readFile(path, { encoding })
|
||||||
|
return fs.readFile(path)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import * as fs from 'fs'
|
|||||||
import { writeFileSync } from 'fs'
|
import { writeFileSync } from 'fs'
|
||||||
import { readFile } from 'fs/promises'
|
import { readFile } from 'fs/promises'
|
||||||
import officeParser from 'officeparser'
|
import officeParser from 'officeparser'
|
||||||
|
import { getDocument } from 'officeparser/pdfjs-dist-build/pdf.js'
|
||||||
import * as path from 'path'
|
import * as path from 'path'
|
||||||
import { chdir } from 'process'
|
import { chdir } from 'process'
|
||||||
import { v4 as uuidv4 } from 'uuid'
|
import { v4 as uuidv4 } from 'uuid'
|
||||||
@@ -268,6 +269,51 @@ class FileStorage {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public saveBase64Image = async (_: Electron.IpcMainInvokeEvent, base64Data: string): Promise<FileType> => {
|
||||||
|
try {
|
||||||
|
if (!base64Data) {
|
||||||
|
throw new Error('Base64 data is required')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 移除 base64 头部信息(如果存在)
|
||||||
|
const base64String = base64Data.replace(/^data:.*;base64,/, '')
|
||||||
|
const buffer = Buffer.from(base64String, 'base64')
|
||||||
|
const uuid = uuidv4()
|
||||||
|
const ext = '.png'
|
||||||
|
const destPath = path.join(this.storageDir, uuid + ext)
|
||||||
|
|
||||||
|
logger.info('[FileStorage] Saving base64 image:', {
|
||||||
|
storageDir: this.storageDir,
|
||||||
|
destPath,
|
||||||
|
bufferSize: buffer.length
|
||||||
|
})
|
||||||
|
|
||||||
|
// 确保目录存在
|
||||||
|
if (!fs.existsSync(this.storageDir)) {
|
||||||
|
fs.mkdirSync(this.storageDir, { recursive: true })
|
||||||
|
}
|
||||||
|
|
||||||
|
await fs.promises.writeFile(destPath, buffer)
|
||||||
|
|
||||||
|
const fileMetadata: FileType = {
|
||||||
|
id: uuid,
|
||||||
|
origin_name: uuid + ext,
|
||||||
|
name: uuid + ext,
|
||||||
|
path: destPath,
|
||||||
|
created_at: new Date().toISOString(),
|
||||||
|
size: buffer.length,
|
||||||
|
ext: ext.slice(1),
|
||||||
|
type: getFileType(ext),
|
||||||
|
count: 1
|
||||||
|
}
|
||||||
|
|
||||||
|
return fileMetadata
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('[FileStorage] Failed to save base64 image:', error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public base64File = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<{ data: string; mime: string }> => {
|
public base64File = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<{ data: string; mime: string }> => {
|
||||||
const filePath = path.join(this.storageDir, id)
|
const filePath = path.join(this.storageDir, id)
|
||||||
const buffer = await fs.promises.readFile(filePath)
|
const buffer = await fs.promises.readFile(filePath)
|
||||||
@@ -276,6 +322,16 @@ class FileStorage {
|
|||||||
return { data: base64, mime }
|
return { data: base64, mime }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public pdfPageCount = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<number> => {
|
||||||
|
const filePath = path.join(this.storageDir, id)
|
||||||
|
const buffer = await fs.promises.readFile(filePath)
|
||||||
|
|
||||||
|
const doc = await getDocument({ data: buffer }).promise
|
||||||
|
const pages = doc.numPages
|
||||||
|
await doc.destroy()
|
||||||
|
return pages
|
||||||
|
}
|
||||||
|
|
||||||
public binaryImage = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<{ data: Buffer; mime: string }> => {
|
public binaryImage = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<{ data: Buffer; mime: string }> => {
|
||||||
const filePath = path.join(this.storageDir, id)
|
const filePath = path.join(this.storageDir, id)
|
||||||
const data = await fs.promises.readFile(filePath)
|
const data = await fs.promises.readFile(filePath)
|
||||||
|
|||||||
@@ -1,79 +0,0 @@
|
|||||||
import { File, FileState, GoogleGenAI, Pager } from '@google/genai'
|
|
||||||
import { FileType } from '@types'
|
|
||||||
import fs from 'fs'
|
|
||||||
|
|
||||||
import { CacheService } from './CacheService'
|
|
||||||
|
|
||||||
export class GeminiService {
|
|
||||||
private static readonly FILE_LIST_CACHE_KEY = 'gemini_file_list'
|
|
||||||
private static readonly CACHE_DURATION = 3000
|
|
||||||
|
|
||||||
static async uploadFile(
|
|
||||||
_: Electron.IpcMainInvokeEvent,
|
|
||||||
file: FileType,
|
|
||||||
{ apiKey, baseURL }: { apiKey: string; baseURL: string }
|
|
||||||
): Promise<File> {
|
|
||||||
const sdk = new GoogleGenAI({
|
|
||||||
vertexai: false,
|
|
||||||
apiKey,
|
|
||||||
httpOptions: {
|
|
||||||
baseUrl: baseURL
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
return await sdk.files.upload({
|
|
||||||
file: file.path,
|
|
||||||
config: {
|
|
||||||
mimeType: 'application/pdf',
|
|
||||||
name: file.id,
|
|
||||||
displayName: file.origin_name
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
static async base64File(_: Electron.IpcMainInvokeEvent, file: FileType) {
|
|
||||||
return {
|
|
||||||
data: Buffer.from(fs.readFileSync(file.path)).toString('base64'),
|
|
||||||
mimeType: 'application/pdf'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
static async retrieveFile(_: Electron.IpcMainInvokeEvent, file: FileType, apiKey: string): Promise<File | undefined> {
|
|
||||||
const sdk = new GoogleGenAI({ vertexai: false, apiKey })
|
|
||||||
const cachedResponse = CacheService.get<any>(GeminiService.FILE_LIST_CACHE_KEY)
|
|
||||||
if (cachedResponse) {
|
|
||||||
return GeminiService.processResponse(cachedResponse, file)
|
|
||||||
}
|
|
||||||
|
|
||||||
const response = await sdk.files.list()
|
|
||||||
CacheService.set(GeminiService.FILE_LIST_CACHE_KEY, response, GeminiService.CACHE_DURATION)
|
|
||||||
|
|
||||||
return GeminiService.processResponse(response, file)
|
|
||||||
}
|
|
||||||
|
|
||||||
private static async processResponse(response: Pager<File>, file: FileType) {
|
|
||||||
for await (const f of response) {
|
|
||||||
if (f.state === FileState.ACTIVE) {
|
|
||||||
if (f.displayName === file.origin_name && Number(f.sizeBytes) === file.size) {
|
|
||||||
return f
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return undefined
|
|
||||||
}
|
|
||||||
|
|
||||||
static async listFiles(_: Electron.IpcMainInvokeEvent, apiKey: string): Promise<File[]> {
|
|
||||||
const sdk = new GoogleGenAI({ vertexai: false, apiKey })
|
|
||||||
const files: File[] = []
|
|
||||||
for await (const f of await sdk.files.list()) {
|
|
||||||
files.push(f)
|
|
||||||
}
|
|
||||||
return files
|
|
||||||
}
|
|
||||||
|
|
||||||
static async deleteFile(_: Electron.IpcMainInvokeEvent, fileId: string, apiKey: string) {
|
|
||||||
const sdk = new GoogleGenAI({ vertexai: false, apiKey })
|
|
||||||
await sdk.files.delete({ name: fileId })
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -110,13 +110,21 @@ class KnowledgeService {
|
|||||||
private getRagApplication = async ({
|
private getRagApplication = async ({
|
||||||
id,
|
id,
|
||||||
model,
|
model,
|
||||||
|
provider,
|
||||||
apiKey,
|
apiKey,
|
||||||
apiVersion,
|
apiVersion,
|
||||||
baseURL,
|
baseURL,
|
||||||
dimensions
|
dimensions
|
||||||
}: KnowledgeBaseParams): Promise<RAGApplication> => {
|
}: KnowledgeBaseParams): Promise<RAGApplication> => {
|
||||||
let ragApplication: RAGApplication
|
let ragApplication: RAGApplication
|
||||||
const embeddings = new Embeddings({ model, apiKey, apiVersion, baseURL, dimensions } as KnowledgeBaseParams)
|
const embeddings = new Embeddings({
|
||||||
|
model,
|
||||||
|
provider,
|
||||||
|
apiKey,
|
||||||
|
apiVersion,
|
||||||
|
baseURL,
|
||||||
|
dimensions
|
||||||
|
} as KnowledgeBaseParams)
|
||||||
try {
|
try {
|
||||||
ragApplication = await new RAGApplicationBuilder()
|
ragApplication = await new RAGApplicationBuilder()
|
||||||
.setModel('NO_MODEL')
|
.setModel('NO_MODEL')
|
||||||
|
|||||||
@@ -91,7 +91,7 @@ class McpService {
|
|||||||
return JSON.stringify({
|
return JSON.stringify({
|
||||||
baseUrl: server.baseUrl,
|
baseUrl: server.baseUrl,
|
||||||
command: server.command,
|
command: server.command,
|
||||||
args: server.args,
|
args: Array.isArray(server.args) ? server.args : [],
|
||||||
registryUrl: server.registryUrl,
|
registryUrl: server.registryUrl,
|
||||||
env: server.env,
|
env: server.env,
|
||||||
id: server.id
|
id: server.id
|
||||||
@@ -245,7 +245,7 @@ class McpService {
|
|||||||
const loginShellEnv = await this.getLoginShellEnv()
|
const loginShellEnv = await this.getLoginShellEnv()
|
||||||
|
|
||||||
// Bun not support proxy https://github.com/oven-sh/bun/issues/16812
|
// Bun not support proxy https://github.com/oven-sh/bun/issues/16812
|
||||||
if (cmd.endsWith('bun')) {
|
if (cmd.includes('bun')) {
|
||||||
this.removeProxyEnv(loginShellEnv)
|
this.removeProxyEnv(loginShellEnv)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -567,12 +567,11 @@ class McpService {
|
|||||||
try {
|
try {
|
||||||
const result = await client.listResources()
|
const result = await client.listResources()
|
||||||
const resources = result.resources || []
|
const resources = result.resources || []
|
||||||
const serverResources = (Array.isArray(resources) ? resources : []).map((resource: any) => ({
|
return (Array.isArray(resources) ? resources : []).map((resource: any) => ({
|
||||||
...resource,
|
...resource,
|
||||||
serverId: server.id,
|
serverId: server.id,
|
||||||
serverName: server.name
|
serverName: server.name
|
||||||
}))
|
}))
|
||||||
return serverResources
|
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
// -32601 is the code for the method not found
|
// -32601 is the code for the method not found
|
||||||
if (error?.code !== -32601) {
|
if (error?.code !== -32601) {
|
||||||
|
|||||||
1324
src/main/services/SelectionService.ts
Normal file
@@ -4,10 +4,16 @@ import { BrowserWindow, globalShortcut } from 'electron'
|
|||||||
import Logger from 'electron-log'
|
import Logger from 'electron-log'
|
||||||
|
|
||||||
import { configManager } from './ConfigManager'
|
import { configManager } from './ConfigManager'
|
||||||
|
import selectionService from './SelectionService'
|
||||||
import { windowService } from './WindowService'
|
import { windowService } from './WindowService'
|
||||||
|
|
||||||
let showAppAccelerator: string | null = null
|
let showAppAccelerator: string | null = null
|
||||||
let showMiniWindowAccelerator: string | null = null
|
let showMiniWindowAccelerator: string | null = null
|
||||||
|
let selectionAssistantToggleAccelerator: string | null = null
|
||||||
|
let selectionAssistantSelectTextAccelerator: string | null = null
|
||||||
|
|
||||||
|
//indicate if the shortcuts are registered on app boot time
|
||||||
|
let isRegisterOnBoot = true
|
||||||
|
|
||||||
// store the focus and blur handlers for each window to unregister them later
|
// store the focus and blur handlers for each window to unregister them later
|
||||||
const windowOnHandlers = new Map<BrowserWindow, { onFocusHandler: () => void; onBlurHandler: () => void }>()
|
const windowOnHandlers = new Map<BrowserWindow, { onFocusHandler: () => void; onBlurHandler: () => void }>()
|
||||||
@@ -28,6 +34,18 @@ function getShortcutHandler(shortcut: Shortcut) {
|
|||||||
return () => {
|
return () => {
|
||||||
windowService.toggleMiniWindow()
|
windowService.toggleMiniWindow()
|
||||||
}
|
}
|
||||||
|
case 'selection_assistant_toggle':
|
||||||
|
return () => {
|
||||||
|
if (selectionService) {
|
||||||
|
selectionService.toggleEnabled()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case 'selection_assistant_select_text':
|
||||||
|
return () => {
|
||||||
|
if (selectionService) {
|
||||||
|
selectionService.processSelectTextByShortcut()
|
||||||
|
}
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
return null
|
return null
|
||||||
}
|
}
|
||||||
@@ -37,9 +55,8 @@ function formatShortcutKey(shortcut: string[]): string {
|
|||||||
return shortcut.join('+')
|
return shortcut.join('+')
|
||||||
}
|
}
|
||||||
|
|
||||||
const convertShortcutRecordedByKeyboardEventKeyValueToElectronGlobalShortcutFormat = (
|
// convert the shortcut recorded by keyboard event key value to electron global shortcut format
|
||||||
shortcut: string | string[]
|
const convertShortcutFormat = (shortcut: string | string[]): string => {
|
||||||
): string => {
|
|
||||||
const accelerator = (() => {
|
const accelerator = (() => {
|
||||||
if (Array.isArray(shortcut)) {
|
if (Array.isArray(shortcut)) {
|
||||||
return shortcut
|
return shortcut
|
||||||
@@ -93,11 +110,14 @@ const convertShortcutRecordedByKeyboardEventKeyValueToElectronGlobalShortcutForm
|
|||||||
}
|
}
|
||||||
|
|
||||||
export function registerShortcuts(window: BrowserWindow) {
|
export function registerShortcuts(window: BrowserWindow) {
|
||||||
window.once('ready-to-show', () => {
|
if (isRegisterOnBoot) {
|
||||||
if (configManager.getLaunchToTray()) {
|
window.once('ready-to-show', () => {
|
||||||
registerOnlyUniversalShortcuts()
|
if (configManager.getLaunchToTray()) {
|
||||||
}
|
registerOnlyUniversalShortcuts()
|
||||||
})
|
}
|
||||||
|
})
|
||||||
|
isRegisterOnBoot = false
|
||||||
|
}
|
||||||
|
|
||||||
//only for clearer code
|
//only for clearer code
|
||||||
const registerOnlyUniversalShortcuts = () => {
|
const registerOnlyUniversalShortcuts = () => {
|
||||||
@@ -124,7 +144,12 @@ export function registerShortcuts(window: BrowserWindow) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// only register universal shortcuts when needed
|
// only register universal shortcuts when needed
|
||||||
if (onlyUniversalShortcuts && !['show_app', 'mini_window'].includes(shortcut.key)) {
|
if (
|
||||||
|
onlyUniversalShortcuts &&
|
||||||
|
!['show_app', 'mini_window', 'selection_assistant_toggle', 'selection_assistant_select_text'].includes(
|
||||||
|
shortcut.key
|
||||||
|
)
|
||||||
|
) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,6 +171,14 @@ export function registerShortcuts(window: BrowserWindow) {
|
|||||||
showMiniWindowAccelerator = formatShortcutKey(shortcut.shortcut)
|
showMiniWindowAccelerator = formatShortcutKey(shortcut.shortcut)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
case 'selection_assistant_toggle':
|
||||||
|
selectionAssistantToggleAccelerator = formatShortcutKey(shortcut.shortcut)
|
||||||
|
break
|
||||||
|
|
||||||
|
case 'selection_assistant_select_text':
|
||||||
|
selectionAssistantSelectTextAccelerator = formatShortcutKey(shortcut.shortcut)
|
||||||
|
break
|
||||||
|
|
||||||
//the following ZOOMs will register shortcuts seperately, so will return
|
//the following ZOOMs will register shortcuts seperately, so will return
|
||||||
case 'zoom_in':
|
case 'zoom_in':
|
||||||
globalShortcut.register('CommandOrControl+=', () => handler(window))
|
globalShortcut.register('CommandOrControl+=', () => handler(window))
|
||||||
@@ -162,9 +195,7 @@ export function registerShortcuts(window: BrowserWindow) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
const accelerator = convertShortcutRecordedByKeyboardEventKeyValueToElectronGlobalShortcutFormat(
|
const accelerator = convertShortcutFormat(shortcut.shortcut)
|
||||||
shortcut.shortcut
|
|
||||||
)
|
|
||||||
|
|
||||||
globalShortcut.register(accelerator, () => handler(window))
|
globalShortcut.register(accelerator, () => handler(window))
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -181,15 +212,25 @@ export function registerShortcuts(window: BrowserWindow) {
|
|||||||
|
|
||||||
if (showAppAccelerator) {
|
if (showAppAccelerator) {
|
||||||
const handler = getShortcutHandler({ key: 'show_app' } as Shortcut)
|
const handler = getShortcutHandler({ key: 'show_app' } as Shortcut)
|
||||||
const accelerator =
|
const accelerator = convertShortcutFormat(showAppAccelerator)
|
||||||
convertShortcutRecordedByKeyboardEventKeyValueToElectronGlobalShortcutFormat(showAppAccelerator)
|
|
||||||
handler && globalShortcut.register(accelerator, () => handler(window))
|
handler && globalShortcut.register(accelerator, () => handler(window))
|
||||||
}
|
}
|
||||||
|
|
||||||
if (showMiniWindowAccelerator) {
|
if (showMiniWindowAccelerator) {
|
||||||
const handler = getShortcutHandler({ key: 'mini_window' } as Shortcut)
|
const handler = getShortcutHandler({ key: 'mini_window' } as Shortcut)
|
||||||
const accelerator =
|
const accelerator = convertShortcutFormat(showMiniWindowAccelerator)
|
||||||
convertShortcutRecordedByKeyboardEventKeyValueToElectronGlobalShortcutFormat(showMiniWindowAccelerator)
|
handler && globalShortcut.register(accelerator, () => handler(window))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (selectionAssistantToggleAccelerator) {
|
||||||
|
const handler = getShortcutHandler({ key: 'selection_assistant_toggle' } as Shortcut)
|
||||||
|
const accelerator = convertShortcutFormat(selectionAssistantToggleAccelerator)
|
||||||
|
handler && globalShortcut.register(accelerator, () => handler(window))
|
||||||
|
}
|
||||||
|
|
||||||
|
if (selectionAssistantSelectTextAccelerator) {
|
||||||
|
const handler = getShortcutHandler({ key: 'selection_assistant_select_text' } as Shortcut)
|
||||||
|
const accelerator = convertShortcutFormat(selectionAssistantSelectTextAccelerator)
|
||||||
handler && globalShortcut.register(accelerator, () => handler(window))
|
handler && globalShortcut.register(accelerator, () => handler(window))
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
@@ -217,6 +258,8 @@ export function unregisterAllShortcuts() {
|
|||||||
try {
|
try {
|
||||||
showAppAccelerator = null
|
showAppAccelerator = null
|
||||||
showMiniWindowAccelerator = null
|
showMiniWindowAccelerator = null
|
||||||
|
selectionAssistantToggleAccelerator = null
|
||||||
|
selectionAssistantSelectTextAccelerator = null
|
||||||
windowOnHandlers.forEach((handlers, window) => {
|
windowOnHandlers.forEach((handlers, window) => {
|
||||||
window.off('focus', handlers.onFocusHandler)
|
window.off('focus', handlers.onFocusHandler)
|
||||||
window.off('blur', handlers.onBlurHandler)
|
window.off('blur', handlers.onBlurHandler)
|
||||||
|
|||||||
@@ -49,6 +49,23 @@ export class StoreSyncService {
|
|||||||
this.windowIds = this.windowIds.filter((id) => id !== windowId)
|
this.windowIds = this.windowIds.filter((id) => id !== windowId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Sync an action to all renderer windows
|
||||||
|
* @param type Action type, like 'settings/setTray'
|
||||||
|
* @param payload Action payload
|
||||||
|
*
|
||||||
|
* NOTICE: DO NOT use directly in ConfigManager, may cause infinite sync loop
|
||||||
|
*/
|
||||||
|
public syncToRenderer(type: string, payload: any): void {
|
||||||
|
const action: StoreSyncAction = {
|
||||||
|
type,
|
||||||
|
payload
|
||||||
|
}
|
||||||
|
|
||||||
|
//-1 means the action is from the main process, will be broadcast to all windows
|
||||||
|
this.broadcastToOtherWindows(-1, action)
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Register IPC handlers for store sync communication
|
* Register IPC handlers for store sync communication
|
||||||
* Handles window subscription, unsubscription and action broadcasting
|
* Handles window subscription, unsubscription and action broadcasting
|
||||||
|
|||||||
48
src/main/services/ThemeService.ts
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
import { IpcChannel } from '@shared/IpcChannel'
|
||||||
|
import { ThemeMode } from '@types'
|
||||||
|
import { BrowserWindow, nativeTheme } from 'electron'
|
||||||
|
|
||||||
|
import { titleBarOverlayDark, titleBarOverlayLight } from '../config'
|
||||||
|
import { configManager } from './ConfigManager'
|
||||||
|
|
||||||
|
class ThemeService {
|
||||||
|
private theme: ThemeMode = ThemeMode.system
|
||||||
|
constructor() {
|
||||||
|
this.theme = configManager.getTheme()
|
||||||
|
|
||||||
|
if (this.theme === ThemeMode.dark || this.theme === ThemeMode.light || this.theme === ThemeMode.system) {
|
||||||
|
nativeTheme.themeSource = this.theme
|
||||||
|
} else {
|
||||||
|
// 兼容旧版本
|
||||||
|
configManager.setTheme(ThemeMode.system)
|
||||||
|
nativeTheme.themeSource = ThemeMode.system
|
||||||
|
}
|
||||||
|
nativeTheme.on('updated', this.themeUpdatadHandler.bind(this))
|
||||||
|
}
|
||||||
|
|
||||||
|
themeUpdatadHandler() {
|
||||||
|
BrowserWindow.getAllWindows().forEach((win) => {
|
||||||
|
if (win && !win.isDestroyed() && win.setTitleBarOverlay) {
|
||||||
|
try {
|
||||||
|
win.setTitleBarOverlay(nativeTheme.shouldUseDarkColors ? titleBarOverlayDark : titleBarOverlayLight)
|
||||||
|
} catch (error) {
|
||||||
|
// don't throw error if setTitleBarOverlay failed
|
||||||
|
// Because it may be called with some windows have some title bar
|
||||||
|
}
|
||||||
|
}
|
||||||
|
win.webContents.send(IpcChannel.ThemeUpdated, nativeTheme.shouldUseDarkColors ? ThemeMode.dark : ThemeMode.light)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
setTheme(theme: ThemeMode) {
|
||||||
|
if (theme === this.theme) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
this.theme = theme
|
||||||
|
nativeTheme.themeSource = theme
|
||||||
|
configManager.setTheme(theme)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export const themeService = new ThemeService()
|
||||||
@@ -5,16 +5,17 @@ import { app, Menu, MenuItemConstructorOptions, nativeImage, nativeTheme, Tray }
|
|||||||
import icon from '../../../build/tray_icon.png?asset'
|
import icon from '../../../build/tray_icon.png?asset'
|
||||||
import iconDark from '../../../build/tray_icon_dark.png?asset'
|
import iconDark from '../../../build/tray_icon_dark.png?asset'
|
||||||
import iconLight from '../../../build/tray_icon_light.png?asset'
|
import iconLight from '../../../build/tray_icon_light.png?asset'
|
||||||
import { configManager } from './ConfigManager'
|
import { ConfigKeys, configManager } from './ConfigManager'
|
||||||
import { windowService } from './WindowService'
|
import { windowService } from './WindowService'
|
||||||
|
|
||||||
export class TrayService {
|
export class TrayService {
|
||||||
private static instance: TrayService
|
private static instance: TrayService
|
||||||
private tray: Tray | null = null
|
private tray: Tray | null = null
|
||||||
|
private contextMenu: Menu | null = null
|
||||||
|
|
||||||
constructor() {
|
constructor() {
|
||||||
|
this.watchConfigChanges()
|
||||||
this.updateTray()
|
this.updateTray()
|
||||||
this.watchTrayChanges()
|
|
||||||
TrayService.instance = this
|
TrayService.instance = this
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -43,6 +44,30 @@ export class TrayService {
|
|||||||
|
|
||||||
this.tray = tray
|
this.tray = tray
|
||||||
|
|
||||||
|
this.updateContextMenu()
|
||||||
|
|
||||||
|
if (process.platform === 'linux') {
|
||||||
|
this.tray.setContextMenu(this.contextMenu)
|
||||||
|
}
|
||||||
|
|
||||||
|
this.tray.setToolTip('Cherry Studio')
|
||||||
|
|
||||||
|
this.tray.on('right-click', () => {
|
||||||
|
if (this.contextMenu) {
|
||||||
|
this.tray?.popUpContextMenu(this.contextMenu)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
this.tray.on('click', () => {
|
||||||
|
if (configManager.getEnableQuickAssistant() && configManager.getClickTrayToShowQuickAssistant()) {
|
||||||
|
windowService.showMiniWindow()
|
||||||
|
} else {
|
||||||
|
windowService.showMainWindow()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
private updateContextMenu() {
|
||||||
const locale = locales[configManager.getLanguage()]
|
const locale = locales[configManager.getLanguage()]
|
||||||
const { tray: trayLocale } = locale.translation
|
const { tray: trayLocale } = locale.translation
|
||||||
|
|
||||||
@@ -64,25 +89,7 @@ export class TrayService {
|
|||||||
}
|
}
|
||||||
].filter(Boolean) as MenuItemConstructorOptions[]
|
].filter(Boolean) as MenuItemConstructorOptions[]
|
||||||
|
|
||||||
const contextMenu = Menu.buildFromTemplate(template)
|
this.contextMenu = Menu.buildFromTemplate(template)
|
||||||
|
|
||||||
if (process.platform === 'linux') {
|
|
||||||
this.tray.setContextMenu(contextMenu)
|
|
||||||
}
|
|
||||||
|
|
||||||
this.tray.setToolTip('Cherry Studio')
|
|
||||||
|
|
||||||
this.tray.on('right-click', () => {
|
|
||||||
this.tray?.popUpContextMenu(contextMenu)
|
|
||||||
})
|
|
||||||
|
|
||||||
this.tray.on('click', () => {
|
|
||||||
if (enableQuickAssistant && configManager.getClickTrayToShowQuickAssistant()) {
|
|
||||||
windowService.showMiniWindow()
|
|
||||||
} else {
|
|
||||||
windowService.showMainWindow()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private updateTray() {
|
private updateTray() {
|
||||||
@@ -94,13 +101,6 @@ export class TrayService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public restartTray() {
|
|
||||||
if (configManager.getTray()) {
|
|
||||||
this.destroyTray()
|
|
||||||
this.createTray()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private destroyTray() {
|
private destroyTray() {
|
||||||
if (this.tray) {
|
if (this.tray) {
|
||||||
this.tray.destroy()
|
this.tray.destroy()
|
||||||
@@ -108,8 +108,16 @@ export class TrayService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private watchTrayChanges() {
|
private watchConfigChanges() {
|
||||||
configManager.subscribe<boolean>('tray', () => this.updateTray())
|
configManager.subscribe(ConfigKeys.Tray, () => this.updateTray())
|
||||||
|
|
||||||
|
configManager.subscribe(ConfigKeys.Language, () => {
|
||||||
|
this.updateContextMenu()
|
||||||
|
})
|
||||||
|
|
||||||
|
configManager.subscribe(ConfigKeys.EnableQuickAssistant, () => {
|
||||||
|
this.updateContextMenu()
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
private quit() {
|
private quit() {
|
||||||
|
|||||||
142
src/main/services/VertexAIService.ts
Normal file
@@ -0,0 +1,142 @@
|
|||||||
|
import { GoogleAuth } from 'google-auth-library'
|
||||||
|
|
||||||
|
interface ServiceAccountCredentials {
|
||||||
|
privateKey: string
|
||||||
|
clientEmail: string
|
||||||
|
}
|
||||||
|
|
||||||
|
interface VertexAIAuthParams {
|
||||||
|
projectId: string
|
||||||
|
serviceAccount?: ServiceAccountCredentials
|
||||||
|
}
|
||||||
|
|
||||||
|
const REQUIRED_VERTEX_AI_SCOPE = 'https://www.googleapis.com/auth/cloud-platform'
|
||||||
|
|
||||||
|
class VertexAIService {
|
||||||
|
private static instance: VertexAIService
|
||||||
|
private authClients: Map<string, GoogleAuth> = new Map()
|
||||||
|
|
||||||
|
static getInstance(): VertexAIService {
|
||||||
|
if (!VertexAIService.instance) {
|
||||||
|
VertexAIService.instance = new VertexAIService()
|
||||||
|
}
|
||||||
|
return VertexAIService.instance
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 格式化私钥,确保它包含正确的PEM头部和尾部
|
||||||
|
*/
|
||||||
|
private formatPrivateKey(privateKey: string): string {
|
||||||
|
if (!privateKey || typeof privateKey !== 'string') {
|
||||||
|
throw new Error('Private key must be a non-empty string')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理JSON字符串中的转义换行符
|
||||||
|
let key = privateKey.replace(/\\n/g, '\n')
|
||||||
|
|
||||||
|
// 如果已经是正确格式的PEM,直接返回
|
||||||
|
if (key.includes('-----BEGIN PRIVATE KEY-----') && key.includes('-----END PRIVATE KEY-----')) {
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
|
||||||
|
// 移除所有换行符和空白字符(为了重新格式化)
|
||||||
|
key = key.replace(/\s+/g, '')
|
||||||
|
|
||||||
|
// 移除可能存在的头部和尾部
|
||||||
|
key = key.replace(/-----BEGIN[^-]*-----/g, '')
|
||||||
|
key = key.replace(/-----END[^-]*-----/g, '')
|
||||||
|
|
||||||
|
// 确保私钥不为空
|
||||||
|
if (!key) {
|
||||||
|
throw new Error('Private key is empty after formatting')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 添加正确的PEM头部和尾部,并格式化为64字符一行
|
||||||
|
const formattedKey = key.match(/.{1,64}/g)?.join('\n') || key
|
||||||
|
|
||||||
|
return `-----BEGIN PRIVATE KEY-----\n${formattedKey}\n-----END PRIVATE KEY-----`
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取认证头用于 Vertex AI 请求
|
||||||
|
*/
|
||||||
|
async getAuthHeaders(params: VertexAIAuthParams): Promise<Record<string, string>> {
|
||||||
|
const { projectId, serviceAccount } = params
|
||||||
|
|
||||||
|
if (!serviceAccount?.privateKey || !serviceAccount?.clientEmail) {
|
||||||
|
throw new Error('Service account credentials are required')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建缓存键
|
||||||
|
const cacheKey = `${projectId}-${serviceAccount.clientEmail}`
|
||||||
|
|
||||||
|
// 检查是否已有客户端实例
|
||||||
|
let auth = this.authClients.get(cacheKey)
|
||||||
|
|
||||||
|
if (!auth) {
|
||||||
|
try {
|
||||||
|
// 格式化私钥
|
||||||
|
const formattedPrivateKey = this.formatPrivateKey(serviceAccount.privateKey)
|
||||||
|
|
||||||
|
// 创建新的认证客户端
|
||||||
|
auth = new GoogleAuth({
|
||||||
|
credentials: {
|
||||||
|
private_key: formattedPrivateKey,
|
||||||
|
client_email: serviceAccount.clientEmail
|
||||||
|
},
|
||||||
|
projectId,
|
||||||
|
scopes: [REQUIRED_VERTEX_AI_SCOPE]
|
||||||
|
})
|
||||||
|
|
||||||
|
this.authClients.set(cacheKey, auth)
|
||||||
|
} catch (formatError: any) {
|
||||||
|
throw new Error(`Invalid private key format: ${formatError.message}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 获取认证头
|
||||||
|
const authHeaders = await auth.getRequestHeaders()
|
||||||
|
|
||||||
|
// 转换为普通对象
|
||||||
|
const headers: Record<string, string> = {}
|
||||||
|
for (const [key, value] of Object.entries(authHeaders)) {
|
||||||
|
if (typeof value === 'string') {
|
||||||
|
headers[key] = value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return headers
|
||||||
|
} catch (error: any) {
|
||||||
|
// 如果认证失败,清除缓存的客户端
|
||||||
|
this.authClients.delete(cacheKey)
|
||||||
|
throw new Error(`Failed to authenticate with service account: ${error.message}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 清理指定项目的认证缓存
|
||||||
|
*/
|
||||||
|
clearAuthCache(projectId: string, clientEmail?: string): void {
|
||||||
|
if (clientEmail) {
|
||||||
|
const cacheKey = `${projectId}-${clientEmail}`
|
||||||
|
this.authClients.delete(cacheKey)
|
||||||
|
} else {
|
||||||
|
// 清理该项目的所有缓存
|
||||||
|
for (const [key] of this.authClients) {
|
||||||
|
if (key.startsWith(`${projectId}-`)) {
|
||||||
|
this.authClients.delete(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 清理所有认证缓存
|
||||||
|
*/
|
||||||
|
clearAllAuthCache(): void {
|
||||||
|
this.authClients.clear()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export default VertexAIService
|
||||||
@@ -1,5 +1,7 @@
|
|||||||
import { WebDavConfig } from '@types'
|
import { WebDavConfig } from '@types'
|
||||||
import Logger from 'electron-log'
|
import Logger from 'electron-log'
|
||||||
|
import https from 'https'
|
||||||
|
import path from 'path'
|
||||||
import Stream from 'stream'
|
import Stream from 'stream'
|
||||||
import {
|
import {
|
||||||
BufferLike,
|
BufferLike,
|
||||||
@@ -14,13 +16,14 @@ export default class WebDav {
|
|||||||
private webdavPath: string
|
private webdavPath: string
|
||||||
|
|
||||||
constructor(params: WebDavConfig) {
|
constructor(params: WebDavConfig) {
|
||||||
this.webdavPath = params.webdavPath
|
this.webdavPath = params.webdavPath || '/'
|
||||||
|
|
||||||
this.instance = createClient(params.webdavHost, {
|
this.instance = createClient(params.webdavHost, {
|
||||||
username: params.webdavUser,
|
username: params.webdavUser,
|
||||||
password: params.webdavPass,
|
password: params.webdavPass,
|
||||||
maxBodyLength: Infinity,
|
maxBodyLength: Infinity,
|
||||||
maxContentLength: Infinity
|
maxContentLength: Infinity,
|
||||||
|
httpsAgent: new https.Agent({ rejectUnauthorized: false })
|
||||||
})
|
})
|
||||||
|
|
||||||
this.putFileContents = this.putFileContents.bind(this)
|
this.putFileContents = this.putFileContents.bind(this)
|
||||||
@@ -49,7 +52,7 @@ export default class WebDav {
|
|||||||
throw error
|
throw error
|
||||||
}
|
}
|
||||||
|
|
||||||
const remoteFilePath = `${this.webdavPath}/${filename}`
|
const remoteFilePath = path.posix.join(this.webdavPath, filename)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
return await this.instance.putFileContents(remoteFilePath, data, options)
|
return await this.instance.putFileContents(remoteFilePath, data, options)
|
||||||
@@ -64,7 +67,7 @@ export default class WebDav {
|
|||||||
throw new Error('WebDAV client not initialized')
|
throw new Error('WebDAV client not initialized')
|
||||||
}
|
}
|
||||||
|
|
||||||
const remoteFilePath = `${this.webdavPath}/${filename}`
|
const remoteFilePath = path.posix.join(this.webdavPath, filename)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
return await this.instance.getFileContents(remoteFilePath, options)
|
return await this.instance.getFileContents(remoteFilePath, options)
|
||||||
@@ -74,6 +77,19 @@ export default class WebDav {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public getDirectoryContents = async () => {
|
||||||
|
if (!this.instance) {
|
||||||
|
throw new Error('WebDAV client not initialized')
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
return await this.instance.getDirectoryContents(this.webdavPath)
|
||||||
|
} catch (error) {
|
||||||
|
Logger.error('[WebDAV] Error getting directory contents on WebDAV:', error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public checkConnection = async () => {
|
public checkConnection = async () => {
|
||||||
if (!this.instance) {
|
if (!this.instance) {
|
||||||
throw new Error('WebDAV client not initialized')
|
throw new Error('WebDAV client not initialized')
|
||||||
@@ -105,7 +121,7 @@ export default class WebDav {
|
|||||||
throw new Error('WebDAV client not initialized')
|
throw new Error('WebDAV client not initialized')
|
||||||
}
|
}
|
||||||
|
|
||||||
const remoteFilePath = `${this.webdavPath}/${filename}`
|
const remoteFilePath = path.posix.join(this.webdavPath, filename)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
return await this.instance.deleteFile(remoteFilePath)
|
return await this.instance.deleteFile(remoteFilePath)
|
||||||
|
|||||||
@@ -1,8 +1,10 @@
|
|||||||
|
// just import the themeService to ensure the theme is initialized
|
||||||
|
import './ThemeService'
|
||||||
|
|
||||||
import { is } from '@electron-toolkit/utils'
|
import { is } from '@electron-toolkit/utils'
|
||||||
import { isDev, isLinux, isMac, isWin } from '@main/constant'
|
import { isDev, isLinux, isMac, isWin } from '@main/constant'
|
||||||
import { getFilesDir } from '@main/utils/file'
|
import { getFilesDir } from '@main/utils/file'
|
||||||
import { IpcChannel } from '@shared/IpcChannel'
|
import { IpcChannel } from '@shared/IpcChannel'
|
||||||
import { ThemeMode } from '@types'
|
|
||||||
import { app, BrowserWindow, nativeTheme, shell } from 'electron'
|
import { app, BrowserWindow, nativeTheme, shell } from 'electron'
|
||||||
import Logger from 'electron-log'
|
import Logger from 'electron-log'
|
||||||
import windowStateKeeper from 'electron-window-state'
|
import windowStateKeeper from 'electron-window-state'
|
||||||
@@ -45,13 +47,6 @@ export class WindowService {
|
|||||||
maximize: false
|
maximize: false
|
||||||
})
|
})
|
||||||
|
|
||||||
const theme = configManager.getTheme()
|
|
||||||
if (theme === ThemeMode.auto) {
|
|
||||||
nativeTheme.themeSource = 'system'
|
|
||||||
} else {
|
|
||||||
nativeTheme.themeSource = theme
|
|
||||||
}
|
|
||||||
|
|
||||||
this.mainWindow = new BrowserWindow({
|
this.mainWindow = new BrowserWindow({
|
||||||
x: mainWindowState.x,
|
x: mainWindowState.x,
|
||||||
y: mainWindowState.y,
|
y: mainWindowState.y,
|
||||||
@@ -121,12 +116,6 @@ export class WindowService {
|
|||||||
app.exit(1)
|
app.exit(1)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
mainWindow.webContents.on('unresponsive', () => {
|
|
||||||
// 在升级到electron 34后,可以获取具体js stack trace,目前只打个日志监控下
|
|
||||||
// https://www.electronjs.org/blog/electron-34-0#unresponsive-renderer-javascript-call-stacks
|
|
||||||
Logger.error('Renderer process unresponsive')
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private setupMaximize(mainWindow: BrowserWindow, isMaximized: boolean) {
|
private setupMaximize(mainWindow: BrowserWindow, isMaximized: boolean) {
|
||||||
@@ -549,6 +538,25 @@ export class WindowService {
|
|||||||
public setPinMiniWindow(isPinned) {
|
public setPinMiniWindow(isPinned) {
|
||||||
this.isPinnedMiniWindow = isPinned
|
this.isPinnedMiniWindow = isPinned
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 引用文本到主窗口
|
||||||
|
* @param text 原始文本(未格式化)
|
||||||
|
*/
|
||||||
|
public quoteToMainWindow(text: string): void {
|
||||||
|
try {
|
||||||
|
this.showMainWindow()
|
||||||
|
|
||||||
|
const mainWindow = this.getMainWindow()
|
||||||
|
if (mainWindow && !mainWindow.isDestroyed()) {
|
||||||
|
setTimeout(() => {
|
||||||
|
mainWindow.webContents.send(IpcChannel.App_QuoteToMain, text)
|
||||||
|
}, 100)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
Logger.error('Failed to quote to main window:', error as Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export const windowService = WindowService.getInstance()
|
export const windowService = WindowService.getInstance()
|
||||||
|
|||||||
71
src/main/utils/__tests__/aes.test.ts
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
import { describe, expect, it } from 'vitest'
|
||||||
|
|
||||||
|
import { decrypt, encrypt } from '../aes'
|
||||||
|
|
||||||
|
const key = '12345678901234567890123456789012' // 32字节
|
||||||
|
const iv = '1234567890abcdef1234567890abcdef' // 32字节hex,实际应16字节hex
|
||||||
|
|
||||||
|
function getIv16() {
|
||||||
|
// 取前16字节作为 hex
|
||||||
|
return iv.slice(0, 32)
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('aes utils', () => {
|
||||||
|
it('should encrypt and decrypt normal string', () => {
|
||||||
|
const text = 'hello world'
|
||||||
|
const { iv: outIv, encryptedData } = encrypt(text, key, getIv16())
|
||||||
|
expect(typeof encryptedData).toBe('string')
|
||||||
|
expect(outIv).toBe(getIv16())
|
||||||
|
const decrypted = decrypt(encryptedData, getIv16(), key)
|
||||||
|
expect(decrypted).toBe(text)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support unicode and special chars', () => {
|
||||||
|
const text = '你好,世界!🌟🚀'
|
||||||
|
const { encryptedData } = encrypt(text, key, getIv16())
|
||||||
|
const decrypted = decrypt(encryptedData, getIv16(), key)
|
||||||
|
expect(decrypted).toBe(text)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle empty string', () => {
|
||||||
|
const text = ''
|
||||||
|
const { encryptedData } = encrypt(text, key, getIv16())
|
||||||
|
const decrypted = decrypt(encryptedData, getIv16(), key)
|
||||||
|
expect(decrypted).toBe(text)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should encrypt and decrypt long string', () => {
|
||||||
|
const text = 'a'.repeat(100_000)
|
||||||
|
const { encryptedData } = encrypt(text, key, getIv16())
|
||||||
|
const decrypted = decrypt(encryptedData, getIv16(), key)
|
||||||
|
expect(decrypted).toBe(text)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should throw error for wrong key', () => {
|
||||||
|
const text = 'test'
|
||||||
|
const { encryptedData } = encrypt(text, key, getIv16())
|
||||||
|
expect(() => decrypt(encryptedData, getIv16(), 'wrongkeywrongkeywrongkeywrongkey')).toThrow()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should throw error for wrong iv', () => {
|
||||||
|
const text = 'test'
|
||||||
|
const { encryptedData } = encrypt(text, key, getIv16())
|
||||||
|
expect(() => decrypt(encryptedData, 'abcdefabcdefabcdefabcdefabcdefab', key)).toThrow()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should throw error for invalid key/iv length', () => {
|
||||||
|
expect(() => encrypt('test', 'shortkey', getIv16())).toThrow()
|
||||||
|
expect(() => encrypt('test', key, 'shortiv')).toThrow()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should throw error for invalid encrypted data', () => {
|
||||||
|
expect(() => decrypt('nothexdata', getIv16(), key)).toThrow()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should throw error for non-string input', () => {
|
||||||
|
// @ts-expect-error purposely pass wrong type to test error branch
|
||||||
|
expect(() => encrypt(null, key, getIv16())).toThrow()
|
||||||
|
// @ts-expect-error purposely pass wrong type to test error branch
|
||||||
|
expect(() => decrypt(null, getIv16(), key)).toThrow()
|
||||||
|
})
|
||||||
|
})
|
||||||
243
src/main/utils/__tests__/file.test.ts
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
import * as fs from 'node:fs'
|
||||||
|
import os from 'node:os'
|
||||||
|
import path from 'node:path'
|
||||||
|
|
||||||
|
import { FileTypes } from '@types'
|
||||||
|
import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
import { getAllFiles, getAppConfigDir, getConfigDir, getFilesDir, getFileType, getTempDir } from '../file'
|
||||||
|
|
||||||
|
// Mock dependencies
|
||||||
|
vi.mock('node:fs')
|
||||||
|
vi.mock('node:os')
|
||||||
|
vi.mock('node:path')
|
||||||
|
vi.mock('uuid', () => ({
|
||||||
|
v4: () => 'mock-uuid'
|
||||||
|
}))
|
||||||
|
vi.mock('electron', () => ({
|
||||||
|
app: {
|
||||||
|
getPath: vi.fn((key) => {
|
||||||
|
if (key === 'temp') return '/mock/temp'
|
||||||
|
if (key === 'userData') return '/mock/userData'
|
||||||
|
return '/mock/unknown'
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('file', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
|
||||||
|
// Mock path.extname
|
||||||
|
vi.mocked(path.extname).mockImplementation((file) => {
|
||||||
|
const parts = file.split('.')
|
||||||
|
return parts.length > 1 ? `.${parts[parts.length - 1]}` : ''
|
||||||
|
})
|
||||||
|
|
||||||
|
// Mock path.basename
|
||||||
|
vi.mocked(path.basename).mockImplementation((file) => {
|
||||||
|
const parts = file.split('/')
|
||||||
|
return parts[parts.length - 1]
|
||||||
|
})
|
||||||
|
|
||||||
|
// Mock path.join
|
||||||
|
vi.mocked(path.join).mockImplementation((...args) => args.join('/'))
|
||||||
|
|
||||||
|
// Mock os.homedir
|
||||||
|
vi.mocked(os.homedir).mockReturnValue('/mock/home')
|
||||||
|
})
|
||||||
|
|
||||||
|
afterEach(() => {
|
||||||
|
vi.resetAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getFileType', () => {
|
||||||
|
it('should return IMAGE for image extensions', () => {
|
||||||
|
expect(getFileType('.jpg')).toBe(FileTypes.IMAGE)
|
||||||
|
expect(getFileType('.jpeg')).toBe(FileTypes.IMAGE)
|
||||||
|
expect(getFileType('.png')).toBe(FileTypes.IMAGE)
|
||||||
|
expect(getFileType('.gif')).toBe(FileTypes.IMAGE)
|
||||||
|
expect(getFileType('.webp')).toBe(FileTypes.IMAGE)
|
||||||
|
expect(getFileType('.bmp')).toBe(FileTypes.IMAGE)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return VIDEO for video extensions', () => {
|
||||||
|
expect(getFileType('.mp4')).toBe(FileTypes.VIDEO)
|
||||||
|
expect(getFileType('.avi')).toBe(FileTypes.VIDEO)
|
||||||
|
expect(getFileType('.mov')).toBe(FileTypes.VIDEO)
|
||||||
|
expect(getFileType('.mkv')).toBe(FileTypes.VIDEO)
|
||||||
|
expect(getFileType('.flv')).toBe(FileTypes.VIDEO)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return AUDIO for audio extensions', () => {
|
||||||
|
expect(getFileType('.mp3')).toBe(FileTypes.AUDIO)
|
||||||
|
expect(getFileType('.wav')).toBe(FileTypes.AUDIO)
|
||||||
|
expect(getFileType('.ogg')).toBe(FileTypes.AUDIO)
|
||||||
|
expect(getFileType('.flac')).toBe(FileTypes.AUDIO)
|
||||||
|
expect(getFileType('.aac')).toBe(FileTypes.AUDIO)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return TEXT for text extensions', () => {
|
||||||
|
expect(getFileType('.txt')).toBe(FileTypes.TEXT)
|
||||||
|
expect(getFileType('.md')).toBe(FileTypes.TEXT)
|
||||||
|
expect(getFileType('.html')).toBe(FileTypes.TEXT)
|
||||||
|
expect(getFileType('.json')).toBe(FileTypes.TEXT)
|
||||||
|
expect(getFileType('.js')).toBe(FileTypes.TEXT)
|
||||||
|
expect(getFileType('.ts')).toBe(FileTypes.TEXT)
|
||||||
|
expect(getFileType('.css')).toBe(FileTypes.TEXT)
|
||||||
|
expect(getFileType('.java')).toBe(FileTypes.TEXT)
|
||||||
|
expect(getFileType('.py')).toBe(FileTypes.TEXT)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return DOCUMENT for document extensions', () => {
|
||||||
|
expect(getFileType('.pdf')).toBe(FileTypes.DOCUMENT)
|
||||||
|
expect(getFileType('.pptx')).toBe(FileTypes.DOCUMENT)
|
||||||
|
expect(getFileType('.docx')).toBe(FileTypes.DOCUMENT)
|
||||||
|
expect(getFileType('.xlsx')).toBe(FileTypes.DOCUMENT)
|
||||||
|
expect(getFileType('.odt')).toBe(FileTypes.DOCUMENT)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return OTHER for unknown extensions', () => {
|
||||||
|
expect(getFileType('.unknown')).toBe(FileTypes.OTHER)
|
||||||
|
expect(getFileType('')).toBe(FileTypes.OTHER)
|
||||||
|
expect(getFileType('.')).toBe(FileTypes.OTHER)
|
||||||
|
expect(getFileType('...')).toBe(FileTypes.OTHER)
|
||||||
|
expect(getFileType('.123')).toBe(FileTypes.OTHER)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle case-insensitive extensions', () => {
|
||||||
|
expect(getFileType('.JPG')).toBe(FileTypes.IMAGE)
|
||||||
|
expect(getFileType('.PDF')).toBe(FileTypes.DOCUMENT)
|
||||||
|
expect(getFileType('.Mp3')).toBe(FileTypes.AUDIO)
|
||||||
|
expect(getFileType('.HtMl')).toBe(FileTypes.TEXT)
|
||||||
|
expect(getFileType('.Xlsx')).toBe(FileTypes.DOCUMENT)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle extensions without leading dot', () => {
|
||||||
|
expect(getFileType('jpg')).toBe(FileTypes.OTHER)
|
||||||
|
expect(getFileType('pdf')).toBe(FileTypes.OTHER)
|
||||||
|
expect(getFileType('mp3')).toBe(FileTypes.OTHER)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle extreme cases', () => {
|
||||||
|
expect(getFileType('.averylongfileextensionname')).toBe(FileTypes.OTHER)
|
||||||
|
expect(getFileType('.tar.gz')).toBe(FileTypes.OTHER)
|
||||||
|
expect(getFileType('.文件')).toBe(FileTypes.OTHER)
|
||||||
|
expect(getFileType('.файл')).toBe(FileTypes.OTHER)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getAllFiles', () => {
|
||||||
|
it('should return all valid files recursively', () => {
|
||||||
|
// Mock file system
|
||||||
|
// @ts-ignore - override type for testing
|
||||||
|
vi.spyOn(fs, 'readdirSync').mockImplementation((dirPath) => {
|
||||||
|
if (dirPath === '/test') {
|
||||||
|
return ['file1.txt', 'file2.pdf', 'subdir']
|
||||||
|
} else if (dirPath === '/test/subdir') {
|
||||||
|
return ['file3.md', 'file4.docx']
|
||||||
|
}
|
||||||
|
return []
|
||||||
|
})
|
||||||
|
|
||||||
|
vi.mocked(fs.statSync).mockImplementation((filePath) => {
|
||||||
|
const isDir = String(filePath).endsWith('subdir')
|
||||||
|
return {
|
||||||
|
isDirectory: () => isDir,
|
||||||
|
size: 1024
|
||||||
|
} as fs.Stats
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = getAllFiles('/test')
|
||||||
|
|
||||||
|
expect(result).toHaveLength(4)
|
||||||
|
expect(result[0].id).toBe('mock-uuid')
|
||||||
|
expect(result[0].name).toBe('file1.txt')
|
||||||
|
expect(result[0].type).toBe(FileTypes.TEXT)
|
||||||
|
expect(result[1].name).toBe('file2.pdf')
|
||||||
|
expect(result[1].type).toBe(FileTypes.DOCUMENT)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should skip hidden files', () => {
|
||||||
|
// @ts-ignore - override type for testing
|
||||||
|
vi.spyOn(fs, 'readdirSync').mockReturnValue(['.hidden', 'visible.txt'])
|
||||||
|
vi.mocked(fs.statSync).mockReturnValue({
|
||||||
|
isDirectory: () => false,
|
||||||
|
size: 1024
|
||||||
|
} as fs.Stats)
|
||||||
|
|
||||||
|
const result = getAllFiles('/test')
|
||||||
|
|
||||||
|
expect(result).toHaveLength(1)
|
||||||
|
expect(result[0].name).toBe('visible.txt')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should skip unsupported file types', () => {
|
||||||
|
// @ts-ignore - override type for testing
|
||||||
|
vi.spyOn(fs, 'readdirSync').mockReturnValue(['image.jpg', 'video.mp4', 'audio.mp3', 'document.pdf'])
|
||||||
|
vi.mocked(fs.statSync).mockReturnValue({
|
||||||
|
isDirectory: () => false,
|
||||||
|
size: 1024
|
||||||
|
} as fs.Stats)
|
||||||
|
|
||||||
|
const result = getAllFiles('/test')
|
||||||
|
|
||||||
|
// Should only include document.pdf as the others are excluded types
|
||||||
|
expect(result).toHaveLength(1)
|
||||||
|
expect(result[0].name).toBe('document.pdf')
|
||||||
|
expect(result[0].type).toBe(FileTypes.DOCUMENT)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return empty array for empty directory', () => {
|
||||||
|
// @ts-ignore - override type for testing
|
||||||
|
vi.spyOn(fs, 'readdirSync').mockReturnValue([])
|
||||||
|
|
||||||
|
const result = getAllFiles('/empty')
|
||||||
|
|
||||||
|
expect(result).toHaveLength(0)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle file system errors', () => {
|
||||||
|
// @ts-ignore - override type for testing
|
||||||
|
vi.spyOn(fs, 'readdirSync').mockImplementation(() => {
|
||||||
|
throw new Error('Directory not found')
|
||||||
|
})
|
||||||
|
|
||||||
|
// Since the function doesn't have error handling, we expect it to propagate
|
||||||
|
expect(() => getAllFiles('/nonexistent')).toThrow('Directory not found')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getTempDir', () => {
|
||||||
|
it('should return correct temp directory path', () => {
|
||||||
|
const tempDir = getTempDir()
|
||||||
|
expect(tempDir).toBe('/mock/temp/CherryStudio')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getFilesDir', () => {
|
||||||
|
it('should return correct files directory path', () => {
|
||||||
|
const filesDir = getFilesDir()
|
||||||
|
expect(filesDir).toBe('/mock/userData/Data/Files')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getConfigDir', () => {
|
||||||
|
it('should return correct config directory path', () => {
|
||||||
|
const configDir = getConfigDir()
|
||||||
|
expect(configDir).toBe('/mock/home/.cherrystudio/config')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getAppConfigDir', () => {
|
||||||
|
it('should return correct app config directory path', () => {
|
||||||
|
const appConfigDir = getAppConfigDir('test-app')
|
||||||
|
expect(appConfigDir).toBe('/mock/home/.cherrystudio/config/test-app')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle empty app name', () => {
|
||||||
|
const appConfigDir = getAppConfigDir('')
|
||||||
|
expect(appConfigDir).toBe('/mock/home/.cherrystudio/config/')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
61
src/main/utils/__tests__/zip.test.ts
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
import { describe, expect, it } from 'vitest'
|
||||||
|
|
||||||
|
import { compress, decompress } from '../zip'
|
||||||
|
|
||||||
|
const jsonStr = JSON.stringify({ foo: 'bar', num: 42, arr: [1, 2, 3] })
|
||||||
|
|
||||||
|
// 辅助函数:生成大字符串
|
||||||
|
function makeLargeString(size: number) {
|
||||||
|
return 'a'.repeat(size)
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('zip', () => {
|
||||||
|
describe('compress & decompress', () => {
|
||||||
|
it('should compress and decompress a normal JSON string', async () => {
|
||||||
|
const compressed = await compress(jsonStr)
|
||||||
|
expect(compressed).toBeInstanceOf(Buffer)
|
||||||
|
|
||||||
|
const decompressed = await decompress(compressed)
|
||||||
|
expect(decompressed).toBe(jsonStr)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle empty string', async () => {
|
||||||
|
const compressed = await compress('')
|
||||||
|
expect(compressed).toBeInstanceOf(Buffer)
|
||||||
|
const decompressed = await decompress(compressed)
|
||||||
|
expect(decompressed).toBe('')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle large string', async () => {
|
||||||
|
const largeStr = makeLargeString(100_000)
|
||||||
|
const compressed = await compress(largeStr)
|
||||||
|
expect(compressed).toBeInstanceOf(Buffer)
|
||||||
|
expect(compressed.length).toBeLessThan(largeStr.length)
|
||||||
|
const decompressed = await decompress(compressed)
|
||||||
|
expect(decompressed).toBe(largeStr)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should throw error when decompressing invalid buffer', async () => {
|
||||||
|
const invalidBuffer = Buffer.from('not a valid gzip', 'utf-8')
|
||||||
|
await expect(decompress(invalidBuffer)).rejects.toThrow()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should throw error when compress input is not string', async () => {
|
||||||
|
// @ts-expect-error purposely pass wrong type to test error branch
|
||||||
|
await expect(compress(null)).rejects.toThrow()
|
||||||
|
// @ts-expect-error purposely pass wrong type to test error branch
|
||||||
|
await expect(compress(undefined)).rejects.toThrow()
|
||||||
|
// @ts-expect-error purposely pass wrong type to test error branch
|
||||||
|
await expect(compress(123)).rejects.toThrow()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should throw error when decompress input is not buffer', async () => {
|
||||||
|
// @ts-expect-error purposely pass wrong type to test error branch
|
||||||
|
await expect(decompress(null)).rejects.toThrow()
|
||||||
|
// @ts-expect-error purposely pass wrong type to test error branch
|
||||||
|
await expect(decompress(undefined)).rejects.toThrow()
|
||||||
|
// @ts-expect-error purposely pass wrong type to test error branch
|
||||||
|
await expect(decompress('string')).rejects.toThrow()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -9,10 +9,10 @@ const gunzipPromise = util.promisify(zlib.gunzip)
|
|||||||
|
|
||||||
/**
|
/**
|
||||||
* 压缩字符串
|
* 压缩字符串
|
||||||
|
* @param {string} str 要压缩的 JSON 字符串
|
||||||
* @returns {Promise<Buffer>} 压缩后的 Buffer
|
* @returns {Promise<Buffer>} 压缩后的 Buffer
|
||||||
* @param str
|
|
||||||
*/
|
*/
|
||||||
export async function compress(str) {
|
export async function compress(str: string): Promise<Buffer> {
|
||||||
try {
|
try {
|
||||||
const buffer = Buffer.from(str, 'utf-8')
|
const buffer = Buffer.from(str, 'utf-8')
|
||||||
return await gzipPromise(buffer)
|
return await gzipPromise(buffer)
|
||||||
@@ -27,7 +27,7 @@ export async function compress(str) {
|
|||||||
* @param {Buffer} compressedBuffer - 压缩的 Buffer
|
* @param {Buffer} compressedBuffer - 压缩的 Buffer
|
||||||
* @returns {Promise<string>} 解压缩后的 JSON 字符串
|
* @returns {Promise<string>} 解压缩后的 JSON 字符串
|
||||||
*/
|
*/
|
||||||
export async function decompress(compressedBuffer) {
|
export async function decompress(compressedBuffer: Buffer): Promise<string> {
|
||||||
try {
|
try {
|
||||||
const buffer = await gunzipPromise(compressedBuffer)
|
const buffer = await gunzipPromise(compressedBuffer)
|
||||||
return buffer.toString('utf-8')
|
return buffer.toString('utf-8')
|
||||||
|
|||||||
@@ -1,11 +1,14 @@
|
|||||||
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
|
||||||
import { electronAPI } from '@electron-toolkit/preload'
|
import { electronAPI } from '@electron-toolkit/preload'
|
||||||
|
import { FeedUrl } from '@shared/config/constant'
|
||||||
import { IpcChannel } from '@shared/IpcChannel'
|
import { IpcChannel } from '@shared/IpcChannel'
|
||||||
import { FileType, KnowledgeBaseParams, KnowledgeItem, MCPServer, Shortcut, WebDavConfig } from '@types'
|
import { FileType, KnowledgeBaseParams, KnowledgeItem, MCPServer, Shortcut, ThemeMode, WebDavConfig } from '@types'
|
||||||
import { contextBridge, ipcRenderer, OpenDialogOptions, shell, webUtils } from 'electron'
|
import { contextBridge, ipcRenderer, OpenDialogOptions, shell, webUtils } from 'electron'
|
||||||
import { Notification } from 'src/renderer/src/types/notification'
|
import { Notification } from 'src/renderer/src/types/notification'
|
||||||
import { CreateDirectoryOptions } from 'webdav'
|
import { CreateDirectoryOptions } from 'webdav'
|
||||||
|
|
||||||
|
import type { ActionItem } from '../renderer/src/types/selectionTypes'
|
||||||
|
|
||||||
// Custom APIs for renderer
|
// Custom APIs for renderer
|
||||||
const api = {
|
const api = {
|
||||||
getAppInfo: () => ipcRenderer.invoke(IpcChannel.App_Info),
|
getAppInfo: () => ipcRenderer.invoke(IpcChannel.App_Info),
|
||||||
@@ -18,8 +21,8 @@ const api = {
|
|||||||
setLaunchToTray: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetLaunchToTray, isActive),
|
setLaunchToTray: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetLaunchToTray, isActive),
|
||||||
setTray: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetTray, isActive),
|
setTray: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetTray, isActive),
|
||||||
setTrayOnClose: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetTrayOnClose, isActive),
|
setTrayOnClose: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetTrayOnClose, isActive),
|
||||||
restartTray: () => ipcRenderer.invoke(IpcChannel.App_RestartTray),
|
setFeedUrl: (feedUrl: FeedUrl) => ipcRenderer.invoke(IpcChannel.App_SetFeedUrl, feedUrl),
|
||||||
setTheme: (theme: 'light' | 'dark' | 'auto') => ipcRenderer.invoke(IpcChannel.App_SetTheme, theme),
|
setTheme: (theme: ThemeMode) => ipcRenderer.invoke(IpcChannel.App_SetTheme, theme),
|
||||||
handleZoomFactor: (delta: number, reset: boolean = false) =>
|
handleZoomFactor: (delta: number, reset: boolean = false) =>
|
||||||
ipcRenderer.invoke(IpcChannel.App_HandleZoomFactor, delta, reset),
|
ipcRenderer.invoke(IpcChannel.App_HandleZoomFactor, delta, reset),
|
||||||
setAutoUpdate: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetAutoUpdate, isActive),
|
setAutoUpdate: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetAutoUpdate, isActive),
|
||||||
@@ -74,14 +77,17 @@ const api = {
|
|||||||
selectFolder: () => ipcRenderer.invoke(IpcChannel.File_SelectFolder),
|
selectFolder: () => ipcRenderer.invoke(IpcChannel.File_SelectFolder),
|
||||||
saveImage: (name: string, data: string) => ipcRenderer.invoke(IpcChannel.File_SaveImage, name, data),
|
saveImage: (name: string, data: string) => ipcRenderer.invoke(IpcChannel.File_SaveImage, name, data),
|
||||||
base64Image: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_Base64Image, fileId),
|
base64Image: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_Base64Image, fileId),
|
||||||
download: (url: string, isUseContentType?: boolean) => ipcRenderer.invoke(IpcChannel.File_Download, url, isUseContentType),
|
saveBase64Image: (data: string) => ipcRenderer.invoke(IpcChannel.File_SaveBase64Image, data),
|
||||||
|
download: (url: string, isUseContentType?: boolean) =>
|
||||||
|
ipcRenderer.invoke(IpcChannel.File_Download, url, isUseContentType),
|
||||||
copy: (fileId: string, destPath: string) => ipcRenderer.invoke(IpcChannel.File_Copy, fileId, destPath),
|
copy: (fileId: string, destPath: string) => ipcRenderer.invoke(IpcChannel.File_Copy, fileId, destPath),
|
||||||
binaryImage: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_BinaryImage, fileId),
|
binaryImage: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_BinaryImage, fileId),
|
||||||
base64File: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_Base64File, fileId),
|
base64File: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_Base64File, fileId),
|
||||||
|
pdfInfo: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_GetPdfInfo, fileId),
|
||||||
getPathForFile: (file: File) => webUtils.getPathForFile(file)
|
getPathForFile: (file: File) => webUtils.getPathForFile(file)
|
||||||
},
|
},
|
||||||
fs: {
|
fs: {
|
||||||
read: (path: string) => ipcRenderer.invoke(IpcChannel.Fs_Read, path)
|
read: (pathOrUrl: string, encoding?: BufferEncoding) => ipcRenderer.invoke(IpcChannel.Fs_Read, pathOrUrl, encoding)
|
||||||
},
|
},
|
||||||
export: {
|
export: {
|
||||||
toWord: (markdown: string, fileName: string) => ipcRenderer.invoke(IpcChannel.Export_Word, markdown, fileName)
|
toWord: (markdown: string, fileName: string) => ipcRenderer.invoke(IpcChannel.Export_Word, markdown, fileName)
|
||||||
@@ -123,8 +129,16 @@ const api = {
|
|||||||
listFiles: (apiKey: string) => ipcRenderer.invoke(IpcChannel.Gemini_ListFiles, apiKey),
|
listFiles: (apiKey: string) => ipcRenderer.invoke(IpcChannel.Gemini_ListFiles, apiKey),
|
||||||
deleteFile: (fileId: string, apiKey: string) => ipcRenderer.invoke(IpcChannel.Gemini_DeleteFile, fileId, apiKey)
|
deleteFile: (fileId: string, apiKey: string) => ipcRenderer.invoke(IpcChannel.Gemini_DeleteFile, fileId, apiKey)
|
||||||
},
|
},
|
||||||
|
|
||||||
|
vertexAI: {
|
||||||
|
getAuthHeaders: (params: { projectId: string; serviceAccount?: { privateKey: string; clientEmail: string } }) =>
|
||||||
|
ipcRenderer.invoke(IpcChannel.VertexAI_GetAuthHeaders, params),
|
||||||
|
clearAuthCache: (projectId: string, clientEmail?: string) =>
|
||||||
|
ipcRenderer.invoke(IpcChannel.VertexAI_ClearAuthCache, projectId, clientEmail)
|
||||||
|
},
|
||||||
config: {
|
config: {
|
||||||
set: (key: string, value: any) => ipcRenderer.invoke(IpcChannel.Config_Set, key, value),
|
set: (key: string, value: any, isNotify: boolean = false) =>
|
||||||
|
ipcRenderer.invoke(IpcChannel.Config_Set, key, value, isNotify),
|
||||||
get: (key: string) => ipcRenderer.invoke(IpcChannel.Config_Get, key)
|
get: (key: string) => ipcRenderer.invoke(IpcChannel.Config_Get, key)
|
||||||
},
|
},
|
||||||
miniWindow: {
|
miniWindow: {
|
||||||
@@ -204,7 +218,26 @@ const api = {
|
|||||||
subscribe: () => ipcRenderer.invoke(IpcChannel.StoreSync_Subscribe),
|
subscribe: () => ipcRenderer.invoke(IpcChannel.StoreSync_Subscribe),
|
||||||
unsubscribe: () => ipcRenderer.invoke(IpcChannel.StoreSync_Unsubscribe),
|
unsubscribe: () => ipcRenderer.invoke(IpcChannel.StoreSync_Unsubscribe),
|
||||||
onUpdate: (action: any) => ipcRenderer.invoke(IpcChannel.StoreSync_OnUpdate, action)
|
onUpdate: (action: any) => ipcRenderer.invoke(IpcChannel.StoreSync_OnUpdate, action)
|
||||||
}
|
},
|
||||||
|
selection: {
|
||||||
|
hideToolbar: () => ipcRenderer.invoke(IpcChannel.Selection_ToolbarHide),
|
||||||
|
writeToClipboard: (text: string) => ipcRenderer.invoke(IpcChannel.Selection_WriteToClipboard, text),
|
||||||
|
determineToolbarSize: (width: number, height: number) =>
|
||||||
|
ipcRenderer.invoke(IpcChannel.Selection_ToolbarDetermineSize, width, height),
|
||||||
|
setEnabled: (enabled: boolean) => ipcRenderer.invoke(IpcChannel.Selection_SetEnabled, enabled),
|
||||||
|
setTriggerMode: (triggerMode: string) => ipcRenderer.invoke(IpcChannel.Selection_SetTriggerMode, triggerMode),
|
||||||
|
setFollowToolbar: (isFollowToolbar: boolean) =>
|
||||||
|
ipcRenderer.invoke(IpcChannel.Selection_SetFollowToolbar, isFollowToolbar),
|
||||||
|
setRemeberWinSize: (isRemeberWinSize: boolean) =>
|
||||||
|
ipcRenderer.invoke(IpcChannel.Selection_SetRemeberWinSize, isRemeberWinSize),
|
||||||
|
setFilterMode: (filterMode: string) => ipcRenderer.invoke(IpcChannel.Selection_SetFilterMode, filterMode),
|
||||||
|
setFilterList: (filterList: string[]) => ipcRenderer.invoke(IpcChannel.Selection_SetFilterList, filterList),
|
||||||
|
processAction: (actionItem: ActionItem) => ipcRenderer.invoke(IpcChannel.Selection_ProcessAction, actionItem),
|
||||||
|
closeActionWindow: () => ipcRenderer.invoke(IpcChannel.Selection_ActionWindowClose),
|
||||||
|
minimizeActionWindow: () => ipcRenderer.invoke(IpcChannel.Selection_ActionWindowMinimize),
|
||||||
|
pinActionWindow: (isPinned: boolean) => ipcRenderer.invoke(IpcChannel.Selection_ActionWindowPin, isPinned)
|
||||||
|
},
|
||||||
|
quoteToMainWindow: (text: string) => ipcRenderer.invoke(IpcChannel.App_QuoteToMain, text)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Use `contextBridge` APIs to expose Electron APIs to
|
// Use `contextBridge` APIs to expose Electron APIs to
|
||||||
|
|||||||
@@ -1,49 +0,0 @@
|
|||||||
import { vi } from 'vitest'
|
|
||||||
|
|
||||||
vi.mock('electron-log/renderer', () => {
|
|
||||||
return {
|
|
||||||
default: {
|
|
||||||
info: console.log,
|
|
||||||
error: console.error,
|
|
||||||
warn: console.warn,
|
|
||||||
debug: console.debug,
|
|
||||||
verbose: console.log,
|
|
||||||
silly: console.log,
|
|
||||||
log: console.log,
|
|
||||||
transports: {
|
|
||||||
console: {
|
|
||||||
level: 'info'
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
vi.stubGlobal('window', {
|
|
||||||
electron: {
|
|
||||||
ipcRenderer: {
|
|
||||||
on: vi.fn(), // Mocking ipcRenderer.on
|
|
||||||
send: vi.fn() // Mocking ipcRenderer.send
|
|
||||||
}
|
|
||||||
},
|
|
||||||
api: {
|
|
||||||
file: {
|
|
||||||
read: vi.fn().mockResolvedValue('[]'), // Mock file.read to return an empty array (you can customize this)
|
|
||||||
writeWithId: vi.fn().mockResolvedValue(undefined) // Mock file.writeWithId to do nothing
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
vi.mock('axios', () => ({
|
|
||||||
default: {
|
|
||||||
get: vi.fn().mockResolvedValue({ data: {} }), // Mocking axios GET request
|
|
||||||
post: vi.fn().mockResolvedValue({ data: {} }) // Mocking axios POST request
|
|
||||||
// You can add other axios methods like put, delete etc. as needed
|
|
||||||
}
|
|
||||||
}))
|
|
||||||
|
|
||||||
vi.stubGlobal('window', {
|
|
||||||
...global.window, // Copy other global properties
|
|
||||||
addEventListener: vi.fn(), // Mock addEventListener
|
|
||||||
removeEventListener: vi.fn() // You can also mock removeEventListener if needed
|
|
||||||
})
|
|
||||||
41
src/renderer/selectionAction.html
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
<!doctype html>
|
||||||
|
<html lang="zh-CN">
|
||||||
|
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<meta name="viewport" content="initial-scale=1, width=device-width" />
|
||||||
|
<meta http-equiv="Content-Security-Policy"
|
||||||
|
content="default-src 'self'; connect-src blob: *; script-src 'self' 'unsafe-eval' *; worker-src 'self' blob:; style-src 'self' 'unsafe-inline' *; font-src 'self' data: *; img-src 'self' data: file: * blob:; frame-src * file:" />
|
||||||
|
<title>Cherry Studio Selection Assistant</title>
|
||||||
|
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
<div id="root"></div>
|
||||||
|
<script type="module" src="/src/windows/selection/action/entryPoint.tsx"></script>
|
||||||
|
<style>
|
||||||
|
html {
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
width: 100vw;
|
||||||
|
height: 100vh;
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
|
||||||
|
#root {
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
width: 100%;
|
||||||
|
height: 100%;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
||||||
43
src/renderer/selectionToolbar.html
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
<!doctype html>
|
||||||
|
<html lang="zh-CN">
|
||||||
|
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8" />
|
||||||
|
<meta name="viewport" content="initial-scale=1, width=device-width" />
|
||||||
|
<meta http-equiv="Content-Security-Policy"
|
||||||
|
content="default-src 'self'; connect-src blob: *; script-src 'self' 'unsafe-eval' *; worker-src 'self' blob:; style-src 'self' 'unsafe-inline' *; font-src 'self' data: *; img-src 'self' data: file: * blob:; frame-src * file:" />
|
||||||
|
<title>Cherry Studio Selection Toolbar</title>
|
||||||
|
|
||||||
|
</head>
|
||||||
|
|
||||||
|
<body>
|
||||||
|
<div id="root"></div>
|
||||||
|
<script type="module" src="/src/windows/selection/toolbar/entryPoint.tsx"></script>
|
||||||
|
<style>
|
||||||
|
html {
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
body {
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
overflow: hidden;
|
||||||
|
width: 100vw;
|
||||||
|
height: 100vh;
|
||||||
|
|
||||||
|
-webkit-user-select: none;
|
||||||
|
-moz-user-select: none;
|
||||||
|
-ms-user-select: none;
|
||||||
|
user-select: none;
|
||||||
|
}
|
||||||
|
|
||||||
|
#root {
|
||||||
|
margin: 0;
|
||||||
|
padding: 0;
|
||||||
|
width: max-content !important;
|
||||||
|
height: fit-content !important;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</body>
|
||||||
|
|
||||||
|
</html>
|
||||||
223
src/renderer/src/aiCore/AI_CORE_DESIGN.md
Normal file
@@ -0,0 +1,223 @@
|
|||||||
|
# Cherry Studio AI Provider 技术架构文档 (新方案)
|
||||||
|
|
||||||
|
## 1. 核心设计理念与目标
|
||||||
|
|
||||||
|
本架构旨在重构 Cherry Studio 的 AI Provider(现称为 `aiCore`)层,以实现以下目标:
|
||||||
|
|
||||||
|
- **职责清晰**:明确划分各组件的职责,降低耦合度。
|
||||||
|
- **高度复用**:最大化业务逻辑和通用处理逻辑的复用,减少重复代码。
|
||||||
|
- **易于扩展**:方便快捷地接入新的 AI Provider (LLM供应商) 和添加新的 AI 功能 (如翻译、摘要、图像生成等)。
|
||||||
|
- **易于维护**:简化单个组件的复杂性,提高代码的可读性和可维护性。
|
||||||
|
- **标准化**:统一内部数据流和接口,简化不同 Provider 之间的差异处理。
|
||||||
|
|
||||||
|
核心思路是将纯粹的 **SDK 适配层 (`XxxApiClient`)**、**通用逻辑处理与智能解析层 (中间件)** 以及 **统一业务功能入口层 (`AiCoreService`)** 清晰地分离开来。
|
||||||
|
|
||||||
|
## 2. 核心组件详解
|
||||||
|
|
||||||
|
### 2.1. `aiCore` (原 `AiProvider` 文件夹)
|
||||||
|
|
||||||
|
这是整个 AI 功能的核心模块。
|
||||||
|
|
||||||
|
#### 2.1.1. `XxxApiClient` (例如 `aiCore/clients/openai/OpenAIApiClient.ts`)
|
||||||
|
|
||||||
|
- **职责**:作为特定 AI Provider SDK 的纯粹适配层。
|
||||||
|
- **参数适配**:将应用内部统一的 `CoreRequest` 对象 (见下文) 转换为特定 SDK 所需的请求参数格式。
|
||||||
|
- **基础响应转换**:将 SDK 返回的原始数据块 (`RawSdkChunk`,例如 `OpenAI.Chat.Completions.ChatCompletionChunk`) 转换为一组最基础、最直接的应用层 `Chunk` 对象 (定义于 `src/renderer/src/types/chunk.ts`)。
|
||||||
|
- 例如:SDK 的 `delta.content` -> `TextDeltaChunk`;SDK 的 `delta.reasoning_content` -> `ThinkingDeltaChunk`;SDK 的 `delta.tool_calls` -> `RawToolCallChunk` (包含原始工具调用数据)。
|
||||||
|
- **关键**:`XxxApiClient` **不处理**耦合在文本内容中的复杂结构,如 `<think>` 或 `<tool_use>` 标签。
|
||||||
|
- **特点**:极度轻量化,代码量少,易于实现和维护新的 Provider 适配。
|
||||||
|
|
||||||
|
#### 2.1.2. `ApiClient.ts` (或 `BaseApiClient.ts` 的核心接口)
|
||||||
|
|
||||||
|
- 定义了所有 `XxxApiClient` 必须实现的接口,如:
|
||||||
|
- `getSdkInstance(): Promise<TSdkInstance> | TSdkInstance`
|
||||||
|
- `getRequestTransformer(): RequestTransformer<TSdkParams>`
|
||||||
|
- `getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk, TResponseContext>`
|
||||||
|
- 其他可选的、与特定 Provider 相关的辅助方法 (如工具调用转换)。
|
||||||
|
|
||||||
|
#### 2.1.3. `ApiClientFactory.ts`
|
||||||
|
|
||||||
|
- 根据 Provider 配置动态创建和返回相应的 `XxxApiClient` 实例。
|
||||||
|
|
||||||
|
#### 2.1.4. `AiCoreService.ts` (`aiCore/index.ts`)
|
||||||
|
|
||||||
|
- **职责**:作为所有 AI 相关业务功能的统一入口。
|
||||||
|
- 提供面向应用的高层接口,例如:
|
||||||
|
- `executeCompletions(params: CompletionsParams): Promise<AggregatedCompletionsResult>`
|
||||||
|
- `translateText(params: TranslateParams): Promise<AggregatedTranslateResult>`
|
||||||
|
- `summarizeText(params: SummarizeParams): Promise<AggregatedSummarizeResult>`
|
||||||
|
- 未来可能的 `generateImage(prompt: string): Promise<ImageResult>` 等。
|
||||||
|
- **返回 `Promise`**:每个服务方法返回一个 `Promise`,该 `Promise` 会在整个(可能是流式的)操作完成后,以包含所有聚合结果(如完整文本、工具调用详情、最终的`usage`/`metrics`等)的对象来 `resolve`。
|
||||||
|
- **支持流式回调**:服务方法的参数 (如 `CompletionsParams`) 依然包含 `onChunk` 回调,用于向调用方实时推送处理过程中的 `Chunk` 数据,实现流式UI更新。
|
||||||
|
- **封装特定任务的提示工程 (Prompt Engineering)**:
|
||||||
|
- 例如,`translateText` 方法内部会构建一个包含特定翻译指令的 `CoreRequest`。
|
||||||
|
- **编排和调用中间件链**:通过内部的 `MiddlewareBuilder` (参见 `middleware/BUILDER_USAGE.md`) 实例,根据调用的业务方法和参数,动态构建和组织合适的中间件序列,然后通过 `applyCompletionsMiddlewares` 等组合函数执行。
|
||||||
|
- 获取 `ApiClient` 实例并将其注入到中间件上游的 `Context` 中。
|
||||||
|
- **将 `Promise` 的 `resolve` 和 `reject` 函数传递给中间件链** (通过 `Context`),以便 `FinalChunkConsumerAndNotifierMiddleware` 可以在操作完成或发生错误时结束该 `Promise`。
|
||||||
|
- **优势**:
|
||||||
|
- 业务逻辑(如翻译、摘要的提示构建和流程控制)只需实现一次,即可支持所有通过 `ApiClient` 接入的底层 Provider。
|
||||||
|
- **支持外部编排**:调用方可以 `await` 服务方法以获取最终聚合结果,然后将此结果作为后续操作的输入,轻松实现多步骤工作流。
|
||||||
|
- **支持内部组合**:服务自身也可以通过 `await` 调用其他原子服务方法来构建更复杂的组合功能。
|
||||||
|
|
||||||
|
#### 2.1.5. `coreRequestTypes.ts` (或 `types.ts`)
|
||||||
|
|
||||||
|
- 定义核心的、Provider 无关的内部请求结构,例如:
|
||||||
|
- `CoreCompletionsRequest`: 包含标准化后的消息列表、模型配置、工具列表、最大Token数、是否流式输出等。
|
||||||
|
- `CoreTranslateRequest`, `CoreSummarizeRequest` 等 (如果与 `CoreCompletionsRequest` 结构差异较大,否则可复用并添加任务类型标记)。
|
||||||
|
|
||||||
|
### 2.2. `middleware`
|
||||||
|
|
||||||
|
中间件层负责处理请求和响应流中的通用逻辑和特定特性。其设计和使用遵循 `middleware/BUILDER_USAGE.md` 中定义的规范。
|
||||||
|
|
||||||
|
**核心组件包括:**
|
||||||
|
|
||||||
|
- **`MiddlewareBuilder`**: 一个通用的、提供流式API的类,用于动态构建中间件链。它支持从基础链开始,根据条件添加、插入、替换或移除中间件。
|
||||||
|
- **`applyCompletionsMiddlewares`**: 负责接收 `MiddlewareBuilder` 构建的链并按顺序执行,专门用于 Completions 流程。
|
||||||
|
- **`MiddlewareRegistry`**: 集中管理所有可用中间件的注册表,提供统一的中间件访问接口。
|
||||||
|
- **各种独立的中间件模块** (存放于 `common/`, `core/`, `feat/` 子目录)。
|
||||||
|
|
||||||
|
#### 2.2.1. `middlewareTypes.ts`
|
||||||
|
|
||||||
|
- 定义中间件的核心类型,如 `AiProviderMiddlewareContext` (扩展后包含 `_apiClientInstance` 和 `_coreRequest`)、`MiddlewareAPI`、`CompletionsMiddleware` 等。
|
||||||
|
|
||||||
|
#### 2.2.2. 核心中间件 (`middleware/core/`)
|
||||||
|
|
||||||
|
- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()` 将 `CoreRequest` 转换为特定 SDK 的参数,并存入上下文。
|
||||||
|
- **`RequestExecutionMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。
|
||||||
|
- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流 (如异步迭代器) 统一适配为 `ReadableStream<RawSdkChunk>`。
|
||||||
|
- **`RawSdkChunk`**:指特定AI提供商SDK在流式响应中返回的、未经应用层统一处理的原始数据块格式 (例如 OpenAI 的 `ChatCompletionChunk`,Gemini 的 `GenerateContentResponse` 中的部分等)。
|
||||||
|
- **`RawSdkChunkToAppChunkMiddleware.ts`**: (新增) 消费 `ReadableStream<RawSdkChunk>`,在其内部对每个 `RawSdkChunk` 调用 `ApiClient.getResponseChunkTransformer()`,将其转换为一个或多个基础的应用层 `Chunk` 对象,并输出 `ReadableStream<Chunk>`。
|
||||||
|
|
||||||
|
#### 2.2.3. 特性中间件 (`middleware/feat/`)
|
||||||
|
|
||||||
|
这些中间件消费由 `ResponseTransformMiddleware` 输出的、相对标准化的 `Chunk` 流,并处理更复杂的逻辑。
|
||||||
|
|
||||||
|
- **`ThinkingTagExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `<think>...</think>` 文本内嵌标签,生成 `ThinkingDeltaChunk` 和 `ThinkingCompleteChunk`。
|
||||||
|
- **`ToolUseExtractionMiddleware.ts`**: 检查 `TextDeltaChunk`,解析其中可能包含的 `<tool_use>...</tool_use>` 文本内嵌标签,生成工具调用相关的 Chunk。如果 `ApiClient` 输出了原生工具调用数据,此中间件也负责将其转换为标准格式。
|
||||||
|
|
||||||
|
#### 2.2.4. 核心处理中间件 (`middleware/core/`)
|
||||||
|
|
||||||
|
- **`TransformCoreToSdkParamsMiddleware.ts`**: 调用 `ApiClient.getRequestTransformer()` 将 `CoreRequest` 转换为特定 SDK 的参数,并存入上下文。
|
||||||
|
- **`SdkCallMiddleware.ts`**: 调用 `ApiClient.getSdkInstance()` 获取 SDK 实例,并使用转换后的参数执行实际的 API 调用,返回原始 SDK 流。
|
||||||
|
- **`StreamAdapterMiddleware.ts`**: 将各种形态的原始 SDK 流统一适配为标准流格式。
|
||||||
|
- **`ResponseTransformMiddleware.ts`**: 将原始 SDK 响应转换为应用层标准 `Chunk` 对象。
|
||||||
|
- **`TextChunkMiddleware.ts`**: 处理文本相关的 Chunk 流。
|
||||||
|
- **`ThinkChunkMiddleware.ts`**: 处理思考相关的 Chunk 流。
|
||||||
|
- **`McpToolChunkMiddleware.ts`**: 处理工具调用相关的 Chunk 流。
|
||||||
|
- **`WebSearchMiddleware.ts`**: 处理 Web 搜索相关逻辑。
|
||||||
|
|
||||||
|
#### 2.2.5. 通用中间件 (`middleware/common/`)
|
||||||
|
|
||||||
|
- **`LoggingMiddleware.ts`**: 请求和响应日志。
|
||||||
|
- **`AbortHandlerMiddleware.ts`**: 处理请求中止。
|
||||||
|
- **`FinalChunkConsumerMiddleware.ts`**: 消费最终的 `Chunk` 流,通过 `context.onChunk` 回调通知应用层实时数据。
|
||||||
|
- **累积数据**:在流式处理过程中,累积关键数据,如文本片段、工具调用信息、`usage`/`metrics` 等。
|
||||||
|
- **结束 `Promise`**:当输入流结束时,使用累积的聚合结果来完成整个处理流程。
|
||||||
|
- 在流结束时,发送包含最终累加信息的完成信号。
|
||||||
|
|
||||||
|
### 2.3. `types/chunk.ts`
|
||||||
|
|
||||||
|
- 定义应用全局统一的 `Chunk` 类型及其所有变体。这包括基础类型 (如 `TextDeltaChunk`, `ThinkingDeltaChunk`)、SDK原生数据传递类型 (如 `RawToolCallChunk`, `RawFinishChunk` - 作为 `ApiClient` 转换的中间产物),以及功能性类型 (如 `McpToolCallRequestChunk`, `WebSearchCompleteChunk`)。
|
||||||
|
|
||||||
|
## 3. 核心执行流程 (以 `AiCoreService.executeCompletions` 为例)
|
||||||
|
|
||||||
|
```markdown
|
||||||
|
**应用层 (例如 UI 组件)**
|
||||||
|
||
|
||||||
|
\\/
|
||||||
|
**`AiProvider.completions` (`aiCore/index.ts`)**
|
||||||
|
(1. prepare ApiClient instance. 2. use `CompletionsMiddlewareBuilder.withDefaults()` to build middleware chain. 3. call `applyCompletionsMiddlewares`)
|
||||||
|
||
|
||||||
|
\\/
|
||||||
|
**`applyCompletionsMiddlewares` (`middleware/composer.ts`)**
|
||||||
|
(接收构建好的链、ApiClient实例、原始SDK方法,开始按序执行中间件)
|
||||||
|
||
|
||||||
|
\\/
|
||||||
|
**[ 预处理阶段中间件 ]**
|
||||||
|
(例如: `FinalChunkConsumerMiddleware`, `TransformCoreToSdkParamsMiddleware`, `AbortHandlerMiddleware`)
|
||||||
|
|| (Context 中准备好 SDK 请求参数)
|
||||||
|
\\/
|
||||||
|
**[ 处理阶段中间件 ]**
|
||||||
|
(例如: `McpToolChunkMiddleware`, `WebSearchMiddleware`, `TextChunkMiddleware`, `ThinkingTagExtractionMiddleware`)
|
||||||
|
|| (处理各种特性和Chunk类型)
|
||||||
|
\\/
|
||||||
|
**[ SDK调用阶段中间件 ]**
|
||||||
|
(例如: `ResponseTransformMiddleware`, `StreamAdapterMiddleware`, `SdkCallMiddleware`)
|
||||||
|
|| (输出: 标准化的应用层Chunk流)
|
||||||
|
\\/
|
||||||
|
**`FinalChunkConsumerMiddleware` (核心)**
|
||||||
|
(消费最终的 `Chunk` 流, 通过 `context.onChunk` 回调通知应用层, 并在流结束时完成处理)
|
||||||
|
||
|
||||||
|
\\/
|
||||||
|
**`AiProvider.completions` 返回 `Promise<CompletionsResult>`**
|
||||||
|
```
|
||||||
|
|
||||||
|
## 4. 建议的文件/目录结构
|
||||||
|
|
||||||
|
```
|
||||||
|
src/renderer/src/
|
||||||
|
└── aiCore/
|
||||||
|
├── clients/
|
||||||
|
│ ├── openai/
|
||||||
|
│ ├── gemini/
|
||||||
|
│ ├── anthropic/
|
||||||
|
│ ├── BaseApiClient.ts
|
||||||
|
│ ├── ApiClientFactory.ts
|
||||||
|
│ ├── AihubmixAPIClient.ts
|
||||||
|
│ ├── index.ts
|
||||||
|
│ └── types.ts
|
||||||
|
├── middleware/
|
||||||
|
│ ├── common/
|
||||||
|
│ ├── core/
|
||||||
|
│ ├── feat/
|
||||||
|
│ ├── builder.ts
|
||||||
|
│ ├── composer.ts
|
||||||
|
│ ├── index.ts
|
||||||
|
│ ├── register.ts
|
||||||
|
│ ├── schemas.ts
|
||||||
|
│ ├── types.ts
|
||||||
|
│ └── utils.ts
|
||||||
|
├── types/
|
||||||
|
│ ├── chunk.ts
|
||||||
|
│ └── ...
|
||||||
|
└── index.ts
|
||||||
|
```
|
||||||
|
|
||||||
|
## 5. 迁移和实施建议
|
||||||
|
|
||||||
|
- **小步快跑,逐步迭代**:优先完成核心流程的重构(例如 `completions`),再逐步迁移其他功能(`translate` 等)和其他 Provider。
|
||||||
|
- **优先定义核心类型**:`CoreRequest`, `Chunk`, `ApiClient` 接口是整个架构的基石。
|
||||||
|
- **为 `ApiClient` 瘦身**:将现有 `XxxProvider` 中的复杂逻辑剥离到新的中间件或 `AiCoreService` 中。
|
||||||
|
- **强化中间件**:让中间件承担起更多解析和特性处理的责任。
|
||||||
|
- **编写单元测试和集成测试**:确保每个组件和整体流程的正确性。
|
||||||
|
|
||||||
|
此架构旨在提供一个更健壮、更灵活、更易于维护的 AI 功能核心,支撑 Cherry Studio 未来的发展。
|
||||||
|
|
||||||
|
## 6. 迁移策略与实施建议
|
||||||
|
|
||||||
|
本节内容提炼自早期的 `migrate.md` 文档,并根据最新的架构讨论进行了调整。
|
||||||
|
|
||||||
|
**目标架构核心组件回顾:**
|
||||||
|
|
||||||
|
与第 2 节描述的核心组件一致,主要包括 `XxxApiClient`, `AiCoreService`, 中间件链, `CoreRequest` 类型, 和标准化的 `Chunk` 类型。
|
||||||
|
|
||||||
|
**迁移步骤:**
|
||||||
|
|
||||||
|
**Phase 0: 准备工作和类型定义**
|
||||||
|
|
||||||
|
1. **定义核心数据结构 (TypeScript 类型):**
|
||||||
|
- `CoreCompletionsRequest` (Type):定义应用内部统一的对话请求结构。
|
||||||
|
- `Chunk` (Type - 检查并按需扩展现有 `src/renderer/src/types/chunk.ts`):定义所有可能的通用Chunk类型。
|
||||||
|
- 为其他API(翻译、总结)定义类似的 `CoreXxxRequest` (Type)。
|
||||||
|
2. **定义 `ApiClient` 接口:** 明确 `getRequestTransformer`, `getResponseChunkTransformer`, `getSdkInstance` 等核心方法。
|
||||||
|
3. **调整 `AiProviderMiddlewareContext`:**
|
||||||
|
- 确保包含 `_apiClientInstance: ApiClient<any,any,any>`。
|
||||||
|
- 确保包含 `_coreRequest: CoreRequestType`。
|
||||||
|
- 考虑添加 `resolvePromise: (value: AggregatedResultType) => void` 和 `rejectPromise: (reason?: any) => void` 用于 `AiCoreService` 的 Promise 返回。
|
||||||
|
|
||||||
|
**Phase 1: 实现第一个 `ApiClient` (以 `OpenAIApiClient` 为例)**
|
||||||
|
|
||||||
|
1. **创建 `OpenAIApiClient` 类:** 实现 `ApiClient` 接口。
|
||||||
|
2. **迁移SDK实例和配置。**
|
||||||
|
3. **实现 `getRequestTransformer()`:** 将 `CoreCompletionsRequest` 转换为 OpenAI SDK 参数。
|
||||||
|
4. **实现 `getResponseChunkTransformer()`:** 将 `OpenAI.Chat.Completions.ChatCompletionChunk` 转换为基础的 `
|
||||||
207
src/renderer/src/aiCore/clients/AihubmixAPIClient.ts
Normal file
@@ -0,0 +1,207 @@
|
|||||||
|
import { isOpenAILLMModel } from '@renderer/config/models'
|
||||||
|
import {
|
||||||
|
GenerateImageParams,
|
||||||
|
MCPCallToolResponse,
|
||||||
|
MCPTool,
|
||||||
|
MCPToolResponse,
|
||||||
|
Model,
|
||||||
|
Provider,
|
||||||
|
ToolCallResponse
|
||||||
|
} from '@renderer/types'
|
||||||
|
import {
|
||||||
|
RequestOptions,
|
||||||
|
SdkInstance,
|
||||||
|
SdkMessageParam,
|
||||||
|
SdkModel,
|
||||||
|
SdkParams,
|
||||||
|
SdkRawChunk,
|
||||||
|
SdkRawOutput,
|
||||||
|
SdkTool,
|
||||||
|
SdkToolCall
|
||||||
|
} from '@renderer/types/sdk'
|
||||||
|
|
||||||
|
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
||||||
|
import { BaseApiClient } from './BaseApiClient'
|
||||||
|
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
|
||||||
|
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
|
||||||
|
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
|
||||||
|
import { RequestTransformer, ResponseChunkTransformer } from './types'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* AihubmixAPIClient - 根据模型类型自动选择合适的ApiClient
|
||||||
|
* 使用装饰器模式实现,在ApiClient层面进行模型路由
|
||||||
|
*/
|
||||||
|
export class AihubmixAPIClient extends BaseApiClient {
|
||||||
|
// 使用联合类型而不是any,保持类型安全
|
||||||
|
private clients: Map<string, AnthropicAPIClient | GeminiAPIClient | OpenAIResponseAPIClient | OpenAIAPIClient> =
|
||||||
|
new Map()
|
||||||
|
private defaultClient: OpenAIAPIClient
|
||||||
|
private currentClient: BaseApiClient
|
||||||
|
|
||||||
|
constructor(provider: Provider) {
|
||||||
|
super(provider)
|
||||||
|
|
||||||
|
// 初始化各个client - 现在有类型安全
|
||||||
|
const claudeClient = new AnthropicAPIClient(provider)
|
||||||
|
const geminiClient = new GeminiAPIClient({ ...provider, apiHost: 'https://aihubmix.com/gemini' })
|
||||||
|
const openaiClient = new OpenAIResponseAPIClient(provider)
|
||||||
|
const defaultClient = new OpenAIAPIClient(provider)
|
||||||
|
|
||||||
|
this.clients.set('claude', claudeClient)
|
||||||
|
this.clients.set('gemini', geminiClient)
|
||||||
|
this.clients.set('openai', openaiClient)
|
||||||
|
this.clients.set('default', defaultClient)
|
||||||
|
|
||||||
|
// 设置默认client
|
||||||
|
this.defaultClient = defaultClient
|
||||||
|
this.currentClient = this.defaultClient as BaseApiClient
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 类型守卫:确保client是BaseApiClient的实例
|
||||||
|
*/
|
||||||
|
private isValidClient(client: unknown): client is BaseApiClient {
|
||||||
|
return (
|
||||||
|
client !== null &&
|
||||||
|
client !== undefined &&
|
||||||
|
typeof client === 'object' &&
|
||||||
|
'createCompletions' in client &&
|
||||||
|
'getRequestTransformer' in client &&
|
||||||
|
'getResponseChunkTransformer' in client
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 根据模型获取合适的client
|
||||||
|
*/
|
||||||
|
private getClient(model: Model): BaseApiClient {
|
||||||
|
const id = model.id.toLowerCase()
|
||||||
|
|
||||||
|
// claude开头
|
||||||
|
if (id.startsWith('claude')) {
|
||||||
|
const client = this.clients.get('claude')
|
||||||
|
if (!client || !this.isValidClient(client)) {
|
||||||
|
throw new Error('Claude client not properly initialized')
|
||||||
|
}
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
// gemini开头 且不以-nothink、-search结尾
|
||||||
|
if ((id.startsWith('gemini') || id.startsWith('imagen')) && !id.endsWith('-nothink') && !id.endsWith('-search')) {
|
||||||
|
const client = this.clients.get('gemini')
|
||||||
|
if (!client || !this.isValidClient(client)) {
|
||||||
|
throw new Error('Gemini client not properly initialized')
|
||||||
|
}
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAI系列模型
|
||||||
|
if (isOpenAILLMModel(model)) {
|
||||||
|
const client = this.clients.get('openai')
|
||||||
|
if (!client || !this.isValidClient(client)) {
|
||||||
|
throw new Error('OpenAI client not properly initialized')
|
||||||
|
}
|
||||||
|
return client
|
||||||
|
}
|
||||||
|
|
||||||
|
return this.defaultClient as BaseApiClient
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 根据模型选择合适的client并委托调用
|
||||||
|
*/
|
||||||
|
public getClientForModel(model: Model): BaseApiClient {
|
||||||
|
this.currentClient = this.getClient(model)
|
||||||
|
return this.currentClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============ BaseApiClient 抽象方法实现 ============
|
||||||
|
|
||||||
|
async createCompletions(payload: SdkParams, options?: RequestOptions): Promise<SdkRawOutput> {
|
||||||
|
// 尝试从payload中提取模型信息来选择client
|
||||||
|
const modelId = this.extractModelFromPayload(payload)
|
||||||
|
if (modelId) {
|
||||||
|
const modelObj = { id: modelId } as Model
|
||||||
|
const targetClient = this.getClient(modelObj)
|
||||||
|
return targetClient.createCompletions(payload, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果无法从payload中提取模型,使用当前设置的client
|
||||||
|
return this.currentClient.createCompletions(payload, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从SDK payload中提取模型ID
|
||||||
|
*/
|
||||||
|
private extractModelFromPayload(payload: SdkParams): string | null {
|
||||||
|
// 不同的SDK可能有不同的字段名
|
||||||
|
if ('model' in payload && typeof payload.model === 'string') {
|
||||||
|
return payload.model
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||||
|
return this.currentClient.generateImage(params)
|
||||||
|
}
|
||||||
|
|
||||||
|
async getEmbeddingDimensions(model?: Model): Promise<number> {
|
||||||
|
const client = model ? this.getClient(model) : this.currentClient
|
||||||
|
return client.getEmbeddingDimensions(model)
|
||||||
|
}
|
||||||
|
|
||||||
|
async listModels(): Promise<SdkModel[]> {
|
||||||
|
// 可以聚合所有client的模型,或者使用默认client
|
||||||
|
return this.defaultClient.listModels()
|
||||||
|
}
|
||||||
|
|
||||||
|
async getSdkInstance(): Promise<SdkInstance> {
|
||||||
|
return this.currentClient.getSdkInstance()
|
||||||
|
}
|
||||||
|
|
||||||
|
getRequestTransformer(): RequestTransformer<SdkParams, SdkMessageParam> {
|
||||||
|
return this.currentClient.getRequestTransformer()
|
||||||
|
}
|
||||||
|
|
||||||
|
getResponseChunkTransformer(): ResponseChunkTransformer<SdkRawChunk> {
|
||||||
|
return this.currentClient.getResponseChunkTransformer()
|
||||||
|
}
|
||||||
|
|
||||||
|
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] {
|
||||||
|
return this.currentClient.convertMcpToolsToSdkTools(mcpTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
convertSdkToolCallToMcp(toolCall: SdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||||
|
return this.currentClient.convertSdkToolCallToMcp(toolCall, mcpTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
convertSdkToolCallToMcpToolResponse(toolCall: SdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
||||||
|
return this.currentClient.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
|
||||||
|
}
|
||||||
|
|
||||||
|
buildSdkMessages(
|
||||||
|
currentReqMessages: SdkMessageParam[],
|
||||||
|
output: SdkRawOutput | string,
|
||||||
|
toolResults: SdkMessageParam[],
|
||||||
|
toolCalls?: SdkToolCall[]
|
||||||
|
): SdkMessageParam[] {
|
||||||
|
return this.currentClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
|
||||||
|
}
|
||||||
|
|
||||||
|
convertMcpToolResponseToSdkMessageParam(
|
||||||
|
mcpToolResponse: MCPToolResponse,
|
||||||
|
resp: MCPCallToolResponse,
|
||||||
|
model: Model
|
||||||
|
): SdkMessageParam | undefined {
|
||||||
|
const client = this.getClient(model)
|
||||||
|
return client.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||||
|
}
|
||||||
|
|
||||||
|
extractMessagesFromSdkPayload(sdkPayload: SdkParams): SdkMessageParam[] {
|
||||||
|
return this.currentClient.extractMessagesFromSdkPayload(sdkPayload)
|
||||||
|
}
|
||||||
|
|
||||||
|
estimateMessageTokens(message: SdkMessageParam): number {
|
||||||
|
return this.currentClient.estimateMessageTokens(message)
|
||||||
|
}
|
||||||
|
}
|
||||||
66
src/renderer/src/aiCore/clients/ApiClientFactory.ts
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
import { Provider } from '@renderer/types'
|
||||||
|
|
||||||
|
import { AihubmixAPIClient } from './AihubmixAPIClient'
|
||||||
|
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
||||||
|
import { BaseApiClient } from './BaseApiClient'
|
||||||
|
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
|
||||||
|
import { VertexAPIClient } from './gemini/VertexAPIClient'
|
||||||
|
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
|
||||||
|
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Factory for creating ApiClient instances based on provider configuration
|
||||||
|
* 根据提供者配置创建ApiClient实例的工厂
|
||||||
|
*/
|
||||||
|
export class ApiClientFactory {
|
||||||
|
/**
|
||||||
|
* Create an ApiClient instance for the given provider
|
||||||
|
* 为给定的提供者创建ApiClient实例
|
||||||
|
*/
|
||||||
|
static create(provider: Provider): BaseApiClient {
|
||||||
|
console.log(`[ApiClientFactory] Creating ApiClient for provider:`, {
|
||||||
|
id: provider.id,
|
||||||
|
type: provider.type
|
||||||
|
})
|
||||||
|
|
||||||
|
let instance: BaseApiClient
|
||||||
|
|
||||||
|
// 首先检查特殊的provider id
|
||||||
|
if (provider.id === 'aihubmix') {
|
||||||
|
console.log(`[ApiClientFactory] Creating AihubmixAPIClient for provider: ${provider.id}`)
|
||||||
|
instance = new AihubmixAPIClient(provider) as BaseApiClient
|
||||||
|
return instance
|
||||||
|
}
|
||||||
|
|
||||||
|
// 然后检查标准的provider type
|
||||||
|
switch (provider.type) {
|
||||||
|
case 'openai':
|
||||||
|
case 'azure-openai':
|
||||||
|
console.log(`[ApiClientFactory] Creating OpenAIApiClient for provider: ${provider.id}`)
|
||||||
|
instance = new OpenAIAPIClient(provider) as BaseApiClient
|
||||||
|
break
|
||||||
|
case 'openai-response':
|
||||||
|
instance = new OpenAIResponseAPIClient(provider) as BaseApiClient
|
||||||
|
break
|
||||||
|
case 'gemini':
|
||||||
|
instance = new GeminiAPIClient(provider) as BaseApiClient
|
||||||
|
break
|
||||||
|
case 'vertexai':
|
||||||
|
instance = new VertexAPIClient(provider) as BaseApiClient
|
||||||
|
break
|
||||||
|
case 'anthropic':
|
||||||
|
instance = new AnthropicAPIClient(provider) as BaseApiClient
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
console.log(`[ApiClientFactory] Using default OpenAIApiClient for provider: ${provider.id}`)
|
||||||
|
instance = new OpenAIAPIClient(provider) as BaseApiClient
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
return instance
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export function isOpenAIProvider(provider: Provider) {
|
||||||
|
return !['anthropic', 'gemini'].includes(provider.type)
|
||||||
|
}
|
||||||
@@ -1,40 +1,69 @@
|
|||||||
import Logger from '@renderer/config/logger'
|
import {
|
||||||
import { isFunctionCallingModel, isNotSupportTemperatureAndTopP } from '@renderer/config/models'
|
isFunctionCallingModel,
|
||||||
|
isNotSupportTemperatureAndTopP,
|
||||||
|
isOpenAIModel,
|
||||||
|
isSupportedFlexServiceTier
|
||||||
|
} from '@renderer/config/models'
|
||||||
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
|
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
|
||||||
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
|
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
|
||||||
import type {
|
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||||
|
import { SettingsState } from '@renderer/store/settings'
|
||||||
|
import {
|
||||||
Assistant,
|
Assistant,
|
||||||
|
FileTypes,
|
||||||
GenerateImageParams,
|
GenerateImageParams,
|
||||||
KnowledgeReference,
|
KnowledgeReference,
|
||||||
MCPCallToolResponse,
|
MCPCallToolResponse,
|
||||||
MCPTool,
|
MCPTool,
|
||||||
MCPToolResponse,
|
MCPToolResponse,
|
||||||
Model,
|
Model,
|
||||||
|
OpenAIServiceTier,
|
||||||
Provider,
|
Provider,
|
||||||
Suggestion,
|
ToolCallResponse,
|
||||||
WebSearchProviderResponse,
|
WebSearchProviderResponse,
|
||||||
WebSearchResponse
|
WebSearchResponse
|
||||||
} from '@renderer/types'
|
} from '@renderer/types'
|
||||||
import { ChunkType } from '@renderer/types/chunk'
|
import { Message } from '@renderer/types/newMessage'
|
||||||
import type { Message } from '@renderer/types/newMessage'
|
import {
|
||||||
import { delay, isJSON, parseJSON } from '@renderer/utils'
|
RequestOptions,
|
||||||
|
SdkInstance,
|
||||||
|
SdkMessageParam,
|
||||||
|
SdkModel,
|
||||||
|
SdkParams,
|
||||||
|
SdkRawChunk,
|
||||||
|
SdkRawOutput,
|
||||||
|
SdkTool,
|
||||||
|
SdkToolCall
|
||||||
|
} from '@renderer/types/sdk'
|
||||||
|
import { isJSON, parseJSON } from '@renderer/utils'
|
||||||
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
|
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
|
||||||
import { formatApiHost } from '@renderer/utils/api'
|
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||||
import { getMainTextContent } from '@renderer/utils/messageUtils/find'
|
import { defaultTimeout } from '@shared/config/constant'
|
||||||
|
import Logger from 'electron-log/renderer'
|
||||||
import { isEmpty } from 'lodash'
|
import { isEmpty } from 'lodash'
|
||||||
import type OpenAI from 'openai'
|
|
||||||
|
|
||||||
import type { CompletionsParams } from '.'
|
import { ApiClient, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from './types'
|
||||||
|
|
||||||
export default abstract class BaseProvider {
|
/**
|
||||||
// Threshold for determining whether to use system prompt for tools
|
* Abstract base class for API clients.
|
||||||
|
* Provides common functionality and structure for specific client implementations.
|
||||||
|
*/
|
||||||
|
export abstract class BaseApiClient<
|
||||||
|
TSdkInstance extends SdkInstance = SdkInstance,
|
||||||
|
TSdkParams extends SdkParams = SdkParams,
|
||||||
|
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||||
|
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||||
|
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||||
|
TToolCall extends SdkToolCall = SdkToolCall,
|
||||||
|
TSdkSpecificTool extends SdkTool = SdkTool
|
||||||
|
> implements ApiClient<TSdkInstance, TSdkParams, TRawOutput, TRawChunk, TMessageParam, TToolCall, TSdkSpecificTool>
|
||||||
|
{
|
||||||
private static readonly SYSTEM_PROMPT_THRESHOLD: number = 128
|
private static readonly SYSTEM_PROMPT_THRESHOLD: number = 128
|
||||||
|
public provider: Provider
|
||||||
protected provider: Provider
|
|
||||||
protected host: string
|
protected host: string
|
||||||
protected apiKey: string
|
protected apiKey: string
|
||||||
|
protected sdkInstance?: TSdkInstance
|
||||||
protected useSystemPromptForTools: boolean = true
|
public useSystemPromptForTools: boolean = true
|
||||||
|
|
||||||
constructor(provider: Provider) {
|
constructor(provider: Provider) {
|
||||||
this.provider = provider
|
this.provider = provider
|
||||||
@@ -42,31 +71,81 @@ export default abstract class BaseProvider {
|
|||||||
this.apiKey = this.getApiKey()
|
this.apiKey = this.getApiKey()
|
||||||
}
|
}
|
||||||
|
|
||||||
abstract completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void>
|
// // 核心的completions方法 - 在中间件架构中,这通常只是一个占位符
|
||||||
abstract translate(
|
// abstract completions(params: CompletionsParams, internal?: ProcessingState): Promise<CompletionsResult>
|
||||||
content: string,
|
|
||||||
assistant: Assistant,
|
/**
|
||||||
onResponse?: (text: string, isComplete: boolean) => void
|
* 核心API Endpoint
|
||||||
): Promise<string>
|
**/
|
||||||
abstract summaries(messages: Message[], assistant: Assistant): Promise<string>
|
|
||||||
abstract summaryForSearch(messages: Message[], assistant: Assistant): Promise<string | null>
|
abstract createCompletions(payload: TSdkParams, options?: RequestOptions): Promise<TRawOutput>
|
||||||
abstract suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]>
|
|
||||||
abstract generateText({ prompt, content }: { prompt: string; content: string }): Promise<string>
|
abstract generateImage(generateImageParams: GenerateImageParams): Promise<string[]>
|
||||||
abstract check(model: Model, stream: boolean): Promise<{ valid: boolean; error: Error | null }>
|
|
||||||
abstract models(): Promise<OpenAI.Models.Model[]>
|
abstract getEmbeddingDimensions(model?: Model): Promise<number>
|
||||||
abstract generateImage(params: GenerateImageParams): Promise<string[]>
|
|
||||||
abstract generateImageByChat({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void>
|
abstract listModels(): Promise<SdkModel[]>
|
||||||
abstract getEmbeddingDimensions(model: Model): Promise<number>
|
|
||||||
public abstract convertMcpTools<T>(mcpTools: MCPTool[]): T[]
|
abstract getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
|
||||||
public abstract mcpToolCallResponseToMessage(
|
|
||||||
|
/**
|
||||||
|
* 中间件
|
||||||
|
**/
|
||||||
|
|
||||||
|
// 在 CoreRequestToSdkParamsMiddleware中使用
|
||||||
|
abstract getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
|
||||||
|
// 在RawSdkChunkToGenericChunkMiddleware中使用
|
||||||
|
abstract getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 工具转换
|
||||||
|
**/
|
||||||
|
|
||||||
|
// Optional tool conversion methods - implement if needed by the specific provider
|
||||||
|
abstract convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[]
|
||||||
|
|
||||||
|
abstract convertSdkToolCallToMcp(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined
|
||||||
|
|
||||||
|
abstract convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse
|
||||||
|
|
||||||
|
abstract buildSdkMessages(
|
||||||
|
currentReqMessages: TMessageParam[],
|
||||||
|
output: TRawOutput | string,
|
||||||
|
toolResults: TMessageParam[],
|
||||||
|
toolCalls?: TToolCall[]
|
||||||
|
): TMessageParam[]
|
||||||
|
|
||||||
|
abstract estimateMessageTokens(message: TMessageParam): number
|
||||||
|
|
||||||
|
abstract convertMcpToolResponseToSdkMessageParam(
|
||||||
mcpToolResponse: MCPToolResponse,
|
mcpToolResponse: MCPToolResponse,
|
||||||
resp: MCPCallToolResponse,
|
resp: MCPCallToolResponse,
|
||||||
model: Model
|
model: Model
|
||||||
): any
|
): TMessageParam | undefined
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从SDK载荷中提取消息数组(用于中间件中的类型安全访问)
|
||||||
|
* 不同的提供商可能使用不同的字段名(如messages、history等)
|
||||||
|
*/
|
||||||
|
abstract extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 附加原始流监听器
|
||||||
|
*/
|
||||||
|
public attachRawStreamListener<TListener extends RawStreamListener<TRawChunk>>(
|
||||||
|
rawOutput: TRawOutput,
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||||
|
_listener: TListener
|
||||||
|
): TRawOutput {
|
||||||
|
return rawOutput
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 通用函数
|
||||||
|
**/
|
||||||
|
|
||||||
public getBaseURL(): string {
|
public getBaseURL(): string {
|
||||||
const host = this.provider.apiHost
|
return this.provider.apiHost
|
||||||
return formatApiHost(host)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public getApiKey() {
|
public getApiKey() {
|
||||||
@@ -111,14 +190,32 @@ export default abstract class BaseProvider {
|
|||||||
return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.topP
|
return isNotSupportTemperatureAndTopP(model) ? undefined : assistant.settings?.topP
|
||||||
}
|
}
|
||||||
|
|
||||||
public async fakeCompletions({ onChunk }: CompletionsParams) {
|
protected getServiceTier(model: Model) {
|
||||||
for (let i = 0; i < 100; i++) {
|
if (!isOpenAIModel(model) || model.provider === 'github' || model.provider === 'copilot') {
|
||||||
await delay(0.01)
|
return undefined
|
||||||
onChunk({
|
|
||||||
response: { text: i + '\n', usage: { completion_tokens: 0, prompt_tokens: 0, total_tokens: 0 } },
|
|
||||||
type: ChunkType.BLOCK_COMPLETE
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
|
||||||
|
let serviceTier = 'auto' as OpenAIServiceTier
|
||||||
|
|
||||||
|
if (openAI && openAI?.serviceTier === 'flex') {
|
||||||
|
if (isSupportedFlexServiceTier(model)) {
|
||||||
|
serviceTier = 'flex'
|
||||||
|
} else {
|
||||||
|
serviceTier = 'auto'
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
serviceTier = openAI.serviceTier
|
||||||
|
}
|
||||||
|
|
||||||
|
return serviceTier
|
||||||
|
}
|
||||||
|
|
||||||
|
protected getTimeout(model: Model) {
|
||||||
|
if (isSupportedFlexServiceTier(model)) {
|
||||||
|
return 15 * 1000 * 60
|
||||||
|
}
|
||||||
|
return defaultTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
public async getMessageContent(message: Message): Promise<string> {
|
public async getMessageContent(message: Message): Promise<string> {
|
||||||
@@ -148,6 +245,36 @@ export default abstract class BaseProvider {
|
|||||||
return content
|
return content
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extract the file content from the message
|
||||||
|
* @param message - The message
|
||||||
|
* @returns The file content
|
||||||
|
*/
|
||||||
|
protected async extractFileContent(message: Message) {
|
||||||
|
const fileBlocks = findFileBlocks(message)
|
||||||
|
if (fileBlocks.length > 0) {
|
||||||
|
const textFileBlocks = fileBlocks.filter(
|
||||||
|
(fb) => fb.file && [FileTypes.TEXT, FileTypes.DOCUMENT].includes(fb.file.type)
|
||||||
|
)
|
||||||
|
|
||||||
|
if (textFileBlocks.length > 0) {
|
||||||
|
let text = ''
|
||||||
|
const divider = '\n\n---\n\n'
|
||||||
|
|
||||||
|
for (const fileBlock of textFileBlocks) {
|
||||||
|
const file = fileBlock.file
|
||||||
|
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
|
||||||
|
const fileNameRow = 'file: ' + file.origin_name + '\n\n'
|
||||||
|
text = text + fileNameRow + fileContent + divider
|
||||||
|
}
|
||||||
|
|
||||||
|
return text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return ''
|
||||||
|
}
|
||||||
|
|
||||||
private async getWebSearchReferencesFromCache(message: Message) {
|
private async getWebSearchReferencesFromCache(message: Message) {
|
||||||
const content = getMainTextContent(message)
|
const content = getMainTextContent(message)
|
||||||
if (isEmpty(content)) {
|
if (isEmpty(content)) {
|
||||||
@@ -209,7 +336,7 @@ export default abstract class BaseProvider {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
protected createAbortController(messageId?: string, isAddEventListener?: boolean) {
|
public createAbortController(messageId?: string, isAddEventListener?: boolean) {
|
||||||
const abortController = new AbortController()
|
const abortController = new AbortController()
|
||||||
const abortFn = () => abortController.abort()
|
const abortFn = () => abortController.abort()
|
||||||
|
|
||||||
@@ -255,11 +382,11 @@ export default abstract class BaseProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Setup tools configuration based on provided parameters
|
// Setup tools configuration based on provided parameters
|
||||||
protected setupToolsConfig<T>(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
|
public setupToolsConfig(params: { mcpTools?: MCPTool[]; model: Model; enableToolUse?: boolean }): {
|
||||||
tools: T[]
|
tools: TSdkSpecificTool[]
|
||||||
} {
|
} {
|
||||||
const { mcpTools, model, enableToolUse } = params
|
const { mcpTools, model, enableToolUse } = params
|
||||||
let tools: T[] = []
|
let tools: TSdkSpecificTool[] = []
|
||||||
|
|
||||||
// If there are no tools, return an empty array
|
// If there are no tools, return an empty array
|
||||||
if (!mcpTools?.length) {
|
if (!mcpTools?.length) {
|
||||||
@@ -267,14 +394,14 @@ export default abstract class BaseProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// If the number of tools exceeds the threshold, use the system prompt
|
// If the number of tools exceeds the threshold, use the system prompt
|
||||||
if (mcpTools.length > BaseProvider.SYSTEM_PROMPT_THRESHOLD) {
|
if (mcpTools.length > BaseApiClient.SYSTEM_PROMPT_THRESHOLD) {
|
||||||
this.useSystemPromptForTools = true
|
this.useSystemPromptForTools = true
|
||||||
return { tools }
|
return { tools }
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the model supports function calling and tool usage is enabled
|
// If the model supports function calling and tool usage is enabled
|
||||||
if (isFunctionCallingModel(model) && enableToolUse) {
|
if (isFunctionCallingModel(model) && enableToolUse) {
|
||||||
tools = this.convertMcpTools<T>(mcpTools)
|
tools = this.convertMcpToolsToSdkTools(mcpTools)
|
||||||
this.useSystemPromptForTools = false
|
this.useSystemPromptForTools = false
|
||||||
}
|
}
|
||||||
|
|
||||||
714
src/renderer/src/aiCore/clients/anthropic/AnthropicAPIClient.ts
Normal file
@@ -0,0 +1,714 @@
|
|||||||
|
import Anthropic from '@anthropic-ai/sdk'
|
||||||
|
import {
|
||||||
|
Base64ImageSource,
|
||||||
|
ImageBlockParam,
|
||||||
|
MessageParam,
|
||||||
|
TextBlockParam,
|
||||||
|
ToolResultBlockParam,
|
||||||
|
ToolUseBlock,
|
||||||
|
WebSearchTool20250305
|
||||||
|
} from '@anthropic-ai/sdk/resources'
|
||||||
|
import {
|
||||||
|
ContentBlock,
|
||||||
|
ContentBlockParam,
|
||||||
|
MessageCreateParams,
|
||||||
|
MessageCreateParamsBase,
|
||||||
|
RedactedThinkingBlockParam,
|
||||||
|
ServerToolUseBlockParam,
|
||||||
|
ThinkingBlockParam,
|
||||||
|
ThinkingConfigParam,
|
||||||
|
ToolUnion,
|
||||||
|
ToolUseBlockParam,
|
||||||
|
WebSearchResultBlock,
|
||||||
|
WebSearchToolResultBlockParam,
|
||||||
|
WebSearchToolResultError
|
||||||
|
} from '@anthropic-ai/sdk/resources/messages'
|
||||||
|
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
|
||||||
|
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||||
|
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||||
|
import Logger from '@renderer/config/logger'
|
||||||
|
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
|
||||||
|
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||||
|
import FileManager from '@renderer/services/FileManager'
|
||||||
|
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||||
|
import {
|
||||||
|
Assistant,
|
||||||
|
EFFORT_RATIO,
|
||||||
|
FileTypes,
|
||||||
|
MCPCallToolResponse,
|
||||||
|
MCPTool,
|
||||||
|
MCPToolResponse,
|
||||||
|
Model,
|
||||||
|
Provider,
|
||||||
|
ToolCallResponse,
|
||||||
|
WebSearchSource
|
||||||
|
} from '@renderer/types'
|
||||||
|
import {
|
||||||
|
ChunkType,
|
||||||
|
ErrorChunk,
|
||||||
|
LLMWebSearchCompleteChunk,
|
||||||
|
LLMWebSearchInProgressChunk,
|
||||||
|
MCPToolCreatedChunk,
|
||||||
|
TextDeltaChunk,
|
||||||
|
ThinkingDeltaChunk
|
||||||
|
} from '@renderer/types/chunk'
|
||||||
|
import type { Message } from '@renderer/types/newMessage'
|
||||||
|
import {
|
||||||
|
AnthropicSdkMessageParam,
|
||||||
|
AnthropicSdkParams,
|
||||||
|
AnthropicSdkRawChunk,
|
||||||
|
AnthropicSdkRawOutput
|
||||||
|
} from '@renderer/types/sdk'
|
||||||
|
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||||
|
import {
|
||||||
|
anthropicToolUseToMcpTool,
|
||||||
|
isEnabledToolUse,
|
||||||
|
mcpToolCallResponseToAnthropicMessage,
|
||||||
|
mcpToolsToAnthropicTools
|
||||||
|
} from '@renderer/utils/mcp-tools'
|
||||||
|
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||||
|
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||||
|
|
||||||
|
import { BaseApiClient } from '../BaseApiClient'
|
||||||
|
import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||||
|
|
||||||
|
export class AnthropicAPIClient extends BaseApiClient<
|
||||||
|
Anthropic,
|
||||||
|
AnthropicSdkParams,
|
||||||
|
AnthropicSdkRawOutput,
|
||||||
|
AnthropicSdkRawChunk,
|
||||||
|
AnthropicSdkMessageParam,
|
||||||
|
ToolUseBlock,
|
||||||
|
ToolUnion
|
||||||
|
> {
|
||||||
|
constructor(provider: Provider) {
|
||||||
|
super(provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
async getSdkInstance(): Promise<Anthropic> {
|
||||||
|
if (this.sdkInstance) {
|
||||||
|
return this.sdkInstance
|
||||||
|
}
|
||||||
|
this.sdkInstance = new Anthropic({
|
||||||
|
apiKey: this.getApiKey(),
|
||||||
|
baseURL: this.getBaseURL(),
|
||||||
|
dangerouslyAllowBrowser: true,
|
||||||
|
defaultHeaders: {
|
||||||
|
'anthropic-beta': 'output-128k-2025-02-19'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return this.sdkInstance
|
||||||
|
}
|
||||||
|
|
||||||
|
override async createCompletions(
|
||||||
|
payload: AnthropicSdkParams,
|
||||||
|
options?: Anthropic.RequestOptions
|
||||||
|
): Promise<AnthropicSdkRawOutput> {
|
||||||
|
const sdk = await this.getSdkInstance()
|
||||||
|
if (payload.stream) {
|
||||||
|
return sdk.messages.stream(payload, options)
|
||||||
|
}
|
||||||
|
return await sdk.messages.create(payload, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
// @ts-ignore sdk未提供
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||||
|
override async generateImage(generateImageParams: GenerateImageParams): Promise<string[]> {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
override async listModels(): Promise<Anthropic.ModelInfo[]> {
|
||||||
|
const sdk = await this.getSdkInstance()
|
||||||
|
const response = await sdk.models.list()
|
||||||
|
return response.data
|
||||||
|
}
|
||||||
|
|
||||||
|
// @ts-ignore sdk未提供
|
||||||
|
override async getEmbeddingDimensions(): Promise<number> {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
override getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||||
|
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
return assistant.settings?.temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
override getTopP(assistant: Assistant, model: Model): number | undefined {
|
||||||
|
if (assistant.settings?.reasoning_effort && isClaudeReasoningModel(model)) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
return assistant.settings?.topP
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the reasoning effort
|
||||||
|
* @param assistant - The assistant
|
||||||
|
* @param model - The model
|
||||||
|
* @returns The reasoning effort
|
||||||
|
*/
|
||||||
|
private getBudgetToken(assistant: Assistant, model: Model): ThinkingConfigParam | undefined {
|
||||||
|
if (!isReasoningModel(model)) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
const { maxTokens } = getAssistantSettings(assistant)
|
||||||
|
|
||||||
|
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||||
|
|
||||||
|
if (reasoningEffort === undefined) {
|
||||||
|
return {
|
||||||
|
type: 'disabled'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||||
|
|
||||||
|
const budgetTokens = Math.max(
|
||||||
|
1024,
|
||||||
|
Math.floor(
|
||||||
|
Math.min(
|
||||||
|
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
|
||||||
|
findTokenLimit(model.id)?.min!,
|
||||||
|
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
|
||||||
|
)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
type: 'enabled',
|
||||||
|
budget_tokens: budgetTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the message parameter
|
||||||
|
* @param message - The message
|
||||||
|
* @param model - The model
|
||||||
|
* @returns The message parameter
|
||||||
|
*/
|
||||||
|
public async convertMessageToSdkParam(message: Message): Promise<AnthropicSdkMessageParam> {
|
||||||
|
const parts: MessageParam['content'] = [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: getMainTextContent(message)
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
// Get and process image blocks
|
||||||
|
const imageBlocks = findImageBlocks(message)
|
||||||
|
for (const imageBlock of imageBlocks) {
|
||||||
|
if (imageBlock.file) {
|
||||||
|
// Handle uploaded file
|
||||||
|
const file = imageBlock.file
|
||||||
|
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||||
|
parts.push({
|
||||||
|
type: 'image',
|
||||||
|
source: {
|
||||||
|
data: base64Data.base64,
|
||||||
|
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
|
||||||
|
type: 'base64'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Get and process file blocks
|
||||||
|
const fileBlocks = findFileBlocks(message)
|
||||||
|
for (const fileBlock of fileBlocks) {
|
||||||
|
const { file } = fileBlock
|
||||||
|
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||||
|
if (file.ext === '.pdf' && file.size < 32 * 1024 * 1024) {
|
||||||
|
const base64Data = await FileManager.readBase64File(file)
|
||||||
|
parts.push({
|
||||||
|
type: 'document',
|
||||||
|
source: {
|
||||||
|
type: 'base64',
|
||||||
|
media_type: 'application/pdf',
|
||||||
|
data: base64Data
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
|
||||||
|
parts.push({
|
||||||
|
type: 'text',
|
||||||
|
text: file.origin_name + '\n' + fileContent
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
role: message.role === 'system' ? 'user' : message.role,
|
||||||
|
content: parts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ToolUnion[] {
|
||||||
|
return mcpToolsToAnthropicTools(mcpTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
public convertMcpToolResponseToSdkMessageParam(
|
||||||
|
mcpToolResponse: MCPToolResponse,
|
||||||
|
resp: MCPCallToolResponse,
|
||||||
|
model: Model
|
||||||
|
): AnthropicSdkMessageParam | undefined {
|
||||||
|
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||||
|
return mcpToolCallResponseToAnthropicMessage(mcpToolResponse, resp, model)
|
||||||
|
} else if ('toolCallId' in mcpToolResponse) {
|
||||||
|
return {
|
||||||
|
role: 'user',
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'tool_result',
|
||||||
|
tool_use_id: mcpToolResponse.toolCallId!,
|
||||||
|
content: resp.content
|
||||||
|
.map((item) => {
|
||||||
|
if (item.type === 'text') {
|
||||||
|
return {
|
||||||
|
type: 'text',
|
||||||
|
text: item.text || ''
|
||||||
|
} satisfies TextBlockParam
|
||||||
|
}
|
||||||
|
if (item.type === 'image') {
|
||||||
|
return {
|
||||||
|
type: 'image',
|
||||||
|
source: {
|
||||||
|
data: item.data || '',
|
||||||
|
media_type: (item.mimeType || 'image/png') as Base64ImageSource['media_type'],
|
||||||
|
type: 'base64'
|
||||||
|
}
|
||||||
|
} satisfies ImageBlockParam
|
||||||
|
}
|
||||||
|
return
|
||||||
|
})
|
||||||
|
.filter((n) => typeof n !== 'undefined'),
|
||||||
|
is_error: resp.isError
|
||||||
|
} satisfies ToolResultBlockParam
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Implementing abstract methods from BaseApiClient
|
||||||
|
convertSdkToolCallToMcp(toolCall: ToolUseBlock, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||||
|
// Based on anthropicToolUseToMcpTool logic in AnthropicProvider
|
||||||
|
// This might need adjustment based on how tool calls are specifically handled in the new structure
|
||||||
|
const mcpTool = anthropicToolUseToMcpTool(mcpTools, toolCall)
|
||||||
|
return mcpTool
|
||||||
|
}
|
||||||
|
|
||||||
|
convertSdkToolCallToMcpToolResponse(toolCall: ToolUseBlock, mcpTool: MCPTool): ToolCallResponse {
|
||||||
|
return {
|
||||||
|
id: toolCall.id,
|
||||||
|
toolCallId: toolCall.id,
|
||||||
|
tool: mcpTool,
|
||||||
|
arguments: toolCall.input as Record<string, unknown>,
|
||||||
|
status: 'pending'
|
||||||
|
} as ToolCallResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
override buildSdkMessages(
|
||||||
|
currentReqMessages: AnthropicSdkMessageParam[],
|
||||||
|
output: Anthropic.Message,
|
||||||
|
toolResults: AnthropicSdkMessageParam[]
|
||||||
|
): AnthropicSdkMessageParam[] {
|
||||||
|
const assistantMessage: AnthropicSdkMessageParam = {
|
||||||
|
role: output.role,
|
||||||
|
content: convertContentBlocksToParams(output.content)
|
||||||
|
}
|
||||||
|
|
||||||
|
const newMessages: AnthropicSdkMessageParam[] = [...currentReqMessages, assistantMessage]
|
||||||
|
if (toolResults && toolResults.length > 0) {
|
||||||
|
newMessages.push(...toolResults)
|
||||||
|
}
|
||||||
|
return newMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
override estimateMessageTokens(message: AnthropicSdkMessageParam): number {
|
||||||
|
if (typeof message.content === 'string') {
|
||||||
|
return estimateTextTokens(message.content)
|
||||||
|
}
|
||||||
|
return message.content
|
||||||
|
.map((content) => {
|
||||||
|
switch (content.type) {
|
||||||
|
case 'text':
|
||||||
|
return estimateTextTokens(content.text)
|
||||||
|
case 'image':
|
||||||
|
if (content.source.type === 'base64') {
|
||||||
|
return estimateTextTokens(content.source.data)
|
||||||
|
} else {
|
||||||
|
return estimateTextTokens(content.source.url)
|
||||||
|
}
|
||||||
|
case 'tool_use':
|
||||||
|
return estimateTextTokens(JSON.stringify(content.input))
|
||||||
|
case 'tool_result':
|
||||||
|
return estimateTextTokens(JSON.stringify(content.content))
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.reduce((acc, curr) => acc + curr, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
public buildAssistantMessage(message: Anthropic.Message): AnthropicSdkMessageParam {
|
||||||
|
const messageParam: AnthropicSdkMessageParam = {
|
||||||
|
role: message.role,
|
||||||
|
content: convertContentBlocksToParams(message.content)
|
||||||
|
}
|
||||||
|
return messageParam
|
||||||
|
}
|
||||||
|
|
||||||
|
public extractMessagesFromSdkPayload(sdkPayload: AnthropicSdkParams): AnthropicSdkMessageParam[] {
|
||||||
|
return sdkPayload.messages || []
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Anthropic专用的原始流监听器
|
||||||
|
* 处理MessageStream对象的特定事件
|
||||||
|
*/
|
||||||
|
override attachRawStreamListener(
|
||||||
|
rawOutput: AnthropicSdkRawOutput,
|
||||||
|
listener: RawStreamListener<AnthropicSdkRawChunk>
|
||||||
|
): AnthropicSdkRawOutput {
|
||||||
|
console.log(`[AnthropicApiClient] 附加流监听器到原始输出`)
|
||||||
|
|
||||||
|
// 检查是否为MessageStream
|
||||||
|
if (rawOutput instanceof MessageStream) {
|
||||||
|
console.log(`[AnthropicApiClient] 检测到 Anthropic MessageStream,附加专用监听器`)
|
||||||
|
|
||||||
|
if (listener.onStart) {
|
||||||
|
listener.onStart()
|
||||||
|
}
|
||||||
|
|
||||||
|
if (listener.onChunk) {
|
||||||
|
rawOutput.on('streamEvent', (event: AnthropicSdkRawChunk) => {
|
||||||
|
listener.onChunk!(event)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 专用的Anthropic事件处理
|
||||||
|
const anthropicListener = listener as AnthropicStreamListener
|
||||||
|
|
||||||
|
if (anthropicListener.onContentBlock) {
|
||||||
|
rawOutput.on('contentBlock', anthropicListener.onContentBlock)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (anthropicListener.onMessage) {
|
||||||
|
rawOutput.on('finalMessage', anthropicListener.onMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (listener.onEnd) {
|
||||||
|
rawOutput.on('end', () => {
|
||||||
|
listener.onEnd!()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if (listener.onError) {
|
||||||
|
rawOutput.on('error', (error: Error) => {
|
||||||
|
listener.onError!(error)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return rawOutput
|
||||||
|
}
|
||||||
|
|
||||||
|
// 对于非MessageStream响应
|
||||||
|
return rawOutput
|
||||||
|
}
|
||||||
|
|
||||||
|
private async getWebSearchParams(model: Model): Promise<WebSearchTool20250305 | undefined> {
|
||||||
|
if (!isWebSearchModel(model)) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
type: 'web_search_20250305',
|
||||||
|
name: 'web_search',
|
||||||
|
max_uses: 5
|
||||||
|
} as WebSearchTool20250305
|
||||||
|
}
|
||||||
|
|
||||||
|
getRequestTransformer(): RequestTransformer<AnthropicSdkParams, AnthropicSdkMessageParam> {
|
||||||
|
return {
|
||||||
|
transform: async (
|
||||||
|
coreRequest,
|
||||||
|
assistant,
|
||||||
|
model,
|
||||||
|
isRecursiveCall,
|
||||||
|
recursiveSdkMessages
|
||||||
|
): Promise<{
|
||||||
|
payload: AnthropicSdkParams
|
||||||
|
messages: AnthropicSdkMessageParam[]
|
||||||
|
metadata: Record<string, any>
|
||||||
|
}> => {
|
||||||
|
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
|
||||||
|
// 1. 处理系统消息
|
||||||
|
let systemPrompt = assistant.prompt
|
||||||
|
|
||||||
|
// 2. 设置工具
|
||||||
|
const { tools } = this.setupToolsConfig({
|
||||||
|
mcpTools: mcpTools,
|
||||||
|
model,
|
||||||
|
enableToolUse: isEnabledToolUse(assistant)
|
||||||
|
})
|
||||||
|
|
||||||
|
if (this.useSystemPromptForTools) {
|
||||||
|
systemPrompt = await buildSystemPrompt(systemPrompt, mcpTools, assistant)
|
||||||
|
}
|
||||||
|
|
||||||
|
const systemMessage: TextBlockParam | undefined = systemPrompt
|
||||||
|
? { type: 'text', text: systemPrompt }
|
||||||
|
: undefined
|
||||||
|
|
||||||
|
// 3. 处理用户消息
|
||||||
|
const sdkMessages: AnthropicSdkMessageParam[] = []
|
||||||
|
if (typeof messages === 'string') {
|
||||||
|
sdkMessages.push({ role: 'user', content: messages })
|
||||||
|
} else {
|
||||||
|
const processedMessages = addImageFileToContents(messages)
|
||||||
|
for (const message of processedMessages) {
|
||||||
|
sdkMessages.push(await this.convertMessageToSdkParam(message))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (enableWebSearch) {
|
||||||
|
const webSearchTool = await this.getWebSearchParams(model)
|
||||||
|
if (webSearchTool) {
|
||||||
|
tools.push(webSearchTool)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const commonParams: MessageCreateParamsBase = {
|
||||||
|
model: model.id,
|
||||||
|
messages:
|
||||||
|
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||||
|
? recursiveSdkMessages
|
||||||
|
: sdkMessages,
|
||||||
|
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
|
||||||
|
temperature: this.getTemperature(assistant, model),
|
||||||
|
top_p: this.getTopP(assistant, model),
|
||||||
|
system: systemMessage ? [systemMessage] : undefined,
|
||||||
|
thinking: this.getBudgetToken(assistant, model),
|
||||||
|
tools: tools.length > 0 ? tools : undefined,
|
||||||
|
...this.getCustomParameters(assistant)
|
||||||
|
}
|
||||||
|
|
||||||
|
const finalParams: MessageCreateParams = streamOutput
|
||||||
|
? {
|
||||||
|
...commonParams,
|
||||||
|
stream: true
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
...commonParams,
|
||||||
|
stream: false
|
||||||
|
}
|
||||||
|
|
||||||
|
const timeout = this.getTimeout(model)
|
||||||
|
return { payload: finalParams, messages: sdkMessages, metadata: { timeout } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
getResponseChunkTransformer(): ResponseChunkTransformer<AnthropicSdkRawChunk> {
|
||||||
|
return () => {
|
||||||
|
let accumulatedJson = ''
|
||||||
|
const toolCalls: Record<number, ToolUseBlock> = {}
|
||||||
|
|
||||||
|
return {
|
||||||
|
async transform(rawChunk: AnthropicSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||||
|
switch (rawChunk.type) {
|
||||||
|
case 'message': {
|
||||||
|
for (const content of rawChunk.content) {
|
||||||
|
switch (content.type) {
|
||||||
|
case 'text': {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.TEXT_DELTA,
|
||||||
|
text: content.text
|
||||||
|
} as TextDeltaChunk)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'tool_use': {
|
||||||
|
toolCalls[0] = content
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'thinking': {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.THINKING_DELTA,
|
||||||
|
text: content.thinking
|
||||||
|
} as ThinkingDeltaChunk)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'web_search_tool_result': {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||||
|
llm_web_search: {
|
||||||
|
results: content.content,
|
||||||
|
source: WebSearchSource.ANTHROPIC
|
||||||
|
}
|
||||||
|
} as LLMWebSearchCompleteChunk)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'content_block_start': {
|
||||||
|
const contentBlock = rawChunk.content_block
|
||||||
|
switch (contentBlock.type) {
|
||||||
|
case 'server_tool_use': {
|
||||||
|
if (contentBlock.name === 'web_search') {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.LLM_WEB_SEARCH_IN_PROGRESS
|
||||||
|
} as LLMWebSearchInProgressChunk)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'web_search_tool_result': {
|
||||||
|
if (
|
||||||
|
contentBlock.content &&
|
||||||
|
(contentBlock.content as WebSearchToolResultError).type === 'web_search_tool_result_error'
|
||||||
|
) {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.ERROR,
|
||||||
|
error: {
|
||||||
|
code: (contentBlock.content as WebSearchToolResultError).error_code,
|
||||||
|
message: (contentBlock.content as WebSearchToolResultError).error_code
|
||||||
|
}
|
||||||
|
} as ErrorChunk)
|
||||||
|
} else {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||||
|
llm_web_search: {
|
||||||
|
results: contentBlock.content as Array<WebSearchResultBlock>,
|
||||||
|
source: WebSearchSource.ANTHROPIC
|
||||||
|
}
|
||||||
|
} as LLMWebSearchCompleteChunk)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'tool_use': {
|
||||||
|
toolCalls[rawChunk.index] = contentBlock
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'content_block_delta': {
|
||||||
|
const messageDelta = rawChunk.delta
|
||||||
|
switch (messageDelta.type) {
|
||||||
|
case 'text_delta': {
|
||||||
|
if (messageDelta.text) {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.TEXT_DELTA,
|
||||||
|
text: messageDelta.text
|
||||||
|
} as TextDeltaChunk)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'thinking_delta': {
|
||||||
|
if (messageDelta.thinking) {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.THINKING_DELTA,
|
||||||
|
text: messageDelta.thinking
|
||||||
|
} as ThinkingDeltaChunk)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'input_json_delta': {
|
||||||
|
if (messageDelta.partial_json) {
|
||||||
|
accumulatedJson += messageDelta.partial_json
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'content_block_stop': {
|
||||||
|
const toolCall = toolCalls[rawChunk.index]
|
||||||
|
if (toolCall) {
|
||||||
|
try {
|
||||||
|
toolCall.input = JSON.parse(accumulatedJson)
|
||||||
|
Logger.debug(`Tool call id: ${toolCall.id}, accumulated json: ${accumulatedJson}`)
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.MCP_TOOL_CREATED,
|
||||||
|
tool_calls: [toolCall]
|
||||||
|
} as MCPToolCreatedChunk)
|
||||||
|
} catch (error) {
|
||||||
|
Logger.error(`Error parsing tool call input: ${error}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'message_delta': {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||||
|
response: {
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: rawChunk.usage.input_tokens || 0,
|
||||||
|
completion_tokens: rawChunk.usage.output_tokens || 0,
|
||||||
|
total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 将 ContentBlock 数组转换为 ContentBlockParam 数组
|
||||||
|
* 去除服务器生成的额外字段,只保留发送给API所需的字段
|
||||||
|
*/
|
||||||
|
function convertContentBlocksToParams(contentBlocks: ContentBlock[]): ContentBlockParam[] {
|
||||||
|
return contentBlocks.map((block): ContentBlockParam => {
|
||||||
|
switch (block.type) {
|
||||||
|
case 'text':
|
||||||
|
// TextBlock -> TextBlockParam,去除 citations 等服务器字段
|
||||||
|
return {
|
||||||
|
type: 'text',
|
||||||
|
text: block.text
|
||||||
|
} satisfies TextBlockParam
|
||||||
|
case 'tool_use':
|
||||||
|
// ToolUseBlock -> ToolUseBlockParam
|
||||||
|
return {
|
||||||
|
type: 'tool_use',
|
||||||
|
id: block.id,
|
||||||
|
name: block.name,
|
||||||
|
input: block.input
|
||||||
|
} satisfies ToolUseBlockParam
|
||||||
|
case 'thinking':
|
||||||
|
// ThinkingBlock -> ThinkingBlockParam
|
||||||
|
return {
|
||||||
|
type: 'thinking',
|
||||||
|
thinking: block.thinking,
|
||||||
|
signature: block.signature
|
||||||
|
} satisfies ThinkingBlockParam
|
||||||
|
case 'redacted_thinking':
|
||||||
|
// RedactedThinkingBlock -> RedactedThinkingBlockParam
|
||||||
|
return {
|
||||||
|
type: 'redacted_thinking',
|
||||||
|
data: block.data
|
||||||
|
} satisfies RedactedThinkingBlockParam
|
||||||
|
case 'server_tool_use':
|
||||||
|
// ServerToolUseBlock -> ServerToolUseBlockParam
|
||||||
|
return {
|
||||||
|
type: 'server_tool_use',
|
||||||
|
id: block.id,
|
||||||
|
name: block.name,
|
||||||
|
input: block.input
|
||||||
|
} satisfies ServerToolUseBlockParam
|
||||||
|
case 'web_search_tool_result':
|
||||||
|
// WebSearchToolResultBlock -> WebSearchToolResultBlockParam
|
||||||
|
return {
|
||||||
|
type: 'web_search_tool_result',
|
||||||
|
tool_use_id: block.tool_use_id,
|
||||||
|
content: block.content
|
||||||
|
} satisfies WebSearchToolResultBlockParam
|
||||||
|
default:
|
||||||
|
return block as ContentBlockParam
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
797
src/renderer/src/aiCore/clients/gemini/GeminiAPIClient.ts
Normal file
@@ -0,0 +1,797 @@
|
|||||||
|
import {
|
||||||
|
Content,
|
||||||
|
File,
|
||||||
|
FileState,
|
||||||
|
FunctionCall,
|
||||||
|
GenerateContentConfig,
|
||||||
|
GenerateImagesConfig,
|
||||||
|
GoogleGenAI,
|
||||||
|
HarmBlockThreshold,
|
||||||
|
HarmCategory,
|
||||||
|
Modality,
|
||||||
|
Model as GeminiModel,
|
||||||
|
Pager,
|
||||||
|
Part,
|
||||||
|
SafetySetting,
|
||||||
|
SendMessageParameters,
|
||||||
|
ThinkingConfig,
|
||||||
|
Tool
|
||||||
|
} from '@google/genai'
|
||||||
|
import { nanoid } from '@reduxjs/toolkit'
|
||||||
|
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||||
|
import {
|
||||||
|
findTokenLimit,
|
||||||
|
GEMINI_FLASH_MODEL_REGEX,
|
||||||
|
isGeminiReasoningModel,
|
||||||
|
isGemmaModel,
|
||||||
|
isVisionModel
|
||||||
|
} from '@renderer/config/models'
|
||||||
|
import { CacheService } from '@renderer/services/CacheService'
|
||||||
|
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||||
|
import {
|
||||||
|
Assistant,
|
||||||
|
EFFORT_RATIO,
|
||||||
|
FileType,
|
||||||
|
FileTypes,
|
||||||
|
GenerateImageParams,
|
||||||
|
MCPCallToolResponse,
|
||||||
|
MCPTool,
|
||||||
|
MCPToolResponse,
|
||||||
|
Model,
|
||||||
|
Provider,
|
||||||
|
ToolCallResponse,
|
||||||
|
WebSearchSource
|
||||||
|
} from '@renderer/types'
|
||||||
|
import { ChunkType, LLMWebSearchCompleteChunk } from '@renderer/types/chunk'
|
||||||
|
import { Message } from '@renderer/types/newMessage'
|
||||||
|
import {
|
||||||
|
GeminiOptions,
|
||||||
|
GeminiSdkMessageParam,
|
||||||
|
GeminiSdkParams,
|
||||||
|
GeminiSdkRawChunk,
|
||||||
|
GeminiSdkRawOutput,
|
||||||
|
GeminiSdkToolCall
|
||||||
|
} from '@renderer/types/sdk'
|
||||||
|
import {
|
||||||
|
geminiFunctionCallToMcpTool,
|
||||||
|
isEnabledToolUse,
|
||||||
|
mcpToolCallResponseToGeminiMessage,
|
||||||
|
mcpToolsToGeminiTools
|
||||||
|
} from '@renderer/utils/mcp-tools'
|
||||||
|
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||||
|
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||||
|
import { MB } from '@shared/config/constant'
|
||||||
|
|
||||||
|
import { BaseApiClient } from '../BaseApiClient'
|
||||||
|
import { RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||||
|
|
||||||
|
export class GeminiAPIClient extends BaseApiClient<
|
||||||
|
GoogleGenAI,
|
||||||
|
GeminiSdkParams,
|
||||||
|
GeminiSdkRawOutput,
|
||||||
|
GeminiSdkRawChunk,
|
||||||
|
GeminiSdkMessageParam,
|
||||||
|
GeminiSdkToolCall,
|
||||||
|
Tool
|
||||||
|
> {
|
||||||
|
constructor(provider: Provider) {
|
||||||
|
super(provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
override async createCompletions(payload: GeminiSdkParams, options?: GeminiOptions): Promise<GeminiSdkRawOutput> {
|
||||||
|
const sdk = await this.getSdkInstance()
|
||||||
|
const { model, history, ...rest } = payload
|
||||||
|
const realPayload: Omit<GeminiSdkParams, 'model'> = {
|
||||||
|
...rest,
|
||||||
|
config: {
|
||||||
|
...rest.config,
|
||||||
|
abortSignal: options?.abortSignal,
|
||||||
|
httpOptions: {
|
||||||
|
...rest.config?.httpOptions,
|
||||||
|
timeout: options?.timeout
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} satisfies SendMessageParameters
|
||||||
|
|
||||||
|
const streamOutput = options?.streamOutput
|
||||||
|
|
||||||
|
const chat = sdk.chats.create({
|
||||||
|
model: model,
|
||||||
|
history: history
|
||||||
|
})
|
||||||
|
|
||||||
|
if (streamOutput) {
|
||||||
|
const stream = chat.sendMessageStream(realPayload)
|
||||||
|
return stream
|
||||||
|
} else {
|
||||||
|
const response = await chat.sendMessage(realPayload)
|
||||||
|
return response
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override async generateImage(generateImageParams: GenerateImageParams): Promise<string[]> {
|
||||||
|
const sdk = await this.getSdkInstance()
|
||||||
|
try {
|
||||||
|
const { model, prompt, imageSize, batchSize, signal } = generateImageParams
|
||||||
|
const config: GenerateImagesConfig = {
|
||||||
|
numberOfImages: batchSize,
|
||||||
|
aspectRatio: imageSize,
|
||||||
|
abortSignal: signal,
|
||||||
|
httpOptions: {
|
||||||
|
timeout: 5 * 60 * 1000
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const response = await sdk.models.generateImages({
|
||||||
|
model: model,
|
||||||
|
prompt,
|
||||||
|
config
|
||||||
|
})
|
||||||
|
|
||||||
|
if (!response.generatedImages || response.generatedImages.length === 0) {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
const images = response.generatedImages
|
||||||
|
.filter((image) => image.image?.imageBytes)
|
||||||
|
.map((image) => {
|
||||||
|
const dataPrefix = `data:${image.image?.mimeType || 'image/png'};base64,`
|
||||||
|
return dataPrefix + image.image?.imageBytes
|
||||||
|
})
|
||||||
|
// console.log(response?.generatedImages?.[0]?.image?.imageBytes);
|
||||||
|
return images
|
||||||
|
} catch (error) {
|
||||||
|
console.error('[generateImage] error:', error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||||
|
const sdk = await this.getSdkInstance()
|
||||||
|
try {
|
||||||
|
const data = await sdk.models.embedContent({
|
||||||
|
model: model.id,
|
||||||
|
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
|
||||||
|
})
|
||||||
|
return data.embeddings?.[0]?.values?.length || 0
|
||||||
|
} catch (e) {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override async listModels(): Promise<GeminiModel[]> {
|
||||||
|
const sdk = await this.getSdkInstance()
|
||||||
|
const response = await sdk.models.list()
|
||||||
|
const models: GeminiModel[] = []
|
||||||
|
for await (const model of response) {
|
||||||
|
models.push(model)
|
||||||
|
}
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
override async getSdkInstance() {
|
||||||
|
if (this.sdkInstance) {
|
||||||
|
return this.sdkInstance
|
||||||
|
}
|
||||||
|
|
||||||
|
this.sdkInstance = new GoogleGenAI({
|
||||||
|
vertexai: false,
|
||||||
|
apiKey: this.apiKey,
|
||||||
|
apiVersion: this.getApiVersion(),
|
||||||
|
httpOptions: {
|
||||||
|
baseUrl: this.getBaseURL(),
|
||||||
|
apiVersion: this.getApiVersion()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return this.sdkInstance
|
||||||
|
}
|
||||||
|
|
||||||
|
protected getApiVersion(): string {
|
||||||
|
if (this.provider.isVertex) {
|
||||||
|
return 'v1'
|
||||||
|
}
|
||||||
|
return 'v1beta'
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Handle a PDF file
|
||||||
|
* @param file - The file
|
||||||
|
* @returns The part
|
||||||
|
*/
|
||||||
|
private async handlePdfFile(file: FileType): Promise<Part> {
|
||||||
|
const smallFileSize = 20 * MB
|
||||||
|
const isSmallFile = file.size < smallFileSize
|
||||||
|
|
||||||
|
if (isSmallFile) {
|
||||||
|
const { data, mimeType } = await this.base64File(file)
|
||||||
|
return {
|
||||||
|
inlineData: {
|
||||||
|
data,
|
||||||
|
mimeType
|
||||||
|
} as Part['inlineData']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Retrieve file from Gemini uploaded files
|
||||||
|
const fileMetadata: File | undefined = await this.retrieveFile(file)
|
||||||
|
|
||||||
|
if (fileMetadata) {
|
||||||
|
return {
|
||||||
|
fileData: {
|
||||||
|
fileUri: fileMetadata.uri,
|
||||||
|
mimeType: fileMetadata.mimeType
|
||||||
|
} as Part['fileData']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// If file is not found, upload it to Gemini
|
||||||
|
const result = await this.uploadFile(file)
|
||||||
|
|
||||||
|
return {
|
||||||
|
fileData: {
|
||||||
|
fileUri: result.uri,
|
||||||
|
mimeType: result.mimeType
|
||||||
|
} as Part['fileData']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the message contents
|
||||||
|
* @param message - The message
|
||||||
|
* @returns The message contents
|
||||||
|
*/
|
||||||
|
private async convertMessageToSdkParam(message: Message): Promise<Content> {
|
||||||
|
const role = message.role === 'user' ? 'user' : 'model'
|
||||||
|
const parts: Part[] = [{ text: await this.getMessageContent(message) }]
|
||||||
|
// Add any generated images from previous responses
|
||||||
|
const imageBlocks = findImageBlocks(message)
|
||||||
|
for (const imageBlock of imageBlocks) {
|
||||||
|
if (
|
||||||
|
imageBlock.metadata?.generateImageResponse?.images &&
|
||||||
|
imageBlock.metadata.generateImageResponse.images.length > 0
|
||||||
|
) {
|
||||||
|
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
|
||||||
|
if (imageUrl && imageUrl.startsWith('data:')) {
|
||||||
|
// Extract base64 data and mime type from the data URL
|
||||||
|
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
|
||||||
|
if (matches && matches.length === 3) {
|
||||||
|
const mimeType = matches[1]
|
||||||
|
const base64Data = matches[2]
|
||||||
|
parts.push({
|
||||||
|
inlineData: {
|
||||||
|
data: base64Data,
|
||||||
|
mimeType: mimeType
|
||||||
|
} as Part['inlineData']
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const file = imageBlock.file
|
||||||
|
if (file) {
|
||||||
|
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||||
|
parts.push({
|
||||||
|
inlineData: {
|
||||||
|
data: base64Data.base64,
|
||||||
|
mimeType: base64Data.mime
|
||||||
|
} as Part['inlineData']
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const fileBlocks = findFileBlocks(message)
|
||||||
|
for (const fileBlock of fileBlocks) {
|
||||||
|
const file = fileBlock.file
|
||||||
|
if (file.type === FileTypes.IMAGE) {
|
||||||
|
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||||
|
parts.push({
|
||||||
|
inlineData: {
|
||||||
|
data: base64Data.base64,
|
||||||
|
mimeType: base64Data.mime
|
||||||
|
} as Part['inlineData']
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if (file.ext === '.pdf') {
|
||||||
|
parts.push(await this.handlePdfFile(file))
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||||
|
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
|
||||||
|
parts.push({
|
||||||
|
text: file.origin_name + '\n' + fileContent
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
role,
|
||||||
|
parts: parts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// @ts-ignore unused
|
||||||
|
private async getImageFileContents(message: Message): Promise<Content> {
|
||||||
|
const role = message.role === 'user' ? 'user' : 'model'
|
||||||
|
const content = getMainTextContent(message)
|
||||||
|
const parts: Part[] = [{ text: content }]
|
||||||
|
const imageBlocks = findImageBlocks(message)
|
||||||
|
for (const imageBlock of imageBlocks) {
|
||||||
|
if (
|
||||||
|
imageBlock.metadata?.generateImageResponse?.images &&
|
||||||
|
imageBlock.metadata.generateImageResponse.images.length > 0
|
||||||
|
) {
|
||||||
|
for (const imageUrl of imageBlock.metadata.generateImageResponse.images) {
|
||||||
|
if (imageUrl && imageUrl.startsWith('data:')) {
|
||||||
|
// Extract base64 data and mime type from the data URL
|
||||||
|
const matches = imageUrl.match(/^data:(.+);base64,(.*)$/)
|
||||||
|
if (matches && matches.length === 3) {
|
||||||
|
const mimeType = matches[1]
|
||||||
|
const base64Data = matches[2]
|
||||||
|
parts.push({
|
||||||
|
inlineData: {
|
||||||
|
data: base64Data,
|
||||||
|
mimeType: mimeType
|
||||||
|
} as Part['inlineData']
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
const file = imageBlock.file
|
||||||
|
if (file) {
|
||||||
|
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||||
|
parts.push({
|
||||||
|
inlineData: {
|
||||||
|
data: base64Data.base64,
|
||||||
|
mimeType: base64Data.mime
|
||||||
|
} as Part['inlineData']
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
role,
|
||||||
|
parts: parts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the safety settings
|
||||||
|
* @returns The safety settings
|
||||||
|
*/
|
||||||
|
private getSafetySettings(): SafetySetting[] {
|
||||||
|
const safetyThreshold = 'OFF' as HarmBlockThreshold
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
category: HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||||
|
threshold: safetyThreshold
|
||||||
|
},
|
||||||
|
{
|
||||||
|
category: HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||||
|
threshold: safetyThreshold
|
||||||
|
},
|
||||||
|
{
|
||||||
|
category: HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||||
|
threshold: safetyThreshold
|
||||||
|
},
|
||||||
|
{
|
||||||
|
category: HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||||
|
threshold: safetyThreshold
|
||||||
|
},
|
||||||
|
{
|
||||||
|
category: HarmCategory.HARM_CATEGORY_CIVIC_INTEGRITY,
|
||||||
|
threshold: HarmBlockThreshold.BLOCK_NONE
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the reasoning effort for the assistant
|
||||||
|
* @param assistant - The assistant
|
||||||
|
* @param model - The model
|
||||||
|
* @returns The reasoning effort
|
||||||
|
*/
|
||||||
|
private getBudgetToken(assistant: Assistant, model: Model) {
|
||||||
|
if (isGeminiReasoningModel(model)) {
|
||||||
|
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||||
|
|
||||||
|
// 如果thinking_budget是undefined,不思考
|
||||||
|
if (reasoningEffort === undefined) {
|
||||||
|
return {
|
||||||
|
thinkingConfig: {
|
||||||
|
includeThoughts: false,
|
||||||
|
...(GEMINI_FLASH_MODEL_REGEX.test(model.id) ? { thinkingBudget: 0 } : {})
|
||||||
|
} as ThinkingConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||||
|
|
||||||
|
if (effortRatio > 1) {
|
||||||
|
return {
|
||||||
|
thinkingConfig: {
|
||||||
|
includeThoughts: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const { max } = findTokenLimit(model.id) || { max: 0 }
|
||||||
|
const budget = Math.floor(max * effortRatio)
|
||||||
|
|
||||||
|
return {
|
||||||
|
thinkingConfig: {
|
||||||
|
...(budget > 0 ? { thinkingBudget: budget } : {}),
|
||||||
|
includeThoughts: true
|
||||||
|
} as ThinkingConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
|
||||||
|
private getGenerateImageParameter(): Partial<GenerateContentConfig> {
|
||||||
|
return {
|
||||||
|
systemInstruction: undefined,
|
||||||
|
responseModalities: [Modality.TEXT, Modality.IMAGE],
|
||||||
|
responseMimeType: 'text/plain'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
getRequestTransformer(): RequestTransformer<GeminiSdkParams, GeminiSdkMessageParam> {
|
||||||
|
return {
|
||||||
|
transform: async (
|
||||||
|
coreRequest,
|
||||||
|
assistant,
|
||||||
|
model,
|
||||||
|
isRecursiveCall,
|
||||||
|
recursiveSdkMessages
|
||||||
|
): Promise<{
|
||||||
|
payload: GeminiSdkParams
|
||||||
|
messages: GeminiSdkMessageParam[]
|
||||||
|
metadata: Record<string, any>
|
||||||
|
}> => {
|
||||||
|
const { messages, mcpTools, maxTokens, enableWebSearch, enableGenerateImage } = coreRequest
|
||||||
|
// 1. 处理系统消息
|
||||||
|
let systemInstruction = assistant.prompt
|
||||||
|
|
||||||
|
// 2. 设置工具
|
||||||
|
const { tools } = this.setupToolsConfig({
|
||||||
|
mcpTools,
|
||||||
|
model,
|
||||||
|
enableToolUse: isEnabledToolUse(assistant)
|
||||||
|
})
|
||||||
|
|
||||||
|
if (this.useSystemPromptForTools) {
|
||||||
|
systemInstruction = await buildSystemPrompt(assistant.prompt || '', mcpTools, assistant)
|
||||||
|
}
|
||||||
|
|
||||||
|
let messageContents: Content
|
||||||
|
const history: Content[] = []
|
||||||
|
// 3. 处理用户消息
|
||||||
|
if (typeof messages === 'string') {
|
||||||
|
messageContents = {
|
||||||
|
role: 'user',
|
||||||
|
parts: [{ text: messages }]
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
const userLastMessage = messages.pop()!
|
||||||
|
messageContents = await this.convertMessageToSdkParam(userLastMessage)
|
||||||
|
for (const message of messages) {
|
||||||
|
history.push(await this.convertMessageToSdkParam(message))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (enableWebSearch) {
|
||||||
|
tools.push({
|
||||||
|
googleSearch: {}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isGemmaModel(model) && assistant.prompt) {
|
||||||
|
const isFirstMessage = history.length === 0
|
||||||
|
if (isFirstMessage && messageContents) {
|
||||||
|
const systemMessage = [
|
||||||
|
{
|
||||||
|
text:
|
||||||
|
'<start_of_turn>user\n' +
|
||||||
|
systemInstruction +
|
||||||
|
'<end_of_turn>\n' +
|
||||||
|
'<start_of_turn>user\n' +
|
||||||
|
(messageContents?.parts?.[0] as Part).text +
|
||||||
|
'<end_of_turn>'
|
||||||
|
}
|
||||||
|
] as Part[]
|
||||||
|
if (messageContents && messageContents.parts) {
|
||||||
|
messageContents.parts[0] = systemMessage[0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const newHistory =
|
||||||
|
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||||
|
? recursiveSdkMessages.slice(0, recursiveSdkMessages.length - 1)
|
||||||
|
: history
|
||||||
|
|
||||||
|
const newMessageContents =
|
||||||
|
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||||
|
? {
|
||||||
|
...messageContents,
|
||||||
|
parts: [
|
||||||
|
...(messageContents.parts || []),
|
||||||
|
...(recursiveSdkMessages[recursiveSdkMessages.length - 1].parts || [])
|
||||||
|
]
|
||||||
|
}
|
||||||
|
: messageContents
|
||||||
|
|
||||||
|
const generateContentConfig: GenerateContentConfig = {
|
||||||
|
safetySettings: this.getSafetySettings(),
|
||||||
|
systemInstruction: isGemmaModel(model) ? undefined : systemInstruction,
|
||||||
|
temperature: this.getTemperature(assistant, model),
|
||||||
|
topP: this.getTopP(assistant, model),
|
||||||
|
maxOutputTokens: maxTokens,
|
||||||
|
tools: tools,
|
||||||
|
...(enableGenerateImage ? this.getGenerateImageParameter() : {}),
|
||||||
|
...this.getBudgetToken(assistant, model),
|
||||||
|
...this.getCustomParameters(assistant)
|
||||||
|
}
|
||||||
|
|
||||||
|
const param: GeminiSdkParams = {
|
||||||
|
model: model.id,
|
||||||
|
config: generateContentConfig,
|
||||||
|
history: newHistory,
|
||||||
|
message: newMessageContents.parts!
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
payload: param,
|
||||||
|
messages: [messageContents],
|
||||||
|
metadata: {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
getResponseChunkTransformer(): ResponseChunkTransformer<GeminiSdkRawChunk> {
|
||||||
|
return () => ({
|
||||||
|
async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||||
|
let toolCalls: FunctionCall[] = []
|
||||||
|
if (chunk.candidates && chunk.candidates.length > 0) {
|
||||||
|
for (const candidate of chunk.candidates) {
|
||||||
|
if (candidate.content) {
|
||||||
|
candidate.content.parts?.forEach((part) => {
|
||||||
|
const text = part.text || ''
|
||||||
|
if (part.thought) {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.THINKING_DELTA,
|
||||||
|
text: text
|
||||||
|
})
|
||||||
|
} else if (part.text) {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.TEXT_DELTA,
|
||||||
|
text: text
|
||||||
|
})
|
||||||
|
} else if (part.inlineData) {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.IMAGE_COMPLETE,
|
||||||
|
image: {
|
||||||
|
type: 'base64',
|
||||||
|
images: [
|
||||||
|
part.inlineData?.data?.startsWith('data:')
|
||||||
|
? part.inlineData?.data
|
||||||
|
: `data:${part.inlineData?.mimeType || 'image/png'};base64,${part.inlineData?.data}`
|
||||||
|
]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if (candidate.finishReason) {
|
||||||
|
if (candidate.groundingMetadata) {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||||
|
llm_web_search: {
|
||||||
|
results: candidate.groundingMetadata,
|
||||||
|
source: WebSearchSource.GEMINI
|
||||||
|
}
|
||||||
|
} as LLMWebSearchCompleteChunk)
|
||||||
|
}
|
||||||
|
if (chunk.functionCalls) {
|
||||||
|
toolCalls = toolCalls.concat(chunk.functionCalls)
|
||||||
|
}
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||||
|
response: {
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: chunk.usageMetadata?.promptTokenCount || 0,
|
||||||
|
completion_tokens:
|
||||||
|
(chunk.usageMetadata?.totalTokenCount || 0) - (chunk.usageMetadata?.promptTokenCount || 0),
|
||||||
|
total_tokens: chunk.usageMetadata?.totalTokenCount || 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (toolCalls.length > 0) {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.MCP_TOOL_CREATED,
|
||||||
|
tool_calls: toolCalls
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): Tool[] {
|
||||||
|
return mcpToolsToGeminiTools(mcpTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
public convertSdkToolCallToMcp(toolCall: GeminiSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||||
|
return geminiFunctionCallToMcpTool(mcpTools, toolCall)
|
||||||
|
}
|
||||||
|
|
||||||
|
public convertSdkToolCallToMcpToolResponse(toolCall: GeminiSdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
||||||
|
const parsedArgs = (() => {
|
||||||
|
try {
|
||||||
|
return typeof toolCall.args === 'string' ? JSON.parse(toolCall.args) : toolCall.args
|
||||||
|
} catch {
|
||||||
|
return toolCall.args
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
|
return {
|
||||||
|
id: toolCall.id || nanoid(),
|
||||||
|
toolCallId: toolCall.id,
|
||||||
|
tool: mcpTool,
|
||||||
|
arguments: parsedArgs,
|
||||||
|
status: 'pending'
|
||||||
|
} as ToolCallResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
public convertMcpToolResponseToSdkMessageParam(
|
||||||
|
mcpToolResponse: MCPToolResponse,
|
||||||
|
resp: MCPCallToolResponse,
|
||||||
|
model: Model
|
||||||
|
): GeminiSdkMessageParam | undefined {
|
||||||
|
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||||
|
return mcpToolCallResponseToGeminiMessage(mcpToolResponse, resp, isVisionModel(model))
|
||||||
|
} else if ('toolCallId' in mcpToolResponse) {
|
||||||
|
return {
|
||||||
|
role: 'user',
|
||||||
|
parts: [
|
||||||
|
{
|
||||||
|
functionResponse: {
|
||||||
|
id: mcpToolResponse.toolCallId,
|
||||||
|
name: mcpToolResponse.tool.id,
|
||||||
|
response: {
|
||||||
|
output: !resp.isError ? resp.content : undefined,
|
||||||
|
error: resp.isError ? resp.content : undefined
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
} satisfies Content
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
public buildSdkMessages(
|
||||||
|
currentReqMessages: Content[],
|
||||||
|
output: string,
|
||||||
|
toolResults: Content[],
|
||||||
|
toolCalls: FunctionCall[]
|
||||||
|
): Content[] {
|
||||||
|
const parts: Part[] = []
|
||||||
|
if (output) {
|
||||||
|
parts.push({
|
||||||
|
text: output
|
||||||
|
})
|
||||||
|
}
|
||||||
|
toolCalls.forEach((toolCall) => {
|
||||||
|
parts.push({
|
||||||
|
functionCall: toolCall
|
||||||
|
})
|
||||||
|
})
|
||||||
|
parts.push(
|
||||||
|
...toolResults
|
||||||
|
.map((ts) => ts.parts)
|
||||||
|
.flat()
|
||||||
|
.filter((p) => p !== undefined)
|
||||||
|
)
|
||||||
|
|
||||||
|
const userMessage: Content = {
|
||||||
|
role: 'user',
|
||||||
|
parts: parts
|
||||||
|
}
|
||||||
|
|
||||||
|
return [...currentReqMessages, userMessage]
|
||||||
|
}
|
||||||
|
|
||||||
|
override estimateMessageTokens(message: GeminiSdkMessageParam): number {
|
||||||
|
return (
|
||||||
|
message.parts?.reduce((acc, part) => {
|
||||||
|
if (part.text) {
|
||||||
|
return acc + estimateTextTokens(part.text)
|
||||||
|
}
|
||||||
|
if (part.functionCall) {
|
||||||
|
return acc + estimateTextTokens(JSON.stringify(part.functionCall))
|
||||||
|
}
|
||||||
|
if (part.functionResponse) {
|
||||||
|
return acc + estimateTextTokens(JSON.stringify(part.functionResponse.response))
|
||||||
|
}
|
||||||
|
if (part.inlineData) {
|
||||||
|
return acc + estimateTextTokens(part.inlineData.data || '')
|
||||||
|
}
|
||||||
|
if (part.fileData) {
|
||||||
|
return acc + estimateTextTokens(part.fileData.fileUri || '')
|
||||||
|
}
|
||||||
|
return acc
|
||||||
|
}, 0) || 0
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
public extractMessagesFromSdkPayload(sdkPayload: GeminiSdkParams): GeminiSdkMessageParam[] {
|
||||||
|
return sdkPayload.history || []
|
||||||
|
}
|
||||||
|
|
||||||
|
private async uploadFile(file: FileType): Promise<File> {
|
||||||
|
return await this.sdkInstance!.files.upload({
|
||||||
|
file: file.path,
|
||||||
|
config: {
|
||||||
|
mimeType: 'application/pdf',
|
||||||
|
name: file.id,
|
||||||
|
displayName: file.origin_name
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
private async base64File(file: FileType) {
|
||||||
|
const { data } = await window.api.file.base64File(file.id + file.ext)
|
||||||
|
return {
|
||||||
|
data,
|
||||||
|
mimeType: 'application/pdf'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private async retrieveFile(file: FileType): Promise<File | undefined> {
|
||||||
|
const cachedResponse = CacheService.get<any>('gemini_file_list')
|
||||||
|
|
||||||
|
if (cachedResponse) {
|
||||||
|
return this.processResponse(cachedResponse, file)
|
||||||
|
}
|
||||||
|
|
||||||
|
const response = await this.sdkInstance!.files.list()
|
||||||
|
CacheService.set('gemini_file_list', response, 3000)
|
||||||
|
|
||||||
|
return this.processResponse(response, file)
|
||||||
|
}
|
||||||
|
|
||||||
|
private async processResponse(response: Pager<File>, file: FileType) {
|
||||||
|
for await (const f of response) {
|
||||||
|
if (f.state === FileState.ACTIVE) {
|
||||||
|
if (f.displayName === file.origin_name && Number(f.sizeBytes) === file.size) {
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
// @ts-ignore unused
|
||||||
|
private async listFiles(): Promise<File[]> {
|
||||||
|
const files: File[] = []
|
||||||
|
for await (const f of await this.sdkInstance!.files.list()) {
|
||||||
|
files.push(f)
|
||||||
|
}
|
||||||
|
return files
|
||||||
|
}
|
||||||
|
|
||||||
|
// @ts-ignore unused
|
||||||
|
private async deleteFile(fileId: string) {
|
||||||
|
await this.sdkInstance!.files.delete({ name: fileId })
|
||||||
|
}
|
||||||
|
}
|
||||||
95
src/renderer/src/aiCore/clients/gemini/VertexAPIClient.ts
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
import { GoogleGenAI } from '@google/genai'
|
||||||
|
import { getVertexAILocation, getVertexAIProjectId, getVertexAIServiceAccount } from '@renderer/hooks/useVertexAI'
|
||||||
|
import { Provider } from '@renderer/types'
|
||||||
|
|
||||||
|
import { GeminiAPIClient } from './GeminiAPIClient'
|
||||||
|
|
||||||
|
export class VertexAPIClient extends GeminiAPIClient {
|
||||||
|
private authHeaders?: Record<string, string>
|
||||||
|
private authHeadersExpiry?: number
|
||||||
|
|
||||||
|
constructor(provider: Provider) {
|
||||||
|
super(provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
override async getSdkInstance() {
|
||||||
|
if (this.sdkInstance) {
|
||||||
|
return this.sdkInstance
|
||||||
|
}
|
||||||
|
|
||||||
|
const serviceAccount = getVertexAIServiceAccount()
|
||||||
|
const projectId = getVertexAIProjectId()
|
||||||
|
const location = getVertexAILocation()
|
||||||
|
|
||||||
|
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId || !location) {
|
||||||
|
throw new Error('Vertex AI settings are not configured')
|
||||||
|
}
|
||||||
|
|
||||||
|
const authHeaders = await this.getServiceAccountAuthHeaders()
|
||||||
|
|
||||||
|
this.sdkInstance = new GoogleGenAI({
|
||||||
|
vertexai: true,
|
||||||
|
project: projectId,
|
||||||
|
location: location,
|
||||||
|
httpOptions: {
|
||||||
|
apiVersion: this.getApiVersion(),
|
||||||
|
headers: authHeaders
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return this.sdkInstance
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取认证头,如果配置了 service account 则从主进程获取
|
||||||
|
*/
|
||||||
|
private async getServiceAccountAuthHeaders(): Promise<Record<string, string> | undefined> {
|
||||||
|
const serviceAccount = getVertexAIServiceAccount()
|
||||||
|
const projectId = getVertexAIProjectId()
|
||||||
|
|
||||||
|
// 检查是否配置了 service account
|
||||||
|
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否已有有效的认证头(提前 5 分钟过期)
|
||||||
|
const now = Date.now()
|
||||||
|
if (this.authHeaders && this.authHeadersExpiry && this.authHeadersExpiry - now > 5 * 60 * 1000) {
|
||||||
|
return this.authHeaders
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 从主进程获取认证头
|
||||||
|
this.authHeaders = await window.api.vertexAI.getAuthHeaders({
|
||||||
|
projectId,
|
||||||
|
serviceAccount: {
|
||||||
|
privateKey: serviceAccount.privateKey,
|
||||||
|
clientEmail: serviceAccount.clientEmail
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// 设置过期时间(通常认证头有效期为 1 小时)
|
||||||
|
this.authHeadersExpiry = now + 60 * 60 * 1000
|
||||||
|
|
||||||
|
return this.authHeaders
|
||||||
|
} catch (error: any) {
|
||||||
|
console.error('Failed to get auth headers:', error)
|
||||||
|
throw new Error(`Service Account authentication failed: ${error.message}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 清理认证缓存并重新初始化
|
||||||
|
*/
|
||||||
|
clearAuthCache(): void {
|
||||||
|
this.authHeaders = undefined
|
||||||
|
this.authHeadersExpiry = undefined
|
||||||
|
|
||||||
|
const serviceAccount = getVertexAIServiceAccount()
|
||||||
|
const projectId = getVertexAIProjectId()
|
||||||
|
|
||||||
|
if (projectId && serviceAccount.clientEmail) {
|
||||||
|
window.api.vertexAI.clearAuthCache(projectId, serviceAccount.clientEmail)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
6
src/renderer/src/aiCore/clients/index.ts
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
export * from './ApiClientFactory'
|
||||||
|
export * from './BaseApiClient'
|
||||||
|
export * from './types'
|
||||||
|
|
||||||
|
// Export specific clients from subdirectories
|
||||||
|
export * from './openai/OpenAIApiClient'
|
||||||
682
src/renderer/src/aiCore/clients/openai/OpenAIApiClient.ts
Normal file
@@ -0,0 +1,682 @@
|
|||||||
|
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||||
|
import Logger from '@renderer/config/logger'
|
||||||
|
import {
|
||||||
|
findTokenLimit,
|
||||||
|
GEMINI_FLASH_MODEL_REGEX,
|
||||||
|
getOpenAIWebSearchParams,
|
||||||
|
isDoubaoThinkingAutoModel,
|
||||||
|
isReasoningModel,
|
||||||
|
isSupportedReasoningEffortGrokModel,
|
||||||
|
isSupportedReasoningEffortModel,
|
||||||
|
isSupportedReasoningEffortOpenAIModel,
|
||||||
|
isSupportedThinkingTokenClaudeModel,
|
||||||
|
isSupportedThinkingTokenDoubaoModel,
|
||||||
|
isSupportedThinkingTokenGeminiModel,
|
||||||
|
isSupportedThinkingTokenModel,
|
||||||
|
isSupportedThinkingTokenQwenModel,
|
||||||
|
isVisionModel
|
||||||
|
} from '@renderer/config/models'
|
||||||
|
import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService'
|
||||||
|
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||||
|
// For Copilot token
|
||||||
|
import {
|
||||||
|
Assistant,
|
||||||
|
EFFORT_RATIO,
|
||||||
|
FileTypes,
|
||||||
|
MCPCallToolResponse,
|
||||||
|
MCPTool,
|
||||||
|
MCPToolResponse,
|
||||||
|
Model,
|
||||||
|
Provider,
|
||||||
|
ToolCallResponse,
|
||||||
|
WebSearchSource
|
||||||
|
} from '@renderer/types'
|
||||||
|
import { ChunkType } from '@renderer/types/chunk'
|
||||||
|
import { Message } from '@renderer/types/newMessage'
|
||||||
|
import {
|
||||||
|
OpenAISdkMessageParam,
|
||||||
|
OpenAISdkParams,
|
||||||
|
OpenAISdkRawChunk,
|
||||||
|
OpenAISdkRawContentSource,
|
||||||
|
OpenAISdkRawOutput,
|
||||||
|
ReasoningEffortOptionalParams
|
||||||
|
} from '@renderer/types/sdk'
|
||||||
|
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||||
|
import {
|
||||||
|
isEnabledToolUse,
|
||||||
|
mcpToolCallResponseToOpenAICompatibleMessage,
|
||||||
|
mcpToolsToOpenAIChatTools,
|
||||||
|
openAIToolsToMcpTool
|
||||||
|
} from '@renderer/utils/mcp-tools'
|
||||||
|
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||||
|
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||||
|
import OpenAI, { AzureOpenAI } from 'openai'
|
||||||
|
import { ChatCompletionContentPart, ChatCompletionContentPartRefusal, ChatCompletionTool } from 'openai/resources'
|
||||||
|
|
||||||
|
import { GenericChunk } from '../../middleware/schemas'
|
||||||
|
import { RequestTransformer, ResponseChunkTransformer, ResponseChunkTransformerContext } from '../types'
|
||||||
|
import { OpenAIBaseClient } from './OpenAIBaseClient'
|
||||||
|
|
||||||
|
export class OpenAIAPIClient extends OpenAIBaseClient<
|
||||||
|
OpenAI | AzureOpenAI,
|
||||||
|
OpenAISdkParams,
|
||||||
|
OpenAISdkRawOutput,
|
||||||
|
OpenAISdkRawChunk,
|
||||||
|
OpenAISdkMessageParam,
|
||||||
|
OpenAI.Chat.Completions.ChatCompletionMessageToolCall,
|
||||||
|
ChatCompletionTool
|
||||||
|
> {
|
||||||
|
constructor(provider: Provider) {
|
||||||
|
super(provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
override async createCompletions(
|
||||||
|
payload: OpenAISdkParams,
|
||||||
|
options?: OpenAI.RequestOptions
|
||||||
|
): Promise<OpenAISdkRawOutput> {
|
||||||
|
const sdk = await this.getSdkInstance()
|
||||||
|
// @ts-ignore - SDK参数可能有额外的字段
|
||||||
|
return await sdk.chat.completions.create(payload, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the reasoning effort for the assistant
|
||||||
|
* @param assistant - The assistant
|
||||||
|
* @param model - The model
|
||||||
|
* @returns The reasoning effort
|
||||||
|
*/
|
||||||
|
// Method for reasoning effort, moved from OpenAIProvider
|
||||||
|
override getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams {
|
||||||
|
if (this.provider.id === 'groq') {
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isReasoningModel(model)) {
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||||
|
|
||||||
|
// Doubao 思考模式支持
|
||||||
|
if (isSupportedThinkingTokenDoubaoModel(model)) {
|
||||||
|
// reasoningEffort 为空,默认开启 enabled
|
||||||
|
if (!reasoningEffort) {
|
||||||
|
return { thinking: { type: 'disabled' } }
|
||||||
|
}
|
||||||
|
if (reasoningEffort === 'high') {
|
||||||
|
return { thinking: { type: 'enabled' } }
|
||||||
|
}
|
||||||
|
if (reasoningEffort === 'auto' && isDoubaoThinkingAutoModel(model)) {
|
||||||
|
return { thinking: { type: 'auto' } }
|
||||||
|
}
|
||||||
|
// 其他情况不带 thinking 字段
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!reasoningEffort) {
|
||||||
|
if (isSupportedThinkingTokenQwenModel(model)) {
|
||||||
|
return { enable_thinking: false }
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isSupportedThinkingTokenClaudeModel(model)) {
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isSupportedThinkingTokenGeminiModel(model)) {
|
||||||
|
// openrouter没有提供一个不推理的选项,先隐藏
|
||||||
|
if (this.provider.id === 'openrouter') {
|
||||||
|
return { reasoning: { max_tokens: 0, exclude: true } }
|
||||||
|
}
|
||||||
|
if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
|
||||||
|
return { reasoning_effort: 'none' }
|
||||||
|
}
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isSupportedThinkingTokenDoubaoModel(model)) {
|
||||||
|
return { thinking: { type: 'disabled' } }
|
||||||
|
}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||||
|
const budgetTokens = Math.floor(
|
||||||
|
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio + findTokenLimit(model.id)?.min!
|
||||||
|
)
|
||||||
|
|
||||||
|
// OpenRouter models
|
||||||
|
if (model.provider === 'openrouter') {
|
||||||
|
if (isSupportedReasoningEffortModel(model) || isSupportedThinkingTokenModel(model)) {
|
||||||
|
return {
|
||||||
|
reasoning: {
|
||||||
|
effort: reasoningEffort === 'auto' ? 'medium' : reasoningEffort
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Qwen models
|
||||||
|
if (isSupportedThinkingTokenQwenModel(model)) {
|
||||||
|
return {
|
||||||
|
enable_thinking: true,
|
||||||
|
thinking_budget: budgetTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Grok models
|
||||||
|
if (isSupportedReasoningEffortGrokModel(model)) {
|
||||||
|
return {
|
||||||
|
reasoning_effort: reasoningEffort
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenAI models
|
||||||
|
if (isSupportedReasoningEffortOpenAIModel(model) || isSupportedThinkingTokenGeminiModel(model)) {
|
||||||
|
return {
|
||||||
|
reasoning_effort: reasoningEffort
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Claude models
|
||||||
|
if (isSupportedThinkingTokenClaudeModel(model)) {
|
||||||
|
const maxTokens = assistant.settings?.maxTokens
|
||||||
|
return {
|
||||||
|
thinking: {
|
||||||
|
type: 'enabled',
|
||||||
|
budget_tokens: Math.floor(
|
||||||
|
Math.max(1024, Math.min(budgetTokens, (maxTokens || DEFAULT_MAX_TOKENS) * effortRatio))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Doubao models
|
||||||
|
if (isSupportedThinkingTokenDoubaoModel(model)) {
|
||||||
|
if (assistant.settings?.reasoning_effort === 'high') {
|
||||||
|
return {
|
||||||
|
thinking: {
|
||||||
|
type: 'enabled'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Default case: no special thinking settings
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if the provider does not support files
|
||||||
|
* @returns True if the provider does not support files, false otherwise
|
||||||
|
*/
|
||||||
|
private get isNotSupportFiles() {
|
||||||
|
if (this.provider?.isNotSupportArrayContent) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
const providers = ['deepseek', 'baichuan', 'minimax', 'xirang']
|
||||||
|
|
||||||
|
return providers.includes(this.provider.id)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the message parameter
|
||||||
|
* @param message - The message
|
||||||
|
* @param model - The model
|
||||||
|
* @returns The message parameter
|
||||||
|
*/
|
||||||
|
public async convertMessageToSdkParam(message: Message, model: Model): Promise<OpenAISdkMessageParam> {
|
||||||
|
const isVision = isVisionModel(model)
|
||||||
|
const content = await this.getMessageContent(message)
|
||||||
|
const fileBlocks = findFileBlocks(message)
|
||||||
|
const imageBlocks = findImageBlocks(message)
|
||||||
|
|
||||||
|
if (fileBlocks.length === 0 && imageBlocks.length === 0) {
|
||||||
|
return {
|
||||||
|
role: message.role === 'system' ? 'user' : message.role,
|
||||||
|
content
|
||||||
|
} as OpenAISdkMessageParam
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the model does not support files, extract the file content
|
||||||
|
if (this.isNotSupportFiles) {
|
||||||
|
const fileContent = await this.extractFileContent(message)
|
||||||
|
|
||||||
|
return {
|
||||||
|
role: message.role === 'system' ? 'user' : message.role,
|
||||||
|
content: content + '\n\n---\n\n' + fileContent
|
||||||
|
} as OpenAISdkMessageParam
|
||||||
|
}
|
||||||
|
|
||||||
|
// If the model supports files, add the file content to the message
|
||||||
|
const parts: ChatCompletionContentPart[] = []
|
||||||
|
|
||||||
|
if (content) {
|
||||||
|
parts.push({ type: 'text', text: content })
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const imageBlock of imageBlocks) {
|
||||||
|
if (isVision) {
|
||||||
|
if (imageBlock.file) {
|
||||||
|
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
|
||||||
|
parts.push({ type: 'image_url', image_url: { url: image.data } })
|
||||||
|
} else if (imageBlock.url && imageBlock.url.startsWith('data:')) {
|
||||||
|
parts.push({ type: 'image_url', image_url: { url: imageBlock.url } })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const fileBlock of fileBlocks) {
|
||||||
|
const file = fileBlock.file
|
||||||
|
if (!file) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||||
|
const fileContent = await (await window.api.file.read(file.id + file.ext)).trim()
|
||||||
|
parts.push({
|
||||||
|
type: 'text',
|
||||||
|
text: file.origin_name + '\n' + fileContent
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
role: message.role === 'system' ? 'user' : message.role,
|
||||||
|
content: parts
|
||||||
|
} as OpenAISdkMessageParam
|
||||||
|
}
|
||||||
|
|
||||||
|
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): ChatCompletionTool[] {
|
||||||
|
return mcpToolsToOpenAIChatTools(mcpTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
public convertSdkToolCallToMcp(
|
||||||
|
toolCall: OpenAI.Chat.Completions.ChatCompletionMessageToolCall,
|
||||||
|
mcpTools: MCPTool[]
|
||||||
|
): MCPTool | undefined {
|
||||||
|
return openAIToolsToMcpTool(mcpTools, toolCall)
|
||||||
|
}
|
||||||
|
|
||||||
|
public convertSdkToolCallToMcpToolResponse(
|
||||||
|
toolCall: OpenAI.Chat.Completions.ChatCompletionMessageToolCall,
|
||||||
|
mcpTool: MCPTool
|
||||||
|
): ToolCallResponse {
|
||||||
|
let parsedArgs: any
|
||||||
|
try {
|
||||||
|
parsedArgs = JSON.parse(toolCall.function.arguments)
|
||||||
|
} catch {
|
||||||
|
parsedArgs = toolCall.function.arguments
|
||||||
|
}
|
||||||
|
return {
|
||||||
|
id: toolCall.id,
|
||||||
|
toolCallId: toolCall.id,
|
||||||
|
tool: mcpTool,
|
||||||
|
arguments: parsedArgs,
|
||||||
|
status: 'pending'
|
||||||
|
} as ToolCallResponse
|
||||||
|
}
|
||||||
|
|
||||||
|
public convertMcpToolResponseToSdkMessageParam(
|
||||||
|
mcpToolResponse: MCPToolResponse,
|
||||||
|
resp: MCPCallToolResponse,
|
||||||
|
model: Model
|
||||||
|
): OpenAISdkMessageParam | undefined {
|
||||||
|
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||||
|
// This case is for Anthropic/Claude like tool usage, OpenAI uses tool_call_id
|
||||||
|
// For OpenAI, we primarily expect toolCallId. This might need adjustment if mixing provider concepts.
|
||||||
|
return mcpToolCallResponseToOpenAICompatibleMessage(mcpToolResponse, resp, isVisionModel(model))
|
||||||
|
} else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) {
|
||||||
|
return {
|
||||||
|
role: 'tool',
|
||||||
|
tool_call_id: mcpToolResponse.toolCallId,
|
||||||
|
content: JSON.stringify(resp.content)
|
||||||
|
} as OpenAI.Chat.Completions.ChatCompletionToolMessageParam
|
||||||
|
}
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
public buildSdkMessages(
|
||||||
|
currentReqMessages: OpenAISdkMessageParam[],
|
||||||
|
output: string,
|
||||||
|
toolResults: OpenAISdkMessageParam[],
|
||||||
|
toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[]
|
||||||
|
): OpenAISdkMessageParam[] {
|
||||||
|
const assistantMessage: OpenAISdkMessageParam = {
|
||||||
|
role: 'assistant',
|
||||||
|
content: output,
|
||||||
|
tool_calls: toolCalls.length > 0 ? toolCalls : undefined
|
||||||
|
}
|
||||||
|
const newReqMessages = [...currentReqMessages, assistantMessage, ...toolResults]
|
||||||
|
return newReqMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
override estimateMessageTokens(message: OpenAISdkMessageParam): number {
|
||||||
|
let sum = 0
|
||||||
|
if (typeof message.content === 'string') {
|
||||||
|
sum += estimateTextTokens(message.content)
|
||||||
|
} else if (Array.isArray(message.content)) {
|
||||||
|
sum += (message.content || [])
|
||||||
|
.map((part: ChatCompletionContentPart | ChatCompletionContentPartRefusal) => {
|
||||||
|
switch (part.type) {
|
||||||
|
case 'text':
|
||||||
|
return estimateTextTokens(part.text)
|
||||||
|
case 'image_url':
|
||||||
|
return estimateTextTokens(part.image_url.url)
|
||||||
|
case 'input_audio':
|
||||||
|
return estimateTextTokens(part.input_audio.data)
|
||||||
|
case 'file':
|
||||||
|
return estimateTextTokens(part.file.file_data || '')
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.reduce((acc, curr) => acc + curr, 0)
|
||||||
|
}
|
||||||
|
if ('tool_calls' in message && message.tool_calls) {
|
||||||
|
sum += message.tool_calls.reduce((acc, toolCall) => {
|
||||||
|
return acc + estimateTextTokens(JSON.stringify(toolCall.function.arguments))
|
||||||
|
}, 0)
|
||||||
|
}
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
|
public extractMessagesFromSdkPayload(sdkPayload: OpenAISdkParams): OpenAISdkMessageParam[] {
|
||||||
|
return sdkPayload.messages || []
|
||||||
|
}
|
||||||
|
|
||||||
|
getRequestTransformer(): RequestTransformer<OpenAISdkParams, OpenAISdkMessageParam> {
|
||||||
|
return {
|
||||||
|
transform: async (
|
||||||
|
coreRequest,
|
||||||
|
assistant,
|
||||||
|
model,
|
||||||
|
isRecursiveCall,
|
||||||
|
recursiveSdkMessages
|
||||||
|
): Promise<{
|
||||||
|
payload: OpenAISdkParams
|
||||||
|
messages: OpenAISdkMessageParam[]
|
||||||
|
metadata: Record<string, any>
|
||||||
|
}> => {
|
||||||
|
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch } = coreRequest
|
||||||
|
// 1. 处理系统消息
|
||||||
|
let systemMessage = { role: 'system', content: assistant.prompt || '' }
|
||||||
|
|
||||||
|
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||||
|
systemMessage = {
|
||||||
|
role: 'developer',
|
||||||
|
content: `Formatting re-enabled${systemMessage ? '\n' + systemMessage.content : ''}`
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (model.id.includes('o1-mini') || model.id.includes('o1-preview')) {
|
||||||
|
systemMessage.role = 'assistant'
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 设置工具(必须在this.usesystemPromptForTools前面)
|
||||||
|
const { tools } = this.setupToolsConfig({
|
||||||
|
mcpTools: mcpTools,
|
||||||
|
model,
|
||||||
|
enableToolUse: isEnabledToolUse(assistant)
|
||||||
|
})
|
||||||
|
|
||||||
|
if (this.useSystemPromptForTools) {
|
||||||
|
systemMessage.content = await buildSystemPrompt(systemMessage.content || '', mcpTools, assistant)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 处理用户消息
|
||||||
|
const userMessages: OpenAISdkMessageParam[] = []
|
||||||
|
if (typeof messages === 'string') {
|
||||||
|
userMessages.push({ role: 'user', content: messages })
|
||||||
|
} else {
|
||||||
|
const processedMessages = addImageFileToContents(messages)
|
||||||
|
for (const message of processedMessages) {
|
||||||
|
userMessages.push(await this.convertMessageToSdkParam(message, model))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const lastUserMsg = userMessages.findLast((m) => m.role === 'user')
|
||||||
|
if (lastUserMsg && isSupportedThinkingTokenQwenModel(model)) {
|
||||||
|
const postsuffix = '/no_think'
|
||||||
|
const qwenThinkModeEnabled = assistant.settings?.qwenThinkMode === true
|
||||||
|
const currentContent = lastUserMsg.content
|
||||||
|
|
||||||
|
lastUserMsg.content = processPostsuffixQwen3Model(currentContent, postsuffix, qwenThinkModeEnabled) as any
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 最终请求消息
|
||||||
|
let reqMessages: OpenAISdkMessageParam[]
|
||||||
|
if (!systemMessage.content) {
|
||||||
|
reqMessages = [...userMessages]
|
||||||
|
} else {
|
||||||
|
reqMessages = [systemMessage, ...userMessages].filter(Boolean) as OpenAISdkMessageParam[]
|
||||||
|
}
|
||||||
|
|
||||||
|
reqMessages = processReqMessages(model, reqMessages)
|
||||||
|
|
||||||
|
// 5. 创建通用参数
|
||||||
|
const commonParams = {
|
||||||
|
model: model.id,
|
||||||
|
messages:
|
||||||
|
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||||
|
? recursiveSdkMessages
|
||||||
|
: reqMessages,
|
||||||
|
temperature: this.getTemperature(assistant, model),
|
||||||
|
top_p: this.getTopP(assistant, model),
|
||||||
|
max_tokens: maxTokens,
|
||||||
|
tools: tools.length > 0 ? tools : undefined,
|
||||||
|
service_tier: this.getServiceTier(model),
|
||||||
|
...this.getProviderSpecificParameters(assistant, model),
|
||||||
|
...this.getReasoningEffort(assistant, model),
|
||||||
|
...getOpenAIWebSearchParams(model, enableWebSearch),
|
||||||
|
...this.getCustomParameters(assistant)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create the appropriate parameters object based on whether streaming is enabled
|
||||||
|
const sdkParams: OpenAISdkParams = streamOutput
|
||||||
|
? {
|
||||||
|
...commonParams,
|
||||||
|
stream: true
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
...commonParams,
|
||||||
|
stream: false
|
||||||
|
}
|
||||||
|
|
||||||
|
const timeout = this.getTimeout(model)
|
||||||
|
|
||||||
|
return { payload: sdkParams, messages: reqMessages, metadata: { timeout } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 在RawSdkChunkToGenericChunkMiddleware中使用
|
||||||
|
getResponseChunkTransformer = (): ResponseChunkTransformer<OpenAISdkRawChunk> => {
|
||||||
|
let hasBeenCollectedWebSearch = false
|
||||||
|
const collectWebSearchData = (
|
||||||
|
chunk: OpenAISdkRawChunk,
|
||||||
|
contentSource: OpenAISdkRawContentSource,
|
||||||
|
context: ResponseChunkTransformerContext
|
||||||
|
) => {
|
||||||
|
if (hasBeenCollectedWebSearch) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// OpenAI annotations
|
||||||
|
// @ts-ignore - annotations may not be in standard type definitions
|
||||||
|
const annotations = contentSource.annotations || chunk.annotations
|
||||||
|
if (annotations && annotations.length > 0 && annotations[0].type === 'url_citation') {
|
||||||
|
hasBeenCollectedWebSearch = true
|
||||||
|
return {
|
||||||
|
results: annotations,
|
||||||
|
source: WebSearchSource.OPENAI
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Grok citations
|
||||||
|
// @ts-ignore - citations may not be in standard type definitions
|
||||||
|
if (context.provider?.id === 'grok' && chunk.citations) {
|
||||||
|
hasBeenCollectedWebSearch = true
|
||||||
|
return {
|
||||||
|
// @ts-ignore - citations may not be in standard type definitions
|
||||||
|
results: chunk.citations,
|
||||||
|
source: WebSearchSource.GROK
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Perplexity citations
|
||||||
|
// @ts-ignore - citations may not be in standard type definitions
|
||||||
|
if (context.provider?.id === 'perplexity' && chunk.citations && chunk.citations.length > 0) {
|
||||||
|
hasBeenCollectedWebSearch = true
|
||||||
|
return {
|
||||||
|
// @ts-ignore - citations may not be in standard type definitions
|
||||||
|
results: chunk.citations,
|
||||||
|
source: WebSearchSource.PERPLEXITY
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// OpenRouter citations
|
||||||
|
// @ts-ignore - citations may not be in standard type definitions
|
||||||
|
if (context.provider?.id === 'openrouter' && chunk.citations && chunk.citations.length > 0) {
|
||||||
|
hasBeenCollectedWebSearch = true
|
||||||
|
return {
|
||||||
|
// @ts-ignore - citations may not be in standard type definitions
|
||||||
|
results: chunk.citations,
|
||||||
|
source: WebSearchSource.OPENROUTER
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Zhipu web search
|
||||||
|
// @ts-ignore - web_search may not be in standard type definitions
|
||||||
|
if (context.provider?.id === 'zhipu' && chunk.web_search) {
|
||||||
|
hasBeenCollectedWebSearch = true
|
||||||
|
return {
|
||||||
|
// @ts-ignore - web_search may not be in standard type definitions
|
||||||
|
results: chunk.web_search,
|
||||||
|
source: WebSearchSource.ZHIPU
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hunyuan web search
|
||||||
|
// @ts-ignore - search_info may not be in standard type definitions
|
||||||
|
if (context.provider?.id === 'hunyuan' && chunk.search_info?.search_results) {
|
||||||
|
hasBeenCollectedWebSearch = true
|
||||||
|
return {
|
||||||
|
// @ts-ignore - search_info may not be in standard type definitions
|
||||||
|
results: chunk.search_info.search_results,
|
||||||
|
source: WebSearchSource.HUNYUAN
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: 放到AnthropicApiClient中
|
||||||
|
// // Other providers...
|
||||||
|
// // @ts-ignore - web_search may not be in standard type definitions
|
||||||
|
// if (chunk.web_search) {
|
||||||
|
// const sourceMap: Record<string, string> = {
|
||||||
|
// openai: 'openai',
|
||||||
|
// anthropic: 'anthropic',
|
||||||
|
// qwenlm: 'qwen'
|
||||||
|
// }
|
||||||
|
// const source = sourceMap[context.provider?.id] || 'openai_response'
|
||||||
|
// return {
|
||||||
|
// results: chunk.web_search,
|
||||||
|
// source: source as const
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
const toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] = []
|
||||||
|
return (context: ResponseChunkTransformerContext) => ({
|
||||||
|
async transform(chunk: OpenAISdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||||
|
// 处理chunk
|
||||||
|
if ('choices' in chunk && chunk.choices && chunk.choices.length > 0) {
|
||||||
|
const choice = chunk.choices[0]
|
||||||
|
|
||||||
|
if (!choice) return
|
||||||
|
|
||||||
|
// 对于流式响应,使用delta;对于非流式响应,使用message
|
||||||
|
const contentSource: OpenAISdkRawContentSource | null =
|
||||||
|
'delta' in choice ? choice.delta : 'message' in choice ? choice.message : null
|
||||||
|
|
||||||
|
if (!contentSource) return
|
||||||
|
|
||||||
|
const webSearchData = collectWebSearchData(chunk, contentSource, context)
|
||||||
|
if (webSearchData) {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||||
|
llm_web_search: webSearchData
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理推理内容 (e.g. from OpenRouter DeepSeek-R1)
|
||||||
|
// @ts-ignore - reasoning_content is not in standard OpenAI types but some providers use it
|
||||||
|
const reasoningText = contentSource.reasoning_content || contentSource.reasoning
|
||||||
|
if (reasoningText) {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.THINKING_DELTA,
|
||||||
|
text: reasoningText
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理文本内容
|
||||||
|
if (contentSource.content) {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.TEXT_DELTA,
|
||||||
|
text: contentSource.content
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理工具调用
|
||||||
|
if (contentSource.tool_calls) {
|
||||||
|
for (const toolCall of contentSource.tool_calls) {
|
||||||
|
if ('index' in toolCall) {
|
||||||
|
const { id, index, function: fun } = toolCall
|
||||||
|
if (fun?.name) {
|
||||||
|
toolCalls[index] = {
|
||||||
|
id: id || '',
|
||||||
|
function: {
|
||||||
|
name: fun.name,
|
||||||
|
arguments: fun.arguments || ''
|
||||||
|
},
|
||||||
|
type: 'function'
|
||||||
|
}
|
||||||
|
} else if (fun?.arguments) {
|
||||||
|
toolCalls[index].function.arguments += fun.arguments
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
toolCalls.push(toolCall)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理finish_reason,发送流结束信号
|
||||||
|
if ('finish_reason' in choice && choice.finish_reason) {
|
||||||
|
Logger.debug(`[OpenAIApiClient] Stream finished with reason: ${choice.finish_reason}`)
|
||||||
|
if (toolCalls.length > 0) {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.MCP_TOOL_CREATED,
|
||||||
|
tool_calls: toolCalls
|
||||||
|
})
|
||||||
|
}
|
||||||
|
const webSearchData = collectWebSearchData(chunk, contentSource, context)
|
||||||
|
if (webSearchData) {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||||
|
llm_web_search: webSearchData
|
||||||
|
})
|
||||||
|
}
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||||
|
response: {
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: chunk.usage?.prompt_tokens || 0,
|
||||||
|
completion_tokens: chunk.usage?.completion_tokens || 0,
|
||||||
|
total_tokens: (chunk.usage?.prompt_tokens || 0) + (chunk.usage?.completion_tokens || 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
258
src/renderer/src/aiCore/clients/openai/OpenAIBaseClient.ts
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
import {
|
||||||
|
isClaudeReasoningModel,
|
||||||
|
isNotSupportTemperatureAndTopP,
|
||||||
|
isOpenAIReasoningModel,
|
||||||
|
isSupportedModel,
|
||||||
|
isSupportedReasoningEffortOpenAIModel
|
||||||
|
} from '@renderer/config/models'
|
||||||
|
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||||
|
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||||
|
import store from '@renderer/store'
|
||||||
|
import { SettingsState } from '@renderer/store/settings'
|
||||||
|
import { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||||
|
import {
|
||||||
|
OpenAIResponseSdkMessageParam,
|
||||||
|
OpenAIResponseSdkParams,
|
||||||
|
OpenAIResponseSdkRawChunk,
|
||||||
|
OpenAIResponseSdkRawOutput,
|
||||||
|
OpenAIResponseSdkTool,
|
||||||
|
OpenAIResponseSdkToolCall,
|
||||||
|
OpenAISdkMessageParam,
|
||||||
|
OpenAISdkParams,
|
||||||
|
OpenAISdkRawChunk,
|
||||||
|
OpenAISdkRawOutput,
|
||||||
|
ReasoningEffortOptionalParams
|
||||||
|
} from '@renderer/types/sdk'
|
||||||
|
import { formatApiHost } from '@renderer/utils/api'
|
||||||
|
import OpenAI, { AzureOpenAI } from 'openai'
|
||||||
|
|
||||||
|
import { BaseApiClient } from '../BaseApiClient'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 抽象的OpenAI基础客户端类,包含两个OpenAI客户端之间的共享功能
|
||||||
|
*/
|
||||||
|
export abstract class OpenAIBaseClient<
|
||||||
|
TSdkInstance extends OpenAI | AzureOpenAI,
|
||||||
|
TSdkParams extends OpenAISdkParams | OpenAIResponseSdkParams,
|
||||||
|
TRawOutput extends OpenAISdkRawOutput | OpenAIResponseSdkRawOutput,
|
||||||
|
TRawChunk extends OpenAISdkRawChunk | OpenAIResponseSdkRawChunk,
|
||||||
|
TMessageParam extends OpenAISdkMessageParam | OpenAIResponseSdkMessageParam,
|
||||||
|
TToolCall extends OpenAI.Chat.Completions.ChatCompletionMessageToolCall | OpenAIResponseSdkToolCall,
|
||||||
|
TSdkSpecificTool extends OpenAI.Chat.Completions.ChatCompletionTool | OpenAIResponseSdkTool
|
||||||
|
> extends BaseApiClient<TSdkInstance, TSdkParams, TRawOutput, TRawChunk, TMessageParam, TToolCall, TSdkSpecificTool> {
|
||||||
|
constructor(provider: Provider) {
|
||||||
|
super(provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 仅适用于openai
|
||||||
|
override getBaseURL(): string {
|
||||||
|
const host = this.provider.apiHost
|
||||||
|
return formatApiHost(host)
|
||||||
|
}
|
||||||
|
|
||||||
|
override async generateImage({
|
||||||
|
model,
|
||||||
|
prompt,
|
||||||
|
negativePrompt,
|
||||||
|
imageSize,
|
||||||
|
batchSize,
|
||||||
|
seed,
|
||||||
|
numInferenceSteps,
|
||||||
|
guidanceScale,
|
||||||
|
signal,
|
||||||
|
promptEnhancement
|
||||||
|
}: GenerateImageParams): Promise<string[]> {
|
||||||
|
const sdk = await this.getSdkInstance()
|
||||||
|
const response = (await sdk.request({
|
||||||
|
method: 'post',
|
||||||
|
path: '/images/generations',
|
||||||
|
signal,
|
||||||
|
body: {
|
||||||
|
model,
|
||||||
|
prompt,
|
||||||
|
negative_prompt: negativePrompt,
|
||||||
|
image_size: imageSize,
|
||||||
|
batch_size: batchSize,
|
||||||
|
seed: seed ? parseInt(seed) : undefined,
|
||||||
|
num_inference_steps: numInferenceSteps,
|
||||||
|
guidance_scale: guidanceScale,
|
||||||
|
prompt_enhancement: promptEnhancement
|
||||||
|
}
|
||||||
|
})) as { data: Array<{ url: string }> }
|
||||||
|
|
||||||
|
return response.data.map((item) => item.url)
|
||||||
|
}
|
||||||
|
|
||||||
|
override async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||||
|
const sdk = await this.getSdkInstance()
|
||||||
|
try {
|
||||||
|
const data = await sdk.embeddings.create({
|
||||||
|
model: model.id,
|
||||||
|
input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi',
|
||||||
|
encoding_format: 'float'
|
||||||
|
})
|
||||||
|
return data.data[0].embedding.length
|
||||||
|
} catch (e) {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override async listModels(): Promise<OpenAI.Models.Model[]> {
|
||||||
|
try {
|
||||||
|
const sdk = await this.getSdkInstance()
|
||||||
|
const response = await sdk.models.list()
|
||||||
|
if (this.provider.id === 'github') {
|
||||||
|
// @ts-ignore key is not typed
|
||||||
|
return response?.body
|
||||||
|
.map((model) => ({
|
||||||
|
id: model.name,
|
||||||
|
description: model.summary,
|
||||||
|
object: 'model',
|
||||||
|
owned_by: model.publisher
|
||||||
|
}))
|
||||||
|
.filter(isSupportedModel)
|
||||||
|
}
|
||||||
|
if (this.provider.id === 'together') {
|
||||||
|
// @ts-ignore key is not typed
|
||||||
|
return response?.body.map((model) => ({
|
||||||
|
id: model.id,
|
||||||
|
description: model.display_name,
|
||||||
|
object: 'model',
|
||||||
|
owned_by: model.organization
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
const models = response.data || []
|
||||||
|
models.forEach((model) => {
|
||||||
|
model.id = model.id.trim()
|
||||||
|
})
|
||||||
|
|
||||||
|
return models.filter(isSupportedModel)
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error listing models:', error)
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override async getSdkInstance() {
|
||||||
|
if (this.sdkInstance) {
|
||||||
|
return this.sdkInstance
|
||||||
|
}
|
||||||
|
|
||||||
|
let apiKeyForSdkInstance = this.provider.apiKey
|
||||||
|
|
||||||
|
if (this.provider.id === 'copilot') {
|
||||||
|
const defaultHeaders = store.getState().copilot.defaultHeaders
|
||||||
|
const { token } = await window.api.copilot.getToken(defaultHeaders)
|
||||||
|
// this.provider.apiKey不允许修改
|
||||||
|
// this.provider.apiKey = token
|
||||||
|
apiKeyForSdkInstance = token
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
|
||||||
|
this.sdkInstance = new AzureOpenAI({
|
||||||
|
dangerouslyAllowBrowser: true,
|
||||||
|
apiKey: apiKeyForSdkInstance,
|
||||||
|
apiVersion: this.provider.apiVersion,
|
||||||
|
endpoint: this.provider.apiHost
|
||||||
|
}) as TSdkInstance
|
||||||
|
} else {
|
||||||
|
this.sdkInstance = new OpenAI({
|
||||||
|
dangerouslyAllowBrowser: true,
|
||||||
|
apiKey: apiKeyForSdkInstance,
|
||||||
|
baseURL: this.getBaseURL(),
|
||||||
|
defaultHeaders: {
|
||||||
|
...this.defaultHeaders(),
|
||||||
|
...(this.provider.id === 'copilot' ? { 'editor-version': 'vscode/1.97.2' } : {}),
|
||||||
|
...(this.provider.id === 'copilot' ? { 'copilot-vision-request': 'true' } : {})
|
||||||
|
}
|
||||||
|
}) as TSdkInstance
|
||||||
|
}
|
||||||
|
return this.sdkInstance
|
||||||
|
}
|
||||||
|
|
||||||
|
override getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||||
|
if (
|
||||||
|
isNotSupportTemperatureAndTopP(model) ||
|
||||||
|
(assistant.settings?.reasoning_effort && isClaudeReasoningModel(model))
|
||||||
|
) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
return assistant.settings?.temperature
|
||||||
|
}
|
||||||
|
|
||||||
|
override getTopP(assistant: Assistant, model: Model): number | undefined {
|
||||||
|
if (
|
||||||
|
isNotSupportTemperatureAndTopP(model) ||
|
||||||
|
(assistant.settings?.reasoning_effort && isClaudeReasoningModel(model))
|
||||||
|
) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
return assistant.settings?.topP
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the provider specific parameters for the assistant
|
||||||
|
* @param assistant - The assistant
|
||||||
|
* @param model - The model
|
||||||
|
* @returns The provider specific parameters
|
||||||
|
*/
|
||||||
|
protected getProviderSpecificParameters(assistant: Assistant, model: Model) {
|
||||||
|
const { maxTokens } = getAssistantSettings(assistant)
|
||||||
|
|
||||||
|
if (this.provider.id === 'openrouter') {
|
||||||
|
if (model.id.includes('deepseek-r1')) {
|
||||||
|
return {
|
||||||
|
include_reasoning: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isOpenAIReasoningModel(model)) {
|
||||||
|
return {
|
||||||
|
max_tokens: undefined,
|
||||||
|
max_completion_tokens: maxTokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the reasoning effort for the assistant
|
||||||
|
* @param assistant - The assistant
|
||||||
|
* @param model - The model
|
||||||
|
* @returns The reasoning effort
|
||||||
|
*/
|
||||||
|
protected getReasoningEffort(assistant: Assistant, model: Model): ReasoningEffortOptionalParams {
|
||||||
|
if (!isSupportedReasoningEffortOpenAIModel(model)) {
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
|
||||||
|
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
|
||||||
|
const summaryText = openAI?.summaryText || 'off'
|
||||||
|
|
||||||
|
let summary: string | undefined = undefined
|
||||||
|
|
||||||
|
if (summaryText === 'off' || model.id.includes('o1-pro')) {
|
||||||
|
summary = undefined
|
||||||
|
} else {
|
||||||
|
summary = summaryText
|
||||||
|
}
|
||||||
|
|
||||||
|
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||||
|
if (!reasoningEffort) {
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||||
|
return {
|
||||||
|
reasoning: {
|
||||||
|
effort: reasoningEffort as OpenAI.ReasoningEffort,
|
||||||
|
summary: summary
|
||||||
|
} as OpenAI.Reasoning
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,559 @@
|
|||||||
|
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
|
||||||
|
import {
|
||||||
|
isOpenAIChatCompletionOnlyModel,
|
||||||
|
isSupportedReasoningEffortOpenAIModel,
|
||||||
|
isVisionModel
|
||||||
|
} from '@renderer/config/models'
|
||||||
|
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||||
|
import {
|
||||||
|
FileType,
|
||||||
|
FileTypes,
|
||||||
|
MCPCallToolResponse,
|
||||||
|
MCPTool,
|
||||||
|
MCPToolResponse,
|
||||||
|
Model,
|
||||||
|
Provider,
|
||||||
|
ToolCallResponse,
|
||||||
|
WebSearchSource
|
||||||
|
} from '@renderer/types'
|
||||||
|
import { ChunkType } from '@renderer/types/chunk'
|
||||||
|
import { Message } from '@renderer/types/newMessage'
|
||||||
|
import {
|
||||||
|
OpenAIResponseSdkMessageParam,
|
||||||
|
OpenAIResponseSdkParams,
|
||||||
|
OpenAIResponseSdkRawChunk,
|
||||||
|
OpenAIResponseSdkRawOutput,
|
||||||
|
OpenAIResponseSdkTool,
|
||||||
|
OpenAIResponseSdkToolCall
|
||||||
|
} from '@renderer/types/sdk'
|
||||||
|
import { addImageFileToContents } from '@renderer/utils/formats'
|
||||||
|
import {
|
||||||
|
isEnabledToolUse,
|
||||||
|
mcpToolCallResponseToOpenAIMessage,
|
||||||
|
mcpToolsToOpenAIResponseTools,
|
||||||
|
openAIToolsToMcpTool
|
||||||
|
} from '@renderer/utils/mcp-tools'
|
||||||
|
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||||
|
import { buildSystemPrompt } from '@renderer/utils/prompt'
|
||||||
|
import { MB } from '@shared/config/constant'
|
||||||
|
import { isEmpty } from 'lodash'
|
||||||
|
import OpenAI from 'openai'
|
||||||
|
|
||||||
|
import { RequestTransformer, ResponseChunkTransformer } from '../types'
|
||||||
|
import { OpenAIAPIClient } from './OpenAIApiClient'
|
||||||
|
import { OpenAIBaseClient } from './OpenAIBaseClient'
|
||||||
|
|
||||||
|
export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
||||||
|
OpenAI,
|
||||||
|
OpenAIResponseSdkParams,
|
||||||
|
OpenAIResponseSdkRawOutput,
|
||||||
|
OpenAIResponseSdkRawChunk,
|
||||||
|
OpenAIResponseSdkMessageParam,
|
||||||
|
OpenAIResponseSdkToolCall,
|
||||||
|
OpenAIResponseSdkTool
|
||||||
|
> {
|
||||||
|
private client: OpenAIAPIClient
|
||||||
|
constructor(provider: Provider) {
|
||||||
|
super(provider)
|
||||||
|
this.client = new OpenAIAPIClient(provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 根据模型特征选择合适的客户端
|
||||||
|
*/
|
||||||
|
public getClient(model: Model) {
|
||||||
|
if (isOpenAIChatCompletionOnlyModel(model)) {
|
||||||
|
return this.client
|
||||||
|
} else {
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
override async getSdkInstance() {
|
||||||
|
if (this.sdkInstance) {
|
||||||
|
return this.sdkInstance
|
||||||
|
}
|
||||||
|
|
||||||
|
return new OpenAI({
|
||||||
|
dangerouslyAllowBrowser: true,
|
||||||
|
apiKey: this.provider.apiKey,
|
||||||
|
baseURL: this.getBaseURL(),
|
||||||
|
defaultHeaders: {
|
||||||
|
...this.defaultHeaders()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
override async createCompletions(
|
||||||
|
payload: OpenAIResponseSdkParams,
|
||||||
|
options?: OpenAI.RequestOptions
|
||||||
|
): Promise<OpenAIResponseSdkRawOutput> {
|
||||||
|
const sdk = await this.getSdkInstance()
|
||||||
|
return await sdk.responses.create(payload, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
private async handlePdfFile(file: FileType): Promise<OpenAI.Responses.ResponseInputFile | undefined> {
|
||||||
|
if (file.size > 32 * MB) return undefined
|
||||||
|
try {
|
||||||
|
const pageCount = await window.api.file.pdfInfo(file.id + file.ext)
|
||||||
|
if (pageCount > 100) return undefined
|
||||||
|
} catch {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
|
||||||
|
const { data } = await window.api.file.base64File(file.id + file.ext)
|
||||||
|
return {
|
||||||
|
type: 'input_file',
|
||||||
|
filename: file.origin_name,
|
||||||
|
file_data: `data:application/pdf;base64,${data}`
|
||||||
|
} as OpenAI.Responses.ResponseInputFile
|
||||||
|
}
|
||||||
|
|
||||||
|
public async convertMessageToSdkParam(message: Message, model: Model): Promise<OpenAIResponseSdkMessageParam> {
|
||||||
|
const isVision = isVisionModel(model)
|
||||||
|
const content = await this.getMessageContent(message)
|
||||||
|
const fileBlocks = findFileBlocks(message)
|
||||||
|
const imageBlocks = findImageBlocks(message)
|
||||||
|
|
||||||
|
if (fileBlocks.length === 0 && imageBlocks.length === 0) {
|
||||||
|
if (message.role === 'assistant') {
|
||||||
|
return {
|
||||||
|
role: 'assistant',
|
||||||
|
content: content
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return {
|
||||||
|
role: message.role === 'system' ? 'user' : message.role,
|
||||||
|
content: content ? [{ type: 'input_text', text: content }] : []
|
||||||
|
} as OpenAI.Responses.EasyInputMessage
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const parts: OpenAI.Responses.ResponseInputContent[] = []
|
||||||
|
if (content) {
|
||||||
|
parts.push({
|
||||||
|
type: 'input_text',
|
||||||
|
text: content
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const imageBlock of imageBlocks) {
|
||||||
|
if (isVision) {
|
||||||
|
if (imageBlock.file) {
|
||||||
|
const image = await window.api.file.base64Image(imageBlock.file.id + imageBlock.file.ext)
|
||||||
|
parts.push({
|
||||||
|
detail: 'auto',
|
||||||
|
type: 'input_image',
|
||||||
|
image_url: image.data as string
|
||||||
|
})
|
||||||
|
} else if (imageBlock.url && imageBlock.url.startsWith('data:')) {
|
||||||
|
parts.push({
|
||||||
|
detail: 'auto',
|
||||||
|
type: 'input_image',
|
||||||
|
image_url: imageBlock.url
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const fileBlock of fileBlocks) {
|
||||||
|
const file = fileBlock.file
|
||||||
|
if (!file) continue
|
||||||
|
|
||||||
|
if (isVision && file.ext === '.pdf') {
|
||||||
|
const pdfPart = await this.handlePdfFile(file)
|
||||||
|
if (pdfPart) {
|
||||||
|
parts.push(pdfPart)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if ([FileTypes.TEXT, FileTypes.DOCUMENT].includes(file.type)) {
|
||||||
|
const fileContent = (await window.api.file.read(file.id + file.ext)).trim()
|
||||||
|
parts.push({
|
||||||
|
type: 'input_text',
|
||||||
|
text: file.origin_name + '\n' + fileContent
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
role: message.role === 'system' ? 'user' : message.role,
|
||||||
|
content: parts
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public convertMcpToolsToSdkTools(mcpTools: MCPTool[]): OpenAI.Responses.Tool[] {
|
||||||
|
return mcpToolsToOpenAIResponseTools(mcpTools)
|
||||||
|
}
|
||||||
|
|
||||||
|
public convertSdkToolCallToMcp(toolCall: OpenAIResponseSdkToolCall, mcpTools: MCPTool[]): MCPTool | undefined {
|
||||||
|
return openAIToolsToMcpTool(mcpTools, toolCall)
|
||||||
|
}
|
||||||
|
public convertSdkToolCallToMcpToolResponse(toolCall: OpenAIResponseSdkToolCall, mcpTool: MCPTool): ToolCallResponse {
|
||||||
|
const parsedArgs = (() => {
|
||||||
|
try {
|
||||||
|
return JSON.parse(toolCall.arguments)
|
||||||
|
} catch {
|
||||||
|
return toolCall.arguments
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
|
||||||
|
return {
|
||||||
|
id: toolCall.call_id,
|
||||||
|
toolCallId: toolCall.call_id,
|
||||||
|
tool: mcpTool,
|
||||||
|
arguments: parsedArgs,
|
||||||
|
status: 'pending'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public convertMcpToolResponseToSdkMessageParam(
|
||||||
|
mcpToolResponse: MCPToolResponse,
|
||||||
|
resp: MCPCallToolResponse,
|
||||||
|
model: Model
|
||||||
|
): OpenAIResponseSdkMessageParam | undefined {
|
||||||
|
if ('toolUseId' in mcpToolResponse && mcpToolResponse.toolUseId) {
|
||||||
|
return mcpToolCallResponseToOpenAIMessage(mcpToolResponse, resp, isVisionModel(model))
|
||||||
|
} else if ('toolCallId' in mcpToolResponse && mcpToolResponse.toolCallId) {
|
||||||
|
return {
|
||||||
|
type: 'function_call_output',
|
||||||
|
call_id: mcpToolResponse.toolCallId,
|
||||||
|
output: JSON.stringify(resp.content)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
public buildSdkMessages(
|
||||||
|
currentReqMessages: OpenAIResponseSdkMessageParam[],
|
||||||
|
output: string,
|
||||||
|
toolResults: OpenAIResponseSdkMessageParam[],
|
||||||
|
toolCalls: OpenAIResponseSdkToolCall[]
|
||||||
|
): OpenAIResponseSdkMessageParam[] {
|
||||||
|
const assistantMessage: OpenAIResponseSdkMessageParam = {
|
||||||
|
role: 'assistant',
|
||||||
|
content: [{ type: 'input_text', text: output }]
|
||||||
|
}
|
||||||
|
const newReqMessages = [...currentReqMessages, assistantMessage, ...(toolCalls || []), ...(toolResults || [])]
|
||||||
|
return newReqMessages
|
||||||
|
}
|
||||||
|
|
||||||
|
override estimateMessageTokens(message: OpenAIResponseSdkMessageParam): number {
|
||||||
|
let sum = 0
|
||||||
|
if ('content' in message) {
|
||||||
|
if (typeof message.content === 'string') {
|
||||||
|
sum += estimateTextTokens(message.content)
|
||||||
|
} else if (Array.isArray(message.content)) {
|
||||||
|
for (const part of message.content) {
|
||||||
|
switch (part.type) {
|
||||||
|
case 'input_text':
|
||||||
|
sum += estimateTextTokens(part.text)
|
||||||
|
break
|
||||||
|
case 'input_image':
|
||||||
|
sum += estimateTextTokens(part.image_url || '')
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
switch (message.type) {
|
||||||
|
case 'function_call_output':
|
||||||
|
sum += estimateTextTokens(message.output)
|
||||||
|
break
|
||||||
|
case 'function_call':
|
||||||
|
sum += estimateTextTokens(message.arguments)
|
||||||
|
break
|
||||||
|
default:
|
||||||
|
break
|
||||||
|
}
|
||||||
|
return sum
|
||||||
|
}
|
||||||
|
|
||||||
|
public extractMessagesFromSdkPayload(sdkPayload: OpenAIResponseSdkParams): OpenAIResponseSdkMessageParam[] {
|
||||||
|
if (typeof sdkPayload.input === 'string') {
|
||||||
|
return [{ role: 'user', content: sdkPayload.input }]
|
||||||
|
}
|
||||||
|
return sdkPayload.input
|
||||||
|
}
|
||||||
|
|
||||||
|
getRequestTransformer(): RequestTransformer<OpenAIResponseSdkParams, OpenAIResponseSdkMessageParam> {
|
||||||
|
return {
|
||||||
|
transform: async (
|
||||||
|
coreRequest,
|
||||||
|
assistant,
|
||||||
|
model,
|
||||||
|
isRecursiveCall,
|
||||||
|
recursiveSdkMessages
|
||||||
|
): Promise<{
|
||||||
|
payload: OpenAIResponseSdkParams
|
||||||
|
messages: OpenAIResponseSdkMessageParam[]
|
||||||
|
metadata: Record<string, any>
|
||||||
|
}> => {
|
||||||
|
const { messages, mcpTools, maxTokens, streamOutput, enableWebSearch, enableGenerateImage } = coreRequest
|
||||||
|
// 1. 处理系统消息
|
||||||
|
const systemMessage: OpenAI.Responses.EasyInputMessage = {
|
||||||
|
role: 'system',
|
||||||
|
content: []
|
||||||
|
}
|
||||||
|
|
||||||
|
const systemMessageContent: OpenAI.Responses.ResponseInputMessageContentList = []
|
||||||
|
const systemMessageInput: OpenAI.Responses.ResponseInputText = {
|
||||||
|
text: assistant.prompt || '',
|
||||||
|
type: 'input_text'
|
||||||
|
}
|
||||||
|
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||||
|
systemMessage.role = 'developer'
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 设置工具
|
||||||
|
let tools: OpenAI.Responses.Tool[] = []
|
||||||
|
const { tools: extraTools } = this.setupToolsConfig({
|
||||||
|
mcpTools: mcpTools,
|
||||||
|
model,
|
||||||
|
enableToolUse: isEnabledToolUse(assistant)
|
||||||
|
})
|
||||||
|
|
||||||
|
if (this.useSystemPromptForTools) {
|
||||||
|
systemMessageInput.text = await buildSystemPrompt(systemMessageInput.text || '', mcpTools, assistant)
|
||||||
|
}
|
||||||
|
systemMessageContent.push(systemMessageInput)
|
||||||
|
systemMessage.content = systemMessageContent
|
||||||
|
|
||||||
|
// 3. 处理用户消息
|
||||||
|
let userMessage: OpenAI.Responses.ResponseInputItem[] = []
|
||||||
|
if (typeof messages === 'string') {
|
||||||
|
userMessage.push({ role: 'user', content: messages })
|
||||||
|
} else {
|
||||||
|
const processedMessages = addImageFileToContents(messages)
|
||||||
|
for (const message of processedMessages) {
|
||||||
|
userMessage.push(await this.convertMessageToSdkParam(message, model))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// FIXME: 最好还是直接使用previous_response_id来处理(或者在数据库中存储image_generation_call的id)
|
||||||
|
if (enableGenerateImage) {
|
||||||
|
const finalAssistantMessage = userMessage.findLast(
|
||||||
|
(m) => (m as OpenAI.Responses.EasyInputMessage).role === 'assistant'
|
||||||
|
) as OpenAI.Responses.EasyInputMessage
|
||||||
|
const finalUserMessage = userMessage.pop() as OpenAI.Responses.EasyInputMessage
|
||||||
|
if (
|
||||||
|
finalAssistantMessage &&
|
||||||
|
Array.isArray(finalAssistantMessage.content) &&
|
||||||
|
finalUserMessage &&
|
||||||
|
Array.isArray(finalUserMessage.content)
|
||||||
|
) {
|
||||||
|
finalAssistantMessage.content = [...finalAssistantMessage.content, ...finalUserMessage.content]
|
||||||
|
}
|
||||||
|
// 这里是故意将上条助手消息的内容(包含图片和文件)作为用户消息发送
|
||||||
|
userMessage = [{ ...finalAssistantMessage, role: 'user' } as OpenAI.Responses.EasyInputMessage]
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 最终请求消息
|
||||||
|
let reqMessages: OpenAI.Responses.ResponseInput
|
||||||
|
if (!systemMessage.content) {
|
||||||
|
reqMessages = [...userMessage]
|
||||||
|
} else {
|
||||||
|
reqMessages = [systemMessage, ...userMessage].filter(Boolean) as OpenAI.Responses.EasyInputMessage[]
|
||||||
|
}
|
||||||
|
|
||||||
|
if (enableWebSearch) {
|
||||||
|
tools.push({
|
||||||
|
type: 'web_search_preview'
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if (enableGenerateImage) {
|
||||||
|
tools.push({
|
||||||
|
type: 'image_generation',
|
||||||
|
partial_images: streamOutput ? 2 : undefined
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const toolChoices: OpenAI.Responses.ToolChoiceTypes = {
|
||||||
|
type: 'web_search_preview'
|
||||||
|
}
|
||||||
|
|
||||||
|
tools = tools.concat(extraTools)
|
||||||
|
const commonParams = {
|
||||||
|
model: model.id,
|
||||||
|
input:
|
||||||
|
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
|
||||||
|
? recursiveSdkMessages
|
||||||
|
: reqMessages,
|
||||||
|
temperature: this.getTemperature(assistant, model),
|
||||||
|
top_p: this.getTopP(assistant, model),
|
||||||
|
max_output_tokens: maxTokens,
|
||||||
|
stream: streamOutput,
|
||||||
|
tools: !isEmpty(tools) ? tools : undefined,
|
||||||
|
tool_choice: enableWebSearch ? toolChoices : undefined,
|
||||||
|
service_tier: this.getServiceTier(model),
|
||||||
|
...(this.getReasoningEffort(assistant, model) as OpenAI.Reasoning),
|
||||||
|
...this.getCustomParameters(assistant)
|
||||||
|
}
|
||||||
|
const sdkParams: OpenAIResponseSdkParams = streamOutput
|
||||||
|
? {
|
||||||
|
...commonParams,
|
||||||
|
stream: true
|
||||||
|
}
|
||||||
|
: {
|
||||||
|
...commonParams,
|
||||||
|
stream: false
|
||||||
|
}
|
||||||
|
const timeout = this.getTimeout(model)
|
||||||
|
return { payload: sdkParams, messages: reqMessages, metadata: { timeout } }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
getResponseChunkTransformer(): ResponseChunkTransformer<OpenAIResponseSdkRawChunk> {
|
||||||
|
const toolCalls: OpenAIResponseSdkToolCall[] = []
|
||||||
|
const outputItems: OpenAI.Responses.ResponseOutputItem[] = []
|
||||||
|
return () => ({
|
||||||
|
async transform(chunk: OpenAIResponseSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
|
||||||
|
// 处理chunk
|
||||||
|
if ('output' in chunk) {
|
||||||
|
for (const output of chunk.output) {
|
||||||
|
switch (output.type) {
|
||||||
|
case 'message':
|
||||||
|
if (output.content[0].type === 'output_text') {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.TEXT_DELTA,
|
||||||
|
text: output.content[0].text
|
||||||
|
})
|
||||||
|
if (output.content[0].annotations && output.content[0].annotations.length > 0) {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||||
|
llm_web_search: {
|
||||||
|
source: WebSearchSource.OPENAI_RESPONSE,
|
||||||
|
results: output.content[0].annotations
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
case 'reasoning':
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.THINKING_DELTA,
|
||||||
|
text: output.summary.map((s) => s.text).join('\n')
|
||||||
|
})
|
||||||
|
break
|
||||||
|
case 'function_call':
|
||||||
|
toolCalls.push(output)
|
||||||
|
break
|
||||||
|
case 'image_generation_call':
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.IMAGE_CREATED
|
||||||
|
})
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.IMAGE_COMPLETE,
|
||||||
|
image: {
|
||||||
|
type: 'base64',
|
||||||
|
images: [`data:image/png;base64,${output.result}`]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
switch (chunk.type) {
|
||||||
|
case 'response.output_item.added':
|
||||||
|
if (chunk.item.type === 'function_call') {
|
||||||
|
outputItems.push(chunk.item)
|
||||||
|
}
|
||||||
|
break
|
||||||
|
case 'response.reasoning_summary_text.delta':
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.THINKING_DELTA,
|
||||||
|
text: chunk.delta
|
||||||
|
})
|
||||||
|
break
|
||||||
|
case 'response.image_generation_call.generating':
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.IMAGE_CREATED
|
||||||
|
})
|
||||||
|
break
|
||||||
|
case 'response.image_generation_call.partial_image':
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.IMAGE_DELTA,
|
||||||
|
image: {
|
||||||
|
type: 'base64',
|
||||||
|
images: [`data:image/png;base64,${chunk.partial_image_b64}`]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
break
|
||||||
|
case 'response.image_generation_call.completed':
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.IMAGE_COMPLETE
|
||||||
|
})
|
||||||
|
break
|
||||||
|
case 'response.output_text.delta': {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.TEXT_DELTA,
|
||||||
|
text: chunk.delta
|
||||||
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'response.function_call_arguments.done': {
|
||||||
|
const outputItem: OpenAI.Responses.ResponseOutputItem | undefined = outputItems.find(
|
||||||
|
(item) => item.id === chunk.item_id
|
||||||
|
)
|
||||||
|
if (outputItem) {
|
||||||
|
if (outputItem.type === 'function_call') {
|
||||||
|
toolCalls.push({
|
||||||
|
...outputItem,
|
||||||
|
arguments: chunk.arguments
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'response.content_part.done': {
|
||||||
|
if (chunk.part.type === 'output_text' && chunk.part.annotations && chunk.part.annotations.length > 0) {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
|
||||||
|
llm_web_search: {
|
||||||
|
source: WebSearchSource.OPENAI_RESPONSE,
|
||||||
|
results: chunk.part.annotations
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
if (toolCalls.length > 0) {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.MCP_TOOL_CREATED,
|
||||||
|
tool_calls: toolCalls
|
||||||
|
})
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'response.completed': {
|
||||||
|
const completion_tokens = chunk.response.usage?.output_tokens || 0
|
||||||
|
const total_tokens = chunk.response.usage?.total_tokens || 0
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||||
|
response: {
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: chunk.response.usage?.input_tokens || 0,
|
||||||
|
completion_tokens: completion_tokens,
|
||||||
|
total_tokens: total_tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case 'error': {
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.ERROR,
|
||||||
|
error: {
|
||||||
|
message: chunk.message,
|
||||||
|
code: chunk.code
|
||||||
|
}
|
||||||
|
})
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
129
src/renderer/src/aiCore/clients/types.ts
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
import Anthropic from '@anthropic-ai/sdk'
|
||||||
|
import { Assistant, MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types'
|
||||||
|
import { Provider } from '@renderer/types'
|
||||||
|
import {
|
||||||
|
AnthropicSdkRawChunk,
|
||||||
|
OpenAISdkRawChunk,
|
||||||
|
SdkMessageParam,
|
||||||
|
SdkParams,
|
||||||
|
SdkRawChunk,
|
||||||
|
SdkRawOutput,
|
||||||
|
SdkTool,
|
||||||
|
SdkToolCall
|
||||||
|
} from '@renderer/types/sdk'
|
||||||
|
import OpenAI from 'openai'
|
||||||
|
|
||||||
|
import { CompletionsParams, GenericChunk } from '../middleware/schemas'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 原始流监听器接口
|
||||||
|
*/
|
||||||
|
export interface RawStreamListener<TRawChunk = SdkRawChunk> {
|
||||||
|
onChunk?: (chunk: TRawChunk) => void
|
||||||
|
onStart?: () => void
|
||||||
|
onEnd?: () => void
|
||||||
|
onError?: (error: Error) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* OpenAI 专用的流监听器
|
||||||
|
*/
|
||||||
|
export interface OpenAIStreamListener extends RawStreamListener<OpenAISdkRawChunk> {
|
||||||
|
onChoice?: (choice: OpenAI.Chat.Completions.ChatCompletionChunk.Choice) => void
|
||||||
|
onFinishReason?: (reason: string) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Anthropic 专用的流监听器
|
||||||
|
*/
|
||||||
|
export interface AnthropicStreamListener<TChunk extends AnthropicSdkRawChunk = AnthropicSdkRawChunk>
|
||||||
|
extends RawStreamListener<TChunk> {
|
||||||
|
onContentBlock?: (contentBlock: Anthropic.Messages.ContentBlock) => void
|
||||||
|
onMessage?: (message: Anthropic.Messages.Message) => void
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 请求转换器接口
|
||||||
|
*/
|
||||||
|
export interface RequestTransformer<
|
||||||
|
TSdkParams extends SdkParams = SdkParams,
|
||||||
|
TMessageParam extends SdkMessageParam = SdkMessageParam
|
||||||
|
> {
|
||||||
|
transform(
|
||||||
|
completionsParams: CompletionsParams,
|
||||||
|
assistant: Assistant,
|
||||||
|
model: Model,
|
||||||
|
isRecursiveCall?: boolean,
|
||||||
|
recursiveSdkMessages?: TMessageParam[]
|
||||||
|
): Promise<{
|
||||||
|
payload: TSdkParams
|
||||||
|
messages: TMessageParam[]
|
||||||
|
metadata?: Record<string, any>
|
||||||
|
}>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 响应块转换器接口
|
||||||
|
*/
|
||||||
|
export type ResponseChunkTransformer<TRawChunk extends SdkRawChunk = SdkRawChunk, TContext = any> = (
|
||||||
|
context?: TContext
|
||||||
|
) => Transformer<TRawChunk, GenericChunk>
|
||||||
|
|
||||||
|
export interface ResponseChunkTransformerContext {
|
||||||
|
isStreaming: boolean
|
||||||
|
isEnabledToolCalling: boolean
|
||||||
|
isEnabledWebSearch: boolean
|
||||||
|
isEnabledReasoning: boolean
|
||||||
|
mcpTools: MCPTool[]
|
||||||
|
provider: Provider
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* API客户端接口
|
||||||
|
*/
|
||||||
|
export interface ApiClient<
|
||||||
|
TSdkInstance = any,
|
||||||
|
TSdkParams extends SdkParams = SdkParams,
|
||||||
|
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||||
|
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||||
|
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||||
|
TToolCall extends SdkToolCall = SdkToolCall,
|
||||||
|
TSdkSpecificTool extends SdkTool = SdkTool
|
||||||
|
> {
|
||||||
|
provider: Provider
|
||||||
|
|
||||||
|
// 核心方法 - 在中间件架构中,这个方法可能只是一个占位符
|
||||||
|
// 实际的SDK调用由SdkCallMiddleware处理
|
||||||
|
// completions(params: CompletionsParams): Promise<CompletionsResult>
|
||||||
|
|
||||||
|
createCompletions(payload: TSdkParams): Promise<TRawOutput>
|
||||||
|
|
||||||
|
// SDK相关方法
|
||||||
|
getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
|
||||||
|
getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
|
||||||
|
getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk>
|
||||||
|
|
||||||
|
// 原始流监听方法
|
||||||
|
attachRawStreamListener?(rawOutput: TRawOutput, listener: RawStreamListener<TRawChunk>): TRawOutput
|
||||||
|
|
||||||
|
// 工具转换相关方法 (保持可选,因为不是所有Provider都支持工具)
|
||||||
|
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): TSdkSpecificTool[]
|
||||||
|
convertMcpToolResponseToSdkMessageParam?(
|
||||||
|
mcpToolResponse: MCPToolResponse,
|
||||||
|
resp: any,
|
||||||
|
model: Model
|
||||||
|
): TMessageParam | undefined
|
||||||
|
convertSdkToolCallToMcp?(toolCall: TToolCall, mcpTools: MCPTool[]): MCPTool | undefined
|
||||||
|
convertSdkToolCallToMcpToolResponse(toolCall: TToolCall, mcpTool: MCPTool): ToolCallResponse
|
||||||
|
|
||||||
|
// 构建SDK特定的消息列表,用于工具调用后的递归调用
|
||||||
|
buildSdkMessages(
|
||||||
|
currentReqMessages: TMessageParam[],
|
||||||
|
output: TRawOutput | string,
|
||||||
|
toolResults: TMessageParam[],
|
||||||
|
toolCalls?: TToolCall[]
|
||||||
|
): TMessageParam[]
|
||||||
|
|
||||||
|
// 从SDK载荷中提取消息数组(用于中间件中的类型安全访问)
|
||||||
|
extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
|
||||||
|
}
|
||||||
130
src/renderer/src/aiCore/index.ts
Normal file
@@ -0,0 +1,130 @@
|
|||||||
|
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
|
||||||
|
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
|
||||||
|
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
|
||||||
|
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
|
||||||
|
import { RequestOptions, SdkModel } from '@renderer/types/sdk'
|
||||||
|
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
|
||||||
|
|
||||||
|
import { OpenAIAPIClient } from './clients'
|
||||||
|
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
|
||||||
|
import { AnthropicAPIClient } from './clients/anthropic/AnthropicAPIClient'
|
||||||
|
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
|
||||||
|
import { CompletionsMiddlewareBuilder } from './middleware/builder'
|
||||||
|
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
|
||||||
|
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
|
||||||
|
import { applyCompletionsMiddlewares } from './middleware/composer'
|
||||||
|
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
|
||||||
|
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
|
||||||
|
import { MIDDLEWARE_NAME as ThinkChunkMiddlewareName } from './middleware/core/ThinkChunkMiddleware'
|
||||||
|
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
|
||||||
|
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
|
||||||
|
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
|
||||||
|
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
|
||||||
|
import { MiddlewareRegistry } from './middleware/register'
|
||||||
|
import { CompletionsParams, CompletionsResult } from './middleware/schemas'
|
||||||
|
|
||||||
|
export default class AiProvider {
|
||||||
|
private apiClient: BaseApiClient
|
||||||
|
|
||||||
|
constructor(provider: Provider) {
|
||||||
|
// Use the new ApiClientFactory to get a BaseApiClient instance
|
||||||
|
this.apiClient = ApiClientFactory.create(provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
|
||||||
|
// 1. 根据模型识别正确的客户端
|
||||||
|
const model = params.assistant.model
|
||||||
|
if (!model) {
|
||||||
|
return Promise.reject(new Error('Model is required'))
|
||||||
|
}
|
||||||
|
|
||||||
|
// 根据client类型选择合适的处理方式
|
||||||
|
let client: BaseApiClient
|
||||||
|
|
||||||
|
if (this.apiClient instanceof AihubmixAPIClient) {
|
||||||
|
// AihubmixAPIClient: 根据模型选择合适的子client
|
||||||
|
client = this.apiClient.getClientForModel(model)
|
||||||
|
if (client instanceof OpenAIResponseAPIClient) {
|
||||||
|
client = client.getClient(model) as BaseApiClient
|
||||||
|
}
|
||||||
|
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
|
||||||
|
// OpenAIResponseAPIClient: 根据模型特征选择API类型
|
||||||
|
client = this.apiClient.getClient(model) as BaseApiClient
|
||||||
|
} else {
|
||||||
|
// 其他client直接使用
|
||||||
|
client = this.apiClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 构建中间件链
|
||||||
|
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||||
|
// images api
|
||||||
|
if (isDedicatedImageGenerationModel(model)) {
|
||||||
|
builder.clear()
|
||||||
|
builder
|
||||||
|
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
|
||||||
|
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
|
||||||
|
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
|
||||||
|
} else {
|
||||||
|
// Existing logic for other models
|
||||||
|
if (!params.enableReasoning) {
|
||||||
|
builder.remove(ThinkingTagExtractionMiddlewareName)
|
||||||
|
builder.remove(ThinkChunkMiddlewareName)
|
||||||
|
}
|
||||||
|
// 注意:用client判断会导致typescript类型收窄
|
||||||
|
if (!(this.apiClient instanceof OpenAIAPIClient)) {
|
||||||
|
builder.remove(ThinkingTagExtractionMiddlewareName)
|
||||||
|
}
|
||||||
|
if (!(this.apiClient instanceof AnthropicAPIClient)) {
|
||||||
|
builder.remove(RawStreamListenerMiddlewareName)
|
||||||
|
}
|
||||||
|
if (!params.enableWebSearch) {
|
||||||
|
builder.remove(WebSearchMiddlewareName)
|
||||||
|
}
|
||||||
|
if (!params.mcpTools?.length) {
|
||||||
|
builder.remove(ToolUseExtractionMiddlewareName)
|
||||||
|
builder.remove(McpToolChunkMiddlewareName)
|
||||||
|
}
|
||||||
|
if (isEnabledToolUse(params.assistant) && isFunctionCallingModel(model)) {
|
||||||
|
builder.remove(ToolUseExtractionMiddlewareName)
|
||||||
|
}
|
||||||
|
if (params.callType !== 'chat') {
|
||||||
|
builder.remove(AbortHandlerMiddlewareName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const middlewares = builder.build()
|
||||||
|
|
||||||
|
// 3. Create the wrapped SDK method with middlewares
|
||||||
|
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
|
||||||
|
|
||||||
|
// 4. Execute the wrapped method with the original params
|
||||||
|
return wrappedCompletionMethod(params, options)
|
||||||
|
}
|
||||||
|
|
||||||
|
public async models(): Promise<SdkModel[]> {
|
||||||
|
return this.apiClient.listModels()
|
||||||
|
}
|
||||||
|
|
||||||
|
public async getEmbeddingDimensions(model: Model): Promise<number> {
|
||||||
|
try {
|
||||||
|
// Use the SDK instance to test embedding capabilities
|
||||||
|
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
|
||||||
|
return dimensions
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error getting embedding dimensions:', error)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public async generateImage(params: GenerateImageParams): Promise<string[]> {
|
||||||
|
return this.apiClient.generateImage(params)
|
||||||
|
}
|
||||||
|
|
||||||
|
public getBaseURL(): string {
|
||||||
|
return this.apiClient.getBaseURL()
|
||||||
|
}
|
||||||
|
|
||||||
|
public getApiKey(): string {
|
||||||
|
return this.apiClient.getApiKey()
|
||||||
|
}
|
||||||
|
}
|
||||||
182
src/renderer/src/aiCore/middleware/BUILDER_USAGE.md
Normal file
@@ -0,0 +1,182 @@
|
|||||||
|
# MiddlewareBuilder 使用指南
|
||||||
|
|
||||||
|
`MiddlewareBuilder` 是一个用于动态构建和管理中间件链的工具,提供灵活的中间件组织和配置能力。
|
||||||
|
|
||||||
|
## 主要特性
|
||||||
|
|
||||||
|
### 1. 统一的中间件命名
|
||||||
|
|
||||||
|
所有中间件都通过导出的 `MIDDLEWARE_NAME` 常量标识:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// 中间件文件示例
|
||||||
|
export const MIDDLEWARE_NAME = 'SdkCallMiddleware'
|
||||||
|
export const SdkCallMiddleware: CompletionsMiddleware = ...
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. NamedMiddleware 接口
|
||||||
|
|
||||||
|
中间件使用统一的 `NamedMiddleware` 接口格式:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
interface NamedMiddleware<TMiddleware = any> {
|
||||||
|
name: string
|
||||||
|
middleware: TMiddleware
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 中间件注册表
|
||||||
|
|
||||||
|
通过 `MiddlewareRegistry` 集中管理所有可用中间件:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { MiddlewareRegistry } from './register'
|
||||||
|
|
||||||
|
// 通过名称获取中间件
|
||||||
|
const sdkCallMiddleware = MiddlewareRegistry['SdkCallMiddleware']
|
||||||
|
```
|
||||||
|
|
||||||
|
## 基本用法
|
||||||
|
|
||||||
|
### 1. 使用默认中间件链
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { CompletionsMiddlewareBuilder } from './builder'
|
||||||
|
|
||||||
|
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||||
|
const middlewares = builder.build()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 自定义中间件链
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
import { createCompletionsBuilder, MiddlewareRegistry } from './builder'
|
||||||
|
|
||||||
|
const builder = createCompletionsBuilder([
|
||||||
|
MiddlewareRegistry['AbortHandlerMiddleware'],
|
||||||
|
MiddlewareRegistry['TextChunkMiddleware']
|
||||||
|
])
|
||||||
|
|
||||||
|
const middlewares = builder.build()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. 动态调整中间件链
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||||
|
|
||||||
|
// 根据条件添加、移除、替换中间件
|
||||||
|
if (needsLogging) {
|
||||||
|
builder.prepend(MiddlewareRegistry['GenericLoggingMiddleware'])
|
||||||
|
}
|
||||||
|
|
||||||
|
if (disableTools) {
|
||||||
|
builder.remove('McpToolChunkMiddleware')
|
||||||
|
}
|
||||||
|
|
||||||
|
if (customThinking) {
|
||||||
|
builder.replace('ThinkingTagExtractionMiddleware', customThinkingMiddleware)
|
||||||
|
}
|
||||||
|
|
||||||
|
const middlewares = builder.build()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. 链式操作
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
const middlewares = CompletionsMiddlewareBuilder.withDefaults()
|
||||||
|
.add(MiddlewareRegistry['CustomMiddleware'])
|
||||||
|
.insertBefore('SdkCallMiddleware', MiddlewareRegistry['SecurityCheckMiddleware'])
|
||||||
|
.remove('WebSearchMiddleware')
|
||||||
|
.build()
|
||||||
|
```
|
||||||
|
|
||||||
|
## API 参考
|
||||||
|
|
||||||
|
### CompletionsMiddlewareBuilder
|
||||||
|
|
||||||
|
**静态方法:**
|
||||||
|
|
||||||
|
- `static withDefaults()`: 创建带有默认中间件链的构建器
|
||||||
|
|
||||||
|
**实例方法:**
|
||||||
|
|
||||||
|
- `add(middleware: NamedMiddleware)`: 在链末尾添加中间件
|
||||||
|
- `prepend(middleware: NamedMiddleware)`: 在链开头添加中间件
|
||||||
|
- `insertAfter(targetName: string, middleware: NamedMiddleware)`: 在指定中间件后插入
|
||||||
|
- `insertBefore(targetName: string, middleware: NamedMiddleware)`: 在指定中间件前插入
|
||||||
|
- `replace(targetName: string, middleware: NamedMiddleware)`: 替换指定中间件
|
||||||
|
- `remove(targetName: string)`: 移除指定中间件
|
||||||
|
- `has(name: string)`: 检查是否包含指定中间件
|
||||||
|
- `build()`: 构建最终的中间件数组
|
||||||
|
- `getChain()`: 获取当前链(包含名称信息)
|
||||||
|
- `clear()`: 清空中间件链
|
||||||
|
- `execute(context, params, middlewareExecutor)`: 直接执行构建好的中间件链
|
||||||
|
|
||||||
|
### 工厂函数
|
||||||
|
|
||||||
|
- `createCompletionsBuilder(baseChain?)`: 创建 Completions 中间件构建器
|
||||||
|
- `createMethodBuilder(baseChain?)`: 创建通用方法中间件构建器
|
||||||
|
- `addMiddlewareName(middleware, name)`: 为中间件添加名称属性的辅助函数
|
||||||
|
|
||||||
|
### 中间件注册表
|
||||||
|
|
||||||
|
- `MiddlewareRegistry`: 所有注册中间件的集中访问点
|
||||||
|
- `getMiddleware(name)`: 根据名称获取中间件
|
||||||
|
- `getRegisteredMiddlewareNames()`: 获取所有注册的中间件名称
|
||||||
|
- `DefaultCompletionsNamedMiddlewares`: 默认的 Completions 中间件链(NamedMiddleware 格式)
|
||||||
|
|
||||||
|
## 类型安全
|
||||||
|
|
||||||
|
构建器提供完整的 TypeScript 类型支持:
|
||||||
|
|
||||||
|
- `CompletionsMiddlewareBuilder` 专门用于 `CompletionsMiddleware` 类型
|
||||||
|
- `MethodMiddlewareBuilder` 用于通用的 `MethodMiddleware` 类型
|
||||||
|
- 所有中间件操作都基于 `NamedMiddleware<TMiddleware>` 接口
|
||||||
|
|
||||||
|
## 默认中间件链
|
||||||
|
|
||||||
|
默认的 Completions 中间件执行顺序:
|
||||||
|
|
||||||
|
1. `FinalChunkConsumerMiddleware` - 最终消费者
|
||||||
|
2. `TransformCoreToSdkParamsMiddleware` - 参数转换
|
||||||
|
3. `AbortHandlerMiddleware` - 中止处理
|
||||||
|
4. `McpToolChunkMiddleware` - 工具处理
|
||||||
|
5. `WebSearchMiddleware` - Web搜索处理
|
||||||
|
6. `TextChunkMiddleware` - 文本处理
|
||||||
|
7. `ThinkingTagExtractionMiddleware` - 思考标签提取处理
|
||||||
|
8. `ThinkChunkMiddleware` - 思考处理
|
||||||
|
9. `ResponseTransformMiddleware` - 响应转换
|
||||||
|
10. `StreamAdapterMiddleware` - 流适配器
|
||||||
|
11. `SdkCallMiddleware` - SDK调用
|
||||||
|
|
||||||
|
## 在 AiProvider 中的使用
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
export default class AiProvider {
|
||||||
|
public async completions(params: CompletionsParams): Promise<CompletionsResult> {
|
||||||
|
// 1. 构建中间件链
|
||||||
|
const builder = CompletionsMiddlewareBuilder.withDefaults()
|
||||||
|
|
||||||
|
// 2. 根据参数动态调整
|
||||||
|
if (params.enableCustomFeature) {
|
||||||
|
builder.insertAfter('StreamAdapterMiddleware', customFeatureMiddleware)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 应用中间件
|
||||||
|
const middlewares = builder.build()
|
||||||
|
const wrappedMethod = applyCompletionsMiddlewares(this.apiClient, this.apiClient.createCompletions, middlewares)
|
||||||
|
|
||||||
|
return wrappedMethod(params)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 注意事项
|
||||||
|
|
||||||
|
1. **类型兼容性**:`MethodMiddleware` 和 `CompletionsMiddleware` 不兼容,需要使用对应的构建器
|
||||||
|
2. **中间件名称**:所有中间件必须导出 `MIDDLEWARE_NAME` 常量用于标识
|
||||||
|
3. **注册表管理**:新增中间件需要在 `register.ts` 中注册
|
||||||
|
4. **默认链**:默认链通过 `DefaultCompletionsNamedMiddlewares` 提供,支持延迟加载避免循环依赖
|
||||||
|
|
||||||
|
这种设计使得中间件链的构建既灵活又类型安全,同时保持了简洁的 API 接口。
|
||||||
175
src/renderer/src/aiCore/middleware/MIDDLEWARE_SPECIFICATION.md
Normal file
@@ -0,0 +1,175 @@
|
|||||||
|
# Cherry Studio 中间件规范
|
||||||
|
|
||||||
|
本文档定义了 Cherry Studio `aiCore` 模块中中间件的设计、实现和使用规范。目标是建立一个灵活、可维护且易于扩展的中间件系统。
|
||||||
|
|
||||||
|
## 1. 核心概念
|
||||||
|
|
||||||
|
### 1.1. 中间件 (Middleware)
|
||||||
|
|
||||||
|
中间件是一个函数或对象,它在 AI 请求的处理流程中的特定阶段执行,可以访问和修改请求上下文 (`AiProviderMiddlewareContext`)、请求参数 (`Params`),并控制是否将请求传递给下一个中间件或终止流程。
|
||||||
|
|
||||||
|
每个中间件应该专注于一个单一的横切关注点,例如日志记录、错误处理、流适配、特性解析等。
|
||||||
|
|
||||||
|
### 1.2. `AiProviderMiddlewareContext` (上下文对象)
|
||||||
|
|
||||||
|
这是一个在整个中间件链执行过程中传递的对象,包含以下核心信息:
|
||||||
|
|
||||||
|
- `_apiClientInstance: ApiClient<any,any,any>`: 当前选定的、已实例化的 AI Provider 客户端。
|
||||||
|
- `_coreRequest: CoreRequestType`: 标准化的内部核心请求对象。
|
||||||
|
- `resolvePromise: (value: AggregatedResultType) => void`: 用于在整个操作成功完成时解析 `AiCoreService` 返回的 Promise。
|
||||||
|
- `rejectPromise: (reason?: any) => void`: 用于在发生错误时拒绝 `AiCoreService` 返回的 Promise。
|
||||||
|
- `onChunk?: (chunk: Chunk) => void`: 应用层提供的流式数据块回调。
|
||||||
|
- `abortController?: AbortController`: 用于中止请求的控制器。
|
||||||
|
- 其他中间件可能读写的、与当前请求相关的动态数据。
|
||||||
|
|
||||||
|
### 1.3. `MiddlewareName` (中间件名称)
|
||||||
|
|
||||||
|
为了方便动态操作(如插入、替换、移除)中间件,每个重要的、可能被其他逻辑引用的中间件都应该有一个唯一的、可识别的名称。推荐使用 TypeScript 的 `enum` 来定义:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// example
|
||||||
|
export enum MiddlewareName {
|
||||||
|
LOGGING_START = 'LoggingStartMiddleware',
|
||||||
|
LOGGING_END = 'LoggingEndMiddleware',
|
||||||
|
ERROR_HANDLING = 'ErrorHandlingMiddleware',
|
||||||
|
ABORT_HANDLER = 'AbortHandlerMiddleware',
|
||||||
|
// Core Flow
|
||||||
|
TRANSFORM_CORE_TO_SDK_PARAMS = 'TransformCoreToSdkParamsMiddleware',
|
||||||
|
REQUEST_EXECUTION = 'RequestExecutionMiddleware',
|
||||||
|
STREAM_ADAPTER = 'StreamAdapterMiddleware',
|
||||||
|
RAW_SDK_CHUNK_TO_APP_CHUNK = 'RawSdkChunkToAppChunkMiddleware',
|
||||||
|
// Features
|
||||||
|
THINKING_TAG_EXTRACTION = 'ThinkingTagExtractionMiddleware',
|
||||||
|
TOOL_USE_TAG_EXTRACTION = 'ToolUseTagExtractionMiddleware',
|
||||||
|
MCP_TOOL_HANDLER = 'McpToolHandlerMiddleware',
|
||||||
|
// Finalization
|
||||||
|
FINAL_CHUNK_CONSUMER = 'FinalChunkConsumerAndNotifierMiddleware'
|
||||||
|
// Add more as needed
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
中间件实例需要某种方式暴露其 `MiddlewareName`,例如通过一个 `name` 属性。
|
||||||
|
|
||||||
|
### 1.4. 中间件执行结构
|
||||||
|
|
||||||
|
我们采用一种灵活的中间件执行结构。一个中间件通常是一个函数,它接收 `Context`、`Params`,以及一个 `next` 函数(用于调用链中的下一个中间件)。
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// 简化形式的中间件函数签名
|
||||||
|
type MiddlewareFunction = (
|
||||||
|
context: AiProviderMiddlewareContext,
|
||||||
|
params: any, // e.g., CompletionsParams
|
||||||
|
next: () => Promise<void> // next 通常返回 Promise 以支持异步操作
|
||||||
|
) => Promise<void> // 中间件自身也可能返回 Promise
|
||||||
|
|
||||||
|
// 或者更经典的 Koa/Express 风格 (三段式)
|
||||||
|
// type MiddlewareFactory = (api?: MiddlewareApi) =>
|
||||||
|
// (nextMiddleware: (ctx: AiProviderMiddlewareContext, params: any) => Promise<void>) =>
|
||||||
|
// (context: AiProviderMiddlewareContext, params: any) => Promise<void>;
|
||||||
|
// 当前设计更倾向于上述简化的 MiddlewareFunction,由 MiddlewareExecutor 负责 next 的编排。
|
||||||
|
```
|
||||||
|
|
||||||
|
`MiddlewareExecutor` (或 `applyMiddlewares`) 会负责管理 `next` 的调用。
|
||||||
|
|
||||||
|
## 2. `MiddlewareBuilder` (通用中间件构建器)
|
||||||
|
|
||||||
|
为了动态构建和管理中间件链,我们引入一个通用的 `MiddlewareBuilder` 类。
|
||||||
|
|
||||||
|
### 2.1. 设计理念
|
||||||
|
|
||||||
|
`MiddlewareBuilder` 提供了一个流式 API,用于以声明式的方式构建中间件链。它允许从一个基础链开始,然后根据特定条件添加、插入、替换或移除中间件。
|
||||||
|
|
||||||
|
### 2.2. API 概览
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
class MiddlewareBuilder {
|
||||||
|
constructor(baseChain?: Middleware[])
|
||||||
|
|
||||||
|
add(middleware: Middleware): this
|
||||||
|
prepend(middleware: Middleware): this
|
||||||
|
insertAfter(targetName: MiddlewareName, middlewareToInsert: Middleware): this
|
||||||
|
insertBefore(targetName: MiddlewareName, middlewareToInsert: Middleware): this
|
||||||
|
replace(targetName: MiddlewareName, newMiddleware: Middleware): this
|
||||||
|
remove(targetName: MiddlewareName): this
|
||||||
|
|
||||||
|
build(): Middleware[] // 返回构建好的中间件数组
|
||||||
|
|
||||||
|
// 可选:直接执行链
|
||||||
|
execute(
|
||||||
|
context: AiProviderMiddlewareContext,
|
||||||
|
params: any,
|
||||||
|
middlewareExecutor: (chain: Middleware[], context: AiProviderMiddlewareContext, params: any) => void
|
||||||
|
): void
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2.3. 使用示例
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// 1. 定义一些中间件实例 (假设它们有 .name 属性)
|
||||||
|
const loggingStart = { name: MiddlewareName.LOGGING_START, fn: loggingStartFn }
|
||||||
|
const requestExec = { name: MiddlewareName.REQUEST_EXECUTION, fn: requestExecFn }
|
||||||
|
const streamAdapter = { name: MiddlewareName.STREAM_ADAPTER, fn: streamAdapterFn }
|
||||||
|
const customFeature = { name: MiddlewareName.CUSTOM_FEATURE, fn: customFeatureFn } // 假设自定义
|
||||||
|
|
||||||
|
// 2. 定义一个基础链 (可选)
|
||||||
|
const BASE_CHAIN: Middleware[] = [loggingStart, requestExec, streamAdapter]
|
||||||
|
|
||||||
|
// 3. 使用 MiddlewareBuilder
|
||||||
|
const builder = new MiddlewareBuilder(BASE_CHAIN)
|
||||||
|
|
||||||
|
if (params.needsCustomFeature) {
|
||||||
|
builder.insertAfter(MiddlewareName.STREAM_ADAPTER, customFeature)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.isHighSecurityContext) {
|
||||||
|
builder.insertBefore(MiddlewareName.REQUEST_EXECUTION, высокоSecurityCheckMiddleware)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.overrideLogging) {
|
||||||
|
builder.replace(MiddlewareName.LOGGING_START, newSpecialLoggingMiddleware)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 4. 获取最终链
|
||||||
|
const finalChain = builder.build()
|
||||||
|
|
||||||
|
// 5. 执行 (通过外部执行器)
|
||||||
|
// middlewareExecutor(finalChain, context, params);
|
||||||
|
// 或者 builder.execute(context, params, middlewareExecutor);
|
||||||
|
```
|
||||||
|
|
||||||
|
## 3. `MiddlewareExecutor` / `applyMiddlewares` (中间件执行器)
|
||||||
|
|
||||||
|
这是负责接收 `MiddlewareBuilder` 构建的中间件链并实际执行它们的组件。
|
||||||
|
|
||||||
|
### 3.1. 职责
|
||||||
|
|
||||||
|
- 接收 `Middleware[]`, `AiProviderMiddlewareContext`, `Params`。
|
||||||
|
- 按顺序迭代中间件。
|
||||||
|
- 为每个中间件提供正确的 `next` 函数,该函数在被调用时会执行链中的下一个中间件。
|
||||||
|
- 处理中间件执行过程中的Promise(如果中间件是异步的)。
|
||||||
|
- 基础的错误捕获(具体错误处理应由链内的 `ErrorHandlingMiddleware` 负责)。
|
||||||
|
|
||||||
|
## 4. 在 `AiCoreService` 中使用
|
||||||
|
|
||||||
|
`AiCoreService` 中的每个核心业务方法 (如 `executeCompletions`) 将负责:
|
||||||
|
|
||||||
|
1. 准备基础数据:实例化 `ApiClient`,转换 `Params` 为 `CoreRequest`。
|
||||||
|
2. 实例化 `MiddlewareBuilder`,可能会传入一个特定于该业务方法的基础中间件链。
|
||||||
|
3. 根据 `Params` 和 `CoreRequest` 中的条件,调用 `MiddlewareBuilder` 的方法来动态调整中间件链。
|
||||||
|
4. 调用 `MiddlewareBuilder.build()` 获取最终的中间件链。
|
||||||
|
5. 创建完整的 `AiProviderMiddlewareContext` (包含 `resolvePromise`, `rejectPromise` 等)。
|
||||||
|
6. 调用 `MiddlewareExecutor` (或 `applyMiddlewares`) 来执行构建好的链。
|
||||||
|
|
||||||
|
## 5. 组合功能
|
||||||
|
|
||||||
|
对于组合功能(例如 "Completions then Translate"):
|
||||||
|
|
||||||
|
- 不推荐创建一个单一、庞大的 `MiddlewareBuilder` 来处理整个组合流程。
|
||||||
|
- 推荐在 `AiCoreService` 中创建一个新的方法,该方法按顺序 `await` 调用底层的原子 `AiCoreService` 方法(例如,先 `await this.executeCompletions(...)`,然后用其结果 `await this.translateText(...)`)。
|
||||||
|
- 每个被调用的原子方法内部会使用其自身的 `MiddlewareBuilder` 实例来构建和执行其特定阶段的中间件链。
|
||||||
|
- 这种方式最大化了复用,并保持了各部分职责的清晰。
|
||||||
|
|
||||||
|
## 6. 中间件命名和发现
|
||||||
|
|
||||||
|
为中间件赋予唯一的 `MiddlewareName` 对于 `MiddlewareBuilder` 的 `insertAfter`, `insertBefore`, `replace`, `remove` 等操作至关重要。确保中间件实例能够以某种方式暴露其名称(例如,一个 `name` 属性)。
|
||||||
241
src/renderer/src/aiCore/middleware/builder.ts
Normal file
@@ -0,0 +1,241 @@
|
|||||||
|
import { DefaultCompletionsNamedMiddlewares } from './register'
|
||||||
|
import { BaseContext, CompletionsMiddleware, MethodMiddleware } from './types'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 带有名称标识的中间件接口
|
||||||
|
*/
|
||||||
|
export interface NamedMiddleware<TMiddleware = any> {
|
||||||
|
name: string
|
||||||
|
middleware: TMiddleware
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 中间件执行器函数类型
|
||||||
|
*/
|
||||||
|
export type MiddlewareExecutor<TContext extends BaseContext = BaseContext> = (
|
||||||
|
chain: any[],
|
||||||
|
context: TContext,
|
||||||
|
params: any
|
||||||
|
) => Promise<any>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 通用中间件构建器类
|
||||||
|
* 提供流式 API 用于动态构建和管理中间件链
|
||||||
|
*
|
||||||
|
* 注意:所有中间件都通过 MiddlewareRegistry 管理,使用 NamedMiddleware 格式
|
||||||
|
*/
|
||||||
|
export class MiddlewareBuilder<TMiddleware = any> {
|
||||||
|
private middlewares: NamedMiddleware<TMiddleware>[]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构造函数
|
||||||
|
* @param baseChain - 可选的基础中间件链(NamedMiddleware 格式)
|
||||||
|
*/
|
||||||
|
constructor(baseChain?: NamedMiddleware<TMiddleware>[]) {
|
||||||
|
this.middlewares = baseChain ? [...baseChain] : []
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 在链的末尾添加中间件
|
||||||
|
* @param middleware - 要添加的具名中间件
|
||||||
|
* @returns this,支持链式调用
|
||||||
|
*/
|
||||||
|
add(middleware: NamedMiddleware<TMiddleware>): this {
|
||||||
|
this.middlewares.push(middleware)
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 在链的开头添加中间件
|
||||||
|
* @param middleware - 要添加的具名中间件
|
||||||
|
* @returns this,支持链式调用
|
||||||
|
*/
|
||||||
|
prepend(middleware: NamedMiddleware<TMiddleware>): this {
|
||||||
|
this.middlewares.unshift(middleware)
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 在指定中间件之后插入新中间件
|
||||||
|
* @param targetName - 目标中间件名称
|
||||||
|
* @param middlewareToInsert - 要插入的具名中间件
|
||||||
|
* @returns this,支持链式调用
|
||||||
|
*/
|
||||||
|
insertAfter(targetName: string, middlewareToInsert: NamedMiddleware<TMiddleware>): this {
|
||||||
|
const index = this.findMiddlewareIndex(targetName)
|
||||||
|
if (index !== -1) {
|
||||||
|
this.middlewares.splice(index + 1, 0, middlewareToInsert)
|
||||||
|
} else {
|
||||||
|
console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`)
|
||||||
|
}
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 在指定中间件之前插入新中间件
|
||||||
|
* @param targetName - 目标中间件名称
|
||||||
|
* @param middlewareToInsert - 要插入的具名中间件
|
||||||
|
* @returns this,支持链式调用
|
||||||
|
*/
|
||||||
|
insertBefore(targetName: string, middlewareToInsert: NamedMiddleware<TMiddleware>): this {
|
||||||
|
const index = this.findMiddlewareIndex(targetName)
|
||||||
|
if (index !== -1) {
|
||||||
|
this.middlewares.splice(index, 0, middlewareToInsert)
|
||||||
|
} else {
|
||||||
|
console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法插入`)
|
||||||
|
}
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 替换指定的中间件
|
||||||
|
* @param targetName - 要替换的中间件名称
|
||||||
|
* @param newMiddleware - 新的具名中间件
|
||||||
|
* @returns this,支持链式调用
|
||||||
|
*/
|
||||||
|
replace(targetName: string, newMiddleware: NamedMiddleware<TMiddleware>): this {
|
||||||
|
const index = this.findMiddlewareIndex(targetName)
|
||||||
|
if (index !== -1) {
|
||||||
|
this.middlewares[index] = newMiddleware
|
||||||
|
} else {
|
||||||
|
console.warn(`MiddlewareBuilder: 未找到名为 '${targetName}' 的中间件,无法替换`)
|
||||||
|
}
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 移除指定的中间件
|
||||||
|
* @param targetName - 要移除的中间件名称
|
||||||
|
* @returns this,支持链式调用
|
||||||
|
*/
|
||||||
|
remove(targetName: string): this {
|
||||||
|
const index = this.findMiddlewareIndex(targetName)
|
||||||
|
if (index !== -1) {
|
||||||
|
this.middlewares.splice(index, 1)
|
||||||
|
}
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构建最终的中间件数组
|
||||||
|
* @returns 构建好的中间件数组
|
||||||
|
*/
|
||||||
|
build(): TMiddleware[] {
|
||||||
|
return this.middlewares.map((item) => item.middleware)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取当前中间件链的副本(包含名称信息)
|
||||||
|
* @returns 当前中间件链的副本
|
||||||
|
*/
|
||||||
|
getChain(): NamedMiddleware<TMiddleware>[] {
|
||||||
|
return [...this.middlewares]
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检查是否包含指定名称的中间件
|
||||||
|
* @param name - 中间件名称
|
||||||
|
* @returns 是否包含该中间件
|
||||||
|
*/
|
||||||
|
has(name: string): boolean {
|
||||||
|
return this.findMiddlewareIndex(name) !== -1
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取中间件链的长度
|
||||||
|
* @returns 中间件数量
|
||||||
|
*/
|
||||||
|
get length(): number {
|
||||||
|
return this.middlewares.length
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 清空中间件链
|
||||||
|
* @returns this,支持链式调用
|
||||||
|
*/
|
||||||
|
clear(): this {
|
||||||
|
this.middlewares = []
|
||||||
|
return this
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 直接执行构建好的中间件链
|
||||||
|
* @param context - 中间件上下文
|
||||||
|
* @param params - 参数
|
||||||
|
* @param middlewareExecutor - 中间件执行器
|
||||||
|
* @returns 执行结果
|
||||||
|
*/
|
||||||
|
execute<TContext extends BaseContext>(
|
||||||
|
context: TContext,
|
||||||
|
params: any,
|
||||||
|
middlewareExecutor: MiddlewareExecutor<TContext>
|
||||||
|
): Promise<any> {
|
||||||
|
const chain = this.build()
|
||||||
|
return middlewareExecutor(chain, context, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 查找中间件在链中的索引
|
||||||
|
* @param name - 中间件名称
|
||||||
|
* @returns 索引,如果未找到返回 -1
|
||||||
|
*/
|
||||||
|
private findMiddlewareIndex(name: string): number {
|
||||||
|
return this.middlewares.findIndex((item) => item.name === name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Completions 中间件构建器
|
||||||
|
*/
|
||||||
|
export class CompletionsMiddlewareBuilder extends MiddlewareBuilder<CompletionsMiddleware> {
|
||||||
|
constructor(baseChain?: NamedMiddleware<CompletionsMiddleware>[]) {
|
||||||
|
super(baseChain)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 使用默认的 Completions 中间件链
|
||||||
|
* @returns CompletionsMiddlewareBuilder 实例
|
||||||
|
*/
|
||||||
|
static withDefaults(): CompletionsMiddlewareBuilder {
|
||||||
|
return new CompletionsMiddlewareBuilder(DefaultCompletionsNamedMiddlewares)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 通用方法中间件构建器
|
||||||
|
*/
|
||||||
|
export class MethodMiddlewareBuilder extends MiddlewareBuilder<MethodMiddleware> {
|
||||||
|
constructor(baseChain?: NamedMiddleware<MethodMiddleware>[]) {
|
||||||
|
super(baseChain)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 便捷的工厂函数
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建 Completions 中间件构建器
|
||||||
|
* @param baseChain - 可选的基础链
|
||||||
|
* @returns Completions 中间件构建器实例
|
||||||
|
*/
|
||||||
|
export function createCompletionsBuilder(
|
||||||
|
baseChain?: NamedMiddleware<CompletionsMiddleware>[]
|
||||||
|
): CompletionsMiddlewareBuilder {
|
||||||
|
return new CompletionsMiddlewareBuilder(baseChain)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建通用方法中间件构建器
|
||||||
|
* @param baseChain - 可选的基础链
|
||||||
|
* @returns 通用方法中间件构建器实例
|
||||||
|
*/
|
||||||
|
export function createMethodBuilder(baseChain?: NamedMiddleware<MethodMiddleware>[]): MethodMiddlewareBuilder {
|
||||||
|
return new MethodMiddlewareBuilder(baseChain)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 为中间件添加名称属性的辅助函数
|
||||||
|
* 可以用于给现有的中间件添加名称属性
|
||||||
|
*/
|
||||||
|
export function addMiddlewareName<T extends object>(middleware: T, name: string): T & { MIDDLEWARE_NAME: string } {
|
||||||
|
return Object.assign(middleware, { MIDDLEWARE_NAME: name })
|
||||||
|
}
|
||||||
@@ -0,0 +1,106 @@
|
|||||||
|
import { Chunk, ChunkType, ErrorChunk } from '@renderer/types/chunk'
|
||||||
|
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
|
||||||
|
|
||||||
|
import { CompletionsParams, CompletionsResult } from '../schemas'
|
||||||
|
import type { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||||
|
|
||||||
|
export const MIDDLEWARE_NAME = 'AbortHandlerMiddleware'
|
||||||
|
|
||||||
|
export const AbortHandlerMiddleware: CompletionsMiddleware =
|
||||||
|
() =>
|
||||||
|
(next) =>
|
||||||
|
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||||
|
const isRecursiveCall = ctx._internal?.toolProcessingState?.isRecursiveCall || false
|
||||||
|
|
||||||
|
// 在递归调用中,跳过 AbortController 的创建,直接使用已有的
|
||||||
|
if (isRecursiveCall) {
|
||||||
|
const result = await next(ctx, params)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取当前消息的ID用于abort管理
|
||||||
|
// 优先使用处理过的消息,如果没有则使用原始消息
|
||||||
|
let messageId: string | undefined
|
||||||
|
|
||||||
|
if (typeof params.messages === 'string') {
|
||||||
|
messageId = `message-${Date.now()}-${Math.random().toString(36).substring(2, 9)}`
|
||||||
|
} else {
|
||||||
|
const processedMessages = params.messages
|
||||||
|
const lastUserMessage = processedMessages.findLast((m) => m.role === 'user')
|
||||||
|
messageId = lastUserMessage?.id
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!messageId) {
|
||||||
|
console.warn(`[${MIDDLEWARE_NAME}] No messageId found, abort functionality will not be available.`)
|
||||||
|
return next(ctx, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
const abortController = new AbortController()
|
||||||
|
const abortFn = (): void => abortController.abort()
|
||||||
|
|
||||||
|
addAbortController(messageId, abortFn)
|
||||||
|
|
||||||
|
let abortSignal: AbortSignal | null = abortController.signal
|
||||||
|
|
||||||
|
const cleanup = (): void => {
|
||||||
|
removeAbortController(messageId as string, abortFn)
|
||||||
|
if (ctx._internal?.flowControl) {
|
||||||
|
ctx._internal.flowControl.abortController = undefined
|
||||||
|
ctx._internal.flowControl.abortSignal = undefined
|
||||||
|
ctx._internal.flowControl.cleanup = undefined
|
||||||
|
}
|
||||||
|
abortSignal = null
|
||||||
|
}
|
||||||
|
|
||||||
|
// 将controller添加到_internal中的flowControl状态
|
||||||
|
if (!ctx._internal.flowControl) {
|
||||||
|
ctx._internal.flowControl = {}
|
||||||
|
}
|
||||||
|
ctx._internal.flowControl.abortController = abortController
|
||||||
|
ctx._internal.flowControl.abortSignal = abortSignal
|
||||||
|
ctx._internal.flowControl.cleanup = cleanup
|
||||||
|
|
||||||
|
const result = await next(ctx, params)
|
||||||
|
|
||||||
|
const error = new DOMException('Request was aborted', 'AbortError')
|
||||||
|
|
||||||
|
const streamWithAbortHandler = (result.stream as ReadableStream<Chunk>).pipeThrough(
|
||||||
|
new TransformStream<Chunk, Chunk | ErrorChunk>({
|
||||||
|
transform(chunk, controller) {
|
||||||
|
// 检查 abort 状态
|
||||||
|
if (abortSignal?.aborted) {
|
||||||
|
// 转换为 ErrorChunk
|
||||||
|
const errorChunk: ErrorChunk = {
|
||||||
|
type: ChunkType.ERROR,
|
||||||
|
error
|
||||||
|
}
|
||||||
|
|
||||||
|
controller.enqueue(errorChunk)
|
||||||
|
cleanup()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 正常传递 chunk
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
},
|
||||||
|
|
||||||
|
flush(controller) {
|
||||||
|
// 在流结束时再次检查 abort 状态
|
||||||
|
if (abortSignal?.aborted) {
|
||||||
|
const errorChunk: ErrorChunk = {
|
||||||
|
type: ChunkType.ERROR,
|
||||||
|
error
|
||||||
|
}
|
||||||
|
controller.enqueue(errorChunk)
|
||||||
|
}
|
||||||
|
// 在流完全处理完成后清理 AbortController
|
||||||
|
cleanup()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
...result,
|
||||||
|
stream: streamWithAbortHandler
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,60 @@
|
|||||||
|
import { Chunk } from '@renderer/types/chunk'
|
||||||
|
import { isAbortError } from '@renderer/utils/error'
|
||||||
|
|
||||||
|
import { CompletionsResult } from '../schemas'
|
||||||
|
import { CompletionsContext } from '../types'
|
||||||
|
import { createErrorChunk } from '../utils'
|
||||||
|
|
||||||
|
export const MIDDLEWARE_NAME = 'ErrorHandlerMiddleware'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建一个错误处理中间件。
|
||||||
|
*
|
||||||
|
* 这是一个高阶函数,它接收配置并返回一个标准的中间件。
|
||||||
|
* 它的主要职责是捕获下游中间件或API调用中发生的任何错误。
|
||||||
|
*
|
||||||
|
* @param config - 中间件的配置。
|
||||||
|
* @returns 一个配置好的CompletionsMiddleware。
|
||||||
|
*/
|
||||||
|
export const ErrorHandlerMiddleware =
|
||||||
|
() =>
|
||||||
|
(next) =>
|
||||||
|
async (ctx: CompletionsContext, params): Promise<CompletionsResult> => {
|
||||||
|
const { shouldThrow } = params
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 尝试执行下一个中间件
|
||||||
|
return await next(ctx, params)
|
||||||
|
} catch (error: any) {
|
||||||
|
let errorStream: ReadableStream<Chunk> | undefined
|
||||||
|
// 有些sdk的abort error 是直接抛出的
|
||||||
|
if (!isAbortError(error)) {
|
||||||
|
// 1. 使用通用的工具函数将错误解析为标准格式
|
||||||
|
const errorChunk = createErrorChunk(error)
|
||||||
|
// 2. 调用从外部传入的 onError 回调
|
||||||
|
if (params.onError) {
|
||||||
|
params.onError(error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 3. 根据配置决定是重新抛出错误,还是将其作为流的一部分向下传递
|
||||||
|
if (shouldThrow) {
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果不抛出,则创建一个只包含该错误块的流并向下传递
|
||||||
|
errorStream = new ReadableStream<Chunk>({
|
||||||
|
start(controller) {
|
||||||
|
controller.enqueue(errorChunk)
|
||||||
|
controller.close()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
rawOutput: undefined,
|
||||||
|
stream: errorStream, // 将包含错误的流传递下去
|
||||||
|
controller: undefined,
|
||||||
|
getText: () => '' // 错误情况下没有文本结果
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,183 @@
|
|||||||
|
import Logger from '@renderer/config/logger'
|
||||||
|
import { Usage } from '@renderer/types'
|
||||||
|
import type { Chunk } from '@renderer/types/chunk'
|
||||||
|
import { ChunkType } from '@renderer/types/chunk'
|
||||||
|
|
||||||
|
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||||
|
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||||
|
|
||||||
|
export const MIDDLEWARE_NAME = 'FinalChunkConsumerAndNotifierMiddleware'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 最终Chunk消费和通知中间件
|
||||||
|
*
|
||||||
|
* 职责:
|
||||||
|
* 1. 消费所有GenericChunk流中的chunks并转发给onChunk回调
|
||||||
|
* 2. 累加usage/metrics数据(从原始SDK chunks或GenericChunk中提取)
|
||||||
|
* 3. 在检测到LLM_RESPONSE_COMPLETE时发送包含累计数据的BLOCK_COMPLETE
|
||||||
|
* 4. 处理MCP工具调用的多轮请求中的数据累加
|
||||||
|
*/
|
||||||
|
const FinalChunkConsumerMiddleware: CompletionsMiddleware =
|
||||||
|
() =>
|
||||||
|
(next) =>
|
||||||
|
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||||
|
const isRecursiveCall =
|
||||||
|
params._internal?.toolProcessingState?.isRecursiveCall ||
|
||||||
|
ctx._internal?.toolProcessingState?.isRecursiveCall ||
|
||||||
|
false
|
||||||
|
|
||||||
|
// 初始化累计数据(只在顶层调用时初始化)
|
||||||
|
if (!isRecursiveCall) {
|
||||||
|
if (!ctx._internal.customState) {
|
||||||
|
ctx._internal.customState = {}
|
||||||
|
}
|
||||||
|
ctx._internal.observer = {
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 0,
|
||||||
|
completion_tokens: 0,
|
||||||
|
total_tokens: 0
|
||||||
|
},
|
||||||
|
metrics: {
|
||||||
|
completion_tokens: 0,
|
||||||
|
time_completion_millsec: 0,
|
||||||
|
time_first_token_millsec: 0,
|
||||||
|
time_thinking_millsec: 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// 初始化文本累积器
|
||||||
|
ctx._internal.customState.accumulatedText = ''
|
||||||
|
ctx._internal.customState.startTimestamp = Date.now()
|
||||||
|
}
|
||||||
|
|
||||||
|
// 调用下游中间件
|
||||||
|
const result = await next(ctx, params)
|
||||||
|
|
||||||
|
// 响应后处理:处理GenericChunk流式响应
|
||||||
|
if (result.stream) {
|
||||||
|
const resultFromUpstream = result.stream
|
||||||
|
|
||||||
|
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||||
|
const reader = resultFromUpstream.getReader()
|
||||||
|
|
||||||
|
try {
|
||||||
|
while (true) {
|
||||||
|
const { done, value: chunk } = await reader.read()
|
||||||
|
if (done) {
|
||||||
|
Logger.debug(`[${MIDDLEWARE_NAME}] Input stream finished.`)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if (chunk) {
|
||||||
|
const genericChunk = chunk as GenericChunk
|
||||||
|
// 提取并累加usage/metrics数据
|
||||||
|
extractAndAccumulateUsageMetrics(ctx, genericChunk)
|
||||||
|
|
||||||
|
const shouldSkipChunk =
|
||||||
|
isRecursiveCall &&
|
||||||
|
(genericChunk.type === ChunkType.BLOCK_COMPLETE ||
|
||||||
|
genericChunk.type === ChunkType.LLM_RESPONSE_COMPLETE)
|
||||||
|
|
||||||
|
if (!shouldSkipChunk) params.onChunk?.(genericChunk)
|
||||||
|
} else {
|
||||||
|
Logger.warn(`[${MIDDLEWARE_NAME}] Received undefined chunk before stream was done.`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
Logger.error(`[${MIDDLEWARE_NAME}] Error consuming stream:`, error)
|
||||||
|
throw error
|
||||||
|
} finally {
|
||||||
|
if (params.onChunk && !isRecursiveCall) {
|
||||||
|
params.onChunk({
|
||||||
|
type: ChunkType.BLOCK_COMPLETE,
|
||||||
|
response: {
|
||||||
|
usage: ctx._internal.observer?.usage ? { ...ctx._internal.observer.usage } : undefined,
|
||||||
|
metrics: ctx._internal.observer?.metrics ? { ...ctx._internal.observer.metrics } : undefined
|
||||||
|
}
|
||||||
|
} as Chunk)
|
||||||
|
if (ctx._internal.toolProcessingState) {
|
||||||
|
ctx._internal.toolProcessingState = {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 为流式输出添加getText方法
|
||||||
|
const modifiedResult = {
|
||||||
|
...result,
|
||||||
|
stream: new ReadableStream<GenericChunk>({
|
||||||
|
start(controller) {
|
||||||
|
controller.close()
|
||||||
|
}
|
||||||
|
}),
|
||||||
|
getText: () => {
|
||||||
|
return ctx._internal.customState?.accumulatedText || ''
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return modifiedResult
|
||||||
|
} else {
|
||||||
|
Logger.debug(`[${MIDDLEWARE_NAME}] No GenericChunk stream to process.`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 从GenericChunk或原始SDK chunks中提取usage/metrics数据并累加
|
||||||
|
*/
|
||||||
|
function extractAndAccumulateUsageMetrics(ctx: CompletionsContext, chunk: GenericChunk): void {
|
||||||
|
if (!ctx._internal.observer?.usage || !ctx._internal.observer?.metrics) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (ctx._internal.customState && !ctx._internal.customState?.firstTokenTimestamp) {
|
||||||
|
ctx._internal.customState.firstTokenTimestamp = Date.now()
|
||||||
|
Logger.debug(`[${MIDDLEWARE_NAME}] First token timestamp: ${ctx._internal.customState.firstTokenTimestamp}`)
|
||||||
|
}
|
||||||
|
if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
|
||||||
|
Logger.debug(`[${MIDDLEWARE_NAME}] LLM_RESPONSE_COMPLETE chunk received:`, ctx._internal)
|
||||||
|
// 从LLM_RESPONSE_COMPLETE chunk中提取usage数据
|
||||||
|
if (chunk.response?.usage) {
|
||||||
|
accumulateUsage(ctx._internal.observer.usage, chunk.response.usage)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (ctx._internal.customState && ctx._internal.customState?.firstTokenTimestamp) {
|
||||||
|
ctx._internal.observer.metrics.time_first_token_millsec =
|
||||||
|
ctx._internal.customState.firstTokenTimestamp - ctx._internal.customState.startTimestamp
|
||||||
|
ctx._internal.observer.metrics.time_completion_millsec +=
|
||||||
|
Date.now() - ctx._internal.customState.firstTokenTimestamp
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 也可以从其他chunk类型中提取metrics数据
|
||||||
|
if (chunk.type === ChunkType.THINKING_COMPLETE && chunk.thinking_millsec && ctx._internal.observer?.metrics) {
|
||||||
|
ctx._internal.observer.metrics.time_thinking_millsec = Math.max(
|
||||||
|
ctx._internal.observer.metrics.time_thinking_millsec || 0,
|
||||||
|
chunk.thinking_millsec
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`[${MIDDLEWARE_NAME}] Error extracting usage/metrics from chunk:`, error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 累加usage数据
|
||||||
|
*/
|
||||||
|
function accumulateUsage(accumulated: Usage, newUsage: Usage): void {
|
||||||
|
if (newUsage.prompt_tokens !== undefined) {
|
||||||
|
accumulated.prompt_tokens += newUsage.prompt_tokens
|
||||||
|
}
|
||||||
|
if (newUsage.completion_tokens !== undefined) {
|
||||||
|
accumulated.completion_tokens += newUsage.completion_tokens
|
||||||
|
}
|
||||||
|
if (newUsage.total_tokens !== undefined) {
|
||||||
|
accumulated.total_tokens += newUsage.total_tokens
|
||||||
|
}
|
||||||
|
if (newUsage.thoughts_tokens !== undefined) {
|
||||||
|
accumulated.thoughts_tokens = (accumulated.thoughts_tokens || 0) + newUsage.thoughts_tokens
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export default FinalChunkConsumerMiddleware
|
||||||
@@ -0,0 +1,64 @@
|
|||||||
|
import { BaseContext, MethodMiddleware, MiddlewareAPI } from '../types'
|
||||||
|
|
||||||
|
export const MIDDLEWARE_NAME = 'GenericLoggingMiddlewares'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Helper function to safely stringify arguments for logging, handling circular references and large objects.
|
||||||
|
* 安全地字符串化日志参数的辅助函数,处理循环引用和大型对象。
|
||||||
|
* @param args - The arguments array to stringify. 要字符串化的参数数组。
|
||||||
|
* @returns A string representation of the arguments. 参数的字符串表示形式。
|
||||||
|
*/
|
||||||
|
const stringifyArgsForLogging = (args: any[]): string => {
|
||||||
|
try {
|
||||||
|
return args
|
||||||
|
.map((arg) => {
|
||||||
|
if (typeof arg === 'function') return '[Function]'
|
||||||
|
if (typeof arg === 'object' && arg !== null && arg.constructor === Object && Object.keys(arg).length > 20) {
|
||||||
|
return '[Object with >20 keys]'
|
||||||
|
}
|
||||||
|
// Truncate long strings to avoid flooding logs 截断长字符串以避免日志泛滥
|
||||||
|
const stringifiedArg = JSON.stringify(arg, null, 2)
|
||||||
|
return stringifiedArg && stringifiedArg.length > 200 ? stringifiedArg.substring(0, 200) + '...' : stringifiedArg
|
||||||
|
})
|
||||||
|
.join(', ')
|
||||||
|
} catch (e) {
|
||||||
|
return '[Error serializing arguments]' // Handle potential errors during stringification 处理字符串化期间的潜在错误
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generic logging middleware for provider methods.
|
||||||
|
* 为提供者方法创建一个通用的日志中间件。
|
||||||
|
* This middleware logs the initiation, success/failure, and duration of a method call.
|
||||||
|
* 此中间件记录方法调用的启动、成功/失败以及持续时间。
|
||||||
|
*/
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a generic logging middleware for provider methods.
|
||||||
|
* 为提供者方法创建一个通用的日志中间件。
|
||||||
|
* @returns A `MethodMiddleware` instance. 一个 `MethodMiddleware` 实例。
|
||||||
|
*/
|
||||||
|
export const createGenericLoggingMiddleware: () => MethodMiddleware = () => {
|
||||||
|
const middlewareName = 'GenericLoggingMiddleware'
|
||||||
|
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
||||||
|
return (_: MiddlewareAPI<BaseContext, any[]>) => (next) => async (ctx, args) => {
|
||||||
|
const methodName = ctx.methodName
|
||||||
|
const logPrefix = `[${middlewareName} (${methodName})]`
|
||||||
|
console.log(`${logPrefix} Initiating. Args:`, stringifyArgsForLogging(args))
|
||||||
|
const startTime = Date.now()
|
||||||
|
try {
|
||||||
|
const result = await next(ctx, args)
|
||||||
|
const duration = Date.now() - startTime
|
||||||
|
// Log successful completion of the method call with duration. /
|
||||||
|
// 记录方法调用成功完成及其持续时间。
|
||||||
|
console.log(`${logPrefix} Successful. Duration: ${duration}ms`)
|
||||||
|
return result
|
||||||
|
} catch (error) {
|
||||||
|
const duration = Date.now() - startTime
|
||||||
|
// Log failure of the method call with duration and error information. /
|
||||||
|
// 记录方法调用失败及其持续时间和错误信息。
|
||||||
|
console.error(`${logPrefix} Failed. Duration: ${duration}ms`, error)
|
||||||
|
throw error // Re-throw the error to be handled by subsequent layers or the caller / 重新抛出错误,由后续层或调用者处理
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
285
src/renderer/src/aiCore/middleware/composer.ts
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
import {
|
||||||
|
RequestOptions,
|
||||||
|
SdkInstance,
|
||||||
|
SdkMessageParam,
|
||||||
|
SdkParams,
|
||||||
|
SdkRawChunk,
|
||||||
|
SdkRawOutput,
|
||||||
|
SdkTool,
|
||||||
|
SdkToolCall
|
||||||
|
} from '@renderer/types/sdk'
|
||||||
|
|
||||||
|
import { BaseApiClient } from '../clients'
|
||||||
|
import { CompletionsParams, CompletionsResult } from './schemas'
|
||||||
|
import {
|
||||||
|
BaseContext,
|
||||||
|
CompletionsContext,
|
||||||
|
CompletionsMiddleware,
|
||||||
|
MethodMiddleware,
|
||||||
|
MIDDLEWARE_CONTEXT_SYMBOL,
|
||||||
|
MiddlewareAPI
|
||||||
|
} from './types'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates the initial context for a method call, populating method-specific fields. /
|
||||||
|
* 为方法调用创建初始上下文,并填充特定于该方法的字段。
|
||||||
|
* @param methodName - The name of the method being called. / 被调用的方法名。
|
||||||
|
* @param originalCallArgs - The actual arguments array from the proxy/method call. / 代理/方法调用的实际参数数组。
|
||||||
|
* @param providerId - The ID of the provider, if available. / 提供者的ID(如果可用)。
|
||||||
|
* @param providerInstance - The instance of the provider. / 提供者实例。
|
||||||
|
* @param specificContextFactory - An optional factory function to create a specific context type from the base context and original call arguments. / 一个可选的工厂函数,用于从基础上下文和原始调用参数创建特定的上下文类型。
|
||||||
|
* @returns The created context object. / 创建的上下文对象。
|
||||||
|
*/
|
||||||
|
function createInitialCallContext<TContext extends BaseContext, TCallArgs extends unknown[]>(
|
||||||
|
methodName: string,
|
||||||
|
originalCallArgs: TCallArgs, // Renamed from originalArgs to avoid confusion with context.originalArgs
|
||||||
|
// Factory to create specific context from base and the *original call arguments array*
|
||||||
|
specificContextFactory?: (base: BaseContext, callArgs: TCallArgs) => TContext
|
||||||
|
): TContext {
|
||||||
|
const baseContext: BaseContext = {
|
||||||
|
[MIDDLEWARE_CONTEXT_SYMBOL]: true,
|
||||||
|
methodName,
|
||||||
|
originalArgs: originalCallArgs // Store the full original arguments array in the context
|
||||||
|
}
|
||||||
|
|
||||||
|
if (specificContextFactory) {
|
||||||
|
return specificContextFactory(baseContext, originalCallArgs)
|
||||||
|
}
|
||||||
|
return baseContext as TContext // Fallback to base context if no specific factory
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Composes an array of functions from right to left. /
|
||||||
|
* 从右到左组合一个函数数组。
|
||||||
|
* `compose(f, g, h)` is `(...args) => f(g(h(...args)))`. /
|
||||||
|
* `compose(f, g, h)` 等同于 `(...args) => f(g(h(...args)))`。
|
||||||
|
* Each function in funcs is expected to take the result of the next function
|
||||||
|
* (or the initial value for the rightmost function) as its argument. /
|
||||||
|
* `funcs` 中的每个函数都期望接收下一个函数的结果(或最右侧函数的初始值)作为其参数。
|
||||||
|
* @param funcs - Array of functions to compose. / 要组合的函数数组。
|
||||||
|
* @returns The composed function. / 组合后的函数。
|
||||||
|
*/
|
||||||
|
function compose(...funcs: Array<(...args: any[]) => any>): (...args: any[]) => any {
|
||||||
|
if (funcs.length === 0) {
|
||||||
|
// If no functions to compose, return a function that returns its first argument, or undefined if no args. /
|
||||||
|
// 如果没有要组合的函数,则返回一个函数,该函数返回其第一个参数,如果没有参数则返回undefined。
|
||||||
|
return (...args: any[]) => (args.length > 0 ? args[0] : undefined)
|
||||||
|
}
|
||||||
|
if (funcs.length === 1) {
|
||||||
|
return funcs[0]
|
||||||
|
}
|
||||||
|
return funcs.reduce(
|
||||||
|
(a, b) =>
|
||||||
|
(...args: any[]) =>
|
||||||
|
a(b(...args))
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies an array of Redux-style middlewares to a generic provider method. /
|
||||||
|
* 将一组Redux风格的中间件应用于一个通用的提供者方法。
|
||||||
|
* This version keeps arguments as an array throughout the middleware chain. /
|
||||||
|
* 此版本在整个中间件链中将参数保持为数组形式。
|
||||||
|
* @param originalProviderInstance - The original provider instance. / 原始提供者实例。
|
||||||
|
* @param methodName - The name of the method to be enhanced. / 需要增强的方法名。
|
||||||
|
* @param originalMethod - The original method to be wrapped. / 需要包装的原始方法。
|
||||||
|
* @param middlewares - An array of `ProviderMethodMiddleware` to apply. / 要应用的 `ProviderMethodMiddleware` 数组。
|
||||||
|
* @param specificContextFactory - An optional factory to create a specific context for this method. / 可选的工厂函数,用于为此方法创建特定的上下文。
|
||||||
|
* @returns An enhanced method with the middlewares applied. / 应用了中间件的增强方法。
|
||||||
|
*/
|
||||||
|
export function applyMethodMiddlewares<
|
||||||
|
TArgs extends unknown[] = unknown[], // Original method's arguments array type / 原始方法的参数数组类型
|
||||||
|
TResult = unknown,
|
||||||
|
TContext extends BaseContext = BaseContext
|
||||||
|
>(
|
||||||
|
methodName: string,
|
||||||
|
originalMethod: (...args: TArgs) => Promise<TResult>,
|
||||||
|
middlewares: MethodMiddleware[], // Expects generic middlewares / 期望通用中间件
|
||||||
|
specificContextFactory?: (base: BaseContext, callArgs: TArgs) => TContext
|
||||||
|
): (...args: TArgs) => Promise<TResult> {
|
||||||
|
// Returns a function matching the original method signature. /
|
||||||
|
// 返回一个与原始方法签名匹配的函数。
|
||||||
|
return async function enhancedMethod(...methodCallArgs: TArgs): Promise<TResult> {
|
||||||
|
const ctx = createInitialCallContext<TContext, TArgs>(
|
||||||
|
methodName,
|
||||||
|
methodCallArgs, // Pass the actual call arguments array / 传递实际的调用参数数组
|
||||||
|
specificContextFactory
|
||||||
|
)
|
||||||
|
|
||||||
|
const api: MiddlewareAPI<TContext, TArgs> = {
|
||||||
|
getContext: () => ctx,
|
||||||
|
getOriginalArgs: () => methodCallArgs // API provides the original arguments array / API提供原始参数数组
|
||||||
|
}
|
||||||
|
|
||||||
|
// `finalDispatch` is the function that will ultimately call the original provider method. /
|
||||||
|
// `finalDispatch` 是最终将调用原始提供者方法的函数。
|
||||||
|
// It receives the current context and arguments, which may have been transformed by middlewares. /
|
||||||
|
// 它接收当前的上下文和参数,这些参数可能已被中间件转换。
|
||||||
|
const finalDispatch = async (
|
||||||
|
_: TContext,
|
||||||
|
currentArgs: TArgs // Generic final dispatch expects args array / 通用finalDispatch期望参数数组
|
||||||
|
): Promise<TResult> => {
|
||||||
|
return originalMethod.apply(currentArgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
const chain = middlewares.map((middleware) => middleware(api)) // Cast API if TContext/TArgs mismatch general ProviderMethodMiddleware / 如果TContext/TArgs与通用的ProviderMethodMiddleware不匹配,则转换API
|
||||||
|
const composedMiddlewareLogic = compose(...chain)
|
||||||
|
const enhancedDispatch = composedMiddlewareLogic(finalDispatch)
|
||||||
|
|
||||||
|
return enhancedDispatch(ctx, methodCallArgs) // Pass context and original args array / 传递上下文和原始参数数组
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Applies an array of `CompletionsMiddleware` to the `completions` method. /
|
||||||
|
* 将一组 `CompletionsMiddleware` 应用于 `completions` 方法。
|
||||||
|
* This version adapts for `CompletionsMiddleware` expecting a single `params` object. /
|
||||||
|
* 此版本适配了期望单个 `params` 对象的 `CompletionsMiddleware`。
|
||||||
|
* @param originalProviderInstance - The original provider instance. / 原始提供者实例。
|
||||||
|
* @param originalCompletionsMethod - The original SDK `createCompletions` method. / 原始的 SDK `createCompletions` 方法。
|
||||||
|
* @param middlewares - An array of `CompletionsMiddleware` to apply. / 要应用的 `CompletionsMiddleware` 数组。
|
||||||
|
* @returns An enhanced `completions` method with the middlewares applied. / 应用了中间件的增强版 `completions` 方法。
|
||||||
|
*/
|
||||||
|
export function applyCompletionsMiddlewares<
|
||||||
|
TSdkInstance extends SdkInstance = SdkInstance,
|
||||||
|
TSdkParams extends SdkParams = SdkParams,
|
||||||
|
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||||
|
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||||
|
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||||
|
TToolCall extends SdkToolCall = SdkToolCall,
|
||||||
|
TSdkSpecificTool extends SdkTool = SdkTool
|
||||||
|
>(
|
||||||
|
originalApiClientInstance: BaseApiClient<
|
||||||
|
TSdkInstance,
|
||||||
|
TSdkParams,
|
||||||
|
TRawOutput,
|
||||||
|
TRawChunk,
|
||||||
|
TMessageParam,
|
||||||
|
TToolCall,
|
||||||
|
TSdkSpecificTool
|
||||||
|
>,
|
||||||
|
originalCompletionsMethod: (payload: TSdkParams, options?: RequestOptions) => Promise<TRawOutput>,
|
||||||
|
middlewares: CompletionsMiddleware<
|
||||||
|
TSdkParams,
|
||||||
|
TMessageParam,
|
||||||
|
TToolCall,
|
||||||
|
TSdkInstance,
|
||||||
|
TRawOutput,
|
||||||
|
TRawChunk,
|
||||||
|
TSdkSpecificTool
|
||||||
|
>[]
|
||||||
|
): (params: CompletionsParams, options?: RequestOptions) => Promise<CompletionsResult> {
|
||||||
|
// Returns a function matching the original method signature. /
|
||||||
|
// 返回一个与原始方法签名匹配的函数。
|
||||||
|
|
||||||
|
const methodName = 'completions'
|
||||||
|
|
||||||
|
// Factory to create AiProviderMiddlewareCompletionsContext. /
|
||||||
|
// 用于创建 AiProviderMiddlewareCompletionsContext 的工厂函数。
|
||||||
|
const completionsContextFactory = (
|
||||||
|
base: BaseContext,
|
||||||
|
callArgs: [CompletionsParams]
|
||||||
|
): CompletionsContext<
|
||||||
|
TSdkParams,
|
||||||
|
TMessageParam,
|
||||||
|
TToolCall,
|
||||||
|
TSdkInstance,
|
||||||
|
TRawOutput,
|
||||||
|
TRawChunk,
|
||||||
|
TSdkSpecificTool
|
||||||
|
> => {
|
||||||
|
return {
|
||||||
|
...base,
|
||||||
|
methodName,
|
||||||
|
apiClientInstance: originalApiClientInstance,
|
||||||
|
originalArgs: callArgs,
|
||||||
|
_internal: {
|
||||||
|
toolProcessingState: {
|
||||||
|
recursionDepth: 0,
|
||||||
|
isRecursiveCall: false
|
||||||
|
},
|
||||||
|
observer: {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return async function enhancedCompletionsMethod(
|
||||||
|
params: CompletionsParams,
|
||||||
|
options?: RequestOptions
|
||||||
|
): Promise<CompletionsResult> {
|
||||||
|
// `originalCallArgs` for context creation is `[params]`. /
|
||||||
|
// 用于上下文创建的 `originalCallArgs` 是 `[params]`。
|
||||||
|
const originalCallArgs: [CompletionsParams] = [params]
|
||||||
|
const baseContext: BaseContext = {
|
||||||
|
[MIDDLEWARE_CONTEXT_SYMBOL]: true,
|
||||||
|
methodName,
|
||||||
|
originalArgs: originalCallArgs
|
||||||
|
}
|
||||||
|
const ctx = completionsContextFactory(baseContext, originalCallArgs)
|
||||||
|
|
||||||
|
const api: MiddlewareAPI<
|
||||||
|
CompletionsContext<TSdkParams, TMessageParam, TToolCall, TSdkInstance, TRawOutput, TRawChunk, TSdkSpecificTool>,
|
||||||
|
[CompletionsParams]
|
||||||
|
> = {
|
||||||
|
getContext: () => ctx,
|
||||||
|
getOriginalArgs: () => originalCallArgs // API provides [CompletionsParams] / API提供 `[CompletionsParams]`
|
||||||
|
}
|
||||||
|
|
||||||
|
// `finalDispatch` for CompletionsMiddleware: expects (context, params) not (context, args_array). /
|
||||||
|
// `CompletionsMiddleware` 的 `finalDispatch`:期望 (context, params) 而不是 (context, args_array)。
|
||||||
|
const finalDispatch = async (
|
||||||
|
context: CompletionsContext<
|
||||||
|
TSdkParams,
|
||||||
|
TMessageParam,
|
||||||
|
TToolCall,
|
||||||
|
TSdkInstance,
|
||||||
|
TRawOutput,
|
||||||
|
TRawChunk,
|
||||||
|
TSdkSpecificTool
|
||||||
|
> // Context passed through / 上下文透传
|
||||||
|
// _currentParams: CompletionsParams // Directly takes params / 直接接收参数 (unused but required for middleware signature)
|
||||||
|
): Promise<CompletionsResult> => {
|
||||||
|
// At this point, middleware should have transformed CompletionsParams to SDK params
|
||||||
|
// and stored them in context. If no transformation happened, we need to handle it.
|
||||||
|
// 此时,中间件应该已经将 CompletionsParams 转换为 SDK 参数并存储在上下文中。
|
||||||
|
// 如果没有进行转换,我们需要处理它。
|
||||||
|
|
||||||
|
const sdkPayload = context._internal?.sdkPayload
|
||||||
|
if (!sdkPayload) {
|
||||||
|
throw new Error('SDK payload not found in context. Middleware chain should have transformed parameters.')
|
||||||
|
}
|
||||||
|
|
||||||
|
const abortSignal = context._internal.flowControl?.abortSignal
|
||||||
|
const timeout = context._internal.customState?.sdkMetadata?.timeout
|
||||||
|
|
||||||
|
// Call the original SDK method with transformed parameters
|
||||||
|
// 使用转换后的参数调用原始 SDK 方法
|
||||||
|
const rawOutput = await originalCompletionsMethod.call(originalApiClientInstance, sdkPayload, {
|
||||||
|
...options,
|
||||||
|
signal: abortSignal,
|
||||||
|
timeout
|
||||||
|
})
|
||||||
|
|
||||||
|
// Return result wrapped in CompletionsResult format
|
||||||
|
// 以 CompletionsResult 格式返回包装的结果
|
||||||
|
return {
|
||||||
|
rawOutput
|
||||||
|
} as CompletionsResult
|
||||||
|
}
|
||||||
|
|
||||||
|
const chain = middlewares.map((middleware) => middleware(api))
|
||||||
|
const composedMiddlewareLogic = compose(...chain)
|
||||||
|
|
||||||
|
// `enhancedDispatch` has the signature `(context, params) => Promise<CompletionsResult>`. /
|
||||||
|
// `enhancedDispatch` 的签名为 `(context, params) => Promise<CompletionsResult>`。
|
||||||
|
const enhancedDispatch = composedMiddlewareLogic(finalDispatch)
|
||||||
|
|
||||||
|
// 将 enhancedDispatch 保存到 context 中,供中间件进行递归调用
|
||||||
|
// 这样可以避免重复执行整个中间件链
|
||||||
|
ctx._internal.enhancedDispatch = enhancedDispatch
|
||||||
|
|
||||||
|
// Execute with context and the single params object. /
|
||||||
|
// 使用上下文和单个参数对象执行。
|
||||||
|
return enhancedDispatch(ctx, params)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,306 @@
|
|||||||
|
import Logger from '@renderer/config/logger'
|
||||||
|
import { MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types'
|
||||||
|
import { ChunkType, MCPToolCreatedChunk } from '@renderer/types/chunk'
|
||||||
|
import { SdkMessageParam, SdkRawOutput, SdkToolCall } from '@renderer/types/sdk'
|
||||||
|
import { parseAndCallTools } from '@renderer/utils/mcp-tools'
|
||||||
|
|
||||||
|
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||||
|
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||||
|
|
||||||
|
export const MIDDLEWARE_NAME = 'McpToolChunkMiddleware'
|
||||||
|
const MAX_TOOL_RECURSION_DEPTH = 20 // 防止无限递归
|
||||||
|
|
||||||
|
/**
|
||||||
|
* MCP工具处理中间件
|
||||||
|
*
|
||||||
|
* 职责:
|
||||||
|
* 1. 检测并拦截MCP工具进展chunk(Function Call方式和Tool Use方式)
|
||||||
|
* 2. 执行工具调用
|
||||||
|
* 3. 递归处理工具结果
|
||||||
|
* 4. 管理工具调用状态和递归深度
|
||||||
|
*/
|
||||||
|
export const McpToolChunkMiddleware: CompletionsMiddleware =
|
||||||
|
() =>
|
||||||
|
(next) =>
|
||||||
|
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||||
|
const mcpTools = params.mcpTools || []
|
||||||
|
|
||||||
|
// 如果没有工具,直接调用下一个中间件
|
||||||
|
if (!mcpTools || mcpTools.length === 0) {
|
||||||
|
return next(ctx, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
const executeWithToolHandling = async (currentParams: CompletionsParams, depth = 0): Promise<CompletionsResult> => {
|
||||||
|
if (depth >= MAX_TOOL_RECURSION_DEPTH) {
|
||||||
|
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Maximum recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`)
|
||||||
|
throw new Error(`Maximum tool recursion depth ${MAX_TOOL_RECURSION_DEPTH} exceeded`)
|
||||||
|
}
|
||||||
|
|
||||||
|
let result: CompletionsResult
|
||||||
|
|
||||||
|
if (depth === 0) {
|
||||||
|
result = await next(ctx, currentParams)
|
||||||
|
} else {
|
||||||
|
const enhancedCompletions = ctx._internal.enhancedDispatch
|
||||||
|
if (!enhancedCompletions) {
|
||||||
|
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Enhanced completions method not found, cannot perform recursive call`)
|
||||||
|
throw new Error('Enhanced completions method not found')
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx._internal.toolProcessingState!.isRecursiveCall = true
|
||||||
|
ctx._internal.toolProcessingState!.recursionDepth = depth
|
||||||
|
|
||||||
|
result = await enhancedCompletions(ctx, currentParams)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!result.stream) {
|
||||||
|
Logger.error(`🔧 [${MIDDLEWARE_NAME}] No stream returned from enhanced completions`)
|
||||||
|
throw new Error('No stream returned from enhanced completions')
|
||||||
|
}
|
||||||
|
|
||||||
|
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||||
|
const toolHandlingStream = resultFromUpstream.pipeThrough(
|
||||||
|
createToolHandlingTransform(ctx, currentParams, mcpTools, depth, executeWithToolHandling)
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
...result,
|
||||||
|
stream: toolHandlingStream
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return executeWithToolHandling(params, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建工具处理的 TransformStream
|
||||||
|
*/
|
||||||
|
function createToolHandlingTransform(
|
||||||
|
ctx: CompletionsContext,
|
||||||
|
currentParams: CompletionsParams,
|
||||||
|
mcpTools: MCPTool[],
|
||||||
|
depth: number,
|
||||||
|
executeWithToolHandling: (params: CompletionsParams, depth: number) => Promise<CompletionsResult>
|
||||||
|
): TransformStream<GenericChunk, GenericChunk> {
|
||||||
|
const toolCalls: SdkToolCall[] = []
|
||||||
|
const toolUseResponses: MCPToolResponse[] = []
|
||||||
|
const allToolResponses: MCPToolResponse[] = [] // 统一的工具响应状态管理数组
|
||||||
|
let hasToolCalls = false
|
||||||
|
let hasToolUseResponses = false
|
||||||
|
let streamEnded = false
|
||||||
|
|
||||||
|
return new TransformStream({
|
||||||
|
async transform(chunk: GenericChunk, controller) {
|
||||||
|
try {
|
||||||
|
// 处理MCP工具进展chunk
|
||||||
|
if (chunk.type === ChunkType.MCP_TOOL_CREATED) {
|
||||||
|
const createdChunk = chunk as MCPToolCreatedChunk
|
||||||
|
|
||||||
|
// 1. 处理Function Call方式的工具调用
|
||||||
|
if (createdChunk.tool_calls && createdChunk.tool_calls.length > 0) {
|
||||||
|
toolCalls.push(...createdChunk.tool_calls)
|
||||||
|
hasToolCalls = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. 处理Tool Use方式的工具调用
|
||||||
|
if (createdChunk.tool_use_responses && createdChunk.tool_use_responses.length > 0) {
|
||||||
|
toolUseResponses.push(...createdChunk.tool_use_responses)
|
||||||
|
hasToolUseResponses = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// 不转发MCP工具进展chunks,避免重复处理
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转发其他所有chunk
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`🔧 [${MIDDLEWARE_NAME}] Error processing chunk:`, error)
|
||||||
|
controller.error(error)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
async flush(controller) {
|
||||||
|
const shouldExecuteToolCalls = hasToolCalls && toolCalls.length > 0
|
||||||
|
const shouldExecuteToolUseResponses = hasToolUseResponses && toolUseResponses.length > 0
|
||||||
|
|
||||||
|
if (!streamEnded && (shouldExecuteToolCalls || shouldExecuteToolUseResponses)) {
|
||||||
|
streamEnded = true
|
||||||
|
|
||||||
|
try {
|
||||||
|
let toolResult: SdkMessageParam[] = []
|
||||||
|
|
||||||
|
if (shouldExecuteToolCalls) {
|
||||||
|
toolResult = await executeToolCalls(
|
||||||
|
ctx,
|
||||||
|
toolCalls,
|
||||||
|
mcpTools,
|
||||||
|
allToolResponses,
|
||||||
|
currentParams.onChunk,
|
||||||
|
currentParams.assistant.model!
|
||||||
|
)
|
||||||
|
} else if (shouldExecuteToolUseResponses) {
|
||||||
|
toolResult = await executeToolUseResponses(
|
||||||
|
ctx,
|
||||||
|
toolUseResponses,
|
||||||
|
mcpTools,
|
||||||
|
allToolResponses,
|
||||||
|
currentParams.onChunk,
|
||||||
|
currentParams.assistant.model!
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (toolResult.length > 0) {
|
||||||
|
const output = ctx._internal.toolProcessingState?.output
|
||||||
|
|
||||||
|
const newParams = buildParamsWithToolResults(ctx, currentParams, output!, toolResult, toolCalls)
|
||||||
|
await executeWithToolHandling(newParams, depth + 1)
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`🔧 [${MIDDLEWARE_NAME}] Error in tool processing:`, error)
|
||||||
|
controller.error(error)
|
||||||
|
} finally {
|
||||||
|
hasToolCalls = false
|
||||||
|
hasToolUseResponses = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行工具调用(Function Call 方式)
|
||||||
|
*/
|
||||||
|
async function executeToolCalls(
|
||||||
|
ctx: CompletionsContext,
|
||||||
|
toolCalls: SdkToolCall[],
|
||||||
|
mcpTools: MCPTool[],
|
||||||
|
allToolResponses: MCPToolResponse[],
|
||||||
|
onChunk: CompletionsParams['onChunk'],
|
||||||
|
model: Model
|
||||||
|
): Promise<SdkMessageParam[]> {
|
||||||
|
// 转换为MCPToolResponse格式
|
||||||
|
const mcpToolResponses: ToolCallResponse[] = toolCalls
|
||||||
|
.map((toolCall) => {
|
||||||
|
const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools)
|
||||||
|
if (!mcpTool) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
return ctx.apiClientInstance.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
|
||||||
|
})
|
||||||
|
.filter((t): t is ToolCallResponse => typeof t !== 'undefined')
|
||||||
|
|
||||||
|
if (mcpToolResponses.length === 0) {
|
||||||
|
console.warn(`🔧 [${MIDDLEWARE_NAME}] No valid MCP tool responses to execute`)
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用现有的parseAndCallTools函数执行工具
|
||||||
|
const toolResults = await parseAndCallTools(
|
||||||
|
mcpToolResponses,
|
||||||
|
allToolResponses,
|
||||||
|
onChunk,
|
||||||
|
(mcpToolResponse, resp, model) => {
|
||||||
|
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||||
|
},
|
||||||
|
model,
|
||||||
|
mcpTools
|
||||||
|
)
|
||||||
|
|
||||||
|
return toolResults
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 执行工具使用响应(Tool Use Response 方式)
|
||||||
|
* 处理已经解析好的 ToolUseResponse[],不需要重新解析字符串
|
||||||
|
*/
|
||||||
|
async function executeToolUseResponses(
|
||||||
|
ctx: CompletionsContext,
|
||||||
|
toolUseResponses: MCPToolResponse[],
|
||||||
|
mcpTools: MCPTool[],
|
||||||
|
allToolResponses: MCPToolResponse[],
|
||||||
|
onChunk: CompletionsParams['onChunk'],
|
||||||
|
model: Model
|
||||||
|
): Promise<SdkMessageParam[]> {
|
||||||
|
// 直接使用parseAndCallTools函数处理已经解析好的ToolUseResponse
|
||||||
|
const toolResults = await parseAndCallTools(
|
||||||
|
toolUseResponses,
|
||||||
|
allToolResponses,
|
||||||
|
onChunk,
|
||||||
|
(mcpToolResponse, resp, model) => {
|
||||||
|
return ctx.apiClientInstance.convertMcpToolResponseToSdkMessageParam(mcpToolResponse, resp, model)
|
||||||
|
},
|
||||||
|
model,
|
||||||
|
mcpTools
|
||||||
|
)
|
||||||
|
|
||||||
|
return toolResults
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 构建包含工具结果的新参数
|
||||||
|
*/
|
||||||
|
function buildParamsWithToolResults(
|
||||||
|
ctx: CompletionsContext,
|
||||||
|
currentParams: CompletionsParams,
|
||||||
|
output: SdkRawOutput | string,
|
||||||
|
toolResults: SdkMessageParam[],
|
||||||
|
toolCalls: SdkToolCall[]
|
||||||
|
): CompletionsParams {
|
||||||
|
// 获取当前已经转换好的reqMessages,如果没有则使用原始messages
|
||||||
|
const currentReqMessages = getCurrentReqMessages(ctx)
|
||||||
|
|
||||||
|
const apiClient = ctx.apiClientInstance
|
||||||
|
|
||||||
|
// 从回复中构建助手消息
|
||||||
|
const newReqMessages = apiClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
|
||||||
|
|
||||||
|
// 估算新增消息的 token 消耗并累加到 usage 中
|
||||||
|
if (ctx._internal.observer?.usage && newReqMessages.length > currentReqMessages.length) {
|
||||||
|
try {
|
||||||
|
const newMessages = newReqMessages.slice(currentReqMessages.length)
|
||||||
|
const additionalTokens = newMessages.reduce((acc, message) => {
|
||||||
|
return acc + ctx.apiClientInstance.estimateMessageTokens(message)
|
||||||
|
}, 0)
|
||||||
|
|
||||||
|
if (additionalTokens > 0) {
|
||||||
|
ctx._internal.observer.usage.prompt_tokens += additionalTokens
|
||||||
|
ctx._internal.observer.usage.total_tokens += additionalTokens
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
Logger.error(`🔧 [${MIDDLEWARE_NAME}] Error estimating token usage for new messages:`, error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 更新递归状态
|
||||||
|
if (!ctx._internal.toolProcessingState) {
|
||||||
|
ctx._internal.toolProcessingState = {}
|
||||||
|
}
|
||||||
|
ctx._internal.toolProcessingState.isRecursiveCall = true
|
||||||
|
ctx._internal.toolProcessingState.recursionDepth = (ctx._internal.toolProcessingState?.recursionDepth || 0) + 1
|
||||||
|
|
||||||
|
return {
|
||||||
|
...currentParams,
|
||||||
|
_internal: {
|
||||||
|
...ctx._internal,
|
||||||
|
sdkPayload: ctx._internal.sdkPayload,
|
||||||
|
newReqMessages: newReqMessages
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 类型安全地获取当前请求消息
|
||||||
|
* 使用API客户端提供的抽象方法,保持中间件的provider无关性
|
||||||
|
*/
|
||||||
|
function getCurrentReqMessages(ctx: CompletionsContext): SdkMessageParam[] {
|
||||||
|
const sdkPayload = ctx._internal.sdkPayload
|
||||||
|
if (!sdkPayload) {
|
||||||
|
return []
|
||||||
|
}
|
||||||
|
|
||||||
|
// 使用API客户端的抽象方法来提取消息,保持provider无关性
|
||||||
|
return ctx.apiClientInstance.extractMessagesFromSdkPayload(sdkPayload)
|
||||||
|
}
|
||||||
|
|
||||||
|
export default McpToolChunkMiddleware
|
||||||
@@ -0,0 +1,48 @@
|
|||||||
|
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
|
||||||
|
import { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk'
|
||||||
|
|
||||||
|
import { AnthropicStreamListener } from '../../clients/types'
|
||||||
|
import { CompletionsParams, CompletionsResult } from '../schemas'
|
||||||
|
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||||
|
|
||||||
|
export const MIDDLEWARE_NAME = 'RawStreamListenerMiddleware'
|
||||||
|
|
||||||
|
export const RawStreamListenerMiddleware: CompletionsMiddleware =
|
||||||
|
() =>
|
||||||
|
(next) =>
|
||||||
|
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||||
|
const result = await next(ctx, params)
|
||||||
|
|
||||||
|
// 在这里可以监听到从SDK返回的最原始流
|
||||||
|
if (result.rawOutput) {
|
||||||
|
console.log(`[${MIDDLEWARE_NAME}] 检测到原始SDK输出,准备附加监听器`)
|
||||||
|
|
||||||
|
const providerType = ctx.apiClientInstance.provider.type
|
||||||
|
// TODO: 后面下放到AnthropicAPIClient
|
||||||
|
if (providerType === 'anthropic') {
|
||||||
|
const anthropicListener: AnthropicStreamListener<AnthropicSdkRawChunk> = {
|
||||||
|
onMessage: (message) => {
|
||||||
|
if (ctx._internal?.toolProcessingState) {
|
||||||
|
ctx._internal.toolProcessingState.output = message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// onContentBlock: (contentBlock) => {
|
||||||
|
// console.log(`[${MIDDLEWARE_NAME}] 📝 Anthropic content block:`, contentBlock.type)
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
|
||||||
|
const specificApiClient = ctx.apiClientInstance as AnthropicAPIClient
|
||||||
|
|
||||||
|
const monitoredOutput = specificApiClient.attachRawStreamListener(
|
||||||
|
result.rawOutput as AnthropicSdkRawOutput,
|
||||||
|
anthropicListener
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
...result,
|
||||||
|
rawOutput: monitoredOutput
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
@@ -0,0 +1,85 @@
|
|||||||
|
import Logger from '@renderer/config/logger'
|
||||||
|
import { SdkRawChunk } from '@renderer/types/sdk'
|
||||||
|
|
||||||
|
import { ResponseChunkTransformerContext } from '../../clients/types'
|
||||||
|
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||||
|
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||||
|
|
||||||
|
export const MIDDLEWARE_NAME = 'ResponseTransformMiddleware'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 响应转换中间件
|
||||||
|
*
|
||||||
|
* 职责:
|
||||||
|
* 1. 检测ReadableStream类型的响应流
|
||||||
|
* 2. 使用ApiClient的getResponseChunkTransformer()将原始SDK响应块转换为通用格式
|
||||||
|
* 3. 将转换后的ReadableStream保存到ctx._internal.apiCall.genericChunkStream,供下游中间件使用
|
||||||
|
*
|
||||||
|
* 注意:此中间件应该在StreamAdapterMiddleware之后执行
|
||||||
|
*/
|
||||||
|
export const ResponseTransformMiddleware: CompletionsMiddleware =
|
||||||
|
() =>
|
||||||
|
(next) =>
|
||||||
|
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||||
|
// 调用下游中间件
|
||||||
|
const result = await next(ctx, params)
|
||||||
|
|
||||||
|
// 响应后处理:转换原始SDK响应块
|
||||||
|
if (result.stream) {
|
||||||
|
const adaptedStream = result.stream
|
||||||
|
|
||||||
|
// 处理ReadableStream类型的流
|
||||||
|
if (adaptedStream instanceof ReadableStream) {
|
||||||
|
const apiClient = ctx.apiClientInstance
|
||||||
|
if (!apiClient) {
|
||||||
|
console.error(`[${MIDDLEWARE_NAME}] ApiClient instance not found in context`)
|
||||||
|
throw new Error('ApiClient instance not found in context')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 获取响应转换器
|
||||||
|
const responseChunkTransformer = apiClient.getResponseChunkTransformer?.()
|
||||||
|
if (!responseChunkTransformer) {
|
||||||
|
Logger.warn(`[${MIDDLEWARE_NAME}] No ResponseChunkTransformer available, skipping transformation`)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
const assistant = params.assistant
|
||||||
|
const model = assistant?.model
|
||||||
|
|
||||||
|
if (!assistant || !model) {
|
||||||
|
console.error(`[${MIDDLEWARE_NAME}] Assistant or Model not found for transformation`)
|
||||||
|
throw new Error('Assistant or Model not found for transformation')
|
||||||
|
}
|
||||||
|
|
||||||
|
const transformerContext: ResponseChunkTransformerContext = {
|
||||||
|
isStreaming: params.streamOutput || false,
|
||||||
|
isEnabledToolCalling: (params.mcpTools && params.mcpTools.length > 0) || false,
|
||||||
|
isEnabledWebSearch: params.enableWebSearch || false,
|
||||||
|
isEnabledReasoning: params.enableReasoning || false,
|
||||||
|
mcpTools: params.mcpTools || [],
|
||||||
|
provider: ctx.apiClientInstance?.provider
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(`[${MIDDLEWARE_NAME}] Transforming raw SDK chunks with context:`, transformerContext)
|
||||||
|
|
||||||
|
try {
|
||||||
|
// 创建转换后的流
|
||||||
|
const genericChunkTransformStream = (adaptedStream as ReadableStream<SdkRawChunk>).pipeThrough<GenericChunk>(
|
||||||
|
new TransformStream<SdkRawChunk, GenericChunk>(responseChunkTransformer(transformerContext))
|
||||||
|
)
|
||||||
|
|
||||||
|
// 将转换后的ReadableStream保存到result,供下游中间件使用
|
||||||
|
return {
|
||||||
|
...result,
|
||||||
|
stream: genericChunkTransformStream
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
Logger.error(`[${MIDDLEWARE_NAME}] Error during chunk transformation:`, error)
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// 如果没有流或不是ReadableStream,返回原始结果
|
||||||
|
return result
|
||||||
|
}
|
||||||
@@ -0,0 +1,57 @@
|
|||||||
|
import { SdkRawChunk } from '@renderer/types/sdk'
|
||||||
|
import { asyncGeneratorToReadableStream, createSingleChunkReadableStream } from '@renderer/utils/stream'
|
||||||
|
|
||||||
|
import { CompletionsParams, CompletionsResult } from '../schemas'
|
||||||
|
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||||
|
import { isAsyncIterable } from '../utils'
|
||||||
|
|
||||||
|
export const MIDDLEWARE_NAME = 'StreamAdapterMiddleware'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 流适配器中间件
|
||||||
|
*
|
||||||
|
* 职责:
|
||||||
|
* 1. 检测ctx._internal.apiCall.rawSdkOutput(优先)或原始AsyncIterable流
|
||||||
|
* 2. 将AsyncIterable转换为WHATWG ReadableStream
|
||||||
|
* 3. 更新响应结果中的stream
|
||||||
|
*
|
||||||
|
* 注意:如果ResponseTransformMiddleware已处理过,会优先使用transformedStream
|
||||||
|
*/
|
||||||
|
export const StreamAdapterMiddleware: CompletionsMiddleware =
|
||||||
|
() =>
|
||||||
|
(next) =>
|
||||||
|
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||||
|
// TODO:调用开始,因为这个是最靠近接口请求的地方,next执行代表着开始接口请求了
|
||||||
|
// 但是这个中间件的职责是流适配,是否在这调用优待商榷
|
||||||
|
// 调用下游中间件
|
||||||
|
const result = await next(ctx, params)
|
||||||
|
|
||||||
|
if (
|
||||||
|
result.rawOutput &&
|
||||||
|
!(result.rawOutput instanceof ReadableStream) &&
|
||||||
|
isAsyncIterable<SdkRawChunk>(result.rawOutput)
|
||||||
|
) {
|
||||||
|
const whatwgReadableStream: ReadableStream<SdkRawChunk> = asyncGeneratorToReadableStream<SdkRawChunk>(
|
||||||
|
result.rawOutput
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
...result,
|
||||||
|
stream: whatwgReadableStream
|
||||||
|
}
|
||||||
|
} else if (result.rawOutput && result.rawOutput instanceof ReadableStream) {
|
||||||
|
return {
|
||||||
|
...result,
|
||||||
|
stream: result.rawOutput
|
||||||
|
}
|
||||||
|
} else if (result.rawOutput) {
|
||||||
|
// 非流式输出,强行变为可读流
|
||||||
|
const whatwgReadableStream: ReadableStream<SdkRawChunk> = createSingleChunkReadableStream<SdkRawChunk>(
|
||||||
|
result.rawOutput
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
...result,
|
||||||
|
stream: whatwgReadableStream
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
@@ -0,0 +1,99 @@
|
|||||||
|
import Logger from '@renderer/config/logger'
|
||||||
|
import { ChunkType, TextDeltaChunk } from '@renderer/types/chunk'
|
||||||
|
|
||||||
|
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||||
|
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||||
|
|
||||||
|
export const MIDDLEWARE_NAME = 'TextChunkMiddleware'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 文本块处理中间件
|
||||||
|
*
|
||||||
|
* 职责:
|
||||||
|
* 1. 累积文本内容(TEXT_DELTA)
|
||||||
|
* 2. 对文本内容进行智能链接转换
|
||||||
|
* 3. 生成TEXT_COMPLETE事件
|
||||||
|
* 4. 暂存Web搜索结果,用于最终链接完善
|
||||||
|
* 5. 处理 onResponse 回调,实时发送文本更新和最终完整文本
|
||||||
|
*/
|
||||||
|
export const TextChunkMiddleware: CompletionsMiddleware =
|
||||||
|
() =>
|
||||||
|
(next) =>
|
||||||
|
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||||
|
// 调用下游中间件
|
||||||
|
const result = await next(ctx, params)
|
||||||
|
|
||||||
|
// 响应后处理:转换流式响应中的文本内容
|
||||||
|
if (result.stream) {
|
||||||
|
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||||
|
|
||||||
|
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||||
|
const assistant = params.assistant
|
||||||
|
const model = params.assistant?.model
|
||||||
|
|
||||||
|
if (!assistant || !model) {
|
||||||
|
Logger.warn(`[${MIDDLEWARE_NAME}] Missing assistant or model information, skipping text processing`)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// 用于跨chunk的状态管理
|
||||||
|
let accumulatedTextContent = ''
|
||||||
|
let hasEnqueue = false
|
||||||
|
const enhancedTextStream = resultFromUpstream.pipeThrough(
|
||||||
|
new TransformStream<GenericChunk, GenericChunk>({
|
||||||
|
transform(chunk: GenericChunk, controller) {
|
||||||
|
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||||
|
const textChunk = chunk as TextDeltaChunk
|
||||||
|
accumulatedTextContent += textChunk.text
|
||||||
|
|
||||||
|
// 处理 onResponse 回调 - 发送增量文本更新
|
||||||
|
if (params.onResponse) {
|
||||||
|
params.onResponse(accumulatedTextContent, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 创建新的chunk,包含处理后的文本
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
} else if (accumulatedTextContent) {
|
||||||
|
if (chunk.type !== ChunkType.LLM_RESPONSE_COMPLETE) {
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
hasEnqueue = true
|
||||||
|
}
|
||||||
|
const finalText = accumulatedTextContent
|
||||||
|
ctx._internal.customState!.accumulatedText = finalText
|
||||||
|
if (ctx._internal.toolProcessingState && !ctx._internal.toolProcessingState?.output) {
|
||||||
|
ctx._internal.toolProcessingState.output = finalText
|
||||||
|
}
|
||||||
|
|
||||||
|
// 处理 onResponse 回调 - 发送最终完整文本
|
||||||
|
if (params.onResponse) {
|
||||||
|
params.onResponse(finalText, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
controller.enqueue({
|
||||||
|
type: ChunkType.TEXT_COMPLETE,
|
||||||
|
text: finalText
|
||||||
|
})
|
||||||
|
accumulatedTextContent = ''
|
||||||
|
if (!hasEnqueue) {
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 其他类型的chunk直接传递
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
// 更新响应结果
|
||||||
|
return {
|
||||||
|
...result,
|
||||||
|
stream: enhancedTextStream
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Logger.warn(`[${MIDDLEWARE_NAME}] No stream to process or not a ReadableStream. Returning original result.`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
101
src/renderer/src/aiCore/middleware/core/ThinkChunkMiddleware.ts
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
import Logger from '@renderer/config/logger'
|
||||||
|
import { ChunkType, ThinkingCompleteChunk, ThinkingDeltaChunk } from '@renderer/types/chunk'
|
||||||
|
|
||||||
|
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||||
|
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||||
|
|
||||||
|
export const MIDDLEWARE_NAME = 'ThinkChunkMiddleware'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 处理思考内容的中间件
|
||||||
|
*
|
||||||
|
* 注意:从 v2 版本开始,流结束语义的判断已移至 ApiClient 层处理
|
||||||
|
* 此中间件现在主要负责:
|
||||||
|
* 1. 处理原始SDK chunk中的reasoning字段
|
||||||
|
* 2. 计算准确的思考时间
|
||||||
|
* 3. 在思考内容结束时生成THINKING_COMPLETE事件
|
||||||
|
*
|
||||||
|
* 职责:
|
||||||
|
* 1. 累积思考内容(THINKING_DELTA)
|
||||||
|
* 2. 监听流结束信号,生成THINKING_COMPLETE事件
|
||||||
|
* 3. 计算准确的思考时间
|
||||||
|
*
|
||||||
|
*/
|
||||||
|
export const ThinkChunkMiddleware: CompletionsMiddleware =
|
||||||
|
() =>
|
||||||
|
(next) =>
|
||||||
|
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||||
|
// 调用下游中间件
|
||||||
|
const result = await next(ctx, params)
|
||||||
|
|
||||||
|
// 响应后处理:处理思考内容
|
||||||
|
if (result.stream) {
|
||||||
|
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||||
|
|
||||||
|
// 检查是否启用reasoning
|
||||||
|
const enableReasoning = params.enableReasoning || false
|
||||||
|
if (!enableReasoning) {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否有流需要处理
|
||||||
|
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||||
|
// thinking 处理状态
|
||||||
|
let accumulatedThinkingContent = ''
|
||||||
|
let hasThinkingContent = false
|
||||||
|
let thinkingStartTime = 0
|
||||||
|
|
||||||
|
const processedStream = resultFromUpstream.pipeThrough(
|
||||||
|
new TransformStream<GenericChunk, GenericChunk>({
|
||||||
|
transform(chunk: GenericChunk, controller) {
|
||||||
|
if (chunk.type === ChunkType.THINKING_DELTA) {
|
||||||
|
const thinkingChunk = chunk as ThinkingDeltaChunk
|
||||||
|
|
||||||
|
// 第一次接收到思考内容时记录开始时间
|
||||||
|
if (!hasThinkingContent) {
|
||||||
|
hasThinkingContent = true
|
||||||
|
thinkingStartTime = Date.now()
|
||||||
|
}
|
||||||
|
|
||||||
|
accumulatedThinkingContent += thinkingChunk.text
|
||||||
|
|
||||||
|
// 更新思考时间并传递
|
||||||
|
const enhancedChunk: ThinkingDeltaChunk = {
|
||||||
|
...thinkingChunk,
|
||||||
|
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||||
|
}
|
||||||
|
controller.enqueue(enhancedChunk)
|
||||||
|
} else if (hasThinkingContent && thinkingStartTime > 0) {
|
||||||
|
// 收到任何非THINKING_DELTA的chunk时,如果有累积的思考内容,生成THINKING_COMPLETE
|
||||||
|
const thinkingCompleteChunk: ThinkingCompleteChunk = {
|
||||||
|
type: ChunkType.THINKING_COMPLETE,
|
||||||
|
text: accumulatedThinkingContent,
|
||||||
|
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||||
|
}
|
||||||
|
controller.enqueue(thinkingCompleteChunk)
|
||||||
|
hasThinkingContent = false
|
||||||
|
accumulatedThinkingContent = ''
|
||||||
|
thinkingStartTime = 0
|
||||||
|
|
||||||
|
// 继续传递当前chunk
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
} else {
|
||||||
|
// 其他情况直接传递
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
// 更新响应结果
|
||||||
|
return {
|
||||||
|
...result,
|
||||||
|
stream: processedStream
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Logger.warn(`[${MIDDLEWARE_NAME}] No generic chunk stream to process or not a ReadableStream.`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
@@ -0,0 +1,83 @@
|
|||||||
|
import Logger from '@renderer/config/logger'
|
||||||
|
import { ChunkType } from '@renderer/types/chunk'
|
||||||
|
|
||||||
|
import { CompletionsParams, CompletionsResult } from '../schemas'
|
||||||
|
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||||
|
|
||||||
|
export const MIDDLEWARE_NAME = 'TransformCoreToSdkParamsMiddleware'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 中间件:将CoreCompletionsRequest转换为SDK特定的参数
|
||||||
|
* 使用上下文中ApiClient实例的requestTransformer进行转换
|
||||||
|
*/
|
||||||
|
export const TransformCoreToSdkParamsMiddleware: CompletionsMiddleware =
|
||||||
|
() =>
|
||||||
|
(next) =>
|
||||||
|
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||||
|
Logger.debug(`🔄 [${MIDDLEWARE_NAME}] Starting core to SDK params transformation:`, ctx)
|
||||||
|
|
||||||
|
const internal = ctx._internal
|
||||||
|
|
||||||
|
// 🔧 检测递归调用:检查 params 中是否携带了预处理的 SDK 消息
|
||||||
|
const isRecursiveCall = internal?.toolProcessingState?.isRecursiveCall || false
|
||||||
|
const newSdkMessages = params._internal?.newReqMessages
|
||||||
|
|
||||||
|
const apiClient = ctx.apiClientInstance
|
||||||
|
|
||||||
|
if (!apiClient) {
|
||||||
|
Logger.error(`🔄 [${MIDDLEWARE_NAME}] ApiClient instance not found in context.`)
|
||||||
|
throw new Error('ApiClient instance not found in context')
|
||||||
|
}
|
||||||
|
|
||||||
|
// 检查是否有requestTransformer方法
|
||||||
|
const requestTransformer = apiClient.getRequestTransformer()
|
||||||
|
if (!requestTransformer) {
|
||||||
|
Logger.warn(
|
||||||
|
`🔄 [${MIDDLEWARE_NAME}] ApiClient does not have getRequestTransformer method, skipping transformation`
|
||||||
|
)
|
||||||
|
const result = await next(ctx, params)
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
// 确保assistant和model可用,它们是transformer所需的
|
||||||
|
const assistant = params.assistant
|
||||||
|
const model = params.assistant.model
|
||||||
|
|
||||||
|
if (!assistant || !model) {
|
||||||
|
console.error(`🔄 [${MIDDLEWARE_NAME}] Assistant or Model not found for transformation.`)
|
||||||
|
throw new Error('Assistant or Model not found for transformation')
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
const transformResult = await requestTransformer.transform(
|
||||||
|
params,
|
||||||
|
assistant,
|
||||||
|
model,
|
||||||
|
isRecursiveCall,
|
||||||
|
newSdkMessages
|
||||||
|
)
|
||||||
|
|
||||||
|
const { payload: sdkPayload, metadata } = transformResult
|
||||||
|
|
||||||
|
// 将SDK特定的payload和metadata存储在状态中,供下游中间件使用
|
||||||
|
ctx._internal.sdkPayload = sdkPayload
|
||||||
|
|
||||||
|
if (metadata) {
|
||||||
|
ctx._internal.customState = {
|
||||||
|
...ctx._internal.customState,
|
||||||
|
sdkMetadata: metadata
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (params.enableGenerateImage) {
|
||||||
|
params.onChunk?.({
|
||||||
|
type: ChunkType.IMAGE_CREATED
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return next(ctx, params)
|
||||||
|
} catch (error) {
|
||||||
|
Logger.error(`🔄 [${MIDDLEWARE_NAME}] Error during request transformation:`, error)
|
||||||
|
// 让错误向上传播,或者可以在这里进行特定的错误处理
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,76 @@
|
|||||||
|
import { ChunkType } from '@renderer/types/chunk'
|
||||||
|
import { smartLinkConverter } from '@renderer/utils/linkConverter'
|
||||||
|
|
||||||
|
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||||
|
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||||
|
|
||||||
|
export const MIDDLEWARE_NAME = 'WebSearchMiddleware'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Web搜索处理中间件 - 基于GenericChunk流处理
|
||||||
|
*
|
||||||
|
* 职责:
|
||||||
|
* 1. 监听和记录Web搜索事件
|
||||||
|
* 2. 可以在此处添加Web搜索结果的后处理逻辑
|
||||||
|
* 3. 维护Web搜索相关的状态
|
||||||
|
*
|
||||||
|
* 注意:Web搜索结果的识别和生成已在ApiClient的响应转换器中处理
|
||||||
|
*/
|
||||||
|
export const WebSearchMiddleware: CompletionsMiddleware =
|
||||||
|
() =>
|
||||||
|
(next) =>
|
||||||
|
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||||
|
ctx._internal.webSearchState = {
|
||||||
|
results: undefined
|
||||||
|
}
|
||||||
|
// 调用下游中间件
|
||||||
|
const result = await next(ctx, params)
|
||||||
|
|
||||||
|
const model = params.assistant?.model!
|
||||||
|
let isFirstChunk = true
|
||||||
|
|
||||||
|
// 响应后处理:记录Web搜索事件
|
||||||
|
if (result.stream) {
|
||||||
|
const resultFromUpstream = result.stream
|
||||||
|
|
||||||
|
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||||
|
// Web搜索状态跟踪
|
||||||
|
const enhancedStream = (resultFromUpstream as ReadableStream<GenericChunk>).pipeThrough(
|
||||||
|
new TransformStream<GenericChunk, GenericChunk>({
|
||||||
|
transform(chunk: GenericChunk, controller) {
|
||||||
|
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||||
|
const providerType = model.provider || 'openai'
|
||||||
|
// 使用当前可用的Web搜索结果进行链接转换
|
||||||
|
const text = chunk.text
|
||||||
|
const processedText = smartLinkConverter(text, providerType, isFirstChunk)
|
||||||
|
if (isFirstChunk) {
|
||||||
|
isFirstChunk = false
|
||||||
|
}
|
||||||
|
controller.enqueue({
|
||||||
|
...chunk,
|
||||||
|
text: processedText
|
||||||
|
})
|
||||||
|
} else if (chunk.type === ChunkType.LLM_WEB_SEARCH_COMPLETE) {
|
||||||
|
// 暂存Web搜索结果用于链接完善
|
||||||
|
ctx._internal.webSearchState!.results = chunk.llm_web_search
|
||||||
|
|
||||||
|
// 将Web搜索完成事件继续传递下去
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
} else {
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
return {
|
||||||
|
...result,
|
||||||
|
stream: enhancedStream
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
console.log(`[${MIDDLEWARE_NAME}] No stream to process or not a ReadableStream.`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
@@ -0,0 +1,142 @@
|
|||||||
|
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
|
||||||
|
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
|
||||||
|
import { ChunkType } from '@renderer/types/chunk'
|
||||||
|
import { findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||||
|
import OpenAI from 'openai'
|
||||||
|
import { toFile } from 'openai/uploads'
|
||||||
|
|
||||||
|
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||||
|
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||||
|
|
||||||
|
export const MIDDLEWARE_NAME = 'ImageGenerationMiddleware'
|
||||||
|
|
||||||
|
export const ImageGenerationMiddleware: CompletionsMiddleware =
|
||||||
|
() =>
|
||||||
|
(next) =>
|
||||||
|
async (context: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||||
|
const { assistant, messages } = params
|
||||||
|
const client = context.apiClientInstance as BaseApiClient<OpenAI>
|
||||||
|
const signal = context._internal?.flowControl?.abortSignal
|
||||||
|
|
||||||
|
if (!assistant.model || !isDedicatedImageGenerationModel(assistant.model) || typeof messages === 'string') {
|
||||||
|
return next(context, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
const stream = new ReadableStream<GenericChunk>({
|
||||||
|
async start(controller) {
|
||||||
|
const enqueue = (chunk: GenericChunk) => controller.enqueue(chunk)
|
||||||
|
|
||||||
|
try {
|
||||||
|
if (!assistant.model) {
|
||||||
|
throw new Error('Assistant model is not defined.')
|
||||||
|
}
|
||||||
|
|
||||||
|
const sdk = await client.getSdkInstance()
|
||||||
|
const lastUserMessage = messages.findLast((m) => m.role === 'user')
|
||||||
|
const lastAssistantMessage = messages.findLast((m) => m.role === 'assistant')
|
||||||
|
|
||||||
|
if (!lastUserMessage) {
|
||||||
|
throw new Error('No user message found for image generation.')
|
||||||
|
}
|
||||||
|
|
||||||
|
const prompt = getMainTextContent(lastUserMessage)
|
||||||
|
let imageFiles: Blob[] = []
|
||||||
|
|
||||||
|
// Collect images from user message
|
||||||
|
const userImageBlocks = findImageBlocks(lastUserMessage)
|
||||||
|
const userImages = await Promise.all(
|
||||||
|
userImageBlocks.map(async (block) => {
|
||||||
|
if (!block.file) return null
|
||||||
|
const binaryData: Uint8Array = await window.api.file.binaryImage(block.file.id)
|
||||||
|
const mimeType = `${block.file.type}/${block.file.ext.slice(1)}`
|
||||||
|
return await toFile(new Blob([binaryData]), block.file.origin_name || 'image.png', { type: mimeType })
|
||||||
|
})
|
||||||
|
)
|
||||||
|
imageFiles = imageFiles.concat(userImages.filter(Boolean) as Blob[])
|
||||||
|
|
||||||
|
// Collect images from last assistant message
|
||||||
|
if (lastAssistantMessage) {
|
||||||
|
const assistantImageBlocks = findImageBlocks(lastAssistantMessage)
|
||||||
|
const assistantImages = await Promise.all(
|
||||||
|
assistantImageBlocks.map(async (block) => {
|
||||||
|
const b64 = block.url?.replace(/^data:image\/\w+;base64,/, '')
|
||||||
|
if (!b64) return null
|
||||||
|
const binary = atob(b64)
|
||||||
|
const bytes = new Uint8Array(binary.length)
|
||||||
|
for (let i = 0; i < binary.length; i++) bytes[i] = binary.charCodeAt(i)
|
||||||
|
return await toFile(new Blob([bytes]), 'assistant_image.png', { type: 'image/png' })
|
||||||
|
})
|
||||||
|
)
|
||||||
|
imageFiles = imageFiles.concat(assistantImages.filter(Boolean) as Blob[])
|
||||||
|
}
|
||||||
|
|
||||||
|
enqueue({ type: ChunkType.IMAGE_CREATED })
|
||||||
|
|
||||||
|
const startTime = Date.now()
|
||||||
|
let response: OpenAI.Images.ImagesResponse
|
||||||
|
|
||||||
|
const options = { signal, timeout: 300_000 }
|
||||||
|
|
||||||
|
if (imageFiles.length > 0) {
|
||||||
|
response = await sdk.images.edit(
|
||||||
|
{
|
||||||
|
model: assistant.model.id,
|
||||||
|
image: imageFiles,
|
||||||
|
prompt: prompt || ''
|
||||||
|
},
|
||||||
|
options
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
response = await sdk.images.generate(
|
||||||
|
{
|
||||||
|
model: assistant.model.id,
|
||||||
|
prompt: prompt || '',
|
||||||
|
response_format: assistant.model.id.includes('gpt-image-1') ? undefined : 'b64_json'
|
||||||
|
},
|
||||||
|
options
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
let imageType: 'url' | 'base64' = 'base64'
|
||||||
|
const imageList =
|
||||||
|
response.data?.reduce((acc: string[], image) => {
|
||||||
|
if (image.url) {
|
||||||
|
acc.push(image.url)
|
||||||
|
imageType = 'url'
|
||||||
|
} else if (image.b64_json) {
|
||||||
|
acc.push(`data:image/png;base64,${image.b64_json}`)
|
||||||
|
}
|
||||||
|
return acc
|
||||||
|
}, []) || []
|
||||||
|
|
||||||
|
enqueue({
|
||||||
|
type: ChunkType.IMAGE_COMPLETE,
|
||||||
|
image: { type: imageType, images: imageList }
|
||||||
|
})
|
||||||
|
|
||||||
|
const usage = (response as any).usage || { prompt_tokens: 0, completion_tokens: 0, total_tokens: 0 }
|
||||||
|
|
||||||
|
enqueue({
|
||||||
|
type: ChunkType.LLM_RESPONSE_COMPLETE,
|
||||||
|
response: {
|
||||||
|
usage,
|
||||||
|
metrics: {
|
||||||
|
completion_tokens: usage.completion_tokens,
|
||||||
|
time_first_token_millsec: 0,
|
||||||
|
time_completion_millsec: Date.now() - startTime
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
} catch (error: any) {
|
||||||
|
enqueue({ type: ChunkType.ERROR, error })
|
||||||
|
} finally {
|
||||||
|
controller.close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
stream,
|
||||||
|
getText: () => ''
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,136 @@
|
|||||||
|
import { Model } from '@renderer/types'
|
||||||
|
import { ChunkType, TextDeltaChunk, ThinkingCompleteChunk, ThinkingDeltaChunk } from '@renderer/types/chunk'
|
||||||
|
import { TagConfig, TagExtractor } from '@renderer/utils/tagExtraction'
|
||||||
|
import Logger from 'electron-log/renderer'
|
||||||
|
|
||||||
|
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||||
|
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||||
|
|
||||||
|
export const MIDDLEWARE_NAME = 'ThinkingTagExtractionMiddleware'
|
||||||
|
|
||||||
|
// 不同模型的思考标签配置
|
||||||
|
const reasoningTags: TagConfig[] = [
|
||||||
|
{ openingTag: '<think>', closingTag: '</think>', separator: '\n' },
|
||||||
|
{ openingTag: '###Thinking', closingTag: '###Response', separator: '\n' }
|
||||||
|
]
|
||||||
|
|
||||||
|
const getAppropriateTag = (model?: Model): TagConfig => {
|
||||||
|
if (model?.id?.includes('qwen3')) return reasoningTags[0]
|
||||||
|
// 可以在这里添加更多模型特定的标签配置
|
||||||
|
return reasoningTags[0] // 默认使用 <think> 标签
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 处理文本流中思考标签提取的中间件
|
||||||
|
*
|
||||||
|
* 该中间件专门处理文本流中的思考标签内容(如 <think>...</think>)
|
||||||
|
* 主要用于 OpenAI 等支持思考标签的 provider
|
||||||
|
*
|
||||||
|
* 职责:
|
||||||
|
* 1. 从文本流中提取思考标签内容
|
||||||
|
* 2. 将标签内的内容转换为 THINKING_DELTA chunk
|
||||||
|
* 3. 将标签外的内容作为正常文本输出
|
||||||
|
* 4. 处理不同模型的思考标签格式
|
||||||
|
* 5. 在思考内容结束时生成 THINKING_COMPLETE 事件
|
||||||
|
*/
|
||||||
|
export const ThinkingTagExtractionMiddleware: CompletionsMiddleware =
|
||||||
|
() =>
|
||||||
|
(next) =>
|
||||||
|
async (context: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||||
|
// 调用下游中间件
|
||||||
|
const result = await next(context, params)
|
||||||
|
|
||||||
|
// 响应后处理:处理思考标签提取
|
||||||
|
if (result.stream) {
|
||||||
|
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||||
|
|
||||||
|
// 检查是否有流需要处理
|
||||||
|
if (resultFromUpstream && resultFromUpstream instanceof ReadableStream) {
|
||||||
|
// 获取当前模型的思考标签配置
|
||||||
|
const model = params.assistant?.model
|
||||||
|
const reasoningTag = getAppropriateTag(model)
|
||||||
|
|
||||||
|
// 创建标签提取器
|
||||||
|
const tagExtractor = new TagExtractor(reasoningTag)
|
||||||
|
|
||||||
|
// thinking 处理状态
|
||||||
|
let hasThinkingContent = false
|
||||||
|
let thinkingStartTime = 0
|
||||||
|
|
||||||
|
const processedStream = resultFromUpstream.pipeThrough(
|
||||||
|
new TransformStream<GenericChunk, GenericChunk>({
|
||||||
|
transform(chunk: GenericChunk, controller) {
|
||||||
|
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||||
|
const textChunk = chunk as TextDeltaChunk
|
||||||
|
|
||||||
|
// 使用 TagExtractor 处理文本
|
||||||
|
const extractionResults = tagExtractor.processText(textChunk.text)
|
||||||
|
|
||||||
|
for (const extractionResult of extractionResults) {
|
||||||
|
if (extractionResult.complete && extractionResult.tagContentExtracted) {
|
||||||
|
// 生成 THINKING_COMPLETE 事件
|
||||||
|
const thinkingCompleteChunk: ThinkingCompleteChunk = {
|
||||||
|
type: ChunkType.THINKING_COMPLETE,
|
||||||
|
text: extractionResult.tagContentExtracted,
|
||||||
|
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||||
|
}
|
||||||
|
controller.enqueue(thinkingCompleteChunk)
|
||||||
|
|
||||||
|
// 重置思考状态
|
||||||
|
hasThinkingContent = false
|
||||||
|
thinkingStartTime = 0
|
||||||
|
} else if (extractionResult.content.length > 0) {
|
||||||
|
if (extractionResult.isTagContent) {
|
||||||
|
// 第一次接收到思考内容时记录开始时间
|
||||||
|
if (!hasThinkingContent) {
|
||||||
|
hasThinkingContent = true
|
||||||
|
thinkingStartTime = Date.now()
|
||||||
|
}
|
||||||
|
|
||||||
|
const thinkingDeltaChunk: ThinkingDeltaChunk = {
|
||||||
|
type: ChunkType.THINKING_DELTA,
|
||||||
|
text: extractionResult.content,
|
||||||
|
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||||
|
}
|
||||||
|
controller.enqueue(thinkingDeltaChunk)
|
||||||
|
} else {
|
||||||
|
// 发送清理后的文本内容
|
||||||
|
const cleanTextChunk: TextDeltaChunk = {
|
||||||
|
...textChunk,
|
||||||
|
text: extractionResult.content
|
||||||
|
}
|
||||||
|
controller.enqueue(cleanTextChunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// 其他类型的chunk直接传递(包括 THINKING_DELTA, THINKING_COMPLETE 等)
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
flush(controller) {
|
||||||
|
// 处理可能剩余的思考内容
|
||||||
|
const finalResult = tagExtractor.finalize()
|
||||||
|
if (finalResult?.tagContentExtracted) {
|
||||||
|
const thinkingCompleteChunk: ThinkingCompleteChunk = {
|
||||||
|
type: ChunkType.THINKING_COMPLETE,
|
||||||
|
text: finalResult.tagContentExtracted,
|
||||||
|
thinking_millsec: thinkingStartTime > 0 ? Date.now() - thinkingStartTime : 0
|
||||||
|
}
|
||||||
|
controller.enqueue(thinkingCompleteChunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
// 更新响应结果
|
||||||
|
return {
|
||||||
|
...result,
|
||||||
|
stream: processedStream
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
Logger.warn(`[${MIDDLEWARE_NAME}] No generic chunk stream to process or not a ReadableStream.`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
@@ -0,0 +1,124 @@
|
|||||||
|
import { MCPTool } from '@renderer/types'
|
||||||
|
import { ChunkType, MCPToolCreatedChunk, TextDeltaChunk } from '@renderer/types/chunk'
|
||||||
|
import { parseToolUse } from '@renderer/utils/mcp-tools'
|
||||||
|
import { TagConfig, TagExtractor } from '@renderer/utils/tagExtraction'
|
||||||
|
|
||||||
|
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
|
||||||
|
import { CompletionsContext, CompletionsMiddleware } from '../types'
|
||||||
|
|
||||||
|
export const MIDDLEWARE_NAME = 'ToolUseExtractionMiddleware'
|
||||||
|
|
||||||
|
// 工具使用标签配置
|
||||||
|
const TOOL_USE_TAG_CONFIG: TagConfig = {
|
||||||
|
openingTag: '<tool_use>',
|
||||||
|
closingTag: '</tool_use>',
|
||||||
|
separator: '\n'
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 工具使用提取中间件
|
||||||
|
*
|
||||||
|
* 职责:
|
||||||
|
* 1. 从文本流中检测并提取 <tool_use></tool_use> 标签
|
||||||
|
* 2. 解析工具调用信息并转换为 ToolUseResponse 格式
|
||||||
|
* 3. 生成 MCP_TOOL_CREATED chunk 供 McpToolChunkMiddleware 处理
|
||||||
|
* 4. 清理文本流,移除工具使用标签但保留正常文本
|
||||||
|
*
|
||||||
|
* 注意:此中间件只负责提取和转换,实际工具调用由 McpToolChunkMiddleware 处理
|
||||||
|
*/
|
||||||
|
export const ToolUseExtractionMiddleware: CompletionsMiddleware =
|
||||||
|
() =>
|
||||||
|
(next) =>
|
||||||
|
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
|
||||||
|
const mcpTools = params.mcpTools || []
|
||||||
|
|
||||||
|
// 如果没有工具,直接调用下一个中间件
|
||||||
|
if (!mcpTools || mcpTools.length === 0) return next(ctx, params)
|
||||||
|
|
||||||
|
// 调用下游中间件
|
||||||
|
const result = await next(ctx, params)
|
||||||
|
|
||||||
|
// 响应后处理:处理工具使用标签提取
|
||||||
|
if (result.stream) {
|
||||||
|
const resultFromUpstream = result.stream as ReadableStream<GenericChunk>
|
||||||
|
|
||||||
|
const processedStream = resultFromUpstream.pipeThrough(createToolUseExtractionTransform(ctx, mcpTools))
|
||||||
|
|
||||||
|
return {
|
||||||
|
...result,
|
||||||
|
stream: processedStream
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 创建工具使用提取的 TransformStream
|
||||||
|
*/
|
||||||
|
function createToolUseExtractionTransform(
|
||||||
|
_ctx: CompletionsContext,
|
||||||
|
mcpTools: MCPTool[]
|
||||||
|
): TransformStream<GenericChunk, GenericChunk> {
|
||||||
|
const tagExtractor = new TagExtractor(TOOL_USE_TAG_CONFIG)
|
||||||
|
|
||||||
|
return new TransformStream({
|
||||||
|
async transform(chunk: GenericChunk, controller) {
|
||||||
|
try {
|
||||||
|
// 处理文本内容,检测工具使用标签
|
||||||
|
if (chunk.type === ChunkType.TEXT_DELTA) {
|
||||||
|
const textChunk = chunk as TextDeltaChunk
|
||||||
|
const extractionResults = tagExtractor.processText(textChunk.text)
|
||||||
|
|
||||||
|
for (const result of extractionResults) {
|
||||||
|
if (result.complete && result.tagContentExtracted) {
|
||||||
|
// 提取到完整的工具使用内容,解析并转换为 SDK ToolCall 格式
|
||||||
|
const toolUseResponses = parseToolUse(result.tagContentExtracted, mcpTools)
|
||||||
|
|
||||||
|
if (toolUseResponses.length > 0) {
|
||||||
|
// 生成 MCP_TOOL_CREATED chunk,复用现有的处理流程
|
||||||
|
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
|
||||||
|
type: ChunkType.MCP_TOOL_CREATED,
|
||||||
|
tool_use_responses: toolUseResponses
|
||||||
|
}
|
||||||
|
controller.enqueue(mcpToolCreatedChunk)
|
||||||
|
}
|
||||||
|
} else if (!result.isTagContent && result.content) {
|
||||||
|
// 发送标签外的正常文本内容
|
||||||
|
const cleanTextChunk: TextDeltaChunk = {
|
||||||
|
...textChunk,
|
||||||
|
text: result.content
|
||||||
|
}
|
||||||
|
controller.enqueue(cleanTextChunk)
|
||||||
|
}
|
||||||
|
// 注意:标签内的内容不会作为TEXT_DELTA转发,避免重复显示
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// 转发其他所有chunk
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
} catch (error) {
|
||||||
|
console.error(`🔧 [${MIDDLEWARE_NAME}] Error processing chunk:`, error)
|
||||||
|
controller.error(error)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
async flush(controller) {
|
||||||
|
// 检查是否有未完成的标签内容
|
||||||
|
const finalResult = tagExtractor.finalize()
|
||||||
|
if (finalResult && finalResult.tagContentExtracted) {
|
||||||
|
const toolUseResponses = parseToolUse(finalResult.tagContentExtracted, mcpTools)
|
||||||
|
if (toolUseResponses.length > 0) {
|
||||||
|
const mcpToolCreatedChunk: MCPToolCreatedChunk = {
|
||||||
|
type: ChunkType.MCP_TOOL_CREATED,
|
||||||
|
tool_use_responses: toolUseResponses
|
||||||
|
}
|
||||||
|
controller.enqueue(mcpToolCreatedChunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
export default ToolUseExtractionMiddleware
|
||||||
88
src/renderer/src/aiCore/middleware/index.ts
Normal file
@@ -0,0 +1,88 @@
|
|||||||
|
import { CompletionsMiddleware, MethodMiddleware } from './types'
|
||||||
|
|
||||||
|
// /**
|
||||||
|
// * Wraps a provider instance with middlewares.
|
||||||
|
// */
|
||||||
|
// export function wrapProviderWithMiddleware(
|
||||||
|
// apiClientInstance: BaseApiClient,
|
||||||
|
// middlewareConfig: MiddlewareConfig
|
||||||
|
// ): BaseApiClient {
|
||||||
|
// console.log(`[wrapProviderWithMiddleware] Wrapping provider: ${apiClientInstance.provider?.id}`)
|
||||||
|
// console.log(`[wrapProviderWithMiddleware] Middleware config:`, {
|
||||||
|
// completions: middlewareConfig.completions?.length || 0,
|
||||||
|
// methods: Object.keys(middlewareConfig.methods || {}).length
|
||||||
|
// })
|
||||||
|
|
||||||
|
// // Cache for already wrapped methods to avoid re-wrapping on every access.
|
||||||
|
// const wrappedMethodsCache = new Map<string, (...args: any[]) => Promise<any>>()
|
||||||
|
|
||||||
|
// const proxy = new Proxy(apiClientInstance, {
|
||||||
|
// get(target, propKey, receiver) {
|
||||||
|
// const methodName = typeof propKey === 'string' ? propKey : undefined
|
||||||
|
|
||||||
|
// if (!methodName) {
|
||||||
|
// return Reflect.get(target, propKey, receiver)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if (wrappedMethodsCache.has(methodName)) {
|
||||||
|
// console.log(`[wrapProviderWithMiddleware] Using cached wrapped method: ${methodName}`)
|
||||||
|
// return wrappedMethodsCache.get(methodName)
|
||||||
|
// }
|
||||||
|
|
||||||
|
// const originalMethod = Reflect.get(target, propKey, receiver)
|
||||||
|
|
||||||
|
// // If the property is not a function, return it directly.
|
||||||
|
// if (typeof originalMethod !== 'function') {
|
||||||
|
// return originalMethod
|
||||||
|
// }
|
||||||
|
|
||||||
|
// let wrappedMethod: ((...args: any[]) => Promise<any>) | undefined
|
||||||
|
|
||||||
|
// // Handle completions method
|
||||||
|
// if (methodName === 'completions' && middlewareConfig.completions?.length) {
|
||||||
|
// console.log(
|
||||||
|
// `[wrapProviderWithMiddleware] Wrapping completions method with ${middlewareConfig.completions.length} middlewares`
|
||||||
|
// )
|
||||||
|
// const completionsOriginalMethod = originalMethod as (params: CompletionsParams) => Promise<any>
|
||||||
|
// wrappedMethod = applyCompletionsMiddlewares(target, completionsOriginalMethod, middlewareConfig.completions)
|
||||||
|
// }
|
||||||
|
// // Handle other methods
|
||||||
|
// else {
|
||||||
|
// const methodMiddlewares = middlewareConfig.methods?.[methodName]
|
||||||
|
// if (methodMiddlewares?.length) {
|
||||||
|
// console.log(
|
||||||
|
// `[wrapProviderWithMiddleware] Wrapping method ${methodName} with ${methodMiddlewares.length} middlewares`
|
||||||
|
// )
|
||||||
|
// const genericOriginalMethod = originalMethod as (...args: any[]) => Promise<any>
|
||||||
|
// wrappedMethod = applyMethodMiddlewares(target, methodName, genericOriginalMethod, methodMiddlewares)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
// if (wrappedMethod) {
|
||||||
|
// console.log(`[wrapProviderWithMiddleware] Successfully wrapped method: ${methodName}`)
|
||||||
|
// wrappedMethodsCache.set(methodName, wrappedMethod)
|
||||||
|
// return wrappedMethod
|
||||||
|
// }
|
||||||
|
|
||||||
|
// // If no middlewares are configured for this method, return the original method bound to the target. /
|
||||||
|
// // 如果没有为此方法配置中间件,则返回绑定到目标的原始方法。
|
||||||
|
// console.log(`[wrapProviderWithMiddleware] No middlewares for method ${methodName}, returning original`)
|
||||||
|
// return originalMethod.bind(target)
|
||||||
|
// }
|
||||||
|
// })
|
||||||
|
// return proxy as BaseApiClient
|
||||||
|
// }
|
||||||
|
|
||||||
|
// Export types for external use
|
||||||
|
export type { CompletionsMiddleware, MethodMiddleware }
|
||||||
|
|
||||||
|
// Export MiddlewareBuilder related types and classes
|
||||||
|
export {
|
||||||
|
CompletionsMiddlewareBuilder,
|
||||||
|
createCompletionsBuilder,
|
||||||
|
createMethodBuilder,
|
||||||
|
MethodMiddlewareBuilder,
|
||||||
|
MiddlewareBuilder,
|
||||||
|
type MiddlewareExecutor,
|
||||||
|
type NamedMiddleware
|
||||||
|
} from './builder'
|
||||||
149
src/renderer/src/aiCore/middleware/register.ts
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
import * as AbortHandlerModule from './common/AbortHandlerMiddleware'
|
||||||
|
import * as ErrorHandlerModule from './common/ErrorHandlerMiddleware'
|
||||||
|
import * as FinalChunkConsumerModule from './common/FinalChunkConsumerMiddleware'
|
||||||
|
import * as LoggingModule from './common/LoggingMiddleware'
|
||||||
|
import * as McpToolChunkModule from './core/McpToolChunkMiddleware'
|
||||||
|
import * as RawStreamListenerModule from './core/RawStreamListenerMiddleware'
|
||||||
|
import * as ResponseTransformModule from './core/ResponseTransformMiddleware'
|
||||||
|
// import * as SdkCallModule from './core/SdkCallMiddleware'
|
||||||
|
import * as StreamAdapterModule from './core/StreamAdapterMiddleware'
|
||||||
|
import * as TextChunkModule from './core/TextChunkMiddleware'
|
||||||
|
import * as ThinkChunkModule from './core/ThinkChunkMiddleware'
|
||||||
|
import * as TransformCoreToSdkParamsModule from './core/TransformCoreToSdkParamsMiddleware'
|
||||||
|
import * as WebSearchModule from './core/WebSearchMiddleware'
|
||||||
|
import * as ImageGenerationModule from './feat/ImageGenerationMiddleware'
|
||||||
|
import * as ThinkingTagExtractionModule from './feat/ThinkingTagExtractionMiddleware'
|
||||||
|
import * as ToolUseExtractionMiddleware from './feat/ToolUseExtractionMiddleware'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 中间件注册表 - 提供所有可用中间件的集中访问
|
||||||
|
* 注意:目前中间件文件还未导出 MIDDLEWARE_NAME,会有 linter 错误,这是正常的
|
||||||
|
*/
|
||||||
|
export const MiddlewareRegistry = {
|
||||||
|
[ErrorHandlerModule.MIDDLEWARE_NAME]: {
|
||||||
|
name: ErrorHandlerModule.MIDDLEWARE_NAME,
|
||||||
|
middleware: ErrorHandlerModule.ErrorHandlerMiddleware
|
||||||
|
},
|
||||||
|
// 通用中间件
|
||||||
|
[AbortHandlerModule.MIDDLEWARE_NAME]: {
|
||||||
|
name: AbortHandlerModule.MIDDLEWARE_NAME,
|
||||||
|
middleware: AbortHandlerModule.AbortHandlerMiddleware
|
||||||
|
},
|
||||||
|
[FinalChunkConsumerModule.MIDDLEWARE_NAME]: {
|
||||||
|
name: FinalChunkConsumerModule.MIDDLEWARE_NAME,
|
||||||
|
middleware: FinalChunkConsumerModule.default
|
||||||
|
},
|
||||||
|
|
||||||
|
// 核心流程中间件
|
||||||
|
[TransformCoreToSdkParamsModule.MIDDLEWARE_NAME]: {
|
||||||
|
name: TransformCoreToSdkParamsModule.MIDDLEWARE_NAME,
|
||||||
|
middleware: TransformCoreToSdkParamsModule.TransformCoreToSdkParamsMiddleware
|
||||||
|
},
|
||||||
|
// [SdkCallModule.MIDDLEWARE_NAME]: {
|
||||||
|
// name: SdkCallModule.MIDDLEWARE_NAME,
|
||||||
|
// middleware: SdkCallModule.SdkCallMiddleware
|
||||||
|
// },
|
||||||
|
[StreamAdapterModule.MIDDLEWARE_NAME]: {
|
||||||
|
name: StreamAdapterModule.MIDDLEWARE_NAME,
|
||||||
|
middleware: StreamAdapterModule.StreamAdapterMiddleware
|
||||||
|
},
|
||||||
|
[RawStreamListenerModule.MIDDLEWARE_NAME]: {
|
||||||
|
name: RawStreamListenerModule.MIDDLEWARE_NAME,
|
||||||
|
middleware: RawStreamListenerModule.RawStreamListenerMiddleware
|
||||||
|
},
|
||||||
|
[ResponseTransformModule.MIDDLEWARE_NAME]: {
|
||||||
|
name: ResponseTransformModule.MIDDLEWARE_NAME,
|
||||||
|
middleware: ResponseTransformModule.ResponseTransformMiddleware
|
||||||
|
},
|
||||||
|
|
||||||
|
// 特性处理中间件
|
||||||
|
[ThinkingTagExtractionModule.MIDDLEWARE_NAME]: {
|
||||||
|
name: ThinkingTagExtractionModule.MIDDLEWARE_NAME,
|
||||||
|
middleware: ThinkingTagExtractionModule.ThinkingTagExtractionMiddleware
|
||||||
|
},
|
||||||
|
[ToolUseExtractionMiddleware.MIDDLEWARE_NAME]: {
|
||||||
|
name: ToolUseExtractionMiddleware.MIDDLEWARE_NAME,
|
||||||
|
middleware: ToolUseExtractionMiddleware.ToolUseExtractionMiddleware
|
||||||
|
},
|
||||||
|
[ThinkChunkModule.MIDDLEWARE_NAME]: {
|
||||||
|
name: ThinkChunkModule.MIDDLEWARE_NAME,
|
||||||
|
middleware: ThinkChunkModule.ThinkChunkMiddleware
|
||||||
|
},
|
||||||
|
[McpToolChunkModule.MIDDLEWARE_NAME]: {
|
||||||
|
name: McpToolChunkModule.MIDDLEWARE_NAME,
|
||||||
|
middleware: McpToolChunkModule.McpToolChunkMiddleware
|
||||||
|
},
|
||||||
|
[WebSearchModule.MIDDLEWARE_NAME]: {
|
||||||
|
name: WebSearchModule.MIDDLEWARE_NAME,
|
||||||
|
middleware: WebSearchModule.WebSearchMiddleware
|
||||||
|
},
|
||||||
|
[TextChunkModule.MIDDLEWARE_NAME]: {
|
||||||
|
name: TextChunkModule.MIDDLEWARE_NAME,
|
||||||
|
middleware: TextChunkModule.TextChunkMiddleware
|
||||||
|
},
|
||||||
|
[ImageGenerationModule.MIDDLEWARE_NAME]: {
|
||||||
|
name: ImageGenerationModule.MIDDLEWARE_NAME,
|
||||||
|
middleware: ImageGenerationModule.ImageGenerationMiddleware
|
||||||
|
}
|
||||||
|
} as const
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 根据名称获取中间件
|
||||||
|
* @param name - 中间件名称
|
||||||
|
* @returns 对应的中间件信息
|
||||||
|
*/
|
||||||
|
export function getMiddleware(name: string) {
|
||||||
|
return MiddlewareRegistry[name]
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 获取所有注册的中间件名称
|
||||||
|
* @returns 中间件名称列表
|
||||||
|
*/
|
||||||
|
export function getRegisteredMiddlewareNames(): string[] {
|
||||||
|
return Object.keys(MiddlewareRegistry)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 默认的 Completions 中间件配置 - NamedMiddleware 格式,用于 MiddlewareBuilder
|
||||||
|
*/
|
||||||
|
export const DefaultCompletionsNamedMiddlewares = [
|
||||||
|
MiddlewareRegistry[FinalChunkConsumerModule.MIDDLEWARE_NAME], // 最终消费者
|
||||||
|
MiddlewareRegistry[ErrorHandlerModule.MIDDLEWARE_NAME], // 错误处理
|
||||||
|
MiddlewareRegistry[TransformCoreToSdkParamsModule.MIDDLEWARE_NAME], // 参数转换
|
||||||
|
MiddlewareRegistry[AbortHandlerModule.MIDDLEWARE_NAME], // 中止处理
|
||||||
|
MiddlewareRegistry[McpToolChunkModule.MIDDLEWARE_NAME], // 工具处理
|
||||||
|
MiddlewareRegistry[TextChunkModule.MIDDLEWARE_NAME], // 文本处理
|
||||||
|
MiddlewareRegistry[WebSearchModule.MIDDLEWARE_NAME], // Web搜索处理
|
||||||
|
MiddlewareRegistry[ToolUseExtractionMiddleware.MIDDLEWARE_NAME], // 工具使用提取处理
|
||||||
|
MiddlewareRegistry[ThinkingTagExtractionModule.MIDDLEWARE_NAME], // 思考标签提取处理(特定provider)
|
||||||
|
MiddlewareRegistry[ThinkChunkModule.MIDDLEWARE_NAME], // 思考处理(通用SDK)
|
||||||
|
MiddlewareRegistry[ResponseTransformModule.MIDDLEWARE_NAME], // 响应转换
|
||||||
|
MiddlewareRegistry[StreamAdapterModule.MIDDLEWARE_NAME], // 流适配器
|
||||||
|
MiddlewareRegistry[RawStreamListenerModule.MIDDLEWARE_NAME] // 原始流监听器
|
||||||
|
]
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 默认的通用方法中间件 - 例如翻译、摘要等
|
||||||
|
*/
|
||||||
|
export const DefaultMethodMiddlewares = {
|
||||||
|
translate: [LoggingModule.createGenericLoggingMiddleware()],
|
||||||
|
summaries: [LoggingModule.createGenericLoggingMiddleware()]
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 导出所有中间件模块,方便外部使用
|
||||||
|
*/
|
||||||
|
export {
|
||||||
|
AbortHandlerModule,
|
||||||
|
FinalChunkConsumerModule,
|
||||||
|
LoggingModule,
|
||||||
|
McpToolChunkModule,
|
||||||
|
ResponseTransformModule,
|
||||||
|
StreamAdapterModule,
|
||||||
|
TextChunkModule,
|
||||||
|
ThinkChunkModule,
|
||||||
|
ThinkingTagExtractionModule,
|
||||||
|
TransformCoreToSdkParamsModule,
|
||||||
|
WebSearchModule
|
||||||
|
}
|
||||||
77
src/renderer/src/aiCore/middleware/schemas.ts
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
import { Assistant, MCPTool } from '@renderer/types'
|
||||||
|
import { Chunk } from '@renderer/types/chunk'
|
||||||
|
import { Message } from '@renderer/types/newMessage'
|
||||||
|
import { SdkRawChunk, SdkRawOutput } from '@renderer/types/sdk'
|
||||||
|
|
||||||
|
import { ProcessingState } from './types'
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Core Request Types - 核心请求结构
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 标准化的内部核心请求结构,用于所有AI Provider的统一处理
|
||||||
|
* 这是应用层参数转换后的标准格式,不包含回调函数和控制逻辑
|
||||||
|
*/
|
||||||
|
export interface CompletionsParams {
|
||||||
|
/**
|
||||||
|
* 调用的业务场景类型,用于中间件判断是否执行
|
||||||
|
* 'chat': 主要对话流程
|
||||||
|
* 'translate': 翻译
|
||||||
|
* 'summary': 摘要
|
||||||
|
* 'search': 搜索摘要
|
||||||
|
* 'generate': 生成
|
||||||
|
* 'check': API检查
|
||||||
|
*/
|
||||||
|
callType?: 'chat' | 'translate' | 'summary' | 'search' | 'generate' | 'check'
|
||||||
|
|
||||||
|
// 基础对话数据
|
||||||
|
messages: Message[] | string // 联合类型方便判断是否为空
|
||||||
|
|
||||||
|
assistant: Assistant // 助手为基本单位
|
||||||
|
// model: Model
|
||||||
|
|
||||||
|
onChunk?: (chunk: Chunk) => void
|
||||||
|
onResponse?: (text: string, isComplete: boolean) => void
|
||||||
|
|
||||||
|
// 错误相关
|
||||||
|
onError?: (error: Error) => void
|
||||||
|
shouldThrow?: boolean
|
||||||
|
|
||||||
|
// 工具相关
|
||||||
|
mcpTools?: MCPTool[]
|
||||||
|
|
||||||
|
// 生成参数
|
||||||
|
temperature?: number
|
||||||
|
topP?: number
|
||||||
|
maxTokens?: number
|
||||||
|
|
||||||
|
// 功能开关
|
||||||
|
streamOutput: boolean
|
||||||
|
enableWebSearch?: boolean
|
||||||
|
enableReasoning?: boolean
|
||||||
|
enableGenerateImage?: boolean
|
||||||
|
|
||||||
|
// 上下文控制
|
||||||
|
contextCount?: number
|
||||||
|
|
||||||
|
_internal?: ProcessingState
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface CompletionsResult {
|
||||||
|
rawOutput?: SdkRawOutput
|
||||||
|
stream?: ReadableStream<SdkRawChunk> | ReadableStream<Chunk> | AsyncIterable<Chunk>
|
||||||
|
controller?: AbortController
|
||||||
|
|
||||||
|
getText: () => string
|
||||||
|
}
|
||||||
|
|
||||||
|
// ============================================================================
|
||||||
|
// Generic Chunk Types - 通用数据块结构
|
||||||
|
// ============================================================================
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 通用数据块类型
|
||||||
|
* 复用现有的 Chunk 类型,这是所有AI Provider都应该输出的标准化数据块格式
|
||||||
|
*/
|
||||||
|
export type GenericChunk = Chunk
|
||||||
166
src/renderer/src/aiCore/middleware/types.ts
Normal file
@@ -0,0 +1,166 @@
|
|||||||
|
import { MCPToolResponse, Metrics, Usage, WebSearchResponse } from '@renderer/types'
|
||||||
|
import { Chunk, ErrorChunk } from '@renderer/types/chunk'
|
||||||
|
import {
|
||||||
|
SdkInstance,
|
||||||
|
SdkMessageParam,
|
||||||
|
SdkParams,
|
||||||
|
SdkRawChunk,
|
||||||
|
SdkRawOutput,
|
||||||
|
SdkTool,
|
||||||
|
SdkToolCall
|
||||||
|
} from '@renderer/types/sdk'
|
||||||
|
|
||||||
|
import { BaseApiClient } from '../clients'
|
||||||
|
import { CompletionsParams, CompletionsResult } from './schemas'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Symbol to uniquely identify middleware context objects.
|
||||||
|
*/
|
||||||
|
export const MIDDLEWARE_CONTEXT_SYMBOL = Symbol.for('AiProviderMiddlewareContext')
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Defines the structure for the onChunk callback function.
|
||||||
|
*/
|
||||||
|
export type OnChunkFunction = (chunk: Chunk | ErrorChunk) => void
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Base context that carries information about the current method call.
|
||||||
|
*/
|
||||||
|
export interface BaseContext {
|
||||||
|
[MIDDLEWARE_CONTEXT_SYMBOL]: true
|
||||||
|
methodName: string
|
||||||
|
originalArgs: Readonly<any[]>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Processing state shared between middlewares.
|
||||||
|
*/
|
||||||
|
export interface ProcessingState<
|
||||||
|
TParams extends SdkParams = SdkParams,
|
||||||
|
TMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||||
|
TToolCall extends SdkToolCall = SdkToolCall
|
||||||
|
> {
|
||||||
|
sdkPayload?: TParams
|
||||||
|
newReqMessages?: TMessageParam[]
|
||||||
|
observer?: {
|
||||||
|
usage?: Usage
|
||||||
|
metrics?: Metrics
|
||||||
|
}
|
||||||
|
toolProcessingState?: {
|
||||||
|
pendingToolCalls?: Array<TToolCall>
|
||||||
|
executingToolCalls?: Array<{
|
||||||
|
sdkToolCall: TToolCall
|
||||||
|
mcpToolResponse: MCPToolResponse
|
||||||
|
}>
|
||||||
|
output?: SdkRawOutput | string
|
||||||
|
isRecursiveCall?: boolean
|
||||||
|
recursionDepth?: number
|
||||||
|
}
|
||||||
|
webSearchState?: {
|
||||||
|
results?: WebSearchResponse
|
||||||
|
}
|
||||||
|
flowControl?: {
|
||||||
|
abortController?: AbortController
|
||||||
|
abortSignal?: AbortSignal
|
||||||
|
cleanup?: () => void
|
||||||
|
}
|
||||||
|
enhancedDispatch?: (context: CompletionsContext, params: CompletionsParams) => Promise<CompletionsResult>
|
||||||
|
customState?: Record<string, any>
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Extended context for completions method.
|
||||||
|
*/
|
||||||
|
export interface CompletionsContext<
|
||||||
|
TSdkParams extends SdkParams = SdkParams,
|
||||||
|
TSdkMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||||
|
TSdkToolCall extends SdkToolCall = SdkToolCall,
|
||||||
|
TSdkInstance extends SdkInstance = SdkInstance,
|
||||||
|
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||||
|
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||||
|
TSdkSpecificTool extends SdkTool = SdkTool
|
||||||
|
> extends BaseContext {
|
||||||
|
readonly methodName: 'completions' // 强制方法名为 'completions'
|
||||||
|
|
||||||
|
apiClientInstance: BaseApiClient<
|
||||||
|
TSdkInstance,
|
||||||
|
TSdkParams,
|
||||||
|
TRawOutput,
|
||||||
|
TRawChunk,
|
||||||
|
TSdkMessageParam,
|
||||||
|
TSdkToolCall,
|
||||||
|
TSdkSpecificTool
|
||||||
|
>
|
||||||
|
|
||||||
|
// --- Mutable internal state for the duration of the middleware chain ---
|
||||||
|
_internal: ProcessingState<TSdkParams, TSdkMessageParam, TSdkToolCall> // 包含所有可变的处理状态
|
||||||
|
}
|
||||||
|
|
||||||
|
export interface MiddlewareAPI<Ctx extends BaseContext = BaseContext, Args extends any[] = any[]> {
|
||||||
|
getContext: () => Ctx // Function to get the current context / 获取当前上下文的函数
|
||||||
|
getOriginalArgs: () => Args // Function to get the original arguments of the method call / 获取方法调用原始参数的函数
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Base middleware type.
|
||||||
|
*/
|
||||||
|
export type Middleware<TContext extends BaseContext> = (
|
||||||
|
api: MiddlewareAPI<TContext>
|
||||||
|
) => (
|
||||||
|
next: (context: TContext, args: any[]) => Promise<unknown>
|
||||||
|
) => (context: TContext, args: any[]) => Promise<unknown>
|
||||||
|
|
||||||
|
export type MethodMiddleware = Middleware<BaseContext>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Completions middleware type.
|
||||||
|
*/
|
||||||
|
export type CompletionsMiddleware<
|
||||||
|
TSdkParams extends SdkParams = SdkParams,
|
||||||
|
TSdkMessageParam extends SdkMessageParam = SdkMessageParam,
|
||||||
|
TSdkToolCall extends SdkToolCall = SdkToolCall,
|
||||||
|
TSdkInstance extends SdkInstance = SdkInstance,
|
||||||
|
TRawOutput extends SdkRawOutput = SdkRawOutput,
|
||||||
|
TRawChunk extends SdkRawChunk = SdkRawChunk,
|
||||||
|
TSdkSpecificTool extends SdkTool = SdkTool
|
||||||
|
> = (
|
||||||
|
api: MiddlewareAPI<
|
||||||
|
CompletionsContext<
|
||||||
|
TSdkParams,
|
||||||
|
TSdkMessageParam,
|
||||||
|
TSdkToolCall,
|
||||||
|
TSdkInstance,
|
||||||
|
TRawOutput,
|
||||||
|
TRawChunk,
|
||||||
|
TSdkSpecificTool
|
||||||
|
>,
|
||||||
|
[CompletionsParams]
|
||||||
|
>
|
||||||
|
) => (
|
||||||
|
next: (
|
||||||
|
context: CompletionsContext<
|
||||||
|
TSdkParams,
|
||||||
|
TSdkMessageParam,
|
||||||
|
TSdkToolCall,
|
||||||
|
TSdkInstance,
|
||||||
|
TRawOutput,
|
||||||
|
TRawChunk,
|
||||||
|
TSdkSpecificTool
|
||||||
|
>,
|
||||||
|
params: CompletionsParams
|
||||||
|
) => Promise<CompletionsResult>
|
||||||
|
) => (
|
||||||
|
context: CompletionsContext<
|
||||||
|
TSdkParams,
|
||||||
|
TSdkMessageParam,
|
||||||
|
TSdkToolCall,
|
||||||
|
TSdkInstance,
|
||||||
|
TRawOutput,
|
||||||
|
TRawChunk,
|
||||||
|
TSdkSpecificTool
|
||||||
|
>,
|
||||||
|
params: CompletionsParams
|
||||||
|
) => Promise<CompletionsResult>
|
||||||
|
|
||||||
|
// Re-export for convenience
|
||||||
|
export type { Chunk as OnChunkArg } from '@renderer/types/chunk'
|
||||||
57
src/renderer/src/aiCore/middleware/utils.ts
Normal file
@@ -0,0 +1,57 @@
|
|||||||
|
import { ChunkType, ErrorChunk } from '@renderer/types/chunk'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates an ErrorChunk object with a standardized structure.
|
||||||
|
* @param error The error object or message.
|
||||||
|
* @param chunkType The type of chunk, defaults to ChunkType.ERROR.
|
||||||
|
* @returns An ErrorChunk object.
|
||||||
|
*/
|
||||||
|
export function createErrorChunk(error: any, chunkType: ChunkType = ChunkType.ERROR): ErrorChunk {
|
||||||
|
let errorDetails: Record<string, any> = {}
|
||||||
|
|
||||||
|
if (error instanceof Error) {
|
||||||
|
errorDetails = {
|
||||||
|
message: error.message,
|
||||||
|
name: error.name,
|
||||||
|
stack: error.stack
|
||||||
|
}
|
||||||
|
} else if (typeof error === 'string') {
|
||||||
|
errorDetails = { message: error }
|
||||||
|
} else if (typeof error === 'object' && error !== null) {
|
||||||
|
errorDetails = Object.getOwnPropertyNames(error).reduce(
|
||||||
|
(acc, key) => {
|
||||||
|
acc[key] = error[key]
|
||||||
|
return acc
|
||||||
|
},
|
||||||
|
{} as Record<string, any>
|
||||||
|
)
|
||||||
|
if (!errorDetails.message && error.toString && typeof error.toString === 'function') {
|
||||||
|
const errMsg = error.toString()
|
||||||
|
if (errMsg !== '[object Object]') {
|
||||||
|
errorDetails.message = errMsg
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
type: chunkType,
|
||||||
|
error: errorDetails
|
||||||
|
} as ErrorChunk
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to capitalize method names for hook construction
|
||||||
|
export function capitalize(str: string): string {
|
||||||
|
if (!str) return ''
|
||||||
|
return str.charAt(0).toUpperCase() + str.slice(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* 检查对象是否实现了AsyncIterable接口
|
||||||
|
*/
|
||||||
|
export function isAsyncIterable<T = unknown>(obj: unknown): obj is AsyncIterable<T> {
|
||||||
|
return (
|
||||||
|
obj !== null &&
|
||||||
|
typeof obj === 'object' &&
|
||||||
|
typeof (obj as Record<symbol, unknown>)[Symbol.asyncIterator] === 'function'
|
||||||
|
)
|
||||||
|
}
|
||||||
BIN
src/renderer/src/assets/images/models/gpt_image_1.png
Normal file
|
After Width: | Height: | Size: 20 KiB |
BIN
src/renderer/src/assets/images/providers/302ai.webp
Normal file
|
After Width: | Height: | Size: 6.9 KiB |
BIN
src/renderer/src/assets/images/providers/cephalon.jpeg
Normal file
|
After Width: | Height: | Size: 5.9 KiB |
BIN
src/renderer/src/assets/images/providers/dmxapi-logo-dark.webp
Normal file
|
After Width: | Height: | Size: 10 KiB |
|
Before Width: | Height: | Size: 5.4 KiB After Width: | Height: | Size: 12 KiB |
BIN
src/renderer/src/assets/images/providers/lanyun.png
Normal file
|
After Width: | Height: | Size: 16 KiB |
|
After Width: | Height: | Size: 5.6 KiB |
BIN
src/renderer/src/assets/images/providers/nomic.png
Normal file
|
After Width: | Height: | Size: 2.3 KiB |
|
Before Width: | Height: | Size: 1.9 KiB After Width: | Height: | Size: 1.5 KiB |