From 42da19a4632c9ceadfddb6dcf85e5b160a3438ee Mon Sep 17 00:00:00 2001 From: NamH Date: Fri, 16 Feb 2024 11:32:14 +0700 Subject: [PATCH] fix: download mutilple binaries (#2043) Signed-off-by: James Co-authored-by: James --- extensions/model-extension/src/index.ts | 3 +- .../BottomBar/DownloadingState/index.tsx | 3 +- web/containers/ModalCancelDownload/index.tsx | 37 ++++--- web/containers/Providers/EventListener.tsx | 2 +- web/hooks/useDownloadModel.ts | 18 ++-- web/hooks/useDownloadState.ts | 97 ++++++++++++++++--- .../ExploreModels/ExploreModelItem/index.tsx | 19 +--- .../ExploreModelItemHeader/index.tsx | 65 ++++++------- web/types/downloadState.d.ts | 21 ---- web/utils/model.ts | 10 -- 10 files changed, 152 insertions(+), 123 deletions(-) delete mode 100644 web/types/downloadState.d.ts delete mode 100644 web/utils/model.ts diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts index 817794c74..926e65ee5 100644 --- a/extensions/model-extension/src/index.ts +++ b/extensions/model-extension/src/index.ts @@ -12,8 +12,9 @@ import { DownloadEvent, DownloadRoute, ModelEvent, + DownloadState, } from '@janhq/core' -import { DownloadState } from '@janhq/core/.' + import { extractFileName } from './helpers/path' /** diff --git a/web/containers/Layout/BottomBar/DownloadingState/index.tsx b/web/containers/Layout/BottomBar/DownloadingState/index.tsx index c7191d0b9..dcebacd3c 100644 --- a/web/containers/Layout/BottomBar/DownloadingState/index.tsx +++ b/web/containers/Layout/BottomBar/DownloadingState/index.tsx @@ -32,7 +32,8 @@ export default function DownloadingState() { .map((a) => a.size.total + a.size.total) .reduce((partialSum, a) => partialSum + a, 0) - const totalPercentage = ((totalCurrentProgress / totalSize) * 100).toFixed(2) + const totalPercentage = + totalSize !== 0 ? ((totalCurrentProgress / totalSize) * 100).toFixed(2) : 0 return ( diff --git a/web/containers/ModalCancelDownload/index.tsx b/web/containers/ModalCancelDownload/index.tsx index 8d08665f4..d52fbe5e9 100644 --- a/web/containers/ModalCancelDownload/index.tsx +++ b/web/containers/ModalCancelDownload/index.tsx @@ -1,4 +1,4 @@ -import { useMemo } from 'react' +import { useCallback } from 'react' import { Model } from '@janhq/core' @@ -14,7 +14,7 @@ import { Progress, } from '@janhq/uikit' -import { atom, useAtomValue } from 'jotai' +import { useAtomValue } from 'jotai' import useDownloadModel from '@/hooks/useDownloadModel' @@ -30,14 +30,21 @@ type Props = { } const ModalCancelDownload: React.FC = ({ model, isFromList }) => { - const downloadingModels = useAtomValue(getDownloadingModelAtom) - const downloadAtom = useMemo( - () => atom((get) => get(modelDownloadStateAtom)[model.id]), - [model.id] - ) - const downloadState = useAtomValue(downloadAtom) - const cancelText = `Cancel ${formatDownloadPercentage(downloadState.percent)}` const { abortModelDownload } = useDownloadModel() + const downloadingModels = useAtomValue(getDownloadingModelAtom) + const allDownloadStates = useAtomValue(modelDownloadStateAtom) + const downloadState = allDownloadStates[model.id] + + const cancelText = `Cancel ${formatDownloadPercentage(downloadState.percent)}` + + const onAbortDownloadClick = useCallback(() => { + if (downloadState?.modelId) { + const model = downloadingModels.find( + (model) => model.id === downloadState.modelId + ) + if (model) abortModelDownload(model) + } + }, [downloadState, downloadingModels, abortModelDownload]) return ( @@ -77,17 +84,7 @@ const ModalCancelDownload: React.FC = ({ model, isFromList }) => { - diff --git a/web/containers/Providers/EventListener.tsx b/web/containers/Providers/EventListener.tsx index a72faf924..100805e17 100644 --- a/web/containers/Providers/EventListener.tsx +++ b/web/containers/Providers/EventListener.tsx @@ -3,7 +3,7 @@ import { PropsWithChildren, useCallback, useEffect } from 'react' import React from 'react' -import { DownloadEvent, events } from '@janhq/core' +import { DownloadEvent, events, DownloadState } from '@janhq/core' import { useSetAtom } from 'jotai' import { setDownloadStateAtom } from '@/hooks/useDownloadState' diff --git a/web/hooks/useDownloadModel.ts b/web/hooks/useDownloadModel.ts index 3c544a24c..2dfef0445 100644 --- a/web/hooks/useDownloadModel.ts +++ b/web/hooks/useDownloadModel.ts @@ -7,14 +7,13 @@ import { abortDownload, joinPath, ModelArtifact, + DownloadState, } from '@janhq/core' import { useSetAtom } from 'jotai' import { FeatureToggleContext } from '@/context/FeatureToggle' -import { modelBinFileName } from '@/utils/model' - import { setDownloadStateAtom } from './useDownloadState' import { extensionManager } from '@/extension/ExtensionManager' @@ -29,7 +28,7 @@ export default function useDownloadModel() { async (model: Model) => { const childProgresses: DownloadState[] = model.sources.map( (source: ModelArtifact) => ({ - filename: source.filename, + fileName: source.filename, modelId: model.id, time: { elapsed: 0, @@ -47,7 +46,7 @@ export default function useDownloadModel() { // set an initial download state setDownloadState({ - filename: '', + fileName: '', modelId: model.id, time: { elapsed: 0, @@ -70,11 +69,12 @@ export default function useDownloadModel() { [ignoreSSL, proxy, addDownloadingModel, setDownloadState] ) - const abortModelDownload = async (model: Model) => { - await abortDownload( - await joinPath(['models', model.id, modelBinFileName(model)]) - ) - } + const abortModelDownload = useCallback(async (model: Model) => { + for (const source of model.sources) { + const path = await joinPath(['models', model.id, source.filename]) + await abortDownload(path) + } + }, []) return { downloadModel, diff --git a/web/hooks/useDownloadState.ts b/web/hooks/useDownloadState.ts index 863c612ed..3fdd2cbc6 100644 --- a/web/hooks/useDownloadState.ts +++ b/web/hooks/useDownloadState.ts @@ -1,3 +1,4 @@ +import { DownloadState } from '@janhq/core' import { atom } from 'jotai' import { toaster } from '@/containers/Toast' @@ -20,18 +21,35 @@ export const setDownloadStateAtom = atom( const currentState = { ...get(modelDownloadStateAtom) } if (state.downloadState === 'end') { - // download successfully - delete currentState[state.modelId] - set(removeDownloadingModelAtom, state.modelId) - const model = get(configuredModelsAtom).find( - (e) => e.id === state.modelId + const modelDownloadState = currentState[state.modelId] + + const updatedChildren: DownloadState[] = + modelDownloadState.children!.filter( + (m) => m.fileName !== state.fileName + ) + updatedChildren.push(state) + modelDownloadState.children = updatedChildren + currentState[state.modelId] = modelDownloadState + + const isAllChildrenDownloadEnd = modelDownloadState.children?.every( + (m) => m.downloadState === 'end' ) - if (model) set(downloadedModelsAtom, (prev) => [...prev, model]) - toaster({ - title: 'Download Completed', - description: `Download ${state.modelId} completed`, - type: 'success', - }) + + if (isAllChildrenDownloadEnd) { + // download successfully + delete currentState[state.modelId] + set(removeDownloadingModelAtom, state.modelId) + + const model = get(configuredModelsAtom).find( + (e) => e.id === state.modelId + ) + if (model) set(downloadedModelsAtom, (prev) => [...prev, model]) + toaster({ + title: 'Download Completed', + description: `Download ${state.modelId} completed`, + type: 'success', + }) + } } else if (state.downloadState === 'error') { // download error delete currentState[state.modelId] @@ -59,7 +77,62 @@ export const setDownloadStateAtom = atom( } } else { // download in progress - currentState[state.modelId] = state + if (state.size.total === 0) { + // this is initial state, just set the state + currentState[state.modelId] = state + set(modelDownloadStateAtom, currentState) + return + } + + const modelDownloadState = currentState[state.modelId] + if (!modelDownloadState) { + console.debug('setDownloadStateAtom: modelDownloadState not found') + return + } + + // delete the children if the filename is matched and replace the new state + const updatedChildren: DownloadState[] = + modelDownloadState.children!.filter( + (m) => m.fileName !== state.fileName + ) + + updatedChildren.push(state) + + // re-calculate the overall progress if we have all the children download data + const isAnyChildDownloadNotReady = updatedChildren.some( + (m) => m.size.total === 0 + ) + + modelDownloadState.children = updatedChildren + + if (isAnyChildDownloadNotReady) { + // just update the children + currentState[state.modelId] = modelDownloadState + set(modelDownloadStateAtom, currentState) + + return + } + + const parentTotalSize = modelDownloadState.size.total + if (parentTotalSize === 0) { + // calculate the total size of the parent by sum all children total size + const totalSize = updatedChildren.reduce( + (acc, m) => acc + m.size.total, + 0 + ) + + modelDownloadState.size.total = totalSize + } + + // calculate the total transferred size by sum all children transferred size + const transferredSize = updatedChildren.reduce( + (acc, m) => acc + m.size.transferred, + 0 + ) + modelDownloadState.size.transferred = transferredSize + modelDownloadState.percent = transferredSize / parentTotalSize + + currentState[state.modelId] = modelDownloadState } set(modelDownloadStateAtom, currentState) diff --git a/web/screens/ExploreModels/ExploreModelItem/index.tsx b/web/screens/ExploreModels/ExploreModelItem/index.tsx index 553c73a49..9cdfbc01a 100644 --- a/web/screens/ExploreModels/ExploreModelItem/index.tsx +++ b/web/screens/ExploreModels/ExploreModelItem/index.tsx @@ -1,6 +1,4 @@ -/* eslint-disable react/display-name */ - -import { forwardRef, useState } from 'react' +import { useState } from 'react' import { Model } from '@janhq/core' import { Badge } from '@janhq/uikit' @@ -11,7 +9,7 @@ type Props = { model: Model } -const ExploreModelItem = forwardRef(({ model }, ref) => { +const ExploreModelItem: React.FC = ({ model }) => { const [open, setOpen] = useState('') const handleToggle = () => { @@ -23,10 +21,7 @@ const ExploreModelItem = forwardRef(({ model }, ref) => { } return ( -
+
(({ model }, ref) => {

{model.format}

- {/*
- - Compatibility - -

-

-
*/}
)} ) -}) +} export default ExploreModelItem diff --git a/web/screens/ExploreModels/ExploreModelItemHeader/index.tsx b/web/screens/ExploreModels/ExploreModelItemHeader/index.tsx index cf8c68821..7af5d3d97 100644 --- a/web/screens/ExploreModels/ExploreModelItemHeader/index.tsx +++ b/web/screens/ExploreModels/ExploreModelItemHeader/index.tsx @@ -1,5 +1,4 @@ -/* eslint-disable react-hooks/exhaustive-deps */ -import { useCallback, useMemo } from 'react' +import { useCallback } from 'react' import { Model } from '@janhq/core' import { @@ -12,7 +11,7 @@ import { TooltipTrigger, } from '@janhq/uikit' -import { atom, useAtomValue } from 'jotai' +import { useAtomValue } from 'jotai' import { ChevronDownIcon } from 'lucide-react' @@ -25,8 +24,6 @@ import { MainViewState } from '@/constants/screens' import { useCreateNewThread } from '@/hooks/useCreateNewThread' import useDownloadModel from '@/hooks/useDownloadModel' -import { modelDownloadStateAtom } from '@/hooks/useDownloadState' - import { useMainViewState } from '@/hooks/useMainViewState' import { toGibibytes } from '@/utils/converter' @@ -34,7 +31,10 @@ import { toGibibytes } from '@/utils/converter' import { assistantsAtom } from '@/helpers/atoms/Assistant.atom' import { serverEnabledAtom } from '@/helpers/atoms/LocalServer.atom' -import { downloadedModelsAtom } from '@/helpers/atoms/Model.atom' +import { + downloadedModelsAtom, + getDownloadingModelAtom, +} from '@/helpers/atoms/Model.atom' import { nvidiaTotalVramAtom, totalRamAtom, @@ -46,12 +46,32 @@ type Props = { open: string } +const getLabel = (size: number, ram: number) => { + if (size * 1.25 >= ram) { + return ( + + Not enough RAM + + ) + } else { + return ( + + Recommended + + ) + } +} + const ExploreModelItemHeader: React.FC = ({ model, onClick, open }) => { const { downloadModel } = useDownloadModel() + const downloadingModels = useAtomValue(getDownloadingModelAtom) const downloadedModels = useAtomValue(downloadedModelsAtom) const { requestCreateNewThread } = useCreateNewThread() const totalRam = useAtomValue(totalRamAtom) + const nvidiaTotalVram = useAtomValue(nvidiaTotalVramAtom) + const { setMainViewState } = useMainViewState() + // Default nvidia returns vram in MB, need to convert to bytes to match the unit of totalRamW let ram = nvidiaTotalVram * 1024 * 1024 if (ram === 0) { @@ -60,16 +80,9 @@ const ExploreModelItemHeader: React.FC = ({ model, onClick, open }) => { const serverEnabled = useAtomValue(serverEnabledAtom) const assistants = useAtomValue(assistantsAtom) - const downloadAtom = useMemo( - () => atom((get) => get(modelDownloadStateAtom)[model.id]), - [model.id] - ) - const downloadState = useAtomValue(downloadAtom) - const { setMainViewState } = useMainViewState() - const onDownloadClick = useCallback(() => { downloadModel(model) - }, [model]) + }, [model, downloadModel]) const isDownloaded = downloadedModels.find((md) => md.id === model.id) != null @@ -85,6 +98,8 @@ const ExploreModelItemHeader: React.FC = ({ model, onClick, open }) => { ) + const isDownloading = downloadingModels.some((md) => md.id === model.id) + const onUseModelClick = useCallback(async () => { if (assistants.length === 0) { alert('No assistant available') @@ -92,7 +107,7 @@ const ExploreModelItemHeader: React.FC = ({ model, onClick, open }) => { } await requestCreateNewThread(assistants[0], model) setMainViewState(MainViewState.Thread) - }, []) + }, [assistants, model, requestCreateNewThread, setMainViewState]) if (isDownloaded) { downloadButton = ( @@ -117,26 +132,10 @@ const ExploreModelItemHeader: React.FC = ({ model, onClick, open }) => { )} ) - } else if (downloadState != null) { + } else if (isDownloading) { downloadButton = } - const getLabel = (size: number) => { - if (size * 1.25 >= ram) { - return ( - - Not enough RAM - - ) - } else { - return ( - - Recommended - - ) - } - } - return (
= ({ model, onClick, open }) => { {toGibibytes(model.metadata.size)} - {getLabel(model.metadata.size)} + {getLabel(model.metadata.size, ram)} {downloadButton} { - const modelFormatExt = '.gguf' - const extractedFileName = model.sources[0]?.url.split('/').pop() ?? model.id - const fileName = extractedFileName.toLowerCase().endsWith(modelFormatExt) - ? extractedFileName - : model.id - return fileName -}