Dinh Long Nguyen 5cd81bc6e8
feat: improve testing (#6395)
* add more test rust test

* fix servicehub test

* fix tauri failing on windows
2025-09-09 12:16:25 +07:00

413 lines
13 KiB
Rust

use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::process::Stdio;
use std::time::Duration;
use tokio::process::Command;
use tokio::time::timeout;
use crate::error::{ErrorCode, LlamacppError, ServerError, ServerResult};
use crate::path::validate_binary_path;
use jan_utils::{setup_library_path, setup_windows_process_flags};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeviceInfo {
pub id: String,
pub name: String,
pub mem: i32,
pub free: i32,
}
pub async fn get_devices_from_backend(
backend_path: &str,
library_path: Option<&str>,
envs: HashMap<String, String>,
) -> ServerResult<Vec<DeviceInfo>> {
log::info!("Getting devices from server at path: {:?}", backend_path);
validate_binary_path(backend_path)?;
// Configure the command to run the server with --list-devices
let mut command = Command::new(backend_path);
command.arg("--list-devices");
command.envs(envs);
// Set up library path
setup_library_path(library_path, &mut command);
command.stdout(Stdio::piped());
command.stderr(Stdio::piped());
setup_windows_process_flags(&mut command);
// Execute the command and wait for completion
let output = timeout(Duration::from_secs(30), command.output())
.await
.map_err(|_| {
LlamacppError::new(
ErrorCode::InternalError,
"Timeout waiting for device list".into(),
None,
)
})?
.map_err(ServerError::Io)?;
// Check if command executed successfully
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
log::error!("llama-server --list-devices failed: {}", stderr);
return Err(LlamacppError::from_stderr(&stderr).into());
}
// Parse the output
let stdout = String::from_utf8_lossy(&output.stdout);
log::info!("Device list output:\n{}", stdout);
parse_device_output(&stdout)
}
fn parse_device_output(output: &str) -> ServerResult<Vec<DeviceInfo>> {
let mut devices = Vec::new();
let mut found_devices_section = false;
for raw in output.lines() {
// detect header (ignoring whitespace)
if raw.trim() == "Available devices:" {
found_devices_section = true;
continue;
}
if !found_devices_section {
continue;
}
// skip blank lines
if raw.trim().is_empty() {
continue;
}
// now parse any non-blank line after the header
let line = raw.trim();
if let Some(device) = parse_device_line(line)? {
devices.push(device);
}
}
if devices.is_empty() && found_devices_section {
log::warn!("No devices found in output");
} else if !found_devices_section {
return Err(LlamacppError::new(
ErrorCode::DeviceListParseFailed,
"Could not find 'Available devices:' section in the backend output.".into(),
Some(output.to_string()),
)
.into());
}
Ok(devices)
}
fn parse_device_line(line: &str) -> ServerResult<Option<DeviceInfo>> {
let line = line.trim();
log::info!("Parsing device line: '{}'", line);
// Expected formats:
// "Vulkan0: Intel(R) Arc(tm) A750 Graphics (DG2) (8128 MiB, 8128 MiB free)"
// "CUDA0: NVIDIA GeForce RTX 4090 (24576 MiB, 24000 MiB free)"
// "SYCL0: Intel(R) Arc(TM) A750 Graphics (8000 MiB, 7721 MiB free)"
// Split by colon to get ID and rest
let parts: Vec<&str> = line.splitn(2, ':').collect();
if parts.len() != 2 {
log::warn!("Skipping malformed device line: {}", line);
return Ok(None);
}
let id = parts[0].trim().to_string();
let rest = parts[1].trim();
// Use regex-like approach to find the memory pattern at the end
// Look for pattern: (number MiB, number MiB free) at the end
if let Some(memory_match) = find_memory_pattern(rest) {
let (memory_start, memory_content) = memory_match;
let name = rest[..memory_start].trim().to_string();
// Parse memory info: "8128 MiB, 8128 MiB free"
let memory_parts: Vec<&str> = memory_content.split(',').collect();
if memory_parts.len() >= 2 {
if let (Ok(total_mem), Ok(free_mem)) = (
parse_memory_value(memory_parts[0].trim()),
parse_memory_value(memory_parts[1].trim()),
) {
log::info!(
"Parsed device - ID: '{}', Name: '{}', Mem: {}, Free: {}",
id,
name,
total_mem,
free_mem
);
return Ok(Some(DeviceInfo {
id,
name,
mem: total_mem,
free: free_mem,
}));
}
}
}
log::warn!("Could not parse device line: {}", line);
Ok(None)
}
fn find_memory_pattern(text: &str) -> Option<(usize, &str)> {
// Find the last parenthesis that contains the memory pattern
let mut last_match = None;
let mut chars = text.char_indices().peekable();
while let Some((start_idx, ch)) = chars.next() {
if ch == '(' {
// Find the closing parenthesis
let remaining = &text[start_idx + 1..];
if let Some(close_pos) = remaining.find(')') {
let content = &remaining[..close_pos];
// Check if this looks like memory info
if is_memory_pattern(content) {
last_match = Some((start_idx, content));
}
}
}
}
last_match
}
fn is_memory_pattern(content: &str) -> bool {
// Check if content matches pattern like "8128 MiB, 8128 MiB free"
// Must contain: numbers, "MiB", comma, "free"
if !(content.contains("MiB") && content.contains("free") && content.contains(',')) {
return false;
}
let parts: Vec<&str> = content.split(',').collect();
if parts.len() != 2 {
return false;
}
parts.iter().all(|part| {
let part = part.trim();
// Each part should start with a number and contain "MiB"
part.split_whitespace()
.next()
.map_or(false, |first_word| first_word.parse::<i32>().is_ok())
&& part.contains("MiB")
})
}
fn parse_memory_value(mem_str: &str) -> ServerResult<i32> {
// Handle formats like "8000 MiB" or "7721 MiB free"
let parts: Vec<&str> = mem_str.split_whitespace().collect();
if parts.is_empty() {
return Err(LlamacppError::new(
ErrorCode::DeviceListParseFailed,
format!("empty memory value: {}", mem_str),
None,
)
.into());
}
// Take the first part which should be the number
let number_str = parts[0];
number_str.parse::<i32>().map_err(|_| {
LlamacppError::new(
ErrorCode::DeviceListParseFailed,
format!("Could not parse memory value: '{}'", number_str),
None,
)
.into()
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_is_memory_pattern_valid() {
assert!(is_memory_pattern("8128 MiB, 8128 MiB free"));
assert!(is_memory_pattern("1024 MiB, 512 MiB free"));
assert!(is_memory_pattern("16384 MiB, 12000 MiB free"));
assert!(is_memory_pattern("0 MiB, 0 MiB free"));
}
#[test]
fn test_is_memory_pattern_invalid() {
assert!(!is_memory_pattern("8128 MB, 8128 MB free")); // Wrong unit
assert!(!is_memory_pattern("8128 MiB 8128 MiB free")); // Missing comma
assert!(!is_memory_pattern("8128 MiB, 8128 MiB used")); // Wrong second part
assert!(!is_memory_pattern("not_a_number MiB, 8128 MiB free")); // Invalid number
assert!(!is_memory_pattern("8128 MiB")); // Missing second part
assert!(!is_memory_pattern("")); // Empty string
assert!(!is_memory_pattern("8128 MiB, free")); // Missing number in second part
}
#[test]
fn test_find_memory_pattern() {
let text = "Intel(R) Arc(tm) A750 Graphics (DG2) (8128 MiB, 4096 MiB free)";
let result = find_memory_pattern(text);
assert!(result.is_some());
let (start_idx, content) = result.unwrap();
assert!(start_idx > 0);
assert_eq!(content, "8128 MiB, 4096 MiB free");
}
#[test]
fn test_find_memory_pattern_multiple_parentheses() {
let text = "Device (test) with (1024 MiB, 512 MiB free) and (2048 MiB, 1024 MiB free)";
let result = find_memory_pattern(text);
assert!(result.is_some());
let (_, content) = result.unwrap();
// Should return the LAST valid memory pattern
assert_eq!(content, "2048 MiB, 1024 MiB free");
}
#[test]
fn test_find_memory_pattern_no_match() {
let text = "No memory info here";
assert!(find_memory_pattern(text).is_none());
let text_with_invalid = "Some text (invalid memory info) here";
assert!(find_memory_pattern(text_with_invalid).is_none());
}
#[test]
fn test_parse_memory_value() {
assert_eq!(parse_memory_value("8128 MiB").unwrap(), 8128);
assert_eq!(parse_memory_value("7721 MiB free").unwrap(), 7721);
assert_eq!(parse_memory_value("0 MiB").unwrap(), 0);
assert_eq!(parse_memory_value("24576 MiB").unwrap(), 24576);
}
#[test]
fn test_parse_memory_value_invalid() {
assert!(parse_memory_value("").is_err());
assert!(parse_memory_value("not_a_number MiB").is_err());
assert!(parse_memory_value(" ").is_err());
}
#[test]
fn test_parse_device_line_vulkan() {
let line = "Vulkan0: Intel(R) Arc(tm) A750 Graphics (DG2) (8128 MiB, 8128 MiB free)";
let result = parse_device_line(line).unwrap();
assert!(result.is_some());
let device = result.unwrap();
assert_eq!(device.id, "Vulkan0");
assert_eq!(device.name, "Intel(R) Arc(tm) A750 Graphics (DG2)");
assert_eq!(device.mem, 8128);
assert_eq!(device.free, 8128);
}
#[test]
fn test_parse_device_line_cuda() {
let line = "CUDA0: NVIDIA GeForce RTX 4090 (24576 MiB, 24000 MiB free)";
let result = parse_device_line(line).unwrap();
assert!(result.is_some());
let device = result.unwrap();
assert_eq!(device.id, "CUDA0");
assert_eq!(device.name, "NVIDIA GeForce RTX 4090");
assert_eq!(device.mem, 24576);
assert_eq!(device.free, 24000);
}
#[test]
fn test_parse_device_line_sycl() {
let line = "SYCL0: Intel(R) Arc(TM) A750 Graphics (8000 MiB, 7721 MiB free)";
let result = parse_device_line(line).unwrap();
assert!(result.is_some());
let device = result.unwrap();
assert_eq!(device.id, "SYCL0");
assert_eq!(device.name, "Intel(R) Arc(TM) A750 Graphics");
assert_eq!(device.mem, 8000);
assert_eq!(device.free, 7721);
}
#[test]
fn test_parse_device_line_malformed() {
// Missing colon
let result = parse_device_line("Vulkan0 Intel Graphics (8128 MiB, 8128 MiB free)").unwrap();
assert!(result.is_none());
// Missing memory info
let result = parse_device_line("Vulkan0: Intel Graphics").unwrap();
assert!(result.is_none());
// Invalid memory format
let result = parse_device_line("Vulkan0: Intel Graphics (invalid memory)").unwrap();
assert!(result.is_none());
}
#[test]
fn test_parse_device_output_valid() {
let output = r#"
Some header text
Available devices:
Vulkan0: Intel(R) Arc(tm) A750 Graphics (DG2) (8128 MiB, 8128 MiB free)
CUDA0: NVIDIA GeForce RTX 4090 (24576 MiB, 24000 MiB free)
SYCL0: Intel(R) Arc(TM) A750 Graphics (8000 MiB, 7721 MiB free)
Some footer text
"#;
let result = parse_device_output(output).unwrap();
assert_eq!(result.len(), 3);
assert_eq!(result[0].id, "Vulkan0");
assert_eq!(result[0].name, "Intel(R) Arc(tm) A750 Graphics (DG2)");
assert_eq!(result[0].mem, 8128);
assert_eq!(result[1].id, "CUDA0");
assert_eq!(result[1].name, "NVIDIA GeForce RTX 4090");
assert_eq!(result[1].mem, 24576);
assert_eq!(result[2].id, "SYCL0");
assert_eq!(result[2].name, "Intel(R) Arc(TM) A750 Graphics");
assert_eq!(result[2].mem, 8000);
}
#[test]
fn test_parse_device_output_no_devices_section() {
let output = "Some output without Available devices section";
let result = parse_device_output(output);
assert!(result.is_err());
}
#[test]
fn test_parse_device_output_empty_devices() {
let output = r#"
Some header text
Available devices:
Some footer text
"#;
let result = parse_device_output(output).unwrap();
assert_eq!(result.len(), 0);
}
#[test]
fn test_parse_device_output_mixed_valid_invalid() {
let output = r#"
Available devices:
Vulkan0: Intel(R) Arc(tm) A750 Graphics (DG2) (8128 MiB, 8128 MiB free)
InvalidLine: No memory info
CUDA0: NVIDIA GeForce RTX 4090 (24576 MiB, 24000 MiB free)
AnotherInvalid
"#;
let result = parse_device_output(output).unwrap();
assert_eq!(result.len(), 2); // Only valid lines should be parsed
assert_eq!(result[0].id, "Vulkan0");
assert_eq!(result[1].id, "CUDA0");
}
}