fix: download mutilple binaries (#2043)

Signed-off-by: James <james@jan.ai>
Co-authored-by: James <james@jan.ai>
This commit is contained in:
NamH 2024-02-16 11:32:14 +07:00 committed by GitHub
parent b7e94aac02
commit 42da19a463
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 152 additions and 123 deletions

View File

@ -12,8 +12,9 @@ import {
DownloadEvent,
DownloadRoute,
ModelEvent,
DownloadState,
} from '@janhq/core'
import { DownloadState } from '@janhq/core/.'
import { extractFileName } from './helpers/path'
/**

View File

@ -32,7 +32,8 @@ export default function DownloadingState() {
.map((a) => a.size.total + a.size.total)
.reduce((partialSum, a) => partialSum + a, 0)
const totalPercentage = ((totalCurrentProgress / totalSize) * 100).toFixed(2)
const totalPercentage =
totalSize !== 0 ? ((totalCurrentProgress / totalSize) * 100).toFixed(2) : 0
return (
<Fragment>

View File

@ -1,4 +1,4 @@
import { useMemo } from 'react'
import { useCallback } from 'react'
import { Model } from '@janhq/core'
@ -14,7 +14,7 @@ import {
Progress,
} from '@janhq/uikit'
import { atom, useAtomValue } from 'jotai'
import { useAtomValue } from 'jotai'
import useDownloadModel from '@/hooks/useDownloadModel'
@ -30,14 +30,21 @@ type Props = {
}
const ModalCancelDownload: React.FC<Props> = ({ model, isFromList }) => {
const downloadingModels = useAtomValue(getDownloadingModelAtom)
const downloadAtom = useMemo(
() => atom((get) => get(modelDownloadStateAtom)[model.id]),
[model.id]
)
const downloadState = useAtomValue(downloadAtom)
const cancelText = `Cancel ${formatDownloadPercentage(downloadState.percent)}`
const { abortModelDownload } = useDownloadModel()
const downloadingModels = useAtomValue(getDownloadingModelAtom)
const allDownloadStates = useAtomValue(modelDownloadStateAtom)
const downloadState = allDownloadStates[model.id]
const cancelText = `Cancel ${formatDownloadPercentage(downloadState.percent)}`
const onAbortDownloadClick = useCallback(() => {
if (downloadState?.modelId) {
const model = downloadingModels.find(
(model) => model.id === downloadState.modelId
)
if (model) abortModelDownload(model)
}
}, [downloadState, downloadingModels, abortModelDownload])
return (
<Modal>
@ -77,17 +84,7 @@ const ModalCancelDownload: React.FC<Props> = ({ model, isFromList }) => {
<Button themes="ghost">No</Button>
</ModalClose>
<ModalClose asChild>
<Button
themes="danger"
onClick={() => {
if (downloadState?.modelId) {
const model = downloadingModels.find(
(model) => model.id === downloadState.modelId
)
if (model) abortModelDownload(model)
}
}}
>
<Button themes="danger" onClick={onAbortDownloadClick}>
Yes
</Button>
</ModalClose>

View File

@ -3,7 +3,7 @@ import { PropsWithChildren, useCallback, useEffect } from 'react'
import React from 'react'
import { DownloadEvent, events } from '@janhq/core'
import { DownloadEvent, events, DownloadState } from '@janhq/core'
import { useSetAtom } from 'jotai'
import { setDownloadStateAtom } from '@/hooks/useDownloadState'

View File

@ -7,14 +7,13 @@ import {
abortDownload,
joinPath,
ModelArtifact,
DownloadState,
} from '@janhq/core'
import { useSetAtom } from 'jotai'
import { FeatureToggleContext } from '@/context/FeatureToggle'
import { modelBinFileName } from '@/utils/model'
import { setDownloadStateAtom } from './useDownloadState'
import { extensionManager } from '@/extension/ExtensionManager'
@ -29,7 +28,7 @@ export default function useDownloadModel() {
async (model: Model) => {
const childProgresses: DownloadState[] = model.sources.map(
(source: ModelArtifact) => ({
filename: source.filename,
fileName: source.filename,
modelId: model.id,
time: {
elapsed: 0,
@ -47,7 +46,7 @@ export default function useDownloadModel() {
// set an initial download state
setDownloadState({
filename: '',
fileName: '',
modelId: model.id,
time: {
elapsed: 0,
@ -70,11 +69,12 @@ export default function useDownloadModel() {
[ignoreSSL, proxy, addDownloadingModel, setDownloadState]
)
const abortModelDownload = async (model: Model) => {
await abortDownload(
await joinPath(['models', model.id, modelBinFileName(model)])
)
}
const abortModelDownload = useCallback(async (model: Model) => {
for (const source of model.sources) {
const path = await joinPath(['models', model.id, source.filename])
await abortDownload(path)
}
}, [])
return {
downloadModel,

View File

@ -1,3 +1,4 @@
import { DownloadState } from '@janhq/core'
import { atom } from 'jotai'
import { toaster } from '@/containers/Toast'
@ -20,18 +21,35 @@ export const setDownloadStateAtom = atom(
const currentState = { ...get(modelDownloadStateAtom) }
if (state.downloadState === 'end') {
// download successfully
delete currentState[state.modelId]
set(removeDownloadingModelAtom, state.modelId)
const model = get(configuredModelsAtom).find(
(e) => e.id === state.modelId
const modelDownloadState = currentState[state.modelId]
const updatedChildren: DownloadState[] =
modelDownloadState.children!.filter(
(m) => m.fileName !== state.fileName
)
updatedChildren.push(state)
modelDownloadState.children = updatedChildren
currentState[state.modelId] = modelDownloadState
const isAllChildrenDownloadEnd = modelDownloadState.children?.every(
(m) => m.downloadState === 'end'
)
if (model) set(downloadedModelsAtom, (prev) => [...prev, model])
toaster({
title: 'Download Completed',
description: `Download ${state.modelId} completed`,
type: 'success',
})
if (isAllChildrenDownloadEnd) {
// download successfully
delete currentState[state.modelId]
set(removeDownloadingModelAtom, state.modelId)
const model = get(configuredModelsAtom).find(
(e) => e.id === state.modelId
)
if (model) set(downloadedModelsAtom, (prev) => [...prev, model])
toaster({
title: 'Download Completed',
description: `Download ${state.modelId} completed`,
type: 'success',
})
}
} else if (state.downloadState === 'error') {
// download error
delete currentState[state.modelId]
@ -59,7 +77,62 @@ export const setDownloadStateAtom = atom(
}
} else {
// download in progress
currentState[state.modelId] = state
if (state.size.total === 0) {
// this is initial state, just set the state
currentState[state.modelId] = state
set(modelDownloadStateAtom, currentState)
return
}
const modelDownloadState = currentState[state.modelId]
if (!modelDownloadState) {
console.debug('setDownloadStateAtom: modelDownloadState not found')
return
}
// delete the children if the filename is matched and replace the new state
const updatedChildren: DownloadState[] =
modelDownloadState.children!.filter(
(m) => m.fileName !== state.fileName
)
updatedChildren.push(state)
// re-calculate the overall progress if we have all the children download data
const isAnyChildDownloadNotReady = updatedChildren.some(
(m) => m.size.total === 0
)
modelDownloadState.children = updatedChildren
if (isAnyChildDownloadNotReady) {
// just update the children
currentState[state.modelId] = modelDownloadState
set(modelDownloadStateAtom, currentState)
return
}
const parentTotalSize = modelDownloadState.size.total
if (parentTotalSize === 0) {
// calculate the total size of the parent by sum all children total size
const totalSize = updatedChildren.reduce(
(acc, m) => acc + m.size.total,
0
)
modelDownloadState.size.total = totalSize
}
// calculate the total transferred size by sum all children transferred size
const transferredSize = updatedChildren.reduce(
(acc, m) => acc + m.size.transferred,
0
)
modelDownloadState.size.transferred = transferredSize
modelDownloadState.percent = transferredSize / parentTotalSize
currentState[state.modelId] = modelDownloadState
}
set(modelDownloadStateAtom, currentState)

View File

@ -1,6 +1,4 @@
/* eslint-disable react/display-name */
import { forwardRef, useState } from 'react'
import { useState } from 'react'
import { Model } from '@janhq/core'
import { Badge } from '@janhq/uikit'
@ -11,7 +9,7 @@ type Props = {
model: Model
}
const ExploreModelItem = forwardRef<HTMLDivElement, Props>(({ model }, ref) => {
const ExploreModelItem: React.FC<Props> = ({ model }) => {
const [open, setOpen] = useState('')
const handleToggle = () => {
@ -23,10 +21,7 @@ const ExploreModelItem = forwardRef<HTMLDivElement, Props>(({ model }, ref) => {
}
return (
<div
ref={ref}
className="mb-6 flex flex-col overflow-hidden rounded-xl border border-border bg-background/60"
>
<div className="mb-6 flex flex-col overflow-hidden rounded-xl border border-border bg-background/60">
<ExploreModelItemHeader
model={model}
onClick={handleToggle}
@ -82,17 +77,11 @@ const ExploreModelItem = forwardRef<HTMLDivElement, Props>(({ model }, ref) => {
</span>
<p className="mt-2 font-medium uppercase">{model.format}</p>
</div>
{/* <div className="mt-4">
<span className="font-semibold text-muted-foreground">
Compatibility
</span>
<p className="mt-2 font-medium">-</p>
</div> */}
</div>
</div>
)}
</div>
)
})
}
export default ExploreModelItem

View File

@ -1,5 +1,4 @@
/* eslint-disable react-hooks/exhaustive-deps */
import { useCallback, useMemo } from 'react'
import { useCallback } from 'react'
import { Model } from '@janhq/core'
import {
@ -12,7 +11,7 @@ import {
TooltipTrigger,
} from '@janhq/uikit'
import { atom, useAtomValue } from 'jotai'
import { useAtomValue } from 'jotai'
import { ChevronDownIcon } from 'lucide-react'
@ -25,8 +24,6 @@ import { MainViewState } from '@/constants/screens'
import { useCreateNewThread } from '@/hooks/useCreateNewThread'
import useDownloadModel from '@/hooks/useDownloadModel'
import { modelDownloadStateAtom } from '@/hooks/useDownloadState'
import { useMainViewState } from '@/hooks/useMainViewState'
import { toGibibytes } from '@/utils/converter'
@ -34,7 +31,10 @@ import { toGibibytes } from '@/utils/converter'
import { assistantsAtom } from '@/helpers/atoms/Assistant.atom'
import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom'
import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom'
import {
downloadedModelsAtom,
getDownloadingModelAtom,
} from '@/helpers/atoms/Model.atom'
import {
nvidiaTotalVramAtom,
totalRamAtom,
@ -46,12 +46,32 @@ type Props = {
open: string
}
const getLabel = (size: number, ram: number) => {
if (size * 1.25 >= ram) {
return (
<Badge className="rounded-md" themes="danger">
Not enough RAM
</Badge>
)
} else {
return (
<Badge className="rounded-md" themes="success">
Recommended
</Badge>
)
}
}
const ExploreModelItemHeader: React.FC<Props> = ({ model, onClick, open }) => {
const { downloadModel } = useDownloadModel()
const downloadingModels = useAtomValue(getDownloadingModelAtom)
const downloadedModels = useAtomValue(downloadedModelsAtom)
const { requestCreateNewThread } = useCreateNewThread()
const totalRam = useAtomValue(totalRamAtom)
const nvidiaTotalVram = useAtomValue(nvidiaTotalVramAtom)
const { setMainViewState } = useMainViewState()
// Default nvidia returns vram in MB, need to convert to bytes to match the unit of totalRamW
let ram = nvidiaTotalVram * 1024 * 1024
if (ram === 0) {
@ -60,16 +80,9 @@ const ExploreModelItemHeader: React.FC<Props> = ({ model, onClick, open }) => {
const serverEnabled = useAtomValue(serverEnabledAtom)
const assistants = useAtomValue(assistantsAtom)
const downloadAtom = useMemo(
() => atom((get) => get(modelDownloadStateAtom)[model.id]),
[model.id]
)
const downloadState = useAtomValue(downloadAtom)
const { setMainViewState } = useMainViewState()
const onDownloadClick = useCallback(() => {
downloadModel(model)
}, [model])
}, [model, downloadModel])
const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null
@ -85,6 +98,8 @@ const ExploreModelItemHeader: React.FC<Props> = ({ model, onClick, open }) => {
</Button>
)
const isDownloading = downloadingModels.some((md) => md.id === model.id)
const onUseModelClick = useCallback(async () => {
if (assistants.length === 0) {
alert('No assistant available')
@ -92,7 +107,7 @@ const ExploreModelItemHeader: React.FC<Props> = ({ model, onClick, open }) => {
}
await requestCreateNewThread(assistants[0], model)
setMainViewState(MainViewState.Thread)
}, [])
}, [assistants, model, requestCreateNewThread, setMainViewState])
if (isDownloaded) {
downloadButton = (
@ -117,26 +132,10 @@ const ExploreModelItemHeader: React.FC<Props> = ({ model, onClick, open }) => {
)}
</Tooltip>
)
} else if (downloadState != null) {
} else if (isDownloading) {
downloadButton = <ModalCancelDownload model={model} />
}
const getLabel = (size: number) => {
if (size * 1.25 >= ram) {
return (
<Badge className="rounded-md" themes="danger">
Not enough RAM
</Badge>
)
} else {
return (
<Badge className="rounded-md" themes="success">
Recommended
</Badge>
)
}
}
return (
<div
className="cursor-pointer rounded-t-md bg-background"
@ -159,7 +158,7 @@ const ExploreModelItemHeader: React.FC<Props> = ({ model, onClick, open }) => {
<span className="mr-4 font-semibold text-muted-foreground">
{toGibibytes(model.metadata.size)}
</span>
{getLabel(model.metadata.size)}
{getLabel(model.metadata.size, ram)}
{downloadButton}
<ChevronDownIcon

View File

@ -1,21 +0,0 @@
type DownloadState = {
modelId: string
filename: string
time: DownloadTime
speed: number
percent: number
size: DownloadSize
children?: DownloadState[]
error?: string
downloadState: 'downloading' | 'error' | 'end'
}
type DownloadTime = {
elapsed: number
remaining: number
}
type DownloadSize = {
total: number
transferred: number
}

View File

@ -1,10 +0,0 @@
import { Model } from '@janhq/core'
export const modelBinFileName = (model: Model) => {
const modelFormatExt = '.gguf'
const extractedFileName = model.sources[0]?.url.split('/').pop() ?? model.id
const fileName = extractedFileName.toLowerCase().endsWith(modelFormatExt)
? extractedFileName
: model.id
return fileName
}