Stochastic gradient descent and variants
In this series (18 parts)
- What is optimization and why ML needs it
- Convex sets and convex functions
- Optimality conditions: first order
- Optimality conditions: second order
- Line search methods
- Least squares: the closed-form solution
- Steepest descent (gradient descent)
- Newton's method for optimization
- Quasi-Newton methods: BFGS and L-BFGS
- Conjugate gradient methods
- Constrained optimization and Lagrangian duality
- KKT conditions
- Penalty and barrier methods
- Interior point methods
- The simplex method
- Frank-Wolfe method
- Optimization in dynamic programming and optimal control
- Stochastic gradient descent and variants
Prerequisites
You should understand gradient descent and how the gradient tells you the direction of steepest ascent. Everything in this article builds on that foundation.
Why not just use gradient descent?
Standard (batch) gradient descent computes the gradient using the entire dataset:
where is the average loss over training examples.
For a dataset with million images, computing the full gradient requires a forward and backward pass through all 10 million examples. That is one gradient step. You might need hundreds of steps to converge. This is too slow.
Stochastic gradient descent (SGD) fixes this by using a single example (or a small batch) to estimate the gradient. Each step is noisy, but you get to take many more steps in the same wall-clock time.
Stochastic gradient descent
At each step, sample a random index uniformly from and update:
The stochastic gradient is an unbiased estimate of the true gradient:
On average, you move in the right direction. But any single step might be off. The noise is both a blessing (helps escape shallow local minima, provides implicit regularization) and a curse (prevents exact convergence without decreasing the learning rate).
Learning rate schedule
For SGD to converge, you need:
The classic choice is . In practice, people use step decay (halve the learning rate every few epochs), cosine annealing, or warmup followed by decay.
Mini-batch SGD
Pure SGD uses one sample. Mini-batch SGD uses a batch of samples:
The mini-batch gradient has lower variance than the single-sample gradient by a factor of :
Typical batch sizes: 32, 64, 128, 256. Larger batches give smoother gradients but less frequent updates. There is also a practical reason: GPU parallelism. A batch of 256 processes almost as fast as a batch of 1 on modern hardware, so you get 256x variance reduction nearly for free.
Example 1: SGD vs batch GD on a quadratic
Problem: Minimize . Suppose we have “data points” with individual losses:
Average: .
True gradient: .
Stochastic gradients: , , .
Batch GD with , starting at :
Smooth, steady convergence: .
SGD with , starting at . Random sample order: .
Noisy path: . The first step overshoots dramatically (the stochastic gradient was 2x the true gradient), but we still make progress. After three steps, went from 4 to 0.219, which is actually better than batch GD’s 0.289.
This is SGD in a nutshell: noisy individual steps, but fast overall progress because each step is cheap.
Momentum
SGD oscillates, especially in narrow valleys where the gradient direction swings back and forth. Momentum smooths this out by maintaining a running average of past gradients:
Here is the momentum coefficient, typically . The velocity accumulates gradient history: consistent gradient directions build up speed, while oscillating directions cancel out.
Think of a ball rolling down a hill. Without momentum, it changes direction instantly based on the local slope. With momentum, it has inertia. It rolls faster down consistent slopes and resists sudden direction changes.
Nesterov momentum
A slight but important twist: compute the gradient at the “lookahead” position instead of the current position:
Nesterov momentum has better theoretical convergence for convex problems and tends to work slightly better in practice.
Example 2: Momentum vs plain SGD
Problem: Minimize .
This is a narrow valley: the curvature in is 10x the curvature in . Gradient descent oscillates in and converges slowly in .
Starting point: . Learning rate .
For simplicity, we use the full gradient (not stochastic) to isolate the effect of momentum.
.
Plain GD, 5 steps:
| Step | |||
|---|---|---|---|
| 0 | 1.000 | 10.000 | 110.0 |
| 1 | 81.0 | ||
| 2 | 65.61 | ||
| 3 | 0.0 | 53.14 | |
| 4 | 0.0 | 43.05 | |
| 5 | 0.0 | 5.905 | 34.87 |
converged in one step (lucky with and curvature 20, giving step ). converges slowly: .
SGD with momentum (), same setup:
Initialize .
Step 0 to 1:
Step 1 to 2:
Step 2 to 3:
Step 3 to 4:
Step 4 to 5:
With momentum, progress: . Much faster convergence in ! But oscillates a bit: . The momentum builds up speed in the consistent direction ( shrinking) while the oscillations in partially cancel.
After 5 steps: plain GD has , momentum has . Momentum wins by a wide margin.
Adaptive learning rates
The learning rate is the most important hyperparameter in SGD. Too large, you diverge. Too small, you converge painfully slowly. Different parameters might need different learning rates (a bias term vs a weight matrix, for instance).
Adaptive methods give each parameter its own effective learning rate based on the history of its gradients.
AdaGrad
Accumulate the sum of squared gradients for each parameter:
Parameters with large accumulated gradients get smaller learning rates. This is great for sparse features (NLP, recommendations) where some features are rare and need larger updates. The downside: only grows, so the learning rate monotonically decreases and can become too small.
RMSProp
Fix AdaGrad’s shrinking learning rate by using an exponential moving average instead of a sum:
The decay rate (typically 0.99) controls how far back the moving average looks. Old gradients are exponentially forgotten, preventing the learning rate from collapsing.
Adam
Adam combines momentum and RMSProp:
Bias correction (important in early steps when and are biased toward zero):
Update:
Default hyperparameters: , , , .
Adam is the default optimizer for most deep learning tasks. It usually works well out of the box with the default parameters.
Example 3: Five steps of Adam vs plain SGD
Problem: Minimize . True gradient: . Starting point: .
We use the true gradient (not stochastic) to focus on the optimizer behavior.
Plain SGD with :
| 0 | 10.0 | 20.0 | |
| 1 | 8.0 | 16.0 | |
| 2 | 6.4 | 12.8 | |
| 3 | 5.12 | 10.24 | |
| 4 | 4.096 | 8.192 |
After 5 steps: , .
Adam with (Adam’s default is 0.001, but we use 1.0 to see clear steps), , :
| 0 | 10.0 | 20.0 |
Step 1:
Step 2:
Step 3:
Step 4:
Step 5:
Summary after 5 steps:
| Method | ||
|---|---|---|
| Plain SGD () | 3.277 | 10.74 |
| Adam () | 5.078 | 25.79 |
In this simple case, plain SGD actually wins because it has a well-tuned learning rate for this specific problem. Adam takes steps of roughly constant size (close to ) regardless of gradient magnitude, which is a feature for complex loss landscapes but a disadvantage on simple quadratics. Adam shines on problems with:
- Very different scales across parameters
- Noisy gradients
- Saddle points and flat regions
Learning rate schedules
Even with adaptive methods, the learning rate schedule matters. Common choices:
Step decay
Drop the learning rate by factor (e.g., 0.1) every epochs. Simple and effective. Used in most ResNet training recipes.
Cosine annealing
Smoothly decreases from to over steps. Popular in modern training (vision transformers, language models).
Warmup
Start with a very small learning rate and linearly increase to the target over the first few thousand steps:
Warmup stabilizes training when the initial gradients are large and unreliable (common with random initialization in transformers).
One-cycle policy
Increase the learning rate from small to large over the first half of training, then decrease it back down. This is surprisingly effective and was popularized by Leslie Smith.
Which optimizer to use
| Scenario | Recommendation |
|---|---|
| Starting a new project | Adam with defaults (, , ) |
| Training CNNs (image classification) | SGD + momentum () + step decay schedule |
| Training transformers | Adam or AdamW with warmup + cosine schedule |
| Sparse features (NLP, recommendations) | Adam or AdaGrad |
| Need best final accuracy | SGD + momentum (often beats Adam at convergence with proper tuning) |
| Quick prototyping | Adam (less tuning required) |
The general pattern: Adam converges faster initially, but well-tuned SGD with momentum often reaches a better final solution. This has been observed repeatedly in computer vision, though Adam with weight decay (AdamW) has narrowed the gap.
Convergence comparison of SGD variants. Adam and RMSProp adapt their learning rates and converge faster than plain SGD. Momentum smooths out oscillations for faster convergence than vanilla SGD.
Optimization paths on f(x,y) = x^2 + 10y^2. SGD oscillates in the high-curvature y direction. Momentum damps the oscillations. Adam takes a more direct path to the minimum.
AdamW: weight decay done right
Standard regularization adds to the loss. In Adam, this interacts poorly with the adaptive learning rate: the regularization gradient gets scaled by , weakening it for parameters with large gradients.
AdamW decouples weight decay from the gradient:
The term is applied directly, not through the adaptive mechanism. This is now the standard optimizer for training language models.
Gradient noise and generalization
A surprising finding in deep learning: SGD with small batches often generalizes better than full-batch gradient descent, even when both reach the same training loss. The noise in SGD acts as an implicit regularizer, pushing the optimizer toward “flat” minima that generalize well.
This connects to the bias-variance tradeoff:
- Large batch, low noise: converges to sharp minima. Low training loss, potentially high test loss.
- Small batch, high noise: finds flat minima. Slightly higher training loss, better test loss.
The optimal batch size depends on the problem. Too small wastes GPU parallelism. Too large hurts generalization. Most practitioners use batch sizes of 32 to 512 and tune from there.
Practical tips
-
Start with Adam, . If results are not good enough, try SGD + momentum with a carefully tuned schedule.
-
Use gradient clipping (clip the gradient norm to a maximum value, say 1.0) to prevent exploding gradients, especially in RNNs and transformers.
-
Monitor the gradient norm during training. If it spikes, your learning rate is too high or there is a data issue.
-
Learning rate finder: sweep the learning rate from to over one epoch and plot the loss. The best learning rate is usually about 10x smaller than the one where loss starts increasing.
-
Weight decay ( to ) usually helps. Use AdamW, not Adam + L2 regularization.
Python: implementing the optimizers
import numpy as np
class SGD:
def __init__(self, lr=0.01):
self.lr = lr
def step(self, params, grads):
return params - self.lr * grads
class SGDMomentum:
def __init__(self, lr=0.01, beta=0.9):
self.lr = lr
self.beta = beta
self.v = None
def step(self, params, grads):
if self.v is None:
self.v = np.zeros_like(params)
self.v = self.beta * self.v + grads
return params - self.lr * self.v
class Adam:
def __init__(self, lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8):
self.lr = lr
self.beta1 = beta1
self.beta2 = beta2
self.eps = eps
self.m = None
self.v = None
self.t = 0
def step(self, params, grads):
if self.m is None:
self.m = np.zeros_like(params)
self.v = np.zeros_like(params)
self.t += 1
self.m = self.beta1 * self.m + (1 - self.beta1) * grads
self.v = self.beta2 * self.v + (1 - self.beta2) * grads**2
m_hat = self.m / (1 - self.beta1**self.t)
v_hat = self.v / (1 - self.beta2**self.t)
return params - self.lr * m_hat / (np.sqrt(v_hat) + self.eps)
# Compare on f(x) = x^2, starting at x = 10
for name, opt in [("SGD", SGD(0.1)),
("Momentum", SGDMomentum(0.05, 0.9)),
("Adam", Adam(1.0))]:
x = 10.0
trajectory = [x]
for _ in range(20):
g = 2 * x # gradient of x^2
x = opt.step(x, g)
trajectory.append(float(x))
print(f"{name:10s}: final x = {x:.4f}, f(x) = {x**2:.4f}")
Summary table
| Optimizer | Update rule (simplified) | Memory | Good for |
|---|---|---|---|
| SGD | Simple problems, theory | ||
| Momentum | CNNs, well-tuned training | ||
| AdaGrad | Sparse features | ||
| RMSProp | RNNs, non-stationary | ||
| Adam | Default choice for DL | ||
| AdamW | Adam + decoupled weight decay | Transformers, LLMs |
What comes next
With SGD and its variants, you have the optimizers that power all of modern machine learning. The next step is to see how these tools get applied: from loss functions and regularization to training pipelines for real models. Head over to what is machine learning to start the ML series, where optimization meets data.