Compare commits

...

26 Commits

Author SHA1 Message Date
Vaayne
bf30d91cd1 🐛 fix: detect and handle migration tag mismatches
When a migration version exists but with a different tag (e.g., local
migration replaced by upstream), the system now:

- Detects the tag mismatch instead of silently skipping
- Logs a warning about the mismatch
- Re-applies the correct migration
- Uses upsert to update the migration record

This fixes issues where schema changes are skipped because a different
migration with the same version number was previously applied.
2025-11-25 15:57:43 +08:00
SuYao
dc8df98929 fix: websearch button condition (#11440)
fix: button
2025-11-25 13:24:37 +08:00
fullex
0004a8cafe fix: respect enableMaxTokens setting when maxTokens is not configured (#11438)
* fix: respect enableMaxTokens setting when maxTokens is not configured

When enableMaxTokens is disabled, getMaxTokens() should return undefined
to let the API use its own default value, instead of forcing 4096 tokens.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix(modelParameters): handle max tokens when feature is disabled

Check if max tokens feature is enabled before returning undefined to ensure proper API behavior

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: icarus <eurfelux@gmail.com>
2025-11-25 11:12:50 +08:00
defi-failure
1992363580 chore: bump version to 1.7.0-rc.2 (#11429) 2025-11-24 14:46:10 +08:00
defi-failure
c901771480 chore: update release notes for v1.7.0-rc.2 (#11426) 2025-11-24 11:30:40 +08:00
SuYao
475f718efb fix: improve error handling and display in AiSdkToChunkAdapter (#11423)
* fix: improve error handling and display in AiSdkToChunkAdapter

* fix: test
2025-11-24 10:57:51 +08:00
SuYao
2c3338939e feat: update Google and OpenAI SDKs with new features and fixes (#11395)
* feat: update Google and OpenAI SDKs with new features and fixes

- Updated Google SDK to ensure model paths are correctly formatted.
- Enhanced OpenAI SDK to include support for image URLs in chat responses.
- Added reasoning content handling in OpenAI chat responses and chunks.
- Introduced Azure Anthropic provider configuration for Claude integration.

* fix: azure error

* fix: lint

* fix: test

* fix: test

* fix type

* fix comment

* fix: redundant

* chore resolution

* fix: test

* fix: comment

* fix: comment

* fix

* feat: 添加 OpenRouter 推理中间件以支持内容过滤
2025-11-23 23:18:57 +08:00
槑囿脑袋
64ca3802a4 feat: support gemini 3 pro image preview (#11416)
feat: support gemini 3 pro preview
2025-11-23 21:40:22 +08:00
Phantom
fa361126b8 refactor: aisdk config (#11402)
* refactor: improve model filtering with todo for robust conversion

* refactor(aiCore): add AiSdkConfig type and update provider config handling

- Introduce new AiSdkConfig type in aiCoreTypes for better type safety
- Update provider factory and config to use AiSdkConfig consistently
- Simplify getAiSdkProviderId return type to string
- Add config validation in ModernAiProvider

* refactor(aiCore): move ai core types to dedicated module

Consolidate AI core type definitions into a dedicated module under aiCore/types. This improves code organization by keeping related types together and removes circular dependencies between modules. The change includes:
- Moving AiSdkConfig to aiCore/types
- Updating all imports to reference the new location
- Removing duplicate type definitions

* refactor(provider): add return type to createAiSdkProvider function
2025-11-23 21:12:57 +08:00
SuYao
49903a1567 Test/ai-core (#11307)
* test: 1

* test: 2

* test: 3

* format

* chore: move provider from config to utils

* fix: 4

* test: 5

* chore: redundant logic

* test: add reasoning model tests and improve provider options typings

* chore: format

* test 6

* chore: format

* test: 7

* test: 8

* fix: test

* fix: format and typecheck

* fix error

* test: isClaude4SeriesModel

* fix: test

* fix: test

---------

Co-authored-by: defi-failure <159208748+defi-failure@users.noreply.github.com>
2025-11-23 17:33:27 +08:00
Phantom
086b16a59c ci: update PR title in auto-i18n workflow to be more specific (#11406) 2025-11-23 11:48:44 +08:00
github-actions[bot]
e2562d8224 🤖 Weekly Automated Update: Nov 23, 2025 (#11412)
feat(bot): Weekly automated script run

Co-authored-by: EurFelux <59059173+EurFelux@users.noreply.github.com>
2025-11-23 11:47:54 +08:00
Phantom
c9be949853 fix: adjacent user messages appear when assistant message contains error only (#11390)
* feat(messages): add filter for error-only messages and their related pairs

Add new filter function to remove assistant messages containing only error blocks along with their associated user messages, identified by askId. This improves conversation quality by cleaning up error-only responses.

* refactor(ConversationService): improve message filtering pipeline readability

Break down complex message filtering chain into clearly labeled steps
Add comments explaining each filtering step's purpose
Maintain same functionality while improving code maintainability

* test(messageUtils): add test cases for message filter utilities

* docs(messageUtils): correct jsdoc for filterUsefulMessages

* refactor(ConversationService): extract message filtering logic into pipeline method

Move message filtering steps into a dedicated static method to improve testability and maintainability. Add comprehensive tests to verify pipeline behavior.

* refactor(ConversationService): add logging and improve message filtering readability

Add logger service to track message pipeline output
Split filterUserRoleStartMessages into separate variable for better debugging
2025-11-22 23:00:13 +08:00
defi-failure
ebfb1c5abf fix: add missing execution state for approved tool permissions (#11394) 2025-11-22 21:45:42 +08:00
SuYao
c1f1d7996d test: add thinking budget token test (#11305)
* refactor: add thinking budget token test

* fix comment
2025-11-22 21:43:57 +08:00
Phantom
0a72c613af fix(openai): apply verbosity setting with type safety improvements (#10964)
* refactor(types): consolidate OpenAI types and improve type safety

- Move OpenAI-related types to aiCoreTypes.ts
- Rename FetchChatCompletionOptions to FetchChatCompletionRequestOptions
- Add proper type definitions for service tiers and verbosity
- Improve type guards for service tier checks

* refactor(api): rename options parameter to requestOptions for consistency

Update parameter name across multiple files to use requestOptions instead of options for better clarity and consistency in API calls

* refactor(aiCore): simplify OpenAI summary text handling and improve type safety

- Remove 'off' option from OpenAISummaryText type and use null instead
- Add migration to convert 'off' values to null
- Add utility function to convert undefined to null
- Update Selector component to handle null/undefined values
- Improve type safety in provider options and reasoning params

* fix(i18n): Auto update translations for PR #10964

* feat(utils): add notNull function to convert null to undefined

* refactor(utils): move defined and notNull functions to shared package

Consolidate utility functions into shared package to improve code organization and reuse

* Revert "fix(i18n): Auto update translations for PR #10964"

This reverts commit 68bd7eaac5.

* feat(i18n): add "off" translation and remove "performance" tier

Add "off" translation for multiple languages and remove "performance" service tier option from translations

* Apply suggestion from @EurFelux

* docs(types): clarify handling of undefined and null values

Add comments to explain that undefined is treated as default and null as explicitly off in OpenAIVerbosity and OpenAIServiceTier types. Also update type safety for OpenAIServiceTiers record.

* fix(migration): update migration version from 167 to 171 for removed type

* chore: update store version to 172

* fix(migrate): update migration version number from 171 to 172

* fix(i18n): Auto update translations for PR #10964

* refactor(types): improve type safety for verbosity handling

add NotUndefined and NotNull utility types to better handle null/undefined cases
clarify verbosity types in aiCoreTypes and update related utility functions

* refactor(types): replace null with undefined for verbosity values

Standardize on undefined instead of null for verbosity values to align with OpenAI API docs and improve type consistency

* refactor(aiCore): update OpenAI provider options type import and usage

* fix(openai): change summaryText default from null to 'auto'

Update OpenAI settings to use 'auto' as default summaryText value instead of null for consistency with API behavior. Remove 'off' option and add 'concise' option while maintaining type safety.

* refactor(OpenAISettingsGroup): extract service tier options type for better maintainability

* refactor(types): make SystemProviderIdTypeMap internal type

* docs(provider): clarify OpenAIServiceTier behavior for undefined vs null

Explain that undefined and null values for serviceTier should be treated differently since they affect whether the field appears in the response

* refactor(utils): rename utility functions for clarity

Rename `defined` to `toNullIfUndefined` and `notNull` to `toUndefinedIfNull` to better reflect their functionality

* refactor(aiCore): extract service tier logic and improve type safety

Extract service tier validation logic into separate functions for better reusability
Add proper type annotations for provider options
Pass service tier parameter through provider option builders

* refactor(utils): comment out unused utility functions

Keep commented utility functions for potential future use while cleaning up current codebase

* fix(migration): update migration version number from 172 to 177

* docs(aiCoreTypes): clarify parameter passing behavior in OpenAI API

Update comments to consistently use 'undefined' instead of 'null' when describing parameter passing behavior in OpenAI API requests, as they share the same meaning in this context

---------

Co-authored-by: GitHub Action <action@github.com>
2025-11-22 21:41:12 +08:00
SuYao
a1ac3207f1 fix/anthropic-vertex (#11397)
* 100m

* feat: add web search header for Claude 4 series models

* fix: typo

* fix: identify model

---------

Co-authored-by: defi-failure <159208748+defi-failure@users.noreply.github.com>
2025-11-22 20:56:05 +08:00
Caelan
f98a063a8f Fix the issue where base64 images cannot be saved (#11398) 2025-11-22 20:20:02 +08:00
亢奋猫
1cb2af57ae refactor: optimize DatabaseManager and fix libsql crash issues (#11392)
* refactor: optimize DatabaseManager and fix libsql crash issues

Major improvements:
- Created DatabaseManager singleton to centralize database connection management
- Auto-initialize database in constructor (no manual initialization needed)
- Removed all manual initialize() and ensureInitialized() calls (47 occurrences)
- Simplified initialization logic (removed retry loops that could cause crashes)
- Removed unused close() and reinitialize() methods
- Reduced code from ~270 lines to 172 lines (-36%)

Key changes:
1. DatabaseManager.ts (new file):
   - Singleton pattern with auto-initialization
   - State management (INITIALIZING, INITIALIZED, FAILED)
   - Windows compatibility fixes (empty file detection, intMode: 'number')
   - Simplified waitForInitialization() logic

2. BaseService.ts:
   - Removed static initialize() and ensureInitialized() methods
   - Simplified database/rawClient getters to use DatabaseManager

3. Service classes (AgentService, SessionService, SessionMessageService):
   - Removed all initialize() methods
   - Removed all ensureInitialized() calls
   - Services now work out of the box

4. Main entry points (index.ts, server.ts):
   - Removed explicit database initialization calls
   - Database initializes automatically on first access

Benefits:
- Fixes Windows libsql crashes by removing dangerous retry logic
- Simpler API - no need to remember to call initialize()
- Better separation of concerns
- Cleaner codebase with 36% less code

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: wait for database initialization on app startup

Issue: "Database is still initializing" error on startup
Root cause: Synchronous database getter was called before async initialization completed

Solution:
- Explicitly wait for database initialization in main index.ts
- Import DatabaseManager and call getDatabase() to ensure initialization is complete
- This guarantees database is ready before any service methods are called

Changes:
- src/main/index.ts: Added explicit database initialization wait before API server check

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* refactor: use static import for getDatabaseManager

- Move import to top of file for better code organization
- Remove unnecessary dynamic import

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* refactor: streamline database access in service classes

- Replaced direct database access with asynchronous calls to getDatabase() in various service classes (AgentService, SessionService, SessionMessageService).
- Updated the main index.ts to utilize runAsyncFunction for API server initialization, ensuring proper handling of asynchronous database access.
- Improved code organization and readability by consolidating database access logic.

This change enhances the reliability of database interactions across the application and ensures that services are correctly initialized before use.

* refactor: remove redundant logging in ApiServer initialization

- Removed the logging statement for 'AgentService ready' during server initialization.
- This change streamlines the startup process by eliminating unnecessary log entries.

This update contributes to cleaner logs and improved readability during server startup.

* refactor: change getDatabase method to synchronous return type

- Updated the getDatabase method in DatabaseManager to return a synchronous LibSQLDatabase instance instead of a Promise.
- This change simplifies the database access pattern, aligning with the current initialization logic.

This refactor enhances code clarity and reduces unnecessary asynchronous handling in the database access layer.

* refactor: simplify sessionMessageRepository by removing transaction handling

- Removed transaction handling parameters from message persistence methods in sessionMessageRepository.
- Updated database access to use a direct call to getDatabase() instead of passing a transaction client.
- Streamlined the upsertMessage and persistExchange methods for improved clarity and reduced complexity.

This refactor enhances code readability and simplifies the database interaction logic.

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-11-22 09:12:11 +08:00
fullex
62309ae1bf fix: prevent EventEmitter memory leak in useApiServer hook (#11385)
Implement single instance IPC subscription pattern to resolve MaxListenersExceededWarning. Previously, each component using useApiServer would register a separate 'api-server:ready' listener, and React strict mode double rendering would quickly exceed the 10 listener limit.

Changes:
- Add module-level subscription manager with onReadyCallbacks Set
- Ensure only one IPC listener is registered regardless of component count
- Use useRef to maintain stable callback references
- Properly cleanup subscriptions when all components unmount

This maintains existing behavior while keeping listener count constant at 1.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude <noreply@anthropic.com>
2025-11-21 21:42:34 +08:00
defi-failure
c48f222cdb feat: add endpoint type support for cherryin provider (#11367)
* feat: add endpoint type support for cherryin provider

* chore: bump @cherrystudio/ai-sdk-provider version to 0.1.1

* chore: bump ai-sdk-provider version to 0.1.3
2025-11-21 21:42:08 +08:00
亢奋猫
cea0058f87 refactor: simplify knowledge base creation modal (#11371)
* test(knowledge): fix tests for knowledge base form modal refactoring

Update all test files to match the new vertical layout structure with button-based advanced settings toggle. Remove obsolete tests for deleted features.

Changes:
- Rewrite KnowledgeBaseFormModal.test.tsx for new button-toggle structure
- Remove tests for preprocess and rerank features from GeneralSettingsPanel
- Update AdvancedSettingsPanel tests with required props
- Update all snapshots to reflect new component structure
- Format test files according to biome rules

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* test(knowledge): simplify KnowledgeBaseFormModal button tests

Simplify button interaction tests to avoid text matching issues. Focus on testing behavior rather than implementation details.

Changes:
- Simplify advanced settings toggle test
- Simplify footer buttons test to check button count instead of text content
- Remove fragile text-based button selection

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-11-21 21:34:34 +08:00
beyondkmp
852192dce6 feat: add Git Bash detection and requirement check for Windows agents (#11388)
* feat: add Git Bash detection and requirement check for Windows agents

- Add System_CheckGitBash IPC channel for detecting Git Bash installation
- Implement detection logic checking common installation paths and PATH environment
- Display non-closable error alert in AgentModal when Git Bash is not found
- Disable agent creation/edit button until Git Bash is installed
- Add recheck functionality to verify installation without restarting app

Git Bash is required for agents to function properly on Windows systems.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* i18n: add Git Bash requirement translations for agent modal

- Add English translations for Git Bash detection warnings
- Add Simplified Chinese (zh-cn) translations
- Add Traditional Chinese (zh-tw) translations

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* format code

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-11-21 21:32:53 +08:00
Pleasure1234
eee49d1580 feat: add ChatGPT conversation import feature (#11272)
* feat: add ChatGPT conversation import feature

Introduces a new import workflow for ChatGPT conversations, including UI components, service logic, and i18n support for English, Simplified Chinese, and Traditional Chinese. Adds an import menu to data settings, a popup for file selection and progress, and a service to parse and store imported conversations as topics and messages.

* fix: ci failure

* refactor: import service and add modular importers

Refactored the import service to support a modular importer architecture. Moved ChatGPT import logic to a dedicated importer class and directory. Updated UI components and i18n descriptions for clarity. Removed unused Redux selector in ImportMenuSettings. This change enables easier addition of new importers in the future.

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix: improve ChatGPT import UX and set model for assistant

Added a loading state and spinner for file selection in the ChatGPT import popup, with new translations for the 'selecting' state in en-us, zh-cn, and zh-tw locales. Also, set the model property for imported assistant messages to display the GPT-5 logo.

---------

Co-authored-by: SuYao <sy20010504@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-21 14:58:47 +08:00
SuYao
dcdd1bf852 refactor: replace renderToolContent function with ToolContent component for improved readability (#11300)
* refactor: replace renderToolContent function with ToolContent component for improved readability

* fix

* fix test
2025-11-21 09:55:46 +08:00
beyondkmp
a12b6bfeca feat: enable native language emoji search with CLDR data format (#11381)
* feat: add i18n support and local data to emoji picker

- Add emoji-picker-element-data package for offline-first emoji data
- Implement i18n translations for emoji picker UI (de, en, es, fr, ja, pt, ru, zh)
- Switch from CDN to local emoji data to improve performance and reliability
- Add locale mapping to match app language with emoji picker data
- Move emoji-picker-element import to EmojiPicker component for better encapsulation
- Use proper TypeScript types instead of 'any' for type safety

This improves user experience by providing localized emoji picker interface
and eliminating dependency on external CDN, ensuring the picker works offline.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* feat: enable native language emoji search with CLDR data format

Switch from emojibase to CLDR format for emoji-picker-element data to support full multi-language search functionality. Users can now search for emojis in their native language (e.g., German users can search "Herz" for ❤️, Spanish users can search "corazón"). Also improves type safety by using the LanguageVarious type for locale mappings.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-11-20 19:23:27 +08:00
180 changed files with 13214 additions and 3328 deletions

View File

@@ -77,7 +77,7 @@ jobs:
with:
token: ${{ secrets.GITHUB_TOKEN }} # Use the built-in GITHUB_TOKEN for bot actions
commit-message: "feat(bot): Weekly automated script run"
title: "🤖 Weekly Automated Update: ${{ env.CURRENT_DATE }}"
title: "🤖 Weekly Auto I18N Sync: ${{ env.CURRENT_DATE }}"
body: |
This PR includes changes generated by the weekly auto i18n.
Review the changes before merging.

View File

@@ -1,152 +0,0 @@
diff --git a/dist/index.js b/dist/index.js
index c2ef089c42e13a8ee4a833899a415564130e5d79..75efa7baafb0f019fb44dd50dec1641eee8879e7 100644
--- a/dist/index.js
+++ b/dist/index.js
@@ -471,7 +471,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
// src/get-model-path.ts
function getModelPath(modelId) {
- return modelId.includes("/") ? modelId : `models/${modelId}`;
+ return modelId.includes("models/") ? modelId : `models/${modelId}`;
}
// src/google-generative-ai-options.ts
diff --git a/dist/index.mjs b/dist/index.mjs
index d75c0cc13c41192408c1f3f2d29d76a7bffa6268..ada730b8cb97d9b7d4cb32883a1d1ff416404d9b 100644
--- a/dist/index.mjs
+++ b/dist/index.mjs
@@ -477,7 +477,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
// src/get-model-path.ts
function getModelPath(modelId) {
- return modelId.includes("/") ? modelId : `models/${modelId}`;
+ return modelId.includes("models/") ? modelId : `models/${modelId}`;
}
// src/google-generative-ai-options.ts
diff --git a/dist/internal/index.js b/dist/internal/index.js
index 277cac8dc734bea2fb4f3e9a225986b402b24f48..bb704cd79e602eb8b0cee1889e42497d59ccdb7a 100644
--- a/dist/internal/index.js
+++ b/dist/internal/index.js
@@ -432,7 +432,15 @@ function prepareTools({
var _a;
tools = (tools == null ? void 0 : tools.length) ? tools : void 0;
const toolWarnings = [];
- const isGemini2 = modelId.includes("gemini-2");
+ // These changes could be safely removed when @ai-sdk/google v3 released.
+ const isLatest = (
+ [
+ 'gemini-flash-latest',
+ 'gemini-flash-lite-latest',
+ 'gemini-pro-latest',
+ ]
+ ).some(id => id === modelId);
+ const isGemini2OrNewer = modelId.includes("gemini-2") || modelId.includes("gemini-3") || isLatest;
const supportsDynamicRetrieval = modelId.includes("gemini-1.5-flash") && !modelId.includes("-8b");
const supportsFileSearch = modelId.includes("gemini-2.5");
if (tools == null) {
@@ -458,7 +466,7 @@ function prepareTools({
providerDefinedTools.forEach((tool) => {
switch (tool.id) {
case "google.google_search":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({ googleSearch: {} });
} else if (supportsDynamicRetrieval) {
googleTools2.push({
@@ -474,7 +482,7 @@ function prepareTools({
}
break;
case "google.url_context":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({ urlContext: {} });
} else {
toolWarnings.push({
@@ -485,7 +493,7 @@ function prepareTools({
}
break;
case "google.code_execution":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({ codeExecution: {} });
} else {
toolWarnings.push({
@@ -507,7 +515,7 @@ function prepareTools({
}
break;
case "google.vertex_rag_store":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({
retrieval: {
vertex_rag_store: {
diff --git a/dist/internal/index.mjs b/dist/internal/index.mjs
index 03b7cc591be9b58bcc2e775a96740d9f98862a10..347d2c12e1cee79f0f8bb258f3844fb0522a6485 100644
--- a/dist/internal/index.mjs
+++ b/dist/internal/index.mjs
@@ -424,7 +424,15 @@ function prepareTools({
var _a;
tools = (tools == null ? void 0 : tools.length) ? tools : void 0;
const toolWarnings = [];
- const isGemini2 = modelId.includes("gemini-2");
+ // These changes could be safely removed when @ai-sdk/google v3 released.
+ const isLatest = (
+ [
+ 'gemini-flash-latest',
+ 'gemini-flash-lite-latest',
+ 'gemini-pro-latest',
+ ]
+ ).some(id => id === modelId);
+ const isGemini2OrNewer = modelId.includes("gemini-2") || modelId.includes("gemini-3") || isLatest;
const supportsDynamicRetrieval = modelId.includes("gemini-1.5-flash") && !modelId.includes("-8b");
const supportsFileSearch = modelId.includes("gemini-2.5");
if (tools == null) {
@@ -450,7 +458,7 @@ function prepareTools({
providerDefinedTools.forEach((tool) => {
switch (tool.id) {
case "google.google_search":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({ googleSearch: {} });
} else if (supportsDynamicRetrieval) {
googleTools2.push({
@@ -466,7 +474,7 @@ function prepareTools({
}
break;
case "google.url_context":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({ urlContext: {} });
} else {
toolWarnings.push({
@@ -477,7 +485,7 @@ function prepareTools({
}
break;
case "google.code_execution":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({ codeExecution: {} });
} else {
toolWarnings.push({
@@ -499,7 +507,7 @@ function prepareTools({
}
break;
case "google.vertex_rag_store":
- if (isGemini2) {
+ if (isGemini2OrNewer) {
googleTools2.push({
retrieval: {
vertex_rag_store: {
@@ -1434,9 +1442,7 @@ var googleTools = {
vertexRagStore
};
export {
- GoogleGenerativeAILanguageModel,
getGroundingMetadataSchema,
- getUrlContextMetadataSchema,
- googleTools
+ getUrlContextMetadataSchema, GoogleGenerativeAILanguageModel, googleTools
};
//# sourceMappingURL=index.mjs.map
\ No newline at end of file

View File

@@ -0,0 +1,26 @@
diff --git a/dist/index.js b/dist/index.js
index dc7b74ba55337c491cdf1ab3e39ca68cc4187884..ace8c90591288e42c2957e93c9bf7984f1b22444 100644
--- a/dist/index.js
+++ b/dist/index.js
@@ -472,7 +472,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
// src/get-model-path.ts
function getModelPath(modelId) {
- return modelId.includes("/") ? modelId : `models/${modelId}`;
+ return modelId.includes("models/") ? modelId : `models/${modelId}`;
}
// src/google-generative-ai-options.ts
diff --git a/dist/index.mjs b/dist/index.mjs
index 8390439c38cb7eaeb52080862cd6f4c58509e67c..a7647f2e11700dff7e1c8d4ae8f99d3637010733 100644
--- a/dist/index.mjs
+++ b/dist/index.mjs
@@ -478,7 +478,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
// src/get-model-path.ts
function getModelPath(modelId) {
- return modelId.includes("/") ? modelId : `models/${modelId}`;
+ return modelId.includes("models/") ? modelId : `models/${modelId}`;
}
// src/google-generative-ai-options.ts

View File

@@ -1,131 +0,0 @@
diff --git a/dist/index.mjs b/dist/index.mjs
index b3f018730a93639aad7c203f15fb1aeb766c73f4..ade2a43d66e9184799d072153df61ef7be4ea110 100644
--- a/dist/index.mjs
+++ b/dist/index.mjs
@@ -296,7 +296,14 @@ var HuggingFaceResponsesLanguageModel = class {
metadata: huggingfaceOptions == null ? void 0 : huggingfaceOptions.metadata,
instructions: huggingfaceOptions == null ? void 0 : huggingfaceOptions.instructions,
...preparedTools && { tools: preparedTools },
- ...preparedToolChoice && { tool_choice: preparedToolChoice }
+ ...preparedToolChoice && { tool_choice: preparedToolChoice },
+ ...(huggingfaceOptions?.reasoningEffort != null && {
+ reasoning: {
+ ...(huggingfaceOptions?.reasoningEffort != null && {
+ effort: huggingfaceOptions.reasoningEffort,
+ }),
+ },
+ }),
};
return { args: baseArgs, warnings };
}
@@ -365,6 +372,20 @@ var HuggingFaceResponsesLanguageModel = class {
}
break;
}
+ case 'reasoning': {
+ for (const contentPart of part.content) {
+ content.push({
+ type: 'reasoning',
+ text: contentPart.text,
+ providerMetadata: {
+ huggingface: {
+ itemId: part.id,
+ },
+ },
+ });
+ }
+ break;
+ }
case "mcp_call": {
content.push({
type: "tool-call",
@@ -519,6 +540,11 @@ var HuggingFaceResponsesLanguageModel = class {
id: value.item.call_id,
toolName: value.item.name
});
+ } else if (value.item.type === 'reasoning') {
+ controller.enqueue({
+ type: 'reasoning-start',
+ id: value.item.id,
+ });
}
return;
}
@@ -570,6 +596,22 @@ var HuggingFaceResponsesLanguageModel = class {
});
return;
}
+ if (isReasoningDeltaChunk(value)) {
+ controller.enqueue({
+ type: 'reasoning-delta',
+ id: value.item_id,
+ delta: value.delta,
+ });
+ return;
+ }
+
+ if (isReasoningEndChunk(value)) {
+ controller.enqueue({
+ type: 'reasoning-end',
+ id: value.item_id,
+ });
+ return;
+ }
},
flush(controller) {
controller.enqueue({
@@ -593,7 +635,8 @@ var HuggingFaceResponsesLanguageModel = class {
var huggingfaceResponsesProviderOptionsSchema = z2.object({
metadata: z2.record(z2.string(), z2.string()).optional(),
instructions: z2.string().optional(),
- strictJsonSchema: z2.boolean().optional()
+ strictJsonSchema: z2.boolean().optional(),
+ reasoningEffort: z2.string().optional(),
});
var huggingfaceResponsesResponseSchema = z2.object({
id: z2.string(),
@@ -727,12 +770,31 @@ var responseCreatedChunkSchema = z2.object({
model: z2.string()
})
});
+var reasoningTextDeltaChunkSchema = z2.object({
+ type: z2.literal('response.reasoning_text.delta'),
+ item_id: z2.string(),
+ output_index: z2.number(),
+ content_index: z2.number(),
+ delta: z2.string(),
+ sequence_number: z2.number(),
+});
+
+var reasoningTextEndChunkSchema = z2.object({
+ type: z2.literal('response.reasoning_text.done'),
+ item_id: z2.string(),
+ output_index: z2.number(),
+ content_index: z2.number(),
+ text: z2.string(),
+ sequence_number: z2.number(),
+});
var huggingfaceResponsesChunkSchema = z2.union([
responseOutputItemAddedSchema,
responseOutputItemDoneSchema,
textDeltaChunkSchema,
responseCompletedChunkSchema,
responseCreatedChunkSchema,
+ reasoningTextDeltaChunkSchema,
+ reasoningTextEndChunkSchema,
z2.object({ type: z2.string() }).loose()
// fallback for unknown chunks
]);
@@ -751,6 +813,12 @@ function isResponseCompletedChunk(chunk) {
function isResponseCreatedChunk(chunk) {
return chunk.type === "response.created";
}
+function isReasoningDeltaChunk(chunk) {
+ return chunk.type === 'response.reasoning_text.delta';
+}
+function isReasoningEndChunk(chunk) {
+ return chunk.type === 'response.reasoning_text.done';
+}
// src/huggingface-provider.ts
function createHuggingFace(options = {}) {

View File

@@ -0,0 +1,140 @@
diff --git a/dist/index.js b/dist/index.js
index 73045a7d38faafdc7f7d2cd79d7ff0e2b031056b..8d948c9ac4ea4b474db9ef3c5491961e7fcf9a07 100644
--- a/dist/index.js
+++ b/dist/index.js
@@ -421,6 +421,17 @@ var OpenAICompatibleChatLanguageModel = class {
text: reasoning
});
}
+ if (choice.message.images) {
+ for (const image of choice.message.images) {
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
+ content.push({
+ type: 'file',
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
+ data: match2 ? match2[1] : image.image_url.url,
+ });
+ }
+ }
if (choice.message.tool_calls != null) {
for (const toolCall of choice.message.tool_calls) {
content.push({
@@ -598,6 +609,17 @@ var OpenAICompatibleChatLanguageModel = class {
delta: delta.content
});
}
+ if (delta.images) {
+ for (const image of delta.images) {
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
+ controller.enqueue({
+ type: 'file',
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
+ data: match2 ? match2[1] : image.image_url.url,
+ });
+ }
+ }
if (delta.tool_calls != null) {
for (const toolCallDelta of delta.tool_calls) {
const index = toolCallDelta.index;
@@ -765,6 +787,14 @@ var OpenAICompatibleChatResponseSchema = import_v43.z.object({
arguments: import_v43.z.string()
})
})
+ ).nullish(),
+ images: import_v43.z.array(
+ import_v43.z.object({
+ type: import_v43.z.literal('image_url'),
+ image_url: import_v43.z.object({
+ url: import_v43.z.string(),
+ })
+ })
).nullish()
}),
finish_reason: import_v43.z.string().nullish()
@@ -795,6 +825,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => import_v43.z.union(
arguments: import_v43.z.string().nullish()
})
})
+ ).nullish(),
+ images: import_v43.z.array(
+ import_v43.z.object({
+ type: import_v43.z.literal('image_url'),
+ image_url: import_v43.z.object({
+ url: import_v43.z.string(),
+ })
+ })
).nullish()
}).nullish(),
finish_reason: import_v43.z.string().nullish()
diff --git a/dist/index.mjs b/dist/index.mjs
index 1c2b9560bbfbfe10cb01af080aeeed4ff59db29c..2c8ddc4fc9bfc5e7e06cfca105d197a08864c427 100644
--- a/dist/index.mjs
+++ b/dist/index.mjs
@@ -405,6 +405,17 @@ var OpenAICompatibleChatLanguageModel = class {
text: reasoning
});
}
+ if (choice.message.images) {
+ for (const image of choice.message.images) {
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
+ content.push({
+ type: 'file',
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
+ data: match2 ? match2[1] : image.image_url.url,
+ });
+ }
+ }
if (choice.message.tool_calls != null) {
for (const toolCall of choice.message.tool_calls) {
content.push({
@@ -582,6 +593,17 @@ var OpenAICompatibleChatLanguageModel = class {
delta: delta.content
});
}
+ if (delta.images) {
+ for (const image of delta.images) {
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
+ controller.enqueue({
+ type: 'file',
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
+ data: match2 ? match2[1] : image.image_url.url,
+ });
+ }
+ }
if (delta.tool_calls != null) {
for (const toolCallDelta of delta.tool_calls) {
const index = toolCallDelta.index;
@@ -749,6 +771,14 @@ var OpenAICompatibleChatResponseSchema = z3.object({
arguments: z3.string()
})
})
+ ).nullish(),
+ images: z3.array(
+ z3.object({
+ type: z3.literal('image_url'),
+ image_url: z3.object({
+ url: z3.string(),
+ })
+ })
).nullish()
}),
finish_reason: z3.string().nullish()
@@ -779,6 +809,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => z3.union([
arguments: z3.string().nullish()
})
})
+ ).nullish(),
+ images: z3.array(
+ z3.object({
+ type: z3.literal('image_url'),
+ image_url: z3.object({
+ url: z3.string(),
+ })
+ })
).nullish()
}).nullish(),
finish_reason: z3.string().nullish()

View File

@@ -1,5 +1,5 @@
diff --git a/dist/index.js b/dist/index.js
index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa96b52ac0d 100644
index 7481f3b3511078068d87d03855b568b20bb86971..8ac5ec28d2f7ad1b3b0d3f8da945c75674e59637 100644
--- a/dist/index.js
+++ b/dist/index.js
@@ -274,6 +274,7 @@ var openaiChatResponseSchema = (0, import_provider_utils3.lazyValidator)(
@@ -18,7 +18,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
tool_calls: import_v42.z.array(
import_v42.z.object({
index: import_v42.z.number(),
@@ -785,6 +787,13 @@ var OpenAIChatLanguageModel = class {
@@ -795,6 +797,13 @@ var OpenAIChatLanguageModel = class {
if (text != null && text.length > 0) {
content.push({ type: "text", text });
}
@@ -32,7 +32,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
for (const toolCall of (_a = choice.message.tool_calls) != null ? _a : []) {
content.push({
type: "tool-call",
@@ -866,6 +875,7 @@ var OpenAIChatLanguageModel = class {
@@ -876,6 +885,7 @@ var OpenAIChatLanguageModel = class {
};
let metadataExtracted = false;
let isActiveText = false;
@@ -40,7 +40,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
const providerMetadata = { openai: {} };
return {
stream: response.pipeThrough(
@@ -923,6 +933,21 @@ var OpenAIChatLanguageModel = class {
@@ -933,6 +943,21 @@ var OpenAIChatLanguageModel = class {
return;
}
const delta = choice.delta;
@@ -62,7 +62,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
if (delta.content != null) {
if (!isActiveText) {
controller.enqueue({ type: "text-start", id: "0" });
@@ -1035,6 +1060,9 @@ var OpenAIChatLanguageModel = class {
@@ -1045,6 +1070,9 @@ var OpenAIChatLanguageModel = class {
}
},
flush(controller) {

View File

@@ -14,7 +14,7 @@
}
},
"enabled": true,
"includes": ["**/*.json", "!*.json", "!**/package.json"]
"includes": ["**/*.json", "!*.json", "!**/package.json", "!coverage/**"]
},
"css": {
"formatter": {
@@ -23,7 +23,7 @@
},
"files": {
"ignoreUnknown": false,
"includes": ["**", "!**/.claude/**"],
"includes": ["**", "!**/.claude/**", "!**/.vscode/**"],
"maxSize": 2097152
},
"formatter": {

View File

@@ -134,58 +134,66 @@ artifactBuildCompleted: scripts/artifact-build-completed.js
releaseInfo:
releaseNotes: |
<!--LANG:en-->
What's New in v1.7.0-rc.1
🎉 MAJOR NEW FEATURE: AI Agents
- Create and manage custom AI agents with specialized tools and permissions
- Dedicated agent sessions with persistent SQLite storage, separate from regular chats
- Real-time tool approval system - review and approve agent actions dynamically
- MCP (Model Context Protocol) integration for connecting external tools
- Slash commands support for quick agent interactions
- OpenAI-compatible REST API for agent access
What's New in v1.7.0-rc.2
✨ New Features:
- AI Providers: Added support for Hugging Face, Mistral, Perplexity, and SophNet
- Knowledge Base: OpenMinerU document preprocessor, full-text search in notes, enhanced tool selection
- Image & OCR: Intel OVMS painting provider and Intel OpenVINO (NPU) OCR support
- MCP Management: Redesigned interface with dual-column layout for easier management
- Languages: Added German language support
⚡ Improvements:
- Upgraded to Electron 38.7.0
- Enhanced system shutdown handling and automatic update checks
- Improved proxy bypass rules
- AI Models: Added support for Gemini 3, Gemini 3 Pro with image preview, and GPT-5.1
- Import: ChatGPT conversation import feature
- Agent: Git Bash detection and requirement check for Windows agents
- Search: Native language emoji search with CLDR data format
- Provider: Endpoint type support for cherryin provider
- Debug: Local crash mini dump file for better diagnostics
🐛 Important Bug Fixes:
- Fixed streaming response issues across multiple AI providers
- Fixed session list scrolling problems
- Fixed knowledge base deletion errors
- Error Handling: Improved error display in AiSdkToChunkAdapter
- Database: Optimized DatabaseManager and fixed libsql crash issues
- Memory: Fixed EventEmitter memory leak in useApiServer hook
- Messages: Fixed adjacent user messages appearing when assistant message contains error only
- Tools: Fixed missing execution state for approved tool permissions
- File Processing: Fixed "no such file" error for non-English filenames in open-mineru
- PDF: Fixed mineru PDF validation and 403 errors
- Images: Fixed base64 image save issues
- Search: Fixed URL context and web search capability
- Models: Added verbosity parameter support for GPT-5 models
- UI: Improved todo tool status icon visibility and colors
- Providers: Fixed api-host for vercel ai-gateway and gitcode update config
⚡ Improvements:
- SDK: Updated Google and OpenAI SDKs with new features
- UI: Simplified knowledge base creation modal and agent creation form
- Tools: Replaced renderToolContent function with ToolContent component
- Architecture: Namespace tool call IDs with session ID to prevent conflicts
- Config: AI SDK configuration refactoring
<!--LANG:zh-CN-->
v1.7.0-rc.1 新特性
🎉 重大更新AI Agent 智能体系统
- 创建和管理专属 AI Agent配置专用工具和权限
- 独立的 Agent 会话,使用 SQLite 持久化存储,与普通聊天分离
- 实时工具审批系统 - 动态审查和批准 Agent 操作
- MCP模型上下文协议集成连接外部工具
- 支持斜杠命令快速交互
- 兼容 OpenAI 的 REST API 访问
v1.7.0-rc.2 新特性
✨ 新功能:
- AI 提供商:新增 Hugging Face、Mistral、Perplexity 和 SophNet 支持
- 知识库OpenMinerU 文档预处理器、笔记全文搜索、增强的工具选择
- 图像与 OCRIntel OVMS 绘图提供商和 Intel OpenVINO (NPU) OCR 支持
- MCP 管理:重构管理界面,采用双列布局,更加方便管理
- 语言:新增德语支持
⚡ 改进:
- 升级到 Electron 38.7.0
- 增强的系统关机处理和自动更新检查
- 改进的代理绕过规则
- AI 模型:新增 Gemini 3、Gemini 3 Pro 图像预览支持,以及 GPT-5.1
- 导入ChatGPT 对话导入功能
- AgentWindows Agent 的 Git Bash 检测和要求检查
- 搜索:支持本地语言 emoji 搜索CLDR 数据格式)
- 提供商cherryin provider 的端点类型支持
- 调试:启用本地崩溃 mini dump 文件,方便诊断
🐛 重要修复:
- 修复多个 AI 提供商的流式响应问题
- 修复会话列表滚动问题
- 修复知识库删除错误
- 错误处理:改进 AiSdkToChunkAdapter 的错误显示
- 数据库:优化 DatabaseManager 并修复 libsql 崩溃问题
- 内存:修复 useApiServer hook 中的 EventEmitter 内存泄漏
- 消息:修复当助手消息仅包含错误时相邻用户消息出现的问题
- 工具:修复批准工具权限缺少执行状态的问题
- 文件处理:修复 open-mineru 处理非英文文件名时的"无此文件"错误
- PDF修复 mineru PDF 验证和 403 错误
- 图片:修复 base64 图片保存问题
- 搜索:修复 URL 上下文和网络搜索功能
- 模型:为 GPT-5 模型添加 verbosity 参数支持
- UI改进 todo 工具状态图标可见性和颜色
- 提供商:修复 vercel ai-gateway 和 gitcode 更新配置的 api-host
⚡ 改进:
- SDK更新 Google 和 OpenAI SDK新增功能和修复
- UI简化知识库创建模态框和 agent 创建表单
- 工具:用 ToolContent 组件替换 renderToolContent 函数,提升可读性
- 架构:用会话 ID 命名工具调用 ID 以防止冲突
- 配置AI SDK 配置重构
<!--LANG:END-->

View File

@@ -1,6 +1,6 @@
{
"name": "CherryStudio",
"version": "1.7.0-rc.1",
"version": "1.7.0-rc.2",
"private": true,
"description": "A powerful AI assistant for producer.",
"main": "./out/main/index.js",
@@ -86,6 +86,7 @@
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch",
"@paymoapp/electron-shutdown-handler": "^1.1.2",
"@strongtz/win32-arm64-msvc": "^0.4.7",
"emoji-picker-element-data": "^1",
"express": "^5.1.0",
"font-list": "^2.0.0",
"graceful-fs": "^4.2.11",
@@ -108,16 +109,17 @@
"@agentic/exa": "^7.3.3",
"@agentic/searxng": "^7.3.3",
"@agentic/tavily": "^7.3.3",
"@ai-sdk/amazon-bedrock": "^3.0.53",
"@ai-sdk/anthropic": "^2.0.44",
"@ai-sdk/amazon-bedrock": "^3.0.56",
"@ai-sdk/anthropic": "^2.0.45",
"@ai-sdk/cerebras": "^1.0.31",
"@ai-sdk/gateway": "^2.0.9",
"@ai-sdk/google": "patch:@ai-sdk/google@npm%3A2.0.36#~/.yarn/patches/@ai-sdk-google-npm-2.0.36-6f3cc06026.patch",
"@ai-sdk/google-vertex": "^3.0.68",
"@ai-sdk/huggingface": "patch:@ai-sdk/huggingface@npm%3A0.0.8#~/.yarn/patches/@ai-sdk-huggingface-npm-0.0.8-d4d0aaac93.patch",
"@ai-sdk/mistral": "^2.0.23",
"@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch",
"@ai-sdk/perplexity": "^2.0.17",
"@ai-sdk/gateway": "^2.0.13",
"@ai-sdk/google": "patch:@ai-sdk/google@npm%3A2.0.40#~/.yarn/patches/@ai-sdk-google-npm-2.0.40-47e0eeee83.patch",
"@ai-sdk/google-vertex": "^3.0.72",
"@ai-sdk/huggingface": "^0.0.10",
"@ai-sdk/mistral": "^2.0.24",
"@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.71#~/.yarn/patches/@ai-sdk-openai-npm-2.0.71-a88ef00525.patch",
"@ai-sdk/perplexity": "^2.0.20",
"@ai-sdk/test-server": "^0.0.1",
"@ant-design/v5-patch-for-react-19": "^1.0.3",
"@anthropic-ai/sdk": "^0.41.0",
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
@@ -162,7 +164,7 @@
"@modelcontextprotocol/sdk": "^1.17.5",
"@mozilla/readability": "^0.6.0",
"@notionhq/client": "^2.2.15",
"@openrouter/ai-sdk-provider": "^1.2.0",
"@openrouter/ai-sdk-provider": "^1.2.5",
"@opentelemetry/api": "^1.9.0",
"@opentelemetry/core": "2.0.0",
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
@@ -238,7 +240,7 @@
"@viz-js/lang-dot": "^1.0.5",
"@viz-js/viz": "^3.14.0",
"@xyflow/react": "^12.4.4",
"ai": "^5.0.90",
"ai": "^5.0.98",
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
"archiver": "^7.0.1",
"async-mutex": "^0.5.0",
@@ -411,8 +413,11 @@
"@langchain/openai@npm:^0.3.16": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
"@langchain/openai@npm:>=0.2.0 <0.7.0": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
"@ai-sdk/openai@npm:2.0.64": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch",
"@ai-sdk/openai@npm:^2.0.42": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch",
"@ai-sdk/google@npm:2.0.36": "patch:@ai-sdk/google@npm%3A2.0.36#~/.yarn/patches/@ai-sdk-google-npm-2.0.36-6f3cc06026.patch"
"@ai-sdk/openai@npm:^2.0.42": "patch:@ai-sdk/openai@npm%3A2.0.71#~/.yarn/patches/@ai-sdk-openai-npm-2.0.71-a88ef00525.patch",
"@ai-sdk/google@npm:2.0.40": "patch:@ai-sdk/google@npm%3A2.0.40#~/.yarn/patches/@ai-sdk-google-npm-2.0.40-47e0eeee83.patch",
"@ai-sdk/openai@npm:2.0.71": "patch:@ai-sdk/openai@npm%3A2.0.71#~/.yarn/patches/@ai-sdk-openai-npm-2.0.71-a88ef00525.patch",
"@ai-sdk/openai-compatible@npm:1.0.27": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch",
"@ai-sdk/openai-compatible@npm:^1.0.19": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch"
},
"packageManager": "yarn@4.9.1",
"lint-staged": {

View File

@@ -1,6 +1,6 @@
{
"name": "@cherrystudio/ai-sdk-provider",
"version": "0.1.2",
"version": "0.1.3",
"description": "Cherry Studio AI SDK provider bundle with CherryIN routing.",
"keywords": [
"ai-sdk",
@@ -42,7 +42,7 @@
},
"dependencies": {
"@ai-sdk/provider": "^2.0.0",
"@ai-sdk/provider-utils": "^3.0.12"
"@ai-sdk/provider-utils": "^3.0.17"
},
"devDependencies": {
"tsdown": "^0.13.3",

View File

@@ -67,6 +67,10 @@ export interface CherryInProviderSettings {
* Optional static headers applied to every request.
*/
headers?: HeadersInput
/**
* Optional endpoint type to distinguish different endpoint behaviors.
*/
endpointType?: 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'image-generation' | 'jina-rerank'
}
export interface CherryInProvider extends ProviderV2 {
@@ -151,7 +155,8 @@ export const createCherryIn = (options: CherryInProviderSettings = {}): CherryIn
baseURL = DEFAULT_CHERRYIN_BASE_URL,
anthropicBaseURL = DEFAULT_CHERRYIN_ANTHROPIC_BASE_URL,
geminiBaseURL = DEFAULT_CHERRYIN_GEMINI_BASE_URL,
fetch
fetch,
endpointType
} = options
const getJsonHeaders = createJsonHeadersGetter(options)
@@ -205,7 +210,7 @@ export const createCherryIn = (options: CherryInProviderSettings = {}): CherryIn
fetch
})
const createChatModel = (modelId: string, settings: OpenAIProviderSettings = {}) => {
const createChatModelByModelId = (modelId: string, settings: OpenAIProviderSettings = {}) => {
if (isAnthropicModel(modelId)) {
return createAnthropicModel(modelId)
}
@@ -223,6 +228,29 @@ export const createCherryIn = (options: CherryInProviderSettings = {}): CherryIn
})
}
const createChatModel = (modelId: string, settings: OpenAIProviderSettings = {}) => {
if (!endpointType) return createChatModelByModelId(modelId, settings)
switch (endpointType) {
case 'anthropic':
return createAnthropicModel(modelId)
case 'gemini':
return createGeminiModel(modelId)
case 'openai':
return createOpenAIChatModel(modelId)
case 'openai-response':
default:
return new OpenAIResponsesLanguageModel(modelId, {
provider: `${CHERRYIN_PROVIDER_NAME}.openai`,
url,
headers: () => ({
...getJsonHeaders(),
...settings.headers
}),
fetch
})
}
}
const createCompletionModel = (modelId: string, settings: OpenAIProviderSettings = {}) =>
new OpenAICompletionLanguageModel(modelId, {
provider: `${CHERRYIN_PROVIDER_NAME}.completion`,

View File

@@ -35,17 +35,17 @@
"peerDependencies": {
"@ai-sdk/google": "^2.0.36",
"@ai-sdk/openai": "^2.0.64",
"@cherrystudio/ai-sdk-provider": "^0.1.2",
"@cherrystudio/ai-sdk-provider": "^0.1.3",
"ai": "^5.0.26"
},
"dependencies": {
"@ai-sdk/anthropic": "^2.0.43",
"@ai-sdk/azure": "^2.0.66",
"@ai-sdk/deepseek": "^1.0.27",
"@ai-sdk/openai-compatible": "^1.0.26",
"@ai-sdk/anthropic": "^2.0.45",
"@ai-sdk/azure": "^2.0.73",
"@ai-sdk/deepseek": "^1.0.29",
"@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch",
"@ai-sdk/provider": "^2.0.0",
"@ai-sdk/provider-utils": "^3.0.16",
"@ai-sdk/xai": "^2.0.31",
"@ai-sdk/provider-utils": "^3.0.17",
"@ai-sdk/xai": "^2.0.34",
"zod": "^4.1.5"
},
"devDependencies": {

View File

@@ -0,0 +1,180 @@
/**
* Mock Provider Instances
* Provides mock implementations for all supported AI providers
*/
import type { ImageModelV2, LanguageModelV2 } from '@ai-sdk/provider'
import { vi } from 'vitest'
/**
* Creates a mock language model with customizable behavior
*/
export function createMockLanguageModel(overrides?: Partial<LanguageModelV2>): LanguageModelV2 {
return {
specificationVersion: 'v1',
provider: 'mock-provider',
modelId: 'mock-model',
defaultObjectGenerationMode: 'tool',
doGenerate: vi.fn().mockResolvedValue({
text: 'Mock response text',
finishReason: 'stop',
usage: {
promptTokens: 10,
completionTokens: 20,
totalTokens: 30
},
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
}),
doStream: vi.fn().mockReturnValue({
stream: (async function* () {
yield {
type: 'text-delta',
textDelta: 'Mock '
}
yield {
type: 'text-delta',
textDelta: 'streaming '
}
yield {
type: 'text-delta',
textDelta: 'response'
}
yield {
type: 'finish',
finishReason: 'stop',
usage: {
promptTokens: 10,
completionTokens: 15,
totalTokens: 25
}
}
})(),
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
}),
...overrides
} as LanguageModelV2
}
/**
* Creates a mock image model with customizable behavior
*/
export function createMockImageModel(overrides?: Partial<ImageModelV2>): ImageModelV2 {
return {
specificationVersion: 'v2',
provider: 'mock-provider',
modelId: 'mock-image-model',
doGenerate: vi.fn().mockResolvedValue({
images: [
{
base64: 'mock-base64-image-data',
uint8Array: new Uint8Array([1, 2, 3, 4, 5]),
mimeType: 'image/png'
}
],
warnings: []
}),
...overrides
} as ImageModelV2
}
/**
* Mock provider configurations for testing
*/
export const mockProviderConfigs = {
openai: {
apiKey: 'sk-test-openai-key-123456789',
baseURL: 'https://api.openai.com/v1',
organization: 'test-org'
},
anthropic: {
apiKey: 'sk-ant-test-key-123456789',
baseURL: 'https://api.anthropic.com'
},
google: {
apiKey: 'test-google-api-key-123456789',
baseURL: 'https://generativelanguage.googleapis.com/v1'
},
xai: {
apiKey: 'xai-test-key-123456789',
baseURL: 'https://api.x.ai/v1'
},
azure: {
apiKey: 'test-azure-key-123456789',
resourceName: 'test-resource',
deployment: 'test-deployment'
},
deepseek: {
apiKey: 'sk-test-deepseek-key-123456789',
baseURL: 'https://api.deepseek.com/v1'
},
openrouter: {
apiKey: 'sk-or-test-key-123456789',
baseURL: 'https://openrouter.ai/api/v1'
},
huggingface: {
apiKey: 'hf_test_key_123456789',
baseURL: 'https://api-inference.huggingface.co'
},
'openai-compatible': {
apiKey: 'test-compatible-key-123456789',
baseURL: 'https://api.example.com/v1',
name: 'test-provider'
},
'openai-chat': {
apiKey: 'sk-test-chat-key-123456789',
baseURL: 'https://api.openai.com/v1'
}
} as const
/**
* Mock provider instances for testing
*/
export const mockProviderInstances = {
openai: {
name: 'openai-mock',
languageModel: createMockLanguageModel({ provider: 'openai', modelId: 'gpt-4' }),
imageModel: createMockImageModel({ provider: 'openai', modelId: 'dall-e-3' })
},
anthropic: {
name: 'anthropic-mock',
languageModel: createMockLanguageModel({ provider: 'anthropic', modelId: 'claude-3-5-sonnet-20241022' })
},
google: {
name: 'google-mock',
languageModel: createMockLanguageModel({ provider: 'google', modelId: 'gemini-2.0-flash-exp' }),
imageModel: createMockImageModel({ provider: 'google', modelId: 'imagen-3.0-generate-001' })
},
xai: {
name: 'xai-mock',
languageModel: createMockLanguageModel({ provider: 'xai', modelId: 'grok-2-latest' }),
imageModel: createMockImageModel({ provider: 'xai', modelId: 'grok-2-image-latest' })
},
deepseek: {
name: 'deepseek-mock',
languageModel: createMockLanguageModel({ provider: 'deepseek', modelId: 'deepseek-chat' })
}
}
export type ProviderId = keyof typeof mockProviderConfigs

View File

@@ -0,0 +1,331 @@
/**
* Mock Responses
* Provides realistic mock responses for all provider types
*/
import { jsonSchema, type ModelMessage, type Tool } from 'ai'
/**
* Standard test messages for all scenarios
*/
export const testMessages = {
simple: [{ role: 'user' as const, content: 'Hello, how are you?' }],
conversation: [
{ role: 'user' as const, content: 'What is the capital of France?' },
{ role: 'assistant' as const, content: 'The capital of France is Paris.' },
{ role: 'user' as const, content: 'What is its population?' }
],
withSystem: [
{ role: 'system' as const, content: 'You are a helpful assistant that provides concise answers.' },
{ role: 'user' as const, content: 'Explain quantum computing in one sentence.' }
],
withImages: [
{
role: 'user' as const,
content: [
{ type: 'text' as const, text: 'What is in this image?' },
{
type: 'image' as const,
image:
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=='
}
]
}
],
toolUse: [{ role: 'user' as const, content: 'What is the weather in San Francisco?' }],
multiTurn: [
{ role: 'user' as const, content: 'Can you help me with a math problem?' },
{ role: 'assistant' as const, content: 'Of course! What math problem would you like help with?' },
{ role: 'user' as const, content: 'What is 15 * 23?' },
{ role: 'assistant' as const, content: '15 * 23 = 345' },
{ role: 'user' as const, content: 'Now divide that by 5' }
]
} satisfies Record<string, ModelMessage[]>
/**
* Standard test tools for tool calling scenarios
*/
export const testTools: Record<string, Tool> = {
getWeather: {
description: 'Get the current weather in a given location',
inputSchema: jsonSchema({
type: 'object',
properties: {
location: {
type: 'string',
description: 'The city and state, e.g. San Francisco, CA'
},
unit: {
type: 'string',
enum: ['celsius', 'fahrenheit'],
description: 'The temperature unit to use'
}
},
required: ['location']
}),
execute: async ({ location, unit = 'fahrenheit' }) => {
return {
location,
temperature: unit === 'celsius' ? 22 : 72,
unit,
condition: 'sunny'
}
}
},
calculate: {
description: 'Perform a mathematical calculation',
inputSchema: jsonSchema({
type: 'object',
properties: {
operation: {
type: 'string',
enum: ['add', 'subtract', 'multiply', 'divide'],
description: 'The operation to perform'
},
a: {
type: 'number',
description: 'The first number'
},
b: {
type: 'number',
description: 'The second number'
}
},
required: ['operation', 'a', 'b']
}),
execute: async ({ operation, a, b }) => {
const operations = {
add: (x: number, y: number) => x + y,
subtract: (x: number, y: number) => x - y,
multiply: (x: number, y: number) => x * y,
divide: (x: number, y: number) => x / y
}
return { result: operations[operation as keyof typeof operations](a, b) }
}
},
searchDatabase: {
description: 'Search for information in a database',
inputSchema: jsonSchema({
type: 'object',
properties: {
query: {
type: 'string',
description: 'The search query'
},
limit: {
type: 'number',
description: 'Maximum number of results to return',
default: 10
}
},
required: ['query']
}),
execute: async ({ query, limit = 10 }) => {
return {
results: [
{ id: 1, title: `Result 1 for ${query}`, relevance: 0.95 },
{ id: 2, title: `Result 2 for ${query}`, relevance: 0.87 }
].slice(0, limit)
}
}
}
}
/**
* Mock streaming chunks for different providers
*/
export const mockStreamingChunks = {
text: [
{ type: 'text-delta' as const, textDelta: 'Hello' },
{ type: 'text-delta' as const, textDelta: ', ' },
{ type: 'text-delta' as const, textDelta: 'this ' },
{ type: 'text-delta' as const, textDelta: 'is ' },
{ type: 'text-delta' as const, textDelta: 'a ' },
{ type: 'text-delta' as const, textDelta: 'test.' }
],
withToolCall: [
{ type: 'text-delta' as const, textDelta: 'Let me check the weather for you.' },
{
type: 'tool-call-delta' as const,
toolCallType: 'function' as const,
toolCallId: 'call_123',
toolName: 'getWeather',
argsTextDelta: '{"location":'
},
{
type: 'tool-call-delta' as const,
toolCallType: 'function' as const,
toolCallId: 'call_123',
toolName: 'getWeather',
argsTextDelta: ' "San Francisco, CA"}'
},
{
type: 'tool-call' as const,
toolCallType: 'function' as const,
toolCallId: 'call_123',
toolName: 'getWeather',
args: { location: 'San Francisco, CA' }
}
],
withFinish: [
{ type: 'text-delta' as const, textDelta: 'Complete response.' },
{
type: 'finish' as const,
finishReason: 'stop' as const,
usage: {
promptTokens: 10,
completionTokens: 5,
totalTokens: 15
}
}
]
}
/**
* Mock complete responses for non-streaming scenarios
*/
export const mockCompleteResponses = {
simple: {
text: 'This is a simple response.',
finishReason: 'stop' as const,
usage: {
promptTokens: 15,
completionTokens: 8,
totalTokens: 23
}
},
withToolCalls: {
text: 'I will check the weather for you.',
toolCalls: [
{
toolCallId: 'call_456',
toolName: 'getWeather',
args: { location: 'New York, NY', unit: 'celsius' }
}
],
finishReason: 'tool-calls' as const,
usage: {
promptTokens: 25,
completionTokens: 12,
totalTokens: 37
}
},
withWarnings: {
text: 'Response with warnings.',
finishReason: 'stop' as const,
usage: {
promptTokens: 10,
completionTokens: 5,
totalTokens: 15
},
warnings: [
{
type: 'unsupported-setting' as const,
message: 'Temperature parameter not supported for this model'
}
]
}
}
/**
* Mock image generation responses
*/
export const mockImageResponses = {
single: {
image: {
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
uint8Array: new Uint8Array([137, 80, 78, 71, 13, 10, 26, 10, 0, 0, 0, 13, 73, 72, 68, 82]),
mimeType: 'image/png' as const
},
warnings: []
},
multiple: {
images: [
{
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
uint8Array: new Uint8Array([137, 80, 78, 71]),
mimeType: 'image/png' as const
},
{
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAIAAAACCAYAAABytg0kAAAAEklEQVR42mNk+M9QzwAEjDAGACCKAgdZ9zImAAAAAElFTkSuQmCC',
uint8Array: new Uint8Array([137, 80, 78, 71]),
mimeType: 'image/png' as const
}
],
warnings: []
},
withProviderMetadata: {
image: {
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
uint8Array: new Uint8Array([137, 80, 78, 71]),
mimeType: 'image/png' as const
},
providerMetadata: {
openai: {
images: [
{
revisedPrompt: 'A detailed and enhanced version of the original prompt'
}
]
}
},
warnings: []
}
}
/**
* Mock error responses
*/
export const mockErrors = {
invalidApiKey: {
name: 'APIError',
message: 'Invalid API key provided',
statusCode: 401
},
rateLimitExceeded: {
name: 'RateLimitError',
message: 'Rate limit exceeded. Please try again later.',
statusCode: 429,
headers: {
'retry-after': '60'
}
},
modelNotFound: {
name: 'ModelNotFoundError',
message: 'The requested model was not found',
statusCode: 404
},
contextLengthExceeded: {
name: 'ContextLengthError',
message: "This model's maximum context length is 4096 tokens",
statusCode: 400
},
timeout: {
name: 'TimeoutError',
message: 'Request timed out after 30000ms',
code: 'ETIMEDOUT'
},
networkError: {
name: 'NetworkError',
message: 'Network connection failed',
code: 'ECONNREFUSED'
}
}

View File

@@ -0,0 +1,329 @@
/**
* Provider-Specific Test Utilities
* Helper functions for testing individual providers with all their parameters
*/
import type { Tool } from 'ai'
import { expect } from 'vitest'
/**
* Provider parameter configurations for comprehensive testing
*/
export const providerParameterMatrix = {
openai: {
models: ['gpt-4', 'gpt-4-turbo', 'gpt-3.5-turbo', 'gpt-4o'],
parameters: {
temperature: [0, 0.5, 0.7, 1.0, 1.5, 2.0],
maxTokens: [100, 500, 1000, 2000, 4000],
topP: [0.1, 0.5, 0.9, 1.0],
frequencyPenalty: [-2.0, -1.0, 0, 1.0, 2.0],
presencePenalty: [-2.0, -1.0, 0, 1.0, 2.0],
stop: [undefined, ['stop'], ['STOP', 'END']],
seed: [undefined, 12345, 67890],
responseFormat: [undefined, { type: 'json_object' as const }],
user: [undefined, 'test-user-123']
},
toolChoice: ['auto', 'required', 'none', { type: 'function' as const, name: 'getWeather' }],
parallelToolCalls: [true, false]
},
anthropic: {
models: ['claude-3-5-sonnet-20241022', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'],
parameters: {
temperature: [0, 0.5, 1.0],
maxTokens: [100, 1000, 4000, 8000],
topP: [0.1, 0.5, 0.9, 1.0],
topK: [undefined, 1, 5, 10, 40],
stop: [undefined, ['Human:', 'Assistant:']],
metadata: [undefined, { userId: 'test-123' }]
},
toolChoice: ['auto', 'any', { type: 'tool' as const, name: 'getWeather' }]
},
google: {
models: ['gemini-2.0-flash-exp', 'gemini-1.5-pro', 'gemini-1.5-flash'],
parameters: {
temperature: [0, 0.5, 0.9, 1.0],
maxTokens: [100, 1000, 2000, 8000],
topP: [0.1, 0.5, 0.95, 1.0],
topK: [undefined, 1, 16, 40],
stopSequences: [undefined, ['END'], ['STOP', 'TERMINATE']]
},
safetySettings: [
undefined,
[
{ category: 'HARM_CATEGORY_HARASSMENT', threshold: 'BLOCK_MEDIUM_AND_ABOVE' },
{ category: 'HARM_CATEGORY_HATE_SPEECH', threshold: 'BLOCK_ONLY_HIGH' }
]
]
},
xai: {
models: ['grok-2-latest', 'grok-2-1212'],
parameters: {
temperature: [0, 0.5, 1.0, 1.5],
maxTokens: [100, 500, 2000, 4000],
topP: [0.1, 0.5, 0.9, 1.0],
stop: [undefined, ['STOP'], ['END', 'TERMINATE']],
seed: [undefined, 12345]
}
},
deepseek: {
models: ['deepseek-chat', 'deepseek-coder'],
parameters: {
temperature: [0, 0.5, 1.0],
maxTokens: [100, 1000, 4000],
topP: [0.1, 0.5, 0.95],
frequencyPenalty: [0, 0.5, 1.0],
presencePenalty: [0, 0.5, 1.0],
stop: [undefined, ['```'], ['END']]
}
},
azure: {
deployments: ['gpt-4-deployment', 'gpt-35-turbo-deployment'],
parameters: {
temperature: [0, 0.7, 1.0],
maxTokens: [100, 1000, 2000],
topP: [0.1, 0.5, 0.95],
frequencyPenalty: [0, 1.0],
presencePenalty: [0, 1.0],
stop: [undefined, ['STOP']]
}
}
} as const
/**
* Creates test cases for all parameter combinations
*/
export function generateParameterTestCases<T extends Record<string, any[]>>(
params: T,
maxCombinations = 50
): Array<Partial<{ [K in keyof T]: T[K][number] }>> {
const keys = Object.keys(params) as Array<keyof T>
const testCases: Array<Partial<{ [K in keyof T]: T[K][number] }>> = []
// Generate combinations using sampling strategy for large parameter spaces
const totalCombinations = keys.reduce((acc, key) => acc * params[key].length, 1)
if (totalCombinations <= maxCombinations) {
// Generate all combinations if total is small
generateAllCombinations(params, keys, 0, {}, testCases)
} else {
// Sample diverse combinations if total is large
generateSampledCombinations(params, keys, maxCombinations, testCases)
}
return testCases
}
function generateAllCombinations<T extends Record<string, any[]>>(
params: T,
keys: Array<keyof T>,
index: number,
current: Partial<{ [K in keyof T]: T[K][number] }>,
results: Array<Partial<{ [K in keyof T]: T[K][number] }>>
) {
if (index === keys.length) {
results.push({ ...current })
return
}
const key = keys[index]
for (const value of params[key]) {
generateAllCombinations(params, keys, index + 1, { ...current, [key]: value }, results)
}
}
function generateSampledCombinations<T extends Record<string, any[]>>(
params: T,
keys: Array<keyof T>,
count: number,
results: Array<Partial<{ [K in keyof T]: T[K][number] }>>
) {
// Generate edge cases first (min/max values)
const edgeCase1: any = {}
const edgeCase2: any = {}
for (const key of keys) {
edgeCase1[key] = params[key][0]
edgeCase2[key] = params[key][params[key].length - 1]
}
results.push(edgeCase1, edgeCase2)
// Generate random combinations for the rest
for (let i = results.length; i < count; i++) {
const combination: any = {}
for (const key of keys) {
const values = params[key]
combination[key] = values[Math.floor(Math.random() * values.length)]
}
results.push(combination)
}
}
/**
* Validates that all provider-specific parameters are correctly passed through
*/
export function validateProviderParams(providerId: string, actualParams: any, expectedParams: any): void {
const requiredFields: Record<string, string[]> = {
openai: ['model', 'messages'],
anthropic: ['model', 'messages'],
google: ['model', 'contents'],
xai: ['model', 'messages'],
deepseek: ['model', 'messages'],
azure: ['messages']
}
const fields = requiredFields[providerId] || ['model', 'messages']
for (const field of fields) {
expect(actualParams).toHaveProperty(field)
}
// Validate optional parameters if they were provided
const optionalParams = ['temperature', 'max_tokens', 'top_p', 'stop', 'tools']
for (const param of optionalParams) {
if (expectedParams[param] !== undefined) {
expect(actualParams[param]).toEqual(expectedParams[param])
}
}
}
/**
* Creates a comprehensive test suite for a provider
*/
// oxlint-disable-next-line no-unused-vars
export function createProviderTestSuite(_providerId: string) {
return {
testBasicCompletion: async (executor: any, model: string) => {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'Hello' }]
})
expect(result).toBeDefined()
expect(result.text).toBeDefined()
expect(typeof result.text).toBe('string')
},
testStreaming: async (executor: any, model: string) => {
const chunks: any[] = []
const result = await executor.streamText({
model,
messages: [{ role: 'user' as const, content: 'Hello' }]
})
for await (const chunk of result.textStream) {
chunks.push(chunk)
}
expect(chunks.length).toBeGreaterThan(0)
},
testTemperature: async (executor: any, model: string, temperatures: number[]) => {
for (const temperature of temperatures) {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'Hello' }],
temperature
})
expect(result).toBeDefined()
}
},
testMaxTokens: async (executor: any, model: string, maxTokensValues: number[]) => {
for (const maxTokens of maxTokensValues) {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'Hello' }],
maxTokens
})
expect(result).toBeDefined()
if (result.usage?.completionTokens) {
expect(result.usage.completionTokens).toBeLessThanOrEqual(maxTokens)
}
}
},
testToolCalling: async (executor: any, model: string, tools: Record<string, Tool>) => {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'What is the weather in SF?' }],
tools
})
expect(result).toBeDefined()
},
testStopSequences: async (executor: any, model: string, stopSequences: string[][]) => {
for (const stop of stopSequences) {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'Count to 10' }],
stop
})
expect(result).toBeDefined()
}
}
}
}
/**
* Generates test data for vision/multimodal testing
*/
export function createVisionTestData() {
return {
imageUrl: 'https://example.com/test-image.jpg',
base64Image:
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
messages: [
{
role: 'user' as const,
content: [
{ type: 'text' as const, text: 'What is in this image?' },
{
type: 'image' as const,
image:
'data:image/png;base64,iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg=='
}
]
}
]
}
}
/**
* Creates mock responses for different finish reasons
*/
export function createFinishReasonMocks() {
return {
stop: {
text: 'Complete response.',
finishReason: 'stop' as const,
usage: { promptTokens: 10, completionTokens: 5, totalTokens: 15 }
},
length: {
text: 'Incomplete response due to',
finishReason: 'length' as const,
usage: { promptTokens: 10, completionTokens: 100, totalTokens: 110 }
},
'tool-calls': {
text: 'Calling tools',
finishReason: 'tool-calls' as const,
toolCalls: [{ toolCallId: 'call_1', toolName: 'getWeather', args: { location: 'SF' } }],
usage: { promptTokens: 10, completionTokens: 8, totalTokens: 18 }
},
'content-filter': {
text: '',
finishReason: 'content-filter' as const,
usage: { promptTokens: 10, completionTokens: 0, totalTokens: 10 }
}
}
}

View File

@@ -0,0 +1,291 @@
/**
* Test Utilities
* Helper functions for testing AI Core functionality
*/
import { expect, vi } from 'vitest'
import type { ProviderId } from '../fixtures/mock-providers'
import { createMockImageModel, createMockLanguageModel, mockProviderConfigs } from '../fixtures/mock-providers'
/**
* Creates a test provider with streaming support
*/
export function createTestStreamingProvider(chunks: any[]) {
return createMockLanguageModel({
doStream: vi.fn().mockReturnValue({
stream: (async function* () {
for (const chunk of chunks) {
yield chunk
}
})(),
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
})
})
}
/**
* Creates a test provider that throws errors
*/
export function createErrorProvider(error: Error) {
return createMockLanguageModel({
doGenerate: vi.fn().mockRejectedValue(error),
doStream: vi.fn().mockImplementation(() => {
throw error
})
})
}
/**
* Collects all chunks from a stream
*/
export async function collectStreamChunks<T>(stream: AsyncIterable<T>): Promise<T[]> {
const chunks: T[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}
return chunks
}
/**
* Waits for a specific number of milliseconds
*/
export function wait(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms))
}
/**
* Creates a mock abort controller that aborts after a delay
*/
export function createDelayedAbortController(delayMs: number): AbortController {
const controller = new AbortController()
setTimeout(() => controller.abort(), delayMs)
return controller
}
/**
* Asserts that a function throws an error with a specific message
*/
export async function expectError(fn: () => Promise<any>, expectedMessage?: string | RegExp): Promise<Error> {
try {
await fn()
throw new Error('Expected function to throw an error, but it did not')
} catch (error) {
if (expectedMessage) {
const message = (error as Error).message
if (typeof expectedMessage === 'string') {
if (!message.includes(expectedMessage)) {
throw new Error(`Expected error message to include "${expectedMessage}", but got "${message}"`)
}
} else {
if (!expectedMessage.test(message)) {
throw new Error(`Expected error message to match ${expectedMessage}, but got "${message}"`)
}
}
}
return error as Error
}
}
/**
* Creates a spy function that tracks calls and arguments
*/
export function createSpy<T extends (...args: any[]) => any>() {
const calls: Array<{ args: Parameters<T>; result?: ReturnType<T>; error?: Error }> = []
const spy = vi.fn((...args: Parameters<T>) => {
try {
const result = undefined as ReturnType<T>
calls.push({ args, result })
return result
} catch (error) {
calls.push({ args, error: error as Error })
throw error
}
})
return {
fn: spy,
calls,
getCalls: () => calls,
getCallCount: () => calls.length,
getLastCall: () => calls[calls.length - 1],
reset: () => {
calls.length = 0
spy.mockClear()
}
}
}
/**
* Validates provider configuration
*/
export function validateProviderConfig(providerId: ProviderId) {
const config = mockProviderConfigs[providerId]
if (!config) {
throw new Error(`No mock configuration found for provider: ${providerId}`)
}
if (!config.apiKey) {
throw new Error(`Provider ${providerId} is missing apiKey in mock config`)
}
return config
}
/**
* Creates a test context with common setup
*/
export function createTestContext() {
const mocks = {
languageModel: createMockLanguageModel(),
imageModel: createMockImageModel(),
providers: new Map<string, any>()
}
const cleanup = () => {
mocks.providers.clear()
vi.clearAllMocks()
}
return {
mocks,
cleanup
}
}
/**
* Measures execution time of an async function
*/
export async function measureTime<T>(fn: () => Promise<T>): Promise<{ result: T; duration: number }> {
const start = Date.now()
const result = await fn()
const duration = Date.now() - start
return { result, duration }
}
/**
* Retries a function until it succeeds or max attempts reached
*/
export async function retryUntilSuccess<T>(fn: () => Promise<T>, maxAttempts = 3, delayMs = 100): Promise<T> {
let lastError: Error | undefined
for (let attempt = 1; attempt <= maxAttempts; attempt++) {
try {
return await fn()
} catch (error) {
lastError = error as Error
if (attempt < maxAttempts) {
await wait(delayMs)
}
}
}
throw lastError || new Error('All retry attempts failed')
}
/**
* Creates a mock streaming response that emits chunks at intervals
*/
export function createTimedStream<T>(chunks: T[], intervalMs = 10) {
return {
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
await wait(intervalMs)
yield chunk
}
}
}
}
/**
* Asserts that two objects are deeply equal, ignoring specified keys
*/
export function assertDeepEqualIgnoring<T extends Record<string, any>>(
actual: T,
expected: T,
ignoreKeys: string[] = []
): void {
const filterKeys = (obj: T): Partial<T> => {
const filtered = { ...obj }
for (const key of ignoreKeys) {
delete filtered[key]
}
return filtered
}
const filteredActual = filterKeys(actual)
const filteredExpected = filterKeys(expected)
expect(filteredActual).toEqual(filteredExpected)
}
/**
* Creates a provider mock that simulates rate limiting
*/
export function createRateLimitedProvider(limitPerSecond: number) {
const calls: number[] = []
return createMockLanguageModel({
doGenerate: vi.fn().mockImplementation(async () => {
const now = Date.now()
calls.push(now)
// Remove calls older than 1 second
const recentCalls = calls.filter((time) => now - time < 1000)
if (recentCalls.length > limitPerSecond) {
throw new Error('Rate limit exceeded')
}
return {
text: 'Rate limited response',
finishReason: 'stop' as const,
usage: { promptTokens: 10, completionTokens: 5, totalTokens: 15 },
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
}
})
})
}
/**
* Validates streaming response structure
*/
export function validateStreamChunk(chunk: any): void {
expect(chunk).toBeDefined()
expect(chunk).toHaveProperty('type')
if (chunk.type === 'text-delta') {
expect(chunk).toHaveProperty('textDelta')
expect(typeof chunk.textDelta).toBe('string')
} else if (chunk.type === 'finish') {
expect(chunk).toHaveProperty('finishReason')
expect(chunk).toHaveProperty('usage')
} else if (chunk.type === 'tool-call') {
expect(chunk).toHaveProperty('toolCallId')
expect(chunk).toHaveProperty('toolName')
expect(chunk).toHaveProperty('args')
}
}
/**
* Creates a test logger that captures log messages
*/
export function createTestLogger() {
const logs: Array<{ level: string; message: string; meta?: any }> = []
return {
info: (message: string, meta?: any) => logs.push({ level: 'info', message, meta }),
warn: (message: string, meta?: any) => logs.push({ level: 'warn', message, meta }),
error: (message: string, meta?: any) => logs.push({ level: 'error', message, meta }),
debug: (message: string, meta?: any) => logs.push({ level: 'debug', message, meta }),
getLogs: () => logs,
clear: () => {
logs.length = 0
}
}
}

View File

@@ -0,0 +1,12 @@
/**
* Test Infrastructure Exports
* Central export point for all test utilities, fixtures, and helpers
*/
// Fixtures
export * from './fixtures/mock-providers'
export * from './fixtures/mock-responses'
// Helpers
export * from './helpers/provider-test-utils'
export * from './helpers/test-utils'

View File

@@ -0,0 +1,499 @@
/**
* RuntimeExecutor.generateText Comprehensive Tests
* Tests non-streaming text generation across all providers with various parameters
*/
import { generateText } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
createMockLanguageModel,
mockCompleteResponses,
mockProviderConfigs,
testMessages,
testTools
} from '../../../__tests__'
import type { AiPlugin } from '../../plugins'
import { globalRegistryManagement } from '../../providers/RegistryManagement'
import { RuntimeExecutor } from '../executor'
// Mock AI SDK
vi.mock('ai', () => ({
generateText: vi.fn()
}))
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
languageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
describe('RuntimeExecutor.generateText', () => {
let executor: RuntimeExecutor<'openai'>
let mockLanguageModel: any
beforeEach(() => {
vi.clearAllMocks()
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
mockLanguageModel = createMockLanguageModel({
provider: 'openai',
modelId: 'gpt-4'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.simple as any)
})
describe('Basic Functionality', () => {
it('should generate text with minimal parameters', async () => {
const result = await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(generateText).toHaveBeenCalledWith({
model: mockLanguageModel,
messages: testMessages.simple
})
expect(result.text).toBe('This is a simple response.')
expect(result.finishReason).toBe('stop')
expect(result.usage).toBeDefined()
})
it('should generate with system messages', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.withSystem
})
expect(generateText).toHaveBeenCalledWith({
model: mockLanguageModel,
messages: testMessages.withSystem
})
})
it('should generate with conversation history', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.conversation
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
messages: testMessages.conversation
})
)
})
})
describe('All Parameter Combinations', () => {
it('should support all parameters together', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
temperature: 0.7,
maxOutputTokens: 500,
topP: 0.9,
frequencyPenalty: 0.5,
presencePenalty: 0.3,
stopSequences: ['STOP'],
seed: 12345
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.7,
maxOutputTokens: 500,
topP: 0.9,
frequencyPenalty: 0.5,
presencePenalty: 0.3,
stopSequences: ['STOP'],
seed: 12345
})
)
})
it('should support partial parameters', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
temperature: 0.5,
maxOutputTokens: 100
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.5,
maxOutputTokens: 100
})
)
})
})
describe('Tool Calling', () => {
beforeEach(() => {
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.withToolCalls as any)
})
it('should support tool calling', async () => {
const result = await executor.generateText({
model: 'gpt-4',
messages: testMessages.toolUse,
tools: testTools
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
tools: testTools
})
)
expect(result.toolCalls).toBeDefined()
expect(result.toolCalls).toHaveLength(1)
})
it('should support toolChoice auto', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.toolUse,
tools: testTools,
toolChoice: 'auto'
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
toolChoice: 'auto'
})
)
})
it('should support toolChoice required', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.toolUse,
tools: testTools,
toolChoice: 'required'
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
toolChoice: 'required'
})
)
})
it('should support toolChoice none', async () => {
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.simple as any)
await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
tools: testTools,
toolChoice: 'none'
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
toolChoice: 'none'
})
)
})
it('should support specific tool selection', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.toolUse,
tools: testTools,
toolChoice: {
type: 'tool',
toolName: 'getWeather'
}
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
toolChoice: {
type: 'tool',
toolName: 'getWeather'
}
})
)
})
})
describe('Multiple Providers', () => {
it('should work with Anthropic provider', async () => {
const anthropicExecutor = RuntimeExecutor.create('anthropic', mockProviderConfigs.anthropic)
const anthropicModel = createMockLanguageModel({
provider: 'anthropic',
modelId: 'claude-3-5-sonnet-20241022'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(anthropicModel)
await anthropicExecutor.generateText({
model: 'claude-3-5-sonnet-20241022',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('anthropic|claude-3-5-sonnet-20241022')
})
it('should work with Google provider', async () => {
const googleExecutor = RuntimeExecutor.create('google', mockProviderConfigs.google)
const googleModel = createMockLanguageModel({
provider: 'google',
modelId: 'gemini-2.0-flash-exp'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(googleModel)
await googleExecutor.generateText({
model: 'gemini-2.0-flash-exp',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('google|gemini-2.0-flash-exp')
})
it('should work with xAI provider', async () => {
const xaiExecutor = RuntimeExecutor.create('xai', mockProviderConfigs.xai)
const xaiModel = createMockLanguageModel({
provider: 'xai',
modelId: 'grok-2-latest'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(xaiModel)
await xaiExecutor.generateText({
model: 'grok-2-latest',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('xai|grok-2-latest')
})
it('should work with DeepSeek provider', async () => {
const deepseekExecutor = RuntimeExecutor.create('deepseek', mockProviderConfigs.deepseek)
const deepseekModel = createMockLanguageModel({
provider: 'deepseek',
modelId: 'deepseek-chat'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(deepseekModel)
await deepseekExecutor.generateText({
model: 'deepseek-chat',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('deepseek|deepseek-chat')
})
})
describe('Plugin Integration', () => {
it('should execute all plugin hooks', async () => {
const pluginCalls: string[] = []
const testPlugin: AiPlugin = {
name: 'test-plugin',
onRequestStart: vi.fn(async () => {
pluginCalls.push('onRequestStart')
}),
transformParams: vi.fn(async (params) => {
pluginCalls.push('transformParams')
return { ...params, temperature: 0.8 }
}),
transformResult: vi.fn(async (result) => {
pluginCalls.push('transformResult')
return { ...result, text: result.text + ' [modified]' }
}),
onRequestEnd: vi.fn(async () => {
pluginCalls.push('onRequestEnd')
})
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
const result = await executorWithPlugin.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(pluginCalls).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
// Verify transformed parameters
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.8
})
)
// Verify transformed result
expect(result.text).toContain('[modified]')
})
it('should handle multiple plugins in order', async () => {
const pluginOrder: string[] = []
const plugin1: AiPlugin = {
name: 'plugin-1',
transformParams: vi.fn(async (params) => {
pluginOrder.push('plugin-1')
return { ...params, temperature: 0.5 }
})
}
const plugin2: AiPlugin = {
name: 'plugin-2',
transformParams: vi.fn(async (params) => {
pluginOrder.push('plugin-2')
return { ...params, maxTokens: 200 }
})
}
const executorWithPlugins = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [plugin1, plugin2])
await executorWithPlugins.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(pluginOrder).toEqual(['plugin-1', 'plugin-2'])
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.5,
maxTokens: 200
})
)
})
})
describe('Error Handling', () => {
it('should handle API errors', async () => {
const error = new Error('API request failed')
vi.mocked(generateText).mockRejectedValue(error)
await expect(
executor.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
).rejects.toThrow('API request failed')
})
it('should execute onError plugin hook', async () => {
const error = new Error('Generation failed')
vi.mocked(generateText).mockRejectedValue(error)
const errorPlugin: AiPlugin = {
name: 'error-handler',
onError: vi.fn()
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
await expect(
executorWithPlugin.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
).rejects.toThrow('Generation failed')
expect(errorPlugin.onError).toHaveBeenCalledWith(
error,
expect.objectContaining({
providerId: 'openai',
modelId: 'gpt-4'
})
)
})
it('should handle model not found error', async () => {
const error = new Error('Model not found: invalid-model')
vi.mocked(globalRegistryManagement.languageModel).mockImplementation(() => {
throw error
})
await expect(
executor.generateText({
model: 'invalid-model',
messages: testMessages.simple
})
).rejects.toThrow('Model not found')
})
})
describe('Usage and Metadata', () => {
it('should return usage information', async () => {
const result = await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(result.usage).toBeDefined()
expect(result.usage.inputTokens).toBe(15)
expect(result.usage.outputTokens).toBe(8)
expect(result.usage.totalTokens).toBe(23)
})
it('should handle warnings', async () => {
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.withWarnings as any)
const result = await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
temperature: 2.5 // Unsupported value
})
expect(result.warnings).toBeDefined()
expect(result.warnings).toHaveLength(1)
expect(result.warnings![0].type).toBe('unsupported-setting')
})
})
describe('Abort Signal', () => {
it('should support abort signal', async () => {
const abortController = new AbortController()
await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
abortSignal: abortController.signal
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
abortSignal: abortController.signal
})
)
})
it('should handle aborted request', async () => {
const abortError = new Error('Request aborted')
abortError.name = 'AbortError'
vi.mocked(generateText).mockRejectedValue(abortError)
const abortController = new AbortController()
abortController.abort()
await expect(
executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
abortSignal: abortController.signal
})
).rejects.toThrow('Request aborted')
})
})
})

View File

@@ -0,0 +1,525 @@
/**
* RuntimeExecutor.streamText Comprehensive Tests
* Tests streaming text generation across all providers with various parameters
*/
import { streamText } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { collectStreamChunks, createMockLanguageModel, mockProviderConfigs, testMessages } from '../../../__tests__'
import type { AiPlugin } from '../../plugins'
import { globalRegistryManagement } from '../../providers/RegistryManagement'
import { RuntimeExecutor } from '../executor'
// Mock AI SDK
vi.mock('ai', () => ({
streamText: vi.fn()
}))
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
languageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
describe('RuntimeExecutor.streamText', () => {
let executor: RuntimeExecutor<'openai'>
let mockLanguageModel: any
beforeEach(() => {
vi.clearAllMocks()
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
mockLanguageModel = createMockLanguageModel({
provider: 'openai',
modelId: 'gpt-4'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
})
describe('Basic Functionality', () => {
it('should stream text with minimal parameters', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Hello'
yield ' '
yield 'World'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Hello' }
yield { type: 'text-delta', textDelta: ' ' }
yield { type: 'text-delta', textDelta: 'World' }
})(),
usage: Promise.resolve({ promptTokens: 5, completionTokens: 3, totalTokens: 8 })
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
const result = await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(streamText).toHaveBeenCalledWith({
model: mockLanguageModel,
messages: testMessages.simple
})
const chunks = await collectStreamChunks(result.textStream)
expect(chunks).toEqual(['Hello', ' ', 'World'])
})
it('should stream with system messages', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.withSystem
})
expect(streamText).toHaveBeenCalledWith({
model: mockLanguageModel,
messages: testMessages.withSystem
})
})
it('should stream multi-turn conversations', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Multi-turn response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Multi-turn response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.multiTurn
})
expect(streamText).toHaveBeenCalled()
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
messages: testMessages.multiTurn
})
)
})
})
describe('Temperature Parameter', () => {
const temperatures = [0, 0.3, 0.5, 0.7, 0.9, 1.0, 1.5, 2.0]
it.each(temperatures)('should support temperature=%s', async (temperature) => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
temperature
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
temperature
})
)
})
})
describe('Max Tokens Parameter', () => {
const maxTokensValues = [10, 50, 100, 500, 1000, 2000, 4000]
it.each(maxTokensValues)('should support maxTokens=%s', async (maxTokens) => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
maxOutputTokens: maxTokens
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
maxTokens
})
)
})
})
describe('Top P Parameter', () => {
const topPValues = [0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 1.0]
it.each(topPValues)('should support topP=%s', async (topP) => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
topP
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
topP
})
)
})
})
describe('Frequency and Presence Penalty', () => {
it('should support frequency penalty', async () => {
const penalties = [-2.0, -1.0, 0, 0.5, 1.0, 1.5, 2.0]
for (const frequencyPenalty of penalties) {
vi.clearAllMocks()
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
frequencyPenalty
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
frequencyPenalty
})
)
}
})
it('should support presence penalty', async () => {
const penalties = [-2.0, -1.0, 0, 0.5, 1.0, 1.5, 2.0]
for (const presencePenalty of penalties) {
vi.clearAllMocks()
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
presencePenalty
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
presencePenalty
})
)
}
})
it('should support both penalties together', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
frequencyPenalty: 0.5,
presencePenalty: 0.5
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
frequencyPenalty: 0.5,
presencePenalty: 0.5
})
)
})
})
describe('Seed Parameter', () => {
it('should support seed for deterministic output', async () => {
const seeds = [0, 12345, 67890, 999999]
for (const seed of seeds) {
vi.clearAllMocks()
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
seed
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
seed
})
)
}
})
})
describe('Abort Signal', () => {
it('should support abort signal', async () => {
const abortController = new AbortController()
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
abortSignal: abortController.signal
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
abortSignal: abortController.signal
})
)
})
it('should handle abort during streaming', async () => {
const abortController = new AbortController()
const mockStream = {
textStream: (async function* () {
yield 'Start'
// Simulate abort
abortController.abort()
throw new Error('Aborted')
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Start' }
throw new Error('Aborted')
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
const result = await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
abortSignal: abortController.signal
})
await expect(async () => {
// oxlint-disable-next-line no-unused-vars
for await (const _chunk of result.textStream) {
// Stream should be interrupted
}
}).rejects.toThrow('Aborted')
})
})
describe('Plugin Integration', () => {
it('should execute plugins during streaming', async () => {
const pluginCalls: string[] = []
const testPlugin: AiPlugin = {
name: 'test-plugin',
onRequestStart: vi.fn(async () => {
pluginCalls.push('onRequestStart')
}),
transformParams: vi.fn(async (params) => {
pluginCalls.push('transformParams')
return { ...params, temperature: 0.5 }
}),
onRequestEnd: vi.fn(async () => {
pluginCalls.push('onRequestEnd')
})
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
const result = await executorWithPlugin.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
// Consume stream
// oxlint-disable-next-line no-unused-vars
for await (const _chunk of result.textStream) {
// Stream chunks
}
expect(pluginCalls).toContain('onRequestStart')
expect(pluginCalls).toContain('transformParams')
// Verify transformed parameters were used
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.5
})
)
})
})
describe('Full Stream with Finish Reason', () => {
it('should provide finish reason in full stream', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
yield {
type: 'finish',
finishReason: 'stop',
usage: { promptTokens: 5, completionTokens: 3, totalTokens: 8 }
}
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
const result = await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
const fullChunks = await collectStreamChunks(result.fullStream)
expect(fullChunks).toHaveLength(2)
expect(fullChunks[0]).toEqual({ type: 'text-delta', textDelta: 'Response' })
expect(fullChunks[1]).toEqual({
type: 'finish',
finishReason: 'stop',
usage: { promptTokens: 5, completionTokens: 3, totalTokens: 8 }
})
})
})
describe('Error Handling', () => {
it('should handle streaming errors', async () => {
const error = new Error('Streaming failed')
vi.mocked(streamText).mockRejectedValue(error)
await expect(
executor.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
).rejects.toThrow('Streaming failed')
})
it('should execute onError plugin hook on failure', async () => {
const error = new Error('Stream error')
vi.mocked(streamText).mockRejectedValue(error)
const errorPlugin: AiPlugin = {
name: 'error-handler',
onError: vi.fn()
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
await expect(
executorWithPlugin.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
).rejects.toThrow('Stream error')
expect(errorPlugin.onError).toHaveBeenCalledWith(
error,
expect.objectContaining({
providerId: 'openai',
modelId: 'gpt-4'
})
)
})
})
})

View File

@@ -235,6 +235,7 @@ export enum IpcChannel {
System_GetDeviceType = 'system:getDeviceType',
System_GetHostname = 'system:getHostname',
System_GetCpuName = 'system:getCpuName',
System_CheckGitBash = 'system:checkGitBash',
// DevTools
System_ToggleDevTools = 'system:toggleDevTools',

View File

@@ -4,3 +4,34 @@ export const defaultAppHeaders = () => {
'X-Title': 'Cherry Studio'
}
}
// Following two function are not being used for now.
// I may use them in the future, so just keep them commented. - by eurfelux
/**
* Converts an `undefined` value to `null`, otherwise returns the value as-is.
* @param value - The value to check
* @returns `null` if the input is `undefined`; otherwise the input value
*/
// export function toNullIfUndefined<T>(value: T | undefined): T | null {
// if (value === undefined) {
// return null
// } else {
// return value
// }
// }
/**
* Converts a `null` value to `undefined`, otherwise returns the value as-is.
* @param value - The value to check
* @returns `undefined` if the input is `null`; otherwise the input value
*/
// export function toUndefinedIfNull<T>(value: T | null): T | undefined {
// if (value === null) {
// return undefined
// } else {
// return value
// }
// }

View File

@@ -104,12 +104,6 @@ const router = express
logger.warn('No models available from providers', { filter })
}
logger.info('Models response ready', {
filter,
total: response.total,
modelIds: response.data.map((m) => m.id)
})
return res.json(response satisfies ApiModelsResponse)
} catch (error: any) {
logger.error('Error fetching models', { error })

View File

@@ -3,7 +3,6 @@ import { createServer } from 'node:http'
import { loggerService } from '@logger'
import { IpcChannel } from '@shared/IpcChannel'
import { agentService } from '../services/agents'
import { windowService } from '../services/WindowService'
import { app } from './app'
import { config } from './config'
@@ -32,11 +31,6 @@ export class ApiServer {
// Load config
const { port, host } = await config.load()
// Initialize AgentService
logger.info('Initializing AgentService')
await agentService.initialize()
logger.info('AgentService initialized')
// Create server with Express app
this.server = createServer(app)
this.applyServerTimeouts(this.server)

View File

@@ -32,7 +32,7 @@ export class ModelsService {
for (const model of models) {
const provider = providers.find((p) => p.id === model.provider)
logger.debug(`Processing model ${model.id}`)
// logger.debug(`Processing model ${model.id}`)
if (!provider) {
logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`)
continue

View File

@@ -34,6 +34,7 @@ import { TrayService } from './services/TrayService'
import { versionService } from './services/VersionService'
import { windowService } from './services/WindowService'
import { initWebviewHotkeys } from './services/WebviewService'
import { runAsyncFunction } from './utils'
const logger = loggerService.withContext('MainEntry')
@@ -170,39 +171,33 @@ if (!app.requestSingleInstanceLock()) {
//start selection assistant service
initSelectionService()
// Initialize Agent Service
try {
await agentService.initialize()
logger.info('Agent service initialized successfully')
} catch (error: any) {
logger.error('Failed to initialize Agent service:', error)
}
runAsyncFunction(async () => {
// Start API server if enabled or if agents exist
try {
const config = await apiServerService.getCurrentConfig()
logger.info('API server config:', config)
// Start API server if enabled or if agents exist
try {
const config = await apiServerService.getCurrentConfig()
logger.info('API server config:', config)
// Check if there are any agents
let shouldStart = config.enabled
if (!shouldStart) {
try {
const { total } = await agentService.listAgents({ limit: 1 })
if (total > 0) {
shouldStart = true
logger.info(`Detected ${total} agent(s), auto-starting API server`)
// Check if there are any agents
let shouldStart = config.enabled
if (!shouldStart) {
try {
const { total } = await agentService.listAgents({ limit: 1 })
if (total > 0) {
shouldStart = true
logger.info(`Detected ${total} agent(s), auto-starting API server`)
}
} catch (error: any) {
logger.warn('Failed to check agent count:', error)
}
} catch (error: any) {
logger.warn('Failed to check agent count:', error)
}
}
if (shouldStart) {
await apiServerService.start()
if (shouldStart) {
await apiServerService.start()
}
} catch (error: any) {
logger.error('Failed to check/start API server:', error)
}
} catch (error: any) {
logger.error('Failed to check/start API server:', error)
}
})
})
registerProtocolClient(app)

View File

@@ -493,6 +493,44 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux'))
ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname())
ipcMain.handle(IpcChannel.System_GetCpuName, () => require('os').cpus()[0].model)
ipcMain.handle(IpcChannel.System_CheckGitBash, () => {
if (!isWin) {
return true // Non-Windows systems don't need Git Bash
}
try {
// Check common Git Bash installation paths
const commonPaths = [
path.join(process.env.ProgramFiles || 'C:\\Program Files', 'Git', 'bin', 'bash.exe'),
path.join(process.env['ProgramFiles(x86)'] || 'C:\\Program Files (x86)', 'Git', 'bin', 'bash.exe'),
path.join(process.env.LOCALAPPDATA || '', 'Programs', 'Git', 'bin', 'bash.exe')
]
// Check if any of the common paths exist
for (const bashPath of commonPaths) {
if (fs.existsSync(bashPath)) {
logger.debug('Git Bash found', { path: bashPath })
return true
}
}
// Check if git is in PATH
const { execSync } = require('child_process')
try {
execSync('git --version', { stdio: 'ignore' })
logger.debug('Git found in PATH')
return true
} catch {
// Git not in PATH
}
logger.debug('Git Bash not found on Windows system')
return false
} catch (error) {
logger.error('Error checking Git Bash', error as Error)
return false
}
})
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
const win = BrowserWindow.fromWebContents(e.sender)
win && win.webContents.toggleDevTools()

View File

@@ -1,17 +1,13 @@
import { type Client, createClient } from '@libsql/client'
import { loggerService } from '@logger'
import { mcpApiService } from '@main/apiServer/services/mcp'
import type { ModelValidationError } from '@main/apiServer/utils'
import { validateModelId } from '@main/apiServer/utils'
import type { AgentType, MCPTool, SlashCommand, Tool } from '@types'
import { objectKeys } from '@types'
import { drizzle, type LibSQLDatabase } from 'drizzle-orm/libsql'
import fs from 'fs'
import path from 'path'
import { MigrationService } from './database/MigrationService'
import * as schema from './database/schema'
import { dbPath } from './drizzle.config'
import { DatabaseManager } from './database/DatabaseManager'
import type { AgentModelField } from './errors'
import { AgentModelValidationError } from './errors'
import { builtinSlashCommands } from './services/claudecode/commands'
@@ -20,22 +16,16 @@ import { builtinTools } from './services/claudecode/tools'
const logger = loggerService.withContext('BaseService')
/**
* Base service class providing shared database connection and utilities
* for all agent-related services.
* Base service class providing shared utilities for all agent-related services.
*
* Features:
* - Programmatic schema management (no CLI dependencies)
* - Automatic table creation and migration
* - Schema version tracking and compatibility checks
* - Transaction-based operations for safety
* - Development vs production mode handling
* - Connection retry logic with exponential backoff
* - Database access through DatabaseManager singleton
* - JSON field serialization/deserialization
* - Path validation and creation
* - Model validation
* - MCP tools and slash commands listing
*/
export abstract class BaseService {
protected static client: Client | null = null
protected static db: LibSQLDatabase<typeof schema> | null = null
protected static isInitialized = false
protected static initializationPromise: Promise<void> | null = null
protected jsonFields: string[] = [
'tools',
'mcps',
@@ -45,23 +35,6 @@ export abstract class BaseService {
'slash_commands'
]
/**
* Initialize database with retry logic and proper error handling
*/
protected static async initialize(): Promise<void> {
// Return existing initialization if in progress
if (BaseService.initializationPromise) {
return BaseService.initializationPromise
}
if (BaseService.isInitialized) {
return
}
BaseService.initializationPromise = BaseService.performInitialization()
return BaseService.initializationPromise
}
public async listMcpTools(agentType: AgentType, ids?: string[]): Promise<Tool[]> {
const tools: Tool[] = []
if (agentType === 'claude-code') {
@@ -101,78 +74,13 @@ export abstract class BaseService {
return []
}
private static async performInitialization(): Promise<void> {
const maxRetries = 3
let lastError: Error
for (let attempt = 1; attempt <= maxRetries; attempt++) {
try {
logger.info(`Initializing Agent database at: ${dbPath} (attempt ${attempt}/${maxRetries})`)
// Ensure the database directory exists
const dbDir = path.dirname(dbPath)
if (!fs.existsSync(dbDir)) {
logger.info(`Creating database directory: ${dbDir}`)
fs.mkdirSync(dbDir, { recursive: true })
}
BaseService.client = createClient({
url: `file:${dbPath}`
})
BaseService.db = drizzle(BaseService.client, { schema })
// Run database migrations
const migrationService = new MigrationService(BaseService.db, BaseService.client)
await migrationService.runMigrations()
BaseService.isInitialized = true
logger.info('Agent database initialized successfully')
return
} catch (error) {
lastError = error as Error
logger.warn(`Database initialization attempt ${attempt} failed:`, lastError)
// Clean up on failure
if (BaseService.client) {
try {
BaseService.client.close()
} catch (closeError) {
logger.warn('Failed to close client during cleanup:', closeError as Error)
}
}
BaseService.client = null
BaseService.db = null
// Wait before retrying (exponential backoff)
if (attempt < maxRetries) {
const delay = Math.pow(2, attempt) * 1000 // 2s, 4s, 8s
logger.info(`Retrying in ${delay}ms...`)
await new Promise((resolve) => setTimeout(resolve, delay))
}
}
}
// All retries failed
BaseService.initializationPromise = null
logger.error('Failed to initialize Agent database after all retries:', lastError!)
throw lastError!
}
protected ensureInitialized(): void {
if (!BaseService.isInitialized || !BaseService.db || !BaseService.client) {
throw new Error('Database not initialized. Call initialize() first.')
}
}
protected get database(): LibSQLDatabase<typeof schema> {
this.ensureInitialized()
return BaseService.db!
}
protected get rawClient(): Client {
this.ensureInitialized()
return BaseService.client!
/**
* Get database instance
* Automatically waits for initialization to complete
*/
protected async getDatabase() {
const dbManager = await DatabaseManager.getInstance()
return dbManager.getDatabase()
}
protected serializeJsonFields(data: any): any {
@@ -284,7 +192,7 @@ export abstract class BaseService {
}
/**
* Force re-initialization (for development/testing)
* Validate agent model configuration
*/
protected async validateAgentModels(
agentType: AgentType,
@@ -325,22 +233,4 @@ export abstract class BaseService {
}
}
}
static async reinitialize(): Promise<void> {
BaseService.isInitialized = false
BaseService.initializationPromise = null
if (BaseService.client) {
try {
BaseService.client.close()
} catch (error) {
logger.warn('Failed to close client during reinitialize:', error as Error)
}
}
BaseService.client = null
BaseService.db = null
await BaseService.initialize()
}
}

View File

@@ -0,0 +1,156 @@
import { type Client, createClient } from '@libsql/client'
import { loggerService } from '@logger'
import type { LibSQLDatabase } from 'drizzle-orm/libsql'
import { drizzle } from 'drizzle-orm/libsql'
import fs from 'fs'
import path from 'path'
import { dbPath } from '../drizzle.config'
import { MigrationService } from './MigrationService'
import * as schema from './schema'
const logger = loggerService.withContext('DatabaseManager')
/**
* Database initialization state
*/
enum InitState {
INITIALIZING = 'initializing',
INITIALIZED = 'initialized',
FAILED = 'failed'
}
/**
* DatabaseManager - Singleton class for managing libsql database connections
*
* Responsibilities:
* - Single source of truth for database connection
* - Thread-safe initialization with state management
* - Automatic migration handling
* - Safe connection cleanup
* - Error recovery and retry logic
* - Windows platform compatibility fixes
*/
export class DatabaseManager {
private static instance: DatabaseManager | null = null
private client: Client | null = null
private db: LibSQLDatabase<typeof schema> | null = null
private state: InitState = InitState.INITIALIZING
/**
* Get the singleton instance (database initialization starts automatically)
*/
public static async getInstance(): Promise<DatabaseManager> {
if (DatabaseManager.instance) {
return DatabaseManager.instance
}
const instance = new DatabaseManager()
await instance.initialize()
DatabaseManager.instance = instance
return instance
}
/**
* Perform the actual initialization
*/
public async initialize(): Promise<void> {
if (this.state === InitState.INITIALIZED) {
return
}
try {
logger.info(`Initializing database at: ${dbPath}`)
// Ensure database directory exists
const dbDir = path.dirname(dbPath)
if (!fs.existsSync(dbDir)) {
logger.info(`Creating database directory: ${dbDir}`)
fs.mkdirSync(dbDir, { recursive: true })
}
// Check if database file is corrupted (Windows specific check)
if (fs.existsSync(dbPath)) {
const stats = fs.statSync(dbPath)
if (stats.size === 0) {
logger.warn('Database file is empty, removing corrupted file')
fs.unlinkSync(dbPath)
}
}
// Create client with platform-specific options
this.client = createClient({
url: `file:${dbPath}`,
// intMode: 'number' helps avoid some Windows compatibility issues
intMode: 'number'
})
// Create drizzle instance
this.db = drizzle(this.client, { schema })
// Run migrations
const migrationService = new MigrationService(this.db, this.client)
await migrationService.runMigrations()
this.state = InitState.INITIALIZED
logger.info('Database initialized successfully')
} catch (error) {
const err = error as Error
logger.error('Database initialization failed:', {
error: err.message,
stack: err.stack
})
// Clean up failed initialization
this.cleanupFailedInit()
// Set failed state
this.state = InitState.FAILED
throw new Error(`Database initialization failed: ${err.message || 'Unknown error'}`)
}
}
/**
* Clean up after failed initialization
*/
private cleanupFailedInit(): void {
if (this.client) {
try {
// On Windows, closing a partially initialized client can crash
// Wrap in try-catch and ignore errors during cleanup
this.client.close()
} catch (error) {
logger.warn('Failed to close client during cleanup:', error as Error)
}
}
this.client = null
this.db = null
}
/**
* Get the database instance
* Automatically waits for initialization to complete
* @throws Error if database initialization failed
*/
public getDatabase(): LibSQLDatabase<typeof schema> {
return this.db!
}
/**
* Get the raw client (for advanced operations)
* Automatically waits for initialization to complete
* @throws Error if database initialization failed
*/
public async getClient(): Promise<Client> {
return this.client!
}
/**
* Check if database is initialized
*/
public isInitialized(): boolean {
return this.state === InitState.INITIALIZED
}
}

View File

@@ -52,7 +52,7 @@ export class MigrationService {
// Get applied migrations
const appliedMigrations = hasMigrationsTable ? await this.getAppliedMigrations() : []
const appliedVersions = new Set(appliedMigrations.map((m) => Number(m.version)))
const appliedByVersion = new Map(appliedMigrations.map((m) => [Number(m.version), m.tag]))
const latestAppliedVersion = appliedMigrations.reduce(
(max, migration) => Math.max(max, Number(migration.version)),
@@ -62,10 +62,35 @@ export class MigrationService {
logger.info(`Latest applied migration: v${latestAppliedVersion}, latest available: v${latestJournalVersion}`)
// Find pending migrations (compare journal idx with stored version, which is the same value)
const pendingMigrations = journal.entries
.filter((entry) => !appliedVersions.has(entry.idx))
.sort((a, b) => a.idx - b.idx)
// Find pending migrations and tag mismatches
const pendingMigrations: MigrationJournal['entries'] = []
const mismatchedMigrations: Array<{ entry: MigrationJournal['entries'][0]; appliedTag: string }> = []
for (const entry of journal.entries) {
const appliedTag = appliedByVersion.get(entry.idx)
if (appliedTag === undefined) {
// Migration not applied yet
pendingMigrations.push(entry)
} else if (appliedTag !== entry.tag) {
// Migration tag mismatch - different migration was applied for this version
mismatchedMigrations.push({ entry, appliedTag })
}
}
// Handle tag mismatches by re-applying the correct migration
if (mismatchedMigrations.length > 0) {
for (const { entry, appliedTag } of mismatchedMigrations) {
logger.warn(
`Migration tag mismatch for version ${entry.idx}: ` +
`applied "${appliedTag}", expected "${entry.tag}". Will re-apply correct migration.`
)
// Add to pending so it gets executed
pendingMigrations.push(entry)
}
}
// Sort pending migrations by idx
pendingMigrations.sort((a, b) => a.idx - b.idx)
if (pendingMigrations.length === 0) {
logger.info('Database is up to date')
@@ -149,7 +174,14 @@ export class MigrationService {
throw new Error('Migrations table missing after executing migration; cannot record progress')
}
await this.db.insert(migrations).values(newMigration)
// Use upsert to handle tag mismatch case (same version, different tag)
await this.db
.insert(migrations)
.values(newMigration)
.onConflictDoUpdate({
target: migrations.version,
set: { tag: migration.tag, executedAt: Date.now() }
})
const executionTime = Date.now() - startTime
logger.info(`Migration ${migration.tag} completed in ${executionTime}ms`)

View File

@@ -7,8 +7,14 @@
* Schema evolution is handled by Drizzle Kit migrations.
*/
// Database Manager (Singleton)
export * from './DatabaseManager'
// Drizzle ORM schemas
export * from './schema'
// Repository helpers
export * from './sessionMessageRepository'
// Migration Service
export * from './MigrationService'

View File

@@ -15,26 +15,16 @@ import { sessionMessagesTable } from './schema'
const logger = loggerService.withContext('AgentMessageRepository')
type TxClient = any
export type PersistUserMessageParams = AgentMessageUserPersistPayload & {
sessionId: string
agentSessionId?: string
tx?: TxClient
}
export type PersistAssistantMessageParams = AgentMessageAssistantPersistPayload & {
sessionId: string
agentSessionId: string
tx?: TxClient
}
type PersistExchangeParams = AgentMessagePersistExchangePayload & {
tx?: TxClient
}
type PersistExchangeResult = AgentMessagePersistExchangeResult
class AgentMessageRepository extends BaseService {
private static instance: AgentMessageRepository | null = null
@@ -87,17 +77,13 @@ class AgentMessageRepository extends BaseService {
return deserialized
}
private getWriter(tx?: TxClient): TxClient {
return tx ?? this.database
}
private async findExistingMessageRow(
writer: TxClient,
sessionId: string,
role: string,
messageId: string
): Promise<SessionMessageRow | null> {
const candidateRows: SessionMessageRow[] = await writer
const database = await this.getDatabase()
const candidateRows: SessionMessageRow[] = await database
.select()
.from(sessionMessagesTable)
.where(and(eq(sessionMessagesTable.session_id, sessionId), eq(sessionMessagesTable.role, role)))
@@ -122,10 +108,7 @@ class AgentMessageRepository extends BaseService {
private async upsertMessage(
params: PersistUserMessageParams | PersistAssistantMessageParams
): Promise<AgentSessionMessageEntity> {
await AgentMessageRepository.initialize()
this.ensureInitialized()
const { sessionId, agentSessionId = '', payload, metadata, createdAt, tx } = params
const { sessionId, agentSessionId = '', payload, metadata, createdAt } = params
if (!payload?.message?.role) {
throw new Error('Message payload missing role')
@@ -135,18 +118,18 @@ class AgentMessageRepository extends BaseService {
throw new Error('Message payload missing id')
}
const writer = this.getWriter(tx)
const database = await this.getDatabase()
const now = createdAt ?? payload.message.createdAt ?? new Date().toISOString()
const serializedPayload = this.serializeMessage(payload)
const serializedMetadata = this.serializeMetadata(metadata)
const existingRow = await this.findExistingMessageRow(writer, sessionId, payload.message.role, payload.message.id)
const existingRow = await this.findExistingMessageRow(sessionId, payload.message.role, payload.message.id)
if (existingRow) {
const metadataToPersist = serializedMetadata ?? existingRow.metadata ?? undefined
const agentSessionToPersist = agentSessionId || existingRow.agent_session_id || ''
await writer
await database
.update(sessionMessagesTable)
.set({
content: serializedPayload,
@@ -175,7 +158,7 @@ class AgentMessageRepository extends BaseService {
updated_at: now
}
const [saved] = await writer.insert(sessionMessagesTable).values(insertData).returning()
const [saved] = await database.insert(sessionMessagesTable).values(insertData).returning()
return this.deserialize(saved)
}
@@ -188,49 +171,38 @@ class AgentMessageRepository extends BaseService {
return this.upsertMessage(params)
}
async persistExchange(params: PersistExchangeParams): Promise<PersistExchangeResult> {
await AgentMessageRepository.initialize()
this.ensureInitialized()
async persistExchange(params: AgentMessagePersistExchangePayload): Promise<AgentMessagePersistExchangeResult> {
const { sessionId, agentSessionId, user, assistant } = params
const result = await this.database.transaction(async (tx) => {
const exchangeResult: PersistExchangeResult = {}
const exchangeResult: AgentMessagePersistExchangeResult = {}
if (user?.payload) {
exchangeResult.userMessage = await this.persistUserMessage({
sessionId,
agentSessionId,
payload: user.payload,
metadata: user.metadata,
createdAt: user.createdAt,
tx
})
}
if (user?.payload) {
exchangeResult.userMessage = await this.persistUserMessage({
sessionId,
agentSessionId,
payload: user.payload,
metadata: user.metadata,
createdAt: user.createdAt
})
}
if (assistant?.payload) {
exchangeResult.assistantMessage = await this.persistAssistantMessage({
sessionId,
agentSessionId,
payload: assistant.payload,
metadata: assistant.metadata,
createdAt: assistant.createdAt,
tx
})
}
if (assistant?.payload) {
exchangeResult.assistantMessage = await this.persistAssistantMessage({
sessionId,
agentSessionId,
payload: assistant.payload,
metadata: assistant.metadata,
createdAt: assistant.createdAt
})
}
return exchangeResult
})
return result
return exchangeResult
}
async getSessionHistory(sessionId: string): Promise<AgentPersistedMessage[]> {
await AgentMessageRepository.initialize()
this.ensureInitialized()
try {
const rows = await this.database
const database = await this.getDatabase()
const rows = await database
.select()
.from(sessionMessagesTable)
.where(eq(sessionMessagesTable.session_id, sessionId))

View File

@@ -32,14 +32,8 @@ export class AgentService extends BaseService {
return AgentService.instance
}
async initialize(): Promise<void> {
await BaseService.initialize()
}
// Agent Methods
async createAgent(req: CreateAgentRequest): Promise<CreateAgentResponse> {
this.ensureInitialized()
const id = `agent_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`
const now = new Date().toISOString()
@@ -75,8 +69,9 @@ export class AgentService extends BaseService {
updated_at: now
}
await this.database.insert(agentsTable).values(insertData)
const result = await this.database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
const database = await this.getDatabase()
await database.insert(agentsTable).values(insertData)
const result = await database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
if (!result[0]) {
throw new Error('Failed to create agent')
}
@@ -86,9 +81,8 @@ export class AgentService extends BaseService {
}
async getAgent(id: string): Promise<GetAgentResponse | null> {
this.ensureInitialized()
const result = await this.database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
const database = await this.getDatabase()
const result = await database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
if (!result[0]) {
return null
@@ -118,9 +112,9 @@ export class AgentService extends BaseService {
}
async listAgents(options: ListOptions = {}): Promise<{ agents: AgentEntity[]; total: number }> {
this.ensureInitialized() // Build query with pagination
const totalResult = await this.database.select({ count: count() }).from(agentsTable)
// Build query with pagination
const database = await this.getDatabase()
const totalResult = await database.select({ count: count() }).from(agentsTable)
const sortBy = options.sortBy || 'created_at'
const orderBy = options.orderBy || 'desc'
@@ -128,7 +122,7 @@ export class AgentService extends BaseService {
const sortField = agentsTable[sortBy]
const orderFn = orderBy === 'asc' ? asc : desc
const baseQuery = this.database.select().from(agentsTable).orderBy(orderFn(sortField))
const baseQuery = database.select().from(agentsTable).orderBy(orderFn(sortField))
const result =
options.limit !== undefined
@@ -151,8 +145,6 @@ export class AgentService extends BaseService {
updates: UpdateAgentRequest,
options: { replace?: boolean } = {}
): Promise<UpdateAgentResponse | null> {
this.ensureInitialized()
// Check if agent exists
const existing = await this.getAgent(id)
if (!existing) {
@@ -195,22 +187,21 @@ export class AgentService extends BaseService {
}
}
await this.database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id))
const database = await this.getDatabase()
await database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id))
return await this.getAgent(id)
}
async deleteAgent(id: string): Promise<boolean> {
this.ensureInitialized()
const result = await this.database.delete(agentsTable).where(eq(agentsTable.id, id))
const database = await this.getDatabase()
const result = await database.delete(agentsTable).where(eq(agentsTable.id, id))
return result.rowsAffected > 0
}
async agentExists(id: string): Promise<boolean> {
this.ensureInitialized()
const result = await this.database
const database = await this.getDatabase()
const result = await database
.select({ id: agentsTable.id })
.from(agentsTable)
.where(eq(agentsTable.id, id))

View File

@@ -104,14 +104,9 @@ export class SessionMessageService extends BaseService {
return SessionMessageService.instance
}
async initialize(): Promise<void> {
await BaseService.initialize()
}
async sessionMessageExists(id: number): Promise<boolean> {
this.ensureInitialized()
const result = await this.database
const database = await this.getDatabase()
const result = await database
.select({ id: sessionMessagesTable.id })
.from(sessionMessagesTable)
.where(eq(sessionMessagesTable.id, id))
@@ -124,10 +119,9 @@ export class SessionMessageService extends BaseService {
sessionId: string,
options: ListOptions = {}
): Promise<{ messages: AgentSessionMessageEntity[] }> {
this.ensureInitialized()
// Get messages with pagination
const baseQuery = this.database
const database = await this.getDatabase()
const baseQuery = database
.select()
.from(sessionMessagesTable)
.where(eq(sessionMessagesTable.session_id, sessionId))
@@ -146,9 +140,8 @@ export class SessionMessageService extends BaseService {
}
async deleteSessionMessage(sessionId: string, messageId: number): Promise<boolean> {
this.ensureInitialized()
const result = await this.database
const database = await this.getDatabase()
const result = await database
.delete(sessionMessagesTable)
.where(and(eq(sessionMessagesTable.id, messageId), eq(sessionMessagesTable.session_id, sessionId)))
@@ -160,8 +153,6 @@ export class SessionMessageService extends BaseService {
messageData: CreateSessionMessageRequest,
abortController: AbortController
): Promise<SessionStreamResult> {
this.ensureInitialized()
return await this.startSessionMessageStream(session, messageData, abortController)
}
@@ -270,10 +261,9 @@ export class SessionMessageService extends BaseService {
}
private async getLastAgentSessionId(sessionId: string): Promise<string> {
this.ensureInitialized()
try {
const result = await this.database
const database = await this.getDatabase()
const result = await database
.select({ agent_session_id: sessionMessagesTable.agent_session_id })
.from(sessionMessagesTable)
.where(and(eq(sessionMessagesTable.session_id, sessionId), not(eq(sessionMessagesTable.agent_session_id, ''))))

View File

@@ -30,10 +30,6 @@ export class SessionService extends BaseService {
return SessionService.instance
}
async initialize(): Promise<void> {
await BaseService.initialize()
}
/**
* Override BaseService.listSlashCommands to merge builtin and plugin commands
*/
@@ -84,13 +80,12 @@ export class SessionService extends BaseService {
agentId: string,
req: Partial<CreateSessionRequest> = {}
): Promise<GetAgentSessionResponse | null> {
this.ensureInitialized()
// Validate agent exists - we'll need to import AgentService for this check
// For now, we'll skip this validation to avoid circular dependencies
// The database foreign key constraint will handle this
const agents = await this.database.select().from(agentsTable).where(eq(agentsTable.id, agentId)).limit(1)
const database = await this.getDatabase()
const agents = await database.select().from(agentsTable).where(eq(agentsTable.id, agentId)).limit(1)
if (!agents[0]) {
throw new Error('Agent not found')
}
@@ -135,9 +130,10 @@ export class SessionService extends BaseService {
updated_at: now
}
await this.database.insert(sessionsTable).values(insertData)
const db = await this.getDatabase()
await db.insert(sessionsTable).values(insertData)
const result = await this.database.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1)
const result = await db.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1)
if (!result[0]) {
throw new Error('Failed to create session')
@@ -148,9 +144,8 @@ export class SessionService extends BaseService {
}
async getSession(agentId: string, id: string): Promise<GetAgentSessionResponse | null> {
this.ensureInitialized()
const result = await this.database
const database = await this.getDatabase()
const result = await database
.select()
.from(sessionsTable)
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
@@ -176,8 +171,6 @@ export class SessionService extends BaseService {
agentId?: string,
options: ListOptions = {}
): Promise<{ sessions: AgentSessionEntity[]; total: number }> {
this.ensureInitialized()
// Build where conditions
const whereConditions: SQL[] = []
if (agentId) {
@@ -192,16 +185,13 @@ export class SessionService extends BaseService {
: undefined
// Get total count
const totalResult = await this.database.select({ count: count() }).from(sessionsTable).where(whereClause)
const database = await this.getDatabase()
const totalResult = await database.select({ count: count() }).from(sessionsTable).where(whereClause)
const total = totalResult[0].count
// Build list query with pagination - sort by updated_at descending (latest first)
const baseQuery = this.database
.select()
.from(sessionsTable)
.where(whereClause)
.orderBy(desc(sessionsTable.updated_at))
const baseQuery = database.select().from(sessionsTable).where(whereClause).orderBy(desc(sessionsTable.updated_at))
const result =
options.limit !== undefined
@@ -220,8 +210,6 @@ export class SessionService extends BaseService {
id: string,
updates: UpdateSessionRequest
): Promise<UpdateSessionResponse | null> {
this.ensureInitialized()
// Check if session exists
const existing = await this.getSession(agentId, id)
if (!existing) {
@@ -262,15 +250,15 @@ export class SessionService extends BaseService {
}
}
await this.database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))
const database = await this.getDatabase()
await database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))
return await this.getSession(agentId, id)
}
async deleteSession(agentId: string, id: string): Promise<boolean> {
this.ensureInitialized()
const result = await this.database
const database = await this.getDatabase()
const result = await database
.delete(sessionsTable)
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
@@ -278,9 +266,8 @@ export class SessionService extends BaseService {
}
async sessionExists(agentId: string, id: string): Promise<boolean> {
this.ensureInitialized()
const result = await this.database
const database = await this.getDatabase()
const result = await database
.select({ id: sessionsTable.id })
.from(sessionsTable)
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))

View File

@@ -21,6 +21,11 @@ describe('stripLocalCommandTags', () => {
'<local-command-stdout>line1</local-command-stdout>\nkeep\n<local-command-stderr>Error</local-command-stderr>'
expect(stripLocalCommandTags(input)).toBe('line1\nkeep\nError')
})
it('if no tags present, returns original string', () => {
const input = 'just some normal text'
expect(stripLocalCommandTags(input)).toBe(input)
})
})
describe('Claude → AiSDK transform', () => {
@@ -188,6 +193,111 @@ describe('Claude → AiSDK transform', () => {
expect(toolResult.output).toBe('ok')
})
it('handles tool calls without streaming events (no content_block_start/stop)', () => {
const state = new ClaudeStreamState({ agentSessionId: '12344' })
const parts: ReturnType<typeof transformSDKMessageToStreamParts>[number][] = []
const messages: SDKMessage[] = [
{
...baseStreamMetadata,
type: 'assistant',
uuid: uuid(20),
message: {
id: 'msg-tool-no-stream',
type: 'message',
role: 'assistant',
model: 'claude-test',
content: [
{
type: 'tool_use',
id: 'tool-read',
name: 'Read',
input: { file_path: '/test.txt' }
},
{
type: 'tool_use',
id: 'tool-bash',
name: 'Bash',
input: { command: 'ls -la' }
}
],
stop_reason: 'tool_use',
stop_sequence: null,
usage: {
input_tokens: 10,
output_tokens: 20
}
}
} as unknown as SDKMessage,
{
...baseStreamMetadata,
type: 'user',
uuid: uuid(21),
message: {
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: 'tool-read',
content: 'file contents',
is_error: false
}
]
}
} as SDKMessage,
{
...baseStreamMetadata,
type: 'user',
uuid: uuid(22),
message: {
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: 'tool-bash',
content: 'total 42\n...',
is_error: false
}
]
}
} as SDKMessage
]
for (const message of messages) {
const transformed = transformSDKMessageToStreamParts(message, state)
parts.push(...transformed)
}
const types = parts.map((part) => part.type)
expect(types).toEqual(['tool-call', 'tool-call', 'tool-result', 'tool-result'])
const toolCalls = parts.filter((part) => part.type === 'tool-call') as Extract<
(typeof parts)[number],
{ type: 'tool-call' }
>[]
expect(toolCalls).toHaveLength(2)
expect(toolCalls[0].toolName).toBe('Read')
expect(toolCalls[0].toolCallId).toBe('12344:tool-read')
expect(toolCalls[1].toolName).toBe('Bash')
expect(toolCalls[1].toolCallId).toBe('12344:tool-bash')
const toolResults = parts.filter((part) => part.type === 'tool-result') as Extract<
(typeof parts)[number],
{ type: 'tool-result' }
>[]
expect(toolResults).toHaveLength(2)
// This is the key assertion - toolName should NOT be 'unknown'
expect(toolResults[0].toolName).toBe('Read')
expect(toolResults[0].toolCallId).toBe('12344:tool-read')
expect(toolResults[0].input).toEqual({ file_path: '/test.txt' })
expect(toolResults[0].output).toBe('file contents')
expect(toolResults[1].toolName).toBe('Bash')
expect(toolResults[1].toolCallId).toBe('12344:tool-bash')
expect(toolResults[1].input).toEqual({ command: 'ls -la' })
expect(toolResults[1].output).toBe('total 42\n...')
})
it('handles streaming text completion', () => {
const state = new ClaudeStreamState({ agentSessionId: baseStreamMetadata.session_id })
const parts: ReturnType<typeof transformSDKMessageToStreamParts>[number][] = []
@@ -300,4 +410,87 @@ describe('Claude → AiSDK transform', () => {
expect(finishStep.finishReason).toBe('stop')
expect(finishStep.usage).toEqual({ inputTokens: 2, outputTokens: 4, totalTokens: 6 })
})
it('emits fallback text when Claude sends a snapshot instead of deltas', () => {
const state = new ClaudeStreamState({ agentSessionId: '12344' })
const parts: ReturnType<typeof transformSDKMessageToStreamParts>[number][] = []
const messages: SDKMessage[] = [
{
...baseStreamMetadata,
type: 'stream_event',
uuid: uuid(30),
event: {
type: 'message_start',
message: {
id: 'msg-fallback',
type: 'message',
role: 'assistant',
model: 'claude-test',
content: [],
stop_reason: null,
stop_sequence: null,
usage: {}
}
}
} as unknown as SDKMessage,
{
...baseStreamMetadata,
type: 'stream_event',
uuid: uuid(31),
event: {
type: 'content_block_start',
index: 0,
content_block: {
type: 'text',
text: ''
}
}
} as unknown as SDKMessage,
{
...baseStreamMetadata,
type: 'assistant',
uuid: uuid(32),
message: {
id: 'msg-fallback-content',
type: 'message',
role: 'assistant',
model: 'claude-test',
content: [
{
type: 'text',
text: 'Final answer without streaming deltas.'
}
],
stop_reason: 'end_turn',
stop_sequence: null,
usage: {
input_tokens: 3,
output_tokens: 7
}
}
} as unknown as SDKMessage
]
for (const message of messages) {
const transformed = transformSDKMessageToStreamParts(message, state)
parts.push(...transformed)
}
const types = parts.map((part) => part.type)
expect(types).toEqual(['start-step', 'text-start', 'text-delta', 'text-end', 'finish-step'])
const delta = parts.find((part) => part.type === 'text-delta') as Extract<
(typeof parts)[number],
{ type: 'text-delta' }
>
expect(delta.text).toBe('Final answer without streaming deltas.')
const finish = parts.find((part) => part.type === 'finish-step') as Extract<
(typeof parts)[number],
{ type: 'finish-step' }
>
expect(finish.usage).toEqual({ inputTokens: 3, outputTokens: 7, totalTokens: 10 })
expect(finish.finishReason).toBe('stop')
})
})

View File

@@ -153,6 +153,20 @@ export class ClaudeStreamState {
return this.blocksByIndex.get(index)
}
getFirstOpenTextBlock(): TextBlockState | undefined {
const candidates: TextBlockState[] = []
for (const block of this.blocksByIndex.values()) {
if (block.kind === 'text') {
candidates.push(block)
}
}
if (candidates.length === 0) {
return undefined
}
candidates.sort((a, b) => a.index - b.index)
return candidates[0]
}
getToolBlockById(toolCallId: string): ToolBlockState | undefined {
const index = this.toolIndexByNamespacedId.get(toolCallId)
if (index === undefined) return undefined
@@ -217,10 +231,10 @@ export class ClaudeStreamState {
* Persists the final input payload for a tool block once the provider signals
* completion so that downstream tool results can reference the original call.
*/
completeToolBlock(toolCallId: string, input: unknown, providerMetadata?: ProviderMetadata): void {
completeToolBlock(toolCallId: string, toolName: string, input: unknown, providerMetadata?: ProviderMetadata): void {
const block = this.getToolBlockByRawId(toolCallId)
this.registerToolCall(toolCallId, {
toolName: block?.toolName ?? 'unknown',
toolName,
input,
providerMetadata
})

View File

@@ -2,7 +2,14 @@
import { EventEmitter } from 'node:events'
import { createRequire } from 'node:module'
import type { CanUseTool, McpHttpServerConfig, Options, SDKMessage } from '@anthropic-ai/claude-agent-sdk'
import type {
CanUseTool,
HookCallback,
McpHttpServerConfig,
Options,
PreToolUseHookInput,
SDKMessage
} from '@anthropic-ai/claude-agent-sdk'
import { query } from '@anthropic-ai/claude-agent-sdk'
import { loggerService } from '@logger'
import { config as apiConfigService } from '@main/apiServer/config'
@@ -157,6 +164,63 @@ class ClaudeCodeService implements AgentServiceInterface {
})
}
const preToolUseHook: HookCallback = async (input, toolUseID, options) => {
// Type guard to ensure we're handling PreToolUse event
if (input.hook_event_name !== 'PreToolUse') {
return {}
}
const hookInput = input as PreToolUseHookInput
const toolName = hookInput.tool_name
logger.debug('PreToolUse hook triggered', {
session_id: hookInput.session_id,
tool_name: hookInput.tool_name,
tool_use_id: toolUseID,
tool_input: hookInput.tool_input,
cwd: hookInput.cwd,
permission_mode: hookInput.permission_mode,
autoAllowTools: autoAllowTools
})
if (options?.signal?.aborted) {
logger.debug('PreToolUse hook signal already aborted; skipping tool use', {
tool_name: hookInput.tool_name
})
return {}
}
// handle auto approved tools since it never triggers canUseTool
const normalizedToolName = normalizeToolName(toolName)
if (toolUseID) {
const bypassAll = input.permission_mode === 'bypassPermissions'
const autoAllowed = autoAllowTools.has(toolName) || autoAllowTools.has(normalizedToolName)
if (bypassAll || autoAllowed) {
const namespacedToolCallId = buildNamespacedToolCallId(session.id, toolUseID)
logger.debug('handling auto approved tools', {
toolName,
normalizedToolName,
namespacedToolCallId,
permission_mode: input.permission_mode,
autoAllowTools
})
const isRecord = (v: unknown): v is Record<string, unknown> => {
return !!v && typeof v === 'object' && !Array.isArray(v)
}
const toolInput = isRecord(input.tool_input) ? input.tool_input : {}
await promptForToolApproval(toolName, toolInput, {
...options,
toolCallId: namespacedToolCallId,
autoApprove: true
})
}
}
// Return to proceed without modification
return {}
}
// Build SDK options from parameters
const options: Options = {
abortController,
@@ -180,7 +244,14 @@ class ClaudeCodeService implements AgentServiceInterface {
permissionMode: session.configuration?.permission_mode,
maxTurns: session.configuration?.max_turns,
allowedTools: session.allowed_tools,
canUseTool
canUseTool,
hooks: {
PreToolUse: [
{
hooks: [preToolUseHook]
}
]
}
}
if (session.accessible_paths.length > 1) {
@@ -414,23 +485,6 @@ class ClaudeCodeService implements AgentServiceInterface {
}
}
if (message.type === 'assistant' || message.type === 'user') {
logger.silly('claude response', {
message,
content: JSON.stringify(message.message.content)
})
} else if (message.type === 'stream_event') {
// logger.silly('Claude stream event', {
// message,
// event: JSON.stringify(message.event)
// })
} else {
logger.silly('Claude response', {
message,
event: JSON.stringify(message)
})
}
const chunks = transformSDKMessageToStreamParts(message, streamState)
for (const chunk of chunks) {
stream.emit('data', {

View File

@@ -31,6 +31,7 @@ type PendingPermissionRequest = {
abortListener?: () => void
originalInput: Record<string, unknown>
toolName: string
toolCallId?: string
}
type RendererPermissionRequestPayload = {
@@ -45,6 +46,7 @@ type RendererPermissionRequestPayload = {
createdAt: number
expiresAt: number
suggestions: PermissionUpdate[]
autoApprove?: boolean
}
type RendererPermissionResultPayload = {
@@ -52,6 +54,7 @@ type RendererPermissionResultPayload = {
behavior: ToolPermissionBehavior
message?: string
reason: 'response' | 'timeout' | 'aborted' | 'no-window'
toolCallId?: string
}
const pendingRequests = new Map<string, PendingPermissionRequest>()
@@ -145,7 +148,8 @@ const finalizeRequest = (
requestId,
behavior: update.behavior,
message: update.behavior === 'deny' ? update.message : undefined,
reason
reason,
toolCallId: pending.toolCallId
}
const dispatched = broadcastToRenderer(IpcChannel.AgentToolPermission_Result, resultPayload)
@@ -210,6 +214,7 @@ const ensureIpcHandlersRegistered = () => {
type PromptForToolApprovalOptions = {
signal: AbortSignal
suggestions?: PermissionUpdate[]
autoApprove?: boolean
// NOTICE: This ID is namespaced with session ID, not the raw SDK tool call ID.
// Format: `${sessionId}:${rawToolCallId}`, e.g., `session_123:WebFetch_0`
@@ -270,7 +275,8 @@ export async function promptForToolApproval(
inputPreview,
createdAt,
expiresAt,
suggestions: sanitizedSuggestions
suggestions: sanitizedSuggestions,
autoApprove: options.autoApprove
}
const defaultDenyUpdate: PermissionResult = { behavior: 'deny', message: 'Tool request aborted before user decision' }
@@ -299,7 +305,8 @@ export async function promptForToolApproval(
timeout,
originalInput: sanitizedInput,
toolName,
signal: options?.signal
signal: options?.signal,
toolCallId: options.toolCallId
}
if (options?.signal) {

View File

@@ -110,7 +110,7 @@ const sdkMessageToProviderMetadata = (message: SDKMessage): ProviderMetadata =>
* blocks across calls so that incremental deltas can be correlated correctly.
*/
export function transformSDKMessageToStreamParts(sdkMessage: SDKMessage, state: ClaudeStreamState): AgentStreamPart[] {
logger.silly('Transforming SDKMessage', { message: sdkMessage })
logger.silly('Transforming SDKMessage', { message: JSON.stringify(sdkMessage) })
switch (sdkMessage.type) {
case 'assistant':
return handleAssistantMessage(sdkMessage, state)
@@ -186,14 +186,13 @@ function handleAssistantMessage(
for (const block of content) {
switch (block.type) {
case 'text':
if (!isStreamingActive) {
const sanitizedText = stripLocalCommandTags(block.text)
if (sanitizedText) {
textBlocks.push(sanitizedText)
}
case 'text': {
const sanitizedText = stripLocalCommandTags(block.text)
if (sanitizedText) {
textBlocks.push(sanitizedText)
}
break
}
case 'tool_use':
handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks)
break
@@ -203,7 +202,16 @@ function handleAssistantMessage(
}
}
if (!isStreamingActive && textBlocks.length > 0) {
if (textBlocks.length === 0) {
return chunks
}
const combinedText = textBlocks.join('')
if (!combinedText) {
return chunks
}
if (!isStreamingActive) {
const id = message.uuid?.toString() || generateMessageId()
state.beginStep()
chunks.push({
@@ -219,7 +227,7 @@ function handleAssistantMessage(
chunks.push({
type: 'text-delta',
id,
text: textBlocks.join(''),
text: combinedText,
providerMetadata
})
chunks.push({
@@ -230,7 +238,27 @@ function handleAssistantMessage(
return finalizeNonStreamingStep(message, state, chunks)
}
return chunks
const existingTextBlock = state.getFirstOpenTextBlock()
const fallbackId = existingTextBlock?.id || message.uuid?.toString() || generateMessageId()
if (!existingTextBlock) {
chunks.push({
type: 'text-start',
id: fallbackId,
providerMetadata
})
}
chunks.push({
type: 'text-delta',
id: fallbackId,
text: combinedText,
providerMetadata
})
chunks.push({
type: 'text-end',
id: fallbackId,
providerMetadata
})
return finalizeNonStreamingStep(message, state, chunks)
}
/**
@@ -252,7 +280,7 @@ function handleAssistantToolUse(
providerExecuted: true,
providerMetadata
})
state.completeToolBlock(block.id, block.input, providerMetadata)
state.completeToolBlock(block.id, block.name, block.input, providerMetadata)
}
/**
@@ -459,6 +487,9 @@ function handleStreamEvent(
}
case 'message_stop': {
if (!state.hasActiveStep()) {
break
}
const pending = state.getPendingUsage()
chunks.push({
type: 'finish-step',

View File

@@ -122,7 +122,8 @@ const api = {
system: {
getDeviceType: () => ipcRenderer.invoke(IpcChannel.System_GetDeviceType),
getHostname: () => ipcRenderer.invoke(IpcChannel.System_GetHostname),
getCpuName: () => ipcRenderer.invoke(IpcChannel.System_GetCpuName)
getCpuName: () => ipcRenderer.invoke(IpcChannel.System_GetCpuName),
checkGitBash: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.System_CheckGitBash)
},
devTools: {
toggle: () => ipcRenderer.invoke(IpcChannel.System_ToggleDevTools)

View File

@@ -386,14 +386,13 @@ export class AiSdkToChunkAdapter {
case 'error':
this.onChunk({
type: ChunkType.ERROR,
error:
chunk.error instanceof AISDKError
? chunk.error
: new ProviderSpecificError({
message: formatErrorMessage(chunk.error),
provider: 'unknown',
cause: chunk.error
})
error: AISDKError.isInstance(chunk.error)
? chunk.error
: new ProviderSpecificError({
message: formatErrorMessage(chunk.error),
provider: 'unknown',
cause: chunk.error
})
})
break

View File

@@ -32,6 +32,7 @@ import {
prepareSpecialProviderConfig,
providerToAiSdkConfig
} from './provider/providerConfig'
import type { AiSdkConfig } from './types'
const logger = loggerService.withContext('ModernAiProvider')
@@ -44,7 +45,7 @@ export type ModernAiProviderConfig = AiSdkMiddlewareConfig & {
export default class ModernAiProvider {
private legacyProvider: LegacyAiProvider
private config?: ReturnType<typeof providerToAiSdkConfig>
private config?: AiSdkConfig
private actualProvider: Provider
private model?: Model
private localProvider: Awaited<AiSdkProvider> | null = null
@@ -89,6 +90,11 @@ export default class ModernAiProvider {
// 每次请求时重新生成配置以确保API key轮换生效
this.config = providerToAiSdkConfig(this.actualProvider, this.model)
logger.debug('Generated provider config for completions', this.config)
// 检查 config 是否存在
if (!this.config) {
throw new Error('Provider config is undefined; cannot proceed with completions')
}
if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) {
providerConfig.isImageGenerationEndpoint = true
}
@@ -149,7 +155,8 @@ export default class ModernAiProvider {
params: StreamTextParams,
config: ModernAiProviderConfig
): Promise<CompletionsResult> {
if (config.isImageGenerationEndpoint) {
// ai-gateway不是image/generation 端点所以就先不走legacy了
if (config.isImageGenerationEndpoint && config.provider!.id !== SystemProviderIds['ai-gateway']) {
// 使用 legacy 实现处理图像生成(支持图片编辑等高级功能)
if (!config.uiMessages) {
throw new Error('uiMessages is required for image generation endpoint')
@@ -463,8 +470,13 @@ export default class ModernAiProvider {
// 如果支持新的 AI SDK使用现代化实现
if (isModernSdkSupported(this.actualProvider)) {
try {
// 确保 config 已定义
if (!this.config) {
throw new Error('Provider config is undefined; cannot proceed with generateImage')
}
// 确保本地provider已创建
if (!this.localProvider) {
if (!this.localProvider && this.config) {
this.localProvider = await createAiSdkProvider(this.config)
if (!this.localProvider) {
throw new Error('Local provider not created')

View File

@@ -1,6 +1,6 @@
import { loggerService } from '@logger'
import { isNewApiProvider } from '@renderer/config/providers'
import type { Provider } from '@renderer/types'
import { isNewApiProvider } from '@renderer/utils/provider'
import { AihubmixAPIClient } from './aihubmix/AihubmixAPIClient'
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'

View File

@@ -7,7 +7,6 @@ import {
isSupportFlexServiceTierModel
} from '@renderer/config/models'
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
import { isSupportServiceTierProvider } from '@renderer/config/providers'
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import type {
@@ -19,7 +18,6 @@ import type {
MCPToolResponse,
MemoryItem,
Model,
OpenAIVerbosity,
Provider,
ToolCallResponse,
WebSearchProviderResponse,
@@ -33,6 +31,7 @@ import {
OpenAIServiceTiers,
SystemProviderIds
} from '@renderer/types'
import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import type { Message } from '@renderer/types/newMessage'
import type {
RequestOptions,
@@ -48,6 +47,7 @@ import type {
import { isJSON, parseJSON } from '@renderer/utils'
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { isSupportServiceTierProvider } from '@renderer/utils/provider'
import { defaultTimeout } from '@shared/config/constant'
import { defaultAppHeaders } from '@shared/utils'
import { isEmpty } from 'lodash'

View File

@@ -58,10 +58,27 @@ vi.mock('../aws/AwsBedrockAPIClient', () => ({
AwsBedrockAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('@renderer/services/AssistantService.ts', () => ({
getDefaultAssistant: () => {
return {
id: 'default',
name: 'default',
emoji: '😀',
prompt: '',
topics: [],
messages: [],
type: 'assistant',
regularPhrases: [],
settings: {}
}
}
}))
// Mock the models config to prevent circular dependency issues
vi.mock('@renderer/config/models', () => ({
findTokenLimit: vi.fn(),
isReasoningModel: vi.fn(),
isOpenAILLMModel: vi.fn(),
SYSTEM_MODELS: {
silicon: [],
defaultModel: []

View File

@@ -1,7 +1,8 @@
import { GoogleGenAI } from '@google/genai'
import { loggerService } from '@logger'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
import type { Model, Provider, VertexProvider } from '@renderer/types'
import { isVertexProvider } from '@renderer/utils/provider'
import { isEmpty } from 'lodash'
import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient'

View File

@@ -10,7 +10,6 @@ import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import {
findTokenLimit,
GEMINI_FLASH_MODEL_REGEX,
getOpenAIWebSearchParams,
getThinkModelType,
isClaudeReasoningModel,
isDeepSeekHybridInferenceModel,
@@ -40,12 +39,6 @@ import {
MODEL_SUPPORTED_REASONING_EFFORT,
ZHIPU_RESULT_TOKENS
} from '@renderer/config/models'
import {
isSupportArrayContentProvider,
isSupportDeveloperRoleProvider,
isSupportEnableThinkingProvider,
isSupportStreamOptionsProvider
} from '@renderer/config/providers'
import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService'
import { estimateTextTokens } from '@renderer/services/TokenService'
@@ -89,6 +82,12 @@ import {
openAIToolsToMcpTool
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import {
isSupportArrayContentProvider,
isSupportDeveloperRoleProvider,
isSupportEnableThinkingProvider,
isSupportStreamOptionsProvider
} from '@renderer/utils/provider'
import { t } from 'i18next'
import type { GenericChunk } from '../../middleware/schemas'
@@ -743,7 +742,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
: {}),
...this.getProviderSpecificParameters(assistant, model),
...reasoningEffort,
...getOpenAIWebSearchParams(model, enableWebSearch),
// ...getOpenAIWebSearchParams(model, enableWebSearch),
// OpenRouter usage tracking
...(this.provider.id === 'openrouter' ? { usage: { include: true } } : {}),
...extra_body,

View File

@@ -12,7 +12,6 @@ import {
isSupportVerbosityModel,
isVisionModel
} from '@renderer/config/models'
import { isSupportDeveloperRoleProvider } from '@renderer/config/providers'
import { estimateTextTokens } from '@renderer/services/TokenService'
import type {
FileMetadata,
@@ -43,6 +42,7 @@ import {
openAIToolsToMcpTool
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { isSupportDeveloperRoleProvider } from '@renderer/utils/provider'
import { MB } from '@shared/config/constant'
import { t } from 'i18next'
import { isEmpty } from 'lodash'

View File

@@ -1,6 +1,7 @@
import { loggerService } from '@logger'
import { isZhipuModel } from '@renderer/config/models'
import { getStoreProviders } from '@renderer/hooks/useStore'
import { getDefaultModel } from '@renderer/services/AssistantService'
import type { Chunk } from '@renderer/types/chunk'
import type { CompletionsParams, CompletionsResult } from '../schemas'
@@ -66,7 +67,7 @@ export const ErrorHandlerMiddleware =
}
function handleError(error: any, params: CompletionsParams): any {
if (isZhipuModel(params.assistant.model) && error.status && !params.enableGenerateImage) {
if (isZhipuModel(params.assistant.model || getDefaultModel()) && error.status && !params.enableGenerateImage) {
return handleZhipuError(error)
}

View File

@@ -1,10 +1,10 @@
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
import { loggerService } from '@logger'
import { isSupportedThinkingTokenQwenModel } from '@renderer/config/models'
import { isSupportEnableThinkingProvider } from '@renderer/config/providers'
import type { MCPTool } from '@renderer/types'
import { type Assistant, type Message, type Model, type Provider } from '@renderer/types'
import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
import type { LanguageModelMiddleware } from 'ai'
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
import { isEmpty } from 'lodash'
@@ -12,6 +12,7 @@ import { isEmpty } from 'lodash'
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
import { noThinkMiddleware } from './noThinkMiddleware'
import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware'
import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
import { toolChoiceMiddleware } from './toolChoiceMiddleware'
@@ -217,6 +218,14 @@ function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config:
middleware: noThinkMiddleware()
})
}
if (config.provider.id === SystemProviderIds.openrouter && config.enableReasoning) {
builder.add({
name: 'openrouter-reasoning-redaction',
middleware: openrouterReasoningMiddleware()
})
logger.debug('Added OpenRouter reasoning redaction middleware')
}
}
/**

View File

@@ -0,0 +1,50 @@
import type { LanguageModelV2StreamPart } from '@ai-sdk/provider'
import type { LanguageModelMiddleware } from 'ai'
/**
* https://openrouter.ai/docs/docs/best-practices/reasoning-tokens#example-preserving-reasoning-blocks-with-openrouter-and-claude
*
* @returns LanguageModelMiddleware - a middleware filter redacted block
*/
export function openrouterReasoningMiddleware(): LanguageModelMiddleware {
const REDACTED_BLOCK = '[REDACTED]'
return {
middlewareVersion: 'v2',
wrapGenerate: async ({ doGenerate }) => {
const { content, ...rest } = await doGenerate()
const modifiedContent = content.map((part) => {
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
return {
...part,
text: part.text.replace(REDACTED_BLOCK, '')
}
}
return part
})
return { content: modifiedContent, ...rest }
},
wrapStream: async ({ doStream }) => {
const { stream, ...rest } = await doStream()
return {
stream: stream.pipeThrough(
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
transform(
chunk: LanguageModelV2StreamPart,
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
) {
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
controller.enqueue({
...chunk,
delta: chunk.delta.replace(REDACTED_BLOCK, '')
})
} else {
controller.enqueue(chunk)
}
}
})
),
...rest
}
}
}
}

View File

@@ -0,0 +1,234 @@
import type { Message, Model } from '@renderer/types'
import type { FileMetadata } from '@renderer/types/file'
import { FileTypes } from '@renderer/types/file'
import {
AssistantMessageStatus,
type FileMessageBlock,
type ImageMessageBlock,
MessageBlockStatus,
MessageBlockType,
type ThinkingMessageBlock,
UserMessageStatus
} from '@renderer/types/newMessage'
import { beforeEach, describe, expect, it, vi } from 'vitest'
const { convertFileBlockToFilePartMock, convertFileBlockToTextPartMock } = vi.hoisted(() => ({
convertFileBlockToFilePartMock: vi.fn(),
convertFileBlockToTextPartMock: vi.fn()
}))
vi.mock('../fileProcessor', () => ({
convertFileBlockToFilePart: convertFileBlockToFilePartMock,
convertFileBlockToTextPart: convertFileBlockToTextPartMock
}))
const visionModelIds = new Set(['gpt-4o-mini', 'qwen-image-edit'])
const imageEnhancementModelIds = new Set(['qwen-image-edit'])
vi.mock('@renderer/config/models', () => ({
isVisionModel: (model: Model) => visionModelIds.has(model.id),
isImageEnhancementModel: (model: Model) => imageEnhancementModelIds.has(model.id)
}))
type MockableMessage = Message & {
__mockContent?: string
__mockFileBlocks?: FileMessageBlock[]
__mockImageBlocks?: ImageMessageBlock[]
__mockThinkingBlocks?: ThinkingMessageBlock[]
}
vi.mock('@renderer/utils/messageUtils/find', () => ({
getMainTextContent: (message: Message) => (message as MockableMessage).__mockContent ?? '',
findFileBlocks: (message: Message) => (message as MockableMessage).__mockFileBlocks ?? [],
findImageBlocks: (message: Message) => (message as MockableMessage).__mockImageBlocks ?? [],
findThinkingBlocks: (message: Message) => (message as MockableMessage).__mockThinkingBlocks ?? []
}))
import { convertMessagesToSdkMessages, convertMessageToSdkParam } from '../messageConverter'
let messageCounter = 0
let blockCounter = 0
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o-mini',
name: 'GPT-4o mini',
provider: 'openai',
group: 'openai',
...overrides
})
const createMessage = (role: Message['role']): MockableMessage =>
({
id: `message-${++messageCounter}`,
role,
assistantId: 'assistant-1',
topicId: 'topic-1',
createdAt: new Date(2024, 0, 1, 0, 0, messageCounter).toISOString(),
status: role === 'assistant' ? AssistantMessageStatus.SUCCESS : UserMessageStatus.SUCCESS,
blocks: []
}) as MockableMessage
const createFileBlock = (
messageId: string,
overrides: Partial<Omit<FileMessageBlock, 'file' | 'messageId' | 'type'>> & { file?: Partial<FileMetadata> } = {}
): FileMessageBlock => {
const { file, ...blockOverrides } = overrides
const timestamp = new Date(2024, 0, 1, 0, 0, ++blockCounter).toISOString()
return {
id: blockOverrides.id ?? `file-block-${blockCounter}`,
messageId,
type: MessageBlockType.FILE,
createdAt: blockOverrides.createdAt ?? timestamp,
status: blockOverrides.status ?? MessageBlockStatus.SUCCESS,
file: {
id: file?.id ?? `file-${blockCounter}`,
name: file?.name ?? 'document.txt',
origin_name: file?.origin_name ?? 'document.txt',
path: file?.path ?? '/tmp/document.txt',
size: file?.size ?? 1024,
ext: file?.ext ?? '.txt',
type: file?.type ?? FileTypes.TEXT,
created_at: file?.created_at ?? timestamp,
count: file?.count ?? 1,
...file
},
...blockOverrides
}
}
const createImageBlock = (
messageId: string,
overrides: Partial<Omit<ImageMessageBlock, 'type' | 'messageId'>> = {}
): ImageMessageBlock => ({
id: overrides.id ?? `image-block-${++blockCounter}`,
messageId,
type: MessageBlockType.IMAGE,
createdAt: overrides.createdAt ?? new Date(2024, 0, 1, 0, 0, blockCounter).toISOString(),
status: overrides.status ?? MessageBlockStatus.SUCCESS,
url: overrides.url ?? 'https://example.com/image.png',
...overrides
})
describe('messageConverter', () => {
beforeEach(() => {
convertFileBlockToFilePartMock.mockReset()
convertFileBlockToTextPartMock.mockReset()
convertFileBlockToFilePartMock.mockResolvedValue(null)
convertFileBlockToTextPartMock.mockResolvedValue(null)
messageCounter = 0
blockCounter = 0
})
describe('convertMessageToSdkParam', () => {
it('includes text and image parts for user messages on vision models', async () => {
const model = createModel()
const message = createMessage('user')
message.__mockContent = 'Describe this picture'
message.__mockImageBlocks = [createImageBlock(message.id, { url: 'https://example.com/cat.png' })]
const result = await convertMessageToSdkParam(message, true, model)
expect(result).toEqual({
role: 'user',
content: [
{ type: 'text', text: 'Describe this picture' },
{ type: 'image', image: 'https://example.com/cat.png' }
]
})
})
it('returns file instructions as a system message when native uploads succeed', async () => {
const model = createModel()
const message = createMessage('user')
message.__mockContent = 'Summarize the PDF'
message.__mockFileBlocks = [createFileBlock(message.id)]
convertFileBlockToFilePartMock.mockResolvedValueOnce({
type: 'file',
filename: 'document.pdf',
mediaType: 'application/pdf',
data: 'fileid://remote-file'
})
const result = await convertMessageToSdkParam(message, false, model)
expect(result).toEqual([
{
role: 'system',
content: 'fileid://remote-file'
},
{
role: 'user',
content: [{ type: 'text', text: 'Summarize the PDF' }]
}
])
})
})
describe('convertMessagesToSdkMessages', () => {
it('appends assistant images to the final user message for image enhancement models', async () => {
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
const initialUser = createMessage('user')
initialUser.__mockContent = 'Start editing'
const assistant = createMessage('assistant')
assistant.__mockContent = 'Here is the current preview'
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/preview.png' })]
const finalUser = createMessage('user')
finalUser.__mockContent = 'Increase the brightness'
const result = await convertMessagesToSdkMessages([initialUser, assistant, finalUser], model)
expect(result).toEqual([
{
role: 'assistant',
content: [{ type: 'text', text: 'Here is the current preview' }]
},
{
role: 'user',
content: [
{ type: 'text', text: 'Increase the brightness' },
{ type: 'image', image: 'https://example.com/preview.png' }
]
}
])
})
it('preserves preceding system instructions when building enhancement payloads', async () => {
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
const fileUser = createMessage('user')
fileUser.__mockContent = 'Use this document as inspiration'
fileUser.__mockFileBlocks = [createFileBlock(fileUser.id, { file: { ext: '.pdf', type: FileTypes.DOCUMENT } })]
convertFileBlockToFilePartMock.mockResolvedValueOnce({
type: 'file',
filename: 'reference.pdf',
mediaType: 'application/pdf',
data: 'fileid://reference'
})
const assistant = createMessage('assistant')
assistant.__mockContent = 'Generated previews ready'
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/reference.png' })]
const finalUser = createMessage('user')
finalUser.__mockContent = 'Apply the edits'
const result = await convertMessagesToSdkMessages([fileUser, assistant, finalUser], model)
expect(result).toEqual([
{ role: 'system', content: 'fileid://reference' },
{
role: 'assistant',
content: [{ type: 'text', text: 'Generated previews ready' }]
},
{
role: 'user',
content: [
{ type: 'text', text: 'Apply the edits' },
{ type: 'image', image: 'https://example.com/reference.png' }
]
}
])
})
})
})

View File

@@ -0,0 +1,218 @@
import type { Assistant, AssistantSettings, Model, Topic } from '@renderer/types'
import { TopicType } from '@renderer/types'
import { defaultTimeout } from '@shared/config/constant'
import { describe, expect, it, vi } from 'vitest'
import { getTemperature, getTimeout, getTopP } from '../modelParameters'
vi.mock('@renderer/services/AssistantService', () => ({
getAssistantSettings: (assistant: Assistant): AssistantSettings => ({
contextCount: assistant.settings?.contextCount ?? 4096,
temperature: assistant.settings?.temperature ?? 0.7,
enableTemperature: assistant.settings?.enableTemperature ?? true,
topP: assistant.settings?.topP ?? 1,
enableTopP: assistant.settings?.enableTopP ?? false,
enableMaxTokens: assistant.settings?.enableMaxTokens ?? false,
maxTokens: assistant.settings?.maxTokens,
streamOutput: assistant.settings?.streamOutput ?? true,
toolUseMode: assistant.settings?.toolUseMode ?? 'prompt',
defaultModel: assistant.defaultModel,
customParameters: assistant.settings?.customParameters ?? [],
reasoning_effort: assistant.settings?.reasoning_effort,
reasoning_effort_cache: assistant.settings?.reasoning_effort_cache,
qwenThinkMode: assistant.settings?.qwenThinkMode
})
}))
vi.mock('@renderer/hooks/useSettings', () => ({
getStoreSetting: vi.fn(),
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left', isLeftNavbar: true, isTopNavbar: false }))
}))
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store/settings', () => ({
default: (state = { settings: {} }) => state
}))
vi.mock('@renderer/store/assistants', () => ({
default: (state = { assistants: [] }) => state
}))
const createTopic = (assistantId: string): Topic => ({
id: `topic-${assistantId}`,
assistantId,
name: 'topic',
createdAt: new Date().toISOString(),
updatedAt: new Date().toISOString(),
messages: [],
type: TopicType.Chat
})
const createAssistant = (settings: Assistant['settings'] = {}): Assistant => {
const assistantId = 'assistant-1'
return {
id: assistantId,
name: 'Test Assistant',
prompt: 'prompt',
topics: [createTopic(assistantId)],
type: 'assistant',
settings
}
}
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o',
provider: 'openai',
name: 'GPT-4o',
group: 'openai',
...overrides
})
describe('modelParameters', () => {
describe('getTemperature', () => {
it('returns undefined when reasoning effort is enabled for Claude models', () => {
const assistant = createAssistant({ reasoning_effort: 'medium' })
const model = createModel({ id: 'claude-opus-4', name: 'Claude Opus 4', provider: 'anthropic', group: 'claude' })
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('returns undefined for models without temperature/topP support', () => {
const assistant = createAssistant({ enableTemperature: true })
const model = createModel({ id: 'qwen-mt-large', name: 'Qwen MT', provider: 'qwen', group: 'qwen' })
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('returns undefined for Claude 4.5 reasoning models when only TopP is enabled', () => {
const assistant = createAssistant({ enableTopP: true, enableTemperature: false })
const model = createModel({
id: 'claude-sonnet-4.5',
name: 'Claude Sonnet 4.5',
provider: 'anthropic',
group: 'claude'
})
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('returns configured temperature when enabled', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 0.42 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTemperature(assistant, model)).toBe(0.42)
})
it('returns undefined when temperature is disabled', () => {
const assistant = createAssistant({ enableTemperature: false, temperature: 0.9 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('clamps temperature to max 1.0 for Zhipu models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
const model = createModel({ id: 'glm-4-plus', name: 'GLM-4 Plus', provider: 'zhipu', group: 'zhipu' })
expect(getTemperature(assistant, model)).toBe(1.0)
})
it('clamps temperature to max 1.0 for Anthropic models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 1.5 })
const model = createModel({
id: 'claude-sonnet-3.5',
name: 'Claude 3.5 Sonnet',
provider: 'anthropic',
group: 'claude'
})
expect(getTemperature(assistant, model)).toBe(1.0)
})
it('clamps temperature to max 1.0 for Moonshot models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
const model = createModel({
id: 'moonshot-v1-8k',
name: 'Moonshot v1 8k',
provider: 'moonshot',
group: 'moonshot'
})
expect(getTemperature(assistant, model)).toBe(1.0)
})
it('does not clamp temperature for OpenAI models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTemperature(assistant, model)).toBe(2.0)
})
it('does not clamp temperature when it is already within limits', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 0.8 })
const model = createModel({ id: 'glm-4-plus', name: 'GLM-4 Plus', provider: 'zhipu', group: 'zhipu' })
expect(getTemperature(assistant, model)).toBe(0.8)
})
})
describe('getTopP', () => {
it('returns undefined when reasoning effort is enabled for Claude models', () => {
const assistant = createAssistant({ reasoning_effort: 'high' })
const model = createModel({ id: 'claude-opus-4', provider: 'anthropic', group: 'claude' })
expect(getTopP(assistant, model)).toBeUndefined()
})
it('returns undefined for models without TopP support', () => {
const assistant = createAssistant({ enableTopP: true })
const model = createModel({ id: 'qwen-mt-small', name: 'Qwen MT', provider: 'qwen', group: 'qwen' })
expect(getTopP(assistant, model)).toBeUndefined()
})
it('returns undefined for Claude 4.5 reasoning models when temperature is enabled', () => {
const assistant = createAssistant({ enableTemperature: true })
const model = createModel({
id: 'claude-opus-4.5',
name: 'Claude Opus 4.5',
provider: 'anthropic',
group: 'claude'
})
expect(getTopP(assistant, model)).toBeUndefined()
})
it('returns configured TopP when enabled', () => {
const assistant = createAssistant({ enableTopP: true, topP: 0.73 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTopP(assistant, model)).toBe(0.73)
})
it('returns undefined when TopP is disabled', () => {
const assistant = createAssistant({ enableTopP: false, topP: 0.5 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTopP(assistant, model)).toBeUndefined()
})
})
describe('getTimeout', () => {
it('uses an extended timeout for flex service tier models', () => {
const model = createModel({ id: 'o3-pro', provider: 'openai', group: 'openai' })
expect(getTimeout(model)).toBe(15 * 1000 * 60)
})
it('falls back to the default timeout otherwise', () => {
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTimeout(model)).toBe(defaultTimeout)
})
})
})

View File

@@ -1,13 +1,31 @@
import { isClaude45ReasoningModel } from '@renderer/config/models'
import { isClaude4SeriesModel, isClaude45ReasoningModel } from '@renderer/config/models'
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Assistant, Model } from '@renderer/types'
import { isToolUseModeFunction } from '@renderer/utils/assistant'
import { isAwsBedrockProvider, isVertexProvider } from '@renderer/utils/provider'
// https://docs.claude.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
const INTERLEAVED_THINKING_HEADER = 'interleaved-thinking-2025-05-14'
// https://docs.claude.com/en/docs/build-with-claude/context-windows#1m-token-context-window
const CONTEXT_100M_HEADER = 'context-1m-2025-08-07'
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude/web-search
const WEBSEARCH_HEADER = 'web-search-2025-03-05'
export function addAnthropicHeaders(assistant: Assistant, model: Model): string[] {
const anthropicHeaders: string[] = []
if (isClaude45ReasoningModel(model) && isToolUseModeFunction(assistant)) {
const provider = getProviderByModel(model)
if (
isClaude45ReasoningModel(model) &&
isToolUseModeFunction(assistant) &&
!(isVertexProvider(provider) && isAwsBedrockProvider(provider))
) {
anthropicHeaders.push(INTERLEAVED_THINKING_HEADER)
}
if (isClaude4SeriesModel(model)) {
if (isVertexProvider(provider) && assistant.enableWebSearch) {
anthropicHeaders.push(WEBSEARCH_HEADER)
}
anthropicHeaders.push(CONTEXT_100M_HEADER)
}
return anthropicHeaders
}

View File

@@ -85,19 +85,6 @@ export function supportsLargeFileUpload(model: Model): boolean {
})
}
/**
* 检查模型是否支持TopP
*/
export function supportsTopP(model: Model): boolean {
const provider = getProviderByModel(model)
if (provider?.type === 'anthropic' || model?.endpoint_type === 'anthropic') {
return false
}
return true
}
/**
* 获取提供商特定的文件大小限制
*/

View File

@@ -6,14 +6,23 @@
import {
isClaude45ReasoningModel,
isClaudeReasoningModel,
isMaxTemperatureOneModel,
isNotSupportTemperatureAndTopP,
isSupportedFlexServiceTier
isSupportedFlexServiceTier,
isSupportedThinkingTokenClaudeModel
} from '@renderer/config/models'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
import type { Assistant, Model } from '@renderer/types'
import { defaultTimeout } from '@shared/config/constant'
import { getAnthropicThinkingBudget } from '../utils/reasoning'
/**
* Claude 4.5 推理模型:
* - 只启用 temperature → 使用 temperature
* - 只启用 top_p → 使用 top_p
* - 同时启用 → temperature 生效,top_p 被忽略
* - 都不启用 → 都不使用
* 获取温度参数
*/
export function getTemperature(assistant: Assistant, model: Model): number | undefined {
@@ -27,7 +36,11 @@ export function getTemperature(assistant: Assistant, model: Model): number | und
return undefined
}
const assistantSettings = getAssistantSettings(assistant)
return assistantSettings?.enableTemperature ? assistantSettings?.temperature : undefined
let temperature = assistantSettings?.temperature
if (temperature && isMaxTemperatureOneModel(model)) {
temperature = Math.min(1, temperature)
}
return assistantSettings?.enableTemperature ? temperature : undefined
}
/**
@@ -56,3 +69,26 @@ export function getTimeout(model: Model): number {
}
return defaultTimeout
}
export function getMaxTokens(assistant: Assistant, model: Model): number | undefined {
// NOTE: ai-sdk会把maxToken和budgetToken加起来
const assistantSettings = getAssistantSettings(assistant)
const enabledMaxTokens = assistantSettings.enableMaxTokens ?? false
let maxTokens = assistantSettings.maxTokens
// If user hasn't enabled enableMaxTokens, return undefined to let the API use its default value.
// Note: Anthropic API requires max_tokens, but that's handled by the Anthropic client with a fallback.
if (!enabledMaxTokens || maxTokens === undefined) {
return undefined
}
const provider = getProviderByModel(model)
if (isSupportedThinkingTokenClaudeModel(model) && ['anthropic', 'aws-bedrock'].includes(provider.type)) {
const { reasoning_effort: reasoningEffort } = assistantSettings
const budget = getAnthropicThinkingBudget(maxTokens, reasoningEffort, model.id)
if (budget) {
maxTokens -= budget
}
}
return maxTokens
}

View File

@@ -4,11 +4,12 @@
*/
import { anthropic } from '@ai-sdk/anthropic'
import { azure } from '@ai-sdk/azure'
import { google } from '@ai-sdk/google'
import { vertexAnthropic } from '@ai-sdk/google-vertex/anthropic/edge'
import { vertex } from '@ai-sdk/google-vertex/edge'
import { combineHeaders } from '@ai-sdk/provider-utils'
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
import type { AnthropicSearchConfig, WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
import { isBaseProvider } from '@cherrystudio/ai-core/core/providers/schemas'
import { loggerService } from '@logger'
import {
@@ -17,13 +18,10 @@ import {
isOpenRouterBuiltInWebSearchModel,
isReasoningModel,
isSupportedReasoningEffortModel,
isSupportedThinkingTokenClaudeModel,
isSupportedThinkingTokenModel,
isWebSearchModel
} from '@renderer/config/models'
import { isAwsBedrockProvider } from '@renderer/config/providers'
import { isVertexProvider } from '@renderer/hooks/useVertexAI'
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
import { getDefaultModel } from '@renderer/services/AssistantService'
import store from '@renderer/store'
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
import { type Assistant, type MCPTool, type Provider } from '@renderer/types'
@@ -36,11 +34,9 @@ import { stepCountIs } from 'ai'
import { getAiSdkProviderId } from '../provider/factory'
import { setupToolsConfig } from '../utils/mcp'
import { buildProviderOptions } from '../utils/options'
import { getAnthropicThinkingBudget } from '../utils/reasoning'
import { buildProviderBuiltinWebSearchConfig } from '../utils/websearch'
import { addAnthropicHeaders } from './header'
import { supportsTopP } from './modelCapabilities'
import { getTemperature, getTopP } from './modelParameters'
import { getMaxTokens, getTemperature, getTopP } from './modelParameters'
const logger = loggerService.withContext('parameterBuilder')
@@ -63,7 +59,7 @@ export async function buildStreamTextParams(
timeout?: number
headers?: Record<string, string>
}
} = {}
}
): Promise<{
params: StreamTextParams
modelId: string
@@ -80,8 +76,6 @@ export async function buildStreamTextParams(
const model = assistant.model || getDefaultModel()
const aiSdkProviderId = getAiSdkProviderId(provider)
let { maxTokens } = getAssistantSettings(assistant)
// 这三个变量透传出来,交给下面启用插件/中间件
// 也可以在外部构建好再传入buildStreamTextParams
// FIXME: qwen3即使关闭思考仍然会导致enableReasoning的结果为true
@@ -118,16 +112,6 @@ export async function buildStreamTextParams(
enableGenerateImage
})
// NOTE: ai-sdk会把maxToken和budgetToken加起来
if (
enableReasoning &&
maxTokens !== undefined &&
isSupportedThinkingTokenClaudeModel(model) &&
(provider.type === 'anthropic' || provider.type === 'aws-bedrock')
) {
maxTokens -= getAnthropicThinkingBudget(assistant, model)
}
let webSearchPluginConfig: WebSearchPluginConfig | undefined = undefined
if (enableWebSearch) {
if (isBaseProvider(aiSdkProviderId)) {
@@ -144,6 +128,17 @@ export async function buildStreamTextParams(
maxUses: webSearchConfig.maxResults,
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
}) as ProviderDefinedTool
} else if (aiSdkProviderId === 'azure-responses') {
tools.web_search_preview = azure.tools.webSearchPreview({
searchContextSize: webSearchPluginConfig?.openai!.searchContextSize
}) as ProviderDefinedTool
} else if (aiSdkProviderId === 'azure-anthropic') {
const blockedDomains = mapRegexToPatterns(webSearchConfig.excludeDomains)
const anthropicSearchOptions: AnthropicSearchConfig = {
maxUses: webSearchConfig.maxResults,
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
}
tools.web_search = anthropic.tools.webSearch_20250305(anthropicSearchOptions) as ProviderDefinedTool
}
}
@@ -161,9 +156,10 @@ export async function buildStreamTextParams(
tools.url_context = google.tools.urlContext({}) as ProviderDefinedTool
break
case 'anthropic':
case 'azure-anthropic':
case 'google-vertex-anthropic':
tools.web_fetch = (
aiSdkProviderId === 'anthropic'
['anthropic', 'azure-anthropic'].includes(aiSdkProviderId)
? anthropic.tools.webFetch_20250910({
maxUses: webSearchConfig.maxResults,
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
@@ -179,8 +175,7 @@ export async function buildStreamTextParams(
let headers: Record<string, string | undefined> = options.requestOptions?.headers ?? {}
// https://docs.claude.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
if (!isVertexProvider(provider) && !isAwsBedrockProvider(provider) && isAnthropicModel(model)) {
if (isAnthropicModel(model)) {
const newBetaHeaders = { 'anthropic-beta': addAnthropicHeaders(assistant, model).join(',') }
headers = combineHeaders(headers, newBetaHeaders)
}
@@ -188,8 +183,9 @@ export async function buildStreamTextParams(
// 构建基础参数
const params: StreamTextParams = {
messages: sdkMessages,
maxOutputTokens: maxTokens,
maxOutputTokens: getMaxTokens(assistant, model),
temperature: getTemperature(assistant, model),
topP: getTopP(assistant, model),
abortSignal: options.requestOptions?.signal,
headers,
providerOptions,
@@ -197,10 +193,6 @@ export async function buildStreamTextParams(
maxRetries: 0
}
if (supportsTopP(model)) {
params.topP = getTopP(assistant, model)
}
if (tools) {
params.tools = tools
}

View File

@@ -23,6 +23,26 @@ vi.mock('@cherrystudio/ai-core', () => ({
}
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn(),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
vi.mock('@renderer/store/settings', () => ({
default: {},
settingsSlice: {
name: 'settings',
reducer: vi.fn(),
actions: {}
}
}))
// Mock the provider configs
vi.mock('../providerConfigs', () => ({
initializeNewProviders: vi.fn()

View File

@@ -12,7 +12,14 @@ vi.mock('@renderer/services/LoggerService', () => ({
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn()
getProviderByModel: vi.fn(),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
vi.mock('@renderer/store', () => ({
@@ -34,7 +41,7 @@ vi.mock('@renderer/utils/api', () => ({
}))
}))
vi.mock('@renderer/config/providers', async (importOriginal) => {
vi.mock('@renderer/utils/provider', async (importOriginal) => {
const actual = (await importOriginal()) as any
return {
...actual,
@@ -53,10 +60,21 @@ vi.mock('@renderer/hooks/useVertexAI', () => ({
createVertexProvider: vi.fn()
}))
import { isCherryAIProvider, isPerplexityProvider } from '@renderer/config/providers'
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn(),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model, Provider } from '@renderer/types'
import { formatApiHost } from '@renderer/utils/api'
import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider'
import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants'
import { getActualProvider, providerToAiSdkConfig } from '../providerConfig'

View File

@@ -0,0 +1,22 @@
import type { Provider } from '@renderer/types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
const AZURE_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: Provider) => ({
...provider,
type: 'anthropic',
apiHost: provider.apiHost + 'anthropic/v1',
id: 'azure-anthropic'
})
}
],
fallbackRule: (provider: Provider) => provider
}
export const azureAnthropicProviderCreator = provider2Provider.bind(null, AZURE_ANTHROPIC_RULES)

View File

@@ -2,8 +2,10 @@ import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } fr
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import type { Provider } from '@renderer/types'
import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from '@renderer/utils/provider'
import type { Provider as AiSdkProvider } from 'ai'
import type { AiSdkConfig } from '../types'
import { initializeNewProviders } from './providerInitialization'
const logger = loggerService.withContext('ProviderFactory')
@@ -55,9 +57,12 @@ function tryResolveProviderId(identifier: string): ProviderId | null {
* 获取AI SDK Provider ID
* 简化版:减少重复逻辑,利用通用解析函数
*/
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' {
export function getAiSdkProviderId(provider: Provider): string {
// 1. 尝试解析provider.id
const resolvedFromId = tryResolveProviderId(provider.id)
if (isAzureOpenAIProvider(provider) && isAzureResponsesEndpoint(provider)) {
return 'azure-responses'
}
if (resolvedFromId) {
return resolvedFromId
}
@@ -73,11 +78,11 @@ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-com
if (provider.apiHost.includes('api.openai.com')) {
return 'openai-chat'
}
// 3. 最后的fallback通常会成为openai-compatible
return provider.id as ProviderId
// 3. 最后的fallback使用provider本身的id
return provider.id
}
export async function createAiSdkProvider(config) {
export async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {
let localProvider: Awaited<AiSdkProvider> | null = null
try {
if (config.providerId === 'openai' && config.options?.mode === 'chat') {

View File

@@ -1,19 +1,5 @@
import {
formatPrivateKey,
hasProviderConfig,
ProviderConfigFactory,
type ProviderId,
type ProviderSettingsMap
} from '@cherrystudio/ai-core/provider'
import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
import {
isAnthropicProvider,
isAzureOpenAIProvider,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isPerplexityProvider
} from '@renderer/config/providers'
import {
getAwsBedrockAccessKeyId,
getAwsBedrockApiKey,
@@ -21,14 +7,25 @@ import {
getAwsBedrockRegion,
getAwsBedrockSecretAccessKey
} from '@renderer/hooks/useAwsBedrock'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService'
import store from '@renderer/store'
import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import { formatApiHost, formatAzureOpenAIApiHost, formatVertexApiHost, routeToEndpoint } from '@renderer/utils/api'
import {
isAnthropicProvider,
isAzureOpenAIProvider,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isPerplexityProvider,
isVertexProvider
} from '@renderer/utils/provider'
import { cloneDeep } from 'lodash'
import type { AiSdkConfig } from '../types'
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
import { COPILOT_DEFAULT_HEADERS } from './constants'
import { getAiSdkProviderId } from './factory'
@@ -74,6 +71,9 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider {
return vertexAnthropicProviderCreator(model, provider)
}
}
if (isAzureOpenAIProvider(provider)) {
return azureAnthropicProviderCreator(model, provider)
}
return provider
}
@@ -131,13 +131,7 @@ export function getActualProvider(model: Model): Provider {
* 将 Provider 配置转换为新 AI SDK 格式
* 简化版:利用新的别名映射系统
*/
export function providerToAiSdkConfig(
actualProvider: Provider,
model: Model
): {
providerId: ProviderId | 'openai-compatible'
options: ProviderSettingsMap[keyof ProviderSettingsMap]
} {
export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig {
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
// 构建基础配置
@@ -191,13 +185,10 @@ export function providerToAiSdkConfig(
// azure
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses?tabs=python-key#responses-api
if (aiSdkProviderId === 'azure' || actualProvider.type === 'azure-openai') {
// extraOptions.apiVersion = actualProvider.apiVersion === 'preview' ? 'v1' : actualProvider.apiVersion 默认使用v1不使用azure endpoint
if (actualProvider.apiVersion === 'preview' || actualProvider.apiVersion === 'v1') {
extraOptions.mode = 'responses'
} else {
extraOptions.mode = 'chat'
}
if (aiSdkProviderId === 'azure-responses') {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'azure') {
extraOptions.mode = 'chat'
}
// bedrock
@@ -227,10 +218,17 @@ export function providerToAiSdkConfig(
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
}
// cherryin
if (aiSdkProviderId === 'cherryin') {
if (model.endpoint_type) {
extraOptions.endpointType = model.endpoint_type
}
}
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return {
providerId: aiSdkProviderId as ProviderId,
providerId: aiSdkProviderId,
options
}
}

View File

@@ -32,6 +32,14 @@ export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
supportsImageGeneration: true,
aliases: ['vertexai-anthropic']
},
{
id: 'azure-anthropic',
name: 'Azure AI Anthropic',
import: () => import('@ai-sdk/anthropic'),
creatorFunctionName: 'createAnthropic',
supportsImageGeneration: false,
aliases: ['azure-anthropic']
},
{
id: 'github-copilot-openai-compatible',
name: 'GitHub Copilot OpenAI Compatible',

View File

@@ -133,7 +133,7 @@ export class AiSdkSpanAdapter {
// 详细记录转换过程
const operationId = attributes['ai.operationId']
logger.info('Converting AI SDK span to SpanEntity', {
logger.debug('Converting AI SDK span to SpanEntity', {
spanName: spanName,
operationId,
spanTag,
@@ -149,7 +149,7 @@ export class AiSdkSpanAdapter {
})
if (tokenUsage) {
logger.info('Token usage data found', {
logger.debug('Token usage data found', {
spanName: spanName,
operationId,
usage: tokenUsage,
@@ -158,7 +158,7 @@ export class AiSdkSpanAdapter {
}
if (inputs || outputs) {
logger.info('Input/Output data extracted', {
logger.debug('Input/Output data extracted', {
spanName: spanName,
operationId,
hasInputs: !!inputs,
@@ -170,7 +170,7 @@ export class AiSdkSpanAdapter {
}
if (Object.keys(typeSpecificData).length > 0) {
logger.info('Type-specific data extracted', {
logger.debug('Type-specific data extracted', {
spanName: spanName,
operationId,
typeSpecificKeys: Object.keys(typeSpecificData),
@@ -204,7 +204,7 @@ export class AiSdkSpanAdapter {
modelName: modelName || this.extractModelFromAttributes(attributes)
}
logger.info('AI SDK span successfully converted to SpanEntity', {
logger.debug('AI SDK span successfully converted to SpanEntity', {
spanName: spanName,
operationId,
spanId: spanContext.spanId,

View File

@@ -0,0 +1,15 @@
/**
* This type definition file is only for renderer.
* It cannot be migrated to @renderer/types since files within it are actually being used by both main and renderer.
* If we do that, main would throw an error because it cannot import a module which imports a type from a browser-enviroment-only package.
* (ai-core package is set as browser-enviroment-only)
*
* TODO: We should separate them clearly. Keep renderer only types in renderer, and main only types in main, and shared types in shared.
*/
import type { ProviderSettingsMap } from '@cherrystudio/ai-core/provider'
export type AiSdkConfig = {
providerId: string
options: ProviderSettingsMap[keyof ProviderSettingsMap]
}

View File

@@ -0,0 +1,121 @@
/**
* image.ts Unit Tests
* Tests for Gemini image generation utilities
*/
import type { Model, Provider } from '@renderer/types'
import { SystemProviderIds } from '@renderer/types'
import { describe, expect, it } from 'vitest'
import { buildGeminiGenerateImageParams, isOpenRouterGeminiGenerateImageModel } from '../image'
describe('image utils', () => {
describe('buildGeminiGenerateImageParams', () => {
it('should return correct response modalities', () => {
const result = buildGeminiGenerateImageParams()
expect(result).toEqual({
responseModalities: ['TEXT', 'IMAGE']
})
})
it('should return an object with responseModalities property', () => {
const result = buildGeminiGenerateImageParams()
expect(result).toHaveProperty('responseModalities')
expect(Array.isArray(result.responseModalities)).toBe(true)
expect(result.responseModalities).toHaveLength(2)
})
})
describe('isOpenRouterGeminiGenerateImageModel', () => {
const mockOpenRouterProvider: Provider = {
id: SystemProviderIds.openrouter,
name: 'OpenRouter',
apiKey: 'test-key',
apiHost: 'https://openrouter.ai/api/v1',
isSystem: true
} as Provider
const mockOtherProvider: Provider = {
id: SystemProviderIds.openai,
name: 'OpenAI',
apiKey: 'test-key',
apiHost: 'https://api.openai.com/v1',
isSystem: true
} as Provider
it('should return true for OpenRouter Gemini 2.5 Flash Image model', () => {
const model: Model = {
id: 'google/gemini-2.5-flash-image-preview',
name: 'Gemini 2.5 Flash Image',
provider: SystemProviderIds.openrouter
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
expect(result).toBe(true)
})
it('should return false for non-Gemini model on OpenRouter', () => {
const model: Model = {
id: 'openai/gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openrouter
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
expect(result).toBe(false)
})
it('should return false for Gemini model on non-OpenRouter provider', () => {
const model: Model = {
id: 'gemini-2.5-flash-image-preview',
name: 'Gemini 2.5 Flash Image',
provider: SystemProviderIds.gemini
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOtherProvider)
expect(result).toBe(false)
})
it('should return false for Gemini model without image suffix', () => {
const model: Model = {
id: 'google/gemini-2.5-flash',
name: 'Gemini 2.5 Flash',
provider: SystemProviderIds.openrouter
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
expect(result).toBe(false)
})
it('should handle model ID with partial match', () => {
const model: Model = {
id: 'google/gemini-2.5-flash-image-generation',
name: 'Gemini Image Gen',
provider: SystemProviderIds.openrouter
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
expect(result).toBe(true)
})
it('should return false for custom provider', () => {
const customProvider: Provider = {
id: 'custom-provider-123',
name: 'Custom Provider',
apiKey: 'test-key',
apiHost: 'https://custom.com'
} as Provider
const model: Model = {
id: 'gemini-2.5-flash-image-preview',
name: 'Gemini 2.5 Flash Image',
provider: 'custom-provider-123'
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, customProvider)
expect(result).toBe(false)
})
})
})

View File

@@ -0,0 +1,435 @@
/**
* mcp.ts Unit Tests
* Tests for MCP tools configuration and conversion utilities
*/
import type { MCPTool } from '@renderer/types'
import type { Tool } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { convertMcpToolsToAiSdkTools, setupToolsConfig } from '../mcp'
// Mock dependencies
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({
debug: vi.fn(),
error: vi.fn(),
warn: vi.fn(),
info: vi.fn()
})
}
}))
vi.mock('@renderer/utils/mcp-tools', () => ({
getMcpServerByTool: vi.fn(() => ({ id: 'test-server', autoApprove: false })),
isToolAutoApproved: vi.fn(() => false),
callMCPTool: vi.fn(async () => ({
content: [{ type: 'text', text: 'Tool executed successfully' }],
isError: false
}))
}))
vi.mock('@renderer/utils/userConfirmation', () => ({
requestToolConfirmation: vi.fn(async () => true)
}))
describe('mcp utils', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('setupToolsConfig', () => {
it('should return undefined when no MCP tools provided', () => {
const result = setupToolsConfig()
expect(result).toBeUndefined()
})
it('should return undefined when empty MCP tools array provided', () => {
const result = setupToolsConfig([])
expect(result).toBeUndefined()
})
it('should convert MCP tools to AI SDK tools format', () => {
const mcpTools: MCPTool[] = [
{
id: 'test-tool-1',
serverId: 'test-server',
serverName: 'test-server',
name: 'test-tool',
description: 'A test tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
query: { type: 'string' }
}
}
}
]
const result = setupToolsConfig(mcpTools)
expect(result).not.toBeUndefined()
expect(Object.keys(result!)).toEqual(['test-tool'])
expect(result!['test-tool']).toHaveProperty('description')
expect(result!['test-tool']).toHaveProperty('inputSchema')
expect(result!['test-tool']).toHaveProperty('execute')
})
it('should handle multiple MCP tools', () => {
const mcpTools: MCPTool[] = [
{
id: 'tool1-id',
serverId: 'server1',
serverName: 'server1',
name: 'tool1',
description: 'First tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
},
{
id: 'tool2-id',
serverId: 'server2',
serverName: 'server2',
name: 'tool2',
description: 'Second tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const result = setupToolsConfig(mcpTools)
expect(result).not.toBeUndefined()
expect(Object.keys(result!)).toHaveLength(2)
expect(Object.keys(result!)).toEqual(['tool1', 'tool2'])
})
})
describe('convertMcpToolsToAiSdkTools', () => {
it('should convert single MCP tool to AI SDK tool', () => {
const mcpTools: MCPTool[] = [
{
id: 'get-weather-id',
serverId: 'weather-server',
serverName: 'weather-server',
name: 'get-weather',
description: 'Get weather information',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
location: { type: 'string' }
},
required: ['location']
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['get-weather'])
const tool = result['get-weather'] as Tool
expect(tool.description).toBe('Get weather information')
expect(tool.inputSchema).toBeDefined()
expect(typeof tool.execute).toBe('function')
})
it('should handle tool without description', () => {
const mcpTools: MCPTool[] = [
{
id: 'no-desc-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'no-desc-tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['no-desc-tool'])
const tool = result['no-desc-tool'] as Tool
expect(tool.description).toBe('Tool from test-server')
})
it('should convert empty tools array', () => {
const result = convertMcpToolsToAiSdkTools([])
expect(result).toEqual({})
})
it('should handle complex input schemas', () => {
const mcpTools: MCPTool[] = [
{
id: 'complex-tool-id',
serverId: 'server',
serverName: 'server',
name: 'complex-tool',
description: 'Tool with complex schema',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
name: { type: 'string' },
age: { type: 'number' },
tags: {
type: 'array',
items: { type: 'string' }
},
metadata: {
type: 'object',
properties: {
key: { type: 'string' }
}
}
},
required: ['name']
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['complex-tool'])
const tool = result['complex-tool'] as Tool
expect(tool.inputSchema).toBeDefined()
expect(typeof tool.execute).toBe('function')
})
it('should preserve tool names with special characters', () => {
const mcpTools: MCPTool[] = [
{
id: 'special-tool-id',
serverId: 'server',
serverName: 'server',
name: 'tool_with-special.chars',
description: 'Special chars tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['tool_with-special.chars'])
})
it('should handle multiple tools with different schemas', () => {
const mcpTools: MCPTool[] = [
{
id: 'string-tool-id',
serverId: 'server1',
serverName: 'server1',
name: 'string-tool',
description: 'String tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
input: { type: 'string' }
}
}
},
{
id: 'number-tool-id',
serverId: 'server2',
serverName: 'server2',
name: 'number-tool',
description: 'Number tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
count: { type: 'number' }
}
}
},
{
id: 'boolean-tool-id',
serverId: 'server3',
serverName: 'server3',
name: 'boolean-tool',
description: 'Boolean tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
enabled: { type: 'boolean' }
}
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result).sort()).toEqual(['boolean-tool', 'number-tool', 'string-tool'])
expect(result['string-tool']).toBeDefined()
expect(result['number-tool']).toBeDefined()
expect(result['boolean-tool']).toBeDefined()
})
})
describe('tool execution', () => {
it('should execute tool with user confirmation', async () => {
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
vi.mocked(requestToolConfirmation).mockResolvedValue(true)
vi.mocked(callMCPTool).mockResolvedValue({
content: [{ type: 'text', text: 'Success' }],
isError: false
})
const mcpTools: MCPTool[] = [
{
id: 'test-exec-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'test-exec-tool',
description: 'Test execution tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['test-exec-tool'] as Tool
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'test-call-123' })
expect(requestToolConfirmation).toHaveBeenCalled()
expect(callMCPTool).toHaveBeenCalled()
expect(result).toEqual({
content: [{ type: 'text', text: 'Success' }],
isError: false
})
})
it('should handle user cancellation', async () => {
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
vi.mocked(requestToolConfirmation).mockResolvedValue(false)
const mcpTools: MCPTool[] = [
{
id: 'cancelled-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'cancelled-tool',
description: 'Tool to cancel',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['cancelled-tool'] as Tool
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'cancel-call-123' })
expect(requestToolConfirmation).toHaveBeenCalled()
expect(callMCPTool).not.toHaveBeenCalled()
expect(result).toEqual({
content: [
{
type: 'text',
text: 'User declined to execute tool "cancelled-tool".'
}
],
isError: false
})
})
it('should handle tool execution error', async () => {
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
vi.mocked(requestToolConfirmation).mockResolvedValue(true)
vi.mocked(callMCPTool).mockResolvedValue({
content: [{ type: 'text', text: 'Error occurred' }],
isError: true
})
const mcpTools: MCPTool[] = [
{
id: 'error-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'error-tool',
description: 'Tool that errors',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['error-tool'] as Tool
await expect(
tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'error-call-123' })
).rejects.toEqual({
content: [{ type: 'text', text: 'Error occurred' }],
isError: true
})
})
it('should auto-approve when enabled', async () => {
const { callMCPTool, isToolAutoApproved } = await import('@renderer/utils/mcp-tools')
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
vi.mocked(isToolAutoApproved).mockReturnValue(true)
vi.mocked(callMCPTool).mockResolvedValue({
content: [{ type: 'text', text: 'Auto-approved success' }],
isError: false
})
const mcpTools: MCPTool[] = [
{
id: 'auto-approve-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'auto-approve-tool',
description: 'Auto-approved tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['auto-approve-tool'] as Tool
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'auto-call-123' })
expect(requestToolConfirmation).not.toHaveBeenCalled()
expect(callMCPTool).toHaveBeenCalled()
expect(result).toEqual({
content: [{ type: 'text', text: 'Auto-approved success' }],
isError: false
})
})
})
})

View File

@@ -0,0 +1,545 @@
/**
* options.ts Unit Tests
* Tests for building provider-specific options
*/
import type { Assistant, Model, Provider } from '@renderer/types'
import { OpenAIServiceTiers, SystemProviderIds } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { buildProviderOptions } from '../options'
// Mock dependencies
vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => {
const actual = (await importOriginal()) as object
return {
...actual,
baseProviderIdSchema: {
safeParse: vi.fn((id) => {
const baseProviders = [
'openai',
'openai-chat',
'azure',
'azure-responses',
'huggingface',
'anthropic',
'google',
'xai',
'deepseek',
'openrouter',
'openai-compatible'
]
if (baseProviders.includes(id)) {
return { success: true, data: id }
}
return { success: false }
})
},
customProviderIdSchema: {
safeParse: vi.fn((id) => {
const customProviders = ['google-vertex', 'google-vertex-anthropic', 'bedrock']
if (customProviders.includes(id)) {
return { success: true, data: id }
}
return { success: false, error: new Error('Invalid provider') }
})
}
}
})
vi.mock('../provider/factory', () => ({
getAiSdkProviderId: vi.fn((provider) => {
// Simulate the provider ID mapping
const mapping: Record<string, string> = {
[SystemProviderIds.gemini]: 'google',
[SystemProviderIds.openai]: 'openai',
[SystemProviderIds.anthropic]: 'anthropic',
[SystemProviderIds.grok]: 'xai',
[SystemProviderIds.deepseek]: 'deepseek',
[SystemProviderIds.openrouter]: 'openrouter'
}
return mapping[provider.id] || provider.id
})
}))
vi.mock('@renderer/config/models', async (importOriginal) => ({
...(await importOriginal()),
isOpenAIModel: vi.fn((model) => model.id.includes('gpt') || model.id.includes('o1')),
isQwenMTModel: vi.fn(() => false),
isSupportFlexServiceTierModel: vi.fn(() => true),
isOpenAILLMModel: vi.fn(() => true),
SYSTEM_MODELS: {
defaultModel: [
{ id: 'default-1', name: 'Default 1' },
{ id: 'default-2', name: 'Default 2' },
{ id: 'default-3', name: 'Default 3' }
]
}
}))
vi.mock(import('@renderer/utils/provider'), async (importOriginal) => {
return {
...(await importOriginal()),
isSupportServiceTierProvider: vi.fn((provider) => {
return [SystemProviderIds.openai, SystemProviderIds.groq].includes(provider.id)
})
}
})
vi.mock('@renderer/store/settings', () => ({
default: (state = { settings: {} }) => state
}))
vi.mock('@renderer/hooks/useSettings', () => ({
getStoreSetting: vi.fn((key) => {
if (key === 'openAI') {
return { summaryText: 'off', verbosity: 'medium' } as any
}
return {}
})
}))
vi.mock('@renderer/services/AssistantService', () => ({
getDefaultAssistant: vi.fn(() => ({
id: 'default',
name: 'Default Assistant',
settings: {}
})),
getAssistantSettings: vi.fn(() => ({
reasoning_effort: 'medium',
maxTokens: 4096
})),
getProviderByModel: vi.fn((model: Model) => ({
id: model.provider,
name: 'Mock Provider'
}))
}))
vi.mock('../reasoning', () => ({
getOpenAIReasoningParams: vi.fn(() => ({ reasoningEffort: 'medium' })),
getAnthropicReasoningParams: vi.fn(() => ({
thinking: { type: 'enabled', budgetTokens: 5000 }
})),
getGeminiReasoningParams: vi.fn(() => ({
thinkingConfig: { include_thoughts: true }
})),
getXAIReasoningParams: vi.fn(() => ({ reasoningEffort: 'high' })),
getBedrockReasoningParams: vi.fn(() => ({
reasoningConfig: { type: 'enabled', budgetTokens: 5000 }
})),
getReasoningEffort: vi.fn(() => ({ reasoningEffort: 'medium' })),
getCustomParameters: vi.fn(() => ({}))
}))
vi.mock('../image', () => ({
buildGeminiGenerateImageParams: vi.fn(() => ({
responseModalities: ['TEXT', 'IMAGE']
}))
}))
vi.mock('../websearch', () => ({
getWebSearchParams: vi.fn(() => ({ enable_search: true }))
}))
const ensureWindowApi = () => {
const globalWindow = window as any
globalWindow.api = globalWindow.api || {}
globalWindow.api.getAppInfo = globalWindow.api.getAppInfo || vi.fn(async () => ({ notesPath: '' }))
}
ensureWindowApi()
describe('options utils', () => {
const mockAssistant: Assistant = {
id: 'test-assistant',
name: 'Test Assistant',
settings: {}
} as Assistant
const mockModel: Model = {
id: 'gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openai
} as Model
beforeEach(() => {
vi.clearAllMocks()
})
describe('buildProviderOptions', () => {
describe('OpenAI provider', () => {
const openaiProvider: Provider = {
id: SystemProviderIds.openai,
name: 'OpenAI',
type: 'openai-response',
apiKey: 'test-key',
apiHost: 'https://api.openai.com/v1',
isSystem: true
} as Provider
it('should build basic OpenAI options', () => {
const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('openai')
expect(result.openai).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.openai).toHaveProperty('reasoningEffort')
expect(result.openai.reasoningEffort).toBe('medium')
})
it('should include service tier when supported', () => {
const providerWithServiceTier: Provider = {
...openaiProvider,
serviceTier: OpenAIServiceTiers.auto
}
const result = buildProviderOptions(mockAssistant, mockModel, providerWithServiceTier, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.openai).toHaveProperty('serviceTier')
expect(result.openai.serviceTier).toBe(OpenAIServiceTiers.auto)
})
})
describe('Anthropic provider', () => {
const anthropicProvider: Provider = {
id: SystemProviderIds.anthropic,
name: 'Anthropic',
type: 'anthropic',
apiKey: 'test-key',
apiHost: 'https://api.anthropic.com',
isSystem: true
} as Provider
const anthropicModel: Model = {
id: 'claude-3-5-sonnet-20241022',
name: 'Claude 3.5 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
it('should build basic Anthropic options', () => {
const result = buildProviderOptions(mockAssistant, anthropicModel, anthropicProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('anthropic')
expect(result.anthropic).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, anthropicModel, anthropicProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.anthropic).toHaveProperty('thinking')
expect(result.anthropic.thinking).toEqual({
type: 'enabled',
budgetTokens: 5000
})
})
})
describe('Google provider', () => {
const googleProvider: Provider = {
id: SystemProviderIds.gemini,
name: 'Google',
type: 'gemini',
apiKey: 'test-key',
apiHost: 'https://generativelanguage.googleapis.com',
isSystem: true,
models: [{ id: 'gemini-2.0-flash-exp' }] as Model[]
} as Provider
const googleModel: Model = {
id: 'gemini-2.0-flash-exp',
name: 'Gemini 2.0 Flash',
provider: SystemProviderIds.gemini
} as Model
it('should build basic Google options', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('google')
expect(result.google).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.google).toHaveProperty('thinkingConfig')
expect(result.google.thinkingConfig).toEqual({
include_thoughts: true
})
})
it('should include image generation parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: true
})
expect(result.google).toHaveProperty('responseModalities')
expect(result.google.responseModalities).toEqual(['TEXT', 'IMAGE'])
})
})
describe('xAI provider', () => {
const xaiProvider = {
id: SystemProviderIds.grok,
name: 'xAI',
type: 'new-api',
apiKey: 'test-key',
apiHost: 'https://api.x.ai/v1',
isSystem: true,
models: [] as Model[]
} as Provider
const xaiModel: Model = {
id: 'grok-2-latest',
name: 'Grok 2',
provider: SystemProviderIds.grok
} as Model
it('should build basic xAI options', () => {
const result = buildProviderOptions(mockAssistant, xaiModel, xaiProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('xai')
expect(result.xai).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, xaiModel, xaiProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.xai).toHaveProperty('reasoningEffort')
expect(result.xai.reasoningEffort).toBe('high')
})
})
describe('DeepSeek provider', () => {
const deepseekProvider: Provider = {
id: SystemProviderIds.deepseek,
name: 'DeepSeek',
type: 'openai',
apiKey: 'test-key',
apiHost: 'https://api.deepseek.com',
isSystem: true
} as Provider
const deepseekModel: Model = {
id: 'deepseek-chat',
name: 'DeepSeek Chat',
provider: SystemProviderIds.deepseek
} as Model
it('should build basic DeepSeek options', () => {
const result = buildProviderOptions(mockAssistant, deepseekModel, deepseekProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('deepseek')
expect(result.deepseek).toBeDefined()
})
})
describe('OpenRouter provider', () => {
const openrouterProvider: Provider = {
id: SystemProviderIds.openrouter,
name: 'OpenRouter',
type: 'openai',
apiKey: 'test-key',
apiHost: 'https://openrouter.ai/api/v1',
isSystem: true
} as Provider
const openrouterModel: Model = {
id: 'openai/gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openrouter
} as Model
it('should build basic OpenRouter options', () => {
const result = buildProviderOptions(mockAssistant, openrouterModel, openrouterProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('openrouter')
expect(result.openrouter).toBeDefined()
})
it('should include web search parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, openrouterModel, openrouterProvider, {
enableReasoning: false,
enableWebSearch: true,
enableGenerateImage: false
})
expect(result.openrouter).toHaveProperty('enable_search')
})
})
describe('Custom parameters', () => {
it('should merge custom parameters', async () => {
const { getCustomParameters } = await import('../reasoning')
vi.mocked(getCustomParameters).mockReturnValue({
custom_param: 'custom_value',
another_param: 123
})
const result = buildProviderOptions(
mockAssistant,
mockModel,
{
id: SystemProviderIds.openai,
name: 'OpenAI',
type: 'openai',
apiKey: 'test-key',
apiHost: 'https://api.openai.com/v1'
} as Provider,
{
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
}
)
expect(result.openai).toHaveProperty('custom_param')
expect(result.openai.custom_param).toBe('custom_value')
expect(result.openai).toHaveProperty('another_param')
expect(result.openai.another_param).toBe(123)
})
})
describe('Multiple capabilities', () => {
const googleProvider = {
id: SystemProviderIds.gemini,
name: 'Google',
type: 'gemini',
apiKey: 'test-key',
apiHost: 'https://generativelanguage.googleapis.com',
isSystem: true,
models: [] as Model[]
} as Provider
const googleModel: Model = {
id: 'gemini-2.0-flash-exp',
name: 'Gemini 2.0 Flash',
provider: SystemProviderIds.gemini
} as Model
it('should combine reasoning and image generation', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: true
})
expect(result.google).toHaveProperty('thinkingConfig')
expect(result.google).toHaveProperty('responseModalities')
})
it('should handle all capabilities enabled', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: true,
enableWebSearch: true,
enableGenerateImage: true
})
expect(result.google).toBeDefined()
expect(Object.keys(result.google).length).toBeGreaterThan(0)
})
})
describe('Vertex AI providers', () => {
it('should map google-vertex to google', () => {
const vertexProvider = {
id: 'google-vertex',
name: 'Vertex AI',
type: 'vertexai',
apiKey: 'test-key',
apiHost: 'https://vertex-ai.googleapis.com',
models: [] as Model[]
} as Provider
const vertexModel: Model = {
id: 'gemini-2.0-flash-exp',
name: 'Gemini 2.0 Flash',
provider: 'google-vertex'
} as Model
const result = buildProviderOptions(mockAssistant, vertexModel, vertexProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('google')
})
it('should map google-vertex-anthropic to anthropic', () => {
const vertexAnthropicProvider = {
id: 'google-vertex-anthropic',
name: 'Vertex AI Anthropic',
type: 'vertex-anthropic',
apiKey: 'test-key',
apiHost: 'https://vertex-ai.googleapis.com',
models: [] as Model[]
} as Provider
const vertexModel: Model = {
id: 'claude-3-5-sonnet-20241022',
name: 'Claude 3.5 Sonnet',
provider: 'google-vertex-anthropic'
} as Model
const result = buildProviderOptions(mockAssistant, vertexModel, vertexAnthropicProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('anthropic')
})
})
})
})

View File

@@ -0,0 +1,967 @@
/**
* reasoning.ts Unit Tests
* Tests for reasoning parameter generation utilities
*/
import { getStoreSetting } from '@renderer/hooks/useSettings'
import type { SettingsState } from '@renderer/store/settings'
import type { Assistant, Model, Provider } from '@renderer/types'
import { SystemProviderIds } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
getAnthropicReasoningParams,
getBedrockReasoningParams,
getCustomParameters,
getGeminiReasoningParams,
getOpenAIReasoningParams,
getReasoningEffort,
getXAIReasoningParams
} from '../reasoning'
function defaultGetStoreSetting<K extends keyof SettingsState>(key: K): SettingsState[K] {
if (key === 'openAI') {
return {
summaryText: 'auto',
verbosity: 'medium'
} as SettingsState[K]
}
return undefined as SettingsState[K]
}
// Mock dependencies
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({
debug: vi.fn(),
error: vi.fn(),
warn: vi.fn(),
info: vi.fn()
})
}
}))
vi.mock('@renderer/store/settings', () => ({
default: (state = { settings: {} }) => state
}))
vi.mock('@renderer/store/llm', () => ({
initialState: {},
default: (state = { llm: {} }) => state
}))
vi.mock('@renderer/config/constant', () => ({
DEFAULT_MAX_TOKENS: 4096,
isMac: false,
isWin: false,
TOKENFLUX_HOST: 'mock-host'
}))
vi.mock('@renderer/utils/provider', () => ({
isSupportEnableThinkingProvider: vi.fn((provider) => {
return [SystemProviderIds.dashscope, SystemProviderIds.silicon].includes(provider.id)
})
}))
vi.mock('@renderer/config/models', async (importOriginal) => {
const actual: any = await importOriginal()
return {
...actual,
isReasoningModel: vi.fn(() => false),
isOpenAIDeepResearchModel: vi.fn(() => false),
isOpenAIModel: vi.fn(() => false),
isSupportedReasoningEffortOpenAIModel: vi.fn(() => false),
isSupportedThinkingTokenQwenModel: vi.fn(() => false),
isQwenReasoningModel: vi.fn(() => false),
isSupportedThinkingTokenClaudeModel: vi.fn(() => false),
isSupportedThinkingTokenGeminiModel: vi.fn(() => false),
isSupportedThinkingTokenDoubaoModel: vi.fn(() => false),
isSupportedThinkingTokenZhipuModel: vi.fn(() => false),
isSupportedReasoningEffortModel: vi.fn(() => false),
isDeepSeekHybridInferenceModel: vi.fn(() => false),
isSupportedReasoningEffortGrokModel: vi.fn(() => false),
getThinkModelType: vi.fn(() => 'default'),
isDoubaoSeedAfter251015: vi.fn(() => false),
isDoubaoThinkingAutoModel: vi.fn(() => false),
isGrok4FastReasoningModel: vi.fn(() => false),
isGrokReasoningModel: vi.fn(() => false),
isOpenAIReasoningModel: vi.fn(() => false),
isQwenAlwaysThinkModel: vi.fn(() => false),
isSupportedThinkingTokenHunyuanModel: vi.fn(() => false),
isSupportedThinkingTokenModel: vi.fn(() => false),
isGPT51SeriesModel: vi.fn(() => false)
}
})
vi.mock('@renderer/hooks/useSettings', () => ({
getStoreSetting: vi.fn(defaultGetStoreSetting)
}))
vi.mock('@renderer/services/AssistantService', () => ({
getAssistantSettings: vi.fn((assistant) => ({
maxTokens: assistant?.settings?.maxTokens || 4096,
reasoning_effort: assistant?.settings?.reasoning_effort
})),
getProviderByModel: vi.fn((model) => ({
id: model.provider,
name: 'Test Provider'
})),
getDefaultAssistant: vi.fn(() => ({
id: 'default',
name: 'Default Assistant',
settings: {}
}))
}))
const ensureWindowApi = () => {
const globalWindow = window as any
globalWindow.api = globalWindow.api || {}
globalWindow.api.getAppInfo = globalWindow.api.getAppInfo || vi.fn(async () => ({ notesPath: '' }))
}
ensureWindowApi()
describe('reasoning utils', () => {
beforeEach(() => {
vi.resetAllMocks()
})
describe('getReasoningEffort', () => {
it('should return empty object for non-reasoning model', async () => {
const model: Model = {
id: 'gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({})
})
it('should disable reasoning for OpenRouter when no reasoning effort set', async () => {
const { isReasoningModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
const model: Model = {
id: 'anthropic/claude-sonnet-4',
name: 'Claude Sonnet 4',
provider: SystemProviderIds.openrouter
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({ reasoning: { enabled: false, exclude: true } })
})
it('should handle Qwen models with enable_thinking', async () => {
const { isReasoningModel, isSupportedThinkingTokenQwenModel, isQwenReasoningModel } = await import(
'@renderer/config/models'
)
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenQwenModel).mockReturnValue(true)
vi.mocked(isQwenReasoningModel).mockReturnValue(true)
const model: Model = {
id: 'qwen-plus',
name: 'Qwen Plus',
provider: SystemProviderIds.dashscope
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium'
}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toHaveProperty('enable_thinking')
})
it('should handle Claude models with thinking config', async () => {
const {
isSupportedThinkingTokenClaudeModel,
isReasoningModel,
isQwenReasoningModel,
isSupportedThinkingTokenGeminiModel,
isSupportedThinkingTokenDoubaoModel,
isSupportedThinkingTokenZhipuModel,
isSupportedReasoningEffortModel
} = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
vi.mocked(isQwenReasoningModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenDoubaoModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenZhipuModel).mockReturnValue(false)
vi.mocked(isSupportedReasoningEffortModel).mockReturnValue(false)
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high',
maxTokens: 4096
}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
thinking: {
type: 'enabled',
budget_tokens: expect.any(Number)
}
})
})
it('should handle Gemini Flash models with thinking budget 0', async () => {
const {
isSupportedThinkingTokenGeminiModel,
isReasoningModel,
isQwenReasoningModel,
isSupportedThinkingTokenClaudeModel,
isSupportedThinkingTokenDoubaoModel,
isSupportedThinkingTokenZhipuModel,
isOpenAIDeepResearchModel,
isSupportedThinkingTokenQwenModel,
isSupportedThinkingTokenHunyuanModel,
isDeepSeekHybridInferenceModel
} = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
vi.mocked(isQwenReasoningModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenDoubaoModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenZhipuModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenQwenModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenHunyuanModel).mockReturnValue(false)
vi.mocked(isDeepSeekHybridInferenceModel).mockReturnValue(false)
const model: Model = {
id: 'gemini-2.5-flash',
name: 'Gemini 2.5 Flash',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
extra_body: {
google: {
thinking_config: {
thinking_budget: 0
}
}
}
})
})
it('should handle GPT-5.1 reasoning model with effort levels', async () => {
const {
isReasoningModel,
isOpenAIDeepResearchModel,
isSupportedReasoningEffortModel,
isGPT51SeriesModel,
getThinkModelType
} = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
vi.mocked(isSupportedReasoningEffortModel).mockReturnValue(true)
vi.mocked(getThinkModelType).mockReturnValue('gpt5_1')
vi.mocked(isGPT51SeriesModel).mockReturnValue(true)
const model: Model = {
id: 'gpt-5.1',
name: 'GPT-5.1',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'none'
}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
reasoningEffort: 'none'
})
})
it('should handle DeepSeek hybrid inference models', async () => {
const { isReasoningModel, isDeepSeekHybridInferenceModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isDeepSeekHybridInferenceModel).mockReturnValue(true)
const model: Model = {
id: 'deepseek-v3.1',
name: 'DeepSeek V3.1',
provider: SystemProviderIds.silicon
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
enable_thinking: true
})
})
it('should return medium effort for deep research models', async () => {
const { isReasoningModel, isOpenAIDeepResearchModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(true)
const model: Model = {
id: 'o3-deep-research',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({ reasoning_effort: 'medium' })
})
it('should return empty for groq provider', async () => {
const { getProviderByModel } = await import('@renderer/services/AssistantService')
vi.mocked(getProviderByModel).mockReturnValue({
id: 'groq',
name: 'Groq'
} as Provider)
const model: Model = {
id: 'groq-model',
name: 'Groq Model',
provider: 'groq'
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({})
})
})
describe('getOpenAIReasoningParams', () => {
it('should return empty object for non-reasoning model', async () => {
const model: Model = {
id: 'gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return empty when no reasoning effort set', async () => {
const model: Model = {
id: 'o1-preview',
name: 'O1 Preview',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return reasoning effort for OpenAI models', async () => {
const { isReasoningModel, isOpenAIModel, isSupportedReasoningEffortOpenAIModel } = await import(
'@renderer/config/models'
)
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIModel).mockReturnValue(true)
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
const model: Model = {
id: 'gpt-5.1',
name: 'GPT 5.1',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({
reasoningEffort: 'high',
reasoningSummary: 'auto'
})
})
it('should include reasoning summary when not o1-pro', async () => {
const { isReasoningModel, isOpenAIModel, isSupportedReasoningEffortOpenAIModel } = await import(
'@renderer/config/models'
)
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIModel).mockReturnValue(true)
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
const model: Model = {
id: 'gpt-5',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium'
}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({
reasoningEffort: 'medium',
reasoningSummary: 'auto'
})
})
it('should not include reasoning summary for o1-pro', async () => {
const { isReasoningModel, isOpenAIDeepResearchModel, isSupportedReasoningEffortOpenAIModel } = await import(
'@renderer/config/models'
)
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
vi.mocked(getStoreSetting).mockReturnValue({ summaryText: 'off' } as any)
const model: Model = {
id: 'o1-pro',
name: 'O1 Pro',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({
reasoningEffort: 'high',
reasoningSummary: undefined
})
})
it('should force medium effort for deep research models', async () => {
const { isReasoningModel, isOpenAIModel, isOpenAIDeepResearchModel, isSupportedReasoningEffortOpenAIModel } =
await import('@renderer/config/models')
const { getStoreSetting } = await import('@renderer/hooks/useSettings')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(true)
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
vi.mocked(getStoreSetting).mockReturnValue({ summaryText: 'off' } as any)
const model: Model = {
id: 'o3-deep-research',
name: 'O3 Mini',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({
reasoningEffort: 'medium',
reasoningSummary: 'off'
})
})
})
describe('getAnthropicReasoningParams', () => {
it('should return empty for non-reasoning model', async () => {
const { isReasoningModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(false)
const model: Model = {
id: 'claude-3-5-sonnet',
name: 'Claude 3.5 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getAnthropicReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return disabled thinking when no reasoning effort', async () => {
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(false)
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getAnthropicReasoningParams(assistant, model)
expect(result).toEqual({
thinking: {
type: 'disabled'
}
})
})
it('should return enabled thinking with budget for Claude models', async () => {
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium',
maxTokens: 4096
}
} as Assistant
const result = getAnthropicReasoningParams(assistant, model)
expect(result).toEqual({
thinking: {
type: 'enabled',
budgetTokens: 2048
}
})
})
})
describe('getGeminiReasoningParams', () => {
it('should return empty for non-reasoning model', async () => {
const { isReasoningModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(false)
const model: Model = {
id: 'gemini-2.0-flash',
name: 'Gemini 2.0 Flash',
provider: SystemProviderIds.gemini
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getGeminiReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should disable thinking for Flash models without reasoning effort', async () => {
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
const model: Model = {
id: 'gemini-2.5-flash',
name: 'Gemini 2.5 Flash',
provider: SystemProviderIds.gemini
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getGeminiReasoningParams(assistant, model)
expect(result).toEqual({
thinkingConfig: {
includeThoughts: false,
thinkingBudget: 0
}
})
})
it('should enable thinking with budget for reasoning effort', async () => {
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
const model: Model = {
id: 'gemini-2.5-pro',
name: 'Gemini 2.5 Pro',
provider: SystemProviderIds.gemini
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium'
}
} as Assistant
const result = getGeminiReasoningParams(assistant, model)
expect(result).toEqual({
thinkingConfig: {
thinkingBudget: 16448,
includeThoughts: true
}
})
})
it('should enable thinking without budget for auto effort ratio > 1', async () => {
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
const model: Model = {
id: 'gemini-2.5-pro',
name: 'Gemini 2.5 Pro',
provider: SystemProviderIds.gemini
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'auto'
}
} as Assistant
const result = getGeminiReasoningParams(assistant, model)
expect(result).toEqual({
thinkingConfig: {
includeThoughts: true
}
})
})
})
describe('getXAIReasoningParams', () => {
it('should return empty for non-Grok model', async () => {
const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(false)
const model: Model = {
id: 'other-model',
name: 'Other Model',
provider: SystemProviderIds.grok
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getXAIReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return empty when no reasoning effort', async () => {
const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(true)
const model: Model = {
id: 'grok-2',
name: 'Grok 2',
provider: SystemProviderIds.grok
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getXAIReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return reasoning effort for Grok models', async () => {
const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(true)
const model: Model = {
id: 'grok-3',
name: 'Grok 3',
provider: SystemProviderIds.grok
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getXAIReasoningParams(assistant, model)
expect(result).toHaveProperty('reasoningEffort')
expect(result.reasoningEffort).toBe('high')
})
})
describe('getBedrockReasoningParams', () => {
it('should return empty for non-reasoning model', async () => {
const model: Model = {
id: 'other-model',
name: 'Other Model',
provider: 'bedrock'
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getBedrockReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return empty when no reasoning effort', async () => {
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: 'bedrock'
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getBedrockReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return reasoning config for Claude models on Bedrock', async () => {
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: 'bedrock'
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium',
maxTokens: 4096
}
} as Assistant
const result = getBedrockReasoningParams(assistant, model)
expect(result).toEqual({
reasoningConfig: {
type: 'enabled',
budgetTokens: 2048
}
})
})
})
describe('getCustomParameters', () => {
it('should return empty object when no custom parameters', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({})
})
it('should return custom parameters as key-value pairs', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [
{ name: 'param1', value: 'value1', type: 'string' },
{ name: 'param2', value: 123, type: 'number' }
]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
param1: 'value1',
param2: 123
})
})
it('should parse JSON type parameters', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [{ name: 'config', value: '{"key": "value"}', type: 'json' }]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
config: { key: 'value' }
})
})
it('should handle invalid JSON gracefully', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [{ name: 'invalid', value: '{invalid json', type: 'json' }]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
invalid: '{invalid json'
})
})
it('should handle undefined JSON value', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [{ name: 'undef', value: 'undefined', type: 'json' }]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
undef: undefined
})
})
it('should skip parameters with empty names', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [
{ name: '', value: 'value1', type: 'string' },
{ name: ' ', value: 'value2', type: 'string' },
{ name: 'valid', value: 'value3', type: 'string' }
]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
valid: 'value3'
})
})
})
})

View File

@@ -0,0 +1,384 @@
/**
* websearch.ts Unit Tests
* Tests for web search parameters generation utilities
*/
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
import type { Model } from '@renderer/types'
import { describe, expect, it, vi } from 'vitest'
import { buildProviderBuiltinWebSearchConfig, getWebSearchParams } from '../websearch'
// Mock dependencies
vi.mock('@renderer/config/models', () => ({
isOpenAIWebSearchChatCompletionOnlyModel: vi.fn((model) => model?.id?.includes('o1-pro') ?? false),
isOpenAIDeepResearchModel: vi.fn((model) => model?.id?.includes('o3-mini') ?? false)
}))
vi.mock('@renderer/utils/blacklistMatchPattern', () => ({
mapRegexToPatterns: vi.fn((patterns) => patterns || [])
}))
describe('websearch utils', () => {
describe('getWebSearchParams', () => {
it('should return enhancement params for hunyuan provider', () => {
const model: Model = {
id: 'hunyuan-model',
name: 'Hunyuan Model',
provider: 'hunyuan'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({
enable_enhancement: true,
citation: true,
search_info: true
})
})
it('should return search params for dashscope provider', () => {
const model: Model = {
id: 'qwen-model',
name: 'Qwen Model',
provider: 'dashscope'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({
enable_search: true,
search_options: {
forced_search: true
}
})
})
it('should return web_search_options for OpenAI web search models', () => {
const model: Model = {
id: 'o1-pro',
name: 'O1 Pro',
provider: 'openai'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({
web_search_options: {}
})
})
it('should return empty object for other providers', () => {
const model: Model = {
id: 'gpt-4',
name: 'GPT-4',
provider: 'openai'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({})
})
it('should return empty object for custom provider', () => {
const model: Model = {
id: 'custom-model',
name: 'Custom Model',
provider: 'custom-provider'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({})
})
})
describe('buildProviderBuiltinWebSearchConfig', () => {
const defaultWebSearchConfig: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 50,
excludeDomains: []
}
describe('openai provider', () => {
it('should return low search context size for low maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 20,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result).toEqual({
openai: {
searchContextSize: 'low'
}
})
})
it('should return medium search context size for medium maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 50,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result).toEqual({
openai: {
searchContextSize: 'medium'
}
})
})
it('should return high search context size for high maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 80,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result).toEqual({
openai: {
searchContextSize: 'high'
}
})
})
it('should use medium for deep research models regardless of maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 100,
excludeDomains: []
}
const model: Model = {
id: 'o3-mini',
name: 'O3 Mini',
provider: 'openai'
} as Model
const result = buildProviderBuiltinWebSearchConfig('openai', config, model)
expect(result).toEqual({
openai: {
searchContextSize: 'medium'
}
})
})
})
describe('openai-chat provider', () => {
it('should return correct search context size', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 50,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai-chat', config)
expect(result).toEqual({
'openai-chat': {
searchContextSize: 'medium'
}
})
})
it('should handle deep research models', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 100,
excludeDomains: []
}
const model: Model = {
id: 'o3-mini',
name: 'O3 Mini',
provider: 'openai'
} as Model
const result = buildProviderBuiltinWebSearchConfig('openai-chat', config, model)
expect(result).toEqual({
'openai-chat': {
searchContextSize: 'medium'
}
})
})
})
describe('anthropic provider', () => {
it('should return anthropic search options with maxUses', () => {
const result = buildProviderBuiltinWebSearchConfig('anthropic', defaultWebSearchConfig)
expect(result).toEqual({
anthropic: {
maxUses: 50,
blockedDomains: undefined
}
})
})
it('should include blockedDomains when excludeDomains provided', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 30,
excludeDomains: ['example.com', 'test.com']
}
const result = buildProviderBuiltinWebSearchConfig('anthropic', config)
expect(result).toEqual({
anthropic: {
maxUses: 30,
blockedDomains: ['example.com', 'test.com']
}
})
})
it('should not include blockedDomains when empty', () => {
const result = buildProviderBuiltinWebSearchConfig('anthropic', defaultWebSearchConfig)
expect(result).toEqual({
anthropic: {
maxUses: 50,
blockedDomains: undefined
}
})
})
})
describe('xai provider', () => {
it('should return xai search options', () => {
const result = buildProviderBuiltinWebSearchConfig('xai', defaultWebSearchConfig)
expect(result).toEqual({
xai: {
maxSearchResults: 50,
returnCitations: true,
sources: [{ type: 'web', excludedWebsites: [] }, { type: 'news' }, { type: 'x' }],
mode: 'on'
}
})
})
it('should limit excluded websites to 5', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 40,
excludeDomains: ['site1.com', 'site2.com', 'site3.com', 'site4.com', 'site5.com', 'site6.com', 'site7.com']
}
const result = buildProviderBuiltinWebSearchConfig('xai', config)
expect(result?.xai?.sources).toBeDefined()
const webSource = result?.xai?.sources?.[0]
if (webSource && webSource.type === 'web') {
expect(webSource.excludedWebsites).toHaveLength(5)
}
})
it('should include all sources types', () => {
const result = buildProviderBuiltinWebSearchConfig('xai', defaultWebSearchConfig)
expect(result?.xai?.sources).toHaveLength(3)
expect(result?.xai?.sources?.[0].type).toBe('web')
expect(result?.xai?.sources?.[1].type).toBe('news')
expect(result?.xai?.sources?.[2].type).toBe('x')
})
})
describe('openrouter provider', () => {
it('should return openrouter plugins config', () => {
const result = buildProviderBuiltinWebSearchConfig('openrouter', defaultWebSearchConfig)
expect(result).toEqual({
openrouter: {
plugins: [
{
id: 'web',
max_results: 50
}
]
}
})
})
it('should respect custom maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 75,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openrouter', config)
expect(result).toEqual({
openrouter: {
plugins: [
{
id: 'web',
max_results: 75
}
]
}
})
})
})
describe('unsupported provider', () => {
it('should return empty object for unsupported provider', () => {
const result = buildProviderBuiltinWebSearchConfig('unsupported' as any, defaultWebSearchConfig)
expect(result).toEqual({})
})
it('should return empty object for google provider', () => {
const result = buildProviderBuiltinWebSearchConfig('google', defaultWebSearchConfig)
expect(result).toEqual({})
})
})
describe('edge cases', () => {
it('should handle maxResults at boundary values', () => {
// Test boundary at 33 (low/medium)
const config33: CherryWebSearchConfig = { searchWithTime: true, maxResults: 33, excludeDomains: [] }
const result33 = buildProviderBuiltinWebSearchConfig('openai', config33)
expect(result33?.openai?.searchContextSize).toBe('low')
// Test boundary at 34 (medium)
const config34: CherryWebSearchConfig = { searchWithTime: true, maxResults: 34, excludeDomains: [] }
const result34 = buildProviderBuiltinWebSearchConfig('openai', config34)
expect(result34?.openai?.searchContextSize).toBe('medium')
// Test boundary at 66 (medium)
const config66: CherryWebSearchConfig = { searchWithTime: true, maxResults: 66, excludeDomains: [] }
const result66 = buildProviderBuiltinWebSearchConfig('openai', config66)
expect(result66?.openai?.searchContextSize).toBe('medium')
// Test boundary at 67 (high)
const config67: CherryWebSearchConfig = { searchWithTime: true, maxResults: 67, excludeDomains: [] }
const result67 = buildProviderBuiltinWebSearchConfig('openai', config67)
expect(result67?.openai?.searchContextSize).toBe('high')
})
it('should handle zero maxResults', () => {
const config: CherryWebSearchConfig = { searchWithTime: true, maxResults: 0, excludeDomains: [] }
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result?.openai?.searchContextSize).toBe('low')
})
it('should handle very large maxResults', () => {
const config: CherryWebSearchConfig = { searchWithTime: true, maxResults: 1000, excludeDomains: [] }
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result?.openai?.searchContextSize).toBe('high')
})
})
})
})

View File

@@ -1,3 +1,8 @@
import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock'
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import type { XaiProviderOptions } from '@ai-sdk/xai'
import { baseProviderIdSchema, customProviderIdSchema } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import {
@@ -7,17 +12,27 @@ import {
isSupportFlexServiceTierModel,
isSupportVerbosityModel
} from '@renderer/config/models'
import { isSupportServiceTierProvider } from '@renderer/config/providers'
import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
import type { Assistant, Model, Provider } from '@renderer/types'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import {
type Assistant,
type GroqServiceTier,
GroqServiceTiers,
type GroqSystemProvider,
isGroqServiceTier,
isGroqSystemProvider,
isOpenAIServiceTier,
isTranslateAssistant,
type Model,
type NotGroqProvider,
type OpenAIServiceTier,
OpenAIServiceTiers,
SystemProviderIds
type Provider,
type ServiceTier
} from '@renderer/types'
import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import { isSupportServiceTierProvider } from '@renderer/utils/provider'
import type { JSONValue } from 'ai'
import { t } from 'i18next'
import { getAiSdkProviderId } from '../provider/factory'
@@ -35,8 +50,31 @@ import { getWebSearchParams } from './websearch'
const logger = loggerService.withContext('aiCore.utils.options')
// copy from BaseApiClient.ts
const getServiceTier = (model: Model, provider: Provider) => {
function toOpenAIServiceTier(model: Model, serviceTier: ServiceTier): OpenAIServiceTier {
if (
!isOpenAIServiceTier(serviceTier) ||
(serviceTier === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
} else {
return serviceTier
}
}
function toGroqServiceTier(model: Model, serviceTier: ServiceTier): GroqServiceTier {
if (
!isGroqServiceTier(serviceTier) ||
(serviceTier === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
} else {
return serviceTier
}
}
function getServiceTier<T extends GroqSystemProvider>(model: Model, provider: T): GroqServiceTier
function getServiceTier<T extends NotGroqProvider>(model: Model, provider: T): OpenAIServiceTier
function getServiceTier<T extends Provider>(model: Model, provider: T): OpenAIServiceTier | GroqServiceTier {
const serviceTierSetting = provider.serviceTier
if (!isSupportServiceTierProvider(provider) || !isOpenAIModel(model) || !serviceTierSetting) {
@@ -44,24 +82,17 @@ const getServiceTier = (model: Model, provider: Provider) => {
}
// 处理不同供应商需要 fallback 到默认值的情况
if (provider.id === SystemProviderIds.groq) {
if (
!isGroqServiceTier(serviceTierSetting) ||
(serviceTierSetting === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
}
if (isGroqSystemProvider(provider)) {
return toGroqServiceTier(model, serviceTierSetting)
} else {
// 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同
if (
!isOpenAIServiceTier(serviceTierSetting) ||
(serviceTierSetting === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
}
return toOpenAIServiceTier(model, serviceTierSetting)
}
}
return serviceTierSetting
function getVerbosity(): OpenAIVerbosity {
const openAI = getStoreSetting('openAI')
return openAI.verbosity
}
/**
@@ -78,13 +109,13 @@ export function buildProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
): Record<string, Record<string, JSONValue>> {
logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities })
const rawProviderId = getAiSdkProviderId(actualProvider)
// 构建 provider 特定的选项
let providerSpecificOptions: Record<string, any> = {}
const serviceTierSetting = getServiceTier(model, actualProvider)
providerSpecificOptions.serviceTier = serviceTierSetting
const serviceTier = getServiceTier(model, actualProvider)
const textVerbosity = getVerbosity()
// 根据 provider 类型分离构建逻辑
const { data: baseProviderId, success } = baseProviderIdSchema.safeParse(rawProviderId)
if (success) {
@@ -94,9 +125,14 @@ export function buildProviderOptions(
case 'openai-chat':
case 'azure':
case 'azure-responses':
providerSpecificOptions = {
...buildOpenAIProviderOptions(assistant, model, capabilities),
serviceTier: serviceTierSetting
{
const options: OpenAIResponsesProviderOptions = buildOpenAIProviderOptions(
assistant,
model,
capabilities,
serviceTier
)
providerSpecificOptions = options
}
break
case 'anthropic':
@@ -116,12 +152,19 @@ export function buildProviderOptions(
// 对于其他 provider使用通用的构建逻辑
providerSpecificOptions = {
...buildGenericProviderOptions(assistant, model, capabilities),
serviceTier: serviceTierSetting
serviceTier,
textVerbosity
}
break
}
case 'cherryin':
providerSpecificOptions = buildCherryInProviderOptions(assistant, model, capabilities, actualProvider)
providerSpecificOptions = buildCherryInProviderOptions(
assistant,
model,
capabilities,
actualProvider,
serviceTier
)
break
default:
throw new Error(`Unsupported base provider ${baseProviderId}`)
@@ -135,6 +178,7 @@ export function buildProviderOptions(
case 'google-vertex':
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
break
case 'azure-anthropic':
case 'google-vertex-anthropic':
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
break
@@ -142,13 +186,14 @@ export function buildProviderOptions(
providerSpecificOptions = buildBedrockProviderOptions(assistant, model, capabilities)
break
case 'huggingface':
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities)
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier)
break
default:
// 对于其他 provider使用通用的构建逻辑
providerSpecificOptions = {
...buildGenericProviderOptions(assistant, model, capabilities),
serviceTier: serviceTierSetting
serviceTier,
textVerbosity
}
}
} else {
@@ -166,6 +211,7 @@ export function buildProviderOptions(
{
'google-vertex': 'google',
'google-vertex-anthropic': 'anthropic',
'azure-anthropic': 'anthropic',
'ai-gateway': 'gateway'
}[rawProviderId] || rawProviderId
@@ -189,10 +235,11 @@ function buildOpenAIProviderOptions(
enableReasoning: boolean
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
},
serviceTier: OpenAIServiceTier
): OpenAIResponsesProviderOptions {
const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {}
let providerOptions: OpenAIResponsesProviderOptions = {}
// OpenAI 推理参数
if (enableReasoning) {
const reasoningParams = getOpenAIReasoningParams(assistant, model)
@@ -203,8 +250,8 @@ function buildOpenAIProviderOptions(
}
if (isSupportVerbosityModel(model)) {
const state = window.store?.getState()
const userVerbosity = state?.settings?.openAI?.verbosity
const openAI = getStoreSetting<'openAI'>('openAI')
const userVerbosity = openAI?.verbosity
if (userVerbosity && ['low', 'medium', 'high'].includes(userVerbosity)) {
const supportedVerbosity = getModelSupportedVerbosity(model)
@@ -218,6 +265,11 @@ function buildOpenAIProviderOptions(
}
}
providerOptions = {
...providerOptions,
serviceTier
}
return providerOptions
}
@@ -232,9 +284,9 @@ function buildAnthropicProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
): AnthropicProviderOptions {
const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {}
let providerOptions: AnthropicProviderOptions = {}
// Anthropic 推理参数
if (enableReasoning) {
@@ -259,9 +311,9 @@ function buildGeminiProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
): GoogleGenerativeAIProviderOptions {
const { enableReasoning, enableGenerateImage } = capabilities
let providerOptions: Record<string, any> = {}
let providerOptions: GoogleGenerativeAIProviderOptions = {}
// Gemini 推理参数
if (enableReasoning) {
@@ -290,7 +342,7 @@ function buildXAIProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
): XaiProviderOptions {
const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {}
@@ -313,16 +365,12 @@ function buildCherryInProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
},
actualProvider: Provider
): Record<string, any> {
const serviceTierSetting = getServiceTier(model, actualProvider)
actualProvider: Provider,
serviceTier: OpenAIServiceTier
): OpenAIResponsesProviderOptions | AnthropicProviderOptions | GoogleGenerativeAIProviderOptions {
switch (actualProvider.type) {
case 'openai':
return {
...buildOpenAIProviderOptions(assistant, model, capabilities),
serviceTier: serviceTierSetting
}
return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier)
case 'anthropic':
return buildAnthropicProviderOptions(assistant, model, capabilities)
@@ -344,9 +392,9 @@ function buildBedrockProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
): BedrockProviderOptions {
const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {}
let providerOptions: BedrockProviderOptions = {}
if (enableReasoning) {
const reasoningParams = getBedrockReasoningParams(assistant, model)

View File

@@ -1,6 +1,7 @@
import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock'
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import type { XaiProviderOptions } from '@ai-sdk/xai'
import { loggerService } from '@logger'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
@@ -11,6 +12,7 @@ import {
isDeepSeekHybridInferenceModel,
isDoubaoSeedAfter251015,
isDoubaoThinkingAutoModel,
isGemini3Model,
isGPT51SeriesModel,
isGrok4FastReasoningModel,
isGrokReasoningModel,
@@ -32,13 +34,13 @@ import {
isSupportedThinkingTokenZhipuModel,
MODEL_SUPPORTED_REASONING_EFFORT
} from '@renderer/config/models'
import { isSupportEnableThinkingProvider } from '@renderer/config/providers'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
import type { SettingsState } from '@renderer/store/settings'
import type { Assistant, Model } from '@renderer/types'
import type { Assistant, Model, ReasoningEffortOption } from '@renderer/types'
import { EFFORT_RATIO, isSystemProvider, SystemProviderIds } from '@renderer/types'
import type { OpenAISummaryText } from '@renderer/types/aiCoreTypes'
import type { ReasoningEffortOptionalParams } from '@renderer/types/sdk'
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
import { toInteger } from 'lodash'
const logger = loggerService.withContext('reasoning')
@@ -130,7 +132,7 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
}
// Specially for GPT-5.1. Suppose this is a OpenAI Compatible provider
if (isGPT51SeriesModel(model) && reasoningEffort === 'none') {
if (isGPT51SeriesModel(model)) {
return {
reasoningEffort: 'none'
}
@@ -278,6 +280,12 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
// gemini series, openai compatible api
if (isSupportedThinkingTokenGeminiModel(model)) {
// https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#openai_compatibility
if (isGemini3Model(model)) {
return {
reasoning_effort: reasoningEffort
}
}
if (reasoningEffort === 'auto') {
return {
extra_body: {
@@ -341,10 +349,14 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
}
/**
* 获取 OpenAI 推理参数
* OpenAIResponseAPIClient OpenAIAPIClient 中提取的逻辑
* Get OpenAI reasoning parameters
* Extracted from OpenAIResponseAPIClient and OpenAIAPIClient logic
* For official OpenAI provider only
*/
export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
export function getOpenAIReasoningParams(
assistant: Assistant,
model: Model
): Pick<OpenAIResponsesProviderOptions, 'reasoningEffort' | 'reasoningSummary'> {
if (!isReasoningModel(model)) {
return {}
}
@@ -355,6 +367,10 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
return {}
}
if (isOpenAIDeepResearchModel(model) || reasoningEffort === 'auto') {
reasoningEffort = 'medium'
}
// 非OpenAI模型但是Provider类型是responses/azure openai的情况
if (!isOpenAIModel(model)) {
return {
@@ -362,21 +378,17 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
}
}
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
const summaryText = openAI?.summaryText || 'off'
const openAI = getStoreSetting('openAI')
const summaryText = openAI.summaryText
let reasoningSummary: string | undefined = undefined
let reasoningSummary: OpenAISummaryText = undefined
if (summaryText === 'off' || model.id.includes('o1-pro')) {
if (model.id.includes('o1-pro')) {
reasoningSummary = undefined
} else {
reasoningSummary = summaryText
}
if (isOpenAIDeepResearchModel(model)) {
reasoningEffort = 'medium'
}
// OpenAI 推理参数
if (isSupportedReasoningEffortOpenAIModel(model)) {
return {
@@ -388,19 +400,26 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
return {}
}
export function getAnthropicThinkingBudget(assistant: Assistant, model: Model): number {
const { maxTokens, reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
export function getAnthropicThinkingBudget(
maxTokens: number | undefined,
reasoningEffort: string | undefined,
modelId: string
): number | undefined {
if (reasoningEffort === undefined || reasoningEffort === 'none') {
return 0
return undefined
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
const tokenLimit = findTokenLimit(modelId)
if (!tokenLimit) {
return undefined
}
const budgetTokens = Math.max(
1024,
Math.floor(
Math.min(
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
findTokenLimit(model.id)?.min!,
(tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min,
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
)
)
@@ -432,7 +451,8 @@ export function getAnthropicReasoningParams(
// Claude 推理参数
if (isSupportedThinkingTokenClaudeModel(model)) {
const budgetTokens = getAnthropicThinkingBudget(assistant, model)
const { maxTokens } = getAssistantSettings(assistant)
const budgetTokens = getAnthropicThinkingBudget(maxTokens, reasoningEffort, model.id)
return {
thinking: {
@@ -445,6 +465,21 @@ export function getAnthropicReasoningParams(
return {}
}
type GoogelThinkingLevel = NonNullable<GoogleGenerativeAIProviderOptions['thinkingConfig']>['thinkingLevel']
function mapToGeminiThinkingLevel(reasoningEffort: ReasoningEffortOption): GoogelThinkingLevel {
switch (reasoningEffort) {
case 'low':
return 'low'
case 'medium':
return 'medium'
case 'high':
return 'high'
default:
return 'medium'
}
}
/**
* 获取 Gemini 推理参数
* 从 GeminiAPIClient 中提取的逻辑
@@ -472,6 +507,15 @@ export function getGeminiReasoningParams(
}
}
// https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#new_api_features_in_gemini_3
if (isGemini3Model(model)) {
return {
thinkingConfig: {
thinkingLevel: mapToGeminiThinkingLevel(reasoningEffort)
}
}
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
if (effortRatio > 1) {
@@ -555,7 +599,8 @@ export function getBedrockReasoningParams(
return {}
}
const budgetTokens = getAnthropicThinkingBudget(assistant, model)
const { maxTokens } = getAssistantSettings(assistant)
const budgetTokens = getAnthropicThinkingBudget(maxTokens, reasoningEffort, model.id)
return {
reasoningConfig: {
type: 'enabled',

View File

@@ -47,6 +47,7 @@ export function buildProviderBuiltinWebSearchConfig(
model?: Model
): WebSearchPluginConfig | undefined {
switch (providerId) {
case 'azure-responses':
case 'openai': {
const searchContextSize = isOpenAIDeepResearchModel(model)
? 'medium'

View File

@@ -1,35 +1,120 @@
import 'emoji-picker-element'
import TwemojiCountryFlagsWoff2 from '@renderer/assets/fonts/country-flag-fonts/TwemojiCountryFlags.woff2?url'
import { useTheme } from '@renderer/context/ThemeProvider'
import type { LanguageVarious } from '@renderer/types'
import { polyfillCountryFlagEmojis } from 'country-flag-emoji-polyfill'
// i18n translations from emoji-picker-element
import de from 'emoji-picker-element/i18n/de'
import en from 'emoji-picker-element/i18n/en'
import es from 'emoji-picker-element/i18n/es'
import fr from 'emoji-picker-element/i18n/fr'
import ja from 'emoji-picker-element/i18n/ja'
import pt_PT from 'emoji-picker-element/i18n/pt_PT'
import ru_RU from 'emoji-picker-element/i18n/ru_RU'
import zh_CN from 'emoji-picker-element/i18n/zh_CN'
import type Picker from 'emoji-picker-element/picker'
import type { EmojiClickEvent, NativeEmoji } from 'emoji-picker-element/shared'
// Emoji data from emoji-picker-element-data (local, no CDN)
// Using CLDR format for full multi-language search support (28 languages)
import dataDE from 'emoji-picker-element-data/de/cldr/data.json?url'
import dataEN from 'emoji-picker-element-data/en/cldr/data.json?url'
import dataES from 'emoji-picker-element-data/es/cldr/data.json?url'
import dataFR from 'emoji-picker-element-data/fr/cldr/data.json?url'
import dataJA from 'emoji-picker-element-data/ja/cldr/data.json?url'
import dataPT from 'emoji-picker-element-data/pt/cldr/data.json?url'
import dataRU from 'emoji-picker-element-data/ru/cldr/data.json?url'
import dataZH from 'emoji-picker-element-data/zh/cldr/data.json?url'
import dataZH_HANT from 'emoji-picker-element-data/zh-hant/cldr/data.json?url'
import type { FC } from 'react'
import { useEffect, useRef } from 'react'
import { useTranslation } from 'react-i18next'
interface Props {
onEmojiClick: (emoji: string) => void
}
// Mapping from app locale to emoji-picker-element i18n
const i18nMap: Record<LanguageVarious, typeof en> = {
'en-US': en,
'zh-CN': zh_CN,
'zh-TW': zh_CN, // Closest available
'de-DE': de,
'el-GR': en, // No Greek available, fallback to English
'es-ES': es,
'fr-FR': fr,
'ja-JP': ja,
'pt-PT': pt_PT,
'ru-RU': ru_RU
}
// Mapping from app locale to emoji data URL
// Using CLDR format provides native language search support for all locales
const dataSourceMap: Record<LanguageVarious, string> = {
'en-US': dataEN,
'zh-CN': dataZH,
'zh-TW': dataZH_HANT,
'de-DE': dataDE,
'el-GR': dataEN, // No Greek CLDR available, fallback to English
'es-ES': dataES,
'fr-FR': dataFR,
'ja-JP': dataJA,
'pt-PT': dataPT,
'ru-RU': dataRU
}
// Mapping from app locale to emoji-picker-element locale string
// Must match the data source locale for proper IndexedDB caching
const localeMap: Record<LanguageVarious, string> = {
'en-US': 'en',
'zh-CN': 'zh',
'zh-TW': 'zh-hant',
'de-DE': 'de',
'el-GR': 'en',
'es-ES': 'es',
'fr-FR': 'fr',
'ja-JP': 'ja',
'pt-PT': 'pt',
'ru-RU': 'ru'
}
const EmojiPicker: FC<Props> = ({ onEmojiClick }) => {
const { theme } = useTheme()
const ref = useRef<HTMLDivElement>(null)
const { i18n } = useTranslation()
const ref = useRef<Picker>(null)
const currentLocale = i18n.language as LanguageVarious
useEffect(() => {
polyfillCountryFlagEmojis('Twemoji Mozilla', TwemojiCountryFlagsWoff2)
}, [])
// Configure picker with i18n and dataSource
useEffect(() => {
const refValue = ref.current
const picker = ref.current
if (picker) {
picker.i18n = i18nMap[currentLocale] || en
picker.dataSource = dataSourceMap[currentLocale] || dataEN
picker.locale = localeMap[currentLocale] || 'en'
}
}, [currentLocale])
if (refValue) {
const handleEmojiClick = (event: any) => {
useEffect(() => {
const picker = ref.current
if (picker) {
const handleEmojiClick = (event: EmojiClickEvent) => {
event.stopPropagation()
onEmojiClick(event.detail.unicode || event.detail.emoji.unicode)
const { detail } = event
// Use detail.unicode (processed with skin tone) or fallback to emoji's unicode for native emoji
const unicode = detail.unicode || ('unicode' in detail.emoji ? (detail.emoji as NativeEmoji).unicode : '')
onEmojiClick(unicode)
}
// 添加事件监听器
refValue.addEventListener('emoji-click', handleEmojiClick)
picker.addEventListener('emoji-click', handleEmojiClick)
// 清理事件监听器
return () => {
refValue.removeEventListener('emoji-click', handleEmojiClick)
picker.removeEventListener('emoji-click', handleEmojiClick)
}
}
return

View File

@@ -0,0 +1,141 @@
import { importChatGPTConversations } from '@renderer/services/import'
import { Alert, Modal, Progress, Space, Spin } from 'antd'
import { useState } from 'react'
import { useTranslation } from 'react-i18next'
import { TopView } from '../TopView'
interface PopupResult {
success?: boolean
}
interface Props {
resolve: (data: PopupResult) => void
}
const PopupContainer: React.FC<Props> = ({ resolve }) => {
const [open, setOpen] = useState(true)
const [selecting, setSelecting] = useState(false)
const [importing, setImporting] = useState(false)
const { t } = useTranslation()
const onOk = async () => {
setSelecting(true)
try {
// Select ChatGPT JSON file
const file = await window.api.file.open({
filters: [{ name: 'ChatGPT Conversations', extensions: ['json'] }]
})
setSelecting(false)
if (!file) {
return
}
setImporting(true)
// Parse file content
const fileContent = typeof file.content === 'string' ? file.content : new TextDecoder().decode(file.content)
// Import conversations
const result = await importChatGPTConversations(fileContent)
if (result.success) {
window.toast.success(
t('import.chatgpt.success', {
topics: result.topicsCount,
messages: result.messagesCount
})
)
setOpen(false)
} else {
window.toast.error(result.error || t('import.chatgpt.error.unknown'))
}
} catch (error) {
window.toast.error(t('import.chatgpt.error.unknown'))
setOpen(false)
} finally {
setSelecting(false)
setImporting(false)
}
}
const onCancel = () => {
setOpen(false)
}
const onClose = () => {
resolve({})
}
ImportPopup.hide = onCancel
return (
<Modal
title={t('import.chatgpt.title')}
open={open}
onOk={onOk}
onCancel={onCancel}
afterClose={onClose}
okText={t('import.chatgpt.button')}
okButtonProps={{ disabled: selecting || importing, loading: selecting }}
cancelButtonProps={{ disabled: selecting || importing }}
maskClosable={false}
transitionName="animation-move-down"
centered>
{!selecting && !importing && (
<Space direction="vertical" style={{ width: '100%' }}>
<div>{t('import.chatgpt.description')}</div>
<Alert
message={t('import.chatgpt.help.title')}
description={
<div>
<p>{t('import.chatgpt.help.step1')}</p>
<p>{t('import.chatgpt.help.step2')}</p>
<p>{t('import.chatgpt.help.step3')}</p>
</div>
}
type="info"
showIcon
style={{ marginTop: 12 }}
/>
</Space>
)}
{selecting && (
<div style={{ textAlign: 'center', padding: '40px 0' }}>
<Spin size="large" />
<div style={{ marginTop: 16 }}>{t('import.chatgpt.selecting')}</div>
</div>
)}
{importing && (
<div style={{ textAlign: 'center', padding: '20px 0' }}>
<Progress percent={100} status="active" strokeColor="var(--color-primary)" showInfo={false} />
<div style={{ marginTop: 16 }}>{t('import.chatgpt.importing')}</div>
</div>
)}
</Modal>
)
}
const TopViewKey = 'ImportPopup'
export default class ImportPopup {
static topviewId = 0
static hide() {
TopView.hide(TopViewKey)
}
static show() {
return new Promise<PopupResult>((resolve) => {
TopView.show(
<PopupContainer
resolve={(v) => {
resolve(v)
TopView.hide(TopViewKey)
}}
/>,
TopViewKey
)
})
}
}

View File

@@ -15,7 +15,7 @@ import type {
UpdateAgentForm
} from '@renderer/types'
import { AgentConfigurationSchema, isAgentType } from '@renderer/types'
import { Button, Input, Modal, Select } from 'antd'
import { Alert, Button, Input, Modal, Select } from 'antd'
import { AlertTriangleIcon } from 'lucide-react'
import type { ChangeEvent, FormEvent } from 'react'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
@@ -58,6 +58,7 @@ const PopupContainer: React.FC<Props> = ({ agent, afterSubmit, resolve }) => {
const isEditing = (agent?: AgentWithTools) => agent !== undefined
const [form, setForm] = useState<BaseAgentForm>(() => buildAgentForm(agent))
const [hasGitBash, setHasGitBash] = useState<boolean>(true)
useEffect(() => {
if (open) {
@@ -65,6 +66,30 @@ const PopupContainer: React.FC<Props> = ({ agent, afterSubmit, resolve }) => {
}
}, [agent, open])
const checkGitBash = useCallback(
async (showToast = false) => {
try {
const gitBashInstalled = await window.api.system.checkGitBash()
setHasGitBash(gitBashInstalled)
if (showToast) {
if (gitBashInstalled) {
window.toast.success(t('agent.gitBash.success', 'Git Bash detected successfully!'))
} else {
window.toast.error(t('agent.gitBash.notFound', 'Git Bash not found. Please install it first.'))
}
}
} catch (error) {
logger.error('Failed to check Git Bash:', error as Error)
setHasGitBash(true) // Default to true on error to avoid false warnings
}
},
[t]
)
useEffect(() => {
checkGitBash()
}, [checkGitBash])
const selectedPermissionMode = form.configuration?.permission_mode ?? 'default'
const onPermissionModeChange = useCallback((value: PermissionMode) => {
@@ -275,6 +300,36 @@ const PopupContainer: React.FC<Props> = ({ agent, afterSubmit, resolve }) => {
footer={null}>
<StyledForm onSubmit={onSubmit}>
<FormContent>
{!hasGitBash && (
<Alert
message={t('agent.gitBash.error.title', 'Git Bash Required')}
description={
<div>
<div style={{ marginBottom: 8 }}>
{t(
'agent.gitBash.error.description',
'Git Bash is required to run agents on Windows. The agent cannot function without it. Please install Git for Windows from'
)}{' '}
<a
href="https://git-scm.com/download/win"
onClick={(e) => {
e.preventDefault()
window.api.openWebsite('https://git-scm.com/download/win')
}}
style={{ textDecoration: 'underline' }}>
git-scm.com
</a>
</div>
<Button size="small" onClick={() => checkGitBash(true)}>
{t('agent.gitBash.error.recheck', 'Recheck Git Bash Installation')}
</Button>
</div>
}
type="error"
showIcon
style={{ marginBottom: 16 }}
/>
)}
<FormRow>
<FormItem style={{ flex: 1 }}>
<Label>
@@ -377,7 +432,7 @@ const PopupContainer: React.FC<Props> = ({ agent, afterSubmit, resolve }) => {
<FormFooter>
<Button onClick={onCancel}>{t('common.close')}</Button>
<Button type="primary" htmlType="submit" loading={loadingRef.current}>
<Button type="primary" htmlType="submit" loading={loadingRef.current} disabled={!hasGitBash}>
{isEditing(agent) ? t('common.confirm') : t('common.add')}
</Button>
</FormFooter>

View File

@@ -6,7 +6,7 @@ import { useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled, { css } from 'styled-components'
interface SelectorOption<V = string | number> {
interface SelectorOption<V = string | number | undefined | null> {
label: string | ReactNode
value: V
type?: 'group'
@@ -14,7 +14,7 @@ interface SelectorOption<V = string | number> {
disabled?: boolean
}
interface BaseSelectorProps<V = string | number> {
interface BaseSelectorProps<V = string | number | undefined | null> {
options: SelectorOption<V>[]
placeholder?: string
placement?: 'topLeft' | 'topCenter' | 'topRight' | 'bottomLeft' | 'bottomCenter' | 'bottomRight' | 'top' | 'bottom'
@@ -39,7 +39,7 @@ interface MultipleSelectorProps<V> extends BaseSelectorProps<V> {
export type SelectorProps<V> = SingleSelectorProps<V> | MultipleSelectorProps<V>
const Selector = <V extends string | number>({
const Selector = <V extends string | number | undefined | null>({
options,
value,
onChange = () => {},

View File

@@ -1,520 +0,0 @@
import { describe, expect, it, vi } from 'vitest'
import {
isDoubaoSeedAfter251015,
isDoubaoThinkingAutoModel,
isGeminiReasoningModel,
isLingReasoningModel,
isSupportedThinkingTokenGeminiModel
} from '../models/reasoning'
vi.mock('@renderer/store', () => ({
default: {
getState: () => ({
llm: {
settings: {}
}
})
}
}))
// FIXME: Idk why it's imported. Maybe circular dependency somewhere
vi.mock('@renderer/services/AssistantService.ts', () => ({
getDefaultAssistant: () => {
return {
id: 'default',
name: 'default',
emoji: '😀',
prompt: '',
topics: [],
messages: [],
type: 'assistant',
regularPhrases: [],
settings: {}
}
}
}))
describe('Doubao Models', () => {
describe('isDoubaoThinkingAutoModel', () => {
it('should return false for invalid models', () => {
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1-6-251015',
name: 'doubao-seed-1-6-251015',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1-6-lite-251015',
name: 'doubao-seed-1-6-lite-251015',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1-6-thinking-250715',
name: 'doubao-seed-1-6-thinking-250715',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1-6-flash',
name: 'doubao-seed-1-6-flash',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1-6-thinking',
name: 'doubao-seed-1-6-thinking',
provider: '',
group: ''
})
).toBe(false)
})
it('should return true for valid models', () => {
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1-6-250615',
name: 'doubao-seed-1-6-250615',
provider: '',
group: ''
})
).toBe(true)
expect(
isDoubaoThinkingAutoModel({
id: 'Doubao-Seed-1.6',
name: 'Doubao-Seed-1.6',
provider: '',
group: ''
})
).toBe(true)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-1-5-thinking-pro-m',
name: 'doubao-1-5-thinking-pro-m',
provider: '',
group: ''
})
).toBe(true)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-seed-1.6-lite',
name: 'doubao-seed-1.6-lite',
provider: '',
group: ''
})
).toBe(true)
expect(
isDoubaoThinkingAutoModel({
id: 'doubao-1-5-thinking-pro-m-12345',
name: 'doubao-1-5-thinking-pro-m-12345',
provider: '',
group: ''
})
).toBe(true)
})
})
describe('isDoubaoSeedAfter251015', () => {
it('should return true for models matching the pattern', () => {
expect(
isDoubaoSeedAfter251015({
id: 'doubao-seed-1-6-251015',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isDoubaoSeedAfter251015({
id: 'doubao-seed-1-6-lite-251015',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return false for models not matching the pattern', () => {
expect(
isDoubaoSeedAfter251015({
id: 'doubao-seed-1-6-250615',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoSeedAfter251015({
id: 'Doubao-Seed-1.6',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoSeedAfter251015({
id: 'doubao-1-5-thinking-pro-m',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isDoubaoSeedAfter251015({
id: 'doubao-seed-1-6-lite-251016',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
})
})
describe('Ling Models', () => {
describe('isLingReasoningModel', () => {
it('should return false for ling variants', () => {
expect(
isLingReasoningModel({
id: 'ling-1t',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isLingReasoningModel({
id: 'ling-flash-2.0',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isLingReasoningModel({
id: 'ling-mini-2.0',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
it('should return true for ring variants', () => {
expect(
isLingReasoningModel({
id: 'ring-1t',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isLingReasoningModel({
id: 'ring-flash-2.0',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isLingReasoningModel({
id: 'ring-mini-2.0',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
})
})
describe('Gemini Models', () => {
describe('isSupportedThinkingTokenGeminiModel', () => {
it('should return true for gemini 2.5 models', () => {
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-2.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-2.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-2.5-flash-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-2.5-pro-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini latest models', () => {
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-flash-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-pro-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-flash-lite-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini 3 models', () => {
// Preview versions
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-3-pro-preview',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'google/gemini-3-pro-preview',
name: '',
provider: '',
group: ''
})
).toBe(true)
// Future stable versions
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-3-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-3-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'google/gemini-3-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'google/gemini-3-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return false for image and tts models', () => {
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-2.5-flash-image',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-2.5-flash-preview-tts',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
it('should return false for older gemini models', () => {
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-1.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-1.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isSupportedThinkingTokenGeminiModel({
id: 'gemini-1.0-pro',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
})
describe('isGeminiReasoningModel', () => {
it('should return true for gemini thinking models', () => {
expect(
isGeminiReasoningModel({
id: 'gemini-2.0-flash-thinking',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isGeminiReasoningModel({
id: 'gemini-thinking-exp',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for supported thinking token gemini models', () => {
expect(
isGeminiReasoningModel({
id: 'gemini-2.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isGeminiReasoningModel({
id: 'gemini-2.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini-3 models', () => {
// Preview versions
expect(
isGeminiReasoningModel({
id: 'gemini-3-pro-preview',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isGeminiReasoningModel({
id: 'google/gemini-3-pro-preview',
name: '',
provider: '',
group: ''
})
).toBe(true)
// Future stable versions
expect(
isGeminiReasoningModel({
id: 'gemini-3-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isGeminiReasoningModel({
id: 'gemini-3-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isGeminiReasoningModel({
id: 'google/gemini-3-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isGeminiReasoningModel({
id: 'google/gemini-3-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return false for older gemini models without thinking', () => {
expect(
isGeminiReasoningModel({
id: 'gemini-1.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(false)
expect(
isGeminiReasoningModel({
id: 'gemini-1.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
it('should return false for undefined model', () => {
expect(isGeminiReasoningModel(undefined)).toBe(false)
})
})
})

View File

@@ -1,167 +0,0 @@
import { describe, expect, it, vi } from 'vitest'
import { isVisionModel } from '../models/vision'
vi.mock('@renderer/store', () => ({
default: {
getState: () => ({
llm: {
settings: {}
}
})
}
}))
// FIXME: Idk why it's imported. Maybe circular dependency somewhere
vi.mock('@renderer/services/AssistantService.ts', () => ({
getDefaultAssistant: () => {
return {
id: 'default',
name: 'default',
emoji: '😀',
prompt: '',
topics: [],
messages: [],
type: 'assistant',
regularPhrases: [],
settings: {}
}
},
getProviderByModel: () => null
}))
describe('isVisionModel', () => {
describe('Gemini Models', () => {
it('should return true for gemini 1.5 models', () => {
expect(
isVisionModel({
id: 'gemini-1.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-1.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini 2.x models', () => {
expect(
isVisionModel({
id: 'gemini-2.0-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-2.0-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-2.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-2.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini latest models', () => {
expect(
isVisionModel({
id: 'gemini-flash-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-pro-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-flash-lite-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini 3 models', () => {
// Preview versions
expect(
isVisionModel({
id: 'gemini-3-pro-preview',
name: '',
provider: '',
group: ''
})
).toBe(true)
// Future stable versions
expect(
isVisionModel({
id: 'gemini-3-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-3-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini exp models', () => {
expect(
isVisionModel({
id: 'gemini-exp-1206',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return false for gemini 1.0 models', () => {
expect(
isVisionModel({
id: 'gemini-1.0-pro',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
})
})

View File

@@ -1,64 +0,0 @@
import { describe, expect, it, vi } from 'vitest'
import { GEMINI_SEARCH_REGEX } from '../models/websearch'
vi.mock('@renderer/store', () => ({
default: {
getState: () => ({
llm: {
settings: {}
}
})
}
}))
// FIXME: Idk why it's imported. Maybe circular dependency somewhere
vi.mock('@renderer/services/AssistantService.ts', () => ({
getDefaultAssistant: () => {
return {
id: 'default',
name: 'default',
emoji: '😀',
prompt: '',
topics: [],
messages: [],
type: 'assistant',
regularPhrases: [],
settings: {}
}
},
getProviderByModel: () => null
}))
describe('Gemini Search Models', () => {
describe('GEMINI_SEARCH_REGEX', () => {
it('should match gemini 2.x models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-pro')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-pro')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash-latest')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-pro-latest')).toBe(true)
})
it('should match gemini latest models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-flash-latest')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-pro-latest')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-flash-lite-latest')).toBe(true)
})
it('should match gemini 3 models', () => {
// Preview versions
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro-preview')).toBe(true)
// Future stable versions
expect(GEMINI_SEARCH_REGEX.test('gemini-3-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro')).toBe(true)
})
it('should not match older gemini models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-1.5-flash')).toBe(false)
expect(GEMINI_SEARCH_REGEX.test('gemini-1.5-pro')).toBe(false)
expect(GEMINI_SEARCH_REGEX.test('gemini-1.0-pro')).toBe(false)
})
})
})

View File

@@ -0,0 +1,101 @@
import type { Model } from '@renderer/types'
import { describe, expect, it, vi } from 'vitest'
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store', () => ({
__esModule: true,
default: {
getState: () => ({
llm: { providers: [] },
settings: {}
})
},
useAppDispatch: vi.fn(),
useAppSelector: vi.fn()
}))
vi.mock('@renderer/store/settings', () => {
const noop = vi.fn()
return new Proxy(
{},
{
get: (_target, prop) => {
if (prop === 'initialState') {
return {}
}
return noop
}
}
)
})
vi.mock('@renderer/hooks/useSettings', () => ({
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
getStoreSetting: vi.fn()
}))
import { isEmbeddingModel, isRerankModel } from '../embedding'
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'test-model',
name: 'Test Model',
provider: 'openai',
group: 'Test',
...overrides
})
describe('isEmbeddingModel', () => {
it('returns true for ids that match the embedding regex', () => {
expect(isEmbeddingModel(createModel({ id: 'Text-Embedding-3-Small' }))).toBe(true)
})
it('returns false for rerank models even if they match embedding patterns', () => {
const model = createModel({ id: 'rerank-qa', name: 'rerank-qa' })
expect(isRerankModel(model)).toBe(true)
expect(isEmbeddingModel(model)).toBe(false)
})
it('honors user overrides for embedding capability', () => {
const model = createModel({
id: 'text-embedding-3-small',
capabilities: [{ type: 'embedding', isUserSelected: false }]
})
expect(isEmbeddingModel(model)).toBe(false)
})
it('uses the model name when provider is doubao', () => {
const model = createModel({
id: 'custom-id',
name: 'BGE-Large-zh-v1.5',
provider: 'doubao'
})
expect(isEmbeddingModel(model)).toBe(true)
})
it('returns false for anthropic provider models', () => {
const model = createModel({
id: 'text-embedding-ada-002',
provider: 'anthropic'
})
expect(isEmbeddingModel(model)).toBe(false)
})
})
describe('isRerankModel', () => {
it('identifies ids that match rerank regex', () => {
expect(isRerankModel(createModel({ id: 'jina-rerank-v2-base' }))).toBe(true)
})
it('honors user overrides for rerank capability', () => {
const model = createModel({
id: 'jina-rerank-v2-base',
capabilities: [{ type: 'rerank', isUserSelected: false }]
})
expect(isRerankModel(model)).toBe(false)
})
})

View File

@@ -1,33 +1,55 @@
import {
isImageEnhancementModel,
isPureGenerateImageModel,
isQwenReasoningModel,
isSupportedThinkingTokenQwenModel,
isVisionModel,
isWebSearchModel
isVisionModel
} from '@renderer/config/models'
import type { Model } from '@renderer/types'
import { beforeEach, describe, expect, test, vi } from 'vitest'
vi.mock('@renderer/store/llm', () => ({
initialState: {}
}))
vi.mock('@renderer/store', () => ({
default: {
getState: () => ({
llm: {
settings: {}
}
})
}
}))
const getProviderByModelMock = vi.fn()
const isEmbeddingModelMock = vi.fn()
const isRerankModelMock = vi.fn()
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: (...args: any[]) => getProviderByModelMock(...args),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
vi.mock('@renderer/config/models/embedding', () => ({
isEmbeddingModel: (...args: any[]) => isEmbeddingModelMock(...args),
isRerankModel: (...args: any[]) => isRerankModelMock(...args)
}))
beforeEach(() => {
vi.clearAllMocks()
getProviderByModelMock.mockReturnValue({ type: 'openai-response' } as any)
isEmbeddingModelMock.mockReturnValue(false)
isRerankModelMock.mockReturnValue(false)
})
// Suggested test cases
describe('Qwen Model Detection', () => {
beforeEach(() => {
vi.mock('@renderer/store/llm', () => ({
initialState: {}
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn().mockReturnValue({ id: 'cherryai' })
}))
vi.mock('@renderer/store', () => ({
default: {
getState: () => ({
llm: {
settings: {}
}
})
}
}))
})
test('isQwenReasoningModel', () => {
expect(isQwenReasoningModel({ id: 'qwen3-thinking' } as Model)).toBe(true)
expect(isQwenReasoningModel({ id: 'qwen3-instruct' } as Model)).toBe(false)
@@ -56,14 +78,6 @@ describe('Qwen Model Detection', () => {
})
describe('Vision Model Detection', () => {
beforeEach(() => {
vi.mock('@renderer/store/llm', () => ({
initialState: {}
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn().mockReturnValue({ id: 'cherryai' })
}))
})
test('isVisionModel', () => {
expect(isVisionModel({ id: 'qwen-vl-max' } as Model)).toBe(true)
expect(isVisionModel({ id: 'qwen-omni-turbo' } as Model)).toBe(true)
@@ -75,25 +89,4 @@ describe('Vision Model Detection', () => {
expect(isImageEnhancementModel({ id: 'qwen-image-edit' } as Model)).toBe(true)
expect(isImageEnhancementModel({ id: 'grok-2-image-latest' } as Model)).toBe(true)
})
test('isPureGenerateImageModel', () => {
expect(isPureGenerateImageModel({ id: 'gpt-image-1' } as Model)).toBe(true)
expect(isPureGenerateImageModel({ id: 'gemini-2.5-flash-image-preview' } as Model)).toBe(true)
expect(isPureGenerateImageModel({ id: 'gemini-2.0-flash-preview-image-generation' } as Model)).toBe(true)
expect(isPureGenerateImageModel({ id: 'grok-2-image-latest' } as Model)).toBe(true)
expect(isPureGenerateImageModel({ id: 'gpt-4o' } as Model)).toBe(false)
})
})
describe('Web Search Model Detection', () => {
beforeEach(() => {
vi.mock('@renderer/store/llm', () => ({
initialState: {}
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn().mockReturnValue({ id: 'cherryai' })
}))
})
test('isWebSearchModel', () => {
expect(isWebSearchModel({ id: 'grok-2-image-latest' } as Model)).toBe(false)
})
})

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,137 @@
import type { Model } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { isEmbeddingModel, isRerankModel } from '../embedding'
import { isDeepSeekHybridInferenceModel } from '../reasoning'
import { isFunctionCallingModel } from '../tooluse'
import { isPureGenerateImageModel, isTextToImageModel } from '../vision'
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store', () => ({
__esModule: true,
default: {
getState: () => ({
llm: { providers: [] },
settings: {}
})
},
useAppDispatch: vi.fn(),
useAppSelector: vi.fn()
}))
vi.mock('@renderer/store/settings', () => {
const noop = vi.fn()
return new Proxy(
{},
{
get: (_target, prop) => {
if (prop === 'initialState') {
return {}
}
return noop
}
}
)
})
vi.mock('@renderer/hooks/useSettings', () => ({
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
getStoreSetting: vi.fn()
}))
vi.mock('../embedding', () => ({
isEmbeddingModel: vi.fn(),
isRerankModel: vi.fn()
}))
vi.mock('../vision', () => ({
isPureGenerateImageModel: vi.fn(),
isTextToImageModel: vi.fn()
}))
vi.mock('../reasoning', () => ({
isDeepSeekHybridInferenceModel: vi.fn()
}))
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o',
name: 'gpt-4o',
provider: 'openai',
group: 'OpenAI',
...overrides
})
const embeddingMock = vi.mocked(isEmbeddingModel)
const rerankMock = vi.mocked(isRerankModel)
const pureImageMock = vi.mocked(isPureGenerateImageModel)
const textToImageMock = vi.mocked(isTextToImageModel)
const deepSeekHybridMock = vi.mocked(isDeepSeekHybridInferenceModel)
describe('isFunctionCallingModel', () => {
beforeEach(() => {
vi.clearAllMocks()
embeddingMock.mockReturnValue(false)
rerankMock.mockReturnValue(false)
pureImageMock.mockReturnValue(false)
textToImageMock.mockReturnValue(false)
deepSeekHybridMock.mockReturnValue(false)
})
it('returns false when the model is undefined', () => {
expect(isFunctionCallingModel(undefined as unknown as Model)).toBe(false)
})
it('returns false when model is classified as embedding/rerank/image', () => {
embeddingMock.mockReturnValueOnce(true)
expect(isFunctionCallingModel(createModel())).toBe(false)
})
it('respect manual user overrides', () => {
const model = createModel({
capabilities: [{ type: 'function_calling', isUserSelected: false }]
})
expect(isFunctionCallingModel(model)).toBe(false)
const enabled = createModel({
capabilities: [{ type: 'function_calling', isUserSelected: true }]
})
expect(isFunctionCallingModel(enabled)).toBe(true)
})
it('matches doubao models by name when regex applies', () => {
const doubao = createModel({
id: 'custom-model',
name: 'Doubao-Seed-1.6-251015',
provider: 'doubao'
})
expect(isFunctionCallingModel(doubao)).toBe(true)
})
it('returns true for regex matches on standard providers', () => {
expect(isFunctionCallingModel(createModel({ id: 'gpt-5' }))).toBe(true)
})
it('excludes explicitly blocked ids', () => {
expect(isFunctionCallingModel(createModel({ id: 'gemini-1.5-flash' }))).toBe(false)
})
it('forces support for trusted providers', () => {
for (const provider of ['deepseek', 'anthropic', 'kimi', 'moonshot']) {
expect(isFunctionCallingModel(createModel({ provider }))).toBe(true)
}
})
it('returns true when identified as deepseek hybrid inference model', () => {
deepSeekHybridMock.mockReturnValueOnce(true)
expect(isFunctionCallingModel(createModel({ id: 'deepseek-v3-1', provider: 'custom' }))).toBe(true)
})
it('returns false for deepseek hybrid models behind restricted system providers', () => {
deepSeekHybridMock.mockReturnValueOnce(true)
expect(isFunctionCallingModel(createModel({ id: 'deepseek-v3-1', provider: 'dashscope' }))).toBe(false)
})
})

View File

@@ -0,0 +1,280 @@
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding'
import type { Model } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
isGPT5ProModel,
isGPT5SeriesModel,
isGPT5SeriesReasoningModel,
isGPT51SeriesModel,
isOpenAIChatCompletionOnlyModel,
isOpenAILLMModel,
isOpenAIModel,
isOpenAIOpenWeightModel,
isOpenAIReasoningModel,
isSupportVerbosityModel
} from '../openai'
import { isQwenMTModel } from '../qwen'
import {
agentModelFilter,
getModelSupportedVerbosity,
groupQwenModels,
isAnthropicModel,
isGeminiModel,
isGemmaModel,
isGenerateImageModels,
isMaxTemperatureOneModel,
isNotSupportedTextDelta,
isNotSupportSystemMessageModel,
isNotSupportTemperatureAndTopP,
isSupportedFlexServiceTier,
isSupportedModel,
isSupportFlexServiceTierModel,
isVisionModels,
isZhipuModel
} from '../utils'
import { isGenerateImageModel, isTextToImageModel, isVisionModel } from '../vision'
import { isOpenAIWebSearchChatCompletionOnlyModel } from '../websearch'
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store', () => ({
__esModule: true,
default: {
getState: () => ({
llm: { providers: [] },
settings: {}
})
},
useAppDispatch: vi.fn(),
useAppSelector: vi.fn()
}))
vi.mock('@renderer/store/settings', () => {
const noop = vi.fn()
return new Proxy(
{},
{
get: (_target, prop) => {
if (prop === 'initialState') {
return {}
}
return noop
}
}
)
})
vi.mock('@renderer/hooks/useSettings', () => ({
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
getStoreSetting: vi.fn()
}))
vi.mock('@renderer/config/models/embedding', () => ({
isEmbeddingModel: vi.fn(),
isRerankModel: vi.fn()
}))
vi.mock('../vision', () => ({
isGenerateImageModel: vi.fn(),
isTextToImageModel: vi.fn(),
isVisionModel: vi.fn()
}))
vi.mock(import('../openai'), async (importOriginal) => {
const actual = await importOriginal()
return {
...actual,
isOpenAIReasoningModel: vi.fn()
}
})
vi.mock('../websearch', () => ({
isOpenAIWebSearchChatCompletionOnlyModel: vi.fn()
}))
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o',
name: 'gpt-4o',
provider: 'openai',
group: 'OpenAI',
...overrides
})
const embeddingMock = vi.mocked(isEmbeddingModel)
const rerankMock = vi.mocked(isRerankModel)
const visionMock = vi.mocked(isVisionModel)
const textToImageMock = vi.mocked(isTextToImageModel)
const generateImageMock = vi.mocked(isGenerateImageModel)
const reasoningMock = vi.mocked(isOpenAIReasoningModel)
const openAIWebSearchOnlyMock = vi.mocked(isOpenAIWebSearchChatCompletionOnlyModel)
describe('model utils', () => {
beforeEach(() => {
vi.clearAllMocks()
embeddingMock.mockReturnValue(false)
rerankMock.mockReturnValue(false)
visionMock.mockReturnValue(true)
textToImageMock.mockReturnValue(false)
generateImageMock.mockReturnValue(true)
reasoningMock.mockReturnValue(false)
openAIWebSearchOnlyMock.mockReturnValue(false)
})
it('detects OpenAI LLM models through reasoning and GPT prefix', () => {
expect(isOpenAILLMModel(undefined as unknown as Model)).toBe(false)
expect(isOpenAILLMModel(createModel({ id: 'gpt-4o-image' }))).toBe(false)
reasoningMock.mockReturnValueOnce(true)
expect(isOpenAILLMModel(createModel({ id: 'o1-preview' }))).toBe(true)
expect(isOpenAILLMModel(createModel({ id: 'GPT-5-turbo' }))).toBe(true)
})
it('detects OpenAI models via GPT prefix or reasoning support', () => {
expect(isOpenAIModel(createModel({ id: 'gpt-4.1' }))).toBe(true)
reasoningMock.mockReturnValueOnce(true)
expect(isOpenAIModel(createModel({ id: 'o3' }))).toBe(true)
})
it('evaluates support for flex service tier and alias helper', () => {
expect(isSupportFlexServiceTierModel(createModel({ id: 'o3' }))).toBe(true)
expect(isSupportFlexServiceTierModel(createModel({ id: 'o3-mini' }))).toBe(false)
expect(isSupportFlexServiceTierModel(createModel({ id: 'o4-mini' }))).toBe(true)
expect(isSupportFlexServiceTierModel(createModel({ id: 'gpt-5-preview' }))).toBe(true)
expect(isSupportedFlexServiceTier(createModel({ id: 'gpt-4o' }))).toBe(false)
})
it('detects verbosity support for GPT-5+ families', () => {
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5' }))).toBe(true)
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5-chat' }))).toBe(false)
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5.1-preview' }))).toBe(true)
})
it('limits verbosity controls for GPT-5 Pro models', () => {
const proModel = createModel({ id: 'gpt-5-pro' })
const previewModel = createModel({ id: 'gpt-5-preview' })
expect(getModelSupportedVerbosity(proModel)).toEqual([undefined, 'high'])
expect(getModelSupportedVerbosity(previewModel)).toEqual([undefined, 'low', 'medium', 'high'])
expect(isGPT5ProModel(proModel)).toBe(true)
expect(isGPT5ProModel(previewModel)).toBe(false)
})
it('identifies OpenAI chat-completion-only models', () => {
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'o1-mini' }))).toBe(true)
expect(isOpenAIChatCompletionOnlyModel(createModel({ id: 'gpt-4o' }))).toBe(false)
})
it('filters unsupported OpenAI catalog entries', () => {
expect(isSupportedModel({ id: 'gpt-4', object: 'model' } as any)).toBe(true)
expect(isSupportedModel({ id: 'tts-1', object: 'model' } as any)).toBe(false)
})
it('calculates temperature/top-p support correctly', () => {
const model = createModel({ id: 'o1' })
reasoningMock.mockReturnValue(true)
expect(isNotSupportTemperatureAndTopP(model)).toBe(true)
const openWeight = createModel({ id: 'gpt-oss-debug' })
expect(isNotSupportTemperatureAndTopP(openWeight)).toBe(false)
const chatOnly = createModel({ id: 'o1-preview' })
reasoningMock.mockReturnValue(false)
expect(isNotSupportTemperatureAndTopP(chatOnly)).toBe(true)
const qwenMt = createModel({ id: 'qwen-mt-large', provider: 'aliyun' })
expect(isNotSupportTemperatureAndTopP(qwenMt)).toBe(true)
})
it('handles gemma and gemini detections plus zhipu tagging', () => {
expect(isGemmaModel(createModel({ id: 'Gemma-3-27B' }))).toBe(true)
expect(isGemmaModel(createModel({ group: 'Gemma' }))).toBe(true)
expect(isGemmaModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isGeminiModel(createModel({ id: 'Gemini-2.0' }))).toBe(true)
expect(isZhipuModel(createModel({ provider: 'zhipu' }))).toBe(true)
expect(isZhipuModel(createModel({ provider: 'openai' }))).toBe(false)
})
it('groups qwen models by prefix', () => {
const qwen = createModel({ id: 'Qwen-7B', provider: 'qwen', name: 'Qwen-7B' })
const qwenOmni = createModel({ id: 'qwen2.5-omni', name: 'qwen2.5-omni' })
const other = createModel({ id: 'deepseek-v3', group: 'DeepSeek' })
const grouped = groupQwenModels([qwen, qwenOmni, other])
expect(Object.keys(grouped)).toContain('qwen-7b')
expect(Object.keys(grouped)).toContain('qwen2.5')
expect(grouped.DeepSeek).toContain(other)
})
it('aggregates boolean helpers based on regex rules', () => {
expect(isAnthropicModel(createModel({ id: 'claude-3.5' }))).toBe(true)
expect(isQwenMTModel(createModel({ id: 'qwen-mt-large' }))).toBe(true)
expect(isNotSupportedTextDelta(createModel({ id: 'qwen-mt-large' }))).toBe(true)
expect(isNotSupportSystemMessageModel(createModel({ id: 'gemma-moe' }))).toBe(true)
expect(isOpenAIOpenWeightModel(createModel({ id: 'gpt-oss-free' }))).toBe(true)
})
it('evaluates GPT-5 family helpers', () => {
expect(isGPT5SeriesModel(createModel({ id: 'gpt-5-preview' }))).toBe(true)
expect(isGPT5SeriesModel(createModel({ id: 'gpt-5.1-preview' }))).toBe(false)
expect(isGPT51SeriesModel(createModel({ id: 'gpt-5.1-mini' }))).toBe(true)
expect(isGPT5SeriesReasoningModel(createModel({ id: 'gpt-5-prompt' }))).toBe(true)
expect(isSupportVerbosityModel(createModel({ id: 'gpt-5-chat' }))).toBe(false)
})
it('wraps generate/vision helpers that operate on arrays', () => {
const models = [createModel({ id: 'gpt-4o' }), createModel({ id: 'gpt-4o-mini' })]
expect(isVisionModels(models)).toBe(true)
visionMock.mockReturnValueOnce(true).mockReturnValueOnce(false)
expect(isVisionModels(models)).toBe(false)
expect(isGenerateImageModels(models)).toBe(true)
generateImageMock.mockReturnValueOnce(true).mockReturnValueOnce(false)
expect(isGenerateImageModels(models)).toBe(false)
})
it('filters models for agent usage', () => {
expect(agentModelFilter(createModel())).toBe(true)
embeddingMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'text-embedding' }))).toBe(false)
embeddingMock.mockReturnValue(false)
rerankMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'rerank' }))).toBe(false)
rerankMock.mockReturnValue(false)
textToImageMock.mockReturnValueOnce(true)
expect(agentModelFilter(createModel({ id: 'gpt-image-1' }))).toBe(false)
})
it('identifies models with maximum temperature of 1.0', () => {
// Zhipu models should have max temperature of 1.0
expect(isMaxTemperatureOneModel(createModel({ id: 'glm-4' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'GLM-4-Plus' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'glm-3-turbo' }))).toBe(true)
// Anthropic models should have max temperature of 1.0
expect(isMaxTemperatureOneModel(createModel({ id: 'claude-3.5-sonnet' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'Claude-3-opus' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'claude-2.1' }))).toBe(true)
// Moonshot models should have max temperature of 1.0
expect(isMaxTemperatureOneModel(createModel({ id: 'moonshot-1.0' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'kimi-k2-thinking' }))).toBe(true)
expect(isMaxTemperatureOneModel(createModel({ id: 'Moonshot-Pro' }))).toBe(true)
// Other models should return false
expect(isMaxTemperatureOneModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isMaxTemperatureOneModel(createModel({ id: 'gpt-4-turbo' }))).toBe(false)
expect(isMaxTemperatureOneModel(createModel({ id: 'qwen-max' }))).toBe(false)
expect(isMaxTemperatureOneModel(createModel({ id: 'gemini-pro' }))).toBe(false)
})
})

View File

@@ -0,0 +1,311 @@
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { isEmbeddingModel, isRerankModel } from '../embedding'
import {
isAutoEnableImageGenerationModel,
isDedicatedImageGenerationModel,
isGenerateImageModel,
isImageEnhancementModel,
isPureGenerateImageModel,
isTextToImageModel,
isVisionModel
} from '../vision'
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store', () => ({
__esModule: true,
default: {
getState: () => ({
llm: { providers: [] },
settings: {}
})
},
useAppDispatch: vi.fn(),
useAppSelector: vi.fn()
}))
vi.mock('@renderer/store/settings', () => {
const noop = vi.fn()
return new Proxy(
{},
{
get: (_target, prop) => {
if (prop === 'initialState') {
return {}
}
return noop
}
}
)
})
vi.mock('@renderer/hooks/useSettings', () => ({
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
getStoreSetting: vi.fn()
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn()
}))
vi.mock('../embedding', () => ({
isEmbeddingModel: vi.fn(),
isRerankModel: vi.fn()
}))
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o',
name: 'gpt-4o',
provider: 'openai',
group: 'OpenAI',
...overrides
})
const providerMock = vi.mocked(getProviderByModel)
const embeddingMock = vi.mocked(isEmbeddingModel)
const rerankMock = vi.mocked(isRerankModel)
describe('vision helpers', () => {
beforeEach(() => {
vi.clearAllMocks()
providerMock.mockReturnValue({ type: 'openai-response' } as any)
embeddingMock.mockReturnValue(false)
rerankMock.mockReturnValue(false)
})
describe('isGenerateImageModel', () => {
it('returns false for embedding/rerank models or missing providers', () => {
embeddingMock.mockReturnValueOnce(true)
expect(isGenerateImageModel(createModel({ id: 'gpt-image-1' }))).toBe(false)
embeddingMock.mockReturnValue(false)
rerankMock.mockReturnValueOnce(true)
expect(isGenerateImageModel(createModel({ id: 'gpt-image-1' }))).toBe(false)
rerankMock.mockReturnValue(false)
providerMock.mockReturnValueOnce(undefined as any)
expect(isGenerateImageModel(createModel({ id: 'gpt-image-1' }))).toBe(false)
})
it('detects OpenAI and third-party generative image models', () => {
expect(isGenerateImageModel(createModel({ id: 'gpt-4o-mini' }))).toBe(true)
providerMock.mockReturnValue({ type: 'custom' } as any)
expect(isGenerateImageModel(createModel({ id: 'gemini-2.5-flash-image' }))).toBe(true)
})
it('returns false when openai-response model is not on allow list', () => {
expect(isGenerateImageModel(createModel({ id: 'gpt-4.2-experimental' }))).toBe(false)
})
})
describe('isPureGenerateImageModel', () => {
it('requires both generate and text-to-image support', () => {
expect(isPureGenerateImageModel(createModel({ id: 'gpt-image-1' }))).toBe(true)
expect(isPureGenerateImageModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isPureGenerateImageModel(createModel({ id: 'gemini-2.5-flash-image-preview' }))).toBe(true)
})
})
describe('text-to-image helpers', () => {
it('matches predefined keywords', () => {
expect(isTextToImageModel(createModel({ id: 'midjourney-v6' }))).toBe(true)
expect(isTextToImageModel(createModel({ id: 'gpt-4o' }))).toBe(false)
})
it('detects models with restricted image size support and enhancement', () => {
expect(isImageEnhancementModel(createModel({ id: 'qwen-image-edit' }))).toBe(true)
expect(isImageEnhancementModel(createModel({ id: 'gpt-4o' }))).toBe(false)
})
it('identifies dedicated and auto-enabled image generation models', () => {
expect(isDedicatedImageGenerationModel(createModel({ id: 'grok-2-image-1212' }))).toBe(true)
expect(isAutoEnableImageGenerationModel(createModel({ id: 'gemini-2.5-flash-image-ultra' }))).toBe(true)
})
it('returns false when models are not in dedicated or auto-enable sets', () => {
expect(isDedicatedImageGenerationModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isAutoEnableImageGenerationModel(createModel({ id: 'gpt-4o' }))).toBe(false)
})
})
})
describe('isVisionModel', () => {
it('returns false for embedding/rerank models and honors overrides', () => {
embeddingMock.mockReturnValueOnce(true)
expect(isVisionModel(createModel({ id: 'gpt-4o' }))).toBe(false)
embeddingMock.mockReturnValue(false)
const disabled = createModel({
id: 'gpt-4o',
capabilities: [{ type: 'vision', isUserSelected: false }]
})
expect(isVisionModel(disabled)).toBe(false)
const forced = createModel({
id: 'gpt-4o',
capabilities: [{ type: 'vision', isUserSelected: true }]
})
expect(isVisionModel(forced)).toBe(true)
})
it('matches doubao models by name and general regexes by id', () => {
const doubao = createModel({
id: 'custom-id',
provider: 'doubao',
name: 'Doubao-Seed-1-6-Lite-251015'
})
expect(isVisionModel(doubao)).toBe(true)
expect(isVisionModel(createModel({ id: 'gpt-4o-mini' }))).toBe(true)
})
it('leverages image enhancement regex when standard vision regex does not match', () => {
expect(isVisionModel(createModel({ id: 'qwen-image-edit' }))).toBe(true)
})
it('returns false for doubao models that fail regex checks', () => {
const doubao = createModel({ id: 'doubao-standard', provider: 'doubao', name: 'basic' })
expect(isVisionModel(doubao)).toBe(false)
})
describe('Gemini Models', () => {
it('should return true for gemini 1.5 models', () => {
expect(
isVisionModel({
id: 'gemini-1.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-1.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini 2.x models', () => {
expect(
isVisionModel({
id: 'gemini-2.0-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-2.0-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-2.5-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-2.5-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini latest models', () => {
expect(
isVisionModel({
id: 'gemini-flash-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-pro-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-flash-lite-latest',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini 3 models', () => {
// Preview versions
expect(
isVisionModel({
id: 'gemini-3-pro-preview',
name: '',
provider: '',
group: ''
})
).toBe(true)
// Future stable versions
expect(
isVisionModel({
id: 'gemini-3-flash',
name: '',
provider: '',
group: ''
})
).toBe(true)
expect(
isVisionModel({
id: 'gemini-3-pro',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return true for gemini exp models', () => {
expect(
isVisionModel({
id: 'gemini-exp-1206',
name: '',
provider: '',
group: ''
})
).toBe(true)
})
it('should return false for gemini 1.0 models', () => {
expect(
isVisionModel({
id: 'gemini-1.0-pro',
name: '',
provider: '',
group: ''
})
).toBe(false)
})
})
})

View File

@@ -0,0 +1,397 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
const providerMock = vi.mocked(getProviderByModel)
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn(),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
const isEmbeddingModel = vi.hoisted(() => vi.fn())
const isRerankModel = vi.hoisted(() => vi.fn())
vi.mock('../embedding', () => ({
isEmbeddingModel: (...args: any[]) => isEmbeddingModel(...args),
isRerankModel: (...args: any[]) => isRerankModel(...args)
}))
const isPureGenerateImageModel = vi.hoisted(() => vi.fn())
const isTextToImageModel = vi.hoisted(() => vi.fn())
const isGenerateImageModel = vi.hoisted(() => vi.fn())
vi.mock('../vision', () => ({
isPureGenerateImageModel: (...args: any[]) => isPureGenerateImageModel(...args),
isTextToImageModel: (...args: any[]) => isTextToImageModel(...args),
isGenerateImageModel: (...args: any[]) => isGenerateImageModel(...args),
isModernGenerateImageModel: vi.fn()
}))
const providerMocks = vi.hoisted(() => ({
isGeminiProvider: vi.fn(),
isNewApiProvider: vi.fn(),
isOpenAICompatibleProvider: vi.fn(),
isOpenAIProvider: vi.fn(),
isVertexProvider: vi.fn(),
isAwsBedrockProvider: vi.fn(),
isAzureOpenAIProvider: vi.fn()
}))
vi.mock('@renderer/utils/provider', () => providerMocks)
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store', () => ({
__esModule: true,
default: {
getState: () => ({
llm: { providers: [] },
settings: {}
})
},
useAppDispatch: vi.fn(),
useAppSelector: vi.fn()
}))
vi.mock('@renderer/store/settings', () => {
const noop = vi.fn()
return new Proxy(
{},
{
get: (_target, prop) => {
if (prop === 'initialState') {
return {}
}
return noop
}
}
)
})
vi.mock('@renderer/hooks/useSettings', () => ({
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left' })),
useMessageStyle: vi.fn(() => ({ isBubbleStyle: false })),
getStoreSetting: vi.fn()
}))
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model, Provider } from '@renderer/types'
import { SystemProviderIds } from '@renderer/types'
import { isOpenAIDeepResearchModel } from '../openai'
import {
GEMINI_SEARCH_REGEX,
isHunyuanSearchModel,
isMandatoryWebSearchModel,
isOpenAIWebSearchChatCompletionOnlyModel,
isOpenAIWebSearchModel,
isOpenRouterBuiltInWebSearchModel,
isWebSearchModel
} from '../websearch'
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o',
name: 'gpt-4o',
provider: 'openai',
group: 'OpenAI',
...overrides
})
const createProvider = (overrides: Partial<Provider> = {}): Provider => ({
id: 'openai',
type: 'openai',
name: 'OpenAI',
apiKey: '',
apiHost: '',
models: [],
...overrides
})
const resetMocks = () => {
providerMock.mockReturnValue(createProvider())
isEmbeddingModel.mockReturnValue(false)
isRerankModel.mockReturnValue(false)
isPureGenerateImageModel.mockReturnValue(false)
isTextToImageModel.mockReturnValue(false)
providerMocks.isGeminiProvider.mockReturnValue(false)
providerMocks.isNewApiProvider.mockReturnValue(false)
providerMocks.isOpenAICompatibleProvider.mockReturnValue(false)
providerMocks.isOpenAIProvider.mockReturnValue(false)
}
describe('websearch helpers', () => {
beforeEach(() => {
vi.clearAllMocks()
resetMocks()
})
describe('isOpenAIDeepResearchModel', () => {
it('detects deep research ids for OpenAI only', () => {
expect(isOpenAIDeepResearchModel(createModel({ id: 'openai/deep-research-preview' }))).toBe(true)
expect(isOpenAIDeepResearchModel(createModel({ provider: 'openai', id: 'gpt-4o' }))).toBe(false)
expect(isOpenAIDeepResearchModel(createModel({ provider: 'openrouter', id: 'deep-research' }))).toBe(false)
})
})
describe('isWebSearchModel', () => {
it('returns false for embedding/rerank/image models', () => {
isEmbeddingModel.mockReturnValueOnce(true)
expect(isWebSearchModel(createModel())).toBe(false)
resetMocks()
isRerankModel.mockReturnValueOnce(true)
expect(isWebSearchModel(createModel())).toBe(false)
resetMocks()
isTextToImageModel.mockReturnValueOnce(true)
expect(isWebSearchModel(createModel())).toBe(false)
})
it('honors user overrides', () => {
const enabled = createModel({ capabilities: [{ type: 'web_search', isUserSelected: true }] })
expect(isWebSearchModel(enabled)).toBe(true)
const disabled = createModel({ capabilities: [{ type: 'web_search', isUserSelected: false }] })
expect(isWebSearchModel(disabled)).toBe(false)
})
it('returns false when provider lookup fails', () => {
providerMock.mockReturnValueOnce(undefined as any)
expect(isWebSearchModel(createModel())).toBe(false)
})
it('handles Anthropic providers on unsupported platforms', () => {
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds['aws-bedrock'] }))
const model = createModel({ id: 'claude-2-sonnet' })
expect(isWebSearchModel(model)).toBe(false)
})
it('returns true for first-party Anthropic provider', () => {
providerMock.mockReturnValueOnce(createProvider({ id: 'anthropic' }))
const model = createModel({ id: 'claude-3.5-sonnet-latest', provider: 'anthropic' })
expect(isWebSearchModel(model)).toBe(true)
})
it('detects OpenAI preview search models only when supported', () => {
providerMocks.isOpenAIProvider.mockReturnValue(true)
const model = createModel({ id: 'gpt-4o-search-preview' })
expect(isWebSearchModel(model)).toBe(true)
const nonSearch = createModel({ id: 'gpt-4o-image' })
expect(isWebSearchModel(nonSearch)).toBe(false)
})
it('supports Perplexity sonar families including mandatory variants', () => {
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.perplexity }))
expect(isWebSearchModel(createModel({ id: 'sonar-deep-research' }))).toBe(true)
})
it('handles AIHubMix Gemini and OpenAI search models', () => {
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.aihubmix }))
expect(isWebSearchModel(createModel({ id: 'gemini-2.5-pro-preview' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.aihubmix }))
const openaiSearch = createModel({ id: 'gpt-4o-search-preview' })
expect(isWebSearchModel(openaiSearch)).toBe(true)
})
it('supports OpenAI-compatible or new API providers for Gemini/OpenAI models', () => {
const model = createModel({ id: 'gemini-2.5-flash-lite-latest' })
providerMock.mockReturnValueOnce(createProvider({ id: 'custom' }))
providerMocks.isOpenAICompatibleProvider.mockReturnValueOnce(true)
expect(isWebSearchModel(model)).toBe(true)
resetMocks()
providerMock.mockReturnValueOnce(createProvider({ id: 'custom' }))
providerMocks.isNewApiProvider.mockReturnValueOnce(true)
expect(isWebSearchModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
})
it('falls back to Gemini/Vertex provider regex matching', () => {
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.vertexai }))
providerMocks.isGeminiProvider.mockReturnValueOnce(true)
expect(isWebSearchModel(createModel({ id: 'gemini-2.0-flash-latest' }))).toBe(true)
})
it('evaluates hunyuan/zhipu/dashscope/openrouter/grok providers', () => {
providerMock.mockReturnValueOnce(createProvider({ id: 'hunyuan' }))
expect(isWebSearchModel(createModel({ id: 'hunyuan-pro' }))).toBe(true)
expect(isWebSearchModel(createModel({ id: 'hunyuan-lite', provider: 'hunyuan' }))).toBe(false)
providerMock.mockReturnValueOnce(createProvider({ id: 'zhipu' }))
expect(isWebSearchModel(createModel({ id: 'glm-4-air' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: 'dashscope' }))
expect(isWebSearchModel(createModel({ id: 'qwen-max-latest' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: 'openrouter' }))
expect(isWebSearchModel(createModel())).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: 'grok' }))
expect(isWebSearchModel(createModel({ id: 'grok-2' }))).toBe(true)
})
})
describe('isMandatoryWebSearchModel', () => {
it('requires sonar ids for perplexity/openrouter providers', () => {
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.perplexity }))
expect(isMandatoryWebSearchModel(createModel({ id: 'sonar-pro' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: SystemProviderIds.openrouter }))
expect(isMandatoryWebSearchModel(createModel({ id: 'sonar-reasoning' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: 'openai' }))
expect(isMandatoryWebSearchModel(createModel({ id: 'sonar-pro' }))).toBe(false)
})
it.each([
['perplexity', 'non-sonar'],
['openrouter', 'gpt-4o-search-preview']
])('returns false for %s provider when id is %s', (providerId, modelId) => {
providerMock.mockReturnValueOnce(createProvider({ id: providerId }))
expect(isMandatoryWebSearchModel(createModel({ id: modelId }))).toBe(false)
})
})
describe('isOpenRouterBuiltInWebSearchModel', () => {
it('checks for sonar ids or OpenAI chat-completion-only variants', () => {
providerMock.mockReturnValueOnce(createProvider({ id: 'openrouter' }))
expect(isOpenRouterBuiltInWebSearchModel(createModel({ id: 'sonar-reasoning' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: 'openrouter' }))
expect(isOpenRouterBuiltInWebSearchModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
providerMock.mockReturnValueOnce(createProvider({ id: 'custom' }))
expect(isOpenRouterBuiltInWebSearchModel(createModel({ id: 'sonar-reasoning' }))).toBe(false)
})
})
describe('OpenAI web search helpers', () => {
it('detects chat completion only variants and openai search ids', () => {
expect(isOpenAIWebSearchChatCompletionOnlyModel(createModel({ id: 'gpt-4o-search-preview' }))).toBe(true)
expect(isOpenAIWebSearchChatCompletionOnlyModel(createModel({ id: 'gpt-4o-mini-search-preview' }))).toBe(true)
expect(isOpenAIWebSearchChatCompletionOnlyModel(createModel({ id: 'gpt-4o' }))).toBe(false)
expect(isOpenAIWebSearchModel(createModel({ id: 'gpt-4.1-turbo' }))).toBe(true)
expect(isOpenAIWebSearchModel(createModel({ id: 'gpt-4o-image' }))).toBe(false)
expect(isOpenAIWebSearchModel(createModel({ id: 'gpt-5.1-chat' }))).toBe(false)
expect(isOpenAIWebSearchModel(createModel({ id: 'o3-mini' }))).toBe(true)
})
it.each(['gpt-4.1-preview', 'gpt-4o-2024-05-13', 'o4-mini', 'gpt-5-explorer'])(
'treats %s as an OpenAI web search model',
(id) => {
expect(isOpenAIWebSearchModel(createModel({ id }))).toBe(true)
}
)
it.each(['gpt-4o-image-preview', 'gpt-4.1-nano', 'gpt-5.1-chat', 'gpt-image-1'])(
'excludes %s from OpenAI web search',
(id) => {
expect(isOpenAIWebSearchModel(createModel({ id }))).toBe(false)
}
)
it.each(['gpt-4o-search-preview', 'gpt-4o-mini-search-preview'])('flags %s as chat-completion-only', (id) => {
expect(isOpenAIWebSearchChatCompletionOnlyModel(createModel({ id }))).toBe(true)
})
})
describe('isHunyuanSearchModel', () => {
it('identifies hunyuan models except lite', () => {
expect(isHunyuanSearchModel(createModel({ id: 'hunyuan-pro', provider: 'hunyuan' }))).toBe(true)
expect(isHunyuanSearchModel(createModel({ id: 'hunyuan-lite', provider: 'hunyuan' }))).toBe(false)
expect(isHunyuanSearchModel(createModel())).toBe(false)
})
it.each(['hunyuan-standard', 'hunyuan-advanced'])('accepts %s', (suffix) => {
expect(isHunyuanSearchModel(createModel({ id: suffix, provider: 'hunyuan' }))).toBe(true)
})
})
describe('provider-specific regex coverage', () => {
it.each(['qwen-turbo', 'qwen-max-0919', 'qwen3-max', 'qwen-plus-2024', 'qwq-32b'])(
'dashscope treats %s as searchable',
(id) => {
providerMock.mockReturnValue(createProvider({ id: 'dashscope' }))
expect(isWebSearchModel(createModel({ id }))).toBe(true)
}
)
it.each(['qwen-1.5-chat', 'custom-model'])('dashscope ignores %s', (id) => {
providerMock.mockReturnValue(createProvider({ id: 'dashscope' }))
expect(isWebSearchModel(createModel({ id }))).toBe(false)
})
it.each(['sonar', 'sonar-pro', 'sonar-reasoning-pro', 'sonar-deep-research'])(
'perplexity provider supports %s',
(id) => {
providerMock.mockReturnValue(createProvider({ id: SystemProviderIds.perplexity }))
expect(isWebSearchModel(createModel({ id }))).toBe(true)
}
)
it.each([
'gemini-2.0-flash-latest',
'gemini-2.5-flash-lite-latest',
'gemini-flash-lite-latest',
'gemini-pro-latest'
])('Gemini provider supports %s', (id) => {
providerMock.mockReturnValue(createProvider({ id: SystemProviderIds.vertexai }))
providerMocks.isGeminiProvider.mockReturnValue(true)
expect(isWebSearchModel(createModel({ id }))).toBe(true)
})
})
describe('Gemini Search Models', () => {
describe('GEMINI_SEARCH_REGEX', () => {
it('should match gemini 2.x models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-pro')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-pro')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash-latest')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-pro-latest')).toBe(true)
})
it('should match gemini latest models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-flash-latest')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-pro-latest')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-flash-lite-latest')).toBe(true)
})
it('should match gemini 3 models', () => {
// Preview versions
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro-preview')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3-flash-preview')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro-image-preview')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3-flash-image-preview')).toBe(true)
// Future stable versions
expect(GEMINI_SEARCH_REGEX.test('gemini-3-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3-pro')).toBe(true)
// Version with decimals
expect(GEMINI_SEARCH_REGEX.test('gemini-3.0-flash')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3.0-pro')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3.5-flash-preview')).toBe(true)
expect(GEMINI_SEARCH_REGEX.test('gemini-3.5-pro-image-preview')).toBe(true)
})
it('should not match gemini 2.x image-preview models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-2.5-flash-image-preview')).toBe(false)
expect(GEMINI_SEARCH_REGEX.test('gemini-2.0-pro-image-preview')).toBe(false)
})
it('should not match older gemini models', () => {
expect(GEMINI_SEARCH_REGEX.test('gemini-1.5-flash')).toBe(false)
expect(GEMINI_SEARCH_REGEX.test('gemini-1.5-pro')).toBe(false)
expect(GEMINI_SEARCH_REGEX.test('gemini-1.0-pro')).toBe(false)
})
})
})
})

View File

@@ -1,6 +1,8 @@
export * from './default'
export * from './embedding'
export * from './logo'
export * from './openai'
export * from './qwen'
export * from './reasoning'
export * from './tooluse'
export * from './utils'

View File

@@ -0,0 +1,107 @@
import type { Model } from '@renderer/types'
import { getLowerBaseModelName } from '@renderer/utils'
export const OPENAI_NO_SUPPORT_DEV_ROLE_MODELS = ['o1-preview', 'o1-mini']
export function isOpenAILLMModel(model: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
if (modelId.includes('gpt-4o-image')) {
return false
}
if (isOpenAIReasoningModel(model)) {
return true
}
if (modelId.includes('gpt')) {
return true
}
return false
}
export function isOpenAIModel(model: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt') || isOpenAIReasoningModel(model)
}
export const isGPT5ProModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-5-pro')
}
export const isOpenAIOpenWeightModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-oss')
}
export const isGPT5SeriesModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-5') && !modelId.includes('gpt-5.1')
}
export const isGPT5SeriesReasoningModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return isGPT5SeriesModel(model) && !modelId.includes('chat')
}
export const isGPT51SeriesModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-5.1')
}
export function isSupportVerbosityModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id)
return (isGPT5SeriesModel(model) || isGPT51SeriesModel(model)) && !modelId.includes('chat')
}
export function isOpenAIChatCompletionOnlyModel(model: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
return (
modelId.includes('gpt-4o-search-preview') ||
modelId.includes('gpt-4o-mini-search-preview') ||
modelId.includes('o1-mini') ||
modelId.includes('o1-preview')
)
}
export function isOpenAIReasoningModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id, '/')
return isSupportedReasoningEffortOpenAIModel(model) || modelId.includes('o1')
}
export function isSupportedReasoningEffortOpenAIModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id)
return (
(modelId.includes('o1') && !(modelId.includes('o1-preview') || modelId.includes('o1-mini'))) ||
modelId.includes('o3') ||
modelId.includes('o4') ||
modelId.includes('gpt-oss') ||
((isGPT5SeriesModel(model) || isGPT51SeriesModel(model)) && !modelId.includes('chat'))
)
}
const OPENAI_DEEP_RESEARCH_MODEL_REGEX = /deep[-_]?research/
export function isOpenAIDeepResearchModel(model?: Model): boolean {
if (!model) {
return false
}
const providerId = model.provider
if (providerId !== 'openai' && providerId !== 'openai-chat') {
return false
}
const modelId = getLowerBaseModelName(model.id, '/')
return OPENAI_DEEP_RESEARCH_MODEL_REGEX.test(modelId)
}

View File

@@ -0,0 +1,7 @@
import type { Model } from '@renderer/types'
import { getLowerBaseModelName } from '@renderer/utils'
export const isQwenMTModel = (model: Model): boolean => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('qwen-mt')
}

View File

@@ -8,9 +8,16 @@ import type {
import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
import { isEmbeddingModel, isRerankModel } from './embedding'
import { isGPT5ProModel, isGPT5SeriesModel, isGPT51SeriesModel } from './utils'
import {
isGPT5ProModel,
isGPT5SeriesModel,
isGPT51SeriesModel,
isOpenAIDeepResearchModel,
isOpenAIReasoningModel,
isSupportedReasoningEffortOpenAIModel
} from './openai'
import { GEMINI_FLASH_MODEL_REGEX, isGemini3Model } from './utils'
import { isTextToImageModel } from './vision'
import { GEMINI_FLASH_MODEL_REGEX, isOpenAIDeepResearchModel } from './websearch'
// Reasoning models
export const REASONING_REGEX =
@@ -30,6 +37,7 @@ export const MODEL_SUPPORTED_REASONING_EFFORT: ReasoningEffortConfig = {
grok: ['low', 'high'] as const,
grok4_fast: ['auto'] as const,
gemini: ['low', 'medium', 'high', 'auto'] as const,
gemini3: ['low', 'medium', 'high'] as const,
gemini_pro: ['low', 'medium', 'high', 'auto'] as const,
qwen: ['low', 'medium', 'high'] as const,
qwen_thinking: ['low', 'medium', 'high'] as const,
@@ -56,6 +64,7 @@ export const MODEL_SUPPORTED_OPTIONS: ThinkingOptionConfig = {
grok4_fast: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.grok4_fast] as const,
gemini: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.gemini] as const,
gemini_pro: MODEL_SUPPORTED_REASONING_EFFORT.gemini_pro,
gemini3: MODEL_SUPPORTED_REASONING_EFFORT.gemini3,
qwen: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.qwen] as const,
qwen_thinking: MODEL_SUPPORTED_REASONING_EFFORT.qwen_thinking,
doubao: ['none', ...MODEL_SUPPORTED_REASONING_EFFORT.doubao] as const,
@@ -106,6 +115,9 @@ const _getThinkModelType = (model: Model): ThinkingModelType => {
} else {
thinkingModelType = 'gemini_pro'
}
if (isGemini3Model(model)) {
thinkingModelType = 'gemini3'
}
} else if (isSupportedReasoningEffortGrokModel(model)) thinkingModelType = 'grok'
else if (isSupportedThinkingTokenQwenModel(model)) {
if (isQwenAlwaysThinkModel(model)) {
@@ -254,11 +266,19 @@ export function isGeminiReasoningModel(model?: Model): boolean {
// Gemini 支持思考模式的模型正则
export const GEMINI_THINKING_MODEL_REGEX =
/gemini-(?:2\.5.*(?:-latest)?|3-(?:flash|pro)(?:-preview)?|flash-latest|pro-latest|flash-lite-latest)(?:-[\w-]+)*$/i
/gemini-(?:2\.5.*(?:-latest)?|3(?:\.\d+)?-(?:flash|pro)(?:-preview)?|flash-latest|pro-latest|flash-lite-latest)(?:-[\w-]+)*$/i
export const isSupportedThinkingTokenGeminiModel = (model: Model): boolean => {
const modelId = getLowerBaseModelName(model.id, '/')
if (GEMINI_THINKING_MODEL_REGEX.test(modelId)) {
// gemini-3.x 的 image 模型支持思考模式
if (isGemini3Model(model)) {
if (modelId.includes('tts')) {
return false
}
return true
}
// gemini-2.x 的 image/tts 模型不支持
if (modelId.includes('image') || modelId.includes('tts')) {
return false
}
@@ -382,6 +402,12 @@ export function isClaude45ReasoningModel(model: Model): boolean {
return regex.test(modelId)
}
export function isClaude4SeriesModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id, '/')
const regex = /claude-(sonnet|opus|haiku)-4(?:[.-]\d+)?(?:-[\w-]+)?$/i
return regex.test(modelId)
}
export function isClaudeReasoningModel(model?: Model): boolean {
if (!model) {
return false
@@ -529,22 +555,6 @@ export function isReasoningModel(model?: Model): boolean {
return REASONING_REGEX.test(modelId) || false
}
export function isOpenAIReasoningModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id, '/')
return isSupportedReasoningEffortOpenAIModel(model) || modelId.includes('o1')
}
export function isSupportedReasoningEffortOpenAIModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id)
return (
(modelId.includes('o1') && !(modelId.includes('o1-preview') || modelId.includes('o1-mini'))) ||
modelId.includes('o3') ||
modelId.includes('o4') ||
modelId.includes('gpt-oss') ||
((isGPT5SeriesModel(model) || isGPT51SeriesModel(model)) && !modelId.includes('chat'))
)
}
export const THINKING_TOKEN_MAP: Record<string, { min: number; max: number }> = {
// Gemini models
'gemini-2\\.5-flash-lite.*$': { min: 512, max: 24576 },

View File

@@ -4,7 +4,7 @@ import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
import { isEmbeddingModel, isRerankModel } from './embedding'
import { isDeepSeekHybridInferenceModel } from './reasoning'
import { isPureGenerateImageModel, isTextToImageModel } from './vision'
import { isTextToImageModel } from './vision'
// Tool calling models
export const FUNCTION_CALLING_MODELS = [
@@ -41,7 +41,9 @@ const FUNCTION_CALLING_EXCLUDED_MODELS = [
'gemini-1(?:\\.[\\w-]+)?',
'qwen-mt(?:-[\\w-]+)?',
'gpt-5-chat(?:-[\\w-]+)?',
'glm-4\\.5v'
'glm-4\\.5v',
'gemini-2.5-flash-image(?:-[\\w-]+)?',
'gemini-2.0-flash-preview-image-generation'
]
export const FUNCTION_CALLING_REGEX = new RegExp(
@@ -50,13 +52,7 @@ export const FUNCTION_CALLING_REGEX = new RegExp(
)
export function isFunctionCallingModel(model?: Model): boolean {
if (
!model ||
isEmbeddingModel(model) ||
isRerankModel(model) ||
isTextToImageModel(model) ||
isPureGenerateImageModel(model)
) {
if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) {
return false
}
@@ -66,10 +62,6 @@ export function isFunctionCallingModel(model?: Model): boolean {
return isUserSelectedModelType(model, 'function_calling')!
}
if (model.provider === 'qiniu') {
return ['deepseek-v3-tool', 'deepseek-v3-0324', 'qwq-32b', 'qwen2.5-72b-instruct'].includes(modelId)
}
if (model.provider === 'doubao' || modelId.includes('doubao')) {
return FUNCTION_CALLING_REGEX.test(modelId) || FUNCTION_CALLING_REGEX.test(model.name)
}

View File

@@ -1,43 +1,14 @@
import type OpenAI from '@cherrystudio/openai'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models/embedding'
import type { Model } from '@renderer/types'
import { type Model, SystemProviderIds } from '@renderer/types'
import type { OpenAIVerbosity, ValidOpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import { getLowerBaseModelName } from '@renderer/utils'
import { WEB_SEARCH_PROMPT_FOR_OPENROUTER } from '../prompts'
import { getWebSearchTools } from '../tools'
import { isOpenAIReasoningModel } from './reasoning'
import { isOpenAIChatCompletionOnlyModel, isOpenAIOpenWeightModel, isOpenAIReasoningModel } from './openai'
import { isQwenMTModel } from './qwen'
import { isGenerateImageModel, isTextToImageModel, isVisionModel } from './vision'
import { isOpenAIWebSearchChatCompletionOnlyModel } from './websearch'
export const NOT_SUPPORTED_REGEX = /(?:^tts|whisper|speech)/i
export const OPENAI_NO_SUPPORT_DEV_ROLE_MODELS = ['o1-preview', 'o1-mini']
export function isOpenAILLMModel(model: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
if (modelId.includes('gpt-4o-image')) {
return false
}
if (isOpenAIReasoningModel(model)) {
return true
}
if (modelId.includes('gpt')) {
return true
}
return false
}
export function isOpenAIModel(model: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt') || isOpenAIReasoningModel(model)
}
export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini.*-flash.*$', 'i')
export function isSupportFlexServiceTierModel(model: Model): boolean {
if (!model) {
@@ -52,33 +23,6 @@ export function isSupportedFlexServiceTier(model: Model): boolean {
return isSupportFlexServiceTierModel(model)
}
export function isSupportVerbosityModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id)
return (isGPT5SeriesModel(model) || isGPT51SeriesModel(model)) && !modelId.includes('chat')
}
export function isOpenAIChatCompletionOnlyModel(model: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
return (
modelId.includes('gpt-4o-search-preview') ||
modelId.includes('gpt-4o-mini-search-preview') ||
modelId.includes('o1-mini') ||
modelId.includes('o1-preview')
)
}
export function isGrokModel(model?: Model): boolean {
if (!model) {
return false
}
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('grok')
}
export function isSupportedModel(model: OpenAI.Models.Model): boolean {
if (!model) {
return false
@@ -105,53 +49,6 @@ export function isNotSupportTemperatureAndTopP(model: Model): boolean {
return false
}
export function getOpenAIWebSearchParams(model: Model, isEnableWebSearch?: boolean): Record<string, any> {
if (!isEnableWebSearch) {
return {}
}
const webSearchTools = getWebSearchTools(model)
if (model.provider === 'grok') {
return {
search_parameters: {
mode: 'auto',
return_citations: true,
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
}
}
}
if (model.provider === 'hunyuan') {
return { enable_enhancement: true, citation: true, search_info: true }
}
if (model.provider === 'dashscope') {
return {
enable_search: true,
search_options: {
forced_search: true
}
}
}
if (isOpenAIWebSearchChatCompletionOnlyModel(model)) {
return {
web_search_options: {}
}
}
if (model.provider === 'openrouter') {
return {
plugins: [{ id: 'web', search_prompts: WEB_SEARCH_PROMPT_FOR_OPENROUTER }]
}
}
return {
tools: webSearchTools
}
}
export function isGemmaModel(model?: Model): boolean {
if (!model) {
return false
@@ -161,12 +58,14 @@ export function isGemmaModel(model?: Model): boolean {
return modelId.includes('gemma-') || model.group === 'Gemma'
}
export function isZhipuModel(model?: Model): boolean {
if (!model) {
return false
}
export function isZhipuModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('glm') || model.provider === SystemProviderIds.zhipu
}
return model.provider === 'zhipu'
export function isMoonshotModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id)
return ['moonshot', 'kimi'].some((m) => modelId.includes(m))
}
/**
@@ -212,11 +111,6 @@ export const isAnthropicModel = (model?: Model): boolean => {
return modelId.startsWith('claude')
}
export const isQwenMTModel = (model: Model): boolean => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('qwen-mt')
}
export const isNotSupportedTextDelta = (model: Model): boolean => {
return isQwenMTModel(model)
}
@@ -225,34 +119,22 @@ export const isNotSupportSystemMessageModel = (model: Model): boolean => {
return isQwenMTModel(model) || isGemmaModel(model)
}
export const isGPT5SeriesModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-5') && !modelId.includes('gpt-5.1')
}
export const isGPT5SeriesReasoningModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return isGPT5SeriesModel(model) && !modelId.includes('chat')
}
export const isGPT51SeriesModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-5.1')
}
// GPT-5 verbosity configuration
// gpt-5-pro only supports 'high', other GPT-5 models support all levels
export const MODEL_SUPPORTED_VERBOSITY: Record<string, ('low' | 'medium' | 'high')[]> = {
export const MODEL_SUPPORTED_VERBOSITY: Record<string, ValidOpenAIVerbosity[]> = {
'gpt-5-pro': ['high'],
default: ['low', 'medium', 'high']
}
} as const
export const getModelSupportedVerbosity = (model: Model): ('low' | 'medium' | 'high')[] => {
export const getModelSupportedVerbosity = (model: Model): OpenAIVerbosity[] => {
const modelId = getLowerBaseModelName(model.id)
let supportedValues: ValidOpenAIVerbosity[]
if (modelId.includes('gpt-5-pro')) {
return MODEL_SUPPORTED_VERBOSITY['gpt-5-pro']
supportedValues = MODEL_SUPPORTED_VERBOSITY['gpt-5-pro']
} else {
supportedValues = MODEL_SUPPORTED_VERBOSITY.default
}
return MODEL_SUPPORTED_VERBOSITY.default
return [undefined, ...supportedValues]
}
export const isGeminiModel = (model: Model) => {
@@ -260,11 +142,6 @@ export const isGeminiModel = (model: Model) => {
return modelId.includes('gemini')
}
export const isOpenAIOpenWeightModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-oss')
}
// zhipu 视觉推理模型用这组 special token 标记推理结果
export const ZHIPU_RESULT_TOKENS = ['<|begin_of_box|>', '<|end_of_box|>'] as const
@@ -272,7 +149,14 @@ export const agentModelFilter = (model: Model): boolean => {
return !isEmbeddingModel(model) && !isRerankModel(model) && !isTextToImageModel(model)
}
export const isGPT5ProModel = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gpt-5-pro')
export const isMaxTemperatureOneModel = (model: Model): boolean => {
if (isZhipuModel(model) || isAnthropicModel(model) || isMoonshotModel(model)) {
return true
}
return false
}
export const isGemini3Model = (model: Model) => {
const modelId = getLowerBaseModelName(model.id)
return modelId.includes('gemini-3')
}

View File

@@ -3,6 +3,7 @@ import type { Model } from '@renderer/types'
import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
import { isEmbeddingModel, isRerankModel } from './embedding'
import { isFunctionCallingModel } from './tooluse'
// Vision models
const visionAllowedModels = [
@@ -72,12 +73,10 @@ const VISION_REGEX = new RegExp(
// For middleware to identify models that must use the dedicated Image API
const DEDICATED_IMAGE_MODELS = [
'grok-2-image',
'grok-2-image-1212',
'grok-2-image-latest',
'dall-e-3',
'dall-e-2',
'gpt-image-1'
'grok-2-image(?:-[\\w-]+)?',
'dall-e(?:-[\\w-]+)?',
'gpt-image-1(?:-[\\w-]+)?',
'imagen(?:-[\\w-]+)?'
]
const IMAGE_ENHANCEMENT_MODELS = [
@@ -85,13 +84,22 @@ const IMAGE_ENHANCEMENT_MODELS = [
'qwen-image-edit',
'gpt-image-1',
'gemini-2.5-flash-image(?:-[\\w-]+)?',
'gemini-2.0-flash-preview-image-generation'
'gemini-2.0-flash-preview-image-generation',
'gemini-3(?:\\.\\d+)?-pro-image(?:-[\\w-]+)?'
]
const IMAGE_ENHANCEMENT_MODELS_REGEX = new RegExp(IMAGE_ENHANCEMENT_MODELS.join('|'), 'i')
const DEDICATED_IMAGE_MODELS_REGEX = new RegExp(DEDICATED_IMAGE_MODELS.join('|'), 'i')
// Models that should auto-enable image generation button when selected
const AUTO_ENABLE_IMAGE_MODELS = ['gemini-2.5-flash-image', ...DEDICATED_IMAGE_MODELS]
const AUTO_ENABLE_IMAGE_MODELS = [
'gemini-2.5-flash-image(?:-[\\w-]+)?',
'gemini-3(?:\\.\\d+)?-pro-image(?:-[\\w-]+)?',
...DEDICATED_IMAGE_MODELS
]
const AUTO_ENABLE_IMAGE_MODELS_REGEX = new RegExp(AUTO_ENABLE_IMAGE_MODELS.join('|'), 'i')
const OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS = [
'o3',
@@ -105,26 +113,34 @@ const OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS = [
const OPENAI_IMAGE_GENERATION_MODELS = [...OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS, 'gpt-image-1']
const MODERN_IMAGE_MODELS = ['gemini-3(?:\\.\\d+)?-pro-image(?:-[\\w-]+)?']
const GENERATE_IMAGE_MODELS = [
'gemini-2.0-flash-exp',
'gemini-2.0-flash-exp-image-generation',
'gemini-2.0-flash-exp(?:-[\\w-]+)?',
'gemini-2.5-flash-image(?:-[\\w-]+)?',
'gemini-2.0-flash-preview-image-generation',
'gemini-2.5-flash-image',
...MODERN_IMAGE_MODELS,
...DEDICATED_IMAGE_MODELS
]
const OPENAI_IMAGE_GENERATION_MODELS_REGEX = new RegExp(OPENAI_IMAGE_GENERATION_MODELS.join('|'), 'i')
const GENERATE_IMAGE_MODELS_REGEX = new RegExp(GENERATE_IMAGE_MODELS.join('|'), 'i')
const MODERN_GENERATE_IMAGE_MODELS_REGEX = new RegExp(MODERN_IMAGE_MODELS.join('|'), 'i')
export const isDedicatedImageGenerationModel = (model: Model): boolean => {
if (!model) return false
const modelId = getLowerBaseModelName(model.id)
return DEDICATED_IMAGE_MODELS.some((m) => modelId.includes(m))
return DEDICATED_IMAGE_MODELS_REGEX.test(modelId)
}
export const isAutoEnableImageGenerationModel = (model: Model): boolean => {
if (!model) return false
const modelId = getLowerBaseModelName(model.id)
return AUTO_ENABLE_IMAGE_MODELS.some((m) => modelId.includes(m))
return AUTO_ENABLE_IMAGE_MODELS_REGEX.test(modelId)
}
/**
@@ -146,48 +162,44 @@ export function isGenerateImageModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id, '/')
if (provider.type === 'openai-response') {
return (
OPENAI_IMAGE_GENERATION_MODELS.some((imageModel) => modelId.includes(imageModel)) ||
GENERATE_IMAGE_MODELS.some((imageModel) => modelId.includes(imageModel))
)
return OPENAI_IMAGE_GENERATION_MODELS_REGEX.test(modelId) || GENERATE_IMAGE_MODELS_REGEX.test(modelId)
}
return GENERATE_IMAGE_MODELS.some((imageModel) => modelId.includes(imageModel))
return GENERATE_IMAGE_MODELS_REGEX.test(modelId)
}
// TODO: refine the regex
/**
* 判断模型是否支持纯图片生成(不支持通过工具调用)
* @param model
* @returns
*/
export function isPureGenerateImageModel(model: Model): boolean {
if (!isGenerateImageModel(model) || !isTextToImageModel(model)) {
if (!isGenerateImageModel(model) && !isTextToImageModel(model)) {
return false
}
if (isFunctionCallingModel(model)) {
return false
}
const modelId = getLowerBaseModelName(model.id)
return !OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS.some((imageModel) => modelId.includes(imageModel))
if (GENERATE_IMAGE_MODELS_REGEX.test(modelId) && !MODERN_GENERATE_IMAGE_MODELS_REGEX.test(modelId)) {
return true
}
return !OPENAI_TOOL_USE_IMAGE_GENERATION_MODELS.some((m) => modelId.includes(m))
}
// TODO: refine the regex
// Text to image models
const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-|dall|cogview|janus|midjourney|mj-|image|gpt-image/i
const TEXT_TO_IMAGE_REGEX = /flux|diffusion|stabilityai|sd-|dall|cogview|janus|midjourney|mj-|imagen|gpt-image/i
export function isTextToImageModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id)
return TEXT_TO_IMAGE_REGEX.test(modelId)
}
// It's not used now
// export function isNotSupportedImageSizeModel(model?: Model): boolean {
// if (!model) {
// return false
// }
// const baseName = getLowerBaseModelName(model.id, '/')
// return baseName.includes('grok-2-image')
// }
/**
* 判断模型是否支持图片增强(包括编辑、增强、修复等)
* @param model

View File

@@ -2,27 +2,29 @@ import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model } from '@renderer/types'
import { SystemProviderIds } from '@renderer/types'
import { getLowerBaseModelName, isUserSelectedModelType } from '@renderer/utils'
import {
isAzureOpenAIProvider,
isGeminiProvider,
isNewApiProvider,
isOpenAICompatibleProvider,
isOpenAIProvider,
isVertexAiProvider
} from '../providers'
import { isEmbeddingModel, isRerankModel } from './embedding'
import { isAnthropicModel } from './utils'
import { isPureGenerateImageModel, isTextToImageModel } from './vision'
isVertexProvider
} from '@renderer/utils/provider'
export const CLAUDE_SUPPORTED_WEBSEARCH_REGEX = new RegExp(
export { GEMINI_FLASH_MODEL_REGEX } from './utils'
import { isEmbeddingModel, isRerankModel } from './embedding'
import { isClaude4SeriesModel } from './reasoning'
import { isAnthropicModel } from './utils'
import { isTextToImageModel } from './vision'
const CLAUDE_SUPPORTED_WEBSEARCH_REGEX = new RegExp(
`\\b(?:claude-3(-|\\.)(7|5)-sonnet(?:-[\\w-]+)|claude-3(-|\\.)5-haiku(?:-[\\w-]+)|claude-(haiku|sonnet|opus)-4(?:-[\\w-]+)?)\\b`,
'i'
)
export const GEMINI_FLASH_MODEL_REGEX = new RegExp('gemini.*-flash.*$')
export const GEMINI_SEARCH_REGEX = new RegExp(
'gemini-(?:2.*(?:-latest)?|3-(?:flash|pro)(?:-preview)?|flash-latest|pro-latest|flash-lite-latest)(?:-[\\w-]+)*$',
'gemini-(?:2(?!.*-image-preview).*(?:-latest)?|3(?:\\.\\d+)?-(?:flash|pro)(?:-(?:image-)?preview)?|flash-latest|pro-latest|flash-lite-latest)(?:-[\\w-]+)*$',
'i'
)
@@ -34,30 +36,8 @@ export const PERPLEXITY_SEARCH_MODELS = [
'sonar-deep-research'
]
const OPENAI_DEEP_RESEARCH_MODEL_REGEX = /deep[-_]?research/
export function isOpenAIDeepResearchModel(model?: Model): boolean {
if (!model) {
return false
}
const providerId = model.provider
if (providerId !== 'openai' && providerId !== 'openai-chat') {
return false
}
const modelId = getLowerBaseModelName(model.id, '/')
return OPENAI_DEEP_RESEARCH_MODEL_REGEX.test(modelId)
}
export function isWebSearchModel(model: Model): boolean {
if (
!model ||
isEmbeddingModel(model) ||
isRerankModel(model) ||
isTextToImageModel(model) ||
isPureGenerateImageModel(model)
) {
if (!model || isEmbeddingModel(model) || isRerankModel(model) || isTextToImageModel(model)) {
return false
}
@@ -73,16 +53,17 @@ export function isWebSearchModel(model: Model): boolean {
const modelId = getLowerBaseModelName(model.id, '/')
// bedrock和vertex不支持
if (
isAnthropicModel(model) &&
!(provider.id === SystemProviderIds['aws-bedrock'] || provider.id === SystemProviderIds.vertexai)
) {
// bedrock不支持, azure支持
if (isAnthropicModel(model) && !(provider.id === SystemProviderIds['aws-bedrock'])) {
if (isVertexProvider(provider)) {
return isClaude4SeriesModel(model)
}
return CLAUDE_SUPPORTED_WEBSEARCH_REGEX.test(modelId)
}
// TODO: 当其他供应商采用Response端点时这个地方逻辑需要改进
if (isOpenAIProvider(provider)) {
// azure现在也支持了websearch
if (isOpenAIProvider(provider) || isAzureOpenAIProvider(provider)) {
if (isOpenAIWebSearchModel(model)) {
return true
}
@@ -113,7 +94,7 @@ export function isWebSearchModel(model: Model): boolean {
}
}
if (isGeminiProvider(provider) || isVertexAiProvider(provider)) {
if (isGeminiProvider(provider) || isVertexProvider(provider)) {
return GEMINI_SEARCH_REGEX.test(modelId)
}

View File

@@ -59,15 +59,8 @@ import VoyageAIProviderLogo from '@renderer/assets/images/providers/voyageai.png
import XirangProviderLogo from '@renderer/assets/images/providers/xirang.png'
import ZeroOneProviderLogo from '@renderer/assets/images/providers/zero-one.png'
import ZhipuProviderLogo from '@renderer/assets/images/providers/zhipu.png'
import type {
AtLeast,
AzureOpenAIProvider,
Provider,
ProviderType,
SystemProvider,
SystemProviderId
} from '@renderer/types'
import { isSystemProvider, OpenAIServiceTiers, SystemProviderIds } from '@renderer/types'
import type { AtLeast, SystemProvider, SystemProviderId } from '@renderer/types'
import { OpenAIServiceTiers } from '@renderer/types'
import { TOKENFLUX_HOST } from './constant'
import { glm45FlashModel, qwen38bModel, SYSTEM_MODELS } from './models'
@@ -1441,153 +1434,3 @@ export const PROVIDER_URLS: Record<SystemProviderId, ProviderUrls> = {
}
}
}
const NOT_SUPPORT_ARRAY_CONTENT_PROVIDERS = [
'deepseek',
'baichuan',
'minimax',
'xirang',
'poe',
'cephalon'
] as const satisfies SystemProviderId[]
/**
* 判断提供商是否支持 message 的 content 为数组类型。 Only for OpenAI Chat Completions API.
*/
export const isSupportArrayContentProvider = (provider: Provider) => {
return (
provider.apiOptions?.isNotSupportArrayContent !== true &&
!NOT_SUPPORT_ARRAY_CONTENT_PROVIDERS.some((pid) => pid === provider.id)
)
}
const NOT_SUPPORT_DEVELOPER_ROLE_PROVIDERS = ['poe', 'qiniu'] as const satisfies SystemProviderId[]
/**
* 判断提供商是否支持 developer 作为 message role。 Only for OpenAI API.
*/
export const isSupportDeveloperRoleProvider = (provider: Provider) => {
return (
provider.apiOptions?.isSupportDeveloperRole === true ||
(isSystemProvider(provider) && !NOT_SUPPORT_DEVELOPER_ROLE_PROVIDERS.some((pid) => pid === provider.id))
)
}
const NOT_SUPPORT_STREAM_OPTIONS_PROVIDERS = ['mistral'] as const satisfies SystemProviderId[]
/**
* 判断提供商是否支持 stream_options 参数。Only for OpenAI API.
*/
export const isSupportStreamOptionsProvider = (provider: Provider) => {
return (
provider.apiOptions?.isNotSupportStreamOptions !== true &&
!NOT_SUPPORT_STREAM_OPTIONS_PROVIDERS.some((pid) => pid === provider.id)
)
}
const NOT_SUPPORT_QWEN3_ENABLE_THINKING_PROVIDER = [
'ollama',
'lmstudio',
'nvidia'
] as const satisfies SystemProviderId[]
/**
* 判断提供商是否支持使用 enable_thinking 参数来控制 Qwen3 等模型的思考。 Only for OpenAI Chat Completions API.
*/
export const isSupportEnableThinkingProvider = (provider: Provider) => {
return (
provider.apiOptions?.isNotSupportEnableThinking !== true &&
!NOT_SUPPORT_QWEN3_ENABLE_THINKING_PROVIDER.some((pid) => pid === provider.id)
)
}
const NOT_SUPPORT_SERVICE_TIER_PROVIDERS = ['github', 'copilot', 'cerebras'] as const satisfies SystemProviderId[]
/**
* 判断提供商是否支持 service_tier 设置。 Only for OpenAI API.
*/
export const isSupportServiceTierProvider = (provider: Provider) => {
return (
provider.apiOptions?.isSupportServiceTier === true ||
(isSystemProvider(provider) && !NOT_SUPPORT_SERVICE_TIER_PROVIDERS.some((pid) => pid === provider.id))
)
}
const SUPPORT_URL_CONTEXT_PROVIDER_TYPES = [
'gemini',
'vertexai',
'anthropic',
'new-api'
] as const satisfies ProviderType[]
export const isSupportUrlContextProvider = (provider: Provider) => {
return (
SUPPORT_URL_CONTEXT_PROVIDER_TYPES.some((type) => type === provider.type) ||
provider.id === SystemProviderIds.cherryin
)
}
const SUPPORT_GEMINI_NATIVE_WEB_SEARCH_PROVIDERS = ['gemini', 'vertexai'] as const satisfies SystemProviderId[]
/** 判断是否是使用 Gemini 原生搜索工具的 provider. 目前假设只有官方 API 使用原生工具 */
export const isGeminiWebSearchProvider = (provider: Provider) => {
return SUPPORT_GEMINI_NATIVE_WEB_SEARCH_PROVIDERS.some((id) => id === provider.id)
}
export const isNewApiProvider = (provider: Provider) => {
return ['new-api', 'cherryin'].includes(provider.id) || provider.type === 'new-api'
}
export function isCherryAIProvider(provider: Provider): boolean {
return provider.id === 'cherryai'
}
export function isPerplexityProvider(provider: Provider): boolean {
return provider.id === 'perplexity'
}
/**
* 判断是否为 OpenAI 兼容的提供商
* @param {Provider} provider 提供商对象
* @returns {boolean} 是否为 OpenAI 兼容提供商
*/
export function isOpenAICompatibleProvider(provider: Provider): boolean {
return ['openai', 'new-api', 'mistral'].includes(provider.type)
}
export function isAzureOpenAIProvider(provider: Provider): provider is AzureOpenAIProvider {
return provider.type === 'azure-openai'
}
export function isOpenAIProvider(provider: Provider): boolean {
return provider.type === 'openai-response'
}
export function isAnthropicProvider(provider: Provider): boolean {
return provider.type === 'anthropic'
}
export function isGeminiProvider(provider: Provider): boolean {
return provider.type === 'gemini'
}
export function isVertexAiProvider(provider: Provider): boolean {
return provider.type === 'vertexai'
}
export function isAIGatewayProvider(provider: Provider): boolean {
return provider.type === 'ai-gateway'
}
export function isAwsBedrockProvider(provider: Provider): boolean {
return provider.type === 'aws-bedrock'
}
const NOT_SUPPORT_API_VERSION_PROVIDERS = ['github', 'copilot', 'perplexity'] as const satisfies SystemProviderId[]
export const isSupportAPIVersionProvider = (provider: Provider) => {
if (isSystemProvider(provider)) {
return !NOT_SUPPORT_API_VERSION_PROVIDERS.some((pid) => pid === provider.id)
}
return provider.apiOptions?.isNotSupportAPIVersion !== false
}

View File

@@ -1,56 +0,0 @@
import type { ChatCompletionTool } from '@cherrystudio/openai/resources'
import type { Model } from '@renderer/types'
import { WEB_SEARCH_PROMPT_FOR_ZHIPU } from './prompts'
export function getWebSearchTools(model: Model): ChatCompletionTool[] {
if (model?.provider === 'zhipu') {
if (model.id === 'glm-4-alltools') {
return [
{
type: 'web_browser',
web_browser: {
browser: 'auto'
}
} as unknown as ChatCompletionTool
]
}
return [
{
type: 'web_search',
web_search: {
enable: true,
search_result: true,
search_prompt: WEB_SEARCH_PROMPT_FOR_ZHIPU
}
} as unknown as ChatCompletionTool
]
}
if (model?.id.includes('gemini')) {
return [
{
type: 'function',
function: {
name: 'googleSearch'
}
}
]
}
return []
}
export function getUrlContextTools(model: Model): ChatCompletionTool[] {
if (model.id.includes('gemini')) {
return [
{
type: 'function',
function: {
name: 'urlContext'
}
}
]
}
return []
}

View File

@@ -1,11 +1,31 @@
import { loggerService } from '@logger'
import { useAppDispatch, useAppSelector } from '@renderer/store'
import { setApiServerEnabled as setApiServerEnabledAction } from '@renderer/store/settings'
import { useCallback, useEffect, useState } from 'react'
import { useCallback, useEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
const logger = loggerService.withContext('useApiServer')
// Module-level single instance subscription to prevent EventEmitter memory leak
// Only one IPC listener will be registered regardless of how many components use this hook
const onReadyCallbacks = new Set<() => void>()
let removeIpcListener: (() => void) | null = null
const ensureIpcSubscribed = () => {
if (!removeIpcListener) {
removeIpcListener = window.api.apiServer.onReady(() => {
onReadyCallbacks.forEach((cb) => cb())
})
}
}
const cleanupIpcIfEmpty = () => {
if (onReadyCallbacks.size === 0 && removeIpcListener) {
removeIpcListener()
removeIpcListener = null
}
}
export const useApiServer = () => {
const { t } = useTranslation()
// FIXME: We currently store two copies of the config data in both the renderer and the main processes,
@@ -102,15 +122,28 @@ export const useApiServer = () => {
checkApiServerStatus()
}, [checkApiServerStatus])
// Listen for API server ready event
// Use ref to keep the latest checkApiServerStatus without causing re-subscription
const checkStatusRef = useRef(checkApiServerStatus)
useEffect(() => {
const cleanup = window.api.apiServer.onReady(() => {
logger.info('API server ready event received, checking status')
checkApiServerStatus()
})
checkStatusRef.current = checkApiServerStatus
})
return cleanup
}, [checkApiServerStatus])
// Create stable callback for the single instance subscription
const handleReady = useCallback(() => {
logger.info('API server ready event received, checking status')
checkStatusRef.current()
}, [])
// Listen for API server ready event using single instance subscription
useEffect(() => {
ensureIpcSubscribed()
onReadyCallbacks.add(handleReady)
return () => {
onReadyCallbacks.delete(handleReady)
cleanupIpcIfEmpty()
}
}, [handleReady])
return {
apiServerConfig,

View File

@@ -175,14 +175,46 @@ export function useAppInit() {
useEffect(() => {
if (!window.electron?.ipcRenderer) return
const requestListener = (_event: Electron.IpcRendererEvent, payload: ToolPermissionRequestPayload) => {
const requestListener = async (_event: Electron.IpcRendererEvent, payload: ToolPermissionRequestPayload) => {
logger.debug('Renderer received tool permission request', {
requestId: payload.requestId,
toolName: payload.toolName,
expiresAt: payload.expiresAt,
suggestionCount: payload.suggestions.length
suggestionCount: payload.suggestions.length,
autoApprove: payload.autoApprove
})
dispatch(toolPermissionsActions.requestReceived(payload))
// Auto-approve if requested
if (payload.autoApprove) {
logger.debug('Auto-approving tool permission request', {
requestId: payload.requestId,
toolName: payload.toolName
})
dispatch(toolPermissionsActions.submissionSent({ requestId: payload.requestId, behavior: 'allow' }))
try {
const response = await window.api.agentTools.respondToPermission({
requestId: payload.requestId,
behavior: 'allow',
updatedInput: payload.input,
updatedPermissions: payload.suggestions
})
if (!response?.success) {
throw new Error('Auto-approval response rejected by main process')
}
logger.debug('Auto-approval acknowledged by main process', {
requestId: payload.requestId,
toolName: payload.toolName
})
} catch (error) {
logger.error('Failed to send auto-approval response', error as Error)
dispatch(toolPermissionsActions.submissionFailed({ requestId: payload.requestId }))
}
}
}
const resultListener = (_event: Electron.IpcRendererEvent, payload: ToolPermissionResultPayload) => {

View File

@@ -38,13 +38,6 @@ export function getVertexAIServiceAccount() {
return store.getState().llm.settings.vertexai.serviceAccount
}
/**
* 类型守卫:检查 Provider 是否为 VertexProvider
*/
export function isVertexProvider(provider: Provider): provider is VertexProvider {
return provider.type === 'vertexai'
}
/**
* 创建 VertexProvider 对象,整合单独的配置
* @param baseProvider 基础的 provider 配置

Some files were not shown because too many files have changed in this diff Show More