Merge pull request #6659 from menloresearch/fix/6626
fix: Improve KV cache estimation robustness
This commit is contained in:
commit
04fcd788a3
@ -62,6 +62,7 @@ pub async fn estimate_kv_cache_internal(
|
||||
ctx_size: Option<u64>,
|
||||
) -> Result<KVCacheEstimate, KVCacheError> {
|
||||
log::info!("Received ctx_size parameter: {:?}", ctx_size);
|
||||
log::info!("Received model metadata:\n{:?}", &meta);
|
||||
let arch = meta
|
||||
.get("general.architecture")
|
||||
.ok_or(KVCacheError::ArchitectureNotFound)?;
|
||||
@ -94,15 +95,43 @@ pub async fn estimate_kv_cache_internal(
|
||||
let key_len_key = format!("{}.attention.key_length", arch);
|
||||
let val_len_key = format!("{}.attention.value_length", arch);
|
||||
|
||||
let key_len = meta
|
||||
let mut key_len = meta
|
||||
.get(&key_len_key)
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
.unwrap_or(0);
|
||||
let val_len = meta
|
||||
let mut val_len = meta
|
||||
.get(&val_len_key)
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
// Fallback: calculate from embedding_length if key/val lengths not found
|
||||
if key_len == 0 || val_len == 0 {
|
||||
let emb_len_key = format!("{}.embedding_length", arch);
|
||||
let emb_len = meta
|
||||
.get(&emb_len_key)
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
.unwrap_or(0);
|
||||
|
||||
if emb_len > 0 && n_head > 0 {
|
||||
// For most transformers: head_dim = embedding_length / total_heads
|
||||
let total_heads = meta
|
||||
.get(&n_head_key)
|
||||
.and_then(|s| s.parse::<u64>().ok())
|
||||
.unwrap_or(n_head);
|
||||
|
||||
let head_dim = emb_len / total_heads;
|
||||
key_len = head_dim;
|
||||
val_len = head_dim;
|
||||
|
||||
log::info!(
|
||||
"Calculated key_len and val_len from embedding_length: {} / {} heads = {} per head",
|
||||
emb_len,
|
||||
total_heads,
|
||||
head_dim
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if key_len == 0 || val_len == 0 {
|
||||
return Err(KVCacheError::EmbeddingLengthInvalid);
|
||||
}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user