From 34b254e2d8eb3ea45339e5bab6c592628db1d101 Mon Sep 17 00:00:00 2001 From: Akarshan Date: Tue, 30 Sep 2025 11:14:18 +0530 Subject: [PATCH] fix: Improve KV cache estimation robustness The KV cache size calculation in estimate_kv_cache_internal now includes a fallback mechanism for models that do not explicitly define key_length and value_length in the GGUF metadata. If these attention keys are missing, the head dimension (and thus key/value length) is calculated using the formula embedding_length / total_heads. This improves robustness and compatibility with GGUF models that don't have the proper keys in metadata. Also adds logging of the full model metadata for easier debugging of the estimation process. --- .../tauri-plugin-llamacpp/src/gguf/utils.rs | 33 +++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) diff --git a/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/utils.rs b/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/utils.rs index 50e3f4a14..cdbbf92d5 100644 --- a/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/utils.rs +++ b/src-tauri/plugins/tauri-plugin-llamacpp/src/gguf/utils.rs @@ -62,6 +62,7 @@ pub async fn estimate_kv_cache_internal( ctx_size: Option, ) -> Result { log::info!("Received ctx_size parameter: {:?}", ctx_size); + log::info!("Received model metadata:\n{:?}", &meta); let arch = meta .get("general.architecture") .ok_or(KVCacheError::ArchitectureNotFound)?; @@ -94,15 +95,43 @@ pub async fn estimate_kv_cache_internal( let key_len_key = format!("{}.attention.key_length", arch); let val_len_key = format!("{}.attention.value_length", arch); - let key_len = meta + let mut key_len = meta .get(&key_len_key) .and_then(|s| s.parse::().ok()) .unwrap_or(0); - let val_len = meta + let mut val_len = meta .get(&val_len_key) .and_then(|s| s.parse::().ok()) .unwrap_or(0); + // Fallback: calculate from embedding_length if key/val lengths not found + if key_len == 0 || val_len == 0 { + let emb_len_key = format!("{}.embedding_length", arch); + let emb_len = meta + .get(&emb_len_key) + .and_then(|s| s.parse::().ok()) + .unwrap_or(0); + + if emb_len > 0 && n_head > 0 { + // For most transformers: head_dim = embedding_length / total_heads + let total_heads = meta + .get(&n_head_key) + .and_then(|s| s.parse::().ok()) + .unwrap_or(n_head); + + let head_dim = emb_len / total_heads; + key_len = head_dim; + val_len = head_dim; + + log::info!( + "Calculated key_len and val_len from embedding_length: {} / {} heads = {} per head", + emb_len, + total_heads, + head_dim + ); + } + } + if key_len == 0 || val_len == 0 { return Err(KVCacheError::EmbeddingLengthInvalid); }