Merge pull request #6659 from menloresearch/fix/6626
fix: Improve KV cache estimation robustness
This commit is contained in:
commit
04fcd788a3
@ -62,6 +62,7 @@ pub async fn estimate_kv_cache_internal(
|
|||||||
ctx_size: Option<u64>,
|
ctx_size: Option<u64>,
|
||||||
) -> Result<KVCacheEstimate, KVCacheError> {
|
) -> Result<KVCacheEstimate, KVCacheError> {
|
||||||
log::info!("Received ctx_size parameter: {:?}", ctx_size);
|
log::info!("Received ctx_size parameter: {:?}", ctx_size);
|
||||||
|
log::info!("Received model metadata:\n{:?}", &meta);
|
||||||
let arch = meta
|
let arch = meta
|
||||||
.get("general.architecture")
|
.get("general.architecture")
|
||||||
.ok_or(KVCacheError::ArchitectureNotFound)?;
|
.ok_or(KVCacheError::ArchitectureNotFound)?;
|
||||||
@ -94,15 +95,43 @@ pub async fn estimate_kv_cache_internal(
|
|||||||
let key_len_key = format!("{}.attention.key_length", arch);
|
let key_len_key = format!("{}.attention.key_length", arch);
|
||||||
let val_len_key = format!("{}.attention.value_length", arch);
|
let val_len_key = format!("{}.attention.value_length", arch);
|
||||||
|
|
||||||
let key_len = meta
|
let mut key_len = meta
|
||||||
.get(&key_len_key)
|
.get(&key_len_key)
|
||||||
.and_then(|s| s.parse::<u64>().ok())
|
.and_then(|s| s.parse::<u64>().ok())
|
||||||
.unwrap_or(0);
|
.unwrap_or(0);
|
||||||
let val_len = meta
|
let mut val_len = meta
|
||||||
.get(&val_len_key)
|
.get(&val_len_key)
|
||||||
.and_then(|s| s.parse::<u64>().ok())
|
.and_then(|s| s.parse::<u64>().ok())
|
||||||
.unwrap_or(0);
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
// Fallback: calculate from embedding_length if key/val lengths not found
|
||||||
|
if key_len == 0 || val_len == 0 {
|
||||||
|
let emb_len_key = format!("{}.embedding_length", arch);
|
||||||
|
let emb_len = meta
|
||||||
|
.get(&emb_len_key)
|
||||||
|
.and_then(|s| s.parse::<u64>().ok())
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
if emb_len > 0 && n_head > 0 {
|
||||||
|
// For most transformers: head_dim = embedding_length / total_heads
|
||||||
|
let total_heads = meta
|
||||||
|
.get(&n_head_key)
|
||||||
|
.and_then(|s| s.parse::<u64>().ok())
|
||||||
|
.unwrap_or(n_head);
|
||||||
|
|
||||||
|
let head_dim = emb_len / total_heads;
|
||||||
|
key_len = head_dim;
|
||||||
|
val_len = head_dim;
|
||||||
|
|
||||||
|
log::info!(
|
||||||
|
"Calculated key_len and val_len from embedding_length: {} / {} heads = {} per head",
|
||||||
|
emb_len,
|
||||||
|
total_heads,
|
||||||
|
head_dim
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if key_len == 0 || val_len == 0 {
|
if key_len == 0 || val_len == 0 {
|
||||||
return Err(KVCacheError::EmbeddingLengthInvalid);
|
return Err(KVCacheError::EmbeddingLengthInvalid);
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user