Files
cherry-studio/src/renderer/src/services/ApiService.ts
T
SuYao e25007ce0d feat: Support search info with search summary model (#2443)
* feat: Add search summary model and related functionality

- Introduce new search summary model configuration in settings
- Implement search summary prompt and model selection
- Add support for generating search keywords across providers
- Update localization files with new search summary model translations
- Enhance web search functionality with search summary generation

* refactor: Improve web search error handling and async flow

* fix: Update migration version for settings search summary prompt

* refactor(webSearch): Remove search summary model references from settings and localization files

- Deleted search summary model entries from English, Japanese, Russian, Chinese, and Traditional Chinese localization files.
- Refactored ModelSettings component to remove search summary model handling.
- Updated related services and settings to eliminate search summary model dependencies.

* refactor(llm): Remove search summary model from state and related hooks
2025-03-19 13:09:47 +08:00

362 lines
9.6 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import { getOpenAIWebSearchParams } from '@renderer/config/models'
import { SEARCH_SUMMARY_PROMPT } from '@renderer/config/prompts'
import i18n from '@renderer/i18n'
import store from '@renderer/store'
import { setGenerating } from '@renderer/store/runtime'
import { Assistant, Message, Model, Provider, Suggestion } from '@renderer/types'
import { formatMessageError, isAbortError } from '@renderer/utils/error'
import { cloneDeep, findLast, isEmpty } from 'lodash'
import AiProvider from '../providers/AiProvider'
import {
getAssistantProvider,
getDefaultAssistant,
getDefaultModel,
getProviderByModel,
getTopNamingModel,
getTranslateModel
} from './AssistantService'
import { EVENT_NAMES, EventEmitter } from './EventService'
import { filterMessages, filterUsefulMessages } from './MessagesService'
import { estimateMessagesUsage } from './TokenService'
import WebSearchService from './WebSearchService'
export async function fetchChatCompletion({
message,
messages,
assistant,
onResponse
}: {
message: Message
messages: Message[]
assistant: Assistant
onResponse: (message: Message) => void
}) {
const provider = getAssistantProvider(assistant)
const webSearchProvider = WebSearchService.getWebSearchProvider()
const AI = new AiProvider(provider)
try {
let _messages: Message[] = []
let isFirstChunk = true
let query = ''
// Search web
if (WebSearchService.isWebSearchEnabled() && assistant.enableWebSearch && assistant.model) {
const webSearchParams = getOpenAIWebSearchParams(assistant, assistant.model)
if (isEmpty(webSearchParams)) {
const lastMessage = findLast(messages, (m) => m.role === 'user')
const lastAnswer = findLast(messages, (m) => m.role === 'assistant')
const hasKnowledgeBase = !isEmpty(lastMessage?.knowledgeBaseIds)
if (lastMessage) {
if (hasKnowledgeBase) {
window.message.info({
content: i18n.t('message.ignore.knowledge.base'),
key: 'knowledge-base-no-match-info'
})
}
try {
// 等待关键词生成完成
const searchSummaryAssistant = getDefaultAssistant()
searchSummaryAssistant.model = assistant.model || getDefaultModel()
searchSummaryAssistant.prompt = SEARCH_SUMMARY_PROMPT
const keywords = await fetchSearchSummary({
messages: lastAnswer ? [lastAnswer, lastMessage] : [lastMessage],
assistant: searchSummaryAssistant
})
if (keywords) {
query = keywords
} else {
query = lastMessage.content
}
// 更新消息状态为搜索中
onResponse({ ...message, status: 'searching' })
// 等待搜索完成
const webSearch = await WebSearchService.search(webSearchProvider, query)
// 处理搜索结果
message.metadata = {
...message.metadata,
webSearch: webSearch
}
window.keyv.set(`web-search-${lastMessage?.id}`, webSearch)
} catch (error) {
console.error('Web search failed:', error)
}
}
}
}
const allMCPTools = await window.api.mcp.listTools()
await AI.completions({
messages: filterUsefulMessages(messages),
assistant,
onFilterMessages: (messages) => (_messages = messages),
onChunk: ({ text, reasoning_content, usage, metrics, search, citations, mcpToolResponse }) => {
message.content = message.content + text || ''
message.usage = usage
message.metrics = metrics
if (reasoning_content) {
message.reasoning_content = (message.reasoning_content || '') + reasoning_content
}
if (search) {
message.metadata = { ...message.metadata, groundingMetadata: search }
}
if (mcpToolResponse) {
message.metadata = { ...message.metadata, mcpTools: cloneDeep(mcpToolResponse) }
}
// Handle citations from Perplexity API
if (isFirstChunk && citations) {
message.metadata = {
...message.metadata,
citations
}
isFirstChunk = false
}
onResponse({ ...message, status: 'pending' })
},
mcpTools: allMCPTools
})
message.status = 'success'
if (!message.usage || !message?.usage?.completion_tokens) {
message.usage = await estimateMessagesUsage({
assistant,
messages: [..._messages, message]
})
// Set metrics.completion_tokens
if (message.metrics && message?.usage?.completion_tokens) {
if (!message.metrics?.completion_tokens) {
message = {
...message,
metrics: {
...message.metrics,
completion_tokens: message.usage.completion_tokens
}
}
}
}
}
} catch (error: any) {
if (isAbortError(error)) {
message.status = 'paused'
} else {
message.status = 'error'
message.error = formatMessageError(error)
}
}
// Emit chat completion event
EventEmitter.emit(EVENT_NAMES.RECEIVE_MESSAGE, message)
onResponse(message)
// Reset generating state
store.dispatch(setGenerating(false))
return message
}
interface FetchTranslateProps {
message: Message
assistant: Assistant
onResponse?: (text: string) => void
}
export async function fetchTranslate({ message, assistant, onResponse }: FetchTranslateProps) {
const model = getTranslateModel()
if (!model) {
throw new Error(i18n.t('error.provider_disabled'))
}
const provider = getProviderByModel(model)
if (!hasApiKey(provider)) {
throw new Error(i18n.t('error.no_api_key'))
}
const AI = new AiProvider(provider)
try {
return await AI.translate(message, assistant, onResponse)
} catch (error: any) {
return ''
}
}
export async function fetchMessagesSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
const model = getTopNamingModel() || assistant.model || getDefaultModel()
const provider = getProviderByModel(model)
if (!hasApiKey(provider)) {
return null
}
const AI = new AiProvider(provider)
try {
return await AI.summaries(filterMessages(messages), assistant)
} catch (error: any) {
return null
}
}
export async function fetchSearchSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
const model = assistant.model || getDefaultModel()
const provider = getProviderByModel(model)
if (!hasApiKey(provider)) {
return null
}
const AI = new AiProvider(provider)
try {
return await AI.summaryForSearch(messages, assistant)
} catch (error: any) {
return null
}
}
export async function fetchGenerate({ prompt, content }: { prompt: string; content: string }): Promise<string> {
const model = getDefaultModel()
const provider = getProviderByModel(model)
if (!hasApiKey(provider)) {
return ''
}
const AI = new AiProvider(provider)
try {
return await AI.generateText({ prompt, content })
} catch (error: any) {
return ''
}
}
export async function fetchSuggestions({
messages,
assistant
}: {
messages: Message[]
assistant: Assistant
}): Promise<Suggestion[]> {
const model = assistant.model
if (!model) {
return []
}
if (model.owned_by !== 'graphrag') {
return []
}
if (model.id.endsWith('global')) {
return []
}
const provider = getAssistantProvider(assistant)
const AI = new AiProvider(provider)
try {
return await AI.suggestions(filterMessages(messages), assistant)
} catch (error: any) {
return []
}
}
// Helper function to validate provider's basic settings such as API key, host, and model list
export function checkApiProvider(provider: Provider): {
valid: boolean
error: Error | null
} {
const key = 'api-check'
const style = { marginTop: '3vh' }
if (provider.id !== 'ollama' && provider.id !== 'lmstudio') {
if (!provider.apiKey) {
window.message.error({ content: i18n.t('message.error.enter.api.key'), key, style })
return {
valid: false,
error: new Error(i18n.t('message.error.enter.api.key'))
}
}
}
if (!provider.apiHost) {
window.message.error({ content: i18n.t('message.error.enter.api.host'), key, style })
return {
valid: false,
error: new Error(i18n.t('message.error.enter.api.host'))
}
}
if (isEmpty(provider.models)) {
window.message.error({ content: i18n.t('message.error.enter.model'), key, style })
return {
valid: false,
error: new Error(i18n.t('message.error.enter.model'))
}
}
return {
valid: true,
error: null
}
}
export async function checkApi(provider: Provider, model: Model) {
const validation = checkApiProvider(provider)
if (!validation.valid) {
return {
valid: validation.valid,
error: validation.error
}
}
const AI = new AiProvider(provider)
const { valid, error } = await AI.check(model)
return {
valid,
error
}
}
function hasApiKey(provider: Provider) {
if (!provider) return false
if (provider.id === 'ollama' || provider.id === 'lmstudio') return true
return !isEmpty(provider.apiKey)
}
export async function fetchModels(provider: Provider) {
const AI = new AiProvider(provider)
try {
return await AI.models()
} catch (error) {
return []
}
}
/**
* Format API keys
* @param value Raw key string
* @returns Formatted key string
*/
export const formatApiKeys = (value: string) => {
return value.replaceAll('', ',').replaceAll(' ', ',').replaceAll(' ', '').replaceAll('\n', ',')
}