diff --git a/core/src/types/api/index.ts b/core/src/types/api/index.ts index c0de0f5e8..093314a15 100644 --- a/core/src/types/api/index.ts +++ b/core/src/types/api/index.ts @@ -69,11 +69,11 @@ export enum DownloadRoute { } export enum DownloadEvent { - onFileDownloadUpdate = 'DownloadUpdated', - onFileDownloadError = 'DownloadError', - onFileDownloadSuccess = 'DownloadSuccess', - onFileDownloadStopped = 'DownloadStopped', - onFileDownloadStarted = 'DownloadStarted', + onFileDownloadUpdate = 'onFileDownloadUpdate', + onFileDownloadError = 'onFileDownloadError', + onFileDownloadSuccess = 'onFileDownloadSuccess', + onFileDownloadStopped = 'onFileDownloadStopped', + onFileDownloadStarted = 'onFileDownloadStarted', onFileUnzipSuccess = 'onFileUnzipSuccess', } diff --git a/extensions/model-extension/src/cortex.ts b/extensions/model-extension/src/cortex.ts index b0acd6d08..c690f0c16 100644 --- a/extensions/model-extension/src/cortex.ts +++ b/extensions/model-extension/src/cortex.ts @@ -25,6 +25,14 @@ type ModelList = { data: any[] } +enum DownloadTypes { + DownloadUpdated = 'onFileDownloadUpdate', + DownloadError = 'onFileDownloadError', + DownloadSuccess = 'onFileDownloadSuccess', + DownloadStopped = 'onFileDownloadStopped', + DownloadStarted = 'onFileDownloadStarted', +} + export class CortexAPI implements ICortexAPI { queue = new PQueue({ concurrency: 1 }) socket?: WebSocket = undefined @@ -159,17 +167,16 @@ export class CortexAPI implements ICortexAPI { this.socket.addEventListener('message', (event) => { const data = JSON.parse(event.data) const transferred = data.task.items.reduce( - (accumulator, currentValue) => - accumulator + currentValue.downloadedBytes, + (acc, cur) => acc + cur.downloadedBytes, 0 ) const total = data.task.items.reduce( - (accumulator, currentValue) => accumulator + currentValue.bytes, + (acc, cur) => acc + cur.bytes, 0 ) const percent = (transferred / total || 0) * 100 - events.emit(data.type, { + events.emit(DownloadTypes[data.type], { modelId: data.task.id, percent: percent, size: { @@ -178,7 +185,7 @@ export class CortexAPI implements ICortexAPI { }, }) // Update models list from Hub - if (data.type === DownloadEvent.onFileDownloadSuccess) { + if (data.type === DownloadTypes.DownloadSuccess) { // Delay for the state update from cortex.cpp // Just to be sure setTimeout(() => { diff --git a/extensions/model-extension/src/index.ts b/extensions/model-extension/src/index.ts index 54e91a6aa..3696acd79 100644 --- a/extensions/model-extension/src/index.ts +++ b/extensions/model-extension/src/index.ts @@ -4,9 +4,15 @@ import { InferenceEngine, joinPath, dirName, + ModelManager, + abortDownload, + DownloadState, + events, + DownloadEvent, } from '@janhq/core' import { CortexAPI } from './cortex' -import { scanModelsFolder } from './model-json' +import { scanModelsFolder } from './legacy/model-json' +import { downloadModel } from './legacy/download' declare const SETTINGS: Array @@ -34,6 +40,9 @@ export default class JanModelExtension extends ModelExtension { this.getModels().then((models) => { this.registerModels(models) }) + + // Listen to app download events + this.handleDesktopEvents() } /** @@ -48,6 +57,17 @@ export default class JanModelExtension extends ModelExtension { * @returns A Promise that resolves when the model is downloaded. */ async pullModel(model: string, id?: string): Promise { + if (id) { + const model: Model = ModelManager.instance().get(id) + // Clip vision model - should not be handled by cortex.cpp + // TensorRT model - should not be handled by cortex.cpp + if ( + model.engine === InferenceEngine.nitro_tensorrt_llm || + model.settings.vision_model + ) { + return downloadModel(model) + } + } /** * Sending POST to /models/pull/{id} endpoint to pull the model */ @@ -61,10 +81,24 @@ export default class JanModelExtension extends ModelExtension { * @returns {Promise} A promise that resolves when the download has been cancelled. */ async cancelModelPull(model: string): Promise { + if (model) { + const modelDto: Model = ModelManager.instance().get(model) + // Clip vision model - should not be handled by cortex.cpp + // TensorRT model - should not be handled by cortex.cpp + if ( + modelDto.engine === InferenceEngine.nitro_tensorrt_llm || + modelDto.settings.vision_model + ) { + for (const source of modelDto.sources) { + const path = await joinPath(['models', modelDto.id, source.filename]) + return abortDownload(path) + } + } + } /** * Sending DELETE to /models/pull/{id} endpoint to cancel a model pull */ - this.cortexAPI.cancelModelPull(model) + return this.cortexAPI.cancelModelPull(model) } /** @@ -87,14 +121,18 @@ export default class JanModelExtension extends ModelExtension { * should compare and try import */ let currentModels: Model[] = [] + + /** + * Legacy models should be supported + */ + let legacyModels = await scanModelsFolder() + try { if (!localStorage.getItem(ExtensionEnum.downloadedModels)) { // Updated from an older version than 0.5.5 // Scan through the models folder and import them (Legacy flow) // Return models immediately - currentModels = await scanModelsFolder().then((models) => { - return models ?? [] - }) + currentModels = legacyModels } else { currentModels = JSON.parse( localStorage.getItem(ExtensionEnum.downloadedModels) @@ -116,7 +154,7 @@ export default class JanModelExtension extends ModelExtension { await this.cortexAPI.getModels().then((models) => { const existingIds = models.map((e) => e.id) toImportModels = toImportModels.filter( - (e: Model) => !existingIds.includes(e.id) + (e: Model) => !existingIds.includes(e.id) && !e.settings?.vision_model ) }) @@ -147,13 +185,15 @@ export default class JanModelExtension extends ModelExtension { } /** - * All models are imported successfully before - * just return models from cortex.cpp + * Models are imported successfully before + * Now return models from cortex.cpp and merge with legacy models which are not imported */ return ( this.cortexAPI.getModels().then((models) => { - return models - }) ?? Promise.resolve([]) + return models.concat( + legacyModels.filter((e) => !models.some((x) => x.id === e.id)) + ) + }) ?? Promise.resolve(legacyModels) ) } @@ -175,4 +215,31 @@ export default class JanModelExtension extends ModelExtension { async importModel(model: string, modelPath: string): Promise { return this.cortexAPI.importModel(model, modelPath) } + + /** + * Handle download state from main app + */ + handleDesktopEvents() { + if (window && window.electronAPI) { + window.electronAPI.onFileDownloadUpdate( + async (_event: string, state: DownloadState | undefined) => { + if (!state) return + state.downloadState = 'downloading' + events.emit(DownloadEvent.onFileDownloadUpdate, state) + } + ) + window.electronAPI.onFileDownloadError( + async (_event: string, state: DownloadState) => { + state.downloadState = 'error' + events.emit(DownloadEvent.onFileDownloadError, state) + } + ) + window.electronAPI.onFileDownloadSuccess( + async (_event: string, state: DownloadState) => { + state.downloadState = 'end' + events.emit(DownloadEvent.onFileDownloadSuccess, state) + } + ) + } + } } diff --git a/extensions/model-extension/src/legacy/download.ts b/extensions/model-extension/src/legacy/download.ts new file mode 100644 index 000000000..a1a998daf --- /dev/null +++ b/extensions/model-extension/src/legacy/download.ts @@ -0,0 +1,97 @@ +import { + downloadFile, + DownloadRequest, + fs, + GpuSetting, + InferenceEngine, + joinPath, + Model, +} from '@janhq/core' + +export const downloadModel = async ( + model: Model, + gpuSettings?: GpuSetting, + network?: { ignoreSSL?: boolean; proxy?: string } +): Promise => { + const homedir = 'file://models' + const supportedGpuArch = ['ampere', 'ada'] + // Create corresponding directory + const modelDirPath = await joinPath([homedir, model.id]) + if (!(await fs.existsSync(modelDirPath))) await fs.mkdir(modelDirPath) + + if (model.engine === InferenceEngine.nitro_tensorrt_llm) { + if (!gpuSettings || gpuSettings.gpus.length === 0) { + console.error('No GPU found. Please check your GPU setting.') + return + } + const firstGpu = gpuSettings.gpus[0] + if (!firstGpu.name.toLowerCase().includes('nvidia')) { + console.error('No Nvidia GPU found. Please check your GPU setting.') + return + } + const gpuArch = firstGpu.arch + if (gpuArch === undefined) { + console.error('No GPU architecture found. Please check your GPU setting.') + return + } + + if (!supportedGpuArch.includes(gpuArch)) { + console.debug( + `Your GPU: ${JSON.stringify(firstGpu)} is not supported. Only 30xx, 40xx series are supported.` + ) + return + } + + const os = 'windows' // TODO: remove this hard coded value + + const newSources = model.sources.map((source) => { + const newSource = { ...source } + newSource.url = newSource.url + .replace(//g, os) + .replace(//g, gpuArch) + return newSource + }) + model.sources = newSources + } + + console.debug(`Download sources: ${JSON.stringify(model.sources)}`) + + if (model.sources.length > 1) { + // path to model binaries + for (const source of model.sources) { + let path = extractFileName(source.url, '.gguf') + if (source.filename) { + path = await joinPath([modelDirPath, source.filename]) + } + + const downloadRequest: DownloadRequest = { + url: source.url, + localPath: path, + modelId: model.id, + } + downloadFile(downloadRequest, network) + } + } else { + const fileName = extractFileName(model.sources[0]?.url, '.gguf') + const path = await joinPath([modelDirPath, fileName]) + const downloadRequest: DownloadRequest = { + url: model.sources[0]?.url, + localPath: path, + modelId: model.id, + } + downloadFile(downloadRequest, network) + } +} + +/** + * try to retrieve the download file name from the source url + */ +function extractFileName(url: string, fileExtension: string): string { + if (!url) return fileExtension + + const extractedFileName = url.split('/').pop() + const fileName = extractedFileName.toLowerCase().endsWith(fileExtension) + ? extractedFileName + : extractedFileName + fileExtension + return fileName +} diff --git a/extensions/model-extension/src/model-json.test.ts b/extensions/model-extension/src/legacy/model-json.test.ts similarity index 100% rename from extensions/model-extension/src/model-json.test.ts rename to extensions/model-extension/src/legacy/model-json.test.ts diff --git a/extensions/model-extension/src/model-json.ts b/extensions/model-extension/src/legacy/model-json.ts similarity index 97% rename from extensions/model-extension/src/model-json.ts rename to extensions/model-extension/src/legacy/model-json.ts index 46eee3482..646ae85d7 100644 --- a/extensions/model-extension/src/model-json.ts +++ b/extensions/model-extension/src/legacy/model-json.ts @@ -71,7 +71,7 @@ export const scanModelsFolder = async (): Promise => { file.toLowerCase().endsWith('.gguf') || // GGUF file.toLowerCase().endsWith('.engine') // Tensort-LLM ) - })?.length > 0 // TODO: find better way (can use basename to check the file name with source url) + })?.length >= (model.sources?.length ?? 1) ) }) diff --git a/web/containers/Providers/EventListener.tsx b/web/containers/Providers/EventListener.tsx index 5df59b0fd..37711ee0d 100644 --- a/web/containers/Providers/EventListener.tsx +++ b/web/containers/Providers/EventListener.tsx @@ -50,7 +50,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => { setDownloadState(state) } }, - [setDownloadState, setInstallingExtension] + [addDownloadingModel, setDownloadState, setInstallingExtension] ) const onFileDownloadError = useCallback( @@ -64,7 +64,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => { removeDownloadingModel(state.modelId) } }, - [setDownloadState, removeInstallingExtension] + [removeInstallingExtension, setDownloadState, removeDownloadingModel] ) const onFileDownloadStopped = useCallback( @@ -79,7 +79,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => { removeDownloadingModel(state.modelId) } }, - [setDownloadState, removeInstallingExtension] + [removeInstallingExtension, setDownloadState, removeDownloadingModel] ) const onFileDownloadSuccess = useCallback( @@ -92,7 +92,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => { } events.emit(ModelEvent.OnModelsUpdate, {}) }, - [setDownloadState] + [removeDownloadingModel, setDownloadState] ) const onFileUnzipSuccess = useCallback( @@ -121,7 +121,7 @@ const EventListenerWrapper = ({ children }: PropsWithChildren) => { events.off(DownloadEvent.onFileDownloadUpdate, onFileDownloadUpdate) events.off(DownloadEvent.onFileDownloadError, onFileDownloadError) events.off(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess) - events.off(DownloadEvent.onFileDownloadSuccess, onFileDownloadSuccess) + events.off(DownloadEvent.onFileDownloadStopped, onFileDownloadStopped) events.off(DownloadEvent.onFileUnzipSuccess, onFileUnzipSuccess) } }, [