Correctness first, then speed
Everything in this series so far produces a correct inference engine — it generates the right tokens. But correct is not the same as fast. The difference between a naive correct implementation and a production-grade one is often 10× or more in throughput. That gap is closed by a stack of optimizations — each one a separate technique, each one compounding on the others.
This chapter covers the four most important optimizations and — crucially — which bottleneck each one attacks. Recall from Chapter 06 that prefill is compute-bound and decode is memory-bound → Ch.06. Different optimizations target different bottlenecks, which is exactly why they stack rather than overlap.
FlashAttention — never write the giant matrix
Attention is the most expensive operation in a transformer. The naive way to compute it has a hidden performance killer: it creates an enormous intermediate matrix and writes it to slow memory. FlashAttention eliminates that write entirely. To understand how, we need to revisit the two kinds of GPU memory → Ch.01.
The memory hierarchy — the root of the problem
A GPU has two relevant kinds of memory. HBM (High Bandwidth Memory) is large — tens of gigabytes — but relatively slow. SRAM (on-chip memory) is tiny — a few megabytes — but extraordinarily fast, maybe 10× the bandwidth of HBM. The catch: SRAM is far too small to hold the big intermediate results of attention. So naive attention computes a result in SRAM, writes it out to HBM, reads it back, and so on — and all that HBM traffic is the bottleneck.
What naive attention does wrong
For a sequence of N tokens, naive attention computes an N×N matrix of attention scores — every token's relationship to every other token. For N=4,000 tokens, that's a 4,000×4,000 matrix = 16 million numbers, per attention head, per layer. This giant matrix is written to HBM, read back for the softmax, written again, read again. The matrix itself is never the final output — it's a throwaway intermediate — but moving it to and from HBM dominates the time.
Tiling — the core trick
Tiling means breaking the big computation into small blocks ("tiles") that each fit in fast SRAM. FlashAttention loads a tile of queries and a tile of keys into SRAM, computes their partial attention, updates a running result, then moves to the next tile — all without ever writing the full attention matrix to HBM. Through a clever running-softmax technique, it produces exactly the same answer as naive attention, but with a fraction of the HBM traffic.
By never materialising the N×N matrix in HBM, FlashAttention reduces memory traffic dramatically — often 5–10× less HBM I/O for long sequences. Since attention is memory-bound, less HBM traffic translates almost directly to speedup. It also reduces memory usage: the O(N²) matrix never exists, so memory scales linearly with sequence length instead of quadratically.
flash-attn library's functions: flash_attn_varlen_func for prefill and flash_attn_with_kvcache for decode → Ch.06. These are highly optimised CUDA kernels written by experts. This is the right engineering choice: use the battle-tested implementation for the single most performance-critical operation, rather than reinventing it.
Kernel fusion — do more per trip to memory
To understand kernel fusion, first understand what a kernel is: a single function that runs on the GPU → Ch.03. Every operation — add, multiply, normalise — is a kernel. Each kernel reads its inputs from HBM, computes, and writes its output back to HBM. The problem: if you run ten small kernels in sequence, you pay for ten round-trips to slow HBM, even though the data could have stayed on-chip between steps.
A common example: the SwiGLU activation in the feed-forward network involves a multiply, a gating function, and another multiply. Done as separate kernels, that's multiple HBM round-trips. Fused into one kernel, the intermediate values never leave SRAM. nano-vLLM relies on fused kernels from PyTorch and the flash-attn library, and on torch.compile (next section) to fuse operations automatically.
CUDA Graphs — eliminate the launch overhead
This optimization attacks a subtle but significant cost in the decode phase: kernel launch overhead. Every time the CPU tells the GPU to run a kernel, there's a small setup cost — the CPU has to prepare and dispatch the instruction. For one big kernel, this overhead is negligible. But decode runs hundreds of tiny kernels per token, and the launch overhead for all of them adds up to a large fraction of the total time.
Why decode is especially hurt by launch overhead
Recall that decode processes just one token at a time → Ch.06. Each token's forward pass involves hundreds of small GPU operations — one per layer, per sub-operation. Each is a separate kernel launch. The GPU work for each is tiny (one token!), so the CPU-side launch overhead can actually be larger than the GPU compute itself. The GPU sits idle, waiting for the CPU to dispatch the next tiny kernel. This is a CPU bottleneck masquerading as a GPU problem.
Capture once, replay many times
Capture — record the kernel sequence
The first time a decode step of a given shape runs, CUDA records every kernel launch into a graph — a single replayable object capturing the entire sequence of GPU operations and their dependencies. This capture happens once and has a one-time cost.
Replay — fire the whole graph at once
On every subsequent decode step, instead of launching hundreds of kernels individually, the CPU issues a single "replay this graph" command. The GPU executes the entire recorded sequence with no per-kernel CPU involvement. The launch overhead for hundreds of kernels collapses into one. For decode, this can improve throughput by 20–40%.
torch.compile — let the compiler optimise
torch.compile is PyTorch's optimising compiler. Normally PyTorch runs in "eager mode" — each operation executes immediately, one Python line at a time, exactly as written. torch.compile instead analyses your whole model ahead of time, then automatically applies optimizations like kernel fusion → Ch.03, removing redundant operations, and generating efficient fused GPU code — often via Triton.
The trade-off: torch.compile spends time compiling on the first run (the "warmup"), which can take seconds to minutes. But every run after that is faster. For a long-running inference server that starts once and serves millions of requests, paying a one-time compile cost for permanent speedup is an excellent trade. It also composes with the other techniques — torch.compile can automatically generate fused kernels and works alongside CUDA Graphs.
Stack the optimizations — watch them compound
Toggle each optimization on or off below. Watch how throughput (tokens/second) climbs and per-token latency drops as you layer them. Notice that they compound — each one multiplies on top of the others, because each attacks a different bottleneck.
All optimizations off — this is the naive baseline. Toggle them on to see the stack compound.
The optimizations in code
nano-vLLM applies these optimizations with surprisingly little code, because it leans on battle-tested libraries. Here's how the decode-path CUDA Graph capture works:
class ModelRunner: def capture_decode_graph(self, batch_size: int): # Capture a CUDA graph for a fixed decode batch size. # Decode always processes exactly 1 token per sequence — # a fixed shape — which is what makes graph capture possible. # Warmup run first — required before capture so memory is allocated for _ in range(3): self._run_decode(self.static_input) # Record every kernel launch into a replayable graph self.graph = torch.cuda.CUDAGraph() with torch.cuda.graph(self.graph): # Everything inside here is recorded, not executed normally self.static_output = self._run_decode(self.static_input) def decode_step(self, input_ids): # Copy new data into the static input buffers the graph expects self.static_input.copy_(input_ids) # Replay the ENTIRE recorded kernel sequence in one shot — # no per-kernel CPU launch overhead self.graph.replay() return self.static_output.clone()
class ModelConfig: # When enforce_eager=True, CUDA graphs are DISABLED. # Useful for debugging (eager mode gives clear error messages) # but slower. Production runs with enforce_eager=False. enforce_eager: bool = False # When False, nano-vLLM captures decode CUDA graphs at startup # for a range of batch sizes, trading startup time for speed.
flash-attn library for attention, PyTorch's native CUDA Graph API for decode, and a clean separation of prefill (eager) and decode (graphed) paths. The whole optimization layer is a few hundred lines because the hard parts live in well-tested external libraries. This is exactly how you should build real systems: don't reinvent the performance-critical primitives, compose them well.
Which optimization targets which bottleneck
The key to understanding the stack is knowing what each layer attacks. This is why they compose so well — minimal overlap.
FlashAttention → attention memory
Targets the HBM traffic of the attention operation specifically. Biggest win for long sequences, where the N×N matrix would be enormous. Helps both prefill and decode.
Kernel fusion → memory round-trips
Targets the HBM round-trips between small operations across the whole model. Keeps intermediates in SRAM. Broad, consistent improvement everywhere.
CUDA Graphs → CPU launch overhead
Targets the per-kernel dispatch cost, which dominates decode. The single biggest decode-specific win. Useless for prefill (variable shapes), essential for decode.
torch.compile → redundant work
Targets redundant operations and missed fusion opportunities across the whole graph. A compiler-level catch-all that finds optimizations humans miss.
Things beginners get wrong about optimization
enforce_eager=True disables CUDA Graphs precisely because eager mode gives readable error messages and easier debugging. The right move is eager mode while developing, full optimization in production. Blindly enabling everything can make a system harder to debug with little benefit during development.Quiz
Three questions on the optimization stack. Wrong answers explain exactly where the reasoning broke down.
1. Why does FlashAttention provide a speedup, given that it computes exactly the same mathematical result as naive attention?
2. CUDA Graphs are used for nano-vLLM's decode path but not its prefill path. Why?
3. Why do the four optimizations multiply together (≈3.8×) rather than just adding up?
What you now know
Optimization is a stack, not a single trick. FlashAttention, kernel fusion, CUDA Graphs, and torch.compile each attack a different bottleneck and compound multiplicatively — together often 3–4× over a naive baseline.
FlashAttention avoids the HBM write. By tiling attention into SRAM-sized chunks and using a running softmax, it never materialises the N×N matrix in HBM. Exact same result, 5–10× less memory traffic.
Kernel fusion cuts memory round-trips. Combining several operations into one kernel keeps intermediates in fast SRAM instead of bouncing to slow HBM between each step.
CUDA Graphs kill launch overhead. Capture the decode kernel sequence once, replay it as one unit — eliminating per-kernel CPU dispatch cost. Works for decode (fixed shape), not prefill (variable shape).
torch.compile optimises the whole graph. A one-time compile cost buys permanent speedup via automatic fusion and redundancy elimination — an excellent trade for a long-running server.
Match the optimization to the bottleneck. Knowing prefill is compute-bound and decode is memory-bound → Ch.06 is what tells you which optimization helps which phase. nano-vLLM composes libraries rather than reinventing them.