Compare commits

..

3 Commits

357 changed files with 6153 additions and 38661 deletions

View File

@@ -1,17 +1,86 @@
version: 2
updates:
- package-ecosystem: 'github-actions'
directory: '/'
- package-ecosystem: "npm"
directory: "/"
schedule:
interval: 'monthly'
interval: "monthly"
open-pull-requests-limit: 7
target-branch: "main"
commit-message:
prefix: "chore"
include: "scope"
groups:
# 核心框架
core-framework:
patterns:
- "react"
- "react-dom"
- "electron"
- "typescript"
- "@types/react*"
- "@types/node"
update-types:
- "minor"
- "patch"
# Electron 生态和构建工具
electron-build:
patterns:
- "electron-*"
- "@electron*"
- "vite"
- "@vitejs/*"
- "dotenv-cli"
- "rollup-plugin-*"
- "@swc/*"
update-types:
- "minor"
- "patch"
# 测试工具
testing-tools:
patterns:
- "vitest"
- "@vitest/*"
- "playwright"
- "@playwright/*"
- "eslint*"
- "@eslint*"
- "prettier"
- "husky"
- "lint-staged"
update-types:
- "minor"
- "patch"
# CherryStudio 自定义包
cherrystudio-packages:
patterns:
- "@cherrystudio/*"
update-types:
- "minor"
- "patch"
# 兜底其他 dependencies
other-dependencies:
dependency-type: "production"
# 兜底其他 devDependencies
other-dev-dependencies:
dependency-type: "development"
- package-ecosystem: "github-actions"
directory: "/"
schedule:
interval: "weekly"
open-pull-requests-limit: 3
commit-message:
prefix: 'ci'
include: 'scope'
prefix: "ci"
include: "scope"
groups:
github-actions:
patterns:
- '*'
- "*"
update-types:
- 'minor'
- 'patch'
- "minor"
- "patch"

View File

@@ -79,7 +79,6 @@ jobs:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
RENDERER_VITE_AIHUBMIX_SECRET: ${{ vars.RENDERER_VITE_AIHUBMIX_SECRET }}
NODE_OPTIONS: --max-old-space-size=8192
MAIN_VITE_MINERU_API_KEY: ${{ vars.MAIN_VITE_MINERU_API_KEY }}
- name: Build Mac
if: matrix.os == 'macos-latest'
@@ -96,7 +95,6 @@ jobs:
RENDERER_VITE_AIHUBMIX_SECRET: ${{ vars.RENDERER_VITE_AIHUBMIX_SECRET }}
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
NODE_OPTIONS: --max-old-space-size=8192
MAIN_VITE_MINERU_API_KEY: ${{ vars.MAIN_VITE_MINERU_API_KEY }}
- name: Build Windows
if: matrix.os == 'windows-latest'
@@ -107,7 +105,6 @@ jobs:
GH_TOKEN: ${{ secrets.GITHUB_TOKEN }}
RENDERER_VITE_AIHUBMIX_SECRET: ${{ vars.RENDERER_VITE_AIHUBMIX_SECRET }}
NODE_OPTIONS: --max-old-space-size=8192
MAIN_VITE_MINERU_API_KEY: ${{ vars.MAIN_VITE_MINERU_API_KEY }}
- name: Release
uses: ncipollo/release-action@v1

View File

@@ -1,8 +1,7 @@
{
"editor.formatOnSave": true,
"editor.codeActionsOnSave": {
"source.fixAll.eslint": "explicit",
"source.organizeImports": "never"
"source.fixAll.eslint": "explicit"
},
"search.exclude": {
"**/dist/**": true,

View File

@@ -1,69 +0,0 @@
diff --git a/es/dropdown/dropdown.js b/es/dropdown/dropdown.js
index 986877a762b9ad0aca596a8552732cd12d2eaabb..1f18aa2ea745e68950e4cee16d4d655f5c835fd5 100644
--- a/es/dropdown/dropdown.js
+++ b/es/dropdown/dropdown.js
@@ -2,7 +2,7 @@
import * as React from 'react';
import LeftOutlined from "@ant-design/icons/es/icons/LeftOutlined";
-import RightOutlined from "@ant-design/icons/es/icons/RightOutlined";
+import { ChevronRight } from 'lucide-react';
import classNames from 'classnames';
import RcDropdown from 'rc-dropdown';
import useEvent from "rc-util/es/hooks/useEvent";
@@ -158,8 +158,10 @@ const Dropdown = props => {
className: `${prefixCls}-menu-submenu-arrow`
}, direction === 'rtl' ? (/*#__PURE__*/React.createElement(LeftOutlined, {
className: `${prefixCls}-menu-submenu-arrow-icon`
- })) : (/*#__PURE__*/React.createElement(RightOutlined, {
- className: `${prefixCls}-menu-submenu-arrow-icon`
+ })) : (/*#__PURE__*/React.createElement(ChevronRight, {
+ size: 16,
+ strokeWidth: 1.8,
+ className: `${prefixCls}-menu-submenu-arrow-icon lucide-custom`
}))),
mode: "vertical",
selectable: false,
diff --git a/es/dropdown/style/index.js b/es/dropdown/style/index.js
index 768c01783002c6901c85a73061ff6b3e776a60ce..39b1b95a56cdc9fb586a193c3adad5141f5cf213 100644
--- a/es/dropdown/style/index.js
+++ b/es/dropdown/style/index.js
@@ -240,7 +240,8 @@ const genBaseStyle = token => {
marginInlineEnd: '0 !important',
color: token.colorTextDescription,
fontSize: fontSizeIcon,
- fontStyle: 'normal'
+ fontStyle: 'normal',
+ marginTop: 3,
}
}
}),
diff --git a/es/select/useIcons.js b/es/select/useIcons.js
index 959115be936ef8901548af2658c5dcfdc5852723..c812edd52123eb0faf4638b1154fcfa1b05b513b 100644
--- a/es/select/useIcons.js
+++ b/es/select/useIcons.js
@@ -4,10 +4,10 @@ import * as React from 'react';
import CheckOutlined from "@ant-design/icons/es/icons/CheckOutlined";
import CloseCircleFilled from "@ant-design/icons/es/icons/CloseCircleFilled";
import CloseOutlined from "@ant-design/icons/es/icons/CloseOutlined";
-import DownOutlined from "@ant-design/icons/es/icons/DownOutlined";
import LoadingOutlined from "@ant-design/icons/es/icons/LoadingOutlined";
import SearchOutlined from "@ant-design/icons/es/icons/SearchOutlined";
import { devUseWarning } from '../_util/warning';
+import { ChevronDown } from 'lucide-react';
export default function useIcons(_ref) {
let {
suffixIcon,
@@ -56,8 +56,10 @@ export default function useIcons(_ref) {
className: iconCls
}));
}
- return getSuffixIconNode(/*#__PURE__*/React.createElement(DownOutlined, {
- className: iconCls
+ return getSuffixIconNode(/*#__PURE__*/React.createElement(ChevronDown, {
+ size: 16,
+ strokeWidth: 1.8,
+ className: `${iconCls} lucide-custom`
}));
};
}

105
README.md
View File

@@ -1,33 +1,3 @@
<div align="right" >
<details>
<summary >🌐 Language</summary>
<div>
<div align="right">
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=en">English</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=zh-CN">简体中文</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=zh-TW">繁體中文</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=ja">日本語</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=ko">한국어</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=hi">हिन्दी</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=th">ไทย</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=fr">Français</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=de">Deutsch</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=es">Español</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=it">Itapano</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=ru">Русский</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=pt">Português</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=nl">Nederlands</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=pl">Polski</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=ar">العربية</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=fa">فارسی</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=tr">Türkçe</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=vi">Tiếng Việt</a></p>
<p><a href="https://openaitx.github.io/view.html?user=CherryHQ&project=cherry-studio&lang=id">Bahasa Indonesia</a></p>
</div>
</div>
</details>
</div>
<h1 align="center">
<a href="https://github.com/CherryHQ/cherry-studio/releases">
<img src="https://github.com/CherryHQ/cherry-studio/blob/main/build/icon.png?raw=true" width="150" height="150" alt="banner" /><br>
@@ -197,78 +167,6 @@ For more detailed guidelines, please refer to our [Contributing Guide](./CONTRIB
Thank you for your support and contributions!
# 🔧 Developer Co-creation Program
We are launching the Cherry Studio Developer Co-creation Program to foster a healthy and positive-feedback loop within the open-source ecosystem. We believe that great software is built collaboratively, and every merged pull request breathes new life into the project.
We sincerely invite you to join our ranks of contributors and shape the future of Cherry Studio with us.
## Contributor Rewards Program
To give back to our core contributors and create a virtuous cycle, we have established the following long-term incentive plan.
**The inaugural tracking period for this program will be Q3 2025 (July, August, September). Rewards for this cycle will be distributed on October 1st.**
Within any tracking period (e.g., July 1st to September 30th for the first cycle), any developer who contributes more than **30 meaningful commits** to any of Cherry Studio's open-source projects on GitHub is eligible for the following benefits:
- **Cursor Subscription Sponsorship**: Receive a **$70 USD** credit or reimbursement for your [Cursor](https://cursor.sh/) subscription, making AI your most efficient coding partner.
- **Unlimited Model Access**: Get **unlimited** API calls for the **DeepSeek** and **Qwen** models.
- **Cutting-Edge Tech Access**: Enjoy occasional perks, including API access to models like **Claude**, **Gemini**, and **OpenAI**, keeping you at the forefront of technology.
## Growing Together & Future Plans
A vibrant community is the driving force behind any sustainable open-source project. As Cherry Studio grows, so will our rewards program. We are committed to continuously aligning our benefits with the best-in-class tools and resources in the industry. This ensures our core contributors receive meaningful support, creating a positive cycle where developers, the community, and the project grow together.
**Moving forward, the project will also embrace an increasingly open stance to give back to the entire open-source community.**
## How to Get Started?
We look forward to your first Pull Request!
You can start by exploring our repositories, picking up a `good first issue`, or proposing your own enhancements. Every commit is a testament to the spirit of open source.
Thank you for your interest and contributions.
Let's build together.
# 🏢 Enterprise Edition
Building on the Community Edition, we are proud to introduce **Cherry Studio Enterprise Edition**—a privately deployable AI productivity and management platform designed for modern teams and enterprises.
The Enterprise Edition addresses core challenges in team collaboration by centralizing the management of AI resources, knowledge, and data. It empowers organizations to enhance efficiency, foster innovation, and ensure compliance, all while maintaining 100% control over their data in a secure environment.
## Core Advantages
- **Unified Model Management**: Centrally integrate and manage various cloud-based LLMs (e.g., OpenAI, Anthropic, Google Gemini) and locally deployed private models. Employees can use them out-of-the-box without individual configuration.
- **Enterprise-Grade Knowledge Base**: Build, manage, and share team-wide knowledge bases. Ensure knowledge is retained and consistent, enabling team members to interact with AI based on unified and accurate information.
- **Fine-Grained Access Control**: Easily manage employee accounts and assign role-based permissions for different models, knowledge bases, and features through a unified admin backend.
- **Fully Private Deployment**: Deploy the entire backend service on your on-premises servers or private cloud, ensuring your data remains 100% private and under your control to meet the strictest security and compliance standards.
- **Reliable Backend Services**: Provides stable API services, enterprise-grade data backup and recovery mechanisms to ensure business continuity.
## ✨ Online Demo
> 🚧 **Public Beta Notice**
>
> The Enterprise Edition is currently in its early public beta stage, and we are actively iterating and optimizing its features. We are aware that it may not be perfectly stable yet. If you encounter any issues or have valuable suggestions during your trial, we would be very grateful if you could contact us via email to provide feedback.
**🔗 [Cherry Studio Enterprise](https://www.cherry-ai.com/enterprise)**
## Version Comparison
| Feature | Community Edition | Enterprise Edition |
| :---------------- | :----------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------- |
| **Open Source** | ✅ Yes | ⭕️ part. released to cust. |
| **Cost** | Free for Personal Use / Commercial License | Buyout / Subscription Fee |
| **Admin Backend** | — | ● Centralized **Model** Access<br>● **Employee** Management<br>● Shared **Knowledge Base**<br>● **Access** Control<br>● **Data** Backup |
| **Server** | — | ✅ Dedicated Private Deployment |
## Get the Enterprise Edition
We believe the Enterprise Edition will become your team's AI productivity engine. If you are interested in Cherry Studio Enterprise Edition and would like to learn more, request a quote, or schedule a demo, please contact us.
- **For Business Inquiries & Purchasing**:
**📧 [bd@cherry-ai.com](mailto:bd@cherry-ai.com)**
# 🔗 Related Projects
- [one-api](https://github.com/songquanpeng/one-api):LLM API management and distribution system, supporting mainstream models like OpenAI, Azure, and Anthropic. Features unified API interface, suitable for key management and secondary distribution.
@@ -287,7 +185,6 @@ We believe the Enterprise Edition will become your team's AI productivity engine
[![Star History Chart](https://api.star-history.com/svg?repos=CherryHQ/cherry-studio&type=Timeline)](https://star-history.com/#CherryHQ/cherry-studio&Timeline)
<!-- Links & Images -->
[deepwiki-shield]: https://img.shields.io/badge/Deepwiki-CherryHQ-0088CC?style=plastic
[deepwiki-link]: https://deepwiki.com/CherryHQ/cherry-studio
[twitter-shield]: https://img.shields.io/badge/Twitter-CherryStudioApp-0088CC?style=plastic&logo=x
@@ -298,7 +195,6 @@ We believe the Enterprise Edition will become your team's AI productivity engine
[telegram-link]: https://t.me/CherryStudioAI
<!-- Links & Images -->
[github-stars-shield]: https://img.shields.io/github/stars/CherryHQ/cherry-studio?style=social
[github-stars-link]: https://github.com/CherryHQ/cherry-studio/stargazers
[github-forks-shield]: https://img.shields.io/github/forks/CherryHQ/cherry-studio?style=social
@@ -309,7 +205,6 @@ We believe the Enterprise Edition will become your team's AI productivity engine
[github-contributors-link]: https://github.com/CherryHQ/cherry-studio/graphs/contributors
<!-- Links & Images -->
[license-shield]: https://img.shields.io/badge/License-AGPLv3-important.svg?style=plastic&logo=gnu
[license-link]: https://www.gnu.org/licenses/agpl-3.0
[commercial-shield]: https://img.shields.io/badge/License-Contact-white.svg?style=plastic&logoColor=white&logo=telegram&color=blue

View File

@@ -11,11 +11,6 @@ electronLanguages:
- en # for macOS
directories:
buildResources: build
protocols:
- name: Cherry Studio
schemes:
- cherrystudio
files:
- '**/*'
- '!**/{.vscode,.yarn,.yarn-lock,.github,.cursorrules,.prettierrc}'
@@ -53,11 +48,7 @@ files:
- '!node_modules/pdf-parse/lib/pdf.js/{v1.9.426,v1.10.88,v2.0.550}'
- '!node_modules/mammoth/{mammoth.browser.js,mammoth.browser.min.js}'
- '!node_modules/selection-hook/prebuilds/**/*' # we rebuild .node, don't use prebuilds
- '!node_modules/pdfjs-dist/web/**/*'
- '!node_modules/pdfjs-dist/legacy/web/*'
- '!node_modules/selection-hook/node_modules' # we don't need what in the node_modules dir
- '!node_modules/selection-hook/src' # we don't need source files
- '!**/*.{h,iobj,ipdb,tlog,recipe,vcxproj,vcxproj.filters,Makefile,*.Makefile}' # filter .node build files
- '!**/*.{h,iobj,ipdb,tlog,recipe,vcxproj,vcxproj.filters}' # filter .node build files
asarUnpack:
- resources/**
- '**/*.{metal,exp,lib}'
@@ -99,7 +90,6 @@ linux:
artifactName: ${productName}-${version}-${arch}.${ext}
target:
- target: AppImage
- target: deb
maintainer: electronjs.org
category: Utility
desktop:
@@ -117,9 +107,11 @@ afterSign: scripts/notarize.js
artifactBuildCompleted: scripts/artifact-build-completed.js
releaseInfo:
releaseNotes: |
划词助手:支持 macOS 系统
文档处理:增加 MinerU、Doc2xMistral 等服务商支持
知识库:新的知识库界面,增加扫描版 PDF 支持
OCRmacOS 增加系统 OCR 支持
服务商:支持一键添加服务商,新增 PH8 大模型开放平台, 支持 PPIO OAuth 登录
修复Linux下数据目录移动问题
划词助手:支持文本选择快捷键、开关快捷键、思考块支持和引用功能
复制功能新增纯文本复制去除Markdown格式符号
知识库:支持设置向量维度修复Ollama分数错误和维度编辑问题
多语言:增加模型名称多语言提示和翻译源语言手动选择
文件管理:修复主题/消息删除时文件未清理问题,优化文件选择流程
模型修复Gemini模型推理预算、Voyage AI嵌入问题和DeepSeek翻译模型更新
图像功能统一图片查看器支持Base64图片渲染修复图片预览相关问题
UI实现标签折叠/拖拽排序,修复气泡溢出,增加引文索引显示

View File

@@ -1,5 +1,4 @@
import react from '@vitejs/plugin-react-swc'
import { CodeInspectorPlugin } from 'code-inspector-plugin'
import { defineConfig, externalizeDepsPlugin } from 'electron-vite'
import { resolve } from 'path'
import { visualizer } from 'rollup-plugin-visualizer'
@@ -20,7 +19,7 @@ export default defineConfig({
},
build: {
rollupOptions: {
external: ['@libsql/client', 'bufferutil', 'utf-8-validate', '@cherrystudio/mac-system-ocr'],
external: ['@libsql/client', 'bufferutil', 'utf-8-validate'],
output: {
// 彻底禁用代码分割 - 返回 null 强制单文件打包
manualChunks: undefined,
@@ -60,14 +59,6 @@ export default defineConfig({
]
]
}),
// 只在开发环境下启用 CodeInspectorPlugin
...(process.env.NODE_ENV === 'development'
? [
CodeInspectorPlugin({
bundler: 'vite'
})
]
: []),
...visualizerPlugin('renderer')
],
resolve: {
@@ -77,16 +68,12 @@ export default defineConfig({
}
},
optimizeDeps: {
exclude: ['pyodide'],
esbuildOptions: {
target: 'esnext' // for dev
}
exclude: ['pyodide']
},
worker: {
format: 'es'
},
build: {
target: 'esnext', // for build
rollupOptions: {
input: {
index: resolve(__dirname, 'src/renderer/index.html'),

View File

@@ -1,6 +1,6 @@
{
"name": "CherryStudio",
"version": "1.4.8",
"version": "1.4.2",
"private": true,
"description": "A powerful AI assistant for producer.",
"main": "./out/main/index.js",
@@ -58,17 +58,13 @@
"prepare": "husky"
},
"dependencies": {
"@cherrystudio/pdf-to-img-napi": "^0.0.1",
"@libsql/client": "0.14.0",
"@libsql/win32-x64-msvc": "^0.4.7",
"@strongtz/win32-arm64-msvc": "^0.4.7",
"jsdom": "26.1.0",
"macos-release": "^3.4.0",
"node-stream-zip": "^1.15.0",
"notion-helper": "^1.3.22",
"os-proxy-config": "^1.1.2",
"pdfjs-dist": "4.10.38",
"selection-hook": "^1.0.4",
"selection-hook": "^0.9.23",
"turndown": "7.2.0"
},
"devDependencies": {
@@ -103,13 +99,12 @@
"@kangfenmao/keyv-storage": "^0.1.0",
"@langchain/community": "^0.3.36",
"@langchain/ollama": "^0.2.1",
"@mistralai/mistralai": "^1.6.0",
"@modelcontextprotocol/sdk": "^1.11.4",
"@mozilla/readability": "^0.6.0",
"@notionhq/client": "^2.2.15",
"@playwright/test": "^1.52.0",
"@reduxjs/toolkit": "^2.2.5",
"@shikijs/markdown-it": "^3.7.0",
"@shikijs/markdown-it": "^3.4.2",
"@swc/plugin-styled-components": "^7.1.5",
"@tanstack/react-query": "^5.27.0",
"@testing-library/dom": "^10.4.0",
@@ -128,31 +123,28 @@
"@types/react-infinite-scroll-component": "^5.0.0",
"@types/react-window": "^1",
"@types/tinycolor2": "^1",
"@types/word-extractor": "^1",
"@uiw/codemirror-extensions-langs": "^4.23.14",
"@uiw/codemirror-themes-all": "^4.23.14",
"@uiw/react-codemirror": "^4.23.14",
"@uiw/codemirror-extensions-langs": "^4.23.12",
"@uiw/codemirror-themes-all": "^4.23.12",
"@uiw/react-codemirror": "^4.23.12",
"@vitejs/plugin-react-swc": "^3.9.0",
"@vitest/browser": "^3.1.4",
"@vitest/coverage-v8": "^3.1.4",
"@vitest/ui": "^3.1.4",
"@vitest/web-worker": "^3.1.4",
"@xyflow/react": "^12.4.4",
"antd": "patch:antd@npm%3A5.24.7#~/.yarn/patches/antd-npm-5.24.7-356a553ae5.patch",
"antd": "^5.22.5",
"archiver": "^7.0.1",
"async-mutex": "^0.5.0",
"axios": "^1.7.3",
"browser-image-compression": "^2.0.2",
"code-inspector-plugin": "^0.20.14",
"color": "^5.0.0",
"country-flag-emoji-polyfill": "0.1.8",
"dayjs": "^1.11.11",
"dexie": "^4.0.8",
"dexie-react-hooks": "^1.1.7",
"diff": "^7.0.0",
"docx": "^9.0.2",
"dotenv-cli": "^7.4.2",
"electron": "35.6.0",
"electron": "35.4.0",
"electron-builder": "26.0.15",
"electron-devtools-installer": "^3.2.0",
"electron-log": "^5.1.5",
@@ -181,9 +173,10 @@
"lru-cache": "^11.1.0",
"lucide-react": "^0.487.0",
"markdown-it": "^14.1.0",
"mermaid": "^11.7.0",
"mermaid": "^11.6.0",
"mime": "^4.0.4",
"motion": "^12.10.5",
"node-stream-zip": "^1.15.0",
"npx-scope-finder": "^1.2.0",
"officeparser": "^4.1.1",
"openai": "patch:openai@npm%3A5.1.0#~/.yarn/patches/openai-npm-5.1.0-0e7b3ccb07.patch",
@@ -197,7 +190,7 @@
"react-hotkeys-hook": "^4.6.1",
"react-i18next": "^14.1.2",
"react-infinite-scroll-component": "^6.1.0",
"react-markdown": "^10.1.0",
"react-markdown": "^9.0.1",
"react-redux": "^9.1.2",
"react-router": "6",
"react-router-dom": "6",
@@ -206,31 +199,27 @@
"redux": "^5.0.1",
"redux-persist": "^6.0.0",
"rehype-katex": "^7.0.1",
"rehype-mathjax": "^7.1.0",
"rehype-mathjax": "^7.0.0",
"rehype-raw": "^7.0.0",
"remark-cjk-friendly": "^1.2.0",
"remark-gfm": "^4.0.1",
"remark-cjk-friendly": "^1.1.0",
"remark-gfm": "^4.0.0",
"remark-math": "^6.0.0",
"remove-markdown": "^0.6.2",
"rollup-plugin-visualizer": "^5.12.0",
"sass": "^1.88.0",
"shiki": "^3.7.0",
"shiki": "^3.4.2",
"string-width": "^7.2.0",
"styled-components": "^6.1.11",
"tar": "^7.4.3",
"tiny-pinyin": "^1.3.2",
"tokenx": "^1.1.0",
"tokenx": "^0.4.1",
"typescript": "^5.6.2",
"uuid": "^10.0.0",
"vite": "6.2.6",
"vitest": "^3.1.4",
"webdav": "^5.8.0",
"word-extractor": "^1.0.4",
"zipread": "^1.3.3"
},
"optionalDependencies": {
"@cherrystudio/mac-system-ocr": "^0.2.2"
},
"resolutions": {
"pdf-parse@npm:1.1.1": "patch:pdf-parse@npm%3A1.1.1#~/.yarn/patches/pdf-parse-npm-1.1.1-04a6109b2a.patch",
"@langchain/openai@npm:^0.3.16": "patch:@langchain/openai@npm%3A0.3.16#~/.yarn/patches/@langchain-openai-npm-0.3.16-e525b59526.patch",

View File

@@ -3,8 +3,6 @@ export enum IpcChannel {
App_ClearCache = 'app:clear-cache',
App_SetLaunchOnBoot = 'app:set-launch-on-boot',
App_SetLanguage = 'app:set-language',
App_SetEnableSpellCheck = 'app:set-enable-spell-check',
App_SetSpellCheckLanguages = 'app:set-spell-check-languages',
App_ShowUpdateDialog = 'app:show-update-dialog',
App_CheckForUpdate = 'app:check-for-update',
App_Reload = 'app:reload',
@@ -15,33 +13,20 @@ export enum IpcChannel {
App_SetTrayOnClose = 'app:set-tray-on-close',
App_SetTheme = 'app:set-theme',
App_SetAutoUpdate = 'app:set-auto-update',
App_SetTestPlan = 'app:set-test-plan',
App_SetTestChannel = 'app:set-test-channel',
App_SetFeedUrl = 'app:set-feed-url',
App_HandleZoomFactor = 'app:handle-zoom-factor',
App_Select = 'app:select',
App_HasWritePermission = 'app:has-write-permission',
App_Copy = 'app:copy',
App_SetStopQuitApp = 'app:set-stop-quit-app',
App_SetAppDataPath = 'app:set-app-data-path',
App_GetDataPathFromArgs = 'app:get-data-path-from-args',
App_FlushAppData = 'app:flush-app-data',
App_IsNotEmptyDir = 'app:is-not-empty-dir',
App_RelaunchApp = 'app:relaunch-app',
App_IsBinaryExist = 'app:is-binary-exist',
App_GetBinaryPath = 'app:get-binary-path',
App_InstallUvBinary = 'app:install-uv-binary',
App_InstallBunBinary = 'app:install-bun-binary',
App_MacIsProcessTrusted = 'app:mac-is-process-trusted',
App_MacRequestProcessTrust = 'app:mac-request-process-trust',
App_QuoteToMain = 'app:quote-to-main',
Notification_Send = 'notification:send',
Notification_OnClick = 'notification:on-click',
Webview_SetOpenLinkExternal = 'webview:set-open-link-external',
Webview_SetSpellCheckEnabled = 'webview:set-spell-check-enabled',
// Open
Open_Path = 'open:path',
@@ -74,9 +59,6 @@ export enum IpcChannel {
Mcp_ServersUpdated = 'mcp:servers-updated',
Mcp_CheckConnectivity = 'mcp:check-connectivity',
// Python
Python_Execute = 'python:execute',
//copilot
Copilot_GetAuthMessage = 'copilot:get-auth-message',
Copilot_GetCopilotToken = 'copilot:get-copilot-token',
@@ -118,7 +100,6 @@ export enum IpcChannel {
KnowledgeBase_Remove = 'knowledge-base:remove',
KnowledgeBase_Search = 'knowledge-base:search',
KnowledgeBase_Rerank = 'knowledge-base:rerank',
KnowledgeBase_Check_Quota = 'knowledge-base:check-quota',
//file
File_Open = 'file:open',
@@ -129,10 +110,9 @@ export enum IpcChannel {
File_Clear = 'file:clear',
File_Read = 'file:read',
File_Delete = 'file:delete',
File_DeleteDir = 'file:deleteDir',
File_Get = 'file:get',
File_SelectFolder = 'file:selectFolder',
File_CreateTempFile = 'file:createTempFile',
File_Create = 'file:create',
File_Write = 'file:write',
File_WriteWithId = 'file:writeWithId',
File_SaveImage = 'file:saveImage',
@@ -145,12 +125,6 @@ export enum IpcChannel {
File_GetPdfInfo = 'file:getPdfInfo',
Fs_Read = 'fs:read',
// file service
FileService_Upload = 'file-service:upload',
FileService_List = 'file-service:list',
FileService_Delete = 'file-service:delete',
FileService_Retrieve = 'file-service:retrieve',
Export_Word = 'export:word',
Shortcuts_Update = 'shortcuts:update',

View File

@@ -1,7 +1,7 @@
export const imageExts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']
export const videoExts = ['.mp4', '.avi', '.mov', '.wmv', '.flv', '.mkv']
export const audioExts = ['.mp3', '.wav', '.ogg', '.flac', '.aac']
export const documentExts = ['.pdf', '.doc', '.docx', '.pptx', '.xlsx', '.odt', '.odp', '.ods']
export const documentExts = ['.pdf', '.docx', '.pptx', '.xlsx', '.odt', '.odp', '.ods']
export const thirdPartyApplicationExts = ['.draftsExport']
export const bookExts = ['.epub']
const textExtsByCategory = new Map([
@@ -406,16 +406,6 @@ export const defaultLanguage = 'en-US'
export enum FeedUrl {
PRODUCTION = 'https://releases.cherry-ai.com',
GITHUB_LATEST = 'https://github.com/CherryHQ/cherry-studio/releases/latest/download',
PRERELEASE_LOWEST = 'https://github.com/CherryHQ/cherry-studio/releases/download/v1.4.0'
EARLY_ACCESS = 'https://github.com/CherryHQ/cherry-studio/releases/latest/download'
}
export enum UpgradeChannel {
LATEST = 'latest', // 最新稳定版本
RC = 'rc', // 公测版本
BETA = 'beta' // 预览版本
}
export const defaultTimeout = 10 * 1000 * 60
export const occupiedDirs = ['logs', 'Network', 'Partitions/webview/Network']
export const defaultTimeout = 5 * 1000 * 60

File diff suppressed because it is too large Load Diff

View File

@@ -1,11 +1,6 @@
import { ProcessingStatus } from '@types'
export type LoaderReturn = {
entriesAdded: number
uniqueId: string
uniqueIds: string[]
loaderType: string
status?: ProcessingStatus
message?: string
messageSource?: 'preprocess' | 'embedding'
}

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@@ -2,12 +2,12 @@ const fs = require('fs')
const path = require('path')
const os = require('os')
const { execSync } = require('child_process')
const StreamZip = require('node-stream-zip')
const AdmZip = require('adm-zip')
const { downloadWithRedirects } = require('./download')
// Base URL for downloading bun binaries
const BUN_RELEASE_BASE_URL = 'https://gitcode.com/CherryHQ/bun/releases/download'
const DEFAULT_BUN_VERSION = '1.2.17' // Default fallback version
const DEFAULT_BUN_VERSION = '1.2.9' // Default fallback version
// Mapping of platform+arch to binary package name
const BUN_PACKAGES = {
@@ -66,36 +66,35 @@ async function downloadBunBinary(platform, arch, version = DEFAULT_BUN_VERSION,
// Extract the zip file using adm-zip
console.log(`Extracting ${packageName} to ${binDir}...`)
const zip = new StreamZip.async({ file: tempFilename })
const zip = new AdmZip(tempFilename)
zip.extractAllTo(tempdir, true)
// Get all entries in the zip file
const entries = await zip.entries()
// Move files using Node.js fs
const sourceDir = path.join(tempdir, packageName.split('.')[0])
const files = fs.readdirSync(sourceDir)
// Extract files directly to binDir, flattening the directory structure
for (const entry of Object.values(entries)) {
if (!entry.isDirectory) {
// Get just the filename without path
const filename = path.basename(entry.name)
const outputPath = path.join(binDir, filename)
for (const file of files) {
const sourcePath = path.join(sourceDir, file)
const destPath = path.join(binDir, file)
console.log(`Extracting ${entry.name} -> ${filename}`)
await zip.extract(entry.name, outputPath)
// Make executable files executable on Unix-like systems
if (platform !== 'win32') {
try {
fs.chmodSync(outputPath, 0o755)
} catch (chmodError) {
console.error(`Warning: Failed to set executable permissions on ${filename}`)
return false
}
fs.copyFileSync(sourcePath, destPath)
fs.unlinkSync(sourcePath)
// Set executable permissions for non-Windows platforms
if (platform !== 'win32') {
try {
// 755 permission: rwxr-xr-x
fs.chmodSync(destPath, '755')
} catch (error) {
console.warn(`Warning: Failed to set executable permissions: ${error.message}`)
}
console.log(`Extracted ${entry.name} -> ${outputPath}`)
}
}
await zip.close()
// Clean up
fs.unlinkSync(tempFilename)
fs.rmSync(sourceDir, { recursive: true })
console.log(`Successfully installed bun ${version} for ${platformKey}`)
return true
} catch (error) {

View File

@@ -2,33 +2,34 @@ const fs = require('fs')
const path = require('path')
const os = require('os')
const { execSync } = require('child_process')
const StreamZip = require('node-stream-zip')
const tar = require('tar')
const AdmZip = require('adm-zip')
const { downloadWithRedirects } = require('./download')
// Base URL for downloading uv binaries
const UV_RELEASE_BASE_URL = 'https://gitcode.com/CherryHQ/uv/releases/download'
const DEFAULT_UV_VERSION = '0.7.13'
const DEFAULT_UV_VERSION = '0.6.14'
// Mapping of platform+arch to binary package name
const UV_PACKAGES = {
'darwin-arm64': 'uv-aarch64-apple-darwin.zip',
'darwin-x64': 'uv-x86_64-apple-darwin.zip',
'darwin-arm64': 'uv-aarch64-apple-darwin.tar.gz',
'darwin-x64': 'uv-x86_64-apple-darwin.tar.gz',
'win32-arm64': 'uv-aarch64-pc-windows-msvc.zip',
'win32-ia32': 'uv-i686-pc-windows-msvc.zip',
'win32-x64': 'uv-x86_64-pc-windows-msvc.zip',
'linux-arm64': 'uv-aarch64-unknown-linux-gnu.zip',
'linux-ia32': 'uv-i686-unknown-linux-gnu.zip',
'linux-ppc64': 'uv-powerpc64-unknown-linux-gnu.zip',
'linux-ppc64le': 'uv-powerpc64le-unknown-linux-gnu.zip',
'linux-s390x': 'uv-s390x-unknown-linux-gnu.zip',
'linux-x64': 'uv-x86_64-unknown-linux-gnu.zip',
'linux-armv7l': 'uv-armv7-unknown-linux-gnueabihf.zip',
'linux-arm64': 'uv-aarch64-unknown-linux-gnu.tar.gz',
'linux-ia32': 'uv-i686-unknown-linux-gnu.tar.gz',
'linux-ppc64': 'uv-powerpc64-unknown-linux-gnu.tar.gz',
'linux-ppc64le': 'uv-powerpc64le-unknown-linux-gnu.tar.gz',
'linux-s390x': 'uv-s390x-unknown-linux-gnu.tar.gz',
'linux-x64': 'uv-x86_64-unknown-linux-gnu.tar.gz',
'linux-armv7l': 'uv-armv7-unknown-linux-gnueabihf.tar.gz',
// MUSL variants
'linux-musl-arm64': 'uv-aarch64-unknown-linux-musl.zip',
'linux-musl-ia32': 'uv-i686-unknown-linux-musl.zip',
'linux-musl-x64': 'uv-x86_64-unknown-linux-musl.zip',
'linux-musl-armv6l': 'uv-arm-unknown-linux-musleabihf.zip',
'linux-musl-armv7l': 'uv-armv7-unknown-linux-musleabihf.zip'
'linux-musl-arm64': 'uv-aarch64-unknown-linux-musl.tar.gz',
'linux-musl-ia32': 'uv-i686-unknown-linux-musl.tar.gz',
'linux-musl-x64': 'uv-x86_64-unknown-linux-musl.tar.gz',
'linux-musl-armv6l': 'uv-arm-unknown-linux-musleabihf.tar.gz',
'linux-musl-armv7l': 'uv-armv7-unknown-linux-musleabihf.tar.gz'
}
/**
@@ -65,35 +66,46 @@ async function downloadUvBinary(platform, arch, version = DEFAULT_UV_VERSION, is
console.log(`Extracting ${packageName} to ${binDir}...`)
const zip = new StreamZip.async({ file: tempFilename })
// 根据文件扩展名选择解压方法
if (packageName.endsWith('.zip')) {
// 使用 adm-zip 处理 zip 文件
const zip = new AdmZip(tempFilename)
zip.extractAllTo(binDir, true)
fs.unlinkSync(tempFilename)
console.log(`Successfully installed uv ${version} for ${platform}-${arch}`)
return true
} else {
// tar.gz 文件的处理保持不变
await tar.x({
file: tempFilename,
cwd: tempdir,
z: true
})
// Get all entries in the zip file
const entries = await zip.entries()
// Move files using Node.js fs
const sourceDir = path.join(tempdir, packageName.split('.')[0])
const files = fs.readdirSync(sourceDir)
for (const file of files) {
const sourcePath = path.join(sourceDir, file)
const destPath = path.join(binDir, file)
fs.copyFileSync(sourcePath, destPath)
fs.unlinkSync(sourcePath)
// Extract files directly to binDir, flattening the directory structure
for (const entry of Object.values(entries)) {
if (!entry.isDirectory) {
// Get just the filename without path
const filename = path.basename(entry.name)
const outputPath = path.join(binDir, filename)
console.log(`Extracting ${entry.name} -> ${filename}`)
await zip.extract(entry.name, outputPath)
// Make executable files executable on Unix-like systems
// Set executable permissions for non-Windows platforms
if (platform !== 'win32') {
try {
fs.chmodSync(outputPath, 0o755)
} catch (chmodError) {
console.error(`Warning: Failed to set executable permissions on ${filename}`)
return false
fs.chmodSync(destPath, '755')
} catch (error) {
console.warn(`Warning: Failed to set executable permissions: ${error.message}`)
}
}
console.log(`Extracted ${entry.name} -> ${outputPath}`)
}
// Clean up
fs.unlinkSync(tempFilename)
fs.rmSync(sourceDir, { recursive: true })
}
await zip.close()
fs.unlinkSync(tempFilename)
console.log(`Successfully installed uv ${version} for ${platform}-${arch}`)
return true
} catch (error) {

View File

@@ -23,9 +23,6 @@ exports.default = async function (context) {
const node_modules_path = path.join(context.appOutDir, 'resources', 'app.asar.unpacked', 'node_modules')
const _arch = arch === Arch.arm64 ? ['linux-arm64-gnu', 'linux-arm64-musl'] : ['linux-x64-gnu', 'linux-x64-musl']
keepPackageNodeFiles(node_modules_path, '@libsql', _arch)
// 删除 macOS 专用的 OCR 包
removeMacOnlyPackages(node_modules_path)
}
if (platform === 'windows') {
@@ -38,8 +35,6 @@ exports.default = async function (context) {
keepPackageNodeFiles(node_modules_path, '@strongtz', ['win32-x64-msvc'])
keepPackageNodeFiles(node_modules_path, '@libsql', ['win32-x64-msvc'])
}
removeMacOnlyPackages(node_modules_path)
}
if (platform === 'windows') {
@@ -48,22 +43,6 @@ exports.default = async function (context) {
}
}
/**
* 删除 macOS 专用的包
* @param {string} nodeModulesPath
*/
function removeMacOnlyPackages(nodeModulesPath) {
const macOnlyPackages = ['@cherrystudio/mac-system-ocr']
macOnlyPackages.forEach((packageName) => {
const packagePath = path.join(nodeModulesPath, packageName)
if (fs.existsSync(packagePath)) {
fs.rmSync(packagePath, { recursive: true, force: true })
console.log(`[After Pack] Removed macOS-only package: ${packageName}`)
}
})
}
/**
* 使用指定架构的 node_modules 文件
* @param {*} nodeModulesPath

View File

@@ -1,19 +1,16 @@
/**
* 使用 OpenAI 兼容的模型生成 i18n 文本,并更新到 translate 目录
*
* API_KEY=sk-xxxx BASE_URL=xxxx MODEL=xxxx ts-node scripts/update-i18n.ts
* Paratera_API_KEY=sk-abcxxxxxxxxxxxxxxxxxxxxxxx123 ts-node scripts/update-i18n.ts
*/
const API_KEY = process.env.API_KEY
const BASE_URL = process.env.BASE_URL || 'https://llmapi.paratera.com/v1'
const MODEL = process.env.MODEL || 'Qwen3-235B-A22B'
// OCOOL API KEY
const Paratera_API_KEY = process.env.Paratera_API_KEY
const INDEX = [
// 语言的名称代码用来翻译的模型
{ name: 'France', code: 'fr-fr', model: MODEL },
{ name: 'Spanish', code: 'es-es', model: MODEL },
{ name: 'Portuguese', code: 'pt-pt', model: MODEL },
{ name: 'Greek', code: 'el-gr', model: MODEL }
// 语言的名称 代码 用来翻译的模型
{ name: 'France', code: 'fr-fr', model: 'Qwen3-235B-A22B' },
{ name: 'Spanish', code: 'es-es', model: 'Qwen3-235B-A22B' },
{ name: 'Portuguese', code: 'pt-pt', model: 'Qwen3-235B-A22B' },
{ name: 'Greek', code: 'el-gr', model: 'Qwen3-235B-A22B' }
]
const fs = require('fs')
@@ -22,8 +19,8 @@ import OpenAI from 'openai'
const zh = JSON.parse(fs.readFileSync('src/renderer/src/i18n/locales/zh-cn.json', 'utf8')) as object
const openai = new OpenAI({
apiKey: API_KEY,
baseURL: BASE_URL
apiKey: Paratera_API_KEY,
baseURL: 'https://llmapi.paratera.com/v1'
})
// 递归遍历翻译

View File

@@ -1,33 +0,0 @@
import { occupiedDirs } from '@shared/config/constant'
import { app } from 'electron'
import fs from 'fs'
import path from 'path'
import { initAppDataDir } from './utils/file'
app.isPackaged && initAppDataDir()
// 在主进程中复制 appData 中某些一直被占用的文件
// 在renderer进程还没有启动时主进程可以复制这些文件到新的appData中
function copyOccupiedDirsInMainProcess() {
const newAppDataPath = process.argv
.slice(1)
.find((arg) => arg.startsWith('--new-data-path='))
?.split('--new-data-path=')[1]
if (!newAppDataPath) {
return
}
if (process.platform === 'win32') {
const appDataPath = app.getPath('userData')
occupiedDirs.forEach((dir) => {
const dirPath = path.join(appDataPath, dir)
const newDirPath = path.join(newAppDataPath, dir)
if (fs.existsSync(dirPath)) {
fs.cpSync(dirPath, newDirPath, { recursive: true })
}
})
}
}
copyOccupiedDirsInMainProcess()

View File

@@ -1,6 +1,7 @@
import { app } from 'electron'
import { getDataPath } from './utils'
const isDev = process.env.NODE_ENV === 'development'
if (isDev) {

View File

@@ -1,6 +1,6 @@
interface IFilterList {
WINDOWS: string[]
MAC: string[]
MAC?: string[]
}
interface IFinetunedList {
@@ -45,17 +45,14 @@ export const SELECTION_PREDEFINED_BLACKLIST: IFilterList = {
'sldworks.exe',
// Remote Desktop
'mstsc.exe'
],
MAC: []
]
}
export const SELECTION_FINETUNED_LIST: IFinetunedList = {
EXCLUDE_CLIPBOARD_CURSOR_DETECT: {
WINDOWS: ['acrobat.exe', 'wps.exe', 'cajviewer.exe'],
MAC: []
WINDOWS: ['acrobat.exe', 'wps.exe', 'cajviewer.exe']
},
INCLUDE_CLIPBOARD_DELAY_READ: {
WINDOWS: ['acrobat.exe', 'wps.exe', 'cajviewer.exe', 'foxitphantom.exe'],
MAC: []
WINDOWS: ['acrobat.exe', 'wps.exe', 'cajviewer.exe', 'foxitphantom.exe']
}
}

View File

@@ -1,8 +1,3 @@
// don't reorder this file, it's used to initialize the app data dir and
// other which should be run before the main process is ready
// eslint-disable-next-line
import './bootstrap'
import '@main/config'
import { electronApp, optimizer } from '@electron-toolkit/utils'
@@ -25,6 +20,7 @@ import selectionService, { initSelectionService } from './services/SelectionServ
import { registerShortcuts } from './services/ShortcutService'
import { TrayService } from './services/TrayService'
import { windowService } from './services/WindowService'
import { setUserDataDir } from './utils/file'
Logger.initialize()
@@ -76,6 +72,9 @@ if (!app.requestSingleInstanceLock()) {
app.quit()
process.exit(0)
} else {
// Portable dir must be setup before app ready
setUserDataDir()
// This method will be called when Electron has finished
// initialization and is ready to create browser windows.
// Some APIs can only be used after this event occurs.
@@ -124,27 +123,19 @@ if (!app.requestSingleInstanceLock()) {
registerProtocolClient(app)
// macOS specific: handle protocol when app is already running
app.on('open-url', (event, url) => {
event.preventDefault()
handleProtocolUrl(url)
})
const handleOpenUrl = (args: string[]) => {
const url = args.find((arg) => arg.startsWith(CHERRY_STUDIO_PROTOCOL + '://'))
if (url) handleProtocolUrl(url)
}
// for windows to start with url
handleOpenUrl(process.argv)
// Listen for second instance
app.on('second-instance', (_event, argv) => {
windowService.showMainWindow()
// Protocol handler for Windows/Linux
// The commandLine is an array of strings where the last item might be the URL
handleOpenUrl(argv)
const url = argv.find((arg) => arg.startsWith(CHERRY_STUDIO_PROTOCOL + '://'))
if (url) handleProtocolUrl(url)
})
app.on('browser-window-created', (_, window) => {

View File

@@ -1,14 +1,13 @@
import fs from 'node:fs'
import { arch } from 'node:os'
import path from 'node:path'
import { isLinux, isMac, isWin } from '@main/constant'
import { isMac, isWin } from '@main/constant'
import { getBinaryPath, isBinaryExists, runInstallScript } from '@main/utils/process'
import { handleZoomFactor } from '@main/utils/zoom'
import { UpgradeChannel } from '@shared/config/constant'
import { FeedUrl } from '@shared/config/constant'
import { IpcChannel } from '@shared/IpcChannel'
import { FileMetadata, Provider, Shortcut, ThemeMode } from '@types'
import { BrowserWindow, dialog, ipcMain, session, shell, systemPreferences, webContents } from 'electron'
import { Shortcut, ThemeMode } from '@types'
import { BrowserWindow, ipcMain, session, shell } from 'electron'
import log from 'electron-log'
import { Notification } from 'src/renderer/src/types/notification'
@@ -17,16 +16,14 @@ import BackupManager from './services/BackupManager'
import { configManager } from './services/ConfigManager'
import CopilotService from './services/CopilotService'
import { ExportService } from './services/ExportService'
import FileService from './services/FileService'
import FileStorage from './services/FileStorage'
import FileService from './services/FileSystemService'
import KnowledgeService from './services/KnowledgeService'
import mcpService from './services/MCPService'
import NotificationService from './services/NotificationService'
import * as NutstoreService from './services/NutstoreService'
import ObsidianVaultService from './services/ObsidianVaultService'
import { ProxyConfig, proxyManager } from './services/ProxyManager'
import { pythonService } from './services/PythonService'
import { FileServiceManager } from './services/remotefile/FileServiceManager'
import { searchService } from './services/SearchService'
import { SelectionService } from './services/SelectionService'
import { registerShortcuts, unregisterAllShortcuts } from './services/ShortcutService'
@@ -37,7 +34,7 @@ import { setOpenLinkExternal } from './services/WebviewService'
import { windowService } from './services/WindowService'
import { calculateDirectorySize, getResourcePath } from './utils'
import { decrypt, encrypt } from './utils/aes'
import { getCacheDir, getConfigDir, getFilesDir, hasWritePermission, updateAppDataConfig } from './utils/file'
import { getCacheDir, getConfigDir, getFilesDir } from './utils/file'
import { compress, decompress } from './utils/zip'
const fileManager = new FileStorage()
@@ -50,9 +47,6 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
const appUpdater = new AppUpdater(mainWindow)
const notificationService = new NotificationService(mainWindow)
// Initialize Python service with main window
pythonService.setMainWindow(mainWindow)
ipcMain.handle(IpcChannel.App_Info, () => ({
version: app.getVersion(),
isPackaged: app.isPackaged,
@@ -63,8 +57,7 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
resourcesPath: getResourcePath(),
logsPath: log.transports.file.getFile().path,
arch: arch(),
isPortable: isWin && 'PORTABLE_EXECUTABLE_DIR' in process.env,
installPath: path.dirname(app.getPath('exe'))
isPortable: isWin && 'PORTABLE_EXECUTABLE_DIR' in process.env
}))
ipcMain.handle(IpcChannel.App_Proxy, async (_, proxy: string) => {
@@ -92,27 +85,6 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
configManager.setLanguage(language)
})
// spell check
ipcMain.handle(IpcChannel.App_SetEnableSpellCheck, (_, isEnable: boolean) => {
// disable spell check for all webviews
const webviews = webContents.getAllWebContents()
webviews.forEach((webview) => {
webview.session.setSpellCheckerEnabled(isEnable)
})
})
// spell check languages
ipcMain.handle(IpcChannel.App_SetSpellCheckLanguages, (_, languages: string[]) => {
if (languages.length === 0) {
return
}
const windows = BrowserWindow.getAllWindows()
windows.forEach((window) => {
window.webContents.session.setSpellCheckerLanguages(languages)
})
configManager.set('spellCheckLanguages', languages)
})
// launch on boot
ipcMain.handle(IpcChannel.App_SetLaunchOnBoot, (_, openAtLogin: boolean) => {
// Set login item settings for windows and mac
@@ -143,34 +115,10 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
configManager.setAutoUpdate(isActive)
})
ipcMain.handle(IpcChannel.App_SetTestPlan, async (_, isActive: boolean) => {
log.info('set test plan', isActive)
if (isActive !== configManager.getTestPlan()) {
appUpdater.cancelDownload()
configManager.setTestPlan(isActive)
}
ipcMain.handle(IpcChannel.App_SetFeedUrl, (_, feedUrl: FeedUrl) => {
appUpdater.setFeedUrl(feedUrl)
})
ipcMain.handle(IpcChannel.App_SetTestChannel, async (_, channel: UpgradeChannel) => {
log.info('set test channel', channel)
if (channel !== configManager.getTestChannel()) {
appUpdater.cancelDownload()
configManager.setTestChannel(channel)
}
})
//only for mac
if (isMac) {
ipcMain.handle(IpcChannel.App_MacIsProcessTrusted, (): boolean => {
return systemPreferences.isTrustedAccessibilityClient(false)
})
//return is only the current state, not the new state
ipcMain.handle(IpcChannel.App_MacRequestProcessTrust, (): boolean => {
return systemPreferences.isTrustedAccessibilityClient(true)
})
}
ipcMain.handle(IpcChannel.Config_Set, (_, key: string, value: any, isNotify: boolean = false) => {
configManager.set(key, value, isNotify)
})
@@ -227,113 +175,6 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
}
})
let preventQuitListener: ((event: Electron.Event) => void) | null = null
ipcMain.handle(IpcChannel.App_SetStopQuitApp, (_, stop: boolean = false, reason: string = '') => {
if (stop) {
// Only add listener if not already added
if (!preventQuitListener) {
preventQuitListener = (event: Electron.Event) => {
event.preventDefault()
notificationService.sendNotification({
title: reason,
message: reason
} as Notification)
}
app.on('before-quit', preventQuitListener)
}
} else {
// Remove listener if it exists
if (preventQuitListener) {
app.removeListener('before-quit', preventQuitListener)
preventQuitListener = null
}
}
})
// Select app data path
ipcMain.handle(IpcChannel.App_Select, async (_, options: Electron.OpenDialogOptions) => {
try {
const { canceled, filePaths } = await dialog.showOpenDialog(options)
if (canceled || filePaths.length === 0) {
return null
}
return filePaths[0]
} catch (error: any) {
log.error('Failed to select app data path:', error)
return null
}
})
ipcMain.handle(IpcChannel.App_HasWritePermission, async (_, filePath: string) => {
return hasWritePermission(filePath)
})
// Set app data path
ipcMain.handle(IpcChannel.App_SetAppDataPath, async (_, filePath: string) => {
updateAppDataConfig(filePath)
app.setPath('userData', filePath)
})
ipcMain.handle(IpcChannel.App_GetDataPathFromArgs, () => {
return process.argv
.slice(1)
.find((arg) => arg.startsWith('--new-data-path='))
?.split('--new-data-path=')[1]
})
ipcMain.handle(IpcChannel.App_FlushAppData, () => {
BrowserWindow.getAllWindows().forEach((w) => {
w.webContents.session.flushStorageData()
w.webContents.session.cookies.flushStore()
w.webContents.session.closeAllConnections()
})
session.defaultSession.flushStorageData()
session.defaultSession.cookies.flushStore()
session.defaultSession.closeAllConnections()
})
ipcMain.handle(IpcChannel.App_IsNotEmptyDir, async (_, path: string) => {
return fs.readdirSync(path).length > 0
})
// Copy user data to new location
ipcMain.handle(IpcChannel.App_Copy, async (_, oldPath: string, newPath: string, occupiedDirs: string[] = []) => {
try {
await fs.promises.cp(oldPath, newPath, {
recursive: true,
filter: (src) => {
if (occupiedDirs.some((dir) => src.startsWith(path.resolve(dir)))) {
return false
}
return true
}
})
return { success: true }
} catch (error: any) {
log.error('Failed to copy user data:', error)
return { success: false, error: error.message }
}
})
// Relaunch app
ipcMain.handle(IpcChannel.App_RelaunchApp, (_, options?: Electron.RelaunchOptions) => {
// Fix for .AppImage
if (isLinux && process.env.APPIMAGE) {
log.info('Relaunching app with options:', process.env.APPIMAGE, options)
// On Linux, we need to use the APPIMAGE environment variable to relaunch
// https://github.com/electron-userland/electron-builder/issues/1727#issuecomment-769896927
options = options || {}
options.execPath = process.env.APPIMAGE
options.args = options.args || []
options.args.unshift('--appimage-extract-and-run')
}
app.relaunch(options)
app.exit(0)
})
// check for update
ipcMain.handle(IpcChannel.App_CheckForUpdate, async () => {
return await appUpdater.checkForUpdates()
@@ -378,10 +219,9 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle(IpcChannel.File_Clear, fileManager.clear)
ipcMain.handle(IpcChannel.File_Read, fileManager.readFile)
ipcMain.handle(IpcChannel.File_Delete, fileManager.deleteFile)
ipcMain.handle('file:deleteDir', fileManager.deleteDir)
ipcMain.handle(IpcChannel.File_Get, fileManager.getFile)
ipcMain.handle(IpcChannel.File_SelectFolder, fileManager.selectFolder)
ipcMain.handle(IpcChannel.File_CreateTempFile, fileManager.createTempFile)
ipcMain.handle(IpcChannel.File_Create, fileManager.createTempFile)
ipcMain.handle(IpcChannel.File_Write, fileManager.writeFile)
ipcMain.handle(IpcChannel.File_WriteWithId, fileManager.writeFileWithId)
ipcMain.handle(IpcChannel.File_SaveImage, fileManager.saveImage)
@@ -393,27 +233,6 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle(IpcChannel.File_Copy, fileManager.copyFile)
ipcMain.handle(IpcChannel.File_BinaryImage, fileManager.binaryImage)
// file service
ipcMain.handle(IpcChannel.FileService_Upload, async (_, provider: Provider, file: FileMetadata) => {
const service = FileServiceManager.getInstance().getService(provider)
return await service.uploadFile(file)
})
ipcMain.handle(IpcChannel.FileService_List, async (_, provider: Provider) => {
const service = FileServiceManager.getInstance().getService(provider)
return await service.listFiles()
})
ipcMain.handle(IpcChannel.FileService_Delete, async (_, provider: Provider, fileId: string) => {
const service = FileServiceManager.getInstance().getService(provider)
return await service.deleteFile(fileId)
})
ipcMain.handle(IpcChannel.FileService_Retrieve, async (_, provider: Provider, fileId: string) => {
const service = FileServiceManager.getInstance().getService(provider)
return await service.retrieveFile(fileId)
})
// fs
ipcMain.handle(IpcChannel.Fs_Read, FileService.readFile)
@@ -443,7 +262,6 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle(IpcChannel.KnowledgeBase_Remove, KnowledgeService.remove)
ipcMain.handle(IpcChannel.KnowledgeBase_Search, KnowledgeService.search)
ipcMain.handle(IpcChannel.KnowledgeBase_Rerank, KnowledgeService.rerank)
ipcMain.handle(IpcChannel.KnowledgeBase_Check_Quota, KnowledgeService.checkQuota)
// window
ipcMain.handle(IpcChannel.Windows_SetMinimumSize, (_, width: number, height: number) => {
@@ -495,14 +313,6 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle(IpcChannel.Mcp_GetInstallInfo, mcpService.getInstallInfo)
ipcMain.handle(IpcChannel.Mcp_CheckConnectivity, mcpService.checkMcpConnectivity)
// Register Python execution handler
ipcMain.handle(
IpcChannel.Python_Execute,
async (_, script: string, context?: Record<string, any>, timeout?: number) => {
return await pythonService.executeScript(script, context, timeout)
}
)
ipcMain.handle(IpcChannel.App_IsBinaryExist, (_, name: string) => isBinaryExists(name))
ipcMain.handle(IpcChannel.App_GetBinaryPath, (_, name: string) => getBinaryPath(name))
ipcMain.handle(IpcChannel.App_InstallUvBinary, () => runInstallScript('install-uv.js'))
@@ -548,12 +358,6 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
setOpenLinkExternal(webviewId, isExternal)
)
ipcMain.handle(IpcChannel.Webview_SetSpellCheckEnabled, (_, webviewId: number, isEnable: boolean) => {
const webview = webContents.fromId(webviewId)
if (!webview) return
webview.session.setSpellCheckerEnabled(isEnable)
})
// store sync
storeSyncService.registerIpcHandler()

View File

@@ -1,44 +0,0 @@
import { BaseLoader } from '@cherrystudio/embedjs-interfaces'
import { cleanString } from '@cherrystudio/embedjs-utils'
import { RecursiveCharacterTextSplitter } from '@langchain/textsplitters'
import md5 from 'md5'
export class NoteLoader extends BaseLoader<{ type: 'NoteLoader' }> {
private readonly text: string
private readonly sourceUrl?: string
constructor({
text,
sourceUrl,
chunkSize,
chunkOverlap
}: {
text: string
sourceUrl?: string
chunkSize?: number
chunkOverlap?: number
}) {
super(`NoteLoader_${md5(text + (sourceUrl || ''))}`, { text, sourceUrl }, chunkSize ?? 2000, chunkOverlap ?? 0)
this.text = text
this.sourceUrl = sourceUrl
}
override async *getUnfilteredChunks() {
const chunker = new RecursiveCharacterTextSplitter({
chunkSize: this.chunkSize,
chunkOverlap: this.chunkOverlap
})
const chunks = await chunker.splitText(cleanString(this.text))
for (const chunk of chunks) {
yield {
pageContent: chunk,
metadata: {
type: 'NoteLoader' as const,
source: this.sourceUrl || 'note'
}
}
}
}
}

View File

@@ -4,7 +4,7 @@ import { JsonLoader, LocalPathLoader, RAGApplication, TextLoader } from '@cherry
import type { AddLoaderReturn } from '@cherrystudio/embedjs-interfaces'
import { WebLoader } from '@cherrystudio/embedjs-loader-web'
import { LoaderReturn } from '@shared/config/types'
import { FileMetadata, KnowledgeBaseParams } from '@types'
import { FileType, KnowledgeBaseParams } from '@types'
import Logger from 'electron-log'
import { DraftsExportLoader } from './draftsExportLoader'
@@ -16,7 +16,6 @@ const FILE_LOADER_MAP: Record<string, string> = {
// 内置类型
'.pdf': 'common',
'.csv': 'common',
'.doc': 'common',
'.docx': 'common',
'.pptx': 'common',
'.xlsx': 'common',
@@ -39,7 +38,7 @@ const FILE_LOADER_MAP: Record<string, string> = {
export async function addOdLoader(
ragApplication: RAGApplication,
file: FileMetadata,
file: FileType,
base: KnowledgeBaseParams,
forceReload: boolean
): Promise<AddLoaderReturn> {
@@ -65,7 +64,7 @@ export async function addOdLoader(
export async function addFileLoader(
ragApplication: RAGApplication,
file: FileMetadata,
file: FileType,
base: KnowledgeBaseParams,
forceReload: boolean
): Promise<LoaderReturn> {

View File

@@ -6,7 +6,6 @@ import DifyKnowledgeServer from './dify-knowledge'
import FetchServer from './fetch'
import FileSystemServer from './filesystem'
import MemoryServer from './memory'
import PythonServer from './python'
import ThinkingServer from './sequentialthinking'
export function createInMemoryMCPServer(name: string, args: string[] = [], envs: Record<string, string> = {}): Server {
@@ -32,9 +31,6 @@ export function createInMemoryMCPServer(name: string, args: string[] = [], envs:
const difyKey = envs.DIFY_KEY
return new DifyKnowledgeServer(difyKey, args).server
}
case '@cherry/python': {
return new PythonServer().server
}
default:
throw new Error(`Unknown in-memory MCP server: ${name}`)
}

View File

@@ -1,113 +0,0 @@
import { pythonService } from '@main/services/PythonService'
import { Server } from '@modelcontextprotocol/sdk/server/index.js'
import { CallToolRequestSchema, ErrorCode, ListToolsRequestSchema, McpError } from '@modelcontextprotocol/sdk/types.js'
import Logger from 'electron-log'
/**
* Python MCP Server for executing Python code using Pyodide
*/
class PythonServer {
public server: Server
constructor() {
this.server = new Server(
{
name: 'python-server',
version: '1.0.0'
},
{
capabilities: {
tools: {}
}
}
)
this.setupRequestHandlers()
}
private setupRequestHandlers() {
// List available tools
this.server.setRequestHandler(ListToolsRequestSchema, async () => {
return {
tools: [
{
name: 'python_execute',
description: `Execute Python code using Pyodide in a sandboxed environment. Supports most Python standard library and scientific packages.
The code will be executed with Python 3.12.
Dependencies may be defined via PEP 723 script metadata, e.g. to install "pydantic", the script should start
with a comment of the form:
# /// script
# dependencies = ['pydantic']
# ///
print('python code here')`,
inputSchema: {
type: 'object',
properties: {
code: {
type: 'string',
description: 'The Python code to execute'
},
context: {
type: 'object',
description: 'Optional context variables to pass to the Python execution environment',
additionalProperties: true
},
timeout: {
type: 'number',
description: 'Timeout in milliseconds (default: 60000)',
default: 60000
}
},
required: ['code']
}
}
]
}
})
// Handle tool calls
this.server.setRequestHandler(CallToolRequestSchema, async (request) => {
const { name, arguments: args } = request.params
if (name !== 'python_execute') {
throw new McpError(ErrorCode.MethodNotFound, `Tool ${name} not found`)
}
try {
const {
code,
context = {},
timeout = 60000
} = args as {
code: string
context?: Record<string, any>
timeout?: number
}
if (!code || typeof code !== 'string') {
throw new McpError(ErrorCode.InvalidParams, 'Code parameter is required and must be a string')
}
Logger.info('Executing Python code via Pyodide')
const result = await pythonService.executeScript(code, context, timeout)
return {
content: [
{
type: 'text',
text: result
}
]
}
} catch (error) {
const errorMessage = error instanceof Error ? error.message : String(error)
Logger.error('Python execution error:', errorMessage)
throw new McpError(ErrorCode.InternalError, `Python execution failed: ${errorMessage}`)
}
})
}
}
export default PythonServer

View File

@@ -106,7 +106,6 @@ class SequentialThinkingServer {
type: 'text',
text: JSON.stringify(
{
thought: validatedInput.thought,
thoughtNumber: validatedInput.thoughtNumber,
totalThoughts: validatedInput.totalThoughts,
nextThoughtNeeded: validatedInput.nextThoughtNeeded,

View File

@@ -1,122 +0,0 @@
import fs from 'node:fs'
import path from 'node:path'
import { windowService } from '@main/services/WindowService'
import { getFileExt } from '@main/utils/file'
import { FileMetadata, OcrProvider } from '@types'
import { app } from 'electron'
import { TypedArray } from 'pdfjs-dist/types/src/display/api'
export default abstract class BaseOcrProvider {
protected provider: OcrProvider
public storageDir = path.join(app.getPath('userData'), 'Data', 'Files')
constructor(provider: OcrProvider) {
if (!provider) {
throw new Error('OCR provider is not set')
}
this.provider = provider
}
abstract parseFile(sourceId: string, file: FileMetadata): Promise<{ processedFile: FileMetadata; quota?: number }>
/**
* 检查文件是否已经被预处理过
* 统一检测方法:如果 Data/Files/{file.id} 是目录,说明已被预处理
* @param file 文件信息
* @returns 如果已处理返回处理后的文件信息否则返回null
*/
public async checkIfAlreadyProcessed(file: FileMetadata): Promise<FileMetadata | null> {
try {
// 检查 Data/Files/{file.id} 是否是目录
const preprocessDirPath = path.join(this.storageDir, file.id)
if (fs.existsSync(preprocessDirPath)) {
const stats = await fs.promises.stat(preprocessDirPath)
// 如果是目录,说明已经被预处理过
if (stats.isDirectory()) {
// 查找目录中的处理结果文件
const files = await fs.promises.readdir(preprocessDirPath)
// 查找主要的处理结果文件(.md 或 .txt
const processedFile = files.find((fileName) => fileName.endsWith('.md') || fileName.endsWith('.txt'))
if (processedFile) {
const processedFilePath = path.join(preprocessDirPath, processedFile)
const processedStats = await fs.promises.stat(processedFilePath)
const ext = getFileExt(processedFile)
return {
...file,
name: file.name.replace(file.ext, ext),
path: processedFilePath,
ext: ext,
size: processedStats.size,
created_at: processedStats.birthtime.toISOString()
}
}
}
}
return null
} catch (error) {
// 如果检查过程中出现错误返回null表示未处理
return null
}
}
/**
* 辅助方法:延迟执行
*/
public delay = (ms: number): Promise<void> => {
return new Promise((resolve) => setTimeout(resolve, ms))
}
public async readPdf(
source: string | URL | TypedArray,
passwordCallback?: (fn: (password: string) => void, reason: string) => string
) {
const { getDocument } = await import('pdfjs-dist/legacy/build/pdf.mjs')
const documentLoadingTask = getDocument(source)
if (passwordCallback) {
documentLoadingTask.onPassword = passwordCallback
}
const document = await documentLoadingTask.promise
return document
}
public async sendOcrProgress(sourceId: string, progress: number): Promise<void> {
const mainWindow = windowService.getMainWindow()
mainWindow?.webContents.send('file-ocr-progress', {
itemId: sourceId,
progress: progress
})
}
/**
* 将文件移动到附件目录
* @param fileId 文件id
* @param filePaths 需要移动的文件路径数组
* @returns 移动后的文件路径数组
*/
public moveToAttachmentsDir(fileId: string, filePaths: string[]): string[] {
const attachmentsPath = path.join(this.storageDir, fileId)
if (!fs.existsSync(attachmentsPath)) {
fs.mkdirSync(attachmentsPath, { recursive: true })
}
const movedPaths: string[] = []
for (const filePath of filePaths) {
if (fs.existsSync(filePath)) {
const fileName = path.basename(filePath)
const destPath = path.join(attachmentsPath, fileName)
fs.copyFileSync(filePath, destPath)
fs.unlinkSync(filePath) // 删除原文件,实现"移动"
movedPaths.push(destPath)
}
}
return movedPaths
}
}

View File

@@ -1,12 +0,0 @@
import { FileMetadata, OcrProvider } from '@types'
import BaseOcrProvider from './BaseOcrProvider'
export default class DefaultOcrProvider extends BaseOcrProvider {
constructor(provider: OcrProvider) {
super(provider)
}
public parseFile(): Promise<{ processedFile: FileMetadata }> {
throw new Error('Method not implemented.')
}
}

View File

@@ -1,128 +0,0 @@
import { isMac } from '@main/constant'
import { FileMetadata, OcrProvider } from '@types'
import Logger from 'electron-log'
import * as fs from 'fs'
import * as path from 'path'
import { TextItem } from 'pdfjs-dist/types/src/display/api'
import BaseOcrProvider from './BaseOcrProvider'
export default class MacSysOcrProvider extends BaseOcrProvider {
private readonly MIN_TEXT_LENGTH = 1000
private MacOCR: any
private async initMacOCR() {
if (!isMac) {
throw new Error('MacSysOcrProvider is only available on macOS')
}
if (!this.MacOCR) {
try {
// @ts-ignore This module is optional and only installed/available on macOS. Runtime checks prevent execution on other platforms.
const module = await import('@cherrystudio/mac-system-ocr')
this.MacOCR = module.default
} catch (error) {
Logger.error('[OCR] Failed to load mac-system-ocr:', error)
throw error
}
}
return this.MacOCR
}
private getRecognitionLevel(level?: number) {
return level === 0 ? this.MacOCR.RECOGNITION_LEVEL_FAST : this.MacOCR.RECOGNITION_LEVEL_ACCURATE
}
constructor(provider: OcrProvider) {
super(provider)
}
private async processPages(
results: any,
totalPages: number,
sourceId: string,
writeStream: fs.WriteStream
): Promise<void> {
await this.initMacOCR()
// TODO: 下个版本后面使用批处理以及p-queue来优化
for (let i = 0; i < totalPages; i++) {
// Convert pages to buffers
const pageNum = i + 1
const pageBuffer = await results.getPage(pageNum)
// Process batch
const ocrResult = await this.MacOCR.recognizeFromBuffer(pageBuffer, {
ocrOptions: {
recognitionLevel: this.getRecognitionLevel(this.provider.options?.recognitionLevel),
minConfidence: this.provider.options?.minConfidence || 0.5
}
})
// Write results in order
writeStream.write(ocrResult.text + '\n')
// Update progress
await this.sendOcrProgress(sourceId, (pageNum / totalPages) * 100)
}
}
public async isScanPdf(buffer: Buffer): Promise<boolean> {
const doc = await this.readPdf(new Uint8Array(buffer))
const pageLength = doc.numPages
let counts = 0
const pagesToCheck = Math.min(pageLength, 10)
for (let i = 0; i < pagesToCheck; i++) {
const page = await doc.getPage(i + 1)
const pageData = await page.getTextContent()
const pageText = pageData.items.map((item) => (item as TextItem).str).join('')
counts += pageText.length
if (counts >= this.MIN_TEXT_LENGTH) {
return false
}
}
return true
}
public async parseFile(sourceId: string, file: FileMetadata): Promise<{ processedFile: FileMetadata }> {
Logger.info(`[OCR] Starting OCR process for file: ${file.name}`)
if (file.ext === '.pdf') {
try {
const { pdf } = await import('@cherrystudio/pdf-to-img-napi')
const pdfBuffer = await fs.promises.readFile(file.path)
const results = await pdf(pdfBuffer, {
scale: 2
})
const totalPages = results.length
const baseDir = path.dirname(file.path)
const baseName = path.basename(file.path, path.extname(file.path))
const txtFileName = `${baseName}.txt`
const txtFilePath = path.join(baseDir, txtFileName)
const writeStream = fs.createWriteStream(txtFilePath)
await this.processPages(results, totalPages, sourceId, writeStream)
await new Promise<void>((resolve, reject) => {
writeStream.end(() => {
Logger.info(`[OCR] OCR process completed successfully for ${file.origin_name}`)
resolve()
})
writeStream.on('error', reject)
})
const movedPaths = this.moveToAttachmentsDir(file.id, [txtFilePath])
return {
processedFile: {
...file,
name: txtFileName,
path: movedPaths[0],
ext: '.txt',
size: fs.statSync(movedPaths[0]).size
}
}
} catch (error) {
Logger.error('[OCR] Error during OCR process:', error)
throw error
}
}
return { processedFile: file }
}
}

View File

@@ -1,26 +0,0 @@
import { FileMetadata, OcrProvider as Provider } from '@types'
import BaseOcrProvider from './BaseOcrProvider'
import OcrProviderFactory from './OcrProviderFactory'
export default class OcrProvider {
private sdk: BaseOcrProvider
constructor(provider: Provider) {
this.sdk = OcrProviderFactory.create(provider)
}
public async parseFile(
sourceId: string,
file: FileMetadata
): Promise<{ processedFile: FileMetadata; quota?: number }> {
return this.sdk.parseFile(sourceId, file)
}
/**
* 检查文件是否已经被预处理过
* @param file 文件信息
* @returns 如果已处理返回处理后的文件信息否则返回null
*/
public async checkIfAlreadyProcessed(file: FileMetadata): Promise<FileMetadata | null> {
return this.sdk.checkIfAlreadyProcessed(file)
}
}

View File

@@ -1,20 +0,0 @@
import { isMac } from '@main/constant'
import { OcrProvider } from '@types'
import Logger from 'electron-log'
import BaseOcrProvider from './BaseOcrProvider'
import DefaultOcrProvider from './DefaultOcrProvider'
import MacSysOcrProvider from './MacSysOcrProvider'
export default class OcrProviderFactory {
static create(provider: OcrProvider): BaseOcrProvider {
switch (provider.id) {
case 'system':
if (!isMac) {
Logger.warn('[OCR] System OCR provider is only available on macOS')
}
return new MacSysOcrProvider(provider)
default:
return new DefaultOcrProvider(provider)
}
}
}

View File

@@ -1,126 +0,0 @@
import fs from 'node:fs'
import path from 'node:path'
import { windowService } from '@main/services/WindowService'
import { getFileExt } from '@main/utils/file'
import { FileMetadata, PreprocessProvider } from '@types'
import { app } from 'electron'
import { TypedArray } from 'pdfjs-dist/types/src/display/api'
export default abstract class BasePreprocessProvider {
protected provider: PreprocessProvider
protected userId?: string
public storageDir = path.join(app.getPath('userData'), 'Data', 'Files')
constructor(provider: PreprocessProvider, userId?: string) {
if (!provider) {
throw new Error('Preprocess provider is not set')
}
this.provider = provider
this.userId = userId
}
abstract parseFile(sourceId: string, file: FileMetadata): Promise<{ processedFile: FileMetadata; quota?: number }>
abstract checkQuota(): Promise<number>
/**
* 检查文件是否已经被预处理过
* 统一检测方法:如果 Data/Files/{file.id} 是目录,说明已被预处理
* @param file 文件信息
* @returns 如果已处理返回处理后的文件信息否则返回null
*/
public async checkIfAlreadyProcessed(file: FileMetadata): Promise<FileMetadata | null> {
try {
// 检查 Data/Files/{file.id} 是否是目录
const preprocessDirPath = path.join(this.storageDir, file.id)
if (fs.existsSync(preprocessDirPath)) {
const stats = await fs.promises.stat(preprocessDirPath)
// 如果是目录,说明已经被预处理过
if (stats.isDirectory()) {
// 查找目录中的处理结果文件
const files = await fs.promises.readdir(preprocessDirPath)
// 查找主要的处理结果文件(.md 或 .txt
const processedFile = files.find((fileName) => fileName.endsWith('.md') || fileName.endsWith('.txt'))
if (processedFile) {
const processedFilePath = path.join(preprocessDirPath, processedFile)
const processedStats = await fs.promises.stat(processedFilePath)
const ext = getFileExt(processedFile)
return {
...file,
name: file.name.replace(file.ext, ext),
path: processedFilePath,
ext: ext,
size: processedStats.size,
created_at: processedStats.birthtime.toISOString()
}
}
}
}
return null
} catch (error) {
// 如果检查过程中出现错误返回null表示未处理
return null
}
}
/**
* 辅助方法:延迟执行
*/
public delay = (ms: number): Promise<void> => {
return new Promise((resolve) => setTimeout(resolve, ms))
}
public async readPdf(
source: string | URL | TypedArray,
passwordCallback?: (fn: (password: string) => void, reason: string) => string
) {
const { getDocument } = await import('pdfjs-dist/legacy/build/pdf.mjs')
const documentLoadingTask = getDocument(source)
if (passwordCallback) {
documentLoadingTask.onPassword = passwordCallback
}
const document = await documentLoadingTask.promise
return document
}
public async sendPreprocessProgress(sourceId: string, progress: number): Promise<void> {
const mainWindow = windowService.getMainWindow()
mainWindow?.webContents.send('file-preprocess-progress', {
itemId: sourceId,
progress: progress
})
}
/**
* 将文件移动到附件目录
* @param fileId 文件id
* @param filePaths 需要移动的文件路径数组
* @returns 移动后的文件路径数组
*/
public moveToAttachmentsDir(fileId: string, filePaths: string[]): string[] {
const attachmentsPath = path.join(this.storageDir, fileId)
if (!fs.existsSync(attachmentsPath)) {
fs.mkdirSync(attachmentsPath, { recursive: true })
}
const movedPaths: string[] = []
for (const filePath of filePaths) {
if (fs.existsSync(filePath)) {
const fileName = path.basename(filePath)
const destPath = path.join(attachmentsPath, fileName)
fs.copyFileSync(filePath, destPath)
fs.unlinkSync(filePath) // 删除原文件,实现"移动"
movedPaths.push(destPath)
}
}
return movedPaths
}
}

View File

@@ -1,16 +0,0 @@
import { FileMetadata, PreprocessProvider } from '@types'
import BasePreprocessProvider from './BasePreprocessProvider'
export default class DefaultPreprocessProvider extends BasePreprocessProvider {
constructor(provider: PreprocessProvider) {
super(provider)
}
public parseFile(): Promise<{ processedFile: FileMetadata }> {
throw new Error('Method not implemented.')
}
public checkQuota(): Promise<number> {
throw new Error('Method not implemented.')
}
}

View File

@@ -1,329 +0,0 @@
import fs from 'node:fs'
import path from 'node:path'
import { FileMetadata, PreprocessProvider } from '@types'
import AdmZip from 'adm-zip'
import axios, { AxiosRequestConfig } from 'axios'
import Logger from 'electron-log'
import BasePreprocessProvider from './BasePreprocessProvider'
type ApiResponse<T> = {
code: string
data: T
message?: string
}
type PreuploadResponse = {
uid: string
url: string
}
type StatusResponse = {
status: string
progress: number
}
type ParsedFileResponse = {
status: string
url: string
}
export default class Doc2xPreprocessProvider extends BasePreprocessProvider {
constructor(provider: PreprocessProvider) {
super(provider)
}
private async validateFile(filePath: string): Promise<void> {
const pdfBuffer = await fs.promises.readFile(filePath)
const doc = await this.readPdf(new Uint8Array(pdfBuffer))
// 文件页数小于1000页
if (doc.numPages >= 1000) {
throw new Error(`PDF page count (${doc.numPages}) exceeds the limit of 1000 pages`)
}
// 文件大小小于300MB
if (pdfBuffer.length >= 300 * 1024 * 1024) {
const fileSizeMB = Math.round(pdfBuffer.length / (1024 * 1024))
throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 300MB`)
}
}
public async parseFile(sourceId: string, file: FileMetadata): Promise<{ processedFile: FileMetadata }> {
try {
Logger.info(`Preprocess processing started: ${file.path}`)
// 步骤1: 准备上传
const { uid, url } = await this.preupload()
Logger.info(`Preprocess preupload completed: uid=${uid}`)
await this.validateFile(file.path)
// 步骤2: 上传文件
await this.putFile(file.path, url)
// 步骤3: 等待处理完成
await this.waitForProcessing(sourceId, uid)
Logger.info(`Preprocess parsing completed successfully for: ${file.path}`)
// 步骤4: 导出文件
const { path: outputPath } = await this.exportFile(file, uid)
// 步骤5: 创建处理后的文件信息
return {
processedFile: this.createProcessedFileInfo(file, outputPath)
}
} catch (error) {
Logger.error(
`Preprocess processing failed for ${file.path}: ${error instanceof Error ? error.message : String(error)}`
)
throw error
}
}
private createProcessedFileInfo(file: FileMetadata, outputPath: string): FileMetadata {
const outputFilePath = `${outputPath}/${file.name.split('.').slice(0, -1).join('.')}.md`
return {
...file,
name: file.name.replace('.pdf', '.md'),
path: outputFilePath,
ext: '.md',
size: fs.statSync(outputFilePath).size
}
}
/**
* 导出文件
* @param file 文件信息
* @param uid 预上传响应的uid
* @returns 导出文件的路径
*/
public async exportFile(file: FileMetadata, uid: string): Promise<{ path: string }> {
Logger.info(`Exporting file: ${file.path}`)
// 步骤1: 转换文件
await this.convertFile(uid, file.path)
Logger.info(`File conversion completed for: ${file.path}`)
// 步骤2: 等待导出并获取URL
const exportUrl = await this.waitForExport(uid)
// 步骤3: 下载并解压文件
return this.downloadFile(exportUrl, file)
}
/**
* 等待处理完成
* @param sourceId 源文件ID
* @param uid 预上传响应的uid
*/
private async waitForProcessing(sourceId: string, uid: string): Promise<void> {
while (true) {
await this.delay(1000)
const { status, progress } = await this.getStatus(uid)
await this.sendPreprocessProgress(sourceId, progress)
Logger.info(`Preprocess processing status: ${status}, progress: ${progress}%`)
if (status === 'success') {
return
} else if (status === 'failed') {
throw new Error('Preprocess processing failed')
}
}
}
/**
* 等待导出完成
* @param uid 预上传响应的uid
* @returns 导出文件的url
*/
private async waitForExport(uid: string): Promise<string> {
while (true) {
await this.delay(1000)
const { status, url } = await this.getParsedFile(uid)
Logger.info(`Export status: ${status}`)
if (status === 'success' && url) {
return url
} else if (status === 'failed') {
throw new Error('Export failed')
}
}
}
/**
* 预上传文件
* @returns 预上传响应的url和uid
*/
private async preupload(): Promise<PreuploadResponse> {
const config = this.createAuthConfig()
const endpoint = `${this.provider.apiHost}/api/v2/parse/preupload`
try {
const { data } = await axios.post<ApiResponse<PreuploadResponse>>(endpoint, null, config)
if (data.code === 'success' && data.data) {
return data.data
} else {
throw new Error(`API returned error: ${data.message || JSON.stringify(data)}`)
}
} catch (error) {
Logger.error(`Failed to get preupload URL: ${error instanceof Error ? error.message : String(error)}`)
throw new Error('Failed to get preupload URL')
}
}
/**
* 上传文件
* @param filePath 文件路径
* @param url 预上传响应的url
*/
private async putFile(filePath: string, url: string): Promise<void> {
try {
const fileStream = fs.createReadStream(filePath)
const response = await axios.put(url, fileStream)
if (response.status !== 200) {
throw new Error(`HTTP status ${response.status}: ${response.statusText}`)
}
} catch (error) {
Logger.error(`Failed to upload file ${filePath}: ${error instanceof Error ? error.message : String(error)}`)
throw new Error('Failed to upload file')
}
}
private async getStatus(uid: string): Promise<StatusResponse> {
const config = this.createAuthConfig()
const endpoint = `${this.provider.apiHost}/api/v2/parse/status?uid=${uid}`
try {
const response = await axios.get<ApiResponse<StatusResponse>>(endpoint, config)
if (response.data.code === 'success' && response.data.data) {
return response.data.data
} else {
throw new Error(`API returned error: ${response.data.message || JSON.stringify(response.data)}`)
}
} catch (error) {
Logger.error(`Failed to get status for uid ${uid}: ${error instanceof Error ? error.message : String(error)}`)
throw new Error('Failed to get processing status')
}
}
/**
* Preprocess文件
* @param uid 预上传响应的uid
* @param filePath 文件路径
*/
private async convertFile(uid: string, filePath: string): Promise<void> {
const fileName = path.basename(filePath).split('.')[0]
const config = {
...this.createAuthConfig(),
headers: {
...this.createAuthConfig().headers,
'Content-Type': 'application/json'
}
}
const payload = {
uid,
to: 'md',
formula_mode: 'normal',
filename: fileName
}
const endpoint = `${this.provider.apiHost}/api/v2/convert/parse`
try {
const response = await axios.post<ApiResponse<any>>(endpoint, payload, config)
if (response.data.code !== 'success') {
throw new Error(`API returned error: ${response.data.message || JSON.stringify(response.data)}`)
}
} catch (error) {
Logger.error(`Failed to convert file ${filePath}: ${error instanceof Error ? error.message : String(error)}`)
throw new Error('Failed to convert file')
}
}
/**
* 获取解析后的文件信息
* @param uid 预上传响应的uid
* @returns 解析后的文件信息
*/
private async getParsedFile(uid: string): Promise<ParsedFileResponse> {
const config = this.createAuthConfig()
const endpoint = `${this.provider.apiHost}/api/v2/convert/parse/result?uid=${uid}`
try {
const response = await axios.get<ApiResponse<ParsedFileResponse>>(endpoint, config)
if (response.status === 200 && response.data.data) {
return response.data.data
} else {
throw new Error(`HTTP status ${response.status}: ${response.statusText}`)
}
} catch (error) {
Logger.error(
`Failed to get parsed file for uid ${uid}: ${error instanceof Error ? error.message : String(error)}`
)
throw new Error('Failed to get parsed file information')
}
}
/**
* 下载文件
* @param url 导出文件的url
* @param file 文件信息
* @returns 下载文件的路径
*/
private async downloadFile(url: string, file: FileMetadata): Promise<{ path: string }> {
const dirPath = this.storageDir
// 使用统一的存储路径Data/Files/{file.id}/
const extractPath = path.join(dirPath, file.id)
const zipPath = path.join(dirPath, `${file.id}.zip`)
// 确保目录存在
fs.mkdirSync(dirPath, { recursive: true })
fs.mkdirSync(extractPath, { recursive: true })
Logger.info(`Downloading to export path: ${zipPath}`)
try {
// 下载文件
const response = await axios.get(url, { responseType: 'arraybuffer' })
fs.writeFileSync(zipPath, response.data)
// 确保提取目录存在
if (!fs.existsSync(extractPath)) {
fs.mkdirSync(extractPath, { recursive: true })
}
// 解压文件
const zip = new AdmZip(zipPath)
zip.extractAllTo(extractPath, true)
Logger.info(`Extracted files to: ${extractPath}`)
// 删除临时ZIP文件
fs.unlinkSync(zipPath)
return { path: extractPath }
} catch (error) {
Logger.error(`Failed to download and extract file: ${error instanceof Error ? error.message : String(error)}`)
throw new Error('Failed to download and extract file')
}
}
private createAuthConfig(): AxiosRequestConfig {
return {
headers: {
Authorization: `Bearer ${this.provider.apiKey}`
}
}
}
public checkQuota(): Promise<number> {
throw new Error('Method not implemented.')
}
}

View File

@@ -1,399 +0,0 @@
import fs from 'node:fs'
import path from 'node:path'
import { FileMetadata, PreprocessProvider } from '@types'
import AdmZip from 'adm-zip'
import axios from 'axios'
import Logger from 'electron-log'
import BasePreprocessProvider from './BasePreprocessProvider'
type ApiResponse<T> = {
code: number
data: T
msg?: string
trace_id?: string
}
type BatchUploadResponse = {
batch_id: string
file_urls: string[]
}
type ExtractProgress = {
extracted_pages: number
total_pages: number
start_time: string
}
type ExtractFileResult = {
file_name: string
state: 'done' | 'waiting-file' | 'pending' | 'running' | 'converting' | 'failed'
err_msg: string
full_zip_url?: string
extract_progress?: ExtractProgress
}
type ExtractResultResponse = {
batch_id: string
extract_result: ExtractFileResult[]
}
type QuotaResponse = {
code: number
data: {
user_left_quota: number
total_left_quota: number
}
msg?: string
trace_id?: string
}
export default class MineruPreprocessProvider extends BasePreprocessProvider {
constructor(provider: PreprocessProvider, userId?: string) {
super(provider, userId)
// todo免费期结束后删除
this.provider.apiKey = this.provider.apiKey || import.meta.env.MAIN_VITE_MINERU_API_KEY
}
public async parseFile(
sourceId: string,
file: FileMetadata
): Promise<{ processedFile: FileMetadata; quota: number }> {
try {
Logger.info(`MinerU preprocess processing started: ${file.path}`)
await this.validateFile(file.path)
// 1. 获取上传URL并上传文件
const batchId = await this.uploadFile(file)
Logger.info(`MinerU file upload completed: batch_id=${batchId}`)
// 2. 等待处理完成并获取结果
const extractResult = await this.waitForCompletion(sourceId, batchId, file.origin_name)
Logger.info(`MinerU processing completed for batch: ${batchId}`)
// 3. 下载并解压文件
const { path: outputPath } = await this.downloadAndExtractFile(extractResult.full_zip_url!, file)
// 4. check quota
const quota = await this.checkQuota()
// 5. 创建处理后的文件信息
return {
processedFile: this.createProcessedFileInfo(file, outputPath),
quota
}
} catch (error: any) {
Logger.error(`MinerU preprocess processing failed for ${file.path}: ${error.message}`)
throw new Error(error.message)
}
}
public async checkQuota() {
try {
const quota = await fetch(`${this.provider.apiHost}/api/v4/quota`, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.provider.apiKey}`,
token: this.userId ?? ''
}
})
if (!quota.ok) {
throw new Error(`HTTP ${quota.status}: ${quota.statusText}`)
}
const response: QuotaResponse = await quota.json()
return response.data.user_left_quota
} catch (error) {
console.error('Error checking quota:', error)
throw error
}
}
private async validateFile(filePath: string): Promise<void> {
const quota = await this.checkQuota()
const pdfBuffer = await fs.promises.readFile(filePath)
const doc = await this.readPdf(new Uint8Array(pdfBuffer))
// 文件页数小于600页
if (doc.numPages >= 600) {
throw new Error(`PDF page count (${doc.numPages}) exceeds the limit of 600 pages`)
}
// 文件大小小于200MB
if (pdfBuffer.length >= 200 * 1024 * 1024) {
const fileSizeMB = Math.round(pdfBuffer.length / (1024 * 1024))
throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 200MB`)
}
// 检查配额
if (quota <= 0 || quota - doc.numPages <= 0) {
throw new Error('MinerU解析配额不足请申请企业账户或自行部署剩余额度' + quota)
}
}
private createProcessedFileInfo(file: FileMetadata, outputPath: string): FileMetadata {
// 查找解压后的主要文件
let finalPath = ''
let finalName = file.origin_name.replace('.pdf', '.md')
try {
const files = fs.readdirSync(outputPath)
const mdFile = files.find((f) => f.endsWith('.md'))
if (mdFile) {
const originalMdPath = path.join(outputPath, mdFile)
const newMdPath = path.join(outputPath, finalName)
// 重命名文件为原始文件名
try {
fs.renameSync(originalMdPath, newMdPath)
finalPath = newMdPath
Logger.info(`Renamed markdown file from ${mdFile} to ${finalName}`)
} catch (renameError) {
Logger.warn(`Failed to rename file ${mdFile} to ${finalName}: ${renameError}`)
// 如果重命名失败,使用原文件
finalPath = originalMdPath
finalName = mdFile
}
}
} catch (error) {
Logger.warn(`Failed to read output directory ${outputPath}: ${error}`)
finalPath = path.join(outputPath, `${file.id}.md`)
}
return {
...file,
name: finalName,
path: finalPath,
ext: '.md',
size: fs.existsSync(finalPath) ? fs.statSync(finalPath).size : 0
}
}
private async downloadAndExtractFile(zipUrl: string, file: FileMetadata): Promise<{ path: string }> {
const dirPath = this.storageDir
const zipPath = path.join(dirPath, `${file.id}.zip`)
const extractPath = path.join(dirPath, `${file.id}`)
Logger.info(`Downloading MinerU result to: ${zipPath}`)
try {
// 下载ZIP文件
const response = await axios.get(zipUrl, { responseType: 'arraybuffer' })
fs.writeFileSync(zipPath, response.data)
Logger.info(`Downloaded ZIP file: ${zipPath}`)
// 确保提取目录存在
if (!fs.existsSync(extractPath)) {
fs.mkdirSync(extractPath, { recursive: true })
}
// 解压文件
const zip = new AdmZip(zipPath)
zip.extractAllTo(extractPath, true)
Logger.info(`Extracted files to: ${extractPath}`)
// 删除临时ZIP文件
fs.unlinkSync(zipPath)
return { path: extractPath }
} catch (error: any) {
Logger.error(`Failed to download and extract file: ${error.message}`)
throw new Error(error.message)
}
}
private async uploadFile(file: FileMetadata): Promise<string> {
try {
// 步骤1: 获取上传URL
const { batchId, fileUrls } = await this.getBatchUploadUrls(file)
Logger.info(`Got upload URLs for batch: ${batchId}`)
console.log('batchId:', batchId, 'fileurls:', fileUrls)
// 步骤2: 上传文件到获取的URL
await this.putFileToUrl(file.path, fileUrls[0])
Logger.info(`File uploaded successfully: ${file.path}`)
return batchId
} catch (error: any) {
Logger.error(`Failed to upload file ${file.path}: ${error.message}`)
throw new Error(error.message)
}
}
private async getBatchUploadUrls(file: FileMetadata): Promise<{ batchId: string; fileUrls: string[] }> {
const endpoint = `${this.provider.apiHost}/api/v4/file-urls/batch`
const payload = {
language: 'auto',
enable_formula: true,
enable_table: true,
files: [
{
name: file.origin_name,
is_ocr: true,
data_id: file.id
}
]
}
try {
const response = await fetch(endpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.provider.apiKey}`,
token: this.userId ?? '',
Accept: '*/*'
},
body: JSON.stringify(payload)
})
if (response.ok) {
const data: ApiResponse<BatchUploadResponse> = await response.json()
if (data.code === 0 && data.data) {
const { batch_id, file_urls } = data.data
return {
batchId: batch_id,
fileUrls: file_urls
}
} else {
throw new Error(`API returned error: ${data.msg || JSON.stringify(data)}`)
}
} else {
throw new Error(`HTTP ${response.status}: ${response.statusText}`)
}
} catch (error: any) {
Logger.error(`Failed to get batch upload URLs: ${error.message}`)
throw new Error(error.message)
}
}
private async putFileToUrl(filePath: string, uploadUrl: string): Promise<void> {
try {
const fileBuffer = await fs.promises.readFile(filePath)
const response = await fetch(uploadUrl, {
method: 'PUT',
body: fileBuffer,
headers: {
'Content-Type': 'application/pdf'
}
// headers: {
// 'Content-Length': fileBuffer.length.toString()
// }
})
if (!response.ok) {
// 克隆 response 以避免消费 body stream
const responseClone = response.clone()
try {
const responseBody = await responseClone.text()
const errorInfo = {
status: response.status,
statusText: response.statusText,
url: response.url,
type: response.type,
redirected: response.redirected,
headers: Object.fromEntries(response.headers.entries()),
body: responseBody
}
console.error('Response details:', errorInfo)
throw new Error(`Upload failed with status ${response.status}: ${responseBody}`)
} catch (parseError) {
throw new Error(`Upload failed with status ${response.status}. Could not parse response body.`)
}
}
Logger.info(`File uploaded successfully to: ${uploadUrl}`)
} catch (error: any) {
Logger.error(`Failed to upload file to URL ${uploadUrl}: ${error}`)
throw new Error(error.message)
}
}
private async getExtractResults(batchId: string): Promise<ExtractResultResponse> {
const endpoint = `${this.provider.apiHost}/api/v4/extract-results/batch/${batchId}`
try {
const response = await fetch(endpoint, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.provider.apiKey}`,
token: this.userId ?? ''
}
})
if (response.ok) {
const data: ApiResponse<ExtractResultResponse> = await response.json()
if (data.code === 0 && data.data) {
return data.data
} else {
throw new Error(`API returned error: ${data.msg || JSON.stringify(data)}`)
}
} else {
throw new Error(`HTTP ${response.status}: ${response.statusText}`)
}
} catch (error: any) {
Logger.error(`Failed to get extract results for batch ${batchId}: ${error.message}`)
throw new Error(error.message)
}
}
private async waitForCompletion(
sourceId: string,
batchId: string,
fileName: string,
maxRetries: number = 60,
intervalMs: number = 5000
): Promise<ExtractFileResult> {
let retries = 0
while (retries < maxRetries) {
try {
const result = await this.getExtractResults(batchId)
// 查找对应文件的处理结果
const fileResult = result.extract_result.find((item) => item.file_name === fileName)
if (!fileResult) {
throw new Error(`File ${fileName} not found in batch results`)
}
// 检查处理状态
if (fileResult.state === 'done' && fileResult.full_zip_url) {
Logger.info(`Processing completed for file: ${fileName}`)
return fileResult
} else if (fileResult.state === 'failed') {
throw new Error(`Processing failed for file: ${fileName}, error: ${fileResult.err_msg}`)
} else if (fileResult.state === 'running') {
// 发送进度更新
if (fileResult.extract_progress) {
const progress = Math.round(
(fileResult.extract_progress.extracted_pages / fileResult.extract_progress.total_pages) * 100
)
await this.sendPreprocessProgress(sourceId, progress)
Logger.info(`File ${fileName} processing progress: ${progress}%`)
} else {
// 如果没有具体进度信息,发送一个通用进度
await this.sendPreprocessProgress(sourceId, 50)
Logger.info(`File ${fileName} is still processing...`)
}
}
} catch (error) {
Logger.warn(`Failed to check status for batch ${batchId}, retry ${retries + 1}/${maxRetries}`)
if (retries === maxRetries - 1) {
throw error
}
}
retries++
await new Promise((resolve) => setTimeout(resolve, intervalMs))
}
throw new Error(`Processing timeout for batch: ${batchId}`)
}
}

View File

@@ -1,187 +0,0 @@
import fs from 'node:fs'
import { MistralClientManager } from '@main/services/MistralClientManager'
import { MistralService } from '@main/services/remotefile/MistralService'
import { Mistral } from '@mistralai/mistralai'
import { DocumentURLChunk } from '@mistralai/mistralai/models/components/documenturlchunk'
import { ImageURLChunk } from '@mistralai/mistralai/models/components/imageurlchunk'
import { OCRResponse } from '@mistralai/mistralai/models/components/ocrresponse'
import { FileMetadata, FileTypes, PreprocessProvider, Provider } from '@types'
import Logger from 'electron-log'
import path from 'path'
import BasePreprocessProvider from './BasePreprocessProvider'
type PreuploadResponse = DocumentURLChunk | ImageURLChunk
export default class MistralPreprocessProvider extends BasePreprocessProvider {
private sdk: Mistral
private fileService: MistralService
constructor(provider: PreprocessProvider) {
super(provider)
const clientManager = MistralClientManager.getInstance()
const aiProvider: Provider = {
id: provider.id,
type: 'mistral',
name: provider.name,
apiKey: provider.apiKey!,
apiHost: provider.apiHost!,
models: []
}
clientManager.initializeClient(aiProvider)
this.sdk = clientManager.getClient()
this.fileService = new MistralService(aiProvider)
}
private async preupload(file: FileMetadata): Promise<PreuploadResponse> {
let document: PreuploadResponse
Logger.info(`preprocess preupload started for local file: ${file.path}`)
if (file.ext.toLowerCase() === '.pdf') {
const uploadResponse = await this.fileService.uploadFile(file)
if (uploadResponse.status === 'failed') {
Logger.error('File upload failed:', uploadResponse)
throw new Error('Failed to upload file: ' + uploadResponse.displayName)
}
await this.sendPreprocessProgress(file.id, 15)
const fileUrl = await this.sdk.files.getSignedUrl({
fileId: uploadResponse.fileId
})
Logger.info('Got signed URL:', fileUrl)
await this.sendPreprocessProgress(file.id, 20)
document = {
type: 'document_url',
documentUrl: fileUrl.url
}
} else {
const base64Image = Buffer.from(fs.readFileSync(file.path)).toString('base64')
document = {
type: 'image_url',
imageUrl: `data:image/png;base64,${base64Image}`
}
}
if (!document) {
throw new Error('Unsupported file type')
}
return document
}
public async parseFile(sourceId: string, file: FileMetadata): Promise<{ processedFile: FileMetadata }> {
try {
const document = await this.preupload(file)
const result = await this.sdk.ocr.process({
model: this.provider.model!,
document: document,
includeImageBase64: true
})
if (result) {
await this.sendPreprocessProgress(sourceId, 100)
const processedFile = this.convertFile(result, file)
return {
processedFile
}
} else {
throw new Error('preprocess processing failed: OCR response is empty')
}
} catch (error) {
throw new Error('preprocess processing failed: ' + error)
}
}
private convertFile(result: OCRResponse, file: FileMetadata): FileMetadata {
// 使用统一的存储路径Data/Files/{file.id}/
const conversionId = file.id
const outputPath = path.join(this.storageDir, file.id)
// const outputPath = this.storageDir
const outputFileName = path.basename(file.path, path.extname(file.path))
fs.mkdirSync(outputPath, { recursive: true })
const markdownParts: string[] = []
let counter = 0
// Process each page
result.pages.forEach((page) => {
let pageMarkdown = page.markdown
// Process images from this page
page.images.forEach((image) => {
if (image.imageBase64) {
let imageFormat = 'jpeg' // default format
let imageBase64Data = image.imageBase64
// Check for data URL prefix more efficiently
const prefixEnd = image.imageBase64.indexOf(';base64,')
if (prefixEnd > 0) {
const prefix = image.imageBase64.substring(0, prefixEnd)
const formatIndex = prefix.indexOf('image/')
if (formatIndex >= 0) {
imageFormat = prefix.substring(formatIndex + 6)
}
imageBase64Data = image.imageBase64.substring(prefixEnd + 8)
}
const imageFileName = `img-${counter}.${imageFormat}`
const imagePath = path.join(outputPath, imageFileName)
// Save image file
try {
fs.writeFileSync(imagePath, Buffer.from(imageBase64Data, 'base64'))
// Update image reference in markdown
// Use relative path for better portability
const relativeImagePath = `./${imageFileName}`
// Find the start and end of the image markdown
const imgStart = pageMarkdown.indexOf(image.imageBase64)
if (imgStart >= 0) {
// Find the markdown image syntax around this base64
const mdStart = pageMarkdown.lastIndexOf('![', imgStart)
const mdEnd = pageMarkdown.indexOf(')', imgStart)
if (mdStart >= 0 && mdEnd >= 0) {
// Replace just this specific image reference
pageMarkdown =
pageMarkdown.substring(0, mdStart) +
`![Image ${counter}](${relativeImagePath})` +
pageMarkdown.substring(mdEnd + 1)
}
}
counter++
} catch (error) {
Logger.error(`Failed to save image ${imageFileName}:`, error)
}
}
})
markdownParts.push(pageMarkdown)
})
// Combine all markdown content with double newlines for readability
const combinedMarkdown = markdownParts.join('\n\n')
// Write the markdown content to a file
const mdFileName = `${outputFileName}.md`
const mdFilePath = path.join(outputPath, mdFileName)
fs.writeFileSync(mdFilePath, combinedMarkdown)
return {
id: conversionId,
name: file.name.replace(/\.[^/.]+$/, '.md'),
origin_name: file.origin_name,
path: mdFilePath,
created_at: new Date().toISOString(),
type: FileTypes.DOCUMENT,
ext: '.md',
size: fs.statSync(mdFilePath).size,
count: 1
} as FileMetadata
}
public checkQuota(): Promise<number> {
throw new Error('Method not implemented.')
}
}

View File

@@ -1,30 +0,0 @@
import { FileMetadata, PreprocessProvider as Provider } from '@types'
import BasePreprocessProvider from './BasePreprocessProvider'
import PreprocessProviderFactory from './PreprocessProviderFactory'
export default class PreprocessProvider {
private sdk: BasePreprocessProvider
constructor(provider: Provider, userId?: string) {
this.sdk = PreprocessProviderFactory.create(provider, userId)
}
public async parseFile(
sourceId: string,
file: FileMetadata
): Promise<{ processedFile: FileMetadata; quota?: number }> {
return this.sdk.parseFile(sourceId, file)
}
public async checkQuota(): Promise<number> {
return this.sdk.checkQuota()
}
/**
* 检查文件是否已经被预处理过
* @param file 文件信息
* @returns 如果已处理返回处理后的文件信息否则返回null
*/
public async checkIfAlreadyProcessed(file: FileMetadata): Promise<FileMetadata | null> {
return this.sdk.checkIfAlreadyProcessed(file)
}
}

View File

@@ -1,21 +0,0 @@
import { PreprocessProvider } from '@types'
import BasePreprocessProvider from './BasePreprocessProvider'
import DefaultPreprocessProvider from './DefaultPreprocessProvider'
import Doc2xPreprocessProvider from './Doc2xPreprocessProvider'
import MineruPreprocessProvider from './MineruPreprocessProvider'
import MistralPreprocessProvider from './MistralPreprocessProvider'
export default class PreprocessProviderFactory {
static create(provider: PreprocessProvider, userId?: string): BasePreprocessProvider {
switch (provider.id) {
case 'doc2x':
return new Doc2xPreprocessProvider(provider)
case 'mistral':
return new MistralPreprocessProvider(provider)
case 'mineru':
return new MineruPreprocessProvider(provider, userId)
default:
return new DefaultPreprocessProvider(provider)
}
}
}

View File

@@ -17,7 +17,7 @@ export default abstract class BaseReranker {
* Get Rerank Request Url
*/
protected getRerankUrl() {
if (this.base.rerankModelProvider === 'bailian') {
if (this.base.rerankModelProvider === 'dashscope') {
return 'https://dashscope.aliyuncs.com/api/v1/services/rerank/text-rerank/text-rerank'
}
@@ -50,7 +50,7 @@ export default abstract class BaseReranker {
documents,
top_k: topN
}
} else if (provider === 'bailian') {
} else if (provider === 'dashscope') {
return {
model: this.base.rerankModel,
input: {
@@ -82,11 +82,11 @@ export default abstract class BaseReranker {
*/
protected extractRerankResult(data: any) {
const provider = this.base.rerankModelProvider
if (provider === 'bailian') {
if (provider === 'dashscope') {
return data.output.results
} else if (provider === 'voyageai') {
return data.data
} else if (provider?.includes('tei')) {
} else if (provider === 'mis-tei') {
return data.map((item: any) => {
return {
index: item.index,

View File

@@ -1,12 +1,11 @@
import { isWin } from '@main/constant'
import { locales } from '@main/utils/locales'
import { generateUserAgent } from '@main/utils/systemInfo'
import { FeedUrl, UpgradeChannel } from '@shared/config/constant'
import { FeedUrl } from '@shared/config/constant'
import { IpcChannel } from '@shared/IpcChannel'
import { CancellationToken, UpdateInfo } from 'builder-util-runtime'
import { UpdateInfo } from 'builder-util-runtime'
import { app, BrowserWindow, dialog } from 'electron'
import logger from 'electron-log'
import { AppUpdater as _AppUpdater, autoUpdater, NsisUpdater, UpdateCheckResult } from 'electron-updater'
import { AppUpdater as _AppUpdater, autoUpdater, NsisUpdater } from 'electron-updater'
import path from 'path'
import icon from '../../../build/icon.png?asset'
@@ -15,8 +14,6 @@ import { configManager } from './ConfigManager'
export default class AppUpdater {
autoUpdater: _AppUpdater = autoUpdater
private releaseInfo: UpdateInfo | undefined
private cancellationToken: CancellationToken = new CancellationToken()
private updateCheckResult: UpdateCheckResult | null = null
constructor(mainWindow: BrowserWindow) {
logger.transports.file.level = 'info'
@@ -25,11 +22,9 @@ export default class AppUpdater {
autoUpdater.forceDevUpdateConfig = !app.isPackaged
autoUpdater.autoDownload = configManager.getAutoUpdate()
autoUpdater.autoInstallOnAppQuit = configManager.getAutoUpdate()
autoUpdater.requestHeaders = {
...autoUpdater.requestHeaders,
'User-Agent': generateUserAgent()
}
autoUpdater.setFeedURL(configManager.getFeedUrl())
// 检测下载错误
autoUpdater.on('error', (error) => {
// 简单记录错误信息和时间戳
logger.error('更新异常', {
@@ -69,35 +64,6 @@ export default class AppUpdater {
this.autoUpdater = autoUpdater
}
private async _getPreReleaseVersionFromGithub(channel: UpgradeChannel) {
try {
logger.info('get pre release version from github', channel)
const responses = await fetch('https://api.github.com/repos/CherryHQ/cherry-studio/releases?per_page=8', {
headers: {
Accept: 'application/vnd.github+json',
'X-GitHub-Api-Version': '2022-11-28',
'Accept-Language': 'en-US,en;q=0.9'
}
})
const data = (await responses.json()) as GithubReleaseInfo[]
const release: GithubReleaseInfo | undefined = data.find((item: GithubReleaseInfo) => {
return item.prerelease && item.tag_name.includes(`-${channel}.`)
})
logger.info('release info', release)
if (!release) {
return null
}
logger.info('release info', release.tag_name)
return `https://github.com/CherryHQ/cherry-studio/releases/download/${release.tag_name}`
} catch (error) {
logger.error('Failed to get latest not draft version from github:', error)
return null
}
}
private async _getIpCountry() {
try {
// add timeout using AbortController
@@ -127,72 +93,9 @@ export default class AppUpdater {
autoUpdater.autoInstallOnAppQuit = isActive
}
private _getChannelByVersion(version: string) {
if (version.includes(`-${UpgradeChannel.BETA}.`)) {
return UpgradeChannel.BETA
}
if (version.includes(`-${UpgradeChannel.RC}.`)) {
return UpgradeChannel.RC
}
return UpgradeChannel.LATEST
}
private _getTestChannel() {
const currentChannel = this._getChannelByVersion(app.getVersion())
const savedChannel = configManager.getTestChannel()
if (currentChannel === UpgradeChannel.LATEST) {
return savedChannel || UpgradeChannel.RC
}
if (savedChannel === currentChannel) {
return savedChannel
}
// if the upgrade channel is not equal to the current channel, use the latest channel
return UpgradeChannel.LATEST
}
private async _setFeedUrl() {
const testPlan = configManager.getTestPlan()
if (testPlan) {
const channel = this._getTestChannel()
if (channel === UpgradeChannel.LATEST) {
this.autoUpdater.channel = UpgradeChannel.LATEST
this.autoUpdater.setFeedURL(FeedUrl.GITHUB_LATEST)
return
}
const preReleaseUrl = await this._getPreReleaseVersionFromGithub(channel)
if (preReleaseUrl) {
this.autoUpdater.setFeedURL(preReleaseUrl)
this.autoUpdater.channel = channel
return
}
// if no prerelease url, use lowest prerelease version to avoid error
this.autoUpdater.setFeedURL(FeedUrl.PRERELEASE_LOWEST)
this.autoUpdater.channel = UpgradeChannel.LATEST
return
}
this.autoUpdater.channel = UpgradeChannel.LATEST
this.autoUpdater.setFeedURL(FeedUrl.PRODUCTION)
const ipCountry = await this._getIpCountry()
logger.info('ipCountry', ipCountry)
if (ipCountry.toLowerCase() !== 'cn') {
this.autoUpdater.setFeedURL(FeedUrl.GITHUB_LATEST)
}
}
public cancelDownload() {
this.cancellationToken.cancel()
this.cancellationToken = new CancellationToken()
if (this.autoUpdater.autoDownload) {
this.updateCheckResult?.cancellationToken?.cancel()
}
public setFeedUrl(feedUrl: FeedUrl) {
autoUpdater.setFeedURL(feedUrl)
configManager.setFeedUrl(feedUrl)
}
public async checkForUpdates() {
@@ -203,26 +106,23 @@ export default class AppUpdater {
}
}
await this._setFeedUrl()
// disable downgrade after change the channel
this.autoUpdater.allowDowngrade = false
// github and gitcode don't support multiple range download
this.autoUpdater.disableDifferentialDownload = true
const ipCountry = await this._getIpCountry()
logger.info('ipCountry', ipCountry)
if (ipCountry !== 'CN') {
this.autoUpdater.setFeedURL(FeedUrl.EARLY_ACCESS)
}
try {
this.updateCheckResult = await this.autoUpdater.checkForUpdates()
if (this.updateCheckResult?.isUpdateAvailable && !this.autoUpdater.autoDownload) {
const update = await this.autoUpdater.checkForUpdates()
if (update?.isUpdateAvailable && !this.autoUpdater.autoDownload) {
// 如果 autoDownload 为 false则需要再调用下面的函数触发下
// do not use await, because it will block the return of this function
logger.info('downloadUpdate manual by check for updates', this.cancellationToken)
this.autoUpdater.downloadUpdate(this.cancellationToken)
this.autoUpdater.downloadUpdate()
}
return {
currentVersion: this.autoUpdater.currentVersion,
updateInfo: this.updateCheckResult?.updateInfo
updateInfo: update?.updateInfo
}
} catch (error) {
logger.error('Failed to check for update:', error)
@@ -278,11 +178,7 @@ export default class AppUpdater {
return releaseNotes.map((note) => note.note).join('\n')
}
}
interface GithubReleaseInfo {
draft: boolean
prerelease: boolean
tag_name: string
}
interface ReleaseNoteInfo {
readonly version: string
readonly note: string | null

View File

@@ -9,7 +9,6 @@ import StreamZip from 'node-stream-zip'
import * as path from 'path'
import { CreateDirectoryOptions, FileStat } from 'webdav'
import { getDataPath } from '../utils'
import WebDav from './WebDav'
import { windowService } from './WindowService'
@@ -254,7 +253,7 @@ class BackupManager {
Logger.log('[backup] step 3: restore Data directory')
// 恢复 Data 目录
const sourcePath = path.join(this.tempDir, 'Data')
const destPath = getDataPath()
const destPath = path.join(app.getPath('userData'), 'Data')
const dataExists = await fs.pathExists(sourcePath)
const dataFiles = dataExists ? await fs.readdir(sourcePath) : []

View File

@@ -1,4 +1,4 @@
import { defaultLanguage, UpgradeChannel, ZOOM_SHORTCUTS } from '@shared/config/constant'
import { defaultLanguage, FeedUrl, ZOOM_SHORTCUTS } from '@shared/config/constant'
import { LanguageVarious, Shortcut, ThemeMode } from '@types'
import { app } from 'electron'
import Store from 'electron-store'
@@ -16,8 +16,7 @@ export enum ConfigKeys {
ClickTrayToShowQuickAssistant = 'clickTrayToShowQuickAssistant',
EnableQuickAssistant = 'enableQuickAssistant',
AutoUpdate = 'autoUpdate',
TestPlan = 'testPlan',
TestChannel = 'testChannel',
FeedUrl = 'feedUrl',
EnableDataCollection = 'enableDataCollection',
SelectionAssistantEnabled = 'selectionAssistantEnabled',
SelectionAssistantTriggerMode = 'selectionAssistantTriggerMode',
@@ -143,20 +142,12 @@ export class ConfigManager {
this.set(ConfigKeys.AutoUpdate, value)
}
getTestPlan(): boolean {
return this.get<boolean>(ConfigKeys.TestPlan, false)
getFeedUrl(): string {
return this.get<string>(ConfigKeys.FeedUrl, FeedUrl.PRODUCTION)
}
setTestPlan(value: boolean) {
this.set(ConfigKeys.TestPlan, value)
}
getTestChannel(): UpgradeChannel {
return this.get<UpgradeChannel>(ConfigKeys.TestChannel)
}
setTestChannel(value: UpgradeChannel) {
this.set(ConfigKeys.TestChannel, value)
setFeedUrl(value: FeedUrl) {
this.set(ConfigKeys.FeedUrl, value)
}
getEnableDataCollection(): boolean {

View File

@@ -4,29 +4,18 @@ import { locales } from '../utils/locales'
import { configManager } from './ConfigManager'
class ContextMenu {
public contextMenu(w: Electron.WebContents) {
w.on('context-menu', (_event, properties) => {
public contextMenu(w: Electron.BrowserWindow) {
w.webContents.on('context-menu', (_event, properties) => {
const template: MenuItemConstructorOptions[] = this.createEditMenuItems(properties)
const filtered = template.filter((item) => item.visible !== false)
if (filtered.length > 0) {
let template = [...filtered, ...this.createInspectMenuItems(w)]
const dictionarySuggestions = this.createDictionarySuggestions(properties, w)
if (dictionarySuggestions.length > 0) {
template = [
...dictionarySuggestions,
{ type: 'separator' },
this.createSpellCheckMenuItem(properties, w),
{ type: 'separator' },
...template
]
}
const menu = Menu.buildFromTemplate(template)
const menu = Menu.buildFromTemplate([...filtered, ...this.createInspectMenuItems(w)])
menu.popup()
}
})
}
private createInspectMenuItems(w: Electron.WebContents): MenuItemConstructorOptions[] {
private createInspectMenuItems(w: Electron.BrowserWindow): MenuItemConstructorOptions[] {
const locale = locales[configManager.getLanguage()]
const { common } = locale.translation
const template: MenuItemConstructorOptions[] = [
@@ -34,7 +23,7 @@ class ContextMenu {
id: 'inspect',
label: common.inspect,
click: () => {
w.toggleDevTools()
w.webContents.toggleDevTools()
},
enabled: true
}
@@ -83,53 +72,6 @@ class ContextMenu {
return template
}
private createSpellCheckMenuItem(
properties: Electron.ContextMenuParams,
w: Electron.WebContents
): MenuItemConstructorOptions {
const hasText = properties.selectionText.length > 0
return {
id: 'learnSpelling',
label: '&Learn Spelling',
visible: Boolean(properties.isEditable && hasText && properties.misspelledWord),
click: () => {
w.session.addWordToSpellCheckerDictionary(properties.misspelledWord)
}
}
}
private createDictionarySuggestions(
properties: Electron.ContextMenuParams,
w: Electron.WebContents
): MenuItemConstructorOptions[] {
const hasText = properties.selectionText.length > 0
if (!hasText || !properties.misspelledWord) {
return []
}
if (properties.dictionarySuggestions.length === 0) {
return [
{
id: 'dictionarySuggestions',
label: 'No Guesses Found',
visible: true,
enabled: false
}
]
}
return properties.dictionarySuggestions.map((suggestion) => ({
id: 'dictionarySuggestions',
label: suggestion,
visible: Boolean(properties.isEditable && hasText && properties.misspelledWord),
click: (menuItem: Electron.MenuItem) => {
w.replaceMisspelling(menuItem.label)
}
}))
}
}
export const contextMenu = new ContextMenu()

View File

@@ -1,6 +1,6 @@
import { getFilesDir, getFileType, getTempDir } from '@main/utils/file'
import { documentExts, imageExts, MB } from '@shared/config/constant'
import { FileMetadata } from '@types'
import { FileType } from '@types'
import * as crypto from 'crypto'
import {
dialog,
@@ -19,7 +19,6 @@ import { getDocument } from 'officeparser/pdfjs-dist-build/pdf.js'
import * as path from 'path'
import { chdir } from 'process'
import { v4 as uuidv4 } from 'uuid'
import WordExtractor from 'word-extractor'
class FileStorage {
private storageDir = getFilesDir()
@@ -53,9 +52,8 @@ class FileStorage {
})
}
findDuplicateFile = async (filePath: string): Promise<FileMetadata | null> => {
findDuplicateFile = async (filePath: string): Promise<FileType | null> => {
const stats = fs.statSync(filePath)
console.log('stats', stats, filePath)
const fileSize = stats.size
const files = await fs.promises.readdir(this.storageDir)
@@ -93,7 +91,7 @@ class FileStorage {
public selectFile = async (
_: Electron.IpcMainInvokeEvent,
options?: OpenDialogOptions
): Promise<FileMetadata[] | null> => {
): Promise<FileType[] | null> => {
const defaultOptions: OpenDialogOptions = {
properties: ['openFile']
}
@@ -152,7 +150,7 @@ class FileStorage {
}
}
public uploadFile = async (_: Electron.IpcMainInvokeEvent, file: FileMetadata): Promise<FileMetadata> => {
public uploadFile = async (_: Electron.IpcMainInvokeEvent, file: FileType): Promise<FileType> => {
const duplicateFile = await this.findDuplicateFile(file.path)
if (duplicateFile) {
@@ -176,7 +174,7 @@ class FileStorage {
const stats = await fs.promises.stat(destPath)
const fileType = getFileType(ext)
const fileMetadata: FileMetadata = {
const fileMetadata: FileType = {
id: uuid,
origin_name,
name: uuid + ext,
@@ -191,7 +189,7 @@ class FileStorage {
return fileMetadata
}
public getFile = async (_: Electron.IpcMainInvokeEvent, filePath: string): Promise<FileMetadata | null> => {
public getFile = async (_: Electron.IpcMainInvokeEvent, filePath: string): Promise<FileType | null> => {
if (!fs.existsSync(filePath)) {
return null
}
@@ -200,7 +198,7 @@ class FileStorage {
const ext = path.extname(filePath)
const fileType = getFileType(ext)
const fileInfo: FileMetadata = {
const fileInfo: FileType = {
id: uuidv4(),
origin_name: path.basename(filePath),
name: path.basename(filePath),
@@ -216,36 +214,16 @@ class FileStorage {
}
public deleteFile = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<void> => {
if (!fs.existsSync(path.join(this.storageDir, id))) {
return
}
await fs.promises.unlink(path.join(this.storageDir, id))
}
public deleteDir = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<void> => {
if (!fs.existsSync(path.join(this.storageDir, id))) {
return
}
await fs.promises.rm(path.join(this.storageDir, id), { recursive: true })
}
public readFile = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<string> => {
const filePath = path.join(this.storageDir, id)
const fileExtension = path.extname(filePath)
if (documentExts.includes(fileExtension)) {
if (documentExts.includes(path.extname(filePath))) {
const originalCwd = process.cwd()
try {
chdir(this.tempDir)
if (fileExtension === '.doc') {
const extractor = new WordExtractor()
const extracted = await extractor.extract(filePath)
chdir(originalCwd)
return extracted.getBody()
}
const data = await officeParser.parseOfficeAsync(filePath)
chdir(originalCwd)
return data
@@ -263,8 +241,8 @@ class FileStorage {
if (!fs.existsSync(this.tempDir)) {
fs.mkdirSync(this.tempDir, { recursive: true })
}
return path.join(this.tempDir, `temp_file_${uuidv4()}_${fileName}`)
const tempFilePath = path.join(this.tempDir, `temp_file_${uuidv4()}_${fileName}`)
return tempFilePath
}
public writeFile = async (
@@ -291,7 +269,7 @@ class FileStorage {
}
}
public saveBase64Image = async (_: Electron.IpcMainInvokeEvent, base64Data: string): Promise<FileMetadata> => {
public saveBase64Image = async (_: Electron.IpcMainInvokeEvent, base64Data: string): Promise<FileType> => {
try {
if (!base64Data) {
throw new Error('Base64 data is required')
@@ -317,7 +295,7 @@ class FileStorage {
await fs.promises.writeFile(destPath, buffer)
const fileMetadata: FileMetadata = {
const fileMetadata: FileType = {
id: uuid,
origin_name: uuid + ext,
name: uuid + ext,
@@ -374,7 +352,7 @@ class FileStorage {
public open = async (
_: Electron.IpcMainInvokeEvent,
options: OpenDialogOptions
): Promise<{ fileName: string; filePath: string; content?: Buffer; size: number } | null> => {
): Promise<{ fileName: string; filePath: string; content: Buffer } | null> => {
try {
const result: OpenDialogReturnValue = await dialog.showOpenDialog({
title: '打开文件',
@@ -386,16 +364,8 @@ class FileStorage {
if (!result.canceled && result.filePaths.length > 0) {
const filePath = result.filePaths[0]
const fileName = filePath.split('/').pop() || ''
const stats = await fs.promises.stat(filePath)
// If the file is less than 2GB, read the content
if (stats.size < 2 * 1024 * 1024 * 1024) {
const content = await readFile(filePath)
return { fileName, filePath, content, size: stats.size }
}
// For large files, only return file information, do not read content
return { fileName, filePath, size: stats.size }
const content = await readFile(filePath)
return { fileName, filePath, content }
}
return null
@@ -476,7 +446,7 @@ class FileStorage {
_: Electron.IpcMainInvokeEvent,
url: string,
isUseContentType?: boolean
): Promise<FileMetadata> => {
): Promise<FileType> => {
try {
const response = await fetch(url)
if (!response.ok) {
@@ -518,7 +488,7 @@ class FileStorage {
const stats = await fs.promises.stat(destPath)
const fileType = getFileType(ext)
const fileMetadata: FileMetadata = {
const fileMetadata: FileType = {
id: uuid,
origin_name: filename,
name: uuid + ext,

View File

@@ -16,24 +16,21 @@
import * as fs from 'node:fs'
import path from 'node:path'
import { RAGApplication, RAGApplicationBuilder } from '@cherrystudio/embedjs'
import { RAGApplication, RAGApplicationBuilder, TextLoader } from '@cherrystudio/embedjs'
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import { LibSqlDb } from '@cherrystudio/embedjs-libsql'
import { SitemapLoader } from '@cherrystudio/embedjs-loader-sitemap'
import { WebLoader } from '@cherrystudio/embedjs-loader-web'
import Embeddings from '@main/knowledage/embeddings/Embeddings'
import { addFileLoader } from '@main/knowledage/loader'
import { NoteLoader } from '@main/knowledage/loader/noteLoader'
import Reranker from '@main/knowledage/reranker/Reranker'
import OcrProvider from '@main/ocr/OcrProvider'
import PreprocessProvider from '@main/preprocess/PreprocessProvider'
import Embeddings from '@main/embeddings/Embeddings'
import { addFileLoader } from '@main/loader'
import Reranker from '@main/reranker/Reranker'
import { windowService } from '@main/services/WindowService'
import { getDataPath } from '@main/utils'
import { getAllFiles } from '@main/utils/file'
import { MB } from '@shared/config/constant'
import type { LoaderReturn } from '@shared/config/types'
import { IpcChannel } from '@shared/IpcChannel'
import { FileMetadata, KnowledgeBaseParams, KnowledgeItem } from '@types'
import { FileType, KnowledgeBaseParams, KnowledgeItem } from '@types'
import { app } from 'electron'
import Logger from 'electron-log'
import { v4 as uuidv4 } from 'uuid'
@@ -41,14 +38,12 @@ export interface KnowledgeBaseAddItemOptions {
base: KnowledgeBaseParams
item: KnowledgeItem
forceReload?: boolean
userId?: string
}
interface KnowledgeBaseAddItemOptionsNonNullableAttribute {
base: KnowledgeBaseParams
item: KnowledgeItem
forceReload: boolean
userId: string
}
interface EvaluateTaskWorkload {
@@ -93,20 +88,14 @@ const loaderTaskIntoOfSet = (loaderTask: LoaderTask): LoaderTaskOfSet => {
}
class KnowledgeService {
private storageDir = path.join(getDataPath(), 'KnowledgeBase')
private storageDir = path.join(app.getPath('userData'), 'Data', 'KnowledgeBase')
// Byte based
private workload = 0
private processingItemCount = 0
private knowledgeItemProcessingQueueMappingPromise: Map<LoaderTaskOfSet, () => void> = new Map()
private static MAXIMUM_WORKLOAD = 80 * MB
private static MAXIMUM_PROCESSING_ITEM_COUNT = 30
private static ERROR_LOADER_RETURN: LoaderReturn = {
entriesAdded: 0,
uniqueId: '',
uniqueIds: [''],
loaderType: '',
status: 'failed'
}
private static ERROR_LOADER_RETURN: LoaderReturn = { entriesAdded: 0, uniqueId: '', uniqueIds: [''], loaderType: '' }
constructor() {
this.initStorageDir()
@@ -154,13 +143,12 @@ class KnowledgeService {
this.getRagApplication(base)
}
public reset = async (_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams): Promise<void> => {
public reset = async (_: Electron.IpcMainInvokeEvent, { base }: { base: KnowledgeBaseParams }): Promise<void> => {
const ragApplication = await this.getRagApplication(base)
await ragApplication.reset()
}
public delete = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<void> => {
console.log('id', id)
const dbPath = path.join(this.storageDir, id)
if (fs.existsSync(dbPath)) {
fs.rmSync(dbPath, { recursive: true })
@@ -173,49 +161,28 @@ class KnowledgeService {
this.workload >= KnowledgeService.MAXIMUM_WORKLOAD
)
}
private fileTask(
ragApplication: RAGApplication,
options: KnowledgeBaseAddItemOptionsNonNullableAttribute
): LoaderTask {
const { base, item, forceReload, userId } = options
const file = item.content as FileMetadata
const { base, item, forceReload } = options
const file = item.content as FileType
const loaderTask: LoaderTask = {
loaderTasks: [
{
state: LoaderTaskItemState.PENDING,
task: async () => {
try {
// 添加预处理逻辑
const fileToProcess: FileMetadata = await this.preprocessing(file, base, item, userId)
// 使用处理后的文件进行加载
return addFileLoader(ragApplication, fileToProcess, base, forceReload)
.then((result) => {
loaderTask.loaderDoneReturn = result
return result
})
.catch((e) => {
Logger.error(`Error in addFileLoader for ${file.name}: ${e}`)
const errorResult: LoaderReturn = {
...KnowledgeService.ERROR_LOADER_RETURN,
message: e.message,
messageSource: 'embedding'
}
loaderTask.loaderDoneReturn = errorResult
return errorResult
})
} catch (e: any) {
Logger.error(`Preprocessing failed for ${file.name}: ${e}`)
const errorResult: LoaderReturn = {
...KnowledgeService.ERROR_LOADER_RETURN,
message: e.message,
messageSource: 'preprocess'
}
loaderTask.loaderDoneReturn = errorResult
return errorResult
}
},
task: () =>
addFileLoader(ragApplication, file, base, forceReload)
.then((result) => {
loaderTask.loaderDoneReturn = result
return result
})
.catch((err) => {
Logger.error(err)
return KnowledgeService.ERROR_LOADER_RETURN
}),
evaluateTaskWorkload: { workload: file.size }
}
],
@@ -224,6 +191,7 @@ class KnowledgeService {
return loaderTask
}
private directoryTask(
ragApplication: RAGApplication,
options: KnowledgeBaseAddItemOptionsNonNullableAttribute
@@ -263,11 +231,7 @@ class KnowledgeService {
})
.catch((err) => {
Logger.error(err)
return {
...KnowledgeService.ERROR_LOADER_RETURN,
message: `Failed to add dir loader: ${err.message}`,
messageSource: 'embedding'
}
return KnowledgeService.ERROR_LOADER_RETURN
}),
evaluateTaskWorkload: { workload: file.size }
})
@@ -313,11 +277,7 @@ class KnowledgeService {
})
.catch((err) => {
Logger.error(err)
return {
...KnowledgeService.ERROR_LOADER_RETURN,
message: `Failed to add url loader: ${err.message}`,
messageSource: 'embedding'
}
return KnowledgeService.ERROR_LOADER_RETURN
})
},
evaluateTaskWorkload: { workload: 2 * MB }
@@ -357,11 +317,7 @@ class KnowledgeService {
})
.catch((err) => {
Logger.error(err)
return {
...KnowledgeService.ERROR_LOADER_RETURN,
message: `Failed to add sitemap loader: ${err.message}`,
messageSource: 'embedding'
}
return KnowledgeService.ERROR_LOADER_RETURN
}),
evaluateTaskWorkload: { workload: 20 * MB }
}
@@ -377,7 +333,6 @@ class KnowledgeService {
): LoaderTask {
const { base, item, forceReload } = options
const content = item.content as string
const sourceUrl = (item as any).sourceUrl
const encoder = new TextEncoder()
const contentBytes = encoder.encode(content)
@@ -387,12 +342,7 @@ class KnowledgeService {
state: LoaderTaskItemState.PENDING,
task: () => {
const loaderReturn = ragApplication.addLoader(
new NoteLoader({
text: content,
sourceUrl,
chunkSize: base.chunkSize,
chunkOverlap: base.chunkOverlap
}),
new TextLoader({ text: content, chunkSize: base.chunkSize, chunkOverlap: base.chunkOverlap }),
forceReload
) as Promise<LoaderReturn>
@@ -407,11 +357,7 @@ class KnowledgeService {
})
.catch((err) => {
Logger.error(err)
return {
...KnowledgeService.ERROR_LOADER_RETURN,
message: `Failed to add note loader: ${err.message}`,
messageSource: 'embedding'
}
return KnowledgeService.ERROR_LOADER_RETURN
})
},
evaluateTaskWorkload: { workload: contentBytes.length }
@@ -477,10 +423,10 @@ class KnowledgeService {
})
}
public add = async (_: Electron.IpcMainInvokeEvent, options: KnowledgeBaseAddItemOptions): Promise<LoaderReturn> => {
public add = (_: Electron.IpcMainInvokeEvent, options: KnowledgeBaseAddItemOptions): Promise<LoaderReturn> => {
return new Promise((resolve) => {
const { base, item, forceReload = false, userId = '' } = options
const optionsNonNullableAttribute = { base, item, forceReload, userId }
const { base, item, forceReload = false } = options
const optionsNonNullableAttribute = { base, item, forceReload }
this.getRagApplication(base)
.then((ragApplication) => {
const task = (() => {
@@ -506,20 +452,12 @@ class KnowledgeService {
})
this.processingQueueHandle()
} else {
resolve({
...KnowledgeService.ERROR_LOADER_RETURN,
message: 'Unsupported item type',
messageSource: 'embedding'
})
resolve(KnowledgeService.ERROR_LOADER_RETURN)
}
})
.catch((err) => {
Logger.error(err)
resolve({
...KnowledgeService.ERROR_LOADER_RETURN,
message: `Failed to add item: ${err.message}`,
messageSource: 'embedding'
})
resolve(KnowledgeService.ERROR_LOADER_RETURN)
})
})
}
@@ -552,69 +490,6 @@ class KnowledgeService {
}
return await new Reranker(base).rerank(search, results)
}
public getStorageDir = (): string => {
return this.storageDir
}
private preprocessing = async (
file: FileMetadata,
base: KnowledgeBaseParams,
item: KnowledgeItem,
userId: string
): Promise<FileMetadata> => {
let fileToProcess: FileMetadata = file
if (base.preprocessOrOcrProvider && file.ext.toLowerCase() === '.pdf') {
try {
let provider: PreprocessProvider | OcrProvider
if (base.preprocessOrOcrProvider.type === 'preprocess') {
provider = new PreprocessProvider(base.preprocessOrOcrProvider.provider, userId)
} else {
provider = new OcrProvider(base.preprocessOrOcrProvider.provider)
}
// 首先检查文件是否已经被预处理过
const alreadyProcessed = await provider.checkIfAlreadyProcessed(file)
if (alreadyProcessed) {
Logger.info(`File already preprocess processed, using cached result: ${file.path}`)
return alreadyProcessed
}
// 执行预处理
Logger.info(`Starting preprocess processing for scanned PDF: ${file.path}`)
const { processedFile, quota } = await provider.parseFile(item.id, file)
fileToProcess = processedFile
const mainWindow = windowService.getMainWindow()
mainWindow?.webContents.send('file-preprocess-finished', {
itemId: item.id,
quota: quota
})
} catch (err) {
Logger.error(`Preprocess processing failed: ${err}`)
// 如果预处理失败,使用原始文件
// fileToProcess = file
throw new Error(`Preprocess processing failed: ${err}`)
}
}
return fileToProcess
}
public checkQuota = async (
_: Electron.IpcMainInvokeEvent,
base: KnowledgeBaseParams,
userId: string
): Promise<number> => {
try {
if (base.preprocessOrOcrProvider && base.preprocessOrOcrProvider.type === 'preprocess') {
const provider = new PreprocessProvider(base.preprocessOrOcrProvider.provider, userId)
return await provider.checkQuota()
}
throw new Error('No preprocess provider configured')
} catch (err) {
Logger.error(`Failed to check quota: ${err}`)
throw new Error(`Failed to check quota: ${err}`)
}
}
}
export default new KnowledgeService()

View File

@@ -1,33 +0,0 @@
import { Mistral } from '@mistralai/mistralai'
import { Provider } from '@types'
export class MistralClientManager {
private static instance: MistralClientManager
private client: Mistral | null = null
// eslint-disable-next-line @typescript-eslint/no-empty-function
private constructor() {}
public static getInstance(): MistralClientManager {
if (!MistralClientManager.instance) {
MistralClientManager.instance = new MistralClientManager()
}
return MistralClientManager.instance
}
public initializeClient(provider: Provider): void {
if (!this.client) {
this.client = new Mistral({
apiKey: provider.apiKey,
serverURL: provider.apiHost
})
}
}
public getClient(): Mistral {
if (!this.client) {
throw new Error('Mistral client not initialized. Call initializeClient first.')
}
return this.client
}
}

View File

@@ -19,7 +19,7 @@ export function registerProtocolClient(app: Electron.App) {
}
}
app.setAsDefaultProtocolClient(CHERRY_STUDIO_PROTOCOL)
app.setAsDefaultProtocolClient('cherrystudio')
}
export function handleProtocolUrl(url: string) {

View File

@@ -1,102 +0,0 @@
import { randomUUID } from 'node:crypto'
import { BrowserWindow, ipcMain } from 'electron'
interface PythonExecutionRequest {
id: string
script: string
context: Record<string, any>
timeout: number
}
interface PythonExecutionResponse {
id: string
result?: string
error?: string
}
/**
* Service for executing Python code by communicating with the PyodideService in the renderer process
*/
export class PythonService {
private static instance: PythonService | null = null
private mainWindow: BrowserWindow | null = null
private pendingRequests = new Map<string, { resolve: (value: string) => void; reject: (error: Error) => void }>()
private constructor() {
// Private constructor for singleton pattern
this.setupIpcHandlers()
}
public static getInstance(): PythonService {
if (!PythonService.instance) {
PythonService.instance = new PythonService()
}
return PythonService.instance
}
private setupIpcHandlers() {
// Handle responses from renderer
ipcMain.on('python-execution-response', (_, response: PythonExecutionResponse) => {
const request = this.pendingRequests.get(response.id)
if (request) {
this.pendingRequests.delete(response.id)
if (response.error) {
request.reject(new Error(response.error))
} else {
request.resolve(response.result || '')
}
}
})
}
public setMainWindow(mainWindow: BrowserWindow) {
this.mainWindow = mainWindow
}
/**
* Execute Python code by sending request to renderer PyodideService
*/
public async executeScript(
script: string,
context: Record<string, any> = {},
timeout: number = 60000
): Promise<string> {
if (!this.mainWindow) {
throw new Error('Main window not set in PythonService')
}
return new Promise((resolve, reject) => {
const requestId = randomUUID()
// Store the request
this.pendingRequests.set(requestId, { resolve, reject })
// Set up timeout
const timeoutId = setTimeout(() => {
this.pendingRequests.delete(requestId)
reject(new Error('Python execution timed out'))
}, timeout + 5000) // Add 5s buffer for IPC communication
// Update resolve/reject to clear timeout
const originalResolve = resolve
const originalReject = reject
this.pendingRequests.set(requestId, {
resolve: (value: string) => {
clearTimeout(timeoutId)
originalResolve(value)
},
reject: (error: Error) => {
clearTimeout(timeoutId)
originalReject(error)
}
})
// Send request to renderer
const request: PythonExecutionRequest = { id: requestId, script, context, timeout }
this.mainWindow?.webContents.send('python-execution-request', request)
})
}
}
export const pythonService = PythonService.getInstance()

View File

@@ -1,7 +1,7 @@
import { SELECTION_FINETUNED_LIST, SELECTION_PREDEFINED_BLACKLIST } from '@main/configs/SelectionConfig'
import { isDev, isMac, isWin } from '@main/constant'
import { isDev, isWin } from '@main/constant'
import { IpcChannel } from '@shared/IpcChannel'
import { BrowserWindow, ipcMain, screen, systemPreferences } from 'electron'
import { BrowserWindow, ipcMain, screen } from 'electron'
import Logger from 'electron-log'
import { join } from 'path'
import type {
@@ -16,12 +16,9 @@ import type { ActionItem } from '../../renderer/src/types/selectionTypes'
import { ConfigKeys, configManager } from './ConfigManager'
import storeSyncService from './StoreSyncService'
const isSupportedOS = isWin || isMac
let SelectionHook: SelectionHookConstructor | null = null
try {
//since selection-hook v1.0.0, it supports macOS
if (isSupportedOS) {
if (isWin) {
SelectionHook = require('selection-hook')
}
} catch (error) {
@@ -121,7 +118,7 @@ export class SelectionService {
}
public static getInstance(): SelectionService | null {
if (!isSupportedOS) return null
if (!isWin) return null
if (!SelectionService.instance) {
SelectionService.instance = new SelectionService()
@@ -216,8 +213,6 @@ export class SelectionService {
blacklist: SelectionHook!.FilterMode.EXCLUDE_LIST
}
const predefinedBlacklist = isWin ? SELECTION_PREDEFINED_BLACKLIST.WINDOWS : SELECTION_PREDEFINED_BLACKLIST.MAC
let combinedList: string[] = list
let combinedMode = mode
@@ -226,7 +221,7 @@ export class SelectionService {
switch (mode) {
case 'blacklist':
//combine the predefined blacklist with the user-defined blacklist
combinedList = [...new Set([...list, ...predefinedBlacklist])]
combinedList = [...new Set([...list, ...SELECTION_PREDEFINED_BLACKLIST.WINDOWS])]
break
case 'whitelist':
combinedList = [...list]
@@ -234,7 +229,7 @@ export class SelectionService {
case 'default':
default:
//use the predefined blacklist as the default filter list
combinedList = [...predefinedBlacklist]
combinedList = [...SELECTION_PREDEFINED_BLACKLIST.WINDOWS]
combinedMode = 'blacklist'
break
}
@@ -248,21 +243,14 @@ export class SelectionService {
private setHookFineTunedList() {
if (!this.selectionHook) return
const excludeClipboardCursorDetectList = isWin
? SELECTION_FINETUNED_LIST.EXCLUDE_CLIPBOARD_CURSOR_DETECT.WINDOWS
: SELECTION_FINETUNED_LIST.EXCLUDE_CLIPBOARD_CURSOR_DETECT.MAC
const includeClipboardDelayReadList = isWin
? SELECTION_FINETUNED_LIST.INCLUDE_CLIPBOARD_DELAY_READ.WINDOWS
: SELECTION_FINETUNED_LIST.INCLUDE_CLIPBOARD_DELAY_READ.MAC
this.selectionHook.setFineTunedList(
SelectionHook!.FineTunedListType.EXCLUDE_CLIPBOARD_CURSOR_DETECT,
excludeClipboardCursorDetectList
SELECTION_FINETUNED_LIST.EXCLUDE_CLIPBOARD_CURSOR_DETECT.WINDOWS
)
this.selectionHook.setFineTunedList(
SelectionHook!.FineTunedListType.INCLUDE_CLIPBOARD_DELAY_READ,
includeClipboardDelayReadList
SELECTION_FINETUNED_LIST.INCLUDE_CLIPBOARD_DELAY_READ.WINDOWS
)
}
@@ -271,28 +259,11 @@ export class SelectionService {
* @returns {boolean} Success status of service start
*/
public start(): boolean {
if (!this.selectionHook) {
this.logError(new Error('SelectionService start(): instance is null'))
if (!this.selectionHook || this.started) {
this.logError(new Error('SelectionService start(): instance is null or already started'))
return false
}
if (this.started) {
this.logError(new Error('SelectionService start(): already started'))
return false
}
//On macOS, we need to check if the process is trusted
if (isMac) {
if (!systemPreferences.isTrustedAccessibilityClient(false)) {
this.logError(
new Error(
'SelectionSerice not started: process is not trusted on macOS, please turn on the Accessibility permission'
)
)
return false
}
}
try {
//make sure the toolbar window is ready
this.createToolbarWindow()
@@ -335,7 +306,6 @@ export class SelectionService {
if (!this.selectionHook) return false
this.selectionHook.stop()
this.selectionHook.cleanup() //already remove all listeners
//reset the listener states
@@ -346,7 +316,6 @@ export class SelectionService {
this.toolbarWindow.close()
this.toolbarWindow = null
}
this.closePreloadedActionWindows()
this.started = false
@@ -397,29 +366,21 @@ export class SelectionService {
this.toolbarWindow = new BrowserWindow({
width: toolbarWidth,
height: toolbarHeight,
show: false,
frame: false,
transparent: true,
alwaysOnTop: true,
skipTaskbar: true,
autoHideMenuBar: true,
resizable: false,
minimizable: false,
maximizable: false,
fullscreenable: false, // [macOS] must be false
movable: true,
focusable: false,
hasShadow: false,
thickFrame: false,
roundedCorners: true,
backgroundMaterial: 'none',
// Platform specific settings
// [macOS] DO NOT set type to 'panel', it will not work because it conflicts with other settings
// [macOS] DO NOT set focusable to false, it will make other windows bring to front together
...(isWin ? { type: 'toolbar', focusable: false } : {}),
hiddenInMissionControl: true, // [macOS only]
acceptFirstMouse: true, // [macOS only]
type: 'toolbar',
show: false,
webPreferences: {
preload: join(__dirname, '../preload/index.js'),
contextIsolation: true,
@@ -431,9 +392,7 @@ export class SelectionService {
// Hide when losing focus
this.toolbarWindow.on('blur', () => {
if (this.toolbarWindow!.isVisible()) {
this.hideToolbar()
}
this.hideToolbar()
})
// Clean up when closed
@@ -447,13 +406,6 @@ export class SelectionService {
// Add show/hide event listeners
this.toolbarWindow.on('show', () => {
this.toolbarWindow?.webContents.send(IpcChannel.Selection_ToolbarVisibilityChange, true)
// [macOS] force the toolbar window to be visible on current desktop
// but it will make docker icon flash. And we found that it's not necessary now.
// will remove after testing
// if (isMac) {
// this.toolbarWindow!.setVisibleOnAllWorkspaces(false)
// }
})
this.toolbarWindow.on('hide', () => {
@@ -508,22 +460,11 @@ export class SelectionService {
//set the window to always on top (highest level)
//should set every time the window is shown
this.toolbarWindow!.setAlwaysOnTop(true, 'screen-saver')
// [macOS] force the toolbar window to be visible on current desktop
// but it will make docker icon flash. And we found that it's not necessary now.
// will remove after testing
// if (isMac) {
// this.toolbarWindow!.setVisibleOnAllWorkspaces(true, { visibleOnFullScreen: true })
// }
// [macOS] MUST use `showInactive()` to prevent other windows bring to front together
// [Windows] is OK for both `show()` and `showInactive()` because of `focusable: false`
this.toolbarWindow!.showInactive()
this.toolbarWindow!.show()
/**
* [Windows]
* In Windows 10, setOpacity(1) will make the window completely transparent
* It's a strange behavior, so we don't use it for compatibility
* In Windows 10, setOpacity(1) will make the window completely transparent
* It's a strange behavior, so we don't use it for compatibility
*/
// this.toolbarWindow!.setOpacity(1)
@@ -536,52 +477,10 @@ export class SelectionService {
public hideToolbar(): void {
if (!this.isToolbarAlive()) return
this.stopHideByMouseKeyListener()
// [Windows] just hide the toolbar window is enough
if (!isMac) {
this.toolbarWindow!.hide()
return
}
/************************************************
* [macOS] the following code is only for macOS
*************************************************/
// [macOS] a HACKY way
// make sure other windows do not bring to front when toolbar is hidden
// get all focusable windows and set them to not focusable
const focusableWindows: BrowserWindow[] = []
for (const window of BrowserWindow.getAllWindows()) {
if (!window.isDestroyed() && window.isVisible()) {
if (window.isFocusable()) {
focusableWindows.push(window)
window.setFocusable(false)
}
}
}
// this.toolbarWindow!.setOpacity(0)
this.toolbarWindow!.hide()
// set them back to focusable after 50ms
setTimeout(() => {
for (const window of focusableWindows) {
if (!window.isDestroyed()) {
window.setFocusable(true)
}
}
}, 50)
// [macOS] hacky way
// Because toolbar is not a FOCUSED window, so the hover status will remain when next time show
// so we just send mouseMove event to the toolbar window to make the hover status disappear
this.toolbarWindow!.webContents.sendInputEvent({
type: 'mouseMove',
x: -1,
y: -1
})
return
this.stopHideByMouseKeyListener()
}
/**
@@ -621,71 +520,71 @@ export class SelectionService {
/**
* Calculate optimal toolbar position based on selection context
* Ensures toolbar stays within screen boundaries and follows selection direction
* @param refPoint Reference point for positioning, must be INTEGER
* @param point Reference point for positioning, must be INTEGER
* @param orientation Preferred position relative to reference point
* @returns Calculated screen coordinates for toolbar, INTEGER
*/
private calculateToolbarPosition(refPoint: Point, orientation: RelativeOrientation): Point {
private calculateToolbarPosition(point: Point, orientation: RelativeOrientation): Point {
// Calculate initial position based on the specified anchor
const posPoint: Point = { x: 0, y: 0 }
let posX: number, posY: number
const { toolbarWidth, toolbarHeight } = this.getToolbarRealSize()
switch (orientation) {
case 'topLeft':
posPoint.x = refPoint.x - toolbarWidth
posPoint.y = refPoint.y - toolbarHeight
posX = point.x - toolbarWidth
posY = point.y - toolbarHeight
break
case 'topRight':
posPoint.x = refPoint.x
posPoint.y = refPoint.y - toolbarHeight
posX = point.x
posY = point.y - toolbarHeight
break
case 'topMiddle':
posPoint.x = refPoint.x - toolbarWidth / 2
posPoint.y = refPoint.y - toolbarHeight
posX = point.x - toolbarWidth / 2
posY = point.y - toolbarHeight
break
case 'bottomLeft':
posPoint.x = refPoint.x - toolbarWidth
posPoint.y = refPoint.y
posX = point.x - toolbarWidth
posY = point.y
break
case 'bottomRight':
posPoint.x = refPoint.x
posPoint.y = refPoint.y
posX = point.x
posY = point.y
break
case 'bottomMiddle':
posPoint.x = refPoint.x - toolbarWidth / 2
posPoint.y = refPoint.y
posX = point.x - toolbarWidth / 2
posY = point.y
break
case 'middleLeft':
posPoint.x = refPoint.x - toolbarWidth
posPoint.y = refPoint.y - toolbarHeight / 2
posX = point.x - toolbarWidth
posY = point.y - toolbarHeight / 2
break
case 'middleRight':
posPoint.x = refPoint.x
posPoint.y = refPoint.y - toolbarHeight / 2
posX = point.x
posY = point.y - toolbarHeight / 2
break
case 'center':
posPoint.x = refPoint.x - toolbarWidth / 2
posPoint.y = refPoint.y - toolbarHeight / 2
posX = point.x - toolbarWidth / 2
posY = point.y - toolbarHeight / 2
break
default:
// Default to 'topMiddle' if invalid position
posPoint.x = refPoint.x - toolbarWidth / 2
posPoint.y = refPoint.y - toolbarHeight / 2
posX = point.x - toolbarWidth / 2
posY = point.y - toolbarHeight / 2
}
//use original point to get the display
const display = screen.getDisplayNearestPoint(refPoint)
const display = screen.getDisplayNearestPoint({ x: point.x, y: point.y })
// Ensure toolbar stays within screen boundaries
posPoint.x = Math.round(
Math.max(display.workArea.x, Math.min(posPoint.x, display.workArea.x + display.workArea.width - toolbarWidth))
posX = Math.round(
Math.max(display.workArea.x, Math.min(posX, display.workArea.x + display.workArea.width - toolbarWidth))
)
posPoint.y = Math.round(
Math.max(display.workArea.y, Math.min(posPoint.y, display.workArea.y + display.workArea.height - toolbarHeight))
posY = Math.round(
Math.max(display.workArea.y, Math.min(posY, display.workArea.y + display.workArea.height - toolbarHeight))
)
return posPoint
return { x: posX, y: posY }
}
private isSamePoint(point1: Point, point2: Point): boolean {
@@ -874,11 +773,8 @@ export class SelectionService {
}
if (!isLogical) {
// [macOS] don't need to convert by screenToDipPoint
if (!isMac) {
refPoint = screen.screenToDipPoint(refPoint)
}
//screenToDipPoint can be float, so we need to round it
refPoint = screen.screenToDipPoint(refPoint)
refPoint = { x: Math.round(refPoint.x), y: Math.round(refPoint.y) }
}
@@ -936,8 +832,8 @@ export class SelectionService {
return
}
//data point is physical coordinates, convert to logical coordinates(only for windows/linux)
const mousePoint = isMac ? { x: data.x, y: data.y } : screen.screenToDipPoint({ x: data.x, y: data.y })
//data point is physical coordinates, convert to logical coordinates
const mousePoint = screen.screenToDipPoint({ x: data.x, y: data.y })
const bounds = this.toolbarWindow!.getBounds()
@@ -1070,8 +966,7 @@ export class SelectionService {
frame: false,
transparent: true,
autoHideMenuBar: true,
titleBarStyle: 'hidden', // [macOS]
trafficLightPosition: { x: 12, y: 9 }, // [macOS]
titleBarStyle: 'hidden',
hasShadow: false,
thickFrame: false,
show: false,
@@ -1148,27 +1043,6 @@ export class SelectionService {
if (!actionWindow.isDestroyed()) {
actionWindow.destroy()
}
// [macOS] a HACKY way
// make sure other windows do not bring to front when action window is closed
if (isMac) {
const focusableWindows: BrowserWindow[] = []
for (const window of BrowserWindow.getAllWindows()) {
if (!window.isDestroyed() && window.isVisible()) {
if (window.isFocusable()) {
focusableWindows.push(window)
window.setFocusable(false)
}
}
}
setTimeout(() => {
for (const window of focusableWindows) {
if (!window.isDestroyed()) {
window.setFocusable(true)
}
}
}, 50)
}
})
//remember the action window size
@@ -1214,26 +1088,22 @@ export class SelectionService {
//center way
if (!this.isFollowToolbar || !this.toolbarWindow) {
const display = screen.getDisplayNearestPoint(screen.getCursorScreenPoint())
const workArea = display.workArea
const centerX = workArea.x + (workArea.width - actionWindowWidth) / 2
const centerY = workArea.y + (workArea.height - actionWindowHeight) / 2
actionWindow.setBounds({
width: actionWindowWidth,
height: actionWindowHeight,
x: Math.round(centerX),
y: Math.round(centerY)
})
if (this.isRemeberWinSize) {
actionWindow.setBounds({
width: actionWindowWidth,
height: actionWindowHeight
})
}
actionWindow.show()
this.hideToolbar()
return
}
//follow toolbar
const toolbarBounds = this.toolbarWindow!.getBounds()
const display = screen.getDisplayNearestPoint(screen.getCursorScreenPoint())
const display = screen.getDisplayNearestPoint({ x: toolbarBounds.x, y: toolbarBounds.y })
const workArea = display.workArea
const GAP = 6 // 6px gap from screen edges
@@ -1344,7 +1214,7 @@ export class SelectionService {
selectionService?.hideToolbar()
})
ipcMain.handle(IpcChannel.Selection_WriteToClipboard, (_, text: string): boolean => {
ipcMain.handle(IpcChannel.Selection_WriteToClipboard, (_, text: string) => {
return selectionService?.writeToClipboard(text) ?? false
})
@@ -1421,7 +1291,7 @@ export class SelectionService {
* @returns {boolean} Success status of initialization
*/
export function initSelectionService(): boolean {
if (!isSupportedOS) return false
if (!isWin) return false
configManager.subscribe(ConfigKeys.SelectionAssistantEnabled, (enabled: boolean) => {
//avoid closure

View File

@@ -1,4 +1,4 @@
import { isLinux, isMac, isWin } from '@main/constant'
import { isMac } from '@main/constant'
import { locales } from '@main/utils/locales'
import { app, Menu, MenuItemConstructorOptions, nativeImage, nativeTheme, Tray } from 'electron'
@@ -6,7 +6,6 @@ import icon from '../../../build/tray_icon.png?asset'
import iconDark from '../../../build/tray_icon_dark.png?asset'
import iconLight from '../../../build/tray_icon_light.png?asset'
import { ConfigKeys, configManager } from './ConfigManager'
import selectionService from './SelectionService'
import { windowService } from './WindowService'
export class TrayService {
@@ -30,14 +29,14 @@ export class TrayService {
const iconPath = isMac ? (nativeTheme.shouldUseDarkColors ? iconLight : iconDark) : icon
const tray = new Tray(iconPath)
if (isWin) {
if (process.platform === 'win32') {
tray.setImage(iconPath)
} else if (isMac) {
} else if (process.platform === 'darwin') {
const image = nativeImage.createFromPath(iconPath)
const resizedImage = image.resize({ width: 16, height: 16 })
resizedImage.setTemplateImage(true)
tray.setImage(resizedImage)
} else if (isLinux) {
} else if (process.platform === 'linux') {
const image = nativeImage.createFromPath(iconPath)
const resizedImage = image.resize({ width: 16, height: 16 })
tray.setImage(resizedImage)
@@ -47,7 +46,7 @@ export class TrayService {
this.updateContextMenu()
if (isLinux) {
if (process.platform === 'linux') {
this.tray.setContextMenu(this.contextMenu)
}
@@ -70,29 +69,19 @@ export class TrayService {
private updateContextMenu() {
const locale = locales[configManager.getLanguage()]
const { tray: trayLocale, selection: selectionLocale } = locale.translation
const { tray: trayLocale } = locale.translation
const quickAssistantEnabled = configManager.getEnableQuickAssistant()
const selectionAssistantEnabled = configManager.getSelectionAssistantEnabled()
const enableQuickAssistant = configManager.getEnableQuickAssistant()
const template = [
{
label: trayLocale.show_window,
click: () => windowService.showMainWindow()
},
quickAssistantEnabled && {
enableQuickAssistant && {
label: trayLocale.show_mini_window,
click: () => windowService.showMiniWindow()
},
(isWin || isMac) && {
label: selectionLocale.name + (selectionAssistantEnabled ? ' - On' : ' - Off'),
click: () => {
if (selectionService) {
selectionService.toggleEnabled()
this.updateContextMenu()
}
}
},
{ type: 'separator' },
{
label: trayLocale.quit,
@@ -129,10 +118,6 @@ export class TrayService {
configManager.subscribe(ConfigKeys.EnableQuickAssistant, () => {
this.updateContextMenu()
})
configManager.subscribe(ConfigKeys.SelectionAssistantEnabled, () => {
this.updateContextMenu()
})
}
private quit() {

View File

@@ -56,7 +56,7 @@ export class WindowService {
minHeight: 600,
show: false,
autoHideMenuBar: true,
transparent: false,
transparent: isMac,
vibrancy: 'sidebar',
visualEffectState: 'active',
titleBarStyle: 'hidden',
@@ -95,7 +95,6 @@ export class WindowService {
this.setupMaximize(mainWindow, mainWindowState.isMaximized)
this.setupContextMenu(mainWindow)
this.setupSpellCheck(mainWindow)
this.setupWindowEvents(mainWindow)
this.setupWebContentsHandlers(mainWindow)
this.setupWindowLifecycleEvents(mainWindow)
@@ -103,18 +102,6 @@ export class WindowService {
this.loadMainWindowContent(mainWindow)
}
private setupSpellCheck(mainWindow: BrowserWindow) {
const enableSpellCheck = configManager.get('enableSpellCheck', false)
if (enableSpellCheck) {
try {
const spellCheckLanguages = configManager.get('spellCheckLanguages', []) as string[]
spellCheckLanguages.length > 0 && mainWindow.webContents.session.setSpellCheckerLanguages(spellCheckLanguages)
} catch (error) {
Logger.error('Failed to set spell check languages:', error as Error)
}
}
}
private setupMainWindowMonitor(mainWindow: BrowserWindow) {
mainWindow.webContents.on('render-process-gone', (_, details) => {
Logger.error(`Renderer process crashed with: ${JSON.stringify(details)}`)
@@ -143,10 +130,9 @@ export class WindowService {
}
private setupContextMenu(mainWindow: BrowserWindow) {
contextMenu.contextMenu(mainWindow.webContents)
// setup context menu for all webviews like miniapp
app.on('web-contents-created', (_, webContents) => {
contextMenu.contextMenu(webContents)
contextMenu.contextMenu(mainWindow)
app.on('browser-window-created', (_, win) => {
contextMenu.contextMenu(win)
})
// Dangerous API
@@ -451,7 +437,8 @@ export class WindowService {
preload: join(__dirname, '../preload/index.js'),
sandbox: false,
webSecurity: false,
webviewTag: true
webviewTag: true,
backgroundThrottling: false
}
})

View File

@@ -1,13 +0,0 @@
import { FileListResponse, FileMetadata, FileUploadResponse, Provider } from '@types'
export abstract class BaseFileService {
protected readonly provider: Provider
protected constructor(provider: Provider) {
this.provider = provider
}
abstract uploadFile(file: FileMetadata): Promise<FileUploadResponse>
abstract deleteFile(fileId: string): Promise<void>
abstract listFiles(): Promise<FileListResponse>
abstract retrieveFile(fileId: string): Promise<FileUploadResponse>
}

View File

@@ -1,41 +0,0 @@
import { Provider } from '@types'
import { BaseFileService } from './BaseFileService'
import { GeminiService } from './GeminiService'
import { MistralService } from './MistralService'
export class FileServiceManager {
private static instance: FileServiceManager
private services: Map<string, BaseFileService> = new Map()
// eslint-disable-next-line @typescript-eslint/no-empty-function
private constructor() {}
static getInstance(): FileServiceManager {
if (!this.instance) {
this.instance = new FileServiceManager()
}
return this.instance
}
getService(provider: Provider): BaseFileService {
const type = provider.type
let service = this.services.get(type)
if (!service) {
switch (type) {
case 'gemini':
service = new GeminiService(provider)
break
case 'mistral':
service = new MistralService(provider)
break
default:
throw new Error(`Unsupported service type: ${type}`)
}
this.services.set(type, service)
}
return service
}
}

View File

@@ -1,190 +0,0 @@
import { File, Files, FileState, GoogleGenAI } from '@google/genai'
import { FileListResponse, FileMetadata, FileUploadResponse, Provider } from '@types'
import Logger from 'electron-log'
import { v4 as uuidv4 } from 'uuid'
import { CacheService } from '../CacheService'
import { BaseFileService } from './BaseFileService'
export class GeminiService extends BaseFileService {
private static readonly FILE_LIST_CACHE_KEY = 'gemini_file_list'
private static readonly FILE_CACHE_DURATION = 48 * 60 * 60 * 1000
private static readonly LIST_CACHE_DURATION = 3000
protected readonly fileManager: Files
constructor(provider: Provider) {
super(provider)
this.fileManager = new GoogleGenAI({
vertexai: false,
apiKey: provider.apiKey,
httpOptions: {
baseUrl: provider.apiHost
}
}).files
}
async uploadFile(file: FileMetadata): Promise<FileUploadResponse> {
try {
const uploadResult = await this.fileManager.upload({
file: file.path,
config: {
mimeType: 'application/pdf',
name: file.id,
displayName: file.origin_name
}
})
// 根据文件状态设置响应状态
let status: 'success' | 'processing' | 'failed' | 'unknown'
switch (uploadResult.state) {
case FileState.ACTIVE:
status = 'success'
break
case FileState.PROCESSING:
status = 'processing'
break
case FileState.FAILED:
status = 'failed'
break
default:
status = 'unknown'
}
const response: FileUploadResponse = {
fileId: uploadResult.name || '',
displayName: file.origin_name,
status,
originalFile: {
type: 'gemini',
file: uploadResult
}
}
// 只缓存成功的文件
if (status === 'success') {
const cacheKey = `${GeminiService.FILE_LIST_CACHE_KEY}_${response.fileId}`
CacheService.set<FileUploadResponse>(cacheKey, response, GeminiService.FILE_CACHE_DURATION)
}
return response
} catch (error) {
Logger.error('Error uploading file to Gemini:', error)
return {
fileId: '',
displayName: file.origin_name,
status: 'failed',
originalFile: undefined
}
}
}
async retrieveFile(fileId: string): Promise<FileUploadResponse> {
try {
const cachedResponse = CacheService.get<FileUploadResponse>(`${GeminiService.FILE_LIST_CACHE_KEY}_${fileId}`)
Logger.info('[GeminiService] cachedResponse', cachedResponse)
if (cachedResponse) {
return cachedResponse
}
const files: File[] = []
for await (const f of await this.fileManager.list()) {
files.push(f)
}
Logger.info('[GeminiService] files', files)
const file = files
.filter((file) => file.state === FileState.ACTIVE)
.find((file) => file.name?.substring(6) === fileId) // 去掉 files/ 前缀
Logger.info('[GeminiService] file', file)
if (file) {
return {
fileId: fileId,
displayName: file.displayName || '',
status: 'success',
originalFile: {
type: 'gemini',
file
}
}
}
return {
fileId: fileId,
displayName: '',
status: 'failed',
originalFile: undefined
}
} catch (error) {
Logger.error('Error retrieving file from Gemini:', error)
return {
fileId: fileId,
displayName: '',
status: 'failed',
originalFile: undefined
}
}
}
async listFiles(): Promise<FileListResponse> {
try {
const cachedList = CacheService.get<FileListResponse>(GeminiService.FILE_LIST_CACHE_KEY)
if (cachedList) {
return cachedList
}
const geminiFiles: File[] = []
for await (const f of await this.fileManager.list()) {
geminiFiles.push(f)
}
const fileList: FileListResponse = {
files: geminiFiles
.filter((file) => file.state === FileState.ACTIVE)
.map((file) => {
// 更新单个文件的缓存
const fileResponse: FileUploadResponse = {
fileId: file.name || uuidv4(),
displayName: file.displayName || '',
status: 'success',
originalFile: {
type: 'gemini',
file
}
}
CacheService.set(
`${GeminiService.FILE_LIST_CACHE_KEY}_${file.name}`,
fileResponse,
GeminiService.FILE_CACHE_DURATION
)
return {
id: file.name || uuidv4(),
displayName: file.displayName || '',
size: Number(file.sizeBytes),
status: 'success',
originalFile: {
type: 'gemini',
file
}
}
})
}
// 更新文件列表缓存
CacheService.set(GeminiService.FILE_LIST_CACHE_KEY, fileList, GeminiService.LIST_CACHE_DURATION)
return fileList
} catch (error) {
Logger.error('Error listing files from Gemini:', error)
return { files: [] }
}
}
async deleteFile(fileId: string): Promise<void> {
try {
await this.fileManager.delete({ name: fileId })
Logger.info(`File ${fileId} deleted from Gemini`)
} catch (error) {
Logger.error('Error deleting file from Gemini:', error)
throw error
}
}
}

View File

@@ -1,104 +0,0 @@
import fs from 'node:fs/promises'
import { Mistral } from '@mistralai/mistralai'
import { FileListResponse, FileMetadata, FileUploadResponse, Provider } from '@types'
import Logger from 'electron-log'
import { MistralClientManager } from '../MistralClientManager'
import { BaseFileService } from './BaseFileService'
export class MistralService extends BaseFileService {
private readonly client: Mistral
constructor(provider: Provider) {
super(provider)
const clientManager = MistralClientManager.getInstance()
clientManager.initializeClient(provider)
this.client = clientManager.getClient()
}
async uploadFile(file: FileMetadata): Promise<FileUploadResponse> {
try {
const fileBuffer = await fs.readFile(file.path)
const response = await this.client.files.upload({
file: {
fileName: file.origin_name,
content: new Uint8Array(fileBuffer)
},
purpose: 'ocr'
})
return {
fileId: response.id,
displayName: file.origin_name,
status: 'success',
originalFile: {
type: 'mistral',
file: response
}
}
} catch (error) {
Logger.error('Error uploading file:', error)
return {
fileId: '',
displayName: file.origin_name,
status: 'failed'
}
}
}
async listFiles(): Promise<FileListResponse> {
try {
const response = await this.client.files.list({})
return {
files: response.data.map((file) => ({
id: file.id,
displayName: file.filename || '',
size: file.sizeBytes,
status: 'success', // All listed files are processed,
originalFile: {
type: 'mistral',
file
}
}))
}
} catch (error) {
Logger.error('Error listing files:', error)
return { files: [] }
}
}
async deleteFile(fileId: string): Promise<void> {
try {
await this.client.files.delete({
fileId
})
Logger.info(`File ${fileId} deleted`)
} catch (error) {
Logger.error('Error deleting file:', error)
throw error
}
}
async retrieveFile(fileId: string): Promise<FileUploadResponse> {
try {
const response = await this.client.files.retrieve({
fileId
})
return {
fileId: response.id,
displayName: response.filename || '',
status: 'success' // Retrieved files are always processed
}
} catch (error) {
Logger.error('Error retrieving file:', error)
return {
fileId: fileId,
displayName: '',
status: 'failed',
originalFile: undefined
}
}
}
}

View File

@@ -1,47 +1,37 @@
import { IpcChannel } from '@shared/IpcChannel'
import Logger from 'electron-log'
import { windowService } from '../WindowService'
export async function handleProvidersProtocolUrl(url: URL) {
export function handleProvidersProtocolUrl(url: URL) {
const params = new URLSearchParams(url.search)
switch (url.pathname) {
case '/api-keys': {
// jsonConfig example:
// {
// "id": "tokenflux",
// "baseUrl": "https://tokenflux.ai/v1",
// "apiKey": "sk-xxxx",
// "name": "TokenFlux", // optional
// "type": "openai" // optional
// "apiKey": "sk-xxxx"
// }
// cherrystudio://providers/api-keys?v=1&data={base64Encode(JSON.stringify(jsonConfig))}
// replace + and / to _ and - because + and / are processed by URLSearchParams
const processedSearch = url.search.replaceAll('+', '_').replaceAll('/', '-')
const params = new URLSearchParams(processedSearch)
// cherrystudio://providers/api-keys?data={base64Encode(JSON.stringify(jsonConfig))}
const data = params.get('data')
const mainWindow = windowService.getMainWindow()
const version = params.get('v')
if (version == '1') {
// TODO: handle different version
Logger.info('handleProvidersProtocolUrl', { data, version })
}
// add check there is window.navigate function in mainWindow
if (
mainWindow &&
!mainWindow.isDestroyed() &&
(await mainWindow.webContents.executeJavaScript(`typeof window.navigate === 'function'`))
) {
mainWindow.webContents.executeJavaScript(`window.navigate('/settings/provider?addProviderData=${data}')`)
if (data) {
const stringify = Buffer.from(data, 'base64').toString('utf8')
Logger.info('get api keys from urlschema: ', stringify)
const jsonConfig = JSON.parse(stringify)
Logger.info('get api keys from urlschema: ', jsonConfig)
const mainWindow = windowService.getMainWindow()
if (mainWindow && !mainWindow.isDestroyed()) {
mainWindow.webContents.send(IpcChannel.Provider_AddKey, jsonConfig)
mainWindow.webContents.executeJavaScript(`window.navigate('/settings/provider?id=${jsonConfig.id}')`)
}
} else {
setTimeout(() => {
handleProvidersProtocolUrl(url)
}, 1000)
Logger.error('No data found in URL')
}
break
}
default:
Logger.error(`Unknown MCP protocol URL: ${url}`)
console.error(`Unknown MCP protocol URL: ${url}`)
break
}
}

View File

@@ -92,7 +92,6 @@ describe('file', () => {
it('should return DOCUMENT for document extensions', () => {
expect(getFileType('.pdf')).toBe(FileTypes.DOCUMENT)
expect(getFileType('.pptx')).toBe(FileTypes.DOCUMENT)
expect(getFileType('.doc')).toBe(FileTypes.DOCUMENT)
expect(getFileType('.docx')).toBe(FileTypes.DOCUMENT)
expect(getFileType('.xlsx')).toBe(FileTypes.DOCUMENT)
expect(getFileType('.odt')).toBe(FileTypes.DOCUMENT)

View File

@@ -2,26 +2,12 @@ import * as fs from 'node:fs'
import os from 'node:os'
import path from 'node:path'
import { isLinux, isPortable } from '@main/constant'
import { isMac } from '@main/constant'
import { audioExts, documentExts, imageExts, textExts, videoExts } from '@shared/config/constant'
import { FileMetadata, FileTypes } from '@types'
import { FileType, FileTypes } from '@types'
import { app } from 'electron'
import { v4 as uuidv4 } from 'uuid'
export function initAppDataDir() {
const appDataPath = getAppDataPathFromConfig()
if (appDataPath) {
app.setPath('userData', appDataPath)
return
}
if (isPortable) {
const portableDir = process.env.PORTABLE_EXECUTABLE_DIR
app.setPath('userData', path.join(portableDir || app.getPath('exe'), 'data'))
return
}
}
// 创建文件类型映射表,提高查找效率
const fileTypeMap = new Map<string, FileTypes>()
@@ -37,112 +23,12 @@ function initFileTypeMap() {
// 初始化映射表
initFileTypeMap()
export function hasWritePermission(path: string) {
try {
fs.accessSync(path, fs.constants.W_OK)
return true
} catch (error) {
return false
}
}
function getAppDataPathFromConfig() {
try {
const configPath = path.join(getConfigDir(), 'config.json')
if (!fs.existsSync(configPath)) {
return null
}
const config = JSON.parse(fs.readFileSync(configPath, 'utf-8'))
if (!config.appDataPath) {
return null
}
let executablePath = app.getPath('exe')
if (isLinux && process.env.APPIMAGE) {
// 如果是 AppImage 打包的应用,直接使用 APPIMAGE 环境变量
// 这样可以确保获取到正确的可执行文件路径
executablePath = path.join(path.dirname(process.env.APPIMAGE), 'cherry-studio.appimage')
}
let appDataPath = null
// 兼容旧版本
if (config.appDataPath && typeof config.appDataPath === 'string') {
appDataPath = config.appDataPath
// 将旧版本数据迁移到新版本
appDataPath && updateAppDataConfig(appDataPath)
} else {
appDataPath = config.appDataPath.find(
(item: { executablePath: string }) => item.executablePath === executablePath
)?.dataPath
}
if (appDataPath && fs.existsSync(appDataPath) && hasWritePermission(appDataPath)) {
return appDataPath
}
return null
} catch (error) {
return null
}
}
export function updateAppDataConfig(appDataPath: string) {
const configDir = getConfigDir()
if (!fs.existsSync(configDir)) {
fs.mkdirSync(configDir, { recursive: true })
}
// config.json
// appDataPath: [{ executablePath: string, dataPath: string }]
const configPath = path.join(getConfigDir(), 'config.json')
let executablePath = app.getPath('exe')
if (isLinux && process.env.APPIMAGE) {
executablePath = path.join(path.dirname(process.env.APPIMAGE), 'cherry-studio.appimage')
}
if (!fs.existsSync(configPath)) {
fs.writeFileSync(configPath, JSON.stringify({ appDataPath: [{ executablePath, dataPath: appDataPath }] }, null, 2))
return
}
const config = JSON.parse(fs.readFileSync(configPath, 'utf-8'))
if (!config.appDataPath || (config.appDataPath && typeof config.appDataPath !== 'object')) {
config.appDataPath = []
}
const existingPath = config.appDataPath.find(
(item: { executablePath: string }) => item.executablePath === executablePath
)
if (existingPath) {
existingPath.dataPath = appDataPath
} else {
config.appDataPath.push({ executablePath, dataPath: appDataPath })
}
fs.writeFileSync(configPath, JSON.stringify(config, null, 2))
}
export function getFileType(ext: string): FileTypes {
ext = ext.toLowerCase()
return fileTypeMap.get(ext) || FileTypes.OTHER
}
export function getFileDir(filePath: string) {
return path.dirname(filePath)
}
export function getFileName(filePath: string) {
return path.basename(filePath)
}
export function getFileExt(filePath: string) {
return path.extname(filePath)
}
export function getAllFiles(dirPath: string, arrayOfFiles: FileMetadata[] = []): FileMetadata[] {
export function getAllFiles(dirPath: string, arrayOfFiles: FileType[] = []): FileType[] {
const files = fs.readdirSync(dirPath)
files.forEach((file) => {
@@ -164,7 +50,7 @@ export function getAllFiles(dirPath: string, arrayOfFiles: FileMetadata[] = []):
const name = path.basename(file)
const size = fs.statSync(fullPath).size
const fileItem: FileMetadata = {
const fileItem: FileType = {
id: uuidv4(),
name,
path: fullPath,
@@ -202,3 +88,12 @@ export function getCacheDir() {
export function getAppConfigDir(name: string) {
return path.join(getConfigDir(), name)
}
export function setUserDataDir() {
if (!isMac) {
const dir = path.join(path.dirname(app.getPath('exe')), 'data')
if (fs.existsSync(dir) && fs.statSync(dir).isDirectory()) {
app.setPath('userData', dir)
}
}
}

View File

@@ -49,7 +49,7 @@ export async function getBinaryPath(name?: string): Promise<string> {
const binaryName = await getBinaryName(name)
const binariesDir = path.join(os.homedir(), '.cherrystudio', 'bin')
const binariesDirExists = fs.existsSync(binariesDir)
const binariesDirExists = await fs.existsSync(binariesDir)
return binariesDirExists ? path.join(binariesDir, binaryName) : binaryName
}

View File

@@ -1,92 +0,0 @@
import { app } from 'electron'
import macosRelease from 'macos-release'
import os from 'os'
/**
* System information interface
*/
export interface SystemInfo {
platform: NodeJS.Platform
arch: string
osRelease: string
appVersion: string
osString: string
archString: string
}
/**
* Get basic system constants for quick access
* @returns {Object} Basic system constants
*/
export function getSystemConstants() {
return {
platform: process.platform,
arch: process.arch,
osRelease: os.release(),
appVersion: app.getVersion()
}
}
/**
* Get system information
* @returns {SystemInfo} Complete system information object
*/
export function getSystemInfo(): SystemInfo {
const platform = process.platform
const arch = process.arch
const osRelease = os.release()
const appVersion = app.getVersion()
let osString = ''
switch (platform) {
case 'win32': {
// Get Windows version
const parts = osRelease.split('.')
const buildNumber = parseInt(parts[2], 10)
osString = buildNumber >= 22000 ? 'Windows 11' : 'Windows 10'
break
}
case 'darwin': {
// macOS version handling using macos-release for better accuracy
try {
const macVersionInfo = macosRelease()
const versionString = macVersionInfo.version.replace(/\./g, '_') // 15.0.0 -> 15_0_0
osString = arch === 'arm64' ? `Mac OS X ${versionString}` : `Intel Mac OS X ${versionString}` // Mac OS X 15_0_0
} catch (error) {
// Fallback to original logic if macos-release fails
const macVersion = osRelease.split('.').slice(0, 2).join('_')
osString = arch === 'arm64' ? `Mac OS X ${macVersion}` : `Intel Mac OS X ${macVersion}`
}
break
}
case 'linux': {
osString = `Linux ${arch}`
break
}
default: {
osString = `${platform} ${arch}`
}
}
const archString = arch === 'x64' ? 'x86_64' : arch === 'arm64' ? 'arm64' : arch
return {
platform,
arch,
osRelease,
appVersion,
osString,
archString
}
}
/**
* Generate User-Agent string based on user system data
* @returns {string} Dynamically generated User-Agent string
*/
export function generateUserAgent(): string {
const systemInfo = getSystemInfo()
return `Mozilla/5.0 (${systemInfo.osString}; ${systemInfo.archString}) AppleWebKit/537.36 (KHTML, like Gecko) CherryStudio/${systemInfo.appVersion} Chrome/124.0.0.0 Safari/537.36`
}

View File

@@ -1,19 +1,8 @@
import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import { electronAPI } from '@electron-toolkit/preload'
import { UpgradeChannel } from '@shared/config/constant'
import { FeedUrl } from '@shared/config/constant'
import { IpcChannel } from '@shared/IpcChannel'
import {
FileListResponse,
FileMetadata,
FileUploadResponse,
KnowledgeBaseParams,
KnowledgeItem,
MCPServer,
Provider,
Shortcut,
ThemeMode,
WebDavConfig
} from '@types'
import { FileType, KnowledgeBaseParams, KnowledgeItem, MCPServer, Shortcut, ThemeMode, WebDavConfig } from '@types'
import { contextBridge, ipcRenderer, OpenDialogOptions, shell, webUtils } from 'electron'
import { Notification } from 'src/renderer/src/types/notification'
import { CreateDirectoryOptions } from 'webdav'
@@ -28,35 +17,18 @@ const api = {
checkForUpdate: () => ipcRenderer.invoke(IpcChannel.App_CheckForUpdate),
showUpdateDialog: () => ipcRenderer.invoke(IpcChannel.App_ShowUpdateDialog),
setLanguage: (lang: string) => ipcRenderer.invoke(IpcChannel.App_SetLanguage, lang),
setEnableSpellCheck: (isEnable: boolean) => ipcRenderer.invoke(IpcChannel.App_SetEnableSpellCheck, isEnable),
setSpellCheckLanguages: (languages: string[]) => ipcRenderer.invoke(IpcChannel.App_SetSpellCheckLanguages, languages),
setLaunchOnBoot: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetLaunchOnBoot, isActive),
setLaunchToTray: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetLaunchToTray, isActive),
setTray: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetTray, isActive),
setTrayOnClose: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetTrayOnClose, isActive),
setTestPlan: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetTestPlan, isActive),
setTestChannel: (channel: UpgradeChannel) => ipcRenderer.invoke(IpcChannel.App_SetTestChannel, channel),
setFeedUrl: (feedUrl: FeedUrl) => ipcRenderer.invoke(IpcChannel.App_SetFeedUrl, feedUrl),
setTheme: (theme: ThemeMode) => ipcRenderer.invoke(IpcChannel.App_SetTheme, theme),
handleZoomFactor: (delta: number, reset: boolean = false) =>
ipcRenderer.invoke(IpcChannel.App_HandleZoomFactor, delta, reset),
setAutoUpdate: (isActive: boolean) => ipcRenderer.invoke(IpcChannel.App_SetAutoUpdate, isActive),
select: (options: Electron.OpenDialogOptions) => ipcRenderer.invoke(IpcChannel.App_Select, options),
hasWritePermission: (path: string) => ipcRenderer.invoke(IpcChannel.App_HasWritePermission, path),
setAppDataPath: (path: string) => ipcRenderer.invoke(IpcChannel.App_SetAppDataPath, path),
getDataPathFromArgs: () => ipcRenderer.invoke(IpcChannel.App_GetDataPathFromArgs),
copy: (oldPath: string, newPath: string, occupiedDirs: string[] = []) =>
ipcRenderer.invoke(IpcChannel.App_Copy, oldPath, newPath, occupiedDirs),
setStopQuitApp: (stop: boolean, reason: string) => ipcRenderer.invoke(IpcChannel.App_SetStopQuitApp, stop, reason),
flushAppData: () => ipcRenderer.invoke(IpcChannel.App_FlushAppData),
isNotEmptyDir: (path: string) => ipcRenderer.invoke(IpcChannel.App_IsNotEmptyDir, path),
relaunchApp: (options?: Electron.RelaunchOptions) => ipcRenderer.invoke(IpcChannel.App_RelaunchApp, options),
openWebsite: (url: string) => ipcRenderer.invoke(IpcChannel.Open_Website, url),
getCacheSize: () => ipcRenderer.invoke(IpcChannel.App_GetCacheSize),
clearCache: () => ipcRenderer.invoke(IpcChannel.App_ClearCache),
mac: {
isProcessTrusted: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.App_MacIsProcessTrusted),
requestProcessTrust: (): Promise<boolean> => ipcRenderer.invoke(IpcChannel.App_MacRequestProcessTrust)
},
notification: {
send: (notification: Notification) => ipcRenderer.invoke(IpcChannel.Notification_Send, notification)
},
@@ -90,25 +62,13 @@ const api = {
},
file: {
select: (options?: OpenDialogOptions) => ipcRenderer.invoke(IpcChannel.File_Select, options),
upload: (file: FileMetadata) => ipcRenderer.invoke(IpcChannel.File_Upload, file),
upload: (file: FileType) => ipcRenderer.invoke(IpcChannel.File_Upload, file),
delete: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_Delete, fileId),
deleteDir: (dirPath: string) => ipcRenderer.invoke(IpcChannel.File_DeleteDir, dirPath),
read: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_Read, fileId),
clear: () => ipcRenderer.invoke(IpcChannel.File_Clear),
get: (filePath: string) => ipcRenderer.invoke(IpcChannel.File_Get, filePath),
/**
* 创建一个空的临时文件
* @param fileName 文件名
* @returns 临时文件路径
*/
createTempFile: (fileName: string): Promise<string> => ipcRenderer.invoke(IpcChannel.File_CreateTempFile, fileName),
/**
* 写入文件
* @param filePath 文件路径
* @param data 数据
*/
create: (fileName: string) => ipcRenderer.invoke(IpcChannel.File_Create, fileName),
write: (filePath: string, data: Uint8Array | string) => ipcRenderer.invoke(IpcChannel.File_Write, filePath, data),
writeWithId: (id: string, content: string) => ipcRenderer.invoke(IpcChannel.File_WriteWithId, id, content),
open: (options?: OpenDialogOptions) => ipcRenderer.invoke(IpcChannel.File_Open, options),
openPath: (path: string) => ipcRenderer.invoke(IpcChannel.File_OpenPath, path),
@@ -116,12 +76,12 @@ const api = {
ipcRenderer.invoke(IpcChannel.File_Save, path, content, options),
selectFolder: () => ipcRenderer.invoke(IpcChannel.File_SelectFolder),
saveImage: (name: string, data: string) => ipcRenderer.invoke(IpcChannel.File_SaveImage, name, data),
binaryImage: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_BinaryImage, fileId),
base64Image: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_Base64Image, fileId),
saveBase64Image: (data: string) => ipcRenderer.invoke(IpcChannel.File_SaveBase64Image, data),
download: (url: string, isUseContentType?: boolean) =>
ipcRenderer.invoke(IpcChannel.File_Download, url, isUseContentType),
copy: (fileId: string, destPath: string) => ipcRenderer.invoke(IpcChannel.File_Copy, fileId, destPath),
binaryImage: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_BinaryImage, fileId),
base64File: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_Base64File, fileId),
pdfInfo: (fileId: string) => ipcRenderer.invoke(IpcChannel.File_GetPdfInfo, fileId),
getPathForFile: (file: File) => webUtils.getPathForFile(file)
@@ -143,38 +103,31 @@ const api = {
add: ({
base,
item,
userId,
forceReload = false
}: {
base: KnowledgeBaseParams
item: KnowledgeItem
userId?: string
forceReload?: boolean
}) => ipcRenderer.invoke(IpcChannel.KnowledgeBase_Add, { base, item, forceReload, userId }),
}) => ipcRenderer.invoke(IpcChannel.KnowledgeBase_Add, { base, item, forceReload }),
remove: ({ uniqueId, uniqueIds, base }: { uniqueId: string; uniqueIds: string[]; base: KnowledgeBaseParams }) =>
ipcRenderer.invoke(IpcChannel.KnowledgeBase_Remove, { uniqueId, uniqueIds, base }),
search: ({ search, base }: { search: string; base: KnowledgeBaseParams }) =>
ipcRenderer.invoke(IpcChannel.KnowledgeBase_Search, { search, base }),
rerank: ({ search, base, results }: { search: string; base: KnowledgeBaseParams; results: ExtractChunkData[] }) =>
ipcRenderer.invoke(IpcChannel.KnowledgeBase_Rerank, { search, base, results }),
checkQuota: ({ base, userId }: { base: KnowledgeBaseParams; userId: string }) =>
ipcRenderer.invoke(IpcChannel.KnowledgeBase_Check_Quota, base, userId)
ipcRenderer.invoke(IpcChannel.KnowledgeBase_Rerank, { search, base, results })
},
window: {
setMinimumSize: (width: number, height: number) =>
ipcRenderer.invoke(IpcChannel.Windows_SetMinimumSize, width, height),
resetMinimumSize: () => ipcRenderer.invoke(IpcChannel.Windows_ResetMinimumSize)
},
fileService: {
upload: (provider: Provider, file: FileMetadata): Promise<FileUploadResponse> =>
ipcRenderer.invoke(IpcChannel.FileService_Upload, provider, file),
list: (provider: Provider): Promise<FileListResponse> => ipcRenderer.invoke(IpcChannel.FileService_List, provider),
delete: (provider: Provider, fileId: string) => ipcRenderer.invoke(IpcChannel.FileService_Delete, provider, fileId),
retrieve: (provider: Provider, fileId: string): Promise<FileUploadResponse> =>
ipcRenderer.invoke(IpcChannel.FileService_Retrieve, provider, fileId)
},
selectionMenu: {
action: (action: string) => ipcRenderer.invoke('selection-menu:action', action)
gemini: {
uploadFile: (file: FileType, { apiKey, baseURL }: { apiKey: string; baseURL: string }) =>
ipcRenderer.invoke(IpcChannel.Gemini_UploadFile, file, { apiKey, baseURL }),
base64File: (file: FileType) => ipcRenderer.invoke(IpcChannel.Gemini_Base64File, file),
retrieveFile: (file: FileType, apiKey: string) => ipcRenderer.invoke(IpcChannel.Gemini_RetrieveFile, file, apiKey),
listFiles: (apiKey: string) => ipcRenderer.invoke(IpcChannel.Gemini_ListFiles, apiKey),
deleteFile: (fileId: string, apiKey: string) => ipcRenderer.invoke(IpcChannel.Gemini_DeleteFile, fileId, apiKey)
},
vertexAI: {
@@ -217,10 +170,6 @@ const api = {
getInstallInfo: () => ipcRenderer.invoke(IpcChannel.Mcp_GetInstallInfo),
checkMcpConnectivity: (server: any) => ipcRenderer.invoke(IpcChannel.Mcp_CheckConnectivity, server)
},
python: {
execute: (script: string, context?: Record<string, any>, timeout?: number) =>
ipcRenderer.invoke(IpcChannel.Python_Execute, script, context, timeout)
},
shell: {
openExternal: (url: string, options?: Electron.OpenExternalOptions) => shell.openExternal(url, options)
},
@@ -263,9 +212,7 @@ const api = {
},
webview: {
setOpenLinkExternal: (webviewId: number, isExternal: boolean) =>
ipcRenderer.invoke(IpcChannel.Webview_SetOpenLinkExternal, webviewId, isExternal),
setSpellCheckEnabled: (webviewId: number, isEnable: boolean) =>
ipcRenderer.invoke(IpcChannel.Webview_SetSpellCheckEnabled, webviewId, isEnable)
ipcRenderer.invoke(IpcChannel.Webview_SetOpenLinkExternal, webviewId, isExternal)
},
storeSync: {
subscribe: () => ipcRenderer.invoke(IpcChannel.StoreSync_Subscribe),

View File

@@ -2,45 +2,42 @@
<html lang="zh-CN">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="initial-scale=1, width=device-width" />
<meta http-equiv="Content-Security-Policy"
content="default-src 'self'; connect-src blob: *; script-src 'self' 'unsafe-eval' *; worker-src 'self' blob:; style-src 'self' 'unsafe-inline' *; font-src 'self' data: *; img-src 'self' data: file: * blob:; frame-src * file:" />
<title>Cherry Studio Selection Toolbar</title>
<meta charset="UTF-8" />
<meta name="viewport" content="initial-scale=1, width=device-width" />
<meta http-equiv="Content-Security-Policy"
content="default-src 'self'; connect-src blob: *; script-src 'self' 'unsafe-eval' *; worker-src 'self' blob:; style-src 'self' 'unsafe-inline' *; font-src 'self' data: *; img-src 'self' data: file: * blob:; frame-src * file:" />
<title>Cherry Studio Selection Toolbar</title>
</head>
<body>
<div id="root"></div>
<script type="module" src="/src/windows/selection/toolbar/entryPoint.tsx"></script>
<style>
html {
margin: 0 !important;
background-color: transparent !important;
background-image: none !important;
<div id="root"></div>
<script type="module" src="/src/windows/selection/toolbar/entryPoint.tsx"></script>
<style>
html {
margin: 0;
}
}
body {
margin: 0;
padding: 0;
overflow: hidden;
width: 100vw;
height: 100vh;
body {
margin: 0 !important;
padding: 0 !important;
overflow: hidden !important;
width: 100vw !important;
height: 100vh !important;
-webkit-user-select: none;
-moz-user-select: none;
-ms-user-select: none;
user-select: none;
}
-webkit-user-select: none;
-moz-user-select: none;
-ms-user-select: none;
user-select: none;
}
#root {
margin: 0 !important;
padding: 0 !important;
width: max-content !important;
height: fit-content !important;
}
</style>
#root {
margin: 0;
padding: 0;
width: max-content !important;
height: fit-content !important;
}
</style>
</body>
</html>

View File

@@ -20,7 +20,6 @@ import {
SdkToolCall
} from '@renderer/types/sdk'
import { CompletionsContext } from '../middleware/types'
import { AnthropicAPIClient } from './anthropic/AnthropicAPIClient'
import { BaseApiClient } from './BaseApiClient'
import { GeminiAPIClient } from './gemini/GeminiAPIClient'
@@ -42,19 +41,11 @@ export class AihubmixAPIClient extends BaseApiClient {
constructor(provider: Provider) {
super(provider)
const providerExtraHeaders = {
...provider,
extra_headers: {
...provider.extra_headers,
'APP-Code': 'MLTG2087'
}
}
// 初始化各个client - 现在有类型安全
const claudeClient = new AnthropicAPIClient(providerExtraHeaders)
const geminiClient = new GeminiAPIClient({ ...providerExtraHeaders, apiHost: 'https://aihubmix.com/gemini' })
const openaiClient = new OpenAIResponseAPIClient(providerExtraHeaders)
const defaultClient = new OpenAIAPIClient(providerExtraHeaders)
const claudeClient = new AnthropicAPIClient(provider)
const geminiClient = new GeminiAPIClient({ ...provider, apiHost: 'https://aihubmix.com/gemini' })
const openaiClient = new OpenAIResponseAPIClient(provider)
const defaultClient = new OpenAIAPIClient(provider)
this.clients.set('claude', claudeClient)
this.clients.set('gemini', geminiClient)
@@ -66,13 +57,6 @@ export class AihubmixAPIClient extends BaseApiClient {
this.currentClient = this.defaultClient as BaseApiClient
}
override getBaseURL(): string {
if (!this.currentClient) {
return this.provider.apiHost
}
return this.currentClient.getBaseURL()
}
/**
* 类型守卫确保client是BaseApiClient的实例
*/
@@ -179,8 +163,8 @@ export class AihubmixAPIClient extends BaseApiClient {
return this.currentClient.getRequestTransformer()
}
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<SdkRawChunk> {
return this.currentClient.getResponseChunkTransformer(ctx)
getResponseChunkTransformer(): ResponseChunkTransformer<SdkRawChunk> {
return this.currentClient.getResponseChunkTransformer()
}
convertMcpToolsToSdkTools(mcpTools: MCPTool[]): SdkTool[] {

View File

@@ -7,7 +7,6 @@ import { GeminiAPIClient } from './gemini/GeminiAPIClient'
import { VertexAPIClient } from './gemini/VertexAPIClient'
import { OpenAIAPIClient } from './openai/OpenAIApiClient'
import { OpenAIResponseAPIClient } from './openai/OpenAIResponseAPIClient'
import { PPIOAPIClient } from './ppio/PPIOAPIClient'
/**
* Factory for creating ApiClient instances based on provider configuration
@@ -32,11 +31,6 @@ export class ApiClientFactory {
instance = new AihubmixAPIClient(provider) as BaseApiClient
return instance
}
if (provider.id === 'ppio') {
console.log(`[ApiClientFactory] Creating PPIOAPIClient for provider: ${provider.id}`)
instance = new PPIOAPIClient(provider) as BaseApiClient
return instance
}
// 然后检查标准的provider type
switch (provider.type) {

View File

@@ -37,13 +37,12 @@ import {
} from '@renderer/types/sdk'
import { isJSON, parseJSON } from '@renderer/utils'
import { addAbortController, removeAbortController } from '@renderer/utils/abortController'
import { findFileBlocks, getContentWithTools, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { defaultTimeout } from '@shared/config/constant'
import Logger from 'electron-log/renderer'
import { isEmpty } from 'lodash'
import { CompletionsContext } from '../middleware/types'
import { ApiClient, RequestTransformer, ResponseChunkTransformer } from './types'
import { ApiClient, RawStreamListener, RequestTransformer, ResponseChunkTransformer } from './types'
/**
* Abstract base class for API clients.
@@ -96,7 +95,7 @@ export abstract class BaseApiClient<
// 在 CoreRequestToSdkParamsMiddleware中使用
abstract getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
// 在RawSdkChunkToGenericChunkMiddleware中使用
abstract getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<TRawChunk>
abstract getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk>
/**
* 工具转换
@@ -111,7 +110,7 @@ export abstract class BaseApiClient<
abstract buildSdkMessages(
currentReqMessages: TMessageParam[],
output: TRawOutput | string | undefined,
output: TRawOutput | string,
toolResults: TMessageParam[],
toolCalls?: TToolCall[]
): TMessageParam[]
@@ -130,6 +129,17 @@ export abstract class BaseApiClient<
*/
abstract extractMessagesFromSdkPayload(sdkPayload: TSdkParams): TMessageParam[]
/**
* 附加原始流监听器
*/
public attachRawStreamListener<TListener extends RawStreamListener<TRawChunk>>(
rawOutput: TRawOutput,
// eslint-disable-next-line @typescript-eslint/no-unused-vars
_listener: TListener
): TRawOutput {
return rawOutput
}
/**
* 通用函数
**/
@@ -209,8 +219,7 @@ export abstract class BaseApiClient<
}
public async getMessageContent(message: Message): Promise<string> {
const content = getContentWithTools(message)
const content = getMainTextContent(message)
if (isEmpty(content)) {
return ''
}
@@ -274,7 +283,6 @@ export abstract class BaseApiClient<
const webSearch: WebSearchResponse = window.keyv.get(`web-search-${message.id}`)
if (webSearch) {
window.keyv.remove(`web-search-${message.id}`)
return (webSearch.results as WebSearchProviderResponse).results.map(
(result, index) =>
({
@@ -300,7 +308,6 @@ export abstract class BaseApiClient<
const knowledgeReferences: KnowledgeReference[] = window.keyv.get(`knowledge-search-${message.id}`)
if (!isEmpty(knowledgeReferences)) {
window.keyv.remove(`knowledge-search-${message.id}`)
// Logger.log(`Found ${knowledgeReferences.length} knowledge base references in cache for ID: ${message.id}`)
return knowledgeReferences
}

View File

@@ -52,7 +52,7 @@ import {
TextDeltaChunk,
ThinkingDeltaChunk
} from '@renderer/types/chunk'
import { type Message } from '@renderer/types/newMessage'
import type { Message } from '@renderer/types/newMessage'
import {
AnthropicSdkMessageParam,
AnthropicSdkParams,
@@ -66,7 +66,7 @@ import {
mcpToolCallResponseToAnthropicMessage,
mcpToolsToAnthropicTools
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks } from '@renderer/utils/messageUtils/find'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { BaseApiClient } from '../BaseApiClient'
@@ -90,12 +90,11 @@ export class AnthropicAPIClient extends BaseApiClient<
return this.sdkInstance
}
this.sdkInstance = new Anthropic({
apiKey: this.apiKey,
apiKey: this.getApiKey(),
baseURL: this.getBaseURL(),
dangerouslyAllowBrowser: true,
defaultHeaders: {
'anthropic-beta': 'output-128k-2025-02-19',
...this.provider.extra_headers
'anthropic-beta': 'output-128k-2025-02-19'
}
})
return this.sdkInstance
@@ -126,7 +125,7 @@ export class AnthropicAPIClient extends BaseApiClient<
// @ts-ignore sdk未提供
override async getEmbeddingDimensions(): Promise<number> {
throw new Error("Anthropic SDK doesn't support getEmbeddingDimensions method.")
return 0
}
override getTemperature(assistant: Assistant, model: Model): number | undefined {
@@ -192,7 +191,7 @@ export class AnthropicAPIClient extends BaseApiClient<
const parts: MessageParam['content'] = [
{
type: 'text',
text: await this.getMessageContent(message)
text: getMainTextContent(message)
}
]
@@ -368,13 +367,12 @@ export class AnthropicAPIClient extends BaseApiClient<
* Anthropic专用的原始流监听器
* 处理MessageStream对象的特定事件
*/
attachRawStreamListener(
override attachRawStreamListener(
rawOutput: AnthropicSdkRawOutput,
listener: RawStreamListener<AnthropicSdkRawChunk>
): AnthropicSdkRawOutput {
console.log(`[AnthropicApiClient] 附加流监听器到原始输出`)
// 专用的Anthropic事件处理
const anthropicListener = listener as AnthropicStreamListener
// 检查是否为MessageStream
if (rawOutput instanceof MessageStream) {
console.log(`[AnthropicApiClient] 检测到 Anthropic MessageStream附加专用监听器`)
@@ -389,6 +387,9 @@ export class AnthropicAPIClient extends BaseApiClient<
})
}
// 专用的Anthropic事件处理
const anthropicListener = listener as AnthropicStreamListener
if (anthropicListener.onContentBlock) {
rawOutput.on('contentBlock', anthropicListener.onContentBlock)
}
@@ -412,10 +413,6 @@ export class AnthropicAPIClient extends BaseApiClient<
return rawOutput
}
if (anthropicListener.onMessage) {
anthropicListener.onMessage(rawOutput)
}
// 对于非MessageStream响应
return rawOutput
}
@@ -493,8 +490,7 @@ export class AnthropicAPIClient extends BaseApiClient<
system: systemMessage ? [systemMessage] : undefined,
thinking: this.getBudgetToken(assistant, model),
tools: tools.length > 0 ? tools : undefined,
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
...this.getCustomParameters(assistant)
}
const finalParams: MessageCreateParams = streamOutput
@@ -522,7 +518,6 @@ export class AnthropicAPIClient extends BaseApiClient<
async transform(rawChunk: AnthropicSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
switch (rawChunk.type) {
case 'message': {
let i = 0
for (const content of rawChunk.content) {
switch (content.type) {
case 'text': {
@@ -533,8 +528,7 @@ export class AnthropicAPIClient extends BaseApiClient<
break
}
case 'tool_use': {
toolCalls[i] = content
i++
toolCalls[0] = content
break
}
case 'thinking': {
@@ -556,22 +550,6 @@ export class AnthropicAPIClient extends BaseApiClient<
}
}
}
if (i > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: Object.values(toolCalls)
} as MCPToolCreatedChunk)
}
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: rawChunk.usage.input_tokens || 0,
completion_tokens: rawChunk.usage.output_tokens || 0,
total_tokens: (rawChunk.usage.input_tokens || 0) + (rawChunk.usage.output_tokens || 0)
}
}
})
break
}
case 'content_block_start': {

View File

@@ -1,7 +1,7 @@
import {
Content,
createPartFromUri,
File,
FileState,
FunctionCall,
GenerateContentConfig,
GenerateImagesConfig,
@@ -10,6 +10,7 @@ import {
HarmCategory,
Modality,
Model as GeminiModel,
Pager,
Part,
SafetySetting,
SendMessageParameters,
@@ -21,17 +22,17 @@ import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import {
findTokenLimit,
GEMINI_FLASH_MODEL_REGEX,
isGeminiReasoningModel,
isGemmaModel,
isSupportedThinkingTokenGeminiModel,
isVisionModel
} from '@renderer/config/models'
import { CacheService } from '@renderer/services/CacheService'
import { estimateTextTokens } from '@renderer/services/TokenService'
import {
Assistant,
EFFORT_RATIO,
FileMetadata,
FileType,
FileTypes,
FileUploadResponse,
GenerateImageParams,
MCPCallToolResponse,
MCPTool,
@@ -59,7 +60,7 @@ import {
} from '@renderer/utils/mcp-tools'
import { findFileBlocks, findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { buildSystemPrompt } from '@renderer/utils/prompt'
import { defaultTimeout, MB } from '@shared/config/constant'
import { MB } from '@shared/config/constant'
import { BaseApiClient } from '../BaseApiClient'
import { RequestTransformer, ResponseChunkTransformer } from '../types'
@@ -84,7 +85,7 @@ export class GeminiAPIClient extends BaseApiClient<
...rest,
config: {
...rest.config,
abortSignal: options?.signal,
abortSignal: options?.abortSignal,
httpOptions: {
...rest.config?.httpOptions,
timeout: options?.timeout
@@ -117,7 +118,7 @@ export class GeminiAPIClient extends BaseApiClient<
aspectRatio: imageSize,
abortSignal: signal,
httpOptions: {
timeout: defaultTimeout
timeout: 5 * 60 * 1000
}
}
const response = await sdk.models.generateImages({
@@ -146,12 +147,15 @@ export class GeminiAPIClient extends BaseApiClient<
override async getEmbeddingDimensions(model: Model): Promise<number> {
const sdk = await this.getSdkInstance()
const data = await sdk.models.embedContent({
model: model.id,
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
})
return data.embeddings?.[0]?.values?.length || 0
try {
const data = await sdk.models.embedContent({
model: model.id,
contents: [{ role: 'user', parts: [{ text: 'hi' }] }]
})
return data.embeddings?.[0]?.values?.length || 0
} catch (e) {
return 0
}
}
override async listModels(): Promise<GeminiModel[]> {
@@ -175,10 +179,7 @@ export class GeminiAPIClient extends BaseApiClient<
apiVersion: this.getApiVersion(),
httpOptions: {
baseUrl: this.getBaseURL(),
apiVersion: this.getApiVersion(),
headers: {
...this.provider.extra_headers
}
apiVersion: this.getApiVersion()
}
})
@@ -197,7 +198,7 @@ export class GeminiAPIClient extends BaseApiClient<
* @param file - The file
* @returns The part
*/
private async handlePdfFile(file: FileMetadata): Promise<Part> {
private async handlePdfFile(file: FileType): Promise<Part> {
const smallFileSize = 20 * MB
const isSmallFile = file.size < smallFileSize
@@ -212,17 +213,26 @@ export class GeminiAPIClient extends BaseApiClient<
}
// Retrieve file from Gemini uploaded files
const fileMetadata: FileUploadResponse = await window.api.fileService.retrieve(this.provider, file.id)
const fileMetadata: File | undefined = await this.retrieveFile(file)
if (fileMetadata.status === 'success') {
const remoteFile = fileMetadata.originalFile?.file as File
return createPartFromUri(remoteFile.uri!, remoteFile.mimeType!)
if (fileMetadata) {
return {
fileData: {
fileUri: fileMetadata.uri,
mimeType: fileMetadata.mimeType
} as Part['fileData']
}
}
// If file is not found, upload it to Gemini
const result = await window.api.fileService.upload(this.provider, file)
const remoteFile = result.originalFile?.file as File
return createPartFromUri(remoteFile.uri!, remoteFile.mimeType!)
const result = await this.uploadFile(file)
return {
fileData: {
fileUri: result.uri,
mimeType: result.mimeType
} as Part['fileData']
}
}
/**
@@ -233,7 +243,6 @@ export class GeminiAPIClient extends BaseApiClient<
private async convertMessageToSdkParam(message: Message): Promise<Content> {
const role = message.role === 'user' ? 'user' : 'model'
const parts: Part[] = [{ text: await this.getMessageContent(message) }]
// Add any generated images from previous responses
const imageBlocks = findImageBlocks(message)
for (const imageBlock of imageBlocks) {
@@ -384,32 +393,31 @@ export class GeminiAPIClient extends BaseApiClient<
* @returns The reasoning effort
*/
private getBudgetToken(assistant: Assistant, model: Model) {
if (isSupportedThinkingTokenGeminiModel(model)) {
if (isGeminiReasoningModel(model)) {
const reasoningEffort = assistant?.settings?.reasoning_effort
// 如果thinking_budget是undefined不思考
if (reasoningEffort === undefined) {
return GEMINI_FLASH_MODEL_REGEX.test(model.id)
? {
thinkingConfig: {
thinkingBudget: 0
}
}
: {}
}
if (reasoningEffort === 'auto') {
return {
thinkingConfig: {
includeThoughts: true,
thinkingBudget: -1
includeThoughts: false,
...(GEMINI_FLASH_MODEL_REGEX.test(model.id) ? { thinkingBudget: 0 } : {})
} as ThinkingConfig
}
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
if (effortRatio > 1) {
return {
thinkingConfig: {
includeThoughts: true
}
}
}
const effortRatio = EFFORT_RATIO[reasoningEffort]
const { min, max } = findTokenLimit(model.id) || { min: 0, max: 0 }
// 计算 budgetTokens确保不低于 min
const budget = Math.floor((max - min) * effortRatio + min)
const { max } = findTokenLimit(model.id) || { max: 0 }
const budget = Math.floor(max * effortRatio)
return {
thinkingConfig: {
@@ -458,7 +466,7 @@ export class GeminiAPIClient extends BaseApiClient<
systemInstruction = await buildSystemPrompt(assistant.prompt || '', mcpTools, assistant)
}
let messageContents: Content = { role: 'user', parts: [] } // Initialize messageContents
let messageContents: Content
const history: Content[] = []
// 3. 处理用户消息
if (typeof messages === 'string') {
@@ -467,13 +475,10 @@ export class GeminiAPIClient extends BaseApiClient<
parts: [{ text: messages }]
}
} else {
const userLastMessage = messages.pop()
if (userLastMessage) {
messageContents = await this.convertMessageToSdkParam(userLastMessage)
for (const message of messages) {
history.push(await this.convertMessageToSdkParam(message))
}
messages.push(userLastMessage)
const userLastMessage = messages.pop()!
messageContents = await this.convertMessageToSdkParam(userLastMessage)
for (const message of messages) {
history.push(await this.convertMessageToSdkParam(message))
}
}
@@ -486,10 +491,6 @@ export class GeminiAPIClient extends BaseApiClient<
if (isGemmaModel(model) && assistant.prompt) {
const isFirstMessage = history.length === 0
if (isFirstMessage && messageContents) {
const userMessageText =
messageContents.parts && messageContents.parts.length > 0
? (messageContents.parts[0] as Part).text || ''
: ''
const systemMessage = [
{
text:
@@ -497,7 +498,7 @@ export class GeminiAPIClient extends BaseApiClient<
systemInstruction +
'<end_of_turn>\n' +
'<start_of_turn>user\n' +
userMessageText +
(messageContents?.parts?.[0] as Part).text +
'<end_of_turn>'
}
] as Part[]
@@ -514,7 +515,13 @@ export class GeminiAPIClient extends BaseApiClient<
const newMessageContents =
isRecursiveCall && recursiveSdkMessages && recursiveSdkMessages.length > 0
? recursiveSdkMessages[recursiveSdkMessages.length - 1]
? {
...messageContents,
parts: [
...(messageContents.parts || []),
...(recursiveSdkMessages[recursiveSdkMessages.length - 1].parts || [])
]
}
: messageContents
const generateContentConfig: GenerateContentConfig = {
@@ -526,8 +533,7 @@ export class GeminiAPIClient extends BaseApiClient<
tools: tools,
...(enableGenerateImage ? this.getGenerateImageParameter() : {}),
...this.getBudgetToken(assistant, model),
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
...this.getCustomParameters(assistant)
}
const param: GeminiSdkParams = {
@@ -549,7 +555,7 @@ export class GeminiAPIClient extends BaseApiClient<
getResponseChunkTransformer(): ResponseChunkTransformer<GeminiSdkRawChunk> {
return () => ({
async transform(chunk: GeminiSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
const toolCalls: FunctionCall[] = []
let toolCalls: FunctionCall[] = []
if (chunk.candidates && chunk.candidates.length > 0) {
for (const candidate of chunk.candidates) {
if (candidate.content) {
@@ -577,8 +583,6 @@ export class GeminiAPIClient extends BaseApiClient<
]
}
})
} else if (part.functionCall) {
toolCalls.push(part.functionCall)
}
})
}
@@ -593,6 +597,9 @@ export class GeminiAPIClient extends BaseApiClient<
}
} as LLMWebSearchCompleteChunk)
}
if (chunk.functionCalls) {
toolCalls = toolCalls.concat(chunk.functionCalls)
}
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
@@ -678,19 +685,16 @@ export class GeminiAPIClient extends BaseApiClient<
toolCalls: FunctionCall[]
): Content[] {
const parts: Part[] = []
const modelParts: Part[] = []
if (output) {
modelParts.push({
parts.push({
text: output
})
}
toolCalls.forEach((toolCall) => {
modelParts.push({
parts.push({
functionCall: toolCall
})
})
parts.push(
...toolResults
.map((ts) => ts.parts)
@@ -700,21 +704,10 @@ export class GeminiAPIClient extends BaseApiClient<
const userMessage: Content = {
role: 'user',
parts: []
parts: parts
}
if (modelParts.length > 0) {
currentReqMessages.push({
role: 'model',
parts: modelParts
})
}
if (parts.length > 0) {
userMessage.parts?.push(...parts)
currentReqMessages.push(userMessage)
}
return currentReqMessages
return [...currentReqMessages, userMessage]
}
override estimateMessageTokens(message: GeminiSdkMessageParam): number {
@@ -741,27 +734,64 @@ export class GeminiAPIClient extends BaseApiClient<
}
public extractMessagesFromSdkPayload(sdkPayload: GeminiSdkParams): GeminiSdkMessageParam[] {
const messageParam: GeminiSdkMessageParam = {
role: 'user',
parts: []
}
if (Array.isArray(sdkPayload.message)) {
sdkPayload.message.forEach((part) => {
if (typeof part === 'string') {
messageParam.parts?.push({ text: part })
} else if (typeof part === 'object') {
messageParam.parts?.push(part)
}
})
}
return [...(sdkPayload.history || []), messageParam]
return sdkPayload.history || []
}
private async base64File(file: FileMetadata) {
private async uploadFile(file: FileType): Promise<File> {
return await this.sdkInstance!.files.upload({
file: file.path,
config: {
mimeType: 'application/pdf',
name: file.id,
displayName: file.origin_name
}
})
}
private async base64File(file: FileType) {
const { data } = await window.api.file.base64File(file.id + file.ext)
return {
data,
mimeType: 'application/pdf'
}
}
private async retrieveFile(file: FileType): Promise<File | undefined> {
const cachedResponse = CacheService.get<any>('gemini_file_list')
if (cachedResponse) {
return this.processResponse(cachedResponse, file)
}
const response = await this.sdkInstance!.files.list()
CacheService.set('gemini_file_list', response, 3000)
return this.processResponse(response, file)
}
private async processResponse(response: Pager<File>, file: FileType) {
for await (const f of response) {
if (f.state === FileState.ACTIVE) {
if (f.displayName === file.origin_name && Number(f.sizeBytes) === file.size) {
return f
}
}
}
return undefined
}
// @ts-ignore unused
private async listFiles(): Promise<File[]> {
const files: File[] = []
for await (const f of await this.sdkInstance!.files.list()) {
files.push(f)
}
return files
}
// @ts-ignore unused
private async deleteFile(fileId: string) {
await this.sdkInstance!.files.delete({ name: fileId })
}
}

View File

@@ -113,12 +113,6 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
}
if (!reasoningEffort) {
if (model.provider === 'openrouter') {
if (isSupportedThinkingTokenGeminiModel(model) && !GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
return {}
}
return { reasoning: { enabled: false, exclude: true } }
}
if (isSupportedThinkingTokenQwenModel(model)) {
return { enable_thinking: false }
}
@@ -128,16 +122,12 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
}
if (isSupportedThinkingTokenGeminiModel(model)) {
// openrouter没有提供一个不推理的选项先隐藏
if (this.provider.id === 'openrouter') {
return { reasoning: { max_tokens: 0, exclude: true } }
}
if (GEMINI_FLASH_MODEL_REGEX.test(model.id)) {
return {
extra_body: {
google: {
thinking_config: {
thinking_budget: 0
}
}
}
}
return { reasoning_effort: 'none' }
}
return {}
}
@@ -180,37 +170,12 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
}
// OpenAI models
if (isSupportedReasoningEffortOpenAIModel(model)) {
if (isSupportedReasoningEffortOpenAIModel(model) || isSupportedThinkingTokenGeminiModel(model)) {
return {
reasoning_effort: reasoningEffort
}
}
if (isSupportedThinkingTokenGeminiModel(model)) {
if (reasoningEffort === 'auto') {
return {
extra_body: {
google: {
thinking_config: {
thinking_budget: -1,
include_thoughts: true
}
}
}
}
}
return {
extra_body: {
google: {
thinking_config: {
thinking_budget: budgetTokens,
include_thoughts: true
}
}
}
}
}
// Claude models
if (isSupportedThinkingTokenClaudeModel(model)) {
const maxTokens = assistant.settings?.maxTokens
@@ -372,14 +337,10 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
public buildSdkMessages(
currentReqMessages: OpenAISdkMessageParam[],
output: string | undefined,
output: string,
toolResults: OpenAISdkMessageParam[],
toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[]
): OpenAISdkMessageParam[] {
if (!output && toolCalls.length === 0) {
return [...currentReqMessages, ...toolResults]
}
const assistantMessage: OpenAISdkMessageParam = {
role: 'assistant',
content: output,
@@ -507,8 +468,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
...this.getProviderSpecificParameters(assistant, model),
...this.getReasoningEffort(assistant, model),
...getOpenAIWebSearchParams(model, enableWebSearch),
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
...this.getCustomParameters(assistant)
}
// Create the appropriate parameters object based on whether streaming is enabled
@@ -530,7 +490,7 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
}
// 在RawSdkChunkToGenericChunkMiddleware中使用
getResponseChunkTransformer(): ResponseChunkTransformer<OpenAISdkRawChunk> {
getResponseChunkTransformer = (): ResponseChunkTransformer<OpenAISdkRawChunk> => {
let hasBeenCollectedWebSearch = false
const collectWebSearchData = (
chunk: OpenAISdkRawChunk,
@@ -624,67 +584,18 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
return null
}
const toolCalls: OpenAI.Chat.Completions.ChatCompletionMessageToolCall[] = []
let isFinished = false
let lastUsageInfo: any = null
/**
* 统一的完成信号发送逻辑
* - 有 finish_reason 时
* - 无 finish_reason 但是流正常结束时
*/
const emitCompletionSignals = (controller: TransformStreamDefaultController<GenericChunk>) => {
if (isFinished) return
if (toolCalls.length > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
}
const usage = lastUsageInfo || {
prompt_tokens: 0,
completion_tokens: 0,
total_tokens: 0
}
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: { usage }
})
// 防止重复发送
isFinished = true
}
return (context: ResponseChunkTransformerContext) => ({
async transform(chunk: OpenAISdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
// 持续更新usage信息
if (chunk.usage) {
lastUsageInfo = {
prompt_tokens: chunk.usage.prompt_tokens || 0,
completion_tokens: chunk.usage.completion_tokens || 0,
total_tokens: (chunk.usage.prompt_tokens || 0) + (chunk.usage.completion_tokens || 0)
}
}
// 处理chunk
if ('choices' in chunk && chunk.choices && chunk.choices.length > 0) {
const choice = chunk.choices[0]
if (!choice) return
// 对于流式响应,使用 delta对于非流式响应使用 message
// 然而某些 OpenAI 兼容平台在非流式请求时会错误地返回一个空对象的 delta 字段。
// 如果 delta 为空对象,应当忽略它并回退到 message避免造成内容缺失。
let contentSource: OpenAISdkRawContentSource | null = null
if ('delta' in choice && choice.delta && Object.keys(choice.delta).length > 0) {
contentSource = choice.delta
} else if ('message' in choice) {
contentSource = choice.message
}
// 对于流式响应使用delta对于非流式响应使用message
const contentSource: OpenAISdkRawContentSource | null =
'delta' in choice ? choice.delta : 'message' in choice ? choice.message : null
if (!contentSource) return
@@ -740,6 +651,12 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
// 处理finish_reason发送流结束信号
if ('finish_reason' in choice && choice.finish_reason) {
Logger.debug(`[OpenAIApiClient] Stream finished with reason: ${choice.finish_reason}`)
if (toolCalls.length > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
}
const webSearchData = collectWebSearchData(chunk, contentSource, context)
if (webSearchData) {
controller.enqueue({
@@ -747,17 +664,18 @@ export class OpenAIAPIClient extends OpenAIBaseClient<
llm_web_search: webSearchData
})
}
emitCompletionSignals(controller)
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: chunk.usage?.prompt_tokens || 0,
completion_tokens: chunk.usage?.completion_tokens || 0,
total_tokens: (chunk.usage?.prompt_tokens || 0) + (chunk.usage?.completion_tokens || 0)
}
}
})
}
}
},
// 流正常结束时,检查是否需要发送完成信号
flush(controller) {
if (isFinished) return
Logger.debug('[OpenAIApiClient] Stream ended without finish_reason, emitting fallback completion signals')
emitCompletionSignals(controller)
}
})
}

View File

@@ -85,13 +85,16 @@ export abstract class OpenAIBaseClient<
override async getEmbeddingDimensions(model: Model): Promise<number> {
const sdk = await this.getSdkInstance()
const data = await sdk.embeddings.create({
model: model.id,
input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi',
encoding_format: 'float'
})
return data.data[0].embedding.length
try {
const data = await sdk.embeddings.create({
model: model.id,
input: model?.provider === 'baidu-cloud' ? ['hi'] : 'hi',
encoding_format: 'float'
})
return data.data[0].embedding.length
} catch (e) {
return 0
}
}
override async listModels(): Promise<OpenAI.Models.Model[]> {
@@ -135,7 +138,7 @@ export abstract class OpenAIBaseClient<
return this.sdkInstance
}
let apiKeyForSdkInstance = this.apiKey
let apiKeyForSdkInstance = this.provider.apiKey
if (this.provider.id === 'copilot') {
const defaultHeaders = store.getState().copilot.defaultHeaders
@@ -159,7 +162,6 @@ export abstract class OpenAIBaseClient<
baseURL: this.getBaseURL(),
defaultHeaders: {
...this.defaultHeaders(),
...this.provider.extra_headers,
...(this.provider.id === 'copilot' ? { 'editor-version': 'vscode/1.97.2' } : {}),
...(this.provider.id === 'copilot' ? { 'copilot-vision-request': 'true' } : {})
}

View File

@@ -1,5 +1,4 @@
import { GenericChunk } from '@renderer/aiCore/middleware/schemas'
import { CompletionsContext } from '@renderer/aiCore/middleware/types'
import {
isOpenAIChatCompletionOnlyModel,
isSupportedReasoningEffortOpenAIModel,
@@ -7,7 +6,7 @@ import {
} from '@renderer/config/models'
import { estimateTextTokens } from '@renderer/services/TokenService'
import {
FileMetadata,
FileType,
FileTypes,
MCPCallToolResponse,
MCPTool,
@@ -39,7 +38,6 @@ import { buildSystemPrompt } from '@renderer/utils/prompt'
import { MB } from '@shared/config/constant'
import { isEmpty } from 'lodash'
import OpenAI from 'openai'
import { ResponseInput } from 'openai/resources/responses/responses'
import { RequestTransformer, ResponseChunkTransformer } from '../types'
import { OpenAIAPIClient } from './OpenAIApiClient'
@@ -78,11 +76,10 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
return new OpenAI({
dangerouslyAllowBrowser: true,
apiKey: this.apiKey,
apiKey: this.provider.apiKey,
baseURL: this.getBaseURL(),
defaultHeaders: {
...this.defaultHeaders(),
...this.provider.extra_headers
...this.defaultHeaders()
}
})
}
@@ -95,7 +92,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
return await sdk.responses.create(payload, options)
}
private async handlePdfFile(file: FileMetadata): Promise<OpenAI.Responses.ResponseInputFile | undefined> {
private async handlePdfFile(file: FileType): Promise<OpenAI.Responses.ResponseInputFile | undefined> {
if (file.size > 32 * MB) return undefined
try {
const pageCount = await window.api.file.pdfInfo(file.id + file.ext)
@@ -228,29 +225,17 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
return
}
private convertResponseToMessageContent(response: OpenAI.Responses.Response): ResponseInput {
const content: OpenAI.Responses.ResponseInput = []
content.push(...response.output)
return content
}
public buildSdkMessages(
currentReqMessages: OpenAIResponseSdkMessageParam[],
output: OpenAI.Responses.Response | undefined,
output: string,
toolResults: OpenAIResponseSdkMessageParam[],
toolCalls: OpenAIResponseSdkToolCall[]
): OpenAIResponseSdkMessageParam[] {
if (!output && toolCalls.length === 0) {
return [...currentReqMessages, ...toolResults]
const assistantMessage: OpenAIResponseSdkMessageParam = {
role: 'assistant',
content: [{ type: 'input_text', text: output }]
}
if (!output) {
return [...currentReqMessages, ...(toolCalls || []), ...(toolResults || [])]
}
const content = this.convertResponseToMessageContent(output)
const newReqMessages = [...currentReqMessages, ...content, ...(toolResults || [])]
const newReqMessages = [...currentReqMessages, assistantMessage, ...(toolCalls || []), ...(toolResults || [])]
return newReqMessages
}
@@ -386,6 +371,10 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
})
}
const toolChoices: OpenAI.Responses.ToolChoiceTypes = {
type: 'web_search_preview'
}
tools = tools.concat(extraTools)
const commonParams = {
model: model.id,
@@ -398,10 +387,10 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
max_output_tokens: maxTokens,
stream: streamOutput,
tools: !isEmpty(tools) ? tools : undefined,
tool_choice: enableWebSearch ? toolChoices : undefined,
service_tier: this.getServiceTier(model),
...(this.getReasoningEffort(assistant, model) as OpenAI.Reasoning),
// 只在对话场景下应用自定义参数,避免影响翻译、总结等其他业务逻辑
...(coreRequest.callType === 'chat' ? this.getCustomParameters(assistant) : {})
...this.getCustomParameters(assistant)
}
const sdkParams: OpenAIResponseSdkParams = streamOutput
? {
@@ -418,18 +407,13 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
}
}
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<OpenAIResponseSdkRawChunk> {
getResponseChunkTransformer(): ResponseChunkTransformer<OpenAIResponseSdkRawChunk> {
const toolCalls: OpenAIResponseSdkToolCall[] = []
const outputItems: OpenAI.Responses.ResponseOutputItem[] = []
let hasBeenCollectedToolCalls = false
let hasReasoningSummary = false
return () => ({
async transform(chunk: OpenAIResponseSdkRawChunk, controller: TransformStreamDefaultController<GenericChunk>) {
// 处理chunk
if ('output' in chunk) {
if (ctx._internal?.toolProcessingState) {
ctx._internal.toolProcessingState.output = chunk
}
for (const output of chunk.output) {
switch (output.type) {
case 'message':
@@ -471,22 +455,6 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
})
}
}
if (toolCalls.length > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
}
controller.enqueue({
type: ChunkType.LLM_RESPONSE_COMPLETE,
response: {
usage: {
prompt_tokens: chunk.usage?.input_tokens || 0,
completion_tokens: chunk.usage?.output_tokens || 0,
total_tokens: chunk.usage?.total_tokens || 0
}
}
})
} else {
switch (chunk.type) {
case 'response.output_item.added':
@@ -494,16 +462,6 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
outputItems.push(chunk.item)
}
break
case 'response.reasoning_summary_part.added':
if (hasReasoningSummary) {
const separator = '\n\n'
controller.enqueue({
type: ChunkType.THINKING_DELTA,
text: separator
})
}
hasReasoningSummary = true
break
case 'response.reasoning_summary_text.delta':
controller.enqueue({
type: ChunkType.THINKING_DELTA,
@@ -544,8 +502,7 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
if (outputItem.type === 'function_call') {
toolCalls.push({
...outputItem,
arguments: chunk.arguments,
status: 'completed'
arguments: chunk.arguments
})
}
}
@@ -561,26 +518,15 @@ export class OpenAIResponseAPIClient extends OpenAIBaseClient<
}
})
}
if (toolCalls.length > 0 && !hasBeenCollectedToolCalls) {
if (toolCalls.length > 0) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
hasBeenCollectedToolCalls = true
}
break
}
case 'response.completed': {
if (ctx._internal?.toolProcessingState) {
ctx._internal.toolProcessingState.output = chunk.response
}
if (toolCalls.length > 0 && !hasBeenCollectedToolCalls) {
controller.enqueue({
type: ChunkType.MCP_TOOL_CREATED,
tool_calls: toolCalls
})
hasBeenCollectedToolCalls = true
}
const completion_tokens = chunk.response.usage?.output_tokens || 0
const total_tokens = chunk.response.usage?.total_tokens || 0
controller.enqueue({

View File

@@ -1,65 +0,0 @@
import { isSupportedModel } from '@renderer/config/models'
import { Provider } from '@renderer/types'
import OpenAI from 'openai'
import { OpenAIAPIClient } from '../openai/OpenAIApiClient'
export class PPIOAPIClient extends OpenAIAPIClient {
constructor(provider: Provider) {
super(provider)
}
override async listModels(): Promise<OpenAI.Models.Model[]> {
try {
const sdk = await this.getSdkInstance()
// PPIO requires three separate requests to get all model types
const [chatModelsResponse, embeddingModelsResponse, rerankerModelsResponse] = await Promise.all([
// Chat/completion models
sdk.request({
method: 'get',
path: '/models'
}),
// Embedding models
sdk.request({
method: 'get',
path: '/models?model_type=embedding'
}),
// Reranker models
sdk.request({
method: 'get',
path: '/models?model_type=reranker'
})
])
// Extract models from all responses
// @ts-ignore - PPIO response structure may not be typed
const allModels = [
...((chatModelsResponse as any)?.data || []),
...((embeddingModelsResponse as any)?.data || []),
...((rerankerModelsResponse as any)?.data || [])
]
// Process and standardize model data
const processedModels = allModels.map((model: any) => ({
id: model.id || model.name,
description: model.description || model.display_name || model.summary,
object: 'model' as const,
owned_by: model.owned_by || model.publisher || model.organization || 'ppio',
created: model.created || Date.now()
}))
// Clean up model IDs and filter supported models
processedModels.forEach((model) => {
if (model.id) {
model.id = model.id.trim()
}
})
return processedModels.filter(isSupportedModel)
} catch (error) {
console.error('Error listing PPIO models:', error)
return []
}
}
}

View File

@@ -3,8 +3,6 @@ import { Assistant, MCPTool, MCPToolResponse, Model, ToolCallResponse } from '@r
import { Provider } from '@renderer/types'
import {
AnthropicSdkRawChunk,
OpenAIResponseSdkRawChunk,
OpenAIResponseSdkRawOutput,
OpenAISdkRawChunk,
SdkMessageParam,
SdkParams,
@@ -16,7 +14,6 @@ import {
import OpenAI from 'openai'
import { CompletionsParams, GenericChunk } from '../middleware/schemas'
import { CompletionsContext } from '../middleware/types'
/**
* 原始流监听器接口
@@ -36,14 +33,6 @@ export interface OpenAIStreamListener extends RawStreamListener<OpenAISdkRawChun
onFinishReason?: (reason: string) => void
}
/**
* OpenAI Response 专用的流监听器
*/
export interface OpenAIResponseStreamListener<TChunk extends OpenAIResponseSdkRawChunk = OpenAIResponseSdkRawChunk>
extends RawStreamListener<TChunk> {
onMessage?: (response: OpenAIResponseSdkRawOutput) => void
}
/**
* Anthropic 专用的流监听器
*/
@@ -112,7 +101,7 @@ export interface ApiClient<
// SDK相关方法
getSdkInstance(): Promise<TSdkInstance> | TSdkInstance
getRequestTransformer(): RequestTransformer<TSdkParams, TMessageParam>
getResponseChunkTransformer(ctx: CompletionsContext): ResponseChunkTransformer<TRawChunk>
getResponseChunkTransformer(): ResponseChunkTransformer<TRawChunk>
// 原始流监听方法
attachRawStreamListener?(rawOutput: TRawOutput, listener: RawStreamListener<TRawChunk>): TRawOutput

View File

@@ -11,7 +11,6 @@ import { AnthropicAPIClient } from './clients/anthropic/AnthropicAPIClient'
import { OpenAIResponseAPIClient } from './clients/openai/OpenAIResponseAPIClient'
import { CompletionsMiddlewareBuilder } from './middleware/builder'
import { MIDDLEWARE_NAME as AbortHandlerMiddlewareName } from './middleware/common/AbortHandlerMiddleware'
import { MIDDLEWARE_NAME as ErrorHandlerMiddlewareName } from './middleware/common/ErrorHandlerMiddleware'
import { MIDDLEWARE_NAME as FinalChunkConsumerMiddlewareName } from './middleware/common/FinalChunkConsumerMiddleware'
import { applyCompletionsMiddlewares } from './middleware/composer'
import { MIDDLEWARE_NAME as McpToolChunkMiddlewareName } from './middleware/core/McpToolChunkMiddleware'
@@ -63,7 +62,6 @@ export default class AiProvider {
builder.clear()
builder
.add(MiddlewareRegistry[FinalChunkConsumerMiddlewareName])
.add(MiddlewareRegistry[ErrorHandlerMiddlewareName])
.add(MiddlewareRegistry[AbortHandlerMiddlewareName])
.add(MiddlewareRegistry[ImageGenerationMiddlewareName])
} else {
@@ -76,7 +74,7 @@ export default class AiProvider {
if (!(this.apiClient instanceof OpenAIAPIClient)) {
builder.remove(ThinkingTagExtractionMiddlewareName)
}
if (!(this.apiClient instanceof AnthropicAPIClient) && !(this.apiClient instanceof OpenAIResponseAPIClient)) {
if (!(this.apiClient instanceof AnthropicAPIClient)) {
builder.remove(RawStreamListenerMiddlewareName)
}
if (!params.enableWebSearch) {
@@ -114,7 +112,7 @@ export default class AiProvider {
return dimensions
} catch (error) {
console.error('Error getting embedding dimensions:', error)
throw error
return 0
}
}

View File

@@ -1,4 +1,5 @@
import { Chunk } from '@renderer/types/chunk'
import { isAbortError } from '@renderer/utils/error'
import { CompletionsResult } from '../schemas'
import { CompletionsContext } from '../types'
@@ -25,26 +26,29 @@ export const ErrorHandlerMiddleware =
// 尝试执行下一个中间件
return await next(ctx, params)
} catch (error: any) {
console.log('ErrorHandlerMiddleware_error', error)
// 1. 使用通用的工具函数将错误解析为标准格式
const errorChunk = createErrorChunk(error)
// 2. 调用从外部传入的 onError 回调
if (params.onError) {
params.onError(error)
}
// 3. 根据配置决定是重新抛出错误,还是将其作为流的一部分向下传递
if (shouldThrow) {
throw error
}
// 如果不抛出,则创建一个只包含该错误块的流并向下传递
const errorStream = new ReadableStream<Chunk>({
start(controller) {
controller.enqueue(errorChunk)
controller.close()
let errorStream: ReadableStream<Chunk> | undefined
// 有些sdk的abort error 是直接抛出的
if (!isAbortError(error)) {
// 1. 使用通用的工具函数将错误解析为标准格式
const errorChunk = createErrorChunk(error)
// 2. 调用从外部传入的 onError 回调
if (params.onError) {
params.onError(error)
}
})
// 3. 根据配置决定是重新抛出错误,还是将其作为流的一部分向下传递
if (shouldThrow) {
throw error
}
// 如果不抛出,则创建一个只包含该错误块的流并向下传递
errorStream = new ReadableStream<Chunk>({
start(controller) {
controller.enqueue(errorChunk)
controller.close()
}
})
}
return {
rawOutput: undefined,

View File

@@ -153,7 +153,7 @@ function createToolHandlingTransform(
if (toolResult.length > 0) {
const output = ctx._internal.toolProcessingState?.output
const newParams = buildParamsWithToolResults(ctx, currentParams, output, toolResult, toolCalls)
const newParams = buildParamsWithToolResults(ctx, currentParams, output!, toolResult, toolCalls)
await executeWithToolHandling(newParams, depth + 1)
}
} catch (error) {
@@ -243,7 +243,7 @@ async function executeToolUseResponses(
function buildParamsWithToolResults(
ctx: CompletionsContext,
currentParams: CompletionsParams,
output: SdkRawOutput | string | undefined,
output: SdkRawOutput | string,
toolResults: SdkMessageParam[],
toolCalls: SdkToolCall[]
): CompletionsParams {
@@ -255,10 +255,6 @@ function buildParamsWithToolResults(
// 从回复中构建助手消息
const newReqMessages = apiClient.buildSdkMessages(currentReqMessages, output, toolResults, toolCalls)
if (output && ctx._internal.toolProcessingState) {
ctx._internal.toolProcessingState.output = undefined
}
// 估算新增消息的 token 消耗并累加到 usage 中
if (ctx._internal.observer?.usage && newReqMessages.length > currentReqMessages.length) {
try {

View File

@@ -15,6 +15,8 @@ export const RawStreamListenerMiddleware: CompletionsMiddleware =
// 在这里可以监听到从SDK返回的最原始流
if (result.rawOutput) {
console.log(`[${MIDDLEWARE_NAME}] 检测到原始SDK输出准备附加监听器`)
const providerType = ctx.apiClientInstance.provider.type
// TODO: 后面下放到AnthropicAPIClient
if (providerType === 'anthropic') {

View File

@@ -37,7 +37,7 @@ export const ResponseTransformMiddleware: CompletionsMiddleware =
}
// 获取响应转换器
const responseChunkTransformer = apiClient.getResponseChunkTransformer(ctx)
const responseChunkTransformer = apiClient.getResponseChunkTransformer?.()
if (!responseChunkTransformer) {
Logger.warn(`[${MIDDLEWARE_NAME}] No ResponseChunkTransformer available, skipping transformation`)
return result

View File

@@ -25,6 +25,7 @@ export const StreamAdapterMiddleware: CompletionsMiddleware =
// 但是这个中间件的职责是流适配,是否在这调用优待商榷
// 调用下游中间件
const result = await next(ctx, params)
if (
result.rawOutput &&
!(result.rawOutput instanceof ReadableStream) &&

View File

@@ -14,6 +14,8 @@ export const TransformCoreToSdkParamsMiddleware: CompletionsMiddleware =
() =>
(next) =>
async (ctx: CompletionsContext, params: CompletionsParams): Promise<CompletionsResult> => {
Logger.debug(`🔄 [${MIDDLEWARE_NAME}] Starting core to SDK params transformation:`, ctx)
const internal = ctx._internal
// 🔧 检测递归调用:检查 params 中是否携带了预处理的 SDK 消息

View File

@@ -1,5 +1,5 @@
import { ChunkType } from '@renderer/types/chunk'
import { flushLinkConverterBuffer, smartLinkConverter } from '@renderer/utils/linkConverter'
import { smartLinkConverter } from '@renderer/utils/linkConverter'
import { CompletionsParams, CompletionsResult, GenericChunk } from '../schemas'
import { CompletionsContext, CompletionsMiddleware } from '../types'
@@ -42,46 +42,20 @@ export const WebSearchMiddleware: CompletionsMiddleware =
const providerType = model.provider || 'openai'
// 使用当前可用的Web搜索结果进行链接转换
const text = chunk.text
const result = smartLinkConverter(text, providerType, isFirstChunk)
const processedText = smartLinkConverter(text, providerType, isFirstChunk)
if (isFirstChunk) {
isFirstChunk = false
}
// - 如果有内容被缓冲说明convertLinks正在等待后续chunk不使用原文本避免重复
// - 如果没有内容被缓冲且结果为空,可能是其他处理问题,使用原文本作为安全回退
let finalText: string
if (result.hasBufferedContent) {
// 有内容被缓冲使用处理后的结果可能为空等待后续chunk
finalText = result.text
} else {
// 没有内容被缓冲,可以安全使用回退逻辑
finalText = result.text || text
}
// 只有当finalText不为空时才发送chunk
if (finalText) {
controller.enqueue({
...chunk,
text: finalText
})
}
controller.enqueue({
...chunk,
text: processedText
})
} else if (chunk.type === ChunkType.LLM_WEB_SEARCH_COMPLETE) {
// 暂存Web搜索结果用于链接完善
ctx._internal.webSearchState!.results = chunk.llm_web_search
// 将Web搜索完成事件继续传递下去
controller.enqueue(chunk)
} else if (chunk.type === ChunkType.LLM_RESPONSE_COMPLETE) {
// 流结束时清空链接转换器的buffer并处理剩余内容
const remainingText = flushLinkConverterBuffer()
if (remainingText) {
controller.enqueue({
type: ChunkType.TEXT_DELTA,
text: remainingText
})
}
// 继续传递LLM_RESPONSE_COMPLETE事件
controller.enqueue(chunk)
} else {
controller.enqueue(chunk)
}

View File

@@ -1,9 +1,7 @@
import { BaseApiClient } from '@renderer/aiCore/clients/BaseApiClient'
import { isDedicatedImageGenerationModel } from '@renderer/config/models'
import FileManager from '@renderer/services/FileManager'
import { ChunkType } from '@renderer/types/chunk'
import { findImageBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { defaultTimeout } from '@shared/config/constant'
import OpenAI from 'openai'
import { toFile } from 'openai/uploads'
@@ -19,6 +17,7 @@ export const ImageGenerationMiddleware: CompletionsMiddleware =
const { assistant, messages } = params
const client = context.apiClientInstance as BaseApiClient<OpenAI>
const signal = context._internal?.flowControl?.abortSignal
if (!assistant.model || !isDedicatedImageGenerationModel(assistant.model) || typeof messages === 'string') {
return next(context, params)
}
@@ -48,7 +47,7 @@ export const ImageGenerationMiddleware: CompletionsMiddleware =
const userImages = await Promise.all(
userImageBlocks.map(async (block) => {
if (!block.file) return null
const binaryData: Uint8Array = await FileManager.readBinaryImage(block.file)
const binaryData: Uint8Array = await window.api.file.binaryImage(block.file.id)
const mimeType = `${block.file.type}/${block.file.ext.slice(1)}`
return await toFile(new Blob([binaryData]), block.file.origin_name || 'image.png', { type: mimeType })
})
@@ -75,7 +74,8 @@ export const ImageGenerationMiddleware: CompletionsMiddleware =
const startTime = Date.now()
let response: OpenAI.Images.ImagesResponse
const options = { signal, timeout: defaultTimeout }
const options = { signal, timeout: 300_000 }
if (imageFiles.length > 0) {
response = await sdk.images.edit(

View File

@@ -11,13 +11,11 @@ export const MIDDLEWARE_NAME = 'ThinkingTagExtractionMiddleware'
// 不同模型的思考标签配置
const reasoningTags: TagConfig[] = [
{ openingTag: '<think>', closingTag: '</think>', separator: '\n' },
{ openingTag: '<thought>', closingTag: '</thought>', separator: '\n' },
{ openingTag: '###Thinking', closingTag: '###Response', separator: '\n' }
]
const getAppropriateTag = (model?: Model): TagConfig => {
if (model?.id?.includes('qwen3')) return reasoningTags[0]
if (model?.id?.includes('gemini-2.5')) return reasoningTags[1]
// 可以在这里添加更多模型特定的标签配置
return reasoningTags[0] // 默认使用 <think> 标签
}

View File

@@ -1,13 +0,0 @@
@font-face {
font-family: 'Twemoji Country Flags';
unicode-range:
U+1F1E6-1F1FF, U+1F3F4, U+E0062-E0063, U+E0065, U+E0067, U+E006C, U+E006E, U+E0073-E0074, U+E0077, U+E007F;
/*https://github.com/beyondkmp/country-flag-emoji-polyfill/blob/master/font/TwemojiCountryFlags.woff2 */
src: url('TwemojiCountryFlags.woff2') format('woff2');
font-display: swap;
}
/* 国旗字体样式类 */
.country-flag-font {
font-family: 'Twemoji Country Flags', 'Apple Color Emoji', 'Segoe UI Emoji', sans-serif;
}

View File

@@ -1,6 +1,6 @@
@font-face {
font-family: 'iconfont'; /* Project id 4753420 */
src: url('iconfont.woff2?t=1742793497518') format('woff2');
src: url('iconfont.woff2?t=1742184675192') format('woff2');
}
.iconfont {
@@ -11,18 +11,6 @@
-moz-osx-font-smoothing: grayscale;
}
.icon-plugin:before {
content: '\e612';
}
.icon-tools:before {
content: '\e762';
}
.icon-OCRshibie:before {
content: '\e658';
}
.icon-obsidian:before {
content: '\e677';
}

View File

@@ -1,8 +1 @@
<svg width="22" height="22" viewBox="13 -2 25 22" xmlns="http://www.w3.org/2000/svg">
<g id="White=False">
<g id="if">
<path d="M21.2002 3.73454C22.5633 3.73454 23.0666 2.89917 23.0666 1.86812C23.0666 0.837081 22.5623 0.00170898 21.2002 0.00170898C19.838 0.00170898 19.3337 0.837081 19.3337 1.86812C19.3337 2.89917 19.838 3.73454 21.2002 3.73454Z" fill="#0033FF"/>
<path d="M27.7336 4.13435V5.33473H24.6668V8.00171H27.7336V14.6687H22.6668V5.33567H15.9998V8.00265H19.7336V14.6696H15.3337V17.3366H35.3337V14.6696H30.6668V8.00265H35.3337V5.33567H30.6668V2.66869H35.3337V0.00170898H31.8671C29.5877 0.00170898 27.7336 1.8559 27.7336 4.13529V4.13435Z" fill="#0033FF"/>
</g>
</g>
</svg>
<svg fill="currentColor" fill-rule="evenodd" height="1em" style="flex:none;line-height:1" viewBox="0 0 24 24" width="1em" xmlns="http://www.w3.org/2000/svg"><title>Dify</title><clipPath id="lobe-icons-dify-fill"><path d="M1 0h10.286c6.627 0 12 5.373 12 12s-5.373 12-12 12H1V0z"></path></clipPath><foreignObject clip-path="url(#lobe-icons-dify-fill)" height="24" style="background:conic-gradient(from 180deg at 50% 50%, #0222C3, #8FB1F4, #FFFFFF)" width="24"></foreignObject></svg>

Before

Width:  |  Height:  |  Size: 680 B

After

Width:  |  Height:  |  Size: 480 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 48 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 16 KiB

Some files were not shown because too many files have changed in this diff Show More