diff --git a/core/src/browser/extensions/engines/LocalOAIEngine.ts b/core/src/browser/extensions/engines/LocalOAIEngine.ts index e8bd8cdf2..b54f8fbde 100644 --- a/core/src/browser/extensions/engines/LocalOAIEngine.ts +++ b/core/src/browser/extensions/engines/LocalOAIEngine.ts @@ -1,4 +1,4 @@ -import { executeOnMain, systemInformation, dirName } from '../../core' +import { executeOnMain, systemInformation, dirName, joinPath, getJanDataFolderPath } from '../../core' import { events } from '../../events' import { Model, ModelEvent } from '../../../types' import { OAIEngine } from './OAIEngine' @@ -29,13 +29,46 @@ export abstract class LocalOAIEngine extends OAIEngine { /** * Load the model. */ - override async loadModel(model: Model): Promise { - return Promise.resolve() + override async loadModel(model: Model & { file_path?: string }): Promise { + if (model.engine.toString() !== this.provider) return + const modelFolder = 'file_path' in model && model.file_path ? await dirName(model.file_path) : await this.getModelFilePath(model.id) + const systemInfo = await systemInformation() + const res = await executeOnMain( + this.nodeModule, + this.loadModelFunctionName, + { + modelFolder, + model, + }, + systemInfo + ) + + if (res?.error) { + events.emit(ModelEvent.OnModelFail, { error: res.error }) + return Promise.reject(res.error) + } else { + this.loadedModel = model + events.emit(ModelEvent.OnModelReady, model) + return Promise.resolve() + } } /** * Stops the model. */ override async unloadModel(model?: Model) { - return Promise.resolve() + if (model?.engine && model.engine?.toString() !== this.provider) return Promise.resolve() + + this.loadedModel = undefined + await executeOnMain(this.nodeModule, this.unloadModelFunctionName).then(() => { + events.emit(ModelEvent.OnModelStopped, {}) + }) } + + /// Legacy + private getModelFilePath = async ( + id: string, + ): Promise => { + return joinPath([await getJanDataFolderPath(), 'models', id]) + } + /// } diff --git a/extensions/inference-cortex-extension/src/index.ts b/extensions/inference-cortex-extension/src/index.ts index 8143a71cf..45f0e5fe0 100644 --- a/extensions/inference-cortex-extension/src/index.ts +++ b/extensions/inference-cortex-extension/src/index.ts @@ -118,19 +118,6 @@ export default class JanInferenceCortexExtension extends LocalOAIEngine { .then() } - private async modelPath( - model: Model & { file_path?: string } - ): Promise { - if (!model.file_path) return model.id - return await joinPath([ - await dirName(model.file_path), - model.sources[0]?.filename ?? - model.settings?.llama_model_path ?? - model.sources[0]?.url.split('/').pop() ?? - model.id, - ]) - } - /** * Do health check on cortex.cpp * @returns diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts index 3e0af0172..a42fc2a52 100644 --- a/extensions/model-extension/src/index.ts +++ b/extensions/model-extension/src/index.ts @@ -14,6 +14,7 @@ import { CortexAPI } from './cortex' import { scanModelsFolder } from './legacy/model-json' import { downloadModel } from './legacy/download' import { systemInformation } from '@janhq/core' +import { deleteModelFiles } from './legacy/delete' declare const SETTINGS: Array @@ -50,7 +51,7 @@ export default class JanModelExtension extends ModelExtension { * Called when the extension is unloaded. * @override */ - async onUnload() {} + async onUnload() { } /** * Downloads a machine learning model. @@ -92,7 +93,7 @@ export default class JanModelExtension extends ModelExtension { ) { for (const source of modelDto.sources) { const path = await joinPath(['models', modelDto.id, source.filename]) - return abortDownload(path) + await abortDownload(path) } } } @@ -108,7 +109,14 @@ export default class JanModelExtension extends ModelExtension { * @returns A Promise that resolves when the model is deleted. */ async deleteModel(model: string): Promise { + const modelDto: Model = ModelManager.instance().get(model) return this.cortexAPI.deleteModel(model) + .catch(e => console.debug(e)) + .finally(async () => { + // Delete legacy model files + await deleteModelFiles(modelDto) + .catch(e => console.debug(e)) + }) } /** @@ -174,9 +182,9 @@ export default class JanModelExtension extends ModelExtension { await joinPath([ await dirName(model.file_path), model.sources[0]?.filename ?? - model.settings?.llama_model_path ?? - model.sources[0]?.url.split('/').pop() ?? - model.id, + model.settings?.llama_model_path ?? + model.sources[0]?.url.split('/').pop() ?? + model.id, ]) ) ) diff --git a/extensions/model-extension/src/legacy/delete.ts b/extensions/model-extension/src/legacy/delete.ts new file mode 100644 index 000000000..a46d90ea5 --- /dev/null +++ b/extensions/model-extension/src/legacy/delete.ts @@ -0,0 +1,18 @@ +import { fs, joinPath, Model } from "@janhq/core" + +export const deleteModelFiles = async (model: Model) => { + try { + const dirPath = await joinPath(['file://models', model.id]) + + // remove all files under dirPath except model.json + const files = await fs.readdirSync(dirPath) + const deletePromises = files.map(async (fileName: string) => { + if (fileName !== 'model.json') { + return fs.unlinkSync(await joinPath([dirPath, fileName])) + } + }) + await Promise.allSettled(deletePromises) + } catch (err) { + console.error(err) + } +} \ No newline at end of file diff --git a/extensions/model-extension/src/legacy/model-json.ts b/extensions/model-extension/src/legacy/model-json.ts index 646ae85d7..c47b7c661 100644 --- a/extensions/model-extension/src/legacy/model-json.ts +++ b/extensions/model-extension/src/legacy/model-json.ts @@ -1,4 +1,4 @@ -import { Model, fs, joinPath } from '@janhq/core' +import { InferenceEngine, Model, fs, joinPath } from '@janhq/core' //// LEGACY MODEL FOLDER //// /** * Scan through models folder and return downloaded models @@ -71,7 +71,7 @@ export const scanModelsFolder = async (): Promise => { file.toLowerCase().endsWith('.gguf') || // GGUF file.toLowerCase().endsWith('.engine') // Tensort-LLM ) - })?.length >= (model.sources?.length ?? 1) + })?.length >= (model.engine === InferenceEngine.nitro_tensorrt_llm ? 1 : (model.sources?.length ?? 1)) ) }) diff --git a/web/containers/Providers/EventListener.tsx b/web/containers/Providers/EventListener.tsx index 37711ee0d..af91b6027 100644 --- a/web/containers/Providers/EventListener.tsx +++ b/web/containers/Providers/EventListener.tsx @@ -88,7 +88,8 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => { if (state.downloadType !== 'extension') { state.downloadState = 'end' setDownloadState(state) - removeDownloadingModel(state.modelId) + if (state.percent !== 0) + removeDownloadingModel(state.modelId) } events.emit(ModelEvent.OnModelsUpdate, {}) }, diff --git a/web/hooks/useDownloadState.ts b/web/hooks/useDownloadState.ts index 59267749e..9aaa00bc4 100644 --- a/web/hooks/useDownloadState.ts +++ b/web/hooks/useDownloadState.ts @@ -108,6 +108,7 @@ export const setDownloadStateAtom = atom( ) modelDownloadState.children = updatedChildren + if (isAnyChildDownloadNotReady) { // just update the children currentState[state.modelId] = modelDownloadState @@ -115,23 +116,17 @@ export const setDownloadStateAtom = atom( return } - const parentTotalSize = modelDownloadState.size.total - if (parentTotalSize === 0) { - // calculate the total size of the parent by sum all children total size - const totalSize = updatedChildren.reduce( - (acc, m) => acc + m.size.total, - 0 - ) - - modelDownloadState.size.total = totalSize - } - + const parentTotalSize = updatedChildren.reduce( + (acc, m) => acc + m.size.total, + 0 + ) // calculate the total transferred size by sum all children transferred size const transferredSize = updatedChildren.reduce( (acc, m) => acc + m.size.transferred, 0 ) modelDownloadState.size.transferred = transferredSize + modelDownloadState.percent = parentTotalSize === 0 ? 0 : transferredSize / parentTotalSize currentState[state.modelId] = modelDownloadState