refactor: Simplify Tauri plugin calls and enhance 'Flash Attention' setting
This commit introduces significant improvements to the llama.cpp extension, focusing on the 'Flash Attention' setting and refactoring Tauri plugin interactions for better code clarity and maintenance. The backend interaction is streamlined by removing the unnecessary `libraryPath` argument from the Tauri plugin commands for loading models and listing devices. * **Simplified API Calls:** The `loadLlamaModel`, `unloadLlamaModel`, and `get_devices` functions in both the extension and the Tauri plugin now manage the library path internally based on the backend executable's location. * **Decoupled Logic:** The extension (`src/index.ts`) now uses the new, simplified Tauri plugin functions, which enhances modularity and reduces boilerplate code in the extension. * **Type Consistency:** Added `UnloadResult` interface to `guest-js/index.ts` for consistency. * **Updated UI Control:** The 'Flash Attention' setting in `settings.json` is changed from a boolean checkbox to a string-based dropdown, offering **'auto'**, **'on'**, and **'off'** options. * **Improved Logic:** The extension logic in `src/index.ts` is updated to correctly handle the new string-based `flash_attn` configuration. It now passes the string value (`'auto'`, `'on'`, or `'off'`) directly as a command-line argument to the llama.cpp backend, simplifying the version-checking logic previously required for older llama.cpp versions. The old, complex logic tied to specific backend versions is removed. This refactoring cleans up the extension's codebase and moves environment and path setup concerns into the Tauri plugin where they are most relevant.
This commit is contained in:
parent
653ecdb494
commit
0c5fbc102c
@ -149,9 +149,14 @@
|
||||
"key": "flash_attn",
|
||||
"title": "Flash Attention",
|
||||
"description": "Enable Flash Attention for optimized performance.",
|
||||
"controllerType": "checkbox",
|
||||
"controllerType": "dropdown",
|
||||
"controllerProps": {
|
||||
"value": false
|
||||
"value": "auto",
|
||||
"options": [
|
||||
{ "value": "auto", "name": "Auto" },
|
||||
{ "value": "on", "name": "ON" },
|
||||
{ "value": "off", "name": "OFF" }
|
||||
]
|
||||
}
|
||||
},
|
||||
{
|
||||
|
||||
@ -38,10 +38,12 @@ import { invoke } from '@tauri-apps/api/core'
|
||||
import { getProxyConfig } from './util'
|
||||
import { basename } from '@tauri-apps/api/path'
|
||||
import {
|
||||
loadLlamaModel,
|
||||
readGgufMetadata,
|
||||
getModelSize,
|
||||
isModelSupported,
|
||||
planModelLoadInternal,
|
||||
unloadLlamaModel,
|
||||
} from '@janhq/tauri-plugin-llamacpp-api'
|
||||
import { getSystemUsage, getSystemInfo } from '@janhq/tauri-plugin-hardware-api'
|
||||
|
||||
@ -69,7 +71,7 @@ type LlamacppConfig = {
|
||||
device: string
|
||||
split_mode: string
|
||||
main_gpu: number
|
||||
flash_attn: boolean
|
||||
flash_attn: string
|
||||
cont_batching: boolean
|
||||
no_mmap: boolean
|
||||
mlock: boolean
|
||||
@ -1646,14 +1648,11 @@ export default class llamacpp_extension extends AIEngine {
|
||||
args.push('--split-mode', cfg.split_mode)
|
||||
if (cfg.main_gpu !== undefined && cfg.main_gpu != 0)
|
||||
args.push('--main-gpu', String(cfg.main_gpu))
|
||||
// Note: Older llama.cpp versions are no longer supported
|
||||
if (cfg.flash_attn !== undefined || cfg.flash_attn === '') args.push('--flash-attn', String(cfg.flash_attn)) //default: auto = ON when supported
|
||||
|
||||
// Boolean flags
|
||||
if (cfg.ctx_shift) args.push('--context-shift')
|
||||
if (Number(version.replace(/^b/, '')) >= 6325) {
|
||||
if (!cfg.flash_attn) args.push('--flash-attn', 'off') //default: auto = ON when supported
|
||||
} else {
|
||||
if (cfg.flash_attn) args.push('--flash-attn')
|
||||
}
|
||||
if (cfg.cont_batching) args.push('--cont-batching')
|
||||
args.push('--no-mmap')
|
||||
if (cfg.mlock) args.push('--mlock')
|
||||
@ -1688,20 +1687,9 @@ export default class llamacpp_extension extends AIEngine {
|
||||
|
||||
logger.info('Calling Tauri command llama_load with args:', args)
|
||||
const backendPath = await getBackendExePath(backend, version)
|
||||
const libraryPath = await joinPath([await this.getProviderPath(), 'lib'])
|
||||
|
||||
try {
|
||||
// TODO: add LIBRARY_PATH
|
||||
const sInfo = await invoke<SessionInfo>(
|
||||
'plugin:llamacpp|load_llama_model',
|
||||
{
|
||||
backendPath,
|
||||
libraryPath,
|
||||
args,
|
||||
envs,
|
||||
isEmbedding,
|
||||
}
|
||||
)
|
||||
const sInfo = await loadLlamaModel(backendPath, args, envs)
|
||||
return sInfo
|
||||
} catch (error) {
|
||||
logger.error('Error in load command:\n', error)
|
||||
@ -1717,12 +1705,7 @@ export default class llamacpp_extension extends AIEngine {
|
||||
const pid = sInfo.pid
|
||||
try {
|
||||
// Pass the PID as the session_id
|
||||
const result = await invoke<UnloadResult>(
|
||||
'plugin:llamacpp|unload_llama_model',
|
||||
{
|
||||
pid: pid,
|
||||
}
|
||||
)
|
||||
const result = await unloadLlamaModel(pid)
|
||||
|
||||
// If successful, remove from active sessions
|
||||
if (result.success) {
|
||||
@ -2042,7 +2025,10 @@ export default class llamacpp_extension extends AIEngine {
|
||||
if (sysInfo?.os_type === 'linux' && Array.isArray(sysInfo.gpus)) {
|
||||
const usage = await getSystemUsage()
|
||||
if (usage && Array.isArray(usage.gpus)) {
|
||||
const uuidToUsage: Record<string, { total_memory: number; used_memory: number }> = {}
|
||||
const uuidToUsage: Record<
|
||||
string,
|
||||
{ total_memory: number; used_memory: number }
|
||||
> = {}
|
||||
for (const u of usage.gpus as any[]) {
|
||||
if (u && typeof u.uuid === 'string') {
|
||||
uuidToUsage[u.uuid] = u
|
||||
@ -2082,7 +2068,10 @@ export default class llamacpp_extension extends AIEngine {
|
||||
typeof u.used_memory === 'number'
|
||||
) {
|
||||
const total = Math.max(0, Math.floor(u.total_memory))
|
||||
const free = Math.max(0, Math.floor(u.total_memory - u.used_memory))
|
||||
const free = Math.max(
|
||||
0,
|
||||
Math.floor(u.total_memory - u.used_memory)
|
||||
)
|
||||
return { ...dev, mem: total, free }
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,11 +2,17 @@ import { invoke } from '@tauri-apps/api/core'
|
||||
|
||||
// Types
|
||||
export interface SessionInfo {
|
||||
pid: number
|
||||
port: number
|
||||
model_id: string
|
||||
model_path: string
|
||||
api_key: string
|
||||
pid: number;
|
||||
port: number;
|
||||
model_id: string;
|
||||
model_path: string;
|
||||
api_key: string;
|
||||
mmproj_path?: string;
|
||||
}
|
||||
|
||||
export interface UnloadResult {
|
||||
success: boolean;
|
||||
error?: string;
|
||||
}
|
||||
|
||||
export interface DeviceInfo {
|
||||
@ -29,19 +35,17 @@ export async function cleanupLlamaProcesses(): Promise<void> {
|
||||
// LlamaCpp server commands
|
||||
export async function loadLlamaModel(
|
||||
backendPath: string,
|
||||
libraryPath?: string,
|
||||
args: string[] = [],
|
||||
isEmbedding: boolean = false
|
||||
args: string[],
|
||||
envs: Record<string, string>
|
||||
): Promise<SessionInfo> {
|
||||
return await invoke('plugin:llamacpp|load_llama_model', {
|
||||
backendPath,
|
||||
libraryPath,
|
||||
args,
|
||||
isEmbedding,
|
||||
envs
|
||||
})
|
||||
}
|
||||
|
||||
export async function unloadLlamaModel(pid: number): Promise<void> {
|
||||
export async function unloadLlamaModel(pid: number): Promise<UnloadResult> {
|
||||
return await invoke('plugin:llamacpp|unload_llama_model', { pid })
|
||||
}
|
||||
|
||||
|
||||
@ -41,7 +41,6 @@ pub struct UnloadResult {
|
||||
pub async fn load_llama_model<R: Runtime>(
|
||||
app_handle: tauri::AppHandle<R>,
|
||||
backend_path: &str,
|
||||
library_path: Option<&str>,
|
||||
mut args: Vec<String>,
|
||||
envs: HashMap<String, String>,
|
||||
is_embedding: bool,
|
||||
@ -52,7 +51,7 @@ pub async fn load_llama_model<R: Runtime>(
|
||||
log::info!("Attempting to launch server at path: {:?}", backend_path);
|
||||
log::info!("Using arguments: {:?}", args);
|
||||
|
||||
validate_binary_path(backend_path)?;
|
||||
let bin_path = validate_binary_path(backend_path)?;
|
||||
|
||||
let port = parse_port_from_args(&args);
|
||||
let model_path_pb = validate_model_path(&mut args)?;
|
||||
@ -83,11 +82,11 @@ pub async fn load_llama_model<R: Runtime>(
|
||||
let model_id = extract_arg_value(&args, "-a");
|
||||
|
||||
// Configure the command to run the server
|
||||
let mut command = Command::new(backend_path);
|
||||
let mut command = Command::new(&bin_path);
|
||||
command.args(args);
|
||||
command.envs(envs);
|
||||
|
||||
setup_library_path(library_path, &mut command);
|
||||
setup_library_path(bin_path.parent().and_then(|p| p.to_str()), &mut command);
|
||||
command.stdout(Stdio::piped());
|
||||
command.stderr(Stdio::piped());
|
||||
setup_windows_process_flags(&mut command);
|
||||
@ -280,10 +279,9 @@ pub async fn unload_llama_model<R: Runtime>(
|
||||
#[tauri::command]
|
||||
pub async fn get_devices(
|
||||
backend_path: &str,
|
||||
library_path: Option<&str>,
|
||||
envs: HashMap<String, String>,
|
||||
) -> ServerResult<Vec<DeviceInfo>> {
|
||||
get_devices_from_backend(backend_path, library_path, envs).await
|
||||
get_devices_from_backend(backend_path, envs).await
|
||||
}
|
||||
|
||||
/// Generate API key using HMAC-SHA256
|
||||
|
||||
@ -19,20 +19,19 @@ pub struct DeviceInfo {
|
||||
|
||||
pub async fn get_devices_from_backend(
|
||||
backend_path: &str,
|
||||
library_path: Option<&str>,
|
||||
envs: HashMap<String, String>,
|
||||
) -> ServerResult<Vec<DeviceInfo>> {
|
||||
log::info!("Getting devices from server at path: {:?}", backend_path);
|
||||
|
||||
validate_binary_path(backend_path)?;
|
||||
let bin_path = validate_binary_path(backend_path)?;
|
||||
|
||||
// Configure the command to run the server with --list-devices
|
||||
let mut command = Command::new(backend_path);
|
||||
let mut command = Command::new(&bin_path);
|
||||
command.arg("--list-devices");
|
||||
command.envs(envs);
|
||||
|
||||
// Set up library path
|
||||
setup_library_path(library_path, &mut command);
|
||||
setup_library_path(bin_path.parent().and_then(|p| p.to_str()), &mut command);
|
||||
|
||||
command.stdout(Stdio::piped());
|
||||
command.stderr(Stdio::piped());
|
||||
@ -410,4 +409,4 @@ AnotherInvalid
|
||||
assert_eq!(result[0].id, "Vulkan0");
|
||||
assert_eq!(result[1].id, "CUDA0");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -62,7 +62,6 @@ pub async fn estimate_kv_cache_internal(
|
||||
ctx_size: Option<u64>,
|
||||
) -> Result<KVCacheEstimate, KVCacheError> {
|
||||
log::info!("Received ctx_size parameter: {:?}", ctx_size);
|
||||
log::info!("Received model metadata:\n{:?}", &meta);
|
||||
let arch = meta
|
||||
.get("general.architecture")
|
||||
.ok_or(KVCacheError::ArchitectureNotFound)?;
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user