Compare commits
5 Commits
feat/mcp-u
...
copilot/fi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3093a9e5d0 | ||
|
|
3274723b1e | ||
|
|
5c724a03a6 | ||
|
|
a95e776699 | ||
|
|
be99f4df71 |
2
.github/workflows/auto-i18n.yml
vendored
2
.github/workflows/auto-i18n.yml
vendored
@@ -77,7 +77,7 @@ jobs:
|
||||
with:
|
||||
token: ${{ secrets.GITHUB_TOKEN }} # Use the built-in GITHUB_TOKEN for bot actions
|
||||
commit-message: "feat(bot): Weekly automated script run"
|
||||
title: "🤖 Weekly Auto I18N Sync: ${{ env.CURRENT_DATE }}"
|
||||
title: "🤖 Weekly Automated Update: ${{ env.CURRENT_DATE }}"
|
||||
body: |
|
||||
This PR includes changes generated by the weekly auto i18n.
|
||||
Review the changes before merging.
|
||||
|
||||
152
.yarn/patches/@ai-sdk-google-npm-2.0.36-6f3cc06026.patch
vendored
Normal file
152
.yarn/patches/@ai-sdk-google-npm-2.0.36-6f3cc06026.patch
vendored
Normal file
@@ -0,0 +1,152 @@
|
||||
diff --git a/dist/index.js b/dist/index.js
|
||||
index c2ef089c42e13a8ee4a833899a415564130e5d79..75efa7baafb0f019fb44dd50dec1641eee8879e7 100644
|
||||
--- a/dist/index.js
|
||||
+++ b/dist/index.js
|
||||
@@ -471,7 +471,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
|
||||
|
||||
// src/get-model-path.ts
|
||||
function getModelPath(modelId) {
|
||||
- return modelId.includes("/") ? modelId : `models/${modelId}`;
|
||||
+ return modelId.includes("models/") ? modelId : `models/${modelId}`;
|
||||
}
|
||||
|
||||
// src/google-generative-ai-options.ts
|
||||
diff --git a/dist/index.mjs b/dist/index.mjs
|
||||
index d75c0cc13c41192408c1f3f2d29d76a7bffa6268..ada730b8cb97d9b7d4cb32883a1d1ff416404d9b 100644
|
||||
--- a/dist/index.mjs
|
||||
+++ b/dist/index.mjs
|
||||
@@ -477,7 +477,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
|
||||
|
||||
// src/get-model-path.ts
|
||||
function getModelPath(modelId) {
|
||||
- return modelId.includes("/") ? modelId : `models/${modelId}`;
|
||||
+ return modelId.includes("models/") ? modelId : `models/${modelId}`;
|
||||
}
|
||||
|
||||
// src/google-generative-ai-options.ts
|
||||
diff --git a/dist/internal/index.js b/dist/internal/index.js
|
||||
index 277cac8dc734bea2fb4f3e9a225986b402b24f48..bb704cd79e602eb8b0cee1889e42497d59ccdb7a 100644
|
||||
--- a/dist/internal/index.js
|
||||
+++ b/dist/internal/index.js
|
||||
@@ -432,7 +432,15 @@ function prepareTools({
|
||||
var _a;
|
||||
tools = (tools == null ? void 0 : tools.length) ? tools : void 0;
|
||||
const toolWarnings = [];
|
||||
- const isGemini2 = modelId.includes("gemini-2");
|
||||
+ // These changes could be safely removed when @ai-sdk/google v3 released.
|
||||
+ const isLatest = (
|
||||
+ [
|
||||
+ 'gemini-flash-latest',
|
||||
+ 'gemini-flash-lite-latest',
|
||||
+ 'gemini-pro-latest',
|
||||
+ ]
|
||||
+ ).some(id => id === modelId);
|
||||
+ const isGemini2OrNewer = modelId.includes("gemini-2") || modelId.includes("gemini-3") || isLatest;
|
||||
const supportsDynamicRetrieval = modelId.includes("gemini-1.5-flash") && !modelId.includes("-8b");
|
||||
const supportsFileSearch = modelId.includes("gemini-2.5");
|
||||
if (tools == null) {
|
||||
@@ -458,7 +466,7 @@ function prepareTools({
|
||||
providerDefinedTools.forEach((tool) => {
|
||||
switch (tool.id) {
|
||||
case "google.google_search":
|
||||
- if (isGemini2) {
|
||||
+ if (isGemini2OrNewer) {
|
||||
googleTools2.push({ googleSearch: {} });
|
||||
} else if (supportsDynamicRetrieval) {
|
||||
googleTools2.push({
|
||||
@@ -474,7 +482,7 @@ function prepareTools({
|
||||
}
|
||||
break;
|
||||
case "google.url_context":
|
||||
- if (isGemini2) {
|
||||
+ if (isGemini2OrNewer) {
|
||||
googleTools2.push({ urlContext: {} });
|
||||
} else {
|
||||
toolWarnings.push({
|
||||
@@ -485,7 +493,7 @@ function prepareTools({
|
||||
}
|
||||
break;
|
||||
case "google.code_execution":
|
||||
- if (isGemini2) {
|
||||
+ if (isGemini2OrNewer) {
|
||||
googleTools2.push({ codeExecution: {} });
|
||||
} else {
|
||||
toolWarnings.push({
|
||||
@@ -507,7 +515,7 @@ function prepareTools({
|
||||
}
|
||||
break;
|
||||
case "google.vertex_rag_store":
|
||||
- if (isGemini2) {
|
||||
+ if (isGemini2OrNewer) {
|
||||
googleTools2.push({
|
||||
retrieval: {
|
||||
vertex_rag_store: {
|
||||
diff --git a/dist/internal/index.mjs b/dist/internal/index.mjs
|
||||
index 03b7cc591be9b58bcc2e775a96740d9f98862a10..347d2c12e1cee79f0f8bb258f3844fb0522a6485 100644
|
||||
--- a/dist/internal/index.mjs
|
||||
+++ b/dist/internal/index.mjs
|
||||
@@ -424,7 +424,15 @@ function prepareTools({
|
||||
var _a;
|
||||
tools = (tools == null ? void 0 : tools.length) ? tools : void 0;
|
||||
const toolWarnings = [];
|
||||
- const isGemini2 = modelId.includes("gemini-2");
|
||||
+ // These changes could be safely removed when @ai-sdk/google v3 released.
|
||||
+ const isLatest = (
|
||||
+ [
|
||||
+ 'gemini-flash-latest',
|
||||
+ 'gemini-flash-lite-latest',
|
||||
+ 'gemini-pro-latest',
|
||||
+ ]
|
||||
+ ).some(id => id === modelId);
|
||||
+ const isGemini2OrNewer = modelId.includes("gemini-2") || modelId.includes("gemini-3") || isLatest;
|
||||
const supportsDynamicRetrieval = modelId.includes("gemini-1.5-flash") && !modelId.includes("-8b");
|
||||
const supportsFileSearch = modelId.includes("gemini-2.5");
|
||||
if (tools == null) {
|
||||
@@ -450,7 +458,7 @@ function prepareTools({
|
||||
providerDefinedTools.forEach((tool) => {
|
||||
switch (tool.id) {
|
||||
case "google.google_search":
|
||||
- if (isGemini2) {
|
||||
+ if (isGemini2OrNewer) {
|
||||
googleTools2.push({ googleSearch: {} });
|
||||
} else if (supportsDynamicRetrieval) {
|
||||
googleTools2.push({
|
||||
@@ -466,7 +474,7 @@ function prepareTools({
|
||||
}
|
||||
break;
|
||||
case "google.url_context":
|
||||
- if (isGemini2) {
|
||||
+ if (isGemini2OrNewer) {
|
||||
googleTools2.push({ urlContext: {} });
|
||||
} else {
|
||||
toolWarnings.push({
|
||||
@@ -477,7 +485,7 @@ function prepareTools({
|
||||
}
|
||||
break;
|
||||
case "google.code_execution":
|
||||
- if (isGemini2) {
|
||||
+ if (isGemini2OrNewer) {
|
||||
googleTools2.push({ codeExecution: {} });
|
||||
} else {
|
||||
toolWarnings.push({
|
||||
@@ -499,7 +507,7 @@ function prepareTools({
|
||||
}
|
||||
break;
|
||||
case "google.vertex_rag_store":
|
||||
- if (isGemini2) {
|
||||
+ if (isGemini2OrNewer) {
|
||||
googleTools2.push({
|
||||
retrieval: {
|
||||
vertex_rag_store: {
|
||||
@@ -1434,9 +1442,7 @@ var googleTools = {
|
||||
vertexRagStore
|
||||
};
|
||||
export {
|
||||
- GoogleGenerativeAILanguageModel,
|
||||
getGroundingMetadataSchema,
|
||||
- getUrlContextMetadataSchema,
|
||||
- googleTools
|
||||
+ getUrlContextMetadataSchema, GoogleGenerativeAILanguageModel, googleTools
|
||||
};
|
||||
//# sourceMappingURL=index.mjs.map
|
||||
\ No newline at end of file
|
||||
@@ -1,26 +0,0 @@
|
||||
diff --git a/dist/index.js b/dist/index.js
|
||||
index dc7b74ba55337c491cdf1ab3e39ca68cc4187884..ace8c90591288e42c2957e93c9bf7984f1b22444 100644
|
||||
--- a/dist/index.js
|
||||
+++ b/dist/index.js
|
||||
@@ -472,7 +472,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
|
||||
|
||||
// src/get-model-path.ts
|
||||
function getModelPath(modelId) {
|
||||
- return modelId.includes("/") ? modelId : `models/${modelId}`;
|
||||
+ return modelId.includes("models/") ? modelId : `models/${modelId}`;
|
||||
}
|
||||
|
||||
// src/google-generative-ai-options.ts
|
||||
diff --git a/dist/index.mjs b/dist/index.mjs
|
||||
index 8390439c38cb7eaeb52080862cd6f4c58509e67c..a7647f2e11700dff7e1c8d4ae8f99d3637010733 100644
|
||||
--- a/dist/index.mjs
|
||||
+++ b/dist/index.mjs
|
||||
@@ -478,7 +478,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
|
||||
|
||||
// src/get-model-path.ts
|
||||
function getModelPath(modelId) {
|
||||
- return modelId.includes("/") ? modelId : `models/${modelId}`;
|
||||
+ return modelId.includes("models/") ? modelId : `models/${modelId}`;
|
||||
}
|
||||
|
||||
// src/google-generative-ai-options.ts
|
||||
131
.yarn/patches/@ai-sdk-huggingface-npm-0.0.8-d4d0aaac93.patch
vendored
Normal file
131
.yarn/patches/@ai-sdk-huggingface-npm-0.0.8-d4d0aaac93.patch
vendored
Normal file
@@ -0,0 +1,131 @@
|
||||
diff --git a/dist/index.mjs b/dist/index.mjs
|
||||
index b3f018730a93639aad7c203f15fb1aeb766c73f4..ade2a43d66e9184799d072153df61ef7be4ea110 100644
|
||||
--- a/dist/index.mjs
|
||||
+++ b/dist/index.mjs
|
||||
@@ -296,7 +296,14 @@ var HuggingFaceResponsesLanguageModel = class {
|
||||
metadata: huggingfaceOptions == null ? void 0 : huggingfaceOptions.metadata,
|
||||
instructions: huggingfaceOptions == null ? void 0 : huggingfaceOptions.instructions,
|
||||
...preparedTools && { tools: preparedTools },
|
||||
- ...preparedToolChoice && { tool_choice: preparedToolChoice }
|
||||
+ ...preparedToolChoice && { tool_choice: preparedToolChoice },
|
||||
+ ...(huggingfaceOptions?.reasoningEffort != null && {
|
||||
+ reasoning: {
|
||||
+ ...(huggingfaceOptions?.reasoningEffort != null && {
|
||||
+ effort: huggingfaceOptions.reasoningEffort,
|
||||
+ }),
|
||||
+ },
|
||||
+ }),
|
||||
};
|
||||
return { args: baseArgs, warnings };
|
||||
}
|
||||
@@ -365,6 +372,20 @@ var HuggingFaceResponsesLanguageModel = class {
|
||||
}
|
||||
break;
|
||||
}
|
||||
+ case 'reasoning': {
|
||||
+ for (const contentPart of part.content) {
|
||||
+ content.push({
|
||||
+ type: 'reasoning',
|
||||
+ text: contentPart.text,
|
||||
+ providerMetadata: {
|
||||
+ huggingface: {
|
||||
+ itemId: part.id,
|
||||
+ },
|
||||
+ },
|
||||
+ });
|
||||
+ }
|
||||
+ break;
|
||||
+ }
|
||||
case "mcp_call": {
|
||||
content.push({
|
||||
type: "tool-call",
|
||||
@@ -519,6 +540,11 @@ var HuggingFaceResponsesLanguageModel = class {
|
||||
id: value.item.call_id,
|
||||
toolName: value.item.name
|
||||
});
|
||||
+ } else if (value.item.type === 'reasoning') {
|
||||
+ controller.enqueue({
|
||||
+ type: 'reasoning-start',
|
||||
+ id: value.item.id,
|
||||
+ });
|
||||
}
|
||||
return;
|
||||
}
|
||||
@@ -570,6 +596,22 @@ var HuggingFaceResponsesLanguageModel = class {
|
||||
});
|
||||
return;
|
||||
}
|
||||
+ if (isReasoningDeltaChunk(value)) {
|
||||
+ controller.enqueue({
|
||||
+ type: 'reasoning-delta',
|
||||
+ id: value.item_id,
|
||||
+ delta: value.delta,
|
||||
+ });
|
||||
+ return;
|
||||
+ }
|
||||
+
|
||||
+ if (isReasoningEndChunk(value)) {
|
||||
+ controller.enqueue({
|
||||
+ type: 'reasoning-end',
|
||||
+ id: value.item_id,
|
||||
+ });
|
||||
+ return;
|
||||
+ }
|
||||
},
|
||||
flush(controller) {
|
||||
controller.enqueue({
|
||||
@@ -593,7 +635,8 @@ var HuggingFaceResponsesLanguageModel = class {
|
||||
var huggingfaceResponsesProviderOptionsSchema = z2.object({
|
||||
metadata: z2.record(z2.string(), z2.string()).optional(),
|
||||
instructions: z2.string().optional(),
|
||||
- strictJsonSchema: z2.boolean().optional()
|
||||
+ strictJsonSchema: z2.boolean().optional(),
|
||||
+ reasoningEffort: z2.string().optional(),
|
||||
});
|
||||
var huggingfaceResponsesResponseSchema = z2.object({
|
||||
id: z2.string(),
|
||||
@@ -727,12 +770,31 @@ var responseCreatedChunkSchema = z2.object({
|
||||
model: z2.string()
|
||||
})
|
||||
});
|
||||
+var reasoningTextDeltaChunkSchema = z2.object({
|
||||
+ type: z2.literal('response.reasoning_text.delta'),
|
||||
+ item_id: z2.string(),
|
||||
+ output_index: z2.number(),
|
||||
+ content_index: z2.number(),
|
||||
+ delta: z2.string(),
|
||||
+ sequence_number: z2.number(),
|
||||
+});
|
||||
+
|
||||
+var reasoningTextEndChunkSchema = z2.object({
|
||||
+ type: z2.literal('response.reasoning_text.done'),
|
||||
+ item_id: z2.string(),
|
||||
+ output_index: z2.number(),
|
||||
+ content_index: z2.number(),
|
||||
+ text: z2.string(),
|
||||
+ sequence_number: z2.number(),
|
||||
+});
|
||||
var huggingfaceResponsesChunkSchema = z2.union([
|
||||
responseOutputItemAddedSchema,
|
||||
responseOutputItemDoneSchema,
|
||||
textDeltaChunkSchema,
|
||||
responseCompletedChunkSchema,
|
||||
responseCreatedChunkSchema,
|
||||
+ reasoningTextDeltaChunkSchema,
|
||||
+ reasoningTextEndChunkSchema,
|
||||
z2.object({ type: z2.string() }).loose()
|
||||
// fallback for unknown chunks
|
||||
]);
|
||||
@@ -751,6 +813,12 @@ function isResponseCompletedChunk(chunk) {
|
||||
function isResponseCreatedChunk(chunk) {
|
||||
return chunk.type === "response.created";
|
||||
}
|
||||
+function isReasoningDeltaChunk(chunk) {
|
||||
+ return chunk.type === 'response.reasoning_text.delta';
|
||||
+}
|
||||
+function isReasoningEndChunk(chunk) {
|
||||
+ return chunk.type === 'response.reasoning_text.done';
|
||||
+}
|
||||
|
||||
// src/huggingface-provider.ts
|
||||
function createHuggingFace(options = {}) {
|
||||
@@ -1,140 +0,0 @@
|
||||
diff --git a/dist/index.js b/dist/index.js
|
||||
index 73045a7d38faafdc7f7d2cd79d7ff0e2b031056b..8d948c9ac4ea4b474db9ef3c5491961e7fcf9a07 100644
|
||||
--- a/dist/index.js
|
||||
+++ b/dist/index.js
|
||||
@@ -421,6 +421,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||
text: reasoning
|
||||
});
|
||||
}
|
||||
+ if (choice.message.images) {
|
||||
+ for (const image of choice.message.images) {
|
||||
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||
+ content.push({
|
||||
+ type: 'file',
|
||||
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||
+ data: match2 ? match2[1] : image.image_url.url,
|
||||
+ });
|
||||
+ }
|
||||
+ }
|
||||
if (choice.message.tool_calls != null) {
|
||||
for (const toolCall of choice.message.tool_calls) {
|
||||
content.push({
|
||||
@@ -598,6 +609,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||
delta: delta.content
|
||||
});
|
||||
}
|
||||
+ if (delta.images) {
|
||||
+ for (const image of delta.images) {
|
||||
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||
+ controller.enqueue({
|
||||
+ type: 'file',
|
||||
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||
+ data: match2 ? match2[1] : image.image_url.url,
|
||||
+ });
|
||||
+ }
|
||||
+ }
|
||||
if (delta.tool_calls != null) {
|
||||
for (const toolCallDelta of delta.tool_calls) {
|
||||
const index = toolCallDelta.index;
|
||||
@@ -765,6 +787,14 @@ var OpenAICompatibleChatResponseSchema = import_v43.z.object({
|
||||
arguments: import_v43.z.string()
|
||||
})
|
||||
})
|
||||
+ ).nullish(),
|
||||
+ images: import_v43.z.array(
|
||||
+ import_v43.z.object({
|
||||
+ type: import_v43.z.literal('image_url'),
|
||||
+ image_url: import_v43.z.object({
|
||||
+ url: import_v43.z.string(),
|
||||
+ })
|
||||
+ })
|
||||
).nullish()
|
||||
}),
|
||||
finish_reason: import_v43.z.string().nullish()
|
||||
@@ -795,6 +825,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => import_v43.z.union(
|
||||
arguments: import_v43.z.string().nullish()
|
||||
})
|
||||
})
|
||||
+ ).nullish(),
|
||||
+ images: import_v43.z.array(
|
||||
+ import_v43.z.object({
|
||||
+ type: import_v43.z.literal('image_url'),
|
||||
+ image_url: import_v43.z.object({
|
||||
+ url: import_v43.z.string(),
|
||||
+ })
|
||||
+ })
|
||||
).nullish()
|
||||
}).nullish(),
|
||||
finish_reason: import_v43.z.string().nullish()
|
||||
diff --git a/dist/index.mjs b/dist/index.mjs
|
||||
index 1c2b9560bbfbfe10cb01af080aeeed4ff59db29c..2c8ddc4fc9bfc5e7e06cfca105d197a08864c427 100644
|
||||
--- a/dist/index.mjs
|
||||
+++ b/dist/index.mjs
|
||||
@@ -405,6 +405,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||
text: reasoning
|
||||
});
|
||||
}
|
||||
+ if (choice.message.images) {
|
||||
+ for (const image of choice.message.images) {
|
||||
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||
+ content.push({
|
||||
+ type: 'file',
|
||||
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||
+ data: match2 ? match2[1] : image.image_url.url,
|
||||
+ });
|
||||
+ }
|
||||
+ }
|
||||
if (choice.message.tool_calls != null) {
|
||||
for (const toolCall of choice.message.tool_calls) {
|
||||
content.push({
|
||||
@@ -582,6 +593,17 @@ var OpenAICompatibleChatLanguageModel = class {
|
||||
delta: delta.content
|
||||
});
|
||||
}
|
||||
+ if (delta.images) {
|
||||
+ for (const image of delta.images) {
|
||||
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
|
||||
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
|
||||
+ controller.enqueue({
|
||||
+ type: 'file',
|
||||
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
|
||||
+ data: match2 ? match2[1] : image.image_url.url,
|
||||
+ });
|
||||
+ }
|
||||
+ }
|
||||
if (delta.tool_calls != null) {
|
||||
for (const toolCallDelta of delta.tool_calls) {
|
||||
const index = toolCallDelta.index;
|
||||
@@ -749,6 +771,14 @@ var OpenAICompatibleChatResponseSchema = z3.object({
|
||||
arguments: z3.string()
|
||||
})
|
||||
})
|
||||
+ ).nullish(),
|
||||
+ images: z3.array(
|
||||
+ z3.object({
|
||||
+ type: z3.literal('image_url'),
|
||||
+ image_url: z3.object({
|
||||
+ url: z3.string(),
|
||||
+ })
|
||||
+ })
|
||||
).nullish()
|
||||
}),
|
||||
finish_reason: z3.string().nullish()
|
||||
@@ -779,6 +809,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => z3.union([
|
||||
arguments: z3.string().nullish()
|
||||
})
|
||||
})
|
||||
+ ).nullish(),
|
||||
+ images: z3.array(
|
||||
+ z3.object({
|
||||
+ type: z3.literal('image_url'),
|
||||
+ image_url: z3.object({
|
||||
+ url: z3.string(),
|
||||
+ })
|
||||
+ })
|
||||
).nullish()
|
||||
}).nullish(),
|
||||
finish_reason: z3.string().nullish()
|
||||
@@ -1,5 +1,5 @@
|
||||
diff --git a/dist/index.js b/dist/index.js
|
||||
index 7481f3b3511078068d87d03855b568b20bb86971..8ac5ec28d2f7ad1b3b0d3f8da945c75674e59637 100644
|
||||
index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa96b52ac0d 100644
|
||||
--- a/dist/index.js
|
||||
+++ b/dist/index.js
|
||||
@@ -274,6 +274,7 @@ var openaiChatResponseSchema = (0, import_provider_utils3.lazyValidator)(
|
||||
@@ -18,7 +18,7 @@ index 7481f3b3511078068d87d03855b568b20bb86971..8ac5ec28d2f7ad1b3b0d3f8da945c756
|
||||
tool_calls: import_v42.z.array(
|
||||
import_v42.z.object({
|
||||
index: import_v42.z.number(),
|
||||
@@ -795,6 +797,13 @@ var OpenAIChatLanguageModel = class {
|
||||
@@ -785,6 +787,13 @@ var OpenAIChatLanguageModel = class {
|
||||
if (text != null && text.length > 0) {
|
||||
content.push({ type: "text", text });
|
||||
}
|
||||
@@ -32,7 +32,7 @@ index 7481f3b3511078068d87d03855b568b20bb86971..8ac5ec28d2f7ad1b3b0d3f8da945c756
|
||||
for (const toolCall of (_a = choice.message.tool_calls) != null ? _a : []) {
|
||||
content.push({
|
||||
type: "tool-call",
|
||||
@@ -876,6 +885,7 @@ var OpenAIChatLanguageModel = class {
|
||||
@@ -866,6 +875,7 @@ var OpenAIChatLanguageModel = class {
|
||||
};
|
||||
let metadataExtracted = false;
|
||||
let isActiveText = false;
|
||||
@@ -40,7 +40,7 @@ index 7481f3b3511078068d87d03855b568b20bb86971..8ac5ec28d2f7ad1b3b0d3f8da945c756
|
||||
const providerMetadata = { openai: {} };
|
||||
return {
|
||||
stream: response.pipeThrough(
|
||||
@@ -933,6 +943,21 @@ var OpenAIChatLanguageModel = class {
|
||||
@@ -923,6 +933,21 @@ var OpenAIChatLanguageModel = class {
|
||||
return;
|
||||
}
|
||||
const delta = choice.delta;
|
||||
@@ -62,7 +62,7 @@ index 7481f3b3511078068d87d03855b568b20bb86971..8ac5ec28d2f7ad1b3b0d3f8da945c756
|
||||
if (delta.content != null) {
|
||||
if (!isActiveText) {
|
||||
controller.enqueue({ type: "text-start", id: "0" });
|
||||
@@ -1045,6 +1070,9 @@ var OpenAIChatLanguageModel = class {
|
||||
@@ -1035,6 +1060,9 @@ var OpenAIChatLanguageModel = class {
|
||||
}
|
||||
},
|
||||
flush(controller) {
|
||||
@@ -14,7 +14,7 @@
|
||||
}
|
||||
},
|
||||
"enabled": true,
|
||||
"includes": ["**/*.json", "!*.json", "!**/package.json", "!coverage/**"]
|
||||
"includes": ["**/*.json", "!*.json", "!**/package.json"]
|
||||
},
|
||||
"css": {
|
||||
"formatter": {
|
||||
@@ -23,7 +23,7 @@
|
||||
},
|
||||
"files": {
|
||||
"ignoreUnknown": false,
|
||||
"includes": ["**", "!**/.claude/**", "!**/.vscode/**"],
|
||||
"includes": ["**", "!**/.claude/**"],
|
||||
"maxSize": 2097152
|
||||
},
|
||||
"formatter": {
|
||||
|
||||
36
package.json
36
package.json
@@ -86,7 +86,6 @@
|
||||
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch",
|
||||
"@paymoapp/electron-shutdown-handler": "^1.1.2",
|
||||
"@strongtz/win32-arm64-msvc": "^0.4.7",
|
||||
"emoji-picker-element-data": "^1",
|
||||
"express": "^5.1.0",
|
||||
"font-list": "^2.0.0",
|
||||
"graceful-fs": "^4.2.11",
|
||||
@@ -109,17 +108,16 @@
|
||||
"@agentic/exa": "^7.3.3",
|
||||
"@agentic/searxng": "^7.3.3",
|
||||
"@agentic/tavily": "^7.3.3",
|
||||
"@ai-sdk/amazon-bedrock": "^3.0.56",
|
||||
"@ai-sdk/anthropic": "^2.0.45",
|
||||
"@ai-sdk/amazon-bedrock": "^3.0.53",
|
||||
"@ai-sdk/anthropic": "^2.0.44",
|
||||
"@ai-sdk/cerebras": "^1.0.31",
|
||||
"@ai-sdk/gateway": "^2.0.13",
|
||||
"@ai-sdk/google": "patch:@ai-sdk/google@npm%3A2.0.40#~/.yarn/patches/@ai-sdk-google-npm-2.0.40-47e0eeee83.patch",
|
||||
"@ai-sdk/google-vertex": "^3.0.72",
|
||||
"@ai-sdk/huggingface": "^0.0.10",
|
||||
"@ai-sdk/mistral": "^2.0.24",
|
||||
"@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.71#~/.yarn/patches/@ai-sdk-openai-npm-2.0.71-a88ef00525.patch",
|
||||
"@ai-sdk/perplexity": "^2.0.20",
|
||||
"@ai-sdk/test-server": "^0.0.1",
|
||||
"@ai-sdk/gateway": "^2.0.9",
|
||||
"@ai-sdk/google": "patch:@ai-sdk/google@npm%3A2.0.36#~/.yarn/patches/@ai-sdk-google-npm-2.0.36-6f3cc06026.patch",
|
||||
"@ai-sdk/google-vertex": "^3.0.68",
|
||||
"@ai-sdk/huggingface": "patch:@ai-sdk/huggingface@npm%3A0.0.8#~/.yarn/patches/@ai-sdk-huggingface-npm-0.0.8-d4d0aaac93.patch",
|
||||
"@ai-sdk/mistral": "^2.0.23",
|
||||
"@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch",
|
||||
"@ai-sdk/perplexity": "^2.0.17",
|
||||
"@ant-design/v5-patch-for-react-19": "^1.0.3",
|
||||
"@anthropic-ai/sdk": "^0.41.0",
|
||||
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
|
||||
@@ -160,12 +158,11 @@
|
||||
"@langchain/community": "^1.0.0",
|
||||
"@langchain/core": "patch:@langchain/core@npm%3A1.0.2#~/.yarn/patches/@langchain-core-npm-1.0.2-183ef83fe4.patch",
|
||||
"@langchain/openai": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
|
||||
"@mcp-ui/client": "^5.14.1",
|
||||
"@mistralai/mistralai": "^1.7.5",
|
||||
"@modelcontextprotocol/sdk": "^1.17.5",
|
||||
"@mozilla/readability": "^0.6.0",
|
||||
"@notionhq/client": "^2.2.15",
|
||||
"@openrouter/ai-sdk-provider": "^1.2.5",
|
||||
"@openrouter/ai-sdk-provider": "^1.2.0",
|
||||
"@opentelemetry/api": "^1.9.0",
|
||||
"@opentelemetry/core": "2.0.0",
|
||||
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
|
||||
@@ -218,8 +215,8 @@
|
||||
"@types/mime-types": "^3",
|
||||
"@types/node": "^22.17.1",
|
||||
"@types/pako": "^1.0.2",
|
||||
"@types/react": "^19.0.12",
|
||||
"@types/react-dom": "^19.0.4",
|
||||
"@types/react": "^19.2.6",
|
||||
"@types/react-dom": "^19.2.3",
|
||||
"@types/react-infinite-scroll-component": "^5.0.0",
|
||||
"@types/react-transition-group": "^4.4.12",
|
||||
"@types/react-window": "^1",
|
||||
@@ -241,7 +238,7 @@
|
||||
"@viz-js/lang-dot": "^1.0.5",
|
||||
"@viz-js/viz": "^3.14.0",
|
||||
"@xyflow/react": "^12.4.4",
|
||||
"ai": "^5.0.98",
|
||||
"ai": "^5.0.90",
|
||||
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
|
||||
"archiver": "^7.0.1",
|
||||
"async-mutex": "^0.5.0",
|
||||
@@ -414,11 +411,8 @@
|
||||
"@langchain/openai@npm:^0.3.16": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
|
||||
"@langchain/openai@npm:>=0.2.0 <0.7.0": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
|
||||
"@ai-sdk/openai@npm:2.0.64": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch",
|
||||
"@ai-sdk/openai@npm:^2.0.42": "patch:@ai-sdk/openai@npm%3A2.0.71#~/.yarn/patches/@ai-sdk-openai-npm-2.0.71-a88ef00525.patch",
|
||||
"@ai-sdk/google@npm:2.0.40": "patch:@ai-sdk/google@npm%3A2.0.40#~/.yarn/patches/@ai-sdk-google-npm-2.0.40-47e0eeee83.patch",
|
||||
"@ai-sdk/openai@npm:2.0.71": "patch:@ai-sdk/openai@npm%3A2.0.71#~/.yarn/patches/@ai-sdk-openai-npm-2.0.71-a88ef00525.patch",
|
||||
"@ai-sdk/openai-compatible@npm:1.0.27": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch",
|
||||
"@ai-sdk/openai-compatible@npm:^1.0.19": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch"
|
||||
"@ai-sdk/openai@npm:^2.0.42": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch",
|
||||
"@ai-sdk/google@npm:2.0.36": "patch:@ai-sdk/google@npm%3A2.0.36#~/.yarn/patches/@ai-sdk-google-npm-2.0.36-6f3cc06026.patch"
|
||||
},
|
||||
"packageManager": "yarn@4.9.1",
|
||||
"lint-staged": {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "@cherrystudio/ai-sdk-provider",
|
||||
"version": "0.1.3",
|
||||
"version": "0.1.2",
|
||||
"description": "Cherry Studio AI SDK provider bundle with CherryIN routing.",
|
||||
"keywords": [
|
||||
"ai-sdk",
|
||||
@@ -42,7 +42,7 @@
|
||||
},
|
||||
"dependencies": {
|
||||
"@ai-sdk/provider": "^2.0.0",
|
||||
"@ai-sdk/provider-utils": "^3.0.17"
|
||||
"@ai-sdk/provider-utils": "^3.0.12"
|
||||
},
|
||||
"devDependencies": {
|
||||
"tsdown": "^0.13.3",
|
||||
|
||||
@@ -67,10 +67,6 @@ export interface CherryInProviderSettings {
|
||||
* Optional static headers applied to every request.
|
||||
*/
|
||||
headers?: HeadersInput
|
||||
/**
|
||||
* Optional endpoint type to distinguish different endpoint behaviors.
|
||||
*/
|
||||
endpointType?: 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'image-generation' | 'jina-rerank'
|
||||
}
|
||||
|
||||
export interface CherryInProvider extends ProviderV2 {
|
||||
@@ -155,8 +151,7 @@ export const createCherryIn = (options: CherryInProviderSettings = {}): CherryIn
|
||||
baseURL = DEFAULT_CHERRYIN_BASE_URL,
|
||||
anthropicBaseURL = DEFAULT_CHERRYIN_ANTHROPIC_BASE_URL,
|
||||
geminiBaseURL = DEFAULT_CHERRYIN_GEMINI_BASE_URL,
|
||||
fetch,
|
||||
endpointType
|
||||
fetch
|
||||
} = options
|
||||
|
||||
const getJsonHeaders = createJsonHeadersGetter(options)
|
||||
@@ -210,7 +205,7 @@ export const createCherryIn = (options: CherryInProviderSettings = {}): CherryIn
|
||||
fetch
|
||||
})
|
||||
|
||||
const createChatModelByModelId = (modelId: string, settings: OpenAIProviderSettings = {}) => {
|
||||
const createChatModel = (modelId: string, settings: OpenAIProviderSettings = {}) => {
|
||||
if (isAnthropicModel(modelId)) {
|
||||
return createAnthropicModel(modelId)
|
||||
}
|
||||
@@ -228,29 +223,6 @@ export const createCherryIn = (options: CherryInProviderSettings = {}): CherryIn
|
||||
})
|
||||
}
|
||||
|
||||
const createChatModel = (modelId: string, settings: OpenAIProviderSettings = {}) => {
|
||||
if (!endpointType) return createChatModelByModelId(modelId, settings)
|
||||
switch (endpointType) {
|
||||
case 'anthropic':
|
||||
return createAnthropicModel(modelId)
|
||||
case 'gemini':
|
||||
return createGeminiModel(modelId)
|
||||
case 'openai':
|
||||
return createOpenAIChatModel(modelId)
|
||||
case 'openai-response':
|
||||
default:
|
||||
return new OpenAIResponsesLanguageModel(modelId, {
|
||||
provider: `${CHERRYIN_PROVIDER_NAME}.openai`,
|
||||
url,
|
||||
headers: () => ({
|
||||
...getJsonHeaders(),
|
||||
...settings.headers
|
||||
}),
|
||||
fetch
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
const createCompletionModel = (modelId: string, settings: OpenAIProviderSettings = {}) =>
|
||||
new OpenAICompletionLanguageModel(modelId, {
|
||||
provider: `${CHERRYIN_PROVIDER_NAME}.completion`,
|
||||
|
||||
@@ -35,17 +35,17 @@
|
||||
"peerDependencies": {
|
||||
"@ai-sdk/google": "^2.0.36",
|
||||
"@ai-sdk/openai": "^2.0.64",
|
||||
"@cherrystudio/ai-sdk-provider": "^0.1.3",
|
||||
"@cherrystudio/ai-sdk-provider": "^0.1.2",
|
||||
"ai": "^5.0.26"
|
||||
},
|
||||
"dependencies": {
|
||||
"@ai-sdk/anthropic": "^2.0.45",
|
||||
"@ai-sdk/azure": "^2.0.73",
|
||||
"@ai-sdk/deepseek": "^1.0.29",
|
||||
"@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch",
|
||||
"@ai-sdk/anthropic": "^2.0.43",
|
||||
"@ai-sdk/azure": "^2.0.66",
|
||||
"@ai-sdk/deepseek": "^1.0.27",
|
||||
"@ai-sdk/openai-compatible": "^1.0.26",
|
||||
"@ai-sdk/provider": "^2.0.0",
|
||||
"@ai-sdk/provider-utils": "^3.0.17",
|
||||
"@ai-sdk/xai": "^2.0.34",
|
||||
"@ai-sdk/provider-utils": "^3.0.16",
|
||||
"@ai-sdk/xai": "^2.0.31",
|
||||
"zod": "^4.1.5"
|
||||
},
|
||||
"devDependencies": {
|
||||
|
||||
@@ -1,180 +0,0 @@
|
||||
/**
|
||||
* Mock Provider Instances
|
||||
* Provides mock implementations for all supported AI providers
|
||||
*/
|
||||
|
||||
import type { ImageModelV2, LanguageModelV2 } from '@ai-sdk/provider'
|
||||
import { vi } from 'vitest'
|
||||
|
||||
/**
|
||||
* Creates a mock language model with customizable behavior
|
||||
*/
|
||||
export function createMockLanguageModel(overrides?: Partial<LanguageModelV2>): LanguageModelV2 {
|
||||
return {
|
||||
specificationVersion: 'v1',
|
||||
provider: 'mock-provider',
|
||||
modelId: 'mock-model',
|
||||
defaultObjectGenerationMode: 'tool',
|
||||
|
||||
doGenerate: vi.fn().mockResolvedValue({
|
||||
text: 'Mock response text',
|
||||
finishReason: 'stop',
|
||||
usage: {
|
||||
promptTokens: 10,
|
||||
completionTokens: 20,
|
||||
totalTokens: 30
|
||||
},
|
||||
rawCall: { rawPrompt: null, rawSettings: {} },
|
||||
rawResponse: { headers: {} },
|
||||
warnings: []
|
||||
}),
|
||||
|
||||
doStream: vi.fn().mockReturnValue({
|
||||
stream: (async function* () {
|
||||
yield {
|
||||
type: 'text-delta',
|
||||
textDelta: 'Mock '
|
||||
}
|
||||
yield {
|
||||
type: 'text-delta',
|
||||
textDelta: 'streaming '
|
||||
}
|
||||
yield {
|
||||
type: 'text-delta',
|
||||
textDelta: 'response'
|
||||
}
|
||||
yield {
|
||||
type: 'finish',
|
||||
finishReason: 'stop',
|
||||
usage: {
|
||||
promptTokens: 10,
|
||||
completionTokens: 15,
|
||||
totalTokens: 25
|
||||
}
|
||||
}
|
||||
})(),
|
||||
rawCall: { rawPrompt: null, rawSettings: {} },
|
||||
rawResponse: { headers: {} },
|
||||
warnings: []
|
||||
}),
|
||||
|
||||
...overrides
|
||||
} as LanguageModelV2
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a mock image model with customizable behavior
|
||||
*/
|
||||
export function createMockImageModel(overrides?: Partial<ImageModelV2>): ImageModelV2 {
|
||||
return {
|
||||
specificationVersion: 'v2',
|
||||
provider: 'mock-provider',
|
||||
modelId: 'mock-image-model',
|
||||
|
||||
doGenerate: vi.fn().mockResolvedValue({
|
||||
images: [
|
||||
{
|
||||
base64: 'mock-base64-image-data',
|
||||
uint8Array: new Uint8Array([1, 2, 3, 4, 5]),
|
||||
mimeType: 'image/png'
|
||||
}
|
||||
],
|
||||
warnings: []
|
||||
}),
|
||||
|
||||
...overrides
|
||||
} as ImageModelV2
|
||||
}
|
||||
|
||||
/**
|
||||
* Mock provider configurations for testing
|
||||
*/
|
||||
export const mockProviderConfigs = {
|
||||
openai: {
|
||||
apiKey: 'sk-test-openai-key-123456789',
|
||||
baseURL: 'https://api.openai.com/v1',
|
||||
organization: 'test-org'
|
||||
},
|
||||
|
||||
anthropic: {
|
||||
apiKey: 'sk-ant-test-key-123456789',
|
||||
baseURL: 'https://api.anthropic.com'
|
||||
},
|
||||
|
||||
google: {
|
||||
apiKey: 'test-google-api-key-123456789',
|
||||
baseURL: 'https://generativelanguage.googleapis.com/v1'
|
||||
},
|
||||
|
||||
xai: {
|
||||
apiKey: 'xai-test-key-123456789',
|
||||
baseURL: 'https://api.x.ai/v1'
|
||||
},
|
||||
|
||||
azure: {
|
||||
apiKey: 'test-azure-key-123456789',
|
||||
resourceName: 'test-resource',
|
||||
deployment: 'test-deployment'
|
||||
},
|
||||
|
||||
deepseek: {
|
||||
apiKey: 'sk-test-deepseek-key-123456789',
|
||||
baseURL: 'https://api.deepseek.com/v1'
|
||||
},
|
||||
|
||||
openrouter: {
|
||||
apiKey: 'sk-or-test-key-123456789',
|
||||
baseURL: 'https://openrouter.ai/api/v1'
|
||||
},
|
||||
|
||||
huggingface: {
|
||||
apiKey: 'hf_test_key_123456789',
|
||||
baseURL: 'https://api-inference.huggingface.co'
|
||||
},
|
||||
|
||||
'openai-compatible': {
|
||||
apiKey: 'test-compatible-key-123456789',
|
||||
baseURL: 'https://api.example.com/v1',
|
||||
name: 'test-provider'
|
||||
},
|
||||
|
||||
'openai-chat': {
|
||||
apiKey: 'sk-test-chat-key-123456789',
|
||||
baseURL: 'https://api.openai.com/v1'
|
||||
}
|
||||
} as const
|
||||
|
||||
/**
|
||||
* Mock provider instances for testing
|
||||
*/
|
||||
export const mockProviderInstances = {
|
||||
openai: {
|
||||
name: 'openai-mock',
|
||||
languageModel: createMockLanguageModel({ provider: 'openai', modelId: 'gpt-4' }),
|
||||
imageModel: createMockImageModel({ provider: 'openai', modelId: 'dall-e-3' })
|
||||
},
|
||||
|
||||
anthropic: {
|
||||
name: 'anthropic-mock',
|
||||
languageModel: createMockLanguageModel({ provider: 'anthropic', modelId: 'claude-3-5-sonnet-20241022' })
|
||||
},
|
||||
|
||||
google: {
|
||||
name: 'google-mock',
|
||||
languageModel: createMockLanguageModel({ provider: 'google', modelId: 'gemini-2.0-flash-exp' }),
|
||||
imageModel: createMockImageModel({ provider: 'google', modelId: 'imagen-3.0-generate-001' })
|
||||
},
|
||||
|
||||
xai: {
|
||||
name: 'xai-mock',
|
||||
languageModel: createMockLanguageModel({ provider: 'xai', modelId: 'grok-2-latest' }),
|
||||
imageModel: createMockImageModel({ provider: 'xai', modelId: 'grok-2-image-latest' })
|
||||
},
|
||||
|
||||
deepseek: {
|
||||
name: 'deepseek-mock',
|
||||
languageModel: createMockLanguageModel({ provider: 'deepseek', modelId: 'deepseek-chat' })
|
||||
}
|
||||
}
|
||||
|
||||
export type ProviderId = keyof typeof mockProviderConfigs
|
||||
@@ -1,331 +0,0 @@
|
||||
/**
|
||||
* Mock Responses
|
||||
* Provides realistic mock responses for all provider types
|
||||
*/
|
||||
|
||||
import { jsonSchema, type ModelMessage, type Tool } from 'ai'
|
||||
|
||||
/**
|
||||
* Standard test messages for all scenarios
|
||||
*/
|
||||
export const testMessages = {
|
||||
simple: [{ role: 'user' as const, content: 'Hello, how are you?' }],
|
||||
|
||||
conversation: [
|
||||
{ role: 'user' as const, content: 'What is the capital of France?' },
|
||||
{ role: 'assistant' as const, content: 'The capital of France is Paris.' },
|
||||
{ role: 'user' as const, content: 'What is its population?' }
|
||||
],
|
||||
|
||||
withSystem: [
|
||||
{ role: 'system' as const, content: 'You are a helpful assistant that provides concise answers.' },
|
||||
{ role: 'user' as const, content: 'Explain quantum computing in one sentence.' }
|
||||
],
|
||||
|
||||
withImages: [
|
||||
{
|
||||
role: 'user' as const,
|
||||
content: [
|
||||
{ type: 'text' as const, text: 'What is in this image?' },
|
||||
{
|
||||
type: 'image' as const,
|
||||
image:
|
||||
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=='
|
||||
}
|
||||
]
|
||||
}
|
||||
],
|
||||
|
||||
toolUse: [{ role: 'user' as const, content: 'What is the weather in San Francisco?' }],
|
||||
|
||||
multiTurn: [
|
||||
{ role: 'user' as const, content: 'Can you help me with a math problem?' },
|
||||
{ role: 'assistant' as const, content: 'Of course! What math problem would you like help with?' },
|
||||
{ role: 'user' as const, content: 'What is 15 * 23?' },
|
||||
{ role: 'assistant' as const, content: '15 * 23 = 345' },
|
||||
{ role: 'user' as const, content: 'Now divide that by 5' }
|
||||
]
|
||||
} satisfies Record<string, ModelMessage[]>
|
||||
|
||||
/**
|
||||
* Standard test tools for tool calling scenarios
|
||||
*/
|
||||
export const testTools: Record<string, Tool> = {
|
||||
getWeather: {
|
||||
description: 'Get the current weather in a given location',
|
||||
inputSchema: jsonSchema({
|
||||
type: 'object',
|
||||
properties: {
|
||||
location: {
|
||||
type: 'string',
|
||||
description: 'The city and state, e.g. San Francisco, CA'
|
||||
},
|
||||
unit: {
|
||||
type: 'string',
|
||||
enum: ['celsius', 'fahrenheit'],
|
||||
description: 'The temperature unit to use'
|
||||
}
|
||||
},
|
||||
required: ['location']
|
||||
}),
|
||||
execute: async ({ location, unit = 'fahrenheit' }) => {
|
||||
return {
|
||||
location,
|
||||
temperature: unit === 'celsius' ? 22 : 72,
|
||||
unit,
|
||||
condition: 'sunny'
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
calculate: {
|
||||
description: 'Perform a mathematical calculation',
|
||||
inputSchema: jsonSchema({
|
||||
type: 'object',
|
||||
properties: {
|
||||
operation: {
|
||||
type: 'string',
|
||||
enum: ['add', 'subtract', 'multiply', 'divide'],
|
||||
description: 'The operation to perform'
|
||||
},
|
||||
a: {
|
||||
type: 'number',
|
||||
description: 'The first number'
|
||||
},
|
||||
b: {
|
||||
type: 'number',
|
||||
description: 'The second number'
|
||||
}
|
||||
},
|
||||
required: ['operation', 'a', 'b']
|
||||
}),
|
||||
execute: async ({ operation, a, b }) => {
|
||||
const operations = {
|
||||
add: (x: number, y: number) => x + y,
|
||||
subtract: (x: number, y: number) => x - y,
|
||||
multiply: (x: number, y: number) => x * y,
|
||||
divide: (x: number, y: number) => x / y
|
||||
}
|
||||
return { result: operations[operation as keyof typeof operations](a, b) }
|
||||
}
|
||||
},
|
||||
|
||||
searchDatabase: {
|
||||
description: 'Search for information in a database',
|
||||
inputSchema: jsonSchema({
|
||||
type: 'object',
|
||||
properties: {
|
||||
query: {
|
||||
type: 'string',
|
||||
description: 'The search query'
|
||||
},
|
||||
limit: {
|
||||
type: 'number',
|
||||
description: 'Maximum number of results to return',
|
||||
default: 10
|
||||
}
|
||||
},
|
||||
required: ['query']
|
||||
}),
|
||||
execute: async ({ query, limit = 10 }) => {
|
||||
return {
|
||||
results: [
|
||||
{ id: 1, title: `Result 1 for ${query}`, relevance: 0.95 },
|
||||
{ id: 2, title: `Result 2 for ${query}`, relevance: 0.87 }
|
||||
].slice(0, limit)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Mock streaming chunks for different providers
|
||||
*/
|
||||
export const mockStreamingChunks = {
|
||||
text: [
|
||||
{ type: 'text-delta' as const, textDelta: 'Hello' },
|
||||
{ type: 'text-delta' as const, textDelta: ', ' },
|
||||
{ type: 'text-delta' as const, textDelta: 'this ' },
|
||||
{ type: 'text-delta' as const, textDelta: 'is ' },
|
||||
{ type: 'text-delta' as const, textDelta: 'a ' },
|
||||
{ type: 'text-delta' as const, textDelta: 'test.' }
|
||||
],
|
||||
|
||||
withToolCall: [
|
||||
{ type: 'text-delta' as const, textDelta: 'Let me check the weather for you.' },
|
||||
{
|
||||
type: 'tool-call-delta' as const,
|
||||
toolCallType: 'function' as const,
|
||||
toolCallId: 'call_123',
|
||||
toolName: 'getWeather',
|
||||
argsTextDelta: '{"location":'
|
||||
},
|
||||
{
|
||||
type: 'tool-call-delta' as const,
|
||||
toolCallType: 'function' as const,
|
||||
toolCallId: 'call_123',
|
||||
toolName: 'getWeather',
|
||||
argsTextDelta: ' "San Francisco, CA"}'
|
||||
},
|
||||
{
|
||||
type: 'tool-call' as const,
|
||||
toolCallType: 'function' as const,
|
||||
toolCallId: 'call_123',
|
||||
toolName: 'getWeather',
|
||||
args: { location: 'San Francisco, CA' }
|
||||
}
|
||||
],
|
||||
|
||||
withFinish: [
|
||||
{ type: 'text-delta' as const, textDelta: 'Complete response.' },
|
||||
{
|
||||
type: 'finish' as const,
|
||||
finishReason: 'stop' as const,
|
||||
usage: {
|
||||
promptTokens: 10,
|
||||
completionTokens: 5,
|
||||
totalTokens: 15
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
/**
|
||||
* Mock complete responses for non-streaming scenarios
|
||||
*/
|
||||
export const mockCompleteResponses = {
|
||||
simple: {
|
||||
text: 'This is a simple response.',
|
||||
finishReason: 'stop' as const,
|
||||
usage: {
|
||||
promptTokens: 15,
|
||||
completionTokens: 8,
|
||||
totalTokens: 23
|
||||
}
|
||||
},
|
||||
|
||||
withToolCalls: {
|
||||
text: 'I will check the weather for you.',
|
||||
toolCalls: [
|
||||
{
|
||||
toolCallId: 'call_456',
|
||||
toolName: 'getWeather',
|
||||
args: { location: 'New York, NY', unit: 'celsius' }
|
||||
}
|
||||
],
|
||||
finishReason: 'tool-calls' as const,
|
||||
usage: {
|
||||
promptTokens: 25,
|
||||
completionTokens: 12,
|
||||
totalTokens: 37
|
||||
}
|
||||
},
|
||||
|
||||
withWarnings: {
|
||||
text: 'Response with warnings.',
|
||||
finishReason: 'stop' as const,
|
||||
usage: {
|
||||
promptTokens: 10,
|
||||
completionTokens: 5,
|
||||
totalTokens: 15
|
||||
},
|
||||
warnings: [
|
||||
{
|
||||
type: 'unsupported-setting' as const,
|
||||
message: 'Temperature parameter not supported for this model'
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Mock image generation responses
|
||||
*/
|
||||
export const mockImageResponses = {
|
||||
single: {
|
||||
image: {
|
||||
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
|
||||
uint8Array: new Uint8Array([137, 80, 78, 71, 13, 10, 26, 10, 0, 0, 0, 13, 73, 72, 68, 82]),
|
||||
mimeType: 'image/png' as const
|
||||
},
|
||||
warnings: []
|
||||
},
|
||||
|
||||
multiple: {
|
||||
images: [
|
||||
{
|
||||
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
|
||||
uint8Array: new Uint8Array([137, 80, 78, 71]),
|
||||
mimeType: 'image/png' as const
|
||||
},
|
||||
{
|
||||
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAIAAAACCAYAAABytg0kAAAAEklEQVR42mNk+M9QzwAEjDAGACCKAgdZ9zImAAAAAElFTkSuQmCC',
|
||||
uint8Array: new Uint8Array([137, 80, 78, 71]),
|
||||
mimeType: 'image/png' as const
|
||||
}
|
||||
],
|
||||
warnings: []
|
||||
},
|
||||
|
||||
withProviderMetadata: {
|
||||
image: {
|
||||
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
|
||||
uint8Array: new Uint8Array([137, 80, 78, 71]),
|
||||
mimeType: 'image/png' as const
|
||||
},
|
||||
providerMetadata: {
|
||||
openai: {
|
||||
images: [
|
||||
{
|
||||
revisedPrompt: 'A detailed and enhanced version of the original prompt'
|
||||
}
|
||||
]
|
||||
}
|
||||
},
|
||||
warnings: []
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Mock error responses
|
||||
*/
|
||||
export const mockErrors = {
|
||||
invalidApiKey: {
|
||||
name: 'APIError',
|
||||
message: 'Invalid API key provided',
|
||||
statusCode: 401
|
||||
},
|
||||
|
||||
rateLimitExceeded: {
|
||||
name: 'RateLimitError',
|
||||
message: 'Rate limit exceeded. Please try again later.',
|
||||
statusCode: 429,
|
||||
headers: {
|
||||
'retry-after': '60'
|
||||
}
|
||||
},
|
||||
|
||||
modelNotFound: {
|
||||
name: 'ModelNotFoundError',
|
||||
message: 'The requested model was not found',
|
||||
statusCode: 404
|
||||
},
|
||||
|
||||
contextLengthExceeded: {
|
||||
name: 'ContextLengthError',
|
||||
message: "This model's maximum context length is 4096 tokens",
|
||||
statusCode: 400
|
||||
},
|
||||
|
||||
timeout: {
|
||||
name: 'TimeoutError',
|
||||
message: 'Request timed out after 30000ms',
|
||||
code: 'ETIMEDOUT'
|
||||
},
|
||||
|
||||
networkError: {
|
||||
name: 'NetworkError',
|
||||
message: 'Network connection failed',
|
||||
code: 'ECONNREFUSED'
|
||||
}
|
||||
}
|
||||
@@ -1,329 +0,0 @@
|
||||
/**
|
||||
* Provider-Specific Test Utilities
|
||||
* Helper functions for testing individual providers with all their parameters
|
||||
*/
|
||||
|
||||
import type { Tool } from 'ai'
|
||||
import { expect } from 'vitest'
|
||||
|
||||
/**
|
||||
* Provider parameter configurations for comprehensive testing
|
||||
*/
|
||||
export const providerParameterMatrix = {
|
||||
openai: {
|
||||
models: ['gpt-4', 'gpt-4-turbo', 'gpt-3.5-turbo', 'gpt-4o'],
|
||||
parameters: {
|
||||
temperature: [0, 0.5, 0.7, 1.0, 1.5, 2.0],
|
||||
maxTokens: [100, 500, 1000, 2000, 4000],
|
||||
topP: [0.1, 0.5, 0.9, 1.0],
|
||||
frequencyPenalty: [-2.0, -1.0, 0, 1.0, 2.0],
|
||||
presencePenalty: [-2.0, -1.0, 0, 1.0, 2.0],
|
||||
stop: [undefined, ['stop'], ['STOP', 'END']],
|
||||
seed: [undefined, 12345, 67890],
|
||||
responseFormat: [undefined, { type: 'json_object' as const }],
|
||||
user: [undefined, 'test-user-123']
|
||||
},
|
||||
toolChoice: ['auto', 'required', 'none', { type: 'function' as const, name: 'getWeather' }],
|
||||
parallelToolCalls: [true, false]
|
||||
},
|
||||
|
||||
anthropic: {
|
||||
models: ['claude-3-5-sonnet-20241022', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'],
|
||||
parameters: {
|
||||
temperature: [0, 0.5, 1.0],
|
||||
maxTokens: [100, 1000, 4000, 8000],
|
||||
topP: [0.1, 0.5, 0.9, 1.0],
|
||||
topK: [undefined, 1, 5, 10, 40],
|
||||
stop: [undefined, ['Human:', 'Assistant:']],
|
||||
metadata: [undefined, { userId: 'test-123' }]
|
||||
},
|
||||
toolChoice: ['auto', 'any', { type: 'tool' as const, name: 'getWeather' }]
|
||||
},
|
||||
|
||||
google: {
|
||||
models: ['gemini-2.0-flash-exp', 'gemini-1.5-pro', 'gemini-1.5-flash'],
|
||||
parameters: {
|
||||
temperature: [0, 0.5, 0.9, 1.0],
|
||||
maxTokens: [100, 1000, 2000, 8000],
|
||||
topP: [0.1, 0.5, 0.95, 1.0],
|
||||
topK: [undefined, 1, 16, 40],
|
||||
stopSequences: [undefined, ['END'], ['STOP', 'TERMINATE']]
|
||||
},
|
||||
safetySettings: [
|
||||
undefined,
|
||||
[
|
||||
{ category: 'HARM_CATEGORY_HARASSMENT', threshold: 'BLOCK_MEDIUM_AND_ABOVE' },
|
||||
{ category: 'HARM_CATEGORY_HATE_SPEECH', threshold: 'BLOCK_ONLY_HIGH' }
|
||||
]
|
||||
]
|
||||
},
|
||||
|
||||
xai: {
|
||||
models: ['grok-2-latest', 'grok-2-1212'],
|
||||
parameters: {
|
||||
temperature: [0, 0.5, 1.0, 1.5],
|
||||
maxTokens: [100, 500, 2000, 4000],
|
||||
topP: [0.1, 0.5, 0.9, 1.0],
|
||||
stop: [undefined, ['STOP'], ['END', 'TERMINATE']],
|
||||
seed: [undefined, 12345]
|
||||
}
|
||||
},
|
||||
|
||||
deepseek: {
|
||||
models: ['deepseek-chat', 'deepseek-coder'],
|
||||
parameters: {
|
||||
temperature: [0, 0.5, 1.0],
|
||||
maxTokens: [100, 1000, 4000],
|
||||
topP: [0.1, 0.5, 0.95],
|
||||
frequencyPenalty: [0, 0.5, 1.0],
|
||||
presencePenalty: [0, 0.5, 1.0],
|
||||
stop: [undefined, ['```'], ['END']]
|
||||
}
|
||||
},
|
||||
|
||||
azure: {
|
||||
deployments: ['gpt-4-deployment', 'gpt-35-turbo-deployment'],
|
||||
parameters: {
|
||||
temperature: [0, 0.7, 1.0],
|
||||
maxTokens: [100, 1000, 2000],
|
||||
topP: [0.1, 0.5, 0.95],
|
||||
frequencyPenalty: [0, 1.0],
|
||||
presencePenalty: [0, 1.0],
|
||||
stop: [undefined, ['STOP']]
|
||||
}
|
||||
}
|
||||
} as const
|
||||
|
||||
/**
|
||||
* Creates test cases for all parameter combinations
|
||||
*/
|
||||
export function generateParameterTestCases<T extends Record<string, any[]>>(
|
||||
params: T,
|
||||
maxCombinations = 50
|
||||
): Array<Partial<{ [K in keyof T]: T[K][number] }>> {
|
||||
const keys = Object.keys(params) as Array<keyof T>
|
||||
const testCases: Array<Partial<{ [K in keyof T]: T[K][number] }>> = []
|
||||
|
||||
// Generate combinations using sampling strategy for large parameter spaces
|
||||
const totalCombinations = keys.reduce((acc, key) => acc * params[key].length, 1)
|
||||
|
||||
if (totalCombinations <= maxCombinations) {
|
||||
// Generate all combinations if total is small
|
||||
generateAllCombinations(params, keys, 0, {}, testCases)
|
||||
} else {
|
||||
// Sample diverse combinations if total is large
|
||||
generateSampledCombinations(params, keys, maxCombinations, testCases)
|
||||
}
|
||||
|
||||
return testCases
|
||||
}
|
||||
|
||||
function generateAllCombinations<T extends Record<string, any[]>>(
|
||||
params: T,
|
||||
keys: Array<keyof T>,
|
||||
index: number,
|
||||
current: Partial<{ [K in keyof T]: T[K][number] }>,
|
||||
results: Array<Partial<{ [K in keyof T]: T[K][number] }>>
|
||||
) {
|
||||
if (index === keys.length) {
|
||||
results.push({ ...current })
|
||||
return
|
||||
}
|
||||
|
||||
const key = keys[index]
|
||||
for (const value of params[key]) {
|
||||
generateAllCombinations(params, keys, index + 1, { ...current, [key]: value }, results)
|
||||
}
|
||||
}
|
||||
|
||||
function generateSampledCombinations<T extends Record<string, any[]>>(
|
||||
params: T,
|
||||
keys: Array<keyof T>,
|
||||
count: number,
|
||||
results: Array<Partial<{ [K in keyof T]: T[K][number] }>>
|
||||
) {
|
||||
// Generate edge cases first (min/max values)
|
||||
const edgeCase1: any = {}
|
||||
const edgeCase2: any = {}
|
||||
|
||||
for (const key of keys) {
|
||||
edgeCase1[key] = params[key][0]
|
||||
edgeCase2[key] = params[key][params[key].length - 1]
|
||||
}
|
||||
|
||||
results.push(edgeCase1, edgeCase2)
|
||||
|
||||
// Generate random combinations for the rest
|
||||
for (let i = results.length; i < count; i++) {
|
||||
const combination: any = {}
|
||||
for (const key of keys) {
|
||||
const values = params[key]
|
||||
combination[key] = values[Math.floor(Math.random() * values.length)]
|
||||
}
|
||||
results.push(combination)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates that all provider-specific parameters are correctly passed through
|
||||
*/
|
||||
export function validateProviderParams(providerId: string, actualParams: any, expectedParams: any): void {
|
||||
const requiredFields: Record<string, string[]> = {
|
||||
openai: ['model', 'messages'],
|
||||
anthropic: ['model', 'messages'],
|
||||
google: ['model', 'contents'],
|
||||
xai: ['model', 'messages'],
|
||||
deepseek: ['model', 'messages'],
|
||||
azure: ['messages']
|
||||
}
|
||||
|
||||
const fields = requiredFields[providerId] || ['model', 'messages']
|
||||
|
||||
for (const field of fields) {
|
||||
expect(actualParams).toHaveProperty(field)
|
||||
}
|
||||
|
||||
// Validate optional parameters if they were provided
|
||||
const optionalParams = ['temperature', 'max_tokens', 'top_p', 'stop', 'tools']
|
||||
|
||||
for (const param of optionalParams) {
|
||||
if (expectedParams[param] !== undefined) {
|
||||
expect(actualParams[param]).toEqual(expectedParams[param])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a comprehensive test suite for a provider
|
||||
*/
|
||||
// oxlint-disable-next-line no-unused-vars
|
||||
export function createProviderTestSuite(_providerId: string) {
|
||||
return {
|
||||
testBasicCompletion: async (executor: any, model: string) => {
|
||||
const result = await executor.generateText({
|
||||
model,
|
||||
messages: [{ role: 'user' as const, content: 'Hello' }]
|
||||
})
|
||||
|
||||
expect(result).toBeDefined()
|
||||
expect(result.text).toBeDefined()
|
||||
expect(typeof result.text).toBe('string')
|
||||
},
|
||||
|
||||
testStreaming: async (executor: any, model: string) => {
|
||||
const chunks: any[] = []
|
||||
const result = await executor.streamText({
|
||||
model,
|
||||
messages: [{ role: 'user' as const, content: 'Hello' }]
|
||||
})
|
||||
|
||||
for await (const chunk of result.textStream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
|
||||
expect(chunks.length).toBeGreaterThan(0)
|
||||
},
|
||||
|
||||
testTemperature: async (executor: any, model: string, temperatures: number[]) => {
|
||||
for (const temperature of temperatures) {
|
||||
const result = await executor.generateText({
|
||||
model,
|
||||
messages: [{ role: 'user' as const, content: 'Hello' }],
|
||||
temperature
|
||||
})
|
||||
|
||||
expect(result).toBeDefined()
|
||||
}
|
||||
},
|
||||
|
||||
testMaxTokens: async (executor: any, model: string, maxTokensValues: number[]) => {
|
||||
for (const maxTokens of maxTokensValues) {
|
||||
const result = await executor.generateText({
|
||||
model,
|
||||
messages: [{ role: 'user' as const, content: 'Hello' }],
|
||||
maxTokens
|
||||
})
|
||||
|
||||
expect(result).toBeDefined()
|
||||
if (result.usage?.completionTokens) {
|
||||
expect(result.usage.completionTokens).toBeLessThanOrEqual(maxTokens)
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
testToolCalling: async (executor: any, model: string, tools: Record<string, Tool>) => {
|
||||
const result = await executor.generateText({
|
||||
model,
|
||||
messages: [{ role: 'user' as const, content: 'What is the weather in SF?' }],
|
||||
tools
|
||||
})
|
||||
|
||||
expect(result).toBeDefined()
|
||||
},
|
||||
|
||||
testStopSequences: async (executor: any, model: string, stopSequences: string[][]) => {
|
||||
for (const stop of stopSequences) {
|
||||
const result = await executor.generateText({
|
||||
model,
|
||||
messages: [{ role: 'user' as const, content: 'Count to 10' }],
|
||||
stop
|
||||
})
|
||||
|
||||
expect(result).toBeDefined()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Generates test data for vision/multimodal testing
|
||||
*/
|
||||
export function createVisionTestData() {
|
||||
return {
|
||||
imageUrl: 'https://example.com/test-image.jpg',
|
||||
base64Image:
|
||||
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
|
||||
messages: [
|
||||
{
|
||||
role: 'user' as const,
|
||||
content: [
|
||||
{ type: 'text' as const, text: 'What is in this image?' },
|
||||
{
|
||||
type: 'image' as const,
|
||||
image:
|
||||
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=='
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates mock responses for different finish reasons
|
||||
*/
|
||||
export function createFinishReasonMocks() {
|
||||
return {
|
||||
stop: {
|
||||
text: 'Complete response.',
|
||||
finishReason: 'stop' as const,
|
||||
usage: { promptTokens: 10, completionTokens: 5, totalTokens: 15 }
|
||||
},
|
||||
length: {
|
||||
text: 'Incomplete response due to',
|
||||
finishReason: 'length' as const,
|
||||
usage: { promptTokens: 10, completionTokens: 100, totalTokens: 110 }
|
||||
},
|
||||
'tool-calls': {
|
||||
text: 'Calling tools',
|
||||
finishReason: 'tool-calls' as const,
|
||||
toolCalls: [{ toolCallId: 'call_1', toolName: 'getWeather', args: { location: 'SF' } }],
|
||||
usage: { promptTokens: 10, completionTokens: 8, totalTokens: 18 }
|
||||
},
|
||||
'content-filter': {
|
||||
text: '',
|
||||
finishReason: 'content-filter' as const,
|
||||
usage: { promptTokens: 10, completionTokens: 0, totalTokens: 10 }
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,291 +0,0 @@
|
||||
/**
|
||||
* Test Utilities
|
||||
* Helper functions for testing AI Core functionality
|
||||
*/
|
||||
|
||||
import { expect, vi } from 'vitest'
|
||||
|
||||
import type { ProviderId } from '../fixtures/mock-providers'
|
||||
import { createMockImageModel, createMockLanguageModel, mockProviderConfigs } from '../fixtures/mock-providers'
|
||||
|
||||
/**
|
||||
* Creates a test provider with streaming support
|
||||
*/
|
||||
export function createTestStreamingProvider(chunks: any[]) {
|
||||
return createMockLanguageModel({
|
||||
doStream: vi.fn().mockReturnValue({
|
||||
stream: (async function* () {
|
||||
for (const chunk of chunks) {
|
||||
yield chunk
|
||||
}
|
||||
})(),
|
||||
rawCall: { rawPrompt: null, rawSettings: {} },
|
||||
rawResponse: { headers: {} },
|
||||
warnings: []
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a test provider that throws errors
|
||||
*/
|
||||
export function createErrorProvider(error: Error) {
|
||||
return createMockLanguageModel({
|
||||
doGenerate: vi.fn().mockRejectedValue(error),
|
||||
doStream: vi.fn().mockImplementation(() => {
|
||||
throw error
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Collects all chunks from a stream
|
||||
*/
|
||||
export async function collectStreamChunks<T>(stream: AsyncIterable<T>): Promise<T[]> {
|
||||
const chunks: T[] = []
|
||||
for await (const chunk of stream) {
|
||||
chunks.push(chunk)
|
||||
}
|
||||
return chunks
|
||||
}
|
||||
|
||||
/**
|
||||
* Waits for a specific number of milliseconds
|
||||
*/
|
||||
export function wait(ms: number): Promise<void> {
|
||||
return new Promise((resolve) => setTimeout(resolve, ms))
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a mock abort controller that aborts after a delay
|
||||
*/
|
||||
export function createDelayedAbortController(delayMs: number): AbortController {
|
||||
const controller = new AbortController()
|
||||
setTimeout(() => controller.abort(), delayMs)
|
||||
return controller
|
||||
}
|
||||
|
||||
/**
|
||||
* Asserts that a function throws an error with a specific message
|
||||
*/
|
||||
export async function expectError(fn: () => Promise<any>, expectedMessage?: string | RegExp): Promise<Error> {
|
||||
try {
|
||||
await fn()
|
||||
throw new Error('Expected function to throw an error, but it did not')
|
||||
} catch (error) {
|
||||
if (expectedMessage) {
|
||||
const message = (error as Error).message
|
||||
if (typeof expectedMessage === 'string') {
|
||||
if (!message.includes(expectedMessage)) {
|
||||
throw new Error(`Expected error message to include "${expectedMessage}", but got "${message}"`)
|
||||
}
|
||||
} else {
|
||||
if (!expectedMessage.test(message)) {
|
||||
throw new Error(`Expected error message to match ${expectedMessage}, but got "${message}"`)
|
||||
}
|
||||
}
|
||||
}
|
||||
return error as Error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a spy function that tracks calls and arguments
|
||||
*/
|
||||
export function createSpy<T extends (...args: any[]) => any>() {
|
||||
const calls: Array<{ args: Parameters<T>; result?: ReturnType<T>; error?: Error }> = []
|
||||
|
||||
const spy = vi.fn((...args: Parameters<T>) => {
|
||||
try {
|
||||
const result = undefined as ReturnType<T>
|
||||
calls.push({ args, result })
|
||||
return result
|
||||
} catch (error) {
|
||||
calls.push({ args, error: error as Error })
|
||||
throw error
|
||||
}
|
||||
})
|
||||
|
||||
return {
|
||||
fn: spy,
|
||||
calls,
|
||||
getCalls: () => calls,
|
||||
getCallCount: () => calls.length,
|
||||
getLastCall: () => calls[calls.length - 1],
|
||||
reset: () => {
|
||||
calls.length = 0
|
||||
spy.mockClear()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates provider configuration
|
||||
*/
|
||||
export function validateProviderConfig(providerId: ProviderId) {
|
||||
const config = mockProviderConfigs[providerId]
|
||||
if (!config) {
|
||||
throw new Error(`No mock configuration found for provider: ${providerId}`)
|
||||
}
|
||||
|
||||
if (!config.apiKey) {
|
||||
throw new Error(`Provider ${providerId} is missing apiKey in mock config`)
|
||||
}
|
||||
|
||||
return config
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a test context with common setup
|
||||
*/
|
||||
export function createTestContext() {
|
||||
const mocks = {
|
||||
languageModel: createMockLanguageModel(),
|
||||
imageModel: createMockImageModel(),
|
||||
providers: new Map<string, any>()
|
||||
}
|
||||
|
||||
const cleanup = () => {
|
||||
mocks.providers.clear()
|
||||
vi.clearAllMocks()
|
||||
}
|
||||
|
||||
return {
|
||||
mocks,
|
||||
cleanup
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Measures execution time of an async function
|
||||
*/
|
||||
export async function measureTime<T>(fn: () => Promise<T>): Promise<{ result: T; duration: number }> {
|
||||
const start = Date.now()
|
||||
const result = await fn()
|
||||
const duration = Date.now() - start
|
||||
return { result, duration }
|
||||
}
|
||||
|
||||
/**
|
||||
* Retries a function until it succeeds or max attempts reached
|
||||
*/
|
||||
export async function retryUntilSuccess<T>(fn: () => Promise<T>, maxAttempts = 3, delayMs = 100): Promise<T> {
|
||||
let lastError: Error | undefined
|
||||
|
||||
for (let attempt = 1; attempt <= maxAttempts; attempt++) {
|
||||
try {
|
||||
return await fn()
|
||||
} catch (error) {
|
||||
lastError = error as Error
|
||||
if (attempt < maxAttempts) {
|
||||
await wait(delayMs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
throw lastError || new Error('All retry attempts failed')
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a mock streaming response that emits chunks at intervals
|
||||
*/
|
||||
export function createTimedStream<T>(chunks: T[], intervalMs = 10) {
|
||||
return {
|
||||
async *[Symbol.asyncIterator]() {
|
||||
for (const chunk of chunks) {
|
||||
await wait(intervalMs)
|
||||
yield chunk
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Asserts that two objects are deeply equal, ignoring specified keys
|
||||
*/
|
||||
export function assertDeepEqualIgnoring<T extends Record<string, any>>(
|
||||
actual: T,
|
||||
expected: T,
|
||||
ignoreKeys: string[] = []
|
||||
): void {
|
||||
const filterKeys = (obj: T): Partial<T> => {
|
||||
const filtered = { ...obj }
|
||||
for (const key of ignoreKeys) {
|
||||
delete filtered[key]
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
const filteredActual = filterKeys(actual)
|
||||
const filteredExpected = filterKeys(expected)
|
||||
|
||||
expect(filteredActual).toEqual(filteredExpected)
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a provider mock that simulates rate limiting
|
||||
*/
|
||||
export function createRateLimitedProvider(limitPerSecond: number) {
|
||||
const calls: number[] = []
|
||||
|
||||
return createMockLanguageModel({
|
||||
doGenerate: vi.fn().mockImplementation(async () => {
|
||||
const now = Date.now()
|
||||
calls.push(now)
|
||||
|
||||
// Remove calls older than 1 second
|
||||
const recentCalls = calls.filter((time) => now - time < 1000)
|
||||
|
||||
if (recentCalls.length > limitPerSecond) {
|
||||
throw new Error('Rate limit exceeded')
|
||||
}
|
||||
|
||||
return {
|
||||
text: 'Rate limited response',
|
||||
finishReason: 'stop' as const,
|
||||
usage: { promptTokens: 10, completionTokens: 5, totalTokens: 15 },
|
||||
rawCall: { rawPrompt: null, rawSettings: {} },
|
||||
rawResponse: { headers: {} },
|
||||
warnings: []
|
||||
}
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* Validates streaming response structure
|
||||
*/
|
||||
export function validateStreamChunk(chunk: any): void {
|
||||
expect(chunk).toBeDefined()
|
||||
expect(chunk).toHaveProperty('type')
|
||||
|
||||
if (chunk.type === 'text-delta') {
|
||||
expect(chunk).toHaveProperty('textDelta')
|
||||
expect(typeof chunk.textDelta).toBe('string')
|
||||
} else if (chunk.type === 'finish') {
|
||||
expect(chunk).toHaveProperty('finishReason')
|
||||
expect(chunk).toHaveProperty('usage')
|
||||
} else if (chunk.type === 'tool-call') {
|
||||
expect(chunk).toHaveProperty('toolCallId')
|
||||
expect(chunk).toHaveProperty('toolName')
|
||||
expect(chunk).toHaveProperty('args')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a test logger that captures log messages
|
||||
*/
|
||||
export function createTestLogger() {
|
||||
const logs: Array<{ level: string; message: string; meta?: any }> = []
|
||||
|
||||
return {
|
||||
info: (message: string, meta?: any) => logs.push({ level: 'info', message, meta }),
|
||||
warn: (message: string, meta?: any) => logs.push({ level: 'warn', message, meta }),
|
||||
error: (message: string, meta?: any) => logs.push({ level: 'error', message, meta }),
|
||||
debug: (message: string, meta?: any) => logs.push({ level: 'debug', message, meta }),
|
||||
getLogs: () => logs,
|
||||
clear: () => {
|
||||
logs.length = 0
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,12 +0,0 @@
|
||||
/**
|
||||
* Test Infrastructure Exports
|
||||
* Central export point for all test utilities, fixtures, and helpers
|
||||
*/
|
||||
|
||||
// Fixtures
|
||||
export * from './fixtures/mock-providers'
|
||||
export * from './fixtures/mock-responses'
|
||||
|
||||
// Helpers
|
||||
export * from './helpers/provider-test-utils'
|
||||
export * from './helpers/test-utils'
|
||||
@@ -1,499 +0,0 @@
|
||||
/**
|
||||
* RuntimeExecutor.generateText Comprehensive Tests
|
||||
* Tests non-streaming text generation across all providers with various parameters
|
||||
*/
|
||||
|
||||
import { generateText } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
createMockLanguageModel,
|
||||
mockCompleteResponses,
|
||||
mockProviderConfigs,
|
||||
testMessages,
|
||||
testTools
|
||||
} from '../../../__tests__'
|
||||
import type { AiPlugin } from '../../plugins'
|
||||
import { globalRegistryManagement } from '../../providers/RegistryManagement'
|
||||
import { RuntimeExecutor } from '../executor'
|
||||
|
||||
// Mock AI SDK
|
||||
vi.mock('ai', () => ({
|
||||
generateText: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('../../providers/RegistryManagement', () => ({
|
||||
globalRegistryManagement: {
|
||||
languageModel: vi.fn()
|
||||
},
|
||||
DEFAULT_SEPARATOR: '|'
|
||||
}))
|
||||
|
||||
describe('RuntimeExecutor.generateText', () => {
|
||||
let executor: RuntimeExecutor<'openai'>
|
||||
let mockLanguageModel: any
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
|
||||
|
||||
mockLanguageModel = createMockLanguageModel({
|
||||
provider: 'openai',
|
||||
modelId: 'gpt-4'
|
||||
})
|
||||
|
||||
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
|
||||
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.simple as any)
|
||||
})
|
||||
|
||||
describe('Basic Functionality', () => {
|
||||
it('should generate text with minimal parameters', async () => {
|
||||
const result = await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
|
||||
expect(generateText).toHaveBeenCalledWith({
|
||||
model: mockLanguageModel,
|
||||
messages: testMessages.simple
|
||||
})
|
||||
|
||||
expect(result.text).toBe('This is a simple response.')
|
||||
expect(result.finishReason).toBe('stop')
|
||||
expect(result.usage).toBeDefined()
|
||||
})
|
||||
|
||||
it('should generate with system messages', async () => {
|
||||
await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.withSystem
|
||||
})
|
||||
|
||||
expect(generateText).toHaveBeenCalledWith({
|
||||
model: mockLanguageModel,
|
||||
messages: testMessages.withSystem
|
||||
})
|
||||
})
|
||||
|
||||
it('should generate with conversation history', async () => {
|
||||
await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.conversation
|
||||
})
|
||||
|
||||
expect(generateText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
messages: testMessages.conversation
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('All Parameter Combinations', () => {
|
||||
it('should support all parameters together', async () => {
|
||||
await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple,
|
||||
temperature: 0.7,
|
||||
maxOutputTokens: 500,
|
||||
topP: 0.9,
|
||||
frequencyPenalty: 0.5,
|
||||
presencePenalty: 0.3,
|
||||
stopSequences: ['STOP'],
|
||||
seed: 12345
|
||||
})
|
||||
|
||||
expect(generateText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
temperature: 0.7,
|
||||
maxOutputTokens: 500,
|
||||
topP: 0.9,
|
||||
frequencyPenalty: 0.5,
|
||||
presencePenalty: 0.3,
|
||||
stopSequences: ['STOP'],
|
||||
seed: 12345
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should support partial parameters', async () => {
|
||||
await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple,
|
||||
temperature: 0.5,
|
||||
maxOutputTokens: 100
|
||||
})
|
||||
|
||||
expect(generateText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
temperature: 0.5,
|
||||
maxOutputTokens: 100
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Tool Calling', () => {
|
||||
beforeEach(() => {
|
||||
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.withToolCalls as any)
|
||||
})
|
||||
|
||||
it('should support tool calling', async () => {
|
||||
const result = await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.toolUse,
|
||||
tools: testTools
|
||||
})
|
||||
|
||||
expect(generateText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
tools: testTools
|
||||
})
|
||||
)
|
||||
|
||||
expect(result.toolCalls).toBeDefined()
|
||||
expect(result.toolCalls).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('should support toolChoice auto', async () => {
|
||||
await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.toolUse,
|
||||
tools: testTools,
|
||||
toolChoice: 'auto'
|
||||
})
|
||||
|
||||
expect(generateText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
toolChoice: 'auto'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should support toolChoice required', async () => {
|
||||
await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.toolUse,
|
||||
tools: testTools,
|
||||
toolChoice: 'required'
|
||||
})
|
||||
|
||||
expect(generateText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
toolChoice: 'required'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should support toolChoice none', async () => {
|
||||
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.simple as any)
|
||||
|
||||
await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple,
|
||||
tools: testTools,
|
||||
toolChoice: 'none'
|
||||
})
|
||||
|
||||
expect(generateText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
toolChoice: 'none'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should support specific tool selection', async () => {
|
||||
await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.toolUse,
|
||||
tools: testTools,
|
||||
toolChoice: {
|
||||
type: 'tool',
|
||||
toolName: 'getWeather'
|
||||
}
|
||||
})
|
||||
|
||||
expect(generateText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
toolChoice: {
|
||||
type: 'tool',
|
||||
toolName: 'getWeather'
|
||||
}
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Multiple Providers', () => {
|
||||
it('should work with Anthropic provider', async () => {
|
||||
const anthropicExecutor = RuntimeExecutor.create('anthropic', mockProviderConfigs.anthropic)
|
||||
|
||||
const anthropicModel = createMockLanguageModel({
|
||||
provider: 'anthropic',
|
||||
modelId: 'claude-3-5-sonnet-20241022'
|
||||
})
|
||||
|
||||
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(anthropicModel)
|
||||
|
||||
await anthropicExecutor.generateText({
|
||||
model: 'claude-3-5-sonnet-20241022',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('anthropic|claude-3-5-sonnet-20241022')
|
||||
})
|
||||
|
||||
it('should work with Google provider', async () => {
|
||||
const googleExecutor = RuntimeExecutor.create('google', mockProviderConfigs.google)
|
||||
|
||||
const googleModel = createMockLanguageModel({
|
||||
provider: 'google',
|
||||
modelId: 'gemini-2.0-flash-exp'
|
||||
})
|
||||
|
||||
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(googleModel)
|
||||
|
||||
await googleExecutor.generateText({
|
||||
model: 'gemini-2.0-flash-exp',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('google|gemini-2.0-flash-exp')
|
||||
})
|
||||
|
||||
it('should work with xAI provider', async () => {
|
||||
const xaiExecutor = RuntimeExecutor.create('xai', mockProviderConfigs.xai)
|
||||
|
||||
const xaiModel = createMockLanguageModel({
|
||||
provider: 'xai',
|
||||
modelId: 'grok-2-latest'
|
||||
})
|
||||
|
||||
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(xaiModel)
|
||||
|
||||
await xaiExecutor.generateText({
|
||||
model: 'grok-2-latest',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('xai|grok-2-latest')
|
||||
})
|
||||
|
||||
it('should work with DeepSeek provider', async () => {
|
||||
const deepseekExecutor = RuntimeExecutor.create('deepseek', mockProviderConfigs.deepseek)
|
||||
|
||||
const deepseekModel = createMockLanguageModel({
|
||||
provider: 'deepseek',
|
||||
modelId: 'deepseek-chat'
|
||||
})
|
||||
|
||||
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(deepseekModel)
|
||||
|
||||
await deepseekExecutor.generateText({
|
||||
model: 'deepseek-chat',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
|
||||
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('deepseek|deepseek-chat')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Plugin Integration', () => {
|
||||
it('should execute all plugin hooks', async () => {
|
||||
const pluginCalls: string[] = []
|
||||
|
||||
const testPlugin: AiPlugin = {
|
||||
name: 'test-plugin',
|
||||
onRequestStart: vi.fn(async () => {
|
||||
pluginCalls.push('onRequestStart')
|
||||
}),
|
||||
transformParams: vi.fn(async (params) => {
|
||||
pluginCalls.push('transformParams')
|
||||
return { ...params, temperature: 0.8 }
|
||||
}),
|
||||
transformResult: vi.fn(async (result) => {
|
||||
pluginCalls.push('transformResult')
|
||||
return { ...result, text: result.text + ' [modified]' }
|
||||
}),
|
||||
onRequestEnd: vi.fn(async () => {
|
||||
pluginCalls.push('onRequestEnd')
|
||||
})
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
|
||||
|
||||
const result = await executorWithPlugin.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
|
||||
expect(pluginCalls).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
|
||||
|
||||
// Verify transformed parameters
|
||||
expect(generateText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
temperature: 0.8
|
||||
})
|
||||
)
|
||||
|
||||
// Verify transformed result
|
||||
expect(result.text).toContain('[modified]')
|
||||
})
|
||||
|
||||
it('should handle multiple plugins in order', async () => {
|
||||
const pluginOrder: string[] = []
|
||||
|
||||
const plugin1: AiPlugin = {
|
||||
name: 'plugin-1',
|
||||
transformParams: vi.fn(async (params) => {
|
||||
pluginOrder.push('plugin-1')
|
||||
return { ...params, temperature: 0.5 }
|
||||
})
|
||||
}
|
||||
|
||||
const plugin2: AiPlugin = {
|
||||
name: 'plugin-2',
|
||||
transformParams: vi.fn(async (params) => {
|
||||
pluginOrder.push('plugin-2')
|
||||
return { ...params, maxTokens: 200 }
|
||||
})
|
||||
}
|
||||
|
||||
const executorWithPlugins = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [plugin1, plugin2])
|
||||
|
||||
await executorWithPlugins.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
|
||||
expect(pluginOrder).toEqual(['plugin-1', 'plugin-2'])
|
||||
|
||||
expect(generateText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
temperature: 0.5,
|
||||
maxTokens: 200
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should handle API errors', async () => {
|
||||
const error = new Error('API request failed')
|
||||
vi.mocked(generateText).mockRejectedValue(error)
|
||||
|
||||
await expect(
|
||||
executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
).rejects.toThrow('API request failed')
|
||||
})
|
||||
|
||||
it('should execute onError plugin hook', async () => {
|
||||
const error = new Error('Generation failed')
|
||||
vi.mocked(generateText).mockRejectedValue(error)
|
||||
|
||||
const errorPlugin: AiPlugin = {
|
||||
name: 'error-handler',
|
||||
onError: vi.fn()
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
|
||||
|
||||
await expect(
|
||||
executorWithPlugin.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
).rejects.toThrow('Generation failed')
|
||||
|
||||
expect(errorPlugin.onError).toHaveBeenCalledWith(
|
||||
error,
|
||||
expect.objectContaining({
|
||||
providerId: 'openai',
|
||||
modelId: 'gpt-4'
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle model not found error', async () => {
|
||||
const error = new Error('Model not found: invalid-model')
|
||||
vi.mocked(globalRegistryManagement.languageModel).mockImplementation(() => {
|
||||
throw error
|
||||
})
|
||||
|
||||
await expect(
|
||||
executor.generateText({
|
||||
model: 'invalid-model',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
).rejects.toThrow('Model not found')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Usage and Metadata', () => {
|
||||
it('should return usage information', async () => {
|
||||
const result = await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
|
||||
expect(result.usage).toBeDefined()
|
||||
expect(result.usage.inputTokens).toBe(15)
|
||||
expect(result.usage.outputTokens).toBe(8)
|
||||
expect(result.usage.totalTokens).toBe(23)
|
||||
})
|
||||
|
||||
it('should handle warnings', async () => {
|
||||
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.withWarnings as any)
|
||||
|
||||
const result = await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple,
|
||||
temperature: 2.5 // Unsupported value
|
||||
})
|
||||
|
||||
expect(result.warnings).toBeDefined()
|
||||
expect(result.warnings).toHaveLength(1)
|
||||
expect(result.warnings![0].type).toBe('unsupported-setting')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Abort Signal', () => {
|
||||
it('should support abort signal', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
await executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple,
|
||||
abortSignal: abortController.signal
|
||||
})
|
||||
|
||||
expect(generateText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
abortSignal: abortController.signal
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle aborted request', async () => {
|
||||
const abortError = new Error('Request aborted')
|
||||
abortError.name = 'AbortError'
|
||||
|
||||
vi.mocked(generateText).mockRejectedValue(abortError)
|
||||
|
||||
const abortController = new AbortController()
|
||||
abortController.abort()
|
||||
|
||||
await expect(
|
||||
executor.generateText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple,
|
||||
abortSignal: abortController.signal
|
||||
})
|
||||
).rejects.toThrow('Request aborted')
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,525 +0,0 @@
|
||||
/**
|
||||
* RuntimeExecutor.streamText Comprehensive Tests
|
||||
* Tests streaming text generation across all providers with various parameters
|
||||
*/
|
||||
|
||||
import { streamText } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { collectStreamChunks, createMockLanguageModel, mockProviderConfigs, testMessages } from '../../../__tests__'
|
||||
import type { AiPlugin } from '../../plugins'
|
||||
import { globalRegistryManagement } from '../../providers/RegistryManagement'
|
||||
import { RuntimeExecutor } from '../executor'
|
||||
|
||||
// Mock AI SDK
|
||||
vi.mock('ai', () => ({
|
||||
streamText: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('../../providers/RegistryManagement', () => ({
|
||||
globalRegistryManagement: {
|
||||
languageModel: vi.fn()
|
||||
},
|
||||
DEFAULT_SEPARATOR: '|'
|
||||
}))
|
||||
|
||||
describe('RuntimeExecutor.streamText', () => {
|
||||
let executor: RuntimeExecutor<'openai'>
|
||||
let mockLanguageModel: any
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
|
||||
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
|
||||
|
||||
mockLanguageModel = createMockLanguageModel({
|
||||
provider: 'openai',
|
||||
modelId: 'gpt-4'
|
||||
})
|
||||
|
||||
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
|
||||
})
|
||||
|
||||
describe('Basic Functionality', () => {
|
||||
it('should stream text with minimal parameters', async () => {
|
||||
const mockStream = {
|
||||
textStream: (async function* () {
|
||||
yield 'Hello'
|
||||
yield ' '
|
||||
yield 'World'
|
||||
})(),
|
||||
fullStream: (async function* () {
|
||||
yield { type: 'text-delta', textDelta: 'Hello' }
|
||||
yield { type: 'text-delta', textDelta: ' ' }
|
||||
yield { type: 'text-delta', textDelta: 'World' }
|
||||
})(),
|
||||
usage: Promise.resolve({ promptTokens: 5, completionTokens: 3, totalTokens: 8 })
|
||||
}
|
||||
|
||||
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||
|
||||
const result = await executor.streamText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
|
||||
expect(streamText).toHaveBeenCalledWith({
|
||||
model: mockLanguageModel,
|
||||
messages: testMessages.simple
|
||||
})
|
||||
|
||||
const chunks = await collectStreamChunks(result.textStream)
|
||||
expect(chunks).toEqual(['Hello', ' ', 'World'])
|
||||
})
|
||||
|
||||
it('should stream with system messages', async () => {
|
||||
const mockStream = {
|
||||
textStream: (async function* () {
|
||||
yield 'Response'
|
||||
})(),
|
||||
fullStream: (async function* () {
|
||||
yield { type: 'text-delta', textDelta: 'Response' }
|
||||
})()
|
||||
}
|
||||
|
||||
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||
|
||||
await executor.streamText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.withSystem
|
||||
})
|
||||
|
||||
expect(streamText).toHaveBeenCalledWith({
|
||||
model: mockLanguageModel,
|
||||
messages: testMessages.withSystem
|
||||
})
|
||||
})
|
||||
|
||||
it('should stream multi-turn conversations', async () => {
|
||||
const mockStream = {
|
||||
textStream: (async function* () {
|
||||
yield 'Multi-turn response'
|
||||
})(),
|
||||
fullStream: (async function* () {
|
||||
yield { type: 'text-delta', textDelta: 'Multi-turn response' }
|
||||
})()
|
||||
}
|
||||
|
||||
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||
|
||||
await executor.streamText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.multiTurn
|
||||
})
|
||||
|
||||
expect(streamText).toHaveBeenCalled()
|
||||
expect(streamText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
messages: testMessages.multiTurn
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Temperature Parameter', () => {
|
||||
const temperatures = [0, 0.3, 0.5, 0.7, 0.9, 1.0, 1.5, 2.0]
|
||||
|
||||
it.each(temperatures)('should support temperature=%s', async (temperature) => {
|
||||
const mockStream = {
|
||||
textStream: (async function* () {
|
||||
yield 'Response'
|
||||
})(),
|
||||
fullStream: (async function* () {
|
||||
yield { type: 'text-delta', textDelta: 'Response' }
|
||||
})()
|
||||
}
|
||||
|
||||
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||
|
||||
await executor.streamText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple,
|
||||
temperature
|
||||
})
|
||||
|
||||
expect(streamText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
temperature
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Max Tokens Parameter', () => {
|
||||
const maxTokensValues = [10, 50, 100, 500, 1000, 2000, 4000]
|
||||
|
||||
it.each(maxTokensValues)('should support maxTokens=%s', async (maxTokens) => {
|
||||
const mockStream = {
|
||||
textStream: (async function* () {
|
||||
yield 'Response'
|
||||
})(),
|
||||
fullStream: (async function* () {
|
||||
yield { type: 'text-delta', textDelta: 'Response' }
|
||||
})()
|
||||
}
|
||||
|
||||
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||
|
||||
await executor.streamText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple,
|
||||
maxOutputTokens: maxTokens
|
||||
})
|
||||
|
||||
expect(streamText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
maxTokens
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Top P Parameter', () => {
|
||||
const topPValues = [0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 1.0]
|
||||
|
||||
it.each(topPValues)('should support topP=%s', async (topP) => {
|
||||
const mockStream = {
|
||||
textStream: (async function* () {
|
||||
yield 'Response'
|
||||
})(),
|
||||
fullStream: (async function* () {
|
||||
yield { type: 'text-delta', textDelta: 'Response' }
|
||||
})()
|
||||
}
|
||||
|
||||
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||
|
||||
await executor.streamText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple,
|
||||
topP
|
||||
})
|
||||
|
||||
expect(streamText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
topP
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Frequency and Presence Penalty', () => {
|
||||
it('should support frequency penalty', async () => {
|
||||
const penalties = [-2.0, -1.0, 0, 0.5, 1.0, 1.5, 2.0]
|
||||
|
||||
for (const frequencyPenalty of penalties) {
|
||||
vi.clearAllMocks()
|
||||
|
||||
const mockStream = {
|
||||
textStream: (async function* () {
|
||||
yield 'Response'
|
||||
})(),
|
||||
fullStream: (async function* () {
|
||||
yield { type: 'text-delta', textDelta: 'Response' }
|
||||
})()
|
||||
}
|
||||
|
||||
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||
|
||||
await executor.streamText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple,
|
||||
frequencyPenalty
|
||||
})
|
||||
|
||||
expect(streamText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
frequencyPenalty
|
||||
})
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
it('should support presence penalty', async () => {
|
||||
const penalties = [-2.0, -1.0, 0, 0.5, 1.0, 1.5, 2.0]
|
||||
|
||||
for (const presencePenalty of penalties) {
|
||||
vi.clearAllMocks()
|
||||
|
||||
const mockStream = {
|
||||
textStream: (async function* () {
|
||||
yield 'Response'
|
||||
})(),
|
||||
fullStream: (async function* () {
|
||||
yield { type: 'text-delta', textDelta: 'Response' }
|
||||
})()
|
||||
}
|
||||
|
||||
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||
|
||||
await executor.streamText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple,
|
||||
presencePenalty
|
||||
})
|
||||
|
||||
expect(streamText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
presencePenalty
|
||||
})
|
||||
)
|
||||
}
|
||||
})
|
||||
|
||||
it('should support both penalties together', async () => {
|
||||
const mockStream = {
|
||||
textStream: (async function* () {
|
||||
yield 'Response'
|
||||
})(),
|
||||
fullStream: (async function* () {
|
||||
yield { type: 'text-delta', textDelta: 'Response' }
|
||||
})()
|
||||
}
|
||||
|
||||
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||
|
||||
await executor.streamText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple,
|
||||
frequencyPenalty: 0.5,
|
||||
presencePenalty: 0.5
|
||||
})
|
||||
|
||||
expect(streamText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
frequencyPenalty: 0.5,
|
||||
presencePenalty: 0.5
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Seed Parameter', () => {
|
||||
it('should support seed for deterministic output', async () => {
|
||||
const seeds = [0, 12345, 67890, 999999]
|
||||
|
||||
for (const seed of seeds) {
|
||||
vi.clearAllMocks()
|
||||
|
||||
const mockStream = {
|
||||
textStream: (async function* () {
|
||||
yield 'Response'
|
||||
})(),
|
||||
fullStream: (async function* () {
|
||||
yield { type: 'text-delta', textDelta: 'Response' }
|
||||
})()
|
||||
}
|
||||
|
||||
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||
|
||||
await executor.streamText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple,
|
||||
seed
|
||||
})
|
||||
|
||||
expect(streamText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
seed
|
||||
})
|
||||
)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
describe('Abort Signal', () => {
|
||||
it('should support abort signal', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
const mockStream = {
|
||||
textStream: (async function* () {
|
||||
yield 'Response'
|
||||
})(),
|
||||
fullStream: (async function* () {
|
||||
yield { type: 'text-delta', textDelta: 'Response' }
|
||||
})()
|
||||
}
|
||||
|
||||
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||
|
||||
await executor.streamText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple,
|
||||
abortSignal: abortController.signal
|
||||
})
|
||||
|
||||
expect(streamText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
abortSignal: abortController.signal
|
||||
})
|
||||
)
|
||||
})
|
||||
|
||||
it('should handle abort during streaming', async () => {
|
||||
const abortController = new AbortController()
|
||||
|
||||
const mockStream = {
|
||||
textStream: (async function* () {
|
||||
yield 'Start'
|
||||
// Simulate abort
|
||||
abortController.abort()
|
||||
throw new Error('Aborted')
|
||||
})(),
|
||||
fullStream: (async function* () {
|
||||
yield { type: 'text-delta', textDelta: 'Start' }
|
||||
throw new Error('Aborted')
|
||||
})()
|
||||
}
|
||||
|
||||
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||
|
||||
const result = await executor.streamText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple,
|
||||
abortSignal: abortController.signal
|
||||
})
|
||||
|
||||
await expect(async () => {
|
||||
// oxlint-disable-next-line no-unused-vars
|
||||
for await (const _chunk of result.textStream) {
|
||||
// Stream should be interrupted
|
||||
}
|
||||
}).rejects.toThrow('Aborted')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Plugin Integration', () => {
|
||||
it('should execute plugins during streaming', async () => {
|
||||
const pluginCalls: string[] = []
|
||||
|
||||
const testPlugin: AiPlugin = {
|
||||
name: 'test-plugin',
|
||||
onRequestStart: vi.fn(async () => {
|
||||
pluginCalls.push('onRequestStart')
|
||||
}),
|
||||
transformParams: vi.fn(async (params) => {
|
||||
pluginCalls.push('transformParams')
|
||||
return { ...params, temperature: 0.5 }
|
||||
}),
|
||||
onRequestEnd: vi.fn(async () => {
|
||||
pluginCalls.push('onRequestEnd')
|
||||
})
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
|
||||
|
||||
const mockStream = {
|
||||
textStream: (async function* () {
|
||||
yield 'Response'
|
||||
})(),
|
||||
fullStream: (async function* () {
|
||||
yield { type: 'text-delta', textDelta: 'Response' }
|
||||
})()
|
||||
}
|
||||
|
||||
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||
|
||||
const result = await executorWithPlugin.streamText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
|
||||
// Consume stream
|
||||
// oxlint-disable-next-line no-unused-vars
|
||||
for await (const _chunk of result.textStream) {
|
||||
// Stream chunks
|
||||
}
|
||||
|
||||
expect(pluginCalls).toContain('onRequestStart')
|
||||
expect(pluginCalls).toContain('transformParams')
|
||||
|
||||
// Verify transformed parameters were used
|
||||
expect(streamText).toHaveBeenCalledWith(
|
||||
expect.objectContaining({
|
||||
temperature: 0.5
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Full Stream with Finish Reason', () => {
|
||||
it('should provide finish reason in full stream', async () => {
|
||||
const mockStream = {
|
||||
textStream: (async function* () {
|
||||
yield 'Response'
|
||||
})(),
|
||||
fullStream: (async function* () {
|
||||
yield { type: 'text-delta', textDelta: 'Response' }
|
||||
yield {
|
||||
type: 'finish',
|
||||
finishReason: 'stop',
|
||||
usage: { promptTokens: 5, completionTokens: 3, totalTokens: 8 }
|
||||
}
|
||||
})()
|
||||
}
|
||||
|
||||
vi.mocked(streamText).mockResolvedValue(mockStream as any)
|
||||
|
||||
const result = await executor.streamText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
|
||||
const fullChunks = await collectStreamChunks(result.fullStream)
|
||||
|
||||
expect(fullChunks).toHaveLength(2)
|
||||
expect(fullChunks[0]).toEqual({ type: 'text-delta', textDelta: 'Response' })
|
||||
expect(fullChunks[1]).toEqual({
|
||||
type: 'finish',
|
||||
finishReason: 'stop',
|
||||
usage: { promptTokens: 5, completionTokens: 3, totalTokens: 8 }
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Error Handling', () => {
|
||||
it('should handle streaming errors', async () => {
|
||||
const error = new Error('Streaming failed')
|
||||
vi.mocked(streamText).mockRejectedValue(error)
|
||||
|
||||
await expect(
|
||||
executor.streamText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
).rejects.toThrow('Streaming failed')
|
||||
})
|
||||
|
||||
it('should execute onError plugin hook on failure', async () => {
|
||||
const error = new Error('Stream error')
|
||||
vi.mocked(streamText).mockRejectedValue(error)
|
||||
|
||||
const errorPlugin: AiPlugin = {
|
||||
name: 'error-handler',
|
||||
onError: vi.fn()
|
||||
}
|
||||
|
||||
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
|
||||
|
||||
await expect(
|
||||
executorWithPlugin.streamText({
|
||||
model: 'gpt-4',
|
||||
messages: testMessages.simple
|
||||
})
|
||||
).rejects.toThrow('Stream error')
|
||||
|
||||
expect(errorPlugin.onError).toHaveBeenCalledWith(
|
||||
error,
|
||||
expect.objectContaining({
|
||||
providerId: 'openai',
|
||||
modelId: 'gpt-4'
|
||||
})
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -196,6 +196,9 @@ export enum IpcChannel {
|
||||
File_ValidateNotesDirectory = 'file:validateNotesDirectory',
|
||||
File_StartWatcher = 'file:startWatcher',
|
||||
File_StopWatcher = 'file:stopWatcher',
|
||||
File_PauseWatcher = 'file:pauseWatcher',
|
||||
File_ResumeWatcher = 'file:resumeWatcher',
|
||||
File_BatchUploadMarkdown = 'file:batchUploadMarkdown',
|
||||
File_ShowInFolder = 'file:showInFolder',
|
||||
|
||||
// file service
|
||||
@@ -235,7 +238,6 @@ export enum IpcChannel {
|
||||
System_GetDeviceType = 'system:getDeviceType',
|
||||
System_GetHostname = 'system:getHostname',
|
||||
System_GetCpuName = 'system:getCpuName',
|
||||
System_CheckGitBash = 'system:checkGitBash',
|
||||
|
||||
// DevTools
|
||||
System_ToggleDevTools = 'system:toggleDevTools',
|
||||
|
||||
@@ -10,7 +10,7 @@ export type LoaderReturn = {
|
||||
messageSource?: 'preprocess' | 'embedding' | 'validation'
|
||||
}
|
||||
|
||||
export type FileChangeEventType = 'add' | 'change' | 'unlink' | 'addDir' | 'unlinkDir'
|
||||
export type FileChangeEventType = 'add' | 'change' | 'unlink' | 'addDir' | 'unlinkDir' | 'refresh'
|
||||
|
||||
export type FileChangeEvent = {
|
||||
eventType: FileChangeEventType
|
||||
|
||||
@@ -4,34 +4,3 @@ export const defaultAppHeaders = () => {
|
||||
'X-Title': 'Cherry Studio'
|
||||
}
|
||||
}
|
||||
|
||||
// Following two function are not being used for now.
|
||||
// I may use them in the future, so just keep them commented. - by eurfelux
|
||||
|
||||
/**
|
||||
* Converts an `undefined` value to `null`, otherwise returns the value as-is.
|
||||
* @param value - The value to check
|
||||
* @returns `null` if the input is `undefined`; otherwise the input value
|
||||
*/
|
||||
|
||||
// export function toNullIfUndefined<T>(value: T | undefined): T | null {
|
||||
// if (value === undefined) {
|
||||
// return null
|
||||
// } else {
|
||||
// return value
|
||||
// }
|
||||
// }
|
||||
|
||||
/**
|
||||
* Converts a `null` value to `undefined`, otherwise returns the value as-is.
|
||||
* @param value - The value to check
|
||||
* @returns `undefined` if the input is `null`; otherwise the input value
|
||||
*/
|
||||
|
||||
// export function toUndefinedIfNull<T>(value: T | null): T | undefined {
|
||||
// if (value === null) {
|
||||
// return undefined
|
||||
// } else {
|
||||
// return value
|
||||
// }
|
||||
// }
|
||||
|
||||
@@ -104,6 +104,12 @@ const router = express
|
||||
logger.warn('No models available from providers', { filter })
|
||||
}
|
||||
|
||||
logger.info('Models response ready', {
|
||||
filter,
|
||||
total: response.total,
|
||||
modelIds: response.data.map((m) => m.id)
|
||||
})
|
||||
|
||||
return res.json(response satisfies ApiModelsResponse)
|
||||
} catch (error: any) {
|
||||
logger.error('Error fetching models', { error })
|
||||
|
||||
@@ -3,6 +3,7 @@ import { createServer } from 'node:http'
|
||||
import { loggerService } from '@logger'
|
||||
import { IpcChannel } from '@shared/IpcChannel'
|
||||
|
||||
import { agentService } from '../services/agents'
|
||||
import { windowService } from '../services/WindowService'
|
||||
import { app } from './app'
|
||||
import { config } from './config'
|
||||
@@ -31,6 +32,11 @@ export class ApiServer {
|
||||
// Load config
|
||||
const { port, host } = await config.load()
|
||||
|
||||
// Initialize AgentService
|
||||
logger.info('Initializing AgentService')
|
||||
await agentService.initialize()
|
||||
logger.info('AgentService initialized')
|
||||
|
||||
// Create server with Express app
|
||||
this.server = createServer(app)
|
||||
this.applyServerTimeouts(this.server)
|
||||
|
||||
@@ -32,7 +32,7 @@ export class ModelsService {
|
||||
|
||||
for (const model of models) {
|
||||
const provider = providers.find((p) => p.id === model.provider)
|
||||
// logger.debug(`Processing model ${model.id}`)
|
||||
logger.debug(`Processing model ${model.id}`)
|
||||
if (!provider) {
|
||||
logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`)
|
||||
continue
|
||||
|
||||
@@ -34,7 +34,6 @@ import { TrayService } from './services/TrayService'
|
||||
import { versionService } from './services/VersionService'
|
||||
import { windowService } from './services/WindowService'
|
||||
import { initWebviewHotkeys } from './services/WebviewService'
|
||||
import { runAsyncFunction } from './utils'
|
||||
|
||||
const logger = loggerService.withContext('MainEntry')
|
||||
|
||||
@@ -171,33 +170,39 @@ if (!app.requestSingleInstanceLock()) {
|
||||
//start selection assistant service
|
||||
initSelectionService()
|
||||
|
||||
runAsyncFunction(async () => {
|
||||
// Start API server if enabled or if agents exist
|
||||
try {
|
||||
const config = await apiServerService.getCurrentConfig()
|
||||
logger.info('API server config:', config)
|
||||
// Initialize Agent Service
|
||||
try {
|
||||
await agentService.initialize()
|
||||
logger.info('Agent service initialized successfully')
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to initialize Agent service:', error)
|
||||
}
|
||||
|
||||
// Check if there are any agents
|
||||
let shouldStart = config.enabled
|
||||
if (!shouldStart) {
|
||||
try {
|
||||
const { total } = await agentService.listAgents({ limit: 1 })
|
||||
if (total > 0) {
|
||||
shouldStart = true
|
||||
logger.info(`Detected ${total} agent(s), auto-starting API server`)
|
||||
}
|
||||
} catch (error: any) {
|
||||
logger.warn('Failed to check agent count:', error)
|
||||
// Start API server if enabled or if agents exist
|
||||
try {
|
||||
const config = await apiServerService.getCurrentConfig()
|
||||
logger.info('API server config:', config)
|
||||
|
||||
// Check if there are any agents
|
||||
let shouldStart = config.enabled
|
||||
if (!shouldStart) {
|
||||
try {
|
||||
const { total } = await agentService.listAgents({ limit: 1 })
|
||||
if (total > 0) {
|
||||
shouldStart = true
|
||||
logger.info(`Detected ${total} agent(s), auto-starting API server`)
|
||||
}
|
||||
} catch (error: any) {
|
||||
logger.warn('Failed to check agent count:', error)
|
||||
}
|
||||
|
||||
if (shouldStart) {
|
||||
await apiServerService.start()
|
||||
}
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to check/start API server:', error)
|
||||
}
|
||||
})
|
||||
|
||||
if (shouldStart) {
|
||||
await apiServerService.start()
|
||||
}
|
||||
} catch (error: any) {
|
||||
logger.error('Failed to check/start API server:', error)
|
||||
}
|
||||
})
|
||||
|
||||
registerProtocolClient(app)
|
||||
|
||||
@@ -493,44 +493,6 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux'))
|
||||
ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname())
|
||||
ipcMain.handle(IpcChannel.System_GetCpuName, () => require('os').cpus()[0].model)
|
||||
ipcMain.handle(IpcChannel.System_CheckGitBash, () => {
|
||||
if (!isWin) {
|
||||
return true // Non-Windows systems don't need Git Bash
|
||||
}
|
||||
|
||||
try {
|
||||
// Check common Git Bash installation paths
|
||||
const commonPaths = [
|
||||
path.join(process.env.ProgramFiles || 'C:\\Program Files', 'Git', 'bin', 'bash.exe'),
|
||||
path.join(process.env['ProgramFiles(x86)'] || 'C:\\Program Files (x86)', 'Git', 'bin', 'bash.exe'),
|
||||
path.join(process.env.LOCALAPPDATA || '', 'Programs', 'Git', 'bin', 'bash.exe')
|
||||
]
|
||||
|
||||
// Check if any of the common paths exist
|
||||
for (const bashPath of commonPaths) {
|
||||
if (fs.existsSync(bashPath)) {
|
||||
logger.debug('Git Bash found', { path: bashPath })
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check if git is in PATH
|
||||
const { execSync } = require('child_process')
|
||||
try {
|
||||
execSync('git --version', { stdio: 'ignore' })
|
||||
logger.debug('Git found in PATH')
|
||||
return true
|
||||
} catch {
|
||||
// Git not in PATH
|
||||
}
|
||||
|
||||
logger.debug('Git Bash not found on Windows system')
|
||||
return false
|
||||
} catch (error) {
|
||||
logger.error('Error checking Git Bash', error as Error)
|
||||
return false
|
||||
}
|
||||
})
|
||||
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
|
||||
const win = BrowserWindow.fromWebContents(e.sender)
|
||||
win && win.webContents.toggleDevTools()
|
||||
@@ -595,6 +557,9 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
ipcMain.handle(IpcChannel.File_ValidateNotesDirectory, fileManager.validateNotesDirectory.bind(fileManager))
|
||||
ipcMain.handle(IpcChannel.File_StartWatcher, fileManager.startFileWatcher.bind(fileManager))
|
||||
ipcMain.handle(IpcChannel.File_StopWatcher, fileManager.stopFileWatcher.bind(fileManager))
|
||||
ipcMain.handle(IpcChannel.File_PauseWatcher, fileManager.pauseFileWatcher.bind(fileManager))
|
||||
ipcMain.handle(IpcChannel.File_ResumeWatcher, fileManager.resumeFileWatcher.bind(fileManager))
|
||||
ipcMain.handle(IpcChannel.File_BatchUploadMarkdown, fileManager.batchUploadMarkdownFiles.bind(fileManager))
|
||||
ipcMain.handle(IpcChannel.File_ShowInFolder, fileManager.showInFolder.bind(fileManager))
|
||||
|
||||
// file service
|
||||
|
||||
@@ -8,7 +8,6 @@ import DiDiMcpServer from './didi-mcp'
|
||||
import DifyKnowledgeServer from './dify-knowledge'
|
||||
import FetchServer from './fetch'
|
||||
import FileSystemServer from './filesystem'
|
||||
import MCPUIDemoServer from './mcp-ui-demo'
|
||||
import MemoryServer from './memory'
|
||||
import PythonServer from './python'
|
||||
import ThinkingServer from './sequentialthinking'
|
||||
@@ -49,9 +48,6 @@ export function createInMemoryMCPServer(
|
||||
const apiKey = envs.DIDI_API_KEY
|
||||
return new DiDiMcpServer(apiKey).server
|
||||
}
|
||||
case BuiltinMCPServerNames.mcpUIDemo: {
|
||||
return new MCPUIDemoServer().server
|
||||
}
|
||||
default:
|
||||
throw new Error(`Unknown in-memory MCP server: ${name}`)
|
||||
}
|
||||
|
||||
@@ -1,433 +0,0 @@
|
||||
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
|
||||
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js'
|
||||
|
||||
const server = new Server(
|
||||
{
|
||||
name: 'mcp-ui-demo',
|
||||
version: '1.0.0'
|
||||
},
|
||||
{
|
||||
capabilities: {
|
||||
tools: {}
|
||||
}
|
||||
}
|
||||
)
|
||||
|
||||
// HTML templates for different UIs
|
||||
const getHelloWorldUI = () =>
|
||||
`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
body {
|
||||
font-family: system-ui, -apple-system, sans-serif;
|
||||
padding: 20px;
|
||||
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
|
||||
color: white;
|
||||
min-height: 200px;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
margin: 0;
|
||||
}
|
||||
.container {
|
||||
text-align: center;
|
||||
}
|
||||
h1 {
|
||||
font-size: 2.5em;
|
||||
margin-bottom: 10px;
|
||||
text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
|
||||
}
|
||||
p {
|
||||
font-size: 1.2em;
|
||||
opacity: 0.9;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h1>🎉 Hello from MCP UI!</h1>
|
||||
<p>This is a simple MCP UI Resource rendered in Cherry Studio</p>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
`.trim()
|
||||
|
||||
const getInteractiveUI = () =>
|
||||
`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
body {
|
||||
font-family: system-ui, -apple-system, sans-serif;
|
||||
padding: 20px;
|
||||
background: #f5f5f5;
|
||||
margin: 0;
|
||||
}
|
||||
.container {
|
||||
max-width: 600px;
|
||||
margin: 0 auto;
|
||||
background: white;
|
||||
padding: 30px;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
||||
}
|
||||
h2 {
|
||||
color: #333;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
button {
|
||||
background: #667eea;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 12px 24px;
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
font-size: 16px;
|
||||
margin: 5px;
|
||||
transition: background 0.2s;
|
||||
}
|
||||
button:hover {
|
||||
background: #5568d3;
|
||||
}
|
||||
#output {
|
||||
margin-top: 20px;
|
||||
padding: 15px;
|
||||
background: #f8f9fa;
|
||||
border-radius: 6px;
|
||||
border: 1px solid #e0e0e0;
|
||||
min-height: 50px;
|
||||
}
|
||||
.info {
|
||||
color: #666;
|
||||
font-size: 14px;
|
||||
margin-top: 15px;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h2>Interactive MCP UI Demo</h2>
|
||||
<p>Click the buttons to interact with MCP tools:</p>
|
||||
|
||||
<button onclick="callEchoTool()">Call Echo Tool</button>
|
||||
<button onclick="getTimestamp()">Get Timestamp</button>
|
||||
<button onclick="openLink()">Open External Link</button>
|
||||
|
||||
<div id="output"></div>
|
||||
|
||||
<div class="info">
|
||||
This UI can communicate with the host application through postMessage API.
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
function callEchoTool() {
|
||||
const output = document.getElementById('output');
|
||||
output.innerHTML = '<p style="color: #0066cc;">Calling echo tool...</p>';
|
||||
|
||||
window.parent.postMessage({
|
||||
type: 'tool',
|
||||
payload: {
|
||||
toolName: 'demo_echo',
|
||||
params: {
|
||||
message: 'Hello from MCP UI! Time: ' + new Date().toLocaleTimeString()
|
||||
}
|
||||
}
|
||||
}, '*');
|
||||
}
|
||||
|
||||
function getTimestamp() {
|
||||
const output = document.getElementById('output');
|
||||
const now = new Date();
|
||||
output.innerHTML = \`
|
||||
<p style="color: #00aa00;">
|
||||
<strong>Current Timestamp:</strong><br/>
|
||||
\${now.toLocaleString()}<br/>
|
||||
Unix: \${Math.floor(now.getTime() / 1000)}
|
||||
</p>
|
||||
\`;
|
||||
}
|
||||
|
||||
function openLink() {
|
||||
window.parent.postMessage({
|
||||
type: 'link',
|
||||
payload: {
|
||||
url: 'https://github.com/idosal/mcp-ui'
|
||||
}
|
||||
}, '*');
|
||||
}
|
||||
|
||||
// Listen for responses
|
||||
window.addEventListener('message', (event) => {
|
||||
if (event.data.type === 'ui-message-response') {
|
||||
const output = document.getElementById('output');
|
||||
const { response, error } = event.data.payload;
|
||||
|
||||
if (error) {
|
||||
output.innerHTML = \`<p style="color: #cc0000;">Error: \${error}</p>\`;
|
||||
} else {
|
||||
output.innerHTML = \`<p style="color: #00aa00;">Response: \${JSON.stringify(response, null, 2)}</p>\`;
|
||||
}
|
||||
}
|
||||
});
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
`.trim()
|
||||
|
||||
const getFormUI = () =>
|
||||
`
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<style>
|
||||
body {
|
||||
font-family: system-ui, -apple-system, sans-serif;
|
||||
padding: 20px;
|
||||
background: #f5f5f5;
|
||||
margin: 0;
|
||||
}
|
||||
.container {
|
||||
max-width: 500px;
|
||||
margin: 0 auto;
|
||||
background: white;
|
||||
padding: 30px;
|
||||
border-radius: 12px;
|
||||
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
|
||||
}
|
||||
h2 {
|
||||
color: #333;
|
||||
margin-bottom: 20px;
|
||||
}
|
||||
.form-group {
|
||||
margin-bottom: 15px;
|
||||
}
|
||||
label {
|
||||
display: block;
|
||||
margin-bottom: 5px;
|
||||
color: #555;
|
||||
font-weight: 500;
|
||||
}
|
||||
input, textarea {
|
||||
width: 100%;
|
||||
padding: 10px;
|
||||
border: 1px solid #ddd;
|
||||
border-radius: 6px;
|
||||
font-size: 14px;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
textarea {
|
||||
min-height: 100px;
|
||||
resize: vertical;
|
||||
}
|
||||
button {
|
||||
background: #667eea;
|
||||
color: white;
|
||||
border: none;
|
||||
padding: 12px 24px;
|
||||
border-radius: 6px;
|
||||
cursor: pointer;
|
||||
font-size: 16px;
|
||||
width: 100%;
|
||||
margin-top: 10px;
|
||||
}
|
||||
button:hover {
|
||||
background: #5568d3;
|
||||
}
|
||||
#result {
|
||||
margin-top: 20px;
|
||||
padding: 15px;
|
||||
background: #f8f9fa;
|
||||
border-radius: 6px;
|
||||
display: none;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<h2>📝 Form UI Demo</h2>
|
||||
<form id="demoForm" onsubmit="handleSubmit(event)">
|
||||
<div class="form-group">
|
||||
<label for="name">Name:</label>
|
||||
<input type="text" id="name" name="name" required placeholder="Enter your name">
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="email">Email:</label>
|
||||
<input type="email" id="email" name="email" required placeholder="your@email.com">
|
||||
</div>
|
||||
|
||||
<div class="form-group">
|
||||
<label for="message">Message:</label>
|
||||
<textarea id="message" name="message" required placeholder="Enter your message here..."></textarea>
|
||||
</div>
|
||||
|
||||
<button type="submit">Submit Form</button>
|
||||
</form>
|
||||
|
||||
<div id="result"></div>
|
||||
</div>
|
||||
|
||||
<script>
|
||||
function handleSubmit(event) {
|
||||
event.preventDefault();
|
||||
|
||||
const formData = new FormData(event.target);
|
||||
const data = Object.fromEntries(formData.entries());
|
||||
|
||||
const result = document.getElementById('result');
|
||||
result.style.display = 'block';
|
||||
result.innerHTML = '<p style="color: #0066cc;">Submitting form...</p>';
|
||||
|
||||
// Send form data to host
|
||||
window.parent.postMessage({
|
||||
type: 'notify',
|
||||
payload: {
|
||||
message: 'Form submitted with data: ' + JSON.stringify(data)
|
||||
}
|
||||
}, '*');
|
||||
|
||||
// Display result
|
||||
result.innerHTML = \`
|
||||
<p style="color: #00aa00;"><strong>Form Submitted!</strong></p>
|
||||
<pre style="background: white; padding: 10px; border-radius: 4px; overflow-x: auto;">\${JSON.stringify(data, null, 2)}</pre>
|
||||
\`;
|
||||
}
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
`.trim()
|
||||
|
||||
// List available tools
|
||||
server.setRequestHandler(ListToolsRequestSchema, async () => {
|
||||
return {
|
||||
tools: [
|
||||
{
|
||||
name: 'demo_echo',
|
||||
description: 'Echo back the message sent from UI',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
message: {
|
||||
type: 'string',
|
||||
description: 'Message to echo back'
|
||||
}
|
||||
},
|
||||
required: ['message']
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'show_hello_ui',
|
||||
description: 'Display a simple hello world UI with gradient background',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'show_interactive_ui',
|
||||
description:
|
||||
'Display an interactive UI demo with buttons for calling tools, getting timestamps, and opening links',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
},
|
||||
{
|
||||
name: 'show_form_ui',
|
||||
description: 'Display a form UI demo with input fields for name, email, and message',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
|
||||
// Handle tool calls
|
||||
server.setRequestHandler(CallToolRequestSchema, async (request) => {
|
||||
const { name, arguments: args } = request.params
|
||||
|
||||
if (name === 'demo_echo') {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: JSON.stringify({
|
||||
echo: args?.message || 'No message provided',
|
||||
timestamp: new Date().toISOString()
|
||||
})
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
if (name === 'show_hello_ui') {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: JSON.stringify({
|
||||
type: 'resource',
|
||||
resource: {
|
||||
uri: 'ui://demo/hello',
|
||||
mimeType: 'text/html',
|
||||
text: getHelloWorldUI()
|
||||
}
|
||||
})
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
if (name === 'show_interactive_ui') {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: JSON.stringify({
|
||||
type: 'resource',
|
||||
resource: {
|
||||
uri: 'ui://demo/interactive',
|
||||
mimeType: 'text/html',
|
||||
text: getInteractiveUI()
|
||||
}
|
||||
})
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
if (name === 'show_form_ui') {
|
||||
return {
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: JSON.stringify({
|
||||
type: 'resource',
|
||||
resource: {
|
||||
uri: 'ui://demo/form',
|
||||
mimeType: 'text/html',
|
||||
text: getFormUI()
|
||||
}
|
||||
})
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
throw new Error(`Unknown tool: ${name}`)
|
||||
})
|
||||
|
||||
class MCPUIDemoServer {
|
||||
public server: Server
|
||||
constructor() {
|
||||
this.server = server
|
||||
}
|
||||
}
|
||||
|
||||
export default MCPUIDemoServer
|
||||
@@ -1605,6 +1605,164 @@ class FileStorage {
|
||||
logger.error('Failed to show item in folder:', error as Error)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Batch upload markdown files from native File objects
|
||||
* This handles all I/O operations in the Main process to avoid blocking Renderer
|
||||
*/
|
||||
public batchUploadMarkdownFiles = async (
|
||||
_: Electron.IpcMainInvokeEvent,
|
||||
filePaths: string[],
|
||||
targetPath: string
|
||||
): Promise<{
|
||||
fileCount: number
|
||||
folderCount: number
|
||||
skippedFiles: number
|
||||
}> => {
|
||||
try {
|
||||
logger.info('Starting batch upload', { fileCount: filePaths.length, targetPath })
|
||||
|
||||
const basePath = path.resolve(targetPath)
|
||||
const MARKDOWN_EXTS = ['.md', '.markdown']
|
||||
|
||||
// Filter markdown files
|
||||
const markdownFiles = filePaths.filter((filePath) => {
|
||||
const ext = path.extname(filePath).toLowerCase()
|
||||
return MARKDOWN_EXTS.includes(ext)
|
||||
})
|
||||
|
||||
const skippedFiles = filePaths.length - markdownFiles.length
|
||||
|
||||
if (markdownFiles.length === 0) {
|
||||
return { fileCount: 0, folderCount: 0, skippedFiles }
|
||||
}
|
||||
|
||||
// Collect unique folders needed
|
||||
const foldersSet = new Set<string>()
|
||||
const fileOperations: Array<{ sourcePath: string; targetPath: string }> = []
|
||||
|
||||
for (const filePath of markdownFiles) {
|
||||
try {
|
||||
// Get relative path if file is from a directory upload
|
||||
const fileName = path.basename(filePath)
|
||||
const relativePath = path.dirname(filePath)
|
||||
|
||||
// Determine target directory structure
|
||||
let targetDir = basePath
|
||||
const folderParts: string[] = []
|
||||
|
||||
// Extract folder structure from file path for nested uploads
|
||||
// This is a simplified version - in real scenario we'd need the original directory structure
|
||||
if (relativePath && relativePath !== '.') {
|
||||
const parts = relativePath.split(path.sep)
|
||||
// Get the last few parts that represent the folder structure within upload
|
||||
const relevantParts = parts.slice(Math.max(0, parts.length - 3))
|
||||
folderParts.push(...relevantParts)
|
||||
}
|
||||
|
||||
// Build target directory path
|
||||
for (const part of folderParts) {
|
||||
targetDir = path.join(targetDir, part)
|
||||
foldersSet.add(targetDir)
|
||||
}
|
||||
|
||||
// Determine final file name
|
||||
const nameWithoutExt = fileName.endsWith('.md')
|
||||
? fileName.slice(0, -3)
|
||||
: fileName.endsWith('.markdown')
|
||||
? fileName.slice(0, -9)
|
||||
: fileName
|
||||
|
||||
const { safeName } = await this.fileNameGuard(_, targetDir, nameWithoutExt, true)
|
||||
const finalPath = path.join(targetDir, safeName + '.md')
|
||||
|
||||
fileOperations.push({ sourcePath: filePath, targetPath: finalPath })
|
||||
} catch (error) {
|
||||
logger.error('Failed to prepare file operation:', error as Error, { filePath })
|
||||
}
|
||||
}
|
||||
|
||||
// Create folders in order (shallow to deep)
|
||||
const sortedFolders = Array.from(foldersSet).sort((a, b) => a.length - b.length)
|
||||
for (const folder of sortedFolders) {
|
||||
try {
|
||||
if (!fs.existsSync(folder)) {
|
||||
await fs.promises.mkdir(folder, { recursive: true })
|
||||
}
|
||||
} catch (error) {
|
||||
logger.debug('Folder already exists or creation failed', { folder, error: (error as Error).message })
|
||||
}
|
||||
}
|
||||
|
||||
// Process files in batches
|
||||
const BATCH_SIZE = 10 // Higher batch size since we're in Main process
|
||||
let successCount = 0
|
||||
|
||||
for (let i = 0; i < fileOperations.length; i += BATCH_SIZE) {
|
||||
const batch = fileOperations.slice(i, i + BATCH_SIZE)
|
||||
|
||||
const results = await Promise.allSettled(
|
||||
batch.map(async (op) => {
|
||||
// Read from source and write to target in Main process
|
||||
const content = await fs.promises.readFile(op.sourcePath, 'utf-8')
|
||||
await fs.promises.writeFile(op.targetPath, content, 'utf-8')
|
||||
return true
|
||||
})
|
||||
)
|
||||
|
||||
results.forEach((result, index) => {
|
||||
if (result.status === 'fulfilled') {
|
||||
successCount++
|
||||
} else {
|
||||
logger.error('Failed to upload file:', result.reason, {
|
||||
file: batch[index].sourcePath
|
||||
})
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
logger.info('Batch upload completed', {
|
||||
successCount,
|
||||
folderCount: foldersSet.size,
|
||||
skippedFiles
|
||||
})
|
||||
|
||||
return {
|
||||
fileCount: successCount,
|
||||
folderCount: foldersSet.size,
|
||||
skippedFiles
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Batch upload failed:', error as Error)
|
||||
throw error
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Pause file watcher to prevent events during batch operations
|
||||
*/
|
||||
public pauseFileWatcher = async (): Promise<void> => {
|
||||
if (this.watcher) {
|
||||
logger.debug('Pausing file watcher')
|
||||
// Chokidar doesn't have pause, so we temporarily set a flag
|
||||
// We'll handle this by clearing the debounce timer
|
||||
if (this.debounceTimer) {
|
||||
clearTimeout(this.debounceTimer)
|
||||
this.debounceTimer = undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Resume file watcher and trigger a refresh
|
||||
*/
|
||||
public resumeFileWatcher = async (): Promise<void> => {
|
||||
if (this.watcher && this.currentWatchPath) {
|
||||
logger.debug('Resuming file watcher')
|
||||
// Send a synthetic refresh event to trigger tree reload
|
||||
this.notifyChange('refresh', this.currentWatchPath)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
export const fileStorage = new FileStorage()
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
import { type Client, createClient } from '@libsql/client'
|
||||
import { loggerService } from '@logger'
|
||||
import { mcpApiService } from '@main/apiServer/services/mcp'
|
||||
import type { ModelValidationError } from '@main/apiServer/utils'
|
||||
import { validateModelId } from '@main/apiServer/utils'
|
||||
import type { AgentType, MCPTool, SlashCommand, Tool } from '@types'
|
||||
import { objectKeys } from '@types'
|
||||
import { drizzle, type LibSQLDatabase } from 'drizzle-orm/libsql'
|
||||
import fs from 'fs'
|
||||
import path from 'path'
|
||||
|
||||
import { DatabaseManager } from './database/DatabaseManager'
|
||||
import { MigrationService } from './database/MigrationService'
|
||||
import * as schema from './database/schema'
|
||||
import { dbPath } from './drizzle.config'
|
||||
import type { AgentModelField } from './errors'
|
||||
import { AgentModelValidationError } from './errors'
|
||||
import { builtinSlashCommands } from './services/claudecode/commands'
|
||||
@@ -16,16 +20,22 @@ import { builtinTools } from './services/claudecode/tools'
|
||||
const logger = loggerService.withContext('BaseService')
|
||||
|
||||
/**
|
||||
* Base service class providing shared utilities for all agent-related services.
|
||||
* Base service class providing shared database connection and utilities
|
||||
* for all agent-related services.
|
||||
*
|
||||
* Features:
|
||||
* - Database access through DatabaseManager singleton
|
||||
* - JSON field serialization/deserialization
|
||||
* - Path validation and creation
|
||||
* - Model validation
|
||||
* - MCP tools and slash commands listing
|
||||
* - Programmatic schema management (no CLI dependencies)
|
||||
* - Automatic table creation and migration
|
||||
* - Schema version tracking and compatibility checks
|
||||
* - Transaction-based operations for safety
|
||||
* - Development vs production mode handling
|
||||
* - Connection retry logic with exponential backoff
|
||||
*/
|
||||
export abstract class BaseService {
|
||||
protected static client: Client | null = null
|
||||
protected static db: LibSQLDatabase<typeof schema> | null = null
|
||||
protected static isInitialized = false
|
||||
protected static initializationPromise: Promise<void> | null = null
|
||||
protected jsonFields: string[] = [
|
||||
'tools',
|
||||
'mcps',
|
||||
@@ -35,6 +45,23 @@ export abstract class BaseService {
|
||||
'slash_commands'
|
||||
]
|
||||
|
||||
/**
|
||||
* Initialize database with retry logic and proper error handling
|
||||
*/
|
||||
protected static async initialize(): Promise<void> {
|
||||
// Return existing initialization if in progress
|
||||
if (BaseService.initializationPromise) {
|
||||
return BaseService.initializationPromise
|
||||
}
|
||||
|
||||
if (BaseService.isInitialized) {
|
||||
return
|
||||
}
|
||||
|
||||
BaseService.initializationPromise = BaseService.performInitialization()
|
||||
return BaseService.initializationPromise
|
||||
}
|
||||
|
||||
public async listMcpTools(agentType: AgentType, ids?: string[]): Promise<Tool[]> {
|
||||
const tools: Tool[] = []
|
||||
if (agentType === 'claude-code') {
|
||||
@@ -74,13 +101,78 @@ export abstract class BaseService {
|
||||
return []
|
||||
}
|
||||
|
||||
/**
|
||||
* Get database instance
|
||||
* Automatically waits for initialization to complete
|
||||
*/
|
||||
protected async getDatabase() {
|
||||
const dbManager = await DatabaseManager.getInstance()
|
||||
return dbManager.getDatabase()
|
||||
private static async performInitialization(): Promise<void> {
|
||||
const maxRetries = 3
|
||||
let lastError: Error
|
||||
|
||||
for (let attempt = 1; attempt <= maxRetries; attempt++) {
|
||||
try {
|
||||
logger.info(`Initializing Agent database at: ${dbPath} (attempt ${attempt}/${maxRetries})`)
|
||||
|
||||
// Ensure the database directory exists
|
||||
const dbDir = path.dirname(dbPath)
|
||||
if (!fs.existsSync(dbDir)) {
|
||||
logger.info(`Creating database directory: ${dbDir}`)
|
||||
fs.mkdirSync(dbDir, { recursive: true })
|
||||
}
|
||||
|
||||
BaseService.client = createClient({
|
||||
url: `file:${dbPath}`
|
||||
})
|
||||
|
||||
BaseService.db = drizzle(BaseService.client, { schema })
|
||||
|
||||
// Run database migrations
|
||||
const migrationService = new MigrationService(BaseService.db, BaseService.client)
|
||||
await migrationService.runMigrations()
|
||||
|
||||
BaseService.isInitialized = true
|
||||
logger.info('Agent database initialized successfully')
|
||||
return
|
||||
} catch (error) {
|
||||
lastError = error as Error
|
||||
logger.warn(`Database initialization attempt ${attempt} failed:`, lastError)
|
||||
|
||||
// Clean up on failure
|
||||
if (BaseService.client) {
|
||||
try {
|
||||
BaseService.client.close()
|
||||
} catch (closeError) {
|
||||
logger.warn('Failed to close client during cleanup:', closeError as Error)
|
||||
}
|
||||
}
|
||||
BaseService.client = null
|
||||
BaseService.db = null
|
||||
|
||||
// Wait before retrying (exponential backoff)
|
||||
if (attempt < maxRetries) {
|
||||
const delay = Math.pow(2, attempt) * 1000 // 2s, 4s, 8s
|
||||
logger.info(`Retrying in ${delay}ms...`)
|
||||
await new Promise((resolve) => setTimeout(resolve, delay))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// All retries failed
|
||||
BaseService.initializationPromise = null
|
||||
logger.error('Failed to initialize Agent database after all retries:', lastError!)
|
||||
throw lastError!
|
||||
}
|
||||
|
||||
protected ensureInitialized(): void {
|
||||
if (!BaseService.isInitialized || !BaseService.db || !BaseService.client) {
|
||||
throw new Error('Database not initialized. Call initialize() first.')
|
||||
}
|
||||
}
|
||||
|
||||
protected get database(): LibSQLDatabase<typeof schema> {
|
||||
this.ensureInitialized()
|
||||
return BaseService.db!
|
||||
}
|
||||
|
||||
protected get rawClient(): Client {
|
||||
this.ensureInitialized()
|
||||
return BaseService.client!
|
||||
}
|
||||
|
||||
protected serializeJsonFields(data: any): any {
|
||||
@@ -192,7 +284,7 @@ export abstract class BaseService {
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate agent model configuration
|
||||
* Force re-initialization (for development/testing)
|
||||
*/
|
||||
protected async validateAgentModels(
|
||||
agentType: AgentType,
|
||||
@@ -233,4 +325,22 @@ export abstract class BaseService {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static async reinitialize(): Promise<void> {
|
||||
BaseService.isInitialized = false
|
||||
BaseService.initializationPromise = null
|
||||
|
||||
if (BaseService.client) {
|
||||
try {
|
||||
BaseService.client.close()
|
||||
} catch (error) {
|
||||
logger.warn('Failed to close client during reinitialize:', error as Error)
|
||||
}
|
||||
}
|
||||
|
||||
BaseService.client = null
|
||||
BaseService.db = null
|
||||
|
||||
await BaseService.initialize()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,156 +0,0 @@
|
||||
import { type Client, createClient } from '@libsql/client'
|
||||
import { loggerService } from '@logger'
|
||||
import type { LibSQLDatabase } from 'drizzle-orm/libsql'
|
||||
import { drizzle } from 'drizzle-orm/libsql'
|
||||
import fs from 'fs'
|
||||
import path from 'path'
|
||||
|
||||
import { dbPath } from '../drizzle.config'
|
||||
import { MigrationService } from './MigrationService'
|
||||
import * as schema from './schema'
|
||||
|
||||
const logger = loggerService.withContext('DatabaseManager')
|
||||
|
||||
/**
|
||||
* Database initialization state
|
||||
*/
|
||||
enum InitState {
|
||||
INITIALIZING = 'initializing',
|
||||
INITIALIZED = 'initialized',
|
||||
FAILED = 'failed'
|
||||
}
|
||||
|
||||
/**
|
||||
* DatabaseManager - Singleton class for managing libsql database connections
|
||||
*
|
||||
* Responsibilities:
|
||||
* - Single source of truth for database connection
|
||||
* - Thread-safe initialization with state management
|
||||
* - Automatic migration handling
|
||||
* - Safe connection cleanup
|
||||
* - Error recovery and retry logic
|
||||
* - Windows platform compatibility fixes
|
||||
*/
|
||||
export class DatabaseManager {
|
||||
private static instance: DatabaseManager | null = null
|
||||
|
||||
private client: Client | null = null
|
||||
private db: LibSQLDatabase<typeof schema> | null = null
|
||||
private state: InitState = InitState.INITIALIZING
|
||||
|
||||
/**
|
||||
* Get the singleton instance (database initialization starts automatically)
|
||||
*/
|
||||
public static async getInstance(): Promise<DatabaseManager> {
|
||||
if (DatabaseManager.instance) {
|
||||
return DatabaseManager.instance
|
||||
}
|
||||
|
||||
const instance = new DatabaseManager()
|
||||
await instance.initialize()
|
||||
DatabaseManager.instance = instance
|
||||
|
||||
return instance
|
||||
}
|
||||
|
||||
/**
|
||||
* Perform the actual initialization
|
||||
*/
|
||||
public async initialize(): Promise<void> {
|
||||
if (this.state === InitState.INITIALIZED) {
|
||||
return
|
||||
}
|
||||
|
||||
try {
|
||||
logger.info(`Initializing database at: ${dbPath}`)
|
||||
|
||||
// Ensure database directory exists
|
||||
const dbDir = path.dirname(dbPath)
|
||||
if (!fs.existsSync(dbDir)) {
|
||||
logger.info(`Creating database directory: ${dbDir}`)
|
||||
fs.mkdirSync(dbDir, { recursive: true })
|
||||
}
|
||||
|
||||
// Check if database file is corrupted (Windows specific check)
|
||||
if (fs.existsSync(dbPath)) {
|
||||
const stats = fs.statSync(dbPath)
|
||||
if (stats.size === 0) {
|
||||
logger.warn('Database file is empty, removing corrupted file')
|
||||
fs.unlinkSync(dbPath)
|
||||
}
|
||||
}
|
||||
|
||||
// Create client with platform-specific options
|
||||
this.client = createClient({
|
||||
url: `file:${dbPath}`,
|
||||
// intMode: 'number' helps avoid some Windows compatibility issues
|
||||
intMode: 'number'
|
||||
})
|
||||
|
||||
// Create drizzle instance
|
||||
this.db = drizzle(this.client, { schema })
|
||||
|
||||
// Run migrations
|
||||
const migrationService = new MigrationService(this.db, this.client)
|
||||
await migrationService.runMigrations()
|
||||
|
||||
this.state = InitState.INITIALIZED
|
||||
logger.info('Database initialized successfully')
|
||||
} catch (error) {
|
||||
const err = error as Error
|
||||
logger.error('Database initialization failed:', {
|
||||
error: err.message,
|
||||
stack: err.stack
|
||||
})
|
||||
|
||||
// Clean up failed initialization
|
||||
this.cleanupFailedInit()
|
||||
|
||||
// Set failed state
|
||||
this.state = InitState.FAILED
|
||||
throw new Error(`Database initialization failed: ${err.message || 'Unknown error'}`)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Clean up after failed initialization
|
||||
*/
|
||||
private cleanupFailedInit(): void {
|
||||
if (this.client) {
|
||||
try {
|
||||
// On Windows, closing a partially initialized client can crash
|
||||
// Wrap in try-catch and ignore errors during cleanup
|
||||
this.client.close()
|
||||
} catch (error) {
|
||||
logger.warn('Failed to close client during cleanup:', error as Error)
|
||||
}
|
||||
}
|
||||
this.client = null
|
||||
this.db = null
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the database instance
|
||||
* Automatically waits for initialization to complete
|
||||
* @throws Error if database initialization failed
|
||||
*/
|
||||
public getDatabase(): LibSQLDatabase<typeof schema> {
|
||||
return this.db!
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the raw client (for advanced operations)
|
||||
* Automatically waits for initialization to complete
|
||||
* @throws Error if database initialization failed
|
||||
*/
|
||||
public async getClient(): Promise<Client> {
|
||||
return this.client!
|
||||
}
|
||||
|
||||
/**
|
||||
* Check if database is initialized
|
||||
*/
|
||||
public isInitialized(): boolean {
|
||||
return this.state === InitState.INITIALIZED
|
||||
}
|
||||
}
|
||||
@@ -7,14 +7,8 @@
|
||||
* Schema evolution is handled by Drizzle Kit migrations.
|
||||
*/
|
||||
|
||||
// Database Manager (Singleton)
|
||||
export * from './DatabaseManager'
|
||||
|
||||
// Drizzle ORM schemas
|
||||
export * from './schema'
|
||||
|
||||
// Repository helpers
|
||||
export * from './sessionMessageRepository'
|
||||
|
||||
// Migration Service
|
||||
export * from './MigrationService'
|
||||
|
||||
@@ -15,16 +15,26 @@ import { sessionMessagesTable } from './schema'
|
||||
|
||||
const logger = loggerService.withContext('AgentMessageRepository')
|
||||
|
||||
type TxClient = any
|
||||
|
||||
export type PersistUserMessageParams = AgentMessageUserPersistPayload & {
|
||||
sessionId: string
|
||||
agentSessionId?: string
|
||||
tx?: TxClient
|
||||
}
|
||||
|
||||
export type PersistAssistantMessageParams = AgentMessageAssistantPersistPayload & {
|
||||
sessionId: string
|
||||
agentSessionId: string
|
||||
tx?: TxClient
|
||||
}
|
||||
|
||||
type PersistExchangeParams = AgentMessagePersistExchangePayload & {
|
||||
tx?: TxClient
|
||||
}
|
||||
|
||||
type PersistExchangeResult = AgentMessagePersistExchangeResult
|
||||
|
||||
class AgentMessageRepository extends BaseService {
|
||||
private static instance: AgentMessageRepository | null = null
|
||||
|
||||
@@ -77,13 +87,17 @@ class AgentMessageRepository extends BaseService {
|
||||
return deserialized
|
||||
}
|
||||
|
||||
private getWriter(tx?: TxClient): TxClient {
|
||||
return tx ?? this.database
|
||||
}
|
||||
|
||||
private async findExistingMessageRow(
|
||||
writer: TxClient,
|
||||
sessionId: string,
|
||||
role: string,
|
||||
messageId: string
|
||||
): Promise<SessionMessageRow | null> {
|
||||
const database = await this.getDatabase()
|
||||
const candidateRows: SessionMessageRow[] = await database
|
||||
const candidateRows: SessionMessageRow[] = await writer
|
||||
.select()
|
||||
.from(sessionMessagesTable)
|
||||
.where(and(eq(sessionMessagesTable.session_id, sessionId), eq(sessionMessagesTable.role, role)))
|
||||
@@ -108,7 +122,10 @@ class AgentMessageRepository extends BaseService {
|
||||
private async upsertMessage(
|
||||
params: PersistUserMessageParams | PersistAssistantMessageParams
|
||||
): Promise<AgentSessionMessageEntity> {
|
||||
const { sessionId, agentSessionId = '', payload, metadata, createdAt } = params
|
||||
await AgentMessageRepository.initialize()
|
||||
this.ensureInitialized()
|
||||
|
||||
const { sessionId, agentSessionId = '', payload, metadata, createdAt, tx } = params
|
||||
|
||||
if (!payload?.message?.role) {
|
||||
throw new Error('Message payload missing role')
|
||||
@@ -118,18 +135,18 @@ class AgentMessageRepository extends BaseService {
|
||||
throw new Error('Message payload missing id')
|
||||
}
|
||||
|
||||
const database = await this.getDatabase()
|
||||
const writer = this.getWriter(tx)
|
||||
const now = createdAt ?? payload.message.createdAt ?? new Date().toISOString()
|
||||
const serializedPayload = this.serializeMessage(payload)
|
||||
const serializedMetadata = this.serializeMetadata(metadata)
|
||||
|
||||
const existingRow = await this.findExistingMessageRow(sessionId, payload.message.role, payload.message.id)
|
||||
const existingRow = await this.findExistingMessageRow(writer, sessionId, payload.message.role, payload.message.id)
|
||||
|
||||
if (existingRow) {
|
||||
const metadataToPersist = serializedMetadata ?? existingRow.metadata ?? undefined
|
||||
const agentSessionToPersist = agentSessionId || existingRow.agent_session_id || ''
|
||||
|
||||
await database
|
||||
await writer
|
||||
.update(sessionMessagesTable)
|
||||
.set({
|
||||
content: serializedPayload,
|
||||
@@ -158,7 +175,7 @@ class AgentMessageRepository extends BaseService {
|
||||
updated_at: now
|
||||
}
|
||||
|
||||
const [saved] = await database.insert(sessionMessagesTable).values(insertData).returning()
|
||||
const [saved] = await writer.insert(sessionMessagesTable).values(insertData).returning()
|
||||
|
||||
return this.deserialize(saved)
|
||||
}
|
||||
@@ -171,38 +188,49 @@ class AgentMessageRepository extends BaseService {
|
||||
return this.upsertMessage(params)
|
||||
}
|
||||
|
||||
async persistExchange(params: AgentMessagePersistExchangePayload): Promise<AgentMessagePersistExchangeResult> {
|
||||
async persistExchange(params: PersistExchangeParams): Promise<PersistExchangeResult> {
|
||||
await AgentMessageRepository.initialize()
|
||||
this.ensureInitialized()
|
||||
|
||||
const { sessionId, agentSessionId, user, assistant } = params
|
||||
|
||||
const exchangeResult: AgentMessagePersistExchangeResult = {}
|
||||
const result = await this.database.transaction(async (tx) => {
|
||||
const exchangeResult: PersistExchangeResult = {}
|
||||
|
||||
if (user?.payload) {
|
||||
exchangeResult.userMessage = await this.persistUserMessage({
|
||||
sessionId,
|
||||
agentSessionId,
|
||||
payload: user.payload,
|
||||
metadata: user.metadata,
|
||||
createdAt: user.createdAt
|
||||
})
|
||||
}
|
||||
if (user?.payload) {
|
||||
exchangeResult.userMessage = await this.persistUserMessage({
|
||||
sessionId,
|
||||
agentSessionId,
|
||||
payload: user.payload,
|
||||
metadata: user.metadata,
|
||||
createdAt: user.createdAt,
|
||||
tx
|
||||
})
|
||||
}
|
||||
|
||||
if (assistant?.payload) {
|
||||
exchangeResult.assistantMessage = await this.persistAssistantMessage({
|
||||
sessionId,
|
||||
agentSessionId,
|
||||
payload: assistant.payload,
|
||||
metadata: assistant.metadata,
|
||||
createdAt: assistant.createdAt
|
||||
})
|
||||
}
|
||||
if (assistant?.payload) {
|
||||
exchangeResult.assistantMessage = await this.persistAssistantMessage({
|
||||
sessionId,
|
||||
agentSessionId,
|
||||
payload: assistant.payload,
|
||||
metadata: assistant.metadata,
|
||||
createdAt: assistant.createdAt,
|
||||
tx
|
||||
})
|
||||
}
|
||||
|
||||
return exchangeResult
|
||||
return exchangeResult
|
||||
})
|
||||
|
||||
return result
|
||||
}
|
||||
|
||||
async getSessionHistory(sessionId: string): Promise<AgentPersistedMessage[]> {
|
||||
await AgentMessageRepository.initialize()
|
||||
this.ensureInitialized()
|
||||
|
||||
try {
|
||||
const database = await this.getDatabase()
|
||||
const rows = await database
|
||||
const rows = await this.database
|
||||
.select()
|
||||
.from(sessionMessagesTable)
|
||||
.where(eq(sessionMessagesTable.session_id, sessionId))
|
||||
|
||||
@@ -32,8 +32,14 @@ export class AgentService extends BaseService {
|
||||
return AgentService.instance
|
||||
}
|
||||
|
||||
async initialize(): Promise<void> {
|
||||
await BaseService.initialize()
|
||||
}
|
||||
|
||||
// Agent Methods
|
||||
async createAgent(req: CreateAgentRequest): Promise<CreateAgentResponse> {
|
||||
this.ensureInitialized()
|
||||
|
||||
const id = `agent_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`
|
||||
const now = new Date().toISOString()
|
||||
|
||||
@@ -69,9 +75,8 @@ export class AgentService extends BaseService {
|
||||
updated_at: now
|
||||
}
|
||||
|
||||
const database = await this.getDatabase()
|
||||
await database.insert(agentsTable).values(insertData)
|
||||
const result = await database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
|
||||
await this.database.insert(agentsTable).values(insertData)
|
||||
const result = await this.database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
|
||||
if (!result[0]) {
|
||||
throw new Error('Failed to create agent')
|
||||
}
|
||||
@@ -81,8 +86,9 @@ export class AgentService extends BaseService {
|
||||
}
|
||||
|
||||
async getAgent(id: string): Promise<GetAgentResponse | null> {
|
||||
const database = await this.getDatabase()
|
||||
const result = await database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
|
||||
this.ensureInitialized()
|
||||
|
||||
const result = await this.database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
|
||||
|
||||
if (!result[0]) {
|
||||
return null
|
||||
@@ -112,9 +118,9 @@ export class AgentService extends BaseService {
|
||||
}
|
||||
|
||||
async listAgents(options: ListOptions = {}): Promise<{ agents: AgentEntity[]; total: number }> {
|
||||
// Build query with pagination
|
||||
const database = await this.getDatabase()
|
||||
const totalResult = await database.select({ count: count() }).from(agentsTable)
|
||||
this.ensureInitialized() // Build query with pagination
|
||||
|
||||
const totalResult = await this.database.select({ count: count() }).from(agentsTable)
|
||||
|
||||
const sortBy = options.sortBy || 'created_at'
|
||||
const orderBy = options.orderBy || 'desc'
|
||||
@@ -122,7 +128,7 @@ export class AgentService extends BaseService {
|
||||
const sortField = agentsTable[sortBy]
|
||||
const orderFn = orderBy === 'asc' ? asc : desc
|
||||
|
||||
const baseQuery = database.select().from(agentsTable).orderBy(orderFn(sortField))
|
||||
const baseQuery = this.database.select().from(agentsTable).orderBy(orderFn(sortField))
|
||||
|
||||
const result =
|
||||
options.limit !== undefined
|
||||
@@ -145,6 +151,8 @@ export class AgentService extends BaseService {
|
||||
updates: UpdateAgentRequest,
|
||||
options: { replace?: boolean } = {}
|
||||
): Promise<UpdateAgentResponse | null> {
|
||||
this.ensureInitialized()
|
||||
|
||||
// Check if agent exists
|
||||
const existing = await this.getAgent(id)
|
||||
if (!existing) {
|
||||
@@ -187,21 +195,22 @@ export class AgentService extends BaseService {
|
||||
}
|
||||
}
|
||||
|
||||
const database = await this.getDatabase()
|
||||
await database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id))
|
||||
await this.database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id))
|
||||
return await this.getAgent(id)
|
||||
}
|
||||
|
||||
async deleteAgent(id: string): Promise<boolean> {
|
||||
const database = await this.getDatabase()
|
||||
const result = await database.delete(agentsTable).where(eq(agentsTable.id, id))
|
||||
this.ensureInitialized()
|
||||
|
||||
const result = await this.database.delete(agentsTable).where(eq(agentsTable.id, id))
|
||||
|
||||
return result.rowsAffected > 0
|
||||
}
|
||||
|
||||
async agentExists(id: string): Promise<boolean> {
|
||||
const database = await this.getDatabase()
|
||||
const result = await database
|
||||
this.ensureInitialized()
|
||||
|
||||
const result = await this.database
|
||||
.select({ id: agentsTable.id })
|
||||
.from(agentsTable)
|
||||
.where(eq(agentsTable.id, id))
|
||||
|
||||
@@ -104,9 +104,14 @@ export class SessionMessageService extends BaseService {
|
||||
return SessionMessageService.instance
|
||||
}
|
||||
|
||||
async initialize(): Promise<void> {
|
||||
await BaseService.initialize()
|
||||
}
|
||||
|
||||
async sessionMessageExists(id: number): Promise<boolean> {
|
||||
const database = await this.getDatabase()
|
||||
const result = await database
|
||||
this.ensureInitialized()
|
||||
|
||||
const result = await this.database
|
||||
.select({ id: sessionMessagesTable.id })
|
||||
.from(sessionMessagesTable)
|
||||
.where(eq(sessionMessagesTable.id, id))
|
||||
@@ -119,9 +124,10 @@ export class SessionMessageService extends BaseService {
|
||||
sessionId: string,
|
||||
options: ListOptions = {}
|
||||
): Promise<{ messages: AgentSessionMessageEntity[] }> {
|
||||
this.ensureInitialized()
|
||||
|
||||
// Get messages with pagination
|
||||
const database = await this.getDatabase()
|
||||
const baseQuery = database
|
||||
const baseQuery = this.database
|
||||
.select()
|
||||
.from(sessionMessagesTable)
|
||||
.where(eq(sessionMessagesTable.session_id, sessionId))
|
||||
@@ -140,8 +146,9 @@ export class SessionMessageService extends BaseService {
|
||||
}
|
||||
|
||||
async deleteSessionMessage(sessionId: string, messageId: number): Promise<boolean> {
|
||||
const database = await this.getDatabase()
|
||||
const result = await database
|
||||
this.ensureInitialized()
|
||||
|
||||
const result = await this.database
|
||||
.delete(sessionMessagesTable)
|
||||
.where(and(eq(sessionMessagesTable.id, messageId), eq(sessionMessagesTable.session_id, sessionId)))
|
||||
|
||||
@@ -153,6 +160,8 @@ export class SessionMessageService extends BaseService {
|
||||
messageData: CreateSessionMessageRequest,
|
||||
abortController: AbortController
|
||||
): Promise<SessionStreamResult> {
|
||||
this.ensureInitialized()
|
||||
|
||||
return await this.startSessionMessageStream(session, messageData, abortController)
|
||||
}
|
||||
|
||||
@@ -261,9 +270,10 @@ export class SessionMessageService extends BaseService {
|
||||
}
|
||||
|
||||
private async getLastAgentSessionId(sessionId: string): Promise<string> {
|
||||
this.ensureInitialized()
|
||||
|
||||
try {
|
||||
const database = await this.getDatabase()
|
||||
const result = await database
|
||||
const result = await this.database
|
||||
.select({ agent_session_id: sessionMessagesTable.agent_session_id })
|
||||
.from(sessionMessagesTable)
|
||||
.where(and(eq(sessionMessagesTable.session_id, sessionId), not(eq(sessionMessagesTable.agent_session_id, ''))))
|
||||
|
||||
@@ -30,6 +30,10 @@ export class SessionService extends BaseService {
|
||||
return SessionService.instance
|
||||
}
|
||||
|
||||
async initialize(): Promise<void> {
|
||||
await BaseService.initialize()
|
||||
}
|
||||
|
||||
/**
|
||||
* Override BaseService.listSlashCommands to merge builtin and plugin commands
|
||||
*/
|
||||
@@ -80,12 +84,13 @@ export class SessionService extends BaseService {
|
||||
agentId: string,
|
||||
req: Partial<CreateSessionRequest> = {}
|
||||
): Promise<GetAgentSessionResponse | null> {
|
||||
this.ensureInitialized()
|
||||
|
||||
// Validate agent exists - we'll need to import AgentService for this check
|
||||
// For now, we'll skip this validation to avoid circular dependencies
|
||||
// The database foreign key constraint will handle this
|
||||
|
||||
const database = await this.getDatabase()
|
||||
const agents = await database.select().from(agentsTable).where(eq(agentsTable.id, agentId)).limit(1)
|
||||
const agents = await this.database.select().from(agentsTable).where(eq(agentsTable.id, agentId)).limit(1)
|
||||
if (!agents[0]) {
|
||||
throw new Error('Agent not found')
|
||||
}
|
||||
@@ -130,10 +135,9 @@ export class SessionService extends BaseService {
|
||||
updated_at: now
|
||||
}
|
||||
|
||||
const db = await this.getDatabase()
|
||||
await db.insert(sessionsTable).values(insertData)
|
||||
await this.database.insert(sessionsTable).values(insertData)
|
||||
|
||||
const result = await db.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1)
|
||||
const result = await this.database.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1)
|
||||
|
||||
if (!result[0]) {
|
||||
throw new Error('Failed to create session')
|
||||
@@ -144,8 +148,9 @@ export class SessionService extends BaseService {
|
||||
}
|
||||
|
||||
async getSession(agentId: string, id: string): Promise<GetAgentSessionResponse | null> {
|
||||
const database = await this.getDatabase()
|
||||
const result = await database
|
||||
this.ensureInitialized()
|
||||
|
||||
const result = await this.database
|
||||
.select()
|
||||
.from(sessionsTable)
|
||||
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
|
||||
@@ -171,6 +176,8 @@ export class SessionService extends BaseService {
|
||||
agentId?: string,
|
||||
options: ListOptions = {}
|
||||
): Promise<{ sessions: AgentSessionEntity[]; total: number }> {
|
||||
this.ensureInitialized()
|
||||
|
||||
// Build where conditions
|
||||
const whereConditions: SQL[] = []
|
||||
if (agentId) {
|
||||
@@ -185,13 +192,16 @@ export class SessionService extends BaseService {
|
||||
: undefined
|
||||
|
||||
// Get total count
|
||||
const database = await this.getDatabase()
|
||||
const totalResult = await database.select({ count: count() }).from(sessionsTable).where(whereClause)
|
||||
const totalResult = await this.database.select({ count: count() }).from(sessionsTable).where(whereClause)
|
||||
|
||||
const total = totalResult[0].count
|
||||
|
||||
// Build list query with pagination - sort by updated_at descending (latest first)
|
||||
const baseQuery = database.select().from(sessionsTable).where(whereClause).orderBy(desc(sessionsTable.updated_at))
|
||||
const baseQuery = this.database
|
||||
.select()
|
||||
.from(sessionsTable)
|
||||
.where(whereClause)
|
||||
.orderBy(desc(sessionsTable.updated_at))
|
||||
|
||||
const result =
|
||||
options.limit !== undefined
|
||||
@@ -210,6 +220,8 @@ export class SessionService extends BaseService {
|
||||
id: string,
|
||||
updates: UpdateSessionRequest
|
||||
): Promise<UpdateSessionResponse | null> {
|
||||
this.ensureInitialized()
|
||||
|
||||
// Check if session exists
|
||||
const existing = await this.getSession(agentId, id)
|
||||
if (!existing) {
|
||||
@@ -250,15 +262,15 @@ export class SessionService extends BaseService {
|
||||
}
|
||||
}
|
||||
|
||||
const database = await this.getDatabase()
|
||||
await database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))
|
||||
await this.database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))
|
||||
|
||||
return await this.getSession(agentId, id)
|
||||
}
|
||||
|
||||
async deleteSession(agentId: string, id: string): Promise<boolean> {
|
||||
const database = await this.getDatabase()
|
||||
const result = await database
|
||||
this.ensureInitialized()
|
||||
|
||||
const result = await this.database
|
||||
.delete(sessionsTable)
|
||||
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
|
||||
|
||||
@@ -266,8 +278,9 @@ export class SessionService extends BaseService {
|
||||
}
|
||||
|
||||
async sessionExists(agentId: string, id: string): Promise<boolean> {
|
||||
const database = await this.getDatabase()
|
||||
const result = await database
|
||||
this.ensureInitialized()
|
||||
|
||||
const result = await this.database
|
||||
.select({ id: sessionsTable.id })
|
||||
.from(sessionsTable)
|
||||
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
|
||||
|
||||
@@ -21,11 +21,6 @@ describe('stripLocalCommandTags', () => {
|
||||
'<local-command-stdout>line1</local-command-stdout>\nkeep\n<local-command-stderr>Error</local-command-stderr>'
|
||||
expect(stripLocalCommandTags(input)).toBe('line1\nkeep\nError')
|
||||
})
|
||||
|
||||
it('if no tags present, returns original string', () => {
|
||||
const input = 'just some normal text'
|
||||
expect(stripLocalCommandTags(input)).toBe(input)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Claude → AiSDK transform', () => {
|
||||
@@ -193,111 +188,6 @@ describe('Claude → AiSDK transform', () => {
|
||||
expect(toolResult.output).toBe('ok')
|
||||
})
|
||||
|
||||
it('handles tool calls without streaming events (no content_block_start/stop)', () => {
|
||||
const state = new ClaudeStreamState({ agentSessionId: '12344' })
|
||||
const parts: ReturnType<typeof transformSDKMessageToStreamParts>[number][] = []
|
||||
|
||||
const messages: SDKMessage[] = [
|
||||
{
|
||||
...baseStreamMetadata,
|
||||
type: 'assistant',
|
||||
uuid: uuid(20),
|
||||
message: {
|
||||
id: 'msg-tool-no-stream',
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
model: 'claude-test',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_use',
|
||||
id: 'tool-read',
|
||||
name: 'Read',
|
||||
input: { file_path: '/test.txt' }
|
||||
},
|
||||
{
|
||||
type: 'tool_use',
|
||||
id: 'tool-bash',
|
||||
name: 'Bash',
|
||||
input: { command: 'ls -la' }
|
||||
}
|
||||
],
|
||||
stop_reason: 'tool_use',
|
||||
stop_sequence: null,
|
||||
usage: {
|
||||
input_tokens: 10,
|
||||
output_tokens: 20
|
||||
}
|
||||
}
|
||||
} as unknown as SDKMessage,
|
||||
{
|
||||
...baseStreamMetadata,
|
||||
type: 'user',
|
||||
uuid: uuid(21),
|
||||
message: {
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: 'tool-read',
|
||||
content: 'file contents',
|
||||
is_error: false
|
||||
}
|
||||
]
|
||||
}
|
||||
} as SDKMessage,
|
||||
{
|
||||
...baseStreamMetadata,
|
||||
type: 'user',
|
||||
uuid: uuid(22),
|
||||
message: {
|
||||
role: 'user',
|
||||
content: [
|
||||
{
|
||||
type: 'tool_result',
|
||||
tool_use_id: 'tool-bash',
|
||||
content: 'total 42\n...',
|
||||
is_error: false
|
||||
}
|
||||
]
|
||||
}
|
||||
} as SDKMessage
|
||||
]
|
||||
|
||||
for (const message of messages) {
|
||||
const transformed = transformSDKMessageToStreamParts(message, state)
|
||||
parts.push(...transformed)
|
||||
}
|
||||
|
||||
const types = parts.map((part) => part.type)
|
||||
expect(types).toEqual(['tool-call', 'tool-call', 'tool-result', 'tool-result'])
|
||||
|
||||
const toolCalls = parts.filter((part) => part.type === 'tool-call') as Extract<
|
||||
(typeof parts)[number],
|
||||
{ type: 'tool-call' }
|
||||
>[]
|
||||
expect(toolCalls).toHaveLength(2)
|
||||
expect(toolCalls[0].toolName).toBe('Read')
|
||||
expect(toolCalls[0].toolCallId).toBe('12344:tool-read')
|
||||
expect(toolCalls[1].toolName).toBe('Bash')
|
||||
expect(toolCalls[1].toolCallId).toBe('12344:tool-bash')
|
||||
|
||||
const toolResults = parts.filter((part) => part.type === 'tool-result') as Extract<
|
||||
(typeof parts)[number],
|
||||
{ type: 'tool-result' }
|
||||
>[]
|
||||
expect(toolResults).toHaveLength(2)
|
||||
// This is the key assertion - toolName should NOT be 'unknown'
|
||||
expect(toolResults[0].toolName).toBe('Read')
|
||||
expect(toolResults[0].toolCallId).toBe('12344:tool-read')
|
||||
expect(toolResults[0].input).toEqual({ file_path: '/test.txt' })
|
||||
expect(toolResults[0].output).toBe('file contents')
|
||||
|
||||
expect(toolResults[1].toolName).toBe('Bash')
|
||||
expect(toolResults[1].toolCallId).toBe('12344:tool-bash')
|
||||
expect(toolResults[1].input).toEqual({ command: 'ls -la' })
|
||||
expect(toolResults[1].output).toBe('total 42\n...')
|
||||
})
|
||||
|
||||
it('handles streaming text completion', () => {
|
||||
const state = new ClaudeStreamState({ agentSessionId: baseStreamMetadata.session_id })
|
||||
const parts: ReturnType<typeof transformSDKMessageToStreamParts>[number][] = []
|
||||
@@ -410,87 +300,4 @@ describe('Claude → AiSDK transform', () => {
|
||||
expect(finishStep.finishReason).toBe('stop')
|
||||
expect(finishStep.usage).toEqual({ inputTokens: 2, outputTokens: 4, totalTokens: 6 })
|
||||
})
|
||||
|
||||
it('emits fallback text when Claude sends a snapshot instead of deltas', () => {
|
||||
const state = new ClaudeStreamState({ agentSessionId: '12344' })
|
||||
const parts: ReturnType<typeof transformSDKMessageToStreamParts>[number][] = []
|
||||
|
||||
const messages: SDKMessage[] = [
|
||||
{
|
||||
...baseStreamMetadata,
|
||||
type: 'stream_event',
|
||||
uuid: uuid(30),
|
||||
event: {
|
||||
type: 'message_start',
|
||||
message: {
|
||||
id: 'msg-fallback',
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
model: 'claude-test',
|
||||
content: [],
|
||||
stop_reason: null,
|
||||
stop_sequence: null,
|
||||
usage: {}
|
||||
}
|
||||
}
|
||||
} as unknown as SDKMessage,
|
||||
{
|
||||
...baseStreamMetadata,
|
||||
type: 'stream_event',
|
||||
uuid: uuid(31),
|
||||
event: {
|
||||
type: 'content_block_start',
|
||||
index: 0,
|
||||
content_block: {
|
||||
type: 'text',
|
||||
text: ''
|
||||
}
|
||||
}
|
||||
} as unknown as SDKMessage,
|
||||
{
|
||||
...baseStreamMetadata,
|
||||
type: 'assistant',
|
||||
uuid: uuid(32),
|
||||
message: {
|
||||
id: 'msg-fallback-content',
|
||||
type: 'message',
|
||||
role: 'assistant',
|
||||
model: 'claude-test',
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'Final answer without streaming deltas.'
|
||||
}
|
||||
],
|
||||
stop_reason: 'end_turn',
|
||||
stop_sequence: null,
|
||||
usage: {
|
||||
input_tokens: 3,
|
||||
output_tokens: 7
|
||||
}
|
||||
}
|
||||
} as unknown as SDKMessage
|
||||
]
|
||||
|
||||
for (const message of messages) {
|
||||
const transformed = transformSDKMessageToStreamParts(message, state)
|
||||
parts.push(...transformed)
|
||||
}
|
||||
|
||||
const types = parts.map((part) => part.type)
|
||||
expect(types).toEqual(['start-step', 'text-start', 'text-delta', 'text-end', 'finish-step'])
|
||||
|
||||
const delta = parts.find((part) => part.type === 'text-delta') as Extract<
|
||||
(typeof parts)[number],
|
||||
{ type: 'text-delta' }
|
||||
>
|
||||
expect(delta.text).toBe('Final answer without streaming deltas.')
|
||||
|
||||
const finish = parts.find((part) => part.type === 'finish-step') as Extract<
|
||||
(typeof parts)[number],
|
||||
{ type: 'finish-step' }
|
||||
>
|
||||
expect(finish.usage).toEqual({ inputTokens: 3, outputTokens: 7, totalTokens: 10 })
|
||||
expect(finish.finishReason).toBe('stop')
|
||||
})
|
||||
})
|
||||
|
||||
@@ -153,20 +153,6 @@ export class ClaudeStreamState {
|
||||
return this.blocksByIndex.get(index)
|
||||
}
|
||||
|
||||
getFirstOpenTextBlock(): TextBlockState | undefined {
|
||||
const candidates: TextBlockState[] = []
|
||||
for (const block of this.blocksByIndex.values()) {
|
||||
if (block.kind === 'text') {
|
||||
candidates.push(block)
|
||||
}
|
||||
}
|
||||
if (candidates.length === 0) {
|
||||
return undefined
|
||||
}
|
||||
candidates.sort((a, b) => a.index - b.index)
|
||||
return candidates[0]
|
||||
}
|
||||
|
||||
getToolBlockById(toolCallId: string): ToolBlockState | undefined {
|
||||
const index = this.toolIndexByNamespacedId.get(toolCallId)
|
||||
if (index === undefined) return undefined
|
||||
@@ -231,10 +217,10 @@ export class ClaudeStreamState {
|
||||
* Persists the final input payload for a tool block once the provider signals
|
||||
* completion so that downstream tool results can reference the original call.
|
||||
*/
|
||||
completeToolBlock(toolCallId: string, toolName: string, input: unknown, providerMetadata?: ProviderMetadata): void {
|
||||
completeToolBlock(toolCallId: string, input: unknown, providerMetadata?: ProviderMetadata): void {
|
||||
const block = this.getToolBlockByRawId(toolCallId)
|
||||
this.registerToolCall(toolCallId, {
|
||||
toolName,
|
||||
toolName: block?.toolName ?? 'unknown',
|
||||
input,
|
||||
providerMetadata
|
||||
})
|
||||
|
||||
@@ -2,14 +2,7 @@
|
||||
import { EventEmitter } from 'node:events'
|
||||
import { createRequire } from 'node:module'
|
||||
|
||||
import type {
|
||||
CanUseTool,
|
||||
HookCallback,
|
||||
McpHttpServerConfig,
|
||||
Options,
|
||||
PreToolUseHookInput,
|
||||
SDKMessage
|
||||
} from '@anthropic-ai/claude-agent-sdk'
|
||||
import type { CanUseTool, McpHttpServerConfig, Options, SDKMessage } from '@anthropic-ai/claude-agent-sdk'
|
||||
import { query } from '@anthropic-ai/claude-agent-sdk'
|
||||
import { loggerService } from '@logger'
|
||||
import { config as apiConfigService } from '@main/apiServer/config'
|
||||
@@ -164,63 +157,6 @@ class ClaudeCodeService implements AgentServiceInterface {
|
||||
})
|
||||
}
|
||||
|
||||
const preToolUseHook: HookCallback = async (input, toolUseID, options) => {
|
||||
// Type guard to ensure we're handling PreToolUse event
|
||||
if (input.hook_event_name !== 'PreToolUse') {
|
||||
return {}
|
||||
}
|
||||
|
||||
const hookInput = input as PreToolUseHookInput
|
||||
const toolName = hookInput.tool_name
|
||||
|
||||
logger.debug('PreToolUse hook triggered', {
|
||||
session_id: hookInput.session_id,
|
||||
tool_name: hookInput.tool_name,
|
||||
tool_use_id: toolUseID,
|
||||
tool_input: hookInput.tool_input,
|
||||
cwd: hookInput.cwd,
|
||||
permission_mode: hookInput.permission_mode,
|
||||
autoAllowTools: autoAllowTools
|
||||
})
|
||||
|
||||
if (options?.signal?.aborted) {
|
||||
logger.debug('PreToolUse hook signal already aborted; skipping tool use', {
|
||||
tool_name: hookInput.tool_name
|
||||
})
|
||||
return {}
|
||||
}
|
||||
|
||||
// handle auto approved tools since it never triggers canUseTool
|
||||
const normalizedToolName = normalizeToolName(toolName)
|
||||
if (toolUseID) {
|
||||
const bypassAll = input.permission_mode === 'bypassPermissions'
|
||||
const autoAllowed = autoAllowTools.has(toolName) || autoAllowTools.has(normalizedToolName)
|
||||
if (bypassAll || autoAllowed) {
|
||||
const namespacedToolCallId = buildNamespacedToolCallId(session.id, toolUseID)
|
||||
logger.debug('handling auto approved tools', {
|
||||
toolName,
|
||||
normalizedToolName,
|
||||
namespacedToolCallId,
|
||||
permission_mode: input.permission_mode,
|
||||
autoAllowTools
|
||||
})
|
||||
const isRecord = (v: unknown): v is Record<string, unknown> => {
|
||||
return !!v && typeof v === 'object' && !Array.isArray(v)
|
||||
}
|
||||
const toolInput = isRecord(input.tool_input) ? input.tool_input : {}
|
||||
|
||||
await promptForToolApproval(toolName, toolInput, {
|
||||
...options,
|
||||
toolCallId: namespacedToolCallId,
|
||||
autoApprove: true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Return to proceed without modification
|
||||
return {}
|
||||
}
|
||||
|
||||
// Build SDK options from parameters
|
||||
const options: Options = {
|
||||
abortController,
|
||||
@@ -244,14 +180,7 @@ class ClaudeCodeService implements AgentServiceInterface {
|
||||
permissionMode: session.configuration?.permission_mode,
|
||||
maxTurns: session.configuration?.max_turns,
|
||||
allowedTools: session.allowed_tools,
|
||||
canUseTool,
|
||||
hooks: {
|
||||
PreToolUse: [
|
||||
{
|
||||
hooks: [preToolUseHook]
|
||||
}
|
||||
]
|
||||
}
|
||||
canUseTool
|
||||
}
|
||||
|
||||
if (session.accessible_paths.length > 1) {
|
||||
@@ -485,6 +414,23 @@ class ClaudeCodeService implements AgentServiceInterface {
|
||||
}
|
||||
}
|
||||
|
||||
if (message.type === 'assistant' || message.type === 'user') {
|
||||
logger.silly('claude response', {
|
||||
message,
|
||||
content: JSON.stringify(message.message.content)
|
||||
})
|
||||
} else if (message.type === 'stream_event') {
|
||||
// logger.silly('Claude stream event', {
|
||||
// message,
|
||||
// event: JSON.stringify(message.event)
|
||||
// })
|
||||
} else {
|
||||
logger.silly('Claude response', {
|
||||
message,
|
||||
event: JSON.stringify(message)
|
||||
})
|
||||
}
|
||||
|
||||
const chunks = transformSDKMessageToStreamParts(message, streamState)
|
||||
for (const chunk of chunks) {
|
||||
stream.emit('data', {
|
||||
|
||||
@@ -31,7 +31,6 @@ type PendingPermissionRequest = {
|
||||
abortListener?: () => void
|
||||
originalInput: Record<string, unknown>
|
||||
toolName: string
|
||||
toolCallId?: string
|
||||
}
|
||||
|
||||
type RendererPermissionRequestPayload = {
|
||||
@@ -46,7 +45,6 @@ type RendererPermissionRequestPayload = {
|
||||
createdAt: number
|
||||
expiresAt: number
|
||||
suggestions: PermissionUpdate[]
|
||||
autoApprove?: boolean
|
||||
}
|
||||
|
||||
type RendererPermissionResultPayload = {
|
||||
@@ -54,7 +52,6 @@ type RendererPermissionResultPayload = {
|
||||
behavior: ToolPermissionBehavior
|
||||
message?: string
|
||||
reason: 'response' | 'timeout' | 'aborted' | 'no-window'
|
||||
toolCallId?: string
|
||||
}
|
||||
|
||||
const pendingRequests = new Map<string, PendingPermissionRequest>()
|
||||
@@ -148,8 +145,7 @@ const finalizeRequest = (
|
||||
requestId,
|
||||
behavior: update.behavior,
|
||||
message: update.behavior === 'deny' ? update.message : undefined,
|
||||
reason,
|
||||
toolCallId: pending.toolCallId
|
||||
reason
|
||||
}
|
||||
|
||||
const dispatched = broadcastToRenderer(IpcChannel.AgentToolPermission_Result, resultPayload)
|
||||
@@ -214,7 +210,6 @@ const ensureIpcHandlersRegistered = () => {
|
||||
type PromptForToolApprovalOptions = {
|
||||
signal: AbortSignal
|
||||
suggestions?: PermissionUpdate[]
|
||||
autoApprove?: boolean
|
||||
|
||||
// NOTICE: This ID is namespaced with session ID, not the raw SDK tool call ID.
|
||||
// Format: `${sessionId}:${rawToolCallId}`, e.g., `session_123:WebFetch_0`
|
||||
@@ -275,8 +270,7 @@ export async function promptForToolApproval(
|
||||
inputPreview,
|
||||
createdAt,
|
||||
expiresAt,
|
||||
suggestions: sanitizedSuggestions,
|
||||
autoApprove: options.autoApprove
|
||||
suggestions: sanitizedSuggestions
|
||||
}
|
||||
|
||||
const defaultDenyUpdate: PermissionResult = { behavior: 'deny', message: 'Tool request aborted before user decision' }
|
||||
@@ -305,8 +299,7 @@ export async function promptForToolApproval(
|
||||
timeout,
|
||||
originalInput: sanitizedInput,
|
||||
toolName,
|
||||
signal: options?.signal,
|
||||
toolCallId: options.toolCallId
|
||||
signal: options?.signal
|
||||
}
|
||||
|
||||
if (options?.signal) {
|
||||
|
||||
@@ -110,7 +110,7 @@ const sdkMessageToProviderMetadata = (message: SDKMessage): ProviderMetadata =>
|
||||
* blocks across calls so that incremental deltas can be correlated correctly.
|
||||
*/
|
||||
export function transformSDKMessageToStreamParts(sdkMessage: SDKMessage, state: ClaudeStreamState): AgentStreamPart[] {
|
||||
logger.silly('Transforming SDKMessage', { message: JSON.stringify(sdkMessage) })
|
||||
logger.silly('Transforming SDKMessage', { message: sdkMessage })
|
||||
switch (sdkMessage.type) {
|
||||
case 'assistant':
|
||||
return handleAssistantMessage(sdkMessage, state)
|
||||
@@ -186,13 +186,14 @@ function handleAssistantMessage(
|
||||
|
||||
for (const block of content) {
|
||||
switch (block.type) {
|
||||
case 'text': {
|
||||
const sanitizedText = stripLocalCommandTags(block.text)
|
||||
if (sanitizedText) {
|
||||
textBlocks.push(sanitizedText)
|
||||
case 'text':
|
||||
if (!isStreamingActive) {
|
||||
const sanitizedText = stripLocalCommandTags(block.text)
|
||||
if (sanitizedText) {
|
||||
textBlocks.push(sanitizedText)
|
||||
}
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'tool_use':
|
||||
handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks)
|
||||
break
|
||||
@@ -202,16 +203,7 @@ function handleAssistantMessage(
|
||||
}
|
||||
}
|
||||
|
||||
if (textBlocks.length === 0) {
|
||||
return chunks
|
||||
}
|
||||
|
||||
const combinedText = textBlocks.join('')
|
||||
if (!combinedText) {
|
||||
return chunks
|
||||
}
|
||||
|
||||
if (!isStreamingActive) {
|
||||
if (!isStreamingActive && textBlocks.length > 0) {
|
||||
const id = message.uuid?.toString() || generateMessageId()
|
||||
state.beginStep()
|
||||
chunks.push({
|
||||
@@ -227,7 +219,7 @@ function handleAssistantMessage(
|
||||
chunks.push({
|
||||
type: 'text-delta',
|
||||
id,
|
||||
text: combinedText,
|
||||
text: textBlocks.join(''),
|
||||
providerMetadata
|
||||
})
|
||||
chunks.push({
|
||||
@@ -238,27 +230,7 @@ function handleAssistantMessage(
|
||||
return finalizeNonStreamingStep(message, state, chunks)
|
||||
}
|
||||
|
||||
const existingTextBlock = state.getFirstOpenTextBlock()
|
||||
const fallbackId = existingTextBlock?.id || message.uuid?.toString() || generateMessageId()
|
||||
if (!existingTextBlock) {
|
||||
chunks.push({
|
||||
type: 'text-start',
|
||||
id: fallbackId,
|
||||
providerMetadata
|
||||
})
|
||||
}
|
||||
chunks.push({
|
||||
type: 'text-delta',
|
||||
id: fallbackId,
|
||||
text: combinedText,
|
||||
providerMetadata
|
||||
})
|
||||
chunks.push({
|
||||
type: 'text-end',
|
||||
id: fallbackId,
|
||||
providerMetadata
|
||||
})
|
||||
return finalizeNonStreamingStep(message, state, chunks)
|
||||
return chunks
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -280,7 +252,7 @@ function handleAssistantToolUse(
|
||||
providerExecuted: true,
|
||||
providerMetadata
|
||||
})
|
||||
state.completeToolBlock(block.id, block.name, block.input, providerMetadata)
|
||||
state.completeToolBlock(block.id, block.input, providerMetadata)
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -487,9 +459,6 @@ function handleStreamEvent(
|
||||
}
|
||||
|
||||
case 'message_stop': {
|
||||
if (!state.hasActiveStep()) {
|
||||
break
|
||||
}
|
||||
const pending = state.getPendingUsage()
|
||||
chunks.push({
|
||||
type: 'finish-step',
|
||||
|
||||
@@ -122,8 +122,7 @@ const api = {
|
||||
system: {
|
||||
getDeviceType: () => ipcRenderer.invoke(IpcChannel.System_GetDeviceType),
|
||||
getHostname: () => ipcRenderer.invoke(IpcChannel.System_GetHostname),
|
||||
getCpuName: () => ipcRenderer.invoke(IpcChannel.System_GetCpuName),
|
||||
checkGitBash: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.System_CheckGitBash)
|
||||
getCpuName: () => ipcRenderer.invoke(IpcChannel.System_GetCpuName)
|
||||
},
|
||||
devTools: {
|
||||
toggle: () => ipcRenderer.invoke(IpcChannel.System_ToggleDevTools)
|
||||
@@ -221,6 +220,10 @@ const api = {
|
||||
startFileWatcher: (dirPath: string, config?: any) =>
|
||||
ipcRenderer.invoke(IpcChannel.File_StartWatcher, dirPath, config),
|
||||
stopFileWatcher: () => ipcRenderer.invoke(IpcChannel.File_StopWatcher),
|
||||
pauseFileWatcher: () => ipcRenderer.invoke(IpcChannel.File_PauseWatcher),
|
||||
resumeFileWatcher: () => ipcRenderer.invoke(IpcChannel.File_ResumeWatcher),
|
||||
batchUploadMarkdown: (filePaths: string[], targetPath: string) =>
|
||||
ipcRenderer.invoke(IpcChannel.File_BatchUploadMarkdown, filePaths, targetPath),
|
||||
onFileChange: (callback: (data: FileChangeEvent) => void) => {
|
||||
const listener = (_event: Electron.IpcRendererEvent, data: any) => {
|
||||
if (data && typeof data === 'object') {
|
||||
|
||||
@@ -32,7 +32,6 @@ import {
|
||||
prepareSpecialProviderConfig,
|
||||
providerToAiSdkConfig
|
||||
} from './provider/providerConfig'
|
||||
import type { AiSdkConfig } from './types'
|
||||
|
||||
const logger = loggerService.withContext('ModernAiProvider')
|
||||
|
||||
@@ -45,7 +44,7 @@ export type ModernAiProviderConfig = AiSdkMiddlewareConfig & {
|
||||
|
||||
export default class ModernAiProvider {
|
||||
private legacyProvider: LegacyAiProvider
|
||||
private config?: AiSdkConfig
|
||||
private config?: ReturnType<typeof providerToAiSdkConfig>
|
||||
private actualProvider: Provider
|
||||
private model?: Model
|
||||
private localProvider: Awaited<AiSdkProvider> | null = null
|
||||
@@ -90,11 +89,6 @@ export default class ModernAiProvider {
|
||||
// 每次请求时重新生成配置以确保API key轮换生效
|
||||
this.config = providerToAiSdkConfig(this.actualProvider, this.model)
|
||||
logger.debug('Generated provider config for completions', this.config)
|
||||
|
||||
// 检查 config 是否存在
|
||||
if (!this.config) {
|
||||
throw new Error('Provider config is undefined; cannot proceed with completions')
|
||||
}
|
||||
if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) {
|
||||
providerConfig.isImageGenerationEndpoint = true
|
||||
}
|
||||
@@ -155,8 +149,7 @@ export default class ModernAiProvider {
|
||||
params: StreamTextParams,
|
||||
config: ModernAiProviderConfig
|
||||
): Promise<CompletionsResult> {
|
||||
// ai-gateway不是image/generation 端点,所以就先不走legacy了
|
||||
if (config.isImageGenerationEndpoint && config.provider!.id !== SystemProviderIds['ai-gateway']) {
|
||||
if (config.isImageGenerationEndpoint) {
|
||||
// 使用 legacy 实现处理图像生成(支持图片编辑等高级功能)
|
||||
if (!config.uiMessages) {
|
||||
throw new Error('uiMessages is required for image generation endpoint')
|
||||
@@ -470,13 +463,8 @@ export default class ModernAiProvider {
|
||||
// 如果支持新的 AI SDK,使用现代化实现
|
||||
if (isModernSdkSupported(this.actualProvider)) {
|
||||
try {
|
||||
// 确保 config 已定义
|
||||
if (!this.config) {
|
||||
throw new Error('Provider config is undefined; cannot proceed with generateImage')
|
||||
}
|
||||
|
||||
// 确保本地provider已创建
|
||||
if (!this.localProvider && this.config) {
|
||||
if (!this.localProvider) {
|
||||
this.localProvider = await createAiSdkProvider(this.config)
|
||||
if (!this.localProvider) {
|
||||
throw new Error('Local provider not created')
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { isNewApiProvider } from '@renderer/config/providers'
|
||||
import type { Provider } from '@renderer/types'
|
||||
import { isNewApiProvider } from '@renderer/utils/provider'
|
||||
|
||||
import { AihubmixAPIClient } from './aihubmix/AihubmixAPIClient'
|
||||
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
isSupportFlexServiceTierModel
|
||||
} from '@renderer/config/models'
|
||||
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
|
||||
import { isSupportServiceTierProvider } from '@renderer/config/providers'
|
||||
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
|
||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||
import type {
|
||||
@@ -18,6 +19,7 @@ import type {
|
||||
MCPToolResponse,
|
||||
MemoryItem,
|
||||
Model,
|
||||
OpenAIVerbosity,
|
||||
Provider,
|
||||
ToolCallResponse,
|
||||
WebSearchProviderResponse,
|
||||
@@ -31,7 +33,6 @@ import {
|
||||
OpenAIServiceTiers,
|
||||
SystemProviderIds
|
||||
} from '@renderer/types'
|
||||
import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
|
||||
import type { Message } from '@renderer/types/newMessage'
|
||||
import type {
|
||||
RequestOptions,
|
||||
@@ -47,7 +48,6 @@ import type {
|
||||
import { isJSON, parseJSON } from '@renderer/utils'
|
||||
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
|
||||
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
|
||||
import { isSupportServiceTierProvider } from '@renderer/utils/provider'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
import { defaultAppHeaders } from '@shared/utils'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
@@ -58,27 +58,10 @@ vi.mock('../aws/AwsBedrockAPIClient', () => ({
|
||||
AwsBedrockAPIClient: vi.fn().mockImplementation(() => ({}))
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/AssistantService.ts', () => ({
|
||||
getDefaultAssistant: () => {
|
||||
return {
|
||||
id: 'default',
|
||||
name: 'default',
|
||||
emoji: '😀',
|
||||
prompt: '',
|
||||
topics: [],
|
||||
messages: [],
|
||||
type: 'assistant',
|
||||
regularPhrases: [],
|
||||
settings: {}
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
// Mock the models config to prevent circular dependency issues
|
||||
vi.mock('@renderer/config/models', () => ({
|
||||
findTokenLimit: vi.fn(),
|
||||
isReasoningModel: vi.fn(),
|
||||
isOpenAILLMModel: vi.fn(),
|
||||
SYSTEM_MODELS: {
|
||||
silicon: [],
|
||||
defaultModel: []
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import { GoogleGenAI } from '@google/genai'
|
||||
import { loggerService } from '@logger'
|
||||
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
|
||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import type { Model, Provider, VertexProvider } from '@renderer/types'
|
||||
import { isVertexProvider } from '@renderer/utils/provider'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient'
|
||||
|
||||
@@ -10,6 +10,7 @@ import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import {
|
||||
findTokenLimit,
|
||||
GEMINI_FLASH_MODEL_REGEX,
|
||||
getOpenAIWebSearchParams,
|
||||
getThinkModelType,
|
||||
isClaudeReasoningModel,
|
||||
isDeepSeekHybridInferenceModel,
|
||||
@@ -39,6 +40,12 @@ import {
|
||||
MODEL_SUPPORTED_REASONING_EFFORT,
|
||||
ZHIPU_RESULT_TOKENS
|
||||
} from '@renderer/config/models'
|
||||
import {
|
||||
isSupportArrayContentProvider,
|
||||
isSupportDeveloperRoleProvider,
|
||||
isSupportEnableThinkingProvider,
|
||||
isSupportStreamOptionsProvider
|
||||
} from '@renderer/config/providers'
|
||||
import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
|
||||
import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
@@ -82,12 +89,6 @@ import {
|
||||
openAIToolsToMcpTool
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||
import {
|
||||
isSupportArrayContentProvider,
|
||||
isSupportDeveloperRoleProvider,
|
||||
isSupportEnableThinkingProvider,
|
||||
isSupportStreamOptionsProvider
|
||||
} from '@renderer/utils/provider'
|
||||
import { t } from 'i18next'
|
||||
|
||||
import type { GenericChunk } from '../../middleware/schemas'
|
||||
@@ -742,7 +743,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
|
||||
: {}),
|
||||
...this.getProviderSpecificParameters(assistant, model),
|
||||
...reasoningEffort,
|
||||
// ...getOpenAIWebSearchParams(model, enableWebSearch),
|
||||
...getOpenAIWebSearchParams(model, enableWebSearch),
|
||||
// OpenRouter usage tracking
|
||||
...(this.provider.id === 'openrouter' ? { usage: { include: true } } : {}),
|
||||
...extra_body,
|
||||
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
isSupportVerbosityModel,
|
||||
isVisionModel
|
||||
} from '@renderer/config/models'
|
||||
import { isSupportDeveloperRoleProvider } from '@renderer/config/providers'
|
||||
import { estimateTextTokens } from '@renderer/services/TokenService'
|
||||
import type {
|
||||
FileMetadata,
|
||||
@@ -42,7 +43,6 @@ import {
|
||||
openAIToolsToMcpTool
|
||||
} from '@renderer/utils/mcp-tools'
|
||||
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
|
||||
import { isSupportDeveloperRoleProvider } from '@renderer/utils/provider'
|
||||
import { MB } from '@shared/config/constant'
|
||||
import { t } from 'i18next'
|
||||
import { isEmpty } from 'lodash'
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { loggerService } from '@logger'
|
||||
import { isZhipuModel } from '@renderer/config/models'
|
||||
import { getStoreProviders } from '@renderer/hooks/useStore'
|
||||
import { getDefaultModel } from '@renderer/services/AssistantService'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
|
||||
import type { CompletionsParams, CompletionsResult } from '../schemas'
|
||||
@@ -67,7 +66,7 @@ export const ErrorHandlerMiddleware =
|
||||
}
|
||||
|
||||
function handleError(error: any, params: CompletionsParams): any {
|
||||
if (isZhipuModel(params.assistant.model || getDefaultModel()) && error.status && !params.enableGenerateImage) {
|
||||
if (isZhipuModel(params.assistant.model) && error.status && !params.enableGenerateImage) {
|
||||
return handleZhipuError(error)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
import { loggerService } from '@logger'
|
||||
import { isSupportedThinkingTokenQwenModel } from '@renderer/config/models'
|
||||
import { isSupportEnableThinkingProvider } from '@renderer/config/providers'
|
||||
import type { MCPTool } from '@renderer/types'
|
||||
import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types'
|
||||
import { type Assistant, type Message, type Model, type Provider } from '@renderer/types'
|
||||
import type { Chunk } from '@renderer/types/chunk'
|
||||
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
|
||||
import { isEmpty } from 'lodash'
|
||||
@@ -12,7 +12,6 @@ import { isEmpty } from 'lodash'
|
||||
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
|
||||
import { noThinkMiddleware } from './noThinkMiddleware'
|
||||
import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware'
|
||||
import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
|
||||
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
|
||||
import { toolChoiceMiddleware } from './toolChoiceMiddleware'
|
||||
|
||||
@@ -218,14 +217,6 @@ function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config:
|
||||
middleware: noThinkMiddleware()
|
||||
})
|
||||
}
|
||||
|
||||
if (config.provider.id === SystemProviderIds.openrouter && config.enableReasoning) {
|
||||
builder.add({
|
||||
name: 'openrouter-reasoning-redaction',
|
||||
middleware: openrouterReasoningMiddleware()
|
||||
})
|
||||
logger.debug('Added OpenRouter reasoning redaction middleware')
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
import type { LanguageModelV2StreamPart } from '@ai-sdk/provider'
|
||||
import type { LanguageModelMiddleware } from 'ai'
|
||||
|
||||
/**
|
||||
* https://openrouter.ai/docs/docs/best-practices/reasoning-tokens#example-preserving-reasoning-blocks-with-openrouter-and-claude
|
||||
*
|
||||
* @returns LanguageModelMiddleware - a middleware filter redacted block
|
||||
*/
|
||||
export function openrouterReasoningMiddleware(): LanguageModelMiddleware {
|
||||
const REDACTED_BLOCK = '[REDACTED]'
|
||||
return {
|
||||
middlewareVersion: 'v2',
|
||||
wrapGenerate: async ({ doGenerate }) => {
|
||||
const { content, ...rest } = await doGenerate()
|
||||
const modifiedContent = content.map((part) => {
|
||||
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
|
||||
return {
|
||||
...part,
|
||||
text: part.text.replace(REDACTED_BLOCK, '')
|
||||
}
|
||||
}
|
||||
return part
|
||||
})
|
||||
return { content: modifiedContent, ...rest }
|
||||
},
|
||||
wrapStream: async ({ doStream }) => {
|
||||
const { stream, ...rest } = await doStream()
|
||||
return {
|
||||
stream: stream.pipeThrough(
|
||||
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
|
||||
transform(
|
||||
chunk: LanguageModelV2StreamPart,
|
||||
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
|
||||
) {
|
||||
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
|
||||
controller.enqueue({
|
||||
...chunk,
|
||||
delta: chunk.delta.replace(REDACTED_BLOCK, '')
|
||||
})
|
||||
} else {
|
||||
controller.enqueue(chunk)
|
||||
}
|
||||
}
|
||||
})
|
||||
),
|
||||
...rest
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,234 +0,0 @@
|
||||
import type { Message, Model } from '@renderer/types'
|
||||
import type { FileMetadata } from '@renderer/types/file'
|
||||
import { FileTypes } from '@renderer/types/file'
|
||||
import {
|
||||
AssistantMessageStatus,
|
||||
type FileMessageBlock,
|
||||
type ImageMessageBlock,
|
||||
MessageBlockStatus,
|
||||
MessageBlockType,
|
||||
type ThinkingMessageBlock,
|
||||
UserMessageStatus
|
||||
} from '@renderer/types/newMessage'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const { convertFileBlockToFilePartMock, convertFileBlockToTextPartMock } = vi.hoisted(() => ({
|
||||
convertFileBlockToFilePartMock: vi.fn(),
|
||||
convertFileBlockToTextPartMock: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('../fileProcessor', () => ({
|
||||
convertFileBlockToFilePart: convertFileBlockToFilePartMock,
|
||||
convertFileBlockToTextPart: convertFileBlockToTextPartMock
|
||||
}))
|
||||
|
||||
const visionModelIds = new Set(['gpt-4o-mini', 'qwen-image-edit'])
|
||||
const imageEnhancementModelIds = new Set(['qwen-image-edit'])
|
||||
|
||||
vi.mock('@renderer/config/models', () => ({
|
||||
isVisionModel: (model: Model) => visionModelIds.has(model.id),
|
||||
isImageEnhancementModel: (model: Model) => imageEnhancementModelIds.has(model.id)
|
||||
}))
|
||||
|
||||
type MockableMessage = Message & {
|
||||
__mockContent?: string
|
||||
__mockFileBlocks?: FileMessageBlock[]
|
||||
__mockImageBlocks?: ImageMessageBlock[]
|
||||
__mockThinkingBlocks?: ThinkingMessageBlock[]
|
||||
}
|
||||
|
||||
vi.mock('@renderer/utils/messageUtils/find', () => ({
|
||||
getMainTextContent: (message: Message) => (message as MockableMessage).__mockContent ?? '',
|
||||
findFileBlocks: (message: Message) => (message as MockableMessage).__mockFileBlocks ?? [],
|
||||
findImageBlocks: (message: Message) => (message as MockableMessage).__mockImageBlocks ?? [],
|
||||
findThinkingBlocks: (message: Message) => (message as MockableMessage).__mockThinkingBlocks ?? []
|
||||
}))
|
||||
|
||||
import { convertMessagesToSdkMessages, convertMessageToSdkParam } from '../messageConverter'
|
||||
|
||||
let messageCounter = 0
|
||||
let blockCounter = 0
|
||||
|
||||
const createModel = (overrides: Partial<Model> = {}): Model => ({
|
||||
id: 'gpt-4o-mini',
|
||||
name: 'GPT-4o mini',
|
||||
provider: 'openai',
|
||||
group: 'openai',
|
||||
...overrides
|
||||
})
|
||||
|
||||
const createMessage = (role: Message['role']): MockableMessage =>
|
||||
({
|
||||
id: `message-${++messageCounter}`,
|
||||
role,
|
||||
assistantId: 'assistant-1',
|
||||
topicId: 'topic-1',
|
||||
createdAt: new Date(2024, 0, 1, 0, 0, messageCounter).toISOString(),
|
||||
status: role === 'assistant' ? AssistantMessageStatus.SUCCESS : UserMessageStatus.SUCCESS,
|
||||
blocks: []
|
||||
}) as MockableMessage
|
||||
|
||||
const createFileBlock = (
|
||||
messageId: string,
|
||||
overrides: Partial<Omit<FileMessageBlock, 'file' | 'messageId' | 'type'>> & { file?: Partial<FileMetadata> } = {}
|
||||
): FileMessageBlock => {
|
||||
const { file, ...blockOverrides } = overrides
|
||||
const timestamp = new Date(2024, 0, 1, 0, 0, ++blockCounter).toISOString()
|
||||
return {
|
||||
id: blockOverrides.id ?? `file-block-${blockCounter}`,
|
||||
messageId,
|
||||
type: MessageBlockType.FILE,
|
||||
createdAt: blockOverrides.createdAt ?? timestamp,
|
||||
status: blockOverrides.status ?? MessageBlockStatus.SUCCESS,
|
||||
file: {
|
||||
id: file?.id ?? `file-${blockCounter}`,
|
||||
name: file?.name ?? 'document.txt',
|
||||
origin_name: file?.origin_name ?? 'document.txt',
|
||||
path: file?.path ?? '/tmp/document.txt',
|
||||
size: file?.size ?? 1024,
|
||||
ext: file?.ext ?? '.txt',
|
||||
type: file?.type ?? FileTypes.TEXT,
|
||||
created_at: file?.created_at ?? timestamp,
|
||||
count: file?.count ?? 1,
|
||||
...file
|
||||
},
|
||||
...blockOverrides
|
||||
}
|
||||
}
|
||||
|
||||
const createImageBlock = (
|
||||
messageId: string,
|
||||
overrides: Partial<Omit<ImageMessageBlock, 'type' | 'messageId'>> = {}
|
||||
): ImageMessageBlock => ({
|
||||
id: overrides.id ?? `image-block-${++blockCounter}`,
|
||||
messageId,
|
||||
type: MessageBlockType.IMAGE,
|
||||
createdAt: overrides.createdAt ?? new Date(2024, 0, 1, 0, 0, blockCounter).toISOString(),
|
||||
status: overrides.status ?? MessageBlockStatus.SUCCESS,
|
||||
url: overrides.url ?? 'https://example.com/image.png',
|
||||
...overrides
|
||||
})
|
||||
|
||||
describe('messageConverter', () => {
|
||||
beforeEach(() => {
|
||||
convertFileBlockToFilePartMock.mockReset()
|
||||
convertFileBlockToTextPartMock.mockReset()
|
||||
convertFileBlockToFilePartMock.mockResolvedValue(null)
|
||||
convertFileBlockToTextPartMock.mockResolvedValue(null)
|
||||
messageCounter = 0
|
||||
blockCounter = 0
|
||||
})
|
||||
|
||||
describe('convertMessageToSdkParam', () => {
|
||||
it('includes text and image parts for user messages on vision models', async () => {
|
||||
const model = createModel()
|
||||
const message = createMessage('user')
|
||||
message.__mockContent = 'Describe this picture'
|
||||
message.__mockImageBlocks = [createImageBlock(message.id, { url: 'https://example.com/cat.png' })]
|
||||
|
||||
const result = await convertMessageToSdkParam(message, true, model)
|
||||
|
||||
expect(result).toEqual({
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Describe this picture' },
|
||||
{ type: 'image', image: 'https://example.com/cat.png' }
|
||||
]
|
||||
})
|
||||
})
|
||||
|
||||
it('returns file instructions as a system message when native uploads succeed', async () => {
|
||||
const model = createModel()
|
||||
const message = createMessage('user')
|
||||
message.__mockContent = 'Summarize the PDF'
|
||||
message.__mockFileBlocks = [createFileBlock(message.id)]
|
||||
convertFileBlockToFilePartMock.mockResolvedValueOnce({
|
||||
type: 'file',
|
||||
filename: 'document.pdf',
|
||||
mediaType: 'application/pdf',
|
||||
data: 'fileid://remote-file'
|
||||
})
|
||||
|
||||
const result = await convertMessageToSdkParam(message, false, model)
|
||||
|
||||
expect(result).toEqual([
|
||||
{
|
||||
role: 'system',
|
||||
content: 'fileid://remote-file'
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: [{ type: 'text', text: 'Summarize the PDF' }]
|
||||
}
|
||||
])
|
||||
})
|
||||
})
|
||||
|
||||
describe('convertMessagesToSdkMessages', () => {
|
||||
it('appends assistant images to the final user message for image enhancement models', async () => {
|
||||
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
|
||||
const initialUser = createMessage('user')
|
||||
initialUser.__mockContent = 'Start editing'
|
||||
|
||||
const assistant = createMessage('assistant')
|
||||
assistant.__mockContent = 'Here is the current preview'
|
||||
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/preview.png' })]
|
||||
|
||||
const finalUser = createMessage('user')
|
||||
finalUser.__mockContent = 'Increase the brightness'
|
||||
|
||||
const result = await convertMessagesToSdkMessages([initialUser, assistant, finalUser], model)
|
||||
|
||||
expect(result).toEqual([
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'Here is the current preview' }]
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Increase the brightness' },
|
||||
{ type: 'image', image: 'https://example.com/preview.png' }
|
||||
]
|
||||
}
|
||||
])
|
||||
})
|
||||
|
||||
it('preserves preceding system instructions when building enhancement payloads', async () => {
|
||||
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
|
||||
const fileUser = createMessage('user')
|
||||
fileUser.__mockContent = 'Use this document as inspiration'
|
||||
fileUser.__mockFileBlocks = [createFileBlock(fileUser.id, { file: { ext: '.pdf', type: FileTypes.DOCUMENT } })]
|
||||
convertFileBlockToFilePartMock.mockResolvedValueOnce({
|
||||
type: 'file',
|
||||
filename: 'reference.pdf',
|
||||
mediaType: 'application/pdf',
|
||||
data: 'fileid://reference'
|
||||
})
|
||||
|
||||
const assistant = createMessage('assistant')
|
||||
assistant.__mockContent = 'Generated previews ready'
|
||||
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/reference.png' })]
|
||||
|
||||
const finalUser = createMessage('user')
|
||||
finalUser.__mockContent = 'Apply the edits'
|
||||
|
||||
const result = await convertMessagesToSdkMessages([fileUser, assistant, finalUser], model)
|
||||
|
||||
expect(result).toEqual([
|
||||
{ role: 'system', content: 'fileid://reference' },
|
||||
{
|
||||
role: 'assistant',
|
||||
content: [{ type: 'text', text: 'Generated previews ready' }]
|
||||
},
|
||||
{
|
||||
role: 'user',
|
||||
content: [
|
||||
{ type: 'text', text: 'Apply the edits' },
|
||||
{ type: 'image', image: 'https://example.com/reference.png' }
|
||||
]
|
||||
}
|
||||
])
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,218 +0,0 @@
|
||||
import type { Assistant, AssistantSettings, Model, Topic } from '@renderer/types'
|
||||
import { TopicType } from '@renderer/types'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { getTemperature, getTimeout, getTopP } from '../modelParameters'
|
||||
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
getAssistantSettings: (assistant: Assistant): AssistantSettings => ({
|
||||
contextCount: assistant.settings?.contextCount ?? 4096,
|
||||
temperature: assistant.settings?.temperature ?? 0.7,
|
||||
enableTemperature: assistant.settings?.enableTemperature ?? true,
|
||||
topP: assistant.settings?.topP ?? 1,
|
||||
enableTopP: assistant.settings?.enableTopP ?? false,
|
||||
enableMaxTokens: assistant.settings?.enableMaxTokens ?? false,
|
||||
maxTokens: assistant.settings?.maxTokens,
|
||||
streamOutput: assistant.settings?.streamOutput ?? true,
|
||||
toolUseMode: assistant.settings?.toolUseMode ?? 'prompt',
|
||||
defaultModel: assistant.defaultModel,
|
||||
customParameters: assistant.settings?.customParameters ?? [],
|
||||
reasoning_effort: assistant.settings?.reasoning_effort,
|
||||
reasoning_effort_cache: assistant.settings?.reasoning_effort_cache,
|
||||
qwenThinkMode: assistant.settings?.qwenThinkMode
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/hooks/useSettings', () => ({
|
||||
getStoreSetting: vi.fn(),
|
||||
useSettings: vi.fn(() => ({})),
|
||||
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left', isLeftNavbar: true, isTopNavbar: false }))
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/hooks/useStore', () => ({
|
||||
getStoreProviders: vi.fn(() => [])
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store/settings', () => ({
|
||||
default: (state = { settings: {} }) => state
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store/assistants', () => ({
|
||||
default: (state = { assistants: [] }) => state
|
||||
}))
|
||||
|
||||
const createTopic = (assistantId: string): Topic => ({
|
||||
id: `topic-${assistantId}`,
|
||||
assistantId,
|
||||
name: 'topic',
|
||||
createdAt: new Date().toISOString(),
|
||||
updatedAt: new Date().toISOString(),
|
||||
messages: [],
|
||||
type: TopicType.Chat
|
||||
})
|
||||
|
||||
const createAssistant = (settings: Assistant['settings'] = {}): Assistant => {
|
||||
const assistantId = 'assistant-1'
|
||||
return {
|
||||
id: assistantId,
|
||||
name: 'Test Assistant',
|
||||
prompt: 'prompt',
|
||||
topics: [createTopic(assistantId)],
|
||||
type: 'assistant',
|
||||
settings
|
||||
}
|
||||
}
|
||||
|
||||
const createModel = (overrides: Partial<Model> = {}): Model => ({
|
||||
id: 'gpt-4o',
|
||||
provider: 'openai',
|
||||
name: 'GPT-4o',
|
||||
group: 'openai',
|
||||
...overrides
|
||||
})
|
||||
|
||||
describe('modelParameters', () => {
|
||||
describe('getTemperature', () => {
|
||||
it('returns undefined when reasoning effort is enabled for Claude models', () => {
|
||||
const assistant = createAssistant({ reasoning_effort: 'medium' })
|
||||
const model = createModel({ id: 'claude-opus-4', name: 'Claude Opus 4', provider: 'anthropic', group: 'claude' })
|
||||
|
||||
expect(getTemperature(assistant, model)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns undefined for models without temperature/topP support', () => {
|
||||
const assistant = createAssistant({ enableTemperature: true })
|
||||
const model = createModel({ id: 'qwen-mt-large', name: 'Qwen MT', provider: 'qwen', group: 'qwen' })
|
||||
|
||||
expect(getTemperature(assistant, model)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns undefined for Claude 4.5 reasoning models when only TopP is enabled', () => {
|
||||
const assistant = createAssistant({ enableTopP: true, enableTemperature: false })
|
||||
const model = createModel({
|
||||
id: 'claude-sonnet-4.5',
|
||||
name: 'Claude Sonnet 4.5',
|
||||
provider: 'anthropic',
|
||||
group: 'claude'
|
||||
})
|
||||
|
||||
expect(getTemperature(assistant, model)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns configured temperature when enabled', () => {
|
||||
const assistant = createAssistant({ enableTemperature: true, temperature: 0.42 })
|
||||
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
|
||||
|
||||
expect(getTemperature(assistant, model)).toBe(0.42)
|
||||
})
|
||||
|
||||
it('returns undefined when temperature is disabled', () => {
|
||||
const assistant = createAssistant({ enableTemperature: false, temperature: 0.9 })
|
||||
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
|
||||
|
||||
expect(getTemperature(assistant, model)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('clamps temperature to max 1.0 for Zhipu models', () => {
|
||||
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
|
||||
const model = createModel({ id: 'glm-4-plus', name: 'GLM-4 Plus', provider: 'zhipu', group: 'zhipu' })
|
||||
|
||||
expect(getTemperature(assistant, model)).toBe(1.0)
|
||||
})
|
||||
|
||||
it('clamps temperature to max 1.0 for Anthropic models', () => {
|
||||
const assistant = createAssistant({ enableTemperature: true, temperature: 1.5 })
|
||||
const model = createModel({
|
||||
id: 'claude-sonnet-3.5',
|
||||
name: 'Claude 3.5 Sonnet',
|
||||
provider: 'anthropic',
|
||||
group: 'claude'
|
||||
})
|
||||
|
||||
expect(getTemperature(assistant, model)).toBe(1.0)
|
||||
})
|
||||
|
||||
it('clamps temperature to max 1.0 for Moonshot models', () => {
|
||||
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
|
||||
const model = createModel({
|
||||
id: 'moonshot-v1-8k',
|
||||
name: 'Moonshot v1 8k',
|
||||
provider: 'moonshot',
|
||||
group: 'moonshot'
|
||||
})
|
||||
|
||||
expect(getTemperature(assistant, model)).toBe(1.0)
|
||||
})
|
||||
|
||||
it('does not clamp temperature for OpenAI models', () => {
|
||||
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
|
||||
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
|
||||
|
||||
expect(getTemperature(assistant, model)).toBe(2.0)
|
||||
})
|
||||
|
||||
it('does not clamp temperature when it is already within limits', () => {
|
||||
const assistant = createAssistant({ enableTemperature: true, temperature: 0.8 })
|
||||
const model = createModel({ id: 'glm-4-plus', name: 'GLM-4 Plus', provider: 'zhipu', group: 'zhipu' })
|
||||
|
||||
expect(getTemperature(assistant, model)).toBe(0.8)
|
||||
})
|
||||
})
|
||||
|
||||
describe('getTopP', () => {
|
||||
it('returns undefined when reasoning effort is enabled for Claude models', () => {
|
||||
const assistant = createAssistant({ reasoning_effort: 'high' })
|
||||
const model = createModel({ id: 'claude-opus-4', provider: 'anthropic', group: 'claude' })
|
||||
|
||||
expect(getTopP(assistant, model)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns undefined for models without TopP support', () => {
|
||||
const assistant = createAssistant({ enableTopP: true })
|
||||
const model = createModel({ id: 'qwen-mt-small', name: 'Qwen MT', provider: 'qwen', group: 'qwen' })
|
||||
|
||||
expect(getTopP(assistant, model)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns undefined for Claude 4.5 reasoning models when temperature is enabled', () => {
|
||||
const assistant = createAssistant({ enableTemperature: true })
|
||||
const model = createModel({
|
||||
id: 'claude-opus-4.5',
|
||||
name: 'Claude Opus 4.5',
|
||||
provider: 'anthropic',
|
||||
group: 'claude'
|
||||
})
|
||||
|
||||
expect(getTopP(assistant, model)).toBeUndefined()
|
||||
})
|
||||
|
||||
it('returns configured TopP when enabled', () => {
|
||||
const assistant = createAssistant({ enableTopP: true, topP: 0.73 })
|
||||
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
|
||||
|
||||
expect(getTopP(assistant, model)).toBe(0.73)
|
||||
})
|
||||
|
||||
it('returns undefined when TopP is disabled', () => {
|
||||
const assistant = createAssistant({ enableTopP: false, topP: 0.5 })
|
||||
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
|
||||
|
||||
expect(getTopP(assistant, model)).toBeUndefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('getTimeout', () => {
|
||||
it('uses an extended timeout for flex service tier models', () => {
|
||||
const model = createModel({ id: 'o3-pro', provider: 'openai', group: 'openai' })
|
||||
|
||||
expect(getTimeout(model)).toBe(15 * 1000 * 60)
|
||||
})
|
||||
|
||||
it('falls back to the default timeout otherwise', () => {
|
||||
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
|
||||
|
||||
expect(getTimeout(model)).toBe(defaultTimeout)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,31 +1,13 @@
|
||||
import { isClaude4SeriesModel, isClaude45ReasoningModel } from '@renderer/config/models'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { isClaude45ReasoningModel } from '@renderer/config/models'
|
||||
import type { Assistant, Model } from '@renderer/types'
|
||||
import { isToolUseModeFunction } from '@renderer/utils/assistant'
|
||||
import { isAwsBedrockProvider, isVertexProvider } from '@renderer/utils/provider'
|
||||
|
||||
// https://docs.claude.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
|
||||
const INTERLEAVED_THINKING_HEADER = 'interleaved-thinking-2025-05-14'
|
||||
// https://docs.claude.com/en/docs/build-with-claude/context-windows#1m-token-context-window
|
||||
const CONTEXT_100M_HEADER = 'context-1m-2025-08-07'
|
||||
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude/web-search
|
||||
const WEBSEARCH_HEADER = 'web-search-2025-03-05'
|
||||
|
||||
export function addAnthropicHeaders(assistant: Assistant, model: Model): string[] {
|
||||
const anthropicHeaders: string[] = []
|
||||
const provider = getProviderByModel(model)
|
||||
if (
|
||||
isClaude45ReasoningModel(model) &&
|
||||
isToolUseModeFunction(assistant) &&
|
||||
!(isVertexProvider(provider) && isAwsBedrockProvider(provider))
|
||||
) {
|
||||
if (isClaude45ReasoningModel(model) && isToolUseModeFunction(assistant)) {
|
||||
anthropicHeaders.push(INTERLEAVED_THINKING_HEADER)
|
||||
}
|
||||
if (isClaude4SeriesModel(model)) {
|
||||
if (isVertexProvider(provider) && assistant.enableWebSearch) {
|
||||
anthropicHeaders.push(WEBSEARCH_HEADER)
|
||||
}
|
||||
anthropicHeaders.push(CONTEXT_100M_HEADER)
|
||||
}
|
||||
return anthropicHeaders
|
||||
}
|
||||
|
||||
@@ -85,6 +85,19 @@ export function supportsLargeFileUpload(model: Model): boolean {
|
||||
})
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查模型是否支持TopP
|
||||
*/
|
||||
export function supportsTopP(model: Model): boolean {
|
||||
const provider = getProviderByModel(model)
|
||||
|
||||
if (provider?.type === 'anthropic' || model?.endpoint_type === 'anthropic') {
|
||||
return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取提供商特定的文件大小限制
|
||||
*/
|
||||
|
||||
@@ -3,27 +3,17 @@
|
||||
* 处理温度、TopP、超时等基础参数的获取逻辑
|
||||
*/
|
||||
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import {
|
||||
isClaude45ReasoningModel,
|
||||
isClaudeReasoningModel,
|
||||
isMaxTemperatureOneModel,
|
||||
isNotSupportTemperatureAndTopP,
|
||||
isSupportedFlexServiceTier,
|
||||
isSupportedThinkingTokenClaudeModel
|
||||
isSupportedFlexServiceTier
|
||||
} from '@renderer/config/models'
|
||||
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import { getAssistantSettings } from '@renderer/services/AssistantService'
|
||||
import type { Assistant, Model } from '@renderer/types'
|
||||
import { defaultTimeout } from '@shared/config/constant'
|
||||
|
||||
import { getAnthropicThinkingBudget } from '../utils/reasoning'
|
||||
|
||||
/**
|
||||
* Claude 4.5 推理模型:
|
||||
* - 只启用 temperature → 使用 temperature
|
||||
* - 只启用 top_p → 使用 top_p
|
||||
* - 同时启用 → temperature 生效,top_p 被忽略
|
||||
* - 都不启用 → 都不使用
|
||||
* 获取温度参数
|
||||
*/
|
||||
export function getTemperature(assistant: Assistant, model: Model): number | undefined {
|
||||
@@ -37,11 +27,7 @@ export function getTemperature(assistant: Assistant, model: Model): number | und
|
||||
return undefined
|
||||
}
|
||||
const assistantSettings = getAssistantSettings(assistant)
|
||||
let temperature = assistantSettings?.temperature
|
||||
if (temperature && isMaxTemperatureOneModel(model)) {
|
||||
temperature = Math.min(1, temperature)
|
||||
}
|
||||
return assistantSettings?.enableTemperature ? temperature : undefined
|
||||
return assistantSettings?.enableTemperature ? assistantSettings?.temperature : undefined
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -70,18 +56,3 @@ export function getTimeout(model: Model): number {
|
||||
}
|
||||
return defaultTimeout
|
||||
}
|
||||
|
||||
export function getMaxTokens(assistant: Assistant, model: Model): number | undefined {
|
||||
// NOTE: ai-sdk会把maxToken和budgetToken加起来
|
||||
let { maxTokens = DEFAULT_MAX_TOKENS } = getAssistantSettings(assistant)
|
||||
|
||||
const provider = getProviderByModel(model)
|
||||
if (isSupportedThinkingTokenClaudeModel(model) && ['anthropic', 'aws-bedrock'].includes(provider.type)) {
|
||||
const { reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
|
||||
const budget = getAnthropicThinkingBudget(maxTokens, reasoningEffort, model.id)
|
||||
if (budget) {
|
||||
maxTokens -= budget
|
||||
}
|
||||
}
|
||||
return maxTokens
|
||||
}
|
||||
|
||||
@@ -4,12 +4,11 @@
|
||||
*/
|
||||
|
||||
import { anthropic } from '@ai-sdk/anthropic'
|
||||
import { azure } from '@ai-sdk/azure'
|
||||
import { google } from '@ai-sdk/google'
|
||||
import { vertexAnthropic } from '@ai-sdk/google-vertex/anthropic/edge'
|
||||
import { vertex } from '@ai-sdk/google-vertex/edge'
|
||||
import { combineHeaders } from '@ai-sdk/provider-utils'
|
||||
import type { AnthropicSearchConfig, WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
|
||||
import { isBaseProvider } from '@cherrystudio/ai-core/core/providers/schemas'
|
||||
import { loggerService } from '@logger'
|
||||
import {
|
||||
@@ -18,10 +17,13 @@ import {
|
||||
isOpenRouterBuiltInWebSearchModel,
|
||||
isReasoningModel,
|
||||
isSupportedReasoningEffortModel,
|
||||
isSupportedThinkingTokenClaudeModel,
|
||||
isSupportedThinkingTokenModel,
|
||||
isWebSearchModel
|
||||
} from '@renderer/config/models'
|
||||
import { getDefaultModel } from '@renderer/services/AssistantService'
|
||||
import { isAwsBedrockProvider } from '@renderer/config/providers'
|
||||
import { isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
|
||||
import store from '@renderer/store'
|
||||
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
|
||||
import { type Assistant, type MCPTool, type Provider } from '@renderer/types'
|
||||
@@ -34,9 +36,11 @@ import { stepCountIs } from 'ai'
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
import { setupToolsConfig } from '../utils/mcp'
|
||||
import { buildProviderOptions } from '../utils/options'
|
||||
import { getAnthropicThinkingBudget } from '../utils/reasoning'
|
||||
import { buildProviderBuiltinWebSearchConfig } from '../utils/websearch'
|
||||
import { addAnthropicHeaders } from './header'
|
||||
import { getMaxTokens, getTemperature, getTopP } from './modelParameters'
|
||||
import { supportsTopP } from './modelCapabilities'
|
||||
import { getTemperature, getTopP } from './modelParameters'
|
||||
|
||||
const logger = loggerService.withContext('parameterBuilder')
|
||||
|
||||
@@ -59,7 +63,7 @@ export async function buildStreamTextParams(
|
||||
timeout?: number
|
||||
headers?: Record<string, string>
|
||||
}
|
||||
}
|
||||
} = {}
|
||||
): Promise<{
|
||||
params: StreamTextParams
|
||||
modelId: string
|
||||
@@ -76,6 +80,8 @@ export async function buildStreamTextParams(
|
||||
const model = assistant.model || getDefaultModel()
|
||||
const aiSdkProviderId = getAiSdkProviderId(provider)
|
||||
|
||||
let { maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
// 这三个变量透传出来,交给下面启用插件/中间件
|
||||
// 也可以在外部构建好再传入buildStreamTextParams
|
||||
// FIXME: qwen3即使关闭思考仍然会导致enableReasoning的结果为true
|
||||
@@ -112,6 +118,16 @@ export async function buildStreamTextParams(
|
||||
enableGenerateImage
|
||||
})
|
||||
|
||||
// NOTE: ai-sdk会把maxToken和budgetToken加起来
|
||||
if (
|
||||
enableReasoning &&
|
||||
maxTokens !== undefined &&
|
||||
isSupportedThinkingTokenClaudeModel(model) &&
|
||||
(provider.type === 'anthropic' || provider.type === 'aws-bedrock')
|
||||
) {
|
||||
maxTokens -= getAnthropicThinkingBudget(assistant, model)
|
||||
}
|
||||
|
||||
let webSearchPluginConfig: WebSearchPluginConfig | undefined = undefined
|
||||
if (enableWebSearch) {
|
||||
if (isBaseProvider(aiSdkProviderId)) {
|
||||
@@ -128,17 +144,6 @@ export async function buildStreamTextParams(
|
||||
maxUses: webSearchConfig.maxResults,
|
||||
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
|
||||
}) as ProviderDefinedTool
|
||||
} else if (aiSdkProviderId === 'azure-responses') {
|
||||
tools.web_search_preview = azure.tools.webSearchPreview({
|
||||
searchContextSize: webSearchPluginConfig?.openai!.searchContextSize
|
||||
}) as ProviderDefinedTool
|
||||
} else if (aiSdkProviderId === 'azure-anthropic') {
|
||||
const blockedDomains = mapRegexToPatterns(webSearchConfig.excludeDomains)
|
||||
const anthropicSearchOptions: AnthropicSearchConfig = {
|
||||
maxUses: webSearchConfig.maxResults,
|
||||
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
|
||||
}
|
||||
tools.web_search = anthropic.tools.webSearch_20250305(anthropicSearchOptions) as ProviderDefinedTool
|
||||
}
|
||||
}
|
||||
|
||||
@@ -156,10 +161,9 @@ export async function buildStreamTextParams(
|
||||
tools.url_context = google.tools.urlContext({}) as ProviderDefinedTool
|
||||
break
|
||||
case 'anthropic':
|
||||
case 'azure-anthropic':
|
||||
case 'google-vertex-anthropic':
|
||||
tools.web_fetch = (
|
||||
['anthropic', 'azure-anthropic'].includes(aiSdkProviderId)
|
||||
aiSdkProviderId === 'anthropic'
|
||||
? anthropic.tools.webFetch_20250910({
|
||||
maxUses: webSearchConfig.maxResults,
|
||||
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
|
||||
@@ -175,7 +179,8 @@ export async function buildStreamTextParams(
|
||||
|
||||
let headers: Record<string, string | undefined> = options.requestOptions?.headers ?? {}
|
||||
|
||||
if (isAnthropicModel(model)) {
|
||||
// https://docs.claude.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
|
||||
if (!isVertexProvider(provider) && !isAwsBedrockProvider(provider) && isAnthropicModel(model)) {
|
||||
const newBetaHeaders = { 'anthropic-beta': addAnthropicHeaders(assistant, model).join(',') }
|
||||
headers = combineHeaders(headers, newBetaHeaders)
|
||||
}
|
||||
@@ -183,9 +188,8 @@ export async function buildStreamTextParams(
|
||||
// 构建基础参数
|
||||
const params: StreamTextParams = {
|
||||
messages: sdkMessages,
|
||||
maxOutputTokens: getMaxTokens(assistant, model),
|
||||
maxOutputTokens: maxTokens,
|
||||
temperature: getTemperature(assistant, model),
|
||||
topP: getTopP(assistant, model),
|
||||
abortSignal: options.requestOptions?.signal,
|
||||
headers,
|
||||
providerOptions,
|
||||
@@ -193,6 +197,10 @@ export async function buildStreamTextParams(
|
||||
maxRetries: 0
|
||||
}
|
||||
|
||||
if (supportsTopP(model)) {
|
||||
params.topP = getTopP(assistant, model)
|
||||
}
|
||||
|
||||
if (tools) {
|
||||
params.tools = tools
|
||||
}
|
||||
|
||||
@@ -23,26 +23,6 @@ vi.mock('@cherrystudio/ai-core', () => ({
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
getProviderByModel: vi.fn(),
|
||||
getAssistantSettings: vi.fn(),
|
||||
getDefaultAssistant: vi.fn().mockReturnValue({
|
||||
id: 'default',
|
||||
name: 'Default Assistant',
|
||||
prompt: '',
|
||||
settings: {}
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store/settings', () => ({
|
||||
default: {},
|
||||
settingsSlice: {
|
||||
name: 'settings',
|
||||
reducer: vi.fn(),
|
||||
actions: {}
|
||||
}
|
||||
}))
|
||||
|
||||
// Mock the provider configs
|
||||
vi.mock('../providerConfigs', () => ({
|
||||
initializeNewProviders: vi.fn()
|
||||
|
||||
@@ -12,14 +12,7 @@ vi.mock('@renderer/services/LoggerService', () => ({
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
getProviderByModel: vi.fn(),
|
||||
getAssistantSettings: vi.fn(),
|
||||
getDefaultAssistant: vi.fn().mockReturnValue({
|
||||
id: 'default',
|
||||
name: 'Default Assistant',
|
||||
prompt: '',
|
||||
settings: {}
|
||||
})
|
||||
getProviderByModel: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store', () => ({
|
||||
@@ -41,7 +34,7 @@ vi.mock('@renderer/utils/api', () => ({
|
||||
}))
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/provider', async (importOriginal) => {
|
||||
vi.mock('@renderer/config/providers', async (importOriginal) => {
|
||||
const actual = (await importOriginal()) as any
|
||||
return {
|
||||
...actual,
|
||||
@@ -60,21 +53,10 @@ vi.mock('@renderer/hooks/useVertexAI', () => ({
|
||||
createVertexProvider: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
getProviderByModel: vi.fn(),
|
||||
getAssistantSettings: vi.fn(),
|
||||
getDefaultAssistant: vi.fn().mockReturnValue({
|
||||
id: 'default',
|
||||
name: 'Default Assistant',
|
||||
prompt: '',
|
||||
settings: {}
|
||||
})
|
||||
}))
|
||||
|
||||
import { isCherryAIProvider, isPerplexityProvider } from '@renderer/config/providers'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
import { formatApiHost } from '@renderer/utils/api'
|
||||
import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider'
|
||||
|
||||
import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants'
|
||||
import { getActualProvider, providerToAiSdkConfig } from '../providerConfig'
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
import type { Provider } from '@renderer/types'
|
||||
|
||||
import { provider2Provider, startsWith } from './helper'
|
||||
import type { RuleSet } from './types'
|
||||
|
||||
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
|
||||
const AZURE_ANTHROPIC_RULES: RuleSet = {
|
||||
rules: [
|
||||
{
|
||||
match: startsWith('claude'),
|
||||
provider: (provider: Provider) => ({
|
||||
...provider,
|
||||
type: 'anthropic',
|
||||
apiHost: provider.apiHost + 'anthropic/v1',
|
||||
id: 'azure-anthropic'
|
||||
})
|
||||
}
|
||||
],
|
||||
fallbackRule: (provider: Provider) => provider
|
||||
}
|
||||
|
||||
export const azureAnthropicProviderCreator = provider2Provider.bind(null, AZURE_ANTHROPIC_RULES)
|
||||
@@ -2,10 +2,8 @@ import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } fr
|
||||
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
|
||||
import { loggerService } from '@logger'
|
||||
import type { Provider } from '@renderer/types'
|
||||
import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from '@renderer/utils/provider'
|
||||
import type { Provider as AiSdkProvider } from 'ai'
|
||||
|
||||
import type { AiSdkConfig } from '../types'
|
||||
import { initializeNewProviders } from './providerInitialization'
|
||||
|
||||
const logger = loggerService.withContext('ProviderFactory')
|
||||
@@ -57,12 +55,9 @@ function tryResolveProviderId(identifier: string): ProviderId | null {
|
||||
* 获取AI SDK Provider ID
|
||||
* 简化版:减少重复逻辑,利用通用解析函数
|
||||
*/
|
||||
export function getAiSdkProviderId(provider: Provider): string {
|
||||
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' {
|
||||
// 1. 尝试解析provider.id
|
||||
const resolvedFromId = tryResolveProviderId(provider.id)
|
||||
if (isAzureOpenAIProvider(provider) && isAzureResponsesEndpoint(provider)) {
|
||||
return 'azure-responses'
|
||||
}
|
||||
if (resolvedFromId) {
|
||||
return resolvedFromId
|
||||
}
|
||||
@@ -78,11 +73,11 @@ export function getAiSdkProviderId(provider: Provider): string {
|
||||
if (provider.apiHost.includes('api.openai.com')) {
|
||||
return 'openai-chat'
|
||||
}
|
||||
// 3. 最后的fallback(使用provider本身的id)
|
||||
return provider.id
|
||||
// 3. 最后的fallback(通常会成为openai-compatible)
|
||||
return provider.id as ProviderId
|
||||
}
|
||||
|
||||
export async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {
|
||||
export async function createAiSdkProvider(config) {
|
||||
let localProvider: Awaited<AiSdkProvider> | null = null
|
||||
try {
|
||||
if (config.providerId === 'openai' && config.options?.mode === 'chat') {
|
||||
|
||||
@@ -1,5 +1,19 @@
|
||||
import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
|
||||
import {
|
||||
formatPrivateKey,
|
||||
hasProviderConfig,
|
||||
ProviderConfigFactory,
|
||||
type ProviderId,
|
||||
type ProviderSettingsMap
|
||||
} from '@cherrystudio/ai-core/provider'
|
||||
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
|
||||
import {
|
||||
isAnthropicProvider,
|
||||
isAzureOpenAIProvider,
|
||||
isCherryAIProvider,
|
||||
isGeminiProvider,
|
||||
isNewApiProvider,
|
||||
isPerplexityProvider
|
||||
} from '@renderer/config/providers'
|
||||
import {
|
||||
getAwsBedrockAccessKeyId,
|
||||
getAwsBedrockApiKey,
|
||||
@@ -7,25 +21,14 @@ import {
|
||||
getAwsBedrockRegion,
|
||||
getAwsBedrockSecretAccessKey
|
||||
} from '@renderer/hooks/useAwsBedrock'
|
||||
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
|
||||
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import store from '@renderer/store'
|
||||
import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types'
|
||||
import { formatApiHost, formatAzureOpenAIApiHost, formatVertexApiHost, routeToEndpoint } from '@renderer/utils/api'
|
||||
import {
|
||||
isAnthropicProvider,
|
||||
isAzureOpenAIProvider,
|
||||
isCherryAIProvider,
|
||||
isGeminiProvider,
|
||||
isNewApiProvider,
|
||||
isPerplexityProvider,
|
||||
isVertexProvider
|
||||
} from '@renderer/utils/provider'
|
||||
import { cloneDeep } from 'lodash'
|
||||
|
||||
import type { AiSdkConfig } from '../types'
|
||||
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
|
||||
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
|
||||
import { COPILOT_DEFAULT_HEADERS } from './constants'
|
||||
import { getAiSdkProviderId } from './factory'
|
||||
|
||||
@@ -71,9 +74,6 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider {
|
||||
return vertexAnthropicProviderCreator(model, provider)
|
||||
}
|
||||
}
|
||||
if (isAzureOpenAIProvider(provider)) {
|
||||
return azureAnthropicProviderCreator(model, provider)
|
||||
}
|
||||
return provider
|
||||
}
|
||||
|
||||
@@ -131,7 +131,13 @@ export function getActualProvider(model: Model): Provider {
|
||||
* 将 Provider 配置转换为新 AI SDK 格式
|
||||
* 简化版:利用新的别名映射系统
|
||||
*/
|
||||
export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig {
|
||||
export function providerToAiSdkConfig(
|
||||
actualProvider: Provider,
|
||||
model: Model
|
||||
): {
|
||||
providerId: ProviderId | 'openai-compatible'
|
||||
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||
} {
|
||||
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
|
||||
|
||||
// 构建基础配置
|
||||
@@ -185,10 +191,13 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A
|
||||
// azure
|
||||
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest
|
||||
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses?tabs=python-key#responses-api
|
||||
if (aiSdkProviderId === 'azure-responses') {
|
||||
extraOptions.mode = 'responses'
|
||||
} else if (aiSdkProviderId === 'azure') {
|
||||
extraOptions.mode = 'chat'
|
||||
if (aiSdkProviderId === 'azure' || actualProvider.type === 'azure-openai') {
|
||||
// extraOptions.apiVersion = actualProvider.apiVersion === 'preview' ? 'v1' : actualProvider.apiVersion 默认使用v1,不使用azure endpoint
|
||||
if (actualProvider.apiVersion === 'preview' || actualProvider.apiVersion === 'v1') {
|
||||
extraOptions.mode = 'responses'
|
||||
} else {
|
||||
extraOptions.mode = 'chat'
|
||||
}
|
||||
}
|
||||
|
||||
// bedrock
|
||||
@@ -218,17 +227,10 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A
|
||||
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
|
||||
}
|
||||
|
||||
// cherryin
|
||||
if (aiSdkProviderId === 'cherryin') {
|
||||
if (model.endpoint_type) {
|
||||
extraOptions.endpointType = model.endpoint_type
|
||||
}
|
||||
}
|
||||
|
||||
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
|
||||
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
|
||||
return {
|
||||
providerId: aiSdkProviderId,
|
||||
providerId: aiSdkProviderId as ProviderId,
|
||||
options
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,14 +32,6 @@ export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
|
||||
supportsImageGeneration: true,
|
||||
aliases: ['vertexai-anthropic']
|
||||
},
|
||||
{
|
||||
id: 'azure-anthropic',
|
||||
name: 'Azure AI Anthropic',
|
||||
import: () => import('@ai-sdk/anthropic'),
|
||||
creatorFunctionName: 'createAnthropic',
|
||||
supportsImageGeneration: false,
|
||||
aliases: ['azure-anthropic']
|
||||
},
|
||||
{
|
||||
id: 'github-copilot-openai-compatible',
|
||||
name: 'GitHub Copilot OpenAI Compatible',
|
||||
|
||||
@@ -133,7 +133,7 @@ export class AiSdkSpanAdapter {
|
||||
|
||||
// 详细记录转换过程
|
||||
const operationId = attributes['ai.operationId']
|
||||
logger.debug('Converting AI SDK span to SpanEntity', {
|
||||
logger.info('Converting AI SDK span to SpanEntity', {
|
||||
spanName: spanName,
|
||||
operationId,
|
||||
spanTag,
|
||||
@@ -149,7 +149,7 @@ export class AiSdkSpanAdapter {
|
||||
})
|
||||
|
||||
if (tokenUsage) {
|
||||
logger.debug('Token usage data found', {
|
||||
logger.info('Token usage data found', {
|
||||
spanName: spanName,
|
||||
operationId,
|
||||
usage: tokenUsage,
|
||||
@@ -158,7 +158,7 @@ export class AiSdkSpanAdapter {
|
||||
}
|
||||
|
||||
if (inputs || outputs) {
|
||||
logger.debug('Input/Output data extracted', {
|
||||
logger.info('Input/Output data extracted', {
|
||||
spanName: spanName,
|
||||
operationId,
|
||||
hasInputs: !!inputs,
|
||||
@@ -170,7 +170,7 @@ export class AiSdkSpanAdapter {
|
||||
}
|
||||
|
||||
if (Object.keys(typeSpecificData).length > 0) {
|
||||
logger.debug('Type-specific data extracted', {
|
||||
logger.info('Type-specific data extracted', {
|
||||
spanName: spanName,
|
||||
operationId,
|
||||
typeSpecificKeys: Object.keys(typeSpecificData),
|
||||
@@ -204,7 +204,7 @@ export class AiSdkSpanAdapter {
|
||||
modelName: modelName || this.extractModelFromAttributes(attributes)
|
||||
}
|
||||
|
||||
logger.debug('AI SDK span successfully converted to SpanEntity', {
|
||||
logger.info('AI SDK span successfully converted to SpanEntity', {
|
||||
spanName: spanName,
|
||||
operationId,
|
||||
spanId: spanContext.spanId,
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
/**
|
||||
* This type definition file is only for renderer.
|
||||
* It cannot be migrated to @renderer/types since files within it are actually being used by both main and renderer.
|
||||
* If we do that, main would throw an error because it cannot import a module which imports a type from a browser-enviroment-only package.
|
||||
* (ai-core package is set as browser-enviroment-only)
|
||||
*
|
||||
* TODO: We should separate them clearly. Keep renderer only types in renderer, and main only types in main, and shared types in shared.
|
||||
*/
|
||||
|
||||
import type { ProviderSettingsMap } from '@cherrystudio/ai-core/provider'
|
||||
|
||||
export type AiSdkConfig = {
|
||||
providerId: string
|
||||
options: ProviderSettingsMap[keyof ProviderSettingsMap]
|
||||
}
|
||||
@@ -1,121 +0,0 @@
|
||||
/**
|
||||
* image.ts Unit Tests
|
||||
* Tests for Gemini image generation utilities
|
||||
*/
|
||||
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
import { SystemProviderIds } from '@renderer/types'
|
||||
import { describe, expect, it } from 'vitest'
|
||||
|
||||
import { buildGeminiGenerateImageParams, isOpenRouterGeminiGenerateImageModel } from '../image'
|
||||
|
||||
describe('image utils', () => {
|
||||
describe('buildGeminiGenerateImageParams', () => {
|
||||
it('should return correct response modalities', () => {
|
||||
const result = buildGeminiGenerateImageParams()
|
||||
|
||||
expect(result).toEqual({
|
||||
responseModalities: ['TEXT', 'IMAGE']
|
||||
})
|
||||
})
|
||||
|
||||
it('should return an object with responseModalities property', () => {
|
||||
const result = buildGeminiGenerateImageParams()
|
||||
|
||||
expect(result).toHaveProperty('responseModalities')
|
||||
expect(Array.isArray(result.responseModalities)).toBe(true)
|
||||
expect(result.responseModalities).toHaveLength(2)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isOpenRouterGeminiGenerateImageModel', () => {
|
||||
const mockOpenRouterProvider: Provider = {
|
||||
id: SystemProviderIds.openrouter,
|
||||
name: 'OpenRouter',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://openrouter.ai/api/v1',
|
||||
isSystem: true
|
||||
} as Provider
|
||||
|
||||
const mockOtherProvider: Provider = {
|
||||
id: SystemProviderIds.openai,
|
||||
name: 'OpenAI',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://api.openai.com/v1',
|
||||
isSystem: true
|
||||
} as Provider
|
||||
|
||||
it('should return true for OpenRouter Gemini 2.5 Flash Image model', () => {
|
||||
const model: Model = {
|
||||
id: 'google/gemini-2.5-flash-image-preview',
|
||||
name: 'Gemini 2.5 Flash Image',
|
||||
provider: SystemProviderIds.openrouter
|
||||
} as Model
|
||||
|
||||
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
|
||||
expect(result).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for non-Gemini model on OpenRouter', () => {
|
||||
const model: Model = {
|
||||
id: 'openai/gpt-4',
|
||||
name: 'GPT-4',
|
||||
provider: SystemProviderIds.openrouter
|
||||
} as Model
|
||||
|
||||
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for Gemini model on non-OpenRouter provider', () => {
|
||||
const model: Model = {
|
||||
id: 'gemini-2.5-flash-image-preview',
|
||||
name: 'Gemini 2.5 Flash Image',
|
||||
provider: SystemProviderIds.gemini
|
||||
} as Model
|
||||
|
||||
const result = isOpenRouterGeminiGenerateImageModel(model, mockOtherProvider)
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for Gemini model without image suffix', () => {
|
||||
const model: Model = {
|
||||
id: 'google/gemini-2.5-flash',
|
||||
name: 'Gemini 2.5 Flash',
|
||||
provider: SystemProviderIds.openrouter
|
||||
} as Model
|
||||
|
||||
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle model ID with partial match', () => {
|
||||
const model: Model = {
|
||||
id: 'google/gemini-2.5-flash-image-generation',
|
||||
name: 'Gemini Image Gen',
|
||||
provider: SystemProviderIds.openrouter
|
||||
} as Model
|
||||
|
||||
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
|
||||
expect(result).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for custom provider', () => {
|
||||
const customProvider: Provider = {
|
||||
id: 'custom-provider-123',
|
||||
name: 'Custom Provider',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://custom.com'
|
||||
} as Provider
|
||||
|
||||
const model: Model = {
|
||||
id: 'gemini-2.5-flash-image-preview',
|
||||
name: 'Gemini 2.5 Flash Image',
|
||||
provider: 'custom-provider-123'
|
||||
} as Model
|
||||
|
||||
const result = isOpenRouterGeminiGenerateImageModel(model, customProvider)
|
||||
expect(result).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,435 +0,0 @@
|
||||
/**
|
||||
* mcp.ts Unit Tests
|
||||
* Tests for MCP tools configuration and conversion utilities
|
||||
*/
|
||||
|
||||
import type { MCPTool } from '@renderer/types'
|
||||
import type { Tool } from 'ai'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { convertMcpToolsToAiSdkTools, setupToolsConfig } from '../mcp'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: () => ({
|
||||
debug: vi.fn(),
|
||||
error: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
info: vi.fn()
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/mcp-tools', () => ({
|
||||
getMcpServerByTool: vi.fn(() => ({ id: 'test-server', autoApprove: false })),
|
||||
isToolAutoApproved: vi.fn(() => false),
|
||||
callMCPTool: vi.fn(async () => ({
|
||||
content: [{ type: 'text', text: 'Tool executed successfully' }],
|
||||
isError: false
|
||||
}))
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/userConfirmation', () => ({
|
||||
requestToolConfirmation: vi.fn(async () => true)
|
||||
}))
|
||||
|
||||
describe('mcp utils', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('setupToolsConfig', () => {
|
||||
it('should return undefined when no MCP tools provided', () => {
|
||||
const result = setupToolsConfig()
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should return undefined when empty MCP tools array provided', () => {
|
||||
const result = setupToolsConfig([])
|
||||
expect(result).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should convert MCP tools to AI SDK tools format', () => {
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'test-tool-1',
|
||||
serverId: 'test-server',
|
||||
serverName: 'test-server',
|
||||
name: 'test-tool',
|
||||
description: 'A test tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
query: { type: 'string' }
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const result = setupToolsConfig(mcpTools)
|
||||
|
||||
expect(result).not.toBeUndefined()
|
||||
expect(Object.keys(result!)).toEqual(['test-tool'])
|
||||
expect(result!['test-tool']).toHaveProperty('description')
|
||||
expect(result!['test-tool']).toHaveProperty('inputSchema')
|
||||
expect(result!['test-tool']).toHaveProperty('execute')
|
||||
})
|
||||
|
||||
it('should handle multiple MCP tools', () => {
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'tool1-id',
|
||||
serverId: 'server1',
|
||||
serverName: 'server1',
|
||||
name: 'tool1',
|
||||
description: 'First tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
},
|
||||
{
|
||||
id: 'tool2-id',
|
||||
serverId: 'server2',
|
||||
serverName: 'server2',
|
||||
name: 'tool2',
|
||||
description: 'Second tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const result = setupToolsConfig(mcpTools)
|
||||
|
||||
expect(result).not.toBeUndefined()
|
||||
expect(Object.keys(result!)).toHaveLength(2)
|
||||
expect(Object.keys(result!)).toEqual(['tool1', 'tool2'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('convertMcpToolsToAiSdkTools', () => {
|
||||
it('should convert single MCP tool to AI SDK tool', () => {
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'get-weather-id',
|
||||
serverId: 'weather-server',
|
||||
serverName: 'weather-server',
|
||||
name: 'get-weather',
|
||||
description: 'Get weather information',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
location: { type: 'string' }
|
||||
},
|
||||
required: ['location']
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
|
||||
expect(Object.keys(result)).toEqual(['get-weather'])
|
||||
|
||||
const tool = result['get-weather'] as Tool
|
||||
expect(tool.description).toBe('Get weather information')
|
||||
expect(tool.inputSchema).toBeDefined()
|
||||
expect(typeof tool.execute).toBe('function')
|
||||
})
|
||||
|
||||
it('should handle tool without description', () => {
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'no-desc-tool-id',
|
||||
serverId: 'test-server',
|
||||
serverName: 'test-server',
|
||||
name: 'no-desc-tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
|
||||
expect(Object.keys(result)).toEqual(['no-desc-tool'])
|
||||
const tool = result['no-desc-tool'] as Tool
|
||||
expect(tool.description).toBe('Tool from test-server')
|
||||
})
|
||||
|
||||
it('should convert empty tools array', () => {
|
||||
const result = convertMcpToolsToAiSdkTools([])
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('should handle complex input schemas', () => {
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'complex-tool-id',
|
||||
serverId: 'server',
|
||||
serverName: 'server',
|
||||
name: 'complex-tool',
|
||||
description: 'Tool with complex schema',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
name: { type: 'string' },
|
||||
age: { type: 'number' },
|
||||
tags: {
|
||||
type: 'array',
|
||||
items: { type: 'string' }
|
||||
},
|
||||
metadata: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
key: { type: 'string' }
|
||||
}
|
||||
}
|
||||
},
|
||||
required: ['name']
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
|
||||
expect(Object.keys(result)).toEqual(['complex-tool'])
|
||||
const tool = result['complex-tool'] as Tool
|
||||
expect(tool.inputSchema).toBeDefined()
|
||||
expect(typeof tool.execute).toBe('function')
|
||||
})
|
||||
|
||||
it('should preserve tool names with special characters', () => {
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'special-tool-id',
|
||||
serverId: 'server',
|
||||
serverName: 'server',
|
||||
name: 'tool_with-special.chars',
|
||||
description: 'Special chars tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
expect(Object.keys(result)).toEqual(['tool_with-special.chars'])
|
||||
})
|
||||
|
||||
it('should handle multiple tools with different schemas', () => {
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'string-tool-id',
|
||||
serverId: 'server1',
|
||||
serverName: 'server1',
|
||||
name: 'string-tool',
|
||||
description: 'String tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
input: { type: 'string' }
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
id: 'number-tool-id',
|
||||
serverId: 'server2',
|
||||
serverName: 'server2',
|
||||
name: 'number-tool',
|
||||
description: 'Number tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
count: { type: 'number' }
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
id: 'boolean-tool-id',
|
||||
serverId: 'server3',
|
||||
serverName: 'server3',
|
||||
name: 'boolean-tool',
|
||||
description: 'Boolean tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
enabled: { type: 'boolean' }
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const result = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
|
||||
expect(Object.keys(result).sort()).toEqual(['boolean-tool', 'number-tool', 'string-tool'])
|
||||
expect(result['string-tool']).toBeDefined()
|
||||
expect(result['number-tool']).toBeDefined()
|
||||
expect(result['boolean-tool']).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('tool execution', () => {
|
||||
it('should execute tool with user confirmation', async () => {
|
||||
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
|
||||
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
|
||||
|
||||
vi.mocked(requestToolConfirmation).mockResolvedValue(true)
|
||||
vi.mocked(callMCPTool).mockResolvedValue({
|
||||
content: [{ type: 'text', text: 'Success' }],
|
||||
isError: false
|
||||
})
|
||||
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'test-exec-tool-id',
|
||||
serverId: 'test-server',
|
||||
serverName: 'test-server',
|
||||
name: 'test-exec-tool',
|
||||
description: 'Test execution tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
const tool = tools['test-exec-tool'] as Tool
|
||||
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'test-call-123' })
|
||||
|
||||
expect(requestToolConfirmation).toHaveBeenCalled()
|
||||
expect(callMCPTool).toHaveBeenCalled()
|
||||
expect(result).toEqual({
|
||||
content: [{ type: 'text', text: 'Success' }],
|
||||
isError: false
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle user cancellation', async () => {
|
||||
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
|
||||
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
|
||||
|
||||
vi.mocked(requestToolConfirmation).mockResolvedValue(false)
|
||||
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'cancelled-tool-id',
|
||||
serverId: 'test-server',
|
||||
serverName: 'test-server',
|
||||
name: 'cancelled-tool',
|
||||
description: 'Tool to cancel',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
const tool = tools['cancelled-tool'] as Tool
|
||||
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'cancel-call-123' })
|
||||
|
||||
expect(requestToolConfirmation).toHaveBeenCalled()
|
||||
expect(callMCPTool).not.toHaveBeenCalled()
|
||||
expect(result).toEqual({
|
||||
content: [
|
||||
{
|
||||
type: 'text',
|
||||
text: 'User declined to execute tool "cancelled-tool".'
|
||||
}
|
||||
],
|
||||
isError: false
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle tool execution error', async () => {
|
||||
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
|
||||
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
|
||||
|
||||
vi.mocked(requestToolConfirmation).mockResolvedValue(true)
|
||||
vi.mocked(callMCPTool).mockResolvedValue({
|
||||
content: [{ type: 'text', text: 'Error occurred' }],
|
||||
isError: true
|
||||
})
|
||||
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'error-tool-id',
|
||||
serverId: 'test-server',
|
||||
serverName: 'test-server',
|
||||
name: 'error-tool',
|
||||
description: 'Tool that errors',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
const tool = tools['error-tool'] as Tool
|
||||
|
||||
await expect(
|
||||
tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'error-call-123' })
|
||||
).rejects.toEqual({
|
||||
content: [{ type: 'text', text: 'Error occurred' }],
|
||||
isError: true
|
||||
})
|
||||
})
|
||||
|
||||
it('should auto-approve when enabled', async () => {
|
||||
const { callMCPTool, isToolAutoApproved } = await import('@renderer/utils/mcp-tools')
|
||||
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
|
||||
|
||||
vi.mocked(isToolAutoApproved).mockReturnValue(true)
|
||||
vi.mocked(callMCPTool).mockResolvedValue({
|
||||
content: [{ type: 'text', text: 'Auto-approved success' }],
|
||||
isError: false
|
||||
})
|
||||
|
||||
const mcpTools: MCPTool[] = [
|
||||
{
|
||||
id: 'auto-approve-tool-id',
|
||||
serverId: 'test-server',
|
||||
serverName: 'test-server',
|
||||
name: 'auto-approve-tool',
|
||||
description: 'Auto-approved tool',
|
||||
type: 'mcp',
|
||||
inputSchema: {
|
||||
type: 'object',
|
||||
properties: {}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
const tools = convertMcpToolsToAiSdkTools(mcpTools)
|
||||
const tool = tools['auto-approve-tool'] as Tool
|
||||
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'auto-call-123' })
|
||||
|
||||
expect(requestToolConfirmation).not.toHaveBeenCalled()
|
||||
expect(callMCPTool).toHaveBeenCalled()
|
||||
expect(result).toEqual({
|
||||
content: [{ type: 'text', text: 'Auto-approved success' }],
|
||||
isError: false
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,545 +0,0 @@
|
||||
/**
|
||||
* options.ts Unit Tests
|
||||
* Tests for building provider-specific options
|
||||
*/
|
||||
|
||||
import type { Assistant, Model, Provider } from '@renderer/types'
|
||||
import { OpenAIServiceTiers, SystemProviderIds } from '@renderer/types'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { buildProviderOptions } from '../options'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => {
|
||||
const actual = (await importOriginal()) as object
|
||||
return {
|
||||
...actual,
|
||||
baseProviderIdSchema: {
|
||||
safeParse: vi.fn((id) => {
|
||||
const baseProviders = [
|
||||
'openai',
|
||||
'openai-chat',
|
||||
'azure',
|
||||
'azure-responses',
|
||||
'huggingface',
|
||||
'anthropic',
|
||||
'google',
|
||||
'xai',
|
||||
'deepseek',
|
||||
'openrouter',
|
||||
'openai-compatible'
|
||||
]
|
||||
if (baseProviders.includes(id)) {
|
||||
return { success: true, data: id }
|
||||
}
|
||||
return { success: false }
|
||||
})
|
||||
},
|
||||
customProviderIdSchema: {
|
||||
safeParse: vi.fn((id) => {
|
||||
const customProviders = ['google-vertex', 'google-vertex-anthropic', 'bedrock']
|
||||
if (customProviders.includes(id)) {
|
||||
return { success: true, data: id }
|
||||
}
|
||||
return { success: false, error: new Error('Invalid provider') }
|
||||
})
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('../provider/factory', () => ({
|
||||
getAiSdkProviderId: vi.fn((provider) => {
|
||||
// Simulate the provider ID mapping
|
||||
const mapping: Record<string, string> = {
|
||||
[SystemProviderIds.gemini]: 'google',
|
||||
[SystemProviderIds.openai]: 'openai',
|
||||
[SystemProviderIds.anthropic]: 'anthropic',
|
||||
[SystemProviderIds.grok]: 'xai',
|
||||
[SystemProviderIds.deepseek]: 'deepseek',
|
||||
[SystemProviderIds.openrouter]: 'openrouter'
|
||||
}
|
||||
return mapping[provider.id] || provider.id
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/config/models', async (importOriginal) => ({
|
||||
...(await importOriginal()),
|
||||
isOpenAIModel: vi.fn((model) => model.id.includes('gpt') || model.id.includes('o1')),
|
||||
isQwenMTModel: vi.fn(() => false),
|
||||
isSupportFlexServiceTierModel: vi.fn(() => true),
|
||||
isOpenAILLMModel: vi.fn(() => true),
|
||||
SYSTEM_MODELS: {
|
||||
defaultModel: [
|
||||
{ id: 'default-1', name: 'Default 1' },
|
||||
{ id: 'default-2', name: 'Default 2' },
|
||||
{ id: 'default-3', name: 'Default 3' }
|
||||
]
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock(import('@renderer/utils/provider'), async (importOriginal) => {
|
||||
return {
|
||||
...(await importOriginal()),
|
||||
isSupportServiceTierProvider: vi.fn((provider) => {
|
||||
return [SystemProviderIds.openai, SystemProviderIds.groq].includes(provider.id)
|
||||
})
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@renderer/store/settings', () => ({
|
||||
default: (state = { settings: {} }) => state
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/hooks/useSettings', () => ({
|
||||
getStoreSetting: vi.fn((key) => {
|
||||
if (key === 'openAI') {
|
||||
return { summaryText: 'off', verbosity: 'medium' } as any
|
||||
}
|
||||
return {}
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
getDefaultAssistant: vi.fn(() => ({
|
||||
id: 'default',
|
||||
name: 'Default Assistant',
|
||||
settings: {}
|
||||
})),
|
||||
getAssistantSettings: vi.fn(() => ({
|
||||
reasoning_effort: 'medium',
|
||||
maxTokens: 4096
|
||||
})),
|
||||
getProviderByModel: vi.fn((model: Model) => ({
|
||||
id: model.provider,
|
||||
name: 'Mock Provider'
|
||||
}))
|
||||
}))
|
||||
|
||||
vi.mock('../reasoning', () => ({
|
||||
getOpenAIReasoningParams: vi.fn(() => ({ reasoningEffort: 'medium' })),
|
||||
getAnthropicReasoningParams: vi.fn(() => ({
|
||||
thinking: { type: 'enabled', budgetTokens: 5000 }
|
||||
})),
|
||||
getGeminiReasoningParams: vi.fn(() => ({
|
||||
thinkingConfig: { include_thoughts: true }
|
||||
})),
|
||||
getXAIReasoningParams: vi.fn(() => ({ reasoningEffort: 'high' })),
|
||||
getBedrockReasoningParams: vi.fn(() => ({
|
||||
reasoningConfig: { type: 'enabled', budgetTokens: 5000 }
|
||||
})),
|
||||
getReasoningEffort: vi.fn(() => ({ reasoningEffort: 'medium' })),
|
||||
getCustomParameters: vi.fn(() => ({}))
|
||||
}))
|
||||
|
||||
vi.mock('../image', () => ({
|
||||
buildGeminiGenerateImageParams: vi.fn(() => ({
|
||||
responseModalities: ['TEXT', 'IMAGE']
|
||||
}))
|
||||
}))
|
||||
|
||||
vi.mock('../websearch', () => ({
|
||||
getWebSearchParams: vi.fn(() => ({ enable_search: true }))
|
||||
}))
|
||||
|
||||
const ensureWindowApi = () => {
|
||||
const globalWindow = window as any
|
||||
globalWindow.api = globalWindow.api || {}
|
||||
globalWindow.api.getAppInfo = globalWindow.api.getAppInfo || vi.fn(async () => ({ notesPath: '' }))
|
||||
}
|
||||
|
||||
ensureWindowApi()
|
||||
|
||||
describe('options utils', () => {
|
||||
const mockAssistant: Assistant = {
|
||||
id: 'test-assistant',
|
||||
name: 'Test Assistant',
|
||||
settings: {}
|
||||
} as Assistant
|
||||
|
||||
const mockModel: Model = {
|
||||
id: 'gpt-4',
|
||||
name: 'GPT-4',
|
||||
provider: SystemProviderIds.openai
|
||||
} as Model
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
describe('buildProviderOptions', () => {
|
||||
describe('OpenAI provider', () => {
|
||||
const openaiProvider: Provider = {
|
||||
id: SystemProviderIds.openai,
|
||||
name: 'OpenAI',
|
||||
type: 'openai-response',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://api.openai.com/v1',
|
||||
isSystem: true
|
||||
} as Provider
|
||||
|
||||
it('should build basic OpenAI options', () => {
|
||||
const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, {
|
||||
enableReasoning: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result).toHaveProperty('openai')
|
||||
expect(result.openai).toBeDefined()
|
||||
})
|
||||
|
||||
it('should include reasoning parameters when enabled', () => {
|
||||
const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, {
|
||||
enableReasoning: true,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result.openai).toHaveProperty('reasoningEffort')
|
||||
expect(result.openai.reasoningEffort).toBe('medium')
|
||||
})
|
||||
|
||||
it('should include service tier when supported', () => {
|
||||
const providerWithServiceTier: Provider = {
|
||||
...openaiProvider,
|
||||
serviceTier: OpenAIServiceTiers.auto
|
||||
}
|
||||
|
||||
const result = buildProviderOptions(mockAssistant, mockModel, providerWithServiceTier, {
|
||||
enableReasoning: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result.openai).toHaveProperty('serviceTier')
|
||||
expect(result.openai.serviceTier).toBe(OpenAIServiceTiers.auto)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Anthropic provider', () => {
|
||||
const anthropicProvider: Provider = {
|
||||
id: SystemProviderIds.anthropic,
|
||||
name: 'Anthropic',
|
||||
type: 'anthropic',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://api.anthropic.com',
|
||||
isSystem: true
|
||||
} as Provider
|
||||
|
||||
const anthropicModel: Model = {
|
||||
id: 'claude-3-5-sonnet-20241022',
|
||||
name: 'Claude 3.5 Sonnet',
|
||||
provider: SystemProviderIds.anthropic
|
||||
} as Model
|
||||
|
||||
it('should build basic Anthropic options', () => {
|
||||
const result = buildProviderOptions(mockAssistant, anthropicModel, anthropicProvider, {
|
||||
enableReasoning: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result).toHaveProperty('anthropic')
|
||||
expect(result.anthropic).toBeDefined()
|
||||
})
|
||||
|
||||
it('should include reasoning parameters when enabled', () => {
|
||||
const result = buildProviderOptions(mockAssistant, anthropicModel, anthropicProvider, {
|
||||
enableReasoning: true,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result.anthropic).toHaveProperty('thinking')
|
||||
expect(result.anthropic.thinking).toEqual({
|
||||
type: 'enabled',
|
||||
budgetTokens: 5000
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Google provider', () => {
|
||||
const googleProvider: Provider = {
|
||||
id: SystemProviderIds.gemini,
|
||||
name: 'Google',
|
||||
type: 'gemini',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://generativelanguage.googleapis.com',
|
||||
isSystem: true,
|
||||
models: [{ id: 'gemini-2.0-flash-exp' }] as Model[]
|
||||
} as Provider
|
||||
|
||||
const googleModel: Model = {
|
||||
id: 'gemini-2.0-flash-exp',
|
||||
name: 'Gemini 2.0 Flash',
|
||||
provider: SystemProviderIds.gemini
|
||||
} as Model
|
||||
|
||||
it('should build basic Google options', () => {
|
||||
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
|
||||
enableReasoning: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result).toHaveProperty('google')
|
||||
expect(result.google).toBeDefined()
|
||||
})
|
||||
|
||||
it('should include reasoning parameters when enabled', () => {
|
||||
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
|
||||
enableReasoning: true,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result.google).toHaveProperty('thinkingConfig')
|
||||
expect(result.google.thinkingConfig).toEqual({
|
||||
include_thoughts: true
|
||||
})
|
||||
})
|
||||
|
||||
it('should include image generation parameters when enabled', () => {
|
||||
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
|
||||
enableReasoning: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: true
|
||||
})
|
||||
|
||||
expect(result.google).toHaveProperty('responseModalities')
|
||||
expect(result.google.responseModalities).toEqual(['TEXT', 'IMAGE'])
|
||||
})
|
||||
})
|
||||
|
||||
describe('xAI provider', () => {
|
||||
const xaiProvider = {
|
||||
id: SystemProviderIds.grok,
|
||||
name: 'xAI',
|
||||
type: 'new-api',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://api.x.ai/v1',
|
||||
isSystem: true,
|
||||
models: [] as Model[]
|
||||
} as Provider
|
||||
|
||||
const xaiModel: Model = {
|
||||
id: 'grok-2-latest',
|
||||
name: 'Grok 2',
|
||||
provider: SystemProviderIds.grok
|
||||
} as Model
|
||||
|
||||
it('should build basic xAI options', () => {
|
||||
const result = buildProviderOptions(mockAssistant, xaiModel, xaiProvider, {
|
||||
enableReasoning: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result).toHaveProperty('xai')
|
||||
expect(result.xai).toBeDefined()
|
||||
})
|
||||
|
||||
it('should include reasoning parameters when enabled', () => {
|
||||
const result = buildProviderOptions(mockAssistant, xaiModel, xaiProvider, {
|
||||
enableReasoning: true,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result.xai).toHaveProperty('reasoningEffort')
|
||||
expect(result.xai.reasoningEffort).toBe('high')
|
||||
})
|
||||
})
|
||||
|
||||
describe('DeepSeek provider', () => {
|
||||
const deepseekProvider: Provider = {
|
||||
id: SystemProviderIds.deepseek,
|
||||
name: 'DeepSeek',
|
||||
type: 'openai',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://api.deepseek.com',
|
||||
isSystem: true
|
||||
} as Provider
|
||||
|
||||
const deepseekModel: Model = {
|
||||
id: 'deepseek-chat',
|
||||
name: 'DeepSeek Chat',
|
||||
provider: SystemProviderIds.deepseek
|
||||
} as Model
|
||||
|
||||
it('should build basic DeepSeek options', () => {
|
||||
const result = buildProviderOptions(mockAssistant, deepseekModel, deepseekProvider, {
|
||||
enableReasoning: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result).toHaveProperty('deepseek')
|
||||
expect(result.deepseek).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
describe('OpenRouter provider', () => {
|
||||
const openrouterProvider: Provider = {
|
||||
id: SystemProviderIds.openrouter,
|
||||
name: 'OpenRouter',
|
||||
type: 'openai',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://openrouter.ai/api/v1',
|
||||
isSystem: true
|
||||
} as Provider
|
||||
|
||||
const openrouterModel: Model = {
|
||||
id: 'openai/gpt-4',
|
||||
name: 'GPT-4',
|
||||
provider: SystemProviderIds.openrouter
|
||||
} as Model
|
||||
|
||||
it('should build basic OpenRouter options', () => {
|
||||
const result = buildProviderOptions(mockAssistant, openrouterModel, openrouterProvider, {
|
||||
enableReasoning: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result).toHaveProperty('openrouter')
|
||||
expect(result.openrouter).toBeDefined()
|
||||
})
|
||||
|
||||
it('should include web search parameters when enabled', () => {
|
||||
const result = buildProviderOptions(mockAssistant, openrouterModel, openrouterProvider, {
|
||||
enableReasoning: false,
|
||||
enableWebSearch: true,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result.openrouter).toHaveProperty('enable_search')
|
||||
})
|
||||
})
|
||||
|
||||
describe('Custom parameters', () => {
|
||||
it('should merge custom parameters', async () => {
|
||||
const { getCustomParameters } = await import('../reasoning')
|
||||
|
||||
vi.mocked(getCustomParameters).mockReturnValue({
|
||||
custom_param: 'custom_value',
|
||||
another_param: 123
|
||||
})
|
||||
|
||||
const result = buildProviderOptions(
|
||||
mockAssistant,
|
||||
mockModel,
|
||||
{
|
||||
id: SystemProviderIds.openai,
|
||||
name: 'OpenAI',
|
||||
type: 'openai',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://api.openai.com/v1'
|
||||
} as Provider,
|
||||
{
|
||||
enableReasoning: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
}
|
||||
)
|
||||
|
||||
expect(result.openai).toHaveProperty('custom_param')
|
||||
expect(result.openai.custom_param).toBe('custom_value')
|
||||
expect(result.openai).toHaveProperty('another_param')
|
||||
expect(result.openai.another_param).toBe(123)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Multiple capabilities', () => {
|
||||
const googleProvider = {
|
||||
id: SystemProviderIds.gemini,
|
||||
name: 'Google',
|
||||
type: 'gemini',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://generativelanguage.googleapis.com',
|
||||
isSystem: true,
|
||||
models: [] as Model[]
|
||||
} as Provider
|
||||
|
||||
const googleModel: Model = {
|
||||
id: 'gemini-2.0-flash-exp',
|
||||
name: 'Gemini 2.0 Flash',
|
||||
provider: SystemProviderIds.gemini
|
||||
} as Model
|
||||
|
||||
it('should combine reasoning and image generation', () => {
|
||||
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
|
||||
enableReasoning: true,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: true
|
||||
})
|
||||
|
||||
expect(result.google).toHaveProperty('thinkingConfig')
|
||||
expect(result.google).toHaveProperty('responseModalities')
|
||||
})
|
||||
|
||||
it('should handle all capabilities enabled', () => {
|
||||
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
|
||||
enableReasoning: true,
|
||||
enableWebSearch: true,
|
||||
enableGenerateImage: true
|
||||
})
|
||||
|
||||
expect(result.google).toBeDefined()
|
||||
expect(Object.keys(result.google).length).toBeGreaterThan(0)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Vertex AI providers', () => {
|
||||
it('should map google-vertex to google', () => {
|
||||
const vertexProvider = {
|
||||
id: 'google-vertex',
|
||||
name: 'Vertex AI',
|
||||
type: 'vertexai',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://vertex-ai.googleapis.com',
|
||||
models: [] as Model[]
|
||||
} as Provider
|
||||
|
||||
const vertexModel: Model = {
|
||||
id: 'gemini-2.0-flash-exp',
|
||||
name: 'Gemini 2.0 Flash',
|
||||
provider: 'google-vertex'
|
||||
} as Model
|
||||
|
||||
const result = buildProviderOptions(mockAssistant, vertexModel, vertexProvider, {
|
||||
enableReasoning: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result).toHaveProperty('google')
|
||||
})
|
||||
|
||||
it('should map google-vertex-anthropic to anthropic', () => {
|
||||
const vertexAnthropicProvider = {
|
||||
id: 'google-vertex-anthropic',
|
||||
name: 'Vertex AI Anthropic',
|
||||
type: 'vertex-anthropic',
|
||||
apiKey: 'test-key',
|
||||
apiHost: 'https://vertex-ai.googleapis.com',
|
||||
models: [] as Model[]
|
||||
} as Provider
|
||||
|
||||
const vertexModel: Model = {
|
||||
id: 'claude-3-5-sonnet-20241022',
|
||||
name: 'Claude 3.5 Sonnet',
|
||||
provider: 'google-vertex-anthropic'
|
||||
} as Model
|
||||
|
||||
const result = buildProviderOptions(mockAssistant, vertexModel, vertexAnthropicProvider, {
|
||||
enableReasoning: false,
|
||||
enableWebSearch: false,
|
||||
enableGenerateImage: false
|
||||
})
|
||||
|
||||
expect(result).toHaveProperty('anthropic')
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,967 +0,0 @@
|
||||
/**
|
||||
* reasoning.ts Unit Tests
|
||||
* Tests for reasoning parameter generation utilities
|
||||
*/
|
||||
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import type { SettingsState } from '@renderer/store/settings'
|
||||
import type { Assistant, Model, Provider } from '@renderer/types'
|
||||
import { SystemProviderIds } from '@renderer/types'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
getAnthropicReasoningParams,
|
||||
getBedrockReasoningParams,
|
||||
getCustomParameters,
|
||||
getGeminiReasoningParams,
|
||||
getOpenAIReasoningParams,
|
||||
getReasoningEffort,
|
||||
getXAIReasoningParams
|
||||
} from '../reasoning'
|
||||
|
||||
function defaultGetStoreSetting<K extends keyof SettingsState>(key: K): SettingsState[K] {
|
||||
if (key === 'openAI') {
|
||||
return {
|
||||
summaryText: 'auto',
|
||||
verbosity: 'medium'
|
||||
} as SettingsState[K]
|
||||
}
|
||||
return undefined as SettingsState[K]
|
||||
}
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@logger', () => ({
|
||||
loggerService: {
|
||||
withContext: () => ({
|
||||
debug: vi.fn(),
|
||||
error: vi.fn(),
|
||||
warn: vi.fn(),
|
||||
info: vi.fn()
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store/settings', () => ({
|
||||
default: (state = { settings: {} }) => state
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store/llm', () => ({
|
||||
initialState: {},
|
||||
default: (state = { llm: {} }) => state
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/config/constant', () => ({
|
||||
DEFAULT_MAX_TOKENS: 4096,
|
||||
isMac: false,
|
||||
isWin: false,
|
||||
TOKENFLUX_HOST: 'mock-host'
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/provider', () => ({
|
||||
isSupportEnableThinkingProvider: vi.fn((provider) => {
|
||||
return [SystemProviderIds.dashscope, SystemProviderIds.silicon].includes(provider.id)
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/config/models', async (importOriginal) => {
|
||||
const actual: any = await importOriginal()
|
||||
return {
|
||||
...actual,
|
||||
isReasoningModel: vi.fn(() => false),
|
||||
isOpenAIDeepResearchModel: vi.fn(() => false),
|
||||
isOpenAIModel: vi.fn(() => false),
|
||||
isSupportedReasoningEffortOpenAIModel: vi.fn(() => false),
|
||||
isSupportedThinkingTokenQwenModel: vi.fn(() => false),
|
||||
isQwenReasoningModel: vi.fn(() => false),
|
||||
isSupportedThinkingTokenClaudeModel: vi.fn(() => false),
|
||||
isSupportedThinkingTokenGeminiModel: vi.fn(() => false),
|
||||
isSupportedThinkingTokenDoubaoModel: vi.fn(() => false),
|
||||
isSupportedThinkingTokenZhipuModel: vi.fn(() => false),
|
||||
isSupportedReasoningEffortModel: vi.fn(() => false),
|
||||
isDeepSeekHybridInferenceModel: vi.fn(() => false),
|
||||
isSupportedReasoningEffortGrokModel: vi.fn(() => false),
|
||||
getThinkModelType: vi.fn(() => 'default'),
|
||||
isDoubaoSeedAfter251015: vi.fn(() => false),
|
||||
isDoubaoThinkingAutoModel: vi.fn(() => false),
|
||||
isGrok4FastReasoningModel: vi.fn(() => false),
|
||||
isGrokReasoningModel: vi.fn(() => false),
|
||||
isOpenAIReasoningModel: vi.fn(() => false),
|
||||
isQwenAlwaysThinkModel: vi.fn(() => false),
|
||||
isSupportedThinkingTokenHunyuanModel: vi.fn(() => false),
|
||||
isSupportedThinkingTokenModel: vi.fn(() => false),
|
||||
isGPT51SeriesModel: vi.fn(() => false)
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@renderer/hooks/useSettings', () => ({
|
||||
getStoreSetting: vi.fn(defaultGetStoreSetting)
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
getAssistantSettings: vi.fn((assistant) => ({
|
||||
maxTokens: assistant?.settings?.maxTokens || 4096,
|
||||
reasoning_effort: assistant?.settings?.reasoning_effort
|
||||
})),
|
||||
getProviderByModel: vi.fn((model) => ({
|
||||
id: model.provider,
|
||||
name: 'Test Provider'
|
||||
})),
|
||||
getDefaultAssistant: vi.fn(() => ({
|
||||
id: 'default',
|
||||
name: 'Default Assistant',
|
||||
settings: {}
|
||||
}))
|
||||
}))
|
||||
|
||||
const ensureWindowApi = () => {
|
||||
const globalWindow = window as any
|
||||
globalWindow.api = globalWindow.api || {}
|
||||
globalWindow.api.getAppInfo = globalWindow.api.getAppInfo || vi.fn(async () => ({ notesPath: '' }))
|
||||
}
|
||||
|
||||
ensureWindowApi()
|
||||
|
||||
describe('reasoning utils', () => {
|
||||
beforeEach(() => {
|
||||
vi.resetAllMocks()
|
||||
})
|
||||
|
||||
describe('getReasoningEffort', () => {
|
||||
it('should return empty object for non-reasoning model', async () => {
|
||||
const model: Model = {
|
||||
id: 'gpt-4',
|
||||
name: 'GPT-4',
|
||||
provider: SystemProviderIds.openai
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {}
|
||||
} as Assistant
|
||||
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('should disable reasoning for OpenRouter when no reasoning effort set', async () => {
|
||||
const { isReasoningModel } = await import('@renderer/config/models')
|
||||
|
||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||
|
||||
const model: Model = {
|
||||
id: 'anthropic/claude-sonnet-4',
|
||||
name: 'Claude Sonnet 4',
|
||||
provider: SystemProviderIds.openrouter
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {}
|
||||
} as Assistant
|
||||
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
expect(result).toEqual({ reasoning: { enabled: false, exclude: true } })
|
||||
})
|
||||
|
||||
it('should handle Qwen models with enable_thinking', async () => {
|
||||
const { isReasoningModel, isSupportedThinkingTokenQwenModel, isQwenReasoningModel } = await import(
|
||||
'@renderer/config/models'
|
||||
)
|
||||
|
||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||
vi.mocked(isSupportedThinkingTokenQwenModel).mockReturnValue(true)
|
||||
vi.mocked(isQwenReasoningModel).mockReturnValue(true)
|
||||
|
||||
const model: Model = {
|
||||
id: 'qwen-plus',
|
||||
name: 'Qwen Plus',
|
||||
provider: SystemProviderIds.dashscope
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {
|
||||
reasoning_effort: 'medium'
|
||||
}
|
||||
} as Assistant
|
||||
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
expect(result).toHaveProperty('enable_thinking')
|
||||
})
|
||||
|
||||
it('should handle Claude models with thinking config', async () => {
|
||||
const {
|
||||
isSupportedThinkingTokenClaudeModel,
|
||||
isReasoningModel,
|
||||
isQwenReasoningModel,
|
||||
isSupportedThinkingTokenGeminiModel,
|
||||
isSupportedThinkingTokenDoubaoModel,
|
||||
isSupportedThinkingTokenZhipuModel,
|
||||
isSupportedReasoningEffortModel
|
||||
} = await import('@renderer/config/models')
|
||||
|
||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
|
||||
vi.mocked(isQwenReasoningModel).mockReturnValue(false)
|
||||
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(false)
|
||||
vi.mocked(isSupportedThinkingTokenDoubaoModel).mockReturnValue(false)
|
||||
vi.mocked(isSupportedThinkingTokenZhipuModel).mockReturnValue(false)
|
||||
vi.mocked(isSupportedReasoningEffortModel).mockReturnValue(false)
|
||||
|
||||
const model: Model = {
|
||||
id: 'claude-3-7-sonnet',
|
||||
name: 'Claude 3.7 Sonnet',
|
||||
provider: SystemProviderIds.anthropic
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {
|
||||
reasoning_effort: 'high',
|
||||
maxTokens: 4096
|
||||
}
|
||||
} as Assistant
|
||||
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
expect(result).toEqual({
|
||||
thinking: {
|
||||
type: 'enabled',
|
||||
budget_tokens: expect.any(Number)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle Gemini Flash models with thinking budget 0', async () => {
|
||||
const {
|
||||
isSupportedThinkingTokenGeminiModel,
|
||||
isReasoningModel,
|
||||
isQwenReasoningModel,
|
||||
isSupportedThinkingTokenClaudeModel,
|
||||
isSupportedThinkingTokenDoubaoModel,
|
||||
isSupportedThinkingTokenZhipuModel,
|
||||
isOpenAIDeepResearchModel,
|
||||
isSupportedThinkingTokenQwenModel,
|
||||
isSupportedThinkingTokenHunyuanModel,
|
||||
isDeepSeekHybridInferenceModel
|
||||
} = await import('@renderer/config/models')
|
||||
|
||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
|
||||
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
|
||||
vi.mocked(isQwenReasoningModel).mockReturnValue(false)
|
||||
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(false)
|
||||
vi.mocked(isSupportedThinkingTokenDoubaoModel).mockReturnValue(false)
|
||||
vi.mocked(isSupportedThinkingTokenZhipuModel).mockReturnValue(false)
|
||||
vi.mocked(isSupportedThinkingTokenQwenModel).mockReturnValue(false)
|
||||
vi.mocked(isSupportedThinkingTokenHunyuanModel).mockReturnValue(false)
|
||||
vi.mocked(isDeepSeekHybridInferenceModel).mockReturnValue(false)
|
||||
|
||||
const model: Model = {
|
||||
id: 'gemini-2.5-flash',
|
||||
name: 'Gemini 2.5 Flash',
|
||||
provider: SystemProviderIds.openai
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {}
|
||||
} as Assistant
|
||||
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
expect(result).toEqual({
|
||||
extra_body: {
|
||||
google: {
|
||||
thinking_config: {
|
||||
thinking_budget: 0
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle GPT-5.1 reasoning model with effort levels', async () => {
|
||||
const {
|
||||
isReasoningModel,
|
||||
isOpenAIDeepResearchModel,
|
||||
isSupportedReasoningEffortModel,
|
||||
isGPT51SeriesModel,
|
||||
getThinkModelType
|
||||
} = await import('@renderer/config/models')
|
||||
|
||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
|
||||
vi.mocked(isSupportedReasoningEffortModel).mockReturnValue(true)
|
||||
vi.mocked(getThinkModelType).mockReturnValue('gpt5_1')
|
||||
vi.mocked(isGPT51SeriesModel).mockReturnValue(true)
|
||||
|
||||
const model: Model = {
|
||||
id: 'gpt-5.1',
|
||||
name: 'GPT-5.1',
|
||||
provider: SystemProviderIds.openai
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {
|
||||
reasoning_effort: 'none'
|
||||
}
|
||||
} as Assistant
|
||||
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
expect(result).toEqual({
|
||||
reasoningEffort: 'none'
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle DeepSeek hybrid inference models', async () => {
|
||||
const { isReasoningModel, isDeepSeekHybridInferenceModel } = await import('@renderer/config/models')
|
||||
|
||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||
vi.mocked(isDeepSeekHybridInferenceModel).mockReturnValue(true)
|
||||
|
||||
const model: Model = {
|
||||
id: 'deepseek-v3.1',
|
||||
name: 'DeepSeek V3.1',
|
||||
provider: SystemProviderIds.silicon
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {
|
||||
reasoning_effort: 'high'
|
||||
}
|
||||
} as Assistant
|
||||
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
expect(result).toEqual({
|
||||
enable_thinking: true
|
||||
})
|
||||
})
|
||||
|
||||
it('should return medium effort for deep research models', async () => {
|
||||
const { isReasoningModel, isOpenAIDeepResearchModel } = await import('@renderer/config/models')
|
||||
|
||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(true)
|
||||
|
||||
const model: Model = {
|
||||
id: 'o3-deep-research',
|
||||
provider: SystemProviderIds.openai
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {}
|
||||
} as Assistant
|
||||
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
expect(result).toEqual({ reasoning_effort: 'medium' })
|
||||
})
|
||||
|
||||
it('should return empty for groq provider', async () => {
|
||||
const { getProviderByModel } = await import('@renderer/services/AssistantService')
|
||||
|
||||
vi.mocked(getProviderByModel).mockReturnValue({
|
||||
id: 'groq',
|
||||
name: 'Groq'
|
||||
} as Provider)
|
||||
|
||||
const model: Model = {
|
||||
id: 'groq-model',
|
||||
name: 'Groq Model',
|
||||
provider: 'groq'
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {}
|
||||
} as Assistant
|
||||
|
||||
const result = getReasoningEffort(assistant, model)
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('getOpenAIReasoningParams', () => {
|
||||
it('should return empty object for non-reasoning model', async () => {
|
||||
const model: Model = {
|
||||
id: 'gpt-4',
|
||||
name: 'GPT-4',
|
||||
provider: SystemProviderIds.openai
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {}
|
||||
} as Assistant
|
||||
|
||||
const result = getOpenAIReasoningParams(assistant, model)
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('should return empty when no reasoning effort set', async () => {
|
||||
const model: Model = {
|
||||
id: 'o1-preview',
|
||||
name: 'O1 Preview',
|
||||
provider: SystemProviderIds.openai
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {}
|
||||
} as Assistant
|
||||
|
||||
const result = getOpenAIReasoningParams(assistant, model)
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('should return reasoning effort for OpenAI models', async () => {
|
||||
const { isReasoningModel, isOpenAIModel, isSupportedReasoningEffortOpenAIModel } = await import(
|
||||
'@renderer/config/models'
|
||||
)
|
||||
|
||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||
vi.mocked(isOpenAIModel).mockReturnValue(true)
|
||||
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
|
||||
|
||||
const model: Model = {
|
||||
id: 'gpt-5.1',
|
||||
name: 'GPT 5.1',
|
||||
provider: SystemProviderIds.openai
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {
|
||||
reasoning_effort: 'high'
|
||||
}
|
||||
} as Assistant
|
||||
|
||||
const result = getOpenAIReasoningParams(assistant, model)
|
||||
expect(result).toEqual({
|
||||
reasoningEffort: 'high',
|
||||
reasoningSummary: 'auto'
|
||||
})
|
||||
})
|
||||
|
||||
it('should include reasoning summary when not o1-pro', async () => {
|
||||
const { isReasoningModel, isOpenAIModel, isSupportedReasoningEffortOpenAIModel } = await import(
|
||||
'@renderer/config/models'
|
||||
)
|
||||
|
||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||
vi.mocked(isOpenAIModel).mockReturnValue(true)
|
||||
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
|
||||
|
||||
const model: Model = {
|
||||
id: 'gpt-5',
|
||||
provider: SystemProviderIds.openai
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {
|
||||
reasoning_effort: 'medium'
|
||||
}
|
||||
} as Assistant
|
||||
|
||||
const result = getOpenAIReasoningParams(assistant, model)
|
||||
expect(result).toEqual({
|
||||
reasoningEffort: 'medium',
|
||||
reasoningSummary: 'auto'
|
||||
})
|
||||
})
|
||||
|
||||
it('should not include reasoning summary for o1-pro', async () => {
|
||||
const { isReasoningModel, isOpenAIDeepResearchModel, isSupportedReasoningEffortOpenAIModel } = await import(
|
||||
'@renderer/config/models'
|
||||
)
|
||||
|
||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
|
||||
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
|
||||
vi.mocked(getStoreSetting).mockReturnValue({ summaryText: 'off' } as any)
|
||||
|
||||
const model: Model = {
|
||||
id: 'o1-pro',
|
||||
name: 'O1 Pro',
|
||||
provider: SystemProviderIds.openai
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {
|
||||
reasoning_effort: 'high'
|
||||
}
|
||||
} as Assistant
|
||||
|
||||
const result = getOpenAIReasoningParams(assistant, model)
|
||||
expect(result).toEqual({
|
||||
reasoningEffort: 'high',
|
||||
reasoningSummary: undefined
|
||||
})
|
||||
})
|
||||
|
||||
it('should force medium effort for deep research models', async () => {
|
||||
const { isReasoningModel, isOpenAIModel, isOpenAIDeepResearchModel, isSupportedReasoningEffortOpenAIModel } =
|
||||
await import('@renderer/config/models')
|
||||
const { getStoreSetting } = await import('@renderer/hooks/useSettings')
|
||||
|
||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||
vi.mocked(isOpenAIModel).mockReturnValue(true)
|
||||
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(true)
|
||||
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
|
||||
vi.mocked(getStoreSetting).mockReturnValue({ summaryText: 'off' } as any)
|
||||
|
||||
const model: Model = {
|
||||
id: 'o3-deep-research',
|
||||
name: 'O3 Mini',
|
||||
provider: SystemProviderIds.openai
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {
|
||||
reasoning_effort: 'high'
|
||||
}
|
||||
} as Assistant
|
||||
|
||||
const result = getOpenAIReasoningParams(assistant, model)
|
||||
expect(result).toEqual({
|
||||
reasoningEffort: 'medium',
|
||||
reasoningSummary: 'off'
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('getAnthropicReasoningParams', () => {
|
||||
it('should return empty for non-reasoning model', async () => {
|
||||
const { isReasoningModel } = await import('@renderer/config/models')
|
||||
|
||||
vi.mocked(isReasoningModel).mockReturnValue(false)
|
||||
|
||||
const model: Model = {
|
||||
id: 'claude-3-5-sonnet',
|
||||
name: 'Claude 3.5 Sonnet',
|
||||
provider: SystemProviderIds.anthropic
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {}
|
||||
} as Assistant
|
||||
|
||||
const result = getAnthropicReasoningParams(assistant, model)
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('should return disabled thinking when no reasoning effort', async () => {
|
||||
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
|
||||
|
||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(false)
|
||||
|
||||
const model: Model = {
|
||||
id: 'claude-3-7-sonnet',
|
||||
name: 'Claude 3.7 Sonnet',
|
||||
provider: SystemProviderIds.anthropic
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {}
|
||||
} as Assistant
|
||||
|
||||
const result = getAnthropicReasoningParams(assistant, model)
|
||||
expect(result).toEqual({
|
||||
thinking: {
|
||||
type: 'disabled'
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should return enabled thinking with budget for Claude models', async () => {
|
||||
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
|
||||
|
||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
|
||||
|
||||
const model: Model = {
|
||||
id: 'claude-3-7-sonnet',
|
||||
name: 'Claude 3.7 Sonnet',
|
||||
provider: SystemProviderIds.anthropic
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {
|
||||
reasoning_effort: 'medium',
|
||||
maxTokens: 4096
|
||||
}
|
||||
} as Assistant
|
||||
|
||||
const result = getAnthropicReasoningParams(assistant, model)
|
||||
expect(result).toEqual({
|
||||
thinking: {
|
||||
type: 'enabled',
|
||||
budgetTokens: 2048
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('getGeminiReasoningParams', () => {
|
||||
it('should return empty for non-reasoning model', async () => {
|
||||
const { isReasoningModel } = await import('@renderer/config/models')
|
||||
|
||||
vi.mocked(isReasoningModel).mockReturnValue(false)
|
||||
|
||||
const model: Model = {
|
||||
id: 'gemini-2.0-flash',
|
||||
name: 'Gemini 2.0 Flash',
|
||||
provider: SystemProviderIds.gemini
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {}
|
||||
} as Assistant
|
||||
|
||||
const result = getGeminiReasoningParams(assistant, model)
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('should disable thinking for Flash models without reasoning effort', async () => {
|
||||
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
|
||||
|
||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
|
||||
|
||||
const model: Model = {
|
||||
id: 'gemini-2.5-flash',
|
||||
name: 'Gemini 2.5 Flash',
|
||||
provider: SystemProviderIds.gemini
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {}
|
||||
} as Assistant
|
||||
|
||||
const result = getGeminiReasoningParams(assistant, model)
|
||||
expect(result).toEqual({
|
||||
thinkingConfig: {
|
||||
includeThoughts: false,
|
||||
thinkingBudget: 0
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should enable thinking with budget for reasoning effort', async () => {
|
||||
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
|
||||
|
||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
|
||||
|
||||
const model: Model = {
|
||||
id: 'gemini-2.5-pro',
|
||||
name: 'Gemini 2.5 Pro',
|
||||
provider: SystemProviderIds.gemini
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {
|
||||
reasoning_effort: 'medium'
|
||||
}
|
||||
} as Assistant
|
||||
|
||||
const result = getGeminiReasoningParams(assistant, model)
|
||||
expect(result).toEqual({
|
||||
thinkingConfig: {
|
||||
thinkingBudget: 16448,
|
||||
includeThoughts: true
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should enable thinking without budget for auto effort ratio > 1', async () => {
|
||||
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
|
||||
|
||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
|
||||
|
||||
const model: Model = {
|
||||
id: 'gemini-2.5-pro',
|
||||
name: 'Gemini 2.5 Pro',
|
||||
provider: SystemProviderIds.gemini
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {
|
||||
reasoning_effort: 'auto'
|
||||
}
|
||||
} as Assistant
|
||||
|
||||
const result = getGeminiReasoningParams(assistant, model)
|
||||
expect(result).toEqual({
|
||||
thinkingConfig: {
|
||||
includeThoughts: true
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('getXAIReasoningParams', () => {
|
||||
it('should return empty for non-Grok model', async () => {
|
||||
const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
|
||||
|
||||
vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(false)
|
||||
|
||||
const model: Model = {
|
||||
id: 'other-model',
|
||||
name: 'Other Model',
|
||||
provider: SystemProviderIds.grok
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {}
|
||||
} as Assistant
|
||||
|
||||
const result = getXAIReasoningParams(assistant, model)
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('should return empty when no reasoning effort', async () => {
|
||||
const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
|
||||
|
||||
vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(true)
|
||||
|
||||
const model: Model = {
|
||||
id: 'grok-2',
|
||||
name: 'Grok 2',
|
||||
provider: SystemProviderIds.grok
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {}
|
||||
} as Assistant
|
||||
|
||||
const result = getXAIReasoningParams(assistant, model)
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('should return reasoning effort for Grok models', async () => {
|
||||
const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
|
||||
|
||||
vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(true)
|
||||
|
||||
const model: Model = {
|
||||
id: 'grok-3',
|
||||
name: 'Grok 3',
|
||||
provider: SystemProviderIds.grok
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {
|
||||
reasoning_effort: 'high'
|
||||
}
|
||||
} as Assistant
|
||||
|
||||
const result = getXAIReasoningParams(assistant, model)
|
||||
expect(result).toHaveProperty('reasoningEffort')
|
||||
expect(result.reasoningEffort).toBe('high')
|
||||
})
|
||||
})
|
||||
|
||||
describe('getBedrockReasoningParams', () => {
|
||||
it('should return empty for non-reasoning model', async () => {
|
||||
const model: Model = {
|
||||
id: 'other-model',
|
||||
name: 'Other Model',
|
||||
provider: 'bedrock'
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {}
|
||||
} as Assistant
|
||||
|
||||
const result = getBedrockReasoningParams(assistant, model)
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('should return empty when no reasoning effort', async () => {
|
||||
const model: Model = {
|
||||
id: 'claude-3-7-sonnet',
|
||||
name: 'Claude 3.7 Sonnet',
|
||||
provider: 'bedrock'
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {}
|
||||
} as Assistant
|
||||
|
||||
const result = getBedrockReasoningParams(assistant, model)
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('should return reasoning config for Claude models on Bedrock', async () => {
|
||||
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
|
||||
|
||||
vi.mocked(isReasoningModel).mockReturnValue(true)
|
||||
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
|
||||
|
||||
const model: Model = {
|
||||
id: 'claude-3-7-sonnet',
|
||||
name: 'Claude 3.7 Sonnet',
|
||||
provider: 'bedrock'
|
||||
} as Model
|
||||
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {
|
||||
reasoning_effort: 'medium',
|
||||
maxTokens: 4096
|
||||
}
|
||||
} as Assistant
|
||||
|
||||
const result = getBedrockReasoningParams(assistant, model)
|
||||
expect(result).toEqual({
|
||||
reasoningConfig: {
|
||||
type: 'enabled',
|
||||
budgetTokens: 2048
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('getCustomParameters', () => {
|
||||
it('should return empty object when no custom parameters', async () => {
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {}
|
||||
} as Assistant
|
||||
|
||||
const result = getCustomParameters(assistant)
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('should return custom parameters as key-value pairs', async () => {
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {
|
||||
customParameters: [
|
||||
{ name: 'param1', value: 'value1', type: 'string' },
|
||||
{ name: 'param2', value: 123, type: 'number' }
|
||||
]
|
||||
}
|
||||
} as Assistant
|
||||
|
||||
const result = getCustomParameters(assistant)
|
||||
expect(result).toEqual({
|
||||
param1: 'value1',
|
||||
param2: 123
|
||||
})
|
||||
})
|
||||
|
||||
it('should parse JSON type parameters', async () => {
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {
|
||||
customParameters: [{ name: 'config', value: '{"key": "value"}', type: 'json' }]
|
||||
}
|
||||
} as Assistant
|
||||
|
||||
const result = getCustomParameters(assistant)
|
||||
expect(result).toEqual({
|
||||
config: { key: 'value' }
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle invalid JSON gracefully', async () => {
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {
|
||||
customParameters: [{ name: 'invalid', value: '{invalid json', type: 'json' }]
|
||||
}
|
||||
} as Assistant
|
||||
|
||||
const result = getCustomParameters(assistant)
|
||||
expect(result).toEqual({
|
||||
invalid: '{invalid json'
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle undefined JSON value', async () => {
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {
|
||||
customParameters: [{ name: 'undef', value: 'undefined', type: 'json' }]
|
||||
}
|
||||
} as Assistant
|
||||
|
||||
const result = getCustomParameters(assistant)
|
||||
expect(result).toEqual({
|
||||
undef: undefined
|
||||
})
|
||||
})
|
||||
|
||||
it('should skip parameters with empty names', async () => {
|
||||
const assistant: Assistant = {
|
||||
id: 'test',
|
||||
name: 'Test',
|
||||
settings: {
|
||||
customParameters: [
|
||||
{ name: '', value: 'value1', type: 'string' },
|
||||
{ name: ' ', value: 'value2', type: 'string' },
|
||||
{ name: 'valid', value: 'value3', type: 'string' }
|
||||
]
|
||||
}
|
||||
} as Assistant
|
||||
|
||||
const result = getCustomParameters(assistant)
|
||||
expect(result).toEqual({
|
||||
valid: 'value3'
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,384 +0,0 @@
|
||||
/**
|
||||
* websearch.ts Unit Tests
|
||||
* Tests for web search parameters generation utilities
|
||||
*/
|
||||
|
||||
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
|
||||
import type { Model } from '@renderer/types'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { buildProviderBuiltinWebSearchConfig, getWebSearchParams } from '../websearch'
|
||||
|
||||
// Mock dependencies
|
||||
vi.mock('@renderer/config/models', () => ({
|
||||
isOpenAIWebSearchChatCompletionOnlyModel: vi.fn((model) => model?.id?.includes('o1-pro') ?? false),
|
||||
isOpenAIDeepResearchModel: vi.fn((model) => model?.id?.includes('o3-mini') ?? false)
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/blacklistMatchPattern', () => ({
|
||||
mapRegexToPatterns: vi.fn((patterns) => patterns || [])
|
||||
}))
|
||||
|
||||
describe('websearch utils', () => {
|
||||
describe('getWebSearchParams', () => {
|
||||
it('should return enhancement params for hunyuan provider', () => {
|
||||
const model: Model = {
|
||||
id: 'hunyuan-model',
|
||||
name: 'Hunyuan Model',
|
||||
provider: 'hunyuan'
|
||||
} as Model
|
||||
|
||||
const result = getWebSearchParams(model)
|
||||
|
||||
expect(result).toEqual({
|
||||
enable_enhancement: true,
|
||||
citation: true,
|
||||
search_info: true
|
||||
})
|
||||
})
|
||||
|
||||
it('should return search params for dashscope provider', () => {
|
||||
const model: Model = {
|
||||
id: 'qwen-model',
|
||||
name: 'Qwen Model',
|
||||
provider: 'dashscope'
|
||||
} as Model
|
||||
|
||||
const result = getWebSearchParams(model)
|
||||
|
||||
expect(result).toEqual({
|
||||
enable_search: true,
|
||||
search_options: {
|
||||
forced_search: true
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should return web_search_options for OpenAI web search models', () => {
|
||||
const model: Model = {
|
||||
id: 'o1-pro',
|
||||
name: 'O1 Pro',
|
||||
provider: 'openai'
|
||||
} as Model
|
||||
|
||||
const result = getWebSearchParams(model)
|
||||
|
||||
expect(result).toEqual({
|
||||
web_search_options: {}
|
||||
})
|
||||
})
|
||||
|
||||
it('should return empty object for other providers', () => {
|
||||
const model: Model = {
|
||||
id: 'gpt-4',
|
||||
name: 'GPT-4',
|
||||
provider: 'openai'
|
||||
} as Model
|
||||
|
||||
const result = getWebSearchParams(model)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('should return empty object for custom provider', () => {
|
||||
const model: Model = {
|
||||
id: 'custom-model',
|
||||
name: 'Custom Model',
|
||||
provider: 'custom-provider'
|
||||
} as Model
|
||||
|
||||
const result = getWebSearchParams(model)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('buildProviderBuiltinWebSearchConfig', () => {
|
||||
const defaultWebSearchConfig: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 50,
|
||||
excludeDomains: []
|
||||
}
|
||||
|
||||
describe('openai provider', () => {
|
||||
it('should return low search context size for low maxResults', () => {
|
||||
const config: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 20,
|
||||
excludeDomains: []
|
||||
}
|
||||
|
||||
const result = buildProviderBuiltinWebSearchConfig('openai', config)
|
||||
|
||||
expect(result).toEqual({
|
||||
openai: {
|
||||
searchContextSize: 'low'
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should return medium search context size for medium maxResults', () => {
|
||||
const config: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 50,
|
||||
excludeDomains: []
|
||||
}
|
||||
|
||||
const result = buildProviderBuiltinWebSearchConfig('openai', config)
|
||||
|
||||
expect(result).toEqual({
|
||||
openai: {
|
||||
searchContextSize: 'medium'
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should return high search context size for high maxResults', () => {
|
||||
const config: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 80,
|
||||
excludeDomains: []
|
||||
}
|
||||
|
||||
const result = buildProviderBuiltinWebSearchConfig('openai', config)
|
||||
|
||||
expect(result).toEqual({
|
||||
openai: {
|
||||
searchContextSize: 'high'
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should use medium for deep research models regardless of maxResults', () => {
|
||||
const config: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 100,
|
||||
excludeDomains: []
|
||||
}
|
||||
|
||||
const model: Model = {
|
||||
id: 'o3-mini',
|
||||
name: 'O3 Mini',
|
||||
provider: 'openai'
|
||||
} as Model
|
||||
|
||||
const result = buildProviderBuiltinWebSearchConfig('openai', config, model)
|
||||
|
||||
expect(result).toEqual({
|
||||
openai: {
|
||||
searchContextSize: 'medium'
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('openai-chat provider', () => {
|
||||
it('should return correct search context size', () => {
|
||||
const config: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 50,
|
||||
excludeDomains: []
|
||||
}
|
||||
|
||||
const result = buildProviderBuiltinWebSearchConfig('openai-chat', config)
|
||||
|
||||
expect(result).toEqual({
|
||||
'openai-chat': {
|
||||
searchContextSize: 'medium'
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle deep research models', () => {
|
||||
const config: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 100,
|
||||
excludeDomains: []
|
||||
}
|
||||
|
||||
const model: Model = {
|
||||
id: 'o3-mini',
|
||||
name: 'O3 Mini',
|
||||
provider: 'openai'
|
||||
} as Model
|
||||
|
||||
const result = buildProviderBuiltinWebSearchConfig('openai-chat', config, model)
|
||||
|
||||
expect(result).toEqual({
|
||||
'openai-chat': {
|
||||
searchContextSize: 'medium'
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('anthropic provider', () => {
|
||||
it('should return anthropic search options with maxUses', () => {
|
||||
const result = buildProviderBuiltinWebSearchConfig('anthropic', defaultWebSearchConfig)
|
||||
|
||||
expect(result).toEqual({
|
||||
anthropic: {
|
||||
maxUses: 50,
|
||||
blockedDomains: undefined
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should include blockedDomains when excludeDomains provided', () => {
|
||||
const config: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 30,
|
||||
excludeDomains: ['example.com', 'test.com']
|
||||
}
|
||||
|
||||
const result = buildProviderBuiltinWebSearchConfig('anthropic', config)
|
||||
|
||||
expect(result).toEqual({
|
||||
anthropic: {
|
||||
maxUses: 30,
|
||||
blockedDomains: ['example.com', 'test.com']
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should not include blockedDomains when empty', () => {
|
||||
const result = buildProviderBuiltinWebSearchConfig('anthropic', defaultWebSearchConfig)
|
||||
|
||||
expect(result).toEqual({
|
||||
anthropic: {
|
||||
maxUses: 50,
|
||||
blockedDomains: undefined
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('xai provider', () => {
|
||||
it('should return xai search options', () => {
|
||||
const result = buildProviderBuiltinWebSearchConfig('xai', defaultWebSearchConfig)
|
||||
|
||||
expect(result).toEqual({
|
||||
xai: {
|
||||
maxSearchResults: 50,
|
||||
returnCitations: true,
|
||||
sources: [{ type: 'web', excludedWebsites: [] }, { type: 'news' }, { type: 'x' }],
|
||||
mode: 'on'
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should limit excluded websites to 5', () => {
|
||||
const config: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 40,
|
||||
excludeDomains: ['site1.com', 'site2.com', 'site3.com', 'site4.com', 'site5.com', 'site6.com', 'site7.com']
|
||||
}
|
||||
|
||||
const result = buildProviderBuiltinWebSearchConfig('xai', config)
|
||||
|
||||
expect(result?.xai?.sources).toBeDefined()
|
||||
const webSource = result?.xai?.sources?.[0]
|
||||
if (webSource && webSource.type === 'web') {
|
||||
expect(webSource.excludedWebsites).toHaveLength(5)
|
||||
}
|
||||
})
|
||||
|
||||
it('should include all sources types', () => {
|
||||
const result = buildProviderBuiltinWebSearchConfig('xai', defaultWebSearchConfig)
|
||||
|
||||
expect(result?.xai?.sources).toHaveLength(3)
|
||||
expect(result?.xai?.sources?.[0].type).toBe('web')
|
||||
expect(result?.xai?.sources?.[1].type).toBe('news')
|
||||
expect(result?.xai?.sources?.[2].type).toBe('x')
|
||||
})
|
||||
})
|
||||
|
||||
describe('openrouter provider', () => {
|
||||
it('should return openrouter plugins config', () => {
|
||||
const result = buildProviderBuiltinWebSearchConfig('openrouter', defaultWebSearchConfig)
|
||||
|
||||
expect(result).toEqual({
|
||||
openrouter: {
|
||||
plugins: [
|
||||
{
|
||||
id: 'web',
|
||||
max_results: 50
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
it('should respect custom maxResults', () => {
|
||||
const config: CherryWebSearchConfig = {
|
||||
searchWithTime: true,
|
||||
maxResults: 75,
|
||||
excludeDomains: []
|
||||
}
|
||||
|
||||
const result = buildProviderBuiltinWebSearchConfig('openrouter', config)
|
||||
|
||||
expect(result).toEqual({
|
||||
openrouter: {
|
||||
plugins: [
|
||||
{
|
||||
id: 'web',
|
||||
max_results: 75
|
||||
}
|
||||
]
|
||||
}
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('unsupported provider', () => {
|
||||
it('should return empty object for unsupported provider', () => {
|
||||
const result = buildProviderBuiltinWebSearchConfig('unsupported' as any, defaultWebSearchConfig)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
|
||||
it('should return empty object for google provider', () => {
|
||||
const result = buildProviderBuiltinWebSearchConfig('google', defaultWebSearchConfig)
|
||||
|
||||
expect(result).toEqual({})
|
||||
})
|
||||
})
|
||||
|
||||
describe('edge cases', () => {
|
||||
it('should handle maxResults at boundary values', () => {
|
||||
// Test boundary at 33 (low/medium)
|
||||
const config33: CherryWebSearchConfig = { searchWithTime: true, maxResults: 33, excludeDomains: [] }
|
||||
const result33 = buildProviderBuiltinWebSearchConfig('openai', config33)
|
||||
expect(result33?.openai?.searchContextSize).toBe('low')
|
||||
|
||||
// Test boundary at 34 (medium)
|
||||
const config34: CherryWebSearchConfig = { searchWithTime: true, maxResults: 34, excludeDomains: [] }
|
||||
const result34 = buildProviderBuiltinWebSearchConfig('openai', config34)
|
||||
expect(result34?.openai?.searchContextSize).toBe('medium')
|
||||
|
||||
// Test boundary at 66 (medium)
|
||||
const config66: CherryWebSearchConfig = { searchWithTime: true, maxResults: 66, excludeDomains: [] }
|
||||
const result66 = buildProviderBuiltinWebSearchConfig('openai', config66)
|
||||
expect(result66?.openai?.searchContextSize).toBe('medium')
|
||||
|
||||
// Test boundary at 67 (high)
|
||||
const config67: CherryWebSearchConfig = { searchWithTime: true, maxResults: 67, excludeDomains: [] }
|
||||
const result67 = buildProviderBuiltinWebSearchConfig('openai', config67)
|
||||
expect(result67?.openai?.searchContextSize).toBe('high')
|
||||
})
|
||||
|
||||
it('should handle zero maxResults', () => {
|
||||
const config: CherryWebSearchConfig = { searchWithTime: true, maxResults: 0, excludeDomains: [] }
|
||||
const result = buildProviderBuiltinWebSearchConfig('openai', config)
|
||||
expect(result?.openai?.searchContextSize).toBe('low')
|
||||
})
|
||||
|
||||
it('should handle very large maxResults', () => {
|
||||
const config: CherryWebSearchConfig = { searchWithTime: true, maxResults: 1000, excludeDomains: [] }
|
||||
const result = buildProviderBuiltinWebSearchConfig('openai', config)
|
||||
expect(result?.openai?.searchContextSize).toBe('high')
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,8 +1,3 @@
|
||||
import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock'
|
||||
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
|
||||
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
|
||||
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
|
||||
import type { XaiProviderOptions } from '@ai-sdk/xai'
|
||||
import { baseProviderIdSchema, customProviderIdSchema } from '@cherrystudio/ai-core/provider'
|
||||
import { loggerService } from '@logger'
|
||||
import {
|
||||
@@ -12,27 +7,17 @@ import {
|
||||
isSupportFlexServiceTierModel,
|
||||
isSupportVerbosityModel
|
||||
} from '@renderer/config/models'
|
||||
import { isSupportServiceTierProvider } from '@renderer/config/providers'
|
||||
import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import type { Assistant, Model, Provider } from '@renderer/types'
|
||||
import {
|
||||
type Assistant,
|
||||
type GroqServiceTier,
|
||||
GroqServiceTiers,
|
||||
type GroqSystemProvider,
|
||||
isGroqServiceTier,
|
||||
isGroqSystemProvider,
|
||||
isOpenAIServiceTier,
|
||||
isTranslateAssistant,
|
||||
type Model,
|
||||
type NotGroqProvider,
|
||||
type OpenAIServiceTier,
|
||||
OpenAIServiceTiers,
|
||||
type Provider,
|
||||
type ServiceTier
|
||||
SystemProviderIds
|
||||
} from '@renderer/types'
|
||||
import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
|
||||
import { isSupportServiceTierProvider } from '@renderer/utils/provider'
|
||||
import type { JSONValue } from 'ai'
|
||||
import { t } from 'i18next'
|
||||
|
||||
import { getAiSdkProviderId } from '../provider/factory'
|
||||
@@ -50,31 +35,8 @@ import { getWebSearchParams } from './websearch'
|
||||
|
||||
const logger = loggerService.withContext('aiCore.utils.options')
|
||||
|
||||
function toOpenAIServiceTier(model: Model, serviceTier: ServiceTier): OpenAIServiceTier {
|
||||
if (
|
||||
!isOpenAIServiceTier(serviceTier) ||
|
||||
(serviceTier === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
|
||||
) {
|
||||
return undefined
|
||||
} else {
|
||||
return serviceTier
|
||||
}
|
||||
}
|
||||
|
||||
function toGroqServiceTier(model: Model, serviceTier: ServiceTier): GroqServiceTier {
|
||||
if (
|
||||
!isGroqServiceTier(serviceTier) ||
|
||||
(serviceTier === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
|
||||
) {
|
||||
return undefined
|
||||
} else {
|
||||
return serviceTier
|
||||
}
|
||||
}
|
||||
|
||||
function getServiceTier<T extends GroqSystemProvider>(model: Model, provider: T): GroqServiceTier
|
||||
function getServiceTier<T extends NotGroqProvider>(model: Model, provider: T): OpenAIServiceTier
|
||||
function getServiceTier<T extends Provider>(model: Model, provider: T): OpenAIServiceTier | GroqServiceTier {
|
||||
// copy from BaseApiClient.ts
|
||||
const getServiceTier = (model: Model, provider: Provider) => {
|
||||
const serviceTierSetting = provider.serviceTier
|
||||
|
||||
if (!isSupportServiceTierProvider(provider) || !isOpenAIModel(model) || !serviceTierSetting) {
|
||||
@@ -82,17 +44,24 @@ function getServiceTier<T extends Provider>(model: Model, provider: T): OpenAISe
|
||||
}
|
||||
|
||||
// 处理不同供应商需要 fallback 到默认值的情况
|
||||
if (isGroqSystemProvider(provider)) {
|
||||
return toGroqServiceTier(model, serviceTierSetting)
|
||||
if (provider.id === SystemProviderIds.groq) {
|
||||
if (
|
||||
!isGroqServiceTier(serviceTierSetting) ||
|
||||
(serviceTierSetting === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
|
||||
) {
|
||||
return undefined
|
||||
}
|
||||
} else {
|
||||
// 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同
|
||||
return toOpenAIServiceTier(model, serviceTierSetting)
|
||||
if (
|
||||
!isOpenAIServiceTier(serviceTierSetting) ||
|
||||
(serviceTierSetting === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
|
||||
) {
|
||||
return undefined
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function getVerbosity(): OpenAIVerbosity {
|
||||
const openAI = getStoreSetting('openAI')
|
||||
return openAI.verbosity
|
||||
return serviceTierSetting
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -109,13 +78,13 @@ export function buildProviderOptions(
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): Record<string, Record<string, JSONValue>> {
|
||||
): Record<string, any> {
|
||||
logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities })
|
||||
const rawProviderId = getAiSdkProviderId(actualProvider)
|
||||
// 构建 provider 特定的选项
|
||||
let providerSpecificOptions: Record<string, any> = {}
|
||||
const serviceTier = getServiceTier(model, actualProvider)
|
||||
const textVerbosity = getVerbosity()
|
||||
const serviceTierSetting = getServiceTier(model, actualProvider)
|
||||
providerSpecificOptions.serviceTier = serviceTierSetting
|
||||
// 根据 provider 类型分离构建逻辑
|
||||
const { data: baseProviderId, success } = baseProviderIdSchema.safeParse(rawProviderId)
|
||||
if (success) {
|
||||
@@ -125,14 +94,9 @@ export function buildProviderOptions(
|
||||
case 'openai-chat':
|
||||
case 'azure':
|
||||
case 'azure-responses':
|
||||
{
|
||||
const options: OpenAIResponsesProviderOptions = buildOpenAIProviderOptions(
|
||||
assistant,
|
||||
model,
|
||||
capabilities,
|
||||
serviceTier
|
||||
)
|
||||
providerSpecificOptions = options
|
||||
providerSpecificOptions = {
|
||||
...buildOpenAIProviderOptions(assistant, model, capabilities),
|
||||
serviceTier: serviceTierSetting
|
||||
}
|
||||
break
|
||||
case 'anthropic':
|
||||
@@ -152,19 +116,12 @@ export function buildProviderOptions(
|
||||
// 对于其他 provider,使用通用的构建逻辑
|
||||
providerSpecificOptions = {
|
||||
...buildGenericProviderOptions(assistant, model, capabilities),
|
||||
serviceTier,
|
||||
textVerbosity
|
||||
serviceTier: serviceTierSetting
|
||||
}
|
||||
break
|
||||
}
|
||||
case 'cherryin':
|
||||
providerSpecificOptions = buildCherryInProviderOptions(
|
||||
assistant,
|
||||
model,
|
||||
capabilities,
|
||||
actualProvider,
|
||||
serviceTier
|
||||
)
|
||||
providerSpecificOptions = buildCherryInProviderOptions(assistant, model, capabilities, actualProvider)
|
||||
break
|
||||
default:
|
||||
throw new Error(`Unsupported base provider ${baseProviderId}`)
|
||||
@@ -178,7 +135,6 @@ export function buildProviderOptions(
|
||||
case 'google-vertex':
|
||||
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
case 'azure-anthropic':
|
||||
case 'google-vertex-anthropic':
|
||||
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
@@ -186,14 +142,13 @@ export function buildProviderOptions(
|
||||
providerSpecificOptions = buildBedrockProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
case 'huggingface':
|
||||
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier)
|
||||
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities)
|
||||
break
|
||||
default:
|
||||
// 对于其他 provider,使用通用的构建逻辑
|
||||
providerSpecificOptions = {
|
||||
...buildGenericProviderOptions(assistant, model, capabilities),
|
||||
serviceTier,
|
||||
textVerbosity
|
||||
serviceTier: serviceTierSetting
|
||||
}
|
||||
}
|
||||
} else {
|
||||
@@ -211,7 +166,6 @@ export function buildProviderOptions(
|
||||
{
|
||||
'google-vertex': 'google',
|
||||
'google-vertex-anthropic': 'anthropic',
|
||||
'azure-anthropic': 'anthropic',
|
||||
'ai-gateway': 'gateway'
|
||||
}[rawProviderId] || rawProviderId
|
||||
|
||||
@@ -235,11 +189,10 @@ function buildOpenAIProviderOptions(
|
||||
enableReasoning: boolean
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
},
|
||||
serviceTier: OpenAIServiceTier
|
||||
): OpenAIResponsesProviderOptions {
|
||||
}
|
||||
): Record<string, any> {
|
||||
const { enableReasoning } = capabilities
|
||||
let providerOptions: OpenAIResponsesProviderOptions = {}
|
||||
let providerOptions: Record<string, any> = {}
|
||||
// OpenAI 推理参数
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getOpenAIReasoningParams(assistant, model)
|
||||
@@ -250,8 +203,8 @@ function buildOpenAIProviderOptions(
|
||||
}
|
||||
|
||||
if (isSupportVerbosityModel(model)) {
|
||||
const openAI = getStoreSetting<'openAI'>('openAI')
|
||||
const userVerbosity = openAI?.verbosity
|
||||
const state = window.store?.getState()
|
||||
const userVerbosity = state?.settings?.openAI?.verbosity
|
||||
|
||||
if (userVerbosity && ['low', 'medium', 'high'].includes(userVerbosity)) {
|
||||
const supportedVerbosity = getModelSupportedVerbosity(model)
|
||||
@@ -265,11 +218,6 @@ function buildOpenAIProviderOptions(
|
||||
}
|
||||
}
|
||||
|
||||
providerOptions = {
|
||||
...providerOptions,
|
||||
serviceTier
|
||||
}
|
||||
|
||||
return providerOptions
|
||||
}
|
||||
|
||||
@@ -284,9 +232,9 @@ function buildAnthropicProviderOptions(
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): AnthropicProviderOptions {
|
||||
): Record<string, any> {
|
||||
const { enableReasoning } = capabilities
|
||||
let providerOptions: AnthropicProviderOptions = {}
|
||||
let providerOptions: Record<string, any> = {}
|
||||
|
||||
// Anthropic 推理参数
|
||||
if (enableReasoning) {
|
||||
@@ -311,9 +259,9 @@ function buildGeminiProviderOptions(
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): GoogleGenerativeAIProviderOptions {
|
||||
): Record<string, any> {
|
||||
const { enableReasoning, enableGenerateImage } = capabilities
|
||||
let providerOptions: GoogleGenerativeAIProviderOptions = {}
|
||||
let providerOptions: Record<string, any> = {}
|
||||
|
||||
// Gemini 推理参数
|
||||
if (enableReasoning) {
|
||||
@@ -342,7 +290,7 @@ function buildXAIProviderOptions(
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): XaiProviderOptions {
|
||||
): Record<string, any> {
|
||||
const { enableReasoning } = capabilities
|
||||
let providerOptions: Record<string, any> = {}
|
||||
|
||||
@@ -365,12 +313,16 @@ function buildCherryInProviderOptions(
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
},
|
||||
actualProvider: Provider,
|
||||
serviceTier: OpenAIServiceTier
|
||||
): OpenAIResponsesProviderOptions | AnthropicProviderOptions | GoogleGenerativeAIProviderOptions {
|
||||
actualProvider: Provider
|
||||
): Record<string, any> {
|
||||
const serviceTierSetting = getServiceTier(model, actualProvider)
|
||||
|
||||
switch (actualProvider.type) {
|
||||
case 'openai':
|
||||
return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier)
|
||||
return {
|
||||
...buildOpenAIProviderOptions(assistant, model, capabilities),
|
||||
serviceTier: serviceTierSetting
|
||||
}
|
||||
|
||||
case 'anthropic':
|
||||
return buildAnthropicProviderOptions(assistant, model, capabilities)
|
||||
@@ -392,9 +344,9 @@ function buildBedrockProviderOptions(
|
||||
enableWebSearch: boolean
|
||||
enableGenerateImage: boolean
|
||||
}
|
||||
): BedrockProviderOptions {
|
||||
): Record<string, any> {
|
||||
const { enableReasoning } = capabilities
|
||||
let providerOptions: BedrockProviderOptions = {}
|
||||
let providerOptions: Record<string, any> = {}
|
||||
|
||||
if (enableReasoning) {
|
||||
const reasoningParams = getBedrockReasoningParams(assistant, model)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock'
|
||||
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
|
||||
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
|
||||
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
|
||||
import type { XaiProviderOptions } from '@ai-sdk/xai'
|
||||
import { loggerService } from '@logger'
|
||||
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
@@ -12,7 +11,6 @@ import {
|
||||
isDeepSeekHybridInferenceModel,
|
||||
isDoubaoSeedAfter251015,
|
||||
isDoubaoThinkingAutoModel,
|
||||
isGemini3Model,
|
||||
isGPT51SeriesModel,
|
||||
isGrok4FastReasoningModel,
|
||||
isGrokReasoningModel,
|
||||
@@ -34,13 +32,13 @@ import {
|
||||
isSupportedThinkingTokenZhipuModel,
|
||||
MODEL_SUPPORTED_REASONING_EFFORT
|
||||
} from '@renderer/config/models'
|
||||
import { isSupportEnableThinkingProvider } from '@renderer/config/providers'
|
||||
import { getStoreSetting } from '@renderer/hooks/useSettings'
|
||||
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import type { Assistant, Model, ReasoningEffortOption } from '@renderer/types'
|
||||
import type { SettingsState } from '@renderer/store/settings'
|
||||
import type { Assistant, Model } from '@renderer/types'
|
||||
import { EFFORT_RATIO, isSystemProvider, SystemProviderIds } from '@renderer/types'
|
||||
import type { OpenAISummaryText } from '@renderer/types/aiCoreTypes'
|
||||
import type { ReasoningEffortOptionalParams } from '@renderer/types/sdk'
|
||||
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
|
||||
import { toInteger } from 'lodash'
|
||||
|
||||
const logger = loggerService.withContext('reasoning')
|
||||
@@ -132,7 +130,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
|
||||
}
|
||||
|
||||
// Specially for GPT-5.1. Suppose this is a OpenAI Compatible provider
|
||||
if (isGPT51SeriesModel(model)) {
|
||||
if (isGPT51SeriesModel(model) && reasoningEffort === 'none') {
|
||||
return {
|
||||
reasoningEffort: 'none'
|
||||
}
|
||||
@@ -280,12 +278,6 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
|
||||
|
||||
// gemini series, openai compatible api
|
||||
if (isSupportedThinkingTokenGeminiModel(model)) {
|
||||
// https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#openai_compatibility
|
||||
if (isGemini3Model(model)) {
|
||||
return {
|
||||
reasoning_effort: reasoningEffort
|
||||
}
|
||||
}
|
||||
if (reasoningEffort === 'auto') {
|
||||
return {
|
||||
extra_body: {
|
||||
@@ -349,14 +341,10 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
|
||||
}
|
||||
|
||||
/**
|
||||
* Get OpenAI reasoning parameters
|
||||
* Extracted from OpenAIResponseAPIClient and OpenAIAPIClient logic
|
||||
* For official OpenAI provider only
|
||||
* 获取 OpenAI 推理参数
|
||||
* 从 OpenAIResponseAPIClient 和 OpenAIAPIClient 中提取的逻辑
|
||||
*/
|
||||
export function getOpenAIReasoningParams(
|
||||
assistant: Assistant,
|
||||
model: Model
|
||||
): Pick<OpenAIResponsesProviderOptions, 'reasoningEffort' | 'reasoningSummary'> {
|
||||
export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
|
||||
if (!isReasoningModel(model)) {
|
||||
return {}
|
||||
}
|
||||
@@ -367,10 +355,6 @@ export function getOpenAIReasoningParams(
|
||||
return {}
|
||||
}
|
||||
|
||||
if (isOpenAIDeepResearchModel(model) || reasoningEffort === 'auto') {
|
||||
reasoningEffort = 'medium'
|
||||
}
|
||||
|
||||
// 非OpenAI模型,但是Provider类型是responses/azure openai的情况
|
||||
if (!isOpenAIModel(model)) {
|
||||
return {
|
||||
@@ -378,17 +362,21 @@ export function getOpenAIReasoningParams(
|
||||
}
|
||||
}
|
||||
|
||||
const openAI = getStoreSetting('openAI')
|
||||
const summaryText = openAI.summaryText
|
||||
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
|
||||
const summaryText = openAI?.summaryText || 'off'
|
||||
|
||||
let reasoningSummary: OpenAISummaryText = undefined
|
||||
let reasoningSummary: string | undefined = undefined
|
||||
|
||||
if (model.id.includes('o1-pro')) {
|
||||
if (summaryText === 'off' || model.id.includes('o1-pro')) {
|
||||
reasoningSummary = undefined
|
||||
} else {
|
||||
reasoningSummary = summaryText
|
||||
}
|
||||
|
||||
if (isOpenAIDeepResearchModel(model)) {
|
||||
reasoningEffort = 'medium'
|
||||
}
|
||||
|
||||
// OpenAI 推理参数
|
||||
if (isSupportedReasoningEffortOpenAIModel(model)) {
|
||||
return {
|
||||
@@ -400,26 +388,19 @@ export function getOpenAIReasoningParams(
|
||||
return {}
|
||||
}
|
||||
|
||||
export function getAnthropicThinkingBudget(
|
||||
maxTokens: number | undefined,
|
||||
reasoningEffort: string | undefined,
|
||||
modelId: string
|
||||
): number | undefined {
|
||||
export function getAnthropicThinkingBudget(assistant: Assistant, model: Model): number {
|
||||
const { maxTokens, reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
|
||||
if (reasoningEffort === undefined || reasoningEffort === 'none') {
|
||||
return undefined
|
||||
return 0
|
||||
}
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
|
||||
const tokenLimit = findTokenLimit(modelId)
|
||||
if (!tokenLimit) {
|
||||
return undefined
|
||||
}
|
||||
|
||||
const budgetTokens = Math.max(
|
||||
1024,
|
||||
Math.floor(
|
||||
Math.min(
|
||||
(tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min,
|
||||
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
|
||||
findTokenLimit(model.id)?.min!,
|
||||
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
|
||||
)
|
||||
)
|
||||
@@ -451,8 +432,7 @@ export function getAnthropicReasoningParams(
|
||||
|
||||
// Claude 推理参数
|
||||
if (isSupportedThinkingTokenClaudeModel(model)) {
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
const budgetTokens = getAnthropicThinkingBudget(maxTokens, reasoningEffort, model.id)
|
||||
const budgetTokens = getAnthropicThinkingBudget(assistant, model)
|
||||
|
||||
return {
|
||||
thinking: {
|
||||
@@ -465,21 +445,6 @@ export function getAnthropicReasoningParams(
|
||||
return {}
|
||||
}
|
||||
|
||||
type GoogelThinkingLevel = NonNullable<GoogleGenerativeAIProviderOptions['thinkingConfig']>['thinkingLevel']
|
||||
|
||||
function mapToGeminiThinkingLevel(reasoningEffort: ReasoningEffortOption): GoogelThinkingLevel {
|
||||
switch (reasoningEffort) {
|
||||
case 'low':
|
||||
return 'low'
|
||||
case 'medium':
|
||||
return 'medium'
|
||||
case 'high':
|
||||
return 'high'
|
||||
default:
|
||||
return 'medium'
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取 Gemini 推理参数
|
||||
* 从 GeminiAPIClient 中提取的逻辑
|
||||
@@ -507,15 +472,6 @@ export function getGeminiReasoningParams(
|
||||
}
|
||||
}
|
||||
|
||||
// https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#new_api_features_in_gemini_3
|
||||
if (isGemini3Model(model)) {
|
||||
return {
|
||||
thinkingConfig: {
|
||||
thinkingLevel: mapToGeminiThinkingLevel(reasoningEffort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const effortRatio = EFFORT_RATIO[reasoningEffort]
|
||||
|
||||
if (effortRatio > 1) {
|
||||
@@ -599,8 +555,7 @@ export function getBedrockReasoningParams(
|
||||
return {}
|
||||
}
|
||||
|
||||
const { maxTokens } = getAssistantSettings(assistant)
|
||||
const budgetTokens = getAnthropicThinkingBudget(maxTokens, reasoningEffort, model.id)
|
||||
const budgetTokens = getAnthropicThinkingBudget(assistant, model)
|
||||
return {
|
||||
reasoningConfig: {
|
||||
type: 'enabled',
|
||||
|
||||
@@ -47,7 +47,6 @@ export function buildProviderBuiltinWebSearchConfig(
|
||||
model?: Model
|
||||
): WebSearchPluginConfig | undefined {
|
||||
switch (providerId) {
|
||||
case 'azure-responses':
|
||||
case 'openai': {
|
||||
const searchContextSize = isOpenAIDeepResearchModel(model)
|
||||
? 'medium'
|
||||
|
||||
@@ -1,120 +1,35 @@
|
||||
import 'emoji-picker-element'
|
||||
|
||||
import TwemojiCountryFlagsWoff2 from '@renderer/assets/fonts/country-flag-fonts/TwemojiCountryFlags.woff2?url'
|
||||
import { useTheme } from '@renderer/context/ThemeProvider'
|
||||
import type { LanguageVarious } from '@renderer/types'
|
||||
import { polyfillCountryFlagEmojis } from 'country-flag-emoji-polyfill'
|
||||
// i18n translations from emoji-picker-element
|
||||
import de from 'emoji-picker-element/i18n/de'
|
||||
import en from 'emoji-picker-element/i18n/en'
|
||||
import es from 'emoji-picker-element/i18n/es'
|
||||
import fr from 'emoji-picker-element/i18n/fr'
|
||||
import ja from 'emoji-picker-element/i18n/ja'
|
||||
import pt_PT from 'emoji-picker-element/i18n/pt_PT'
|
||||
import ru_RU from 'emoji-picker-element/i18n/ru_RU'
|
||||
import zh_CN from 'emoji-picker-element/i18n/zh_CN'
|
||||
import type Picker from 'emoji-picker-element/picker'
|
||||
import type { EmojiClickEvent, NativeEmoji } from 'emoji-picker-element/shared'
|
||||
// Emoji data from emoji-picker-element-data (local, no CDN)
|
||||
// Using CLDR format for full multi-language search support (28 languages)
|
||||
import dataDE from 'emoji-picker-element-data/de/cldr/data.json?url'
|
||||
import dataEN from 'emoji-picker-element-data/en/cldr/data.json?url'
|
||||
import dataES from 'emoji-picker-element-data/es/cldr/data.json?url'
|
||||
import dataFR from 'emoji-picker-element-data/fr/cldr/data.json?url'
|
||||
import dataJA from 'emoji-picker-element-data/ja/cldr/data.json?url'
|
||||
import dataPT from 'emoji-picker-element-data/pt/cldr/data.json?url'
|
||||
import dataRU from 'emoji-picker-element-data/ru/cldr/data.json?url'
|
||||
import dataZH from 'emoji-picker-element-data/zh/cldr/data.json?url'
|
||||
import dataZH_HANT from 'emoji-picker-element-data/zh-hant/cldr/data.json?url'
|
||||
import type { FC } from 'react'
|
||||
import { useEffect, useRef } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
interface Props {
|
||||
onEmojiClick: (emoji: string) => void
|
||||
}
|
||||
|
||||
// Mapping from app locale to emoji-picker-element i18n
|
||||
const i18nMap: Record<LanguageVarious, typeof en> = {
|
||||
'en-US': en,
|
||||
'zh-CN': zh_CN,
|
||||
'zh-TW': zh_CN, // Closest available
|
||||
'de-DE': de,
|
||||
'el-GR': en, // No Greek available, fallback to English
|
||||
'es-ES': es,
|
||||
'fr-FR': fr,
|
||||
'ja-JP': ja,
|
||||
'pt-PT': pt_PT,
|
||||
'ru-RU': ru_RU
|
||||
}
|
||||
|
||||
// Mapping from app locale to emoji data URL
|
||||
// Using CLDR format provides native language search support for all locales
|
||||
const dataSourceMap: Record<LanguageVarious, string> = {
|
||||
'en-US': dataEN,
|
||||
'zh-CN': dataZH,
|
||||
'zh-TW': dataZH_HANT,
|
||||
'de-DE': dataDE,
|
||||
'el-GR': dataEN, // No Greek CLDR available, fallback to English
|
||||
'es-ES': dataES,
|
||||
'fr-FR': dataFR,
|
||||
'ja-JP': dataJA,
|
||||
'pt-PT': dataPT,
|
||||
'ru-RU': dataRU
|
||||
}
|
||||
|
||||
// Mapping from app locale to emoji-picker-element locale string
|
||||
// Must match the data source locale for proper IndexedDB caching
|
||||
const localeMap: Record<LanguageVarious, string> = {
|
||||
'en-US': 'en',
|
||||
'zh-CN': 'zh',
|
||||
'zh-TW': 'zh-hant',
|
||||
'de-DE': 'de',
|
||||
'el-GR': 'en',
|
||||
'es-ES': 'es',
|
||||
'fr-FR': 'fr',
|
||||
'ja-JP': 'ja',
|
||||
'pt-PT': 'pt',
|
||||
'ru-RU': 'ru'
|
||||
}
|
||||
|
||||
const EmojiPicker: FC<Props> = ({ onEmojiClick }) => {
|
||||
const { theme } = useTheme()
|
||||
const { i18n } = useTranslation()
|
||||
const ref = useRef<Picker>(null)
|
||||
const currentLocale = i18n.language as LanguageVarious
|
||||
const ref = useRef<HTMLDivElement>(null)
|
||||
|
||||
useEffect(() => {
|
||||
polyfillCountryFlagEmojis('Twemoji Mozilla', TwemojiCountryFlagsWoff2)
|
||||
}, [])
|
||||
|
||||
// Configure picker with i18n and dataSource
|
||||
useEffect(() => {
|
||||
const picker = ref.current
|
||||
if (picker) {
|
||||
picker.i18n = i18nMap[currentLocale] || en
|
||||
picker.dataSource = dataSourceMap[currentLocale] || dataEN
|
||||
picker.locale = localeMap[currentLocale] || 'en'
|
||||
}
|
||||
}, [currentLocale])
|
||||
const refValue = ref.current
|
||||
|
||||
useEffect(() => {
|
||||
const picker = ref.current
|
||||
|
||||
if (picker) {
|
||||
const handleEmojiClick = (event: EmojiClickEvent) => {
|
||||
if (refValue) {
|
||||
const handleEmojiClick = (event: any) => {
|
||||
event.stopPropagation()
|
||||
const { detail } = event
|
||||
// Use detail.unicode (processed with skin tone) or fallback to emoji's unicode for native emoji
|
||||
const unicode = detail.unicode || ('unicode' in detail.emoji ? (detail.emoji as NativeEmoji).unicode : '')
|
||||
onEmojiClick(unicode)
|
||||
onEmojiClick(event.detail.unicode || event.detail.emoji.unicode)
|
||||
}
|
||||
// 添加事件监听器
|
||||
picker.addEventListener('emoji-click', handleEmojiClick)
|
||||
refValue.addEventListener('emoji-click', handleEmojiClick)
|
||||
|
||||
// 清理事件监听器
|
||||
return () => {
|
||||
picker.removeEventListener('emoji-click', handleEmojiClick)
|
||||
refValue.removeEventListener('emoji-click', handleEmojiClick)
|
||||
}
|
||||
}
|
||||
return
|
||||
|
||||
@@ -1,157 +0,0 @@
|
||||
import { loggerService } from '@logger'
|
||||
import type { UIActionResult } from '@mcp-ui/client'
|
||||
import { UIResourceRenderer } from '@mcp-ui/client'
|
||||
import type { EmbeddedResource } from '@modelcontextprotocol/sdk/types.js'
|
||||
import { isUIResource } from '@renderer/types'
|
||||
import type { FC } from 'react'
|
||||
import { useCallback, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import styled from 'styled-components'
|
||||
|
||||
const logger = loggerService.withContext('MCPUIRenderer')
|
||||
|
||||
interface Props {
|
||||
resource: EmbeddedResource
|
||||
serverId?: string
|
||||
serverName?: string
|
||||
onToolCall?: (toolName: string, params: any) => Promise<any>
|
||||
}
|
||||
|
||||
const MCPUIRenderer: FC<Props> = ({ resource, onToolCall }) => {
|
||||
const { t } = useTranslation()
|
||||
const [error] = useState<string | null>(null)
|
||||
|
||||
const handleUIAction = useCallback(
|
||||
async (result: UIActionResult): Promise<any> => {
|
||||
logger.debug('UI Action received:', result)
|
||||
|
||||
try {
|
||||
switch (result.type) {
|
||||
case 'tool': {
|
||||
// Handle tool call from UI
|
||||
if (onToolCall) {
|
||||
const { toolName, params } = result.payload
|
||||
logger.info(`UI requesting tool call: ${toolName}`, { params })
|
||||
const response = await onToolCall(toolName, params)
|
||||
|
||||
// Check if the response contains a UIResource
|
||||
try {
|
||||
if (response && response.content && Array.isArray(response.content)) {
|
||||
const firstContent = response.content[0]
|
||||
if (firstContent && firstContent.type === 'text' && firstContent.text) {
|
||||
const parsedText = JSON.parse(firstContent.text)
|
||||
if (isUIResource(parsedText)) {
|
||||
// Return the UIResource directly for rendering in the iframe
|
||||
logger.info('Tool response contains UIResource:', { uri: parsedText.resource.uri })
|
||||
return { status: 'success', data: parsedText }
|
||||
}
|
||||
}
|
||||
}
|
||||
} catch (parseError) {
|
||||
// Not a UIResource, return the original response
|
||||
logger.debug('Tool response is not a UIResource')
|
||||
}
|
||||
|
||||
return { status: 'success', data: response }
|
||||
} else {
|
||||
logger.warn('Tool call requested but no handler provided')
|
||||
return { status: 'error', message: 'Tool call handler not available' }
|
||||
}
|
||||
}
|
||||
|
||||
case 'intent': {
|
||||
// Handle user intent
|
||||
logger.info('UI intent:', result.payload)
|
||||
window.toast.info(t('message.mcp.ui.intent_received'))
|
||||
return { status: 'acknowledged' }
|
||||
}
|
||||
|
||||
case 'notify': {
|
||||
// Handle notification from UI
|
||||
logger.info('UI notification:', result.payload)
|
||||
window.toast.info(result.payload.message || t('message.mcp.ui.notification'))
|
||||
return { status: 'acknowledged' }
|
||||
}
|
||||
|
||||
case 'prompt': {
|
||||
// Handle prompt request from UI
|
||||
logger.info('UI prompt request:', result.payload)
|
||||
// TODO: Integrate with prompt system
|
||||
return { status: 'error', message: 'Prompt execution not yet implemented' }
|
||||
}
|
||||
|
||||
case 'link': {
|
||||
// Handle navigation request
|
||||
const { url } = result.payload
|
||||
logger.info('UI navigation request:', { url })
|
||||
window.open(url, '_blank')
|
||||
return { status: 'acknowledged' }
|
||||
}
|
||||
|
||||
default:
|
||||
logger.warn('Unknown UI action type:', { result })
|
||||
return { status: 'error', message: 'Unknown action type' }
|
||||
}
|
||||
} catch (err) {
|
||||
logger.error('Error handling UI action:', err as Error)
|
||||
return {
|
||||
status: 'error',
|
||||
message: err instanceof Error ? err.message : 'Unknown error'
|
||||
}
|
||||
}
|
||||
},
|
||||
[onToolCall, t]
|
||||
)
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<ErrorContainer>
|
||||
<ErrorTitle>{t('message.mcp.ui.error')}</ErrorTitle>
|
||||
<ErrorMessage>{error}</ErrorMessage>
|
||||
</ErrorContainer>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<UIContainer>
|
||||
<UIResourceRenderer resource={resource} onUIAction={handleUIAction} />
|
||||
</UIContainer>
|
||||
)
|
||||
}
|
||||
|
||||
const UIContainer = styled.div`
|
||||
width: 100%;
|
||||
min-height: 400px;
|
||||
border-radius: 8px;
|
||||
overflow: hidden;
|
||||
background: var(--color-background);
|
||||
border: 1px solid var(--color-border);
|
||||
|
||||
iframe {
|
||||
width: 100%;
|
||||
border: none;
|
||||
min-height: 400px;
|
||||
height: 600px;
|
||||
}
|
||||
`
|
||||
|
||||
const ErrorContainer = styled.div`
|
||||
padding: 16px;
|
||||
border-radius: 8px;
|
||||
background: var(--color-error-bg, #fee);
|
||||
border: 1px solid var(--color-error-border, #fcc);
|
||||
color: var(--color-error-text, #c33);
|
||||
`
|
||||
|
||||
const ErrorTitle = styled.div`
|
||||
font-weight: 600;
|
||||
margin-bottom: 8px;
|
||||
font-size: 14px;
|
||||
`
|
||||
|
||||
const ErrorMessage = styled.div`
|
||||
font-size: 13px;
|
||||
opacity: 0.9;
|
||||
`
|
||||
|
||||
export default MCPUIRenderer
|
||||
@@ -1 +0,0 @@
|
||||
export { default as MCPUIRenderer } from './MCPUIRenderer'
|
||||
@@ -1,141 +0,0 @@
|
||||
import { importChatGPTConversations } from '@renderer/services/import'
|
||||
import { Alert, Modal, Progress, Space, Spin } from 'antd'
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
import { TopView } from '../TopView'
|
||||
|
||||
interface PopupResult {
|
||||
success?: boolean
|
||||
}
|
||||
|
||||
interface Props {
|
||||
resolve: (data: PopupResult) => void
|
||||
}
|
||||
|
||||
const PopupContainer: React.FC<Props> = ({ resolve }) => {
|
||||
const [open, setOpen] = useState(true)
|
||||
const [selecting, setSelecting] = useState(false)
|
||||
const [importing, setImporting] = useState(false)
|
||||
const { t } = useTranslation()
|
||||
|
||||
const onOk = async () => {
|
||||
setSelecting(true)
|
||||
try {
|
||||
// Select ChatGPT JSON file
|
||||
const file = await window.api.file.open({
|
||||
filters: [{ name: 'ChatGPT Conversations', extensions: ['json'] }]
|
||||
})
|
||||
|
||||
setSelecting(false)
|
||||
|
||||
if (!file) {
|
||||
return
|
||||
}
|
||||
|
||||
setImporting(true)
|
||||
|
||||
// Parse file content
|
||||
const fileContent = typeof file.content === 'string' ? file.content : new TextDecoder().decode(file.content)
|
||||
|
||||
// Import conversations
|
||||
const result = await importChatGPTConversations(fileContent)
|
||||
|
||||
if (result.success) {
|
||||
window.toast.success(
|
||||
t('import.chatgpt.success', {
|
||||
topics: result.topicsCount,
|
||||
messages: result.messagesCount
|
||||
})
|
||||
)
|
||||
setOpen(false)
|
||||
} else {
|
||||
window.toast.error(result.error || t('import.chatgpt.error.unknown'))
|
||||
}
|
||||
} catch (error) {
|
||||
window.toast.error(t('import.chatgpt.error.unknown'))
|
||||
setOpen(false)
|
||||
} finally {
|
||||
setSelecting(false)
|
||||
setImporting(false)
|
||||
}
|
||||
}
|
||||
|
||||
const onCancel = () => {
|
||||
setOpen(false)
|
||||
}
|
||||
|
||||
const onClose = () => {
|
||||
resolve({})
|
||||
}
|
||||
|
||||
ImportPopup.hide = onCancel
|
||||
|
||||
return (
|
||||
<Modal
|
||||
title={t('import.chatgpt.title')}
|
||||
open={open}
|
||||
onOk={onOk}
|
||||
onCancel={onCancel}
|
||||
afterClose={onClose}
|
||||
okText={t('import.chatgpt.button')}
|
||||
okButtonProps={{ disabled: selecting || importing, loading: selecting }}
|
||||
cancelButtonProps={{ disabled: selecting || importing }}
|
||||
maskClosable={false}
|
||||
transitionName="animation-move-down"
|
||||
centered>
|
||||
{!selecting && !importing && (
|
||||
<Space direction="vertical" style={{ width: '100%' }}>
|
||||
<div>{t('import.chatgpt.description')}</div>
|
||||
<Alert
|
||||
message={t('import.chatgpt.help.title')}
|
||||
description={
|
||||
<div>
|
||||
<p>{t('import.chatgpt.help.step1')}</p>
|
||||
<p>{t('import.chatgpt.help.step2')}</p>
|
||||
<p>{t('import.chatgpt.help.step3')}</p>
|
||||
</div>
|
||||
}
|
||||
type="info"
|
||||
showIcon
|
||||
style={{ marginTop: 12 }}
|
||||
/>
|
||||
</Space>
|
||||
)}
|
||||
{selecting && (
|
||||
<div style={{ textAlign: 'center', padding: '40px 0' }}>
|
||||
<Spin size="large" />
|
||||
<div style={{ marginTop: 16 }}>{t('import.chatgpt.selecting')}</div>
|
||||
</div>
|
||||
)}
|
||||
{importing && (
|
||||
<div style={{ textAlign: 'center', padding: '20px 0' }}>
|
||||
<Progress percent={100} status="active" strokeColor="var(--color-primary)" showInfo={false} />
|
||||
<div style={{ marginTop: 16 }}>{t('import.chatgpt.importing')}</div>
|
||||
</div>
|
||||
)}
|
||||
</Modal>
|
||||
)
|
||||
}
|
||||
|
||||
const TopViewKey = 'ImportPopup'
|
||||
|
||||
export default class ImportPopup {
|
||||
static topviewId = 0
|
||||
static hide() {
|
||||
TopView.hide(TopViewKey)
|
||||
}
|
||||
static show() {
|
||||
return new Promise<PopupResult>((resolve) => {
|
||||
TopView.show(
|
||||
<PopupContainer
|
||||
resolve={(v) => {
|
||||
resolve(v)
|
||||
TopView.hide(TopViewKey)
|
||||
}}
|
||||
/>,
|
||||
TopViewKey
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -15,7 +15,7 @@ import type {
|
||||
UpdateAgentForm
|
||||
} from '@renderer/types'
|
||||
import { AgentConfigurationSchema, isAgentType } from '@renderer/types'
|
||||
import { Alert, Button, Input, Modal, Select } from 'antd'
|
||||
import { Button, Input, Modal, Select } from 'antd'
|
||||
import { AlertTriangleIcon } from 'lucide-react'
|
||||
import type { ChangeEvent, FormEvent } from 'react'
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
@@ -58,7 +58,6 @@ const PopupContainer: React.FC<Props> = ({ agent, afterSubmit, resolve }) => {
|
||||
const isEditing = (agent?: AgentWithTools) => agent !== undefined
|
||||
|
||||
const [form, setForm] = useState<BaseAgentForm>(() => buildAgentForm(agent))
|
||||
const [hasGitBash, setHasGitBash] = useState<boolean>(true)
|
||||
|
||||
useEffect(() => {
|
||||
if (open) {
|
||||
@@ -66,30 +65,6 @@ const PopupContainer: React.FC<Props> = ({ agent, afterSubmit, resolve }) => {
|
||||
}
|
||||
}, [agent, open])
|
||||
|
||||
const checkGitBash = useCallback(
|
||||
async (showToast = false) => {
|
||||
try {
|
||||
const gitBashInstalled = await window.api.system.checkGitBash()
|
||||
setHasGitBash(gitBashInstalled)
|
||||
if (showToast) {
|
||||
if (gitBashInstalled) {
|
||||
window.toast.success(t('agent.gitBash.success', 'Git Bash detected successfully!'))
|
||||
} else {
|
||||
window.toast.error(t('agent.gitBash.notFound', 'Git Bash not found. Please install it first.'))
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
logger.error('Failed to check Git Bash:', error as Error)
|
||||
setHasGitBash(true) // Default to true on error to avoid false warnings
|
||||
}
|
||||
},
|
||||
[t]
|
||||
)
|
||||
|
||||
useEffect(() => {
|
||||
checkGitBash()
|
||||
}, [checkGitBash])
|
||||
|
||||
const selectedPermissionMode = form.configuration?.permission_mode ?? 'default'
|
||||
|
||||
const onPermissionModeChange = useCallback((value: PermissionMode) => {
|
||||
@@ -300,36 +275,6 @@ const PopupContainer: React.FC<Props> = ({ agent, afterSubmit, resolve }) => {
|
||||
footer={null}>
|
||||
<StyledForm onSubmit={onSubmit}>
|
||||
<FormContent>
|
||||
{!hasGitBash && (
|
||||
<Alert
|
||||
message={t('agent.gitBash.error.title', 'Git Bash Required')}
|
||||
description={
|
||||
<div>
|
||||
<div style={{ marginBottom: 8 }}>
|
||||
{t(
|
||||
'agent.gitBash.error.description',
|
||||
'Git Bash is required to run agents on Windows. The agent cannot function without it. Please install Git for Windows from'
|
||||
)}{' '}
|
||||
<a
|
||||
href="https://git-scm.com/download/win"
|
||||
onClick={(e) => {
|
||||
e.preventDefault()
|
||||
window.api.openWebsite('https://git-scm.com/download/win')
|
||||
}}
|
||||
style={{ textDecoration: 'underline' }}>
|
||||
git-scm.com
|
||||
</a>
|
||||
</div>
|
||||
<Button size="small" onClick={() => checkGitBash(true)}>
|
||||
{t('agent.gitBash.error.recheck', 'Recheck Git Bash Installation')}
|
||||
</Button>
|
||||
</div>
|
||||
}
|
||||
type="error"
|
||||
showIcon
|
||||
style={{ marginBottom: 16 }}
|
||||
/>
|
||||
)}
|
||||
<FormRow>
|
||||
<FormItem style={{ flex: 1 }}>
|
||||
<Label>
|
||||
@@ -432,7 +377,7 @@ const PopupContainer: React.FC<Props> = ({ agent, afterSubmit, resolve }) => {
|
||||
|
||||
<FormFooter>
|
||||
<Button onClick={onCancel}>{t('common.close')}</Button>
|
||||
<Button type="primary" htmlType="submit" loading={loadingRef.current} disabled={!hasGitBash}>
|
||||
<Button type="primary" htmlType="submit" loading={loadingRef.current}>
|
||||
{isEditing(agent) ? t('common.confirm') : t('common.add')}
|
||||
</Button>
|
||||
</FormFooter>
|
||||
|
||||
@@ -6,7 +6,7 @@ import { useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import styled, { css } from 'styled-components'
|
||||
|
||||
interface SelectorOption<V = string | number | undefined | null> {
|
||||
interface SelectorOption<V = string | number> {
|
||||
label: string | ReactNode
|
||||
value: V
|
||||
type?: 'group'
|
||||
@@ -14,7 +14,7 @@ interface SelectorOption<V = string | number | undefined | null> {
|
||||
disabled?: boolean
|
||||
}
|
||||
|
||||
interface BaseSelectorProps<V = string | number | undefined | null> {
|
||||
interface BaseSelectorProps<V = string | number> {
|
||||
options: SelectorOption<V>[]
|
||||
placeholder?: string
|
||||
placement?: 'topLeft' | 'topCenter' | 'topRight' | 'bottomLeft' | 'bottomCenter' | 'bottomRight' | 'top' | 'bottom'
|
||||
@@ -39,7 +39,7 @@ interface MultipleSelectorProps<V> extends BaseSelectorProps<V> {
|
||||
|
||||
export type SelectorProps<V> = SingleSelectorProps<V> | MultipleSelectorProps<V>
|
||||
|
||||
const Selector = <V extends string | number | undefined | null>({
|
||||
const Selector = <V extends string | number>({
|
||||
options,
|
||||
value,
|
||||
onChange = () => {},
|
||||
|
||||
@@ -140,11 +140,11 @@ describe('DynamicVirtualList', () => {
|
||||
// Should call isSticky function during rendering
|
||||
expect(isSticky).toHaveBeenCalled()
|
||||
|
||||
// Should apply sticky styles to sticky items
|
||||
// Sticky items within visible range should have proper z-index but may be absolute until scrolled
|
||||
const stickyItem = document.querySelector('[data-index="0"]') as HTMLElement
|
||||
expect(stickyItem).toBeInTheDocument()
|
||||
expect(stickyItem).toHaveStyle('position: sticky')
|
||||
expect(stickyItem).toHaveStyle('z-index: 1')
|
||||
// When sticky item is in visible range, it gets z-index but may not be sticky yet
|
||||
expect(stickyItem).toHaveStyle('z-index: 999')
|
||||
})
|
||||
|
||||
it('should apply absolute positioning to non-sticky items', () => {
|
||||
|
||||
@@ -24,7 +24,7 @@ exports[`DynamicVirtualList > basic rendering > snapshot test 1`] = `
|
||||
>
|
||||
<div
|
||||
data-index="0"
|
||||
style="position: absolute; top: 0px; left: 0px; transform: translateY(0px); width: 100%;"
|
||||
style="position: absolute; top: 0px; left: 0px; z-index: 0; pointer-events: auto; transform: translateY(0px); width: 100%;"
|
||||
>
|
||||
<div
|
||||
data-testid="item-0"
|
||||
@@ -34,7 +34,7 @@ exports[`DynamicVirtualList > basic rendering > snapshot test 1`] = `
|
||||
</div>
|
||||
<div
|
||||
data-index="1"
|
||||
style="position: absolute; top: 0px; left: 0px; transform: translateY(50px); width: 100%;"
|
||||
style="position: absolute; top: 0px; left: 0px; z-index: 0; pointer-events: auto; transform: translateY(50px); width: 100%;"
|
||||
>
|
||||
<div
|
||||
data-testid="item-1"
|
||||
@@ -44,7 +44,7 @@ exports[`DynamicVirtualList > basic rendering > snapshot test 1`] = `
|
||||
</div>
|
||||
<div
|
||||
data-index="2"
|
||||
style="position: absolute; top: 0px; left: 0px; transform: translateY(100px); width: 100%;"
|
||||
style="position: absolute; top: 0px; left: 0px; z-index: 0; pointer-events: auto; transform: translateY(100px); width: 100%;"
|
||||
>
|
||||
<div
|
||||
data-testid="item-2"
|
||||
|
||||
@@ -62,6 +62,12 @@ export interface DynamicVirtualListProps<T> extends InheritedVirtualizerOptions
|
||||
*/
|
||||
isSticky?: (index: number) => boolean
|
||||
|
||||
/**
|
||||
* Get the depth/level of an item for hierarchical sticky positioning
|
||||
* Used with isSticky to determine ancestor relationships
|
||||
*/
|
||||
getItemDepth?: (index: number) => number
|
||||
|
||||
/**
|
||||
* Range extractor function, cannot be used with isSticky
|
||||
*/
|
||||
@@ -101,6 +107,7 @@ function DynamicVirtualList<T>(props: DynamicVirtualListProps<T>) {
|
||||
size,
|
||||
estimateSize,
|
||||
isSticky,
|
||||
getItemDepth,
|
||||
rangeExtractor: customRangeExtractor,
|
||||
itemContainerStyle,
|
||||
scrollerStyle,
|
||||
@@ -115,7 +122,7 @@ function DynamicVirtualList<T>(props: DynamicVirtualListProps<T>) {
|
||||
const internalScrollerRef = useRef<HTMLDivElement>(null)
|
||||
const scrollerRef = internalScrollerRef
|
||||
|
||||
const activeStickyIndexRef = useRef(0)
|
||||
const activeStickyIndexesRef = useRef<number[]>([])
|
||||
|
||||
const stickyIndexes = useMemo(() => {
|
||||
if (!isSticky) return []
|
||||
@@ -124,21 +131,54 @@ function DynamicVirtualList<T>(props: DynamicVirtualListProps<T>) {
|
||||
|
||||
const internalStickyRangeExtractor = useCallback(
|
||||
(range: Range) => {
|
||||
// The active sticky index is the last one that is before or at the start of the visible range
|
||||
const newActiveStickyIndex =
|
||||
[...stickyIndexes].reverse().find((index) => range.startIndex >= index) ?? stickyIndexes[0] ?? 0
|
||||
const activeStickies: number[] = []
|
||||
|
||||
if (newActiveStickyIndex !== activeStickyIndexRef.current) {
|
||||
activeStickyIndexRef.current = newActiveStickyIndex
|
||||
if (getItemDepth) {
|
||||
// With depth information, we can build a proper ancestor chain
|
||||
// Find all sticky items before the visible range
|
||||
const stickiesBeforeRange = stickyIndexes.filter((index) => index < range.startIndex)
|
||||
|
||||
if (stickiesBeforeRange.length > 0) {
|
||||
// Find the depth of the first visible item (or last sticky before it)
|
||||
const firstVisibleIndex = range.startIndex
|
||||
const referenceDepth = getItemDepth(firstVisibleIndex)
|
||||
|
||||
// Build ancestor chain: include all sticky parents
|
||||
const ancestorChain: number[] = []
|
||||
let minDepth = referenceDepth
|
||||
|
||||
// Walk backwards from the last sticky before visible range
|
||||
for (let i = stickiesBeforeRange.length - 1; i >= 0; i--) {
|
||||
const stickyIndex = stickiesBeforeRange[i]
|
||||
const stickyDepth = getItemDepth(stickyIndex)
|
||||
|
||||
// Include this sticky if it's a parent (smaller depth) of our reference
|
||||
if (stickyDepth < minDepth) {
|
||||
ancestorChain.unshift(stickyIndex)
|
||||
minDepth = stickyDepth
|
||||
}
|
||||
}
|
||||
|
||||
activeStickies.push(...ancestorChain)
|
||||
}
|
||||
} else {
|
||||
// Fallback: without depth info, just use the last sticky before range
|
||||
const lastStickyBeforeRange = [...stickyIndexes].reverse().find((index) => index < range.startIndex)
|
||||
if (lastStickyBeforeRange !== undefined) {
|
||||
activeStickies.push(lastStickyBeforeRange)
|
||||
}
|
||||
}
|
||||
|
||||
// Merge the active sticky index and the default range extractor
|
||||
const next = new Set([activeStickyIndexRef.current, ...defaultRangeExtractor(range)])
|
||||
// Update the ref with current active stickies
|
||||
activeStickyIndexesRef.current = activeStickies
|
||||
|
||||
// Merge the active sticky indexes and the default range extractor
|
||||
const next = new Set([...activeStickyIndexesRef.current, ...defaultRangeExtractor(range)])
|
||||
|
||||
// Sort the set to maintain proper order
|
||||
return [...next].sort((a, b) => a - b)
|
||||
},
|
||||
[stickyIndexes]
|
||||
[stickyIndexes, getItemDepth]
|
||||
)
|
||||
|
||||
const rangeExtractor = customRangeExtractor ?? (isSticky ? internalStickyRangeExtractor : undefined)
|
||||
@@ -221,14 +261,47 @@ function DynamicVirtualList<T>(props: DynamicVirtualListProps<T>) {
|
||||
}}>
|
||||
{virtualItems.map((virtualItem) => {
|
||||
const isItemSticky = stickyIndexes.includes(virtualItem.index)
|
||||
const isItemActiveSticky = isItemSticky && activeStickyIndexRef.current === virtualItem.index
|
||||
const isItemActiveSticky = isItemSticky && activeStickyIndexesRef.current.includes(virtualItem.index)
|
||||
|
||||
// Calculate the sticky offset for multi-level sticky headers
|
||||
const activeStickyIndex = isItemActiveSticky ? activeStickyIndexesRef.current.indexOf(virtualItem.index) : -1
|
||||
|
||||
// Calculate cumulative offset based on actual sizes of previous sticky items
|
||||
let stickyOffset = 0
|
||||
if (activeStickyIndex >= 0) {
|
||||
for (let i = 0; i < activeStickyIndex; i++) {
|
||||
const prevStickyIndex = activeStickyIndexesRef.current[i]
|
||||
stickyOffset += estimateSize(prevStickyIndex)
|
||||
}
|
||||
}
|
||||
|
||||
// Check if this item is visually covered by sticky items
|
||||
// If covered, disable pointer events to prevent hover/click bleeding through
|
||||
const isCoveredBySticky = (() => {
|
||||
if (!activeStickyIndexesRef.current.length) return false
|
||||
if (isItemActiveSticky) return false // Sticky items themselves are not covered
|
||||
|
||||
// Calculate if this item's visual position is under any sticky header
|
||||
const itemVisualTop = virtualItem.start
|
||||
let totalStickyHeight = 0
|
||||
for (const stickyIdx of activeStickyIndexesRef.current) {
|
||||
totalStickyHeight += estimateSize(stickyIdx)
|
||||
}
|
||||
|
||||
// If item starts within the sticky area, it's covered
|
||||
return itemVisualTop < totalStickyHeight
|
||||
})()
|
||||
|
||||
const style: React.CSSProperties = {
|
||||
...itemContainerStyle,
|
||||
position: isItemActiveSticky ? 'sticky' : 'absolute',
|
||||
top: 0,
|
||||
top: isItemActiveSticky ? stickyOffset : 0,
|
||||
left: 0,
|
||||
zIndex: isItemSticky ? 1 : undefined,
|
||||
zIndex: isItemActiveSticky ? 1000 + (100 - activeStickyIndex) : isItemSticky ? 999 : 0,
|
||||
pointerEvents: isCoveredBySticky ? 'none' : 'auto',
|
||||
...(isItemActiveSticky && {
|
||||
backgroundColor: 'var(--color-background)'
|
||||
}),
|
||||
...(horizontal
|
||||
? {
|
||||
transform: isItemActiveSticky ? undefined : `translateX(${virtualItem.start}px)`,
|
||||
|
||||
@@ -1,55 +1,33 @@
|
||||
import {
|
||||
isImageEnhancementModel,
|
||||
isPureGenerateImageModel,
|
||||
isQwenReasoningModel,
|
||||
isSupportedThinkingTokenQwenModel,
|
||||
isVisionModel
|
||||
isVisionModel,
|
||||
isWebSearchModel
|
||||
} from '@renderer/config/models'
|
||||
import type { Model } from '@renderer/types'
|
||||
import { beforeEach, describe, expect, test, vi } from 'vitest'
|
||||
|
||||
vi.mock('@renderer/store/llm', () => ({
|
||||
initialState: {}
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store', () => ({
|
||||
default: {
|
||||
getState: () => ({
|
||||
llm: {
|
||||
settings: {}
|
||||
}
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
const getProviderByModelMock = vi.fn()
|
||||
const isEmbeddingModelMock = vi.fn()
|
||||
const isRerankModelMock = vi.fn()
|
||||
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
getProviderByModel: (...args: any[]) => getProviderByModelMock(...args),
|
||||
getAssistantSettings: vi.fn(),
|
||||
getDefaultAssistant: vi.fn().mockReturnValue({
|
||||
id: 'default',
|
||||
name: 'Default Assistant',
|
||||
prompt: '',
|
||||
settings: {}
|
||||
})
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/config/models/embedding', () => ({
|
||||
isEmbeddingModel: (...args: any[]) => isEmbeddingModelMock(...args),
|
||||
isRerankModel: (...args: any[]) => isRerankModelMock(...args)
|
||||
}))
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
getProviderByModelMock.mockReturnValue({ type: 'openai-response' } as any)
|
||||
isEmbeddingModelMock.mockReturnValue(false)
|
||||
isRerankModelMock.mockReturnValue(false)
|
||||
})
|
||||
|
||||
// Suggested test cases
|
||||
describe('Qwen Model Detection', () => {
|
||||
beforeEach(() => {
|
||||
vi.mock('@renderer/store/llm', () => ({
|
||||
initialState: {}
|
||||
}))
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
getProviderByModel: vi.fn().mockReturnValue({ id: 'cherryai' })
|
||||
}))
|
||||
vi.mock('@renderer/store', () => ({
|
||||
default: {
|
||||
getState: () => ({
|
||||
llm: {
|
||||
settings: {}
|
||||
}
|
||||
})
|
||||
}
|
||||
}))
|
||||
})
|
||||
test('isQwenReasoningModel', () => {
|
||||
expect(isQwenReasoningModel({ id: 'qwen3-thinking' } as Model)).toBe(true)
|
||||
expect(isQwenReasoningModel({ id: 'qwen3-instruct' } as Model)).toBe(false)
|
||||
@@ -78,6 +56,14 @@ describe('Qwen Model Detection', () => {
|
||||
})
|
||||
|
||||
describe('Vision Model Detection', () => {
|
||||
beforeEach(() => {
|
||||
vi.mock('@renderer/store/llm', () => ({
|
||||
initialState: {}
|
||||
}))
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
getProviderByModel: vi.fn().mockReturnValue({ id: 'cherryai' })
|
||||
}))
|
||||
})
|
||||
test('isVisionModel', () => {
|
||||
expect(isVisionModel({ id: 'qwen-vl-max' } as Model)).toBe(true)
|
||||
expect(isVisionModel({ id: 'qwen-omni-turbo' } as Model)).toBe(true)
|
||||
@@ -89,4 +75,25 @@ describe('Vision Model Detection', () => {
|
||||
expect(isImageEnhancementModel({ id: 'qwen-image-edit' } as Model)).toBe(true)
|
||||
expect(isImageEnhancementModel({ id: 'grok-2-image-latest' } as Model)).toBe(true)
|
||||
})
|
||||
test('isPureGenerateImageModel', () => {
|
||||
expect(isPureGenerateImageModel({ id: 'gpt-image-1' } as Model)).toBe(true)
|
||||
expect(isPureGenerateImageModel({ id: 'gemini-2.5-flash-image-preview' } as Model)).toBe(true)
|
||||
expect(isPureGenerateImageModel({ id: 'gemini-2.0-flash-preview-image-generation' } as Model)).toBe(true)
|
||||
expect(isPureGenerateImageModel({ id: 'grok-2-image-latest' } as Model)).toBe(true)
|
||||
expect(isPureGenerateImageModel({ id: 'gpt-4o' } as Model)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Web Search Model Detection', () => {
|
||||
beforeEach(() => {
|
||||
vi.mock('@renderer/store/llm', () => ({
|
||||
initialState: {}
|
||||
}))
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
getProviderByModel: vi.fn().mockReturnValue({ id: 'cherryai' })
|
||||
}))
|
||||
})
|
||||
test('isWebSearchModel', () => {
|
||||
expect(isWebSearchModel({ id: 'grok-2-image-latest' } as Model)).toBe(false)
|
||||
})
|
||||
})
|
||||
520
src/renderer/src/config/__test__/reasoning.test.ts
Normal file
520
src/renderer/src/config/__test__/reasoning.test.ts
Normal file
@@ -0,0 +1,520 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
isDoubaoSeedAfter251015,
|
||||
isDoubaoThinkingAutoModel,
|
||||
isGeminiReasoningModel,
|
||||
isLingReasoningModel,
|
||||
isSupportedThinkingTokenGeminiModel
|
||||
} from '../models/reasoning'
|
||||
|
||||
vi.mock('@renderer/store', () => ({
|
||||
default: {
|
||||
getState: () => ({
|
||||
llm: {
|
||||
settings: {}
|
||||
}
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
// FIXME: Idk why it's imported. Maybe circular dependency somewhere
|
||||
vi.mock('@renderer/services/AssistantService.ts', () => ({
|
||||
getDefaultAssistant: () => {
|
||||
return {
|
||||
id: 'default',
|
||||
name: 'default',
|
||||
emoji: '😀',
|
||||
prompt: '',
|
||||
topics: [],
|
||||
messages: [],
|
||||
type: 'assistant',
|
||||
regularPhrases: [],
|
||||
settings: {}
|
||||
}
|
||||
}
|
||||
}))
|
||||
|
||||
describe('Doubao Models', () => {
|
||||
describe('isDoubaoThinkingAutoModel', () => {
|
||||
it('should return false for invalid models', () => {
|
||||
expect(
|
||||
isDoubaoThinkingAutoModel({
|
||||
id: 'doubao-seed-1-6-251015',
|
||||
name: 'doubao-seed-1-6-251015',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isDoubaoThinkingAutoModel({
|
||||
id: 'doubao-seed-1-6-lite-251015',
|
||||
name: 'doubao-seed-1-6-lite-251015',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isDoubaoThinkingAutoModel({
|
||||
id: 'doubao-seed-1-6-thinking-250715',
|
||||
name: 'doubao-seed-1-6-thinking-250715',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isDoubaoThinkingAutoModel({
|
||||
id: 'doubao-seed-1-6-flash',
|
||||
name: 'doubao-seed-1-6-flash',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isDoubaoThinkingAutoModel({
|
||||
id: 'doubao-seed-1-6-thinking',
|
||||
name: 'doubao-seed-1-6-thinking',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
it('should return true for valid models', () => {
|
||||
expect(
|
||||
isDoubaoThinkingAutoModel({
|
||||
id: 'doubao-seed-1-6-250615',
|
||||
name: 'doubao-seed-1-6-250615',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isDoubaoThinkingAutoModel({
|
||||
id: 'Doubao-Seed-1.6',
|
||||
name: 'Doubao-Seed-1.6',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isDoubaoThinkingAutoModel({
|
||||
id: 'doubao-1-5-thinking-pro-m',
|
||||
name: 'doubao-1-5-thinking-pro-m',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isDoubaoThinkingAutoModel({
|
||||
id: 'doubao-seed-1.6-lite',
|
||||
name: 'doubao-seed-1.6-lite',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isDoubaoThinkingAutoModel({
|
||||
id: 'doubao-1-5-thinking-pro-m-12345',
|
||||
name: 'doubao-1-5-thinking-pro-m-12345',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isDoubaoSeedAfter251015', () => {
|
||||
it('should return true for models matching the pattern', () => {
|
||||
expect(
|
||||
isDoubaoSeedAfter251015({
|
||||
id: 'doubao-seed-1-6-251015',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isDoubaoSeedAfter251015({
|
||||
id: 'doubao-seed-1-6-lite-251015',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for models not matching the pattern', () => {
|
||||
expect(
|
||||
isDoubaoSeedAfter251015({
|
||||
id: 'doubao-seed-1-6-250615',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isDoubaoSeedAfter251015({
|
||||
id: 'Doubao-Seed-1.6',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isDoubaoSeedAfter251015({
|
||||
id: 'doubao-1-5-thinking-pro-m',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isDoubaoSeedAfter251015({
|
||||
id: 'doubao-seed-1-6-lite-251016',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
describe('Ling Models', () => {
|
||||
describe('isLingReasoningModel', () => {
|
||||
it('should return false for ling variants', () => {
|
||||
expect(
|
||||
isLingReasoningModel({
|
||||
id: 'ling-1t',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isLingReasoningModel({
|
||||
id: 'ling-flash-2.0',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isLingReasoningModel({
|
||||
id: 'ling-mini-2.0',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
it('should return true for ring variants', () => {
|
||||
expect(
|
||||
isLingReasoningModel({
|
||||
id: 'ring-1t',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isLingReasoningModel({
|
||||
id: 'ring-flash-2.0',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isLingReasoningModel({
|
||||
id: 'ring-mini-2.0',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('Gemini Models', () => {
|
||||
describe('isSupportedThinkingTokenGeminiModel', () => {
|
||||
it('should return true for gemini 2.5 models', () => {
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'gemini-2.5-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'gemini-2.5-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'gemini-2.5-flash-latest',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'gemini-2.5-pro-latest',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for gemini latest models', () => {
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'gemini-flash-latest',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'gemini-pro-latest',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'gemini-flash-lite-latest',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for gemini 3 models', () => {
|
||||
// Preview versions
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'gemini-3-pro-preview',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'google/gemini-3-pro-preview',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
// Future stable versions
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'gemini-3-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'gemini-3-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'google/gemini-3-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'google/gemini-3-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for image and tts models', () => {
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'gemini-2.5-flash-image',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'gemini-2.5-flash-preview-tts',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for older gemini models', () => {
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'gemini-1.5-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'gemini-1.5-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isSupportedThinkingTokenGeminiModel({
|
||||
id: 'gemini-1.0-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isGeminiReasoningModel', () => {
|
||||
it('should return true for gemini thinking models', () => {
|
||||
expect(
|
||||
isGeminiReasoningModel({
|
||||
id: 'gemini-2.0-flash-thinking',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isGeminiReasoningModel({
|
||||
id: 'gemini-thinking-exp',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for supported thinking token gemini models', () => {
|
||||
expect(
|
||||
isGeminiReasoningModel({
|
||||
id: 'gemini-2.5-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isGeminiReasoningModel({
|
||||
id: 'gemini-2.5-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for gemini-3 models', () => {
|
||||
// Preview versions
|
||||
expect(
|
||||
isGeminiReasoningModel({
|
||||
id: 'gemini-3-pro-preview',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isGeminiReasoningModel({
|
||||
id: 'google/gemini-3-pro-preview',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
// Future stable versions
|
||||
expect(
|
||||
isGeminiReasoningModel({
|
||||
id: 'gemini-3-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isGeminiReasoningModel({
|
||||
id: 'gemini-3-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isGeminiReasoningModel({
|
||||
id: 'google/gemini-3-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isGeminiReasoningModel({
|
||||
id: 'google/gemini-3-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for older gemini models without thinking', () => {
|
||||
expect(
|
||||
isGeminiReasoningModel({
|
||||
id: 'gemini-1.5-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
expect(
|
||||
isGeminiReasoningModel({
|
||||
id: 'gemini-1.5-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false for undefined model', () => {
|
||||
expect(isGeminiReasoningModel(undefined)).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
167
src/renderer/src/config/__test__/vision.test.ts
Normal file
167
src/renderer/src/config/__test__/vision.test.ts
Normal file
@@ -0,0 +1,167 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { isVisionModel } from '../models/vision'
|
||||
|
||||
vi.mock('@renderer/store', () => ({
|
||||
default: {
|
||||
getState: () => ({
|
||||
llm: {
|
||||
settings: {}
|
||||
}
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
// FIXME: Idk why it's imported. Maybe circular dependency somewhere
|
||||
vi.mock('@renderer/services/AssistantService.ts', () => ({
|
||||
getDefaultAssistant: () => {
|
||||
return {
|
||||
id: 'default',
|
||||
name: 'default',
|
||||
emoji: '😀',
|
||||
prompt: '',
|
||||
topics: [],
|
||||
messages: [],
|
||||
type: 'assistant',
|
||||
regularPhrases: [],
|
||||
settings: {}
|
||||
}
|
||||
},
|
||||
getProviderByModel: () => null
|
||||
}))
|
||||
|
||||
describe('isVisionModel', () => {
|
||||
describe('Gemini Models', () => {
|
||||
it('should return true for gemini 1.5 models', () => {
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-1.5-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-1.5-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for gemini 2.x models', () => {
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-2.0-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-2.0-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-2.5-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-2.5-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for gemini latest models', () => {
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-flash-latest',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-pro-latest',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-flash-lite-latest',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for gemini 3 models', () => {
|
||||
// Preview versions
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-3-pro-preview',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
// Future stable versions
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-3-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-3-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for gemini exp models', () => {
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-exp-1206',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for gemini 1.0 models', () => {
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-1.0-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
64
src/renderer/src/config/__test__/websearch.test.ts
Normal file
64
src/renderer/src/config/__test__/websearch.test.ts
Normal file
@@ -0,0 +1,64 @@
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { GEMINI_SEARCH_REGEX } from '../models/websearch'
|
||||
|
||||
vi.mock('@renderer/store', () => ({
|
||||
default: {
|
||||
getState: () => ({
|
||||
llm: {
|
||||
settings: {}
|
||||
}
|
||||
})
|
||||
}
|
||||
}))
|
||||
|
||||
// FIXME: Idk why it's imported. Maybe circular dependency somewhere
|
||||
vi.mock('@renderer/services/AssistantService.ts', () => ({
|
||||
getDefaultAssistant: () => {
|
||||
return {
|
||||
id: 'default',
|
||||
name: 'default',
|
||||
emoji: '😀',
|
||||
prompt: '',
|
||||
topics: [],
|
||||
messages: [],
|
||||
type: 'assistant',
|
||||
regularPhrases: [],
|
||||
settings: {}
|
||||
}
|
||||
},
|
||||
getProviderByModel: () => null
|
||||
}))
|
||||
|
||||
describe('Gemini Search Models', () => {
|
||||
describe('GEMINI_SEARCH_REGEX', () => {
|
||||
it('should match gemini 2.x models', () => {
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-flash')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-pro')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-pro')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash-latest')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-pro-latest')).toBe(true)
|
||||
})
|
||||
|
||||
it('should match gemini latest models', () => {
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-flash-latest')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-pro-latest')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-flash-lite-latest')).toBe(true)
|
||||
})
|
||||
|
||||
it('should match gemini 3 models', () => {
|
||||
// Preview versions
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro-preview')).toBe(true)
|
||||
// Future stable versions
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-3-flash')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro')).toBe(true)
|
||||
})
|
||||
|
||||
it('should not match older gemini models', () => {
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-1.5-flash')).toBe(false)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-1.5-pro')).toBe(false)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-1.0-pro')).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,101 +0,0 @@
|
||||
import type { Model } from '@renderer/types'
|
||||
import { describe, expect, it, vi } from 'vitest'
|
||||
|
||||
vi.mock('@renderer/hooks/useStore', () => ({
|
||||
getStoreProviders: vi.fn(() => [])
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store', () => ({
|
||||
__esModule: true,
|
||||
default: {
|
||||
getState: () => ({
|
||||
llm: { providers: [] },
|
||||
settings: {}
|
||||
})
|
||||
},
|
||||
useAppDispatch: vi.fn(),
|
||||
useAppSelector: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store/settings', () => {
|
||||
const noop = vi.fn()
|
||||
return new Proxy(
|
||||
{},
|
||||
{
|
||||
get: (_target, prop) => {
|
||||
if (prop === 'initialState') {
|
||||
return {}
|
||||
}
|
||||
return noop
|
||||
}
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
vi.mock('@renderer/hooks/useSettings', () => ({
|
||||
useSettings: vi.fn(() => ({})),
|
||||
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
|
||||
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
|
||||
getStoreSetting: vi.fn()
|
||||
}))
|
||||
|
||||
import { isEmbeddingModel, isRerankModel } from '../embedding'
|
||||
|
||||
const createModel = (overrides: Partial<Model> = {}): Model => ({
|
||||
id: 'test-model',
|
||||
name: 'Test Model',
|
||||
provider: 'openai',
|
||||
group: 'Test',
|
||||
...overrides
|
||||
})
|
||||
|
||||
describe('isEmbeddingModel', () => {
|
||||
it('returns true for ids that match the embedding regex', () => {
|
||||
expect(isEmbeddingModel(createModel({ id: 'Text-Embedding-3-Small' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('returns false for rerank models even if they match embedding patterns', () => {
|
||||
const model = createModel({ id: 'rerank-qa', name: 'rerank-qa' })
|
||||
expect(isRerankModel(model)).toBe(true)
|
||||
expect(isEmbeddingModel(model)).toBe(false)
|
||||
})
|
||||
|
||||
it('honors user overrides for embedding capability', () => {
|
||||
const model = createModel({
|
||||
id: 'text-embedding-3-small',
|
||||
capabilities: [{ type: 'embedding', isUserSelected: false }]
|
||||
})
|
||||
expect(isEmbeddingModel(model)).toBe(false)
|
||||
})
|
||||
|
||||
it('uses the model name when provider is doubao', () => {
|
||||
const model = createModel({
|
||||
id: 'custom-id',
|
||||
name: 'BGE-Large-zh-v1.5',
|
||||
provider: 'doubao'
|
||||
})
|
||||
expect(isEmbeddingModel(model)).toBe(true)
|
||||
})
|
||||
|
||||
it('returns false for anthropic provider models', () => {
|
||||
const model = createModel({
|
||||
id: 'text-embedding-ada-002',
|
||||
provider: 'anthropic'
|
||||
})
|
||||
expect(isEmbeddingModel(model)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isRerankModel', () => {
|
||||
it('identifies ids that match rerank regex', () => {
|
||||
expect(isRerankModel(createModel({ id: 'jina-rerank-v2-base' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('honors user overrides for rerank capability', () => {
|
||||
const model = createModel({
|
||||
id: 'jina-rerank-v2-base',
|
||||
capabilities: [{ type: 'rerank', isUserSelected: false }]
|
||||
})
|
||||
expect(isRerankModel(model)).toBe(false)
|
||||
})
|
||||
})
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,137 +0,0 @@
|
||||
import type { Model } from '@renderer/types'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { isEmbeddingModel, isRerankModel } from '../embedding'
|
||||
import { isDeepSeekHybridInferenceModel } from '../reasoning'
|
||||
import { isFunctionCallingModel } from '../tooluse'
|
||||
import { isPureGenerateImageModel, isTextToImageModel } from '../vision'
|
||||
|
||||
vi.mock('@renderer/hooks/useStore', () => ({
|
||||
getStoreProviders: vi.fn(() => [])
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store', () => ({
|
||||
__esModule: true,
|
||||
default: {
|
||||
getState: () => ({
|
||||
llm: { providers: [] },
|
||||
settings: {}
|
||||
})
|
||||
},
|
||||
useAppDispatch: vi.fn(),
|
||||
useAppSelector: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store/settings', () => {
|
||||
const noop = vi.fn()
|
||||
return new Proxy(
|
||||
{},
|
||||
{
|
||||
get: (_target, prop) => {
|
||||
if (prop === 'initialState') {
|
||||
return {}
|
||||
}
|
||||
return noop
|
||||
}
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
vi.mock('@renderer/hooks/useSettings', () => ({
|
||||
useSettings: vi.fn(() => ({})),
|
||||
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
|
||||
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
|
||||
getStoreSetting: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('../embedding', () => ({
|
||||
isEmbeddingModel: vi.fn(),
|
||||
isRerankModel: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('../vision', () => ({
|
||||
isPureGenerateImageModel: vi.fn(),
|
||||
isTextToImageModel: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('../reasoning', () => ({
|
||||
isDeepSeekHybridInferenceModel: vi.fn()
|
||||
}))
|
||||
|
||||
const createModel = (overrides: Partial<Model> = {}): Model => ({
|
||||
id: 'gpt-4o',
|
||||
name: 'gpt-4o',
|
||||
provider: 'openai',
|
||||
group: 'OpenAI',
|
||||
...overrides
|
||||
})
|
||||
|
||||
const embeddingMock = vi.mocked(isEmbeddingModel)
|
||||
const rerankMock = vi.mocked(isRerankModel)
|
||||
const pureImageMock = vi.mocked(isPureGenerateImageModel)
|
||||
const textToImageMock = vi.mocked(isTextToImageModel)
|
||||
const deepSeekHybridMock = vi.mocked(isDeepSeekHybridInferenceModel)
|
||||
|
||||
describe('isFunctionCallingModel', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
embeddingMock.mockReturnValue(false)
|
||||
rerankMock.mockReturnValue(false)
|
||||
pureImageMock.mockReturnValue(false)
|
||||
textToImageMock.mockReturnValue(false)
|
||||
deepSeekHybridMock.mockReturnValue(false)
|
||||
})
|
||||
|
||||
it('returns false when the model is undefined', () => {
|
||||
expect(isFunctionCallingModel(undefined as unknown as Model)).toBe(false)
|
||||
})
|
||||
|
||||
it('returns false when model is classified as embedding/rerank/image', () => {
|
||||
embeddingMock.mockReturnValueOnce(true)
|
||||
expect(isFunctionCallingModel(createModel())).toBe(false)
|
||||
})
|
||||
|
||||
it('respect manual user overrides', () => {
|
||||
const model = createModel({
|
||||
capabilities: [{ type: 'function_calling', isUserSelected: false }]
|
||||
})
|
||||
expect(isFunctionCallingModel(model)).toBe(false)
|
||||
const enabled = createModel({
|
||||
capabilities: [{ type: 'function_calling', isUserSelected: true }]
|
||||
})
|
||||
expect(isFunctionCallingModel(enabled)).toBe(true)
|
||||
})
|
||||
|
||||
it('matches doubao models by name when regex applies', () => {
|
||||
const doubao = createModel({
|
||||
id: 'custom-model',
|
||||
name: 'Doubao-Seed-1.6-251015',
|
||||
provider: 'doubao'
|
||||
})
|
||||
expect(isFunctionCallingModel(doubao)).toBe(true)
|
||||
})
|
||||
|
||||
it('returns true for regex matches on standard providers', () => {
|
||||
expect(isFunctionCallingModel(createModel({ id: 'gpt-5' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('excludes explicitly blocked ids', () => {
|
||||
expect(isFunctionCallingModel(createModel({ id: 'gemini-1.5-flash' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('forces support for trusted providers', () => {
|
||||
for (const provider of ['deepseek', 'anthropic', 'kimi', 'moonshot']) {
|
||||
expect(isFunctionCallingModel(createModel({ provider }))).toBe(true)
|
||||
}
|
||||
})
|
||||
|
||||
it('returns true when identified as deepseek hybrid inference model', () => {
|
||||
deepSeekHybridMock.mockReturnValueOnce(true)
|
||||
expect(isFunctionCallingModel(createModel({ id: 'deepseek-v3-1', provider: 'custom' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('returns false for deepseek hybrid models behind restricted system providers', () => {
|
||||
deepSeekHybridMock.mockReturnValueOnce(true)
|
||||
expect(isFunctionCallingModel(createModel({ id: 'deepseek-v3-1', provider: 'dashscope' }))).toBe(false)
|
||||
})
|
||||
})
|
||||
@@ -1,280 +0,0 @@
|
||||
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding'
|
||||
import type { Model } from '@renderer/types'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import {
|
||||
isGPT5ProModel,
|
||||
isGPT5SeriesModel,
|
||||
isGPT5SeriesReasoningModel,
|
||||
isGPT51SeriesModel,
|
||||
isOpenAIChatCompletionOnlyModel,
|
||||
isOpenAILLMModel,
|
||||
isOpenAIModel,
|
||||
isOpenAIOpenWeightModel,
|
||||
isOpenAIReasoningModel,
|
||||
isSupportVerbosityModel
|
||||
} from '../openai'
|
||||
import { isQwenMTModel } from '../qwen'
|
||||
import {
|
||||
agentModelFilter,
|
||||
getModelSupportedVerbosity,
|
||||
groupQwenModels,
|
||||
isAnthropicModel,
|
||||
isGeminiModel,
|
||||
isGemmaModel,
|
||||
isGenerateImageModels,
|
||||
isMaxTemperatureOneModel,
|
||||
isNotSupportedTextDelta,
|
||||
isNotSupportSystemMessageModel,
|
||||
isNotSupportTemperatureAndTopP,
|
||||
isSupportedFlexServiceTier,
|
||||
isSupportedModel,
|
||||
isSupportFlexServiceTierModel,
|
||||
isVisionModels,
|
||||
isZhipuModel
|
||||
} from '../utils'
|
||||
import { isGenerateImageModel, isTextToImageModel, isVisionModel } from '../vision'
|
||||
import { isOpenAIWebSearchChatCompletionOnlyModel } from '../websearch'
|
||||
|
||||
vi.mock('@renderer/hooks/useStore', () => ({
|
||||
getStoreProviders: vi.fn(() => [])
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store', () => ({
|
||||
__esModule: true,
|
||||
default: {
|
||||
getState: () => ({
|
||||
llm: { providers: [] },
|
||||
settings: {}
|
||||
})
|
||||
},
|
||||
useAppDispatch: vi.fn(),
|
||||
useAppSelector: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store/settings', () => {
|
||||
const noop = vi.fn()
|
||||
return new Proxy(
|
||||
{},
|
||||
{
|
||||
get: (_target, prop) => {
|
||||
if (prop === 'initialState') {
|
||||
return {}
|
||||
}
|
||||
return noop
|
||||
}
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
vi.mock('@renderer/hooks/useSettings', () => ({
|
||||
useSettings: vi.fn(() => ({})),
|
||||
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
|
||||
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
|
||||
getStoreSetting: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/config/models/embedding', () => ({
|
||||
isEmbeddingModel: vi.fn(),
|
||||
isRerankModel: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('../vision', () => ({
|
||||
isGenerateImageModel: vi.fn(),
|
||||
isTextToImageModel: vi.fn(),
|
||||
isVisionModel: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock(import('../openai'), async (importOriginal) => {
|
||||
const actual = await importOriginal()
|
||||
return {
|
||||
...actual,
|
||||
isOpenAIReasoningModel: vi.fn()
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('../websearch', () => ({
|
||||
isOpenAIWebSearchChatCompletionOnlyModel: vi.fn()
|
||||
}))
|
||||
|
||||
const createModel = (overrides: Partial<Model> = {}): Model => ({
|
||||
id: 'gpt-4o',
|
||||
name: 'gpt-4o',
|
||||
provider: 'openai',
|
||||
group: 'OpenAI',
|
||||
...overrides
|
||||
})
|
||||
|
||||
const embeddingMock = vi.mocked(isEmbeddingModel)
|
||||
const rerankMock = vi.mocked(isRerankModel)
|
||||
const visionMock = vi.mocked(isVisionModel)
|
||||
const textToImageMock = vi.mocked(isTextToImageModel)
|
||||
const generateImageMock = vi.mocked(isGenerateImageModel)
|
||||
const reasoningMock = vi.mocked(isOpenAIReasoningModel)
|
||||
const openAIWebSearchOnlyMock = vi.mocked(isOpenAIWebSearchChatCompletionOnlyModel)
|
||||
|
||||
describe('model utils', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
embeddingMock.mockReturnValue(false)
|
||||
rerankMock.mockReturnValue(false)
|
||||
visionMock.mockReturnValue(true)
|
||||
textToImageMock.mockReturnValue(false)
|
||||
generateImageMock.mockReturnValue(true)
|
||||
reasoningMock.mockReturnValue(false)
|
||||
openAIWebSearchOnlyMock.mockReturnValue(false)
|
||||
})
|
||||
|
||||
it('detects OpenAI LLM models through reasoning and GPT prefix', () => {
|
||||
expect(isOpenAILLMModel(undefined as unknown as Model)).toBe(false)
|
||||
expect(isOpenAILLMModel(createModel({ id: 'gpt-4o-image' }))).toBe(false)
|
||||
|
||||
reasoningMock.mockReturnValueOnce(true)
|
||||
expect(isOpenAILLMModel(createModel({ id: 'o1-preview' }))).toBe(true)
|
||||
|
||||
expect(isOpenAILLMModel(createModel({ id: 'GPT-5-turbo' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('detects OpenAI models via GPT prefix or reasoning support', () => {
|
||||
expect(isOpenAIModel(createModel({ id: 'gpt-4.1' }))).toBe(true)
|
||||
reasoningMock.mockReturnValueOnce(true)
|
||||
expect(isOpenAIModel(createModel({ id: 'o3' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('evaluates support for flex service tier and alias helper', () => {
|
||||
expect(isSupportFlexServiceTierModel(createModel({ id: 'o3' }))).toBe(true)
|
||||
expect(isSupportFlexServiceTierModel(createModel({ id: 'o3-mini' }))).toBe(false)
|
||||
expect(isSupportFlexServiceTierModel(createModel({ id: 'o4-mini' }))).toBe(true)
|
||||
expect(isSupportFlexServiceTierModel(createModel({ id: 'gpt-5-preview' }))).toBe(true)
|
||||
expect(isSupportedFlexServiceTier(createModel({ id: 'gpt-4o' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('detects verbosity support for GPT-5+ families', () => {
|
||||
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5' }))).toBe(true)
|
||||
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5-chat' }))).toBe(false)
|
||||
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5.1-preview' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('limits verbosity controls for GPT-5 Pro models', () => {
|
||||
const proModel = createModel({ id: 'gpt-5-pro' })
|
||||
const previewModel = createModel({ id: 'gpt-5-preview' })
|
||||
expect(getModelSupportedVerbosity(proModel)).toEqual([undefined, 'high'])
|
||||
expect(getModelSupportedVerbosity(previewModel)).toEqual([undefined, 'low', 'medium', 'high'])
|
||||
expect(isGPT5ProModel(proModel)).toBe(true)
|
||||
expect(isGPT5ProModel(previewModel)).toBe(false)
|
||||
})
|
||||
|
||||
it('identifies OpenAI chat-completion-only models', () => {
|
||||
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
|
||||
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'o1-mini' }))).toBe(true)
|
||||
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'gpt-4o' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('filters unsupported OpenAI catalog entries', () => {
|
||||
expect(isSupportedModel({ id: 'gpt-4', object: 'model' } as any)).toBe(true)
|
||||
expect(isSupportedModel({ id: 'tts-1', object: 'model' } as any)).toBe(false)
|
||||
})
|
||||
|
||||
it('calculates temperature/top-p support correctly', () => {
|
||||
const model = createModel({ id: 'o1' })
|
||||
reasoningMock.mockReturnValue(true)
|
||||
expect(isNotSupportTemperatureAndTopP(model)).toBe(true)
|
||||
|
||||
const openWeight = createModel({ id: 'gpt-oss-debug' })
|
||||
expect(isNotSupportTemperatureAndTopP(openWeight)).toBe(false)
|
||||
|
||||
const chatOnly = createModel({ id: 'o1-preview' })
|
||||
reasoningMock.mockReturnValue(false)
|
||||
expect(isNotSupportTemperatureAndTopP(chatOnly)).toBe(true)
|
||||
|
||||
const qwenMt = createModel({ id: 'qwen-mt-large', provider: 'aliyun' })
|
||||
expect(isNotSupportTemperatureAndTopP(qwenMt)).toBe(true)
|
||||
})
|
||||
|
||||
it('handles gemma and gemini detections plus zhipu tagging', () => {
|
||||
expect(isGemmaModel(createModel({ id: 'Gemma-3-27B' }))).toBe(true)
|
||||
expect(isGemmaModel(createModel({ group: 'Gemma' }))).toBe(true)
|
||||
expect(isGemmaModel(createModel({ id: 'gpt-4o' }))).toBe(false)
|
||||
|
||||
expect(isGeminiModel(createModel({ id: 'Gemini-2.0' }))).toBe(true)
|
||||
|
||||
expect(isZhipuModel(createModel({ provider: 'zhipu' }))).toBe(true)
|
||||
expect(isZhipuModel(createModel({ provider: 'openai' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('groups qwen models by prefix', () => {
|
||||
const qwen = createModel({ id: 'Qwen-7B', provider: 'qwen', name: 'Qwen-7B' })
|
||||
const qwenOmni = createModel({ id: 'qwen2.5-omni', name: 'qwen2.5-omni' })
|
||||
const other = createModel({ id: 'deepseek-v3', group: 'DeepSeek' })
|
||||
|
||||
const grouped = groupQwenModels([qwen, qwenOmni, other])
|
||||
expect(Object.keys(grouped)).toContain('qwen-7b')
|
||||
expect(Object.keys(grouped)).toContain('qwen2.5')
|
||||
expect(grouped.DeepSeek).toContain(other)
|
||||
})
|
||||
|
||||
it('aggregates boolean helpers based on regex rules', () => {
|
||||
expect(isAnthropicModel(createModel({ id: 'claude-3.5' }))).toBe(true)
|
||||
expect(isQwenMTModel(createModel({ id: 'qwen-mt-large' }))).toBe(true)
|
||||
expect(isNotSupportedTextDelta(createModel({ id: 'qwen-mt-large' }))).toBe(true)
|
||||
expect(isNotSupportSystemMessageModel(createModel({ id: 'gemma-moe' }))).toBe(true)
|
||||
expect(isOpenAIOpenWeightModel(createModel({ id: 'gpt-oss-free' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('evaluates GPT-5 family helpers', () => {
|
||||
expect(isGPT5SeriesModel(createModel({ id: 'gpt-5-preview' }))).toBe(true)
|
||||
expect(isGPT5SeriesModel(createModel({ id: 'gpt-5.1-preview' }))).toBe(false)
|
||||
expect(isGPT51SeriesModel(createModel({ id: 'gpt-5.1-mini' }))).toBe(true)
|
||||
expect(isGPT5SeriesReasoningModel(createModel({ id: 'gpt-5-prompt' }))).toBe(true)
|
||||
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5-chat' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('wraps generate/vision helpers that operate on arrays', () => {
|
||||
const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })]
|
||||
expect(isVisionModels(models)).toBe(true)
|
||||
visionMock.mockReturnValueOnce(true).mockReturnValueOnce(false)
|
||||
expect(isVisionModels(models)).toBe(false)
|
||||
|
||||
expect(isGenerateImageModels(models)).toBe(true)
|
||||
generateImageMock.mockReturnValueOnce(true).mockReturnValueOnce(false)
|
||||
expect(isGenerateImageModels(models)).toBe(false)
|
||||
})
|
||||
|
||||
it('filters models for agent usage', () => {
|
||||
expect(agentModelFilter(createModel())).toBe(true)
|
||||
|
||||
embeddingMock.mockReturnValueOnce(true)
|
||||
expect(agentModelFilter(createModel({ id: 'text-embedding' }))).toBe(false)
|
||||
|
||||
embeddingMock.mockReturnValue(false)
|
||||
rerankMock.mockReturnValueOnce(true)
|
||||
expect(agentModelFilter(createModel({ id: 'rerank' }))).toBe(false)
|
||||
|
||||
rerankMock.mockReturnValue(false)
|
||||
textToImageMock.mockReturnValueOnce(true)
|
||||
expect(agentModelFilter(createModel({ id: 'gpt-image-1' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('identifies models with maximum temperature of 1.0', () => {
|
||||
// Zhipu models should have max temperature of 1.0
|
||||
expect(isMaxTemperatureOneModel(createModel({ id: 'glm-4' }))).toBe(true)
|
||||
expect(isMaxTemperatureOneModel(createModel({ id: 'GLM-4-Plus' }))).toBe(true)
|
||||
expect(isMaxTemperatureOneModel(createModel({ id: 'glm-3-turbo' }))).toBe(true)
|
||||
|
||||
// Anthropic models should have max temperature of 1.0
|
||||
expect(isMaxTemperatureOneModel(createModel({ id: 'claude-3.5-sonnet' }))).toBe(true)
|
||||
expect(isMaxTemperatureOneModel(createModel({ id: 'Claude-3-opus' }))).toBe(true)
|
||||
expect(isMaxTemperatureOneModel(createModel({ id: 'claude-2.1' }))).toBe(true)
|
||||
|
||||
// Moonshot models should have max temperature of 1.0
|
||||
expect(isMaxTemperatureOneModel(createModel({ id: 'moonshot-1.0' }))).toBe(true)
|
||||
expect(isMaxTemperatureOneModel(createModel({ id: 'kimi-k2-thinking' }))).toBe(true)
|
||||
expect(isMaxTemperatureOneModel(createModel({ id: 'Moonshot-Pro' }))).toBe(true)
|
||||
|
||||
// Other models should return false
|
||||
expect(isMaxTemperatureOneModel(createModel({ id: 'gpt-4o' }))).toBe(false)
|
||||
expect(isMaxTemperatureOneModel(createModel({ id: 'gpt-4-turbo' }))).toBe(false)
|
||||
expect(isMaxTemperatureOneModel(createModel({ id: 'qwen-max' }))).toBe(false)
|
||||
expect(isMaxTemperatureOneModel(createModel({ id: 'gemini-pro' }))).toBe(false)
|
||||
})
|
||||
})
|
||||
@@ -1,311 +0,0 @@
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import type { Model } from '@renderer/types'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
import { isEmbeddingModel, isRerankModel } from '../embedding'
|
||||
import {
|
||||
isAutoEnableImageGenerationModel,
|
||||
isDedicatedImageGenerationModel,
|
||||
isGenerateImageModel,
|
||||
isImageEnhancementModel,
|
||||
isPureGenerateImageModel,
|
||||
isTextToImageModel,
|
||||
isVisionModel
|
||||
} from '../vision'
|
||||
|
||||
vi.mock('@renderer/hooks/useStore', () => ({
|
||||
getStoreProviders: vi.fn(() => [])
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store', () => ({
|
||||
__esModule: true,
|
||||
default: {
|
||||
getState: () => ({
|
||||
llm: { providers: [] },
|
||||
settings: {}
|
||||
})
|
||||
},
|
||||
useAppDispatch: vi.fn(),
|
||||
useAppSelector: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store/settings', () => {
|
||||
const noop = vi.fn()
|
||||
return new Proxy(
|
||||
{},
|
||||
{
|
||||
get: (_target, prop) => {
|
||||
if (prop === 'initialState') {
|
||||
return {}
|
||||
}
|
||||
return noop
|
||||
}
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
vi.mock('@renderer/hooks/useSettings', () => ({
|
||||
useSettings: vi.fn(() => ({})),
|
||||
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
|
||||
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
|
||||
getStoreSetting: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
getProviderByModel: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('../embedding', () => ({
|
||||
isEmbeddingModel: vi.fn(),
|
||||
isRerankModel: vi.fn()
|
||||
}))
|
||||
|
||||
const createModel = (overrides: Partial<Model> = {}): Model => ({
|
||||
id: 'gpt-4o',
|
||||
name: 'gpt-4o',
|
||||
provider: 'openai',
|
||||
group: 'OpenAI',
|
||||
...overrides
|
||||
})
|
||||
|
||||
const providerMock = vi.mocked(getProviderByModel)
|
||||
const embeddingMock = vi.mocked(isEmbeddingModel)
|
||||
const rerankMock = vi.mocked(isRerankModel)
|
||||
|
||||
describe('vision helpers', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
providerMock.mockReturnValue({ type: 'openai-response' } as any)
|
||||
embeddingMock.mockReturnValue(false)
|
||||
rerankMock.mockReturnValue(false)
|
||||
})
|
||||
|
||||
describe('isGenerateImageModel', () => {
|
||||
it('returns false for embedding/rerank models or missing providers', () => {
|
||||
embeddingMock.mockReturnValueOnce(true)
|
||||
expect(isGenerateImageModel(createModel({ id: 'gpt-image-1' }))).toBe(false)
|
||||
|
||||
embeddingMock.mockReturnValue(false)
|
||||
rerankMock.mockReturnValueOnce(true)
|
||||
expect(isGenerateImageModel(createModel({ id: 'gpt-image-1' }))).toBe(false)
|
||||
|
||||
rerankMock.mockReturnValue(false)
|
||||
providerMock.mockReturnValueOnce(undefined as any)
|
||||
expect(isGenerateImageModel(createModel({ id: 'gpt-image-1' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('detects OpenAI and third-party generative image models', () => {
|
||||
expect(isGenerateImageModel(createModel({ id: 'gpt-4o-mini' }))).toBe(true)
|
||||
|
||||
providerMock.mockReturnValue({ type: 'custom' } as any)
|
||||
expect(isGenerateImageModel(createModel({ id: 'gemini-2.5-flash-image' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('returns false when openai-response model is not on allow list', () => {
|
||||
expect(isGenerateImageModel(createModel({ id: 'gpt-4.2-experimental' }))).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isPureGenerateImageModel', () => {
|
||||
it('requires both generate and text-to-image support', () => {
|
||||
expect(isPureGenerateImageModel(createModel({ id: 'gpt-image-1' }))).toBe(true)
|
||||
expect(isPureGenerateImageModel(createModel({ id: 'gpt-4o' }))).toBe(false)
|
||||
expect(isPureGenerateImageModel(createModel({ id: 'gemini-2.5-flash-image-preview' }))).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('text-to-image helpers', () => {
|
||||
it('matches predefined keywords', () => {
|
||||
expect(isTextToImageModel(createModel({ id: 'midjourney-v6' }))).toBe(true)
|
||||
expect(isTextToImageModel(createModel({ id: 'gpt-4o' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('detects models with restricted image size support and enhancement', () => {
|
||||
expect(isImageEnhancementModel(createModel({ id: 'qwen-image-edit' }))).toBe(true)
|
||||
expect(isImageEnhancementModel(createModel({ id: 'gpt-4o' }))).toBe(false)
|
||||
})
|
||||
|
||||
it('identifies dedicated and auto-enabled image generation models', () => {
|
||||
expect(isDedicatedImageGenerationModel(createModel({ id: 'grok-2-image-1212' }))).toBe(true)
|
||||
expect(isAutoEnableImageGenerationModel(createModel({ id: 'gemini-2.5-flash-image-ultra' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('returns false when models are not in dedicated or auto-enable sets', () => {
|
||||
expect(isDedicatedImageGenerationModel(createModel({ id: 'gpt-4o' }))).toBe(false)
|
||||
expect(isAutoEnableImageGenerationModel(createModel({ id: 'gpt-4o' }))).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
describe('isVisionModel', () => {
|
||||
it('returns false for embedding/rerank models and honors overrides', () => {
|
||||
embeddingMock.mockReturnValueOnce(true)
|
||||
expect(isVisionModel(createModel({ id: 'gpt-4o' }))).toBe(false)
|
||||
|
||||
embeddingMock.mockReturnValue(false)
|
||||
const disabled = createModel({
|
||||
id: 'gpt-4o',
|
||||
capabilities: [{ type: 'vision', isUserSelected: false }]
|
||||
})
|
||||
expect(isVisionModel(disabled)).toBe(false)
|
||||
|
||||
const forced = createModel({
|
||||
id: 'gpt-4o',
|
||||
capabilities: [{ type: 'vision', isUserSelected: true }]
|
||||
})
|
||||
expect(isVisionModel(forced)).toBe(true)
|
||||
})
|
||||
|
||||
it('matches doubao models by name and general regexes by id', () => {
|
||||
const doubao = createModel({
|
||||
id: 'custom-id',
|
||||
provider: 'doubao',
|
||||
name: 'Doubao-Seed-1-6-Lite-251015'
|
||||
})
|
||||
expect(isVisionModel(doubao)).toBe(true)
|
||||
|
||||
expect(isVisionModel(createModel({ id: 'gpt-4o-mini' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('leverages image enhancement regex when standard vision regex does not match', () => {
|
||||
expect(isVisionModel(createModel({ id: 'qwen-image-edit' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('returns false for doubao models that fail regex checks', () => {
|
||||
const doubao = createModel({ id: 'doubao-standard', provider: 'doubao', name: 'basic' })
|
||||
expect(isVisionModel(doubao)).toBe(false)
|
||||
})
|
||||
describe('Gemini Models', () => {
|
||||
it('should return true for gemini 1.5 models', () => {
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-1.5-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-1.5-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for gemini 2.x models', () => {
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-2.0-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-2.0-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-2.5-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-2.5-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for gemini latest models', () => {
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-flash-latest',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-pro-latest',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-flash-lite-latest',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for gemini 3 models', () => {
|
||||
// Preview versions
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-3-pro-preview',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
// Future stable versions
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-3-flash',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-3-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true for gemini exp models', () => {
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-exp-1206',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false for gemini 1.0 models', () => {
|
||||
expect(
|
||||
isVisionModel({
|
||||
id: 'gemini-1.0-pro',
|
||||
name: '',
|
||||
provider: '',
|
||||
group: ''
|
||||
})
|
||||
).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,397 +0,0 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
|
||||
const providerMock = vi.mocked(getProviderByModel)
|
||||
|
||||
vi.mock('@renderer/services/AssistantService', () => ({
|
||||
getProviderByModel: vi.fn(),
|
||||
getAssistantSettings: vi.fn(),
|
||||
getDefaultAssistant: vi.fn().mockReturnValue({
|
||||
id: 'default',
|
||||
name: 'Default Assistant',
|
||||
prompt: '',
|
||||
settings: {}
|
||||
})
|
||||
}))
|
||||
|
||||
const isEmbeddingModel = vi.hoisted(() => vi.fn())
|
||||
const isRerankModel = vi.hoisted(() => vi.fn())
|
||||
vi.mock('../embedding', () => ({
|
||||
isEmbeddingModel: (...args: any[]) => isEmbeddingModel(...args),
|
||||
isRerankModel: (...args: any[]) => isRerankModel(...args)
|
||||
}))
|
||||
|
||||
const isPureGenerateImageModel = vi.hoisted(() => vi.fn())
|
||||
const isTextToImageModel = vi.hoisted(() => vi.fn())
|
||||
const isGenerateImageModel = vi.hoisted(() => vi.fn())
|
||||
vi.mock('../vision', () => ({
|
||||
isPureGenerateImageModel: (...args: any[]) => isPureGenerateImageModel(...args),
|
||||
isTextToImageModel: (...args: any[]) => isTextToImageModel(...args),
|
||||
isGenerateImageModel: (...args: any[]) => isGenerateImageModel(...args),
|
||||
isModernGenerateImageModel: vi.fn()
|
||||
}))
|
||||
|
||||
const providerMocks = vi.hoisted(() => ({
|
||||
isGeminiProvider: vi.fn(),
|
||||
isNewApiProvider: vi.fn(),
|
||||
isOpenAICompatibleProvider: vi.fn(),
|
||||
isOpenAIProvider: vi.fn(),
|
||||
isVertexProvider: vi.fn(),
|
||||
isAwsBedrockProvider: vi.fn(),
|
||||
isAzureOpenAIProvider: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/utils/provider', () => providerMocks)
|
||||
|
||||
vi.mock('@renderer/hooks/useStore', () => ({
|
||||
getStoreProviders: vi.fn(() => [])
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store', () => ({
|
||||
__esModule: true,
|
||||
default: {
|
||||
getState: () => ({
|
||||
llm: { providers: [] },
|
||||
settings: {}
|
||||
})
|
||||
},
|
||||
useAppDispatch: vi.fn(),
|
||||
useAppSelector: vi.fn()
|
||||
}))
|
||||
|
||||
vi.mock('@renderer/store/settings', () => {
|
||||
const noop = vi.fn()
|
||||
return new Proxy(
|
||||
{},
|
||||
{
|
||||
get: (_target, prop) => {
|
||||
if (prop === 'initialState') {
|
||||
return {}
|
||||
}
|
||||
return noop
|
||||
}
|
||||
}
|
||||
)
|
||||
})
|
||||
|
||||
vi.mock('@renderer/hooks/useSettings', () => ({
|
||||
useSettings: vi.fn(() => ({})),
|
||||
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
|
||||
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
|
||||
getStoreSetting: vi.fn()
|
||||
}))
|
||||
|
||||
import { getProviderByModel } from '@renderer/services/AssistantService'
|
||||
import type { Model, Provider } from '@renderer/types'
|
||||
import { SystemProviderIds } from '@renderer/types'
|
||||
|
||||
import { isOpenAIDeepResearchModel } from '../openai'
|
||||
import {
|
||||
GEMINI_SEARCH_REGEX,
|
||||
isHunyuanSearchModel,
|
||||
isMandatoryWebSearchModel,
|
||||
isOpenAIWebSearchChatCompletionOnlyModel,
|
||||
isOpenAIWebSearchModel,
|
||||
isOpenRouterBuiltInWebSearchModel,
|
||||
isWebSearchModel
|
||||
} from '../websearch'
|
||||
|
||||
const createModel = (overrides: Partial<Model> = {}): Model => ({
|
||||
id: 'gpt-4o',
|
||||
name: 'gpt-4o',
|
||||
provider: 'openai',
|
||||
group: 'OpenAI',
|
||||
...overrides
|
||||
})
|
||||
|
||||
const createProvider = (overrides: Partial<Provider> = {}): Provider => ({
|
||||
id: 'openai',
|
||||
type: 'openai',
|
||||
name: 'OpenAI',
|
||||
apiKey: '',
|
||||
apiHost: '',
|
||||
models: [],
|
||||
...overrides
|
||||
})
|
||||
|
||||
const resetMocks = () => {
|
||||
providerMock.mockReturnValue(createProvider())
|
||||
isEmbeddingModel.mockReturnValue(false)
|
||||
isRerankModel.mockReturnValue(false)
|
||||
isPureGenerateImageModel.mockReturnValue(false)
|
||||
isTextToImageModel.mockReturnValue(false)
|
||||
providerMocks.isGeminiProvider.mockReturnValue(false)
|
||||
providerMocks.isNewApiProvider.mockReturnValue(false)
|
||||
providerMocks.isOpenAICompatibleProvider.mockReturnValue(false)
|
||||
providerMocks.isOpenAIProvider.mockReturnValue(false)
|
||||
}
|
||||
|
||||
describe('websearch helpers', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
resetMocks()
|
||||
})
|
||||
|
||||
describe('isOpenAIDeepResearchModel', () => {
|
||||
it('detects deep research ids for OpenAI only', () => {
|
||||
expect(isOpenAIDeepResearchModel(createModel({ id: 'openai/deep-research-preview' }))).toBe(true)
|
||||
expect(isOpenAIDeepResearchModel(createModel({ provider: 'openai', id: 'gpt-4o' }))).toBe(false)
|
||||
expect(isOpenAIDeepResearchModel(createModel({ provider: 'openrouter', id: 'deep-research' }))).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isWebSearchModel', () => {
|
||||
it('returns false for embedding/rerank/image models', () => {
|
||||
isEmbeddingModel.mockReturnValueOnce(true)
|
||||
expect(isWebSearchModel(createModel())).toBe(false)
|
||||
|
||||
resetMocks()
|
||||
isRerankModel.mockReturnValueOnce(true)
|
||||
expect(isWebSearchModel(createModel())).toBe(false)
|
||||
|
||||
resetMocks()
|
||||
isTextToImageModel.mockReturnValueOnce(true)
|
||||
expect(isWebSearchModel(createModel())).toBe(false)
|
||||
})
|
||||
|
||||
it('honors user overrides', () => {
|
||||
const enabled = createModel({ capabilities: [{ type: 'web_search', isUserSelected: true }] })
|
||||
expect(isWebSearchModel(enabled)).toBe(true)
|
||||
|
||||
const disabled = createModel({ capabilities: [{ type: 'web_search', isUserSelected: false }] })
|
||||
expect(isWebSearchModel(disabled)).toBe(false)
|
||||
})
|
||||
|
||||
it('returns false when provider lookup fails', () => {
|
||||
providerMock.mockReturnValueOnce(undefined as any)
|
||||
expect(isWebSearchModel(createModel())).toBe(false)
|
||||
})
|
||||
|
||||
it('handles Anthropic providers on unsupported platforms', () => {
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds['aws-bedrock'] }))
|
||||
const model = createModel({ id: 'claude-2-sonnet' })
|
||||
expect(isWebSearchModel(model)).toBe(false)
|
||||
})
|
||||
|
||||
it('returns true for first-party Anthropic provider', () => {
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: 'anthropic' }))
|
||||
const model = createModel({ id: 'claude-3.5-sonnet-latest', provider: 'anthropic' })
|
||||
expect(isWebSearchModel(model)).toBe(true)
|
||||
})
|
||||
|
||||
it('detects OpenAI preview search models only when supported', () => {
|
||||
providerMocks.isOpenAIProvider.mockReturnValue(true)
|
||||
const model = createModel({ id: 'gpt-4o-search-preview' })
|
||||
expect(isWebSearchModel(model)).toBe(true)
|
||||
|
||||
const nonSearch = createModel({ id: 'gpt-4o-image' })
|
||||
expect(isWebSearchModel(nonSearch)).toBe(false)
|
||||
})
|
||||
|
||||
it('supports Perplexity sonar families including mandatory variants', () => {
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.perplexity }))
|
||||
expect(isWebSearchModel(createModel({ id: 'sonar-deep-research' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('handles AIHubMix Gemini and OpenAI search models', () => {
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.aihubmix }))
|
||||
expect(isWebSearchModel(createModel({ id: 'gemini-2.5-pro-preview' }))).toBe(true)
|
||||
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.aihubmix }))
|
||||
const openaiSearch = createModel({ id: 'gpt-4o-search-preview' })
|
||||
expect(isWebSearchModel(openaiSearch)).toBe(true)
|
||||
})
|
||||
|
||||
it('supports OpenAI-compatible or new API providers for Gemini/OpenAI models', () => {
|
||||
const model = createModel({ id: 'gemini-2.5-flash-lite-latest' })
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: 'custom' }))
|
||||
providerMocks.isOpenAICompatibleProvider.mockReturnValueOnce(true)
|
||||
expect(isWebSearchModel(model)).toBe(true)
|
||||
|
||||
resetMocks()
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: 'custom' }))
|
||||
providerMocks.isNewApiProvider.mockReturnValueOnce(true)
|
||||
expect(isWebSearchModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('falls back to Gemini/Vertex provider regex matching', () => {
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.vertexai }))
|
||||
providerMocks.isGeminiProvider.mockReturnValueOnce(true)
|
||||
expect(isWebSearchModel(createModel({ id: 'gemini-2.0-flash-latest' }))).toBe(true)
|
||||
})
|
||||
|
||||
it('evaluates hunyuan/zhipu/dashscope/openrouter/grok providers', () => {
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: 'hunyuan' }))
|
||||
expect(isWebSearchModel(createModel({ id: 'hunyuan-pro' }))).toBe(true)
|
||||
expect(isWebSearchModel(createModel({ id: 'hunyuan-lite', provider: 'hunyuan' }))).toBe(false)
|
||||
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: 'zhipu' }))
|
||||
expect(isWebSearchModel(createModel({ id: 'glm-4-air' }))).toBe(true)
|
||||
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: 'dashscope' }))
|
||||
expect(isWebSearchModel(createModel({ id: 'qwen-max-latest' }))).toBe(true)
|
||||
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: 'openrouter' }))
|
||||
expect(isWebSearchModel(createModel())).toBe(true)
|
||||
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: 'grok' }))
|
||||
expect(isWebSearchModel(createModel({ id: 'grok-2' }))).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isMandatoryWebSearchModel', () => {
|
||||
it('requires sonar ids for perplexity/openrouter providers', () => {
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.perplexity }))
|
||||
expect(isMandatoryWebSearchModel(createModel({ id: 'sonar-pro' }))).toBe(true)
|
||||
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.openrouter }))
|
||||
expect(isMandatoryWebSearchModel(createModel({ id: 'sonar-reasoning' }))).toBe(true)
|
||||
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: 'openai' }))
|
||||
expect(isMandatoryWebSearchModel(createModel({ id: 'sonar-pro' }))).toBe(false)
|
||||
})
|
||||
|
||||
it.each([
|
||||
['perplexity', 'non-sonar'],
|
||||
['openrouter', 'gpt-4o-search-preview']
|
||||
])('returns false for %s provider when id is %s', (providerId, modelId) => {
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: providerId }))
|
||||
expect(isMandatoryWebSearchModel(createModel({ id: modelId }))).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isOpenRouterBuiltInWebSearchModel', () => {
|
||||
it('checks for sonar ids or OpenAI chat-completion-only variants', () => {
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: 'openrouter' }))
|
||||
expect(isOpenRouterBuiltInWebSearchModel(createModel({ id: 'sonar-reasoning' }))).toBe(true)
|
||||
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: 'openrouter' }))
|
||||
expect(isOpenRouterBuiltInWebSearchModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
|
||||
|
||||
providerMock.mockReturnValueOnce(createProvider({ id: 'custom' }))
|
||||
expect(isOpenRouterBuiltInWebSearchModel(createModel({ id: 'sonar-reasoning' }))).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('OpenAI web search helpers', () => {
|
||||
it('detects chat completion only variants and openai search ids', () => {
|
||||
expect(isOpenAIWebSearchChatCompletionOnlyModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
|
||||
expect(isOpenAIWebSearchChatCompletionOnlyModel(createModel({ id: 'gpt-4o-mini-search-preview' }))).toBe(true)
|
||||
expect(isOpenAIWebSearchChatCompletionOnlyModel(createModel({ id: 'gpt-4o' }))).toBe(false)
|
||||
|
||||
expect(isOpenAIWebSearchModel(createModel({ id: 'gpt-4.1-turbo' }))).toBe(true)
|
||||
expect(isOpenAIWebSearchModel(createModel({ id: 'gpt-4o-image' }))).toBe(false)
|
||||
expect(isOpenAIWebSearchModel(createModel({ id: 'gpt-5.1-chat' }))).toBe(false)
|
||||
expect(isOpenAIWebSearchModel(createModel({ id: 'o3-mini' }))).toBe(true)
|
||||
})
|
||||
|
||||
it.each(['gpt-4.1-preview', 'gpt-4o-2024-05-13', 'o4-mini', 'gpt-5-explorer'])(
|
||||
'treats %s as an OpenAI web search model',
|
||||
(id) => {
|
||||
expect(isOpenAIWebSearchModel(createModel({ id }))).toBe(true)
|
||||
}
|
||||
)
|
||||
|
||||
it.each(['gpt-4o-image-preview', 'gpt-4.1-nano', 'gpt-5.1-chat', 'gpt-image-1'])(
|
||||
'excludes %s from OpenAI web search',
|
||||
(id) => {
|
||||
expect(isOpenAIWebSearchModel(createModel({ id }))).toBe(false)
|
||||
}
|
||||
)
|
||||
|
||||
it.each(['gpt-4o-search-preview', 'gpt-4o-mini-search-preview'])('flags %s as chat-completion-only', (id) => {
|
||||
expect(isOpenAIWebSearchChatCompletionOnlyModel(createModel({ id }))).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('isHunyuanSearchModel', () => {
|
||||
it('identifies hunyuan models except lite', () => {
|
||||
expect(isHunyuanSearchModel(createModel({ id: 'hunyuan-pro', provider: 'hunyuan' }))).toBe(true)
|
||||
expect(isHunyuanSearchModel(createModel({ id: 'hunyuan-lite', provider: 'hunyuan' }))).toBe(false)
|
||||
expect(isHunyuanSearchModel(createModel())).toBe(false)
|
||||
})
|
||||
|
||||
it.each(['hunyuan-standard', 'hunyuan-advanced'])('accepts %s', (suffix) => {
|
||||
expect(isHunyuanSearchModel(createModel({ id: suffix, provider: 'hunyuan' }))).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('provider-specific regex coverage', () => {
|
||||
it.each(['qwen-turbo', 'qwen-max-0919', 'qwen3-max', 'qwen-plus-2024', 'qwq-32b'])(
|
||||
'dashscope treats %s as searchable',
|
||||
(id) => {
|
||||
providerMock.mockReturnValue(createProvider({ id: 'dashscope' }))
|
||||
expect(isWebSearchModel(createModel({ id }))).toBe(true)
|
||||
}
|
||||
)
|
||||
|
||||
it.each(['qwen-1.5-chat', 'custom-model'])('dashscope ignores %s', (id) => {
|
||||
providerMock.mockReturnValue(createProvider({ id: 'dashscope' }))
|
||||
expect(isWebSearchModel(createModel({ id }))).toBe(false)
|
||||
})
|
||||
|
||||
it.each(['sonar', 'sonar-pro', 'sonar-reasoning-pro', 'sonar-deep-research'])(
|
||||
'perplexity provider supports %s',
|
||||
(id) => {
|
||||
providerMock.mockReturnValue(createProvider({ id: SystemProviderIds.perplexity }))
|
||||
expect(isWebSearchModel(createModel({ id }))).toBe(true)
|
||||
}
|
||||
)
|
||||
|
||||
it.each([
|
||||
'gemini-2.0-flash-latest',
|
||||
'gemini-2.5-flash-lite-latest',
|
||||
'gemini-flash-lite-latest',
|
||||
'gemini-pro-latest'
|
||||
])('Gemini provider supports %s', (id) => {
|
||||
providerMock.mockReturnValue(createProvider({ id: SystemProviderIds.vertexai }))
|
||||
providerMocks.isGeminiProvider.mockReturnValue(true)
|
||||
expect(isWebSearchModel(createModel({ id }))).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Gemini Search Models', () => {
|
||||
describe('GEMINI_SEARCH_REGEX', () => {
|
||||
it('should match gemini 2.x models', () => {
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-flash')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-pro')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-pro')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash-latest')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-pro-latest')).toBe(true)
|
||||
})
|
||||
|
||||
it('should match gemini latest models', () => {
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-flash-latest')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-pro-latest')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-flash-lite-latest')).toBe(true)
|
||||
})
|
||||
|
||||
it('should match gemini 3 models', () => {
|
||||
// Preview versions
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro-preview')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-3-flash-preview')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro-image-preview')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-3-flash-image-preview')).toBe(true)
|
||||
// Future stable versions
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-3-flash')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro')).toBe(true)
|
||||
// Version with decimals
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-3.0-flash')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-3.0-pro')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-3.5-flash-preview')).toBe(true)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-3.5-pro-image-preview')).toBe(true)
|
||||
})
|
||||
|
||||
it('should not match gemini 2.x image-preview models', () => {
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash-image-preview')).toBe(false)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-pro-image-preview')).toBe(false)
|
||||
})
|
||||
|
||||
it('should not match older gemini models', () => {
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-1.5-flash')).toBe(false)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-1.5-pro')).toBe(false)
|
||||
expect(GEMINI_SEARCH_REGEX.test('gemini-1.0-pro')).toBe(false)
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,8 +1,6 @@
|
||||
export * from './default'
|
||||
export * from './embedding'
|
||||
export * from './logo'
|
||||
export * from './openai'
|
||||
export * from './qwen'
|
||||
export * from './reasoning'
|
||||
export * from './tooluse'
|
||||
export * from './utils'
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
import type { Model } from '@renderer/types'
|
||||
import { getLowerBaseModelName } from '@renderer/utils'
|
||||
|
||||
export const OPENAI_NO_SUPPORT_DEV_ROLE_MODELS = ['o1-preview', 'o1-mini']
|
||||
|
||||
export function isOpenAILLMModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
|
||||
if (modelId.includes('gpt-4o-image')) {
|
||||
return false
|
||||
}
|
||||
if (isOpenAIReasoningModel(model)) {
|
||||
return true
|
||||
}
|
||||
if (modelId.includes('gpt')) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
export function isOpenAIModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
|
||||
return modelId.includes('gpt') || isOpenAIReasoningModel(model)
|
||||
}
|
||||
|
||||
export const isGPT5ProModel = (model: Model) => {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return modelId.includes('gpt-5-pro')
|
||||
}
|
||||
|
||||
export const isOpenAIOpenWeightModel = (model: Model) => {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return modelId.includes('gpt-oss')
|
||||
}
|
||||
|
||||
export const isGPT5SeriesModel = (model: Model) => {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return modelId.includes('gpt-5') && !modelId.includes('gpt-5.1')
|
||||
}
|
||||
|
||||
export const isGPT5SeriesReasoningModel = (model: Model) => {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return isGPT5SeriesModel(model) && !modelId.includes('chat')
|
||||
}
|
||||
|
||||
export const isGPT51SeriesModel = (model: Model) => {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return modelId.includes('gpt-5.1')
|
||||
}
|
||||
|
||||
export function isSupportVerbosityModel(model: Model): boolean {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return (isGPT5SeriesModel(model) || isGPT51SeriesModel(model)) && !modelId.includes('chat')
|
||||
}
|
||||
|
||||
export function isOpenAIChatCompletionOnlyModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return (
|
||||
modelId.includes('gpt-4o-search-preview') ||
|
||||
modelId.includes('gpt-4o-mini-search-preview') ||
|
||||
modelId.includes('o1-mini') ||
|
||||
modelId.includes('o1-preview')
|
||||
)
|
||||
}
|
||||
|
||||
export function isOpenAIReasoningModel(model: Model): boolean {
|
||||
const modelId = getLowerBaseModelName(model.id, '/')
|
||||
return isSupportedReasoningEffortOpenAIModel(model) || modelId.includes('o1')
|
||||
}
|
||||
|
||||
export function isSupportedReasoningEffortOpenAIModel(model: Model): boolean {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return (
|
||||
(modelId.includes('o1') && !(modelId.includes('o1-preview') || modelId.includes('o1-mini'))) ||
|
||||
modelId.includes('o3') ||
|
||||
modelId.includes('o4') ||
|
||||
modelId.includes('gpt-oss') ||
|
||||
((isGPT5SeriesModel(model) || isGPT51SeriesModel(model)) && !modelId.includes('chat'))
|
||||
)
|
||||
}
|
||||
|
||||
const OPENAI_DEEP_RESEARCH_MODEL_REGEX = /deep[-_]?research/
|
||||
|
||||
export function isOpenAIDeepResearchModel(model?: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
|
||||
const providerId = model.provider
|
||||
if (providerId !== 'openai' && providerId !== 'openai-chat') {
|
||||
return false
|
||||
}
|
||||
|
||||
const modelId = getLowerBaseModelName(model.id, '/')
|
||||
return OPENAI_DEEP_RESEARCH_MODEL_REGEX.test(modelId)
|
||||
}
|
||||
@@ -1,7 +0,0 @@
|
||||
import type { Model } from '@renderer/types'
|
||||
import { getLowerBaseModelName } from '@renderer/utils'
|
||||
|
||||
export const isQwenMTModel = (model: Model): boolean => {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return modelId.includes('qwen-mt')
|
||||
}
|
||||
@@ -8,16 +8,9 @@ import type {
|
||||
import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
|
||||
|
||||
import { isEmbeddingModel, isRerankModel } from './embedding'
|
||||
import {
|
||||
isGPT5ProModel,
|
||||
isGPT5SeriesModel,
|
||||
isGPT51SeriesModel,
|
||||
isOpenAIDeepResearchModel,
|
||||
isOpenAIReasoningModel,
|
||||
isSupportedReasoningEffortOpenAIModel
|
||||
} from './openai'
|
||||
import { GEMINI_FLASH_MODEL_REGEX, isGemini3Model } from './utils'
|
||||
import { isGPT5ProModel, isGPT5SeriesModel, isGPT51SeriesModel } from './utils'
|
||||
import { isTextToImageModel } from './vision'
|
||||
import { GEMINI_FLASH_MODEL_REGEX, isOpenAIDeepResearchModel } from './websearch'
|
||||
|
||||
// Reasoning models
|
||||
export const REASONING_REGEX =
|
||||
@@ -37,7 +30,6 @@ export const MODEL_SUPPORTED_REASONING_EFFORT: ReasoningEffortConfig = {
|
||||
grok: ['low', 'high'] as const,
|
||||
grok4_fast: ['auto'] as const,
|
||||
gemini: ['low', 'medium', 'high', 'auto'] as const,
|
||||
gemini3: ['low', 'medium', 'high'] as const,
|
||||
gemini_pro: ['low', 'medium', 'high', 'auto'] as const,
|
||||
qwen: ['low', 'medium', 'high'] as const,
|
||||
qwen_thinking: ['low', 'medium', 'high'] as const,
|
||||
@@ -64,7 +56,6 @@ export const MODEL_SUPPORTED_OPTIONS: ThinkingOptionConfig = {
|
||||
grok4_fast: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.grok4_fast] as const,
|
||||
gemini: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.gemini] as const,
|
||||
gemini_pro: MODEL_SUPPORTED_REASONING_EFFORT.gemini_pro,
|
||||
gemini3: MODEL_SUPPORTED_REASONING_EFFORT.gemini3,
|
||||
qwen: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.qwen] as const,
|
||||
qwen_thinking: MODEL_SUPPORTED_REASONING_EFFORT.qwen_thinking,
|
||||
doubao: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.doubao] as const,
|
||||
@@ -115,9 +106,6 @@ const _getThinkModelType = (model: Model): ThinkingModelType => {
|
||||
} else {
|
||||
thinkingModelType = 'gemini_pro'
|
||||
}
|
||||
if (isGemini3Model(model)) {
|
||||
thinkingModelType = 'gemini3'
|
||||
}
|
||||
} else if (isSupportedReasoningEffortGrokModel(model)) thinkingModelType = 'grok'
|
||||
else if (isSupportedThinkingTokenQwenModel(model)) {
|
||||
if (isQwenAlwaysThinkModel(model)) {
|
||||
@@ -266,19 +254,11 @@ export function isGeminiReasoningModel(model?: Model): boolean {
|
||||
|
||||
// Gemini 支持思考模式的模型正则
|
||||
export const GEMINI_THINKING_MODEL_REGEX =
|
||||
/gemini-(?:2\.5.*(?:-latest)?|3(?:\.\d+)?-(?:flash|pro)(?:-preview)?|flash-latest|pro-latest|flash-lite-latest)(?:-[\w-]+)*$/i
|
||||
/gemini-(?:2\.5.*(?:-latest)?|3-(?:flash|pro)(?:-preview)?|flash-latest|pro-latest|flash-lite-latest)(?:-[\w-]+)*$/i
|
||||
|
||||
export const isSupportedThinkingTokenGeminiModel = (model: Model): boolean => {
|
||||
const modelId = getLowerBaseModelName(model.id, '/')
|
||||
if (GEMINI_THINKING_MODEL_REGEX.test(modelId)) {
|
||||
// gemini-3.x 的 image 模型支持思考模式
|
||||
if (isGemini3Model(model)) {
|
||||
if (modelId.includes('tts')) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
// gemini-2.x 的 image/tts 模型不支持
|
||||
if (modelId.includes('image') || modelId.includes('tts')) {
|
||||
return false
|
||||
}
|
||||
@@ -402,12 +382,6 @@ export function isClaude45ReasoningModel(model: Model): boolean {
|
||||
return regex.test(modelId)
|
||||
}
|
||||
|
||||
export function isClaude4SeriesModel(model: Model): boolean {
|
||||
const modelId = getLowerBaseModelName(model.id, '/')
|
||||
const regex = /claude-(sonnet|opus|haiku)-4(?:[.-]\d+)?(?:-[\w-]+)?$/i
|
||||
return regex.test(modelId)
|
||||
}
|
||||
|
||||
export function isClaudeReasoningModel(model?: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
@@ -555,6 +529,22 @@ export function isReasoningModel(model?: Model): boolean {
|
||||
return REASONING_REGEX.test(modelId) || false
|
||||
}
|
||||
|
||||
export function isOpenAIReasoningModel(model: Model): boolean {
|
||||
const modelId = getLowerBaseModelName(model.id, '/')
|
||||
return isSupportedReasoningEffortOpenAIModel(model) || modelId.includes('o1')
|
||||
}
|
||||
|
||||
export function isSupportedReasoningEffortOpenAIModel(model: Model): boolean {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return (
|
||||
(modelId.includes('o1') && !(modelId.includes('o1-preview') || modelId.includes('o1-mini'))) ||
|
||||
modelId.includes('o3') ||
|
||||
modelId.includes('o4') ||
|
||||
modelId.includes('gpt-oss') ||
|
||||
((isGPT5SeriesModel(model) || isGPT51SeriesModel(model)) && !modelId.includes('chat'))
|
||||
)
|
||||
}
|
||||
|
||||
export const THINKING_TOKEN_MAP: Record<string, { min: number; max: number }> = {
|
||||
// Gemini models
|
||||
'gemini-2\\.5-flash-lite.*$': { min: 512, max: 24576 },
|
||||
|
||||
@@ -4,7 +4,7 @@ import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
|
||||
|
||||
import { isEmbeddingModel, isRerankModel } from './embedding'
|
||||
import { isDeepSeekHybridInferenceModel } from './reasoning'
|
||||
import { isTextToImageModel } from './vision'
|
||||
import { isPureGenerateImageModel, isTextToImageModel } from './vision'
|
||||
|
||||
// Tool calling models
|
||||
export const FUNCTION_CALLING_MODELS = [
|
||||
@@ -41,9 +41,7 @@ const FUNCTION_CALLING_EXCLUDED_MODELS = [
|
||||
'gemini-1(?:\\.[\\w-]+)?',
|
||||
'qwen-mt(?:-[\\w-]+)?',
|
||||
'gpt-5-chat(?:-[\\w-]+)?',
|
||||
'glm-4\\.5v',
|
||||
'gemini-2.5-flash-image(?:-[\\w-]+)?',
|
||||
'gemini-2.0-flash-preview-image-generation'
|
||||
'glm-4\\.5v'
|
||||
]
|
||||
|
||||
export const FUNCTION_CALLING_REGEX = new RegExp(
|
||||
@@ -52,7 +50,13 @@ export const FUNCTION_CALLING_REGEX = new RegExp(
|
||||
)
|
||||
|
||||
export function isFunctionCallingModel(model?: Model): boolean {
|
||||
if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) {
|
||||
if (
|
||||
!model ||
|
||||
isEmbeddingModel(model) ||
|
||||
isRerankModel(model) ||
|
||||
isTextToImageModel(model) ||
|
||||
isPureGenerateImageModel(model)
|
||||
) {
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -62,6 +66,10 @@ export function isFunctionCallingModel(model?: Model): boolean {
|
||||
return isUserSelectedModelType(model, 'function_calling')!
|
||||
}
|
||||
|
||||
if (model.provider === 'qiniu') {
|
||||
return ['deepseek-v3-tool', 'deepseek-v3-0324', 'qwq-32b', 'qwen2.5-72b-instruct'].includes(modelId)
|
||||
}
|
||||
|
||||
if (model.provider === 'doubao' || modelId.includes('doubao')) {
|
||||
return FUNCTION_CALLING_REGEX.test(modelId) || FUNCTION_CALLING_REGEX.test(model.name)
|
||||
}
|
||||
|
||||
@@ -1,14 +1,43 @@
|
||||
import type OpenAI from '@cherrystudio/openai'
|
||||
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding'
|
||||
import { type Model, SystemProviderIds } from '@renderer/types'
|
||||
import type { OpenAIVerbosity, ValidOpenAIVerbosity } from '@renderer/types/aiCoreTypes'
|
||||
import type { Model } from '@renderer/types'
|
||||
import { getLowerBaseModelName } from '@renderer/utils'
|
||||
|
||||
import { isOpenAIChatCompletionOnlyModel, isOpenAIOpenWeightModel, isOpenAIReasoningModel } from './openai'
|
||||
import { isQwenMTModel } from './qwen'
|
||||
import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from '../prompts'
|
||||
import { getWebSearchTools } from '../tools'
|
||||
import { isOpenAIReasoningModel } from './reasoning'
|
||||
import { isGenerateImageModel, isTextToImageModel, isVisionModel } from './vision'
|
||||
import { isOpenAIWebSearchChatCompletionOnlyModel } from './websearch'
|
||||
export const NOT_SUPPORTED_REGEX = /(?:^tts|whisper|speech)/i
|
||||
export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini.*-flash.*$', 'i')
|
||||
|
||||
export const OPENAI_NO_SUPPORT_DEV_ROLE_MODELS = ['o1-preview', 'o1-mini']
|
||||
|
||||
export function isOpenAILLMModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
|
||||
if (modelId.includes('gpt-4o-image')) {
|
||||
return false
|
||||
}
|
||||
if (isOpenAIReasoningModel(model)) {
|
||||
return true
|
||||
}
|
||||
if (modelId.includes('gpt')) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
export function isOpenAIModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
|
||||
return modelId.includes('gpt') || isOpenAIReasoningModel(model)
|
||||
}
|
||||
|
||||
export function isSupportFlexServiceTierModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
@@ -23,6 +52,33 @@ export function isSupportedFlexServiceTier(model: Model): boolean {
|
||||
return isSupportFlexServiceTierModel(model)
|
||||
}
|
||||
|
||||
export function isSupportVerbosityModel(model: Model): boolean {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return (isGPT5SeriesModel(model) || isGPT51SeriesModel(model)) && !modelId.includes('chat')
|
||||
}
|
||||
|
||||
export function isOpenAIChatCompletionOnlyModel(model: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return (
|
||||
modelId.includes('gpt-4o-search-preview') ||
|
||||
modelId.includes('gpt-4o-mini-search-preview') ||
|
||||
modelId.includes('o1-mini') ||
|
||||
modelId.includes('o1-preview')
|
||||
)
|
||||
}
|
||||
|
||||
export function isGrokModel(model?: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return modelId.includes('grok')
|
||||
}
|
||||
|
||||
export function isSupportedModel(model: OpenAI.Models.Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
@@ -49,6 +105,53 @@ export function isNotSupportTemperatureAndTopP(model: Model): boolean {
|
||||
return false
|
||||
}
|
||||
|
||||
export function getOpenAIWebSearchParams(model: Model, isEnableWebSearch?: boolean): Record<string, any> {
|
||||
if (!isEnableWebSearch) {
|
||||
return {}
|
||||
}
|
||||
|
||||
const webSearchTools = getWebSearchTools(model)
|
||||
|
||||
if (model.provider === 'grok') {
|
||||
return {
|
||||
search_parameters: {
|
||||
mode: 'auto',
|
||||
return_citations: true,
|
||||
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (model.provider === 'hunyuan') {
|
||||
return { enable_enhancement: true, citation: true, search_info: true }
|
||||
}
|
||||
|
||||
if (model.provider === 'dashscope') {
|
||||
return {
|
||||
enable_search: true,
|
||||
search_options: {
|
||||
forced_search: true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (isOpenAIWebSearchChatCompletionOnlyModel(model)) {
|
||||
return {
|
||||
web_search_options: {}
|
||||
}
|
||||
}
|
||||
|
||||
if (model.provider === 'openrouter') {
|
||||
return {
|
||||
plugins: [{ id: 'web', search_prompts: WEB_SEARCH_PROMPT_FOR_OPENROUTER }]
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
tools: webSearchTools
|
||||
}
|
||||
}
|
||||
|
||||
export function isGemmaModel(model?: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
@@ -58,14 +161,12 @@ export function isGemmaModel(model?: Model): boolean {
|
||||
return modelId.includes('gemma-') || model.group === 'Gemma'
|
||||
}
|
||||
|
||||
export function isZhipuModel(model: Model): boolean {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return modelId.includes('glm') || model.provider === SystemProviderIds.zhipu
|
||||
}
|
||||
export function isZhipuModel(model?: Model): boolean {
|
||||
if (!model) {
|
||||
return false
|
||||
}
|
||||
|
||||
export function isMoonshotModel(model: Model): boolean {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return ['moonshot', 'kimi'].some((m) => modelId.includes(m))
|
||||
return model.provider === 'zhipu'
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -111,6 +212,11 @@ export const isAnthropicModel = (model?: Model): boolean => {
|
||||
return modelId.startsWith('claude')
|
||||
}
|
||||
|
||||
export const isQwenMTModel = (model: Model): boolean => {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return modelId.includes('qwen-mt')
|
||||
}
|
||||
|
||||
export const isNotSupportedTextDelta = (model: Model): boolean => {
|
||||
return isQwenMTModel(model)
|
||||
}
|
||||
@@ -119,22 +225,34 @@ export const isNotSupportSystemMessageModel = (model: Model): boolean => {
|
||||
return isQwenMTModel(model) || isGemmaModel(model)
|
||||
}
|
||||
|
||||
export const isGPT5SeriesModel = (model: Model) => {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return modelId.includes('gpt-5') && !modelId.includes('gpt-5.1')
|
||||
}
|
||||
|
||||
export const isGPT5SeriesReasoningModel = (model: Model) => {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return isGPT5SeriesModel(model) && !modelId.includes('chat')
|
||||
}
|
||||
|
||||
export const isGPT51SeriesModel = (model: Model) => {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return modelId.includes('gpt-5.1')
|
||||
}
|
||||
|
||||
// GPT-5 verbosity configuration
|
||||
// gpt-5-pro only supports 'high', other GPT-5 models support all levels
|
||||
export const MODEL_SUPPORTED_VERBOSITY: Record<string, ValidOpenAIVerbosity[]> = {
|
||||
export const MODEL_SUPPORTED_VERBOSITY: Record<string, ('low' | 'medium' | 'high')[]> = {
|
||||
'gpt-5-pro': ['high'],
|
||||
default: ['low', 'medium', 'high']
|
||||
} as const
|
||||
}
|
||||
|
||||
export const getModelSupportedVerbosity = (model: Model): OpenAIVerbosity[] => {
|
||||
export const getModelSupportedVerbosity = (model: Model): ('low' | 'medium' | 'high')[] => {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
let supportedValues: ValidOpenAIVerbosity[]
|
||||
if (modelId.includes('gpt-5-pro')) {
|
||||
supportedValues = MODEL_SUPPORTED_VERBOSITY['gpt-5-pro']
|
||||
} else {
|
||||
supportedValues = MODEL_SUPPORTED_VERBOSITY.default
|
||||
return MODEL_SUPPORTED_VERBOSITY['gpt-5-pro']
|
||||
}
|
||||
return [undefined, ...supportedValues]
|
||||
return MODEL_SUPPORTED_VERBOSITY.default
|
||||
}
|
||||
|
||||
export const isGeminiModel = (model: Model) => {
|
||||
@@ -142,6 +260,11 @@ export const isGeminiModel = (model: Model) => {
|
||||
return modelId.includes('gemini')
|
||||
}
|
||||
|
||||
export const isOpenAIOpenWeightModel = (model: Model) => {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return modelId.includes('gpt-oss')
|
||||
}
|
||||
|
||||
// zhipu 视觉推理模型用这组 special token 标记推理结果
|
||||
export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const
|
||||
|
||||
@@ -149,14 +272,7 @@ export const agentModelFilter = (model: Model): boolean => {
|
||||
return !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model)
|
||||
}
|
||||
|
||||
export const isMaxTemperatureOneModel = (model: Model): boolean => {
|
||||
if (isZhipuModel(model) || isAnthropicModel(model) || isMoonshotModel(model)) {
|
||||
return true
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
export const isGemini3Model = (model: Model) => {
|
||||
export const isGPT5ProModel = (model: Model) => {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return modelId.includes('gemini-3')
|
||||
return modelId.includes('gpt-5-pro')
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ import type { Model } from '@renderer/types'
|
||||
import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
|
||||
|
||||
import { isEmbeddingModel, isRerankModel } from './embedding'
|
||||
import { isFunctionCallingModel } from './tooluse'
|
||||
|
||||
// Vision models
|
||||
const visionAllowedModels = [
|
||||
@@ -73,10 +72,12 @@ const VISION_REGEX = new RegExp(
|
||||
|
||||
// For middleware to identify models that must use the dedicated Image API
|
||||
const DEDICATED_IMAGE_MODELS = [
|
||||
'grok-2-image(?:-[\\w-]+)?',
|
||||
'dall-e(?:-[\\w-]+)?',
|
||||
'gpt-image-1(?:-[\\w-]+)?',
|
||||
'imagen(?:-[\\w-]+)?'
|
||||
'grok-2-image',
|
||||
'grok-2-image-1212',
|
||||
'grok-2-image-latest',
|
||||
'dall-e-3',
|
||||
'dall-e-2',
|
||||
'gpt-image-1'
|
||||
]
|
||||
|
||||
const IMAGE_ENHANCEMENT_MODELS = [
|
||||
@@ -84,22 +85,13 @@ const IMAGE_ENHANCEMENT_MODELS = [
|
||||
'qwen-image-edit',
|
||||
'gpt-image-1',
|
||||
'gemini-2.5-flash-image(?:-[\\w-]+)?',
|
||||
'gemini-2.0-flash-preview-image-generation',
|
||||
'gemini-3(?:\\.\\d+)?-pro-image(?:-[\\w-]+)?'
|
||||
'gemini-2.0-flash-preview-image-generation'
|
||||
]
|
||||
|
||||
const IMAGE_ENHANCEMENT_MODELS_REGEX = new RegExp(IMAGE_ENHANCEMENT_MODELS.join('|'), 'i')
|
||||
|
||||
const DEDICATED_IMAGE_MODELS_REGEX = new RegExp(DEDICATED_IMAGE_MODELS.join('|'), 'i')
|
||||
|
||||
// Models that should auto-enable image generation button when selected
|
||||
const AUTO_ENABLE_IMAGE_MODELS = [
|
||||
'gemini-2.5-flash-image(?:-[\\w-]+)?',
|
||||
'gemini-3(?:\\.\\d+)?-pro-image(?:-[\\w-]+)?',
|
||||
...DEDICATED_IMAGE_MODELS
|
||||
]
|
||||
|
||||
const AUTO_ENABLE_IMAGE_MODELS_REGEX = new RegExp(AUTO_ENABLE_IMAGE_MODELS.join('|'), 'i')
|
||||
const AUTO_ENABLE_IMAGE_MODELS = ['gemini-2.5-flash-image', ...DEDICATED_IMAGE_MODELS]
|
||||
|
||||
const OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS = [
|
||||
'o3',
|
||||
@@ -113,34 +105,26 @@ const OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS = [
|
||||
|
||||
const OPENAI_IMAGE_GENERATION_MODELS = [...OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS, 'gpt-image-1']
|
||||
|
||||
const MODERN_IMAGE_MODELS = ['gemini-3(?:\\.\\d+)?-pro-image(?:-[\\w-]+)?']
|
||||
|
||||
const GENERATE_IMAGE_MODELS = [
|
||||
'gemini-2.0-flash-exp(?:-[\\w-]+)?',
|
||||
'gemini-2.5-flash-image(?:-[\\w-]+)?',
|
||||
'gemini-2.0-flash-exp',
|
||||
'gemini-2.0-flash-exp-image-generation',
|
||||
'gemini-2.0-flash-preview-image-generation',
|
||||
...MODERN_IMAGE_MODELS,
|
||||
'gemini-2.5-flash-image',
|
||||
...DEDICATED_IMAGE_MODELS
|
||||
]
|
||||
|
||||
const OPENAI_IMAGE_GENERATION_MODELS_REGEX = new RegExp(OPENAI_IMAGE_GENERATION_MODELS.join('|'), 'i')
|
||||
|
||||
const GENERATE_IMAGE_MODELS_REGEX = new RegExp(GENERATE_IMAGE_MODELS.join('|'), 'i')
|
||||
|
||||
const MODERN_GENERATE_IMAGE_MODELS_REGEX = new RegExp(MODERN_IMAGE_MODELS.join('|'), 'i')
|
||||
|
||||
export const isDedicatedImageGenerationModel = (model: Model): boolean => {
|
||||
if (!model) return false
|
||||
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return DEDICATED_IMAGE_MODELS_REGEX.test(modelId)
|
||||
return DEDICATED_IMAGE_MODELS.some((m) => modelId.includes(m))
|
||||
}
|
||||
|
||||
export const isAutoEnableImageGenerationModel = (model: Model): boolean => {
|
||||
if (!model) return false
|
||||
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return AUTO_ENABLE_IMAGE_MODELS_REGEX.test(modelId)
|
||||
return AUTO_ENABLE_IMAGE_MODELS.some((m) => modelId.includes(m))
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -162,44 +146,48 @@ export function isGenerateImageModel(model: Model): boolean {
|
||||
const modelId = getLowerBaseModelName(model.id, '/')
|
||||
|
||||
if (provider.type === 'openai-response') {
|
||||
return OPENAI_IMAGE_GENERATION_MODELS_REGEX.test(modelId) || GENERATE_IMAGE_MODELS_REGEX.test(modelId)
|
||||
return (
|
||||
OPENAI_IMAGE_GENERATION_MODELS.some((imageModel) => modelId.includes(imageModel)) ||
|
||||
GENERATE_IMAGE_MODELS.some((imageModel) => modelId.includes(imageModel))
|
||||
)
|
||||
}
|
||||
|
||||
return GENERATE_IMAGE_MODELS_REGEX.test(modelId)
|
||||
return GENERATE_IMAGE_MODELS.some((imageModel) => modelId.includes(imageModel))
|
||||
}
|
||||
|
||||
// TODO: refine the regex
|
||||
/**
|
||||
* 判断模型是否支持纯图片生成(不支持通过工具调用)
|
||||
* @param model
|
||||
* @returns
|
||||
*/
|
||||
export function isPureGenerateImageModel(model: Model): boolean {
|
||||
if (!isGenerateImageModel(model) && !isTextToImageModel(model)) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (isFunctionCallingModel(model)) {
|
||||
if (!isGenerateImageModel(model) || !isTextToImageModel(model)) {
|
||||
return false
|
||||
}
|
||||
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
if (GENERATE_IMAGE_MODELS_REGEX.test(modelId) && !MODERN_GENERATE_IMAGE_MODELS_REGEX.test(modelId)) {
|
||||
return true
|
||||
}
|
||||
|
||||
return !OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS.some((m) => modelId.includes(m))
|
||||
return !OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS.some((imageModel) => modelId.includes(imageModel))
|
||||
}
|
||||
|
||||
// TODO: refine the regex
|
||||
// Text to image models
|
||||
const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-|dall|cogview|janus|midjourney|mj-|imagen|gpt-image/i
|
||||
const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-|dall|cogview|janus|midjourney|mj-|image|gpt-image/i
|
||||
|
||||
export function isTextToImageModel(model: Model): boolean {
|
||||
const modelId = getLowerBaseModelName(model.id)
|
||||
return TEXT_TO_IMAGE_REGEX.test(modelId)
|
||||
}
|
||||
|
||||
// It's not used now
|
||||
// export function isNotSupportedImageSizeModel(model?: Model): boolean {
|
||||
// if (!model) {
|
||||
// return false
|
||||
// }
|
||||
|
||||
// const baseName = getLowerBaseModelName(model.id, '/')
|
||||
|
||||
// return baseName.includes('grok-2-image')
|
||||
// }
|
||||
|
||||
/**
|
||||
* 判断模型是否支持图片增强(包括编辑、增强、修复等)
|
||||
* @param model
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user