CS336 Notes: Lecture 8 - Parallelism 2
Training on many GPUs is about one thing: getting the most compute while moving the least data.
This lecture from Stanford CS336 covers hands-on distributed training. The constraint that matters: memory bandwidth forms a hierarchy, and your job is to keep work close to data.
The Memory Hierarchy
A node typically has 8 GPUs. Each GPU has streaming multiprocessors (SMs) that do the math. Around them sits a speed ladder:
- L1 cache / shared memory: tiny, very fast, local to each SM.
- HBM (high bandwidth memory): larger, slower than L1, on each GPU.
- NVLink: connects GPUs in the same node, fast.
- NVSwitch: connects across nodes, slower.
- PCIe and Ethernet: older/slower links with more overhead.
Best case: your data is in L1. If not, it comes from HBM. In multi-GPU setups, it may come from another GPU.
The goal: do most math on data that's close. Avoid moving data over NVLink, NVSwitch, PCIe, and Ethernet unless you must. Keep arithmetic intensity high.
Collective Operations
Collectives are standard multi-device communication patterns. They're the building blocks of all distributed training.
Terms:
- world_size: number of devices in the group.
- rank: the device ID, from 0 to world_size - 1.
Main operations:
- Broadcast: one rank sends a tensor to all ranks.
- Scatter: one rank sends different slices to each rank.
- Gather: one rank collects values from all ranks.
- Reduce: combines values from all ranks with an op (sum, min, max) and writes the result on one rank.
- all_gather: like gather, but every rank gets the full collected result.
- reduce_scatter: reduce across ranks, then scatter slices of the reduced result.
- all_reduce: reduce followed by all_gather. Everyone gets the reduced result.
Memory hooks: Reduce = combine across ranks. Scatter = split outputs across ranks. Gather = collect outputs onto one rank. All = everyone receives the output.
From Hardware to Software
NCCL (NVIDIA Collective Communication Library) understands topology: NVLink, NVSwitch, PCIe. It implements collectives with ring and tree algorithms. On startup, ranks discover topology and pick paths. Collectives launch CUDA kernels that move data GPU-to-GPU with little CPU work.
PyTorch's torch.distributed exposes collectives like all_reduce, reduce_scatter, all_gather, broadcast, and barrier. It supports backends: nccl for GPU collectives, gloo for CPU collectives (useful for debugging).
Typical setup:
- Launch one Python process per rank.
- Call
dist.init_process_groupwith world_size, rank, backend, and an init method. - Call torch.distributed collectives on tensors.
barrier makes all processes wait until everyone reaches the same point. Useful for debugging, ordered printing, and sometimes correctness.
all_reduce Example
Each rank creates t = [0, 1, 2, 3] + rank.
Rank 0: [0, 1, 2, 3]
Rank 1: [1, 2, 3, 4]
Rank 2: [2, 3, 4, 5]
Rank 3: [3, 4, 5, 6]
Call dist.all_reduce(t, op=SUM).
Afterward, every rank holds the element-wise sum: [6, 10, 14, 18].
The tensor updates in place. Every rank ends with the same result.
reduce_scatter Example
Each rank starts with an input tensor whose first dimension is world_size (shape [4, N] when world_size = 4).
reduce_scatter reduces across ranks, then splits the result: slice 0 goes to rank 0, slice 1 to rank 1, and so on. Each rank receives one slice (shape [N]).
If you follow reduce_scatter with all_gather, you rebuild the same output as all_reduce.
Two common questions:
How does it know where slices go? It uses the first dimension. Slice i goes to rank i. Shapes must match across ranks.
How do ranks stay in sync? All ranks must call the same collectives in the same order with compatible shapes. If one rank skips a call, the others wait forever.
Benchmarking Collectives
To measure communication speed:
- Warm up first so kernels are loaded.
- Synchronize before and after timing.
- Use large tensors so fixed overhead doesn't dominate.
- Compute effective bandwidth from bytes moved and time.
all_reduce benchmark example:
world_size = 4. Each rank has 100,000,000 float32 values (400 MB per rank).
A rough transfer estimate per rank: 2 × size_bytes × (world_size - 1). Factor 2 because data flows out for reduction and back out to all ranks.
The example measured about 277 GB/s, below H100 NVLink peak (around 900 GB/s). That gap is normal. It depends on tensor size, algorithm, overlap, and topology.
Key points: Know what's sent and received. Estimate bytes per operation. Time with synchronization. Compare effective bandwidth to hardware limits.
Data Parallel Training
Data parallelism is the most common setup. Every GPU holds a full copy of the model. The batch splits across GPUs.
Example with a deep MLP:
Input batch shape: [batch_size, hidden_dim]. Each layer: [hidden_dim × hidden_dim] matmul + nonlinearity.
Batch split: local_batch_size = batch_size / world_size. Rank r uses its slice.
Model and optimizer: every rank stores full parameters, every rank has its own optimizer.
Training loop per rank:
- Forward on local batch.
- Compute local loss (differs across ranks because data differs).
- Backward to get local gradients.
- Synchronize gradients: all_reduce each parameter's gradient with SUM, divide by world_size for average.
- After all_reduce, gradients match across ranks.
- Optimizer step: each rank updates its local model. Parameters stay identical.
Properties: Parameters start equal (same init, same RNG seed). Parameters stay equal because gradients synchronize every step.
The key rule: all_reduce is communication and synchronization. If any rank skips it, other ranks block.
Tensor Model Parallelism
Tensor parallelism splits the model itself, not the batch.
Why you need it: Sometimes the model is too large for one GPU, even with batch size 1.
MLP example: split the hidden dimension
local_num_dim = hidden_dim / world_size. Each rank holds [hidden_dim × local_num_dim] for each layer.
Forward pass:
- Start with activations x (shape [batch_size × hidden_dim]), identical on all ranks.
- For each layer: each rank computes local_x = x @ local_W (shape [batch_size × local_num_dim]). Apply nonlinearity. all_gather the shards from every rank. Concatenate to rebuild x as [batch_size × hidden_dim].
- Repeat for next layer.
At the end, all ranks have the same final activations.
Costs: Each rank stores 1/world_size of parameters, so you can scale width. But you pay communication every layer. This only works well with very fast links like NVLink.
Backward follows the same idea in reverse, using reduce_scatter and all_reduce patterns to combine gradients.
Pipeline Parallelism
Pipeline parallelism splits the model by depth.
Divide layers into chunks, assign chunks to ranks, pass activations forward.
Example: A 4-layer MLP, world_size = 2. Rank 0 holds layers 0 and 1. Rank 1 holds layers 2 and 3.
Naive pipeline problem: Rank 1 waits while rank 0 processes the full batch, then rank 0 waits while rank 1 finishes. Creates bubbles.
Microbatching fills the pipeline: Split the batch into microbatches. Example: batch size 128 becomes 4 microbatches of 32.
Forward schedule:
Rank 0, for each microbatch: run layers 0 and 1, send activations to rank 1.
Rank 1, for each microbatch: recv activations from rank 0, run layers 2 and 3, store outputs.
With microbatches, rank 0 works on microbatch k+1 while rank 1 works on microbatch k.
Point-to-point primitives: send(tensor, dst_rank) sends a tensor. recv(tensor, src_rank) receives into a provided tensor. In simple code, these block. isend returns a handle for async sends.
Backward pass: The last stage computes loss and gradients, sends gradients back. Earlier stages backprop through their layers and send gradients further back.
Safety rules: Every send needs a matching recv. Sends between the same pair preserve order. If a send has no matching recv, you deadlock.
Higher-Level Sharding with JAX
PyTorch often makes you write collectives directly.
In JAX with TPUs: you describe the model and its logical axes (batch, width, heads, sequence) and a sharding plan. You specify which axes are sharded or replicated. The compiler generates the required collectives.
Tools like Levanter express FSDP-like and tensor-parallel layouts by describing sharding. The compiler chooses the communication.
The Same Ideas Repeat
Memory and links form a speed ladder from small/fast to large/slow.
You want computation close to data. You want high FLOPs per byte moved. You want to overlap compute and communication.
Parallelism is choosing how to split the work:
- Data parallel splits the batch.
- Tensor parallel splits the hidden dimension.
- Pipeline parallel splits layers across ranks.
- Other axes exist too, like sequence length in attention.
Every design trades off three resources:
- Compute: sometimes cheaper to recompute than to store or move.
- Memory: activations and parameters consume limited GPU memory.
- Communication: moving tensors between devices is slow, so minimize it or hide it.
Hardware keeps improving. Models keep growing to fill it. The limits don't go away. Understanding these patterns is what lets you use the hardware well.
Keep reading
You might also like
CS336 Notes: Lecture 7 - Parallelism 1
Distributed training fundamentals: data parallelism, ZeRO/FSDP for memory efficiency, tensor and pipeline parallelism, and how to combine strategies for frontier-scale models.