Akarshan Biswas 885da29f28
feat: add getTokensCount method to compute token usage (#6467)
* feat: add getTokensCount method to compute token usage

Implemented a new async `getTokensCount` function in the LLaMA.cpp extension.
The method validates the model session, checks process health, applies the request template, and tokenizes the resulting prompt to return the token count. Includes detailed error handling for crashed models and API failures, enabling callers to assess token usage before sending completions.

* Fix: typos

* chore: update ui token usage

* chore: remove unused code

* feat: add image token handling for multimodal LlamaCPP models

Implemented support for counting image tokens when using vision-enabled models:
- Extended `SessionInfo` with optional `mmprojPath` to store the multimodal project file.
- Propagated `mmproj_path` from the Tauri plugin into the session info.
- Added import of `chatCompletionRequestMessage` and enhanced token calculation logic in the LlamaCPP extension:
- Detects image content in messages.
- Reads GGUF metadata from `mmprojPath` to compute accurate image token counts.
- Provides a fallback estimation if metadata reading fails.
- Returns the sum of text and image tokens.
- Introduced helper methods `calculateImageTokens` and `estimateImageTokensFallback`.
- Minor clean‑ups such as comment capitalization and debug logging.

* chore: update FE send params message include content type image_url

* fix mmproj path from session info and num tokens calculation

* fix: Correct image token estimation calculation in llamacpp extension

This commit addresses an inaccurate token count for images in the llama.cpp extension.

The previous logic incorrectly calculated the token count based on image patch size and dimensions. This has been replaced with a more precise method that uses the clip.vision.projection_dim value from the model metadata.

Additionally, unnecessary debug logging was removed, and a new log was added to show the mmproj metadata for improved visibility.

* fix per image calc

* fix: crash due to force unwrap

---------

Co-authored-by: Faisal Amir <urmauur@gmail.com>
Co-authored-by: Louis <louis@jan.ai>
2025-09-23 07:52:19 +05:30

347 lines
11 KiB
Rust

use base64::{engine::general_purpose, Engine as _};
use hmac::{Hmac, Mac};
use sha2::Sha256;
use std::collections::HashMap;
use std::process::Stdio;
use std::time::Duration;
use tauri::{Manager, Runtime, State};
use tokio::io::{AsyncBufReadExt, BufReader};
use tokio::process::Command;
use tokio::sync::mpsc;
use tokio::time::Instant;
use crate::device::{get_devices_from_backend, DeviceInfo};
use crate::error::{ErrorCode, LlamacppError, ServerError, ServerResult};
use crate::path::{validate_binary_path, validate_mmproj_path, validate_model_path};
use crate::process::{
find_session_by_model_id, get_all_active_sessions, get_all_loaded_model_ids,
get_random_available_port, is_process_running_by_pid,
};
use crate::state::{LLamaBackendSession, LlamacppState, SessionInfo};
use jan_utils::{
extract_arg_value, parse_port_from_args, setup_library_path, setup_windows_process_flags,
};
#[cfg(unix)]
use crate::process::graceful_terminate_process;
#[cfg(all(windows, target_arch = "x86_64"))]
use crate::process::force_terminate_process;
type HmacSha256 = Hmac<Sha256>;
#[derive(serde::Serialize, serde::Deserialize)]
pub struct UnloadResult {
success: bool,
error: Option<String>,
}
/// Load a llama model and start the server
#[tauri::command]
pub async fn load_llama_model<R: Runtime>(
app_handle: tauri::AppHandle<R>,
backend_path: &str,
library_path: Option<&str>,
mut args: Vec<String>,
envs: HashMap<String, String>,
) -> ServerResult<SessionInfo> {
let state: State<LlamacppState> = app_handle.state();
let mut process_map = state.llama_server_process.lock().await;
log::info!("Attempting to launch server at path: {:?}", backend_path);
log::info!("Using arguments: {:?}", args);
validate_binary_path(backend_path)?;
let port = parse_port_from_args(&args);
let model_path_pb = validate_model_path(&mut args)?;
let mmproj_path_pb = validate_mmproj_path(&mut args)?;
let mmproj_path_string = if let Some(ref _mmproj_pb) = mmproj_path_pb {
// Find the actual mmproj path from args after validation/conversion
if let Some(mmproj_index) = args.iter().position(|arg| arg == "--mmproj") {
Some(args[mmproj_index + 1].clone())
} else {
None
}
} else {
None
};
log::info!("MMPROJ Path string: {}", &mmproj_path_string.as_ref().unwrap_or(&"None".to_string()));
let api_key: String;
if let Some(api_value) = envs.get("LLAMA_API_KEY") {
api_key = api_value.to_string();
} else {
log::warn!("API key not provided");
api_key = "".to_string();
}
let model_id = extract_arg_value(&args, "-a");
// Configure the command to run the server
let mut command = Command::new(backend_path);
command.args(args);
command.envs(envs);
setup_library_path(library_path, &mut command);
command.stdout(Stdio::piped());
command.stderr(Stdio::piped());
setup_windows_process_flags(&mut command);
// Spawn the child process
let mut child = command.spawn().map_err(ServerError::Io)?;
let stderr = child.stderr.take().expect("stderr was piped");
let stdout = child.stdout.take().expect("stdout was piped");
// Create channels for communication between tasks
let (ready_tx, mut ready_rx) = mpsc::channel::<bool>(1);
// Spawn task to monitor stdout for readiness
let _stdout_task = tokio::spawn(async move {
let mut reader = BufReader::new(stdout);
let mut byte_buffer = Vec::new();
loop {
byte_buffer.clear();
match reader.read_until(b'\n', &mut byte_buffer).await {
Ok(0) => break, // EOF
Ok(_) => {
let line = String::from_utf8_lossy(&byte_buffer);
let line = line.trim_end();
if !line.is_empty() {
log::info!("[llamacpp stdout] {}", line);
}
}
Err(e) => {
log::error!("Error reading stdout: {}", e);
break;
}
}
}
});
// Spawn task to capture stderr and monitor for errors
let stderr_task = tokio::spawn(async move {
let mut reader = BufReader::new(stderr);
let mut byte_buffer = Vec::new();
let mut stderr_buffer = String::new();
loop {
byte_buffer.clear();
match reader.read_until(b'\n', &mut byte_buffer).await {
Ok(0) => break, // EOF
Ok(_) => {
let line = String::from_utf8_lossy(&byte_buffer);
let line = line.trim_end();
if !line.is_empty() {
stderr_buffer.push_str(line);
stderr_buffer.push('\n');
log::info!("[llamacpp] {}", line);
// Check for readiness indicator - llama-server outputs this when ready
let line_lower = line.to_string().to_lowercase();
if line_lower.contains("server is listening on")
|| line_lower.contains("starting the main loop")
|| line_lower.contains("server listening on")
{
log::info!("Model appears to be ready based on logs: '{}'", line);
let _ = ready_tx.send(true).await;
}
}
}
Err(e) => {
log::error!("Error reading logs: {}", e);
break;
}
}
}
stderr_buffer
});
// Check if process exited early
if let Some(status) = child.try_wait()? {
if !status.success() {
let stderr_output = stderr_task.await.unwrap_or_default();
log::error!("llama.cpp failed early with code {:?}", status);
log::error!("{}", stderr_output);
return Err(LlamacppError::from_stderr(&stderr_output).into());
}
}
// Wait for server to be ready or timeout
let timeout_duration = Duration::from_secs(300); // 5 minutes timeout
let start_time = Instant::now();
log::info!("Waiting for model session to be ready...");
loop {
tokio::select! {
// Server is ready
Some(true) = ready_rx.recv() => {
log::info!("Model is ready to accept requests!");
break;
}
// Check for process exit more frequently
_ = tokio::time::sleep(Duration::from_millis(50)) => {
// Check if process exited
if let Some(status) = child.try_wait()? {
let stderr_output = stderr_task.await.unwrap_or_default();
if !status.success() {
log::error!("llama.cpp exited with error code {:?}", status);
return Err(LlamacppError::from_stderr(&stderr_output).into());
} else {
log::error!("llama.cpp exited successfully but without ready signal");
return Err(LlamacppError::from_stderr(&stderr_output).into());
}
}
// Timeout check
if start_time.elapsed() > timeout_duration {
log::error!("Timeout waiting for server to be ready");
let _ = child.kill().await;
let stderr_output = stderr_task.await.unwrap_or_default();
return Err(LlamacppError::new(
ErrorCode::ModelLoadTimedOut,
"The model took too long to load and timed out.".into(),
Some(format!("Timeout: {}s\n\nStderr:\n{}", timeout_duration.as_secs(), stderr_output)),
).into());
}
}
}
}
// Get the PID to use as session ID
let pid = child.id().map(|id| id as i32).unwrap_or(-1);
log::info!("Server process started with PID: {} and is ready", pid);
let session_info = SessionInfo {
pid: pid.clone(),
port: port,
model_id: model_id,
model_path: model_path_pb.display().to_string(),
api_key: api_key,
mmproj_path: mmproj_path_string,
};
// Insert session info to process_map
process_map.insert(
pid.clone(),
LLamaBackendSession {
child,
info: session_info.clone(),
},
);
Ok(session_info)
}
/// Unload a llama model by terminating its process
#[tauri::command]
pub async fn unload_llama_model<R: Runtime>(
app_handle: tauri::AppHandle<R>,
pid: i32,
) -> ServerResult<UnloadResult> {
let state: State<LlamacppState> = app_handle.state();
let mut map = state.llama_server_process.lock().await;
if let Some(session) = map.remove(&pid) {
let mut child = session.child;
#[cfg(unix)]
{
graceful_terminate_process(&mut child).await;
}
#[cfg(all(windows, target_arch = "x86_64"))]
{
force_terminate_process(&mut child).await;
}
Ok(UnloadResult {
success: true,
error: None,
})
} else {
log::warn!("No server with PID '{}' found", pid);
Ok(UnloadResult {
success: true,
error: None,
})
}
}
/// Get available devices from the llama.cpp backend
#[tauri::command]
pub async fn get_devices(
backend_path: &str,
library_path: Option<&str>,
envs: HashMap<String, String>,
) -> ServerResult<Vec<DeviceInfo>> {
get_devices_from_backend(backend_path, library_path, envs).await
}
/// Generate API key using HMAC-SHA256
#[tauri::command]
pub fn generate_api_key(model_id: String, api_secret: String) -> Result<String, String> {
let mut mac = HmacSha256::new_from_slice(api_secret.as_bytes())
.map_err(|e| format!("Invalid key length: {}", e))?;
mac.update(model_id.as_bytes());
let result = mac.finalize();
let code_bytes = result.into_bytes();
let hash = general_purpose::STANDARD.encode(code_bytes);
Ok(hash)
}
/// Check if a process is still running
#[tauri::command]
pub async fn is_process_running<R: Runtime>(
app_handle: tauri::AppHandle<R>,
pid: i32,
) -> Result<bool, String> {
is_process_running_by_pid(app_handle, pid).await
}
/// Get a random available port
#[tauri::command]
pub async fn get_random_port<R: Runtime>(app_handle: tauri::AppHandle<R>) -> Result<u16, String> {
get_random_available_port(app_handle).await
}
/// Find session information by model ID
#[tauri::command]
pub async fn find_session_by_model<R: Runtime>(
app_handle: tauri::AppHandle<R>,
model_id: String,
) -> Result<Option<SessionInfo>, String> {
find_session_by_model_id(app_handle, &model_id).await
}
/// Get all loaded model IDs
#[tauri::command]
pub async fn get_loaded_models<R: Runtime>(
app_handle: tauri::AppHandle<R>,
) -> Result<Vec<String>, String> {
get_all_loaded_model_ids(app_handle).await
}
/// Get all active sessions
#[tauri::command]
pub async fn get_all_sessions<R: Runtime>(
app_handle: tauri::AppHandle<R>,
) -> Result<Vec<SessionInfo>, String> {
get_all_active_sessions(app_handle).await
}
/// Get session information by model ID
#[tauri::command]
pub async fn get_session_by_model<R: Runtime>(
app_handle: tauri::AppHandle<R>,
model_id: String,
) -> Result<Option<SessionInfo>, String> {
find_session_by_model_id(app_handle, &model_id).await
}