CS336 Notes: Lecture 6 - Kernels and Triton
Most GPU optimization advice is wrong because it skips the first step: measuring where time actually goes.
This lecture from Stanford CS336 covers GPU kernel optimization with Triton. The constraint that matters most: memory bandwidth, not compute.
What Actually Limits GPU Performance
A GPU has many streaming multiprocessors (SMs). Each SM runs threads grouped into blocks and warps (32 threads executing in lockstep).
The memory hierarchy determines speed:
- Global DRAM: large, slow. This is where your tensors live.
- Caches: smaller, faster. Recently used data stays here.
- Registers and shared memory: tiny, very fast, local to an SM.
Fast kernels follow one pattern: load from DRAM once, reuse values in registers or shared memory, write back once. Every extra global memory trip costs you.
Arithmetic Intensity Tells You the Bottleneck
Arithmetic intensity = FLOPs per byte moved.
GPU compute improves faster than memory bandwidth each generation. That makes most operations memory-bound, not compute-bound.
Rules of thumb:
- Large, well-tiled matrix multiplies: often compute-bound.
- Elementwise ops, small matmuls, poorly structured work: usually memory-bound.
If you don't know your arithmetic intensity, you don't know what to fix.
Benchmarking GPU Code Without Fooling Yourself
Three traps kill accurate GPU benchmarks:
First-call overhead. The first CUDA op compiles or loads code. That's startup, not steady state. Run warmup iterations before timing.
Asynchronous execution. Python returns immediately after queueing CUDA work. If you time without syncing, you measure the queue, not the kernel. Call torch.cuda.synchronize() before and after timing.
Noise. Thermals, OS activity, and other processes cause variance. Run multiple trials and average.
A simple benchmark helper should: warm up, sync, time several trials, sync after each, report the mean.
PyTorch Profiler: Where Time Goes
The profiler breaks a workload into Python calls, aten ops, CUDA kernels, and launch overhead.
Example: A + B
Python calls result = A + B. The profiler shows aten::add, a CUDA elementwise kernel, and cudaLaunchKernel. The GPU kernel often runs in microseconds. The CPU time can be dominated by the aten dispatch layer and launch overhead.
Example: torch.cdist(A, B)
This breaks into matmuls (often 70%+ of time), plus elementwise ops like pow, subtract, add, plus reductions and sqrt. Each becomes one or more CUDA kernels. If matmul dominates, that's where gains come from.
Built-in ops often use fused kernels. Naive Python math versions launch many small kernels and pay both launch overhead and extra memory traffic. Profiling reveals the difference.
Nsight Systems: CPU and GPU Together
The PyTorch profiler is useful. Nsight Systems shows the full picture: GPU timeline, CPU threads, and NVTX ranges like "forward" and "backward".
Key behaviors to watch:
The CPU queues kernels ahead while the GPU executes later. The CPU often runs far ahead, launching kernels quickly.
Sync points stall. Printing a GPU scalar forces a synchronize. The CPU waits idle for the GPU to finish. Frequent printing cuts GPU utilization.
Python slowness often doesn't matter when most time is on GPU. The exception: huge pure-Python loops or frequent syncs.
Kernel Fusion: Why It Matters
Fusion combines multiple operations into one kernel. Intermediates stay in registers instead of going to global memory. You cut memory traffic and reduce launch overhead.
GeLU example:
Naive manual PyTorch GeLU: write the formula as raw PyTorch math (multiplies, adds, tanh, x**3). Each op launches a separate kernel. Result: many kernels, many global memory passes. Time: about 8.1 ms on a large tensor.
Built-in torch.nn.functional.gelu(approximate="tanh"): uses a fused CUDA kernel. All math happens inside one kernel. Intermediates stay in registers. Time: about 1.1 ms. That's 8x faster.
Custom fused kernels in CUDA C++ or Triton: also one kernel. Time: about 1.8 ms. Much faster than naive, slightly slower than PyTorch's tuned version.
Fusion is a big win for memory-bound elementwise chains.
Writing a Custom CUDA Kernel
Split the work into a GPU kernel and a CPU wrapper.
Kernel function (runs on GPU):
- Mark it
__global__. - Compute
i = blockIdx.x * blockDim.x + threadIdx.x. - If
i < num_elements, load input, compute GeLU, store output.
Wrapper function (runs on CPU, called from Python):
- Check input is on CUDA and contiguous.
- Allocate output with
torch::empty_like(x). - Choose block size (e.g., 1024) and compute num blocks.
- Launch the kernel.
Debug tip: set CUDA_LAUNCH_BLOCKING=1 so errors show at the right point.
Result: custom CUDA GeLU drops time from 8.1 ms to about 1.8 ms.
Writing the Same Kernel in Triton
Triton lets you write GPU kernels in Python at the block level. You write one program instance per block. Triton handles coalescing and many low-level details.
Triton GeLU structure:
pid = tl.program_id(0)
offsets = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
mask = offsets < num_elements
x = tl.load(x_ptr + offsets, mask=mask, other=0)
# Compute GeLU on x as vector math
tl.store(y_ptr + offsets, y, mask=mask)
Performance: Triton GeLU runs in about 1.84 ms, close to the CUDA C++ kernel. It's easier to write and change. It reads like vectorized Python math.
Softmax in Triton
Softmax is harder than GeLU because it needs a row-wise reduction: subtract max, exponentiate, sum, divide.
Simple design for small softmax:
- One block handles one row.
- Number of blocks equals number of rows.
- BLOCK_SIZE is at least the number of columns, rounded up to a power of two.
Performance on a large matrix:
- Naive manual PyTorch softmax: about 3.7 s.
- PyTorch built-in softmax: about 1.5 s.
- torch.compile softmax: about 1.3 s.
- Triton softmax: about 1.9 s.
torch.compile can beat both built-in and a simple Triton kernel by searching better tilings and fusion choices.
torch.compile and Auto-Fusion
torch.compile traces your code, builds a graph, fuses elementwise chains (often with Triton), and picks better kernels based on shapes and hardware.
GeLU times:
- Manual naive math: 8.1 ms.
- Built-in F.gelu: 1.1 ms.
- Custom CUDA or Triton: about 1.8 ms.
- torch.compile on manual GeLU: about 1.47 ms, using a generated fused Triton kernel.
Many wins come from compile and fusion without writing kernels yourself.
When to Hand-Optimize
Measure first. Benchmark with warmup and torch.cuda.synchronize().
Profile to find what dominates. Use the PyTorch profiler for operator breakdown. Use Nsight Systems to see overlap and sync stalls.
Use high-level tools before custom kernels. Prefer built-in ops and tuned libraries. Try torch.compile for fusion and kernel selection. Use Triton when you need a new kernel but want to stay in Python. Drop to C++ and CUDA only when the pattern is unusual or truly critical.
Think in memory traffic and arithmetic intensity. Fuse elementwise chains to cut global reads and writes. Reuse values in registers and shared memory. Avoid unnecessary CPU-GPU transfers, like frequent .item() in training loops.
Custom kernels are worth it for:
- New building blocks that libraries don't cover yet.
- Tight inner loops that profiling proves are hotspots.
- Cases where compilers miss important hardware behavior.
The test: profile first. If you can't point to the hotspot, you're guessing. Guessing wastes time.
Keep reading
You might also like
CS336 Notes: Lecture 5 - GPUs
GPU fundamentals for LLM training: memory hierarchy, arithmetic intensity, kernel optimization, FlashAttention, and bandwidth limits.