Chapter 09: Need for Speed II — Precision (fp32, fp16, bf16, fp8)

Standard neural network training uses float32 (FP32): 32-bit numbers with 8 exponent bits and 23 mantissa bits. FP32 gives ample numerical range and precision for most computations. But it is also wasteful: weights, activations, and gradients that vary over many orders of magnitude only truly need a few bits of precision.

Mixed precision training keeps a master copy of weights in FP32 (for numerical stability in the optimiser) while performing the forward and backward passes in FP16 or BF16. This halves memory consumption and, on modern Tensor Core GPUs, doubles or quadruples arithmetic throughput.

FP16 (IEEE half-precision): 5 exponent bits, 10 mantissa bits. Maximum value ≈32 767. Activations can easily overflow during training, so a GradScaler is required to scale the loss up before the backward pass and scale gradients back down before the optimiser step.

BF16 (bfloat16): 8 exponent bits, 7 mantissa bits. Same dynamic range as FP32, much lower precision. Overflow is essentially impossible, so no GradScaler is needed. BF16 is available on Ampere+ GPUs (A100, RTX 3090+) and is the recommended format for modern LLM training.

1. Float Formats Explained

import torch
import struct

def float_bits(x: float, dtype) -> str:
    """Show the binary representation of a float."""
    t = torch.tensor(x, dtype=dtype)
    if dtype == torch.float32:
        bits = struct.unpack("I", struct.pack("f", t.item()))[0]
        return f"{bits:032b}"
    elif dtype == torch.float16:
        bits = t.view(torch.int16).item() & 0xFFFF
        return f"{bits:016b}"
    elif dtype == torch.bfloat16:
        bits = t.view(torch.int16).item() & 0xFFFF
        return f"{bits:016b}"

val = 3.14
print(f"  float32 bits: {float_bits(val, torch.float32)}")
print(f"  float16 bits: {float_bits(val, torch.float16)}")
print(f" bfloat16 bits: {float_bits(val, torch.bfloat16)}")

# Range and precision comparison
for dtype, label in [(torch.float32, "float32"),
                     (torch.float16, "float16"),
                     (torch.bfloat16, "bfloat16")]:
    info = torch.finfo(dtype)
    print(f"\n{label}:")
    print(f"  bits     : {info.bits}")
    print(f"  max      : {info.max:.2e}")
    print(f"  min      : {info.min:.2e}")
    print(f"  smallest : {info.tiny:.2e}")
    print(f"  eps      : {info.eps:.2e}")

2. Memory Savings

import torch.nn as nn

def model_memory_mb(model: nn.Module, dtype) -> float:
    """Estimate model parameter memory in MB."""
    n_params = sum(p.numel() for p in model.parameters())
    bytes_per_param = torch.finfo(dtype).bits // 8
    return n_params * bytes_per_param / 1e6

model = nn.Sequential(
    nn.Linear(1024, 4096), nn.GELU(),
    nn.Linear(4096, 4096), nn.GELU(),
    nn.Linear(4096, 1024),
)
n_params = sum(p.numel() for p in model.parameters())
print(f"Model parameters: {n_params:,}")

for dtype in [torch.float32, torch.float16, torch.bfloat16]:
    mb = model_memory_mb(model, dtype)
    print(f"  {str(dtype):20s}: {mb:.1f} MB")

3. Mixed Precision Training with torch.autocast

import time

device = "cuda" if torch.cuda.is_available() else "cpu"
model  = model.to(device)
optim  = torch.optim.AdamW(model.parameters(), lr=1e-3)

x = torch.randn(64, 1024, device=device)

# ---- FP32 baseline ----
model = model.float()
t0 = time.perf_counter()
for _ in range(20):
    out  = model(x)
    loss = out.sum()
    loss.backward()
    optim.step()
    optim.zero_grad()
if device == "cuda": torch.cuda.synchronize()
fp32_ms = (time.perf_counter() - t0) * 1000 / 20
print(f"FP32 step time: {fp32_ms:.2f} ms")

# ---- Mixed Precision (BF16 preferred; fall back to FP16) ----
amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
print(f"Using AMP dtype: {amp_dtype}")

# GradScaler is only needed for FP16 (BF16 has same range as FP32)
scaler = torch.cuda.amp.GradScaler(enabled=(amp_dtype == torch.float16))

t0 = time.perf_counter()
for _ in range(20):
    with torch.autocast(device_type=device, dtype=amp_dtype):
        out  = model(x)
        loss = out.sum()

    scaler.scale(loss).backward()
    scaler.step(optim)
    scaler.update()
    optim.zero_grad()
if device == "cuda": torch.cuda.synchronize()
amp_ms = (time.perf_counter() - t0) * 1000 / 20

print(f"AMP  step time : {amp_ms:.2f} ms")
print(f"Speedup        : {fp32_ms / amp_ms:.2f}×")

4. GradScaler Deep Dive

# The GradScaler prevents FP16 underflow in gradients:
#   1. Multiply loss by a large scale factor before backward()
#   2. After backward(), divide gradients by the scale factor
#   3. If any gradient is inf/nan, skip the update and reduce the scale
#   4. Periodically increase the scale if updates have been stable

scaler = torch.cuda.amp.GradScaler(
    init_scale     = 2**16,    # initial scale (65 536)
    growth_factor  = 2.0,      # multiply scale by this after `growth_interval` steps
    backoff_factor = 0.5,      # multiply scale by this after inf/nan
    growth_interval= 2000,     # steps between scale increases
    enabled        = (amp_dtype == torch.float16),
)

print(f"Initial scale: {scaler.get_scale()}")

# One training step with explicit scaler calls
with torch.autocast(device_type=device, dtype=amp_dtype):
    out  = model(x)
    loss = out.mean()

scaler.scale(loss).backward()      # backward in FP16/BF16; scale grads
scaler.unscale_(optim)             # divide grads by scale (optional explicit step)
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)   # clip AFTER unscale
scaler.step(optim)                 # only applies update if no inf/nan
scaler.update()                    # adjust scale for next step
optim.zero_grad()
print(f"Scale after step: {scaler.get_scale()}")

5. Measure Memory Usage

if device == "cuda":
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    # FP32
    model_fp32 = nn.Linear(4096, 4096).to(device)
    x_fp32     = torch.randn(64, 4096, device=device, dtype=torch.float32)
    out_fp32   = model_fp32(x_fp32)
    out_fp32.sum().backward()
    peak_fp32  = torch.cuda.max_memory_allocated() / 1e6
    del model_fp32, x_fp32, out_fp32

    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()

    # BF16
    model_bf16 = nn.Linear(4096, 4096).to(device, dtype=torch.bfloat16)
    x_bf16     = torch.randn(64, 4096, device=device, dtype=torch.bfloat16)
    out_bf16   = model_bf16(x_bf16)
    out_bf16.sum().backward()
    peak_bf16  = torch.cuda.max_memory_allocated() / 1e6

    print(f"Peak GPU memory — FP32: {peak_fp32:.1f} MB  |  BF16: {peak_bf16:.1f} MB")
    print(f"Memory reduction: {peak_fp32 / peak_bf16:.2f}×")
else:
    print("GPU memory stats require CUDA")

6. FP8 Preview

# FP8 (E4M3 and E5M2) is available on Hopper GPUs (H100) via transformer-engine
# or torch._scaled_mm. Here we show the concept conceptually.

print("FP8 formats:")
print("  E4M3: 4 exponent bits, 3 mantissa bits → max ≈ 448, eps ≈ 0.00195")
print("  E5M2: 5 exponent bits, 2 mantissa bits → max ≈ 57344, eps ≈ 0.0039")
print()
print("FP8 is used for forward pass only (weights + activations).")
print("Gradients and master weights remain in BF16/FP32.")
print("Expected speedup over BF16: ~2× on H100 Tensor Cores.")

# Illustrate the quantisation error
x = torch.linspace(-10, 10, 1000)
fp32_vals = x
fp16_vals = x.half().float()     # round-trip through FP16
bf16_vals = x.bfloat16().float() # round-trip through BF16

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

fig, ax = plt.subplots(figsize=(10, 4))
ax.plot(x, (fp16_vals - fp32_vals).abs(), label="FP16 error", alpha=0.8)
ax.plot(x, (bf16_vals - fp32_vals).abs(), label="BF16 error", alpha=0.8)
ax.set_xlabel("Value")
ax.set_ylabel("Absolute quantisation error")
ax.set_title("FP16 vs BF16 quantisation error")
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig("data/ch09_precision.png", dpi=100)
print("Saved → data/ch09_precision.png")

7. Practical Recommendations

print("""
Precision cheat sheet
─────────────────────
Training on A100/H100:   BF16 with torch.autocast  (no GradScaler needed)
Training on V100/3090:   FP16 with torch.autocast + GradScaler
Inference on any GPU:    FP16 or BF16  (no GradScaler needed)
Inference on CPU:        FP32 (FP16 is slow on CPU; BF16 ok on recent AVX-512)
Fine-tuning (memory):    4-bit with bitsandbytes  (Chapter 14)

Rule of thumb:
  • BF16 = FP32 stability + FP16 speed
  • If you have BF16 support, always use it over FP16
""")

8. Summary

Format Bits Max Use case
float32 32 3.4e38 Master weights, optimiser state
float16 16 65 504 Inference; training w/ GradScaler
bfloat16 16 3.4e38 Training on modern GPUs (preferred)
FP8 8 448 Hopper forward passes (emerging)

Chapter 10 scales training further with distributed data parallelism (DDP) across multiple GPUs or machines.