diff --git a/src/main/ocr/BaseOcrProvider.ts b/src/main/ocr/BaseOcrProvider.ts index b165f9ff2..c7992b141 100644 --- a/src/main/ocr/BaseOcrProvider.ts +++ b/src/main/ocr/BaseOcrProvider.ts @@ -1,7 +1,7 @@ import fs from 'node:fs' import { windowService } from '@main/services/WindowService' -import { FileSource, OcrProvider } from '@types' +import { FileSource, LocalFileSource, OcrProvider } from '@types' import Logger from 'electron-log' import pdfParse from 'pdf-parse' export default abstract class BaseOcrProvider { @@ -12,7 +12,7 @@ export default abstract class BaseOcrProvider { } this.provider = provider } - abstract parseFile(sourceId: string, file: FileSource): Promise<{ processedFile: FileSource }> + abstract parseFile(sourceId: string, file: FileSource): Promise<{ processedFile: LocalFileSource }> /** * 辅助方法:延迟执行 */ diff --git a/src/main/ocr/MistralOcrProvider.ts b/src/main/ocr/MistralOcrProvider.ts new file mode 100644 index 000000000..dd242ba16 --- /dev/null +++ b/src/main/ocr/MistralOcrProvider.ts @@ -0,0 +1,203 @@ +import fs from 'node:fs' + +import { MistralService } from '@main/services/file/MistralService' +import { MistralClientManager } from '@main/services/MistralClientManager' +import { Mistral } from '@mistralai/mistralai' +import { DocumentURLChunk } from '@mistralai/mistralai/models/components/documenturlchunk' +import { ImageURLChunk } from '@mistralai/mistralai/models/components/imageurlchunk' +import { OCRResponse } from '@mistralai/mistralai/models/components/ocrresponse' +import { FileSource, FileTypes, isLocalFile, LocalFileSource, OcrProvider } from '@types' +import Logger from 'electron-log' +import path from 'path' +import { v4 as uuidv4 } from 'uuid' + +import BaseOcrProvider from './BaseOcrProvider' + +type PreuploadResponse = DocumentURLChunk | ImageURLChunk + +export default class MistralOcrProvider extends BaseOcrProvider { + private sdk: Mistral + private fileService: MistralService + + constructor(provider: OcrProvider) { + super(provider) + const clientManager = MistralClientManager.getInstance() + clientManager.initializeClient(provider.apiKey!) + this.sdk = clientManager.getClient() + this.fileService = new MistralService(provider.apiKey!) + } + + private async preupload(file: FileSource): Promise { + let document: PreuploadResponse + if (isLocalFile(file)) { + Logger.info(`OCR preupload started for local file: ${file.path}`) + + const pdfInfo = await this.getPdfInfo(file.path) + if (pdfInfo.pageCount >= 1000) { + throw new Error(`PDF page count (${pdfInfo.pageCount}) exceeds the limit of 1000 pages`) + } + if (pdfInfo.fileSize >= 512 * 1024 * 1024) { + const fileSizeMB = Math.round(pdfInfo.fileSize / (1024 * 1024)) + throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 300MB`) + } + + if (file.ext.toLowerCase() === '.pdf') { + const uploadResponse = await this.fileService.uploadFile(file) + + if (uploadResponse.status === 'failed') { + Logger.error('File upload failed:', uploadResponse) + throw new Error('Failed to upload file: ' + uploadResponse.displayName) + } + + const fileUrl = await this.sdk.files.getSignedUrl({ + fileId: uploadResponse.fileId + }) + Logger.info('Got signed URL:', fileUrl) + + document = { + type: 'document_url', + documentUrl: fileUrl.url + } + } else { + const base64Image = Buffer.from(fs.readFileSync(file.path)).toString('base64') + document = { + type: 'image_url', + imageUrl: `data:image/png;base64,${base64Image}` + } + } + } else { + if (file.ext.toLowerCase() === '.pdf') { + document = { + type: 'document_url', + documentUrl: file.url + } + } else { + document = { + type: 'image_url', + imageUrl: file.url + } + } + } + + if (!document) { + throw new Error('Unsupported file type') + } + return document + } + + public async parseFile(sourceId: string, file: FileSource): Promise<{ processedFile: LocalFileSource }> { + try { + const document = await this.preupload(file) + const result = await this.sdk.ocr.process({ + model: this.provider.model!, + document: document, + includeImageBase64: true + }) + if (result) { + await this.sendOcrProgress(sourceId, 100) + const processedFile = this.convertFile(result, file) + return { + processedFile + } + } else { + throw new Error('OCR processing failed: OCR response is empty') + } + } catch (error) { + throw new Error('OCR processing failed: ' + error) + } + } + + private convertFile(result: OCRResponse, file: FileSource): LocalFileSource { + // Create a unique directory for this conversion to store images + const conversionId = uuidv4() + let outputPath = '' + let outputFileName = '' + if (isLocalFile(file)) { + outputPath = path.join(path.dirname(file.path), conversionId) + outputFileName = path.basename(file.path, path.extname(file.path)) + fs.mkdirSync(outputPath, { recursive: true }) + } + + const markdownParts: string[] = [] + let counter = 0 + + // Process each page + result.pages.forEach((page) => { + let pageMarkdown = page.markdown + + // Process images from this page + page.images.forEach((image) => { + if (image.imageBase64) { + let imageFormat = 'jpeg' // default format + let imageBase64Data = image.imageBase64 + + // Check for data URL prefix more efficiently + const prefixEnd = image.imageBase64.indexOf(';base64,') + if (prefixEnd > 0) { + const prefix = image.imageBase64.substring(0, prefixEnd) + const formatIndex = prefix.indexOf('image/') + if (formatIndex >= 0) { + imageFormat = prefix.substring(formatIndex + 6) + } + imageBase64Data = image.imageBase64.substring(prefixEnd + 8) + } + + const imageFileName = `img-${counter}.${imageFormat}` + const imagePath = path.join(outputPath, imageFileName) + + // Save image file + try { + fs.writeFileSync(imagePath, Buffer.from(imageBase64Data, 'base64')) + + // Update image reference in markdown + // Use relative path for better portability + const relativeImagePath = `./${imageFileName}` + + // Find the start and end of the image markdown + const imgStart = pageMarkdown.indexOf(image.imageBase64) + if (imgStart >= 0) { + // Find the markdown image syntax around this base64 + const mdStart = pageMarkdown.lastIndexOf('![', imgStart) + const mdEnd = pageMarkdown.indexOf(')', imgStart) + + if (mdStart >= 0 && mdEnd >= 0) { + // Replace just this specific image reference + pageMarkdown = + pageMarkdown.substring(0, mdStart) + + `![Image ${counter}](${relativeImagePath})` + + pageMarkdown.substring(mdEnd + 1) + } + } + + counter++ + } catch (error) { + Logger.error(`Failed to save image ${imageFileName}:`, error) + } + } + }) + + markdownParts.push(pageMarkdown) + }) + + // Combine all markdown content with double newlines for readability + const combinedMarkdown = markdownParts.join('\n\n') + + // Write the markdown content to a file + const mdFileName = `${outputFileName}.md` + const mdFilePath = path.join(outputPath, mdFileName) + fs.writeFileSync(mdFilePath, combinedMarkdown) + + return { + id: conversionId, + name: mdFileName, + origin_name: mdFileName, + path: mdFilePath, + created_at: new Date().toISOString(), + type: FileTypes.DOCUMENT, + ext: '.md', + size: fs.statSync(mdFilePath).size, + count: result.pages.length, + source: 'local' + } + } +} diff --git a/src/main/ocr/OcrProvider.ts b/src/main/ocr/OcrProvider.ts index 032c3062c..22b1efe66 100644 --- a/src/main/ocr/OcrProvider.ts +++ b/src/main/ocr/OcrProvider.ts @@ -1,12 +1,12 @@ -import { FileType, KnowledgeBaseParams } from '@types' +import { FileType, OcrProvider as Provider } from '@types' import BaseOcrProvider from './BaseOcrProvider' import OcrProviderFactory from './OcrProviderFactory' export default class OcrProvider { private sdk: BaseOcrProvider - constructor(base: KnowledgeBaseParams) { - this.sdk = OcrProviderFactory.create(base) + constructor(provider: Provider) { + this.sdk = OcrProviderFactory.create(provider) } public async parseFile(sourceId: string, file: FileType): Promise<{ processedFile: FileType }> { return this.sdk.parseFile(sourceId, file) diff --git a/src/main/services/FileStorage.ts b/src/main/services/FileStorage.ts index 252435740..2ecb34569 100644 --- a/src/main/services/FileStorage.ts +++ b/src/main/services/FileStorage.ts @@ -73,7 +73,8 @@ class FileStorage { size: storedStats.size, ext, type: getFileType(ext), - count: 2 + count: 2, + source: 'local' } } } @@ -112,7 +113,8 @@ class FileStorage { size: stats.size, ext: ext, type: fileType, - count: 1 + count: 1, + source: 'local' as const } }) @@ -177,7 +179,8 @@ class FileStorage { size: stats.size, ext: ext, type: fileType, - count: 1 + count: 1, + source: 'local' } return fileMetadata @@ -201,7 +204,8 @@ class FileStorage { size: stats.size, ext: ext, type: fileType, - count: 1 + count: 1, + source: 'local' } return fileInfo @@ -267,6 +271,16 @@ class FileStorage { } } + public base64File = async ( + _: Electron.IpcMainInvokeEvent, + filePath: string + ): Promise<{ data: Buffer; mime: string }> => { + return { + data: await fs.promises.readFile(filePath), + mime: 'application/pdf' + } + } + public binaryFile = async (_: Electron.IpcMainInvokeEvent, id: string): Promise<{ data: Buffer; mime: string }> => { const filePath = path.join(this.storageDir, id) const data = await fs.promises.readFile(filePath) @@ -424,7 +438,8 @@ class FileStorage { size: stats.size, ext: ext, type: fileType, - count: 1 + count: 1, + source: 'local' } return fileMetadata diff --git a/src/main/services/KnowledgeService.ts b/src/main/services/KnowledgeService.ts index d0bd5e6d8..8b65214d1 100644 --- a/src/main/services/KnowledgeService.ts +++ b/src/main/services/KnowledgeService.ts @@ -177,9 +177,10 @@ class KnowledgeService { task: async () => { // 添加OCR预处理逻辑 let fileToProcess: FileType = file - if (base.preprocessing && file.ext.toLowerCase() === '.pdf') { + if (base.preprocessing && base.ocrProvider && file.ext.toLowerCase() === '.pdf') { try { - const ocrProvider = new OcrProvider(base) + file.source = 'local' + const ocrProvider = new OcrProvider(base.ocrProvider) Logger.info(`Starting OCR processing for file: ${file.path}`) const { processedFile } = await ocrProvider.parseFile(item.id, file) @@ -508,6 +509,10 @@ class KnowledgeService { ): Promise => { return await new Reranker(base).rerank(search, results) } + + public getStorageDir = (): string => { + return this.storageDir + } } export default new KnowledgeService() diff --git a/src/main/services/file/FileServiceManager.ts b/src/main/services/file/FileServiceManager.ts index 7ebdbe922..07237ab0b 100644 --- a/src/main/services/file/FileServiceManager.ts +++ b/src/main/services/file/FileServiceManager.ts @@ -6,6 +6,7 @@ export class FileServiceManager { private static instance: FileServiceManager private services: Map = new Map() + // eslint-disable-next-line @typescript-eslint/no-empty-function private constructor() {} static getInstance(): FileServiceManager { diff --git a/src/main/services/file/GeminiService.ts b/src/main/services/file/GeminiService.ts index b832f51e6..abb8bf606 100644 --- a/src/main/services/file/GeminiService.ts +++ b/src/main/services/file/GeminiService.ts @@ -60,41 +60,26 @@ export class GeminiService extends BaseFileService { return cachedResponse } - const response = await this.fileManager.getFile(fileId) + const response = await this.fileManager.listFiles() - // 根据文件状态设置响应状态 - let status: 'success' | 'processing' | 'failed' | 'unknown' - switch (response.state) { - case FileState.ACTIVE: - status = 'success' - break - case FileState.PROCESSING: - status = 'processing' - break - case FileState.FAILED: - status = 'failed' - break - default: - status = 'unknown' + if (response.files) { + const file = response.files.filter((file) => file.state === FileState.ACTIVE).find((file) => file.name === fileId) + if (file) { + return { + fileId: fileId, + displayName: file.displayName || '', + status: 'success', + originalFile: file + } + } } - const fileResponse: FileUploadResponse = { - fileId, - displayName: response.displayName || '', - status, - originalFile: response + return { + fileId: fileId, + displayName: '', + status: 'failed', + originalFile: undefined } - - // 只缓存成功的文件 - if (status === 'success') { - CacheService.set( - `${GeminiService.FILE_LIST_CACHE_KEY}_${fileId}`, - fileResponse, - GeminiService.FILE_CACHE_DURATION - ) - } - - return fileResponse } async listFiles(): Promise { diff --git a/src/main/services/file/MistralService.ts b/src/main/services/file/MistralService.ts index 88df2c918..863d5ec25 100644 --- a/src/main/services/file/MistralService.ts +++ b/src/main/services/file/MistralService.ts @@ -1,11 +1,14 @@ +import fs from 'node:fs/promises' + +import { Mistral } from '@mistralai/mistralai' import { FileListResponse, FileUploadResponse, LocalFileSource } from '@types' -import { fileFrom } from 'fetch-blob/from.js' +import Logger from 'electron-log' import { MistralClientManager } from '../MistralClientManager' import { BaseFileService } from './BaseFileService' export class MistralService extends BaseFileService { - private readonly client + private readonly client: Mistral constructor(apiKey: string) { super(apiKey) @@ -16,9 +19,12 @@ export class MistralService extends BaseFileService { async uploadFile(file: LocalFileSource): Promise { try { - const blob = await fileFrom(file.path) + const fileBuffer = await fs.readFile(file.path) const response = await this.client.files.upload({ - file: blob, + file: { + fileName: file.path, + content: new Uint8Array(fileBuffer) + }, purpose: 'ocr' }) @@ -28,7 +34,7 @@ export class MistralService extends BaseFileService { status: 'success' } } catch (error) { - console.error('Error uploading file:', error) + Logger.error('Error uploading file:', error) return { fileId: '', displayName: file.origin_name, @@ -78,7 +84,12 @@ export class MistralService extends BaseFileService { } } catch (error) { console.error('Error retrieving file:', error) - throw error + return { + fileId: fileId, + displayName: '', + status: 'failed', + originalFile: undefined + } } } } diff --git a/src/main/utils/file.ts b/src/main/utils/file.ts index e447b0b96..e066c3616 100644 --- a/src/main/utils/file.ts +++ b/src/main/utils/file.ts @@ -57,7 +57,8 @@ export function getAllFiles(dirPath: string, arrayOfFiles: FileType[] = []): Fil count: 1, origin_name: name, type: fileType, - created_at: new Date().toISOString() + created_at: new Date().toISOString(), + source: 'local' } arrayOfFiles.push(fileItem) diff --git a/src/preload/index.d.ts b/src/preload/index.d.ts index bcbb211ad..aadd20990 100644 --- a/src/preload/index.d.ts +++ b/src/preload/index.d.ts @@ -1,8 +1,16 @@ import { ElectronAPI } from '@electron-toolkit/preload' -import type { FileMetadataResponse, ListFilesResponse, UploadFileResponse } from '@google/generative-ai/server' import { ExtractChunkData } from '@llm-tools/embedjs-interfaces' import type { MCPServer, MCPTool } from '@renderer/types' -import { AppInfo, FileType, KnowledgeBaseParams, KnowledgeItem, LanguageVarious, WebDavConfig } from '@renderer/types' +import { + AppInfo, + FileListResponse, + FileType, + FileUploadResponse, + KnowledgeBaseParams, + KnowledgeItem, + LanguageVarious, + WebDavConfig +} from '@renderer/types' import type { LoaderReturn } from '@shared/config/types' import type { OpenDialogOptions } from 'electron' import type { UpdateInfo } from 'electron-updater' @@ -66,6 +74,7 @@ declare global { ) => Promise saveImage: (name: string, data: string) => void base64Image: (fileId: string) => Promise<{ mime: string; base64: string; data: string }> + base64File: (filePath: string) => Promise<{ mime: string; data: Buffer }> download: (url: string) => Promise copy: (fileId: string, destPath: string) => Promise binaryFile: (fileId: string) => Promise<{ data: Buffer; mime: string }> @@ -118,9 +127,9 @@ declare global { resetMinimumSize: () => Promise } fileService: { - upload: (type: string, apiKey: string, file: FileType) => Promise - retrieve: (type: string, apiKey: string, fileId: string) => Promise - list: (type: string, apiKey: string) => Promise + upload: (type: string, apiKey: string, file: FileType) => Promise + retrieve: (type: string, apiKey: string, fileId: string) => Promise + list: (type: string, apiKey: string) => Promise delete: (type: string, apiKey: string, fileId: string) => Promise } selectionMenu: { diff --git a/src/preload/index.ts b/src/preload/index.ts index 1b4dfb9da..f129e0229 100644 --- a/src/preload/index.ts +++ b/src/preload/index.ts @@ -53,6 +53,7 @@ const api = { selectFolder: () => ipcRenderer.invoke('file:selectFolder'), saveImage: (name: string, data: string) => ipcRenderer.invoke('file:saveImage', name, data), base64Image: (fileId: string) => ipcRenderer.invoke('file:base64Image', fileId), + base64File: (filePath: string) => ipcRenderer.invoke('file:base64File', filePath), download: (url: string) => ipcRenderer.invoke('file:download', url), copy: (fileId: string, destPath: string) => ipcRenderer.invoke('file:copy', fileId, destPath), binaryFile: (fileId: string) => ipcRenderer.invoke('file:binaryFile', fileId) diff --git a/src/renderer/src/components/Icons/OcrIcon.tsx b/src/renderer/src/components/Icons/OcrIcon.tsx index ae25e2adb..2565d85fa 100644 --- a/src/renderer/src/components/Icons/OcrIcon.tsx +++ b/src/renderer/src/components/Icons/OcrIcon.tsx @@ -10,6 +10,7 @@ const OcrIcon: FC> = (props) => ( width="16" height="16" strokeWidth="1" + fill="var(--color-text-2)" {...props}> > = (props) => ( width="16" height="16" className="icon" + fill="var(--color-text-2)" {...props}> diff --git a/src/renderer/src/config/ocrProviders.ts b/src/renderer/src/config/ocrProviders.ts index 19c538f20..e03b48f5c 100644 --- a/src/renderer/src/config/ocrProviders.ts +++ b/src/renderer/src/config/ocrProviders.ts @@ -1,8 +1,12 @@ import Doc2xLogo from '@renderer/assets/images/ocr/doc2x.svg' +import MistralLogo from '@renderer/assets/images/providers/mistral.png' + export function getOcrProviderLogo(providerId: string) { switch (providerId) { case 'doc2x': return Doc2xLogo + case 'mistral': + return MistralLogo default: return undefined } @@ -14,5 +18,11 @@ export const OCR_PROVIDER_CONFIG = { official: 'https://doc2x.noedgeai.com', apiKey: 'https://open.noedgeai.com/apiKeys' } + }, + mistral: { + websites: { + official: 'https://mistral.ai', + apiKey: 'https://mistral.ai/api-keys' + } } } diff --git a/src/renderer/src/providers/GeminiProvider.ts b/src/renderer/src/providers/GeminiProvider.ts index 06c7d8c03..9891f9d0a 100644 --- a/src/renderer/src/providers/GeminiProvider.ts +++ b/src/renderer/src/providers/GeminiProvider.ts @@ -30,7 +30,17 @@ import { filterEmptyMessages, filterUserRoleStartMessages } from '@renderer/services/MessagesService' -import { Assistant, FileType, FileTypes, MCPToolResponse, Message, Model, Provider, Suggestion } from '@renderer/types' +import { + Assistant, + FileType, + FileTypes, + FileUploadResponse, + MCPToolResponse, + Message, + Model, + Provider, + Suggestion +} from '@renderer/types' import { removeSpecialCharactersForTopicName } from '@renderer/utils' import { fileToBase64 } from '@renderer/utils/file' import { @@ -84,25 +94,28 @@ export default class GeminiProvider extends BaseProvider { } as InlineDataPart } - // Retrieve file from Gemini uploaded files - const fileMetadata = await window.api.fileService.retrieve(this.provider.type, this.apiKey, file.id) - - if (fileMetadata) { + // 尝试检索文件 + const response: FileUploadResponse = await window.api.fileService.retrieve(this.provider.type, this.apiKey, file.id) + if (response && response.status === 'success') { return { fileData: { - fileUri: fileMetadata.uri, - mimeType: fileMetadata.mimeType + fileUri: response.originalFile.uri, + mimeType: response.originalFile.mimeType } } as FileDataPart + } else { + console.log('file not found', response) } - - // If file is not found, upload it to Gemini - const uploadResult = await window.api.fileService.upload(this.provider.type, this.apiKey, file) - + // 如果文件不存在,上传新文件 + const uploadResponse: FileUploadResponse = await window.api.fileService.upload( + this.provider.type, + this.apiKey, + file + ) return { fileData: { - fileUri: uploadResult.file.uri, - mimeType: uploadResult.file.mimeType + fileUri: uploadResponse.originalFile.file.uri, + mimeType: uploadResponse.originalFile.file.mimeType } } as FileDataPart } diff --git a/src/renderer/src/store/migrate.ts b/src/renderer/src/store/migrate.ts index fc9345d56..1c7055b7b 100644 --- a/src/renderer/src/store/migrate.ts +++ b/src/renderer/src/store/migrate.ts @@ -804,6 +804,13 @@ const migrateConfig = { name: 'Doc2x', apiKey: '', apiHost: 'https://v2.doc2x.noedgeai.com' + }, + { + id: 'mistral', + name: 'Mistral', + model: 'mistral-ocr-latest', + apiKey: '', + apiHost: 'https://api.mistral.ai' } ] } diff --git a/src/renderer/src/store/ocr.ts b/src/renderer/src/store/ocr.ts index 859d2ebe1..27c2346c2 100644 --- a/src/renderer/src/store/ocr.ts +++ b/src/renderer/src/store/ocr.ts @@ -13,6 +13,13 @@ const initialState: OcrState = { name: 'Doc2x', apiKey: '', apiHost: 'https://v2.doc2x.noedgeai.com' + }, + { + id: 'mistral', + name: 'Mistral', + model: 'mistral-ocr-latest', + apiKey: '', + apiHost: 'https://api.mistral.ai' } ], defaultProvider: '' diff --git a/src/renderer/src/types/index.ts b/src/renderer/src/types/index.ts index 70c9b8ed4..2b75c3046 100644 --- a/src/renderer/src/types/index.ts +++ b/src/renderer/src/types/index.ts @@ -340,6 +340,7 @@ export type KnowledgeBaseParams = { rerankModelProvider?: string topN?: number preprocessing?: boolean + ocrProvider?: OcrProvider } export interface OcrProvider { diff --git a/src/renderer/src/utils/file.ts b/src/renderer/src/utils/file.ts index 18b1a26e9..02f20dae3 100644 --- a/src/renderer/src/utils/file.ts +++ b/src/renderer/src/utils/file.ts @@ -1,9 +1,7 @@ -import fs from 'fs' - export const fileToBase64 = async (filePath: string) => { - const buffer = await fs.promises.readFile(filePath) + const result = await window.api.file.base64Image(filePath) return { - data: buffer.toString('base64'), - mimeType: 'application/pdf' + data: result.base64, + mimeType: result.mime } }