Chapter 10: Need for Speed III — Distributed Training (DDP, ZeRO)

A single GPU has a fixed amount of memory and a fixed peak FLOPS. To train models larger than a single GPU can hold — or simply to train faster — we spread the work across multiple GPUs and multiple machines.

Data Parallelism is the simplest approach: every GPU holds a complete copy of the model, but each processes a different mini-batch. After each backward pass, gradients are all-reduced (summed and averaged) across all GPUs so every replica stays in sync. PyTorch’s DistributedDataParallel (DDP) implements this with efficient NCCL communication that overlaps gradient synchronisation with backward computation.

ZeRO (Zero Redundancy Optimizer) from DeepSpeed eliminates the redundancy of data parallelism. ZeRO Stage 1 shards the optimiser states across GPUs; Stage 2 also shards gradients; Stage 3 shards parameters. ZeRO-3 can train models that are N× larger than a single GPU’s memory, where N is the number of GPUs.

When you don’t have multiple GPUs, gradient accumulation achieves the same effect as a large batch: run several forward/backward passes before calling optimizer.step(), summing gradients across micro-batches.

1. Distributed Training with torchrun

# save this as train_ddp.py and launch with:
#   torchrun --nproc_per_node=<NUM_GPUS> train_ddp.py

import os
import torch
import torch.nn as nn
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, TensorDataset, DistributedSampler

def main():
    # torchrun sets these environment variables automatically
    rank       = int(os.environ.get("RANK",       0))
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    world_size = int(os.environ.get("WORLD_SIZE", 1))

    # Initialise the process group for NCCL (GPU) or Gloo (CPU)
    backend = "nccl" if torch.cuda.is_available() else "gloo"
    dist.init_process_group(backend=backend)

    device = f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
    torch.cuda.set_device(device)

    if rank == 0:
        print(f"World size: {world_size}  |  Backend: {backend}")

    # ----------------------------------------------------------------
    # Build model and wrap with DDP
    # ----------------------------------------------------------------
    model = nn.Sequential(
        nn.Linear(512, 2048), nn.GELU(),
        nn.Linear(2048, 2048), nn.GELU(),
        nn.Linear(2048, 512),
    ).to(device)

    # DDP wraps the model; all_reduce happens automatically
    model = DDP(model, device_ids=[local_rank] if torch.cuda.is_available() else None)

    # ----------------------------------------------------------------
    # Dataset: DistributedSampler ensures each rank gets unique shards
    # ----------------------------------------------------------------
    N = 10_000
    X = torch.randn(N, 512)
    Y = torch.randn(N, 512)
    dataset = TensorDataset(X, Y)

    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank,
                                 shuffle=True, drop_last=True)
    loader  = DataLoader(dataset, batch_size=64, sampler=sampler,
                         num_workers=2, pin_memory=True)

    optim = torch.optim.AdamW(model.parameters(), lr=1e-3)

    # ----------------------------------------------------------------
    # Training loop
    # ----------------------------------------------------------------
    for epoch in range(3):
        sampler.set_epoch(epoch)   # ensures different shuffling each epoch
        total_loss = 0.0
        for xb, yb in loader:
            xb, yb = xb.to(device), yb.to(device)
            out    = model(xb)
            loss   = nn.functional.mse_loss(out, yb)
            optim.zero_grad()
            loss.backward()   # DDP all-reduces gradients here
            optim.step()
            total_loss += loss.item()

        if rank == 0:
            print(f"Epoch {epoch}  loss: {total_loss / len(loader):.4f}")

    # Save checkpoint only on rank 0
    if rank == 0:
        # DDP wraps the model, so access the inner module for saving
        torch.save(model.module.state_dict(), "data/ddp_checkpoint.pt")
        print("Checkpoint saved → data/ddp_checkpoint.pt")

    dist.destroy_process_group()

# Uncomment to run (requires torchrun or multiple GPUs):
# if __name__ == "__main__":
#     main()

print("DDP script defined. Launch with: torchrun --nproc_per_node=2 train_ddp.py")

2. Gradient Accumulation (DDP Alternative)

import torch
import torch.nn as nn

# Gradient accumulation: equivalent to training with a batch size
# of ACCUM_STEPS × MICRO_BATCH_SIZE without the memory cost.

device = "cuda" if torch.cuda.is_available() else "cpu"

model = nn.Sequential(
    nn.Linear(256, 1024), nn.GELU(),
    nn.Linear(1024, 256),
).to(device)

optim = torch.optim.AdamW(model.parameters(), lr=3e-4)

MICRO_BATCH_SIZE = 32
ACCUM_STEPS      = 8           # effective batch size = 32 × 8 = 256

print(f"Effective batch size: {MICRO_BATCH_SIZE * ACCUM_STEPS}")

for step in range(20):
    accum_loss = 0.0

    for micro_step in range(ACCUM_STEPS):
        xb = torch.randn(MICRO_BATCH_SIZE, 256, device=device)
        yb = torch.randn(MICRO_BATCH_SIZE, 256, device=device)

        out  = model(xb)
        # Divide loss by number of accumulation steps before backward
        # so the gradient magnitude is equivalent to a single large batch
        loss = nn.functional.mse_loss(out, yb) / ACCUM_STEPS
        loss.backward()
        accum_loss += loss.item()

    # Gradient update happens once per ACCUM_STEPS micro-batches
    nn.utils.clip_grad_norm_(model.parameters(), 1.0)
    optim.step()
    optim.zero_grad()

    if step % 5 == 0:
        print(f"Step {step}  loss: {accum_loss:.4f}")

3. DDP Communication Overhead Simulation

import torch.distributed as dist

def simulate_allreduce(tensor_size: int, world_size: int) -> str:
    """
    Ring all-reduce sends 2*(world_size-1)/world_size * data per GPU.
    Approximate bandwidth requirement.
    """
    bytes_per_elem = 4   # float32
    total_bytes    = tensor_size * bytes_per_elem
    # Ring all-reduce traffic per GPU
    traffic_bytes  = 2 * (world_size - 1) / world_size * total_bytes
    return (f"  Tensor: {total_bytes/1e6:.1f} MB  "
            f"| Per-GPU traffic: {traffic_bytes/1e6:.1f} MB  "
            f"| At 600 GB/s NVLink: {traffic_bytes/600e9*1e3:.2f} ms")

param_counts = {
    "GPT-2 Small (117M)": 117_000_000,
    "GPT-2 Large (774M)": 774_000_000,
    "LLaMA-7B":           7_000_000_000,
}

for name, params in param_counts.items():
    print(f"\n{name} ({params/1e6:.0f}M params), world_size=8:")
    print(simulate_allreduce(params, world_size=8))

4. ZeRO Stages Explained

For a model with $\Phi$ parameters trained in mixed precision, the memory per GPU under each ZeRO stage is:

Strategy Memory per GPU
Baseline DDP $16\Phi$ bytes (2Φ FP16 params + 2Φ FP16 grads + 12Φ FP32 optimizer state)
ZeRO Stage 1 (shard optimizer state) $(4 + 12/N)\Phi$ bytes
ZeRO Stage 2 (shard grads too) $(2 + 14/N)\Phi$ bytes
ZeRO Stage 3 (shard everything) $(16/N)\Phi$ bytes

where $N$ is the number of GPUs. For a 7B model on 8 GPUs: baseline requires 112 GB per GPU (won’t fit on an 80 GB A100), while ZeRO-3 needs only 14 GB per GPU.

5. Using DeepSpeed ZeRO

# DeepSpeed config for ZeRO Stage 2 — save as ds_config.json
ds_config = {
    "train_batch_size": 256,
    "gradient_accumulation_steps": 8,
    "optimizer": {
        "type": "AdamW",
        "params": {"lr": 3e-4, "betas": [0.9, 0.95], "weight_decay": 0.1}
    },
    "scheduler": {
        "type": "WarmupCosineAnnealing",
        "params": {"warmup_num_steps": 1000, "total_num_steps": 100000}
    },
    "fp16": {"enabled": True},
    "zero_optimization": {
        "stage": 2,
        "overlap_comm": True,
        "contiguous_gradients": True,
        "reduce_bucket_size": 5e8,
    }
}

import json, os
os.makedirs("data", exist_ok=True)
with open("data/ds_config.json", "w") as f:
    json.dump(ds_config, f, indent=2)
print("DeepSpeed config saved → data/ds_config.json")

# Launch with:
# deepspeed --num_gpus=8 train.py --deepspeed data/ds_config.json

print("""
DeepSpeed ZeRO training launch:
  deepspeed --num_gpus 8 train.py --deepspeed data/ds_config.json

FSDP (Fully Sharded Data Parallel) — PyTorch native alternative:
  from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
  model = FSDP(model, sharding_strategy=ShardingStrategy.FULL_SHARD)
""")

6. Practical Single-Machine Multi-GPU Setup

Quick-start: 2-GPU training with torchrun

  1. Install required packages: pip install torch torchvision
  2. Create train.py using the DDP template above (main() function).
  3. Launch training: torchrun --standalone --nproc_per_node=2 train.py

Key DDP pitfalls to avoid:

  • Always call sampler.set_epoch(epoch) for proper shuffling
  • Save checkpoints only on rank == 0
  • Use model.module.state_dict() not model.state_dict() for DDP models
  • Wrap model with DDP after moving it to the correct device
  • Set NCCL_DEBUG=INFO to debug communication issues

7. Summary

Strategy Memory/GPU Communication Complexity
Single GPU None Trivial
DDP 1× (replicated) All-reduce grads Low
ZeRO Stage 1 ~0.5× All-reduce grads Medium
ZeRO Stage 3 ~1/N× All-gather params High
Gradient accum None Trivial

Chapter 11 covers the data pipeline: downloading, preprocessing, and efficiently loading TinyStories for our training runs.