refactor: Simplify Tauri plugin calls and enhance 'Flash Attention' setting

This commit introduces significant improvements to the llama.cpp extension, focusing on the 'Flash Attention' setting and refactoring Tauri plugin interactions for better code clarity and maintenance.

The backend interaction is streamlined by removing the unnecessary `libraryPath` argument from the Tauri plugin commands for loading models and listing devices.

* **Simplified API Calls:** The `loadLlamaModel`, `unloadLlamaModel`, and `get_devices` functions in both the extension and the Tauri plugin now manage the library path internally based on the backend executable's location.
* **Decoupled Logic:** The extension (`src/index.ts`) now uses the new, simplified Tauri plugin functions, which enhances modularity and reduces boilerplate code in the extension.
* **Type Consistency:** Added `UnloadResult` interface to `guest-js/index.ts` for consistency.

* **Updated UI Control:** The 'Flash Attention' setting in `settings.json` is changed from a boolean checkbox to a string-based dropdown, offering **'auto'**, **'on'**, and **'off'** options.
* **Improved Logic:** The extension logic in `src/index.ts` is updated to correctly handle the new string-based `flash_attn` configuration. It now passes the string value (`'auto'`, `'on'`, or `'off'`) directly as a command-line argument to the llama.cpp backend, simplifying the version-checking logic previously required for older llama.cpp versions. The old, complex logic tied to specific backend versions is removed.

This refactoring cleans up the extension's codebase and moves environment and path setup concerns into the Tauri plugin where they are most relevant.
This commit is contained in:
Akarshan 2025-10-13 19:29:09 +05:30
parent 653ecdb494
commit 0c5fbc102c
No known key found for this signature in database
GPG Key ID: D75C9634A870665F
6 changed files with 45 additions and 51 deletions

View File

@ -149,9 +149,14 @@
"key": "flash_attn",
"title": "Flash Attention",
"description": "Enable Flash Attention for optimized performance.",
"controllerType": "checkbox",
"controllerType": "dropdown",
"controllerProps": {
"value": false
"value": "auto",
"options": [
{ "value": "auto", "name": "Auto" },
{ "value": "on", "name": "ON" },
{ "value": "off", "name": "OFF" }
]
}
},
{

View File

@ -38,10 +38,12 @@ import { invoke } from '@tauri-apps/api/core'
import { getProxyConfig } from './util'
import { basename } from '@tauri-apps/api/path'
import {
loadLlamaModel,
readGgufMetadata,
getModelSize,
isModelSupported,
planModelLoadInternal,
unloadLlamaModel,
} from '@janhq/tauri-plugin-llamacpp-api'
import { getSystemUsage, getSystemInfo } from '@janhq/tauri-plugin-hardware-api'
@ -69,7 +71,7 @@ type LlamacppConfig = {
device: string
split_mode: string
main_gpu: number
flash_attn: boolean
flash_attn: string
cont_batching: boolean
no_mmap: boolean
mlock: boolean
@ -1646,14 +1648,11 @@ export default class llamacpp_extension extends AIEngine {
args.push('--split-mode', cfg.split_mode)
if (cfg.main_gpu !== undefined && cfg.main_gpu != 0)
args.push('--main-gpu', String(cfg.main_gpu))
// Note: Older llama.cpp versions are no longer supported
if (cfg.flash_attn !== undefined || cfg.flash_attn === '') args.push('--flash-attn', String(cfg.flash_attn)) //default: auto = ON when supported
// Boolean flags
if (cfg.ctx_shift) args.push('--context-shift')
if (Number(version.replace(/^b/, '')) >= 6325) {
if (!cfg.flash_attn) args.push('--flash-attn', 'off') //default: auto = ON when supported
} else {
if (cfg.flash_attn) args.push('--flash-attn')
}
if (cfg.cont_batching) args.push('--cont-batching')
args.push('--no-mmap')
if (cfg.mlock) args.push('--mlock')
@ -1688,20 +1687,9 @@ export default class llamacpp_extension extends AIEngine {
logger.info('Calling Tauri command llama_load with args:', args)
const backendPath = await getBackendExePath(backend, version)
const libraryPath = await joinPath([await this.getProviderPath(), 'lib'])
try {
// TODO: add LIBRARY_PATH
const sInfo = await invoke<SessionInfo>(
'plugin:llamacpp|load_llama_model',
{
backendPath,
libraryPath,
args,
envs,
isEmbedding,
}
)
const sInfo = await loadLlamaModel(backendPath, args, envs)
return sInfo
} catch (error) {
logger.error('Error in load command:\n', error)
@ -1717,12 +1705,7 @@ export default class llamacpp_extension extends AIEngine {
const pid = sInfo.pid
try {
// Pass the PID as the session_id
const result = await invoke<UnloadResult>(
'plugin:llamacpp|unload_llama_model',
{
pid: pid,
}
)
const result = await unloadLlamaModel(pid)
// If successful, remove from active sessions
if (result.success) {
@ -2042,7 +2025,10 @@ export default class llamacpp_extension extends AIEngine {
if (sysInfo?.os_type === 'linux' && Array.isArray(sysInfo.gpus)) {
const usage = await getSystemUsage()
if (usage && Array.isArray(usage.gpus)) {
const uuidToUsage: Record<string, { total_memory: number; used_memory: number }> = {}
const uuidToUsage: Record<
string,
{ total_memory: number; used_memory: number }
> = {}
for (const u of usage.gpus as any[]) {
if (u && typeof u.uuid === 'string') {
uuidToUsage[u.uuid] = u
@ -2082,7 +2068,10 @@ export default class llamacpp_extension extends AIEngine {
typeof u.used_memory === 'number'
) {
const total = Math.max(0, Math.floor(u.total_memory))
const free = Math.max(0, Math.floor(u.total_memory - u.used_memory))
const free = Math.max(
0,
Math.floor(u.total_memory - u.used_memory)
)
return { ...dev, mem: total, free }
}
}

View File

@ -2,11 +2,17 @@ import { invoke } from '@tauri-apps/api/core'
// Types
export interface SessionInfo {
pid: number
port: number
model_id: string
model_path: string
api_key: string
pid: number;
port: number;
model_id: string;
model_path: string;
api_key: string;
mmproj_path?: string;
}
export interface UnloadResult {
success: boolean;
error?: string;
}
export interface DeviceInfo {
@ -29,19 +35,17 @@ export async function cleanupLlamaProcesses(): Promise<void> {
// LlamaCpp server commands
export async function loadLlamaModel(
backendPath: string,
libraryPath?: string,
args: string[] = [],
isEmbedding: boolean = false
args: string[],
envs: Record<string, string>
): Promise<SessionInfo> {
return await invoke('plugin:llamacpp|load_llama_model', {
backendPath,
libraryPath,
args,
isEmbedding,
envs
})
}
export async function unloadLlamaModel(pid: number): Promise<void> {
export async function unloadLlamaModel(pid: number): Promise<UnloadResult> {
return await invoke('plugin:llamacpp|unload_llama_model', { pid })
}

View File

@ -41,7 +41,6 @@ pub struct UnloadResult {
pub async fn load_llama_model<R: Runtime>(
app_handle: tauri::AppHandle<R>,
backend_path: &str,
library_path: Option<&str>,
mut args: Vec<String>,
envs: HashMap<String, String>,
is_embedding: bool,
@ -52,7 +51,7 @@ pub async fn load_llama_model<R: Runtime>(
log::info!("Attempting to launch server at path: {:?}", backend_path);
log::info!("Using arguments: {:?}", args);
validate_binary_path(backend_path)?;
let bin_path = validate_binary_path(backend_path)?;
let port = parse_port_from_args(&args);
let model_path_pb = validate_model_path(&mut args)?;
@ -83,11 +82,11 @@ pub async fn load_llama_model<R: Runtime>(
let model_id = extract_arg_value(&args, "-a");
// Configure the command to run the server
let mut command = Command::new(backend_path);
let mut command = Command::new(&bin_path);
command.args(args);
command.envs(envs);
setup_library_path(library_path, &mut command);
setup_library_path(bin_path.parent().and_then(|p| p.to_str()), &mut command);
command.stdout(Stdio::piped());
command.stderr(Stdio::piped());
setup_windows_process_flags(&mut command);
@ -280,10 +279,9 @@ pub async fn unload_llama_model<R: Runtime>(
#[tauri::command]
pub async fn get_devices(
backend_path: &str,
library_path: Option<&str>,
envs: HashMap<String, String>,
) -> ServerResult<Vec<DeviceInfo>> {
get_devices_from_backend(backend_path, library_path, envs).await
get_devices_from_backend(backend_path, envs).await
}
/// Generate API key using HMAC-SHA256

View File

@ -19,20 +19,19 @@ pub struct DeviceInfo {
pub async fn get_devices_from_backend(
backend_path: &str,
library_path: Option<&str>,
envs: HashMap<String, String>,
) -> ServerResult<Vec<DeviceInfo>> {
log::info!("Getting devices from server at path: {:?}", backend_path);
validate_binary_path(backend_path)?;
let bin_path = validate_binary_path(backend_path)?;
// Configure the command to run the server with --list-devices
let mut command = Command::new(backend_path);
let mut command = Command::new(&bin_path);
command.arg("--list-devices");
command.envs(envs);
// Set up library path
setup_library_path(library_path, &mut command);
setup_library_path(bin_path.parent().and_then(|p| p.to_str()), &mut command);
command.stdout(Stdio::piped());
command.stderr(Stdio::piped());
@ -410,4 +409,4 @@ AnotherInvalid
assert_eq!(result[0].id, "Vulkan0");
assert_eq!(result[1].id, "CUDA0");
}
}
}

View File

@ -62,7 +62,6 @@ pub async fn estimate_kv_cache_internal(
ctx_size: Option<u64>,
) -> Result<KVCacheEstimate, KVCacheError> {
log::info!("Received ctx_size parameter: {:?}", ctx_size);
log::info!("Received model metadata:\n{:?}", &meta);
let arch = meta
.get("general.architecture")
.ok_or(KVCacheError::ArchitectureNotFound)?;