Akarshan Biswas 5c3a6fec32
feat: Add support for custom environmental variables to llama.cpp (#6256)
This commit adds a new setting `llamacpp_env` to the llama.cpp extension, allowing users to specify custom environment variables. These variables are passed to the backend process when it starts.

A new function `parseEnvFromString` is introduced to handle the parsing of the semicolon-separated key-value pairs from the user input. The environment variables are then used in the `load` function and when listing available devices. This enables more flexible configuration of the llama.cpp backend, such as specifying visible GPUs for Vulkan.

This change also updates the Tauri command `get_devices` to accept environment variables, ensuring that device discovery respects the user's settings.
2025-08-21 15:50:37 +05:30

333 lines
11 KiB
Rust

use base64::{engine::general_purpose, Engine as _};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use std::collections::HashMap;
use std::process::Stdio;
use std::time::Duration;
use tauri::{Manager, Runtime, State};
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command;
use tokio::sync::mpsc;
use tokio::time::Instant;
use crate::device::{get_devices_from_backend, DeviceInfo};
use crate::error::{ErrorCode, LlamacppError, ServerError, ServerResult};
use crate::path::{validate_binary_path, validate_model_path, validate_mmproj_path};
use crate::process::{
find_session_by_model_id, get_all_active_sessions, get_all_loaded_model_ids,
get_random_available_port, is_process_running_by_pid,
};
use crate::state::{LLamaBackendSession, LlamacppState, SessionInfo};
use jan_utils::{
extract_arg_value, parse_port_from_args, setup_library_path, setup_windows_process_flags,
};
#[cfg(unix)]
use crate::process::graceful_terminate_process;
#[cfg(all(windows, target_arch = "x86_64"))]
use crate::process::force_terminate_process;
type HmacSha256 = Hmac<Sha256>;
#[derive(serde::Serialize, serde::Deserialize)]
pub struct UnloadResult {
success: bool,
error: Option<String>,
}
/// Load a llama model and start the server
#[tauri::command]
pub async fn load_llama_model<R: Runtime>(
app_handle: tauri::AppHandle<R>,
backend_path: &str,
library_path: Option<&str>,
mut args: Vec<String>,
envs: HashMap<String, String>,
) -> ServerResult<SessionInfo> {
let state: State<LlamacppState> = app_handle.state();
let mut process_map = state.llama_server_process.lock().await;
log::info!("Attempting to launch server at path: {:?}", backend_path);
log::info!("Using arguments: {:?}", args);
validate_binary_path(backend_path)?;
let port = parse_port_from_args(&args);
let model_path_pb = validate_model_path(&mut args)?;
let _mmproj_path_pb = validate_mmproj_path(&mut args)?;
let api_key: String;
if let Some(api_value) = envs.get("LLAMA_API_KEY") {
api_key = api_value.to_string();
} else {
log::warn!("API key not provided");
api_key = "".to_string();
}
let model_id = extract_arg_value(&args, "-a");
// Configure the command to run the server
let mut command = Command::new(backend_path);
command.args(args);
command.envs(envs);
setup_library_path(library_path, &mut command);
command.stdout(Stdio::piped());
command.stderr(Stdio::piped());
setup_windows_process_flags(&mut command);
// Spawn the child process
let mut child = command.spawn().map_err(ServerError::Io)?;
let stderr = child.stderr.take().expect("stderr was piped");
let stdout = child.stdout.take().expect("stdout was piped");
// Create channels for communication between tasks
let (ready_tx, mut ready_rx) = mpsc::channel::<bool>(1);
// Spawn task to monitor stdout for readiness
let _stdout_task = tokio::spawn(async move {
let mut reader = BufReader::new(stdout);
let mut byte_buffer = Vec::new();
loop {
byte_buffer.clear();
match reader.read_until(b'\n', &mut byte_buffer).await {
Ok(0) => break, // EOF
Ok(_) => {
let line = String::from_utf8_lossy(&byte_buffer);
let line = line.trim_end();
if !line.is_empty() {
log::info!("[llamacpp stdout] {}", line);
}
}
Err(e) => {
log::error!("Error reading stdout: {}", e);
break;
}
}
}
});
// Spawn task to capture stderr and monitor for errors
let stderr_task = tokio::spawn(async move {
let mut reader = BufReader::new(stderr);
let mut byte_buffer = Vec::new();
let mut stderr_buffer = String::new();
loop {
byte_buffer.clear();
match reader.read_until(b'\n', &mut byte_buffer).await {
Ok(0) => break, // EOF
Ok(_) => {
let line = String::from_utf8_lossy(&byte_buffer);
let line = line.trim_end();
if !line.is_empty() {
stderr_buffer.push_str(line);
stderr_buffer.push('\n');
log::info!("[llamacpp] {}", line);
// Check for readiness indicator - llama-server outputs this when ready
let line_lower = line.to_string().to_lowercase();
if line_lower.contains("server is listening on")
|| line_lower.contains("starting the main loop")
|| line_lower.contains("server listening on")
{
log::info!("Model appears to be ready based on logs: '{}'", line);
let _ = ready_tx.send(true).await;
}
}
}
Err(e) => {
log::error!("Error reading logs: {}", e);
break;
}
}
}
stderr_buffer
});
// Check if process exited early
if let Some(status) = child.try_wait()? {
if !status.success() {
let stderr_output = stderr_task.await.unwrap_or_default();
log::error!("llama.cpp failed early with code {:?}", status);
log::error!("{}", stderr_output);
return Err(LlamacppError::from_stderr(&stderr_output).into());
}
}
// Wait for server to be ready or timeout
let timeout_duration = Duration::from_secs(180); // 3 minutes timeout
let start_time = Instant::now();
log::info!("Waiting for model session to be ready...");
loop {
tokio::select! {
// Server is ready
Some(true) = ready_rx.recv() => {
log::info!("Model is ready to accept requests!");
break;
}
// Check for process exit more frequently
_ = tokio::time::sleep(Duration::from_millis(50)) => {
// Check if process exited
if let Some(status) = child.try_wait()? {
let stderr_output = stderr_task.await.unwrap_or_default();
if !status.success() {
log::error!("llama.cpp exited with error code {:?}", status);
return Err(LlamacppError::from_stderr(&stderr_output).into());
} else {
log::error!("llama.cpp exited successfully but without ready signal");
return Err(LlamacppError::from_stderr(&stderr_output).into());
}
}
// Timeout check
if start_time.elapsed() > timeout_duration {
log::error!("Timeout waiting for server to be ready");
let _ = child.kill().await;
let stderr_output = stderr_task.await.unwrap_or_default();
return Err(LlamacppError::new(
ErrorCode::ModelLoadTimedOut,
"The model took too long to load and timed out.".into(),
Some(format!("Timeout: {}s\n\nStderr:\n{}", timeout_duration.as_secs(), stderr_output)),
).into());
}
}
}
}
// Get the PID to use as session ID
let pid = child.id().map(|id| id as i32).unwrap_or(-1);
log::info!("Server process started with PID: {} and is ready", pid);
let session_info = SessionInfo {
pid: pid.clone(),
port: port,
model_id: model_id,
model_path: model_path_pb.display().to_string(),
api_key: api_key,
};
// Insert session info to process_map
process_map.insert(
pid.clone(),
LLamaBackendSession {
child,
info: session_info.clone(),
},
);
Ok(session_info)
}
/// Unload a llama model by terminating its process
#[tauri::command]
pub async fn unload_llama_model<R: Runtime>(
app_handle: tauri::AppHandle<R>,
pid: i32,
) -> ServerResult<UnloadResult> {
let state: State<LlamacppState> = app_handle.state();
let mut map = state.llama_server_process.lock().await;
if let Some(session) = map.remove(&pid) {
let mut child = session.child;
#[cfg(unix)]
{
graceful_terminate_process(&mut child).await;
}
#[cfg(all(windows, target_arch = "x86_64"))]
{
force_terminate_process(&mut child).await;
}
Ok(UnloadResult {
success: true,
error: None,
})
} else {
log::warn!("No server with PID '{}' found", pid);
Ok(UnloadResult {
success: true,
error: None,
})
}
}
/// Get available devices from the llama.cpp backend
#[tauri::command]
pub async fn get_devices(
backend_path: &str,
library_path: Option<&str>,
envs: HashMap<String, String>
) -> ServerResult<Vec<DeviceInfo>> {
get_devices_from_backend(backend_path, library_path, envs).await
}
/// Generate API key using HMAC-SHA256
#[tauri::command]
pub fn generate_api_key(model_id: String, api_secret: String) -> Result<String, String> {
let mut mac = HmacSha256::new_from_slice(api_secret.as_bytes())
.map_err(|e| format!("Invalid key length: {}", e))?;
mac.update(model_id.as_bytes());
let result = mac.finalize();
let code_bytes = result.into_bytes();
let hash = general_purpose::STANDARD.encode(code_bytes);
Ok(hash)
}
/// Check if a process is still running
#[tauri::command]
pub async fn is_process_running<R: Runtime>(
app_handle: tauri::AppHandle<R>,
pid: i32,
) -> Result<bool, String> {
is_process_running_by_pid(app_handle, pid).await
}
/// Get a random available port
#[tauri::command]
pub async fn get_random_port<R: Runtime>(app_handle: tauri::AppHandle<R>) -> Result<u16, String> {
get_random_available_port(app_handle).await
}
/// Find session information by model ID
#[tauri::command]
pub async fn find_session_by_model<R: Runtime>(
app_handle: tauri::AppHandle<R>,
model_id: String,
) -> Result<Option<SessionInfo>, String> {
find_session_by_model_id(app_handle, &model_id).await
}
/// Get all loaded model IDs
#[tauri::command]
pub async fn get_loaded_models<R: Runtime>(
app_handle: tauri::AppHandle<R>,
) -> Result<Vec<String>, String> {
get_all_loaded_model_ids(app_handle).await
}
/// Get all active sessions
#[tauri::command]
pub async fn get_all_sessions<R: Runtime>(
app_handle: tauri::AppHandle<R>,
) -> Result<Vec<SessionInfo>, String> {
get_all_active_sessions(app_handle).await
}
/// Get session information by model ID
#[tauri::command]
pub async fn get_session_by_model<R: Runtime>(
app_handle: tauri::AppHandle<R>,
model_id: String,
) -> Result<Option<SessionInfo>, String> {
find_session_by_model_id(app_handle, &model_id).await
}