update blogpost

This commit is contained in:
Thien Tran 2025-09-26 07:59:54 +00:00
parent b6270dadd3
commit ade55d13ed

View File

@ -28,14 +28,107 @@ On the journey to ensure vLLM runs as fast as possible on our fleet of RTX PRO 6
Before discussing the details of our perf model, we also want to provide you with a napkin-math version. To predict decode performance in terms of tok/s, you can take memory bandwidth and divide it by model size. Model size can be estimated by multiplying parameters count by 2 for BF16, or by 1 and 0.5 for INT8 and INT4 respectively. For example, you want to know the theoretical performance of running [Qwen3-4B-Instruct-2507](https://huggingface.co/Qwen/Qwen3-4B-Instruct-2507) on [RTX 4090](https://www.nvidia.com/en-us/geforce/graphics-cards/40-series/rtx-4090/). The GPU has 1000 GB/s VRAM bandwidth, and the model is estimated to be 4x2 = 8GB. The theoretical decode speed would be 1000 / 8 = 125 tok/s.
TODO: show the equation
This estimate is actually quite accurate for short context size, and as you can see later when we walk through the performance model. To get a slightly better estimate, you can also look up the HuggingFace repo which shows the exact repo size. In our case of Qwen3-4B-Instruct-2507, the repo size is 8.06 GB, so it's not that much difference from our times 2 rule.
TODO: screenshot of HF
The biggest problem with this approach is not taking into account the effect of long context - doing attention on a large KV cache takes time! It also doesn't predict the latency of **prefill** operation i.e. Time to First Token (TTFT). We aim to predict these metrics more accurately with our perf model!
###
### Break down into ops: Matmul and Attention
An LLM forward pass consists of a series of operations: matrix multiplication, attention, RMS norm, and so on. For simplicity, we only need to consider **matrix multiplication** and **attention** operations as they account for most of the runtime.
Let's take a top down view of a typical transformer-based LLM. We have an embedding layer, a series of repeated hidden layers, and finally the Language Modelling (LM) head. Each hidden layer consists of 2 modules: attention module, which typically employs Grouped Query Attention (GQA), and Multi-Layer Perceptron (MLP) module.
TODO: image of transformer arch
There are a lot of operations in these layers, such as matrix multiplication (matmul), attention, RMS norm, activation functions, and so on. Though they can look overwhelming, we only need to consider **matmul** and **attention** as they account for most of the runtime.
Module | Main operations
-------|----------------
Input embedding | Embed tokens (ignore)
Hidden layers (repeated N times)
- Attention module | Query, Key, Value projections (matmul), Attention, Output projection (matmul)
- MLP module | Up and Down projections (matmul)
LM Head | Matmul
A given matmul or attention operation, applied on a specific shape of inputs, can be charaterized as either **compute-bound** or **memory-bound**. These bounds are the theoretical limit on how fast the operation can run, assuming we can fully utilize compute units and memory bandwidth of the GPU (or any kind of hardware).
#### Characterizing Matmul
Matmul takes in two input matrices, A with shape `(M, K)` and B with shape`(K, N)`, and produces output matrix C with shape `(M, N)`. Mathematically speaking, each output element is a dot product between a row of A and a column of B.
$$$
C_{mn} = \sum_{k=0}^{K-1} A_{mk}B_{kn}, \forall 0\leq m<M, 0\leq n<N
$$$
TODO: show a diagram
We want to count:
1. Number of floating point operations to compute matmul, or FLOPs. This will determine the compute bound.
2. Amount of data being read and write from VRAM. This will determine the memory bound.
Recall that each output element is a dot product of two vectors with size K. This involves 2K FLOPs, since we need to do 1 multiplication and 1 addition for each pair of elements. To compute the whole output of size `(M, N)`, that will be **2MNK FLOPs**. This is the most important fact you need to know about matmul - there are 2MNK floating point operations.
For the amount of data transferred, regardless of how a particular matmul algorithm is implemented, **minimally**, we have to read input data from VRAM, and write output data to VRAM. Assuming BF16 data type, which has 2 bytes per element, matmul has 2(MN + NK + MN) bytes of data transferred.
To determine whether a particular matmul is compute-bound or memory-bound, we compute the expected duration under either compute- or memory-bound assumption. We also need the actual hardware specifications - compute floating point operations per second (FLOPS, with a capital S), and memory bandwidth. Taking RTX PRO 6000 for example.
There is also another approach using **arithmetic intensity**, but it basically compares the same thing.
#### Characterizing Attention
For each attention head, attention is computed as
$$$
S = Q @ K.T
P = softmax(S)
O = P @ V
$$$
Where S and P are the attention weights before and after softmax respectively. If we ignore the softmax operation, which we do, attention is simply two back-to-back matmuls. Hence, FLOPs count is simply the sum of FLOPs for individual matmul. For the amount of data transferred, we only consider the inputs (Q, K, V) and output (O). In other words, assuming intermediates like S and P are not written to VRAM (this is in fact the key innovation of [Flash Attention](https://github.com/Dao-AILab/flash-attention)).
### Theoretical vs Actual hardware specs
Our perf model can only be as good as its inputs i.e. the raw performance of the GPU. We can consider more operations and overhead involved in an LLM's forward pass, but it would be pointless if our assumed GPU's FLOPS and memory bandwidth are not accurate. Fortunately, we don't need to do guess work. We can just measure them!
Generally, matmul is guaranteed to be compute-bound when the shapes are sufficiently large. In other words, by measuring the runtime duration of a large matmul, we can estimate the realistically achievable FLOPS of the GPU. We can safely assume the default matmul implementation in PyTorch is near optimal for large problem sizes. In practice, we sweep through large values of matrix sizes and take the highest achieved FLOPS.
```python
import time
import torch
from triton.testing import do_bench
max_flops = 0
for size in range(4096, 16384, 2048):
A = torch.randn(size, size, dtype=torch.bfloat16, device="cuda")
B = torch.randn(size, size, dtype=torch.bfloat16, device="cuda").T
time.sleep(0.5)
latency_ms = do_bench(lambda: torch.mm(A, B))
flops = (2 * size * size * size) / (latency_ms * 1e-3)
max_flops = max(max_flops, flops)
print(f"{max_flops * 1e-12:.2f} TFLOPS")
```
Estimating memory bandwidth is much easier. We can use `memcpy` as the optimal memory operation, which simply reads data from one location and write it to a new location. For a tensor of N bytes, `memcpy` will perform N bytes of memory read and N bytes of memory write.
```python
import torch
from triton.testing import do_bench
max_mem_bw = 0
for size in range(1, 4):
x = torch.randn(size * (1 << 30), device="cuda")
latency_ms = do_bench(lambda: x.clone())
mem_bw = x.nbytes * 2 / (latency_ms * 1e-3)
max_mem_bw = max(max_mem_bw, mem_bw)
print(f"{max_mem_bw * 1e-9:.2f} GB/s")
```
<CTABlog />