Fix: Improve Llama.cpp model path handling and error handling (#6045)
* Improve Llama.cpp model path handling and validation This commit refactors the load_llama_model function to improve how it handles and validates the model path. Previously, the function extracted the model path but did not perform any validation. This change adds the following improvements: It now checks for the presence of the -m flag. It verifies that a path is provided after the -m flag. It validates that the specified model path actually exists on the filesystem. It ensures that the SessionInfo struct stores the canonical display path of the model, which is a more robust approach. These changes make the model loading process more reliable and provide better error handling for invalid or missing model paths. * Exp: Use short path on Windows * Fix: Remove error channel and handling in llama.cpp server loading The previous implementation used a channel to receive error messages from the llama.cpp server's stdout. However, this proved unreliable as the path names can contain 'errors strings' that we use to check even during normal operation. This commit removes the error channel and associated error handling logic. The server readiness is still determined by checking for the "server is listening" message in stdout. Errors are now handled by relying on the process exit code and capturing the full stderr output if the process fails to start or exits unexpectedly. This approach provides a more robust and accurate error detection mechanism. * Add else block in Windows path handling * Add some path related tests * Fix windows tests
This commit is contained in:
parent
99567a1102
commit
088b9d7f25
@ -63,8 +63,12 @@ nix = "=0.30.1"
|
||||
|
||||
[target.'cfg(windows)'.dependencies]
|
||||
libc = "0.2.172"
|
||||
windows-sys = { version = "0.60.2", features = ["Win32_Storage_FileSystem"] }
|
||||
|
||||
[target.'cfg(not(any(target_os = "android", target_os = "ios")))'.dependencies]
|
||||
tauri-plugin-updater = "2"
|
||||
once_cell = "1.18"
|
||||
tauri-plugin-single-instance = { version = "2.0.0", features = ["deep-link"] }
|
||||
|
||||
[target.'cfg(windows)'.dev-dependencies]
|
||||
tempfile = "3.20.0"
|
||||
|
||||
@ -67,13 +67,39 @@ pub struct DeviceInfo {
|
||||
pub free: i32,
|
||||
}
|
||||
|
||||
#[cfg(windows)]
|
||||
use std::os::windows::ffi::OsStrExt;
|
||||
|
||||
#[cfg(windows)]
|
||||
use std::ffi::OsStr;
|
||||
|
||||
#[cfg(windows)]
|
||||
use windows_sys::Win32::Storage::FileSystem::GetShortPathNameW;
|
||||
|
||||
#[cfg(windows)]
|
||||
pub fn get_short_path<P: AsRef<std::path::Path>>(path: P) -> Option<String> {
|
||||
let wide: Vec<u16> = OsStr::new(path.as_ref())
|
||||
.encode_wide()
|
||||
.chain(Some(0))
|
||||
.collect();
|
||||
|
||||
let mut buffer = vec![0u16; 260];
|
||||
let len = unsafe { GetShortPathNameW(wide.as_ptr(), buffer.as_mut_ptr(), buffer.len() as u32) };
|
||||
|
||||
if len > 0 {
|
||||
Some(String::from_utf16_lossy(&buffer[..len as usize]))
|
||||
} else {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
// --- Load Command ---
|
||||
#[tauri::command]
|
||||
pub async fn load_llama_model(
|
||||
state: State<'_, AppState>,
|
||||
backend_path: &str,
|
||||
library_path: Option<&str>,
|
||||
args: Vec<String>,
|
||||
mut args: Vec<String>,
|
||||
) -> ServerResult<SessionInfo> {
|
||||
let mut process_map = state.llama_server_process.lock().await;
|
||||
|
||||
@ -105,13 +131,38 @@ pub async fn load_llama_model(
|
||||
8080
|
||||
}
|
||||
};
|
||||
|
||||
let model_path = args
|
||||
// FOR MODEL PATH; TODO: DO SIMILARLY FOR MMPROJ PATH
|
||||
let model_path_index = args
|
||||
.iter()
|
||||
.position(|arg| arg == "-m")
|
||||
.and_then(|i| args.get(i + 1))
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
.ok_or(ServerError::LlamacppError("Missing `-m` flag".into()))?;
|
||||
|
||||
let model_path = args
|
||||
.get(model_path_index + 1)
|
||||
.ok_or(ServerError::LlamacppError("Missing path after `-m`".into()))?
|
||||
.clone();
|
||||
|
||||
let model_path_pb = PathBuf::from(model_path);
|
||||
if !model_path_pb.exists() {
|
||||
return Err(ServerError::LlamacppError(format!(
|
||||
"Invalid or inaccessible model path: {}",
|
||||
model_path_pb.display().to_string(),
|
||||
)));
|
||||
}
|
||||
#[cfg(windows)]
|
||||
{
|
||||
// use short path on Windows
|
||||
if let Some(short) = get_short_path(&model_path_pb) {
|
||||
args[model_path_index + 1] = short;
|
||||
} else {
|
||||
args[model_path_index + 1] = model_path_pb.display().to_string();
|
||||
}
|
||||
}
|
||||
#[cfg(not(windows))]
|
||||
{
|
||||
args[model_path_index + 1] = model_path_pb.display().to_string();
|
||||
}
|
||||
// -----------------------------------------------------------------
|
||||
|
||||
let api_key = args
|
||||
.iter()
|
||||
@ -181,7 +232,6 @@ pub async fn load_llama_model(
|
||||
|
||||
// Create channels for communication between tasks
|
||||
let (ready_tx, mut ready_rx) = mpsc::channel::<bool>(1);
|
||||
let (error_tx, mut error_rx) = mpsc::channel::<String>(1);
|
||||
|
||||
// Spawn task to monitor stdout for readiness
|
||||
let _stdout_task = tokio::spawn(async move {
|
||||
@ -228,20 +278,10 @@ pub async fn load_llama_model(
|
||||
|
||||
// Check for critical error indicators that should stop the process
|
||||
let line_lower = line.to_string().to_lowercase();
|
||||
if line_lower.contains("error loading model")
|
||||
|| line_lower.contains("unknown model architecture")
|
||||
|| line_lower.contains("fatal")
|
||||
|| line_lower.contains("cuda error")
|
||||
|| line_lower.contains("out of memory")
|
||||
|| line_lower.contains("error")
|
||||
|| line_lower.contains("failed")
|
||||
{
|
||||
let _ = error_tx.send(line.to_string()).await;
|
||||
}
|
||||
// Check for readiness indicator - llama-server outputs this when ready
|
||||
else if line.contains("server is listening on")
|
||||
|| line.contains("starting the main loop")
|
||||
|| line.contains("server listening on")
|
||||
if line_lower.contains("server is listening on")
|
||||
|| line_lower.contains("starting the main loop")
|
||||
|| line_lower.contains("server listening on")
|
||||
{
|
||||
log::info!("Server appears to be ready based on stderr: '{}'", line);
|
||||
let _ = ready_tx.send(true).await;
|
||||
@ -279,26 +319,6 @@ pub async fn load_llama_model(
|
||||
log::info!("Server is ready to accept requests!");
|
||||
break;
|
||||
}
|
||||
// Error occurred
|
||||
Some(error_msg) = error_rx.recv() => {
|
||||
log::error!("Server encountered an error: {}", error_msg);
|
||||
|
||||
// Give process a moment to exit naturally
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
|
||||
// Check if process already exited
|
||||
if let Some(status) = child.try_wait()? {
|
||||
log::info!("Process exited with code {:?}", status);
|
||||
return Err(ServerError::LlamacppError(error_msg));
|
||||
} else {
|
||||
log::info!("Process still running, killing it...");
|
||||
let _ = child.kill().await;
|
||||
}
|
||||
|
||||
// Get full stderr output
|
||||
let stderr_output = stderr_task.await.unwrap_or_default();
|
||||
return Err(ServerError::LlamacppError(format!("Error: {}\n\nFull stderr:\n{}", error_msg, stderr_output)));
|
||||
}
|
||||
// Check for process exit more frequently
|
||||
_ = tokio::time::sleep(Duration::from_millis(50)) => {
|
||||
// Check if process exited
|
||||
@ -332,7 +352,7 @@ pub async fn load_llama_model(
|
||||
pid: pid.clone(),
|
||||
port: port,
|
||||
model_id: model_id,
|
||||
model_path: model_path,
|
||||
model_path: model_path_pb.display().to_string(),
|
||||
api_key: api_key,
|
||||
};
|
||||
|
||||
@ -714,6 +734,9 @@ pub fn is_port_available(port: u16) -> bool {
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::path::PathBuf;
|
||||
#[cfg(windows)]
|
||||
use tempfile;
|
||||
|
||||
#[test]
|
||||
fn test_parse_multiple_devices() {
|
||||
@ -899,4 +922,41 @@ Vulkan1: AMD Radeon Graphics (RADV GFX1151) (87722 MiB, 87722 MiB free)"#;
|
||||
let (_start, content) = result.unwrap();
|
||||
assert_eq!(content, "8128 MiB, 8128 MiB free");
|
||||
}
|
||||
#[test]
|
||||
fn test_path_with_uncommon_dir_names() {
|
||||
const UNCOMMON_DIR_NAME: &str = "тест-你好-éàç-🚀";
|
||||
#[cfg(windows)]
|
||||
{
|
||||
let dir = tempfile::tempdir().expect("Failed to create temp dir");
|
||||
let long_path = dir.path().join(UNCOMMON_DIR_NAME);
|
||||
std::fs::create_dir(&long_path)
|
||||
.expect("Failed to create test directory with non-ASCII name");
|
||||
let short_path = get_short_path(&long_path);
|
||||
assert!(
|
||||
short_path.is_ascii(),
|
||||
"The resulting short path must be composed of only ASCII characters. Got: {}",
|
||||
short_path
|
||||
);
|
||||
assert!(
|
||||
PathBuf::from(&short_path).exists(),
|
||||
"The returned short path must exist on the filesystem"
|
||||
);
|
||||
assert_ne!(
|
||||
short_path,
|
||||
long_path.to_str().unwrap(),
|
||||
"Short path should not be the same as the long path"
|
||||
);
|
||||
}
|
||||
#[cfg(not(windows))]
|
||||
{
|
||||
// On Unix, paths are typically UTF-8 and there's no "short path" concept.
|
||||
let long_path_str = format!("/tmp/{}", UNCOMMON_DIR_NAME);
|
||||
let path_buf = PathBuf::from(&long_path_str);
|
||||
let displayed_path = path_buf.display().to_string();
|
||||
assert_eq!(
|
||||
displayed_path, long_path_str,
|
||||
"Path with non-ASCII characters should be preserved exactly on non-Windows platforms"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user