Compare commits

..

3 Commits

Author SHA1 Message Date
suyao
7f8d0b06ee Merge branch 'main' into fix/check-api-key 2025-12-01 16:37:43 +08:00
suyao
4be5fedeec fix 2025-12-01 00:07:43 +08:00
suyao
163e016759 fix: enhance provider handling and API key rotation logic in AiProvider 2025-12-01 00:01:01 +08:00
6 changed files with 144 additions and 175 deletions

View File

@@ -478,16 +478,13 @@ class FileStorage {
}
}
/**
* Core file reading logic that handles both documents and text files.
*
* @private
* @param filePath - Full path to the file
* @param detectEncoding - Whether to auto-detect text file encoding
* @returns Promise resolving to the extracted text content
* @throws Error if file reading fails
*/
private async readFileCore(filePath: string, detectEncoding: boolean = false): Promise<string> {
public readFile = async (
_: Electron.IpcMainInvokeEvent,
id: string,
detectEncoding: boolean = false
): Promise<string> => {
const filePath = path.join(this.storageDir, id)
const fileExtension = path.extname(filePath)
if (documentExts.includes(fileExtension)) {
@@ -507,7 +504,7 @@ class FileStorage {
return data
} catch (error) {
chdir(originalCwd)
logger.error('Failed to read document file:', error as Error)
logger.error('Failed to read file:', error as Error)
throw error
}
}
@@ -519,72 +516,11 @@ class FileStorage {
return fs.readFileSync(filePath, 'utf-8')
}
} catch (error) {
logger.error('Failed to read text file:', error as Error)
logger.error('Failed to read file:', error as Error)
throw new Error(`Failed to read file: ${filePath}.`)
}
}
/**
* Reads and extracts content from a stored file.
*
* Supports multiple file formats including:
* - Complex documents: .pdf, .doc, .docx, .pptx, .xlsx, .odt, .odp, .ods
* - Text files: .txt, .md, .json, .csv, etc.
* - Code files: .js, .ts, .py, .java, etc.
*
* For document formats, extracts text content using specialized parsers:
* - .doc files: Uses word-extractor library
* - Other Office formats: Uses officeparser library
*
* For text files, can optionally detect encoding automatically.
*
* @param _ - Electron IPC invoke event (unused)
* @param id - File identifier with extension (e.g., "uuid.docx")
* @param detectEncoding - Whether to auto-detect text file encoding (default: false)
* @returns Promise resolving to the extracted text content of the file
* @throws Error if file reading fails or file is not found
*
* @example
* // Read a DOCX file
* const content = await readFile(event, "document.docx");
*
* @example
* // Read a text file with encoding detection
* const content = await readFile(event, "text.txt", true);
*
* @example
* // Read a PDF file
* const content = await readFile(event, "manual.pdf");
*/
public readFile = async (
_: Electron.IpcMainInvokeEvent,
id: string,
detectEncoding: boolean = false
): Promise<string> => {
const filePath = path.join(this.storageDir, id)
return this.readFileCore(filePath, detectEncoding)
}
/**
* Reads and extracts content from an external file path.
*
* Similar to readFile, but operates on external file paths instead of stored files.
* Supports the same file formats including complex documents and text files.
*
* @param _ - Electron IPC invoke event (unused)
* @param filePath - Absolute path to the external file
* @param detectEncoding - Whether to auto-detect text file encoding (default: false)
* @returns Promise resolving to the extracted text content of the file
* @throws Error if file does not exist or reading fails
*
* @example
* // Read an external DOCX file
* const content = await readExternalFile(event, "/path/to/document.docx");
*
* @example
* // Read an external text file with encoding detection
* const content = await readExternalFile(event, "/path/to/text.txt", true);
*/
public readExternalFile = async (
_: Electron.IpcMainInvokeEvent,
filePath: string,
@@ -594,7 +530,40 @@ class FileStorage {
throw new Error(`File does not exist: ${filePath}`)
}
return this.readFileCore(filePath, detectEncoding)
const fileExtension = path.extname(filePath)
if (documentExts.includes(fileExtension)) {
const originalCwd = process.cwd()
try {
chdir(this.tempDir)
if (fileExtension === '.doc') {
const extractor = new WordExtractor()
const extracted = await extractor.extract(filePath)
chdir(originalCwd)
return extracted.getBody()
}
const data = await officeParser.parseOfficeAsync(filePath)
chdir(originalCwd)
return data
} catch (error) {
chdir(originalCwd)
logger.error('Failed to read file:', error as Error)
throw error
}
}
try {
if (detectEncoding) {
return readTextFileWithAutoEncoding(filePath)
} else {
return fs.readFileSync(filePath, 'utf-8')
}
} catch (error) {
logger.error('Failed to read file:', error as Error)
throw new Error(`Failed to read file: ${filePath}.`)
}
}
public createTempFile = async (_: Electron.IpcMainInvokeEvent, fileName: string): Promise<string> => {

View File

@@ -120,9 +120,12 @@ export default class ModernAiProvider {
throw new Error('Model is required for completions. Please use constructor with model parameter.')
}
// 每次请求时重新生成配置以确保API key轮换生效
this.config = providerToAiSdkConfig(this.actualProvider, this.model)
logger.debug('Generated provider config for completions', this.config)
// Config is now set in constructor, ApiService handles key rotation before passing provider
if (!this.config) {
// If config wasn't set in constructor (when provider only), generate it now
this.config = providerToAiSdkConfig(this.actualProvider, this.model!)
}
logger.debug('Using provider config for completions', this.config)
// 检查 config 是否存在
if (!this.config) {

View File

@@ -29,32 +29,6 @@ import { azureAnthropicProviderCreator } from './config/azure-anthropic'
import { COPILOT_DEFAULT_HEADERS } from './constants'
import { getAiSdkProviderId } from './factory'
/**
* 获取轮询的API key
* 复用legacy架构的多key轮询逻辑
*/
function getRotatedApiKey(provider: Provider): string {
const keys = provider.apiKey.split(',').map((key) => key.trim())
const keyName = `provider:${provider.id}:last_used_key`
if (keys.length === 1) {
return keys[0]
}
const lastUsedKey = window.keyv.get(keyName)
if (!lastUsedKey) {
window.keyv.set(keyName, keys[0])
return keys[0]
}
const currentIndex = keys.indexOf(lastUsedKey)
const nextIndex = (currentIndex + 1) % keys.length
const nextKey = keys[nextIndex]
window.keyv.set(keyName, nextKey)
return nextKey
}
/**
* 处理特殊provider的转换逻辑
*/
@@ -161,7 +135,7 @@ export function providerToAiSdkConfig(actualProvider: Provider, model: Model): A
const { baseURL, endpoint } = routeToEndpoint(actualProvider.apiHost)
const baseConfig = {
baseURL: baseURL,
apiKey: getRotatedApiKey(actualProvider)
apiKey: actualProvider.apiKey
}
const isCopilotProvider = actualProvider.id === SystemProviderIds.copilot

View File

@@ -39,7 +39,6 @@ import {
detectLanguage,
determineTargetLanguage
} from '@renderer/utils/translate'
import { documentExts } from '@shared/config/constant'
import { imageExts, MB, textExts } from '@shared/config/constant'
import { Button, Flex, FloatButton, Popover, Tooltip, Typography } from 'antd'
import type { TextAreaRef } from 'antd/es/input/TextArea'
@@ -67,7 +66,7 @@ const TranslatePage: FC = () => {
const { prompt, getLanguageByLangcode, settings } = useTranslate()
const { autoCopy } = settings
const { shikiMarkdownIt } = useCodeStyle()
const { onSelectFile, selecting, clearFiles } = useFiles({ extensions: [...imageExts, ...textExts, ...documentExts] })
const { onSelectFile, selecting, clearFiles } = useFiles({ extensions: [...imageExts, ...textExts] })
const { ocr } = useOcr()
const { setTimeoutTimer } = useTimer()
@@ -485,56 +484,33 @@ const TranslatePage: FC = () => {
const readFile = useCallback(
async (file: FileMetadata) => {
const _readFile = async () => {
let isText: boolean
try {
const fileExtension = getFileExtension(file.path)
// 检查文件是否为文本文件
isText = await isTextFile(file.path)
} catch (e) {
logger.error('Failed to check if file is text.', e as Error)
window.toast.error(t('translate.files.error.check_type') + ': ' + formatErrorMessage(e))
return
}
// Check if file is supported format (text file or document file)
let isText: boolean
const isDocument: boolean = documentExts.includes(fileExtension)
if (!isText) {
window.toast.error(t('common.file.not_supported', { type: getFileExtension(file.path) }))
logger.error('Unsupported file type.')
return
}
if (!isDocument) {
try {
// For non-document files, check if it's a text file
isText = await isTextFile(file.path)
} catch (e) {
logger.error('Failed to check file type.', e as Error)
window.toast.error(t('translate.files.error.check_type') + ': ' + formatErrorMessage(e))
return
}
} else {
isText = false
}
if (!isText && !isDocument) {
window.toast.error(t('common.file.not_supported', { type: fileExtension }))
logger.error('Unsupported file type.')
return
}
// File size check - document files allowed to be larger
const maxSize = isDocument ? 20 * MB : 5 * MB
if (file.size > maxSize) {
window.toast.error(t('translate.files.error.too_large') + ` (0 ~ ${maxSize / MB} MB)`)
return
}
let result: string
// the threshold may be too large
if (file.size > 5 * MB) {
window.toast.error(t('translate.files.error.too_large') + ' (0 ~ 5 MB)')
} else {
try {
if (isDocument) {
// Use the new document reading API
result = await window.api.file.readExternal(file.path, true)
} else {
// Read text file
result = await window.api.fs.readText(file.path)
}
const result = await window.api.fs.readText(file.path)
setText(text + result)
} catch (e) {
logger.error('Failed to read file.', e as Error)
logger.error('Failed to read text file.', e as Error)
window.toast.error(t('translate.files.error.unknown') + ': ' + formatErrorMessage(e))
}
} catch (e) {
logger.error('Failed to read file.', e as Error)
window.toast.error(t('translate.files.error.unknown') + ': ' + formatErrorMessage(e))
}
}
const promise = _readFile()

View File

@@ -8,8 +8,8 @@ import { isDedicatedImageGenerationModel, isEmbeddingModel, isFunctionCallingMod
import { getStoreSetting } from '@renderer/hooks/useSettings'
import i18n from '@renderer/i18n'
import store from '@renderer/store'
import type { FetchChatCompletionParams } from '@renderer/types'
import type { Assistant, MCPServer, MCPTool, Model, Provider } from '@renderer/types'
import { type FetchChatCompletionParams, isSystemProvider } from '@renderer/types'
import type { StreamTextParams } from '@renderer/types/aiCoreTypes'
import { type Chunk, ChunkType } from '@renderer/types/chunk'
import type { Message, ResponseError } from '@renderer/types/newMessage'
@@ -22,7 +22,8 @@ import { purifyMarkdownImages } from '@renderer/utils/markdown'
import { isPromptToolUse, isSupportedToolUse } from '@renderer/utils/mcp-tools'
import { findFileBlocks, getMainTextContent } from '@renderer/utils/messageUtils/find'
import { containsSupportedVariables, replacePromptVariables } from '@renderer/utils/prompt'
import { isEmpty, takeRight } from 'lodash'
import { NOT_SUPPORT_API_KEY_PROVIDERS } from '@renderer/utils/provider'
import { cloneDeep, isEmpty, takeRight } from 'lodash'
import type { ModernAiProviderConfig } from '../aiCore/index_new'
import AiProviderNew from '../aiCore/index_new'
@@ -43,6 +44,8 @@ import {
// } from './MessagesService'
// import WebSearchService from './WebSearchService'
// FIXME: 这里太多重复逻辑,需要重构
const logger = loggerService.withContext('ApiService')
export async function fetchMcpTools(assistant: Assistant) {
@@ -95,7 +98,15 @@ export async function fetchChatCompletion({
modelId: assistant.model?.id,
modelName: assistant.model?.name
})
const AI = new AiProviderNew(assistant.model || getDefaultModel())
// Get base provider and apply API key rotation
const baseProvider = getProviderByModel(assistant.model || getDefaultModel())
const providerWithRotatedKey = {
...cloneDeep(baseProvider),
apiKey: getRotatedApiKey(baseProvider)
}
const AI = new AiProviderNew(assistant.model || getDefaultModel(), providerWithRotatedKey)
const provider = AI.getActualProvider()
const mcpTools: MCPTool[] = []
@@ -172,7 +183,13 @@ export async function fetchMessagesSummary({ messages, assistant }: { messages:
return null
}
const AI = new AiProviderNew(model)
// Apply API key rotation
const providerWithRotatedKey = {
...cloneDeep(provider),
apiKey: getRotatedApiKey(provider)
}
const AI = new AiProviderNew(model, providerWithRotatedKey)
const topicId = messages?.find((message) => message.topicId)?.topicId || ''
@@ -271,7 +288,13 @@ export async function fetchNoteSummary({ content, assistant }: { content: string
return null
}
const AI = new AiProviderNew(model)
// Apply API key rotation
const providerWithRotatedKey = {
...cloneDeep(provider),
apiKey: getRotatedApiKey(provider)
}
const AI = new AiProviderNew(model, providerWithRotatedKey)
// only 2000 char and no images
const truncatedContent = content.substring(0, 2000)
@@ -359,7 +382,13 @@ export async function fetchGenerate({
return ''
}
const AI = new AiProviderNew(model)
// Apply API key rotation
const providerWithRotatedKey = {
...cloneDeep(provider),
apiKey: getRotatedApiKey(provider)
}
const AI = new AiProviderNew(model, providerWithRotatedKey)
const assistant = getDefaultAssistant()
assistant.model = model
@@ -404,28 +433,44 @@ export async function fetchGenerate({
export function hasApiKey(provider: Provider) {
if (!provider) return false
if (['ollama', 'lmstudio', 'vertexai', 'cherryai'].includes(provider.id)) return true
if (isSystemProvider(provider) && NOT_SUPPORT_API_KEY_PROVIDERS.includes(provider.id)) return true
return !isEmpty(provider.apiKey)
}
/**
* Get the first available embedding model from enabled providers
* 获取轮询的API key
* 复用legacy架构的多key轮询逻辑
*/
// function getFirstEmbeddingModel() {
// const providers = store.getState().llm.providers.filter((p) => p.enabled)
function getRotatedApiKey(provider: Provider): string {
const keys = provider.apiKey.split(',').map((key) => key.trim())
const keyName = `provider:${provider.id}:last_used_key`
// for (const provider of providers) {
// const embeddingModel = provider.models.find((model) => isEmbeddingModel(model))
// if (embeddingModel) {
// return embeddingModel
// }
// }
if (keys.length === 1) {
return keys[0]
}
// return undefined
// }
const lastUsedKey = window.keyv.get(keyName)
if (!lastUsedKey) {
window.keyv.set(keyName, keys[0])
return keys[0]
}
const currentIndex = keys.indexOf(lastUsedKey)
const nextIndex = (currentIndex + 1) % keys.length
const nextKey = keys[nextIndex]
window.keyv.set(keyName, nextKey)
return nextKey
}
export async function fetchModels(provider: Provider): Promise<SdkModel[]> {
const AI = new AiProviderNew(provider)
// Apply API key rotation
const providerWithRotatedKey = {
...cloneDeep(provider),
apiKey: getRotatedApiKey(provider)
}
const AI = new AiProviderNew(providerWithRotatedKey)
try {
return await AI.models()
@@ -435,12 +480,7 @@ export async function fetchModels(provider: Provider): Promise<SdkModel[]> {
}
export function checkApiProvider(provider: Provider): void {
if (
provider.id !== 'ollama' &&
provider.id !== 'lmstudio' &&
provider.type !== 'vertexai' &&
provider.id !== 'copilot'
) {
if (isSystemProvider(provider) && !NOT_SUPPORT_API_KEY_PROVIDERS.includes(provider.id)) {
if (!provider.apiKey) {
window.toast.error(i18n.t('message.error.enter.api.label'))
throw new Error(i18n.t('message.error.enter.api.label'))
@@ -461,8 +501,7 @@ export function checkApiProvider(provider: Provider): void {
export async function checkApi(provider: Provider, model: Model, timeout = 15000): Promise<void> {
checkApiProvider(provider)
// Don't pass in provider parameter. We need auto-format URL
const ai = new AiProviderNew(model)
const ai = new AiProviderNew(model, provider)
const assistant = getDefaultAssistant()
assistant.model = model

View File

@@ -183,3 +183,11 @@ export const isSupportAPIVersionProvider = (provider: Provider) => {
}
return provider.apiOptions?.isNotSupportAPIVersion !== false
}
export const NOT_SUPPORT_API_KEY_PROVIDERS: readonly SystemProviderId[] = [
'ollama',
'lmstudio',
'vertexai',
'aws-bedrock',
'copilot'
]