feat: split ocr and preprocess

This commit is contained in:
eeee0717
2025-06-04 20:45:57 +08:00
parent 88204878b0
commit 62cfccc035
26 changed files with 761 additions and 162 deletions
+200
View File
@@ -0,0 +1,200 @@
import fs from 'node:fs'
import path from 'node:path'
import { windowService } from '@main/services/WindowService'
import { getFileExt } from '@main/utils/file'
import { FileMetadata, OcrProvider } from '@types'
import { createCanvas, loadImage } from 'canvas'
import { app } from 'electron'
import { TypedArray } from 'pdfjs-dist/types/src/display/api'
export default abstract class BaseOcrProvider {
protected provider: OcrProvider
public storageDir = path.join(app.getPath('userData'), 'Data', 'Files')
constructor(provider: OcrProvider) {
if (!provider) {
throw new Error('OCR provider is not set')
}
this.provider = provider
}
abstract parseFile(sourceId: string, file: FileMetadata): Promise<{ processedFile: FileMetadata }>
/**
* 检查文件是否已经被预处理过
* 统一检测方法:如果 Data/Files/{file.id} 是目录,说明已被预处理
* @param file 文件信息
* @returns 如果已处理返回处理后的文件信息,否则返回null
*/
public async checkIfAlreadyProcessed(file: FileMetadata): Promise<FileMetadata | null> {
try {
// 检查 Data/Files/{file.id} 是否是目录
const preprocessDirPath = path.join(this.storageDir, file.id)
if (fs.existsSync(preprocessDirPath)) {
const stats = await fs.promises.stat(preprocessDirPath)
// 如果是目录,说明已经被预处理过
if (stats.isDirectory()) {
// 查找目录中的处理结果文件
const files = await fs.promises.readdir(preprocessDirPath)
// 查找主要的处理结果文件(.md 或 .txt)
const processedFile = files.find((fileName) => fileName.endsWith('.md') || fileName.endsWith('.txt'))
if (processedFile) {
const processedFilePath = path.join(preprocessDirPath, processedFile)
const processedStats = await fs.promises.stat(processedFilePath)
const ext = getFileExt(processedFile)
return {
...file,
name: file.name.replace(file.ext, ext),
path: processedFilePath,
ext: ext,
size: processedStats.size,
created_at: processedStats.birthtime.toISOString()
}
}
}
}
return null
} catch (error) {
// 如果检查过程中出现错误,返回null表示未处理
return null
}
}
/**
* 辅助方法:延迟执行
*/
public delay = (ms: number): Promise<void> => {
return new Promise((resolve) => setTimeout(resolve, ms))
}
public async readPdf(
source: string | URL | TypedArray,
passwordCallback?: (fn: (password: string) => void, reason: string) => string
) {
const { getDocument } = await import('pdfjs-dist/legacy/build/pdf.mjs')
const documentLoadingTask = getDocument(source)
if (passwordCallback) {
documentLoadingTask.onPassword = passwordCallback
}
const document = await documentLoadingTask.promise
return document
}
public async sendOcrProgress(sourceId: string, progress: number): Promise<void> {
const mainWindow = windowService.getMainWindow()
mainWindow?.webContents.send('file-preprocess-progress', {
itemId: sourceId,
progress: progress
})
}
/**
* 将文件移动到附件目录
* @param fileId 文件id
* @param filePaths 需要移动的文件路径数组
* @returns 移动后的文件路径数组
*/
public moveToAttachmentsDir(fileId: string, filePaths: string[]): string[] {
const attachmentsPath = path.join(this.storageDir, fileId)
if (!fs.existsSync(attachmentsPath)) {
fs.mkdirSync(attachmentsPath, { recursive: true })
}
const movedPaths: string[] = []
for (const filePath of filePaths) {
if (fs.existsSync(filePath)) {
const fileName = path.basename(filePath)
const destPath = path.join(attachmentsPath, fileName)
fs.copyFileSync(filePath, destPath)
fs.unlinkSync(filePath) // 删除原文件,实现"移动"
movedPaths.push(destPath)
}
}
return movedPaths
}
public async cropImage(image: Buffer | string) {
const img = await loadImage(image)
const width = img.width
const height = img.height
const canvas = createCanvas(width, height)
const context = canvas.getContext('2d')
context.drawImage(img, 0, 0)
const data = context.getImageData(0, 0, width, height).data
const top = scanY(true)
const bottom = scanY(false)
const left = scanX(true)
const right = scanX(false)
if (top === null || bottom === null || left === null || right === null) {
console.error('image is empty')
return canvas.toBuffer()
}
const new_width = right - left
const new_height = bottom - top
canvas.width = new_width
canvas.height = new_height
context.drawImage(img, left, top, new_width, new_height, 0, 0, new_width, new_height)
return canvas.toBuffer()
// get pixel RGB data:
function getRGB(x: number, y: number) {
return {
red: data[(width * y + x) * 4],
green: data[(width * y + x) * 4 + 1],
blue: data[(width * y + x) * 4 + 2]
}
}
// check if pixel is a color other than white:
function isColor(rgb: { red: number; green: number; blue: number }) {
return rgb.red == 255 && rgb.green == 255 && rgb.blue == 255
}
// scan top and bottom edges of image:
function scanY(top: boolean) {
const offset = top ? 1 : -1
for (let y = top ? 0 : height - 1; top ? y < height : y > -1; y += offset) {
for (let x = 0; x < width; x++) {
if (!isColor(getRGB(x, y))) {
return y
}
}
}
return null
}
// scan left and right edges of image:
function scanX(left: boolean) {
const offset = left ? 1 : -1
for (let x = left ? 0 : width - 1; left ? x < width : x > -1; x += offset) {
for (let y = 0; y < height; y++) {
if (!isColor(getRGB(x, y))) {
return x
}
}
}
return null
}
}
}
+12
View File
@@ -0,0 +1,12 @@
import { FileMetadata, OcrProvider } from '@types'
import BaseOcrProvider from './BaseOcrProvider'
export default class DefaultOcrProvider extends BaseOcrProvider {
constructor(provider: OcrProvider) {
super(provider)
}
public parseFile(): Promise<{ processedFile: FileMetadata }> {
throw new Error('Method not implemented.')
}
}
@@ -1,13 +1,13 @@
import { isMac } from '@main/constant'
import { FileMetadata, PreprocessProvider } from '@types'
import { FileMetadata, OcrProvider } from '@types'
import Logger from 'electron-log'
import * as fs from 'fs'
import * as path from 'path'
import { TextItem } from 'pdfjs-dist/types/src/display/api'
import BasePreprocessProvider from './BasePreprocessProvider'
import BaseOcrProvider from './BaseOcrProvider'
export default class MacSysOcrProvider extends BasePreprocessProvider {
export default class MacSysOcrProvider extends BaseOcrProvider {
private readonly MIN_TEXT_LENGTH = 1000
private MacOCR: any
@@ -32,7 +32,7 @@ export default class MacSysOcrProvider extends BasePreprocessProvider {
return level === 0 ? this.MacOCR.RECOGNITION_LEVEL_FAST : this.MacOCR.RECOGNITION_LEVEL_ACCURATE
}
constructor(provider: PreprocessProvider) {
constructor(provider: OcrProvider) {
super(provider)
}
@@ -61,7 +61,7 @@ export default class MacSysOcrProvider extends BasePreprocessProvider {
writeStream.write(ocrResult.text + '\n')
// Update progress
await this.sendPreprocessProgress(sourceId, (pageNum / totalPages) * 100)
await this.sendOcrProgress(sourceId, (pageNum / totalPages) * 100)
}
}
+23
View File
@@ -0,0 +1,23 @@
import { FileMetadata, PreprocessProvider as Provider } from '@types'
import BaseOcrProvider from './BaseOcrProvider'
import OcrProviderFactory from './OcrProviderFactory'
export default class OcrProvider {
private sdk: BaseOcrProvider
constructor(provider: Provider) {
this.sdk = OcrProviderFactory.create(provider)
}
public async parseFile(sourceId: string, file: FileMetadata): Promise<{ processedFile: FileMetadata }> {
return this.sdk.parseFile(sourceId, file)
}
/**
* 检查文件是否已经被预处理过
* @param file 文件信息
* @returns 如果已处理返回处理后的文件信息,否则返回null
*/
public async checkIfAlreadyProcessed(file: FileMetadata): Promise<FileMetadata | null> {
return this.sdk.checkIfAlreadyProcessed(file)
}
}
+20
View File
@@ -0,0 +1,20 @@
import { isMac } from '@main/constant'
import { OcrProvider } from '@types'
import Logger from 'electron-log'
import BaseOcrProvider from './BaseOcrProvider'
import DefaultOcrProvider from './DefaultOcrProvider'
import MacSysOcrProvider from './MacSysOcrProvider'
export default class OcrProviderFactory {
static create(provider: OcrProvider): BaseOcrProvider {
switch (provider.id) {
case 'system':
if (!isMac) {
Logger.warn('[OCR] System OCR provider is only available on macOS')
}
return new MacSysOcrProvider(provider)
default:
return new DefaultOcrProvider(provider)
}
}
}
@@ -4,7 +4,6 @@ import path from 'node:path'
import { windowService } from '@main/services/WindowService'
import { getFileExt } from '@main/utils/file'
import { FileMetadata, PreprocessProvider } from '@types'
import { createCanvas, loadImage } from 'canvas'
import { app } from 'electron'
import { TypedArray } from 'pdfjs-dist/types/src/display/api'
@@ -120,81 +119,4 @@ export default abstract class BasePreprocessProvider {
}
return movedPaths
}
public async cropImage(image: Buffer | string) {
const img = await loadImage(image)
const width = img.width
const height = img.height
const canvas = createCanvas(width, height)
const context = canvas.getContext('2d')
context.drawImage(img, 0, 0)
const data = context.getImageData(0, 0, width, height).data
const top = scanY(true)
const bottom = scanY(false)
const left = scanX(true)
const right = scanX(false)
if (top === null || bottom === null || left === null || right === null) {
console.error('image is empty')
return canvas.toBuffer()
}
const new_width = right - left
const new_height = bottom - top
canvas.width = new_width
canvas.height = new_height
context.drawImage(img, left, top, new_width, new_height, 0, 0, new_width, new_height)
return canvas.toBuffer()
// get pixel RGB data:
function getRGB(x: number, y: number) {
return {
red: data[(width * y + x) * 4],
green: data[(width * y + x) * 4 + 1],
blue: data[(width * y + x) * 4 + 2]
}
}
// check if pixel is a color other than white:
function isColor(rgb: { red: number; green: number; blue: number }) {
return rgb.red == 255 && rgb.green == 255 && rgb.blue == 255
}
// scan top and bottom edges of image:
function scanY(top: boolean) {
const offset = top ? 1 : -1
for (let y = top ? 0 : height - 1; top ? y < height : y > -1; y += offset) {
for (let x = 0; x < width; x++) {
if (!isColor(getRGB(x, y))) {
return y
}
}
}
return null
}
// scan left and right edges of image:
function scanX(left: boolean) {
const offset = left ? 1 : -1
for (let x = left ? 0 : width - 1; left ? x < width : x > -1; x += offset) {
for (let y = 0; y < height; y++) {
if (!isColor(getRGB(x, y))) {
return x
}
}
}
return null
}
}
}
@@ -58,7 +58,7 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
Logger.info(`MinerU processing completed for batch: ${batchId}`)
// 3. 下载并解压文件
const { path: outputPath } = await this.downloadAndExtractFile(extractResult.full_zip_url!, file.path)
const { path: outputPath } = await this.downloadAndExtractFile(extractResult.full_zip_url!, file)
// 4. 创建处理后的文件信息
return {
@@ -125,11 +125,11 @@ export default class MineruPreprocessProvider extends BasePreprocessProvider {
}
}
private async downloadAndExtractFile(zipUrl: string, originalFilePath: string): Promise<{ path: string }> {
private async downloadAndExtractFile(zipUrl: string, file: FileMetadata): Promise<{ path: string }> {
const dirPath = this.storageDir
const baseName = path.basename(originalFilePath, path.extname(originalFilePath))
const zipPath = path.join(dirPath, `${baseName}.zip`)
const extractPath = path.join(dirPath, `${baseName}`)
const zipPath = path.join(dirPath, `${file.id}.zip`)
const extractPath = path.join(dirPath, `${file.id}`)
Logger.info(`Downloading MinerU result to: ${zipPath}`)
@@ -1,11 +1,8 @@
import { isMac } from '@main/constant'
import { PreprocessProvider } from '@types'
import Logger from 'electron-log'
import BasePreprocessProvider from './BasePreprocessProvider'
import DefaultPreprocessProvider from './DefaultPreprocessProvider'
import Doc2xPreprocessProvider from './Doc2xPreprocessProvider'
import MacSysOcrProvider from './MacSysOcrProvider'
import MineruPreprocessProvider from './MineruPreprocessProvider'
import MistralPreprocessProvider from './MistralPreprocessProvider'
export default class PreprocessProviderFactory {
@@ -15,11 +12,6 @@ export default class PreprocessProviderFactory {
return new Doc2xPreprocessProvider(provider)
case 'mistral':
return new MistralPreprocessProvider(provider)
case 'system':
if (!isMac) {
Logger.warn('[OCR] System OCR provider is only available on macOS')
}
return new MacSysOcrProvider(provider)
case 'mineru':
return new MineruPreprocessProvider(provider)
default:
+10 -5
View File
@@ -23,6 +23,7 @@ import { SitemapLoader } from '@cherrystudio/embedjs-loader-sitemap'
import { WebLoader } from '@cherrystudio/embedjs-loader-web'
import Embeddings from '@main/embeddings/Embeddings'
import { addFileLoader } from '@main/loader'
import OcrProvider from '@main/ocr/OcrProvider'
import PreprocessProvider from '@main/preprocess/PreprocessProvider'
import Reranker from '@main/reranker/Reranker'
import { windowService } from '@main/services/WindowService'
@@ -498,12 +499,16 @@ class KnowledgeService {
item: KnowledgeItem
): Promise<FileMetadata> => {
let fileToProcess: FileMetadata = file
if (base.preprocessProvider && file.ext.toLowerCase() === '.pdf') {
if (base.preprocessOrOcrProvider && file.ext.toLowerCase() === '.pdf') {
try {
const preprocessProvider = new PreprocessProvider(base.preprocessProvider)
let provider: PreprocessProvider | OcrProvider
if (base.preprocessOrOcrProvider.type === 'preprocess') {
provider = new PreprocessProvider(base.preprocessOrOcrProvider.provider)
} else {
provider = new OcrProvider(base.preprocessOrOcrProvider.provider)
}
// 首先检查文件是否已经被预处理过
const alreadyProcessed = await preprocessProvider.checkIfAlreadyProcessed(file)
const alreadyProcessed = await provider.checkIfAlreadyProcessed(file)
if (alreadyProcessed) {
Logger.info(`File already preprocess processed, using cached result: ${file.path}`)
return alreadyProcessed
@@ -511,7 +516,7 @@ class KnowledgeService {
// 执行预处理
Logger.info(`Starting preprocess processing for scanned PDF: ${file.path}`)
const { processedFile } = await preprocessProvider.parseFile(item.id, file)
const { processedFile } = await provider.parseFile(item.id, file)
fileToProcess = processedFile
} catch (err) {
Logger.error(`Preprocess processing failed: ${err}`)
+12
View File
@@ -0,0 +1,12 @@
import MacOSLogo from '@renderer/assets/images/providers/macos.svg'
export function getOcrProviderLogo(providerId: string) {
switch (providerId) {
case 'system':
return MacOSLogo
default:
return undefined
}
}
export const OCR_PROVIDER_CONFIG = {}
@@ -1,6 +1,5 @@
import Doc2xLogo from '@renderer/assets/images/ocr/doc2x.png'
import MinerULogo from '@renderer/assets/images/ocr/mineru.jpg'
import MacOSLogo from '@renderer/assets/images/providers/macos.svg'
import MistralLogo from '@renderer/assets/images/providers/mistral.png'
export function getPreprocessProviderLogo(providerId: string) {
@@ -9,8 +8,6 @@ export function getPreprocessProviderLogo(providerId: string) {
return Doc2xLogo
case 'mistral':
return MistralLogo
case 'system':
return MacOSLogo
case 'mineru':
return MinerULogo
default:
+1 -5
View File
@@ -145,11 +145,7 @@ export const useKnowledge = (baseId: string) => {
}
}
if (item.type === 'file' && typeof item.content === 'object') {
// await FileManager.deleteFile(item.content.id)
// file name 1.xxx.pdf
// get 1.xxx
// remove .pdf ext
await window.api.file.deleteDir(item.content.name.split('.').slice(0, -1).join('.'))
await window.api.file.deleteDir(item.content.id)
}
}
// 刷新项目
+45
View File
@@ -0,0 +1,45 @@
import { RootState } from '@renderer/store'
import {
setDefaultOcrProvider as _setDefaultOcrProvider,
updateOcrProvider as _updateOcrProvider,
updateOcrProviders as _updateOcrProviders
} from '@renderer/store/ocr'
import { OcrProvider } from '@renderer/types'
import { useDispatch, useSelector } from 'react-redux'
export const useOcrProvider = (id: string) => {
const dispatch = useDispatch()
const ocrProviders = useSelector((state: RootState) => state.ocr.providers)
const provider = ocrProviders.find((provider) => provider.id === id)
if (!provider) {
throw new Error(`OCR provider with id ${id} not found`)
}
const updateOcrProvider = (ocrProvider: OcrProvider) => {
dispatch(_updateOcrProvider(ocrProvider))
}
return { provider, updateOcrProvider }
}
export const useOcrProviders = () => {
const dispatch = useDispatch()
const ocrProviders = useSelector((state: RootState) => state.ocr.providers)
return {
ocrProviders: ocrProviders,
updateOcrProviders: (ocrProviders: OcrProvider[]) => dispatch(_updateOcrProviders(ocrProviders))
}
}
export const useDefaultOcrProvider = () => {
const defaultProviderId = useSelector((state: RootState) => state.ocr.defaultProvider)
const { ocrProviders } = useOcrProviders()
const dispatch = useDispatch()
const provider = defaultProviderId ? ocrProviders.find((provider) => provider.id === defaultProviderId) : undefined
const setDefaultOcrProvider = (ocrProvider: OcrProvider) => {
dispatch(_setDefaultOcrProvider(ocrProvider.id))
}
const updateDefaultOcrProvider = (ocrProvider: OcrProvider) => {
dispatch(_updateOcrProvider(ocrProvider))
}
return { provider, setDefaultOcrProvider, updateDefaultOcrProvider }
}
+7
View File
@@ -1726,6 +1726,13 @@
"title": "Pre Process",
"provider": "Pre Process Provider",
"provider_placeholder": "Choose a Pre Process provider",
"preprocess_tooltip": "Setting the document preprocessing service provider in Settings -> Tools can effectively improve the retrieval performance for complex document formats."
},
"ocr": {
"title": "OCR",
"provider": "OCR Provider",
"provider_placeholder": "Choose an OCR provider",
"ocr_tooltip": "Setting the OCR service provider in Settings -> Tools can effectively improve the text recognition performance for images and documents.",
"mac_system_ocr_options": {
"mode": {
"title": "Recognition Mode",
+8 -1
View File
@@ -1772,6 +1772,13 @@
"title": "文档预处理",
"provider": "文档预处理",
"provider_placeholder": "选择一个文档预处理服务商",
"preprocess_tooltip": "在设置 -> 工具中设置文档预处理服务商,可以有效提升复杂文档格式的检索效果"
},
"ocr": {
"title": "OCR 设置",
"provider": "OCR 服务商",
"provider_placeholder": "选择一个 OCR 服务商",
"ocr_tooltip": "在设置 -> 工具中设置 OCR 服务商,可以有效提升图片文字识别效果",
"mac_system_ocr_options": {
"mode": {
"title": "识别模式",
@@ -2042,4 +2049,4 @@
}
}
}
}
}
@@ -6,12 +6,13 @@ import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { NOT_SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
// import { SUPPORTED_REANK_PROVIDERS } from '@renderer/config/providers'
import { useKnowledgeBases } from '@renderer/hooks/useKnowledge'
import { useOcrProviders } from '@renderer/hooks/useOcr'
import { usePreprocessProviders } from '@renderer/hooks/usePreprocess'
import { useProviders } from '@renderer/hooks/useProvider'
import AiProvider from '@renderer/providers/AiProvider'
import { getKnowledgeBaseParams } from '@renderer/services/KnowledgeService'
import { getModelUniqId } from '@renderer/services/ModelService'
import { KnowledgeBase, Model, PreprocessProvider } from '@renderer/types'
import { KnowledgeBase, Model, OcrProvider, PreprocessProvider } from '@renderer/types'
import { getErrorMessage } from '@renderer/utils/error'
import { Alert, Input, InputNumber, Modal, Select, Slider, Tabs, TabsProps, Tooltip } from 'antd'
import { find, sortBy } from 'lodash'
@@ -37,7 +38,8 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
const [newBase, setNewBase] = useState<KnowledgeBase>({} as KnowledgeBase)
const { preprocessProviders } = usePreprocessProviders()
const [selectedProvider, setSelectedProvider] = useState<PreprocessProvider | undefined>(undefined)
const { ocrProviders } = useOcrProviders()
const [selectedProvider, setSelectedProvider] = useState<PreprocessProvider | OcrProvider | undefined>(undefined)
const embeddingModels = useMemo(() => {
return providers
@@ -89,6 +91,20 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
.filter((group) => group.options.length > 0)
}, [providers, t])
const preprocessOrOcrSelectOptions = useMemo(() => {
const preprocessOptions = {
label: t('settings.tool.preprocess.provider'),
title: t('settings.tool.preprocess.provider'),
options: preprocessProviders.filter((p) => p.apiKey !== '').map((p) => ({ value: p.id, label: p.name }))
}
const ocrOptions = {
label: t('settings.tool.ocr.provider'),
title: t('settings.tool.ocr.provider'),
options: ocrProviders.filter((p) => p.apiKey !== '').map((p) => ({ value: p.id, label: p.name }))
}
return [preprocessOptions, ocrOptions]
}, [ocrProviders, preprocessProviders])
const onOk = async () => {
try {
// const values = await form.validateFields()
@@ -167,6 +183,44 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
/>
</SettingsItem>
<SettingsItem>
<div className="settings-label">
{t('settings.tool.preprocess.title')}
<Tooltip title={t('settings.tool.preprocess.preprocess_tooltip')} placement="right">
<InfoCircleOutlined style={{ marginLeft: 8 }} />
</Tooltip>
</div>
<Select
value={selectedProvider?.id}
style={{ width: '100%' }}
onChange={(value: string) => {
const type = preprocessProviders.find((p) => p.id === value) ? 'preprocess' : 'ocr'
const provider = (type === 'preprocess' ? preprocessProviders : ocrProviders).find(
(p) => p.id === value
)
if (!provider) {
setSelectedProvider(undefined)
setNewBase({
...newBase,
preprocessOrOcrProvider: undefined
})
return
}
setSelectedProvider(provider)
setNewBase({
...newBase,
preprocessOrOcrProvider: {
type: type,
provider: provider
}
})
}}
placeholder={t('settings.tool.preprocess.provider_placeholder')}
options={preprocessOrOcrSelectOptions}
allowClear
/>
</SettingsItem>
<SettingsItem>
<div className="settings-label">
{t('models.embedding_model')}
@@ -202,22 +256,6 @@ const PopupContainer: React.FC<Props> = ({ title, resolve }) => {
/>
</SettingsItem>
<SettingsItem>
<div className="settings-label">{t('settings.tool.preprocess.title')}</div>
<Select
value={selectedProvider?.id}
style={{ width: '100%' }}
onChange={(value: string) => {
const provider = preprocessProviders.find((p) => p.id === value)
setSelectedProvider(provider)
setNewBase({ ...newBase, preprocessProvider: provider })
}}
placeholder={t('settings.tool.preprocess.provider_placeholder')}
options={preprocessProviders.filter((p) => p.apiKey !== '').map((p) => ({ value: p.id, label: p.name }))}
allowClear
/>
</SettingsItem>
<SettingsItem>
<div className="settings-label">
{t('knowledge.document_count')}
@@ -4,6 +4,7 @@ import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT } from '@renderer/config/constant'
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
import { isEmbeddingModel, isRerankModel } from '@renderer/config/models'
import { useKnowledge } from '@renderer/hooks/useKnowledge'
import { useOcrProviders } from '@renderer/hooks/useOcr'
import { usePreprocessProviders } from '@renderer/hooks/usePreprocess'
import { useProviders } from '@renderer/hooks/useProvider'
import { getModelUniqId } from '@renderer/services/ModelService'
@@ -24,7 +25,11 @@ interface Props extends ShowParams {
const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
const { preprocessProviders } = usePreprocessProviders()
const [selectedProvider, setSelectedProvider] = useState<PreprocessProvider | undefined>(_base.preprocessProvider)
const { ocrProviders } = useOcrProviders()
const [selectedProvider, setSelectedProvider] = useState<PreprocessProvider | undefined>(
_base.preprocessOrOcrProvider?.provider
)
const [open, setOpen] = useState(true)
const { t } = useTranslation()
@@ -65,6 +70,19 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
}))
.filter((group) => group.options.length > 0)
const preprocessOptions = {
label: t('settings.tool.preprocess.provider'),
title: t('settings.tool.preprocess.provider'),
options: preprocessProviders.filter((p) => p.apiKey !== '').map((p) => ({ value: p.id, label: p.name }))
}
const ocrOptions = {
label: t('settings.tool.ocr.provider'),
title: t('settings.tool.ocr.provider'),
options: ocrProviders.filter((p) => p.apiKey !== '').map((p) => ({ value: p.id, label: p.name }))
}
const preprocessOrOcrSelectOptions = [preprocessOptions, ocrOptions].filter((group) => group.options.length > 0)
const onOk = async () => {
try {
console.log('newbase', newBase)
@@ -99,6 +117,44 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
/>
</SettingsItem>
<SettingsItem>
<div className="settings-label">
{t('settings.tool.preprocess.title')}
<Tooltip title={t('settings.tool.preprocess.preprocess_tooltip')} placement="right">
<InfoCircleOutlined style={{ marginLeft: 8 }} />
</Tooltip>
</div>
<Select
value={selectedProvider?.id}
style={{ width: '100%' }}
onChange={(value: string) => {
const type = preprocessProviders.find((p) => p.id === value) ? 'preprocess' : 'ocr'
const provider = (type === 'preprocess' ? preprocessProviders : ocrProviders).find(
(p) => p.id === value
)
if (!provider) {
setSelectedProvider(undefined)
setNewBase({
...newBase,
preprocessOrOcrProvider: undefined
})
return
}
setSelectedProvider(provider)
setNewBase({
...newBase,
preprocessOrOcrProvider: {
type: type,
provider: provider
}
})
}}
placeholder={t('settings.tool.preprocess.provider_placeholder')}
options={preprocessOrOcrSelectOptions}
allowClear
/>
</SettingsItem>
<SettingsItem>
<div className="settings-label">
{t('models.embedding_model')}
@@ -137,22 +193,6 @@ const PopupContainer: React.FC<Props> = ({ base: _base, resolve }) => {
/>
</SettingsItem>
<SettingsItem>
<div className="settings-label">{t('settings.tool.preprocess.title')}</div>
<Select
value={selectedProvider?.id}
style={{ width: '100%' }}
onChange={(value: string) => {
const provider = preprocessProviders.find((p) => p.id === value)
setSelectedProvider(provider)
setNewBase({ ...newBase, preprocessProvider: provider })
}}
placeholder={t('settings.tool.preprocess.provider_placeholder')}
options={preprocessProviders.filter((p) => p.apiKey !== '').map((p) => ({ value: p.id, label: p.name }))}
allowClear
/>
</SettingsItem>
<SettingsItem>
<div className="settings-label">
{t('knowledge.document_count')}
@@ -0,0 +1,168 @@
import { ExportOutlined } from '@ant-design/icons'
import { getOcrProviderLogo, OCR_PROVIDER_CONFIG } from '@renderer/config/ocrProviders'
import { useOcrProvider } from '@renderer/hooks/useOcr'
import { formatApiKeys } from '@renderer/services/ApiService'
import { OcrProvider } from '@renderer/types'
import { hasObjectKey } from '@renderer/utils'
import { Avatar, Divider, Flex, Input, InputNumber, Segmented } from 'antd'
import Link from 'antd/es/typography/Link'
import { FC, useEffect, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
import {
SettingDivider,
SettingHelpLink,
SettingHelpText,
SettingHelpTextRow,
SettingRow,
SettingRowTitle,
SettingSubtitle,
SettingTitle
} from '../..'
interface Props {
provider: OcrProvider
}
const OcrProviderSettings: FC<Props> = ({ provider: _provider }) => {
const { provider: ocrProvider, updateOcrProvider } = useOcrProvider(_provider.id)
const { t } = useTranslation()
const [apiKey, setApiKey] = useState(ocrProvider.apiKey || '')
const [apiHost, setApiHost] = useState(ocrProvider.apiHost || '')
const [options, setOptions] = useState(ocrProvider.options || {})
const ocrProviderConfig = OCR_PROVIDER_CONFIG[ocrProvider.id]
const apiKeyWebsite = ocrProviderConfig?.websites?.apiKey
const officialWebsite = ocrProviderConfig?.websites?.official
useEffect(() => {
setApiKey(ocrProvider.apiKey ?? '')
setApiHost(ocrProvider.apiHost ?? '')
setOptions(ocrProvider.options ?? {})
}, [ocrProvider.apiKey, ocrProvider.apiHost, ocrProvider.options])
const onUpdateApiKey = () => {
if (apiKey !== ocrProvider.apiKey) {
updateOcrProvider({ ...ocrProvider, apiKey })
}
}
const onUpdateApiHost = () => {
let trimmedHost = apiHost?.trim() || ''
if (trimmedHost.endsWith('/')) {
trimmedHost = trimmedHost.slice(0, -1)
}
if (trimmedHost !== ocrProvider.apiHost) {
updateOcrProvider({ ...ocrProvider, apiHost: trimmedHost })
} else {
setApiHost(ocrProvider.apiHost || '')
}
}
const onUpdateOptions = (key: string, value: any) => {
const newOptions = { ...options, [key]: value }
setOptions(newOptions)
updateOcrProvider({ ...ocrProvider, options: newOptions })
}
return (
<>
<SettingTitle>
<Flex align="center" gap={8}>
<ProviderLogo shape="square" src={getOcrProviderLogo(ocrProvider.id)} size={16} />
<ProviderName> {ocrProvider.name}</ProviderName>
{officialWebsite && ocrProviderConfig?.websites && (
<Link target="_blank" href={ocrProviderConfig.websites.official}>
<ExportOutlined style={{ color: 'var(--color-text)', fontSize: '12px' }} />
</Link>
)}
</Flex>
</SettingTitle>
<Divider style={{ width: '100%', margin: '10px 0' }} />
{hasObjectKey(ocrProvider, 'apiKey') && (
<>
<SettingSubtitle style={{ marginTop: 5, marginBottom: 10 }}>{t('settings.provider.api_key')}</SettingSubtitle>
<Flex gap={8}>
<Input.Password
value={apiKey}
placeholder={t('settings.provider.api_key')}
onChange={(e) => setApiKey(formatApiKeys(e.target.value))}
onBlur={onUpdateApiKey}
spellCheck={false}
type="password"
autoFocus={apiKey === ''}
/>
</Flex>
<SettingHelpTextRow style={{ justifyContent: 'space-between', marginTop: 5 }}>
<SettingHelpLink target="_blank" href={apiKeyWebsite}>
{t('settings.provider.get_api_key')}
</SettingHelpLink>
<SettingHelpText>{t('settings.provider.api_key.tip')}</SettingHelpText>
</SettingHelpTextRow>
</>
)}
{hasObjectKey(ocrProvider, 'apiHost') && (
<>
<SettingSubtitle style={{ marginTop: 5, marginBottom: 10 }}>
{t('settings.provider.api_host')}
</SettingSubtitle>
<Flex>
<Input
value={apiHost}
placeholder={t('settings.provider.api_host')}
onChange={(e) => setApiHost(e.target.value)}
onBlur={onUpdateApiHost}
/>
</Flex>
</>
)}
{hasObjectKey(ocrProvider, 'options') && ocrProvider.id === 'system' && (
<>
<SettingDivider style={{ marginTop: 15, marginBottom: 12 }} />
<SettingRow>
<SettingRowTitle>{t('settings.tool.ocr.mac_system_ocr_options.mode.title')}</SettingRowTitle>
<Segmented
options={[
{
label: t('settings.tool.ocr.mac_system_ocr_options.mode.accurate'),
value: 1
},
{
label: t('settings.tool.ocr.mac_system_ocr_options.mode.fast'),
value: 0
}
]}
value={options.recognitionLevel}
onChange={(value) => onUpdateOptions('recognitionLevel', value)}
/>
</SettingRow>
<SettingDivider style={{ marginTop: 15, marginBottom: 12 }} />
<SettingRow>
<SettingRowTitle>{t('settings.tool.ocr.mac_system_ocr_options.min_confidence')}</SettingRowTitle>
<InputNumber
value={options.minConfidence}
onChange={(value) => onUpdateOptions('minConfidence', value)}
min={0}
max={1}
step={0.1}
/>
</SettingRow>
</>
)}
</>
)
}
const ProviderName = styled.span`
font-size: 14px;
font-weight: 500;
`
const ProviderLogo = styled(Avatar)`
border: 0.5px solid var(--color-border);
`
export default OcrProviderSettings
@@ -0,0 +1,58 @@
import { isMac } from '@renderer/config/constant'
import { useTheme } from '@renderer/context/ThemeProvider'
import { useDefaultOcrProvider, useOcrProviders } from '@renderer/hooks/useOcr'
import { PreprocessProvider } from '@renderer/types'
import { Select } from 'antd'
import { FC, useState } from 'react'
import { useTranslation } from 'react-i18next'
import { SettingContainer, SettingDivider, SettingGroup, SettingRow, SettingRowTitle, SettingTitle } from '../..'
import OcrProviderSettings from './OcrSettings'
const OcrSettings: FC = () => {
const { ocrProviders } = useOcrProviders()
const { provider: defaultProvider, setDefaultOcrProvider } = useDefaultOcrProvider()
const { t } = useTranslation()
const [selectedProvider, setSelectedProvider] = useState<PreprocessProvider | undefined>(defaultProvider)
const { theme: themeMode } = useTheme()
function updateSelectedOcrProvider(providerId: string) {
const provider = ocrProviders.find((p) => p.id === providerId)
if (!provider) {
return
}
setDefaultOcrProvider(provider)
setSelectedProvider(provider)
}
return (
<SettingContainer theme={themeMode}>
<SettingGroup theme={themeMode}>
<SettingTitle>{t('settings.tool.ocr.title')}</SettingTitle>
<SettingDivider />
<SettingRow>
<SettingRowTitle>{t('settings.tool.ocr.provider')}</SettingRowTitle>
<div style={{ display: 'flex', gap: '8px' }}>
<Select
value={selectedProvider?.id}
style={{ width: '200px' }}
onChange={(value: string) => updateSelectedOcrProvider(value)}
placeholder={t('settings.tool.ocr.provider_placeholder')}
options={ocrProviders.map((p) => ({
value: p.id,
label: p.name,
disabled: !isMac && p.id === 'system' // 在非 Mac 系统下禁用 system 选项
}))}
/>
</div>
</SettingRow>
</SettingGroup>
{selectedProvider && (
<SettingGroup theme={themeMode}>
<OcrProviderSettings provider={selectedProvider} />
</SettingGroup>
)}
</SettingContainer>
)
}
export default OcrSettings
@@ -1,4 +1,5 @@
import { GlobalOutlined } from '@ant-design/icons'
import OcrIcon from '@renderer/components/Icons/OcrIcon'
import { HStack } from '@renderer/components/Layout'
import ListItem from '@renderer/components/ListItem'
import { FileCode } from 'lucide-react'
@@ -6,6 +7,7 @@ import { FC, useState } from 'react'
import { useTranslation } from 'react-i18next'
import styled from 'styled-components'
import OcrSettings from './OcrSettings'
import PreprocessSettings from './PreprocessSettings'
import WebSearchSettings from './WebSearchSettings'
@@ -14,8 +16,8 @@ const ToolSettings: FC = () => {
const [menu, setMenu] = useState<string>('web-search')
const menuItems = [
{ key: 'web-search', title: 'settings.tool.websearch.title', icon: <GlobalOutlined style={{ fontSize: 16 }} /> },
{ key: 'preprocess', title: 'settings.tool.preprocess.title', icon: <FileCode size={16} /> }
// { key: 'ocr', title: 'settings.tool.ocr.title', icon: <OcrIcon /> }
{ key: 'preprocess', title: 'settings.tool.preprocess.title', icon: <FileCode size={16} /> },
{ key: 'ocr', title: 'settings.tool.ocr.title', icon: <OcrIcon /> }
]
return (
<Container>
@@ -33,7 +35,7 @@ const ToolSettings: FC = () => {
</MenuList>
{menu == 'web-search' && <WebSearchSettings />}
{menu == 'preprocess' && <PreprocessSettings />}
{/* {menu == 'ocr' && <OcrSettings />} */}
{menu == 'ocr' && <OcrSettings />}
</Container>
)
}
@@ -51,7 +51,7 @@ export const getKnowledgeBaseParams = (base: KnowledgeBase): KnowledgeBaseParams
rerankModelProvider: base.rerankModel?.provider,
// topN: base.topN,
// preprocessing: base.preprocessing,
preprocessProvider: base.preprocessProvider
preprocessOrOcrProvider: base.preprocessOrOcrProvider
}
}
+3 -1
View File
@@ -17,6 +17,7 @@ import migrate from './migrate'
import minapps from './minapps'
import newMessagesReducer from './newMessage'
import nutstore from './nutstore'
import ocr from './ocr'
import paintings from './paintings'
import preprocess from './preprocess'
import runtime from './runtime'
@@ -34,6 +35,7 @@ const rootReducer = combineReducers({
llm,
settings,
runtime,
ocr,
shortcuts,
knowledge,
minapps,
@@ -52,7 +54,7 @@ const persistedReducer = persistReducer(
{
key: 'cherry-studio',
storage,
version: 109,
version: 110,
blacklist: ['runtime', 'messages', 'messageBlocks'],
migrate
},
+2 -2
View File
@@ -1505,8 +1505,8 @@ const migrateConfig = {
}
]
}
if (!state.preprocess.providers.find((provider) => provider.id === 'system')) {
state.preprocess.providers.push({
if (!state.ocr.providers.find((provider) => provider.id === 'system')) {
state.ocr.providers.push({
id: 'system',
name: 'System(Mac Only)',
options: {
+46
View File
@@ -0,0 +1,46 @@
import { createSlice, PayloadAction } from '@reduxjs/toolkit'
import { OcrProvider } from '@renderer/types'
export interface OcrState {
providers: OcrProvider[]
defaultProvider: string
}
const initialState: OcrState = {
providers: [
{
id: 'system',
name: 'System(Mac Only)',
options: {
recognitionLevel: 0,
minConfidence: 0.5
}
}
],
defaultProvider: ''
}
const ocrSlice = createSlice({
name: 'ocr',
initialState,
reducers: {
setDefaultOcrProvider(state, action: PayloadAction<string>) {
state.defaultProvider = action.payload
},
setOcrProviders(state, action: PayloadAction<OcrProvider[]>) {
state.providers = action.payload
},
updateOcrProviders(state, action: PayloadAction<OcrProvider[]>) {
state.providers = action.payload
},
updateOcrProvider(state, action: PayloadAction<OcrProvider>) {
const index = state.providers.findIndex((provider) => provider.id === action.payload.id)
if (index !== -1) {
state.providers[index] = action.payload
}
}
}
})
export const { updateOcrProviders, updateOcrProvider, setDefaultOcrProvider, setOcrProviders } = ocrSlice.actions
export default ocrSlice.reducer
-8
View File
@@ -26,14 +26,6 @@ const initialState: PreprocessState = {
model: 'mistral-ocr-latest',
apiKey: '',
apiHost: 'https://api.mistral.ai'
},
{
id: 'system',
name: 'System(Mac Only)',
options: {
recognitionLevel: 0,
minConfidence: 0.5
}
}
],
defaultProvider: 'mineru'
+17 -2
View File
@@ -382,7 +382,10 @@ export interface KnowledgeBase {
rerankModel?: Model
// topN?: number
// preprocessing?: boolean
preprocessProvider?: PreprocessProvider
preprocessOrOcrProvider?: {
type: 'preprocess' | 'ocr'
provider: PreprocessProvider | OcrProvider
}
}
export type KnowledgeBaseParams = {
@@ -400,7 +403,10 @@ export type KnowledgeBaseParams = {
rerankModelProvider?: string
documentCount?: number
// preprocessing?: boolean
preprocessProvider?: PreprocessProvider
preprocessOrOcrProvider?: {
type: 'preprocess' | 'ocr'
provider: PreprocessProvider | OcrProvider
}
}
export interface PreprocessProvider {
@@ -412,6 +418,15 @@ export interface PreprocessProvider {
options?: any
}
export interface OcrProvider {
id: string
name: string
apiKey?: string
apiHost?: string
model?: string
options?: any
}
export type GenerateImageParams = {
model: string
prompt: string