Compare commits

...

56 Commits

Author SHA1 Message Date
suyao
22ed2b605e feat: Add MCP UI Demo server and integrate UI rendering capabilities 2025-11-24 12:41:03 +08:00
SuYao
2c3338939e feat: update Google and OpenAI SDKs with new features and fixes (#11395)
* feat: update Google and OpenAI SDKs with new features and fixes

- Updated Google SDK to ensure model paths are correctly formatted.
- Enhanced OpenAI SDK to include support for image URLs in chat responses.
- Added reasoning content handling in OpenAI chat responses and chunks.
- Introduced Azure Anthropic provider configuration for Claude integration.

* fix: azure error

* fix: lint

* fix: test

* fix: test

* fix type

* fix comment

* fix: redundant

* chore resolution

* fix: test

* fix: comment

* fix: comment

* fix

* feat: 添加 OpenRouter 推理中间件以支持内容过滤
2025-11-23 23:18:57 +08:00
槑囿脑袋
64ca3802a4 feat: support gemini 3 pro image preview (#11416)
feat: support gemini 3 pro preview
2025-11-23 21:40:22 +08:00
Phantom
fa361126b8 refactor: aisdk config (#11402)
* refactor: improve model filtering with todo for robust conversion

* refactor(aiCore): add AiSdkConfig type and update provider config handling

- Introduce new AiSdkConfig type in aiCoreTypes for better type safety
- Update provider factory and config to use AiSdkConfig consistently
- Simplify getAiSdkProviderId return type to string
- Add config validation in ModernAiProvider

* refactor(aiCore): move ai core types to dedicated module

Consolidate AI core type definitions into a dedicated module under aiCore/types. This improves code organization by keeping related types together and removes circular dependencies between modules. The change includes:
- Moving AiSdkConfig to aiCore/types
- Updating all imports to reference the new location
- Removing duplicate type definitions

* refactor(provider): add return type to createAiSdkProvider function
2025-11-23 21:12:57 +08:00
SuYao
49903a1567 Test/ai-core (#11307)
* test: 1

* test: 2

* test: 3

* format

* chore: move provider from config to utils

* fix: 4

* test: 5

* chore: redundant logic

* test: add reasoning model tests and improve provider options typings

* chore: format

* test 6

* chore: format

* test: 7

* test: 8

* fix: test

* fix: format and typecheck

* fix error

* test: isClaude4SeriesModel

* fix: test

* fix: test

---------

Co-authored-by: defi-failure <159208748+defi-failure@users.noreply.github.com>
2025-11-23 17:33:27 +08:00
Phantom
086b16a59c ci: update PR title in auto-i18n workflow to be more specific (#11406) 2025-11-23 11:48:44 +08:00
github-actions[bot]
e2562d8224 🤖 Weekly Automated Update: Nov 23, 2025 (#11412)
feat(bot): Weekly automated script run

Co-authored-by: EurFelux <59059173+EurFelux@users.noreply.github.com>
2025-11-23 11:47:54 +08:00
Phantom
c9be949853 fix: adjacent user messages appear when assistant message contains error only (#11390)
* feat(messages): add filter for error-only messages and their related pairs

Add new filter function to remove assistant messages containing only error blocks along with their associated user messages, identified by askId. This improves conversation quality by cleaning up error-only responses.

* refactor(ConversationService): improve message filtering pipeline readability

Break down complex message filtering chain into clearly labeled steps
Add comments explaining each filtering step's purpose
Maintain same functionality while improving code maintainability

* test(messageUtils): add test cases for message filter utilities

* docs(messageUtils): correct jsdoc for filterUsefulMessages

* refactor(ConversationService): extract message filtering logic into pipeline method

Move message filtering steps into a dedicated static method to improve testability and maintainability. Add comprehensive tests to verify pipeline behavior.

* refactor(ConversationService): add logging and improve message filtering readability

Add logger service to track message pipeline output
Split filterUserRoleStartMessages into separate variable for better debugging
2025-11-22 23:00:13 +08:00
defi-failure
ebfb1c5abf fix: add missing execution state for approved tool permissions (#11394) 2025-11-22 21:45:42 +08:00
SuYao
c1f1d7996d test: add thinking budget token test (#11305)
* refactor: add thinking budget token test

* fix comment
2025-11-22 21:43:57 +08:00
Phantom
0a72c613af fix(openai): apply verbosity setting with type safety improvements (#10964)
* refactor(types): consolidate OpenAI types and improve type safety

- Move OpenAI-related types to aiCoreTypes.ts
- Rename FetchChatCompletionOptions to FetchChatCompletionRequestOptions
- Add proper type definitions for service tiers and verbosity
- Improve type guards for service tier checks

* refactor(api): rename options parameter to requestOptions for consistency

Update parameter name across multiple files to use requestOptions instead of options for better clarity and consistency in API calls

* refactor(aiCore): simplify OpenAI summary text handling and improve type safety

- Remove 'off' option from OpenAISummaryText type and use null instead
- Add migration to convert 'off' values to null
- Add utility function to convert undefined to null
- Update Selector component to handle null/undefined values
- Improve type safety in provider options and reasoning params

* fix(i18n): Auto update translations for PR #10964

* feat(utils): add notNull function to convert null to undefined

* refactor(utils): move defined and notNull functions to shared package

Consolidate utility functions into shared package to improve code organization and reuse

* Revert "fix(i18n): Auto update translations for PR #10964"

This reverts commit 68bd7eaac5.

* feat(i18n): add "off" translation and remove "performance" tier

Add "off" translation for multiple languages and remove "performance" service tier option from translations

* Apply suggestion from @EurFelux

* docs(types): clarify handling of undefined and null values

Add comments to explain that undefined is treated as default and null as explicitly off in OpenAIVerbosity and OpenAIServiceTier types. Also update type safety for OpenAIServiceTiers record.

* fix(migration): update migration version from 167 to 171 for removed type

* chore: update store version to 172

* fix(migrate): update migration version number from 171 to 172

* fix(i18n): Auto update translations for PR #10964

* refactor(types): improve type safety for verbosity handling

add NotUndefined and NotNull utility types to better handle null/undefined cases
clarify verbosity types in aiCoreTypes and update related utility functions

* refactor(types): replace null with undefined for verbosity values

Standardize on undefined instead of null for verbosity values to align with OpenAI API docs and improve type consistency

* refactor(aiCore): update OpenAI provider options type import and usage

* fix(openai): change summaryText default from null to 'auto'

Update OpenAI settings to use 'auto' as default summaryText value instead of null for consistency with API behavior. Remove 'off' option and add 'concise' option while maintaining type safety.

* refactor(OpenAISettingsGroup): extract service tier options type for better maintainability

* refactor(types): make SystemProviderIdTypeMap internal type

* docs(provider): clarify OpenAIServiceTier behavior for undefined vs null

Explain that undefined and null values for serviceTier should be treated differently since they affect whether the field appears in the response

* refactor(utils): rename utility functions for clarity

Rename `defined` to `toNullIfUndefined` and `notNull` to `toUndefinedIfNull` to better reflect their functionality

* refactor(aiCore): extract service tier logic and improve type safety

Extract service tier validation logic into separate functions for better reusability
Add proper type annotations for provider options
Pass service tier parameter through provider option builders

* refactor(utils): comment out unused utility functions

Keep commented utility functions for potential future use while cleaning up current codebase

* fix(migration): update migration version number from 172 to 177

* docs(aiCoreTypes): clarify parameter passing behavior in OpenAI API

Update comments to consistently use 'undefined' instead of 'null' when describing parameter passing behavior in OpenAI API requests, as they share the same meaning in this context

---------

Co-authored-by: GitHub Action <action@github.com>
2025-11-22 21:41:12 +08:00
SuYao
a1ac3207f1 fix/anthropic-vertex (#11397)
* 100m

* feat: add web search header for Claude 4 series models

* fix: typo

* fix: identify model

---------

Co-authored-by: defi-failure <159208748+defi-failure@users.noreply.github.com>
2025-11-22 20:56:05 +08:00
Caelan
f98a063a8f Fix the issue where base64 images cannot be saved (#11398) 2025-11-22 20:20:02 +08:00
亢奋猫
1cb2af57ae refactor: optimize DatabaseManager and fix libsql crash issues (#11392)
* refactor: optimize DatabaseManager and fix libsql crash issues

Major improvements:
- Created DatabaseManager singleton to centralize database connection management
- Auto-initialize database in constructor (no manual initialization needed)
- Removed all manual initialize() and ensureInitialized() calls (47 occurrences)
- Simplified initialization logic (removed retry loops that could cause crashes)
- Removed unused close() and reinitialize() methods
- Reduced code from ~270 lines to 172 lines (-36%)

Key changes:
1. DatabaseManager.ts (new file):
   - Singleton pattern with auto-initialization
   - State management (INITIALIZING, INITIALIZED, FAILED)
   - Windows compatibility fixes (empty file detection, intMode: 'number')
   - Simplified waitForInitialization() logic

2. BaseService.ts:
   - Removed static initialize() and ensureInitialized() methods
   - Simplified database/rawClient getters to use DatabaseManager

3. Service classes (AgentService, SessionService, SessionMessageService):
   - Removed all initialize() methods
   - Removed all ensureInitialized() calls
   - Services now work out of the box

4. Main entry points (index.ts, server.ts):
   - Removed explicit database initialization calls
   - Database initializes automatically on first access

Benefits:
- Fixes Windows libsql crashes by removing dangerous retry logic
- Simpler API - no need to remember to call initialize()
- Better separation of concerns
- Cleaner codebase with 36% less code

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix: wait for database initialization on app startup

Issue: "Database is still initializing" error on startup
Root cause: Synchronous database getter was called before async initialization completed

Solution:
- Explicitly wait for database initialization in main index.ts
- Import DatabaseManager and call getDatabase() to ensure initialization is complete
- This guarantees database is ready before any service methods are called

Changes:
- src/main/index.ts: Added explicit database initialization wait before API server check

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* refactor: use static import for getDatabaseManager

- Move import to top of file for better code organization
- Remove unnecessary dynamic import

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* refactor: streamline database access in service classes

- Replaced direct database access with asynchronous calls to getDatabase() in various service classes (AgentService, SessionService, SessionMessageService).
- Updated the main index.ts to utilize runAsyncFunction for API server initialization, ensuring proper handling of asynchronous database access.
- Improved code organization and readability by consolidating database access logic.

This change enhances the reliability of database interactions across the application and ensures that services are correctly initialized before use.

* refactor: remove redundant logging in ApiServer initialization

- Removed the logging statement for 'AgentService ready' during server initialization.
- This change streamlines the startup process by eliminating unnecessary log entries.

This update contributes to cleaner logs and improved readability during server startup.

* refactor: change getDatabase method to synchronous return type

- Updated the getDatabase method in DatabaseManager to return a synchronous LibSQLDatabase instance instead of a Promise.
- This change simplifies the database access pattern, aligning with the current initialization logic.

This refactor enhances code clarity and reduces unnecessary asynchronous handling in the database access layer.

* refactor: simplify sessionMessageRepository by removing transaction handling

- Removed transaction handling parameters from message persistence methods in sessionMessageRepository.
- Updated database access to use a direct call to getDatabase() instead of passing a transaction client.
- Streamlined the upsertMessage and persistExchange methods for improved clarity and reduced complexity.

This refactor enhances code readability and simplifies the database interaction logic.

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-11-22 09:12:11 +08:00
fullex
62309ae1bf fix: prevent EventEmitter memory leak in useApiServer hook (#11385)
Implement single instance IPC subscription pattern to resolve MaxListenersExceededWarning. Previously, each component using useApiServer would register a separate 'api-server:ready' listener, and React strict mode double rendering would quickly exceed the 10 listener limit.

Changes:
- Add module-level subscription manager with onReadyCallbacks Set
- Ensure only one IPC listener is registered regardless of component count
- Use useRef to maintain stable callback references
- Properly cleanup subscriptions when all components unmount

This maintains existing behavior while keeping listener count constant at 1.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude <noreply@anthropic.com>
2025-11-21 21:42:34 +08:00
defi-failure
c48f222cdb feat: add endpoint type support for cherryin provider (#11367)
* feat: add endpoint type support for cherryin provider

* chore: bump @cherrystudio/ai-sdk-provider version to 0.1.1

* chore: bump ai-sdk-provider version to 0.1.3
2025-11-21 21:42:08 +08:00
亢奋猫
cea0058f87 refactor: simplify knowledge base creation modal (#11371)
* test(knowledge): fix tests for knowledge base form modal refactoring

Update all test files to match the new vertical layout structure with button-based advanced settings toggle. Remove obsolete tests for deleted features.

Changes:
- Rewrite KnowledgeBaseFormModal.test.tsx for new button-toggle structure
- Remove tests for preprocess and rerank features from GeneralSettingsPanel
- Update AdvancedSettingsPanel tests with required props
- Update all snapshots to reflect new component structure
- Format test files according to biome rules

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* test(knowledge): simplify KnowledgeBaseFormModal button tests

Simplify button interaction tests to avoid text matching issues. Focus on testing behavior rather than implementation details.

Changes:
- Simplify advanced settings toggle test
- Simplify footer buttons test to check button count instead of text content
- Remove fragile text-based button selection

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-11-21 21:34:34 +08:00
beyondkmp
852192dce6 feat: add Git Bash detection and requirement check for Windows agents (#11388)
* feat: add Git Bash detection and requirement check for Windows agents

- Add System_CheckGitBash IPC channel for detecting Git Bash installation
- Implement detection logic checking common installation paths and PATH environment
- Display non-closable error alert in AgentModal when Git Bash is not found
- Disable agent creation/edit button until Git Bash is installed
- Add recheck functionality to verify installation without restarting app

Git Bash is required for agents to function properly on Windows systems.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* i18n: add Git Bash requirement translations for agent modal

- Add English translations for Git Bash detection warnings
- Add Simplified Chinese (zh-cn) translations
- Add Traditional Chinese (zh-tw) translations

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* format code

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-11-21 21:32:53 +08:00
Pleasure1234
eee49d1580 feat: add ChatGPT conversation import feature (#11272)
* feat: add ChatGPT conversation import feature

Introduces a new import workflow for ChatGPT conversations, including UI components, service logic, and i18n support for English, Simplified Chinese, and Traditional Chinese. Adds an import menu to data settings, a popup for file selection and progress, and a service to parse and store imported conversations as topics and messages.

* fix: ci failure

* refactor: import service and add modular importers

Refactored the import service to support a modular importer architecture. Moved ChatGPT import logic to a dedicated importer class and directory. Updated UI components and i18n descriptions for clarity. Removed unused Redux selector in ImportMenuSettings. This change enables easier addition of new importers in the future.

* Apply suggestion from @Copilot

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* fix: improve ChatGPT import UX and set model for assistant

Added a loading state and spinner for file selection in the ChatGPT import popup, with new translations for the 'selecting' state in en-us, zh-cn, and zh-tw locales. Also, set the model property for imported assistant messages to display the GPT-5 logo.

---------

Co-authored-by: SuYao <sy20010504@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-11-21 14:58:47 +08:00
SuYao
dcdd1bf852 refactor: replace renderToolContent function with ToolContent component for improved readability (#11300)
* refactor: replace renderToolContent function with ToolContent component for improved readability

* fix

* fix test
2025-11-21 09:55:46 +08:00
beyondkmp
a12b6bfeca feat: enable native language emoji search with CLDR data format (#11381)
* feat: add i18n support and local data to emoji picker

- Add emoji-picker-element-data package for offline-first emoji data
- Implement i18n translations for emoji picker UI (de, en, es, fr, ja, pt, ru, zh)
- Switch from CDN to local emoji data to improve performance and reliability
- Add locale mapping to match app language with emoji picker data
- Move emoji-picker-element import to EmojiPicker component for better encapsulation
- Use proper TypeScript types instead of 'any' for type safety

This improves user experience by providing localized emoji picker interface
and eliminating dependency on external CDN, ensuring the picker works offline.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* feat: enable native language emoji search with CLDR data format

Switch from emojibase to CLDR format for emoji-picker-element data to support full multi-language search functionality. Users can now search for emojis in their native language (e.g., German users can search "Herz" for ❤️, Spanish users can search "corazón"). Also improves type safety by using the LanguageVarious type for locale mappings.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

---------

Co-authored-by: Claude <noreply@anthropic.com>
2025-11-20 19:23:27 +08:00
亢奋猫
0f1a487bb0 refactor: simplify agent creation form (#11369)
* refactor(AgentModal): simplify agent type handling and update default values

- Removed unused agent type options and related logic.
- Updated default agent name from 'Claude Code' to 'Agent'.
- Adjusted padding in button styles and textarea rows for better UI consistency.
- Cleaned up unnecessary imports and code comments for improved readability.

* refactor(AgentSettings): clean up and enhance name setting component

- Removed unused imports and commented-out code in AgentModal and EssentialSettings.
- Updated NameSetting to include an emoji avatar picker for enhanced user experience.
- Simplified the logic for updating the agent's name and avatar.
- Improved overall readability and maintainability of the code.
2025-11-20 10:42:49 +08:00
亢奋猫
2df8bb58df fix: remove light background from MCP NpxUv install alerts (#11372)
- Remove 'banner' prop from Alert components in InstallNpxUv
- Set SettingContainer background to 'inherit' in MCP settings
- Fixes the light background color issue in NpxUv interface

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-authored-by: Claude <noreply@anthropic.com>
2025-11-20 10:41:41 +08:00
defi-failure
62976f6fe0 refactor: namespace tool call ids with session id to prevent conflicts (#11319) 2025-11-20 10:35:11 +08:00
MyPrototypeWhat
77529b3cd3 chore: update ai-core release scripts and bump version to 1.0.7 (#11370)
* chore: update ai-core release scripts and bump version to 1.0.7

* chore: update ai-sdk-provider release script to include build step and enhance type exports in webSearchPlugin and providers

* chore: bump @cherrystudio/ai-core version to 1.0.8 and update dependencies in package.json and yarn.lock

* chore: bump @cherrystudio/ai-core version to 1.0.9 and @cherrystudio/ai-sdk-provider version to 0.1.2 in package.json and yarn.lock

---------

Co-authored-by: suyao <sy20010504@gmail.com>
2025-11-19 20:44:22 +08:00
SuYao
c8e9a10190 bump ai core version (#11363)
* bump ai core version

* chore

* chore: add patch for @ai-sdk/openai and update peer dependencies in aiCore

* chore: update installation instructions in README to include @ai-sdk/google and @ai-sdk/openai

* chore: bump @cherrystudio/ai-core version to 1.0.6 in package.json and yarn.lock

---------

Co-authored-by: MyPrototypeWhat <daoquqiexing@gmail.com>
2025-11-19 18:13:33 +08:00
scientia
0e011ff35f fix: fix api-host for vercel ai-gateway provider (#11321)
Co-authored-by: scientia <wangdenghui@xiaomi.com>
2025-11-19 17:11:17 +08:00
MyPrototypeWhat
40a64a7c92 feat(options): enhance provider key handling for cherryin in buildPro… (#11361)
feat(options): enhance provider key handling for cherryin in buildProviderOptions function
2025-11-19 16:25:29 +08:00
Phantom
dc9503ef8b feat: support gemini 3 (#11356)
* feat(reasoning): add support for gemini-3-pro-preview model

Update regex pattern to include gemini-3-pro-preview as a supported thinking model
Add tests for new gemini-3 model support and edge cases

* fix(reasoning): update gemini model regex to include stable versions

Add support for stable versions of gemini-3-flash and gemini-3-pro in the model regex pattern. Update tests to verify both preview and stable versions are correctly identified.

* feat(providers): add vertexai provider check function

Add isVertexAiProvider function to consistently check for vertexai provider type and use it in websearch model detection

* feat(websearch): update gemini search regex to include v3 models

Add support for gemini 3.x models in the search regex pattern, including preview versions

* feat(vision): add support for gemini-3 models and add tests

Add regex pattern for gemini-3 models in visionAllowedModels
Create comprehensive test suite for isVisionModel function

* refactor(vision): make vision-related model constants private

Remove unused isNotSupportedImageSizeModel function and change exports to const declarations for internal use only

* chore(deps): update @ai-sdk/google to v2.0.36 and related dependencies

update @ai-sdk/google dependency from v2.0.31 to v2.0.36 to include fixes for model path handling and tool support for newer Gemini models

* chore: remove outdated @ai-sdk-google patch file

* chore: remove outdated @ai-sdk/google patch dependency
2025-11-19 14:05:14 +08:00
beyondkmp
f2c8484c48 feat: enable local crash mini dump file (#11348)
* feat: enabel loca crash mini file dump

* update version
2025-11-18 18:27:57 +08:00
kangfenmao
a9c9224835 fix(migrate): update anthropicApiHost for qiniu and longcat providers in migration to version 176
- Added anthropicApiHost configuration for qiniu and longcat providers during state migration.
- Incremented version number in persistedReducer to 176.
- Ensured proper handling of reasoning_effort settings during migration.
2025-11-18 11:05:46 +08:00
caoli5288
43223fd1f5 feat(config): add anthropicApiHost for qiniu and longcat providers (#11335) 2025-11-18 10:10:59 +08:00
Phantom
4bac843b37 fix(InputbarCore): prevent message send when cannotSend is true (#11337)
Add cannotSend check to prevent message sending when conditions aren't met
2025-11-18 10:08:54 +08:00
Phantom
34723934f4 fix: use function as default tool use mode (#11338)
* refactor(assistant): change default tool use mode to function and use default settings

Simplify reset logic by using DEFAULT_ASSISTANT_SETTINGS object instead of hardcoded values

* fix(ApiService): safely fallback to prompt tool use for unsupported models

Add check for function calling model support before using tool use mode to prevent errors with unsupported models.
2025-11-17 23:28:43 +08:00
defi-failure
096c36caf8 fix: improve todo tool status icon visibility and colors (#11323) 2025-11-17 14:01:27 +08:00
beyondkmp
139950e193 fix(i18n): add input placeholder translations for multiple languages (#11320)
feat(i18n): add input placeholder translations for multiple languages

- Introduced a new placeholder for the input field in various language files, providing guidance on message entry and command selection.
- Updated English, Chinese (Simplified and Traditional), German, Greek, Spanish, French, Japanese, Portuguese, and Russian translations to include the new input placeholder text.
- Adjusted the reference in the AgentSessionInputbar component to use the new translation key for consistency.
2025-11-17 11:51:04 +08:00
SuYao
31eec403f7 fix: url context and web search capability (#11306)
* fix: enhance support for interleaved thinking and model compatibility

* fix: type
2025-11-17 10:53:47 +08:00
槑囿脑袋
7fd4837a47 fix: mineru validate pdf error and 403 error (#11312)
* fix: validate pdf error

* fix: net fetch error

* fix: mineru 403 error

* chore: change comment to english

* fix: format
2025-11-16 16:02:15 +00:00
Carlton
90b0c8b4a6 fix: resolve "no such file" error when processing non-English filenames in open-mineru (#11315) 2025-11-16 22:10:43 +08:00
github-actions[bot]
556353e910 docs: Weekly Automated Update: Nov 16, 2025 (#11308)
feat(bot): Weekly automated script run

Co-authored-by: EurFelux <59059173+EurFelux@users.noreply.github.com>
2025-11-16 10:57:32 +08:00
Copilot
11fb730b4d fix: add verbosity parameter support for GPT-5 models across legacy and modern AI SDK (#11281)
* Initial plan

* feat: add verbosity parameter support for GPT-5 models in OpenAIAPIClient

Co-authored-by: DeJeune <67425183+DeJeune@users.noreply.github.com>

* fix: ensure gpt-5-pro always uses 'high' verbosity

Co-authored-by: DeJeune <67425183+DeJeune@users.noreply.github.com>

* refactor: move verbosity configuration to config/models as suggested

Co-authored-by: DeJeune <67425183+DeJeune@users.noreply.github.com>

* refactor: encapsulate verbosity logic in getVerbosity method

Co-authored-by: DeJeune <67425183+DeJeune@users.noreply.github.com>

* feat: add support for verbosity and reasoning options for GPT-5 Pro and GPT-5.1 models

* fix comment

* build: add @ai-sdk/google dependency

Add the @ai-sdk/google package to support Google AI SDK integration

* build: add @ai-sdk/anthropic dependency

* refactor(aiCore): update reasoning params handling for AI providers

- Add type imports for provider options
- Handle 'none' reasoning effort consistently across providers
- Improve type safety by using Pick with provider options
- Standardize disabled reasoning config for all providers

* fix: adjust none effort ratio from 0 to 0.01

Prevent potential division by zero errors by ensuring none effort ratio has a small positive value

* feat(reasoning): add support for GPT-5.1 series models

Handle 'none' reasoning effort for GPT-5.1 models and add model type check

* Update src/renderer/src/aiCore/utils/reasoning.ts

---------

Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com>
Co-authored-by: DeJeune <67425183+DeJeune@users.noreply.github.com>
Co-authored-by: suyao <sy20010504@gmail.com>
Co-authored-by: icarus <eurfelux@gmail.com>
2025-11-16 10:22:14 +08:00
Phantom
2511113b62 feat: support gpt-5.1 (#11294)
* build: update @cherrystudio/openai dependency from v6.5.0 to v6.9.0

* refactor(reasoning): replace 'off' with 'none' for reasoning effort option

Update reasoning effort option from 'off' to 'none' across multiple files for consistency
Add support for gpt5_1 model with reasoning effort options

* fix(openai): handle apply_patch_call and apply_patch_call_output in response conversion

Filter and properly handle apply_patch_call and apply_patch_call_output types in OpenAI response conversion. Ensure undefined/null values are handled appropriately and log warnings for missing required fields.

* feat(models): add gpt-5.1 model logo and configuration

* fix(providers): include cherryin in url context provider check

Add SystemProviderIds.cherryin to the list of providers that support URL context to ensure proper functionality

* feat(models): add logo images for gpt-5.1 model variants

* feat(model): add support for GPT-5.1 series models

- Add new model type check for GPT-5.1 series
- Update reasoning effort and verbosity checks to include GPT-5.1
- Add logging to provider options builder

* feat(models): add gpt5_1_codex model support

Add new model type 'gpt5_1_codex' to ThinkModelTypes and configure its reasoning effort levels
Update model type detection logic to handle gpt5_1_codex variant
2025-11-15 19:09:43 +08:00
beyondkmp
a29b2bb3d6 chore: update @opeoginni/github-copilot-openai-compatible to support gpt5.1 (#11299)
* chore: update @opeoginni/github-copilot-openai-compatible to version 0.1.21

- Updated package version in package.json and yarn.lock.
- Refactored OpenAIBaseClient to enhance getBaseURL method and improve header management for SDK instances.

* format
2025-11-15 19:07:16 +08:00
beyondkmp
d2be450906 fix: update gitcode update config url (#11298)
* fix: update gitcode update config url

* update version

---------

Co-authored-by: Payne Fu <payne@Paynes-MacBook-Pro.local>
2025-11-15 10:01:33 +08:00
kangfenmao
9c020f0d56 docs: update release notes for v1.7.0-rc.1
Add comprehensive release notes highlighting:
- AI Agent system as the major new feature
- New AI providers support (Hugging Face, Mistral, Perplexity, SophNet)
- Knowledge base enhancements (OpenMinerU, full-text search)
- Image & OCR improvements (Intel OVMS, OpenVINO NPU)
- MCP management interface redesign with dual-column layout
- German language support
- Electron 38.7.0 upgrade and system improvements
- Important bug fixes
2025-11-14 20:04:16 +08:00
fullex
e033eb5b5c Add CODEOWNER for app-upgrade-config.json 2025-11-14 19:02:03 +08:00
beyondkmp
073d43c7cb chore: rename cs-releases to x-files/app-upgrade-config (#11290)
rename cs-releases to x-files/app-upgrade-config
2025-11-14 18:53:11 +08:00
kangfenmao
fa7646e18f feat: enhance DynamicVirtualList with header and className props
- Added `header` prop to display content above the list.
- Introduced `className` prop for additional styling on the container.
- Updated `Sessions` component to utilize `StyledVirtualList` with new props for improved layout and functionality.
2025-11-14 18:29:33 +08:00
beyondkmp
038d30831c ♻️ refactor: implement config-based update system with version compatibility control (#11147)
* ♻️ refactor: implement config-based update system with version compatibility control

Replace GitHub API-based update discovery with JSON config file system. Support
version gating (users below v1.7 must upgrade to v1.7.0 before v2.0). Auto-select
GitHub/GitCode config source based on IP location. Simplify fallback logic.

Changes:
- Add update-config.json with version compatibility rules
- Implement _fetchUpdateConfig() and _findCompatibleChannel()
- Remove legacy _getReleaseVersionFromGithub() and GitHub API dependency
- Refactor _setFeedUrl() with simplified fallback to default feed URLs
- Add design documentation in docs/UPDATE_CONFIG_DESIGN.md

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* fix(i18n): Auto update translations for PR #11147

* format code

* 🔧 chore: update config for v1.7.5 → v2.0.0 → v2.1.6 upgrade path

Update version configuration to support multi-step upgrade path:
- v1.6.x users → v1.7.5 (last v1.x release)
- v1.7.x users → v2.0.0 (v2.x intermediate version)
- v2.0.0+ users → v2.1.6 (current latest)

Changes:
- Update 1.7.0 → 1.7.5 with fixed feedUrl
- Set 2.0.0 as intermediate version with fixed feedUrl
- Add 2.1.6 as current latest pointing to releases/latest

This ensures users upgrade through required intermediate versions
before jumping to major releases.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* 🔧 chore: refactor update config with constants and adjust versions

Refactor update configuration system and adjust to actual versions:

- Add UpdateConfigUrl enum in constant.ts for centralized config URLs
- Point to test server (birdcat.top) for development testing
- Update AppUpdater.ts to use UpdateConfigUrl constants
- Adjust update-config.json to actual v1.6.7 with rc/beta channels
- Remove v2.1.6 entry (not yet released)
- Set package version to 1.6.5 for testing upgrade path
- Add update-config.example.json for reference

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* update version

*  test: add comprehensive unit tests for AppUpdater config system

Add extensive test coverage for new config-based update system including:
- Config fetching with IP-based source selection (GitHub/GitCode)
- Channel compatibility matching with version constraints
- Smart fallback from rc/beta to latest when appropriate
- Multi-step upgrade path validation (1.6.3 → 1.6.7 → 2.0.0)
- Error handling for network and HTTP failures

Test Coverage:
- _fetchUpdateConfig: 4 tests (GitHub/GitCode selection, error handling)
- _findCompatibleChannel: 9 tests (channel matching, version comparison)
- Upgrade Path: 3 tests (version gating scenarios)
- Total: 30 tests, 100% passing

Also optimize _findCompatibleChannel logic with better variable naming
and log messages.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

*  test: add complete multi-step upgrade path tests (1.6.3 → 1.7.5 → 2.0.0 → 2.1.6)

Add comprehensive test suite for complete upgrade journey including:
- Individual step validation (1.6.3→1.7.5, 1.7.5→2.0.0, 2.0.0→2.1.6)
- Full multi-step upgrade simulation with version progression
- Version gating enforcement (block skipping intermediate versions)
- Verification that 1.6.3 cannot directly upgrade to 2.0.0 or 2.1.6
- Verification that 1.7.5 cannot skip 2.0.0 to reach 2.1.6

Test Coverage:
- 6 new tests for complete upgrade path scenarios
- Total: 36 tests, 100% passing

This ensures the version compatibility system correctly enforces
intermediate version upgrades for major releases.

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* 📝 docs: reorganize update config documentation with English translation

Move update configuration design document to docs/technical/ directory
and add English translation for international contributors.

Changes:
- Move docs/UPDATE_CONFIG_DESIGN.md → docs/technical/app-update-config-zh.md
- Add docs/technical/app-update-config-en.md (English translation)
- Organize technical documentation in dedicated directory

Documentation covers:
- Config-based update system design and rationale
- JSON schema with version compatibility control
- Multi-step upgrade path examples (1.6.3 → 1.7.5 → 2.0.0 → 2.1.6)
- TypeScript type definitions and matching algorithms
- GitHub/GitCode source selection for different regions

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* format code

*  test: add tests for latest channel self-comparison prevention

Add tests to verify the optimization that prevents comparing latest
channel with itself when latest is requested, and ensures rc/beta
channels are returned when they are newer than latest.

New tests:
- should not compare latest with itself when requesting latest channel
- should return rc when rc version > latest version
- should return beta when beta version > latest version

These tests ensure the requestedChannel !== UpgradeChannel.LATEST
check works correctly and users get the right channel based on
version comparisons.

Test Coverage: 39 tests, 100% passing

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* update github/gitcode

* format code

* update rc version

* ♻️ refactor: merge update configs into single multi-mirror file

- Merge app-upgrade-config-github.json and app-upgrade-config-gitcode.json into single app-upgrade-config.json
- Add UpdateMirror enum for type-safe mirror selection
- Optimize _fetchUpdateConfig to receive mirror parameter, eliminating duplicate IP country checks
- Update ChannelConfig interface to use Record<UpdateMirror, string> for feedUrls
- Rename documentation files from app-update-config-* to app-upgrade-config-*
- Update docs with new multi-mirror configuration structure

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

*  test: update AppUpdater tests for multi-mirror configuration

- Add UpdateMirror enum import
- Update _fetchUpdateConfig tests to accept mirror parameter
- Convert all feedUrl to feedUrls structure in test mocks
- Update test expectations to match new ChannelConfig interface
- All 39 tests passing

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* format code

* delete files

* 📝 docs: add UpdateMirror enum to type definitions

- Add UpdateMirror enum definition in both EN and ZH docs
- Update ChannelConfig to use Record<UpdateMirror, string>
- Add comments showing equivalent structure for clarity

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* 🐛 fix: return actual channel from _findCompatibleChannel

Fix channel mismatch issue where requesting rc/beta but getting latest:
- Change _findCompatibleChannel return type to include actual channel
- Return { config, channel } instead of just config
- Update _setFeedUrl to use actualChannel instead of requestedChannel
- Update all test expectations to match new return structure
- Add channel assertions to key tests

This ensures autoUpdater.channel matches the actual feed URL being used.

Fixes issue where:
- User requests 'rc' channel
- latest >= rc, so latest config is returned
- But channel was set to 'rc' with latest URL 
- Now channel is correctly set to 'latest' 

All 39 tests passing 

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>

* update version

* udpate version

* update config

* add no cache header

* update files

* 🤖 chore: automate app upgrade config updates

* format code

* update workflow

* update get method

* docs: document upgrade workflow automation

---------

Co-authored-by: Claude <noreply@anthropic.com>
Co-authored-by: GitHub Action <action@github.com>
2025-11-14 17:49:40 +08:00
defi-failure
68ee5164f0 fix: session list can't scroll (#11285) 2025-11-14 17:10:13 +08:00
beyondkmp
a1a3b9bd96 fix: can hide when close the app to tray (#11282)
* fix: can hide when close the app to tray

* format code

* udpate version
2025-11-14 16:52:09 +08:00
SuYao
4e699c48bc fix: update Azure OpenAI API version references to v1 in configuration and translations (#10799)
* fix: update Azure OpenAI API version references to v1 in configuration and translations

* fix: support Azure OpenAI API v1 in client compatibility check

* fix: lint/ format
2025-11-14 13:10:13 +08:00
Zhaokun
75fcf8fbb5 fix: notes content search next scroll (#10908)
* fix: topic branch incomplete copy - split ID mapping into two passes

Fix the bug where topic branching would not copy all message relationships completely.The issue was that askId mapping lookup happened in the same loop as ID generation, causing later messages' askIds to fail mapping when they referenced messages that hadn't been processed yet.

Solution: Split into two passes:
 1. First pass: Generate new IDs for all messages and build complete mapping
 2. Second pass: Clone messages and blocks using the complete ID mapping

This ensures all message relationships (especially assistant message askId references)are properly maintained in the new topic.

* fix(notes): 保持 Ctrl+F ‘下一个’在编辑器容器内滚动,避免索引提前回到第一条

- 使用传入的滚动容器计算相对偏移并 target.scrollTo 居中
- 容器不可滚动时回退到 scrollIntoView,兼容其他页面
- 将 target 纳入依赖,确保引用最新容器

受影响文件:
- src/renderer/src/components/ContentSearch.tsx:165

* fix(search): improve notes content search next-scroll behavior

* Update dom.ts

---------

Co-authored-by: Pleasurecruise <3196812536@qq.com>
2025-11-14 11:51:18 +08:00
Phantom
35aa9d7355 fix: Incorrect navigation when creating new message with @ (#10930)
* fix(message): Incorrect navigation when creating new message with @

Update variable name from newAssistantStub to newAssistantMessageStub for clarity
Add dispatch calls to update message folding state
Remove unused message length tracking effect in MessageGroup

Fixes #10928

* refactor(MessageGroup): remove unused prevMessageLengthRef variable
2025-11-14 11:45:10 +08:00
Pleasure1234
b08aecb22b fix: enable numeric sorting for note names (#11261)
Updated the sorting logic in getSorter to use the 'numeric' option in localeCompare for all name-based sorts. This ensures that note names containing numbers are sorted in a more natural, human-friendly order.
2025-11-14 11:37:19 +08:00
Phantom
45fc6c2afd fix: minimax new api host & anthropic api support (#11269)
* feat(models): add MiniMax M2 models to default configuration

* fix(config): update minimax api host and add anthropic host

Update the API endpoint for MiniMax provider and add a new endpoint for Anthropic integration

* feat: add minimax to ANTHROPIC_COMPATIBLE_PROVIDER_IDS

* docs(ProviderSetting): add todo comment for reset button

* fix(store): update minimax provider config in migration 174

Add anthropicApiHost to minimax provider configuration during state migration

* fix(store): revert version and remove unused migration

Remove migration for version 175 and revert persisted reducer version to 174
2025-11-14 10:55:41 +08:00
226 changed files with 17690 additions and 3867 deletions

1
.github/CODEOWNERS vendored
View File

@@ -3,3 +3,4 @@
/src/main/services/ConfigManager.ts @0xfullex
/packages/shared/IpcChannel.ts @0xfullex
/src/main/ipc.ts @0xfullex
/app-upgrade-config.json @kangfenmao

View File

@@ -77,7 +77,7 @@ jobs:
with:
token: ${{ secrets.GITHUB_TOKEN }} # Use the built-in GITHUB_TOKEN for bot actions
commit-message: "feat(bot): Weekly automated script run"
title: "🤖 Weekly Automated Update: ${{ env.CURRENT_DATE }}"
title: "🤖 Weekly Auto I18N Sync: ${{ env.CURRENT_DATE }}"
body: |
This PR includes changes generated by the weekly auto i18n.
Review the changes before merging.

View File

@@ -0,0 +1,212 @@
name: Update App Upgrade Config
on:
release:
types:
- released
- prereleased
workflow_dispatch:
inputs:
tag:
description: "Release tag (e.g., v1.2.3)"
required: true
type: string
is_prerelease:
description: "Mark the tag as a prerelease when running manually"
required: false
default: false
type: boolean
permissions:
contents: write
pull-requests: write
jobs:
propose-update:
runs-on: ubuntu-latest
if: github.event_name == 'workflow_dispatch' || (github.event_name == 'release' && github.event.release.draft == false)
steps:
- name: Check if should proceed
id: check
run: |
EVENT="${{ github.event_name }}"
if [ "$EVENT" = "workflow_dispatch" ]; then
TAG="${{ github.event.inputs.tag }}"
else
TAG="${{ github.event.release.tag_name }}"
fi
latest_tag=$(
curl -L \
-H "Accept: application/vnd.github+json" \
-H "Authorization: Bearer ${{ github.token }}" \
-H "X-GitHub-Api-Version: 2022-11-28" \
https://api.github.com/repos/${{ github.repository }}/releases/latest \
| jq -r '.tag_name'
)
if [ "$EVENT" = "workflow_dispatch" ]; then
MANUAL_IS_PRERELEASE="${{ github.event.inputs.is_prerelease }}"
if [ -z "$MANUAL_IS_PRERELEASE" ]; then
MANUAL_IS_PRERELEASE="false"
fi
if [ "$MANUAL_IS_PRERELEASE" = "true" ]; then
if ! echo "$TAG" | grep -E '(-beta([.-][0-9]+)?|-rc([.-][0-9]+)?)' >/dev/null; then
echo "Manual prerelease flag set but tag $TAG lacks beta/rc suffix. Skipping." >&2
echo "should_run=false" >> "$GITHUB_OUTPUT"
echo "is_prerelease=false" >> "$GITHUB_OUTPUT"
echo "latest_tag=$latest_tag" >> "$GITHUB_OUTPUT"
exit 0
fi
fi
echo "should_run=true" >> "$GITHUB_OUTPUT"
echo "is_prerelease=$MANUAL_IS_PRERELEASE" >> "$GITHUB_OUTPUT"
echo "latest_tag=$latest_tag" >> "$GITHUB_OUTPUT"
exit 0
fi
IS_PRERELEASE="${{ github.event.release.prerelease }}"
if [ "$IS_PRERELEASE" = "true" ]; then
if ! echo "$TAG" | grep -E '(-beta([.-][0-9]+)?|-rc([.-][0-9]+)?)' >/dev/null; then
echo "Release marked as prerelease but tag $TAG lacks beta/rc suffix. Skipping." >&2
echo "should_run=false" >> "$GITHUB_OUTPUT"
echo "is_prerelease=false" >> "$GITHUB_OUTPUT"
echo "latest_tag=$latest_tag" >> "$GITHUB_OUTPUT"
exit 0
fi
echo "should_run=true" >> "$GITHUB_OUTPUT"
echo "is_prerelease=true" >> "$GITHUB_OUTPUT"
echo "latest_tag=$latest_tag" >> "$GITHUB_OUTPUT"
echo "Release is prerelease, proceeding"
exit 0
fi
if [[ "${latest_tag}" == "$TAG" ]]; then
echo "should_run=true" >> "$GITHUB_OUTPUT"
echo "is_prerelease=false" >> "$GITHUB_OUTPUT"
echo "latest_tag=$latest_tag" >> "$GITHUB_OUTPUT"
echo "Release is latest, proceeding"
else
echo "should_run=false" >> "$GITHUB_OUTPUT"
echo "is_prerelease=false" >> "$GITHUB_OUTPUT"
echo "latest_tag=$latest_tag" >> "$GITHUB_OUTPUT"
echo "Release is neither prerelease nor latest, skipping"
fi
- name: Prepare metadata
id: meta
if: steps.check.outputs.should_run == 'true'
run: |
EVENT="${{ github.event_name }}"
LATEST_TAG="${{ steps.check.outputs.latest_tag }}"
if [ "$EVENT" = "release" ]; then
TAG="${{ github.event.release.tag_name }}"
PRE="${{ github.event.release.prerelease }}"
if [ -n "$LATEST_TAG" ] && [ "$LATEST_TAG" = "$TAG" ]; then
LATEST="true"
else
LATEST="false"
fi
TRIGGER="release"
else
TAG="${{ github.event.inputs.tag }}"
PRE="${{ github.event.inputs.is_prerelease }}"
if [ -z "$PRE" ]; then
PRE="false"
fi
if [ -n "$LATEST_TAG" ] && [ "$LATEST_TAG" = "$TAG" ] && [ "$PRE" != "true" ]; then
LATEST="true"
else
LATEST="false"
fi
TRIGGER="manual"
fi
SAFE_TAG=$(echo "$TAG" | sed 's/[^A-Za-z0-9._-]/-/g')
echo "tag=$TAG" >> "$GITHUB_OUTPUT"
echo "safe_tag=$SAFE_TAG" >> "$GITHUB_OUTPUT"
echo "prerelease=$PRE" >> "$GITHUB_OUTPUT"
echo "latest=$LATEST" >> "$GITHUB_OUTPUT"
echo "trigger=$TRIGGER" >> "$GITHUB_OUTPUT"
- name: Checkout default branch
if: steps.check.outputs.should_run == 'true'
uses: actions/checkout@v5
with:
ref: ${{ github.event.repository.default_branch }}
path: main
fetch-depth: 0
- name: Checkout x-files/app-upgrade-config branch
if: steps.check.outputs.should_run == 'true'
uses: actions/checkout@v5
with:
ref: x-files/app-upgrade-config
path: cs
fetch-depth: 0
- name: Setup Node.js
if: steps.check.outputs.should_run == 'true'
uses: actions/setup-node@v4
with:
node-version: 22
- name: Enable Corepack
if: steps.check.outputs.should_run == 'true'
run: corepack enable && corepack prepare yarn@4.9.1 --activate
- name: Install dependencies
if: steps.check.outputs.should_run == 'true'
working-directory: main
run: yarn install --immutable
- name: Update upgrade config
if: steps.check.outputs.should_run == 'true'
working-directory: main
env:
RELEASE_TAG: ${{ steps.meta.outputs.tag }}
IS_PRERELEASE: ${{ steps.check.outputs.is_prerelease }}
run: |
yarn tsx scripts/update-app-upgrade-config.ts \
--tag "$RELEASE_TAG" \
--config ../cs/app-upgrade-config.json \
--is-prerelease "$IS_PRERELEASE"
- name: Detect changes
if: steps.check.outputs.should_run == 'true'
id: diff
working-directory: cs
run: |
if git diff --quiet -- app-upgrade-config.json; then
echo "changed=false" >> "$GITHUB_OUTPUT"
else
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
- name: Create pull request
if: steps.check.outputs.should_run == 'true' && steps.diff.outputs.changed == 'true'
uses: peter-evans/create-pull-request@v7
with:
path: cs
base: x-files/app-upgrade-config
branch: chore/update-app-upgrade-config/${{ steps.meta.outputs.safe_tag }}
commit-message: "🤖 chore: sync app-upgrade-config for ${{ steps.meta.outputs.tag }}"
title: "chore: update app-upgrade-config for ${{ steps.meta.outputs.tag }}"
body: |
Automated update triggered by `${{ steps.meta.outputs.trigger }}`.
- Source tag: `${{ steps.meta.outputs.tag }}`
- Pre-release: `${{ steps.meta.outputs.prerelease }}`
- Latest: `${{ steps.meta.outputs.latest }}`
- Workflow run: https://github.com/${{ github.repository }}/actions/runs/${{ github.run_id }}
labels: |
automation
app-upgrade
- name: No changes detected
if: steps.check.outputs.should_run == 'true' && steps.diff.outputs.changed != 'true'
run: echo "No updates required for x-files/app-upgrade-config/app-upgrade-config.json"

View File

@@ -1,8 +1,8 @@
diff --git a/dist/index.js b/dist/index.js
index ff305b112779b718f21a636a27b1196125a332d9..cf32ff5086d4d9e56f8fe90c98724559083bafc3 100644
index dc7b74ba55337c491cdf1ab3e39ca68cc4187884..ace8c90591288e42c2957e93c9bf7984f1b22444 100644
--- a/dist/index.js
+++ b/dist/index.js
@@ -471,7 +471,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
@@ -472,7 +472,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
// src/get-model-path.ts
function getModelPath(modelId) {
@@ -12,10 +12,10 @@ index ff305b112779b718f21a636a27b1196125a332d9..cf32ff5086d4d9e56f8fe90c98724559
// src/google-generative-ai-options.ts
diff --git a/dist/index.mjs b/dist/index.mjs
index 57659290f1cec74878a385626ad75b2a4d5cd3fc..d04e5927ec3725b6ffdb80868bfa1b5a48849537 100644
index 8390439c38cb7eaeb52080862cd6f4c58509e67c..a7647f2e11700dff7e1c8d4ae8f99d3637010733 100644
--- a/dist/index.mjs
+++ b/dist/index.mjs
@@ -477,7 +477,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
@@ -478,7 +478,7 @@ function convertToGoogleGenerativeAIMessages(prompt, options) {
// src/get-model-path.ts
function getModelPath(modelId) {

View File

@@ -1,131 +0,0 @@
diff --git a/dist/index.mjs b/dist/index.mjs
index b3f018730a93639aad7c203f15fb1aeb766c73f4..ade2a43d66e9184799d072153df61ef7be4ea110 100644
--- a/dist/index.mjs
+++ b/dist/index.mjs
@@ -296,7 +296,14 @@ var HuggingFaceResponsesLanguageModel = class {
metadata: huggingfaceOptions == null ? void 0 : huggingfaceOptions.metadata,
instructions: huggingfaceOptions == null ? void 0 : huggingfaceOptions.instructions,
...preparedTools && { tools: preparedTools },
- ...preparedToolChoice && { tool_choice: preparedToolChoice }
+ ...preparedToolChoice && { tool_choice: preparedToolChoice },
+ ...(huggingfaceOptions?.reasoningEffort != null && {
+ reasoning: {
+ ...(huggingfaceOptions?.reasoningEffort != null && {
+ effort: huggingfaceOptions.reasoningEffort,
+ }),
+ },
+ }),
};
return { args: baseArgs, warnings };
}
@@ -365,6 +372,20 @@ var HuggingFaceResponsesLanguageModel = class {
}
break;
}
+ case 'reasoning': {
+ for (const contentPart of part.content) {
+ content.push({
+ type: 'reasoning',
+ text: contentPart.text,
+ providerMetadata: {
+ huggingface: {
+ itemId: part.id,
+ },
+ },
+ });
+ }
+ break;
+ }
case "mcp_call": {
content.push({
type: "tool-call",
@@ -519,6 +540,11 @@ var HuggingFaceResponsesLanguageModel = class {
id: value.item.call_id,
toolName: value.item.name
});
+ } else if (value.item.type === 'reasoning') {
+ controller.enqueue({
+ type: 'reasoning-start',
+ id: value.item.id,
+ });
}
return;
}
@@ -570,6 +596,22 @@ var HuggingFaceResponsesLanguageModel = class {
});
return;
}
+ if (isReasoningDeltaChunk(value)) {
+ controller.enqueue({
+ type: 'reasoning-delta',
+ id: value.item_id,
+ delta: value.delta,
+ });
+ return;
+ }
+
+ if (isReasoningEndChunk(value)) {
+ controller.enqueue({
+ type: 'reasoning-end',
+ id: value.item_id,
+ });
+ return;
+ }
},
flush(controller) {
controller.enqueue({
@@ -593,7 +635,8 @@ var HuggingFaceResponsesLanguageModel = class {
var huggingfaceResponsesProviderOptionsSchema = z2.object({
metadata: z2.record(z2.string(), z2.string()).optional(),
instructions: z2.string().optional(),
- strictJsonSchema: z2.boolean().optional()
+ strictJsonSchema: z2.boolean().optional(),
+ reasoningEffort: z2.string().optional(),
});
var huggingfaceResponsesResponseSchema = z2.object({
id: z2.string(),
@@ -727,12 +770,31 @@ var responseCreatedChunkSchema = z2.object({
model: z2.string()
})
});
+var reasoningTextDeltaChunkSchema = z2.object({
+ type: z2.literal('response.reasoning_text.delta'),
+ item_id: z2.string(),
+ output_index: z2.number(),
+ content_index: z2.number(),
+ delta: z2.string(),
+ sequence_number: z2.number(),
+});
+
+var reasoningTextEndChunkSchema = z2.object({
+ type: z2.literal('response.reasoning_text.done'),
+ item_id: z2.string(),
+ output_index: z2.number(),
+ content_index: z2.number(),
+ text: z2.string(),
+ sequence_number: z2.number(),
+});
var huggingfaceResponsesChunkSchema = z2.union([
responseOutputItemAddedSchema,
responseOutputItemDoneSchema,
textDeltaChunkSchema,
responseCompletedChunkSchema,
responseCreatedChunkSchema,
+ reasoningTextDeltaChunkSchema,
+ reasoningTextEndChunkSchema,
z2.object({ type: z2.string() }).loose()
// fallback for unknown chunks
]);
@@ -751,6 +813,12 @@ function isResponseCompletedChunk(chunk) {
function isResponseCreatedChunk(chunk) {
return chunk.type === "response.created";
}
+function isReasoningDeltaChunk(chunk) {
+ return chunk.type === 'response.reasoning_text.delta';
+}
+function isReasoningEndChunk(chunk) {
+ return chunk.type === 'response.reasoning_text.done';
+}
// src/huggingface-provider.ts
function createHuggingFace(options = {}) {

View File

@@ -0,0 +1,140 @@
diff --git a/dist/index.js b/dist/index.js
index 73045a7d38faafdc7f7d2cd79d7ff0e2b031056b..8d948c9ac4ea4b474db9ef3c5491961e7fcf9a07 100644
--- a/dist/index.js
+++ b/dist/index.js
@@ -421,6 +421,17 @@ var OpenAICompatibleChatLanguageModel = class {
text: reasoning
});
}
+ if (choice.message.images) {
+ for (const image of choice.message.images) {
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
+ content.push({
+ type: 'file',
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
+ data: match2 ? match2[1] : image.image_url.url,
+ });
+ }
+ }
if (choice.message.tool_calls != null) {
for (const toolCall of choice.message.tool_calls) {
content.push({
@@ -598,6 +609,17 @@ var OpenAICompatibleChatLanguageModel = class {
delta: delta.content
});
}
+ if (delta.images) {
+ for (const image of delta.images) {
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
+ controller.enqueue({
+ type: 'file',
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
+ data: match2 ? match2[1] : image.image_url.url,
+ });
+ }
+ }
if (delta.tool_calls != null) {
for (const toolCallDelta of delta.tool_calls) {
const index = toolCallDelta.index;
@@ -765,6 +787,14 @@ var OpenAICompatibleChatResponseSchema = import_v43.z.object({
arguments: import_v43.z.string()
})
})
+ ).nullish(),
+ images: import_v43.z.array(
+ import_v43.z.object({
+ type: import_v43.z.literal('image_url'),
+ image_url: import_v43.z.object({
+ url: import_v43.z.string(),
+ })
+ })
).nullish()
}),
finish_reason: import_v43.z.string().nullish()
@@ -795,6 +825,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => import_v43.z.union(
arguments: import_v43.z.string().nullish()
})
})
+ ).nullish(),
+ images: import_v43.z.array(
+ import_v43.z.object({
+ type: import_v43.z.literal('image_url'),
+ image_url: import_v43.z.object({
+ url: import_v43.z.string(),
+ })
+ })
).nullish()
}).nullish(),
finish_reason: import_v43.z.string().nullish()
diff --git a/dist/index.mjs b/dist/index.mjs
index 1c2b9560bbfbfe10cb01af080aeeed4ff59db29c..2c8ddc4fc9bfc5e7e06cfca105d197a08864c427 100644
--- a/dist/index.mjs
+++ b/dist/index.mjs
@@ -405,6 +405,17 @@ var OpenAICompatibleChatLanguageModel = class {
text: reasoning
});
}
+ if (choice.message.images) {
+ for (const image of choice.message.images) {
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
+ content.push({
+ type: 'file',
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
+ data: match2 ? match2[1] : image.image_url.url,
+ });
+ }
+ }
if (choice.message.tool_calls != null) {
for (const toolCall of choice.message.tool_calls) {
content.push({
@@ -582,6 +593,17 @@ var OpenAICompatibleChatLanguageModel = class {
delta: delta.content
});
}
+ if (delta.images) {
+ for (const image of delta.images) {
+ const match1 = image.image_url.url.match(/^data:([^;]+)/)
+ const match2 = image.image_url.url.match(/^data:[^;]*;base64,(.+)$/);
+ controller.enqueue({
+ type: 'file',
+ mediaType: match1 ? (match1[1] ?? 'image/jpeg') : 'image/jpeg',
+ data: match2 ? match2[1] : image.image_url.url,
+ });
+ }
+ }
if (delta.tool_calls != null) {
for (const toolCallDelta of delta.tool_calls) {
const index = toolCallDelta.index;
@@ -749,6 +771,14 @@ var OpenAICompatibleChatResponseSchema = z3.object({
arguments: z3.string()
})
})
+ ).nullish(),
+ images: z3.array(
+ z3.object({
+ type: z3.literal('image_url'),
+ image_url: z3.object({
+ url: z3.string(),
+ })
+ })
).nullish()
}),
finish_reason: z3.string().nullish()
@@ -779,6 +809,14 @@ var createOpenAICompatibleChatChunkSchema = (errorSchema) => z3.union([
arguments: z3.string().nullish()
})
})
+ ).nullish(),
+ images: z3.array(
+ z3.object({
+ type: z3.literal('image_url'),
+ image_url: z3.object({
+ url: z3.string(),
+ })
+ })
).nullish()
}).nullish(),
finish_reason: z3.string().nullish()

View File

@@ -1,5 +1,5 @@
diff --git a/dist/index.js b/dist/index.js
index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa96b52ac0d 100644
index 7481f3b3511078068d87d03855b568b20bb86971..8ac5ec28d2f7ad1b3b0d3f8da945c75674e59637 100644
--- a/dist/index.js
+++ b/dist/index.js
@@ -274,6 +274,7 @@ var openaiChatResponseSchema = (0, import_provider_utils3.lazyValidator)(
@@ -18,7 +18,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
tool_calls: import_v42.z.array(
import_v42.z.object({
index: import_v42.z.number(),
@@ -785,6 +787,13 @@ var OpenAIChatLanguageModel = class {
@@ -795,6 +797,13 @@ var OpenAIChatLanguageModel = class {
if (text != null && text.length > 0) {
content.push({ type: "text", text });
}
@@ -32,7 +32,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
for (const toolCall of (_a = choice.message.tool_calls) != null ? _a : []) {
content.push({
type: "tool-call",
@@ -866,6 +875,7 @@ var OpenAIChatLanguageModel = class {
@@ -876,6 +885,7 @@ var OpenAIChatLanguageModel = class {
};
let metadataExtracted = false;
let isActiveText = false;
@@ -40,7 +40,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
const providerMetadata = { openai: {} };
return {
stream: response.pipeThrough(
@@ -923,6 +933,21 @@ var OpenAIChatLanguageModel = class {
@@ -933,6 +943,21 @@ var OpenAIChatLanguageModel = class {
return;
}
const delta = choice.delta;
@@ -62,7 +62,7 @@ index 992c85ac6656e51c3471af741583533c5a7bf79f..83c05952a07aebb95fc6c62f9ddb8aa9
if (delta.content != null) {
if (!isActiveText) {
controller.enqueue({ type: "text-start", id: "0" });
@@ -1035,6 +1060,9 @@ var OpenAIChatLanguageModel = class {
@@ -1045,6 +1070,9 @@ var OpenAIChatLanguageModel = class {
}
},
flush(controller) {

View File

@@ -1,276 +0,0 @@
diff --git a/out/macPackager.js b/out/macPackager.js
index 852f6c4d16f86a7bb8a78bf1ed5a14647a279aa1..60e7f5f16a844541eb1909b215fcda1811e924b8 100644
--- a/out/macPackager.js
+++ b/out/macPackager.js
@@ -423,7 +423,7 @@ class MacPackager extends platformPackager_1.PlatformPackager {
}
appPlist.CFBundleName = appInfo.productName;
appPlist.CFBundleDisplayName = appInfo.productName;
- const minimumSystemVersion = this.platformSpecificBuildOptions.minimumSystemVersion;
+ const minimumSystemVersion = this.platformSpecificBuildOptions.LSMinimumSystemVersion;
if (minimumSystemVersion != null) {
appPlist.LSMinimumSystemVersion = minimumSystemVersion;
}
diff --git a/out/publish/updateInfoBuilder.js b/out/publish/updateInfoBuilder.js
index 7924c5b47d01f8dfccccb8f46658015fa66da1f7..1a1588923c3939ae1297b87931ba83f0ebc052d8 100644
--- a/out/publish/updateInfoBuilder.js
+++ b/out/publish/updateInfoBuilder.js
@@ -133,6 +133,7 @@ async function createUpdateInfo(version, event, releaseInfo) {
const customUpdateInfo = event.updateInfo;
const url = path.basename(event.file);
const sha512 = (customUpdateInfo == null ? null : customUpdateInfo.sha512) || (await (0, hash_1.hashFile)(event.file));
+ const minimumSystemVersion = customUpdateInfo == null ? null : customUpdateInfo.minimumSystemVersion;
const files = [{ url, sha512 }];
const result = {
// @ts-ignore
@@ -143,9 +144,13 @@ async function createUpdateInfo(version, event, releaseInfo) {
path: url /* backward compatibility, electron-updater 1.x - electron-updater 2.15.0 */,
// @ts-ignore
sha512 /* backward compatibility, electron-updater 1.x - electron-updater 2.15.0 */,
+ minimumSystemVersion,
...releaseInfo,
};
if (customUpdateInfo != null) {
+ if (customUpdateInfo.minimumSystemVersion) {
+ delete customUpdateInfo.minimumSystemVersion;
+ }
// file info or nsis web installer packages info
Object.assign("sha512" in customUpdateInfo ? files[0] : result, customUpdateInfo);
}
diff --git a/out/targets/ArchiveTarget.js b/out/targets/ArchiveTarget.js
index e1f52a5fa86fff6643b2e57eaf2af318d541f865..47cc347f154a24b365e70ae5e1f6d309f3582ed0 100644
--- a/out/targets/ArchiveTarget.js
+++ b/out/targets/ArchiveTarget.js
@@ -69,6 +69,9 @@ class ArchiveTarget extends core_1.Target {
}
}
}
+ if (updateInfo != null && this.packager.platformSpecificBuildOptions.minimumSystemVersion) {
+ updateInfo.minimumSystemVersion = this.packager.platformSpecificBuildOptions.minimumSystemVersion;
+ }
await packager.info.emitArtifactBuildCompleted({
updateInfo,
file: artifactPath,
diff --git a/out/targets/nsis/NsisTarget.js b/out/targets/nsis/NsisTarget.js
index e8bd7bb46c8a54b3f55cf3a853ef924195271e01..f956e9f3fe9eb903c78aef3502553b01de4b89b1 100644
--- a/out/targets/nsis/NsisTarget.js
+++ b/out/targets/nsis/NsisTarget.js
@@ -305,6 +305,9 @@ class NsisTarget extends core_1.Target {
if (updateInfo != null && isPerMachine && (oneClick || options.packElevateHelper)) {
updateInfo.isAdminRightsRequired = true;
}
+ if (updateInfo != null && this.packager.platformSpecificBuildOptions.minimumSystemVersion) {
+ updateInfo.minimumSystemVersion = this.packager.platformSpecificBuildOptions.minimumSystemVersion;
+ }
await packager.info.emitArtifactBuildCompleted({
file: installerPath,
updateInfo,
diff --git a/out/util/yarn.js b/out/util/yarn.js
index 1ee20f8b252a8f28d0c7b103789cf0a9a427aec1..c2878ec54d57da50bf14225e0c70c9c88664eb8a 100644
--- a/out/util/yarn.js
+++ b/out/util/yarn.js
@@ -140,6 +140,7 @@ async function rebuild(config, { appDir, projectDir }, options) {
arch,
platform,
buildFromSource,
+ ignoreModules: config.excludeReBuildModules || undefined,
projectRootPath: projectDir,
mode: config.nativeRebuilder || "sequential",
disablePreGypCopy: true,
diff --git a/scheme.json b/scheme.json
index 433e2efc9cef156ff5444f0c4520362ed2ef9ea7..0167441bf928a92f59b5dbe70b2317a74dda74c9 100644
--- a/scheme.json
+++ b/scheme.json
@@ -1825,6 +1825,20 @@
"string"
]
},
+ "excludeReBuildModules": {
+ "anyOf": [
+ {
+ "items": {
+ "type": "string"
+ },
+ "type": "array"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "description": "The modules to exclude from the rebuild."
+ },
"executableArgs": {
"anyOf": [
{
@@ -1975,6 +1989,13 @@
],
"description": "The mime types in addition to specified in the file associations. Use it if you don't want to register a new mime type, but reuse existing."
},
+ "minimumSystemVersion": {
+ "description": "The minimum os kernel version required to install the application.",
+ "type": [
+ "null",
+ "string"
+ ]
+ },
"packageCategory": {
"description": "backward compatibility + to allow specify fpm-only category for all possible fpm targets in one place",
"type": [
@@ -2327,6 +2348,13 @@
"MacConfiguration": {
"additionalProperties": false,
"properties": {
+ "LSMinimumSystemVersion": {
+ "description": "The minimum version of macOS required for the app to run. Corresponds to `LSMinimumSystemVersion`.",
+ "type": [
+ "null",
+ "string"
+ ]
+ },
"additionalArguments": {
"anyOf": [
{
@@ -2527,6 +2555,20 @@
"string"
]
},
+ "excludeReBuildModules": {
+ "anyOf": [
+ {
+ "items": {
+ "type": "string"
+ },
+ "type": "array"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "description": "The modules to exclude from the rebuild."
+ },
"executableName": {
"description": "The executable name. Defaults to `productName`.",
"type": [
@@ -2737,7 +2779,7 @@
"type": "boolean"
},
"minimumSystemVersion": {
- "description": "The minimum version of macOS required for the app to run. Corresponds to `LSMinimumSystemVersion`.",
+ "description": "The minimum os kernel version required to install the application.",
"type": [
"null",
"string"
@@ -2959,6 +3001,13 @@
"MasConfiguration": {
"additionalProperties": false,
"properties": {
+ "LSMinimumSystemVersion": {
+ "description": "The minimum version of macOS required for the app to run. Corresponds to `LSMinimumSystemVersion`.",
+ "type": [
+ "null",
+ "string"
+ ]
+ },
"additionalArguments": {
"anyOf": [
{
@@ -3159,6 +3208,20 @@
"string"
]
},
+ "excludeReBuildModules": {
+ "anyOf": [
+ {
+ "items": {
+ "type": "string"
+ },
+ "type": "array"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "description": "The modules to exclude from the rebuild."
+ },
"executableName": {
"description": "The executable name. Defaults to `productName`.",
"type": [
@@ -3369,7 +3432,7 @@
"type": "boolean"
},
"minimumSystemVersion": {
- "description": "The minimum version of macOS required for the app to run. Corresponds to `LSMinimumSystemVersion`.",
+ "description": "The minimum os kernel version required to install the application.",
"type": [
"null",
"string"
@@ -6381,6 +6444,20 @@
"string"
]
},
+ "excludeReBuildModules": {
+ "anyOf": [
+ {
+ "items": {
+ "type": "string"
+ },
+ "type": "array"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "description": "The modules to exclude from the rebuild."
+ },
"executableName": {
"description": "The executable name. Defaults to `productName`.",
"type": [
@@ -6507,6 +6584,13 @@
"string"
]
},
+ "minimumSystemVersion": {
+ "description": "The minimum os kernel version required to install the application.",
+ "type": [
+ "null",
+ "string"
+ ]
+ },
"protocols": {
"anyOf": [
{
@@ -7153,6 +7237,20 @@
"string"
]
},
+ "excludeReBuildModules": {
+ "anyOf": [
+ {
+ "items": {
+ "type": "string"
+ },
+ "type": "array"
+ },
+ {
+ "type": "null"
+ }
+ ],
+ "description": "The modules to exclude from the rebuild."
+ },
"executableName": {
"description": "The executable name. Defaults to `productName`.",
"type": [
@@ -7376,6 +7474,13 @@
],
"description": "MAS (Mac Application Store) development options (`mas-dev` target)."
},
+ "minimumSystemVersion": {
+ "description": "The minimum os kernel version required to install the application.",
+ "type": [
+ "null",
+ "string"
+ ]
+ },
"msi": {
"anyOf": [
{

View File

@@ -0,0 +1,14 @@
diff --git a/out/util.js b/out/util.js
index 9294ffd6ca8f02c2e0f90c663e7e9cdc02c1ac37..f52107493e2995320ee4efd0eb2a8c9bf03291a2 100644
--- a/out/util.js
+++ b/out/util.js
@@ -23,7 +23,8 @@ function newUrlFromBase(pathname, baseUrl, addRandomQueryToAvoidCaching = false)
result.search = search;
}
else if (addRandomQueryToAvoidCaching) {
- result.search = `noCache=${Date.now().toString(32)}`;
+ // use no cache header instead
+ // result.search = `noCache=${Date.now().toString(32)}`;
}
return result;
}

49
app-upgrade-config.json Normal file
View File

@@ -0,0 +1,49 @@
{
"lastUpdated": "2025-11-10T08:14:28Z",
"versions": {
"1.6.7": {
"metadata": {
"segmentId": "legacy-v1",
"segmentType": "legacy"
},
"minCompatibleVersion": "1.0.0",
"description": "Last stable v1.7.x release - required intermediate version for users below v1.7",
"channels": {
"latest": {
"version": "1.6.7",
"feedUrls": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.7",
"gitcode": "https://releases.cherry-ai.com"
}
},
"rc": {
"version": "1.6.0-rc.5",
"feedUrls": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.0-rc.5",
"gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.0-rc.5"
}
},
"beta": {
"version": "1.7.0-beta.3",
"feedUrls": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.7.0-beta.3",
"gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.7.0-beta.3"
}
}
}
},
"2.0.0": {
"metadata": {
"segmentId": "gateway-v2",
"segmentType": "breaking"
},
"minCompatibleVersion": "1.7.0",
"description": "Major release v2.0 - required intermediate version for v2.x upgrades",
"channels": {
"latest": null,
"rc": null,
"beta": null
}
}
}
}

View File

@@ -14,7 +14,7 @@
}
},
"enabled": true,
"includes": ["**/*.json", "!*.json", "!**/package.json"]
"includes": ["**/*.json", "!*.json", "!**/package.json", "!coverage/**"]
},
"css": {
"formatter": {
@@ -23,7 +23,7 @@
},
"files": {
"ignoreUnknown": false,
"includes": ["**", "!**/.claude/**"],
"includes": ["**", "!**/.claude/**", "!**/.vscode/**"],
"maxSize": 2097152
},
"formatter": {

View File

@@ -0,0 +1,81 @@
{
"segments": [
{
"id": "legacy-v1",
"type": "legacy",
"match": {
"range": ">=1.0.0 <2.0.0"
},
"minCompatibleVersion": "1.0.0",
"description": "Last stable v1.7.x release - required intermediate version for users below v1.7",
"channelTemplates": {
"latest": {
"feedTemplates": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
"gitcode": "https://releases.cherry-ai.com"
}
},
"rc": {
"feedTemplates": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
"gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}"
}
},
"beta": {
"feedTemplates": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
"gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}"
}
}
}
},
{
"id": "gateway-v2",
"type": "breaking",
"match": {
"exact": ["2.0.0"]
},
"lockedVersion": "2.0.0",
"minCompatibleVersion": "1.7.0",
"description": "Major release v2.0 - required intermediate version for v2.x upgrades",
"channelTemplates": {
"latest": {
"feedTemplates": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/{{tag}}"
}
}
}
},
{
"id": "current-v2",
"type": "latest",
"match": {
"range": ">=2.0.0 <3.0.0",
"excludeExact": ["2.0.0"]
},
"minCompatibleVersion": "2.0.0",
"description": "Current latest v2.x release",
"channelTemplates": {
"latest": {
"feedTemplates": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/{{tag}}"
}
},
"rc": {
"feedTemplates": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/{{tag}}"
}
},
"beta": {
"feedTemplates": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/{{tag}}",
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/{{tag}}"
}
}
}
}
]
}

View File

@@ -0,0 +1,430 @@
# Update Configuration System Design Document
## Background
Currently, AppUpdater directly queries the GitHub API to retrieve beta and rc update information. To support users in China, we need to fetch a static JSON configuration file from GitHub/GitCode based on IP geolocation, which contains update URLs for all channels.
## Design Goals
1. Support different configuration sources based on IP geolocation (GitHub/GitCode)
2. Support version compatibility control (e.g., users below v1.x must upgrade to v1.7.0 before upgrading to v2.0)
3. Easy to extend, supporting future multi-major-version upgrade paths (v1.6 → v1.7 → v2.0 → v2.8 → v3.0)
4. Maintain compatibility with existing electron-updater mechanism
## Current Version Strategy
- **v1.7.x** is the last version of the 1.x series
- Users **below v1.7.0** must first upgrade to v1.7.0 (or higher 1.7.x version)
- Users **v1.7.0 and above** can directly upgrade to v2.x.x
## Automation Workflow
The `x-files/app-upgrade-config/app-upgrade-config.json` file is synchronized by the [`Update App Upgrade Config`](../../.github/workflows/update-app-upgrade-config.yml) workflow. The workflow runs the [`scripts/update-app-upgrade-config.ts`](../../scripts/update-app-upgrade-config.ts) helper so that every release tag automatically updates the JSON in `x-files/app-upgrade-config`.
### Trigger Conditions
- **Release events (`release: released/prereleased`)**
- Draft releases are ignored.
- When GitHub marks the release as _prerelease_, the tag must include `-beta`/`-rc` (with optional numeric suffix). Otherwise the workflow exits early.
- When GitHub marks the release as stable, the tag must match the latest release returned by the GitHub API. This prevents out-of-order updates when publishing historical tags.
- If the guard clauses pass, the version is tagged as `latest` or `beta/rc` based on its semantic suffix and propagated to the script through the `IS_PRERELEASE` flag.
- **Manual dispatch (`workflow_dispatch`)**
- Required input: `tag` (e.g., `v2.0.1`). Optional input: `is_prerelease` (defaults to `false`).
- When `is_prerelease=true`, the tag must carry a beta/rc suffix, mirroring the automatic validation.
- Manual runs still download the latest release metadata so that the workflow knows whether the tag represents the newest stable version (for documentation inside the PR body).
### Workflow Steps
1. **Guard + metadata preparation** the `Check if should proceed` and `Prepare metadata` steps compute the target tag, prerelease flag, whether the tag is the newest release, and a `safe_tag` slug used for branch names. When any rule fails, the workflow stops without touching the config.
2. **Checkout source branches** the default branch is checked out into `main/`, while the long-lived `x-files/app-upgrade-config` branch lives in `cs/`. All modifications happen in the latter directory.
3. **Install toolchain** Node.js 22, Corepack, and frozen Yarn dependencies are installed inside `main/`.
4. **Run the update script** `yarn tsx scripts/update-app-upgrade-config.ts --tag <tag> --config ../cs/app-upgrade-config.json --is-prerelease <flag>` updates the JSON in-place.
- The script normalizes the tag (e.g., strips `v` prefix), detects the release channel (`latest`, `rc`, `beta`), and loads segment rules from `config/app-upgrade-segments.json`.
- It validates that prerelease flags and semantic suffixes agree, enforces locked segments, builds mirror feed URLs, and performs release-availability checks (GitHub HEAD request for every channel; GitCode GET for latest channels, falling back to `https://releases.cherry-ai.com` when gitcode is delayed).
- After updating the relevant channel entry, the script rewrites the config with semver-sort order and a new `lastUpdated` timestamp.
5. **Detect changes + create PR** if `cs/app-upgrade-config.json` changed, the workflow opens a PR `chore/update-app-upgrade-config/<safe_tag>` against `x-files/app-upgrade-config` with a commit message `🤖 chore: sync app-upgrade-config for <tag>`. Otherwise it logs that no update is required.
### Manual Trigger Guide
1. Open the Cherry Studio repository on GitHub → **Actions** tab → select **Update App Upgrade Config**.
2. Click **Run workflow**, choose the default branch (usually `main`), and fill in the `tag` input (e.g., `v2.1.0`).
3. Toggle `is_prerelease` only when the tag carries a prerelease suffix (`-beta`, `-rc`). Leave it unchecked for stable releases.
4. Start the run and wait for it to finish. Check the generated PR in the `x-files/app-upgrade-config` branch, verify the diff in `app-upgrade-config.json`, and merge once validated.
## JSON Configuration File Format
### File Location
- **GitHub**: `https://raw.githubusercontent.com/CherryHQ/cherry-studio/refs/heads/x-files/app-upgrade-config/app-upgrade-config.json`
- **GitCode**: `https://gitcode.com/CherryHQ/cherry-studio/raw/x-files/app-upgrade-config/app-upgrade-config.json`
**Note**: Both mirrors provide the same configuration file hosted on the `x-files/app-upgrade-config` branch. The client automatically selects the optimal mirror based on IP geolocation.
### Configuration Structure (Current Implementation)
```json
{
"lastUpdated": "2025-01-05T00:00:00Z",
"versions": {
"1.6.7": {
"minCompatibleVersion": "1.0.0",
"description": "Last stable v1.7.x release - required intermediate version for users below v1.7",
"channels": {
"latest": {
"version": "1.6.7",
"feedUrls": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.7",
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/v1.6.7"
}
},
"rc": {
"version": "1.6.0-rc.5",
"feedUrls": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.0-rc.5",
"gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.0-rc.5"
}
},
"beta": {
"version": "1.6.7-beta.3",
"feedUrls": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.7.0-beta.3",
"gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.7.0-beta.3"
}
}
}
},
"2.0.0": {
"minCompatibleVersion": "1.7.0",
"description": "Major release v2.0 - required intermediate version for v2.x upgrades",
"channels": {
"latest": null,
"rc": null,
"beta": null
}
}
}
}
```
### Future Extension Example
When releasing v3.0, if users need to first upgrade to v2.8, you can add:
```json
{
"2.8.0": {
"minCompatibleVersion": "2.0.0",
"description": "Stable v2.8 - required for v3 upgrade",
"channels": {
"latest": {
"version": "2.8.0",
"feedUrls": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v2.8.0",
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/v2.8.0"
}
},
"rc": null,
"beta": null
}
},
"3.0.0": {
"minCompatibleVersion": "2.8.0",
"description": "Major release v3.0",
"channels": {
"latest": {
"version": "3.0.0",
"feedUrls": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/latest",
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/latest"
}
},
"rc": {
"version": "3.0.0-rc.1",
"feedUrls": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v3.0.0-rc.1",
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/v3.0.0-rc.1"
}
},
"beta": null
}
}
}
```
### Field Descriptions
- `lastUpdated`: Last update time of the configuration file (ISO 8601 format)
- `versions`: Version configuration object, key is the version number, sorted by semantic versioning
- `minCompatibleVersion`: Minimum compatible version that can upgrade to this version
- `description`: Version description
- `channels`: Update channel configuration
- `latest`: Stable release channel
- `rc`: Release Candidate channel
- `beta`: Beta testing channel
- Each channel contains:
- `version`: Version number for this channel
- `feedUrls`: Multi-mirror URL configuration
- `github`: electron-updater feed URL for GitHub mirror
- `gitcode`: electron-updater feed URL for GitCode mirror
- `metadata`: Stable mapping info for automation
- `segmentId`: ID from `config/app-upgrade-segments.json`
- `segmentType`: Optional flag (`legacy` | `breaking` | `latest`) for documentation/debugging
## TypeScript Type Definitions
```typescript
// Mirror enum
enum UpdateMirror {
GITHUB = 'github',
GITCODE = 'gitcode'
}
interface UpdateConfig {
lastUpdated: string
versions: {
[versionKey: string]: VersionConfig
}
}
interface VersionConfig {
minCompatibleVersion: string
description: string
channels: {
latest: ChannelConfig | null
rc: ChannelConfig | null
beta: ChannelConfig | null
}
metadata?: {
segmentId: string
segmentType?: 'legacy' | 'breaking' | 'latest'
}
}
interface ChannelConfig {
version: string
feedUrls: Record<UpdateMirror, string>
// Equivalent to:
// feedUrls: {
// github: string
// gitcode: string
// }
}
```
## Segment Metadata & Breaking Markers
- **Segment definitions** now live in `config/app-upgrade-segments.json`. Each segment describes a semantic-version range (or exact matches) plus metadata such as `segmentId`, `segmentType`, `minCompatibleVersion`, and per-channel feed URL templates.
- Each entry under `versions` carries a `metadata.segmentId`. This acts as the stable key that scripts use to decide which slot to update, even if the actual semantic version string changes.
- Mark major upgrade gateways (e.g., `2.0.0`) by giving the related segment a `segmentType: "breaking"` and (optionally) `lockedVersion`. This prevents automation from accidentally moving that entry when other 2.x builds ship.
- Adding another breaking hop (e.g., `3.0.0`) only requires defining a new segment in the JSON file; the automation will pick it up on the next run.
## Automation Workflow
Starting from this change, `.github/workflows/update-app-upgrade-config.yml` listens to GitHub release events (published + prerelease). The workflow:
1. Checks out the default branch (for scripts) and the `x-files/app-upgrade-config` branch (where the config is hosted).
2. Runs `yarn tsx scripts/update-app-upgrade-config.ts --tag <tag> --config ../cs/app-upgrade-config.json` to regenerate the config directly inside the `x-files/app-upgrade-config` working tree.
3. If the file changed, it opens a PR against `x-files/app-upgrade-config` via `peter-evans/create-pull-request`, with the generated diff limited to `app-upgrade-config.json`.
You can run the same script locally via `yarn update:upgrade-config --tag v2.1.6 --config ../cs/app-upgrade-config.json` (add `--dry-run` to preview) to reproduce or debug whatever the workflow does. Passing `--skip-release-checks` along with `--dry-run` lets you bypass the release-page existence check (useful when the GitHub/GitCode pages arent published yet). Running without `--config` continues to update the copy in your current working directory (main branch) for documentation purposes.
## Version Matching Logic
### Algorithm Flow
1. Get user's current version (`currentVersion`) and requested channel (`requestedChannel`)
2. Get all version numbers from configuration file, sort in descending order by semantic versioning
3. Iterate through the sorted version list:
- Check if `currentVersion >= minCompatibleVersion`
- Check if the requested `channel` exists and is not `null`
- If conditions are met, return the channel configuration
4. If no matching version is found, return `null`
### Pseudocode Implementation
```typescript
function findCompatibleVersion(
currentVersion: string,
requestedChannel: UpgradeChannel,
config: UpdateConfig
): ChannelConfig | null {
// Get all version numbers and sort in descending order
const versions = Object.keys(config.versions).sort(semver.rcompare)
for (const versionKey of versions) {
const versionConfig = config.versions[versionKey]
const channelConfig = versionConfig.channels[requestedChannel]
// Check version compatibility and channel availability
if (
semver.gte(currentVersion, versionConfig.minCompatibleVersion) &&
channelConfig !== null
) {
return channelConfig
}
}
return null // No compatible version found
}
```
## Upgrade Path Examples
### Scenario 1: v1.6.5 User Upgrade (Below 1.7)
- **Current Version**: 1.6.5
- **Requested Channel**: latest
- **Match Result**: 1.7.0
- **Reason**: 1.6.5 >= 0.0.0 (satisfies 1.7.0's minCompatibleVersion), but doesn't satisfy 2.0.0's minCompatibleVersion (1.7.0)
- **Action**: Prompt user to upgrade to 1.7.0, which is the required intermediate version for v2.x upgrade
### Scenario 2: v1.6.5 User Requests rc/beta
- **Current Version**: 1.6.5
- **Requested Channel**: rc or beta
- **Match Result**: 1.7.0 (latest)
- **Reason**: 1.7.0 version doesn't provide rc/beta channels (values are null)
- **Action**: Upgrade to 1.7.0 stable version
### Scenario 3: v1.7.0 User Upgrades to Latest
- **Current Version**: 1.7.0
- **Requested Channel**: latest
- **Match Result**: 2.0.0
- **Reason**: 1.7.0 >= 1.7.0 (satisfies 2.0.0's minCompatibleVersion)
- **Action**: Directly upgrade to 2.0.0 (current latest stable version)
### Scenario 4: v1.7.2 User Upgrades to RC Version
- **Current Version**: 1.7.2
- **Requested Channel**: rc
- **Match Result**: 2.0.0-rc.1
- **Reason**: 1.7.2 >= 1.7.0 (satisfies 2.0.0's minCompatibleVersion), and rc channel exists
- **Action**: Upgrade to 2.0.0-rc.1
### Scenario 5: v1.7.0 User Upgrades to Beta Version
- **Current Version**: 1.7.0
- **Requested Channel**: beta
- **Match Result**: 2.0.0-beta.1
- **Reason**: 1.7.0 >= 1.7.0, and beta channel exists
- **Action**: Upgrade to 2.0.0-beta.1
### Scenario 6: v2.5.0 User Upgrade (Future)
Assuming v2.8.0 and v3.0.0 configurations have been added:
- **Current Version**: 2.5.0
- **Requested Channel**: latest
- **Match Result**: 2.8.0
- **Reason**: 2.5.0 >= 2.0.0 (satisfies 2.8.0's minCompatibleVersion), but doesn't satisfy 3.0.0's requirement
- **Action**: Prompt user to upgrade to 2.8.0, which is the required intermediate version for v3.x upgrade
## Code Changes
### Main Modifications
1. **New Methods**
- `_fetchUpdateConfig(ipCountry: string): Promise<UpdateConfig | null>` - Fetch configuration file based on IP
- `_findCompatibleChannel(currentVersion: string, channel: UpgradeChannel, config: UpdateConfig): ChannelConfig | null` - Find compatible channel configuration
2. **Modified Methods**
- `_getReleaseVersionFromGithub()` → Remove or refactor to `_getChannelFeedUrl()`
- `_setFeedUrl()` - Use new configuration system to replace existing logic
3. **New Type Definitions**
- `UpdateConfig`
- `VersionConfig`
- `ChannelConfig`
### Mirror Selection Logic
The client automatically selects the optimal mirror based on IP geolocation:
```typescript
private async _setFeedUrl() {
const currentVersion = app.getVersion()
const testPlan = configManager.getTestPlan()
const requestedChannel = testPlan ? this._getTestChannel() : UpgradeChannel.LATEST
// Determine mirror based on IP country
const ipCountry = await getIpCountry()
const mirror = ipCountry.toLowerCase() === 'cn' ? 'gitcode' : 'github'
// Fetch update config
const config = await this._fetchUpdateConfig(mirror)
if (config) {
const channelConfig = this._findCompatibleChannel(currentVersion, requestedChannel, config)
if (channelConfig) {
// Select feed URL from the corresponding mirror
const feedUrl = channelConfig.feedUrls[mirror]
this._setChannel(requestedChannel, feedUrl)
return
}
}
// Fallback logic
const defaultFeedUrl = mirror === 'gitcode'
? FeedUrl.PRODUCTION
: FeedUrl.GITHUB_LATEST
this._setChannel(UpgradeChannel.LATEST, defaultFeedUrl)
}
private async _fetchUpdateConfig(mirror: 'github' | 'gitcode'): Promise<UpdateConfig | null> {
const configUrl = mirror === 'gitcode'
? UpdateConfigUrl.GITCODE
: UpdateConfigUrl.GITHUB
try {
const response = await net.fetch(configUrl, {
headers: {
'User-Agent': generateUserAgent(),
'Accept': 'application/json',
'X-Client-Id': configManager.getClientId()
}
})
return await response.json() as UpdateConfig
} catch (error) {
logger.error('Failed to fetch update config:', error)
return null
}
}
```
## Fallback and Error Handling Strategy
1. **Configuration file fetch failure**: Log error, return current version, don't offer updates
2. **No matching version**: Notify user that current version doesn't support automatic upgrade
3. **Network exception**: Cache last successfully fetched configuration (optional)
## GitHub Release Requirements
To support intermediate version upgrades, the following files need to be retained:
- **v1.7.0 release** and its latest*.yml files (as upgrade target for users below v1.7)
- Future intermediate versions (e.g., v2.8.0) need to retain corresponding release and latest*.yml files
- Complete installation packages for each version
### Currently Required Releases
| Version | Purpose | Must Retain |
|---------|---------|-------------|
| v1.7.0 | Upgrade target for users below 1.7 | ✅ Yes |
| v2.0.0-rc.1 | RC testing channel | ❌ Optional |
| v2.0.0-beta.1 | Beta testing channel | ❌ Optional |
| latest | Latest stable version (automatic) | ✅ Yes |
## Advantages
1. **Flexibility**: Supports arbitrarily complex upgrade paths
2. **Extensibility**: Adding new versions only requires adding new entries to the configuration file
3. **Maintainability**: Configuration is separated from code, allowing upgrade strategy adjustments without releasing new versions
4. **Multi-source support**: Automatically selects optimal configuration source based on geolocation
5. **Version control**: Enforces intermediate version upgrades, ensuring data migration and compatibility
## Future Extensions
- Support more granular version range control (e.g., `>=1.5.0 <1.8.0`)
- Support multi-step upgrade path hints (e.g., notify user needs 1.5 → 1.8 → 2.0)
- Support A/B testing and gradual rollout
- Support local caching and expiration strategy for configuration files

View File

@@ -0,0 +1,430 @@
# 更新配置系统设计文档
## 背景
当前 AppUpdater 直接请求 GitHub API 获取 beta 和 rc 的更新信息。为了支持国内用户,需要根据 IP 地理位置,分别从 GitHub/GitCode 获取一个固定的 JSON 配置文件,该文件包含所有渠道的更新地址。
## 设计目标
1. 支持根据 IP 地理位置选择不同的配置源GitHub/GitCode
2. 支持版本兼容性控制(如 v1.x 以下必须先升级到 v1.7.0 才能升级到 v2.0
3. 易于扩展支持未来多个主版本的升级路径v1.6 → v1.7 → v2.0 → v2.8 → v3.0
4. 保持与现有 electron-updater 机制的兼容性
## 当前版本策略
- **v1.7.x** 是 1.x 系列的最后版本
- **v1.7.0 以下**的用户必须先升级到 v1.7.0(或更高的 1.7.x 版本)
- **v1.7.0 及以上**的用户可以直接升级到 v2.x.x
## 自动化工作流
`x-files/app-upgrade-config/app-upgrade-config.json` 由 [`Update App Upgrade Config`](../../.github/workflows/update-app-upgrade-config.yml) workflow 自动同步。工作流会调用 [`scripts/update-app-upgrade-config.ts`](../../scripts/update-app-upgrade-config.ts) 脚本,根据指定 tag 更新 `x-files/app-upgrade-config` 分支上的配置文件。
### 触发条件
- **Release 事件(`release: released/prereleased`**
- Draft release 会被忽略。
- 当 GitHub 将 release 标记为 *prerelease*tag 必须包含 `-beta`/`-rc`(可带序号),否则直接跳过。
- 当 release 标记为稳定版时tag 必须与 GitHub API 返回的最新稳定版本一致,防止发布历史 tag 时意外挂起工作流。
- 满足上述条件后,工作流会根据语义化版本判断渠道(`latest`/`beta`/`rc`),并通过 `IS_PRERELEASE` 传递给脚本。
- **手动触发(`workflow_dispatch`**
- 必填:`tag`(例:`v2.0.1`);选填:`is_prerelease`(默认 `false`)。
-`is_prerelease=true` 时,同样要求 tag 带有 beta/rc 后缀。
- 手动运行仍会请求 GitHub 最新 release 信息,用于在 PR 说明中标注该 tag 是否是最新稳定版。
### 工作流步骤
1. **检查与元数据准备**`Check if should proceed``Prepare metadata` 步骤会计算 tag、prerelease 标志、是否最新版本以及用于分支名的 `safe_tag`。若任意校验失败,工作流立即退出。
2. **检出分支**:默认分支被检出到 `main/`,长期维护的 `x-files/app-upgrade-config` 分支则在 `cs/` 中,所有改动都发生在 `cs/`
3. **安装工具链**:安装 Node.js 22、启用 Corepack并在 `main/` 目录执行 `yarn install --immutable`
4. **运行更新脚本**:执行 `yarn tsx scripts/update-app-upgrade-config.ts --tag <tag> --config ../cs/app-upgrade-config.json --is-prerelease <flag>`
- 脚本会标准化 tag去掉 `v` 前缀等)、识别渠道、加载 `config/app-upgrade-segments.json` 中的分段规则。
- 校验 prerelease 标志与语义后缀是否匹配、强制锁定的 segment 是否满足、生成镜像的下载地址,并检查 release 是否已经在 GitHub/GitCode 可用latest 渠道在 GitCode 不可用时会回退到 `https://releases.cherry-ai.com`)。
- 更新对应的渠道配置后,脚本会按 semver 排序写回 JSON并刷新 `lastUpdated`
5. **检测变更并创建 PR**:若 `cs/app-upgrade-config.json` 有变更,则创建 `chore/update-app-upgrade-config/<safe_tag>` 分支,提交信息为 `🤖 chore: sync app-upgrade-config for <tag>`,并向 `x-files/app-upgrade-config` 提 PR无变更则输出提示。
### 手动触发指南
1. 进入 Cherry Studio 仓库的 GitHub **Actions** 页面,选择 **Update App Upgrade Config** 工作流。
2. 点击 **Run workflow**,保持默认分支(通常为 `main`),填写 `tag`(如 `v2.1.0`)。
3. 只有在 tag 带 `-beta`/`-rc` 后缀时才勾选 `is_prerelease`,稳定版保持默认。
4. 启动运行并等待完成,随后到 `x-files/app-upgrade-config` 分支的 PR 查看 `app-upgrade-config.json` 的变更并在验证后合并。
## JSON 配置文件格式
### 文件位置
- **GitHub**: `https://raw.githubusercontent.com/CherryHQ/cherry-studio/refs/heads/x-files/app-upgrade-config/app-upgrade-config.json`
- **GitCode**: `https://gitcode.com/CherryHQ/cherry-studio/raw/x-files/app-upgrade-config/app-upgrade-config.json`
**说明**:两个镜像源提供相同的配置文件,统一托管在 `x-files/app-upgrade-config` 分支上。客户端根据 IP 地理位置自动选择最优镜像源。
### 配置结构(当前实际配置)
```json
{
"lastUpdated": "2025-01-05T00:00:00Z",
"versions": {
"1.6.7": {
"minCompatibleVersion": "1.0.0",
"description": "Last stable v1.7.x release - required intermediate version for users below v1.7",
"channels": {
"latest": {
"version": "1.6.7",
"feedUrls": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.7",
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/v1.6.7"
}
},
"rc": {
"version": "1.6.0-rc.5",
"feedUrls": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.0-rc.5",
"gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.6.0-rc.5"
}
},
"beta": {
"version": "1.6.7-beta.3",
"feedUrls": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.7.0-beta.3",
"gitcode": "https://github.com/CherryHQ/cherry-studio/releases/download/v1.7.0-beta.3"
}
}
}
},
"2.0.0": {
"minCompatibleVersion": "1.7.0",
"description": "Major release v2.0 - required intermediate version for v2.x upgrades",
"channels": {
"latest": null,
"rc": null,
"beta": null
}
}
}
}
```
### 未来扩展示例
当需要发布 v3.0 时,如果需要强制用户先升级到 v2.8,可以添加:
```json
{
"2.8.0": {
"minCompatibleVersion": "2.0.0",
"description": "Stable v2.8 - required for v3 upgrade",
"channels": {
"latest": {
"version": "2.8.0",
"feedUrls": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v2.8.0",
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/v2.8.0"
}
},
"rc": null,
"beta": null
}
},
"3.0.0": {
"minCompatibleVersion": "2.8.0",
"description": "Major release v3.0",
"channels": {
"latest": {
"version": "3.0.0",
"feedUrls": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/latest",
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/latest"
}
},
"rc": {
"version": "3.0.0-rc.1",
"feedUrls": {
"github": "https://github.com/CherryHQ/cherry-studio/releases/download/v3.0.0-rc.1",
"gitcode": "https://gitcode.com/CherryHQ/cherry-studio/releases/download/v3.0.0-rc.1"
}
},
"beta": null
}
}
}
```
### 字段说明
- `lastUpdated`: 配置文件最后更新时间ISO 8601 格式)
- `versions`: 版本配置对象key 为版本号,按语义化版本排序
- `minCompatibleVersion`: 可以升级到此版本的最低兼容版本
- `description`: 版本描述
- `channels`: 更新渠道配置
- `latest`: 稳定版渠道
- `rc`: Release Candidate 渠道
- `beta`: Beta 测试渠道
- 每个渠道包含:
- `version`: 该渠道的版本号
- `feedUrls`: 多镜像源 URL 配置
- `github`: GitHub 镜像源的 electron-updater feed URL
- `gitcode`: GitCode 镜像源的 electron-updater feed URL
- `metadata`: 自动化匹配所需的稳定标识
- `segmentId`: 来自 `config/app-upgrade-segments.json` 的段位 ID
- `segmentType`: 可选字段(`legacy` | `breaking` | `latest`),便于文档/调试
## TypeScript 类型定义
```typescript
// 镜像源枚举
enum UpdateMirror {
GITHUB = 'github',
GITCODE = 'gitcode'
}
interface UpdateConfig {
lastUpdated: string
versions: {
[versionKey: string]: VersionConfig
}
}
interface VersionConfig {
minCompatibleVersion: string
description: string
channels: {
latest: ChannelConfig | null
rc: ChannelConfig | null
beta: ChannelConfig | null
}
metadata?: {
segmentId: string
segmentType?: 'legacy' | 'breaking' | 'latest'
}
}
interface ChannelConfig {
version: string
feedUrls: Record<UpdateMirror, string>
// 等同于:
// feedUrls: {
// github: string
// gitcode: string
// }
}
```
## 段位元数据Break Change 标记)
- 所有段位定义(如 `legacy-v1``gateway-v2` 等)集中在 `config/app-upgrade-segments.json`,用于描述匹配范围、`segmentId``segmentType`、默认 `minCompatibleVersion/description` 以及各渠道的 URL 模板。
- `versions` 下的每个节点都会带上 `metadata.segmentId`。自动脚本始终依据该 ID 来定位并更新条目,即便 key 从 `2.1.5` 切换到 `2.1.6` 也不会错位。
- 如果某段需要锁死在特定版本(例如 `2.0.0` 的 break change可在段定义中设置 `segmentType: "breaking"` 并提供 `lockedVersion`,脚本在遇到不匹配的 tag 时会短路报错,保证升级路径安全。
- 面对未来新的断层(例如 `3.0.0`),只需要在段定义里新增一段,自动化即可识别并更新。
## 自动化工作流
`.github/workflows/update-app-upgrade-config.yml` 会在 GitHub Release包含正常发布与 Pre Release触发
1. 同时 Checkout 仓库默认分支(用于脚本)和 `x-files/app-upgrade-config` 分支(真实托管配置的分支)。
2. 在默认分支目录执行 `yarn tsx scripts/update-app-upgrade-config.ts --tag <tag> --config ../cs/app-upgrade-config.json`,直接重写 `x-files/app-upgrade-config` 分支里的配置文件。
3. 如果 `app-upgrade-config.json` 有变化,则通过 `peter-evans/create-pull-request` 自动创建一个指向 `x-files/app-upgrade-config` 的 PRDiff 仅包含该文件。
如需本地调试,可执行 `yarn update:upgrade-config --tag v2.1.6 --config ../cs/app-upgrade-config.json`(加 `--dry-run` 仅打印结果)来复现 CI 行为。若需要暂时跳过 GitHub/GitCode Release 页面是否就绪的校验,可在 `--dry-run` 的同时附加 `--skip-release-checks`。不加 `--config` 时默认更新当前工作目录(通常是 main 分支)下的副本,方便文档/审查。
## 版本匹配逻辑
### 算法流程
1. 获取用户当前版本(`currentVersion`)和请求的渠道(`requestedChannel`
2. 获取配置文件中所有版本号,按语义化版本从大到小排序
3. 遍历排序后的版本列表:
- 检查 `currentVersion >= minCompatibleVersion`
- 检查请求的 `channel` 是否存在且不为 `null`
- 如果满足条件,返回该渠道配置
4. 如果没有找到匹配版本,返回 `null`
### 伪代码实现
```typescript
function findCompatibleVersion(
currentVersion: string,
requestedChannel: UpgradeChannel,
config: UpdateConfig
): ChannelConfig | null {
// 获取所有版本号并从大到小排序
const versions = Object.keys(config.versions).sort(semver.rcompare)
for (const versionKey of versions) {
const versionConfig = config.versions[versionKey]
const channelConfig = versionConfig.channels[requestedChannel]
// 检查版本兼容性和渠道可用性
if (
semver.gte(currentVersion, versionConfig.minCompatibleVersion) &&
channelConfig !== null
) {
return channelConfig
}
}
return null // 没有找到兼容版本
}
```
## 升级路径示例
### 场景 1: v1.6.5 用户升级(低于 1.7
- **当前版本**: 1.6.5
- **请求渠道**: latest
- **匹配结果**: 1.7.0
- **原因**: 1.6.5 >= 0.0.0(满足 1.7.0 的 minCompatibleVersion但不满足 2.0.0 的 minCompatibleVersion (1.7.0)
- **操作**: 提示用户升级到 1.7.0,这是升级到 v2.x 的必要中间版本
### 场景 2: v1.6.5 用户请求 rc/beta
- **当前版本**: 1.6.5
- **请求渠道**: rc 或 beta
- **匹配结果**: 1.7.0 (latest)
- **原因**: 1.7.0 版本不提供 rc/beta 渠道(值为 null
- **操作**: 升级到 1.7.0 稳定版
### 场景 3: v1.7.0 用户升级到最新版
- **当前版本**: 1.7.0
- **请求渠道**: latest
- **匹配结果**: 2.0.0
- **原因**: 1.7.0 >= 1.7.0(满足 2.0.0 的 minCompatibleVersion
- **操作**: 直接升级到 2.0.0(当前最新稳定版)
### 场景 4: v1.7.2 用户升级到 RC 版本
- **当前版本**: 1.7.2
- **请求渠道**: rc
- **匹配结果**: 2.0.0-rc.1
- **原因**: 1.7.2 >= 1.7.0(满足 2.0.0 的 minCompatibleVersion且 rc 渠道存在
- **操作**: 升级到 2.0.0-rc.1
### 场景 5: v1.7.0 用户升级到 Beta 版本
- **当前版本**: 1.7.0
- **请求渠道**: beta
- **匹配结果**: 2.0.0-beta.1
- **原因**: 1.7.0 >= 1.7.0,且 beta 渠道存在
- **操作**: 升级到 2.0.0-beta.1
### 场景 6: v2.5.0 用户升级(未来)
假设已添加 v2.8.0 和 v3.0.0 配置:
- **当前版本**: 2.5.0
- **请求渠道**: latest
- **匹配结果**: 2.8.0
- **原因**: 2.5.0 >= 2.0.0(满足 2.8.0 的 minCompatibleVersion但不满足 3.0.0 的要求
- **操作**: 提示用户升级到 2.8.0,这是升级到 v3.x 的必要中间版本
## 代码改动计划
### 主要修改
1. **新增方法**
- `_fetchUpdateConfig(ipCountry: string): Promise<UpdateConfig | null>` - 根据 IP 获取配置文件
- `_findCompatibleChannel(currentVersion: string, channel: UpgradeChannel, config: UpdateConfig): ChannelConfig | null` - 查找兼容的渠道配置
2. **修改方法**
- `_getReleaseVersionFromGithub()` → 移除或重构为 `_getChannelFeedUrl()`
- `_setFeedUrl()` - 使用新的配置系统替代现有逻辑
3. **新增类型定义**
- `UpdateConfig`
- `VersionConfig`
- `ChannelConfig`
### 镜像源选择逻辑
客户端根据 IP 地理位置自动选择最优镜像源:
```typescript
private async _setFeedUrl() {
const currentVersion = app.getVersion()
const testPlan = configManager.getTestPlan()
const requestedChannel = testPlan ? this._getTestChannel() : UpgradeChannel.LATEST
// 根据 IP 国家确定镜像源
const ipCountry = await getIpCountry()
const mirror = ipCountry.toLowerCase() === 'cn' ? 'gitcode' : 'github'
// 获取更新配置
const config = await this._fetchUpdateConfig(mirror)
if (config) {
const channelConfig = this._findCompatibleChannel(currentVersion, requestedChannel, config)
if (channelConfig) {
// 从配置中选择对应镜像源的 URL
const feedUrl = channelConfig.feedUrls[mirror]
this._setChannel(requestedChannel, feedUrl)
return
}
}
// Fallback 逻辑
const defaultFeedUrl = mirror === 'gitcode'
? FeedUrl.PRODUCTION
: FeedUrl.GITHUB_LATEST
this._setChannel(UpgradeChannel.LATEST, defaultFeedUrl)
}
private async _fetchUpdateConfig(mirror: 'github' | 'gitcode'): Promise<UpdateConfig | null> {
const configUrl = mirror === 'gitcode'
? UpdateConfigUrl.GITCODE
: UpdateConfigUrl.GITHUB
try {
const response = await net.fetch(configUrl, {
headers: {
'User-Agent': generateUserAgent(),
'Accept': 'application/json',
'X-Client-Id': configManager.getClientId()
}
})
return await response.json() as UpdateConfig
} catch (error) {
logger.error('Failed to fetch update config:', error)
return null
}
}
```
## 降级和容错策略
1. **配置文件获取失败**: 记录错误日志,返回当前版本,不提供更新
2. **没有匹配的版本**: 提示用户当前版本不支持自动升级
3. **网络异常**: 缓存上次成功获取的配置(可选)
## GitHub Release 要求
为支持中间版本升级,需要保留以下文件:
- **v1.7.0 release** 及其 latest*.yml 文件(作为 v1.7 以下用户的升级目标)
- 未来如需强制中间版本(如 v2.8.0),需要保留对应的 release 和 latest*.yml 文件
- 各版本的完整安装包
### 当前需要的 Release
| 版本 | 用途 | 必须保留 |
|------|------|---------|
| v1.7.0 | 1.7 以下用户的升级目标 | ✅ 是 |
| v2.0.0-rc.1 | RC 测试渠道 | ❌ 可选 |
| v2.0.0-beta.1 | Beta 测试渠道 | ❌ 可选 |
| latest | 最新稳定版(自动) | ✅ 是 |
## 优势
1. **灵活性**: 支持任意复杂的升级路径
2. **可扩展性**: 新增版本只需在配置文件中添加新条目
3. **可维护性**: 配置与代码分离,无需发版即可调整升级策略
4. **多源支持**: 自动根据地理位置选择最优配置源
5. **版本控制**: 强制中间版本升级,确保数据迁移和兼容性
## 未来扩展
- 支持更细粒度的版本范围控制(如 `>=1.5.0 <1.8.0`
- 支持多步升级路径提示(如提示用户需要 1.5 → 1.8 → 2.0
- 支持 A/B 测试和灰度发布
- 支持配置文件的本地缓存和过期策略

View File

@@ -97,7 +97,6 @@ mac:
entitlementsInherit: build/entitlements.mac.plist
notarize: false
artifactName: ${productName}-${version}-${arch}.${ext}
minimumSystemVersion: "20.1.0" # 最低支持 macOS 11.0
extendInfo:
- NSCameraUsageDescription: Application requests access to the device's camera.
- NSMicrophoneUsageDescription: Application requests access to the device's microphone.
@@ -135,42 +134,58 @@ artifactBuildCompleted: scripts/artifact-build-completed.js
releaseInfo:
releaseNotes: |
<!--LANG:en-->
What's New in v1.7.0-beta.6
What's New in v1.7.0-rc.1
New Features:
- Enhanced Input Bar: Completely redesigned input bar with improved responsiveness and functionality
- Better File Handling: Improved drag-and-drop and paste support for images and documents
- Smart Tool Suggestions: Enhanced quick panel with better item selection and keyboard shortcuts
🎉 MAJOR NEW FEATURE: AI Agents
- Create and manage custom AI agents with specialized tools and permissions
- Dedicated agent sessions with persistent SQLite storage, separate from regular chats
- Real-time tool approval system - review and approve agent actions dynamically
- MCP (Model Context Protocol) integration for connecting external tools
- Slash commands support for quick agent interactions
- OpenAI-compatible REST API for agent access
Improvements:
- Smoother Input Experience: Better auto-resizing and text handling in chat input
- Enhanced AI Performance: Improved connection stability and response speed
- More Reliable File Uploads: Better support for various file types and upload scenarios
- Cleaner Interface: Optimized UI elements for better visual consistency
✨ New Features:
- AI Providers: Added support for Hugging Face, Mistral, Perplexity, and SophNet
- Knowledge Base: OpenMinerU document preprocessor, full-text search in notes, enhanced tool selection
- Image & OCR: Intel OVMS painting provider and Intel OpenVINO (NPU) OCR support
- MCP Management: Redesigned interface with dual-column layout for easier management
- Languages: Added German language support
Bug Fixes:
- Fixed image selection issue when adding custom AI providers
- Fixed file upload problems with certain API configurations
- Fixed input bar responsiveness issues
- Fixed quick panel not working properly in some situations
⚡ Improvements:
- Upgraded to Electron 38.7.0
- Enhanced system shutdown handling and automatic update checks
- Improved proxy bypass rules
🐛 Important Bug Fixes:
- Fixed streaming response issues across multiple AI providers
- Fixed session list scrolling problems
- Fixed knowledge base deletion errors
<!--LANG:zh-CN-->
v1.7.0-beta.6 新特性
v1.7.0-rc.1 新特性
新功能:
- 增强输入栏:完全重新设计的输入栏,响应更灵敏,功能更强大
- 更好的文件处理:改进的拖拽和粘贴功能,支持图片和文档
- 智能工具建议:增强的快速面板,更好的项目选择和键盘快捷键
🎉 重大更新AI Agent 智能体系统
- 创建和管理专属 AI Agent配置专用工具和权限
- 独立的 Agent 会话,使用 SQLite 持久化存储,与普通聊天分离
- 实时工具审批系统 - 动态审查和批准 Agent 操作
- MCP模型上下文协议集成连接外部工具
- 支持斜杠命令快速交互
- 兼容 OpenAI 的 REST API 访问
改进:
- 更流畅的输入体验:聊天输入框的自动调整和文本处理更佳
- 增强 AI 性能:改进连接稳定性和响应速度
- 更可靠的文件上传:更好地支持各种文件类型和上传场景
- 更简洁的界面:优化 UI 元素,视觉一致性更好
✨ 新功能:
- AI 提供商:新增 Hugging Face、Mistral、Perplexity 和 SophNet 支持
- 知识库OpenMinerU 文档预处理器、笔记全文搜索、增强的工具选择
- 图像与 OCRIntel OVMS 绘图提供商和 Intel OpenVINO (NPU) OCR 支持
- MCP 管理:重构管理界面,采用双列布局,更加方便管理
- 语言:新增德语支持
问题修复:
- 修复添加自定义 AI 提供商时的图片选择问题
- 修复某些 API 配置下的文件上传问题
- 修复输入栏响应性问题
- 修复快速面板在某些情况下无法正常工作的问题
⚡ 改进:
- 升级到 Electron 38.7.0
- 增强的系统关机处理和自动更新检查
- 改进的代理绕过规则
🐛 重要修复:
- 修复多个 AI 提供商的流式响应问题
- 修复会话列表滚动问题
- 修复知识库删除错误
<!--LANG:END-->

View File

@@ -1,6 +1,6 @@
{
"name": "CherryStudio",
"version": "1.7.0-beta.3",
"version": "1.7.0-rc.1",
"private": true,
"description": "A powerful AI assistant for producer.",
"main": "./out/main/index.js",
@@ -58,6 +58,7 @@
"update:i18n": "dotenv -e .env -- tsx scripts/update-i18n.ts",
"auto:i18n": "dotenv -e .env -- tsx scripts/auto-translate-i18n.ts",
"update:languages": "tsx scripts/update-languages.ts",
"update:upgrade-config": "tsx scripts/update-app-upgrade-config.ts",
"test": "vitest run --silent",
"test:main": "vitest run --project main",
"test:renderer": "vitest run --project renderer",
@@ -73,9 +74,10 @@
"format:check": "biome format && biome lint",
"prepare": "git config blame.ignoreRevsFile .git-blame-ignore-revs && husky",
"claude": "dotenv -e .env -- claude",
"release:aicore:alpha": "yarn workspace @cherrystudio/ai-core version prerelease --immediate && yarn workspace @cherrystudio/ai-core npm publish --tag alpha --access public",
"release:aicore:beta": "yarn workspace @cherrystudio/ai-core version prerelease --immediate && yarn workspace @cherrystudio/ai-core npm publish --tag beta --access public",
"release:aicore": "yarn workspace @cherrystudio/ai-core version patch --immediate && yarn workspace @cherrystudio/ai-core npm publish --access public"
"release:aicore:alpha": "yarn workspace @cherrystudio/ai-core version prerelease --preid alpha --immediate && yarn workspace @cherrystudio/ai-core build && yarn workspace @cherrystudio/ai-core npm publish --tag alpha --access public",
"release:aicore:beta": "yarn workspace @cherrystudio/ai-core version prerelease --preid beta --immediate && yarn workspace @cherrystudio/ai-core build && yarn workspace @cherrystudio/ai-core npm publish --tag beta --access public",
"release:aicore": "yarn workspace @cherrystudio/ai-core version patch --immediate && yarn workspace @cherrystudio/ai-core build && yarn workspace @cherrystudio/ai-core npm publish --access public",
"release:ai-sdk-provider": "yarn workspace @cherrystudio/ai-sdk-provider version patch --immediate && yarn workspace @cherrystudio/ai-sdk-provider build && yarn workspace @cherrystudio/ai-sdk-provider npm publish --access public"
},
"dependencies": {
"@anthropic-ai/claude-agent-sdk": "patch:@anthropic-ai/claude-agent-sdk@npm%3A0.1.30#~/.yarn/patches/@anthropic-ai-claude-agent-sdk-npm-0.1.30-b50a299674.patch",
@@ -84,6 +86,7 @@
"@napi-rs/system-ocr": "patch:@napi-rs/system-ocr@npm%3A1.0.2#~/.yarn/patches/@napi-rs-system-ocr-npm-1.0.2-59e7a78e8b.patch",
"@paymoapp/electron-shutdown-handler": "^1.1.2",
"@strongtz/win32-arm64-msvc": "^0.4.7",
"emoji-picker-element-data": "^1",
"express": "^5.1.0",
"font-list": "^2.0.0",
"graceful-fs": "^4.2.11",
@@ -106,13 +109,17 @@
"@agentic/exa": "^7.3.3",
"@agentic/searxng": "^7.3.3",
"@agentic/tavily": "^7.3.3",
"@ai-sdk/amazon-bedrock": "^3.0.53",
"@ai-sdk/amazon-bedrock": "^3.0.56",
"@ai-sdk/anthropic": "^2.0.45",
"@ai-sdk/cerebras": "^1.0.31",
"@ai-sdk/gateway": "^2.0.9",
"@ai-sdk/google-vertex": "^3.0.62",
"@ai-sdk/huggingface": "patch:@ai-sdk/huggingface@npm%3A0.0.8#~/.yarn/patches/@ai-sdk-huggingface-npm-0.0.8-d4d0aaac93.patch",
"@ai-sdk/mistral": "^2.0.23",
"@ai-sdk/perplexity": "^2.0.17",
"@ai-sdk/gateway": "^2.0.13",
"@ai-sdk/google": "patch:@ai-sdk/google@npm%3A2.0.40#~/.yarn/patches/@ai-sdk-google-npm-2.0.40-47e0eeee83.patch",
"@ai-sdk/google-vertex": "^3.0.72",
"@ai-sdk/huggingface": "^0.0.10",
"@ai-sdk/mistral": "^2.0.24",
"@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.71#~/.yarn/patches/@ai-sdk-openai-npm-2.0.71-a88ef00525.patch",
"@ai-sdk/perplexity": "^2.0.20",
"@ai-sdk/test-server": "^0.0.1",
"@ant-design/v5-patch-for-react-19": "^1.0.3",
"@anthropic-ai/sdk": "^0.41.0",
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
@@ -120,7 +127,7 @@
"@aws-sdk/client-bedrock-runtime": "^3.910.0",
"@aws-sdk/client-s3": "^3.910.0",
"@biomejs/biome": "2.2.4",
"@cherrystudio/ai-core": "workspace:^1.0.0-alpha.18",
"@cherrystudio/ai-core": "workspace:^1.0.9",
"@cherrystudio/embedjs": "^0.1.31",
"@cherrystudio/embedjs-libsql": "^0.1.31",
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
@@ -134,7 +141,7 @@
"@cherrystudio/embedjs-ollama": "^0.1.31",
"@cherrystudio/embedjs-openai": "^0.1.31",
"@cherrystudio/extension-table-plus": "workspace:^",
"@cherrystudio/openai": "^6.5.0",
"@cherrystudio/openai": "^6.9.0",
"@dnd-kit/core": "^6.3.1",
"@dnd-kit/modifiers": "^9.0.0",
"@dnd-kit/sortable": "^10.0.0",
@@ -153,18 +160,19 @@
"@langchain/community": "^1.0.0",
"@langchain/core": "patch:@langchain/core@npm%3A1.0.2#~/.yarn/patches/@langchain-core-npm-1.0.2-183ef83fe4.patch",
"@langchain/openai": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
"@mcp-ui/client": "^5.14.1",
"@mistralai/mistralai": "^1.7.5",
"@modelcontextprotocol/sdk": "^1.17.5",
"@mozilla/readability": "^0.6.0",
"@notionhq/client": "^2.2.15",
"@openrouter/ai-sdk-provider": "^1.2.0",
"@openrouter/ai-sdk-provider": "^1.2.5",
"@opentelemetry/api": "^1.9.0",
"@opentelemetry/core": "2.0.0",
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
"@opentelemetry/sdk-trace-base": "^2.0.0",
"@opentelemetry/sdk-trace-node": "^2.0.0",
"@opentelemetry/sdk-trace-web": "^2.0.0",
"@opeoginni/github-copilot-openai-compatible": "0.1.19",
"@opeoginni/github-copilot-openai-compatible": "0.1.21",
"@playwright/test": "^1.52.0",
"@radix-ui/react-context-menu": "^2.2.16",
"@reduxjs/toolkit": "^2.2.5",
@@ -233,7 +241,7 @@
"@viz-js/lang-dot": "^1.0.5",
"@viz-js/viz": "^3.14.0",
"@xyflow/react": "^12.4.4",
"ai": "^5.0.90",
"ai": "^5.0.98",
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
"archiver": "^7.0.1",
"async-mutex": "^0.5.0",
@@ -259,12 +267,12 @@
"dotenv-cli": "^7.4.2",
"drizzle-kit": "^0.31.4",
"drizzle-orm": "^0.44.5",
"electron": "38.4.0",
"electron-builder": "26.0.15",
"electron": "38.7.0",
"electron-builder": "26.1.0",
"electron-devtools-installer": "^3.2.0",
"electron-reload": "^2.0.0-alpha.1",
"electron-store": "^8.2.0",
"electron-updater": "6.6.4",
"electron-updater": "patch:electron-updater@npm%3A6.7.0#~/.yarn/patches/electron-updater-npm-6.7.0-47b11bb0d4.patch",
"electron-vite": "4.0.1",
"electron-window-state": "^5.0.3",
"emittery": "^1.0.3",
@@ -381,13 +389,11 @@
"@codemirror/lint": "6.8.5",
"@codemirror/view": "6.38.1",
"@langchain/core@npm:^0.3.26": "patch:@langchain/core@npm%3A1.0.2#~/.yarn/patches/@langchain-core-npm-1.0.2-183ef83fe4.patch",
"app-builder-lib@npm:26.0.13": "patch:app-builder-lib@npm%3A26.0.13#~/.yarn/patches/app-builder-lib-npm-26.0.13-a064c9e1d0.patch",
"app-builder-lib@npm:26.0.15": "patch:app-builder-lib@npm%3A26.0.15#~/.yarn/patches/app-builder-lib-npm-26.0.15-360e5b0476.patch",
"atomically@npm:^1.7.0": "patch:atomically@npm%3A1.7.0#~/.yarn/patches/atomically-npm-1.7.0-e742e5293b.patch",
"esbuild": "^0.25.0",
"file-stream-rotator@npm:^0.6.1": "patch:file-stream-rotator@npm%3A0.6.1#~/.yarn/patches/file-stream-rotator-npm-0.6.1-eab45fb13d.patch",
"libsql@npm:^0.4.4": "patch:libsql@npm%3A0.4.7#~/.yarn/patches/libsql-npm-0.4.7-444e260fb1.patch",
"node-abi": "4.12.0",
"node-abi": "4.24.0",
"openai@npm:^4.77.0": "npm:@cherrystudio/openai@6.5.0",
"openai@npm:^4.87.3": "npm:@cherrystudio/openai@6.5.0",
"pdf-parse@npm:1.1.1": "patch:pdf-parse@npm%3A1.1.1#~/.yarn/patches/pdf-parse-npm-1.1.1-04a6109b2a.patch",
@@ -408,8 +414,11 @@
"@langchain/openai@npm:^0.3.16": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
"@langchain/openai@npm:>=0.2.0 <0.7.0": "patch:@langchain/openai@npm%3A1.0.0#~/.yarn/patches/@langchain-openai-npm-1.0.0-474d0ad9d4.patch",
"@ai-sdk/openai@npm:2.0.64": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch",
"@ai-sdk/openai@npm:^2.0.42": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch",
"@ai-sdk/google@npm:2.0.31": "patch:@ai-sdk/google@npm%3A2.0.31#~/.yarn/patches/@ai-sdk-google-npm-2.0.31-b0de047210.patch"
"@ai-sdk/openai@npm:^2.0.42": "patch:@ai-sdk/openai@npm%3A2.0.71#~/.yarn/patches/@ai-sdk-openai-npm-2.0.71-a88ef00525.patch",
"@ai-sdk/google@npm:2.0.40": "patch:@ai-sdk/google@npm%3A2.0.40#~/.yarn/patches/@ai-sdk-google-npm-2.0.40-47e0eeee83.patch",
"@ai-sdk/openai@npm:2.0.71": "patch:@ai-sdk/openai@npm%3A2.0.71#~/.yarn/patches/@ai-sdk-openai-npm-2.0.71-a88ef00525.patch",
"@ai-sdk/openai-compatible@npm:1.0.27": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch",
"@ai-sdk/openai-compatible@npm:^1.0.19": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch"
},
"packageManager": "yarn@4.9.1",
"lint-staged": {

View File

@@ -1,6 +1,6 @@
{
"name": "@cherrystudio/ai-sdk-provider",
"version": "0.1.0",
"version": "0.1.3",
"description": "Cherry Studio AI SDK provider bundle with CherryIN routing.",
"keywords": [
"ai-sdk",
@@ -42,7 +42,7 @@
},
"dependencies": {
"@ai-sdk/provider": "^2.0.0",
"@ai-sdk/provider-utils": "^3.0.12"
"@ai-sdk/provider-utils": "^3.0.17"
},
"devDependencies": {
"tsdown": "^0.13.3",

View File

@@ -67,6 +67,10 @@ export interface CherryInProviderSettings {
* Optional static headers applied to every request.
*/
headers?: HeadersInput
/**
* Optional endpoint type to distinguish different endpoint behaviors.
*/
endpointType?: 'openai' | 'openai-response' | 'anthropic' | 'gemini' | 'image-generation' | 'jina-rerank'
}
export interface CherryInProvider extends ProviderV2 {
@@ -151,7 +155,8 @@ export const createCherryIn = (options: CherryInProviderSettings = {}): CherryIn
baseURL = DEFAULT_CHERRYIN_BASE_URL,
anthropicBaseURL = DEFAULT_CHERRYIN_ANTHROPIC_BASE_URL,
geminiBaseURL = DEFAULT_CHERRYIN_GEMINI_BASE_URL,
fetch
fetch,
endpointType
} = options
const getJsonHeaders = createJsonHeadersGetter(options)
@@ -205,7 +210,7 @@ export const createCherryIn = (options: CherryInProviderSettings = {}): CherryIn
fetch
})
const createChatModel = (modelId: string, settings: OpenAIProviderSettings = {}) => {
const createChatModelByModelId = (modelId: string, settings: OpenAIProviderSettings = {}) => {
if (isAnthropicModel(modelId)) {
return createAnthropicModel(modelId)
}
@@ -223,6 +228,29 @@ export const createCherryIn = (options: CherryInProviderSettings = {}): CherryIn
})
}
const createChatModel = (modelId: string, settings: OpenAIProviderSettings = {}) => {
if (!endpointType) return createChatModelByModelId(modelId, settings)
switch (endpointType) {
case 'anthropic':
return createAnthropicModel(modelId)
case 'gemini':
return createGeminiModel(modelId)
case 'openai':
return createOpenAIChatModel(modelId)
case 'openai-response':
default:
return new OpenAIResponsesLanguageModel(modelId, {
provider: `${CHERRYIN_PROVIDER_NAME}.openai`,
url,
headers: () => ({
...getJsonHeaders(),
...settings.headers
}),
fetch
})
}
}
const createCompletionModel = (modelId: string, settings: OpenAIProviderSettings = {}) =>
new OpenAICompletionLanguageModel(modelId, {
provider: `${CHERRYIN_PROVIDER_NAME}.completion`,

View File

@@ -71,7 +71,7 @@ Cherry Studio AI Core 是一个基于 Vercel AI SDK 的统一 AI Provider 接口
## 安装
```bash
npm install @cherrystudio/ai-core ai
npm install @cherrystudio/ai-core ai @ai-sdk/google @ai-sdk/openai
```
### React Native

View File

@@ -1,6 +1,6 @@
{
"name": "@cherrystudio/ai-core",
"version": "1.0.1",
"version": "1.0.9",
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
"main": "dist/index.js",
"module": "dist/index.mjs",
@@ -33,19 +33,19 @@
},
"homepage": "https://github.com/CherryHQ/cherry-studio#readme",
"peerDependencies": {
"@ai-sdk/google": "^2.0.36",
"@ai-sdk/openai": "^2.0.64",
"@cherrystudio/ai-sdk-provider": "^0.1.3",
"ai": "^5.0.26"
},
"dependencies": {
"@ai-sdk/anthropic": "^2.0.43",
"@ai-sdk/azure": "^2.0.66",
"@ai-sdk/deepseek": "^1.0.27",
"@ai-sdk/google": "patch:@ai-sdk/google@npm%3A2.0.31#~/.yarn/patches/@ai-sdk-google-npm-2.0.31-b0de047210.patch",
"@ai-sdk/openai": "patch:@ai-sdk/openai@npm%3A2.0.64#~/.yarn/patches/@ai-sdk-openai-npm-2.0.64-48f99f5bf3.patch",
"@ai-sdk/openai-compatible": "^1.0.26",
"@ai-sdk/anthropic": "^2.0.45",
"@ai-sdk/azure": "^2.0.73",
"@ai-sdk/deepseek": "^1.0.29",
"@ai-sdk/openai-compatible": "patch:@ai-sdk/openai-compatible@npm%3A1.0.27#~/.yarn/patches/@ai-sdk-openai-compatible-npm-1.0.27-06f74278cf.patch",
"@ai-sdk/provider": "^2.0.0",
"@ai-sdk/provider-utils": "^3.0.16",
"@ai-sdk/xai": "^2.0.31",
"@cherrystudio/ai-sdk-provider": "workspace:*",
"@ai-sdk/provider-utils": "^3.0.17",
"@ai-sdk/xai": "^2.0.34",
"zod": "^4.1.5"
},
"devDependencies": {

View File

@@ -0,0 +1,180 @@
/**
* Mock Provider Instances
* Provides mock implementations for all supported AI providers
*/
import type { ImageModelV2, LanguageModelV2 } from '@ai-sdk/provider'
import { vi } from 'vitest'
/**
* Creates a mock language model with customizable behavior
*/
export function createMockLanguageModel(overrides?: Partial<LanguageModelV2>): LanguageModelV2 {
return {
specificationVersion: 'v1',
provider: 'mock-provider',
modelId: 'mock-model',
defaultObjectGenerationMode: 'tool',
doGenerate: vi.fn().mockResolvedValue({
text: 'Mock response text',
finishReason: 'stop',
usage: {
promptTokens: 10,
completionTokens: 20,
totalTokens: 30
},
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
}),
doStream: vi.fn().mockReturnValue({
stream: (async function* () {
yield {
type: 'text-delta',
textDelta: 'Mock '
}
yield {
type: 'text-delta',
textDelta: 'streaming '
}
yield {
type: 'text-delta',
textDelta: 'response'
}
yield {
type: 'finish',
finishReason: 'stop',
usage: {
promptTokens: 10,
completionTokens: 15,
totalTokens: 25
}
}
})(),
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
}),
...overrides
} as LanguageModelV2
}
/**
* Creates a mock image model with customizable behavior
*/
export function createMockImageModel(overrides?: Partial<ImageModelV2>): ImageModelV2 {
return {
specificationVersion: 'v2',
provider: 'mock-provider',
modelId: 'mock-image-model',
doGenerate: vi.fn().mockResolvedValue({
images: [
{
base64: 'mock-base64-image-data',
uint8Array: new Uint8Array([1, 2, 3, 4, 5]),
mimeType: 'image/png'
}
],
warnings: []
}),
...overrides
} as ImageModelV2
}
/**
* Mock provider configurations for testing
*/
export const mockProviderConfigs = {
openai: {
apiKey: 'sk-test-openai-key-123456789',
baseURL: 'https://api.openai.com/v1',
organization: 'test-org'
},
anthropic: {
apiKey: 'sk-ant-test-key-123456789',
baseURL: 'https://api.anthropic.com'
},
google: {
apiKey: 'test-google-api-key-123456789',
baseURL: 'https://generativelanguage.googleapis.com/v1'
},
xai: {
apiKey: 'xai-test-key-123456789',
baseURL: 'https://api.x.ai/v1'
},
azure: {
apiKey: 'test-azure-key-123456789',
resourceName: 'test-resource',
deployment: 'test-deployment'
},
deepseek: {
apiKey: 'sk-test-deepseek-key-123456789',
baseURL: 'https://api.deepseek.com/v1'
},
openrouter: {
apiKey: 'sk-or-test-key-123456789',
baseURL: 'https://openrouter.ai/api/v1'
},
huggingface: {
apiKey: 'hf_test_key_123456789',
baseURL: 'https://api-inference.huggingface.co'
},
'openai-compatible': {
apiKey: 'test-compatible-key-123456789',
baseURL: 'https://api.example.com/v1',
name: 'test-provider'
},
'openai-chat': {
apiKey: 'sk-test-chat-key-123456789',
baseURL: 'https://api.openai.com/v1'
}
} as const
/**
* Mock provider instances for testing
*/
export const mockProviderInstances = {
openai: {
name: 'openai-mock',
languageModel: createMockLanguageModel({ provider: 'openai', modelId: 'gpt-4' }),
imageModel: createMockImageModel({ provider: 'openai', modelId: 'dall-e-3' })
},
anthropic: {
name: 'anthropic-mock',
languageModel: createMockLanguageModel({ provider: 'anthropic', modelId: 'claude-3-5-sonnet-20241022' })
},
google: {
name: 'google-mock',
languageModel: createMockLanguageModel({ provider: 'google', modelId: 'gemini-2.0-flash-exp' }),
imageModel: createMockImageModel({ provider: 'google', modelId: 'imagen-3.0-generate-001' })
},
xai: {
name: 'xai-mock',
languageModel: createMockLanguageModel({ provider: 'xai', modelId: 'grok-2-latest' }),
imageModel: createMockImageModel({ provider: 'xai', modelId: 'grok-2-image-latest' })
},
deepseek: {
name: 'deepseek-mock',
languageModel: createMockLanguageModel({ provider: 'deepseek', modelId: 'deepseek-chat' })
}
}
export type ProviderId = keyof typeof mockProviderConfigs

View File

@@ -0,0 +1,331 @@
/**
* Mock Responses
* Provides realistic mock responses for all provider types
*/
import { jsonSchema, type ModelMessage, type Tool } from 'ai'
/**
* Standard test messages for all scenarios
*/
export const testMessages = {
simple: [{ role: 'user' as const, content: 'Hello, how are you?' }],
conversation: [
{ role: 'user' as const, content: 'What is the capital of France?' },
{ role: 'assistant' as const, content: 'The capital of France is Paris.' },
{ role: 'user' as const, content: 'What is its population?' }
],
withSystem: [
{ role: 'system' as const, content: 'You are a helpful assistant that provides concise answers.' },
{ role: 'user' as const, content: 'Explain quantum computing in one sentence.' }
],
withImages: [
{
role: 'user' as const,
content: [
{ type: 'text' as const, text: 'What is in this image?' },
{
type: 'image' as const,
image:
''
}
]
}
],
toolUse: [{ role: 'user' as const, content: 'What is the weather in San Francisco?' }],
multiTurn: [
{ role: 'user' as const, content: 'Can you help me with a math problem?' },
{ role: 'assistant' as const, content: 'Of course! What math problem would you like help with?' },
{ role: 'user' as const, content: 'What is 15 * 23?' },
{ role: 'assistant' as const, content: '15 * 23 = 345' },
{ role: 'user' as const, content: 'Now divide that by 5' }
]
} satisfies Record<string, ModelMessage[]>
/**
* Standard test tools for tool calling scenarios
*/
export const testTools: Record<string, Tool> = {
getWeather: {
description: 'Get the current weather in a given location',
inputSchema: jsonSchema({
type: 'object',
properties: {
location: {
type: 'string',
description: 'The city and state, e.g. San Francisco, CA'
},
unit: {
type: 'string',
enum: ['celsius', 'fahrenheit'],
description: 'The temperature unit to use'
}
},
required: ['location']
}),
execute: async ({ location, unit = 'fahrenheit' }) => {
return {
location,
temperature: unit === 'celsius' ? 22 : 72,
unit,
condition: 'sunny'
}
}
},
calculate: {
description: 'Perform a mathematical calculation',
inputSchema: jsonSchema({
type: 'object',
properties: {
operation: {
type: 'string',
enum: ['add', 'subtract', 'multiply', 'divide'],
description: 'The operation to perform'
},
a: {
type: 'number',
description: 'The first number'
},
b: {
type: 'number',
description: 'The second number'
}
},
required: ['operation', 'a', 'b']
}),
execute: async ({ operation, a, b }) => {
const operations = {
add: (x: number, y: number) => x + y,
subtract: (x: number, y: number) => x - y,
multiply: (x: number, y: number) => x * y,
divide: (x: number, y: number) => x / y
}
return { result: operations[operation as keyof typeof operations](a, b) }
}
},
searchDatabase: {
description: 'Search for information in a database',
inputSchema: jsonSchema({
type: 'object',
properties: {
query: {
type: 'string',
description: 'The search query'
},
limit: {
type: 'number',
description: 'Maximum number of results to return',
default: 10
}
},
required: ['query']
}),
execute: async ({ query, limit = 10 }) => {
return {
results: [
{ id: 1, title: `Result 1 for ${query}`, relevance: 0.95 },
{ id: 2, title: `Result 2 for ${query}`, relevance: 0.87 }
].slice(0, limit)
}
}
}
}
/**
* Mock streaming chunks for different providers
*/
export const mockStreamingChunks = {
text: [
{ type: 'text-delta' as const, textDelta: 'Hello' },
{ type: 'text-delta' as const, textDelta: ', ' },
{ type: 'text-delta' as const, textDelta: 'this ' },
{ type: 'text-delta' as const, textDelta: 'is ' },
{ type: 'text-delta' as const, textDelta: 'a ' },
{ type: 'text-delta' as const, textDelta: 'test.' }
],
withToolCall: [
{ type: 'text-delta' as const, textDelta: 'Let me check the weather for you.' },
{
type: 'tool-call-delta' as const,
toolCallType: 'function' as const,
toolCallId: 'call_123',
toolName: 'getWeather',
argsTextDelta: '{"location":'
},
{
type: 'tool-call-delta' as const,
toolCallType: 'function' as const,
toolCallId: 'call_123',
toolName: 'getWeather',
argsTextDelta: ' "San Francisco, CA"}'
},
{
type: 'tool-call' as const,
toolCallType: 'function' as const,
toolCallId: 'call_123',
toolName: 'getWeather',
args: { location: 'San Francisco, CA' }
}
],
withFinish: [
{ type: 'text-delta' as const, textDelta: 'Complete response.' },
{
type: 'finish' as const,
finishReason: 'stop' as const,
usage: {
promptTokens: 10,
completionTokens: 5,
totalTokens: 15
}
}
]
}
/**
* Mock complete responses for non-streaming scenarios
*/
export const mockCompleteResponses = {
simple: {
text: 'This is a simple response.',
finishReason: 'stop' as const,
usage: {
promptTokens: 15,
completionTokens: 8,
totalTokens: 23
}
},
withToolCalls: {
text: 'I will check the weather for you.',
toolCalls: [
{
toolCallId: 'call_456',
toolName: 'getWeather',
args: { location: 'New York, NY', unit: 'celsius' }
}
],
finishReason: 'tool-calls' as const,
usage: {
promptTokens: 25,
completionTokens: 12,
totalTokens: 37
}
},
withWarnings: {
text: 'Response with warnings.',
finishReason: 'stop' as const,
usage: {
promptTokens: 10,
completionTokens: 5,
totalTokens: 15
},
warnings: [
{
type: 'unsupported-setting' as const,
message: 'Temperature parameter not supported for this model'
}
]
}
}
/**
* Mock image generation responses
*/
export const mockImageResponses = {
single: {
image: {
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
uint8Array: new Uint8Array([137, 80, 78, 71, 13, 10, 26, 10, 0, 0, 0, 13, 73, 72, 68, 82]),
mimeType: 'image/png' as const
},
warnings: []
},
multiple: {
images: [
{
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
uint8Array: new Uint8Array([137, 80, 78, 71]),
mimeType: 'image/png' as const
},
{
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAIAAAACCAYAAABytg0kAAAAEklEQVR42mNk+M9QzwAEjDAGACCKAgdZ9zImAAAAAElFTkSuQmCC',
uint8Array: new Uint8Array([137, 80, 78, 71]),
mimeType: 'image/png' as const
}
],
warnings: []
},
withProviderMetadata: {
image: {
base64: 'iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mNk+M9QDwADhgGAWjR9awAAAABJRU5ErkJggg==',
uint8Array: new Uint8Array([137, 80, 78, 71]),
mimeType: 'image/png' as const
},
providerMetadata: {
openai: {
images: [
{
revisedPrompt: 'A detailed and enhanced version of the original prompt'
}
]
}
},
warnings: []
}
}
/**
* Mock error responses
*/
export const mockErrors = {
invalidApiKey: {
name: 'APIError',
message: 'Invalid API key provided',
statusCode: 401
},
rateLimitExceeded: {
name: 'RateLimitError',
message: 'Rate limit exceeded. Please try again later.',
statusCode: 429,
headers: {
'retry-after': '60'
}
},
modelNotFound: {
name: 'ModelNotFoundError',
message: 'The requested model was not found',
statusCode: 404
},
contextLengthExceeded: {
name: 'ContextLengthError',
message: "This model's maximum context length is 4096 tokens",
statusCode: 400
},
timeout: {
name: 'TimeoutError',
message: 'Request timed out after 30000ms',
code: 'ETIMEDOUT'
},
networkError: {
name: 'NetworkError',
message: 'Network connection failed',
code: 'ECONNREFUSED'
}
}

View File

@@ -0,0 +1,329 @@
/**
* Provider-Specific Test Utilities
* Helper functions for testing individual providers with all their parameters
*/
import type { Tool } from 'ai'
import { expect } from 'vitest'
/**
* Provider parameter configurations for comprehensive testing
*/
export const providerParameterMatrix = {
openai: {
models: ['gpt-4', 'gpt-4-turbo', 'gpt-3.5-turbo', 'gpt-4o'],
parameters: {
temperature: [0, 0.5, 0.7, 1.0, 1.5, 2.0],
maxTokens: [100, 500, 1000, 2000, 4000],
topP: [0.1, 0.5, 0.9, 1.0],
frequencyPenalty: [-2.0, -1.0, 0, 1.0, 2.0],
presencePenalty: [-2.0, -1.0, 0, 1.0, 2.0],
stop: [undefined, ['stop'], ['STOP', 'END']],
seed: [undefined, 12345, 67890],
responseFormat: [undefined, { type: 'json_object' as const }],
user: [undefined, 'test-user-123']
},
toolChoice: ['auto', 'required', 'none', { type: 'function' as const, name: 'getWeather' }],
parallelToolCalls: [true, false]
},
anthropic: {
models: ['claude-3-5-sonnet-20241022', 'claude-3-opus-20240229', 'claude-3-haiku-20240307'],
parameters: {
temperature: [0, 0.5, 1.0],
maxTokens: [100, 1000, 4000, 8000],
topP: [0.1, 0.5, 0.9, 1.0],
topK: [undefined, 1, 5, 10, 40],
stop: [undefined, ['Human:', 'Assistant:']],
metadata: [undefined, { userId: 'test-123' }]
},
toolChoice: ['auto', 'any', { type: 'tool' as const, name: 'getWeather' }]
},
google: {
models: ['gemini-2.0-flash-exp', 'gemini-1.5-pro', 'gemini-1.5-flash'],
parameters: {
temperature: [0, 0.5, 0.9, 1.0],
maxTokens: [100, 1000, 2000, 8000],
topP: [0.1, 0.5, 0.95, 1.0],
topK: [undefined, 1, 16, 40],
stopSequences: [undefined, ['END'], ['STOP', 'TERMINATE']]
},
safetySettings: [
undefined,
[
{ category: 'HARM_CATEGORY_HARASSMENT', threshold: 'BLOCK_MEDIUM_AND_ABOVE' },
{ category: 'HARM_CATEGORY_HATE_SPEECH', threshold: 'BLOCK_ONLY_HIGH' }
]
]
},
xai: {
models: ['grok-2-latest', 'grok-2-1212'],
parameters: {
temperature: [0, 0.5, 1.0, 1.5],
maxTokens: [100, 500, 2000, 4000],
topP: [0.1, 0.5, 0.9, 1.0],
stop: [undefined, ['STOP'], ['END', 'TERMINATE']],
seed: [undefined, 12345]
}
},
deepseek: {
models: ['deepseek-chat', 'deepseek-coder'],
parameters: {
temperature: [0, 0.5, 1.0],
maxTokens: [100, 1000, 4000],
topP: [0.1, 0.5, 0.95],
frequencyPenalty: [0, 0.5, 1.0],
presencePenalty: [0, 0.5, 1.0],
stop: [undefined, ['```'], ['END']]
}
},
azure: {
deployments: ['gpt-4-deployment', 'gpt-35-turbo-deployment'],
parameters: {
temperature: [0, 0.7, 1.0],
maxTokens: [100, 1000, 2000],
topP: [0.1, 0.5, 0.95],
frequencyPenalty: [0, 1.0],
presencePenalty: [0, 1.0],
stop: [undefined, ['STOP']]
}
}
} as const
/**
* Creates test cases for all parameter combinations
*/
export function generateParameterTestCases<T extends Record<string, any[]>>(
params: T,
maxCombinations = 50
): Array<Partial<{ [K in keyof T]: T[K][number] }>> {
const keys = Object.keys(params) as Array<keyof T>
const testCases: Array<Partial<{ [K in keyof T]: T[K][number] }>> = []
// Generate combinations using sampling strategy for large parameter spaces
const totalCombinations = keys.reduce((acc, key) => acc * params[key].length, 1)
if (totalCombinations <= maxCombinations) {
// Generate all combinations if total is small
generateAllCombinations(params, keys, 0, {}, testCases)
} else {
// Sample diverse combinations if total is large
generateSampledCombinations(params, keys, maxCombinations, testCases)
}
return testCases
}
function generateAllCombinations<T extends Record<string, any[]>>(
params: T,
keys: Array<keyof T>,
index: number,
current: Partial<{ [K in keyof T]: T[K][number] }>,
results: Array<Partial<{ [K in keyof T]: T[K][number] }>>
) {
if (index === keys.length) {
results.push({ ...current })
return
}
const key = keys[index]
for (const value of params[key]) {
generateAllCombinations(params, keys, index + 1, { ...current, [key]: value }, results)
}
}
function generateSampledCombinations<T extends Record<string, any[]>>(
params: T,
keys: Array<keyof T>,
count: number,
results: Array<Partial<{ [K in keyof T]: T[K][number] }>>
) {
// Generate edge cases first (min/max values)
const edgeCase1: any = {}
const edgeCase2: any = {}
for (const key of keys) {
edgeCase1[key] = params[key][0]
edgeCase2[key] = params[key][params[key].length - 1]
}
results.push(edgeCase1, edgeCase2)
// Generate random combinations for the rest
for (let i = results.length; i < count; i++) {
const combination: any = {}
for (const key of keys) {
const values = params[key]
combination[key] = values[Math.floor(Math.random() * values.length)]
}
results.push(combination)
}
}
/**
* Validates that all provider-specific parameters are correctly passed through
*/
export function validateProviderParams(providerId: string, actualParams: any, expectedParams: any): void {
const requiredFields: Record<string, string[]> = {
openai: ['model', 'messages'],
anthropic: ['model', 'messages'],
google: ['model', 'contents'],
xai: ['model', 'messages'],
deepseek: ['model', 'messages'],
azure: ['messages']
}
const fields = requiredFields[providerId] || ['model', 'messages']
for (const field of fields) {
expect(actualParams).toHaveProperty(field)
}
// Validate optional parameters if they were provided
const optionalParams = ['temperature', 'max_tokens', 'top_p', 'stop', 'tools']
for (const param of optionalParams) {
if (expectedParams[param] !== undefined) {
expect(actualParams[param]).toEqual(expectedParams[param])
}
}
}
/**
* Creates a comprehensive test suite for a provider
*/
// oxlint-disable-next-line no-unused-vars
export function createProviderTestSuite(_providerId: string) {
return {
testBasicCompletion: async (executor: any, model: string) => {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'Hello' }]
})
expect(result).toBeDefined()
expect(result.text).toBeDefined()
expect(typeof result.text).toBe('string')
},
testStreaming: async (executor: any, model: string) => {
const chunks: any[] = []
const result = await executor.streamText({
model,
messages: [{ role: 'user' as const, content: 'Hello' }]
})
for await (const chunk of result.textStream) {
chunks.push(chunk)
}
expect(chunks.length).toBeGreaterThan(0)
},
testTemperature: async (executor: any, model: string, temperatures: number[]) => {
for (const temperature of temperatures) {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'Hello' }],
temperature
})
expect(result).toBeDefined()
}
},
testMaxTokens: async (executor: any, model: string, maxTokensValues: number[]) => {
for (const maxTokens of maxTokensValues) {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'Hello' }],
maxTokens
})
expect(result).toBeDefined()
if (result.usage?.completionTokens) {
expect(result.usage.completionTokens).toBeLessThanOrEqual(maxTokens)
}
}
},
testToolCalling: async (executor: any, model: string, tools: Record<string, Tool>) => {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'What is the weather in SF?' }],
tools
})
expect(result).toBeDefined()
},
testStopSequences: async (executor: any, model: string, stopSequences: string[][]) => {
for (const stop of stopSequences) {
const result = await executor.generateText({
model,
messages: [{ role: 'user' as const, content: 'Count to 10' }],
stop
})
expect(result).toBeDefined()
}
}
}
}
/**
* Generates test data for vision/multimodal testing
*/
export function createVisionTestData() {
return {
imageUrl: 'https://example.com/test-image.jpg',
base64Image:
'',
messages: [
{
role: 'user' as const,
content: [
{ type: 'text' as const, text: 'What is in this image?' },
{
type: 'image' as const,
image:
''
}
]
}
]
}
}
/**
* Creates mock responses for different finish reasons
*/
export function createFinishReasonMocks() {
return {
stop: {
text: 'Complete response.',
finishReason: 'stop' as const,
usage: { promptTokens: 10, completionTokens: 5, totalTokens: 15 }
},
length: {
text: 'Incomplete response due to',
finishReason: 'length' as const,
usage: { promptTokens: 10, completionTokens: 100, totalTokens: 110 }
},
'tool-calls': {
text: 'Calling tools',
finishReason: 'tool-calls' as const,
toolCalls: [{ toolCallId: 'call_1', toolName: 'getWeather', args: { location: 'SF' } }],
usage: { promptTokens: 10, completionTokens: 8, totalTokens: 18 }
},
'content-filter': {
text: '',
finishReason: 'content-filter' as const,
usage: { promptTokens: 10, completionTokens: 0, totalTokens: 10 }
}
}
}

View File

@@ -0,0 +1,291 @@
/**
* Test Utilities
* Helper functions for testing AI Core functionality
*/
import { expect, vi } from 'vitest'
import type { ProviderId } from '../fixtures/mock-providers'
import { createMockImageModel, createMockLanguageModel, mockProviderConfigs } from '../fixtures/mock-providers'
/**
* Creates a test provider with streaming support
*/
export function createTestStreamingProvider(chunks: any[]) {
return createMockLanguageModel({
doStream: vi.fn().mockReturnValue({
stream: (async function* () {
for (const chunk of chunks) {
yield chunk
}
})(),
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
})
})
}
/**
* Creates a test provider that throws errors
*/
export function createErrorProvider(error: Error) {
return createMockLanguageModel({
doGenerate: vi.fn().mockRejectedValue(error),
doStream: vi.fn().mockImplementation(() => {
throw error
})
})
}
/**
* Collects all chunks from a stream
*/
export async function collectStreamChunks<T>(stream: AsyncIterable<T>): Promise<T[]> {
const chunks: T[] = []
for await (const chunk of stream) {
chunks.push(chunk)
}
return chunks
}
/**
* Waits for a specific number of milliseconds
*/
export function wait(ms: number): Promise<void> {
return new Promise((resolve) => setTimeout(resolve, ms))
}
/**
* Creates a mock abort controller that aborts after a delay
*/
export function createDelayedAbortController(delayMs: number): AbortController {
const controller = new AbortController()
setTimeout(() => controller.abort(), delayMs)
return controller
}
/**
* Asserts that a function throws an error with a specific message
*/
export async function expectError(fn: () => Promise<any>, expectedMessage?: string | RegExp): Promise<Error> {
try {
await fn()
throw new Error('Expected function to throw an error, but it did not')
} catch (error) {
if (expectedMessage) {
const message = (error as Error).message
if (typeof expectedMessage === 'string') {
if (!message.includes(expectedMessage)) {
throw new Error(`Expected error message to include "${expectedMessage}", but got "${message}"`)
}
} else {
if (!expectedMessage.test(message)) {
throw new Error(`Expected error message to match ${expectedMessage}, but got "${message}"`)
}
}
}
return error as Error
}
}
/**
* Creates a spy function that tracks calls and arguments
*/
export function createSpy<T extends (...args: any[]) => any>() {
const calls: Array<{ args: Parameters<T>; result?: ReturnType<T>; error?: Error }> = []
const spy = vi.fn((...args: Parameters<T>) => {
try {
const result = undefined as ReturnType<T>
calls.push({ args, result })
return result
} catch (error) {
calls.push({ args, error: error as Error })
throw error
}
})
return {
fn: spy,
calls,
getCalls: () => calls,
getCallCount: () => calls.length,
getLastCall: () => calls[calls.length - 1],
reset: () => {
calls.length = 0
spy.mockClear()
}
}
}
/**
* Validates provider configuration
*/
export function validateProviderConfig(providerId: ProviderId) {
const config = mockProviderConfigs[providerId]
if (!config) {
throw new Error(`No mock configuration found for provider: ${providerId}`)
}
if (!config.apiKey) {
throw new Error(`Provider ${providerId} is missing apiKey in mock config`)
}
return config
}
/**
* Creates a test context with common setup
*/
export function createTestContext() {
const mocks = {
languageModel: createMockLanguageModel(),
imageModel: createMockImageModel(),
providers: new Map<string, any>()
}
const cleanup = () => {
mocks.providers.clear()
vi.clearAllMocks()
}
return {
mocks,
cleanup
}
}
/**
* Measures execution time of an async function
*/
export async function measureTime<T>(fn: () => Promise<T>): Promise<{ result: T; duration: number }> {
const start = Date.now()
const result = await fn()
const duration = Date.now() - start
return { result, duration }
}
/**
* Retries a function until it succeeds or max attempts reached
*/
export async function retryUntilSuccess<T>(fn: () => Promise<T>, maxAttempts = 3, delayMs = 100): Promise<T> {
let lastError: Error | undefined
for (let attempt = 1; attempt <= maxAttempts; attempt++) {
try {
return await fn()
} catch (error) {
lastError = error as Error
if (attempt < maxAttempts) {
await wait(delayMs)
}
}
}
throw lastError || new Error('All retry attempts failed')
}
/**
* Creates a mock streaming response that emits chunks at intervals
*/
export function createTimedStream<T>(chunks: T[], intervalMs = 10) {
return {
async *[Symbol.asyncIterator]() {
for (const chunk of chunks) {
await wait(intervalMs)
yield chunk
}
}
}
}
/**
* Asserts that two objects are deeply equal, ignoring specified keys
*/
export function assertDeepEqualIgnoring<T extends Record<string, any>>(
actual: T,
expected: T,
ignoreKeys: string[] = []
): void {
const filterKeys = (obj: T): Partial<T> => {
const filtered = { ...obj }
for (const key of ignoreKeys) {
delete filtered[key]
}
return filtered
}
const filteredActual = filterKeys(actual)
const filteredExpected = filterKeys(expected)
expect(filteredActual).toEqual(filteredExpected)
}
/**
* Creates a provider mock that simulates rate limiting
*/
export function createRateLimitedProvider(limitPerSecond: number) {
const calls: number[] = []
return createMockLanguageModel({
doGenerate: vi.fn().mockImplementation(async () => {
const now = Date.now()
calls.push(now)
// Remove calls older than 1 second
const recentCalls = calls.filter((time) => now - time < 1000)
if (recentCalls.length > limitPerSecond) {
throw new Error('Rate limit exceeded')
}
return {
text: 'Rate limited response',
finishReason: 'stop' as const,
usage: { promptTokens: 10, completionTokens: 5, totalTokens: 15 },
rawCall: { rawPrompt: null, rawSettings: {} },
rawResponse: { headers: {} },
warnings: []
}
})
})
}
/**
* Validates streaming response structure
*/
export function validateStreamChunk(chunk: any): void {
expect(chunk).toBeDefined()
expect(chunk).toHaveProperty('type')
if (chunk.type === 'text-delta') {
expect(chunk).toHaveProperty('textDelta')
expect(typeof chunk.textDelta).toBe('string')
} else if (chunk.type === 'finish') {
expect(chunk).toHaveProperty('finishReason')
expect(chunk).toHaveProperty('usage')
} else if (chunk.type === 'tool-call') {
expect(chunk).toHaveProperty('toolCallId')
expect(chunk).toHaveProperty('toolName')
expect(chunk).toHaveProperty('args')
}
}
/**
* Creates a test logger that captures log messages
*/
export function createTestLogger() {
const logs: Array<{ level: string; message: string; meta?: any }> = []
return {
info: (message: string, meta?: any) => logs.push({ level: 'info', message, meta }),
warn: (message: string, meta?: any) => logs.push({ level: 'warn', message, meta }),
error: (message: string, meta?: any) => logs.push({ level: 'error', message, meta }),
debug: (message: string, meta?: any) => logs.push({ level: 'debug', message, meta }),
getLogs: () => logs,
clear: () => {
logs.length = 0
}
}
}

View File

@@ -0,0 +1,12 @@
/**
* Test Infrastructure Exports
* Central export point for all test utilities, fixtures, and helpers
*/
// Fixtures
export * from './fixtures/mock-providers'
export * from './fixtures/mock-responses'
// Helpers
export * from './helpers/provider-test-utils'
export * from './helpers/test-utils'

View File

@@ -4,12 +4,7 @@
*/
export const BUILT_IN_PLUGIN_PREFIX = 'built-in:'
export { googleToolsPlugin } from './googleToolsPlugin'
export { createLoggingPlugin } from './logging'
export { createPromptToolUsePlugin } from './toolUsePlugin/promptToolUsePlugin'
export type {
PromptToolUseConfig,
ToolUseRequestContext,
ToolUseResult
} from './toolUsePlugin/type'
export { webSearchPlugin, type WebSearchPluginConfig } from './webSearchPlugin'
export * from './googleToolsPlugin'
export * from './toolUsePlugin/promptToolUsePlugin'
export * from './toolUsePlugin/type'
export * from './webSearchPlugin'

View File

@@ -32,7 +32,7 @@ export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEAR
})
// 导出类型定义供开发者使用
export type { WebSearchPluginConfig, WebSearchToolOutputSchema } from './helper'
export * from './helper'
// 默认导出
export default webSearchPlugin

View File

@@ -44,7 +44,7 @@ export {
// ==================== 基础数据和类型 ====================
// 基础Provider数据源
export { baseProviderIds, baseProviders } from './schemas'
export { baseProviderIds, baseProviders, isBaseProvider } from './schemas'
// 类型定义和Schema
export type {

View File

@@ -7,7 +7,6 @@ import { createAzure } from '@ai-sdk/azure'
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
import { createDeepSeek } from '@ai-sdk/deepseek'
import { createGoogleGenerativeAI } from '@ai-sdk/google'
import { createHuggingFace } from '@ai-sdk/huggingface'
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
import type { LanguageModelV2 } from '@ai-sdk/provider'
@@ -33,8 +32,7 @@ export const baseProviderIds = [
'deepseek',
'openrouter',
'cherryin',
'cherryin-chat',
'huggingface'
'cherryin-chat'
] as const
/**
@@ -158,12 +156,6 @@ export const baseProviders = [
})
},
supportsImageGeneration: true
},
{
id: 'huggingface',
name: 'HuggingFace',
creator: createHuggingFace,
supportsImageGeneration: true
}
] as const satisfies BaseProvider[]

View File

@@ -0,0 +1,499 @@
/**
* RuntimeExecutor.generateText Comprehensive Tests
* Tests non-streaming text generation across all providers with various parameters
*/
import { generateText } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
createMockLanguageModel,
mockCompleteResponses,
mockProviderConfigs,
testMessages,
testTools
} from '../../../__tests__'
import type { AiPlugin } from '../../plugins'
import { globalRegistryManagement } from '../../providers/RegistryManagement'
import { RuntimeExecutor } from '../executor'
// Mock AI SDK
vi.mock('ai', () => ({
generateText: vi.fn()
}))
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
languageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
describe('RuntimeExecutor.generateText', () => {
let executor: RuntimeExecutor<'openai'>
let mockLanguageModel: any
beforeEach(() => {
vi.clearAllMocks()
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
mockLanguageModel = createMockLanguageModel({
provider: 'openai',
modelId: 'gpt-4'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.simple as any)
})
describe('Basic Functionality', () => {
it('should generate text with minimal parameters', async () => {
const result = await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(generateText).toHaveBeenCalledWith({
model: mockLanguageModel,
messages: testMessages.simple
})
expect(result.text).toBe('This is a simple response.')
expect(result.finishReason).toBe('stop')
expect(result.usage).toBeDefined()
})
it('should generate with system messages', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.withSystem
})
expect(generateText).toHaveBeenCalledWith({
model: mockLanguageModel,
messages: testMessages.withSystem
})
})
it('should generate with conversation history', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.conversation
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
messages: testMessages.conversation
})
)
})
})
describe('All Parameter Combinations', () => {
it('should support all parameters together', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
temperature: 0.7,
maxOutputTokens: 500,
topP: 0.9,
frequencyPenalty: 0.5,
presencePenalty: 0.3,
stopSequences: ['STOP'],
seed: 12345
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.7,
maxOutputTokens: 500,
topP: 0.9,
frequencyPenalty: 0.5,
presencePenalty: 0.3,
stopSequences: ['STOP'],
seed: 12345
})
)
})
it('should support partial parameters', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
temperature: 0.5,
maxOutputTokens: 100
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.5,
maxOutputTokens: 100
})
)
})
})
describe('Tool Calling', () => {
beforeEach(() => {
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.withToolCalls as any)
})
it('should support tool calling', async () => {
const result = await executor.generateText({
model: 'gpt-4',
messages: testMessages.toolUse,
tools: testTools
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
tools: testTools
})
)
expect(result.toolCalls).toBeDefined()
expect(result.toolCalls).toHaveLength(1)
})
it('should support toolChoice auto', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.toolUse,
tools: testTools,
toolChoice: 'auto'
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
toolChoice: 'auto'
})
)
})
it('should support toolChoice required', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.toolUse,
tools: testTools,
toolChoice: 'required'
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
toolChoice: 'required'
})
)
})
it('should support toolChoice none', async () => {
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.simple as any)
await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
tools: testTools,
toolChoice: 'none'
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
toolChoice: 'none'
})
)
})
it('should support specific tool selection', async () => {
await executor.generateText({
model: 'gpt-4',
messages: testMessages.toolUse,
tools: testTools,
toolChoice: {
type: 'tool',
toolName: 'getWeather'
}
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
toolChoice: {
type: 'tool',
toolName: 'getWeather'
}
})
)
})
})
describe('Multiple Providers', () => {
it('should work with Anthropic provider', async () => {
const anthropicExecutor = RuntimeExecutor.create('anthropic', mockProviderConfigs.anthropic)
const anthropicModel = createMockLanguageModel({
provider: 'anthropic',
modelId: 'claude-3-5-sonnet-20241022'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(anthropicModel)
await anthropicExecutor.generateText({
model: 'claude-3-5-sonnet-20241022',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('anthropic|claude-3-5-sonnet-20241022')
})
it('should work with Google provider', async () => {
const googleExecutor = RuntimeExecutor.create('google', mockProviderConfigs.google)
const googleModel = createMockLanguageModel({
provider: 'google',
modelId: 'gemini-2.0-flash-exp'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(googleModel)
await googleExecutor.generateText({
model: 'gemini-2.0-flash-exp',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('google|gemini-2.0-flash-exp')
})
it('should work with xAI provider', async () => {
const xaiExecutor = RuntimeExecutor.create('xai', mockProviderConfigs.xai)
const xaiModel = createMockLanguageModel({
provider: 'xai',
modelId: 'grok-2-latest'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(xaiModel)
await xaiExecutor.generateText({
model: 'grok-2-latest',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('xai|grok-2-latest')
})
it('should work with DeepSeek provider', async () => {
const deepseekExecutor = RuntimeExecutor.create('deepseek', mockProviderConfigs.deepseek)
const deepseekModel = createMockLanguageModel({
provider: 'deepseek',
modelId: 'deepseek-chat'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(deepseekModel)
await deepseekExecutor.generateText({
model: 'deepseek-chat',
messages: testMessages.simple
})
expect(globalRegistryManagement.languageModel).toHaveBeenCalledWith('deepseek|deepseek-chat')
})
})
describe('Plugin Integration', () => {
it('should execute all plugin hooks', async () => {
const pluginCalls: string[] = []
const testPlugin: AiPlugin = {
name: 'test-plugin',
onRequestStart: vi.fn(async () => {
pluginCalls.push('onRequestStart')
}),
transformParams: vi.fn(async (params) => {
pluginCalls.push('transformParams')
return { ...params, temperature: 0.8 }
}),
transformResult: vi.fn(async (result) => {
pluginCalls.push('transformResult')
return { ...result, text: result.text + ' [modified]' }
}),
onRequestEnd: vi.fn(async () => {
pluginCalls.push('onRequestEnd')
})
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
const result = await executorWithPlugin.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(pluginCalls).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
// Verify transformed parameters
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.8
})
)
// Verify transformed result
expect(result.text).toContain('[modified]')
})
it('should handle multiple plugins in order', async () => {
const pluginOrder: string[] = []
const plugin1: AiPlugin = {
name: 'plugin-1',
transformParams: vi.fn(async (params) => {
pluginOrder.push('plugin-1')
return { ...params, temperature: 0.5 }
})
}
const plugin2: AiPlugin = {
name: 'plugin-2',
transformParams: vi.fn(async (params) => {
pluginOrder.push('plugin-2')
return { ...params, maxTokens: 200 }
})
}
const executorWithPlugins = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [plugin1, plugin2])
await executorWithPlugins.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(pluginOrder).toEqual(['plugin-1', 'plugin-2'])
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.5,
maxTokens: 200
})
)
})
})
describe('Error Handling', () => {
it('should handle API errors', async () => {
const error = new Error('API request failed')
vi.mocked(generateText).mockRejectedValue(error)
await expect(
executor.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
).rejects.toThrow('API request failed')
})
it('should execute onError plugin hook', async () => {
const error = new Error('Generation failed')
vi.mocked(generateText).mockRejectedValue(error)
const errorPlugin: AiPlugin = {
name: 'error-handler',
onError: vi.fn()
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
await expect(
executorWithPlugin.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
).rejects.toThrow('Generation failed')
expect(errorPlugin.onError).toHaveBeenCalledWith(
error,
expect.objectContaining({
providerId: 'openai',
modelId: 'gpt-4'
})
)
})
it('should handle model not found error', async () => {
const error = new Error('Model not found: invalid-model')
vi.mocked(globalRegistryManagement.languageModel).mockImplementation(() => {
throw error
})
await expect(
executor.generateText({
model: 'invalid-model',
messages: testMessages.simple
})
).rejects.toThrow('Model not found')
})
})
describe('Usage and Metadata', () => {
it('should return usage information', async () => {
const result = await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(result.usage).toBeDefined()
expect(result.usage.inputTokens).toBe(15)
expect(result.usage.outputTokens).toBe(8)
expect(result.usage.totalTokens).toBe(23)
})
it('should handle warnings', async () => {
vi.mocked(generateText).mockResolvedValue(mockCompleteResponses.withWarnings as any)
const result = await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
temperature: 2.5 // Unsupported value
})
expect(result.warnings).toBeDefined()
expect(result.warnings).toHaveLength(1)
expect(result.warnings![0].type).toBe('unsupported-setting')
})
})
describe('Abort Signal', () => {
it('should support abort signal', async () => {
const abortController = new AbortController()
await executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
abortSignal: abortController.signal
})
expect(generateText).toHaveBeenCalledWith(
expect.objectContaining({
abortSignal: abortController.signal
})
)
})
it('should handle aborted request', async () => {
const abortError = new Error('Request aborted')
abortError.name = 'AbortError'
vi.mocked(generateText).mockRejectedValue(abortError)
const abortController = new AbortController()
abortController.abort()
await expect(
executor.generateText({
model: 'gpt-4',
messages: testMessages.simple,
abortSignal: abortController.signal
})
).rejects.toThrow('Request aborted')
})
})
})

View File

@@ -0,0 +1,525 @@
/**
* RuntimeExecutor.streamText Comprehensive Tests
* Tests streaming text generation across all providers with various parameters
*/
import { streamText } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { collectStreamChunks, createMockLanguageModel, mockProviderConfigs, testMessages } from '../../../__tests__'
import type { AiPlugin } from '../../plugins'
import { globalRegistryManagement } from '../../providers/RegistryManagement'
import { RuntimeExecutor } from '../executor'
// Mock AI SDK
vi.mock('ai', () => ({
streamText: vi.fn()
}))
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
languageModel: vi.fn()
},
DEFAULT_SEPARATOR: '|'
}))
describe('RuntimeExecutor.streamText', () => {
let executor: RuntimeExecutor<'openai'>
let mockLanguageModel: any
beforeEach(() => {
vi.clearAllMocks()
executor = RuntimeExecutor.create('openai', mockProviderConfigs.openai)
mockLanguageModel = createMockLanguageModel({
provider: 'openai',
modelId: 'gpt-4'
})
vi.mocked(globalRegistryManagement.languageModel).mockReturnValue(mockLanguageModel)
})
describe('Basic Functionality', () => {
it('should stream text with minimal parameters', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Hello'
yield ' '
yield 'World'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Hello' }
yield { type: 'text-delta', textDelta: ' ' }
yield { type: 'text-delta', textDelta: 'World' }
})(),
usage: Promise.resolve({ promptTokens: 5, completionTokens: 3, totalTokens: 8 })
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
const result = await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
expect(streamText).toHaveBeenCalledWith({
model: mockLanguageModel,
messages: testMessages.simple
})
const chunks = await collectStreamChunks(result.textStream)
expect(chunks).toEqual(['Hello', ' ', 'World'])
})
it('should stream with system messages', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.withSystem
})
expect(streamText).toHaveBeenCalledWith({
model: mockLanguageModel,
messages: testMessages.withSystem
})
})
it('should stream multi-turn conversations', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Multi-turn response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Multi-turn response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.multiTurn
})
expect(streamText).toHaveBeenCalled()
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
messages: testMessages.multiTurn
})
)
})
})
describe('Temperature Parameter', () => {
const temperatures = [0, 0.3, 0.5, 0.7, 0.9, 1.0, 1.5, 2.0]
it.each(temperatures)('should support temperature=%s', async (temperature) => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
temperature
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
temperature
})
)
})
})
describe('Max Tokens Parameter', () => {
const maxTokensValues = [10, 50, 100, 500, 1000, 2000, 4000]
it.each(maxTokensValues)('should support maxTokens=%s', async (maxTokens) => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
maxOutputTokens: maxTokens
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
maxTokens
})
)
})
})
describe('Top P Parameter', () => {
const topPValues = [0.1, 0.3, 0.5, 0.7, 0.9, 0.95, 1.0]
it.each(topPValues)('should support topP=%s', async (topP) => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
topP
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
topP
})
)
})
})
describe('Frequency and Presence Penalty', () => {
it('should support frequency penalty', async () => {
const penalties = [-2.0, -1.0, 0, 0.5, 1.0, 1.5, 2.0]
for (const frequencyPenalty of penalties) {
vi.clearAllMocks()
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
frequencyPenalty
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
frequencyPenalty
})
)
}
})
it('should support presence penalty', async () => {
const penalties = [-2.0, -1.0, 0, 0.5, 1.0, 1.5, 2.0]
for (const presencePenalty of penalties) {
vi.clearAllMocks()
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
presencePenalty
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
presencePenalty
})
)
}
})
it('should support both penalties together', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
frequencyPenalty: 0.5,
presencePenalty: 0.5
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
frequencyPenalty: 0.5,
presencePenalty: 0.5
})
)
})
})
describe('Seed Parameter', () => {
it('should support seed for deterministic output', async () => {
const seeds = [0, 12345, 67890, 999999]
for (const seed of seeds) {
vi.clearAllMocks()
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
seed
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
seed
})
)
}
})
})
describe('Abort Signal', () => {
it('should support abort signal', async () => {
const abortController = new AbortController()
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
abortSignal: abortController.signal
})
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
abortSignal: abortController.signal
})
)
})
it('should handle abort during streaming', async () => {
const abortController = new AbortController()
const mockStream = {
textStream: (async function* () {
yield 'Start'
// Simulate abort
abortController.abort()
throw new Error('Aborted')
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Start' }
throw new Error('Aborted')
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
const result = await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple,
abortSignal: abortController.signal
})
await expect(async () => {
// oxlint-disable-next-line no-unused-vars
for await (const _chunk of result.textStream) {
// Stream should be interrupted
}
}).rejects.toThrow('Aborted')
})
})
describe('Plugin Integration', () => {
it('should execute plugins during streaming', async () => {
const pluginCalls: string[] = []
const testPlugin: AiPlugin = {
name: 'test-plugin',
onRequestStart: vi.fn(async () => {
pluginCalls.push('onRequestStart')
}),
transformParams: vi.fn(async (params) => {
pluginCalls.push('transformParams')
return { ...params, temperature: 0.5 }
}),
onRequestEnd: vi.fn(async () => {
pluginCalls.push('onRequestEnd')
})
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [testPlugin])
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
const result = await executorWithPlugin.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
// Consume stream
// oxlint-disable-next-line no-unused-vars
for await (const _chunk of result.textStream) {
// Stream chunks
}
expect(pluginCalls).toContain('onRequestStart')
expect(pluginCalls).toContain('transformParams')
// Verify transformed parameters were used
expect(streamText).toHaveBeenCalledWith(
expect.objectContaining({
temperature: 0.5
})
)
})
})
describe('Full Stream with Finish Reason', () => {
it('should provide finish reason in full stream', async () => {
const mockStream = {
textStream: (async function* () {
yield 'Response'
})(),
fullStream: (async function* () {
yield { type: 'text-delta', textDelta: 'Response' }
yield {
type: 'finish',
finishReason: 'stop',
usage: { promptTokens: 5, completionTokens: 3, totalTokens: 8 }
}
})()
}
vi.mocked(streamText).mockResolvedValue(mockStream as any)
const result = await executor.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
const fullChunks = await collectStreamChunks(result.fullStream)
expect(fullChunks).toHaveLength(2)
expect(fullChunks[0]).toEqual({ type: 'text-delta', textDelta: 'Response' })
expect(fullChunks[1]).toEqual({
type: 'finish',
finishReason: 'stop',
usage: { promptTokens: 5, completionTokens: 3, totalTokens: 8 }
})
})
})
describe('Error Handling', () => {
it('should handle streaming errors', async () => {
const error = new Error('Streaming failed')
vi.mocked(streamText).mockRejectedValue(error)
await expect(
executor.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
).rejects.toThrow('Streaming failed')
})
it('should execute onError plugin hook on failure', async () => {
const error = new Error('Stream error')
vi.mocked(streamText).mockRejectedValue(error)
const errorPlugin: AiPlugin = {
name: 'error-handler',
onError: vi.fn()
}
const executorWithPlugin = RuntimeExecutor.create('openai', mockProviderConfigs.openai, [errorPlugin])
await expect(
executorWithPlugin.streamText({
model: 'gpt-4',
messages: testMessages.simple
})
).rejects.toThrow('Stream error')
expect(errorPlugin.onError).toHaveBeenCalledWith(
error,
expect.objectContaining({
providerId: 'openai',
modelId: 'gpt-4'
})
)
})
})
})

View File

@@ -41,6 +41,7 @@ export enum IpcChannel {
App_SetFullScreen = 'app:set-full-screen',
App_IsFullScreen = 'app:is-full-screen',
App_GetSystemFonts = 'app:get-system-fonts',
APP_CrashRenderProcess = 'app:crash-render-process',
App_MacIsProcessTrusted = 'app:mac-is-process-trusted',
App_MacRequestProcessTrust = 'app:mac-request-process-trust',
@@ -234,6 +235,7 @@ export enum IpcChannel {
System_GetDeviceType = 'system:getDeviceType',
System_GetHostname = 'system:getHostname',
System_GetCpuName = 'system:getCpuName',
System_CheckGitBash = 'system:checkGitBash',
// DevTools
System_ToggleDevTools = 'system:toggleDevTools',

View File

@@ -197,12 +197,22 @@ export enum FeedUrl {
GITHUB_LATEST = 'https://github.com/CherryHQ/cherry-studio/releases/latest/download'
}
export enum UpdateConfigUrl {
GITHUB = 'https://raw.githubusercontent.com/CherryHQ/cherry-studio/refs/heads/x-files/app-upgrade-config/app-upgrade-config.json',
GITCODE = 'https://raw.gitcode.com/CherryHQ/cherry-studio/raw/x-files%2Fapp-upgrade-config/app-upgrade-config.json'
}
export enum UpgradeChannel {
LATEST = 'latest', // 最新稳定版本
RC = 'rc', // 公测版本
BETA = 'beta' // 预览版本
}
export enum UpdateMirror {
GITHUB = 'github',
GITCODE = 'gitcode'
}
export const defaultTimeout = 10 * 1000 * 60
export const occupiedDirs = ['logs', 'Network', 'Partitions/webview/Network']

View File

@@ -4,3 +4,34 @@ export const defaultAppHeaders = () => {
'X-Title': 'Cherry Studio'
}
}
// Following two function are not being used for now.
// I may use them in the future, so just keep them commented. - by eurfelux
/**
* Converts an `undefined` value to `null`, otherwise returns the value as-is.
* @param value - The value to check
* @returns `null` if the input is `undefined`; otherwise the input value
*/
// export function toNullIfUndefined<T>(value: T | undefined): T | null {
// if (value === undefined) {
// return null
// } else {
// return value
// }
// }
/**
* Converts a `null` value to `undefined`, otherwise returns the value as-is.
* @param value - The value to check
* @returns `undefined` if the input is `null`; otherwise the input value
*/
// export function toUndefinedIfNull<T>(value: T | null): T | undefined {
// if (value === null) {
// return undefined
// } else {
// return value
// }
// }

View File

@@ -0,0 +1,532 @@
import fs from 'fs/promises'
import path from 'path'
import semver from 'semver'
type UpgradeChannel = 'latest' | 'rc' | 'beta'
type UpdateMirror = 'github' | 'gitcode'
const CHANNELS: UpgradeChannel[] = ['latest', 'rc', 'beta']
const MIRRORS: UpdateMirror[] = ['github', 'gitcode']
const GITHUB_REPO = 'CherryHQ/cherry-studio'
const GITCODE_REPO = 'CherryHQ/cherry-studio'
const DEFAULT_FEED_TEMPLATES: Record<UpdateMirror, string> = {
github: `https://github.com/${GITHUB_REPO}/releases/download/{{tag}}`,
gitcode: `https://gitcode.com/${GITCODE_REPO}/releases/download/{{tag}}`
}
const GITCODE_LATEST_FALLBACK = 'https://releases.cherry-ai.com'
interface CliOptions {
tag?: string
configPath?: string
segmentsPath?: string
dryRun?: boolean
skipReleaseChecks?: boolean
isPrerelease?: boolean
}
interface ChannelTemplateConfig {
feedTemplates?: Partial<Record<UpdateMirror, string>>
}
interface SegmentMatchRule {
range?: string
exact?: string[]
excludeExact?: string[]
}
interface SegmentDefinition {
id: string
type: 'legacy' | 'breaking' | 'latest'
match: SegmentMatchRule
lockedVersion?: string
minCompatibleVersion: string
description: string
channelTemplates?: Partial<Record<UpgradeChannel, ChannelTemplateConfig>>
}
interface SegmentMetadataFile {
segments: SegmentDefinition[]
}
interface ChannelConfig {
version: string
feedUrls: Record<UpdateMirror, string>
}
interface VersionMetadata {
segmentId: string
segmentType?: string
}
interface VersionEntry {
metadata?: VersionMetadata
minCompatibleVersion: string
description: string
channels: Record<UpgradeChannel, ChannelConfig | null>
}
interface UpgradeConfigFile {
lastUpdated: string
versions: Record<string, VersionEntry>
}
interface ReleaseInfo {
tag: string
version: string
channel: UpgradeChannel
}
interface UpdateVersionsResult {
versions: Record<string, VersionEntry>
updated: boolean
}
const ROOT_DIR = path.resolve(__dirname, '..')
const DEFAULT_CONFIG_PATH = path.join(ROOT_DIR, 'app-upgrade-config.json')
const DEFAULT_SEGMENTS_PATH = path.join(ROOT_DIR, 'config/app-upgrade-segments.json')
async function main() {
const options = parseArgs()
const releaseTag = resolveTag(options)
const normalizedVersion = normalizeVersion(releaseTag)
const releaseChannel = detectChannel(normalizedVersion)
if (!releaseChannel) {
console.warn(`[update-app-upgrade-config] Tag ${normalizedVersion} does not map to beta/rc/latest. Skipping.`)
return
}
// Validate version format matches prerelease status
if (options.isPrerelease !== undefined) {
const hasPrereleaseSuffix = releaseChannel === 'beta' || releaseChannel === 'rc'
if (options.isPrerelease && !hasPrereleaseSuffix) {
console.warn(
`[update-app-upgrade-config] ⚠️ Release marked as prerelease but version ${normalizedVersion} has no beta/rc suffix. Skipping.`
)
return
}
if (!options.isPrerelease && hasPrereleaseSuffix) {
console.warn(
`[update-app-upgrade-config] ⚠️ Release marked as latest but version ${normalizedVersion} has prerelease suffix (${releaseChannel}). Skipping.`
)
return
}
}
const [config, segmentFile] = await Promise.all([
readJson<UpgradeConfigFile>(options.configPath ?? DEFAULT_CONFIG_PATH),
readJson<SegmentMetadataFile>(options.segmentsPath ?? DEFAULT_SEGMENTS_PATH)
])
const segment = pickSegment(segmentFile.segments, normalizedVersion)
if (!segment) {
throw new Error(`Unable to find upgrade segment for version ${normalizedVersion}`)
}
if (segment.lockedVersion && segment.lockedVersion !== normalizedVersion) {
throw new Error(`Segment ${segment.id} is locked to ${segment.lockedVersion}, but received ${normalizedVersion}`)
}
const releaseInfo: ReleaseInfo = {
tag: formatTag(releaseTag),
version: normalizedVersion,
channel: releaseChannel
}
const { versions: updatedVersions, updated } = await updateVersions(
config.versions,
segment,
releaseInfo,
Boolean(options.skipReleaseChecks)
)
if (!updated) {
throw new Error(
`[update-app-upgrade-config] Feed URLs are not ready for ${releaseInfo.version} (${releaseInfo.channel}). Try again after the release mirrors finish syncing.`
)
}
const updatedConfig: UpgradeConfigFile = {
...config,
lastUpdated: new Date().toISOString(),
versions: updatedVersions
}
const output = JSON.stringify(updatedConfig, null, 2) + '\n'
if (options.dryRun) {
console.log('Dry run enabled. Generated configuration:\n')
console.log(output)
return
}
await fs.writeFile(options.configPath ?? DEFAULT_CONFIG_PATH, output, 'utf-8')
console.log(
`✅ Updated ${path.relative(process.cwd(), options.configPath ?? DEFAULT_CONFIG_PATH)} for ${segment.id} (${releaseInfo.channel}) -> ${releaseInfo.version}`
)
}
function parseArgs(): CliOptions {
const args = process.argv.slice(2)
const options: CliOptions = {}
for (let i = 0; i < args.length; i += 1) {
const arg = args[i]
if (arg === '--tag') {
options.tag = args[i + 1]
i += 1
} else if (arg === '--config') {
options.configPath = args[i + 1]
i += 1
} else if (arg === '--segments') {
options.segmentsPath = args[i + 1]
i += 1
} else if (arg === '--dry-run') {
options.dryRun = true
} else if (arg === '--skip-release-checks') {
options.skipReleaseChecks = true
} else if (arg === '--is-prerelease') {
options.isPrerelease = args[i + 1] === 'true'
i += 1
} else if (arg === '--help') {
printHelp()
process.exit(0)
} else {
console.warn(`Ignoring unknown argument "${arg}"`)
}
}
if (options.skipReleaseChecks && !options.dryRun) {
throw new Error('--skip-release-checks can only be used together with --dry-run')
}
return options
}
function printHelp() {
console.log(`Usage: tsx scripts/update-app-upgrade-config.ts [options]
Options:
--tag <tag> Release tag (e.g. v2.1.6). Falls back to GITHUB_REF_NAME/RELEASE_TAG.
--config <path> Path to app-upgrade-config.json.
--segments <path> Path to app-upgrade-segments.json.
--is-prerelease <true|false> Whether this is a prerelease (validates version format).
--dry-run Print the result without writing to disk.
--skip-release-checks Skip release page availability checks (only valid with --dry-run).
--help Show this help message.`)
}
function resolveTag(options: CliOptions): string {
const envTag = process.env.RELEASE_TAG ?? process.env.GITHUB_REF_NAME ?? process.env.TAG_NAME
const tag = options.tag ?? envTag
if (!tag) {
throw new Error('A release tag is required. Pass --tag or set RELEASE_TAG/GITHUB_REF_NAME.')
}
return tag
}
function normalizeVersion(tag: string): string {
const cleaned = semver.clean(tag, { loose: true })
if (!cleaned) {
throw new Error(`Tag "${tag}" is not a valid semantic version`)
}
const valid = semver.valid(cleaned, { loose: true })
if (!valid) {
throw new Error(`Unable to normalize tag "${tag}" to a valid semantic version`)
}
return valid
}
function detectChannel(version: string): UpgradeChannel | null {
const parsed = semver.parse(version, { loose: true, includePrerelease: true })
if (!parsed) {
return null
}
if (parsed.prerelease.length === 0) {
return 'latest'
}
const label = String(parsed.prerelease[0]).toLowerCase()
if (label === 'beta') {
return 'beta'
}
if (label === 'rc') {
return 'rc'
}
return null
}
async function readJson<T>(filePath: string): Promise<T> {
const absolute = path.isAbsolute(filePath) ? filePath : path.resolve(filePath)
const data = await fs.readFile(absolute, 'utf-8')
return JSON.parse(data) as T
}
function pickSegment(segments: SegmentDefinition[], version: string): SegmentDefinition | null {
for (const segment of segments) {
if (matchesSegment(segment.match, version)) {
return segment
}
}
return null
}
function matchesSegment(matchRule: SegmentMatchRule, version: string): boolean {
if (matchRule.exact && matchRule.exact.includes(version)) {
return true
}
if (matchRule.excludeExact && matchRule.excludeExact.includes(version)) {
return false
}
if (matchRule.range && !semver.satisfies(version, matchRule.range, { includePrerelease: true })) {
return false
}
if (matchRule.exact) {
return matchRule.exact.includes(version)
}
return Boolean(matchRule.range)
}
function formatTag(tag: string): string {
if (tag.startsWith('refs/tags/')) {
return tag.replace('refs/tags/', '')
}
return tag
}
async function updateVersions(
versions: Record<string, VersionEntry>,
segment: SegmentDefinition,
releaseInfo: ReleaseInfo,
skipReleaseValidation: boolean
): Promise<UpdateVersionsResult> {
const versionsCopy: Record<string, VersionEntry> = { ...versions }
const existingKey = findVersionKeyBySegment(versionsCopy, segment.id)
const targetKey = resolveVersionKey(existingKey, segment, releaseInfo)
const shouldRename = existingKey && existingKey !== targetKey
let entry: VersionEntry
if (existingKey) {
entry = { ...versionsCopy[existingKey], channels: { ...versionsCopy[existingKey].channels } }
} else {
entry = createEmptyVersionEntry()
}
entry.channels = ensureChannelSlots(entry.channels)
const channelUpdated = await applyChannelUpdate(entry, segment, releaseInfo, skipReleaseValidation)
if (!channelUpdated) {
return { versions, updated: false }
}
if (shouldRename && existingKey) {
delete versionsCopy[existingKey]
}
entry.metadata = {
segmentId: segment.id,
segmentType: segment.type
}
entry.minCompatibleVersion = segment.minCompatibleVersion
entry.description = segment.description
versionsCopy[targetKey] = entry
return {
versions: sortVersionMap(versionsCopy),
updated: true
}
}
function findVersionKeyBySegment(versions: Record<string, VersionEntry>, segmentId: string): string | null {
for (const [key, value] of Object.entries(versions)) {
if (value.metadata?.segmentId === segmentId) {
return key
}
}
return null
}
function resolveVersionKey(existingKey: string | null, segment: SegmentDefinition, releaseInfo: ReleaseInfo): string {
if (segment.lockedVersion) {
return segment.lockedVersion
}
if (releaseInfo.channel === 'latest') {
return releaseInfo.version
}
if (existingKey) {
return existingKey
}
const baseVersion = getBaseVersion(releaseInfo.version)
return baseVersion ?? releaseInfo.version
}
function getBaseVersion(version: string): string | null {
const parsed = semver.parse(version, { loose: true, includePrerelease: true })
if (!parsed) {
return null
}
return `${parsed.major}.${parsed.minor}.${parsed.patch}`
}
function createEmptyVersionEntry(): VersionEntry {
return {
minCompatibleVersion: '',
description: '',
channels: {
latest: null,
rc: null,
beta: null
}
}
}
function ensureChannelSlots(
channels: Record<UpgradeChannel, ChannelConfig | null>
): Record<UpgradeChannel, ChannelConfig | null> {
return CHANNELS.reduce(
(acc, channel) => {
acc[channel] = channels[channel] ?? null
return acc
},
{} as Record<UpgradeChannel, ChannelConfig | null>
)
}
async function applyChannelUpdate(
entry: VersionEntry,
segment: SegmentDefinition,
releaseInfo: ReleaseInfo,
skipReleaseValidation: boolean
): Promise<boolean> {
if (!CHANNELS.includes(releaseInfo.channel)) {
throw new Error(`Unsupported channel "${releaseInfo.channel}"`)
}
const feedUrls = buildFeedUrls(segment, releaseInfo)
if (skipReleaseValidation) {
console.warn(
`[update-app-upgrade-config] Skipping release availability validation for ${releaseInfo.version} (${releaseInfo.channel}).`
)
} else {
const availability = await ensureReleaseAvailability(releaseInfo)
if (!availability.github) {
return false
}
if (releaseInfo.channel === 'latest' && !availability.gitcode) {
console.warn(
`[update-app-upgrade-config] gitcode release page not ready for ${releaseInfo.tag}. Falling back to ${GITCODE_LATEST_FALLBACK}.`
)
feedUrls.gitcode = GITCODE_LATEST_FALLBACK
}
}
entry.channels[releaseInfo.channel] = {
version: releaseInfo.version,
feedUrls
}
return true
}
function buildFeedUrls(segment: SegmentDefinition, releaseInfo: ReleaseInfo): Record<UpdateMirror, string> {
return MIRRORS.reduce(
(acc, mirror) => {
const template = resolveFeedTemplate(segment, releaseInfo, mirror)
acc[mirror] = applyTemplate(template, releaseInfo)
return acc
},
{} as Record<UpdateMirror, string>
)
}
function resolveFeedTemplate(segment: SegmentDefinition, releaseInfo: ReleaseInfo, mirror: UpdateMirror): string {
if (mirror === 'gitcode' && releaseInfo.channel !== 'latest') {
return segment.channelTemplates?.[releaseInfo.channel]?.feedTemplates?.github ?? DEFAULT_FEED_TEMPLATES.github
}
return segment.channelTemplates?.[releaseInfo.channel]?.feedTemplates?.[mirror] ?? DEFAULT_FEED_TEMPLATES[mirror]
}
function applyTemplate(template: string, releaseInfo: ReleaseInfo): string {
return template.replace(/{{\s*tag\s*}}/gi, releaseInfo.tag).replace(/{{\s*version\s*}}/gi, releaseInfo.version)
}
function sortVersionMap(versions: Record<string, VersionEntry>): Record<string, VersionEntry> {
const sorted = Object.entries(versions).sort(([a], [b]) => semver.rcompare(a, b))
return sorted.reduce(
(acc, [version, entry]) => {
acc[version] = entry
return acc
},
{} as Record<string, VersionEntry>
)
}
interface ReleaseAvailability {
github: boolean
gitcode: boolean
}
async function ensureReleaseAvailability(releaseInfo: ReleaseInfo): Promise<ReleaseAvailability> {
const mirrorsToCheck: UpdateMirror[] = releaseInfo.channel === 'latest' ? MIRRORS : ['github']
const availability: ReleaseAvailability = {
github: false,
gitcode: releaseInfo.channel === 'latest' ? false : true
}
for (const mirror of mirrorsToCheck) {
const url = getReleasePageUrl(mirror, releaseInfo.tag)
try {
const response = await fetch(url, {
method: mirror === 'github' ? 'HEAD' : 'GET',
redirect: 'follow'
})
if (response.ok) {
availability[mirror] = true
} else {
console.warn(
`[update-app-upgrade-config] ${mirror} release not available for ${releaseInfo.tag} (status ${response.status}, ${url}).`
)
availability[mirror] = false
}
} catch (error) {
console.warn(
`[update-app-upgrade-config] Failed to verify ${mirror} release page for ${releaseInfo.tag} (${url}). Continuing.`,
error
)
availability[mirror] = false
}
}
return availability
}
function getReleasePageUrl(mirror: UpdateMirror, tag: string): string {
if (mirror === 'github') {
return `https://github.com/${GITHUB_REPO}/releases/tag/${encodeURIComponent(tag)}`
}
// Use latest.yml download URL for GitCode to check if release exists
// Note: GitCode returns 401 for HEAD requests, so we use GET in ensureReleaseAvailability
return `https://gitcode.com/${GITCODE_REPO}/releases/download/${encodeURIComponent(tag)}/latest.yml`
}
main().catch((error) => {
console.error('❌ Failed to update app-upgrade-config:', error)
process.exit(1)
})

View File

@@ -104,12 +104,6 @@ const router = express
logger.warn('No models available from providers', { filter })
}
logger.info('Models response ready', {
filter,
total: response.total,
modelIds: response.data.map((m) => m.id)
})
return res.json(response satisfies ApiModelsResponse)
} catch (error: any) {
logger.error('Error fetching models', { error })

View File

@@ -3,7 +3,6 @@ import { createServer } from 'node:http'
import { loggerService } from '@logger'
import { IpcChannel } from '@shared/IpcChannel'
import { agentService } from '../services/agents'
import { windowService } from '../services/WindowService'
import { app } from './app'
import { config } from './config'
@@ -32,11 +31,6 @@ export class ApiServer {
// Load config
const { port, host } = await config.load()
// Initialize AgentService
logger.info('Initializing AgentService')
await agentService.initialize()
logger.info('AgentService initialized')
// Create server with Express app
this.server = createServer(app)
this.applyServerTimeouts(this.server)

View File

@@ -32,7 +32,7 @@ export class ModelsService {
for (const model of models) {
const provider = providers.find((p) => p.id === model.provider)
logger.debug(`Processing model ${model.id}`)
// logger.debug(`Processing model ${model.id}`)
if (!provider) {
logger.debug(`Skipping model ${model.id} . Reason: Provider not found.`)
continue

View File

@@ -8,7 +8,7 @@ import '@main/config'
import { loggerService } from '@logger'
import { electronApp, optimizer } from '@electron-toolkit/utils'
import { replaceDevtoolsFont } from '@main/utils/windowUtil'
import { app } from 'electron'
import { app, crashReporter } from 'electron'
import installExtension, { REACT_DEVELOPER_TOOLS, REDUX_DEVTOOLS } from 'electron-devtools-installer'
import { isDev, isLinux, isWin } from './constant'
@@ -34,9 +34,18 @@ import { TrayService } from './services/TrayService'
import { versionService } from './services/VersionService'
import { windowService } from './services/WindowService'
import { initWebviewHotkeys } from './services/WebviewService'
import { runAsyncFunction } from './utils'
const logger = loggerService.withContext('MainEntry')
// enable local crash reports
crashReporter.start({
companyName: 'CherryHQ',
productName: 'CherryStudio',
submitURL: '',
uploadToServer: false
})
/**
* Disable hardware acceleration if setting is enabled
*/
@@ -162,39 +171,33 @@ if (!app.requestSingleInstanceLock()) {
//start selection assistant service
initSelectionService()
// Initialize Agent Service
try {
await agentService.initialize()
logger.info('Agent service initialized successfully')
} catch (error: any) {
logger.error('Failed to initialize Agent service:', error)
}
runAsyncFunction(async () => {
// Start API server if enabled or if agents exist
try {
const config = await apiServerService.getCurrentConfig()
logger.info('API server config:', config)
// Start API server if enabled or if agents exist
try {
const config = await apiServerService.getCurrentConfig()
logger.info('API server config:', config)
// Check if there are any agents
let shouldStart = config.enabled
if (!shouldStart) {
try {
const { total } = await agentService.listAgents({ limit: 1 })
if (total > 0) {
shouldStart = true
logger.info(`Detected ${total} agent(s), auto-starting API server`)
// Check if there are any agents
let shouldStart = config.enabled
if (!shouldStart) {
try {
const { total } = await agentService.listAgents({ limit: 1 })
if (total > 0) {
shouldStart = true
logger.info(`Detected ${total} agent(s), auto-starting API server`)
}
} catch (error: any) {
logger.warn('Failed to check agent count:', error)
}
} catch (error: any) {
logger.warn('Failed to check agent count:', error)
}
}
if (shouldStart) {
await apiServerService.start()
if (shouldStart) {
await apiServerService.start()
}
} catch (error: any) {
logger.error('Failed to check/start API server:', error)
}
} catch (error: any) {
logger.error('Failed to check/start API server:', error)
}
})
})
registerProtocolClient(app)

View File

@@ -493,6 +493,44 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle(IpcChannel.System_GetDeviceType, () => (isMac ? 'mac' : isWin ? 'windows' : 'linux'))
ipcMain.handle(IpcChannel.System_GetHostname, () => require('os').hostname())
ipcMain.handle(IpcChannel.System_GetCpuName, () => require('os').cpus()[0].model)
ipcMain.handle(IpcChannel.System_CheckGitBash, () => {
if (!isWin) {
return true // Non-Windows systems don't need Git Bash
}
try {
// Check common Git Bash installation paths
const commonPaths = [
path.join(process.env.ProgramFiles || 'C:\\Program Files', 'Git', 'bin', 'bash.exe'),
path.join(process.env['ProgramFiles(x86)'] || 'C:\\Program Files (x86)', 'Git', 'bin', 'bash.exe'),
path.join(process.env.LOCALAPPDATA || '', 'Programs', 'Git', 'bin', 'bash.exe')
]
// Check if any of the common paths exist
for (const bashPath of commonPaths) {
if (fs.existsSync(bashPath)) {
logger.debug('Git Bash found', { path: bashPath })
return true
}
}
// Check if git is in PATH
const { execSync } = require('child_process')
try {
execSync('git --version', { stdio: 'ignore' })
logger.debug('Git found in PATH')
return true
} catch {
// Git not in PATH
}
logger.debug('Git Bash not found on Windows system')
return false
} catch (error) {
logger.error('Error checking Git Bash', error as Error)
return false
}
})
ipcMain.handle(IpcChannel.System_ToggleDevTools, (e) => {
const win = BrowserWindow.fromWebContents(e.sender)
win && win.webContents.toggleDevTools()
@@ -1038,4 +1076,8 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle(IpcChannel.WebSocket_Status, WebSocketService.getStatus)
ipcMain.handle(IpcChannel.WebSocket_SendFile, WebSocketService.sendFile)
ipcMain.handle(IpcChannel.WebSocket_GetAllCandidates, WebSocketService.getAllCandidates)
ipcMain.handle(IpcChannel.APP_CrashRenderProcess, () => {
mainWindow.webContents.forcefullyCrashRenderer()
})
}

View File

@@ -21,6 +21,7 @@ type ApiResponse<T> = {
type BatchUploadResponse = {
batch_id: string
file_urls: string[]
headers?: Record<string, string>[]
}
type ExtractProgress = {
@@ -55,7 +56,7 @@ type QuotaResponse = {
export default class MineruPreprocessProvider extends BasePreprocessProvider {
constructor(provider: PreprocessProvider, userId?: string) {
super(provider, userId)
// todo免费期结束后删除
// TODO: remove after free period ends
this.provider.apiKey = this.provider.apiKey || import.meta.env.MAIN_VITE_MINERU_API_KEY
}
@@ -68,21 +69,21 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
logger.info(`MinerU preprocess processing started: ${filePath}`)
await this.validateFile(filePath)
// 1. 获取上传URL并上传文件
// 1. Get upload URL and upload file
const batchId = await this.uploadFile(file)
logger.info(`MinerU file upload completed: batch_id=${batchId}`)
// 2. 等待处理完成并获取结果
// 2. Wait for completion and fetch results
const extractResult = await this.waitForCompletion(sourceId, batchId, file.origin_name)
logger.info(`MinerU processing completed for batch: ${batchId}`)
// 3. 下载并解压文件
// 3. Download and extract output
const { path: outputPath } = await this.downloadAndExtractFile(extractResult.full_zip_url!, file)
// 4. check quota
const quota = await this.checkQuota()
// 5. 创建处理后的文件信息
// 5. Create processed file metadata
return {
processedFile: this.createProcessedFileInfo(file, outputPath),
quota
@@ -115,23 +116,48 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
}
private async validateFile(filePath: string): Promise<void> {
// Phase 1: check file size (without loading into memory)
logger.info(`Validating PDF file: ${filePath}`)
const stats = await fs.promises.stat(filePath)
const fileSizeBytes = stats.size
// Ensure file size is under 200MB
if (fileSizeBytes >= 200 * 1024 * 1024) {
const fileSizeMB = Math.round(fileSizeBytes / (1024 * 1024))
throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 200MB`)
}
// Phase 2: check page count (requires reading file with error handling)
const pdfBuffer = await fs.promises.readFile(filePath)
const doc = await this.readPdf(pdfBuffer)
try {
const doc = await this.readPdf(pdfBuffer)
// 文件页数小于600页
if (doc.numPages >= 600) {
throw new Error(`PDF page count (${doc.numPages}) exceeds the limit of 600 pages`)
}
// 文件大小小于200MB
if (pdfBuffer.length >= 200 * 1024 * 1024) {
const fileSizeMB = Math.round(pdfBuffer.length / (1024 * 1024))
throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 200MB`)
// Ensure page count is under 600 pages
if (doc.numPages >= 600) {
throw new Error(`PDF page count (${doc.numPages}) exceeds the limit of 600 pages`)
}
logger.info(`PDF validation passed: ${doc.numPages} pages, ${Math.round(fileSizeBytes / (1024 * 1024))}MB`)
} catch (error: any) {
// If the page limit is exceeded, rethrow immediately
if (error.message.includes('exceeds the limit')) {
throw error
}
// If PDF parsing fails, log a detailed warning but continue processing
logger.warn(
`Failed to parse PDF structure (file may be corrupted or use non-standard format). ` +
`Skipping page count validation. Will attempt to process with MinerU API. ` +
`Error details: ${error.message}. ` +
`Suggestion: If processing fails, try repairing the PDF using tools like Adobe Acrobat or online PDF repair services.`
)
// Do not throw; continue processing
}
}
private createProcessedFileInfo(file: FileMetadata, outputPath: string): FileMetadata {
// 查找解压后的主要文件
// Locate the main extracted file
let finalPath = ''
let finalName = file.origin_name.replace('.pdf', '.md')
@@ -143,14 +169,14 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
const originalMdPath = path.join(outputPath, mdFile)
const newMdPath = path.join(outputPath, finalName)
// 重命名文件为原始文件名
// Rename the file to match the original name
try {
fs.renameSync(originalMdPath, newMdPath)
finalPath = newMdPath
logger.info(`Renamed markdown file from ${mdFile} to ${finalName}`)
} catch (renameError) {
logger.warn(`Failed to rename file ${mdFile} to ${finalName}: ${renameError}`)
// 如果重命名失败,使用原文件
// If renaming fails, fall back to the original file
finalPath = originalMdPath
finalName = mdFile
}
@@ -178,7 +204,7 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
logger.info(`Downloading MinerU result to: ${zipPath}`)
try {
// 下载ZIP文件
// Download the ZIP file
const response = await net.fetch(zipUrl, { method: 'GET' })
if (!response.ok) {
throw new Error(`HTTP ${response.status}: ${response.statusText}`)
@@ -187,17 +213,17 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
fs.writeFileSync(zipPath, Buffer.from(arrayBuffer))
logger.info(`Downloaded ZIP file: ${zipPath}`)
// 确保提取目录存在
// Ensure the extraction directory exists
if (!fs.existsSync(extractPath)) {
fs.mkdirSync(extractPath, { recursive: true })
}
// 解压文件
// Extract the ZIP contents
const zip = new AdmZip(zipPath)
zip.extractAllTo(extractPath, true)
logger.info(`Extracted files to: ${extractPath}`)
// 删除临时ZIP文件
// Remove the temporary ZIP file
fs.unlinkSync(zipPath)
return { path: extractPath }
@@ -209,11 +235,11 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
private async uploadFile(file: FileMetadata): Promise<string> {
try {
// 步骤1: 获取上传URL
const { batchId, fileUrls } = await this.getBatchUploadUrls(file)
// 步骤2: 上传文件到获取的URL
// Step 1: obtain the upload URL
const { batchId, fileUrls, uploadHeaders } = await this.getBatchUploadUrls(file)
// Step 2: upload the file to the obtained URL
const filePath = fileStorage.getFilePathById(file)
await this.putFileToUrl(filePath, fileUrls[0])
await this.putFileToUrl(filePath, fileUrls[0], file.origin_name, uploadHeaders?.[0])
logger.info(`File uploaded successfully: ${filePath}`, { batchId, fileUrls })
return batchId
@@ -223,7 +249,9 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
}
}
private async getBatchUploadUrls(file: FileMetadata): Promise<{ batchId: string; fileUrls: string[] }> {
private async getBatchUploadUrls(
file: FileMetadata
): Promise<{ batchId: string; fileUrls: string[]; uploadHeaders?: Record<string, string>[] }> {
const endpoint = `${this.provider.apiHost}/api/v4/file-urls/batch`
const payload = {
@@ -254,10 +282,11 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
if (response.ok) {
const data: ApiResponse<BatchUploadResponse> = await response.json()
if (data.code === 0 && data.data) {
const { batch_id, file_urls } = data.data
const { batch_id, file_urls, headers: uploadHeaders } = data.data
return {
batchId: batch_id,
fileUrls: file_urls
fileUrls: file_urls,
uploadHeaders
}
} else {
throw new Error(`API returned error: ${data.msg || JSON.stringify(data)}`)
@@ -271,18 +300,28 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
}
}
private async putFileToUrl(filePath: string, uploadUrl: string): Promise<void> {
private async putFileToUrl(
filePath: string,
uploadUrl: string,
fileName?: string,
headers?: Record<string, string>
): Promise<void> {
try {
const fileBuffer = await fs.promises.readFile(filePath)
const fileSize = fileBuffer.byteLength
const displayName = fileName ?? path.basename(filePath)
logger.info(`Uploading file to MinerU OSS: ${displayName} (${fileSize} bytes)`)
// https://mineru.net/apiManage/docs
const response = await net.fetch(uploadUrl, {
method: 'PUT',
body: fileBuffer
headers,
body: new Uint8Array(fileBuffer)
})
if (!response.ok) {
// 克隆 response 以避免消费 body stream
// Clone the response to avoid consuming the body stream
const responseClone = response.clone()
try {
@@ -353,20 +392,20 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
try {
const result = await this.getExtractResults(batchId)
// 查找对应文件的处理结果
// Find the corresponding file result
const fileResult = result.extract_result.find((item) => item.file_name === fileName)
if (!fileResult) {
throw new Error(`File ${fileName} not found in batch results`)
}
// 检查处理状态
// Check the processing state
if (fileResult.state === 'done' && fileResult.full_zip_url) {
logger.info(`Processing completed for file: ${fileName}`)
return fileResult
} else if (fileResult.state === 'failed') {
throw new Error(`Processing failed for file: ${fileName}, error: ${fileResult.err_msg}`)
} else if (fileResult.state === 'running') {
// 发送进度更新
// Send progress updates
if (fileResult.extract_progress) {
const progress = Math.round(
(fileResult.extract_progress.extracted_pages / fileResult.extract_progress.total_pages) * 100
@@ -374,7 +413,7 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
await this.sendPreprocessProgress(sourceId, progress)
logger.info(`File ${fileName} processing progress: ${progress}%`)
} else {
// 如果没有具体进度信息,发送一个通用进度
// If no detailed progress information is available, send a generic update
await this.sendPreprocessProgress(sourceId, 50)
logger.info(`File ${fileName} is still processing...`)
}

View File

@@ -53,18 +53,43 @@ export default class OpenMineruPreprocessProvider extends BasePreprocessProvider
}
private async validateFile(filePath: string): Promise<void> {
// 第一阶段:检查文件大小(无需读取文件到内存)
logger.info(`Validating PDF file: ${filePath}`)
const stats = await fs.promises.stat(filePath)
const fileSizeBytes = stats.size
// File size must be less than 200MB
if (fileSizeBytes >= 200 * 1024 * 1024) {
const fileSizeMB = Math.round(fileSizeBytes / (1024 * 1024))
throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 200MB`)
}
// 第二阶段:检查页数(需要读取文件,带错误处理)
const pdfBuffer = await fs.promises.readFile(filePath)
const doc = await this.readPdf(pdfBuffer)
try {
const doc = await this.readPdf(pdfBuffer)
// File page count must be less than 600 pages
if (doc.numPages >= 600) {
throw new Error(`PDF page count (${doc.numPages}) exceeds the limit of 600 pages`)
}
// File size must be less than 200MB
if (pdfBuffer.length >= 200 * 1024 * 1024) {
const fileSizeMB = Math.round(pdfBuffer.length / (1024 * 1024))
throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 200MB`)
// File page count must be less than 600 pages
if (doc.numPages >= 600) {
throw new Error(`PDF page count (${doc.numPages}) exceeds the limit of 600 pages`)
}
logger.info(`PDF validation passed: ${doc.numPages} pages, ${Math.round(fileSizeBytes / (1024 * 1024))}MB`)
} catch (error: any) {
// 如果是页数超限错误,直接抛出
if (error.message.includes('exceeds the limit')) {
throw error
}
// PDF 解析失败,记录详细警告但允许继续处理
logger.warn(
`Failed to parse PDF structure (file may be corrupted or use non-standard format). ` +
`Skipping page count validation. Will attempt to process with MinerU API. ` +
`Error details: ${error.message}. ` +
`Suggestion: If processing fails, try repairing the PDF using tools like Adobe Acrobat or online PDF repair services.`
)
// 不抛出错误,允许继续处理
}
}
@@ -72,8 +97,8 @@ export default class OpenMineruPreprocessProvider extends BasePreprocessProvider
// Find the main file after extraction
let finalPath = ''
let finalName = file.origin_name.replace('.pdf', '.md')
// Find the corresponding folder by file name
outputPath = path.join(outputPath, `${file.origin_name.replace('.pdf', '')}`)
// Find the corresponding folder by file id
outputPath = path.join(outputPath, file.id)
try {
const files = fs.readdirSync(outputPath)
@@ -125,7 +150,7 @@ export default class OpenMineruPreprocessProvider extends BasePreprocessProvider
formData.append('return_md', 'true')
formData.append('response_format_zip', 'true')
formData.append('files', fileBuffer, {
filename: file.origin_name
filename: file.name
})
while (retries < maxRetries) {
@@ -139,7 +164,7 @@ export default class OpenMineruPreprocessProvider extends BasePreprocessProvider
...(this.provider.apiKey ? { Authorization: `Bearer ${this.provider.apiKey}` } : {}),
...formData.getHeaders()
},
body: formData.getBuffer()
body: new Uint8Array(formData.getBuffer())
})
if (!response.ok) {

View File

@@ -8,6 +8,7 @@ import DiDiMcpServer from './didi-mcp'
import DifyKnowledgeServer from './dify-knowledge'
import FetchServer from './fetch'
import FileSystemServer from './filesystem'
import MCPUIDemoServer from './mcp-ui-demo'
import MemoryServer from './memory'
import PythonServer from './python'
import ThinkingServer from './sequentialthinking'
@@ -48,6 +49,9 @@ export function createInMemoryMCPServer(
const apiKey = envs.DIDI_API_KEY
return new DiDiMcpServer(apiKey).server
}
case BuiltinMCPServerNames.mcpUIDemo: {
return new MCPUIDemoServer().server
}
default:
throw new Error(`Unknown in-memory MCP server: ${name}`)
}

View File

@@ -0,0 +1,433 @@
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
import { CallToolRequestSchema, ListToolsRequestSchema } from '@modelcontextprotocol/sdk/types.js'
const server = new Server(
{
name: 'mcp-ui-demo',
version: '1.0.0'
},
{
capabilities: {
tools: {}
}
}
)
// HTML templates for different UIs
const getHelloWorldUI = () =>
`
<!DOCTYPE html>
<html>
<head>
<style>
body {
font-family: system-ui, -apple-system, sans-serif;
padding: 20px;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
min-height: 200px;
display: flex;
align-items: center;
justify-content: center;
margin: 0;
}
.container {
text-align: center;
}
h1 {
font-size: 2.5em;
margin-bottom: 10px;
text-shadow: 2px 2px 4px rgba(0,0,0,0.3);
}
p {
font-size: 1.2em;
opacity: 0.9;
}
</style>
</head>
<body>
<div class="container">
<h1>🎉 Hello from MCP UI!</h1>
<p>This is a simple MCP UI Resource rendered in Cherry Studio</p>
</div>
</body>
</html>
`.trim()
const getInteractiveUI = () =>
`
<!DOCTYPE html>
<html>
<head>
<style>
body {
font-family: system-ui, -apple-system, sans-serif;
padding: 20px;
background: #f5f5f5;
margin: 0;
}
.container {
max-width: 600px;
margin: 0 auto;
background: white;
padding: 30px;
border-radius: 12px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
h2 {
color: #333;
margin-bottom: 20px;
}
button {
background: #667eea;
color: white;
border: none;
padding: 12px 24px;
border-radius: 6px;
cursor: pointer;
font-size: 16px;
margin: 5px;
transition: background 0.2s;
}
button:hover {
background: #5568d3;
}
#output {
margin-top: 20px;
padding: 15px;
background: #f8f9fa;
border-radius: 6px;
border: 1px solid #e0e0e0;
min-height: 50px;
}
.info {
color: #666;
font-size: 14px;
margin-top: 15px;
}
</style>
</head>
<body>
<div class="container">
<h2>Interactive MCP UI Demo</h2>
<p>Click the buttons to interact with MCP tools:</p>
<button onclick="callEchoTool()">Call Echo Tool</button>
<button onclick="getTimestamp()">Get Timestamp</button>
<button onclick="openLink()">Open External Link</button>
<div id="output"></div>
<div class="info">
This UI can communicate with the host application through postMessage API.
</div>
</div>
<script>
function callEchoTool() {
const output = document.getElementById('output');
output.innerHTML = '<p style="color: #0066cc;">Calling echo tool...</p>';
window.parent.postMessage({
type: 'tool',
payload: {
toolName: 'demo_echo',
params: {
message: 'Hello from MCP UI! Time: ' + new Date().toLocaleTimeString()
}
}
}, '*');
}
function getTimestamp() {
const output = document.getElementById('output');
const now = new Date();
output.innerHTML = \`
<p style="color: #00aa00;">
<strong>Current Timestamp:</strong><br/>
\${now.toLocaleString()}<br/>
Unix: \${Math.floor(now.getTime() / 1000)}
</p>
\`;
}
function openLink() {
window.parent.postMessage({
type: 'link',
payload: {
url: 'https://github.com/idosal/mcp-ui'
}
}, '*');
}
// Listen for responses
window.addEventListener('message', (event) => {
if (event.data.type === 'ui-message-response') {
const output = document.getElementById('output');
const { response, error } = event.data.payload;
if (error) {
output.innerHTML = \`<p style="color: #cc0000;">Error: \${error}</p>\`;
} else {
output.innerHTML = \`<p style="color: #00aa00;">Response: \${JSON.stringify(response, null, 2)}</p>\`;
}
}
});
</script>
</body>
</html>
`.trim()
const getFormUI = () =>
`
<!DOCTYPE html>
<html>
<head>
<style>
body {
font-family: system-ui, -apple-system, sans-serif;
padding: 20px;
background: #f5f5f5;
margin: 0;
}
.container {
max-width: 500px;
margin: 0 auto;
background: white;
padding: 30px;
border-radius: 12px;
box-shadow: 0 2px 10px rgba(0,0,0,0.1);
}
h2 {
color: #333;
margin-bottom: 20px;
}
.form-group {
margin-bottom: 15px;
}
label {
display: block;
margin-bottom: 5px;
color: #555;
font-weight: 500;
}
input, textarea {
width: 100%;
padding: 10px;
border: 1px solid #ddd;
border-radius: 6px;
font-size: 14px;
box-sizing: border-box;
}
textarea {
min-height: 100px;
resize: vertical;
}
button {
background: #667eea;
color: white;
border: none;
padding: 12px 24px;
border-radius: 6px;
cursor: pointer;
font-size: 16px;
width: 100%;
margin-top: 10px;
}
button:hover {
background: #5568d3;
}
#result {
margin-top: 20px;
padding: 15px;
background: #f8f9fa;
border-radius: 6px;
display: none;
}
</style>
</head>
<body>
<div class="container">
<h2>📝 Form UI Demo</h2>
<form id="demoForm" onsubmit="handleSubmit(event)">
<div class="form-group">
<label for="name">Name:</label>
<input type="text" id="name" name="name" required placeholder="Enter your name">
</div>
<div class="form-group">
<label for="email">Email:</label>
<input type="email" id="email" name="email" required placeholder="your@email.com">
</div>
<div class="form-group">
<label for="message">Message:</label>
<textarea id="message" name="message" required placeholder="Enter your message here..."></textarea>
</div>
<button type="submit">Submit Form</button>
</form>
<div id="result"></div>
</div>
<script>
function handleSubmit(event) {
event.preventDefault();
const formData = new FormData(event.target);
const data = Object.fromEntries(formData.entries());
const result = document.getElementById('result');
result.style.display = 'block';
result.innerHTML = '<p style="color: #0066cc;">Submitting form...</p>';
// Send form data to host
window.parent.postMessage({
type: 'notify',
payload: {
message: 'Form submitted with data: ' + JSON.stringify(data)
}
}, '*');
// Display result
result.innerHTML = \`
<p style="color: #00aa00;"><strong>Form Submitted!</strong></p>
<pre style="background: white; padding: 10px; border-radius: 4px; overflow-x: auto;">\${JSON.stringify(data, null, 2)}</pre>
\`;
}
</script>
</body>
</html>
`.trim()
// List available tools
server.setRequestHandler(ListToolsRequestSchema, async () => {
return {
tools: [
{
name: 'demo_echo',
description: 'Echo back the message sent from UI',
inputSchema: {
type: 'object',
properties: {
message: {
type: 'string',
description: 'Message to echo back'
}
},
required: ['message']
}
},
{
name: 'show_hello_ui',
description: 'Display a simple hello world UI with gradient background',
inputSchema: {
type: 'object',
properties: {}
}
},
{
name: 'show_interactive_ui',
description:
'Display an interactive UI demo with buttons for calling tools, getting timestamps, and opening links',
inputSchema: {
type: 'object',
properties: {}
}
},
{
name: 'show_form_ui',
description: 'Display a form UI demo with input fields for name, email, and message',
inputSchema: {
type: 'object',
properties: {}
}
}
]
}
})
// Handle tool calls
server.setRequestHandler(CallToolRequestSchema, async (request) => {
const { name, arguments: args } = request.params
if (name === 'demo_echo') {
return {
content: [
{
type: 'text',
text: JSON.stringify({
echo: args?.message || 'No message provided',
timestamp: new Date().toISOString()
})
}
]
}
}
if (name === 'show_hello_ui') {
return {
content: [
{
type: 'text',
text: JSON.stringify({
type: 'resource',
resource: {
uri: 'ui://demo/hello',
mimeType: 'text/html',
text: getHelloWorldUI()
}
})
}
]
}
}
if (name === 'show_interactive_ui') {
return {
content: [
{
type: 'text',
text: JSON.stringify({
type: 'resource',
resource: {
uri: 'ui://demo/interactive',
mimeType: 'text/html',
text: getInteractiveUI()
}
})
}
]
}
}
if (name === 'show_form_ui') {
return {
content: [
{
type: 'text',
text: JSON.stringify({
type: 'resource',
resource: {
uri: 'ui://demo/form',
mimeType: 'text/html',
text: getFormUI()
}
})
}
]
}
}
throw new Error(`Unknown tool: ${name}`)
})
class MCPUIDemoServer {
public server: Server
constructor() {
this.server = server
}
}
export default MCPUIDemoServer

View File

@@ -2,7 +2,7 @@ import { loggerService } from '@logger'
import { isWin } from '@main/constant'
import { getIpCountry } from '@main/utils/ipService'
import { generateUserAgent } from '@main/utils/systemInfo'
import { FeedUrl, UpgradeChannel } from '@shared/config/constant'
import { FeedUrl, UpdateConfigUrl, UpdateMirror, UpgradeChannel } from '@shared/config/constant'
import { IpcChannel } from '@shared/IpcChannel'
import type { UpdateInfo } from 'builder-util-runtime'
import { CancellationToken } from 'builder-util-runtime'
@@ -22,7 +22,29 @@ const LANG_MARKERS = {
EN_START: '<!--LANG:en-->',
ZH_CN_START: '<!--LANG:zh-CN-->',
END: '<!--LANG:END-->'
} as const
}
interface UpdateConfig {
lastUpdated: string
versions: {
[versionKey: string]: VersionConfig
}
}
interface VersionConfig {
minCompatibleVersion: string
description: string
channels: {
latest: ChannelConfig | null
rc: ChannelConfig | null
beta: ChannelConfig | null
}
}
interface ChannelConfig {
version: string
feedUrls: Record<UpdateMirror, string>
}
export default class AppUpdater {
autoUpdater: _AppUpdater = autoUpdater
@@ -37,7 +59,9 @@ export default class AppUpdater {
autoUpdater.requestHeaders = {
...autoUpdater.requestHeaders,
'User-Agent': generateUserAgent(),
'X-Client-Id': configManager.getClientId()
'X-Client-Id': configManager.getClientId(),
// no-cache
'Cache-Control': 'no-cache'
}
autoUpdater.on('error', (error) => {
@@ -75,61 +99,6 @@ export default class AppUpdater {
this.autoUpdater = autoUpdater
}
private async _getReleaseVersionFromGithub(channel: UpgradeChannel) {
const headers = {
Accept: 'application/vnd.github+json',
'X-GitHub-Api-Version': '2022-11-28',
'Accept-Language': 'en-US,en;q=0.9'
}
try {
logger.info(`get release version from github: ${channel}`)
const responses = await net.fetch('https://api.github.com/repos/CherryHQ/cherry-studio/releases?per_page=8', {
headers
})
const data = (await responses.json()) as GithubReleaseInfo[]
let mightHaveLatest = false
const release: GithubReleaseInfo | undefined = data.find((item: GithubReleaseInfo) => {
if (!item.draft && !item.prerelease) {
mightHaveLatest = true
}
return item.prerelease && item.tag_name.includes(`-${channel}.`)
})
if (!release) {
return null
}
// if the release version is the same as the current version, return null
if (release.tag_name === app.getVersion()) {
return null
}
if (mightHaveLatest) {
logger.info(`might have latest release, get latest release`)
const latestReleaseResponse = await net.fetch(
'https://api.github.com/repos/CherryHQ/cherry-studio/releases/latest',
{
headers
}
)
const latestRelease = (await latestReleaseResponse.json()) as GithubReleaseInfo
if (semver.gt(latestRelease.tag_name, release.tag_name)) {
logger.info(
`latest release version is ${latestRelease.tag_name}, prerelease version is ${release.tag_name}, return null`
)
return null
}
}
logger.info(`release url is ${release.tag_name}, set channel to ${channel}`)
return `https://github.com/CherryHQ/cherry-studio/releases/download/${release.tag_name}`
} catch (error) {
logger.error('Failed to get latest not draft version from github:', error as Error)
return null
}
}
public setAutoUpdate(isActive: boolean) {
autoUpdater.autoDownload = isActive
autoUpdater.autoInstallOnAppQuit = isActive
@@ -161,6 +130,88 @@ export default class AppUpdater {
return UpgradeChannel.LATEST
}
/**
* Fetch update configuration from GitHub or GitCode based on mirror
* @param mirror - Mirror to fetch config from
* @returns UpdateConfig object or null if fetch fails
*/
private async _fetchUpdateConfig(mirror: UpdateMirror): Promise<UpdateConfig | null> {
const configUrl = mirror === UpdateMirror.GITCODE ? UpdateConfigUrl.GITCODE : UpdateConfigUrl.GITHUB
try {
logger.info(`Fetching update config from ${configUrl} (mirror: ${mirror})`)
const response = await net.fetch(configUrl, {
headers: {
'User-Agent': generateUserAgent(),
Accept: 'application/json',
'X-Client-Id': configManager.getClientId(),
// no-cache
'Cache-Control': 'no-cache'
}
})
if (!response.ok) {
throw new Error(`HTTP error! status: ${response.status}`)
}
const config = (await response.json()) as UpdateConfig
logger.info(`Update config fetched successfully, last updated: ${config.lastUpdated}`)
return config
} catch (error) {
logger.error('Failed to fetch update config:', error as Error)
return null
}
}
/**
* Find compatible channel configuration based on current version
* @param currentVersion - Current app version
* @param requestedChannel - Requested upgrade channel (latest/rc/beta)
* @param config - Update configuration object
* @returns Object containing ChannelConfig and actual channel if found, null otherwise
*/
private _findCompatibleChannel(
currentVersion: string,
requestedChannel: UpgradeChannel,
config: UpdateConfig
): { config: ChannelConfig; channel: UpgradeChannel } | null {
// Get all version keys and sort descending (newest first)
const versionKeys = Object.keys(config.versions).sort(semver.rcompare)
logger.info(
`Finding compatible channel for version ${currentVersion}, requested channel: ${requestedChannel}, available versions: ${versionKeys.join(', ')}`
)
for (const versionKey of versionKeys) {
const versionConfig = config.versions[versionKey]
const channelConfig = versionConfig.channels[requestedChannel]
const latestChannelConfig = versionConfig.channels[UpgradeChannel.LATEST]
// Check version compatibility and channel availability
if (semver.gte(currentVersion, versionConfig.minCompatibleVersion) && channelConfig !== null) {
logger.info(
`Found compatible version: ${versionKey} (minCompatibleVersion: ${versionConfig.minCompatibleVersion}), version: ${channelConfig.version}`
)
if (
requestedChannel !== UpgradeChannel.LATEST &&
latestChannelConfig &&
semver.gte(latestChannelConfig.version, channelConfig.version)
) {
logger.info(
`latest channel version is greater than the requested channel version: ${latestChannelConfig.version} > ${channelConfig.version}, using latest instead`
)
return { config: latestChannelConfig, channel: UpgradeChannel.LATEST }
}
return { config: channelConfig, channel: requestedChannel }
}
}
logger.warn(`No compatible channel found for version ${currentVersion} and channel ${requestedChannel}`)
return null
}
private _setChannel(channel: UpgradeChannel, feedUrl: string) {
this.autoUpdater.channel = channel
this.autoUpdater.setFeedURL(feedUrl)
@@ -172,33 +223,42 @@ export default class AppUpdater {
}
private async _setFeedUrl() {
const currentVersion = app.getVersion()
const testPlan = configManager.getTestPlan()
if (testPlan) {
const channel = this._getTestChannel()
const requestedChannel = testPlan ? this._getTestChannel() : UpgradeChannel.LATEST
if (channel === UpgradeChannel.LATEST) {
this._setChannel(UpgradeChannel.LATEST, FeedUrl.GITHUB_LATEST)
return
}
const releaseUrl = await this._getReleaseVersionFromGithub(channel)
if (releaseUrl) {
logger.info(`release url is ${releaseUrl}, set channel to ${channel}`)
this._setChannel(channel, releaseUrl)
return
}
// if no prerelease url, use github latest to get release
this._setChannel(UpgradeChannel.LATEST, FeedUrl.GITHUB_LATEST)
return
}
this._setChannel(UpgradeChannel.LATEST, FeedUrl.PRODUCTION)
// Determine mirror based on IP country
const ipCountry = await getIpCountry()
logger.info(`ipCountry is ${ipCountry}, set channel to ${UpgradeChannel.LATEST}`)
if (ipCountry.toLowerCase() !== 'cn') {
this._setChannel(UpgradeChannel.LATEST, FeedUrl.GITHUB_LATEST)
const mirror = ipCountry.toLowerCase() === 'cn' ? UpdateMirror.GITCODE : UpdateMirror.GITHUB
logger.info(
`Setting feed URL for version ${currentVersion}, testPlan: ${testPlan}, requested channel: ${requestedChannel}, mirror: ${mirror} (IP country: ${ipCountry})`
)
// Try to fetch update config from remote
const config = await this._fetchUpdateConfig(mirror)
if (config) {
// Use new config-based system
const result = this._findCompatibleChannel(currentVersion, requestedChannel, config)
if (result) {
const { config: channelConfig, channel: actualChannel } = result
const feedUrl = channelConfig.feedUrls[mirror]
logger.info(
`Using config-based feed URL: ${feedUrl} for channel ${actualChannel} (requested: ${requestedChannel}, mirror: ${mirror})`
)
this._setChannel(actualChannel, feedUrl)
return
}
}
logger.info('Failed to fetch update config, falling back to default feed URL')
// Fallback: use default feed URL based on mirror
const defaultFeedUrl = mirror === UpdateMirror.GITCODE ? FeedUrl.PRODUCTION : FeedUrl.GITHUB_LATEST
logger.info(`Using fallback feed URL: ${defaultFeedUrl}`)
this._setChannel(UpgradeChannel.LATEST, defaultFeedUrl)
}
public cancelDownload() {
@@ -320,8 +380,3 @@ export default class AppUpdater {
return processedInfo
}
}
interface GithubReleaseInfo {
draft: boolean
prerelease: boolean
tag_name: string
}

View File

@@ -375,13 +375,16 @@ export class WindowService {
mainWindow.hide()
// TODO: don't hide dock icon when close to tray
// will cause the cmd+h behavior not working
// after the electron fix the bug, we can restore this code
// //for mac users, should hide dock icon if close to tray
// if (isMac && isTrayOnClose) {
// app.dock?.hide()
// }
//for mac users, should hide dock icon if close to tray
if (isMac && isTrayOnClose) {
app.dock?.hide()
mainWindow.once('show', () => {
//restore the window can hide by cmd+h when the window is shown again
// https://github.com/electron/electron/pull/47970
app.dock?.show()
})
}
})
mainWindow.on('closed', () => {

View File

@@ -85,6 +85,9 @@ vi.mock('electron-updater', () => ({
}))
// Import after mocks
import { UpdateMirror } from '@shared/config/constant'
import { app, net } from 'electron'
import AppUpdater from '../AppUpdater'
import { configManager } from '../ConfigManager'
@@ -274,4 +277,711 @@ describe('AppUpdater', () => {
expect(result.releaseNotes).toBeNull()
})
})
describe('_fetchUpdateConfig', () => {
const mockConfig = {
lastUpdated: '2025-01-05T00:00:00Z',
versions: {
'1.6.7': {
minCompatibleVersion: '1.0.0',
description: 'Test version',
channels: {
latest: {
version: '1.6.7',
feedUrls: {
github: 'https://github.com/test/v1.6.7',
gitcode: 'https://gitcode.com/test/v1.6.7'
}
},
rc: null,
beta: null
}
}
}
}
it('should fetch config from GitHub mirror', async () => {
vi.mocked(net.fetch).mockResolvedValue({
ok: true,
json: async () => mockConfig
} as any)
const result = await (appUpdater as any)._fetchUpdateConfig(UpdateMirror.GITHUB)
expect(result).toEqual(mockConfig)
expect(net.fetch).toHaveBeenCalledWith(expect.stringContaining('github'), expect.any(Object))
})
it('should fetch config from GitCode mirror', async () => {
vi.mocked(net.fetch).mockResolvedValue({
ok: true,
json: async () => mockConfig
} as any)
const result = await (appUpdater as any)._fetchUpdateConfig(UpdateMirror.GITCODE)
expect(result).toEqual(mockConfig)
// GitCode URL may vary, just check that fetch was called
expect(net.fetch).toHaveBeenCalledWith(expect.any(String), expect.any(Object))
})
it('should return null on HTTP error', async () => {
vi.mocked(net.fetch).mockResolvedValue({
ok: false,
status: 404
} as any)
const result = await (appUpdater as any)._fetchUpdateConfig(UpdateMirror.GITHUB)
expect(result).toBeNull()
})
it('should return null on network error', async () => {
vi.mocked(net.fetch).mockRejectedValue(new Error('Network error'))
const result = await (appUpdater as any)._fetchUpdateConfig(UpdateMirror.GITHUB)
expect(result).toBeNull()
})
})
describe('_findCompatibleChannel', () => {
const mockConfig = {
lastUpdated: '2025-01-05T00:00:00Z',
versions: {
'1.6.7': {
minCompatibleVersion: '1.0.0',
description: 'v1.6.7',
channels: {
latest: {
version: '1.6.7',
feedUrls: {
github: 'https://github.com/test/v1.6.7',
gitcode: 'https://gitcode.com/test/v1.6.7'
}
},
rc: {
version: '1.7.0-rc.1',
feedUrls: {
github: 'https://github.com/test/v1.7.0-rc.1',
gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
}
},
beta: {
version: '1.7.0-beta.3',
feedUrls: {
github: 'https://github.com/test/v1.7.0-beta.3',
gitcode: 'https://gitcode.com/test/v1.7.0-beta.3'
}
}
}
},
'2.0.0': {
minCompatibleVersion: '1.7.0',
description: 'v2.0.0',
channels: {
latest: null,
rc: null,
beta: null
}
}
}
}
it('should find compatible latest channel', () => {
vi.mocked(app.getVersion).mockReturnValue('1.5.0')
const result = (appUpdater as any)._findCompatibleChannel('1.5.0', 'latest', mockConfig)
expect(result?.config).toEqual({
version: '1.6.7',
feedUrls: {
github: 'https://github.com/test/v1.6.7',
gitcode: 'https://gitcode.com/test/v1.6.7'
}
})
expect(result?.channel).toBe('latest')
})
it('should find compatible rc channel', () => {
vi.mocked(app.getVersion).mockReturnValue('1.5.0')
const result = (appUpdater as any)._findCompatibleChannel('1.5.0', 'rc', mockConfig)
expect(result?.config).toEqual({
version: '1.7.0-rc.1',
feedUrls: {
github: 'https://github.com/test/v1.7.0-rc.1',
gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
}
})
expect(result?.channel).toBe('rc')
})
it('should find compatible beta channel', () => {
vi.mocked(app.getVersion).mockReturnValue('1.5.0')
const result = (appUpdater as any)._findCompatibleChannel('1.5.0', 'beta', mockConfig)
expect(result?.config).toEqual({
version: '1.7.0-beta.3',
feedUrls: {
github: 'https://github.com/test/v1.7.0-beta.3',
gitcode: 'https://gitcode.com/test/v1.7.0-beta.3'
}
})
expect(result?.channel).toBe('beta')
})
it('should return latest when latest version >= rc version', () => {
const configWithNewerLatest = {
lastUpdated: '2025-01-05T00:00:00Z',
versions: {
'1.7.0': {
minCompatibleVersion: '1.0.0',
description: 'v1.7.0',
channels: {
latest: {
version: '1.7.0',
feedUrls: {
github: 'https://github.com/test/v1.7.0',
gitcode: 'https://gitcode.com/test/v1.7.0'
}
},
rc: {
version: '1.7.0-rc.1',
feedUrls: {
github: 'https://github.com/test/v1.7.0-rc.1',
gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
}
},
beta: null
}
}
}
}
const result = (appUpdater as any)._findCompatibleChannel('1.6.0', 'rc', configWithNewerLatest)
// Should return latest instead of rc because 1.7.0 >= 1.7.0-rc.1
expect(result?.config).toEqual({
version: '1.7.0',
feedUrls: {
github: 'https://github.com/test/v1.7.0',
gitcode: 'https://gitcode.com/test/v1.7.0'
}
})
expect(result?.channel).toBe('latest') // ✅ 返回 latest 频道
})
it('should return latest when latest version >= beta version', () => {
const configWithNewerLatest = {
lastUpdated: '2025-01-05T00:00:00Z',
versions: {
'1.7.0': {
minCompatibleVersion: '1.0.0',
description: 'v1.7.0',
channels: {
latest: {
version: '1.7.0',
feedUrls: {
github: 'https://github.com/test/v1.7.0',
gitcode: 'https://gitcode.com/test/v1.7.0'
}
},
rc: null,
beta: {
version: '1.6.8-beta.1',
feedUrls: {
github: 'https://github.com/test/v1.6.8-beta.1',
gitcode: 'https://gitcode.com/test/v1.6.8-beta.1'
}
}
}
}
}
}
const result = (appUpdater as any)._findCompatibleChannel('1.6.0', 'beta', configWithNewerLatest)
// Should return latest instead of beta because 1.7.0 >= 1.6.8-beta.1
expect(result?.config).toEqual({
version: '1.7.0',
feedUrls: {
github: 'https://github.com/test/v1.7.0',
gitcode: 'https://gitcode.com/test/v1.7.0'
}
})
})
it('should not compare latest with itself when requesting latest channel', () => {
const config = {
lastUpdated: '2025-01-05T00:00:00Z',
versions: {
'1.7.0': {
minCompatibleVersion: '1.0.0',
description: 'v1.7.0',
channels: {
latest: {
version: '1.7.0',
feedUrls: {
github: 'https://github.com/test/v1.7.0',
gitcode: 'https://gitcode.com/test/v1.7.0'
}
},
rc: {
version: '1.7.0-rc.1',
feedUrls: {
github: 'https://github.com/test/v1.7.0-rc.1',
gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
}
},
beta: null
}
}
}
}
const result = (appUpdater as any)._findCompatibleChannel('1.6.0', 'latest', config)
// Should return latest directly without comparing with itself
expect(result?.config).toEqual({
version: '1.7.0',
feedUrls: {
github: 'https://github.com/test/v1.7.0',
gitcode: 'https://gitcode.com/test/v1.7.0'
}
})
})
it('should return rc when rc version > latest version', () => {
const configWithNewerRc = {
lastUpdated: '2025-01-05T00:00:00Z',
versions: {
'1.7.0': {
minCompatibleVersion: '1.0.0',
description: 'v1.7.0',
channels: {
latest: {
version: '1.6.7',
feedUrls: {
github: 'https://github.com/test/v1.6.7',
gitcode: 'https://gitcode.com/test/v1.6.7'
}
},
rc: {
version: '1.7.0-rc.1',
feedUrls: {
github: 'https://github.com/test/v1.7.0-rc.1',
gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
}
},
beta: null
}
}
}
}
const result = (appUpdater as any)._findCompatibleChannel('1.6.0', 'rc', configWithNewerRc)
// Should return rc because 1.7.0-rc.1 > 1.6.7
expect(result?.config).toEqual({
version: '1.7.0-rc.1',
feedUrls: {
github: 'https://github.com/test/v1.7.0-rc.1',
gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
}
})
})
it('should return beta when beta version > latest version', () => {
const configWithNewerBeta = {
lastUpdated: '2025-01-05T00:00:00Z',
versions: {
'1.7.0': {
minCompatibleVersion: '1.0.0',
description: 'v1.7.0',
channels: {
latest: {
version: '1.6.7',
feedUrls: {
github: 'https://github.com/test/v1.6.7',
gitcode: 'https://gitcode.com/test/v1.6.7'
}
},
rc: null,
beta: {
version: '1.7.0-beta.5',
feedUrls: {
github: 'https://github.com/test/v1.7.0-beta.5',
gitcode: 'https://gitcode.com/test/v1.7.0-beta.5'
}
}
}
}
}
}
const result = (appUpdater as any)._findCompatibleChannel('1.6.0', 'beta', configWithNewerBeta)
// Should return beta because 1.7.0-beta.5 > 1.6.7
expect(result?.config).toEqual({
version: '1.7.0-beta.5',
feedUrls: {
github: 'https://github.com/test/v1.7.0-beta.5',
gitcode: 'https://gitcode.com/test/v1.7.0-beta.5'
}
})
})
it('should return lower version when higher version has no compatible channel', () => {
vi.mocked(app.getVersion).mockReturnValue('1.8.0')
const result = (appUpdater as any)._findCompatibleChannel('1.8.0', 'latest', mockConfig)
// 1.8.0 >= 1.7.0 but 2.0.0 has no latest channel, so return 1.6.7
expect(result?.config).toEqual({
version: '1.6.7',
feedUrls: {
github: 'https://github.com/test/v1.6.7',
gitcode: 'https://gitcode.com/test/v1.6.7'
}
})
})
it('should return null when current version does not meet minCompatibleVersion', () => {
vi.mocked(app.getVersion).mockReturnValue('0.9.0')
const result = (appUpdater as any)._findCompatibleChannel('0.9.0', 'latest', mockConfig)
// 0.9.0 < 1.0.0 (minCompatibleVersion)
expect(result).toBeNull()
})
it('should return lower version rc when higher version has no rc channel', () => {
const result = (appUpdater as any)._findCompatibleChannel('1.8.0', 'rc', mockConfig)
// 1.8.0 >= 1.7.0 but 2.0.0 has no rc channel, so return 1.6.7 rc
expect(result?.config).toEqual({
version: '1.7.0-rc.1',
feedUrls: {
github: 'https://github.com/test/v1.7.0-rc.1',
gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
}
})
})
it('should return null when no version has the requested channel', () => {
const configWithoutRc = {
lastUpdated: '2025-01-05T00:00:00Z',
versions: {
'1.6.7': {
minCompatibleVersion: '1.0.0',
description: 'v1.6.7',
channels: {
latest: {
version: '1.6.7',
feedUrls: {
github: 'https://github.com/test/v1.6.7',
gitcode: 'https://gitcode.com/test/v1.6.7'
}
},
rc: null,
beta: null
}
}
}
}
const result = (appUpdater as any)._findCompatibleChannel('1.5.0', 'rc', configWithoutRc)
expect(result).toBeNull()
})
})
describe('Upgrade Path', () => {
const fullConfig = {
lastUpdated: '2025-01-05T00:00:00Z',
versions: {
'1.6.7': {
minCompatibleVersion: '1.0.0',
description: 'Last v1.x',
channels: {
latest: {
version: '1.6.7',
feedUrls: {
github: 'https://github.com/test/v1.6.7',
gitcode: 'https://gitcode.com/test/v1.6.7'
}
},
rc: {
version: '1.7.0-rc.1',
feedUrls: {
github: 'https://github.com/test/v1.7.0-rc.1',
gitcode: 'https://gitcode.com/test/v1.7.0-rc.1'
}
},
beta: {
version: '1.7.0-beta.3',
feedUrls: {
github: 'https://github.com/test/v1.7.0-beta.3',
gitcode: 'https://gitcode.com/test/v1.7.0-beta.3'
}
}
}
},
'2.0.0': {
minCompatibleVersion: '1.7.0',
description: 'First v2.x',
channels: {
latest: null,
rc: null,
beta: null
}
}
}
}
it('should upgrade from 1.6.3 to 1.6.7', () => {
const result = (appUpdater as any)._findCompatibleChannel('1.6.3', 'latest', fullConfig)
expect(result?.config).toEqual({
version: '1.6.7',
feedUrls: {
github: 'https://github.com/test/v1.6.7',
gitcode: 'https://gitcode.com/test/v1.6.7'
}
})
})
it('should block upgrade from 1.6.7 to 2.0.0 (minCompatibleVersion not met)', () => {
const result = (appUpdater as any)._findCompatibleChannel('1.6.7', 'latest', fullConfig)
// Should return 1.6.7, not 2.0.0, because 1.6.7 < 1.7.0 (minCompatibleVersion of 2.0.0)
expect(result?.config).toEqual({
version: '1.6.7',
feedUrls: {
github: 'https://github.com/test/v1.6.7',
gitcode: 'https://gitcode.com/test/v1.6.7'
}
})
})
it('should allow upgrade from 1.7.0 to 2.0.0', () => {
const configWith2x = {
...fullConfig,
versions: {
...fullConfig.versions,
'2.0.0': {
minCompatibleVersion: '1.7.0',
description: 'First v2.x',
channels: {
latest: {
version: '2.0.0',
feedUrls: {
github: 'https://github.com/test/v2.0.0',
gitcode: 'https://gitcode.com/test/v2.0.0'
}
},
rc: null,
beta: null
}
}
}
}
const result = (appUpdater as any)._findCompatibleChannel('1.7.0', 'latest', configWith2x)
expect(result?.config).toEqual({
version: '2.0.0',
feedUrls: {
github: 'https://github.com/test/v2.0.0',
gitcode: 'https://gitcode.com/test/v2.0.0'
}
})
})
})
describe('Complete Multi-Step Upgrade Path', () => {
const fullUpgradeConfig = {
lastUpdated: '2025-01-05T00:00:00Z',
versions: {
'1.7.5': {
minCompatibleVersion: '1.0.0',
description: 'Last v1.x stable',
channels: {
latest: {
version: '1.7.5',
feedUrls: {
github: 'https://github.com/test/v1.7.5',
gitcode: 'https://gitcode.com/test/v1.7.5'
}
},
rc: null,
beta: null
}
},
'2.0.0': {
minCompatibleVersion: '1.7.0',
description: 'First v2.x - intermediate version',
channels: {
latest: {
version: '2.0.0',
feedUrls: {
github: 'https://github.com/test/v2.0.0',
gitcode: 'https://gitcode.com/test/v2.0.0'
}
},
rc: null,
beta: null
}
},
'2.1.6': {
minCompatibleVersion: '2.0.0',
description: 'Current v2.x stable',
channels: {
latest: {
version: '2.1.6',
feedUrls: {
github: 'https://github.com/test/latest',
gitcode: 'https://gitcode.com/test/latest'
}
},
rc: null,
beta: null
}
}
}
}
it('should upgrade from 1.6.3 to 1.7.5 (step 1)', () => {
const result = (appUpdater as any)._findCompatibleChannel('1.6.3', 'latest', fullUpgradeConfig)
expect(result?.config).toEqual({
version: '1.7.5',
feedUrls: {
github: 'https://github.com/test/v1.7.5',
gitcode: 'https://gitcode.com/test/v1.7.5'
}
})
})
it('should upgrade from 1.7.5 to 2.0.0 (step 2)', () => {
const result = (appUpdater as any)._findCompatibleChannel('1.7.5', 'latest', fullUpgradeConfig)
expect(result?.config).toEqual({
version: '2.0.0',
feedUrls: {
github: 'https://github.com/test/v2.0.0',
gitcode: 'https://gitcode.com/test/v2.0.0'
}
})
})
it('should upgrade from 2.0.0 to 2.1.6 (step 3)', () => {
const result = (appUpdater as any)._findCompatibleChannel('2.0.0', 'latest', fullUpgradeConfig)
expect(result?.config).toEqual({
version: '2.1.6',
feedUrls: {
github: 'https://github.com/test/latest',
gitcode: 'https://gitcode.com/test/latest'
}
})
})
it('should complete full upgrade path: 1.6.3 -> 1.7.5 -> 2.0.0 -> 2.1.6', () => {
// Step 1: 1.6.3 -> 1.7.5
let currentVersion = '1.6.3'
let result = (appUpdater as any)._findCompatibleChannel(currentVersion, 'latest', fullUpgradeConfig)
expect(result?.config.version).toBe('1.7.5')
// Step 2: 1.7.5 -> 2.0.0
currentVersion = result?.config.version!
result = (appUpdater as any)._findCompatibleChannel(currentVersion, 'latest', fullUpgradeConfig)
expect(result?.config.version).toBe('2.0.0')
// Step 3: 2.0.0 -> 2.1.6
currentVersion = result?.config.version!
result = (appUpdater as any)._findCompatibleChannel(currentVersion, 'latest', fullUpgradeConfig)
expect(result?.config.version).toBe('2.1.6')
// Final: 2.1.6 is the latest, no more upgrades
currentVersion = result?.config.version!
result = (appUpdater as any)._findCompatibleChannel(currentVersion, 'latest', fullUpgradeConfig)
expect(result?.config.version).toBe('2.1.6')
})
it('should block direct upgrade from 1.6.3 to 2.0.0 (skip intermediate)', () => {
const result = (appUpdater as any)._findCompatibleChannel('1.6.3', 'latest', fullUpgradeConfig)
// Should return 1.7.5, not 2.0.0, because 1.6.3 < 1.7.0 (minCompatibleVersion of 2.0.0)
expect(result?.config.version).toBe('1.7.5')
expect(result?.config.version).not.toBe('2.0.0')
})
it('should block direct upgrade from 1.7.5 to 2.1.6 (skip intermediate)', () => {
const result = (appUpdater as any)._findCompatibleChannel('1.7.5', 'latest', fullUpgradeConfig)
// Should return 2.0.0, not 2.1.6, because 1.7.5 < 2.0.0 (minCompatibleVersion of 2.1.6)
expect(result?.config.version).toBe('2.0.0')
expect(result?.config.version).not.toBe('2.1.6')
})
})
})

View File

@@ -1,17 +1,13 @@
import { type Client, createClient } from '@libsql/client'
import { loggerService } from '@logger'
import { mcpApiService } from '@main/apiServer/services/mcp'
import type { ModelValidationError } from '@main/apiServer/utils'
import { validateModelId } from '@main/apiServer/utils'
import type { AgentType, MCPTool, SlashCommand, Tool } from '@types'
import { objectKeys } from '@types'
import { drizzle, type LibSQLDatabase } from 'drizzle-orm/libsql'
import fs from 'fs'
import path from 'path'
import { MigrationService } from './database/MigrationService'
import * as schema from './database/schema'
import { dbPath } from './drizzle.config'
import { DatabaseManager } from './database/DatabaseManager'
import type { AgentModelField } from './errors'
import { AgentModelValidationError } from './errors'
import { builtinSlashCommands } from './services/claudecode/commands'
@@ -20,22 +16,16 @@ import { builtinTools } from './services/claudecode/tools'
const logger = loggerService.withContext('BaseService')
/**
* Base service class providing shared database connection and utilities
* for all agent-related services.
* Base service class providing shared utilities for all agent-related services.
*
* Features:
* - Programmatic schema management (no CLI dependencies)
* - Automatic table creation and migration
* - Schema version tracking and compatibility checks
* - Transaction-based operations for safety
* - Development vs production mode handling
* - Connection retry logic with exponential backoff
* - Database access through DatabaseManager singleton
* - JSON field serialization/deserialization
* - Path validation and creation
* - Model validation
* - MCP tools and slash commands listing
*/
export abstract class BaseService {
protected static client: Client | null = null
protected static db: LibSQLDatabase<typeof schema> | null = null
protected static isInitialized = false
protected static initializationPromise: Promise<void> | null = null
protected jsonFields: string[] = [
'tools',
'mcps',
@@ -45,23 +35,6 @@ export abstract class BaseService {
'slash_commands'
]
/**
* Initialize database with retry logic and proper error handling
*/
protected static async initialize(): Promise<void> {
// Return existing initialization if in progress
if (BaseService.initializationPromise) {
return BaseService.initializationPromise
}
if (BaseService.isInitialized) {
return
}
BaseService.initializationPromise = BaseService.performInitialization()
return BaseService.initializationPromise
}
public async listMcpTools(agentType: AgentType, ids?: string[]): Promise<Tool[]> {
const tools: Tool[] = []
if (agentType === 'claude-code') {
@@ -101,78 +74,13 @@ export abstract class BaseService {
return []
}
private static async performInitialization(): Promise<void> {
const maxRetries = 3
let lastError: Error
for (let attempt = 1; attempt <= maxRetries; attempt++) {
try {
logger.info(`Initializing Agent database at: ${dbPath} (attempt ${attempt}/${maxRetries})`)
// Ensure the database directory exists
const dbDir = path.dirname(dbPath)
if (!fs.existsSync(dbDir)) {
logger.info(`Creating database directory: ${dbDir}`)
fs.mkdirSync(dbDir, { recursive: true })
}
BaseService.client = createClient({
url: `file:${dbPath}`
})
BaseService.db = drizzle(BaseService.client, { schema })
// Run database migrations
const migrationService = new MigrationService(BaseService.db, BaseService.client)
await migrationService.runMigrations()
BaseService.isInitialized = true
logger.info('Agent database initialized successfully')
return
} catch (error) {
lastError = error as Error
logger.warn(`Database initialization attempt ${attempt} failed:`, lastError)
// Clean up on failure
if (BaseService.client) {
try {
BaseService.client.close()
} catch (closeError) {
logger.warn('Failed to close client during cleanup:', closeError as Error)
}
}
BaseService.client = null
BaseService.db = null
// Wait before retrying (exponential backoff)
if (attempt < maxRetries) {
const delay = Math.pow(2, attempt) * 1000 // 2s, 4s, 8s
logger.info(`Retrying in ${delay}ms...`)
await new Promise((resolve) => setTimeout(resolve, delay))
}
}
}
// All retries failed
BaseService.initializationPromise = null
logger.error('Failed to initialize Agent database after all retries:', lastError!)
throw lastError!
}
protected ensureInitialized(): void {
if (!BaseService.isInitialized || !BaseService.db || !BaseService.client) {
throw new Error('Database not initialized. Call initialize() first.')
}
}
protected get database(): LibSQLDatabase<typeof schema> {
this.ensureInitialized()
return BaseService.db!
}
protected get rawClient(): Client {
this.ensureInitialized()
return BaseService.client!
/**
* Get database instance
* Automatically waits for initialization to complete
*/
protected async getDatabase() {
const dbManager = await DatabaseManager.getInstance()
return dbManager.getDatabase()
}
protected serializeJsonFields(data: any): any {
@@ -284,7 +192,7 @@ export abstract class BaseService {
}
/**
* Force re-initialization (for development/testing)
* Validate agent model configuration
*/
protected async validateAgentModels(
agentType: AgentType,
@@ -325,22 +233,4 @@ export abstract class BaseService {
}
}
}
static async reinitialize(): Promise<void> {
BaseService.isInitialized = false
BaseService.initializationPromise = null
if (BaseService.client) {
try {
BaseService.client.close()
} catch (error) {
logger.warn('Failed to close client during reinitialize:', error as Error)
}
}
BaseService.client = null
BaseService.db = null
await BaseService.initialize()
}
}

View File

@@ -0,0 +1,156 @@
import { type Client, createClient } from '@libsql/client'
import { loggerService } from '@logger'
import type { LibSQLDatabase } from 'drizzle-orm/libsql'
import { drizzle } from 'drizzle-orm/libsql'
import fs from 'fs'
import path from 'path'
import { dbPath } from '../drizzle.config'
import { MigrationService } from './MigrationService'
import * as schema from './schema'
const logger = loggerService.withContext('DatabaseManager')
/**
* Database initialization state
*/
enum InitState {
INITIALIZING = 'initializing',
INITIALIZED = 'initialized',
FAILED = 'failed'
}
/**
* DatabaseManager - Singleton class for managing libsql database connections
*
* Responsibilities:
* - Single source of truth for database connection
* - Thread-safe initialization with state management
* - Automatic migration handling
* - Safe connection cleanup
* - Error recovery and retry logic
* - Windows platform compatibility fixes
*/
export class DatabaseManager {
private static instance: DatabaseManager | null = null
private client: Client | null = null
private db: LibSQLDatabase<typeof schema> | null = null
private state: InitState = InitState.INITIALIZING
/**
* Get the singleton instance (database initialization starts automatically)
*/
public static async getInstance(): Promise<DatabaseManager> {
if (DatabaseManager.instance) {
return DatabaseManager.instance
}
const instance = new DatabaseManager()
await instance.initialize()
DatabaseManager.instance = instance
return instance
}
/**
* Perform the actual initialization
*/
public async initialize(): Promise<void> {
if (this.state === InitState.INITIALIZED) {
return
}
try {
logger.info(`Initializing database at: ${dbPath}`)
// Ensure database directory exists
const dbDir = path.dirname(dbPath)
if (!fs.existsSync(dbDir)) {
logger.info(`Creating database directory: ${dbDir}`)
fs.mkdirSync(dbDir, { recursive: true })
}
// Check if database file is corrupted (Windows specific check)
if (fs.existsSync(dbPath)) {
const stats = fs.statSync(dbPath)
if (stats.size === 0) {
logger.warn('Database file is empty, removing corrupted file')
fs.unlinkSync(dbPath)
}
}
// Create client with platform-specific options
this.client = createClient({
url: `file:${dbPath}`,
// intMode: 'number' helps avoid some Windows compatibility issues
intMode: 'number'
})
// Create drizzle instance
this.db = drizzle(this.client, { schema })
// Run migrations
const migrationService = new MigrationService(this.db, this.client)
await migrationService.runMigrations()
this.state = InitState.INITIALIZED
logger.info('Database initialized successfully')
} catch (error) {
const err = error as Error
logger.error('Database initialization failed:', {
error: err.message,
stack: err.stack
})
// Clean up failed initialization
this.cleanupFailedInit()
// Set failed state
this.state = InitState.FAILED
throw new Error(`Database initialization failed: ${err.message || 'Unknown error'}`)
}
}
/**
* Clean up after failed initialization
*/
private cleanupFailedInit(): void {
if (this.client) {
try {
// On Windows, closing a partially initialized client can crash
// Wrap in try-catch and ignore errors during cleanup
this.client.close()
} catch (error) {
logger.warn('Failed to close client during cleanup:', error as Error)
}
}
this.client = null
this.db = null
}
/**
* Get the database instance
* Automatically waits for initialization to complete
* @throws Error if database initialization failed
*/
public getDatabase(): LibSQLDatabase<typeof schema> {
return this.db!
}
/**
* Get the raw client (for advanced operations)
* Automatically waits for initialization to complete
* @throws Error if database initialization failed
*/
public async getClient(): Promise<Client> {
return this.client!
}
/**
* Check if database is initialized
*/
public isInitialized(): boolean {
return this.state === InitState.INITIALIZED
}
}

View File

@@ -7,8 +7,14 @@
* Schema evolution is handled by Drizzle Kit migrations.
*/
// Database Manager (Singleton)
export * from './DatabaseManager'
// Drizzle ORM schemas
export * from './schema'
// Repository helpers
export * from './sessionMessageRepository'
// Migration Service
export * from './MigrationService'

View File

@@ -15,26 +15,16 @@ import { sessionMessagesTable } from './schema'
const logger = loggerService.withContext('AgentMessageRepository')
type TxClient = any
export type PersistUserMessageParams = AgentMessageUserPersistPayload & {
sessionId: string
agentSessionId?: string
tx?: TxClient
}
export type PersistAssistantMessageParams = AgentMessageAssistantPersistPayload & {
sessionId: string
agentSessionId: string
tx?: TxClient
}
type PersistExchangeParams = AgentMessagePersistExchangePayload & {
tx?: TxClient
}
type PersistExchangeResult = AgentMessagePersistExchangeResult
class AgentMessageRepository extends BaseService {
private static instance: AgentMessageRepository | null = null
@@ -87,17 +77,13 @@ class AgentMessageRepository extends BaseService {
return deserialized
}
private getWriter(tx?: TxClient): TxClient {
return tx ?? this.database
}
private async findExistingMessageRow(
writer: TxClient,
sessionId: string,
role: string,
messageId: string
): Promise<SessionMessageRow | null> {
const candidateRows: SessionMessageRow[] = await writer
const database = await this.getDatabase()
const candidateRows: SessionMessageRow[] = await database
.select()
.from(sessionMessagesTable)
.where(and(eq(sessionMessagesTable.session_id, sessionId), eq(sessionMessagesTable.role, role)))
@@ -122,10 +108,7 @@ class AgentMessageRepository extends BaseService {
private async upsertMessage(
params: PersistUserMessageParams | PersistAssistantMessageParams
): Promise<AgentSessionMessageEntity> {
await AgentMessageRepository.initialize()
this.ensureInitialized()
const { sessionId, agentSessionId = '', payload, metadata, createdAt, tx } = params
const { sessionId, agentSessionId = '', payload, metadata, createdAt } = params
if (!payload?.message?.role) {
throw new Error('Message payload missing role')
@@ -135,18 +118,18 @@ class AgentMessageRepository extends BaseService {
throw new Error('Message payload missing id')
}
const writer = this.getWriter(tx)
const database = await this.getDatabase()
const now = createdAt ?? payload.message.createdAt ?? new Date().toISOString()
const serializedPayload = this.serializeMessage(payload)
const serializedMetadata = this.serializeMetadata(metadata)
const existingRow = await this.findExistingMessageRow(writer, sessionId, payload.message.role, payload.message.id)
const existingRow = await this.findExistingMessageRow(sessionId, payload.message.role, payload.message.id)
if (existingRow) {
const metadataToPersist = serializedMetadata ?? existingRow.metadata ?? undefined
const agentSessionToPersist = agentSessionId || existingRow.agent_session_id || ''
await writer
await database
.update(sessionMessagesTable)
.set({
content: serializedPayload,
@@ -175,7 +158,7 @@ class AgentMessageRepository extends BaseService {
updated_at: now
}
const [saved] = await writer.insert(sessionMessagesTable).values(insertData).returning()
const [saved] = await database.insert(sessionMessagesTable).values(insertData).returning()
return this.deserialize(saved)
}
@@ -188,49 +171,38 @@ class AgentMessageRepository extends BaseService {
return this.upsertMessage(params)
}
async persistExchange(params: PersistExchangeParams): Promise<PersistExchangeResult> {
await AgentMessageRepository.initialize()
this.ensureInitialized()
async persistExchange(params: AgentMessagePersistExchangePayload): Promise<AgentMessagePersistExchangeResult> {
const { sessionId, agentSessionId, user, assistant } = params
const result = await this.database.transaction(async (tx) => {
const exchangeResult: PersistExchangeResult = {}
const exchangeResult: AgentMessagePersistExchangeResult = {}
if (user?.payload) {
exchangeResult.userMessage = await this.persistUserMessage({
sessionId,
agentSessionId,
payload: user.payload,
metadata: user.metadata,
createdAt: user.createdAt,
tx
})
}
if (user?.payload) {
exchangeResult.userMessage = await this.persistUserMessage({
sessionId,
agentSessionId,
payload: user.payload,
metadata: user.metadata,
createdAt: user.createdAt
})
}
if (assistant?.payload) {
exchangeResult.assistantMessage = await this.persistAssistantMessage({
sessionId,
agentSessionId,
payload: assistant.payload,
metadata: assistant.metadata,
createdAt: assistant.createdAt,
tx
})
}
if (assistant?.payload) {
exchangeResult.assistantMessage = await this.persistAssistantMessage({
sessionId,
agentSessionId,
payload: assistant.payload,
metadata: assistant.metadata,
createdAt: assistant.createdAt
})
}
return exchangeResult
})
return result
return exchangeResult
}
async getSessionHistory(sessionId: string): Promise<AgentPersistedMessage[]> {
await AgentMessageRepository.initialize()
this.ensureInitialized()
try {
const rows = await this.database
const database = await this.getDatabase()
const rows = await database
.select()
.from(sessionMessagesTable)
.where(eq(sessionMessagesTable.session_id, sessionId))

View File

@@ -32,14 +32,8 @@ export class AgentService extends BaseService {
return AgentService.instance
}
async initialize(): Promise<void> {
await BaseService.initialize()
}
// Agent Methods
async createAgent(req: CreateAgentRequest): Promise<CreateAgentResponse> {
this.ensureInitialized()
const id = `agent_${Date.now()}_${Math.random().toString(36).substring(2, 11)}`
const now = new Date().toISOString()
@@ -75,8 +69,9 @@ export class AgentService extends BaseService {
updated_at: now
}
await this.database.insert(agentsTable).values(insertData)
const result = await this.database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
const database = await this.getDatabase()
await database.insert(agentsTable).values(insertData)
const result = await database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
if (!result[0]) {
throw new Error('Failed to create agent')
}
@@ -86,9 +81,8 @@ export class AgentService extends BaseService {
}
async getAgent(id: string): Promise<GetAgentResponse | null> {
this.ensureInitialized()
const result = await this.database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
const database = await this.getDatabase()
const result = await database.select().from(agentsTable).where(eq(agentsTable.id, id)).limit(1)
if (!result[0]) {
return null
@@ -118,9 +112,9 @@ export class AgentService extends BaseService {
}
async listAgents(options: ListOptions = {}): Promise<{ agents: AgentEntity[]; total: number }> {
this.ensureInitialized() // Build query with pagination
const totalResult = await this.database.select({ count: count() }).from(agentsTable)
// Build query with pagination
const database = await this.getDatabase()
const totalResult = await database.select({ count: count() }).from(agentsTable)
const sortBy = options.sortBy || 'created_at'
const orderBy = options.orderBy || 'desc'
@@ -128,7 +122,7 @@ export class AgentService extends BaseService {
const sortField = agentsTable[sortBy]
const orderFn = orderBy === 'asc' ? asc : desc
const baseQuery = this.database.select().from(agentsTable).orderBy(orderFn(sortField))
const baseQuery = database.select().from(agentsTable).orderBy(orderFn(sortField))
const result =
options.limit !== undefined
@@ -151,8 +145,6 @@ export class AgentService extends BaseService {
updates: UpdateAgentRequest,
options: { replace?: boolean } = {}
): Promise<UpdateAgentResponse | null> {
this.ensureInitialized()
// Check if agent exists
const existing = await this.getAgent(id)
if (!existing) {
@@ -195,22 +187,21 @@ export class AgentService extends BaseService {
}
}
await this.database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id))
const database = await this.getDatabase()
await database.update(agentsTable).set(updateData).where(eq(agentsTable.id, id))
return await this.getAgent(id)
}
async deleteAgent(id: string): Promise<boolean> {
this.ensureInitialized()
const result = await this.database.delete(agentsTable).where(eq(agentsTable.id, id))
const database = await this.getDatabase()
const result = await database.delete(agentsTable).where(eq(agentsTable.id, id))
return result.rowsAffected > 0
}
async agentExists(id: string): Promise<boolean> {
this.ensureInitialized()
const result = await this.database
const database = await this.getDatabase()
const result = await database
.select({ id: agentsTable.id })
.from(agentsTable)
.where(eq(agentsTable.id, id))

View File

@@ -104,14 +104,9 @@ export class SessionMessageService extends BaseService {
return SessionMessageService.instance
}
async initialize(): Promise<void> {
await BaseService.initialize()
}
async sessionMessageExists(id: number): Promise<boolean> {
this.ensureInitialized()
const result = await this.database
const database = await this.getDatabase()
const result = await database
.select({ id: sessionMessagesTable.id })
.from(sessionMessagesTable)
.where(eq(sessionMessagesTable.id, id))
@@ -124,10 +119,9 @@ export class SessionMessageService extends BaseService {
sessionId: string,
options: ListOptions = {}
): Promise<{ messages: AgentSessionMessageEntity[] }> {
this.ensureInitialized()
// Get messages with pagination
const baseQuery = this.database
const database = await this.getDatabase()
const baseQuery = database
.select()
.from(sessionMessagesTable)
.where(eq(sessionMessagesTable.session_id, sessionId))
@@ -146,9 +140,8 @@ export class SessionMessageService extends BaseService {
}
async deleteSessionMessage(sessionId: string, messageId: number): Promise<boolean> {
this.ensureInitialized()
const result = await this.database
const database = await this.getDatabase()
const result = await database
.delete(sessionMessagesTable)
.where(and(eq(sessionMessagesTable.id, messageId), eq(sessionMessagesTable.session_id, sessionId)))
@@ -160,8 +153,6 @@ export class SessionMessageService extends BaseService {
messageData: CreateSessionMessageRequest,
abortController: AbortController
): Promise<SessionStreamResult> {
this.ensureInitialized()
return await this.startSessionMessageStream(session, messageData, abortController)
}
@@ -270,10 +261,9 @@ export class SessionMessageService extends BaseService {
}
private async getLastAgentSessionId(sessionId: string): Promise<string> {
this.ensureInitialized()
try {
const result = await this.database
const database = await this.getDatabase()
const result = await database
.select({ agent_session_id: sessionMessagesTable.agent_session_id })
.from(sessionMessagesTable)
.where(and(eq(sessionMessagesTable.session_id, sessionId), not(eq(sessionMessagesTable.agent_session_id, ''))))

View File

@@ -30,10 +30,6 @@ export class SessionService extends BaseService {
return SessionService.instance
}
async initialize(): Promise<void> {
await BaseService.initialize()
}
/**
* Override BaseService.listSlashCommands to merge builtin and plugin commands
*/
@@ -84,13 +80,12 @@ export class SessionService extends BaseService {
agentId: string,
req: Partial<CreateSessionRequest> = {}
): Promise<GetAgentSessionResponse | null> {
this.ensureInitialized()
// Validate agent exists - we'll need to import AgentService for this check
// For now, we'll skip this validation to avoid circular dependencies
// The database foreign key constraint will handle this
const agents = await this.database.select().from(agentsTable).where(eq(agentsTable.id, agentId)).limit(1)
const database = await this.getDatabase()
const agents = await database.select().from(agentsTable).where(eq(agentsTable.id, agentId)).limit(1)
if (!agents[0]) {
throw new Error('Agent not found')
}
@@ -135,9 +130,10 @@ export class SessionService extends BaseService {
updated_at: now
}
await this.database.insert(sessionsTable).values(insertData)
const db = await this.getDatabase()
await db.insert(sessionsTable).values(insertData)
const result = await this.database.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1)
const result = await db.select().from(sessionsTable).where(eq(sessionsTable.id, id)).limit(1)
if (!result[0]) {
throw new Error('Failed to create session')
@@ -148,9 +144,8 @@ export class SessionService extends BaseService {
}
async getSession(agentId: string, id: string): Promise<GetAgentSessionResponse | null> {
this.ensureInitialized()
const result = await this.database
const database = await this.getDatabase()
const result = await database
.select()
.from(sessionsTable)
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
@@ -176,8 +171,6 @@ export class SessionService extends BaseService {
agentId?: string,
options: ListOptions = {}
): Promise<{ sessions: AgentSessionEntity[]; total: number }> {
this.ensureInitialized()
// Build where conditions
const whereConditions: SQL[] = []
if (agentId) {
@@ -192,16 +185,13 @@ export class SessionService extends BaseService {
: undefined
// Get total count
const totalResult = await this.database.select({ count: count() }).from(sessionsTable).where(whereClause)
const database = await this.getDatabase()
const totalResult = await database.select({ count: count() }).from(sessionsTable).where(whereClause)
const total = totalResult[0].count
// Build list query with pagination - sort by updated_at descending (latest first)
const baseQuery = this.database
.select()
.from(sessionsTable)
.where(whereClause)
.orderBy(desc(sessionsTable.updated_at))
const baseQuery = database.select().from(sessionsTable).where(whereClause).orderBy(desc(sessionsTable.updated_at))
const result =
options.limit !== undefined
@@ -220,8 +210,6 @@ export class SessionService extends BaseService {
id: string,
updates: UpdateSessionRequest
): Promise<UpdateSessionResponse | null> {
this.ensureInitialized()
// Check if session exists
const existing = await this.getSession(agentId, id)
if (!existing) {
@@ -262,15 +250,15 @@ export class SessionService extends BaseService {
}
}
await this.database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))
const database = await this.getDatabase()
await database.update(sessionsTable).set(updateData).where(eq(sessionsTable.id, id))
return await this.getSession(agentId, id)
}
async deleteSession(agentId: string, id: string): Promise<boolean> {
this.ensureInitialized()
const result = await this.database
const database = await this.getDatabase()
const result = await database
.delete(sessionsTable)
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))
@@ -278,9 +266,8 @@ export class SessionService extends BaseService {
}
async sessionExists(agentId: string, id: string): Promise<boolean> {
this.ensureInitialized()
const result = await this.database
const database = await this.getDatabase()
const result = await database
.select({ id: sessionsTable.id })
.from(sessionsTable)
.where(and(eq(sessionsTable.id, id), eq(sessionsTable.agent_id, agentId)))

View File

@@ -21,11 +21,16 @@ describe('stripLocalCommandTags', () => {
'<local-command-stdout>line1</local-command-stdout>\nkeep\n<local-command-stderr>Error</local-command-stderr>'
expect(stripLocalCommandTags(input)).toBe('line1\nkeep\nError')
})
it('if no tags present, returns original string', () => {
const input = 'just some normal text'
expect(stripLocalCommandTags(input)).toBe(input)
})
})
describe('Claude → AiSDK transform', () => {
it('handles tool call streaming lifecycle', () => {
const state = new ClaudeStreamState()
const state = new ClaudeStreamState({ agentSessionId: baseStreamMetadata.session_id })
const parts: ReturnType<typeof transformSDKMessageToStreamParts>[number][] = []
const messages: SDKMessage[] = [
@@ -182,14 +187,119 @@ describe('Claude → AiSDK transform', () => {
(typeof parts)[number],
{ type: 'tool-result' }
>
expect(toolResult.toolCallId).toBe('tool-1')
expect(toolResult.toolCallId).toBe('session-123:tool-1')
expect(toolResult.toolName).toBe('Bash')
expect(toolResult.input).toEqual({ command: 'ls' })
expect(toolResult.output).toBe('ok')
})
it('handles tool calls without streaming events (no content_block_start/stop)', () => {
const state = new ClaudeStreamState({ agentSessionId: '12344' })
const parts: ReturnType<typeof transformSDKMessageToStreamParts>[number][] = []
const messages: SDKMessage[] = [
{
...baseStreamMetadata,
type: 'assistant',
uuid: uuid(20),
message: {
id: 'msg-tool-no-stream',
type: 'message',
role: 'assistant',
model: 'claude-test',
content: [
{
type: 'tool_use',
id: 'tool-read',
name: 'Read',
input: { file_path: '/test.txt' }
},
{
type: 'tool_use',
id: 'tool-bash',
name: 'Bash',
input: { command: 'ls -la' }
}
],
stop_reason: 'tool_use',
stop_sequence: null,
usage: {
input_tokens: 10,
output_tokens: 20
}
}
} as unknown as SDKMessage,
{
...baseStreamMetadata,
type: 'user',
uuid: uuid(21),
message: {
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: 'tool-read',
content: 'file contents',
is_error: false
}
]
}
} as SDKMessage,
{
...baseStreamMetadata,
type: 'user',
uuid: uuid(22),
message: {
role: 'user',
content: [
{
type: 'tool_result',
tool_use_id: 'tool-bash',
content: 'total 42\n...',
is_error: false
}
]
}
} as SDKMessage
]
for (const message of messages) {
const transformed = transformSDKMessageToStreamParts(message, state)
parts.push(...transformed)
}
const types = parts.map((part) => part.type)
expect(types).toEqual(['tool-call', 'tool-call', 'tool-result', 'tool-result'])
const toolCalls = parts.filter((part) => part.type === 'tool-call') as Extract<
(typeof parts)[number],
{ type: 'tool-call' }
>[]
expect(toolCalls).toHaveLength(2)
expect(toolCalls[0].toolName).toBe('Read')
expect(toolCalls[0].toolCallId).toBe('12344:tool-read')
expect(toolCalls[1].toolName).toBe('Bash')
expect(toolCalls[1].toolCallId).toBe('12344:tool-bash')
const toolResults = parts.filter((part) => part.type === 'tool-result') as Extract<
(typeof parts)[number],
{ type: 'tool-result' }
>[]
expect(toolResults).toHaveLength(2)
// This is the key assertion - toolName should NOT be 'unknown'
expect(toolResults[0].toolName).toBe('Read')
expect(toolResults[0].toolCallId).toBe('12344:tool-read')
expect(toolResults[0].input).toEqual({ file_path: '/test.txt' })
expect(toolResults[0].output).toBe('file contents')
expect(toolResults[1].toolName).toBe('Bash')
expect(toolResults[1].toolCallId).toBe('12344:tool-bash')
expect(toolResults[1].input).toEqual({ command: 'ls -la' })
expect(toolResults[1].output).toBe('total 42\n...')
})
it('handles streaming text completion', () => {
const state = new ClaudeStreamState()
const state = new ClaudeStreamState({ agentSessionId: baseStreamMetadata.session_id })
const parts: ReturnType<typeof transformSDKMessageToStreamParts>[number][] = []
const messages: SDKMessage[] = [
@@ -300,4 +410,87 @@ describe('Claude → AiSDK transform', () => {
expect(finishStep.finishReason).toBe('stop')
expect(finishStep.usage).toEqual({ inputTokens: 2, outputTokens: 4, totalTokens: 6 })
})
it('emits fallback text when Claude sends a snapshot instead of deltas', () => {
const state = new ClaudeStreamState({ agentSessionId: '12344' })
const parts: ReturnType<typeof transformSDKMessageToStreamParts>[number][] = []
const messages: SDKMessage[] = [
{
...baseStreamMetadata,
type: 'stream_event',
uuid: uuid(30),
event: {
type: 'message_start',
message: {
id: 'msg-fallback',
type: 'message',
role: 'assistant',
model: 'claude-test',
content: [],
stop_reason: null,
stop_sequence: null,
usage: {}
}
}
} as unknown as SDKMessage,
{
...baseStreamMetadata,
type: 'stream_event',
uuid: uuid(31),
event: {
type: 'content_block_start',
index: 0,
content_block: {
type: 'text',
text: ''
}
}
} as unknown as SDKMessage,
{
...baseStreamMetadata,
type: 'assistant',
uuid: uuid(32),
message: {
id: 'msg-fallback-content',
type: 'message',
role: 'assistant',
model: 'claude-test',
content: [
{
type: 'text',
text: 'Final answer without streaming deltas.'
}
],
stop_reason: 'end_turn',
stop_sequence: null,
usage: {
input_tokens: 3,
output_tokens: 7
}
}
} as unknown as SDKMessage
]
for (const message of messages) {
const transformed = transformSDKMessageToStreamParts(message, state)
parts.push(...transformed)
}
const types = parts.map((part) => part.type)
expect(types).toEqual(['start-step', 'text-start', 'text-delta', 'text-end', 'finish-step'])
const delta = parts.find((part) => part.type === 'text-delta') as Extract<
(typeof parts)[number],
{ type: 'text-delta' }
>
expect(delta.text).toBe('Final answer without streaming deltas.')
const finish = parts.find((part) => part.type === 'finish-step') as Extract<
(typeof parts)[number],
{ type: 'finish-step' }
>
expect(finish.usage).toEqual({ inputTokens: 3, outputTokens: 7, totalTokens: 10 })
expect(finish.finishReason).toBe('stop')
})
})

View File

@@ -10,8 +10,21 @@
* Every Claude turn gets its own instance. `resetStep` should be invoked once the finish event has
* been emitted to avoid leaking state into the next turn.
*/
import { loggerService } from '@logger'
import type { FinishReason, LanguageModelUsage, ProviderMetadata } from 'ai'
/**
* Builds a namespaced tool call ID by combining session ID with raw tool call ID.
* This ensures tool calls from different sessions don't conflict even if they have
* the same raw ID from the SDK.
*
* @param sessionId - The agent session ID
* @param rawToolCallId - The raw tool call ID from SDK (e.g., "WebFetch_0")
*/
export function buildNamespacedToolCallId(sessionId: string, rawToolCallId: string): string {
return `${sessionId}:${rawToolCallId}`
}
/**
* Shared fields for every block that Claude can stream (text, reasoning, tool).
*/
@@ -34,6 +47,7 @@ type ReasoningBlockState = BaseBlockState & {
type ToolBlockState = BaseBlockState & {
kind: 'tool'
toolCallId: string
rawToolCallId: string
toolName: string
inputBuffer: string
providerMetadata?: ProviderMetadata
@@ -48,12 +62,17 @@ type PendingUsageState = {
}
type PendingToolCall = {
rawToolCallId: string
toolCallId: string
toolName: string
input: unknown
providerMetadata?: ProviderMetadata
}
type ClaudeStreamStateOptions = {
agentSessionId: string
}
/**
* Tracks the lifecycle of Claude streaming blocks (text, thinking, tool calls)
* across individual websocket events. The transformer relies on this class to
@@ -61,12 +80,20 @@ type PendingToolCall = {
* usage/finish metadata once Anthropic closes a message.
*/
export class ClaudeStreamState {
private logger
private readonly agentSessionId: string
private blocksByIndex = new Map<number, BlockState>()
private toolIndexById = new Map<string, number>()
private toolIndexByNamespacedId = new Map<string, number>()
private pendingUsage: PendingUsageState = {}
private pendingToolCalls = new Map<string, PendingToolCall>()
private stepActive = false
constructor(options: ClaudeStreamStateOptions) {
this.logger = loggerService.withContext('ClaudeStreamState')
this.agentSessionId = options.agentSessionId
this.logger.silly('ClaudeStreamState', options)
}
/** Marks the beginning of a new AiSDK step. */
beginStep(): void {
this.stepActive = true
@@ -104,19 +131,21 @@ export class ClaudeStreamState {
/** Caches tool metadata so subsequent input deltas and results can find it. */
openToolBlock(
index: number,
params: { toolCallId: string; toolName: string; providerMetadata?: ProviderMetadata }
params: { rawToolCallId: string; toolName: string; providerMetadata?: ProviderMetadata }
): ToolBlockState {
const toolCallId = buildNamespacedToolCallId(this.agentSessionId, params.rawToolCallId)
const block: ToolBlockState = {
kind: 'tool',
id: params.toolCallId,
id: toolCallId,
index,
toolCallId: params.toolCallId,
toolCallId,
rawToolCallId: params.rawToolCallId,
toolName: params.toolName,
inputBuffer: '',
providerMetadata: params.providerMetadata
}
this.blocksByIndex.set(index, block)
this.toolIndexById.set(params.toolCallId, index)
this.toolIndexByNamespacedId.set(toolCallId, index)
return block
}
@@ -124,14 +153,32 @@ export class ClaudeStreamState {
return this.blocksByIndex.get(index)
}
getFirstOpenTextBlock(): TextBlockState | undefined {
const candidates: TextBlockState[] = []
for (const block of this.blocksByIndex.values()) {
if (block.kind === 'text') {
candidates.push(block)
}
}
if (candidates.length === 0) {
return undefined
}
candidates.sort((a, b) => a.index - b.index)
return candidates[0]
}
getToolBlockById(toolCallId: string): ToolBlockState | undefined {
const index = this.toolIndexById.get(toolCallId)
const index = this.toolIndexByNamespacedId.get(toolCallId)
if (index === undefined) return undefined
const block = this.blocksByIndex.get(index)
if (!block || block.kind !== 'tool') return undefined
return block
}
getToolBlockByRawId(rawToolCallId: string): ToolBlockState | undefined {
return this.getToolBlockById(buildNamespacedToolCallId(this.agentSessionId, rawToolCallId))
}
/** Appends streamed text to a text block, returning the updated state when present. */
appendTextDelta(index: number, text: string): TextBlockState | undefined {
const block = this.blocksByIndex.get(index)
@@ -158,10 +205,12 @@ export class ClaudeStreamState {
/** Records a tool call to be consumed once its result arrives from the user. */
registerToolCall(
toolCallId: string,
rawToolCallId: string,
payload: { toolName: string; input: unknown; providerMetadata?: ProviderMetadata }
): void {
this.pendingToolCalls.set(toolCallId, {
const toolCallId = buildNamespacedToolCallId(this.agentSessionId, rawToolCallId)
this.pendingToolCalls.set(rawToolCallId, {
rawToolCallId,
toolCallId,
toolName: payload.toolName,
input: payload.input,
@@ -170,10 +219,10 @@ export class ClaudeStreamState {
}
/** Retrieves and clears the buffered tool call metadata for the given id. */
consumePendingToolCall(toolCallId: string): PendingToolCall | undefined {
const entry = this.pendingToolCalls.get(toolCallId)
consumePendingToolCall(rawToolCallId: string): PendingToolCall | undefined {
const entry = this.pendingToolCalls.get(rawToolCallId)
if (entry) {
this.pendingToolCalls.delete(toolCallId)
this.pendingToolCalls.delete(rawToolCallId)
}
return entry
}
@@ -182,13 +231,13 @@ export class ClaudeStreamState {
* Persists the final input payload for a tool block once the provider signals
* completion so that downstream tool results can reference the original call.
*/
completeToolBlock(toolCallId: string, input: unknown, providerMetadata?: ProviderMetadata): void {
completeToolBlock(toolCallId: string, toolName: string, input: unknown, providerMetadata?: ProviderMetadata): void {
const block = this.getToolBlockByRawId(toolCallId)
this.registerToolCall(toolCallId, {
toolName: this.getToolBlockById(toolCallId)?.toolName ?? 'unknown',
toolName,
input,
providerMetadata
})
const block = this.getToolBlockById(toolCallId)
if (block) {
block.resolvedInput = input
}
@@ -200,7 +249,7 @@ export class ClaudeStreamState {
if (!block) return undefined
this.blocksByIndex.delete(index)
if (block.kind === 'tool') {
this.toolIndexById.delete(block.toolCallId)
this.toolIndexByNamespacedId.delete(block.toolCallId)
}
return block
}
@@ -227,7 +276,7 @@ export class ClaudeStreamState {
/** Drops cached block metadata for the currently active message. */
resetBlocks(): void {
this.blocksByIndex.clear()
this.toolIndexById.clear()
this.toolIndexByNamespacedId.clear()
}
/** Resets the entire step lifecycle after emitting a terminal frame. */
@@ -236,6 +285,10 @@ export class ClaudeStreamState {
this.resetPendingUsage()
this.stepActive = false
}
getNamespacedToolCallId(rawToolCallId: string): string {
return buildNamespacedToolCallId(this.agentSessionId, rawToolCallId)
}
}
export type { PendingToolCall }

View File

@@ -2,7 +2,14 @@
import { EventEmitter } from 'node:events'
import { createRequire } from 'node:module'
import type { CanUseTool, McpHttpServerConfig, Options, SDKMessage } from '@anthropic-ai/claude-agent-sdk'
import type {
CanUseTool,
HookCallback,
McpHttpServerConfig,
Options,
PreToolUseHookInput,
SDKMessage
} from '@anthropic-ai/claude-agent-sdk'
import { query } from '@anthropic-ai/claude-agent-sdk'
import { loggerService } from '@logger'
import { config as apiConfigService } from '@main/apiServer/config'
@@ -13,6 +20,7 @@ import { app } from 'electron'
import type { GetAgentSessionResponse } from '../..'
import type { AgentServiceInterface, AgentStream, AgentStreamEvent } from '../../interfaces/AgentStreamInterface'
import { sessionService } from '../SessionService'
import { buildNamespacedToolCallId } from './claude-stream-state'
import { promptForToolApproval } from './tool-permissions'
import { ClaudeStreamState, transformSDKMessageToStreamParts } from './transform'
@@ -150,7 +158,67 @@ class ClaudeCodeService implements AgentServiceInterface {
return { behavior: 'allow', updatedInput: input }
}
return promptForToolApproval(toolName, input, options)
return promptForToolApproval(toolName, input, {
...options,
toolCallId: buildNamespacedToolCallId(session.id, options.toolUseID)
})
}
const preToolUseHook: HookCallback = async (input, toolUseID, options) => {
// Type guard to ensure we're handling PreToolUse event
if (input.hook_event_name !== 'PreToolUse') {
return {}
}
const hookInput = input as PreToolUseHookInput
const toolName = hookInput.tool_name
logger.debug('PreToolUse hook triggered', {
session_id: hookInput.session_id,
tool_name: hookInput.tool_name,
tool_use_id: toolUseID,
tool_input: hookInput.tool_input,
cwd: hookInput.cwd,
permission_mode: hookInput.permission_mode,
autoAllowTools: autoAllowTools
})
if (options?.signal?.aborted) {
logger.debug('PreToolUse hook signal already aborted; skipping tool use', {
tool_name: hookInput.tool_name
})
return {}
}
// handle auto approved tools since it never triggers canUseTool
const normalizedToolName = normalizeToolName(toolName)
if (toolUseID) {
const bypassAll = input.permission_mode === 'bypassPermissions'
const autoAllowed = autoAllowTools.has(toolName) || autoAllowTools.has(normalizedToolName)
if (bypassAll || autoAllowed) {
const namespacedToolCallId = buildNamespacedToolCallId(session.id, toolUseID)
logger.debug('handling auto approved tools', {
toolName,
normalizedToolName,
namespacedToolCallId,
permission_mode: input.permission_mode,
autoAllowTools
})
const isRecord = (v: unknown): v is Record<string, unknown> => {
return !!v && typeof v === 'object' && !Array.isArray(v)
}
const toolInput = isRecord(input.tool_input) ? input.tool_input : {}
await promptForToolApproval(toolName, toolInput, {
...options,
toolCallId: namespacedToolCallId,
autoApprove: true
})
}
}
// Return to proceed without modification
return {}
}
// Build SDK options from parameters
@@ -176,7 +244,14 @@ class ClaudeCodeService implements AgentServiceInterface {
permissionMode: session.configuration?.permission_mode,
maxTurns: session.configuration?.max_turns,
allowedTools: session.allowed_tools,
canUseTool
canUseTool,
hooks: {
PreToolUse: [
{
hooks: [preToolUseHook]
}
]
}
}
if (session.accessible_paths.length > 1) {
@@ -346,7 +421,7 @@ class ClaudeCodeService implements AgentServiceInterface {
const jsonOutput: SDKMessage[] = []
let hasCompleted = false
const startTime = Date.now()
const streamState = new ClaudeStreamState()
const streamState = new ClaudeStreamState({ agentSessionId: sessionId })
try {
for await (const message of query({ prompt: promptStream, options })) {
@@ -410,23 +485,6 @@ class ClaudeCodeService implements AgentServiceInterface {
}
}
if (message.type === 'assistant' || message.type === 'user') {
logger.silly('claude response', {
message,
content: JSON.stringify(message.message.content)
})
} else if (message.type === 'stream_event') {
// logger.silly('Claude stream event', {
// message,
// event: JSON.stringify(message.event)
// })
} else {
logger.silly('Claude response', {
message,
event: JSON.stringify(message)
})
}
const chunks = transformSDKMessageToStreamParts(message, streamState)
for (const chunk of chunks) {
stream.emit('data', {

View File

@@ -31,12 +31,14 @@ type PendingPermissionRequest = {
abortListener?: () => void
originalInput: Record<string, unknown>
toolName: string
toolCallId?: string
}
type RendererPermissionRequestPayload = {
requestId: string
toolName: string
toolId: string
toolCallId: string
description?: string
requiresPermissions: boolean
input: Record<string, unknown>
@@ -44,6 +46,7 @@ type RendererPermissionRequestPayload = {
createdAt: number
expiresAt: number
suggestions: PermissionUpdate[]
autoApprove?: boolean
}
type RendererPermissionResultPayload = {
@@ -51,6 +54,7 @@ type RendererPermissionResultPayload = {
behavior: ToolPermissionBehavior
message?: string
reason: 'response' | 'timeout' | 'aborted' | 'no-window'
toolCallId?: string
}
const pendingRequests = new Map<string, PendingPermissionRequest>()
@@ -144,7 +148,8 @@ const finalizeRequest = (
requestId,
behavior: update.behavior,
message: update.behavior === 'deny' ? update.message : undefined,
reason
reason,
toolCallId: pending.toolCallId
}
const dispatched = broadcastToRenderer(IpcChannel.AgentToolPermission_Result, resultPayload)
@@ -206,10 +211,20 @@ const ensureIpcHandlersRegistered = () => {
})
}
type PromptForToolApprovalOptions = {
signal: AbortSignal
suggestions?: PermissionUpdate[]
autoApprove?: boolean
// NOTICE: This ID is namespaced with session ID, not the raw SDK tool call ID.
// Format: `${sessionId}:${rawToolCallId}`, e.g., `session_123:WebFetch_0`
toolCallId: string
}
export async function promptForToolApproval(
toolName: string,
input: Record<string, unknown>,
options?: { signal: AbortSignal; suggestions?: PermissionUpdate[] }
options: PromptForToolApprovalOptions
): Promise<PermissionResult> {
if (shouldAutoApproveTools) {
logger.debug('promptForToolApproval auto-approving tool for test', {
@@ -245,6 +260,7 @@ export async function promptForToolApproval(
logger.info('Requesting user approval for tool usage', {
requestId,
toolName,
toolCallId: options.toolCallId,
description: toolMetadata?.description
})
@@ -252,13 +268,15 @@ export async function promptForToolApproval(
requestId,
toolName,
toolId: toolMetadata?.id ?? toolName,
toolCallId: options.toolCallId,
description: toolMetadata?.description,
requiresPermissions: toolMetadata?.requirePermissions ?? false,
input: sanitizedInput,
inputPreview,
createdAt,
expiresAt,
suggestions: sanitizedSuggestions
suggestions: sanitizedSuggestions,
autoApprove: options.autoApprove
}
const defaultDenyUpdate: PermissionResult = { behavior: 'deny', message: 'Tool request aborted before user decision' }
@@ -266,6 +284,7 @@ export async function promptForToolApproval(
logger.debug('Registering tool permission request', {
requestId,
toolName,
toolCallId: options.toolCallId,
requiresPermissions: requestPayload.requiresPermissions,
timeoutMs: TOOL_APPROVAL_TIMEOUT_MS,
suggestionCount: sanitizedSuggestions.length
@@ -273,7 +292,11 @@ export async function promptForToolApproval(
return new Promise<PermissionResult>((resolve) => {
const timeout = setTimeout(() => {
logger.info('User tool permission request timed out', { requestId, toolName })
logger.info('User tool permission request timed out', {
requestId,
toolName,
toolCallId: options.toolCallId
})
finalizeRequest(requestId, { behavior: 'deny', message: 'Timed out waiting for approval' }, 'timeout')
}, TOOL_APPROVAL_TIMEOUT_MS)
@@ -282,12 +305,17 @@ export async function promptForToolApproval(
timeout,
originalInput: sanitizedInput,
toolName,
signal: options?.signal
signal: options?.signal,
toolCallId: options.toolCallId
}
if (options?.signal) {
const abortListener = () => {
logger.info('Tool permission request aborted before user responded', { requestId, toolName })
logger.info('Tool permission request aborted before user responded', {
requestId,
toolName,
toolCallId: options.toolCallId
})
finalizeRequest(requestId, defaultDenyUpdate, 'aborted')
}

View File

@@ -110,7 +110,7 @@ const sdkMessageToProviderMetadata = (message: SDKMessage): ProviderMetadata =>
* blocks across calls so that incremental deltas can be correlated correctly.
*/
export function transformSDKMessageToStreamParts(sdkMessage: SDKMessage, state: ClaudeStreamState): AgentStreamPart[] {
logger.silly('Transforming SDKMessage', { message: sdkMessage })
logger.silly('Transforming SDKMessage', { message: JSON.stringify(sdkMessage) })
switch (sdkMessage.type) {
case 'assistant':
return handleAssistantMessage(sdkMessage, state)
@@ -186,14 +186,13 @@ function handleAssistantMessage(
for (const block of content) {
switch (block.type) {
case 'text':
if (!isStreamingActive) {
const sanitizedText = stripLocalCommandTags(block.text)
if (sanitizedText) {
textBlocks.push(sanitizedText)
}
case 'text': {
const sanitizedText = stripLocalCommandTags(block.text)
if (sanitizedText) {
textBlocks.push(sanitizedText)
}
break
}
case 'tool_use':
handleAssistantToolUse(block as ToolUseContent, providerMetadata, state, chunks)
break
@@ -203,7 +202,16 @@ function handleAssistantMessage(
}
}
if (!isStreamingActive && textBlocks.length > 0) {
if (textBlocks.length === 0) {
return chunks
}
const combinedText = textBlocks.join('')
if (!combinedText) {
return chunks
}
if (!isStreamingActive) {
const id = message.uuid?.toString() || generateMessageId()
state.beginStep()
chunks.push({
@@ -219,7 +227,7 @@ function handleAssistantMessage(
chunks.push({
type: 'text-delta',
id,
text: textBlocks.join(''),
text: combinedText,
providerMetadata
})
chunks.push({
@@ -230,7 +238,27 @@ function handleAssistantMessage(
return finalizeNonStreamingStep(message, state, chunks)
}
return chunks
const existingTextBlock = state.getFirstOpenTextBlock()
const fallbackId = existingTextBlock?.id || message.uuid?.toString() || generateMessageId()
if (!existingTextBlock) {
chunks.push({
type: 'text-start',
id: fallbackId,
providerMetadata
})
}
chunks.push({
type: 'text-delta',
id: fallbackId,
text: combinedText,
providerMetadata
})
chunks.push({
type: 'text-end',
id: fallbackId,
providerMetadata
})
return finalizeNonStreamingStep(message, state, chunks)
}
/**
@@ -243,15 +271,16 @@ function handleAssistantToolUse(
state: ClaudeStreamState,
chunks: AgentStreamPart[]
): void {
const toolCallId = state.getNamespacedToolCallId(block.id)
chunks.push({
type: 'tool-call',
toolCallId: block.id,
toolCallId,
toolName: block.name,
input: block.input,
providerExecuted: true,
providerMetadata
})
state.completeToolBlock(block.id, block.input, providerMetadata)
state.completeToolBlock(block.id, block.name, block.input, providerMetadata)
}
/**
@@ -331,10 +360,11 @@ function handleUserMessage(
if (block.type === 'tool_result') {
const toolResult = block as ToolResultContent
const pendingCall = state.consumePendingToolCall(toolResult.tool_use_id)
const toolCallId = pendingCall?.toolCallId ?? state.getNamespacedToolCallId(toolResult.tool_use_id)
if (toolResult.is_error) {
chunks.push({
type: 'tool-error',
toolCallId: toolResult.tool_use_id,
toolCallId,
toolName: pendingCall?.toolName ?? 'unknown',
input: pendingCall?.input,
error: toolResult.content,
@@ -343,7 +373,7 @@ function handleUserMessage(
} else {
chunks.push({
type: 'tool-result',
toolCallId: toolResult.tool_use_id,
toolCallId,
toolName: pendingCall?.toolName ?? 'unknown',
input: pendingCall?.input,
output: toolResult.content,
@@ -457,6 +487,9 @@ function handleStreamEvent(
}
case 'message_stop': {
if (!state.hasActiveStep()) {
break
}
const pending = state.getPendingUsage()
chunks.push({
type: 'finish-step',
@@ -514,7 +547,7 @@ function handleContentBlockStart(
}
case 'tool_use': {
const block = state.openToolBlock(index, {
toolCallId: contentBlock.id,
rawToolCallId: contentBlock.id,
toolName: contentBlock.name,
providerMetadata
})

View File

@@ -111,6 +111,7 @@ const api = {
setFullScreen: (value: boolean): Promise<void> => ipcRenderer.invoke(IpcChannel.App_SetFullScreen, value),
isFullScreen: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.App_IsFullScreen),
getSystemFonts: (): Promise<string[]> => ipcRenderer.invoke(IpcChannel.App_GetSystemFonts),
mockCrashRenderProcess: () => ipcRenderer.invoke(IpcChannel.APP_CrashRenderProcess),
mac: {
isProcessTrusted: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.App_MacIsProcessTrusted),
requestProcessTrust: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.App_MacRequestProcessTrust)
@@ -121,7 +122,8 @@ const api = {
system: {
getDeviceType: () => ipcRenderer.invoke(IpcChannel.System_GetDeviceType),
getHostname: () => ipcRenderer.invoke(IpcChannel.System_GetHostname),
getCpuName: () => ipcRenderer.invoke(IpcChannel.System_GetCpuName)
getCpuName: () => ipcRenderer.invoke(IpcChannel.System_GetCpuName),
checkGitBash: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.System_CheckGitBash)
},
devTools: {
toggle: () => ipcRenderer.invoke(IpcChannel.System_ToggleDevTools)

View File

@@ -32,6 +32,7 @@ import {
prepareSpecialProviderConfig,
providerToAiSdkConfig
} from './provider/providerConfig'
import type { AiSdkConfig } from './types'
const logger = loggerService.withContext('ModernAiProvider')
@@ -44,7 +45,7 @@ export type ModernAiProviderConfig = AiSdkMiddlewareConfig & {
export default class ModernAiProvider {
private legacyProvider: LegacyAiProvider
private config?: ReturnType<typeof providerToAiSdkConfig>
private config?: AiSdkConfig
private actualProvider: Provider
private model?: Model
private localProvider: Awaited<AiSdkProvider> | null = null
@@ -89,6 +90,11 @@ export default class ModernAiProvider {
// 每次请求时重新生成配置以确保API key轮换生效
this.config = providerToAiSdkConfig(this.actualProvider, this.model)
logger.debug('Generated provider config for completions', this.config)
// 检查 config 是否存在
if (!this.config) {
throw new Error('Provider config is undefined; cannot proceed with completions')
}
if (SUPPORTED_IMAGE_ENDPOINT_LIST.includes(this.config.options.endpoint)) {
providerConfig.isImageGenerationEndpoint = true
}
@@ -149,7 +155,8 @@ export default class ModernAiProvider {
params: StreamTextParams,
config: ModernAiProviderConfig
): Promise<CompletionsResult> {
if (config.isImageGenerationEndpoint) {
// ai-gateway不是image/generation 端点所以就先不走legacy了
if (config.isImageGenerationEndpoint && config.provider!.id !== SystemProviderIds['ai-gateway']) {
// 使用 legacy 实现处理图像生成(支持图片编辑等高级功能)
if (!config.uiMessages) {
throw new Error('uiMessages is required for image generation endpoint')
@@ -463,8 +470,13 @@ export default class ModernAiProvider {
// 如果支持新的 AI SDK使用现代化实现
if (isModernSdkSupported(this.actualProvider)) {
try {
// 确保 config 已定义
if (!this.config) {
throw new Error('Provider config is undefined; cannot proceed with generateImage')
}
// 确保本地provider已创建
if (!this.localProvider) {
if (!this.localProvider && this.config) {
this.localProvider = await createAiSdkProvider(this.config)
if (!this.localProvider) {
throw new Error('Local provider not created')

View File

@@ -1,6 +1,6 @@
import { loggerService } from '@logger'
import { isNewApiProvider } from '@renderer/config/providers'
import type { Provider } from '@renderer/types'
import { isNewApiProvider } from '@renderer/utils/provider'
import { AihubmixAPIClient } from './aihubmix/AihubmixAPIClient'
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'

View File

@@ -1,12 +1,12 @@
import { loggerService } from '@logger'
import {
getModelSupportedVerbosity,
isFunctionCallingModel,
isNotSupportTemperatureAndTopP,
isOpenAIModel,
isSupportFlexServiceTierModel
} from '@renderer/config/models'
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
import { isSupportServiceTierProvider } from '@renderer/config/providers'
import { getLMStudioKeepAliveTime } from '@renderer/hooks/useLMStudio'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import type {
@@ -18,7 +18,6 @@ import type {
MCPToolResponse,
MemoryItem,
Model,
OpenAIVerbosity,
Provider,
ToolCallResponse,
WebSearchProviderResponse,
@@ -32,6 +31,7 @@ import {
OpenAIServiceTiers,
SystemProviderIds
} from '@renderer/types'
import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import type { Message } from '@renderer/types/newMessage'
import type {
RequestOptions,
@@ -47,6 +47,7 @@ import type {
import { isJSON, parseJSON } from '@renderer/utils'
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { isSupportServiceTierProvider } from '@renderer/utils/provider'
import { defaultTimeout } from '@shared/config/constant'
import { defaultAppHeaders } from '@shared/utils'
import { isEmpty } from 'lodash'
@@ -242,12 +243,18 @@ export abstract class BaseApiClient<
return serviceTierSetting
}
protected getVerbosity(): OpenAIVerbosity {
protected getVerbosity(model?: Model): OpenAIVerbosity {
try {
const state = window.store?.getState()
const verbosity = state?.settings?.openAI?.verbosity
if (verbosity && ['low', 'medium', 'high'].includes(verbosity)) {
// If model is provided, check if the verbosity is supported by the model
if (model) {
const supportedVerbosity = getModelSupportedVerbosity(model)
// Use user's verbosity if supported, otherwise use the first supported option
return supportedVerbosity.includes(verbosity) ? verbosity : supportedVerbosity[0]
}
return verbosity
}
} catch (error) {

View File

@@ -58,10 +58,27 @@ vi.mock('../aws/AwsBedrockAPIClient', () => ({
AwsBedrockAPIClient: vi.fn().mockImplementation(() => ({}))
}))
vi.mock('@renderer/services/AssistantService.ts', () => ({
getDefaultAssistant: () => {
return {
id: 'default',
name: 'default',
emoji: '😀',
prompt: '',
topics: [],
messages: [],
type: 'assistant',
regularPhrases: [],
settings: {}
}
}
}))
// Mock the models config to prevent circular dependency issues
vi.mock('@renderer/config/models', () => ({
findTokenLimit: vi.fn(),
isReasoningModel: vi.fn(),
isOpenAILLMModel: vi.fn(),
SYSTEM_MODELS: {
silicon: [],
defaultModel: []

View File

@@ -1,7 +1,8 @@
import { GoogleGenAI } from '@google/genai'
import { loggerService } from '@logger'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
import type { Model, Provider, VertexProvider } from '@renderer/types'
import { isVertexProvider } from '@renderer/utils/provider'
import { isEmpty } from 'lodash'
import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient'

View File

@@ -10,7 +10,6 @@ import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import {
findTokenLimit,
GEMINI_FLASH_MODEL_REGEX,
getOpenAIWebSearchParams,
getThinkModelType,
isClaudeReasoningModel,
isDeepSeekHybridInferenceModel,
@@ -35,16 +34,11 @@ import {
isSupportedThinkingTokenModel,
isSupportedThinkingTokenQwenModel,
isSupportedThinkingTokenZhipuModel,
isSupportVerbosityModel,
isVisionModel,
MODEL_SUPPORTED_REASONING_EFFORT,
ZHIPU_RESULT_TOKENS
} from '@renderer/config/models'
import {
isSupportArrayContentProvider,
isSupportDeveloperRoleProvider,
isSupportEnableThinkingProvider,
isSupportStreamOptionsProvider
} from '@renderer/config/providers'
import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
import { processPostsuffixQwen3Model, processReqMessages } from '@renderer/services/ModelMessageService'
import { estimateTextTokens } from '@renderer/services/TokenService'
@@ -88,6 +82,12 @@ import {
openAIToolsToMcpTool
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import {
isSupportArrayContentProvider,
isSupportDeveloperRoleProvider,
isSupportEnableThinkingProvider,
isSupportStreamOptionsProvider
} from '@renderer/utils/provider'
import { t } from 'i18next'
import type { GenericChunk } from '../../middleware/schemas'
@@ -733,9 +733,16 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
...modalities,
// groq 有不同的 service tier 配置,不符合 openai 接口类型
service_tier: this.getServiceTier(model) as OpenAIServiceTier,
...(isSupportVerbosityModel(model)
? {
text: {
verbosity: this.getVerbosity(model)
}
}
: {}),
...this.getProviderSpecificParameters(assistant, model),
...reasoningEffort,
...getOpenAIWebSearchParams(model, enableWebSearch),
// ...getOpenAIWebSearchParams(model, enableWebSearch),
// OpenRouter usage tracking
...(this.provider.id === 'openrouter' ? { usage: { include: true } } : {}),
...extra_body,

View File

@@ -48,9 +48,8 @@ export abstract class OpenAIBaseClient<
}
// 仅适用于openai
override getBaseURL(): string {
const host = this.provider.apiHost
return formatApiHost(host)
override getBaseURL(isSupportedAPIVerion: boolean = true): string {
return formatApiHost(this.provider.apiHost, isSupportedAPIVerion)
}
override async generateImage({
@@ -144,6 +143,11 @@ export abstract class OpenAIBaseClient<
}
let apiKeyForSdkInstance = this.apiKey
let baseURLForSdkInstance = this.getBaseURL()
let headersForSdkInstance = {
...this.defaultHeaders(),
...this.provider.extra_headers
}
if (this.provider.id === 'copilot') {
const defaultHeaders = store.getState().copilot.defaultHeaders
@@ -151,6 +155,11 @@ export abstract class OpenAIBaseClient<
// this.provider.apiKey不允许修改
// this.provider.apiKey = token
apiKeyForSdkInstance = token
baseURLForSdkInstance = this.getBaseURL(false)
headersForSdkInstance = {
...headersForSdkInstance,
...COPILOT_DEFAULT_HEADERS
}
}
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
@@ -164,12 +173,8 @@ export abstract class OpenAIBaseClient<
this.sdkInstance = new OpenAI({
dangerouslyAllowBrowser: true,
apiKey: apiKeyForSdkInstance,
baseURL: this.getBaseURL(),
defaultHeaders: {
...this.defaultHeaders(),
...this.provider.extra_headers,
...(this.provider.id === 'copilot' ? COPILOT_DEFAULT_HEADERS : {})
}
baseURL: baseURLForSdkInstance,
defaultHeaders: headersForSdkInstance
}) as TSdkInstance
}
return this.sdkInstance

View File

@@ -12,7 +12,6 @@ import {
isSupportVerbosityModel,
isVisionModel
} from '@renderer/config/models'
import { isSupportDeveloperRoleProvider } from '@renderer/config/providers'
import { estimateTextTokens } from '@renderer/services/TokenService'
import type {
FileMetadata,
@@ -43,6 +42,7 @@ import {
openAIToolsToMcpTool
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { isSupportDeveloperRoleProvider } from '@renderer/utils/provider'
import { MB } from '@shared/config/constant'
import { t } from 'i18next'
import { isEmpty } from 'lodash'
@@ -90,7 +90,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
if (isOpenAILLMModel(model) && !isOpenAIChatCompletionOnlyModel(model)) {
if (this.provider.id === 'azure-openai' || this.provider.type === 'azure-openai') {
this.provider = { ...this.provider, apiHost: this.formatApiHost() }
if (this.provider.apiVersion === 'preview') {
if (this.provider.apiVersion === 'preview' || this.provider.apiVersion === 'v1') {
return this
} else {
return this.client
@@ -297,7 +297,31 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
private convertResponseToMessageContent(response: OpenAI.Responses.Response): ResponseInput {
const content: OpenAI.Responses.ResponseInput = []
content.push(...response.output)
response.output.forEach((item) => {
if (item.type !== 'apply_patch_call' && item.type !== 'apply_patch_call_output') {
content.push(item)
} else if (item.type === 'apply_patch_call') {
if (item.operation !== undefined) {
const applyPatchToolCall: OpenAI.Responses.ResponseInputItem.ApplyPatchCall = {
...item,
operation: item.operation
}
content.push(applyPatchToolCall)
} else {
logger.warn('Undefined tool call operation for ApplyPatchToolCall.')
}
} else if (item.type === 'apply_patch_call_output') {
if (item.output !== undefined) {
const applyPatchToolCallOutput: OpenAI.Responses.ResponseInputItem.ApplyPatchCallOutput = {
...item,
output: item.output === null ? undefined : item.output
}
content.push(applyPatchToolCallOutput)
} else {
logger.warn('Undefined tool call operation for ApplyPatchToolCall.')
}
}
})
return content
}
@@ -496,7 +520,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
...(isSupportVerbosityModel(model)
? {
text: {
verbosity: this.getVerbosity()
verbosity: this.getVerbosity(model)
}
}
: {}),

View File

@@ -1,6 +1,7 @@
import { loggerService } from '@logger'
import { isZhipuModel } from '@renderer/config/models'
import { getStoreProviders } from '@renderer/hooks/useStore'
import { getDefaultModel } from '@renderer/services/AssistantService'
import type { Chunk } from '@renderer/types/chunk'
import type { CompletionsParams, CompletionsResult } from '../schemas'
@@ -66,7 +67,7 @@ export const ErrorHandlerMiddleware =
}
function handleError(error: any, params: CompletionsParams): any {
if (isZhipuModel(params.assistant.model) && error.status && !params.enableGenerateImage) {
if (isZhipuModel(params.assistant.model || getDefaultModel()) && error.status && !params.enableGenerateImage) {
return handleZhipuError(error)
}

View File

@@ -1,10 +1,10 @@
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
import { loggerService } from '@logger'
import { isSupportedThinkingTokenQwenModel } from '@renderer/config/models'
import { isSupportEnableThinkingProvider } from '@renderer/config/providers'
import type { MCPTool } from '@renderer/types'
import { type Assistant, type Message, type Model, type Provider } from '@renderer/types'
import { type Assistant, type Message, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import type { Chunk } from '@renderer/types/chunk'
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
import type { LanguageModelMiddleware } from 'ai'
import { extractReasoningMiddleware, simulateStreamingMiddleware } from 'ai'
import { isEmpty } from 'lodash'
@@ -12,6 +12,7 @@ import { isEmpty } from 'lodash'
import { isOpenRouterGeminiGenerateImageModel } from '../utils/image'
import { noThinkMiddleware } from './noThinkMiddleware'
import { openrouterGenerateImageMiddleware } from './openrouterGenerateImageMiddleware'
import { openrouterReasoningMiddleware } from './openrouterReasoningMiddleware'
import { qwenThinkingMiddleware } from './qwenThinkingMiddleware'
import { toolChoiceMiddleware } from './toolChoiceMiddleware'
@@ -217,6 +218,14 @@ function addProviderSpecificMiddlewares(builder: AiSdkMiddlewareBuilder, config:
middleware: noThinkMiddleware()
})
}
if (config.provider.id === SystemProviderIds.openrouter && config.enableReasoning) {
builder.add({
name: 'openrouter-reasoning-redaction',
middleware: openrouterReasoningMiddleware()
})
logger.debug('Added OpenRouter reasoning redaction middleware')
}
}
/**

View File

@@ -0,0 +1,50 @@
import type { LanguageModelV2StreamPart } from '@ai-sdk/provider'
import type { LanguageModelMiddleware } from 'ai'
/**
* https://openrouter.ai/docs/docs/best-practices/reasoning-tokens#example-preserving-reasoning-blocks-with-openrouter-and-claude
*
* @returns LanguageModelMiddleware - a middleware filter redacted block
*/
export function openrouterReasoningMiddleware(): LanguageModelMiddleware {
const REDACTED_BLOCK = '[REDACTED]'
return {
middlewareVersion: 'v2',
wrapGenerate: async ({ doGenerate }) => {
const { content, ...rest } = await doGenerate()
const modifiedContent = content.map((part) => {
if (part.type === 'reasoning' && part.text.includes(REDACTED_BLOCK)) {
return {
...part,
text: part.text.replace(REDACTED_BLOCK, '')
}
}
return part
})
return { content: modifiedContent, ...rest }
},
wrapStream: async ({ doStream }) => {
const { stream, ...rest } = await doStream()
return {
stream: stream.pipeThrough(
new TransformStream<LanguageModelV2StreamPart, LanguageModelV2StreamPart>({
transform(
chunk: LanguageModelV2StreamPart,
controller: TransformStreamDefaultController<LanguageModelV2StreamPart>
) {
if (chunk.type === 'reasoning-delta' && chunk.delta.includes(REDACTED_BLOCK)) {
controller.enqueue({
...chunk,
delta: chunk.delta.replace(REDACTED_BLOCK, '')
})
} else {
controller.enqueue(chunk)
}
}
})
),
...rest
}
}
}
}

View File

@@ -0,0 +1,234 @@
import type { Message, Model } from '@renderer/types'
import type { FileMetadata } from '@renderer/types/file'
import { FileTypes } from '@renderer/types/file'
import {
AssistantMessageStatus,
type FileMessageBlock,
type ImageMessageBlock,
MessageBlockStatus,
MessageBlockType,
type ThinkingMessageBlock,
UserMessageStatus
} from '@renderer/types/newMessage'
import { beforeEach, describe, expect, it, vi } from 'vitest'
const { convertFileBlockToFilePartMock, convertFileBlockToTextPartMock } = vi.hoisted(() => ({
convertFileBlockToFilePartMock: vi.fn(),
convertFileBlockToTextPartMock: vi.fn()
}))
vi.mock('../fileProcessor', () => ({
convertFileBlockToFilePart: convertFileBlockToFilePartMock,
convertFileBlockToTextPart: convertFileBlockToTextPartMock
}))
const visionModelIds = new Set(['gpt-4o-mini', 'qwen-image-edit'])
const imageEnhancementModelIds = new Set(['qwen-image-edit'])
vi.mock('@renderer/config/models', () => ({
isVisionModel: (model: Model) => visionModelIds.has(model.id),
isImageEnhancementModel: (model: Model) => imageEnhancementModelIds.has(model.id)
}))
type MockableMessage = Message & {
__mockContent?: string
__mockFileBlocks?: FileMessageBlock[]
__mockImageBlocks?: ImageMessageBlock[]
__mockThinkingBlocks?: ThinkingMessageBlock[]
}
vi.mock('@renderer/utils/messageUtils/find', () => ({
getMainTextContent: (message: Message) => (message as MockableMessage).__mockContent ?? '',
findFileBlocks: (message: Message) => (message as MockableMessage).__mockFileBlocks ?? [],
findImageBlocks: (message: Message) => (message as MockableMessage).__mockImageBlocks ?? [],
findThinkingBlocks: (message: Message) => (message as MockableMessage).__mockThinkingBlocks ?? []
}))
import { convertMessagesToSdkMessages, convertMessageToSdkParam } from '../messageConverter'
let messageCounter = 0
let blockCounter = 0
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o-mini',
name: 'GPT-4o mini',
provider: 'openai',
group: 'openai',
...overrides
})
const createMessage = (role: Message['role']): MockableMessage =>
({
id: `message-${++messageCounter}`,
role,
assistantId: 'assistant-1',
topicId: 'topic-1',
createdAt: new Date(2024, 0, 1, 0, 0, messageCounter).toISOString(),
status: role === 'assistant' ? AssistantMessageStatus.SUCCESS : UserMessageStatus.SUCCESS,
blocks: []
}) as MockableMessage
const createFileBlock = (
messageId: string,
overrides: Partial<Omit<FileMessageBlock, 'file' | 'messageId' | 'type'>> & { file?: Partial<FileMetadata> } = {}
): FileMessageBlock => {
const { file, ...blockOverrides } = overrides
const timestamp = new Date(2024, 0, 1, 0, 0, ++blockCounter).toISOString()
return {
id: blockOverrides.id ?? `file-block-${blockCounter}`,
messageId,
type: MessageBlockType.FILE,
createdAt: blockOverrides.createdAt ?? timestamp,
status: blockOverrides.status ?? MessageBlockStatus.SUCCESS,
file: {
id: file?.id ?? `file-${blockCounter}`,
name: file?.name ?? 'document.txt',
origin_name: file?.origin_name ?? 'document.txt',
path: file?.path ?? '/tmp/document.txt',
size: file?.size ?? 1024,
ext: file?.ext ?? '.txt',
type: file?.type ?? FileTypes.TEXT,
created_at: file?.created_at ?? timestamp,
count: file?.count ?? 1,
...file
},
...blockOverrides
}
}
const createImageBlock = (
messageId: string,
overrides: Partial<Omit<ImageMessageBlock, 'type' | 'messageId'>> = {}
): ImageMessageBlock => ({
id: overrides.id ?? `image-block-${++blockCounter}`,
messageId,
type: MessageBlockType.IMAGE,
createdAt: overrides.createdAt ?? new Date(2024, 0, 1, 0, 0, blockCounter).toISOString(),
status: overrides.status ?? MessageBlockStatus.SUCCESS,
url: overrides.url ?? 'https://example.com/image.png',
...overrides
})
describe('messageConverter', () => {
beforeEach(() => {
convertFileBlockToFilePartMock.mockReset()
convertFileBlockToTextPartMock.mockReset()
convertFileBlockToFilePartMock.mockResolvedValue(null)
convertFileBlockToTextPartMock.mockResolvedValue(null)
messageCounter = 0
blockCounter = 0
})
describe('convertMessageToSdkParam', () => {
it('includes text and image parts for user messages on vision models', async () => {
const model = createModel()
const message = createMessage('user')
message.__mockContent = 'Describe this picture'
message.__mockImageBlocks = [createImageBlock(message.id, { url: 'https://example.com/cat.png' })]
const result = await convertMessageToSdkParam(message, true, model)
expect(result).toEqual({
role: 'user',
content: [
{ type: 'text', text: 'Describe this picture' },
{ type: 'image', image: 'https://example.com/cat.png' }
]
})
})
it('returns file instructions as a system message when native uploads succeed', async () => {
const model = createModel()
const message = createMessage('user')
message.__mockContent = 'Summarize the PDF'
message.__mockFileBlocks = [createFileBlock(message.id)]
convertFileBlockToFilePartMock.mockResolvedValueOnce({
type: 'file',
filename: 'document.pdf',
mediaType: 'application/pdf',
data: 'fileid://remote-file'
})
const result = await convertMessageToSdkParam(message, false, model)
expect(result).toEqual([
{
role: 'system',
content: 'fileid://remote-file'
},
{
role: 'user',
content: [{ type: 'text', text: 'Summarize the PDF' }]
}
])
})
})
describe('convertMessagesToSdkMessages', () => {
it('appends assistant images to the final user message for image enhancement models', async () => {
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
const initialUser = createMessage('user')
initialUser.__mockContent = 'Start editing'
const assistant = createMessage('assistant')
assistant.__mockContent = 'Here is the current preview'
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/preview.png' })]
const finalUser = createMessage('user')
finalUser.__mockContent = 'Increase the brightness'
const result = await convertMessagesToSdkMessages([initialUser, assistant, finalUser], model)
expect(result).toEqual([
{
role: 'assistant',
content: [{ type: 'text', text: 'Here is the current preview' }]
},
{
role: 'user',
content: [
{ type: 'text', text: 'Increase the brightness' },
{ type: 'image', image: 'https://example.com/preview.png' }
]
}
])
})
it('preserves preceding system instructions when building enhancement payloads', async () => {
const model = createModel({ id: 'qwen-image-edit', name: 'Qwen Image Edit', provider: 'qwen', group: 'qwen' })
const fileUser = createMessage('user')
fileUser.__mockContent = 'Use this document as inspiration'
fileUser.__mockFileBlocks = [createFileBlock(fileUser.id, { file: { ext: '.pdf', type: FileTypes.DOCUMENT } })]
convertFileBlockToFilePartMock.mockResolvedValueOnce({
type: 'file',
filename: 'reference.pdf',
mediaType: 'application/pdf',
data: 'fileid://reference'
})
const assistant = createMessage('assistant')
assistant.__mockContent = 'Generated previews ready'
assistant.__mockImageBlocks = [createImageBlock(assistant.id, { url: 'https://example.com/reference.png' })]
const finalUser = createMessage('user')
finalUser.__mockContent = 'Apply the edits'
const result = await convertMessagesToSdkMessages([fileUser, assistant, finalUser], model)
expect(result).toEqual([
{ role: 'system', content: 'fileid://reference' },
{
role: 'assistant',
content: [{ type: 'text', text: 'Generated previews ready' }]
},
{
role: 'user',
content: [
{ type: 'text', text: 'Apply the edits' },
{ type: 'image', image: 'https://example.com/reference.png' }
]
}
])
})
})
})

View File

@@ -0,0 +1,218 @@
import type { Assistant, AssistantSettings, Model, Topic } from '@renderer/types'
import { TopicType } from '@renderer/types'
import { defaultTimeout } from '@shared/config/constant'
import { describe, expect, it, vi } from 'vitest'
import { getTemperature, getTimeout, getTopP } from '../modelParameters'
vi.mock('@renderer/services/AssistantService', () => ({
getAssistantSettings: (assistant: Assistant): AssistantSettings => ({
contextCount: assistant.settings?.contextCount ?? 4096,
temperature: assistant.settings?.temperature ?? 0.7,
enableTemperature: assistant.settings?.enableTemperature ?? true,
topP: assistant.settings?.topP ?? 1,
enableTopP: assistant.settings?.enableTopP ?? false,
enableMaxTokens: assistant.settings?.enableMaxTokens ?? false,
maxTokens: assistant.settings?.maxTokens,
streamOutput: assistant.settings?.streamOutput ?? true,
toolUseMode: assistant.settings?.toolUseMode ?? 'prompt',
defaultModel: assistant.defaultModel,
customParameters: assistant.settings?.customParameters ?? [],
reasoning_effort: assistant.settings?.reasoning_effort,
reasoning_effort_cache: assistant.settings?.reasoning_effort_cache,
qwenThinkMode: assistant.settings?.qwenThinkMode
})
}))
vi.mock('@renderer/hooks/useSettings', () => ({
getStoreSetting: vi.fn(),
useSettings: vi.fn(() => ({})),
useNavbarPosition: vi.fn(() => ({ navbarPosition: 'left', isLeftNavbar: true, isTopNavbar: false }))
}))
vi.mock('@renderer/hooks/useStore', () => ({
getStoreProviders: vi.fn(() => [])
}))
vi.mock('@renderer/store/settings', () => ({
default: (state = { settings: {} }) => state
}))
vi.mock('@renderer/store/assistants', () => ({
default: (state = { assistants: [] }) => state
}))
const createTopic = (assistantId: string): Topic => ({
id: `topic-${assistantId}`,
assistantId,
name: 'topic',
createdAt: new Date().toISOString(),
updatedAt: new Date().toISOString(),
messages: [],
type: TopicType.Chat
})
const createAssistant = (settings: Assistant['settings'] = {}): Assistant => {
const assistantId = 'assistant-1'
return {
id: assistantId,
name: 'Test Assistant',
prompt: 'prompt',
topics: [createTopic(assistantId)],
type: 'assistant',
settings
}
}
const createModel = (overrides: Partial<Model> = {}): Model => ({
id: 'gpt-4o',
provider: 'openai',
name: 'GPT-4o',
group: 'openai',
...overrides
})
describe('modelParameters', () => {
describe('getTemperature', () => {
it('returns undefined when reasoning effort is enabled for Claude models', () => {
const assistant = createAssistant({ reasoning_effort: 'medium' })
const model = createModel({ id: 'claude-opus-4', name: 'Claude Opus 4', provider: 'anthropic', group: 'claude' })
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('returns undefined for models without temperature/topP support', () => {
const assistant = createAssistant({ enableTemperature: true })
const model = createModel({ id: 'qwen-mt-large', name: 'Qwen MT', provider: 'qwen', group: 'qwen' })
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('returns undefined for Claude 4.5 reasoning models when only TopP is enabled', () => {
const assistant = createAssistant({ enableTopP: true, enableTemperature: false })
const model = createModel({
id: 'claude-sonnet-4.5',
name: 'Claude Sonnet 4.5',
provider: 'anthropic',
group: 'claude'
})
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('returns configured temperature when enabled', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 0.42 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTemperature(assistant, model)).toBe(0.42)
})
it('returns undefined when temperature is disabled', () => {
const assistant = createAssistant({ enableTemperature: false, temperature: 0.9 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTemperature(assistant, model)).toBeUndefined()
})
it('clamps temperature to max 1.0 for Zhipu models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
const model = createModel({ id: 'glm-4-plus', name: 'GLM-4 Plus', provider: 'zhipu', group: 'zhipu' })
expect(getTemperature(assistant, model)).toBe(1.0)
})
it('clamps temperature to max 1.0 for Anthropic models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 1.5 })
const model = createModel({
id: 'claude-sonnet-3.5',
name: 'Claude 3.5 Sonnet',
provider: 'anthropic',
group: 'claude'
})
expect(getTemperature(assistant, model)).toBe(1.0)
})
it('clamps temperature to max 1.0 for Moonshot models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
const model = createModel({
id: 'moonshot-v1-8k',
name: 'Moonshot v1 8k',
provider: 'moonshot',
group: 'moonshot'
})
expect(getTemperature(assistant, model)).toBe(1.0)
})
it('does not clamp temperature for OpenAI models', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 2.0 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTemperature(assistant, model)).toBe(2.0)
})
it('does not clamp temperature when it is already within limits', () => {
const assistant = createAssistant({ enableTemperature: true, temperature: 0.8 })
const model = createModel({ id: 'glm-4-plus', name: 'GLM-4 Plus', provider: 'zhipu', group: 'zhipu' })
expect(getTemperature(assistant, model)).toBe(0.8)
})
})
describe('getTopP', () => {
it('returns undefined when reasoning effort is enabled for Claude models', () => {
const assistant = createAssistant({ reasoning_effort: 'high' })
const model = createModel({ id: 'claude-opus-4', provider: 'anthropic', group: 'claude' })
expect(getTopP(assistant, model)).toBeUndefined()
})
it('returns undefined for models without TopP support', () => {
const assistant = createAssistant({ enableTopP: true })
const model = createModel({ id: 'qwen-mt-small', name: 'Qwen MT', provider: 'qwen', group: 'qwen' })
expect(getTopP(assistant, model)).toBeUndefined()
})
it('returns undefined for Claude 4.5 reasoning models when temperature is enabled', () => {
const assistant = createAssistant({ enableTemperature: true })
const model = createModel({
id: 'claude-opus-4.5',
name: 'Claude Opus 4.5',
provider: 'anthropic',
group: 'claude'
})
expect(getTopP(assistant, model)).toBeUndefined()
})
it('returns configured TopP when enabled', () => {
const assistant = createAssistant({ enableTopP: true, topP: 0.73 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTopP(assistant, model)).toBe(0.73)
})
it('returns undefined when TopP is disabled', () => {
const assistant = createAssistant({ enableTopP: false, topP: 0.5 })
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTopP(assistant, model)).toBeUndefined()
})
})
describe('getTimeout', () => {
it('uses an extended timeout for flex service tier models', () => {
const model = createModel({ id: 'o3-pro', provider: 'openai', group: 'openai' })
expect(getTimeout(model)).toBe(15 * 1000 * 60)
})
it('falls back to the default timeout otherwise', () => {
const model = createModel({ id: 'gpt-4o', provider: 'openai', group: 'openai' })
expect(getTimeout(model)).toBe(defaultTimeout)
})
})
})

View File

@@ -0,0 +1,31 @@
import { isClaude4SeriesModel, isClaude45ReasoningModel } from '@renderer/config/models'
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Assistant, Model } from '@renderer/types'
import { isToolUseModeFunction } from '@renderer/utils/assistant'
import { isAwsBedrockProvider, isVertexProvider } from '@renderer/utils/provider'
// https://docs.claude.com/en/docs/build-with-claude/extended-thinking#interleaved-thinking
const INTERLEAVED_THINKING_HEADER = 'interleaved-thinking-2025-05-14'
// https://docs.claude.com/en/docs/build-with-claude/context-windows#1m-token-context-window
const CONTEXT_100M_HEADER = 'context-1m-2025-08-07'
// https://docs.cloud.google.com/vertex-ai/generative-ai/docs/partner-models/claude/web-search
const WEBSEARCH_HEADER = 'web-search-2025-03-05'
export function addAnthropicHeaders(assistant: Assistant, model: Model): string[] {
const anthropicHeaders: string[] = []
const provider = getProviderByModel(model)
if (
isClaude45ReasoningModel(model) &&
isToolUseModeFunction(assistant) &&
!(isVertexProvider(provider) && isAwsBedrockProvider(provider))
) {
anthropicHeaders.push(INTERLEAVED_THINKING_HEADER)
}
if (isClaude4SeriesModel(model)) {
if (isVertexProvider(provider) && assistant.enableWebSearch) {
anthropicHeaders.push(WEBSEARCH_HEADER)
}
anthropicHeaders.push(CONTEXT_100M_HEADER)
}
return anthropicHeaders
}

View File

@@ -85,19 +85,6 @@ export function supportsLargeFileUpload(model: Model): boolean {
})
}
/**
* 检查模型是否支持TopP
*/
export function supportsTopP(model: Model): boolean {
const provider = getProviderByModel(model)
if (provider?.type === 'anthropic' || model?.endpoint_type === 'anthropic') {
return false
}
return true
}
/**
* 获取提供商特定的文件大小限制
*/

View File

@@ -3,17 +3,27 @@
* 处理温度、TopP、超时等基础参数的获取逻辑
*/
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import {
isClaude45ReasoningModel,
isClaudeReasoningModel,
isMaxTemperatureOneModel,
isNotSupportTemperatureAndTopP,
isSupportedFlexServiceTier
isSupportedFlexServiceTier,
isSupportedThinkingTokenClaudeModel
} from '@renderer/config/models'
import { getAssistantSettings } from '@renderer/services/AssistantService'
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
import type { Assistant, Model } from '@renderer/types'
import { defaultTimeout } from '@shared/config/constant'
import { getAnthropicThinkingBudget } from '../utils/reasoning'
/**
* Claude 4.5 推理模型:
* - 只启用 temperature → 使用 temperature
* - 只启用 top_p → 使用 top_p
* - 同时启用 → temperature 生效,top_p 被忽略
* - 都不启用 → 都不使用
* 获取温度参数
*/
export function getTemperature(assistant: Assistant, model: Model): number | undefined {
@@ -27,7 +37,11 @@ export function getTemperature(assistant: Assistant, model: Model): number | und
return undefined
}
const assistantSettings = getAssistantSettings(assistant)
return assistantSettings?.enableTemperature ? assistantSettings?.temperature : undefined
let temperature = assistantSettings?.temperature
if (temperature && isMaxTemperatureOneModel(model)) {
temperature = Math.min(1, temperature)
}
return assistantSettings?.enableTemperature ? temperature : undefined
}
/**
@@ -56,3 +70,18 @@ export function getTimeout(model: Model): number {
}
return defaultTimeout
}
export function getMaxTokens(assistant: Assistant, model: Model): number | undefined {
// NOTE: ai-sdk会把maxToken和budgetToken加起来
let { maxTokens = DEFAULT_MAX_TOKENS } = getAssistantSettings(assistant)
const provider = getProviderByModel(model)
if (isSupportedThinkingTokenClaudeModel(model) && ['anthropic', 'aws-bedrock'].includes(provider.type)) {
const { reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
const budget = getAnthropicThinkingBudget(maxTokens, reasoningEffort, model.id)
if (budget) {
maxTokens -= budget
}
}
return maxTokens
}

View File

@@ -4,22 +4,24 @@
*/
import { anthropic } from '@ai-sdk/anthropic'
import { azure } from '@ai-sdk/azure'
import { google } from '@ai-sdk/google'
import { vertexAnthropic } from '@ai-sdk/google-vertex/anthropic/edge'
import { vertex } from '@ai-sdk/google-vertex/edge'
import type { WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
import { combineHeaders } from '@ai-sdk/provider-utils'
import type { AnthropicSearchConfig, WebSearchPluginConfig } from '@cherrystudio/ai-core/built-in/plugins'
import { isBaseProvider } from '@cherrystudio/ai-core/core/providers/schemas'
import { loggerService } from '@logger'
import {
isAnthropicModel,
isGenerateImageModel,
isOpenRouterBuiltInWebSearchModel,
isReasoningModel,
isSupportedReasoningEffortModel,
isSupportedThinkingTokenClaudeModel,
isSupportedThinkingTokenModel,
isWebSearchModel
} from '@renderer/config/models'
import { getAssistantSettings, getDefaultModel } from '@renderer/services/AssistantService'
import { getDefaultModel } from '@renderer/services/AssistantService'
import store from '@renderer/store'
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
import { type Assistant, type MCPTool, type Provider } from '@renderer/types'
@@ -32,10 +34,9 @@ import { stepCountIs } from 'ai'
import { getAiSdkProviderId } from '../provider/factory'
import { setupToolsConfig } from '../utils/mcp'
import { buildProviderOptions } from '../utils/options'
import { getAnthropicThinkingBudget } from '../utils/reasoning'
import { buildProviderBuiltinWebSearchConfig } from '../utils/websearch'
import { supportsTopP } from './modelCapabilities'
import { getTemperature, getTopP } from './modelParameters'
import { addAnthropicHeaders } from './header'
import { getMaxTokens, getTemperature, getTopP } from './modelParameters'
const logger = loggerService.withContext('parameterBuilder')
@@ -58,7 +59,7 @@ export async function buildStreamTextParams(
timeout?: number
headers?: Record<string, string>
}
} = {}
}
): Promise<{
params: StreamTextParams
modelId: string
@@ -75,8 +76,6 @@ export async function buildStreamTextParams(
const model = assistant.model || getDefaultModel()
const aiSdkProviderId = getAiSdkProviderId(provider)
let { maxTokens } = getAssistantSettings(assistant)
// 这三个变量透传出来,交给下面启用插件/中间件
// 也可以在外部构建好再传入buildStreamTextParams
// FIXME: qwen3即使关闭思考仍然会导致enableReasoning的结果为true
@@ -113,16 +112,6 @@ export async function buildStreamTextParams(
enableGenerateImage
})
// NOTE: ai-sdk会把maxToken和budgetToken加起来
if (
enableReasoning &&
maxTokens !== undefined &&
isSupportedThinkingTokenClaudeModel(model) &&
(provider.type === 'anthropic' || provider.type === 'aws-bedrock')
) {
maxTokens -= getAnthropicThinkingBudget(assistant, model)
}
let webSearchPluginConfig: WebSearchPluginConfig | undefined = undefined
if (enableWebSearch) {
if (isBaseProvider(aiSdkProviderId)) {
@@ -139,6 +128,17 @@ export async function buildStreamTextParams(
maxUses: webSearchConfig.maxResults,
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
}) as ProviderDefinedTool
} else if (aiSdkProviderId === 'azure-responses') {
tools.web_search_preview = azure.tools.webSearchPreview({
searchContextSize: webSearchPluginConfig?.openai!.searchContextSize
}) as ProviderDefinedTool
} else if (aiSdkProviderId === 'azure-anthropic') {
const blockedDomains = mapRegexToPatterns(webSearchConfig.excludeDomains)
const anthropicSearchOptions: AnthropicSearchConfig = {
maxUses: webSearchConfig.maxResults,
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
}
tools.web_search = anthropic.tools.webSearch_20250305(anthropicSearchOptions) as ProviderDefinedTool
}
}
@@ -156,9 +156,10 @@ export async function buildStreamTextParams(
tools.url_context = google.tools.urlContext({}) as ProviderDefinedTool
break
case 'anthropic':
case 'azure-anthropic':
case 'google-vertex-anthropic':
tools.web_fetch = (
aiSdkProviderId === 'anthropic'
['anthropic', 'azure-anthropic'].includes(aiSdkProviderId)
? anthropic.tools.webFetch_20250910({
maxUses: webSearchConfig.maxResults,
blockedDomains: blockedDomains.length > 0 ? blockedDomains : undefined
@@ -172,22 +173,26 @@ export async function buildStreamTextParams(
}
}
let headers: Record<string, string | undefined> = options.requestOptions?.headers ?? {}
if (isAnthropicModel(model)) {
const newBetaHeaders = { 'anthropic-beta': addAnthropicHeaders(assistant, model).join(',') }
headers = combineHeaders(headers, newBetaHeaders)
}
// 构建基础参数
const params: StreamTextParams = {
messages: sdkMessages,
maxOutputTokens: maxTokens,
maxOutputTokens: getMaxTokens(assistant, model),
temperature: getTemperature(assistant, model),
topP: getTopP(assistant, model),
abortSignal: options.requestOptions?.signal,
headers: options.requestOptions?.headers,
headers,
providerOptions,
stopWhen: stepCountIs(20),
maxRetries: 0
}
if (supportsTopP(model)) {
params.topP = getTopP(assistant, model)
}
if (tools) {
params.tools = tools
}

View File

@@ -23,6 +23,26 @@ vi.mock('@cherrystudio/ai-core', () => ({
}
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn(),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
vi.mock('@renderer/store/settings', () => ({
default: {},
settingsSlice: {
name: 'settings',
reducer: vi.fn(),
actions: {}
}
}))
// Mock the provider configs
vi.mock('../providerConfigs', () => ({
initializeNewProviders: vi.fn()

View File

@@ -12,7 +12,14 @@ vi.mock('@renderer/services/LoggerService', () => ({
}))
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn()
getProviderByModel: vi.fn(),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
vi.mock('@renderer/store', () => ({
@@ -34,7 +41,7 @@ vi.mock('@renderer/utils/api', () => ({
}))
}))
vi.mock('@renderer/config/providers', async (importOriginal) => {
vi.mock('@renderer/utils/provider', async (importOriginal) => {
const actual = (await importOriginal()) as any
return {
...actual,
@@ -53,10 +60,21 @@ vi.mock('@renderer/hooks/useVertexAI', () => ({
createVertexProvider: vi.fn()
}))
import { isCherryAIProvider, isPerplexityProvider } from '@renderer/config/providers'
vi.mock('@renderer/services/AssistantService', () => ({
getProviderByModel: vi.fn(),
getAssistantSettings: vi.fn(),
getDefaultAssistant: vi.fn().mockReturnValue({
id: 'default',
name: 'Default Assistant',
prompt: '',
settings: {}
})
}))
import { getProviderByModel } from '@renderer/services/AssistantService'
import type { Model, Provider } from '@renderer/types'
import { formatApiHost } from '@renderer/utils/api'
import { isCherryAIProvider, isPerplexityProvider } from '@renderer/utils/provider'
import { COPILOT_DEFAULT_HEADERS, COPILOT_EDITOR_VERSION, isCopilotResponsesModel } from '../constants'
import { getActualProvider, providerToAiSdkConfig } from '../providerConfig'

View File

@@ -0,0 +1,22 @@
import type { Provider } from '@renderer/types'
import { provider2Provider, startsWith } from './helper'
import type { RuleSet } from './types'
// https://platform.claude.com/docs/en/build-with-claude/claude-in-microsoft-foundry
const AZURE_ANTHROPIC_RULES: RuleSet = {
rules: [
{
match: startsWith('claude'),
provider: (provider: Provider) => ({
...provider,
type: 'anthropic',
apiHost: provider.apiHost + 'anthropic/v1',
id: 'azure-anthropic'
})
}
],
fallbackRule: (provider: Provider) => provider
}
export const azureAnthropicProviderCreator = provider2Provider.bind(null, AZURE_ANTHROPIC_RULES)

View File

@@ -2,8 +2,10 @@ import { hasProviderConfigByAlias, type ProviderId, resolveProviderConfigId } fr
import { createProvider as createProviderCore } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import type { Provider } from '@renderer/types'
import { isAzureOpenAIProvider, isAzureResponsesEndpoint } from '@renderer/utils/provider'
import type { Provider as AiSdkProvider } from 'ai'
import type { AiSdkConfig } from '../types'
import { initializeNewProviders } from './providerInitialization'
const logger = loggerService.withContext('ProviderFactory')
@@ -55,9 +57,12 @@ function tryResolveProviderId(identifier: string): ProviderId | null {
* 获取AI SDK Provider ID
* 简化版:减少重复逻辑,利用通用解析函数
*/
export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-compatible' {
export function getAiSdkProviderId(provider: Provider): string {
// 1. 尝试解析provider.id
const resolvedFromId = tryResolveProviderId(provider.id)
if (isAzureOpenAIProvider(provider) && isAzureResponsesEndpoint(provider)) {
return 'azure-responses'
}
if (resolvedFromId) {
return resolvedFromId
}
@@ -73,11 +78,11 @@ export function getAiSdkProviderId(provider: Provider): ProviderId | 'openai-com
if (provider.apiHost.includes('api.openai.com')) {
return 'openai-chat'
}
// 3. 最后的fallback通常会成为openai-compatible
return provider.id as ProviderId
// 3. 最后的fallback使用provider本身的id
return provider.id
}
export async function createAiSdkProvider(config) {
export async function createAiSdkProvider(config: AiSdkConfig): Promise<AiSdkProvider | null> {
let localProvider: Awaited<AiSdkProvider> | null = null
try {
if (config.providerId === 'openai' && config.options?.mode === 'chat') {

View File

@@ -1,19 +1,5 @@
import {
formatPrivateKey,
hasProviderConfig,
ProviderConfigFactory,
type ProviderId,
type ProviderSettingsMap
} from '@cherrystudio/ai-core/provider'
import { formatPrivateKey, hasProviderConfig, ProviderConfigFactory } from '@cherrystudio/ai-core/provider'
import { isOpenAIChatCompletionOnlyModel } from '@renderer/config/models'
import {
isAnthropicProvider,
isAzureOpenAIProvider,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isPerplexityProvider
} from '@renderer/config/providers'
import {
getAwsBedrockAccessKeyId,
getAwsBedrockApiKey,
@@ -21,14 +7,25 @@ import {
getAwsBedrockRegion,
getAwsBedrockSecretAccessKey
} from '@renderer/hooks/useAwsBedrock'
import { createVertexProvider, isVertexAIConfigured, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { createVertexProvider, isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
import { getProviderByModel } from '@renderer/services/AssistantService'
import store from '@renderer/store'
import { isSystemProvider, type Model, type Provider, SystemProviderIds } from '@renderer/types'
import { formatApiHost, formatAzureOpenAIApiHost, formatVertexApiHost, routeToEndpoint } from '@renderer/utils/api'
import {
isAnthropicProvider,
isAzureOpenAIProvider,
isCherryAIProvider,
isGeminiProvider,
isNewApiProvider,
isPerplexityProvider,
isVertexProvider
} from '@renderer/utils/provider'
import { cloneDeep } from 'lodash'
import type { AiSdkConfig } from '../types'
import { aihubmixProviderCreator, newApiResolverCreator, vertexAnthropicProviderCreator } from './config'
import { azureAnthropicProviderCreator } from './config/azure-anthropic'
import { COPILOT_DEFAULT_HEADERS } from './constants'
import { getAiSdkProviderId } from './factory'
@@ -74,6 +71,9 @@ function handleSpecialProviders(model: Model, provider: Provider): Provider {
return vertexAnthropicProviderCreator(model, provider)
}
}
if (isAzureOpenAIProvider(provider)) {
return azureAnthropicProviderCreator(model, provider)
}
return provider
}
@@ -131,13 +131,7 @@ export function getActualProvider(model: Model): Provider {
* 将 Provider 配置转换为新 AI SDK 格式
* 简化版:利用新的别名映射系统
*/
export function providerToAiSdkConfig(
actualProvider: Provider,
model: Model
): {
providerId: ProviderId | 'openai-compatible'
options: ProviderSettingsMap[keyof ProviderSettingsMap]
} {
export function providerToAiSdkConfig(actualProvider: Provider, model: Model): AiSdkConfig {
const aiSdkProviderId = getAiSdkProviderId(actualProvider)
// 构建基础配置
@@ -189,13 +183,12 @@ export function providerToAiSdkConfig(
}
}
// azure
if (aiSdkProviderId === 'azure' || actualProvider.type === 'azure-openai') {
// extraOptions.apiVersion = actualProvider.apiVersion 默认使用v1不使用azure endpoint
if (actualProvider.apiVersion === 'preview') {
extraOptions.mode = 'responses'
} else {
extraOptions.mode = 'chat'
}
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/latest
// https://learn.microsoft.com/en-us/azure/ai-foundry/openai/how-to/responses?tabs=python-key#responses-api
if (aiSdkProviderId === 'azure-responses') {
extraOptions.mode = 'responses'
} else if (aiSdkProviderId === 'azure') {
extraOptions.mode = 'chat'
}
// bedrock
@@ -225,10 +218,17 @@ export function providerToAiSdkConfig(
baseConfig.baseURL += aiSdkProviderId === 'google-vertex' ? '/publishers/google' : '/publishers/anthropic/models'
}
// cherryin
if (aiSdkProviderId === 'cherryin') {
if (model.endpoint_type) {
extraOptions.endpointType = model.endpoint_type
}
}
if (hasProviderConfig(aiSdkProviderId) && aiSdkProviderId !== 'openai-compatible') {
const options = ProviderConfigFactory.fromProvider(aiSdkProviderId, baseConfig, extraOptions)
return {
providerId: aiSdkProviderId as ProviderId,
providerId: aiSdkProviderId,
options
}
}

View File

@@ -32,6 +32,14 @@ export const NEW_PROVIDER_CONFIGS: ProviderConfig[] = [
supportsImageGeneration: true,
aliases: ['vertexai-anthropic']
},
{
id: 'azure-anthropic',
name: 'Azure AI Anthropic',
import: () => import('@ai-sdk/anthropic'),
creatorFunctionName: 'createAnthropic',
supportsImageGeneration: false,
aliases: ['azure-anthropic']
},
{
id: 'github-copilot-openai-compatible',
name: 'GitHub Copilot OpenAI Compatible',

View File

@@ -133,7 +133,7 @@ export class AiSdkSpanAdapter {
// 详细记录转换过程
const operationId = attributes['ai.operationId']
logger.info('Converting AI SDK span to SpanEntity', {
logger.debug('Converting AI SDK span to SpanEntity', {
spanName: spanName,
operationId,
spanTag,
@@ -149,7 +149,7 @@ export class AiSdkSpanAdapter {
})
if (tokenUsage) {
logger.info('Token usage data found', {
logger.debug('Token usage data found', {
spanName: spanName,
operationId,
usage: tokenUsage,
@@ -158,7 +158,7 @@ export class AiSdkSpanAdapter {
}
if (inputs || outputs) {
logger.info('Input/Output data extracted', {
logger.debug('Input/Output data extracted', {
spanName: spanName,
operationId,
hasInputs: !!inputs,
@@ -170,7 +170,7 @@ export class AiSdkSpanAdapter {
}
if (Object.keys(typeSpecificData).length > 0) {
logger.info('Type-specific data extracted', {
logger.debug('Type-specific data extracted', {
spanName: spanName,
operationId,
typeSpecificKeys: Object.keys(typeSpecificData),
@@ -204,7 +204,7 @@ export class AiSdkSpanAdapter {
modelName: modelName || this.extractModelFromAttributes(attributes)
}
logger.info('AI SDK span successfully converted to SpanEntity', {
logger.debug('AI SDK span successfully converted to SpanEntity', {
spanName: spanName,
operationId,
spanId: spanContext.spanId,

View File

@@ -0,0 +1,15 @@
/**
* This type definition file is only for renderer.
* It cannot be migrated to @renderer/types since files within it are actually being used by both main and renderer.
* If we do that, main would throw an error because it cannot import a module which imports a type from a browser-enviroment-only package.
* (ai-core package is set as browser-enviroment-only)
*
* TODO: We should separate them clearly. Keep renderer only types in renderer, and main only types in main, and shared types in shared.
*/
import type { ProviderSettingsMap } from '@cherrystudio/ai-core/provider'
export type AiSdkConfig = {
providerId: string
options: ProviderSettingsMap[keyof ProviderSettingsMap]
}

View File

@@ -0,0 +1,121 @@
/**
* image.ts Unit Tests
* Tests for Gemini image generation utilities
*/
import type { Model, Provider } from '@renderer/types'
import { SystemProviderIds } from '@renderer/types'
import { describe, expect, it } from 'vitest'
import { buildGeminiGenerateImageParams, isOpenRouterGeminiGenerateImageModel } from '../image'
describe('image utils', () => {
describe('buildGeminiGenerateImageParams', () => {
it('should return correct response modalities', () => {
const result = buildGeminiGenerateImageParams()
expect(result).toEqual({
responseModalities: ['TEXT', 'IMAGE']
})
})
it('should return an object with responseModalities property', () => {
const result = buildGeminiGenerateImageParams()
expect(result).toHaveProperty('responseModalities')
expect(Array.isArray(result.responseModalities)).toBe(true)
expect(result.responseModalities).toHaveLength(2)
})
})
describe('isOpenRouterGeminiGenerateImageModel', () => {
const mockOpenRouterProvider: Provider = {
id: SystemProviderIds.openrouter,
name: 'OpenRouter',
apiKey: 'test-key',
apiHost: 'https://openrouter.ai/api/v1',
isSystem: true
} as Provider
const mockOtherProvider: Provider = {
id: SystemProviderIds.openai,
name: 'OpenAI',
apiKey: 'test-key',
apiHost: 'https://api.openai.com/v1',
isSystem: true
} as Provider
it('should return true for OpenRouter Gemini 2.5 Flash Image model', () => {
const model: Model = {
id: 'google/gemini-2.5-flash-image-preview',
name: 'Gemini 2.5 Flash Image',
provider: SystemProviderIds.openrouter
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
expect(result).toBe(true)
})
it('should return false for non-Gemini model on OpenRouter', () => {
const model: Model = {
id: 'openai/gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openrouter
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
expect(result).toBe(false)
})
it('should return false for Gemini model on non-OpenRouter provider', () => {
const model: Model = {
id: 'gemini-2.5-flash-image-preview',
name: 'Gemini 2.5 Flash Image',
provider: SystemProviderIds.gemini
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOtherProvider)
expect(result).toBe(false)
})
it('should return false for Gemini model without image suffix', () => {
const model: Model = {
id: 'google/gemini-2.5-flash',
name: 'Gemini 2.5 Flash',
provider: SystemProviderIds.openrouter
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
expect(result).toBe(false)
})
it('should handle model ID with partial match', () => {
const model: Model = {
id: 'google/gemini-2.5-flash-image-generation',
name: 'Gemini Image Gen',
provider: SystemProviderIds.openrouter
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, mockOpenRouterProvider)
expect(result).toBe(true)
})
it('should return false for custom provider', () => {
const customProvider: Provider = {
id: 'custom-provider-123',
name: 'Custom Provider',
apiKey: 'test-key',
apiHost: 'https://custom.com'
} as Provider
const model: Model = {
id: 'gemini-2.5-flash-image-preview',
name: 'Gemini 2.5 Flash Image',
provider: 'custom-provider-123'
} as Model
const result = isOpenRouterGeminiGenerateImageModel(model, customProvider)
expect(result).toBe(false)
})
})
})

View File

@@ -0,0 +1,435 @@
/**
* mcp.ts Unit Tests
* Tests for MCP tools configuration and conversion utilities
*/
import type { MCPTool } from '@renderer/types'
import type { Tool } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { convertMcpToolsToAiSdkTools, setupToolsConfig } from '../mcp'
// Mock dependencies
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({
debug: vi.fn(),
error: vi.fn(),
warn: vi.fn(),
info: vi.fn()
})
}
}))
vi.mock('@renderer/utils/mcp-tools', () => ({
getMcpServerByTool: vi.fn(() => ({ id: 'test-server', autoApprove: false })),
isToolAutoApproved: vi.fn(() => false),
callMCPTool: vi.fn(async () => ({
content: [{ type: 'text', text: 'Tool executed successfully' }],
isError: false
}))
}))
vi.mock('@renderer/utils/userConfirmation', () => ({
requestToolConfirmation: vi.fn(async () => true)
}))
describe('mcp utils', () => {
beforeEach(() => {
vi.clearAllMocks()
})
describe('setupToolsConfig', () => {
it('should return undefined when no MCP tools provided', () => {
const result = setupToolsConfig()
expect(result).toBeUndefined()
})
it('should return undefined when empty MCP tools array provided', () => {
const result = setupToolsConfig([])
expect(result).toBeUndefined()
})
it('should convert MCP tools to AI SDK tools format', () => {
const mcpTools: MCPTool[] = [
{
id: 'test-tool-1',
serverId: 'test-server',
serverName: 'test-server',
name: 'test-tool',
description: 'A test tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
query: { type: 'string' }
}
}
}
]
const result = setupToolsConfig(mcpTools)
expect(result).not.toBeUndefined()
expect(Object.keys(result!)).toEqual(['test-tool'])
expect(result!['test-tool']).toHaveProperty('description')
expect(result!['test-tool']).toHaveProperty('inputSchema')
expect(result!['test-tool']).toHaveProperty('execute')
})
it('should handle multiple MCP tools', () => {
const mcpTools: MCPTool[] = [
{
id: 'tool1-id',
serverId: 'server1',
serverName: 'server1',
name: 'tool1',
description: 'First tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
},
{
id: 'tool2-id',
serverId: 'server2',
serverName: 'server2',
name: 'tool2',
description: 'Second tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const result = setupToolsConfig(mcpTools)
expect(result).not.toBeUndefined()
expect(Object.keys(result!)).toHaveLength(2)
expect(Object.keys(result!)).toEqual(['tool1', 'tool2'])
})
})
describe('convertMcpToolsToAiSdkTools', () => {
it('should convert single MCP tool to AI SDK tool', () => {
const mcpTools: MCPTool[] = [
{
id: 'get-weather-id',
serverId: 'weather-server',
serverName: 'weather-server',
name: 'get-weather',
description: 'Get weather information',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
location: { type: 'string' }
},
required: ['location']
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['get-weather'])
const tool = result['get-weather'] as Tool
expect(tool.description).toBe('Get weather information')
expect(tool.inputSchema).toBeDefined()
expect(typeof tool.execute).toBe('function')
})
it('should handle tool without description', () => {
const mcpTools: MCPTool[] = [
{
id: 'no-desc-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'no-desc-tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['no-desc-tool'])
const tool = result['no-desc-tool'] as Tool
expect(tool.description).toBe('Tool from test-server')
})
it('should convert empty tools array', () => {
const result = convertMcpToolsToAiSdkTools([])
expect(result).toEqual({})
})
it('should handle complex input schemas', () => {
const mcpTools: MCPTool[] = [
{
id: 'complex-tool-id',
serverId: 'server',
serverName: 'server',
name: 'complex-tool',
description: 'Tool with complex schema',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
name: { type: 'string' },
age: { type: 'number' },
tags: {
type: 'array',
items: { type: 'string' }
},
metadata: {
type: 'object',
properties: {
key: { type: 'string' }
}
}
},
required: ['name']
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['complex-tool'])
const tool = result['complex-tool'] as Tool
expect(tool.inputSchema).toBeDefined()
expect(typeof tool.execute).toBe('function')
})
it('should preserve tool names with special characters', () => {
const mcpTools: MCPTool[] = [
{
id: 'special-tool-id',
serverId: 'server',
serverName: 'server',
name: 'tool_with-special.chars',
description: 'Special chars tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result)).toEqual(['tool_with-special.chars'])
})
it('should handle multiple tools with different schemas', () => {
const mcpTools: MCPTool[] = [
{
id: 'string-tool-id',
serverId: 'server1',
serverName: 'server1',
name: 'string-tool',
description: 'String tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
input: { type: 'string' }
}
}
},
{
id: 'number-tool-id',
serverId: 'server2',
serverName: 'server2',
name: 'number-tool',
description: 'Number tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
count: { type: 'number' }
}
}
},
{
id: 'boolean-tool-id',
serverId: 'server3',
serverName: 'server3',
name: 'boolean-tool',
description: 'Boolean tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {
enabled: { type: 'boolean' }
}
}
}
]
const result = convertMcpToolsToAiSdkTools(mcpTools)
expect(Object.keys(result).sort()).toEqual(['boolean-tool', 'number-tool', 'string-tool'])
expect(result['string-tool']).toBeDefined()
expect(result['number-tool']).toBeDefined()
expect(result['boolean-tool']).toBeDefined()
})
})
describe('tool execution', () => {
it('should execute tool with user confirmation', async () => {
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
vi.mocked(requestToolConfirmation).mockResolvedValue(true)
vi.mocked(callMCPTool).mockResolvedValue({
content: [{ type: 'text', text: 'Success' }],
isError: false
})
const mcpTools: MCPTool[] = [
{
id: 'test-exec-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'test-exec-tool',
description: 'Test execution tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['test-exec-tool'] as Tool
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'test-call-123' })
expect(requestToolConfirmation).toHaveBeenCalled()
expect(callMCPTool).toHaveBeenCalled()
expect(result).toEqual({
content: [{ type: 'text', text: 'Success' }],
isError: false
})
})
it('should handle user cancellation', async () => {
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
vi.mocked(requestToolConfirmation).mockResolvedValue(false)
const mcpTools: MCPTool[] = [
{
id: 'cancelled-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'cancelled-tool',
description: 'Tool to cancel',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['cancelled-tool'] as Tool
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'cancel-call-123' })
expect(requestToolConfirmation).toHaveBeenCalled()
expect(callMCPTool).not.toHaveBeenCalled()
expect(result).toEqual({
content: [
{
type: 'text',
text: 'User declined to execute tool "cancelled-tool".'
}
],
isError: false
})
})
it('should handle tool execution error', async () => {
const { callMCPTool } = await import('@renderer/utils/mcp-tools')
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
vi.mocked(requestToolConfirmation).mockResolvedValue(true)
vi.mocked(callMCPTool).mockResolvedValue({
content: [{ type: 'text', text: 'Error occurred' }],
isError: true
})
const mcpTools: MCPTool[] = [
{
id: 'error-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'error-tool',
description: 'Tool that errors',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['error-tool'] as Tool
await expect(
tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'error-call-123' })
).rejects.toEqual({
content: [{ type: 'text', text: 'Error occurred' }],
isError: true
})
})
it('should auto-approve when enabled', async () => {
const { callMCPTool, isToolAutoApproved } = await import('@renderer/utils/mcp-tools')
const { requestToolConfirmation } = await import('@renderer/utils/userConfirmation')
vi.mocked(isToolAutoApproved).mockReturnValue(true)
vi.mocked(callMCPTool).mockResolvedValue({
content: [{ type: 'text', text: 'Auto-approved success' }],
isError: false
})
const mcpTools: MCPTool[] = [
{
id: 'auto-approve-tool-id',
serverId: 'test-server',
serverName: 'test-server',
name: 'auto-approve-tool',
description: 'Auto-approved tool',
type: 'mcp',
inputSchema: {
type: 'object',
properties: {}
}
}
]
const tools = convertMcpToolsToAiSdkTools(mcpTools)
const tool = tools['auto-approve-tool'] as Tool
const result = await tool.execute!({}, { messages: [], abortSignal: undefined, toolCallId: 'auto-call-123' })
expect(requestToolConfirmation).not.toHaveBeenCalled()
expect(callMCPTool).toHaveBeenCalled()
expect(result).toEqual({
content: [{ type: 'text', text: 'Auto-approved success' }],
isError: false
})
})
})
})

View File

@@ -0,0 +1,545 @@
/**
* options.ts Unit Tests
* Tests for building provider-specific options
*/
import type { Assistant, Model, Provider } from '@renderer/types'
import { OpenAIServiceTiers, SystemProviderIds } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { buildProviderOptions } from '../options'
// Mock dependencies
vi.mock('@cherrystudio/ai-core/provider', async (importOriginal) => {
const actual = (await importOriginal()) as object
return {
...actual,
baseProviderIdSchema: {
safeParse: vi.fn((id) => {
const baseProviders = [
'openai',
'openai-chat',
'azure',
'azure-responses',
'huggingface',
'anthropic',
'google',
'xai',
'deepseek',
'openrouter',
'openai-compatible'
]
if (baseProviders.includes(id)) {
return { success: true, data: id }
}
return { success: false }
})
},
customProviderIdSchema: {
safeParse: vi.fn((id) => {
const customProviders = ['google-vertex', 'google-vertex-anthropic', 'bedrock']
if (customProviders.includes(id)) {
return { success: true, data: id }
}
return { success: false, error: new Error('Invalid provider') }
})
}
}
})
vi.mock('../provider/factory', () => ({
getAiSdkProviderId: vi.fn((provider) => {
// Simulate the provider ID mapping
const mapping: Record<string, string> = {
[SystemProviderIds.gemini]: 'google',
[SystemProviderIds.openai]: 'openai',
[SystemProviderIds.anthropic]: 'anthropic',
[SystemProviderIds.grok]: 'xai',
[SystemProviderIds.deepseek]: 'deepseek',
[SystemProviderIds.openrouter]: 'openrouter'
}
return mapping[provider.id] || provider.id
})
}))
vi.mock('@renderer/config/models', async (importOriginal) => ({
...(await importOriginal()),
isOpenAIModel: vi.fn((model) => model.id.includes('gpt') || model.id.includes('o1')),
isQwenMTModel: vi.fn(() => false),
isSupportFlexServiceTierModel: vi.fn(() => true),
isOpenAILLMModel: vi.fn(() => true),
SYSTEM_MODELS: {
defaultModel: [
{ id: 'default-1', name: 'Default 1' },
{ id: 'default-2', name: 'Default 2' },
{ id: 'default-3', name: 'Default 3' }
]
}
}))
vi.mock(import('@renderer/utils/provider'), async (importOriginal) => {
return {
...(await importOriginal()),
isSupportServiceTierProvider: vi.fn((provider) => {
return [SystemProviderIds.openai, SystemProviderIds.groq].includes(provider.id)
})
}
})
vi.mock('@renderer/store/settings', () => ({
default: (state = { settings: {} }) => state
}))
vi.mock('@renderer/hooks/useSettings', () => ({
getStoreSetting: vi.fn((key) => {
if (key === 'openAI') {
return { summaryText: 'off', verbosity: 'medium' } as any
}
return {}
})
}))
vi.mock('@renderer/services/AssistantService', () => ({
getDefaultAssistant: vi.fn(() => ({
id: 'default',
name: 'Default Assistant',
settings: {}
})),
getAssistantSettings: vi.fn(() => ({
reasoning_effort: 'medium',
maxTokens: 4096
})),
getProviderByModel: vi.fn((model: Model) => ({
id: model.provider,
name: 'Mock Provider'
}))
}))
vi.mock('../reasoning', () => ({
getOpenAIReasoningParams: vi.fn(() => ({ reasoningEffort: 'medium' })),
getAnthropicReasoningParams: vi.fn(() => ({
thinking: { type: 'enabled', budgetTokens: 5000 }
})),
getGeminiReasoningParams: vi.fn(() => ({
thinkingConfig: { include_thoughts: true }
})),
getXAIReasoningParams: vi.fn(() => ({ reasoningEffort: 'high' })),
getBedrockReasoningParams: vi.fn(() => ({
reasoningConfig: { type: 'enabled', budgetTokens: 5000 }
})),
getReasoningEffort: vi.fn(() => ({ reasoningEffort: 'medium' })),
getCustomParameters: vi.fn(() => ({}))
}))
vi.mock('../image', () => ({
buildGeminiGenerateImageParams: vi.fn(() => ({
responseModalities: ['TEXT', 'IMAGE']
}))
}))
vi.mock('../websearch', () => ({
getWebSearchParams: vi.fn(() => ({ enable_search: true }))
}))
const ensureWindowApi = () => {
const globalWindow = window as any
globalWindow.api = globalWindow.api || {}
globalWindow.api.getAppInfo = globalWindow.api.getAppInfo || vi.fn(async () => ({ notesPath: '' }))
}
ensureWindowApi()
describe('options utils', () => {
const mockAssistant: Assistant = {
id: 'test-assistant',
name: 'Test Assistant',
settings: {}
} as Assistant
const mockModel: Model = {
id: 'gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openai
} as Model
beforeEach(() => {
vi.clearAllMocks()
})
describe('buildProviderOptions', () => {
describe('OpenAI provider', () => {
const openaiProvider: Provider = {
id: SystemProviderIds.openai,
name: 'OpenAI',
type: 'openai-response',
apiKey: 'test-key',
apiHost: 'https://api.openai.com/v1',
isSystem: true
} as Provider
it('should build basic OpenAI options', () => {
const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('openai')
expect(result.openai).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, mockModel, openaiProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.openai).toHaveProperty('reasoningEffort')
expect(result.openai.reasoningEffort).toBe('medium')
})
it('should include service tier when supported', () => {
const providerWithServiceTier: Provider = {
...openaiProvider,
serviceTier: OpenAIServiceTiers.auto
}
const result = buildProviderOptions(mockAssistant, mockModel, providerWithServiceTier, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.openai).toHaveProperty('serviceTier')
expect(result.openai.serviceTier).toBe(OpenAIServiceTiers.auto)
})
})
describe('Anthropic provider', () => {
const anthropicProvider: Provider = {
id: SystemProviderIds.anthropic,
name: 'Anthropic',
type: 'anthropic',
apiKey: 'test-key',
apiHost: 'https://api.anthropic.com',
isSystem: true
} as Provider
const anthropicModel: Model = {
id: 'claude-3-5-sonnet-20241022',
name: 'Claude 3.5 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
it('should build basic Anthropic options', () => {
const result = buildProviderOptions(mockAssistant, anthropicModel, anthropicProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('anthropic')
expect(result.anthropic).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, anthropicModel, anthropicProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.anthropic).toHaveProperty('thinking')
expect(result.anthropic.thinking).toEqual({
type: 'enabled',
budgetTokens: 5000
})
})
})
describe('Google provider', () => {
const googleProvider: Provider = {
id: SystemProviderIds.gemini,
name: 'Google',
type: 'gemini',
apiKey: 'test-key',
apiHost: 'https://generativelanguage.googleapis.com',
isSystem: true,
models: [{ id: 'gemini-2.0-flash-exp' }] as Model[]
} as Provider
const googleModel: Model = {
id: 'gemini-2.0-flash-exp',
name: 'Gemini 2.0 Flash',
provider: SystemProviderIds.gemini
} as Model
it('should build basic Google options', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('google')
expect(result.google).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.google).toHaveProperty('thinkingConfig')
expect(result.google.thinkingConfig).toEqual({
include_thoughts: true
})
})
it('should include image generation parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: true
})
expect(result.google).toHaveProperty('responseModalities')
expect(result.google.responseModalities).toEqual(['TEXT', 'IMAGE'])
})
})
describe('xAI provider', () => {
const xaiProvider = {
id: SystemProviderIds.grok,
name: 'xAI',
type: 'new-api',
apiKey: 'test-key',
apiHost: 'https://api.x.ai/v1',
isSystem: true,
models: [] as Model[]
} as Provider
const xaiModel: Model = {
id: 'grok-2-latest',
name: 'Grok 2',
provider: SystemProviderIds.grok
} as Model
it('should build basic xAI options', () => {
const result = buildProviderOptions(mockAssistant, xaiModel, xaiProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('xai')
expect(result.xai).toBeDefined()
})
it('should include reasoning parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, xaiModel, xaiProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result.xai).toHaveProperty('reasoningEffort')
expect(result.xai.reasoningEffort).toBe('high')
})
})
describe('DeepSeek provider', () => {
const deepseekProvider: Provider = {
id: SystemProviderIds.deepseek,
name: 'DeepSeek',
type: 'openai',
apiKey: 'test-key',
apiHost: 'https://api.deepseek.com',
isSystem: true
} as Provider
const deepseekModel: Model = {
id: 'deepseek-chat',
name: 'DeepSeek Chat',
provider: SystemProviderIds.deepseek
} as Model
it('should build basic DeepSeek options', () => {
const result = buildProviderOptions(mockAssistant, deepseekModel, deepseekProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('deepseek')
expect(result.deepseek).toBeDefined()
})
})
describe('OpenRouter provider', () => {
const openrouterProvider: Provider = {
id: SystemProviderIds.openrouter,
name: 'OpenRouter',
type: 'openai',
apiKey: 'test-key',
apiHost: 'https://openrouter.ai/api/v1',
isSystem: true
} as Provider
const openrouterModel: Model = {
id: 'openai/gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openrouter
} as Model
it('should build basic OpenRouter options', () => {
const result = buildProviderOptions(mockAssistant, openrouterModel, openrouterProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('openrouter')
expect(result.openrouter).toBeDefined()
})
it('should include web search parameters when enabled', () => {
const result = buildProviderOptions(mockAssistant, openrouterModel, openrouterProvider, {
enableReasoning: false,
enableWebSearch: true,
enableGenerateImage: false
})
expect(result.openrouter).toHaveProperty('enable_search')
})
})
describe('Custom parameters', () => {
it('should merge custom parameters', async () => {
const { getCustomParameters } = await import('../reasoning')
vi.mocked(getCustomParameters).mockReturnValue({
custom_param: 'custom_value',
another_param: 123
})
const result = buildProviderOptions(
mockAssistant,
mockModel,
{
id: SystemProviderIds.openai,
name: 'OpenAI',
type: 'openai',
apiKey: 'test-key',
apiHost: 'https://api.openai.com/v1'
} as Provider,
{
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
}
)
expect(result.openai).toHaveProperty('custom_param')
expect(result.openai.custom_param).toBe('custom_value')
expect(result.openai).toHaveProperty('another_param')
expect(result.openai.another_param).toBe(123)
})
})
describe('Multiple capabilities', () => {
const googleProvider = {
id: SystemProviderIds.gemini,
name: 'Google',
type: 'gemini',
apiKey: 'test-key',
apiHost: 'https://generativelanguage.googleapis.com',
isSystem: true,
models: [] as Model[]
} as Provider
const googleModel: Model = {
id: 'gemini-2.0-flash-exp',
name: 'Gemini 2.0 Flash',
provider: SystemProviderIds.gemini
} as Model
it('should combine reasoning and image generation', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: true,
enableWebSearch: false,
enableGenerateImage: true
})
expect(result.google).toHaveProperty('thinkingConfig')
expect(result.google).toHaveProperty('responseModalities')
})
it('should handle all capabilities enabled', () => {
const result = buildProviderOptions(mockAssistant, googleModel, googleProvider, {
enableReasoning: true,
enableWebSearch: true,
enableGenerateImage: true
})
expect(result.google).toBeDefined()
expect(Object.keys(result.google).length).toBeGreaterThan(0)
})
})
describe('Vertex AI providers', () => {
it('should map google-vertex to google', () => {
const vertexProvider = {
id: 'google-vertex',
name: 'Vertex AI',
type: 'vertexai',
apiKey: 'test-key',
apiHost: 'https://vertex-ai.googleapis.com',
models: [] as Model[]
} as Provider
const vertexModel: Model = {
id: 'gemini-2.0-flash-exp',
name: 'Gemini 2.0 Flash',
provider: 'google-vertex'
} as Model
const result = buildProviderOptions(mockAssistant, vertexModel, vertexProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('google')
})
it('should map google-vertex-anthropic to anthropic', () => {
const vertexAnthropicProvider = {
id: 'google-vertex-anthropic',
name: 'Vertex AI Anthropic',
type: 'vertex-anthropic',
apiKey: 'test-key',
apiHost: 'https://vertex-ai.googleapis.com',
models: [] as Model[]
} as Provider
const vertexModel: Model = {
id: 'claude-3-5-sonnet-20241022',
name: 'Claude 3.5 Sonnet',
provider: 'google-vertex-anthropic'
} as Model
const result = buildProviderOptions(mockAssistant, vertexModel, vertexAnthropicProvider, {
enableReasoning: false,
enableWebSearch: false,
enableGenerateImage: false
})
expect(result).toHaveProperty('anthropic')
})
})
})
})

View File

@@ -0,0 +1,967 @@
/**
* reasoning.ts Unit Tests
* Tests for reasoning parameter generation utilities
*/
import { getStoreSetting } from '@renderer/hooks/useSettings'
import type { SettingsState } from '@renderer/store/settings'
import type { Assistant, Model, Provider } from '@renderer/types'
import { SystemProviderIds } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import {
getAnthropicReasoningParams,
getBedrockReasoningParams,
getCustomParameters,
getGeminiReasoningParams,
getOpenAIReasoningParams,
getReasoningEffort,
getXAIReasoningParams
} from '../reasoning'
function defaultGetStoreSetting<K extends keyof SettingsState>(key: K): SettingsState[K] {
if (key === 'openAI') {
return {
summaryText: 'auto',
verbosity: 'medium'
} as SettingsState[K]
}
return undefined as SettingsState[K]
}
// Mock dependencies
vi.mock('@logger', () => ({
loggerService: {
withContext: () => ({
debug: vi.fn(),
error: vi.fn(),
warn: vi.fn(),
info: vi.fn()
})
}
}))
vi.mock('@renderer/store/settings', () => ({
default: (state = { settings: {} }) => state
}))
vi.mock('@renderer/store/llm', () => ({
initialState: {},
default: (state = { llm: {} }) => state
}))
vi.mock('@renderer/config/constant', () => ({
DEFAULT_MAX_TOKENS: 4096,
isMac: false,
isWin: false,
TOKENFLUX_HOST: 'mock-host'
}))
vi.mock('@renderer/utils/provider', () => ({
isSupportEnableThinkingProvider: vi.fn((provider) => {
return [SystemProviderIds.dashscope, SystemProviderIds.silicon].includes(provider.id)
})
}))
vi.mock('@renderer/config/models', async (importOriginal) => {
const actual: any = await importOriginal()
return {
...actual,
isReasoningModel: vi.fn(() => false),
isOpenAIDeepResearchModel: vi.fn(() => false),
isOpenAIModel: vi.fn(() => false),
isSupportedReasoningEffortOpenAIModel: vi.fn(() => false),
isSupportedThinkingTokenQwenModel: vi.fn(() => false),
isQwenReasoningModel: vi.fn(() => false),
isSupportedThinkingTokenClaudeModel: vi.fn(() => false),
isSupportedThinkingTokenGeminiModel: vi.fn(() => false),
isSupportedThinkingTokenDoubaoModel: vi.fn(() => false),
isSupportedThinkingTokenZhipuModel: vi.fn(() => false),
isSupportedReasoningEffortModel: vi.fn(() => false),
isDeepSeekHybridInferenceModel: vi.fn(() => false),
isSupportedReasoningEffortGrokModel: vi.fn(() => false),
getThinkModelType: vi.fn(() => 'default'),
isDoubaoSeedAfter251015: vi.fn(() => false),
isDoubaoThinkingAutoModel: vi.fn(() => false),
isGrok4FastReasoningModel: vi.fn(() => false),
isGrokReasoningModel: vi.fn(() => false),
isOpenAIReasoningModel: vi.fn(() => false),
isQwenAlwaysThinkModel: vi.fn(() => false),
isSupportedThinkingTokenHunyuanModel: vi.fn(() => false),
isSupportedThinkingTokenModel: vi.fn(() => false),
isGPT51SeriesModel: vi.fn(() => false)
}
})
vi.mock('@renderer/hooks/useSettings', () => ({
getStoreSetting: vi.fn(defaultGetStoreSetting)
}))
vi.mock('@renderer/services/AssistantService', () => ({
getAssistantSettings: vi.fn((assistant) => ({
maxTokens: assistant?.settings?.maxTokens || 4096,
reasoning_effort: assistant?.settings?.reasoning_effort
})),
getProviderByModel: vi.fn((model) => ({
id: model.provider,
name: 'Test Provider'
})),
getDefaultAssistant: vi.fn(() => ({
id: 'default',
name: 'Default Assistant',
settings: {}
}))
}))
const ensureWindowApi = () => {
const globalWindow = window as any
globalWindow.api = globalWindow.api || {}
globalWindow.api.getAppInfo = globalWindow.api.getAppInfo || vi.fn(async () => ({ notesPath: '' }))
}
ensureWindowApi()
describe('reasoning utils', () => {
beforeEach(() => {
vi.resetAllMocks()
})
describe('getReasoningEffort', () => {
it('should return empty object for non-reasoning model', async () => {
const model: Model = {
id: 'gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({})
})
it('should disable reasoning for OpenRouter when no reasoning effort set', async () => {
const { isReasoningModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
const model: Model = {
id: 'anthropic/claude-sonnet-4',
name: 'Claude Sonnet 4',
provider: SystemProviderIds.openrouter
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({ reasoning: { enabled: false, exclude: true } })
})
it('should handle Qwen models with enable_thinking', async () => {
const { isReasoningModel, isSupportedThinkingTokenQwenModel, isQwenReasoningModel } = await import(
'@renderer/config/models'
)
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenQwenModel).mockReturnValue(true)
vi.mocked(isQwenReasoningModel).mockReturnValue(true)
const model: Model = {
id: 'qwen-plus',
name: 'Qwen Plus',
provider: SystemProviderIds.dashscope
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium'
}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toHaveProperty('enable_thinking')
})
it('should handle Claude models with thinking config', async () => {
const {
isSupportedThinkingTokenClaudeModel,
isReasoningModel,
isQwenReasoningModel,
isSupportedThinkingTokenGeminiModel,
isSupportedThinkingTokenDoubaoModel,
isSupportedThinkingTokenZhipuModel,
isSupportedReasoningEffortModel
} = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
vi.mocked(isQwenReasoningModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenDoubaoModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenZhipuModel).mockReturnValue(false)
vi.mocked(isSupportedReasoningEffortModel).mockReturnValue(false)
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high',
maxTokens: 4096
}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
thinking: {
type: 'enabled',
budget_tokens: expect.any(Number)
}
})
})
it('should handle Gemini Flash models with thinking budget 0', async () => {
const {
isSupportedThinkingTokenGeminiModel,
isReasoningModel,
isQwenReasoningModel,
isSupportedThinkingTokenClaudeModel,
isSupportedThinkingTokenDoubaoModel,
isSupportedThinkingTokenZhipuModel,
isOpenAIDeepResearchModel,
isSupportedThinkingTokenQwenModel,
isSupportedThinkingTokenHunyuanModel,
isDeepSeekHybridInferenceModel
} = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
vi.mocked(isQwenReasoningModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenDoubaoModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenZhipuModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenQwenModel).mockReturnValue(false)
vi.mocked(isSupportedThinkingTokenHunyuanModel).mockReturnValue(false)
vi.mocked(isDeepSeekHybridInferenceModel).mockReturnValue(false)
const model: Model = {
id: 'gemini-2.5-flash',
name: 'Gemini 2.5 Flash',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
extra_body: {
google: {
thinking_config: {
thinking_budget: 0
}
}
}
})
})
it('should handle GPT-5.1 reasoning model with effort levels', async () => {
const {
isReasoningModel,
isOpenAIDeepResearchModel,
isSupportedReasoningEffortModel,
isGPT51SeriesModel,
getThinkModelType
} = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
vi.mocked(isSupportedReasoningEffortModel).mockReturnValue(true)
vi.mocked(getThinkModelType).mockReturnValue('gpt5_1')
vi.mocked(isGPT51SeriesModel).mockReturnValue(true)
const model: Model = {
id: 'gpt-5.1',
name: 'GPT-5.1',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'none'
}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
reasoningEffort: 'none'
})
})
it('should handle DeepSeek hybrid inference models', async () => {
const { isReasoningModel, isDeepSeekHybridInferenceModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isDeepSeekHybridInferenceModel).mockReturnValue(true)
const model: Model = {
id: 'deepseek-v3.1',
name: 'DeepSeek V3.1',
provider: SystemProviderIds.silicon
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({
enable_thinking: true
})
})
it('should return medium effort for deep research models', async () => {
const { isReasoningModel, isOpenAIDeepResearchModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(true)
const model: Model = {
id: 'o3-deep-research',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({ reasoning_effort: 'medium' })
})
it('should return empty for groq provider', async () => {
const { getProviderByModel } = await import('@renderer/services/AssistantService')
vi.mocked(getProviderByModel).mockReturnValue({
id: 'groq',
name: 'Groq'
} as Provider)
const model: Model = {
id: 'groq-model',
name: 'Groq Model',
provider: 'groq'
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getReasoningEffort(assistant, model)
expect(result).toEqual({})
})
})
describe('getOpenAIReasoningParams', () => {
it('should return empty object for non-reasoning model', async () => {
const model: Model = {
id: 'gpt-4',
name: 'GPT-4',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return empty when no reasoning effort set', async () => {
const model: Model = {
id: 'o1-preview',
name: 'O1 Preview',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return reasoning effort for OpenAI models', async () => {
const { isReasoningModel, isOpenAIModel, isSupportedReasoningEffortOpenAIModel } = await import(
'@renderer/config/models'
)
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIModel).mockReturnValue(true)
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
const model: Model = {
id: 'gpt-5.1',
name: 'GPT 5.1',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({
reasoningEffort: 'high',
reasoningSummary: 'auto'
})
})
it('should include reasoning summary when not o1-pro', async () => {
const { isReasoningModel, isOpenAIModel, isSupportedReasoningEffortOpenAIModel } = await import(
'@renderer/config/models'
)
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIModel).mockReturnValue(true)
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
const model: Model = {
id: 'gpt-5',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium'
}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({
reasoningEffort: 'medium',
reasoningSummary: 'auto'
})
})
it('should not include reasoning summary for o1-pro', async () => {
const { isReasoningModel, isOpenAIDeepResearchModel, isSupportedReasoningEffortOpenAIModel } = await import(
'@renderer/config/models'
)
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(false)
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
vi.mocked(getStoreSetting).mockReturnValue({ summaryText: 'off' } as any)
const model: Model = {
id: 'o1-pro',
name: 'O1 Pro',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({
reasoningEffort: 'high',
reasoningSummary: undefined
})
})
it('should force medium effort for deep research models', async () => {
const { isReasoningModel, isOpenAIModel, isOpenAIDeepResearchModel, isSupportedReasoningEffortOpenAIModel } =
await import('@renderer/config/models')
const { getStoreSetting } = await import('@renderer/hooks/useSettings')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isOpenAIModel).mockReturnValue(true)
vi.mocked(isOpenAIDeepResearchModel).mockReturnValue(true)
vi.mocked(isSupportedReasoningEffortOpenAIModel).mockReturnValue(true)
vi.mocked(getStoreSetting).mockReturnValue({ summaryText: 'off' } as any)
const model: Model = {
id: 'o3-deep-research',
name: 'O3 Mini',
provider: SystemProviderIds.openai
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getOpenAIReasoningParams(assistant, model)
expect(result).toEqual({
reasoningEffort: 'medium',
reasoningSummary: 'off'
})
})
})
describe('getAnthropicReasoningParams', () => {
it('should return empty for non-reasoning model', async () => {
const { isReasoningModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(false)
const model: Model = {
id: 'claude-3-5-sonnet',
name: 'Claude 3.5 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getAnthropicReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return disabled thinking when no reasoning effort', async () => {
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(false)
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getAnthropicReasoningParams(assistant, model)
expect(result).toEqual({
thinking: {
type: 'disabled'
}
})
})
it('should return enabled thinking with budget for Claude models', async () => {
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: SystemProviderIds.anthropic
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium',
maxTokens: 4096
}
} as Assistant
const result = getAnthropicReasoningParams(assistant, model)
expect(result).toEqual({
thinking: {
type: 'enabled',
budgetTokens: 2048
}
})
})
})
describe('getGeminiReasoningParams', () => {
it('should return empty for non-reasoning model', async () => {
const { isReasoningModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(false)
const model: Model = {
id: 'gemini-2.0-flash',
name: 'Gemini 2.0 Flash',
provider: SystemProviderIds.gemini
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getGeminiReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should disable thinking for Flash models without reasoning effort', async () => {
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
const model: Model = {
id: 'gemini-2.5-flash',
name: 'Gemini 2.5 Flash',
provider: SystemProviderIds.gemini
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getGeminiReasoningParams(assistant, model)
expect(result).toEqual({
thinkingConfig: {
includeThoughts: false,
thinkingBudget: 0
}
})
})
it('should enable thinking with budget for reasoning effort', async () => {
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
const model: Model = {
id: 'gemini-2.5-pro',
name: 'Gemini 2.5 Pro',
provider: SystemProviderIds.gemini
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium'
}
} as Assistant
const result = getGeminiReasoningParams(assistant, model)
expect(result).toEqual({
thinkingConfig: {
thinkingBudget: 16448,
includeThoughts: true
}
})
})
it('should enable thinking without budget for auto effort ratio > 1', async () => {
const { isReasoningModel, isSupportedThinkingTokenGeminiModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenGeminiModel).mockReturnValue(true)
const model: Model = {
id: 'gemini-2.5-pro',
name: 'Gemini 2.5 Pro',
provider: SystemProviderIds.gemini
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'auto'
}
} as Assistant
const result = getGeminiReasoningParams(assistant, model)
expect(result).toEqual({
thinkingConfig: {
includeThoughts: true
}
})
})
})
describe('getXAIReasoningParams', () => {
it('should return empty for non-Grok model', async () => {
const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(false)
const model: Model = {
id: 'other-model',
name: 'Other Model',
provider: SystemProviderIds.grok
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getXAIReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return empty when no reasoning effort', async () => {
const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(true)
const model: Model = {
id: 'grok-2',
name: 'Grok 2',
provider: SystemProviderIds.grok
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getXAIReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return reasoning effort for Grok models', async () => {
const { isSupportedReasoningEffortGrokModel } = await import('@renderer/config/models')
vi.mocked(isSupportedReasoningEffortGrokModel).mockReturnValue(true)
const model: Model = {
id: 'grok-3',
name: 'Grok 3',
provider: SystemProviderIds.grok
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'high'
}
} as Assistant
const result = getXAIReasoningParams(assistant, model)
expect(result).toHaveProperty('reasoningEffort')
expect(result.reasoningEffort).toBe('high')
})
})
describe('getBedrockReasoningParams', () => {
it('should return empty for non-reasoning model', async () => {
const model: Model = {
id: 'other-model',
name: 'Other Model',
provider: 'bedrock'
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getBedrockReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return empty when no reasoning effort', async () => {
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: 'bedrock'
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getBedrockReasoningParams(assistant, model)
expect(result).toEqual({})
})
it('should return reasoning config for Claude models on Bedrock', async () => {
const { isReasoningModel, isSupportedThinkingTokenClaudeModel } = await import('@renderer/config/models')
vi.mocked(isReasoningModel).mockReturnValue(true)
vi.mocked(isSupportedThinkingTokenClaudeModel).mockReturnValue(true)
const model: Model = {
id: 'claude-3-7-sonnet',
name: 'Claude 3.7 Sonnet',
provider: 'bedrock'
} as Model
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
reasoning_effort: 'medium',
maxTokens: 4096
}
} as Assistant
const result = getBedrockReasoningParams(assistant, model)
expect(result).toEqual({
reasoningConfig: {
type: 'enabled',
budgetTokens: 2048
}
})
})
})
describe('getCustomParameters', () => {
it('should return empty object when no custom parameters', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({})
})
it('should return custom parameters as key-value pairs', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [
{ name: 'param1', value: 'value1', type: 'string' },
{ name: 'param2', value: 123, type: 'number' }
]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
param1: 'value1',
param2: 123
})
})
it('should parse JSON type parameters', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [{ name: 'config', value: '{"key": "value"}', type: 'json' }]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
config: { key: 'value' }
})
})
it('should handle invalid JSON gracefully', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [{ name: 'invalid', value: '{invalid json', type: 'json' }]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
invalid: '{invalid json'
})
})
it('should handle undefined JSON value', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [{ name: 'undef', value: 'undefined', type: 'json' }]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
undef: undefined
})
})
it('should skip parameters with empty names', async () => {
const assistant: Assistant = {
id: 'test',
name: 'Test',
settings: {
customParameters: [
{ name: '', value: 'value1', type: 'string' },
{ name: ' ', value: 'value2', type: 'string' },
{ name: 'valid', value: 'value3', type: 'string' }
]
}
} as Assistant
const result = getCustomParameters(assistant)
expect(result).toEqual({
valid: 'value3'
})
})
})
})

View File

@@ -0,0 +1,384 @@
/**
* websearch.ts Unit Tests
* Tests for web search parameters generation utilities
*/
import type { CherryWebSearchConfig } from '@renderer/store/websearch'
import type { Model } from '@renderer/types'
import { describe, expect, it, vi } from 'vitest'
import { buildProviderBuiltinWebSearchConfig, getWebSearchParams } from '../websearch'
// Mock dependencies
vi.mock('@renderer/config/models', () => ({
isOpenAIWebSearchChatCompletionOnlyModel: vi.fn((model) => model?.id?.includes('o1-pro') ?? false),
isOpenAIDeepResearchModel: vi.fn((model) => model?.id?.includes('o3-mini') ?? false)
}))
vi.mock('@renderer/utils/blacklistMatchPattern', () => ({
mapRegexToPatterns: vi.fn((patterns) => patterns || [])
}))
describe('websearch utils', () => {
describe('getWebSearchParams', () => {
it('should return enhancement params for hunyuan provider', () => {
const model: Model = {
id: 'hunyuan-model',
name: 'Hunyuan Model',
provider: 'hunyuan'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({
enable_enhancement: true,
citation: true,
search_info: true
})
})
it('should return search params for dashscope provider', () => {
const model: Model = {
id: 'qwen-model',
name: 'Qwen Model',
provider: 'dashscope'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({
enable_search: true,
search_options: {
forced_search: true
}
})
})
it('should return web_search_options for OpenAI web search models', () => {
const model: Model = {
id: 'o1-pro',
name: 'O1 Pro',
provider: 'openai'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({
web_search_options: {}
})
})
it('should return empty object for other providers', () => {
const model: Model = {
id: 'gpt-4',
name: 'GPT-4',
provider: 'openai'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({})
})
it('should return empty object for custom provider', () => {
const model: Model = {
id: 'custom-model',
name: 'Custom Model',
provider: 'custom-provider'
} as Model
const result = getWebSearchParams(model)
expect(result).toEqual({})
})
})
describe('buildProviderBuiltinWebSearchConfig', () => {
const defaultWebSearchConfig: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 50,
excludeDomains: []
}
describe('openai provider', () => {
it('should return low search context size for low maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 20,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result).toEqual({
openai: {
searchContextSize: 'low'
}
})
})
it('should return medium search context size for medium maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 50,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result).toEqual({
openai: {
searchContextSize: 'medium'
}
})
})
it('should return high search context size for high maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 80,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result).toEqual({
openai: {
searchContextSize: 'high'
}
})
})
it('should use medium for deep research models regardless of maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 100,
excludeDomains: []
}
const model: Model = {
id: 'o3-mini',
name: 'O3 Mini',
provider: 'openai'
} as Model
const result = buildProviderBuiltinWebSearchConfig('openai', config, model)
expect(result).toEqual({
openai: {
searchContextSize: 'medium'
}
})
})
})
describe('openai-chat provider', () => {
it('should return correct search context size', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 50,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openai-chat', config)
expect(result).toEqual({
'openai-chat': {
searchContextSize: 'medium'
}
})
})
it('should handle deep research models', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 100,
excludeDomains: []
}
const model: Model = {
id: 'o3-mini',
name: 'O3 Mini',
provider: 'openai'
} as Model
const result = buildProviderBuiltinWebSearchConfig('openai-chat', config, model)
expect(result).toEqual({
'openai-chat': {
searchContextSize: 'medium'
}
})
})
})
describe('anthropic provider', () => {
it('should return anthropic search options with maxUses', () => {
const result = buildProviderBuiltinWebSearchConfig('anthropic', defaultWebSearchConfig)
expect(result).toEqual({
anthropic: {
maxUses: 50,
blockedDomains: undefined
}
})
})
it('should include blockedDomains when excludeDomains provided', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 30,
excludeDomains: ['example.com', 'test.com']
}
const result = buildProviderBuiltinWebSearchConfig('anthropic', config)
expect(result).toEqual({
anthropic: {
maxUses: 30,
blockedDomains: ['example.com', 'test.com']
}
})
})
it('should not include blockedDomains when empty', () => {
const result = buildProviderBuiltinWebSearchConfig('anthropic', defaultWebSearchConfig)
expect(result).toEqual({
anthropic: {
maxUses: 50,
blockedDomains: undefined
}
})
})
})
describe('xai provider', () => {
it('should return xai search options', () => {
const result = buildProviderBuiltinWebSearchConfig('xai', defaultWebSearchConfig)
expect(result).toEqual({
xai: {
maxSearchResults: 50,
returnCitations: true,
sources: [{ type: 'web', excludedWebsites: [] }, { type: 'news' }, { type: 'x' }],
mode: 'on'
}
})
})
it('should limit excluded websites to 5', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 40,
excludeDomains: ['site1.com', 'site2.com', 'site3.com', 'site4.com', 'site5.com', 'site6.com', 'site7.com']
}
const result = buildProviderBuiltinWebSearchConfig('xai', config)
expect(result?.xai?.sources).toBeDefined()
const webSource = result?.xai?.sources?.[0]
if (webSource && webSource.type === 'web') {
expect(webSource.excludedWebsites).toHaveLength(5)
}
})
it('should include all sources types', () => {
const result = buildProviderBuiltinWebSearchConfig('xai', defaultWebSearchConfig)
expect(result?.xai?.sources).toHaveLength(3)
expect(result?.xai?.sources?.[0].type).toBe('web')
expect(result?.xai?.sources?.[1].type).toBe('news')
expect(result?.xai?.sources?.[2].type).toBe('x')
})
})
describe('openrouter provider', () => {
it('should return openrouter plugins config', () => {
const result = buildProviderBuiltinWebSearchConfig('openrouter', defaultWebSearchConfig)
expect(result).toEqual({
openrouter: {
plugins: [
{
id: 'web',
max_results: 50
}
]
}
})
})
it('should respect custom maxResults', () => {
const config: CherryWebSearchConfig = {
searchWithTime: true,
maxResults: 75,
excludeDomains: []
}
const result = buildProviderBuiltinWebSearchConfig('openrouter', config)
expect(result).toEqual({
openrouter: {
plugins: [
{
id: 'web',
max_results: 75
}
]
}
})
})
})
describe('unsupported provider', () => {
it('should return empty object for unsupported provider', () => {
const result = buildProviderBuiltinWebSearchConfig('unsupported' as any, defaultWebSearchConfig)
expect(result).toEqual({})
})
it('should return empty object for google provider', () => {
const result = buildProviderBuiltinWebSearchConfig('google', defaultWebSearchConfig)
expect(result).toEqual({})
})
})
describe('edge cases', () => {
it('should handle maxResults at boundary values', () => {
// Test boundary at 33 (low/medium)
const config33: CherryWebSearchConfig = { searchWithTime: true, maxResults: 33, excludeDomains: [] }
const result33 = buildProviderBuiltinWebSearchConfig('openai', config33)
expect(result33?.openai?.searchContextSize).toBe('low')
// Test boundary at 34 (medium)
const config34: CherryWebSearchConfig = { searchWithTime: true, maxResults: 34, excludeDomains: [] }
const result34 = buildProviderBuiltinWebSearchConfig('openai', config34)
expect(result34?.openai?.searchContextSize).toBe('medium')
// Test boundary at 66 (medium)
const config66: CherryWebSearchConfig = { searchWithTime: true, maxResults: 66, excludeDomains: [] }
const result66 = buildProviderBuiltinWebSearchConfig('openai', config66)
expect(result66?.openai?.searchContextSize).toBe('medium')
// Test boundary at 67 (high)
const config67: CherryWebSearchConfig = { searchWithTime: true, maxResults: 67, excludeDomains: [] }
const result67 = buildProviderBuiltinWebSearchConfig('openai', config67)
expect(result67?.openai?.searchContextSize).toBe('high')
})
it('should handle zero maxResults', () => {
const config: CherryWebSearchConfig = { searchWithTime: true, maxResults: 0, excludeDomains: [] }
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result?.openai?.searchContextSize).toBe('low')
})
it('should handle very large maxResults', () => {
const config: CherryWebSearchConfig = { searchWithTime: true, maxResults: 1000, excludeDomains: [] }
const result = buildProviderBuiltinWebSearchConfig('openai', config)
expect(result?.openai?.searchContextSize).toBe('high')
})
})
})
})

View File

@@ -1,16 +1,38 @@
import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock'
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import type { XaiProviderOptions } from '@ai-sdk/xai'
import { baseProviderIdSchema, customProviderIdSchema } from '@cherrystudio/ai-core/provider'
import { isOpenAIModel, isQwenMTModel, isSupportFlexServiceTierModel } from '@renderer/config/models'
import { isSupportServiceTierProvider } from '@renderer/config/providers'
import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
import type { Assistant, Model, Provider } from '@renderer/types'
import { loggerService } from '@logger'
import {
getModelSupportedVerbosity,
isOpenAIModel,
isQwenMTModel,
isSupportFlexServiceTierModel,
isSupportVerbosityModel
} from '@renderer/config/models'
import { mapLanguageToQwenMTModel } from '@renderer/config/translate'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import {
type Assistant,
type GroqServiceTier,
GroqServiceTiers,
type GroqSystemProvider,
isGroqServiceTier,
isGroqSystemProvider,
isOpenAIServiceTier,
isTranslateAssistant,
type Model,
type NotGroqProvider,
type OpenAIServiceTier,
OpenAIServiceTiers,
SystemProviderIds
type Provider,
type ServiceTier
} from '@renderer/types'
import type { OpenAIVerbosity } from '@renderer/types/aiCoreTypes'
import { isSupportServiceTierProvider } from '@renderer/utils/provider'
import type { JSONValue } from 'ai'
import { t } from 'i18next'
import { getAiSdkProviderId } from '../provider/factory'
@@ -26,8 +48,33 @@ import {
} from './reasoning'
import { getWebSearchParams } from './websearch'
// copy from BaseApiClient.ts
const getServiceTier = (model: Model, provider: Provider) => {
const logger = loggerService.withContext('aiCore.utils.options')
function toOpenAIServiceTier(model: Model, serviceTier: ServiceTier): OpenAIServiceTier {
if (
!isOpenAIServiceTier(serviceTier) ||
(serviceTier === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
} else {
return serviceTier
}
}
function toGroqServiceTier(model: Model, serviceTier: ServiceTier): GroqServiceTier {
if (
!isGroqServiceTier(serviceTier) ||
(serviceTier === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
} else {
return serviceTier
}
}
function getServiceTier<T extends GroqSystemProvider>(model: Model, provider: T): GroqServiceTier
function getServiceTier<T extends NotGroqProvider>(model: Model, provider: T): OpenAIServiceTier
function getServiceTier<T extends Provider>(model: Model, provider: T): OpenAIServiceTier | GroqServiceTier {
const serviceTierSetting = provider.serviceTier
if (!isSupportServiceTierProvider(provider) || !isOpenAIModel(model) || !serviceTierSetting) {
@@ -35,24 +82,17 @@ const getServiceTier = (model: Model, provider: Provider) => {
}
// 处理不同供应商需要 fallback 到默认值的情况
if (provider.id === SystemProviderIds.groq) {
if (
!isGroqServiceTier(serviceTierSetting) ||
(serviceTierSetting === GroqServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
}
if (isGroqSystemProvider(provider)) {
return toGroqServiceTier(model, serviceTierSetting)
} else {
// 其他 OpenAI 供应商,假设他们的服务层级设置和 OpenAI 完全相同
if (
!isOpenAIServiceTier(serviceTierSetting) ||
(serviceTierSetting === OpenAIServiceTiers.flex && !isSupportFlexServiceTierModel(model))
) {
return undefined
}
return toOpenAIServiceTier(model, serviceTierSetting)
}
}
return serviceTierSetting
function getVerbosity(): OpenAIVerbosity {
const openAI = getStoreSetting('openAI')
return openAI.verbosity
}
/**
@@ -69,12 +109,13 @@ export function buildProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
): Record<string, Record<string, JSONValue>> {
logger.debug('buildProviderOptions', { assistant, model, actualProvider, capabilities })
const rawProviderId = getAiSdkProviderId(actualProvider)
// 构建 provider 特定的选项
let providerSpecificOptions: Record<string, any> = {}
const serviceTierSetting = getServiceTier(model, actualProvider)
providerSpecificOptions.serviceTier = serviceTierSetting
const serviceTier = getServiceTier(model, actualProvider)
const textVerbosity = getVerbosity()
// 根据 provider 类型分离构建逻辑
const { data: baseProviderId, success } = baseProviderIdSchema.safeParse(rawProviderId)
if (success) {
@@ -84,14 +125,16 @@ export function buildProviderOptions(
case 'openai-chat':
case 'azure':
case 'azure-responses':
providerSpecificOptions = {
...buildOpenAIProviderOptions(assistant, model, capabilities),
serviceTier: serviceTierSetting
{
const options: OpenAIResponsesProviderOptions = buildOpenAIProviderOptions(
assistant,
model,
capabilities,
serviceTier
)
providerSpecificOptions = options
}
break
case 'huggingface':
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities)
break
case 'anthropic':
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
break
@@ -109,12 +152,19 @@ export function buildProviderOptions(
// 对于其他 provider使用通用的构建逻辑
providerSpecificOptions = {
...buildGenericProviderOptions(assistant, model, capabilities),
serviceTier: serviceTierSetting
serviceTier,
textVerbosity
}
break
}
case 'cherryin':
providerSpecificOptions = buildCherryInProviderOptions(assistant, model, capabilities, actualProvider)
providerSpecificOptions = buildCherryInProviderOptions(
assistant,
model,
capabilities,
actualProvider,
serviceTier
)
break
default:
throw new Error(`Unsupported base provider ${baseProviderId}`)
@@ -128,17 +178,22 @@ export function buildProviderOptions(
case 'google-vertex':
providerSpecificOptions = buildGeminiProviderOptions(assistant, model, capabilities)
break
case 'azure-anthropic':
case 'google-vertex-anthropic':
providerSpecificOptions = buildAnthropicProviderOptions(assistant, model, capabilities)
break
case 'bedrock':
providerSpecificOptions = buildBedrockProviderOptions(assistant, model, capabilities)
break
case 'huggingface':
providerSpecificOptions = buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier)
break
default:
// 对于其他 provider使用通用的构建逻辑
providerSpecificOptions = {
...buildGenericProviderOptions(assistant, model, capabilities),
serviceTier: serviceTierSetting
serviceTier,
textVerbosity
}
}
} else {
@@ -152,13 +207,18 @@ export function buildProviderOptions(
...getCustomParameters(assistant)
}
const rawProviderKey =
let rawProviderKey =
{
'google-vertex': 'google',
'google-vertex-anthropic': 'anthropic',
'azure-anthropic': 'anthropic',
'ai-gateway': 'gateway'
}[rawProviderId] || rawProviderId
if (rawProviderKey === 'cherryin') {
rawProviderKey = { gemini: 'google' }[actualProvider.type] || actualProvider.type
}
// 返回 AI Core SDK 要求的格式:{ 'providerId': providerOptions }
return {
[rawProviderKey]: providerSpecificOptions
@@ -175,10 +235,11 @@ function buildOpenAIProviderOptions(
enableReasoning: boolean
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
},
serviceTier: OpenAIServiceTier
): OpenAIResponsesProviderOptions {
const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {}
let providerOptions: OpenAIResponsesProviderOptions = {}
// OpenAI 推理参数
if (enableReasoning) {
const reasoningParams = getOpenAIReasoningParams(assistant, model)
@@ -187,6 +248,28 @@ function buildOpenAIProviderOptions(
...reasoningParams
}
}
if (isSupportVerbosityModel(model)) {
const openAI = getStoreSetting<'openAI'>('openAI')
const userVerbosity = openAI?.verbosity
if (userVerbosity && ['low', 'medium', 'high'].includes(userVerbosity)) {
const supportedVerbosity = getModelSupportedVerbosity(model)
// Use user's verbosity if supported, otherwise use the first supported option
const verbosity = supportedVerbosity.includes(userVerbosity) ? userVerbosity : supportedVerbosity[0]
providerOptions = {
...providerOptions,
textVerbosity: verbosity
}
}
}
providerOptions = {
...providerOptions,
serviceTier
}
return providerOptions
}
@@ -201,9 +284,9 @@ function buildAnthropicProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
): AnthropicProviderOptions {
const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {}
let providerOptions: AnthropicProviderOptions = {}
// Anthropic 推理参数
if (enableReasoning) {
@@ -228,9 +311,9 @@ function buildGeminiProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
): GoogleGenerativeAIProviderOptions {
const { enableReasoning, enableGenerateImage } = capabilities
let providerOptions: Record<string, any> = {}
let providerOptions: GoogleGenerativeAIProviderOptions = {}
// Gemini 推理参数
if (enableReasoning) {
@@ -259,7 +342,7 @@ function buildXAIProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
): XaiProviderOptions {
const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {}
@@ -282,16 +365,12 @@ function buildCherryInProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
},
actualProvider: Provider
): Record<string, any> {
const serviceTierSetting = getServiceTier(model, actualProvider)
actualProvider: Provider,
serviceTier: OpenAIServiceTier
): OpenAIResponsesProviderOptions | AnthropicProviderOptions | GoogleGenerativeAIProviderOptions {
switch (actualProvider.type) {
case 'openai':
return {
...buildOpenAIProviderOptions(assistant, model, capabilities),
serviceTier: serviceTierSetting
}
return buildOpenAIProviderOptions(assistant, model, capabilities, serviceTier)
case 'anthropic':
return buildAnthropicProviderOptions(assistant, model, capabilities)
@@ -313,9 +392,9 @@ function buildBedrockProviderOptions(
enableWebSearch: boolean
enableGenerateImage: boolean
}
): Record<string, any> {
): BedrockProviderOptions {
const { enableReasoning } = capabilities
let providerOptions: Record<string, any> = {}
let providerOptions: BedrockProviderOptions = {}
if (enableReasoning) {
const reasoningParams = getBedrockReasoningParams(assistant, model)

View File

@@ -1,3 +1,8 @@
import type { BedrockProviderOptions } from '@ai-sdk/amazon-bedrock'
import type { AnthropicProviderOptions } from '@ai-sdk/anthropic'
import type { GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
import type { OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import type { XaiProviderOptions } from '@ai-sdk/xai'
import { loggerService } from '@logger'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import {
@@ -7,6 +12,8 @@ import {
isDeepSeekHybridInferenceModel,
isDoubaoSeedAfter251015,
isDoubaoThinkingAutoModel,
isGemini3Model,
isGPT51SeriesModel,
isGrok4FastReasoningModel,
isGrokReasoningModel,
isOpenAIDeepResearchModel,
@@ -27,13 +34,13 @@ import {
isSupportedThinkingTokenZhipuModel,
MODEL_SUPPORTED_REASONING_EFFORT
} from '@renderer/config/models'
import { isSupportEnableThinkingProvider } from '@renderer/config/providers'
import { getStoreSetting } from '@renderer/hooks/useSettings'
import { getAssistantSettings, getProviderByModel } from '@renderer/services/AssistantService'
import type { SettingsState } from '@renderer/store/settings'
import type { Assistant, Model } from '@renderer/types'
import type { Assistant, Model, ReasoningEffortOption } from '@renderer/types'
import { EFFORT_RATIO, isSystemProvider, SystemProviderIds } from '@renderer/types'
import type { OpenAISummaryText } from '@renderer/types/aiCoreTypes'
import type { ReasoningEffortOptionalParams } from '@renderer/types/sdk'
import { isSupportEnableThinkingProvider } from '@renderer/utils/provider'
import { toInteger } from 'lodash'
const logger = loggerService.withContext('reasoning')
@@ -56,13 +63,20 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
}
const reasoningEffort = assistant?.settings?.reasoning_effort
if (!reasoningEffort) {
// Handle undefined and 'none' reasoningEffort.
// TODO: They should be separated.
if (!reasoningEffort || reasoningEffort === 'none') {
// openrouter: use reasoning
if (model.provider === SystemProviderIds.openrouter) {
// Don't disable reasoning for Gemini models that support thinking tokens
if (isSupportedThinkingTokenGeminiModel(model) && !GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
return {}
}
// 'none' is not an available value for effort for now.
// I think they should resolve this issue soon, so I'll just go ahead and use this value.
if (isGPT51SeriesModel(model) && reasoningEffort === 'none') {
return { reasoning: { effort: 'none' } }
}
// Don't disable reasoning for models that require it
if (
isGrokReasoningModel(model) ||
@@ -117,6 +131,13 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
return { thinking: { type: 'disabled' } }
}
// Specially for GPT-5.1. Suppose this is a OpenAI Compatible provider
if (isGPT51SeriesModel(model)) {
return {
reasoningEffort: 'none'
}
}
return {}
}
@@ -259,6 +280,12 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
// gemini series, openai compatible api
if (isSupportedThinkingTokenGeminiModel(model)) {
// https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#openai_compatibility
if (isGemini3Model(model)) {
return {
reasoning_effort: reasoningEffort
}
}
if (reasoningEffort === 'auto') {
return {
extra_body: {
@@ -322,10 +349,14 @@ export function getReasoningEffort(assistant: Assistant, model: Model): Reasonin
}
/**
* 获取 OpenAI 推理参数
* OpenAIResponseAPIClient OpenAIAPIClient 中提取的逻辑
* Get OpenAI reasoning parameters
* Extracted from OpenAIResponseAPIClient and OpenAIAPIClient logic
* For official OpenAI provider only
*/
export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
export function getOpenAIReasoningParams(
assistant: Assistant,
model: Model
): Pick<OpenAIResponsesProviderOptions, 'reasoningEffort' | 'reasoningSummary'> {
if (!isReasoningModel(model)) {
return {}
}
@@ -336,6 +367,10 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
return {}
}
if (isOpenAIDeepResearchModel(model) || reasoningEffort === 'auto') {
reasoningEffort = 'medium'
}
// 非OpenAI模型但是Provider类型是responses/azure openai的情况
if (!isOpenAIModel(model)) {
return {
@@ -343,21 +378,17 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
}
}
const openAI = getStoreSetting('openAI') as SettingsState['openAI']
const summaryText = openAI?.summaryText || 'off'
const openAI = getStoreSetting('openAI')
const summaryText = openAI.summaryText
let reasoningSummary: string | undefined = undefined
let reasoningSummary: OpenAISummaryText = undefined
if (summaryText === 'off' || model.id.includes('o1-pro')) {
if (model.id.includes('o1-pro')) {
reasoningSummary = undefined
} else {
reasoningSummary = summaryText
}
if (isOpenAIDeepResearchModel(model)) {
reasoningEffort = 'medium'
}
// OpenAI 推理参数
if (isSupportedReasoningEffortOpenAIModel(model)) {
return {
@@ -369,19 +400,26 @@ export function getOpenAIReasoningParams(assistant: Assistant, model: Model): Re
return {}
}
export function getAnthropicThinkingBudget(assistant: Assistant, model: Model): number {
const { maxTokens, reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
if (reasoningEffort === undefined) {
return 0
export function getAnthropicThinkingBudget(
maxTokens: number | undefined,
reasoningEffort: string | undefined,
modelId: string
): number | undefined {
if (reasoningEffort === undefined || reasoningEffort === 'none') {
return undefined
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
const tokenLimit = findTokenLimit(modelId)
if (!tokenLimit) {
return undefined
}
const budgetTokens = Math.max(
1024,
Math.floor(
Math.min(
(findTokenLimit(model.id)?.max! - findTokenLimit(model.id)?.min!) * effortRatio +
findTokenLimit(model.id)?.min!,
(tokenLimit.max - tokenLimit.min) * effortRatio + tokenLimit.min,
(maxTokens || DEFAULT_MAX_TOKENS) * effortRatio
)
)
@@ -393,14 +431,17 @@ export function getAnthropicThinkingBudget(assistant: Assistant, model: Model):
* 获取 Anthropic 推理参数
* 从 AnthropicAPIClient 中提取的逻辑
*/
export function getAnthropicReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
export function getAnthropicReasoningParams(
assistant: Assistant,
model: Model
): Pick<AnthropicProviderOptions, 'thinking'> {
if (!isReasoningModel(model)) {
return {}
}
const reasoningEffort = assistant?.settings?.reasoning_effort
if (reasoningEffort === undefined) {
if (reasoningEffort === undefined || reasoningEffort === 'none') {
return {
thinking: {
type: 'disabled'
@@ -410,7 +451,8 @@ export function getAnthropicReasoningParams(assistant: Assistant, model: Model):
// Claude 推理参数
if (isSupportedThinkingTokenClaudeModel(model)) {
const budgetTokens = getAnthropicThinkingBudget(assistant, model)
const { maxTokens } = getAssistantSettings(assistant)
const budgetTokens = getAnthropicThinkingBudget(maxTokens, reasoningEffort, model.id)
return {
thinking: {
@@ -423,13 +465,31 @@ export function getAnthropicReasoningParams(assistant: Assistant, model: Model):
return {}
}
type GoogelThinkingLevel = NonNullable<GoogleGenerativeAIProviderOptions['thinkingConfig']>['thinkingLevel']
function mapToGeminiThinkingLevel(reasoningEffort: ReasoningEffortOption): GoogelThinkingLevel {
switch (reasoningEffort) {
case 'low':
return 'low'
case 'medium':
return 'medium'
case 'high':
return 'high'
default:
return 'medium'
}
}
/**
* 获取 Gemini 推理参数
* 从 GeminiAPIClient 中提取的逻辑
* 注意Gemini/GCP 端点所使用的 thinkingBudget 等参数应该按照驼峰命名法传递
* 而在 Google 官方提供的 OpenAI 兼容端点中则使用蛇形命名法 thinking_budget
*/
export function getGeminiReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
export function getGeminiReasoningParams(
assistant: Assistant,
model: Model
): Pick<GoogleGenerativeAIProviderOptions, 'thinkingConfig'> {
if (!isReasoningModel(model)) {
return {}
}
@@ -438,7 +498,7 @@ export function getGeminiReasoningParams(assistant: Assistant, model: Model): Re
// Gemini 推理参数
if (isSupportedThinkingTokenGeminiModel(model)) {
if (reasoningEffort === undefined) {
if (reasoningEffort === undefined || reasoningEffort === 'none') {
return {
thinkingConfig: {
includeThoughts: false,
@@ -447,6 +507,15 @@ export function getGeminiReasoningParams(assistant: Assistant, model: Model): Re
}
}
// https://ai.google.dev/gemini-api/docs/gemini-3?thinking=high#new_api_features_in_gemini_3
if (isGemini3Model(model)) {
return {
thinkingConfig: {
thinkingLevel: mapToGeminiThinkingLevel(reasoningEffort)
}
}
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
if (effortRatio > 1) {
@@ -478,27 +547,35 @@ export function getGeminiReasoningParams(assistant: Assistant, model: Model): Re
* @param model - The model being used
* @returns XAI-specific reasoning parameters
*/
export function getXAIReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
export function getXAIReasoningParams(assistant: Assistant, model: Model): Pick<XaiProviderOptions, 'reasoningEffort'> {
if (!isSupportedReasoningEffortGrokModel(model)) {
return {}
}
const { reasoning_effort: reasoningEffort } = getAssistantSettings(assistant)
if (!reasoningEffort) {
if (!reasoningEffort || reasoningEffort === 'none') {
return {}
}
// For XAI provider Grok models, use reasoningEffort parameter directly
return {
reasoningEffort
switch (reasoningEffort) {
case 'auto':
case 'minimal':
case 'medium':
return { reasoningEffort: 'low' }
case 'low':
case 'high':
return { reasoningEffort }
}
}
/**
* Get Bedrock reasoning parameters
*/
export function getBedrockReasoningParams(assistant: Assistant, model: Model): Record<string, any> {
export function getBedrockReasoningParams(
assistant: Assistant,
model: Model
): Pick<BedrockProviderOptions, 'reasoningConfig'> {
if (!isReasoningModel(model)) {
return {}
}
@@ -509,12 +586,21 @@ export function getBedrockReasoningParams(assistant: Assistant, model: Model): R
return {}
}
if (reasoningEffort === 'none') {
return {
reasoningConfig: {
type: 'disabled'
}
}
}
// Only apply thinking budget for Claude reasoning models
if (!isSupportedThinkingTokenClaudeModel(model)) {
return {}
}
const budgetTokens = getAnthropicThinkingBudget(assistant, model)
const { maxTokens } = getAssistantSettings(assistant)
const budgetTokens = getAnthropicThinkingBudget(maxTokens, reasoningEffort, model.id)
return {
reasoningConfig: {
type: 'enabled',

View File

@@ -47,6 +47,7 @@ export function buildProviderBuiltinWebSearchConfig(
model?: Model
): WebSearchPluginConfig | undefined {
switch (providerId) {
case 'azure-responses':
case 'openai': {
const searchContextSize = isOpenAIDeepResearchModel(model)
? 'medium'

Binary file not shown.

After

Width:  |  Height:  |  Size: 19 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 21 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 20 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 18 KiB

View File

@@ -1,5 +1,6 @@
import { ActionIconButton } from '@renderer/components/Buttons'
import NarrowLayout from '@renderer/pages/home/Messages/NarrowLayout'
import { scrollElementIntoView } from '@renderer/utils'
import { Tooltip } from 'antd'
import { debounce } from 'lodash'
import { CaseSensitive, ChevronDown, ChevronUp, User, WholeWord, X } from 'lucide-react'
@@ -181,17 +182,14 @@ export const ContentSearch = React.forwardRef<ContentSearchRef, Props>(
// 3. 将当前项滚动到视图中
// 获取第一个文本节点的父元素来进行滚动
const parentElement = currentMatchRange.startContainer.parentElement
if (shouldScroll) {
parentElement?.scrollIntoView({
behavior: 'smooth',
block: 'center',
inline: 'nearest'
})
if (shouldScroll && parentElement) {
// 优先在指定的滚动容器内滚动,避免滚动整个页面导致索引错乱/看起来"跳到第一条"
scrollElementIntoView(parentElement, target)
}
}
}
},
[allRanges, currentIndex]
[allRanges, currentIndex, target]
)
const search = useCallback(

View File

@@ -1,35 +1,120 @@
import 'emoji-picker-element'
import TwemojiCountryFlagsWoff2 from '@renderer/assets/fonts/country-flag-fonts/TwemojiCountryFlags.woff2?url'
import { useTheme } from '@renderer/context/ThemeProvider'
import type { LanguageVarious } from '@renderer/types'
import { polyfillCountryFlagEmojis } from 'country-flag-emoji-polyfill'
// i18n translations from emoji-picker-element
import de from 'emoji-picker-element/i18n/de'
import en from 'emoji-picker-element/i18n/en'
import es from 'emoji-picker-element/i18n/es'
import fr from 'emoji-picker-element/i18n/fr'
import ja from 'emoji-picker-element/i18n/ja'
import pt_PT from 'emoji-picker-element/i18n/pt_PT'
import ru_RU from 'emoji-picker-element/i18n/ru_RU'
import zh_CN from 'emoji-picker-element/i18n/zh_CN'
import type Picker from 'emoji-picker-element/picker'
import type { EmojiClickEvent, NativeEmoji } from 'emoji-picker-element/shared'
// Emoji data from emoji-picker-element-data (local, no CDN)
// Using CLDR format for full multi-language search support (28 languages)
import dataDE from 'emoji-picker-element-data/de/cldr/data.json?url'
import dataEN from 'emoji-picker-element-data/en/cldr/data.json?url'
import dataES from 'emoji-picker-element-data/es/cldr/data.json?url'
import dataFR from 'emoji-picker-element-data/fr/cldr/data.json?url'
import dataJA from 'emoji-picker-element-data/ja/cldr/data.json?url'
import dataPT from 'emoji-picker-element-data/pt/cldr/data.json?url'
import dataRU from 'emoji-picker-element-data/ru/cldr/data.json?url'
import dataZH from 'emoji-picker-element-data/zh/cldr/data.json?url'
import dataZH_HANT from 'emoji-picker-element-data/zh-hant/cldr/data.json?url'
import type { FC } from 'react'
import { useEffect, useRef } from 'react'
import { useTranslation } from 'react-i18next'
interface Props {
onEmojiClick: (emoji: string) => void
}
// Mapping from app locale to emoji-picker-element i18n
const i18nMap: Record<LanguageVarious, typeof en> = {
'en-US': en,
'zh-CN': zh_CN,
'zh-TW': zh_CN, // Closest available
'de-DE': de,
'el-GR': en, // No Greek available, fallback to English
'es-ES': es,
'fr-FR': fr,
'ja-JP': ja,
'pt-PT': pt_PT,
'ru-RU': ru_RU
}
// Mapping from app locale to emoji data URL
// Using CLDR format provides native language search support for all locales
const dataSourceMap: Record<LanguageVarious, string> = {
'en-US': dataEN,
'zh-CN': dataZH,
'zh-TW': dataZH_HANT,
'de-DE': dataDE,
'el-GR': dataEN, // No Greek CLDR available, fallback to English
'es-ES': dataES,
'fr-FR': dataFR,
'ja-JP': dataJA,
'pt-PT': dataPT,
'ru-RU': dataRU
}
// Mapping from app locale to emoji-picker-element locale string
// Must match the data source locale for proper IndexedDB caching
const localeMap: Record<LanguageVarious, string> = {
'en-US': 'en',
'zh-CN': 'zh',
'zh-TW': 'zh-hant',
'de-DE': 'de',
'el-GR': 'en',
'es-ES': 'es',
'fr-FR': 'fr',
'ja-JP': 'ja',
'pt-PT': 'pt',
'ru-RU': 'ru'
}
const EmojiPicker: FC<Props> = ({ onEmojiClick }) => {
const { theme } = useTheme()
const ref = useRef<HTMLDivElement>(null)
const { i18n } = useTranslation()
const ref = useRef<Picker>(null)
const currentLocale = i18n.language as LanguageVarious
useEffect(() => {
polyfillCountryFlagEmojis('Twemoji Mozilla', TwemojiCountryFlagsWoff2)
}, [])
// Configure picker with i18n and dataSource
useEffect(() => {
const refValue = ref.current
const picker = ref.current
if (picker) {
picker.i18n = i18nMap[currentLocale] || en
picker.dataSource = dataSourceMap[currentLocale] || dataEN
picker.locale = localeMap[currentLocale] || 'en'
}
}, [currentLocale])
if (refValue) {
const handleEmojiClick = (event: any) => {
useEffect(() => {
const picker = ref.current
if (picker) {
const handleEmojiClick = (event: EmojiClickEvent) => {
event.stopPropagation()
onEmojiClick(event.detail.unicode || event.detail.emoji.unicode)
const { detail } = event
// Use detail.unicode (processed with skin tone) or fallback to emoji's unicode for native emoji
const unicode = detail.unicode || ('unicode' in detail.emoji ? (detail.emoji as NativeEmoji).unicode : '')
onEmojiClick(unicode)
}
// 添加事件监听器
refValue.addEventListener('emoji-click', handleEmojiClick)
picker.addEventListener('emoji-click', handleEmojiClick)
// 清理事件监听器
return () => {
refValue.removeEventListener('emoji-click', handleEmojiClick)
picker.removeEventListener('emoji-click', handleEmojiClick)
}
}
return

View File

@@ -0,0 +1,157 @@
import { loggerService } from '@logger'
import type { UIActionResult } from '@mcp-ui/client'
import { UIResourceRenderer } from '@mcp-ui/client'
import type { EmbeddedResource } from '@modelcontextprotocol/sdk/types.js'
import { isUIResource } from '@renderer/types'
import type { FC } from 'react'
import { useCallback, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
const logger = loggerService.withContext('MCPUIRenderer')
interface Props {
resource: EmbeddedResource
serverId?: string
serverName?: string
onToolCall?: (toolName: string, params: any) => Promise<any>
}
const MCPUIRenderer: FC<Props> = ({ resource, onToolCall }) => {
const { t } = useTranslation()
const [error] = useState<string | null>(null)
const handleUIAction = useCallback(
async (result: UIActionResult): Promise<any> => {
logger.debug('UI Action received:', result)
try {
switch (result.type) {
case 'tool': {
// Handle tool call from UI
if (onToolCall) {
const { toolName, params } = result.payload
logger.info(`UI requesting tool call: ${toolName}`, { params })
const response = await onToolCall(toolName, params)
// Check if the response contains a UIResource
try {
if (response && response.content && Array.isArray(response.content)) {
const firstContent = response.content[0]
if (firstContent && firstContent.type === 'text' && firstContent.text) {
const parsedText = JSON.parse(firstContent.text)
if (isUIResource(parsedText)) {
// Return the UIResource directly for rendering in the iframe
logger.info('Tool response contains UIResource:', { uri: parsedText.resource.uri })
return { status: 'success', data: parsedText }
}
}
}
} catch (parseError) {
// Not a UIResource, return the original response
logger.debug('Tool response is not a UIResource')
}
return { status: 'success', data: response }
} else {
logger.warn('Tool call requested but no handler provided')
return { status: 'error', message: 'Tool call handler not available' }
}
}
case 'intent': {
// Handle user intent
logger.info('UI intent:', result.payload)
window.toast.info(t('message.mcp.ui.intent_received'))
return { status: 'acknowledged' }
}
case 'notify': {
// Handle notification from UI
logger.info('UI notification:', result.payload)
window.toast.info(result.payload.message || t('message.mcp.ui.notification'))
return { status: 'acknowledged' }
}
case 'prompt': {
// Handle prompt request from UI
logger.info('UI prompt request:', result.payload)
// TODO: Integrate with prompt system
return { status: 'error', message: 'Prompt execution not yet implemented' }
}
case 'link': {
// Handle navigation request
const { url } = result.payload
logger.info('UI navigation request:', { url })
window.open(url, '_blank')
return { status: 'acknowledged' }
}
default:
logger.warn('Unknown UI action type:', { result })
return { status: 'error', message: 'Unknown action type' }
}
} catch (err) {
logger.error('Error handling UI action:', err as Error)
return {
status: 'error',
message: err instanceof Error ? err.message : 'Unknown error'
}
}
},
[onToolCall, t]
)
if (error) {
return (
<ErrorContainer>
<ErrorTitle>{t('message.mcp.ui.error')}</ErrorTitle>
<ErrorMessage>{error}</ErrorMessage>
</ErrorContainer>
)
}
return (
<UIContainer>
<UIResourceRenderer resource={resource} onUIAction={handleUIAction} />
</UIContainer>
)
}
const UIContainer = styled.div`
width: 100%;
min-height: 400px;
border-radius: 8px;
overflow: hidden;
background: var(--color-background);
border: 1px solid var(--color-border);
iframe {
width: 100%;
border: none;
min-height: 400px;
height: 600px;
}
`
const ErrorContainer = styled.div`
padding: 16px;
border-radius: 8px;
background: var(--color-error-bg, #fee);
border: 1px solid var(--color-error-border, #fcc);
color: var(--color-error-text, #c33);
`
const ErrorTitle = styled.div`
font-weight: 600;
margin-bottom: 8px;
font-size: 14px;
`
const ErrorMessage = styled.div`
font-size: 13px;
opacity: 0.9;
`
export default MCPUIRenderer

Some files were not shown because too many files have changed in this diff Show More