Compare commits

...

174 Commits

Author SHA1 Message Date
MyPrototypeWhat
70c6478278 chore(package): bump version to 1.6.0-beta.2 2025-08-29 16:01:45 +08:00
MyPrototypeWhat
0944faf8e5 refactor(aiCore): clean up KnowledgeSearchTool and searchOrchestrationPlugin
- Commented out unused parameters and code in the KnowledgeSearchTool to improve clarity and maintainability.
- Removed the tool choice assignment in searchOrchestrationPlugin to streamline the logic.
- Updated instructions handling in KnowledgeSearchTool to eliminate unnecessary fields.
2025-08-29 16:01:11 +08:00
MyPrototypeWhat
6a1918deef feat(releaseNotes): update release notes with new features and improvements
- Added a modal for detailed error information with multi-language support.
- Enhanced AI Core with version upgrade and improved parameter handling.
- Refactored error handling for better type safety and performance.
- Removed deprecated code and improved provider initialization logic.
2025-08-29 15:23:58 +08:00
MyPrototypeWhat
c9d0265872 refactor(aiCore): enhance temperature and TopP parameter handling
- Updated `getTemperature` and `getTopP` functions to incorporate reasoning effort checks for Claude models.
- Refactored logic to ensure temperature and TopP settings are only returned when applicable.
- Improved clarity and maintainability of parameter retrieval functions.
2025-08-29 15:06:34 +08:00
MyPrototypeWhat
2ca5116769 refactor(aiCore): streamline provider options and enhance OpenAI handling
- Simplified the OpenAI mode handling in the provider configuration.
- Added service tier settings to provider-specific options for better configuration management.
- Refactored the `buildOpenAIProviderOptions` function to remove redundant parameters and improve clarity.
2025-08-29 14:20:03 +08:00
MyPrototypeWhat
efeada281a feat(aiCore): introduce provider configuration enhancements and initialization
- Added a new provider configuration module to handle special provider logic and formatting.
- Implemented asynchronous preparation of special provider configurations in the ModernAiProvider class.
- Refactored provider initialization logic to support dynamic registration of new AI providers.
- Updated utility functions to streamline provider option building and improve compatibility with new provider configurations.
2025-08-29 13:34:15 +08:00
MyPrototypeWhat
49cd9d6723 chore(aiCore): update version to 1.0.0-alpha.11 and refactor model resolution logic
- Bumped the version of the ai-core package to 1.0.0-alpha.11.
- Removed the `isOpenAIChatCompletionOnlyModel` utility function to simplify model resolution.
- Adjusted the `providerToAiSdkConfig` function to accept a model parameter for improved configuration handling.
- Updated the `ModernAiProvider` class to utilize the new model parameter in its configuration.
- Cleaned up deprecated code related to search keyword extraction and reasoning parameters.
2025-08-29 12:20:22 +08:00
MyPrototypeWhat
1735a9efb6 feat(ErrorBlock): add centered modal display for error details
- Updated the ErrorBlock component to include a centered modal for displaying error details, enhancing the user interface and accessibility of error information.
2025-08-29 11:30:16 +08:00
MyPrototypeWhat
1b86997f14 refactor(aiCore): enhance completions methods with developer mode support
- Introduced a check for developer mode in the completions methods to enable tracing capabilities when a topic ID is provided.
- Updated the method signatures and internal calls to streamline the handling of completions with and without tracing.
- Improved code organization by making the completionsForTrace method private and renaming it for clarity.
2025-08-29 11:14:38 +08:00
suyao
cf777ba62b feat(inputbar): enhance MCP tools visibility with prompt tool support
- Updated the Inputbar component to include the `isPromptToolUse` function, allowing for better visibility of MCP tools based on the assistant's capabilities.
- This change improves user experience by expanding the conditions under which MCP tools are displayed.
2025-08-29 07:02:00 +08:00
suyao
4918628131 feat(i18n): add error detail translations and enhance error handling UI
- Added new translation keys for error details in multiple languages, including 'detail', 'details', 'message', 'requestBody', 'requestUrl', 'stack', and 'status'.
- Updated the ErrorBlock component to display a modal with detailed error information, allowing users to view and copy error details easily.
- Improved the user experience by providing a clear and accessible way to understand error messages and their context.
2025-08-29 06:55:25 +08:00
suyao
fee6ad58d1 refactor(errorHandling): improve error serialization and update error handling in callbacks
- Updated the error handling in the `onError` callback to use `AISDKError` type for better type safety.
- Introduced a new `serializeError` function to standardize error serialization.
- Modified the `ErrorBlock` component to directly access the error message.
- Removed deprecated error formatting functions to streamline the error utility module.
2025-08-29 06:12:16 +08:00
suyao
a30df46c40 feat(aihubmix): add 'type' property to provider configuration for Gemini integration 2025-08-29 04:20:32 +08:00
suyao
3004f84be3 refactor(modelResolver): replace ':' with '|' as the default separator for model IDs
Updated the ModelResolver and related components to use '|' as the default separator instead of ':'. This change improves compatibility and resolves potential conflicts with model ID suffixes. Adjusted model resolution logic accordingly to ensure consistent behavior across the application.
2025-08-29 04:12:48 +08:00
MyPrototypeWhat
9551c49452 feat(release): update version to 1.6.0-beta.1 and enhance release notes with new features, improvements, and bug fixes
- Integrated a new AI SDK architecture for better performance
- Added OCR functionality for image text recognition
- Introduced a code tools page with environment variable settings
- Enhanced the MCP server list with search capabilities
- Improved SVG preview and HTML content rendering
- Fixed multiple issues including document preprocessing failures and path handling on Windows
- Optimized performance and memory usage across various components
2025-08-28 18:25:45 +08:00
MyPrototypeWhat
e312c84a0e chore(config): add new aliases for ai-core in Vite and TypeScript configuration
Updated the Vite and TypeScript configuration files to include new path aliases for the ai-core package, enhancing module resolution for core providers and built-in plugins. This change improves the organization and accessibility of the ai-core components within the project.
2025-08-28 18:12:46 +08:00
MyPrototypeWhat
3d0fb97475 chore(dependencies): update ai and related packages to version 5.0.26 and 1.0.15
Updated the 'ai' package to version 5.0.26 and '@ai-sdk/gateway' to version 1.0.15. Also, updated '@ai-sdk/provider-utils' to version 3.0.7 and 'eventsource-parser' to version 3.0.5. Adjusted type definitions in aiCore for better type safety in plugin parameters and results.
2025-08-28 16:11:08 +08:00
MyPrototypeWhat
d10ba04047 refactor(aiCore): streamline type exports and enhance provider registration
Removed unused type exports from the aiCore module and consolidated type definitions for better clarity. Updated provider registration tests to reflect new configurations and improved error handling for non-existent providers. Enhanced the overall structure of the provider management system, ensuring better type safety and consistency across the codebase.
2025-08-28 15:28:44 +08:00
icarus
4b7023f855 fix(i18n): 更新多语言文件中 websearch.fetch_complete 的翻译格式
统一将“已完成 X 次搜索”改为“X 个搜索结果”格式,并添加 avatar.builtin 字段翻译
2025-08-28 15:11:52 +08:00
icarus
5f096ecf8c refactor(logging): 将console替换为logger以统一日志管理 2025-08-28 14:52:59 +08:00
suyao
a12d627b65 feat(types): add VertexProvider type for Google Cloud integration
Introduced a new VertexProvider type that includes properties for Google credentials and project details, enhancing type safety and support for Google Cloud functionalities.
2025-08-28 12:22:14 +08:00
MyPrototypeWhat
7bda658022 Merge remote-tracking branch 'origin/main' into feat/aisdk-package 2025-08-28 12:03:41 +08:00
MyPrototypeWhat
bfcb215c16 fix(aiCore): update tool call status and enhance execution flow
- Changed tool call status from 'invoking' to 'pending' for better clarity in execution state.
- Updated the tool execution logic to include user confirmation for non-auto-approved tools, improving user interaction.
- Refactored the handling of experimental context in the tool execution parameters to support chunk streaming.
- Commented out unused tool input event cases in AiSdkToChunkAdapter for cleaner code.
2025-08-27 19:37:10 +08:00
MyPrototypeWhat
1b3fcb2e55 chore(aiCore): bump version to 1.0.0-alpha.10 in package.json 2025-08-27 16:23:08 +08:00
MyPrototypeWhat
9c01e24317 feat(aiCore): enhance provider management and registration system
- Added support for a new provider configuration structure in package.json, enabling better integration of provider types.
- Updated tsdown.config.ts to include new entry points for provider modules, improving build organization.
- Refactored index.ts to streamline exports and enhance type handling for provider-related functionalities.
- Simplified provider initialization and registration processes, allowing for more flexible provider management.
- Improved type definitions and removed deprecated methods to enhance code clarity and maintainability.
2025-08-27 16:17:57 +08:00
MyPrototypeWhat
2ce9314a10 refactor(aiCore): improve type handling and response structures
- Updated AiSdkToChunkAdapter to refine web search result handling.
- Modified McpToolChunkMiddleware to ensure consistent type usage for tool responses.
- Enhanced type definitions in chunk.ts and index.ts for better clarity and type safety.
- Adjusted MessageWebSearch styles for improved UI consistency.
- Refactored parseToolUse function to align with updated MCPTool response structures.
2025-08-27 11:23:30 +08:00
MyPrototypeWhat
0c7e221b4e feat(aiCore): add MemorySearchTool and WebSearchTool components
- Introduced MessageMemorySearch and MessageWebSearch components for handling memory and web search tool responses.
- Updated MemorySearchTool and WebSearchTool to improve response handling and integrate with the new components.
- Removed unused console logs and streamlined code for better readability and maintainability.
- Added new dependencies in package.json for enhanced functionality.
2025-08-26 17:59:52 +08:00
MyPrototypeWhat
82d4637c9d chore(aiCore): bump version to 1.0.0-alpha.9 in package.json 2025-08-26 16:19:41 +08:00
MyPrototypeWhat
84eef25ff9 feat(aiCore): enhance dynamic provider registration and refactor HubProvider
- Introduced dynamic provider registration functionality, allowing for flexible management of providers through a new registry system.
- Refactored HubProvider to streamline model resolution and improve error handling for unsupported models.
- Added utility functions for managing dynamic providers, including registration, cleanup, and alias resolution.
- Updated index exports to include new dynamic provider APIs, enhancing overall usability and integration.
- Removed outdated provider files and simplified the provider management structure for better maintainability.
2025-08-26 16:17:01 +08:00
lizhixuan
53dcda6942 feat(aiCore): introduce Hub Provider and enhance provider management
- Added a new example file demonstrating the usage of the Hub Provider for routing to multiple underlying providers.
- Implemented the Hub Provider to support model ID parsing and routing based on a specified format.
- Refactored provider management by introducing a Registry Management class for better organization and retrieval of provider instances.
- Updated the Provider Initializer to streamline the initialization and registration of providers, enhancing overall flexibility and usability.
- Removed outdated files related to provider creation and dynamic registration to simplify the codebase.
2025-08-26 00:31:41 +08:00
MyPrototypeWhat
ead0e22c60 [WIP]refactor(aiCore): restructure models and introduce ModelResolver
- Removed outdated ConfigManager and factory files to streamline model management.
- Added ModelResolver for improved model ID resolution, supporting both traditional and namespaced formats.
- Introduced DynamicProviderRegistry for dynamic provider management, enhancing flexibility in model handling.
- Updated index exports to reflect the new structure and maintain compatibility with existing functionality.
2025-08-25 19:46:51 +08:00
MyPrototypeWhat
417f90df3b feat(dependencies): update @ai-sdk/openai and @ai-sdk/provider-utils versions
- Upgraded `@ai-sdk/openai` to version 2.0.19 in `yarn.lock` and `package.json` for improved functionality and compatibility.
- Updated `@ai-sdk/provider-utils` to version 3.0.5, enhancing dependency management.
- Added `TypedToolError` type export in `index.ts` for better error handling.
- Removed unnecessary console logs in `webSearchPlugin` for cleaner code.
- Refactored type handling in `createProvider` to ensure proper type assertions.
- Enforced `topicId` as a required field in the `ModernAiProvider` configuration for stricter validation.
2025-08-25 16:04:50 +08:00
suyao
65c15c6d87 feat(aiCore): update ai-sdk-provider and enhance message conversion logic
- Upgraded `@openrouter/ai-sdk-provider` to version ^1.1.2 in package.json and yarn.lock for improved functionality.
- Enhanced `convertMessageToSdkParam` and related functions to support additional model parameters, improving message conversion for various AI models.
- Integrated logging for error handling in file processing functions to aid in debugging and user feedback.
- Added support for native PDF input handling based on model capabilities, enhancing file processing features.
2025-08-25 14:40:48 +08:00
MyPrototypeWhat
ca4e7e3d2b feat(tools): refactor MemorySearchTool and WebSearchTool for improved response handling
- Updated MemorySearchTool to utilize aiSdk for better integration and removed unused imports.
- Refactored WebSearchTool to streamline search results handling, changing from an array to a structured object for clarity.
- Adjusted MessageTool and MessageWebSearchTool components to reflect changes in tool response structure.
- Enhanced error handling and logging in tool callbacks for improved debugging and user feedback.
2025-08-22 19:35:09 +08:00
MyPrototypeWhat
d34b640807 feat(aiCore): enhance tool response handling and type definitions
- Updated the ToolCallChunkHandler to support both MCPTool and NormalToolResponse types, improving flexibility in tool response management.
- Refactored type definitions for MCPToolResponse and introduced NormalToolResponse to better differentiate between tool response types.
- Enhanced logging in MCP utility functions for improved error tracking and debugging.
- Cleaned up type imports and ensured consistent handling of tool responses across various chunks.
2025-08-21 16:30:30 +08:00
MyPrototypeWhat
aa9ed3b9c8 feat(dependencies): update ai-sdk packages and improve type safety
- Upgraded multiple `@ai-sdk` packages in `yarn.lock` and `package.json` to their latest versions for enhanced functionality and compatibility.
- Improved type safety in `searchOrchestrationPlugin` by adding optional chaining to handle potential undefined values in knowledge bases.
- Cleaned up dependency declarations to use caret (^) for versioning, ensuring compatibility with future updates.
2025-08-19 16:07:29 +08:00
MyPrototypeWhat
d4da7d817d feat(dependencies): update ai-sdk packages and improve logging
- Upgraded `@ai-sdk/gateway` to version 1.0.8 and `@ai-sdk/provider-utils` to version 3.0.4 in yarn.lock for enhanced functionality.
- Updated `ai` dependency in package.json to version ^5.0.16 for better compatibility.
- Added logging functionality in `AiSdkToChunkAdapter` to track chunk types and improve debugging.
- Refactored plugin imports to streamline code and enhance readability.
- Removed unnecessary console logs in `searchOrchestrationPlugin` to clean up the codebase.
2025-08-19 16:03:51 +08:00
MyPrototypeWhat
179b7af9bd feat(toolUsePlugin): refactor tool execution and event management
- Extracted `StreamEventManager` and `ToolExecutor` classes from `promptToolUsePlugin.ts` to improve code organization and reduce complexity.
- Enhanced tool execution logic with better error handling and event management.
- Updated the `createPromptToolUsePlugin` function to utilize the new classes for cleaner implementation.
- Improved recursive call handling and result formatting for tool executions.
- Streamlined the overall flow of tool calls and event emissions within the plugin.
2025-08-19 14:28:04 +08:00
MyPrototypeWhat
5d0ab0a9a1 feat(aiCore): update vitest version and enhance provider validation
- Upgraded `vitest` dependency to version 3.2.4 in package.json and yarn.lock for improved testing capabilities.
- Removed console error logging in provider validation functions to streamline error handling.
- Added comprehensive tests for the AiProviderRegistry functionality, ensuring robust provider management and dynamic registration.
- Introduced new test cases for provider schemas to validate configurations and IDs.
- Deleted outdated registry test file to maintain a clean test suite.
2025-08-19 11:13:03 +08:00
lizhixuan
d93a36e5c9 feat(toolUsePlugin): enhance tool parsing and extraction functionality
- Updated the `defaultParseToolUse` function to return both parsed results and remaining content, improving usability.
- Introduced a new `TagExtractor` class for flexible tag extraction, supporting various tag formats.
- Modified type definitions to reflect changes in parsing function signatures.
- Enhanced handling of tool call events in the `ToolCallChunkHandler` for better integration with the new parsing logic.
- Added `isBuiltIn` property to the `MCPTool` interface for clearer tool categorization.
2025-08-19 00:44:30 +08:00
MyPrototypeWhat
c9c0616c91 feat(provider): enhance provider registration and validation system
- Introduced a new Zod-based schema for provider validation, improving type safety and consistency.
- Added support for dynamic provider IDs and enhanced the provider registration process.
- Updated the AiProviderRegistry to utilize the new validation functions, ensuring robust provider management.
- Added tests for the provider registry to validate dynamic imports and functionality.
- Updated yarn.lock to reflect the latest dependency versions.
2025-08-18 19:41:43 +08:00
one
356443babf fix: remove default renderer from MessageTool 2025-08-17 17:00:43 +08:00
one
b2c512082f fix: missing dependencies 2025-08-17 16:41:50 +08:00
one
0273c58050 fix: file name 2025-08-17 16:14:36 +08:00
one
239c849890 Merge branch 'main' into feat/aisdk-package 2025-08-17 16:11:18 +08:00
MyPrototypeWhat
bbc472c169 refactor(types): modify ProviderId type definition for improved flexibility
- Updated ProviderId type from an intersection of keyof ExtensibleProviderSettingsMap and string to a union, allowing for greater compatibility with dynamic provider settings.
2025-08-15 18:11:43 +08:00
MyPrototypeWhat
b099c9b0b3 refactor(types): update ProviderId type definition for better compatibility
- Changed ProviderId type from a union of keyof ExtensibleProviderSettingsMap and string to an intersection, enhancing type safety.
- Renamed appendTrace method to appendMessageTrace in SpanManagerService for clarity and consistency.
- Updated references to appendTrace in useMessageOperations and ApiService to use the new method name.
- Added a new appendTrace method in SpanManagerService to bind existing traces, improving trace management.
- Adjusted topicId handling in fetchMessagesSummary to default to an empty string for better consistency.
2025-08-15 18:04:24 +08:00
MyPrototypeWhat
628919b562 chore: update dependencies and remove unused patches
- Updated various package versions in yarn.lock for improved compatibility and performance.
- Removed obsolete patches for antd and openai, streamlining the dependency management.
- Adjusted icon imports in Dropdown and useIcons to utilize Lucide icons for better visual consistency.
2025-08-15 11:47:24 +08:00
MyPrototypeWhat
d05ed94702 Merge branch 'main' into feat/aisdk-package 2025-08-15 11:11:20 +08:00
MyPrototypeWhat
02f8e7a857 feat(aiCore): add enableUrlContext capability and new support function
- Enhanced the buildStreamTextParams function to include enableUrlContext in the capabilities object, improving the parameter set for AI interactions.
- Introduced a new isSupportedFlexServiceTier function to streamline model support checks, enhancing code clarity and maintainability.
2025-08-14 19:31:17 +08:00
MyPrototypeWhat
6c093f72d8 feat(Dropdown): replace RightOutlined with ChevronRight icon and update useIcons to use ChevronDown
- Introduced a patch to replace the RightOutlined icon with ChevronRight in the Dropdown component for improved visual consistency.
- Updated the useIcons hook to utilize ChevronDown instead of DownOutlined, enhancing the icon set with Lucide icons.
- Adjusted icon properties for better customization and styling options.
2025-08-14 19:15:49 +08:00
MyPrototypeWhat
0bb1001d40 Merge remote-tracking branch 'origin/main' into feat/aisdk-package 2025-08-14 18:59:19 +08:00
MyPrototypeWhat
376020b23c refactor(aiCore): update ModernAiProvider constructor and clean up unused code
- Modified the constructor of ModernAiProvider to accept an optional provider parameter, enhancing flexibility in provider selection.
- Removed deprecated and unused functions related to search keyword extraction and search summary fetching, streamlining the codebase.
- Updated import statements and adjusted related logic to reflect the removal of obsolete functions, improving maintainability.
2025-08-14 18:06:11 +08:00
MyPrototypeWhat
3630133efd fix: update MessageKnowledgeSearch to use knowledgeReferences
- Modified MessageKnowledgeSearch component to display additional context from toolInput.
- Updated the fetch complete message to reflect the count of knowledgeReferences instead of toolOutput.
- Adjusted the mapping of results to iterate over knowledgeReferences for rendering.
2025-08-14 16:27:23 +08:00
MyPrototypeWhat
cb55f7a69b feat(aiCore): enhance AI SDK with tracing and telemetry support
- Integrated tracing capabilities into the ModernAiProvider, allowing for better tracking of AI completions and image generation processes.
- Added a new TelemetryPlugin to inject telemetry data into AI SDK requests, ensuring compatibility with existing tracing systems.
- Updated middleware and plugin configurations to support topic-based tracing, improving the overall observability of AI interactions.
- Introduced comprehensive logging throughout the AI SDK processes to facilitate debugging and performance monitoring.
- Added unit tests for new functionalities to ensure reliability and maintainability.
2025-08-14 16:17:41 +08:00
MyPrototypeWhat
ff7ad52ad5 feat(tests): add unit tests for utility functions in utils.test.ts
- Implemented tests for `createErrorChunk`, `capitalize`, and `isAsyncIterable` functions.
- Ensured comprehensive coverage for various input scenarios, including error handling and edge cases.
2025-08-08 15:20:02 +08:00
suyao
bf02afa841 chore(package.json): bump version to 1.0.0-alpha.7 2025-08-06 17:07:32 +08:00
suyao
e8b059c4db chore(yarn.lock): remove deprecated provider entries and clean up dependencies 2025-08-06 17:03:17 +08:00
suyao
abfec7a228 Merge branch 'feat/aisdk-package' of https://github.com/CherryHQ/cherry-studio into feat/aisdk-package 2025-08-06 17:01:57 +08:00
MyPrototypeWhat
eeafb99059 refactor: restructure aiCore for improved modularity and legacy support
- Introduced a new `index_new.ts` file to facilitate the modern AI provider while maintaining backward compatibility with the legacy `index.ts`.
- Created a `legacy` directory to house existing clients and middleware, ensuring a clear separation from new implementations.
- Updated import paths across various modules to reflect the new structure, enhancing code organization and maintainability.
- Added comprehensive middleware and utility functions to support the new architecture, improving overall functionality and extensibility.
- Enhanced plugin management with a dedicated `PluginBuilder` for better integration and configuration of AI plugins.
2025-08-05 19:42:57 +08:00
suyao
1ea8266280 fix: migrate to v5-patch2 2025-08-03 23:09:19 +08:00
suyao
def685921c Merge remote-tracking branch 'origin/main' into feat/aisdk-package 2025-08-02 21:02:23 +08:00
suyao
a8dbae1715 refactor: migrate to v5 patch-1 2025-08-02 19:52:17 +08:00
suyao
71959f577d refactor: enhance image generation handling and tool integration
- Updated image generation logic to support new model types and improved size handling.
- Refactored middleware configuration to better manage tool usage and reasoning capabilities.
- Introduced new utility functions for checking model compatibility with image generation.
- Enhanced the integration of plugins for improved functionality during image generation processes.
- Removed deprecated knowledge search tool to streamline the codebase.
2025-08-01 19:00:24 +08:00
suyao
ecc08bd3f7 feat: integrate image generation capabilities and enhance testing framework
- Added support for image generation in the `RuntimeExecutor` with a new `generateImage` method.
- Updated `aiCore` package to include `vitest` for testing, with new test scripts added.
- Enhanced type definitions to accommodate image model handling in plugins.
- Introduced new methods for resolving and executing image generation with plugins.
- Updated package dependencies in `package.json` to include `vitest` and ensure compatibility with new features.
2025-08-01 10:45:31 +08:00
MyPrototypeWhat
7216e9943c refactor: streamline async function syntax and enhance plugin event handling
- Simplified async function syntax in `RuntimeExecutor` and `PluginEngine` for improved readability.
- Updated `AiSdkToChunkAdapter` to refine condition checks for Google metadata.
- Enhanced `searchOrchestrationPlugin` to log conversation messages and improve memory storage logic.
- Improved memory processing by ensuring fallback for existing memories.
- Added new citation block handling in `toolCallbacks` for better integration with web search results.
2025-07-29 19:26:29 +08:00
MyPrototypeWhat
a05d7cbe2d refactor: enhance search orchestration and web search tool integration
- Updated `searchOrchestrationPlugin` to improve handling of assistant configurations and prevent concurrent analysis.
- Refactored `webSearchTool` to utilize pre-extracted keywords for more efficient web searches.
- Introduced a new `MessageKnowledgeSearch` component for displaying knowledge search results.
- Cleaned up commented-out code and improved type safety across various components.
- Enhanced the integration of web search results in the UI for better user experience.
2025-07-29 12:16:06 +08:00
lizhixuan
0310648445 feat: implement knowledge search tool and enhance search orchestration logic
- Added a new `knowledgeSearchTool` to facilitate knowledge base searches based on user queries and intent analysis.
- Refactored `analyzeSearchIntent` to simplify message context construction and improve prompt formatting.
- Introduced a flag to prevent concurrent analysis processes in `searchOrchestrationPlugin`.
- Updated tool configuration logic to conditionally add the knowledge search tool based on the presence of knowledge bases and user settings.
- Cleaned up commented-out code for better readability and maintainability.
2025-07-24 00:11:57 +08:00
MyPrototypeWhat
33db455e32 refactor: consolidate queue utility imports in messageThunk.ts
- Combined separate imports of `getTopicQueue` and `waitForTopicQueue` from the queue utility into a single import statement for improved code clarity and organization.
2025-07-23 15:01:48 +08:00
lizhixuan
e690da840c chore: bump @cherrystudio/ai-core version to 1.0.0-alpha.6 and refactor web search tool
- Updated version in package.json to 1.0.0-alpha.6.
- Simplified response structure in ToolCallChunkHandler by removing unnecessary nesting.
- Refactored input schema for web search tool to enhance type safety and clarity.
- Cleaned up commented-out code in MessageTool for improved readability.
2025-07-22 21:58:22 +08:00
lizhixuan
eca9442907 refactor: update message handling in searchOrchestrationPlugin for improved type safety
- Replaced `Message` type with `ModelMessage` in various functions to enhance type consistency.
- Refactored `getMessageContent` function to utilize the new `ModelMessage` type for better content extraction.
- Updated `storeConversationMemory` and `analyzeSearchIntent` functions to align with the new type definitions, ensuring clearer memory storage and intent analysis processes.
2025-07-22 21:58:12 +08:00
lizhixuan
4b62384fc5 <type>: <subject>
<body>
<footer>
用來簡要描述影響本次變動,概述即可
2025-07-22 18:52:39 +08:00
lizhixuan
addd5ffdfa feat: enhance ToolCallChunkHandler with detailed chunk handling and remove unused plugins
- Updated `handleToolCallCreated` method to support additional chunk types with optional provider metadata.
- Removed deprecated `smoothReasoningPlugin` and `textPlugin` files to clean up the codebase.
- Cleaned up unused type imports in `tool.ts` for improved clarity and maintainability.
2025-07-21 23:39:46 +08:00
MyPrototypeWhat
fcc8836c95 feat: update OpenAI provider integration and enhance type definitions
- Bumped version of `@ai-sdk/openai-compatible` to 1.0.0-beta.8 in package.json.
- Introduced a new provider configuration for 'OpenAI Responses' in AiProviderRegistry, allowing for more flexible response handling.
- Updated type definitions to include 'openai-responses' in ProviderSettingsMap for improved type safety.
- Refactored getModelToProviderId function to return a more specific ProviderId type.
2025-07-21 14:43:54 +08:00
suyao
61e3309cd2 fix: conditionally enable reasoning middleware for OpenAI and Azure providers
- Added a check to enable the 'thinking-tag-extraction' middleware only if reasoning is enabled in the configuration for OpenAI and Azure providers.
- Commented out the provider type check in `getAiSdkProviderId` to prevent issues with retrieving provider options.
2025-07-21 14:20:33 +08:00
MyPrototypeWhat
786bc8dca9 feat: enhance web search tool functionality and type definitions
- Introduced new `WebSearchToolOutputSchema` type to standardize output from web search tools.
- Updated `webSearchTool` and `webSearchToolWithExtraction` to utilize Zod for input and output schema validation.
- Refactored tool execution logic to improve error handling and response formatting.
- Cleaned up unused type imports and comments for better code clarity.
2025-07-18 19:33:54 +08:00
MyPrototypeWhat
c3a6456499 docs: update AI SDK architecture and README for enhanced clarity and new features
- Revised AI SDK architecture diagram to reflect changes in component relationships, replacing PluginEngine with RuntimeExecutor.
- Updated README to highlight core features, including a refined plugin system, improved architecture design, and new built-in plugins.
- Added detailed examples for using built-in plugins and creating custom plugins, enhancing documentation for better usability.
- Included future version roadmap and related resources for user reference.
2025-07-18 17:20:55 +08:00
MyPrototypeWhat
ef6be4a6f9 chore: bump @cherrystudio/ai-core version to 1.0.0-alpha.5
- Updated version in package.json to 1.0.0-alpha.5.
- Enhanced provider configuration validation in createProvider function for improved error handling.
2025-07-18 16:30:49 +08:00
MyPrototypeWhat
69e87ce21a refactor: streamline AI provider registration by replacing dynamic imports with direct creator functions
- Updated the AiProviderRegistry to use direct references to creator functions for each AI provider, improving clarity and performance.
- Removed dynamic import statements for providers, simplifying the registration process and enhancing maintainability.
2025-07-18 16:26:52 +08:00
MyPrototypeWhat
608943bdbc chore: update @cherrystudio/ai-core version to 1.0.0-alpha.4 and clean up dependencies
- Bumped version in package.json to 1.0.0-alpha.4.
- Removed deprecated dependencies from package.json and yarn.lock for improved clarity.
- Updated README to reflect changes in supported providers and installation instructions.
- Refactored provider registration and usage examples for better clarity and usability.
2025-07-18 15:58:43 +08:00
MyPrototypeWhat
1248e3c49a refactor: reorganize provider and model exports for improved structure
- Updated exports in index.ts and related files to streamline provider and model management.
- Introduced a new ModelCreator module for better encapsulation of model creation logic.
- Refactored type imports to enhance clarity and maintainability across the codebase.
- Removed deprecated provider configurations and cleaned up unused code for better performance.
2025-07-18 15:35:44 +08:00
MyPrototypeWhat
c3ad18b77e chore: bump @cherrystudio/ai-core version to 1.0.0-alpha.2 and update exports
- Updated version in package.json to 1.0.0-alpha.2.
- Added new path mapping for @cherrystudio/ai-core in tsconfig.web.json.
- Refactored export paths in tsdown.config.ts and index.ts for consistency.
- Cleaned up type exports in index.ts and types.ts for better organization.
2025-07-18 11:39:27 +08:00
suyao
0bc5e3d24d chore: bump @cherrystudio/ai-core version to 1.0.0-alpha.1 2025-07-18 11:03:32 +08:00
suyao
36e20d545b feat: add React Native support to aiCore package
Add React Native compatibility configuration to package.json, including the
react-native field and updated exports mapping. Include documentation for
React Native usage with metro.config.js setup instructions.
2025-07-18 11:03:09 +08:00
lizhixuan
45405213fc feat: enhance AI core functionality and introduce new tool components
- Updated README to reflect the addition of a powerful plugin system and built-in web search capabilities.
- Refactored tool call handling in `ToolCallChunkHandler` to improve state management and response formatting.
- Introduced new components `MessageMcpTool`, `MessageTool`, and `MessageTools` for better handling of tool responses and user interactions.
- Updated type definitions to support new tool response structures and improved overall code organization.
- Enhanced spinner component to accept React nodes for more flexible content rendering.
2025-07-18 00:37:28 +08:00
suyao
b83837708b chore(aiCore/version): update version to 1.0.0-alpha.0 2025-07-17 21:10:59 +08:00
suyao
4732c8f1bd chore: update package.json and add tsdown configuration for build process
- Changed the main and types entries in package.json to point to the dist directory for better output management.
- Added a new tsdown.config.ts file to define the build configuration, specifying entry points, output directory, and formats for the project.
2025-07-17 21:00:32 +08:00
suyao
ef8cf65ece chore: remove deprecated patches for @ai-sdk/google-vertex and @ai-sdk/openai-compatible
- Deleted outdated patch files for @ai-sdk/google-vertex and @ai-sdk/openai-compatible from the project.
- Updated package.json to reflect the removal of these patches, streamlining dependency management.
2025-07-17 20:44:29 +08:00
suyao
e3c5c87e1b chore: add repository metadata and homepage to package.json
- Included repository URL, bugs URL, and homepage in package.json for better project visibility and issue tracking.
- This update enhances the package's metadata, making it easier for users to find relevant resources and report issues.
2025-07-17 20:39:39 +08:00
MyPrototypeWhat
e7d5626055 refactor: enhance provider settings and update web search plugin configuration
- Updated providerSettings to allow optional 'mode' parameter for various providers, enhancing flexibility in model configuration.
- Refactored web search plugin to integrate Google search capabilities and streamline provider options handling.
- Removed deprecated code and improved type definitions for better clarity and maintainability.
- Added console logging for debugging purposes in the provider configuration process.
2025-07-17 18:12:26 +08:00
MyPrototypeWhat
650650a68f refactor: reorganize AiSdkToChunkAdapter and enhance tool call handling
- Moved AiSdkToChunkAdapter to a new directory structure for better organization.
- Implemented detailed handling for tool call events in ToolCallChunkHandler, including creation, updates, and completions.
- Added a new method to handle tool call creation and improved state management for active tool calls.
- Updated StreamProcessingService to support new chunk types and callbacks for block creation.
- Enhanced type definitions and added comments for clarity in the new chunk handling logic.
2025-07-17 16:30:26 +08:00
MyPrototypeWhat
f38e4a87b8 chore: update package dependencies and improve AI SDK chunk handling
- Bumped versions of several dependencies in package.json, including `@swc/plugin-styled-components` to 8.0.4 and `@vitejs/plugin-react-swc` to 3.10.2.
- Enhanced `AiSdkToChunkAdapter` to streamline chunk processing, including better handling of text and reasoning events.
- Added console logging for debugging in `BlockManager` and `messageThunk` to track state changes and callback executions.
- Updated integration tests to reflect changes in message structure and types.
2025-07-17 13:49:06 +08:00
MyPrototypeWhat
a356492d6f Merge remote-tracking branch 'origin/main' into feat/aisdk-package 2025-07-17 11:59:50 +08:00
suyao
8863e10df1 fix: update provider identification logic in aiCore
- Refactored the provider identification in `index_new.ts` to use `actualProvider.type` instead of `actualProvider.id` for better clarity and accuracy in determining OpenAI response modes.
- Removed redundant type checks in `factory.ts` to streamline the provider ID retrieval process.
2025-07-17 03:21:52 +08:00
suyao
42bfa281a7 chore: update dependencies and versions in package.json and yarn.lock
- Upgraded various SDK packages to their latest beta versions for improved functionality and compatibility.
- Updated `@ai-sdk/provider-utils` to version 3.0.0-beta.3.
- Adjusted dependencies in `package.json` to reflect the latest versions, including `@ai-sdk/amazon-bedrock`, `@ai-sdk/anthropic`, `@ai-sdk/azure`, and others.
- Removed outdated versions from `yarn.lock` and ensured consistency across the project.
2025-07-17 03:10:48 +08:00
suyao
e7b4f1f934 feat: add type property to server tools in MCPService
- Enhanced the server tool structure by adding a `type` property set to 'mcp' for better identification and handling of tools within the MCPService.
2025-07-15 23:50:53 +08:00
suyao
0456094512 feat: enhance web search functionality and tool integration
- Introduced `extractSearchKeywords` function to facilitate keyword extraction from user messages for web searches.
- Updated `webSearchTool` to streamline the execution of web searches without requiring a request ID.
- Enhanced `WebSearchService` methods to be static for improved accessibility and clarity.
- Modified `ApiService` to pass `webSearchProviderId` for better integration with the web search functionality.
- Improved `ToolCallChunkHandler` to handle built-in tools more effectively.
2025-07-15 23:39:49 +08:00
suyao
da455997ad feat: integrate web search tool and enhance tool handling
- Added `webSearchTool` to facilitate web search functionality within the SDK.
- Updated `AiSdkToChunkAdapter` to utilize `BaseTool` for improved type handling.
- Refactored `transformParameters` to support `webSearchProviderId` for enhanced web search integration.
- Introduced new `BaseTool` type structure to unify tool definitions across the codebase.
- Adjusted imports and type definitions to align with the new tool handling logic.
2025-07-15 22:47:43 +08:00
lizhixuan
0c4e8228af feat: enhance AiSdkToChunkAdapter for web search results handling
- Updated `AiSdkToChunkAdapter` to include `webSearchResults` in the final output structure for improved web search integration.
- Modified `convertAndEmitChunk` method to handle `finish-step` events, differentiating between Google and other web search results.
- Adjusted the handling of `source` events to accumulate web search results for better processing.
- Enhanced citation formatting in `messageBlock.ts` to support new web search result structures.
2025-07-12 21:11:17 +08:00
lizhixuan
16e0154200 feat: enhance provider settings and model configuration
- Updated `ModelConfig` to include a `mode` property for better differentiation between 'chat' and 'responses'.
- Modified `createBaseModel` to conditionally set the provider based on the new `mode` property in `providerSettings`.
- Refactored `RuntimeExecutor` to utilize the updated `ModelConfig` for improved type safety and clarity in provider settings.
- Adjusted imports in `executor.ts` and `types.ts` to align with the new model configuration structure.
2025-07-12 11:31:06 +08:00
MyPrototypeWhat
3ab904e789 feat: enhance web search plugin and tool handling
- Refactored `helper.ts` to export new types `AnthropicSearchInput` and `AnthropicSearchOutput` for better integration with the web search plugin.
- Updated `index.ts` to include the new types in the exports for improved type safety.
- Modified `AiSdkToChunkAdapter.ts` to handle tool calls more flexibly by introducing a `GenericProviderTool` type, allowing for better differentiation between MCP tools and provider-executed tools.
- Adjusted `handleTooCallChunk.ts` to accommodate the new tool type structure, enhancing the handling of tool call responses.
- Updated type definitions in `index.ts` to reflect changes in tool handling logic.
2025-07-11 19:15:21 +08:00
MyPrototypeWhat
42c7ebd193 feat: enhance model handling and provider integration
- Updated `createBaseModel` to differentiate between OpenAI chat and response models.
- Introduced new utility functions for model identification: `isOpenAIReasoningModel`, `isOpenAILLMModel`, and `getModelToProviderId`.
- Improved `transformParameters` to conditionally set the system prompt based on the assistant's prompt.
- Refactored `getAiSdkProviderIdForAihubmix` to simplify provider identification logic.
- Enhanced `getAiSdkProviderId` to support provider type checks.
2025-07-11 16:45:54 +08:00
suyao
a0623f2187 chore: update ai package version to 5.0.0-beta.9 in package.json and yarn.lock 2025-07-10 02:57:50 +08:00
MyPrototypeWhat
4bfff85dc8 feat: enhance web search plugin configuration
- Added `sources` array to the default web search configuration, allowing for multiple source types including 'web', 'x', and 'news'.
- This update improves the flexibility and functionality of the web search plugin.
2025-07-08 15:24:51 +08:00
suyao
8317ad55e7 Merge branch 'feat/aisdk-package' of https://github.com/CherryHQ/cherry-studio into feat/aisdk-package 2025-07-08 13:40:39 +08:00
suyao
b67cd9d145 fix: azure-openai provider 2025-07-08 13:38:14 +08:00
MyPrototypeWhat
234514d736 refactor: improve web search plugin and middleware integration
- Cleaned up the web search plugin code by commenting out unused sections for clarity.
- Enhanced middleware handling for the OpenAI provider by wrapping the logic in a block for better readability.
- Removed redundant imports from various files to streamline the codebase.
- Added `enableWebSearch` parameter to the fetchChatCompletion function for improved functionality.
2025-07-08 13:15:41 +08:00
suyao
450d6228d4 feat: aihubmix support 2025-07-08 03:47:25 +08:00
lizhixuan
3c955e69f1 feat: conditionally enable web search plugin based on configuration
- Updated the logic to add the `webSearchPlugin` only if `middlewareConfig.enableWebSearch` is true.
- Added comments to clarify the use of default search parameters and configuration options.
2025-07-07 23:33:22 +08:00
lizhixuan
4573e3f48f feat: add XAI provider options and enhance web search plugin
- Introduced `createXaiOptions` function for XAI provider configuration.
- Added `XaiProviderOptions` type and validation schema in `xai.ts`.
- Updated `ProviderOptionsMap` to include XAI options.
- Enhanced `webSearchPlugin` to support XAI-specific search parameters.
- Refactored helper functions to integrate new XAI options into provider configurations.
2025-07-07 23:28:49 +08:00
suyao
56c5e5a80f fix: format apihost 2025-07-07 21:45:18 +08:00
MyPrototypeWhat
bb520910bc refactor: update type exports and enhance web search functionality
- Added `ReasoningPart`, `FilePart`, and `ImagePart` to type exports in `index.ts`.
- Refactored `transformParameters.ts` to include `enableWebSearch` option and integrate web search tools.
- Introduced new utility `getWebSearchTools` in `websearch.ts` to manage web search tool configurations based on model type.
- Commented out deprecated code in `smoothReasoningPlugin.ts` and `textPlugin.ts` for potential removal.
2025-07-07 19:34:22 +08:00
suyao
342c5ab82c Merge branch 'feat/aisdk-package' of https://github.com/CherryHQ/cherry-studio into feat/aisdk-package 2025-07-07 18:43:27 +08:00
suyao
fce8f2411c fix: openai-gemini support 2025-07-07 18:42:31 +08:00
MyPrototypeWhat
0a908a334b refactor: enhance model configuration and plugin execution
- Simplified the `createModel` function to directly accept the `ModelConfig` object, improving clarity.
- Updated `createBaseModel` to include `extraModelConfig` for extended configuration options.
- Introduced `executeConfigureContext` method in `PluginManager` to handle context configuration for plugins.
- Adjusted type definitions in `types.ts` to ensure consistency with the new configuration structure.
- Refactored plugin execution methods in `PluginEngine` to utilize the resolved model directly, enhancing the flow of data through the plugin system.
2025-07-07 18:33:51 +08:00
suyao
c72156b2da feat: support image 2025-07-07 14:27:03 +08:00
suyao
9e252d7eb0 fix(provider): config error patch-1 2025-07-07 04:46:36 +08:00
suyao
4b0d8d7e65 fix(provider): config error 2025-07-07 04:33:37 +08:00
suyao
448b5b5c9e refactor: migrate to v5 patch-2 2025-07-07 03:58:10 +08:00
suyao
f20d964be3 Merge branch 'feat/aisdk-package' of https://github.com/CherryHQ/cherry-studio into feat/aisdk-package 2025-07-07 02:09:01 +08:00
lizhixuan
c92475b6bf refactor: streamline model configuration and factory functions
- Updated the `createModel` function to accept a simplified `ModelConfig` interface, enhancing clarity and usability.
- Refactored `createBaseModel` to destructure parameters for better readability and maintainability.
- Removed the `ModelCreator.ts` file as its functionality has been integrated into the factory functions.
- Adjusted type definitions in `types.ts` to reflect changes in model configuration structure, ensuring consistency across the codebase.
2025-07-07 00:34:32 +08:00
suyao
89cbf80008 fix: unexpected chunk 2025-07-06 23:37:42 +08:00
suyao
3e5969b97c refactor: migrate to v5 patch-1 2025-07-06 04:25:11 +08:00
suyao
cd42410d70 chore: migrate to v5 2025-07-05 13:28:19 +08:00
MyPrototypeWhat
547e5785c0 feat: add web search plugin for enhanced AI provider capabilities
- Introduced a new `webSearchPlugin` to provide unified web search functionality across multiple AI providers.
- Added helper functions for adapting web search parameters for OpenAI, Gemini, and Anthropic providers.
- Updated the built-in plugin index to export the new web search plugin and its configuration type.
- Created a new `helper.ts` file to encapsulate web search adaptation logic and support checks for provider compatibility.
2025-07-04 19:35:37 +08:00
MyPrototypeWhat
13162edcb2 refactor: remove providerParams utility module 2025-07-04 13:53:27 +08:00
MyPrototypeWhat
ac15930692 feat: enhance OpenAI model handling with utility function
- Introduced `isOpenAIChatCompletionOnlyModel` utility function to determine if a model ID corresponds to OpenAI's chat completion-only models.
- Updated `createBaseModel` function to utilize the new utility for improved handling of OpenAI provider responses in strict mode.
- Refactored reasoning parameters in `getOpenAIReasoningParams` for consistency and clarity.
2025-07-02 19:31:33 +08:00
MyPrototypeWhat
ff3b1fc38f feat: enhance OpenAI provider handling and add providerParams utility module
- Updated the `createBaseModel` function to handle OpenAI provider responses in strict mode.
- Modified `providerToAiSdkConfig` to include specific options for OpenAI when in strict mode.
- Introduced a new utility module `providerParams.ts` for managing provider-specific parameters, including OpenAI, Anthropic, and Gemini configurations.
- Added functions to retrieve service tiers, specific parameters, and reasoning efforts for various providers, improving overall provider management.
2025-07-02 16:43:06 +08:00
MyPrototypeWhat
b660e9d524 feat: implement useSmoothStream hook for dynamic text rendering
- Added a new custom hook `useSmoothStream` to manage smooth text streaming with adjustable delays.
- Integrated the `useSmoothStream` hook into the `Markdown` component to enhance content display during streaming.
- Improved state management for displayed content and stream completion status in the `Markdown` component.
2025-07-01 17:21:57 +08:00
MyPrototypeWhat
182ab6092c refactor: update reasoning plugins and enhance performance
- Replaced `smoothReasoningPlugin` with `reasoningTimePlugin` to improve reasoning time tracking.
- Commented out the unused `textPlugin` in the plugin list for better clarity.
- Adjusted delay settings in both `smoothReasoningPlugin` and `textPlugin` for optimized processing.
- Enhanced logging in reasoning plugins for better debugging and performance insights.
2025-07-01 15:28:06 +08:00
MyPrototypeWhat
cf5ed8e858 refactor: streamline reasoning plugins and remove unused components
- Removed the `reasoningTimePlugin` and `mcpPromptPlugin` to simplify the plugin architecture.
- Updated the `smoothReasoningPlugin` to enhance its functionality and reduce delay in processing.
- Adjusted the `textPlugin` to align with the new delay settings for smoother output.
- Modified the `ModernAiProvider` to utilize the updated `smoothReasoningPlugin` without the removed plugins.
2025-06-30 18:34:08 +08:00
suyao
007de81928 chore: update OpenRouter provider to version 0.7.2 and add support functions
- Updated the OpenRouter provider dependency in `package.json` and `yarn.lock` to version 0.7.2.
- Added a new function `createOpenRouterOptions` in `factory.ts` for creating OpenRouter provider options.
- Updated type definitions in `types.ts` and `registry.ts` to include OpenRouter provider settings, enhancing provider management.
2025-06-29 21:29:57 +08:00
suyao
6c87b42607 refactor: remove OpenRouter provider support and streamline reasoning logic
- Commented out the OpenRouter provider in `registry.ts` and related configurations due to excessive bugs.
- Simplified reasoning logic in `transformParameters.ts` and `options.ts` by removing unnecessary checks for `enableReasoning`.
- Enhanced logging in `transformParameters.ts` to provide better insights into reasoning capabilities.
- Updated `getReasoningEffort` to handle cases where reasoning effort is not defined, improving model compatibility.
2025-06-29 15:16:47 +08:00
suyao
592a7ddc3f Merge branch 'main' into feat/aisdk-package 2025-06-29 03:57:28 +08:00
suyao
60cb198f44 refactor: simplify provider validation and enhance plugin configuration
- Commented out the provider support check in `RuntimeExecutor` to streamline initialization.
- Updated `providerToAiSdkConfig` to utilize `AiCore.isSupported` for improved provider validation.
- Enhanced middleware configuration in `ModernAiProvider` to ensure tools are only added when enabled and available.
- Added comments in `transformParameters` for clarity on parameter handling and plugin activation.
2025-06-29 03:55:29 +08:00
suyao
54c36040af feat: extend buildStreamTextParams to include capabilities for enhanced AI functionality
- Updated the return type of `buildStreamTextParams` to include `capabilities` for reasoning, web search, and image generation.
- Modified `fetchChatCompletion` to utilize the new capabilities structure, improving middleware configuration based on model capabilities.
2025-06-29 02:59:38 +08:00
MyPrototypeWhat
ef616e1c3b fix: update reasoningTimePlugin and smoothReasoningPlugin for improved performance tracking
- Changed the invocation of `reasoningTimePlugin` to a direct reference in `ModernAiProvider`.
- Initialized `thinkingStartTime` with `performance.now()` in `reasoningTimePlugin` for accurate timing.
- Removed `thinking_millsec` from the enqueued chunks in `smoothReasoningPlugin` to streamline data handling.
- Added console logging for performance tracking in `reasoningTimePlugin` to aid in debugging.
2025-06-27 19:24:23 +08:00
MyPrototypeWhat
dc106a8af7 refactor: streamline error handling and logging in ModernAiProvider
- Commented out the try-catch block in the `ModernAiProvider` class to simplify the code structure.
- Enhanced readability by removing unnecessary error logging while maintaining the core functionality of the AI processing flow.
- Updated `messageThunk` to incorporate an abort controller for improved request management during message processing.
2025-06-27 17:08:22 +08:00
MyPrototypeWhat
1bcc716eaf refactor: rename and restructure message handling in Conversation and Orchestrate services
- Renamed `prepareMessagesForLlm` to `prepareMessagesForModel` in `ConversationService` for clarity.
- Updated `OrchestrationService` to use the new method name and introduced a new function `transformMessagesAndFetch` for improved message processing.
- Adjusted imports in `messageThunk` to reflect the changes in the orchestration service, enhancing code readability and maintainability.
2025-06-27 16:38:32 +08:00
MyPrototypeWhat
30a288ce5d feat: introduce MCP Prompt Plugin and refactor built-in plugin structure
- Added `mcpPromptPlugin.ts` to encapsulate MCP Prompt functionality, providing a structured approach for tool calls within prompts.
- Updated `index.ts` to reference the new `mcpPromptPlugin`, enhancing modularity and clarity in the built-in plugins.
- Removed the outdated `example-plugins.ts` file to streamline the plugin directory and focus on essential components.
2025-06-27 15:45:56 +08:00
suyao
cbbaa3127c Merge branch 'feat/aisdk-package' of https://github.com/CherryHQ/cherry-studio into feat/aisdk-package 2025-06-27 15:11:04 +08:00
suyao
f61da8c2d6 feat: enhance ModernAiProvider with new reasoning plugins and dynamic middleware construction
- Introduced `reasoningTimePlugin` and `smoothReasoningPlugin` to improve reasoning content handling and processing.
- Refactored `ModernAiProvider` to dynamically build plugin arrays based on middleware configuration, enhancing flexibility.
- Removed the obsolete `ThinkingTimeMiddleware` to streamline middleware management.
- Updated `buildAiSdkMiddlewares` to reflect changes in middleware handling and improve clarity in the configuration process.
- Enhanced logging for better visibility into plugin and middleware configurations during execution.
2025-06-27 15:10:47 +08:00
MyPrototypeWhat
d9eb9e86fe refactor: disable console logging in MCP Prompt plugin for cleaner output
- Commented out console log statements in the `createMCPPromptPlugin` to reduce noise during execution.
- Maintained the structure and functionality of the plugin while improving readability and performance.
2025-06-27 15:03:58 +08:00
suyao
87f803b0d3 feat: update package dependencies and introduce new patches for AI SDK tools
- Added patches for `@ai-sdk/google-vertex` and `@ai-sdk/openai-compatible` to enhance functionality and fix issues.
- Updated `package.json` to reflect new dependency versions and patch paths.
- Refactored `transformParameters` and `ApiService` to support new tool configurations and improve parameter handling.
- Introduced utility functions for setting up tools and managing options, enhancing the overall integration of tools within the AI SDK.
2025-06-27 13:21:33 +08:00
lizhixuan
c934b45c09 feat: enhance MCP Prompt plugin with recursive call support and context handling
- Updated `AiRequestContext` to enforce `recursiveCall` and added `isRecursiveCall` for better state management.
- Modified `createContext` to initialize `recursiveCall` with a placeholder function.
- Enhanced `MCPPromptPlugin` to utilize a custom `createSystemMessage` function for improved message handling during recursive calls.
- Refactored `PluginEngine` to manage recursive call states, ensuring proper execution flow and context integrity.
2025-06-26 23:48:06 +08:00
lizhixuan
ba121d04b4 <type>: <subject>
<body>
<footer>
用來簡要描述影響本次變動,概述即可
2025-06-26 21:33:05 +08:00
MyPrototypeWhat
9293f26612 feat: enhance MCP Prompt plugin and recursive call capabilities
- Updated `tsconfig.web.json` to support wildcard imports for `@cherrystudio/ai-core`.
- Enhanced `package.json` to include type definitions and imports for built-in plugins.
- Introduced recursive call functionality in `PluginManager` and `PluginEngine`, allowing for improved handling of tool interactions.
- Added `MCPPromptPlugin` to facilitate tool calls within prompts, enabling recursive processing of tool results.
- Refactored `transformStream` methods across plugins to accommodate new parameters and improve type safety.
2025-06-26 19:42:04 +08:00
lizhixuan
8b67a45804 refactor: update RuntimeExecutor and introduce MCP Prompt Plugin
- Changed `pluginClient` to `pluginEngine` in `RuntimeExecutor` for clarity and consistency.
- Updated method calls in `RuntimeExecutor` to use the new `pluginEngine`.
- Enhanced `AiSdkMiddlewareBuilder` to include `mcpTools` in the middleware configuration.
- Added `MCPPromptPlugin` to support tool calls within prompts, enabling recursive processing and improved handling of tool interactions.
- Updated `ApiService` to pass `mcpTools` during chat completion requests, enhancing integration with the new plugin system.
2025-06-26 00:10:39 +08:00
MyPrototypeWhat
f23a026a28 feat: enhance plugin system with new reasoning and text plugins
- Introduced `reasonPlugin` and `textPlugin` to improve chunk processing and handling of reasoning content.
- Updated `transformStream` method signatures for better type safety and usability.
- Enhanced `ThinkingTimeMiddleware` to accurately track thinking time using `performance.now()`.
- Refactored `ThinkingBlock` component to utilize block thinking time directly, improving performance and clarity.
- Added logging for middleware builder to assist in debugging and monitoring middleware configurations.
2025-06-25 19:00:54 +08:00
MyPrototypeWhat
e4c0ea035f feat: enhance AI Core runtime with advanced model handling and middleware support
- Introduced new high-level APIs for model creation and configuration, improving usability for advanced users.
- Enhanced the RuntimeExecutor to support both direct model usage and model ID resolution, allowing for more flexible execution options.
- Updated existing methods to accept middleware configurations, streamlining the integration of custom processing logic.
- Refactored the plugin system to better accommodate middleware, enhancing the overall extensibility of the AI Core.
- Improved documentation to reflect the new capabilities and usage patterns for the runtime APIs.
2025-06-25 17:25:45 +08:00
lizhixuan
7d8ed3a737 refactor: simplify AI Core architecture and enhance runtime execution
- Restructured the AI Core documentation to reflect a simplified two-layer architecture, focusing on clear responsibilities between models and runtime layers.
- Removed the orchestration layer and consolidated its functionality into the runtime layer, streamlining the API for users.
- Introduced a new runtime executor for managing plugin-enhanced AI calls, improving the handling of execution and middleware.
- Updated the core modules to enhance type safety and usability, including comprehensive type definitions for model creation and execution configurations.
- Removed obsolete files and refactored existing code to improve organization and maintainability across the SDK.
2025-06-23 23:58:05 +08:00
MyPrototypeWhat
2a588fdab2 refactor: restructure AI Core architecture and enhance client functionality
- Updated the AI Core documentation to reflect the new architecture and design principles, emphasizing modularity and type safety.
- Refactored the client structure by removing obsolete files and consolidating client creation logic into a more streamlined format.
- Introduced a new core module for managing execution and middleware, improving the overall organization of the codebase.
- Enhanced the orchestration layer to provide a clearer API for users, integrating the creation and execution processes more effectively.
- Added comprehensive type definitions and utility functions for better type safety and usability across the SDK.
2025-06-23 19:51:40 +08:00
suyao
f08c444ffb feat: enhance provider ID resolution in AI SDK
- Updated getAiSdkProviderId function to include mapping for provider types, improving compatibility with third-party SDKs.
- Refined return logic to ensure correct provider ID resolution, enhancing overall functionality and support for various providers.
2025-06-21 23:46:06 +08:00
lizhixuan
f6c3794ac9 feat: enhance AI SDK chunk handling and tool call processing
- Introduced ToolCallChunkHandler for managing tool call events and results, improving the handling of tool interactions.
- Updated AiSdkToChunkAdapter to utilize the new handler, streamlining the processing of tool call chunks.
- Refactored transformParameters to support dynamic tool integration and improved parameter handling.
- Adjusted provider mapping in factory.ts to include new provider types, enhancing compatibility with various AI services.
- Removed obsolete cherryStudioTransformPlugin to clean up the codebase and focus on more relevant functionality.
2025-06-21 23:26:52 +08:00
suyao
ebe85ba24a fix: enhance anthropic provider configuration and middleware handling
- Updated providerToAiSdkConfig to support both OpenAI and Anthropic providers, improving flexibility in API host formatting.
- Refactored thinkingTimeMiddleware to ensure all chunks are correctly enqueued, enhancing middleware functionality.
- Corrected parameter naming in getAnthropicReasoningParams for consistency and clarity in configuration.
2025-06-21 22:38:54 +08:00
suyao
09080f0755 feat: add OpenAI Compatible provider and enhance provider configuration
- Introduced a new OpenAI Compatible provider to the AiProviderRegistry, allowing for integration with the @ai-sdk/openai-compatible package.
- Updated provider configuration logic to support the new provider, including adjustments to API host formatting and options management.
- Refactored middleware to streamline handling of OpenAI-specific configurations.
2025-06-21 22:19:10 +08:00
suyao
e421b81fca feat: add patch for Google Vertex AI and enhance private key handling
- Introduced a patch for the @ai-sdk/google-vertex package to improve URL handling based on region.
- Added a new utility function to format private keys, ensuring correct PEM structure and validation.
- Updated the ProviderConfigBuilder to utilize the new private key formatting function for Google credentials.
- Created a pnpm workspace configuration to manage patched dependencies effectively.
2025-06-21 20:31:24 +08:00
suyao
2f58b3360e feat: enhance provider options and examples for AI SDK
- Introduced new utility functions for creating and merging provider options, improving type safety and usability.
- Added comprehensive examples for OpenAI, Anthropic, Google, and generic provider options to demonstrate usage.
- Refactored existing code to streamline provider configuration and enhance clarity in the options management.
- Updated the PluginEnabledAiClient to simplify the handling of model parameters and improve overall functionality.
2025-06-21 16:48:16 +08:00
suyao
f934b479b2 feat: enhance Vertex AI provider integration and configuration
- Added support for Google Vertex AI credentials in the provider configuration.
- Refactored the VertexAPIClient to handle both standard and VertexProvider types.
- Implemented utility functions to check Vertex AI configuration completeness and create VertexProvider instances.
- Updated provider mapping in index_new.ts to ensure proper handling of Vertex AI settings.
2025-06-21 14:08:35 +08:00
suyao
8ca6341609 feat: add openai-compatible provider and enhance provider configuration
- Introduced the @ai-sdk/openai-compatible package to support compatibility with OpenAI.
- Added a new ProviderConfigFactory and ProviderConfigBuilder for streamlined provider configuration.
- Updated the provider registry to include the new Google Vertex AI import path.
- Enhanced the index.ts to export new provider configuration utilities for better type safety and usability.
- Refactored ApiService and middleware to integrate the new provider configurations effectively.
2025-06-21 12:48:53 +08:00
MyPrototypeWhat
c99a2fedb7 feat: enhance AI SDK documentation and client functionality
- Added detailed usage examples for the native provider registry in the README.md, demonstrating how to create and utilize custom provider registries.
- Updated ApiClientFactory to enforce type safety for model instances.
- Refactored PluginEnabledAiClient methods to support both built-in logic and custom registry usage for text and object generation, improving flexibility and usability.
2025-06-20 20:28:44 +08:00
MyPrototypeWhat
456e6c068e refactor: update ApiClientFactory and index_new for improved type handling and provider mapping
- Changed the type of options in ClientConfig to 'any' for flexibility.
- Overloaded createImageClient method to support different provider settings.
- Added vertexai mapping to the provider type mapping in index_new.ts for enhanced compatibility.
2025-06-20 20:25:19 +08:00
MyPrototypeWhat
f206d4ec4c fix: refine experimental_transform handling and improve chunking logic
- Updated PluginEnabledAiClient to streamline the handling of experimental_transform parameters.
- Adjusted ModernAiProvider's smoothStream configuration for better chunking of text, enhancing processing efficiency.
- Re-enabled block updates in messageThunk for improved state management.
2025-06-20 20:14:37 +08:00
MyPrototypeWhat
1af8be8768 feat: add Cherry Studio transformation and settings plugins
- Introduced cherryStudioTransformPlugin for converting Cherry Studio messages to AI SDK format, enhancing compatibility.
- Added cherryStudioSettingsPlugin to manage Assistant settings like temperature and TopP.
- Implemented createCherryStudioContext function for preparing context metadata for Cherry Studio calls.
2025-06-20 20:14:10 +08:00
suyao
e70174817e refactor: update AiSdkToChunkAdapter and middleware for improved chunk handling
- Modified convertAndEmitChunk method to handle new chunk types and streamline processing.
- Adjusted thinkingTimeMiddleware to remove unnecessary parameters and enhance clarity.
- Updated middleware integration in AiSdkMiddlewareBuilder for better middleware management.
2025-06-20 20:12:44 +08:00
MyPrototypeWhat
c5cb443de0 feat: enhance AI SDK documentation and client functionality
- Added detailed usage examples for the native provider registry in the README.md, demonstrating how to create and utilize custom provider registries.
- Updated ApiClientFactory to enforce type safety for model instances.
- Refactored PluginEnabledAiClient methods to support both built-in logic and custom registry usage for text and object generation, improving flexibility and usability.
2025-06-20 20:12:44 +08:00
suyao
9318d9ffeb feat: enhance AI core functionality with smoothStream integration
- Added smoothStream to the middleware exports in index.ts for improved streaming capabilities.
- Updated PluginEnabledAiClient to conditionally apply middlewares, removing the default simulateStreamingMiddleware.
- Modified ModernAiProvider to utilize smoothStream in streamText, enhancing text processing with configurable chunking and delay options.
2025-06-20 20:12:44 +08:00
suyao
3771b24b52 feat: enhance AI SDK middleware integration and support
- Added AiSdkMiddlewareBuilder for dynamic middleware construction based on various conditions.
- Updated ModernAiProvider to utilize new middleware configuration, improving flexibility in handling completions.
- Refactored ApiService to pass middleware configuration during AI completions, enabling better control over processing.
- Introduced new README documentation for the middleware builder, outlining usage and supported conditions.
2025-06-20 20:08:10 +08:00
suyao
1bccfd3170 feat: 完成api层,业务逻辑层,编排层的分离
feat: 为插件系统实现中间件
feat: 实现自定义的思考中间件

- Updated package.json and related files to reflect the correct naming convention for the @cherrystudio/ai-core package.
- Adjusted import paths in various files to ensure consistency with the new package name.
- Enhanced type resolution in tsconfig.web.json to align with the updated package structure.
2025-06-20 20:01:13 +08:00
MyPrototypeWhat
43d55b7e45 feat: integrate @cherry-studio/ai-core and enhance AI SDK support
- Added @cherry-studio/ai-core as a workspace dependency in package.json for improved modularity.
- Updated tsconfig to include paths for the new AI core package, enhancing type resolution.
- Refactored aiCore package to use source files directly, improving build efficiency.
- Introduced a new AiSdkToChunkAdapter for converting AI SDK streams to Cherry Studio chunk format.
- Implemented a modernized AI provider interface in index_new.ts, allowing fallback to legacy implementations.
- Enhanced parameter transformation logic for better integration with AI SDK features.
- Updated ApiService to utilize the new AI provider, streamlining chat completion requests.
2025-06-20 19:59:36 +08:00
MyPrototypeWhat
1c5a30cf49 feat: enhance AI Core with new client and plugin system features
- Introduced `PluginEnabledAiClient` for a more flexible client interface with integrated plugin support.
- Updated `ApiClientFactory` and `UniversalAiSdkClient` to utilize new provider settings for improved type safety.
- Added a comprehensive plugin management system, allowing for dynamic plugin registration and execution.
- Enhanced the provider registry to include new AI providers and updated existing provider settings.
- Removed deprecated files and streamlined the codebase for better maintainability and clarity.
- Updated documentation to reflect new features and usage examples for the plugin system.
2025-06-20 19:57:55 +08:00
suyao
2df1cddb43 feat: enhance AI Core with image generation capabilities
- Introduced `createImageClient` method in `ApiClientFactory` to support image generation for various providers.
- Updated `UniversalAiSdkClient` to include `generateImage` method, allowing image generation through the unified client interface.
- Refactored client creation functions to utilize the new `ProviderOptions` type for improved type safety.
- Enhanced the provider registry to indicate which providers support image generation, streamlining client creation and usage.
- Updated type definitions in `types.ts` to reflect changes in client options and middleware support.
2025-06-20 19:55:42 +08:00
MyPrototypeWhat
ed2363e561 feat: enhance AI Core with plugin system and middleware support
- Introduced a plugin system in the AI Core package, allowing for flexible request handling and middleware integration.
- Added support for various hook types: First, Sequential, Parallel, and Stream, enabling developers to customize request processing.
- Implemented a PluginManager for managing and executing plugins, enhancing extensibility and modularity.
- Updated architecture documentation to reflect new plugin capabilities and usage examples.
- Included new middleware types and examples to demonstrate the plugin system's functionality.

This update aims to improve the developer experience by providing a robust framework for extending AI Core's capabilities.
2025-06-20 19:49:24 +08:00
MyPrototypeWhat
a27d1bf506 feat: introduce Cherry Studio AI Core package with unified AI provider interface
- Added a new package `@cherry-studio/ai-core` that provides a unified interface for various AI providers based on the Vercel AI SDK.
- Implemented core components including `ApiClientFactory`, `UniversalAiSdkClient`, and a provider registry for dynamic imports.
- Included TypeScript support and a lightweight design for improved developer experience.
- Documented architecture and usage examples in `AI_SDK_ARCHITECTURE.md` and `README.md`.
- Updated `package.json` to include dependencies for supported AI providers.

This package aims to streamline the integration of multiple AI providers while ensuring type safety and modularity.
2025-06-20 19:48:56 +08:00
201 changed files with 14877 additions and 1322 deletions

View File

@@ -7,3 +7,4 @@ tsconfig.*.json
CHANGELOG*.md
agents.json
src/renderer/src/integration/nutstore/sso/lib
AGENT.md

View File

@@ -116,11 +116,26 @@ afterSign: scripts/notarize.js
artifactBuildCompleted: scripts/artifact-build-completed.js
releaseInfo:
releaseNotes: |
输入框快捷菜单增加清除按钮
侧边栏增加代码工具入口,代码工具增加环境变量设置
小程序增加多语言显示
优化 MCP 服务器列表
新增 Web 搜索图标
优化 SVG 预览,优化 HTML 内容样式
修复知识库文档预处理失败问题
稳定性改进和错误修复
🎉 新增功能:
- 新增错误详情模态框,提供完整的错误信息展示和复制功能
- 新增错误详情的多语言支持(英语、日语、俄语、中文简繁体)
🔧 优化改进:
- 升级 AI Core 到 v1.0.0-alpha.11,重构模型解析逻辑
- 增强温度和 TopP 参数处理,特别针对 Claude 推理努力模型优化
- 改进提供商配置管理,简化 OpenAI 模式处理和服务层级设置
- 优化 MCP 工具可见性,增强提示工具支持
- 重构错误序列化机制,提升类型安全性
- 优化补全方法,支持开发者模式下的追踪功能
- 改进提供商初始化逻辑,支持动态注册新的 AI 提供商
🐛 问题修复:
- 修复错误处理回调中的类型安全问题,使用 AISDKError 类型
- 修复提供商初始化和配置相关问题
- 移除过时的模型解析函数,清理废弃代码
- 修复 Gemini 集成中的提供商配置缺失问题
⚡ 性能提升:
- 提升模型参数处理效率,优化温度和 TopP 计算逻辑
- 优化提供商配置加载和初始化性能
- 改进错误处理性能,减少不必要的错误格式化开销

View File

@@ -81,7 +81,10 @@ export default defineConfig({
'@shared': resolve('packages/shared'),
'@logger': resolve('src/renderer/src/services/LoggerService'),
'@mcp-trace/trace-core': resolve('packages/mcp-trace/trace-core'),
'@mcp-trace/trace-web': resolve('packages/mcp-trace/trace-web')
'@mcp-trace/trace-web': resolve('packages/mcp-trace/trace-web'),
'@cherrystudio/ai-core/provider': resolve('packages/aiCore/src/core/providers'),
'@cherrystudio/ai-core/built-in/plugins': resolve('packages/aiCore/src/core/plugins/built-in'),
'@cherrystudio/ai-core': resolve('packages/aiCore/src')
}
},
optimizeDeps: {

View File

@@ -1,6 +1,6 @@
{
"name": "CherryStudio",
"version": "1.5.7-rc.2",
"version": "1.6.0-beta.2",
"private": true,
"description": "A powerful AI assistant for producer.",
"main": "./out/main/index.js",
@@ -87,12 +87,16 @@
"@agentic/exa": "^7.3.3",
"@agentic/searxng": "^7.3.3",
"@agentic/tavily": "^7.3.3",
"@ai-sdk/amazon-bedrock": "^3.0.0",
"@ai-sdk/google-vertex": "^3.0.0",
"@ai-sdk/mistral": "^2.0.0",
"@ant-design/v5-patch-for-react-19": "^1.0.3",
"@anthropic-ai/sdk": "^0.41.0",
"@anthropic-ai/vertex-sdk": "patch:@anthropic-ai/vertex-sdk@npm%3A0.11.4#~/.yarn/patches/@anthropic-ai-vertex-sdk-npm-0.11.4-c19cb41edb.patch",
"@aws-sdk/client-bedrock": "^3.840.0",
"@aws-sdk/client-bedrock-runtime": "^3.840.0",
"@aws-sdk/client-s3": "^3.840.0",
"@cherrystudio/ai-core": "workspace:*",
"@cherrystudio/embedjs": "^0.1.31",
"@cherrystudio/embedjs-libsql": "^0.1.31",
"@cherrystudio/embedjs-loader-csv": "^0.1.31",
@@ -127,6 +131,7 @@
"@modelcontextprotocol/sdk": "^1.17.0",
"@mozilla/readability": "^0.6.0",
"@notionhq/client": "^2.2.15",
"@openrouter/ai-sdk-provider": "^1.1.2",
"@opentelemetry/api": "^1.9.0",
"@opentelemetry/core": "2.0.0",
"@opentelemetry/exporter-trace-otlp-http": "^0.200.0",
@@ -136,7 +141,7 @@
"@playwright/test": "^1.52.0",
"@reduxjs/toolkit": "^2.2.5",
"@shikijs/markdown-it": "^3.9.1",
"@swc/plugin-styled-components": "^7.1.5",
"@swc/plugin-styled-components": "^8.0.4",
"@tanstack/react-query": "^5.27.0",
"@tanstack/react-virtual": "^3.13.12",
"@testing-library/dom": "^10.4.0",
@@ -169,6 +174,7 @@
"@viz-js/lang-dot": "^1.0.5",
"@viz-js/viz": "^3.14.0",
"@xyflow/react": "^12.4.4",
"ai": "^5.0.26",
"antd": "patch:antd@npm%3A5.27.0#~/.yarn/patches/antd-npm-5.27.0-aa91c36546.patch",
"archiver": "^7.0.1",
"async-mutex": "^0.5.0",
@@ -307,7 +313,7 @@
"prettier --write",
"eslint --fix"
],
"*.{json,md,yml,yaml,css,scss,html}": [
"*.{json,yml,yaml,css,scss,html}": [
"prettier --write"
]
}

View File

@@ -0,0 +1,514 @@
# AI Core 基于 Vercel AI SDK 的技术架构
## 1. 架构设计理念
### 1.1 设计目标
- **简化分层**`models`(模型层)→ `runtime`(运行时层),清晰的职责分离
- **统一接口**:使用 Vercel AI SDK 统一不同 AI Provider 的接口差异
- **动态导入**:通过动态导入实现按需加载,减少打包体积
- **最小包装**:直接使用 AI SDK 的类型和接口,避免重复定义
- **插件系统**:基于钩子的通用插件架构,支持请求全生命周期扩展
- **类型安全**:利用 TypeScript 和 AI SDK 的类型系统确保类型安全
- **轻量级**:专注核心功能,保持包的轻量和高效
- **包级独立**:作为独立包管理,便于复用和维护
- **Agent就绪**:为将来集成 OpenAI Agents SDK 预留扩展空间
### 1.2 核心优势
- **标准化**AI SDK 提供统一的模型接口,减少适配工作
- **简化设计**函数式API避免过度抽象
- **更好的开发体验**:完整的 TypeScript 支持和丰富的生态系统
- **性能优化**AI SDK 内置优化和最佳实践
- **模块化设计**:独立包结构,支持跨项目复用
- **可扩展插件**:通用的流转换和参数处理插件系统
- **面向未来**:为 OpenAI Agents SDK 集成做好准备
## 2. 整体架构图
```mermaid
graph TD
subgraph "用户应用 (如 Cherry Studio)"
UI["用户界面"]
Components["应用组件"]
end
subgraph "packages/aiCore (AI Core 包)"
subgraph "Runtime Layer (运行时层)"
RuntimeExecutor["RuntimeExecutor (运行时执行器)"]
PluginEngine["PluginEngine (插件引擎)"]
RuntimeAPI["Runtime API (便捷函数)"]
end
subgraph "Models Layer (模型层)"
ModelFactory["createModel() (模型工厂)"]
ProviderCreator["ProviderCreator (提供商创建器)"]
end
subgraph "Core Systems (核心系统)"
subgraph "Plugins (插件)"
PluginManager["PluginManager (插件管理)"]
BuiltInPlugins["Built-in Plugins (内置插件)"]
StreamTransforms["Stream Transforms (流转换)"]
end
subgraph "Middleware (中间件)"
MiddlewareWrapper["wrapModelWithMiddlewares() (中间件包装)"]
end
subgraph "Providers (提供商)"
Registry["Provider Registry (注册表)"]
Factory["Provider Factory (工厂)"]
end
end
end
subgraph "Vercel AI SDK"
AICore["ai (核心库)"]
OpenAI["@ai-sdk/openai"]
Anthropic["@ai-sdk/anthropic"]
Google["@ai-sdk/google"]
XAI["@ai-sdk/xai"]
Others["其他 19+ Providers"]
end
subgraph "Future: OpenAI Agents SDK"
AgentSDK["@openai/agents (未来集成)"]
AgentExtensions["Agent Extensions (预留)"]
end
UI --> RuntimeAPI
Components --> RuntimeExecutor
RuntimeAPI --> RuntimeExecutor
RuntimeExecutor --> PluginEngine
RuntimeExecutor --> ModelFactory
PluginEngine --> PluginManager
ModelFactory --> ProviderCreator
ModelFactory --> MiddlewareWrapper
ProviderCreator --> Registry
Registry --> Factory
Factory --> OpenAI
Factory --> Anthropic
Factory --> Google
Factory --> XAI
Factory --> Others
RuntimeExecutor --> AICore
AICore --> streamText
AICore --> generateText
AICore --> streamObject
AICore --> generateObject
PluginManager --> StreamTransforms
PluginManager --> BuiltInPlugins
%% 未来集成路径
RuntimeExecutor -.-> AgentSDK
AgentSDK -.-> AgentExtensions
```
## 3. 包结构设计
### 3.1 新架构文件结构
```
packages/aiCore/
├── src/
│ ├── core/ # 核心层 - 内部实现
│ │ ├── models/ # 模型层 - 模型创建和配置
│ │ │ ├── factory.ts # 模型工厂函数 ✅
│ │ │ ├── ModelCreator.ts # 模型创建器 ✅
│ │ │ ├── ConfigManager.ts # 配置管理器 ✅
│ │ │ ├── types.ts # 模型类型定义 ✅
│ │ │ └── index.ts # 模型层导出 ✅
│ │ ├── runtime/ # 运行时层 - 执行和用户API
│ │ │ ├── executor.ts # 运行时执行器 ✅
│ │ │ ├── pluginEngine.ts # 插件引擎 ✅
│ │ │ ├── types.ts # 运行时类型定义 ✅
│ │ │ └── index.ts # 运行时导出 ✅
│ │ ├── middleware/ # 中间件系统
│ │ │ ├── wrapper.ts # 模型包装器 ✅
│ │ │ ├── manager.ts # 中间件管理器 ✅
│ │ │ ├── types.ts # 中间件类型 ✅
│ │ │ └── index.ts # 中间件导出 ✅
│ │ ├── plugins/ # 插件系统
│ │ │ ├── types.ts # 插件类型定义 ✅
│ │ │ ├── manager.ts # 插件管理器 ✅
│ │ │ ├── built-in/ # 内置插件 ✅
│ │ │ │ ├── logging.ts # 日志插件 ✅
│ │ │ │ ├── webSearchPlugin/ # 网络搜索插件 ✅
│ │ │ │ ├── toolUsePlugin/ # 工具使用插件 ✅
│ │ │ │ └── index.ts # 内置插件导出 ✅
│ │ │ ├── README.md # 插件文档 ✅
│ │ │ └── index.ts # 插件导出 ✅
│ │ ├── providers/ # 提供商管理
│ │ │ ├── registry.ts # 提供商注册表 ✅
│ │ │ ├── factory.ts # 提供商工厂 ✅
│ │ │ ├── creator.ts # 提供商创建器 ✅
│ │ │ ├── types.ts # 提供商类型 ✅
│ │ │ ├── utils.ts # 工具函数 ✅
│ │ │ └── index.ts # 提供商导出 ✅
│ │ ├── options/ # 配置选项
│ │ │ ├── factory.ts # 选项工厂 ✅
│ │ │ ├── types.ts # 选项类型 ✅
│ │ │ ├── xai.ts # xAI 选项 ✅
│ │ │ ├── openrouter.ts # OpenRouter 选项 ✅
│ │ │ ├── examples.ts # 示例配置 ✅
│ │ │ └── index.ts # 选项导出 ✅
│ │ └── index.ts # 核心层导出 ✅
│ ├── types.ts # 全局类型定义 ✅
│ └── index.ts # 包主入口文件 ✅
├── package.json # 包配置文件 ✅
├── tsconfig.json # TypeScript 配置 ✅
├── README.md # 包说明文档 ✅
└── AI_SDK_ARCHITECTURE.md # 本文档 ✅
```
## 4. 架构分层详解
### 4.1 Models Layer (模型层)
**职责**:统一的模型创建和配置管理
**核心文件**
- `factory.ts`: 模型工厂函数 (`createModel`, `createModels`)
- `ProviderCreator.ts`: 底层提供商创建和模型实例化
- `types.ts`: 模型配置类型定义
**设计特点**
- 函数式设计,避免不必要的类抽象
- 统一的模型配置接口
- 自动处理中间件应用
- 支持批量模型创建
**核心API**
```typescript
// 模型配置接口
export interface ModelConfig {
providerId: ProviderId
modelId: string
options: ProviderSettingsMap[ProviderId]
middlewares?: LanguageModelV1Middleware[]
}
// 核心模型创建函数
export async function createModel(config: ModelConfig): Promise<LanguageModel>
export async function createModels(configs: ModelConfig[]): Promise<LanguageModel[]>
```
### 4.2 Runtime Layer (运行时层)
**职责**运行时执行器和用户面向的API接口
**核心组件**
- `executor.ts`: 运行时执行器类
- `plugin-engine.ts`: 插件引擎原PluginEnabledAiClient
- `index.ts`: 便捷函数和工厂方法
**设计特点**
- 提供三种使用方式:类实例、静态工厂、函数式调用
- 自动集成模型创建和插件处理
- 完整的类型安全支持
- 为 OpenAI Agents SDK 预留扩展接口
**核心API**
```typescript
// 运行时执行器
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
static create<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T],
plugins?: AiPlugin[]
): RuntimeExecutor<T>
async streamText(modelId: string, params: StreamTextParams): Promise<StreamTextResult>
async generateText(modelId: string, params: GenerateTextParams): Promise<GenerateTextResult>
async streamObject(modelId: string, params: StreamObjectParams): Promise<StreamObjectResult>
async generateObject(modelId: string, params: GenerateObjectParams): Promise<GenerateObjectResult>
}
// 便捷函数式API
export async function streamText<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T],
modelId: string,
params: StreamTextParams,
plugins?: AiPlugin[]
): Promise<StreamTextResult>
```
### 4.3 Plugin System (插件系统)
**职责**:可扩展的插件架构
**核心组件**
- `PluginManager`: 插件生命周期管理
- `built-in/`: 内置插件集合
- 流转换收集和应用
**设计特点**
- 借鉴 Rollup 的钩子分类设计
- 支持流转换 (`experimental_transform`)
- 内置常用插件(日志、计数等)
- 完整的生命周期钩子
**插件接口**
```typescript
export interface AiPlugin {
name: string
enforce?: 'pre' | 'post'
// 【First】首个钩子 - 只执行第一个返回值的插件
resolveModel?: (modelId: string, context: AiRequestContext) => string | null | Promise<string | null>
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
// 【Sequential】串行钩子 - 链式执行,支持数据转换
transformParams?: (params: any, context: AiRequestContext) => any | Promise<any>
transformResult?: (result: any, context: AiRequestContext) => any | Promise<any>
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise<void>
onError?: (error: Error, context: AiRequestContext) => void | Promise<void>
// 【Stream】流处理
transformStream?: () => TransformStream
}
```
### 4.4 Middleware System (中间件系统)
**职责**AI SDK原生中间件支持
**核心组件**
- `ModelWrapper.ts`: 模型包装函数
**设计哲学**
- 直接使用AI SDK的 `wrapLanguageModel`
- 与插件系统分离,职责明确
- 函数式设计,简化使用
```typescript
export function wrapModelWithMiddlewares(model: LanguageModel, middlewares: LanguageModelV1Middleware[]): LanguageModel
```
### 4.5 Provider System (提供商系统)
**职责**AI Provider注册表和动态导入
**核心组件**
- `registry.ts`: 19+ Provider配置和类型
- `factory.ts`: Provider配置工厂
**支持的Providers**
- OpenAI, Anthropic, Google, XAI
- Azure OpenAI, Amazon Bedrock, Google Vertex
- Groq, Together.ai, Fireworks, DeepSeek
- 等19+ AI SDK官方支持的providers
## 5. 使用方式
### 5.1 函数式调用 (推荐 - 简单场景)
```typescript
import { streamText, generateText } from '@cherrystudio/ai-core/runtime'
// 直接函数调用
const stream = await streamText(
'anthropic',
{ apiKey: 'your-api-key' },
'claude-3',
{ messages: [{ role: 'user', content: 'Hello!' }] },
[loggingPlugin]
)
```
### 5.2 执行器实例 (推荐 - 复杂场景)
```typescript
import { createExecutor } from '@cherrystudio/ai-core/runtime'
// 创建可复用的执行器
const executor = createExecutor('openai', { apiKey: 'your-api-key' }, [plugin1, plugin2])
// 多次使用
const stream = await executor.streamText('gpt-4', {
messages: [{ role: 'user', content: 'Hello!' }]
})
const result = await executor.generateText('gpt-4', {
messages: [{ role: 'user', content: 'How are you?' }]
})
```
### 5.3 静态工厂方法
```typescript
import { RuntimeExecutor } from '@cherrystudio/ai-core/runtime'
// 静态创建
const executor = RuntimeExecutor.create('anthropic', { apiKey: 'your-api-key' })
await executor.streamText('claude-3', { messages: [...] })
```
### 5.4 直接模型创建 (高级用法)
```typescript
import { createModel } from '@cherrystudio/ai-core/models'
import { streamText } from 'ai'
// 直接创建模型使用
const model = await createModel({
providerId: 'openai',
modelId: 'gpt-4',
options: { apiKey: 'your-api-key' },
middlewares: [middleware1, middleware2]
})
// 直接使用 AI SDK
const result = await streamText({ model, messages: [...] })
```
## 6. 为 OpenAI Agents SDK 预留的设计
### 6.1 架构兼容性
当前架构完全兼容 OpenAI Agents SDK 的集成需求:
```typescript
// 当前的模型创建
const model = await createModel({
providerId: 'anthropic',
modelId: 'claude-3',
options: { apiKey: 'xxx' }
})
// 将来可以直接用于 OpenAI Agents SDK
import { Agent, run } from '@openai/agents'
const agent = new Agent({
model, // ✅ 直接兼容 LanguageModel 接口
name: 'Assistant',
instructions: '...',
tools: [tool1, tool2]
})
const result = await run(agent, 'user input')
```
### 6.2 预留的扩展点
1. **runtime/agents/** 目录预留
2. **AgentExecutor** 类预留
3. **Agent工具转换插件** 预留
4. **多Agent编排** 预留
### 6.3 未来架构扩展
```
packages/aiCore/src/core/
├── runtime/
│ ├── agents/ # 🚀 未来添加
│ │ ├── AgentExecutor.ts
│ │ ├── WorkflowManager.ts
│ │ └── ConversationManager.ts
│ ├── executor.ts
│ └── index.ts
```
## 7. 架构优势
### 7.1 简化设计
- **移除过度抽象**删除了orchestration层和creation层的复杂包装
- **函数式优先**models层使用函数而非类
- **直接明了**runtime层直接提供用户API
### 7.2 职责清晰
- **Models**: 专注模型创建和配置
- **Runtime**: 专注执行和用户API
- **Plugins**: 专注扩展功能
- **Providers**: 专注AI Provider管理
### 7.3 类型安全
- 完整的 TypeScript 支持
- AI SDK 类型的直接复用
- 避免类型重复定义
### 7.4 灵活使用
- 三种使用模式满足不同需求
- 从简单函数调用到复杂执行器
- 支持直接AI SDK使用
### 7.5 面向未来
- 为 OpenAI Agents SDK 集成做好准备
- 清晰的扩展点和架构边界
- 模块化设计便于功能添加
## 8. 技术决策记录
### 8.1 为什么选择简化的两层架构?
- **职责分离**models专注创建runtime专注执行
- **模块化**:每层都有清晰的边界和职责
- **扩展性**为Agent功能预留了清晰的扩展空间
### 8.2 为什么选择函数式设计?
- **简洁性**:避免不必要的类设计
- **性能**:减少对象创建开销
- **易用性**:函数调用更直观
### 8.3 为什么分离插件和中间件?
- **职责明确**: 插件处理应用特定需求
- **原生支持**: 中间件使用AI SDK原生功能
- **灵活性**: 两套系统可以独立演进
## 9. 总结
AI Core架构实现了
### 9.1 核心特点
-**简化架构**: 2层核心架构职责清晰
-**函数式设计**: models层完全函数化
-**类型安全**: 统一的类型定义和AI SDK类型复用
-**插件扩展**: 强大的插件系统
-**多种使用方式**: 满足不同复杂度需求
-**Agent就绪**: 为OpenAI Agents SDK集成做好准备
### 9.2 核心价值
- **统一接口**: 一套API支持19+ AI providers
- **灵活使用**: 函数式、实例式、静态工厂式
- **强类型**: 完整的TypeScript支持
- **可扩展**: 插件和中间件双重扩展能力
- **高性能**: 最小化包装直接使用AI SDK
- **面向未来**: Agent SDK集成架构就绪
### 9.3 未来发展
这个架构提供了:
- **优秀的开发体验**: 简洁的API和清晰的使用模式
- **强大的扩展能力**: 为Agent功能预留了完整的架构空间
- **良好的维护性**: 职责分离明确,代码易于维护
- **广泛的适用性**: 既适合简单调用也适合复杂应用

433
packages/aiCore/README.md Normal file
View File

@@ -0,0 +1,433 @@
# @cherrystudio/ai-core
Cherry Studio AI Core 是一个基于 Vercel AI SDK 的统一 AI Provider 接口包,为 AI 应用提供强大的抽象层和插件化架构。
## ✨ 核心亮点
### 🏗️ 优雅的架构设计
- **简化分层**`models`(模型层)→ `runtime`(运行时层),清晰的职责分离
- **函数式优先**:避免过度抽象,提供简洁直观的 API
- **类型安全**:完整的 TypeScript 支持,直接复用 AI SDK 类型系统
- **最小包装**:直接使用 AI SDK 的接口,避免重复定义和性能损耗
### 🔌 强大的插件系统
- **生命周期钩子**:支持请求全生命周期的扩展点
- **流转换支持**:基于 AI SDK 的 `experimental_transform` 实现流处理
- **插件分类**First、Sequential、Parallel 三种钩子类型,满足不同场景
- **内置插件**webSearch、logging、toolUse 等开箱即用的功能
### 🌐 统一多 Provider 接口
- **扩展注册**:支持自定义 Provider 注册,无限扩展能力
- **配置统一**:统一的配置接口,简化多 Provider 管理
### 🚀 多种使用方式
- **函数式调用**:适合简单场景的直接函数调用
- **执行器实例**:适合复杂场景的可复用执行器
- **静态工厂**:便捷的静态创建方法
- **原生兼容**:完全兼容 AI SDK 原生 Provider Registry
### 🔮 面向未来
- **Agent 就绪**:为 OpenAI Agents SDK 集成预留架构空间
- **模块化设计**:独立包结构,支持跨项目复用
- **渐进式迁移**:可以逐步从现有 AI SDK 代码迁移
## 特性
- 🚀 统一的 AI Provider 接口
- 🔄 动态导入支持
- 🛠️ TypeScript 支持
- 📦 强大的插件系统
- 🌍 内置webSearch(Openai,Google,Anthropic,xAI)
- 🎯 多种使用模式(函数式/实例式/静态工厂)
- 🔌 可扩展的 Provider 注册系统
- 🧩 完整的中间件支持
- 📊 插件统计和调试功能
## 支持的 Providers
基于 [AI SDK 官方支持的 providers](https://ai-sdk.dev/providers/ai-sdk-providers)
**核心 Providers内置支持:**
- OpenAI
- Anthropic
- Google Generative AI
- OpenAI-Compatible
- xAI (Grok)
- Azure OpenAI
- DeepSeek
**扩展 Providers通过注册API支持:**
- Google Vertex AI
- ...
- 自定义 Provider
## 安装
```bash
npm install @cherrystudio/ai-core ai
```
### React Native
如果你在 React Native 项目中使用此包,需要在 `metro.config.js` 中添加以下配置:
```javascript
// metro.config.js
const { getDefaultConfig } = require('expo/metro-config')
const config = getDefaultConfig(__dirname)
// 添加对 @cherrystudio/ai-core 的支持
config.resolver.resolverMainFields = ['react-native', 'browser', 'main']
config.resolver.platforms = ['ios', 'android', 'native', 'web']
module.exports = config
```
还需要安装你要使用的 AI SDK provider:
```bash
npm install @ai-sdk/openai @ai-sdk/anthropic @ai-sdk/google
```
## 使用示例
### 基础用法
```typescript
import { AiCore } from '@cherrystudio/ai-core'
// 创建 OpenAI executor
const executor = AiCore.create('openai', {
apiKey: 'your-api-key'
})
// 流式生成
const result = await executor.streamText('gpt-4', {
messages: [{ role: 'user', content: 'Hello!' }]
})
// 非流式生成
const response = await executor.generateText('gpt-4', {
messages: [{ role: 'user', content: 'Hello!' }]
})
```
### 便捷函数
```typescript
import { createOpenAIExecutor } from '@cherrystudio/ai-core'
// 快速创建 OpenAI executor
const executor = createOpenAIExecutor({
apiKey: 'your-api-key'
})
// 使用 executor
const result = await executor.streamText('gpt-4', {
messages: [{ role: 'user', content: 'Hello!' }]
})
```
### 多 Provider 支持
```typescript
import { AiCore } from '@cherrystudio/ai-core'
// 支持多种 AI providers
const openaiExecutor = AiCore.create('openai', { apiKey: 'openai-key' })
const anthropicExecutor = AiCore.create('anthropic', { apiKey: 'anthropic-key' })
const googleExecutor = AiCore.create('google', { apiKey: 'google-key' })
const xaiExecutor = AiCore.create('xai', { apiKey: 'xai-key' })
```
### 扩展 Provider 注册
对于非内置的 providers可以通过注册 API 扩展支持:
```typescript
import { registerProvider, AiCore } from '@cherrystudio/ai-core'
// 方式一:导入并注册第三方 provider
import { createGroq } from '@ai-sdk/groq'
registerProvider({
id: 'groq',
name: 'Groq',
creator: createGroq,
supportsImageGeneration: false
})
// 现在可以使用 Groq
const groqExecutor = AiCore.create('groq', { apiKey: 'groq-key' })
// 方式二:动态导入方式注册
registerProvider({
id: 'mistral',
name: 'Mistral AI',
import: () => import('@ai-sdk/mistral'),
creatorFunctionName: 'createMistral'
})
const mistralExecutor = AiCore.create('mistral', { apiKey: 'mistral-key' })
```
## 🔌 插件系统
AI Core 提供了强大的插件系统,支持请求全生命周期的扩展。
### 内置插件
#### webSearchPlugin - 网络搜索插件
为不同 AI Provider 提供统一的网络搜索能力:
```typescript
import { webSearchPlugin } from '@cherrystudio/ai-core/built-in/plugins'
const executor = AiCore.create('openai', { apiKey: 'your-key' }, [
webSearchPlugin({
openai: {
/* OpenAI 搜索配置 */
},
anthropic: { maxUses: 5 },
google: {
/* Google 搜索配置 */
},
xai: {
mode: 'on',
returnCitations: true,
maxSearchResults: 5,
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
}
})
])
```
#### loggingPlugin - 日志插件
提供详细的请求日志记录:
```typescript
import { createLoggingPlugin } from '@cherrystudio/ai-core/built-in/plugins'
const executor = AiCore.create('openai', { apiKey: 'your-key' }, [
createLoggingPlugin({
logLevel: 'info',
includeParams: true,
includeResult: false
})
])
```
#### promptToolUsePlugin - 提示工具使用插件
为不支持原生 Function Call 的模型提供 prompt 方式的工具调用:
```typescript
import { createPromptToolUsePlugin } from '@cherrystudio/ai-core/built-in/plugins'
// 对于不支持 function call 的模型
const executor = AiCore.create(
'providerId',
{
apiKey: 'your-key',
baseURL: 'https://your-model-endpoint'
},
[
createPromptToolUsePlugin({
enabled: true,
// 可选:自定义系统提示符构建
buildSystemPrompt: (userPrompt, tools) => {
return `${userPrompt}\n\nAvailable tools: ${Object.keys(tools).join(', ')}`
}
})
]
)
```
### 自定义插件
创建自定义插件非常简单:
```typescript
import { definePlugin } from '@cherrystudio/ai-core'
const customPlugin = definePlugin({
name: 'custom-plugin',
enforce: 'pre', // 'pre' | 'post' | undefined
// 在请求开始时记录日志
onRequestStart: async (context) => {
console.log(`Starting request for model: ${context.modelId}`)
},
// 转换请求参数
transformParams: async (params, context) => {
// 添加自定义系统消息
if (params.messages) {
params.messages.unshift({
role: 'system',
content: 'You are a helpful assistant.'
})
}
return params
},
// 处理响应结果
transformResult: async (result, context) => {
// 添加元数据
if (result.text) {
result.metadata = {
processedAt: new Date().toISOString(),
modelId: context.modelId
}
}
return result
}
})
// 使用自定义插件
const executor = AiCore.create('openai', { apiKey: 'your-key' }, [customPlugin])
```
### 使用 AI SDK 原生 Provider 注册表
> https://ai-sdk.dev/docs/reference/ai-sdk-core/provider-registry
除了使用内建的 provider 管理,你还可以使用 AI SDK 原生的 `createProviderRegistry` 来构建自己的 provider 注册表。
#### 基本用法示例
```typescript
import { createClient } from '@cherrystudio/ai-core'
import { createProviderRegistry } from 'ai'
import { createOpenAI } from '@ai-sdk/openai'
import { anthropic } from '@ai-sdk/anthropic'
// 1. 创建 AI SDK 原生注册表
export const registry = createProviderRegistry({
// register provider with prefix and default setup:
anthropic,
// register provider with prefix and custom setup:
openai: createOpenAI({
apiKey: process.env.OPENAI_API_KEY
})
})
// 2. 创建client,'openai'可以传空或者传providerId(内建的provider)
const client = PluginEnabledAiClient.create('openai', {
apiKey: process.env.OPENAI_API_KEY
})
// 3. 方式1使用内建逻辑传统方式
const result1 = await client.streamText('gpt-4', {
messages: [{ role: 'user', content: 'Hello with built-in logic!' }]
})
// 4. 方式2使用自定义注册表灵活方式
const result2 = await client.streamText({
model: registry.languageModel('openai:gpt-4'),
messages: [{ role: 'user', content: 'Hello with custom registry!' }]
})
// 5. 支持的重载方法
await client.generateObject({
model: registry.languageModel('openai:gpt-4'),
schema: z.object({ name: z.string() }),
messages: [{ role: 'user', content: 'Generate a user' }]
})
await client.streamObject({
model: registry.languageModel('anthropic:claude-3-opus-20240229'),
schema: z.object({ items: z.array(z.string()) }),
messages: [{ role: 'user', content: 'Generate a list' }]
})
```
#### 与插件系统配合使用
更强大的是,你还可以将自定义注册表与 Cherry Studio 的插件系统结合使用:
```typescript
import { PluginEnabledAiClient } from '@cherrystudio/ai-core'
import { createProviderRegistry } from 'ai'
import { createOpenAI } from '@ai-sdk/openai'
import { anthropic } from '@ai-sdk/anthropic'
// 1. 创建带插件的客户端
const client = PluginEnabledAiClient.create(
'openai',
{
apiKey: process.env.OPENAI_API_KEY
},
[LoggingPlugin, RetryPlugin]
)
// 2. 创建自定义注册表
const registry = createProviderRegistry({
openai: createOpenAI({ apiKey: process.env.OPENAI_API_KEY }),
anthropic: anthropic({ apiKey: process.env.ANTHROPIC_API_KEY })
})
// 3. 方式1使用内建逻辑 + 完整插件系统
await client.streamText('gpt-4', {
messages: [{ role: 'user', content: 'Hello with plugins!' }]
})
// 4. 方式2使用自定义注册表 + 有限插件支持
await client.streamText({
model: registry.languageModel('anthropic:claude-3-opus-20240229'),
messages: [{ role: 'user', content: 'Hello from Claude!' }]
})
// 5. 支持的方法
await client.generateObject({
model: registry.languageModel('openai:gpt-4'),
schema: z.object({ name: z.string() }),
messages: [{ role: 'user', content: 'Generate a user' }]
})
await client.streamObject({
model: registry.languageModel('openai:gpt-4'),
schema: z.object({ items: z.array(z.string()) }),
messages: [{ role: 'user', content: 'Generate a list' }]
})
```
#### 混合使用的优势
- **灵活性**:可以根据需要选择使用内建逻辑或自定义注册表
- **兼容性**:完全兼容 AI SDK 的 `createProviderRegistry` API
- **渐进式**:可以逐步迁移现有代码,无需一次性重构
- **插件支持**:自定义注册表仍可享受插件系统的部分功能
- **最佳实践**:结合两种方式的优点,既有动态加载的性能优势,又有统一注册表的便利性
## 📚 相关资源
- [Vercel AI SDK 文档](https://ai-sdk.dev/)
- [Cherry Studio 项目](https://github.com/CherryHQ/cherry-studio)
- [AI SDK Providers](https://ai-sdk.dev/providers/ai-sdk-providers)
## 未来版本
- 🔮 多 Agent 编排
- 🔮 可视化插件配置
- 🔮 实时监控和分析
- 🔮 云端插件同步
## 📄 License
MIT License - 详见 [LICENSE](https://github.com/CherryHQ/cherry-studio/blob/main/LICENSE) 文件
---
**Cherry Studio AI Core** - 让 AI 开发更简单、更强大、更灵活 🚀

View File

@@ -0,0 +1,103 @@
/**
* Hub Provider 使用示例
*
* 演示如何使用简化后的Hub Provider功能来路由到多个底层provider
*/
import { createHubProvider, initializeProvider, providerRegistry } from '../src/index'
async function demonstrateHubProvider() {
try {
// 1. 初始化底层providers
console.log('📦 初始化底层providers...')
initializeProvider('openai', {
apiKey: process.env.OPENAI_API_KEY || 'sk-test-key'
})
initializeProvider('anthropic', {
apiKey: process.env.ANTHROPIC_API_KEY || 'sk-ant-test-key'
})
// 2. 创建Hub Provider自动包含所有已初始化的providers
console.log('🌐 创建Hub Provider...')
const aihubmixProvider = createHubProvider({
hubId: 'aihubmix',
debug: true
})
// 3. 注册Hub Provider
providerRegistry.registerProvider('aihubmix', aihubmixProvider)
console.log('✅ Hub Provider "aihubmix" 注册成功')
// 4. 使用Hub Provider访问不同的模型
console.log('\n🚀 使用Hub模型...')
// 通过Hub路由到OpenAI
const openaiModel = providerRegistry.languageModel('aihubmix:openai:gpt-4')
console.log('✓ OpenAI模型已获取:', openaiModel.modelId)
// 通过Hub路由到Anthropic
const anthropicModel = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
console.log('✓ Anthropic模型已获取:', anthropicModel.modelId)
// 5. 演示错误处理
console.log('\n❌ 演示错误处理...')
try {
// 尝试访问未初始化的provider
providerRegistry.languageModel('aihubmix:google:gemini-pro')
} catch (error) {
console.log('预期错误:', error.message)
}
try {
// 尝试使用错误的模型ID格式
providerRegistry.languageModel('aihubmix:invalid-format')
} catch (error) {
console.log('预期错误:', error.message)
}
// 6. 多个Hub Provider示例
console.log('\n🔄 创建多个Hub Provider...')
const localHubProvider = createHubProvider({
hubId: 'local-ai'
})
providerRegistry.registerProvider('local-ai', localHubProvider)
console.log('✅ Hub Provider "local-ai" 注册成功')
console.log('\n🎉 Hub Provider演示完成')
} catch (error) {
console.error('💥 演示过程中发生错误:', error)
}
}
// 演示简化的使用方式
function simplifiedUsageExample() {
console.log('\n📝 简化使用示例:')
console.log(`
// 1. 初始化providers
initializeProvider('openai', { apiKey: 'sk-xxx' })
initializeProvider('anthropic', { apiKey: 'sk-ant-xxx' })
// 2. 创建并注册Hub Provider
const hubProvider = createHubProvider({ hubId: 'aihubmix' })
providerRegistry.registerProvider('aihubmix', hubProvider)
// 3. 直接使用
const model1 = providerRegistry.languageModel('aihubmix:openai:gpt-4')
const model2 = providerRegistry.languageModel('aihubmix:anthropic:claude-3.5-sonnet')
`)
}
// 运行演示
if (require.main === module) {
demonstrateHubProvider()
simplifiedUsageExample()
}
export { demonstrateHubProvider, simplifiedUsageExample }

View File

@@ -0,0 +1,167 @@
/**
* Image Generation Example
* 演示如何使用 aiCore 的文生图功能
*/
import { createExecutor, generateImage } from '../src/index'
async function main() {
// 方式1: 使用执行器实例
console.log('📸 创建 OpenAI 图像生成执行器...')
const executor = createExecutor('openai', {
apiKey: process.env.OPENAI_API_KEY!
})
try {
console.log('🎨 使用执行器生成图像...')
const result1 = await executor.generateImage('dall-e-3', {
prompt: 'A futuristic cityscape at sunset with flying cars',
size: '1024x1024',
n: 1
})
console.log('✅ 图像生成成功!')
console.log('📊 结果:', {
imagesCount: result1.images.length,
mediaType: result1.image.mediaType,
hasBase64: !!result1.image.base64,
providerMetadata: result1.providerMetadata
})
} catch (error) {
console.error('❌ 执行器生成失败:', error)
}
// 方式2: 使用直接调用 API
try {
console.log('🎨 使用直接 API 生成图像...')
const result2 = await generateImage('openai', { apiKey: process.env.OPENAI_API_KEY! }, 'dall-e-3', {
prompt: 'A magical forest with glowing mushrooms and fairy lights',
aspectRatio: '16:9',
providerOptions: {
openai: {
quality: 'hd',
style: 'vivid'
}
}
})
console.log('✅ 直接 API 生成成功!')
console.log('📊 结果:', {
imagesCount: result2.images.length,
mediaType: result2.image.mediaType,
hasBase64: !!result2.image.base64
})
} catch (error) {
console.error('❌ 直接 API 生成失败:', error)
}
// 方式3: 支持其他提供商 (Google Imagen)
if (process.env.GOOGLE_API_KEY) {
try {
console.log('🎨 使用 Google Imagen 生成图像...')
const googleExecutor = createExecutor('google', {
apiKey: process.env.GOOGLE_API_KEY!
})
const result3 = await googleExecutor.generateImage('imagen-3.0-generate-002', {
prompt: 'A serene mountain lake at dawn with mist rising from the water',
aspectRatio: '1:1'
})
console.log('✅ Google Imagen 生成成功!')
console.log('📊 结果:', {
imagesCount: result3.images.length,
mediaType: result3.image.mediaType,
hasBase64: !!result3.image.base64
})
} catch (error) {
console.error('❌ Google Imagen 生成失败:', error)
}
}
// 方式4: 支持插件系统
const pluginExample = async () => {
console.log('🔌 演示插件系统...')
// 创建一个示例插件,用于修改提示词
const promptEnhancerPlugin = {
name: 'prompt-enhancer',
transformParams: async (params: any) => {
console.log('🔧 插件: 增强提示词...')
return {
...params,
prompt: `${params.prompt}, highly detailed, cinematic lighting, 4K resolution`
}
},
transformResult: async (result: any) => {
console.log('🔧 插件: 处理结果...')
return {
...result,
enhanced: true
}
}
}
const executorWithPlugin = createExecutor(
'openai',
{
apiKey: process.env.OPENAI_API_KEY!
},
[promptEnhancerPlugin]
)
try {
const result4 = await executorWithPlugin.generateImage('dall-e-3', {
prompt: 'A cute robot playing in a garden'
})
console.log('✅ 插件系统生成成功!')
console.log('📊 结果:', {
imagesCount: result4.images.length,
enhanced: (result4 as any).enhanced,
mediaType: result4.image.mediaType
})
} catch (error) {
console.error('❌ 插件系统生成失败:', error)
}
}
await pluginExample()
}
// 错误处理演示
async function errorHandlingExample() {
console.log('⚠️ 演示错误处理...')
try {
const executor = createExecutor('openai', {
apiKey: 'invalid-key'
})
await executor.generateImage('dall-e-3', {
prompt: 'Test image'
})
} catch (error: any) {
console.log('✅ 成功捕获错误:', error.constructor.name)
console.log('📋 错误信息:', error.message)
console.log('🏷️ 提供商ID:', error.providerId)
console.log('🏷️ 模型ID:', error.modelId)
}
}
// 运行示例
if (require.main === module) {
main()
.then(() => {
console.log('🎉 所有示例完成!')
return errorHandlingExample()
})
.then(() => {
console.log('🎯 示例程序结束')
process.exit(0)
})
.catch((error) => {
console.error('💥 程序执行出错:', error)
process.exit(1)
})
}

View File

@@ -0,0 +1,85 @@
{
"name": "@cherrystudio/ai-core",
"version": "1.0.0-alpha.11",
"description": "Cherry Studio AI Core - Unified AI Provider Interface Based on Vercel AI SDK",
"main": "dist/index.js",
"module": "dist/index.mjs",
"types": "dist/index.d.ts",
"react-native": "dist/index.js",
"scripts": {
"build": "tsdown",
"dev": "tsc -w",
"clean": "rm -rf dist",
"test": "vitest run",
"test:watch": "vitest"
},
"keywords": [
"ai",
"sdk",
"openai",
"anthropic",
"google",
"cherry-studio",
"vercel-ai-sdk"
],
"author": "Cherry Studio",
"license": "MIT",
"repository": {
"type": "git",
"url": "git+https://github.com/CherryHQ/cherry-studio.git"
},
"bugs": {
"url": "https://github.com/CherryHQ/cherry-studio/issues"
},
"homepage": "https://github.com/CherryHQ/cherry-studio#readme",
"peerDependencies": {
"ai": "^5.0.26"
},
"dependencies": {
"@ai-sdk/anthropic": "^2.0.5",
"@ai-sdk/azure": "^2.0.16",
"@ai-sdk/deepseek": "^1.0.9",
"@ai-sdk/google": "^2.0.7",
"@ai-sdk/openai": "^2.0.19",
"@ai-sdk/openai-compatible": "^1.0.9",
"@ai-sdk/provider": "^2.0.0",
"@ai-sdk/provider-utils": "^3.0.4",
"@ai-sdk/xai": "^2.0.9",
"zod": "^3.25.0"
},
"devDependencies": {
"tsdown": "^0.12.9",
"typescript": "^5.0.0",
"vitest": "^3.2.4"
},
"sideEffects": false,
"engines": {
"node": ">=18.0.0"
},
"files": [
"dist"
],
"exports": {
".": {
"types": "./dist/index.d.ts",
"react-native": "./dist/index.js",
"import": "./dist/index.mjs",
"require": "./dist/index.js",
"default": "./dist/index.js"
},
"./built-in/plugins": {
"types": "./dist/built-in/plugins/index.d.ts",
"react-native": "./dist/built-in/plugins/index.js",
"import": "./dist/built-in/plugins/index.mjs",
"require": "./dist/built-in/plugins/index.js",
"default": "./dist/built-in/plugins/index.js"
},
"./provider": {
"types": "./dist/provider/index.d.ts",
"react-native": "./dist/provider/index.js",
"import": "./dist/provider/index.mjs",
"require": "./dist/provider/index.js",
"default": "./dist/provider/index.js"
}
}
}

View File

@@ -0,0 +1,2 @@
// 模拟 Vite SSR helper避免 Node 环境找不到时报错
;(globalThis as any).__vite_ssr_exportName__ = (name: string, value: any) => value

View File

@@ -0,0 +1,3 @@
# @cherryStudio-aiCore
Core

View File

@@ -0,0 +1,17 @@
/**
* Core 模块导出
* 内部核心功能,供其他模块使用,不直接面向最终调用者
*/
// 中间件系统
export type { NamedMiddleware } from './middleware'
export { createMiddlewares, wrapModelWithMiddlewares } from './middleware'
// 创建管理
export { globalModelResolver, ModelResolver } from './models'
export type { ModelConfig as ModelConfigType } from './models/types'
// 执行管理
export type { ToolUseRequestContext } from './plugins/built-in/toolUsePlugin/type'
export { createExecutor, createOpenAICompatibleExecutor } from './runtime'
export type { RuntimeConfig } from './runtime/types'

View File

@@ -0,0 +1,8 @@
/**
* Middleware 模块导出
* 提供通用的中间件管理能力
*/
export { createMiddlewares } from './manager'
export type { NamedMiddleware } from './types'
export { wrapModelWithMiddlewares } from './wrapper'

View File

@@ -0,0 +1,16 @@
/**
* 中间件管理器
* 专注于 AI SDK 中间件的管理,与插件系统分离
*/
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
/**
* 创建中间件列表
* 合并用户提供的中间件
*/
export function createMiddlewares(userMiddlewares: LanguageModelV2Middleware[] = []): LanguageModelV2Middleware[] {
// 未来可以在这里添加默认的中间件
const defaultMiddlewares: LanguageModelV2Middleware[] = []
return [...defaultMiddlewares, ...userMiddlewares]
}

View File

@@ -0,0 +1,12 @@
/**
* 中间件系统类型定义
*/
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
/**
* 具名中间件接口
*/
export interface NamedMiddleware {
name: string
middleware: LanguageModelV2Middleware
}

View File

@@ -0,0 +1,23 @@
/**
* 模型包装工具函数
* 用于将中间件应用到LanguageModel上
*/
import { LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
import { wrapLanguageModel } from 'ai'
/**
* 使用中间件包装模型
*/
export function wrapModelWithMiddlewares(
model: LanguageModelV2,
middlewares: LanguageModelV2Middleware[]
): LanguageModelV2 {
if (middlewares.length === 0) {
return model
}
return wrapLanguageModel({
model,
middleware: middlewares
})
}

View File

@@ -0,0 +1,124 @@
/**
* 模型解析器 - models模块的核心
* 负责将modelId解析为AI SDK的LanguageModel实例
* 支持传统格式和命名空间格式
* 集成了来自 ModelCreator 的特殊处理逻辑
*/
import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
import { wrapModelWithMiddlewares } from '../middleware/wrapper'
import { DEFAULT_SEPARATOR, globalRegistryManagement } from '../providers/RegistryManagement'
export class ModelResolver {
/**
* 核心方法解析任意格式的modelId为语言模型
*
* @param modelId 模型ID支持 'gpt-4' 和 'anthropic>claude-3' 两种格式
* @param fallbackProviderId 当modelId为传统格式时使用的providerId
* @param providerOptions provider配置选项用于OpenAI模式选择等
* @param middlewares 中间件数组,会应用到最终模型上
*/
async resolveLanguageModel(
modelId: string,
fallbackProviderId: string,
providerOptions?: any,
middlewares?: LanguageModelV2Middleware[]
): Promise<LanguageModelV2> {
let finalProviderId = fallbackProviderId
let model: LanguageModelV2
// 🎯 处理 OpenAI 模式选择逻辑 (从 ModelCreator 迁移)
if (fallbackProviderId === 'openai' && providerOptions?.mode === 'chat') {
finalProviderId = 'openai-chat'
}
// 检查是否是命名空间格式
if (modelId.includes(DEFAULT_SEPARATOR)) {
model = this.resolveNamespacedModel(modelId)
} else {
// 传统格式:使用处理后的 providerId + modelId
model = this.resolveTraditionalModel(finalProviderId, modelId)
}
// 🎯 应用中间件(如果有)
if (middlewares && middlewares.length > 0) {
model = wrapModelWithMiddlewares(model, middlewares)
}
return model
}
/**
* 解析文本嵌入模型
*/
async resolveTextEmbeddingModel(modelId: string, fallbackProviderId: string): Promise<EmbeddingModelV2<string>> {
if (modelId.includes(DEFAULT_SEPARATOR)) {
return this.resolveNamespacedEmbeddingModel(modelId)
}
return this.resolveTraditionalEmbeddingModel(fallbackProviderId, modelId)
}
/**
* 解析图像模型
*/
async resolveImageModel(modelId: string, fallbackProviderId: string): Promise<ImageModelV2> {
if (modelId.includes(DEFAULT_SEPARATOR)) {
return this.resolveNamespacedImageModel(modelId)
}
return this.resolveTraditionalImageModel(fallbackProviderId, modelId)
}
/**
* 解析命名空间格式的语言模型
* aihubmix:anthropic:claude-3 -> globalRegistryManagement.languageModel('aihubmix:anthropic:claude-3')
*/
private resolveNamespacedModel(modelId: string): LanguageModelV2 {
return globalRegistryManagement.languageModel(modelId as any)
}
/**
* 解析传统格式的语言模型
* providerId: 'openai', modelId: 'gpt-4' -> globalRegistryManagement.languageModel('openai:gpt-4')
*/
private resolveTraditionalModel(providerId: string, modelId: string): LanguageModelV2 {
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
return globalRegistryManagement.languageModel(fullModelId as any)
}
/**
* 解析命名空间格式的嵌入模型
*/
private resolveNamespacedEmbeddingModel(modelId: string): EmbeddingModelV2<string> {
return globalRegistryManagement.textEmbeddingModel(modelId as any)
}
/**
* 解析传统格式的嵌入模型
*/
private resolveTraditionalEmbeddingModel(providerId: string, modelId: string): EmbeddingModelV2<string> {
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
return globalRegistryManagement.textEmbeddingModel(fullModelId as any)
}
/**
* 解析命名空间格式的图像模型
*/
private resolveNamespacedImageModel(modelId: string): ImageModelV2 {
return globalRegistryManagement.imageModel(modelId as any)
}
/**
* 解析传统格式的图像模型
*/
private resolveTraditionalImageModel(providerId: string, modelId: string): ImageModelV2 {
const fullModelId = `${providerId}${DEFAULT_SEPARATOR}${modelId}`
return globalRegistryManagement.imageModel(fullModelId as any)
}
}
/**
* 全局模型解析器实例
*/
export const globalModelResolver = new ModelResolver()

View File

@@ -0,0 +1,9 @@
/**
* Models 模块统一导出 - 简化版
*/
// 核心模型解析器
export { globalModelResolver, ModelResolver } from './ModelResolver'
// 保留的类型定义(可能被其他地方使用)
export type { ModelConfig as ModelConfigType } from './types'

View File

@@ -0,0 +1,15 @@
/**
* Creation 模块类型定义
*/
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
import type { ProviderId, ProviderSettingsMap } from '../providers/types'
export interface ModelConfig<T extends ProviderId = ProviderId> {
providerId: T
modelId: string
providerSettings: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' }
middlewares?: LanguageModelV2Middleware[]
// 额外模型参数
extraModelConfig?: Record<string, any>
}

View File

@@ -0,0 +1,87 @@
import { streamText } from 'ai'
import {
createAnthropicOptions,
createGenericProviderOptions,
createGoogleOptions,
createOpenAIOptions,
mergeProviderOptions
} from './factory'
// 示例1: 使用已知供应商的严格类型约束
export function exampleOpenAIWithOptions() {
const openaiOptions = createOpenAIOptions({
reasoningEffort: 'medium'
})
// 这里会有类型检查确保选项符合OpenAI的设置
return streamText({
model: {} as any, // 实际使用时替换为真实模型
prompt: 'Hello',
providerOptions: openaiOptions
})
}
// 示例2: 使用Anthropic供应商选项
export function exampleAnthropicWithOptions() {
const anthropicOptions = createAnthropicOptions({
thinking: {
type: 'enabled',
budgetTokens: 1000
}
})
return streamText({
model: {} as any,
prompt: 'Hello',
providerOptions: anthropicOptions
})
}
// 示例3: 使用Google供应商选项
export function exampleGoogleWithOptions() {
const googleOptions = createGoogleOptions({
thinkingConfig: {
includeThoughts: true,
thinkingBudget: 1000
}
})
return streamText({
model: {} as any,
prompt: 'Hello',
providerOptions: googleOptions
})
}
// 示例4: 使用未知供应商(通用类型)
export function exampleUnknownProviderWithOptions() {
const customProviderOptions = createGenericProviderOptions('custom-provider', {
temperature: 0.7,
customSetting: 'value',
anotherOption: true
})
return streamText({
model: {} as any,
prompt: 'Hello',
providerOptions: customProviderOptions
})
}
// 示例5: 合并多个供应商选项
export function exampleMergedOptions() {
const openaiOptions = createOpenAIOptions({})
const customOptions = createGenericProviderOptions('custom', {
customParam: 'value'
})
const mergedOptions = mergeProviderOptions(openaiOptions, customOptions)
return streamText({
model: {} as any,
prompt: 'Hello',
providerOptions: mergedOptions
})
}

View File

@@ -0,0 +1,71 @@
import { ExtractProviderOptions, ProviderOptionsMap, TypedProviderOptions } from './types'
/**
* 创建特定供应商的选项
* @param provider 供应商名称
* @param options 供应商特定的选项
* @returns 格式化的provider options
*/
export function createProviderOptions<T extends keyof ProviderOptionsMap>(
provider: T,
options: ExtractProviderOptions<T>
): Record<T, ExtractProviderOptions<T>> {
return { [provider]: options } as Record<T, ExtractProviderOptions<T>>
}
/**
* 创建任意供应商的选项(包括未知供应商)
* @param provider 供应商名称
* @param options 供应商选项
* @returns 格式化的provider options
*/
export function createGenericProviderOptions<T extends string>(
provider: T,
options: Record<string, any>
): Record<T, Record<string, any>> {
return { [provider]: options } as Record<T, Record<string, any>>
}
/**
* 合并多个供应商的options
* @param optionsMap 包含多个供应商选项的对象
* @returns 合并后的TypedProviderOptions
*/
export function mergeProviderOptions(...optionsMap: Partial<TypedProviderOptions>[]): TypedProviderOptions {
return Object.assign({}, ...optionsMap)
}
/**
* 创建OpenAI供应商选项的便捷函数
*/
export function createOpenAIOptions(options: ExtractProviderOptions<'openai'>) {
return createProviderOptions('openai', options)
}
/**
* 创建Anthropic供应商选项的便捷函数
*/
export function createAnthropicOptions(options: ExtractProviderOptions<'anthropic'>) {
return createProviderOptions('anthropic', options)
}
/**
* 创建Google供应商选项的便捷函数
*/
export function createGoogleOptions(options: ExtractProviderOptions<'google'>) {
return createProviderOptions('google', options)
}
/**
* 创建OpenRouter供应商选项的便捷函数
*/
export function createOpenRouterOptions(options: ExtractProviderOptions<'openrouter'>) {
return createProviderOptions('openrouter', options)
}
/**
* 创建XAI供应商选项的便捷函数
*/
export function createXaiOptions(options: ExtractProviderOptions<'xai'>) {
return createProviderOptions('xai', options)
}

View File

@@ -0,0 +1,2 @@
export * from './factory'
export * from './types'

View File

@@ -0,0 +1,38 @@
export type OpenRouterProviderOptions = {
models?: string[]
/**
* https://openrouter.ai/docs/use-cases/reasoning-tokens
* One of `max_tokens` or `effort` is required.
* If `exclude` is true, reasoning will be removed from the response. Default is false.
*/
reasoning?: {
exclude?: boolean
} & (
| {
max_tokens: number
}
| {
effort: 'high' | 'medium' | 'low'
}
)
/**
* A unique identifier representing your end-user, which can
* help OpenRouter to monitor and detect abuse.
*/
user?: string
extraBody?: Record<string, unknown>
/**
* Enable usage accounting to get detailed token usage information.
* https://openrouter.ai/docs/use-cases/usage-accounting
*/
usage?: {
/**
* When true, includes token usage information in the response.
*/
include: boolean
}
}

View File

@@ -0,0 +1,33 @@
import { type AnthropicProviderOptions } from '@ai-sdk/anthropic'
import { type GoogleGenerativeAIProviderOptions } from '@ai-sdk/google'
import { type OpenAIResponsesProviderOptions } from '@ai-sdk/openai'
import { type SharedV2ProviderMetadata } from '@ai-sdk/provider'
import { type OpenRouterProviderOptions } from './openrouter'
import { type XaiProviderOptions } from './xai'
export type ProviderOptions<T extends keyof SharedV2ProviderMetadata> = SharedV2ProviderMetadata[T]
/**
* 供应商选项类型如果map中没有说明没有约束
*/
export type ProviderOptionsMap = {
openai: OpenAIResponsesProviderOptions
anthropic: AnthropicProviderOptions
google: GoogleGenerativeAIProviderOptions
openrouter: OpenRouterProviderOptions
xai: XaiProviderOptions
}
// 工具类型用于从ProviderOptionsMap中提取特定供应商的选项类型
export type ExtractProviderOptions<T extends keyof ProviderOptionsMap> = ProviderOptionsMap[T]
/**
* 类型安全的ProviderOptions
* 对于已知供应商使用严格类型对于未知供应商允许任意Record<string, JSONValue>
*/
export type TypedProviderOptions = {
[K in keyof ProviderOptionsMap]?: ProviderOptionsMap[K]
} & {
[K in string]?: Record<string, any>
} & SharedV2ProviderMetadata

View File

@@ -0,0 +1,86 @@
// copy from @ai-sdk/xai/xai-chat-options.ts
// 如果@ai-sdk/xai暴露出了xaiProviderOptions就删除这个文件
import * as z from 'zod/v4'
const webSourceSchema = z.object({
type: z.literal('web'),
country: z.string().length(2).optional(),
excludedWebsites: z.array(z.string()).max(5).optional(),
allowedWebsites: z.array(z.string()).max(5).optional(),
safeSearch: z.boolean().optional()
})
const xSourceSchema = z.object({
type: z.literal('x'),
xHandles: z.array(z.string()).optional()
})
const newsSourceSchema = z.object({
type: z.literal('news'),
country: z.string().length(2).optional(),
excludedWebsites: z.array(z.string()).max(5).optional(),
safeSearch: z.boolean().optional()
})
const rssSourceSchema = z.object({
type: z.literal('rss'),
links: z.array(z.url()).max(1) // currently only supports one RSS link
})
const searchSourceSchema = z.discriminatedUnion('type', [
webSourceSchema,
xSourceSchema,
newsSourceSchema,
rssSourceSchema
])
export const xaiProviderOptions = z.object({
/**
* reasoning effort for reasoning models
* only supported by grok-3-mini and grok-3-mini-fast models
*/
reasoningEffort: z.enum(['low', 'high']).optional(),
searchParameters: z
.object({
/**
* search mode preference
* - "off": disables search completely
* - "auto": model decides whether to search (default)
* - "on": always enables search
*/
mode: z.enum(['off', 'auto', 'on']),
/**
* whether to return citations in the response
* defaults to true
*/
returnCitations: z.boolean().optional(),
/**
* start date for search data (ISO8601 format: YYYY-MM-DD)
*/
fromDate: z.string().optional(),
/**
* end date for search data (ISO8601 format: YYYY-MM-DD)
*/
toDate: z.string().optional(),
/**
* maximum number of search results to consider
* defaults to 20
*/
maxSearchResults: z.number().min(1).max(50).optional(),
/**
* data sources to search from
* defaults to ["web", "x"] if not specified
*/
sources: z.array(searchSourceSchema).optional()
})
.optional()
})
export type XaiProviderOptions = z.infer<typeof xaiProviderOptions>

View File

@@ -0,0 +1,257 @@
# AI Core 插件系统
支持四种钩子类型:**First**、**Sequential**、**Parallel** 和 **Stream**
## 🎯 设计理念
- **语义清晰**:不同钩子有不同的执行语义
- **类型安全**TypeScript 完整支持
- **性能优化**First 短路、Parallel 并发、Sequential 链式
- **易于扩展**`enforce` 排序 + 功能分类
## 📋 钩子类型
### 🥇 First 钩子 - 首个有效结果
```typescript
// 只执行第一个返回值的插件,用于解析和查找
resolveModel?: (modelId: string, context: AiRequestContext) => string | null
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null
```
### 🔄 Sequential 钩子 - 链式数据转换
```typescript
// 按顺序链式执行,每个插件可以修改数据
transformParams?: (params: any, context: AiRequestContext) => any
transformResult?: (result: any, context: AiRequestContext) => any
```
### ⚡ Parallel 钩子 - 并行副作用
```typescript
// 并发执行,用于日志、监控等副作用
onRequestStart?: (context: AiRequestContext) => void
onRequestEnd?: (context: AiRequestContext, result: any) => void
onError?: (error: Error, context: AiRequestContext) => void
```
### 🌊 Stream 钩子 - 流处理
```typescript
// 直接使用 AI SDK 的 TransformStream
transformStream?: () => (options) => TransformStream<TextStreamPart, TextStreamPart>
```
## 🚀 快速开始
### 基础用法
```typescript
import { PluginManager, createContext, definePlugin } from '@cherrystudio/ai-core/middleware'
// 创建插件管理器
const pluginManager = new PluginManager()
// 添加插件
pluginManager.use({
name: 'my-plugin',
async transformParams(params, context) {
return { ...params, temperature: 0.7 }
}
})
// 使用插件
const context = createContext('openai', 'gpt-4', { messages: [] })
const transformedParams = await pluginManager.executeSequential(
'transformParams',
{ messages: [{ role: 'user', content: 'Hello' }] },
context
)
```
### 完整示例
```typescript
import {
PluginManager,
ModelAliasPlugin,
LoggingPlugin,
ParamsValidationPlugin,
createContext
} from '@cherrystudio/ai-core/middleware'
// 创建插件管理器
const manager = new PluginManager([
ModelAliasPlugin, // 模型别名解析
ParamsValidationPlugin, // 参数验证
LoggingPlugin // 日志记录
])
// AI 请求流程
async function aiRequest(providerId: string, modelId: string, params: any) {
const context = createContext(providerId, modelId, params)
try {
// 1. 【并行】触发请求开始事件
await manager.executeParallel('onRequestStart', context)
// 2. 【首个】解析模型别名
const resolvedModel = await manager.executeFirst('resolveModel', modelId, context)
context.modelId = resolvedModel || modelId
// 3. 【串行】转换请求参数
const transformedParams = await manager.executeSequential('transformParams', params, context)
// 4. 【流处理】收集流转换器AI SDK 原生支持数组)
const streamTransforms = manager.collectStreamTransforms()
// 5. 调用 AI SDK这里省略具体实现
const result = await callAiSdk(transformedParams, streamTransforms)
// 6. 【串行】转换响应结果
const transformedResult = await manager.executeSequential('transformResult', result, context)
// 7. 【并行】触发请求完成事件
await manager.executeParallel('onRequestEnd', context, transformedResult)
return transformedResult
} catch (error) {
// 8. 【并行】触发错误事件
await manager.executeParallel('onError', context, undefined, error)
throw error
}
}
```
## 🔧 自定义插件
### 模型别名插件
```typescript
const ModelAliasPlugin = definePlugin({
name: 'model-alias',
enforce: 'pre', // 最先执行
async resolveModel(modelId) {
const aliases = {
gpt4: 'gpt-4-turbo-preview',
claude: 'claude-3-sonnet-20240229'
}
return aliases[modelId] || null
}
})
```
### 参数验证插件
```typescript
const ValidationPlugin = definePlugin({
name: 'validation',
async transformParams(params) {
if (!params.messages) {
throw new Error('messages is required')
}
return {
...params,
temperature: params.temperature ?? 0.7,
max_tokens: params.max_tokens ?? 4096
}
}
})
```
### 监控插件
```typescript
const MonitoringPlugin = definePlugin({
name: 'monitoring',
enforce: 'post', // 最后执行
async onRequestEnd(context, result) {
const duration = Date.now() - context.startTime
console.log(`请求耗时: ${duration}ms`)
}
})
```
### 内容过滤插件
```typescript
const FilterPlugin = definePlugin({
name: 'content-filter',
transformStream() {
return () =>
new TransformStream({
transform(chunk, controller) {
if (chunk.type === 'text-delta') {
const filtered = chunk.textDelta.replace(/敏感词/g, '***')
controller.enqueue({ ...chunk, textDelta: filtered })
} else {
controller.enqueue(chunk)
}
}
})
}
})
```
## 📊 执行顺序
### 插件排序
```
enforce: 'pre' → normal → enforce: 'post'
```
### 钩子执行流程
```mermaid
graph TD
A[请求开始] --> B[onRequestStart 并行执行]
B --> C[resolveModel 首个有效]
C --> D[loadTemplate 首个有效]
D --> E[transformParams 串行执行]
E --> F[collectStreamTransforms]
F --> G[AI SDK 调用]
G --> H[transformResult 串行执行]
H --> I[onRequestEnd 并行执行]
G --> J[异常处理]
J --> K[onError 并行执行]
```
## 💡 最佳实践
1. **功能单一**:每个插件专注一个功能
2. **幂等性**:插件应该是幂等的,重复执行不会产生副作用
3. **错误处理**:插件内部处理异常,不要让异常向上传播
4. **性能优化**使用合适的钩子类型First vs Sequential vs Parallel
5. **命名规范**:使用语义化的插件名称
## 🔍 调试工具
```typescript
// 查看插件统计信息
const stats = manager.getStats()
console.log('插件统计:', stats)
// 查看所有插件
const plugins = manager.getPlugins()
console.log(
'已注册插件:',
plugins.map((p) => p.name)
)
```
## ⚡ 性能优势
- **First 钩子**:一旦找到结果立即停止,避免无效计算
- **Parallel 钩子**:真正并发执行,不阻塞主流程
- **Sequential 钩子**:保证数据转换的顺序性
- **Stream 钩子**:直接集成 AI SDK零开销
这个设计兼顾了简洁性和强大功能,为 AI Core 提供了灵活而高效的扩展机制。

View File

@@ -0,0 +1,10 @@
/**
* 内置插件命名空间
* 所有内置插件都以 'built-in:' 为前缀
*/
export const BUILT_IN_PLUGIN_PREFIX = 'built-in:'
export { createLoggingPlugin } from './logging'
export { createPromptToolUsePlugin } from './toolUsePlugin/promptToolUsePlugin'
export type { PromptToolUseConfig, ToolUseRequestContext, ToolUseResult } from './toolUsePlugin/type'
export { webSearchPlugin } from './webSearchPlugin'

View File

@@ -0,0 +1,86 @@
/**
* 内置插件:日志记录
* 记录AI调用的关键信息支持性能监控和调试
*/
import { definePlugin } from '../index'
import type { AiRequestContext } from '../types'
export interface LoggingConfig {
// 日志级别
level?: 'debug' | 'info' | 'warn' | 'error'
// 是否记录参数
logParams?: boolean
// 是否记录结果
logResult?: boolean
// 是否记录性能数据
logPerformance?: boolean
// 自定义日志函数
logger?: (level: string, message: string, data?: any) => void
}
/**
* 创建日志插件
*/
export function createLoggingPlugin(config: LoggingConfig = {}) {
const { level = 'info', logParams = true, logResult = false, logPerformance = true, logger = console.log } = config
const startTimes = new Map<string, number>()
return definePlugin({
name: 'built-in:logging',
onRequestStart: (context: AiRequestContext) => {
const requestId = context.requestId
startTimes.set(requestId, Date.now())
logger(level, `🚀 AI Request Started`, {
requestId,
providerId: context.providerId,
modelId: context.modelId,
originalParams: logParams ? context.originalParams : '[hidden]'
})
},
onRequestEnd: (context: AiRequestContext, result: any) => {
const requestId = context.requestId
const startTime = startTimes.get(requestId)
const duration = startTime ? Date.now() - startTime : undefined
startTimes.delete(requestId)
const logData: any = {
requestId,
providerId: context.providerId,
modelId: context.modelId
}
if (logPerformance && duration) {
logData.duration = `${duration}ms`
}
if (logResult) {
logData.result = result
}
logger(level, `✅ AI Request Completed`, logData)
},
onError: (error: Error, context: AiRequestContext) => {
const requestId = context.requestId
const startTime = startTimes.get(requestId)
const duration = startTime ? Date.now() - startTime : undefined
startTimes.delete(requestId)
logger('error', `❌ AI Request Failed`, {
requestId,
providerId: context.providerId,
modelId: context.modelId,
duration: duration ? `${duration}ms` : undefined,
error: {
name: error.name,
message: error.message,
stack: error.stack
}
})
}
})
}

View File

@@ -0,0 +1,139 @@
/**
* 流事件管理器
*
* 负责处理 AI SDK 流事件的发送和管理
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
*/
import type { ModelMessage } from 'ai'
import type { AiRequestContext } from '../../types'
import type { StreamController } from './ToolExecutor'
/**
* 流事件管理器类
*/
export class StreamEventManager {
/**
* 发送工具调用步骤开始事件
*/
sendStepStartEvent(controller: StreamController): void {
controller.enqueue({
type: 'start-step',
request: {},
warnings: []
})
}
/**
* 发送步骤完成事件
*/
sendStepFinishEvent(controller: StreamController, chunk: any): void {
controller.enqueue({
type: 'finish-step',
finishReason: 'stop',
response: chunk.response,
usage: chunk.usage,
providerMetadata: chunk.providerMetadata
})
}
/**
* 处理递归调用并将结果流接入当前流
*/
async handleRecursiveCall(
controller: StreamController,
recursiveParams: any,
context: AiRequestContext,
stepId: string
): Promise<void> {
try {
console.log('[MCP Prompt] Starting recursive call after tool execution...')
const recursiveResult = await context.recursiveCall(recursiveParams)
if (recursiveResult && recursiveResult.fullStream) {
await this.pipeRecursiveStream(controller, recursiveResult.fullStream)
} else {
console.warn('[MCP Prompt] No fullstream found in recursive result:', recursiveResult)
}
} catch (error) {
this.handleRecursiveCallError(controller, error, stepId)
}
}
/**
* 将递归流的数据传递到当前流
*/
private async pipeRecursiveStream(controller: StreamController, recursiveStream: ReadableStream): Promise<void> {
const reader = recursiveStream.getReader()
try {
while (true) {
const { done, value } = await reader.read()
if (done) {
break
}
if (value.type === 'finish') {
// 迭代的流不发finish
break
}
// 将递归流的数据传递到当前流
controller.enqueue(value)
}
} finally {
reader.releaseLock()
}
}
/**
* 处理递归调用错误
*/
private handleRecursiveCallError(controller: StreamController, error: unknown, stepId: string): void {
console.error('[MCP Prompt] Recursive call failed:', error)
// 使用 AI SDK 标准错误格式,但不中断流
controller.enqueue({
type: 'error',
error: {
message: error instanceof Error ? error.message : String(error),
name: error instanceof Error ? error.name : 'RecursiveCallError'
}
})
// 继续发送文本增量,保持流的连续性
controller.enqueue({
type: 'text-delta',
id: stepId,
text: '\n\n[工具执行后递归调用失败,继续对话...]'
})
}
/**
* 构建递归调用的参数
*/
buildRecursiveParams(context: AiRequestContext, textBuffer: string, toolResultsText: string, tools: any): any {
// 构建新的对话消息
const newMessages: ModelMessage[] = [
...(context.originalParams.messages || []),
{
role: 'assistant',
content: textBuffer
},
{
role: 'user',
content: toolResultsText
}
]
// 递归调用,继续对话,重新传递 tools
const recursiveParams = {
...context.originalParams,
messages: newMessages,
tools: tools
}
// 更新上下文中的消息
context.originalParams.messages = newMessages
return recursiveParams
}
}

View File

@@ -0,0 +1,156 @@
/**
* 工具执行器
*
* 负责工具的执行、结果格式化和相关事件发送
* 从 promptToolUsePlugin.ts 中提取出来以降低复杂度
*/
import type { ToolSet } from 'ai'
import type { ToolUseResult } from './type'
/**
* 工具执行结果
*/
export interface ExecutedResult {
toolCallId: string
toolName: string
result: any
isError?: boolean
}
/**
* 流控制器类型(从 AI SDK 提取)
*/
export interface StreamController {
enqueue(chunk: any): void
}
/**
* 工具执行器类
*/
export class ToolExecutor {
/**
* 执行多个工具调用
*/
async executeTools(
toolUses: ToolUseResult[],
tools: ToolSet,
controller: StreamController
): Promise<ExecutedResult[]> {
const executedResults: ExecutedResult[] = []
for (const toolUse of toolUses) {
try {
const tool = tools[toolUse.toolName]
if (!tool || typeof tool.execute !== 'function') {
throw new Error(`Tool "${toolUse.toolName}" has no execute method`)
}
// 发送工具调用开始事件
this.sendToolStartEvents(controller, toolUse)
console.log(`[MCP Prompt Stream] Executing tool: ${toolUse.toolName}`, toolUse.arguments)
// 发送 tool-call 事件
controller.enqueue({
type: 'tool-call',
toolCallId: toolUse.id,
toolName: toolUse.toolName,
input: tool.inputSchema
})
const result = await tool.execute(toolUse.arguments, {
toolCallId: toolUse.id,
messages: [],
abortSignal: new AbortController().signal
})
// 发送 tool-result 事件
controller.enqueue({
type: 'tool-result',
toolCallId: toolUse.id,
toolName: toolUse.toolName,
input: toolUse.arguments,
output: result
})
executedResults.push({
toolCallId: toolUse.id,
toolName: toolUse.toolName,
result,
isError: false
})
} catch (error) {
console.error(`[MCP Prompt Stream] Tool execution failed: ${toolUse.toolName}`, error)
// 处理错误情况
const errorResult = this.handleToolError(toolUse, error, controller)
executedResults.push(errorResult)
}
}
return executedResults
}
/**
* 格式化工具结果为 Cherry Studio 标准格式
*/
formatToolResults(executedResults: ExecutedResult[]): string {
return executedResults
.map((tr) => {
if (!tr.isError) {
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <result>${JSON.stringify(tr.result)}</result>\n</tool_use_result>`
} else {
const error = tr.result || 'Unknown error'
return `<tool_use_result>\n <name>${tr.toolName}</name>\n <error>${error}</error>\n</tool_use_result>`
}
})
.join('\n\n')
}
/**
* 发送工具调用开始相关事件
*/
private sendToolStartEvents(controller: StreamController, toolUse: ToolUseResult): void {
// 发送 tool-input-start 事件
controller.enqueue({
type: 'tool-input-start',
id: toolUse.id,
toolName: toolUse.toolName
})
}
/**
* 处理工具执行错误
*/
private handleToolError(
toolUse: ToolUseResult,
error: unknown,
controller: StreamController
// _tools: ToolSet
): ExecutedResult {
// 使用 AI SDK 标准错误格式
// const toolError: TypedToolError<typeof _tools> = {
// type: 'tool-error',
// toolCallId: toolUse.id,
// toolName: toolUse.toolName,
// input: toolUse.arguments,
// error: error instanceof Error ? error.message : String(error)
// }
// controller.enqueue(toolError)
// 发送标准错误事件
controller.enqueue({
type: 'error',
error: error instanceof Error ? error.message : String(error)
})
return {
toolCallId: toolUse.id,
toolName: toolUse.toolName,
result: error instanceof Error ? error.message : String(error),
isError: true
}
}
}

View File

@@ -0,0 +1,373 @@
/**
* 内置插件MCP Prompt 模式
* 为不支持原生 Function Call 的模型提供 prompt 方式的工具调用
* 内置默认逻辑,支持自定义覆盖
*/
import type { TextStreamPart, ToolSet } from 'ai'
import { definePlugin } from '../../index'
import type { AiRequestContext } from '../../types'
import { StreamEventManager } from './StreamEventManager'
import { ToolExecutor } from './ToolExecutor'
import { PromptToolUseConfig, ToolUseResult } from './type'
/**
* 默认系统提示符模板(提取自 Cherry Studio
*/
const DEFAULT_SYSTEM_PROMPT = `In this environment you have access to a set of tools you can use to answer the user's question. \\
You can use one tool per message, and will receive the result of that tool use in the user's response. You use tools step-by-step to accomplish a given task, with each tool use informed by the result of the previous tool use.
## Tool Use Formatting
Tool use is formatted using XML-style tags. The tool name is enclosed in opening and closing tags, and each parameter is similarly enclosed within its own set of tags. Here's the structure:
<tool_use>
<name>{tool_name}</name>
<arguments>{json_arguments}</arguments>
</tool_use>
The tool name should be the exact name of the tool you are using, and the arguments should be a JSON object containing the parameters required by that tool. For example:
<tool_use>
<name>python_interpreter</name>
<arguments>{"code": "5 + 3 + 1294.678"}</arguments>
</tool_use>
The user will respond with the result of the tool use, which should be formatted as follows:
<tool_use_result>
<name>{tool_name}</name>
<result>{result}</result>
</tool_use_result>
The result should be a string, which can represent a file or any other output type. You can use this result as input for the next action.
For example, if the result of the tool use is an image file, you can use it in the next action like this:
<tool_use>
<name>image_transformer</name>
<arguments>{"image": "image_1.jpg"}</arguments>
</tool_use>
Always adhere to this format for the tool use to ensure proper parsing and execution.
## Tool Use Examples
{{ TOOL_USE_EXAMPLES }}
## Tool Use Available Tools
Above example were using notional tools that might not exist for you. You only have access to these tools:
{{ AVAILABLE_TOOLS }}
## Tool Use Rules
Here are the rules you should always follow to solve your task:
1. Always use the right arguments for the tools. Never use variable names as the action arguments, use the value instead.
2. Call a tool only when needed: do not call the search agent if you do not need information, try to solve the task yourself.
3. If no tool call is needed, just answer the question directly.
4. Never re-do a tool call that you previously did with the exact same parameters.
5. For tool use, MAKE SURE use XML tag format as shown in the examples above. Do not use any other format.
# User Instructions
{{ USER_SYSTEM_PROMPT }}
Now Begin! If you solve the task correctly, you will receive a reward of $1,000,000.`
/**
* 默认工具使用示例(提取自 Cherry Studio
*/
const DEFAULT_TOOL_USE_EXAMPLES = `
Here are a few examples using notional tools:
---
User: Generate an image of the oldest person in this document.
A: I can use the document_qa tool to find out who the oldest person is in the document.
<tool_use>
<name>document_qa</name>
<arguments>{"document": "document.pdf", "question": "Who is the oldest person mentioned?"}</arguments>
</tool_use>
User: <tool_use_result>
<name>document_qa</name>
<result>John Doe, a 55 year old lumberjack living in Newfoundland.</result>
</tool_use_result>
A: I can use the image_generator tool to create a portrait of John Doe.
<tool_use>
<name>image_generator</name>
<arguments>{"prompt": "A portrait of John Doe, a 55-year-old man living in Canada."}</arguments>
</tool_use>
User: <tool_use_result>
<name>image_generator</name>
<result>image.png</result>
</tool_use_result>
A: the image is generated as image.png
---
User: "What is the result of the following operation: 5 + 3 + 1294.678?"
A: I can use the python_interpreter tool to calculate the result of the operation.
<tool_use>
<name>python_interpreter</name>
<arguments>{"code": "5 + 3 + 1294.678"}</arguments>
</tool_use>
User: <tool_use_result>
<name>python_interpreter</name>
<result>1302.678</result>
</tool_use_result>
A: The result of the operation is 1302.678.
---
User: "Which city has the highest population , Guangzhou or Shanghai?"
A: I can use the search tool to find the population of Guangzhou.
<tool_use>
<name>search</name>
<arguments>{"query": "Population Guangzhou"}</arguments>
</tool_use>
User: <tool_use_result>
<name>search</name>
<result>Guangzhou has a population of 15 million inhabitants as of 2021.</result>
</tool_use_result>
A: I can use the search tool to find the population of Shanghai.
<tool_use>
<name>search</name>
<arguments>{"query": "Population Shanghai"}</arguments>
</tool_use>
User: <tool_use_result>
<name>search</name>
<result>26 million (2019)</result>
</tool_use_result>
Assistant: The population of Shanghai is 26 million, while Guangzhou has a population of 15 million. Therefore, Shanghai has the highest population.`
/**
* 构建可用工具部分(提取自 Cherry Studio
*/
function buildAvailableTools(tools: ToolSet): string {
const availableTools = Object.keys(tools)
.map((toolName: string) => {
const tool = tools[toolName]
return `
<tool>
<name>${toolName}</name>
<description>${tool.description || ''}</description>
<arguments>
${tool.inputSchema ? JSON.stringify(tool.inputSchema) : ''}
</arguments>
</tool>
`
})
.join('\n')
return `<tools>
${availableTools}
</tools>`
}
/**
* 默认的系统提示符构建函数(提取自 Cherry Studio
*/
function defaultBuildSystemPrompt(userSystemPrompt: string, tools: ToolSet): string {
const availableTools = buildAvailableTools(tools)
const fullPrompt = DEFAULT_SYSTEM_PROMPT.replace('{{ TOOL_USE_EXAMPLES }}', DEFAULT_TOOL_USE_EXAMPLES)
.replace('{{ AVAILABLE_TOOLS }}', availableTools)
.replace('{{ USER_SYSTEM_PROMPT }}', userSystemPrompt || '')
return fullPrompt
}
/**
* 默认工具解析函数(提取自 Cherry Studio
* 解析 XML 格式的工具调用
*/
function defaultParseToolUse(content: string, tools: ToolSet): { results: ToolUseResult[]; content: string } {
if (!content || !tools || Object.keys(tools).length === 0) {
return { results: [], content: content }
}
// 支持两种格式:
// 1. 完整的 <tool_use></tool_use> 标签包围的内容
// 2. 只有内部内容(从 TagExtractor 提取出来的)
let contentToProcess = content
// 如果内容不包含 <tool_use> 标签,说明是从 TagExtractor 提取的内部内容,需要包装
if (!content.includes('<tool_use>')) {
contentToProcess = `<tool_use>\n${content}\n</tool_use>`
}
const toolUsePattern =
/<tool_use>([\s\S]*?)<name>([\s\S]*?)<\/name>([\s\S]*?)<arguments>([\s\S]*?)<\/arguments>([\s\S]*?)<\/tool_use>/g
const results: ToolUseResult[] = []
let match
let idx = 0
// Find all tool use blocks
while ((match = toolUsePattern.exec(contentToProcess)) !== null) {
const fullMatch = match[0]
const toolName = match[2].trim()
const toolArgs = match[4].trim()
// Try to parse the arguments as JSON
let parsedArgs
try {
parsedArgs = JSON.parse(toolArgs)
} catch (error) {
// If parsing fails, use the string as is
parsedArgs = toolArgs
}
// Find the corresponding tool
const tool = tools[toolName]
if (!tool) {
console.warn(`Tool "${toolName}" not found in available tools`)
continue
}
// Add to results array
results.push({
id: `${toolName}-${idx++}`, // Unique ID for each tool use
toolName: toolName,
arguments: parsedArgs,
status: 'pending'
})
contentToProcess = contentToProcess.replace(fullMatch, '')
}
return { results, content: contentToProcess }
}
export const createPromptToolUsePlugin = (config: PromptToolUseConfig = {}) => {
const { enabled = true, buildSystemPrompt = defaultBuildSystemPrompt, parseToolUse = defaultParseToolUse } = config
return definePlugin({
name: 'built-in:prompt-tool-use',
transformParams: (params: any, context: AiRequestContext) => {
if (!enabled || !params.tools || typeof params.tools !== 'object') {
return params
}
context.mcpTools = params.tools
console.log('tools stored in context', params.tools)
// 构建系统提示符
const userSystemPrompt = typeof params.system === 'string' ? params.system : ''
const systemPrompt = buildSystemPrompt(userSystemPrompt, params.tools)
let systemMessage: string | null = systemPrompt
console.log('config.context', context)
if (config.createSystemMessage) {
// 🎯 如果用户提供了自定义处理函数,使用它
systemMessage = config.createSystemMessage(systemPrompt, params, context)
}
// 移除 tools改为 prompt 模式
const transformedParams = {
...params,
...(systemMessage ? { system: systemMessage } : {}),
tools: undefined
}
context.originalParams = transformedParams
console.log('transformedParams', transformedParams)
return transformedParams
},
transformStream: (_: any, context: AiRequestContext) => () => {
let textBuffer = ''
let stepId = ''
if (!context.mcpTools) {
throw new Error('No tools available')
}
// 创建工具执行器和流事件管理器
const toolExecutor = new ToolExecutor()
const streamEventManager = new StreamEventManager()
type TOOLS = NonNullable<typeof context.mcpTools>
return new TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>({
async transform(
chunk: TextStreamPart<TOOLS>,
controller: TransformStreamDefaultController<TextStreamPart<TOOLS>>
) {
// 收集文本内容
if (chunk.type === 'text-delta') {
textBuffer += chunk.text || ''
stepId = chunk.id || ''
controller.enqueue(chunk)
return
}
if (chunk.type === 'text-end' || chunk.type === 'finish-step') {
const tools = context.mcpTools
if (!tools || Object.keys(tools).length === 0) {
controller.enqueue(chunk)
return
}
// 解析工具调用
const { results: parsedTools, content: parsedContent } = parseToolUse(textBuffer, tools)
const validToolUses = parsedTools.filter((t) => t.status === 'pending')
// 如果没有有效的工具调用,直接传递原始事件
if (validToolUses.length === 0) {
controller.enqueue(chunk)
return
}
if (chunk.type === 'text-end') {
controller.enqueue({
type: 'text-end',
id: stepId,
providerMetadata: {
text: {
value: parsedContent
}
}
})
return
}
controller.enqueue({
...chunk,
finishReason: 'tool-calls'
})
// 发送步骤开始事件
streamEventManager.sendStepStartEvent(controller)
// 执行工具调用
const executedResults = await toolExecutor.executeTools(validToolUses, tools, controller)
// 发送步骤完成事件
streamEventManager.sendStepFinishEvent(controller, chunk)
// 处理递归调用
if (validToolUses.length > 0) {
const toolResultsText = toolExecutor.formatToolResults(executedResults)
const recursiveParams = streamEventManager.buildRecursiveParams(
context,
textBuffer,
toolResultsText,
tools
)
await streamEventManager.handleRecursiveCall(controller, recursiveParams, context, stepId)
}
// 清理状态
textBuffer = ''
return
}
// 对于其他类型的事件,直接传递
controller.enqueue(chunk)
},
flush() {
// 流结束时的清理工作
console.log('[MCP Prompt] Stream ended, cleaning up...')
}
})
}
})
}

View File

@@ -0,0 +1,196 @@
// Copied from https://github.com/vercel/ai/blob/main/packages/ai/core/util/get-potential-start-index.ts
/**
* Returns the index of the start of the searchedText in the text, or null if it
* is not found.
*/
export function getPotentialStartIndex(text: string, searchedText: string): number | null {
// Return null immediately if searchedText is empty.
if (searchedText.length === 0) {
return null
}
// Check if the searchedText exists as a direct substring of text.
const directIndex = text.indexOf(searchedText)
if (directIndex !== -1) {
return directIndex
}
// Otherwise, look for the largest suffix of "text" that matches
// a prefix of "searchedText". We go from the end of text inward.
for (let i = text.length - 1; i >= 0; i--) {
const suffix = text.substring(i)
if (searchedText.startsWith(suffix)) {
return i
}
}
return null
}
export interface TagConfig {
openingTag: string
closingTag: string
separator?: string
}
export interface TagExtractionState {
textBuffer: string
isInsideTag: boolean
isFirstTag: boolean
isFirstText: boolean
afterSwitch: boolean
accumulatedTagContent: string
hasTagContent: boolean
}
export interface TagExtractionResult {
content: string
isTagContent: boolean
complete: boolean
tagContentExtracted?: string
}
/**
* 通用标签提取处理器
* 可以处理各种形式的标签对,如 <think>...</think>, <tool_use>...</tool_use> 等
*/
export class TagExtractor {
private config: TagConfig
private state: TagExtractionState
constructor(config: TagConfig) {
this.config = config
this.state = {
textBuffer: '',
isInsideTag: false,
isFirstTag: true,
isFirstText: true,
afterSwitch: false,
accumulatedTagContent: '',
hasTagContent: false
}
}
/**
* 处理文本块,返回处理结果
*/
processText(newText: string): TagExtractionResult[] {
this.state.textBuffer += newText
const results: TagExtractionResult[] = []
// 处理标签提取逻辑
while (true) {
const nextTag = this.state.isInsideTag ? this.config.closingTag : this.config.openingTag
const startIndex = getPotentialStartIndex(this.state.textBuffer, nextTag)
if (startIndex == null) {
const content = this.state.textBuffer
if (content.length > 0) {
results.push({
content: this.addPrefix(content),
isTagContent: this.state.isInsideTag,
complete: false
})
if (this.state.isInsideTag) {
this.state.accumulatedTagContent += this.addPrefix(content)
this.state.hasTagContent = true
}
}
this.state.textBuffer = ''
break
}
// 处理标签前的内容
const contentBeforeTag = this.state.textBuffer.slice(0, startIndex)
if (contentBeforeTag.length > 0) {
results.push({
content: this.addPrefix(contentBeforeTag),
isTagContent: this.state.isInsideTag,
complete: false
})
if (this.state.isInsideTag) {
this.state.accumulatedTagContent += this.addPrefix(contentBeforeTag)
this.state.hasTagContent = true
}
}
const foundFullMatch = startIndex + nextTag.length <= this.state.textBuffer.length
if (foundFullMatch) {
// 如果找到完整的标签
this.state.textBuffer = this.state.textBuffer.slice(startIndex + nextTag.length)
// 如果刚刚结束一个标签内容,生成完整的标签内容结果
if (this.state.isInsideTag && this.state.hasTagContent) {
results.push({
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: this.state.accumulatedTagContent
})
this.state.accumulatedTagContent = ''
this.state.hasTagContent = false
}
this.state.isInsideTag = !this.state.isInsideTag
this.state.afterSwitch = true
if (this.state.isInsideTag) {
this.state.isFirstTag = false
} else {
this.state.isFirstText = false
}
} else {
this.state.textBuffer = this.state.textBuffer.slice(startIndex)
break
}
}
return results
}
/**
* 完成处理,返回任何剩余的标签内容
*/
finalize(): TagExtractionResult | null {
if (this.state.hasTagContent && this.state.accumulatedTagContent) {
const result = {
content: '',
isTagContent: false,
complete: true,
tagContentExtracted: this.state.accumulatedTagContent
}
this.state.accumulatedTagContent = ''
this.state.hasTagContent = false
return result
}
return null
}
private addPrefix(text: string): string {
const needsPrefix =
this.state.afterSwitch && (this.state.isInsideTag ? !this.state.isFirstTag : !this.state.isFirstText)
const prefix = needsPrefix && this.config.separator ? this.config.separator : ''
this.state.afterSwitch = false
return prefix + text
}
/**
* 重置状态
*/
reset(): void {
this.state = {
textBuffer: '',
isInsideTag: false,
isFirstTag: true,
isFirstText: true,
afterSwitch: false,
accumulatedTagContent: '',
hasTagContent: false
}
}
}

View File

@@ -0,0 +1,33 @@
import { ToolSet } from 'ai'
import { AiRequestContext } from '../..'
/**
* 解析结果类型
* 表示从AI响应中解析出的工具使用意图
*/
export interface ToolUseResult {
id: string
toolName: string
arguments: any
status: 'pending' | 'invoking' | 'done' | 'error'
}
export interface BaseToolUsePluginConfig {
enabled?: boolean
}
export interface PromptToolUseConfig extends BaseToolUsePluginConfig {
// 自定义系统提示符构建函数(可选,有默认实现)
buildSystemPrompt?: (userSystemPrompt: string, tools: ToolSet) => string
// 自定义工具解析函数(可选,有默认实现)
parseToolUse?: (content: string, tools: ToolSet) => { results: ToolUseResult[]; content: string }
createSystemMessage?: (systemPrompt: string, originalParams: any, context: AiRequestContext) => string | null
}
/**
* 扩展的 AI 请求上下文,支持 MCP 工具存储
*/
export interface ToolUseRequestContext extends AiRequestContext {
mcpTools: ToolSet
}

View File

@@ -0,0 +1,67 @@
import { anthropic } from '@ai-sdk/anthropic'
import { google } from '@ai-sdk/google'
import { openai } from '@ai-sdk/openai'
import { ProviderOptionsMap } from '../../../options/types'
/**
* 从 AI SDK 的工具函数中提取参数类型,以确保类型安全。
*/
type OpenAISearchConfig = Parameters<typeof openai.tools.webSearchPreview>[0]
type AnthropicSearchConfig = Parameters<typeof anthropic.tools.webSearch_20250305>[0]
type GoogleSearchConfig = Parameters<typeof google.tools.googleSearch>[0]
/**
* 插件初始化时接收的完整配置对象
*
* 其结构与 ProviderOptions 保持一致,方便上游统一管理配置
*/
export interface WebSearchPluginConfig {
openai?: OpenAISearchConfig
anthropic?: AnthropicSearchConfig
xai?: ProviderOptionsMap['xai']['searchParameters']
google?: GoogleSearchConfig
'google-vertex'?: GoogleSearchConfig
}
/**
* 插件的默认配置
*/
export const DEFAULT_WEB_SEARCH_CONFIG: WebSearchPluginConfig = {
google: {},
'google-vertex': {},
openai: {},
xai: {
mode: 'on',
returnCitations: true,
maxSearchResults: 5,
sources: [{ type: 'web' }, { type: 'x' }, { type: 'news' }]
},
anthropic: {
maxUses: 5
}
}
export type WebSearchToolOutputSchema = {
// Anthropic 工具 - 手动定义
anthropicWebSearch: Array<{
url: string
title: string
pageAge: string | null
encryptedContent: string
type: string
}>
// OpenAI 工具 - 基于实际输出
openaiWebSearch: {
status: 'completed' | 'failed'
}
// Google 工具
googleSearch: {
webSearchQueries?: string[]
groundingChunks?: Array<{
web?: { uri: string; title: string }
}>
}
}

View File

@@ -0,0 +1,69 @@
/**
* Web Search Plugin
* 提供统一的网络搜索能力,支持多个 AI Provider
*/
import { anthropic } from '@ai-sdk/anthropic'
import { google } from '@ai-sdk/google'
import { openai } from '@ai-sdk/openai'
import { createXaiOptions, mergeProviderOptions } from '../../../options'
import { definePlugin } from '../../'
import type { AiRequestContext } from '../../types'
import { DEFAULT_WEB_SEARCH_CONFIG, WebSearchPluginConfig } from './helper'
/**
* 网络搜索插件
*
* @param config - 在插件初始化时传入的静态配置
*/
export const webSearchPlugin = (config: WebSearchPluginConfig = DEFAULT_WEB_SEARCH_CONFIG) =>
definePlugin({
name: 'webSearch',
enforce: 'pre',
transformParams: async (params: any, context: AiRequestContext) => {
const { providerId } = context
switch (providerId) {
case 'openai': {
if (config.openai) {
if (!params.tools) params.tools = {}
params.tools.web_search_preview = openai.tools.webSearchPreview(config.openai)
}
break
}
case 'anthropic': {
if (config.anthropic) {
if (!params.tools) params.tools = {}
params.tools.web_search = anthropic.tools.webSearch_20250305(config.anthropic)
}
break
}
case 'google': {
// case 'google-vertex':
if (!params.tools) params.tools = {}
params.tools.web_search = google.tools.googleSearch(config.google || {})
break
}
case 'xai': {
if (config.xai) {
const searchOptions = createXaiOptions({
searchParameters: { ...config.xai, mode: 'on' }
})
params.providerOptions = mergeProviderOptions(params.providerOptions, searchOptions)
}
break
}
}
return params
}
})
// 导出类型定义供开发者使用
export type { WebSearchPluginConfig, WebSearchToolOutputSchema } from './helper'
// 默认导出
export default webSearchPlugin

View File

@@ -0,0 +1,32 @@
// 核心类型和接口
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './types'
import type { ProviderId } from '../providers'
import type { AiPlugin, AiRequestContext } from './types'
// 插件管理器
export { PluginManager } from './manager'
// 工具函数
export function createContext<T extends ProviderId>(
providerId: T,
modelId: string,
originalParams: any
): AiRequestContext {
return {
providerId,
modelId,
originalParams,
metadata: {},
startTime: Date.now(),
requestId: `${providerId}-${modelId}-${Date.now()}-${Math.random().toString(36).slice(2)}`,
// 占位
recursiveCall: () => Promise.resolve(null)
}
}
// 插件构建器 - 便于创建插件
export function definePlugin(plugin: AiPlugin): AiPlugin
export function definePlugin<T extends (...args: any[]) => AiPlugin>(pluginFactory: T): T
export function definePlugin(plugin: AiPlugin | ((...args: any[]) => AiPlugin)) {
return plugin
}

View File

@@ -0,0 +1,184 @@
import { AiPlugin, AiRequestContext } from './types'
/**
* 插件管理器
*/
export class PluginManager {
private plugins: AiPlugin[] = []
constructor(plugins: AiPlugin[] = []) {
this.plugins = this.sortPlugins(plugins)
}
/**
* 添加插件
*/
use(plugin: AiPlugin): this {
this.plugins = this.sortPlugins([...this.plugins, plugin])
return this
}
/**
* 移除插件
*/
remove(pluginName: string): this {
this.plugins = this.plugins.filter((p) => p.name !== pluginName)
return this
}
/**
* 插件排序pre -> normal -> post
*/
private sortPlugins(plugins: AiPlugin[]): AiPlugin[] {
const pre: AiPlugin[] = []
const normal: AiPlugin[] = []
const post: AiPlugin[] = []
plugins.forEach((plugin) => {
if (plugin.enforce === 'pre') {
pre.push(plugin)
} else if (plugin.enforce === 'post') {
post.push(plugin)
} else {
normal.push(plugin)
}
})
return [...pre, ...normal, ...post]
}
/**
* 执行 First 钩子 - 返回第一个有效结果
*/
async executeFirst<T>(
hookName: 'resolveModel' | 'loadTemplate',
arg: any,
context: AiRequestContext
): Promise<T | null> {
for (const plugin of this.plugins) {
const hook = plugin[hookName]
if (hook) {
const result = await hook(arg, context)
if (result !== null && result !== undefined) {
return result as T
}
}
}
return null
}
/**
* 执行 Sequential 钩子 - 链式数据转换
*/
async executeSequential<T>(
hookName: 'transformParams' | 'transformResult',
initialValue: T,
context: AiRequestContext
): Promise<T> {
let result = initialValue
for (const plugin of this.plugins) {
const hook = plugin[hookName]
if (hook) {
result = await hook<T>(result, context)
}
}
return result
}
/**
* 执行 ConfigureContext 钩子 - 串行配置上下文
*/
async executeConfigureContext(context: AiRequestContext): Promise<void> {
for (const plugin of this.plugins) {
const hook = plugin.configureContext
if (hook) {
await hook(context)
}
}
}
/**
* 执行 Parallel 钩子 - 并行副作用
*/
async executeParallel(
hookName: 'onRequestStart' | 'onRequestEnd' | 'onError',
context: AiRequestContext,
result?: any,
error?: Error
): Promise<void> {
const promises = this.plugins
.map((plugin) => {
const hook = plugin[hookName]
if (!hook) return null
if (hookName === 'onError' && error) {
return (hook as any)(error, context)
} else if (hookName === 'onRequestEnd' && result !== undefined) {
return (hook as any)(context, result)
} else if (hookName === 'onRequestStart') {
return (hook as any)(context)
}
return null
})
.filter(Boolean)
// 使用 Promise.all 而不是 allSettled让插件错误能够抛出
await Promise.all(promises)
}
/**
* 收集所有流转换器返回数组AI SDK 原生支持)
*/
collectStreamTransforms(params: any, context: AiRequestContext) {
return this.plugins
.filter((plugin) => plugin.transformStream)
.map((plugin) => plugin.transformStream?.(params, context))
}
/**
* 获取所有插件信息
*/
getPlugins(): AiPlugin[] {
return [...this.plugins]
}
/**
* 获取插件统计信息
*/
getStats() {
const stats = {
total: this.plugins.length,
pre: 0,
normal: 0,
post: 0,
hooks: {
resolveModel: 0,
loadTemplate: 0,
transformParams: 0,
transformResult: 0,
onRequestStart: 0,
onRequestEnd: 0,
onError: 0,
transformStream: 0
}
}
this.plugins.forEach((plugin) => {
// 统计 enforce 类型
if (plugin.enforce === 'pre') stats.pre++
else if (plugin.enforce === 'post') stats.post++
else stats.normal++
// 统计钩子数量
Object.keys(stats.hooks).forEach((hookName) => {
if (plugin[hookName as keyof AiPlugin]) {
stats.hooks[hookName as keyof typeof stats.hooks]++
}
})
})
return stats
}
}

View File

@@ -0,0 +1,79 @@
import type { ImageModelV2 } from '@ai-sdk/provider'
import type { LanguageModel, TextStreamPart, ToolSet } from 'ai'
import { type ProviderId } from '../providers/types'
/**
* 递归调用函数类型
* 使用 any 是因为递归调用时参数和返回类型可能完全不同
*/
export type RecursiveCallFn = (newParams: any) => Promise<any>
/**
* AI 请求上下文
*/
export interface AiRequestContext {
providerId: ProviderId
modelId: string
originalParams: any
metadata: Record<string, any>
startTime: number
requestId: string
recursiveCall: RecursiveCallFn
isRecursiveCall?: boolean
mcpTools?: ToolSet
[key: string]: any
}
/**
* 钩子分类
*/
export interface AiPlugin {
name: string
enforce?: 'pre' | 'post'
// 【First】首个钩子 - 只执行第一个返回值的插件
resolveModel?: (
modelId: string,
context: AiRequestContext
) => Promise<LanguageModel | ImageModelV2 | null> | LanguageModel | ImageModelV2 | null
loadTemplate?: (templateName: string, context: AiRequestContext) => any | null | Promise<any | null>
// 【Sequential】串行钩子 - 链式执行,支持数据转换
configureContext?: (context: AiRequestContext) => void | Promise<void>
transformParams?: <T>(params: T, context: AiRequestContext) => T | Promise<T>
transformResult?: <T>(result: T, context: AiRequestContext) => T | Promise<T>
// 【Parallel】并行钩子 - 不依赖顺序,用于副作用
onRequestStart?: (context: AiRequestContext) => void | Promise<void>
onRequestEnd?: (context: AiRequestContext, result: any) => void | Promise<void>
onError?: (error: Error, context: AiRequestContext) => void | Promise<void>
// 【Stream】流处理 - 直接使用 AI SDK
transformStream?: (
params: any,
context: AiRequestContext
) => <TOOLS extends ToolSet>(options?: {
tools: TOOLS
stopStream: () => void
}) => TransformStream<TextStreamPart<TOOLS>, TextStreamPart<TOOLS>>
// AI SDK 原生中间件
// aiSdkMiddlewares?: LanguageModelV1Middleware[]
}
/**
* 插件管理器配置
*/
export interface PluginManagerConfig {
plugins: AiPlugin[]
context: Partial<AiRequestContext>
}
/**
* 钩子执行结果
*/
export interface HookResult<T = any> {
value: T
stop?: boolean
}

View File

@@ -0,0 +1,101 @@
/**
* Hub Provider - 支持路由到多个底层provider
*
* 支持格式: hubId:providerId:modelId
* 例如: aihubmix:anthropic:claude-3.5-sonnet
*/
import { ProviderV2 } from '@ai-sdk/provider'
import { customProvider } from 'ai'
import { globalRegistryManagement } from './RegistryManagement'
import type { AiSdkMethodName, AiSdkModelReturn, AiSdkModelType } from './types'
export interface HubProviderConfig {
/** Hub的唯一标识符 */
hubId: string
/** 是否启用调试日志 */
debug?: boolean
}
export class HubProviderError extends Error {
constructor(
message: string,
public readonly hubId: string,
public readonly providerId?: string,
public readonly originalError?: Error
) {
super(message)
this.name = 'HubProviderError'
}
}
/**
* 解析Hub模型ID
*/
function parseHubModelId(modelId: string): { provider: string; actualModelId: string } {
const parts = modelId.split(':')
if (parts.length !== 2) {
throw new HubProviderError(`Invalid hub model ID format. Expected "provider:modelId", got: ${modelId}`, 'unknown')
}
return {
provider: parts[0],
actualModelId: parts[1]
}
}
/**
* 创建Hub Provider
*/
export function createHubProvider(config: HubProviderConfig): ProviderV2 {
const { hubId } = config
function getTargetProvider(providerId: string): ProviderV2 {
// 从全局注册表获取provider实例
try {
const provider = globalRegistryManagement.getProvider(providerId)
if (!provider) {
throw new HubProviderError(
`Provider "${providerId}" is not initialized. Please call initializeProvider("${providerId}", options) first.`,
hubId,
providerId
)
}
return provider
} catch (error) {
throw new HubProviderError(
`Failed to get provider "${providerId}": ${error instanceof Error ? error.message : 'Unknown error'}`,
hubId,
providerId,
error instanceof Error ? error : undefined
)
}
}
function resolveModel<T extends AiSdkModelType>(
modelId: string,
modelType: T,
methodName: AiSdkMethodName<T>
): AiSdkModelReturn<T> {
const { provider, actualModelId } = parseHubModelId(modelId)
const targetProvider = getTargetProvider(provider)
const fn = targetProvider[methodName] as (id: string) => AiSdkModelReturn<T>
if (!fn) {
throw new HubProviderError(`Provider "${provider}" does not support ${modelType}`, hubId, provider)
}
return fn(actualModelId)
}
return customProvider({
fallbackProvider: {
languageModel: (modelId: string) => resolveModel(modelId, 'text', 'languageModel'),
textEmbeddingModel: (modelId: string) => resolveModel(modelId, 'embedding', 'textEmbeddingModel'),
imageModel: (modelId: string) => resolveModel(modelId, 'image', 'imageModel'),
transcriptionModel: (modelId: string) => resolveModel(modelId, 'transcription', 'transcriptionModel'),
speechModel: (modelId: string) => resolveModel(modelId, 'speech', 'speechModel')
}
})
}

View File

@@ -0,0 +1,221 @@
/**
* Provider 注册表管理器
* 纯粹的管理功能:存储、检索已配置好的 provider 实例
* 基于 AI SDK 原生的 createProviderRegistry
*/
import { EmbeddingModelV2, ImageModelV2, LanguageModelV2, ProviderV2 } from '@ai-sdk/provider'
import { createProviderRegistry, type ProviderRegistryProvider } from 'ai'
type PROVIDERS = Record<string, ProviderV2>
export const DEFAULT_SEPARATOR = '|'
// export type MODEL_ID = `${string}${typeof DEFAULT_SEPARATOR}${string}`
export class RegistryManagement<SEPARATOR extends string = typeof DEFAULT_SEPARATOR> {
private providers: PROVIDERS = {}
private aliases: Set<string> = new Set() // 记录哪些key是别名
private separator: SEPARATOR
private registry: ProviderRegistryProvider<PROVIDERS, SEPARATOR> | null = null
constructor(options: { separator: SEPARATOR } = { separator: DEFAULT_SEPARATOR as SEPARATOR }) {
this.separator = options.separator
}
/**
* 注册已配置好的 provider 实例
*/
registerProvider(id: string, provider: ProviderV2, aliases?: string[]): this {
// 注册主provider
this.providers[id] = provider
// 注册别名都指向同一个provider实例
if (aliases) {
aliases.forEach((alias) => {
this.providers[alias] = provider // 直接存储引用
this.aliases.add(alias) // 标记为别名
})
}
this.rebuildRegistry()
return this
}
/**
* 获取已注册的provider实例
*/
getProvider(id: string): ProviderV2 | undefined {
return this.providers[id]
}
/**
* 批量注册 providers
*/
registerProviders(providers: Record<string, ProviderV2>): this {
Object.assign(this.providers, providers)
this.rebuildRegistry()
return this
}
/**
* 移除 provider同时清理相关别名
*/
unregisterProvider(id: string): this {
const provider = this.providers[id]
if (!provider) return this
// 如果移除的是真实ID需要清理所有指向它的别名
if (!this.aliases.has(id)) {
// 找到所有指向此provider的别名并删除
const aliasesToRemove: string[] = []
this.aliases.forEach((alias) => {
if (this.providers[alias] === provider) {
aliasesToRemove.push(alias)
}
})
aliasesToRemove.forEach((alias) => {
delete this.providers[alias]
this.aliases.delete(alias)
})
} else {
// 如果移除的是别名,只删除别名记录
this.aliases.delete(id)
}
delete this.providers[id]
this.rebuildRegistry()
return this
}
/**
* 立即重建 registry - 每次变更都重建
*/
private rebuildRegistry(): void {
if (Object.keys(this.providers).length === 0) {
this.registry = null
return
}
this.registry = createProviderRegistry<PROVIDERS, SEPARATOR>(this.providers, {
separator: this.separator
})
}
/**
* 获取语言模型 - AI SDK 原生方法
*/
languageModel(id: `${string}${SEPARATOR}${string}`): LanguageModelV2 {
if (!this.registry) {
throw new Error('No providers registered')
}
return this.registry.languageModel(id)
}
/**
* 获取文本嵌入模型 - AI SDK 原生方法
*/
textEmbeddingModel(id: `${string}${SEPARATOR}${string}`): EmbeddingModelV2<string> {
if (!this.registry) {
throw new Error('No providers registered')
}
return this.registry.textEmbeddingModel(id)
}
/**
* 获取图像模型 - AI SDK 原生方法
*/
imageModel(id: `${string}${SEPARATOR}${string}`): ImageModelV2 {
if (!this.registry) {
throw new Error('No providers registered')
}
return this.registry.imageModel(id)
}
/**
* 获取转录模型 - AI SDK 原生方法
*/
transcriptionModel(id: `${string}${SEPARATOR}${string}`): any {
if (!this.registry) {
throw new Error('No providers registered')
}
return this.registry.transcriptionModel(id)
}
/**
* 获取语音模型 - AI SDK 原生方法
*/
speechModel(id: `${string}${SEPARATOR}${string}`): any {
if (!this.registry) {
throw new Error('No providers registered')
}
return this.registry.speechModel(id)
}
/**
* 获取已注册的 provider 列表
*/
getRegisteredProviders(): string[] {
return Object.keys(this.providers)
}
/**
* 检查是否有已注册的 providers
*/
hasProviders(): boolean {
return Object.keys(this.providers).length > 0
}
/**
* 清除所有 providers
*/
clear(): this {
this.providers = {}
this.aliases.clear()
this.registry = null
return this
}
/**
* 解析真实的Provider ID供getAiSdkProviderId使用
* 如果传入的是别名返回真实的Provider ID
* 如果传入的是真实ID直接返回
*/
resolveProviderId(id: string): string {
if (!this.aliases.has(id)) return id // 不是别名,直接返回
// 是别名找到真实ID
const targetProvider = this.providers[id]
for (const [realId, provider] of Object.entries(this.providers)) {
if (provider === targetProvider && !this.aliases.has(realId)) {
return realId
}
}
return id
}
/**
* 检查是否为别名
*/
isAlias(id: string): boolean {
return this.aliases.has(id)
}
/**
* 获取所有别名映射关系
*/
getAllAliases(): Record<string, string> {
const result: Record<string, string> = {}
this.aliases.forEach((alias) => {
result[alias] = this.resolveProviderId(alias)
})
return result
}
}
/**
* 全局注册表管理器实例
* 使用 | 作为分隔符,因为 : 会和 :free 等suffix冲突
*/
export const globalRegistryManagement = new RegistryManagement()

View File

@@ -0,0 +1,632 @@
/**
* 测试真正的 AiProviderRegistry 功能
*/
import { beforeEach, describe, expect, it, vi } from 'vitest'
// 模拟 AI SDK
vi.mock('@ai-sdk/openai', () => ({
createOpenAI: vi.fn(() => ({ name: 'openai-mock' }))
}))
vi.mock('@ai-sdk/anthropic', () => ({
createAnthropic: vi.fn(() => ({ name: 'anthropic-mock' }))
}))
vi.mock('@ai-sdk/azure', () => ({
createAzure: vi.fn(() => ({ name: 'azure-mock' }))
}))
vi.mock('@ai-sdk/deepseek', () => ({
createDeepSeek: vi.fn(() => ({ name: 'deepseek-mock' }))
}))
vi.mock('@ai-sdk/google', () => ({
createGoogleGenerativeAI: vi.fn(() => ({ name: 'google-mock' }))
}))
vi.mock('@ai-sdk/openai-compatible', () => ({
createOpenAICompatible: vi.fn(() => ({ name: 'openai-compatible-mock' }))
}))
vi.mock('@ai-sdk/xai', () => ({
createXai: vi.fn(() => ({ name: 'xai-mock' }))
}))
import {
cleanup,
clearAllProviders,
createAndRegisterProvider,
createProvider,
getAllProviderConfigAliases,
getAllProviderConfigs,
getInitializedProviders,
getLanguageModel,
getProviderConfig,
getProviderConfigByAlias,
getSupportedProviders,
hasInitializedProviders,
hasProviderConfig,
hasProviderConfigByAlias,
isProviderConfigAlias,
ProviderInitializationError,
providerRegistry,
registerMultipleProviderConfigs,
registerProvider,
registerProviderConfig,
resolveProviderConfigId
} from '../registry'
import type { ProviderConfig } from '../schemas'
describe('Provider Registry 功能测试', () => {
beforeEach(() => {
// 清理状态
cleanup()
})
describe('基础功能', () => {
it('能够获取支持的 providers 列表', () => {
const providers = getSupportedProviders()
expect(Array.isArray(providers)).toBe(true)
expect(providers.length).toBeGreaterThan(0)
// 检查返回的数据结构
providers.forEach((provider) => {
expect(provider).toHaveProperty('id')
expect(provider).toHaveProperty('name')
expect(typeof provider.id).toBe('string')
expect(typeof provider.name).toBe('string')
})
// 包含基础 providers
const providerIds = providers.map((p) => p.id)
expect(providerIds).toContain('openai')
expect(providerIds).toContain('anthropic')
expect(providerIds).toContain('google')
})
it('能够获取已初始化的 providers', () => {
// 初始状态下没有已初始化的 providers
expect(getInitializedProviders()).toEqual([])
expect(hasInitializedProviders()).toBe(false)
})
it('能够访问全局注册管理器', () => {
expect(providerRegistry).toBeDefined()
expect(typeof providerRegistry.clear).toBe('function')
expect(typeof providerRegistry.getRegisteredProviders).toBe('function')
expect(typeof providerRegistry.hasProviders).toBe('function')
})
it('能够获取语言模型', () => {
// 在没有注册 provider 的情况下,这个函数可能会抛出错误或返回 undefined
expect(() => getLanguageModel('non-existent')).not.toThrow()
})
})
describe('Provider 配置注册', () => {
it('能够注册自定义 provider 配置', () => {
const config: ProviderConfig = {
id: 'custom-provider',
name: 'Custom Provider',
creator: vi.fn(() => ({ name: 'custom' })),
supportsImageGeneration: false
}
const success = registerProviderConfig(config)
expect(success).toBe(true)
expect(hasProviderConfig('custom-provider')).toBe(true)
expect(getProviderConfig('custom-provider')).toEqual(config)
})
it('能够注册带别名的 provider 配置', () => {
const config: ProviderConfig = {
id: 'custom-provider-with-aliases',
name: 'Custom Provider with Aliases',
creator: vi.fn(() => ({ name: 'custom-aliased' })),
supportsImageGeneration: false,
aliases: ['alias-1', 'alias-2']
}
const success = registerProviderConfig(config)
expect(success).toBe(true)
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
expect(getProviderConfigByAlias('alias-1')).toEqual(config)
expect(resolveProviderConfigId('alias-1')).toBe('custom-provider-with-aliases')
})
it('拒绝无效的配置', () => {
// 缺少必要字段
const invalidConfig = {
id: 'invalid-provider'
// 缺少 name, creator 等
}
const success = registerProviderConfig(invalidConfig as any)
expect(success).toBe(false)
})
it('能够批量注册 provider 配置', () => {
const configs: ProviderConfig[] = [
{
id: 'provider-1',
name: 'Provider 1',
creator: vi.fn(() => ({ name: 'provider-1' })),
supportsImageGeneration: false
},
{
id: 'provider-2',
name: 'Provider 2',
creator: vi.fn(() => ({ name: 'provider-2' })),
supportsImageGeneration: true
},
{
id: '', // 无效配置
name: 'Invalid Provider',
creator: vi.fn(() => ({ name: 'invalid' })),
supportsImageGeneration: false
} as any
]
const successCount = registerMultipleProviderConfigs(configs)
expect(successCount).toBe(2) // 只有前两个成功
expect(hasProviderConfig('provider-1')).toBe(true)
expect(hasProviderConfig('provider-2')).toBe(true)
expect(hasProviderConfig('')).toBe(false)
})
it('能够获取所有配置和别名信息', () => {
// 注册一些配置
registerProviderConfig({
id: 'test-provider',
name: 'Test Provider',
creator: vi.fn(),
supportsImageGeneration: false,
aliases: ['test-alias']
})
const allConfigs = getAllProviderConfigs()
expect(Array.isArray(allConfigs)).toBe(true)
expect(allConfigs.some((config) => config.id === 'test-provider')).toBe(true)
const aliases = getAllProviderConfigAliases()
expect(aliases['test-alias']).toBe('test-provider')
expect(isProviderConfigAlias('test-alias')).toBe(true)
})
})
describe('Provider 创建和注册', () => {
it('能够创建 provider 实例', async () => {
const config: ProviderConfig = {
id: 'test-create-provider',
name: 'Test Create Provider',
creator: vi.fn(() => ({ name: 'test-created' })),
supportsImageGeneration: false
}
// 先注册配置
registerProviderConfig(config)
// 创建 provider 实例
const provider = await createProvider('test-create-provider', { apiKey: 'test' })
expect(provider).toBeDefined()
expect(config.creator).toHaveBeenCalledWith({ apiKey: 'test' })
})
it('能够注册 provider 到全局管理器', () => {
const mockProvider = { name: 'mock-provider' }
const config: ProviderConfig = {
id: 'test-register-provider',
name: 'Test Register Provider',
creator: vi.fn(() => mockProvider),
supportsImageGeneration: false
}
// 先注册配置
registerProviderConfig(config)
// 注册 provider 到全局管理器
const success = registerProvider('test-register-provider', mockProvider)
expect(success).toBe(true)
// 验证注册成功
const registeredProviders = getInitializedProviders()
expect(registeredProviders).toContain('test-register-provider')
expect(hasInitializedProviders()).toBe(true)
})
it('能够一步完成创建和注册', async () => {
const config: ProviderConfig = {
id: 'test-create-and-register',
name: 'Test Create and Register',
creator: vi.fn(() => ({ name: 'test-both' })),
supportsImageGeneration: false
}
// 先注册配置
registerProviderConfig(config)
// 一步完成创建和注册
const success = await createAndRegisterProvider('test-create-and-register', { apiKey: 'test' })
expect(success).toBe(true)
// 验证注册成功
const registeredProviders = getInitializedProviders()
expect(registeredProviders).toContain('test-create-and-register')
})
})
describe('Registry 管理', () => {
it('能够清理所有配置和注册的 providers', () => {
// 注册一些配置
registerProviderConfig({
id: 'temp-provider',
name: 'Temp Provider',
creator: vi.fn(() => ({ name: 'temp' })),
supportsImageGeneration: false
})
expect(hasProviderConfig('temp-provider')).toBe(true)
// 清理
cleanup()
expect(hasProviderConfig('temp-provider')).toBe(false)
// 但基础配置应该重新加载
expect(hasProviderConfig('openai')).toBe(true) // 基础 providers 会重新初始化
})
it('能够单独清理已注册的 providers', () => {
// 清理所有 providers
clearAllProviders()
expect(getInitializedProviders()).toEqual([])
expect(hasInitializedProviders()).toBe(false)
})
it('ProviderInitializationError 错误类工作正常', () => {
const error = new ProviderInitializationError('Test error', 'test-provider')
expect(error.message).toBe('Test error')
expect(error.providerId).toBe('test-provider')
expect(error.name).toBe('ProviderInitializationError')
})
})
describe('错误处理', () => {
it('优雅处理空配置', () => {
const success = registerProviderConfig(null as any)
expect(success).toBe(false)
})
it('优雅处理未定义配置', () => {
const success = registerProviderConfig(undefined as any)
expect(success).toBe(false)
})
it('处理空字符串 ID', () => {
const config = {
id: '',
name: 'Empty ID Provider',
creator: vi.fn(() => ({ name: 'empty' })),
supportsImageGeneration: false
}
const success = registerProviderConfig(config)
expect(success).toBe(false)
})
it('处理创建不存在配置的 provider', async () => {
await expect(createProvider('non-existent-provider', {})).rejects.toThrow(
'ProviderConfig not found for id: non-existent-provider'
)
})
it('处理注册不存在配置的 provider', () => {
const mockProvider = { name: 'mock' }
const success = registerProvider('non-existent-provider', mockProvider)
expect(success).toBe(false)
})
it('处理获取不存在配置的情况', () => {
expect(getProviderConfig('non-existent')).toBeUndefined()
expect(getProviderConfigByAlias('non-existent-alias')).toBeUndefined()
expect(hasProviderConfig('non-existent')).toBe(false)
expect(hasProviderConfigByAlias('non-existent-alias')).toBe(false)
})
it('处理批量注册时的部分失败', () => {
const mixedConfigs: ProviderConfig[] = [
{
id: 'valid-provider-1',
name: 'Valid Provider 1',
creator: vi.fn(() => ({ name: 'valid-1' })),
supportsImageGeneration: false
},
{
id: '', // 无效配置
name: 'Invalid Provider',
creator: vi.fn(() => ({ name: 'invalid' })),
supportsImageGeneration: false
} as any,
{
id: 'valid-provider-2',
name: 'Valid Provider 2',
creator: vi.fn(() => ({ name: 'valid-2' })),
supportsImageGeneration: true
}
]
const successCount = registerMultipleProviderConfigs(mixedConfigs)
expect(successCount).toBe(2) // 只有两个有效配置成功
expect(hasProviderConfig('valid-provider-1')).toBe(true)
expect(hasProviderConfig('valid-provider-2')).toBe(true)
expect(hasProviderConfig('')).toBe(false)
})
it('处理动态导入失败的情况', async () => {
const config: ProviderConfig = {
id: 'import-test-provider',
name: 'Import Test Provider',
import: vi.fn().mockRejectedValue(new Error('Import failed')),
creatorFunctionName: 'createTest',
supportsImageGeneration: false
}
registerProviderConfig(config)
await expect(createProvider('import-test-provider', {})).rejects.toThrow('Import failed')
})
})
describe('集成测试', () => {
it('正确处理复杂的配置、创建、注册和清理场景', async () => {
// 初始状态验证
const initialConfigs = getAllProviderConfigs()
expect(initialConfigs.length).toBeGreaterThan(0) // 有基础配置
expect(getInitializedProviders()).toEqual([])
// 注册多个带别名的 provider 配置
const configs: ProviderConfig[] = [
{
id: 'integration-provider-1',
name: 'Integration Provider 1',
creator: vi.fn(() => ({ name: 'integration-1' })),
supportsImageGeneration: false,
aliases: ['alias-1', 'short-name-1']
},
{
id: 'integration-provider-2',
name: 'Integration Provider 2',
creator: vi.fn(() => ({ name: 'integration-2' })),
supportsImageGeneration: true,
aliases: ['alias-2', 'short-name-2']
}
]
const successCount = registerMultipleProviderConfigs(configs)
expect(successCount).toBe(2)
// 验证配置注册成功
expect(hasProviderConfig('integration-provider-1')).toBe(true)
expect(hasProviderConfig('integration-provider-2')).toBe(true)
expect(hasProviderConfigByAlias('alias-1')).toBe(true)
expect(hasProviderConfigByAlias('alias-2')).toBe(true)
// 验证别名映射
const aliases = getAllProviderConfigAliases()
expect(aliases['alias-1']).toBe('integration-provider-1')
expect(aliases['alias-2']).toBe('integration-provider-2')
// 创建和注册 providers
const success1 = await createAndRegisterProvider('integration-provider-1', { apiKey: 'test1' })
const success2 = await createAndRegisterProvider('integration-provider-2', { apiKey: 'test2' })
expect(success1).toBe(true)
expect(success2).toBe(true)
// 验证注册成功
const registeredProviders = getInitializedProviders()
expect(registeredProviders).toContain('integration-provider-1')
expect(registeredProviders).toContain('integration-provider-2')
expect(hasInitializedProviders()).toBe(true)
// 清理
cleanup()
// 验证清理后的状态
expect(getInitializedProviders()).toEqual([])
expect(hasProviderConfig('integration-provider-1')).toBe(false)
expect(hasProviderConfig('integration-provider-2')).toBe(false)
expect(getAllProviderConfigAliases()).toEqual({})
// 基础配置应该重新加载
expect(hasProviderConfig('openai')).toBe(true)
})
it('正确处理动态导入配置的 provider', async () => {
const mockModule = { createCustomProvider: vi.fn(() => ({ name: 'custom-dynamic' })) }
const dynamicImportConfig: ProviderConfig = {
id: 'dynamic-import-provider',
name: 'Dynamic Import Provider',
import: vi.fn().mockResolvedValue(mockModule),
creatorFunctionName: 'createCustomProvider',
supportsImageGeneration: false
}
// 注册配置
const configSuccess = registerProviderConfig(dynamicImportConfig)
expect(configSuccess).toBe(true)
// 创建和注册 provider
const registerSuccess = await createAndRegisterProvider('dynamic-import-provider', { apiKey: 'test' })
expect(registerSuccess).toBe(true)
// 验证导入函数被调用
expect(dynamicImportConfig.import).toHaveBeenCalled()
expect(mockModule.createCustomProvider).toHaveBeenCalledWith({ apiKey: 'test' })
// 验证注册成功
expect(getInitializedProviders()).toContain('dynamic-import-provider')
})
it('正确处理大量配置的注册和管理', () => {
const largeConfigList: ProviderConfig[] = []
// 生成50个配置
for (let i = 0; i < 50; i++) {
largeConfigList.push({
id: `bulk-provider-${i}`,
name: `Bulk Provider ${i}`,
creator: vi.fn(() => ({ name: `bulk-${i}` })),
supportsImageGeneration: i % 2 === 0, // 偶数支持图像生成
aliases: [`alias-${i}`, `short-${i}`]
})
}
const successCount = registerMultipleProviderConfigs(largeConfigList)
expect(successCount).toBe(50)
// 验证所有配置都被正确注册
const allConfigs = getAllProviderConfigs()
expect(allConfigs.filter((config) => config.id.startsWith('bulk-provider-')).length).toBe(50)
// 验证别名数量
const aliases = getAllProviderConfigAliases()
const bulkAliases = Object.keys(aliases).filter(
(alias) => alias.startsWith('alias-') || alias.startsWith('short-')
)
expect(bulkAliases.length).toBe(100) // 每个 provider 有2个别名
// 随机验证几个配置
expect(hasProviderConfig('bulk-provider-0')).toBe(true)
expect(hasProviderConfig('bulk-provider-25')).toBe(true)
expect(hasProviderConfig('bulk-provider-49')).toBe(true)
// 验证别名工作正常
expect(resolveProviderConfigId('alias-25')).toBe('bulk-provider-25')
expect(isProviderConfigAlias('short-30')).toBe(true)
// 清理能正确处理大量数据
cleanup()
const cleanupAliases = getAllProviderConfigAliases()
expect(
Object.keys(cleanupAliases).filter((alias) => alias.startsWith('alias-') || alias.startsWith('short-'))
).toEqual([])
})
})
describe('边界测试', () => {
it('处理包含特殊字符的 provider IDs', () => {
const specialCharsConfigs: ProviderConfig[] = [
{
id: 'provider-with-dashes',
name: 'Provider With Dashes',
creator: vi.fn(() => ({ name: 'dashes' })),
supportsImageGeneration: false
},
{
id: 'provider_with_underscores',
name: 'Provider With Underscores',
creator: vi.fn(() => ({ name: 'underscores' })),
supportsImageGeneration: false
},
{
id: 'provider.with.dots',
name: 'Provider With Dots',
creator: vi.fn(() => ({ name: 'dots' })),
supportsImageGeneration: false
}
]
const successCount = registerMultipleProviderConfigs(specialCharsConfigs)
expect(successCount).toBeGreaterThan(0) // 至少有一些成功
// 验证支持的特殊字符格式
if (hasProviderConfig('provider-with-dashes')) {
expect(getProviderConfig('provider-with-dashes')).toBeDefined()
}
if (hasProviderConfig('provider_with_underscores')) {
expect(getProviderConfig('provider_with_underscores')).toBeDefined()
}
})
it('处理空的批量注册', () => {
const successCount = registerMultipleProviderConfigs([])
expect(successCount).toBe(0)
// 确保没有额外的配置被添加
const configsBefore = getAllProviderConfigs().length
expect(configsBefore).toBeGreaterThan(0) // 应该有基础配置
})
it('处理重复的配置注册', () => {
const config: ProviderConfig = {
id: 'duplicate-test-provider',
name: 'Duplicate Test Provider',
creator: vi.fn(() => ({ name: 'duplicate' })),
supportsImageGeneration: false
}
// 第一次注册成功
expect(registerProviderConfig(config)).toBe(true)
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
// 重复注册相同的配置(允许覆盖)
const updatedConfig: ProviderConfig = {
...config,
name: 'Updated Duplicate Test Provider'
}
expect(registerProviderConfig(updatedConfig)).toBe(true)
expect(hasProviderConfig('duplicate-test-provider')).toBe(true)
// 验证配置被更新
const retrievedConfig = getProviderConfig('duplicate-test-provider')
expect(retrievedConfig?.name).toBe('Updated Duplicate Test Provider')
})
it('处理极长的 ID 和名称', () => {
const longId = 'very-long-provider-id-' + 'x'.repeat(100)
const longName = 'Very Long Provider Name ' + 'Y'.repeat(100)
const config: ProviderConfig = {
id: longId,
name: longName,
creator: vi.fn(() => ({ name: 'long-test' })),
supportsImageGeneration: false
}
const success = registerProviderConfig(config)
expect(success).toBe(true)
expect(hasProviderConfig(longId)).toBe(true)
const retrievedConfig = getProviderConfig(longId)
expect(retrievedConfig?.name).toBe(longName)
})
it('处理大量别名的配置', () => {
const manyAliases = Array.from({ length: 50 }, (_, i) => `alias-${i}`)
const config: ProviderConfig = {
id: 'provider-with-many-aliases',
name: 'Provider With Many Aliases',
creator: vi.fn(() => ({ name: 'many-aliases' })),
supportsImageGeneration: false,
aliases: manyAliases
}
const success = registerProviderConfig(config)
expect(success).toBe(true)
// 验证所有别名都能正确解析
manyAliases.forEach((alias) => {
expect(hasProviderConfigByAlias(alias)).toBe(true)
expect(resolveProviderConfigId(alias)).toBe('provider-with-many-aliases')
expect(isProviderConfigAlias(alias)).toBe(true)
})
})
})
})

View File

@@ -0,0 +1,264 @@
import { describe, expect, it, vi } from 'vitest'
import {
type BaseProviderId,
baseProviderIds,
baseProviderIdSchema,
baseProviders,
type CustomProviderId,
customProviderIdSchema,
providerConfigSchema,
type ProviderId,
providerIdSchema
} from '../schemas'
describe('Provider Schemas', () => {
describe('baseProviders', () => {
it('包含所有预期的基础 providers', () => {
expect(baseProviders).toBeDefined()
expect(Array.isArray(baseProviders)).toBe(true)
expect(baseProviders.length).toBeGreaterThan(0)
const expectedIds = [
'openai',
'openai-responses',
'openai-compatible',
'anthropic',
'google',
'xai',
'azure',
'deepseek'
]
const actualIds = baseProviders.map((p) => p.id)
expectedIds.forEach((id) => {
expect(actualIds).toContain(id)
})
})
it('每个基础 provider 有必要的属性', () => {
baseProviders.forEach((provider) => {
expect(provider).toHaveProperty('id')
expect(provider).toHaveProperty('name')
expect(provider).toHaveProperty('creator')
expect(provider).toHaveProperty('supportsImageGeneration')
expect(typeof provider.id).toBe('string')
expect(typeof provider.name).toBe('string')
expect(typeof provider.creator).toBe('function')
expect(typeof provider.supportsImageGeneration).toBe('boolean')
})
})
it('provider ID 是唯一的', () => {
const ids = baseProviders.map((p) => p.id)
const uniqueIds = [...new Set(ids)]
expect(ids).toEqual(uniqueIds)
})
})
describe('baseProviderIds', () => {
it('正确提取所有基础 provider IDs', () => {
expect(baseProviderIds).toBeDefined()
expect(Array.isArray(baseProviderIds)).toBe(true)
expect(baseProviderIds.length).toBe(baseProviders.length)
baseProviders.forEach((provider) => {
expect(baseProviderIds).toContain(provider.id)
})
})
})
describe('baseProviderIdSchema', () => {
it('验证有效的基础 provider IDs', () => {
baseProviderIds.forEach((id) => {
expect(baseProviderIdSchema.safeParse(id).success).toBe(true)
})
})
it('拒绝无效的基础 provider IDs', () => {
const invalidIds = ['invalid', 'not-exists', '']
invalidIds.forEach((id) => {
expect(baseProviderIdSchema.safeParse(id).success).toBe(false)
})
})
})
describe('customProviderIdSchema', () => {
it('接受有效的自定义 provider IDs', () => {
const validIds = ['custom-provider', 'my-ai-service', 'company-llm-v2']
validIds.forEach((id) => {
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
})
})
it('拒绝与基础 provider IDs 冲突的 IDs', () => {
baseProviderIds.forEach((id) => {
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
})
})
it('拒绝空字符串', () => {
expect(customProviderIdSchema.safeParse('').success).toBe(false)
})
})
describe('providerIdSchema', () => {
it('接受基础 provider IDs', () => {
baseProviderIds.forEach((id) => {
expect(providerIdSchema.safeParse(id).success).toBe(true)
})
})
it('接受有效的自定义 provider IDs', () => {
const validCustomIds = ['custom-provider', 'my-ai-service']
validCustomIds.forEach((id) => {
expect(providerIdSchema.safeParse(id).success).toBe(true)
})
})
it('拒绝无效的 IDs', () => {
const invalidIds = ['', undefined, null, 123]
invalidIds.forEach((id) => {
expect(providerIdSchema.safeParse(id).success).toBe(false)
})
})
})
describe('providerConfigSchema', () => {
it('验证带有 creator 的有效配置', () => {
const validConfig = {
id: 'custom-provider',
name: 'Custom Provider',
creator: vi.fn(),
supportsImageGeneration: true
}
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
})
it('验证带有 import 配置的有效配置', () => {
const validConfig = {
id: 'custom-provider',
name: 'Custom Provider',
import: vi.fn(),
creatorFunctionName: 'createCustom',
supportsImageGeneration: false
}
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
})
it('拒绝既没有 creator 也没有 import 配置的配置', () => {
const invalidConfig = {
id: 'invalid',
name: 'Invalid Provider',
supportsImageGeneration: false
}
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
})
it('为 supportsImageGeneration 设置默认值', () => {
const config = {
id: 'test',
name: 'Test',
creator: vi.fn()
}
const result = providerConfigSchema.safeParse(config)
expect(result.success).toBe(true)
if (result.success) {
expect(result.data.supportsImageGeneration).toBe(false)
}
})
it('拒绝使用基础 provider ID 的配置', () => {
const invalidConfig = {
id: 'openai', // 基础 provider ID
name: 'Should Fail',
creator: vi.fn()
}
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
})
it('拒绝缺少必需字段的配置', () => {
const invalidConfigs = [
{ name: 'Missing ID', creator: vi.fn() },
{ id: 'missing-name', creator: vi.fn() },
{ id: '', name: 'Empty ID', creator: vi.fn() },
{ id: 'valid-custom', name: '', creator: vi.fn() }
]
invalidConfigs.forEach((config) => {
expect(providerConfigSchema.safeParse(config).success).toBe(false)
})
})
})
describe('Schema 验证功能', () => {
it('baseProviderIdSchema 正确验证基础 provider IDs', () => {
baseProviderIds.forEach((id) => {
expect(baseProviderIdSchema.safeParse(id).success).toBe(true)
})
expect(baseProviderIdSchema.safeParse('invalid-id').success).toBe(false)
})
it('customProviderIdSchema 正确验证自定义 provider IDs', () => {
const customIds = ['custom-provider', 'my-service', 'company-llm']
customIds.forEach((id) => {
expect(customProviderIdSchema.safeParse(id).success).toBe(true)
})
// 拒绝基础 provider IDs
baseProviderIds.forEach((id) => {
expect(customProviderIdSchema.safeParse(id).success).toBe(false)
})
})
it('providerIdSchema 接受基础和自定义 provider IDs', () => {
// 基础 IDs
baseProviderIds.forEach((id) => {
expect(providerIdSchema.safeParse(id).success).toBe(true)
})
// 自定义 IDs
const customIds = ['custom-provider', 'my-service']
customIds.forEach((id) => {
expect(providerIdSchema.safeParse(id).success).toBe(true)
})
})
it('providerConfigSchema 验证完整的 provider 配置', () => {
const validConfig = {
id: 'custom-provider',
name: 'Custom Provider',
creator: vi.fn(),
supportsImageGeneration: true
}
expect(providerConfigSchema.safeParse(validConfig).success).toBe(true)
const invalidConfig = {
id: 'openai', // 不允许基础 provider ID
name: 'OpenAI',
creator: vi.fn()
}
expect(providerConfigSchema.safeParse(invalidConfig).success).toBe(false)
})
})
describe('类型推导', () => {
it('BaseProviderId 类型正确', () => {
const id: BaseProviderId = 'openai'
expect(baseProviderIds).toContain(id)
})
it('CustomProviderId 类型是字符串', () => {
const id: CustomProviderId = 'custom-provider'
expect(typeof id).toBe('string')
})
it('ProviderId 类型支持基础和自定义 IDs', () => {
const baseId: ProviderId = 'openai'
const customId: ProviderId = 'custom-provider'
expect(typeof baseId).toBe('string')
expect(typeof customId).toBe('string')
})
})
})

View File

@@ -0,0 +1,291 @@
/**
* AI Provider 配置工厂
* 提供类型安全的 Provider 配置构建器
*/
import type { ProviderId, ProviderSettingsMap } from './types'
/**
* 通用配置基础类型,包含所有 Provider 共有的属性
*/
export interface BaseProviderConfig {
apiKey?: string
baseURL?: string
timeout?: number
headers?: Record<string, string>
fetch?: typeof globalThis.fetch
}
/**
* 完整的配置类型结合基础配置、AI SDK 配置和特定 Provider 配置
*/
type CompleteProviderConfig<T extends ProviderId> = BaseProviderConfig & Partial<ProviderSettingsMap[T]>
type ConfigHandler<T extends ProviderId> = (
builder: ProviderConfigBuilder<T>,
provider: CompleteProviderConfig<T>
) => void
const configHandlers: {
[K in ProviderId]?: ConfigHandler<K>
} = {
azure: (builder, provider) => {
const azureBuilder = builder as ProviderConfigBuilder<'azure'>
const azureProvider = provider as CompleteProviderConfig<'azure'>
azureBuilder.withAzureConfig({
apiVersion: azureProvider.apiVersion,
resourceName: azureProvider.resourceName
})
}
}
export class ProviderConfigBuilder<T extends ProviderId = ProviderId> {
private config: CompleteProviderConfig<T> = {} as CompleteProviderConfig<T>
constructor(private providerId: T) {}
/**
* 设置 API Key
*/
withApiKey(apiKey: string): this
withApiKey(apiKey: string, options: T extends 'openai' ? { organization?: string; project?: string } : never): this
withApiKey(apiKey: string, options?: any): this {
this.config.apiKey = apiKey
// 类型安全的 OpenAI 特定配置
if (this.providerId === 'openai' && options) {
const openaiConfig = this.config as CompleteProviderConfig<'openai'>
if (options.organization) {
openaiConfig.organization = options.organization
}
if (options.project) {
openaiConfig.project = options.project
}
}
return this
}
/**
* 设置基础 URL
*/
withBaseURL(baseURL: string) {
this.config.baseURL = baseURL
return this
}
/**
* 设置请求配置
*/
withRequestConfig(options: { headers?: Record<string, string>; fetch?: typeof fetch }): this {
if (options.headers) {
this.config.headers = { ...this.config.headers, ...options.headers }
}
if (options.fetch) {
this.config.fetch = options.fetch
}
return this
}
/**
* Azure OpenAI 特定配置
*/
withAzureConfig(options: { apiVersion?: string; resourceName?: string }): T extends 'azure' ? this : never
withAzureConfig(options: any): any {
if (this.providerId === 'azure') {
const azureConfig = this.config as CompleteProviderConfig<'azure'>
if (options.apiVersion) {
azureConfig.apiVersion = options.apiVersion
}
if (options.resourceName) {
azureConfig.resourceName = options.resourceName
}
}
return this
}
/**
* 设置自定义参数
*/
withCustomParams(params: Record<string, any>) {
Object.assign(this.config, params)
return this
}
/**
* 构建最终配置
*/
build(): ProviderSettingsMap[T] {
return this.config as ProviderSettingsMap[T]
}
}
/**
* Provider 配置工厂
* 提供便捷的配置创建方法
*/
export class ProviderConfigFactory {
/**
* 创建配置构建器
*/
static builder<T extends ProviderId>(providerId: T): ProviderConfigBuilder<T> {
return new ProviderConfigBuilder(providerId)
}
/**
* 从通用Provider对象创建配置 - 使用更优雅的处理器模式
*/
static fromProvider<T extends ProviderId>(
providerId: T,
provider: CompleteProviderConfig<T>,
options?: {
headers?: Record<string, string>
[key: string]: any
}
): ProviderSettingsMap[T] {
const builder = new ProviderConfigBuilder<T>(providerId)
// 设置基本配置
if (provider.apiKey) {
builder.withApiKey(provider.apiKey)
}
if (provider.baseURL) {
builder.withBaseURL(provider.baseURL)
}
// 设置请求配置
if (options?.headers) {
builder.withRequestConfig({
headers: options.headers
})
}
// 使用配置处理器模式 - 更加优雅和可扩展
const handler = configHandlers[providerId]
if (handler) {
handler(builder, provider)
}
// 添加其他自定义参数
if (options) {
const customOptions = { ...options }
delete customOptions.headers // 已经处理过了
if (Object.keys(customOptions).length > 0) {
builder.withCustomParams(customOptions)
}
}
return builder.build()
}
/**
* 快速创建 OpenAI 配置
*/
static createOpenAI(
apiKey: string,
options?: {
baseURL?: string
organization?: string
project?: string
}
) {
const builder = this.builder('openai')
// 使用类型安全的重载
if (options?.organization || options?.project) {
builder.withApiKey(apiKey, {
organization: options.organization,
project: options.project
})
} else {
builder.withApiKey(apiKey)
}
return builder.withBaseURL(options?.baseURL || 'https://api.openai.com').build()
}
/**
* 快速创建 Anthropic 配置
*/
static createAnthropic(
apiKey: string,
options?: {
baseURL?: string
}
) {
return this.builder('anthropic')
.withApiKey(apiKey)
.withBaseURL(options?.baseURL || 'https://api.anthropic.com')
.build()
}
/**
* 快速创建 Azure OpenAI 配置
*/
static createAzureOpenAI(
apiKey: string,
options: {
baseURL: string
apiVersion?: string
resourceName?: string
}
) {
return this.builder('azure')
.withApiKey(apiKey)
.withBaseURL(options.baseURL)
.withAzureConfig({
apiVersion: options.apiVersion,
resourceName: options.resourceName
})
.build()
}
/**
* 快速创建 Google 配置
*/
static createGoogle(
apiKey: string,
options?: {
baseURL?: string
projectId?: string
location?: string
}
) {
return this.builder('google')
.withApiKey(apiKey)
.withBaseURL(options?.baseURL || 'https://generativelanguage.googleapis.com')
.build()
}
/**
* 快速创建 Vertex AI 配置
*/
static createVertexAI() {
// credentials: {
// clientEmail: string
// privateKey: string
// },
// options?: {
// project?: string
// location?: string
// }
// return this.builder('google-vertex')
// .withGoogleCredentials(credentials)
// .withGoogleVertexConfig({
// project: options?.project,
// location: options?.location
// })
// .build()
}
static createOpenAICompatible(baseURL: string, apiKey: string) {
return this.builder('openai-compatible').withBaseURL(baseURL).withApiKey(apiKey).build()
}
}
/**
* 便捷的配置创建函数
*/
export const createProviderConfig = ProviderConfigFactory.fromProvider
export const providerConfigBuilder = ProviderConfigFactory.builder

View File

@@ -0,0 +1,83 @@
/**
* Providers 模块统一导出 - 独立Provider包
*/
// ==================== 核心管理器 ====================
// Provider 注册表管理器
export { globalRegistryManagement, RegistryManagement } from './RegistryManagement'
// Provider 核心功能
export {
// 状态管理
cleanup,
clearAllProviders,
createAndRegisterProvider,
createProvider,
getAllProviderConfigAliases,
getAllProviderConfigs,
getImageModel,
// 工具函数
getInitializedProviders,
getLanguageModel,
getProviderConfig,
getProviderConfigByAlias,
getSupportedProviders,
getTextEmbeddingModel,
hasInitializedProviders,
// 工具函数
hasProviderConfig,
// 别名支持
hasProviderConfigByAlias,
isProviderConfigAlias,
// 错误类型
ProviderInitializationError,
// 全局访问
providerRegistry,
registerMultipleProviderConfigs,
registerProvider,
// 统一Provider系统
registerProviderConfig,
resolveProviderConfigId
} from './registry'
// ==================== 基础数据和类型 ====================
// 基础Provider数据源
export { baseProviderIds, baseProviders } from './schemas'
// 类型定义和Schema
export type {
BaseProviderId,
CustomProviderId,
DynamicProviderRegistration,
ProviderConfig,
ProviderId
} from './schemas' // 从 schemas 导出的类型
export { baseProviderIdSchema, customProviderIdSchema, providerConfigSchema, providerIdSchema } from './schemas' // Schema 导出
export type {
DynamicProviderRegistry,
ExtensibleProviderSettingsMap,
ProviderError,
ProviderSettingsMap,
ProviderTypeRegistrar
} from './types'
// ==================== 工具函数 ====================
// Provider配置工厂
export {
type BaseProviderConfig,
createProviderConfig,
ProviderConfigBuilder,
providerConfigBuilder,
ProviderConfigFactory
} from './factory'
// 工具函数
export { formatPrivateKey } from './utils'
// ==================== 扩展功能 ====================
// Hub Provider 功能
export { createHubProvider, type HubProviderConfig, HubProviderError } from './HubProvider'

View File

@@ -0,0 +1,310 @@
/**
* Provider 初始化器
* 负责根据配置创建 providers 并注册到全局管理器
* 集成了来自 ModelCreator 的特殊处理逻辑
*/
import { customProvider } from 'ai'
import { globalRegistryManagement } from './RegistryManagement'
import { baseProviders, type ProviderConfig } from './schemas'
/**
* Provider 初始化错误类型
*/
class ProviderInitializationError extends Error {
constructor(
message: string,
public providerId?: string,
public cause?: Error
) {
super(message)
this.name = 'ProviderInitializationError'
}
}
// ==================== 全局管理器导出 ====================
export { globalRegistryManagement as providerRegistry }
// ==================== 便捷访问方法 ====================
export const getLanguageModel = (id: string) => globalRegistryManagement.languageModel(id as any)
export const getTextEmbeddingModel = (id: string) => globalRegistryManagement.textEmbeddingModel(id as any)
export const getImageModel = (id: string) => globalRegistryManagement.imageModel(id as any)
// ==================== 工具函数 ====================
/**
* 获取支持的 Providers 列表
*/
export function getSupportedProviders(): Array<{
id: string
name: string
}> {
return baseProviders.map((provider) => ({
id: provider.id,
name: provider.name
}))
}
/**
* 获取所有已初始化的 providers
*/
export function getInitializedProviders(): string[] {
return globalRegistryManagement.getRegisteredProviders()
}
/**
* 检查是否有任何已初始化的 providers
*/
export function hasInitializedProviders(): boolean {
return globalRegistryManagement.hasProviders()
}
// ==================== 统一Provider配置系统 ====================
// 全局Provider配置存储
const providerConfigs = new Map<string, ProviderConfig>()
// 全局ProviderConfig别名映射 - 借鉴RegistryManagement模式
const providerConfigAliases = new Map<string, string>() // alias -> realId
/**
* 初始化内置配置 - 将baseProviders转换为统一格式
*/
function initializeBuiltInConfigs(): void {
baseProviders.forEach((provider) => {
const config: ProviderConfig = {
id: provider.id,
name: provider.name,
creator: provider.creator as any, // 类型转换以兼容多种creator签名
supportsImageGeneration: provider.supportsImageGeneration || false
}
providerConfigs.set(provider.id, config)
})
}
// 启动时自动注册内置配置
initializeBuiltInConfigs()
/**
* 步骤1: 注册Provider配置 - 仅存储配置,不执行创建
*/
export function registerProviderConfig(config: ProviderConfig): boolean {
try {
// 验证配置
if (!config.id || !config.name) {
return false
}
// 检查是否与已有配置冲突(包括内置配置)
if (providerConfigs.has(config.id)) {
console.warn(`ProviderConfig "${config.id}" already exists, will override`)
}
// 存储配置(内置和用户配置统一处理)
providerConfigs.set(config.id, config)
// 处理别名
if (config.aliases && config.aliases.length > 0) {
config.aliases.forEach((alias) => {
if (providerConfigAliases.has(alias)) {
console.warn(`ProviderConfig alias "${alias}" already exists, will override`)
}
providerConfigAliases.set(alias, config.id)
})
}
return true
} catch (error) {
console.error(`Failed to register ProviderConfig:`, error)
return false
}
}
/**
* 步骤2: 创建Provider - 根据配置执行实际创建
*/
export async function createProvider(providerId: string, options: any): Promise<any> {
// 支持通过别名查找配置
const config = getProviderConfigByAlias(providerId)
if (!config) {
throw new Error(`ProviderConfig not found for id: ${providerId}`)
}
try {
let creator: (options: any) => any
if (config.creator) {
// 方式1: 直接执行 creator
creator = config.creator
} else if (config.import && config.creatorFunctionName) {
// 方式2: 动态导入并执行
const module = await config.import()
creator = (module as any)[config.creatorFunctionName]
if (!creator || typeof creator !== 'function') {
throw new Error(`Creator function "${config.creatorFunctionName}" not found in imported module`)
}
} else {
throw new Error('No valid creator method provided in ProviderConfig')
}
// 使用真实配置创建provider实例
return creator(options)
} catch (error) {
console.error(`Failed to create provider "${providerId}":`, error)
throw error
}
}
/**
* 步骤3: 注册Provider到全局管理器
*/
export function registerProvider(providerId: string, provider: any): boolean {
try {
const config = providerConfigs.get(providerId)
if (!config) {
console.error(`ProviderConfig not found for id: ${providerId}`)
return false
}
// 获取aliases配置
const aliases = config.aliases
// 处理特殊provider逻辑
if (providerId === 'openai') {
// 注册默认 openai
globalRegistryManagement.registerProvider('openai', provider, aliases)
// 创建并注册 openai-chat 变体
const openaiChatProvider = customProvider({
fallbackProvider: {
...provider,
languageModel: (modelId: string) => provider.chat(modelId)
}
})
globalRegistryManagement.registerProvider('openai-chat', openaiChatProvider)
} else {
// 其他provider直接注册
globalRegistryManagement.registerProvider(providerId, provider, aliases)
}
return true
} catch (error) {
console.error(`Failed to register provider "${providerId}" to global registry:`, error)
return false
}
}
/**
* 便捷函数: 一次性完成创建+注册
*/
export async function createAndRegisterProvider(providerId: string, options: any): Promise<boolean> {
try {
// 步骤2: 创建provider
const provider = await createProvider(providerId, options)
// 步骤3: 注册到全局管理器
return registerProvider(providerId, provider)
} catch (error) {
console.error(`Failed to create and register provider "${providerId}":`, error)
return false
}
}
/**
* 批量注册Provider配置
*/
export function registerMultipleProviderConfigs(configs: ProviderConfig[]): number {
let successCount = 0
configs.forEach((config) => {
if (registerProviderConfig(config)) {
successCount++
}
})
return successCount
}
/**
* 检查是否有对应的Provider配置
*/
export function hasProviderConfig(providerId: string): boolean {
return providerConfigs.has(providerId)
}
/**
* 通过别名或ID检查是否有对应的Provider配置
*/
export function hasProviderConfigByAlias(aliasOrId: string): boolean {
const realId = resolveProviderConfigId(aliasOrId)
return providerConfigs.has(realId)
}
/**
* 获取所有Provider配置
*/
export function getAllProviderConfigs(): ProviderConfig[] {
return Array.from(providerConfigs.values())
}
/**
* 根据ID获取Provider配置
*/
export function getProviderConfig(providerId: string): ProviderConfig | undefined {
return providerConfigs.get(providerId)
}
/**
* 通过别名或ID获取Provider配置
*/
export function getProviderConfigByAlias(aliasOrId: string): ProviderConfig | undefined {
// 先检查是否为别名如果是则解析为真实ID
const realId = providerConfigAliases.get(aliasOrId) || aliasOrId
return providerConfigs.get(realId)
}
/**
* 解析真实的ProviderConfig ID去别名化
*/
export function resolveProviderConfigId(aliasOrId: string): string {
return providerConfigAliases.get(aliasOrId) || aliasOrId
}
/**
* 检查是否为ProviderConfig别名
*/
export function isProviderConfigAlias(id: string): boolean {
return providerConfigAliases.has(id)
}
/**
* 获取所有ProviderConfig别名映射关系
*/
export function getAllProviderConfigAliases(): Record<string, string> {
const result: Record<string, string> = {}
providerConfigAliases.forEach((realId, alias) => {
result[alias] = realId
})
return result
}
/**
* 清理所有Provider配置和已注册的providers
*/
export function cleanup(): void {
providerConfigs.clear()
providerConfigAliases.clear() // 清理别名映射
globalRegistryManagement.clear()
// 重新初始化内置配置
initializeBuiltInConfigs()
}
export function clearAllProviders(): void {
globalRegistryManagement.clear()
}
// ==================== 导出错误类型 ====================
export { ProviderInitializationError }

View File

@@ -0,0 +1,132 @@
/**
* Provider Config 定义
*/
import { createAnthropic } from '@ai-sdk/anthropic'
import { createAzure } from '@ai-sdk/azure'
import { createDeepSeek } from '@ai-sdk/deepseek'
import { createGoogleGenerativeAI } from '@ai-sdk/google'
import { createOpenAI, type OpenAIProviderSettings } from '@ai-sdk/openai'
import { createOpenAICompatible } from '@ai-sdk/openai-compatible'
import { createXai } from '@ai-sdk/xai'
import * as z from 'zod'
/**
* 基础 Providers 定义
* 作为唯一数据源,避免重复维护
*/
export const baseProviders = [
{
id: 'openai',
name: 'OpenAI',
creator: createOpenAI,
supportsImageGeneration: true
},
{
id: 'openai-responses',
name: 'OpenAI Responses',
creator: (options: OpenAIProviderSettings) => createOpenAI(options).responses,
supportsImageGeneration: true
},
{
id: 'openai-compatible',
name: 'OpenAI Compatible',
creator: createOpenAICompatible,
supportsImageGeneration: true
},
{
id: 'anthropic',
name: 'Anthropic',
creator: createAnthropic,
supportsImageGeneration: false
},
{
id: 'google',
name: 'Google Generative AI',
creator: createGoogleGenerativeAI,
supportsImageGeneration: true
},
{
id: 'xai',
name: 'xAI (Grok)',
creator: createXai,
supportsImageGeneration: true
},
{
id: 'azure',
name: 'Azure OpenAI',
creator: createAzure,
supportsImageGeneration: true
},
{
id: 'deepseek',
name: 'DeepSeek',
creator: createDeepSeek,
supportsImageGeneration: false
}
] as const
/**
* 基础 Provider IDs
* 从 baseProviders 动态生成
*/
export const baseProviderIds = baseProviders.map((p) => p.id) as unknown as readonly [string, ...string[]]
/**
* 基础 Provider ID Schema
*/
export const baseProviderIdSchema = z.enum(baseProviderIds)
/**
* 用户自定义 Provider ID Schema
* 允许任意字符串,但排除基础 provider IDs 以避免冲突
*/
export const customProviderIdSchema = z
.string()
.min(1)
.refine((id) => !baseProviderIds.includes(id as any), {
message: 'Custom provider ID cannot conflict with base provider IDs'
})
/**
* Provider ID Schema - 支持基础和自定义
*/
export const providerIdSchema = z.union([baseProviderIdSchema, customProviderIdSchema])
/**
* Provider 配置 Schema
* 用于Provider的配置验证
*/
export const providerConfigSchema = z
.object({
id: customProviderIdSchema, // 只允许自定义ID
name: z.string().min(1),
creator: z.function().optional(),
import: z.function().optional(),
creatorFunctionName: z.string().optional(),
supportsImageGeneration: z.boolean().default(false),
imageCreator: z.function().optional(),
validateOptions: z.function().optional(),
aliases: z.array(z.string()).optional()
})
.refine((data) => data.creator || (data.import && data.creatorFunctionName), {
message: 'Must provide either creator function or import configuration'
})
/**
* Provider ID 类型 - 基于 zod schema 推导
*/
export type ProviderId = z.infer<typeof providerIdSchema>
export type BaseProviderId = z.infer<typeof baseProviderIdSchema>
export type CustomProviderId = z.infer<typeof customProviderIdSchema>
/**
* Provider 配置类型
*/
export type ProviderConfig = z.infer<typeof providerConfigSchema>
/**
* 兼容性类型别名
* @deprecated 使用 ProviderConfig 替代
*/
export type DynamicProviderRegistration = ProviderConfig

View File

@@ -0,0 +1,96 @@
import { type AnthropicProviderSettings } from '@ai-sdk/anthropic'
import { type AzureOpenAIProviderSettings } from '@ai-sdk/azure'
import { type DeepSeekProviderSettings } from '@ai-sdk/deepseek'
import { type GoogleGenerativeAIProviderSettings } from '@ai-sdk/google'
import { type OpenAIProviderSettings } from '@ai-sdk/openai'
import { type OpenAICompatibleProviderSettings } from '@ai-sdk/openai-compatible'
import {
EmbeddingModelV2 as EmbeddingModel,
ImageModelV2 as ImageModel,
LanguageModelV2 as LanguageModel,
ProviderV2,
SpeechModelV2 as SpeechModel,
TranscriptionModelV2 as TranscriptionModel
} from '@ai-sdk/provider'
import { type XaiProviderSettings } from '@ai-sdk/xai'
// 导入基于 Zod 的 ProviderId 类型
import { type ProviderId as ZodProviderId } from './schemas'
export interface ExtensibleProviderSettingsMap {
// 基础的静态providers
openai: OpenAIProviderSettings
'openai-responses': OpenAIProviderSettings
'openai-compatible': OpenAICompatibleProviderSettings
anthropic: AnthropicProviderSettings
google: GoogleGenerativeAIProviderSettings
xai: XaiProviderSettings
azure: AzureOpenAIProviderSettings
deepseek: DeepSeekProviderSettings
}
// 动态扩展的provider类型注册表
export interface DynamicProviderRegistry {
[key: string]: any
}
// 合并基础和动态provider类型
export type ProviderSettingsMap = ExtensibleProviderSettingsMap & DynamicProviderRegistry
// 错误类型
export class ProviderError extends Error {
constructor(
message: string,
public providerId: string,
public code?: string,
public cause?: Error
) {
super(message)
this.name = 'ProviderError'
}
}
// 动态ProviderId类型 - 基于 Zod Schema支持运行时扩展和验证
export type ProviderId = ZodProviderId
export interface ProviderTypeRegistrar {
registerProviderType<T extends string, S>(providerId: T, settingsType: S): void
getProviderSettings<T extends string>(providerId: T): any
}
// 重新导出所有类型供外部使用
export type {
AnthropicProviderSettings,
AzureOpenAIProviderSettings,
DeepSeekProviderSettings,
GoogleGenerativeAIProviderSettings,
OpenAICompatibleProviderSettings,
OpenAIProviderSettings,
XaiProviderSettings
}
export type AiSdkModel = LanguageModel | ImageModel | EmbeddingModel<string> | TranscriptionModel | SpeechModel
export type AiSdkModelType = 'text' | 'image' | 'embedding' | 'transcription' | 'speech'
export const METHOD_MAP = {
text: 'languageModel',
image: 'imageModel',
embedding: 'textEmbeddingModel',
transcription: 'transcriptionModel',
speech: 'speechModel'
} as const satisfies Record<AiSdkModelType, keyof ProviderV2>
export type AiSdkModelMethodMap = Record<AiSdkModelType, keyof ProviderV2>
export type AiSdkModelReturnMap = {
text: LanguageModel
image: ImageModel
embedding: EmbeddingModel<string>
transcription: TranscriptionModel
speech: SpeechModel
}
export type AiSdkMethodName<T extends AiSdkModelType> = (typeof METHOD_MAP)[T]
export type AiSdkModelReturn<T extends AiSdkModelType> = AiSdkModelReturnMap[T]

View File

@@ -0,0 +1,86 @@
/**
* 格式化私钥确保它包含正确的PEM头部和尾部
*/
export function formatPrivateKey(privateKey: string): string {
if (!privateKey || typeof privateKey !== 'string') {
throw new Error('Private key must be a non-empty string')
}
// 先处理 JSON 字符串中的转义换行符
const key = privateKey.replace(/\\n/g, '\n')
// 检查是否已经是正确格式的 PEM 私钥
const hasBeginMarker = key.includes('-----BEGIN PRIVATE KEY-----')
const hasEndMarker = key.includes('-----END PRIVATE KEY-----')
if (hasBeginMarker && hasEndMarker) {
// 已经是 PEM 格式,但可能格式不规范,重新格式化
return normalizePemFormat(key)
}
// 如果没有完整的 PEM 头尾,尝试重新构建
return reconstructPemKey(key)
}
/**
* 标准化 PEM 格式
*/
function normalizePemFormat(pemKey: string): string {
// 分离头部、内容和尾部
const lines = pemKey
.split('\n')
.map((line) => line.trim())
.filter((line) => line.length > 0)
let keyContent = ''
let foundBegin = false
let foundEnd = false
for (const line of lines) {
if (line === '-----BEGIN PRIVATE KEY-----') {
foundBegin = true
continue
}
if (line === '-----END PRIVATE KEY-----') {
foundEnd = true
break
}
if (foundBegin && !foundEnd) {
keyContent += line
}
}
if (!foundBegin || !foundEnd || !keyContent) {
throw new Error('Invalid PEM format: missing BEGIN/END markers or key content')
}
// 重新格式化为 64 字符一行
const formattedContent = keyContent.match(/.{1,64}/g)?.join('\n') || keyContent
return `-----BEGIN PRIVATE KEY-----\n${formattedContent}\n-----END PRIVATE KEY-----`
}
/**
* 重新构建 PEM 私钥
*/
function reconstructPemKey(key: string): string {
// 移除所有空白字符和可能存在的不完整头尾
let cleanKey = key.replace(/\s+/g, '')
cleanKey = cleanKey.replace(/-----BEGIN[^-]*-----/g, '')
cleanKey = cleanKey.replace(/-----END[^-]*-----/g, '')
// 确保私钥内容不为空
if (!cleanKey) {
throw new Error('Private key content is empty after cleaning')
}
// 验证是否是有效的 Base64 字符
if (!/^[A-Za-z0-9+/=]+$/.test(cleanKey)) {
throw new Error('Private key contains invalid characters (not valid Base64)')
}
// 格式化为 64 字符一行
const formattedKey = cleanKey.match(/.{1,64}/g)?.join('\n') || cleanKey
return `-----BEGIN PRIVATE KEY-----\n${formattedKey}\n-----END PRIVATE KEY-----`
}

View File

@@ -0,0 +1,540 @@
import { ImageModelV2 } from '@ai-sdk/provider'
import { experimental_generateImage as aiGenerateImage, NoImageGeneratedError } from 'ai'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { type AiPlugin } from '../../plugins'
import { globalRegistryManagement } from '../../providers/RegistryManagement'
import { ImageGenerationError } from '../errors'
import { RuntimeExecutor } from '../executor'
// Mock dependencies
vi.mock('ai', () => ({
experimental_generateImage: vi.fn(),
NoImageGeneratedError: class NoImageGeneratedError extends Error {
static isInstance = vi.fn()
constructor() {
super('No image generated')
this.name = 'NoImageGeneratedError'
}
}
}))
vi.mock('../../providers/RegistryManagement', () => ({
globalRegistryManagement: {
imageModel: vi.fn()
}
}))
describe('RuntimeExecutor.generateImage', () => {
let executor: RuntimeExecutor<'openai'>
let mockImageModel: ImageModelV2
let mockGenerateImageResult: any
beforeEach(() => {
// Reset all mocks
vi.clearAllMocks()
// Create executor instance
executor = RuntimeExecutor.create('openai', {
apiKey: 'test-key'
})
// Mock image model
mockImageModel = {
modelId: 'dall-e-3',
provider: 'openai'
} as ImageModelV2
// Mock generateImage result
mockGenerateImageResult = {
image: {
base64: 'base64-encoded-image-data',
uint8Array: new Uint8Array([1, 2, 3]),
mediaType: 'image/png'
},
images: [
{
base64: 'base64-encoded-image-data',
uint8Array: new Uint8Array([1, 2, 3]),
mediaType: 'image/png'
}
],
warnings: [],
providerMetadata: {
openai: {
images: [{ revisedPrompt: 'A detailed prompt' }]
}
},
responses: []
}
// Setup mocks
vi.mocked(globalRegistryManagement.imageModel).mockReturnValue(mockImageModel)
vi.mocked(aiGenerateImage).mockResolvedValue(mockGenerateImageResult)
// Reset mock implementation in case it was changed by previous tests
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => mockImageModel)
})
describe('Basic functionality', () => {
it('should generate a single image with minimal parameters', async () => {
const result = await executor.generateImage('dall-e-3', {
prompt: 'A futuristic cityscape at sunset'
})
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('openai:dall-e-3')
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A futuristic cityscape at sunset'
})
expect(result).toEqual(mockGenerateImageResult)
})
it('should generate image with pre-created model', async () => {
const result = await executor.generateImage(mockImageModel, {
prompt: 'A beautiful landscape'
})
// Note: globalRegistryManagement.imageModel may still be called due to resolveImageModel logic
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A beautiful landscape'
})
expect(result).toEqual(mockGenerateImageResult)
})
it('should support multiple images generation', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A futuristic cityscape',
n: 3
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A futuristic cityscape',
n: 3
})
})
it('should support size specification', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A beautiful sunset',
size: '1024x1024'
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A beautiful sunset',
size: '1024x1024'
})
})
it('should support aspect ratio specification', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A mountain landscape',
aspectRatio: '16:9'
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A mountain landscape',
aspectRatio: '16:9'
})
})
it('should support seed for consistent output', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A cat in space',
seed: 1234567890
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A cat in space',
seed: 1234567890
})
})
it('should support abort signal', async () => {
const abortController = new AbortController()
await executor.generateImage('dall-e-3', {
prompt: 'A cityscape',
abortSignal: abortController.signal
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A cityscape',
abortSignal: abortController.signal
})
})
it('should support provider-specific options', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A space station',
providerOptions: {
openai: {
quality: 'hd',
style: 'vivid'
}
}
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A space station',
providerOptions: {
openai: {
quality: 'hd',
style: 'vivid'
}
}
})
})
it('should support custom headers', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A robot',
headers: {
'X-Custom-Header': 'test-value'
}
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A robot',
headers: {
'X-Custom-Header': 'test-value'
}
})
})
})
describe('Plugin integration', () => {
it('should execute plugins in correct order', async () => {
const pluginCallOrder: string[] = []
const testPlugin: AiPlugin = {
name: 'test-plugin',
onRequestStart: vi.fn(async () => {
pluginCallOrder.push('onRequestStart')
}),
transformParams: vi.fn(async (params) => {
pluginCallOrder.push('transformParams')
return { ...params, size: '512x512' }
}),
transformResult: vi.fn(async (result) => {
pluginCallOrder.push('transformResult')
return { ...result, processed: true }
}),
onRequestEnd: vi.fn(async () => {
pluginCallOrder.push('onRequestEnd')
})
}
const executorWithPlugin = RuntimeExecutor.create(
'openai',
{
apiKey: 'test-key'
},
[testPlugin]
)
const result = await executorWithPlugin.generateImage('dall-e-3', {
prompt: 'A test image'
})
expect(pluginCallOrder).toEqual(['onRequestStart', 'transformParams', 'transformResult', 'onRequestEnd'])
expect(testPlugin.transformParams).toHaveBeenCalledWith(
{ prompt: 'A test image' },
expect.objectContaining({
providerId: 'openai',
modelId: 'dall-e-3'
})
)
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A test image',
size: '512x512' // Should be transformed by plugin
})
expect(result).toEqual({
...mockGenerateImageResult,
processed: true // Should be transformed by plugin
})
})
it('should handle model resolution through plugins', async () => {
const customImageModel = {
modelId: 'custom-model',
provider: 'openai'
} as ImageModelV2
const modelResolutionPlugin: AiPlugin = {
name: 'model-resolver',
resolveModel: vi.fn(async () => customImageModel)
}
const executorWithPlugin = RuntimeExecutor.create(
'openai',
{
apiKey: 'test-key'
},
[modelResolutionPlugin]
)
await executorWithPlugin.generateImage('dall-e-3', {
prompt: 'A test image'
})
expect(modelResolutionPlugin.resolveModel).toHaveBeenCalledWith(
'dall-e-3',
expect.objectContaining({
providerId: 'openai',
modelId: 'dall-e-3'
})
)
expect(aiGenerateImage).toHaveBeenCalledWith({
model: customImageModel,
prompt: 'A test image'
})
})
it('should support recursive calls from plugins', async () => {
const recursivePlugin: AiPlugin = {
name: 'recursive-plugin',
transformParams: vi.fn(async (params, context) => {
if (!context.isRecursiveCall && params.prompt === 'original') {
// Make a recursive call with modified prompt
await context.recursiveCall({
prompt: 'modified'
})
}
return params
})
}
const executorWithPlugin = RuntimeExecutor.create(
'openai',
{
apiKey: 'test-key'
},
[recursivePlugin]
)
await executorWithPlugin.generateImage('dall-e-3', {
prompt: 'original'
})
expect(recursivePlugin.transformParams).toHaveBeenCalledTimes(2)
expect(aiGenerateImage).toHaveBeenCalledTimes(2)
})
})
describe('Error handling', () => {
it('should handle model creation errors', async () => {
const modelError = new Error('Failed to get image model')
vi.mocked(globalRegistryManagement.imageModel).mockImplementation(() => {
throw modelError
})
await expect(
executor.generateImage('invalid-model', {
prompt: 'A test image'
})
).rejects.toThrow(ImageGenerationError)
})
it('should handle image generation API errors', async () => {
const apiError = new Error('API request failed')
vi.mocked(aiGenerateImage).mockRejectedValue(apiError)
await expect(
executor.generateImage('dall-e-3', {
prompt: 'A test image'
})
).rejects.toThrow('Failed to generate image: API request failed')
})
it('should handle NoImageGeneratedError', async () => {
const noImageError = new NoImageGeneratedError({
cause: new Error('No image generated'),
responses: []
})
vi.mocked(aiGenerateImage).mockRejectedValue(noImageError)
vi.mocked(NoImageGeneratedError.isInstance).mockReturnValue(true)
await expect(
executor.generateImage('dall-e-3', {
prompt: 'A test image'
})
).rejects.toThrow('Failed to generate image: No image generated')
})
it('should execute onError plugin hook on failure', async () => {
const error = new Error('Generation failed')
vi.mocked(aiGenerateImage).mockRejectedValue(error)
const errorPlugin: AiPlugin = {
name: 'error-handler',
onError: vi.fn()
}
const executorWithPlugin = RuntimeExecutor.create(
'openai',
{
apiKey: 'test-key'
},
[errorPlugin]
)
await expect(
executorWithPlugin.generateImage('dall-e-3', {
prompt: 'A test image'
})
).rejects.toThrow('Failed to generate image: Generation failed')
expect(errorPlugin.onError).toHaveBeenCalledWith(
error,
expect.objectContaining({
providerId: 'openai',
modelId: 'dall-e-3'
})
)
})
it('should handle abort signal timeout', async () => {
const abortError = new Error('Operation was aborted')
abortError.name = 'AbortError'
vi.mocked(aiGenerateImage).mockRejectedValue(abortError)
const abortController = new AbortController()
setTimeout(() => abortController.abort(), 10)
await expect(
executor.generateImage('dall-e-3', {
prompt: 'A test image',
abortSignal: abortController.signal
})
).rejects.toThrow('Operation was aborted')
})
})
describe('Multiple providers support', () => {
it('should work with different providers', async () => {
const googleExecutor = RuntimeExecutor.create('google', {
apiKey: 'google-key'
})
await googleExecutor.generateImage('imagen-3.0-generate-002', {
prompt: 'A landscape'
})
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('google:imagen-3.0-generate-002')
})
it('should support xAI Grok image models', async () => {
const xaiExecutor = RuntimeExecutor.create('xai', {
apiKey: 'xai-key'
})
await xaiExecutor.generateImage('grok-2-image', {
prompt: 'A futuristic robot'
})
expect(globalRegistryManagement.imageModel).toHaveBeenCalledWith('xai:grok-2-image')
})
})
describe('Advanced features', () => {
it('should support batch image generation with maxImagesPerCall', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A test image',
n: 10,
maxImagesPerCall: 5
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A test image',
n: 10,
maxImagesPerCall: 5
})
})
it('should support retries with maxRetries', async () => {
await executor.generateImage('dall-e-3', {
prompt: 'A test image',
maxRetries: 3
})
expect(aiGenerateImage).toHaveBeenCalledWith({
model: mockImageModel,
prompt: 'A test image',
maxRetries: 3
})
})
it('should handle warnings from the model', async () => {
const resultWithWarnings = {
...mockGenerateImageResult,
warnings: [
{
type: 'unsupported-setting',
message: 'Size parameter not supported for this model'
}
]
}
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithWarnings)
const result = await executor.generateImage('dall-e-3', {
prompt: 'A test image',
size: '2048x2048' // Unsupported size
})
expect(result.warnings).toHaveLength(1)
expect(result.warnings[0].type).toBe('unsupported-setting')
})
it('should provide access to provider metadata', async () => {
const result = await executor.generateImage('dall-e-3', {
prompt: 'A test image'
})
expect(result.providerMetadata).toBeDefined()
expect(result.providerMetadata.openai).toBeDefined()
})
it('should provide response metadata', async () => {
const resultWithMetadata = {
...mockGenerateImageResult,
responses: [
{
timestamp: new Date(),
modelId: 'dall-e-3',
headers: { 'x-request-id': 'test-123' }
}
]
}
vi.mocked(aiGenerateImage).mockResolvedValue(resultWithMetadata)
const result = await executor.generateImage('dall-e-3', {
prompt: 'A test image'
})
expect(result.responses).toHaveLength(1)
expect(result.responses[0].modelId).toBe('dall-e-3')
expect(result.responses[0].headers).toEqual({ 'x-request-id': 'test-123' })
})
})
})

View File

@@ -0,0 +1,38 @@
/**
* Error classes for runtime operations
*/
/**
* Error thrown when image generation fails
*/
export class ImageGenerationError extends Error {
constructor(
message: string,
public providerId?: string,
public modelId?: string,
public cause?: Error
) {
super(message)
this.name = 'ImageGenerationError'
// Maintain proper stack trace (for V8 engines)
if (Error.captureStackTrace) {
Error.captureStackTrace(this, ImageGenerationError)
}
}
}
/**
* Error thrown when model resolution fails during image generation
*/
export class ImageModelResolutionError extends ImageGenerationError {
constructor(modelId: string, providerId?: string, cause?: Error) {
super(
`Failed to resolve image model: ${modelId}${providerId ? ` for provider: ${providerId}` : ''}`,
providerId,
modelId,
cause
)
this.name = 'ImageModelResolutionError'
}
}

View File

@@ -0,0 +1,346 @@
/**
* 运行时执行器
* 专注于插件化的AI调用处理
*/
import { ImageModelV2, LanguageModelV2, LanguageModelV2Middleware } from '@ai-sdk/provider'
import {
experimental_generateImage as generateImage,
generateObject,
generateText,
LanguageModel,
streamObject,
streamText
} from 'ai'
import { globalModelResolver } from '../models'
import { type ModelConfig } from '../models/types'
import { type AiPlugin, type AiRequestContext, definePlugin } from '../plugins'
import { type ProviderId } from '../providers'
import { ImageGenerationError, ImageModelResolutionError } from './errors'
import { PluginEngine } from './pluginEngine'
import { type RuntimeConfig } from './types'
export class RuntimeExecutor<T extends ProviderId = ProviderId> {
public pluginEngine: PluginEngine<T>
// private options: ProviderSettingsMap[T]
private config: RuntimeConfig<T>
constructor(config: RuntimeConfig<T>) {
// if (!isProviderSupported(config.providerId)) {
// throw new Error(`Unsupported provider: ${config.providerId}`)
// }
// 存储options供后续使用
// this.options = config.options
this.config = config
// 创建插件客户端
this.pluginEngine = new PluginEngine(config.providerId, config.plugins || [])
}
private createResolveModelPlugin(middlewares?: LanguageModelV2Middleware[]) {
return definePlugin({
name: '_internal_resolveModel',
enforce: 'post',
resolveModel: async (modelId: string) => {
// 注意extraModelConfig 暂时不支持,已在新架构中移除
return await this.resolveModel(modelId, middlewares)
}
})
}
private createResolveImageModelPlugin() {
return definePlugin({
name: '_internal_resolveImageModel',
enforce: 'post',
resolveModel: async (modelId: string) => {
return await this.resolveImageModel(modelId)
}
})
}
private createConfigureContextPlugin() {
return definePlugin({
name: '_internal_configureContext',
configureContext: async (context: AiRequestContext) => {
context.executor = this
}
})
}
// === 高阶重载:直接使用模型 ===
/**
* 流式文本生成 - 使用已创建的模型(高级用法)
*/
async streamText(
model: LanguageModel,
params: Omit<Parameters<typeof streamText>[0], 'model'>
): Promise<ReturnType<typeof streamText>>
async streamText(
modelId: string,
params: Omit<Parameters<typeof streamText>[0], 'model'>,
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof streamText>>
async streamText(
modelOrId: LanguageModel,
params: Omit<Parameters<typeof streamText>[0], 'model'>,
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof streamText>> {
this.pluginEngine.usePlugins([
this.createResolveModelPlugin(options?.middlewares),
this.createConfigureContextPlugin()
])
// 2. 执行插件处理
return this.pluginEngine.executeStreamWithPlugins(
'streamText',
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
params,
async (model, transformedParams, streamTransforms) => {
const experimental_transform =
params?.experimental_transform ?? (streamTransforms.length > 0 ? streamTransforms : undefined)
const finalParams = {
model,
...transformedParams,
experimental_transform
} as Parameters<typeof streamText>[0]
return await streamText(finalParams)
}
)
}
// === 其他方法的重载 ===
/**
* 生成文本 - 使用已创建的模型
*/
async generateText(
model: LanguageModel,
params: Omit<Parameters<typeof generateText>[0], 'model'>
): Promise<ReturnType<typeof generateText>>
async generateText(
modelId: string,
params: Omit<Parameters<typeof generateText>[0], 'model'>,
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof generateText>>
async generateText(
modelOrId: LanguageModel | string,
params: Omit<Parameters<typeof generateText>[0], 'model'>,
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof generateText>> {
this.pluginEngine.usePlugins([
this.createResolveModelPlugin(options?.middlewares),
this.createConfigureContextPlugin()
])
return this.pluginEngine.executeWithPlugins(
'generateText',
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
params,
async (model, transformedParams) =>
generateText({ model, ...transformedParams } as Parameters<typeof generateText>[0])
)
}
/**
* 生成结构化对象 - 使用已创建的模型
*/
async generateObject(
model: LanguageModel,
params: Omit<Parameters<typeof generateObject>[0], 'model'>
): Promise<ReturnType<typeof generateObject>>
async generateObject(
modelOrId: string,
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof generateObject>>
async generateObject(
modelOrId: LanguageModel | string,
params: Omit<Parameters<typeof generateObject>[0], 'model'>,
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof generateObject>> {
this.pluginEngine.usePlugins([
this.createResolveModelPlugin(options?.middlewares),
this.createConfigureContextPlugin()
])
return this.pluginEngine.executeWithPlugins(
'generateObject',
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
params,
async (model, transformedParams) =>
generateObject({ model, ...transformedParams } as Parameters<typeof generateObject>[0])
)
}
/**
* 流式生成结构化对象 - 使用已创建的模型
*/
async streamObject(
model: LanguageModel,
params: Omit<Parameters<typeof streamObject>[0], 'model'>
): Promise<ReturnType<typeof streamObject>>
async streamObject(
modelId: string,
params: Omit<Parameters<typeof streamObject>[0], 'model'>,
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof streamObject>>
async streamObject(
modelOrId: LanguageModel | string,
params: Omit<Parameters<typeof streamObject>[0], 'model'>,
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof streamObject>> {
this.pluginEngine.usePlugins([
this.createResolveModelPlugin(options?.middlewares),
this.createConfigureContextPlugin()
])
return this.pluginEngine.executeWithPlugins(
'streamObject',
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
params,
async (model, transformedParams) =>
streamObject({ model, ...transformedParams } as Parameters<typeof streamObject>[0])
)
}
/**
* 生成图像 - 使用已创建的图像模型
*/
async generateImage(
model: ImageModelV2,
params: Omit<Parameters<typeof generateImage>[0], 'model'>
): Promise<ReturnType<typeof generateImage>>
async generateImage(
modelId: string,
params: Omit<Parameters<typeof generateImage>[0], 'model'>,
options?: {
middlewares?: LanguageModelV2Middleware[]
}
): Promise<ReturnType<typeof generateImage>>
async generateImage(
modelOrId: ImageModelV2 | string,
params: Omit<Parameters<typeof generateImage>[0], 'model'>
): Promise<ReturnType<typeof generateImage>> {
try {
this.pluginEngine.usePlugins([this.createResolveImageModelPlugin(), this.createConfigureContextPlugin()])
return await this.pluginEngine.executeImageWithPlugins(
'generateImage',
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
params,
async (model, transformedParams) => {
return await generateImage({ model, ...transformedParams })
}
)
} catch (error) {
if (error instanceof Error) {
throw new ImageGenerationError(
`Failed to generate image: ${error.message}`,
this.config.providerId,
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
error
)
}
throw error
}
}
// === 辅助方法 ===
/**
* 解析模型:如果是字符串则创建模型,如果是模型则直接返回
*/
private async resolveModel(
modelOrId: LanguageModel,
middlewares?: LanguageModelV2Middleware[]
): Promise<LanguageModelV2> {
if (typeof modelOrId === 'string') {
// 🎯 字符串modelId使用新的ModelResolver解析传递完整参数
return await globalModelResolver.resolveLanguageModel(
modelOrId, // 支持 'gpt-4' 和 'aihubmix:anthropic:claude-3.5-sonnet'
this.config.providerId, // fallback provider
this.config.providerSettings, // provider options
middlewares // 中间件数组
)
} else {
// 已经是模型,直接返回
return modelOrId
}
}
/**
* 解析图像模型:如果是字符串则创建图像模型,如果是模型则直接返回
*/
private async resolveImageModel(modelOrId: ImageModelV2 | string): Promise<ImageModelV2> {
try {
if (typeof modelOrId === 'string') {
// 字符串modelId使用新的ModelResolver解析
return await globalModelResolver.resolveImageModel(
modelOrId, // 支持 'dall-e-3' 和 'aihubmix:openai:dall-e-3'
this.config.providerId // fallback provider
)
} else {
// 已经是模型,直接返回
return modelOrId
}
} catch (error) {
throw new ImageModelResolutionError(
typeof modelOrId === 'string' ? modelOrId : modelOrId.modelId,
this.config.providerId,
error instanceof Error ? error : undefined
)
}
}
// === 静态工厂方法 ===
/**
* 创建执行器 - 支持已知provider的类型安全
*/
static create<T extends ProviderId>(
providerId: T,
options: ModelConfig<T>['providerSettings'],
plugins?: AiPlugin[]
): RuntimeExecutor<T> {
return new RuntimeExecutor({
providerId,
providerSettings: options,
plugins
})
}
/**
* 创建OpenAI Compatible执行器
*/
static createOpenAICompatible(
options: ModelConfig<'openai-compatible'>['providerSettings'],
plugins: AiPlugin[] = []
): RuntimeExecutor<'openai-compatible'> {
return new RuntimeExecutor({
providerId: 'openai-compatible',
providerSettings: options,
plugins
})
}
}

View File

@@ -0,0 +1,123 @@
/**
* Runtime 模块导出
* 专注于运行时插件化AI调用处理
*/
// 主要的运行时执行器
export { RuntimeExecutor } from './executor'
// 导出类型
export type { RuntimeConfig } from './types'
// === 便捷工厂函数 ===
import { LanguageModelV2Middleware } from '@ai-sdk/provider'
import { type AiPlugin } from '../plugins'
import { type ProviderId, type ProviderSettingsMap } from '../providers/types'
import { RuntimeExecutor } from './executor'
/**
* 创建运行时执行器 - 支持类型安全的已知provider
*/
export function createExecutor<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
plugins?: AiPlugin[]
): RuntimeExecutor<T> {
return RuntimeExecutor.create(providerId, options, plugins)
}
/**
* 创建OpenAI Compatible执行器
*/
export function createOpenAICompatibleExecutor(
options: ProviderSettingsMap['openai-compatible'] & { mode?: 'chat' | 'responses' },
plugins: AiPlugin[] = []
): RuntimeExecutor<'openai-compatible'> {
return RuntimeExecutor.createOpenAICompatible(options, plugins)
}
// === 直接调用API无需创建executor实例===
/**
* 直接流式文本生成 - 支持middlewares
*/
export async function streamText<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
modelId: string,
params: Parameters<RuntimeExecutor<T>['streamText']>[1],
plugins?: AiPlugin[],
middlewares?: LanguageModelV2Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['streamText']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.streamText(modelId, params, { middlewares })
}
/**
* 直接生成文本 - 支持middlewares
*/
export async function generateText<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
modelId: string,
params: Parameters<RuntimeExecutor<T>['generateText']>[1],
plugins?: AiPlugin[],
middlewares?: LanguageModelV2Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['generateText']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.generateText(modelId, params, { middlewares })
}
/**
* 直接生成结构化对象 - 支持middlewares
*/
export async function generateObject<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
modelId: string,
params: Parameters<RuntimeExecutor<T>['generateObject']>[1],
plugins?: AiPlugin[],
middlewares?: LanguageModelV2Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['generateObject']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.generateObject(modelId, params, { middlewares })
}
/**
* 直接流式生成结构化对象 - 支持middlewares
*/
export async function streamObject<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
modelId: string,
params: Parameters<RuntimeExecutor<T>['streamObject']>[1],
plugins?: AiPlugin[],
middlewares?: LanguageModelV2Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['streamObject']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.streamObject(modelId, params, { middlewares })
}
/**
* 直接生成图像 - 支持middlewares
*/
export async function generateImage<T extends ProviderId>(
providerId: T,
options: ProviderSettingsMap[T] & { mode?: 'chat' | 'responses' },
modelId: string,
params: Parameters<RuntimeExecutor<T>['generateImage']>[1],
plugins?: AiPlugin[],
middlewares?: LanguageModelV2Middleware[]
): Promise<ReturnType<RuntimeExecutor<T>['generateImage']>> {
const executor = createExecutor(providerId, options, plugins)
return executor.generateImage(modelId, params, { middlewares })
}
// === Agent 功能预留 ===
// 未来将在 ../agents/ 文件夹中添加:
// - AgentExecutor.ts
// - WorkflowManager.ts
// - ConversationManager.ts
// 并在此处导出相关API

View File

@@ -0,0 +1,231 @@
/* eslint-disable @eslint-react/naming-convention/context-name */
import { ImageModelV2 } from '@ai-sdk/provider'
import { LanguageModel } from 'ai'
import { type AiPlugin, createContext, PluginManager } from '../plugins'
import { type ProviderId } from '../providers/types'
/**
* 插件增强的 AI 客户端
* 专注于插件处理不暴露用户API
*/
export class PluginEngine<T extends ProviderId = ProviderId> {
private pluginManager: PluginManager
constructor(
private readonly providerId: T,
// private readonly options: ProviderSettingsMap[T],
plugins: AiPlugin[] = []
) {
this.pluginManager = new PluginManager(plugins)
}
/**
* 添加插件
*/
use(plugin: AiPlugin): this {
this.pluginManager.use(plugin)
return this
}
/**
* 批量添加插件
*/
usePlugins(plugins: AiPlugin[]): this {
plugins.forEach((plugin) => this.use(plugin))
return this
}
/**
* 移除插件
*/
removePlugin(pluginName: string): this {
this.pluginManager.remove(pluginName)
return this
}
/**
* 获取插件统计
*/
getPluginStats() {
return this.pluginManager.getStats()
}
/**
* 获取所有插件
*/
getPlugins() {
return this.pluginManager.getPlugins()
}
/**
* 执行带插件的操作(非流式)
* 提供给AiExecutor使用
*/
async executeWithPlugins<TParams, TResult>(
methodName: string,
modelId: string,
params: TParams,
executor: (model: LanguageModel, transformedParams: TParams) => Promise<TResult>,
_context?: ReturnType<typeof createContext>
): Promise<TResult> {
// 使用正确的createContext创建请求上下文
const context = _context ? _context : createContext(this.providerId, modelId, params)
// 🔥 为上下文添加递归调用能力
context.recursiveCall = async (newParams: any): Promise<TResult> => {
// 递归调用自身,重新走完整的插件流程
context.isRecursiveCall = true
const result = await this.executeWithPlugins(methodName, modelId, newParams, executor, context)
context.isRecursiveCall = false
return result
}
try {
// 0. 配置上下文
await this.pluginManager.executeConfigureContext(context)
// 1. 触发请求开始事件
await this.pluginManager.executeParallel('onRequestStart', context)
// 2. 解析模型
const model = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
if (!model) {
throw new Error(`Failed to resolve model: ${modelId}`)
}
// 3. 转换请求参数
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
// 4. 执行具体的 API 调用
const result = await executor(model, transformedParams)
// 5. 转换结果(对于非流式调用)
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
// 6. 触发完成事件
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
return transformedResult
} catch (error) {
// 7. 触发错误事件
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
throw error
}
}
/**
* 执行带插件的图像生成操作
* 提供给AiExecutor使用
*/
async executeImageWithPlugins<TParams, TResult>(
methodName: string,
modelId: string,
params: TParams,
executor: (model: ImageModelV2, transformedParams: TParams) => Promise<TResult>,
_context?: ReturnType<typeof createContext>
): Promise<TResult> {
// 使用正确的createContext创建请求上下文
const context = _context ? _context : createContext(this.providerId, modelId, params)
// 🔥 为上下文添加递归调用能力
context.recursiveCall = async (newParams: any): Promise<TResult> => {
// 递归调用自身,重新走完整的插件流程
context.isRecursiveCall = true
const result = await this.executeImageWithPlugins(methodName, modelId, newParams, executor, context)
context.isRecursiveCall = false
return result
}
try {
// 0. 配置上下文
await this.pluginManager.executeConfigureContext(context)
// 1. 触发请求开始事件
await this.pluginManager.executeParallel('onRequestStart', context)
// 2. 解析模型
const model = await this.pluginManager.executeFirst<ImageModelV2>('resolveModel', modelId, context)
if (!model) {
throw new Error(`Failed to resolve image model: ${modelId}`)
}
// 3. 转换请求参数
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
// 4. 执行具体的 API 调用
const result = await executor(model, transformedParams)
// 5. 转换结果
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
// 6. 触发完成事件
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
return transformedResult
} catch (error) {
// 7. 触发错误事件
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
throw error
}
}
/**
* 执行流式调用的通用逻辑(支持流转换器)
* 提供给AiExecutor使用
*/
async executeStreamWithPlugins<TParams, TResult>(
methodName: string,
modelId: string,
params: TParams,
executor: (model: LanguageModel, transformedParams: TParams, streamTransforms: any[]) => Promise<TResult>,
_context?: ReturnType<typeof createContext>
): Promise<TResult> {
// 创建请求上下文
const context = _context ? _context : createContext(this.providerId, modelId, params)
// 🔥 为上下文添加递归调用能力
context.recursiveCall = async (newParams: any): Promise<TResult> => {
// 递归调用自身,重新走完整的插件流程
context.isRecursiveCall = true
const result = await this.executeStreamWithPlugins(methodName, modelId, newParams, executor, context)
context.isRecursiveCall = false
return result
}
try {
// 0. 配置上下文
await this.pluginManager.executeConfigureContext(context)
// 1. 触发请求开始事件
await this.pluginManager.executeParallel('onRequestStart', context)
// 2. 解析模型
const model = await this.pluginManager.executeFirst<LanguageModel>('resolveModel', modelId, context)
if (!model) {
throw new Error(`Failed to resolve model: ${modelId}`)
}
// 3. 转换请求参数
const transformedParams = await this.pluginManager.executeSequential('transformParams', params, context)
// 4. 收集流转换器
const streamTransforms = this.pluginManager.collectStreamTransforms(transformedParams, context)
// 5. 执行流式 API 调用
const result = await executor(model, transformedParams, streamTransforms)
const transformedResult = await this.pluginManager.executeSequential('transformResult', result, context)
// 6. 触发完成事件(注意:对于流式调用,这里触发的是开始流式响应的事件)
await this.pluginManager.executeParallel('onRequestEnd', context, transformedResult)
return transformedResult
} catch (error) {
// 7. 触发错误事件
await this.pluginManager.executeParallel('onError', context, undefined, error as Error)
throw error
}
}
}

View File

@@ -0,0 +1,15 @@
/**
* Runtime 层类型定义
*/
import { type ModelConfig } from '../models/types'
import { type AiPlugin } from '../plugins'
import { type ProviderId } from '../providers/types'
/**
* 运行时执行器配置
*/
export interface RuntimeConfig<T extends ProviderId = ProviderId> {
providerId: T
providerSettings: ModelConfig<T>['providerSettings'] & { mode?: 'chat' | 'responses' }
plugins?: AiPlugin[]
}

View File

@@ -0,0 +1,46 @@
/**
* Cherry Studio AI Core Package
* 基于 Vercel AI SDK 的统一 AI Provider 接口
*/
// 导入内部使用的类和函数
// ==================== 主要用户接口 ====================
export {
createExecutor,
createOpenAICompatibleExecutor,
generateImage,
generateObject,
generateText,
streamText
} from './core/runtime'
// ==================== 高级API ====================
export { globalModelResolver as modelResolver } from './core/models'
// ==================== 插件系统 ====================
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins'
export { createContext, definePlugin, PluginManager } from './core/plugins'
// export { createPromptToolUsePlugin, webSearchPlugin } from './core/plugins/built-in'
export { PluginEngine } from './core/runtime/pluginEngine'
// ==================== AI SDK 常用类型导出 ====================
// 直接导出 AI SDK 的常用类型,方便使用
export type { LanguageModelV2Middleware, LanguageModelV2StreamPart } from '@ai-sdk/provider'
export type { ToolCall } from '@ai-sdk/provider-utils'
export type { ReasoningPart } from '@ai-sdk/provider-utils'
// ==================== 选项 ====================
export {
createAnthropicOptions,
createGoogleOptions,
createOpenAIOptions,
type ExtractProviderOptions,
mergeProviderOptions,
type ProviderOptionsMap,
type TypedProviderOptions
} from './core/options'
// ==================== 包信息 ====================
export const AI_CORE_VERSION = '1.0.0'
export const AI_CORE_NAME = '@cherrystudio/ai-core'

View File

@@ -0,0 +1,2 @@
// 重新导出插件类型
export type { AiPlugin, AiRequestContext, HookResult, PluginManagerConfig } from './core/plugins/types'

View File

@@ -0,0 +1,26 @@
{
"compilerOptions": {
"target": "ES2020",
"module": "ESNext",
"moduleResolution": "bundler",
"declaration": true,
"outDir": "./dist",
"rootDir": "./src",
"strict": true,
"esModuleInterop": true,
"skipLibCheck": true,
"forceConsistentCasingInFileNames": true,
"resolveJsonModule": true,
"allowSyntheticDefaultImports": true,
"noEmitOnError": false,
"experimentalDecorators": true,
"emitDecoratorMetadata": true
},
"include": [
"src/**/*"
],
"exclude": [
"node_modules",
"dist"
]
}

View File

@@ -0,0 +1,14 @@
import { defineConfig } from 'tsdown'
export default defineConfig({
entry: {
index: 'src/index.ts',
'built-in/plugins/index': 'src/core/plugins/built-in/index.ts',
'provider/index': 'src/core/providers/index.ts'
},
outDir: 'dist',
format: ['esm', 'cjs'],
clean: true,
dts: true,
tsconfig: 'tsconfig.json'
})

View File

@@ -0,0 +1,15 @@
import { defineConfig } from 'vitest/config'
export default defineConfig({
test: {
globals: true
},
resolve: {
alias: {
'@': './src'
}
},
esbuild: {
target: 'node18'
}
})

View File

@@ -570,7 +570,8 @@ class McpService {
...tool,
id: buildFunctionCallToolName(server.name, tool.name),
serverId: server.id,
serverName: server.name
serverName: server.name,
type: 'mcp'
}
serverTools.push(serverTool)
})

View File

@@ -0,0 +1,306 @@
/**
* AI SDK 到 Cherry Studio Chunk 适配器
* 用于将 AI SDK 的 fullStream 转换为 Cherry Studio 的 chunk 格式
*/
import { loggerService } from '@logger'
import { MCPTool, WebSearchResults, WebSearchSource } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import type { TextStreamPart, ToolSet } from 'ai'
import { ToolCallChunkHandler } from './handleToolCallChunk'
const logger = loggerService.withContext('AiSdkToChunkAdapter')
export interface CherryStudioChunk {
type: 'text-delta' | 'text-complete' | 'tool-call' | 'tool-result' | 'finish' | 'error'
text?: string
toolCall?: any
toolResult?: any
finishReason?: string
usage?: any
error?: any
}
/**
* AI SDK 到 Cherry Studio Chunk 适配器类
* 处理 fullStream 到 Cherry Studio chunk 的转换
*/
export class AiSdkToChunkAdapter {
toolCallHandler: ToolCallChunkHandler
constructor(
private onChunk: (chunk: Chunk) => void,
mcpTools: MCPTool[] = []
) {
this.toolCallHandler = new ToolCallChunkHandler(onChunk, mcpTools)
}
/**
* 处理 AI SDK 流结果
* @param aiSdkResult AI SDK 的流结果对象
* @returns 最终的文本内容
*/
async processStream(aiSdkResult: any): Promise<string> {
// 如果是流式且有 fullStream
if (aiSdkResult.fullStream) {
await this.readFullStream(aiSdkResult.fullStream)
}
// 使用 streamResult.text 获取最终结果
return await aiSdkResult.text
}
/**
* 读取 fullStream 并转换为 Cherry Studio chunks
* @param fullStream AI SDK 的 fullStream (ReadableStream)
*/
private async readFullStream(fullStream: ReadableStream<TextStreamPart<ToolSet>>) {
const reader = fullStream.getReader()
const final = {
text: '',
reasoningContent: '',
webSearchResults: [],
reasoningId: ''
}
try {
while (true) {
const { done, value } = await reader.read()
if (done) {
break
}
// 转换并发送 chunk
this.convertAndEmitChunk(value, final)
}
} finally {
reader.releaseLock()
}
}
/**
* 转换 AI SDK chunk 为 Cherry Studio chunk 并调用回调
* @param chunk AI SDK 的 chunk 数据
*/
private convertAndEmitChunk(
chunk: TextStreamPart<any>,
final: { text: string; reasoningContent: string; webSearchResults: any[]; reasoningId: string }
) {
logger.info(`AI SDK chunk type: ${chunk.type}`, chunk)
switch (chunk.type) {
// === 文本相关事件 ===
case 'text-start':
this.onChunk({
type: ChunkType.TEXT_START
})
break
case 'text-delta':
final.text += chunk.text || ''
this.onChunk({
type: ChunkType.TEXT_DELTA,
text: final.text || ''
})
break
case 'text-end':
this.onChunk({
type: ChunkType.TEXT_COMPLETE,
text: (chunk.providerMetadata?.text?.value as string) ?? final.text ?? ''
})
final.text = ''
break
case 'reasoning-start':
// if (final.reasoningId !== chunk.id) {
final.reasoningId = chunk.id
this.onChunk({
type: ChunkType.THINKING_START
})
// }
break
case 'reasoning-delta':
final.reasoningContent += chunk.text || ''
this.onChunk({
type: ChunkType.THINKING_DELTA,
text: final.reasoningContent || '',
thinking_millsec: (chunk.providerMetadata?.metadata?.thinking_millsec as number) || 0
})
break
case 'reasoning-end':
this.onChunk({
type: ChunkType.THINKING_COMPLETE,
text: (chunk.providerMetadata?.metadata?.thinking_content as string) || '',
thinking_millsec: (chunk.providerMetadata?.metadata?.thinking_millsec as number) || 0
})
final.reasoningContent = ''
break
// === 工具调用相关事件(原始 AI SDK 事件,如果没有被中间件处理) ===
// case 'tool-input-start':
// case 'tool-input-delta':
// case 'tool-input-end':
// this.toolCallHandler.handleToolCallCreated(chunk)
// break
// case 'tool-input-delta':
// this.toolCallHandler.handleToolCallCreated(chunk)
// break
case 'tool-call':
// 原始的工具调用(未被中间件处理)
this.toolCallHandler.handleToolCall(chunk)
break
case 'tool-result':
// 原始的工具调用结果(未被中间件处理)
this.toolCallHandler.handleToolResult(chunk)
break
// === 步骤相关事件 ===
// case 'start':
// this.onChunk({
// type: ChunkType.LLM_RESPONSE_CREATED
// })
// break
// TODO: 需要区分接口开始和步骤开始
// case 'start-step':
// this.onChunk({
// type: ChunkType.BLOCK_CREATED
// })
// break
// case 'step-finish':
// this.onChunk({
// type: ChunkType.TEXT_COMPLETE,
// text: final.text || '' // TEXT_COMPLETE 需要 text 字段
// })
// final.text = ''
// break
case 'finish-step': {
const { providerMetadata, finishReason } = chunk
// googel web search
if (providerMetadata?.google?.groundingMetadata) {
this.onChunk({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: providerMetadata.google?.groundingMetadata as WebSearchResults,
source: WebSearchSource.GEMINI
}
})
} else if (final.webSearchResults.length) {
const providerName = Object.keys(providerMetadata || {})[0]
switch (providerName) {
case WebSearchSource.OPENAI:
this.onChunk({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: final.webSearchResults,
source: WebSearchSource.OPENAI_RESPONSE
}
})
break
default:
this.onChunk({
type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
llm_web_search: {
results: final.webSearchResults,
source: WebSearchSource.AISDK
}
})
break
}
}
if (finishReason === 'tool-calls') {
this.onChunk({ type: ChunkType.LLM_RESPONSE_CREATED })
}
final.webSearchResults = []
// final.reasoningId = ''
break
}
case 'finish':
this.onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
text: final.text || '',
reasoning_content: final.reasoningContent || '',
usage: {
completion_tokens: chunk.totalUsage.outputTokens || 0,
prompt_tokens: chunk.totalUsage.inputTokens || 0,
total_tokens: chunk.totalUsage.totalTokens || 0
},
metrics: chunk.totalUsage
? {
completion_tokens: chunk.totalUsage.outputTokens || 0,
time_completion_millsec: 0
}
: undefined
}
})
this.onChunk({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
text: final.text || '',
reasoning_content: final.reasoningContent || '',
usage: {
completion_tokens: chunk.totalUsage.outputTokens || 0,
prompt_tokens: chunk.totalUsage.inputTokens || 0,
total_tokens: chunk.totalUsage.totalTokens || 0
},
metrics: chunk.totalUsage
? {
completion_tokens: chunk.totalUsage.outputTokens || 0,
time_completion_millsec: 0
}
: undefined
}
})
break
// === 源和文件相关事件 ===
case 'source':
if (chunk.sourceType === 'url') {
// if (final.webSearchResults.length === 0) {
// eslint-disable-next-line @typescript-eslint/no-unused-vars
const { sourceType: _, ...rest } = chunk
final.webSearchResults.push(rest)
// }
// this.onChunk({
// type: ChunkType.LLM_WEB_SEARCH_COMPLETE,
// llm_web_search: {
// source: WebSearchSource.AISDK,
// results: final.webSearchResults
// }
// })
}
break
case 'file':
// 文件相关事件,可能是图片生成
this.onChunk({
type: ChunkType.IMAGE_COMPLETE,
image: {
type: 'base64',
images: [`data:${chunk.file.mediaType};base64,${chunk.file.base64}`]
}
})
break
case 'abort':
this.onChunk({
type: ChunkType.ERROR,
error: new DOMException('Request was aborted', 'AbortError')
})
break
case 'error':
this.onChunk({
type: ChunkType.ERROR,
error: chunk.error as Record<string, any>
})
break
default:
// 其他类型的 chunk 可以忽略或记录日志
// console.log('Unhandled AI SDK chunk type:', chunk.type, chunk)
}
}
}
export default AiSdkToChunkAdapter

View File

@@ -0,0 +1,266 @@
/**
* 工具调用 Chunk 处理模块
* TODO: Tool包含了providerTool和普通的Tool还有MCPTool,后面需要重构
* 提供工具调用相关的处理API每个交互使用一个新的实例
*/
import { loggerService } from '@logger'
import { BaseTool, MCPTool, MCPToolResponse, NormalToolResponse } from '@renderer/types'
import { Chunk, ChunkType } from '@renderer/types/chunk'
import type { ProviderMetadata, ToolSet, TypedToolCall, TypedToolResult } from 'ai'
// import type {
// AnthropicSearchOutput,
// WebSearchPluginConfig
// } from '@cherrystudio/ai-core/core/plugins/built-in/webSearchPlugin'
const logger = loggerService.withContext('ToolCallChunkHandler')
/**
* 工具调用处理器类
*/
export class ToolCallChunkHandler {
// private onChunk: (chunk: Chunk) => void
private activeToolCalls = new Map<
string,
{
toolCallId: string
toolName: string
args: any
// mcpTool 现在可以是 MCPTool 或我们为 Provider 工具创建的通用类型
tool: BaseTool
}
>()
constructor(
private onChunk: (chunk: Chunk) => void,
private mcpTools: MCPTool[]
) {}
// /**
// * 设置 onChunk 回调
// */
// public setOnChunk(callback: (chunk: Chunk) => void): void {
// this.onChunk = callback
// }
handleToolCallCreated(
chunk:
| {
type: 'tool-input-start'
id: string
toolName: string
providerMetadata?: ProviderMetadata
providerExecuted?: boolean
}
| {
type: 'tool-input-end'
id: string
providerMetadata?: ProviderMetadata
}
| {
type: 'tool-input-delta'
id: string
delta: string
providerMetadata?: ProviderMetadata
}
): void {
switch (chunk.type) {
case 'tool-input-start': {
// 能拿到说明是mcpTool
// if (this.activeToolCalls.get(chunk.id)) return
const tool: BaseTool | MCPTool = {
id: chunk.id,
name: chunk.toolName,
description: chunk.toolName,
type: chunk.toolName.startsWith('builtin_') ? 'builtin' : 'provider'
}
this.activeToolCalls.set(chunk.id, {
toolCallId: chunk.id,
toolName: chunk.toolName,
args: '',
tool
})
const toolResponse: MCPToolResponse | NormalToolResponse = {
id: chunk.id,
tool: tool,
arguments: {},
status: 'pending',
toolCallId: chunk.id
}
this.onChunk({
type: ChunkType.MCP_TOOL_PENDING,
responses: [toolResponse]
})
break
}
case 'tool-input-delta': {
const toolCall = this.activeToolCalls.get(chunk.id)
if (!toolCall) {
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
return
}
toolCall.args += chunk.delta
break
}
case 'tool-input-end': {
const toolCall = this.activeToolCalls.get(chunk.id)
this.activeToolCalls.delete(chunk.id)
if (!toolCall) {
logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
return
}
// const toolResponse: ToolCallResponse = {
// id: toolCall.toolCallId,
// tool: toolCall.tool,
// arguments: toolCall.args,
// status: 'pending',
// toolCallId: toolCall.toolCallId
// }
// logger.debug('toolResponse', toolResponse)
// this.onChunk({
// type: ChunkType.MCP_TOOL_PENDING,
// responses: [toolResponse]
// })
break
}
}
// if (!toolCall) {
// Logger.warn(`🔧 [ToolCallChunkHandler] Tool call not found: ${chunk.id}`)
// return
// }
// this.onChunk({
// type: ChunkType.MCP_TOOL_CREATED,
// tool_calls: [
// {
// id: chunk.id,
// name: chunk.toolName,
// status: 'pending'
// }
// ]
// })
}
/**
* 处理工具调用事件
*/
public handleToolCall(
chunk: {
type: 'tool-call'
} & TypedToolCall<ToolSet>
): void {
const { toolCallId, toolName, input: args, providerExecuted } = chunk
if (!toolCallId || !toolName) {
logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool call chunk: missing toolCallId or toolName`)
return
}
let tool: BaseTool
let mcpTool: MCPTool | undefined
// 根据 providerExecuted 标志区分处理逻辑
if (providerExecuted) {
// 如果是 Provider 执行的工具(如 web_search
logger.info(`[ToolCallChunkHandler] Handling provider-executed tool: ${toolName}`)
tool = {
id: toolCallId,
name: toolName,
description: toolName,
type: 'provider'
} as BaseTool
} else if (toolName.startsWith('builtin_')) {
// 如果是内置工具,沿用现有逻辑
logger.info(`[ToolCallChunkHandler] Handling builtin tool: ${toolName}`)
tool = {
id: toolCallId,
name: toolName,
description: toolName,
type: 'builtin'
} as BaseTool
} else if ((mcpTool = this.mcpTools.find((t) => t.name === toolName) as MCPTool)) {
// 如果是客户端执行的 MCP 工具,沿用现有逻辑
logger.info(`[ToolCallChunkHandler] Handling client-side MCP tool: ${toolName}`)
// mcpTool = this.mcpTools.find((t) => t.name === toolName) as MCPTool
// if (!mcpTool) {
// logger.warn(`[ToolCallChunkHandler] MCP tool not found: ${toolName}`)
// return
// }
tool = mcpTool
} else {
tool = {
id: toolCallId,
name: toolName,
description: toolName,
type: 'provider'
}
}
// 记录活跃的工具调用
this.activeToolCalls.set(toolCallId, {
toolCallId,
toolName,
args,
tool
})
// 创建 MCPToolResponse 格式
const toolResponse: MCPToolResponse | NormalToolResponse = {
id: toolCallId,
tool: tool,
arguments: args,
status: 'pending',
toolCallId: toolCallId
}
// 调用 onChunk
if (this.onChunk) {
this.onChunk({
type: ChunkType.MCP_TOOL_PENDING,
responses: [toolResponse]
})
}
}
/**
* 处理工具调用结果事件
*/
public handleToolResult(
chunk: {
type: 'tool-result'
} & TypedToolResult<ToolSet>
): void {
const { toolCallId, output, input } = chunk
if (!toolCallId) {
logger.warn(`🔧 [ToolCallChunkHandler] Invalid tool result chunk: missing toolCallId`)
return
}
// 查找对应的工具调用信息
const toolCallInfo = this.activeToolCalls.get(toolCallId)
if (!toolCallInfo) {
logger.warn(`🔧 [ToolCallChunkHandler] Tool call info not found for ID: ${toolCallId}`)
return
}
// 创建工具调用结果的 MCPToolResponse 格式
const toolResponse: MCPToolResponse | NormalToolResponse = {
id: toolCallInfo.toolCallId,
tool: toolCallInfo.tool,
arguments: input,
status: 'done',
response: output,
toolCallId: toolCallId
}
// 从活跃调用中移除(交互结束后整个实例会被丢弃)
this.activeToolCalls.delete(toolCallId)
// 调用 onChunk
if (this.onChunk) {
this.onChunk({
type: ChunkType.MCP_TOOL_COMPLETE,
responses: [toolResponse]
})
}
}
}

View File

@@ -1,189 +1,16 @@
import { loggerService } from '@logger'
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
import { getProviderByModel } from '@renderer/services/AssistantService'
import { withSpanResult } from '@renderer/services/SpanManagerService'
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
import type { RequestOptions, SdkModel } from '@renderer/types/sdk'
import { isEnabledToolUse } from '@renderer/utils/mcp-tools'
/**
* Cherry Studio AI Core - 统一入口点
*
* 这是新的统一入口,保持向后兼容性
* 默认导出legacy AiProvider以保持现有代码的兼容性
*/
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
import { VertexAPIClient } from './clients/gemini/VertexAPIClient'
import { NewAPIClient } from './clients/NewAPIClient'
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
import { CompletionsMiddlewareBuilder } from './middleware/builder'
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
import { applyCompletionsMiddlewares } from './middleware/composer'
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
import { MiddlewareRegistry } from './middleware/register'
import type { CompletionsParams, CompletionsResult } from './middleware/schemas'
// 导出Legacy AiProvider作为默认导出保持向后兼容
export { default } from './legacy/index'
const logger = loggerService.withContext('AiProvider')
// 同时导出Modern AiProvider供新代码使用
export { default as ModernAiProvider } from './index_new'
export default class AiProvider {
private apiClient: BaseApiClient
constructor(provider: Provider) {
// Use the new ApiClientFactory to get a BaseApiClient instance
this.apiClient = ApiClientFactory.create(provider)
}
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
// 1. 根据模型识别正确的客户端
const model = params.assistant.model
if (!model) {
return Promise.reject(new Error('Model is required'))
}
// 根据client类型选择合适的处理方式
let client: BaseApiClient
if (this.apiClient instanceof AihubmixAPIClient) {
// AihubmixAPIClient: 根据模型选择合适的子client
client = this.apiClient.getClientForModel(model)
if (client instanceof OpenAIResponseAPIClient) {
client = client.getClient(model) as BaseApiClient
}
} else if (this.apiClient instanceof NewAPIClient) {
client = this.apiClient.getClientForModel(model)
if (client instanceof OpenAIResponseAPIClient) {
client = client.getClient(model) as BaseApiClient
}
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
// OpenAIResponseAPIClient: 根据模型特征选择API类型
client = this.apiClient.getClient(model) as BaseApiClient
} else if (this.apiClient instanceof VertexAPIClient) {
client = this.apiClient.getClient(model) as BaseApiClient
} else {
// 其他client直接使用
client = this.apiClient
}
// 2. 构建中间件链
const builder = CompletionsMiddlewareBuilder.withDefaults()
// images api
if (isDedicatedImageGenerationModel(model)) {
builder.clear()
builder
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
} else {
// Existing logic for other models
logger.silly('Builder Params', params)
// 使用兼容性类型检查避免typescript类型收窄和装饰器模式的问题
const clientTypes = client.getClientCompatibilityType(model)
const isOpenAICompatible =
clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
if (!isOpenAICompatible) {
logger.silly('ThinkingTagExtractionMiddleware is removed')
builder.remove(ThinkingTagExtractionMiddlewareName)
}
const isAnthropicOrOpenAIResponseCompatible =
clientTypes.includes('AnthropicAPIClient') ||
clientTypes.includes('OpenAIResponseAPIClient') ||
clientTypes.includes('AnthropicVertexAPIClient')
if (!isAnthropicOrOpenAIResponseCompatible) {
logger.silly('RawStreamListenerMiddleware is removed')
builder.remove(RawStreamListenerMiddlewareName)
}
if (!params.enableWebSearch) {
logger.silly('WebSearchMiddleware is removed')
builder.remove(WebSearchMiddlewareName)
}
if (!params.mcpTools?.length) {
builder.remove(ToolUseExtractionMiddlewareName)
logger.silly('ToolUseExtractionMiddleware is removed')
builder.remove(McpToolChunkMiddlewareName)
logger.silly('McpToolChunkMiddleware is removed')
}
if (isEnabledToolUse(params.assistant) && isFunctionCallingModel(model)) {
builder.remove(ToolUseExtractionMiddlewareName)
logger.silly('ToolUseExtractionMiddleware is removed')
}
if (params.callType !== 'chat' && params.callType !== 'check' && params.callType !== 'translate') {
logger.silly('AbortHandlerMiddleware is removed')
builder.remove(AbortHandlerMiddlewareName)
}
if (params.callType === 'test') {
builder.remove(ErrorHandlerMiddlewareName)
logger.silly('ErrorHandlerMiddleware is removed')
builder.remove(FinalChunkConsumerMiddlewareName)
logger.silly('FinalChunkConsumerMiddleware is removed')
}
}
const middlewares = builder.build()
logger.silly(
'middlewares',
middlewares.map((m) => m.name)
)
// 3. Create the wrapped SDK method with middlewares
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
// 4. Execute the wrapped method with the original params
const result = wrappedCompletionMethod(params, options)
return result
}
public async completionsForTrace(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
const traceName = params.assistant.model?.name
? `${params.assistant.model?.name}.${params.callType}`
: `LLM.${params.callType}`
const traceParams: StartSpanParams = {
name: traceName,
tag: 'LLM',
topicId: params.topicId || '',
modelName: params.assistant.model?.name
}
return await withSpanResult(this.completions.bind(this), traceParams, params, options)
}
public async models(): Promise<SdkModel[]> {
return this.apiClient.listModels()
}
public async getEmbeddingDimensions(model: Model): Promise<number> {
try {
// Use the SDK instance to test embedding capabilities
if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') {
this.apiClient = this.apiClient.getClient(model) as BaseApiClient
}
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
return dimensions
} catch (error) {
logger.error('Error getting embedding dimensions:', error as Error)
throw error
}
}
public async generateImage(params: GenerateImageParams): Promise<string[]> {
if (this.apiClient instanceof AihubmixAPIClient) {
const client = this.apiClient.getClientForModel({ id: params.model } as Model)
return client.generateImage(params)
}
return this.apiClient.generateImage(params)
}
public getBaseURL(): string {
return this.apiClient.getBaseURL()
}
public getApiKey(): string {
return this.apiClient.getApiKey()
}
}
// 导出一些常用的类型和工具
export * from './legacy/clients/types'
export * from './legacy/middleware/schemas'

View File

@@ -0,0 +1,518 @@
/**
* Cherry Studio AI Core - 新版本入口
* 集成 @cherrystudio/ai-core 库的渐进式重构方案
*
* 融合方案:简化实现,专注于核心功能
* 1. 优先使用新AI SDK
* 2. 暂时保持接口兼容性
*/
import { createExecutor, generateImage } from '@cherrystudio/ai-core'
import { createAndRegisterProvider } from '@cherrystudio/ai-core/provider'
import { loggerService } from '@logger'
import { isNotSupportedImageSizeModel } from '@renderer/config/models'
import { getEnableDeveloperMode } from '@renderer/hooks/useSettings'
import { addSpan, endSpan } from '@renderer/services/SpanManagerService'
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
import type { Assistant, GenerateImageParams, Model, Provider } from '@renderer/types'
import { ChunkType } from '@renderer/types/chunk'
import AiSdkToChunkAdapter from './chunk/AiSdkToChunkAdapter'
import LegacyAiProvider from './legacy/index'
import { CompletionsResult } from './legacy/middleware/schemas'
import { AiSdkMiddlewareConfig, buildAiSdkMiddlewares } from './middleware/AiSdkMiddlewareBuilder'
import { buildPlugins } from './plugins/PluginBuilder'
import {
getActualProvider,
isModernSdkSupported,
prepareSpecialProviderConfig,
providerToAiSdkConfig
} from './provider/providerConfig'
import type { StreamTextParams } from './types'
const logger = loggerService.withContext('ModernAiProvider')
export default class ModernAiProvider {
private legacyProvider: LegacyAiProvider
private config: ReturnType<typeof providerToAiSdkConfig>
private actualProvider: Provider
constructor(model: Model, provider?: Provider) {
this.actualProvider = provider || getActualProvider(model)
this.legacyProvider = new LegacyAiProvider(this.actualProvider)
// 只保存配置不预先创建executor
this.config = providerToAiSdkConfig(this.actualProvider, model)
}
public getActualProvider() {
return this.actualProvider
}
public async completions(
modelId: string,
params: StreamTextParams,
config: AiSdkMiddlewareConfig & {
assistant: Assistant
// topicId for tracing
topicId?: string
callType: string
}
) {
// 准备特殊配置
await prepareSpecialProviderConfig(this.actualProvider, this.config)
console.log('this.config', this.config)
if (config.topicId && getEnableDeveloperMode()) {
// TypeScript类型窄化确保topicId是string类型
const traceConfig = {
...config,
topicId: config.topicId
}
return await this._completionsForTrace(modelId, params, traceConfig)
} else {
return await this._completions(modelId, params, config)
}
}
private async _completions(
modelId: string,
params: StreamTextParams,
config: AiSdkMiddlewareConfig & {
assistant: Assistant
// topicId for tracing
topicId?: string
callType: string
}
): Promise<CompletionsResult> {
// 初始化 provider 到全局管理器
try {
await createAndRegisterProvider(this.config.providerId, this.config.options)
logger.debug('Provider initialized successfully', {
providerId: this.config.providerId,
hasOptions: !!this.config.options
})
} catch (error) {
// 如果 provider 已经初始化过,可能会抛出错误,这里可以忽略
logger.debug('Provider initialization skipped (may already be initialized)', {
providerId: this.config.providerId,
error: error instanceof Error ? error.message : String(error)
})
}
if (config.isImageGenerationEndpoint) {
return await this.modernImageGeneration(modelId, params, config)
}
return await this.modernCompletions(modelId, params, config)
}
/**
* 带trace支持的completions方法
* 类似于legacy的completionsForTrace确保AI SDK spans在正确的trace上下文中
*/
private async _completionsForTrace(
modelId: string,
params: StreamTextParams,
config: AiSdkMiddlewareConfig & {
assistant: Assistant
// topicId for tracing
topicId: string
callType: string
}
): Promise<CompletionsResult> {
const traceName = `${this.actualProvider.name}.${modelId}.${config.callType}`
const traceParams: StartSpanParams = {
name: traceName,
tag: 'LLM',
topicId: config.topicId,
modelName: config.assistant.model?.name, // 使用modelId而不是provider名称
inputs: params
}
logger.info('Starting AI SDK trace span', {
traceName,
topicId: config.topicId,
modelId,
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
toolNames: params.tools ? Object.keys(params.tools) : [],
isImageGeneration: config.isImageGenerationEndpoint
})
const span = addSpan(traceParams)
if (!span) {
logger.warn('Failed to create span, falling back to regular completions', {
topicId: config.topicId,
modelId,
traceName
})
return await this._completions(modelId, params, config)
}
try {
logger.info('Created parent span, now calling completions', {
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: config.topicId,
modelId,
parentSpanCreated: true
})
const result = await this._completions(modelId, params, config)
logger.info('Completions finished, ending parent span', {
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: config.topicId,
modelId,
resultLength: result.getText().length
})
// 标记span完成
endSpan({
topicId: config.topicId,
outputs: result,
span,
modelName: modelId // 使用modelId保持一致性
})
return result
} catch (error) {
logger.error('Error in completionsForTrace, ending parent span with error', error as Error, {
spanId: span.spanContext().spanId,
traceId: span.spanContext().traceId,
topicId: config.topicId,
modelId
})
// 标记span出错
endSpan({
topicId: config.topicId,
error: error as Error,
span,
modelName: modelId // 使用modelId保持一致性
})
throw error
}
}
/**
* 使用现代化AI SDK的completions实现
*/
private async modernCompletions(
modelId: string,
params: StreamTextParams,
config: AiSdkMiddlewareConfig & {
assistant: Assistant
topicId?: string
callType: string
}
): Promise<CompletionsResult> {
logger.info('Starting modernCompletions', {
modelId,
providerId: this.config.providerId,
topicId: config.topicId,
hasOnChunk: !!config.onChunk,
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
toolCount: params.tools ? Object.keys(params.tools).length : 0
})
// 根据条件构建插件数组
const plugins = buildPlugins(config)
logger.debug('Built plugins for AI SDK', {
pluginCount: plugins.length,
pluginNames: plugins.map((p) => p.name),
providerId: this.config.providerId,
topicId: config.topicId
})
// 用构建好的插件数组创建executor
const executor = createExecutor(this.config.providerId, this.config.options, plugins)
logger.debug('Created AI SDK executor', {
providerId: this.config.providerId,
hasOptions: !!this.config.options,
pluginCount: plugins.length
})
// 动态构建中间件数组
const middlewares = buildAiSdkMiddlewares(config)
logger.debug('Built AI SDK middlewares', {
middlewareCount: middlewares.length,
topicId: config.topicId
})
// 创建带有中间件的执行器
if (config.onChunk) {
// 流式处理 - 使用适配器
logger.info('Starting streaming with chunk adapter', {
modelId,
hasMiddlewares: middlewares.length > 0,
middlewareCount: middlewares.length,
hasMcpTools: !!config.mcpTools,
mcpToolCount: config.mcpTools?.length || 0,
topicId: config.topicId
})
const adapter = new AiSdkToChunkAdapter(config.onChunk, config.mcpTools)
logger.debug('Final params before streamText', {
modelId,
hasMessages: !!params.messages,
messageCount: params.messages?.length || 0,
hasTools: !!params.tools && Object.keys(params.tools).length > 0,
toolNames: params.tools ? Object.keys(params.tools) : [],
hasSystem: !!params.system,
topicId: config.topicId
})
const streamResult = await executor.streamText(
modelId,
{ ...params, experimental_context: { onChunk: config.onChunk } },
middlewares.length > 0 ? { middlewares } : undefined
)
logger.info('StreamText call successful, processing stream', {
modelId,
topicId: config.topicId,
hasFullStream: !!streamResult.fullStream
})
const finalText = await adapter.processStream(streamResult)
logger.info('Stream processing completed', {
modelId,
topicId: config.topicId,
finalTextLength: finalText.length
})
return {
getText: () => finalText
}
} else {
// 流式处理但没有 onChunk 回调
logger.info('Starting streaming without chunk callback', {
modelId,
hasMiddlewares: middlewares.length > 0,
middlewareCount: middlewares.length,
topicId: config.topicId
})
const streamResult = await executor.streamText(
modelId,
params,
middlewares.length > 0 ? { middlewares } : undefined
)
logger.info('StreamText call successful, waiting for text', {
modelId,
topicId: config.topicId
})
// 强制消费流,不然await streamResult.text会阻塞
await streamResult?.consumeStream()
const finalText = await streamResult.text
logger.info('Text extraction completed', {
modelId,
topicId: config.topicId,
finalTextLength: finalText.length
})
return {
getText: () => finalText
}
}
// }
// catch (error) {
// console.error('Modern AI SDK error:', error)
// throw error
// }
}
/**
* 使用现代化 AI SDK 的图像生成实现,支持流式输出
*/
private async modernImageGeneration(
modelId: string,
params: StreamTextParams,
config: AiSdkMiddlewareConfig & {
assistant: Assistant
// topicId for tracing
topicId?: string
callType: string
}
): Promise<CompletionsResult> {
const { onChunk } = config
try {
// 检查 messages 是否存在
if (!params.messages || params.messages.length === 0) {
throw new Error('No messages provided for image generation.')
}
// 从最后一条用户消息中提取 prompt
const lastUserMessage = params.messages.findLast((m) => m.role === 'user')
if (!lastUserMessage) {
throw new Error('No user message found for image generation.')
}
// 直接使用消息内容,避免类型转换问题
const prompt =
typeof lastUserMessage.content === 'string'
? lastUserMessage.content
: lastUserMessage.content?.map((part) => ('text' in part ? part.text : '')).join('') || ''
if (!prompt) {
throw new Error('No prompt found in user message.')
}
const startTime = Date.now()
// 发送图像生成开始事件
if (onChunk) {
onChunk({ type: ChunkType.IMAGE_CREATED })
}
// 构建图像生成参数
const imageParams = {
prompt,
size: isNotSupportedImageSizeModel(config.model) ? undefined : ('1024x1024' as `${number}x${number}`), // 默认尺寸,使用正确的类型
n: 1,
...(params.abortSignal && { abortSignal: params.abortSignal })
}
// 调用新 AI SDK 的图像生成功能
const result = await generateImage(this.config.providerId, this.config.options, modelId, imageParams)
// 转换结果格式
const images: string[] = []
const imageType: 'url' | 'base64' = 'base64'
if (result.images) {
for (const image of result.images) {
if ('base64' in image && image.base64) {
images.push(`data:${image.mediaType};base64,${image.base64}`)
}
}
}
// 发送图像生成完成事件
if (onChunk && images.length > 0) {
onChunk({
type: ChunkType.IMAGE_COMPLETE,
image: { type: imageType, images }
})
}
// 发送块完成事件(类似于 modernCompletions 的处理)
if (onChunk) {
const usage = {
prompt_tokens: prompt.length, // 估算的 token 数量
completion_tokens: 0, // 图像生成没有 completion tokens
total_tokens: prompt.length
}
onChunk({
type: ChunkType.BLOCK_COMPLETE,
response: {
usage,
metrics: {
completion_tokens: usage.completion_tokens,
time_first_token_millsec: 0,
time_completion_millsec: Date.now() - startTime
}
}
})
// 发送 LLM 响应完成事件
onChunk({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage,
metrics: {
completion_tokens: usage.completion_tokens,
time_first_token_millsec: 0,
time_completion_millsec: Date.now() - startTime
}
}
})
}
return {
getText: () => '' // 图像生成不返回文本
}
} catch (error) {
// 发送错误事件
if (onChunk) {
onChunk({ type: ChunkType.ERROR, error: error as any })
}
throw error
}
}
// 代理其他方法到原有实现
public async models() {
return this.legacyProvider.models()
}
public async getEmbeddingDimensions(model: Model): Promise<number> {
return this.legacyProvider.getEmbeddingDimensions(model)
}
public async generateImage(params: GenerateImageParams): Promise<string[]> {
// 如果支持新的 AI SDK使用现代化实现
if (isModernSdkSupported(this.actualProvider)) {
try {
const result = await this.modernGenerateImage(params)
return result
} catch (error) {
logger.warn('Modern AI SDK generateImage failed, falling back to legacy:', error as Error)
// fallback 到传统实现
return this.legacyProvider.generateImage(params)
}
}
// 直接使用传统实现
return this.legacyProvider.generateImage(params)
}
/**
* 使用现代化 AI SDK 的图像生成实现
*/
private async modernGenerateImage(params: GenerateImageParams): Promise<string[]> {
const { model, prompt, imageSize, batchSize, signal } = params
// 转换参数格式
const aiSdkParams = {
prompt,
size: (imageSize || '1024x1024') as `${number}x${number}`,
n: batchSize || 1,
...(signal && { abortSignal: signal })
}
const result = await generateImage(this.config.providerId, this.config.options, model, aiSdkParams)
// 转换结果格式
const images: string[] = []
if (result.images) {
for (const image of result.images) {
if ('base64' in image && image.base64) {
images.push(`data:image/png;base64,${image.base64}`)
}
}
}
return images
}
public getBaseURL(): string {
return this.legacyProvider.getBaseURL()
}
public getApiKey(): string {
return this.legacyProvider.getApiKey()
}
}
// 为了方便调试,导出一些工具函数
export { isModernSdkSupported, providerToAiSdkConfig }

View File

@@ -1,4 +1,5 @@
import { loggerService } from '@logger'
import { isVertexAIConfigured } from '@renderer/hooks/useVertexAI'
import { Provider } from '@renderer/types'
import { AihubmixAPIClient } from './AihubmixAPIClient'
@@ -61,6 +62,13 @@ export class ApiClientFactory {
instance = new GeminiAPIClient(provider) as BaseApiClient
break
case 'vertexai':
logger.debug(`Creating VertexAPIClient for provider: ${provider.id}`)
// 检查 VertexAI 配置
if (!isVertexAIConfigured()) {
throw new Error(
'VertexAI is not configured. Please configure project, location and service account credentials.'
)
}
instance = new VertexAPIClient(provider) as BaseApiClient
break
case 'anthropic':

View File

@@ -1,11 +1,11 @@
import { AihubmixAPIClient } from '@renderer/aiCore/clients/AihubmixAPIClient'
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
import { ApiClientFactory } from '@renderer/aiCore/clients/ApiClientFactory'
import { GeminiAPIClient } from '@renderer/aiCore/clients/gemini/GeminiAPIClient'
import { VertexAPIClient } from '@renderer/aiCore/clients/gemini/VertexAPIClient'
import { NewAPIClient } from '@renderer/aiCore/clients/NewAPIClient'
import { OpenAIAPIClient } from '@renderer/aiCore/clients/openai/OpenAIApiClient'
import { OpenAIResponseAPIClient } from '@renderer/aiCore/clients/openai/OpenAIResponseAPIClient'
import { AihubmixAPIClient } from '@renderer/aiCore/legacy/clients/AihubmixAPIClient'
import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient'
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
import { GeminiAPIClient } from '@renderer/aiCore/legacy/clients/gemini/GeminiAPIClient'
import { VertexAPIClient } from '@renderer/aiCore/legacy/clients/gemini/VertexAPIClient'
import { NewAPIClient } from '@renderer/aiCore/legacy/clients/NewAPIClient'
import { OpenAIAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIApiClient'
import { OpenAIResponseAPIClient } from '@renderer/aiCore/legacy/clients/openai/OpenAIResponseAPIClient'
import { EndpointType, Model, Provider } from '@renderer/types'
import { beforeEach, describe, expect, it, vi } from 'vitest'

View File

@@ -25,7 +25,6 @@ import {
import { MessageStream } from '@anthropic-ai/sdk/resources/messages/messages'
import AnthropicVertex from '@anthropic-ai/vertex-sdk'
import { loggerService } from '@logger'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { findTokenLimit, isClaudeReasoningModel, isReasoningModel, isWebSearchModel } from '@renderer/config/models'
import { getAssistantSettings } from '@renderer/services/AssistantService'
@@ -64,13 +63,14 @@ import {
import { addImageFileToContents } from '@renderer/utils/formats'
import {
anthropicToolUseToMcpTool,
isEnabledToolUse,
isSupportedToolUse,
mcpToolCallResponseToAnthropicMessage,
mcpToolsToAnthropicTools
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { t } from 'i18next'
import { GenericChunk } from '../../middleware/schemas'
import { BaseApiClient } from '../BaseApiClient'
import { AnthropicStreamListener, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from '../types'
@@ -457,7 +457,7 @@ export class AnthropicAPIClient extends BaseApiClient<
const { tools } = this.setupToolsConfig({
mcpTools: mcpTools,
model,
enableToolUse: isEnabledToolUse(assistant)
enableToolUse: isSupportedToolUse(assistant)
})
const systemMessage: TextBlockParam | undefined = systemPrompt

View File

@@ -6,7 +6,7 @@ import {
InvokeModelWithResponseStreamCommand
} from '@aws-sdk/client-bedrock-runtime'
import { loggerService } from '@logger'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas'
import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { findTokenLimit, isReasoningModel } from '@renderer/config/models'
import {
@@ -50,7 +50,7 @@ import {
import { convertBase64ImageToAwsBedrockFormat } from '@renderer/utils/aws-bedrock-utils'
import {
awsBedrockToolUseToMcpTool,
isEnabledToolUse,
isSupportedToolUse,
mcpToolCallResponseToAwsBedrockMessage,
mcpToolsToAwsBedrockTools
} from '@renderer/utils/mcp-tools'
@@ -739,7 +739,7 @@ export class AwsBedrockAPIClient extends BaseApiClient<
const { tools } = this.setupToolsConfig({
mcpTools: mcpTools,
model,
enableToolUse: isEnabledToolUse(assistant)
enableToolUse: isSupportedToolUse(assistant)
})
// 3. 处理消息

View File

@@ -18,7 +18,6 @@ import {
} from '@google/genai'
import { loggerService } from '@logger'
import { nanoid } from '@reduxjs/toolkit'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import {
findTokenLimit,
GEMINI_FLASH_MODEL_REGEX,
@@ -55,7 +54,7 @@ import {
import { isToolUseModeFunction } from '@renderer/utils/assistant'
import {
geminiFunctionCallToMcpTool,
isEnabledToolUse,
isSupportedToolUse,
mcpToolCallResponseToGeminiMessage,
mcpToolsToGeminiTools
} from '@renderer/utils/mcp-tools'
@@ -63,6 +62,7 @@ import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/u
import { defaultTimeout, MB } from '@shared/config/constant'
import { t } from 'i18next'
import { GenericChunk } from '../../middleware/schemas'
import { BaseApiClient } from '../BaseApiClient'
import { RequestTransformer, ResponseChunkTransformer } from '../types'
@@ -454,7 +454,7 @@ export class GeminiAPIClient extends BaseApiClient<
const { tools } = this.setupToolsConfig({
mcpTools,
model,
enableToolUse: isEnabledToolUse(assistant)
enableToolUse: isSupportedToolUse(assistant)
})
let messageContents: Content = { role: 'user', parts: [] } // Initialize messageContents

View File

@@ -1,7 +1,7 @@
import { GoogleGenAI } from '@google/genai'
import { loggerService } from '@logger'
import { getVertexAILocation, getVertexAIProjectId, getVertexAIServiceAccount } from '@renderer/hooks/useVertexAI'
import { Model, Provider } from '@renderer/types'
import { createVertexProvider, isVertexProvider } from '@renderer/hooks/useVertexAI'
import { Model, Provider, VertexProvider } from '@renderer/types'
import { isEmpty } from 'lodash'
import { AnthropicVertexClient } from '../anthropic/AnthropicVertexClient'
@@ -12,10 +12,17 @@ export class VertexAPIClient extends GeminiAPIClient {
private authHeaders?: Record<string, string>
private authHeadersExpiry?: number
private anthropicVertexClient: AnthropicVertexClient
private vertexProvider: VertexProvider
constructor(provider: Provider) {
super(provider)
this.anthropicVertexClient = new AnthropicVertexClient(provider)
// 如果传入的是普通 Provider转换为 VertexProvider
if (isVertexProvider(provider)) {
this.vertexProvider = provider
} else {
this.vertexProvider = createVertexProvider(provider)
}
}
override getClientCompatibilityType(model?: Model): string[] {
@@ -56,11 +63,9 @@ export class VertexAPIClient extends GeminiAPIClient {
return this.sdkInstance
}
const serviceAccount = getVertexAIServiceAccount()
const projectId = getVertexAIProjectId()
const location = getVertexAILocation()
const { googleCredentials, project, location } = this.vertexProvider
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId || !location) {
if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project || !location) {
throw new Error('Vertex AI settings are not configured')
}
@@ -68,7 +73,7 @@ export class VertexAPIClient extends GeminiAPIClient {
this.sdkInstance = new GoogleGenAI({
vertexai: true,
project: projectId,
project: project,
location: location,
httpOptions: {
apiVersion: this.getApiVersion(),
@@ -84,11 +89,10 @@ export class VertexAPIClient extends GeminiAPIClient {
* service account
*/
private async getServiceAccountAuthHeaders(): Promise<Record<string, string> | undefined> {
const serviceAccount = getVertexAIServiceAccount()
const projectId = getVertexAIProjectId()
const { googleCredentials, project } = this.vertexProvider
// 检查是否配置了 service account
if (!serviceAccount.privateKey || !serviceAccount.clientEmail || !projectId) {
if (!googleCredentials.privateKey || !googleCredentials.clientEmail || !project) {
return undefined
}
@@ -101,10 +105,10 @@ export class VertexAPIClient extends GeminiAPIClient {
try {
// 从主进程获取认证头
this.authHeaders = await window.api.vertexAI.getAuthHeaders({
projectId,
projectId: project,
serviceAccount: {
privateKey: serviceAccount.privateKey,
clientEmail: serviceAccount.clientEmail
privateKey: googleCredentials.privateKey,
clientEmail: googleCredentials.clientEmail
}
})
@@ -125,11 +129,10 @@ export class VertexAPIClient extends GeminiAPIClient {
this.authHeaders = undefined
this.authHeadersExpiry = undefined
const serviceAccount = getVertexAIServiceAccount()
const projectId = getVertexAIProjectId()
const { googleCredentials, project } = this.vertexProvider
if (projectId && serviceAccount.clientEmail) {
window.api.vertexAI.clearAuthCache(projectId, serviceAccount.clientEmail)
if (project && googleCredentials.clientEmail) {
window.api.vertexAI.clearAuthCache(project, googleCredentials.clientEmail)
}
}
}

View File

@@ -69,7 +69,7 @@ import {
} from '@renderer/types/sdk'
import { addImageFileToContents } from '@renderer/utils/formats'
import {
isEnabledToolUse,
isSupportedToolUse,
mcpToolCallResponseToOpenAICompatibleMessage,
mcpToolsToOpenAIChatTools,
openAIToolsToMcpTool
@@ -598,7 +598,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
const { tools } = this.setupToolsConfig({
mcpTools: mcpTools,
model,
enableToolUse: isEnabledToolUse(assistant)
enableToolUse: isSupportedToolUse(assistant)
})
// 3. 处理用户消息

View File

@@ -1,6 +1,6 @@
import { loggerService } from '@logger'
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { CompletionsContext } from '@renderer/aiCore/middleware/types'
import { GenericChunk } from '@renderer/aiCore/legacy/middleware/schemas'
import { CompletionsContext } from '@renderer/aiCore/legacy/middleware/types'
import {
isGPT5SeriesModel,
isOpenAIChatCompletionOnlyModel,
@@ -36,7 +36,7 @@ import {
} from '@renderer/types/sdk'
import { addImageFileToContents } from '@renderer/utils/formats'
import {
isEnabledToolUse,
isSupportedToolUse,
mcpToolCallResponseToOpenAIMessage,
mcpToolsToOpenAIResponseTools,
openAIToolsToMcpTool
@@ -388,7 +388,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
const { tools: extraTools } = this.setupToolsConfig({
mcpTools: mcpTools,
model,
enableToolUse: isEnabledToolUse(assistant)
enableToolUse: isSupportedToolUse(assistant)
})
systemMessageContent.push(systemMessageInput)

View File

@@ -0,0 +1,189 @@
import { loggerService } from '@logger'
import { ApiClientFactory } from '@renderer/aiCore/legacy/clients/ApiClientFactory'
import { BaseApiClient } from '@renderer/aiCore/legacy/clients/BaseApiClient'
import { isDedicatedImageGenerationModel, isFunctionCallingModel } from '@renderer/config/models'
import { getProviderByModel } from '@renderer/services/AssistantService'
import { withSpanResult } from '@renderer/services/SpanManagerService'
import { StartSpanParams } from '@renderer/trace/types/ModelSpanEntity'
import type { GenerateImageParams, Model, Provider } from '@renderer/types'
import type { RequestOptions, SdkModel } from '@renderer/types/sdk'
import { isSupportedToolUse } from '@renderer/utils/mcp-tools'
import { AihubmixAPIClient } from './clients/AihubmixAPIClient'
import { VertexAPIClient } from './clients/gemini/VertexAPIClient'
import { NewAPIClient } from './clients/NewAPIClient'
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
import { CompletionsMiddlewareBuilder } from './middleware/builder'
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
import { applyCompletionsMiddlewares } from './middleware/composer'
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
import { MIDDLEWARE_NAME as RawStreamListenerMiddlewareName } from './middleware/core/RawStreamListenerMiddleware'
import { MIDDLEWARE_NAME as WebSearchMiddlewareName } from './middleware/core/WebSearchMiddleware'
import { MIDDLEWARE_NAME as ImageGenerationMiddlewareName } from './middleware/feat/ImageGenerationMiddleware'
import { MIDDLEWARE_NAME as ThinkingTagExtractionMiddlewareName } from './middleware/feat/ThinkingTagExtractionMiddleware'
import { MIDDLEWARE_NAME as ToolUseExtractionMiddlewareName } from './middleware/feat/ToolUseExtractionMiddleware'
import { MiddlewareRegistry } from './middleware/register'
import type { CompletionsParams, CompletionsResult } from './middleware/schemas'
const logger = loggerService.withContext('AiProvider')
export default class AiProvider {
private apiClient: BaseApiClient
constructor(provider: Provider) {
// Use the new ApiClientFactory to get a BaseApiClient instance
this.apiClient = ApiClientFactory.create(provider)
}
public async completions(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
// 1. 根据模型识别正确的客户端
const model = params.assistant.model
if (!model) {
return Promise.reject(new Error('Model is required'))
}
// 根据client类型选择合适的处理方式
let client: BaseApiClient
if (this.apiClient instanceof AihubmixAPIClient) {
// AihubmixAPIClient: 根据模型选择合适的子client
client = this.apiClient.getClientForModel(model)
if (client instanceof OpenAIResponseAPIClient) {
client = client.getClient(model) as BaseApiClient
}
} else if (this.apiClient instanceof NewAPIClient) {
client = this.apiClient.getClientForModel(model)
if (client instanceof OpenAIResponseAPIClient) {
client = client.getClient(model) as BaseApiClient
}
} else if (this.apiClient instanceof OpenAIResponseAPIClient) {
// OpenAIResponseAPIClient: 根据模型特征选择API类型
client = this.apiClient.getClient(model) as BaseApiClient
} else if (this.apiClient instanceof VertexAPIClient) {
client = this.apiClient.getClient(model) as BaseApiClient
} else {
// 其他client直接使用
client = this.apiClient
}
// 2. 构建中间件链
const builder = CompletionsMiddlewareBuilder.withDefaults()
// images api
if (isDedicatedImageGenerationModel(model)) {
builder.clear()
builder
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
} else {
// Existing logic for other models
logger.silly('Builder Params', params)
// 使用兼容性类型检查避免typescript类型收窄和装饰器模式的问题
const clientTypes = client.getClientCompatibilityType(model)
const isOpenAICompatible =
clientTypes.includes('OpenAIAPIClient') || clientTypes.includes('OpenAIResponseAPIClient')
if (!isOpenAICompatible) {
logger.silly('ThinkingTagExtractionMiddleware is removed')
builder.remove(ThinkingTagExtractionMiddlewareName)
}
const isAnthropicOrOpenAIResponseCompatible =
clientTypes.includes('AnthropicAPIClient') ||
clientTypes.includes('OpenAIResponseAPIClient') ||
clientTypes.includes('AnthropicVertexAPIClient')
if (!isAnthropicOrOpenAIResponseCompatible) {
logger.silly('RawStreamListenerMiddleware is removed')
builder.remove(RawStreamListenerMiddlewareName)
}
if (!params.enableWebSearch) {
logger.silly('WebSearchMiddleware is removed')
builder.remove(WebSearchMiddlewareName)
}
if (!params.mcpTools?.length) {
builder.remove(ToolUseExtractionMiddlewareName)
logger.silly('ToolUseExtractionMiddleware is removed')
builder.remove(McpToolChunkMiddlewareName)
logger.silly('McpToolChunkMiddleware is removed')
}
if (isSupportedToolUse(params.assistant) && isFunctionCallingModel(model)) {
builder.remove(ToolUseExtractionMiddlewareName)
logger.silly('ToolUseExtractionMiddleware is removed')
}
if (params.callType !== 'chat' && params.callType !== 'check' && params.callType !== 'translate') {
logger.silly('AbortHandlerMiddleware is removed')
builder.remove(AbortHandlerMiddlewareName)
}
if (params.callType === 'test') {
builder.remove(ErrorHandlerMiddlewareName)
logger.silly('ErrorHandlerMiddleware is removed')
builder.remove(FinalChunkConsumerMiddlewareName)
logger.silly('FinalChunkConsumerMiddleware is removed')
}
}
const middlewares = builder.build()
logger.silly(
'middlewares',
middlewares.map((m) => m.name)
)
// 3. Create the wrapped SDK method with middlewares
const wrappedCompletionMethod = applyCompletionsMiddlewares(client, client.createCompletions, middlewares)
// 4. Execute the wrapped method with the original params
const result = wrappedCompletionMethod(params, options)
return result
}
public async completionsForTrace(params: CompletionsParams, options?: RequestOptions): Promise<CompletionsResult> {
const traceName = params.assistant.model?.name
? `${params.assistant.model?.name}.${params.callType}`
: `LLM.${params.callType}`
const traceParams: StartSpanParams = {
name: traceName,
tag: 'LLM',
topicId: params.topicId || '',
modelName: params.assistant.model?.name
}
return await withSpanResult(this.completions.bind(this), traceParams, params, options)
}
public async models(): Promise<SdkModel[]> {
return this.apiClient.listModels()
}
public async getEmbeddingDimensions(model: Model): Promise<number> {
try {
// Use the SDK instance to test embedding capabilities
if (this.apiClient instanceof OpenAIResponseAPIClient && getProviderByModel(model).type === 'azure-openai') {
this.apiClient = this.apiClient.getClient(model) as BaseApiClient
}
const dimensions = await this.apiClient.getEmbeddingDimensions(model)
return dimensions
} catch (error) {
logger.error('Error getting embedding dimensions:', error as Error)
throw error
}
}
public async generateImage(params: GenerateImageParams): Promise<string[]> {
if (this.apiClient instanceof AihubmixAPIClient) {
const client = this.apiClient.getClientForModel({ id: params.model } as Model)
return client.generateImage(params)
}
return this.apiClient.generateImage(params)
}
public getBaseURL(): string {
return this.apiClient.getBaseURL()
}
public getApiKey(): string {
return this.apiClient.getApiKey()
}
}

View File

@@ -1,5 +1,5 @@
import { loggerService } from '@logger'
import { MCPCallToolResponse, MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@renderer/types'
import { MCPCallToolResponse, MCPTool, MCPToolResponse, Model } from '@renderer/types'
import { ChunkType, MCPToolCreatedChunk } from '@renderer/types/chunk'
import { SdkMessageParam, SdkRawOutput, SdkToolCall } from '@renderer/types/sdk'
import {
@@ -230,7 +230,7 @@ async function executeToolCalls(
model: Model,
topicId?: string
): Promise<{ toolResults: SdkMessageParam[]; confirmedToolCalls: SdkToolCall[] }> {
const mcpToolResponses: ToolCallResponse[] = toolCalls
const mcpToolResponses: MCPToolResponse[] = toolCalls
.map((toolCall) => {
const mcpTool = ctx.apiClientInstance.convertSdkToolCallToMcp(toolCall, mcpTools)
if (!mcpTool) {
@@ -238,7 +238,7 @@ async function executeToolCalls(
}
return ctx.apiClientInstance.convertSdkToolCallToMcpToolResponse(toolCall, mcpTool)
})
.filter((t): t is ToolCallResponse => typeof t !== 'undefined')
.filter((t): t is MCPToolResponse => typeof t !== 'undefined')
if (mcpToolResponses.length === 0) {
logger.warn(`No valid MCP tool responses to execute`)

View File

@@ -1,4 +1,4 @@
import { AnthropicAPIClient } from '@renderer/aiCore/clients/anthropic/AnthropicAPIClient'
import { AnthropicAPIClient } from '@renderer/aiCore/legacy/clients/anthropic/AnthropicAPIClient'
import { AnthropicSdkRawChunk, AnthropicSdkRawOutput } from '@renderer/types/sdk'
import { AnthropicStreamListener } from '../../clients/types'

Some files were not shown because too many files have changed in this diff Show More