feat: add file attachment

This commit is contained in:
kangfenmao
2024-09-18 18:00:49 +08:00
parent 6f5dccd595
commit 6e7e5cb1f1
32 changed files with 825 additions and 273 deletions
+7
View File
@@ -1,5 +1,6 @@
import { electronApp, optimizer } from '@electron-toolkit/utils'
import { app, BrowserWindow } from 'electron'
import installExtension, { REDUX_DEVTOOLS } from 'electron-devtools-installer'
import { registerIpc } from './ipc'
import { updateUserDataPath } from './utils/upgrade'
@@ -30,6 +31,12 @@ app.whenReady().then(async () => {
const mainWindow = createMainWindow()
registerIpc(mainWindow, app)
if (process.env.NODE_ENV === 'development') {
installExtension(REDUX_DEVTOOLS)
.then((name) => console.log(`Added Extension: ${name}`))
.catch((err) => console.log('An error occurred: ', err))
}
})
// Quit when all windows are closed, except on macOS. There, it's common
+3 -23
View File
@@ -1,8 +1,5 @@
import { FileType } from '@types'
import { BrowserWindow, ipcMain, OpenDialogOptions, session, shell } from 'electron'
import Logger from 'electron-log'
import fs from 'fs'
import path from 'path'
import { appConfig, titleBarOverlayDark, titleBarOverlayLight } from './config'
import AppUpdater from './services/AppUpdater'
@@ -38,29 +35,12 @@ export function registerIpc(mainWindow: BrowserWindow, app: Electron.App) {
ipcMain.handle('zip:compress', (_, text: string) => compress(text))
ipcMain.handle('zip:decompress', (_, text: Buffer) => decompress(text))
ipcMain.handle('image:base64', async (_, filePath) => {
try {
const data = await fs.promises.readFile(filePath)
const base64 = data.toString('base64')
const mime = `image/${path.extname(filePath).slice(1)}`
return {
mime,
base64,
data: `data:${mime};base64,${base64}`
}
} catch (error) {
Logger.error('Error reading file:', error)
return ''
}
})
ipcMain.handle('file:base64Image', async (_, id) => await fileManager.base64Image(id))
ipcMain.handle('file:select', async (_, options?: OpenDialogOptions) => await fileManager.selectFile(options))
ipcMain.handle('file:upload', async (_, file: FileType) => await fileManager.uploadFile(file))
ipcMain.handle('file:clear', async () => await fileManager.clear())
ipcMain.handle('file:delete', async (_, fileId: string) => {
await fileManager.deleteFile(fileId)
return { success: true }
})
ipcMain.handle('file:read', async (_, id: string) => await fileManager.readFile(id))
ipcMain.handle('file:delete', async (_, id: string) => await fileManager.deleteFile(id))
ipcMain.handle('minapp', (_, args) => {
createMinappWindow({
+17
View File
@@ -135,6 +135,23 @@ class File {
await fs.promises.unlink(path.join(this.storageDir, id))
}
async readFile(id: string): Promise<string> {
const filePath = path.join(this.storageDir, id)
return fs.readFileSync(filePath, 'utf8')
}
async base64Image(id: string): Promise<{ mime: string; base64: string; data: string }> {
const filePath = path.join(this.storageDir, id)
const data = await fs.promises.readFile(filePath)
const base64 = data.toString('base64')
const mime = `image/${path.extname(filePath).slice(1)}`
return {
mime,
base64,
data: `data:${mime};base64,${base64}`
}
}
async clear(): Promise<void> {
await fs.promises.rmdir(this.storageDir, { recursive: true })
await this.initStorageDir()
+92 -1
View File
@@ -56,12 +56,103 @@ export function getFileType(ext: string): FileTypes {
const imageExts = ['.jpg', '.jpeg', '.png', '.gif', '.bmp', '.webp']
const videoExts = ['.mp4', '.avi', '.mov', '.wmv', '.flv', '.mkv']
const audioExts = ['.mp3', '.wav', '.ogg', '.flac', '.aac']
const documentExts = ['.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx', '.txt']
const documentExts = ['.pdf', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx']
const textExts = [
'.txt', // 普通文本文件
'.md', // Markdown 文件
'.mdx', // Markdown 文件
'.html', // HTML 文件
'.htm', // HTML 文件的另一种扩展名
'.xml', // XML 文件
'.json', // JSON 文件
'.yaml', // YAML 文件
'.yml', // YAML 文件的另一种扩展名
'.csv', // 逗号分隔值文件
'.tsv', // 制表符分隔值文件
'.ini', // 配置文件
'.log', // 日志文件
'.rtf', // 富文本格式文件
'.tex', // LaTeX 文件
'.srt', // 字幕文件
'.xhtml', // XHTML 文件
'.nfo', // 信息文件(主要用于场景发布)
'.conf', // 配置文件
'.config', // 配置文件
'.env', // 环境变量文件
'.properties', // 配置属性文件
'.latex', // LaTeX 文档文件
'.rst', // reStructuredText 文件
'.php', // PHP 脚本文件,包含嵌入的 HTML
'.js', // JavaScript 文件(部分是文本,部分可能包含代码)
'.ts', // TypeScript 文件
'.jsp', // JavaServer Pages 文件
'.aspx', // ASP.NET 文件
'.bat', // Windows 批处理文件
'.sh', // Unix/Linux Shell 脚本文件
'.py', // Python 脚本文件
'.rb', // Ruby 脚本文件
'.pl', // Perl 脚本文件
'.sql', // SQL 脚本文件
'.css', // Cascading Style Sheets 文件
'.less', // Less CSS 预处理器文件
'.scss', // Sass CSS 预处理器文件
'.sass', // Sass 文件
'.styl', // Stylus CSS 预处理器文件
'.coffee', // CoffeeScript 文件
'.ino', // Arduino 代码文件
'.ino', // Arduino 代码文件
'.asm', // Assembly 语言文件
'.go', // Go 语言文件
'.scala', // Scala 语言文件
'.swift', // Swift 语言文件
'.kt', // Kotlin 语言文件
'.rs', // Rust 语言文件
'.lua', // Lua 语言文件
'.groovy', // Groovy 语言文件
'.dart', // Dart 语言文件
'.hs', // Haskell 语言文件
'.clj', // Clojure 语言文件
'.cljs', // ClojureScript 语言文件
'.elm', // Elm 语言文件
'.erl', // Erlang 语言文件
'.ex', // Elixir 语言文件
'.exs', // Elixir 脚本文件
'.pug', // Pug (formerly Jade) 模板文件
'.haml', // Haml 模板文件
'.slim', // Slim 模板文件
'.tpl', // 模板文件(通用)
'.ejs', // Embedded JavaScript 模板文件
'.hbs', // Handlebars 模板文件
'.mustache', // Mustache 模板文件
'.jade', // Jade 模板文件 (已重命名为 Pug)
'.twig', // Twig 模板文件
'.blade', // Blade 模板文件 (Laravel)
'.vue', // Vue.js 单文件组件
'.jsx', // React JSX 文件
'.tsx', // React TSX 文件
'.graphql', // GraphQL 查询语言文件
'.gql', // GraphQL 查询语言文件
'.proto', // Protocol Buffers 文件
'.thrift', // Thrift 文件
'.toml', // TOML 配置文件
'.edn', // Clojure 数据表示文件
'.cake', // CakePHP 配置文件
'.ctp', // CakePHP 视图文件
'.cfm', // ColdFusion 标记语言文件
'.cfc', // ColdFusion 组件文件
'.m', // Objective-C 源文件
'.mm', // Objective-C++ 源文件
'.gradle', // Gradle 构建文件
'.groovy', // Gradle 构建文件
'.gradle', // Gradle 构建文件
'.kts' // Kotlin Script 文件
]
ext = ext.toLowerCase()
if (imageExts.includes(ext)) return FileTypes.IMAGE
if (videoExts.includes(ext)) return FileTypes.VIDEO
if (audioExts.includes(ext)) return FileTypes.AUDIO
if (textExts.includes(ext)) return FileTypes.TEXT
if (documentExts.includes(ext)) return FileTypes.DOCUMENT
return FileTypes.OTHER
}
+3 -4
View File
@@ -24,12 +24,11 @@ declare global {
file: {
select: (options?: OpenDialogOptions) => Promise<FileType[] | null>
upload: (file: FileType) => Promise<FileType>
delete: (fileId: string) => Promise<{ success: boolean }>
delete: (fileId: string) => Promise<void>
read: (fileId: string) => Promise<string>
base64Image: (fileId: string) => Promise<{ mime: string; base64: string; data: string }>
clear: () => Promise<void>
}
image: {
base64: (filePath: string) => Promise<{ mime: string; base64: string; data: string }>
}
}
}
}
+2 -3
View File
@@ -20,10 +20,9 @@ const api = {
select: (options?: OpenDialogOptions) => ipcRenderer.invoke('file:select', options),
upload: (filePath: string) => ipcRenderer.invoke('file:upload', filePath),
delete: (fileId: string) => ipcRenderer.invoke('file:delete', fileId),
read: (fileId: string) => ipcRenderer.invoke('file:read', fileId),
base64Image: (fileId: string) => ipcRenderer.invoke('file:base64Image', fileId),
clear: () => ipcRenderer.invoke('file:clear')
},
image: {
base64: (filePath: string) => ipcRenderer.invoke('image:base64', filePath)
}
}
Binary file not shown.

After

Width:  |  Height:  |  Size: 1.4 KiB

+92
View File
@@ -7,3 +7,95 @@ export const platform = window.electron?.process?.platform
export const isMac = platform === 'darwin'
export const isWindows = platform === 'win32' || platform === 'win64'
export const isLinux = platform === 'linux'
export const imageExts = ['jpg', 'png', 'jpeg']
export const textExts = [
'.txt', // 普通文本文件
'.md', // Markdown 文件
'.mdx', // Markdown 文件
'.html', // HTML 文件
'.htm', // HTML 文件的另一种扩展名
'.xml', // XML 文件
'.json', // JSON 文件
'.yaml', // YAML 文件
'.yml', // YAML 文件的另一种扩展名
'.csv', // 逗号分隔值文件
'.tsv', // 制表符分隔值文件
'.ini', // 配置文件
'.log', // 日志文件
'.rtf', // 富文本格式文件
'.tex', // LaTeX 文件
'.srt', // 字幕文件
'.xhtml', // XHTML 文件
'.nfo', // 信息文件(主要用于场景发布)
'.conf', // 配置文件
'.config', // 配置文件
'.env', // 环境变量文件
'.properties', // 配置属性文件
'.latex', // LaTeX 文档文件
'.rst', // reStructuredText 文件
'.php', // PHP 脚本文件,包含嵌入的 HTML
'.js', // JavaScript 文件(部分是文本,部分可能包含代码)
'.ts', // TypeScript 文件
'.jsp', // JavaServer Pages 文件
'.aspx', // ASP.NET 文件
'.bat', // Windows 批处理文件
'.sh', // Unix/Linux Shell 脚本文件
'.py', // Python 脚本文件
'.rb', // Ruby 脚本文件
'.pl', // Perl 脚本文件
'.sql', // SQL 脚本文件
'.css', // Cascading Style Sheets 文件
'.less', // Less CSS 预处理器文件
'.scss', // Sass CSS 预处理器文件
'.sass', // Sass 文件
'.styl', // Stylus CSS 预处理器文件
'.coffee', // CoffeeScript 文件
'.ino', // Arduino 代码文件
'.ino', // Arduino 代码文件
'.asm', // Assembly 语言文件
'.go', // Go 语言文件
'.scala', // Scala 语言文件
'.swift', // Swift 语言文件
'.kt', // Kotlin 语言文件
'.rs', // Rust 语言文件
'.lua', // Lua 语言文件
'.groovy', // Groovy 语言文件
'.dart', // Dart 语言文件
'.hs', // Haskell 语言文件
'.clj', // Clojure 语言文件
'.cljs', // ClojureScript 语言文件
'.elm', // Elm 语言文件
'.erl', // Erlang 语言文件
'.ex', // Elixir 语言文件
'.exs', // Elixir 脚本文件
'.pug', // Pug (formerly Jade) 模板文件
'.haml', // Haml 模板文件
'.slim', // Slim 模板文件
'.tpl', // 模板文件(通用)
'.ejs', // Embedded JavaScript 模板文件
'.hbs', // Handlebars 模板文件
'.mustache', // Mustache 模板文件
'.jade', // Jade 模板文件 (已重命名为 Pug)
'.twig', // Twig 模板文件
'.blade', // Blade 模板文件 (Laravel)
'.vue', // Vue.js 单文件组件
'.jsx', // React JSX 文件
'.tsx', // React TSX 文件
'.graphql', // GraphQL 查询语言文件
'.gql', // GraphQL 查询语言文件
'.proto', // Protocol Buffers 文件
'.thrift', // Thrift 文件
'.toml', // TOML 配置文件
'.edn', // Clojure 数据表示文件
'.cake', // CakePHP 配置文件
'.ctp', // CakePHP 视图文件
'.cfm', // ColdFusion 标记语言文件
'.cfc', // ColdFusion 组件文件
'.m', // Objective-C 源文件
'.mm', // Objective-C++ 源文件
'.gradle', // Gradle 构建文件
'.groovy', // Gradle 构建文件
'.gradle', // Gradle 构建文件
'.kts' // Kotlin Script 文件
]
+4 -1
View File
@@ -14,6 +14,7 @@ import GemmaModelLogo from '@renderer/assets/images/models/gemma.jpeg'
import HailuoModelLogo from '@renderer/assets/images/models/hailuo.png'
import LlamaModelLogo from '@renderer/assets/images/models/llama.jpeg'
import MicrosoftModelLogo from '@renderer/assets/images/models/microsoft.png'
import MinicpmModelLogo from '@renderer/assets/images/models/minicpm.webp'
import MixtralModelLogo from '@renderer/assets/images/models/mixtral.jpeg'
import PalmModelLogo from '@renderer/assets/images/models/palm.svg'
import QwenModelLogo from '@renderer/assets/images/models/qwen.png'
@@ -91,6 +92,7 @@ export function getModelLogo(modelId: string) {
}
const logoMap = {
o1: OpenAiProviderLogo,
gpt: ChatGPTModelLogo,
glm: ChatGLMModelLogo,
deepseek: DeepSeekModelLogo,
@@ -112,7 +114,8 @@ export function getModelLogo(modelId: string) {
abab: HailuoModelLogo,
'ep-202': DoubaoModelLogo,
cohere: CohereModelLogo,
command: CohereModelLogo
command: CohereModelLogo,
minicpm: MinicpmModelLogo
}
for (const key in logoMap) {
+2 -2
View File
@@ -88,7 +88,7 @@ const resources = {
'input.send': 'Send',
'input.pause': 'Pause',
'input.settings': 'Settings',
'input.upload': 'Upload image png、jpg、jpeg',
'input.upload': 'Upload image or text file',
'input.context_count.tip': 'Context Count',
'input.estimated_tokens.tip': 'Estimated tokens',
'settings.temperature': 'Temperature',
@@ -356,7 +356,7 @@ const resources = {
'input.send': '发送',
'input.pause': '暂停',
'input.settings': '设置',
'input.upload': '上传图片 png、jpg、jpeg',
'input.upload': '上传图片或纯文本文件',
'input.context_count.tip': '上下文数',
'input.estimated_tokens.tip': '预估 token 数',
'settings.temperature': '模型温度',
+13 -8
View File
@@ -1,7 +1,8 @@
import { Navbar, NavbarCenter } from '@renderer/components/app/Navbar'
import { VStack } from '@renderer/components/Layout'
import db from '@renderer/databases'
import { FileType } from '@renderer/types'
import { FileType, FileTypes } from '@renderer/types'
import { getFileDirectory } from '@renderer/utils'
import { Image, Table } from 'antd'
import dayjs from 'dayjs'
import { useLiveQuery } from 'dexie-react-hooks'
@@ -13,13 +14,17 @@ const FilesPage: FC = () => {
const { t } = useTranslation()
const files = useLiveQuery<FileType[]>(() => db.files.toArray())
const dataSource = files?.map((file) => ({
key: file.id,
file: <Image src={'file://' + file.path} preview={false} style={{ maxHeight: '40px' }} />,
name: <a href={'file://' + file.path}>{file.origin_name}</a>,
size: `${(file.size / 1024 / 1024).toFixed(2)} MB`,
created_at: dayjs(file.created_at).format('MM-DD HH:mm')
}))
const dataSource = files?.map((file) => {
const isImage = file.type === FileTypes.IMAGE
const ImageView = <Image src={'file://' + file.path} preview={false} style={{ maxHeight: '40px' }} />
return {
key: file.id,
file: isImage ? ImageView : file.origin_name,
name: <a href={'file://' + getFileDirectory(file.path)}>{file.origin_name}</a>,
size: `${(file.size / 1024 / 1024).toFixed(2)} MB`,
created_at: dayjs(file.created_at).format('MM-DD HH:mm')
}
})
const columns = [
{
@@ -1,4 +1,5 @@
import { PaperClipOutlined } from '@ant-design/icons'
import { imageExts, textExts } from '@renderer/config/constant'
import { isVisionModel } from '@renderer/config/models'
import { FileType, Model } from '@renderer/types'
import { Tooltip } from 'antd'
@@ -14,18 +15,13 @@ interface Props {
const AttachmentButton: FC<Props> = ({ model, files, setFiles, ToolbarButton }) => {
const { t } = useTranslation()
const extensions = isVisionModel(model) ? [...imageExts, ...textExts] : [...textExts]
const onSelectFile = async () => {
const _files = await window.api.file.select({
filters: [{ name: 'Files', extensions: ['jpg', 'png', 'jpeg'] }]
})
const _files = await window.api.file.select({ filters: [{ name: 'Files', extensions }] })
_files && setFiles(_files)
}
if (!isVisionModel(model)) {
return null
}
return (
<Tooltip placement="top" title={t('chat.input.upload')} arrow>
<ToolbarButton type="text" className={files.length ? 'active' : ''} onClick={onSelectFile}>
@@ -15,7 +15,7 @@ import { useRuntime, useShowTopics } from '@renderer/hooks/useStore'
import { getDefaultTopic } from '@renderer/services/assistant'
import { EVENT_NAMES, EventEmitter } from '@renderer/services/event'
import FileManager from '@renderer/services/file'
import { estimateInputTokenCount } from '@renderer/services/messages'
import { estimateTextTokens } from '@renderer/services/tokens'
import store, { useAppDispatch, useAppSelector } from '@renderer/store'
import { setGenerating, setSearching } from '@renderer/store/runtime'
import { Assistant, FileType, Message, Topic } from '@renderer/types'
@@ -92,7 +92,7 @@ const Inputbar: FC<Props> = ({ assistant, setActiveTopic }) => {
setExpend(false)
}, [assistant.id, assistant.topics, generating, files, text])
const inputTokenCount = useMemo(() => estimateInputTokenCount(text), [text])
const inputTokenCount = useMemo(() => estimateTextTokens(text), [text])
const handleKeyDown = (event: React.KeyboardEvent<HTMLTextAreaElement>) => {
const isEnterPressed = event.keyCode == 13
@@ -44,8 +44,8 @@ const TokenCount: FC<Props> = ({ estimateTokenCount, inputTokenCount, contextCou
<PicCenterOutlined />
</Tooltip>
</ToolbarButton>
<Container {...props}>
<Popover content={PopoverContent} title="" mouseEnterDelay={0.6}>
<Container>
<Popover content={PopoverContent}>
<MenuOutlined /> {contextCount}
<Divider type="vertical" style={{ marginTop: 0, marginLeft: 5, marginRight: 5 }} />
<ArrowUpOutlined />
@@ -46,14 +46,12 @@ const MessageItem: FC<Props> = ({ message, index, showMenu, onDeleteMessage }) =
const { assistant, setModel } = useAssistant(message.assistantId)
const model = useModel(message.modelId)
const { userName, showMessageDivider, messageFont, fontSize } = useSettings()
const { generating } = useRuntime()
const [copied, setCopied] = useState(false)
const isLastMessage = index === 0
const isUserMessage = message.role === 'user'
const isAssistantMessage = message.role === 'assistant'
const canRegenerate = isLastMessage && isAssistantMessage
const showMetadata = Boolean(message.usage) && !generating
const onCopy = useCallback(() => {
navigator.clipboard.writeText(removeTrailingDoubleSpaces(message.content))
@@ -133,7 +131,7 @@ const MessageItem: FC<Props> = ({ message, index, showMenu, onDeleteMessage }) =
style={{
borderRadius: '20%',
cursor: 'pointer',
border: isLocalAi ? '1px solid var(--color-border)' : ''
border: '1px solid var(--color-border)'
}}
onClick={showMiniApp}>
{avatarName}
@@ -206,18 +204,39 @@ const MessageItem: FC<Props> = ({ message, index, showMenu, onDeleteMessage }) =
)}
</MenusBar>
)}
{showMetadata && (
<MessageMetadata>
Tokens: {message?.usage?.total_tokens} | {message?.usage?.prompt_tokens} |
{message?.usage?.completion_tokens}
</MessageMetadata>
)}
<MessgeTokens message={message} />
</MessageFooter>
</MessageContentContainer>
</MessageContainer>
)
}
const MessgeTokens: React.FC<{ message: Message }> = ({ message }) => {
const { generating } = useRuntime()
if (!message.usage) {
return null
}
if (message.role === 'user') {
return <MessageMetadata>Tokens: {message?.usage?.total_tokens}</MessageMetadata>
}
if (generating) {
return null
}
if (message.role === 'assistant') {
return (
<MessageMetadata>
Tokens: {message?.usage?.total_tokens} | {message?.usage?.prompt_tokens} | {message?.usage?.completion_tokens}
</MessageMetadata>
)
}
return null
}
const MessageContent: React.FC<{ message: Message }> = ({ message }) => {
const { t } = useTranslation()
@@ -1,5 +1,6 @@
import { Message } from '@renderer/types'
import { Image as AntdImage } from 'antd'
import { FileTypes, Message } from '@renderer/types'
import { getFileDirectory } from '@renderer/utils'
import { Image as AntdImage, Upload } from 'antd'
import { FC } from 'react'
import styled from 'styled-components'
@@ -8,9 +9,27 @@ interface Props {
}
const MessageAttachments: FC<Props> = ({ message }) => {
if (message?.files && message.files[0]?.type === FileTypes.IMAGE) {
return (
<Container>
{message.files?.map((image) => <Image src={'file://' + image.path} key={image.id} width="33%" />)}
</Container>
)
}
return (
<Container>
{message.files?.map((image) => <Image src={'file://' + image.path} key={image.id} width="33%" />)}
<Container style={{ marginTop: -5 }}>
<Upload
listType="picture"
disabled
onPreview={(item) => item.url && window.open(getFileDirectory(item.url))}
fileList={message.files?.map((file) => ({
uid: file.id,
url: 'file://' + file.path,
status: 'done',
name: file.origin_name
}))}
/>
</Container>
)
}
@@ -4,16 +4,12 @@ import { getTopic, TopicManager } from '@renderer/hooks/useTopic'
import { fetchChatCompletion, fetchMessagesSummary } from '@renderer/services/api'
import { getDefaultTopic } from '@renderer/services/assistant'
import { EVENT_NAMES, EventEmitter } from '@renderer/services/event'
import {
deleteMessageFiles,
estimateHistoryTokenCount,
filterMessages,
getContextCount
} from '@renderer/services/messages'
import { deleteMessageFiles, filterMessages, getContextCount } from '@renderer/services/messages'
import { estimateHistoryTokens, estimateMessageUsage } from '@renderer/services/tokens'
import { Assistant, Message, Model, Topic } from '@renderer/types'
import { getBriefInfo, runAsyncFunction, uuid } from '@renderer/utils'
import { t } from 'i18next'
import { last, reverse, take } from 'lodash'
import { flatten, last, reverse, take } from 'lodash'
import { FC, useCallback, useEffect, useRef, useState } from 'react'
import styled from 'styled-components'
@@ -34,12 +30,15 @@ const Messages: FC<Props> = ({ assistant, topic, setActiveTopic }) => {
const { updateTopic, addTopic } = useAssistant(assistant.id)
const onSendMessage = useCallback(
(message: Message) => {
async (message: Message) => {
if (message.role === 'user') {
message.usage = await estimateMessageUsage(message)
}
const _messages = [...messages, message]
setMessages(_messages)
db.topics.put({ id: topic.id, messages: _messages })
},
[messages, topic]
[messages, topic.id]
)
const autoRenameTopic = useCallback(async () => {
@@ -68,9 +67,14 @@ const Messages: FC<Props> = ({ assistant, topic, setActiveTopic }) => {
const unsubscribes = [
EventEmitter.on(EVENT_NAMES.SEND_MESSAGE, async (msg: Message) => {
onSendMessage(msg)
fetchChatCompletion({ assistant, messages: [...messages, msg], topic, onResponse: setLastMessage })
fetchChatCompletion({
assistant,
messages: [...messages, msg],
topic,
onResponse: setLastMessage
})
}),
EventEmitter.on(EVENT_NAMES.AI_CHAT_COMPLETION, async (msg: Message) => {
EventEmitter.on(EVENT_NAMES.RECEIVE_MESSAGE, async (msg: Message) => {
setLastMessage(null)
onSendMessage(msg)
setTimeout(() => EventEmitter.emit(EVENT_NAMES.AI_AUTO_RENAME), 100)
@@ -98,6 +102,7 @@ const Messages: FC<Props> = ({ assistant, topic, setActiveTopic }) => {
const lastMessage = last(messages)
if (lastMessage && lastMessage.type === 'clear') {
onDeleteMessage(lastMessage)
return
}
@@ -117,16 +122,37 @@ const Messages: FC<Props> = ({ assistant, topic, setActiveTopic }) => {
} as Message)
}),
EventEmitter.on(EVENT_NAMES.NEW_BRANCH, async (index: number) => {
const _topic = getDefaultTopic()
_topic.name = topic.name
await db.topics.add({ id: _topic.id, messages: take(messages, messages.length - index) })
addTopic(_topic)
setActiveTopic(_topic)
const newTopic = getDefaultTopic()
newTopic.name = topic.name
const branchMessages = take(messages, messages.length - index)
// 将分支的消息放入数据库
await db.topics.add({ id: newTopic.id, messages: branchMessages })
addTopic(newTopic)
setActiveTopic(newTopic)
autoRenameTopic()
// 由于复制了消息,消息中附带的文件的总数变了,需要更新
const filesArr = branchMessages.map((m) => m.files)
const files = flatten(filesArr).filter(Boolean)
files.map(async (f) => {
const file = await db.files.get({ id: f?.id })
file && db.files.update(file.id, { count: file.count + 1 })
})
})
]
return () => unsubscribes.forEach((unsub) => unsub())
}, [addTopic, assistant, autoRenameTopic, messages, onSendMessage, setActiveTopic, topic, updateTopic])
}, [
addTopic,
assistant,
autoRenameTopic,
messages,
onDeleteMessage,
onSendMessage,
setActiveTopic,
topic,
updateTopic
])
useEffect(() => {
runAsyncFunction(async () => {
@@ -140,9 +166,11 @@ const Messages: FC<Props> = ({ assistant, topic, setActiveTopic }) => {
}, [messages])
useEffect(() => {
EventEmitter.emit(EVENT_NAMES.ESTIMATED_TOKEN_COUNT, {
tokensCount: estimateHistoryTokenCount(assistant, messages),
contextCount: getContextCount(assistant, messages)
runAsyncFunction(async () => {
EventEmitter.emit(EVENT_NAMES.ESTIMATED_TOKEN_COUNT, {
tokensCount: await estimateHistoryTokens(assistant, messages),
contextCount: getContextCount(assistant, messages)
})
})
}, [assistant, messages])
@@ -37,7 +37,7 @@ const Suggestions: FC<Props> = ({ assistant, messages, lastMessage }) => {
useEffect(() => {
const unsubscribes = [
EventEmitter.on(EVENT_NAMES.AI_CHAT_COMPLETION, async (msg: Message) => {
EventEmitter.on(EVENT_NAMES.RECEIVE_MESSAGE, async (msg: Message) => {
setLoadingSuggestions(true)
const _suggestions = await fetchSuggestions({ assistant, messages: [...messages, msg] })
if (_suggestions.length) {
+2 -6
View File
@@ -10,12 +10,8 @@ export default class AiProvider {
this.sdk = ProviderFactory.create(provider)
}
public async completions(
messages: Message[],
assistant: Assistant,
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
): Promise<void> {
return this.sdk.completions(messages, assistant, onChunk)
public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
return this.sdk.completions({ messages, assistant, onChunk, onFilterMessages })
}
public async translate(message: Message, assistant: Assistant): Promise<string> {
+51 -33
View File
@@ -4,8 +4,8 @@ import { DEFAULT_MAX_TOKENS } from '@renderer/config/constant'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant'
import { EVENT_NAMES } from '@renderer/services/event'
import { filterContextMessages, filterMessages } from '@renderer/services/messages'
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
import { first, sum, takeRight } from 'lodash'
import { Assistant, FileTypes, Message, Provider, Suggestion } from '@renderer/types'
import { first, flatten, sum, takeRight } from 'lodash'
import OpenAI from 'openai'
import BaseProvider from './BaseProvider'
@@ -18,49 +18,67 @@ export default class AnthropicProvider extends BaseProvider {
this.sdk = new Anthropic({ apiKey: provider.apiKey, baseURL: this.getBaseURL() })
}
private async getMessageContent(message: Message): Promise<MessageParam['content']> {
private async getMessageParam(message: Message): Promise<MessageParam[]> {
const file = first(message.files)
if (!file) {
return message.content
if (file) {
if (file.type === FileTypes.IMAGE) {
const base64Data = await window.api.file.base64Image(file.id + file.ext)
return [
{
role: message.role,
content: [
{ type: 'text', text: message.content },
{
type: 'image',
source: {
data: base64Data.base64,
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
type: 'base64'
}
}
]
} as MessageParam
]
}
if (file.type === FileTypes.TEXT) {
return [
{
role: message.role,
content: message.content
} as MessageParam,
{
role: 'assistant',
content: (await window.api.file.read(file.id + file.ext)).trimEnd()
} as MessageParam
]
}
}
if (file.type === 'image') {
const base64Data = await window.api.image.base64(file.path)
return [
{ type: 'text', text: message.content },
{
type: 'image',
source: {
data: base64Data.base64,
media_type: base64Data.mime.replace('jpg', 'jpeg') as any,
type: 'base64'
}
}
]
}
return message.content
return [
{
role: message.role,
content: message.content
} as MessageParam
]
}
public async completions(
messages: Message[],
assistant: Assistant,
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
) {
public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens } = getAssistantSettings(assistant)
const userMessages: MessageParam[] = []
let userMessagesParams: MessageParam[][] = []
const _messages = filterMessages(filterContextMessages(takeRight(messages, contextCount + 2)))
for (const message of filterMessages(filterContextMessages(takeRight(messages, contextCount + 2)))) {
userMessages.push({
role: message.role,
content: await this.getMessageContent(message)
})
onFilterMessages(_messages)
for (const message of _messages) {
userMessagesParams = userMessagesParams.concat(await this.getMessageParam(message))
}
const userMessages = flatten(userMessagesParams)
if (first(userMessages)?.role === 'assistant') {
userMessages.shift()
}
@@ -69,7 +87,7 @@ export default class AnthropicProvider extends BaseProvider {
const stream = this.sdk.messages
.stream({
model: model.id,
messages: userMessages.filter(Boolean) as MessageParam[],
messages: userMessages,
max_tokens: maxTokens || DEFAULT_MAX_TOKENS,
temperature: assistant?.settings?.temperature,
system: assistant.prompt,
+1 -5
View File
@@ -20,11 +20,7 @@ export default abstract class BaseProvider {
return this.provider.id === 'ollama' ? getOllamaKeepAliveTime() : undefined
}
abstract completions(
messages: Message[],
assistant: Assistant,
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
): Promise<void>
abstract completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void>
abstract translate(message: Message, assistant: Assistant): Promise<string>
abstract summaries(messages: Message[], assistant: Assistant): Promise<string>
abstract suggestions(messages: Message[], assistant: Assistant): Promise<Suggestion[]>
+55 -42
View File
@@ -1,10 +1,10 @@
import { Content, GoogleGenerativeAI, InlineDataPart, Part } from '@google/generative-ai'
import { Content, GoogleGenerativeAI, InlineDataPart, TextPart } from '@google/generative-ai'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant'
import { EVENT_NAMES } from '@renderer/services/event'
import { filterContextMessages, filterMessages } from '@renderer/services/messages'
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
import { Assistant, FileTypes, Message, Provider, Suggestion } from '@renderer/types'
import axios from 'axios'
import { first, isEmpty, takeRight } from 'lodash'
import { first, flatten, isEmpty, takeRight } from 'lodash'
import OpenAI from 'openai'
import BaseProvider from './BaseProvider'
@@ -17,42 +17,67 @@ export default class GeminiProvider extends BaseProvider {
this.sdk = new GoogleGenerativeAI(provider.apiKey)
}
private async getMessageParts(message: Message): Promise<Part[]> {
private async getMessageContents(message: Message): Promise<Content[]> {
const file = first(message.files)
const role = message.role === 'user' ? 'user' : 'model'
if (file && file.type === 'image') {
const base64Data = await window.api.image.base64(file.path)
return [
{
text: message.content
},
{
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
if (file) {
if (file.type === FileTypes.IMAGE) {
const base64Data = await window.api.file.base64Image(file.id + file.ext)
return [
{
role: message.role,
parts: [
{ text: message.content } as TextPart,
{
inlineData: {
data: base64Data.base64,
mimeType: base64Data.mime
}
} as InlineDataPart
]
}
} as InlineDataPart
]
]
}
if (file.type === FileTypes.TEXT) {
return [
{
role: 'model',
parts: [{ text: await window.api.file.read(file.id + file.ext) } as TextPart]
},
{
role,
parts: [{ text: message.content } as TextPart]
}
]
}
}
return [{ text: message.content }]
return [
{
role,
parts: [{ text: message.content } as TextPart]
}
]
}
public async completions(
messages: Message[],
assistant: Assistant,
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
) {
public async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams) {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens } = getAssistantSettings(assistant)
const userMessages = filterMessages(filterContextMessages(takeRight(messages, contextCount + 1))).map((message) => {
return {
role: message.role,
message
}
})
const userMessages = filterMessages(filterContextMessages(takeRight(messages, contextCount + 1)))
onFilterMessages(userMessages)
const userLastMessage = userMessages.pop()
let historyContents: Content[][] = []
for (const message of userMessages) {
historyContents = historyContents.concat(await this.getMessageContents(message))
}
const history = flatten(historyContents)
const geminiModel = this.sdk.getGenerativeModel({
model: model.id,
@@ -63,21 +88,9 @@ export default class GeminiProvider extends BaseProvider {
}
})
const userLastMessage = userMessages.pop()
const history: Content[] = []
for (const message of userMessages) {
history.push({
role: message.role === 'user' ? 'user' : 'model',
parts: await this.getMessageParts(message.message)
})
}
const chat = geminiModel.startChat({ history })
const message = await this.getMessageParts(userLastMessage?.message!)
const userMessagesStream = await chat.sendMessageStream(message)
const messageContents = await this.getMessageContents(userLastMessage!)
const userMessagesStream = await chat.sendMessageStream(messageContents[0].parts)
for await (const chunk of userMessagesStream.stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
+66 -35
View File
@@ -2,7 +2,7 @@ import { isLocalAi } from '@renderer/config/env'
import { getAssistantSettings, getDefaultModel, getTopNamingModel } from '@renderer/services/assistant'
import { EVENT_NAMES } from '@renderer/services/event'
import { filterContextMessages, filterMessages } from '@renderer/services/messages'
import { Assistant, Message, Provider, Suggestion } from '@renderer/types'
import { Assistant, FileTypes, Message, Provider, Suggestion } from '@renderer/types'
import { removeQuotes } from '@renderer/utils'
import { first, takeRight } from 'lodash'
import OpenAI from 'openai'
@@ -26,61 +26,92 @@ export default class OpenAIProvider extends BaseProvider {
})
}
private async getMessageContent(message: Message): Promise<string | ChatCompletionContentPart[]> {
const file = first(message.files)
if (!file) {
return message.content
private isSupportStreamOutput(modelId: string): boolean {
if (this.provider.id === 'openai' && modelId.includes('o1-')) {
return false
}
if (file.type === 'image') {
const base64Data = await window.api.image.base64(file.path)
return [
{ type: 'text', text: message.content },
{
type: 'image_url',
image_url: {
url: base64Data.data
}
}
]
}
return message.content
return true
}
async completions(
messages: Message[],
assistant: Assistant,
onChunk: ({ text, usage }: { text?: string; usage?: OpenAI.Completions.CompletionUsage }) => void
): Promise<void> {
private async getMessageParam(message: Message): Promise<OpenAI.Chat.Completions.ChatCompletionMessageParam[]> {
const file = first(message.files)
const content: string | ChatCompletionContentPart[] = message.content
if (file) {
if (file.type === FileTypes.IMAGE) {
const image = await window.api.file.base64Image(file.id + file.ext)
return [
{
role: message.role,
content: [
{ type: 'text', text: message.content },
{
type: 'image_url',
image_url: {
url: image.data
}
}
]
} as ChatCompletionMessageParam
]
}
if (file.type === FileTypes.TEXT) {
return [
{
role: 'assistant',
content: await window.api.file.read(file.id + file.ext)
} as ChatCompletionMessageParam,
{
role: message.role,
content
} as ChatCompletionMessageParam
]
}
}
return [
{
role: message.role,
content
} as ChatCompletionMessageParam
]
}
async completions({ messages, assistant, onChunk, onFilterMessages }: CompletionsParams): Promise<void> {
const defaultModel = getDefaultModel()
const model = assistant.model || defaultModel
const { contextCount, maxTokens } = getAssistantSettings(assistant)
const systemMessage = assistant.prompt ? { role: 'system', content: assistant.prompt } : undefined
const userMessages: ChatCompletionMessageParam[] = []
let userMessages: ChatCompletionMessageParam[] = []
for (const message of filterMessages(filterContextMessages(takeRight(messages, contextCount + 1)))) {
userMessages.push({
role: message.role,
content: await this.getMessageContent(message)
} as ChatCompletionMessageParam)
const _messages = filterMessages(filterContextMessages(takeRight(messages, contextCount + 1)))
onFilterMessages(_messages)
for (const message of _messages) {
userMessages = userMessages.concat(await this.getMessageParam(message))
}
// @ts-ignore key is not typed
const stream = await this.sdk.chat.completions.create({
model: model.id,
messages: [systemMessage, ...userMessages].filter(Boolean) as ChatCompletionMessageParam[],
stream: true,
stream: this.isSupportStreamOutput(model.id),
temperature: assistant?.settings?.temperature,
max_tokens: maxTokens,
keep_alive: this.keepAliveTime
})
for await (const chunk of stream) {
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) break
onChunk({ text: chunk.choices[0]?.delta?.content || '', usage: chunk.usage })
if (window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED)) {
break
}
onChunk({
text: chunk.choices[0]?.delta?.content || '',
usage: chunk.usage
})
}
}
+11
View File
@@ -0,0 +1,11 @@
interface ChunkCallbackData {
text?: string
usage?: OpenAI.Completions.CompletionUsage
}
interface CompletionsParams {
messages: Message[]
assistant: Assistant
onChunk: ({ text, usage }: ChunkCallbackData) => void
onFilterMessages: (messages: Message[]) => void
}
+22 -7
View File
@@ -15,7 +15,8 @@ import {
getTranslateModel
} from './assistant'
import { EVENT_NAMES, EventEmitter } from './event'
import { estimateMessagesToken, filterMessages } from './messages'
import { filterMessages } from './messages'
import { estimateMessagesUsage } from './tokens'
export async function fetchChatCompletion({
messages,
@@ -61,13 +62,27 @@ export async function fetchChatCompletion({
}, 1000)
try {
await AI.completions(messages, assistant, ({ text, usage }) => {
message.content = message.content + text || ''
message.usage = usage
onResponse({ ...message, status: 'pending' })
let _messages: Message[] = []
await AI.completions({
messages,
assistant,
onFilterMessages: (messages) => (_messages = messages),
onChunk: ({ text, usage }) => {
message.content = message.content + text || ''
message.usage = usage
onResponse({ ...message, status: 'pending' })
}
})
message.status = 'success'
message.usage = message.usage || (await estimateMessagesToken({ assistant, messages: [...messages, message] }))
if (!message.usage) {
message.usage = await estimateMessagesUsage({
assistant,
messages: [..._messages, message]
})
}
} catch (error: any) {
message.content = `Error: ${error.message}`
message.status = 'error'
@@ -83,7 +98,7 @@ export async function fetchChatCompletion({
message.status = window.keyv.get(EVENT_NAMES.CHAT_COMPLETION_PAUSED) ? 'paused' : message.status
// Emit chat completion event
EventEmitter.emit(EVENT_NAMES.AI_CHAT_COMPLETION, message)
EventEmitter.emit(EVENT_NAMES.RECEIVE_MESSAGE, message)
// Reset generating state
store.dispatch(setGenerating(false))
+1 -1
View File
@@ -4,7 +4,7 @@ export const EventEmitter = new Emittery()
export const EVENT_NAMES = {
SEND_MESSAGE: 'SEND_MESSAGE',
AI_CHAT_COMPLETION: 'AI_CHAT_COMPLETION',
RECEIVE_MESSAGE: 'RECEIVE_MESSAGE',
AI_AUTO_RENAME: 'AI_AUTO_RENAME',
CLEAR_MESSAGES: 'CLEAR_MESSAGES',
ADD_ASSISTANT: 'ADD_ASSISTANT',
+1 -48
View File
@@ -1,10 +1,7 @@
import { DEFAULT_CONEXTCOUNT } from '@renderer/config/constant'
import { Assistant, Message } from '@renderer/types'
import { GPTTokens } from 'gpt-tokens'
import { isEmpty, last, takeRight } from 'lodash'
import { CompletionUsage } from 'openai/resources'
import { isEmpty, takeRight } from 'lodash'
import { getAssistantSettings } from './assistant'
import FileManager from './file'
export const filterMessages = (messages: Message[]) => {
@@ -36,50 +33,6 @@ export function getContextCount(assistant: Assistant, messages: Message[]) {
return messagesCount - (clearIndex + 1)
}
export function estimateInputTokenCount(text: string) {
const input = new GPTTokens({
model: 'gpt-4o',
messages: [{ role: 'user', content: text }]
})
return input.usedTokens - 7
}
export async function estimateMessagesToken({
assistant,
messages
}: {
assistant: Assistant
messages: Message[]
}): Promise<CompletionUsage> {
const responseMessageContent = last(messages)?.content
const inputMessageContent = messages[messages.length - 2]?.content
const completion_tokens = await estimateInputTokenCount(responseMessageContent ?? '')
const prompt_tokens = await estimateInputTokenCount(assistant.prompt + inputMessageContent ?? '')
return {
completion_tokens,
prompt_tokens: prompt_tokens,
total_tokens: prompt_tokens + completion_tokens
} as CompletionUsage
}
export function estimateHistoryTokenCount(assistant: Assistant, msgs: Message[]) {
const { contextCount } = getAssistantSettings(assistant)
const all = new GPTTokens({
model: 'gpt-4o',
messages: [
{ role: 'system', content: assistant.prompt },
...filterMessages(filterContextMessages(takeRight(msgs, contextCount))).map((message) => ({
role: message.role,
content: message.content
}))
]
})
return all.usedTokens - 7
}
export function deleteMessageFiles(message: Message) {
message.files && FileManager.deleteFiles(message.files.map((f) => f.id))
}
+130
View File
@@ -0,0 +1,130 @@
import { Assistant, FileType, FileTypes, Message } from '@renderer/types'
import { GPTTokens } from 'gpt-tokens'
import { flatten, takeRight } from 'lodash'
import { CompletionUsage } from 'openai/resources'
import { getAssistantSettings } from './assistant'
import { filterContextMessages, filterMessages } from './messages'
interface MessageItem {
name?: string
role: 'system' | 'user' | 'assistant'
content: string
}
async function getFileContent(file: FileType) {
if (!file) {
return ''
}
const fileId = file.id + file.ext
if (file.type === FileTypes.IMAGE) {
const data = await window.api.file.base64Image(fileId)
return data.data
}
if (file.type === FileTypes.TEXT) {
return await window.api.file.read(fileId)
}
return ''
}
async function getMessageParam(message: Message): Promise<MessageItem[]> {
const param: MessageItem[] = []
param.push({
role: message.role,
content: message.content
})
if (message.files) {
for (const file of message.files) {
param.push({
role: 'assistant',
content: await getFileContent(file)
})
}
}
return param
}
export function estimateTextTokens(text: string) {
const { usedTokens } = new GPTTokens({
model: 'gpt-4o',
messages: [{ role: 'user', content: text }]
})
return usedTokens - 7
}
export async function estimateMessageUsage(message: Message): Promise<CompletionUsage> {
const { usedTokens, promptUsedTokens, completionUsedTokens } = new GPTTokens({
model: 'gpt-4o',
messages: await getMessageParam(message)
})
const hasImage = message.files?.some((f) => f.type === FileTypes.IMAGE)
return {
prompt_tokens: promptUsedTokens,
completion_tokens: completionUsedTokens,
total_tokens: hasImage ? Math.floor(usedTokens / 80) : usedTokens - 7
}
}
export async function estimateMessagesUsage({
assistant,
messages
}: {
assistant: Assistant
messages: Message[]
}): Promise<CompletionUsage> {
const outputMessage = messages.pop()!
const prompt_tokens = await estimateHistoryTokens(assistant, messages)
const { completion_tokens } = await estimateMessageUsage(outputMessage)
return {
prompt_tokens: await estimateHistoryTokens(assistant, messages),
completion_tokens,
total_tokens: prompt_tokens + completion_tokens
} as CompletionUsage
}
export async function estimateHistoryTokens(assistant: Assistant, msgs: Message[]) {
const { contextCount } = getAssistantSettings(assistant)
const messages = filterMessages(filterContextMessages(takeRight(msgs, contextCount)))
// 有 usage 数据的消息,快速计算总数
const uasageTokens = messages
.filter((m) => m.usage)
.reduce((acc, message) => {
const inputTokens = message.usage?.total_tokens ?? 0
const outputTokens = message.usage!.completion_tokens ?? 0
return acc + (message.role === 'user' ? inputTokens : outputTokens)
}, 0)
// 没有 usage 数据的消息,需要计算每条消息的 token
let allMessages: MessageItem[][] = []
for (const message of messages.filter((m) => !m.usage)) {
const items = await getMessageParam(message)
allMessages = allMessages.concat(items)
}
const { usedTokens } = new GPTTokens({
model: 'gpt-4o',
messages: [
{
role: 'system',
content: assistant.prompt
},
...flatten(allMessages)
]
})
return usedTokens - 7 + uasageTokens
}
+2
View File
@@ -97,12 +97,14 @@ export interface FileType {
type: FileTypes
created_at: Date
count: number
tokens?: number
}
export enum FileTypes {
IMAGE = 'image',
VIDEO = 'video',
AUDIO = 'audio',
TEXT = 'text',
DOCUMENT = 'document',
OTHER = 'other'
}
+6
View File
@@ -229,3 +229,9 @@ export function removeTrailingDoubleSpaces(markdown: string): string {
// 使用正则表达式匹配末尾的两个空格,并替换为空字符串
return markdown.replace(/ {2}$/gm, '')
}
export function getFileDirectory(filePath: string) {
const parts = filePath.split('/')
const directory = parts.slice(0, -1).join('/')
return directory
}