CS336 Notes: Lecture 11 - Scaling Laws 2
"20 tokens per parameter" is not a law. It's a lower bound.
This lecture from Stanford CS336 covers how real teams use scaling laws. The core lesson: the ideas from Chinchilla survive contact with production, but the constants change with architecture, data quality, and how you handle hyperparameters.
What We Want From Scaling
When you scale a model, you want three things:
- Hyperparameters that don't need full retuning at each size.
- A sensible tradeoff between model size and training tokens.
- Performance you can predict from compute before you spend it.
After Chinchilla and ChatGPT, most big labs stopped publishing their full scaling recipes. We learn scaling practice from groups that stayed open: Cerebras-GPT, MiniCPM, and DeepSeek.
Cerebras-GPT: muP in Action
Cerebras-GPT spans 0.1B to 13B parameters and follows Chinchilla-style compute-optimal training.
Their focus: parameterization. They compare standard parameterization (SP) to muP.
Core result: On a log-log plot of test loss vs compute, with SP, results bounce around the predicted line across sizes because optimal learning rate shifts with width. With muP, performance tracks the predicted line smoothly.
Why this matters: If learning rates stay stable as you widen, you can do heavy hyperparameter search on tiny models, then scale up with confidence.
In SP: Weights initialize with standard 1/sqrt(fanin). You use one global learning rate.
In muP: The big change is learning-rate scaling by layer. For Adam-like optimizers, per-layer learning rates scale down with width (or fanin).
The workflow: tune hard on very small models. Use muP to scale width up while keeping hyperparameters mostly fixed.
MiniCPM: WSD and High Token Ratios
MiniCPM trains small models (1.2B to 2.4B) with serious compute to make them unusually strong for their size.
How they use muP: Embeddings get a fixed scale. Residual blocks use depth-related scaling. Per-layer learning rates scale with width. Their best hyperparameters look consistent with Cerebras, with differences around a constant factor.
Their workflow: Use very small models (9M and 30M) to search for aspect ratio, learning rate, and core settings. Fix those choices and scale to mid-sized models.
Critical Batch Size
They study the point where increasing batch size stops giving meaningful gains.
For each model size, they train over a range of batch sizes and track final loss. The critical batch size grows as you push to lower loss. The relationship is close to a power law.
Once you know your target loss, you can read off a principled batch size instead of guessing.
Learning Rate Stability
They test muP's claim: does optimal learning rate stay stable across width?
They sweep learning rates at multiple model sizes. They see a wide "good" region and a minimum at nearly the same learning rate across sizes. muP makes learning rates width-invariant in practice.
WSD Learning Rate Schedule
WSD (Warmup-Stable-Decay) is a practical trick that makes Chinchilla-style analysis cheap.
The problem: Chinchilla-style work needs many runs: multiple model sizes and token budgets per size. With cosine schedules, the learning-rate curve depends on total steps. You can't treat a checkpoint from a long run as if it were the end of a shorter run. That makes scaling studies expensive.
WSD fixes this with three pieces:
- Warmup: ramp up to full LR.
- Stable: hold LR constant on a long plateau.
- Decay: drop LR quickly to a small value.
To simulate a shorter run, pick an earlier checkpoint from the plateau and apply the same decay from that point. Each rewind gives a clean schedule for a different effective token budget without retraining from scratch.
The real win is efficiency for scaling experiments.
Their Chinchilla Analysis
Using WSD, they run isoFLOPs analysis: lower-envelope method and two-variable fitting.
Their headline result: optimal tokens-per-parameter ratio around 192. Far above 20.
The takeaway: "20 tokens per parameter" depends on architecture, data quality, and optimization. In modern settings, expect higher.
DeepSeek: Scaling Hyperparameters Directly
DeepSeek doesn't use muP. They treat learning rate and batch size as things you scale with compute by fitting empirical laws.
Workflow:
- Pick smaller model sizes.
- Sweep learning rates and batch sizes.
- Record loss over a grid, find broad minimum.
- Repeat at larger compute budgets.
- Record best settings at each level.
- Fit trends and extrapolate to the target run.
Batch size follows a clean scaling trend. Learning rate is noisier but still fits.
They also use a WSD-like schedule for the same reason: reusability across data budgets without rerunning everything.
Their Chinchilla Analysis
They vary model size and tokens at multiple compute budgets. They fit quadratics in tokens at fixed compute and identify optima.
They see clean, smooth isoFLOP curves. Their observed results match predictions closely. The payoff: reduced surprise.
Newer Results
LLaMA 3: Meta's isoFLOPs analysis finds optimal ratio around 39:1, roughly double Chinchilla's 20:1.
Hunyuan (MoE): For MoE, "active parameters" matter more than total. IsoFLOPs using active parameters finds about 96 tokens per active parameter. The method is stable; the numbers depend on architecture.
MiniMax-01 (long context, linear attention): Compares standard attention to linear-time attention using scaling curves. Efficient attention variants scale similarly per compute.
Why muP Exists
muP's goal is narrow: keep learning rate stable as width grows. If that works, you tune on small models and scale width without retuning.
Two stability conditions:
Condition 1 (activation stability at init): As width grows, each activation coordinate stays order-one. Layer L2 norm scales like sqrt(width).
Condition 2 (update stability after one step): After one optimizer step, activation change per coordinate stays order-one.
These conditions define "width-invariant training."
muP Derivation
Step 1 (activation stability gives init rule): To keep activations the same scale as width changes, initialize weights with variance ~ 1/fanin. Standard deviation ~ 1/sqrt(fanin). This is Xavier/Kaiming-style. Most practice already matches it.
Step 2 (update stability gives learning rate rule): For Adam-like optimizers, per-layer learning rates should scale like 1/fanin. When fanin is model width, that's "learning rate scales like 1/width" at the layer level.
This is the practical core of muP. Initialization is mostly already fine. Learning-rate scaling makes transfer work.
Standard vs muP
Standard parameterization: Init is often close to muP's needs. But one global learning rate means wider models need smaller learning rates. The optimum drifts.
muP: Keep init stable. Scale per-layer learning rates with fanin for Adam-like optimizers. The best learning rate stays put as you widen.
Practical muP Recipe
For non-embedding layers:
- Initialize weights with variance ~ 1/fanin.
- Scale each layer's AdamW learning rate ~ 1/fanin.
Treat embeddings separately with a fixed scale.
Transformers have a subtlety: standard attention uses 1/sqrt(d) in logits. Some muP treatments prefer 1/d for tighter update stability.
Empirical Evidence
Large studies test muP by tuning learning rate on small width, then training larger widths with the same rate.
They find: With muP, loss-vs-learning-rate curves line up across widths. The best learning rate transfers. With standard parameterization, the same rate often becomes too large at higher widths.
Robustness: Changing MLP nonlinearities preserves transfer. Changing batch sizes doesn't break it. Small init tweaks matter less than learning-rate scaling.
Failure modes: Extra learnable gains in wrong places can break invariance. Sign-based optimizers like Lion don't match muP's assumptions. Strong weight decay can spoil transfer.
A hero run up to ~10B parameters trained successfully with a learning rate chosen from small-scale muP sweeps.
A Practical Recipe
1. Fix the architecture's aspect ratio. Keep ratios like d_model : d_mlp : heads consistent as you scale.
2. Decide how you'll handle hyperparameters.
Option A (muP): Implement muP init and per-layer learning-rate scaling for AdamW. Tune on small models. Scale width with minimal retuning.
Option B (DeepSeek-style): Run small-scale sweeps. Fit scaling trends for batch size and learning rate vs compute. Extrapolate.
3. Use isoFLOPs analysis. At multiple compute budgets, vary parameters and tokens. Fit curves. Pick the token budget that minimizes loss.
4. Use WSD to make experiments affordable. Warmup, long plateau, fast decay. Reuse plateau checkpoints to emulate many token budgets per model size.
5. Treat 20 tokens per parameter as a floor, not a rule. Expect higher optimal ratios. Re-measure when architecture or data changes materially.
The Core Lesson
Across Cerebras-GPT, MiniCPM, DeepSeek, LLaMA 3, Hunyuan, and MiniMax, the same pattern shows up.
IsoFLOPs analysis is reliable and worth doing. muP makes width scaling easier by stabilizing learning rates. WSD makes scaling experiments cheaper.
Scaling is no longer "make it bigger and hope." It's a loop: predict, train, measure, adjust. Spend compute where the curves say it matters.
Keep reading
You might also like
CS336 Notes: Lecture 9 - Scaling Laws 1
Understanding scaling laws: how loss depends on data, parameters, and compute, the Chinchilla tradeoff for compute-optimal training, and why power laws emerge in deep learning.