feat: remove faiss database (#10178)

This commit is contained in:
Chen Tao
2025-09-15 17:59:46 +08:00
committed by GitHub
parent 7f9f5514a4
commit e3d2bb2ec6
45 changed files with 376 additions and 2598 deletions
+1 -1
View File
@@ -28,7 +28,7 @@ import DxtService from './services/DxtService'
import { ExportService } from './services/ExportService'
import { fileStorage as fileManager } from './services/FileStorage'
import FileService from './services/FileSystemService'
import KnowledgeService from './services/knowledge/KnowledgeService'
import KnowledgeService from './services/KnowledgeService'
import mcpService from './services/MCPService'
import MemoryService from './services/memory/MemoryService'
import { openTraceWindow, setTraceWindowTitle } from './services/NodeTraceService'
@@ -1,63 +0,0 @@
import { VoyageEmbeddings } from '@langchain/community/embeddings/voyage'
import type { Embeddings } from '@langchain/core/embeddings'
import { OllamaEmbeddings } from '@langchain/ollama'
import { AzureOpenAIEmbeddings, OpenAIEmbeddings } from '@langchain/openai'
import { ApiClient, SystemProviderIds } from '@types'
import { isJinaEmbeddingsModel, JinaEmbeddings } from './JinaEmbeddings'
export default class EmbeddingsFactory {
static create({ embedApiClient, dimensions }: { embedApiClient: ApiClient; dimensions?: number }): Embeddings {
const batchSize = 10
const { model, provider, apiKey, apiVersion, baseURL } = embedApiClient
if (provider === SystemProviderIds.ollama) {
let baseUrl = baseURL
if (baseURL.includes('v1/')) {
baseUrl = baseURL.replace('v1/', '')
}
const headers = apiKey
? {
Authorization: `Bearer ${apiKey}`
}
: undefined
return new OllamaEmbeddings({
model: model,
baseUrl,
...headers
})
} else if (provider === SystemProviderIds.voyageai) {
return new VoyageEmbeddings({
modelName: model,
apiKey,
outputDimension: dimensions,
batchSize
})
}
if (isJinaEmbeddingsModel(model)) {
return new JinaEmbeddings({
model,
apiKey,
batchSize,
dimensions,
baseUrl: baseURL
})
}
if (apiVersion !== undefined) {
return new AzureOpenAIEmbeddings({
azureOpenAIApiKey: apiKey,
azureOpenAIApiVersion: apiVersion,
azureOpenAIApiDeploymentName: model,
azureOpenAIEndpoint: baseURL,
dimensions,
batchSize
})
}
return new OpenAIEmbeddings({
model,
apiKey,
dimensions,
batchSize,
configuration: { baseURL }
})
}
}
@@ -1,199 +0,0 @@
import { Embeddings, type EmbeddingsParams } from '@langchain/core/embeddings'
import { chunkArray } from '@langchain/core/utils/chunk_array'
import { getEnvironmentVariable } from '@langchain/core/utils/env'
import { z } from 'zod'
const jinaModelSchema = z.union([
z.literal('jina-clip-v2'),
z.literal('jina-embeddings-v3'),
z.literal('jina-colbert-v2'),
z.literal('jina-clip-v1'),
z.literal('jina-colbert-v1-en'),
z.literal('jina-embeddings-v2-base-es'),
z.literal('jina-embeddings-v2-base-code'),
z.literal('jina-embeddings-v2-base-de'),
z.literal('jina-embeddings-v2-base-zh'),
z.literal('jina-embeddings-v2-base-en')
])
type JinaModel = z.infer<typeof jinaModelSchema>
export const isJinaEmbeddingsModel = (model: string): model is JinaModel => {
return jinaModelSchema.safeParse(model).success
}
interface JinaEmbeddingsParams extends EmbeddingsParams {
/** Model name to use */
model: JinaModel
baseUrl?: string
/**
* Timeout to use when making requests to Jina.
*/
timeout?: number
/**
* The maximum number of documents to embed in a single request.
*/
batchSize?: number
/**
* Whether to strip new lines from the input text.
*/
stripNewLines?: boolean
/**
* The dimensions of the embedding.
*/
dimensions?: number
/**
* Scales the embedding so its Euclidean (L2) norm becomes 1, preserving direction. Useful when downstream involves dot-product, classification, visualization..
*/
normalized?: boolean
}
type JinaMultiModelInput =
| {
text: string
image?: never
}
| {
image: string
text?: never
}
type JinaEmbeddingsInput = string | JinaMultiModelInput
interface EmbeddingCreateParams {
model: JinaEmbeddingsParams['model']
/**
* input can be strings or JinaMultiModelInputs,if you want embed image,you should use JinaMultiModelInputs
*/
input: JinaEmbeddingsInput[]
dimensions: number
task?: 'retrieval.query' | 'retrieval.passage'
}
interface EmbeddingResponse {
model: string
object: string
usage: {
total_tokens: number
prompt_tokens: number
}
data: {
object: string
index: number
embedding: number[]
}[]
}
interface EmbeddingErrorResponse {
detail: string
}
export class JinaEmbeddings extends Embeddings implements JinaEmbeddingsParams {
model: JinaEmbeddingsParams['model'] = 'jina-clip-v2'
batchSize = 24
baseUrl = 'https://api.jina.ai/v1/embeddings'
stripNewLines = true
dimensions = 1024
apiKey: string
constructor(
fields?: Partial<JinaEmbeddingsParams> & {
apiKey?: string
}
) {
const fieldsWithDefaults = { maxConcurrency: 2, ...fields }
super(fieldsWithDefaults)
const apiKey =
fieldsWithDefaults?.apiKey || getEnvironmentVariable('JINA_API_KEY') || getEnvironmentVariable('JINA_AUTH_TOKEN')
if (!apiKey) throw new Error('Jina API key not found')
this.apiKey = apiKey
this.baseUrl = fieldsWithDefaults?.baseUrl ? `${fieldsWithDefaults?.baseUrl}embeddings` : this.baseUrl
this.model = fieldsWithDefaults?.model ?? this.model
this.dimensions = fieldsWithDefaults?.dimensions ?? this.dimensions
this.batchSize = fieldsWithDefaults?.batchSize ?? this.batchSize
this.stripNewLines = fieldsWithDefaults?.stripNewLines ?? this.stripNewLines
}
private doStripNewLines(input: JinaEmbeddingsInput[]) {
if (this.stripNewLines) {
return input.map((i) => {
if (typeof i === 'string') {
return i.replace(/\n/g, ' ')
}
if (i.text) {
return { text: i.text.replace(/\n/g, ' ') }
}
return i
})
}
return input
}
async embedDocuments(input: JinaEmbeddingsInput[]): Promise<number[][]> {
const batches = chunkArray(this.doStripNewLines(input), this.batchSize)
const batchRequests = batches.map((batch) => {
const params = this.getParams(batch)
return this.embeddingWithRetry(params)
})
const batchResponses = await Promise.all(batchRequests)
const embeddings: number[][] = []
for (let i = 0; i < batchResponses.length; i += 1) {
const batch = batches[i]
const batchResponse = batchResponses[i] || []
for (let j = 0; j < batch.length; j += 1) {
embeddings.push(batchResponse[j])
}
}
return embeddings
}
async embedQuery(input: JinaEmbeddingsInput): Promise<number[]> {
const params = this.getParams(this.doStripNewLines([input]), true)
const embeddings = (await this.embeddingWithRetry(params)) || [[]]
return embeddings[0]
}
private getParams(input: JinaEmbeddingsInput[], query?: boolean): EmbeddingCreateParams {
return {
model: this.model,
input,
dimensions: this.dimensions,
task: query ? 'retrieval.query' : this.model === 'jina-clip-v2' ? undefined : 'retrieval.passage'
}
}
private async embeddingWithRetry(body: EmbeddingCreateParams) {
const response = await fetch(this.baseUrl, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
Authorization: `Bearer ${this.apiKey}`
},
body: JSON.stringify(body)
})
const embeddingData: EmbeddingResponse | EmbeddingErrorResponse = await response.json()
if ('detail' in embeddingData && embeddingData.detail) {
throw new Error(`${embeddingData.detail}`)
}
return (embeddingData as EmbeddingResponse).data.map(({ embedding }) => embedding)
}
}
@@ -1,25 +0,0 @@
import type { Embeddings as BaseEmbeddings } from '@langchain/core/embeddings'
import { TraceMethod } from '@mcp-trace/trace-core'
import { ApiClient } from '@types'
import EmbeddingsFactory from './EmbeddingsFactory'
export default class TextEmbeddings {
private sdk: BaseEmbeddings
constructor({ embedApiClient, dimensions }: { embedApiClient: ApiClient; dimensions?: number }) {
this.sdk = EmbeddingsFactory.create({
embedApiClient,
dimensions
})
}
@TraceMethod({ spanName: 'embedDocuments', tag: 'Embeddings' })
public async embedDocuments(texts: string[]): Promise<number[][]> {
return this.sdk.embedDocuments(texts)
}
@TraceMethod({ spanName: 'embedQuery', tag: 'Embeddings' })
public async embedQuery(text: string): Promise<number[]> {
return this.sdk.embedQuery(text)
}
}
@@ -1,97 +0,0 @@
import { BaseDocumentLoader } from '@langchain/core/document_loaders/base'
import { Document } from '@langchain/core/documents'
import { readTextFileWithAutoEncoding } from '@main/utils/file'
import MarkdownIt from 'markdown-it'
export class MarkdownLoader extends BaseDocumentLoader {
private path: string
private md: MarkdownIt
constructor(path: string) {
super()
this.path = path
this.md = new MarkdownIt()
}
public async load(): Promise<Document[]> {
const content = await readTextFileWithAutoEncoding(this.path)
return this.parseMarkdown(content)
}
private parseMarkdown(content: string): Document[] {
const tokens = this.md.parse(content, {})
const documents: Document[] = []
let currentSection: {
heading?: string
level?: number
content: string
startLine?: number
} = { content: '' }
let i = 0
while (i < tokens.length) {
const token = tokens[i]
if (token.type === 'heading_open') {
// Save previous section if it has content
if (currentSection.content.trim()) {
documents.push(
new Document({
pageContent: currentSection.content.trim(),
metadata: {
source: this.path,
heading: currentSection.heading || 'Introduction',
level: currentSection.level || 0,
startLine: currentSection.startLine || 0
}
})
)
}
// Start new section
const level = parseInt(token.tag.slice(1)) // Extract number from h1, h2, etc.
const headingContent = tokens[i + 1]?.content || ''
currentSection = {
heading: headingContent,
level: level,
content: '',
startLine: token.map?.[0] || 0
}
// Skip heading_open, inline, heading_close tokens
i += 3
continue
}
// Add token content to current section
if (token.content) {
currentSection.content += token.content
}
// Add newlines for block tokens
if (token.block && token.type !== 'heading_close') {
currentSection.content += '\n'
}
i++
}
// Add the last section
if (currentSection.content.trim()) {
documents.push(
new Document({
pageContent: currentSection.content.trim(),
metadata: {
source: this.path,
heading: currentSection.heading || 'Introduction',
level: currentSection.level || 0,
startLine: currentSection.startLine || 0
}
})
)
}
return documents
}
}
@@ -1,50 +0,0 @@
import { BaseDocumentLoader } from '@langchain/core/document_loaders/base'
import { Document } from '@langchain/core/documents'
export class NoteLoader extends BaseDocumentLoader {
private text: string
private sourceUrl?: string
constructor(
public _text: string,
public _sourceUrl?: string
) {
super()
this.text = _text
this.sourceUrl = _sourceUrl
}
/**
* A protected method that takes a `raw` string as a parameter and returns
* a promise that resolves to an array containing the raw text as a single
* element.
* @param raw The raw text to be parsed.
* @returns A promise that resolves to an array containing the raw text as a single element.
*/
protected async parse(raw: string): Promise<string[]> {
return [raw]
}
public async load(): Promise<Document[]> {
const metadata = { source: this.sourceUrl || 'note' }
const parsed = await this.parse(this.text)
parsed.forEach((pageContent, i) => {
if (typeof pageContent !== 'string') {
throw new Error(`Expected string, at position ${i} got ${typeof pageContent}`)
}
})
return parsed.map(
(pageContent, i) =>
new Document({
pageContent,
metadata:
parsed.length === 1
? metadata
: {
...metadata,
line: i + 1
}
})
)
}
}
@@ -1,170 +0,0 @@
import { BaseDocumentLoader } from '@langchain/core/document_loaders/base'
import { Document } from '@langchain/core/documents'
import { Innertube } from 'youtubei.js'
// ... (接口定义 YoutubeConfig 和 VideoMetadata 保持不变)
/**
* Configuration options for the YoutubeLoader class. Includes properties
* such as the videoId, language, and addVideoInfo.
*/
interface YoutubeConfig {
videoId: string
language?: string
addVideoInfo?: boolean
// 新增一个选项,用于控制输出格式
transcriptFormat?: 'text' | 'srt'
}
/**
* Metadata of a YouTube video. Includes properties such as the source
* (videoId), description, title, view_count, author, and category.
*/
interface VideoMetadata {
source: string
description?: string
title?: string
view_count?: number
author?: string
category?: string
}
/**
* A document loader for loading data from YouTube videos. It uses the
* youtubei.js library to fetch the transcript and video metadata.
* @example
* ```typescript
* const loader = new YoutubeLoader({
* videoId: "VIDEO_ID",
* language: "en",
* addVideoInfo: true,
* transcriptFormat: "srt" // 获取 SRT 格式
* });
* const docs = await loader.load();
* console.log(docs[0].pageContent);
* ```
*/
export class YoutubeLoader extends BaseDocumentLoader {
private videoId: string
private language?: string
private addVideoInfo: boolean
// 新增格式化选项的私有属性
private transcriptFormat: 'text' | 'srt'
constructor(config: YoutubeConfig) {
super()
this.videoId = config.videoId
this.language = config?.language
this.addVideoInfo = config?.addVideoInfo ?? false
// 初始化格式化选项,默认为 'text' 以保持向后兼容
this.transcriptFormat = config?.transcriptFormat ?? 'text'
}
/**
* Extracts the videoId from a YouTube video URL.
* @param url The URL of the YouTube video.
* @returns The videoId of the YouTube video.
*/
private static getVideoID(url: string): string {
const match = url.match(/.*(?:youtu.be\/|v\/|u\/\w\/|embed\/|watch\?v=)([^#&?]*).*/)
if (match !== null && match[1].length === 11) {
return match[1]
} else {
throw new Error('Failed to get youtube video id from the url')
}
}
/**
* Creates a new instance of the YoutubeLoader class from a YouTube video
* URL.
* @param url The URL of the YouTube video.
* @param config Optional configuration options for the YoutubeLoader instance, excluding the videoId.
* @returns A new instance of the YoutubeLoader class.
*/
static createFromUrl(url: string, config?: Omit<YoutubeConfig, 'videoId'>): YoutubeLoader {
const videoId = YoutubeLoader.getVideoID(url)
return new YoutubeLoader({ ...config, videoId })
}
/**
* [新增] 辅助函数:将毫秒转换为 SRT 时间戳格式 (HH:MM:SS,ms)
* @param ms 毫秒数
* @returns 格式化后的时间字符串
*/
private static formatTimestamp(ms: number): string {
const totalSeconds = Math.floor(ms / 1000)
const hours = Math.floor(totalSeconds / 3600)
.toString()
.padStart(2, '0')
const minutes = Math.floor((totalSeconds % 3600) / 60)
.toString()
.padStart(2, '0')
const seconds = (totalSeconds % 60).toString().padStart(2, '0')
const milliseconds = (ms % 1000).toString().padStart(3, '0')
return `${hours}:${minutes}:${seconds},${milliseconds}`
}
/**
* Loads the transcript and video metadata from the specified YouTube
* video. It can return the transcript as plain text or in SRT format.
* @returns An array of Documents representing the retrieved data.
*/
async load(): Promise<Document[]> {
const metadata: VideoMetadata = {
source: this.videoId
}
try {
const youtube = await Innertube.create({
lang: this.language,
retrieve_player: false
})
const info = await youtube.getInfo(this.videoId)
const transcriptData = await info.getTranscript()
if (!transcriptData.transcript.content?.body?.initial_segments) {
throw new Error('Transcript segments not found in the response.')
}
const segments = transcriptData.transcript.content.body.initial_segments
let pageContent: string
// 根据 transcriptFormat 选项决定如何格式化字幕
if (this.transcriptFormat === 'srt') {
// [修改] 将字幕片段格式化为 SRT 格式
pageContent = segments
.map((segment, index) => {
const srtIndex = index + 1
const startTime = YoutubeLoader.formatTimestamp(Number(segment.start_ms))
const endTime = YoutubeLoader.formatTimestamp(Number(segment.end_ms))
const text = segment.snippet?.text || '' // 使用 segment.snippet.text
return `${srtIndex}\n${startTime} --> ${endTime}\n${text}`
})
.join('\n\n') // 每个 SRT 块之间用两个换行符分隔
} else {
// [原始逻辑] 拼接为纯文本
pageContent = segments.map((segment) => segment.snippet?.text || '').join(' ')
}
if (this.addVideoInfo) {
const basicInfo = info.basic_info
metadata.description = basicInfo.short_description
metadata.title = basicInfo.title
metadata.view_count = basicInfo.view_count
metadata.author = basicInfo.author
}
const document = new Document({
pageContent,
metadata
})
return [document]
} catch (e: unknown) {
throw new Error(`Failed to get YouTube video transcription: ${(e as Error).message}`)
}
}
}
@@ -1,235 +0,0 @@
import { DocxLoader } from '@langchain/community/document_loaders/fs/docx'
import { EPubLoader } from '@langchain/community/document_loaders/fs/epub'
import { PDFLoader } from '@langchain/community/document_loaders/fs/pdf'
import { PPTXLoader } from '@langchain/community/document_loaders/fs/pptx'
import { CheerioWebBaseLoader } from '@langchain/community/document_loaders/web/cheerio'
import { SitemapLoader } from '@langchain/community/document_loaders/web/sitemap'
import { FaissStore } from '@langchain/community/vectorstores/faiss'
import { Document } from '@langchain/core/documents'
import { loggerService } from '@logger'
import { UrlSource } from '@main/utils/knowledge'
import { LoaderReturn } from '@shared/config/types'
import { FileMetadata, FileTypes, KnowledgeBaseParams } from '@types'
import { randomUUID } from 'crypto'
import { JSONLoader } from 'langchain/document_loaders/fs/json'
import { TextLoader } from 'langchain/document_loaders/fs/text'
import { SplitterFactory } from '../splitter'
import { MarkdownLoader } from './MarkdownLoader'
import { NoteLoader } from './NoteLoader'
import { YoutubeLoader } from './YoutubeLoader'
const logger = loggerService.withContext('KnowledgeService File Loader')
type LoaderInstance =
| TextLoader
| PDFLoader
| PPTXLoader
| DocxLoader
| JSONLoader
| EPubLoader
| CheerioWebBaseLoader
| YoutubeLoader
| SitemapLoader
| NoteLoader
| MarkdownLoader
/**
* 为文档数组中的每个文档的 metadata 添加类型信息。
*/
function formatDocument(docs: Document[], type: string): Document[] {
return docs.map((doc) => ({
...doc,
metadata: {
...doc.metadata,
type: type
}
}))
}
/**
* 通用文档处理管道
*/
async function processDocuments(
base: KnowledgeBaseParams,
vectorStore: FaissStore,
docs: Document[],
loaderType: string,
splitterType?: string
): Promise<LoaderReturn> {
const formattedDocs = formatDocument(docs, loaderType)
const splitter = SplitterFactory.create({
chunkSize: base.chunkSize,
chunkOverlap: base.chunkOverlap,
...(splitterType && { type: splitterType })
})
const splitterResults = await splitter.splitDocuments(formattedDocs)
const ids = splitterResults.map(() => randomUUID())
await vectorStore.addDocuments(splitterResults, { ids })
return {
entriesAdded: splitterResults.length,
uniqueId: ids[0] || '',
uniqueIds: ids,
loaderType
}
}
/**
* 通用加载器执行函数
*/
async function executeLoader(
base: KnowledgeBaseParams,
vectorStore: FaissStore,
loaderInstance: LoaderInstance,
loaderType: string,
identifier: string,
splitterType?: string
): Promise<LoaderReturn> {
const emptyResult: LoaderReturn = {
entriesAdded: 0,
uniqueId: '',
uniqueIds: [],
loaderType
}
try {
const docs = await loaderInstance.load()
return await processDocuments(base, vectorStore, docs, loaderType, splitterType)
} catch (error) {
logger.error(`Error loading or processing ${identifier} with loader ${loaderType}: ${error}`)
return emptyResult
}
}
/**
* 文件扩展名到加载器的映射
*/
const FILE_LOADER_MAP: Record<string, { loader: new (path: string) => LoaderInstance; type: string }> = {
'.pdf': { loader: PDFLoader, type: 'pdf' },
'.txt': { loader: TextLoader, type: 'text' },
'.pptx': { loader: PPTXLoader, type: 'pptx' },
'.docx': { loader: DocxLoader, type: 'docx' },
'.doc': { loader: DocxLoader, type: 'doc' },
'.json': { loader: JSONLoader, type: 'json' },
'.epub': { loader: EPubLoader, type: 'epub' },
'.md': { loader: MarkdownLoader, type: 'markdown' }
}
export async function addFileLoader(
base: KnowledgeBaseParams,
vectorStore: FaissStore,
file: FileMetadata
): Promise<LoaderReturn> {
const fileExt = file.ext.toLowerCase()
const loaderConfig = FILE_LOADER_MAP[fileExt]
if (!loaderConfig) {
// 默认使用文本加载器
const loaderInstance = new TextLoader(file.path)
const type = fileExt.replace('.', '') || 'unknown'
return executeLoader(base, vectorStore, loaderInstance, type, file.path)
}
const loaderInstance = new loaderConfig.loader(file.path)
return executeLoader(base, vectorStore, loaderInstance, loaderConfig.type, file.path)
}
export async function addWebLoader(
base: KnowledgeBaseParams,
vectorStore: FaissStore,
url: string,
source: UrlSource
): Promise<LoaderReturn> {
let loaderInstance: CheerioWebBaseLoader | YoutubeLoader | undefined
let splitterType: string | undefined
switch (source) {
case 'normal':
loaderInstance = new CheerioWebBaseLoader(url)
break
case 'youtube':
loaderInstance = YoutubeLoader.createFromUrl(url, {
addVideoInfo: true,
transcriptFormat: 'srt'
})
splitterType = 'srt'
break
}
if (!loaderInstance) {
return {
entriesAdded: 0,
uniqueId: '',
uniqueIds: [],
loaderType: source
}
}
return executeLoader(base, vectorStore, loaderInstance, source, url, splitterType)
}
export async function addSitemapLoader(
base: KnowledgeBaseParams,
vectorStore: FaissStore,
url: string
): Promise<LoaderReturn> {
const loaderInstance = new SitemapLoader(url)
return executeLoader(base, vectorStore, loaderInstance, 'sitemap', url)
}
export async function addNoteLoader(
base: KnowledgeBaseParams,
vectorStore: FaissStore,
content: string,
sourceUrl: string
): Promise<LoaderReturn> {
const loaderInstance = new NoteLoader(content, sourceUrl)
return executeLoader(base, vectorStore, loaderInstance, 'note', sourceUrl)
}
export async function addVideoLoader(
base: KnowledgeBaseParams,
vectorStore: FaissStore,
files: FileMetadata[]
): Promise<LoaderReturn> {
const srtFile = files.find((f) => f.type === FileTypes.TEXT)
const videoFile = files.find((f) => f.type === FileTypes.VIDEO)
const emptyResult: LoaderReturn = {
entriesAdded: 0,
uniqueId: '',
uniqueIds: [],
loaderType: 'video'
}
if (!srtFile || !videoFile) {
return emptyResult
}
try {
const loaderInstance = new TextLoader(srtFile.path)
const originalDocs = await loaderInstance.load()
const docsWithVideoMeta = originalDocs.map(
(doc) =>
new Document({
...doc,
metadata: {
...doc.metadata,
video: {
path: videoFile.path,
name: videoFile.origin_name
}
}
})
)
return await processDocuments(base, vectorStore, docsWithVideoMeta, 'video', 'srt')
} catch (error) {
logger.error(`Error loading or processing file ${srtFile.path} with loader video: ${error}`)
return emptyResult
}
}
@@ -1,55 +0,0 @@
import { BM25Retriever } from '@langchain/community/retrievers/bm25'
import { FaissStore } from '@langchain/community/vectorstores/faiss'
import { BaseRetriever } from '@langchain/core/retrievers'
import { loggerService } from '@main/services/LoggerService'
import { type KnowledgeBaseParams } from '@types'
import { type Document } from 'langchain/document'
import { EnsembleRetriever } from 'langchain/retrievers/ensemble'
const logger = loggerService.withContext('RetrieverFactory')
export class RetrieverFactory {
/**
* 根据提供的参数创建一个 LangChain 检索器 (Retriever)。
* @param base 知识库配置参数。
* @param vectorStore 一个已初始化的向量存储实例。
* @param documents 文档列表,用于初始化 BM25Retriever。
* @returns 返回一个 BaseRetriever 实例。
*/
public createRetriever(base: KnowledgeBaseParams, vectorStore: FaissStore, documents: Document[]): BaseRetriever {
const retrieverType = base.retriever?.mode ?? 'hybrid'
const retrieverWeight = base.retriever?.weight ?? 0.5
const searchK = base.documentCount ?? 5
logger.info(`Creating retriever of type: ${retrieverType} with k=${searchK}`)
switch (retrieverType) {
case 'bm25':
if (documents.length === 0) {
throw new Error('BM25Retriever requires documents, but none were provided or found.')
}
logger.info('Create BM25 Retriever')
return BM25Retriever.fromDocuments(documents, { k: searchK })
case 'hybrid': {
if (documents.length === 0) {
logger.warn('No documents provided for BM25 part of hybrid search. Falling back to vector search only.')
return vectorStore.asRetriever(searchK)
}
const vectorstoreRetriever = vectorStore.asRetriever(searchK)
const bm25Retriever = BM25Retriever.fromDocuments(documents, { k: searchK })
logger.info('Create Hybrid Retriever')
return new EnsembleRetriever({
retrievers: [bm25Retriever, vectorstoreRetriever],
weights: [retrieverWeight, 1 - retrieverWeight]
})
}
case 'vector':
default:
logger.info('Create Vector Retriever')
return vectorStore.asRetriever(searchK)
}
}
}
@@ -1,133 +0,0 @@
import { Document } from '@langchain/core/documents'
import { TextSplitter, TextSplitterParams } from 'langchain/text_splitter'
// 定义一个接口来表示解析后的单个字幕片段
interface SrtSegment {
text: string
startTime: number // in seconds
endTime: number // in seconds
}
// 辅助函数:将 SRT 时间戳字符串 (HH:MM:SS,ms) 转换为秒
function srtTimeToSeconds(time: string): number {
const parts = time.split(':')
const secondsAndMs = parts[2].split(',')
const hours = parseInt(parts[0], 10)
const minutes = parseInt(parts[1], 10)
const seconds = parseInt(secondsAndMs[0], 10)
const milliseconds = parseInt(secondsAndMs[1], 10)
return hours * 3600 + minutes * 60 + seconds + milliseconds / 1000
}
export class SrtSplitter extends TextSplitter {
constructor(fields?: Partial<TextSplitterParams>) {
// 传入 chunkSize 和 chunkOverlap
super(fields)
}
splitText(): Promise<string[]> {
throw new Error('Method not implemented.')
}
// 核心方法:重写 splitDocuments 来实现自定义逻辑
async splitDocuments(documents: Document[]): Promise<Document[]> {
const allChunks: Document[] = []
for (const doc of documents) {
// 1. 解析 SRT 内容
const segments = this.parseSrt(doc.pageContent)
if (segments.length === 0) continue
// 2. 将字幕片段组合成块
const chunks = this.mergeSegmentsIntoChunks(segments, doc.metadata)
allChunks.push(...chunks)
}
return allChunks
}
// 辅助方法:解析整个 SRT 字符串
private parseSrt(srt: string): SrtSegment[] {
const segments: SrtSegment[] = []
const blocks = srt.trim().split(/\n\n/)
for (const block of blocks) {
const lines = block.split('\n')
if (lines.length < 3) continue
const timeMatch = lines[1].match(/(\d{2}:\d{2}:\d{2},\d{3}) --> (\d{2}:\d{2}:\d{2},\d{3})/)
if (!timeMatch) continue
const startTime = srtTimeToSeconds(timeMatch[1])
const endTime = srtTimeToSeconds(timeMatch[2])
const text = lines.slice(2).join(' ').trim()
segments.push({ text, startTime, endTime })
}
return segments
}
// 辅助方法:将解析后的片段合并成每 5 段一个块
private mergeSegmentsIntoChunks(segments: SrtSegment[], baseMetadata: Record<string, any>): Document[] {
const chunks: Document[] = []
let currentChunkText = ''
let currentChunkStartTime = 0
let currentChunkEndTime = 0
let segmentCount = 0
for (const segment of segments) {
if (segmentCount === 0) {
currentChunkStartTime = segment.startTime
}
currentChunkText += (currentChunkText ? ' ' : '') + segment.text
currentChunkEndTime = segment.endTime
segmentCount++
// 当累积到 5 段时,创建一个新的 Document
if (segmentCount === 5) {
const metadata: Record<string, any> = {
...baseMetadata,
startTime: currentChunkStartTime,
endTime: currentChunkEndTime
}
if (baseMetadata.source_url) {
metadata.source_url_with_timestamp = `${baseMetadata.source_url}?t=${Math.floor(currentChunkStartTime)}s`
}
chunks.push(
new Document({
pageContent: currentChunkText,
metadata
})
)
// 重置计数器和临时变量
currentChunkText = ''
currentChunkStartTime = 0
currentChunkEndTime = 0
segmentCount = 0
}
}
// 如果还有剩余的片段,创建最后一个 Document
if (segmentCount > 0) {
const metadata: Record<string, any> = {
...baseMetadata,
startTime: currentChunkStartTime,
endTime: currentChunkEndTime
}
if (baseMetadata.source_url) {
metadata.source_url_with_timestamp = `${baseMetadata.source_url}?t=${Math.floor(currentChunkStartTime)}s`
}
chunks.push(
new Document({
pageContent: currentChunkText,
metadata
})
)
}
return chunks
}
}
@@ -1,31 +0,0 @@
import { RecursiveCharacterTextSplitter, TextSplitter } from '@langchain/textsplitters'
import { SrtSplitter } from './SrtSplitter'
export type SplitterConfig = {
chunkSize?: number
chunkOverlap?: number
type?: 'recursive' | 'srt' | string
}
export class SplitterFactory {
/**
* Creates a TextSplitter instance based on the provided configuration.
* @param config - The configuration object specifying the splitter type and its parameters.
* @returns An instance of a TextSplitter, or null if no splitting is required.
*/
public static create(config: SplitterConfig): TextSplitter {
switch (config.type) {
case 'srt':
return new SrtSplitter({
chunkSize: config.chunkSize,
chunkOverlap: config.chunkOverlap
})
case 'recursive':
default:
return new RecursiveCharacterTextSplitter({
chunkSize: config.chunkSize,
chunkOverlap: config.chunkOverlap
})
}
}
}
@@ -1,3 +1,18 @@
/**
* Knowledge Service - Manages knowledge bases using RAG (Retrieval-Augmented Generation)
*
* This service handles creation, management, and querying of knowledge bases from various sources
* including files, directories, URLs, sitemaps, and notes.
*
* Features:
* - Concurrent task processing with workload management
* - Multiple data source support
* - Vector database integration
*
* For detailed documentation, see:
* @see {@link ../../../docs/technical/KnowledgeService.md}
*/
import * as fs from 'node:fs'
import path from 'node:path'
@@ -9,32 +24,87 @@ import { loggerService } from '@logger'
import Embeddings from '@main/knowledge/embedjs/embeddings/Embeddings'
import { addFileLoader } from '@main/knowledge/embedjs/loader'
import { NoteLoader } from '@main/knowledge/embedjs/loader/noteLoader'
import { preprocessingService } from '@main/knowledge/preprocess/PreprocessingService'
import PreprocessProvider from '@main/knowledge/preprocess/PreprocessProvider'
import Reranker from '@main/knowledge/reranker/Reranker'
import { fileStorage } from '@main/services/FileStorage'
import { windowService } from '@main/services/WindowService'
import { getDataPath } from '@main/utils'
import { getAllFiles } from '@main/utils/file'
import { TraceMethod } from '@mcp-trace/trace-core'
import { MB } from '@shared/config/constant'
import { LoaderReturn } from '@shared/config/types'
import type { LoaderReturn } from '@shared/config/types'
import { IpcChannel } from '@shared/IpcChannel'
import { FileMetadata, KnowledgeBaseParams, KnowledgeSearchResult } from '@types'
import { FileMetadata, KnowledgeBaseParams, KnowledgeItem, KnowledgeSearchResult } from '@types'
import { v4 as uuidv4 } from 'uuid'
import { windowService } from '../WindowService'
import {
IKnowledgeFramework,
KnowledgeBaseAddItemOptionsNonNullableAttribute,
LoaderDoneReturn,
LoaderTask,
LoaderTaskItem,
LoaderTaskItemState
} from './IKnowledgeFramework'
const logger = loggerService.withContext('MainKnowledgeService')
export class EmbedJsFramework implements IKnowledgeFramework {
private storageDir: string
private ragApplications: Map<string, RAGApplication> = new Map()
private pendingDeleteFile: string
private dbInstances: Map<string, LibSqlDb> = new Map()
export interface KnowledgeBaseAddItemOptions {
base: KnowledgeBaseParams
item: KnowledgeItem
forceReload?: boolean
userId?: string
}
interface KnowledgeBaseAddItemOptionsNonNullableAttribute {
base: KnowledgeBaseParams
item: KnowledgeItem
forceReload: boolean
userId: string
}
interface EvaluateTaskWorkload {
workload: number
}
type LoaderDoneReturn = LoaderReturn | null
enum LoaderTaskItemState {
PENDING,
PROCESSING,
DONE
}
interface LoaderTaskItem {
state: LoaderTaskItemState
task: () => Promise<unknown>
evaluateTaskWorkload: EvaluateTaskWorkload
}
interface LoaderTask {
loaderTasks: LoaderTaskItem[]
loaderDoneReturn: LoaderDoneReturn
}
interface LoaderTaskOfSet {
loaderTasks: Set<LoaderTaskItem>
loaderDoneReturn: LoaderDoneReturn
}
interface QueueTaskItem {
taskPromise: () => Promise<unknown>
resolve: () => void
evaluateTaskWorkload: EvaluateTaskWorkload
}
const loaderTaskIntoOfSet = (loaderTask: LoaderTask): LoaderTaskOfSet => {
return {
loaderTasks: new Set(loaderTask.loaderTasks),
loaderDoneReturn: loaderTask.loaderDoneReturn
}
}
class KnowledgeService {
private storageDir = path.join(getDataPath(), 'KnowledgeBase')
private pendingDeleteFile = path.join(this.storageDir, 'knowledge_pending_delete.json')
// Byte based
private workload = 0
private processingItemCount = 0
private knowledgeItemProcessingQueueMappingPromise: Map<LoaderTaskOfSet, () => void> = new Map()
private ragApplications: Map<string, RAGApplication> = new Map()
private dbInstances: Map<string, LibSqlDb> = new Map()
private static MAXIMUM_WORKLOAD = 80 * MB
private static MAXIMUM_PROCESSING_ITEM_COUNT = 30
private static ERROR_LOADER_RETURN: LoaderReturn = {
entriesAdded: 0,
uniqueId: '',
@@ -43,9 +113,7 @@ export class EmbedJsFramework implements IKnowledgeFramework {
status: 'failed'
}
constructor(storageDir: string) {
this.storageDir = storageDir
this.pendingDeleteFile = path.join(this.storageDir, 'knowledge_pending_delete.json')
constructor() {
this.initStorageDir()
this.cleanupOnStartup()
}
@@ -160,28 +228,33 @@ export class EmbedJsFramework implements IKnowledgeFramework {
logger.info(`Startup cleanup completed: ${deletedCount}/${pendingDeleteIds.length} knowledge bases deleted`)
}
private async getRagApplication(base: KnowledgeBaseParams): Promise<RAGApplication> {
if (this.ragApplications.has(base.id)) {
return this.ragApplications.get(base.id)!
private getRagApplication = async ({
id,
embedApiClient,
dimensions,
documentCount
}: KnowledgeBaseParams): Promise<RAGApplication> => {
if (this.ragApplications.has(id)) {
return this.ragApplications.get(id)!
}
let ragApplication: RAGApplication
const embeddings = new Embeddings({
embedApiClient: base.embedApiClient,
dimensions: base.dimensions
embedApiClient,
dimensions
})
try {
const libSqlDb = new LibSqlDb({ path: path.join(this.storageDir, base.id) })
const libSqlDb = new LibSqlDb({ path: path.join(this.storageDir, id) })
// Save database instance for later closing
this.dbInstances.set(base.id, libSqlDb)
this.dbInstances.set(id, libSqlDb)
ragApplication = await new RAGApplicationBuilder()
.setModel('NO_MODEL')
.setEmbeddingModel(embeddings)
.setVectorDatabase(libSqlDb)
.setSearchResultCount(base.documentCount || 30)
.setSearchResultCount(documentCount || 30)
.build()
this.ragApplications.set(base.id, ragApplication)
this.ragApplications.set(id, ragApplication)
} catch (e) {
logger.error('Failed to create RAGApplication:', e as Error)
throw new Error(`Failed to create RAGApplication: ${e}`)
@@ -189,14 +262,17 @@ export class EmbedJsFramework implements IKnowledgeFramework {
return ragApplication
}
async initialize(base: KnowledgeBaseParams): Promise<void> {
public create = async (_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams): Promise<void> => {
await this.getRagApplication(base)
}
async reset(base: KnowledgeBaseParams): Promise<void> {
const ragApp = await this.getRagApplication(base)
await ragApp.reset()
public reset = async (_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams): Promise<void> => {
const ragApplication = await this.getRagApplication(base)
await ragApplication.reset()
}
async delete(id: string): Promise<void> {
public async delete(_: Electron.IpcMainInvokeEvent, id: string): Promise<void> {
logger.debug(`delete id: ${id}`)
await this.cleanupKnowledgeResources(id)
@@ -209,41 +285,15 @@ export class EmbedJsFramework implements IKnowledgeFramework {
this.pendingDeleteManager.add(id)
}
}
getLoaderTask(options: KnowledgeBaseAddItemOptionsNonNullableAttribute): LoaderTask {
const { item } = options
const getRagApplication = () => this.getRagApplication(options.base)
switch (item.type) {
case 'file':
return this.fileTask(getRagApplication, options)
case 'directory':
return this.directoryTask(getRagApplication, options)
case 'url':
return this.urlTask(getRagApplication, options)
case 'sitemap':
return this.sitemapTask(getRagApplication, options)
case 'note':
return this.noteTask(getRagApplication, options)
default:
return {
loaderTasks: [],
loaderDoneReturn: null
}
}
}
async remove(options: { uniqueIds: string[]; base: KnowledgeBaseParams }): Promise<void> {
const ragApp = await this.getRagApplication(options.base)
for (const id of options.uniqueIds) {
await ragApp.deleteLoader(id)
}
private maximumLoad() {
return (
this.processingItemCount >= KnowledgeService.MAXIMUM_PROCESSING_ITEM_COUNT ||
this.workload >= KnowledgeService.MAXIMUM_WORKLOAD
)
}
async search(options: { search: string; base: KnowledgeBaseParams }): Promise<KnowledgeSearchResult[]> {
const ragApp = await this.getRagApplication(options.base)
return await ragApp.search(options.search)
}
private fileTask(
getRagApplication: () => Promise<RAGApplication>,
ragApplication: RAGApplication,
options: KnowledgeBaseAddItemOptionsNonNullableAttribute
): LoaderTask {
const { base, item, forceReload, userId } = options
@@ -256,8 +306,7 @@ export class EmbedJsFramework implements IKnowledgeFramework {
task: async () => {
try {
// Add preprocessing logic
const ragApplication = await getRagApplication()
const fileToProcess: FileMetadata = await preprocessingService.preprocessFile(file, base, item, userId)
const fileToProcess: FileMetadata = await this.preprocessing(file, base, item, userId)
// Use processed file for loading
return addFileLoader(ragApplication, fileToProcess, base, forceReload)
@@ -268,7 +317,7 @@ export class EmbedJsFramework implements IKnowledgeFramework {
.catch((e) => {
logger.error(`Error in addFileLoader for ${file.name}: ${e}`)
const errorResult: LoaderReturn = {
...EmbedJsFramework.ERROR_LOADER_RETURN,
...KnowledgeService.ERROR_LOADER_RETURN,
message: e.message,
messageSource: 'embedding'
}
@@ -278,7 +327,7 @@ export class EmbedJsFramework implements IKnowledgeFramework {
} catch (e: any) {
logger.error(`Preprocessing failed for ${file.name}: ${e}`)
const errorResult: LoaderReturn = {
...EmbedJsFramework.ERROR_LOADER_RETURN,
...KnowledgeService.ERROR_LOADER_RETURN,
message: e.message,
messageSource: 'preprocess'
}
@@ -295,7 +344,7 @@ export class EmbedJsFramework implements IKnowledgeFramework {
return loaderTask
}
private directoryTask(
getRagApplication: () => Promise<RAGApplication>,
ragApplication: RAGApplication,
options: KnowledgeBaseAddItemOptionsNonNullableAttribute
): LoaderTask {
const { base, item, forceReload } = options
@@ -322,9 +371,8 @@ export class EmbedJsFramework implements IKnowledgeFramework {
for (const file of files) {
loaderTasks.push({
state: LoaderTaskItemState.PENDING,
task: async () => {
const ragApplication = await getRagApplication()
return addFileLoader(ragApplication, file, base, forceReload)
task: () =>
addFileLoader(ragApplication, file, base, forceReload)
.then((result) => {
loaderDoneReturn.entriesAdded += 1
processedFiles += 1
@@ -335,12 +383,11 @@ export class EmbedJsFramework implements IKnowledgeFramework {
.catch((err) => {
logger.error('Failed to add dir loader:', err)
return {
...EmbedJsFramework.ERROR_LOADER_RETURN,
...KnowledgeService.ERROR_LOADER_RETURN,
message: `Failed to add dir loader: ${err.message}`,
messageSource: 'embedding'
}
})
},
}),
evaluateTaskWorkload: { workload: file.size }
})
}
@@ -352,7 +399,7 @@ export class EmbedJsFramework implements IKnowledgeFramework {
}
private urlTask(
getRagApplication: () => Promise<RAGApplication>,
ragApplication: RAGApplication,
options: KnowledgeBaseAddItemOptionsNonNullableAttribute
): LoaderTask {
const { base, item, forceReload } = options
@@ -362,8 +409,7 @@ export class EmbedJsFramework implements IKnowledgeFramework {
loaderTasks: [
{
state: LoaderTaskItemState.PENDING,
task: async () => {
const ragApplication = await getRagApplication()
task: () => {
const loaderReturn = ragApplication.addLoader(
new WebLoader({
urlOrContent: content,
@@ -387,7 +433,7 @@ export class EmbedJsFramework implements IKnowledgeFramework {
.catch((err) => {
logger.error('Failed to add url loader:', err)
return {
...EmbedJsFramework.ERROR_LOADER_RETURN,
...KnowledgeService.ERROR_LOADER_RETURN,
message: `Failed to add url loader: ${err.message}`,
messageSource: 'embedding'
}
@@ -402,7 +448,7 @@ export class EmbedJsFramework implements IKnowledgeFramework {
}
private sitemapTask(
getRagApplication: () => Promise<RAGApplication>,
ragApplication: RAGApplication,
options: KnowledgeBaseAddItemOptionsNonNullableAttribute
): LoaderTask {
const { base, item, forceReload } = options
@@ -412,9 +458,8 @@ export class EmbedJsFramework implements IKnowledgeFramework {
loaderTasks: [
{
state: LoaderTaskItemState.PENDING,
task: async () => {
const ragApplication = await getRagApplication()
return ragApplication
task: () =>
ragApplication
.addLoader(
new SitemapLoader({ url: content, chunkSize: base.chunkSize, chunkOverlap: base.chunkOverlap }) as any,
forceReload
@@ -432,12 +477,11 @@ export class EmbedJsFramework implements IKnowledgeFramework {
.catch((err) => {
logger.error('Failed to add sitemap loader:', err)
return {
...EmbedJsFramework.ERROR_LOADER_RETURN,
...KnowledgeService.ERROR_LOADER_RETURN,
message: `Failed to add sitemap loader: ${err.message}`,
messageSource: 'embedding'
}
})
},
}),
evaluateTaskWorkload: { workload: 20 * MB }
}
],
@@ -447,7 +491,7 @@ export class EmbedJsFramework implements IKnowledgeFramework {
}
private noteTask(
getRagApplication: () => Promise<RAGApplication>,
ragApplication: RAGApplication,
options: KnowledgeBaseAddItemOptionsNonNullableAttribute
): LoaderTask {
const { base, item, forceReload } = options
@@ -460,8 +504,7 @@ export class EmbedJsFramework implements IKnowledgeFramework {
loaderTasks: [
{
state: LoaderTaskItemState.PENDING,
task: async () => {
const ragApplication = await getRagApplication()
task: () => {
const loaderReturn = ragApplication.addLoader(
new NoteLoader({
text: content,
@@ -484,7 +527,7 @@ export class EmbedJsFramework implements IKnowledgeFramework {
.catch((err) => {
logger.error('Failed to add note loader:', err)
return {
...EmbedJsFramework.ERROR_LOADER_RETURN,
...KnowledgeService.ERROR_LOADER_RETURN,
message: `Failed to add note loader: ${err.message}`,
messageSource: 'embedding'
}
@@ -497,4 +540,199 @@ export class EmbedJsFramework implements IKnowledgeFramework {
}
return loaderTask
}
private processingQueueHandle() {
const getSubtasksUntilMaximumLoad = (): QueueTaskItem[] => {
const queueTaskList: QueueTaskItem[] = []
that: for (const [task, resolve] of this.knowledgeItemProcessingQueueMappingPromise) {
for (const item of task.loaderTasks) {
if (this.maximumLoad()) {
break that
}
const { state, task: taskPromise, evaluateTaskWorkload } = item
if (state !== LoaderTaskItemState.PENDING) {
continue
}
const { workload } = evaluateTaskWorkload
this.workload += workload
this.processingItemCount += 1
item.state = LoaderTaskItemState.PROCESSING
queueTaskList.push({
taskPromise: () =>
taskPromise().then(() => {
this.workload -= workload
this.processingItemCount -= 1
task.loaderTasks.delete(item)
if (task.loaderTasks.size === 0) {
this.knowledgeItemProcessingQueueMappingPromise.delete(task)
resolve()
}
this.processingQueueHandle()
}),
resolve: () => {},
evaluateTaskWorkload
})
}
}
return queueTaskList
}
const subTasks = getSubtasksUntilMaximumLoad()
if (subTasks.length > 0) {
const subTaskPromises = subTasks.map(({ taskPromise }) => taskPromise())
Promise.all(subTaskPromises).then(() => {
subTasks.forEach(({ resolve }) => resolve())
})
}
}
private appendProcessingQueue(task: LoaderTask): Promise<LoaderReturn> {
return new Promise((resolve) => {
this.knowledgeItemProcessingQueueMappingPromise.set(loaderTaskIntoOfSet(task), () => {
resolve(task.loaderDoneReturn!)
})
})
}
public add = (_: Electron.IpcMainInvokeEvent, options: KnowledgeBaseAddItemOptions): Promise<LoaderReturn> => {
return new Promise((resolve) => {
const { base, item, forceReload = false, userId = '' } = options
const optionsNonNullableAttribute = { base, item, forceReload, userId }
this.getRagApplication(base)
.then((ragApplication) => {
const task = (() => {
switch (item.type) {
case 'file':
return this.fileTask(ragApplication, optionsNonNullableAttribute)
case 'directory':
return this.directoryTask(ragApplication, optionsNonNullableAttribute)
case 'url':
return this.urlTask(ragApplication, optionsNonNullableAttribute)
case 'sitemap':
return this.sitemapTask(ragApplication, optionsNonNullableAttribute)
case 'note':
return this.noteTask(ragApplication, optionsNonNullableAttribute)
default:
return null
}
})()
if (task) {
this.appendProcessingQueue(task).then(() => {
resolve(task.loaderDoneReturn!)
})
this.processingQueueHandle()
} else {
resolve({
...KnowledgeService.ERROR_LOADER_RETURN,
message: 'Unsupported item type',
messageSource: 'embedding'
})
}
})
.catch((err) => {
logger.error('Failed to add item:', err)
resolve({
...KnowledgeService.ERROR_LOADER_RETURN,
message: `Failed to add item: ${err.message}`,
messageSource: 'embedding'
})
})
})
}
@TraceMethod({ spanName: 'remove', tag: 'Knowledge' })
public async remove(
_: Electron.IpcMainInvokeEvent,
{ uniqueId, uniqueIds, base }: { uniqueId: string; uniqueIds: string[]; base: KnowledgeBaseParams }
): Promise<void> {
const ragApplication = await this.getRagApplication(base)
logger.debug(`Remove Item UniqueId: ${uniqueId}`)
for (const id of uniqueIds) {
await ragApplication.deleteLoader(id)
}
}
@TraceMethod({ spanName: 'RagSearch', tag: 'Knowledge' })
public async search(
_: Electron.IpcMainInvokeEvent,
{ search, base }: { search: string; base: KnowledgeBaseParams }
): Promise<KnowledgeSearchResult[]> {
const ragApplication = await this.getRagApplication(base)
return await ragApplication.search(search)
}
@TraceMethod({ spanName: 'rerank', tag: 'Knowledge' })
public async rerank(
_: Electron.IpcMainInvokeEvent,
{ search, base, results }: { search: string; base: KnowledgeBaseParams; results: KnowledgeSearchResult[] }
): Promise<KnowledgeSearchResult[]> {
if (results.length === 0) {
return results
}
return await new Reranker(base).rerank(search, results)
}
public getStorageDir = (): string => {
return this.storageDir
}
private preprocessing = async (
file: FileMetadata,
base: KnowledgeBaseParams,
item: KnowledgeItem,
userId: string
): Promise<FileMetadata> => {
let fileToProcess: FileMetadata = file
if (base.preprocessProvider && file.ext.toLowerCase() === '.pdf') {
try {
const provider = new PreprocessProvider(base.preprocessProvider.provider, userId)
const filePath = fileStorage.getFilePathById(file)
// Check if file has already been preprocessed
const alreadyProcessed = await provider.checkIfAlreadyProcessed(file)
if (alreadyProcessed) {
logger.debug(`File already preprocess processed, using cached result: ${filePath}`)
return alreadyProcessed
}
// Execute preprocessing
logger.debug(`Starting preprocess processing for scanned PDF: ${filePath}`)
const { processedFile, quota } = await provider.parseFile(item.id, file)
fileToProcess = processedFile
const mainWindow = windowService.getMainWindow()
mainWindow?.webContents.send('file-preprocess-finished', {
itemId: item.id,
quota: quota
})
} catch (err) {
logger.error(`Preprocess processing failed: ${err}`)
// If preprocessing fails, use original file
// fileToProcess = file
throw new Error(`Preprocess processing failed: ${err}`)
}
}
return fileToProcess
}
public checkQuota = async (
_: Electron.IpcMainInvokeEvent,
base: KnowledgeBaseParams,
userId: string
): Promise<number> => {
try {
if (base.preprocessProvider && base.preprocessProvider.type === 'preprocess') {
const provider = new PreprocessProvider(base.preprocessProvider.provider, userId)
return await provider.checkQuota()
}
throw new Error('No preprocess provider configured')
} catch (err) {
logger.error(`Failed to check quota: ${err}`)
throw new Error(`Failed to check quota: ${err}`)
}
}
}
export default new KnowledgeService()
@@ -1,72 +0,0 @@
import { LoaderReturn } from '@shared/config/types'
import { KnowledgeBaseParams, KnowledgeItem, KnowledgeSearchResult } from '@types'
export interface KnowledgeBaseAddItemOptions {
base: KnowledgeBaseParams
item: KnowledgeItem
forceReload?: boolean
userId?: string
}
export interface KnowledgeBaseAddItemOptionsNonNullableAttribute {
base: KnowledgeBaseParams
item: KnowledgeItem
forceReload: boolean
userId: string
}
export interface EvaluateTaskWorkload {
workload: number
}
export type LoaderDoneReturn = LoaderReturn | null
export enum LoaderTaskItemState {
PENDING,
PROCESSING,
DONE
}
export interface LoaderTaskItem {
state: LoaderTaskItemState
task: () => Promise<unknown>
evaluateTaskWorkload: EvaluateTaskWorkload
}
export interface LoaderTask {
loaderTasks: LoaderTaskItem[]
loaderDoneReturn: LoaderDoneReturn
}
export interface LoaderTaskOfSet {
loaderTasks: Set<LoaderTaskItem>
loaderDoneReturn: LoaderDoneReturn
}
export interface QueueTaskItem {
taskPromise: () => Promise<unknown>
resolve: () => void
evaluateTaskWorkload: EvaluateTaskWorkload
}
export const loaderTaskIntoOfSet = (loaderTask: LoaderTask): LoaderTaskOfSet => {
return {
loaderTasks: new Set(loaderTask.loaderTasks),
loaderDoneReturn: loaderTask.loaderDoneReturn
}
}
export interface IKnowledgeFramework {
/** 为给定知识库初始化框架资源 */
initialize(base: KnowledgeBaseParams): Promise<void>
/** 重置知识库,删除其所有内容 */
reset(base: KnowledgeBaseParams): Promise<void>
/** 删除与知识库关联的资源,包括文件 */
delete(id: string): Promise<void>
/** 生成用于添加条目的任务对象,由队列处理 */
getLoaderTask(options: KnowledgeBaseAddItemOptionsNonNullableAttribute): LoaderTask
/** 从知识库中删除特定条目 */
remove(options: { uniqueIds: string[]; base: KnowledgeBaseParams }): Promise<void>
/** 搜索知识库 */
search(options: { search: string; base: KnowledgeBaseParams }): Promise<KnowledgeSearchResult[]>
}
@@ -1,48 +0,0 @@
import path from 'node:path'
import { KnowledgeBaseParams } from '@types'
import { app } from 'electron'
import { EmbedJsFramework } from './EmbedJsFramework'
import { IKnowledgeFramework } from './IKnowledgeFramework'
import { LangChainFramework } from './LangChainFramework'
class KnowledgeFrameworkFactory {
private static instance: KnowledgeFrameworkFactory
private frameworks: Map<string, IKnowledgeFramework> = new Map()
private storageDir: string
private constructor(storageDir: string) {
this.storageDir = storageDir
}
public static getInstance(storageDir: string): KnowledgeFrameworkFactory {
if (!KnowledgeFrameworkFactory.instance) {
KnowledgeFrameworkFactory.instance = new KnowledgeFrameworkFactory(storageDir)
}
return KnowledgeFrameworkFactory.instance
}
public getFramework(base: KnowledgeBaseParams): IKnowledgeFramework {
const frameworkType = base.framework || 'embedjs' // 如果未指定,默认为 embedjs
if (this.frameworks.has(frameworkType)) {
return this.frameworks.get(frameworkType)!
}
let framework: IKnowledgeFramework
switch (frameworkType) {
case 'langchain':
framework = new LangChainFramework(this.storageDir)
break
case 'embedjs':
default:
framework = new EmbedJsFramework(this.storageDir)
break
}
this.frameworks.set(frameworkType, framework)
return framework
}
}
export const knowledgeFrameworkFactory = KnowledgeFrameworkFactory.getInstance(
path.join(app.getPath('userData'), 'Data', 'KnowledgeBase')
)
@@ -1,190 +0,0 @@
import * as fs from 'node:fs'
import path from 'node:path'
import { loggerService } from '@logger'
import { preprocessingService } from '@main/knowledge/preprocess/PreprocessingService'
import Reranker from '@main/knowledge/reranker/Reranker'
import { TraceMethod } from '@mcp-trace/trace-core'
import { MB } from '@shared/config/constant'
import { LoaderReturn } from '@shared/config/types'
import { KnowledgeBaseParams, KnowledgeSearchResult } from '@types'
import { app } from 'electron'
import {
KnowledgeBaseAddItemOptions,
LoaderTask,
loaderTaskIntoOfSet,
LoaderTaskItemState,
LoaderTaskOfSet,
QueueTaskItem
} from './IKnowledgeFramework'
import { knowledgeFrameworkFactory } from './KnowledgeFrameworkFactory'
const logger = loggerService.withContext('MainKnowledgeService')
class KnowledgeService {
private storageDir = path.join(app.getPath('userData'), 'Data', 'KnowledgeBase')
private workload = 0
private processingItemCount = 0
private knowledgeItemProcessingQueueMappingPromise: Map<LoaderTaskOfSet, () => void> = new Map()
private static MAXIMUM_WORKLOAD = 80 * MB
private static MAXIMUM_PROCESSING_ITEM_COUNT = 30
private static ERROR_LOADER_RETURN: LoaderReturn = {
entriesAdded: 0,
uniqueId: '',
uniqueIds: [''],
loaderType: '',
status: 'failed'
}
constructor() {
this.initStorageDir()
}
private initStorageDir = (): void => {
if (!fs.existsSync(this.storageDir)) {
fs.mkdirSync(this.storageDir, { recursive: true })
}
}
private maximumLoad() {
return (
this.processingItemCount >= KnowledgeService.MAXIMUM_PROCESSING_ITEM_COUNT ||
this.workload >= KnowledgeService.MAXIMUM_WORKLOAD
)
}
private processingQueueHandle() {
const getSubtasksUntilMaximumLoad = (): QueueTaskItem[] => {
const queueTaskList: QueueTaskItem[] = []
that: for (const [task, resolve] of this.knowledgeItemProcessingQueueMappingPromise) {
for (const item of task.loaderTasks) {
if (this.maximumLoad()) {
break that
}
const { state, task: taskPromise, evaluateTaskWorkload } = item
if (state !== LoaderTaskItemState.PENDING) {
continue
}
const { workload } = evaluateTaskWorkload
this.workload += workload
this.processingItemCount += 1
item.state = LoaderTaskItemState.PROCESSING
queueTaskList.push({
taskPromise: () =>
taskPromise().then(() => {
this.workload -= workload
this.processingItemCount -= 1
task.loaderTasks.delete(item)
if (task.loaderTasks.size === 0) {
this.knowledgeItemProcessingQueueMappingPromise.delete(task)
resolve()
}
this.processingQueueHandle()
}),
resolve: () => {},
evaluateTaskWorkload
})
}
}
return queueTaskList
}
const subTasks = getSubtasksUntilMaximumLoad()
if (subTasks.length > 0) {
const subTaskPromises = subTasks.map(({ taskPromise }) => taskPromise())
Promise.all(subTaskPromises).then(() => {
subTasks.forEach(({ resolve }) => resolve())
})
}
}
private appendProcessingQueue(task: LoaderTask): Promise<LoaderReturn> {
return new Promise((resolve) => {
this.knowledgeItemProcessingQueueMappingPromise.set(loaderTaskIntoOfSet(task), () => {
resolve(task.loaderDoneReturn!)
})
})
}
public async create(_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams): Promise<void> {
logger.info(`Creating knowledge base: ${JSON.stringify(base)}`)
const framework = knowledgeFrameworkFactory.getFramework(base)
await framework.initialize(base)
}
public async reset(_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams): Promise<void> {
const framework = knowledgeFrameworkFactory.getFramework(base)
await framework.reset(base)
}
public async delete(_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams, id: string): Promise<void> {
logger.info(`Deleting knowledge base: ${JSON.stringify(base)}`)
const framework = knowledgeFrameworkFactory.getFramework(base)
await framework.delete(id)
}
public add = async (_: Electron.IpcMainInvokeEvent, options: KnowledgeBaseAddItemOptions): Promise<LoaderReturn> => {
logger.info(`Adding item to knowledge base: ${JSON.stringify(options)}`)
return new Promise((resolve) => {
const { base, item, forceReload = false, userId = '' } = options
const framework = knowledgeFrameworkFactory.getFramework(base)
const task = framework.getLoaderTask({ base, item, forceReload, userId })
if (task) {
this.appendProcessingQueue(task).then(() => {
resolve(task.loaderDoneReturn!)
})
this.processingQueueHandle()
} else {
resolve({
...KnowledgeService.ERROR_LOADER_RETURN,
message: 'Unsupported item type',
messageSource: 'embedding'
})
}
})
}
public async remove(
_: Electron.IpcMainInvokeEvent,
{ uniqueIds, base }: { uniqueIds: string[]; base: KnowledgeBaseParams }
): Promise<void> {
logger.info(`Removing items from knowledge base: ${JSON.stringify({ uniqueIds, base })}`)
const framework = knowledgeFrameworkFactory.getFramework(base)
await framework.remove({ uniqueIds, base })
}
public async search(
_: Electron.IpcMainInvokeEvent,
{ search, base }: { search: string; base: KnowledgeBaseParams }
): Promise<KnowledgeSearchResult[]> {
logger.info(`Searching knowledge base: ${JSON.stringify({ search, base })}`)
const framework = knowledgeFrameworkFactory.getFramework(base)
return framework.search({ search, base })
}
@TraceMethod({ spanName: 'rerank', tag: 'Knowledge' })
public async rerank(
_: Electron.IpcMainInvokeEvent,
{ search, base, results }: { search: string; base: KnowledgeBaseParams; results: KnowledgeSearchResult[] }
): Promise<KnowledgeSearchResult[]> {
logger.info(`Reranking knowledge base: ${JSON.stringify({ search, base, results })}`)
if (results.length === 0) {
return results
}
return await new Reranker(base).rerank(search, results)
}
public getStorageDir = (): string => {
return this.storageDir
}
public async checkQuota(_: Electron.IpcMainInvokeEvent, base: KnowledgeBaseParams, userId: string): Promise<number> {
return preprocessingService.checkQuota(base, userId)
}
}
export default new KnowledgeService()
@@ -1,557 +0,0 @@
import * as fs from 'node:fs'
import path from 'node:path'
import { FaissStore } from '@langchain/community/vectorstores/faiss'
import type { Document } from '@langchain/core/documents'
import { loggerService } from '@logger'
import TextEmbeddings from '@main/knowledge/langchain/embeddings/TextEmbeddings'
import {
addFileLoader,
addNoteLoader,
addSitemapLoader,
addVideoLoader,
addWebLoader
} from '@main/knowledge/langchain/loader'
import { RetrieverFactory } from '@main/knowledge/langchain/retriever'
import { preprocessingService } from '@main/knowledge/preprocess/PreprocessingService'
import { getAllFiles } from '@main/utils/file'
import { getUrlSource } from '@main/utils/knowledge'
import { MB } from '@shared/config/constant'
import { LoaderReturn } from '@shared/config/types'
import { IpcChannel } from '@shared/IpcChannel'
import {
FileMetadata,
isKnowledgeDirectoryItem,
isKnowledgeFileItem,
isKnowledgeNoteItem,
isKnowledgeSitemapItem,
isKnowledgeUrlItem,
isKnowledgeVideoItem,
KnowledgeBaseParams,
KnowledgeSearchResult
} from '@types'
import { uuidv4 } from 'zod'
import { windowService } from '../WindowService'
import {
IKnowledgeFramework,
KnowledgeBaseAddItemOptionsNonNullableAttribute,
LoaderDoneReturn,
LoaderTask,
LoaderTaskItem,
LoaderTaskItemState
} from './IKnowledgeFramework'
const logger = loggerService.withContext('LangChainFramework')
export class LangChainFramework implements IKnowledgeFramework {
private storageDir: string
private static ERROR_LOADER_RETURN: LoaderReturn = {
entriesAdded: 0,
uniqueId: '',
uniqueIds: [''],
loaderType: '',
status: 'failed'
}
constructor(storageDir: string) {
this.storageDir = storageDir
this.initStorageDir()
}
private initStorageDir = (): void => {
if (!fs.existsSync(this.storageDir)) {
fs.mkdirSync(this.storageDir, { recursive: true })
}
}
private async createDatabase(base: KnowledgeBaseParams): Promise<void> {
const dbPath = path.join(this.storageDir, base.id)
const embeddings = this.getEmbeddings(base)
const vectorStore = new FaissStore(embeddings, {})
const mockDocument: Document = {
pageContent: 'Create Database Document',
metadata: {}
}
await vectorStore.addDocuments([mockDocument], { ids: ['1'] })
await vectorStore.save(dbPath)
await vectorStore.delete({ ids: ['1'] })
await vectorStore.save(dbPath)
}
private getEmbeddings(base: KnowledgeBaseParams): TextEmbeddings {
return new TextEmbeddings({
embedApiClient: base.embedApiClient,
dimensions: base.dimensions
})
}
private async getVectorStore(base: KnowledgeBaseParams): Promise<FaissStore> {
const embeddings = this.getEmbeddings(base)
const vectorStore = await FaissStore.load(path.join(this.storageDir, base.id), embeddings)
return vectorStore
}
async initialize(base: KnowledgeBaseParams): Promise<void> {
await this.createDatabase(base)
}
async reset(base: KnowledgeBaseParams): Promise<void> {
const dbPath = path.join(this.storageDir, base.id)
if (fs.existsSync(dbPath)) {
fs.rmSync(dbPath, { recursive: true })
}
// 立即重建空索引,避免随后加载时报错
await this.createDatabase(base)
}
async delete(id: string): Promise<void> {
const dbPath = path.join(this.storageDir, id)
if (fs.existsSync(dbPath)) {
fs.rmSync(dbPath, { recursive: true })
}
}
getLoaderTask(options: KnowledgeBaseAddItemOptionsNonNullableAttribute): LoaderTask {
const { item } = options
const getStore = () => this.getVectorStore(options.base)
switch (item.type) {
case 'file':
return this.fileTask(getStore, options)
case 'directory':
return this.directoryTask(getStore, options)
case 'url':
return this.urlTask(getStore, options)
case 'sitemap':
return this.sitemapTask(getStore, options)
case 'note':
return this.noteTask(getStore, options)
case 'video':
return this.videoTask(getStore, options)
default:
return {
loaderTasks: [],
loaderDoneReturn: null
}
}
}
async remove(options: { uniqueIds: string[]; base: KnowledgeBaseParams }): Promise<void> {
const { uniqueIds, base } = options
const vectorStore = await this.getVectorStore(base)
logger.info(`[ KnowledgeService Remove Item UniqueIds: ${uniqueIds}]`)
await vectorStore.delete({ ids: uniqueIds })
await vectorStore.save(path.join(this.storageDir, base.id))
}
async search(options: { search: string; base: KnowledgeBaseParams }): Promise<KnowledgeSearchResult[]> {
const { search, base } = options
logger.info(`search base: ${JSON.stringify(base)}`)
try {
const vectorStore = await this.getVectorStore(base)
// 如果是 bm25 或 hybrid 模式,则从数据库获取所有文档
const documents: Document[] = await this.getAllDocuments(base)
if (documents.length === 0) return []
const retrieverFactory = new RetrieverFactory()
const retriever = retrieverFactory.createRetriever(base, vectorStore, documents)
const results = await retriever.invoke(search)
logger.info(`Search Results: ${JSON.stringify(results)}`)
// VectorStoreRetriever 和 EnsembleRetriever 会将分数附加到 metadata.score
// BM25Retriever 默认不返回分数,所以我们需要处理这种情况
return results.map((item) => {
return {
pageContent: item.pageContent,
metadata: item.metadata,
// 如果 metadata 中没有 score,提供一个默认值
score: typeof item.metadata.score === 'number' ? item.metadata.score : 0
}
})
} catch (error: any) {
logger.error(`Error during search in knowledge base ${base.id}: ${error.message}`)
return []
}
}
private fileTask(
getVectorStore: () => Promise<FaissStore>,
options: KnowledgeBaseAddItemOptionsNonNullableAttribute
): LoaderTask {
const { base, item, userId } = options
if (!isKnowledgeFileItem(item)) {
logger.error(`Invalid item type for fileTask: expected 'file', got '${item.type}'`)
return {
loaderTasks: [],
loaderDoneReturn: {
...LangChainFramework.ERROR_LOADER_RETURN,
message: `Invalid item type: expected 'file', got '${item.type}'`,
messageSource: 'validation'
}
}
}
const file = item.content
const loaderTask: LoaderTask = {
loaderTasks: [
{
state: LoaderTaskItemState.PENDING,
task: async () => {
try {
const vectorStore = await getVectorStore()
// 添加预处理逻辑
const fileToProcess: FileMetadata = await preprocessingService.preprocessFile(file, base, item, userId)
// 使用处理后的文件进行加载
return addFileLoader(base, vectorStore, fileToProcess)
.then((result) => {
loaderTask.loaderDoneReturn = result
return result
})
.then(async () => {
await vectorStore.save(path.join(this.storageDir, base.id))
})
.catch((e) => {
logger.error(`Error in addFileLoader for ${file.name}: ${e}`)
const errorResult: LoaderReturn = {
...LangChainFramework.ERROR_LOADER_RETURN,
message: e.message,
messageSource: 'embedding'
}
loaderTask.loaderDoneReturn = errorResult
return errorResult
})
} catch (e: any) {
logger.error(`Preprocessing failed for ${file.name}: ${e}`)
const errorResult: LoaderReturn = {
...LangChainFramework.ERROR_LOADER_RETURN,
message: e.message,
messageSource: 'preprocess'
}
loaderTask.loaderDoneReturn = errorResult
return errorResult
}
},
evaluateTaskWorkload: { workload: file.size }
}
],
loaderDoneReturn: null
}
return loaderTask
}
private directoryTask(
getVectorStore: () => Promise<FaissStore>,
options: KnowledgeBaseAddItemOptionsNonNullableAttribute
): LoaderTask {
const { base, item } = options
if (!isKnowledgeDirectoryItem(item)) {
logger.error(`Invalid item type for directoryTask: expected 'directory', got '${item.type}'`)
return {
loaderTasks: [],
loaderDoneReturn: {
...LangChainFramework.ERROR_LOADER_RETURN,
message: `Invalid item type: expected 'directory', got '${item.type}'`,
messageSource: 'validation'
}
}
}
const directory = item.content
const files = getAllFiles(directory)
const totalFiles = files.length
let processedFiles = 0
const sendDirectoryProcessingPercent = (totalFiles: number, processedFiles: number) => {
const mainWindow = windowService.getMainWindow()
mainWindow?.webContents.send(IpcChannel.DirectoryProcessingPercent, {
itemId: item.id,
percent: (processedFiles / totalFiles) * 100
})
}
const loaderDoneReturn: LoaderDoneReturn = {
entriesAdded: 0,
uniqueId: `DirectoryLoader_${uuidv4()}`,
uniqueIds: [],
loaderType: 'DirectoryLoader'
}
const loaderTasks: LoaderTaskItem[] = []
for (const file of files) {
loaderTasks.push({
state: LoaderTaskItemState.PENDING,
task: async () => {
const vectorStore = await getVectorStore()
return addFileLoader(base, vectorStore, file)
.then((result) => {
loaderDoneReturn.entriesAdded += 1
processedFiles += 1
sendDirectoryProcessingPercent(totalFiles, processedFiles)
loaderDoneReturn.uniqueIds.push(result.uniqueId)
return result
})
.then(async () => {
await vectorStore.save(path.join(this.storageDir, base.id))
})
.catch((err) => {
logger.error(err)
return {
...LangChainFramework.ERROR_LOADER_RETURN,
message: `Failed to add dir loader: ${err.message}`,
messageSource: 'embedding'
}
})
},
evaluateTaskWorkload: { workload: file.size }
})
}
return {
loaderTasks,
loaderDoneReturn
}
}
private urlTask(
getVectorStore: () => Promise<FaissStore>,
options: KnowledgeBaseAddItemOptionsNonNullableAttribute
): LoaderTask {
const { base, item } = options
if (!isKnowledgeUrlItem(item)) {
logger.error(`Invalid item type for urlTask: expected 'url', got '${item.type}'`)
return {
loaderTasks: [],
loaderDoneReturn: {
...LangChainFramework.ERROR_LOADER_RETURN,
message: `Invalid item type: expected 'url', got '${item.type}'`,
messageSource: 'validation'
}
}
}
const url = item.content
const loaderTask: LoaderTask = {
loaderTasks: [
{
state: LoaderTaskItemState.PENDING,
task: async () => {
// 使用处理后的网页进行加载
const vectorStore = await getVectorStore()
return addWebLoader(base, vectorStore, url, getUrlSource(url))
.then((result) => {
loaderTask.loaderDoneReturn = result
return result
})
.then(async () => {
await vectorStore.save(path.join(this.storageDir, base.id))
})
.catch((e) => {
logger.error(`Error in addWebLoader for ${url}: ${e}`)
const errorResult: LoaderReturn = {
...LangChainFramework.ERROR_LOADER_RETURN,
message: e.message,
messageSource: 'embedding'
}
loaderTask.loaderDoneReturn = errorResult
return errorResult
})
},
evaluateTaskWorkload: { workload: 2 * MB }
}
],
loaderDoneReturn: null
}
return loaderTask
}
private sitemapTask(
getVectorStore: () => Promise<FaissStore>,
options: KnowledgeBaseAddItemOptionsNonNullableAttribute
): LoaderTask {
const { base, item } = options
if (!isKnowledgeSitemapItem(item)) {
logger.error(`Invalid item type for sitemapTask: expected 'sitemap', got '${item.type}'`)
return {
loaderTasks: [],
loaderDoneReturn: {
...LangChainFramework.ERROR_LOADER_RETURN,
message: `Invalid item type: expected 'sitemap', got '${item.type}'`,
messageSource: 'validation'
}
}
}
const url = item.content
const loaderTask: LoaderTask = {
loaderTasks: [
{
state: LoaderTaskItemState.PENDING,
task: async () => {
// 使用处理后的网页进行加载
const vectorStore = await getVectorStore()
return addSitemapLoader(base, vectorStore, url)
.then((result) => {
loaderTask.loaderDoneReturn = result
return result
})
.then(async () => {
await vectorStore.save(path.join(this.storageDir, base.id))
})
.catch((e) => {
logger.error(`Error in addWebLoader for ${url}: ${e}`)
const errorResult: LoaderReturn = {
...LangChainFramework.ERROR_LOADER_RETURN,
message: e.message,
messageSource: 'embedding'
}
loaderTask.loaderDoneReturn = errorResult
return errorResult
})
},
evaluateTaskWorkload: { workload: 2 * MB }
}
],
loaderDoneReturn: null
}
return loaderTask
}
private noteTask(
getVectorStore: () => Promise<FaissStore>,
options: KnowledgeBaseAddItemOptionsNonNullableAttribute
): LoaderTask {
const { base, item } = options
if (!isKnowledgeNoteItem(item)) {
logger.error(`Invalid item type for noteTask: expected 'note', got '${item.type}'`)
return {
loaderTasks: [],
loaderDoneReturn: {
...LangChainFramework.ERROR_LOADER_RETURN,
message: `Invalid item type: expected 'note', got '${item.type}'`,
messageSource: 'validation'
}
}
}
const content = item.content
const sourceUrl = item.sourceUrl ?? ''
logger.info(`noteTask ${content}, ${sourceUrl}`)
const encoder = new TextEncoder()
const contentBytes = encoder.encode(content)
const loaderTask: LoaderTask = {
loaderTasks: [
{
state: LoaderTaskItemState.PENDING,
task: async () => {
// 使用处理后的笔记进行加载
const vectorStore = await getVectorStore()
return addNoteLoader(base, vectorStore, content, sourceUrl)
.then((result) => {
loaderTask.loaderDoneReturn = result
return result
})
.then(async () => {
await vectorStore.save(path.join(this.storageDir, base.id))
})
.catch((e) => {
logger.error(`Error in addNoteLoader for ${sourceUrl}: ${e}`)
const errorResult: LoaderReturn = {
...LangChainFramework.ERROR_LOADER_RETURN,
message: e.message,
messageSource: 'embedding'
}
loaderTask.loaderDoneReturn = errorResult
return errorResult
})
},
evaluateTaskWorkload: { workload: contentBytes.length }
}
],
loaderDoneReturn: null
}
return loaderTask
}
private videoTask(
getVectorStore: () => Promise<FaissStore>,
options: KnowledgeBaseAddItemOptionsNonNullableAttribute
): LoaderTask {
const { base, item } = options
if (!isKnowledgeVideoItem(item)) {
logger.error(`Invalid item type for videoTask: expected 'video', got '${item.type}'`)
return {
loaderTasks: [],
loaderDoneReturn: {
...LangChainFramework.ERROR_LOADER_RETURN,
message: `Invalid item type: expected 'video', got '${item.type}'`,
messageSource: 'validation'
}
}
}
const files = item.content
const loaderTask: LoaderTask = {
loaderTasks: [
{
state: LoaderTaskItemState.PENDING,
task: async () => {
const vectorStore = await getVectorStore()
return addVideoLoader(base, vectorStore, files)
.then((result) => {
loaderTask.loaderDoneReturn = result
return result
})
.then(async () => {
await vectorStore.save(path.join(this.storageDir, base.id))
})
.catch((e) => {
logger.error(`Preprocessing failed for ${files[0].name}: ${e}`)
const errorResult: LoaderReturn = {
...LangChainFramework.ERROR_LOADER_RETURN,
message: e.message,
messageSource: 'preprocess'
}
loaderTask.loaderDoneReturn = errorResult
return errorResult
})
},
evaluateTaskWorkload: { workload: files[0].size }
}
],
loaderDoneReturn: null
}
return loaderTask
}
private async getAllDocuments(base: KnowledgeBaseParams): Promise<Document[]> {
logger.info(`Fetching all documents from database for knowledge base: ${base.id}`)
try {
const results = (await this.getVectorStore(base)).docstore._docs
const documents: Document[] = Array.from(results.values())
logger.info(`Fetched ${documents.length} documents for BM25/Hybrid retriever.`)
return documents
} catch (e) {
logger.error(`Could not fetch documents from database for base ${base.id}: ${e}`)
// 如果表不存在或查询失败,返回空数组
return []
}
}
}