Files
cherry-studio/src/renderer/src/services/KnowledgeService.ts
T
2025-07-03 16:23:02 +08:00

202 lines
6.3 KiB
TypeScript

import type { ExtractChunkData } from '@cherrystudio/embedjs-interfaces'
import AiProvider from '@renderer/aiCore'
import { DEFAULT_KNOWLEDGE_DOCUMENT_COUNT, DEFAULT_KNOWLEDGE_THRESHOLD } from '@renderer/config/constant'
import { getEmbeddingMaxContext } from '@renderer/config/embedings'
import Logger from '@renderer/config/logger'
import store from '@renderer/store'
import { FileMetadata, KnowledgeBase, KnowledgeBaseParams, KnowledgeReference } from '@renderer/types'
import { ExtractResults } from '@renderer/utils/extract'
import { isEmpty } from 'lodash'
import { getProviderByModel } from './AssistantService'
import FileManager from './FileManager'
export const getKnowledgeBaseParams = (base: KnowledgeBase): KnowledgeBaseParams => {
const provider = getProviderByModel(base.model)
const rerankProvider = getProviderByModel(base.rerankModel)
const aiProvider = new AiProvider(provider)
const rerankAiProvider = new AiProvider(rerankProvider)
let host = aiProvider.getBaseURL()
const rerankHost = rerankAiProvider.getBaseURL()
if (provider.type === 'gemini') {
host = host + '/v1beta/openai/'
}
let chunkSize = base.chunkSize
const maxChunkSize = getEmbeddingMaxContext(base.model.id)
if (maxChunkSize) {
if (chunkSize && chunkSize > maxChunkSize) {
chunkSize = maxChunkSize
}
if (!chunkSize && maxChunkSize < 1024) {
chunkSize = maxChunkSize
}
}
return {
id: base.id,
model: base.model.id,
provider: base.model.provider,
dimensions: base.dimensions,
apiKey: aiProvider.getApiKey() || 'secret',
apiVersion: provider.apiVersion,
baseURL: host,
chunkSize,
chunkOverlap: base.chunkOverlap,
rerankBaseURL: rerankHost,
rerankApiKey: rerankAiProvider.getApiKey() || 'secret',
rerankModel: base.rerankModel?.id,
rerankModelProvider: rerankProvider.name.toLowerCase(),
// topN: base.topN,
// preprocessing: base.preprocessing,
preprocessOrOcrProvider: base.preprocessOrOcrProvider
// topN: base.topN
}
}
export const getFileFromUrl = async (url: string): Promise<FileMetadata | null> => {
console.log('getFileFromUrl', url)
let fileName = ''
if (url && url.includes('CherryStudio')) {
if (url.includes('/Data/Files')) {
fileName = url.split('/Data/Files/')[1]
}
if (url.includes('\\Data\\Files')) {
fileName = url.split('\\Data\\Files\\')[1]
}
}
console.log('fileName', fileName)
if (fileName) {
const actualFileName = fileName.split(/[/\\]/).pop() || fileName
console.log('actualFileName', actualFileName)
const fileId = actualFileName.split('.')[0]
const file = await FileManager.getFile(fileId)
if (file) {
return file
}
}
return null
}
export const getKnowledgeSourceUrl = async (item: ExtractChunkData & { file: FileMetadata | null }) => {
if (item.metadata.source.startsWith('http')) {
return item.metadata.source
}
if (item.file) {
return `[${item.file.origin_name}](http://file/${item.file.name})`
}
return item.metadata.source
}
export const searchKnowledgeBase = async (
query: string,
base: KnowledgeBase,
rewrite?: string
): Promise<Array<ExtractChunkData & { file: FileMetadata | null }>> => {
try {
const baseParams = getKnowledgeBaseParams(base)
const documentCount = base.documentCount || DEFAULT_KNOWLEDGE_DOCUMENT_COUNT
const threshold = base.threshold || DEFAULT_KNOWLEDGE_THRESHOLD
// 执行搜索
const searchResults = await window.api.knowledgeBase.search({
search: rewrite || query,
base: baseParams
})
// 过滤阈值不达标的结果
const filteredResults = searchResults.filter((item) => item.score >= threshold)
// 如果有rerank模型,执行重排
let rerankResults = filteredResults
if (base.rerankModel && filteredResults.length > 0) {
rerankResults = await window.api.knowledgeBase.rerank({
search: rewrite || query,
base: baseParams,
results: filteredResults
})
}
// 限制文档数量
const limitedResults = rerankResults.slice(0, documentCount)
// 处理文件信息
return await Promise.all(
limitedResults.map(async (item) => {
const file = await getFileFromUrl(item.metadata.source)
console.log('Knowledge search item:', item, 'File:', file)
return { ...item, file }
})
)
} catch (error) {
Logger.error(`Error searching knowledge base ${base.name}:`, error)
throw error
}
}
export const processKnowledgeSearch = async (
extractResults: ExtractResults,
knowledgeBaseIds: string[] | undefined
): Promise<KnowledgeReference[]> => {
if (
!extractResults.knowledge?.question ||
extractResults.knowledge.question.length === 0 ||
isEmpty(knowledgeBaseIds)
) {
Logger.log('No valid question found in extractResults.knowledge')
return []
}
const questions = extractResults.knowledge.question
const rewrite = extractResults.knowledge.rewrite
const bases = store.getState().knowledge.bases.filter((kb) => knowledgeBaseIds?.includes(kb.id))
if (!bases || bases.length === 0) {
Logger.log('Skipping knowledge search: No matching knowledge bases found.')
return []
}
// 为每个知识库执行多问题搜索
const baseSearchPromises = bases.map(async (base) => {
// 为每个问题搜索并合并结果
const allResults = await Promise.all(questions.map((question) => searchKnowledgeBase(question, base, rewrite)))
// 合并结果并去重
const flatResults = allResults.flat()
const uniqueResults = Array.from(
new Map(flatResults.map((item) => [item.metadata.uniqueId || item.pageContent, item])).values()
).sort((a, b) => b.score - a.score)
// 转换为引用格式
return await Promise.all(
uniqueResults.map(
async (item, index) =>
({
id: index + 1,
content: item.pageContent,
sourceUrl: await getKnowledgeSourceUrl(item),
type: 'file'
}) as KnowledgeReference
)
)
})
// 汇总所有知识库的结果
const resultsPerBase = await Promise.all(baseSearchPromises)
const allReferencesRaw = resultsPerBase.flat().filter((ref): ref is KnowledgeReference => !!ref)
// 重新为引用分配ID
return allReferencesRaw.map((ref, index) => ({
...ref,
id: index + 1
}))
}