Files
cherry-studio/src/renderer/src/services/ApiService.ts
T
SuYao fcdae8f41f feat(UI, OpenAI): support OpenAI-4o-web-search add support for web search citations (#3524)
* feat(UI, OpenAI): support  OpenAI 4o web search add support for web search citations

- refactor: Introduced a new CitationsList component to display citations in MessageContent.
- feat: Enhanced message handling to support web search results and annotations from OpenAI.
- refactor: Removed the deprecated MessageSearchResults component for cleaner code structure.
- refactor: Added utility functions for link conversion and URL extraction from Markdown.

* chore: remove debug logging from ProxyManager

* revert(OpenAIProvider): streamline reasoning check for stream output handling

* chore(OpenAIProvider): correct placement of webSearch in response object

* fix(patches): update OpenAI package version and remove patch references

- Integrated dayjs for dynamic date formatting in prompts.ts.

* feat(Citation, Favicon): enhance OpenAI web search support and citation handling

- Improved FallbackFavicon component to cache failed favicon URLs.
- Support all web search citation preview
- Added support for Hunyuan search model in OpenAIProvider and ApiService.

* refactor(provider/AI): move additional search parameters to AI Provider
2025-04-06 09:11:59 +08:00

463 lines
13 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import {
getOpenAIWebSearchParams,
isHunyuanSearchModel,
isOpenAIWebSearch,
isZhipuModel
} from '@renderer/config/models'
import { SEARCH_SUMMARY_PROMPT } from '@renderer/config/prompts'
import i18n from '@renderer/i18n'
import store from '@renderer/store'
import { setGenerating } from '@renderer/store/runtime'
import { Assistant, MCPTool, Message, Model, Provider, Suggestion } from '@renderer/types'
import { formatMessageError, isAbortError } from '@renderer/utils/error'
import { withGenerateImage } from '@renderer/utils/formats'
import {
cleanLinkCommas,
completeLinks,
convertLinks,
convertLinksToHunyuan,
convertLinksToOpenRouter,
convertLinksToZhipu,
extractUrlsFromMarkdown
} from '@renderer/utils/linkConverter'
import { cloneDeep, findLast, isEmpty } from 'lodash'
import AiProvider from '../providers/AiProvider'
import {
getAssistantProvider,
getDefaultAssistant,
getDefaultModel,
getProviderByModel,
getTopNamingModel,
getTranslateModel
} from './AssistantService'
import { EVENT_NAMES, EventEmitter } from './EventService'
import { filterContextMessages, filterMessages, filterUsefulMessages } from './MessagesService'
import { estimateMessagesUsage } from './TokenService'
import WebSearchService from './WebSearchService'
export async function fetchChatCompletion({
message,
messages,
assistant,
onResponse
}: {
message: Message
messages: Message[]
assistant: Assistant
onResponse: (message: Message) => void
}) {
const provider = getAssistantProvider(assistant)
const webSearchProvider = WebSearchService.getWebSearchProvider()
const AI = new AiProvider(provider)
try {
let _messages: Message[] = []
let isFirstChunk = true
let query = ''
// Search web
if (WebSearchService.isWebSearchEnabled() && assistant.enableWebSearch && assistant.model) {
const webSearchParams = getOpenAIWebSearchParams(assistant, assistant.model)
if (isEmpty(webSearchParams) && !isOpenAIWebSearch(assistant.model)) {
const lastMessage = findLast(messages, (m) => m.role === 'user')
const lastAnswer = findLast(messages, (m) => m.role === 'assistant')
const hasKnowledgeBase = !isEmpty(lastMessage?.knowledgeBaseIds)
if (lastMessage) {
if (hasKnowledgeBase) {
window.message.info({
content: i18n.t('message.ignore.knowledge.base'),
key: 'knowledge-base-no-match-info'
})
}
// 更新消息状态为搜索中
onResponse({ ...message, status: 'searching' })
try {
// 等待关键词生成完成
const searchSummaryAssistant = getDefaultAssistant()
searchSummaryAssistant.model = assistant.model || getDefaultModel()
searchSummaryAssistant.prompt = SEARCH_SUMMARY_PROMPT
// 如果启用搜索增强模式,则使用搜索增强模式
if (WebSearchService.isEnhanceModeEnabled()) {
const keywords = await fetchSearchSummary({
messages: lastAnswer ? [lastAnswer, lastMessage] : [lastMessage],
assistant: searchSummaryAssistant
})
if (keywords) {
query = keywords
}
} else {
query = lastMessage.content
}
// 等待搜索完成
const webSearch = await WebSearchService.search(webSearchProvider, query)
// 处理搜索结果
message.metadata = {
...message.metadata,
webSearch: webSearch
}
window.keyv.set(`web-search-${lastMessage?.id}`, webSearch)
} catch (error) {
console.error('Web search failed:', error)
}
}
}
}
const lastUserMessage = findLast(messages, (m) => m.role === 'user')
// Get MCP tools
const mcpTools: MCPTool[] = []
const enabledMCPs = lastUserMessage?.enabledMCPs
if (enabledMCPs && enabledMCPs.length > 0) {
for (const mcpServer of enabledMCPs) {
const tools = await window.api.mcp.listTools(mcpServer)
const availableTools = tools.filter((tool: any) => !mcpServer.disabledTools?.includes(tool.name))
mcpTools.push(...availableTools)
}
}
await AI.completions({
messages: filterUsefulMessages(filterContextMessages(messages)),
assistant,
onFilterMessages: (messages) => (_messages = messages),
onChunk: ({
text,
reasoning_content,
usage,
metrics,
webSearch,
search,
annotations,
citations,
mcpToolResponse,
generateImage
}) => {
if (assistant.model) {
if (isOpenAIWebSearch(assistant.model)) {
text = convertLinks(text || '', isFirstChunk)
} else if (assistant.model.provider === 'openrouter' && assistant.enableWebSearch) {
text = convertLinksToOpenRouter(text || '', isFirstChunk)
} else if (assistant.enableWebSearch) {
if (isZhipuModel(assistant.model)) {
text = convertLinksToZhipu(text || '', isFirstChunk)
} else if (isHunyuanSearchModel(assistant.model)) {
text = convertLinksToHunyuan(text || '', webSearch || [], isFirstChunk)
}
}
}
if (isFirstChunk) {
isFirstChunk = false
}
message.content = message.content + text || ''
message.usage = usage
message.metrics = metrics
if (reasoning_content) {
message.reasoning_content = (message.reasoning_content || '') + reasoning_content
}
if (mcpToolResponse) {
message.metadata = { ...message.metadata, mcpTools: cloneDeep(mcpToolResponse) }
}
if (generateImage && generateImage.images.length > 0) {
const existingImages = message.metadata?.generateImage?.images || []
generateImage.images = [...existingImages, ...generateImage.images]
console.log('generateImage', generateImage)
message.metadata = {
...message.metadata,
generateImage: generateImage
}
}
// Handle citations from Perplexity API
if (citations) {
message.metadata = {
...message.metadata,
citations
}
}
// Handle web search from Gemini
if (search) {
message.metadata = { ...message.metadata, groundingMetadata: search }
}
// Handle annotations from OpenAI
if (annotations) {
message.metadata = {
...message.metadata,
annotations: annotations
}
}
// Handle web search from Zhipu or Hunyuan
if (webSearch) {
message.metadata = {
...message.metadata,
webSearchInfo: webSearch
}
}
// Handle citations from Openrouter
if (assistant.model?.provider === 'openrouter' && assistant.enableWebSearch) {
const extractedUrls = extractUrlsFromMarkdown(message.content)
if (extractedUrls.length > 0) {
message.metadata = {
...message.metadata,
citations: extractedUrls
}
}
}
if (assistant.enableWebSearch) {
message.content = cleanLinkCommas(message.content)
if (webSearch && isZhipuModel(assistant.model)) {
message.content = completeLinks(message.content, webSearch)
}
}
onResponse({ ...message, status: 'pending' })
},
mcpTools: mcpTools
})
message.status = 'success'
message = withGenerateImage(message)
if (!message.usage || !message?.usage?.completion_tokens) {
message.usage = await estimateMessagesUsage({
assistant,
messages: [..._messages, message]
})
// Set metrics.completion_tokens
if (message.metrics && message?.usage?.completion_tokens) {
if (!message.metrics?.completion_tokens) {
message = {
...message,
metrics: {
...message.metrics,
completion_tokens: message.usage.completion_tokens
}
}
}
}
}
console.log('message', message)
} catch (error: any) {
if (isAbortError(error)) {
message.status = 'paused'
} else {
message.status = 'error'
message.error = formatMessageError(error)
}
}
// Emit chat completion event
EventEmitter.emit(EVENT_NAMES.RECEIVE_MESSAGE, message)
onResponse(message)
// Reset generating state
store.dispatch(setGenerating(false))
return message
}
interface FetchTranslateProps {
message: Message
assistant: Assistant
onResponse?: (text: string) => void
}
export async function fetchTranslate({ message, assistant, onResponse }: FetchTranslateProps) {
const model = getTranslateModel()
if (!model) {
throw new Error(i18n.t('error.provider_disabled'))
}
const provider = getProviderByModel(model)
if (!hasApiKey(provider)) {
throw new Error(i18n.t('error.no_api_key'))
}
const AI = new AiProvider(provider)
try {
return await AI.translate(message, assistant, onResponse)
} catch (error: any) {
return ''
}
}
export async function fetchMessagesSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
const model = getTopNamingModel() || assistant.model || getDefaultModel()
const provider = getProviderByModel(model)
if (!hasApiKey(provider)) {
return null
}
const AI = new AiProvider(provider)
try {
const text = await AI.summaries(filterMessages(messages), assistant)
// Remove all quotes from the text
return text?.replace(/["']/g, '') || null
} catch (error: any) {
return null
}
}
export async function fetchSearchSummary({ messages, assistant }: { messages: Message[]; assistant: Assistant }) {
const model = assistant.model || getDefaultModel()
const provider = getProviderByModel(model)
if (!hasApiKey(provider)) {
return null
}
const AI = new AiProvider(provider)
try {
return await AI.summaryForSearch(messages, assistant)
} catch (error: any) {
return null
}
}
export async function fetchGenerate({ prompt, content }: { prompt: string; content: string }): Promise<string> {
const model = getDefaultModel()
const provider = getProviderByModel(model)
if (!hasApiKey(provider)) {
return ''
}
const AI = new AiProvider(provider)
try {
return await AI.generateText({ prompt, content })
} catch (error: any) {
return ''
}
}
export async function fetchSuggestions({
messages,
assistant
}: {
messages: Message[]
assistant: Assistant
}): Promise<Suggestion[]> {
const model = assistant.model
if (!model) {
return []
}
if (model.id.endsWith('global')) {
return []
}
const provider = getAssistantProvider(assistant)
const AI = new AiProvider(provider)
try {
return await AI.suggestions(filterMessages(messages), assistant)
} catch (error: any) {
return []
}
}
// Helper function to validate provider's basic settings such as API key, host, and model list
export function checkApiProvider(provider: Provider): {
valid: boolean
error: Error | null
} {
const key = 'api-check'
const style = { marginTop: '3vh' }
if (provider.id !== 'ollama' && provider.id !== 'lmstudio') {
if (!provider.apiKey) {
window.message.error({ content: i18n.t('message.error.enter.api.key'), key, style })
return {
valid: false,
error: new Error(i18n.t('message.error.enter.api.key'))
}
}
}
if (!provider.apiHost) {
window.message.error({ content: i18n.t('message.error.enter.api.host'), key, style })
return {
valid: false,
error: new Error(i18n.t('message.error.enter.api.host'))
}
}
if (isEmpty(provider.models)) {
window.message.error({ content: i18n.t('message.error.enter.model'), key, style })
return {
valid: false,
error: new Error(i18n.t('message.error.enter.model'))
}
}
return {
valid: true,
error: null
}
}
export async function checkApi(provider: Provider, model: Model) {
const validation = checkApiProvider(provider)
if (!validation.valid) {
return {
valid: validation.valid,
error: validation.error
}
}
const AI = new AiProvider(provider)
const { valid, error } = await AI.check(model)
return {
valid,
error
}
}
function hasApiKey(provider: Provider) {
if (!provider) return false
if (provider.id === 'ollama' || provider.id === 'lmstudio') return true
return !isEmpty(provider.apiKey)
}
export async function fetchModels(provider: Provider) {
const AI = new AiProvider(provider)
try {
return await AI.models()
} catch (error) {
return []
}
}
/**
* Format API keys
* @param value Raw key string
* @returns Formatted key string
*/
export const formatApiKeys = (value: string) => {
return value.replaceAll('', ',').replaceAll(' ', ',').replaceAll(' ', '').replaceAll('\n', ',')
}