diff --git a/core/src/api/index.ts b/core/src/api/index.ts index c7dd9146e..7fb8eeb38 100644 --- a/core/src/api/index.ts +++ b/core/src/api/index.ts @@ -49,7 +49,7 @@ export enum DownloadEvent { export enum LocalImportModelEvent { onLocalImportModelUpdate = 'onLocalImportModelUpdate', - onLocalImportModelError = 'onLocalImportModelError', + onLocalImportModelFailed = 'onLocalImportModelFailed', onLocalImportModelSuccess = 'onLocalImportModelSuccess', onLocalImportModelFinished = 'onLocalImportModelFinished', } diff --git a/core/src/core.ts b/core/src/core.ts index 8831c6001..6e2442c2b 100644 --- a/core/src/core.ts +++ b/core/src/core.ts @@ -65,7 +65,7 @@ const joinPath: (paths: string[]) => Promise = (paths) => global.core.ap * @param path - The path to retrieve. * @returns {Promise} A promise that resolves with the basename. */ -const baseName: (paths: string[]) => Promise = (path) => global.core.api?.baseName(path) +const baseName: (paths: string) => Promise = (path) => global.core.api?.baseName(path) /** * Opens an external URL in the default web browser. diff --git a/core/src/types/model/modelImport.ts b/core/src/types/model/modelImport.ts index 8977c42a0..7c72a691b 100644 --- a/core/src/types/model/modelImport.ts +++ b/core/src/types/model/modelImport.ts @@ -19,4 +19,5 @@ export type ImportingModel = { status: ImportingModelStatus format: string percentage?: number + error?: string } diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts index dd5bcdf26..fb1f26885 100644 --- a/extensions/model-extension/src/index.ts +++ b/extensions/model-extension/src/index.ts @@ -16,6 +16,7 @@ import { OptionType, ImportingModel, LocalImportModelEvent, + baseName, } from '@janhq/core' import { extractFileName } from './helpers/path' @@ -488,7 +489,7 @@ export default class JanModelExtension extends ModelExtension { return } - const binaryFileName = extractFileName(modelBinaryPath, '') + const binaryFileName = await baseName(modelBinaryPath) const model: Model = { ...defaultModel, @@ -555,7 +556,7 @@ export default class JanModelExtension extends ModelExtension { model: ImportingModel, optionType: OptionType ): Promise { - const binaryName = extractFileName(model.path, '').replace(/\s/g, '') + const binaryName = (await baseName(model.path)).replace(/\s/g, '') let modelFolderName = binaryName if (binaryName.endsWith(JanModelExtension._supportedModelFormat)) { @@ -568,7 +569,7 @@ export default class JanModelExtension extends ModelExtension { const modelFolderPath = await this.getModelFolderName(modelFolderName) await fs.mkdirSync(modelFolderPath) - const uniqueFolderName = modelFolderPath.split('/').pop() + const uniqueFolderName = await baseName(modelFolderPath) const modelBinaryFile = binaryName.endsWith( JanModelExtension._supportedModelFormat ) @@ -637,14 +638,21 @@ export default class JanModelExtension extends ModelExtension { for (const model of models) { events.emit(LocalImportModelEvent.onLocalImportModelUpdate, model) - const importedModel = await this.importModel(model, optionType) - - events.emit(LocalImportModelEvent.onLocalImportModelSuccess, { - ...model, - modelId: importedModel.id, - }) - importedModels.push(importedModel) + try { + const importedModel = await this.importModel(model, optionType) + events.emit(LocalImportModelEvent.onLocalImportModelSuccess, { + ...model, + modelId: importedModel.id, + }) + importedModels.push(importedModel) + } catch (err) { + events.emit(LocalImportModelEvent.onLocalImportModelFailed, { + ...model, + error: err, + }) + } } + events.emit( LocalImportModelEvent.onLocalImportModelFinished, importedModels diff --git a/web/containers/Providers/ModelImportListener.tsx b/web/containers/Providers/ModelImportListener.tsx index 60347ba40..f1ca2a768 100644 --- a/web/containers/Providers/ModelImportListener.tsx +++ b/web/containers/Providers/ModelImportListener.tsx @@ -12,6 +12,7 @@ import { useSetAtom } from 'jotai' import { snackbar } from '../Toast' import { + setImportingModelErrorAtom, setImportingModelSuccessAtom, updateImportingModelProgressAtom, } from '@/helpers/atoms/Model.atom' @@ -21,6 +22,7 @@ const ModelImportListener = ({ children }: PropsWithChildren) => { updateImportingModelProgressAtom ) const setImportingModelSuccess = useSetAtom(setImportingModelSuccessAtom) + const setImportingModelFailed = useSetAtom(setImportingModelErrorAtom) const onImportModelUpdate = useCallback( async (state: ImportingModel) => { @@ -30,6 +32,14 @@ const ModelImportListener = ({ children }: PropsWithChildren) => { [updateImportingModelProgress] ) + const onImportModelFailed = useCallback( + async (state: ImportingModel) => { + if (!state.importId) return + setImportingModelFailed(state.importId, state.error ?? '') + }, + [setImportingModelFailed] + ) + const onImportModelSuccess = useCallback( (state: ImportingModel) => { if (!state.modelId) return @@ -62,6 +72,10 @@ const ModelImportListener = ({ children }: PropsWithChildren) => { LocalImportModelEvent.onLocalImportModelFinished, onImportModelFinished ) + events.on( + LocalImportModelEvent.onLocalImportModelFailed, + onImportModelFailed + ) return () => { console.debug('ModelImportListener: unregistering event listeners...') @@ -77,8 +91,17 @@ const ModelImportListener = ({ children }: PropsWithChildren) => { LocalImportModelEvent.onLocalImportModelFinished, onImportModelFinished ) + events.off( + LocalImportModelEvent.onLocalImportModelFailed, + onImportModelFailed + ) } - }, [onImportModelUpdate, onImportModelSuccess, onImportModelFinished]) + }, [ + onImportModelUpdate, + onImportModelSuccess, + onImportModelFinished, + onImportModelFailed, + ]) return {children} } diff --git a/web/helpers/atoms/Model.atom.ts b/web/helpers/atoms/Model.atom.ts index 7a6aa6440..da6dc5918 100644 --- a/web/helpers/atoms/Model.atom.ts +++ b/web/helpers/atoms/Model.atom.ts @@ -67,6 +67,24 @@ export const updateImportingModelProgressAtom = atom( } ) +export const setImportingModelErrorAtom = atom( + null, + (get, set, importId: string, error: string) => { + const model = get(importingModelsAtom).find((x) => x.importId === importId) + if (!model) return + const newModel: ImportingModel = { + ...model, + status: 'FAILED', + } + + console.error(`Importing model ${model} failed`, error) + const newList = get(importingModelsAtom).map((m) => + m.importId === importId ? newModel : m + ) + set(importingModelsAtom, newList) + } +) + export const setImportingModelSuccessAtom = atom( null, (get, set, importId: string, modelId: string) => { diff --git a/web/hooks/useDropModelBinaries.ts b/web/hooks/useDropModelBinaries.ts new file mode 100644 index 000000000..c08e1dc73 --- /dev/null +++ b/web/hooks/useDropModelBinaries.ts @@ -0,0 +1,55 @@ +import { useCallback } from 'react' + +import { ImportingModel } from '@janhq/core' +import { useSetAtom } from 'jotai' + +import { v4 as uuidv4 } from 'uuid' + +import { snackbar } from '@/containers/Toast' + +import { getFileInfoFromFile } from '@/utils/file' + +import { setImportModelStageAtom } from './useImportModel' + +import { importingModelsAtom } from '@/helpers/atoms/Model.atom' + +export default function useDropModelBinaries() { + const setImportingModels = useSetAtom(importingModelsAtom) + const setImportModelStage = useSetAtom(setImportModelStageAtom) + + const onDropModels = useCallback( + async (acceptedFiles: File[]) => { + const files = await getFileInfoFromFile(acceptedFiles) + + const unsupportedFiles = files.filter( + (file) => !file.path.endsWith('.gguf') + ) + const supportedFiles = files.filter((file) => file.path.endsWith('.gguf')) + + const importingModels: ImportingModel[] = supportedFiles.map((file) => ({ + importId: uuidv4(), + modelId: undefined, + name: file.name.replace('.gguf', ''), + description: '', + path: file.path, + tags: [], + size: file.size, + status: 'PREPARING', + format: 'gguf', + })) + if (unsupportedFiles.length > 0) { + snackbar({ + description: `File has to be a .gguf file`, + type: 'error', + }) + } + if (importingModels.length === 0) return + + setImportingModels(importingModels) + setImportModelStage('MODEL_SELECTED') + }, + [setImportModelStage, setImportingModels] + ) + + return { onDropModels } +} diff --git a/web/screens/Settings/EditModelInfoModal/index.tsx b/web/screens/Settings/EditModelInfoModal/index.tsx index bb87b7ed9..bc9d6521d 100644 --- a/web/screens/Settings/EditModelInfoModal/index.tsx +++ b/web/screens/Settings/EditModelInfoModal/index.tsx @@ -1,6 +1,12 @@ -import { useCallback, useEffect, useMemo, useState } from 'react' +import { useCallback, useEffect, useState } from 'react' -import { Model, ModelEvent, events, openFileExplorer } from '@janhq/core' +import { + Model, + ModelEvent, + events, + joinPath, + openFileExplorer, +} from '@janhq/core' import { Modal, ModalContent, @@ -47,6 +53,7 @@ const EditModelInfoModal: React.FC = () => { const janDataFolder = useAtomValue(janDataFolderPathAtom) const updateImportingModel = useSetAtom(updateImportingModelAtom) const { updateModelInfo } = useImportModel() + const [modelPath, setModelPath] = useState('') const editingModel = importingModels.find( (model) => model.importId === editingModelId @@ -88,13 +95,19 @@ const EditModelInfoModal: React.FC = () => { setEditingModelId(undefined) } - const modelFolderPath = useMemo(() => { - return `${janDataFolder}/models/${editingModel?.modelId}` + useEffect(() => { + const getModelPath = async () => { + const modelId = editingModel?.modelId + if (!modelId) return '' + const path = await joinPath([janDataFolder, 'models', modelId]) + setModelPath(path) + } + getModelPath() }, [janDataFolder, editingModel]) const onShowInFinderClick = useCallback(() => { - openFileExplorer(modelFolderPath) - }, [modelFolderPath]) + openFileExplorer(modelPath) + }, [modelPath]) if (!editingModel) { setImportModelStage('IMPORTING_MODEL') @@ -104,7 +117,10 @@ const EditModelInfoModal: React.FC = () => { } return ( - + Edit Model Information @@ -130,7 +146,7 @@ const EditModelInfoModal: React.FC = () => {
- {modelFolderPath} + {modelPath}