Merge branch 'feat-knowlege-ocr' of https://github.com/eeee0717/cherry-studio into feat/ocr

This commit is contained in:
suyao
2025-06-03 14:58:21 +08:00
15 changed files with 426 additions and 65 deletions
+1 -1
View File
@@ -1,6 +1,6 @@
{
"name": "CherryStudio",
"version": "1.3.12",
"version": "1.5.0-rc.1",
"private": true,
"description": "A powerful AI assistant for producer.",
"main": "./out/main/index.js",
+5
View File
@@ -295,6 +295,11 @@ export default class Doc2xOcrProvider extends BaseOcrProvider {
const response = await axios.get(url, { responseType: 'arraybuffer' })
fs.writeFileSync(zipPath, response.data)
// 确保提取目录存在
if (!fs.existsSync(extractPath)) {
fs.mkdirSync(extractPath, { recursive: true })
}
// 解压文件
const zip = new streamZip({ file: zipPath })
zip.extract(null, extractPath, (err) => {
+333
View File
@@ -0,0 +1,333 @@
import fs from 'node:fs'
import path from 'node:path'
import { FileType, OcrProvider } from '@types'
import AdmZip from 'adm-zip'
import axios from 'axios'
import Logger from 'electron-log'
import BaseOcrProvider from './BaseOcrProvider'
type ApiResponse<T> = {
code: number
data: T
msg?: string
trace_id?: string
}
type BatchUploadResponse = {
batch_id: string
file_urls: string[]
}
type ExtractProgress = {
extracted_pages: number
total_pages: number
start_time: string
}
type ExtractFileResult = {
file_name: string
state: 'done' | 'waiting-file' | 'pending' | 'running' | 'converting' | 'failed'
err_msg: string
full_zip_url?: string
extract_progress?: ExtractProgress
}
type ExtractResultResponse = {
batch_id: string
extract_result: ExtractFileResult[]
}
export default class MineruOcrProvider extends BaseOcrProvider {
constructor(provider: OcrProvider) {
super(provider)
}
public async parseFile(sourceId: string, file: FileType): Promise<{ processedFile: FileType }> {
try {
Logger.info(`MinerU OCR processing started: ${file.path}`)
await this.validateFile(file.path)
// 1. 获取上传URL并上传文件
const batchId = await this.uploadFile(file)
Logger.info(`MinerU file upload completed: batch_id=${batchId}`)
// 2. 等待处理完成并获取结果
const extractResult = await this.waitForCompletion(sourceId, batchId, file.origin_name)
Logger.info(`MinerU processing completed for batch: ${batchId}`)
// 3. 下载并解压文件
const { path: outputPath } = await this.downloadAndExtractFile(extractResult.full_zip_url!, file.path)
// 4. 创建处理后的文件信息
return {
processedFile: this.createProcessedFileInfo(file, outputPath)
}
} catch (error: any) {
Logger.error(`MinerU OCR processing failed for ${file.path}: ${error.message}`)
throw new Error(`OCR processing failed: ${error.message}`)
}
}
private async validateFile(filePath: string): Promise<void> {
const pdfBuffer = await fs.promises.readFile(filePath)
const doc = await this.readPdf(new Uint8Array(pdfBuffer))
// 文件页数小于600页
if (doc.numPages >= 600) {
throw new Error(`PDF page count (${doc.numPages}) exceeds the limit of 600 pages`)
}
// 文件大小小于200MB
if (pdfBuffer.length >= 200 * 1024 * 1024) {
const fileSizeMB = Math.round(pdfBuffer.length / (1024 * 1024))
throw new Error(`PDF file size (${fileSizeMB}MB) exceeds the limit of 200MB`)
}
}
private createProcessedFileInfo(file: FileType, outputPath: string): FileType {
// 查找解压后的主要文件
let finalPath = ''
let finalName = file.origin_name.replace('.pdf', '.md')
try {
const files = fs.readdirSync(outputPath)
const mdFile = files.find((f) => f.endsWith('.md'))
if (mdFile) {
const originalMdPath = path.join(outputPath, mdFile)
const newMdPath = path.join(outputPath, finalName)
// 重命名文件为原始文件名
try {
fs.renameSync(originalMdPath, newMdPath)
finalPath = newMdPath
Logger.info(`Renamed markdown file from ${mdFile} to ${finalName}`)
} catch (renameError) {
Logger.warn(`Failed to rename file ${mdFile} to ${finalName}: ${renameError}`)
// 如果重命名失败,使用原文件
finalPath = originalMdPath
finalName = mdFile
}
}
} catch (error) {
Logger.warn(`Failed to read output directory ${outputPath}: ${error}`)
finalPath = path.join(outputPath, `${file.id}.md`)
}
return {
...file,
name: finalName,
path: finalPath,
ext: '.md',
size: fs.existsSync(finalPath) ? fs.statSync(finalPath).size : 0
}
}
private async downloadAndExtractFile(zipUrl: string, originalFilePath: string): Promise<{ path: string }> {
const dirPath = path.dirname(originalFilePath)
const baseName = path.basename(originalFilePath, path.extname(originalFilePath))
const zipPath = path.join(dirPath, `${baseName}.zip`)
const extractPath = path.join(dirPath, `${baseName}`)
Logger.info(`Downloading MinerU result to: ${zipPath}`)
try {
// 下载ZIP文件
const response = await axios.get(zipUrl, { responseType: 'arraybuffer' })
fs.writeFileSync(zipPath, response.data)
Logger.info(`Downloaded ZIP file: ${zipPath}`)
// 确保提取目录存在
if (!fs.existsSync(extractPath)) {
fs.mkdirSync(extractPath, { recursive: true })
}
// 解压文件
const zip = new AdmZip(zipPath)
zip.extractAllTo(extractPath, true)
Logger.info(`Extracted files to: ${extractPath}`)
// 删除临时ZIP文件
fs.unlinkSync(zipPath)
return { path: extractPath }
} catch (error) {
Logger.error(`Failed to download and extract file: ${error instanceof Error ? error.message : String(error)}`)
throw new Error('Failed to download and extract file')
}
}
private async uploadFile(file: FileType): Promise<string> {
try {
// 步骤1: 获取上传URL
const { batchId, fileUrls } = await this.getBatchUploadUrls(file)
Logger.info(`Got upload URLs for batch: ${batchId}`)
console.log('batchId:', batchId, 'fileurls:', fileUrls)
// 步骤2: 上传文件到获取的URL
await this.putFileToUrl(file.path, fileUrls[0])
Logger.info(`File uploaded successfully: ${file.path}`)
return batchId
} catch (error) {
Logger.error(`Failed to upload file ${file.path}: ${error instanceof Error ? error.message : String(error)}`)
throw new Error('Failed to upload file')
}
}
private async getBatchUploadUrls(file: FileType): Promise<{ batchId: string; fileUrls: string[] }> {
const endpoint = `${this.provider.apiHost}/api/v4/file-urls/batch`
const payload = {
language: 'auto',
enable_formula: true,
enable_table: true,
files: [
{
name: file.origin_name,
is_ocr: true,
data_id: file.id
}
]
}
try {
const response = await fetch(endpoint, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.provider.apiKey}`
},
body: JSON.stringify(payload)
})
if (response.ok) {
const data: ApiResponse<BatchUploadResponse> = await response.json()
if (data.code === 0 && data.data) {
const { batch_id, file_urls } = data.data
return {
batchId: batch_id,
fileUrls: file_urls
}
} else {
throw new Error(`API returned error: ${data.msg || JSON.stringify(data)}`)
}
} else {
throw new Error(`HTTP ${response.status}: ${response.statusText}`)
}
} catch (error) {
Logger.error(`Failed to get batch upload URLs: ${error instanceof Error ? error.message : String(error)}`)
throw new Error('Failed to get upload URLs')
}
}
private async putFileToUrl(filePath: string, uploadUrl: string): Promise<void> {
try {
const fileBuffer = await fs.promises.readFile(filePath)
const response = await fetch(uploadUrl, {
method: 'PUT',
body: fileBuffer,
headers: {
'Content-Length': fileBuffer.length.toString()
}
})
if (!response.ok) {
throw new Error(`Upload failed with status ${response.status}: ${response.statusText}`)
}
Logger.info(`File uploaded successfully to: ${uploadUrl}`)
} catch (error) {
Logger.error(
`Failed to upload file to URL ${uploadUrl}: ${error instanceof Error ? error.message : String(error)}`
)
throw new Error('Failed to upload file to provided URL')
}
}
private async getExtractResults(batchId: string): Promise<ExtractResultResponse> {
const endpoint = `${this.provider.apiHost}/api/v4/extract-results/batch/${batchId}`
try {
const response = await fetch(endpoint, {
method: 'GET',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.provider.apiKey}`
}
})
if (response.ok) {
const data: ApiResponse<ExtractResultResponse> = await response.json()
if (data.code === 0 && data.data) {
return data.data
} else {
throw new Error(`API returned error: ${data.msg || JSON.stringify(data)}`)
}
} else {
throw new Error(`HTTP ${response.status}: ${response.statusText}`)
}
} catch (error) {
Logger.error(
`Failed to get extract results for batch ${batchId}: ${error instanceof Error ? error.message : String(error)}`
)
throw new Error('Failed to get extract results')
}
}
private async waitForCompletion(
sourceId: string,
batchId: string,
fileName: string,
maxRetries: number = 60,
intervalMs: number = 5000
): Promise<ExtractFileResult> {
let retries = 0
while (retries < maxRetries) {
try {
const result = await this.getExtractResults(batchId)
// 查找对应文件的处理结果
const fileResult = result.extract_result.find((item) => item.file_name === fileName)
if (!fileResult) {
throw new Error(`File ${fileName} not found in batch results`)
}
// 检查处理状态
if (fileResult.state === 'done' && fileResult.full_zip_url) {
Logger.info(`Processing completed for file: ${fileName}`)
return fileResult
} else if (fileResult.state === 'failed') {
throw new Error(`Processing failed for file: ${fileName}, error: ${fileResult.err_msg}`)
} else if (fileResult.state === 'running') {
// 发送进度更新
if (fileResult.extract_progress) {
const progress = Math.round(
(fileResult.extract_progress.extracted_pages / fileResult.extract_progress.total_pages) * 100
)
await this.sendOcrProgress(sourceId, progress)
Logger.info(`File ${fileName} processing progress: ${progress}%`)
} else {
// 如果没有具体进度信息,发送一个通用进度
await this.sendOcrProgress(sourceId, 50)
Logger.info(`File ${fileName} is still processing...`)
}
}
} catch (error) {
Logger.warn(`Failed to check status for batch ${batchId}, retry ${retries + 1}/${maxRetries}`)
if (retries === maxRetries - 1) {
throw error
}
}
retries++
await new Promise((resolve) => setTimeout(resolve, intervalMs))
}
throw new Error(`Processing timeout for batch: ${batchId}`)
}
}
+3
View File
@@ -6,6 +6,7 @@ import BaseOcrProvider from './BaseOcrProvider'
import DefaultOcrProvider from './DefaultOcrProvider'
import Doc2xOcrProvider from './Doc2xOcrProvider'
import MacSysOcrProvider from './MacSysOcrProvider'
import MineruOcrProvider from './MineruOcrProvider'
import MistralOcrProvider from './MistralOcrProvider'
export default class OcrProviderFactory {
static create(provider: OcrProvider): BaseOcrProvider {
@@ -19,6 +20,8 @@ export default class OcrProviderFactory {
Logger.warn('[OCR] System OCR provider is only available on macOS')
}
return new MacSysOcrProvider(provider)
case 'mineru':
return new MineruOcrProvider(provider)
default:
return new DefaultOcrProvider(provider)
}
Binary file not shown.

After

Width:  |  Height:  |  Size: 48 KiB

@@ -1 +0,0 @@
<svg width='42' height='42' viewBox='0 0 42 42' fill='none' xmlns='http://www.w3.org/2000/svg'><rect x='28.6606' y='8.44495' width='6.92163' height='12.0376' rx='3.46082' transform='rotate(45 28.6606 8.44495)' fill='#7748F9'/><rect x='16.957' y='20.1488' width='6.92163' height='12.0376' rx='3.46082' transform='rotate(45 16.957 20.1488)' fill='#7748F9'/><rect x='20.1489' y='25.0432' width='6.92163' height='12.0376' rx='3.46082' transform='rotate(-45 20.1489 25.0432)' fill='#BFABFB'/><rect x='8.44482' y='13.3394' width='6.92163' height='12.0376' rx='3.46082' transform='rotate(-45 8.44482 13.3394)' fill='#BFABFB'/></svg>

Before

Width:  |  Height:  |  Size: 625 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 16 KiB

+10 -1
View File
@@ -1,4 +1,5 @@
import Doc2xLogo from '@renderer/assets/images/ocr/doc2x.svg'
import Doc2xLogo from '@renderer/assets/images/ocr/doc2x.png'
import MinerULogo from '@renderer/assets/images/ocr/mineru.jpg'
import MacOSLogo from '@renderer/assets/images/providers/macos.svg'
import MistralLogo from '@renderer/assets/images/providers/mistral.png'
@@ -10,6 +11,8 @@ export function getOcrProviderLogo(providerId: string) {
return MistralLogo
case 'system':
return MacOSLogo
case 'mineru':
return MinerULogo
default:
return undefined
}
@@ -27,5 +30,11 @@ export const OCR_PROVIDER_CONFIG = {
official: 'https://mistral.ai',
apiKey: 'https://mistral.ai/api-keys'
}
},
mineru: {
websites: {
official: 'https://mineru.net/',
apiKey: 'https://mineru.net/apiManage'
}
}
}
+1 -1
View File
@@ -1936,4 +1936,4 @@
}
}
}
}
}
+1 -1
View File
@@ -1936,4 +1936,4 @@
}
}
}
}
}
+58 -58
View File
@@ -824,14 +824,14 @@
"seed_desc_tip": "相同的种子和提示词可以生成相似的图片,设置 -1 每次生成都不一样",
"title": "图片",
"magic_prompt_option": "提示词增强",
"model": "版本",
"model": "模型",
"aspect_ratio": "画幅比例",
"style_type": "风格",
"rendering_speed": "渲染速度",
"learn_more": "了解更多",
"paint_course": "教程",
"prompt_placeholder_edit": "输入你的图片描述,文本绘制用 \"双引号\" 包裹",
"proxy_required": "目前需要打开代理才能查看生成图片,后续会支持国内直连",
"proxy_required": "打开代理并开启”TUN模式“查看生成图片或复制到浏览器打开,后续会支持国内直连",
"image_file_required": "请先上传图片",
"image_file_retry": "请重新上传图片",
"image_placeholder": "暂无图片",
@@ -854,7 +854,7 @@
"generate": "绘图",
"edit": "编辑",
"remix": "混合",
"upscale": "放大"
"upscale": "高清增强"
},
"generate": {
"model_tip": "模型版本:V3 为最新版本,V2 为之前版本,V2A 为快速模型、V_1 为初代模型,_TURBO 为加速版本",
@@ -950,7 +950,7 @@
"zhinao": "360智脑",
"zhipu": "智谱AI",
"voyageai": "Voyage AI",
"qiniu": "七牛云",
"qiniu": "七牛云 AI 推理",
"tokenflux": "TokenFlux"
},
"restore": {
@@ -1452,7 +1452,7 @@
"messages.input.send_shortcuts": "发送快捷键",
"messages.input.show_estimated_tokens": "显示预估 Token 数",
"messages.input.title": "输入设置",
"messages.input.enable_quick_triggers": "启用 '/' 和 '@' 触发快捷菜单",
"messages.input.enable_quick_triggers": "启用 / 和 @ 触发快捷菜单",
"messages.input.enable_delete_model": "启用删除键删除输入的模型/附件",
"messages.markdown_rendering_input_message": "Markdown 渲染输入消息",
"messages.math_engine": "数学公式引擎",
@@ -1642,7 +1642,6 @@
"zoom_out": "缩小界面",
"zoom_reset": "重置缩放"
},
"theme.auto": "自动",
"theme.dark": "深色",
"theme.light": "浅色",
"theme.title": "主题",
@@ -1650,6 +1649,14 @@
"theme.window.style.title": "窗口样式",
"theme.window.style.transparent": "透明窗口",
"title": "设置",
"topic.position": "话题位置",
"topic.position.left": "左侧",
"topic.position.right": "右侧",
"topic.show.time": "显示话题时间",
"topic.pin_to_top": "固定话题置顶",
"tray.onclose": "关闭时最小化到托盘",
"tray.show": "显示托盘图标",
"tray.title": "托盘",
"quickPhrase": {
"title": "快捷短语",
"add": "添加短语",
@@ -1696,64 +1703,57 @@
"service_tier.default": "默认",
"service_tier.flex": "灵活"
},
"topic.pin_to_top": "固定话题置顶",
"topic.position": "话题位置",
"topic.position.left": "左侧",
"topic.position.right": "右侧",
"topic.show.time": "显示话题时间",
"tray.onclose": "关闭时最小化到托盘",
"tray.show": "显示托盘图标",
"tray.title": "托盘",
"theme.auto": "自动",
"tool": {
"title": "[to be translated]:Tools Settings",
"title": "工具设置",
"ocr": {
"title": "[to be translated]:OCR",
"provider": "[to be translated]:OCR Provider",
"provider_placeholder": "[to be translated]:Choose an OCR provider",
"title": "OCR",
"provider": "OCR 服务商",
"provider_placeholder": "选择一个 OCR 服务商",
"mac_system_ocr_options": {
"mode": {
"title": "[to be translated]:Recognition Mode",
"accurate": "[to be translated]:Accurate",
"fast": "[to be translated]:Fast"
"title": "识别模式",
"accurate": "准确",
"fast": "快速"
},
"min_confidence": "[to be translated]:Minimum Confidence"
"min_confidence": "最低置信度"
}
},
"websearch": {
"blacklist": "[to be translated]:Blacklist",
"blacklist_description": "[to be translated]:Results from the following websites will not appear in search results",
"blacklist_tooltip": "[to be translated]:Please use the following format (separated by newlines)\nPattern matching: *://*.example.com/*\nRegular expression: /example\\.(net|org)/",
"check": "[to be translated]:Check",
"check_failed": "[to be translated]:Verification failed",
"check_success": "[to be translated]:Verification successful",
"get_api_key": "[to be translated]:Get API Key",
"no_provider_selected": "[to be translated]:Please select a search service provider before checking.",
"search_max_result": "[to be translated]:Number of search results",
"search_provider": "[to be translated]:Search service provider",
"search_provider_placeholder": "[to be translated]:Choose a search service provider.",
"search_result_default": "[to be translated]:Default",
"search_with_time": "[to be translated]:Search with dates included",
"blacklist": "黑名单",
"blacklist_description": "在搜索结果中不会出现以下网站的结果",
"blacklist_tooltip": "请使用以下格式(换行分隔)\n匹配模式: *://*.example.com/*\n正则表达式: /example\\.(net|org)/",
"check": "检测",
"check_failed": "验证失败",
"check_success": "验证成功",
"overwrite": "覆盖服务商搜索",
"overwrite_tooltip": "强制使用搜索服务商而不是大语言模型进行搜索",
"get_api_key": "点击这里获取密钥",
"no_provider_selected": "请选择搜索服务商后再检测",
"search_max_result": "搜索结果个数",
"search_provider": "搜索服务商",
"search_provider_placeholder": "选择一个搜索服务商",
"subscribe": "黑名单订阅",
"subscribe_update": "立即更新",
"subscribe_add": "添加订阅",
"subscribe_url": "订阅源地址",
"subscribe_name": "替代名字",
"subscribe_name.placeholder": "当下载的订阅源没有名称时所使用的替代名称",
"subscribe_add_success": "订阅源添加成功!",
"subscribe_delete": "删除订阅源",
"search_result_default": "默认",
"search_with_time": "搜索包含日期",
"tavily": {
"api_key": "[to be translated]:Tavily API Key",
"api_key.placeholder": "[to be translated]:Enter Tavily API Key",
"description": "[to be translated]:Tavily is a search engine tailored for AI agents, delivering real-time, accurate results, intelligent query suggestions, and in-depth research capabilities.",
"title": "[to be translated]:Tavily"
"api_key": "Tavily API 密钥",
"api_key.placeholder": "请输入 Tavily API 密钥",
"description": "Tavily 是一个为 AI 代理量身定制的搜索引擎,提供实时、准确的结果、智能查询建议和深入的研究能力",
"title": "Tavily"
},
"title": "[to be translated]:Web Search",
"subscribe": "[to be translated]:Blacklist Subscription",
"subscribe_update": "[to be translated]:Update",
"subscribe_add": "[to be translated]:Add Subscription",
"subscribe_url": "[to be translated]:Subscription Url",
"subscribe_name": "[to be translated]:Alternative name",
"subscribe_name.placeholder": "[to be translated]:Alternative name used when the downloaded subscription feed has no name.",
"subscribe_add_success": "[to be translated]:Subscription feed added successfully!",
"subscribe_delete": "[to be translated]:Delete",
"overwrite": "[to be translated]:Override search service",
"overwrite_tooltip": "[to be translated]:Force use search service instead of LLM",
"apikey": "[to be translated]:API key",
"free": "[to be translated]:Free",
"content_limit": "[to be translated]:Content length limit",
"content_limit_tooltip": "[to be translated]:Limit the content length of the search results; content that exceeds the limit will be truncated."
"title": "网络搜索",
"apikey": "API 密钥",
"free": "免费",
"content_limit": "内容长度限制",
"content_limit_tooltip": "限制搜索结果的内容长度, 超过限制的内容将被截断"
}
}
},
@@ -1784,10 +1784,10 @@
"input.placeholder": "输入文本进行翻译",
"output.placeholder": "翻译",
"processing": "翻译中...",
"scroll_sync.disable": "关闭滚动同步",
"scroll_sync.enable": "开启滚动同步",
"title": "翻译",
"tooltip.newline": "换行"
"tooltip.newline": "换行",
"scroll_sync.disable": "禁用滚动同步",
"scroll_sync.enable": "启用滚动同步"
},
"tray": {
"quit": "退出",
@@ -1936,4 +1936,4 @@
}
}
}
}
}
+1 -1
View File
@@ -1936,4 +1936,4 @@
}
}
}
}
}
+1 -1
View File
@@ -52,7 +52,7 @@ const persistedReducer = persistReducer(
{
key: 'cherry-studio',
storage,
version: 108,
version: 109,
blacklist: ['runtime', 'messages', 'messageBlocks'],
migrate
},
+6
View File
@@ -1488,6 +1488,12 @@ const migrateConfig = {
model: 'mistral-ocr-latest',
apiKey: '',
apiHost: 'https://api.mistral.ai'
},
{
id: 'mineru',
name: 'MinerU',
apiKey: '',
apiHost: 'https://mineru.net'
}
]
}
+6
View File
@@ -21,6 +21,12 @@ const initialState: OcrState = {
apiKey: '',
apiHost: 'https://api.mistral.ai'
},
{
id: 'mineru',
name: 'MinerU',
apiKey: '',
apiHost: 'https://mineru.net'
},
{
id: 'system',
name: 'System(Mac Only)',