diff --git a/src/main/ocr/BaseOcrProvider.ts b/src/main/ocr/BaseOcrProvider.ts new file mode 100644 index 000000000..c9025f355 --- /dev/null +++ b/src/main/ocr/BaseOcrProvider.ts @@ -0,0 +1,200 @@ +import fs from 'node:fs' +import path from 'node:path' + +import { windowService } from '@main/services/WindowService' +import { getFileExt } from '@main/utils/file' +import { FileMetadata, OcrProvider } from '@types' +import { createCanvas, loadImage } from 'canvas' +import { app } from 'electron' +import { TypedArray } from 'pdfjs-dist/types/src/display/api' + +export default abstract class BaseOcrProvider { + protected provider: OcrProvider + public storageDir = path.join(app.getPath('userData'), 'Data', 'Files') + + constructor(provider: OcrProvider) { + if (!provider) { + throw new Error('OCR provider is not set') + } + this.provider = provider + } + abstract parseFile(sourceId: string, file: FileMetadata): Promise<{ processedFile: FileMetadata }> + + /** + * 检查文件是否已经被预处理过 + * 统一检测方法:如果 Data/Files/{file.id} 是目录,说明已被预处理 + * @param file 文件信息 + * @returns 如果已处理返回处理后的文件信息,否则返回null + */ + public async checkIfAlreadyProcessed(file: FileMetadata): Promise { + try { + // 检查 Data/Files/{file.id} 是否是目录 + const preprocessDirPath = path.join(this.storageDir, file.id) + + if (fs.existsSync(preprocessDirPath)) { + const stats = await fs.promises.stat(preprocessDirPath) + + // 如果是目录,说明已经被预处理过 + if (stats.isDirectory()) { + // 查找目录中的处理结果文件 + const files = await fs.promises.readdir(preprocessDirPath) + + // 查找主要的处理结果文件(.md 或 .txt) + const processedFile = files.find((fileName) => fileName.endsWith('.md') || fileName.endsWith('.txt')) + + if (processedFile) { + const processedFilePath = path.join(preprocessDirPath, processedFile) + const processedStats = await fs.promises.stat(processedFilePath) + const ext = getFileExt(processedFile) + + return { + ...file, + name: file.name.replace(file.ext, ext), + path: processedFilePath, + ext: ext, + size: processedStats.size, + created_at: processedStats.birthtime.toISOString() + } + } + } + } + + return null + } catch (error) { + // 如果检查过程中出现错误,返回null表示未处理 + return null + } + } + + /** + * 辅助方法:延迟执行 + */ + public delay = (ms: number): Promise => { + return new Promise((resolve) => setTimeout(resolve, ms)) + } + + public async readPdf( + source: string | URL | TypedArray, + passwordCallback?: (fn: (password: string) => void, reason: string) => string + ) { + const { getDocument } = await import('pdfjs-dist/legacy/build/pdf.mjs') + const documentLoadingTask = getDocument(source) + if (passwordCallback) { + documentLoadingTask.onPassword = passwordCallback + } + + const document = await documentLoadingTask.promise + return document + } + + public async sendOcrProgress(sourceId: string, progress: number): Promise { + const mainWindow = windowService.getMainWindow() + mainWindow?.webContents.send('file-preprocess-progress', { + itemId: sourceId, + progress: progress + }) + } + + /** + * 将文件移动到附件目录 + * @param fileId 文件id + * @param filePaths 需要移动的文件路径数组 + * @returns 移动后的文件路径数组 + */ + public moveToAttachmentsDir(fileId: string, filePaths: string[]): string[] { + const attachmentsPath = path.join(this.storageDir, fileId) + if (!fs.existsSync(attachmentsPath)) { + fs.mkdirSync(attachmentsPath, { recursive: true }) + } + + const movedPaths: string[] = [] + + for (const filePath of filePaths) { + if (fs.existsSync(filePath)) { + const fileName = path.basename(filePath) + const destPath = path.join(attachmentsPath, fileName) + fs.copyFileSync(filePath, destPath) + fs.unlinkSync(filePath) // 删除原文件,实现"移动" + movedPaths.push(destPath) + } + } + return movedPaths + } + + public async cropImage(image: Buffer | string) { + const img = await loadImage(image) + const width = img.width + const height = img.height + + const canvas = createCanvas(width, height) + const context = canvas.getContext('2d') + + context.drawImage(img, 0, 0) + + const data = context.getImageData(0, 0, width, height).data + + const top = scanY(true) + const bottom = scanY(false) + const left = scanX(true) + const right = scanX(false) + + if (top === null || bottom === null || left === null || right === null) { + console.error('image is empty') + return canvas.toBuffer() + } + + const new_width = right - left + const new_height = bottom - top + + canvas.width = new_width + canvas.height = new_height + + context.drawImage(img, left, top, new_width, new_height, 0, 0, new_width, new_height) + + return canvas.toBuffer() + + // get pixel RGB data: + function getRGB(x: number, y: number) { + return { + red: data[(width * y + x) * 4], + green: data[(width * y + x) * 4 + 1], + blue: data[(width * y + x) * 4 + 2] + } + } + + // check if pixel is a color other than white: + function isColor(rgb: { red: number; green: number; blue: number }) { + return rgb.red == 255 && rgb.green == 255 && rgb.blue == 255 + } + + // scan top and bottom edges of image: + function scanY(top: boolean) { + const offset = top ? 1 : -1 + + for (let y = top ? 0 : height - 1; top ? y < height : y > -1; y += offset) { + for (let x = 0; x < width; x++) { + if (!isColor(getRGB(x, y))) { + return y + } + } + } + + return null + } + + // scan left and right edges of image: + function scanX(left: boolean) { + const offset = left ? 1 : -1 + + for (let x = left ? 0 : width - 1; left ? x < width : x > -1; x += offset) { + for (let y = 0; y < height; y++) { + if (!isColor(getRGB(x, y))) { + return x + } + } + } + + return null + } + } +} diff --git a/src/main/ocr/DefaultOcrProvider.ts b/src/main/ocr/DefaultOcrProvider.ts new file mode 100644 index 000000000..83c8d51c9 --- /dev/null +++ b/src/main/ocr/DefaultOcrProvider.ts @@ -0,0 +1,12 @@ +import { FileMetadata, OcrProvider } from '@types' + +import BaseOcrProvider from './BaseOcrProvider' + +export default class DefaultOcrProvider extends BaseOcrProvider { + constructor(provider: OcrProvider) { + super(provider) + } + public parseFile(): Promise<{ processedFile: FileMetadata }> { + throw new Error('Method not implemented.') + } +} diff --git a/src/main/preprocess/MacSysOcrProvider.ts b/src/main/ocr/MacSysOcrProvider.ts similarity index 92% rename from src/main/preprocess/MacSysOcrProvider.ts rename to src/main/ocr/MacSysOcrProvider.ts index 78b52960f..a565ebf95 100644 --- a/src/main/preprocess/MacSysOcrProvider.ts +++ b/src/main/ocr/MacSysOcrProvider.ts @@ -1,13 +1,13 @@ import { isMac } from '@main/constant' -import { FileMetadata, PreprocessProvider } from '@types' +import { FileMetadata, OcrProvider } from '@types' import Logger from 'electron-log' import * as fs from 'fs' import * as path from 'path' import { TextItem } from 'pdfjs-dist/types/src/display/api' -import BasePreprocessProvider from './BasePreprocessProvider' +import BaseOcrProvider from './BaseOcrProvider' -export default class MacSysOcrProvider extends BasePreprocessProvider { +export default class MacSysOcrProvider extends BaseOcrProvider { private readonly MIN_TEXT_LENGTH = 1000 private MacOCR: any @@ -32,7 +32,7 @@ export default class MacSysOcrProvider extends BasePreprocessProvider { return level === 0 ? this.MacOCR.RECOGNITION_LEVEL_FAST : this.MacOCR.RECOGNITION_LEVEL_ACCURATE } - constructor(provider: PreprocessProvider) { + constructor(provider: OcrProvider) { super(provider) } @@ -61,7 +61,7 @@ export default class MacSysOcrProvider extends BasePreprocessProvider { writeStream.write(ocrResult.text + '\n') // Update progress - await this.sendPreprocessProgress(sourceId, (pageNum / totalPages) * 100) + await this.sendOcrProgress(sourceId, (pageNum / totalPages) * 100) } } diff --git a/src/main/ocr/OcrProvider.ts b/src/main/ocr/OcrProvider.ts new file mode 100644 index 000000000..c4a85d53b --- /dev/null +++ b/src/main/ocr/OcrProvider.ts @@ -0,0 +1,23 @@ +import { FileMetadata, PreprocessProvider as Provider } from '@types' + +import BaseOcrProvider from './BaseOcrProvider' +import OcrProviderFactory from './OcrProviderFactory' + +export default class OcrProvider { + private sdk: BaseOcrProvider + constructor(provider: Provider) { + this.sdk = OcrProviderFactory.create(provider) + } + public async parseFile(sourceId: string, file: FileMetadata): Promise<{ processedFile: FileMetadata }> { + return this.sdk.parseFile(sourceId, file) + } + + /** + * 检查文件是否已经被预处理过 + * @param file 文件信息 + * @returns 如果已处理返回处理后的文件信息,否则返回null + */ + public async checkIfAlreadyProcessed(file: FileMetadata): Promise { + return this.sdk.checkIfAlreadyProcessed(file) + } +} diff --git a/src/main/ocr/OcrProviderFactory.ts b/src/main/ocr/OcrProviderFactory.ts new file mode 100644 index 000000000..96d95a63a --- /dev/null +++ b/src/main/ocr/OcrProviderFactory.ts @@ -0,0 +1,20 @@ +import { isMac } from '@main/constant' +import { OcrProvider } from '@types' +import Logger from 'electron-log' + +import BaseOcrProvider from './BaseOcrProvider' +import DefaultOcrProvider from './DefaultOcrProvider' +import MacSysOcrProvider from './MacSysOcrProvider' +export default class OcrProviderFactory { + static create(provider: OcrProvider): BaseOcrProvider { + switch (provider.id) { + case 'system': + if (!isMac) { + Logger.warn('[OCR] System OCR provider is only available on macOS') + } + return new MacSysOcrProvider(provider) + default: + return new DefaultOcrProvider(provider) + } + } +} diff --git a/src/main/preprocess/BasePreprocessProvider.ts b/src/main/preprocess/BasePreprocessProvider.ts index bf528639d..63648475e 100644 --- a/src/main/preprocess/BasePreprocessProvider.ts +++ b/src/main/preprocess/BasePreprocessProvider.ts @@ -4,7 +4,6 @@ import path from 'node:path' import { windowService } from '@main/services/WindowService' import { getFileExt } from '@main/utils/file' import { FileMetadata, PreprocessProvider } from '@types' -import { createCanvas, loadImage } from 'canvas' import { app } from 'electron' import { TypedArray } from 'pdfjs-dist/types/src/display/api' @@ -120,81 +119,4 @@ export default abstract class BasePreprocessProvider { } return movedPaths } - - public async cropImage(image: Buffer | string) { - const img = await loadImage(image) - const width = img.width - const height = img.height - - const canvas = createCanvas(width, height) - const context = canvas.getContext('2d') - - context.drawImage(img, 0, 0) - - const data = context.getImageData(0, 0, width, height).data - - const top = scanY(true) - const bottom = scanY(false) - const left = scanX(true) - const right = scanX(false) - - if (top === null || bottom === null || left === null || right === null) { - console.error('image is empty') - return canvas.toBuffer() - } - - const new_width = right - left - const new_height = bottom - top - - canvas.width = new_width - canvas.height = new_height - - context.drawImage(img, left, top, new_width, new_height, 0, 0, new_width, new_height) - - return canvas.toBuffer() - - // get pixel RGB data: - function getRGB(x: number, y: number) { - return { - red: data[(width * y + x) * 4], - green: data[(width * y + x) * 4 + 1], - blue: data[(width * y + x) * 4 + 2] - } - } - - // check if pixel is a color other than white: - function isColor(rgb: { red: number; green: number; blue: number }) { - return rgb.red == 255 && rgb.green == 255 && rgb.blue == 255 - } - - // scan top and bottom edges of image: - function scanY(top: boolean) { - const offset = top ? 1 : -1 - - for (let y = top ? 0 : height - 1; top ? y < height : y > -1; y += offset) { - for (let x = 0; x < width; x++) { - if (!isColor(getRGB(x, y))) { - return y - } - } - } - - return null - } - - // scan left and right edges of image: - function scanX(left: boolean) { - const offset = left ? 1 : -1 - - for (let x = left ? 0 : width - 1; left ? x < width : x > -1; x += offset) { - for (let y = 0; y < height; y++) { - if (!isColor(getRGB(x, y))) { - return x - } - } - } - - return null - } - } } diff --git a/src/main/preprocess/MineruPreprocessProvider.ts b/src/main/preprocess/MineruPreprocessProvider.ts index e3bfbe475..844175618 100644 --- a/src/main/preprocess/MineruPreprocessProvider.ts +++ b/src/main/preprocess/MineruPreprocessProvider.ts @@ -58,7 +58,7 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider { Logger.info(`MinerU processing completed for batch: ${batchId}`) // 3. 下载并解压文件 - const { path: outputPath } = await this.downloadAndExtractFile(extractResult.full_zip_url!, file.path) + const { path: outputPath } = await this.downloadAndExtractFile(extractResult.full_zip_url!, file) // 4. 创建处理后的文件信息 return { @@ -125,11 +125,11 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider { } } - private async downloadAndExtractFile(zipUrl: string, originalFilePath: string): Promise<{ path: string }> { + private async downloadAndExtractFile(zipUrl: string, file: FileMetadata): Promise<{ path: string }> { const dirPath = this.storageDir - const baseName = path.basename(originalFilePath, path.extname(originalFilePath)) - const zipPath = path.join(dirPath, `${baseName}.zip`) - const extractPath = path.join(dirPath, `${baseName}`) + + const zipPath = path.join(dirPath, `${file.id}.zip`) + const extractPath = path.join(dirPath, `${file.id}`) Logger.info(`Downloading MinerU result to: ${zipPath}`) diff --git a/src/main/preprocess/PreprocessProviderFactory.ts b/src/main/preprocess/PreprocessProviderFactory.ts index ad17fc0f0..98c5a1e2a 100644 --- a/src/main/preprocess/PreprocessProviderFactory.ts +++ b/src/main/preprocess/PreprocessProviderFactory.ts @@ -1,11 +1,8 @@ -import { isMac } from '@main/constant' import { PreprocessProvider } from '@types' -import Logger from 'electron-log' import BasePreprocessProvider from './BasePreprocessProvider' import DefaultPreprocessProvider from './DefaultPreprocessProvider' import Doc2xPreprocessProvider from './Doc2xPreprocessProvider' -import MacSysOcrProvider from './MacSysOcrProvider' import MineruPreprocessProvider from './MineruPreprocessProvider' import MistralPreprocessProvider from './MistralPreprocessProvider' export default class PreprocessProviderFactory { @@ -15,11 +12,6 @@ export default class PreprocessProviderFactory { return new Doc2xPreprocessProvider(provider) case 'mistral': return new MistralPreprocessProvider(provider) - case 'system': - if (!isMac) { - Logger.warn('[OCR] System OCR provider is only available on macOS') - } - return new MacSysOcrProvider(provider) case 'mineru': return new MineruPreprocessProvider(provider) default: diff --git a/src/main/services/KnowledgeService.ts b/src/main/services/KnowledgeService.ts index 2c00dd875..959a86017 100644 --- a/src/main/services/KnowledgeService.ts +++ b/src/main/services/KnowledgeService.ts @@ -23,6 +23,7 @@ import { SitemapLoader } from '@cherrystudio/embedjs-loader-sitemap' import { WebLoader } from '@cherrystudio/embedjs-loader-web' import Embeddings from '@main/embeddings/Embeddings' import { addFileLoader } from '@main/loader' +import OcrProvider from '@main/ocr/OcrProvider' import PreprocessProvider from '@main/preprocess/PreprocessProvider' import Reranker from '@main/reranker/Reranker' import { windowService } from '@main/services/WindowService' @@ -498,12 +499,16 @@ class KnowledgeService { item: KnowledgeItem ): Promise => { let fileToProcess: FileMetadata = file - if (base.preprocessProvider && file.ext.toLowerCase() === '.pdf') { + if (base.preprocessOrOcrProvider && file.ext.toLowerCase() === '.pdf') { try { - const preprocessProvider = new PreprocessProvider(base.preprocessProvider) - + let provider: PreprocessProvider | OcrProvider + if (base.preprocessOrOcrProvider.type === 'preprocess') { + provider = new PreprocessProvider(base.preprocessOrOcrProvider.provider) + } else { + provider = new OcrProvider(base.preprocessOrOcrProvider.provider) + } // 首先检查文件是否已经被预处理过 - const alreadyProcessed = await preprocessProvider.checkIfAlreadyProcessed(file) + const alreadyProcessed = await provider.checkIfAlreadyProcessed(file) if (alreadyProcessed) { Logger.info(`File already preprocess processed, using cached result: ${file.path}`) return alreadyProcessed @@ -511,7 +516,7 @@ class KnowledgeService { // 执行预处理 Logger.info(`Starting preprocess processing for scanned PDF: ${file.path}`) - const { processedFile } = await preprocessProvider.parseFile(item.id, file) + const { processedFile } = await provider.parseFile(item.id, file) fileToProcess = processedFile } catch (err) { Logger.error(`Preprocess processing failed: ${err}`) diff --git a/src/renderer/src/config/ocrProviders.ts b/src/renderer/src/config/ocrProviders.ts new file mode 100644 index 000000000..5e482e10e --- /dev/null +++ b/src/renderer/src/config/ocrProviders.ts @@ -0,0 +1,12 @@ +import MacOSLogo from '@renderer/assets/images/providers/macos.svg' + +export function getOcrProviderLogo(providerId: string) { + switch (providerId) { + case 'system': + return MacOSLogo + default: + return undefined + } +} + +export const OCR_PROVIDER_CONFIG = {} diff --git a/src/renderer/src/config/preprocessProviders.ts b/src/renderer/src/config/preprocessProviders.ts index f8e1eaeab..587e6ea7f 100644 --- a/src/renderer/src/config/preprocessProviders.ts +++ b/src/renderer/src/config/preprocessProviders.ts @@ -1,6 +1,5 @@ import Doc2xLogo from '@renderer/assets/images/ocr/doc2x.png' import MinerULogo from '@renderer/assets/images/ocr/mineru.jpg' -import MacOSLogo from '@renderer/assets/images/providers/macos.svg' import MistralLogo from '@renderer/assets/images/providers/mistral.png' export function getPreprocessProviderLogo(providerId: string) { @@ -9,8 +8,6 @@ export function getPreprocessProviderLogo(providerId: string) { return Doc2xLogo case 'mistral': return MistralLogo - case 'system': - return MacOSLogo case 'mineru': return MinerULogo default: diff --git a/src/renderer/src/hooks/useKnowledge.ts b/src/renderer/src/hooks/useKnowledge.ts index 8d5afcfba..10e04896b 100644 --- a/src/renderer/src/hooks/useKnowledge.ts +++ b/src/renderer/src/hooks/useKnowledge.ts @@ -145,11 +145,7 @@ export const useKnowledge = (baseId: string) => { } } if (item.type === 'file' && typeof item.content === 'object') { - // await FileManager.deleteFile(item.content.id) - // file name 1.xxx.pdf - // get 1.xxx - // remove .pdf ext - await window.api.file.deleteDir(item.content.name.split('.').slice(0, -1).join('.')) + await window.api.file.deleteDir(item.content.id) } } // 刷新项目 diff --git a/src/renderer/src/hooks/useOcr.ts b/src/renderer/src/hooks/useOcr.ts new file mode 100644 index 000000000..7f83fd9c2 --- /dev/null +++ b/src/renderer/src/hooks/useOcr.ts @@ -0,0 +1,45 @@ +import { RootState } from '@renderer/store' +import { + setDefaultOcrProvider as _setDefaultOcrProvider, + updateOcrProvider as _updateOcrProvider, + updateOcrProviders as _updateOcrProviders +} from '@renderer/store/ocr' +import { OcrProvider } from '@renderer/types' +import { useDispatch, useSelector } from 'react-redux' + +export const useOcrProvider = (id: string) => { + const dispatch = useDispatch() + const ocrProviders = useSelector((state: RootState) => state.ocr.providers) + const provider = ocrProviders.find((provider) => provider.id === id) + if (!provider) { + throw new Error(`OCR provider with id ${id} not found`) + } + const updateOcrProvider = (ocrProvider: OcrProvider) => { + dispatch(_updateOcrProvider(ocrProvider)) + } + return { provider, updateOcrProvider } +} + +export const useOcrProviders = () => { + const dispatch = useDispatch() + const ocrProviders = useSelector((state: RootState) => state.ocr.providers) + return { + ocrProviders: ocrProviders, + updateOcrProviders: (ocrProviders: OcrProvider[]) => dispatch(_updateOcrProviders(ocrProviders)) + } +} + +export const useDefaultOcrProvider = () => { + const defaultProviderId = useSelector((state: RootState) => state.ocr.defaultProvider) + const { ocrProviders } = useOcrProviders() + const dispatch = useDispatch() + const provider = defaultProviderId ? ocrProviders.find((provider) => provider.id === defaultProviderId) : undefined + + const setDefaultOcrProvider = (ocrProvider: OcrProvider) => { + dispatch(_setDefaultOcrProvider(ocrProvider.id)) + } + const updateDefaultOcrProvider = (ocrProvider: OcrProvider) => { + dispatch(_updateOcrProvider(ocrProvider)) + } + return { provider, setDefaultOcrProvider, updateDefaultOcrProvider } +} diff --git a/src/renderer/src/i18n/locales/en-us.json b/src/renderer/src/i18n/locales/en-us.json index 0949e958d..749cdc340 100644 --- a/src/renderer/src/i18n/locales/en-us.json +++ b/src/renderer/src/i18n/locales/en-us.json @@ -1726,6 +1726,13 @@ "title": "Pre Process", "provider": "Pre Process Provider", "provider_placeholder": "Choose a Pre Process provider", + "preprocess_tooltip": "Setting the document preprocessing service provider in Settings -> Tools can effectively improve the retrieval performance for complex document formats." + }, + "ocr": { + "title": "OCR", + "provider": "OCR Provider", + "provider_placeholder": "Choose an OCR provider", + "ocr_tooltip": "Setting the OCR service provider in Settings -> Tools can effectively improve the text recognition performance for images and documents.", "mac_system_ocr_options": { "mode": { "title": "Recognition Mode", diff --git a/src/renderer/src/i18n/locales/zh-cn.json b/src/renderer/src/i18n/locales/zh-cn.json index e8bfc420e..cbed5dbeb 100644 --- a/src/renderer/src/i18n/locales/zh-cn.json +++ b/src/renderer/src/i18n/locales/zh-cn.json @@ -1772,6 +1772,13 @@ "title": "文档预处理", "provider": "文档预处理", "provider_placeholder": "选择一个文档预处理服务商", + "preprocess_tooltip": "在设置 -> 工具中设置文档预处理服务商,可以有效提升复杂文档格式的检索效果" + }, + "ocr": { + "title": "OCR 设置", + "provider": "OCR 服务商", + "provider_placeholder": "选择一个 OCR 服务商", + "ocr_tooltip": "在设置 -> 工具中设置 OCR 服务商,可以有效提升图片文字识别效果", "mac_system_ocr_options": { "mode": { "title": "识别模式", @@ -2042,4 +2049,4 @@ } } } -} \ No newline at end of file +} diff --git a/src/renderer/src/pages/knowledge/components/AddKnowledgePopup.tsx b/src/renderer/src/pages/knowledge/components/AddKnowledgePopup.tsx index dee3276ea..c58d43783 100644 --- a/src/renderer/src/pages/knowledge/components/AddKnowledgePopup.tsx +++ b/src/renderer/src/pages/knowledge/components/AddKnowledgePopup.tsx @@ -6,12 +6,13 @@ import { isEmbeddingModel, isRerankModel } from '@renderer/config/models' import { NOT_SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers' // import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers' import { useKnowledgeBases } from '@renderer/hooks/useKnowledge' +import { useOcrProviders } from '@renderer/hooks/useOcr' import { usePreprocessProviders } from '@renderer/hooks/usePreprocess' import { useProviders } from '@renderer/hooks/useProvider' import AiProvider from '@renderer/providers/AiProvider' import { getKnowledgeBaseParams } from '@renderer/services/KnowledgeService' import { getModelUniqId } from '@renderer/services/ModelService' -import { KnowledgeBase, Model, PreprocessProvider } from '@renderer/types' +import { KnowledgeBase, Model, OcrProvider, PreprocessProvider } from '@renderer/types' import { getErrorMessage } from '@renderer/utils/error' import { Alert, Input, InputNumber, Modal, Select, Slider, Tabs, TabsProps, Tooltip } from 'antd' import { find, sortBy } from 'lodash' @@ -37,7 +38,8 @@ const PopupContainer: React.FC = ({ title, resolve }) => { const [newBase, setNewBase] = useState({} as KnowledgeBase) const { preprocessProviders } = usePreprocessProviders() - const [selectedProvider, setSelectedProvider] = useState(undefined) + const { ocrProviders } = useOcrProviders() + const [selectedProvider, setSelectedProvider] = useState(undefined) const embeddingModels = useMemo(() => { return providers @@ -89,6 +91,20 @@ const PopupContainer: React.FC = ({ title, resolve }) => { .filter((group) => group.options.length > 0) }, [providers, t]) + const preprocessOrOcrSelectOptions = useMemo(() => { + const preprocessOptions = { + label: t('settings.tool.preprocess.provider'), + title: t('settings.tool.preprocess.provider'), + options: preprocessProviders.filter((p) => p.apiKey !== '').map((p) => ({ value: p.id, label: p.name })) + } + const ocrOptions = { + label: t('settings.tool.ocr.provider'), + title: t('settings.tool.ocr.provider'), + options: ocrProviders.filter((p) => p.apiKey !== '').map((p) => ({ value: p.id, label: p.name })) + } + return [preprocessOptions, ocrOptions] + }, [ocrProviders, preprocessProviders]) + const onOk = async () => { try { // const values = await form.validateFields() @@ -167,6 +183,44 @@ const PopupContainer: React.FC = ({ title, resolve }) => { /> + +
+ {t('settings.tool.preprocess.title')} + + + +
+ { - const provider = preprocessProviders.find((p) => p.id === value) - setSelectedProvider(provider) - setNewBase({ ...newBase, preprocessProvider: provider }) - }} - placeholder={t('settings.tool.preprocess.provider_placeholder')} - options={preprocessProviders.filter((p) => p.apiKey !== '').map((p) => ({ value: p.id, label: p.name }))} - allowClear - /> -
-
{t('knowledge.document_count')} diff --git a/src/renderer/src/pages/knowledge/components/KnowledgeSettings.tsx b/src/renderer/src/pages/knowledge/components/KnowledgeSettings.tsx index 41d65ad27..f659167ad 100644 --- a/src/renderer/src/pages/knowledge/components/KnowledgeSettings.tsx +++ b/src/renderer/src/pages/knowledge/components/KnowledgeSettings.tsx @@ -4,6 +4,7 @@ import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant' import { getEmbeddingMaxContext } from '@renderer/config/embedings' import { isEmbeddingModel, isRerankModel } from '@renderer/config/models' import { useKnowledge } from '@renderer/hooks/useKnowledge' +import { useOcrProviders } from '@renderer/hooks/useOcr' import { usePreprocessProviders } from '@renderer/hooks/usePreprocess' import { useProviders } from '@renderer/hooks/useProvider' import { getModelUniqId } from '@renderer/services/ModelService' @@ -24,7 +25,11 @@ interface Props extends ShowParams { const PopupContainer: React.FC = ({ base: _base, resolve }) => { const { preprocessProviders } = usePreprocessProviders() - const [selectedProvider, setSelectedProvider] = useState(_base.preprocessProvider) + const { ocrProviders } = useOcrProviders() + + const [selectedProvider, setSelectedProvider] = useState( + _base.preprocessOrOcrProvider?.provider + ) const [open, setOpen] = useState(true) const { t } = useTranslation() @@ -65,6 +70,19 @@ const PopupContainer: React.FC = ({ base: _base, resolve }) => { })) .filter((group) => group.options.length > 0) + const preprocessOptions = { + label: t('settings.tool.preprocess.provider'), + title: t('settings.tool.preprocess.provider'), + options: preprocessProviders.filter((p) => p.apiKey !== '').map((p) => ({ value: p.id, label: p.name })) + } + const ocrOptions = { + label: t('settings.tool.ocr.provider'), + title: t('settings.tool.ocr.provider'), + options: ocrProviders.filter((p) => p.apiKey !== '').map((p) => ({ value: p.id, label: p.name })) + } + + const preprocessOrOcrSelectOptions = [preprocessOptions, ocrOptions].filter((group) => group.options.length > 0) + const onOk = async () => { try { console.log('newbase', newBase) @@ -99,6 +117,44 @@ const PopupContainer: React.FC = ({ base: _base, resolve }) => { /> + +
+ {t('settings.tool.preprocess.title')} + + + +
+ { - const provider = preprocessProviders.find((p) => p.id === value) - setSelectedProvider(provider) - setNewBase({ ...newBase, preprocessProvider: provider }) - }} - placeholder={t('settings.tool.preprocess.provider_placeholder')} - options={preprocessProviders.filter((p) => p.apiKey !== '').map((p) => ({ value: p.id, label: p.name }))} - allowClear - /> -
-
{t('knowledge.document_count')} diff --git a/src/renderer/src/pages/settings/ToolSettings/OcrSettings/OcrSettings.tsx b/src/renderer/src/pages/settings/ToolSettings/OcrSettings/OcrSettings.tsx new file mode 100644 index 000000000..3e4d703a4 --- /dev/null +++ b/src/renderer/src/pages/settings/ToolSettings/OcrSettings/OcrSettings.tsx @@ -0,0 +1,168 @@ +import { ExportOutlined } from '@ant-design/icons' +import { getOcrProviderLogo, OCR_PROVIDER_CONFIG } from '@renderer/config/ocrProviders' +import { useOcrProvider } from '@renderer/hooks/useOcr' +import { formatApiKeys } from '@renderer/services/ApiService' +import { OcrProvider } from '@renderer/types' +import { hasObjectKey } from '@renderer/utils' +import { Avatar, Divider, Flex, Input, InputNumber, Segmented } from 'antd' +import Link from 'antd/es/typography/Link' +import { FC, useEffect, useState } from 'react' +import { useTranslation } from 'react-i18next' +import styled from 'styled-components' + +import { + SettingDivider, + SettingHelpLink, + SettingHelpText, + SettingHelpTextRow, + SettingRow, + SettingRowTitle, + SettingSubtitle, + SettingTitle +} from '../..' + +interface Props { + provider: OcrProvider +} + +const OcrProviderSettings: FC = ({ provider: _provider }) => { + const { provider: ocrProvider, updateOcrProvider } = useOcrProvider(_provider.id) + const { t } = useTranslation() + const [apiKey, setApiKey] = useState(ocrProvider.apiKey || '') + const [apiHost, setApiHost] = useState(ocrProvider.apiHost || '') + const [options, setOptions] = useState(ocrProvider.options || {}) + + const ocrProviderConfig = OCR_PROVIDER_CONFIG[ocrProvider.id] + const apiKeyWebsite = ocrProviderConfig?.websites?.apiKey + const officialWebsite = ocrProviderConfig?.websites?.official + + useEffect(() => { + setApiKey(ocrProvider.apiKey ?? '') + setApiHost(ocrProvider.apiHost ?? '') + setOptions(ocrProvider.options ?? {}) + }, [ocrProvider.apiKey, ocrProvider.apiHost, ocrProvider.options]) + + const onUpdateApiKey = () => { + if (apiKey !== ocrProvider.apiKey) { + updateOcrProvider({ ...ocrProvider, apiKey }) + } + } + + const onUpdateApiHost = () => { + let trimmedHost = apiHost?.trim() || '' + if (trimmedHost.endsWith('/')) { + trimmedHost = trimmedHost.slice(0, -1) + } + if (trimmedHost !== ocrProvider.apiHost) { + updateOcrProvider({ ...ocrProvider, apiHost: trimmedHost }) + } else { + setApiHost(ocrProvider.apiHost || '') + } + } + + const onUpdateOptions = (key: string, value: any) => { + const newOptions = { ...options, [key]: value } + setOptions(newOptions) + updateOcrProvider({ ...ocrProvider, options: newOptions }) + } + + return ( + <> + + + + + {ocrProvider.name} + {officialWebsite && ocrProviderConfig?.websites && ( + + + + )} + + + + {hasObjectKey(ocrProvider, 'apiKey') && ( + <> + {t('settings.provider.api_key')} + + setApiKey(formatApiKeys(e.target.value))} + onBlur={onUpdateApiKey} + spellCheck={false} + type="password" + autoFocus={apiKey === ''} + /> + + + + {t('settings.provider.get_api_key')} + + {t('settings.provider.api_key.tip')} + + + )} + + {hasObjectKey(ocrProvider, 'apiHost') && ( + <> + + {t('settings.provider.api_host')} + + + setApiHost(e.target.value)} + onBlur={onUpdateApiHost} + /> + + + )} + + {hasObjectKey(ocrProvider, 'options') && ocrProvider.id === 'system' && ( + <> + + + {t('settings.tool.ocr.mac_system_ocr_options.mode.title')} + onUpdateOptions('recognitionLevel', value)} + /> + + + + {t('settings.tool.ocr.mac_system_ocr_options.min_confidence')} + onUpdateOptions('minConfidence', value)} + min={0} + max={1} + step={0.1} + /> + + + )} + + ) +} + +const ProviderName = styled.span` + font-size: 14px; + font-weight: 500; +` +const ProviderLogo = styled(Avatar)` + border: 0.5px solid var(--color-border); +` + +export default OcrProviderSettings diff --git a/src/renderer/src/pages/settings/ToolSettings/OcrSettings/index.tsx b/src/renderer/src/pages/settings/ToolSettings/OcrSettings/index.tsx new file mode 100644 index 000000000..1a3b2d2b5 --- /dev/null +++ b/src/renderer/src/pages/settings/ToolSettings/OcrSettings/index.tsx @@ -0,0 +1,58 @@ +import { isMac } from '@renderer/config/constant' +import { useTheme } from '@renderer/context/ThemeProvider' +import { useDefaultOcrProvider, useOcrProviders } from '@renderer/hooks/useOcr' +import { PreprocessProvider } from '@renderer/types' +import { Select } from 'antd' +import { FC, useState } from 'react' +import { useTranslation } from 'react-i18next' + +import { SettingContainer, SettingDivider, SettingGroup, SettingRow, SettingRowTitle, SettingTitle } from '../..' +import OcrProviderSettings from './OcrSettings' + +const OcrSettings: FC = () => { + const { ocrProviders } = useOcrProviders() + const { provider: defaultProvider, setDefaultOcrProvider } = useDefaultOcrProvider() + const { t } = useTranslation() + const [selectedProvider, setSelectedProvider] = useState(defaultProvider) + const { theme: themeMode } = useTheme() + + function updateSelectedOcrProvider(providerId: string) { + const provider = ocrProviders.find((p) => p.id === providerId) + if (!provider) { + return + } + setDefaultOcrProvider(provider) + setSelectedProvider(provider) + } + + return ( + + + {t('settings.tool.ocr.title')} + + + {t('settings.tool.ocr.provider')} +
+