Fix: Improve Llama.cpp model path handling and error handling (#6045)

* Improve Llama.cpp model path handling and validation

This commit refactors the load_llama_model function to improve how it handles and validates the model path.

Previously, the function extracted the model path but did not perform any validation. This change adds the following improvements:

It now checks for the presence of the -m flag.

It verifies that a path is provided after the -m flag.

It validates that the specified model path actually exists on the filesystem.

It ensures that the SessionInfo struct stores the canonical display path of the model, which is a more robust approach.

These changes make the model loading process more reliable and provide better error handling for invalid or missing model paths.

* Exp: Use short path on Windows

* Fix: Remove error channel and handling in llama.cpp server loading

The previous implementation used a channel to receive error messages from the llama.cpp server's stdout. However, this proved unreliable as the path names can contain 'errors strings' that we use to check even during normal operation. This commit removes the error channel and associated error handling logic.
The server readiness is still determined by checking for the "server is listening" message in stdout. Errors are now handled by relying on the process exit code and capturing the full stderr output if the process fails to start or exits unexpectedly. This approach provides a more robust and accurate error detection mechanism.

* Add else block in Windows path handling

* Add some path related tests

* Fix windows tests
This commit is contained in:
Akarshan Biswas 2025-08-05 14:17:19 +05:30 committed by GitHub
parent 99567a1102
commit 088b9d7f25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 105 additions and 41 deletions

View File

@ -63,8 +63,12 @@ nix = "=0.30.1"
[target.'cfg(windows)'.dependencies] [target.'cfg(windows)'.dependencies]
libc = "0.2.172" libc = "0.2.172"
windows-sys = { version = "0.60.2", features = ["Win32_Storage_FileSystem"] }
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies] [target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
tauri-plugin-updater = "2" tauri-plugin-updater = "2"
once_cell = "1.18" once_cell = "1.18"
tauri-plugin-single-instance = { version = "2.0.0", features = ["deep-link"] } tauri-plugin-single-instance = { version = "2.0.0", features = ["deep-link"] }
[target.'cfg(windows)'.dev-dependencies]
tempfile = "3.20.0"

View File

@ -67,13 +67,39 @@ pub struct DeviceInfo {
pub free: i32, pub free: i32,
} }
#[cfg(windows)]
use std::os::windows::ffi::OsStrExt;
#[cfg(windows)]
use std::ffi::OsStr;
#[cfg(windows)]
use windows_sys::Win32::Storage::FileSystem::GetShortPathNameW;
#[cfg(windows)]
pub fn get_short_path<P: AsRef<std::path::Path>>(path: P) -> Option<String> {
let wide: Vec<u16> = OsStr::new(path.as_ref())
.encode_wide()
.chain(Some(0))
.collect();
let mut buffer = vec![0u16; 260];
let len = unsafe { GetShortPathNameW(wide.as_ptr(), buffer.as_mut_ptr(), buffer.len() as u32) };
if len > 0 {
Some(String::from_utf16_lossy(&buffer[..len as usize]))
} else {
None
}
}
// --- Load Command --- // --- Load Command ---
#[tauri::command] #[tauri::command]
pub async fn load_llama_model( pub async fn load_llama_model(
state: State<'_, AppState>, state: State<'_, AppState>,
backend_path: &str, backend_path: &str,
library_path: Option<&str>, library_path: Option<&str>,
args: Vec<String>, mut args: Vec<String>,
) -> ServerResult<SessionInfo> { ) -> ServerResult<SessionInfo> {
let mut process_map = state.llama_server_process.lock().await; let mut process_map = state.llama_server_process.lock().await;
@ -105,13 +131,38 @@ pub async fn load_llama_model(
8080 8080
} }
}; };
// FOR MODEL PATH; TODO: DO SIMILARLY FOR MMPROJ PATH
let model_path = args let model_path_index = args
.iter() .iter()
.position(|arg| arg == "-m") .position(|arg| arg == "-m")
.and_then(|i| args.get(i + 1)) .ok_or(ServerError::LlamacppError("Missing `-m` flag".into()))?;
.cloned()
.unwrap_or_default(); let model_path = args
.get(model_path_index + 1)
.ok_or(ServerError::LlamacppError("Missing path after `-m`".into()))?
.clone();
let model_path_pb = PathBuf::from(model_path);
if !model_path_pb.exists() {
return Err(ServerError::LlamacppError(format!(
"Invalid or inaccessible model path: {}",
model_path_pb.display().to_string(),
)));
}
#[cfg(windows)]
{
// use short path on Windows
if let Some(short) = get_short_path(&model_path_pb) {
args[model_path_index + 1] = short;
} else {
args[model_path_index + 1] = model_path_pb.display().to_string();
}
}
#[cfg(not(windows))]
{
args[model_path_index + 1] = model_path_pb.display().to_string();
}
// -----------------------------------------------------------------
let api_key = args let api_key = args
.iter() .iter()
@ -181,7 +232,6 @@ pub async fn load_llama_model(
// Create channels for communication between tasks // Create channels for communication between tasks
let (ready_tx, mut ready_rx) = mpsc::channel::<bool>(1); let (ready_tx, mut ready_rx) = mpsc::channel::<bool>(1);
let (error_tx, mut error_rx) = mpsc::channel::<String>(1);
// Spawn task to monitor stdout for readiness // Spawn task to monitor stdout for readiness
let _stdout_task = tokio::spawn(async move { let _stdout_task = tokio::spawn(async move {
@ -228,20 +278,10 @@ pub async fn load_llama_model(
// Check for critical error indicators that should stop the process // Check for critical error indicators that should stop the process
let line_lower = line.to_string().to_lowercase(); let line_lower = line.to_string().to_lowercase();
if line_lower.contains("error loading model")
|| line_lower.contains("unknown model architecture")
|| line_lower.contains("fatal")
|| line_lower.contains("cuda error")
|| line_lower.contains("out of memory")
|| line_lower.contains("error")
|| line_lower.contains("failed")
{
let _ = error_tx.send(line.to_string()).await;
}
// Check for readiness indicator - llama-server outputs this when ready // Check for readiness indicator - llama-server outputs this when ready
else if line.contains("server is listening on") if line_lower.contains("server is listening on")
|| line.contains("starting the main loop") || line_lower.contains("starting the main loop")
|| line.contains("server listening on") || line_lower.contains("server listening on")
{ {
log::info!("Server appears to be ready based on stderr: '{}'", line); log::info!("Server appears to be ready based on stderr: '{}'", line);
let _ = ready_tx.send(true).await; let _ = ready_tx.send(true).await;
@ -279,26 +319,6 @@ pub async fn load_llama_model(
log::info!("Server is ready to accept requests!"); log::info!("Server is ready to accept requests!");
break; break;
} }
// Error occurred
Some(error_msg) = error_rx.recv() => {
log::error!("Server encountered an error: {}", error_msg);
// Give process a moment to exit naturally
tokio::time::sleep(Duration::from_millis(100)).await;
// Check if process already exited
if let Some(status) = child.try_wait()? {
log::info!("Process exited with code {:?}", status);
return Err(ServerError::LlamacppError(error_msg));
} else {
log::info!("Process still running, killing it...");
let _ = child.kill().await;
}
// Get full stderr output
let stderr_output = stderr_task.await.unwrap_or_default();
return Err(ServerError::LlamacppError(format!("Error: {}\n\nFull stderr:\n{}", error_msg, stderr_output)));
}
// Check for process exit more frequently // Check for process exit more frequently
_ = tokio::time::sleep(Duration::from_millis(50)) => { _ = tokio::time::sleep(Duration::from_millis(50)) => {
// Check if process exited // Check if process exited
@ -332,7 +352,7 @@ pub async fn load_llama_model(
pid: pid.clone(), pid: pid.clone(),
port: port, port: port,
model_id: model_id, model_id: model_id,
model_path: model_path, model_path: model_path_pb.display().to_string(),
api_key: api_key, api_key: api_key,
}; };
@ -714,6 +734,9 @@ pub fn is_port_available(port: u16) -> bool {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use std::path::PathBuf;
#[cfg(windows)]
use tempfile;
#[test] #[test]
fn test_parse_multiple_devices() { fn test_parse_multiple_devices() {
@ -899,4 +922,41 @@ Vulkan1: AMD Radeon Graphics (RADV GFX1151) (87722 MiB, 87722 MiB free)"#;
let (_start, content) = result.unwrap(); let (_start, content) = result.unwrap();
assert_eq!(content, "8128 MiB, 8128 MiB free"); assert_eq!(content, "8128 MiB, 8128 MiB free");
} }
#[test]
fn test_path_with_uncommon_dir_names() {
const UNCOMMON_DIR_NAME: &str = "тест-你好-éàç-🚀";
#[cfg(windows)]
{
let dir = tempfile::tempdir().expect("Failed to create temp dir");
let long_path = dir.path().join(UNCOMMON_DIR_NAME);
std::fs::create_dir(&long_path)
.expect("Failed to create test directory with non-ASCII name");
let short_path = get_short_path(&long_path);
assert!(
short_path.is_ascii(),
"The resulting short path must be composed of only ASCII characters. Got: {}",
short_path
);
assert!(
PathBuf::from(&short_path).exists(),
"The returned short path must exist on the filesystem"
);
assert_ne!(
short_path,
long_path.to_str().unwrap(),
"Short path should not be the same as the long path"
);
}
#[cfg(not(windows))]
{
// On Unix, paths are typically UTF-8 and there's no "short path" concept.
let long_path_str = format!("/tmp/{}", UNCOMMON_DIR_NAME);
let path_buf = PathBuf::from(&long_path_str);
let displayed_path = path_buf.display().to_string();
assert_eq!(
displayed_path, long_path_str,
"Path with non-ASCII characters should be preserved exactly on non-Windows platforms"
);
}
}
} }