Merge pull request #6659 from menloresearch/fix/6626

fix: Improve KV cache estimation robustness
This commit is contained in:
Nguyen Ngoc Minh 2025-09-30 13:42:28 +07:00 committed by GitHub
commit 04fcd788a3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -62,6 +62,7 @@ pub async fn estimate_kv_cache_internal(
ctx_size: Option<u64>,
) -> Result<KVCacheEstimate, KVCacheError> {
log::info!("Received ctx_size parameter: {:?}", ctx_size);
log::info!("Received model metadata:\n{:?}", &meta);
let arch = meta
.get("general.architecture")
.ok_or(KVCacheError::ArchitectureNotFound)?;
@ -94,15 +95,43 @@ pub async fn estimate_kv_cache_internal(
let key_len_key = format!("{}.attention.key_length", arch);
let val_len_key = format!("{}.attention.value_length", arch);
let key_len = meta
let mut key_len = meta
.get(&key_len_key)
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0);
let val_len = meta
let mut val_len = meta
.get(&val_len_key)
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0);
// Fallback: calculate from embedding_length if key/val lengths not found
if key_len == 0 || val_len == 0 {
let emb_len_key = format!("{}.embedding_length", arch);
let emb_len = meta
.get(&emb_len_key)
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(0);
if emb_len > 0 && n_head > 0 {
// For most transformers: head_dim = embedding_length / total_heads
let total_heads = meta
.get(&n_head_key)
.and_then(|s| s.parse::<u64>().ok())
.unwrap_or(n_head);
let head_dim = emb_len / total_heads;
key_len = head_dim;
val_len = head_dim;
log::info!(
"Calculated key_len and val_len from embedding_length: {} / {} heads = {} per head",
emb_len,
total_heads,
head_dim
);
}
}
if key_len == 0 || val_len == 0 {
return Err(KVCacheError::EmbeddingLengthInvalid);
}