diff --git a/extensions/llamacpp-extension/src/index.ts b/extensions/llamacpp-extension/src/index.ts index f1a750138..d5d13804f 100644 --- a/extensions/llamacpp-extension/src/index.ts +++ b/extensions/llamacpp-extension/src/index.ts @@ -2012,6 +2012,69 @@ export default class llamacpp_extension extends AIEngine { libraryPath, envs, }) + // On Linux with AMD GPUs, llama.cpp via Vulkan may report UMA (shared) memory as device-local. + // For clearer UX, override with dedicated VRAM from the hardware plugin when available. + try { + const sysInfo = await getSystemInfo() + if (sysInfo?.os_type === 'linux' && Array.isArray(sysInfo.gpus)) { + const usage = await getSystemUsage() + if (usage && Array.isArray(usage.gpus)) { + const uuidToUsage: Record = {} + for (const u of usage.gpus as any[]) { + if (u && typeof u.uuid === 'string') { + uuidToUsage[u.uuid] = u + } + } + + const indexToAmdUuid = new Map() + for (const gpu of sysInfo.gpus as any[]) { + const vendorStr = + typeof gpu?.vendor === 'string' + ? gpu.vendor + : typeof gpu?.vendor === 'object' && gpu.vendor !== null + ? String(gpu.vendor) + : '' + if ( + vendorStr.toUpperCase().includes('AMD') && + gpu?.vulkan_info && + typeof gpu.vulkan_info.index === 'number' && + typeof gpu.uuid === 'string' + ) { + indexToAmdUuid.set(gpu.vulkan_info.index, gpu.uuid) + } + } + + if (indexToAmdUuid.size > 0) { + const adjusted = dList.map((dev) => { + if (dev.id?.startsWith('Vulkan')) { + const match = /^Vulkan(\d+)/.exec(dev.id) + if (match) { + const vIdx = Number(match[1]) + const uuid = indexToAmdUuid.get(vIdx) + if (uuid) { + const u = uuidToUsage[uuid] + if ( + u && + typeof u.total_memory === 'number' && + typeof u.used_memory === 'number' + ) { + const total = Math.max(0, Math.floor(u.total_memory)) + const free = Math.max(0, Math.floor(u.total_memory - u.used_memory)) + return { ...dev, mem: total, free } + } + } + } + } + return dev + }) + return adjusted + } + } + } + } catch (e) { + logger.warn('Device memory override (AMD/Linux) failed:', e) + } + return dList } catch (error) { logger.error('Failed to query devices:\n', error)