CS336 Notes: Lecture 5 - GPUs
Compute has grown faster than memory bandwidth. Moving data is now the main bottleneck. This lecture teaches you to think about GPU performance in terms of bytes moved, not just FLOPs.
The core constraint: modern GPUs are often memory-bound, not compute-bound. Hardware-efficient algorithms are memory-efficient algorithms.
Eight Claims About GPU Performance
GPUs are built for throughput, not latency. Many simple cores, little control logic, same instruction across many data points at once.
Modern LLM progress comes as much from GPU hardware and parallelism as from new model ideas.
Compute has grown far faster than memory bandwidth. Moving data to and from global memory is where time goes.
GPU performance collapses when control flow splits within a warp or when memory access is sloppy. Too many global reads and writes, bad access patterns, weak tiling, and misaligned sizes all hurt.
You can win back speed with a small set of moves. Lower precision, kernel fusion, recomputation, coalesced memory access, and tiling into shared memory.
Matmul speed curves look wavy. Performance depends on tile sizes, SM count, burst access, and whether dimensions divide cleanly.
FlashAttention avoids reading or writing the full n² attention matrix to global memory. It combines tiled matmuls, online softmax, and recomputation.
The roofline model tells you whether you're memory-bound or compute-bound. Low arithmetic intensity means memory-bound. High arithmetic intensity means compute-bound.
Why GPUs Matter
Large language models follow scaling laws: more compute and more data tend to improve performance.
For decades, compute scaling came from faster single cores. Dennard scaling and Moore's law delivered more, smaller, faster, lower-power transistors.
Then single-thread gains flattened. Transistor counts kept rising, but clock speeds and single-thread performance stopped climbing.
Deep learning shifted the game to parallelism: do more operations at once instead of making one core faster.
GPU generations (K20 → V100 → A100 → H100) show huge gains in operations per second, especially for matrix multiplication.
To use that hardware, you need to understand how GPUs execute work and what keeps them fed.
CPU vs GPU
CPUs and GPUs optimize for different goals.
CPU
- Optimized for low latency on one task.
- A few complex cores.
- Heavy control logic: branch prediction, out-of-order execution, big caches.
- Strong on branched, irregular workloads.
GPU
- Optimized for high throughput across many similar tasks.
- Many simple compute units with minimal control per unit.
- Uses SIMT: one instruction runs on many threads at once, each on different data.
- Best for large batches of similar arithmetic, especially vector and matrix math.
If you have tasks T1, T2, T3, T4:
- A CPU tends to finish each quickly, one at a time.
- A GPU tends to finish the whole set faster, even if any single task takes longer.
GPU Anatomy
Streaming Multiprocessor (SM)
The basic scheduling and control unit. Owns fast local storage: registers and shared/L1 memory. Runs many warps in flight to hide stalls.
Streaming Processors (SPs)
Simple arithmetic units inside an SM. Execute the same instruction across different data.
Tensor Cores / Matrix Units
Specialized hardware for fast matrix multiplication. Matmul throughput is far higher than general-purpose FLOPs. Modern models need to be matmul-heavy to fully use the GPU.
An A100 has over 100 SMs, each with many SPs and tensor cores.
Threads, Warps, and Blocks
GPU work is arranged into threads, warps, and blocks.
Thread: Smallest unit of work. Has its own registers and local state.
Warp: A group of 32 threads. Executes in lockstep: same instruction, different data.
Block: A group of threads scheduled onto one SM. Threads in a block can use shared memory and synchronize with each other.
Blocks map to SMs. Warps schedule within blocks. Threads execute together within a warp under SIMT.
Branch Divergence
A warp runs one instruction at a time. If threads choose different branches, the warp can't do both branches at once.
Example: if thread_id < 4 do A else do B.
Within a warp:
- Threads 0-3 run A while the rest sit idle.
- Then threads 4-7 run B while the first group sits idle.
The warp effectively runs both paths serially and pays for both. This is branch divergence. Any conditional that splits threads within a warp is expensive.
GPU Memory Hierarchy
GPUs have a strict memory hierarchy. The closer the memory, the faster it is, and the smaller it is.
Registers: Per-thread storage. Fastest.
Shared / L1 Memory: Per-SM memory shared across a block. Very fast. Good for tiles and reused data.
L2 Cache: Shared across SMs. Slower than shared/L1, faster than DRAM.
Global Memory (HBM/DRAM): Off-chip. Large and much slower than on-chip memory.
If kernels constantly read and write global memory, SMs spend time waiting. On modern GPUs, many workloads are limited by memory bandwidth, not compute.
TPUs and Other Accelerators
TPUs follow similar ideas with a simpler structure.
A TPU has a core with:
- A small scalar unit for control.
- A vector unit for elementwise ops.
- A large matrix unit (MXU) for batch matmuls.
- Fast on-core memories for activations and vectors.
- High-bandwidth memory outside the core.
Both GPUs and TPUs do most work as matrix multiplications, rely on fast on-chip memory, and scale by running many devices in clusters.
Compute vs Memory Scaling
Different parts of the system have scaled at different rates:
- Host-device links (PCIe, NVLink) improved slowly.
- Global memory bandwidth (GDDR → HBM2E) improved more, but still modest.
- Compute, especially matmul FLOPs, grew by orders of magnitude.
Older GPUs could be FLOP-bound. Modern GPUs are often memory-bound.
The Roofline Model
The roofline model frames performance as a function of arithmetic intensity.
- Horizontal axis: FLOPs per byte moved.
- Vertical axis: achieved FLOPs per second.
Two regimes:
- Low intensity: memory-bound. Throughput rises with intensity.
- High intensity: compute-bound. Throughput plateaus at peak compute.
Most of the tricks in this lecture either increase arithmetic intensity or cut memory traffic.
Lower Precision
Fewer bits per number means fewer bytes moved, and often more math per cycle.
Example: ReLU, x = max(0, x).
float32: Read 4 bytes, write 4 bytes. 8 bytes per element.
float16: Read 2 bytes, write 2 bytes. 4 bytes per element.
Same operations, half the traffic. Arithmetic intensity doubles. Tensor cores can often be used.
In practice:
- Inputs and weights often use fp16, bf16, or int8.
- Matmul accumulates in fp32 for stability.
- Sensitive ops (exp, softmax, norms) may use higher precision or special handling.
Mixed precision needs careful casting, but the speed gains are large.
Operator Fusion
Naive GPU code often launches one kernel per small operation. That creates repeated round trips to global memory.
Example: y = sin(x)² + cos(x)².
Naive:
sin → write s
cos → write c
square s → write s2
square c → write c2
add → write y
Each step reads from global and writes back.
Fused:
- Load x once.
- Compute everything in registers or shared memory.
- Write y once.
Global traffic drops and speed rises. Tools like torch.compile can fuse many chains automatically.
Recomputation Instead of Storing
Backprop often stores intermediate activations for reuse. Those activations live in global memory, so storing and rereading them is costly.
Example: three stacked sigmoids.
Naive forward:
- Compute s1, s2, s3.
- Store all three and the output.
Naive backward:
- Read s1, s2, s3 back from global memory.
- Compute gradients.
Recomputation:
- In forward, don't store internal sigmoids.
- In backward, recompute them from x inside the kernel, then compute gradients.
You trade extra compute for less memory traffic. On modern GPUs, that's often a win because compute is cheap and memory is scarce. This is similar to checkpointing, but used for speed as much as memory.
Burst Mode and Memory Coalescing
Global memory moves data in bursts, not single scalars.
Fetching element 0 typically pulls in a whole aligned chunk. The expensive part is starting the transfer. Nearby bytes are cheap once the burst begins.
If threads in a warp read nearby addresses, hardware can coalesce those reads into a few burst requests. If they read scattered addresses, you trigger many bursts and waste bandwidth.
Coalesced access: Threads read contiguous addresses. High effective bandwidth.
Non-coalesced access: Threads read far apart addresses. Bandwidth collapses.
Matrix example: If each thread walks across a row but the layout makes the warp's accesses strided, coalescing is poor. If the warp aligns with contiguous memory, coalescing is good.
Tiling and Matrix Multiplication
Matmul is the core GPU workload. Naively, threads repeatedly pull the same A and B values from global memory.
Tiling fixes this.
- Split A and B into tiles.
- Load an A tile and a B tile into shared memory.
- Compute partial results for a C tile.
- Repeat over tile pairs that contribute to that C tile.
Now each global value is loaded far fewer times. Within a tile, reuse happens from shared memory. This cuts global reads by roughly the tile size factor and shifts work onto fast on-chip memory.
Tile choice is constrained by shared memory size, warp structure, coalescing, and whether dimensions divide cleanly.
Why Matmul Speed Looks Wavy
Throughput vs matrix size often rises overall, but with weird dips and waves.
Divisibility by Tile Size
If dimensions are multiples of tile, warp, and burst sizes, hardware stays busy. If not, you get partial tiles, idle threads, and extra memory transactions.
Awkward sizes can be much slower than nearby "nice" sizes.
Tiles vs SM Count
Each tile maps to a block and then an SM.
Example: if a matrix needs 98 tiles and the GPU has 108 SMs, everything runs in one wave. If it needs 120 tiles, 108 run first wave and 12 run a second wave with low occupancy. Throughput drops at those boundaries.
Burst Alignment
If tile widths align with DRAM burst boundaries, rows fit neatly into a few bursts. Add a column and rows can cross a burst boundary, doubling required bursts.
These effects combine into the "wavy" plots.
FlashAttention
Standard attention does:
S = QKᵀ
softmax(S) row-wise
output = softmax(S)V
For sequence length n, S is n×n. Storing S or the softmax weights is O(n²) memory. Reading and writing that matrix from global memory is the real cost.
You can't avoid O(n²) math for general attention, but you can avoid O(n²) global memory traffic.
FlashAttention computes exact attention while keeping HBM access far below the naive approach.
Online Softmax for Tiled Attention
Softmax is row-global, which seems to prevent tiling. Online softmax solves that.
Stable softmax uses:
- max for the row
- exp(x - max)
- sum of exps
- divide
Online softmax processes the row in chunks:
- Maintain a running max m and running sum d for the processed prefix.
- When a new chunk arrives, update m and d, rescaling d if the max increases.
This lets you stream over score tiles, update normalization on the fly, and avoid storing the full row or the full n×n score matrix in global memory.
FlashAttention Forward Pass
The forward pass combines:
- Tiled matmuls for QKᵀ and for applying V.
- Online softmax to normalize across tiles.
- Shared memory and registers to avoid n² storage.
Sketch:
- Partition Q, K, V along the sequence dimension.
- For each query tile:
- Load Q tile and a K tile into shared memory.
- Compute the score tile.
- Update running max and sum for each row.
- Move to the next K tile.
Once all K tiles are processed, the normalization is known and the output can be formed while keeping intermediates on-chip. The full score matrix is never materialized in global memory.
FlashAttention Backward Pass
Backward is harder because naive gradients want the softmax outputs, which are n².
FlashAttention recomputes what it needs:
- During backward, loop over tiles again.
- Recreate local scores and softmax values from Q, K, V using the forward formulas.
- Compute gradients for that tile.
- Discard intermediates.
More compute, far less memory traffic and storage. On modern GPUs, that trade is favorable.
What's Next
LLM speed comes from using GPUs well, not only from better models.
GPUs are massively parallel, but they are fragile: branch divergence hurts, and memory behavior dominates.
Compute has outrun memory. Bytes moved are the limiting resource.
The main toolkit is small:
- Use lower precision where you can.
- Fuse operators to reduce global reads and writes.
- Recompute intermediates instead of storing them.
- Coalesce global memory accesses.
- Tile big operations into shared memory with sizes that match warp, burst, and SM structure.
FlashAttention is the clean example: tiling, online softmax, recomputation, and matmul hardware turn attention from a memory problem into a GPU-friendly workload.
Next lecture: multi-GPU training. Data parallelism, model parallelism, and how to keep many GPUs fed without drowning in communication.
Keep reading
You might also like
CS336 Notes: Lecture 2 - PyTorch and Resource Accounting
Resource accounting for LLM training: compute estimates, memory budgets, dtypes, tensors, and mixed precision.