* feat: add getTokensCount method to compute token usage Implemented a new async `getTokensCount` function in the LLaMA.cpp extension. The method validates the model session, checks process health, applies the request template, and tokenizes the resulting prompt to return the token count. Includes detailed error handling for crashed models and API failures, enabling callers to assess token usage before sending completions. * Fix: typos * chore: update ui token usage * chore: remove unused code * feat: add image token handling for multimodal LlamaCPP models Implemented support for counting image tokens when using vision-enabled models: - Extended `SessionInfo` with optional `mmprojPath` to store the multimodal project file. - Propagated `mmproj_path` from the Tauri plugin into the session info. - Added import of `chatCompletionRequestMessage` and enhanced token calculation logic in the LlamaCPP extension: - Detects image content in messages. - Reads GGUF metadata from `mmprojPath` to compute accurate image token counts. - Provides a fallback estimation if metadata reading fails. - Returns the sum of text and image tokens. - Introduced helper methods `calculateImageTokens` and `estimateImageTokensFallback`. - Minor clean‑ups such as comment capitalization and debug logging. * chore: update FE send params message include content type image_url * fix mmproj path from session info and num tokens calculation * fix: Correct image token estimation calculation in llamacpp extension This commit addresses an inaccurate token count for images in the llama.cpp extension. The previous logic incorrectly calculated the token count based on image patch size and dimensions. This has been replaced with a more precise method that uses the clip.vision.projection_dim value from the model metadata. Additionally, unnecessary debug logging was removed, and a new log was added to show the mmproj metadata for improved visibility. * fix per image calc * fix: crash due to force unwrap --------- Co-authored-by: Faisal Amir <urmauur@gmail.com> Co-authored-by: Louis <louis@jan.ai>
347 lines
11 KiB
Rust
347 lines
11 KiB
Rust
use base64::{engine::general_purpose, Engine as _};
|
|
use hmac::{Hmac, Mac};
|
|
use sha2::Sha256;
|
|
use std::collections::HashMap;
|
|
use std::process::Stdio;
|
|
use std::time::Duration;
|
|
use tauri::{Manager, Runtime, State};
|
|
use tokio::io::{AsyncBufReadExt, BufReader};
|
|
use tokio::process::Command;
|
|
use tokio::sync::mpsc;
|
|
use tokio::time::Instant;
|
|
|
|
use crate::device::{get_devices_from_backend, DeviceInfo};
|
|
use crate::error::{ErrorCode, LlamacppError, ServerError, ServerResult};
|
|
use crate::path::{validate_binary_path, validate_mmproj_path, validate_model_path};
|
|
use crate::process::{
|
|
find_session_by_model_id, get_all_active_sessions, get_all_loaded_model_ids,
|
|
get_random_available_port, is_process_running_by_pid,
|
|
};
|
|
use crate::state::{LLamaBackendSession, LlamacppState, SessionInfo};
|
|
use jan_utils::{
|
|
extract_arg_value, parse_port_from_args, setup_library_path, setup_windows_process_flags,
|
|
};
|
|
|
|
#[cfg(unix)]
|
|
use crate::process::graceful_terminate_process;
|
|
|
|
#[cfg(all(windows, target_arch = "x86_64"))]
|
|
use crate::process::force_terminate_process;
|
|
|
|
type HmacSha256 = Hmac<Sha256>;
|
|
|
|
#[derive(serde::Serialize, serde::Deserialize)]
|
|
pub struct UnloadResult {
|
|
success: bool,
|
|
error: Option<String>,
|
|
}
|
|
|
|
/// Load a llama model and start the server
|
|
#[tauri::command]
|
|
pub async fn load_llama_model<R: Runtime>(
|
|
app_handle: tauri::AppHandle<R>,
|
|
backend_path: &str,
|
|
library_path: Option<&str>,
|
|
mut args: Vec<String>,
|
|
envs: HashMap<String, String>,
|
|
) -> ServerResult<SessionInfo> {
|
|
let state: State<LlamacppState> = app_handle.state();
|
|
let mut process_map = state.llama_server_process.lock().await;
|
|
|
|
log::info!("Attempting to launch server at path: {:?}", backend_path);
|
|
log::info!("Using arguments: {:?}", args);
|
|
|
|
validate_binary_path(backend_path)?;
|
|
|
|
let port = parse_port_from_args(&args);
|
|
let model_path_pb = validate_model_path(&mut args)?;
|
|
let mmproj_path_pb = validate_mmproj_path(&mut args)?;
|
|
|
|
let mmproj_path_string = if let Some(ref _mmproj_pb) = mmproj_path_pb {
|
|
// Find the actual mmproj path from args after validation/conversion
|
|
if let Some(mmproj_index) = args.iter().position(|arg| arg == "--mmproj") {
|
|
Some(args[mmproj_index + 1].clone())
|
|
} else {
|
|
None
|
|
}
|
|
} else {
|
|
None
|
|
};
|
|
|
|
log::info!("MMPROJ Path string: {}", &mmproj_path_string.as_ref().unwrap_or(&"None".to_string()));
|
|
|
|
let api_key: String;
|
|
|
|
if let Some(api_value) = envs.get("LLAMA_API_KEY") {
|
|
api_key = api_value.to_string();
|
|
} else {
|
|
log::warn!("API key not provided");
|
|
api_key = "".to_string();
|
|
}
|
|
|
|
let model_id = extract_arg_value(&args, "-a");
|
|
|
|
// Configure the command to run the server
|
|
let mut command = Command::new(backend_path);
|
|
command.args(args);
|
|
command.envs(envs);
|
|
|
|
setup_library_path(library_path, &mut command);
|
|
command.stdout(Stdio::piped());
|
|
command.stderr(Stdio::piped());
|
|
setup_windows_process_flags(&mut command);
|
|
|
|
// Spawn the child process
|
|
let mut child = command.spawn().map_err(ServerError::Io)?;
|
|
|
|
let stderr = child.stderr.take().expect("stderr was piped");
|
|
let stdout = child.stdout.take().expect("stdout was piped");
|
|
|
|
// Create channels for communication between tasks
|
|
let (ready_tx, mut ready_rx) = mpsc::channel::<bool>(1);
|
|
|
|
// Spawn task to monitor stdout for readiness
|
|
let _stdout_task = tokio::spawn(async move {
|
|
let mut reader = BufReader::new(stdout);
|
|
let mut byte_buffer = Vec::new();
|
|
|
|
loop {
|
|
byte_buffer.clear();
|
|
match reader.read_until(b'\n', &mut byte_buffer).await {
|
|
Ok(0) => break, // EOF
|
|
Ok(_) => {
|
|
let line = String::from_utf8_lossy(&byte_buffer);
|
|
let line = line.trim_end();
|
|
if !line.is_empty() {
|
|
log::info!("[llamacpp stdout] {}", line);
|
|
}
|
|
}
|
|
Err(e) => {
|
|
log::error!("Error reading stdout: {}", e);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
});
|
|
|
|
// Spawn task to capture stderr and monitor for errors
|
|
let stderr_task = tokio::spawn(async move {
|
|
let mut reader = BufReader::new(stderr);
|
|
let mut byte_buffer = Vec::new();
|
|
let mut stderr_buffer = String::new();
|
|
|
|
loop {
|
|
byte_buffer.clear();
|
|
match reader.read_until(b'\n', &mut byte_buffer).await {
|
|
Ok(0) => break, // EOF
|
|
Ok(_) => {
|
|
let line = String::from_utf8_lossy(&byte_buffer);
|
|
let line = line.trim_end();
|
|
|
|
if !line.is_empty() {
|
|
stderr_buffer.push_str(line);
|
|
stderr_buffer.push('\n');
|
|
log::info!("[llamacpp] {}", line);
|
|
|
|
// Check for readiness indicator - llama-server outputs this when ready
|
|
let line_lower = line.to_string().to_lowercase();
|
|
if line_lower.contains("server is listening on")
|
|
|| line_lower.contains("starting the main loop")
|
|
|| line_lower.contains("server listening on")
|
|
{
|
|
log::info!("Model appears to be ready based on logs: '{}'", line);
|
|
let _ = ready_tx.send(true).await;
|
|
}
|
|
}
|
|
}
|
|
Err(e) => {
|
|
log::error!("Error reading logs: {}", e);
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
stderr_buffer
|
|
});
|
|
|
|
// Check if process exited early
|
|
if let Some(status) = child.try_wait()? {
|
|
if !status.success() {
|
|
let stderr_output = stderr_task.await.unwrap_or_default();
|
|
log::error!("llama.cpp failed early with code {:?}", status);
|
|
log::error!("{}", stderr_output);
|
|
return Err(LlamacppError::from_stderr(&stderr_output).into());
|
|
}
|
|
}
|
|
|
|
// Wait for server to be ready or timeout
|
|
let timeout_duration = Duration::from_secs(300); // 5 minutes timeout
|
|
let start_time = Instant::now();
|
|
log::info!("Waiting for model session to be ready...");
|
|
loop {
|
|
tokio::select! {
|
|
// Server is ready
|
|
Some(true) = ready_rx.recv() => {
|
|
log::info!("Model is ready to accept requests!");
|
|
break;
|
|
}
|
|
// Check for process exit more frequently
|
|
_ = tokio::time::sleep(Duration::from_millis(50)) => {
|
|
// Check if process exited
|
|
if let Some(status) = child.try_wait()? {
|
|
let stderr_output = stderr_task.await.unwrap_or_default();
|
|
if !status.success() {
|
|
log::error!("llama.cpp exited with error code {:?}", status);
|
|
return Err(LlamacppError::from_stderr(&stderr_output).into());
|
|
} else {
|
|
log::error!("llama.cpp exited successfully but without ready signal");
|
|
return Err(LlamacppError::from_stderr(&stderr_output).into());
|
|
}
|
|
}
|
|
|
|
// Timeout check
|
|
if start_time.elapsed() > timeout_duration {
|
|
log::error!("Timeout waiting for server to be ready");
|
|
let _ = child.kill().await;
|
|
let stderr_output = stderr_task.await.unwrap_or_default();
|
|
return Err(LlamacppError::new(
|
|
ErrorCode::ModelLoadTimedOut,
|
|
"The model took too long to load and timed out.".into(),
|
|
Some(format!("Timeout: {}s\n\nStderr:\n{}", timeout_duration.as_secs(), stderr_output)),
|
|
).into());
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Get the PID to use as session ID
|
|
let pid = child.id().map(|id| id as i32).unwrap_or(-1);
|
|
|
|
log::info!("Server process started with PID: {} and is ready", pid);
|
|
let session_info = SessionInfo {
|
|
pid: pid.clone(),
|
|
port: port,
|
|
model_id: model_id,
|
|
model_path: model_path_pb.display().to_string(),
|
|
api_key: api_key,
|
|
mmproj_path: mmproj_path_string,
|
|
};
|
|
|
|
// Insert session info to process_map
|
|
process_map.insert(
|
|
pid.clone(),
|
|
LLamaBackendSession {
|
|
child,
|
|
info: session_info.clone(),
|
|
},
|
|
);
|
|
|
|
Ok(session_info)
|
|
}
|
|
|
|
/// Unload a llama model by terminating its process
|
|
#[tauri::command]
|
|
pub async fn unload_llama_model<R: Runtime>(
|
|
app_handle: tauri::AppHandle<R>,
|
|
pid: i32,
|
|
) -> ServerResult<UnloadResult> {
|
|
let state: State<LlamacppState> = app_handle.state();
|
|
let mut map = state.llama_server_process.lock().await;
|
|
|
|
if let Some(session) = map.remove(&pid) {
|
|
let mut child = session.child;
|
|
|
|
#[cfg(unix)]
|
|
{
|
|
graceful_terminate_process(&mut child).await;
|
|
}
|
|
|
|
#[cfg(all(windows, target_arch = "x86_64"))]
|
|
{
|
|
force_terminate_process(&mut child).await;
|
|
}
|
|
|
|
Ok(UnloadResult {
|
|
success: true,
|
|
error: None,
|
|
})
|
|
} else {
|
|
log::warn!("No server with PID '{}' found", pid);
|
|
Ok(UnloadResult {
|
|
success: true,
|
|
error: None,
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Get available devices from the llama.cpp backend
|
|
#[tauri::command]
|
|
pub async fn get_devices(
|
|
backend_path: &str,
|
|
library_path: Option<&str>,
|
|
envs: HashMap<String, String>,
|
|
) -> ServerResult<Vec<DeviceInfo>> {
|
|
get_devices_from_backend(backend_path, library_path, envs).await
|
|
}
|
|
|
|
/// Generate API key using HMAC-SHA256
|
|
#[tauri::command]
|
|
pub fn generate_api_key(model_id: String, api_secret: String) -> Result<String, String> {
|
|
let mut mac = HmacSha256::new_from_slice(api_secret.as_bytes())
|
|
.map_err(|e| format!("Invalid key length: {}", e))?;
|
|
mac.update(model_id.as_bytes());
|
|
let result = mac.finalize();
|
|
let code_bytes = result.into_bytes();
|
|
let hash = general_purpose::STANDARD.encode(code_bytes);
|
|
Ok(hash)
|
|
}
|
|
|
|
/// Check if a process is still running
|
|
#[tauri::command]
|
|
pub async fn is_process_running<R: Runtime>(
|
|
app_handle: tauri::AppHandle<R>,
|
|
pid: i32,
|
|
) -> Result<bool, String> {
|
|
is_process_running_by_pid(app_handle, pid).await
|
|
}
|
|
|
|
/// Get a random available port
|
|
#[tauri::command]
|
|
pub async fn get_random_port<R: Runtime>(app_handle: tauri::AppHandle<R>) -> Result<u16, String> {
|
|
get_random_available_port(app_handle).await
|
|
}
|
|
|
|
/// Find session information by model ID
|
|
#[tauri::command]
|
|
pub async fn find_session_by_model<R: Runtime>(
|
|
app_handle: tauri::AppHandle<R>,
|
|
model_id: String,
|
|
) -> Result<Option<SessionInfo>, String> {
|
|
find_session_by_model_id(app_handle, &model_id).await
|
|
}
|
|
|
|
/// Get all loaded model IDs
|
|
#[tauri::command]
|
|
pub async fn get_loaded_models<R: Runtime>(
|
|
app_handle: tauri::AppHandle<R>,
|
|
) -> Result<Vec<String>, String> {
|
|
get_all_loaded_model_ids(app_handle).await
|
|
}
|
|
|
|
/// Get all active sessions
|
|
#[tauri::command]
|
|
pub async fn get_all_sessions<R: Runtime>(
|
|
app_handle: tauri::AppHandle<R>,
|
|
) -> Result<Vec<SessionInfo>, String> {
|
|
get_all_active_sessions(app_handle).await
|
|
}
|
|
|
|
/// Get session information by model ID
|
|
#[tauri::command]
|
|
pub async fn get_session_by_model<R: Runtime>(
|
|
app_handle: tauri::AppHandle<R>,
|
|
model_id: String,
|
|
) -> Result<Option<SessionInfo>, String> {
|
|
find_session_by_model_id(app_handle, &model_id).await
|
|
}
|