e25007ce0d
* feat: Add search summary model and related functionality - Introduce new search summary model configuration in settings - Implement search summary prompt and model selection - Add support for generating search keywords across providers - Update localization files with new search summary model translations - Enhance web search functionality with search summary generation * refactor: Improve web search error handling and async flow * fix: Update migration version for settings search summary prompt * refactor(webSearch): Remove search summary model references from settings and localization files - Deleted search summary model entries from English, Japanese, Russian, Chinese, and Traditional Chinese localization files. - Refactored ModelSettings component to remove search summary model handling. - Updated related services and settings to eliminate search summary model dependencies. * refactor(llm): Remove search summary model from state and related hooks
362 lines
9.6 KiB
TypeScript
362 lines
9.6 KiB
TypeScript
import { getOpenAIWebSearchParams } from '@renderer/config/models'
|
||
import { SEARCH_SUMMARY_PROMPT } from '@renderer/config/prompts'
|
||
import i18n from '@renderer/i18n'
|
||
import store from '@renderer/store'
|
||
import { setGenerating } from '@renderer/store/runtime'
|
||
import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types'
|
||
import { formatMessageError, isAbortError } from '@renderer/utils/error'
|
||
import { cloneDeep, findLast, isEmpty } from 'lodash'
|
||
|
||
import AiProvider from '../providers/AiProvider'
|
||
import {
|
||
getAssistantProvider,
|
||
getDefaultAssistant,
|
||
getDefaultModel,
|
||
getProviderByModel,
|
||
getTopNamingModel,
|
||
getTranslateModel
|
||
} from './AssistantService'
|
||
import { EVENT_NAMES, EventEmitter } from './EventService'
|
||
import { filterMessages, filterUsefulMessages } from './MessagesService'
|
||
import { estimateMessagesUsage } from './TokenService'
|
||
import WebSearchService from './WebSearchService'
|
||
|
||
export async function fetchChatCompletion({
|
||
message,
|
||
messages,
|
||
assistant,
|
||
onResponse
|
||
}: {
|
||
message: Message
|
||
messages: Message[]
|
||
assistant: Assistant
|
||
onResponse: (message: Message) => void
|
||
}) {
|
||
const provider = getAssistantProvider(assistant)
|
||
const webSearchProvider = WebSearchService.getWebSearchProvider()
|
||
const AI = new AiProvider(provider)
|
||
|
||
try {
|
||
let _messages: Message[] = []
|
||
let isFirstChunk = true
|
||
let query = ''
|
||
|
||
// Search web
|
||
if (WebSearchService.isWebSearchEnabled() && assistant.enableWebSearch && assistant.model) {
|
||
const webSearchParams = getOpenAIWebSearchParams(assistant, assistant.model)
|
||
|
||
if (isEmpty(webSearchParams)) {
|
||
const lastMessage = findLast(messages, (m) => m.role === 'user')
|
||
const lastAnswer = findLast(messages, (m) => m.role === 'assistant')
|
||
const hasKnowledgeBase = !isEmpty(lastMessage?.knowledgeBaseIds)
|
||
if (lastMessage) {
|
||
if (hasKnowledgeBase) {
|
||
window.message.info({
|
||
content: i18n.t('message.ignore.knowledge.base'),
|
||
key: 'knowledge-base-no-match-info'
|
||
})
|
||
}
|
||
|
||
try {
|
||
// 等待关键词生成完成
|
||
const searchSummaryAssistant = getDefaultAssistant()
|
||
searchSummaryAssistant.model = assistant.model || getDefaultModel()
|
||
searchSummaryAssistant.prompt = SEARCH_SUMMARY_PROMPT
|
||
const keywords = await fetchSearchSummary({
|
||
messages: lastAnswer ? [lastAnswer, lastMessage] : [lastMessage],
|
||
assistant: searchSummaryAssistant
|
||
})
|
||
|
||
if (keywords) {
|
||
query = keywords
|
||
} else {
|
||
query = lastMessage.content
|
||
}
|
||
|
||
// 更新消息状态为搜索中
|
||
onResponse({ ...message, status: 'searching' })
|
||
|
||
// 等待搜索完成
|
||
const webSearch = await WebSearchService.search(webSearchProvider, query)
|
||
|
||
// 处理搜索结果
|
||
message.metadata = {
|
||
...message.metadata,
|
||
webSearch: webSearch
|
||
}
|
||
window.keyv.set(`web-search-${lastMessage?.id}`, webSearch)
|
||
} catch (error) {
|
||
console.error('Web search failed:', error)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
const allMCPTools = await window.api.mcp.listTools()
|
||
await AI.completions({
|
||
messages: filterUsefulMessages(messages),
|
||
assistant,
|
||
onFilterMessages: (messages) => (_messages = messages),
|
||
onChunk: ({ text, reasoning_content, usage, metrics, search, citations, mcpToolResponse }) => {
|
||
message.content = message.content + text || ''
|
||
message.usage = usage
|
||
message.metrics = metrics
|
||
|
||
if (reasoning_content) {
|
||
message.reasoning_content = (message.reasoning_content || '') + reasoning_content
|
||
}
|
||
|
||
if (search) {
|
||
message.metadata = { ...message.metadata, groundingMetadata: search }
|
||
}
|
||
|
||
if (mcpToolResponse) {
|
||
message.metadata = { ...message.metadata, mcpTools: cloneDeep(mcpToolResponse) }
|
||
}
|
||
|
||
// Handle citations from Perplexity API
|
||
if (isFirstChunk && citations) {
|
||
message.metadata = {
|
||
...message.metadata,
|
||
citations
|
||
}
|
||
isFirstChunk = false
|
||
}
|
||
|
||
onResponse({ ...message, status: 'pending' })
|
||
},
|
||
mcpTools: allMCPTools
|
||
})
|
||
|
||
message.status = 'success'
|
||
|
||
if (!message.usage || !message?.usage?.completion_tokens) {
|
||
message.usage = await estimateMessagesUsage({
|
||
assistant,
|
||
messages: [..._messages, message]
|
||
})
|
||
// Set metrics.completion_tokens
|
||
if (message.metrics && message?.usage?.completion_tokens) {
|
||
if (!message.metrics?.completion_tokens) {
|
||
message = {
|
||
...message,
|
||
metrics: {
|
||
...message.metrics,
|
||
completion_tokens: message.usage.completion_tokens
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
} catch (error: any) {
|
||
if (isAbortError(error)) {
|
||
message.status = 'paused'
|
||
} else {
|
||
message.status = 'error'
|
||
message.error = formatMessageError(error)
|
||
}
|
||
}
|
||
|
||
// Emit chat completion event
|
||
EventEmitter.emit(EVENT_NAMES.RECEIVE_MESSAGE, message)
|
||
onResponse(message)
|
||
|
||
// Reset generating state
|
||
store.dispatch(setGenerating(false))
|
||
|
||
return message
|
||
}
|
||
|
||
interface FetchTranslateProps {
|
||
message: Message
|
||
assistant: Assistant
|
||
onResponse?: (text: string) => void
|
||
}
|
||
|
||
export async function fetchTranslate({ message, assistant, onResponse }: FetchTranslateProps) {
|
||
const model = getTranslateModel()
|
||
|
||
if (!model) {
|
||
throw new Error(i18n.t('error.provider_disabled'))
|
||
}
|
||
|
||
const provider = getProviderByModel(model)
|
||
|
||
if (!hasApiKey(provider)) {
|
||
throw new Error(i18n.t('error.no_api_key'))
|
||
}
|
||
|
||
const AI = new AiProvider(provider)
|
||
|
||
try {
|
||
return await AI.translate(message, assistant, onResponse)
|
||
} catch (error: any) {
|
||
return ''
|
||
}
|
||
}
|
||
|
||
export async function fetchMessagesSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
|
||
const model = getTopNamingModel() || assistant.model || getDefaultModel()
|
||
const provider = getProviderByModel(model)
|
||
|
||
if (!hasApiKey(provider)) {
|
||
return null
|
||
}
|
||
|
||
const AI = new AiProvider(provider)
|
||
|
||
try {
|
||
return await AI.summaries(filterMessages(messages), assistant)
|
||
} catch (error: any) {
|
||
return null
|
||
}
|
||
}
|
||
|
||
export async function fetchSearchSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
|
||
const model = assistant.model || getDefaultModel()
|
||
const provider = getProviderByModel(model)
|
||
|
||
if (!hasApiKey(provider)) {
|
||
return null
|
||
}
|
||
|
||
const AI = new AiProvider(provider)
|
||
|
||
try {
|
||
return await AI.summaryForSearch(messages, assistant)
|
||
} catch (error: any) {
|
||
return null
|
||
}
|
||
}
|
||
|
||
export async function fetchGenerate({ prompt, content }: { prompt: string; content: string }): Promise<string> {
|
||
const model = getDefaultModel()
|
||
const provider = getProviderByModel(model)
|
||
|
||
if (!hasApiKey(provider)) {
|
||
return ''
|
||
}
|
||
|
||
const AI = new AiProvider(provider)
|
||
|
||
try {
|
||
return await AI.generateText({ prompt, content })
|
||
} catch (error: any) {
|
||
return ''
|
||
}
|
||
}
|
||
|
||
export async function fetchSuggestions({
|
||
messages,
|
||
assistant
|
||
}: {
|
||
messages: Message[]
|
||
assistant: Assistant
|
||
}): Promise<Suggestion[]> {
|
||
const model = assistant.model
|
||
if (!model) {
|
||
return []
|
||
}
|
||
|
||
if (model.owned_by !== 'graphrag') {
|
||
return []
|
||
}
|
||
|
||
if (model.id.endsWith('global')) {
|
||
return []
|
||
}
|
||
|
||
const provider = getAssistantProvider(assistant)
|
||
const AI = new AiProvider(provider)
|
||
|
||
try {
|
||
return await AI.suggestions(filterMessages(messages), assistant)
|
||
} catch (error: any) {
|
||
return []
|
||
}
|
||
}
|
||
|
||
// Helper function to validate provider's basic settings such as API key, host, and model list
|
||
export function checkApiProvider(provider: Provider): {
|
||
valid: boolean
|
||
error: Error | null
|
||
} {
|
||
const key = 'api-check'
|
||
const style = { marginTop: '3vh' }
|
||
|
||
if (provider.id !== 'ollama' && provider.id !== 'lmstudio') {
|
||
if (!provider.apiKey) {
|
||
window.message.error({ content: i18n.t('message.error.enter.api.key'), key, style })
|
||
return {
|
||
valid: false,
|
||
error: new Error(i18n.t('message.error.enter.api.key'))
|
||
}
|
||
}
|
||
}
|
||
|
||
if (!provider.apiHost) {
|
||
window.message.error({ content: i18n.t('message.error.enter.api.host'), key, style })
|
||
return {
|
||
valid: false,
|
||
error: new Error(i18n.t('message.error.enter.api.host'))
|
||
}
|
||
}
|
||
|
||
if (isEmpty(provider.models)) {
|
||
window.message.error({ content: i18n.t('message.error.enter.model'), key, style })
|
||
return {
|
||
valid: false,
|
||
error: new Error(i18n.t('message.error.enter.model'))
|
||
}
|
||
}
|
||
|
||
return {
|
||
valid: true,
|
||
error: null
|
||
}
|
||
}
|
||
|
||
export async function checkApi(provider: Provider, model: Model) {
|
||
const validation = checkApiProvider(provider)
|
||
if (!validation.valid) {
|
||
return {
|
||
valid: validation.valid,
|
||
error: validation.error
|
||
}
|
||
}
|
||
|
||
const AI = new AiProvider(provider)
|
||
|
||
const { valid, error } = await AI.check(model)
|
||
|
||
return {
|
||
valid,
|
||
error
|
||
}
|
||
}
|
||
|
||
function hasApiKey(provider: Provider) {
|
||
if (!provider) return false
|
||
if (provider.id === 'ollama' || provider.id === 'lmstudio') return true
|
||
return !isEmpty(provider.apiKey)
|
||
}
|
||
|
||
export async function fetchModels(provider: Provider) {
|
||
const AI = new AiProvider(provider)
|
||
|
||
try {
|
||
return await AI.models()
|
||
} catch (error) {
|
||
return []
|
||
}
|
||
}
|
||
|
||
/**
|
||
* Format API keys
|
||
* @param value Raw key string
|
||
* @returns Formatted key string
|
||
*/
|
||
export const formatApiKeys = (value: string) => {
|
||
return value.replaceAll(',', ',').replaceAll(' ', ',').replaceAll(' ', '').replaceAll('\n', ',')
|
||
}
|