Compare commits

...

2 Commits

Author SHA1 Message Date
MyPrototypeWhat
60d6fbe8f4 feat: refactor web search to provider-specific tools with advanced parameters
- Add ExaSearchTool and TavilySearchTool with provider-specific parameters
- Extend type system for Exa (neural search, date filters) and Tavily (AI answers, search depth)
- Update all providers to support ProviderSpecificParams interface
- Add searchResultAdapters for unified citation conversion
- Remove rawContent from LLM output and storage to reduce token usage
- Support favicon, highlights, answer, images metadata
- Update UI components to handle new tool names
- Preserve existing RAG compression and token cutoff capabilities

Breaking changes: None (backward compatible with existing providers)
2025-10-13 17:53:40 +08:00
MyPrototypeWhat
ff378ca567 feat: enhance web search functionality with abort signal support
- Updated WebSearchTool to accept an abort signal in the execute method.
- Modified various WebSearchProvider classes to include httpOptions for search methods, allowing for abort signal handling.
- Improved WebSearchService to prioritize external abort signals for better request management.
- Enhanced MessageTool to reflect tool status with appropriate UI feedback.
2025-10-09 17:44:42 +08:00
20 changed files with 900 additions and 89 deletions

View File

@@ -24,8 +24,10 @@ import { generateText } from 'ai'
import { isEmpty } from 'lodash'
import { MemoryProcessor } from '../../services/MemoryProcessor'
import { exaSearchTool } from '../tools/ExaSearchTool'
import { knowledgeSearchTool } from '../tools/KnowledgeSearchTool'
import { memorySearchTool } from '../tools/MemorySearchTool'
import { tavilySearchTool } from '../tools/TavilySearchTool'
import { webSearchToolWithPreExtractedKeywords } from '../tools/WebSearchTool'
const logger = loggerService.withContext('SearchOrchestrationPlugin')
@@ -316,13 +318,28 @@ export const searchOrchestrationPlugin = (assistant: Assistant, topicId: string)
const needsSearch = analysisResult.websearch.question && analysisResult.websearch.question[0] !== 'not_needed'
if (needsSearch) {
// onChunk({ type: ChunkType.EXTERNEL_TOOL_IN_PROGRESS })
// logger.info('🌐 Adding web search tool with pre-extracted keywords')
// 根据 Provider ID 动态选择工具
switch (assistant.webSearchProviderId) {
case 'exa':
logger.info('🌐 Adding Exa search tool (provider-specific)')
// Exa 工具直接接受单个查询字符串,使用第一个问题或合并所有问题
params.tools['builtin_exa_search'] = exaSearchTool(context.requestId)
break
case 'tavily':
logger.info('🌐 Adding Tavily search tool (provider-specific)')
// Tavily 工具直接接受单个查询字符串
params.tools['builtin_tavily_search'] = tavilySearchTool(context.requestId)
break
default:
logger.info('🌐 Adding web search tool with pre-extracted keywords')
// 其他 Provider 使用通用的 WebSearchTool
params.tools['builtin_web_search'] = webSearchToolWithPreExtractedKeywords(
assistant.webSearchProviderId,
analysisResult.websearch,
context.requestId
)
break
}
}
}

View File

@@ -0,0 +1,166 @@
import { loggerService } from '@logger'
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
import WebSearchService from '@renderer/services/WebSearchService'
import { ProviderSpecificParams, WebSearchProviderResponse } from '@renderer/types'
import { ExtractResults } from '@renderer/utils/extract'
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
import { z } from 'zod'
const logger = loggerService.withContext('ExaSearchTool')
/**
* Exa 专用搜索工具 - 暴露 Exa 的高级搜索能力给 LLM
* 支持 Neural Search、Category Filtering、Date Range 等功能
*/
export const exaSearchTool = (requestId: string) => {
const webSearchProvider = WebSearchService.getWebSearchProvider('exa')
if (!webSearchProvider) {
throw new Error('Exa provider not found or not configured')
}
return tool({
name: 'builtin_exa_search',
description: `Advanced AI-powered search using Exa.ai with neural understanding and filtering capabilities.
Key Features:
- Neural Search: AI-powered semantic search that understands intent
- Search Type: Choose between neural (AI), keyword (traditional), or auto mode
- Category Filter: Focus on specific content types (company, research paper, news, etc.)
- Date Range: Filter by publication date
- Auto-prompt: Let Exa optimize your query automatically
Best for: Research, finding specific types of content, semantic search, and understanding complex queries.`,
inputSchema: z.object({
query: z.string().describe('The search query - be specific and clear'),
numResults: z.number().min(1).max(20).optional().describe('Number of results to return (1-20, default: 5)'),
type: z
.enum(['neural', 'keyword', 'auto', 'fast'])
.optional()
.describe(
'Search type: neural (embeddings-based), keyword (Google-like SERP), auto (default, intelligently combines both), or fast (streamlined versions)'
),
category: z
.string()
.optional()
.describe(
'Filter by content category: company, research paper, news, github, tweet, movie, song, personal site, pdf, etc.'
),
startPublishedDate: z
.string()
.optional()
.describe('Start date filter based on published date in ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM:SSZ)'),
endPublishedDate: z
.string()
.optional()
.describe('End date filter based on published date in ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM:SSZ)'),
startCrawlDate: z
.string()
.optional()
.describe('Start date filter based on crawl date in ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM:SSZ)'),
endCrawlDate: z
.string()
.optional()
.describe('End date filter based on crawl date in ISO 8601 format (YYYY-MM-DD or YYYY-MM-DDTHH:MM:SSZ)'),
useAutoprompt: z.boolean().optional().describe('Let Exa optimize your query automatically (recommended: true)')
}),
execute: async (params, { abortSignal }) => {
// 构建 provider 特定参数(排除 query 和 numResults这些由系统控制
const providerParams: ProviderSpecificParams = {
exa: {
type: params.type,
category: params.category,
startPublishedDate: params.startPublishedDate,
endPublishedDate: params.endPublishedDate,
startCrawlDate: params.startCrawlDate,
endCrawlDate: params.endCrawlDate,
useAutoprompt: params.useAutoprompt
}
}
// 构建 ExtractResults 结构
const extractResults: ExtractResults = {
websearch: {
question: [params.query]
}
}
// 统一调用 processWebsearch - 保留所有中间件时间戳、黑名单、tracing、压缩
const finalResults: WebSearchProviderResponse = await WebSearchService.processWebsearch(
webSearchProvider,
extractResults,
requestId,
abortSignal,
providerParams
)
logger.info(`Exa search completed: ${finalResults.results.length} results for "${params.query}"`)
return finalResults
},
toModelOutput: (results) => {
let summary = 'No search results found.'
if (results.query && results.results.length > 0) {
summary = `Found ${results.results.length} relevant sources using Exa AI search. Use [number] format to cite specific information.`
}
const citationData = results.results.map((result, index) => {
const citation: any = {
number: index + 1,
title: result.title,
content: result.content,
url: result.url
}
// 添加 Exa 特有的元数据
if ('favicon' in result && result.favicon) {
citation.favicon = result.favicon
}
if ('author' in result && result.author) {
citation.author = result.author
}
if ('publishedDate' in result && result.publishedDate) {
citation.publishedDate = result.publishedDate
}
if ('score' in result && result.score !== undefined) {
citation.score = result.score
}
if ('highlights' in result && result.highlights) {
citation.highlights = result.highlights
}
return citation
})
// 使用 REFERENCE_PROMPT 格式化引用
const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
const fullInstructions = REFERENCE_PROMPT.replace(
'{question}',
"Based on the Exa search results, please answer the user's question with proper citations."
).replace('{references}', referenceContent)
return {
type: 'content',
value: [
{
type: 'text',
text: 'Exa AI Search: Neural search with semantic understanding and rich metadata (author, publish date, highlights).'
},
{
type: 'text',
text: summary
},
{
type: 'text',
text: fullInstructions
}
]
}
}
})
}
export type ExaSearchToolOutput = InferToolOutput<ReturnType<typeof exaSearchTool>>
export type ExaSearchToolInput = InferToolInput<ReturnType<typeof exaSearchTool>>

View File

@@ -0,0 +1,161 @@
import { loggerService } from '@logger'
import { REFERENCE_PROMPT } from '@renderer/config/prompts'
import WebSearchService from '@renderer/services/WebSearchService'
import { ProviderSpecificParams, WebSearchProviderResponse } from '@renderer/types'
import { ExtractResults } from '@renderer/utils/extract'
import { type InferToolInput, type InferToolOutput, tool } from 'ai'
import { z } from 'zod'
const logger = loggerService.withContext('TavilySearchTool')
/**
* Tavily 专用搜索工具 - 暴露 Tavily 的高级搜索能力给 LLM
* 支持 AI-powered answers、Search depth control、Topic filtering 等功能
*/
export const tavilySearchTool = (requestId: string) => {
const webSearchProvider = WebSearchService.getWebSearchProvider('tavily')
if (!webSearchProvider) {
throw new Error('Tavily provider not found or not configured')
}
return tool({
name: 'builtin_tavily_search',
description: `AI-powered search using Tavily with direct answers and comprehensive content extraction.
Key Features:
- Direct AI Answer: Get a concise, factual answer extracted from search results
- Search Depth: Choose between basic (fast) or advanced (comprehensive) search
- Topic Focus: Filter by general, news, or finance topics
- Full Content: Access complete webpage content, not just snippets
- Rich Media: Optionally include relevant images from search results
Best for: Quick factual answers, news monitoring, financial research, and comprehensive content analysis.`,
inputSchema: z.object({
query: z.string().describe('The search query - be specific and clear'),
maxResults: z
.number()
.min(1)
.max(20)
.optional()
.describe('Maximum number of results to return (1-20, default: 5)'),
topic: z
.enum(['general', 'news', 'finance'])
.optional()
.describe('Topic filter: general (default), news (latest news), or finance (financial/market data)'),
searchDepth: z
.enum(['basic', 'advanced'])
.optional()
.describe('Search depth: basic (faster, top results) or advanced (slower, more comprehensive)'),
includeAnswer: z
.boolean()
.optional()
.describe('Include AI-generated direct answer extracted from results (default: true)'),
includeRawContent: z
.boolean()
.optional()
.describe('Include full webpage content instead of just snippets (default: true)'),
includeImages: z.boolean().optional().describe('Include relevant images from search results (default: false)')
}),
execute: async (params, { abortSignal }) => {
try {
// 构建 provider 特定参数
const providerParams: ProviderSpecificParams = {
tavily: {
topic: params.topic,
searchDepth: params.searchDepth,
includeAnswer: params.includeAnswer,
includeRawContent: params.includeRawContent,
includeImages: params.includeImages
}
}
// 构建 ExtractResults 结构
const extractResults: ExtractResults = {
websearch: {
question: [params.query]
}
}
// 统一调用 processWebsearch - 保留所有中间件时间戳、黑名单、tracing、压缩
const finalResults: WebSearchProviderResponse = await WebSearchService.processWebsearch(
webSearchProvider,
extractResults,
requestId,
abortSignal,
providerParams
)
logger.info(`Tavily search completed: ${finalResults.results.length} results for "${params.query}"`)
return finalResults
} catch (error) {
if (error instanceof DOMException && error.name === 'AbortError') {
logger.info('Tavily search aborted')
throw error
}
logger.error('Tavily search failed:', error as Error)
throw new Error(`Tavily search failed: ${error instanceof Error ? error.message : 'Unknown error'}`)
}
},
toModelOutput: (results) => {
let summary = 'No search results found.'
if (results.query && results.results.length > 0) {
summary = `Found ${results.results.length} relevant sources using Tavily AI search. Use [number] format to cite specific information.`
}
const citationData = results.results.map((result, index) => {
const citation: any = {
number: index + 1,
title: result.title,
content: result.content,
url: result.url
}
// 添加 Tavily 特有的元数据
if ('answer' in result && result.answer) {
citation.answer = result.answer // Tavily 的直接答案
}
if ('images' in result && result.images && result.images.length > 0) {
citation.images = result.images // Tavily 的图片
}
if ('score' in result && result.score !== undefined) {
citation.score = result.score
}
return citation
})
// 使用 REFERENCE_PROMPT 格式化引用
const referenceContent = `\`\`\`json\n${JSON.stringify(citationData, null, 2)}\n\`\`\``
const fullInstructions = REFERENCE_PROMPT.replace(
'{question}',
"Based on the Tavily search results, please answer the user's question with proper citations."
).replace('{references}', referenceContent)
return {
type: 'content',
value: [
{
type: 'text',
text: 'Tavily AI Search: AI-powered with direct answers, full content extraction, and optional image results.'
},
{
type: 'text',
text: summary
},
{
type: 'text',
text: fullInstructions
}
]
}
}
})
}
export type TavilySearchToolOutput = InferToolOutput<ReturnType<typeof tavilySearchTool>>
export type TavilySearchToolInput = InferToolInput<ReturnType<typeof tavilySearchTool>>

View File

@@ -40,7 +40,7 @@ You can use this tool as-is to search with the prepared queries, or provide addi
.describe('Optional additional context, keywords, or specific focus to enhance the search')
}),
execute: async ({ additionalContext }) => {
execute: async ({ additionalContext }, { abortSignal }) => {
let finalQueries = [...extractedKeywords.question]
if (additionalContext?.trim()) {
@@ -67,7 +67,15 @@ You can use this tool as-is to search with the prepared queries, or provide addi
links: extractedKeywords.links
}
}
searchResults = await WebSearchService.processWebsearch(webSearchProvider!, extractResults, requestId)
// abortSignal?.addEventListener('abort', () => {
// console.log('tool_call_abortSignal', abortSignal?.aborted)
// })
searchResults = await WebSearchService.processWebsearch(
webSearchProvider!,
extractResults,
requestId,
abortSignal
)
return searchResults
},

View File

@@ -1,5 +1,8 @@
import { NormalToolResponse } from '@renderer/types'
import type { ToolMessageBlock } from '@renderer/types/newMessage'
import { MessageBlockStatus, ToolMessageBlock } from '@renderer/types/newMessage'
import { TFunction } from 'i18next'
import { Pause } from 'lucide-react'
import { useTranslation } from 'react-i18next'
import { MessageAgentTools } from './MessageAgentTools'
import { MessageKnowledgeSearchToolTitle } from './MessageKnowledgeSearch'
@@ -35,14 +38,28 @@ const isAgentTool = (toolName: string) => {
return false
}
const ChooseTool = (toolResponse: NormalToolResponse): React.ReactNode | null => {
const ChooseTool = (
toolResponse: NormalToolResponse,
status: MessageBlockStatus,
t: TFunction
): React.ReactNode | null => {
let toolName = toolResponse.tool.name
const toolType = toolResponse.tool.type
if (toolName.startsWith(prefix)) {
toolName = toolName.slice(prefix.length)
if (status === MessageBlockStatus.PAUSED) {
return (
<div className="flex items-center gap-1">
<Pause className="h-4 w-4" />
<span>{t('message.tools.aborted')}</span>
</div>
)
}
switch (toolName) {
case 'web_search':
case 'web_search_preview':
case 'exa_search':
case 'tavily_search':
return toolType === 'provider' ? null : <MessageWebSearchToolTitle toolResponse={toolResponse} />
case 'knowledge_search':
return <MessageKnowledgeSearchToolTitle toolResponse={toolResponse} />
@@ -58,12 +75,13 @@ const ChooseTool = (toolResponse: NormalToolResponse): React.ReactNode | null =>
}
export default function MessageTool({ block }: Props) {
const { t } = useTranslation()
// FIXME: 语义错误,这里已经不是 MCP tool 了,更改rawMcpToolResponse需要改用户数据, 所以暂时保留
const toolResponse = block.metadata?.rawMcpToolResponse as NormalToolResponse
if (!toolResponse) return null
const toolRenderer = ChooseTool(toolResponse as NormalToolResponse)
const toolRenderer = ChooseTool(toolResponse as NormalToolResponse, block.status, t)
if (!toolRenderer) return null

View File

@@ -1,3 +1,5 @@
import { ExaSearchToolInput, ExaSearchToolOutput } from '@renderer/aiCore/tools/ExaSearchTool'
import { TavilySearchToolInput, TavilySearchToolOutput } from '@renderer/aiCore/tools/TavilySearchTool'
import { WebSearchToolInput, WebSearchToolOutput } from '@renderer/aiCore/tools/WebSearchTool'
import Spinner from '@renderer/components/Spinner'
import { NormalToolResponse } from '@renderer/types'
@@ -8,17 +10,31 @@ import styled from 'styled-components'
const { Text } = Typography
// 联合类型 - 支持多种搜索工具
type SearchToolInput = WebSearchToolInput | ExaSearchToolInput | TavilySearchToolInput
type SearchToolOutput = WebSearchToolOutput | ExaSearchToolOutput | TavilySearchToolOutput
export const MessageWebSearchToolTitle = ({ toolResponse }: { toolResponse: NormalToolResponse }) => {
const { t } = useTranslation()
const toolInput = toolResponse.arguments as WebSearchToolInput
const toolOutput = toolResponse.response as WebSearchToolOutput
const toolInput = toolResponse.arguments as SearchToolInput
const toolOutput = toolResponse.response as SearchToolOutput
// 根据不同的工具类型获取查询内容
const getQueryText = () => {
if ('additionalContext' in toolInput) {
return toolInput.additionalContext ?? ''
}
if ('query' in toolInput) {
return toolInput.query ?? ''
}
return ''
}
return toolResponse.status !== 'done' ? (
<Spinner
text={
<PrepareToolWrapper>
{t('message.searching')}
<span>{toolInput?.additionalContext ?? ''}</span>
<span>{getQueryText()}</span>
</PrepareToolWrapper>
}
/>

View File

@@ -1,5 +1,5 @@
import { WebSearchState } from '@renderer/store/websearch'
import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
import { ProviderSpecificParams, WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
export default abstract class BaseWebSearchProvider {
// @ts-ignore this
@@ -16,7 +16,8 @@ export default abstract class BaseWebSearchProvider {
abstract search(
query: string,
websearch: WebSearchState,
httpOptions?: RequestInit
httpOptions?: RequestInit,
providerParams?: ProviderSpecificParams
): Promise<WebSearchProviderResponse>
public getApiHost() {

View File

@@ -1,6 +1,6 @@
import { loggerService } from '@logger'
import { WebSearchState } from '@renderer/store/websearch'
import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
import { ProviderSpecificParams, WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
import { BochaSearchParams, BochaSearchResponse } from '@renderer/utils/bocha'
import BaseWebSearchProvider from './BaseWebSearchProvider'
@@ -18,7 +18,12 @@ export default class BochaProvider extends BaseWebSearchProvider {
}
}
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
public async search(
query: string,
websearch: WebSearchState,
httpOptions?: RequestInit,
_providerParams?: ProviderSpecificParams
): Promise<WebSearchProviderResponse> {
try {
if (!query.trim()) {
throw new Error('Search query cannot be empty')
@@ -44,7 +49,8 @@ export default class BochaProvider extends BaseWebSearchProvider {
headers: {
...this.defaultHeaders(),
...headers
}
},
signal: httpOptions?.signal
})
if (!response.ok) {

View File

@@ -1,9 +1,15 @@
import { WebSearchProviderResponse } from '@renderer/types'
import { WebSearchState } from '@renderer/store/websearch'
import { ProviderSpecificParams, WebSearchProviderResponse } from '@renderer/types'
import BaseWebSearchProvider from './BaseWebSearchProvider'
export default class DefaultProvider extends BaseWebSearchProvider {
search(): Promise<WebSearchProviderResponse> {
search(
_query: string,
_websearch: WebSearchState,
_httpOptions?: RequestInit,
_providerParams?: ProviderSpecificParams
): Promise<WebSearchProviderResponse> {
throw new Error('Method not implemented.')
}
}

View File

@@ -1,14 +1,53 @@
import { ExaClient } from '@agentic/exa'
import { loggerService } from '@logger'
import { WebSearchState } from '@renderer/store/websearch'
import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
import {
ExaSearchResult as ExaSearchResultType,
ProviderSpecificParams,
WebSearchProvider,
WebSearchProviderResponse
} from '@renderer/types'
import BaseWebSearchProvider from './BaseWebSearchProvider'
const logger = loggerService.withContext('ExaProvider')
export default class ExaProvider extends BaseWebSearchProvider {
private exa: ExaClient
interface ExaSearchRequest {
query: string
numResults: number
contents?: {
text?: boolean
highlights?: boolean
summary?: boolean
}
useAutoprompt?: boolean
category?: string
type?: 'keyword' | 'neural' | 'auto' | 'fast'
startPublishedDate?: string
endPublishedDate?: string
startCrawlDate?: string
endCrawlDate?: string
includeDomains?: string[]
excludeDomains?: string[]
}
interface ExaSearchResult {
title: string | null
url: string | null
text?: string | null
author?: string | null
score?: number
publishedDate?: string | null
favicon?: string | null
highlights?: string[]
}
interface ExaSearchResponse {
autopromptString?: string
results: ExaSearchResult[]
resolvedSearchType?: string
}
export default class ExaProvider extends BaseWebSearchProvider {
constructor(provider: WebSearchProvider) {
super(provider)
if (!this.apiKey) {
@@ -17,34 +56,138 @@ export default class ExaProvider extends BaseWebSearchProvider {
if (!this.apiHost) {
throw new Error('API host is required for Exa provider')
}
this.exa = new ExaClient({ apiKey: this.apiKey, apiBaseUrl: this.apiHost })
}
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
/**
* 统一的搜索方法 - 根据 providerParams 决定是否使用高级参数
*/
public async search(
query: string,
websearch: WebSearchState,
httpOptions?: RequestInit,
providerParams?: ProviderSpecificParams
): Promise<WebSearchProviderResponse> {
// 如果提供了 Exa 特定参数,使用高级搜索
if (providerParams?.exa) {
return this.searchWithParams({
query,
numResults: websearch.maxResults,
...providerParams.exa, // 展开高级参数
signal: httpOptions?.signal ?? undefined
})
}
// 否则使用默认参数
return this.searchWithParams({
query,
numResults: websearch.maxResults,
useAutoprompt: true,
signal: httpOptions?.signal ?? undefined
})
}
/**
* 使用完整参数进行搜索(支持 Exa 的所有高级功能)
*/
public async searchWithParams(params: {
query: string
numResults?: number
type?: 'keyword' | 'neural' | 'auto' | 'fast'
category?: string
startPublishedDate?: string
endPublishedDate?: string
startCrawlDate?: string
endCrawlDate?: string
useAutoprompt?: boolean
includeDomains?: string[]
excludeDomains?: string[]
signal?: AbortSignal
}): Promise<WebSearchProviderResponse> {
try {
if (!query.trim()) {
if (!params.query.trim()) {
throw new Error('Search query cannot be empty')
}
const response = await this.exa.search({
query,
numResults: Math.max(1, websearch.maxResults),
const requestBody: ExaSearchRequest = {
query: params.query,
numResults: Math.max(1, params.numResults || 5),
contents: {
text: true
text: true,
highlights: true // 获取高亮片段
},
useAutoprompt: params.useAutoprompt ?? true
}
// 添加可选参数
if (params.type) {
requestBody.type = params.type
}
if (params.category) {
requestBody.category = params.category
}
if (params.startPublishedDate) {
requestBody.startPublishedDate = params.startPublishedDate
}
if (params.endPublishedDate) {
requestBody.endPublishedDate = params.endPublishedDate
}
if (params.startCrawlDate) {
requestBody.startCrawlDate = params.startCrawlDate
}
if (params.endCrawlDate) {
requestBody.endCrawlDate = params.endCrawlDate
}
if (params.includeDomains && params.includeDomains.length > 0) {
requestBody.includeDomains = params.includeDomains
}
if (params.excludeDomains && params.excludeDomains.length > 0) {
requestBody.excludeDomains = params.excludeDomains
}
const response = await fetch(`${this.apiHost}/search`, {
method: 'POST',
headers: {
'x-api-key': this.apiKey!,
'Content-Type': 'application/json'
},
body: JSON.stringify(requestBody),
signal: params.signal
})
if (!response.ok) {
const errorText = await response.text()
throw new Error(`Exa API error (${response.status}): ${errorText}`)
}
const data: ExaSearchResponse = await response.json()
// 返回完整的 Exa 结果(包含 favicon、author、score 等字段)
return {
query: response.autopromptString,
results: response.results.slice(0, websearch.maxResults).map((result) => {
return {
query: data.autopromptString || params.query,
results: data.results.slice(0, params.numResults || 5).map(
(result): ExaSearchResultType => ({
title: result.title || 'No title',
content: result.text || '',
url: result.url || ''
}
url: result.url || '',
favicon: result.favicon || undefined,
publishedDate: result.publishedDate || undefined,
author: result.author || undefined,
score: result.score,
highlights: result.highlights
})
)
}
} catch (error) {
if (error instanceof DOMException && error.name === 'AbortError') {
throw error
}
logger.error('Exa search failed:', error as Error)
throw new Error(`Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`)
}

View File

@@ -2,7 +2,12 @@ import { loggerService } from '@logger'
import { nanoid } from '@reduxjs/toolkit'
import store from '@renderer/store'
import { WebSearchState } from '@renderer/store/websearch'
import { WebSearchProvider, WebSearchProviderResponse, WebSearchProviderResult } from '@renderer/types'
import {
ProviderSpecificParams,
WebSearchProvider,
WebSearchProviderResponse,
WebSearchProviderResult
} from '@renderer/types'
import { createAbortPromise } from '@renderer/utils/abortController'
import { isAbortError } from '@renderer/utils/error'
import { fetchWebContent, noContent } from '@renderer/utils/fetch'
@@ -27,7 +32,8 @@ export default class LocalSearchProvider extends BaseWebSearchProvider {
public async search(
query: string,
websearch: WebSearchState,
httpOptions?: RequestInit
httpOptions?: RequestInit,
_providerParams?: ProviderSpecificParams
): Promise<WebSearchProviderResponse> {
const uid = nanoid()
const language = store.getState().settings.language

View File

@@ -1,7 +1,7 @@
import { SearxngClient } from '@agentic/searxng'
import { loggerService } from '@logger'
import { WebSearchState } from '@renderer/store/websearch'
import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
import { ProviderSpecificParams, WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
import { fetchWebContent, noContent } from '@renderer/utils/fetch'
import axios from 'axios'
import ky from 'ky'
@@ -95,7 +95,12 @@ export default class SearxngProvider extends BaseWebSearchProvider {
}
}
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
public async search(
query: string,
websearch: WebSearchState,
httpOptions?: RequestInit,
_providerParams?: ProviderSpecificParams
): Promise<WebSearchProviderResponse> {
try {
if (!query) {
throw new Error('Search query cannot be empty')
@@ -124,7 +129,7 @@ export default class SearxngProvider extends BaseWebSearchProvider {
// Fetch content for each URL concurrently
const fetchPromises = validItems.map(async (item) => {
// Logger.log(`Fetching content for ${item.url}...`)
return await fetchWebContent(item.url, 'markdown', this.provider.usingBrowser)
return await fetchWebContent(item.url, 'markdown', this.provider.usingBrowser, httpOptions)
})
// Wait for all fetches to complete

View File

@@ -1,14 +1,45 @@
import { TavilyClient } from '@agentic/tavily'
import { loggerService } from '@logger'
import { WebSearchState } from '@renderer/store/websearch'
import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
import {
ProviderSpecificParams,
TavilySearchResult as TavilySearchResultType,
WebSearchProvider,
WebSearchProviderResponse
} from '@renderer/types'
import BaseWebSearchProvider from './BaseWebSearchProvider'
const logger = loggerService.withContext('TavilyProvider')
export default class TavilyProvider extends BaseWebSearchProvider {
private tvly: TavilyClient
interface TavilySearchRequest {
query: string
max_results?: number
topic?: 'general' | 'news' | 'finance'
search_depth?: 'basic' | 'advanced'
include_answer?: boolean
include_raw_content?: boolean
include_images?: boolean
include_domains?: string[]
exclude_domains?: string[]
}
interface TavilySearchResult {
title: string
url: string
content: string
raw_content?: string
score?: number
}
interface TavilySearchResponse {
query: string
results: TavilySearchResult[]
answer?: string
images?: string[]
response_time?: number
}
export default class TavilyProvider extends BaseWebSearchProvider {
constructor(provider: WebSearchProvider) {
super(provider)
if (!this.apiKey) {
@@ -17,30 +48,119 @@ export default class TavilyProvider extends BaseWebSearchProvider {
if (!this.apiHost) {
throw new Error('API host is required for Tavily provider')
}
this.tvly = new TavilyClient({ apiKey: this.apiKey, apiBaseUrl: this.apiHost })
}
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
/**
* 统一的搜索方法 - 根据 providerParams 决定是否使用高级参数
*/
public async search(
query: string,
websearch: WebSearchState,
httpOptions?: RequestInit,
providerParams?: ProviderSpecificParams
): Promise<WebSearchProviderResponse> {
// 如果提供了 Tavily 特定参数,使用高级搜索
if (providerParams?.tavily) {
return this.searchWithParams({
query,
maxResults: websearch.maxResults,
...providerParams.tavily, // 展开高级参数
signal: httpOptions?.signal ?? undefined
})
}
// 否则使用默认参数
return this.searchWithParams({
query,
maxResults: websearch.maxResults,
includeRawContent: true,
signal: httpOptions?.signal ?? undefined
})
}
/**
* 使用完整参数进行搜索(支持 Tavily 的所有高级功能)
*/
public async searchWithParams(params: {
query: string
maxResults?: number
topic?: 'general' | 'news' | 'finance'
searchDepth?: 'basic' | 'advanced'
includeAnswer?: boolean
includeRawContent?: boolean
includeImages?: boolean
includeDomains?: string[]
excludeDomains?: string[]
signal?: AbortSignal
}): Promise<WebSearchProviderResponse> {
try {
if (!query.trim()) {
if (!params.query.trim()) {
throw new Error('Search query cannot be empty')
}
const result = await this.tvly.search({
query,
max_results: Math.max(1, websearch.maxResults)
})
return {
query: result.query,
results: result.results.slice(0, websearch.maxResults).map((result) => {
return {
title: result.title || 'No title',
content: result.content || '',
url: result.url || ''
const requestBody: TavilySearchRequest = {
query: params.query,
max_results: Math.max(1, params.maxResults || 5),
include_raw_content: params.includeRawContent ?? true,
include_answer: params.includeAnswer ?? true,
include_images: params.includeImages ?? false
}
// 添加可选参数
if (params.topic) {
requestBody.topic = params.topic
}
if (params.searchDepth) {
requestBody.search_depth = params.searchDepth
}
if (params.includeDomains && params.includeDomains.length > 0) {
requestBody.include_domains = params.includeDomains
}
if (params.excludeDomains && params.excludeDomains.length > 0) {
requestBody.exclude_domains = params.excludeDomains
}
const response = await fetch(`${this.apiHost}/search`, {
method: 'POST',
headers: {
'Content-Type': 'application/json'
},
body: JSON.stringify({
...requestBody,
api_key: this.apiKey
}),
signal: params.signal
})
if (!response.ok) {
const errorText = await response.text()
throw new Error(`Tavily API error (${response.status}): ${errorText}`)
}
const data: TavilySearchResponse = await response.json()
// 返回完整的 Tavily 结果(包含 answer、images 等字段)
return {
query: data.query,
results: data.results.slice(0, params.maxResults || 5).map(
(item): TavilySearchResultType => ({
title: item.title || 'No title',
content: item.raw_content || item.content || '',
url: item.url || '',
rawContent: item.raw_content,
score: item.score,
answer: data.answer, // Tavily 的直接答案
images: data.images // Tavily 的图片
})
)
}
} catch (error) {
if (error instanceof DOMException && error.name === 'AbortError') {
throw error
}
logger.error('Tavily search failed:', error as Error)
throw new Error(`Search failed: ${error instanceof Error ? error.message : 'Unknown error'}`)
}

View File

@@ -1,6 +1,6 @@
import { loggerService } from '@logger'
import { WebSearchState } from '@renderer/store/websearch'
import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
import { ProviderSpecificParams, WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
import BaseWebSearchProvider from './BaseWebSearchProvider'
@@ -43,7 +43,12 @@ export default class ZhipuProvider extends BaseWebSearchProvider {
}
}
public async search(query: string, websearch: WebSearchState): Promise<WebSearchProviderResponse> {
public async search(
query: string,
websearch: WebSearchState,
httpOptions?: RequestInit,
_providerParams?: ProviderSpecificParams
): Promise<WebSearchProviderResponse> {
try {
if (!query.trim()) {
throw new Error('Search query cannot be empty')
@@ -62,7 +67,8 @@ export default class ZhipuProvider extends BaseWebSearchProvider {
'Content-Type': 'application/json',
...this.defaultHeaders()
},
body: JSON.stringify(requestBody)
body: JSON.stringify(requestBody),
signal: httpOptions?.signal
})
if (!response.ok) {

View File

@@ -1,6 +1,6 @@
import { withSpanResult } from '@renderer/services/SpanManagerService'
import type { WebSearchState } from '@renderer/store/websearch'
import { WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
import { ProviderSpecificParams, WebSearchProvider, WebSearchProviderResponse } from '@renderer/types'
import { filterResultWithBlacklist } from '@renderer/utils/blacklistMatchPattern'
import BaseWebSearchProvider from './BaseWebSearchProvider'
@@ -24,10 +24,11 @@ export default class WebSearchEngineProvider {
public async search(
query: string,
websearch: WebSearchState,
httpOptions?: RequestInit
httpOptions?: RequestInit,
providerParams?: ProviderSpecificParams
): Promise<WebSearchProviderResponse> {
const callSearch = async ({ query, websearch }) => {
return await this.sdk.search(query, websearch, httpOptions)
const callSearch = async ({ query, websearch, providerParams }) => {
return await this.sdk.search(query, websearch, httpOptions, providerParams)
}
const traceParams = {
@@ -38,7 +39,7 @@ export default class WebSearchEngineProvider {
modelName: this.modelName
}
const result = await withSpanResult(callSearch, traceParams, { query, websearch })
const result = await withSpanResult(callSearch, traceParams, { query, websearch, providerParams })
return await filterResultWithBlacklist(result, websearch)
}

View File

@@ -10,6 +10,7 @@ import {
KnowledgeBase,
KnowledgeItem,
KnowledgeReference,
ProviderSpecificParams,
WebSearchProvider,
WebSearchProviderResponse,
WebSearchProviderResult,
@@ -161,13 +162,17 @@ class WebSearchService {
* @public
* @param provider 搜索提供商
* @param query 搜索查询
* @param httpOptions HTTP选项包含signal等
* @param spanId Span ID用于追踪
* @param providerParams Provider特定参数如Exa的category、Tavily的searchDepth等
* @returns 搜索响应
*/
public async search(
provider: WebSearchProvider,
query: string,
httpOptions?: RequestInit,
spanId?: string
spanId?: string,
providerParams?: ProviderSpecificParams
): Promise<WebSearchProviderResponse> {
const websearch = this.getWebSearchState()
const webSearchEngine = new WebSearchEngineProvider(provider, spanId)
@@ -178,7 +183,7 @@ class WebSearchService {
formattedQuery = `today is ${dayjs().format('YYYY-MM-DD')} \r\n ${query}`
}
return await webSearchEngine.search(formattedQuery, websearch, httpOptions)
return await webSearchEngine.search(formattedQuery, websearch, httpOptions, providerParams)
}
/**
@@ -424,13 +429,17 @@ class WebSearchService {
* @param webSearchProvider - 要使用的网络搜索提供商
* @param extractResults - 包含搜索问题和链接的提取结果对象
* @param requestId - 唯一的请求标识符,用于状态跟踪和资源管理
* @param externalSignal - 可选的 AbortSignal 用于取消请求
* @param providerParams - 可选的 Provider 特定参数(如 Exa 的 category、Tavily 的 searchDepth 等)
*
* @returns 包含搜索结果的响应对象
*/
public async processWebsearch(
webSearchProvider: WebSearchProvider,
extractResults: ExtractResults,
requestId: string
requestId: string,
externalSignal?: AbortSignal,
providerParams?: ProviderSpecificParams
): Promise<WebSearchProviderResponse> {
// 重置状态
await this.setWebSearchStatus(requestId, { phase: 'default' })
@@ -441,8 +450,8 @@ class WebSearchService {
return { results: [] }
}
// 使用请求特定的signal如果没有则回退到全局signal
const signal = this.getRequestState(requestId).signal || this.signal
// 优先使用外部传入的signal其次是请求特定的signal最后回退到全局signal
const signal = externalSignal || this.getRequestState(requestId).signal || this.signal
const span = webSearchProvider.topicId
? addSpan({
@@ -473,8 +482,9 @@ class WebSearchService {
return { query: 'summaries', results: contents }
}
// 执行搜索
const searchPromises = questions.map((q) =>
this.search(webSearchProvider, q, { signal }, span?.spanContext().spanId)
this.search(webSearchProvider, q, { signal }, span?.spanContext().spanId, providerParams)
)
const searchResults = await Promise.allSettled(searchPromises)

View File

@@ -84,7 +84,8 @@ export const createToolCallbacks = (deps: ToolCallbacksDependencies) => {
}
blockManager.smartBlockUpdate(existingBlockId, changes, MessageBlockType.TOOL, true)
// Handle citation block creation for web search results
if (toolResponse.tool.name === 'builtin_web_search' && toolResponse.response) {
const webSearchTools = ['builtin_web_search', 'builtin_exa_search', 'builtin_tavily_search']
if (webSearchTools.includes(toolResponse.tool.name) && toolResponse.response) {
const citationBlock = createCitationBlock(
assistantMsgId,
{

View File

@@ -4,6 +4,7 @@ import { createEntityAdapter, createSelector, createSlice, type PayloadAction }
import { AISDKWebSearchResult, Citation, WebSearchProviderResponse, WebSearchSource } from '@renderer/types'
import type { CitationMessageBlock, MessageBlock } from '@renderer/types/newMessage'
import { MessageBlockType } from '@renderer/types/newMessage'
import { adaptSearchResultsToCitations } from '@renderer/utils/searchResultAdapters'
import type OpenAI from 'openai'
import type { RootState } from './index' // 确认 RootState 从 store/index.ts 导出
@@ -217,17 +218,12 @@ export const formatCitationsFromBlock = (block: CitationMessageBlock | undefined
type: 'websearch'
})) || []
break
case WebSearchSource.WEBSEARCH:
formattedCitations =
(block.response.results as WebSearchProviderResponse)?.results?.map((result, index) => ({
number: index + 1,
url: result.url,
title: result.title,
content: result.content,
showFavicon: true,
type: 'websearch'
})) || []
case WebSearchSource.WEBSEARCH: {
const results = (block.response.results as WebSearchProviderResponse)?.results || []
// 使用适配器统一转换,自动处理 Provider 特定字段(如 Exa 的 favicon、Tavily 的 answer 等)
formattedCitations = adaptSearchResultsToCitations(results)
break
}
case WebSearchSource.AISDK:
formattedCitations =
(block.response?.results as AISDKWebSearchResult[])?.map((result, index) => ({

View File

@@ -575,17 +575,63 @@ export type WebSearchProvider = {
modelName?: string
}
export type WebSearchProviderResult = {
// 基础搜索结果(所有 Provider 必须实现)
export interface BaseSearchResult {
title: string
content: string
url: string
}
// Exa Provider 特定扩展
export interface ExaSearchResult extends BaseSearchResult {
favicon?: string
publishedDate?: string
author?: string
score?: number
highlights?: string[]
}
// Tavily Provider 特定扩展
export interface TavilySearchResult extends BaseSearchResult {
answer?: string // Tavily 的 AI 直接答案
images?: string[]
rawContent?: string
score?: number
}
// 联合类型 - 向后兼容
export type WebSearchProviderResult = BaseSearchResult | ExaSearchResult | TavilySearchResult
export type WebSearchProviderResponse = {
query?: string
results: WebSearchProviderResult[]
}
// Provider 特定参数类型
export interface ExaSearchParams {
type?: 'neural' | 'keyword' | 'auto' | 'fast'
category?: string
startPublishedDate?: string
endPublishedDate?: string
startCrawlDate?: string
endCrawlDate?: string
useAutoprompt?: boolean
}
export interface TavilySearchParams {
topic?: 'general' | 'news' | 'finance'
searchDepth?: 'basic' | 'advanced'
includeAnswer?: boolean
includeRawContent?: boolean
includeImages?: boolean
}
// 联合类型 - 支持不同 Provider 的特定参数
export interface ProviderSpecificParams {
exa?: ExaSearchParams
tavily?: TavilySearchParams
}
export type AISDKWebSearchResult = Omit<Extract<LanguageModelV2Source, { sourceType: 'url' }>, 'sourceType'>
export type WebSearchResults =
@@ -813,6 +859,7 @@ export interface Citation {
hostname?: string
content?: string
showFavicon?: boolean
favicon?: string // 新增:直接的 favicon URL来自 Provider
type?: string
metadata?: Record<string, any>
}

View File

@@ -0,0 +1,77 @@
/**
* 搜索结果适配器
* 将不同 Provider 的搜索结果统一转换为 Citation 格式
*/
import type { Citation, WebSearchProviderResult } from '@renderer/types'
/**
* 将 WebSearchProviderResult 转换为 Citation
* 自动识别并处理不同 Provider 的额外字段
*
* @param result - 搜索结果(可能包含 Provider 特定字段)
* @param index - 结果序号从0开始
* @returns Citation 对象
*/
export function adaptSearchResultToCitation(result: WebSearchProviderResult, index: number): Citation {
// 基础字段(所有 Provider 都有)
const citation: Citation = {
number: index + 1,
url: result.url,
title: result.title,
content: result.content,
showFavicon: true,
type: 'websearch'
}
// Exa Provider 特定字段
if ('favicon' in result && result.favicon) {
citation.favicon = result.favicon
}
// 收集元数据
const metadata: Record<string, any> = {}
// Exa 元数据
if ('publishedDate' in result && result.publishedDate) {
metadata.publishedDate = result.publishedDate
}
if ('author' in result && result.author) {
metadata.author = result.author
}
if ('score' in result && result.score !== undefined) {
metadata.score = result.score
}
if ('highlights' in result && result.highlights && result.highlights.length > 0) {
metadata.highlights = result.highlights
}
// Tavily 元数据
if ('answer' in result && result.answer) {
metadata.answer = result.answer
}
if ('images' in result && result.images && result.images.length > 0) {
metadata.images = result.images
}
// 只在有元数据时添加
if (Object.keys(metadata).length > 0) {
citation.metadata = metadata
}
return citation
}
/**
* 批量转换搜索结果为 Citations
*
* @param results - 搜索结果数组
* @returns Citation 数组
*/
export function adaptSearchResultsToCitations(results: WebSearchProviderResult[]): Citation[] {
return results.map((result, index) => adaptSearchResultToCitation(result, index))
}