Compare commits
56 Commits
fix/openmi
...
feat/mcp-u
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
22ed2b605e | ||
|
|
2c3338939e | ||
|
|
64ca3802a4 | ||
|
|
fa361126b8 | ||
|
|
49903a1567 | ||
|
|
086b16a59c | ||
|
|
e2562d8224 | ||
|
|
c9be949853 | ||
|
|
ebfb1c5abf | ||
|
|
c1f1d7996d | ||
|
|
0a72c613af | ||
|
|
a1ac3207f1 | ||
|
|
f98a063a8f | ||
|
|
1cb2af57ae | ||
|
|
62309ae1bf | ||
|
|
c48f222cdb | ||
|
|
cea0058f87 | ||
|
|
852192dce6 | ||
|
|
eee49d1580 | ||
|
|
dcdd1bf852 | ||
|
|
a12b6bfeca | ||
|
|
0f1a487bb0 | ||
|
|
2df8bb58df | ||
|
|
62976f6fe0 | ||
|
|
77529b3cd3 | ||
|
|
c8e9a10190 | ||
|
|
0e011ff35f | ||
|
|
40a64a7c92 | ||
|
|
dc9503ef8b | ||
|
|
f2c8484c48 | ||
|
|
a9c9224835 | ||
|
|
43223fd1f5 | ||
|
|
4bac843b37 | ||
|
|
34723934f4 | ||
|
|
096c36caf8 | ||
|
|
139950e193 | ||
|
|
31eec403f7 | ||
|
|
7fd4837a47 | ||
|
|
90b0c8b4a6 | ||
|
|
556353e910 | ||
|
|
11fb730b4d | ||
|
|
2511113b62 | ||
|
|
a29b2bb3d6 | ||
|
|
d2be450906 | ||
|
|
9c020f0d56 | ||
|
|
e033eb5b5c | ||
|
|
073d43c7cb | ||
|
|
fa7646e18f | ||
|
|
038d30831c | ||
|
|
68ee5164f0 | ||
|
|
a1a3b9bd96 | ||
|
|
4e699c48bc | ||
|
|
75fcf8fbb5 | ||
|
|
35aa9d7355 | ||
|
|
b08aecb22b | ||
|
|
45fc6c2afd |
1
.github/CODEOWNERS
vendored
1
.github/CODEOWNERS
vendored
@@ -3,3 +3,4 @@
|
|||||||
/src/main/services/ConfigManager.ts @0xfullex
|
/src/main/services/ConfigManager.ts @0xfullex
|
||||||
/packages/shared/IpcChannel.ts @0xfullex
|
/packages/shared/IpcChannel.ts @0xfullex
|
||||||
/src/main/ipc.ts @0xfullex
|
/src/main/ipc.ts @0xfullex
|
||||||
|
/app-upgrade-config.json @kangfenmao
|
||||||
|
|||||||
2
.github/workflows/auto-i18n.yml
vendored
2
.github/workflows/auto-i18n.yml
vendored
@@ -77,7 +77,7 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
token: ${{ secrets.GITHUB_TOKEN }} # Use the built-in GITHUB_TOKEN for bot actions
|
token: ${{ secrets.GITHUB_TOKEN }} # Use the built-in GITHUB_TOKEN for bot actions
|
||||||
commit-message: "feat(bot): Weekly automated script run"
|
commit-message: "feat(bot): Weekly automated script run"
|
||||||
title: "🤖 Weekly Automated Update: ${{ env.CURRENT_DATE }}"
|
title: "🤖 Weekly Auto I18N Sync: ${{ env.CURRENT_DATE }}"
|
||||||
body: |
|
body: |
|
||||||
This PR includes changes generated by the weekly auto i18n.
|
This PR includes changes generated by the weekly auto i18n.
|
||||||
Review the changes before merging.
|
Review the changes before merging.
|
||||||
|
|||||||
212
.github/workflows/update-app-upgrade-config.yml
vendored
Normal file
212
.github/workflows/update-app-upgrade-config.yml
vendored
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
name: Update App Upgrade Config
|
||||||
|
|
||||||
|
on:
|
||||||
|
release:
|
||||||
|
types:
|
||||||
|
- released
|
||||||
|
- prereleased
|
||||||
|
workflow_dispatch:
|
||||||
|
inputs:
|
||||||
|
tag:
|
||||||
|
description: "Release tag (e.g., v1.2.3)"
|
||||||
|
required: true
|
||||||
|
type: string
|
||||||
|
is_prerelease:
|
||||||
|
description: "Mark the tag as a prerelease when running manually"
|
||||||
|
required: false
|
||||||
|
default: false
|
||||||
|
type: boolean
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: write
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
propose-update:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
if: github.event_name == 'workflow_dispatch' || (github.event_name == 'release' && github.event.release.draft == false)
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Check if should proceed
|
||||||
|
id: check
|
||||||
|
run: |
|
||||||
|
EVENT="${{ github.event_name }}"
|
||||||
|
|
||||||
|
if [ "$EVENT" = "workflow_dispatch" ]; then
|
||||||
|
TAG="${{ github.event.inputs.tag }}"
|
||||||
|
else
|
||||||
|
TAG="${{ github.event.release.tag_name }}"
|
||||||
|
fi
|
||||||
|
|
||||||
|
latest_tag=$(
|
||||||
|
curl -L \
|
||||||
|
-H "Accept: application/vnd.github+json" \
|
||||||
|
-H "Authorization: Bearer ${{ github.token }}" \
|
||||||
|
-H "X-GitHub-Api-Version: 2022-11-28" \
|
||||||
|
https://api.github.com/repos/${{ github.repository }}/releases/latest \
|
||||||
|
| jq -r '.tag_name'
|
||||||
|
)
|
||||||
|
|
||||||
|
if [ "$EVENT" = "workflow_dispatch" ]; then
|
||||||
|
MANUAL_IS_PRERELEASE="${{ github.event.inputs.is_prerelease }}"
|
||||||
|
if [ -z "$MANUAL_IS_PRERELEASE" ]; then
|
||||||
|
MANUAL_IS_PRERELEASE="false"
|
||||||
|
fi
|
||||||
|
if [ "$MANUAL_IS_PRERELEASE" = "true" ]; then
|
||||||
|
if ! echo "$TAG" | grep -E '(-beta([.-][0-9]+)?|-rc([.-][0-9]+)?)' >/dev/null; then
|
||||||
|
echo "Manual prerelease flag set but tag $TAG lacks beta/rc suffix. Skipping." >&2
|
||||||
|
echo "should_run=false" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "is_prerelease=false" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "latest_tag=$latest_tag" >> "$GITHUB_OUTPUT"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
echo "should_run=true" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "is_prerelease=$MANUAL_IS_PRERELEASE" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "latest_tag=$latest_tag" >> "$GITHUB_OUTPUT"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
IS_PRERELEASE="${{ github.event.release.prerelease }}"
|
||||||
|
|
||||||
|
if [ "$IS_PRERELEASE" = "true" ]; then
|
||||||
|
if ! echo "$TAG" | grep -E '(-beta([.-][0-9]+)?|-rc([.-][0-9]+)?)' >/dev/null; then
|
||||||
|
echo "Release marked as prerelease but tag $TAG lacks beta/rc suffix. Skipping." >&2
|
||||||
|
echo "should_run=false" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "is_prerelease=false" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "latest_tag=$latest_tag" >> "$GITHUB_OUTPUT"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
echo "should_run=true" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "is_prerelease=true" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "latest_tag=$latest_tag" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "Release is prerelease, proceeding"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "${latest_tag}" == "$TAG" ]]; then
|
||||||
|
echo "should_run=true" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "is_prerelease=false" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "latest_tag=$latest_tag" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "Release is latest, proceeding"
|
||||||
|
else
|
||||||
|
echo "should_run=false" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "is_prerelease=false" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "latest_tag=$latest_tag" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "Release is neither prerelease nor latest, skipping"
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Prepare metadata
|
||||||
|
id: meta
|
||||||
|
if: steps.check.outputs.should_run == 'true'
|
||||||
|
run: |
|
||||||
|
EVENT="${{ github.event_name }}"
|
||||||
|
LATEST_TAG="${{ steps.check.outputs.latest_tag }}"
|
||||||
|
if [ "$EVENT" = "release" ]; then
|
||||||
|
TAG="${{ github.event.release.tag_name }}"
|
||||||
|
PRE="${{ github.event.release.prerelease }}"
|
||||||
|
|
||||||
|
if [ -n "$LATEST_TAG" ] && [ "$LATEST_TAG" = "$TAG" ]; then
|
||||||
|
LATEST="true"
|
||||||
|
else
|
||||||
|
LATEST="false"
|
||||||
|
fi
|
||||||
|
TRIGGER="release"
|
||||||
|
else
|
||||||
|
TAG="${{ github.event.inputs.tag }}"
|
||||||
|
PRE="${{ github.event.inputs.is_prerelease }}"
|
||||||
|
if [ -z "$PRE" ]; then
|
||||||
|
PRE="false"
|
||||||
|
fi
|
||||||
|
if [ -n "$LATEST_TAG" ] && [ "$LATEST_TAG" = "$TAG" ] && [ "$PRE" != "true" ]; then
|
||||||
|
LATEST="true"
|
||||||
|
else
|
||||||
|
LATEST="false"
|
||||||
|
fi
|
||||||
|
TRIGGER="manual"
|
||||||
|
fi
|
||||||
|
|
||||||
|
SAFE_TAG=$(echo "$TAG" | sed 's/[^A-Za-z0-9._-]/-/g')
|
||||||
|
echo "tag=$TAG" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "safe_tag=$SAFE_TAG" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "prerelease=$PRE" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "latest=$LATEST" >> "$GITHUB_OUTPUT"
|
||||||
|
echo "trigger=$TRIGGER" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
|
- name: Checkout default branch
|
||||||
|
if: steps.check.outputs.should_run == 'true'
|
||||||
|
uses: actions/checkout@v5
|
||||||
|
with:
|
||||||
|
ref: ${{ github.event.repository.default_branch }}
|
||||||
|
path: main
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Checkout x-files/app-upgrade-config branch
|
||||||
|
if: steps.check.outputs.should_run == 'true'
|
||||||
|
uses: actions/checkout@v5
|
||||||
|
with:
|
||||||
|
ref: x-files/app-upgrade-config
|
||||||
|
path: cs
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Setup Node.js
|
||||||
|
if: steps.check.outputs.should_run == 'true'
|
||||||
|
uses: actions/setup-node@v4
|
||||||
|
with:
|
||||||
|
node-version: 22
|
||||||
|
|
||||||
|
- name: Enable Corepack
|
||||||
|
if: steps.check.outputs.should_run == 'true'
|
||||||
|
run: corepack enable && corepack prepare yarn@4.9.1 --activate
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
if: steps.check.outputs.should_run == 'true'
|
||||||
|
working-directory: main
|
||||||
|
run: yarn install --immutable
|
||||||
|
|
||||||
|
- name: Update upgrade config
|
||||||
|
if: steps.check.outputs.should_run == 'true'
|
||||||
|
working-directory: main
|
||||||
|
env:
|
||||||
|
RELEASE_TAG: ${{ steps.meta.outputs.tag }}
|
||||||
|
IS_PRERELEASE: ${{ steps.check.outputs.is_prerelease }}
|
||||||
|
run: |
|
||||||
|
yarn tsx scripts/update-app-upgrade-config.ts \
|
||||||
|
--tag "$RELEASE_TAG" \
|
||||||
|
--config ../cs/app-upgrade-config.json \
|
||||||
|
--is-prerelease "$IS_PRERELEASE"
|
||||||
|
|
||||||
|
- name: Detect changes
|
||||||
|
if: steps.check.outputs.should_run == 'true'
|
||||||
|
id: diff
|
||||||
|
working-directory: cs
|
||||||
|
run: |
|
||||||
|
if git diff --quiet -- app-upgrade-config.json; then
|
||||||
|
echo "changed=false" >> "$GITHUB_OUTPUT"
|
||||||
|
else
|
||||||
|
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||||
|
fi
|
||||||
|
|
||||||
|
- name: Create pull request
|
||||||
|
if: steps.check.outputs.should_run == 'true' && steps.diff.outputs.changed == 'true'
|
||||||
|
uses: peter-evans/create-pull-request@v7
|
||||||
|
with:
|
||||||
|
path: cs
|
||||||
|
base: x-files/app-upgrade-config
|
||||||
|
branch: chore/update-app-upgrade-config/${{ steps.meta.outputs.safe_tag }}
|
||||||
|
commit-message: "🤖 chore: sync app-upgrade-config for ${{ steps.meta.outputs.tag }}"
|
||||||
|
title: "chore: update app-upgrade-config for ${{ steps.meta.outputs.tag }}"
|
||||||
|
body: |
|
||||||
|
Automated update triggered by `${{ steps.meta.outputs.trigger }}`.
|
||||||
|
|
||||||
|
- Source tag: `${{ steps.meta.outputs.tag }}`
|
||||||
|
- Pre-release: `${{ steps.meta.outputs.prerelease }}`
|
||||||
|
- Latest: `${{ steps.meta.outputs.latest }}`
|
||||||
|
- Workflow run: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}
|
||||||
|
labels: |
|
||||||
|
automation
|
||||||
|
app-upgrade
|
||||||
|
|
||||||
|
- name: No changes detected
|
||||||
|
if: steps.check.outputs.should_run == 'true' && steps.diff.outputs.changed != 'true'
|
||||||
|
run: echo "No updates required for x-files/app-upgrade-config/app-upgrade-config.json"
|
||||||
@@ -1,8 +1,8 @@
|
|||||||
diff --git a/dist/index.js b/dist/index.js
|
diff --git a/dist/index.js b/dist/index.js
|
||||||
index ff305b112779b718f21a636a27b1196125a332d9..cf32ff5086d4d9e56f8fe90c98724559083bafc3 100644
|
index dc7b74ba55337c491cdf1ab3e39ca68cc4187884..ace8c90591288e42c2957e93c9bf7984f1b22444 100644
|
||||||
--- a/dist/index.js
|
--- a/dist/index.js
|
||||||
+++ b/dist/index.js
|
+++ b/dist/index.js
|
||||||
@@ -471,7 +471,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
|
@@ -472,7 +472,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
|
||||||
|
|
||||||
// src/get-model-path.ts
|
// src/get-model-path.ts
|
||||||
function getModelPath(modelId) {
|
function getModelPath(modelId) {
|
||||||
@@ -12,10 +12,10 @@ index ff305b112779b718f21a636a27b1196125a332d9..cf32ff5086d4d9e56f8fe90c98724559
|
|||||||
|
|
||||||
// src/google-generative-ai-options.ts
|
// src/google-generative-ai-options.ts
|
||||||
diff --git a/dist/index.mjs b/dist/index.mjs
|
diff --git a/dist/index.mjs b/dist/index.mjs
|
||||||
index 57659290f1cec74878a385626ad75b2a4d5cd3fc..d04e5927ec3725b6ffdb80868bfa1b5a48849537 100644
|
index 8390439c38cb7eaeb52080862cd6f4c58509e67c..a7647f2e11700dff7e1c8d4ae8f99d3637010733 100644
|
||||||
--- a/dist/index.mjs
|
--- a/dist/index.mjs
|
||||||
+++ b/dist/index.mjs
|
+++ b/dist/index.mjs
|
||||||
@@ -477,7 +477,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
|
@@ -478,7 +478,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
|
||||||
|
|
||||||
// src/get-model-path.ts
|
// src/get-model-path.ts
|
||||||
function getModelPath(modelId) {
|
function getModelPath(modelId) {
|
||||||
@@ -1,131 +0,0 @@
|
|||||||
diff --git a/dist/index.mjs b/dist/index.mjs
|
|
||||||
index b3f018730a93639aad7c203f15fb1aeb766c73f4..ade2a43d66e9184799d072153df61ef7be4ea110 100644
|
|
||||||
--- a/dist/index.mjs
|
|
||||||
+++ b/dist/index.mjs
|
|
||||||
@@ -296,7 +296,14 @@ var HuggingFaceResponsesLanguageModel = class {
|
|
||||||
metadata: huggingfaceOptions == null ? void 0 : huggingfaceOptions.metadata,
|
|
||||||
instructions: huggingfaceOptions == null ? void 0 : huggingfaceOptions.instructions,
|
|
||||||
...preparedTools && { tools: preparedTools },
|
|
||||||
- ...preparedToolChoice && { tool_choice: preparedToolChoice }
|
|
||||||
+ ...preparedToolChoice && { tool_choice: preparedToolChoice },
|
|
||||||
+ ...(huggingfaceOptions?.reasoningEffort != null && {
|
|
||||||
+ reasoning: {
|
|
||||||
+ ...(huggingfaceOptions?.reasoningEffort != null && {
|
|
||||||
+ effort: huggingfaceOptions.reasoningEffort,
|
|
||||||
+ }),
|
|
||||||
+ },
|
|
||||||
+ }),
|
|
||||||
};
|
|
||||||
return { args: baseArgs, warnings };
|
|
||||||
}
|
|
||||||
@@ -365,6 +372,20 @@ var HuggingFaceResponsesLanguageModel = class {
|
|
||||||
}
|
|
||||||
break;
|
|
||||||
}
|
|
||||||
+ case 'reasoning': {
|
|
||||||
+ for (const contentPart of part.content) {
|
|
||||||
+ content.push({
|
|
||||||
+ type: 'reasoning',
|
|
||||||
+ text: contentPart.text,
|
|
||||||
+ providerMetadata: {
|
|
||||||
+ huggingface: {
|
|
||||||
+ itemId: part.id,
|
|
||||||
+ },
|
|
||||||
+ },
|
|
||||||
+ });
|
|
||||||
+ }
|
|
||||||
+ break;
|
|
||||||
+ }
|
|
||||||
case "mcp_call": {
|
|
||||||
content.push({
|
|
||||||
type: "tool-call",
|
|
||||||
@@ -519,6 +540,11 @@ var HuggingFaceResponsesLanguageModel = class {
|
|
||||||
id: value.item.call_id,
|
|
||||||
toolName: value.item.name
|
|
||||||
});
|
|
||||||
+ } else if (value.item.type === 'reasoning') {
|
|
||||||
+ controller.enqueue({
|
|
||||||
+ type: 'reasoning-start',
|
|
||||||
+ id: value.item.id,
|
|
||||||
+ });
|
|
||||||
}
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
@@ -570,6 +596,22 @@ var HuggingFaceResponsesLanguageModel = class {
|
|
||||||
});
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
+ if (isReasoningDeltaChunk(value)) {
|
|
||||||
+ controller.enqueue({
|
|
||||||
+ type: 'reasoning-delta',
|
|
||||||
+ id: value.item_id,
|
|
||||||
+ delta: value.delta,
|
|
||||||
+ });
|
|
||||||
+ return;
|
|
||||||
+ }
|
|
||||||
+
|
|
||||||
+ if (isReasoningEndChunk(value)) {
|
|
||||||
+ controller.enqueue({
|
|
||||||
+ type: 'reasoning-end',
|
|
||||||
+ id: value.item_id,
|
|
||||||
+ });
|
|
||||||
+ return;
|
|
||||||
+ }
|
|
||||||
},
|
|
||||||
flush(controller) {
|
|
||||||
controller.enqueue({
|
|
||||||
@@ -593,7 +635,8 @@ var HuggingFaceResponsesLanguageModel = class {
|
|
||||||
var huggingfaceResponsesProviderOptionsSchema = z2.object({
|
|
||||||
metadata: z2.record(z2.string(), z2.string()).optional(),
|
|
||||||
instructions: z2.string().optional(),
|
|
||||||
- strictJsonSchema: z2.boolean().optional()
|
|
||||||
+ strictJsonSchema: z2.boolean().optional(),
|
|
||||||
+ reasoningEffort: z2.string().optional(),
|
|
||||||
});
|
|
||||||
var huggingfaceResponsesResponseSchema = z2.object({
|
|
||||||
id: z2.string(),
|
|
||||||
@@ -727,12 +770,31 @@ var responseCreatedChunkSchema = z2.object({
|
|
||||||
model: z2.string()
|
|
||||||
})
|
|
||||||
});
|
|
||||||
+var reasoningTextDeltaChunkSchema = z2.object({
|
|
||||||
+ type: z2.literal('response.reasoning_text.delta'),
|
|
||||||
+ item_id: z2.string(),
|
|
||||||
+ output_index: z2.number(),
|
|
||||||
+ content_index: z2.number(),
|
|
||||||
+ delta: z2.string(),
|
|
||||||
+ sequence_number: z2.number(),
|
|
||||||
+});
|
|
||||||
+
|
|
||||||
+var reasoningTextEndChunkSchema = z2.object({
|
|
||||||
+ type: z2.literal('response.reasoning_text.done'),
|
|
||||||
+ item_id: z2.string(),
|
|
||||||
+ output_index: z2.number(),
|
|
||||||
+ content_index: z2.number(),
|
|
||||||
+ text: z2.string(),
|
|
||||||
+ sequence_number: z2.number(),
|
|
||||||
+});
|
|
||||||
var huggingfaceResponsesChunkSchema = z2.union([
|
|
||||||
responseOutputItemAddedSchema,
|
|
||||||
responseOutputItemDoneSchema,
|
|
||||||
textDeltaChunkSchema,
|
|
||||||
responseCompletedChunkSchema,
|
|
||||||
responseCreatedChunkSchema,
|
|
||||||
+ reasoningTextDeltaChunkSchema,
|
|
||||||
+ reasoningTextEndChunkSchema,
|
|
||||||
z2.object({ type: z2.string() }).loose()
|
|
||||||
// fallback for unknown chunks
|
|
||||||
]);
|
|
||||||
@@ -751,6 +813,12 @@ function isResponseCompletedChunk(chunk) {
|
|
||||||
function isResponseCreatedChunk(chunk) {
|
|
||||||
return chunk.type === "response.created";
|
|
||||||
}
|
|
||||||
+function isReasoningDeltaChunk(chunk) {
|
|
||||||
+ return chunk.type === 'response.reasoning_text.delta';
|
|
||||||
+}
|
|
||||||
+function isReasoningEndChunk(chunk) {
|
|
||||||
+ return chunk.type === 'response.reasoning_text.done';
|
|
||||||
+}
|
|
||||||
|
|
||||||
// src/huggingface-provider.ts
|
|
||||||
function createHuggingFace(options = {}) {
|
|
||||||
140
.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch
vendored
Normal file
140
.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch
vendored
Normal file
@@ -0,0 +1,140 @@
|
|||||||
|
diff --git a/dist/index.js b/dist/index.js
|
||||||
|
index 73045a7d38faafdc7f7d2cd79d7ff0e2b031056b..8d948c9ac4ea4b474db9ef3c5491961e7fcf9a07 100644
|
||||||
|
--- a/dist/index.js
|
||||||
|
+++ b/dist/index.js
|
||||||
|
@@ -421,6 +421,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||||
|
text: reasoning
|
||||||
|
});
|
||||||
|
}
|
||||||
|
+ if (choice.message.images) {
|
||||||
|
+ for (const image of choice.message.images) {
|
||||||
|
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||||
|
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||||
|
+ content.push({
|
||||||
|
+ type: 'file',
|
||||||
|
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||||
|
+ data: match2 ? match2[1] : image.image_url.url,
|
||||||
|
+ });
|
||||||
|
+ }
|
||||||
|
+ }
|
||||||
|
if (choice.message.tool_calls != null) {
|
||||||
|
for (const toolCall of choice.message.tool_calls) {
|
||||||
|
content.push({
|
||||||
|
@@ -598,6 +609,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||||
|
delta: delta.content
|
||||||
|
});
|
||||||
|
}
|
||||||
|
+ if (delta.images) {
|
||||||
|
+ for (const image of delta.images) {
|
||||||
|
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||||
|
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||||
|
+ controller.enqueue({
|
||||||
|
+ type: 'file',
|
||||||
|
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||||
|
+ data: match2 ? match2[1] : image.image_url.url,
|
||||||
|
+ });
|
||||||
|
+ }
|
||||||
|
+ }
|
||||||
|
if (delta.tool_calls != null) {
|
||||||
|
for (const toolCallDelta of delta.tool_calls) {
|
||||||
|
const index = toolCallDelta.index;
|
||||||
|
@@ -765,6 +787,14 @@ var OpenAICompatibleChatResponseSchema = import_v43.z.object({
|
||||||
|
arguments: import_v43.z.string()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
+ ).nullish(),
|
||||||
|
+ images: import_v43.z.array(
|
||||||
|
+ import_v43.z.object({
|
||||||
|
+ type: import_v43.z.literal('image_url'),
|
||||||
|
+ image_url: import_v43.z.object({
|
||||||
|
+ url: import_v43.z.string(),
|
||||||
|
+ })
|
||||||
|
+ })
|
||||||
|
).nullish()
|
||||||
|
}),
|
||||||
|
finish_reason: import_v43.z.string().nullish()
|
||||||
|
@@ -795,6 +825,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => import_v43.z.union(
|
||||||
|
arguments: import_v43.z.string().nullish()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
+ ).nullish(),
|
||||||
|
+ images: import_v43.z.array(
|
||||||
|
+ import_v43.z.object({
|
||||||
|
+ type: import_v43.z.literal('image_url'),
|
||||||
|
+ image_url: import_v43.z.object({
|
||||||
|
+ url: import_v43.z.string(),
|
||||||
|
+ })
|
||||||
|
+ })
|
||||||
|
).nullish()
|
||||||
|
}).nullish(),
|
||||||
|
finish_reason: import_v43.z.string().nullish()
|
||||||
|
diff --git a/dist/index.mjs b/dist/index.mjs
|
||||||
|
index 1c2b9560bbfbfe10cb01af080aeeed4ff59db29c..2c8ddc4fc9bfc5e7e06cfca105d197a08864c427 100644
|
||||||
|
--- a/dist/index.mjs
|
||||||
|
+++ b/dist/index.mjs
|
||||||
|
@@ -405,6 +405,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||||
|
text: reasoning
|
||||||
|
});
|
||||||
|
}
|
||||||
|
+ if (choice.message.images) {
|
||||||
|
+ for (const image of choice.message.images) {
|
||||||
|
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||||
|
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||||
|
+ content.push({
|
||||||
|
+ type: 'file',
|
||||||
|
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||||
|
+ data: match2 ? match2[1] : image.image_url.url,
|
||||||
|
+ });
|
||||||
|
+ }
|
||||||
|
+ }
|
||||||
|
if (choice.message.tool_calls != null) {
|
||||||
|
for (const toolCall of choice.message.tool_calls) {
|
||||||
|
content.push({
|
||||||
|
@@ -582,6 +593,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||||
|
delta: delta.content
|
||||||
|
});
|
||||||
|
}
|
||||||
|
+ if (delta.images) {
|
||||||
|
+ for (const image of delta.images) {
|
||||||
|
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||||
|
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||||
|
+ controller.enqueue({
|
||||||
|
+ type: 'file',
|
||||||
|
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||||
|
+ data: match2 ? match2[1] : image.image_url.url,
|
||||||
|
+ });
|
||||||
|
+ }
|
||||||
|
+ }
|
||||||
|
if (delta.tool_calls != null) {
|
||||||
|
for (const toolCallDelta of delta.tool_calls) {
|
||||||
|
const index = toolCallDelta.index;
|
||||||
|
@@ -749,6 +771,14 @@ var OpenAICompatibleChatResponseSchema = z3.object({
|
||||||
|
arguments: z3.string()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
+ ).nullish(),
|
||||||
|
+ images: z3.array(
|
||||||
|
+ z3.object({
|
||||||
|
+ type: z3.literal('image_url'),
|
||||||
|
+ image_url: z3.object({
|
||||||
|
+ url: z3.string(),
|
||||||
|
+ })
|
||||||
|
+ })
|
||||||
|
).nullish()
|
||||||
|
}),
|
||||||
|
finish_reason: z3.string().nullish()
|
||||||
|
@@ -779,6 +809,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => z3.union([
|
||||||
|
arguments: z3.string().nullish()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
+ ).nullish(),
|
||||||
|
+ images: z3.array(
|
||||||
|
+ z3.object({
|
||||||
|
+ type: z3.literal('image_url'),
|
||||||
|
+ image_url: z3.object({
|
||||||
|
+ url: z3.string(),
|
||||||
|
+ })
|
||||||
|
+ })
|
||||||
|
).nullish()
|
||||||
|
}).nullish(),
|
||||||
|
finish_reason: z3.string().nullish()
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
diff --git a/dist/index.js b/dist/index.js
|
diff --git a/dist/index.js b/dist/index.js
|
||||||
index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa96b52ac0d 100644
|
index 7481f3b3511078068d87d03855b568b20bb86971..8ac5ec28d2f7ad1b3b0d3f8da945c75674e59637 100644
|
||||||
--- a/dist/index.js
|
--- a/dist/index.js
|
||||||
+++ b/dist/index.js
|
+++ b/dist/index.js
|
||||||
@@ -274,6 +274,7 @@ var openaiChatResponseSchema = (0, import_provider_utils3.lazyValidator)(
|
@@ -274,6 +274,7 @@ var openaiChatResponseSchema = (0, import_provider_utils3.lazyValidator)(
|
||||||
@@ -18,7 +18,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
|
|||||||
tool_calls: import_v42.z.array(
|
tool_calls: import_v42.z.array(
|
||||||
import_v42.z.object({
|
import_v42.z.object({
|
||||||
index: import_v42.z.number(),
|
index: import_v42.z.number(),
|
||||||
@@ -785,6 +787,13 @@ var OpenAIChatLanguageModel = class {
|
@@ -795,6 +797,13 @@ var OpenAIChatLanguageModel = class {
|
||||||
if (text != null && text.length > 0) {
|
if (text != null && text.length > 0) {
|
||||||
content.push({ type: "text", text });
|
content.push({ type: "text", text });
|
||||||
}
|
}
|
||||||
@@ -32,7 +32,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
|
|||||||
for (const toolCall of (_a = choice.message.tool_calls) != null ? _a : []) {
|
for (const toolCall of (_a = choice.message.tool_calls) != null ? _a : []) {
|
||||||
content.push({
|
content.push({
|
||||||
type: "tool-call",
|
type: "tool-call",
|
||||||
@@ -866,6 +875,7 @@ var OpenAIChatLanguageModel = class {
|
@@ -876,6 +885,7 @@ var OpenAIChatLanguageModel = class {
|
||||||
};
|
};
|
||||||
let metadataExtracted = false;
|
let metadataExtracted = false;
|
||||||
let isActiveText = false;
|
let isActiveText = false;
|
||||||
@@ -40,7 +40,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
|
|||||||
const providerMetadata = { openai: {} };
|
const providerMetadata = { openai: {} };
|
||||||
return {
|
return {
|
||||||
stream: response.pipeThrough(
|
stream: response.pipeThrough(
|
||||||
@@ -923,6 +933,21 @@ var OpenAIChatLanguageModel = class {
|
@@ -933,6 +943,21 @@ var OpenAIChatLanguageModel = class {
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
const delta = choice.delta;
|
const delta = choice.delta;
|
||||||
@@ -62,7 +62,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
|
|||||||
if (delta.content != null) {
|
if (delta.content != null) {
|
||||||
if (!isActiveText) {
|
if (!isActiveText) {
|
||||||
controller.enqueue({ type: "text-start", id: "0" });
|
controller.enqueue({ type: "text-start", id: "0" });
|
||||||
@@ -1035,6 +1060,9 @@ var OpenAIChatLanguageModel = class {
|
@@ -1045,6 +1070,9 @@ var OpenAIChatLanguageModel = class {
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
flush(controller) {
|
flush(controller) {
|
||||||
@@ -1,276 +0,0 @@
|
|||||||
diff --git a/out/macPackager.js b/out/macPackager.js
|
|
||||||
index 852f6c4d16f86a7bb8a78bf1ed5a14647a279aa1..60e7f5f16a844541eb1909b215fcda1811e924b8 100644
|
|
||||||
--- a/out/macPackager.js
|
|
||||||
+++ b/out/macPackager.js
|
|
||||||
@@ -423,7 +423,7 @@ class MacPackager extends platformPackager_1.PlatformPackager {
|
|
||||||
}
|
|
||||||
appPlist.CFBundleName = appInfo.productName;
|
|
||||||
appPlist.CFBundleDisplayName = appInfo.productName;
|
|
||||||
- const minimumSystemVersion = this.platformSpecificBuildOptions.minimumSystemVersion;
|
|
||||||
+ const minimumSystemVersion = this.platformSpecificBuildOptions.LSMinimumSystemVersion;
|
|
||||||
if (minimumSystemVersion != null) {
|
|
||||||
appPlist.LSMinimumSystemVersion = minimumSystemVersion;
|
|
||||||
}
|
|
||||||
diff --git a/out/publish/updateInfoBuilder.js b/out/publish/updateInfoBuilder.js
|
|
||||||
index 7924c5b47d01f8dfccccb8f46658015fa66da1f7..1a1588923c3939ae1297b87931ba83f0ebc052d8 100644
|
|
||||||
--- a/out/publish/updateInfoBuilder.js
|
|
||||||
+++ b/out/publish/updateInfoBuilder.js
|
|
||||||
@@ -133,6 +133,7 @@ async function createUpdateInfo(version, event, releaseInfo) {
|
|
||||||
const customUpdateInfo = event.updateInfo;
|
|
||||||
const url = path.basename(event.file);
|
|
||||||
const sha512 = (customUpdateInfo == null ? null : customUpdateInfo.sha512) || (await (0, hash_1.hashFile)(event.file));
|
|
||||||
+ const minimumSystemVersion = customUpdateInfo == null ? null : customUpdateInfo.minimumSystemVersion;
|
|
||||||
const files = [{ url, sha512 }];
|
|
||||||
const result = {
|
|
||||||
// @ts-ignore
|
|
||||||
@@ -143,9 +144,13 @@ async function createUpdateInfo(version, event, releaseInfo) {
|
|
||||||
path: url /* backward compatibility, electron-updater 1.x - electron-updater 2.15.0 */,
|
|
||||||
// @ts-ignore
|
|
||||||
sha512 /* backward compatibility, electron-updater 1.x - electron-updater 2.15.0 */,
|
|
||||||
+ minimumSystemVersion,
|
|
||||||
...releaseInfo,
|
|
||||||
};
|
|
||||||
if (customUpdateInfo != null) {
|
|
||||||
+ if (customUpdateInfo.minimumSystemVersion) {
|
|
||||||
+ delete customUpdateInfo.minimumSystemVersion;
|
|
||||||
+ }
|
|
||||||
// file info or nsis web installer packages info
|
|
||||||
Object.assign("sha512" in customUpdateInfo ? files[0] : result, customUpdateInfo);
|
|
||||||
}
|
|
||||||
diff --git a/out/targets/ArchiveTarget.js b/out/targets/ArchiveTarget.js
|
|
||||||
index e1f52a5fa86fff6643b2e57eaf2af318d541f865..47cc347f154a24b365e70ae5e1f6d309f3582ed0 100644
|
|
||||||
--- a/out/targets/ArchiveTarget.js
|
|
||||||
+++ b/out/targets/ArchiveTarget.js
|
|
||||||
@@ -69,6 +69,9 @@ class ArchiveTarget extends core_1.Target {
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
+ if (updateInfo != null && this.packager.platformSpecificBuildOptions.minimumSystemVersion) {
|
|
||||||
+ updateInfo.minimumSystemVersion = this.packager.platformSpecificBuildOptions.minimumSystemVersion;
|
|
||||||
+ }
|
|
||||||
await packager.info.emitArtifactBuildCompleted({
|
|
||||||
updateInfo,
|
|
||||||
file: artifactPath,
|
|
||||||
diff --git a/out/targets/nsis/NsisTarget.js b/out/targets/nsis/NsisTarget.js
|
|
||||||
index e8bd7bb46c8a54b3f55cf3a853ef924195271e01..f956e9f3fe9eb903c78aef3502553b01de4b89b1 100644
|
|
||||||
--- a/out/targets/nsis/NsisTarget.js
|
|
||||||
+++ b/out/targets/nsis/NsisTarget.js
|
|
||||||
@@ -305,6 +305,9 @@ class NsisTarget extends core_1.Target {
|
|
||||||
if (updateInfo != null && isPerMachine && (oneClick || options.packElevateHelper)) {
|
|
||||||
updateInfo.isAdminRightsRequired = true;
|
|
||||||
}
|
|
||||||
+ if (updateInfo != null && this.packager.platformSpecificBuildOptions.minimumSystemVersion) {
|
|
||||||
+ updateInfo.minimumSystemVersion = this.packager.platformSpecificBuildOptions.minimumSystemVersion;
|
|
||||||
+ }
|
|
||||||
await packager.info.emitArtifactBuildCompleted({
|
|
||||||
file: installerPath,
|
|
||||||
updateInfo,
|
|
||||||
diff --git a/out/util/yarn.js b/out/util/yarn.js
|
|
||||||
index 1ee20f8b252a8f28d0c7b103789cf0a9a427aec1..c2878ec54d57da50bf14225e0c70c9c88664eb8a 100644
|
|
||||||
--- a/out/util/yarn.js
|
|
||||||
+++ b/out/util/yarn.js
|
|
||||||
@@ -140,6 +140,7 @@ async function rebuild(config, { appDir, projectDir }, options) {
|
|
||||||
arch,
|
|
||||||
platform,
|
|
||||||
buildFromSource,
|
|
||||||
+ ignoreModules: config.excludeReBuildModules || undefined,
|
|
||||||
projectRootPath: projectDir,
|
|
||||||
mode: config.nativeRebuilder || "sequential",
|
|
||||||
disablePreGypCopy: true,
|
|
||||||
diff --git a/scheme.json b/scheme.json
|
|
||||||
index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..0167441bf928a92f59b5dbe70b2317a74dda74c9 100644
|
|
||||||
--- a/scheme.json
|
|
||||||
+++ b/scheme.json
|
|
||||||
@@ -1825,6 +1825,20 @@
|
|
||||||
"string"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
+ "excludeReBuildModules": {
|
|
||||||
+ "anyOf": [
|
|
||||||
+ {
|
|
||||||
+ "items": {
|
|
||||||
+ "type": "string"
|
|
||||||
+ },
|
|
||||||
+ "type": "array"
|
|
||||||
+ },
|
|
||||||
+ {
|
|
||||||
+ "type": "null"
|
|
||||||
+ }
|
|
||||||
+ ],
|
|
||||||
+ "description": "The modules to exclude from the rebuild."
|
|
||||||
+ },
|
|
||||||
"executableArgs": {
|
|
||||||
"anyOf": [
|
|
||||||
{
|
|
||||||
@@ -1975,6 +1989,13 @@
|
|
||||||
],
|
|
||||||
"description": "The mime types in addition to specified in the file associations. Use it if you don't want to register a new mime type, but reuse existing."
|
|
||||||
},
|
|
||||||
+ "minimumSystemVersion": {
|
|
||||||
+ "description": "The minimum os kernel version required to install the application.",
|
|
||||||
+ "type": [
|
|
||||||
+ "null",
|
|
||||||
+ "string"
|
|
||||||
+ ]
|
|
||||||
+ },
|
|
||||||
"packageCategory": {
|
|
||||||
"description": "backward compatibility + to allow specify fpm-only category for all possible fpm targets in one place",
|
|
||||||
"type": [
|
|
||||||
@@ -2327,6 +2348,13 @@
|
|
||||||
"MacConfiguration": {
|
|
||||||
"additionalProperties": false,
|
|
||||||
"properties": {
|
|
||||||
+ "LSMinimumSystemVersion": {
|
|
||||||
+ "description": "The minimum version of macOS required for the app to run. Corresponds to `LSMinimumSystemVersion`.",
|
|
||||||
+ "type": [
|
|
||||||
+ "null",
|
|
||||||
+ "string"
|
|
||||||
+ ]
|
|
||||||
+ },
|
|
||||||
"additionalArguments": {
|
|
||||||
"anyOf": [
|
|
||||||
{
|
|
||||||
@@ -2527,6 +2555,20 @@
|
|
||||||
"string"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
+ "excludeReBuildModules": {
|
|
||||||
+ "anyOf": [
|
|
||||||
+ {
|
|
||||||
+ "items": {
|
|
||||||
+ "type": "string"
|
|
||||||
+ },
|
|
||||||
+ "type": "array"
|
|
||||||
+ },
|
|
||||||
+ {
|
|
||||||
+ "type": "null"
|
|
||||||
+ }
|
|
||||||
+ ],
|
|
||||||
+ "description": "The modules to exclude from the rebuild."
|
|
||||||
+ },
|
|
||||||
"executableName": {
|
|
||||||
"description": "The executable name. Defaults to `productName`.",
|
|
||||||
"type": [
|
|
||||||
@@ -2737,7 +2779,7 @@
|
|
||||||
"type": "boolean"
|
|
||||||
},
|
|
||||||
"minimumSystemVersion": {
|
|
||||||
- "description": "The minimum version of macOS required for the app to run. Corresponds to `LSMinimumSystemVersion`.",
|
|
||||||
+ "description": "The minimum os kernel version required to install the application.",
|
|
||||||
"type": [
|
|
||||||
"null",
|
|
||||||
"string"
|
|
||||||
@@ -2959,6 +3001,13 @@
|
|
||||||
"MasConfiguration": {
|
|
||||||
"additionalProperties": false,
|
|
||||||
"properties": {
|
|
||||||
+ "LSMinimumSystemVersion": {
|
|
||||||
+ "description": "The minimum version of macOS required for the app to run. Corresponds to `LSMinimumSystemVersion`.",
|
|
||||||
+ "type": [
|
|
||||||
+ "null",
|
|
||||||
+ "string"
|
|
||||||
+ ]
|
|
||||||
+ },
|
|
||||||
"additionalArguments": {
|
|
||||||
"anyOf": [
|
|
||||||
{
|
|
||||||
@@ -3159,6 +3208,20 @@
|
|
||||||
"string"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
+ "excludeReBuildModules": {
|
|
||||||
+ "anyOf": [
|
|
||||||
+ {
|
|
||||||
+ "items": {
|
|
||||||
+ "type": "string"
|
|
||||||
+ },
|
|
||||||
+ "type": "array"
|
|
||||||
+ },
|
|
||||||
+ {
|
|
||||||
+ "type": "null"
|
|
||||||
+ }
|
|
||||||
+ ],
|
|
||||||
+ "description": "The modules to exclude from the rebuild."
|
|
||||||
+ },
|
|
||||||
"executableName": {
|
|
||||||
"description": "The executable name. Defaults to `productName`.",
|
|
||||||
"type": [
|
|
||||||
@@ -3369,7 +3432,7 @@
|
|
||||||
"type": "boolean"
|
|
||||||
},
|
|
||||||
"minimumSystemVersion": {
|
|
||||||
- "description": "The minimum version of macOS required for the app to run. Corresponds to `LSMinimumSystemVersion`.",
|
|
||||||
+ "description": "The minimum os kernel version required to install the application.",
|
|
||||||
"type": [
|
|
||||||
"null",
|
|
||||||
"string"
|
|
||||||
@@ -6381,6 +6444,20 @@
|
|
||||||
"string"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
+ "excludeReBuildModules": {
|
|
||||||
+ "anyOf": [
|
|
||||||
+ {
|
|
||||||
+ "items": {
|
|
||||||
+ "type": "string"
|
|
||||||
+ },
|
|
||||||
+ "type": "array"
|
|
||||||
+ },
|
|
||||||
+ {
|
|
||||||
+ "type": "null"
|
|
||||||
+ }
|
|
||||||
+ ],
|
|
||||||
+ "description": "The modules to exclude from the rebuild."
|
|
||||||
+ },
|
|
||||||
"executableName": {
|
|
||||||
"description": "The executable name. Defaults to `productName`.",
|
|
||||||
"type": [
|
|
||||||
@@ -6507,6 +6584,13 @@
|
|
||||||
"string"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
+ "minimumSystemVersion": {
|
|
||||||
+ "description": "The minimum os kernel version required to install the application.",
|
|
||||||
+ "type": [
|
|
||||||
+ "null",
|
|
||||||
+ "string"
|
|
||||||
+ ]
|
|
||||||
+ },
|
|
||||||
"protocols": {
|
|
||||||
"anyOf": [
|
|
||||||
{
|
|
||||||
@@ -7153,6 +7237,20 @@
|
|
||||||
"string"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
+ "excludeReBuildModules": {
|
|
||||||
+ "anyOf": [
|
|
||||||
+ {
|
|
||||||
+ "items": {
|
|
||||||
+ "type": "string"
|
|
||||||
+ },
|
|
||||||
+ "type": "array"
|
|
||||||
+ },
|
|
||||||
+ {
|
|
||||||
+ "type": "null"
|
|
||||||
+ }
|
|
||||||
+ ],
|
|
||||||
+ "description": "The modules to exclude from the rebuild."
|
|
||||||
+ },
|
|
||||||
"executableName": {
|
|
||||||
"description": "The executable name. Defaults to `productName`.",
|
|
||||||
"type": [
|
|
||||||
@@ -7376,6 +7474,13 @@
|
|
||||||
],
|
|
||||||
"description": "MAS (Mac Application Store) development options (`mas-dev` target)."
|
|
||||||
},
|
|
||||||
+ "minimumSystemVersion": {
|
|
||||||
+ "description": "The minimum os kernel version required to install the application.",
|
|
||||||
+ "type": [
|
|
||||||
+ "null",
|
|
||||||
+ "string"
|
|
||||||
+ ]
|
|
||||||
+ },
|
|
||||||
"msi": {
|
|
||||||
"anyOf": [
|
|
||||||
{
|
|
||||||
14
.yarn/patches/electron-updater-npm-6.7.0-47b11bb0d4.patch
vendored
Normal file
14
.yarn/patches/electron-updater-npm-6.7.0-47b11bb0d4.patch
vendored
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
diff --git a/out/util.js b/out/util.js
|
||||||
|
index 9294ffd6ca8f02c2e0f90c663e7e9cdc02c1ac37..f52107493e2995320ee4efd0eb2a8c9bf03291a2 100644
|
||||||
|
--- a/out/util.js
|
||||||
|
+++ b/out/util.js
|
||||||
|
@@ -23,7 +23,8 @@ function newUrlFromBase(pathname, baseUrl, addRandomQueryToAvoidCaching = false)
|
||||||
|
result.search = search;
|
||||||
|
}
|
||||||
|
else if (addRandomQueryToAvoidCaching) {
|
||||||
|
- result.search = `noCache=${Date.now().toString(32)}`;
|
||||||
|
+ // use no cache header instead
|
||||||
|
+ // result.search = `noCache=${Date.now().toString(32)}`;
|
||||||
|
}
|
||||||
|
return result;
|
||||||
|
}
|
||||||
49
app-upgrade-config.json
Normal file
49
app-upgrade-config.json
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
{
|
||||||
|
"lastUpdated": "2025-11-10T08:14:28Z",
|
||||||
|
"versions": {
|
||||||
|
"1.6.7": {
|
||||||
|
"metadata": {
|
||||||
|
"segmentId": "legacy-v1",
|
||||||
|
"segmentType": "legacy"
|
||||||
|
},
|
||||||
|
"minCompatibleVersion": "1.0.0",
|
||||||
|
"description": "Last stable v1.7.x release - required intermediate version for users below v1.7",
|
||||||
|
"channels": {
|
||||||
|
"latest": {
|
||||||
|
"version": "1.6.7",
|
||||||
|
"feedUrls": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.7",
|
||||||
|
"gitcode": "https://releases.cherry-ai.com"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"rc": {
|
||||||
|
"version": "1.6.0-rc.5",
|
||||||
|
"feedUrls": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.0-rc.5",
|
||||||
|
"gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.0-rc.5"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"beta": {
|
||||||
|
"version": "1.7.0-beta.3",
|
||||||
|
"feedUrls": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.7.0-beta.3",
|
||||||
|
"gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.7.0-beta.3"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"2.0.0": {
|
||||||
|
"metadata": {
|
||||||
|
"segmentId": "gateway-v2",
|
||||||
|
"segmentType": "breaking"
|
||||||
|
},
|
||||||
|
"minCompatibleVersion": "1.7.0",
|
||||||
|
"description": "Major release v2.0 - required intermediate version for v2.x upgrades",
|
||||||
|
"channels": {
|
||||||
|
"latest": null,
|
||||||
|
"rc": null,
|
||||||
|
"beta": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,7 +14,7 @@
|
|||||||
}
|
}
|
||||||
},
|
},
|
||||||
"enabled": true,
|
"enabled": true,
|
||||||
"includes": ["**/*.json", "!*.json", "!**/package.json"]
|
"includes": ["**/*.json", "!*.json", "!**/package.json", "!coverage/**"]
|
||||||
},
|
},
|
||||||
"css": {
|
"css": {
|
||||||
"formatter": {
|
"formatter": {
|
||||||
@@ -23,7 +23,7 @@
|
|||||||
},
|
},
|
||||||
"files": {
|
"files": {
|
||||||
"ignoreUnknown": false,
|
"ignoreUnknown": false,
|
||||||
"includes": ["**", "!**/.claude/**"],
|
"includes": ["**", "!**/.claude/**", "!**/.vscode/**"],
|
||||||
"maxSize": 2097152
|
"maxSize": 2097152
|
||||||
},
|
},
|
||||||
"formatter": {
|
"formatter": {
|
||||||
|
|||||||
81
config/app-upgrade-segments.json
Normal file
81
config/app-upgrade-segments.json
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
{
|
||||||
|
"segments": [
|
||||||
|
{
|
||||||
|
"id": "legacy-v1",
|
||||||
|
"type": "legacy",
|
||||||
|
"match": {
|
||||||
|
"range": ">=1.0.0 <2.0.0"
|
||||||
|
},
|
||||||
|
"minCompatibleVersion": "1.0.0",
|
||||||
|
"description": "Last stable v1.7.x release - required intermediate version for users below v1.7",
|
||||||
|
"channelTemplates": {
|
||||||
|
"latest": {
|
||||||
|
"feedTemplates": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
|
||||||
|
"gitcode": "https://releases.cherry-ai.com"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"rc": {
|
||||||
|
"feedTemplates": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
|
||||||
|
"gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"beta": {
|
||||||
|
"feedTemplates": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
|
||||||
|
"gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "gateway-v2",
|
||||||
|
"type": "breaking",
|
||||||
|
"match": {
|
||||||
|
"exact": ["2.0.0"]
|
||||||
|
},
|
||||||
|
"lockedVersion": "2.0.0",
|
||||||
|
"minCompatibleVersion": "1.7.0",
|
||||||
|
"description": "Major release v2.0 - required intermediate version for v2.x upgrades",
|
||||||
|
"channelTemplates": {
|
||||||
|
"latest": {
|
||||||
|
"feedTemplates": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
|
||||||
|
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/{{tag}}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "current-v2",
|
||||||
|
"type": "latest",
|
||||||
|
"match": {
|
||||||
|
"range": ">=2.0.0 <3.0.0",
|
||||||
|
"excludeExact": ["2.0.0"]
|
||||||
|
},
|
||||||
|
"minCompatibleVersion": "2.0.0",
|
||||||
|
"description": "Current latest v2.x release",
|
||||||
|
"channelTemplates": {
|
||||||
|
"latest": {
|
||||||
|
"feedTemplates": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
|
||||||
|
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/{{tag}}"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"rc": {
|
||||||
|
"feedTemplates": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
|
||||||
|
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/{{tag}}"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"beta": {
|
||||||
|
"feedTemplates": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
|
||||||
|
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/{{tag}}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
430
docs/technical/app-upgrade-config-en.md
Normal file
430
docs/technical/app-upgrade-config-en.md
Normal file
@@ -0,0 +1,430 @@
|
|||||||
|
# Update Configuration System Design Document
|
||||||
|
|
||||||
|
## Background
|
||||||
|
|
||||||
|
Currently, AppUpdater directly queries the GitHub API to retrieve beta and rc update information. To support users in China, we need to fetch a static JSON configuration file from GitHub/GitCode based on IP geolocation, which contains update URLs for all channels.
|
||||||
|
|
||||||
|
## Design Goals
|
||||||
|
|
||||||
|
1. Support different configuration sources based on IP geolocation (GitHub/GitCode)
|
||||||
|
2. Support version compatibility control (e.g., users below v1.x must upgrade to v1.7.0 before upgrading to v2.0)
|
||||||
|
3. Easy to extend, supporting future multi-major-version upgrade paths (v1.6 → v1.7 → v2.0 → v2.8 → v3.0)
|
||||||
|
4. Maintain compatibility with existing electron-updater mechanism
|
||||||
|
|
||||||
|
## Current Version Strategy
|
||||||
|
|
||||||
|
- **v1.7.x** is the last version of the 1.x series
|
||||||
|
- Users **below v1.7.0** must first upgrade to v1.7.0 (or higher 1.7.x version)
|
||||||
|
- Users **v1.7.0 and above** can directly upgrade to v2.x.x
|
||||||
|
|
||||||
|
## Automation Workflow
|
||||||
|
|
||||||
|
The `x-files/app-upgrade-config/app-upgrade-config.json` file is synchronized by the [`Update App Upgrade Config`](../../.github/workflows/update-app-upgrade-config.yml) workflow. The workflow runs the [`scripts/update-app-upgrade-config.ts`](../../scripts/update-app-upgrade-config.ts) helper so that every release tag automatically updates the JSON in `x-files/app-upgrade-config`.
|
||||||
|
|
||||||
|
### Trigger Conditions
|
||||||
|
|
||||||
|
- **Release events (`release: released/prereleased`)**
|
||||||
|
- Draft releases are ignored.
|
||||||
|
- When GitHub marks the release as _prerelease_, the tag must include `-beta`/`-rc` (with optional numeric suffix). Otherwise the workflow exits early.
|
||||||
|
- When GitHub marks the release as stable, the tag must match the latest release returned by the GitHub API. This prevents out-of-order updates when publishing historical tags.
|
||||||
|
- If the guard clauses pass, the version is tagged as `latest` or `beta/rc` based on its semantic suffix and propagated to the script through the `IS_PRERELEASE` flag.
|
||||||
|
- **Manual dispatch (`workflow_dispatch`)**
|
||||||
|
- Required input: `tag` (e.g., `v2.0.1`). Optional input: `is_prerelease` (defaults to `false`).
|
||||||
|
- When `is_prerelease=true`, the tag must carry a beta/rc suffix, mirroring the automatic validation.
|
||||||
|
- Manual runs still download the latest release metadata so that the workflow knows whether the tag represents the newest stable version (for documentation inside the PR body).
|
||||||
|
|
||||||
|
### Workflow Steps
|
||||||
|
|
||||||
|
1. **Guard + metadata preparation** – the `Check if should proceed` and `Prepare metadata` steps compute the target tag, prerelease flag, whether the tag is the newest release, and a `safe_tag` slug used for branch names. When any rule fails, the workflow stops without touching the config.
|
||||||
|
2. **Checkout source branches** – the default branch is checked out into `main/`, while the long-lived `x-files/app-upgrade-config` branch lives in `cs/`. All modifications happen in the latter directory.
|
||||||
|
3. **Install toolchain** – Node.js 22, Corepack, and frozen Yarn dependencies are installed inside `main/`.
|
||||||
|
4. **Run the update script** – `yarn tsx scripts/update-app-upgrade-config.ts --tag <tag> --config ../cs/app-upgrade-config.json --is-prerelease <flag>` updates the JSON in-place.
|
||||||
|
- The script normalizes the tag (e.g., strips `v` prefix), detects the release channel (`latest`, `rc`, `beta`), and loads segment rules from `config/app-upgrade-segments.json`.
|
||||||
|
- It validates that prerelease flags and semantic suffixes agree, enforces locked segments, builds mirror feed URLs, and performs release-availability checks (GitHub HEAD request for every channel; GitCode GET for latest channels, falling back to `https://releases.cherry-ai.com` when gitcode is delayed).
|
||||||
|
- After updating the relevant channel entry, the script rewrites the config with semver-sort order and a new `lastUpdated` timestamp.
|
||||||
|
5. **Detect changes + create PR** – if `cs/app-upgrade-config.json` changed, the workflow opens a PR `chore/update-app-upgrade-config/<safe_tag>` against `x-files/app-upgrade-config` with a commit message `🤖 chore: sync app-upgrade-config for <tag>`. Otherwise it logs that no update is required.
|
||||||
|
|
||||||
|
### Manual Trigger Guide
|
||||||
|
|
||||||
|
1. Open the Cherry Studio repository on GitHub → **Actions** tab → select **Update App Upgrade Config**.
|
||||||
|
2. Click **Run workflow**, choose the default branch (usually `main`), and fill in the `tag` input (e.g., `v2.1.0`).
|
||||||
|
3. Toggle `is_prerelease` only when the tag carries a prerelease suffix (`-beta`, `-rc`). Leave it unchecked for stable releases.
|
||||||
|
4. Start the run and wait for it to finish. Check the generated PR in the `x-files/app-upgrade-config` branch, verify the diff in `app-upgrade-config.json`, and merge once validated.
|
||||||
|
|
||||||
|
## JSON Configuration File Format
|
||||||
|
|
||||||
|
### File Location
|
||||||
|
|
||||||
|
- **GitHub**: `https://raw.githubusercontent.com/CherryHQ/cherry-studio/refs/heads/x-files/app-upgrade-config/app-upgrade-config.json`
|
||||||
|
- **GitCode**: `https://gitcode.com/CherryHQ/cherry-studio/raw/x-files/app-upgrade-config/app-upgrade-config.json`
|
||||||
|
|
||||||
|
**Note**: Both mirrors provide the same configuration file hosted on the `x-files/app-upgrade-config` branch. The client automatically selects the optimal mirror based on IP geolocation.
|
||||||
|
|
||||||
|
### Configuration Structure (Current Implementation)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"lastUpdated": "2025-01-05T00:00:00Z",
|
||||||
|
"versions": {
|
||||||
|
"1.6.7": {
|
||||||
|
"minCompatibleVersion": "1.0.0",
|
||||||
|
"description": "Last stable v1.7.x release - required intermediate version for users below v1.7",
|
||||||
|
"channels": {
|
||||||
|
"latest": {
|
||||||
|
"version": "1.6.7",
|
||||||
|
"feedUrls": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.7",
|
||||||
|
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/v1.6.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"rc": {
|
||||||
|
"version": "1.6.0-rc.5",
|
||||||
|
"feedUrls": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.0-rc.5",
|
||||||
|
"gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.0-rc.5"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"beta": {
|
||||||
|
"version": "1.6.7-beta.3",
|
||||||
|
"feedUrls": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.7.0-beta.3",
|
||||||
|
"gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.7.0-beta.3"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"2.0.0": {
|
||||||
|
"minCompatibleVersion": "1.7.0",
|
||||||
|
"description": "Major release v2.0 - required intermediate version for v2.x upgrades",
|
||||||
|
"channels": {
|
||||||
|
"latest": null,
|
||||||
|
"rc": null,
|
||||||
|
"beta": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Future Extension Example
|
||||||
|
|
||||||
|
When releasing v3.0, if users need to first upgrade to v2.8, you can add:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"2.8.0": {
|
||||||
|
"minCompatibleVersion": "2.0.0",
|
||||||
|
"description": "Stable v2.8 - required for v3 upgrade",
|
||||||
|
"channels": {
|
||||||
|
"latest": {
|
||||||
|
"version": "2.8.0",
|
||||||
|
"feedUrls": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v2.8.0",
|
||||||
|
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/v2.8.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"rc": null,
|
||||||
|
"beta": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"3.0.0": {
|
||||||
|
"minCompatibleVersion": "2.8.0",
|
||||||
|
"description": "Major release v3.0",
|
||||||
|
"channels": {
|
||||||
|
"latest": {
|
||||||
|
"version": "3.0.0",
|
||||||
|
"feedUrls": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/latest",
|
||||||
|
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/latest"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"rc": {
|
||||||
|
"version": "3.0.0-rc.1",
|
||||||
|
"feedUrls": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v3.0.0-rc.1",
|
||||||
|
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/v3.0.0-rc.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"beta": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Field Descriptions
|
||||||
|
|
||||||
|
- `lastUpdated`: Last update time of the configuration file (ISO 8601 format)
|
||||||
|
- `versions`: Version configuration object, key is the version number, sorted by semantic versioning
|
||||||
|
- `minCompatibleVersion`: Minimum compatible version that can upgrade to this version
|
||||||
|
- `description`: Version description
|
||||||
|
- `channels`: Update channel configuration
|
||||||
|
- `latest`: Stable release channel
|
||||||
|
- `rc`: Release Candidate channel
|
||||||
|
- `beta`: Beta testing channel
|
||||||
|
- Each channel contains:
|
||||||
|
- `version`: Version number for this channel
|
||||||
|
- `feedUrls`: Multi-mirror URL configuration
|
||||||
|
- `github`: electron-updater feed URL for GitHub mirror
|
||||||
|
- `gitcode`: electron-updater feed URL for GitCode mirror
|
||||||
|
- `metadata`: Stable mapping info for automation
|
||||||
|
- `segmentId`: ID from `config/app-upgrade-segments.json`
|
||||||
|
- `segmentType`: Optional flag (`legacy` | `breaking` | `latest`) for documentation/debugging
|
||||||
|
|
||||||
|
## TypeScript Type Definitions
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// Mirror enum
|
||||||
|
enum UpdateMirror {
|
||||||
|
GITHUB = 'github',
|
||||||
|
GITCODE = 'gitcode'
|
||||||
|
}
|
||||||
|
|
||||||
|
interface UpdateConfig {
|
||||||
|
lastUpdated: string
|
||||||
|
versions: {
|
||||||
|
[versionKey: string]: VersionConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interface VersionConfig {
|
||||||
|
minCompatibleVersion: string
|
||||||
|
description: string
|
||||||
|
channels: {
|
||||||
|
latest: ChannelConfig | null
|
||||||
|
rc: ChannelConfig | null
|
||||||
|
beta: ChannelConfig | null
|
||||||
|
}
|
||||||
|
metadata?: {
|
||||||
|
segmentId: string
|
||||||
|
segmentType?: 'legacy' | 'breaking' | 'latest'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ChannelConfig {
|
||||||
|
version: string
|
||||||
|
feedUrls: Record<UpdateMirror, string>
|
||||||
|
// Equivalent to:
|
||||||
|
// feedUrls: {
|
||||||
|
// github: string
|
||||||
|
// gitcode: string
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Segment Metadata & Breaking Markers
|
||||||
|
|
||||||
|
- **Segment definitions** now live in `config/app-upgrade-segments.json`. Each segment describes a semantic-version range (or exact matches) plus metadata such as `segmentId`, `segmentType`, `minCompatibleVersion`, and per-channel feed URL templates.
|
||||||
|
- Each entry under `versions` carries a `metadata.segmentId`. This acts as the stable key that scripts use to decide which slot to update, even if the actual semantic version string changes.
|
||||||
|
- Mark major upgrade gateways (e.g., `2.0.0`) by giving the related segment a `segmentType: "breaking"` and (optionally) `lockedVersion`. This prevents automation from accidentally moving that entry when other 2.x builds ship.
|
||||||
|
- Adding another breaking hop (e.g., `3.0.0`) only requires defining a new segment in the JSON file; the automation will pick it up on the next run.
|
||||||
|
|
||||||
|
## Automation Workflow
|
||||||
|
|
||||||
|
Starting from this change, `.github/workflows/update-app-upgrade-config.yml` listens to GitHub release events (published + prerelease). The workflow:
|
||||||
|
|
||||||
|
1. Checks out the default branch (for scripts) and the `x-files/app-upgrade-config` branch (where the config is hosted).
|
||||||
|
2. Runs `yarn tsx scripts/update-app-upgrade-config.ts --tag <tag> --config ../cs/app-upgrade-config.json` to regenerate the config directly inside the `x-files/app-upgrade-config` working tree.
|
||||||
|
3. If the file changed, it opens a PR against `x-files/app-upgrade-config` via `peter-evans/create-pull-request`, with the generated diff limited to `app-upgrade-config.json`.
|
||||||
|
|
||||||
|
You can run the same script locally via `yarn update:upgrade-config --tag v2.1.6 --config ../cs/app-upgrade-config.json` (add `--dry-run` to preview) to reproduce or debug whatever the workflow does. Passing `--skip-release-checks` along with `--dry-run` lets you bypass the release-page existence check (useful when the GitHub/GitCode pages aren’t published yet). Running without `--config` continues to update the copy in your current working directory (main branch) for documentation purposes.
|
||||||
|
|
||||||
|
## Version Matching Logic
|
||||||
|
|
||||||
|
### Algorithm Flow
|
||||||
|
|
||||||
|
1. Get user's current version (`currentVersion`) and requested channel (`requestedChannel`)
|
||||||
|
2. Get all version numbers from configuration file, sort in descending order by semantic versioning
|
||||||
|
3. Iterate through the sorted version list:
|
||||||
|
- Check if `currentVersion >= minCompatibleVersion`
|
||||||
|
- Check if the requested `channel` exists and is not `null`
|
||||||
|
- If conditions are met, return the channel configuration
|
||||||
|
4. If no matching version is found, return `null`
|
||||||
|
|
||||||
|
### Pseudocode Implementation
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
function findCompatibleVersion(
|
||||||
|
currentVersion: string,
|
||||||
|
requestedChannel: UpgradeChannel,
|
||||||
|
config: UpdateConfig
|
||||||
|
): ChannelConfig | null {
|
||||||
|
// Get all version numbers and sort in descending order
|
||||||
|
const versions = Object.keys(config.versions).sort(semver.rcompare)
|
||||||
|
|
||||||
|
for (const versionKey of versions) {
|
||||||
|
const versionConfig = config.versions[versionKey]
|
||||||
|
const channelConfig = versionConfig.channels[requestedChannel]
|
||||||
|
|
||||||
|
// Check version compatibility and channel availability
|
||||||
|
if (
|
||||||
|
semver.gte(currentVersion, versionConfig.minCompatibleVersion) &&
|
||||||
|
channelConfig !== null
|
||||||
|
) {
|
||||||
|
return channelConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return null // No compatible version found
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Upgrade Path Examples
|
||||||
|
|
||||||
|
### Scenario 1: v1.6.5 User Upgrade (Below 1.7)
|
||||||
|
|
||||||
|
- **Current Version**: 1.6.5
|
||||||
|
- **Requested Channel**: latest
|
||||||
|
- **Match Result**: 1.7.0
|
||||||
|
- **Reason**: 1.6.5 >= 0.0.0 (satisfies 1.7.0's minCompatibleVersion), but doesn't satisfy 2.0.0's minCompatibleVersion (1.7.0)
|
||||||
|
- **Action**: Prompt user to upgrade to 1.7.0, which is the required intermediate version for v2.x upgrade
|
||||||
|
|
||||||
|
### Scenario 2: v1.6.5 User Requests rc/beta
|
||||||
|
|
||||||
|
- **Current Version**: 1.6.5
|
||||||
|
- **Requested Channel**: rc or beta
|
||||||
|
- **Match Result**: 1.7.0 (latest)
|
||||||
|
- **Reason**: 1.7.0 version doesn't provide rc/beta channels (values are null)
|
||||||
|
- **Action**: Upgrade to 1.7.0 stable version
|
||||||
|
|
||||||
|
### Scenario 3: v1.7.0 User Upgrades to Latest
|
||||||
|
|
||||||
|
- **Current Version**: 1.7.0
|
||||||
|
- **Requested Channel**: latest
|
||||||
|
- **Match Result**: 2.0.0
|
||||||
|
- **Reason**: 1.7.0 >= 1.7.0 (satisfies 2.0.0's minCompatibleVersion)
|
||||||
|
- **Action**: Directly upgrade to 2.0.0 (current latest stable version)
|
||||||
|
|
||||||
|
### Scenario 4: v1.7.2 User Upgrades to RC Version
|
||||||
|
|
||||||
|
- **Current Version**: 1.7.2
|
||||||
|
- **Requested Channel**: rc
|
||||||
|
- **Match Result**: 2.0.0-rc.1
|
||||||
|
- **Reason**: 1.7.2 >= 1.7.0 (satisfies 2.0.0's minCompatibleVersion), and rc channel exists
|
||||||
|
- **Action**: Upgrade to 2.0.0-rc.1
|
||||||
|
|
||||||
|
### Scenario 5: v1.7.0 User Upgrades to Beta Version
|
||||||
|
|
||||||
|
- **Current Version**: 1.7.0
|
||||||
|
- **Requested Channel**: beta
|
||||||
|
- **Match Result**: 2.0.0-beta.1
|
||||||
|
- **Reason**: 1.7.0 >= 1.7.0, and beta channel exists
|
||||||
|
- **Action**: Upgrade to 2.0.0-beta.1
|
||||||
|
|
||||||
|
### Scenario 6: v2.5.0 User Upgrade (Future)
|
||||||
|
|
||||||
|
Assuming v2.8.0 and v3.0.0 configurations have been added:
|
||||||
|
- **Current Version**: 2.5.0
|
||||||
|
- **Requested Channel**: latest
|
||||||
|
- **Match Result**: 2.8.0
|
||||||
|
- **Reason**: 2.5.0 >= 2.0.0 (satisfies 2.8.0's minCompatibleVersion), but doesn't satisfy 3.0.0's requirement
|
||||||
|
- **Action**: Prompt user to upgrade to 2.8.0, which is the required intermediate version for v3.x upgrade
|
||||||
|
|
||||||
|
## Code Changes
|
||||||
|
|
||||||
|
### Main Modifications
|
||||||
|
|
||||||
|
1. **New Methods**
|
||||||
|
- `_fetchUpdateConfig(ipCountry: string): Promise<UpdateConfig | null>` - Fetch configuration file based on IP
|
||||||
|
- `_findCompatibleChannel(currentVersion: string, channel: UpgradeChannel, config: UpdateConfig): ChannelConfig | null` - Find compatible channel configuration
|
||||||
|
|
||||||
|
2. **Modified Methods**
|
||||||
|
- `_getReleaseVersionFromGithub()` → Remove or refactor to `_getChannelFeedUrl()`
|
||||||
|
- `_setFeedUrl()` - Use new configuration system to replace existing logic
|
||||||
|
|
||||||
|
3. **New Type Definitions**
|
||||||
|
- `UpdateConfig`
|
||||||
|
- `VersionConfig`
|
||||||
|
- `ChannelConfig`
|
||||||
|
|
||||||
|
### Mirror Selection Logic
|
||||||
|
|
||||||
|
The client automatically selects the optimal mirror based on IP geolocation:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
private async _setFeedUrl() {
|
||||||
|
const currentVersion = app.getVersion()
|
||||||
|
const testPlan = configManager.getTestPlan()
|
||||||
|
const requestedChannel = testPlan ? this._getTestChannel() : UpgradeChannel.LATEST
|
||||||
|
|
||||||
|
// Determine mirror based on IP country
|
||||||
|
const ipCountry = await getIpCountry()
|
||||||
|
const mirror = ipCountry.toLowerCase() === 'cn' ? 'gitcode' : 'github'
|
||||||
|
|
||||||
|
// Fetch update config
|
||||||
|
const config = await this._fetchUpdateConfig(mirror)
|
||||||
|
|
||||||
|
if (config) {
|
||||||
|
const channelConfig = this._findCompatibleChannel(currentVersion, requestedChannel, config)
|
||||||
|
if (channelConfig) {
|
||||||
|
// Select feed URL from the corresponding mirror
|
||||||
|
const feedUrl = channelConfig.feedUrls[mirror]
|
||||||
|
this._setChannel(requestedChannel, feedUrl)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback logic
|
||||||
|
const defaultFeedUrl = mirror === 'gitcode'
|
||||||
|
? FeedUrl.PRODUCTION
|
||||||
|
: FeedUrl.GITHUB_LATEST
|
||||||
|
this._setChannel(UpgradeChannel.LATEST, defaultFeedUrl)
|
||||||
|
}
|
||||||
|
|
||||||
|
private async _fetchUpdateConfig(mirror: 'github' | 'gitcode'): Promise<UpdateConfig | null> {
|
||||||
|
const configUrl = mirror === 'gitcode'
|
||||||
|
? UpdateConfigUrl.GITCODE
|
||||||
|
: UpdateConfigUrl.GITHUB
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await net.fetch(configUrl, {
|
||||||
|
headers: {
|
||||||
|
'User-Agent': generateUserAgent(),
|
||||||
|
'Accept': 'application/json',
|
||||||
|
'X-Client-Id': configManager.getClientId()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return await response.json() as UpdateConfig
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to fetch update config:', error)
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Fallback and Error Handling Strategy
|
||||||
|
|
||||||
|
1. **Configuration file fetch failure**: Log error, return current version, don't offer updates
|
||||||
|
2. **No matching version**: Notify user that current version doesn't support automatic upgrade
|
||||||
|
3. **Network exception**: Cache last successfully fetched configuration (optional)
|
||||||
|
|
||||||
|
## GitHub Release Requirements
|
||||||
|
|
||||||
|
To support intermediate version upgrades, the following files need to be retained:
|
||||||
|
|
||||||
|
- **v1.7.0 release** and its latest*.yml files (as upgrade target for users below v1.7)
|
||||||
|
- Future intermediate versions (e.g., v2.8.0) need to retain corresponding release and latest*.yml files
|
||||||
|
- Complete installation packages for each version
|
||||||
|
|
||||||
|
### Currently Required Releases
|
||||||
|
|
||||||
|
| Version | Purpose | Must Retain |
|
||||||
|
|---------|---------|-------------|
|
||||||
|
| v1.7.0 | Upgrade target for users below 1.7 | ✅ Yes |
|
||||||
|
| v2.0.0-rc.1 | RC testing channel | ❌ Optional |
|
||||||
|
| v2.0.0-beta.1 | Beta testing channel | ❌ Optional |
|
||||||
|
| latest | Latest stable version (automatic) | ✅ Yes |
|
||||||
|
|
||||||
|
## Advantages
|
||||||
|
|
||||||
|
1. **Flexibility**: Supports arbitrarily complex upgrade paths
|
||||||
|
2. **Extensibility**: Adding new versions only requires adding new entries to the configuration file
|
||||||
|
3. **Maintainability**: Configuration is separated from code, allowing upgrade strategy adjustments without releasing new versions
|
||||||
|
4. **Multi-source support**: Automatically selects optimal configuration source based on geolocation
|
||||||
|
5. **Version control**: Enforces intermediate version upgrades, ensuring data migration and compatibility
|
||||||
|
|
||||||
|
## Future Extensions
|
||||||
|
|
||||||
|
- Support more granular version range control (e.g., `>=1.5.0 <1.8.0`)
|
||||||
|
- Support multi-step upgrade path hints (e.g., notify user needs 1.5 → 1.8 → 2.0)
|
||||||
|
- Support A/B testing and gradual rollout
|
||||||
|
- Support local caching and expiration strategy for configuration files
|
||||||
430
docs/technical/app-upgrade-config-zh.md
Normal file
430
docs/technical/app-upgrade-config-zh.md
Normal file
@@ -0,0 +1,430 @@
|
|||||||
|
# 更新配置系统设计文档
|
||||||
|
|
||||||
|
## 背景
|
||||||
|
|
||||||
|
当前 AppUpdater 直接请求 GitHub API 获取 beta 和 rc 的更新信息。为了支持国内用户,需要根据 IP 地理位置,分别从 GitHub/GitCode 获取一个固定的 JSON 配置文件,该文件包含所有渠道的更新地址。
|
||||||
|
|
||||||
|
## 设计目标
|
||||||
|
|
||||||
|
1. 支持根据 IP 地理位置选择不同的配置源(GitHub/GitCode)
|
||||||
|
2. 支持版本兼容性控制(如 v1.x 以下必须先升级到 v1.7.0 才能升级到 v2.0)
|
||||||
|
3. 易于扩展,支持未来多个主版本的升级路径(v1.6 → v1.7 → v2.0 → v2.8 → v3.0)
|
||||||
|
4. 保持与现有 electron-updater 机制的兼容性
|
||||||
|
|
||||||
|
## 当前版本策略
|
||||||
|
|
||||||
|
- **v1.7.x** 是 1.x 系列的最后版本
|
||||||
|
- **v1.7.0 以下**的用户必须先升级到 v1.7.0(或更高的 1.7.x 版本)
|
||||||
|
- **v1.7.0 及以上**的用户可以直接升级到 v2.x.x
|
||||||
|
|
||||||
|
## 自动化工作流
|
||||||
|
|
||||||
|
`x-files/app-upgrade-config/app-upgrade-config.json` 由 [`Update App Upgrade Config`](../../.github/workflows/update-app-upgrade-config.yml) workflow 自动同步。工作流会调用 [`scripts/update-app-upgrade-config.ts`](../../scripts/update-app-upgrade-config.ts) 脚本,根据指定 tag 更新 `x-files/app-upgrade-config` 分支上的配置文件。
|
||||||
|
|
||||||
|
### 触发条件
|
||||||
|
|
||||||
|
- **Release 事件(`release: released/prereleased`)**
|
||||||
|
- Draft release 会被忽略。
|
||||||
|
- 当 GitHub 将 release 标记为 *prerelease* 时,tag 必须包含 `-beta`/`-rc`(可带序号),否则直接跳过。
|
||||||
|
- 当 release 标记为稳定版时,tag 必须与 GitHub API 返回的最新稳定版本一致,防止发布历史 tag 时意外挂起工作流。
|
||||||
|
- 满足上述条件后,工作流会根据语义化版本判断渠道(`latest`/`beta`/`rc`),并通过 `IS_PRERELEASE` 传递给脚本。
|
||||||
|
- **手动触发(`workflow_dispatch`)**
|
||||||
|
- 必填:`tag`(例:`v2.0.1`);选填:`is_prerelease`(默认 `false`)。
|
||||||
|
- 当 `is_prerelease=true` 时,同样要求 tag 带有 beta/rc 后缀。
|
||||||
|
- 手动运行仍会请求 GitHub 最新 release 信息,用于在 PR 说明中标注该 tag 是否是最新稳定版。
|
||||||
|
|
||||||
|
### 工作流步骤
|
||||||
|
|
||||||
|
1. **检查与元数据准备**:`Check if should proceed` 和 `Prepare metadata` 步骤会计算 tag、prerelease 标志、是否最新版本以及用于分支名的 `safe_tag`。若任意校验失败,工作流立即退出。
|
||||||
|
2. **检出分支**:默认分支被检出到 `main/`,长期维护的 `x-files/app-upgrade-config` 分支则在 `cs/` 中,所有改动都发生在 `cs/`。
|
||||||
|
3. **安装工具链**:安装 Node.js 22、启用 Corepack,并在 `main/` 目录执行 `yarn install --immutable`。
|
||||||
|
4. **运行更新脚本**:执行 `yarn tsx scripts/update-app-upgrade-config.ts --tag <tag> --config ../cs/app-upgrade-config.json --is-prerelease <flag>`。
|
||||||
|
- 脚本会标准化 tag(去掉 `v` 前缀等)、识别渠道、加载 `config/app-upgrade-segments.json` 中的分段规则。
|
||||||
|
- 校验 prerelease 标志与语义后缀是否匹配、强制锁定的 segment 是否满足、生成镜像的下载地址,并检查 release 是否已经在 GitHub/GitCode 可用(latest 渠道在 GitCode 不可用时会回退到 `https://releases.cherry-ai.com`)。
|
||||||
|
- 更新对应的渠道配置后,脚本会按 semver 排序写回 JSON,并刷新 `lastUpdated`。
|
||||||
|
5. **检测变更并创建 PR**:若 `cs/app-upgrade-config.json` 有变更,则创建 `chore/update-app-upgrade-config/<safe_tag>` 分支,提交信息为 `🤖 chore: sync app-upgrade-config for <tag>`,并向 `x-files/app-upgrade-config` 提 PR;无变更则输出提示。
|
||||||
|
|
||||||
|
### 手动触发指南
|
||||||
|
|
||||||
|
1. 进入 Cherry Studio 仓库的 GitHub **Actions** 页面,选择 **Update App Upgrade Config** 工作流。
|
||||||
|
2. 点击 **Run workflow**,保持默认分支(通常为 `main`),填写 `tag`(如 `v2.1.0`)。
|
||||||
|
3. 只有在 tag 带 `-beta`/`-rc` 后缀时才勾选 `is_prerelease`,稳定版保持默认。
|
||||||
|
4. 启动运行并等待完成,随后到 `x-files/app-upgrade-config` 分支的 PR 查看 `app-upgrade-config.json` 的变更并在验证后合并。
|
||||||
|
|
||||||
|
## JSON 配置文件格式
|
||||||
|
|
||||||
|
### 文件位置
|
||||||
|
|
||||||
|
- **GitHub**: `https://raw.githubusercontent.com/CherryHQ/cherry-studio/refs/heads/x-files/app-upgrade-config/app-upgrade-config.json`
|
||||||
|
- **GitCode**: `https://gitcode.com/CherryHQ/cherry-studio/raw/x-files/app-upgrade-config/app-upgrade-config.json`
|
||||||
|
|
||||||
|
**说明**:两个镜像源提供相同的配置文件,统一托管在 `x-files/app-upgrade-config` 分支上。客户端根据 IP 地理位置自动选择最优镜像源。
|
||||||
|
|
||||||
|
### 配置结构(当前实际配置)
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"lastUpdated": "2025-01-05T00:00:00Z",
|
||||||
|
"versions": {
|
||||||
|
"1.6.7": {
|
||||||
|
"minCompatibleVersion": "1.0.0",
|
||||||
|
"description": "Last stable v1.7.x release - required intermediate version for users below v1.7",
|
||||||
|
"channels": {
|
||||||
|
"latest": {
|
||||||
|
"version": "1.6.7",
|
||||||
|
"feedUrls": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.7",
|
||||||
|
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/v1.6.7"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"rc": {
|
||||||
|
"version": "1.6.0-rc.5",
|
||||||
|
"feedUrls": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.0-rc.5",
|
||||||
|
"gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.0-rc.5"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"beta": {
|
||||||
|
"version": "1.6.7-beta.3",
|
||||||
|
"feedUrls": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.7.0-beta.3",
|
||||||
|
"gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.7.0-beta.3"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"2.0.0": {
|
||||||
|
"minCompatibleVersion": "1.7.0",
|
||||||
|
"description": "Major release v2.0 - required intermediate version for v2.x upgrades",
|
||||||
|
"channels": {
|
||||||
|
"latest": null,
|
||||||
|
"rc": null,
|
||||||
|
"beta": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 未来扩展示例
|
||||||
|
|
||||||
|
当需要发布 v3.0 时,如果需要强制用户先升级到 v2.8,可以添加:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"2.8.0": {
|
||||||
|
"minCompatibleVersion": "2.0.0",
|
||||||
|
"description": "Stable v2.8 - required for v3 upgrade",
|
||||||
|
"channels": {
|
||||||
|
"latest": {
|
||||||
|
"version": "2.8.0",
|
||||||
|
"feedUrls": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v2.8.0",
|
||||||
|
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/v2.8.0"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"rc": null,
|
||||||
|
"beta": null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"3.0.0": {
|
||||||
|
"minCompatibleVersion": "2.8.0",
|
||||||
|
"description": "Major release v3.0",
|
||||||
|
"channels": {
|
||||||
|
"latest": {
|
||||||
|
"version": "3.0.0",
|
||||||
|
"feedUrls": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/latest",
|
||||||
|
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/latest"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"rc": {
|
||||||
|
"version": "3.0.0-rc.1",
|
||||||
|
"feedUrls": {
|
||||||
|
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v3.0.0-rc.1",
|
||||||
|
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/v3.0.0-rc.1"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"beta": null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### 字段说明
|
||||||
|
|
||||||
|
- `lastUpdated`: 配置文件最后更新时间(ISO 8601 格式)
|
||||||
|
- `versions`: 版本配置对象,key 为版本号,按语义化版本排序
|
||||||
|
- `minCompatibleVersion`: 可以升级到此版本的最低兼容版本
|
||||||
|
- `description`: 版本描述
|
||||||
|
- `channels`: 更新渠道配置
|
||||||
|
- `latest`: 稳定版渠道
|
||||||
|
- `rc`: Release Candidate 渠道
|
||||||
|
- `beta`: Beta 测试渠道
|
||||||
|
- 每个渠道包含:
|
||||||
|
- `version`: 该渠道的版本号
|
||||||
|
- `feedUrls`: 多镜像源 URL 配置
|
||||||
|
- `github`: GitHub 镜像源的 electron-updater feed URL
|
||||||
|
- `gitcode`: GitCode 镜像源的 electron-updater feed URL
|
||||||
|
- `metadata`: 自动化匹配所需的稳定标识
|
||||||
|
- `segmentId`: 来自 `config/app-upgrade-segments.json` 的段位 ID
|
||||||
|
- `segmentType`: 可选字段(`legacy` | `breaking` | `latest`),便于文档/调试
|
||||||
|
|
||||||
|
## TypeScript 类型定义
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
// 镜像源枚举
|
||||||
|
enum UpdateMirror {
|
||||||
|
GITHUB = 'github',
|
||||||
|
GITCODE = 'gitcode'
|
||||||
|
}
|
||||||
|
|
||||||
|
interface UpdateConfig {
|
||||||
|
lastUpdated: string
|
||||||
|
versions: {
|
||||||
|
[versionKey: string]: VersionConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interface VersionConfig {
|
||||||
|
minCompatibleVersion: string
|
||||||
|
description: string
|
||||||
|
channels: {
|
||||||
|
latest: ChannelConfig | null
|
||||||
|
rc: ChannelConfig | null
|
||||||
|
beta: ChannelConfig | null
|
||||||
|
}
|
||||||
|
metadata?: {
|
||||||
|
segmentId: string
|
||||||
|
segmentType?: 'legacy' | 'breaking' | 'latest'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ChannelConfig {
|
||||||
|
version: string
|
||||||
|
feedUrls: Record<UpdateMirror, string>
|
||||||
|
// 等同于:
|
||||||
|
// feedUrls: {
|
||||||
|
// github: string
|
||||||
|
// gitcode: string
|
||||||
|
// }
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 段位元数据(Break Change 标记)
|
||||||
|
|
||||||
|
- 所有段位定义(如 `legacy-v1`、`gateway-v2` 等)集中在 `config/app-upgrade-segments.json`,用于描述匹配范围、`segmentId`、`segmentType`、默认 `minCompatibleVersion/description` 以及各渠道的 URL 模板。
|
||||||
|
- `versions` 下的每个节点都会带上 `metadata.segmentId`。自动脚本始终依据该 ID 来定位并更新条目,即便 key 从 `2.1.5` 切换到 `2.1.6` 也不会错位。
|
||||||
|
- 如果某段需要锁死在特定版本(例如 `2.0.0` 的 break change),可在段定义中设置 `segmentType: "breaking"` 并提供 `lockedVersion`,脚本在遇到不匹配的 tag 时会短路报错,保证升级路径安全。
|
||||||
|
- 面对未来新的断层(例如 `3.0.0`),只需要在段定义里新增一段,自动化即可识别并更新。
|
||||||
|
|
||||||
|
## 自动化工作流
|
||||||
|
|
||||||
|
`.github/workflows/update-app-upgrade-config.yml` 会在 GitHub Release(包含正常发布与 Pre Release)触发:
|
||||||
|
|
||||||
|
1. 同时 Checkout 仓库默认分支(用于脚本)和 `x-files/app-upgrade-config` 分支(真实托管配置的分支)。
|
||||||
|
2. 在默认分支目录执行 `yarn tsx scripts/update-app-upgrade-config.ts --tag <tag> --config ../cs/app-upgrade-config.json`,直接重写 `x-files/app-upgrade-config` 分支里的配置文件。
|
||||||
|
3. 如果 `app-upgrade-config.json` 有变化,则通过 `peter-evans/create-pull-request` 自动创建一个指向 `x-files/app-upgrade-config` 的 PR,Diff 仅包含该文件。
|
||||||
|
|
||||||
|
如需本地调试,可执行 `yarn update:upgrade-config --tag v2.1.6 --config ../cs/app-upgrade-config.json`(加 `--dry-run` 仅打印结果)来复现 CI 行为。若需要暂时跳过 GitHub/GitCode Release 页面是否就绪的校验,可在 `--dry-run` 的同时附加 `--skip-release-checks`。不加 `--config` 时默认更新当前工作目录(通常是 main 分支)下的副本,方便文档/审查。
|
||||||
|
|
||||||
|
## 版本匹配逻辑
|
||||||
|
|
||||||
|
### 算法流程
|
||||||
|
|
||||||
|
1. 获取用户当前版本(`currentVersion`)和请求的渠道(`requestedChannel`)
|
||||||
|
2. 获取配置文件中所有版本号,按语义化版本从大到小排序
|
||||||
|
3. 遍历排序后的版本列表:
|
||||||
|
- 检查 `currentVersion >= minCompatibleVersion`
|
||||||
|
- 检查请求的 `channel` 是否存在且不为 `null`
|
||||||
|
- 如果满足条件,返回该渠道配置
|
||||||
|
4. 如果没有找到匹配版本,返回 `null`
|
||||||
|
|
||||||
|
### 伪代码实现
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
function findCompatibleVersion(
|
||||||
|
currentVersion: string,
|
||||||
|
requestedChannel: UpgradeChannel,
|
||||||
|
config: UpdateConfig
|
||||||
|
): ChannelConfig | null {
|
||||||
|
// 获取所有版本号并从大到小排序
|
||||||
|
const versions = Object.keys(config.versions).sort(semver.rcompare)
|
||||||
|
|
||||||
|
for (const versionKey of versions) {
|
||||||
|
const versionConfig = config.versions[versionKey]
|
||||||
|
const channelConfig = versionConfig.channels[requestedChannel]
|
||||||
|
|
||||||
|
// 检查版本兼容性和渠道可用性
|
||||||
|
if (
|
||||||
|
semver.gte(currentVersion, versionConfig.minCompatibleVersion) &&
|
||||||
|
channelConfig !== null
|
||||||
|
) {
|
||||||
|
return channelConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return null // 没有找到兼容版本
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 升级路径示例
|
||||||
|
|
||||||
|
### 场景 1: v1.6.5 用户升级(低于 1.7)
|
||||||
|
|
||||||
|
- **当前版本**: 1.6.5
|
||||||
|
- **请求渠道**: latest
|
||||||
|
- **匹配结果**: 1.7.0
|
||||||
|
- **原因**: 1.6.5 >= 0.0.0(满足 1.7.0 的 minCompatibleVersion),但不满足 2.0.0 的 minCompatibleVersion (1.7.0)
|
||||||
|
- **操作**: 提示用户升级到 1.7.0,这是升级到 v2.x 的必要中间版本
|
||||||
|
|
||||||
|
### 场景 2: v1.6.5 用户请求 rc/beta
|
||||||
|
|
||||||
|
- **当前版本**: 1.6.5
|
||||||
|
- **请求渠道**: rc 或 beta
|
||||||
|
- **匹配结果**: 1.7.0 (latest)
|
||||||
|
- **原因**: 1.7.0 版本不提供 rc/beta 渠道(值为 null)
|
||||||
|
- **操作**: 升级到 1.7.0 稳定版
|
||||||
|
|
||||||
|
### 场景 3: v1.7.0 用户升级到最新版
|
||||||
|
|
||||||
|
- **当前版本**: 1.7.0
|
||||||
|
- **请求渠道**: latest
|
||||||
|
- **匹配结果**: 2.0.0
|
||||||
|
- **原因**: 1.7.0 >= 1.7.0(满足 2.0.0 的 minCompatibleVersion)
|
||||||
|
- **操作**: 直接升级到 2.0.0(当前最新稳定版)
|
||||||
|
|
||||||
|
### 场景 4: v1.7.2 用户升级到 RC 版本
|
||||||
|
|
||||||
|
- **当前版本**: 1.7.2
|
||||||
|
- **请求渠道**: rc
|
||||||
|
- **匹配结果**: 2.0.0-rc.1
|
||||||
|
- **原因**: 1.7.2 >= 1.7.0(满足 2.0.0 的 minCompatibleVersion),且 rc 渠道存在
|
||||||
|
- **操作**: 升级到 2.0.0-rc.1
|
||||||
|
|
||||||
|
### 场景 5: v1.7.0 用户升级到 Beta 版本
|
||||||
|
|
||||||
|
- **当前版本**: 1.7.0
|
||||||
|
- **请求渠道**: beta
|
||||||
|
- **匹配结果**: 2.0.0-beta.1
|
||||||
|
- **原因**: 1.7.0 >= 1.7.0,且 beta 渠道存在
|
||||||
|
- **操作**: 升级到 2.0.0-beta.1
|
||||||
|
|
||||||
|
### 场景 6: v2.5.0 用户升级(未来)
|
||||||
|
|
||||||
|
假设已添加 v2.8.0 和 v3.0.0 配置:
|
||||||
|
- **当前版本**: 2.5.0
|
||||||
|
- **请求渠道**: latest
|
||||||
|
- **匹配结果**: 2.8.0
|
||||||
|
- **原因**: 2.5.0 >= 2.0.0(满足 2.8.0 的 minCompatibleVersion),但不满足 3.0.0 的要求
|
||||||
|
- **操作**: 提示用户升级到 2.8.0,这是升级到 v3.x 的必要中间版本
|
||||||
|
|
||||||
|
## 代码改动计划
|
||||||
|
|
||||||
|
### 主要修改
|
||||||
|
|
||||||
|
1. **新增方法**
|
||||||
|
- `_fetchUpdateConfig(ipCountry: string): Promise<UpdateConfig | null>` - 根据 IP 获取配置文件
|
||||||
|
- `_findCompatibleChannel(currentVersion: string, channel: UpgradeChannel, config: UpdateConfig): ChannelConfig | null` - 查找兼容的渠道配置
|
||||||
|
|
||||||
|
2. **修改方法**
|
||||||
|
- `_getReleaseVersionFromGithub()` → 移除或重构为 `_getChannelFeedUrl()`
|
||||||
|
- `_setFeedUrl()` - 使用新的配置系统替代现有逻辑
|
||||||
|
|
||||||
|
3. **新增类型定义**
|
||||||
|
- `UpdateConfig`
|
||||||
|
- `VersionConfig`
|
||||||
|
- `ChannelConfig`
|
||||||
|
|
||||||
|
### 镜像源选择逻辑
|
||||||
|
|
||||||
|
客户端根据 IP 地理位置自动选择最优镜像源:
|
||||||
|
|
||||||
|
```typescript
|
||||||
|
private async _setFeedUrl() {
|
||||||
|
const currentVersion = app.getVersion()
|
||||||
|
const testPlan = configManager.getTestPlan()
|
||||||
|
const requestedChannel = testPlan ? this._getTestChannel() : UpgradeChannel.LATEST
|
||||||
|
|
||||||
|
// 根据 IP 国家确定镜像源
|
||||||
|
const ipCountry = await getIpCountry()
|
||||||
|
const mirror = ipCountry.toLowerCase() === 'cn' ? 'gitcode' : 'github'
|
||||||
|
|
||||||
|
// 获取更新配置
|
||||||
|
const config = await this._fetchUpdateConfig(mirror)
|
||||||
|
|
||||||
|
if (config) {
|
||||||
|
const channelConfig = this._findCompatibleChannel(currentVersion, requestedChannel, config)
|
||||||
|
if (channelConfig) {
|
||||||
|
// 从配置中选择对应镜像源的 URL
|
||||||
|
const feedUrl = channelConfig.feedUrls[mirror]
|
||||||
|
this._setChannel(requestedChannel, feedUrl)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback 逻辑
|
||||||
|
const defaultFeedUrl = mirror === 'gitcode'
|
||||||
|
? FeedUrl.PRODUCTION
|
||||||
|
: FeedUrl.GITHUB_LATEST
|
||||||
|
this._setChannel(UpgradeChannel.LATEST, defaultFeedUrl)
|
||||||
|
}
|
||||||
|
|
||||||
|
private async _fetchUpdateConfig(mirror: 'github' | 'gitcode'): Promise<UpdateConfig | null> {
|
||||||
|
const configUrl = mirror === 'gitcode'
|
||||||
|
? UpdateConfigUrl.GITCODE
|
||||||
|
: UpdateConfigUrl.GITHUB
|
||||||
|
|
||||||
|
try {
|
||||||
|
const response = await net.fetch(configUrl, {
|
||||||
|
headers: {
|
||||||
|
'User-Agent': generateUserAgent(),
|
||||||
|
'Accept': 'application/json',
|
||||||
|
'X-Client-Id': configManager.getClientId()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return await response.json() as UpdateConfig
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to fetch update config:', error)
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## 降级和容错策略
|
||||||
|
|
||||||
|
1. **配置文件获取失败**: 记录错误日志,返回当前版本,不提供更新
|
||||||
|
2. **没有匹配的版本**: 提示用户当前版本不支持自动升级
|
||||||
|
3. **网络异常**: 缓存上次成功获取的配置(可选)
|
||||||
|
|
||||||
|
## GitHub Release 要求
|
||||||
|
|
||||||
|
为支持中间版本升级,需要保留以下文件:
|
||||||
|
|
||||||
|
- **v1.7.0 release** 及其 latest*.yml 文件(作为 v1.7 以下用户的升级目标)
|
||||||
|
- 未来如需强制中间版本(如 v2.8.0),需要保留对应的 release 和 latest*.yml 文件
|
||||||
|
- 各版本的完整安装包
|
||||||
|
|
||||||
|
### 当前需要的 Release
|
||||||
|
|
||||||
|
| 版本 | 用途 | 必须保留 |
|
||||||
|
|------|------|---------|
|
||||||
|
| v1.7.0 | 1.7 以下用户的升级目标 | ✅ 是 |
|
||||||
|
| v2.0.0-rc.1 | RC 测试渠道 | ❌ 可选 |
|
||||||
|
| v2.0.0-beta.1 | Beta 测试渠道 | ❌ 可选 |
|
||||||
|
| latest | 最新稳定版(自动) | ✅ 是 |
|
||||||
|
|
||||||
|
## 优势
|
||||||
|
|
||||||
|
1. **灵活性**: 支持任意复杂的升级路径
|
||||||
|
2. **可扩展性**: 新增版本只需在配置文件中添加新条目
|
||||||
|
3. **可维护性**: 配置与代码分离,无需发版即可调整升级策略
|
||||||
|
4. **多源支持**: 自动根据地理位置选择最优配置源
|
||||||
|
5. **版本控制**: 强制中间版本升级,确保数据迁移和兼容性
|
||||||
|
|
||||||
|
## 未来扩展
|
||||||
|
|
||||||
|
- 支持更细粒度的版本范围控制(如 `>=1.5.0 <1.8.0`)
|
||||||
|
- 支持多步升级路径提示(如提示用户需要 1.5 → 1.8 → 2.0)
|
||||||
|
- 支持 A/B 测试和灰度发布
|
||||||
|
- 支持配置文件的本地缓存和过期策略
|
||||||
@@ -97,7 +97,6 @@ mac:
|
|||||||
entitlementsInherit: build/entitlements.mac.plist
|
entitlementsInherit: build/entitlements.mac.plist
|
||||||
notarize: false
|
notarize: false
|
||||||
artifactName: ${productName}-${version}-${arch}.${ext}
|
artifactName: ${productName}-${version}-${arch}.${ext}
|
||||||
minimumSystemVersion: "20.1.0" # 最低支持 macOS 11.0
|
|
||||||
extendInfo:
|
extendInfo:
|
||||||
- NSCameraUsageDescription: Application requests access to the device's camera.
|
- NSCameraUsageDescription: Application requests access to the device's camera.
|
||||||
- NSMicrophoneUsageDescription: Application requests access to the device's microphone.
|
- NSMicrophoneUsageDescription: Application requests access to the device's microphone.
|
||||||
@@ -135,42 +134,58 @@ artifactBuildCompleted: scripts/artifact-build-completed.js
|
|||||||
releaseInfo:
|
releaseInfo:
|
||||||
releaseNotes: |
|
releaseNotes: |
|
||||||
<!--LANG:en-->
|
<!--LANG:en-->
|
||||||
What's New in v1.7.0-beta.6
|
What's New in v1.7.0-rc.1
|
||||||
|
|
||||||
New Features:
|
🎉 MAJOR NEW FEATURE: AI Agents
|
||||||
- Enhanced Input Bar: Completely redesigned input bar with improved responsiveness and functionality
|
- Create and manage custom AI agents with specialized tools and permissions
|
||||||
- Better File Handling: Improved drag-and-drop and paste support for images and documents
|
- Dedicated agent sessions with persistent SQLite storage, separate from regular chats
|
||||||
- Smart Tool Suggestions: Enhanced quick panel with better item selection and keyboard shortcuts
|
- Real-time tool approval system - review and approve agent actions dynamically
|
||||||
|
- MCP (Model Context Protocol) integration for connecting external tools
|
||||||
|
- Slash commands support for quick agent interactions
|
||||||
|
- OpenAI-compatible REST API for agent access
|
||||||
|
|
||||||
Improvements:
|
✨ New Features:
|
||||||
- Smoother Input Experience: Better auto-resizing and text handling in chat input
|
- AI Providers: Added support for Hugging Face, Mistral, Perplexity, and SophNet
|
||||||
- Enhanced AI Performance: Improved connection stability and response speed
|
- Knowledge Base: OpenMinerU document preprocessor, full-text search in notes, enhanced tool selection
|
||||||
- More Reliable File Uploads: Better support for various file types and upload scenarios
|
- Image & OCR: Intel OVMS painting provider and Intel OpenVINO (NPU) OCR support
|
||||||
- Cleaner Interface: Optimized UI elements for better visual consistency
|
- MCP Management: Redesigned interface with dual-column layout for easier management
|
||||||
|
- Languages: Added German language support
|
||||||
|
|
||||||
Bug Fixes:
|
⚡ Improvements:
|
||||||
- Fixed image selection issue when adding custom AI providers
|
- Upgraded to Electron 38.7.0
|
||||||
- Fixed file upload problems with certain API configurations
|
- Enhanced system shutdown handling and automatic update checks
|
||||||
- Fixed input bar responsiveness issues
|
- Improved proxy bypass rules
|
||||||
- Fixed quick panel not working properly in some situations
|
|
||||||
|
🐛 Important Bug Fixes:
|
||||||
|
- Fixed streaming response issues across multiple AI providers
|
||||||
|
- Fixed session list scrolling problems
|
||||||
|
- Fixed knowledge base deletion errors
|
||||||
|
|
||||||
<!--LANG:zh-CN-->
|
<!--LANG:zh-CN-->
|
||||||
v1.7.0-beta.6 新特性
|
v1.7.0-rc.1 新特性
|
||||||
|
|
||||||
新功能:
|
🎉 重大更新:AI Agent 智能体系统
|
||||||
- 增强输入栏:完全重新设计的输入栏,响应更灵敏,功能更强大
|
- 创建和管理专属 AI Agent,配置专用工具和权限
|
||||||
- 更好的文件处理:改进的拖拽和粘贴功能,支持图片和文档
|
- 独立的 Agent 会话,使用 SQLite 持久化存储,与普通聊天分离
|
||||||
- 智能工具建议:增强的快速面板,更好的项目选择和键盘快捷键
|
- 实时工具审批系统 - 动态审查和批准 Agent 操作
|
||||||
|
- MCP(模型上下文协议)集成,连接外部工具
|
||||||
|
- 支持斜杠命令快速交互
|
||||||
|
- 兼容 OpenAI 的 REST API 访问
|
||||||
|
|
||||||
改进:
|
✨ 新功能:
|
||||||
- 更流畅的输入体验:聊天输入框的自动调整和文本处理更佳
|
- AI 提供商:新增 Hugging Face、Mistral、Perplexity 和 SophNet 支持
|
||||||
- 增强 AI 性能:改进连接稳定性和响应速度
|
- 知识库:OpenMinerU 文档预处理器、笔记全文搜索、增强的工具选择
|
||||||
- 更可靠的文件上传:更好地支持各种文件类型和上传场景
|
- 图像与 OCR:Intel OVMS 绘图提供商和 Intel OpenVINO (NPU) OCR 支持
|
||||||
- 更简洁的界面:优化 UI 元素,视觉一致性更好
|
- MCP 管理:重构管理界面,采用双列布局,更加方便管理
|
||||||
|
- 语言:新增德语支持
|
||||||
|
|
||||||
问题修复:
|
⚡ 改进:
|
||||||
- 修复添加自定义 AI 提供商时的图片选择问题
|
- 升级到 Electron 38.7.0
|
||||||
- 修复某些 API 配置下的文件上传问题
|
- 增强的系统关机处理和自动更新检查
|
||||||
- 修复输入栏响应性问题
|
- 改进的代理绕过规则
|
||||||
- 修复快速面板在某些情况下无法正常工作的问题
|
|
||||||
|
🐛 重要修复:
|
||||||
|
- 修复多个 AI 提供商的流式响应问题
|
||||||
|
- 修复会话列表滚动问题
|
||||||
|
- 修复知识库删除错误
|
||||||
<!--LANG:END-->
|
<!--LANG:END-->
|
||||||
|
|||||||
55
package.json
55
package.json
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "CherryStudio",
|
"name": "CherryStudio",
|
||||||
"version": "1.7.0-beta.3",
|
"version": "1.7.0-rc.1",
|
||||||
"private": true,
|
"private": true,
|
||||||
"description": "A powerful AI assistant for producer.",
|
"description": "A powerful AI assistant for producer.",
|
||||||
"main": "./out/main/index.js",
|
"main": "./out/main/index.js",
|
||||||
@@ -58,6 +58,7 @@
|
|||||||
"update:i18n": "dotenv -e .env -- tsx scripts/update-i18n.ts",
|
"update:i18n": "dotenv -e .env -- tsx scripts/update-i18n.ts",
|
||||||
"auto:i18n": "dotenv -e .env -- tsx scripts/auto-translate-i18n.ts",
|
"auto:i18n": "dotenv -e .env -- tsx scripts/auto-translate-i18n.ts",
|
||||||
"update:languages": "tsx scripts/update-languages.ts",
|
"update:languages": "tsx scripts/update-languages.ts",
|
||||||
|
"update:upgrade-config": "tsx scripts/update-app-upgrade-config.ts",
|
||||||
"test": "vitest run --silent",
|
"test": "vitest run --silent",
|
||||||
"test:main": "vitest run --project main",
|
"test:main": "vitest run --project main",
|
||||||
"test:renderer": "vitest run --project renderer",
|
"test:renderer": "vitest run --project renderer",
|
||||||
@@ -73,9 +74,10 @@
|
|||||||
"format:check": "biome format && biome lint",
|
"format:check": "biome format && biome lint",
|
||||||
"prepare": "git config blame.ignoreRevsFile .git-blame-ignore-revs && husky",
|
"prepare": "git config blame.ignoreRevsFile .git-blame-ignore-revs && husky",
|
||||||
"claude": "dotenv -e .env -- claude",
|
"claude": "dotenv -e .env -- claude",
|
||||||
"release:aicore:alpha": "yarn workspace @cherrystudio/ai-core version prerelease --immediate && yarn workspace @cherrystudio/ai-core npm publish --tag alpha --access public",
|
"release:aicore:alpha": "yarn workspace @cherrystudio/ai-core version prerelease --preid alpha --immediate && yarn workspace @cherrystudio/ai-core build && yarn workspace @cherrystudio/ai-core npm publish --tag alpha --access public",
|
||||||
"release:aicore:beta": "yarn workspace @cherrystudio/ai-core version prerelease --immediate && yarn workspace @cherrystudio/ai-core npm publish --tag beta --access public",
|
"release:aicore:beta": "yarn workspace @cherrystudio/ai-core version prerelease --preid beta --immediate && yarn workspace @cherrystudio/ai-core build && yarn workspace @cherrystudio/ai-core npm publish --tag beta --access public",
|
||||||
"release:aicore": "yarn workspace @cherrystudio/ai-core version patch --immediate && yarn workspace @cherrystudio/ai-core npm publish --access public"
|
"release:aicore": "yarn workspace @cherrystudio/ai-core version patch --immediate && yarn workspace @cherrystudio/ai-core build && yarn workspace @cherrystudio/ai-core npm publish --access public",
|
||||||
|
"release:ai-sdk-provider": "yarn workspace @cherrystudio/ai-sdk-provider version patch --immediate && yarn workspace @cherrystudio/ai-sdk-provider build && yarn workspace @cherrystudio/ai-sdk-provider npm publish --access public"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@anthropic-ai/claude-agent-sdk": "patch:@anthropic-ai/claude-agent-sdk@npm%3A0.1.30#~/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.30-b50a299674.patch",
|
"@anthropic-ai/claude-agent-sdk": "patch:@anthropic-ai/claude-agent-sdk@npm%3A0.1.30#~/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.30-b50a299674.patch",
|
||||||
@@ -84,6 +86,7 @@
|
|||||||
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch",
|
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch",
|
||||||
"@paymoapp/electron-shutdown-handler": "^1.1.2",
|
"@paymoapp/electron-shutdown-handler": "^1.1.2",
|
||||||
"@strongtz/win32-arm64-msvc": "^0.4.7",
|
"@strongtz/win32-arm64-msvc": "^0.4.7",
|
||||||
|
"emoji-picker-element-data": "^1",
|
||||||
"express": "^5.1.0",
|
"express": "^5.1.0",
|
||||||
"font-list": "^2.0.0",
|
"font-list": "^2.0.0",
|
||||||
"graceful-fs": "^4.2.11",
|
"graceful-fs": "^4.2.11",
|
||||||
@@ -106,13 +109,17 @@
|
|||||||
"@agentic/exa": "^7.3.3",
|
"@agentic/exa": "^7.3.3",
|
||||||
"@agentic/searxng": "^7.3.3",
|
"@agentic/searxng": "^7.3.3",
|
||||||
"@agentic/tavily": "^7.3.3",
|
"@agentic/tavily": "^7.3.3",
|
||||||
"@ai-sdk/amazon-bedrock": "^3.0.53",
|
"@ai-sdk/amazon-bedrock": "^3.0.56",
|
||||||
|
"@ai-sdk/anthropic": "^2.0.45",
|
||||||
"@ai-sdk/cerebras": "^1.0.31",
|
"@ai-sdk/cerebras": "^1.0.31",
|
||||||
"@ai-sdk/gateway": "^2.0.9",
|
"@ai-sdk/gateway": "^2.0.13",
|
||||||
"@ai-sdk/google-vertex": "^3.0.62",
|
"@ai-sdk/google": "patch:@ai-sdk/google@npm%3A2.0.40#~/.yarn/patches/@ai-sdk-google-npm-2.0.40-47e0eeee83.patch",
|
||||||
"@ai-sdk/huggingface": "patch:@ai-sdk/huggingface@npm%3A0.0.8#~/.yarn/patches/@ai-sdk-huggingface-npm-0.0.8-d4d0aaac93.patch",
|
"@ai-sdk/google-vertex": "^3.0.72",
|
||||||
"@ai-sdk/mistral": "^2.0.23",
|
"@ai-sdk/huggingface": "^0.0.10",
|
||||||
"@ai-sdk/perplexity": "^2.0.17",
|
"@ai-sdk/mistral": "^2.0.24",
|
||||||
|
"@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.71#~/.yarn/patches/@ai-sdk-openai-npm-2.0.71-a88ef00525.patch",
|
||||||
|
"@ai-sdk/perplexity": "^2.0.20",
|
||||||
|
"@ai-sdk/test-server": "^0.0.1",
|
||||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||||
"@anthropic-ai/sdk": "^0.41.0",
|
"@anthropic-ai/sdk": "^0.41.0",
|
||||||
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
|
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
|
||||||
@@ -120,7 +127,7 @@
|
|||||||
"@aws-sdk/client-bedrock-runtime": "^3.910.0",
|
"@aws-sdk/client-bedrock-runtime": "^3.910.0",
|
||||||
"@aws-sdk/client-s3": "^3.910.0",
|
"@aws-sdk/client-s3": "^3.910.0",
|
||||||
"@biomejs/biome": "2.2.4",
|
"@biomejs/biome": "2.2.4",
|
||||||
"@cherrystudio/ai-core": "workspace:^1.0.0-alpha.18",
|
"@cherrystudio/ai-core": "workspace:^1.0.9",
|
||||||
"@cherrystudio/embedjs": "^0.1.31",
|
"@cherrystudio/embedjs": "^0.1.31",
|
||||||
"@cherrystudio/embedjs-libsql": "^0.1.31",
|
"@cherrystudio/embedjs-libsql": "^0.1.31",
|
||||||
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
|
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
|
||||||
@@ -134,7 +141,7 @@
|
|||||||
"@cherrystudio/embedjs-ollama": "^0.1.31",
|
"@cherrystudio/embedjs-ollama": "^0.1.31",
|
||||||
"@cherrystudio/embedjs-openai": "^0.1.31",
|
"@cherrystudio/embedjs-openai": "^0.1.31",
|
||||||
"@cherrystudio/extension-table-plus": "workspace:^",
|
"@cherrystudio/extension-table-plus": "workspace:^",
|
||||||
"@cherrystudio/openai": "^6.5.0",
|
"@cherrystudio/openai": "^6.9.0",
|
||||||
"@dnd-kit/core": "^6.3.1",
|
"@dnd-kit/core": "^6.3.1",
|
||||||
"@dnd-kit/modifiers": "^9.0.0",
|
"@dnd-kit/modifiers": "^9.0.0",
|
||||||
"@dnd-kit/sortable": "^10.0.0",
|
"@dnd-kit/sortable": "^10.0.0",
|
||||||
@@ -153,18 +160,19 @@
|
|||||||
"@langchain/community": "^1.0.0",
|
"@langchain/community": "^1.0.0",
|
||||||
"@langchain/core": "patch:@langchain/core@npm%3A1.0.2#~/.yarn/patches/@langchain-core-npm-1.0.2-183ef83fe4.patch",
|
"@langchain/core": "patch:@langchain/core@npm%3A1.0.2#~/.yarn/patches/@langchain-core-npm-1.0.2-183ef83fe4.patch",
|
||||||
"@langchain/openai": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
|
"@langchain/openai": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
|
||||||
|
"@mcp-ui/client": "^5.14.1",
|
||||||
"@mistralai/mistralai": "^1.7.5",
|
"@mistralai/mistralai": "^1.7.5",
|
||||||
"@modelcontextprotocol/sdk": "^1.17.5",
|
"@modelcontextprotocol/sdk": "^1.17.5",
|
||||||
"@mozilla/readability": "^0.6.0",
|
"@mozilla/readability": "^0.6.0",
|
||||||
"@notionhq/client": "^2.2.15",
|
"@notionhq/client": "^2.2.15",
|
||||||
"@openrouter/ai-sdk-provider": "^1.2.0",
|
"@openrouter/ai-sdk-provider": "^1.2.5",
|
||||||
"@opentelemetry/api": "^1.9.0",
|
"@opentelemetry/api": "^1.9.0",
|
||||||
"@opentelemetry/core": "2.0.0",
|
"@opentelemetry/core": "2.0.0",
|
||||||
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
|
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
|
||||||
"@opentelemetry/sdk-trace-base": "^2.0.0",
|
"@opentelemetry/sdk-trace-base": "^2.0.0",
|
||||||
"@opentelemetry/sdk-trace-node": "^2.0.0",
|
"@opentelemetry/sdk-trace-node": "^2.0.0",
|
||||||
"@opentelemetry/sdk-trace-web": "^2.0.0",
|
"@opentelemetry/sdk-trace-web": "^2.0.0",
|
||||||
"@opeoginni/github-copilot-openai-compatible": "0.1.19",
|
"@opeoginni/github-copilot-openai-compatible": "0.1.21",
|
||||||
"@playwright/test": "^1.52.0",
|
"@playwright/test": "^1.52.0",
|
||||||
"@radix-ui/react-context-menu": "^2.2.16",
|
"@radix-ui/react-context-menu": "^2.2.16",
|
||||||
"@reduxjs/toolkit": "^2.2.5",
|
"@reduxjs/toolkit": "^2.2.5",
|
||||||
@@ -233,7 +241,7 @@
|
|||||||
"@viz-js/lang-dot": "^1.0.5",
|
"@viz-js/lang-dot": "^1.0.5",
|
||||||
"@viz-js/viz": "^3.14.0",
|
"@viz-js/viz": "^3.14.0",
|
||||||
"@xyflow/react": "^12.4.4",
|
"@xyflow/react": "^12.4.4",
|
||||||
"ai": "^5.0.90",
|
"ai": "^5.0.98",
|
||||||
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
|
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
|
||||||
"archiver": "^7.0.1",
|
"archiver": "^7.0.1",
|
||||||
"async-mutex": "^0.5.0",
|
"async-mutex": "^0.5.0",
|
||||||
@@ -259,12 +267,12 @@
|
|||||||
"dotenv-cli": "^7.4.2",
|
"dotenv-cli": "^7.4.2",
|
||||||
"drizzle-kit": "^0.31.4",
|
"drizzle-kit": "^0.31.4",
|
||||||
"drizzle-orm": "^0.44.5",
|
"drizzle-orm": "^0.44.5",
|
||||||
"electron": "38.4.0",
|
"electron": "38.7.0",
|
||||||
"electron-builder": "26.0.15",
|
"electron-builder": "26.1.0",
|
||||||
"electron-devtools-installer": "^3.2.0",
|
"electron-devtools-installer": "^3.2.0",
|
||||||
"electron-reload": "^2.0.0-alpha.1",
|
"electron-reload": "^2.0.0-alpha.1",
|
||||||
"electron-store": "^8.2.0",
|
"electron-store": "^8.2.0",
|
||||||
"electron-updater": "6.6.4",
|
"electron-updater": "patch:electron-updater@npm%3A6.7.0#~/.yarn/patches/electron-updater-npm-6.7.0-47b11bb0d4.patch",
|
||||||
"electron-vite": "4.0.1",
|
"electron-vite": "4.0.1",
|
||||||
"electron-window-state": "^5.0.3",
|
"electron-window-state": "^5.0.3",
|
||||||
"emittery": "^1.0.3",
|
"emittery": "^1.0.3",
|
||||||
@@ -381,13 +389,11 @@
|
|||||||
"@codemirror/lint": "6.8.5",
|
"@codemirror/lint": "6.8.5",
|
||||||
"@codemirror/view": "6.38.1",
|
"@codemirror/view": "6.38.1",
|
||||||
"@langchain/core@npm:^0.3.26": "patch:@langchain/core@npm%3A1.0.2#~/.yarn/patches/@langchain-core-npm-1.0.2-183ef83fe4.patch",
|
"@langchain/core@npm:^0.3.26": "patch:@langchain/core@npm%3A1.0.2#~/.yarn/patches/@langchain-core-npm-1.0.2-183ef83fe4.patch",
|
||||||
"app-builder-lib@npm:26.0.13": "patch:app-builder-lib@npm%3A26.0.13#~/.yarn/patches/app-builder-lib-npm-26.0.13-a064c9e1d0.patch",
|
|
||||||
"app-builder-lib@npm:26.0.15": "patch:app-builder-lib@npm%3A26.0.15#~/.yarn/patches/app-builder-lib-npm-26.0.15-360e5b0476.patch",
|
|
||||||
"atomically@npm:^1.7.0": "patch:atomically@npm%3A1.7.0#~/.yarn/patches/atomically-npm-1.7.0-e742e5293b.patch",
|
"atomically@npm:^1.7.0": "patch:atomically@npm%3A1.7.0#~/.yarn/patches/atomically-npm-1.7.0-e742e5293b.patch",
|
||||||
"esbuild": "^0.25.0",
|
"esbuild": "^0.25.0",
|
||||||
"file-stream-rotator@npm:^0.6.1": "patch:file-stream-rotator@npm%3A0.6.1#~/.yarn/patches/file-stream-rotator-npm-0.6.1-eab45fb13d.patch",
|
"file-stream-rotator@npm:^0.6.1": "patch:file-stream-rotator@npm%3A0.6.1#~/.yarn/patches/file-stream-rotator-npm-0.6.1-eab45fb13d.patch",
|
||||||
"libsql@npm:^0.4.4": "patch:libsql@npm%3A0.4.7#~/.yarn/patches/libsql-npm-0.4.7-444e260fb1.patch",
|
"libsql@npm:^0.4.4": "patch:libsql@npm%3A0.4.7#~/.yarn/patches/libsql-npm-0.4.7-444e260fb1.patch",
|
||||||
"node-abi": "4.12.0",
|
"node-abi": "4.24.0",
|
||||||
"openai@npm:^4.77.0": "npm:@cherrystudio/openai@6.5.0",
|
"openai@npm:^4.77.0": "npm:@cherrystudio/openai@6.5.0",
|
||||||
"openai@npm:^4.87.3": "npm:@cherrystudio/openai@6.5.0",
|
"openai@npm:^4.87.3": "npm:@cherrystudio/openai@6.5.0",
|
||||||
"pdf-parse@npm:1.1.1": "patch:pdf-parse@npm%3A1.1.1#~/.yarn/patches/pdf-parse-npm-1.1.1-04a6109b2a.patch",
|
"pdf-parse@npm:1.1.1": "patch:pdf-parse@npm%3A1.1.1#~/.yarn/patches/pdf-parse-npm-1.1.1-04a6109b2a.patch",
|
||||||
@@ -408,8 +414,11 @@
|
|||||||
"@langchain/openai@npm:^0.3.16": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
|
"@langchain/openai@npm:^0.3.16": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
|
||||||
"@langchain/openai@npm:>=0.2.0 <0.7.0": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
|
"@langchain/openai@npm:>=0.2.0 <0.7.0": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
|
||||||
"@ai-sdk/openai@npm:2.0.64": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch",
|
"@ai-sdk/openai@npm:2.0.64": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch",
|
||||||
"@ai-sdk/openai@npm:^2.0.42": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch",
|
"@ai-sdk/openai@npm:^2.0.42": "patch:@ai-sdk/openai@npm%3A2.0.71#~/.yarn/patches/@ai-sdk-openai-npm-2.0.71-a88ef00525.patch",
|
||||||
"@ai-sdk/google@npm:2.0.31": "patch:@ai-sdk/google@npm%3A2.0.31#~/.yarn/patches/@ai-sdk-google-npm-2.0.31-b0de047210.patch"
|
"@ai-sdk/google@npm:2.0.40": "patch:@ai-sdk/google@npm%3A2.0.40#~/.yarn/patches/@ai-sdk-google-npm-2.0.40-47e0eeee83.patch",
|
||||||
|
"@ai-sdk/openai@npm:2.0.71": "patch:@ai-sdk/openai@npm%3A2.0.71#~/.yarn/patches/@ai-sdk-openai-npm-2.0.71-a88ef00525.patch",
|
||||||
|
"@ai-sdk/openai-compatible@npm:1.0.27": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch",
|
||||||
|
"@ai-sdk/openai-compatible@npm:^1.0.19": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch"
|
||||||
},
|
},
|
||||||
"packageManager": "yarn@4.9.1",
|
"packageManager": "yarn@4.9.1",
|
||||||
"lint-staged": {
|
"lint-staged": {
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@cherrystudio/ai-sdk-provider",
|
"name": "@cherrystudio/ai-sdk-provider",
|
||||||
"version": "0.1.0",
|
"version": "0.1.3",
|
||||||
"description": "Cherry Studio AI SDK provider bundle with CherryIN routing.",
|
"description": "Cherry Studio AI SDK provider bundle with CherryIN routing.",
|
||||||
"keywords": [
|
"keywords": [
|
||||||
"ai-sdk",
|
"ai-sdk",
|
||||||
@@ -42,7 +42,7 @@
|
|||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@ai-sdk/provider": "^2.0.0",
|
"@ai-sdk/provider": "^2.0.0",
|
||||||
"@ai-sdk/provider-utils": "^3.0.12"
|
"@ai-sdk/provider-utils": "^3.0.17"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
"tsdown": "^0.13.3",
|
"tsdown": "^0.13.3",
|
||||||
|
|||||||
@@ -67,6 +67,10 @@ export interface CherryInProviderSettings {
|
|||||||
* Optional static headers applied to every request.
|
* Optional static headers applied to every request.
|
||||||
*/
|
*/
|
||||||
headers?: HeadersInput
|
headers?: HeadersInput
|
||||||
|
/**
|
||||||
|
* Optional endpoint type to distinguish different endpoint behaviors.
|
||||||
|
*/
|
||||||
|
endpointType?: 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'image-generation' | 'jina-rerank'
|
||||||
}
|
}
|
||||||
|
|
||||||
export interface CherryInProvider extends ProviderV2 {
|
export interface CherryInProvider extends ProviderV2 {
|
||||||
@@ -151,7 +155,8 @@ export const createCherryIn = (options: CherryInProviderSettings = {}): CherryIn
|
|||||||
baseURL = DEFAULT_CHERRYIN_BASE_URL,
|
baseURL = DEFAULT_CHERRYIN_BASE_URL,
|
||||||
anthropicBaseURL = DEFAULT_CHERRYIN_ANTHROPIC_BASE_URL,
|
anthropicBaseURL = DEFAULT_CHERRYIN_ANTHROPIC_BASE_URL,
|
||||||
geminiBaseURL = DEFAULT_CHERRYIN_GEMINI_BASE_URL,
|
geminiBaseURL = DEFAULT_CHERRYIN_GEMINI_BASE_URL,
|
||||||
fetch
|
fetch,
|
||||||
|
endpointType
|
||||||
} = options
|
} = options
|
||||||
|
|
||||||
const getJsonHeaders = createJsonHeadersGetter(options)
|
const getJsonHeaders = createJsonHeadersGetter(options)
|
||||||
@@ -205,7 +210,7 @@ export const createCherryIn = (options: CherryInProviderSettings = {}): CherryIn
|
|||||||
fetch
|
fetch
|
||||||
})
|
})
|
||||||
|
|
||||||
const createChatModel = (modelId: string, settings: OpenAIProviderSettings = {}) => {
|
const createChatModelByModelId = (modelId: string, settings: OpenAIProviderSettings = {}) => {
|
||||||
if (isAnthropicModel(modelId)) {
|
if (isAnthropicModel(modelId)) {
|
||||||
return createAnthropicModel(modelId)
|
return createAnthropicModel(modelId)
|
||||||
}
|
}
|
||||||
@@ -223,6 +228,29 @@ export const createCherryIn = (options: CherryInProviderSettings = {}): CherryIn
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const createChatModel = (modelId: string, settings: OpenAIProviderSettings = {}) => {
|
||||||
|
if (!endpointType) return createChatModelByModelId(modelId, settings)
|
||||||
|
switch (endpointType) {
|
||||||
|
case 'anthropic':
|
||||||
|
return createAnthropicModel(modelId)
|
||||||
|
case 'gemini':
|
||||||
|
return createGeminiModel(modelId)
|
||||||
|
case 'openai':
|
||||||
|
return createOpenAIChatModel(modelId)
|
||||||
|
case 'openai-response':
|
||||||
|
default:
|
||||||
|
return new OpenAIResponsesLanguageModel(modelId, {
|
||||||
|
provider: `${CHERRYIN_PROVIDER_NAME}.openai`,
|
||||||
|
url,
|
||||||
|
headers: () => ({
|
||||||
|
...getJsonHeaders(),
|
||||||
|
...settings.headers
|
||||||
|
}),
|
||||||
|
fetch
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const createCompletionModel = (modelId: string, settings: OpenAIProviderSettings = {}) =>
|
const createCompletionModel = (modelId: string, settings: OpenAIProviderSettings = {}) =>
|
||||||
new OpenAICompletionLanguageModel(modelId, {
|
new OpenAICompletionLanguageModel(modelId, {
|
||||||
provider: `${CHERRYIN_PROVIDER_NAME}.completion`,
|
provider: `${CHERRYIN_PROVIDER_NAME}.completion`,
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ Cherry Studio AI Core 是一个基于 Vercel AI SDK 的统一 AI Provider 接口
|
|||||||
## 安装
|
## 安装
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
npm install @cherrystudio/ai-core ai
|
npm install @cherrystudio/ai-core ai @ai-sdk/google @ai-sdk/openai
|
||||||
```
|
```
|
||||||
|
|
||||||
### React Native
|
### React Native
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
{
|
{
|
||||||
"name": "@cherrystudio/ai-core",
|
"name": "@cherrystudio/ai-core",
|
||||||
"version": "1.0.1",
|
"version": "1.0.9",
|
||||||
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
|
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
|
||||||
"main": "dist/index.js",
|
"main": "dist/index.js",
|
||||||
"module": "dist/index.mjs",
|
"module": "dist/index.mjs",
|
||||||
@@ -33,19 +33,19 @@
|
|||||||
},
|
},
|
||||||
"homepage": "https://github.com/CherryHQ/cherry-studio#readme",
|
"homepage": "https://github.com/CherryHQ/cherry-studio#readme",
|
||||||
"peerDependencies": {
|
"peerDependencies": {
|
||||||
|
"@ai-sdk/google": "^2.0.36",
|
||||||
|
"@ai-sdk/openai": "^2.0.64",
|
||||||
|
"@cherrystudio/ai-sdk-provider": "^0.1.3",
|
||||||
"ai": "^5.0.26"
|
"ai": "^5.0.26"
|
||||||
},
|
},
|
||||||
"dependencies": {
|
"dependencies": {
|
||||||
"@ai-sdk/anthropic": "^2.0.43",
|
"@ai-sdk/anthropic": "^2.0.45",
|
||||||
"@ai-sdk/azure": "^2.0.66",
|
"@ai-sdk/azure": "^2.0.73",
|
||||||
"@ai-sdk/deepseek": "^1.0.27",
|
"@ai-sdk/deepseek": "^1.0.29",
|
||||||
"@ai-sdk/google": "patch:@ai-sdk/google@npm%3A2.0.31#~/.yarn/patches/@ai-sdk-google-npm-2.0.31-b0de047210.patch",
|
"@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch",
|
||||||
"@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch",
|
|
||||||
"@ai-sdk/openai-compatible": "^1.0.26",
|
|
||||||
"@ai-sdk/provider": "^2.0.0",
|
"@ai-sdk/provider": "^2.0.0",
|
||||||
"@ai-sdk/provider-utils": "^3.0.16",
|
"@ai-sdk/provider-utils": "^3.0.17",
|
||||||
"@ai-sdk/xai": "^2.0.31",
|
"@ai-sdk/xai": "^2.0.34",
|
||||||
"@cherrystudio/ai-sdk-provider": "workspace:*",
|
|
||||||
"zod": "^4.1.5"
|
"zod": "^4.1.5"
|
||||||
},
|
},
|
||||||
"devDependencies": {
|
"devDependencies": {
|
||||||
|
|||||||
180
packages/aiCore/src/__tests__/fixtures/mock-providers.ts
Normal file
180
packages/aiCore/src/__tests__/fixtures/mock-providers.ts
Normal file
@@ -0,0 +1,180 @@
|
|||||||
|
/**
|
||||||
|
* Mock Provider Instances
|
||||||
|
* Provides mock implementations for all supported AI providers
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type { ImageModelV2, LanguageModelV2 } from '@ai-sdk/provider'
|
||||||
|
import { vi } from 'vitest'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a mock language model with customizable behavior
|
||||||
|
*/
|
||||||
|
export function createMockLanguageModel(overrides?: Partial<LanguageModelV2>): LanguageModelV2 {
|
||||||
|
return {
|
||||||
|
specificationVersion: 'v1',
|
||||||
|
provider: 'mock-provider',
|
||||||
|
modelId: 'mock-model',
|
||||||
|
defaultObjectGenerationMode: 'tool',
|
||||||
|
|
||||||
|
doGenerate: vi.fn().mockResolvedValue({
|
||||||
|
text: 'Mock response text',
|
||||||
|
finishReason: 'stop',
|
||||||
|
usage: {
|
||||||
|
promptTokens: 10,
|
||||||
|
completionTokens: 20,
|
||||||
|
totalTokens: 30
|
||||||
|
},
|
||||||
|
rawCall: { rawPrompt: null, rawSettings: {} },
|
||||||
|
rawResponse: { headers: {} },
|
||||||
|
warnings: []
|
||||||
|
}),
|
||||||
|
|
||||||
|
doStream: vi.fn().mockReturnValue({
|
||||||
|
stream: (async function* () {
|
||||||
|
yield {
|
||||||
|
type: 'text-delta',
|
||||||
|
textDelta: 'Mock '
|
||||||
|
}
|
||||||
|
yield {
|
||||||
|
type: 'text-delta',
|
||||||
|
textDelta: 'streaming '
|
||||||
|
}
|
||||||
|
yield {
|
||||||
|
type: 'text-delta',
|
||||||
|
textDelta: 'response'
|
||||||
|
}
|
||||||
|
yield {
|
||||||
|
type: 'finish',
|
||||||
|
finishReason: 'stop',
|
||||||
|
usage: {
|
||||||
|
promptTokens: 10,
|
||||||
|
completionTokens: 15,
|
||||||
|
totalTokens: 25
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})(),
|
||||||
|
rawCall: { rawPrompt: null, rawSettings: {} },
|
||||||
|
rawResponse: { headers: {} },
|
||||||
|
warnings: []
|
||||||
|
}),
|
||||||
|
|
||||||
|
...overrides
|
||||||
|
} as LanguageModelV2
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a mock image model with customizable behavior
|
||||||
|
*/
|
||||||
|
export function createMockImageModel(overrides?: Partial<ImageModelV2>): ImageModelV2 {
|
||||||
|
return {
|
||||||
|
specificationVersion: 'v2',
|
||||||
|
provider: 'mock-provider',
|
||||||
|
modelId: 'mock-image-model',
|
||||||
|
|
||||||
|
doGenerate: vi.fn().mockResolvedValue({
|
||||||
|
images: [
|
||||||
|
{
|
||||||
|
base64: 'mock-base64-image-data',
|
||||||
|
uint8Array: new Uint8Array([1, 2, 3, 4, 5]),
|
||||||
|
mimeType: 'image/png'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
warnings: []
|
||||||
|
}),
|
||||||
|
|
||||||
|
...overrides
|
||||||
|
} as ImageModelV2
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Mock provider configurations for testing
|
||||||
|
*/
|
||||||
|
export const mockProviderConfigs = {
|
||||||
|
openai: {
|
||||||
|
apiKey: 'sk-test-openai-key-123456789',
|
||||||
|
baseURL: 'https://api.openai.com/v1',
|
||||||
|
organization: 'test-org'
|
||||||
|
},
|
||||||
|
|
||||||
|
anthropic: {
|
||||||
|
apiKey: 'sk-ant-test-key-123456789',
|
||||||
|
baseURL: 'https://api.anthropic.com'
|
||||||
|
},
|
||||||
|
|
||||||
|
google: {
|
||||||
|
apiKey: 'test-google-api-key-123456789',
|
||||||
|
baseURL: 'https://generativelanguage.googleapis.com/v1'
|
||||||
|
},
|
||||||
|
|
||||||
|
xai: {
|
||||||
|
apiKey: 'xai-test-key-123456789',
|
||||||
|
baseURL: 'https://api.x.ai/v1'
|
||||||
|
},
|
||||||
|
|
||||||
|
azure: {
|
||||||
|
apiKey: 'test-azure-key-123456789',
|
||||||
|
resourceName: 'test-resource',
|
||||||
|
deployment: 'test-deployment'
|
||||||
|
},
|
||||||
|
|
||||||
|
deepseek: {
|
||||||
|
apiKey: 'sk-test-deepseek-key-123456789',
|
||||||
|
baseURL: 'https://api.deepseek.com/v1'
|
||||||
|
},
|
||||||
|
|
||||||
|
openrouter: {
|
||||||
|
apiKey: 'sk-or-test-key-123456789',
|
||||||
|
baseURL: 'https://openrouter.ai/api/v1'
|
||||||
|
},
|
||||||
|
|
||||||
|
huggingface: {
|
||||||
|
apiKey: 'hf_test_key_123456789',
|
||||||
|
baseURL: 'https://api-inference.huggingface.co'
|
||||||
|
},
|
||||||
|
|
||||||
|
'openai-compatible': {
|
||||||
|
apiKey: 'test-compatible-key-123456789',
|
||||||
|
baseURL: 'https://api.example.com/v1',
|
||||||
|
name: 'test-provider'
|
||||||
|
},
|
||||||
|
|
||||||
|
'openai-chat': {
|
||||||
|
apiKey: 'sk-test-chat-key-123456789',
|
||||||
|
baseURL: 'https://api.openai.com/v1'
|
||||||
|
}
|
||||||
|
} as const
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Mock provider instances for testing
|
||||||
|
*/
|
||||||
|
export const mockProviderInstances = {
|
||||||
|
openai: {
|
||||||
|
name: 'openai-mock',
|
||||||
|
languageModel: createMockLanguageModel({ provider: 'openai', modelId: 'gpt-4' }),
|
||||||
|
imageModel: createMockImageModel({ provider: 'openai', modelId: 'dall-e-3' })
|
||||||
|
},
|
||||||
|
|
||||||
|
anthropic: {
|
||||||
|
name: 'anthropic-mock',
|
||||||
|
languageModel: createMockLanguageModel({ provider: 'anthropic', modelId: 'claude-3-5-sonnet-20241022' })
|
||||||
|
},
|
||||||
|
|
||||||
|
google: {
|
||||||
|
name: 'google-mock',
|
||||||
|
languageModel: createMockLanguageModel({ provider: 'google', modelId: 'gemini-2.0-flash-exp' }),
|
||||||
|
imageModel: createMockImageModel({ provider: 'google', modelId: 'imagen-3.0-generate-001' })
|
||||||
|
},
|
||||||
|
|
||||||
|
xai: {
|
||||||
|
name: 'xai-mock',
|
||||||
|
languageModel: createMockLanguageModel({ provider: 'xai', modelId: 'grok-2-latest' }),
|
||||||
|
imageModel: createMockImageModel({ provider: 'xai', modelId: 'grok-2-image-latest' })
|
||||||
|
},
|
||||||
|
|
||||||
|
deepseek: {
|
||||||
|
name: 'deepseek-mock',
|
||||||
|
languageModel: createMockLanguageModel({ provider: 'deepseek', modelId: 'deepseek-chat' })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export type ProviderId = keyof typeof mockProviderConfigs
|
||||||
331
packages/aiCore/src/__tests__/fixtures/mock-responses.ts
Normal file
331
packages/aiCore/src/__tests__/fixtures/mock-responses.ts
Normal file
@@ -0,0 +1,331 @@
|
|||||||
|
/**
|
||||||
|
* Mock Responses
|
||||||
|
* Provides realistic mock responses for all provider types
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { jsonSchema, type ModelMessage, type Tool } from 'ai'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Standard test messages for all scenarios
|
||||||
|
*/
|
||||||
|
export const testMessages = {
|
||||||
|
simple: [{ role: 'user' as const, content: 'Hello, how are you?' }],
|
||||||
|
|
||||||
|
conversation: [
|
||||||
|
{ role: 'user' as const, content: 'What is the capital of France?' },
|
||||||
|
{ role: 'assistant' as const, content: 'The capital of France is Paris.' },
|
||||||
|
{ role: 'user' as const, content: 'What is its population?' }
|
||||||
|
],
|
||||||
|
|
||||||
|
withSystem: [
|
||||||
|
{ role: 'system' as const, content: 'You are a helpful assistant that provides concise answers.' },
|
||||||
|
{ role: 'user' as const, content: 'Explain quantum computing in one sentence.' }
|
||||||
|
],
|
||||||
|
|
||||||
|
withImages: [
|
||||||
|
{
|
||||||
|
role: 'user' as const,
|
||||||
|
content: [
|
||||||
|
{ type: 'text' as const, text: 'What is in this image?' },
|
||||||
|
{
|
||||||
|
type: 'image' as const,
|
||||||
|
image:
|
||||||
|
''
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
|
||||||
|
toolUse: [{ role: 'user' as const, content: 'What is the weather in San Francisco?' }],
|
||||||
|
|
||||||
|
multiTurn: [
|
||||||
|
{ role: 'user' as const, content: 'Can you help me with a math problem?' },
|
||||||
|
{ role: 'assistant' as const, content: 'Of course! What math problem would you like help with?' },
|
||||||
|
{ role: 'user' as const, content: 'What is 15 * 23?' },
|
||||||
|
{ role: 'assistant' as const, content: '15 * 23 = 345' },
|
||||||
|
{ role: 'user' as const, content: 'Now divide that by 5' }
|
||||||
|
]
|
||||||
|
} satisfies Record<string, ModelMessage[]>
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Standard test tools for tool calling scenarios
|
||||||
|
*/
|
||||||
|
export const testTools: Record<string, Tool> = {
|
||||||
|
getWeather: {
|
||||||
|
description: 'Get the current weather in a given location',
|
||||||
|
inputSchema: jsonSchema({
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
location: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'The city and state, e.g. San Francisco, CA'
|
||||||
|
},
|
||||||
|
unit: {
|
||||||
|
type: 'string',
|
||||||
|
enum: ['celsius', 'fahrenheit'],
|
||||||
|
description: 'The temperature unit to use'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
required: ['location']
|
||||||
|
}),
|
||||||
|
execute: async ({ location, unit = 'fahrenheit' }) => {
|
||||||
|
return {
|
||||||
|
location,
|
||||||
|
temperature: unit === 'celsius' ? 22 : 72,
|
||||||
|
unit,
|
||||||
|
condition: 'sunny'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
calculate: {
|
||||||
|
description: 'Perform a mathematical calculation',
|
||||||
|
inputSchema: jsonSchema({
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
operation: {
|
||||||
|
type: 'string',
|
||||||
|
enum: ['add', 'subtract', 'multiply', 'divide'],
|
||||||
|
description: 'The operation to perform'
|
||||||
|
},
|
||||||
|
a: {
|
||||||
|
type: 'number',
|
||||||
|
description: 'The first number'
|
||||||
|
},
|
||||||
|
b: {
|
||||||
|
type: 'number',
|
||||||
|
description: 'The second number'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
required: ['operation', 'a', 'b']
|
||||||
|
}),
|
||||||
|
execute: async ({ operation, a, b }) => {
|
||||||
|
const operations = {
|
||||||
|
add: (x: number, y: number) => x + y,
|
||||||
|
subtract: (x: number, y: number) => x - y,
|
||||||
|
multiply: (x: number, y: number) => x * y,
|
||||||
|
divide: (x: number, y: number) => x / y
|
||||||
|
}
|
||||||
|
return { result: operations[operation as keyof typeof operations](a, b) }
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
searchDatabase: {
|
||||||
|
description: 'Search for information in a database',
|
||||||
|
inputSchema: jsonSchema({
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
query: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'The search query'
|
||||||
|
},
|
||||||
|
limit: {
|
||||||
|
type: 'number',
|
||||||
|
description: 'Maximum number of results to return',
|
||||||
|
default: 10
|
||||||
|
}
|
||||||
|
},
|
||||||
|
required: ['query']
|
||||||
|
}),
|
||||||
|
execute: async ({ query, limit = 10 }) => {
|
||||||
|
return {
|
||||||
|
results: [
|
||||||
|
{ id: 1, title: `Result 1 for ${query}`, relevance: 0.95 },
|
||||||
|
{ id: 2, title: `Result 2 for ${query}`, relevance: 0.87 }
|
||||||
|
].slice(0, limit)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Mock streaming chunks for different providers
|
||||||
|
*/
|
||||||
|
export const mockStreamingChunks = {
|
||||||
|
text: [
|
||||||
|
{ type: 'text-delta' as const, textDelta: 'Hello' },
|
||||||
|
{ type: 'text-delta' as const, textDelta: ', ' },
|
||||||
|
{ type: 'text-delta' as const, textDelta: 'this ' },
|
||||||
|
{ type: 'text-delta' as const, textDelta: 'is ' },
|
||||||
|
{ type: 'text-delta' as const, textDelta: 'a ' },
|
||||||
|
{ type: 'text-delta' as const, textDelta: 'test.' }
|
||||||
|
],
|
||||||
|
|
||||||
|
withToolCall: [
|
||||||
|
{ type: 'text-delta' as const, textDelta: 'Let me check the weather for you.' },
|
||||||
|
{
|
||||||
|
type: 'tool-call-delta' as const,
|
||||||
|
toolCallType: 'function' as const,
|
||||||
|
toolCallId: 'call_123',
|
||||||
|
toolName: 'getWeather',
|
||||||
|
argsTextDelta: '{"location":'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: 'tool-call-delta' as const,
|
||||||
|
toolCallType: 'function' as const,
|
||||||
|
toolCallId: 'call_123',
|
||||||
|
toolName: 'getWeather',
|
||||||
|
argsTextDelta: ' "San Francisco, CA"}'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: 'tool-call' as const,
|
||||||
|
toolCallType: 'function' as const,
|
||||||
|
toolCallId: 'call_123',
|
||||||
|
toolName: 'getWeather',
|
||||||
|
args: { location: 'San Francisco, CA' }
|
||||||
|
}
|
||||||
|
],
|
||||||
|
|
||||||
|
withFinish: [
|
||||||
|
{ type: 'text-delta' as const, textDelta: 'Complete response.' },
|
||||||
|
{
|
||||||
|
type: 'finish' as const,
|
||||||
|
finishReason: 'stop' as const,
|
||||||
|
usage: {
|
||||||
|
promptTokens: 10,
|
||||||
|
completionTokens: 5,
|
||||||
|
totalTokens: 15
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Mock complete responses for non-streaming scenarios
|
||||||
|
*/
|
||||||
|
export const mockCompleteResponses = {
|
||||||
|
simple: {
|
||||||
|
text: 'This is a simple response.',
|
||||||
|
finishReason: 'stop' as const,
|
||||||
|
usage: {
|
||||||
|
promptTokens: 15,
|
||||||
|
completionTokens: 8,
|
||||||
|
totalTokens: 23
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
withToolCalls: {
|
||||||
|
text: 'I will check the weather for you.',
|
||||||
|
toolCalls: [
|
||||||
|
{
|
||||||
|
toolCallId: 'call_456',
|
||||||
|
toolName: 'getWeather',
|
||||||
|
args: { location: 'New York, NY', unit: 'celsius' }
|
||||||
|
}
|
||||||
|
],
|
||||||
|
finishReason: 'tool-calls' as const,
|
||||||
|
usage: {
|
||||||
|
promptTokens: 25,
|
||||||
|
completionTokens: 12,
|
||||||
|
totalTokens: 37
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
withWarnings: {
|
||||||
|
text: 'Response with warnings.',
|
||||||
|
finishReason: 'stop' as const,
|
||||||
|
usage: {
|
||||||
|
promptTokens: 10,
|
||||||
|
completionTokens: 5,
|
||||||
|
totalTokens: 15
|
||||||
|
},
|
||||||
|
warnings: [
|
||||||
|
{
|
||||||
|
type: 'unsupported-setting' as const,
|
||||||
|
message: 'Temperature parameter not supported for this model'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Mock image generation responses
|
||||||
|
*/
|
||||||
|
export const mockImageResponses = {
|
||||||
|
single: {
|
||||||
|
image: {
|
||||||
|
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
|
||||||
|
uint8Array: new Uint8Array([137, 80, 78, 71, 13, 10, 26, 10, 0, 0, 0, 13, 73, 72, 68, 82]),
|
||||||
|
mimeType: 'image/png' as const
|
||||||
|
},
|
||||||
|
warnings: []
|
||||||
|
},
|
||||||
|
|
||||||
|
multiple: {
|
||||||
|
images: [
|
||||||
|
{
|
||||||
|
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
|
||||||
|
uint8Array: new Uint8Array([137, 80, 78, 71]),
|
||||||
|
mimeType: 'image/png' as const
|
||||||
|
},
|
||||||
|
{
|
||||||
|
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAIAAAACCAYAAABytg0kAAAAEklEQVR42mNk+M9QzwAEjDAGACCKAgdZ9zImAAAAAElFTkSuQmCC',
|
||||||
|
uint8Array: new Uint8Array([137, 80, 78, 71]),
|
||||||
|
mimeType: 'image/png' as const
|
||||||
|
}
|
||||||
|
],
|
||||||
|
warnings: []
|
||||||
|
},
|
||||||
|
|
||||||
|
withProviderMetadata: {
|
||||||
|
image: {
|
||||||
|
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
|
||||||
|
uint8Array: new Uint8Array([137, 80, 78, 71]),
|
||||||
|
mimeType: 'image/png' as const
|
||||||
|
},
|
||||||
|
providerMetadata: {
|
||||||
|
openai: {
|
||||||
|
images: [
|
||||||
|
{
|
||||||
|
revisedPrompt: 'A detailed and enhanced version of the original prompt'
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
warnings: []
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Mock error responses
|
||||||
|
*/
|
||||||
|
export const mockErrors = {
|
||||||
|
invalidApiKey: {
|
||||||
|
name: 'APIError',
|
||||||
|
message: 'Invalid API key provided',
|
||||||
|
statusCode: 401
|
||||||
|
},
|
||||||
|
|
||||||
|
rateLimitExceeded: {
|
||||||
|
name: 'RateLimitError',
|
||||||
|
message: 'Rate limit exceeded. Please try again later.',
|
||||||
|
statusCode: 429,
|
||||||
|
headers: {
|
||||||
|
'retry-after': '60'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
modelNotFound: {
|
||||||
|
name: 'ModelNotFoundError',
|
||||||
|
message: 'The requested model was not found',
|
||||||
|
statusCode: 404
|
||||||
|
},
|
||||||
|
|
||||||
|
contextLengthExceeded: {
|
||||||
|
name: 'ContextLengthError',
|
||||||
|
message: "This model's maximum context length is 4096 tokens",
|
||||||
|
statusCode: 400
|
||||||
|
},
|
||||||
|
|
||||||
|
timeout: {
|
||||||
|
name: 'TimeoutError',
|
||||||
|
message: 'Request timed out after 30000ms',
|
||||||
|
code: 'ETIMEDOUT'
|
||||||
|
},
|
||||||
|
|
||||||
|
networkError: {
|
||||||
|
name: 'NetworkError',
|
||||||
|
message: 'Network connection failed',
|
||||||
|
code: 'ECONNREFUSED'
|
||||||
|
}
|
||||||
|
}
|
||||||
329
packages/aiCore/src/__tests__/helpers/provider-test-utils.ts
Normal file
329
packages/aiCore/src/__tests__/helpers/provider-test-utils.ts
Normal file
@@ -0,0 +1,329 @@
|
|||||||
|
/**
|
||||||
|
* Provider-Specific Test Utilities
|
||||||
|
* Helper functions for testing individual providers with all their parameters
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type { Tool } from 'ai'
|
||||||
|
import { expect } from 'vitest'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Provider parameter configurations for comprehensive testing
|
||||||
|
*/
|
||||||
|
export const providerParameterMatrix = {
|
||||||
|
openai: {
|
||||||
|
models: ['gpt-4', 'gpt-4-turbo', 'gpt-3.5-turbo', 'gpt-4o'],
|
||||||
|
parameters: {
|
||||||
|
temperature: [0, 0.5, 0.7, 1.0, 1.5, 2.0],
|
||||||
|
maxTokens: [100, 500, 1000, 2000, 4000],
|
||||||
|
topP: [0.1, 0.5, 0.9, 1.0],
|
||||||
|
frequencyPenalty: [-2.0, -1.0, 0, 1.0, 2.0],
|
||||||
|
presencePenalty: [-2.0, -1.0, 0, 1.0, 2.0],
|
||||||
|
stop: [undefined, ['stop'], ['STOP', 'END']],
|
||||||
|
seed: [undefined, 12345, 67890],
|
||||||
|
responseFormat: [undefined, { type: 'json_object' as const }],
|
||||||
|
user: [undefined, 'test-user-123']
|
||||||
|
},
|
||||||
|
toolChoice: ['auto', 'required', 'none', { type: 'function' as const, name: 'getWeather' }],
|
||||||
|
parallelToolCalls: [true, false]
|
||||||
|
},
|
||||||
|
|
||||||
|
anthropic: {
|
||||||
|
models: ['claude-3-5-sonnet-20241022', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'],
|
||||||
|
parameters: {
|
||||||
|
temperature: [0, 0.5, 1.0],
|
||||||
|
maxTokens: [100, 1000, 4000, 8000],
|
||||||
|
topP: [0.1, 0.5, 0.9, 1.0],
|
||||||
|
topK: [undefined, 1, 5, 10, 40],
|
||||||
|
stop: [undefined, ['Human:', 'Assistant:']],
|
||||||
|
metadata: [undefined, { userId: 'test-123' }]
|
||||||
|
},
|
||||||
|
toolChoice: ['auto', 'any', { type: 'tool' as const, name: 'getWeather' }]
|
||||||
|
},
|
||||||
|
|
||||||
|
google: {
|
||||||
|
models: ['gemini-2.0-flash-exp', 'gemini-1.5-pro', 'gemini-1.5-flash'],
|
||||||
|
parameters: {
|
||||||
|
temperature: [0, 0.5, 0.9, 1.0],
|
||||||
|
maxTokens: [100, 1000, 2000, 8000],
|
||||||
|
topP: [0.1, 0.5, 0.95, 1.0],
|
||||||
|
topK: [undefined, 1, 16, 40],
|
||||||
|
stopSequences: [undefined, ['END'], ['STOP', 'TERMINATE']]
|
||||||
|
},
|
||||||
|
safetySettings: [
|
||||||
|
undefined,
|
||||||
|
[
|
||||||
|
{ category: 'HARM_CATEGORY_HARASSMENT', threshold: 'BLOCK_MEDIUM_AND_ABOVE' },
|
||||||
|
{ category: 'HARM_CATEGORY_HATE_SPEECH', threshold: 'BLOCK_ONLY_HIGH' }
|
||||||
|
]
|
||||||
|
]
|
||||||
|
},
|
||||||
|
|
||||||
|
xai: {
|
||||||
|
models: ['grok-2-latest', 'grok-2-1212'],
|
||||||
|
parameters: {
|
||||||
|
temperature: [0, 0.5, 1.0, 1.5],
|
||||||
|
maxTokens: [100, 500, 2000, 4000],
|
||||||
|
topP: [0.1, 0.5, 0.9, 1.0],
|
||||||
|
stop: [undefined, ['STOP'], ['END', 'TERMINATE']],
|
||||||
|
seed: [undefined, 12345]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
deepseek: {
|
||||||
|
models: ['deepseek-chat', 'deepseek-coder'],
|
||||||
|
parameters: {
|
||||||
|
temperature: [0, 0.5, 1.0],
|
||||||
|
maxTokens: [100, 1000, 4000],
|
||||||
|
topP: [0.1, 0.5, 0.95],
|
||||||
|
frequencyPenalty: [0, 0.5, 1.0],
|
||||||
|
presencePenalty: [0, 0.5, 1.0],
|
||||||
|
stop: [undefined, ['```'], ['END']]
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
azure: {
|
||||||
|
deployments: ['gpt-4-deployment', 'gpt-35-turbo-deployment'],
|
||||||
|
parameters: {
|
||||||
|
temperature: [0, 0.7, 1.0],
|
||||||
|
maxTokens: [100, 1000, 2000],
|
||||||
|
topP: [0.1, 0.5, 0.95],
|
||||||
|
frequencyPenalty: [0, 1.0],
|
||||||
|
presencePenalty: [0, 1.0],
|
||||||
|
stop: [undefined, ['STOP']]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} as const
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates test cases for all parameter combinations
|
||||||
|
*/
|
||||||
|
export function generateParameterTestCases<T extends Record<string, any[]>>(
|
||||||
|
params: T,
|
||||||
|
maxCombinations = 50
|
||||||
|
): Array<Partial<{ [K in keyof T]: T[K][number] }>> {
|
||||||
|
const keys = Object.keys(params) as Array<keyof T>
|
||||||
|
const testCases: Array<Partial<{ [K in keyof T]: T[K][number] }>> = []
|
||||||
|
|
||||||
|
// Generate combinations using sampling strategy for large parameter spaces
|
||||||
|
const totalCombinations = keys.reduce((acc, key) => acc * params[key].length, 1)
|
||||||
|
|
||||||
|
if (totalCombinations <= maxCombinations) {
|
||||||
|
// Generate all combinations if total is small
|
||||||
|
generateAllCombinations(params, keys, 0, {}, testCases)
|
||||||
|
} else {
|
||||||
|
// Sample diverse combinations if total is large
|
||||||
|
generateSampledCombinations(params, keys, maxCombinations, testCases)
|
||||||
|
}
|
||||||
|
|
||||||
|
return testCases
|
||||||
|
}
|
||||||
|
|
||||||
|
function generateAllCombinations<T extends Record<string, any[]>>(
|
||||||
|
params: T,
|
||||||
|
keys: Array<keyof T>,
|
||||||
|
index: number,
|
||||||
|
current: Partial<{ [K in keyof T]: T[K][number] }>,
|
||||||
|
results: Array<Partial<{ [K in keyof T]: T[K][number] }>>
|
||||||
|
) {
|
||||||
|
if (index === keys.length) {
|
||||||
|
results.push({ ...current })
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
const key = keys[index]
|
||||||
|
for (const value of params[key]) {
|
||||||
|
generateAllCombinations(params, keys, index + 1, { ...current, [key]: value }, results)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function generateSampledCombinations<T extends Record<string, any[]>>(
|
||||||
|
params: T,
|
||||||
|
keys: Array<keyof T>,
|
||||||
|
count: number,
|
||||||
|
results: Array<Partial<{ [K in keyof T]: T[K][number] }>>
|
||||||
|
) {
|
||||||
|
// Generate edge cases first (min/max values)
|
||||||
|
const edgeCase1: any = {}
|
||||||
|
const edgeCase2: any = {}
|
||||||
|
|
||||||
|
for (const key of keys) {
|
||||||
|
edgeCase1[key] = params[key][0]
|
||||||
|
edgeCase2[key] = params[key][params[key].length - 1]
|
||||||
|
}
|
||||||
|
|
||||||
|
results.push(edgeCase1, edgeCase2)
|
||||||
|
|
||||||
|
// Generate random combinations for the rest
|
||||||
|
for (let i = results.length; i < count; i++) {
|
||||||
|
const combination: any = {}
|
||||||
|
for (const key of keys) {
|
||||||
|
const values = params[key]
|
||||||
|
combination[key] = values[Math.floor(Math.random() * values.length)]
|
||||||
|
}
|
||||||
|
results.push(combination)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates that all provider-specific parameters are correctly passed through
|
||||||
|
*/
|
||||||
|
export function validateProviderParams(providerId: string, actualParams: any, expectedParams: any): void {
|
||||||
|
const requiredFields: Record<string, string[]> = {
|
||||||
|
openai: ['model', 'messages'],
|
||||||
|
anthropic: ['model', 'messages'],
|
||||||
|
google: ['model', 'contents'],
|
||||||
|
xai: ['model', 'messages'],
|
||||||
|
deepseek: ['model', 'messages'],
|
||||||
|
azure: ['messages']
|
||||||
|
}
|
||||||
|
|
||||||
|
const fields = requiredFields[providerId] || ['model', 'messages']
|
||||||
|
|
||||||
|
for (const field of fields) {
|
||||||
|
expect(actualParams).toHaveProperty(field)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate optional parameters if they were provided
|
||||||
|
const optionalParams = ['temperature', 'max_tokens', 'top_p', 'stop', 'tools']
|
||||||
|
|
||||||
|
for (const param of optionalParams) {
|
||||||
|
if (expectedParams[param] !== undefined) {
|
||||||
|
expect(actualParams[param]).toEqual(expectedParams[param])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a comprehensive test suite for a provider
|
||||||
|
*/
|
||||||
|
// oxlint-disable-next-line no-unused-vars
|
||||||
|
export function createProviderTestSuite(_providerId: string) {
|
||||||
|
return {
|
||||||
|
testBasicCompletion: async (executor: any, model: string) => {
|
||||||
|
const result = await executor.generateText({
|
||||||
|
model,
|
||||||
|
messages: [{ role: 'user' as const, content: 'Hello' }]
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result).toBeDefined()
|
||||||
|
expect(result.text).toBeDefined()
|
||||||
|
expect(typeof result.text).toBe('string')
|
||||||
|
},
|
||||||
|
|
||||||
|
testStreaming: async (executor: any, model: string) => {
|
||||||
|
const chunks: any[] = []
|
||||||
|
const result = await executor.streamText({
|
||||||
|
model,
|
||||||
|
messages: [{ role: 'user' as const, content: 'Hello' }]
|
||||||
|
})
|
||||||
|
|
||||||
|
for await (const chunk of result.textStream) {
|
||||||
|
chunks.push(chunk)
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(chunks.length).toBeGreaterThan(0)
|
||||||
|
},
|
||||||
|
|
||||||
|
testTemperature: async (executor: any, model: string, temperatures: number[]) => {
|
||||||
|
for (const temperature of temperatures) {
|
||||||
|
const result = await executor.generateText({
|
||||||
|
model,
|
||||||
|
messages: [{ role: 'user' as const, content: 'Hello' }],
|
||||||
|
temperature
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result).toBeDefined()
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
testMaxTokens: async (executor: any, model: string, maxTokensValues: number[]) => {
|
||||||
|
for (const maxTokens of maxTokensValues) {
|
||||||
|
const result = await executor.generateText({
|
||||||
|
model,
|
||||||
|
messages: [{ role: 'user' as const, content: 'Hello' }],
|
||||||
|
maxTokens
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result).toBeDefined()
|
||||||
|
if (result.usage?.completionTokens) {
|
||||||
|
expect(result.usage.completionTokens).toBeLessThanOrEqual(maxTokens)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
|
||||||
|
testToolCalling: async (executor: any, model: string, tools: Record<string, Tool>) => {
|
||||||
|
const result = await executor.generateText({
|
||||||
|
model,
|
||||||
|
messages: [{ role: 'user' as const, content: 'What is the weather in SF?' }],
|
||||||
|
tools
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result).toBeDefined()
|
||||||
|
},
|
||||||
|
|
||||||
|
testStopSequences: async (executor: any, model: string, stopSequences: string[][]) => {
|
||||||
|
for (const stop of stopSequences) {
|
||||||
|
const result = await executor.generateText({
|
||||||
|
model,
|
||||||
|
messages: [{ role: 'user' as const, content: 'Count to 10' }],
|
||||||
|
stop
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result).toBeDefined()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Generates test data for vision/multimodal testing
|
||||||
|
*/
|
||||||
|
export function createVisionTestData() {
|
||||||
|
return {
|
||||||
|
imageUrl: 'https://example.com/test-image.jpg',
|
||||||
|
base64Image:
|
||||||
|
'',
|
||||||
|
messages: [
|
||||||
|
{
|
||||||
|
role: 'user' as const,
|
||||||
|
content: [
|
||||||
|
{ type: 'text' as const, text: 'What is in this image?' },
|
||||||
|
{
|
||||||
|
type: 'image' as const,
|
||||||
|
image:
|
||||||
|
''
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates mock responses for different finish reasons
|
||||||
|
*/
|
||||||
|
export function createFinishReasonMocks() {
|
||||||
|
return {
|
||||||
|
stop: {
|
||||||
|
text: 'Complete response.',
|
||||||
|
finishReason: 'stop' as const,
|
||||||
|
usage: { promptTokens: 10, completionTokens: 5, totalTokens: 15 }
|
||||||
|
},
|
||||||
|
length: {
|
||||||
|
text: 'Incomplete response due to',
|
||||||
|
finishReason: 'length' as const,
|
||||||
|
usage: { promptTokens: 10, completionTokens: 100, totalTokens: 110 }
|
||||||
|
},
|
||||||
|
'tool-calls': {
|
||||||
|
text: 'Calling tools',
|
||||||
|
finishReason: 'tool-calls' as const,
|
||||||
|
toolCalls: [{ toolCallId: 'call_1', toolName: 'getWeather', args: { location: 'SF' } }],
|
||||||
|
usage: { promptTokens: 10, completionTokens: 8, totalTokens: 18 }
|
||||||
|
},
|
||||||
|
'content-filter': {
|
||||||
|
text: '',
|
||||||
|
finishReason: 'content-filter' as const,
|
||||||
|
usage: { promptTokens: 10, completionTokens: 0, totalTokens: 10 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
291
packages/aiCore/src/__tests__/helpers/test-utils.ts
Normal file
291
packages/aiCore/src/__tests__/helpers/test-utils.ts
Normal file
@@ -0,0 +1,291 @@
|
|||||||
|
/**
|
||||||
|
* Test Utilities
|
||||||
|
* Helper functions for testing AI Core functionality
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { expect, vi } from 'vitest'
|
||||||
|
|
||||||
|
import type { ProviderId } from '../fixtures/mock-providers'
|
||||||
|
import { createMockImageModel, createMockLanguageModel, mockProviderConfigs } from '../fixtures/mock-providers'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a test provider with streaming support
|
||||||
|
*/
|
||||||
|
export function createTestStreamingProvider(chunks: any[]) {
|
||||||
|
return createMockLanguageModel({
|
||||||
|
doStream: vi.fn().mockReturnValue({
|
||||||
|
stream: (async function* () {
|
||||||
|
for (const chunk of chunks) {
|
||||||
|
yield chunk
|
||||||
|
}
|
||||||
|
})(),
|
||||||
|
rawCall: { rawPrompt: null, rawSettings: {} },
|
||||||
|
rawResponse: { headers: {} },
|
||||||
|
warnings: []
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a test provider that throws errors
|
||||||
|
*/
|
||||||
|
export function createErrorProvider(error: Error) {
|
||||||
|
return createMockLanguageModel({
|
||||||
|
doGenerate: vi.fn().mockRejectedValue(error),
|
||||||
|
doStream: vi.fn().mockImplementation(() => {
|
||||||
|
throw error
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Collects all chunks from a stream
|
||||||
|
*/
|
||||||
|
export async function collectStreamChunks<T>(stream: AsyncIterable<T>): Promise<T[]> {
|
||||||
|
const chunks: T[] = []
|
||||||
|
for await (const chunk of stream) {
|
||||||
|
chunks.push(chunk)
|
||||||
|
}
|
||||||
|
return chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Waits for a specific number of milliseconds
|
||||||
|
*/
|
||||||
|
export function wait(ms: number): Promise<void> {
|
||||||
|
return new Promise((resolve) => setTimeout(resolve, ms))
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a mock abort controller that aborts after a delay
|
||||||
|
*/
|
||||||
|
export function createDelayedAbortController(delayMs: number): AbortController {
|
||||||
|
const controller = new AbortController()
|
||||||
|
setTimeout(() => controller.abort(), delayMs)
|
||||||
|
return controller
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Asserts that a function throws an error with a specific message
|
||||||
|
*/
|
||||||
|
export async function expectError(fn: () => Promise<any>, expectedMessage?: string | RegExp): Promise<Error> {
|
||||||
|
try {
|
||||||
|
await fn()
|
||||||
|
throw new Error('Expected function to throw an error, but it did not')
|
||||||
|
} catch (error) {
|
||||||
|
if (expectedMessage) {
|
||||||
|
const message = (error as Error).message
|
||||||
|
if (typeof expectedMessage === 'string') {
|
||||||
|
if (!message.includes(expectedMessage)) {
|
||||||
|
throw new Error(`Expected error message to include "${expectedMessage}", but got "${message}"`)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if (!expectedMessage.test(message)) {
|
||||||
|
throw new Error(`Expected error message to match ${expectedMessage}, but got "${message}"`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return error as Error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a spy function that tracks calls and arguments
|
||||||
|
*/
|
||||||
|
export function createSpy<T extends (...args: any[]) => any>() {
|
||||||
|
const calls: Array<{ args: Parameters<T>; result?: ReturnType<T>; error?: Error }> = []
|
||||||
|
|
||||||
|
const spy = vi.fn((...args: Parameters<T>) => {
|
||||||
|
try {
|
||||||
|
const result = undefined as ReturnType<T>
|
||||||
|
calls.push({ args, result })
|
||||||
|
return result
|
||||||
|
} catch (error) {
|
||||||
|
calls.push({ args, error: error as Error })
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
fn: spy,
|
||||||
|
calls,
|
||||||
|
getCalls: () => calls,
|
||||||
|
getCallCount: () => calls.length,
|
||||||
|
getLastCall: () => calls[calls.length - 1],
|
||||||
|
reset: () => {
|
||||||
|
calls.length = 0
|
||||||
|
spy.mockClear()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates provider configuration
|
||||||
|
*/
|
||||||
|
export function validateProviderConfig(providerId: ProviderId) {
|
||||||
|
const config = mockProviderConfigs[providerId]
|
||||||
|
if (!config) {
|
||||||
|
throw new Error(`No mock configuration found for provider: ${providerId}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!config.apiKey) {
|
||||||
|
throw new Error(`Provider ${providerId} is missing apiKey in mock config`)
|
||||||
|
}
|
||||||
|
|
||||||
|
return config
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a test context with common setup
|
||||||
|
*/
|
||||||
|
export function createTestContext() {
|
||||||
|
const mocks = {
|
||||||
|
languageModel: createMockLanguageModel(),
|
||||||
|
imageModel: createMockImageModel(),
|
||||||
|
providers: new Map<string, any>()
|
||||||
|
}
|
||||||
|
|
||||||
|
const cleanup = () => {
|
||||||
|
mocks.providers.clear()
|
||||||
|
vi.clearAllMocks()
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
mocks,
|
||||||
|
cleanup
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Measures execution time of an async function
|
||||||
|
*/
|
||||||
|
export async function measureTime<T>(fn: () => Promise<T>): Promise<{ result: T; duration: number }> {
|
||||||
|
const start = Date.now()
|
||||||
|
const result = await fn()
|
||||||
|
const duration = Date.now() - start
|
||||||
|
return { result, duration }
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Retries a function until it succeeds or max attempts reached
|
||||||
|
*/
|
||||||
|
export async function retryUntilSuccess<T>(fn: () => Promise<T>, maxAttempts = 3, delayMs = 100): Promise<T> {
|
||||||
|
let lastError: Error | undefined
|
||||||
|
|
||||||
|
for (let attempt = 1; attempt <= maxAttempts; attempt++) {
|
||||||
|
try {
|
||||||
|
return await fn()
|
||||||
|
} catch (error) {
|
||||||
|
lastError = error as Error
|
||||||
|
if (attempt < maxAttempts) {
|
||||||
|
await wait(delayMs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw lastError || new Error('All retry attempts failed')
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a mock streaming response that emits chunks at intervals
|
||||||
|
*/
|
||||||
|
export function createTimedStream<T>(chunks: T[], intervalMs = 10) {
|
||||||
|
return {
|
||||||
|
async *[Symbol.asyncIterator]() {
|
||||||
|
for (const chunk of chunks) {
|
||||||
|
await wait(intervalMs)
|
||||||
|
yield chunk
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Asserts that two objects are deeply equal, ignoring specified keys
|
||||||
|
*/
|
||||||
|
export function assertDeepEqualIgnoring<T extends Record<string, any>>(
|
||||||
|
actual: T,
|
||||||
|
expected: T,
|
||||||
|
ignoreKeys: string[] = []
|
||||||
|
): void {
|
||||||
|
const filterKeys = (obj: T): Partial<T> => {
|
||||||
|
const filtered = { ...obj }
|
||||||
|
for (const key of ignoreKeys) {
|
||||||
|
delete filtered[key]
|
||||||
|
}
|
||||||
|
return filtered
|
||||||
|
}
|
||||||
|
|
||||||
|
const filteredActual = filterKeys(actual)
|
||||||
|
const filteredExpected = filterKeys(expected)
|
||||||
|
|
||||||
|
expect(filteredActual).toEqual(filteredExpected)
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a provider mock that simulates rate limiting
|
||||||
|
*/
|
||||||
|
export function createRateLimitedProvider(limitPerSecond: number) {
|
||||||
|
const calls: number[] = []
|
||||||
|
|
||||||
|
return createMockLanguageModel({
|
||||||
|
doGenerate: vi.fn().mockImplementation(async () => {
|
||||||
|
const now = Date.now()
|
||||||
|
calls.push(now)
|
||||||
|
|
||||||
|
// Remove calls older than 1 second
|
||||||
|
const recentCalls = calls.filter((time) => now - time < 1000)
|
||||||
|
|
||||||
|
if (recentCalls.length > limitPerSecond) {
|
||||||
|
throw new Error('Rate limit exceeded')
|
||||||
|
}
|
||||||
|
|
||||||
|
return {
|
||||||
|
text: 'Rate limited response',
|
||||||
|
finishReason: 'stop' as const,
|
||||||
|
usage: { promptTokens: 10, completionTokens: 5, totalTokens: 15 },
|
||||||
|
rawCall: { rawPrompt: null, rawSettings: {} },
|
||||||
|
rawResponse: { headers: {} },
|
||||||
|
warnings: []
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Validates streaming response structure
|
||||||
|
*/
|
||||||
|
export function validateStreamChunk(chunk: any): void {
|
||||||
|
expect(chunk).toBeDefined()
|
||||||
|
expect(chunk).toHaveProperty('type')
|
||||||
|
|
||||||
|
if (chunk.type === 'text-delta') {
|
||||||
|
expect(chunk).toHaveProperty('textDelta')
|
||||||
|
expect(typeof chunk.textDelta).toBe('string')
|
||||||
|
} else if (chunk.type === 'finish') {
|
||||||
|
expect(chunk).toHaveProperty('finishReason')
|
||||||
|
expect(chunk).toHaveProperty('usage')
|
||||||
|
} else if (chunk.type === 'tool-call') {
|
||||||
|
expect(chunk).toHaveProperty('toolCallId')
|
||||||
|
expect(chunk).toHaveProperty('toolName')
|
||||||
|
expect(chunk).toHaveProperty('args')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Creates a test logger that captures log messages
|
||||||
|
*/
|
||||||
|
export function createTestLogger() {
|
||||||
|
const logs: Array<{ level: string; message: string; meta?: any }> = []
|
||||||
|
|
||||||
|
return {
|
||||||
|
info: (message: string, meta?: any) => logs.push({ level: 'info', message, meta }),
|
||||||
|
warn: (message: string, meta?: any) => logs.push({ level: 'warn', message, meta }),
|
||||||
|
error: (message: string, meta?: any) => logs.push({ level: 'error', message, meta }),
|
||||||
|
debug: (message: string, meta?: any) => logs.push({ level: 'debug', message, meta }),
|
||||||
|
getLogs: () => logs,
|
||||||
|
clear: () => {
|
||||||
|
logs.length = 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
12
packages/aiCore/src/__tests__/index.ts
Normal file
12
packages/aiCore/src/__tests__/index.ts
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
/**
|
||||||
|
* Test Infrastructure Exports
|
||||||
|
* Central export point for all test utilities, fixtures, and helpers
|
||||||
|
*/
|
||||||
|
|
||||||
|
// Fixtures
|
||||||
|
export * from './fixtures/mock-providers'
|
||||||
|
export * from './fixtures/mock-responses'
|
||||||
|
|
||||||
|
// Helpers
|
||||||
|
export * from './helpers/provider-test-utils'
|
||||||
|
export * from './helpers/test-utils'
|
||||||
@@ -4,12 +4,7 @@
|
|||||||
*/
|
*/
|
||||||
export const BUILT_IN_PLUGIN_PREFIX = 'built-in:'
|
export const BUILT_IN_PLUGIN_PREFIX = 'built-in:'
|
||||||
|
|
||||||
export { googleToolsPlugin } from './googleToolsPlugin'
|
export * from './googleToolsPlugin'
|
||||||
export { createLoggingPlugin } from './logging'
|
export * from './toolUsePlugin/promptToolUsePlugin'
|
||||||
export { createPromptToolUsePlugin } from './toolUsePlugin/promptToolUsePlugin'
|
export * from './toolUsePlugin/type'
|
||||||
export type {
|
export * from './webSearchPlugin'
|
||||||
PromptToolUseConfig,
|
|
||||||
ToolUseRequestContext,
|
|
||||||
ToolUseResult
|
|
||||||
} from './toolUsePlugin/type'
|
|
||||||
export { webSearchPlugin, type WebSearchPluginConfig } from './webSearchPlugin'
|
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
|
|||||||
})
|
})
|
||||||
|
|
||||||
// 导出类型定义供开发者使用
|
// 导出类型定义供开发者使用
|
||||||
export type { WebSearchPluginConfig, WebSearchToolOutputSchema } from './helper'
|
export * from './helper'
|
||||||
|
|
||||||
// 默认导出
|
// 默认导出
|
||||||
export default webSearchPlugin
|
export default webSearchPlugin
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ export {
|
|||||||
// ==================== 基础数据和类型 ====================
|
// ==================== 基础数据和类型 ====================
|
||||||
|
|
||||||
// 基础Provider数据源
|
// 基础Provider数据源
|
||||||
export { baseProviderIds, baseProviders } from './schemas'
|
export { baseProviderIds, baseProviders, isBaseProvider } from './schemas'
|
||||||
|
|
||||||
// 类型定义和Schema
|
// 类型定义和Schema
|
||||||
export type {
|
export type {
|
||||||
|
|||||||
@@ -7,7 +7,6 @@ import { createAzure } from '@ai-sdk/azure'
|
|||||||
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
|
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
|
||||||
import { createDeepSeek } from '@ai-sdk/deepseek'
|
import { createDeepSeek } from '@ai-sdk/deepseek'
|
||||||
import { createGoogleGenerativeAI } from '@ai-sdk/google'
|
import { createGoogleGenerativeAI } from '@ai-sdk/google'
|
||||||
import { createHuggingFace } from '@ai-sdk/huggingface'
|
|
||||||
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
|
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
|
||||||
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
|
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
|
||||||
import type { LanguageModelV2 } from '@ai-sdk/provider'
|
import type { LanguageModelV2 } from '@ai-sdk/provider'
|
||||||
@@ -33,8 +32,7 @@ export const baseProviderIds = [
|
|||||||
'deepseek',
|
'deepseek',
|
||||||
'openrouter',
|
'openrouter',
|
||||||
'cherryin',
|
'cherryin',
|
||||||
'cherryin-chat',
|
'cherryin-chat'
|
||||||
'huggingface'
|
|
||||||
] as const
|
] as const
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -158,12 +156,6 @@ export const baseProviders = [
|
|||||||
})
|
})
|
||||||
},
|
},
|
||||||
supportsImageGeneration: true
|
supportsImageGeneration: true
|
||||||
},
|
|
||||||
{
|
|
||||||
id: 'huggingface',
|
|
||||||
name: 'HuggingFace',
|
|
||||||
creator: createHuggingFace,
|
|
||||||
supportsImageGeneration: true
|
|
||||||
}
|
}
|
||||||
] as const satisfies BaseProvider[]
|
] as const satisfies BaseProvider[]
|
||||||
|
|
||||||
|
|||||||
499
packages/aiCore/src/core/runtime/__tests__/generateText.test.ts
Normal file
499
packages/aiCore/src/core/runtime/__tests__/generateText.test.ts
Normal file
@@ -0,0 +1,499 @@
|
|||||||
|
/**
|
||||||
|
* RuntimeExecutor.generateText Comprehensive Tests
|
||||||
|
* Tests non-streaming text generation across all providers with various parameters
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { generateText } from 'ai'
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
import {
|
||||||
|
createMockLanguageModel,
|
||||||
|
mockCompleteResponses,
|
||||||
|
mockProviderConfigs,
|
||||||
|
testMessages,
|
||||||
|
testTools
|
||||||
|
} from '../../../__tests__'
|
||||||
|
import type { AiPlugin } from '../../plugins'
|
||||||
|
import { globalRegistryManagement } from '../../providers/RegistryManagement'
|
||||||
|
import { RuntimeExecutor } from '../executor'
|
||||||
|
|
||||||
|
// Mock AI SDK
|
||||||
|
vi.mock('ai', () => ({
|
||||||
|
generateText: vi.fn()
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../../providers/RegistryManagement', () => ({
|
||||||
|
globalRegistryManagement: {
|
||||||
|
languageModel: vi.fn()
|
||||||
|
},
|
||||||
|
DEFAULT_SEPARATOR: '|'
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('RuntimeExecutor.generateText', () => {
|
||||||
|
let executor: RuntimeExecutor<'openai'>
|
||||||
|
let mockLanguageModel: any
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
|
||||||
|
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
|
||||||
|
|
||||||
|
mockLanguageModel = createMockLanguageModel({
|
||||||
|
provider: 'openai',
|
||||||
|
modelId: 'gpt-4'
|
||||||
|
})
|
||||||
|
|
||||||
|
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
|
||||||
|
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.simple as any)
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Basic Functionality', () => {
|
||||||
|
it('should generate text with minimal parameters', async () => {
|
||||||
|
const result = await executor.generateText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(generateText).toHaveBeenCalledWith({
|
||||||
|
model: mockLanguageModel,
|
||||||
|
messages: testMessages.simple
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.text).toBe('This is a simple response.')
|
||||||
|
expect(result.finishReason).toBe('stop')
|
||||||
|
expect(result.usage).toBeDefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should generate with system messages', async () => {
|
||||||
|
await executor.generateText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.withSystem
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(generateText).toHaveBeenCalledWith({
|
||||||
|
model: mockLanguageModel,
|
||||||
|
messages: testMessages.withSystem
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should generate with conversation history', async () => {
|
||||||
|
await executor.generateText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.conversation
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(generateText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
messages: testMessages.conversation
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('All Parameter Combinations', () => {
|
||||||
|
it('should support all parameters together', async () => {
|
||||||
|
await executor.generateText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple,
|
||||||
|
temperature: 0.7,
|
||||||
|
maxOutputTokens: 500,
|
||||||
|
topP: 0.9,
|
||||||
|
frequencyPenalty: 0.5,
|
||||||
|
presencePenalty: 0.3,
|
||||||
|
stopSequences: ['STOP'],
|
||||||
|
seed: 12345
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(generateText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
temperature: 0.7,
|
||||||
|
maxOutputTokens: 500,
|
||||||
|
topP: 0.9,
|
||||||
|
frequencyPenalty: 0.5,
|
||||||
|
presencePenalty: 0.3,
|
||||||
|
stopSequences: ['STOP'],
|
||||||
|
seed: 12345
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support partial parameters', async () => {
|
||||||
|
await executor.generateText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple,
|
||||||
|
temperature: 0.5,
|
||||||
|
maxOutputTokens: 100
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(generateText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
temperature: 0.5,
|
||||||
|
maxOutputTokens: 100
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Tool Calling', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.withToolCalls as any)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support tool calling', async () => {
|
||||||
|
const result = await executor.generateText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.toolUse,
|
||||||
|
tools: testTools
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(generateText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
tools: testTools
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(result.toolCalls).toBeDefined()
|
||||||
|
expect(result.toolCalls).toHaveLength(1)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support toolChoice auto', async () => {
|
||||||
|
await executor.generateText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.toolUse,
|
||||||
|
tools: testTools,
|
||||||
|
toolChoice: 'auto'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(generateText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
toolChoice: 'auto'
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support toolChoice required', async () => {
|
||||||
|
await executor.generateText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.toolUse,
|
||||||
|
tools: testTools,
|
||||||
|
toolChoice: 'required'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(generateText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
toolChoice: 'required'
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support toolChoice none', async () => {
|
||||||
|
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.simple as any)
|
||||||
|
|
||||||
|
await executor.generateText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple,
|
||||||
|
tools: testTools,
|
||||||
|
toolChoice: 'none'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(generateText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
toolChoice: 'none'
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support specific tool selection', async () => {
|
||||||
|
await executor.generateText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.toolUse,
|
||||||
|
tools: testTools,
|
||||||
|
toolChoice: {
|
||||||
|
type: 'tool',
|
||||||
|
toolName: 'getWeather'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(generateText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
toolChoice: {
|
||||||
|
type: 'tool',
|
||||||
|
toolName: 'getWeather'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Multiple Providers', () => {
|
||||||
|
it('should work with Anthropic provider', async () => {
|
||||||
|
const anthropicExecutor = RuntimeExecutor.create('anthropic', mockProviderConfigs.anthropic)
|
||||||
|
|
||||||
|
const anthropicModel = createMockLanguageModel({
|
||||||
|
provider: 'anthropic',
|
||||||
|
modelId: 'claude-3-5-sonnet-20241022'
|
||||||
|
})
|
||||||
|
|
||||||
|
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(anthropicModel)
|
||||||
|
|
||||||
|
await anthropicExecutor.generateText({
|
||||||
|
model: 'claude-3-5-sonnet-20241022',
|
||||||
|
messages: testMessages.simple
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('anthropic|claude-3-5-sonnet-20241022')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should work with Google provider', async () => {
|
||||||
|
const googleExecutor = RuntimeExecutor.create('google', mockProviderConfigs.google)
|
||||||
|
|
||||||
|
const googleModel = createMockLanguageModel({
|
||||||
|
provider: 'google',
|
||||||
|
modelId: 'gemini-2.0-flash-exp'
|
||||||
|
})
|
||||||
|
|
||||||
|
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(googleModel)
|
||||||
|
|
||||||
|
await googleExecutor.generateText({
|
||||||
|
model: 'gemini-2.0-flash-exp',
|
||||||
|
messages: testMessages.simple
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('google|gemini-2.0-flash-exp')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should work with xAI provider', async () => {
|
||||||
|
const xaiExecutor = RuntimeExecutor.create('xai', mockProviderConfigs.xai)
|
||||||
|
|
||||||
|
const xaiModel = createMockLanguageModel({
|
||||||
|
provider: 'xai',
|
||||||
|
modelId: 'grok-2-latest'
|
||||||
|
})
|
||||||
|
|
||||||
|
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(xaiModel)
|
||||||
|
|
||||||
|
await xaiExecutor.generateText({
|
||||||
|
model: 'grok-2-latest',
|
||||||
|
messages: testMessages.simple
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('xai|grok-2-latest')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should work with DeepSeek provider', async () => {
|
||||||
|
const deepseekExecutor = RuntimeExecutor.create('deepseek', mockProviderConfigs.deepseek)
|
||||||
|
|
||||||
|
const deepseekModel = createMockLanguageModel({
|
||||||
|
provider: 'deepseek',
|
||||||
|
modelId: 'deepseek-chat'
|
||||||
|
})
|
||||||
|
|
||||||
|
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(deepseekModel)
|
||||||
|
|
||||||
|
await deepseekExecutor.generateText({
|
||||||
|
model: 'deepseek-chat',
|
||||||
|
messages: testMessages.simple
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('deepseek|deepseek-chat')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Plugin Integration', () => {
|
||||||
|
it('should execute all plugin hooks', async () => {
|
||||||
|
const pluginCalls: string[] = []
|
||||||
|
|
||||||
|
const testPlugin: AiPlugin = {
|
||||||
|
name: 'test-plugin',
|
||||||
|
onRequestStart: vi.fn(async () => {
|
||||||
|
pluginCalls.push('onRequestStart')
|
||||||
|
}),
|
||||||
|
transformParams: vi.fn(async (params) => {
|
||||||
|
pluginCalls.push('transformParams')
|
||||||
|
return { ...params, temperature: 0.8 }
|
||||||
|
}),
|
||||||
|
transformResult: vi.fn(async (result) => {
|
||||||
|
pluginCalls.push('transformResult')
|
||||||
|
return { ...result, text: result.text + ' [modified]' }
|
||||||
|
}),
|
||||||
|
onRequestEnd: vi.fn(async () => {
|
||||||
|
pluginCalls.push('onRequestEnd')
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
|
||||||
|
|
||||||
|
const result = await executorWithPlugin.generateText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(pluginCalls).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
|
||||||
|
|
||||||
|
// Verify transformed parameters
|
||||||
|
expect(generateText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
temperature: 0.8
|
||||||
|
})
|
||||||
|
)
|
||||||
|
|
||||||
|
// Verify transformed result
|
||||||
|
expect(result.text).toContain('[modified]')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle multiple plugins in order', async () => {
|
||||||
|
const pluginOrder: string[] = []
|
||||||
|
|
||||||
|
const plugin1: AiPlugin = {
|
||||||
|
name: 'plugin-1',
|
||||||
|
transformParams: vi.fn(async (params) => {
|
||||||
|
pluginOrder.push('plugin-1')
|
||||||
|
return { ...params, temperature: 0.5 }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const plugin2: AiPlugin = {
|
||||||
|
name: 'plugin-2',
|
||||||
|
transformParams: vi.fn(async (params) => {
|
||||||
|
pluginOrder.push('plugin-2')
|
||||||
|
return { ...params, maxTokens: 200 }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const executorWithPlugins = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [plugin1, plugin2])
|
||||||
|
|
||||||
|
await executorWithPlugins.generateText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(pluginOrder).toEqual(['plugin-1', 'plugin-2'])
|
||||||
|
|
||||||
|
expect(generateText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
temperature: 0.5,
|
||||||
|
maxTokens: 200
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Error Handling', () => {
|
||||||
|
it('should handle API errors', async () => {
|
||||||
|
const error = new Error('API request failed')
|
||||||
|
vi.mocked(generateText).mockRejectedValue(error)
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
executor.generateText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple
|
||||||
|
})
|
||||||
|
).rejects.toThrow('API request failed')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should execute onError plugin hook', async () => {
|
||||||
|
const error = new Error('Generation failed')
|
||||||
|
vi.mocked(generateText).mockRejectedValue(error)
|
||||||
|
|
||||||
|
const errorPlugin: AiPlugin = {
|
||||||
|
name: 'error-handler',
|
||||||
|
onError: vi.fn()
|
||||||
|
}
|
||||||
|
|
||||||
|
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
executorWithPlugin.generateText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple
|
||||||
|
})
|
||||||
|
).rejects.toThrow('Generation failed')
|
||||||
|
|
||||||
|
expect(errorPlugin.onError).toHaveBeenCalledWith(
|
||||||
|
error,
|
||||||
|
expect.objectContaining({
|
||||||
|
providerId: 'openai',
|
||||||
|
modelId: 'gpt-4'
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle model not found error', async () => {
|
||||||
|
const error = new Error('Model not found: invalid-model')
|
||||||
|
vi.mocked(globalRegistryManagement.languageModel).mockImplementation(() => {
|
||||||
|
throw error
|
||||||
|
})
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
executor.generateText({
|
||||||
|
model: 'invalid-model',
|
||||||
|
messages: testMessages.simple
|
||||||
|
})
|
||||||
|
).rejects.toThrow('Model not found')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Usage and Metadata', () => {
|
||||||
|
it('should return usage information', async () => {
|
||||||
|
const result = await executor.generateText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.usage).toBeDefined()
|
||||||
|
expect(result.usage.inputTokens).toBe(15)
|
||||||
|
expect(result.usage.outputTokens).toBe(8)
|
||||||
|
expect(result.usage.totalTokens).toBe(23)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle warnings', async () => {
|
||||||
|
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.withWarnings as any)
|
||||||
|
|
||||||
|
const result = await executor.generateText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple,
|
||||||
|
temperature: 2.5 // Unsupported value
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.warnings).toBeDefined()
|
||||||
|
expect(result.warnings).toHaveLength(1)
|
||||||
|
expect(result.warnings![0].type).toBe('unsupported-setting')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Abort Signal', () => {
|
||||||
|
it('should support abort signal', async () => {
|
||||||
|
const abortController = new AbortController()
|
||||||
|
|
||||||
|
await executor.generateText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple,
|
||||||
|
abortSignal: abortController.signal
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(generateText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
abortSignal: abortController.signal
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle aborted request', async () => {
|
||||||
|
const abortError = new Error('Request aborted')
|
||||||
|
abortError.name = 'AbortError'
|
||||||
|
|
||||||
|
vi.mocked(generateText).mockRejectedValue(abortError)
|
||||||
|
|
||||||
|
const abortController = new AbortController()
|
||||||
|
abortController.abort()
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
executor.generateText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple,
|
||||||
|
abortSignal: abortController.signal
|
||||||
|
})
|
||||||
|
).rejects.toThrow('Request aborted')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
525
packages/aiCore/src/core/runtime/__tests__/streamText.test.ts
Normal file
525
packages/aiCore/src/core/runtime/__tests__/streamText.test.ts
Normal file
@@ -0,0 +1,525 @@
|
|||||||
|
/**
|
||||||
|
* RuntimeExecutor.streamText Comprehensive Tests
|
||||||
|
* Tests streaming text generation across all providers with various parameters
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { streamText } from 'ai'
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
import { collectStreamChunks, createMockLanguageModel, mockProviderConfigs, testMessages } from '../../../__tests__'
|
||||||
|
import type { AiPlugin } from '../../plugins'
|
||||||
|
import { globalRegistryManagement } from '../../providers/RegistryManagement'
|
||||||
|
import { RuntimeExecutor } from '../executor'
|
||||||
|
|
||||||
|
// Mock AI SDK
|
||||||
|
vi.mock('ai', () => ({
|
||||||
|
streamText: vi.fn()
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../../providers/RegistryManagement', () => ({
|
||||||
|
globalRegistryManagement: {
|
||||||
|
languageModel: vi.fn()
|
||||||
|
},
|
||||||
|
DEFAULT_SEPARATOR: '|'
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('RuntimeExecutor.streamText', () => {
|
||||||
|
let executor: RuntimeExecutor<'openai'>
|
||||||
|
let mockLanguageModel: any
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
|
||||||
|
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
|
||||||
|
|
||||||
|
mockLanguageModel = createMockLanguageModel({
|
||||||
|
provider: 'openai',
|
||||||
|
modelId: 'gpt-4'
|
||||||
|
})
|
||||||
|
|
||||||
|
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Basic Functionality', () => {
|
||||||
|
it('should stream text with minimal parameters', async () => {
|
||||||
|
const mockStream = {
|
||||||
|
textStream: (async function* () {
|
||||||
|
yield 'Hello'
|
||||||
|
yield ' '
|
||||||
|
yield 'World'
|
||||||
|
})(),
|
||||||
|
fullStream: (async function* () {
|
||||||
|
yield { type: 'text-delta', textDelta: 'Hello' }
|
||||||
|
yield { type: 'text-delta', textDelta: ' ' }
|
||||||
|
yield { type: 'text-delta', textDelta: 'World' }
|
||||||
|
})(),
|
||||||
|
usage: Promise.resolve({ promptTokens: 5, completionTokens: 3, totalTokens: 8 })
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||||
|
|
||||||
|
const result = await executor.streamText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(streamText).toHaveBeenCalledWith({
|
||||||
|
model: mockLanguageModel,
|
||||||
|
messages: testMessages.simple
|
||||||
|
})
|
||||||
|
|
||||||
|
const chunks = await collectStreamChunks(result.textStream)
|
||||||
|
expect(chunks).toEqual(['Hello', ' ', 'World'])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should stream with system messages', async () => {
|
||||||
|
const mockStream = {
|
||||||
|
textStream: (async function* () {
|
||||||
|
yield 'Response'
|
||||||
|
})(),
|
||||||
|
fullStream: (async function* () {
|
||||||
|
yield { type: 'text-delta', textDelta: 'Response' }
|
||||||
|
})()
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||||
|
|
||||||
|
await executor.streamText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.withSystem
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(streamText).toHaveBeenCalledWith({
|
||||||
|
model: mockLanguageModel,
|
||||||
|
messages: testMessages.withSystem
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should stream multi-turn conversations', async () => {
|
||||||
|
const mockStream = {
|
||||||
|
textStream: (async function* () {
|
||||||
|
yield 'Multi-turn response'
|
||||||
|
})(),
|
||||||
|
fullStream: (async function* () {
|
||||||
|
yield { type: 'text-delta', textDelta: 'Multi-turn response' }
|
||||||
|
})()
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||||
|
|
||||||
|
await executor.streamText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.multiTurn
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(streamText).toHaveBeenCalled()
|
||||||
|
expect(streamText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
messages: testMessages.multiTurn
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Temperature Parameter', () => {
|
||||||
|
const temperatures = [0, 0.3, 0.5, 0.7, 0.9, 1.0, 1.5, 2.0]
|
||||||
|
|
||||||
|
it.each(temperatures)('should support temperature=%s', async (temperature) => {
|
||||||
|
const mockStream = {
|
||||||
|
textStream: (async function* () {
|
||||||
|
yield 'Response'
|
||||||
|
})(),
|
||||||
|
fullStream: (async function* () {
|
||||||
|
yield { type: 'text-delta', textDelta: 'Response' }
|
||||||
|
})()
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||||
|
|
||||||
|
await executor.streamText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple,
|
||||||
|
temperature
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(streamText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
temperature
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Max Tokens Parameter', () => {
|
||||||
|
const maxTokensValues = [10, 50, 100, 500, 1000, 2000, 4000]
|
||||||
|
|
||||||
|
it.each(maxTokensValues)('should support maxTokens=%s', async (maxTokens) => {
|
||||||
|
const mockStream = {
|
||||||
|
textStream: (async function* () {
|
||||||
|
yield 'Response'
|
||||||
|
})(),
|
||||||
|
fullStream: (async function* () {
|
||||||
|
yield { type: 'text-delta', textDelta: 'Response' }
|
||||||
|
})()
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||||
|
|
||||||
|
await executor.streamText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple,
|
||||||
|
maxOutputTokens: maxTokens
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(streamText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
maxTokens
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Top P Parameter', () => {
|
||||||
|
const topPValues = [0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 1.0]
|
||||||
|
|
||||||
|
it.each(topPValues)('should support topP=%s', async (topP) => {
|
||||||
|
const mockStream = {
|
||||||
|
textStream: (async function* () {
|
||||||
|
yield 'Response'
|
||||||
|
})(),
|
||||||
|
fullStream: (async function* () {
|
||||||
|
yield { type: 'text-delta', textDelta: 'Response' }
|
||||||
|
})()
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||||
|
|
||||||
|
await executor.streamText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple,
|
||||||
|
topP
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(streamText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
topP
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Frequency and Presence Penalty', () => {
|
||||||
|
it('should support frequency penalty', async () => {
|
||||||
|
const penalties = [-2.0, -1.0, 0, 0.5, 1.0, 1.5, 2.0]
|
||||||
|
|
||||||
|
for (const frequencyPenalty of penalties) {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
|
||||||
|
const mockStream = {
|
||||||
|
textStream: (async function* () {
|
||||||
|
yield 'Response'
|
||||||
|
})(),
|
||||||
|
fullStream: (async function* () {
|
||||||
|
yield { type: 'text-delta', textDelta: 'Response' }
|
||||||
|
})()
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||||
|
|
||||||
|
await executor.streamText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple,
|
||||||
|
frequencyPenalty
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(streamText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
frequencyPenalty
|
||||||
|
})
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support presence penalty', async () => {
|
||||||
|
const penalties = [-2.0, -1.0, 0, 0.5, 1.0, 1.5, 2.0]
|
||||||
|
|
||||||
|
for (const presencePenalty of penalties) {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
|
||||||
|
const mockStream = {
|
||||||
|
textStream: (async function* () {
|
||||||
|
yield 'Response'
|
||||||
|
})(),
|
||||||
|
fullStream: (async function* () {
|
||||||
|
yield { type: 'text-delta', textDelta: 'Response' }
|
||||||
|
})()
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||||
|
|
||||||
|
await executor.streamText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple,
|
||||||
|
presencePenalty
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(streamText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
presencePenalty
|
||||||
|
})
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should support both penalties together', async () => {
|
||||||
|
const mockStream = {
|
||||||
|
textStream: (async function* () {
|
||||||
|
yield 'Response'
|
||||||
|
})(),
|
||||||
|
fullStream: (async function* () {
|
||||||
|
yield { type: 'text-delta', textDelta: 'Response' }
|
||||||
|
})()
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||||
|
|
||||||
|
await executor.streamText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple,
|
||||||
|
frequencyPenalty: 0.5,
|
||||||
|
presencePenalty: 0.5
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(streamText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
frequencyPenalty: 0.5,
|
||||||
|
presencePenalty: 0.5
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Seed Parameter', () => {
|
||||||
|
it('should support seed for deterministic output', async () => {
|
||||||
|
const seeds = [0, 12345, 67890, 999999]
|
||||||
|
|
||||||
|
for (const seed of seeds) {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
|
||||||
|
const mockStream = {
|
||||||
|
textStream: (async function* () {
|
||||||
|
yield 'Response'
|
||||||
|
})(),
|
||||||
|
fullStream: (async function* () {
|
||||||
|
yield { type: 'text-delta', textDelta: 'Response' }
|
||||||
|
})()
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||||
|
|
||||||
|
await executor.streamText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple,
|
||||||
|
seed
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(streamText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
seed
|
||||||
|
})
|
||||||
|
)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Abort Signal', () => {
|
||||||
|
it('should support abort signal', async () => {
|
||||||
|
const abortController = new AbortController()
|
||||||
|
|
||||||
|
const mockStream = {
|
||||||
|
textStream: (async function* () {
|
||||||
|
yield 'Response'
|
||||||
|
})(),
|
||||||
|
fullStream: (async function* () {
|
||||||
|
yield { type: 'text-delta', textDelta: 'Response' }
|
||||||
|
})()
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||||
|
|
||||||
|
await executor.streamText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple,
|
||||||
|
abortSignal: abortController.signal
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(streamText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
abortSignal: abortController.signal
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle abort during streaming', async () => {
|
||||||
|
const abortController = new AbortController()
|
||||||
|
|
||||||
|
const mockStream = {
|
||||||
|
textStream: (async function* () {
|
||||||
|
yield 'Start'
|
||||||
|
// Simulate abort
|
||||||
|
abortController.abort()
|
||||||
|
throw new Error('Aborted')
|
||||||
|
})(),
|
||||||
|
fullStream: (async function* () {
|
||||||
|
yield { type: 'text-delta', textDelta: 'Start' }
|
||||||
|
throw new Error('Aborted')
|
||||||
|
})()
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||||
|
|
||||||
|
const result = await executor.streamText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple,
|
||||||
|
abortSignal: abortController.signal
|
||||||
|
})
|
||||||
|
|
||||||
|
await expect(async () => {
|
||||||
|
// oxlint-disable-next-line no-unused-vars
|
||||||
|
for await (const _chunk of result.textStream) {
|
||||||
|
// Stream should be interrupted
|
||||||
|
}
|
||||||
|
}).rejects.toThrow('Aborted')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Plugin Integration', () => {
|
||||||
|
it('should execute plugins during streaming', async () => {
|
||||||
|
const pluginCalls: string[] = []
|
||||||
|
|
||||||
|
const testPlugin: AiPlugin = {
|
||||||
|
name: 'test-plugin',
|
||||||
|
onRequestStart: vi.fn(async () => {
|
||||||
|
pluginCalls.push('onRequestStart')
|
||||||
|
}),
|
||||||
|
transformParams: vi.fn(async (params) => {
|
||||||
|
pluginCalls.push('transformParams')
|
||||||
|
return { ...params, temperature: 0.5 }
|
||||||
|
}),
|
||||||
|
onRequestEnd: vi.fn(async () => {
|
||||||
|
pluginCalls.push('onRequestEnd')
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
|
||||||
|
|
||||||
|
const mockStream = {
|
||||||
|
textStream: (async function* () {
|
||||||
|
yield 'Response'
|
||||||
|
})(),
|
||||||
|
fullStream: (async function* () {
|
||||||
|
yield { type: 'text-delta', textDelta: 'Response' }
|
||||||
|
})()
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||||
|
|
||||||
|
const result = await executorWithPlugin.streamText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple
|
||||||
|
})
|
||||||
|
|
||||||
|
// Consume stream
|
||||||
|
// oxlint-disable-next-line no-unused-vars
|
||||||
|
for await (const _chunk of result.textStream) {
|
||||||
|
// Stream chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(pluginCalls).toContain('onRequestStart')
|
||||||
|
expect(pluginCalls).toContain('transformParams')
|
||||||
|
|
||||||
|
// Verify transformed parameters were used
|
||||||
|
expect(streamText).toHaveBeenCalledWith(
|
||||||
|
expect.objectContaining({
|
||||||
|
temperature: 0.5
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Full Stream with Finish Reason', () => {
|
||||||
|
it('should provide finish reason in full stream', async () => {
|
||||||
|
const mockStream = {
|
||||||
|
textStream: (async function* () {
|
||||||
|
yield 'Response'
|
||||||
|
})(),
|
||||||
|
fullStream: (async function* () {
|
||||||
|
yield { type: 'text-delta', textDelta: 'Response' }
|
||||||
|
yield {
|
||||||
|
type: 'finish',
|
||||||
|
finishReason: 'stop',
|
||||||
|
usage: { promptTokens: 5, completionTokens: 3, totalTokens: 8 }
|
||||||
|
}
|
||||||
|
})()
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||||
|
|
||||||
|
const result = await executor.streamText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple
|
||||||
|
})
|
||||||
|
|
||||||
|
const fullChunks = await collectStreamChunks(result.fullStream)
|
||||||
|
|
||||||
|
expect(fullChunks).toHaveLength(2)
|
||||||
|
expect(fullChunks[0]).toEqual({ type: 'text-delta', textDelta: 'Response' })
|
||||||
|
expect(fullChunks[1]).toEqual({
|
||||||
|
type: 'finish',
|
||||||
|
finishReason: 'stop',
|
||||||
|
usage: { promptTokens: 5, completionTokens: 3, totalTokens: 8 }
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Error Handling', () => {
|
||||||
|
it('should handle streaming errors', async () => {
|
||||||
|
const error = new Error('Streaming failed')
|
||||||
|
vi.mocked(streamText).mockRejectedValue(error)
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
executor.streamText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple
|
||||||
|
})
|
||||||
|
).rejects.toThrow('Streaming failed')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should execute onError plugin hook on failure', async () => {
|
||||||
|
const error = new Error('Stream error')
|
||||||
|
vi.mocked(streamText).mockRejectedValue(error)
|
||||||
|
|
||||||
|
const errorPlugin: AiPlugin = {
|
||||||
|
name: 'error-handler',
|
||||||
|
onError: vi.fn()
|
||||||
|
}
|
||||||
|
|
||||||
|
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
executorWithPlugin.streamText({
|
||||||
|
model: 'gpt-4',
|
||||||
|
messages: testMessages.simple
|
||||||
|
})
|
||||||
|
).rejects.toThrow('Stream error')
|
||||||
|
|
||||||
|
expect(errorPlugin.onError).toHaveBeenCalledWith(
|
||||||
|
error,
|
||||||
|
expect.objectContaining({
|
||||||
|
providerId: 'openai',
|
||||||
|
modelId: 'gpt-4'
|
||||||
|
})
|
||||||
|
)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -41,6 +41,7 @@ export enum IpcChannel {
|
|||||||
App_SetFullScreen = 'app:set-full-screen',
|
App_SetFullScreen = 'app:set-full-screen',
|
||||||
App_IsFullScreen = 'app:is-full-screen',
|
App_IsFullScreen = 'app:is-full-screen',
|
||||||
App_GetSystemFonts = 'app:get-system-fonts',
|
App_GetSystemFonts = 'app:get-system-fonts',
|
||||||
|
APP_CrashRenderProcess = 'app:crash-render-process',
|
||||||
|
|
||||||
App_MacIsProcessTrusted = 'app:mac-is-process-trusted',
|
App_MacIsProcessTrusted = 'app:mac-is-process-trusted',
|
||||||
App_MacRequestProcessTrust = 'app:mac-request-process-trust',
|
App_MacRequestProcessTrust = 'app:mac-request-process-trust',
|
||||||
@@ -234,6 +235,7 @@ export enum IpcChannel {
|
|||||||
System_GetDeviceType = 'system:getDeviceType',
|
System_GetDeviceType = 'system:getDeviceType',
|
||||||
System_GetHostname = 'system:getHostname',
|
System_GetHostname = 'system:getHostname',
|
||||||
System_GetCpuName = 'system:getCpuName',
|
System_GetCpuName = 'system:getCpuName',
|
||||||
|
System_CheckGitBash = 'system:checkGitBash',
|
||||||
|
|
||||||
// DevTools
|
// DevTools
|
||||||
System_ToggleDevTools = 'system:toggleDevTools',
|
System_ToggleDevTools = 'system:toggleDevTools',
|
||||||
|
|||||||
@@ -197,12 +197,22 @@ export enum FeedUrl {
|
|||||||
GITHUB_LATEST = 'https://github.com/CherryHQ/cherry-studio/releases/latest/download'
|
GITHUB_LATEST = 'https://github.com/CherryHQ/cherry-studio/releases/latest/download'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export enum UpdateConfigUrl {
|
||||||
|
GITHUB = 'https://raw.githubusercontent.com/CherryHQ/cherry-studio/refs/heads/x-files/app-upgrade-config/app-upgrade-config.json',
|
||||||
|
GITCODE = 'https://raw.gitcode.com/CherryHQ/cherry-studio/raw/x-files%2Fapp-upgrade-config/app-upgrade-config.json'
|
||||||
|
}
|
||||||
|
|
||||||
export enum UpgradeChannel {
|
export enum UpgradeChannel {
|
||||||
LATEST = 'latest', // 最新稳定版本
|
LATEST = 'latest', // 最新稳定版本
|
||||||
RC = 'rc', // 公测版本
|
RC = 'rc', // 公测版本
|
||||||
BETA = 'beta' // 预览版本
|
BETA = 'beta' // 预览版本
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export enum UpdateMirror {
|
||||||
|
GITHUB = 'github',
|
||||||
|
GITCODE = 'gitcode'
|
||||||
|
}
|
||||||
|
|
||||||
export const defaultTimeout = 10 * 1000 * 60
|
export const defaultTimeout = 10 * 1000 * 60
|
||||||
|
|
||||||
export const occupiedDirs = ['logs', 'Network', 'Partitions/webview/Network']
|
export const occupiedDirs = ['logs', 'Network', 'Partitions/webview/Network']
|
||||||
|
|||||||
@@ -4,3 +4,34 @@ export const defaultAppHeaders = () => {
|
|||||||
'X-Title': 'Cherry Studio'
|
'X-Title': 'Cherry Studio'
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Following two function are not being used for now.
|
||||||
|
// I may use them in the future, so just keep them commented. - by eurfelux
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts an `undefined` value to `null`, otherwise returns the value as-is.
|
||||||
|
* @param value - The value to check
|
||||||
|
* @returns `null` if the input is `undefined`; otherwise the input value
|
||||||
|
*/
|
||||||
|
|
||||||
|
// export function toNullIfUndefined<T>(value: T | undefined): T | null {
|
||||||
|
// if (value === undefined) {
|
||||||
|
// return null
|
||||||
|
// } else {
|
||||||
|
// return value
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Converts a `null` value to `undefined`, otherwise returns the value as-is.
|
||||||
|
* @param value - The value to check
|
||||||
|
* @returns `undefined` if the input is `null`; otherwise the input value
|
||||||
|
*/
|
||||||
|
|
||||||
|
// export function toUndefinedIfNull<T>(value: T | null): T | undefined {
|
||||||
|
// if (value === null) {
|
||||||
|
// return undefined
|
||||||
|
// } else {
|
||||||
|
// return value
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
|||||||
532
scripts/update-app-upgrade-config.ts
Normal file
532
scripts/update-app-upgrade-config.ts
Normal file
@@ -0,0 +1,532 @@
|
|||||||
|
import fs from 'fs/promises'
|
||||||
|
import path from 'path'
|
||||||
|
import semver from 'semver'
|
||||||
|
|
||||||
|
type UpgradeChannel = 'latest' | 'rc' | 'beta'
|
||||||
|
type UpdateMirror = 'github' | 'gitcode'
|
||||||
|
|
||||||
|
const CHANNELS: UpgradeChannel[] = ['latest', 'rc', 'beta']
|
||||||
|
const MIRRORS: UpdateMirror[] = ['github', 'gitcode']
|
||||||
|
const GITHUB_REPO = 'CherryHQ/cherry-studio'
|
||||||
|
const GITCODE_REPO = 'CherryHQ/cherry-studio'
|
||||||
|
const DEFAULT_FEED_TEMPLATES: Record<UpdateMirror, string> = {
|
||||||
|
github: `https://github.com/${GITHUB_REPO}/releases/download/{{tag}}`,
|
||||||
|
gitcode: `https://gitcode.com/${GITCODE_REPO}/releases/download/{{tag}}`
|
||||||
|
}
|
||||||
|
const GITCODE_LATEST_FALLBACK = 'https://releases.cherry-ai.com'
|
||||||
|
|
||||||
|
interface CliOptions {
|
||||||
|
tag?: string
|
||||||
|
configPath?: string
|
||||||
|
segmentsPath?: string
|
||||||
|
dryRun?: boolean
|
||||||
|
skipReleaseChecks?: boolean
|
||||||
|
isPrerelease?: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ChannelTemplateConfig {
|
||||||
|
feedTemplates?: Partial<Record<UpdateMirror, string>>
|
||||||
|
}
|
||||||
|
|
||||||
|
interface SegmentMatchRule {
|
||||||
|
range?: string
|
||||||
|
exact?: string[]
|
||||||
|
excludeExact?: string[]
|
||||||
|
}
|
||||||
|
|
||||||
|
interface SegmentDefinition {
|
||||||
|
id: string
|
||||||
|
type: 'legacy' | 'breaking' | 'latest'
|
||||||
|
match: SegmentMatchRule
|
||||||
|
lockedVersion?: string
|
||||||
|
minCompatibleVersion: string
|
||||||
|
description: string
|
||||||
|
channelTemplates?: Partial<Record<UpgradeChannel, ChannelTemplateConfig>>
|
||||||
|
}
|
||||||
|
|
||||||
|
interface SegmentMetadataFile {
|
||||||
|
segments: SegmentDefinition[]
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ChannelConfig {
|
||||||
|
version: string
|
||||||
|
feedUrls: Record<UpdateMirror, string>
|
||||||
|
}
|
||||||
|
|
||||||
|
interface VersionMetadata {
|
||||||
|
segmentId: string
|
||||||
|
segmentType?: string
|
||||||
|
}
|
||||||
|
|
||||||
|
interface VersionEntry {
|
||||||
|
metadata?: VersionMetadata
|
||||||
|
minCompatibleVersion: string
|
||||||
|
description: string
|
||||||
|
channels: Record<UpgradeChannel, ChannelConfig | null>
|
||||||
|
}
|
||||||
|
|
||||||
|
interface UpgradeConfigFile {
|
||||||
|
lastUpdated: string
|
||||||
|
versions: Record<string, VersionEntry>
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ReleaseInfo {
|
||||||
|
tag: string
|
||||||
|
version: string
|
||||||
|
channel: UpgradeChannel
|
||||||
|
}
|
||||||
|
|
||||||
|
interface UpdateVersionsResult {
|
||||||
|
versions: Record<string, VersionEntry>
|
||||||
|
updated: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
const ROOT_DIR = path.resolve(__dirname, '..')
|
||||||
|
const DEFAULT_CONFIG_PATH = path.join(ROOT_DIR, 'app-upgrade-config.json')
|
||||||
|
const DEFAULT_SEGMENTS_PATH = path.join(ROOT_DIR, 'config/app-upgrade-segments.json')
|
||||||
|
|
||||||
|
async function main() {
|
||||||
|
const options = parseArgs()
|
||||||
|
const releaseTag = resolveTag(options)
|
||||||
|
const normalizedVersion = normalizeVersion(releaseTag)
|
||||||
|
const releaseChannel = detectChannel(normalizedVersion)
|
||||||
|
if (!releaseChannel) {
|
||||||
|
console.warn(`[update-app-upgrade-config] Tag ${normalizedVersion} does not map to beta/rc/latest. Skipping.`)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate version format matches prerelease status
|
||||||
|
if (options.isPrerelease !== undefined) {
|
||||||
|
const hasPrereleaseSuffix = releaseChannel === 'beta' || releaseChannel === 'rc'
|
||||||
|
|
||||||
|
if (options.isPrerelease && !hasPrereleaseSuffix) {
|
||||||
|
console.warn(
|
||||||
|
`[update-app-upgrade-config] ⚠️ Release marked as prerelease but version ${normalizedVersion} has no beta/rc suffix. Skipping.`
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!options.isPrerelease && hasPrereleaseSuffix) {
|
||||||
|
console.warn(
|
||||||
|
`[update-app-upgrade-config] ⚠️ Release marked as latest but version ${normalizedVersion} has prerelease suffix (${releaseChannel}). Skipping.`
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const [config, segmentFile] = await Promise.all([
|
||||||
|
readJson<UpgradeConfigFile>(options.configPath ?? DEFAULT_CONFIG_PATH),
|
||||||
|
readJson<SegmentMetadataFile>(options.segmentsPath ?? DEFAULT_SEGMENTS_PATH)
|
||||||
|
])
|
||||||
|
|
||||||
|
const segment = pickSegment(segmentFile.segments, normalizedVersion)
|
||||||
|
if (!segment) {
|
||||||
|
throw new Error(`Unable to find upgrade segment for version ${normalizedVersion}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
if (segment.lockedVersion && segment.lockedVersion !== normalizedVersion) {
|
||||||
|
throw new Error(`Segment ${segment.id} is locked to ${segment.lockedVersion}, but received ${normalizedVersion}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
const releaseInfo: ReleaseInfo = {
|
||||||
|
tag: formatTag(releaseTag),
|
||||||
|
version: normalizedVersion,
|
||||||
|
channel: releaseChannel
|
||||||
|
}
|
||||||
|
|
||||||
|
const { versions: updatedVersions, updated } = await updateVersions(
|
||||||
|
config.versions,
|
||||||
|
segment,
|
||||||
|
releaseInfo,
|
||||||
|
Boolean(options.skipReleaseChecks)
|
||||||
|
)
|
||||||
|
|
||||||
|
if (!updated) {
|
||||||
|
throw new Error(
|
||||||
|
`[update-app-upgrade-config] Feed URLs are not ready for ${releaseInfo.version} (${releaseInfo.channel}). Try again after the release mirrors finish syncing.`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const updatedConfig: UpgradeConfigFile = {
|
||||||
|
...config,
|
||||||
|
lastUpdated: new Date().toISOString(),
|
||||||
|
versions: updatedVersions
|
||||||
|
}
|
||||||
|
|
||||||
|
const output = JSON.stringify(updatedConfig, null, 2) + '\n'
|
||||||
|
|
||||||
|
if (options.dryRun) {
|
||||||
|
console.log('Dry run enabled. Generated configuration:\n')
|
||||||
|
console.log(output)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
await fs.writeFile(options.configPath ?? DEFAULT_CONFIG_PATH, output, 'utf-8')
|
||||||
|
console.log(
|
||||||
|
`✅ Updated ${path.relative(process.cwd(), options.configPath ?? DEFAULT_CONFIG_PATH)} for ${segment.id} (${releaseInfo.channel}) -> ${releaseInfo.version}`
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
function parseArgs(): CliOptions {
|
||||||
|
const args = process.argv.slice(2)
|
||||||
|
const options: CliOptions = {}
|
||||||
|
|
||||||
|
for (let i = 0; i < args.length; i += 1) {
|
||||||
|
const arg = args[i]
|
||||||
|
if (arg === '--tag') {
|
||||||
|
options.tag = args[i + 1]
|
||||||
|
i += 1
|
||||||
|
} else if (arg === '--config') {
|
||||||
|
options.configPath = args[i + 1]
|
||||||
|
i += 1
|
||||||
|
} else if (arg === '--segments') {
|
||||||
|
options.segmentsPath = args[i + 1]
|
||||||
|
i += 1
|
||||||
|
} else if (arg === '--dry-run') {
|
||||||
|
options.dryRun = true
|
||||||
|
} else if (arg === '--skip-release-checks') {
|
||||||
|
options.skipReleaseChecks = true
|
||||||
|
} else if (arg === '--is-prerelease') {
|
||||||
|
options.isPrerelease = args[i + 1] === 'true'
|
||||||
|
i += 1
|
||||||
|
} else if (arg === '--help') {
|
||||||
|
printHelp()
|
||||||
|
process.exit(0)
|
||||||
|
} else {
|
||||||
|
console.warn(`Ignoring unknown argument "${arg}"`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (options.skipReleaseChecks && !options.dryRun) {
|
||||||
|
throw new Error('--skip-release-checks can only be used together with --dry-run')
|
||||||
|
}
|
||||||
|
|
||||||
|
return options
|
||||||
|
}
|
||||||
|
|
||||||
|
function printHelp() {
|
||||||
|
console.log(`Usage: tsx scripts/update-app-upgrade-config.ts [options]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--tag <tag> Release tag (e.g. v2.1.6). Falls back to GITHUB_REF_NAME/RELEASE_TAG.
|
||||||
|
--config <path> Path to app-upgrade-config.json.
|
||||||
|
--segments <path> Path to app-upgrade-segments.json.
|
||||||
|
--is-prerelease <true|false> Whether this is a prerelease (validates version format).
|
||||||
|
--dry-run Print the result without writing to disk.
|
||||||
|
--skip-release-checks Skip release page availability checks (only valid with --dry-run).
|
||||||
|
--help Show this help message.`)
|
||||||
|
}
|
||||||
|
|
||||||
|
function resolveTag(options: CliOptions): string {
|
||||||
|
const envTag = process.env.RELEASE_TAG ?? process.env.GITHUB_REF_NAME ?? process.env.TAG_NAME
|
||||||
|
const tag = options.tag ?? envTag
|
||||||
|
|
||||||
|
if (!tag) {
|
||||||
|
throw new Error('A release tag is required. Pass --tag or set RELEASE_TAG/GITHUB_REF_NAME.')
|
||||||
|
}
|
||||||
|
|
||||||
|
return tag
|
||||||
|
}
|
||||||
|
|
||||||
|
function normalizeVersion(tag: string): string {
|
||||||
|
const cleaned = semver.clean(tag, { loose: true })
|
||||||
|
if (!cleaned) {
|
||||||
|
throw new Error(`Tag "${tag}" is not a valid semantic version`)
|
||||||
|
}
|
||||||
|
|
||||||
|
const valid = semver.valid(cleaned, { loose: true })
|
||||||
|
if (!valid) {
|
||||||
|
throw new Error(`Unable to normalize tag "${tag}" to a valid semantic version`)
|
||||||
|
}
|
||||||
|
|
||||||
|
return valid
|
||||||
|
}
|
||||||
|
|
||||||
|
function detectChannel(version: string): UpgradeChannel | null {
|
||||||
|
const parsed = semver.parse(version, { loose: true, includePrerelease: true })
|
||||||
|
if (!parsed) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
if (parsed.prerelease.length === 0) {
|
||||||
|
return 'latest'
|
||||||
|
}
|
||||||
|
|
||||||
|
const label = String(parsed.prerelease[0]).toLowerCase()
|
||||||
|
if (label === 'beta') {
|
||||||
|
return 'beta'
|
||||||
|
}
|
||||||
|
if (label === 'rc') {
|
||||||
|
return 'rc'
|
||||||
|
}
|
||||||
|
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
async function readJson<T>(filePath: string): Promise<T> {
|
||||||
|
const absolute = path.isAbsolute(filePath) ? filePath : path.resolve(filePath)
|
||||||
|
const data = await fs.readFile(absolute, 'utf-8')
|
||||||
|
return JSON.parse(data) as T
|
||||||
|
}
|
||||||
|
|
||||||
|
function pickSegment(segments: SegmentDefinition[], version: string): SegmentDefinition | null {
|
||||||
|
for (const segment of segments) {
|
||||||
|
if (matchesSegment(segment.match, version)) {
|
||||||
|
return segment
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
function matchesSegment(matchRule: SegmentMatchRule, version: string): boolean {
|
||||||
|
if (matchRule.exact && matchRule.exact.includes(version)) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if (matchRule.excludeExact && matchRule.excludeExact.includes(version)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if (matchRule.range && !semver.satisfies(version, matchRule.range, { includePrerelease: true })) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if (matchRule.exact) {
|
||||||
|
return matchRule.exact.includes(version)
|
||||||
|
}
|
||||||
|
|
||||||
|
return Boolean(matchRule.range)
|
||||||
|
}
|
||||||
|
|
||||||
|
function formatTag(tag: string): string {
|
||||||
|
if (tag.startsWith('refs/tags/')) {
|
||||||
|
return tag.replace('refs/tags/', '')
|
||||||
|
}
|
||||||
|
return tag
|
||||||
|
}
|
||||||
|
|
||||||
|
async function updateVersions(
|
||||||
|
versions: Record<string, VersionEntry>,
|
||||||
|
segment: SegmentDefinition,
|
||||||
|
releaseInfo: ReleaseInfo,
|
||||||
|
skipReleaseValidation: boolean
|
||||||
|
): Promise<UpdateVersionsResult> {
|
||||||
|
const versionsCopy: Record<string, VersionEntry> = { ...versions }
|
||||||
|
const existingKey = findVersionKeyBySegment(versionsCopy, segment.id)
|
||||||
|
const targetKey = resolveVersionKey(existingKey, segment, releaseInfo)
|
||||||
|
const shouldRename = existingKey && existingKey !== targetKey
|
||||||
|
|
||||||
|
let entry: VersionEntry
|
||||||
|
if (existingKey) {
|
||||||
|
entry = { ...versionsCopy[existingKey], channels: { ...versionsCopy[existingKey].channels } }
|
||||||
|
} else {
|
||||||
|
entry = createEmptyVersionEntry()
|
||||||
|
}
|
||||||
|
|
||||||
|
entry.channels = ensureChannelSlots(entry.channels)
|
||||||
|
|
||||||
|
const channelUpdated = await applyChannelUpdate(entry, segment, releaseInfo, skipReleaseValidation)
|
||||||
|
if (!channelUpdated) {
|
||||||
|
return { versions, updated: false }
|
||||||
|
}
|
||||||
|
|
||||||
|
if (shouldRename && existingKey) {
|
||||||
|
delete versionsCopy[existingKey]
|
||||||
|
}
|
||||||
|
|
||||||
|
entry.metadata = {
|
||||||
|
segmentId: segment.id,
|
||||||
|
segmentType: segment.type
|
||||||
|
}
|
||||||
|
entry.minCompatibleVersion = segment.minCompatibleVersion
|
||||||
|
entry.description = segment.description
|
||||||
|
|
||||||
|
versionsCopy[targetKey] = entry
|
||||||
|
return {
|
||||||
|
versions: sortVersionMap(versionsCopy),
|
||||||
|
updated: true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function findVersionKeyBySegment(versions: Record<string, VersionEntry>, segmentId: string): string | null {
|
||||||
|
for (const [key, value] of Object.entries(versions)) {
|
||||||
|
if (value.metadata?.segmentId === segmentId) {
|
||||||
|
return key
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
|
function resolveVersionKey(existingKey: string | null, segment: SegmentDefinition, releaseInfo: ReleaseInfo): string {
|
||||||
|
if (segment.lockedVersion) {
|
||||||
|
return segment.lockedVersion
|
||||||
|
}
|
||||||
|
|
||||||
|
if (releaseInfo.channel === 'latest') {
|
||||||
|
return releaseInfo.version
|
||||||
|
}
|
||||||
|
|
||||||
|
if (existingKey) {
|
||||||
|
return existingKey
|
||||||
|
}
|
||||||
|
|
||||||
|
const baseVersion = getBaseVersion(releaseInfo.version)
|
||||||
|
return baseVersion ?? releaseInfo.version
|
||||||
|
}
|
||||||
|
|
||||||
|
function getBaseVersion(version: string): string | null {
|
||||||
|
const parsed = semver.parse(version, { loose: true, includePrerelease: true })
|
||||||
|
if (!parsed) {
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
return `${parsed.major}.${parsed.minor}.${parsed.patch}`
|
||||||
|
}
|
||||||
|
|
||||||
|
function createEmptyVersionEntry(): VersionEntry {
|
||||||
|
return {
|
||||||
|
minCompatibleVersion: '',
|
||||||
|
description: '',
|
||||||
|
channels: {
|
||||||
|
latest: null,
|
||||||
|
rc: null,
|
||||||
|
beta: null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function ensureChannelSlots(
|
||||||
|
channels: Record<UpgradeChannel, ChannelConfig | null>
|
||||||
|
): Record<UpgradeChannel, ChannelConfig | null> {
|
||||||
|
return CHANNELS.reduce(
|
||||||
|
(acc, channel) => {
|
||||||
|
acc[channel] = channels[channel] ?? null
|
||||||
|
return acc
|
||||||
|
},
|
||||||
|
{} as Record<UpgradeChannel, ChannelConfig | null>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
async function applyChannelUpdate(
|
||||||
|
entry: VersionEntry,
|
||||||
|
segment: SegmentDefinition,
|
||||||
|
releaseInfo: ReleaseInfo,
|
||||||
|
skipReleaseValidation: boolean
|
||||||
|
): Promise<boolean> {
|
||||||
|
if (!CHANNELS.includes(releaseInfo.channel)) {
|
||||||
|
throw new Error(`Unsupported channel "${releaseInfo.channel}"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
const feedUrls = buildFeedUrls(segment, releaseInfo)
|
||||||
|
|
||||||
|
if (skipReleaseValidation) {
|
||||||
|
console.warn(
|
||||||
|
`[update-app-upgrade-config] Skipping release availability validation for ${releaseInfo.version} (${releaseInfo.channel}).`
|
||||||
|
)
|
||||||
|
} else {
|
||||||
|
const availability = await ensureReleaseAvailability(releaseInfo)
|
||||||
|
if (!availability.github) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if (releaseInfo.channel === 'latest' && !availability.gitcode) {
|
||||||
|
console.warn(
|
||||||
|
`[update-app-upgrade-config] gitcode release page not ready for ${releaseInfo.tag}. Falling back to ${GITCODE_LATEST_FALLBACK}.`
|
||||||
|
)
|
||||||
|
feedUrls.gitcode = GITCODE_LATEST_FALLBACK
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
entry.channels[releaseInfo.channel] = {
|
||||||
|
version: releaseInfo.version,
|
||||||
|
feedUrls
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
function buildFeedUrls(segment: SegmentDefinition, releaseInfo: ReleaseInfo): Record<UpdateMirror, string> {
|
||||||
|
return MIRRORS.reduce(
|
||||||
|
(acc, mirror) => {
|
||||||
|
const template = resolveFeedTemplate(segment, releaseInfo, mirror)
|
||||||
|
acc[mirror] = applyTemplate(template, releaseInfo)
|
||||||
|
return acc
|
||||||
|
},
|
||||||
|
{} as Record<UpdateMirror, string>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
function resolveFeedTemplate(segment: SegmentDefinition, releaseInfo: ReleaseInfo, mirror: UpdateMirror): string {
|
||||||
|
if (mirror === 'gitcode' && releaseInfo.channel !== 'latest') {
|
||||||
|
return segment.channelTemplates?.[releaseInfo.channel]?.feedTemplates?.github ?? DEFAULT_FEED_TEMPLATES.github
|
||||||
|
}
|
||||||
|
|
||||||
|
return segment.channelTemplates?.[releaseInfo.channel]?.feedTemplates?.[mirror] ?? DEFAULT_FEED_TEMPLATES[mirror]
|
||||||
|
}
|
||||||
|
|
||||||
|
function applyTemplate(template: string, releaseInfo: ReleaseInfo): string {
|
||||||
|
return template.replace(/{{\s*tag\s*}}/gi, releaseInfo.tag).replace(/{{\s*version\s*}}/gi, releaseInfo.version)
|
||||||
|
}
|
||||||
|
|
||||||
|
function sortVersionMap(versions: Record<string, VersionEntry>): Record<string, VersionEntry> {
|
||||||
|
const sorted = Object.entries(versions).sort(([a], [b]) => semver.rcompare(a, b))
|
||||||
|
return sorted.reduce(
|
||||||
|
(acc, [version, entry]) => {
|
||||||
|
acc[version] = entry
|
||||||
|
return acc
|
||||||
|
},
|
||||||
|
{} as Record<string, VersionEntry>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ReleaseAvailability {
|
||||||
|
github: boolean
|
||||||
|
gitcode: boolean
|
||||||
|
}
|
||||||
|
|
||||||
|
async function ensureReleaseAvailability(releaseInfo: ReleaseInfo): Promise<ReleaseAvailability> {
|
||||||
|
const mirrorsToCheck: UpdateMirror[] = releaseInfo.channel === 'latest' ? MIRRORS : ['github']
|
||||||
|
const availability: ReleaseAvailability = {
|
||||||
|
github: false,
|
||||||
|
gitcode: releaseInfo.channel === 'latest' ? false : true
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const mirror of mirrorsToCheck) {
|
||||||
|
const url = getReleasePageUrl(mirror, releaseInfo.tag)
|
||||||
|
try {
|
||||||
|
const response = await fetch(url, {
|
||||||
|
method: mirror === 'github' ? 'HEAD' : 'GET',
|
||||||
|
redirect: 'follow'
|
||||||
|
})
|
||||||
|
|
||||||
|
if (response.ok) {
|
||||||
|
availability[mirror] = true
|
||||||
|
} else {
|
||||||
|
console.warn(
|
||||||
|
`[update-app-upgrade-config] ${mirror} release not available for ${releaseInfo.tag} (status ${response.status}, ${url}).`
|
||||||
|
)
|
||||||
|
availability[mirror] = false
|
||||||
|
}
|
||||||
|
} catch (error) {
|
||||||
|
console.warn(
|
||||||
|
`[update-app-upgrade-config] Failed to verify ${mirror} release page for ${releaseInfo.tag} (${url}). Continuing.`,
|
||||||
|
error
|
||||||
|
)
|
||||||
|
availability[mirror] = false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return availability
|
||||||
|
}
|
||||||
|
|
||||||
|
function getReleasePageUrl(mirror: UpdateMirror, tag: string): string {
|
||||||
|
if (mirror === 'github') {
|
||||||
|
return `https://github.com/${GITHUB_REPO}/releases/tag/${encodeURIComponent(tag)}`
|
||||||
|
}
|
||||||
|
// Use latest.yml download URL for GitCode to check if release exists
|
||||||
|
// Note: GitCode returns 401 for HEAD requests, so we use GET in ensureReleaseAvailability
|
||||||
|
return `https://gitcode.com/${GITCODE_REPO}/releases/download/${encodeURIComponent(tag)}/latest.yml`
|
||||||
|
}
|
||||||
|
|
||||||
|
main().catch((error) => {
|
||||||
|
console.error('❌ Failed to update app-upgrade-config:', error)
|
||||||
|
process.exit(1)
|
||||||
|
})
|
||||||
@@ -104,12 +104,6 @@ const router = express
|
|||||||
logger.warn('No models available from providers', { filter })
|
logger.warn('No models available from providers', { filter })
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info('Models response ready', {
|
|
||||||
filter,
|
|
||||||
total: response.total,
|
|
||||||
modelIds: response.data.map((m) => m.id)
|
|
||||||
})
|
|
||||||
|
|
||||||
return res.json(response satisfies ApiModelsResponse)
|
return res.json(response satisfies ApiModelsResponse)
|
||||||
} catch (error: any) {
|
} catch (error: any) {
|
||||||
logger.error('Error fetching models', { error })
|
logger.error('Error fetching models', { error })
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ import { createServer } from 'node:http'
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { IpcChannel } from '@shared/IpcChannel'
|
import { IpcChannel } from '@shared/IpcChannel'
|
||||||
|
|
||||||
import { agentService } from '../services/agents'
|
|
||||||
import { windowService } from '../services/WindowService'
|
import { windowService } from '../services/WindowService'
|
||||||
import { app } from './app'
|
import { app } from './app'
|
||||||
import { config } from './config'
|
import { config } from './config'
|
||||||
@@ -32,11 +31,6 @@ export class ApiServer {
|
|||||||
// Load config
|
// Load config
|
||||||
const { port, host } = await config.load()
|
const { port, host } = await config.load()
|
||||||
|
|
||||||
// Initialize AgentService
|
|
||||||
logger.info('Initializing AgentService')
|
|
||||||
await agentService.initialize()
|
|
||||||
logger.info('AgentService initialized')
|
|
||||||
|
|
||||||
// Create server with Express app
|
// Create server with Express app
|
||||||
this.server = createServer(app)
|
this.server = createServer(app)
|
||||||
this.applyServerTimeouts(this.server)
|
this.applyServerTimeouts(this.server)
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ export class ModelsService {
|
|||||||
|
|
||||||
for (const model of models) {
|
for (const model of models) {
|
||||||
const provider = providers.find((p) => p.id === model.provider)
|
const provider = providers.find((p) => p.id === model.provider)
|
||||||
logger.debug(`Processing model ${model.id}`)
|
// logger.debug(`Processing model ${model.id}`)
|
||||||
if (!provider) {
|
if (!provider) {
|
||||||
logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`)
|
logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`)
|
||||||
continue
|
continue
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ import '@main/config'
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { electronApp, optimizer } from '@electron-toolkit/utils'
|
import { electronApp, optimizer } from '@electron-toolkit/utils'
|
||||||
import { replaceDevtoolsFont } from '@main/utils/windowUtil'
|
import { replaceDevtoolsFont } from '@main/utils/windowUtil'
|
||||||
import { app } from 'electron'
|
import { app, crashReporter } from 'electron'
|
||||||
import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer'
|
import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer'
|
||||||
import { isDev, isLinux, isWin } from './constant'
|
import { isDev, isLinux, isWin } from './constant'
|
||||||
|
|
||||||
@@ -34,9 +34,18 @@ import { TrayService } from './services/TrayService'
|
|||||||
import { versionService } from './services/VersionService'
|
import { versionService } from './services/VersionService'
|
||||||
import { windowService } from './services/WindowService'
|
import { windowService } from './services/WindowService'
|
||||||
import { initWebviewHotkeys } from './services/WebviewService'
|
import { initWebviewHotkeys } from './services/WebviewService'
|
||||||
|
import { runAsyncFunction } from './utils'
|
||||||
|
|
||||||
const logger = loggerService.withContext('MainEntry')
|
const logger = loggerService.withContext('MainEntry')
|
||||||
|
|
||||||
|
// enable local crash reports
|
||||||
|
crashReporter.start({
|
||||||
|
companyName: 'CherryHQ',
|
||||||
|
productName: 'CherryStudio',
|
||||||
|
submitURL: '',
|
||||||
|
uploadToServer: false
|
||||||
|
})
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Disable hardware acceleration if setting is enabled
|
* Disable hardware acceleration if setting is enabled
|
||||||
*/
|
*/
|
||||||
@@ -162,39 +171,33 @@ if (!app.requestSingleInstanceLock()) {
|
|||||||
//start selection assistant service
|
//start selection assistant service
|
||||||
initSelectionService()
|
initSelectionService()
|
||||||
|
|
||||||
// Initialize Agent Service
|
runAsyncFunction(async () => {
|
||||||
try {
|
// Start API server if enabled or if agents exist
|
||||||
await agentService.initialize()
|
try {
|
||||||
logger.info('Agent service initialized successfully')
|
const config = await apiServerService.getCurrentConfig()
|
||||||
} catch (error: any) {
|
logger.info('API server config:', config)
|
||||||
logger.error('Failed to initialize Agent service:', error)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start API server if enabled or if agents exist
|
// Check if there are any agents
|
||||||
try {
|
let shouldStart = config.enabled
|
||||||
const config = await apiServerService.getCurrentConfig()
|
if (!shouldStart) {
|
||||||
logger.info('API server config:', config)
|
try {
|
||||||
|
const { total } = await agentService.listAgents({ limit: 1 })
|
||||||
// Check if there are any agents
|
if (total > 0) {
|
||||||
let shouldStart = config.enabled
|
shouldStart = true
|
||||||
if (!shouldStart) {
|
logger.info(`Detected ${total} agent(s), auto-starting API server`)
|
||||||
try {
|
}
|
||||||
const { total } = await agentService.listAgents({ limit: 1 })
|
} catch (error: any) {
|
||||||
if (total > 0) {
|
logger.warn('Failed to check agent count:', error)
|
||||||
shouldStart = true
|
|
||||||
logger.info(`Detected ${total} agent(s), auto-starting API server`)
|
|
||||||
}
|
}
|
||||||
} catch (error: any) {
|
|
||||||
logger.warn('Failed to check agent count:', error)
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (shouldStart) {
|
if (shouldStart) {
|
||||||
await apiServerService.start()
|
await apiServerService.start()
|
||||||
|
}
|
||||||
|
} catch (error: any) {
|
||||||
|
logger.error('Failed to check/start API server:', error)
|
||||||
}
|
}
|
||||||
} catch (error: any) {
|
})
|
||||||
logger.error('Failed to check/start API server:', error)
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
|
|
||||||
registerProtocolClient(app)
|
registerProtocolClient(app)
|
||||||
|
|||||||
@@ -493,6 +493,44 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux'))
|
ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux'))
|
||||||
ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname())
|
ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname())
|
||||||
ipcMain.handle(IpcChannel.System_GetCpuName, () => require('os').cpus()[0].model)
|
ipcMain.handle(IpcChannel.System_GetCpuName, () => require('os').cpus()[0].model)
|
||||||
|
ipcMain.handle(IpcChannel.System_CheckGitBash, () => {
|
||||||
|
if (!isWin) {
|
||||||
|
return true // Non-Windows systems don't need Git Bash
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Check common Git Bash installation paths
|
||||||
|
const commonPaths = [
|
||||||
|
path.join(process.env.ProgramFiles || 'C:\\Program Files', 'Git', 'bin', 'bash.exe'),
|
||||||
|
path.join(process.env['ProgramFiles(x86)'] || 'C:\\Program Files (x86)', 'Git', 'bin', 'bash.exe'),
|
||||||
|
path.join(process.env.LOCALAPPDATA || '', 'Programs', 'Git', 'bin', 'bash.exe')
|
||||||
|
]
|
||||||
|
|
||||||
|
// Check if any of the common paths exist
|
||||||
|
for (const bashPath of commonPaths) {
|
||||||
|
if (fs.existsSync(bashPath)) {
|
||||||
|
logger.debug('Git Bash found', { path: bashPath })
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if git is in PATH
|
||||||
|
const { execSync } = require('child_process')
|
||||||
|
try {
|
||||||
|
execSync('git --version', { stdio: 'ignore' })
|
||||||
|
logger.debug('Git found in PATH')
|
||||||
|
return true
|
||||||
|
} catch {
|
||||||
|
// Git not in PATH
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.debug('Git Bash not found on Windows system')
|
||||||
|
return false
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Error checking Git Bash', error as Error)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
})
|
||||||
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
|
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
|
||||||
const win = BrowserWindow.fromWebContents(e.sender)
|
const win = BrowserWindow.fromWebContents(e.sender)
|
||||||
win && win.webContents.toggleDevTools()
|
win && win.webContents.toggleDevTools()
|
||||||
@@ -1038,4 +1076,8 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
|||||||
ipcMain.handle(IpcChannel.WebSocket_Status, WebSocketService.getStatus)
|
ipcMain.handle(IpcChannel.WebSocket_Status, WebSocketService.getStatus)
|
||||||
ipcMain.handle(IpcChannel.WebSocket_SendFile, WebSocketService.sendFile)
|
ipcMain.handle(IpcChannel.WebSocket_SendFile, WebSocketService.sendFile)
|
||||||
ipcMain.handle(IpcChannel.WebSocket_GetAllCandidates, WebSocketService.getAllCandidates)
|
ipcMain.handle(IpcChannel.WebSocket_GetAllCandidates, WebSocketService.getAllCandidates)
|
||||||
|
|
||||||
|
ipcMain.handle(IpcChannel.APP_CrashRenderProcess, () => {
|
||||||
|
mainWindow.webContents.forcefullyCrashRenderer()
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ type ApiResponse<T> = {
|
|||||||
type BatchUploadResponse = {
|
type BatchUploadResponse = {
|
||||||
batch_id: string
|
batch_id: string
|
||||||
file_urls: string[]
|
file_urls: string[]
|
||||||
|
headers?: Record<string, string>[]
|
||||||
}
|
}
|
||||||
|
|
||||||
type ExtractProgress = {
|
type ExtractProgress = {
|
||||||
@@ -55,7 +56,7 @@ type QuotaResponse = {
|
|||||||
export default class MineruPreprocessProvider extends BasePreprocessProvider {
|
export default class MineruPreprocessProvider extends BasePreprocessProvider {
|
||||||
constructor(provider: PreprocessProvider, userId?: string) {
|
constructor(provider: PreprocessProvider, userId?: string) {
|
||||||
super(provider, userId)
|
super(provider, userId)
|
||||||
// todo:免费期结束后删除
|
// TODO: remove after free period ends
|
||||||
this.provider.apiKey = this.provider.apiKey || import.meta.env.MAIN_VITE_MINERU_API_KEY
|
this.provider.apiKey = this.provider.apiKey || import.meta.env.MAIN_VITE_MINERU_API_KEY
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -68,21 +69,21 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
|
|||||||
logger.info(`MinerU preprocess processing started: ${filePath}`)
|
logger.info(`MinerU preprocess processing started: ${filePath}`)
|
||||||
await this.validateFile(filePath)
|
await this.validateFile(filePath)
|
||||||
|
|
||||||
// 1. 获取上传URL并上传文件
|
// 1. Get upload URL and upload file
|
||||||
const batchId = await this.uploadFile(file)
|
const batchId = await this.uploadFile(file)
|
||||||
logger.info(`MinerU file upload completed: batch_id=${batchId}`)
|
logger.info(`MinerU file upload completed: batch_id=${batchId}`)
|
||||||
|
|
||||||
// 2. 等待处理完成并获取结果
|
// 2. Wait for completion and fetch results
|
||||||
const extractResult = await this.waitForCompletion(sourceId, batchId, file.origin_name)
|
const extractResult = await this.waitForCompletion(sourceId, batchId, file.origin_name)
|
||||||
logger.info(`MinerU processing completed for batch: ${batchId}`)
|
logger.info(`MinerU processing completed for batch: ${batchId}`)
|
||||||
|
|
||||||
// 3. 下载并解压文件
|
// 3. Download and extract output
|
||||||
const { path: outputPath } = await this.downloadAndExtractFile(extractResult.full_zip_url!, file)
|
const { path: outputPath } = await this.downloadAndExtractFile(extractResult.full_zip_url!, file)
|
||||||
|
|
||||||
// 4. check quota
|
// 4. check quota
|
||||||
const quota = await this.checkQuota()
|
const quota = await this.checkQuota()
|
||||||
|
|
||||||
// 5. 创建处理后的文件信息
|
// 5. Create processed file metadata
|
||||||
return {
|
return {
|
||||||
processedFile: this.createProcessedFileInfo(file, outputPath),
|
processedFile: this.createProcessedFileInfo(file, outputPath),
|
||||||
quota
|
quota
|
||||||
@@ -115,23 +116,48 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private async validateFile(filePath: string): Promise<void> {
|
private async validateFile(filePath: string): Promise<void> {
|
||||||
|
// Phase 1: check file size (without loading into memory)
|
||||||
|
logger.info(`Validating PDF file: ${filePath}`)
|
||||||
|
const stats = await fs.promises.stat(filePath)
|
||||||
|
const fileSizeBytes = stats.size
|
||||||
|
|
||||||
|
// Ensure file size is under 200MB
|
||||||
|
if (fileSizeBytes >= 200 * 1024 * 1024) {
|
||||||
|
const fileSizeMB = Math.round(fileSizeBytes / (1024 * 1024))
|
||||||
|
throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 200MB`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phase 2: check page count (requires reading file with error handling)
|
||||||
const pdfBuffer = await fs.promises.readFile(filePath)
|
const pdfBuffer = await fs.promises.readFile(filePath)
|
||||||
|
|
||||||
const doc = await this.readPdf(pdfBuffer)
|
try {
|
||||||
|
const doc = await this.readPdf(pdfBuffer)
|
||||||
|
|
||||||
// 文件页数小于600页
|
// Ensure page count is under 600 pages
|
||||||
if (doc.numPages >= 600) {
|
if (doc.numPages >= 600) {
|
||||||
throw new Error(`PDF page count (${doc.numPages}) exceeds the limit of 600 pages`)
|
throw new Error(`PDF page count (${doc.numPages}) exceeds the limit of 600 pages`)
|
||||||
}
|
}
|
||||||
// 文件大小小于200MB
|
|
||||||
if (pdfBuffer.length >= 200 * 1024 * 1024) {
|
logger.info(`PDF validation passed: ${doc.numPages} pages, ${Math.round(fileSizeBytes / (1024 * 1024))}MB`)
|
||||||
const fileSizeMB = Math.round(pdfBuffer.length / (1024 * 1024))
|
} catch (error: any) {
|
||||||
throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 200MB`)
|
// If the page limit is exceeded, rethrow immediately
|
||||||
|
if (error.message.includes('exceeds the limit')) {
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
|
||||||
|
// If PDF parsing fails, log a detailed warning but continue processing
|
||||||
|
logger.warn(
|
||||||
|
`Failed to parse PDF structure (file may be corrupted or use non-standard format). ` +
|
||||||
|
`Skipping page count validation. Will attempt to process with MinerU API. ` +
|
||||||
|
`Error details: ${error.message}. ` +
|
||||||
|
`Suggestion: If processing fails, try repairing the PDF using tools like Adobe Acrobat or online PDF repair services.`
|
||||||
|
)
|
||||||
|
// Do not throw; continue processing
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private createProcessedFileInfo(file: FileMetadata, outputPath: string): FileMetadata {
|
private createProcessedFileInfo(file: FileMetadata, outputPath: string): FileMetadata {
|
||||||
// 查找解压后的主要文件
|
// Locate the main extracted file
|
||||||
let finalPath = ''
|
let finalPath = ''
|
||||||
let finalName = file.origin_name.replace('.pdf', '.md')
|
let finalName = file.origin_name.replace('.pdf', '.md')
|
||||||
|
|
||||||
@@ -143,14 +169,14 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
|
|||||||
const originalMdPath = path.join(outputPath, mdFile)
|
const originalMdPath = path.join(outputPath, mdFile)
|
||||||
const newMdPath = path.join(outputPath, finalName)
|
const newMdPath = path.join(outputPath, finalName)
|
||||||
|
|
||||||
// 重命名文件为原始文件名
|
// Rename the file to match the original name
|
||||||
try {
|
try {
|
||||||
fs.renameSync(originalMdPath, newMdPath)
|
fs.renameSync(originalMdPath, newMdPath)
|
||||||
finalPath = newMdPath
|
finalPath = newMdPath
|
||||||
logger.info(`Renamed markdown file from ${mdFile} to ${finalName}`)
|
logger.info(`Renamed markdown file from ${mdFile} to ${finalName}`)
|
||||||
} catch (renameError) {
|
} catch (renameError) {
|
||||||
logger.warn(`Failed to rename file ${mdFile} to ${finalName}: ${renameError}`)
|
logger.warn(`Failed to rename file ${mdFile} to ${finalName}: ${renameError}`)
|
||||||
// 如果重命名失败,使用原文件
|
// If renaming fails, fall back to the original file
|
||||||
finalPath = originalMdPath
|
finalPath = originalMdPath
|
||||||
finalName = mdFile
|
finalName = mdFile
|
||||||
}
|
}
|
||||||
@@ -178,7 +204,7 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
|
|||||||
logger.info(`Downloading MinerU result to: ${zipPath}`)
|
logger.info(`Downloading MinerU result to: ${zipPath}`)
|
||||||
|
|
||||||
try {
|
try {
|
||||||
// 下载ZIP文件
|
// Download the ZIP file
|
||||||
const response = await net.fetch(zipUrl, { method: 'GET' })
|
const response = await net.fetch(zipUrl, { method: 'GET' })
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
throw new Error(`HTTP ${response.status}: ${response.statusText}`)
|
throw new Error(`HTTP ${response.status}: ${response.statusText}`)
|
||||||
@@ -187,17 +213,17 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
|
|||||||
fs.writeFileSync(zipPath, Buffer.from(arrayBuffer))
|
fs.writeFileSync(zipPath, Buffer.from(arrayBuffer))
|
||||||
logger.info(`Downloaded ZIP file: ${zipPath}`)
|
logger.info(`Downloaded ZIP file: ${zipPath}`)
|
||||||
|
|
||||||
// 确保提取目录存在
|
// Ensure the extraction directory exists
|
||||||
if (!fs.existsSync(extractPath)) {
|
if (!fs.existsSync(extractPath)) {
|
||||||
fs.mkdirSync(extractPath, { recursive: true })
|
fs.mkdirSync(extractPath, { recursive: true })
|
||||||
}
|
}
|
||||||
|
|
||||||
// 解压文件
|
// Extract the ZIP contents
|
||||||
const zip = new AdmZip(zipPath)
|
const zip = new AdmZip(zipPath)
|
||||||
zip.extractAllTo(extractPath, true)
|
zip.extractAllTo(extractPath, true)
|
||||||
logger.info(`Extracted files to: ${extractPath}`)
|
logger.info(`Extracted files to: ${extractPath}`)
|
||||||
|
|
||||||
// 删除临时ZIP文件
|
// Remove the temporary ZIP file
|
||||||
fs.unlinkSync(zipPath)
|
fs.unlinkSync(zipPath)
|
||||||
|
|
||||||
return { path: extractPath }
|
return { path: extractPath }
|
||||||
@@ -209,11 +235,11 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
|
|||||||
|
|
||||||
private async uploadFile(file: FileMetadata): Promise<string> {
|
private async uploadFile(file: FileMetadata): Promise<string> {
|
||||||
try {
|
try {
|
||||||
// 步骤1: 获取上传URL
|
// Step 1: obtain the upload URL
|
||||||
const { batchId, fileUrls } = await this.getBatchUploadUrls(file)
|
const { batchId, fileUrls, uploadHeaders } = await this.getBatchUploadUrls(file)
|
||||||
// 步骤2: 上传文件到获取的URL
|
// Step 2: upload the file to the obtained URL
|
||||||
const filePath = fileStorage.getFilePathById(file)
|
const filePath = fileStorage.getFilePathById(file)
|
||||||
await this.putFileToUrl(filePath, fileUrls[0])
|
await this.putFileToUrl(filePath, fileUrls[0], file.origin_name, uploadHeaders?.[0])
|
||||||
logger.info(`File uploaded successfully: ${filePath}`, { batchId, fileUrls })
|
logger.info(`File uploaded successfully: ${filePath}`, { batchId, fileUrls })
|
||||||
|
|
||||||
return batchId
|
return batchId
|
||||||
@@ -223,7 +249,9 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async getBatchUploadUrls(file: FileMetadata): Promise<{ batchId: string; fileUrls: string[] }> {
|
private async getBatchUploadUrls(
|
||||||
|
file: FileMetadata
|
||||||
|
): Promise<{ batchId: string; fileUrls: string[]; uploadHeaders?: Record<string, string>[] }> {
|
||||||
const endpoint = `${this.provider.apiHost}/api/v4/file-urls/batch`
|
const endpoint = `${this.provider.apiHost}/api/v4/file-urls/batch`
|
||||||
|
|
||||||
const payload = {
|
const payload = {
|
||||||
@@ -254,10 +282,11 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
|
|||||||
if (response.ok) {
|
if (response.ok) {
|
||||||
const data: ApiResponse<BatchUploadResponse> = await response.json()
|
const data: ApiResponse<BatchUploadResponse> = await response.json()
|
||||||
if (data.code === 0 && data.data) {
|
if (data.code === 0 && data.data) {
|
||||||
const { batch_id, file_urls } = data.data
|
const { batch_id, file_urls, headers: uploadHeaders } = data.data
|
||||||
return {
|
return {
|
||||||
batchId: batch_id,
|
batchId: batch_id,
|
||||||
fileUrls: file_urls
|
fileUrls: file_urls,
|
||||||
|
uploadHeaders
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
throw new Error(`API returned error: ${data.msg || JSON.stringify(data)}`)
|
throw new Error(`API returned error: ${data.msg || JSON.stringify(data)}`)
|
||||||
@@ -271,18 +300,28 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private async putFileToUrl(filePath: string, uploadUrl: string): Promise<void> {
|
private async putFileToUrl(
|
||||||
|
filePath: string,
|
||||||
|
uploadUrl: string,
|
||||||
|
fileName?: string,
|
||||||
|
headers?: Record<string, string>
|
||||||
|
): Promise<void> {
|
||||||
try {
|
try {
|
||||||
const fileBuffer = await fs.promises.readFile(filePath)
|
const fileBuffer = await fs.promises.readFile(filePath)
|
||||||
|
const fileSize = fileBuffer.byteLength
|
||||||
|
const displayName = fileName ?? path.basename(filePath)
|
||||||
|
|
||||||
|
logger.info(`Uploading file to MinerU OSS: ${displayName} (${fileSize} bytes)`)
|
||||||
|
|
||||||
// https://mineru.net/apiManage/docs
|
// https://mineru.net/apiManage/docs
|
||||||
const response = await net.fetch(uploadUrl, {
|
const response = await net.fetch(uploadUrl, {
|
||||||
method: 'PUT',
|
method: 'PUT',
|
||||||
body: fileBuffer
|
headers,
|
||||||
|
body: new Uint8Array(fileBuffer)
|
||||||
})
|
})
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
// 克隆 response 以避免消费 body stream
|
// Clone the response to avoid consuming the body stream
|
||||||
const responseClone = response.clone()
|
const responseClone = response.clone()
|
||||||
|
|
||||||
try {
|
try {
|
||||||
@@ -353,20 +392,20 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
|
|||||||
try {
|
try {
|
||||||
const result = await this.getExtractResults(batchId)
|
const result = await this.getExtractResults(batchId)
|
||||||
|
|
||||||
// 查找对应文件的处理结果
|
// Find the corresponding file result
|
||||||
const fileResult = result.extract_result.find((item) => item.file_name === fileName)
|
const fileResult = result.extract_result.find((item) => item.file_name === fileName)
|
||||||
if (!fileResult) {
|
if (!fileResult) {
|
||||||
throw new Error(`File ${fileName} not found in batch results`)
|
throw new Error(`File ${fileName} not found in batch results`)
|
||||||
}
|
}
|
||||||
|
|
||||||
// 检查处理状态
|
// Check the processing state
|
||||||
if (fileResult.state === 'done' && fileResult.full_zip_url) {
|
if (fileResult.state === 'done' && fileResult.full_zip_url) {
|
||||||
logger.info(`Processing completed for file: ${fileName}`)
|
logger.info(`Processing completed for file: ${fileName}`)
|
||||||
return fileResult
|
return fileResult
|
||||||
} else if (fileResult.state === 'failed') {
|
} else if (fileResult.state === 'failed') {
|
||||||
throw new Error(`Processing failed for file: ${fileName}, error: ${fileResult.err_msg}`)
|
throw new Error(`Processing failed for file: ${fileName}, error: ${fileResult.err_msg}`)
|
||||||
} else if (fileResult.state === 'running') {
|
} else if (fileResult.state === 'running') {
|
||||||
// 发送进度更新
|
// Send progress updates
|
||||||
if (fileResult.extract_progress) {
|
if (fileResult.extract_progress) {
|
||||||
const progress = Math.round(
|
const progress = Math.round(
|
||||||
(fileResult.extract_progress.extracted_pages / fileResult.extract_progress.total_pages) * 100
|
(fileResult.extract_progress.extracted_pages / fileResult.extract_progress.total_pages) * 100
|
||||||
@@ -374,7 +413,7 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
|
|||||||
await this.sendPreprocessProgress(sourceId, progress)
|
await this.sendPreprocessProgress(sourceId, progress)
|
||||||
logger.info(`File ${fileName} processing progress: ${progress}%`)
|
logger.info(`File ${fileName} processing progress: ${progress}%`)
|
||||||
} else {
|
} else {
|
||||||
// 如果没有具体进度信息,发送一个通用进度
|
// If no detailed progress information is available, send a generic update
|
||||||
await this.sendPreprocessProgress(sourceId, 50)
|
await this.sendPreprocessProgress(sourceId, 50)
|
||||||
logger.info(`File ${fileName} is still processing...`)
|
logger.info(`File ${fileName} is still processing...`)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,18 +53,43 @@ export default class OpenMineruPreprocessProvider extends BasePreprocessProvider
|
|||||||
}
|
}
|
||||||
|
|
||||||
private async validateFile(filePath: string): Promise<void> {
|
private async validateFile(filePath: string): Promise<void> {
|
||||||
|
// 第一阶段:检查文件大小(无需读取文件到内存)
|
||||||
|
logger.info(`Validating PDF file: ${filePath}`)
|
||||||
|
const stats = await fs.promises.stat(filePath)
|
||||||
|
const fileSizeBytes = stats.size
|
||||||
|
|
||||||
|
// File size must be less than 200MB
|
||||||
|
if (fileSizeBytes >= 200 * 1024 * 1024) {
|
||||||
|
const fileSizeMB = Math.round(fileSizeBytes / (1024 * 1024))
|
||||||
|
throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 200MB`)
|
||||||
|
}
|
||||||
|
|
||||||
|
// 第二阶段:检查页数(需要读取文件,带错误处理)
|
||||||
const pdfBuffer = await fs.promises.readFile(filePath)
|
const pdfBuffer = await fs.promises.readFile(filePath)
|
||||||
|
|
||||||
const doc = await this.readPdf(pdfBuffer)
|
try {
|
||||||
|
const doc = await this.readPdf(pdfBuffer)
|
||||||
|
|
||||||
// File page count must be less than 600 pages
|
// File page count must be less than 600 pages
|
||||||
if (doc.numPages >= 600) {
|
if (doc.numPages >= 600) {
|
||||||
throw new Error(`PDF page count (${doc.numPages}) exceeds the limit of 600 pages`)
|
throw new Error(`PDF page count (${doc.numPages}) exceeds the limit of 600 pages`)
|
||||||
}
|
}
|
||||||
// File size must be less than 200MB
|
|
||||||
if (pdfBuffer.length >= 200 * 1024 * 1024) {
|
logger.info(`PDF validation passed: ${doc.numPages} pages, ${Math.round(fileSizeBytes / (1024 * 1024))}MB`)
|
||||||
const fileSizeMB = Math.round(pdfBuffer.length / (1024 * 1024))
|
} catch (error: any) {
|
||||||
throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 200MB`)
|
// 如果是页数超限错误,直接抛出
|
||||||
|
if (error.message.includes('exceeds the limit')) {
|
||||||
|
throw error
|
||||||
|
}
|
||||||
|
|
||||||
|
// PDF 解析失败,记录详细警告但允许继续处理
|
||||||
|
logger.warn(
|
||||||
|
`Failed to parse PDF structure (file may be corrupted or use non-standard format). ` +
|
||||||
|
`Skipping page count validation. Will attempt to process with MinerU API. ` +
|
||||||
|
`Error details: ${error.message}. ` +
|
||||||
|
`Suggestion: If processing fails, try repairing the PDF using tools like Adobe Acrobat or online PDF repair services.`
|
||||||
|
)
|
||||||
|
// 不抛出错误,允许继续处理
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -72,8 +97,8 @@ export default class OpenMineruPreprocessProvider extends BasePreprocessProvider
|
|||||||
// Find the main file after extraction
|
// Find the main file after extraction
|
||||||
let finalPath = ''
|
let finalPath = ''
|
||||||
let finalName = file.origin_name.replace('.pdf', '.md')
|
let finalName = file.origin_name.replace('.pdf', '.md')
|
||||||
// Find the corresponding folder by file name
|
// Find the corresponding folder by file id
|
||||||
outputPath = path.join(outputPath, `${file.origin_name.replace('.pdf', '')}`)
|
outputPath = path.join(outputPath, file.id)
|
||||||
try {
|
try {
|
||||||
const files = fs.readdirSync(outputPath)
|
const files = fs.readdirSync(outputPath)
|
||||||
|
|
||||||
@@ -125,7 +150,7 @@ export default class OpenMineruPreprocessProvider extends BasePreprocessProvider
|
|||||||
formData.append('return_md', 'true')
|
formData.append('return_md', 'true')
|
||||||
formData.append('response_format_zip', 'true')
|
formData.append('response_format_zip', 'true')
|
||||||
formData.append('files', fileBuffer, {
|
formData.append('files', fileBuffer, {
|
||||||
filename: file.origin_name
|
filename: file.name
|
||||||
})
|
})
|
||||||
|
|
||||||
while (retries < maxRetries) {
|
while (retries < maxRetries) {
|
||||||
@@ -139,7 +164,7 @@ export default class OpenMineruPreprocessProvider extends BasePreprocessProvider
|
|||||||
...(this.provider.apiKey ? { Authorization: `Bearer ${this.provider.apiKey}` } : {}),
|
...(this.provider.apiKey ? { Authorization: `Bearer ${this.provider.apiKey}` } : {}),
|
||||||
...formData.getHeaders()
|
...formData.getHeaders()
|
||||||
},
|
},
|
||||||
body: formData.getBuffer()
|
body: new Uint8Array(formData.getBuffer())
|
||||||
})
|
})
|
||||||
|
|
||||||
if (!response.ok) {
|
if (!response.ok) {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import DiDiMcpServer from './didi-mcp'
|
|||||||
import DifyKnowledgeServer from './dify-knowledge'
|
import DifyKnowledgeServer from './dify-knowledge'
|
||||||
import FetchServer from './fetch'
|
import FetchServer from './fetch'
|
||||||
import FileSystemServer from './filesystem'
|
import FileSystemServer from './filesystem'
|
||||||
|
import MCPUIDemoServer from './mcp-ui-demo'
|
||||||
import MemoryServer from './memory'
|
import MemoryServer from './memory'
|
||||||
import PythonServer from './python'
|
import PythonServer from './python'
|
||||||
import ThinkingServer from './sequentialthinking'
|
import ThinkingServer from './sequentialthinking'
|
||||||
@@ -48,6 +49,9 @@ export function createInMemoryMCPServer(
|
|||||||
const apiKey = envs.DIDI_API_KEY
|
const apiKey = envs.DIDI_API_KEY
|
||||||
return new DiDiMcpServer(apiKey).server
|
return new DiDiMcpServer(apiKey).server
|
||||||
}
|
}
|
||||||
|
case BuiltinMCPServerNames.mcpUIDemo: {
|
||||||
|
return new MCPUIDemoServer().server
|
||||||
|
}
|
||||||
default:
|
default:
|
||||||
throw new Error(`Unknown in-memory MCP server: ${name}`)
|
throw new Error(`Unknown in-memory MCP server: ${name}`)
|
||||||
}
|
}
|
||||||
|
|||||||
433
src/main/mcpServers/mcp-ui-demo.ts
Normal file
433
src/main/mcpServers/mcp-ui-demo.ts
Normal file
@@ -0,0 +1,433 @@
|
|||||||
|
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||||
|
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js'
|
||||||
|
|
||||||
|
const server = new Server(
|
||||||
|
{
|
||||||
|
name: 'mcp-ui-demo',
|
||||||
|
version: '1.0.0'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
capabilities: {
|
||||||
|
tools: {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTML templates for different UIs
|
||||||
|
const getHelloWorldUI = () =>
|
||||||
|
`
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: system-ui, -apple-system, sans-serif;
|
||||||
|
padding: 20px;
|
||||||
|
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||||
|
color: white;
|
||||||
|
min-height: 200px;
|
||||||
|
display: flex;
|
||||||
|
align-items: center;
|
||||||
|
justify-content: center;
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
text-align: center;
|
||||||
|
}
|
||||||
|
h1 {
|
||||||
|
font-size: 2.5em;
|
||||||
|
margin-bottom: 10px;
|
||||||
|
text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
|
||||||
|
}
|
||||||
|
p {
|
||||||
|
font-size: 1.2em;
|
||||||
|
opacity: 0.9;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<h1>🎉 Hello from MCP UI!</h1>
|
||||||
|
<p>This is a simple MCP UI Resource rendered in Cherry Studio</p>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
`.trim()
|
||||||
|
|
||||||
|
const getInteractiveUI = () =>
|
||||||
|
`
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: system-ui, -apple-system, sans-serif;
|
||||||
|
padding: 20px;
|
||||||
|
background: #f5f5f5;
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
max-width: 600px;
|
||||||
|
margin: 0 auto;
|
||||||
|
background: white;
|
||||||
|
padding: 30px;
|
||||||
|
border-radius: 12px;
|
||||||
|
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
||||||
|
}
|
||||||
|
h2 {
|
||||||
|
color: #333;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
button {
|
||||||
|
background: #667eea;
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
padding: 12px 24px;
|
||||||
|
border-radius: 6px;
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 16px;
|
||||||
|
margin: 5px;
|
||||||
|
transition: background 0.2s;
|
||||||
|
}
|
||||||
|
button:hover {
|
||||||
|
background: #5568d3;
|
||||||
|
}
|
||||||
|
#output {
|
||||||
|
margin-top: 20px;
|
||||||
|
padding: 15px;
|
||||||
|
background: #f8f9fa;
|
||||||
|
border-radius: 6px;
|
||||||
|
border: 1px solid #e0e0e0;
|
||||||
|
min-height: 50px;
|
||||||
|
}
|
||||||
|
.info {
|
||||||
|
color: #666;
|
||||||
|
font-size: 14px;
|
||||||
|
margin-top: 15px;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<h2>Interactive MCP UI Demo</h2>
|
||||||
|
<p>Click the buttons to interact with MCP tools:</p>
|
||||||
|
|
||||||
|
<button onclick="callEchoTool()">Call Echo Tool</button>
|
||||||
|
<button onclick="getTimestamp()">Get Timestamp</button>
|
||||||
|
<button onclick="openLink()">Open External Link</button>
|
||||||
|
|
||||||
|
<div id="output"></div>
|
||||||
|
|
||||||
|
<div class="info">
|
||||||
|
This UI can communicate with the host application through postMessage API.
|
||||||
|
</div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
function callEchoTool() {
|
||||||
|
const output = document.getElementById('output');
|
||||||
|
output.innerHTML = '<p style="color: #0066cc;">Calling echo tool...</p>';
|
||||||
|
|
||||||
|
window.parent.postMessage({
|
||||||
|
type: 'tool',
|
||||||
|
payload: {
|
||||||
|
toolName: 'demo_echo',
|
||||||
|
params: {
|
||||||
|
message: 'Hello from MCP UI! Time: ' + new Date().toLocaleTimeString()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}, '*');
|
||||||
|
}
|
||||||
|
|
||||||
|
function getTimestamp() {
|
||||||
|
const output = document.getElementById('output');
|
||||||
|
const now = new Date();
|
||||||
|
output.innerHTML = \`
|
||||||
|
<p style="color: #00aa00;">
|
||||||
|
<strong>Current Timestamp:</strong><br/>
|
||||||
|
\${now.toLocaleString()}<br/>
|
||||||
|
Unix: \${Math.floor(now.getTime() / 1000)}
|
||||||
|
</p>
|
||||||
|
\`;
|
||||||
|
}
|
||||||
|
|
||||||
|
function openLink() {
|
||||||
|
window.parent.postMessage({
|
||||||
|
type: 'link',
|
||||||
|
payload: {
|
||||||
|
url: 'https://github.com/idosal/mcp-ui'
|
||||||
|
}
|
||||||
|
}, '*');
|
||||||
|
}
|
||||||
|
|
||||||
|
// Listen for responses
|
||||||
|
window.addEventListener('message', (event) => {
|
||||||
|
if (event.data.type === 'ui-message-response') {
|
||||||
|
const output = document.getElementById('output');
|
||||||
|
const { response, error } = event.data.payload;
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
output.innerHTML = \`<p style="color: #cc0000;">Error: \${error}</p>\`;
|
||||||
|
} else {
|
||||||
|
output.innerHTML = \`<p style="color: #00aa00;">Response: \${JSON.stringify(response, null, 2)}</p>\`;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
`.trim()
|
||||||
|
|
||||||
|
const getFormUI = () =>
|
||||||
|
`
|
||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<style>
|
||||||
|
body {
|
||||||
|
font-family: system-ui, -apple-system, sans-serif;
|
||||||
|
padding: 20px;
|
||||||
|
background: #f5f5f5;
|
||||||
|
margin: 0;
|
||||||
|
}
|
||||||
|
.container {
|
||||||
|
max-width: 500px;
|
||||||
|
margin: 0 auto;
|
||||||
|
background: white;
|
||||||
|
padding: 30px;
|
||||||
|
border-radius: 12px;
|
||||||
|
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
||||||
|
}
|
||||||
|
h2 {
|
||||||
|
color: #333;
|
||||||
|
margin-bottom: 20px;
|
||||||
|
}
|
||||||
|
.form-group {
|
||||||
|
margin-bottom: 15px;
|
||||||
|
}
|
||||||
|
label {
|
||||||
|
display: block;
|
||||||
|
margin-bottom: 5px;
|
||||||
|
color: #555;
|
||||||
|
font-weight: 500;
|
||||||
|
}
|
||||||
|
input, textarea {
|
||||||
|
width: 100%;
|
||||||
|
padding: 10px;
|
||||||
|
border: 1px solid #ddd;
|
||||||
|
border-radius: 6px;
|
||||||
|
font-size: 14px;
|
||||||
|
box-sizing: border-box;
|
||||||
|
}
|
||||||
|
textarea {
|
||||||
|
min-height: 100px;
|
||||||
|
resize: vertical;
|
||||||
|
}
|
||||||
|
button {
|
||||||
|
background: #667eea;
|
||||||
|
color: white;
|
||||||
|
border: none;
|
||||||
|
padding: 12px 24px;
|
||||||
|
border-radius: 6px;
|
||||||
|
cursor: pointer;
|
||||||
|
font-size: 16px;
|
||||||
|
width: 100%;
|
||||||
|
margin-top: 10px;
|
||||||
|
}
|
||||||
|
button:hover {
|
||||||
|
background: #5568d3;
|
||||||
|
}
|
||||||
|
#result {
|
||||||
|
margin-top: 20px;
|
||||||
|
padding: 15px;
|
||||||
|
background: #f8f9fa;
|
||||||
|
border-radius: 6px;
|
||||||
|
display: none;
|
||||||
|
}
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="container">
|
||||||
|
<h2>📝 Form UI Demo</h2>
|
||||||
|
<form id="demoForm" onsubmit="handleSubmit(event)">
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="name">Name:</label>
|
||||||
|
<input type="text" id="name" name="name" required placeholder="Enter your name">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="email">Email:</label>
|
||||||
|
<input type="email" id="email" name="email" required placeholder="your@email.com">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<div class="form-group">
|
||||||
|
<label for="message">Message:</label>
|
||||||
|
<textarea id="message" name="message" required placeholder="Enter your message here..."></textarea>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<button type="submit">Submit Form</button>
|
||||||
|
</form>
|
||||||
|
|
||||||
|
<div id="result"></div>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<script>
|
||||||
|
function handleSubmit(event) {
|
||||||
|
event.preventDefault();
|
||||||
|
|
||||||
|
const formData = new FormData(event.target);
|
||||||
|
const data = Object.fromEntries(formData.entries());
|
||||||
|
|
||||||
|
const result = document.getElementById('result');
|
||||||
|
result.style.display = 'block';
|
||||||
|
result.innerHTML = '<p style="color: #0066cc;">Submitting form...</p>';
|
||||||
|
|
||||||
|
// Send form data to host
|
||||||
|
window.parent.postMessage({
|
||||||
|
type: 'notify',
|
||||||
|
payload: {
|
||||||
|
message: 'Form submitted with data: ' + JSON.stringify(data)
|
||||||
|
}
|
||||||
|
}, '*');
|
||||||
|
|
||||||
|
// Display result
|
||||||
|
result.innerHTML = \`
|
||||||
|
<p style="color: #00aa00;"><strong>Form Submitted!</strong></p>
|
||||||
|
<pre style="background: white; padding: 10px; border-radius: 4px; overflow-x: auto;">\${JSON.stringify(data, null, 2)}</pre>
|
||||||
|
\`;
|
||||||
|
}
|
||||||
|
</script>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
|
`.trim()
|
||||||
|
|
||||||
|
// List available tools
|
||||||
|
server.setRequestHandler(ListToolsRequestSchema, async () => {
|
||||||
|
return {
|
||||||
|
tools: [
|
||||||
|
{
|
||||||
|
name: 'demo_echo',
|
||||||
|
description: 'Echo back the message sent from UI',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
message: {
|
||||||
|
type: 'string',
|
||||||
|
description: 'Message to echo back'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
required: ['message']
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'show_hello_ui',
|
||||||
|
description: 'Display a simple hello world UI with gradient background',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'show_interactive_ui',
|
||||||
|
description:
|
||||||
|
'Display an interactive UI demo with buttons for calling tools, getting timestamps, and opening links',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: 'show_form_ui',
|
||||||
|
description: 'Display a form UI demo with input fields for name, email, and message',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Handle tool calls
|
||||||
|
server.setRequestHandler(CallToolRequestSchema, async (request) => {
|
||||||
|
const { name, arguments: args } = request.params
|
||||||
|
|
||||||
|
if (name === 'demo_echo') {
|
||||||
|
return {
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: JSON.stringify({
|
||||||
|
echo: args?.message || 'No message provided',
|
||||||
|
timestamp: new Date().toISOString()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (name === 'show_hello_ui') {
|
||||||
|
return {
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: JSON.stringify({
|
||||||
|
type: 'resource',
|
||||||
|
resource: {
|
||||||
|
uri: 'ui://demo/hello',
|
||||||
|
mimeType: 'text/html',
|
||||||
|
text: getHelloWorldUI()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (name === 'show_interactive_ui') {
|
||||||
|
return {
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: JSON.stringify({
|
||||||
|
type: 'resource',
|
||||||
|
resource: {
|
||||||
|
uri: 'ui://demo/interactive',
|
||||||
|
mimeType: 'text/html',
|
||||||
|
text: getInteractiveUI()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (name === 'show_form_ui') {
|
||||||
|
return {
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: JSON.stringify({
|
||||||
|
type: 'resource',
|
||||||
|
resource: {
|
||||||
|
uri: 'ui://demo/form',
|
||||||
|
mimeType: 'text/html',
|
||||||
|
text: getFormUI()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new Error(`Unknown tool: ${name}`)
|
||||||
|
})
|
||||||
|
|
||||||
|
class MCPUIDemoServer {
|
||||||
|
public server: Server
|
||||||
|
constructor() {
|
||||||
|
this.server = server
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
export default MCPUIDemoServer
|
||||||
@@ -2,7 +2,7 @@ import { loggerService } from '@logger'
|
|||||||
import { isWin } from '@main/constant'
|
import { isWin } from '@main/constant'
|
||||||
import { getIpCountry } from '@main/utils/ipService'
|
import { getIpCountry } from '@main/utils/ipService'
|
||||||
import { generateUserAgent } from '@main/utils/systemInfo'
|
import { generateUserAgent } from '@main/utils/systemInfo'
|
||||||
import { FeedUrl, UpgradeChannel } from '@shared/config/constant'
|
import { FeedUrl, UpdateConfigUrl, UpdateMirror, UpgradeChannel } from '@shared/config/constant'
|
||||||
import { IpcChannel } from '@shared/IpcChannel'
|
import { IpcChannel } from '@shared/IpcChannel'
|
||||||
import type { UpdateInfo } from 'builder-util-runtime'
|
import type { UpdateInfo } from 'builder-util-runtime'
|
||||||
import { CancellationToken } from 'builder-util-runtime'
|
import { CancellationToken } from 'builder-util-runtime'
|
||||||
@@ -22,7 +22,29 @@ const LANG_MARKERS = {
|
|||||||
EN_START: '<!--LANG:en-->',
|
EN_START: '<!--LANG:en-->',
|
||||||
ZH_CN_START: '<!--LANG:zh-CN-->',
|
ZH_CN_START: '<!--LANG:zh-CN-->',
|
||||||
END: '<!--LANG:END-->'
|
END: '<!--LANG:END-->'
|
||||||
} as const
|
}
|
||||||
|
|
||||||
|
interface UpdateConfig {
|
||||||
|
lastUpdated: string
|
||||||
|
versions: {
|
||||||
|
[versionKey: string]: VersionConfig
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interface VersionConfig {
|
||||||
|
minCompatibleVersion: string
|
||||||
|
description: string
|
||||||
|
channels: {
|
||||||
|
latest: ChannelConfig | null
|
||||||
|
rc: ChannelConfig | null
|
||||||
|
beta: ChannelConfig | null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interface ChannelConfig {
|
||||||
|
version: string
|
||||||
|
feedUrls: Record<UpdateMirror, string>
|
||||||
|
}
|
||||||
|
|
||||||
export default class AppUpdater {
|
export default class AppUpdater {
|
||||||
autoUpdater: _AppUpdater = autoUpdater
|
autoUpdater: _AppUpdater = autoUpdater
|
||||||
@@ -37,7 +59,9 @@ export default class AppUpdater {
|
|||||||
autoUpdater.requestHeaders = {
|
autoUpdater.requestHeaders = {
|
||||||
...autoUpdater.requestHeaders,
|
...autoUpdater.requestHeaders,
|
||||||
'User-Agent': generateUserAgent(),
|
'User-Agent': generateUserAgent(),
|
||||||
'X-Client-Id': configManager.getClientId()
|
'X-Client-Id': configManager.getClientId(),
|
||||||
|
// no-cache
|
||||||
|
'Cache-Control': 'no-cache'
|
||||||
}
|
}
|
||||||
|
|
||||||
autoUpdater.on('error', (error) => {
|
autoUpdater.on('error', (error) => {
|
||||||
@@ -75,61 +99,6 @@ export default class AppUpdater {
|
|||||||
this.autoUpdater = autoUpdater
|
this.autoUpdater = autoUpdater
|
||||||
}
|
}
|
||||||
|
|
||||||
private async _getReleaseVersionFromGithub(channel: UpgradeChannel) {
|
|
||||||
const headers = {
|
|
||||||
Accept: 'application/vnd.github+json',
|
|
||||||
'X-GitHub-Api-Version': '2022-11-28',
|
|
||||||
'Accept-Language': 'en-US,en;q=0.9'
|
|
||||||
}
|
|
||||||
try {
|
|
||||||
logger.info(`get release version from github: ${channel}`)
|
|
||||||
const responses = await net.fetch('https://api.github.com/repos/CherryHQ/cherry-studio/releases?per_page=8', {
|
|
||||||
headers
|
|
||||||
})
|
|
||||||
const data = (await responses.json()) as GithubReleaseInfo[]
|
|
||||||
let mightHaveLatest = false
|
|
||||||
const release: GithubReleaseInfo | undefined = data.find((item: GithubReleaseInfo) => {
|
|
||||||
if (!item.draft && !item.prerelease) {
|
|
||||||
mightHaveLatest = true
|
|
||||||
}
|
|
||||||
|
|
||||||
return item.prerelease && item.tag_name.includes(`-${channel}.`)
|
|
||||||
})
|
|
||||||
|
|
||||||
if (!release) {
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
|
|
||||||
// if the release version is the same as the current version, return null
|
|
||||||
if (release.tag_name === app.getVersion()) {
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
|
|
||||||
if (mightHaveLatest) {
|
|
||||||
logger.info(`might have latest release, get latest release`)
|
|
||||||
const latestReleaseResponse = await net.fetch(
|
|
||||||
'https://api.github.com/repos/CherryHQ/cherry-studio/releases/latest',
|
|
||||||
{
|
|
||||||
headers
|
|
||||||
}
|
|
||||||
)
|
|
||||||
const latestRelease = (await latestReleaseResponse.json()) as GithubReleaseInfo
|
|
||||||
if (semver.gt(latestRelease.tag_name, release.tag_name)) {
|
|
||||||
logger.info(
|
|
||||||
`latest release version is ${latestRelease.tag_name}, prerelease version is ${release.tag_name}, return null`
|
|
||||||
)
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.info(`release url is ${release.tag_name}, set channel to ${channel}`)
|
|
||||||
return `https://github.com/CherryHQ/cherry-studio/releases/download/${release.tag_name}`
|
|
||||||
} catch (error) {
|
|
||||||
logger.error('Failed to get latest not draft version from github:', error as Error)
|
|
||||||
return null
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
public setAutoUpdate(isActive: boolean) {
|
public setAutoUpdate(isActive: boolean) {
|
||||||
autoUpdater.autoDownload = isActive
|
autoUpdater.autoDownload = isActive
|
||||||
autoUpdater.autoInstallOnAppQuit = isActive
|
autoUpdater.autoInstallOnAppQuit = isActive
|
||||||
@@ -161,6 +130,88 @@ export default class AppUpdater {
|
|||||||
return UpgradeChannel.LATEST
|
return UpgradeChannel.LATEST
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Fetch update configuration from GitHub or GitCode based on mirror
|
||||||
|
* @param mirror - Mirror to fetch config from
|
||||||
|
* @returns UpdateConfig object or null if fetch fails
|
||||||
|
*/
|
||||||
|
private async _fetchUpdateConfig(mirror: UpdateMirror): Promise<UpdateConfig | null> {
|
||||||
|
const configUrl = mirror === UpdateMirror.GITCODE ? UpdateConfigUrl.GITCODE : UpdateConfigUrl.GITHUB
|
||||||
|
|
||||||
|
try {
|
||||||
|
logger.info(`Fetching update config from ${configUrl} (mirror: ${mirror})`)
|
||||||
|
const response = await net.fetch(configUrl, {
|
||||||
|
headers: {
|
||||||
|
'User-Agent': generateUserAgent(),
|
||||||
|
Accept: 'application/json',
|
||||||
|
'X-Client-Id': configManager.getClientId(),
|
||||||
|
// no-cache
|
||||||
|
'Cache-Control': 'no-cache'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
if (!response.ok) {
|
||||||
|
throw new Error(`HTTP error! status: ${response.status}`)
|
||||||
|
}
|
||||||
|
|
||||||
|
const config = (await response.json()) as UpdateConfig
|
||||||
|
logger.info(`Update config fetched successfully, last updated: ${config.lastUpdated}`)
|
||||||
|
return config
|
||||||
|
} catch (error) {
|
||||||
|
logger.error('Failed to fetch update config:', error as Error)
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Find compatible channel configuration based on current version
|
||||||
|
* @param currentVersion - Current app version
|
||||||
|
* @param requestedChannel - Requested upgrade channel (latest/rc/beta)
|
||||||
|
* @param config - Update configuration object
|
||||||
|
* @returns Object containing ChannelConfig and actual channel if found, null otherwise
|
||||||
|
*/
|
||||||
|
private _findCompatibleChannel(
|
||||||
|
currentVersion: string,
|
||||||
|
requestedChannel: UpgradeChannel,
|
||||||
|
config: UpdateConfig
|
||||||
|
): { config: ChannelConfig; channel: UpgradeChannel } | null {
|
||||||
|
// Get all version keys and sort descending (newest first)
|
||||||
|
const versionKeys = Object.keys(config.versions).sort(semver.rcompare)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
`Finding compatible channel for version ${currentVersion}, requested channel: ${requestedChannel}, available versions: ${versionKeys.join(', ')}`
|
||||||
|
)
|
||||||
|
|
||||||
|
for (const versionKey of versionKeys) {
|
||||||
|
const versionConfig = config.versions[versionKey]
|
||||||
|
const channelConfig = versionConfig.channels[requestedChannel]
|
||||||
|
const latestChannelConfig = versionConfig.channels[UpgradeChannel.LATEST]
|
||||||
|
|
||||||
|
// Check version compatibility and channel availability
|
||||||
|
if (semver.gte(currentVersion, versionConfig.minCompatibleVersion) && channelConfig !== null) {
|
||||||
|
logger.info(
|
||||||
|
`Found compatible version: ${versionKey} (minCompatibleVersion: ${versionConfig.minCompatibleVersion}), version: ${channelConfig.version}`
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
requestedChannel !== UpgradeChannel.LATEST &&
|
||||||
|
latestChannelConfig &&
|
||||||
|
semver.gte(latestChannelConfig.version, channelConfig.version)
|
||||||
|
) {
|
||||||
|
logger.info(
|
||||||
|
`latest channel version is greater than the requested channel version: ${latestChannelConfig.version} > ${channelConfig.version}, using latest instead`
|
||||||
|
)
|
||||||
|
return { config: latestChannelConfig, channel: UpgradeChannel.LATEST }
|
||||||
|
}
|
||||||
|
|
||||||
|
return { config: channelConfig, channel: requestedChannel }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.warn(`No compatible channel found for version ${currentVersion} and channel ${requestedChannel}`)
|
||||||
|
return null
|
||||||
|
}
|
||||||
|
|
||||||
private _setChannel(channel: UpgradeChannel, feedUrl: string) {
|
private _setChannel(channel: UpgradeChannel, feedUrl: string) {
|
||||||
this.autoUpdater.channel = channel
|
this.autoUpdater.channel = channel
|
||||||
this.autoUpdater.setFeedURL(feedUrl)
|
this.autoUpdater.setFeedURL(feedUrl)
|
||||||
@@ -172,33 +223,42 @@ export default class AppUpdater {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private async _setFeedUrl() {
|
private async _setFeedUrl() {
|
||||||
|
const currentVersion = app.getVersion()
|
||||||
const testPlan = configManager.getTestPlan()
|
const testPlan = configManager.getTestPlan()
|
||||||
if (testPlan) {
|
const requestedChannel = testPlan ? this._getTestChannel() : UpgradeChannel.LATEST
|
||||||
const channel = this._getTestChannel()
|
|
||||||
|
|
||||||
if (channel === UpgradeChannel.LATEST) {
|
// Determine mirror based on IP country
|
||||||
this._setChannel(UpgradeChannel.LATEST, FeedUrl.GITHUB_LATEST)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
const releaseUrl = await this._getReleaseVersionFromGithub(channel)
|
|
||||||
if (releaseUrl) {
|
|
||||||
logger.info(`release url is ${releaseUrl}, set channel to ${channel}`)
|
|
||||||
this._setChannel(channel, releaseUrl)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// if no prerelease url, use github latest to get release
|
|
||||||
this._setChannel(UpgradeChannel.LATEST, FeedUrl.GITHUB_LATEST)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
this._setChannel(UpgradeChannel.LATEST, FeedUrl.PRODUCTION)
|
|
||||||
const ipCountry = await getIpCountry()
|
const ipCountry = await getIpCountry()
|
||||||
logger.info(`ipCountry is ${ipCountry}, set channel to ${UpgradeChannel.LATEST}`)
|
const mirror = ipCountry.toLowerCase() === 'cn' ? UpdateMirror.GITCODE : UpdateMirror.GITHUB
|
||||||
if (ipCountry.toLowerCase() !== 'cn') {
|
|
||||||
this._setChannel(UpgradeChannel.LATEST, FeedUrl.GITHUB_LATEST)
|
logger.info(
|
||||||
|
`Setting feed URL for version ${currentVersion}, testPlan: ${testPlan}, requested channel: ${requestedChannel}, mirror: ${mirror} (IP country: ${ipCountry})`
|
||||||
|
)
|
||||||
|
|
||||||
|
// Try to fetch update config from remote
|
||||||
|
const config = await this._fetchUpdateConfig(mirror)
|
||||||
|
|
||||||
|
if (config) {
|
||||||
|
// Use new config-based system
|
||||||
|
const result = this._findCompatibleChannel(currentVersion, requestedChannel, config)
|
||||||
|
|
||||||
|
if (result) {
|
||||||
|
const { config: channelConfig, channel: actualChannel } = result
|
||||||
|
const feedUrl = channelConfig.feedUrls[mirror]
|
||||||
|
logger.info(
|
||||||
|
`Using config-based feed URL: ${feedUrl} for channel ${actualChannel} (requested: ${requestedChannel}, mirror: ${mirror})`
|
||||||
|
)
|
||||||
|
this._setChannel(actualChannel, feedUrl)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.info('Failed to fetch update config, falling back to default feed URL')
|
||||||
|
// Fallback: use default feed URL based on mirror
|
||||||
|
const defaultFeedUrl = mirror === UpdateMirror.GITCODE ? FeedUrl.PRODUCTION : FeedUrl.GITHUB_LATEST
|
||||||
|
|
||||||
|
logger.info(`Using fallback feed URL: ${defaultFeedUrl}`)
|
||||||
|
this._setChannel(UpgradeChannel.LATEST, defaultFeedUrl)
|
||||||
}
|
}
|
||||||
|
|
||||||
public cancelDownload() {
|
public cancelDownload() {
|
||||||
@@ -320,8 +380,3 @@ export default class AppUpdater {
|
|||||||
return processedInfo
|
return processedInfo
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
interface GithubReleaseInfo {
|
|
||||||
draft: boolean
|
|
||||||
prerelease: boolean
|
|
||||||
tag_name: string
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -375,13 +375,16 @@ export class WindowService {
|
|||||||
|
|
||||||
mainWindow.hide()
|
mainWindow.hide()
|
||||||
|
|
||||||
// TODO: don't hide dock icon when close to tray
|
//for mac users, should hide dock icon if close to tray
|
||||||
// will cause the cmd+h behavior not working
|
if (isMac && isTrayOnClose) {
|
||||||
// after the electron fix the bug, we can restore this code
|
app.dock?.hide()
|
||||||
// //for mac users, should hide dock icon if close to tray
|
|
||||||
// if (isMac && isTrayOnClose) {
|
mainWindow.once('show', () => {
|
||||||
// app.dock?.hide()
|
//restore the window can hide by cmd+h when the window is shown again
|
||||||
// }
|
// https://github.com/electron/electron/pull/47970
|
||||||
|
app.dock?.show()
|
||||||
|
})
|
||||||
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
mainWindow.on('closed', () => {
|
mainWindow.on('closed', () => {
|
||||||
|
|||||||
@@ -85,6 +85,9 @@ vi.mock('electron-updater', () => ({
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
// Import after mocks
|
// Import after mocks
|
||||||
|
import { UpdateMirror } from '@shared/config/constant'
|
||||||
|
import { app, net } from 'electron'
|
||||||
|
|
||||||
import AppUpdater from '../AppUpdater'
|
import AppUpdater from '../AppUpdater'
|
||||||
import { configManager } from '../ConfigManager'
|
import { configManager } from '../ConfigManager'
|
||||||
|
|
||||||
@@ -274,4 +277,711 @@ describe('AppUpdater', () => {
|
|||||||
expect(result.releaseNotes).toBeNull()
|
expect(result.releaseNotes).toBeNull()
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
|
describe('_fetchUpdateConfig', () => {
|
||||||
|
const mockConfig = {
|
||||||
|
lastUpdated: '2025-01-05T00:00:00Z',
|
||||||
|
versions: {
|
||||||
|
'1.6.7': {
|
||||||
|
minCompatibleVersion: '1.0.0',
|
||||||
|
description: 'Test version',
|
||||||
|
channels: {
|
||||||
|
latest: {
|
||||||
|
version: '1.6.7',
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.6.7',
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.6.7'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
rc: null,
|
||||||
|
beta: null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
it('should fetch config from GitHub mirror', async () => {
|
||||||
|
vi.mocked(net.fetch).mockResolvedValue({
|
||||||
|
ok: true,
|
||||||
|
json: async () => mockConfig
|
||||||
|
} as any)
|
||||||
|
|
||||||
|
const result = await (appUpdater as any)._fetchUpdateConfig(UpdateMirror.GITHUB)
|
||||||
|
|
||||||
|
expect(result).toEqual(mockConfig)
|
||||||
|
expect(net.fetch).toHaveBeenCalledWith(expect.stringContaining('github'), expect.any(Object))
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should fetch config from GitCode mirror', async () => {
|
||||||
|
vi.mocked(net.fetch).mockResolvedValue({
|
||||||
|
ok: true,
|
||||||
|
json: async () => mockConfig
|
||||||
|
} as any)
|
||||||
|
|
||||||
|
const result = await (appUpdater as any)._fetchUpdateConfig(UpdateMirror.GITCODE)
|
||||||
|
|
||||||
|
expect(result).toEqual(mockConfig)
|
||||||
|
// GitCode URL may vary, just check that fetch was called
|
||||||
|
expect(net.fetch).toHaveBeenCalledWith(expect.any(String), expect.any(Object))
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return null on HTTP error', async () => {
|
||||||
|
vi.mocked(net.fetch).mockResolvedValue({
|
||||||
|
ok: false,
|
||||||
|
status: 404
|
||||||
|
} as any)
|
||||||
|
|
||||||
|
const result = await (appUpdater as any)._fetchUpdateConfig(UpdateMirror.GITHUB)
|
||||||
|
|
||||||
|
expect(result).toBeNull()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return null on network error', async () => {
|
||||||
|
vi.mocked(net.fetch).mockRejectedValue(new Error('Network error'))
|
||||||
|
|
||||||
|
const result = await (appUpdater as any)._fetchUpdateConfig(UpdateMirror.GITHUB)
|
||||||
|
|
||||||
|
expect(result).toBeNull()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('_findCompatibleChannel', () => {
|
||||||
|
const mockConfig = {
|
||||||
|
lastUpdated: '2025-01-05T00:00:00Z',
|
||||||
|
versions: {
|
||||||
|
'1.6.7': {
|
||||||
|
minCompatibleVersion: '1.0.0',
|
||||||
|
description: 'v1.6.7',
|
||||||
|
channels: {
|
||||||
|
latest: {
|
||||||
|
version: '1.6.7',
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.6.7',
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.6.7'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
rc: {
|
||||||
|
version: '1.7.0-rc.1',
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.0-rc.1',
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
beta: {
|
||||||
|
version: '1.7.0-beta.3',
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.0-beta.3',
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.0-beta.3'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'2.0.0': {
|
||||||
|
minCompatibleVersion: '1.7.0',
|
||||||
|
description: 'v2.0.0',
|
||||||
|
channels: {
|
||||||
|
latest: null,
|
||||||
|
rc: null,
|
||||||
|
beta: null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
it('should find compatible latest channel', () => {
|
||||||
|
vi.mocked(app.getVersion).mockReturnValue('1.5.0')
|
||||||
|
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('1.5.0', 'latest', mockConfig)
|
||||||
|
|
||||||
|
expect(result?.config).toEqual({
|
||||||
|
version: '1.6.7',
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.6.7',
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.6.7'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
expect(result?.channel).toBe('latest')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should find compatible rc channel', () => {
|
||||||
|
vi.mocked(app.getVersion).mockReturnValue('1.5.0')
|
||||||
|
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('1.5.0', 'rc', mockConfig)
|
||||||
|
|
||||||
|
expect(result?.config).toEqual({
|
||||||
|
version: '1.7.0-rc.1',
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.0-rc.1',
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
expect(result?.channel).toBe('rc')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should find compatible beta channel', () => {
|
||||||
|
vi.mocked(app.getVersion).mockReturnValue('1.5.0')
|
||||||
|
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('1.5.0', 'beta', mockConfig)
|
||||||
|
|
||||||
|
expect(result?.config).toEqual({
|
||||||
|
version: '1.7.0-beta.3',
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.0-beta.3',
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.0-beta.3'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
expect(result?.channel).toBe('beta')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return latest when latest version >= rc version', () => {
|
||||||
|
const configWithNewerLatest = {
|
||||||
|
lastUpdated: '2025-01-05T00:00:00Z',
|
||||||
|
versions: {
|
||||||
|
'1.7.0': {
|
||||||
|
minCompatibleVersion: '1.0.0',
|
||||||
|
description: 'v1.7.0',
|
||||||
|
channels: {
|
||||||
|
latest: {
|
||||||
|
version: '1.7.0',
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.0',
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.0'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
rc: {
|
||||||
|
version: '1.7.0-rc.1',
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.0-rc.1',
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
beta: null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('1.6.0', 'rc', configWithNewerLatest)
|
||||||
|
|
||||||
|
// Should return latest instead of rc because 1.7.0 >= 1.7.0-rc.1
|
||||||
|
expect(result?.config).toEqual({
|
||||||
|
version: '1.7.0',
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.0',
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.0'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
expect(result?.channel).toBe('latest') // ✅ 返回 latest 频道
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return latest when latest version >= beta version', () => {
|
||||||
|
const configWithNewerLatest = {
|
||||||
|
lastUpdated: '2025-01-05T00:00:00Z',
|
||||||
|
versions: {
|
||||||
|
'1.7.0': {
|
||||||
|
minCompatibleVersion: '1.0.0',
|
||||||
|
description: 'v1.7.0',
|
||||||
|
channels: {
|
||||||
|
latest: {
|
||||||
|
version: '1.7.0',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.0',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.0'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
rc: null,
|
||||||
|
beta: {
|
||||||
|
version: '1.6.8-beta.1',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.6.8-beta.1',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.6.8-beta.1'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('1.6.0', 'beta', configWithNewerLatest)
|
||||||
|
|
||||||
|
// Should return latest instead of beta because 1.7.0 >= 1.6.8-beta.1
|
||||||
|
expect(result?.config).toEqual({
|
||||||
|
version: '1.7.0',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.0',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.0'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should not compare latest with itself when requesting latest channel', () => {
|
||||||
|
const config = {
|
||||||
|
lastUpdated: '2025-01-05T00:00:00Z',
|
||||||
|
versions: {
|
||||||
|
'1.7.0': {
|
||||||
|
minCompatibleVersion: '1.0.0',
|
||||||
|
description: 'v1.7.0',
|
||||||
|
channels: {
|
||||||
|
latest: {
|
||||||
|
version: '1.7.0',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.0',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.0'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
rc: {
|
||||||
|
version: '1.7.0-rc.1',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.0-rc.1',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
beta: null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('1.6.0', 'latest', config)
|
||||||
|
|
||||||
|
// Should return latest directly without comparing with itself
|
||||||
|
expect(result?.config).toEqual({
|
||||||
|
version: '1.7.0',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.0',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.0'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return rc when rc version > latest version', () => {
|
||||||
|
const configWithNewerRc = {
|
||||||
|
lastUpdated: '2025-01-05T00:00:00Z',
|
||||||
|
versions: {
|
||||||
|
'1.7.0': {
|
||||||
|
minCompatibleVersion: '1.0.0',
|
||||||
|
description: 'v1.7.0',
|
||||||
|
channels: {
|
||||||
|
latest: {
|
||||||
|
version: '1.6.7',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.6.7',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.6.7'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
rc: {
|
||||||
|
version: '1.7.0-rc.1',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.0-rc.1',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
beta: null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('1.6.0', 'rc', configWithNewerRc)
|
||||||
|
|
||||||
|
// Should return rc because 1.7.0-rc.1 > 1.6.7
|
||||||
|
expect(result?.config).toEqual({
|
||||||
|
version: '1.7.0-rc.1',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.0-rc.1',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return beta when beta version > latest version', () => {
|
||||||
|
const configWithNewerBeta = {
|
||||||
|
lastUpdated: '2025-01-05T00:00:00Z',
|
||||||
|
versions: {
|
||||||
|
'1.7.0': {
|
||||||
|
minCompatibleVersion: '1.0.0',
|
||||||
|
description: 'v1.7.0',
|
||||||
|
channels: {
|
||||||
|
latest: {
|
||||||
|
version: '1.6.7',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.6.7',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.6.7'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
rc: null,
|
||||||
|
beta: {
|
||||||
|
version: '1.7.0-beta.5',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.0-beta.5',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.0-beta.5'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('1.6.0', 'beta', configWithNewerBeta)
|
||||||
|
|
||||||
|
// Should return beta because 1.7.0-beta.5 > 1.6.7
|
||||||
|
expect(result?.config).toEqual({
|
||||||
|
version: '1.7.0-beta.5',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.0-beta.5',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.0-beta.5'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return lower version when higher version has no compatible channel', () => {
|
||||||
|
vi.mocked(app.getVersion).mockReturnValue('1.8.0')
|
||||||
|
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('1.8.0', 'latest', mockConfig)
|
||||||
|
|
||||||
|
// 1.8.0 >= 1.7.0 but 2.0.0 has no latest channel, so return 1.6.7
|
||||||
|
expect(result?.config).toEqual({
|
||||||
|
version: '1.6.7',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.6.7',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.6.7'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return null when current version does not meet minCompatibleVersion', () => {
|
||||||
|
vi.mocked(app.getVersion).mockReturnValue('0.9.0')
|
||||||
|
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('0.9.0', 'latest', mockConfig)
|
||||||
|
|
||||||
|
// 0.9.0 < 1.0.0 (minCompatibleVersion)
|
||||||
|
expect(result).toBeNull()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return lower version rc when higher version has no rc channel', () => {
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('1.8.0', 'rc', mockConfig)
|
||||||
|
|
||||||
|
// 1.8.0 >= 1.7.0 but 2.0.0 has no rc channel, so return 1.6.7 rc
|
||||||
|
expect(result?.config).toEqual({
|
||||||
|
version: '1.7.0-rc.1',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.0-rc.1',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return null when no version has the requested channel', () => {
|
||||||
|
const configWithoutRc = {
|
||||||
|
lastUpdated: '2025-01-05T00:00:00Z',
|
||||||
|
versions: {
|
||||||
|
'1.6.7': {
|
||||||
|
minCompatibleVersion: '1.0.0',
|
||||||
|
description: 'v1.6.7',
|
||||||
|
channels: {
|
||||||
|
latest: {
|
||||||
|
version: '1.6.7',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.6.7',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.6.7'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
rc: null,
|
||||||
|
beta: null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('1.5.0', 'rc', configWithoutRc)
|
||||||
|
|
||||||
|
expect(result).toBeNull()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Upgrade Path', () => {
|
||||||
|
const fullConfig = {
|
||||||
|
lastUpdated: '2025-01-05T00:00:00Z',
|
||||||
|
versions: {
|
||||||
|
'1.6.7': {
|
||||||
|
minCompatibleVersion: '1.0.0',
|
||||||
|
description: 'Last v1.x',
|
||||||
|
channels: {
|
||||||
|
latest: {
|
||||||
|
version: '1.6.7',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.6.7',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.6.7'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
rc: {
|
||||||
|
version: '1.7.0-rc.1',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.0-rc.1',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
beta: {
|
||||||
|
version: '1.7.0-beta.3',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.0-beta.3',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.0-beta.3'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'2.0.0': {
|
||||||
|
minCompatibleVersion: '1.7.0',
|
||||||
|
description: 'First v2.x',
|
||||||
|
channels: {
|
||||||
|
latest: null,
|
||||||
|
rc: null,
|
||||||
|
beta: null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
it('should upgrade from 1.6.3 to 1.6.7', () => {
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('1.6.3', 'latest', fullConfig)
|
||||||
|
|
||||||
|
expect(result?.config).toEqual({
|
||||||
|
version: '1.6.7',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.6.7',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.6.7'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should block upgrade from 1.6.7 to 2.0.0 (minCompatibleVersion not met)', () => {
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('1.6.7', 'latest', fullConfig)
|
||||||
|
|
||||||
|
// Should return 1.6.7, not 2.0.0, because 1.6.7 < 1.7.0 (minCompatibleVersion of 2.0.0)
|
||||||
|
expect(result?.config).toEqual({
|
||||||
|
version: '1.6.7',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.6.7',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.6.7'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should allow upgrade from 1.7.0 to 2.0.0', () => {
|
||||||
|
const configWith2x = {
|
||||||
|
...fullConfig,
|
||||||
|
versions: {
|
||||||
|
...fullConfig.versions,
|
||||||
|
'2.0.0': {
|
||||||
|
minCompatibleVersion: '1.7.0',
|
||||||
|
description: 'First v2.x',
|
||||||
|
channels: {
|
||||||
|
latest: {
|
||||||
|
version: '2.0.0',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v2.0.0',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v2.0.0'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
rc: null,
|
||||||
|
beta: null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('1.7.0', 'latest', configWith2x)
|
||||||
|
|
||||||
|
expect(result?.config).toEqual({
|
||||||
|
version: '2.0.0',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v2.0.0',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v2.0.0'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Complete Multi-Step Upgrade Path', () => {
|
||||||
|
const fullUpgradeConfig = {
|
||||||
|
lastUpdated: '2025-01-05T00:00:00Z',
|
||||||
|
versions: {
|
||||||
|
'1.7.5': {
|
||||||
|
minCompatibleVersion: '1.0.0',
|
||||||
|
description: 'Last v1.x stable',
|
||||||
|
channels: {
|
||||||
|
latest: {
|
||||||
|
version: '1.7.5',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.5',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.5'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
rc: null,
|
||||||
|
beta: null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'2.0.0': {
|
||||||
|
minCompatibleVersion: '1.7.0',
|
||||||
|
description: 'First v2.x - intermediate version',
|
||||||
|
channels: {
|
||||||
|
latest: {
|
||||||
|
version: '2.0.0',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v2.0.0',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v2.0.0'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
rc: null,
|
||||||
|
beta: null
|
||||||
|
}
|
||||||
|
},
|
||||||
|
'2.1.6': {
|
||||||
|
minCompatibleVersion: '2.0.0',
|
||||||
|
description: 'Current v2.x stable',
|
||||||
|
channels: {
|
||||||
|
latest: {
|
||||||
|
version: '2.1.6',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/latest',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/latest'
|
||||||
|
}
|
||||||
|
},
|
||||||
|
rc: null,
|
||||||
|
beta: null
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
it('should upgrade from 1.6.3 to 1.7.5 (step 1)', () => {
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('1.6.3', 'latest', fullUpgradeConfig)
|
||||||
|
|
||||||
|
expect(result?.config).toEqual({
|
||||||
|
version: '1.7.5',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v1.7.5',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v1.7.5'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should upgrade from 1.7.5 to 2.0.0 (step 2)', () => {
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('1.7.5', 'latest', fullUpgradeConfig)
|
||||||
|
|
||||||
|
expect(result?.config).toEqual({
|
||||||
|
version: '2.0.0',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/v2.0.0',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/v2.0.0'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should upgrade from 2.0.0 to 2.1.6 (step 3)', () => {
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('2.0.0', 'latest', fullUpgradeConfig)
|
||||||
|
|
||||||
|
expect(result?.config).toEqual({
|
||||||
|
version: '2.1.6',
|
||||||
|
|
||||||
|
feedUrls: {
|
||||||
|
github: 'https://github.com/test/latest',
|
||||||
|
|
||||||
|
gitcode: 'https://gitcode.com/test/latest'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should complete full upgrade path: 1.6.3 -> 1.7.5 -> 2.0.0 -> 2.1.6', () => {
|
||||||
|
// Step 1: 1.6.3 -> 1.7.5
|
||||||
|
let currentVersion = '1.6.3'
|
||||||
|
let result = (appUpdater as any)._findCompatibleChannel(currentVersion, 'latest', fullUpgradeConfig)
|
||||||
|
expect(result?.config.version).toBe('1.7.5')
|
||||||
|
|
||||||
|
// Step 2: 1.7.5 -> 2.0.0
|
||||||
|
currentVersion = result?.config.version!
|
||||||
|
result = (appUpdater as any)._findCompatibleChannel(currentVersion, 'latest', fullUpgradeConfig)
|
||||||
|
expect(result?.config.version).toBe('2.0.0')
|
||||||
|
|
||||||
|
// Step 3: 2.0.0 -> 2.1.6
|
||||||
|
currentVersion = result?.config.version!
|
||||||
|
result = (appUpdater as any)._findCompatibleChannel(currentVersion, 'latest', fullUpgradeConfig)
|
||||||
|
expect(result?.config.version).toBe('2.1.6')
|
||||||
|
|
||||||
|
// Final: 2.1.6 is the latest, no more upgrades
|
||||||
|
currentVersion = result?.config.version!
|
||||||
|
result = (appUpdater as any)._findCompatibleChannel(currentVersion, 'latest', fullUpgradeConfig)
|
||||||
|
expect(result?.config.version).toBe('2.1.6')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should block direct upgrade from 1.6.3 to 2.0.0 (skip intermediate)', () => {
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('1.6.3', 'latest', fullUpgradeConfig)
|
||||||
|
|
||||||
|
// Should return 1.7.5, not 2.0.0, because 1.6.3 < 1.7.0 (minCompatibleVersion of 2.0.0)
|
||||||
|
expect(result?.config.version).toBe('1.7.5')
|
||||||
|
expect(result?.config.version).not.toBe('2.0.0')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should block direct upgrade from 1.7.5 to 2.1.6 (skip intermediate)', () => {
|
||||||
|
const result = (appUpdater as any)._findCompatibleChannel('1.7.5', 'latest', fullUpgradeConfig)
|
||||||
|
|
||||||
|
// Should return 2.0.0, not 2.1.6, because 1.7.5 < 2.0.0 (minCompatibleVersion of 2.1.6)
|
||||||
|
expect(result?.config.version).toBe('2.0.0')
|
||||||
|
expect(result?.config.version).not.toBe('2.1.6')
|
||||||
|
})
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -1,17 +1,13 @@
|
|||||||
import { type Client, createClient } from '@libsql/client'
|
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { mcpApiService } from '@main/apiServer/services/mcp'
|
import { mcpApiService } from '@main/apiServer/services/mcp'
|
||||||
import type { ModelValidationError } from '@main/apiServer/utils'
|
import type { ModelValidationError } from '@main/apiServer/utils'
|
||||||
import { validateModelId } from '@main/apiServer/utils'
|
import { validateModelId } from '@main/apiServer/utils'
|
||||||
import type { AgentType, MCPTool, SlashCommand, Tool } from '@types'
|
import type { AgentType, MCPTool, SlashCommand, Tool } from '@types'
|
||||||
import { objectKeys } from '@types'
|
import { objectKeys } from '@types'
|
||||||
import { drizzle, type LibSQLDatabase } from 'drizzle-orm/libsql'
|
|
||||||
import fs from 'fs'
|
import fs from 'fs'
|
||||||
import path from 'path'
|
import path from 'path'
|
||||||
|
|
||||||
import { MigrationService } from './database/MigrationService'
|
import { DatabaseManager } from './database/DatabaseManager'
|
||||||
import * as schema from './database/schema'
|
|
||||||
import { dbPath } from './drizzle.config'
|
|
||||||
import type { AgentModelField } from './errors'
|
import type { AgentModelField } from './errors'
|
||||||
import { AgentModelValidationError } from './errors'
|
import { AgentModelValidationError } from './errors'
|
||||||
import { builtinSlashCommands } from './services/claudecode/commands'
|
import { builtinSlashCommands } from './services/claudecode/commands'
|
||||||
@@ -20,22 +16,16 @@ import { builtinTools } from './services/claudecode/tools'
|
|||||||
const logger = loggerService.withContext('BaseService')
|
const logger = loggerService.withContext('BaseService')
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Base service class providing shared database connection and utilities
|
* Base service class providing shared utilities for all agent-related services.
|
||||||
* for all agent-related services.
|
|
||||||
*
|
*
|
||||||
* Features:
|
* Features:
|
||||||
* - Programmatic schema management (no CLI dependencies)
|
* - Database access through DatabaseManager singleton
|
||||||
* - Automatic table creation and migration
|
* - JSON field serialization/deserialization
|
||||||
* - Schema version tracking and compatibility checks
|
* - Path validation and creation
|
||||||
* - Transaction-based operations for safety
|
* - Model validation
|
||||||
* - Development vs production mode handling
|
* - MCP tools and slash commands listing
|
||||||
* - Connection retry logic with exponential backoff
|
|
||||||
*/
|
*/
|
||||||
export abstract class BaseService {
|
export abstract class BaseService {
|
||||||
protected static client: Client | null = null
|
|
||||||
protected static db: LibSQLDatabase<typeof schema> | null = null
|
|
||||||
protected static isInitialized = false
|
|
||||||
protected static initializationPromise: Promise<void> | null = null
|
|
||||||
protected jsonFields: string[] = [
|
protected jsonFields: string[] = [
|
||||||
'tools',
|
'tools',
|
||||||
'mcps',
|
'mcps',
|
||||||
@@ -45,23 +35,6 @@ export abstract class BaseService {
|
|||||||
'slash_commands'
|
'slash_commands'
|
||||||
]
|
]
|
||||||
|
|
||||||
/**
|
|
||||||
* Initialize database with retry logic and proper error handling
|
|
||||||
*/
|
|
||||||
protected static async initialize(): Promise<void> {
|
|
||||||
// Return existing initialization if in progress
|
|
||||||
if (BaseService.initializationPromise) {
|
|
||||||
return BaseService.initializationPromise
|
|
||||||
}
|
|
||||||
|
|
||||||
if (BaseService.isInitialized) {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
BaseService.initializationPromise = BaseService.performInitialization()
|
|
||||||
return BaseService.initializationPromise
|
|
||||||
}
|
|
||||||
|
|
||||||
public async listMcpTools(agentType: AgentType, ids?: string[]): Promise<Tool[]> {
|
public async listMcpTools(agentType: AgentType, ids?: string[]): Promise<Tool[]> {
|
||||||
const tools: Tool[] = []
|
const tools: Tool[] = []
|
||||||
if (agentType === 'claude-code') {
|
if (agentType === 'claude-code') {
|
||||||
@@ -101,78 +74,13 @@ export abstract class BaseService {
|
|||||||
return []
|
return []
|
||||||
}
|
}
|
||||||
|
|
||||||
private static async performInitialization(): Promise<void> {
|
/**
|
||||||
const maxRetries = 3
|
* Get database instance
|
||||||
let lastError: Error
|
* Automatically waits for initialization to complete
|
||||||
|
*/
|
||||||
for (let attempt = 1; attempt <= maxRetries; attempt++) {
|
protected async getDatabase() {
|
||||||
try {
|
const dbManager = await DatabaseManager.getInstance()
|
||||||
logger.info(`Initializing Agent database at: ${dbPath} (attempt ${attempt}/${maxRetries})`)
|
return dbManager.getDatabase()
|
||||||
|
|
||||||
// Ensure the database directory exists
|
|
||||||
const dbDir = path.dirname(dbPath)
|
|
||||||
if (!fs.existsSync(dbDir)) {
|
|
||||||
logger.info(`Creating database directory: ${dbDir}`)
|
|
||||||
fs.mkdirSync(dbDir, { recursive: true })
|
|
||||||
}
|
|
||||||
|
|
||||||
BaseService.client = createClient({
|
|
||||||
url: `file:${dbPath}`
|
|
||||||
})
|
|
||||||
|
|
||||||
BaseService.db = drizzle(BaseService.client, { schema })
|
|
||||||
|
|
||||||
// Run database migrations
|
|
||||||
const migrationService = new MigrationService(BaseService.db, BaseService.client)
|
|
||||||
await migrationService.runMigrations()
|
|
||||||
|
|
||||||
BaseService.isInitialized = true
|
|
||||||
logger.info('Agent database initialized successfully')
|
|
||||||
return
|
|
||||||
} catch (error) {
|
|
||||||
lastError = error as Error
|
|
||||||
logger.warn(`Database initialization attempt ${attempt} failed:`, lastError)
|
|
||||||
|
|
||||||
// Clean up on failure
|
|
||||||
if (BaseService.client) {
|
|
||||||
try {
|
|
||||||
BaseService.client.close()
|
|
||||||
} catch (closeError) {
|
|
||||||
logger.warn('Failed to close client during cleanup:', closeError as Error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
BaseService.client = null
|
|
||||||
BaseService.db = null
|
|
||||||
|
|
||||||
// Wait before retrying (exponential backoff)
|
|
||||||
if (attempt < maxRetries) {
|
|
||||||
const delay = Math.pow(2, attempt) * 1000 // 2s, 4s, 8s
|
|
||||||
logger.info(`Retrying in ${delay}ms...`)
|
|
||||||
await new Promise((resolve) => setTimeout(resolve, delay))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// All retries failed
|
|
||||||
BaseService.initializationPromise = null
|
|
||||||
logger.error('Failed to initialize Agent database after all retries:', lastError!)
|
|
||||||
throw lastError!
|
|
||||||
}
|
|
||||||
|
|
||||||
protected ensureInitialized(): void {
|
|
||||||
if (!BaseService.isInitialized || !BaseService.db || !BaseService.client) {
|
|
||||||
throw new Error('Database not initialized. Call initialize() first.')
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
protected get database(): LibSQLDatabase<typeof schema> {
|
|
||||||
this.ensureInitialized()
|
|
||||||
return BaseService.db!
|
|
||||||
}
|
|
||||||
|
|
||||||
protected get rawClient(): Client {
|
|
||||||
this.ensureInitialized()
|
|
||||||
return BaseService.client!
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected serializeJsonFields(data: any): any {
|
protected serializeJsonFields(data: any): any {
|
||||||
@@ -284,7 +192,7 @@ export abstract class BaseService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Force re-initialization (for development/testing)
|
* Validate agent model configuration
|
||||||
*/
|
*/
|
||||||
protected async validateAgentModels(
|
protected async validateAgentModels(
|
||||||
agentType: AgentType,
|
agentType: AgentType,
|
||||||
@@ -325,22 +233,4 @@ export abstract class BaseService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static async reinitialize(): Promise<void> {
|
|
||||||
BaseService.isInitialized = false
|
|
||||||
BaseService.initializationPromise = null
|
|
||||||
|
|
||||||
if (BaseService.client) {
|
|
||||||
try {
|
|
||||||
BaseService.client.close()
|
|
||||||
} catch (error) {
|
|
||||||
logger.warn('Failed to close client during reinitialize:', error as Error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
BaseService.client = null
|
|
||||||
BaseService.db = null
|
|
||||||
|
|
||||||
await BaseService.initialize()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
156
src/main/services/agents/database/DatabaseManager.ts
Normal file
156
src/main/services/agents/database/DatabaseManager.ts
Normal file
@@ -0,0 +1,156 @@
|
|||||||
|
import { type Client, createClient } from '@libsql/client'
|
||||||
|
import { loggerService } from '@logger'
|
||||||
|
import type { LibSQLDatabase } from 'drizzle-orm/libsql'
|
||||||
|
import { drizzle } from 'drizzle-orm/libsql'
|
||||||
|
import fs from 'fs'
|
||||||
|
import path from 'path'
|
||||||
|
|
||||||
|
import { dbPath } from '../drizzle.config'
|
||||||
|
import { MigrationService } from './MigrationService'
|
||||||
|
import * as schema from './schema'
|
||||||
|
|
||||||
|
const logger = loggerService.withContext('DatabaseManager')
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Database initialization state
|
||||||
|
*/
|
||||||
|
enum InitState {
|
||||||
|
INITIALIZING = 'initializing',
|
||||||
|
INITIALIZED = 'initialized',
|
||||||
|
FAILED = 'failed'
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* DatabaseManager - Singleton class for managing libsql database connections
|
||||||
|
*
|
||||||
|
* Responsibilities:
|
||||||
|
* - Single source of truth for database connection
|
||||||
|
* - Thread-safe initialization with state management
|
||||||
|
* - Automatic migration handling
|
||||||
|
* - Safe connection cleanup
|
||||||
|
* - Error recovery and retry logic
|
||||||
|
* - Windows platform compatibility fixes
|
||||||
|
*/
|
||||||
|
export class DatabaseManager {
|
||||||
|
private static instance: DatabaseManager | null = null
|
||||||
|
|
||||||
|
private client: Client | null = null
|
||||||
|
private db: LibSQLDatabase<typeof schema> | null = null
|
||||||
|
private state: InitState = InitState.INITIALIZING
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the singleton instance (database initialization starts automatically)
|
||||||
|
*/
|
||||||
|
public static async getInstance(): Promise<DatabaseManager> {
|
||||||
|
if (DatabaseManager.instance) {
|
||||||
|
return DatabaseManager.instance
|
||||||
|
}
|
||||||
|
|
||||||
|
const instance = new DatabaseManager()
|
||||||
|
await instance.initialize()
|
||||||
|
DatabaseManager.instance = instance
|
||||||
|
|
||||||
|
return instance
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform the actual initialization
|
||||||
|
*/
|
||||||
|
public async initialize(): Promise<void> {
|
||||||
|
if (this.state === InitState.INITIALIZED) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
logger.info(`Initializing database at: ${dbPath}`)
|
||||||
|
|
||||||
|
// Ensure database directory exists
|
||||||
|
const dbDir = path.dirname(dbPath)
|
||||||
|
if (!fs.existsSync(dbDir)) {
|
||||||
|
logger.info(`Creating database directory: ${dbDir}`)
|
||||||
|
fs.mkdirSync(dbDir, { recursive: true })
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if database file is corrupted (Windows specific check)
|
||||||
|
if (fs.existsSync(dbPath)) {
|
||||||
|
const stats = fs.statSync(dbPath)
|
||||||
|
if (stats.size === 0) {
|
||||||
|
logger.warn('Database file is empty, removing corrupted file')
|
||||||
|
fs.unlinkSync(dbPath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create client with platform-specific options
|
||||||
|
this.client = createClient({
|
||||||
|
url: `file:${dbPath}`,
|
||||||
|
// intMode: 'number' helps avoid some Windows compatibility issues
|
||||||
|
intMode: 'number'
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create drizzle instance
|
||||||
|
this.db = drizzle(this.client, { schema })
|
||||||
|
|
||||||
|
// Run migrations
|
||||||
|
const migrationService = new MigrationService(this.db, this.client)
|
||||||
|
await migrationService.runMigrations()
|
||||||
|
|
||||||
|
this.state = InitState.INITIALIZED
|
||||||
|
logger.info('Database initialized successfully')
|
||||||
|
} catch (error) {
|
||||||
|
const err = error as Error
|
||||||
|
logger.error('Database initialization failed:', {
|
||||||
|
error: err.message,
|
||||||
|
stack: err.stack
|
||||||
|
})
|
||||||
|
|
||||||
|
// Clean up failed initialization
|
||||||
|
this.cleanupFailedInit()
|
||||||
|
|
||||||
|
// Set failed state
|
||||||
|
this.state = InitState.FAILED
|
||||||
|
throw new Error(`Database initialization failed: ${err.message || 'Unknown error'}`)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Clean up after failed initialization
|
||||||
|
*/
|
||||||
|
private cleanupFailedInit(): void {
|
||||||
|
if (this.client) {
|
||||||
|
try {
|
||||||
|
// On Windows, closing a partially initialized client can crash
|
||||||
|
// Wrap in try-catch and ignore errors during cleanup
|
||||||
|
this.client.close()
|
||||||
|
} catch (error) {
|
||||||
|
logger.warn('Failed to close client during cleanup:', error as Error)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
this.client = null
|
||||||
|
this.db = null
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the database instance
|
||||||
|
* Automatically waits for initialization to complete
|
||||||
|
* @throws Error if database initialization failed
|
||||||
|
*/
|
||||||
|
public getDatabase(): LibSQLDatabase<typeof schema> {
|
||||||
|
return this.db!
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the raw client (for advanced operations)
|
||||||
|
* Automatically waits for initialization to complete
|
||||||
|
* @throws Error if database initialization failed
|
||||||
|
*/
|
||||||
|
public async getClient(): Promise<Client> {
|
||||||
|
return this.client!
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Check if database is initialized
|
||||||
|
*/
|
||||||
|
public isInitialized(): boolean {
|
||||||
|
return this.state === InitState.INITIALIZED
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -7,8 +7,14 @@
|
|||||||
* Schema evolution is handled by Drizzle Kit migrations.
|
* Schema evolution is handled by Drizzle Kit migrations.
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
// Database Manager (Singleton)
|
||||||
|
export * from './DatabaseManager'
|
||||||
|
|
||||||
// Drizzle ORM schemas
|
// Drizzle ORM schemas
|
||||||
export * from './schema'
|
export * from './schema'
|
||||||
|
|
||||||
// Repository helpers
|
// Repository helpers
|
||||||
export * from './sessionMessageRepository'
|
export * from './sessionMessageRepository'
|
||||||
|
|
||||||
|
// Migration Service
|
||||||
|
export * from './MigrationService'
|
||||||
|
|||||||
@@ -15,26 +15,16 @@ import { sessionMessagesTable } from './schema'
|
|||||||
|
|
||||||
const logger = loggerService.withContext('AgentMessageRepository')
|
const logger = loggerService.withContext('AgentMessageRepository')
|
||||||
|
|
||||||
type TxClient = any
|
|
||||||
|
|
||||||
export type PersistUserMessageParams = AgentMessageUserPersistPayload & {
|
export type PersistUserMessageParams = AgentMessageUserPersistPayload & {
|
||||||
sessionId: string
|
sessionId: string
|
||||||
agentSessionId?: string
|
agentSessionId?: string
|
||||||
tx?: TxClient
|
|
||||||
}
|
}
|
||||||
|
|
||||||
export type PersistAssistantMessageParams = AgentMessageAssistantPersistPayload & {
|
export type PersistAssistantMessageParams = AgentMessageAssistantPersistPayload & {
|
||||||
sessionId: string
|
sessionId: string
|
||||||
agentSessionId: string
|
agentSessionId: string
|
||||||
tx?: TxClient
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type PersistExchangeParams = AgentMessagePersistExchangePayload & {
|
|
||||||
tx?: TxClient
|
|
||||||
}
|
|
||||||
|
|
||||||
type PersistExchangeResult = AgentMessagePersistExchangeResult
|
|
||||||
|
|
||||||
class AgentMessageRepository extends BaseService {
|
class AgentMessageRepository extends BaseService {
|
||||||
private static instance: AgentMessageRepository | null = null
|
private static instance: AgentMessageRepository | null = null
|
||||||
|
|
||||||
@@ -87,17 +77,13 @@ class AgentMessageRepository extends BaseService {
|
|||||||
return deserialized
|
return deserialized
|
||||||
}
|
}
|
||||||
|
|
||||||
private getWriter(tx?: TxClient): TxClient {
|
|
||||||
return tx ?? this.database
|
|
||||||
}
|
|
||||||
|
|
||||||
private async findExistingMessageRow(
|
private async findExistingMessageRow(
|
||||||
writer: TxClient,
|
|
||||||
sessionId: string,
|
sessionId: string,
|
||||||
role: string,
|
role: string,
|
||||||
messageId: string
|
messageId: string
|
||||||
): Promise<SessionMessageRow | null> {
|
): Promise<SessionMessageRow | null> {
|
||||||
const candidateRows: SessionMessageRow[] = await writer
|
const database = await this.getDatabase()
|
||||||
|
const candidateRows: SessionMessageRow[] = await database
|
||||||
.select()
|
.select()
|
||||||
.from(sessionMessagesTable)
|
.from(sessionMessagesTable)
|
||||||
.where(and(eq(sessionMessagesTable.session_id, sessionId), eq(sessionMessagesTable.role, role)))
|
.where(and(eq(sessionMessagesTable.session_id, sessionId), eq(sessionMessagesTable.role, role)))
|
||||||
@@ -122,10 +108,7 @@ class AgentMessageRepository extends BaseService {
|
|||||||
private async upsertMessage(
|
private async upsertMessage(
|
||||||
params: PersistUserMessageParams | PersistAssistantMessageParams
|
params: PersistUserMessageParams | PersistAssistantMessageParams
|
||||||
): Promise<AgentSessionMessageEntity> {
|
): Promise<AgentSessionMessageEntity> {
|
||||||
await AgentMessageRepository.initialize()
|
const { sessionId, agentSessionId = '', payload, metadata, createdAt } = params
|
||||||
this.ensureInitialized()
|
|
||||||
|
|
||||||
const { sessionId, agentSessionId = '', payload, metadata, createdAt, tx } = params
|
|
||||||
|
|
||||||
if (!payload?.message?.role) {
|
if (!payload?.message?.role) {
|
||||||
throw new Error('Message payload missing role')
|
throw new Error('Message payload missing role')
|
||||||
@@ -135,18 +118,18 @@ class AgentMessageRepository extends BaseService {
|
|||||||
throw new Error('Message payload missing id')
|
throw new Error('Message payload missing id')
|
||||||
}
|
}
|
||||||
|
|
||||||
const writer = this.getWriter(tx)
|
const database = await this.getDatabase()
|
||||||
const now = createdAt ?? payload.message.createdAt ?? new Date().toISOString()
|
const now = createdAt ?? payload.message.createdAt ?? new Date().toISOString()
|
||||||
const serializedPayload = this.serializeMessage(payload)
|
const serializedPayload = this.serializeMessage(payload)
|
||||||
const serializedMetadata = this.serializeMetadata(metadata)
|
const serializedMetadata = this.serializeMetadata(metadata)
|
||||||
|
|
||||||
const existingRow = await this.findExistingMessageRow(writer, sessionId, payload.message.role, payload.message.id)
|
const existingRow = await this.findExistingMessageRow(sessionId, payload.message.role, payload.message.id)
|
||||||
|
|
||||||
if (existingRow) {
|
if (existingRow) {
|
||||||
const metadataToPersist = serializedMetadata ?? existingRow.metadata ?? undefined
|
const metadataToPersist = serializedMetadata ?? existingRow.metadata ?? undefined
|
||||||
const agentSessionToPersist = agentSessionId || existingRow.agent_session_id || ''
|
const agentSessionToPersist = agentSessionId || existingRow.agent_session_id || ''
|
||||||
|
|
||||||
await writer
|
await database
|
||||||
.update(sessionMessagesTable)
|
.update(sessionMessagesTable)
|
||||||
.set({
|
.set({
|
||||||
content: serializedPayload,
|
content: serializedPayload,
|
||||||
@@ -175,7 +158,7 @@ class AgentMessageRepository extends BaseService {
|
|||||||
updated_at: now
|
updated_at: now
|
||||||
}
|
}
|
||||||
|
|
||||||
const [saved] = await writer.insert(sessionMessagesTable).values(insertData).returning()
|
const [saved] = await database.insert(sessionMessagesTable).values(insertData).returning()
|
||||||
|
|
||||||
return this.deserialize(saved)
|
return this.deserialize(saved)
|
||||||
}
|
}
|
||||||
@@ -188,49 +171,38 @@ class AgentMessageRepository extends BaseService {
|
|||||||
return this.upsertMessage(params)
|
return this.upsertMessage(params)
|
||||||
}
|
}
|
||||||
|
|
||||||
async persistExchange(params: PersistExchangeParams): Promise<PersistExchangeResult> {
|
async persistExchange(params: AgentMessagePersistExchangePayload): Promise<AgentMessagePersistExchangeResult> {
|
||||||
await AgentMessageRepository.initialize()
|
|
||||||
this.ensureInitialized()
|
|
||||||
|
|
||||||
const { sessionId, agentSessionId, user, assistant } = params
|
const { sessionId, agentSessionId, user, assistant } = params
|
||||||
|
|
||||||
const result = await this.database.transaction(async (tx) => {
|
const exchangeResult: AgentMessagePersistExchangeResult = {}
|
||||||
const exchangeResult: PersistExchangeResult = {}
|
|
||||||
|
|
||||||
if (user?.payload) {
|
if (user?.payload) {
|
||||||
exchangeResult.userMessage = await this.persistUserMessage({
|
exchangeResult.userMessage = await this.persistUserMessage({
|
||||||
sessionId,
|
sessionId,
|
||||||
agentSessionId,
|
agentSessionId,
|
||||||
payload: user.payload,
|
payload: user.payload,
|
||||||
metadata: user.metadata,
|
metadata: user.metadata,
|
||||||
createdAt: user.createdAt,
|
createdAt: user.createdAt
|
||||||
tx
|
})
|
||||||
})
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if (assistant?.payload) {
|
if (assistant?.payload) {
|
||||||
exchangeResult.assistantMessage = await this.persistAssistantMessage({
|
exchangeResult.assistantMessage = await this.persistAssistantMessage({
|
||||||
sessionId,
|
sessionId,
|
||||||
agentSessionId,
|
agentSessionId,
|
||||||
payload: assistant.payload,
|
payload: assistant.payload,
|
||||||
metadata: assistant.metadata,
|
metadata: assistant.metadata,
|
||||||
createdAt: assistant.createdAt,
|
createdAt: assistant.createdAt
|
||||||
tx
|
})
|
||||||
})
|
}
|
||||||
}
|
|
||||||
|
|
||||||
return exchangeResult
|
return exchangeResult
|
||||||
})
|
|
||||||
|
|
||||||
return result
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async getSessionHistory(sessionId: string): Promise<AgentPersistedMessage[]> {
|
async getSessionHistory(sessionId: string): Promise<AgentPersistedMessage[]> {
|
||||||
await AgentMessageRepository.initialize()
|
|
||||||
this.ensureInitialized()
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const rows = await this.database
|
const database = await this.getDatabase()
|
||||||
|
const rows = await database
|
||||||
.select()
|
.select()
|
||||||
.from(sessionMessagesTable)
|
.from(sessionMessagesTable)
|
||||||
.where(eq(sessionMessagesTable.session_id, sessionId))
|
.where(eq(sessionMessagesTable.session_id, sessionId))
|
||||||
|
|||||||
@@ -32,14 +32,8 @@ export class AgentService extends BaseService {
|
|||||||
return AgentService.instance
|
return AgentService.instance
|
||||||
}
|
}
|
||||||
|
|
||||||
async initialize(): Promise<void> {
|
|
||||||
await BaseService.initialize()
|
|
||||||
}
|
|
||||||
|
|
||||||
// Agent Methods
|
// Agent Methods
|
||||||
async createAgent(req: CreateAgentRequest): Promise<CreateAgentResponse> {
|
async createAgent(req: CreateAgentRequest): Promise<CreateAgentResponse> {
|
||||||
this.ensureInitialized()
|
|
||||||
|
|
||||||
const id = `agent_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`
|
const id = `agent_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`
|
||||||
const now = new Date().toISOString()
|
const now = new Date().toISOString()
|
||||||
|
|
||||||
@@ -75,8 +69,9 @@ export class AgentService extends BaseService {
|
|||||||
updated_at: now
|
updated_at: now
|
||||||
}
|
}
|
||||||
|
|
||||||
await this.database.insert(agentsTable).values(insertData)
|
const database = await this.getDatabase()
|
||||||
const result = await this.database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
|
await database.insert(agentsTable).values(insertData)
|
||||||
|
const result = await database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
|
||||||
if (!result[0]) {
|
if (!result[0]) {
|
||||||
throw new Error('Failed to create agent')
|
throw new Error('Failed to create agent')
|
||||||
}
|
}
|
||||||
@@ -86,9 +81,8 @@ export class AgentService extends BaseService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async getAgent(id: string): Promise<GetAgentResponse | null> {
|
async getAgent(id: string): Promise<GetAgentResponse | null> {
|
||||||
this.ensureInitialized()
|
const database = await this.getDatabase()
|
||||||
|
const result = await database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
|
||||||
const result = await this.database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
|
|
||||||
|
|
||||||
if (!result[0]) {
|
if (!result[0]) {
|
||||||
return null
|
return null
|
||||||
@@ -118,9 +112,9 @@ export class AgentService extends BaseService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async listAgents(options: ListOptions = {}): Promise<{ agents: AgentEntity[]; total: number }> {
|
async listAgents(options: ListOptions = {}): Promise<{ agents: AgentEntity[]; total: number }> {
|
||||||
this.ensureInitialized() // Build query with pagination
|
// Build query with pagination
|
||||||
|
const database = await this.getDatabase()
|
||||||
const totalResult = await this.database.select({ count: count() }).from(agentsTable)
|
const totalResult = await database.select({ count: count() }).from(agentsTable)
|
||||||
|
|
||||||
const sortBy = options.sortBy || 'created_at'
|
const sortBy = options.sortBy || 'created_at'
|
||||||
const orderBy = options.orderBy || 'desc'
|
const orderBy = options.orderBy || 'desc'
|
||||||
@@ -128,7 +122,7 @@ export class AgentService extends BaseService {
|
|||||||
const sortField = agentsTable[sortBy]
|
const sortField = agentsTable[sortBy]
|
||||||
const orderFn = orderBy === 'asc' ? asc : desc
|
const orderFn = orderBy === 'asc' ? asc : desc
|
||||||
|
|
||||||
const baseQuery = this.database.select().from(agentsTable).orderBy(orderFn(sortField))
|
const baseQuery = database.select().from(agentsTable).orderBy(orderFn(sortField))
|
||||||
|
|
||||||
const result =
|
const result =
|
||||||
options.limit !== undefined
|
options.limit !== undefined
|
||||||
@@ -151,8 +145,6 @@ export class AgentService extends BaseService {
|
|||||||
updates: UpdateAgentRequest,
|
updates: UpdateAgentRequest,
|
||||||
options: { replace?: boolean } = {}
|
options: { replace?: boolean } = {}
|
||||||
): Promise<UpdateAgentResponse | null> {
|
): Promise<UpdateAgentResponse | null> {
|
||||||
this.ensureInitialized()
|
|
||||||
|
|
||||||
// Check if agent exists
|
// Check if agent exists
|
||||||
const existing = await this.getAgent(id)
|
const existing = await this.getAgent(id)
|
||||||
if (!existing) {
|
if (!existing) {
|
||||||
@@ -195,22 +187,21 @@ export class AgentService extends BaseService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
await this.database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id))
|
const database = await this.getDatabase()
|
||||||
|
await database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id))
|
||||||
return await this.getAgent(id)
|
return await this.getAgent(id)
|
||||||
}
|
}
|
||||||
|
|
||||||
async deleteAgent(id: string): Promise<boolean> {
|
async deleteAgent(id: string): Promise<boolean> {
|
||||||
this.ensureInitialized()
|
const database = await this.getDatabase()
|
||||||
|
const result = await database.delete(agentsTable).where(eq(agentsTable.id, id))
|
||||||
const result = await this.database.delete(agentsTable).where(eq(agentsTable.id, id))
|
|
||||||
|
|
||||||
return result.rowsAffected > 0
|
return result.rowsAffected > 0
|
||||||
}
|
}
|
||||||
|
|
||||||
async agentExists(id: string): Promise<boolean> {
|
async agentExists(id: string): Promise<boolean> {
|
||||||
this.ensureInitialized()
|
const database = await this.getDatabase()
|
||||||
|
const result = await database
|
||||||
const result = await this.database
|
|
||||||
.select({ id: agentsTable.id })
|
.select({ id: agentsTable.id })
|
||||||
.from(agentsTable)
|
.from(agentsTable)
|
||||||
.where(eq(agentsTable.id, id))
|
.where(eq(agentsTable.id, id))
|
||||||
|
|||||||
@@ -104,14 +104,9 @@ export class SessionMessageService extends BaseService {
|
|||||||
return SessionMessageService.instance
|
return SessionMessageService.instance
|
||||||
}
|
}
|
||||||
|
|
||||||
async initialize(): Promise<void> {
|
|
||||||
await BaseService.initialize()
|
|
||||||
}
|
|
||||||
|
|
||||||
async sessionMessageExists(id: number): Promise<boolean> {
|
async sessionMessageExists(id: number): Promise<boolean> {
|
||||||
this.ensureInitialized()
|
const database = await this.getDatabase()
|
||||||
|
const result = await database
|
||||||
const result = await this.database
|
|
||||||
.select({ id: sessionMessagesTable.id })
|
.select({ id: sessionMessagesTable.id })
|
||||||
.from(sessionMessagesTable)
|
.from(sessionMessagesTable)
|
||||||
.where(eq(sessionMessagesTable.id, id))
|
.where(eq(sessionMessagesTable.id, id))
|
||||||
@@ -124,10 +119,9 @@ export class SessionMessageService extends BaseService {
|
|||||||
sessionId: string,
|
sessionId: string,
|
||||||
options: ListOptions = {}
|
options: ListOptions = {}
|
||||||
): Promise<{ messages: AgentSessionMessageEntity[] }> {
|
): Promise<{ messages: AgentSessionMessageEntity[] }> {
|
||||||
this.ensureInitialized()
|
|
||||||
|
|
||||||
// Get messages with pagination
|
// Get messages with pagination
|
||||||
const baseQuery = this.database
|
const database = await this.getDatabase()
|
||||||
|
const baseQuery = database
|
||||||
.select()
|
.select()
|
||||||
.from(sessionMessagesTable)
|
.from(sessionMessagesTable)
|
||||||
.where(eq(sessionMessagesTable.session_id, sessionId))
|
.where(eq(sessionMessagesTable.session_id, sessionId))
|
||||||
@@ -146,9 +140,8 @@ export class SessionMessageService extends BaseService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async deleteSessionMessage(sessionId: string, messageId: number): Promise<boolean> {
|
async deleteSessionMessage(sessionId: string, messageId: number): Promise<boolean> {
|
||||||
this.ensureInitialized()
|
const database = await this.getDatabase()
|
||||||
|
const result = await database
|
||||||
const result = await this.database
|
|
||||||
.delete(sessionMessagesTable)
|
.delete(sessionMessagesTable)
|
||||||
.where(and(eq(sessionMessagesTable.id, messageId), eq(sessionMessagesTable.session_id, sessionId)))
|
.where(and(eq(sessionMessagesTable.id, messageId), eq(sessionMessagesTable.session_id, sessionId)))
|
||||||
|
|
||||||
@@ -160,8 +153,6 @@ export class SessionMessageService extends BaseService {
|
|||||||
messageData: CreateSessionMessageRequest,
|
messageData: CreateSessionMessageRequest,
|
||||||
abortController: AbortController
|
abortController: AbortController
|
||||||
): Promise<SessionStreamResult> {
|
): Promise<SessionStreamResult> {
|
||||||
this.ensureInitialized()
|
|
||||||
|
|
||||||
return await this.startSessionMessageStream(session, messageData, abortController)
|
return await this.startSessionMessageStream(session, messageData, abortController)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -270,10 +261,9 @@ export class SessionMessageService extends BaseService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
private async getLastAgentSessionId(sessionId: string): Promise<string> {
|
private async getLastAgentSessionId(sessionId: string): Promise<string> {
|
||||||
this.ensureInitialized()
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
const result = await this.database
|
const database = await this.getDatabase()
|
||||||
|
const result = await database
|
||||||
.select({ agent_session_id: sessionMessagesTable.agent_session_id })
|
.select({ agent_session_id: sessionMessagesTable.agent_session_id })
|
||||||
.from(sessionMessagesTable)
|
.from(sessionMessagesTable)
|
||||||
.where(and(eq(sessionMessagesTable.session_id, sessionId), not(eq(sessionMessagesTable.agent_session_id, ''))))
|
.where(and(eq(sessionMessagesTable.session_id, sessionId), not(eq(sessionMessagesTable.agent_session_id, ''))))
|
||||||
|
|||||||
@@ -30,10 +30,6 @@ export class SessionService extends BaseService {
|
|||||||
return SessionService.instance
|
return SessionService.instance
|
||||||
}
|
}
|
||||||
|
|
||||||
async initialize(): Promise<void> {
|
|
||||||
await BaseService.initialize()
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Override BaseService.listSlashCommands to merge builtin and plugin commands
|
* Override BaseService.listSlashCommands to merge builtin and plugin commands
|
||||||
*/
|
*/
|
||||||
@@ -84,13 +80,12 @@ export class SessionService extends BaseService {
|
|||||||
agentId: string,
|
agentId: string,
|
||||||
req: Partial<CreateSessionRequest> = {}
|
req: Partial<CreateSessionRequest> = {}
|
||||||
): Promise<GetAgentSessionResponse | null> {
|
): Promise<GetAgentSessionResponse | null> {
|
||||||
this.ensureInitialized()
|
|
||||||
|
|
||||||
// Validate agent exists - we'll need to import AgentService for this check
|
// Validate agent exists - we'll need to import AgentService for this check
|
||||||
// For now, we'll skip this validation to avoid circular dependencies
|
// For now, we'll skip this validation to avoid circular dependencies
|
||||||
// The database foreign key constraint will handle this
|
// The database foreign key constraint will handle this
|
||||||
|
|
||||||
const agents = await this.database.select().from(agentsTable).where(eq(agentsTable.id, agentId)).limit(1)
|
const database = await this.getDatabase()
|
||||||
|
const agents = await database.select().from(agentsTable).where(eq(agentsTable.id, agentId)).limit(1)
|
||||||
if (!agents[0]) {
|
if (!agents[0]) {
|
||||||
throw new Error('Agent not found')
|
throw new Error('Agent not found')
|
||||||
}
|
}
|
||||||
@@ -135,9 +130,10 @@ export class SessionService extends BaseService {
|
|||||||
updated_at: now
|
updated_at: now
|
||||||
}
|
}
|
||||||
|
|
||||||
await this.database.insert(sessionsTable).values(insertData)
|
const db = await this.getDatabase()
|
||||||
|
await db.insert(sessionsTable).values(insertData)
|
||||||
|
|
||||||
const result = await this.database.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1)
|
const result = await db.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1)
|
||||||
|
|
||||||
if (!result[0]) {
|
if (!result[0]) {
|
||||||
throw new Error('Failed to create session')
|
throw new Error('Failed to create session')
|
||||||
@@ -148,9 +144,8 @@ export class SessionService extends BaseService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async getSession(agentId: string, id: string): Promise<GetAgentSessionResponse | null> {
|
async getSession(agentId: string, id: string): Promise<GetAgentSessionResponse | null> {
|
||||||
this.ensureInitialized()
|
const database = await this.getDatabase()
|
||||||
|
const result = await database
|
||||||
const result = await this.database
|
|
||||||
.select()
|
.select()
|
||||||
.from(sessionsTable)
|
.from(sessionsTable)
|
||||||
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
|
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
|
||||||
@@ -176,8 +171,6 @@ export class SessionService extends BaseService {
|
|||||||
agentId?: string,
|
agentId?: string,
|
||||||
options: ListOptions = {}
|
options: ListOptions = {}
|
||||||
): Promise<{ sessions: AgentSessionEntity[]; total: number }> {
|
): Promise<{ sessions: AgentSessionEntity[]; total: number }> {
|
||||||
this.ensureInitialized()
|
|
||||||
|
|
||||||
// Build where conditions
|
// Build where conditions
|
||||||
const whereConditions: SQL[] = []
|
const whereConditions: SQL[] = []
|
||||||
if (agentId) {
|
if (agentId) {
|
||||||
@@ -192,16 +185,13 @@ export class SessionService extends BaseService {
|
|||||||
: undefined
|
: undefined
|
||||||
|
|
||||||
// Get total count
|
// Get total count
|
||||||
const totalResult = await this.database.select({ count: count() }).from(sessionsTable).where(whereClause)
|
const database = await this.getDatabase()
|
||||||
|
const totalResult = await database.select({ count: count() }).from(sessionsTable).where(whereClause)
|
||||||
|
|
||||||
const total = totalResult[0].count
|
const total = totalResult[0].count
|
||||||
|
|
||||||
// Build list query with pagination - sort by updated_at descending (latest first)
|
// Build list query with pagination - sort by updated_at descending (latest first)
|
||||||
const baseQuery = this.database
|
const baseQuery = database.select().from(sessionsTable).where(whereClause).orderBy(desc(sessionsTable.updated_at))
|
||||||
.select()
|
|
||||||
.from(sessionsTable)
|
|
||||||
.where(whereClause)
|
|
||||||
.orderBy(desc(sessionsTable.updated_at))
|
|
||||||
|
|
||||||
const result =
|
const result =
|
||||||
options.limit !== undefined
|
options.limit !== undefined
|
||||||
@@ -220,8 +210,6 @@ export class SessionService extends BaseService {
|
|||||||
id: string,
|
id: string,
|
||||||
updates: UpdateSessionRequest
|
updates: UpdateSessionRequest
|
||||||
): Promise<UpdateSessionResponse | null> {
|
): Promise<UpdateSessionResponse | null> {
|
||||||
this.ensureInitialized()
|
|
||||||
|
|
||||||
// Check if session exists
|
// Check if session exists
|
||||||
const existing = await this.getSession(agentId, id)
|
const existing = await this.getSession(agentId, id)
|
||||||
if (!existing) {
|
if (!existing) {
|
||||||
@@ -262,15 +250,15 @@ export class SessionService extends BaseService {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
await this.database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))
|
const database = await this.getDatabase()
|
||||||
|
await database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))
|
||||||
|
|
||||||
return await this.getSession(agentId, id)
|
return await this.getSession(agentId, id)
|
||||||
}
|
}
|
||||||
|
|
||||||
async deleteSession(agentId: string, id: string): Promise<boolean> {
|
async deleteSession(agentId: string, id: string): Promise<boolean> {
|
||||||
this.ensureInitialized()
|
const database = await this.getDatabase()
|
||||||
|
const result = await database
|
||||||
const result = await this.database
|
|
||||||
.delete(sessionsTable)
|
.delete(sessionsTable)
|
||||||
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
|
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
|
||||||
|
|
||||||
@@ -278,9 +266,8 @@ export class SessionService extends BaseService {
|
|||||||
}
|
}
|
||||||
|
|
||||||
async sessionExists(agentId: string, id: string): Promise<boolean> {
|
async sessionExists(agentId: string, id: string): Promise<boolean> {
|
||||||
this.ensureInitialized()
|
const database = await this.getDatabase()
|
||||||
|
const result = await database
|
||||||
const result = await this.database
|
|
||||||
.select({ id: sessionsTable.id })
|
.select({ id: sessionsTable.id })
|
||||||
.from(sessionsTable)
|
.from(sessionsTable)
|
||||||
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
|
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
|
||||||
|
|||||||
@@ -21,11 +21,16 @@ describe('stripLocalCommandTags', () => {
|
|||||||
'<local-command-stdout>line1</local-command-stdout>\nkeep\n<local-command-stderr>Error</local-command-stderr>'
|
'<local-command-stdout>line1</local-command-stdout>\nkeep\n<local-command-stderr>Error</local-command-stderr>'
|
||||||
expect(stripLocalCommandTags(input)).toBe('line1\nkeep\nError')
|
expect(stripLocalCommandTags(input)).toBe('line1\nkeep\nError')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('if no tags present, returns original string', () => {
|
||||||
|
const input = 'just some normal text'
|
||||||
|
expect(stripLocalCommandTags(input)).toBe(input)
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|
||||||
describe('Claude → AiSDK transform', () => {
|
describe('Claude → AiSDK transform', () => {
|
||||||
it('handles tool call streaming lifecycle', () => {
|
it('handles tool call streaming lifecycle', () => {
|
||||||
const state = new ClaudeStreamState()
|
const state = new ClaudeStreamState({ agentSessionId: baseStreamMetadata.session_id })
|
||||||
const parts: ReturnType<typeof transformSDKMessageToStreamParts>[number][] = []
|
const parts: ReturnType<typeof transformSDKMessageToStreamParts>[number][] = []
|
||||||
|
|
||||||
const messages: SDKMessage[] = [
|
const messages: SDKMessage[] = [
|
||||||
@@ -182,14 +187,119 @@ describe('Claude → AiSDK transform', () => {
|
|||||||
(typeof parts)[number],
|
(typeof parts)[number],
|
||||||
{ type: 'tool-result' }
|
{ type: 'tool-result' }
|
||||||
>
|
>
|
||||||
expect(toolResult.toolCallId).toBe('tool-1')
|
expect(toolResult.toolCallId).toBe('session-123:tool-1')
|
||||||
expect(toolResult.toolName).toBe('Bash')
|
expect(toolResult.toolName).toBe('Bash')
|
||||||
expect(toolResult.input).toEqual({ command: 'ls' })
|
expect(toolResult.input).toEqual({ command: 'ls' })
|
||||||
expect(toolResult.output).toBe('ok')
|
expect(toolResult.output).toBe('ok')
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('handles tool calls without streaming events (no content_block_start/stop)', () => {
|
||||||
|
const state = new ClaudeStreamState({ agentSessionId: '12344' })
|
||||||
|
const parts: ReturnType<typeof transformSDKMessageToStreamParts>[number][] = []
|
||||||
|
|
||||||
|
const messages: SDKMessage[] = [
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'assistant',
|
||||||
|
uuid: uuid(20),
|
||||||
|
message: {
|
||||||
|
id: 'msg-tool-no-stream',
|
||||||
|
type: 'message',
|
||||||
|
role: 'assistant',
|
||||||
|
model: 'claude-test',
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'tool_use',
|
||||||
|
id: 'tool-read',
|
||||||
|
name: 'Read',
|
||||||
|
input: { file_path: '/test.txt' }
|
||||||
|
},
|
||||||
|
{
|
||||||
|
type: 'tool_use',
|
||||||
|
id: 'tool-bash',
|
||||||
|
name: 'Bash',
|
||||||
|
input: { command: 'ls -la' }
|
||||||
|
}
|
||||||
|
],
|
||||||
|
stop_reason: 'tool_use',
|
||||||
|
stop_sequence: null,
|
||||||
|
usage: {
|
||||||
|
input_tokens: 10,
|
||||||
|
output_tokens: 20
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} as unknown as SDKMessage,
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'user',
|
||||||
|
uuid: uuid(21),
|
||||||
|
message: {
|
||||||
|
role: 'user',
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'tool_result',
|
||||||
|
tool_use_id: 'tool-read',
|
||||||
|
content: 'file contents',
|
||||||
|
is_error: false
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
} as SDKMessage,
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'user',
|
||||||
|
uuid: uuid(22),
|
||||||
|
message: {
|
||||||
|
role: 'user',
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'tool_result',
|
||||||
|
tool_use_id: 'tool-bash',
|
||||||
|
content: 'total 42\n...',
|
||||||
|
is_error: false
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
} as SDKMessage
|
||||||
|
]
|
||||||
|
|
||||||
|
for (const message of messages) {
|
||||||
|
const transformed = transformSDKMessageToStreamParts(message, state)
|
||||||
|
parts.push(...transformed)
|
||||||
|
}
|
||||||
|
|
||||||
|
const types = parts.map((part) => part.type)
|
||||||
|
expect(types).toEqual(['tool-call', 'tool-call', 'tool-result', 'tool-result'])
|
||||||
|
|
||||||
|
const toolCalls = parts.filter((part) => part.type === 'tool-call') as Extract<
|
||||||
|
(typeof parts)[number],
|
||||||
|
{ type: 'tool-call' }
|
||||||
|
>[]
|
||||||
|
expect(toolCalls).toHaveLength(2)
|
||||||
|
expect(toolCalls[0].toolName).toBe('Read')
|
||||||
|
expect(toolCalls[0].toolCallId).toBe('12344:tool-read')
|
||||||
|
expect(toolCalls[1].toolName).toBe('Bash')
|
||||||
|
expect(toolCalls[1].toolCallId).toBe('12344:tool-bash')
|
||||||
|
|
||||||
|
const toolResults = parts.filter((part) => part.type === 'tool-result') as Extract<
|
||||||
|
(typeof parts)[number],
|
||||||
|
{ type: 'tool-result' }
|
||||||
|
>[]
|
||||||
|
expect(toolResults).toHaveLength(2)
|
||||||
|
// This is the key assertion - toolName should NOT be 'unknown'
|
||||||
|
expect(toolResults[0].toolName).toBe('Read')
|
||||||
|
expect(toolResults[0].toolCallId).toBe('12344:tool-read')
|
||||||
|
expect(toolResults[0].input).toEqual({ file_path: '/test.txt' })
|
||||||
|
expect(toolResults[0].output).toBe('file contents')
|
||||||
|
|
||||||
|
expect(toolResults[1].toolName).toBe('Bash')
|
||||||
|
expect(toolResults[1].toolCallId).toBe('12344:tool-bash')
|
||||||
|
expect(toolResults[1].input).toEqual({ command: 'ls -la' })
|
||||||
|
expect(toolResults[1].output).toBe('total 42\n...')
|
||||||
|
})
|
||||||
|
|
||||||
it('handles streaming text completion', () => {
|
it('handles streaming text completion', () => {
|
||||||
const state = new ClaudeStreamState()
|
const state = new ClaudeStreamState({ agentSessionId: baseStreamMetadata.session_id })
|
||||||
const parts: ReturnType<typeof transformSDKMessageToStreamParts>[number][] = []
|
const parts: ReturnType<typeof transformSDKMessageToStreamParts>[number][] = []
|
||||||
|
|
||||||
const messages: SDKMessage[] = [
|
const messages: SDKMessage[] = [
|
||||||
@@ -300,4 +410,87 @@ describe('Claude → AiSDK transform', () => {
|
|||||||
expect(finishStep.finishReason).toBe('stop')
|
expect(finishStep.finishReason).toBe('stop')
|
||||||
expect(finishStep.usage).toEqual({ inputTokens: 2, outputTokens: 4, totalTokens: 6 })
|
expect(finishStep.usage).toEqual({ inputTokens: 2, outputTokens: 4, totalTokens: 6 })
|
||||||
})
|
})
|
||||||
|
|
||||||
|
it('emits fallback text when Claude sends a snapshot instead of deltas', () => {
|
||||||
|
const state = new ClaudeStreamState({ agentSessionId: '12344' })
|
||||||
|
const parts: ReturnType<typeof transformSDKMessageToStreamParts>[number][] = []
|
||||||
|
|
||||||
|
const messages: SDKMessage[] = [
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'stream_event',
|
||||||
|
uuid: uuid(30),
|
||||||
|
event: {
|
||||||
|
type: 'message_start',
|
||||||
|
message: {
|
||||||
|
id: 'msg-fallback',
|
||||||
|
type: 'message',
|
||||||
|
role: 'assistant',
|
||||||
|
model: 'claude-test',
|
||||||
|
content: [],
|
||||||
|
stop_reason: null,
|
||||||
|
stop_sequence: null,
|
||||||
|
usage: {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} as unknown as SDKMessage,
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'stream_event',
|
||||||
|
uuid: uuid(31),
|
||||||
|
event: {
|
||||||
|
type: 'content_block_start',
|
||||||
|
index: 0,
|
||||||
|
content_block: {
|
||||||
|
type: 'text',
|
||||||
|
text: ''
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} as unknown as SDKMessage,
|
||||||
|
{
|
||||||
|
...baseStreamMetadata,
|
||||||
|
type: 'assistant',
|
||||||
|
uuid: uuid(32),
|
||||||
|
message: {
|
||||||
|
id: 'msg-fallback-content',
|
||||||
|
type: 'message',
|
||||||
|
role: 'assistant',
|
||||||
|
model: 'claude-test',
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: 'Final answer without streaming deltas.'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
stop_reason: 'end_turn',
|
||||||
|
stop_sequence: null,
|
||||||
|
usage: {
|
||||||
|
input_tokens: 3,
|
||||||
|
output_tokens: 7
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} as unknown as SDKMessage
|
||||||
|
]
|
||||||
|
|
||||||
|
for (const message of messages) {
|
||||||
|
const transformed = transformSDKMessageToStreamParts(message, state)
|
||||||
|
parts.push(...transformed)
|
||||||
|
}
|
||||||
|
|
||||||
|
const types = parts.map((part) => part.type)
|
||||||
|
expect(types).toEqual(['start-step', 'text-start', 'text-delta', 'text-end', 'finish-step'])
|
||||||
|
|
||||||
|
const delta = parts.find((part) => part.type === 'text-delta') as Extract<
|
||||||
|
(typeof parts)[number],
|
||||||
|
{ type: 'text-delta' }
|
||||||
|
>
|
||||||
|
expect(delta.text).toBe('Final answer without streaming deltas.')
|
||||||
|
|
||||||
|
const finish = parts.find((part) => part.type === 'finish-step') as Extract<
|
||||||
|
(typeof parts)[number],
|
||||||
|
{ type: 'finish-step' }
|
||||||
|
>
|
||||||
|
expect(finish.usage).toEqual({ inputTokens: 3, outputTokens: 7, totalTokens: 10 })
|
||||||
|
expect(finish.finishReason).toBe('stop')
|
||||||
|
})
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -10,8 +10,21 @@
|
|||||||
* Every Claude turn gets its own instance. `resetStep` should be invoked once the finish event has
|
* Every Claude turn gets its own instance. `resetStep` should be invoked once the finish event has
|
||||||
* been emitted to avoid leaking state into the next turn.
|
* been emitted to avoid leaking state into the next turn.
|
||||||
*/
|
*/
|
||||||
|
import { loggerService } from '@logger'
|
||||||
import type { FinishReason, LanguageModelUsage, ProviderMetadata } from 'ai'
|
import type { FinishReason, LanguageModelUsage, ProviderMetadata } from 'ai'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Builds a namespaced tool call ID by combining session ID with raw tool call ID.
|
||||||
|
* This ensures tool calls from different sessions don't conflict even if they have
|
||||||
|
* the same raw ID from the SDK.
|
||||||
|
*
|
||||||
|
* @param sessionId - The agent session ID
|
||||||
|
* @param rawToolCallId - The raw tool call ID from SDK (e.g., "WebFetch_0")
|
||||||
|
*/
|
||||||
|
export function buildNamespacedToolCallId(sessionId: string, rawToolCallId: string): string {
|
||||||
|
return `${sessionId}:${rawToolCallId}`
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Shared fields for every block that Claude can stream (text, reasoning, tool).
|
* Shared fields for every block that Claude can stream (text, reasoning, tool).
|
||||||
*/
|
*/
|
||||||
@@ -34,6 +47,7 @@ type ReasoningBlockState = BaseBlockState & {
|
|||||||
type ToolBlockState = BaseBlockState & {
|
type ToolBlockState = BaseBlockState & {
|
||||||
kind: 'tool'
|
kind: 'tool'
|
||||||
toolCallId: string
|
toolCallId: string
|
||||||
|
rawToolCallId: string
|
||||||
toolName: string
|
toolName: string
|
||||||
inputBuffer: string
|
inputBuffer: string
|
||||||
providerMetadata?: ProviderMetadata
|
providerMetadata?: ProviderMetadata
|
||||||
@@ -48,12 +62,17 @@ type PendingUsageState = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PendingToolCall = {
|
type PendingToolCall = {
|
||||||
|
rawToolCallId: string
|
||||||
toolCallId: string
|
toolCallId: string
|
||||||
toolName: string
|
toolName: string
|
||||||
input: unknown
|
input: unknown
|
||||||
providerMetadata?: ProviderMetadata
|
providerMetadata?: ProviderMetadata
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type ClaudeStreamStateOptions = {
|
||||||
|
agentSessionId: string
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Tracks the lifecycle of Claude streaming blocks (text, thinking, tool calls)
|
* Tracks the lifecycle of Claude streaming blocks (text, thinking, tool calls)
|
||||||
* across individual websocket events. The transformer relies on this class to
|
* across individual websocket events. The transformer relies on this class to
|
||||||
@@ -61,12 +80,20 @@ type PendingToolCall = {
|
|||||||
* usage/finish metadata once Anthropic closes a message.
|
* usage/finish metadata once Anthropic closes a message.
|
||||||
*/
|
*/
|
||||||
export class ClaudeStreamState {
|
export class ClaudeStreamState {
|
||||||
|
private logger
|
||||||
|
private readonly agentSessionId: string
|
||||||
private blocksByIndex = new Map<number, BlockState>()
|
private blocksByIndex = new Map<number, BlockState>()
|
||||||
private toolIndexById = new Map<string, number>()
|
private toolIndexByNamespacedId = new Map<string, number>()
|
||||||
private pendingUsage: PendingUsageState = {}
|
private pendingUsage: PendingUsageState = {}
|
||||||
private pendingToolCalls = new Map<string, PendingToolCall>()
|
private pendingToolCalls = new Map<string, PendingToolCall>()
|
||||||
private stepActive = false
|
private stepActive = false
|
||||||
|
|
||||||
|
constructor(options: ClaudeStreamStateOptions) {
|
||||||
|
this.logger = loggerService.withContext('ClaudeStreamState')
|
||||||
|
this.agentSessionId = options.agentSessionId
|
||||||
|
this.logger.silly('ClaudeStreamState', options)
|
||||||
|
}
|
||||||
|
|
||||||
/** Marks the beginning of a new AiSDK step. */
|
/** Marks the beginning of a new AiSDK step. */
|
||||||
beginStep(): void {
|
beginStep(): void {
|
||||||
this.stepActive = true
|
this.stepActive = true
|
||||||
@@ -104,19 +131,21 @@ export class ClaudeStreamState {
|
|||||||
/** Caches tool metadata so subsequent input deltas and results can find it. */
|
/** Caches tool metadata so subsequent input deltas and results can find it. */
|
||||||
openToolBlock(
|
openToolBlock(
|
||||||
index: number,
|
index: number,
|
||||||
params: { toolCallId: string; toolName: string; providerMetadata?: ProviderMetadata }
|
params: { rawToolCallId: string; toolName: string; providerMetadata?: ProviderMetadata }
|
||||||
): ToolBlockState {
|
): ToolBlockState {
|
||||||
|
const toolCallId = buildNamespacedToolCallId(this.agentSessionId, params.rawToolCallId)
|
||||||
const block: ToolBlockState = {
|
const block: ToolBlockState = {
|
||||||
kind: 'tool',
|
kind: 'tool',
|
||||||
id: params.toolCallId,
|
id: toolCallId,
|
||||||
index,
|
index,
|
||||||
toolCallId: params.toolCallId,
|
toolCallId,
|
||||||
|
rawToolCallId: params.rawToolCallId,
|
||||||
toolName: params.toolName,
|
toolName: params.toolName,
|
||||||
inputBuffer: '',
|
inputBuffer: '',
|
||||||
providerMetadata: params.providerMetadata
|
providerMetadata: params.providerMetadata
|
||||||
}
|
}
|
||||||
this.blocksByIndex.set(index, block)
|
this.blocksByIndex.set(index, block)
|
||||||
this.toolIndexById.set(params.toolCallId, index)
|
this.toolIndexByNamespacedId.set(toolCallId, index)
|
||||||
return block
|
return block
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -124,14 +153,32 @@ export class ClaudeStreamState {
|
|||||||
return this.blocksByIndex.get(index)
|
return this.blocksByIndex.get(index)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getFirstOpenTextBlock(): TextBlockState | undefined {
|
||||||
|
const candidates: TextBlockState[] = []
|
||||||
|
for (const block of this.blocksByIndex.values()) {
|
||||||
|
if (block.kind === 'text') {
|
||||||
|
candidates.push(block)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if (candidates.length === 0) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
candidates.sort((a, b) => a.index - b.index)
|
||||||
|
return candidates[0]
|
||||||
|
}
|
||||||
|
|
||||||
getToolBlockById(toolCallId: string): ToolBlockState | undefined {
|
getToolBlockById(toolCallId: string): ToolBlockState | undefined {
|
||||||
const index = this.toolIndexById.get(toolCallId)
|
const index = this.toolIndexByNamespacedId.get(toolCallId)
|
||||||
if (index === undefined) return undefined
|
if (index === undefined) return undefined
|
||||||
const block = this.blocksByIndex.get(index)
|
const block = this.blocksByIndex.get(index)
|
||||||
if (!block || block.kind !== 'tool') return undefined
|
if (!block || block.kind !== 'tool') return undefined
|
||||||
return block
|
return block
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getToolBlockByRawId(rawToolCallId: string): ToolBlockState | undefined {
|
||||||
|
return this.getToolBlockById(buildNamespacedToolCallId(this.agentSessionId, rawToolCallId))
|
||||||
|
}
|
||||||
|
|
||||||
/** Appends streamed text to a text block, returning the updated state when present. */
|
/** Appends streamed text to a text block, returning the updated state when present. */
|
||||||
appendTextDelta(index: number, text: string): TextBlockState | undefined {
|
appendTextDelta(index: number, text: string): TextBlockState | undefined {
|
||||||
const block = this.blocksByIndex.get(index)
|
const block = this.blocksByIndex.get(index)
|
||||||
@@ -158,10 +205,12 @@ export class ClaudeStreamState {
|
|||||||
|
|
||||||
/** Records a tool call to be consumed once its result arrives from the user. */
|
/** Records a tool call to be consumed once its result arrives from the user. */
|
||||||
registerToolCall(
|
registerToolCall(
|
||||||
toolCallId: string,
|
rawToolCallId: string,
|
||||||
payload: { toolName: string; input: unknown; providerMetadata?: ProviderMetadata }
|
payload: { toolName: string; input: unknown; providerMetadata?: ProviderMetadata }
|
||||||
): void {
|
): void {
|
||||||
this.pendingToolCalls.set(toolCallId, {
|
const toolCallId = buildNamespacedToolCallId(this.agentSessionId, rawToolCallId)
|
||||||
|
this.pendingToolCalls.set(rawToolCallId, {
|
||||||
|
rawToolCallId,
|
||||||
toolCallId,
|
toolCallId,
|
||||||
toolName: payload.toolName,
|
toolName: payload.toolName,
|
||||||
input: payload.input,
|
input: payload.input,
|
||||||
@@ -170,10 +219,10 @@ export class ClaudeStreamState {
|
|||||||
}
|
}
|
||||||
|
|
||||||
/** Retrieves and clears the buffered tool call metadata for the given id. */
|
/** Retrieves and clears the buffered tool call metadata for the given id. */
|
||||||
consumePendingToolCall(toolCallId: string): PendingToolCall | undefined {
|
consumePendingToolCall(rawToolCallId: string): PendingToolCall | undefined {
|
||||||
const entry = this.pendingToolCalls.get(toolCallId)
|
const entry = this.pendingToolCalls.get(rawToolCallId)
|
||||||
if (entry) {
|
if (entry) {
|
||||||
this.pendingToolCalls.delete(toolCallId)
|
this.pendingToolCalls.delete(rawToolCallId)
|
||||||
}
|
}
|
||||||
return entry
|
return entry
|
||||||
}
|
}
|
||||||
@@ -182,13 +231,13 @@ export class ClaudeStreamState {
|
|||||||
* Persists the final input payload for a tool block once the provider signals
|
* Persists the final input payload for a tool block once the provider signals
|
||||||
* completion so that downstream tool results can reference the original call.
|
* completion so that downstream tool results can reference the original call.
|
||||||
*/
|
*/
|
||||||
completeToolBlock(toolCallId: string, input: unknown, providerMetadata?: ProviderMetadata): void {
|
completeToolBlock(toolCallId: string, toolName: string, input: unknown, providerMetadata?: ProviderMetadata): void {
|
||||||
|
const block = this.getToolBlockByRawId(toolCallId)
|
||||||
this.registerToolCall(toolCallId, {
|
this.registerToolCall(toolCallId, {
|
||||||
toolName: this.getToolBlockById(toolCallId)?.toolName ?? 'unknown',
|
toolName,
|
||||||
input,
|
input,
|
||||||
providerMetadata
|
providerMetadata
|
||||||
})
|
})
|
||||||
const block = this.getToolBlockById(toolCallId)
|
|
||||||
if (block) {
|
if (block) {
|
||||||
block.resolvedInput = input
|
block.resolvedInput = input
|
||||||
}
|
}
|
||||||
@@ -200,7 +249,7 @@ export class ClaudeStreamState {
|
|||||||
if (!block) return undefined
|
if (!block) return undefined
|
||||||
this.blocksByIndex.delete(index)
|
this.blocksByIndex.delete(index)
|
||||||
if (block.kind === 'tool') {
|
if (block.kind === 'tool') {
|
||||||
this.toolIndexById.delete(block.toolCallId)
|
this.toolIndexByNamespacedId.delete(block.toolCallId)
|
||||||
}
|
}
|
||||||
return block
|
return block
|
||||||
}
|
}
|
||||||
@@ -227,7 +276,7 @@ export class ClaudeStreamState {
|
|||||||
/** Drops cached block metadata for the currently active message. */
|
/** Drops cached block metadata for the currently active message. */
|
||||||
resetBlocks(): void {
|
resetBlocks(): void {
|
||||||
this.blocksByIndex.clear()
|
this.blocksByIndex.clear()
|
||||||
this.toolIndexById.clear()
|
this.toolIndexByNamespacedId.clear()
|
||||||
}
|
}
|
||||||
|
|
||||||
/** Resets the entire step lifecycle after emitting a terminal frame. */
|
/** Resets the entire step lifecycle after emitting a terminal frame. */
|
||||||
@@ -236,6 +285,10 @@ export class ClaudeStreamState {
|
|||||||
this.resetPendingUsage()
|
this.resetPendingUsage()
|
||||||
this.stepActive = false
|
this.stepActive = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
getNamespacedToolCallId(rawToolCallId: string): string {
|
||||||
|
return buildNamespacedToolCallId(this.agentSessionId, rawToolCallId)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export type { PendingToolCall }
|
export type { PendingToolCall }
|
||||||
|
|||||||
@@ -2,7 +2,14 @@
|
|||||||
import { EventEmitter } from 'node:events'
|
import { EventEmitter } from 'node:events'
|
||||||
import { createRequire } from 'node:module'
|
import { createRequire } from 'node:module'
|
||||||
|
|
||||||
import type { CanUseTool, McpHttpServerConfig, Options, SDKMessage } from '@anthropic-ai/claude-agent-sdk'
|
import type {
|
||||||
|
CanUseTool,
|
||||||
|
HookCallback,
|
||||||
|
McpHttpServerConfig,
|
||||||
|
Options,
|
||||||
|
PreToolUseHookInput,
|
||||||
|
SDKMessage
|
||||||
|
} from '@anthropic-ai/claude-agent-sdk'
|
||||||
import { query } from '@anthropic-ai/claude-agent-sdk'
|
import { query } from '@anthropic-ai/claude-agent-sdk'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { config as apiConfigService } from '@main/apiServer/config'
|
import { config as apiConfigService } from '@main/apiServer/config'
|
||||||
@@ -13,6 +20,7 @@ import { app } from 'electron'
|
|||||||
import type { GetAgentSessionResponse } from '../..'
|
import type { GetAgentSessionResponse } from '../..'
|
||||||
import type { AgentServiceInterface, AgentStream, AgentStreamEvent } from '../../interfaces/AgentStreamInterface'
|
import type { AgentServiceInterface, AgentStream, AgentStreamEvent } from '../../interfaces/AgentStreamInterface'
|
||||||
import { sessionService } from '../SessionService'
|
import { sessionService } from '../SessionService'
|
||||||
|
import { buildNamespacedToolCallId } from './claude-stream-state'
|
||||||
import { promptForToolApproval } from './tool-permissions'
|
import { promptForToolApproval } from './tool-permissions'
|
||||||
import { ClaudeStreamState, transformSDKMessageToStreamParts } from './transform'
|
import { ClaudeStreamState, transformSDKMessageToStreamParts } from './transform'
|
||||||
|
|
||||||
@@ -150,7 +158,67 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
return { behavior: 'allow', updatedInput: input }
|
return { behavior: 'allow', updatedInput: input }
|
||||||
}
|
}
|
||||||
|
|
||||||
return promptForToolApproval(toolName, input, options)
|
return promptForToolApproval(toolName, input, {
|
||||||
|
...options,
|
||||||
|
toolCallId: buildNamespacedToolCallId(session.id, options.toolUseID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
const preToolUseHook: HookCallback = async (input, toolUseID, options) => {
|
||||||
|
// Type guard to ensure we're handling PreToolUse event
|
||||||
|
if (input.hook_event_name !== 'PreToolUse') {
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
|
||||||
|
const hookInput = input as PreToolUseHookInput
|
||||||
|
const toolName = hookInput.tool_name
|
||||||
|
|
||||||
|
logger.debug('PreToolUse hook triggered', {
|
||||||
|
session_id: hookInput.session_id,
|
||||||
|
tool_name: hookInput.tool_name,
|
||||||
|
tool_use_id: toolUseID,
|
||||||
|
tool_input: hookInput.tool_input,
|
||||||
|
cwd: hookInput.cwd,
|
||||||
|
permission_mode: hookInput.permission_mode,
|
||||||
|
autoAllowTools: autoAllowTools
|
||||||
|
})
|
||||||
|
|
||||||
|
if (options?.signal?.aborted) {
|
||||||
|
logger.debug('PreToolUse hook signal already aborted; skipping tool use', {
|
||||||
|
tool_name: hookInput.tool_name
|
||||||
|
})
|
||||||
|
return {}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handle auto approved tools since it never triggers canUseTool
|
||||||
|
const normalizedToolName = normalizeToolName(toolName)
|
||||||
|
if (toolUseID) {
|
||||||
|
const bypassAll = input.permission_mode === 'bypassPermissions'
|
||||||
|
const autoAllowed = autoAllowTools.has(toolName) || autoAllowTools.has(normalizedToolName)
|
||||||
|
if (bypassAll || autoAllowed) {
|
||||||
|
const namespacedToolCallId = buildNamespacedToolCallId(session.id, toolUseID)
|
||||||
|
logger.debug('handling auto approved tools', {
|
||||||
|
toolName,
|
||||||
|
normalizedToolName,
|
||||||
|
namespacedToolCallId,
|
||||||
|
permission_mode: input.permission_mode,
|
||||||
|
autoAllowTools
|
||||||
|
})
|
||||||
|
const isRecord = (v: unknown): v is Record<string, unknown> => {
|
||||||
|
return !!v && typeof v === 'object' && !Array.isArray(v)
|
||||||
|
}
|
||||||
|
const toolInput = isRecord(input.tool_input) ? input.tool_input : {}
|
||||||
|
|
||||||
|
await promptForToolApproval(toolName, toolInput, {
|
||||||
|
...options,
|
||||||
|
toolCallId: namespacedToolCallId,
|
||||||
|
autoApprove: true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return to proceed without modification
|
||||||
|
return {}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Build SDK options from parameters
|
// Build SDK options from parameters
|
||||||
@@ -176,7 +244,14 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
permissionMode: session.configuration?.permission_mode,
|
permissionMode: session.configuration?.permission_mode,
|
||||||
maxTurns: session.configuration?.max_turns,
|
maxTurns: session.configuration?.max_turns,
|
||||||
allowedTools: session.allowed_tools,
|
allowedTools: session.allowed_tools,
|
||||||
canUseTool
|
canUseTool,
|
||||||
|
hooks: {
|
||||||
|
PreToolUse: [
|
||||||
|
{
|
||||||
|
hooks: [preToolUseHook]
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (session.accessible_paths.length > 1) {
|
if (session.accessible_paths.length > 1) {
|
||||||
@@ -346,7 +421,7 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
const jsonOutput: SDKMessage[] = []
|
const jsonOutput: SDKMessage[] = []
|
||||||
let hasCompleted = false
|
let hasCompleted = false
|
||||||
const startTime = Date.now()
|
const startTime = Date.now()
|
||||||
const streamState = new ClaudeStreamState()
|
const streamState = new ClaudeStreamState({ agentSessionId: sessionId })
|
||||||
|
|
||||||
try {
|
try {
|
||||||
for await (const message of query({ prompt: promptStream, options })) {
|
for await (const message of query({ prompt: promptStream, options })) {
|
||||||
@@ -410,23 +485,6 @@ class ClaudeCodeService implements AgentServiceInterface {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (message.type === 'assistant' || message.type === 'user') {
|
|
||||||
logger.silly('claude response', {
|
|
||||||
message,
|
|
||||||
content: JSON.stringify(message.message.content)
|
|
||||||
})
|
|
||||||
} else if (message.type === 'stream_event') {
|
|
||||||
// logger.silly('Claude stream event', {
|
|
||||||
// message,
|
|
||||||
// event: JSON.stringify(message.event)
|
|
||||||
// })
|
|
||||||
} else {
|
|
||||||
logger.silly('Claude response', {
|
|
||||||
message,
|
|
||||||
event: JSON.stringify(message)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
const chunks = transformSDKMessageToStreamParts(message, streamState)
|
const chunks = transformSDKMessageToStreamParts(message, streamState)
|
||||||
for (const chunk of chunks) {
|
for (const chunk of chunks) {
|
||||||
stream.emit('data', {
|
stream.emit('data', {
|
||||||
|
|||||||
@@ -31,12 +31,14 @@ type PendingPermissionRequest = {
|
|||||||
abortListener?: () => void
|
abortListener?: () => void
|
||||||
originalInput: Record<string, unknown>
|
originalInput: Record<string, unknown>
|
||||||
toolName: string
|
toolName: string
|
||||||
|
toolCallId?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
type RendererPermissionRequestPayload = {
|
type RendererPermissionRequestPayload = {
|
||||||
requestId: string
|
requestId: string
|
||||||
toolName: string
|
toolName: string
|
||||||
toolId: string
|
toolId: string
|
||||||
|
toolCallId: string
|
||||||
description?: string
|
description?: string
|
||||||
requiresPermissions: boolean
|
requiresPermissions: boolean
|
||||||
input: Record<string, unknown>
|
input: Record<string, unknown>
|
||||||
@@ -44,6 +46,7 @@ type RendererPermissionRequestPayload = {
|
|||||||
createdAt: number
|
createdAt: number
|
||||||
expiresAt: number
|
expiresAt: number
|
||||||
suggestions: PermissionUpdate[]
|
suggestions: PermissionUpdate[]
|
||||||
|
autoApprove?: boolean
|
||||||
}
|
}
|
||||||
|
|
||||||
type RendererPermissionResultPayload = {
|
type RendererPermissionResultPayload = {
|
||||||
@@ -51,6 +54,7 @@ type RendererPermissionResultPayload = {
|
|||||||
behavior: ToolPermissionBehavior
|
behavior: ToolPermissionBehavior
|
||||||
message?: string
|
message?: string
|
||||||
reason: 'response' | 'timeout' | 'aborted' | 'no-window'
|
reason: 'response' | 'timeout' | 'aborted' | 'no-window'
|
||||||
|
toolCallId?: string
|
||||||
}
|
}
|
||||||
|
|
||||||
const pendingRequests = new Map<string, PendingPermissionRequest>()
|
const pendingRequests = new Map<string, PendingPermissionRequest>()
|
||||||
@@ -144,7 +148,8 @@ const finalizeRequest = (
|
|||||||
requestId,
|
requestId,
|
||||||
behavior: update.behavior,
|
behavior: update.behavior,
|
||||||
message: update.behavior === 'deny' ? update.message : undefined,
|
message: update.behavior === 'deny' ? update.message : undefined,
|
||||||
reason
|
reason,
|
||||||
|
toolCallId: pending.toolCallId
|
||||||
}
|
}
|
||||||
|
|
||||||
const dispatched = broadcastToRenderer(IpcChannel.AgentToolPermission_Result, resultPayload)
|
const dispatched = broadcastToRenderer(IpcChannel.AgentToolPermission_Result, resultPayload)
|
||||||
@@ -206,10 +211,20 @@ const ensureIpcHandlersRegistered = () => {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type PromptForToolApprovalOptions = {
|
||||||
|
signal: AbortSignal
|
||||||
|
suggestions?: PermissionUpdate[]
|
||||||
|
autoApprove?: boolean
|
||||||
|
|
||||||
|
// NOTICE: This ID is namespaced with session ID, not the raw SDK tool call ID.
|
||||||
|
// Format: `${sessionId}:${rawToolCallId}`, e.g., `session_123:WebFetch_0`
|
||||||
|
toolCallId: string
|
||||||
|
}
|
||||||
|
|
||||||
export async function promptForToolApproval(
|
export async function promptForToolApproval(
|
||||||
toolName: string,
|
toolName: string,
|
||||||
input: Record<string, unknown>,
|
input: Record<string, unknown>,
|
||||||
options?: { signal: AbortSignal; suggestions?: PermissionUpdate[] }
|
options: PromptForToolApprovalOptions
|
||||||
): Promise<PermissionResult> {
|
): Promise<PermissionResult> {
|
||||||
if (shouldAutoApproveTools) {
|
if (shouldAutoApproveTools) {
|
||||||
logger.debug('promptForToolApproval auto-approving tool for test', {
|
logger.debug('promptForToolApproval auto-approving tool for test', {
|
||||||
@@ -245,6 +260,7 @@ export async function promptForToolApproval(
|
|||||||
logger.info('Requesting user approval for tool usage', {
|
logger.info('Requesting user approval for tool usage', {
|
||||||
requestId,
|
requestId,
|
||||||
toolName,
|
toolName,
|
||||||
|
toolCallId: options.toolCallId,
|
||||||
description: toolMetadata?.description
|
description: toolMetadata?.description
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -252,13 +268,15 @@ export async function promptForToolApproval(
|
|||||||
requestId,
|
requestId,
|
||||||
toolName,
|
toolName,
|
||||||
toolId: toolMetadata?.id ?? toolName,
|
toolId: toolMetadata?.id ?? toolName,
|
||||||
|
toolCallId: options.toolCallId,
|
||||||
description: toolMetadata?.description,
|
description: toolMetadata?.description,
|
||||||
requiresPermissions: toolMetadata?.requirePermissions ?? false,
|
requiresPermissions: toolMetadata?.requirePermissions ?? false,
|
||||||
input: sanitizedInput,
|
input: sanitizedInput,
|
||||||
inputPreview,
|
inputPreview,
|
||||||
createdAt,
|
createdAt,
|
||||||
expiresAt,
|
expiresAt,
|
||||||
suggestions: sanitizedSuggestions
|
suggestions: sanitizedSuggestions,
|
||||||
|
autoApprove: options.autoApprove
|
||||||
}
|
}
|
||||||
|
|
||||||
const defaultDenyUpdate: PermissionResult = { behavior: 'deny', message: 'Tool request aborted before user decision' }
|
const defaultDenyUpdate: PermissionResult = { behavior: 'deny', message: 'Tool request aborted before user decision' }
|
||||||
@@ -266,6 +284,7 @@ export async function promptForToolApproval(
|
|||||||
logger.debug('Registering tool permission request', {
|
logger.debug('Registering tool permission request', {
|
||||||
requestId,
|
requestId,
|
||||||
toolName,
|
toolName,
|
||||||
|
toolCallId: options.toolCallId,
|
||||||
requiresPermissions: requestPayload.requiresPermissions,
|
requiresPermissions: requestPayload.requiresPermissions,
|
||||||
timeoutMs: TOOL_APPROVAL_TIMEOUT_MS,
|
timeoutMs: TOOL_APPROVAL_TIMEOUT_MS,
|
||||||
suggestionCount: sanitizedSuggestions.length
|
suggestionCount: sanitizedSuggestions.length
|
||||||
@@ -273,7 +292,11 @@ export async function promptForToolApproval(
|
|||||||
|
|
||||||
return new Promise<PermissionResult>((resolve) => {
|
return new Promise<PermissionResult>((resolve) => {
|
||||||
const timeout = setTimeout(() => {
|
const timeout = setTimeout(() => {
|
||||||
logger.info('User tool permission request timed out', { requestId, toolName })
|
logger.info('User tool permission request timed out', {
|
||||||
|
requestId,
|
||||||
|
toolName,
|
||||||
|
toolCallId: options.toolCallId
|
||||||
|
})
|
||||||
finalizeRequest(requestId, { behavior: 'deny', message: 'Timed out waiting for approval' }, 'timeout')
|
finalizeRequest(requestId, { behavior: 'deny', message: 'Timed out waiting for approval' }, 'timeout')
|
||||||
}, TOOL_APPROVAL_TIMEOUT_MS)
|
}, TOOL_APPROVAL_TIMEOUT_MS)
|
||||||
|
|
||||||
@@ -282,12 +305,17 @@ export async function promptForToolApproval(
|
|||||||
timeout,
|
timeout,
|
||||||
originalInput: sanitizedInput,
|
originalInput: sanitizedInput,
|
||||||
toolName,
|
toolName,
|
||||||
signal: options?.signal
|
signal: options?.signal,
|
||||||
|
toolCallId: options.toolCallId
|
||||||
}
|
}
|
||||||
|
|
||||||
if (options?.signal) {
|
if (options?.signal) {
|
||||||
const abortListener = () => {
|
const abortListener = () => {
|
||||||
logger.info('Tool permission request aborted before user responded', { requestId, toolName })
|
logger.info('Tool permission request aborted before user responded', {
|
||||||
|
requestId,
|
||||||
|
toolName,
|
||||||
|
toolCallId: options.toolCallId
|
||||||
|
})
|
||||||
finalizeRequest(requestId, defaultDenyUpdate, 'aborted')
|
finalizeRequest(requestId, defaultDenyUpdate, 'aborted')
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -110,7 +110,7 @@ const sdkMessageToProviderMetadata = (message: SDKMessage): ProviderMetadata =>
|
|||||||
* blocks across calls so that incremental deltas can be correlated correctly.
|
* blocks across calls so that incremental deltas can be correlated correctly.
|
||||||
*/
|
*/
|
||||||
export function transformSDKMessageToStreamParts(sdkMessage: SDKMessage, state: ClaudeStreamState): AgentStreamPart[] {
|
export function transformSDKMessageToStreamParts(sdkMessage: SDKMessage, state: ClaudeStreamState): AgentStreamPart[] {
|
||||||
logger.silly('Transforming SDKMessage', { message: sdkMessage })
|
logger.silly('Transforming SDKMessage', { message: JSON.stringify(sdkMessage) })
|
||||||
switch (sdkMessage.type) {
|
switch (sdkMessage.type) {
|
||||||
case 'assistant':
|
case 'assistant':
|
||||||
return handleAssistantMessage(sdkMessage, state)
|
return handleAssistantMessage(sdkMessage, state)
|
||||||
@@ -186,14 +186,13 @@ function handleAssistantMessage(
|
|||||||
|
|
||||||
for (const block of content) {
|
for (const block of content) {
|
||||||
switch (block.type) {
|
switch (block.type) {
|
||||||
case 'text':
|
case 'text': {
|
||||||
if (!isStreamingActive) {
|
const sanitizedText = stripLocalCommandTags(block.text)
|
||||||
const sanitizedText = stripLocalCommandTags(block.text)
|
if (sanitizedText) {
|
||||||
if (sanitizedText) {
|
textBlocks.push(sanitizedText)
|
||||||
textBlocks.push(sanitizedText)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
|
}
|
||||||
case 'tool_use':
|
case 'tool_use':
|
||||||
handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks)
|
handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks)
|
||||||
break
|
break
|
||||||
@@ -203,7 +202,16 @@ function handleAssistantMessage(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (!isStreamingActive && textBlocks.length > 0) {
|
if (textBlocks.length === 0) {
|
||||||
|
return chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
const combinedText = textBlocks.join('')
|
||||||
|
if (!combinedText) {
|
||||||
|
return chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!isStreamingActive) {
|
||||||
const id = message.uuid?.toString() || generateMessageId()
|
const id = message.uuid?.toString() || generateMessageId()
|
||||||
state.beginStep()
|
state.beginStep()
|
||||||
chunks.push({
|
chunks.push({
|
||||||
@@ -219,7 +227,7 @@ function handleAssistantMessage(
|
|||||||
chunks.push({
|
chunks.push({
|
||||||
type: 'text-delta',
|
type: 'text-delta',
|
||||||
id,
|
id,
|
||||||
text: textBlocks.join(''),
|
text: combinedText,
|
||||||
providerMetadata
|
providerMetadata
|
||||||
})
|
})
|
||||||
chunks.push({
|
chunks.push({
|
||||||
@@ -230,7 +238,27 @@ function handleAssistantMessage(
|
|||||||
return finalizeNonStreamingStep(message, state, chunks)
|
return finalizeNonStreamingStep(message, state, chunks)
|
||||||
}
|
}
|
||||||
|
|
||||||
return chunks
|
const existingTextBlock = state.getFirstOpenTextBlock()
|
||||||
|
const fallbackId = existingTextBlock?.id || message.uuid?.toString() || generateMessageId()
|
||||||
|
if (!existingTextBlock) {
|
||||||
|
chunks.push({
|
||||||
|
type: 'text-start',
|
||||||
|
id: fallbackId,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
}
|
||||||
|
chunks.push({
|
||||||
|
type: 'text-delta',
|
||||||
|
id: fallbackId,
|
||||||
|
text: combinedText,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
chunks.push({
|
||||||
|
type: 'text-end',
|
||||||
|
id: fallbackId,
|
||||||
|
providerMetadata
|
||||||
|
})
|
||||||
|
return finalizeNonStreamingStep(message, state, chunks)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -243,15 +271,16 @@ function handleAssistantToolUse(
|
|||||||
state: ClaudeStreamState,
|
state: ClaudeStreamState,
|
||||||
chunks: AgentStreamPart[]
|
chunks: AgentStreamPart[]
|
||||||
): void {
|
): void {
|
||||||
|
const toolCallId = state.getNamespacedToolCallId(block.id)
|
||||||
chunks.push({
|
chunks.push({
|
||||||
type: 'tool-call',
|
type: 'tool-call',
|
||||||
toolCallId: block.id,
|
toolCallId,
|
||||||
toolName: block.name,
|
toolName: block.name,
|
||||||
input: block.input,
|
input: block.input,
|
||||||
providerExecuted: true,
|
providerExecuted: true,
|
||||||
providerMetadata
|
providerMetadata
|
||||||
})
|
})
|
||||||
state.completeToolBlock(block.id, block.input, providerMetadata)
|
state.completeToolBlock(block.id, block.name, block.input, providerMetadata)
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -331,10 +360,11 @@ function handleUserMessage(
|
|||||||
if (block.type === 'tool_result') {
|
if (block.type === 'tool_result') {
|
||||||
const toolResult = block as ToolResultContent
|
const toolResult = block as ToolResultContent
|
||||||
const pendingCall = state.consumePendingToolCall(toolResult.tool_use_id)
|
const pendingCall = state.consumePendingToolCall(toolResult.tool_use_id)
|
||||||
|
const toolCallId = pendingCall?.toolCallId ?? state.getNamespacedToolCallId(toolResult.tool_use_id)
|
||||||
if (toolResult.is_error) {
|
if (toolResult.is_error) {
|
||||||
chunks.push({
|
chunks.push({
|
||||||
type: 'tool-error',
|
type: 'tool-error',
|
||||||
toolCallId: toolResult.tool_use_id,
|
toolCallId,
|
||||||
toolName: pendingCall?.toolName ?? 'unknown',
|
toolName: pendingCall?.toolName ?? 'unknown',
|
||||||
input: pendingCall?.input,
|
input: pendingCall?.input,
|
||||||
error: toolResult.content,
|
error: toolResult.content,
|
||||||
@@ -343,7 +373,7 @@ function handleUserMessage(
|
|||||||
} else {
|
} else {
|
||||||
chunks.push({
|
chunks.push({
|
||||||
type: 'tool-result',
|
type: 'tool-result',
|
||||||
toolCallId: toolResult.tool_use_id,
|
toolCallId,
|
||||||
toolName: pendingCall?.toolName ?? 'unknown',
|
toolName: pendingCall?.toolName ?? 'unknown',
|
||||||
input: pendingCall?.input,
|
input: pendingCall?.input,
|
||||||
output: toolResult.content,
|
output: toolResult.content,
|
||||||
@@ -457,6 +487,9 @@ function handleStreamEvent(
|
|||||||
}
|
}
|
||||||
|
|
||||||
case 'message_stop': {
|
case 'message_stop': {
|
||||||
|
if (!state.hasActiveStep()) {
|
||||||
|
break
|
||||||
|
}
|
||||||
const pending = state.getPendingUsage()
|
const pending = state.getPendingUsage()
|
||||||
chunks.push({
|
chunks.push({
|
||||||
type: 'finish-step',
|
type: 'finish-step',
|
||||||
@@ -514,7 +547,7 @@ function handleContentBlockStart(
|
|||||||
}
|
}
|
||||||
case 'tool_use': {
|
case 'tool_use': {
|
||||||
const block = state.openToolBlock(index, {
|
const block = state.openToolBlock(index, {
|
||||||
toolCallId: contentBlock.id,
|
rawToolCallId: contentBlock.id,
|
||||||
toolName: contentBlock.name,
|
toolName: contentBlock.name,
|
||||||
providerMetadata
|
providerMetadata
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -111,6 +111,7 @@ const api = {
|
|||||||
setFullScreen: (value: boolean): Promise<void> => ipcRenderer.invoke(IpcChannel.App_SetFullScreen, value),
|
setFullScreen: (value: boolean): Promise<void> => ipcRenderer.invoke(IpcChannel.App_SetFullScreen, value),
|
||||||
isFullScreen: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.App_IsFullScreen),
|
isFullScreen: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.App_IsFullScreen),
|
||||||
getSystemFonts: (): Promise<string[]> => ipcRenderer.invoke(IpcChannel.App_GetSystemFonts),
|
getSystemFonts: (): Promise<string[]> => ipcRenderer.invoke(IpcChannel.App_GetSystemFonts),
|
||||||
|
mockCrashRenderProcess: () => ipcRenderer.invoke(IpcChannel.APP_CrashRenderProcess),
|
||||||
mac: {
|
mac: {
|
||||||
isProcessTrusted: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.App_MacIsProcessTrusted),
|
isProcessTrusted: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.App_MacIsProcessTrusted),
|
||||||
requestProcessTrust: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.App_MacRequestProcessTrust)
|
requestProcessTrust: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.App_MacRequestProcessTrust)
|
||||||
@@ -121,7 +122,8 @@ const api = {
|
|||||||
system: {
|
system: {
|
||||||
getDeviceType: () => ipcRenderer.invoke(IpcChannel.System_GetDeviceType),
|
getDeviceType: () => ipcRenderer.invoke(IpcChannel.System_GetDeviceType),
|
||||||
getHostname: () => ipcRenderer.invoke(IpcChannel.System_GetHostname),
|
getHostname: () => ipcRenderer.invoke(IpcChannel.System_GetHostname),
|
||||||
getCpuName: () => ipcRenderer.invoke(IpcChannel.System_GetCpuName)
|
getCpuName: () => ipcRenderer.invoke(IpcChannel.System_GetCpuName),
|
||||||
|
checkGitBash: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.System_CheckGitBash)
|
||||||
},
|
},
|
||||||
devTools: {
|
devTools: {
|
||||||
toggle: () => ipcRenderer.invoke(IpcChannel.System_ToggleDevTools)
|
toggle: () => ipcRenderer.invoke(IpcChannel.System_ToggleDevTools)
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ import {
|
|||||||
prepareSpecialProviderConfig,
|
prepareSpecialProviderConfig,
|
||||||
providerToAiSdkConfig
|
providerToAiSdkConfig
|
||||||
} from './provider/providerConfig'
|
} from './provider/providerConfig'
|
||||||
|
import type { AiSdkConfig } from './types'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ModernAiProvider')
|
const logger = loggerService.withContext('ModernAiProvider')
|
||||||
|
|
||||||
@@ -44,7 +45,7 @@ export type ModernAiProviderConfig = AiSdkMiddlewareConfig & {
|
|||||||
|
|
||||||
export default class ModernAiProvider {
|
export default class ModernAiProvider {
|
||||||
private legacyProvider: LegacyAiProvider
|
private legacyProvider: LegacyAiProvider
|
||||||
private config?: ReturnType<typeof providerToAiSdkConfig>
|
private config?: AiSdkConfig
|
||||||
private actualProvider: Provider
|
private actualProvider: Provider
|
||||||
private model?: Model
|
private model?: Model
|
||||||
private localProvider: Awaited<AiSdkProvider> | null = null
|
private localProvider: Awaited<AiSdkProvider> | null = null
|
||||||
@@ -89,6 +90,11 @@ export default class ModernAiProvider {
|
|||||||
// 每次请求时重新生成配置以确保API key轮换生效
|
// 每次请求时重新生成配置以确保API key轮换生效
|
||||||
this.config = providerToAiSdkConfig(this.actualProvider, this.model)
|
this.config = providerToAiSdkConfig(this.actualProvider, this.model)
|
||||||
logger.debug('Generated provider config for completions', this.config)
|
logger.debug('Generated provider config for completions', this.config)
|
||||||
|
|
||||||
|
// 检查 config 是否存在
|
||||||
|
if (!this.config) {
|
||||||
|
throw new Error('Provider config is undefined; cannot proceed with completions')
|
||||||
|
}
|
||||||
if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) {
|
if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) {
|
||||||
providerConfig.isImageGenerationEndpoint = true
|
providerConfig.isImageGenerationEndpoint = true
|
||||||
}
|
}
|
||||||
@@ -149,7 +155,8 @@ export default class ModernAiProvider {
|
|||||||
params: StreamTextParams,
|
params: StreamTextParams,
|
||||||
config: ModernAiProviderConfig
|
config: ModernAiProviderConfig
|
||||||
): Promise<CompletionsResult> {
|
): Promise<CompletionsResult> {
|
||||||
if (config.isImageGenerationEndpoint) {
|
// ai-gateway不是image/generation 端点,所以就先不走legacy了
|
||||||
|
if (config.isImageGenerationEndpoint && config.provider!.id !== SystemProviderIds['ai-gateway']) {
|
||||||
// 使用 legacy 实现处理图像生成(支持图片编辑等高级功能)
|
// 使用 legacy 实现处理图像生成(支持图片编辑等高级功能)
|
||||||
if (!config.uiMessages) {
|
if (!config.uiMessages) {
|
||||||
throw new Error('uiMessages is required for image generation endpoint')
|
throw new Error('uiMessages is required for image generation endpoint')
|
||||||
@@ -463,8 +470,13 @@ export default class ModernAiProvider {
|
|||||||
// 如果支持新的 AI SDK,使用现代化实现
|
// 如果支持新的 AI SDK,使用现代化实现
|
||||||
if (isModernSdkSupported(this.actualProvider)) {
|
if (isModernSdkSupported(this.actualProvider)) {
|
||||||
try {
|
try {
|
||||||
|
// 确保 config 已定义
|
||||||
|
if (!this.config) {
|
||||||
|
throw new Error('Provider config is undefined; cannot proceed with generateImage')
|
||||||
|
}
|
||||||
|
|
||||||
// 确保本地provider已创建
|
// 确保本地provider已创建
|
||||||
if (!this.localProvider) {
|
if (!this.localProvider && this.config) {
|
||||||
this.localProvider = await createAiSdkProvider(this.config)
|
this.localProvider = await createAiSdkProvider(this.config)
|
||||||
if (!this.localProvider) {
|
if (!this.localProvider) {
|
||||||
throw new Error('Local provider not created')
|
throw new Error('Local provider not created')
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { isNewApiProvider } from '@renderer/config/providers'
|
|
||||||
import type { Provider } from '@renderer/types'
|
import type { Provider } from '@renderer/types'
|
||||||
|
import { isNewApiProvider } from '@renderer/utils/provider'
|
||||||
|
|
||||||
import { AihubmixAPIClient } from './aihubmix/AihubmixAPIClient'
|
import { AihubmixAPIClient } from './aihubmix/AihubmixAPIClient'
|
||||||
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
||||||
|
|||||||
@@ -1,12 +1,12 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import {
|
import {
|
||||||
|
getModelSupportedVerbosity,
|
||||||
isFunctionCallingModel,
|
isFunctionCallingModel,
|
||||||
isNotSupportTemperatureAndTopP,
|
isNotSupportTemperatureAndTopP,
|
||||||
isOpenAIModel,
|
isOpenAIModel,
|
||||||
isSupportFlexServiceTierModel
|
isSupportFlexServiceTierModel
|
||||||
} from '@renderer/config/models'
|
} from '@renderer/config/models'
|
||||||
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
|
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
|
||||||
import { isSupportServiceTierProvider } from '@renderer/config/providers'
|
|
||||||
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
|
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
|
||||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||||
import type {
|
import type {
|
||||||
@@ -18,7 +18,6 @@ import type {
|
|||||||
MCPToolResponse,
|
MCPToolResponse,
|
||||||
MemoryItem,
|
MemoryItem,
|
||||||
Model,
|
Model,
|
||||||
OpenAIVerbosity,
|
|
||||||
Provider,
|
Provider,
|
||||||
ToolCallResponse,
|
ToolCallResponse,
|
||||||
WebSearchProviderResponse,
|
WebSearchProviderResponse,
|
||||||
@@ -32,6 +31,7 @@ import {
|
|||||||
OpenAIServiceTiers,
|
OpenAIServiceTiers,
|
||||||
SystemProviderIds
|
SystemProviderIds
|
||||||
} from '@renderer/types'
|
} from '@renderer/types'
|
||||||
|
import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
|
||||||
import type { Message } from '@renderer/types/newMessage'
|
import type { Message } from '@renderer/types/newMessage'
|
||||||
import type {
|
import type {
|
||||||
RequestOptions,
|
RequestOptions,
|
||||||
@@ -47,6 +47,7 @@ import type {
|
|||||||
import { isJSON, parseJSON } from '@renderer/utils'
|
import { isJSON, parseJSON } from '@renderer/utils'
|
||||||
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
|
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
|
||||||
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||||
|
import { isSupportServiceTierProvider } from '@renderer/utils/provider'
|
||||||
import { defaultTimeout } from '@shared/config/constant'
|
import { defaultTimeout } from '@shared/config/constant'
|
||||||
import { defaultAppHeaders } from '@shared/utils'
|
import { defaultAppHeaders } from '@shared/utils'
|
||||||
import { isEmpty } from 'lodash'
|
import { isEmpty } from 'lodash'
|
||||||
@@ -242,12 +243,18 @@ export abstract class BaseApiClient<
|
|||||||
return serviceTierSetting
|
return serviceTierSetting
|
||||||
}
|
}
|
||||||
|
|
||||||
protected getVerbosity(): OpenAIVerbosity {
|
protected getVerbosity(model?: Model): OpenAIVerbosity {
|
||||||
try {
|
try {
|
||||||
const state = window.store?.getState()
|
const state = window.store?.getState()
|
||||||
const verbosity = state?.settings?.openAI?.verbosity
|
const verbosity = state?.settings?.openAI?.verbosity
|
||||||
|
|
||||||
if (verbosity && ['low', 'medium', 'high'].includes(verbosity)) {
|
if (verbosity && ['low', 'medium', 'high'].includes(verbosity)) {
|
||||||
|
// If model is provided, check if the verbosity is supported by the model
|
||||||
|
if (model) {
|
||||||
|
const supportedVerbosity = getModelSupportedVerbosity(model)
|
||||||
|
// Use user's verbosity if supported, otherwise use the first supported option
|
||||||
|
return supportedVerbosity.includes(verbosity) ? verbosity : supportedVerbosity[0]
|
||||||
|
}
|
||||||
return verbosity
|
return verbosity
|
||||||
}
|
}
|
||||||
} catch (error) {
|
} catch (error) {
|
||||||
|
|||||||
@@ -58,10 +58,27 @@ vi.mock('../aws/AwsBedrockAPIClient', () => ({
|
|||||||
AwsBedrockAPIClient: vi.fn().mockImplementation(() => ({}))
|
AwsBedrockAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/services/AssistantService.ts', () => ({
|
||||||
|
getDefaultAssistant: () => {
|
||||||
|
return {
|
||||||
|
id: 'default',
|
||||||
|
name: 'default',
|
||||||
|
emoji: '😀',
|
||||||
|
prompt: '',
|
||||||
|
topics: [],
|
||||||
|
messages: [],
|
||||||
|
type: 'assistant',
|
||||||
|
regularPhrases: [],
|
||||||
|
settings: {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
// Mock the models config to prevent circular dependency issues
|
// Mock the models config to prevent circular dependency issues
|
||||||
vi.mock('@renderer/config/models', () => ({
|
vi.mock('@renderer/config/models', () => ({
|
||||||
findTokenLimit: vi.fn(),
|
findTokenLimit: vi.fn(),
|
||||||
isReasoningModel: vi.fn(),
|
isReasoningModel: vi.fn(),
|
||||||
|
isOpenAILLMModel: vi.fn(),
|
||||||
SYSTEM_MODELS: {
|
SYSTEM_MODELS: {
|
||||||
silicon: [],
|
silicon: [],
|
||||||
defaultModel: []
|
defaultModel: []
|
||||||
|
|||||||
@@ -1,7 +1,8 @@
|
|||||||
import { GoogleGenAI } from '@google/genai'
|
import { GoogleGenAI } from '@google/genai'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
|
||||||
import type { Model, Provider, VertexProvider } from '@renderer/types'
|
import type { Model, Provider, VertexProvider } from '@renderer/types'
|
||||||
|
import { isVertexProvider } from '@renderer/utils/provider'
|
||||||
import { isEmpty } from 'lodash'
|
import { isEmpty } from 'lodash'
|
||||||
|
|
||||||
import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient'
|
import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient'
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
|||||||
import {
|
import {
|
||||||
findTokenLimit,
|
findTokenLimit,
|
||||||
GEMINI_FLASH_MODEL_REGEX,
|
GEMINI_FLASH_MODEL_REGEX,
|
||||||
getOpenAIWebSearchParams,
|
|
||||||
getThinkModelType,
|
getThinkModelType,
|
||||||
isClaudeReasoningModel,
|
isClaudeReasoningModel,
|
||||||
isDeepSeekHybridInferenceModel,
|
isDeepSeekHybridInferenceModel,
|
||||||
@@ -35,16 +34,11 @@ import {
|
|||||||
isSupportedThinkingTokenModel,
|
isSupportedThinkingTokenModel,
|
||||||
isSupportedThinkingTokenQwenModel,
|
isSupportedThinkingTokenQwenModel,
|
||||||
isSupportedThinkingTokenZhipuModel,
|
isSupportedThinkingTokenZhipuModel,
|
||||||
|
isSupportVerbosityModel,
|
||||||
isVisionModel,
|
isVisionModel,
|
||||||
MODEL_SUPPORTED_REASONING_EFFORT,
|
MODEL_SUPPORTED_REASONING_EFFORT,
|
||||||
ZHIPU_RESULT_TOKENS
|
ZHIPU_RESULT_TOKENS
|
||||||
} from '@renderer/config/models'
|
} from '@renderer/config/models'
|
||||||
import {
|
|
||||||
isSupportArrayContentProvider,
|
|
||||||
isSupportDeveloperRoleProvider,
|
|
||||||
isSupportEnableThinkingProvider,
|
|
||||||
isSupportStreamOptionsProvider
|
|
||||||
} from '@renderer/config/providers'
|
|
||||||
import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
|
import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
|
||||||
import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService'
|
import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService'
|
||||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||||
@@ -88,6 +82,12 @@ import {
|
|||||||
openAIToolsToMcpTool
|
openAIToolsToMcpTool
|
||||||
} from '@renderer/utils/mcp-tools'
|
} from '@renderer/utils/mcp-tools'
|
||||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||||
|
import {
|
||||||
|
isSupportArrayContentProvider,
|
||||||
|
isSupportDeveloperRoleProvider,
|
||||||
|
isSupportEnableThinkingProvider,
|
||||||
|
isSupportStreamOptionsProvider
|
||||||
|
} from '@renderer/utils/provider'
|
||||||
import { t } from 'i18next'
|
import { t } from 'i18next'
|
||||||
|
|
||||||
import type { GenericChunk } from '../../middleware/schemas'
|
import type { GenericChunk } from '../../middleware/schemas'
|
||||||
@@ -733,9 +733,16 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
|
|||||||
...modalities,
|
...modalities,
|
||||||
// groq 有不同的 service tier 配置,不符合 openai 接口类型
|
// groq 有不同的 service tier 配置,不符合 openai 接口类型
|
||||||
service_tier: this.getServiceTier(model) as OpenAIServiceTier,
|
service_tier: this.getServiceTier(model) as OpenAIServiceTier,
|
||||||
|
...(isSupportVerbosityModel(model)
|
||||||
|
? {
|
||||||
|
text: {
|
||||||
|
verbosity: this.getVerbosity(model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
: {}),
|
||||||
...this.getProviderSpecificParameters(assistant, model),
|
...this.getProviderSpecificParameters(assistant, model),
|
||||||
...reasoningEffort,
|
...reasoningEffort,
|
||||||
...getOpenAIWebSearchParams(model, enableWebSearch),
|
// ...getOpenAIWebSearchParams(model, enableWebSearch),
|
||||||
// OpenRouter usage tracking
|
// OpenRouter usage tracking
|
||||||
...(this.provider.id === 'openrouter' ? { usage: { include: true } } : {}),
|
...(this.provider.id === 'openrouter' ? { usage: { include: true } } : {}),
|
||||||
...extra_body,
|
...extra_body,
|
||||||
|
|||||||
@@ -48,9 +48,8 @@ export abstract class OpenAIBaseClient<
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 仅适用于openai
|
// 仅适用于openai
|
||||||
override getBaseURL(): string {
|
override getBaseURL(isSupportedAPIVerion: boolean = true): string {
|
||||||
const host = this.provider.apiHost
|
return formatApiHost(this.provider.apiHost, isSupportedAPIVerion)
|
||||||
return formatApiHost(host)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
override async generateImage({
|
override async generateImage({
|
||||||
@@ -144,6 +143,11 @@ export abstract class OpenAIBaseClient<
|
|||||||
}
|
}
|
||||||
|
|
||||||
let apiKeyForSdkInstance = this.apiKey
|
let apiKeyForSdkInstance = this.apiKey
|
||||||
|
let baseURLForSdkInstance = this.getBaseURL()
|
||||||
|
let headersForSdkInstance = {
|
||||||
|
...this.defaultHeaders(),
|
||||||
|
...this.provider.extra_headers
|
||||||
|
}
|
||||||
|
|
||||||
if (this.provider.id === 'copilot') {
|
if (this.provider.id === 'copilot') {
|
||||||
const defaultHeaders = store.getState().copilot.defaultHeaders
|
const defaultHeaders = store.getState().copilot.defaultHeaders
|
||||||
@@ -151,6 +155,11 @@ export abstract class OpenAIBaseClient<
|
|||||||
// this.provider.apiKey不允许修改
|
// this.provider.apiKey不允许修改
|
||||||
// this.provider.apiKey = token
|
// this.provider.apiKey = token
|
||||||
apiKeyForSdkInstance = token
|
apiKeyForSdkInstance = token
|
||||||
|
baseURLForSdkInstance = this.getBaseURL(false)
|
||||||
|
headersForSdkInstance = {
|
||||||
|
...headersForSdkInstance,
|
||||||
|
...COPILOT_DEFAULT_HEADERS
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
|
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
|
||||||
@@ -164,12 +173,8 @@ export abstract class OpenAIBaseClient<
|
|||||||
this.sdkInstance = new OpenAI({
|
this.sdkInstance = new OpenAI({
|
||||||
dangerouslyAllowBrowser: true,
|
dangerouslyAllowBrowser: true,
|
||||||
apiKey: apiKeyForSdkInstance,
|
apiKey: apiKeyForSdkInstance,
|
||||||
baseURL: this.getBaseURL(),
|
baseURL: baseURLForSdkInstance,
|
||||||
defaultHeaders: {
|
defaultHeaders: headersForSdkInstance
|
||||||
...this.defaultHeaders(),
|
|
||||||
...this.provider.extra_headers,
|
|
||||||
...(this.provider.id === 'copilot' ? COPILOT_DEFAULT_HEADERS : {})
|
|
||||||
}
|
|
||||||
}) as TSdkInstance
|
}) as TSdkInstance
|
||||||
}
|
}
|
||||||
return this.sdkInstance
|
return this.sdkInstance
|
||||||
|
|||||||
@@ -12,7 +12,6 @@ import {
|
|||||||
isSupportVerbosityModel,
|
isSupportVerbosityModel,
|
||||||
isVisionModel
|
isVisionModel
|
||||||
} from '@renderer/config/models'
|
} from '@renderer/config/models'
|
||||||
import { isSupportDeveloperRoleProvider } from '@renderer/config/providers'
|
|
||||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||||
import type {
|
import type {
|
||||||
FileMetadata,
|
FileMetadata,
|
||||||
@@ -43,6 +42,7 @@ import {
|
|||||||
openAIToolsToMcpTool
|
openAIToolsToMcpTool
|
||||||
} from '@renderer/utils/mcp-tools'
|
} from '@renderer/utils/mcp-tools'
|
||||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||||
|
import { isSupportDeveloperRoleProvider } from '@renderer/utils/provider'
|
||||||
import { MB } from '@shared/config/constant'
|
import { MB } from '@shared/config/constant'
|
||||||
import { t } from 'i18next'
|
import { t } from 'i18next'
|
||||||
import { isEmpty } from 'lodash'
|
import { isEmpty } from 'lodash'
|
||||||
@@ -90,7 +90,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
|||||||
if (isOpenAILLMModel(model) && !isOpenAIChatCompletionOnlyModel(model)) {
|
if (isOpenAILLMModel(model) && !isOpenAIChatCompletionOnlyModel(model)) {
|
||||||
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
|
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
|
||||||
this.provider = { ...this.provider, apiHost: this.formatApiHost() }
|
this.provider = { ...this.provider, apiHost: this.formatApiHost() }
|
||||||
if (this.provider.apiVersion === 'preview') {
|
if (this.provider.apiVersion === 'preview' || this.provider.apiVersion === 'v1') {
|
||||||
return this
|
return this
|
||||||
} else {
|
} else {
|
||||||
return this.client
|
return this.client
|
||||||
@@ -297,7 +297,31 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
|||||||
|
|
||||||
private convertResponseToMessageContent(response: OpenAI.Responses.Response): ResponseInput {
|
private convertResponseToMessageContent(response: OpenAI.Responses.Response): ResponseInput {
|
||||||
const content: OpenAI.Responses.ResponseInput = []
|
const content: OpenAI.Responses.ResponseInput = []
|
||||||
content.push(...response.output)
|
response.output.forEach((item) => {
|
||||||
|
if (item.type !== 'apply_patch_call' && item.type !== 'apply_patch_call_output') {
|
||||||
|
content.push(item)
|
||||||
|
} else if (item.type === 'apply_patch_call') {
|
||||||
|
if (item.operation !== undefined) {
|
||||||
|
const applyPatchToolCall: OpenAI.Responses.ResponseInputItem.ApplyPatchCall = {
|
||||||
|
...item,
|
||||||
|
operation: item.operation
|
||||||
|
}
|
||||||
|
content.push(applyPatchToolCall)
|
||||||
|
} else {
|
||||||
|
logger.warn('Undefined tool call operation for ApplyPatchToolCall.')
|
||||||
|
}
|
||||||
|
} else if (item.type === 'apply_patch_call_output') {
|
||||||
|
if (item.output !== undefined) {
|
||||||
|
const applyPatchToolCallOutput: OpenAI.Responses.ResponseInputItem.ApplyPatchCallOutput = {
|
||||||
|
...item,
|
||||||
|
output: item.output === null ? undefined : item.output
|
||||||
|
}
|
||||||
|
content.push(applyPatchToolCallOutput)
|
||||||
|
} else {
|
||||||
|
logger.warn('Undefined tool call operation for ApplyPatchToolCall.')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
return content
|
return content
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -496,7 +520,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
|
|||||||
...(isSupportVerbosityModel(model)
|
...(isSupportVerbosityModel(model)
|
||||||
? {
|
? {
|
||||||
text: {
|
text: {
|
||||||
verbosity: this.getVerbosity()
|
verbosity: this.getVerbosity(model)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
: {}),
|
: {}),
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { isZhipuModel } from '@renderer/config/models'
|
import { isZhipuModel } from '@renderer/config/models'
|
||||||
import { getStoreProviders } from '@renderer/hooks/useStore'
|
import { getStoreProviders } from '@renderer/hooks/useStore'
|
||||||
|
import { getDefaultModel } from '@renderer/services/AssistantService'
|
||||||
import type { Chunk } from '@renderer/types/chunk'
|
import type { Chunk } from '@renderer/types/chunk'
|
||||||
|
|
||||||
import type { CompletionsParams, CompletionsResult } from '../schemas'
|
import type { CompletionsParams, CompletionsResult } from '../schemas'
|
||||||
@@ -66,7 +67,7 @@ export const ErrorHandlerMiddleware =
|
|||||||
}
|
}
|
||||||
|
|
||||||
function handleError(error: any, params: CompletionsParams): any {
|
function handleError(error: any, params: CompletionsParams): any {
|
||||||
if (isZhipuModel(params.assistant.model) && error.status && !params.enableGenerateImage) {
|
if (isZhipuModel(params.assistant.model || getDefaultModel()) && error.status && !params.enableGenerateImage) {
|
||||||
return handleZhipuError(error)
|
return handleZhipuError(error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -1,10 +1,10 @@
|
|||||||
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
|
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { isSupportedThinkingTokenQwenModel } from '@renderer/config/models'
|
import { isSupportedThinkingTokenQwenModel } from '@renderer/config/models'
|
||||||
import { isSupportEnableThinkingProvider } from '@renderer/config/providers'
|
|
||||||
import type { MCPTool } from '@renderer/types'
|
import type { MCPTool } from '@renderer/types'
|
||||||
import { type Assistant, type Message, type Model, type Provider } from '@renderer/types'
|
import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types'
|
||||||
import type { Chunk } from '@renderer/types/chunk'
|
import type { Chunk } from '@renderer/types/chunk'
|
||||||
|
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
|
||||||
import type { LanguageModelMiddleware } from 'ai'
|
import type { LanguageModelMiddleware } from 'ai'
|
||||||
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
|
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
|
||||||
import { isEmpty } from 'lodash'
|
import { isEmpty } from 'lodash'
|
||||||
@@ -12,6 +12,7 @@ import { isEmpty } from 'lodash'
|
|||||||
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
|
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
|
||||||
import { noThinkMiddleware } from './noThinkMiddleware'
|
import { noThinkMiddleware } from './noThinkMiddleware'
|
||||||
import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware'
|
import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware'
|
||||||
|
import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
|
||||||
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
|
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
|
||||||
import { toolChoiceMiddleware } from './toolChoiceMiddleware'
|
import { toolChoiceMiddleware } from './toolChoiceMiddleware'
|
||||||
|
|
||||||
@@ -217,6 +218,14 @@ function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config:
|
|||||||
middleware: noThinkMiddleware()
|
middleware: noThinkMiddleware()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (config.provider.id === SystemProviderIds.openrouter && config.enableReasoning) {
|
||||||
|
builder.add({
|
||||||
|
name: 'openrouter-reasoning-redaction',
|
||||||
|
middleware: openrouterReasoningMiddleware()
|
||||||
|
})
|
||||||
|
logger.debug('Added OpenRouter reasoning redaction middleware')
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|||||||
@@ -0,0 +1,50 @@
|
|||||||
|
import type { LanguageModelV2StreamPart } from '@ai-sdk/provider'
|
||||||
|
import type { LanguageModelMiddleware } from 'ai'
|
||||||
|
|
||||||
|
/**
|
||||||
|
* https://openrouter.ai/docs/docs/best-practices/reasoning-tokens#example-preserving-reasoning-blocks-with-openrouter-and-claude
|
||||||
|
*
|
||||||
|
* @returns LanguageModelMiddleware - a middleware filter redacted block
|
||||||
|
*/
|
||||||
|
export function openrouterReasoningMiddleware(): LanguageModelMiddleware {
|
||||||
|
const REDACTED_BLOCK = '[REDACTED]'
|
||||||
|
return {
|
||||||
|
middlewareVersion: 'v2',
|
||||||
|
wrapGenerate: async ({ doGenerate }) => {
|
||||||
|
const { content, ...rest } = await doGenerate()
|
||||||
|
const modifiedContent = content.map((part) => {
|
||||||
|
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
|
||||||
|
return {
|
||||||
|
...part,
|
||||||
|
text: part.text.replace(REDACTED_BLOCK, '')
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return part
|
||||||
|
})
|
||||||
|
return { content: modifiedContent, ...rest }
|
||||||
|
},
|
||||||
|
wrapStream: async ({ doStream }) => {
|
||||||
|
const { stream, ...rest } = await doStream()
|
||||||
|
return {
|
||||||
|
stream: stream.pipeThrough(
|
||||||
|
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
|
||||||
|
transform(
|
||||||
|
chunk: LanguageModelV2StreamPart,
|
||||||
|
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
|
||||||
|
) {
|
||||||
|
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
|
||||||
|
controller.enqueue({
|
||||||
|
...chunk,
|
||||||
|
delta: chunk.delta.replace(REDACTED_BLOCK, '')
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
controller.enqueue(chunk)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
),
|
||||||
|
...rest
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -0,0 +1,234 @@
|
|||||||
|
import type { Message, Model } from '@renderer/types'
|
||||||
|
import type { FileMetadata } from '@renderer/types/file'
|
||||||
|
import { FileTypes } from '@renderer/types/file'
|
||||||
|
import {
|
||||||
|
AssistantMessageStatus,
|
||||||
|
type FileMessageBlock,
|
||||||
|
type ImageMessageBlock,
|
||||||
|
MessageBlockStatus,
|
||||||
|
MessageBlockType,
|
||||||
|
type ThinkingMessageBlock,
|
||||||
|
UserMessageStatus
|
||||||
|
} from '@renderer/types/newMessage'
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
const { convertFileBlockToFilePartMock, convertFileBlockToTextPartMock } = vi.hoisted(() => ({
|
||||||
|
convertFileBlockToFilePartMock: vi.fn(),
|
||||||
|
convertFileBlockToTextPartMock: vi.fn()
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../fileProcessor', () => ({
|
||||||
|
convertFileBlockToFilePart: convertFileBlockToFilePartMock,
|
||||||
|
convertFileBlockToTextPart: convertFileBlockToTextPartMock
|
||||||
|
}))
|
||||||
|
|
||||||
|
const visionModelIds = new Set(['gpt-4o-mini', 'qwen-image-edit'])
|
||||||
|
const imageEnhancementModelIds = new Set(['qwen-image-edit'])
|
||||||
|
|
||||||
|
vi.mock('@renderer/config/models', () => ({
|
||||||
|
isVisionModel: (model: Model) => visionModelIds.has(model.id),
|
||||||
|
isImageEnhancementModel: (model: Model) => imageEnhancementModelIds.has(model.id)
|
||||||
|
}))
|
||||||
|
|
||||||
|
type MockableMessage = Message & {
|
||||||
|
__mockContent?: string
|
||||||
|
__mockFileBlocks?: FileMessageBlock[]
|
||||||
|
__mockImageBlocks?: ImageMessageBlock[]
|
||||||
|
__mockThinkingBlocks?: ThinkingMessageBlock[]
|
||||||
|
}
|
||||||
|
|
||||||
|
vi.mock('@renderer/utils/messageUtils/find', () => ({
|
||||||
|
getMainTextContent: (message: Message) => (message as MockableMessage).__mockContent ?? '',
|
||||||
|
findFileBlocks: (message: Message) => (message as MockableMessage).__mockFileBlocks ?? [],
|
||||||
|
findImageBlocks: (message: Message) => (message as MockableMessage).__mockImageBlocks ?? [],
|
||||||
|
findThinkingBlocks: (message: Message) => (message as MockableMessage).__mockThinkingBlocks ?? []
|
||||||
|
}))
|
||||||
|
|
||||||
|
import { convertMessagesToSdkMessages, convertMessageToSdkParam } from '../messageConverter'
|
||||||
|
|
||||||
|
let messageCounter = 0
|
||||||
|
let blockCounter = 0
|
||||||
|
|
||||||
|
const createModel = (overrides: Partial<Model> = {}): Model => ({
|
||||||
|
id: 'gpt-4o-mini',
|
||||||
|
name: 'GPT-4o mini',
|
||||||
|
provider: 'openai',
|
||||||
|
group: 'openai',
|
||||||
|
...overrides
|
||||||
|
})
|
||||||
|
|
||||||
|
const createMessage = (role: Message['role']): MockableMessage =>
|
||||||
|
({
|
||||||
|
id: `message-${++messageCounter}`,
|
||||||
|
role,
|
||||||
|
assistantId: 'assistant-1',
|
||||||
|
topicId: 'topic-1',
|
||||||
|
createdAt: new Date(2024, 0, 1, 0, 0, messageCounter).toISOString(),
|
||||||
|
status: role === 'assistant' ? AssistantMessageStatus.SUCCESS : UserMessageStatus.SUCCESS,
|
||||||
|
blocks: []
|
||||||
|
}) as MockableMessage
|
||||||
|
|
||||||
|
const createFileBlock = (
|
||||||
|
messageId: string,
|
||||||
|
overrides: Partial<Omit<FileMessageBlock, 'file' | 'messageId' | 'type'>> & { file?: Partial<FileMetadata> } = {}
|
||||||
|
): FileMessageBlock => {
|
||||||
|
const { file, ...blockOverrides } = overrides
|
||||||
|
const timestamp = new Date(2024, 0, 1, 0, 0, ++blockCounter).toISOString()
|
||||||
|
return {
|
||||||
|
id: blockOverrides.id ?? `file-block-${blockCounter}`,
|
||||||
|
messageId,
|
||||||
|
type: MessageBlockType.FILE,
|
||||||
|
createdAt: blockOverrides.createdAt ?? timestamp,
|
||||||
|
status: blockOverrides.status ?? MessageBlockStatus.SUCCESS,
|
||||||
|
file: {
|
||||||
|
id: file?.id ?? `file-${blockCounter}`,
|
||||||
|
name: file?.name ?? 'document.txt',
|
||||||
|
origin_name: file?.origin_name ?? 'document.txt',
|
||||||
|
path: file?.path ?? '/tmp/document.txt',
|
||||||
|
size: file?.size ?? 1024,
|
||||||
|
ext: file?.ext ?? '.txt',
|
||||||
|
type: file?.type ?? FileTypes.TEXT,
|
||||||
|
created_at: file?.created_at ?? timestamp,
|
||||||
|
count: file?.count ?? 1,
|
||||||
|
...file
|
||||||
|
},
|
||||||
|
...blockOverrides
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const createImageBlock = (
|
||||||
|
messageId: string,
|
||||||
|
overrides: Partial<Omit<ImageMessageBlock, 'type' | 'messageId'>> = {}
|
||||||
|
): ImageMessageBlock => ({
|
||||||
|
id: overrides.id ?? `image-block-${++blockCounter}`,
|
||||||
|
messageId,
|
||||||
|
type: MessageBlockType.IMAGE,
|
||||||
|
createdAt: overrides.createdAt ?? new Date(2024, 0, 1, 0, 0, blockCounter).toISOString(),
|
||||||
|
status: overrides.status ?? MessageBlockStatus.SUCCESS,
|
||||||
|
url: overrides.url ?? 'https://example.com/image.png',
|
||||||
|
...overrides
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('messageConverter', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
convertFileBlockToFilePartMock.mockReset()
|
||||||
|
convertFileBlockToTextPartMock.mockReset()
|
||||||
|
convertFileBlockToFilePartMock.mockResolvedValue(null)
|
||||||
|
convertFileBlockToTextPartMock.mockResolvedValue(null)
|
||||||
|
messageCounter = 0
|
||||||
|
blockCounter = 0
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('convertMessageToSdkParam', () => {
|
||||||
|
it('includes text and image parts for user messages on vision models', async () => {
|
||||||
|
const model = createModel()
|
||||||
|
const message = createMessage('user')
|
||||||
|
message.__mockContent = 'Describe this picture'
|
||||||
|
message.__mockImageBlocks = [createImageBlock(message.id, { url: 'https://example.com/cat.png' })]
|
||||||
|
|
||||||
|
const result = await convertMessageToSdkParam(message, true, model)
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
role: 'user',
|
||||||
|
content: [
|
||||||
|
{ type: 'text', text: 'Describe this picture' },
|
||||||
|
{ type: 'image', image: 'https://example.com/cat.png' }
|
||||||
|
]
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns file instructions as a system message when native uploads succeed', async () => {
|
||||||
|
const model = createModel()
|
||||||
|
const message = createMessage('user')
|
||||||
|
message.__mockContent = 'Summarize the PDF'
|
||||||
|
message.__mockFileBlocks = [createFileBlock(message.id)]
|
||||||
|
convertFileBlockToFilePartMock.mockResolvedValueOnce({
|
||||||
|
type: 'file',
|
||||||
|
filename: 'document.pdf',
|
||||||
|
mediaType: 'application/pdf',
|
||||||
|
data: 'fileid://remote-file'
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = await convertMessageToSdkParam(message, false, model)
|
||||||
|
|
||||||
|
expect(result).toEqual([
|
||||||
|
{
|
||||||
|
role: 'system',
|
||||||
|
content: 'fileid://remote-file'
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: [{ type: 'text', text: 'Summarize the PDF' }]
|
||||||
|
}
|
||||||
|
])
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('convertMessagesToSdkMessages', () => {
|
||||||
|
it('appends assistant images to the final user message for image enhancement models', async () => {
|
||||||
|
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
|
||||||
|
const initialUser = createMessage('user')
|
||||||
|
initialUser.__mockContent = 'Start editing'
|
||||||
|
|
||||||
|
const assistant = createMessage('assistant')
|
||||||
|
assistant.__mockContent = 'Here is the current preview'
|
||||||
|
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/preview.png' })]
|
||||||
|
|
||||||
|
const finalUser = createMessage('user')
|
||||||
|
finalUser.__mockContent = 'Increase the brightness'
|
||||||
|
|
||||||
|
const result = await convertMessagesToSdkMessages([initialUser, assistant, finalUser], model)
|
||||||
|
|
||||||
|
expect(result).toEqual([
|
||||||
|
{
|
||||||
|
role: 'assistant',
|
||||||
|
content: [{ type: 'text', text: 'Here is the current preview' }]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: [
|
||||||
|
{ type: 'text', text: 'Increase the brightness' },
|
||||||
|
{ type: 'image', image: 'https://example.com/preview.png' }
|
||||||
|
]
|
||||||
|
}
|
||||||
|
])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('preserves preceding system instructions when building enhancement payloads', async () => {
|
||||||
|
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
|
||||||
|
const fileUser = createMessage('user')
|
||||||
|
fileUser.__mockContent = 'Use this document as inspiration'
|
||||||
|
fileUser.__mockFileBlocks = [createFileBlock(fileUser.id, { file: { ext: '.pdf', type: FileTypes.DOCUMENT } })]
|
||||||
|
convertFileBlockToFilePartMock.mockResolvedValueOnce({
|
||||||
|
type: 'file',
|
||||||
|
filename: 'reference.pdf',
|
||||||
|
mediaType: 'application/pdf',
|
||||||
|
data: 'fileid://reference'
|
||||||
|
})
|
||||||
|
|
||||||
|
const assistant = createMessage('assistant')
|
||||||
|
assistant.__mockContent = 'Generated previews ready'
|
||||||
|
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/reference.png' })]
|
||||||
|
|
||||||
|
const finalUser = createMessage('user')
|
||||||
|
finalUser.__mockContent = 'Apply the edits'
|
||||||
|
|
||||||
|
const result = await convertMessagesToSdkMessages([fileUser, assistant, finalUser], model)
|
||||||
|
|
||||||
|
expect(result).toEqual([
|
||||||
|
{ role: 'system', content: 'fileid://reference' },
|
||||||
|
{
|
||||||
|
role: 'assistant',
|
||||||
|
content: [{ type: 'text', text: 'Generated previews ready' }]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: 'user',
|
||||||
|
content: [
|
||||||
|
{ type: 'text', text: 'Apply the edits' },
|
||||||
|
{ type: 'image', image: 'https://example.com/reference.png' }
|
||||||
|
]
|
||||||
|
}
|
||||||
|
])
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -0,0 +1,218 @@
|
|||||||
|
import type { Assistant, AssistantSettings, Model, Topic } from '@renderer/types'
|
||||||
|
import { TopicType } from '@renderer/types'
|
||||||
|
import { defaultTimeout } from '@shared/config/constant'
|
||||||
|
import { describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
import { getTemperature, getTimeout, getTopP } from '../modelParameters'
|
||||||
|
|
||||||
|
vi.mock('@renderer/services/AssistantService', () => ({
|
||||||
|
getAssistantSettings: (assistant: Assistant): AssistantSettings => ({
|
||||||
|
contextCount: assistant.settings?.contextCount ?? 4096,
|
||||||
|
temperature: assistant.settings?.temperature ?? 0.7,
|
||||||
|
enableTemperature: assistant.settings?.enableTemperature ?? true,
|
||||||
|
topP: assistant.settings?.topP ?? 1,
|
||||||
|
enableTopP: assistant.settings?.enableTopP ?? false,
|
||||||
|
enableMaxTokens: assistant.settings?.enableMaxTokens ?? false,
|
||||||
|
maxTokens: assistant.settings?.maxTokens,
|
||||||
|
streamOutput: assistant.settings?.streamOutput ?? true,
|
||||||
|
toolUseMode: assistant.settings?.toolUseMode ?? 'prompt',
|
||||||
|
defaultModel: assistant.defaultModel,
|
||||||
|
customParameters: assistant.settings?.customParameters ?? [],
|
||||||
|
reasoning_effort: assistant.settings?.reasoning_effort,
|
||||||
|
reasoning_effort_cache: assistant.settings?.reasoning_effort_cache,
|
||||||
|
qwenThinkMode: assistant.settings?.qwenThinkMode
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/hooks/useSettings', () => ({
|
||||||
|
getStoreSetting: vi.fn(),
|
||||||
|
useSettings: vi.fn(() => ({})),
|
||||||
|
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left', isLeftNavbar: true, isTopNavbar: false }))
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/hooks/useStore', () => ({
|
||||||
|
getStoreProviders: vi.fn(() => [])
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/store/settings', () => ({
|
||||||
|
default: (state = { settings: {} }) => state
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/store/assistants', () => ({
|
||||||
|
default: (state = { assistants: [] }) => state
|
||||||
|
}))
|
||||||
|
|
||||||
|
const createTopic = (assistantId: string): Topic => ({
|
||||||
|
id: `topic-${assistantId}`,
|
||||||
|
assistantId,
|
||||||
|
name: 'topic',
|
||||||
|
createdAt: new Date().toISOString(),
|
||||||
|
updatedAt: new Date().toISOString(),
|
||||||
|
messages: [],
|
||||||
|
type: TopicType.Chat
|
||||||
|
})
|
||||||
|
|
||||||
|
const createAssistant = (settings: Assistant['settings'] = {}): Assistant => {
|
||||||
|
const assistantId = 'assistant-1'
|
||||||
|
return {
|
||||||
|
id: assistantId,
|
||||||
|
name: 'Test Assistant',
|
||||||
|
prompt: 'prompt',
|
||||||
|
topics: [createTopic(assistantId)],
|
||||||
|
type: 'assistant',
|
||||||
|
settings
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const createModel = (overrides: Partial<Model> = {}): Model => ({
|
||||||
|
id: 'gpt-4o',
|
||||||
|
provider: 'openai',
|
||||||
|
name: 'GPT-4o',
|
||||||
|
group: 'openai',
|
||||||
|
...overrides
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('modelParameters', () => {
|
||||||
|
describe('getTemperature', () => {
|
||||||
|
it('returns undefined when reasoning effort is enabled for Claude models', () => {
|
||||||
|
const assistant = createAssistant({ reasoning_effort: 'medium' })
|
||||||
|
const model = createModel({ id: 'claude-opus-4', name: 'Claude Opus 4', provider: 'anthropic', group: 'claude' })
|
||||||
|
|
||||||
|
expect(getTemperature(assistant, model)).toBeUndefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns undefined for models without temperature/topP support', () => {
|
||||||
|
const assistant = createAssistant({ enableTemperature: true })
|
||||||
|
const model = createModel({ id: 'qwen-mt-large', name: 'Qwen MT', provider: 'qwen', group: 'qwen' })
|
||||||
|
|
||||||
|
expect(getTemperature(assistant, model)).toBeUndefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns undefined for Claude 4.5 reasoning models when only TopP is enabled', () => {
|
||||||
|
const assistant = createAssistant({ enableTopP: true, enableTemperature: false })
|
||||||
|
const model = createModel({
|
||||||
|
id: 'claude-sonnet-4.5',
|
||||||
|
name: 'Claude Sonnet 4.5',
|
||||||
|
provider: 'anthropic',
|
||||||
|
group: 'claude'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(getTemperature(assistant, model)).toBeUndefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns configured temperature when enabled', () => {
|
||||||
|
const assistant = createAssistant({ enableTemperature: true, temperature: 0.42 })
|
||||||
|
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
|
||||||
|
|
||||||
|
expect(getTemperature(assistant, model)).toBe(0.42)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns undefined when temperature is disabled', () => {
|
||||||
|
const assistant = createAssistant({ enableTemperature: false, temperature: 0.9 })
|
||||||
|
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
|
||||||
|
|
||||||
|
expect(getTemperature(assistant, model)).toBeUndefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('clamps temperature to max 1.0 for Zhipu models', () => {
|
||||||
|
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
|
||||||
|
const model = createModel({ id: 'glm-4-plus', name: 'GLM-4 Plus', provider: 'zhipu', group: 'zhipu' })
|
||||||
|
|
||||||
|
expect(getTemperature(assistant, model)).toBe(1.0)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('clamps temperature to max 1.0 for Anthropic models', () => {
|
||||||
|
const assistant = createAssistant({ enableTemperature: true, temperature: 1.5 })
|
||||||
|
const model = createModel({
|
||||||
|
id: 'claude-sonnet-3.5',
|
||||||
|
name: 'Claude 3.5 Sonnet',
|
||||||
|
provider: 'anthropic',
|
||||||
|
group: 'claude'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(getTemperature(assistant, model)).toBe(1.0)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('clamps temperature to max 1.0 for Moonshot models', () => {
|
||||||
|
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
|
||||||
|
const model = createModel({
|
||||||
|
id: 'moonshot-v1-8k',
|
||||||
|
name: 'Moonshot v1 8k',
|
||||||
|
provider: 'moonshot',
|
||||||
|
group: 'moonshot'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(getTemperature(assistant, model)).toBe(1.0)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('does not clamp temperature for OpenAI models', () => {
|
||||||
|
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
|
||||||
|
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
|
||||||
|
|
||||||
|
expect(getTemperature(assistant, model)).toBe(2.0)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('does not clamp temperature when it is already within limits', () => {
|
||||||
|
const assistant = createAssistant({ enableTemperature: true, temperature: 0.8 })
|
||||||
|
const model = createModel({ id: 'glm-4-plus', name: 'GLM-4 Plus', provider: 'zhipu', group: 'zhipu' })
|
||||||
|
|
||||||
|
expect(getTemperature(assistant, model)).toBe(0.8)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getTopP', () => {
|
||||||
|
it('returns undefined when reasoning effort is enabled for Claude models', () => {
|
||||||
|
const assistant = createAssistant({ reasoning_effort: 'high' })
|
||||||
|
const model = createModel({ id: 'claude-opus-4', provider: 'anthropic', group: 'claude' })
|
||||||
|
|
||||||
|
expect(getTopP(assistant, model)).toBeUndefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns undefined for models without TopP support', () => {
|
||||||
|
const assistant = createAssistant({ enableTopP: true })
|
||||||
|
const model = createModel({ id: 'qwen-mt-small', name: 'Qwen MT', provider: 'qwen', group: 'qwen' })
|
||||||
|
|
||||||
|
expect(getTopP(assistant, model)).toBeUndefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns undefined for Claude 4.5 reasoning models when temperature is enabled', () => {
|
||||||
|
const assistant = createAssistant({ enableTemperature: true })
|
||||||
|
const model = createModel({
|
||||||
|
id: 'claude-opus-4.5',
|
||||||
|
name: 'Claude Opus 4.5',
|
||||||
|
provider: 'anthropic',
|
||||||
|
group: 'claude'
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(getTopP(assistant, model)).toBeUndefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns configured TopP when enabled', () => {
|
||||||
|
const assistant = createAssistant({ enableTopP: true, topP: 0.73 })
|
||||||
|
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
|
||||||
|
|
||||||
|
expect(getTopP(assistant, model)).toBe(0.73)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('returns undefined when TopP is disabled', () => {
|
||||||
|
const assistant = createAssistant({ enableTopP: false, topP: 0.5 })
|
||||||
|
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
|
||||||
|
|
||||||
|
expect(getTopP(assistant, model)).toBeUndefined()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getTimeout', () => {
|
||||||
|
it('uses an extended timeout for flex service tier models', () => {
|
||||||
|
const model = createModel({ id: 'o3-pro', provider: 'openai', group: 'openai' })
|
||||||
|
|
||||||
|
expect(getTimeout(model)).toBe(15 * 1000 * 60)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('falls back to the default timeout otherwise', () => {
|
||||||
|
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
|
||||||
|
|
||||||
|
expect(getTimeout(model)).toBe(defaultTimeout)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
31
src/renderer/src/aiCore/prepareParams/header.ts
Normal file
31
src/renderer/src/aiCore/prepareParams/header.ts
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
import { isClaude4SeriesModel, isClaude45ReasoningModel } from '@renderer/config/models'
|
||||||
|
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
|
import type { Assistant, Model } from '@renderer/types'
|
||||||
|
import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
||||||
|
import { isAwsBedrockProvider, isVertexProvider } from '@renderer/utils/provider'
|
||||||
|
|
||||||
|
// https://docs.claude.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
|
||||||
|
const INTERLEAVED_THINKING_HEADER = 'interleaved-thinking-2025-05-14'
|
||||||
|
// https://docs.claude.com/en/docs/build-with-claude/context-windows#1m-token-context-window
|
||||||
|
const CONTEXT_100M_HEADER = 'context-1m-2025-08-07'
|
||||||
|
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude/web-search
|
||||||
|
const WEBSEARCH_HEADER = 'web-search-2025-03-05'
|
||||||
|
|
||||||
|
export function addAnthropicHeaders(assistant: Assistant, model: Model): string[] {
|
||||||
|
const anthropicHeaders: string[] = []
|
||||||
|
const provider = getProviderByModel(model)
|
||||||
|
if (
|
||||||
|
isClaude45ReasoningModel(model) &&
|
||||||
|
isToolUseModeFunction(assistant) &&
|
||||||
|
!(isVertexProvider(provider) && isAwsBedrockProvider(provider))
|
||||||
|
) {
|
||||||
|
anthropicHeaders.push(INTERLEAVED_THINKING_HEADER)
|
||||||
|
}
|
||||||
|
if (isClaude4SeriesModel(model)) {
|
||||||
|
if (isVertexProvider(provider) && assistant.enableWebSearch) {
|
||||||
|
anthropicHeaders.push(WEBSEARCH_HEADER)
|
||||||
|
}
|
||||||
|
anthropicHeaders.push(CONTEXT_100M_HEADER)
|
||||||
|
}
|
||||||
|
return anthropicHeaders
|
||||||
|
}
|
||||||
@@ -85,19 +85,6 @@ export function supportsLargeFileUpload(model: Model): boolean {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* 检查模型是否支持TopP
|
|
||||||
*/
|
|
||||||
export function supportsTopP(model: Model): boolean {
|
|
||||||
const provider = getProviderByModel(model)
|
|
||||||
|
|
||||||
if (provider?.type === 'anthropic' || model?.endpoint_type === 'anthropic') {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取提供商特定的文件大小限制
|
* 获取提供商特定的文件大小限制
|
||||||
*/
|
*/
|
||||||
|
|||||||
@@ -3,17 +3,27 @@
|
|||||||
* 处理温度、TopP、超时等基础参数的获取逻辑
|
* 处理温度、TopP、超时等基础参数的获取逻辑
|
||||||
*/
|
*/
|
||||||
|
|
||||||
|
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||||
import {
|
import {
|
||||||
isClaude45ReasoningModel,
|
isClaude45ReasoningModel,
|
||||||
isClaudeReasoningModel,
|
isClaudeReasoningModel,
|
||||||
|
isMaxTemperatureOneModel,
|
||||||
isNotSupportTemperatureAndTopP,
|
isNotSupportTemperatureAndTopP,
|
||||||
isSupportedFlexServiceTier
|
isSupportedFlexServiceTier,
|
||||||
|
isSupportedThinkingTokenClaudeModel
|
||||||
} from '@renderer/config/models'
|
} from '@renderer/config/models'
|
||||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
import type { Assistant, Model } from '@renderer/types'
|
import type { Assistant, Model } from '@renderer/types'
|
||||||
import { defaultTimeout } from '@shared/config/constant'
|
import { defaultTimeout } from '@shared/config/constant'
|
||||||
|
|
||||||
|
import { getAnthropicThinkingBudget } from '../utils/reasoning'
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
* Claude 4.5 推理模型:
|
||||||
|
* - 只启用 temperature → 使用 temperature
|
||||||
|
* - 只启用 top_p → 使用 top_p
|
||||||
|
* - 同时启用 → temperature 生效,top_p 被忽略
|
||||||
|
* - 都不启用 → 都不使用
|
||||||
* 获取温度参数
|
* 获取温度参数
|
||||||
*/
|
*/
|
||||||
export function getTemperature(assistant: Assistant, model: Model): number | undefined {
|
export function getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||||
@@ -27,7 +37,11 @@ export function getTemperature(assistant: Assistant, model: Model): number | und
|
|||||||
return undefined
|
return undefined
|
||||||
}
|
}
|
||||||
const assistantSettings = getAssistantSettings(assistant)
|
const assistantSettings = getAssistantSettings(assistant)
|
||||||
return assistantSettings?.enableTemperature ? assistantSettings?.temperature : undefined
|
let temperature = assistantSettings?.temperature
|
||||||
|
if (temperature && isMaxTemperatureOneModel(model)) {
|
||||||
|
temperature = Math.min(1, temperature)
|
||||||
|
}
|
||||||
|
return assistantSettings?.enableTemperature ? temperature : undefined
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -56,3 +70,18 @@ export function getTimeout(model: Model): number {
|
|||||||
}
|
}
|
||||||
return defaultTimeout
|
return defaultTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function getMaxTokens(assistant: Assistant, model: Model): number | undefined {
|
||||||
|
// NOTE: ai-sdk会把maxToken和budgetToken加起来
|
||||||
|
let { maxTokens = DEFAULT_MAX_TOKENS } = getAssistantSettings(assistant)
|
||||||
|
|
||||||
|
const provider = getProviderByModel(model)
|
||||||
|
if (isSupportedThinkingTokenClaudeModel(model) && ['anthropic', 'aws-bedrock'].includes(provider.type)) {
|
||||||
|
const { reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
|
||||||
|
const budget = getAnthropicThinkingBudget(maxTokens, reasoningEffort, model.id)
|
||||||
|
if (budget) {
|
||||||
|
maxTokens -= budget
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return maxTokens
|
||||||
|
}
|
||||||
|
|||||||
@@ -4,22 +4,24 @@
|
|||||||
*/
|
*/
|
||||||
|
|
||||||
import { anthropic } from '@ai-sdk/anthropic'
|
import { anthropic } from '@ai-sdk/anthropic'
|
||||||
|
import { azure } from '@ai-sdk/azure'
|
||||||
import { google } from '@ai-sdk/google'
|
import { google } from '@ai-sdk/google'
|
||||||
import { vertexAnthropic } from '@ai-sdk/google-vertex/anthropic/edge'
|
import { vertexAnthropic } from '@ai-sdk/google-vertex/anthropic/edge'
|
||||||
import { vertex } from '@ai-sdk/google-vertex/edge'
|
import { vertex } from '@ai-sdk/google-vertex/edge'
|
||||||
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
|
import { combineHeaders } from '@ai-sdk/provider-utils'
|
||||||
|
import type { AnthropicSearchConfig, WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
|
||||||
import { isBaseProvider } from '@cherrystudio/ai-core/core/providers/schemas'
|
import { isBaseProvider } from '@cherrystudio/ai-core/core/providers/schemas'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import {
|
import {
|
||||||
|
isAnthropicModel,
|
||||||
isGenerateImageModel,
|
isGenerateImageModel,
|
||||||
isOpenRouterBuiltInWebSearchModel,
|
isOpenRouterBuiltInWebSearchModel,
|
||||||
isReasoningModel,
|
isReasoningModel,
|
||||||
isSupportedReasoningEffortModel,
|
isSupportedReasoningEffortModel,
|
||||||
isSupportedThinkingTokenClaudeModel,
|
|
||||||
isSupportedThinkingTokenModel,
|
isSupportedThinkingTokenModel,
|
||||||
isWebSearchModel
|
isWebSearchModel
|
||||||
} from '@renderer/config/models'
|
} from '@renderer/config/models'
|
||||||
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
|
import { getDefaultModel } from '@renderer/services/AssistantService'
|
||||||
import store from '@renderer/store'
|
import store from '@renderer/store'
|
||||||
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
|
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
|
||||||
import { type Assistant, type MCPTool, type Provider } from '@renderer/types'
|
import { type Assistant, type MCPTool, type Provider } from '@renderer/types'
|
||||||
@@ -32,10 +34,9 @@ import { stepCountIs } from 'ai'
|
|||||||
import { getAiSdkProviderId } from '../provider/factory'
|
import { getAiSdkProviderId } from '../provider/factory'
|
||||||
import { setupToolsConfig } from '../utils/mcp'
|
import { setupToolsConfig } from '../utils/mcp'
|
||||||
import { buildProviderOptions } from '../utils/options'
|
import { buildProviderOptions } from '../utils/options'
|
||||||
import { getAnthropicThinkingBudget } from '../utils/reasoning'
|
|
||||||
import { buildProviderBuiltinWebSearchConfig } from '../utils/websearch'
|
import { buildProviderBuiltinWebSearchConfig } from '../utils/websearch'
|
||||||
import { supportsTopP } from './modelCapabilities'
|
import { addAnthropicHeaders } from './header'
|
||||||
import { getTemperature, getTopP } from './modelParameters'
|
import { getMaxTokens, getTemperature, getTopP } from './modelParameters'
|
||||||
|
|
||||||
const logger = loggerService.withContext('parameterBuilder')
|
const logger = loggerService.withContext('parameterBuilder')
|
||||||
|
|
||||||
@@ -58,7 +59,7 @@ export async function buildStreamTextParams(
|
|||||||
timeout?: number
|
timeout?: number
|
||||||
headers?: Record<string, string>
|
headers?: Record<string, string>
|
||||||
}
|
}
|
||||||
} = {}
|
}
|
||||||
): Promise<{
|
): Promise<{
|
||||||
params: StreamTextParams
|
params: StreamTextParams
|
||||||
modelId: string
|
modelId: string
|
||||||
@@ -75,8 +76,6 @@ export async function buildStreamTextParams(
|
|||||||
const model = assistant.model || getDefaultModel()
|
const model = assistant.model || getDefaultModel()
|
||||||
const aiSdkProviderId = getAiSdkProviderId(provider)
|
const aiSdkProviderId = getAiSdkProviderId(provider)
|
||||||
|
|
||||||
let { maxTokens } = getAssistantSettings(assistant)
|
|
||||||
|
|
||||||
// 这三个变量透传出来,交给下面启用插件/中间件
|
// 这三个变量透传出来,交给下面启用插件/中间件
|
||||||
// 也可以在外部构建好再传入buildStreamTextParams
|
// 也可以在外部构建好再传入buildStreamTextParams
|
||||||
// FIXME: qwen3即使关闭思考仍然会导致enableReasoning的结果为true
|
// FIXME: qwen3即使关闭思考仍然会导致enableReasoning的结果为true
|
||||||
@@ -113,16 +112,6 @@ export async function buildStreamTextParams(
|
|||||||
enableGenerateImage
|
enableGenerateImage
|
||||||
})
|
})
|
||||||
|
|
||||||
// NOTE: ai-sdk会把maxToken和budgetToken加起来
|
|
||||||
if (
|
|
||||||
enableReasoning &&
|
|
||||||
maxTokens !== undefined &&
|
|
||||||
isSupportedThinkingTokenClaudeModel(model) &&
|
|
||||||
(provider.type === 'anthropic' || provider.type === 'aws-bedrock')
|
|
||||||
) {
|
|
||||||
maxTokens -= getAnthropicThinkingBudget(assistant, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
let webSearchPluginConfig: WebSearchPluginConfig | undefined = undefined
|
let webSearchPluginConfig: WebSearchPluginConfig | undefined = undefined
|
||||||
if (enableWebSearch) {
|
if (enableWebSearch) {
|
||||||
if (isBaseProvider(aiSdkProviderId)) {
|
if (isBaseProvider(aiSdkProviderId)) {
|
||||||
@@ -139,6 +128,17 @@ export async function buildStreamTextParams(
|
|||||||
maxUses: webSearchConfig.maxResults,
|
maxUses: webSearchConfig.maxResults,
|
||||||
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
|
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
|
||||||
}) as ProviderDefinedTool
|
}) as ProviderDefinedTool
|
||||||
|
} else if (aiSdkProviderId === 'azure-responses') {
|
||||||
|
tools.web_search_preview = azure.tools.webSearchPreview({
|
||||||
|
searchContextSize: webSearchPluginConfig?.openai!.searchContextSize
|
||||||
|
}) as ProviderDefinedTool
|
||||||
|
} else if (aiSdkProviderId === 'azure-anthropic') {
|
||||||
|
const blockedDomains = mapRegexToPatterns(webSearchConfig.excludeDomains)
|
||||||
|
const anthropicSearchOptions: AnthropicSearchConfig = {
|
||||||
|
maxUses: webSearchConfig.maxResults,
|
||||||
|
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
|
||||||
|
}
|
||||||
|
tools.web_search = anthropic.tools.webSearch_20250305(anthropicSearchOptions) as ProviderDefinedTool
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -156,9 +156,10 @@ export async function buildStreamTextParams(
|
|||||||
tools.url_context = google.tools.urlContext({}) as ProviderDefinedTool
|
tools.url_context = google.tools.urlContext({}) as ProviderDefinedTool
|
||||||
break
|
break
|
||||||
case 'anthropic':
|
case 'anthropic':
|
||||||
|
case 'azure-anthropic':
|
||||||
case 'google-vertex-anthropic':
|
case 'google-vertex-anthropic':
|
||||||
tools.web_fetch = (
|
tools.web_fetch = (
|
||||||
aiSdkProviderId === 'anthropic'
|
['anthropic', 'azure-anthropic'].includes(aiSdkProviderId)
|
||||||
? anthropic.tools.webFetch_20250910({
|
? anthropic.tools.webFetch_20250910({
|
||||||
maxUses: webSearchConfig.maxResults,
|
maxUses: webSearchConfig.maxResults,
|
||||||
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
|
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
|
||||||
@@ -172,22 +173,26 @@ export async function buildStreamTextParams(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let headers: Record<string, string | undefined> = options.requestOptions?.headers ?? {}
|
||||||
|
|
||||||
|
if (isAnthropicModel(model)) {
|
||||||
|
const newBetaHeaders = { 'anthropic-beta': addAnthropicHeaders(assistant, model).join(',') }
|
||||||
|
headers = combineHeaders(headers, newBetaHeaders)
|
||||||
|
}
|
||||||
|
|
||||||
// 构建基础参数
|
// 构建基础参数
|
||||||
const params: StreamTextParams = {
|
const params: StreamTextParams = {
|
||||||
messages: sdkMessages,
|
messages: sdkMessages,
|
||||||
maxOutputTokens: maxTokens,
|
maxOutputTokens: getMaxTokens(assistant, model),
|
||||||
temperature: getTemperature(assistant, model),
|
temperature: getTemperature(assistant, model),
|
||||||
|
topP: getTopP(assistant, model),
|
||||||
abortSignal: options.requestOptions?.signal,
|
abortSignal: options.requestOptions?.signal,
|
||||||
headers: options.requestOptions?.headers,
|
headers,
|
||||||
providerOptions,
|
providerOptions,
|
||||||
stopWhen: stepCountIs(20),
|
stopWhen: stepCountIs(20),
|
||||||
maxRetries: 0
|
maxRetries: 0
|
||||||
}
|
}
|
||||||
|
|
||||||
if (supportsTopP(model)) {
|
|
||||||
params.topP = getTopP(assistant, model)
|
|
||||||
}
|
|
||||||
|
|
||||||
if (tools) {
|
if (tools) {
|
||||||
params.tools = tools
|
params.tools = tools
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,6 +23,26 @@ vi.mock('@cherrystudio/ai-core', () => ({
|
|||||||
}
|
}
|
||||||
}))
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/services/AssistantService', () => ({
|
||||||
|
getProviderByModel: vi.fn(),
|
||||||
|
getAssistantSettings: vi.fn(),
|
||||||
|
getDefaultAssistant: vi.fn().mockReturnValue({
|
||||||
|
id: 'default',
|
||||||
|
name: 'Default Assistant',
|
||||||
|
prompt: '',
|
||||||
|
settings: {}
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/store/settings', () => ({
|
||||||
|
default: {},
|
||||||
|
settingsSlice: {
|
||||||
|
name: 'settings',
|
||||||
|
reducer: vi.fn(),
|
||||||
|
actions: {}
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
// Mock the provider configs
|
// Mock the provider configs
|
||||||
vi.mock('../providerConfigs', () => ({
|
vi.mock('../providerConfigs', () => ({
|
||||||
initializeNewProviders: vi.fn()
|
initializeNewProviders: vi.fn()
|
||||||
|
|||||||
@@ -12,7 +12,14 @@ vi.mock('@renderer/services/LoggerService', () => ({
|
|||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@renderer/services/AssistantService', () => ({
|
vi.mock('@renderer/services/AssistantService', () => ({
|
||||||
getProviderByModel: vi.fn()
|
getProviderByModel: vi.fn(),
|
||||||
|
getAssistantSettings: vi.fn(),
|
||||||
|
getDefaultAssistant: vi.fn().mockReturnValue({
|
||||||
|
id: 'default',
|
||||||
|
name: 'Default Assistant',
|
||||||
|
prompt: '',
|
||||||
|
settings: {}
|
||||||
|
})
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@renderer/store', () => ({
|
vi.mock('@renderer/store', () => ({
|
||||||
@@ -34,7 +41,7 @@ vi.mock('@renderer/utils/api', () => ({
|
|||||||
}))
|
}))
|
||||||
}))
|
}))
|
||||||
|
|
||||||
vi.mock('@renderer/config/providers', async (importOriginal) => {
|
vi.mock('@renderer/utils/provider', async (importOriginal) => {
|
||||||
const actual = (await importOriginal()) as any
|
const actual = (await importOriginal()) as any
|
||||||
return {
|
return {
|
||||||
...actual,
|
...actual,
|
||||||
@@ -53,10 +60,21 @@ vi.mock('@renderer/hooks/useVertexAI', () => ({
|
|||||||
createVertexProvider: vi.fn()
|
createVertexProvider: vi.fn()
|
||||||
}))
|
}))
|
||||||
|
|
||||||
import { isCherryAIProvider, isPerplexityProvider } from '@renderer/config/providers'
|
vi.mock('@renderer/services/AssistantService', () => ({
|
||||||
|
getProviderByModel: vi.fn(),
|
||||||
|
getAssistantSettings: vi.fn(),
|
||||||
|
getDefaultAssistant: vi.fn().mockReturnValue({
|
||||||
|
id: 'default',
|
||||||
|
name: 'Default Assistant',
|
||||||
|
prompt: '',
|
||||||
|
settings: {}
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
|
||||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
import type { Model, Provider } from '@renderer/types'
|
import type { Model, Provider } from '@renderer/types'
|
||||||
import { formatApiHost } from '@renderer/utils/api'
|
import { formatApiHost } from '@renderer/utils/api'
|
||||||
|
import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider'
|
||||||
|
|
||||||
import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants'
|
import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants'
|
||||||
import { getActualProvider, providerToAiSdkConfig } from '../providerConfig'
|
import { getActualProvider, providerToAiSdkConfig } from '../providerConfig'
|
||||||
|
|||||||
22
src/renderer/src/aiCore/provider/config/azure-anthropic.ts
Normal file
22
src/renderer/src/aiCore/provider/config/azure-anthropic.ts
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
import type { Provider } from '@renderer/types'
|
||||||
|
|
||||||
|
import { provider2Provider, startsWith } from './helper'
|
||||||
|
import type { RuleSet } from './types'
|
||||||
|
|
||||||
|
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
|
||||||
|
const AZURE_ANTHROPIC_RULES: RuleSet = {
|
||||||
|
rules: [
|
||||||
|
{
|
||||||
|
match: startsWith('claude'),
|
||||||
|
provider: (provider: Provider) => ({
|
||||||
|
...provider,
|
||||||
|
type: 'anthropic',
|
||||||
|
apiHost: provider.apiHost + 'anthropic/v1',
|
||||||
|
id: 'azure-anthropic'
|
||||||
|
})
|
||||||
|
}
|
||||||
|
],
|
||||||
|
fallbackRule: (provider: Provider) => provider
|
||||||
|
}
|
||||||
|
|
||||||
|
export const azureAnthropicProviderCreator = provider2Provider.bind(null, AZURE_ANTHROPIC_RULES)
|
||||||
@@ -2,8 +2,10 @@ import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } fr
|
|||||||
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
|
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import type { Provider } from '@renderer/types'
|
import type { Provider } from '@renderer/types'
|
||||||
|
import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from '@renderer/utils/provider'
|
||||||
import type { Provider as AiSdkProvider } from 'ai'
|
import type { Provider as AiSdkProvider } from 'ai'
|
||||||
|
|
||||||
|
import type { AiSdkConfig } from '../types'
|
||||||
import { initializeNewProviders } from './providerInitialization'
|
import { initializeNewProviders } from './providerInitialization'
|
||||||
|
|
||||||
const logger = loggerService.withContext('ProviderFactory')
|
const logger = loggerService.withContext('ProviderFactory')
|
||||||
@@ -55,9 +57,12 @@ function tryResolveProviderId(identifier: string): ProviderId | null {
|
|||||||
* 获取AI SDK Provider ID
|
* 获取AI SDK Provider ID
|
||||||
* 简化版:减少重复逻辑,利用通用解析函数
|
* 简化版:减少重复逻辑,利用通用解析函数
|
||||||
*/
|
*/
|
||||||
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' {
|
export function getAiSdkProviderId(provider: Provider): string {
|
||||||
// 1. 尝试解析provider.id
|
// 1. 尝试解析provider.id
|
||||||
const resolvedFromId = tryResolveProviderId(provider.id)
|
const resolvedFromId = tryResolveProviderId(provider.id)
|
||||||
|
if (isAzureOpenAIProvider(provider) && isAzureResponsesEndpoint(provider)) {
|
||||||
|
return 'azure-responses'
|
||||||
|
}
|
||||||
if (resolvedFromId) {
|
if (resolvedFromId) {
|
||||||
return resolvedFromId
|
return resolvedFromId
|
||||||
}
|
}
|
||||||
@@ -73,11 +78,11 @@ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-com
|
|||||||
if (provider.apiHost.includes('api.openai.com')) {
|
if (provider.apiHost.includes('api.openai.com')) {
|
||||||
return 'openai-chat'
|
return 'openai-chat'
|
||||||
}
|
}
|
||||||
// 3. 最后的fallback(通常会成为openai-compatible)
|
// 3. 最后的fallback(使用provider本身的id)
|
||||||
return provider.id as ProviderId
|
return provider.id
|
||||||
}
|
}
|
||||||
|
|
||||||
export async function createAiSdkProvider(config) {
|
export async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {
|
||||||
let localProvider: Awaited<AiSdkProvider> | null = null
|
let localProvider: Awaited<AiSdkProvider> | null = null
|
||||||
try {
|
try {
|
||||||
if (config.providerId === 'openai' && config.options?.mode === 'chat') {
|
if (config.providerId === 'openai' && config.options?.mode === 'chat') {
|
||||||
|
|||||||
@@ -1,19 +1,5 @@
|
|||||||
import {
|
import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
|
||||||
formatPrivateKey,
|
|
||||||
hasProviderConfig,
|
|
||||||
ProviderConfigFactory,
|
|
||||||
type ProviderId,
|
|
||||||
type ProviderSettingsMap
|
|
||||||
} from '@cherrystudio/ai-core/provider'
|
|
||||||
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
|
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
|
||||||
import {
|
|
||||||
isAnthropicProvider,
|
|
||||||
isAzureOpenAIProvider,
|
|
||||||
isCherryAIProvider,
|
|
||||||
isGeminiProvider,
|
|
||||||
isNewApiProvider,
|
|
||||||
isPerplexityProvider
|
|
||||||
} from '@renderer/config/providers'
|
|
||||||
import {
|
import {
|
||||||
getAwsBedrockAccessKeyId,
|
getAwsBedrockAccessKeyId,
|
||||||
getAwsBedrockApiKey,
|
getAwsBedrockApiKey,
|
||||||
@@ -21,14 +7,25 @@ import {
|
|||||||
getAwsBedrockRegion,
|
getAwsBedrockRegion,
|
||||||
getAwsBedrockSecretAccessKey
|
getAwsBedrockSecretAccessKey
|
||||||
} from '@renderer/hooks/useAwsBedrock'
|
} from '@renderer/hooks/useAwsBedrock'
|
||||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
|
||||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
import store from '@renderer/store'
|
import store from '@renderer/store'
|
||||||
import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types'
|
import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types'
|
||||||
import { formatApiHost, formatAzureOpenAIApiHost, formatVertexApiHost, routeToEndpoint } from '@renderer/utils/api'
|
import { formatApiHost, formatAzureOpenAIApiHost, formatVertexApiHost, routeToEndpoint } from '@renderer/utils/api'
|
||||||
|
import {
|
||||||
|
isAnthropicProvider,
|
||||||
|
isAzureOpenAIProvider,
|
||||||
|
isCherryAIProvider,
|
||||||
|
isGeminiProvider,
|
||||||
|
isNewApiProvider,
|
||||||
|
isPerplexityProvider,
|
||||||
|
isVertexProvider
|
||||||
|
} from '@renderer/utils/provider'
|
||||||
import { cloneDeep } from 'lodash'
|
import { cloneDeep } from 'lodash'
|
||||||
|
|
||||||
|
import type { AiSdkConfig } from '../types'
|
||||||
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
|
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
|
||||||
|
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
|
||||||
import { COPILOT_DEFAULT_HEADERS } from './constants'
|
import { COPILOT_DEFAULT_HEADERS } from './constants'
|
||||||
import { getAiSdkProviderId } from './factory'
|
import { getAiSdkProviderId } from './factory'
|
||||||
|
|
||||||
@@ -74,6 +71,9 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider {
|
|||||||
return vertexAnthropicProviderCreator(model, provider)
|
return vertexAnthropicProviderCreator(model, provider)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if (isAzureOpenAIProvider(provider)) {
|
||||||
|
return azureAnthropicProviderCreator(model, provider)
|
||||||
|
}
|
||||||
return provider
|
return provider
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -131,13 +131,7 @@ export function getActualProvider(model: Model): Provider {
|
|||||||
* 将 Provider 配置转换为新 AI SDK 格式
|
* 将 Provider 配置转换为新 AI SDK 格式
|
||||||
* 简化版:利用新的别名映射系统
|
* 简化版:利用新的别名映射系统
|
||||||
*/
|
*/
|
||||||
export function providerToAiSdkConfig(
|
export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig {
|
||||||
actualProvider: Provider,
|
|
||||||
model: Model
|
|
||||||
): {
|
|
||||||
providerId: ProviderId | 'openai-compatible'
|
|
||||||
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
|
||||||
} {
|
|
||||||
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
||||||
|
|
||||||
// 构建基础配置
|
// 构建基础配置
|
||||||
@@ -189,13 +183,12 @@ export function providerToAiSdkConfig(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
// azure
|
// azure
|
||||||
if (aiSdkProviderId === 'azure' || actualProvider.type === 'azure-openai') {
|
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest
|
||||||
// extraOptions.apiVersion = actualProvider.apiVersion 默认使用v1,不使用azure endpoint
|
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses?tabs=python-key#responses-api
|
||||||
if (actualProvider.apiVersion === 'preview') {
|
if (aiSdkProviderId === 'azure-responses') {
|
||||||
extraOptions.mode = 'responses'
|
extraOptions.mode = 'responses'
|
||||||
} else {
|
} else if (aiSdkProviderId === 'azure') {
|
||||||
extraOptions.mode = 'chat'
|
extraOptions.mode = 'chat'
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// bedrock
|
// bedrock
|
||||||
@@ -225,10 +218,17 @@ export function providerToAiSdkConfig(
|
|||||||
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
|
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// cherryin
|
||||||
|
if (aiSdkProviderId === 'cherryin') {
|
||||||
|
if (model.endpoint_type) {
|
||||||
|
extraOptions.endpointType = model.endpoint_type
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||||
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
|
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
|
||||||
return {
|
return {
|
||||||
providerId: aiSdkProviderId as ProviderId,
|
providerId: aiSdkProviderId,
|
||||||
options
|
options
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,6 +32,14 @@ export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
|
|||||||
supportsImageGeneration: true,
|
supportsImageGeneration: true,
|
||||||
aliases: ['vertexai-anthropic']
|
aliases: ['vertexai-anthropic']
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
id: 'azure-anthropic',
|
||||||
|
name: 'Azure AI Anthropic',
|
||||||
|
import: () => import('@ai-sdk/anthropic'),
|
||||||
|
creatorFunctionName: 'createAnthropic',
|
||||||
|
supportsImageGeneration: false,
|
||||||
|
aliases: ['azure-anthropic']
|
||||||
|
},
|
||||||
{
|
{
|
||||||
id: 'github-copilot-openai-compatible',
|
id: 'github-copilot-openai-compatible',
|
||||||
name: 'GitHub Copilot OpenAI Compatible',
|
name: 'GitHub Copilot OpenAI Compatible',
|
||||||
|
|||||||
@@ -133,7 +133,7 @@ export class AiSdkSpanAdapter {
|
|||||||
|
|
||||||
// 详细记录转换过程
|
// 详细记录转换过程
|
||||||
const operationId = attributes['ai.operationId']
|
const operationId = attributes['ai.operationId']
|
||||||
logger.info('Converting AI SDK span to SpanEntity', {
|
logger.debug('Converting AI SDK span to SpanEntity', {
|
||||||
spanName: spanName,
|
spanName: spanName,
|
||||||
operationId,
|
operationId,
|
||||||
spanTag,
|
spanTag,
|
||||||
@@ -149,7 +149,7 @@ export class AiSdkSpanAdapter {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if (tokenUsage) {
|
if (tokenUsage) {
|
||||||
logger.info('Token usage data found', {
|
logger.debug('Token usage data found', {
|
||||||
spanName: spanName,
|
spanName: spanName,
|
||||||
operationId,
|
operationId,
|
||||||
usage: tokenUsage,
|
usage: tokenUsage,
|
||||||
@@ -158,7 +158,7 @@ export class AiSdkSpanAdapter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (inputs || outputs) {
|
if (inputs || outputs) {
|
||||||
logger.info('Input/Output data extracted', {
|
logger.debug('Input/Output data extracted', {
|
||||||
spanName: spanName,
|
spanName: spanName,
|
||||||
operationId,
|
operationId,
|
||||||
hasInputs: !!inputs,
|
hasInputs: !!inputs,
|
||||||
@@ -170,7 +170,7 @@ export class AiSdkSpanAdapter {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if (Object.keys(typeSpecificData).length > 0) {
|
if (Object.keys(typeSpecificData).length > 0) {
|
||||||
logger.info('Type-specific data extracted', {
|
logger.debug('Type-specific data extracted', {
|
||||||
spanName: spanName,
|
spanName: spanName,
|
||||||
operationId,
|
operationId,
|
||||||
typeSpecificKeys: Object.keys(typeSpecificData),
|
typeSpecificKeys: Object.keys(typeSpecificData),
|
||||||
@@ -204,7 +204,7 @@ export class AiSdkSpanAdapter {
|
|||||||
modelName: modelName || this.extractModelFromAttributes(attributes)
|
modelName: modelName || this.extractModelFromAttributes(attributes)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info('AI SDK span successfully converted to SpanEntity', {
|
logger.debug('AI SDK span successfully converted to SpanEntity', {
|
||||||
spanName: spanName,
|
spanName: spanName,
|
||||||
operationId,
|
operationId,
|
||||||
spanId: spanContext.spanId,
|
spanId: spanContext.spanId,
|
||||||
|
|||||||
15
src/renderer/src/aiCore/types/index.ts
Normal file
15
src/renderer/src/aiCore/types/index.ts
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
/**
|
||||||
|
* This type definition file is only for renderer.
|
||||||
|
* It cannot be migrated to @renderer/types since files within it are actually being used by both main and renderer.
|
||||||
|
* If we do that, main would throw an error because it cannot import a module which imports a type from a browser-enviroment-only package.
|
||||||
|
* (ai-core package is set as browser-enviroment-only)
|
||||||
|
*
|
||||||
|
* TODO: We should separate them clearly. Keep renderer only types in renderer, and main only types in main, and shared types in shared.
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type { ProviderSettingsMap } from '@cherrystudio/ai-core/provider'
|
||||||
|
|
||||||
|
export type AiSdkConfig = {
|
||||||
|
providerId: string
|
||||||
|
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||||
|
}
|
||||||
121
src/renderer/src/aiCore/utils/__tests__/image.test.ts
Normal file
121
src/renderer/src/aiCore/utils/__tests__/image.test.ts
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
/**
|
||||||
|
* image.ts Unit Tests
|
||||||
|
* Tests for Gemini image generation utilities
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type { Model, Provider } from '@renderer/types'
|
||||||
|
import { SystemProviderIds } from '@renderer/types'
|
||||||
|
import { describe, expect, it } from 'vitest'
|
||||||
|
|
||||||
|
import { buildGeminiGenerateImageParams, isOpenRouterGeminiGenerateImageModel } from '../image'
|
||||||
|
|
||||||
|
describe('image utils', () => {
|
||||||
|
describe('buildGeminiGenerateImageParams', () => {
|
||||||
|
it('should return correct response modalities', () => {
|
||||||
|
const result = buildGeminiGenerateImageParams()
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
responseModalities: ['TEXT', 'IMAGE']
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return an object with responseModalities property', () => {
|
||||||
|
const result = buildGeminiGenerateImageParams()
|
||||||
|
|
||||||
|
expect(result).toHaveProperty('responseModalities')
|
||||||
|
expect(Array.isArray(result.responseModalities)).toBe(true)
|
||||||
|
expect(result.responseModalities).toHaveLength(2)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('isOpenRouterGeminiGenerateImageModel', () => {
|
||||||
|
const mockOpenRouterProvider: Provider = {
|
||||||
|
id: SystemProviderIds.openrouter,
|
||||||
|
name: 'OpenRouter',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://openrouter.ai/api/v1',
|
||||||
|
isSystem: true
|
||||||
|
} as Provider
|
||||||
|
|
||||||
|
const mockOtherProvider: Provider = {
|
||||||
|
id: SystemProviderIds.openai,
|
||||||
|
name: 'OpenAI',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://api.openai.com/v1',
|
||||||
|
isSystem: true
|
||||||
|
} as Provider
|
||||||
|
|
||||||
|
it('should return true for OpenRouter Gemini 2.5 Flash Image model', () => {
|
||||||
|
const model: Model = {
|
||||||
|
id: 'google/gemini-2.5-flash-image-preview',
|
||||||
|
name: 'Gemini 2.5 Flash Image',
|
||||||
|
provider: SystemProviderIds.openrouter
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
|
||||||
|
expect(result).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return false for non-Gemini model on OpenRouter', () => {
|
||||||
|
const model: Model = {
|
||||||
|
id: 'openai/gpt-4',
|
||||||
|
name: 'GPT-4',
|
||||||
|
provider: SystemProviderIds.openrouter
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
|
||||||
|
expect(result).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return false for Gemini model on non-OpenRouter provider', () => {
|
||||||
|
const model: Model = {
|
||||||
|
id: 'gemini-2.5-flash-image-preview',
|
||||||
|
name: 'Gemini 2.5 Flash Image',
|
||||||
|
provider: SystemProviderIds.gemini
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const result = isOpenRouterGeminiGenerateImageModel(model, mockOtherProvider)
|
||||||
|
expect(result).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return false for Gemini model without image suffix', () => {
|
||||||
|
const model: Model = {
|
||||||
|
id: 'google/gemini-2.5-flash',
|
||||||
|
name: 'Gemini 2.5 Flash',
|
||||||
|
provider: SystemProviderIds.openrouter
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
|
||||||
|
expect(result).toBe(false)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle model ID with partial match', () => {
|
||||||
|
const model: Model = {
|
||||||
|
id: 'google/gemini-2.5-flash-image-generation',
|
||||||
|
name: 'Gemini Image Gen',
|
||||||
|
provider: SystemProviderIds.openrouter
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
|
||||||
|
expect(result).toBe(true)
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return false for custom provider', () => {
|
||||||
|
const customProvider: Provider = {
|
||||||
|
id: 'custom-provider-123',
|
||||||
|
name: 'Custom Provider',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://custom.com'
|
||||||
|
} as Provider
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'gemini-2.5-flash-image-preview',
|
||||||
|
name: 'Gemini 2.5 Flash Image',
|
||||||
|
provider: 'custom-provider-123'
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const result = isOpenRouterGeminiGenerateImageModel(model, customProvider)
|
||||||
|
expect(result).toBe(false)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
435
src/renderer/src/aiCore/utils/__tests__/mcp.test.ts
Normal file
435
src/renderer/src/aiCore/utils/__tests__/mcp.test.ts
Normal file
@@ -0,0 +1,435 @@
|
|||||||
|
/**
|
||||||
|
* mcp.ts Unit Tests
|
||||||
|
* Tests for MCP tools configuration and conversion utilities
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type { MCPTool } from '@renderer/types'
|
||||||
|
import type { Tool } from 'ai'
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
import { convertMcpToolsToAiSdkTools, setupToolsConfig } from '../mcp'
|
||||||
|
|
||||||
|
// Mock dependencies
|
||||||
|
vi.mock('@logger', () => ({
|
||||||
|
loggerService: {
|
||||||
|
withContext: () => ({
|
||||||
|
debug: vi.fn(),
|
||||||
|
error: vi.fn(),
|
||||||
|
warn: vi.fn(),
|
||||||
|
info: vi.fn()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/utils/mcp-tools', () => ({
|
||||||
|
getMcpServerByTool: vi.fn(() => ({ id: 'test-server', autoApprove: false })),
|
||||||
|
isToolAutoApproved: vi.fn(() => false),
|
||||||
|
callMCPTool: vi.fn(async () => ({
|
||||||
|
content: [{ type: 'text', text: 'Tool executed successfully' }],
|
||||||
|
isError: false
|
||||||
|
}))
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/utils/userConfirmation', () => ({
|
||||||
|
requestToolConfirmation: vi.fn(async () => true)
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('mcp utils', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('setupToolsConfig', () => {
|
||||||
|
it('should return undefined when no MCP tools provided', () => {
|
||||||
|
const result = setupToolsConfig()
|
||||||
|
expect(result).toBeUndefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return undefined when empty MCP tools array provided', () => {
|
||||||
|
const result = setupToolsConfig([])
|
||||||
|
expect(result).toBeUndefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should convert MCP tools to AI SDK tools format', () => {
|
||||||
|
const mcpTools: MCPTool[] = [
|
||||||
|
{
|
||||||
|
id: 'test-tool-1',
|
||||||
|
serverId: 'test-server',
|
||||||
|
serverName: 'test-server',
|
||||||
|
name: 'test-tool',
|
||||||
|
description: 'A test tool',
|
||||||
|
type: 'mcp',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
query: { type: 'string' }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
const result = setupToolsConfig(mcpTools)
|
||||||
|
|
||||||
|
expect(result).not.toBeUndefined()
|
||||||
|
expect(Object.keys(result!)).toEqual(['test-tool'])
|
||||||
|
expect(result!['test-tool']).toHaveProperty('description')
|
||||||
|
expect(result!['test-tool']).toHaveProperty('inputSchema')
|
||||||
|
expect(result!['test-tool']).toHaveProperty('execute')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle multiple MCP tools', () => {
|
||||||
|
const mcpTools: MCPTool[] = [
|
||||||
|
{
|
||||||
|
id: 'tool1-id',
|
||||||
|
serverId: 'server1',
|
||||||
|
serverName: 'server1',
|
||||||
|
name: 'tool1',
|
||||||
|
description: 'First tool',
|
||||||
|
type: 'mcp',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'tool2-id',
|
||||||
|
serverId: 'server2',
|
||||||
|
serverName: 'server2',
|
||||||
|
name: 'tool2',
|
||||||
|
description: 'Second tool',
|
||||||
|
type: 'mcp',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
const result = setupToolsConfig(mcpTools)
|
||||||
|
|
||||||
|
expect(result).not.toBeUndefined()
|
||||||
|
expect(Object.keys(result!)).toHaveLength(2)
|
||||||
|
expect(Object.keys(result!)).toEqual(['tool1', 'tool2'])
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('convertMcpToolsToAiSdkTools', () => {
|
||||||
|
it('should convert single MCP tool to AI SDK tool', () => {
|
||||||
|
const mcpTools: MCPTool[] = [
|
||||||
|
{
|
||||||
|
id: 'get-weather-id',
|
||||||
|
serverId: 'weather-server',
|
||||||
|
serverName: 'weather-server',
|
||||||
|
name: 'get-weather',
|
||||||
|
description: 'Get weather information',
|
||||||
|
type: 'mcp',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
location: { type: 'string' }
|
||||||
|
},
|
||||||
|
required: ['location']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||||
|
|
||||||
|
expect(Object.keys(result)).toEqual(['get-weather'])
|
||||||
|
|
||||||
|
const tool = result['get-weather'] as Tool
|
||||||
|
expect(tool.description).toBe('Get weather information')
|
||||||
|
expect(tool.inputSchema).toBeDefined()
|
||||||
|
expect(typeof tool.execute).toBe('function')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle tool without description', () => {
|
||||||
|
const mcpTools: MCPTool[] = [
|
||||||
|
{
|
||||||
|
id: 'no-desc-tool-id',
|
||||||
|
serverId: 'test-server',
|
||||||
|
serverName: 'test-server',
|
||||||
|
name: 'no-desc-tool',
|
||||||
|
type: 'mcp',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||||
|
|
||||||
|
expect(Object.keys(result)).toEqual(['no-desc-tool'])
|
||||||
|
const tool = result['no-desc-tool'] as Tool
|
||||||
|
expect(tool.description).toBe('Tool from test-server')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should convert empty tools array', () => {
|
||||||
|
const result = convertMcpToolsToAiSdkTools([])
|
||||||
|
expect(result).toEqual({})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle complex input schemas', () => {
|
||||||
|
const mcpTools: MCPTool[] = [
|
||||||
|
{
|
||||||
|
id: 'complex-tool-id',
|
||||||
|
serverId: 'server',
|
||||||
|
serverName: 'server',
|
||||||
|
name: 'complex-tool',
|
||||||
|
description: 'Tool with complex schema',
|
||||||
|
type: 'mcp',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
name: { type: 'string' },
|
||||||
|
age: { type: 'number' },
|
||||||
|
tags: {
|
||||||
|
type: 'array',
|
||||||
|
items: { type: 'string' }
|
||||||
|
},
|
||||||
|
metadata: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
key: { type: 'string' }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
required: ['name']
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||||
|
|
||||||
|
expect(Object.keys(result)).toEqual(['complex-tool'])
|
||||||
|
const tool = result['complex-tool'] as Tool
|
||||||
|
expect(tool.inputSchema).toBeDefined()
|
||||||
|
expect(typeof tool.execute).toBe('function')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should preserve tool names with special characters', () => {
|
||||||
|
const mcpTools: MCPTool[] = [
|
||||||
|
{
|
||||||
|
id: 'special-tool-id',
|
||||||
|
serverId: 'server',
|
||||||
|
serverName: 'server',
|
||||||
|
name: 'tool_with-special.chars',
|
||||||
|
description: 'Special chars tool',
|
||||||
|
type: 'mcp',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||||
|
expect(Object.keys(result)).toEqual(['tool_with-special.chars'])
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle multiple tools with different schemas', () => {
|
||||||
|
const mcpTools: MCPTool[] = [
|
||||||
|
{
|
||||||
|
id: 'string-tool-id',
|
||||||
|
serverId: 'server1',
|
||||||
|
serverName: 'server1',
|
||||||
|
name: 'string-tool',
|
||||||
|
description: 'String tool',
|
||||||
|
type: 'mcp',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
input: { type: 'string' }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'number-tool-id',
|
||||||
|
serverId: 'server2',
|
||||||
|
serverName: 'server2',
|
||||||
|
name: 'number-tool',
|
||||||
|
description: 'Number tool',
|
||||||
|
type: 'mcp',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
count: { type: 'number' }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
id: 'boolean-tool-id',
|
||||||
|
serverId: 'server3',
|
||||||
|
serverName: 'server3',
|
||||||
|
name: 'boolean-tool',
|
||||||
|
description: 'Boolean tool',
|
||||||
|
type: 'mcp',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {
|
||||||
|
enabled: { type: 'boolean' }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||||
|
|
||||||
|
expect(Object.keys(result).sort()).toEqual(['boolean-tool', 'number-tool', 'string-tool'])
|
||||||
|
expect(result['string-tool']).toBeDefined()
|
||||||
|
expect(result['number-tool']).toBeDefined()
|
||||||
|
expect(result['boolean-tool']).toBeDefined()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('tool execution', () => {
|
||||||
|
it('should execute tool with user confirmation', async () => {
|
||||||
|
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
|
||||||
|
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
|
||||||
|
|
||||||
|
vi.mocked(requestToolConfirmation).mockResolvedValue(true)
|
||||||
|
vi.mocked(callMCPTool).mockResolvedValue({
|
||||||
|
content: [{ type: 'text', text: 'Success' }],
|
||||||
|
isError: false
|
||||||
|
})
|
||||||
|
|
||||||
|
const mcpTools: MCPTool[] = [
|
||||||
|
{
|
||||||
|
id: 'test-exec-tool-id',
|
||||||
|
serverId: 'test-server',
|
||||||
|
serverName: 'test-server',
|
||||||
|
name: 'test-exec-tool',
|
||||||
|
description: 'Test execution tool',
|
||||||
|
type: 'mcp',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
||||||
|
const tool = tools['test-exec-tool'] as Tool
|
||||||
|
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'test-call-123' })
|
||||||
|
|
||||||
|
expect(requestToolConfirmation).toHaveBeenCalled()
|
||||||
|
expect(callMCPTool).toHaveBeenCalled()
|
||||||
|
expect(result).toEqual({
|
||||||
|
content: [{ type: 'text', text: 'Success' }],
|
||||||
|
isError: false
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle user cancellation', async () => {
|
||||||
|
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
|
||||||
|
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
|
||||||
|
|
||||||
|
vi.mocked(requestToolConfirmation).mockResolvedValue(false)
|
||||||
|
|
||||||
|
const mcpTools: MCPTool[] = [
|
||||||
|
{
|
||||||
|
id: 'cancelled-tool-id',
|
||||||
|
serverId: 'test-server',
|
||||||
|
serverName: 'test-server',
|
||||||
|
name: 'cancelled-tool',
|
||||||
|
description: 'Tool to cancel',
|
||||||
|
type: 'mcp',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
||||||
|
const tool = tools['cancelled-tool'] as Tool
|
||||||
|
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'cancel-call-123' })
|
||||||
|
|
||||||
|
expect(requestToolConfirmation).toHaveBeenCalled()
|
||||||
|
expect(callMCPTool).not.toHaveBeenCalled()
|
||||||
|
expect(result).toEqual({
|
||||||
|
content: [
|
||||||
|
{
|
||||||
|
type: 'text',
|
||||||
|
text: 'User declined to execute tool "cancelled-tool".'
|
||||||
|
}
|
||||||
|
],
|
||||||
|
isError: false
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle tool execution error', async () => {
|
||||||
|
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
|
||||||
|
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
|
||||||
|
|
||||||
|
vi.mocked(requestToolConfirmation).mockResolvedValue(true)
|
||||||
|
vi.mocked(callMCPTool).mockResolvedValue({
|
||||||
|
content: [{ type: 'text', text: 'Error occurred' }],
|
||||||
|
isError: true
|
||||||
|
})
|
||||||
|
|
||||||
|
const mcpTools: MCPTool[] = [
|
||||||
|
{
|
||||||
|
id: 'error-tool-id',
|
||||||
|
serverId: 'test-server',
|
||||||
|
serverName: 'test-server',
|
||||||
|
name: 'error-tool',
|
||||||
|
description: 'Tool that errors',
|
||||||
|
type: 'mcp',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
||||||
|
const tool = tools['error-tool'] as Tool
|
||||||
|
|
||||||
|
await expect(
|
||||||
|
tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'error-call-123' })
|
||||||
|
).rejects.toEqual({
|
||||||
|
content: [{ type: 'text', text: 'Error occurred' }],
|
||||||
|
isError: true
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should auto-approve when enabled', async () => {
|
||||||
|
const { callMCPTool, isToolAutoApproved } = await import('@renderer/utils/mcp-tools')
|
||||||
|
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
|
||||||
|
|
||||||
|
vi.mocked(isToolAutoApproved).mockReturnValue(true)
|
||||||
|
vi.mocked(callMCPTool).mockResolvedValue({
|
||||||
|
content: [{ type: 'text', text: 'Auto-approved success' }],
|
||||||
|
isError: false
|
||||||
|
})
|
||||||
|
|
||||||
|
const mcpTools: MCPTool[] = [
|
||||||
|
{
|
||||||
|
id: 'auto-approve-tool-id',
|
||||||
|
serverId: 'test-server',
|
||||||
|
serverName: 'test-server',
|
||||||
|
name: 'auto-approve-tool',
|
||||||
|
description: 'Auto-approved tool',
|
||||||
|
type: 'mcp',
|
||||||
|
inputSchema: {
|
||||||
|
type: 'object',
|
||||||
|
properties: {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
||||||
|
const tool = tools['auto-approve-tool'] as Tool
|
||||||
|
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'auto-call-123' })
|
||||||
|
|
||||||
|
expect(requestToolConfirmation).not.toHaveBeenCalled()
|
||||||
|
expect(callMCPTool).toHaveBeenCalled()
|
||||||
|
expect(result).toEqual({
|
||||||
|
content: [{ type: 'text', text: 'Auto-approved success' }],
|
||||||
|
isError: false
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
545
src/renderer/src/aiCore/utils/__tests__/options.test.ts
Normal file
545
src/renderer/src/aiCore/utils/__tests__/options.test.ts
Normal file
@@ -0,0 +1,545 @@
|
|||||||
|
/**
|
||||||
|
* options.ts Unit Tests
|
||||||
|
* Tests for building provider-specific options
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type { Assistant, Model, Provider } from '@renderer/types'
|
||||||
|
import { OpenAIServiceTiers, SystemProviderIds } from '@renderer/types'
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
import { buildProviderOptions } from '../options'
|
||||||
|
|
||||||
|
// Mock dependencies
|
||||||
|
vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => {
|
||||||
|
const actual = (await importOriginal()) as object
|
||||||
|
return {
|
||||||
|
...actual,
|
||||||
|
baseProviderIdSchema: {
|
||||||
|
safeParse: vi.fn((id) => {
|
||||||
|
const baseProviders = [
|
||||||
|
'openai',
|
||||||
|
'openai-chat',
|
||||||
|
'azure',
|
||||||
|
'azure-responses',
|
||||||
|
'huggingface',
|
||||||
|
'anthropic',
|
||||||
|
'google',
|
||||||
|
'xai',
|
||||||
|
'deepseek',
|
||||||
|
'openrouter',
|
||||||
|
'openai-compatible'
|
||||||
|
]
|
||||||
|
if (baseProviders.includes(id)) {
|
||||||
|
return { success: true, data: id }
|
||||||
|
}
|
||||||
|
return { success: false }
|
||||||
|
})
|
||||||
|
},
|
||||||
|
customProviderIdSchema: {
|
||||||
|
safeParse: vi.fn((id) => {
|
||||||
|
const customProviders = ['google-vertex', 'google-vertex-anthropic', 'bedrock']
|
||||||
|
if (customProviders.includes(id)) {
|
||||||
|
return { success: true, data: id }
|
||||||
|
}
|
||||||
|
return { success: false, error: new Error('Invalid provider') }
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
vi.mock('../provider/factory', () => ({
|
||||||
|
getAiSdkProviderId: vi.fn((provider) => {
|
||||||
|
// Simulate the provider ID mapping
|
||||||
|
const mapping: Record<string, string> = {
|
||||||
|
[SystemProviderIds.gemini]: 'google',
|
||||||
|
[SystemProviderIds.openai]: 'openai',
|
||||||
|
[SystemProviderIds.anthropic]: 'anthropic',
|
||||||
|
[SystemProviderIds.grok]: 'xai',
|
||||||
|
[SystemProviderIds.deepseek]: 'deepseek',
|
||||||
|
[SystemProviderIds.openrouter]: 'openrouter'
|
||||||
|
}
|
||||||
|
return mapping[provider.id] || provider.id
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/config/models', async (importOriginal) => ({
|
||||||
|
...(await importOriginal()),
|
||||||
|
isOpenAIModel: vi.fn((model) => model.id.includes('gpt') || model.id.includes('o1')),
|
||||||
|
isQwenMTModel: vi.fn(() => false),
|
||||||
|
isSupportFlexServiceTierModel: vi.fn(() => true),
|
||||||
|
isOpenAILLMModel: vi.fn(() => true),
|
||||||
|
SYSTEM_MODELS: {
|
||||||
|
defaultModel: [
|
||||||
|
{ id: 'default-1', name: 'Default 1' },
|
||||||
|
{ id: 'default-2', name: 'Default 2' },
|
||||||
|
{ id: 'default-3', name: 'Default 3' }
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock(import('@renderer/utils/provider'), async (importOriginal) => {
|
||||||
|
return {
|
||||||
|
...(await importOriginal()),
|
||||||
|
isSupportServiceTierProvider: vi.fn((provider) => {
|
||||||
|
return [SystemProviderIds.openai, SystemProviderIds.groq].includes(provider.id)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
vi.mock('@renderer/store/settings', () => ({
|
||||||
|
default: (state = { settings: {} }) => state
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/hooks/useSettings', () => ({
|
||||||
|
getStoreSetting: vi.fn((key) => {
|
||||||
|
if (key === 'openAI') {
|
||||||
|
return { summaryText: 'off', verbosity: 'medium' } as any
|
||||||
|
}
|
||||||
|
return {}
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/services/AssistantService', () => ({
|
||||||
|
getDefaultAssistant: vi.fn(() => ({
|
||||||
|
id: 'default',
|
||||||
|
name: 'Default Assistant',
|
||||||
|
settings: {}
|
||||||
|
})),
|
||||||
|
getAssistantSettings: vi.fn(() => ({
|
||||||
|
reasoning_effort: 'medium',
|
||||||
|
maxTokens: 4096
|
||||||
|
})),
|
||||||
|
getProviderByModel: vi.fn((model: Model) => ({
|
||||||
|
id: model.provider,
|
||||||
|
name: 'Mock Provider'
|
||||||
|
}))
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../reasoning', () => ({
|
||||||
|
getOpenAIReasoningParams: vi.fn(() => ({ reasoningEffort: 'medium' })),
|
||||||
|
getAnthropicReasoningParams: vi.fn(() => ({
|
||||||
|
thinking: { type: 'enabled', budgetTokens: 5000 }
|
||||||
|
})),
|
||||||
|
getGeminiReasoningParams: vi.fn(() => ({
|
||||||
|
thinkingConfig: { include_thoughts: true }
|
||||||
|
})),
|
||||||
|
getXAIReasoningParams: vi.fn(() => ({ reasoningEffort: 'high' })),
|
||||||
|
getBedrockReasoningParams: vi.fn(() => ({
|
||||||
|
reasoningConfig: { type: 'enabled', budgetTokens: 5000 }
|
||||||
|
})),
|
||||||
|
getReasoningEffort: vi.fn(() => ({ reasoningEffort: 'medium' })),
|
||||||
|
getCustomParameters: vi.fn(() => ({}))
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../image', () => ({
|
||||||
|
buildGeminiGenerateImageParams: vi.fn(() => ({
|
||||||
|
responseModalities: ['TEXT', 'IMAGE']
|
||||||
|
}))
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('../websearch', () => ({
|
||||||
|
getWebSearchParams: vi.fn(() => ({ enable_search: true }))
|
||||||
|
}))
|
||||||
|
|
||||||
|
const ensureWindowApi = () => {
|
||||||
|
const globalWindow = window as any
|
||||||
|
globalWindow.api = globalWindow.api || {}
|
||||||
|
globalWindow.api.getAppInfo = globalWindow.api.getAppInfo || vi.fn(async () => ({ notesPath: '' }))
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureWindowApi()
|
||||||
|
|
||||||
|
describe('options utils', () => {
|
||||||
|
const mockAssistant: Assistant = {
|
||||||
|
id: 'test-assistant',
|
||||||
|
name: 'Test Assistant',
|
||||||
|
settings: {}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const mockModel: Model = {
|
||||||
|
id: 'gpt-4',
|
||||||
|
name: 'GPT-4',
|
||||||
|
provider: SystemProviderIds.openai
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.clearAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('buildProviderOptions', () => {
|
||||||
|
describe('OpenAI provider', () => {
|
||||||
|
const openaiProvider: Provider = {
|
||||||
|
id: SystemProviderIds.openai,
|
||||||
|
name: 'OpenAI',
|
||||||
|
type: 'openai-response',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://api.openai.com/v1',
|
||||||
|
isSystem: true
|
||||||
|
} as Provider
|
||||||
|
|
||||||
|
it('should build basic OpenAI options', () => {
|
||||||
|
const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, {
|
||||||
|
enableReasoning: false,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result).toHaveProperty('openai')
|
||||||
|
expect(result.openai).toBeDefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should include reasoning parameters when enabled', () => {
|
||||||
|
const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, {
|
||||||
|
enableReasoning: true,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.openai).toHaveProperty('reasoningEffort')
|
||||||
|
expect(result.openai.reasoningEffort).toBe('medium')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should include service tier when supported', () => {
|
||||||
|
const providerWithServiceTier: Provider = {
|
||||||
|
...openaiProvider,
|
||||||
|
serviceTier: OpenAIServiceTiers.auto
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = buildProviderOptions(mockAssistant, mockModel, providerWithServiceTier, {
|
||||||
|
enableReasoning: false,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.openai).toHaveProperty('serviceTier')
|
||||||
|
expect(result.openai.serviceTier).toBe(OpenAIServiceTiers.auto)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Anthropic provider', () => {
|
||||||
|
const anthropicProvider: Provider = {
|
||||||
|
id: SystemProviderIds.anthropic,
|
||||||
|
name: 'Anthropic',
|
||||||
|
type: 'anthropic',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://api.anthropic.com',
|
||||||
|
isSystem: true
|
||||||
|
} as Provider
|
||||||
|
|
||||||
|
const anthropicModel: Model = {
|
||||||
|
id: 'claude-3-5-sonnet-20241022',
|
||||||
|
name: 'Claude 3.5 Sonnet',
|
||||||
|
provider: SystemProviderIds.anthropic
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
it('should build basic Anthropic options', () => {
|
||||||
|
const result = buildProviderOptions(mockAssistant, anthropicModel, anthropicProvider, {
|
||||||
|
enableReasoning: false,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result).toHaveProperty('anthropic')
|
||||||
|
expect(result.anthropic).toBeDefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should include reasoning parameters when enabled', () => {
|
||||||
|
const result = buildProviderOptions(mockAssistant, anthropicModel, anthropicProvider, {
|
||||||
|
enableReasoning: true,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.anthropic).toHaveProperty('thinking')
|
||||||
|
expect(result.anthropic.thinking).toEqual({
|
||||||
|
type: 'enabled',
|
||||||
|
budgetTokens: 5000
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Google provider', () => {
|
||||||
|
const googleProvider: Provider = {
|
||||||
|
id: SystemProviderIds.gemini,
|
||||||
|
name: 'Google',
|
||||||
|
type: 'gemini',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://generativelanguage.googleapis.com',
|
||||||
|
isSystem: true,
|
||||||
|
models: [{ id: 'gemini-2.0-flash-exp' }] as Model[]
|
||||||
|
} as Provider
|
||||||
|
|
||||||
|
const googleModel: Model = {
|
||||||
|
id: 'gemini-2.0-flash-exp',
|
||||||
|
name: 'Gemini 2.0 Flash',
|
||||||
|
provider: SystemProviderIds.gemini
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
it('should build basic Google options', () => {
|
||||||
|
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
|
||||||
|
enableReasoning: false,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result).toHaveProperty('google')
|
||||||
|
expect(result.google).toBeDefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should include reasoning parameters when enabled', () => {
|
||||||
|
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
|
||||||
|
enableReasoning: true,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.google).toHaveProperty('thinkingConfig')
|
||||||
|
expect(result.google.thinkingConfig).toEqual({
|
||||||
|
include_thoughts: true
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should include image generation parameters when enabled', () => {
|
||||||
|
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
|
||||||
|
enableReasoning: false,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: true
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.google).toHaveProperty('responseModalities')
|
||||||
|
expect(result.google.responseModalities).toEqual(['TEXT', 'IMAGE'])
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('xAI provider', () => {
|
||||||
|
const xaiProvider = {
|
||||||
|
id: SystemProviderIds.grok,
|
||||||
|
name: 'xAI',
|
||||||
|
type: 'new-api',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://api.x.ai/v1',
|
||||||
|
isSystem: true,
|
||||||
|
models: [] as Model[]
|
||||||
|
} as Provider
|
||||||
|
|
||||||
|
const xaiModel: Model = {
|
||||||
|
id: 'grok-2-latest',
|
||||||
|
name: 'Grok 2',
|
||||||
|
provider: SystemProviderIds.grok
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
it('should build basic xAI options', () => {
|
||||||
|
const result = buildProviderOptions(mockAssistant, xaiModel, xaiProvider, {
|
||||||
|
enableReasoning: false,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result).toHaveProperty('xai')
|
||||||
|
expect(result.xai).toBeDefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should include reasoning parameters when enabled', () => {
|
||||||
|
const result = buildProviderOptions(mockAssistant, xaiModel, xaiProvider, {
|
||||||
|
enableReasoning: true,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.xai).toHaveProperty('reasoningEffort')
|
||||||
|
expect(result.xai.reasoningEffort).toBe('high')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('DeepSeek provider', () => {
|
||||||
|
const deepseekProvider: Provider = {
|
||||||
|
id: SystemProviderIds.deepseek,
|
||||||
|
name: 'DeepSeek',
|
||||||
|
type: 'openai',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://api.deepseek.com',
|
||||||
|
isSystem: true
|
||||||
|
} as Provider
|
||||||
|
|
||||||
|
const deepseekModel: Model = {
|
||||||
|
id: 'deepseek-chat',
|
||||||
|
name: 'DeepSeek Chat',
|
||||||
|
provider: SystemProviderIds.deepseek
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
it('should build basic DeepSeek options', () => {
|
||||||
|
const result = buildProviderOptions(mockAssistant, deepseekModel, deepseekProvider, {
|
||||||
|
enableReasoning: false,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result).toHaveProperty('deepseek')
|
||||||
|
expect(result.deepseek).toBeDefined()
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('OpenRouter provider', () => {
|
||||||
|
const openrouterProvider: Provider = {
|
||||||
|
id: SystemProviderIds.openrouter,
|
||||||
|
name: 'OpenRouter',
|
||||||
|
type: 'openai',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://openrouter.ai/api/v1',
|
||||||
|
isSystem: true
|
||||||
|
} as Provider
|
||||||
|
|
||||||
|
const openrouterModel: Model = {
|
||||||
|
id: 'openai/gpt-4',
|
||||||
|
name: 'GPT-4',
|
||||||
|
provider: SystemProviderIds.openrouter
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
it('should build basic OpenRouter options', () => {
|
||||||
|
const result = buildProviderOptions(mockAssistant, openrouterModel, openrouterProvider, {
|
||||||
|
enableReasoning: false,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result).toHaveProperty('openrouter')
|
||||||
|
expect(result.openrouter).toBeDefined()
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should include web search parameters when enabled', () => {
|
||||||
|
const result = buildProviderOptions(mockAssistant, openrouterModel, openrouterProvider, {
|
||||||
|
enableReasoning: false,
|
||||||
|
enableWebSearch: true,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.openrouter).toHaveProperty('enable_search')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Custom parameters', () => {
|
||||||
|
it('should merge custom parameters', async () => {
|
||||||
|
const { getCustomParameters } = await import('../reasoning')
|
||||||
|
|
||||||
|
vi.mocked(getCustomParameters).mockReturnValue({
|
||||||
|
custom_param: 'custom_value',
|
||||||
|
another_param: 123
|
||||||
|
})
|
||||||
|
|
||||||
|
const result = buildProviderOptions(
|
||||||
|
mockAssistant,
|
||||||
|
mockModel,
|
||||||
|
{
|
||||||
|
id: SystemProviderIds.openai,
|
||||||
|
name: 'OpenAI',
|
||||||
|
type: 'openai',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://api.openai.com/v1'
|
||||||
|
} as Provider,
|
||||||
|
{
|
||||||
|
enableReasoning: false,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(result.openai).toHaveProperty('custom_param')
|
||||||
|
expect(result.openai.custom_param).toBe('custom_value')
|
||||||
|
expect(result.openai).toHaveProperty('another_param')
|
||||||
|
expect(result.openai.another_param).toBe(123)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Multiple capabilities', () => {
|
||||||
|
const googleProvider = {
|
||||||
|
id: SystemProviderIds.gemini,
|
||||||
|
name: 'Google',
|
||||||
|
type: 'gemini',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://generativelanguage.googleapis.com',
|
||||||
|
isSystem: true,
|
||||||
|
models: [] as Model[]
|
||||||
|
} as Provider
|
||||||
|
|
||||||
|
const googleModel: Model = {
|
||||||
|
id: 'gemini-2.0-flash-exp',
|
||||||
|
name: 'Gemini 2.0 Flash',
|
||||||
|
provider: SystemProviderIds.gemini
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
it('should combine reasoning and image generation', () => {
|
||||||
|
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
|
||||||
|
enableReasoning: true,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: true
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.google).toHaveProperty('thinkingConfig')
|
||||||
|
expect(result.google).toHaveProperty('responseModalities')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle all capabilities enabled', () => {
|
||||||
|
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
|
||||||
|
enableReasoning: true,
|
||||||
|
enableWebSearch: true,
|
||||||
|
enableGenerateImage: true
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result.google).toBeDefined()
|
||||||
|
expect(Object.keys(result.google).length).toBeGreaterThan(0)
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('Vertex AI providers', () => {
|
||||||
|
it('should map google-vertex to google', () => {
|
||||||
|
const vertexProvider = {
|
||||||
|
id: 'google-vertex',
|
||||||
|
name: 'Vertex AI',
|
||||||
|
type: 'vertexai',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://vertex-ai.googleapis.com',
|
||||||
|
models: [] as Model[]
|
||||||
|
} as Provider
|
||||||
|
|
||||||
|
const vertexModel: Model = {
|
||||||
|
id: 'gemini-2.0-flash-exp',
|
||||||
|
name: 'Gemini 2.0 Flash',
|
||||||
|
provider: 'google-vertex'
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const result = buildProviderOptions(mockAssistant, vertexModel, vertexProvider, {
|
||||||
|
enableReasoning: false,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result).toHaveProperty('google')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should map google-vertex-anthropic to anthropic', () => {
|
||||||
|
const vertexAnthropicProvider = {
|
||||||
|
id: 'google-vertex-anthropic',
|
||||||
|
name: 'Vertex AI Anthropic',
|
||||||
|
type: 'vertex-anthropic',
|
||||||
|
apiKey: 'test-key',
|
||||||
|
apiHost: 'https://vertex-ai.googleapis.com',
|
||||||
|
models: [] as Model[]
|
||||||
|
} as Provider
|
||||||
|
|
||||||
|
const vertexModel: Model = {
|
||||||
|
id: 'claude-3-5-sonnet-20241022',
|
||||||
|
name: 'Claude 3.5 Sonnet',
|
||||||
|
provider: 'google-vertex-anthropic'
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const result = buildProviderOptions(mockAssistant, vertexModel, vertexAnthropicProvider, {
|
||||||
|
enableReasoning: false,
|
||||||
|
enableWebSearch: false,
|
||||||
|
enableGenerateImage: false
|
||||||
|
})
|
||||||
|
|
||||||
|
expect(result).toHaveProperty('anthropic')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
967
src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts
Normal file
967
src/renderer/src/aiCore/utils/__tests__/reasoning.test.ts
Normal file
@@ -0,0 +1,967 @@
|
|||||||
|
/**
|
||||||
|
* reasoning.ts Unit Tests
|
||||||
|
* Tests for reasoning parameter generation utilities
|
||||||
|
*/
|
||||||
|
|
||||||
|
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||||
|
import type { SettingsState } from '@renderer/store/settings'
|
||||||
|
import type { Assistant, Model, Provider } from '@renderer/types'
|
||||||
|
import { SystemProviderIds } from '@renderer/types'
|
||||||
|
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
import {
|
||||||
|
getAnthropicReasoningParams,
|
||||||
|
getBedrockReasoningParams,
|
||||||
|
getCustomParameters,
|
||||||
|
getGeminiReasoningParams,
|
||||||
|
getOpenAIReasoningParams,
|
||||||
|
getReasoningEffort,
|
||||||
|
getXAIReasoningParams
|
||||||
|
} from '../reasoning'
|
||||||
|
|
||||||
|
function defaultGetStoreSetting<K extends keyof SettingsState>(key: K): SettingsState[K] {
|
||||||
|
if (key === 'openAI') {
|
||||||
|
return {
|
||||||
|
summaryText: 'auto',
|
||||||
|
verbosity: 'medium'
|
||||||
|
} as SettingsState[K]
|
||||||
|
}
|
||||||
|
return undefined as SettingsState[K]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mock dependencies
|
||||||
|
vi.mock('@logger', () => ({
|
||||||
|
loggerService: {
|
||||||
|
withContext: () => ({
|
||||||
|
debug: vi.fn(),
|
||||||
|
error: vi.fn(),
|
||||||
|
warn: vi.fn(),
|
||||||
|
info: vi.fn()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/store/settings', () => ({
|
||||||
|
default: (state = { settings: {} }) => state
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/store/llm', () => ({
|
||||||
|
initialState: {},
|
||||||
|
default: (state = { llm: {} }) => state
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/config/constant', () => ({
|
||||||
|
DEFAULT_MAX_TOKENS: 4096,
|
||||||
|
isMac: false,
|
||||||
|
isWin: false,
|
||||||
|
TOKENFLUX_HOST: 'mock-host'
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/utils/provider', () => ({
|
||||||
|
isSupportEnableThinkingProvider: vi.fn((provider) => {
|
||||||
|
return [SystemProviderIds.dashscope, SystemProviderIds.silicon].includes(provider.id)
|
||||||
|
})
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/config/models', async (importOriginal) => {
|
||||||
|
const actual: any = await importOriginal()
|
||||||
|
return {
|
||||||
|
...actual,
|
||||||
|
isReasoningModel: vi.fn(() => false),
|
||||||
|
isOpenAIDeepResearchModel: vi.fn(() => false),
|
||||||
|
isOpenAIModel: vi.fn(() => false),
|
||||||
|
isSupportedReasoningEffortOpenAIModel: vi.fn(() => false),
|
||||||
|
isSupportedThinkingTokenQwenModel: vi.fn(() => false),
|
||||||
|
isQwenReasoningModel: vi.fn(() => false),
|
||||||
|
isSupportedThinkingTokenClaudeModel: vi.fn(() => false),
|
||||||
|
isSupportedThinkingTokenGeminiModel: vi.fn(() => false),
|
||||||
|
isSupportedThinkingTokenDoubaoModel: vi.fn(() => false),
|
||||||
|
isSupportedThinkingTokenZhipuModel: vi.fn(() => false),
|
||||||
|
isSupportedReasoningEffortModel: vi.fn(() => false),
|
||||||
|
isDeepSeekHybridInferenceModel: vi.fn(() => false),
|
||||||
|
isSupportedReasoningEffortGrokModel: vi.fn(() => false),
|
||||||
|
getThinkModelType: vi.fn(() => 'default'),
|
||||||
|
isDoubaoSeedAfter251015: vi.fn(() => false),
|
||||||
|
isDoubaoThinkingAutoModel: vi.fn(() => false),
|
||||||
|
isGrok4FastReasoningModel: vi.fn(() => false),
|
||||||
|
isGrokReasoningModel: vi.fn(() => false),
|
||||||
|
isOpenAIReasoningModel: vi.fn(() => false),
|
||||||
|
isQwenAlwaysThinkModel: vi.fn(() => false),
|
||||||
|
isSupportedThinkingTokenHunyuanModel: vi.fn(() => false),
|
||||||
|
isSupportedThinkingTokenModel: vi.fn(() => false),
|
||||||
|
isGPT51SeriesModel: vi.fn(() => false)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
vi.mock('@renderer/hooks/useSettings', () => ({
|
||||||
|
getStoreSetting: vi.fn(defaultGetStoreSetting)
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/services/AssistantService', () => ({
|
||||||
|
getAssistantSettings: vi.fn((assistant) => ({
|
||||||
|
maxTokens: assistant?.settings?.maxTokens || 4096,
|
||||||
|
reasoning_effort: assistant?.settings?.reasoning_effort
|
||||||
|
})),
|
||||||
|
getProviderByModel: vi.fn((model) => ({
|
||||||
|
id: model.provider,
|
||||||
|
name: 'Test Provider'
|
||||||
|
})),
|
||||||
|
getDefaultAssistant: vi.fn(() => ({
|
||||||
|
id: 'default',
|
||||||
|
name: 'Default Assistant',
|
||||||
|
settings: {}
|
||||||
|
}))
|
||||||
|
}))
|
||||||
|
|
||||||
|
const ensureWindowApi = () => {
|
||||||
|
const globalWindow = window as any
|
||||||
|
globalWindow.api = globalWindow.api || {}
|
||||||
|
globalWindow.api.getAppInfo = globalWindow.api.getAppInfo || vi.fn(async () => ({ notesPath: '' }))
|
||||||
|
}
|
||||||
|
|
||||||
|
ensureWindowApi()
|
||||||
|
|
||||||
|
describe('reasoning utils', () => {
|
||||||
|
beforeEach(() => {
|
||||||
|
vi.resetAllMocks()
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getReasoningEffort', () => {
|
||||||
|
it('should return empty object for non-reasoning model', async () => {
|
||||||
|
const model: Model = {
|
||||||
|
id: 'gpt-4',
|
||||||
|
name: 'GPT-4',
|
||||||
|
provider: SystemProviderIds.openai
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getReasoningEffort(assistant, model)
|
||||||
|
expect(result).toEqual({})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should disable reasoning for OpenRouter when no reasoning effort set', async () => {
|
||||||
|
const { isReasoningModel } = await import('@renderer/config/models')
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'anthropic/claude-sonnet-4',
|
||||||
|
name: 'Claude Sonnet 4',
|
||||||
|
provider: SystemProviderIds.openrouter
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getReasoningEffort(assistant, model)
|
||||||
|
expect(result).toEqual({ reasoning: { enabled: false, exclude: true } })
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle Qwen models with enable_thinking', async () => {
|
||||||
|
const { isReasoningModel, isSupportedThinkingTokenQwenModel, isQwenReasoningModel } = await import(
|
||||||
|
'@renderer/config/models'
|
||||||
|
)
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isSupportedThinkingTokenQwenModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isQwenReasoningModel).mockReturnValue(true)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'qwen-plus',
|
||||||
|
name: 'Qwen Plus',
|
||||||
|
provider: SystemProviderIds.dashscope
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {
|
||||||
|
reasoning_effort: 'medium'
|
||||||
|
}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getReasoningEffort(assistant, model)
|
||||||
|
expect(result).toHaveProperty('enable_thinking')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle Claude models with thinking config', async () => {
|
||||||
|
const {
|
||||||
|
isSupportedThinkingTokenClaudeModel,
|
||||||
|
isReasoningModel,
|
||||||
|
isQwenReasoningModel,
|
||||||
|
isSupportedThinkingTokenGeminiModel,
|
||||||
|
isSupportedThinkingTokenDoubaoModel,
|
||||||
|
isSupportedThinkingTokenZhipuModel,
|
||||||
|
isSupportedReasoningEffortModel
|
||||||
|
} = await import('@renderer/config/models')
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isQwenReasoningModel).mockReturnValue(false)
|
||||||
|
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(false)
|
||||||
|
vi.mocked(isSupportedThinkingTokenDoubaoModel).mockReturnValue(false)
|
||||||
|
vi.mocked(isSupportedThinkingTokenZhipuModel).mockReturnValue(false)
|
||||||
|
vi.mocked(isSupportedReasoningEffortModel).mockReturnValue(false)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'claude-3-7-sonnet',
|
||||||
|
name: 'Claude 3.7 Sonnet',
|
||||||
|
provider: SystemProviderIds.anthropic
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {
|
||||||
|
reasoning_effort: 'high',
|
||||||
|
maxTokens: 4096
|
||||||
|
}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getReasoningEffort(assistant, model)
|
||||||
|
expect(result).toEqual({
|
||||||
|
thinking: {
|
||||||
|
type: 'enabled',
|
||||||
|
budget_tokens: expect.any(Number)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle Gemini Flash models with thinking budget 0', async () => {
|
||||||
|
const {
|
||||||
|
isSupportedThinkingTokenGeminiModel,
|
||||||
|
isReasoningModel,
|
||||||
|
isQwenReasoningModel,
|
||||||
|
isSupportedThinkingTokenClaudeModel,
|
||||||
|
isSupportedThinkingTokenDoubaoModel,
|
||||||
|
isSupportedThinkingTokenZhipuModel,
|
||||||
|
isOpenAIDeepResearchModel,
|
||||||
|
isSupportedThinkingTokenQwenModel,
|
||||||
|
isSupportedThinkingTokenHunyuanModel,
|
||||||
|
isDeepSeekHybridInferenceModel
|
||||||
|
} = await import('@renderer/config/models')
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
|
||||||
|
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isQwenReasoningModel).mockReturnValue(false)
|
||||||
|
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(false)
|
||||||
|
vi.mocked(isSupportedThinkingTokenDoubaoModel).mockReturnValue(false)
|
||||||
|
vi.mocked(isSupportedThinkingTokenZhipuModel).mockReturnValue(false)
|
||||||
|
vi.mocked(isSupportedThinkingTokenQwenModel).mockReturnValue(false)
|
||||||
|
vi.mocked(isSupportedThinkingTokenHunyuanModel).mockReturnValue(false)
|
||||||
|
vi.mocked(isDeepSeekHybridInferenceModel).mockReturnValue(false)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'gemini-2.5-flash',
|
||||||
|
name: 'Gemini 2.5 Flash',
|
||||||
|
provider: SystemProviderIds.openai
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getReasoningEffort(assistant, model)
|
||||||
|
expect(result).toEqual({
|
||||||
|
extra_body: {
|
||||||
|
google: {
|
||||||
|
thinking_config: {
|
||||||
|
thinking_budget: 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle GPT-5.1 reasoning model with effort levels', async () => {
|
||||||
|
const {
|
||||||
|
isReasoningModel,
|
||||||
|
isOpenAIDeepResearchModel,
|
||||||
|
isSupportedReasoningEffortModel,
|
||||||
|
isGPT51SeriesModel,
|
||||||
|
getThinkModelType
|
||||||
|
} = await import('@renderer/config/models')
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
|
||||||
|
vi.mocked(isSupportedReasoningEffortModel).mockReturnValue(true)
|
||||||
|
vi.mocked(getThinkModelType).mockReturnValue('gpt5_1')
|
||||||
|
vi.mocked(isGPT51SeriesModel).mockReturnValue(true)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'gpt-5.1',
|
||||||
|
name: 'GPT-5.1',
|
||||||
|
provider: SystemProviderIds.openai
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {
|
||||||
|
reasoning_effort: 'none'
|
||||||
|
}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getReasoningEffort(assistant, model)
|
||||||
|
expect(result).toEqual({
|
||||||
|
reasoningEffort: 'none'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle DeepSeek hybrid inference models', async () => {
|
||||||
|
const { isReasoningModel, isDeepSeekHybridInferenceModel } = await import('@renderer/config/models')
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isDeepSeekHybridInferenceModel).mockReturnValue(true)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'deepseek-v3.1',
|
||||||
|
name: 'DeepSeek V3.1',
|
||||||
|
provider: SystemProviderIds.silicon
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {
|
||||||
|
reasoning_effort: 'high'
|
||||||
|
}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getReasoningEffort(assistant, model)
|
||||||
|
expect(result).toEqual({
|
||||||
|
enable_thinking: true
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return medium effort for deep research models', async () => {
|
||||||
|
const { isReasoningModel, isOpenAIDeepResearchModel } = await import('@renderer/config/models')
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(true)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'o3-deep-research',
|
||||||
|
provider: SystemProviderIds.openai
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getReasoningEffort(assistant, model)
|
||||||
|
expect(result).toEqual({ reasoning_effort: 'medium' })
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return empty for groq provider', async () => {
|
||||||
|
const { getProviderByModel } = await import('@renderer/services/AssistantService')
|
||||||
|
|
||||||
|
vi.mocked(getProviderByModel).mockReturnValue({
|
||||||
|
id: 'groq',
|
||||||
|
name: 'Groq'
|
||||||
|
} as Provider)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'groq-model',
|
||||||
|
name: 'Groq Model',
|
||||||
|
provider: 'groq'
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getReasoningEffort(assistant, model)
|
||||||
|
expect(result).toEqual({})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getOpenAIReasoningParams', () => {
|
||||||
|
it('should return empty object for non-reasoning model', async () => {
|
||||||
|
const model: Model = {
|
||||||
|
id: 'gpt-4',
|
||||||
|
name: 'GPT-4',
|
||||||
|
provider: SystemProviderIds.openai
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getOpenAIReasoningParams(assistant, model)
|
||||||
|
expect(result).toEqual({})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return empty when no reasoning effort set', async () => {
|
||||||
|
const model: Model = {
|
||||||
|
id: 'o1-preview',
|
||||||
|
name: 'O1 Preview',
|
||||||
|
provider: SystemProviderIds.openai
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getOpenAIReasoningParams(assistant, model)
|
||||||
|
expect(result).toEqual({})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return reasoning effort for OpenAI models', async () => {
|
||||||
|
const { isReasoningModel, isOpenAIModel, isSupportedReasoningEffortOpenAIModel } = await import(
|
||||||
|
'@renderer/config/models'
|
||||||
|
)
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isOpenAIModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'gpt-5.1',
|
||||||
|
name: 'GPT 5.1',
|
||||||
|
provider: SystemProviderIds.openai
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {
|
||||||
|
reasoning_effort: 'high'
|
||||||
|
}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getOpenAIReasoningParams(assistant, model)
|
||||||
|
expect(result).toEqual({
|
||||||
|
reasoningEffort: 'high',
|
||||||
|
reasoningSummary: 'auto'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should include reasoning summary when not o1-pro', async () => {
|
||||||
|
const { isReasoningModel, isOpenAIModel, isSupportedReasoningEffortOpenAIModel } = await import(
|
||||||
|
'@renderer/config/models'
|
||||||
|
)
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isOpenAIModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'gpt-5',
|
||||||
|
provider: SystemProviderIds.openai
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {
|
||||||
|
reasoning_effort: 'medium'
|
||||||
|
}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getOpenAIReasoningParams(assistant, model)
|
||||||
|
expect(result).toEqual({
|
||||||
|
reasoningEffort: 'medium',
|
||||||
|
reasoningSummary: 'auto'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should not include reasoning summary for o1-pro', async () => {
|
||||||
|
const { isReasoningModel, isOpenAIDeepResearchModel, isSupportedReasoningEffortOpenAIModel } = await import(
|
||||||
|
'@renderer/config/models'
|
||||||
|
)
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
|
||||||
|
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
|
||||||
|
vi.mocked(getStoreSetting).mockReturnValue({ summaryText: 'off' } as any)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'o1-pro',
|
||||||
|
name: 'O1 Pro',
|
||||||
|
provider: SystemProviderIds.openai
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {
|
||||||
|
reasoning_effort: 'high'
|
||||||
|
}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getOpenAIReasoningParams(assistant, model)
|
||||||
|
expect(result).toEqual({
|
||||||
|
reasoningEffort: 'high',
|
||||||
|
reasoningSummary: undefined
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should force medium effort for deep research models', async () => {
|
||||||
|
const { isReasoningModel, isOpenAIModel, isOpenAIDeepResearchModel, isSupportedReasoningEffortOpenAIModel } =
|
||||||
|
await import('@renderer/config/models')
|
||||||
|
const { getStoreSetting } = await import('@renderer/hooks/useSettings')
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isOpenAIModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
|
||||||
|
vi.mocked(getStoreSetting).mockReturnValue({ summaryText: 'off' } as any)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'o3-deep-research',
|
||||||
|
name: 'O3 Mini',
|
||||||
|
provider: SystemProviderIds.openai
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {
|
||||||
|
reasoning_effort: 'high'
|
||||||
|
}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getOpenAIReasoningParams(assistant, model)
|
||||||
|
expect(result).toEqual({
|
||||||
|
reasoningEffort: 'medium',
|
||||||
|
reasoningSummary: 'off'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getAnthropicReasoningParams', () => {
|
||||||
|
it('should return empty for non-reasoning model', async () => {
|
||||||
|
const { isReasoningModel } = await import('@renderer/config/models')
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(false)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'claude-3-5-sonnet',
|
||||||
|
name: 'Claude 3.5 Sonnet',
|
||||||
|
provider: SystemProviderIds.anthropic
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getAnthropicReasoningParams(assistant, model)
|
||||||
|
expect(result).toEqual({})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return disabled thinking when no reasoning effort', async () => {
|
||||||
|
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(false)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'claude-3-7-sonnet',
|
||||||
|
name: 'Claude 3.7 Sonnet',
|
||||||
|
provider: SystemProviderIds.anthropic
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getAnthropicReasoningParams(assistant, model)
|
||||||
|
expect(result).toEqual({
|
||||||
|
thinking: {
|
||||||
|
type: 'disabled'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return enabled thinking with budget for Claude models', async () => {
|
||||||
|
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'claude-3-7-sonnet',
|
||||||
|
name: 'Claude 3.7 Sonnet',
|
||||||
|
provider: SystemProviderIds.anthropic
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {
|
||||||
|
reasoning_effort: 'medium',
|
||||||
|
maxTokens: 4096
|
||||||
|
}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getAnthropicReasoningParams(assistant, model)
|
||||||
|
expect(result).toEqual({
|
||||||
|
thinking: {
|
||||||
|
type: 'enabled',
|
||||||
|
budgetTokens: 2048
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getGeminiReasoningParams', () => {
|
||||||
|
it('should return empty for non-reasoning model', async () => {
|
||||||
|
const { isReasoningModel } = await import('@renderer/config/models')
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(false)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'gemini-2.0-flash',
|
||||||
|
name: 'Gemini 2.0 Flash',
|
||||||
|
provider: SystemProviderIds.gemini
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getGeminiReasoningParams(assistant, model)
|
||||||
|
expect(result).toEqual({})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should disable thinking for Flash models without reasoning effort', async () => {
|
||||||
|
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'gemini-2.5-flash',
|
||||||
|
name: 'Gemini 2.5 Flash',
|
||||||
|
provider: SystemProviderIds.gemini
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getGeminiReasoningParams(assistant, model)
|
||||||
|
expect(result).toEqual({
|
||||||
|
thinkingConfig: {
|
||||||
|
includeThoughts: false,
|
||||||
|
thinkingBudget: 0
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should enable thinking with budget for reasoning effort', async () => {
|
||||||
|
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'gemini-2.5-pro',
|
||||||
|
name: 'Gemini 2.5 Pro',
|
||||||
|
provider: SystemProviderIds.gemini
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {
|
||||||
|
reasoning_effort: 'medium'
|
||||||
|
}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getGeminiReasoningParams(assistant, model)
|
||||||
|
expect(result).toEqual({
|
||||||
|
thinkingConfig: {
|
||||||
|
thinkingBudget: 16448,
|
||||||
|
includeThoughts: true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should enable thinking without budget for auto effort ratio > 1', async () => {
|
||||||
|
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'gemini-2.5-pro',
|
||||||
|
name: 'Gemini 2.5 Pro',
|
||||||
|
provider: SystemProviderIds.gemini
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {
|
||||||
|
reasoning_effort: 'auto'
|
||||||
|
}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getGeminiReasoningParams(assistant, model)
|
||||||
|
expect(result).toEqual({
|
||||||
|
thinkingConfig: {
|
||||||
|
includeThoughts: true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getXAIReasoningParams', () => {
|
||||||
|
it('should return empty for non-Grok model', async () => {
|
||||||
|
const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
|
||||||
|
|
||||||
|
vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(false)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'other-model',
|
||||||
|
name: 'Other Model',
|
||||||
|
provider: SystemProviderIds.grok
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getXAIReasoningParams(assistant, model)
|
||||||
|
expect(result).toEqual({})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return empty when no reasoning effort', async () => {
|
||||||
|
const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
|
||||||
|
|
||||||
|
vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(true)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'grok-2',
|
||||||
|
name: 'Grok 2',
|
||||||
|
provider: SystemProviderIds.grok
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getXAIReasoningParams(assistant, model)
|
||||||
|
expect(result).toEqual({})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return reasoning effort for Grok models', async () => {
|
||||||
|
const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
|
||||||
|
|
||||||
|
vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(true)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'grok-3',
|
||||||
|
name: 'Grok 3',
|
||||||
|
provider: SystemProviderIds.grok
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {
|
||||||
|
reasoning_effort: 'high'
|
||||||
|
}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getXAIReasoningParams(assistant, model)
|
||||||
|
expect(result).toHaveProperty('reasoningEffort')
|
||||||
|
expect(result.reasoningEffort).toBe('high')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getBedrockReasoningParams', () => {
|
||||||
|
it('should return empty for non-reasoning model', async () => {
|
||||||
|
const model: Model = {
|
||||||
|
id: 'other-model',
|
||||||
|
name: 'Other Model',
|
||||||
|
provider: 'bedrock'
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getBedrockReasoningParams(assistant, model)
|
||||||
|
expect(result).toEqual({})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return empty when no reasoning effort', async () => {
|
||||||
|
const model: Model = {
|
||||||
|
id: 'claude-3-7-sonnet',
|
||||||
|
name: 'Claude 3.7 Sonnet',
|
||||||
|
provider: 'bedrock'
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getBedrockReasoningParams(assistant, model)
|
||||||
|
expect(result).toEqual({})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return reasoning config for Claude models on Bedrock', async () => {
|
||||||
|
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
|
||||||
|
|
||||||
|
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||||
|
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'claude-3-7-sonnet',
|
||||||
|
name: 'Claude 3.7 Sonnet',
|
||||||
|
provider: 'bedrock'
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {
|
||||||
|
reasoning_effort: 'medium',
|
||||||
|
maxTokens: 4096
|
||||||
|
}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getBedrockReasoningParams(assistant, model)
|
||||||
|
expect(result).toEqual({
|
||||||
|
reasoningConfig: {
|
||||||
|
type: 'enabled',
|
||||||
|
budgetTokens: 2048
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('getCustomParameters', () => {
|
||||||
|
it('should return empty object when no custom parameters', async () => {
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getCustomParameters(assistant)
|
||||||
|
expect(result).toEqual({})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return custom parameters as key-value pairs', async () => {
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {
|
||||||
|
customParameters: [
|
||||||
|
{ name: 'param1', value: 'value1', type: 'string' },
|
||||||
|
{ name: 'param2', value: 123, type: 'number' }
|
||||||
|
]
|
||||||
|
}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getCustomParameters(assistant)
|
||||||
|
expect(result).toEqual({
|
||||||
|
param1: 'value1',
|
||||||
|
param2: 123
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should parse JSON type parameters', async () => {
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {
|
||||||
|
customParameters: [{ name: 'config', value: '{"key": "value"}', type: 'json' }]
|
||||||
|
}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getCustomParameters(assistant)
|
||||||
|
expect(result).toEqual({
|
||||||
|
config: { key: 'value' }
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle invalid JSON gracefully', async () => {
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {
|
||||||
|
customParameters: [{ name: 'invalid', value: '{invalid json', type: 'json' }]
|
||||||
|
}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getCustomParameters(assistant)
|
||||||
|
expect(result).toEqual({
|
||||||
|
invalid: '{invalid json'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle undefined JSON value', async () => {
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {
|
||||||
|
customParameters: [{ name: 'undef', value: 'undefined', type: 'json' }]
|
||||||
|
}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getCustomParameters(assistant)
|
||||||
|
expect(result).toEqual({
|
||||||
|
undef: undefined
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should skip parameters with empty names', async () => {
|
||||||
|
const assistant: Assistant = {
|
||||||
|
id: 'test',
|
||||||
|
name: 'Test',
|
||||||
|
settings: {
|
||||||
|
customParameters: [
|
||||||
|
{ name: '', value: 'value1', type: 'string' },
|
||||||
|
{ name: ' ', value: 'value2', type: 'string' },
|
||||||
|
{ name: 'valid', value: 'value3', type: 'string' }
|
||||||
|
]
|
||||||
|
}
|
||||||
|
} as Assistant
|
||||||
|
|
||||||
|
const result = getCustomParameters(assistant)
|
||||||
|
expect(result).toEqual({
|
||||||
|
valid: 'value3'
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
384
src/renderer/src/aiCore/utils/__tests__/websearch.test.ts
Normal file
384
src/renderer/src/aiCore/utils/__tests__/websearch.test.ts
Normal file
@@ -0,0 +1,384 @@
|
|||||||
|
/**
|
||||||
|
* websearch.ts Unit Tests
|
||||||
|
* Tests for web search parameters generation utilities
|
||||||
|
*/
|
||||||
|
|
||||||
|
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
|
||||||
|
import type { Model } from '@renderer/types'
|
||||||
|
import { describe, expect, it, vi } from 'vitest'
|
||||||
|
|
||||||
|
import { buildProviderBuiltinWebSearchConfig, getWebSearchParams } from '../websearch'
|
||||||
|
|
||||||
|
// Mock dependencies
|
||||||
|
vi.mock('@renderer/config/models', () => ({
|
||||||
|
isOpenAIWebSearchChatCompletionOnlyModel: vi.fn((model) => model?.id?.includes('o1-pro') ?? false),
|
||||||
|
isOpenAIDeepResearchModel: vi.fn((model) => model?.id?.includes('o3-mini') ?? false)
|
||||||
|
}))
|
||||||
|
|
||||||
|
vi.mock('@renderer/utils/blacklistMatchPattern', () => ({
|
||||||
|
mapRegexToPatterns: vi.fn((patterns) => patterns || [])
|
||||||
|
}))
|
||||||
|
|
||||||
|
describe('websearch utils', () => {
|
||||||
|
describe('getWebSearchParams', () => {
|
||||||
|
it('should return enhancement params for hunyuan provider', () => {
|
||||||
|
const model: Model = {
|
||||||
|
id: 'hunyuan-model',
|
||||||
|
name: 'Hunyuan Model',
|
||||||
|
provider: 'hunyuan'
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const result = getWebSearchParams(model)
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
enable_enhancement: true,
|
||||||
|
citation: true,
|
||||||
|
search_info: true
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return search params for dashscope provider', () => {
|
||||||
|
const model: Model = {
|
||||||
|
id: 'qwen-model',
|
||||||
|
name: 'Qwen Model',
|
||||||
|
provider: 'dashscope'
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const result = getWebSearchParams(model)
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
enable_search: true,
|
||||||
|
search_options: {
|
||||||
|
forced_search: true
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return web_search_options for OpenAI web search models', () => {
|
||||||
|
const model: Model = {
|
||||||
|
id: 'o1-pro',
|
||||||
|
name: 'O1 Pro',
|
||||||
|
provider: 'openai'
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const result = getWebSearchParams(model)
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
web_search_options: {}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return empty object for other providers', () => {
|
||||||
|
const model: Model = {
|
||||||
|
id: 'gpt-4',
|
||||||
|
name: 'GPT-4',
|
||||||
|
provider: 'openai'
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const result = getWebSearchParams(model)
|
||||||
|
|
||||||
|
expect(result).toEqual({})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return empty object for custom provider', () => {
|
||||||
|
const model: Model = {
|
||||||
|
id: 'custom-model',
|
||||||
|
name: 'Custom Model',
|
||||||
|
provider: 'custom-provider'
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const result = getWebSearchParams(model)
|
||||||
|
|
||||||
|
expect(result).toEqual({})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('buildProviderBuiltinWebSearchConfig', () => {
|
||||||
|
const defaultWebSearchConfig: CherryWebSearchConfig = {
|
||||||
|
searchWithTime: true,
|
||||||
|
maxResults: 50,
|
||||||
|
excludeDomains: []
|
||||||
|
}
|
||||||
|
|
||||||
|
describe('openai provider', () => {
|
||||||
|
it('should return low search context size for low maxResults', () => {
|
||||||
|
const config: CherryWebSearchConfig = {
|
||||||
|
searchWithTime: true,
|
||||||
|
maxResults: 20,
|
||||||
|
excludeDomains: []
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = buildProviderBuiltinWebSearchConfig('openai', config)
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
openai: {
|
||||||
|
searchContextSize: 'low'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return medium search context size for medium maxResults', () => {
|
||||||
|
const config: CherryWebSearchConfig = {
|
||||||
|
searchWithTime: true,
|
||||||
|
maxResults: 50,
|
||||||
|
excludeDomains: []
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = buildProviderBuiltinWebSearchConfig('openai', config)
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
openai: {
|
||||||
|
searchContextSize: 'medium'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return high search context size for high maxResults', () => {
|
||||||
|
const config: CherryWebSearchConfig = {
|
||||||
|
searchWithTime: true,
|
||||||
|
maxResults: 80,
|
||||||
|
excludeDomains: []
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = buildProviderBuiltinWebSearchConfig('openai', config)
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
openai: {
|
||||||
|
searchContextSize: 'high'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should use medium for deep research models regardless of maxResults', () => {
|
||||||
|
const config: CherryWebSearchConfig = {
|
||||||
|
searchWithTime: true,
|
||||||
|
maxResults: 100,
|
||||||
|
excludeDomains: []
|
||||||
|
}
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'o3-mini',
|
||||||
|
name: 'O3 Mini',
|
||||||
|
provider: 'openai'
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const result = buildProviderBuiltinWebSearchConfig('openai', config, model)
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
openai: {
|
||||||
|
searchContextSize: 'medium'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('openai-chat provider', () => {
|
||||||
|
it('should return correct search context size', () => {
|
||||||
|
const config: CherryWebSearchConfig = {
|
||||||
|
searchWithTime: true,
|
||||||
|
maxResults: 50,
|
||||||
|
excludeDomains: []
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = buildProviderBuiltinWebSearchConfig('openai-chat', config)
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
'openai-chat': {
|
||||||
|
searchContextSize: 'medium'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle deep research models', () => {
|
||||||
|
const config: CherryWebSearchConfig = {
|
||||||
|
searchWithTime: true,
|
||||||
|
maxResults: 100,
|
||||||
|
excludeDomains: []
|
||||||
|
}
|
||||||
|
|
||||||
|
const model: Model = {
|
||||||
|
id: 'o3-mini',
|
||||||
|
name: 'O3 Mini',
|
||||||
|
provider: 'openai'
|
||||||
|
} as Model
|
||||||
|
|
||||||
|
const result = buildProviderBuiltinWebSearchConfig('openai-chat', config, model)
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
'openai-chat': {
|
||||||
|
searchContextSize: 'medium'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('anthropic provider', () => {
|
||||||
|
it('should return anthropic search options with maxUses', () => {
|
||||||
|
const result = buildProviderBuiltinWebSearchConfig('anthropic', defaultWebSearchConfig)
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
anthropic: {
|
||||||
|
maxUses: 50,
|
||||||
|
blockedDomains: undefined
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should include blockedDomains when excludeDomains provided', () => {
|
||||||
|
const config: CherryWebSearchConfig = {
|
||||||
|
searchWithTime: true,
|
||||||
|
maxResults: 30,
|
||||||
|
excludeDomains: ['example.com', 'test.com']
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = buildProviderBuiltinWebSearchConfig('anthropic', config)
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
anthropic: {
|
||||||
|
maxUses: 30,
|
||||||
|
blockedDomains: ['example.com', 'test.com']
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should not include blockedDomains when empty', () => {
|
||||||
|
const result = buildProviderBuiltinWebSearchConfig('anthropic', defaultWebSearchConfig)
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
anthropic: {
|
||||||
|
maxUses: 50,
|
||||||
|
blockedDomains: undefined
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('xai provider', () => {
|
||||||
|
it('should return xai search options', () => {
|
||||||
|
const result = buildProviderBuiltinWebSearchConfig('xai', defaultWebSearchConfig)
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
xai: {
|
||||||
|
maxSearchResults: 50,
|
||||||
|
returnCitations: true,
|
||||||
|
sources: [{ type: 'web', excludedWebsites: [] }, { type: 'news' }, { type: 'x' }],
|
||||||
|
mode: 'on'
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should limit excluded websites to 5', () => {
|
||||||
|
const config: CherryWebSearchConfig = {
|
||||||
|
searchWithTime: true,
|
||||||
|
maxResults: 40,
|
||||||
|
excludeDomains: ['site1.com', 'site2.com', 'site3.com', 'site4.com', 'site5.com', 'site6.com', 'site7.com']
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = buildProviderBuiltinWebSearchConfig('xai', config)
|
||||||
|
|
||||||
|
expect(result?.xai?.sources).toBeDefined()
|
||||||
|
const webSource = result?.xai?.sources?.[0]
|
||||||
|
if (webSource && webSource.type === 'web') {
|
||||||
|
expect(webSource.excludedWebsites).toHaveLength(5)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should include all sources types', () => {
|
||||||
|
const result = buildProviderBuiltinWebSearchConfig('xai', defaultWebSearchConfig)
|
||||||
|
|
||||||
|
expect(result?.xai?.sources).toHaveLength(3)
|
||||||
|
expect(result?.xai?.sources?.[0].type).toBe('web')
|
||||||
|
expect(result?.xai?.sources?.[1].type).toBe('news')
|
||||||
|
expect(result?.xai?.sources?.[2].type).toBe('x')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('openrouter provider', () => {
|
||||||
|
it('should return openrouter plugins config', () => {
|
||||||
|
const result = buildProviderBuiltinWebSearchConfig('openrouter', defaultWebSearchConfig)
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
openrouter: {
|
||||||
|
plugins: [
|
||||||
|
{
|
||||||
|
id: 'web',
|
||||||
|
max_results: 50
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should respect custom maxResults', () => {
|
||||||
|
const config: CherryWebSearchConfig = {
|
||||||
|
searchWithTime: true,
|
||||||
|
maxResults: 75,
|
||||||
|
excludeDomains: []
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = buildProviderBuiltinWebSearchConfig('openrouter', config)
|
||||||
|
|
||||||
|
expect(result).toEqual({
|
||||||
|
openrouter: {
|
||||||
|
plugins: [
|
||||||
|
{
|
||||||
|
id: 'web',
|
||||||
|
max_results: 75
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('unsupported provider', () => {
|
||||||
|
it('should return empty object for unsupported provider', () => {
|
||||||
|
const result = buildProviderBuiltinWebSearchConfig('unsupported' as any, defaultWebSearchConfig)
|
||||||
|
|
||||||
|
expect(result).toEqual({})
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should return empty object for google provider', () => {
|
||||||
|
const result = buildProviderBuiltinWebSearchConfig('google', defaultWebSearchConfig)
|
||||||
|
|
||||||
|
expect(result).toEqual({})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
describe('edge cases', () => {
|
||||||
|
it('should handle maxResults at boundary values', () => {
|
||||||
|
// Test boundary at 33 (low/medium)
|
||||||
|
const config33: CherryWebSearchConfig = { searchWithTime: true, maxResults: 33, excludeDomains: [] }
|
||||||
|
const result33 = buildProviderBuiltinWebSearchConfig('openai', config33)
|
||||||
|
expect(result33?.openai?.searchContextSize).toBe('low')
|
||||||
|
|
||||||
|
// Test boundary at 34 (medium)
|
||||||
|
const config34: CherryWebSearchConfig = { searchWithTime: true, maxResults: 34, excludeDomains: [] }
|
||||||
|
const result34 = buildProviderBuiltinWebSearchConfig('openai', config34)
|
||||||
|
expect(result34?.openai?.searchContextSize).toBe('medium')
|
||||||
|
|
||||||
|
// Test boundary at 66 (medium)
|
||||||
|
const config66: CherryWebSearchConfig = { searchWithTime: true, maxResults: 66, excludeDomains: [] }
|
||||||
|
const result66 = buildProviderBuiltinWebSearchConfig('openai', config66)
|
||||||
|
expect(result66?.openai?.searchContextSize).toBe('medium')
|
||||||
|
|
||||||
|
// Test boundary at 67 (high)
|
||||||
|
const config67: CherryWebSearchConfig = { searchWithTime: true, maxResults: 67, excludeDomains: [] }
|
||||||
|
const result67 = buildProviderBuiltinWebSearchConfig('openai', config67)
|
||||||
|
expect(result67?.openai?.searchContextSize).toBe('high')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle zero maxResults', () => {
|
||||||
|
const config: CherryWebSearchConfig = { searchWithTime: true, maxResults: 0, excludeDomains: [] }
|
||||||
|
const result = buildProviderBuiltinWebSearchConfig('openai', config)
|
||||||
|
expect(result?.openai?.searchContextSize).toBe('low')
|
||||||
|
})
|
||||||
|
|
||||||
|
it('should handle very large maxResults', () => {
|
||||||
|
const config: CherryWebSearchConfig = { searchWithTime: true, maxResults: 1000, excludeDomains: [] }
|
||||||
|
const result = buildProviderBuiltinWebSearchConfig('openai', config)
|
||||||
|
expect(result?.openai?.searchContextSize).toBe('high')
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
|
})
|
||||||
@@ -1,16 +1,38 @@
|
|||||||
|
import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock'
|
||||||
|
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
|
||||||
|
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
|
||||||
|
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
|
||||||
|
import type { XaiProviderOptions } from '@ai-sdk/xai'
|
||||||
import { baseProviderIdSchema, customProviderIdSchema } from '@cherrystudio/ai-core/provider'
|
import { baseProviderIdSchema, customProviderIdSchema } from '@cherrystudio/ai-core/provider'
|
||||||
import { isOpenAIModel, isQwenMTModel, isSupportFlexServiceTierModel } from '@renderer/config/models'
|
import { loggerService } from '@logger'
|
||||||
import { isSupportServiceTierProvider } from '@renderer/config/providers'
|
|
||||||
import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
|
|
||||||
import type { Assistant, Model, Provider } from '@renderer/types'
|
|
||||||
import {
|
import {
|
||||||
|
getModelSupportedVerbosity,
|
||||||
|
isOpenAIModel,
|
||||||
|
isQwenMTModel,
|
||||||
|
isSupportFlexServiceTierModel,
|
||||||
|
isSupportVerbosityModel
|
||||||
|
} from '@renderer/config/models'
|
||||||
|
import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
|
||||||
|
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||||
|
import {
|
||||||
|
type Assistant,
|
||||||
|
type GroqServiceTier,
|
||||||
GroqServiceTiers,
|
GroqServiceTiers,
|
||||||
|
type GroqSystemProvider,
|
||||||
isGroqServiceTier,
|
isGroqServiceTier,
|
||||||
|
isGroqSystemProvider,
|
||||||
isOpenAIServiceTier,
|
isOpenAIServiceTier,
|
||||||
isTranslateAssistant,
|
isTranslateAssistant,
|
||||||
|
type Model,
|
||||||
|
type NotGroqProvider,
|
||||||
|
type OpenAIServiceTier,
|
||||||
OpenAIServiceTiers,
|
OpenAIServiceTiers,
|
||||||
SystemProviderIds
|
type Provider,
|
||||||
|
type ServiceTier
|
||||||
} from '@renderer/types'
|
} from '@renderer/types'
|
||||||
|
import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
|
||||||
|
import { isSupportServiceTierProvider } from '@renderer/utils/provider'
|
||||||
|
import type { JSONValue } from 'ai'
|
||||||
import { t } from 'i18next'
|
import { t } from 'i18next'
|
||||||
|
|
||||||
import { getAiSdkProviderId } from '../provider/factory'
|
import { getAiSdkProviderId } from '../provider/factory'
|
||||||
@@ -26,8 +48,33 @@ import {
|
|||||||
} from './reasoning'
|
} from './reasoning'
|
||||||
import { getWebSearchParams } from './websearch'
|
import { getWebSearchParams } from './websearch'
|
||||||
|
|
||||||
// copy from BaseApiClient.ts
|
const logger = loggerService.withContext('aiCore.utils.options')
|
||||||
const getServiceTier = (model: Model, provider: Provider) => {
|
|
||||||
|
function toOpenAIServiceTier(model: Model, serviceTier: ServiceTier): OpenAIServiceTier {
|
||||||
|
if (
|
||||||
|
!isOpenAIServiceTier(serviceTier) ||
|
||||||
|
(serviceTier === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
|
||||||
|
) {
|
||||||
|
return undefined
|
||||||
|
} else {
|
||||||
|
return serviceTier
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function toGroqServiceTier(model: Model, serviceTier: ServiceTier): GroqServiceTier {
|
||||||
|
if (
|
||||||
|
!isGroqServiceTier(serviceTier) ||
|
||||||
|
(serviceTier === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
|
||||||
|
) {
|
||||||
|
return undefined
|
||||||
|
} else {
|
||||||
|
return serviceTier
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function getServiceTier<T extends GroqSystemProvider>(model: Model, provider: T): GroqServiceTier
|
||||||
|
function getServiceTier<T extends NotGroqProvider>(model: Model, provider: T): OpenAIServiceTier
|
||||||
|
function getServiceTier<T extends Provider>(model: Model, provider: T): OpenAIServiceTier | GroqServiceTier {
|
||||||
const serviceTierSetting = provider.serviceTier
|
const serviceTierSetting = provider.serviceTier
|
||||||
|
|
||||||
if (!isSupportServiceTierProvider(provider) || !isOpenAIModel(model) || !serviceTierSetting) {
|
if (!isSupportServiceTierProvider(provider) || !isOpenAIModel(model) || !serviceTierSetting) {
|
||||||
@@ -35,24 +82,17 @@ const getServiceTier = (model: Model, provider: Provider) => {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// 处理不同供应商需要 fallback 到默认值的情况
|
// 处理不同供应商需要 fallback 到默认值的情况
|
||||||
if (provider.id === SystemProviderIds.groq) {
|
if (isGroqSystemProvider(provider)) {
|
||||||
if (
|
return toGroqServiceTier(model, serviceTierSetting)
|
||||||
!isGroqServiceTier(serviceTierSetting) ||
|
|
||||||
(serviceTierSetting === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
|
|
||||||
) {
|
|
||||||
return undefined
|
|
||||||
}
|
|
||||||
} else {
|
} else {
|
||||||
// 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同
|
// 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同
|
||||||
if (
|
return toOpenAIServiceTier(model, serviceTierSetting)
|
||||||
!isOpenAIServiceTier(serviceTierSetting) ||
|
|
||||||
(serviceTierSetting === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
|
|
||||||
) {
|
|
||||||
return undefined
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return serviceTierSetting
|
function getVerbosity(): OpenAIVerbosity {
|
||||||
|
const openAI = getStoreSetting('openAI')
|
||||||
|
return openAI.verbosity
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
@@ -69,12 +109,13 @@ export function buildProviderOptions(
|
|||||||
enableWebSearch: boolean
|
enableWebSearch: boolean
|
||||||
enableGenerateImage: boolean
|
enableGenerateImage: boolean
|
||||||
}
|
}
|
||||||
): Record<string, any> {
|
): Record<string, Record<string, JSONValue>> {
|
||||||
|
logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities })
|
||||||
const rawProviderId = getAiSdkProviderId(actualProvider)
|
const rawProviderId = getAiSdkProviderId(actualProvider)
|
||||||
// 构建 provider 特定的选项
|
// 构建 provider 特定的选项
|
||||||
let providerSpecificOptions: Record<string, any> = {}
|
let providerSpecificOptions: Record<string, any> = {}
|
||||||
const serviceTierSetting = getServiceTier(model, actualProvider)
|
const serviceTier = getServiceTier(model, actualProvider)
|
||||||
providerSpecificOptions.serviceTier = serviceTierSetting
|
const textVerbosity = getVerbosity()
|
||||||
// 根据 provider 类型分离构建逻辑
|
// 根据 provider 类型分离构建逻辑
|
||||||
const { data: baseProviderId, success } = baseProviderIdSchema.safeParse(rawProviderId)
|
const { data: baseProviderId, success } = baseProviderIdSchema.safeParse(rawProviderId)
|
||||||
if (success) {
|
if (success) {
|
||||||
@@ -84,14 +125,16 @@ export function buildProviderOptions(
|
|||||||
case 'openai-chat':
|
case 'openai-chat':
|
||||||
case 'azure':
|
case 'azure':
|
||||||
case 'azure-responses':
|
case 'azure-responses':
|
||||||
providerSpecificOptions = {
|
{
|
||||||
...buildOpenAIProviderOptions(assistant, model, capabilities),
|
const options: OpenAIResponsesProviderOptions = buildOpenAIProviderOptions(
|
||||||
serviceTier: serviceTierSetting
|
assistant,
|
||||||
|
model,
|
||||||
|
capabilities,
|
||||||
|
serviceTier
|
||||||
|
)
|
||||||
|
providerSpecificOptions = options
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
case 'huggingface':
|
|
||||||
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities)
|
|
||||||
break
|
|
||||||
case 'anthropic':
|
case 'anthropic':
|
||||||
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
|
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
|
||||||
break
|
break
|
||||||
@@ -109,12 +152,19 @@ export function buildProviderOptions(
|
|||||||
// 对于其他 provider,使用通用的构建逻辑
|
// 对于其他 provider,使用通用的构建逻辑
|
||||||
providerSpecificOptions = {
|
providerSpecificOptions = {
|
||||||
...buildGenericProviderOptions(assistant, model, capabilities),
|
...buildGenericProviderOptions(assistant, model, capabilities),
|
||||||
serviceTier: serviceTierSetting
|
serviceTier,
|
||||||
|
textVerbosity
|
||||||
}
|
}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
case 'cherryin':
|
case 'cherryin':
|
||||||
providerSpecificOptions = buildCherryInProviderOptions(assistant, model, capabilities, actualProvider)
|
providerSpecificOptions = buildCherryInProviderOptions(
|
||||||
|
assistant,
|
||||||
|
model,
|
||||||
|
capabilities,
|
||||||
|
actualProvider,
|
||||||
|
serviceTier
|
||||||
|
)
|
||||||
break
|
break
|
||||||
default:
|
default:
|
||||||
throw new Error(`Unsupported base provider ${baseProviderId}`)
|
throw new Error(`Unsupported base provider ${baseProviderId}`)
|
||||||
@@ -128,17 +178,22 @@ export function buildProviderOptions(
|
|||||||
case 'google-vertex':
|
case 'google-vertex':
|
||||||
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
|
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
|
||||||
break
|
break
|
||||||
|
case 'azure-anthropic':
|
||||||
case 'google-vertex-anthropic':
|
case 'google-vertex-anthropic':
|
||||||
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
|
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
|
||||||
break
|
break
|
||||||
case 'bedrock':
|
case 'bedrock':
|
||||||
providerSpecificOptions = buildBedrockProviderOptions(assistant, model, capabilities)
|
providerSpecificOptions = buildBedrockProviderOptions(assistant, model, capabilities)
|
||||||
break
|
break
|
||||||
|
case 'huggingface':
|
||||||
|
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier)
|
||||||
|
break
|
||||||
default:
|
default:
|
||||||
// 对于其他 provider,使用通用的构建逻辑
|
// 对于其他 provider,使用通用的构建逻辑
|
||||||
providerSpecificOptions = {
|
providerSpecificOptions = {
|
||||||
...buildGenericProviderOptions(assistant, model, capabilities),
|
...buildGenericProviderOptions(assistant, model, capabilities),
|
||||||
serviceTier: serviceTierSetting
|
serviceTier,
|
||||||
|
textVerbosity
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -152,13 +207,18 @@ export function buildProviderOptions(
|
|||||||
...getCustomParameters(assistant)
|
...getCustomParameters(assistant)
|
||||||
}
|
}
|
||||||
|
|
||||||
const rawProviderKey =
|
let rawProviderKey =
|
||||||
{
|
{
|
||||||
'google-vertex': 'google',
|
'google-vertex': 'google',
|
||||||
'google-vertex-anthropic': 'anthropic',
|
'google-vertex-anthropic': 'anthropic',
|
||||||
|
'azure-anthropic': 'anthropic',
|
||||||
'ai-gateway': 'gateway'
|
'ai-gateway': 'gateway'
|
||||||
}[rawProviderId] || rawProviderId
|
}[rawProviderId] || rawProviderId
|
||||||
|
|
||||||
|
if (rawProviderKey === 'cherryin') {
|
||||||
|
rawProviderKey = { gemini: 'google' }[actualProvider.type] || actualProvider.type
|
||||||
|
}
|
||||||
|
|
||||||
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions }
|
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions }
|
||||||
return {
|
return {
|
||||||
[rawProviderKey]: providerSpecificOptions
|
[rawProviderKey]: providerSpecificOptions
|
||||||
@@ -175,10 +235,11 @@ function buildOpenAIProviderOptions(
|
|||||||
enableReasoning: boolean
|
enableReasoning: boolean
|
||||||
enableWebSearch: boolean
|
enableWebSearch: boolean
|
||||||
enableGenerateImage: boolean
|
enableGenerateImage: boolean
|
||||||
}
|
},
|
||||||
): Record<string, any> {
|
serviceTier: OpenAIServiceTier
|
||||||
|
): OpenAIResponsesProviderOptions {
|
||||||
const { enableReasoning } = capabilities
|
const { enableReasoning } = capabilities
|
||||||
let providerOptions: Record<string, any> = {}
|
let providerOptions: OpenAIResponsesProviderOptions = {}
|
||||||
// OpenAI 推理参数
|
// OpenAI 推理参数
|
||||||
if (enableReasoning) {
|
if (enableReasoning) {
|
||||||
const reasoningParams = getOpenAIReasoningParams(assistant, model)
|
const reasoningParams = getOpenAIReasoningParams(assistant, model)
|
||||||
@@ -187,6 +248,28 @@ function buildOpenAIProviderOptions(
|
|||||||
...reasoningParams
|
...reasoningParams
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isSupportVerbosityModel(model)) {
|
||||||
|
const openAI = getStoreSetting<'openAI'>('openAI')
|
||||||
|
const userVerbosity = openAI?.verbosity
|
||||||
|
|
||||||
|
if (userVerbosity && ['low', 'medium', 'high'].includes(userVerbosity)) {
|
||||||
|
const supportedVerbosity = getModelSupportedVerbosity(model)
|
||||||
|
// Use user's verbosity if supported, otherwise use the first supported option
|
||||||
|
const verbosity = supportedVerbosity.includes(userVerbosity) ? userVerbosity : supportedVerbosity[0]
|
||||||
|
|
||||||
|
providerOptions = {
|
||||||
|
...providerOptions,
|
||||||
|
textVerbosity: verbosity
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
providerOptions = {
|
||||||
|
...providerOptions,
|
||||||
|
serviceTier
|
||||||
|
}
|
||||||
|
|
||||||
return providerOptions
|
return providerOptions
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -201,9 +284,9 @@ function buildAnthropicProviderOptions(
|
|||||||
enableWebSearch: boolean
|
enableWebSearch: boolean
|
||||||
enableGenerateImage: boolean
|
enableGenerateImage: boolean
|
||||||
}
|
}
|
||||||
): Record<string, any> {
|
): AnthropicProviderOptions {
|
||||||
const { enableReasoning } = capabilities
|
const { enableReasoning } = capabilities
|
||||||
let providerOptions: Record<string, any> = {}
|
let providerOptions: AnthropicProviderOptions = {}
|
||||||
|
|
||||||
// Anthropic 推理参数
|
// Anthropic 推理参数
|
||||||
if (enableReasoning) {
|
if (enableReasoning) {
|
||||||
@@ -228,9 +311,9 @@ function buildGeminiProviderOptions(
|
|||||||
enableWebSearch: boolean
|
enableWebSearch: boolean
|
||||||
enableGenerateImage: boolean
|
enableGenerateImage: boolean
|
||||||
}
|
}
|
||||||
): Record<string, any> {
|
): GoogleGenerativeAIProviderOptions {
|
||||||
const { enableReasoning, enableGenerateImage } = capabilities
|
const { enableReasoning, enableGenerateImage } = capabilities
|
||||||
let providerOptions: Record<string, any> = {}
|
let providerOptions: GoogleGenerativeAIProviderOptions = {}
|
||||||
|
|
||||||
// Gemini 推理参数
|
// Gemini 推理参数
|
||||||
if (enableReasoning) {
|
if (enableReasoning) {
|
||||||
@@ -259,7 +342,7 @@ function buildXAIProviderOptions(
|
|||||||
enableWebSearch: boolean
|
enableWebSearch: boolean
|
||||||
enableGenerateImage: boolean
|
enableGenerateImage: boolean
|
||||||
}
|
}
|
||||||
): Record<string, any> {
|
): XaiProviderOptions {
|
||||||
const { enableReasoning } = capabilities
|
const { enableReasoning } = capabilities
|
||||||
let providerOptions: Record<string, any> = {}
|
let providerOptions: Record<string, any> = {}
|
||||||
|
|
||||||
@@ -282,16 +365,12 @@ function buildCherryInProviderOptions(
|
|||||||
enableWebSearch: boolean
|
enableWebSearch: boolean
|
||||||
enableGenerateImage: boolean
|
enableGenerateImage: boolean
|
||||||
},
|
},
|
||||||
actualProvider: Provider
|
actualProvider: Provider,
|
||||||
): Record<string, any> {
|
serviceTier: OpenAIServiceTier
|
||||||
const serviceTierSetting = getServiceTier(model, actualProvider)
|
): OpenAIResponsesProviderOptions | AnthropicProviderOptions | GoogleGenerativeAIProviderOptions {
|
||||||
|
|
||||||
switch (actualProvider.type) {
|
switch (actualProvider.type) {
|
||||||
case 'openai':
|
case 'openai':
|
||||||
return {
|
return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier)
|
||||||
...buildOpenAIProviderOptions(assistant, model, capabilities),
|
|
||||||
serviceTier: serviceTierSetting
|
|
||||||
}
|
|
||||||
|
|
||||||
case 'anthropic':
|
case 'anthropic':
|
||||||
return buildAnthropicProviderOptions(assistant, model, capabilities)
|
return buildAnthropicProviderOptions(assistant, model, capabilities)
|
||||||
@@ -313,9 +392,9 @@ function buildBedrockProviderOptions(
|
|||||||
enableWebSearch: boolean
|
enableWebSearch: boolean
|
||||||
enableGenerateImage: boolean
|
enableGenerateImage: boolean
|
||||||
}
|
}
|
||||||
): Record<string, any> {
|
): BedrockProviderOptions {
|
||||||
const { enableReasoning } = capabilities
|
const { enableReasoning } = capabilities
|
||||||
let providerOptions: Record<string, any> = {}
|
let providerOptions: BedrockProviderOptions = {}
|
||||||
|
|
||||||
if (enableReasoning) {
|
if (enableReasoning) {
|
||||||
const reasoningParams = getBedrockReasoningParams(assistant, model)
|
const reasoningParams = getBedrockReasoningParams(assistant, model)
|
||||||
|
|||||||
@@ -1,3 +1,8 @@
|
|||||||
|
import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock'
|
||||||
|
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
|
||||||
|
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
|
||||||
|
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
|
||||||
|
import type { XaiProviderOptions } from '@ai-sdk/xai'
|
||||||
import { loggerService } from '@logger'
|
import { loggerService } from '@logger'
|
||||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||||
import {
|
import {
|
||||||
@@ -7,6 +12,8 @@ import {
|
|||||||
isDeepSeekHybridInferenceModel,
|
isDeepSeekHybridInferenceModel,
|
||||||
isDoubaoSeedAfter251015,
|
isDoubaoSeedAfter251015,
|
||||||
isDoubaoThinkingAutoModel,
|
isDoubaoThinkingAutoModel,
|
||||||
|
isGemini3Model,
|
||||||
|
isGPT51SeriesModel,
|
||||||
isGrok4FastReasoningModel,
|
isGrok4FastReasoningModel,
|
||||||
isGrokReasoningModel,
|
isGrokReasoningModel,
|
||||||
isOpenAIDeepResearchModel,
|
isOpenAIDeepResearchModel,
|
||||||
@@ -27,13 +34,13 @@ import {
|
|||||||
isSupportedThinkingTokenZhipuModel,
|
isSupportedThinkingTokenZhipuModel,
|
||||||
MODEL_SUPPORTED_REASONING_EFFORT
|
MODEL_SUPPORTED_REASONING_EFFORT
|
||||||
} from '@renderer/config/models'
|
} from '@renderer/config/models'
|
||||||
import { isSupportEnableThinkingProvider } from '@renderer/config/providers'
|
|
||||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||||
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
|
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
|
||||||
import type { SettingsState } from '@renderer/store/settings'
|
import type { Assistant, Model, ReasoningEffortOption } from '@renderer/types'
|
||||||
import type { Assistant, Model } from '@renderer/types'
|
|
||||||
import { EFFORT_RATIO, isSystemProvider, SystemProviderIds } from '@renderer/types'
|
import { EFFORT_RATIO, isSystemProvider, SystemProviderIds } from '@renderer/types'
|
||||||
|
import type { OpenAISummaryText } from '@renderer/types/aiCoreTypes'
|
||||||
import type { ReasoningEffortOptionalParams } from '@renderer/types/sdk'
|
import type { ReasoningEffortOptionalParams } from '@renderer/types/sdk'
|
||||||
|
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
|
||||||
import { toInteger } from 'lodash'
|
import { toInteger } from 'lodash'
|
||||||
|
|
||||||
const logger = loggerService.withContext('reasoning')
|
const logger = loggerService.withContext('reasoning')
|
||||||
@@ -56,13 +63,20 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
|
|||||||
}
|
}
|
||||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||||
|
|
||||||
if (!reasoningEffort) {
|
// Handle undefined and 'none' reasoningEffort.
|
||||||
|
// TODO: They should be separated.
|
||||||
|
if (!reasoningEffort || reasoningEffort === 'none') {
|
||||||
// openrouter: use reasoning
|
// openrouter: use reasoning
|
||||||
if (model.provider === SystemProviderIds.openrouter) {
|
if (model.provider === SystemProviderIds.openrouter) {
|
||||||
// Don't disable reasoning for Gemini models that support thinking tokens
|
// Don't disable reasoning for Gemini models that support thinking tokens
|
||||||
if (isSupportedThinkingTokenGeminiModel(model) && !GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
|
if (isSupportedThinkingTokenGeminiModel(model) && !GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
|
||||||
return {}
|
return {}
|
||||||
}
|
}
|
||||||
|
// 'none' is not an available value for effort for now.
|
||||||
|
// I think they should resolve this issue soon, so I'll just go ahead and use this value.
|
||||||
|
if (isGPT51SeriesModel(model) && reasoningEffort === 'none') {
|
||||||
|
return { reasoning: { effort: 'none' } }
|
||||||
|
}
|
||||||
// Don't disable reasoning for models that require it
|
// Don't disable reasoning for models that require it
|
||||||
if (
|
if (
|
||||||
isGrokReasoningModel(model) ||
|
isGrokReasoningModel(model) ||
|
||||||
@@ -117,6 +131,13 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
|
|||||||
return { thinking: { type: 'disabled' } }
|
return { thinking: { type: 'disabled' } }
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Specially for GPT-5.1. Suppose this is a OpenAI Compatible provider
|
||||||
|
if (isGPT51SeriesModel(model)) {
|
||||||
|
return {
|
||||||
|
reasoningEffort: 'none'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return {}
|
return {}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -259,6 +280,12 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
|
|||||||
|
|
||||||
// gemini series, openai compatible api
|
// gemini series, openai compatible api
|
||||||
if (isSupportedThinkingTokenGeminiModel(model)) {
|
if (isSupportedThinkingTokenGeminiModel(model)) {
|
||||||
|
// https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#openai_compatibility
|
||||||
|
if (isGemini3Model(model)) {
|
||||||
|
return {
|
||||||
|
reasoning_effort: reasoningEffort
|
||||||
|
}
|
||||||
|
}
|
||||||
if (reasoningEffort === 'auto') {
|
if (reasoningEffort === 'auto') {
|
||||||
return {
|
return {
|
||||||
extra_body: {
|
extra_body: {
|
||||||
@@ -322,10 +349,14 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
|
|||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取 OpenAI 推理参数
|
* Get OpenAI reasoning parameters
|
||||||
* 从 OpenAIResponseAPIClient 和 OpenAIAPIClient 中提取的逻辑
|
* Extracted from OpenAIResponseAPIClient and OpenAIAPIClient logic
|
||||||
|
* For official OpenAI provider only
|
||||||
*/
|
*/
|
||||||
export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
|
export function getOpenAIReasoningParams(
|
||||||
|
assistant: Assistant,
|
||||||
|
model: Model
|
||||||
|
): Pick<OpenAIResponsesProviderOptions, 'reasoningEffort' | 'reasoningSummary'> {
|
||||||
if (!isReasoningModel(model)) {
|
if (!isReasoningModel(model)) {
|
||||||
return {}
|
return {}
|
||||||
}
|
}
|
||||||
@@ -336,6 +367,10 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
|
|||||||
return {}
|
return {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (isOpenAIDeepResearchModel(model) || reasoningEffort === 'auto') {
|
||||||
|
reasoningEffort = 'medium'
|
||||||
|
}
|
||||||
|
|
||||||
// 非OpenAI模型,但是Provider类型是responses/azure openai的情况
|
// 非OpenAI模型,但是Provider类型是responses/azure openai的情况
|
||||||
if (!isOpenAIModel(model)) {
|
if (!isOpenAIModel(model)) {
|
||||||
return {
|
return {
|
||||||
@@ -343,21 +378,17 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
|
const openAI = getStoreSetting('openAI')
|
||||||
const summaryText = openAI?.summaryText || 'off'
|
const summaryText = openAI.summaryText
|
||||||
|
|
||||||
let reasoningSummary: string | undefined = undefined
|
let reasoningSummary: OpenAISummaryText = undefined
|
||||||
|
|
||||||
if (summaryText === 'off' || model.id.includes('o1-pro')) {
|
if (model.id.includes('o1-pro')) {
|
||||||
reasoningSummary = undefined
|
reasoningSummary = undefined
|
||||||
} else {
|
} else {
|
||||||
reasoningSummary = summaryText
|
reasoningSummary = summaryText
|
||||||
}
|
}
|
||||||
|
|
||||||
if (isOpenAIDeepResearchModel(model)) {
|
|
||||||
reasoningEffort = 'medium'
|
|
||||||
}
|
|
||||||
|
|
||||||
// OpenAI 推理参数
|
// OpenAI 推理参数
|
||||||
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||||
return {
|
return {
|
||||||
@@ -369,19 +400,26 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
|
|||||||
return {}
|
return {}
|
||||||
}
|
}
|
||||||
|
|
||||||
export function getAnthropicThinkingBudget(assistant: Assistant, model: Model): number {
|
export function getAnthropicThinkingBudget(
|
||||||
const { maxTokens, reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
|
maxTokens: number | undefined,
|
||||||
if (reasoningEffort === undefined) {
|
reasoningEffort: string | undefined,
|
||||||
return 0
|
modelId: string
|
||||||
|
): number | undefined {
|
||||||
|
if (reasoningEffort === undefined || reasoningEffort === 'none') {
|
||||||
|
return undefined
|
||||||
}
|
}
|
||||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||||
|
|
||||||
|
const tokenLimit = findTokenLimit(modelId)
|
||||||
|
if (!tokenLimit) {
|
||||||
|
return undefined
|
||||||
|
}
|
||||||
|
|
||||||
const budgetTokens = Math.max(
|
const budgetTokens = Math.max(
|
||||||
1024,
|
1024,
|
||||||
Math.floor(
|
Math.floor(
|
||||||
Math.min(
|
Math.min(
|
||||||
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
|
(tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min,
|
||||||
findTokenLimit(model.id)?.min!,
|
|
||||||
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
|
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -393,14 +431,17 @@ export function getAnthropicThinkingBudget(assistant: Assistant, model: Model):
|
|||||||
* 获取 Anthropic 推理参数
|
* 获取 Anthropic 推理参数
|
||||||
* 从 AnthropicAPIClient 中提取的逻辑
|
* 从 AnthropicAPIClient 中提取的逻辑
|
||||||
*/
|
*/
|
||||||
export function getAnthropicReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
|
export function getAnthropicReasoningParams(
|
||||||
|
assistant: Assistant,
|
||||||
|
model: Model
|
||||||
|
): Pick<AnthropicProviderOptions, 'thinking'> {
|
||||||
if (!isReasoningModel(model)) {
|
if (!isReasoningModel(model)) {
|
||||||
return {}
|
return {}
|
||||||
}
|
}
|
||||||
|
|
||||||
const reasoningEffort = assistant?.settings?.reasoning_effort
|
const reasoningEffort = assistant?.settings?.reasoning_effort
|
||||||
|
|
||||||
if (reasoningEffort === undefined) {
|
if (reasoningEffort === undefined || reasoningEffort === 'none') {
|
||||||
return {
|
return {
|
||||||
thinking: {
|
thinking: {
|
||||||
type: 'disabled'
|
type: 'disabled'
|
||||||
@@ -410,7 +451,8 @@ export function getAnthropicReasoningParams(assistant: Assistant, model: Model):
|
|||||||
|
|
||||||
// Claude 推理参数
|
// Claude 推理参数
|
||||||
if (isSupportedThinkingTokenClaudeModel(model)) {
|
if (isSupportedThinkingTokenClaudeModel(model)) {
|
||||||
const budgetTokens = getAnthropicThinkingBudget(assistant, model)
|
const { maxTokens } = getAssistantSettings(assistant)
|
||||||
|
const budgetTokens = getAnthropicThinkingBudget(maxTokens, reasoningEffort, model.id)
|
||||||
|
|
||||||
return {
|
return {
|
||||||
thinking: {
|
thinking: {
|
||||||
@@ -423,13 +465,31 @@ export function getAnthropicReasoningParams(assistant: Assistant, model: Model):
|
|||||||
return {}
|
return {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type GoogelThinkingLevel = NonNullable<GoogleGenerativeAIProviderOptions['thinkingConfig']>['thinkingLevel']
|
||||||
|
|
||||||
|
function mapToGeminiThinkingLevel(reasoningEffort: ReasoningEffortOption): GoogelThinkingLevel {
|
||||||
|
switch (reasoningEffort) {
|
||||||
|
case 'low':
|
||||||
|
return 'low'
|
||||||
|
case 'medium':
|
||||||
|
return 'medium'
|
||||||
|
case 'high':
|
||||||
|
return 'high'
|
||||||
|
default:
|
||||||
|
return 'medium'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* 获取 Gemini 推理参数
|
* 获取 Gemini 推理参数
|
||||||
* 从 GeminiAPIClient 中提取的逻辑
|
* 从 GeminiAPIClient 中提取的逻辑
|
||||||
* 注意:Gemini/GCP 端点所使用的 thinkingBudget 等参数应该按照驼峰命名法传递
|
* 注意:Gemini/GCP 端点所使用的 thinkingBudget 等参数应该按照驼峰命名法传递
|
||||||
* 而在 Google 官方提供的 OpenAI 兼容端点中则使用蛇形命名法 thinking_budget
|
* 而在 Google 官方提供的 OpenAI 兼容端点中则使用蛇形命名法 thinking_budget
|
||||||
*/
|
*/
|
||||||
export function getGeminiReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
|
export function getGeminiReasoningParams(
|
||||||
|
assistant: Assistant,
|
||||||
|
model: Model
|
||||||
|
): Pick<GoogleGenerativeAIProviderOptions, 'thinkingConfig'> {
|
||||||
if (!isReasoningModel(model)) {
|
if (!isReasoningModel(model)) {
|
||||||
return {}
|
return {}
|
||||||
}
|
}
|
||||||
@@ -438,7 +498,7 @@ export function getGeminiReasoningParams(assistant: Assistant, model: Model): Re
|
|||||||
|
|
||||||
// Gemini 推理参数
|
// Gemini 推理参数
|
||||||
if (isSupportedThinkingTokenGeminiModel(model)) {
|
if (isSupportedThinkingTokenGeminiModel(model)) {
|
||||||
if (reasoningEffort === undefined) {
|
if (reasoningEffort === undefined || reasoningEffort === 'none') {
|
||||||
return {
|
return {
|
||||||
thinkingConfig: {
|
thinkingConfig: {
|
||||||
includeThoughts: false,
|
includeThoughts: false,
|
||||||
@@ -447,6 +507,15 @@ export function getGeminiReasoningParams(assistant: Assistant, model: Model): Re
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#new_api_features_in_gemini_3
|
||||||
|
if (isGemini3Model(model)) {
|
||||||
|
return {
|
||||||
|
thinkingConfig: {
|
||||||
|
thinkingLevel: mapToGeminiThinkingLevel(reasoningEffort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||||
|
|
||||||
if (effortRatio > 1) {
|
if (effortRatio > 1) {
|
||||||
@@ -478,27 +547,35 @@ export function getGeminiReasoningParams(assistant: Assistant, model: Model): Re
|
|||||||
* @param model - The model being used
|
* @param model - The model being used
|
||||||
* @returns XAI-specific reasoning parameters
|
* @returns XAI-specific reasoning parameters
|
||||||
*/
|
*/
|
||||||
export function getXAIReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
|
export function getXAIReasoningParams(assistant: Assistant, model: Model): Pick<XaiProviderOptions, 'reasoningEffort'> {
|
||||||
if (!isSupportedReasoningEffortGrokModel(model)) {
|
if (!isSupportedReasoningEffortGrokModel(model)) {
|
||||||
return {}
|
return {}
|
||||||
}
|
}
|
||||||
|
|
||||||
const { reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
|
const { reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
|
||||||
|
|
||||||
if (!reasoningEffort) {
|
if (!reasoningEffort || reasoningEffort === 'none') {
|
||||||
return {}
|
return {}
|
||||||
}
|
}
|
||||||
|
|
||||||
// For XAI provider Grok models, use reasoningEffort parameter directly
|
switch (reasoningEffort) {
|
||||||
return {
|
case 'auto':
|
||||||
reasoningEffort
|
case 'minimal':
|
||||||
|
case 'medium':
|
||||||
|
return { reasoningEffort: 'low' }
|
||||||
|
case 'low':
|
||||||
|
case 'high':
|
||||||
|
return { reasoningEffort }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Get Bedrock reasoning parameters
|
* Get Bedrock reasoning parameters
|
||||||
*/
|
*/
|
||||||
export function getBedrockReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
|
export function getBedrockReasoningParams(
|
||||||
|
assistant: Assistant,
|
||||||
|
model: Model
|
||||||
|
): Pick<BedrockProviderOptions, 'reasoningConfig'> {
|
||||||
if (!isReasoningModel(model)) {
|
if (!isReasoningModel(model)) {
|
||||||
return {}
|
return {}
|
||||||
}
|
}
|
||||||
@@ -509,12 +586,21 @@ export function getBedrockReasoningParams(assistant: Assistant, model: Model): R
|
|||||||
return {}
|
return {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (reasoningEffort === 'none') {
|
||||||
|
return {
|
||||||
|
reasoningConfig: {
|
||||||
|
type: 'disabled'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Only apply thinking budget for Claude reasoning models
|
// Only apply thinking budget for Claude reasoning models
|
||||||
if (!isSupportedThinkingTokenClaudeModel(model)) {
|
if (!isSupportedThinkingTokenClaudeModel(model)) {
|
||||||
return {}
|
return {}
|
||||||
}
|
}
|
||||||
|
|
||||||
const budgetTokens = getAnthropicThinkingBudget(assistant, model)
|
const { maxTokens } = getAssistantSettings(assistant)
|
||||||
|
const budgetTokens = getAnthropicThinkingBudget(maxTokens, reasoningEffort, model.id)
|
||||||
return {
|
return {
|
||||||
reasoningConfig: {
|
reasoningConfig: {
|
||||||
type: 'enabled',
|
type: 'enabled',
|
||||||
|
|||||||
@@ -47,6 +47,7 @@ export function buildProviderBuiltinWebSearchConfig(
|
|||||||
model?: Model
|
model?: Model
|
||||||
): WebSearchPluginConfig | undefined {
|
): WebSearchPluginConfig | undefined {
|
||||||
switch (providerId) {
|
switch (providerId) {
|
||||||
|
case 'azure-responses':
|
||||||
case 'openai': {
|
case 'openai': {
|
||||||
const searchContextSize = isOpenAIDeepResearchModel(model)
|
const searchContextSize = isOpenAIDeepResearchModel(model)
|
||||||
? 'medium'
|
? 'medium'
|
||||||
|
|||||||
BIN
src/renderer/src/assets/images/models/gpt-5.1-chat.png
Normal file
BIN
src/renderer/src/assets/images/models/gpt-5.1-chat.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 19 KiB |
BIN
src/renderer/src/assets/images/models/gpt-5.1-codex-mini.png
Normal file
BIN
src/renderer/src/assets/images/models/gpt-5.1-codex-mini.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 21 KiB |
BIN
src/renderer/src/assets/images/models/gpt-5.1-codex.png
Normal file
BIN
src/renderer/src/assets/images/models/gpt-5.1-codex.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 20 KiB |
BIN
src/renderer/src/assets/images/models/gpt-5.1.png
Normal file
BIN
src/renderer/src/assets/images/models/gpt-5.1.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 18 KiB |
@@ -1,5 +1,6 @@
|
|||||||
import { ActionIconButton } from '@renderer/components/Buttons'
|
import { ActionIconButton } from '@renderer/components/Buttons'
|
||||||
import NarrowLayout from '@renderer/pages/home/Messages/NarrowLayout'
|
import NarrowLayout from '@renderer/pages/home/Messages/NarrowLayout'
|
||||||
|
import { scrollElementIntoView } from '@renderer/utils'
|
||||||
import { Tooltip } from 'antd'
|
import { Tooltip } from 'antd'
|
||||||
import { debounce } from 'lodash'
|
import { debounce } from 'lodash'
|
||||||
import { CaseSensitive, ChevronDown, ChevronUp, User, WholeWord, X } from 'lucide-react'
|
import { CaseSensitive, ChevronDown, ChevronUp, User, WholeWord, X } from 'lucide-react'
|
||||||
@@ -181,17 +182,14 @@ export const ContentSearch = React.forwardRef<ContentSearchRef, Props>(
|
|||||||
// 3. 将当前项滚动到视图中
|
// 3. 将当前项滚动到视图中
|
||||||
// 获取第一个文本节点的父元素来进行滚动
|
// 获取第一个文本节点的父元素来进行滚动
|
||||||
const parentElement = currentMatchRange.startContainer.parentElement
|
const parentElement = currentMatchRange.startContainer.parentElement
|
||||||
if (shouldScroll) {
|
if (shouldScroll && parentElement) {
|
||||||
parentElement?.scrollIntoView({
|
// 优先在指定的滚动容器内滚动,避免滚动整个页面导致索引错乱/看起来"跳到第一条"
|
||||||
behavior: 'smooth',
|
scrollElementIntoView(parentElement, target)
|
||||||
block: 'center',
|
|
||||||
inline: 'nearest'
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
[allRanges, currentIndex]
|
[allRanges, currentIndex, target]
|
||||||
)
|
)
|
||||||
|
|
||||||
const search = useCallback(
|
const search = useCallback(
|
||||||
|
|||||||
@@ -1,35 +1,120 @@
|
|||||||
|
import 'emoji-picker-element'
|
||||||
|
|
||||||
import TwemojiCountryFlagsWoff2 from '@renderer/assets/fonts/country-flag-fonts/TwemojiCountryFlags.woff2?url'
|
import TwemojiCountryFlagsWoff2 from '@renderer/assets/fonts/country-flag-fonts/TwemojiCountryFlags.woff2?url'
|
||||||
import { useTheme } from '@renderer/context/ThemeProvider'
|
import { useTheme } from '@renderer/context/ThemeProvider'
|
||||||
|
import type { LanguageVarious } from '@renderer/types'
|
||||||
import { polyfillCountryFlagEmojis } from 'country-flag-emoji-polyfill'
|
import { polyfillCountryFlagEmojis } from 'country-flag-emoji-polyfill'
|
||||||
|
// i18n translations from emoji-picker-element
|
||||||
|
import de from 'emoji-picker-element/i18n/de'
|
||||||
|
import en from 'emoji-picker-element/i18n/en'
|
||||||
|
import es from 'emoji-picker-element/i18n/es'
|
||||||
|
import fr from 'emoji-picker-element/i18n/fr'
|
||||||
|
import ja from 'emoji-picker-element/i18n/ja'
|
||||||
|
import pt_PT from 'emoji-picker-element/i18n/pt_PT'
|
||||||
|
import ru_RU from 'emoji-picker-element/i18n/ru_RU'
|
||||||
|
import zh_CN from 'emoji-picker-element/i18n/zh_CN'
|
||||||
|
import type Picker from 'emoji-picker-element/picker'
|
||||||
|
import type { EmojiClickEvent, NativeEmoji } from 'emoji-picker-element/shared'
|
||||||
|
// Emoji data from emoji-picker-element-data (local, no CDN)
|
||||||
|
// Using CLDR format for full multi-language search support (28 languages)
|
||||||
|
import dataDE from 'emoji-picker-element-data/de/cldr/data.json?url'
|
||||||
|
import dataEN from 'emoji-picker-element-data/en/cldr/data.json?url'
|
||||||
|
import dataES from 'emoji-picker-element-data/es/cldr/data.json?url'
|
||||||
|
import dataFR from 'emoji-picker-element-data/fr/cldr/data.json?url'
|
||||||
|
import dataJA from 'emoji-picker-element-data/ja/cldr/data.json?url'
|
||||||
|
import dataPT from 'emoji-picker-element-data/pt/cldr/data.json?url'
|
||||||
|
import dataRU from 'emoji-picker-element-data/ru/cldr/data.json?url'
|
||||||
|
import dataZH from 'emoji-picker-element-data/zh/cldr/data.json?url'
|
||||||
|
import dataZH_HANT from 'emoji-picker-element-data/zh-hant/cldr/data.json?url'
|
||||||
import type { FC } from 'react'
|
import type { FC } from 'react'
|
||||||
import { useEffect, useRef } from 'react'
|
import { useEffect, useRef } from 'react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
|
||||||
interface Props {
|
interface Props {
|
||||||
onEmojiClick: (emoji: string) => void
|
onEmojiClick: (emoji: string) => void
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Mapping from app locale to emoji-picker-element i18n
|
||||||
|
const i18nMap: Record<LanguageVarious, typeof en> = {
|
||||||
|
'en-US': en,
|
||||||
|
'zh-CN': zh_CN,
|
||||||
|
'zh-TW': zh_CN, // Closest available
|
||||||
|
'de-DE': de,
|
||||||
|
'el-GR': en, // No Greek available, fallback to English
|
||||||
|
'es-ES': es,
|
||||||
|
'fr-FR': fr,
|
||||||
|
'ja-JP': ja,
|
||||||
|
'pt-PT': pt_PT,
|
||||||
|
'ru-RU': ru_RU
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mapping from app locale to emoji data URL
|
||||||
|
// Using CLDR format provides native language search support for all locales
|
||||||
|
const dataSourceMap: Record<LanguageVarious, string> = {
|
||||||
|
'en-US': dataEN,
|
||||||
|
'zh-CN': dataZH,
|
||||||
|
'zh-TW': dataZH_HANT,
|
||||||
|
'de-DE': dataDE,
|
||||||
|
'el-GR': dataEN, // No Greek CLDR available, fallback to English
|
||||||
|
'es-ES': dataES,
|
||||||
|
'fr-FR': dataFR,
|
||||||
|
'ja-JP': dataJA,
|
||||||
|
'pt-PT': dataPT,
|
||||||
|
'ru-RU': dataRU
|
||||||
|
}
|
||||||
|
|
||||||
|
// Mapping from app locale to emoji-picker-element locale string
|
||||||
|
// Must match the data source locale for proper IndexedDB caching
|
||||||
|
const localeMap: Record<LanguageVarious, string> = {
|
||||||
|
'en-US': 'en',
|
||||||
|
'zh-CN': 'zh',
|
||||||
|
'zh-TW': 'zh-hant',
|
||||||
|
'de-DE': 'de',
|
||||||
|
'el-GR': 'en',
|
||||||
|
'es-ES': 'es',
|
||||||
|
'fr-FR': 'fr',
|
||||||
|
'ja-JP': 'ja',
|
||||||
|
'pt-PT': 'pt',
|
||||||
|
'ru-RU': 'ru'
|
||||||
|
}
|
||||||
|
|
||||||
const EmojiPicker: FC<Props> = ({ onEmojiClick }) => {
|
const EmojiPicker: FC<Props> = ({ onEmojiClick }) => {
|
||||||
const { theme } = useTheme()
|
const { theme } = useTheme()
|
||||||
const ref = useRef<HTMLDivElement>(null)
|
const { i18n } = useTranslation()
|
||||||
|
const ref = useRef<Picker>(null)
|
||||||
|
const currentLocale = i18n.language as LanguageVarious
|
||||||
|
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
polyfillCountryFlagEmojis('Twemoji Mozilla', TwemojiCountryFlagsWoff2)
|
polyfillCountryFlagEmojis('Twemoji Mozilla', TwemojiCountryFlagsWoff2)
|
||||||
}, [])
|
}, [])
|
||||||
|
|
||||||
|
// Configure picker with i18n and dataSource
|
||||||
useEffect(() => {
|
useEffect(() => {
|
||||||
const refValue = ref.current
|
const picker = ref.current
|
||||||
|
if (picker) {
|
||||||
|
picker.i18n = i18nMap[currentLocale] || en
|
||||||
|
picker.dataSource = dataSourceMap[currentLocale] || dataEN
|
||||||
|
picker.locale = localeMap[currentLocale] || 'en'
|
||||||
|
}
|
||||||
|
}, [currentLocale])
|
||||||
|
|
||||||
if (refValue) {
|
useEffect(() => {
|
||||||
const handleEmojiClick = (event: any) => {
|
const picker = ref.current
|
||||||
|
|
||||||
|
if (picker) {
|
||||||
|
const handleEmojiClick = (event: EmojiClickEvent) => {
|
||||||
event.stopPropagation()
|
event.stopPropagation()
|
||||||
onEmojiClick(event.detail.unicode || event.detail.emoji.unicode)
|
const { detail } = event
|
||||||
|
// Use detail.unicode (processed with skin tone) or fallback to emoji's unicode for native emoji
|
||||||
|
const unicode = detail.unicode || ('unicode' in detail.emoji ? (detail.emoji as NativeEmoji).unicode : '')
|
||||||
|
onEmojiClick(unicode)
|
||||||
}
|
}
|
||||||
// 添加事件监听器
|
// 添加事件监听器
|
||||||
refValue.addEventListener('emoji-click', handleEmojiClick)
|
picker.addEventListener('emoji-click', handleEmojiClick)
|
||||||
|
|
||||||
// 清理事件监听器
|
// 清理事件监听器
|
||||||
return () => {
|
return () => {
|
||||||
refValue.removeEventListener('emoji-click', handleEmojiClick)
|
picker.removeEventListener('emoji-click', handleEmojiClick)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
|
|||||||
157
src/renderer/src/components/MCPUIRenderer/MCPUIRenderer.tsx
Normal file
157
src/renderer/src/components/MCPUIRenderer/MCPUIRenderer.tsx
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
import { loggerService } from '@logger'
|
||||||
|
import type { UIActionResult } from '@mcp-ui/client'
|
||||||
|
import { UIResourceRenderer } from '@mcp-ui/client'
|
||||||
|
import type { EmbeddedResource } from '@modelcontextprotocol/sdk/types.js'
|
||||||
|
import { isUIResource } from '@renderer/types'
|
||||||
|
import type { FC } from 'react'
|
||||||
|
import { useCallback, useState } from 'react'
|
||||||
|
import { useTranslation } from 'react-i18next'
|
||||||
|
import styled from 'styled-components'
|
||||||
|
|
||||||
|
const logger = loggerService.withContext('MCPUIRenderer')
|
||||||
|
|
||||||
|
interface Props {
|
||||||
|
resource: EmbeddedResource
|
||||||
|
serverId?: string
|
||||||
|
serverName?: string
|
||||||
|
onToolCall?: (toolName: string, params: any) => Promise<any>
|
||||||
|
}
|
||||||
|
|
||||||
|
const MCPUIRenderer: FC<Props> = ({ resource, onToolCall }) => {
|
||||||
|
const { t } = useTranslation()
|
||||||
|
const [error] = useState<string | null>(null)
|
||||||
|
|
||||||
|
const handleUIAction = useCallback(
|
||||||
|
async (result: UIActionResult): Promise<any> => {
|
||||||
|
logger.debug('UI Action received:', result)
|
||||||
|
|
||||||
|
try {
|
||||||
|
switch (result.type) {
|
||||||
|
case 'tool': {
|
||||||
|
// Handle tool call from UI
|
||||||
|
if (onToolCall) {
|
||||||
|
const { toolName, params } = result.payload
|
||||||
|
logger.info(`UI requesting tool call: ${toolName}`, { params })
|
||||||
|
const response = await onToolCall(toolName, params)
|
||||||
|
|
||||||
|
// Check if the response contains a UIResource
|
||||||
|
try {
|
||||||
|
if (response && response.content && Array.isArray(response.content)) {
|
||||||
|
const firstContent = response.content[0]
|
||||||
|
if (firstContent && firstContent.type === 'text' && firstContent.text) {
|
||||||
|
const parsedText = JSON.parse(firstContent.text)
|
||||||
|
if (isUIResource(parsedText)) {
|
||||||
|
// Return the UIResource directly for rendering in the iframe
|
||||||
|
logger.info('Tool response contains UIResource:', { uri: parsedText.resource.uri })
|
||||||
|
return { status: 'success', data: parsedText }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} catch (parseError) {
|
||||||
|
// Not a UIResource, return the original response
|
||||||
|
logger.debug('Tool response is not a UIResource')
|
||||||
|
}
|
||||||
|
|
||||||
|
return { status: 'success', data: response }
|
||||||
|
} else {
|
||||||
|
logger.warn('Tool call requested but no handler provided')
|
||||||
|
return { status: 'error', message: 'Tool call handler not available' }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'intent': {
|
||||||
|
// Handle user intent
|
||||||
|
logger.info('UI intent:', result.payload)
|
||||||
|
window.toast.info(t('message.mcp.ui.intent_received'))
|
||||||
|
return { status: 'acknowledged' }
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'notify': {
|
||||||
|
// Handle notification from UI
|
||||||
|
logger.info('UI notification:', result.payload)
|
||||||
|
window.toast.info(result.payload.message || t('message.mcp.ui.notification'))
|
||||||
|
return { status: 'acknowledged' }
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'prompt': {
|
||||||
|
// Handle prompt request from UI
|
||||||
|
logger.info('UI prompt request:', result.payload)
|
||||||
|
// TODO: Integrate with prompt system
|
||||||
|
return { status: 'error', message: 'Prompt execution not yet implemented' }
|
||||||
|
}
|
||||||
|
|
||||||
|
case 'link': {
|
||||||
|
// Handle navigation request
|
||||||
|
const { url } = result.payload
|
||||||
|
logger.info('UI navigation request:', { url })
|
||||||
|
window.open(url, '_blank')
|
||||||
|
return { status: 'acknowledged' }
|
||||||
|
}
|
||||||
|
|
||||||
|
default:
|
||||||
|
logger.warn('Unknown UI action type:', { result })
|
||||||
|
return { status: 'error', message: 'Unknown action type' }
|
||||||
|
}
|
||||||
|
} catch (err) {
|
||||||
|
logger.error('Error handling UI action:', err as Error)
|
||||||
|
return {
|
||||||
|
status: 'error',
|
||||||
|
message: err instanceof Error ? err.message : 'Unknown error'
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
[onToolCall, t]
|
||||||
|
)
|
||||||
|
|
||||||
|
if (error) {
|
||||||
|
return (
|
||||||
|
<ErrorContainer>
|
||||||
|
<ErrorTitle>{t('message.mcp.ui.error')}</ErrorTitle>
|
||||||
|
<ErrorMessage>{error}</ErrorMessage>
|
||||||
|
</ErrorContainer>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
return (
|
||||||
|
<UIContainer>
|
||||||
|
<UIResourceRenderer resource={resource} onUIAction={handleUIAction} />
|
||||||
|
</UIContainer>
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
const UIContainer = styled.div`
|
||||||
|
width: 100%;
|
||||||
|
min-height: 400px;
|
||||||
|
border-radius: 8px;
|
||||||
|
overflow: hidden;
|
||||||
|
background: var(--color-background);
|
||||||
|
border: 1px solid var(--color-border);
|
||||||
|
|
||||||
|
iframe {
|
||||||
|
width: 100%;
|
||||||
|
border: none;
|
||||||
|
min-height: 400px;
|
||||||
|
height: 600px;
|
||||||
|
}
|
||||||
|
`
|
||||||
|
|
||||||
|
const ErrorContainer = styled.div`
|
||||||
|
padding: 16px;
|
||||||
|
border-radius: 8px;
|
||||||
|
background: var(--color-error-bg, #fee);
|
||||||
|
border: 1px solid var(--color-error-border, #fcc);
|
||||||
|
color: var(--color-error-text, #c33);
|
||||||
|
`
|
||||||
|
|
||||||
|
const ErrorTitle = styled.div`
|
||||||
|
font-weight: 600;
|
||||||
|
margin-bottom: 8px;
|
||||||
|
font-size: 14px;
|
||||||
|
`
|
||||||
|
|
||||||
|
const ErrorMessage = styled.div`
|
||||||
|
font-size: 13px;
|
||||||
|
opacity: 0.9;
|
||||||
|
`
|
||||||
|
|
||||||
|
export default MCPUIRenderer
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user