download CUDA libs if needed

This commit is contained in:
Thien Tran 2025-05-29 17:32:09 +08:00 committed by Louis
parent 40cd7e962a
commit fbfaaf43c5
No known key found for this signature in database
GPG Key ID: 44FA9F4D33C37DE2

View File

@ -100,22 +100,37 @@ export async function isBackendInstalled(backend: string, version: string): Prom
export async function downloadBackend(backend: string, version: string): Promise<void> { export async function downloadBackend(backend: string, version: string): Promise<void> {
const janDataFolderPath = await getJanDataFolderPath() const janDataFolderPath = await getJanDataFolderPath()
const backendPath = await joinPath([janDataFolderPath, 'llamacpp', 'backends', backend, version]) const llamacppPath = await joinPath([janDataFolderPath, 'llamacpp'])
const downloadManager = window.core.extensionManager.getByName('@janhq/download-extension') const downloadManager = window.core.extensionManager.getByName('@janhq/download-extension')
const url = `https://github.com/menloresearch/llama.cpp/releases/download/${version}/llama-${version}-bin-${backend}.tar.gz`
const savePath = await joinPath([backendPath, 'backend.tar.gz']) const downloadItems = [
{
url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/llama-${version}-bin-${backend}.tar.gz`,
save_path: await joinPath([llamacppPath, 'backends', backend, version, 'backend.tar.gz']),
}
]
// also download CUDA runtime + cuBLAS + cuBLASLt if needed
if (backend.includes('cu11.7') && (await _isCudaInstalled('11.7'))) {
downloadItems.push({
url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/cudart-llama-bin-linux-cu11.7-x64.tar.gz`,
save_path: await joinPath([llamacppPath, 'lib', 'cuda.tar.gz']),
})
} else if (backend.includes('cu12.0') && (await _isCudaInstalled('12.0'))) {
downloadItems.push({
url: `https://github.com/menloresearch/llama.cpp/releases/download/${version}/cudart-llama-bin-linux-cu12.4-x64.tar.gz`,
save_path: await joinPath([llamacppPath, 'lib', 'cuda.tar.gz']),
})
}
const taskId = `llamacpp-${version}-${backend}`.replace(/\./g, '-') const taskId = `llamacpp-${version}-${backend}`.replace(/\./g, '-')
const downloadType = 'Engine' const downloadType = 'Engine'
console.log(`Downloading backend ${backend} version ${version} from ${url} to ${savePath}`) console.log(`Downloading backend ${backend} version ${version}: ${downloadItems}`)
let downloadCompleted = false let downloadCompleted = false
try { try {
await downloadManager.downloadFile( const onProgress = (transferred: number, total: number) => {
url,
savePath,
taskId,
(transferred: number, total: number) => {
events.emit('onFileDownloadUpdate', { events.emit('onFileDownloadUpdate', {
modelId: taskId, modelId: taskId,
percent: transferred / total, percent: transferred / total,
@ -124,7 +139,7 @@ export async function downloadBackend(backend: string, version: string): Promise
}) })
downloadCompleted = transferred === total downloadCompleted = transferred === total
} }
) await downloadManager.downloadFiles(downloadItems, taskId, onProgress)
// once we reach this point, it either means download finishes or it was cancelled. // once we reach this point, it either means download finishes or it was cancelled.
// if there was an error, it would have been caught above // if there was an error, it would have been caught above
@ -133,12 +148,18 @@ export async function downloadBackend(backend: string, version: string): Promise
return return
} }
await invoke('decompress', { path: savePath, outputDir: backendPath }) // decompress the downloaded tar.gz files
await fs.rm(savePath) for (const { save_path } of downloadItems) {
if (save_path.endsWith('.tar.gz')) {
const parentDir = save_path.substring(0, save_path.lastIndexOf('/'))
await invoke('decompress', { path: save_path, outputDir: parentDir })
await fs.rm(save_path)
}
}
events.emit('onFileDownloadSuccess', { modelId: taskId, downloadType }) events.emit('onFileDownloadSuccess', { modelId: taskId, downloadType })
} catch (error) { } catch (error) {
console.error(`Failed to download backend ${backend}:`, error) console.error(`Failed to download backend ${backend}: `, error)
events.emit('onFileDownloadError', { modelId: taskId, downloadType }) events.emit('onFileDownloadError', { modelId: taskId, downloadType })
throw error throw error
} }
@ -156,3 +177,25 @@ async function _fetchGithubReleases(
} }
return response.json() return response.json()
} }
async function _isCudaInstalled(version: string): Promise<boolean> {
const sysInfo = await window.core.api.getSystemInfo()
const os_type = sysInfo.os_type
// not sure the reason behind this naming convention
const libnameLookup = {
'windows-11.7': `cudart64_110.dll`,
'windows-12.0': `cudart64_12.dll`,
'linux-11.7': `libcudart.so.11.0`,
'linux-12.0': `libcudart.so.12`,
}
const key = `${os_type}-${version}`
if (!(key in libnameLookup)) {
return false
}
const libname = libnameLookup[key]
const janDataFolderPath = await getJanDataFolderPath()
const cudartPath = await joinPath([janDataFolderPath, 'llamacpp', 'lib', libname])
return await fs.existsSync(cudartPath)
}