feat: add file attachment
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
import { electronApp, optimizer } from '@electron-toolkit/utils'
|
||||
import { app, BrowserWindow } from 'electron'
|
||||
import installExtension, { REDUX_DEVTOOLS } from 'electron-devtools-installer'
|
||||
|
||||
import { registerIpc } from './ipc'
|
||||
import { updateUserDataPath } from './utils/upgrade'
|
||||
@@ -30,6 +31,12 @@ app.whenReady().then(async () => {
|
||||
const mainWindow = createMainWindow()
|
||||
|
||||
registerIpc(mainWindow, app)
|
||||
|
||||
if (process.env.NODE_ENV === 'development') {
|
||||
installExtension(REDUX_DEVTOOLS)
|
||||
.then((name) => console.log(`Added Extension: ${name}`))
|
||||
.catch((err) => console.log('An error occurred: ', err))
|
||||
}
|
||||
})
|
||||
|
||||
// Quit when all windows are closed, except on macOS. There, it's common
|
||||
|
||||
+3
-23
@@ -1,8 +1,5 @@
|
||||
import { FileType } from '@types'
|
||||
import { BrowserWindow, ipcMain, OpenDialogOptions, session, shell } from 'electron'
|
||||
import Logger from 'electron-log'
|
||||
import fs from 'fs'
|
||||
import path from 'path'
|
||||
|
||||
import { appConfig, titleBarOverlayDark, titleBarOverlayLight } from './config'
|
||||
import AppUpdater from './services/AppUpdater'
|
||||
@@ -38,29 +35,12 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
|
||||
ipcMain.handle('zip:compress', (_, text: string) => compress(text))
|
||||
ipcMain.handle('zip:decompress', (_, text: Buffer) => decompress(text))
|
||||
|
||||
ipcMain.handle('image:base64', async (_, filePath) => {
|
||||
try {
|
||||
const data = await fs.promises.readFile(filePath)
|
||||
const base64 = data.toString('base64')
|
||||
const mime = `image/${path.extname(filePath).slice(1)}`
|
||||
return {
|
||||
mime,
|
||||
base64,
|
||||
data: `data:${mime};base64,${base64}`
|
||||
}
|
||||
} catch (error) {
|
||||
Logger.error('Error reading file:', error)
|
||||
return ''
|
||||
}
|
||||
})
|
||||
|
||||
ipcMain.handle('file:base64Image', async (_, id) => await fileManager.base64Image(id))
|
||||
ipcMain.handle('file:select', async (_, options?: OpenDialogOptions) => await fileManager.selectFile(options))
|
||||
ipcMain.handle('file:upload', async (_, file: FileType) => await fileManager.uploadFile(file))
|
||||
ipcMain.handle('file:clear', async () => await fileManager.clear())
|
||||
ipcMain.handle('file:delete', async (_, fileId: string) => {
|
||||
await fileManager.deleteFile(fileId)
|
||||
return { success: true }
|
||||
})
|
||||
ipcMain.handle('file:read', async (_, id: string) => await fileManager.readFile(id))
|
||||
ipcMain.handle('file:delete', async (_, id: string) => await fileManager.deleteFile(id))
|
||||
|
||||
ipcMain.handle('minapp', (_, args) => {
|
||||
createMinappWindow({
|
||||
|
||||
@@ -135,6 +135,23 @@ class File {
|
||||
await fs.promises.unlink(path.join(this.storageDir, id))
|
||||
}
|
||||
|
||||
async readFile(id: string): Promise<string> {
|
||||
const filePath = path.join(this.storageDir, id)
|
||||
return fs.readFileSync(filePath, 'utf8')
|
||||
}
|
||||
|
||||
async base64Image(id: string): Promise<{ mime: string; base64: string; data: string }> {
|
||||
const filePath = path.join(this.storageDir, id)
|
||||
const data = await fs.promises.readFile(filePath)
|
||||
const base64 = data.toString('base64')
|
||||
const mime = `image/${path.extname(filePath).slice(1)}`
|
||||
return {
|
||||
mime,
|
||||
base64,
|
||||
data: `data:${mime};base64,${base64}`
|
||||
}
|
||||
}
|
||||
|
||||
async clear(): Promise<void> {
|
||||
await fs.promises.rmdir(this.storageDir, { recursive: true })
|
||||
await this.initStorageDir()
|
||||
|
||||
+92
-1
@@ -56,12 +56,103 @@ export function getFileType(ext: string): FileTypes {
|
||||
const imageExts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']
|
||||
const videoExts = ['.mp4', '.avi', '.mov', '.wmv', '.flv', '.mkv']
|
||||
const audioExts = ['.mp3', '.wav', '.ogg', '.flac', '.aac']
|
||||
const documentExts = ['.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx', '.txt']
|
||||
const documentExts = ['.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx']
|
||||
const textExts = [
|
||||
'.txt', // 普通文本文件
|
||||
'.md', // Markdown 文件
|
||||
'.mdx', // Markdown 文件
|
||||
'.html', // HTML 文件
|
||||
'.htm', // HTML 文件的另一种扩展名
|
||||
'.xml', // XML 文件
|
||||
'.json', // JSON 文件
|
||||
'.yaml', // YAML 文件
|
||||
'.yml', // YAML 文件的另一种扩展名
|
||||
'.csv', // 逗号分隔值文件
|
||||
'.tsv', // 制表符分隔值文件
|
||||
'.ini', // 配置文件
|
||||
'.log', // 日志文件
|
||||
'.rtf', // 富文本格式文件
|
||||
'.tex', // LaTeX 文件
|
||||
'.srt', // 字幕文件
|
||||
'.xhtml', // XHTML 文件
|
||||
'.nfo', // 信息文件(主要用于场景发布)
|
||||
'.conf', // 配置文件
|
||||
'.config', // 配置文件
|
||||
'.env', // 环境变量文件
|
||||
'.properties', // 配置属性文件
|
||||
'.latex', // LaTeX 文档文件
|
||||
'.rst', // reStructuredText 文件
|
||||
'.php', // PHP 脚本文件,包含嵌入的 HTML
|
||||
'.js', // JavaScript 文件(部分是文本,部分可能包含代码)
|
||||
'.ts', // TypeScript 文件
|
||||
'.jsp', // JavaServer Pages 文件
|
||||
'.aspx', // ASP.NET 文件
|
||||
'.bat', // Windows 批处理文件
|
||||
'.sh', // Unix/Linux Shell 脚本文件
|
||||
'.py', // Python 脚本文件
|
||||
'.rb', // Ruby 脚本文件
|
||||
'.pl', // Perl 脚本文件
|
||||
'.sql', // SQL 脚本文件
|
||||
'.css', // Cascading Style Sheets 文件
|
||||
'.less', // Less CSS 预处理器文件
|
||||
'.scss', // Sass CSS 预处理器文件
|
||||
'.sass', // Sass 文件
|
||||
'.styl', // Stylus CSS 预处理器文件
|
||||
'.coffee', // CoffeeScript 文件
|
||||
'.ino', // Arduino 代码文件
|
||||
'.ino', // Arduino 代码文件
|
||||
'.asm', // Assembly 语言文件
|
||||
'.go', // Go 语言文件
|
||||
'.scala', // Scala 语言文件
|
||||
'.swift', // Swift 语言文件
|
||||
'.kt', // Kotlin 语言文件
|
||||
'.rs', // Rust 语言文件
|
||||
'.lua', // Lua 语言文件
|
||||
'.groovy', // Groovy 语言文件
|
||||
'.dart', // Dart 语言文件
|
||||
'.hs', // Haskell 语言文件
|
||||
'.clj', // Clojure 语言文件
|
||||
'.cljs', // ClojureScript 语言文件
|
||||
'.elm', // Elm 语言文件
|
||||
'.erl', // Erlang 语言文件
|
||||
'.ex', // Elixir 语言文件
|
||||
'.exs', // Elixir 脚本文件
|
||||
'.pug', // Pug (formerly Jade) 模板文件
|
||||
'.haml', // Haml 模板文件
|
||||
'.slim', // Slim 模板文件
|
||||
'.tpl', // 模板文件(通用)
|
||||
'.ejs', // Embedded JavaScript 模板文件
|
||||
'.hbs', // Handlebars 模板文件
|
||||
'.mustache', // Mustache 模板文件
|
||||
'.jade', // Jade 模板文件 (已重命名为 Pug)
|
||||
'.twig', // Twig 模板文件
|
||||
'.blade', // Blade 模板文件 (Laravel)
|
||||
'.vue', // Vue.js 单文件组件
|
||||
'.jsx', // React JSX 文件
|
||||
'.tsx', // React TSX 文件
|
||||
'.graphql', // GraphQL 查询语言文件
|
||||
'.gql', // GraphQL 查询语言文件
|
||||
'.proto', // Protocol Buffers 文件
|
||||
'.thrift', // Thrift 文件
|
||||
'.toml', // TOML 配置文件
|
||||
'.edn', // Clojure 数据表示文件
|
||||
'.cake', // CakePHP 配置文件
|
||||
'.ctp', // CakePHP 视图文件
|
||||
'.cfm', // ColdFusion 标记语言文件
|
||||
'.cfc', // ColdFusion 组件文件
|
||||
'.m', // Objective-C 源文件
|
||||
'.mm', // Objective-C++ 源文件
|
||||
'.gradle', // Gradle 构建文件
|
||||
'.groovy', // Gradle 构建文件
|
||||
'.gradle', // Gradle 构建文件
|
||||
'.kts' // Kotlin Script 文件
|
||||
]
|
||||
|
||||
ext = ext.toLowerCase()
|
||||
if (imageExts.includes(ext)) return FileTypes.IMAGE
|
||||
if (videoExts.includes(ext)) return FileTypes.VIDEO
|
||||
if (audioExts.includes(ext)) return FileTypes.AUDIO
|
||||
if (textExts.includes(ext)) return FileTypes.TEXT
|
||||
if (documentExts.includes(ext)) return FileTypes.DOCUMENT
|
||||
return FileTypes.OTHER
|
||||
}
|
||||
|
||||
Vendored
+3
-4
@@ -24,12 +24,11 @@ declare global {
|
||||
file: {
|
||||
select: (options?: OpenDialogOptions) => Promise<FileType[] | null>
|
||||
upload: (file: FileType) => Promise<FileType>
|
||||
delete: (fileId: string) => Promise<{ success: boolean }>
|
||||
delete: (fileId: string) => Promise<void>
|
||||
read: (fileId: string) => Promise<string>
|
||||
base64Image: (fileId: string) => Promise<{ mime: string; base64: string; data: string }>
|
||||
clear: () => Promise<void>
|
||||
}
|
||||
image: {
|
||||
base64: (filePath: string) => Promise<{ mime: string; base64: string; data: string }>
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -20,10 +20,9 @@ const api = {
|
||||
select: (options?: OpenDialogOptions) => ipcRenderer.invoke('file:select', options),
|
||||
upload: (filePath: string) => ipcRenderer.invoke('file:upload', filePath),
|
||||
delete: (fileId: string) => ipcRenderer.invoke('file:delete', fileId),
|
||||
read: (fileId: string) => ipcRenderer.invoke('file:read', fileId),
|
||||
base64Image: (fileId: string) => ipcRenderer.invoke('file:base64Image', fileId),
|
||||
clear: () => ipcRenderer.invoke('file:clear')
|
||||
},
|
||||
image: {
|
||||
base64: (filePath: string) => ipcRenderer.invoke('image:base64', filePath)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 1.4 KiB |
@@ -7,3 +7,95 @@ export const platform = window.electron?.process?.platform
|
||||
export const isMac = platform === 'darwin'
|
||||
export const isWindows = platform === 'win32' || platform === 'win64'
|
||||
export const isLinux = platform === 'linux'
|
||||
|
||||
export const imageExts = ['jpg', 'png', 'jpeg']
|
||||
export const textExts = [
|
||||
'.txt', // 普通文本文件
|
||||
'.md', // Markdown 文件
|
||||
'.mdx', // Markdown 文件
|
||||
'.html', // HTML 文件
|
||||
'.htm', // HTML 文件的另一种扩展名
|
||||
'.xml', // XML 文件
|
||||
'.json', // JSON 文件
|
||||
'.yaml', // YAML 文件
|
||||
'.yml', // YAML 文件的另一种扩展名
|
||||
'.csv', // 逗号分隔值文件
|
||||
'.tsv', // 制表符分隔值文件
|
||||
'.ini', // 配置文件
|
||||
'.log', // 日志文件
|
||||
'.rtf', // 富文本格式文件
|
||||
'.tex', // LaTeX 文件
|
||||
'.srt', // 字幕文件
|
||||
'.xhtml', // XHTML 文件
|
||||
'.nfo', // 信息文件(主要用于场景发布)
|
||||
'.conf', // 配置文件
|
||||
'.config', // 配置文件
|
||||
'.env', // 环境变量文件
|
||||
'.properties', // 配置属性文件
|
||||
'.latex', // LaTeX 文档文件
|
||||
'.rst', // reStructuredText 文件
|
||||
'.php', // PHP 脚本文件,包含嵌入的 HTML
|
||||
'.js', // JavaScript 文件(部分是文本,部分可能包含代码)
|
||||
'.ts', // TypeScript 文件
|
||||
'.jsp', // JavaServer Pages 文件
|
||||
'.aspx', // ASP.NET 文件
|
||||
'.bat', // Windows 批处理文件
|
||||
'.sh', // Unix/Linux Shell 脚本文件
|
||||
'.py', // Python 脚本文件
|
||||
'.rb', // Ruby 脚本文件
|
||||
'.pl', // Perl 脚本文件
|
||||
'.sql', // SQL 脚本文件
|
||||
'.css', // Cascading Style Sheets 文件
|
||||
'.less', // Less CSS 预处理器文件
|
||||
'.scss', // Sass CSS 预处理器文件
|
||||
'.sass', // Sass 文件
|
||||
'.styl', // Stylus CSS 预处理器文件
|
||||
'.coffee', // CoffeeScript 文件
|
||||
'.ino', // Arduino 代码文件
|
||||
'.ino', // Arduino 代码文件
|
||||
'.asm', // Assembly 语言文件
|
||||
'.go', // Go 语言文件
|
||||
'.scala', // Scala 语言文件
|
||||
'.swift', // Swift 语言文件
|
||||
'.kt', // Kotlin 语言文件
|
||||
'.rs', // Rust 语言文件
|
||||
'.lua', // Lua 语言文件
|
||||
'.groovy', // Groovy 语言文件
|
||||
'.dart', // Dart 语言文件
|
||||
'.hs', // Haskell 语言文件
|
||||
'.clj', // Clojure 语言文件
|
||||
'.cljs', // ClojureScript 语言文件
|
||||
'.elm', // Elm 语言文件
|
||||
'.erl', // Erlang 语言文件
|
||||
'.ex', // Elixir 语言文件
|
||||
'.exs', // Elixir 脚本文件
|
||||
'.pug', // Pug (formerly Jade) 模板文件
|
||||
'.haml', // Haml 模板文件
|
||||
'.slim', // Slim 模板文件
|
||||
'.tpl', // 模板文件(通用)
|
||||
'.ejs', // Embedded JavaScript 模板文件
|
||||
'.hbs', // Handlebars 模板文件
|
||||
'.mustache', // Mustache 模板文件
|
||||
'.jade', // Jade 模板文件 (已重命名为 Pug)
|
||||
'.twig', // Twig 模板文件
|
||||
'.blade', // Blade 模板文件 (Laravel)
|
||||
'.vue', // Vue.js 单文件组件
|
||||
'.jsx', // React JSX 文件
|
||||
'.tsx', // React TSX 文件
|
||||
'.graphql', // GraphQL 查询语言文件
|
||||
'.gql', // GraphQL 查询语言文件
|
||||
'.proto', // Protocol Buffers 文件
|
||||
'.thrift', // Thrift 文件
|
||||
'.toml', // TOML 配置文件
|
||||
'.edn', // Clojure 数据表示文件
|
||||
'.cake', // CakePHP 配置文件
|
||||
'.ctp', // CakePHP 视图文件
|
||||
'.cfm', // ColdFusion 标记语言文件
|
||||
'.cfc', // ColdFusion 组件文件
|
||||
'.m', // Objective-C 源文件
|
||||
'.mm', // Objective-C++ 源文件
|
||||
'.gradle', // Gradle 构建文件
|
||||
'.groovy', // Gradle 构建文件
|
||||
'.gradle', // Gradle 构建文件
|
||||
'.kts' // Kotlin Script 文件
|
||||
]
|
||||
|
||||
@@ -14,6 +14,7 @@ import GemmaModelLogo from '@renderer/assets/images/models/gemma.jpeg'
|
||||
import HailuoModelLogo from '@renderer/assets/images/models/hailuo.png'
|
||||
import LlamaModelLogo from '@renderer/assets/images/models/llama.jpeg'
|
||||
import MicrosoftModelLogo from '@renderer/assets/images/models/microsoft.png'
|
||||
import MinicpmModelLogo from '@renderer/assets/images/models/minicpm.webp'
|
||||
import MixtralModelLogo from '@renderer/assets/images/models/mixtral.jpeg'
|
||||
import PalmModelLogo from '@renderer/assets/images/models/palm.svg'
|
||||
import QwenModelLogo from '@renderer/assets/images/models/qwen.png'
|
||||
@@ -91,6 +92,7 @@ export function getModelLogo(modelId: string) {
|
||||
}
|
||||
|
||||
const logoMap = {
|
||||
o1: OpenAiProviderLogo,
|
||||
gpt: ChatGPTModelLogo,
|
||||
glm: ChatGLMModelLogo,
|
||||
deepseek: DeepSeekModelLogo,
|
||||
@@ -112,7 +114,8 @@ export function getModelLogo(modelId: string) {
|
||||
abab: HailuoModelLogo,
|
||||
'ep-202': DoubaoModelLogo,
|
||||
cohere: CohereModelLogo,
|
||||
command: CohereModelLogo
|
||||
command: CohereModelLogo,
|
||||
minicpm: MinicpmModelLogo
|
||||
}
|
||||
|
||||
for (const key in logoMap) {
|
||||
|
||||
@@ -88,7 +88,7 @@ const resources = {
|
||||
'input.send': 'Send',
|
||||
'input.pause': 'Pause',
|
||||
'input.settings': 'Settings',
|
||||
'input.upload': 'Upload image png、jpg、jpeg',
|
||||
'input.upload': 'Upload image or text file',
|
||||
'input.context_count.tip': 'Context Count',
|
||||
'input.estimated_tokens.tip': 'Estimated tokens',
|
||||
'settings.temperature': 'Temperature',
|
||||
@@ -356,7 +356,7 @@ const resources = {
|
||||
'input.send': '发送',
|
||||
'input.pause': '暂停',
|
||||
'input.settings': '设置',
|
||||
'input.upload': '上传图片 png、jpg、jpeg',
|
||||
'input.upload': '上传图片或纯文本文件',
|
||||
'input.context_count.tip': '上下文数',
|
||||
'input.estimated_tokens.tip': '预估 token 数',
|
||||
'settings.temperature': '模型温度',
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
import { Navbar, NavbarCenter } from '@renderer/components/app/Navbar'
|
||||
import { VStack } from '@renderer/components/Layout'
|
||||
import db from '@renderer/databases'
|
||||
import { FileType } from '@renderer/types'
|
||||
import { FileType, FileTypes } from '@renderer/types'
|
||||
import { getFileDirectory } from '@renderer/utils'
|
||||
import { Image, Table } from 'antd'
|
||||
import dayjs from 'dayjs'
|
||||
import { useLiveQuery } from 'dexie-react-hooks'
|
||||
@@ -13,13 +14,17 @@ const FilesPage: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const files = useLiveQuery<FileType[]>(() => db.files.toArray())
|
||||
|
||||
const dataSource = files?.map((file) => ({
|
||||
key: file.id,
|
||||
file: <Image src={'file://' + file.path} preview={false} style={{ maxHeight: '40px' }} />,
|
||||
name: <a href={'file://' + file.path}>{file.origin_name}</a>,
|
||||
size: `${(file.size / 1024 / 1024).toFixed(2)} MB`,
|
||||
created_at: dayjs(file.created_at).format('MM-DD HH:mm')
|
||||
}))
|
||||
const dataSource = files?.map((file) => {
|
||||
const isImage = file.type === FileTypes.IMAGE
|
||||
const ImageView = <Image src={'file://' + file.path} preview={false} style={{ maxHeight: '40px' }} />
|
||||
return {
|
||||
key: file.id,
|
||||
file: isImage ? ImageView : file.origin_name,
|
||||
name: <a href={'file://' + getFileDirectory(file.path)}>{file.origin_name}</a>,
|
||||
size: `${(file.size / 1024 / 1024).toFixed(2)} MB`,
|
||||
created_at: dayjs(file.created_at).format('MM-DD HH:mm')
|
||||
}
|
||||
})
|
||||
|
||||
const columns = [
|
||||
{
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import { PaperClipOutlined } from '@ant-design/icons'
|
||||
import { imageExts, textExts } from '@renderer/config/constant'
|
||||
import { isVisionModel } from '@renderer/config/models'
|
||||
import { FileType, Model } from '@renderer/types'
|
||||
import { Tooltip } from 'antd'
|
||||
@@ -14,18 +15,13 @@ interface Props {
|
||||
|
||||
const AttachmentButton: FC<Props> = ({ model, files, setFiles, ToolbarButton }) => {
|
||||
const { t } = useTranslation()
|
||||
const extensions = isVisionModel(model) ? [...imageExts, ...textExts] : [...textExts]
|
||||
|
||||
const onSelectFile = async () => {
|
||||
const _files = await window.api.file.select({
|
||||
filters: [{ name: 'Files', extensions: ['jpg', 'png', 'jpeg'] }]
|
||||
})
|
||||
const _files = await window.api.file.select({ filters: [{ name: 'Files', extensions }] })
|
||||
_files && setFiles(_files)
|
||||
}
|
||||
|
||||
if (!isVisionModel(model)) {
|
||||
return null
|
||||
}
|
||||
|
||||
return (
|
||||
<Tooltip placement="top" title={t('chat.input.upload')} arrow>
|
||||
<ToolbarButton type="text" className={files.length ? 'active' : ''} onClick={onSelectFile}>
|
||||
|
||||
@@ -15,7 +15,7 @@ import { useRuntime, useShowTopics } from '@renderer/hooks/useStore'
|
||||
import { getDefaultTopic } from '@renderer/services/assistant'
|
||||
import { EVENT_NAMES, EventEmitter } from '@renderer/services/event'
|
||||
import FileManager from '@renderer/services/file'
|
||||
import { estimateInputTokenCount } from '@renderer/services/messages'
|
||||
import { estimateTextTokens } from '@renderer/services/tokens'
|
||||
import store, { useAppDispatch, useAppSelector } from '@renderer/store'
|
||||
import { setGenerating, setSearching } from '@renderer/store/runtime'
|
||||
import { Assistant, FileType, Message, Topic } from '@renderer/types'
|
||||
@@ -92,7 +92,7 @@ const Inputbar: FC<Props> = ({ assistant, setActiveTopic }) => {
|
||||
setExpend(false)
|
||||
}, [assistant.id, assistant.topics, generating, files, text])
|
||||
|
||||
const inputTokenCount = useMemo(() => estimateInputTokenCount(text), [text])
|
||||
const inputTokenCount = useMemo(() => estimateTextTokens(text), [text])
|
||||
|
||||
const handleKeyDown = (event: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
const isEnterPressed = event.keyCode == 13
|
||||
|
||||
@@ -44,8 +44,8 @@ const TokenCount: FC<Props> = ({ estimateTokenCount, inputTokenCount, contextCou
|
||||
<PicCenterOutlined />
|
||||
</Tooltip>
|
||||
</ToolbarButton>
|
||||
<Container {...props}>
|
||||
<Popover content={PopoverContent} title="" mouseEnterDelay={0.6}>
|
||||
<Container>
|
||||
<Popover content={PopoverContent}>
|
||||
<MenuOutlined /> {contextCount}
|
||||
<Divider type="vertical" style={{ marginTop: 0, marginLeft: 5, marginRight: 5 }} />
|
||||
<ArrowUpOutlined />
|
||||
|
||||
@@ -46,14 +46,12 @@ const MessageItem: FC<Props> = ({ message, index, showMenu, onDeleteMessage }) =
|
||||
const { assistant, setModel } = useAssistant(message.assistantId)
|
||||
const model = useModel(message.modelId)
|
||||
const { userName, showMessageDivider, messageFont, fontSize } = useSettings()
|
||||
const { generating } = useRuntime()
|
||||
const [copied, setCopied] = useState(false)
|
||||
|
||||
const isLastMessage = index === 0
|
||||
const isUserMessage = message.role === 'user'
|
||||
const isAssistantMessage = message.role === 'assistant'
|
||||
const canRegenerate = isLastMessage && isAssistantMessage
|
||||
const showMetadata = Boolean(message.usage) && !generating
|
||||
|
||||
const onCopy = useCallback(() => {
|
||||
navigator.clipboard.writeText(removeTrailingDoubleSpaces(message.content))
|
||||
@@ -133,7 +131,7 @@ const MessageItem: FC<Props> = ({ message, index, showMenu, onDeleteMessage }) =
|
||||
style={{
|
||||
borderRadius: '20%',
|
||||
cursor: 'pointer',
|
||||
border: isLocalAi ? '1px solid var(--color-border)' : ''
|
||||
border: '1px solid var(--color-border)'
|
||||
}}
|
||||
onClick={showMiniApp}>
|
||||
{avatarName}
|
||||
@@ -206,18 +204,39 @@ const MessageItem: FC<Props> = ({ message, index, showMenu, onDeleteMessage }) =
|
||||
)}
|
||||
</MenusBar>
|
||||
)}
|
||||
{showMetadata && (
|
||||
<MessageMetadata>
|
||||
Tokens: {message?.usage?.total_tokens} | ↑{message?.usage?.prompt_tokens} | ↓
|
||||
{message?.usage?.completion_tokens}
|
||||
</MessageMetadata>
|
||||
)}
|
||||
<MessgeTokens message={message} />
|
||||
</MessageFooter>
|
||||
</MessageContentContainer>
|
||||
</MessageContainer>
|
||||
)
|
||||
}
|
||||
|
||||
const MessgeTokens: React.FC<{ message: Message }> = ({ message }) => {
|
||||
const { generating } = useRuntime()
|
||||
|
||||
if (!message.usage) {
|
||||
return null
|
||||
}
|
||||
|
||||
if (message.role === 'user') {
|
||||
return <MessageMetadata>Tokens: {message?.usage?.total_tokens}</MessageMetadata>
|
||||
}
|
||||
|
||||
if (generating) {
|
||||
return null
|
||||
}
|
||||
|
||||
if (message.role === 'assistant') {
|
||||
return (
|
||||
<MessageMetadata>
|
||||
Tokens: {message?.usage?.total_tokens} | ↑{message?.usage?.prompt_tokens} | ↓{message?.usage?.completion_tokens}
|
||||
</MessageMetadata>
|
||||
)
|
||||
}
|
||||
|
||||
return null
|
||||
}
|
||||
|
||||
const MessageContent: React.FC<{ message: Message }> = ({ message }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { Message } from '@renderer/types'
|
||||
import { Image as AntdImage } from 'antd'
|
||||
import { FileTypes, Message } from '@renderer/types'
|
||||
import { getFileDirectory } from '@renderer/utils'
|
||||
import { Image as AntdImage, Upload } from 'antd'
|
||||
import { FC } from 'react'
|
||||
import styled from 'styled-components'
|
||||
|
||||
@@ -8,9 +9,27 @@ interface Props {
|
||||
}
|
||||
|
||||
const MessageAttachments: FC<Props> = ({ message }) => {
|
||||
if (message?.files && message.files[0]?.type === FileTypes.IMAGE) {
|
||||
return (
|
||||
<Container>
|
||||
{message.files?.map((image) => <Image src={'file://' + image.path} key={image.id} width="33%" />)}
|
||||
</Container>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<Container>
|
||||
{message.files?.map((image) => <Image src={'file://' + image.path} key={image.id} width="33%" />)}
|
||||
<Container style={{ marginTop: -5 }}>
|
||||
<Upload
|
||||
listType="picture"
|
||||
disabled
|
||||
onPreview={(item) => item.url && window.open(getFileDirectory(item.url))}
|
||||
fileList={message.files?.map((file) => ({
|
||||
uid: file.id,
|
||||
url: 'file://' + file.path,
|
||||
status: 'done',
|
||||
name: file.origin_name
|
||||
}))}
|
||||
/>
|
||||
</Container>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -4,16 +4,12 @@ import { getTopic, TopicManager } from '@renderer/hooks/useTopic'
|
||||
import { fetchChatCompletion, fetchMessagesSummary } from '@renderer/services/api'
|
||||
import { getDefaultTopic } from '@renderer/services/assistant'
|
||||
import { EVENT_NAMES, EventEmitter } from '@renderer/services/event'
|
||||
import {
|
||||
deleteMessageFiles,
|
||||
estimateHistoryTokenCount,
|
||||
filterMessages,
|
||||
getContextCount
|
||||
} from '@renderer/services/messages'
|
||||
import { deleteMessageFiles, filterMessages, getContextCount } from '@renderer/services/messages'
|
||||
import { estimateHistoryTokens, estimateMessageUsage } from '@renderer/services/tokens'
|
||||
import { Assistant, Message, Model, Topic } from '@renderer/types'
|
||||
import { getBriefInfo, runAsyncFunction, uuid } from '@renderer/utils'
|
||||
import { t } from 'i18next'
|
||||
import { last, reverse, take } from 'lodash'
|
||||
import { flatten, last, reverse, take } from 'lodash'
|
||||
import { FC, useCallback, useEffect, useRef, useState } from 'react'
|
||||
import styled from 'styled-components'
|
||||
|
||||
@@ -34,12 +30,15 @@ const Messages: FC<Props> = ({ assistant, topic, setActiveTopic }) => {
|
||||
const { updateTopic, addTopic } = useAssistant(assistant.id)
|
||||
|
||||
const onSendMessage = useCallback(
|
||||
(message: Message) => {
|
||||
async (message: Message) => {
|
||||
if (message.role === 'user') {
|
||||
message.usage = await estimateMessageUsage(message)
|
||||
}
|
||||
const _messages = [...messages, message]
|
||||
setMessages(_messages)
|
||||
db.topics.put({ id: topic.id, messages: _messages })
|
||||
},
|
||||
[messages, topic]
|
||||
[messages, topic.id]
|
||||
)
|
||||
|
||||
const autoRenameTopic = useCallback(async () => {
|
||||
@@ -68,9 +67,14 @@ const Messages: FC<Props> = ({ assistant, topic, setActiveTopic }) => {
|
||||
const unsubscribes = [
|
||||
EventEmitter.on(EVENT_NAMES.SEND_MESSAGE, async (msg: Message) => {
|
||||
onSendMessage(msg)
|
||||
fetchChatCompletion({ assistant, messages: [...messages, msg], topic, onResponse: setLastMessage })
|
||||
fetchChatCompletion({
|
||||
assistant,
|
||||
messages: [...messages, msg],
|
||||
topic,
|
||||
onResponse: setLastMessage
|
||||
})
|
||||
}),
|
||||
EventEmitter.on(EVENT_NAMES.AI_CHAT_COMPLETION, async (msg: Message) => {
|
||||
EventEmitter.on(EVENT_NAMES.RECEIVE_MESSAGE, async (msg: Message) => {
|
||||
setLastMessage(null)
|
||||
onSendMessage(msg)
|
||||
setTimeout(() => EventEmitter.emit(EVENT_NAMES.AI_AUTO_RENAME), 100)
|
||||
@@ -98,6 +102,7 @@ const Messages: FC<Props> = ({ assistant, topic, setActiveTopic }) => {
|
||||
const lastMessage = last(messages)
|
||||
|
||||
if (lastMessage && lastMessage.type === 'clear') {
|
||||
onDeleteMessage(lastMessage)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -117,16 +122,37 @@ const Messages: FC<Props> = ({ assistant, topic, setActiveTopic }) => {
|
||||
} as Message)
|
||||
}),
|
||||
EventEmitter.on(EVENT_NAMES.NEW_BRANCH, async (index: number) => {
|
||||
const _topic = getDefaultTopic()
|
||||
_topic.name = topic.name
|
||||
await db.topics.add({ id: _topic.id, messages: take(messages, messages.length - index) })
|
||||
addTopic(_topic)
|
||||
setActiveTopic(_topic)
|
||||
const newTopic = getDefaultTopic()
|
||||
newTopic.name = topic.name
|
||||
const branchMessages = take(messages, messages.length - index)
|
||||
|
||||
// 将分支的消息放入数据库
|
||||
await db.topics.add({ id: newTopic.id, messages: branchMessages })
|
||||
addTopic(newTopic)
|
||||
setActiveTopic(newTopic)
|
||||
autoRenameTopic()
|
||||
|
||||
// 由于复制了消息,消息中附带的文件的总数变了,需要更新
|
||||
const filesArr = branchMessages.map((m) => m.files)
|
||||
const files = flatten(filesArr).filter(Boolean)
|
||||
files.map(async (f) => {
|
||||
const file = await db.files.get({ id: f?.id })
|
||||
file && db.files.update(file.id, { count: file.count + 1 })
|
||||
})
|
||||
})
|
||||
]
|
||||
return () => unsubscribes.forEach((unsub) => unsub())
|
||||
}, [addTopic, assistant, autoRenameTopic, messages, onSendMessage, setActiveTopic, topic, updateTopic])
|
||||
}, [
|
||||
addTopic,
|
||||
assistant,
|
||||
autoRenameTopic,
|
||||
messages,
|
||||
onDeleteMessage,
|
||||
onSendMessage,
|
||||
setActiveTopic,
|
||||
topic,
|
||||
updateTopic
|
||||
])
|
||||
|
||||
useEffect(() => {
|
||||
runAsyncFunction(async () => {
|
||||
@@ -140,9 +166,11 @@ const Messages: FC<Props> = ({ assistant, topic, setActiveTopic }) => {
|
||||
}, [messages])
|
||||
|
||||
useEffect(() => {
|
||||
EventEmitter.emit(EVENT_NAMES.ESTIMATED_TOKEN_COUNT, {
|
||||
tokensCount: estimateHistoryTokenCount(assistant, messages),
|
||||
contextCount: getContextCount(assistant, messages)
|
||||
runAsyncFunction(async () => {
|
||||
EventEmitter.emit(EVENT_NAMES.ESTIMATED_TOKEN_COUNT, {
|
||||
tokensCount: await estimateHistoryTokens(assistant, messages),
|
||||
contextCount: getContextCount(assistant, messages)
|
||||
})
|
||||
})
|
||||
}, [assistant, messages])
|
||||
|
||||
|
||||
@@ -37,7 +37,7 @@ const Suggestions: FC<Props> = ({ assistant, messages, lastMessage }) => {
|
||||
|
||||
useEffect(() => {
|
||||
const unsubscribes = [
|
||||
EventEmitter.on(EVENT_NAMES.AI_CHAT_COMPLETION, async (msg: Message) => {
|
||||
EventEmitter.on(EVENT_NAMES.RECEIVE_MESSAGE, async (msg: Message) => {
|
||||
setLoadingSuggestions(true)
|
||||
const _suggestions = await fetchSuggestions({ assistant, messages: [...messages, msg] })
|
||||
if (_suggestions.length) {
|
||||
|
||||
@@ -10,12 +10,8 @@ export default class AiProvider {
|
||||
this.sdk = ProviderFactory.create(provider)
|
||||
}
|
||||
|
||||
public async completions(
|
||||
messages: Message[],
|
||||
assistant: Assistant,
|
||||
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
|
||||
): Promise<void> {
|
||||
return this.sdk.completions(messages, assistant, onChunk)
|
||||
public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
|
||||
return this.sdk.completions({ messages, assistant, onChunk, onFilterMessages })
|
||||
}
|
||||
|
||||
public async translate(message: Message, assistant: Assistant): Promise<string> {
|
||||
|
||||
@@ -4,8 +4,8 @@ import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
|
||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant'
|
||||
import { EVENT_NAMES } from '@renderer/services/event'
|
||||
import { filterContextMessages, filterMessages } from '@renderer/services/messages'
|
||||
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
|
||||
import { first, sum, takeRight } from 'lodash'
|
||||
import { Assistant, FileTypes, Message, Provider, Suggestion } from '@renderer/types'
|
||||
import { first, flatten, sum, takeRight } from 'lodash'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import BaseProvider from './BaseProvider'
|
||||
@@ -18,49 +18,67 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
this.sdk = new Anthropic({ apiKey: provider.apiKey, baseURL: this.getBaseURL() })
|
||||
}
|
||||
|
||||
private async getMessageContent(message: Message): Promise<MessageParam['content']> {
|
||||
private async getMessageParam(message: Message): Promise<MessageParam[]> {
|
||||
const file = first(message.files)
|
||||
|
||||
if (!file) {
|
||||
return message.content
|
||||
if (file) {
|
||||
if (file.type === FileTypes.IMAGE) {
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
return [
|
||||
{
|
||||
role: message.role,
|
||||
content: [
|
||||
{ type: 'text', text: message.content },
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: base64Data.base64,
|
||||
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
|
||||
type: 'base64'
|
||||
}
|
||||
}
|
||||
]
|
||||
} as MessageParam
|
||||
]
|
||||
}
|
||||
if (file.type === FileTypes.TEXT) {
|
||||
return [
|
||||
{
|
||||
role: message.role,
|
||||
content: message.content
|
||||
} as MessageParam,
|
||||
{
|
||||
role: 'assistant',
|
||||
content: (await window.api.file.read(file.id + file.ext)).trimEnd()
|
||||
} as MessageParam
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
if (file.type === 'image') {
|
||||
const base64Data = await window.api.image.base64(file.path)
|
||||
return [
|
||||
{ type: 'text', text: message.content },
|
||||
{
|
||||
type: 'image',
|
||||
source: {
|
||||
data: base64Data.base64,
|
||||
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
|
||||
type: 'base64'
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
return message.content
|
||||
return [
|
||||
{
|
||||
role: message.role,
|
||||
content: message.content
|
||||
} as MessageParam
|
||||
]
|
||||
}
|
||||
|
||||
public async completions(
|
||||
messages: Message[],
|
||||
assistant: Assistant,
|
||||
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
|
||||
) {
|
||||
public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const { contextCount, maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
const userMessages: MessageParam[] = []
|
||||
let userMessagesParams: MessageParam[][] = []
|
||||
const _messages = filterMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
|
||||
|
||||
for (const message of filterMessages(filterContextMessages(takeRight(messages, contextCount + 2)))) {
|
||||
userMessages.push({
|
||||
role: message.role,
|
||||
content: await this.getMessageContent(message)
|
||||
})
|
||||
onFilterMessages(_messages)
|
||||
|
||||
for (const message of _messages) {
|
||||
userMessagesParams = userMessagesParams.concat(await this.getMessageParam(message))
|
||||
}
|
||||
|
||||
const userMessages = flatten(userMessagesParams)
|
||||
|
||||
if (first(userMessages)?.role === 'assistant') {
|
||||
userMessages.shift()
|
||||
}
|
||||
@@ -69,7 +87,7 @@ export default class AnthropicProvider extends BaseProvider {
|
||||
const stream = this.sdk.messages
|
||||
.stream({
|
||||
model: model.id,
|
||||
messages: userMessages.filter(Boolean) as MessageParam[],
|
||||
messages: userMessages,
|
||||
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
|
||||
temperature: assistant?.settings?.temperature,
|
||||
system: assistant.prompt,
|
||||
|
||||
@@ -20,11 +20,7 @@ export default abstract class BaseProvider {
|
||||
return this.provider.id === 'ollama' ? getOllamaKeepAliveTime() : undefined
|
||||
}
|
||||
|
||||
abstract completions(
|
||||
messages: Message[],
|
||||
assistant: Assistant,
|
||||
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
|
||||
): Promise<void>
|
||||
abstract completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void>
|
||||
abstract translate(message: Message, assistant: Assistant): Promise<string>
|
||||
abstract summaries(messages: Message[], assistant: Assistant): Promise<string>
|
||||
abstract suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]>
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import { Content, GoogleGenerativeAI, InlineDataPart, Part } from '@google/generative-ai'
|
||||
import { Content, GoogleGenerativeAI, InlineDataPart, TextPart } from '@google/generative-ai'
|
||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant'
|
||||
import { EVENT_NAMES } from '@renderer/services/event'
|
||||
import { filterContextMessages, filterMessages } from '@renderer/services/messages'
|
||||
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
|
||||
import { Assistant, FileTypes, Message, Provider, Suggestion } from '@renderer/types'
|
||||
import axios from 'axios'
|
||||
import { first, isEmpty, takeRight } from 'lodash'
|
||||
import { first, flatten, isEmpty, takeRight } from 'lodash'
|
||||
import OpenAI from 'openai'
|
||||
|
||||
import BaseProvider from './BaseProvider'
|
||||
@@ -17,42 +17,67 @@ export default class GeminiProvider extends BaseProvider {
|
||||
this.sdk = new GoogleGenerativeAI(provider.apiKey)
|
||||
}
|
||||
|
||||
private async getMessageParts(message: Message): Promise<Part[]> {
|
||||
private async getMessageContents(message: Message): Promise<Content[]> {
|
||||
const file = first(message.files)
|
||||
const role = message.role === 'user' ? 'user' : 'model'
|
||||
|
||||
if (file && file.type === 'image') {
|
||||
const base64Data = await window.api.image.base64(file.path)
|
||||
return [
|
||||
{
|
||||
text: message.content
|
||||
},
|
||||
{
|
||||
inlineData: {
|
||||
data: base64Data.base64,
|
||||
mimeType: base64Data.mime
|
||||
if (file) {
|
||||
if (file.type === FileTypes.IMAGE) {
|
||||
const base64Data = await window.api.file.base64Image(file.id + file.ext)
|
||||
return [
|
||||
{
|
||||
role: message.role,
|
||||
parts: [
|
||||
{ text: message.content } as TextPart,
|
||||
{
|
||||
inlineData: {
|
||||
data: base64Data.base64,
|
||||
mimeType: base64Data.mime
|
||||
}
|
||||
} as InlineDataPart
|
||||
]
|
||||
}
|
||||
} as InlineDataPart
|
||||
]
|
||||
]
|
||||
}
|
||||
if (file.type === FileTypes.TEXT) {
|
||||
return [
|
||||
{
|
||||
role: 'model',
|
||||
parts: [{ text: await window.api.file.read(file.id + file.ext) } as TextPart]
|
||||
},
|
||||
{
|
||||
role,
|
||||
parts: [{ text: message.content } as TextPart]
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
return [{ text: message.content }]
|
||||
return [
|
||||
{
|
||||
role,
|
||||
parts: [{ text: message.content } as TextPart]
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
public async completions(
|
||||
messages: Message[],
|
||||
assistant: Assistant,
|
||||
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
|
||||
) {
|
||||
public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const { contextCount, maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
const userMessages = filterMessages(filterContextMessages(takeRight(messages, contextCount + 1))).map((message) => {
|
||||
return {
|
||||
role: message.role,
|
||||
message
|
||||
}
|
||||
})
|
||||
const userMessages = filterMessages(filterContextMessages(takeRight(messages, contextCount + 1)))
|
||||
onFilterMessages(userMessages)
|
||||
|
||||
const userLastMessage = userMessages.pop()
|
||||
|
||||
let historyContents: Content[][] = []
|
||||
|
||||
for (const message of userMessages) {
|
||||
historyContents = historyContents.concat(await this.getMessageContents(message))
|
||||
}
|
||||
|
||||
const history = flatten(historyContents)
|
||||
|
||||
const geminiModel = this.sdk.getGenerativeModel({
|
||||
model: model.id,
|
||||
@@ -63,21 +88,9 @@ export default class GeminiProvider extends BaseProvider {
|
||||
}
|
||||
})
|
||||
|
||||
const userLastMessage = userMessages.pop()
|
||||
|
||||
const history: Content[] = []
|
||||
|
||||
for (const message of userMessages) {
|
||||
history.push({
|
||||
role: message.role === 'user' ? 'user' : 'model',
|
||||
parts: await this.getMessageParts(message.message)
|
||||
})
|
||||
}
|
||||
|
||||
const chat = geminiModel.startChat({ history })
|
||||
const message = await this.getMessageParts(userLastMessage?.message!)
|
||||
|
||||
const userMessagesStream = await chat.sendMessageStream(message)
|
||||
const messageContents = await this.getMessageContents(userLastMessage!)
|
||||
const userMessagesStream = await chat.sendMessageStream(messageContents[0].parts)
|
||||
|
||||
for await (const chunk of userMessagesStream.stream) {
|
||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
|
||||
|
||||
@@ -2,7 +2,7 @@ import { isLocalAi } from '@renderer/config/env'
|
||||
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant'
|
||||
import { EVENT_NAMES } from '@renderer/services/event'
|
||||
import { filterContextMessages, filterMessages } from '@renderer/services/messages'
|
||||
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
|
||||
import { Assistant, FileTypes, Message, Provider, Suggestion } from '@renderer/types'
|
||||
import { removeQuotes } from '@renderer/utils'
|
||||
import { first, takeRight } from 'lodash'
|
||||
import OpenAI from 'openai'
|
||||
@@ -26,61 +26,92 @@ export default class OpenAIProvider extends BaseProvider {
|
||||
})
|
||||
}
|
||||
|
||||
private async getMessageContent(message: Message): Promise<string | ChatCompletionContentPart[]> {
|
||||
const file = first(message.files)
|
||||
|
||||
if (!file) {
|
||||
return message.content
|
||||
private isSupportStreamOutput(modelId: string): boolean {
|
||||
if (this.provider.id === 'openai' && modelId.includes('o1-')) {
|
||||
return false
|
||||
}
|
||||
|
||||
if (file.type === 'image') {
|
||||
const base64Data = await window.api.image.base64(file.path)
|
||||
return [
|
||||
{ type: 'text', text: message.content },
|
||||
{
|
||||
type: 'image_url',
|
||||
image_url: {
|
||||
url: base64Data.data
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
return message.content
|
||||
return true
|
||||
}
|
||||
|
||||
async completions(
|
||||
messages: Message[],
|
||||
assistant: Assistant,
|
||||
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
|
||||
): Promise<void> {
|
||||
private async getMessageParam(message: Message): Promise<OpenAI.Chat.Completions.ChatCompletionMessageParam[]> {
|
||||
const file = first(message.files)
|
||||
|
||||
const content: string | ChatCompletionContentPart[] = message.content
|
||||
|
||||
if (file) {
|
||||
if (file.type === FileTypes.IMAGE) {
|
||||
const image = await window.api.file.base64Image(file.id + file.ext)
|
||||
return [
|
||||
{
|
||||
role: message.role,
|
||||
content: [
|
||||
{ type: 'text', text: message.content },
|
||||
{
|
||||
type: 'image_url',
|
||||
image_url: {
|
||||
url: image.data
|
||||
}
|
||||
}
|
||||
]
|
||||
} as ChatCompletionMessageParam
|
||||
]
|
||||
}
|
||||
if (file.type === FileTypes.TEXT) {
|
||||
return [
|
||||
{
|
||||
role: 'assistant',
|
||||
content: await window.api.file.read(file.id + file.ext)
|
||||
} as ChatCompletionMessageParam,
|
||||
{
|
||||
role: message.role,
|
||||
content
|
||||
} as ChatCompletionMessageParam
|
||||
]
|
||||
}
|
||||
}
|
||||
|
||||
return [
|
||||
{
|
||||
role: message.role,
|
||||
content
|
||||
} as ChatCompletionMessageParam
|
||||
]
|
||||
}
|
||||
|
||||
async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
|
||||
const defaultModel = getDefaultModel()
|
||||
const model = assistant.model || defaultModel
|
||||
const { contextCount, maxTokens } = getAssistantSettings(assistant)
|
||||
|
||||
const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined
|
||||
const userMessages: ChatCompletionMessageParam[] = []
|
||||
let userMessages: ChatCompletionMessageParam[] = []
|
||||
|
||||
for (const message of filterMessages(filterContextMessages(takeRight(messages, contextCount + 1)))) {
|
||||
userMessages.push({
|
||||
role: message.role,
|
||||
content: await this.getMessageContent(message)
|
||||
} as ChatCompletionMessageParam)
|
||||
const _messages = filterMessages(filterContextMessages(takeRight(messages, contextCount + 1)))
|
||||
onFilterMessages(_messages)
|
||||
|
||||
for (const message of _messages) {
|
||||
userMessages = userMessages.concat(await this.getMessageParam(message))
|
||||
}
|
||||
|
||||
// @ts-ignore key is not typed
|
||||
const stream = await this.sdk.chat.completions.create({
|
||||
model: model.id,
|
||||
messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[],
|
||||
stream: true,
|
||||
stream: this.isSupportStreamOutput(model.id),
|
||||
temperature: assistant?.settings?.temperature,
|
||||
max_tokens: maxTokens,
|
||||
keep_alive: this.keepAliveTime
|
||||
})
|
||||
|
||||
for await (const chunk of stream) {
|
||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
|
||||
onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage })
|
||||
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
|
||||
break
|
||||
}
|
||||
|
||||
onChunk({
|
||||
text: chunk.choices[0]?.delta?.content || '',
|
||||
usage: chunk.usage
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
+11
@@ -0,0 +1,11 @@
|
||||
interface ChunkCallbackData {
|
||||
text?: string
|
||||
usage?: OpenAI.Completions.CompletionUsage
|
||||
}
|
||||
|
||||
interface CompletionsParams {
|
||||
messages: Message[]
|
||||
assistant: Assistant
|
||||
onChunk: ({ text, usage }: ChunkCallbackData) => void
|
||||
onFilterMessages: (messages: Message[]) => void
|
||||
}
|
||||
@@ -15,7 +15,8 @@ import {
|
||||
getTranslateModel
|
||||
} from './assistant'
|
||||
import { EVENT_NAMES, EventEmitter } from './event'
|
||||
import { estimateMessagesToken, filterMessages } from './messages'
|
||||
import { filterMessages } from './messages'
|
||||
import { estimateMessagesUsage } from './tokens'
|
||||
|
||||
export async function fetchChatCompletion({
|
||||
messages,
|
||||
@@ -61,13 +62,27 @@ export async function fetchChatCompletion({
|
||||
}, 1000)
|
||||
|
||||
try {
|
||||
await AI.completions(messages, assistant, ({ text, usage }) => {
|
||||
message.content = message.content + text || ''
|
||||
message.usage = usage
|
||||
onResponse({ ...message, status: 'pending' })
|
||||
let _messages: Message[] = []
|
||||
|
||||
await AI.completions({
|
||||
messages,
|
||||
assistant,
|
||||
onFilterMessages: (messages) => (_messages = messages),
|
||||
onChunk: ({ text, usage }) => {
|
||||
message.content = message.content + text || ''
|
||||
message.usage = usage
|
||||
onResponse({ ...message, status: 'pending' })
|
||||
}
|
||||
})
|
||||
|
||||
message.status = 'success'
|
||||
message.usage = message.usage || (await estimateMessagesToken({ assistant, messages: [...messages, message] }))
|
||||
|
||||
if (!message.usage) {
|
||||
message.usage = await estimateMessagesUsage({
|
||||
assistant,
|
||||
messages: [..._messages, message]
|
||||
})
|
||||
}
|
||||
} catch (error: any) {
|
||||
message.content = `Error: ${error.message}`
|
||||
message.status = 'error'
|
||||
@@ -83,7 +98,7 @@ export async function fetchChatCompletion({
|
||||
message.status = window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED) ? 'paused' : message.status
|
||||
|
||||
// Emit chat completion event
|
||||
EventEmitter.emit(EVENT_NAMES.AI_CHAT_COMPLETION, message)
|
||||
EventEmitter.emit(EVENT_NAMES.RECEIVE_MESSAGE, message)
|
||||
|
||||
// Reset generating state
|
||||
store.dispatch(setGenerating(false))
|
||||
|
||||
@@ -4,7 +4,7 @@ export const EventEmitter = new Emittery()
|
||||
|
||||
export const EVENT_NAMES = {
|
||||
SEND_MESSAGE: 'SEND_MESSAGE',
|
||||
AI_CHAT_COMPLETION: 'AI_CHAT_COMPLETION',
|
||||
RECEIVE_MESSAGE: 'RECEIVE_MESSAGE',
|
||||
AI_AUTO_RENAME: 'AI_AUTO_RENAME',
|
||||
CLEAR_MESSAGES: 'CLEAR_MESSAGES',
|
||||
ADD_ASSISTANT: 'ADD_ASSISTANT',
|
||||
|
||||
@@ -1,10 +1,7 @@
|
||||
import { DEFAULT_CONEXTCOUNT } from '@renderer/config/constant'
|
||||
import { Assistant, Message } from '@renderer/types'
|
||||
import { GPTTokens } from 'gpt-tokens'
|
||||
import { isEmpty, last, takeRight } from 'lodash'
|
||||
import { CompletionUsage } from 'openai/resources'
|
||||
import { isEmpty, takeRight } from 'lodash'
|
||||
|
||||
import { getAssistantSettings } from './assistant'
|
||||
import FileManager from './file'
|
||||
|
||||
export const filterMessages = (messages: Message[]) => {
|
||||
@@ -36,50 +33,6 @@ export function getContextCount(assistant: Assistant, messages: Message[]) {
|
||||
return messagesCount - (clearIndex + 1)
|
||||
}
|
||||
|
||||
export function estimateInputTokenCount(text: string) {
|
||||
const input = new GPTTokens({
|
||||
model: 'gpt-4o',
|
||||
messages: [{ role: 'user', content: text }]
|
||||
})
|
||||
|
||||
return input.usedTokens - 7
|
||||
}
|
||||
|
||||
export async function estimateMessagesToken({
|
||||
assistant,
|
||||
messages
|
||||
}: {
|
||||
assistant: Assistant
|
||||
messages: Message[]
|
||||
}): Promise<CompletionUsage> {
|
||||
const responseMessageContent = last(messages)?.content
|
||||
const inputMessageContent = messages[messages.length - 2]?.content
|
||||
const completion_tokens = await estimateInputTokenCount(responseMessageContent ?? '')
|
||||
const prompt_tokens = await estimateInputTokenCount(assistant.prompt + inputMessageContent ?? '')
|
||||
return {
|
||||
completion_tokens,
|
||||
prompt_tokens: prompt_tokens,
|
||||
total_tokens: prompt_tokens + completion_tokens
|
||||
} as CompletionUsage
|
||||
}
|
||||
|
||||
export function estimateHistoryTokenCount(assistant: Assistant, msgs: Message[]) {
|
||||
const { contextCount } = getAssistantSettings(assistant)
|
||||
|
||||
const all = new GPTTokens({
|
||||
model: 'gpt-4o',
|
||||
messages: [
|
||||
{ role: 'system', content: assistant.prompt },
|
||||
...filterMessages(filterContextMessages(takeRight(msgs, contextCount))).map((message) => ({
|
||||
role: message.role,
|
||||
content: message.content
|
||||
}))
|
||||
]
|
||||
})
|
||||
|
||||
return all.usedTokens - 7
|
||||
}
|
||||
|
||||
export function deleteMessageFiles(message: Message) {
|
||||
message.files && FileManager.deleteFiles(message.files.map((f) => f.id))
|
||||
}
|
||||
|
||||
@@ -0,0 +1,130 @@
|
||||
import { Assistant, FileType, FileTypes, Message } from '@renderer/types'
|
||||
import { GPTTokens } from 'gpt-tokens'
|
||||
import { flatten, takeRight } from 'lodash'
|
||||
import { CompletionUsage } from 'openai/resources'
|
||||
|
||||
import { getAssistantSettings } from './assistant'
|
||||
import { filterContextMessages, filterMessages } from './messages'
|
||||
|
||||
interface MessageItem {
|
||||
name?: string
|
||||
role: 'system' | 'user' | 'assistant'
|
||||
content: string
|
||||
}
|
||||
|
||||
async function getFileContent(file: FileType) {
|
||||
if (!file) {
|
||||
return ''
|
||||
}
|
||||
|
||||
const fileId = file.id + file.ext
|
||||
|
||||
if (file.type === FileTypes.IMAGE) {
|
||||
const data = await window.api.file.base64Image(fileId)
|
||||
return data.data
|
||||
}
|
||||
|
||||
if (file.type === FileTypes.TEXT) {
|
||||
return await window.api.file.read(fileId)
|
||||
}
|
||||
|
||||
return ''
|
||||
}
|
||||
|
||||
async function getMessageParam(message: Message): Promise<MessageItem[]> {
|
||||
const param: MessageItem[] = []
|
||||
|
||||
param.push({
|
||||
role: message.role,
|
||||
content: message.content
|
||||
})
|
||||
|
||||
if (message.files) {
|
||||
for (const file of message.files) {
|
||||
param.push({
|
||||
role: 'assistant',
|
||||
content: await getFileContent(file)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return param
|
||||
}
|
||||
|
||||
export function estimateTextTokens(text: string) {
|
||||
const { usedTokens } = new GPTTokens({
|
||||
model: 'gpt-4o',
|
||||
messages: [{ role: 'user', content: text }]
|
||||
})
|
||||
|
||||
return usedTokens - 7
|
||||
}
|
||||
|
||||
export async function estimateMessageUsage(message: Message): Promise<CompletionUsage> {
|
||||
const { usedTokens, promptUsedTokens, completionUsedTokens } = new GPTTokens({
|
||||
model: 'gpt-4o',
|
||||
messages: await getMessageParam(message)
|
||||
})
|
||||
|
||||
const hasImage = message.files?.some((f) => f.type === FileTypes.IMAGE)
|
||||
|
||||
return {
|
||||
prompt_tokens: promptUsedTokens,
|
||||
completion_tokens: completionUsedTokens,
|
||||
total_tokens: hasImage ? Math.floor(usedTokens / 80) : usedTokens - 7
|
||||
}
|
||||
}
|
||||
|
||||
export async function estimateMessagesUsage({
|
||||
assistant,
|
||||
messages
|
||||
}: {
|
||||
assistant: Assistant
|
||||
messages: Message[]
|
||||
}): Promise<CompletionUsage> {
|
||||
const outputMessage = messages.pop()!
|
||||
|
||||
const prompt_tokens = await estimateHistoryTokens(assistant, messages)
|
||||
const { completion_tokens } = await estimateMessageUsage(outputMessage)
|
||||
|
||||
return {
|
||||
prompt_tokens: await estimateHistoryTokens(assistant, messages),
|
||||
completion_tokens,
|
||||
total_tokens: prompt_tokens + completion_tokens
|
||||
} as CompletionUsage
|
||||
}
|
||||
|
||||
export async function estimateHistoryTokens(assistant: Assistant, msgs: Message[]) {
|
||||
const { contextCount } = getAssistantSettings(assistant)
|
||||
const messages = filterMessages(filterContextMessages(takeRight(msgs, contextCount)))
|
||||
|
||||
// 有 usage 数据的消息,快速计算总数
|
||||
const uasageTokens = messages
|
||||
.filter((m) => m.usage)
|
||||
.reduce((acc, message) => {
|
||||
const inputTokens = message.usage?.total_tokens ?? 0
|
||||
const outputTokens = message.usage!.completion_tokens ?? 0
|
||||
return acc + (message.role === 'user' ? inputTokens : outputTokens)
|
||||
}, 0)
|
||||
|
||||
// 没有 usage 数据的消息,需要计算每条消息的 token
|
||||
let allMessages: MessageItem[][] = []
|
||||
|
||||
for (const message of messages.filter((m) => !m.usage)) {
|
||||
const items = await getMessageParam(message)
|
||||
allMessages = allMessages.concat(items)
|
||||
}
|
||||
|
||||
const { usedTokens } = new GPTTokens({
|
||||
model: 'gpt-4o',
|
||||
messages: [
|
||||
{
|
||||
role: 'system',
|
||||
content: assistant.prompt
|
||||
},
|
||||
...flatten(allMessages)
|
||||
]
|
||||
})
|
||||
|
||||
return usedTokens - 7 + uasageTokens
|
||||
}
|
||||
@@ -97,12 +97,14 @@ export interface FileType {
|
||||
type: FileTypes
|
||||
created_at: Date
|
||||
count: number
|
||||
tokens?: number
|
||||
}
|
||||
|
||||
export enum FileTypes {
|
||||
IMAGE = 'image',
|
||||
VIDEO = 'video',
|
||||
AUDIO = 'audio',
|
||||
TEXT = 'text',
|
||||
DOCUMENT = 'document',
|
||||
OTHER = 'other'
|
||||
}
|
||||
|
||||
@@ -229,3 +229,9 @@ export function removeTrailingDoubleSpaces(markdown: string): string {
|
||||
// 使用正则表达式匹配末尾的两个空格,并替换为空字符串
|
||||
return markdown.replace(/ {2}$/gm, '')
|
||||
}
|
||||
|
||||
export function getFileDirectory(filePath: string) {
|
||||
const parts = filePath.split('/')
|
||||
const directory = parts.slice(0, -1).join('/')
|
||||
return directory
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user