From 0c5fbc102c194175857a24c964d2d30aedf5a504 Mon Sep 17 00:00:00 2001 From: Akarshan Date: Mon, 13 Oct 2025 19:29:09 +0530 Subject: [PATCH] refactor: Simplify Tauri plugin calls and enhance 'Flash Attention' setting This commit introduces significant improvements to the llama.cpp extension, focusing on the 'Flash Attention' setting and refactoring Tauri plugin interactions for better code clarity and maintenance. The backend interaction is streamlined by removing the unnecessary `libraryPath` argument from the Tauri plugin commands for loading models and listing devices. * **Simplified API Calls:** The `loadLlamaModel`, `unloadLlamaModel`, and `get_devices` functions in both the extension and the Tauri plugin now manage the library path internally based on the backend executable's location. * **Decoupled Logic:** The extension (`src/index.ts`) now uses the new, simplified Tauri plugin functions, which enhances modularity and reduces boilerplate code in the extension. * **Type Consistency:** Added `UnloadResult` interface to `guest-js/index.ts` for consistency. * **Updated UI Control:** The 'Flash Attention' setting in `settings.json` is changed from a boolean checkbox to a string-based dropdown, offering **'auto'**, **'on'**, and **'off'** options. * **Improved Logic:** The extension logic in `src/index.ts` is updated to correctly handle the new string-based `flash_attn` configuration. It now passes the string value (`'auto'`, `'on'`, or `'off'`) directly as a command-line argument to the llama.cpp backend, simplifying the version-checking logic previously required for older llama.cpp versions. The old, complex logic tied to specific backend versions is removed. This refactoring cleans up the extension's codebase and moves environment and path setup concerns into the Tauri plugin where they are most relevant. --- extensions/llamacpp-extension/settings.json | 9 +++- extensions/llamacpp-extension/src/index.ts | 41 +++++++------------ .../tauri-plugin-llamacpp/guest-js/index.ts | 26 +++++++----- .../tauri-plugin-llamacpp/src/commands.rs | 10 ++--- .../tauri-plugin-llamacpp/src/device.rs | 9 ++-- .../tauri-plugin-llamacpp/src/gguf/utils.rs | 1 - 6 files changed, 45 insertions(+), 51 deletions(-) diff --git a/extensions/llamacpp-extension/settings.json b/extensions/llamacpp-extension/settings.json index ce5fc62e4..ac4706858 100644 --- a/extensions/llamacpp-extension/settings.json +++ b/extensions/llamacpp-extension/settings.json @@ -149,9 +149,14 @@ "key": "flash_attn", "title": "Flash Attention", "description": "Enable Flash Attention for optimized performance.", - "controllerType": "checkbox", + "controllerType": "dropdown", "controllerProps": { - "value": false + "value": "auto", + "options": [ + { "value": "auto", "name": "Auto" }, + { "value": "on", "name": "ON" }, + { "value": "off", "name": "OFF" } + ] } }, { diff --git a/extensions/llamacpp-extension/src/index.ts b/extensions/llamacpp-extension/src/index.ts index 8d4f277b6..631220a92 100644 --- a/extensions/llamacpp-extension/src/index.ts +++ b/extensions/llamacpp-extension/src/index.ts @@ -38,10 +38,12 @@ import { invoke } from '@tauri-apps/api/core' import { getProxyConfig } from './util' import { basename } from '@tauri-apps/api/path' import { + loadLlamaModel, readGgufMetadata, getModelSize, isModelSupported, planModelLoadInternal, + unloadLlamaModel, } from '@janhq/tauri-plugin-llamacpp-api' import { getSystemUsage, getSystemInfo } from '@janhq/tauri-plugin-hardware-api' @@ -69,7 +71,7 @@ type LlamacppConfig = { device: string split_mode: string main_gpu: number - flash_attn: boolean + flash_attn: string cont_batching: boolean no_mmap: boolean mlock: boolean @@ -1646,14 +1648,11 @@ export default class llamacpp_extension extends AIEngine { args.push('--split-mode', cfg.split_mode) if (cfg.main_gpu !== undefined && cfg.main_gpu != 0) args.push('--main-gpu', String(cfg.main_gpu)) + // Note: Older llama.cpp versions are no longer supported + if (cfg.flash_attn !== undefined || cfg.flash_attn === '') args.push('--flash-attn', String(cfg.flash_attn)) //default: auto = ON when supported // Boolean flags if (cfg.ctx_shift) args.push('--context-shift') - if (Number(version.replace(/^b/, '')) >= 6325) { - if (!cfg.flash_attn) args.push('--flash-attn', 'off') //default: auto = ON when supported - } else { - if (cfg.flash_attn) args.push('--flash-attn') - } if (cfg.cont_batching) args.push('--cont-batching') args.push('--no-mmap') if (cfg.mlock) args.push('--mlock') @@ -1688,20 +1687,9 @@ export default class llamacpp_extension extends AIEngine { logger.info('Calling Tauri command llama_load with args:', args) const backendPath = await getBackendExePath(backend, version) - const libraryPath = await joinPath([await this.getProviderPath(), 'lib']) try { - // TODO: add LIBRARY_PATH - const sInfo = await invoke( - 'plugin:llamacpp|load_llama_model', - { - backendPath, - libraryPath, - args, - envs, - isEmbedding, - } - ) + const sInfo = await loadLlamaModel(backendPath, args, envs) return sInfo } catch (error) { logger.error('Error in load command:\n', error) @@ -1717,12 +1705,7 @@ export default class llamacpp_extension extends AIEngine { const pid = sInfo.pid try { // Pass the PID as the session_id - const result = await invoke( - 'plugin:llamacpp|unload_llama_model', - { - pid: pid, - } - ) + const result = await unloadLlamaModel(pid) // If successful, remove from active sessions if (result.success) { @@ -2042,7 +2025,10 @@ export default class llamacpp_extension extends AIEngine { if (sysInfo?.os_type === 'linux' && Array.isArray(sysInfo.gpus)) { const usage = await getSystemUsage() if (usage && Array.isArray(usage.gpus)) { - const uuidToUsage: Record = {} + const uuidToUsage: Record< + string, + { total_memory: number; used_memory: number } + > = {} for (const u of usage.gpus as any[]) { if (u && typeof u.uuid === 'string') { uuidToUsage[u.uuid] = u @@ -2082,7 +2068,10 @@ export default class llamacpp_extension extends AIEngine { typeof u.used_memory === 'number' ) { const total = Math.max(0, Math.floor(u.total_memory)) - const free = Math.max(0, Math.floor(u.total_memory - u.used_memory)) + const free = Math.max( + 0, + Math.floor(u.total_memory - u.used_memory) + ) return { ...dev, mem: total, free } } } diff --git a/src-tauri/plugins/tauri-plugin-llamacpp/guest-js/index.ts b/src-tauri/plugins/tauri-plugin-llamacpp/guest-js/index.ts index 7c0e3e4be..b31133da5 100644 --- a/src-tauri/plugins/tauri-plugin-llamacpp/guest-js/index.ts +++ b/src-tauri/plugins/tauri-plugin-llamacpp/guest-js/index.ts @@ -2,11 +2,17 @@ import { invoke } from '@tauri-apps/api/core' // Types export interface SessionInfo { - pid: number - port: number - model_id: string - model_path: string - api_key: string + pid: number; + port: number; + model_id: string; + model_path: string; + api_key: string; + mmproj_path?: string; +} + +export interface UnloadResult { + success: boolean; + error?: string; } export interface DeviceInfo { @@ -29,19 +35,17 @@ export async function cleanupLlamaProcesses(): Promise { // LlamaCpp server commands export async function loadLlamaModel( backendPath: string, - libraryPath?: string, - args: string[] = [], - isEmbedding: boolean = false + args: string[], + envs: Record ): Promise { return await invoke('plugin:llamacpp|load_llama_model', { backendPath, - libraryPath, args, - isEmbedding, + envs }) } -export async function unloadLlamaModel(pid: number): Promise { +export async function unloadLlamaModel(pid: number): Promise { return await invoke('plugin:llamacpp|unload_llama_model', { pid }) } diff --git a/src-tauri/plugins/tauri-plugin-llamacpp/src/commands.rs b/src-tauri/plugins/tauri-plugin-llamacpp/src/commands.rs index 1d898b4d9..2b14f5ca7 100644 --- a/src-tauri/plugins/tauri-plugin-llamacpp/src/commands.rs +++ b/src-tauri/plugins/tauri-plugin-llamacpp/src/commands.rs @@ -41,7 +41,6 @@ pub struct UnloadResult { pub async fn load_llama_model( app_handle: tauri::AppHandle, backend_path: &str, - library_path: Option<&str>, mut args: Vec, envs: HashMap, is_embedding: bool, @@ -52,7 +51,7 @@ pub async fn load_llama_model( log::info!("Attempting to launch server at path: {:?}", backend_path); log::info!("Using arguments: {:?}", args); - validate_binary_path(backend_path)?; + let bin_path = validate_binary_path(backend_path)?; let port = parse_port_from_args(&args); let model_path_pb = validate_model_path(&mut args)?; @@ -83,11 +82,11 @@ pub async fn load_llama_model( let model_id = extract_arg_value(&args, "-a"); // Configure the command to run the server - let mut command = Command::new(backend_path); + let mut command = Command::new(&bin_path); command.args(args); command.envs(envs); - setup_library_path(library_path, &mut command); + setup_library_path(bin_path.parent().and_then(|p| p.to_str()), &mut command); command.stdout(Stdio::piped()); command.stderr(Stdio::piped()); setup_windows_process_flags(&mut command); @@ -280,10 +279,9 @@ pub async fn unload_llama_model( #[tauri::command] pub async fn get_devices( backend_path: &str, - library_path: Option<&str>, envs: HashMap, ) -> ServerResult> { - get_devices_from_backend(backend_path, library_path, envs).await + get_devices_from_backend(backend_path, envs).await } /// Generate API key using HMAC-SHA256 diff --git a/src-tauri/plugins/tauri-plugin-llamacpp/src/device.rs b/src-tauri/plugins/tauri-plugin-llamacpp/src/device.rs index 80b0293ac..922e70c14 100644 --- a/src-tauri/plugins/tauri-plugin-llamacpp/src/device.rs +++ b/src-tauri/plugins/tauri-plugin-llamacpp/src/device.rs @@ -19,20 +19,19 @@ pub struct DeviceInfo { pub async fn get_devices_from_backend( backend_path: &str, - library_path: Option<&str>, envs: HashMap, ) -> ServerResult> { log::info!("Getting devices from server at path: {:?}", backend_path); - validate_binary_path(backend_path)?; + let bin_path = validate_binary_path(backend_path)?; // Configure the command to run the server with --list-devices - let mut command = Command::new(backend_path); + let mut command = Command::new(&bin_path); command.arg("--list-devices"); command.envs(envs); // Set up library path - setup_library_path(library_path, &mut command); + setup_library_path(bin_path.parent().and_then(|p| p.to_str()), &mut command); command.stdout(Stdio::piped()); command.stderr(Stdio::piped()); @@ -410,4 +409,4 @@ AnotherInvalid assert_eq!(result[0].id, "Vulkan0"); assert_eq!(result[1].id, "CUDA0"); } -} \ No newline at end of file +} diff --git a/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/utils.rs b/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/utils.rs index cdbbf92d5..10dc66f48 100644 --- a/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/utils.rs +++ b/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/utils.rs @@ -62,7 +62,6 @@ pub async fn estimate_kv_cache_internal( ctx_size: Option, ) -> Result { log::info!("Received ctx_size parameter: {:?}", ctx_size); - log::info!("Received model metadata:\n{:?}", &meta); let arch = meta .get("general.architecture") .ok_or(KVCacheError::ArchitectureNotFound)?;