Chapter 07: Optimization — Initialization, SGD, Adam, AdamW
Training a neural network is an optimisation problem: find the weights that minimise the loss. The choice of optimizer and weight initialisation scheme has a dramatic effect on how quickly (or whether) a network converges.
Weight initialisation matters because gradients flow through matrix multiplications. If weights start too large, activations explode; too small and they vanish. Xavier (Glorot) initialisation keeps variance constant through linear layers assuming symmetric activations; Kaiming (He) initialisation corrects for the fact that ReLU kills roughly half the neurons and doubles the required variance.
Adam (Adaptive Moment Estimation) maintains per-parameter first and second moment estimates, effectively giving each weight its own adaptive learning rate. This makes it much less sensitive to the global learning rate than vanilla SGD. AdamW (Adam with decoupled weight decay) fixes a subtle bug in Adam’s L2 regularisation: in the original Adam, weight decay interacts with the adaptive learning rates in undesirable ways; AdamW applies the decay directly to the weights, independent of the moment estimates.
Learning rate scheduling is equally important. Cosine annealing decays the LR from its peak value to near zero following a cosine curve, often with a short warmup phase at the start to prevent unstable early updates.
1. Weight Initialisation Schemes
import torch
import torch.nn as nn
import math
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
def xavier_uniform_(tensor: torch.Tensor, gain: float = 1.0):
"""
Xavier uniform: U(-a, a) where a = gain * sqrt(6 / (fan_in + fan_out))
Keeps activation variance constant for symmetric activations (tanh, sigmoid).
"""
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor)
a = gain * math.sqrt(6.0 / (fan_in + fan_out))
return tensor.uniform_(-a, a)
def kaiming_normal_(tensor: torch.Tensor, mode: str = "fan_in",
nonlinearity: str = "relu"):
"""
Kaiming (He) normal: N(0, std²) where std = sqrt(2 / fan)
Accounts for the variance-halving effect of ReLU.
"""
fan_in, fan_out = nn.init._calculate_fan_in_and_fan_out(tensor)
fan = fan_in if mode == "fan_in" else fan_out
# Gain for relu = sqrt(2), for linear = 1
gain = math.sqrt(2.0) if nonlinearity == "relu" else 1.0
std = gain / math.sqrt(fan)
return tensor.normal_(0, std)
# Compare activations at initialisation with different schemes
torch.manual_seed(42)
n_layers = 10
n_units = 512
def trace_activations(init_fn, activation):
x = torch.randn(256, n_units)
stds = [x.std().item()]
for _ in range(n_layers):
W = torch.empty(n_units, n_units)
init_fn(W)
x = activation(x @ W)
stds.append(x.std().item())
return stds
relu = torch.relu
tanh = torch.tanh
results = {
"Random N(0,1) + ReLU": trace_activations(lambda W: W.normal_(0, 1), relu),
"Xavier Uniform + Tanh": trace_activations(xavier_uniform_, tanh),
"Kaiming Normal + ReLU": trace_activations(kaiming_normal_, relu),
"Std=0.01 + ReLU": trace_activations(lambda W: W.normal_(0, 0.01), relu),
}
plt.figure(figsize=(10, 5))
for label, stds in results.items():
plt.plot(stds, marker="o", label=label)
plt.xlabel("Layer depth")
plt.ylabel("Activation std")
plt.title("Activation variance under different initialisations")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("data/ch07_init.png", dpi=100)
print("Saved → data/ch07_init.png")
2. Optimizers from Scratch
class SGD:
def __init__(self, params, lr: float = 0.01):
self.params = list(params)
self.lr = lr
def step(self):
for p in self.params:
p.data -= self.lr * p.grad
def zero_grad(self):
for p in self.params:
if p.grad is not None:
p.grad.zero_()
class SGDMomentum:
def __init__(self, params, lr: float = 0.01, momentum: float = 0.9):
self.params = list(params)
self.lr = lr
self.momentum = momentum
# Velocity buffers initialised to zero
self.velocity = [torch.zeros_like(p) for p in self.params]
def step(self):
for p, v in zip(self.params, self.velocity):
v.mul_(self.momentum).add_(p.grad) # v = β*v + g
p.data.sub_(self.lr * v) # θ -= lr * v
def zero_grad(self):
for p in self.params:
if p.grad is not None:
p.grad.zero_()
class Adam:
"""Adam: Adaptive Moment Estimation (Kingma & Ba, 2015)."""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8):
self.params = list(params)
self.lr = lr
self.b1, self.b2 = betas
self.eps = eps
self.t = 0
self.m = [torch.zeros_like(p) for p in self.params] # 1st moment
self.v = [torch.zeros_like(p) for p in self.params] # 2nd moment
def step(self):
self.t += 1
for p, m, v in zip(self.params, self.m, self.v):
g = p.grad
m.mul_(self.b1).add_((1 - self.b1) * g) # m = β1*m + (1-β1)*g
v.mul_(self.b2).add_((1 - self.b2) * g * g) # v = β2*v + (1-β2)*g²
# Bias correction
m_hat = m / (1 - self.b1 ** self.t)
v_hat = v / (1 - self.b2 ** self.t)
p.data -= self.lr * m_hat / (v_hat.sqrt() + self.eps)
def zero_grad(self):
for p in self.params:
if p.grad is not None:
p.grad.zero_()
class AdamW:
"""AdamW: Adam with decoupled weight decay (Loshchilov & Hutter, 2019)."""
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8,
weight_decay=0.01):
self.params = list(params)
self.lr = lr
self.b1, self.b2 = betas
self.eps = eps
self.wd = weight_decay
self.t = 0
self.m = [torch.zeros_like(p) for p in self.params]
self.v = [torch.zeros_like(p) for p in self.params]
def step(self):
self.t += 1
for p, m, v in zip(self.params, self.m, self.v):
g = p.grad
# Weight decay applied directly (decoupled from adaptive step)
p.data.mul_(1 - self.lr * self.wd)
m.mul_(self.b1).add_((1 - self.b1) * g)
v.mul_(self.b2).add_((1 - self.b2) * g * g)
m_hat = m / (1 - self.b1 ** self.t)
v_hat = v / (1 - self.b2 ** self.t)
p.data -= self.lr * m_hat / (v_hat.sqrt() + self.eps)
def zero_grad(self):
for p in self.params:
if p.grad is not None:
p.grad.zero_()
3. Learning Rate Schedules
import numpy as np
def cosine_lr_schedule(step: int, warmup_steps: int,
max_steps: int, max_lr: float,
min_lr: float = None) -> float:
"""
Linear warmup followed by cosine decay.
Used by GPT-2, GPT-3, and most modern LLMs.
"""
if min_lr is None:
min_lr = max_lr / 10
if step < warmup_steps:
# Linear ramp-up
return max_lr * step / warmup_steps
if step >= max_steps:
return min_lr
# Cosine decay from max_lr to min_lr
progress = (step - warmup_steps) / (max_steps - warmup_steps)
coeff = 0.5 * (1.0 + math.cos(math.pi * progress))
return min_lr + coeff * (max_lr - min_lr)
# Visualise the schedule
steps = list(range(3000))
lrs = [cosine_lr_schedule(s, warmup_steps=200, max_steps=3000,
max_lr=3e-4) for s in steps]
plt.figure(figsize=(8, 4))
plt.plot(steps, lrs, color="coral")
plt.axvline(200, linestyle="--", color="gray", label="End of warmup")
plt.xlabel("Step")
plt.ylabel("Learning rate")
plt.title("Cosine LR schedule with linear warmup")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("data/ch07_lr_schedule.png", dpi=100)
print("Saved → data/ch07_lr_schedule.png")
4. Compare Optimisers on a Toy Problem
def rosenbrock(x, y):
"""Classic non-convex optimisation benchmark."""
return (1 - x) ** 2 + 100 * (y - x ** 2) ** 2
optimiser_classes = {
"SGD (lr=0.001)": lambda p: SGD(p, lr=0.001),
"SGD+Momentum (lr=0.01)": lambda p: SGDMomentum(p, lr=0.01),
"Adam (lr=0.01)": lambda p: Adam(p, lr=0.01),
"AdamW (lr=0.01)": lambda p: AdamW(p, lr=0.01, weight_decay=0.0),
}
results_opt = {}
STEPS_OPT = 500
for name, make_opt in optimiser_classes.items():
x = nn.Parameter(torch.tensor([-1.0]))
y = nn.Parameter(torch.tensor([1.0]))
opt = make_opt([x, y])
losses_opt = []
for _ in range(STEPS_OPT):
loss = rosenbrock(x, y)
opt.zero_grad()
loss.backward()
opt.step()
losses_opt.append(loss.item())
results_opt[name] = losses_opt
print(f"{name}: final loss = {losses_opt[-1]:.6f} x={x.item():.4f} y={y.item():.4f}")
plt.figure(figsize=(10, 5))
for name, losses_opt in results_opt.items():
plt.semilogy(losses_opt, label=name)
plt.xlabel("Step")
plt.ylabel("Rosenbrock loss (log scale)")
plt.title("Optimizer comparison on Rosenbrock function")
plt.legend()
plt.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("data/ch07_optimisers.png", dpi=100)
print("Saved → data/ch07_optimisers.png")
5. Gradient Clipping
# Gradient clipping prevents exploding gradients in early training.
# PyTorch's built-in implementation:
model = nn.Linear(512, 512)
dummy_loss = model(torch.randn(32, 512)).sum()
dummy_loss.backward()
total_norm_before = nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
print(f"Gradient norm before clip: {total_norm_before:.4f}")
total_norm_after = sum(p.grad.norm() ** 2 for p in model.parameters() if p.grad is not None) ** 0.5
print(f"Gradient norm after clip: {total_norm_after:.4f}")
6. Summary
| Technique | When to use |
|---|---|
| Xavier init | Linear layers with tanh/sigmoid activations |
| Kaiming init | Linear layers with ReLU activations |
| SGD + momentum | Computer vision, well-tuned schedule |
| Adam | Most NLP tasks; good default |
| AdamW | LLM training; weight decay must be decoupled |
| Cosine schedule + warmup | Standard for all large-scale LLM training runs |
Chapter 8 focuses on hardware: how to exploit GPUs and measure their performance to train models orders of magnitude faster.