From fbfaaf43c5b1b29001a4aba2f151d9f21975de5d Mon Sep 17 00:00:00 2001 From: Thien Tran Date: Thu, 29 May 2025 17:32:09 +0800 Subject: [PATCH] download CUDA libs if needed --- extensions/llamacpp-extension/src/backend.ts | 85 +++++++++++++++----- 1 file changed, 64 insertions(+), 21 deletions(-) diff --git a/extensions/llamacpp-extension/src/backend.ts b/extensions/llamacpp-extension/src/backend.ts index cd6efbde7..63346d257 100644 --- a/extensions/llamacpp-extension/src/backend.ts +++ b/extensions/llamacpp-extension/src/backend.ts @@ -100,31 +100,46 @@ export async function isBackendInstalled(backend: string, version: string): Prom export async function downloadBackend(backend: string, version: string): Promise { const janDataFolderPath = await getJanDataFolderPath() - const backendPath = await joinPath([janDataFolderPath, 'llamacpp', 'backends', backend, version]) + const llamacppPath = await joinPath([janDataFolderPath, 'llamacpp']) const downloadManager = window.core.extensionManager.getByName('@janhq/download-extension') - const url = `https://github.com/menloresearch/llama.cpp/releases/download/${version}/llama-${version}-bin-${backend}.tar.gz` - const savePath = await joinPath([backendPath, 'backend.tar.gz']) + + const downloadItems = [ + { + url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/llama-${version}-bin-${backend}.tar.gz`, + save_path: await joinPath([llamacppPath, 'backends', backend, version, 'backend.tar.gz']), + } + ] + + // also download CUDA runtime + cuBLAS + cuBLASLt if needed + if (backend.includes('cu11.7') && (await _isCudaInstalled('11.7'))) { + downloadItems.push({ + url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/cudart-llama-bin-linux-cu11.7-x64.tar.gz`, + save_path: await joinPath([llamacppPath, 'lib', 'cuda.tar.gz']), + }) + } else if (backend.includes('cu12.0') && (await _isCudaInstalled('12.0'))) { + downloadItems.push({ + url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/cudart-llama-bin-linux-cu12.4-x64.tar.gz`, + save_path: await joinPath([llamacppPath, 'lib', 'cuda.tar.gz']), + }) + } + const taskId = `llamacpp-${version}-${backend}`.replace(/\./g, '-') const downloadType = 'Engine' - console.log(`Downloading backend ${backend} version ${version} from ${url} to ${savePath}`) + console.log(`Downloading backend ${backend} version ${version}: ${downloadItems}`) let downloadCompleted = false try { - await downloadManager.downloadFile( - url, - savePath, - taskId, - (transferred: number, total: number) => { - events.emit('onFileDownloadUpdate', { - modelId: taskId, - percent: transferred / total, - size: { transferred, total }, - downloadType, - }) - downloadCompleted = transferred === total - } - ) + const onProgress = (transferred: number, total: number) => { + events.emit('onFileDownloadUpdate', { + modelId: taskId, + percent: transferred / total, + size: { transferred, total }, + downloadType, + }) + downloadCompleted = transferred === total + } + await downloadManager.downloadFiles(downloadItems, taskId, onProgress) // once we reach this point, it either means download finishes or it was cancelled. // if there was an error, it would have been caught above @@ -133,12 +148,18 @@ export async function downloadBackend(backend: string, version: string): Promise return } - await invoke('decompress', { path: savePath, outputDir: backendPath }) - await fs.rm(savePath) + // decompress the downloaded tar.gz files + for (const { save_path } of downloadItems) { + if (save_path.endsWith('.tar.gz')) { + const parentDir = save_path.substring(0, save_path.lastIndexOf('/')) + await invoke('decompress', { path: save_path, outputDir: parentDir }) + await fs.rm(save_path) + } + } events.emit('onFileDownloadSuccess', { modelId: taskId, downloadType }) } catch (error) { - console.error(`Failed to download backend ${backend}:`, error) + console.error(`Failed to download backend ${backend}: `, error) events.emit('onFileDownloadError', { modelId: taskId, downloadType }) throw error } @@ -156,3 +177,25 @@ async function _fetchGithubReleases( } return response.json() } + +async function _isCudaInstalled(version: string): Promise { + const sysInfo = await window.core.api.getSystemInfo() + const os_type = sysInfo.os_type + + // not sure the reason behind this naming convention + const libnameLookup = { + 'windows-11.7': `cudart64_110.dll`, + 'windows-12.0': `cudart64_12.dll`, + 'linux-11.7': `libcudart.so.11.0`, + 'linux-12.0': `libcudart.so.12`, + } + const key = `${os_type}-${version}` + if (!(key in libnameLookup)) { + return false + } + + const libname = libnameLookup[key] + const janDataFolderPath = await getJanDataFolderPath() + const cudartPath = await joinPath([janDataFolderPath, 'llamacpp', 'lib', libname]) + return await fs.existsSync(cudartPath) +}