fix: Improve KV cache estimation robustness

The KV cache size calculation in estimate_kv_cache_internal now includes a fallback mechanism for models that do not explicitly define key_length and value_length in the GGUF metadata.

If these attention keys are missing, the head dimension (and thus key/value length) is calculated using the formula embedding_length / total_heads. This improves robustness and compatibility with GGUF models that don't have the proper keys in metadata.

Also adds logging of the full model metadata for easier debugging of the estimation process.
This commit is contained in:
Akarshan 2025-09-30 11:14:18 +05:30
parent d315522c5a
commit 34b254e2d8
No known key found for this signature in database
GPG Key ID: D75C9634A870665F

View File

@ -62,6 +62,7 @@ pub async fn estimate_kv_cache_internal(
ctx_size: Option<u64>,
) -> Result<KVCacheEstimate, KVCacheError> {
log::info!("Received ctx_size parameter: {:?}", ctx_size);
log::info!("Received model metadata:\n{:?}", &meta);
let arch = meta
.get("general.architecture")
.ok_or(KVCacheError::ArchitectureNotFound)?;
@ -94,15 +95,43 @@ pub async fn estimate_kv_cache_internal(
let key_len_key = format!("{}.attention.key_length", arch);
let val_len_key = format!("{}.attention.value_length", arch);
let key_len = meta
let mut key_len = meta
.get(&key_len_key)
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0);
let val_len = meta
let mut val_len = meta
.get(&val_len_key)
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0);
// Fallback: calculate from embedding_length if key/val lengths not found
if key_len == 0 || val_len == 0 {
let emb_len_key = format!("{}.embedding_length", arch);
let emb_len = meta
.get(&emb_len_key)
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0);
if emb_len > 0 && n_head > 0 {
// For most transformers: head_dim = embedding_length / total_heads
let total_heads = meta
.get(&n_head_key)
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(n_head);
let head_dim = emb_len / total_heads;
key_len = head_dim;
val_len = head_dim;
log::info!(
"Calculated key_len and val_len from embedding_length: {} / {} heads = {} per head",
emb_len,
total_heads,
head_dim
);
}
}
if key_len == 0 || val_len == 0 {
return Err(KVCacheError::EmbeddingLengthInvalid);
}