Chapter 04: Attention — Softmax, Scaled Dot-Product, Positional Encoding

The fundamental limitation of the MLP language model from Chapter 3 is its fixed context window: it can see at most N previous characters and treats them all equally regardless of their relevance. The attention mechanism removes both constraints. It dynamically computes a weighted average of past token representations, where the weights are derived from the content of the tokens themselves.

Attention was introduced in “Attention is All You Need” (Vaswani et al., 2017) and is the core operation in every modern LLM. The key formula is:

\[\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V\]

Q (queries), K (keys), and V (values) are linear projections of the input. The dot products Q Kᵀ measure pairwise compatibility; dividing by √dₖ keeps the values in a reasonable range so softmax doesn’t saturate; and the resulting weights select which value vectors to mix.

We also implement causal masking (each position can only attend to itself and earlier positions) and sinusoidal positional encoding (so the model knows where in the sequence each token appears). Together these two components are what turn a set-based operation into an ordered sequence model.

1. Scaled Dot-Product Attention

import torch
import torch.nn as nn
import torch.nn.functional as F
import math

def scaled_dot_product_attention(
    Q: torch.Tensor,
    K: torch.Tensor,
    V: torch.Tensor,
    mask: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Q: (B, heads, T, head_dim)
    K: (B, heads, T, head_dim)
    V: (B, heads, T, head_dim)
    mask: (T, T) or (B, 1, T, T) — True/1 means MASKED OUT (ignored)

    Returns:
        out:    (B, heads, T, head_dim) — attention-weighted values
        weights:(B, heads, T, T)        — attention probabilities (for viz)
    """
    d_k = Q.shape[-1]
    # Raw attention scores
    scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k)   # (B, H, T, T)

    if mask is not None:
        # Replace masked positions with -inf so softmax gives 0 weight
        scores = scores.masked_fill(mask, float("-inf"))

    weights = F.softmax(scores, dim=-1)   # (B, H, T, T)
    out     = weights @ V                 # (B, H, T, head_dim)
    return out, weights

2. Causal (Autoregressive) Mask

def causal_mask(T: int, device="cpu") -> torch.Tensor:
    """
    Returns a (T, T) boolean mask where mask[i, j] = True means
    position i cannot attend to position j (i.e., j > i — the future).
    """
    # torch.triu gives the upper triangle; we mask everything above the diagonal
    mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
    return mask

# Visualise the mask for T=6
T = 6
mask = causal_mask(T)
print("Causal mask (True = masked out):")
print(mask.int())

3. Sinusoidal Positional Encoding

class SinusoidalPositionalEncoding(nn.Module):
    """
    Fixed (non-learned) positional encoding from "Attention is All You Need".
    PE(pos, 2i)   = sin(pos / 10000^(2i/d_model))
    PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
    """
    def __init__(self, d_model: int, max_len: int = 2048):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(max_len).unsqueeze(1)          # (max_len, 1)
        div = torch.exp(
            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(pos * div)   # even dimensions
        pe[:, 1::2] = torch.cos(pos * div)   # odd  dimensions
        # Register as buffer so it moves to GPU with .to(device)
        self.register_buffer("pe", pe.unsqueeze(0))   # (1, max_len, d_model)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """x: (B, T, d_model)"""
        return x + self.pe[:, : x.size(1)]

4. Multi-Head Attention

class MultiHeadAttention(nn.Module):
    """
    Runs `n_heads` attention heads in parallel, each with dimension d_model // n_heads.
    Multiple heads let the model attend to different aspects of the sequence simultaneously.
    """
    def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
        super().__init__()
        assert d_model % n_heads == 0
        self.n_heads  = n_heads
        self.head_dim = d_model // n_heads

        # Single fused projection for Q, K, V — more efficient than three separate ones
        self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
        self.out_proj  = nn.Linear(d_model, d_model, bias=False)
        self.dropout   = nn.Dropout(dropout)

    def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
        """x: (B, T, d_model)"""
        B, T, C = x.shape

        # Project to Q, K, V and reshape into heads
        qkv = self.qkv_proj(x)                            # (B, T, 3*C)
        Q, K, V = qkv.split(C, dim=-1)                    # each (B, T, C)

        def split_heads(t):
            # (B, T, C) → (B, n_heads, T, head_dim)
            return t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)

        Q, K, V = split_heads(Q), split_heads(K), split_heads(V)

        # Expand mask to match (B, n_heads, T, T)
        if mask is not None:
            mask = mask.unsqueeze(0).unsqueeze(0)   # (1, 1, T, T)

        out, attn_weights = scaled_dot_product_attention(Q, K, V, mask)

        # Re-merge heads: (B, n_heads, T, head_dim) → (B, T, d_model)
        out = out.transpose(1, 2).contiguous().view(B, T, C)
        out = self.dropout(self.out_proj(out))
        return out, attn_weights

5. Demo — Attention on a Short Sequence

torch.manual_seed(0)

d_model = 64
n_heads = 4
T       = 10
B       = 1

mha = MultiHeadAttention(d_model, n_heads)
pos_enc = SinusoidalPositionalEncoding(d_model)

# Random input (in practice this would be token embeddings)
x = torch.randn(B, T, d_model)
x = pos_enc(x)   # add positional information

mask = causal_mask(T)
out, weights = mha(x, mask)

print(f"Input shape:  {x.shape}")
print(f"Output shape: {out.shape}")
print(f"Attention weight shape: {weights.shape}")

# Show the average attention pattern across heads for position 5
avg_weights = weights[0].mean(dim=0)   # (T, T)
print(f"\nAttention from position 5 (avg across heads):")
print(avg_weights[5].detach().numpy().round(3))
print("(positions 6-9 are zero due to causal mask)")

6. Visualise Attention Patterns

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, n_heads, figsize=(14, 3))
for h in range(n_heads):
    ax = axes[h]
    im = ax.imshow(weights[0, h].detach().numpy(), vmin=0, vmax=1,
                   cmap="Blues")
    ax.set_title(f"Head {h}")
    ax.set_xlabel("Key position")
    ax.set_ylabel("Query position")
    plt.colorbar(im, ax=ax)

plt.suptitle("Multi-Head Attention Patterns (causal)")
plt.tight_layout()
plt.savefig("data/ch04_attention.png", dpi=100)
print("Saved → data/ch04_attention.png")

7. Compare Sinusoidal vs Learned Positional Encoding

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

d_model = 64
pos_enc = SinusoidalPositionalEncoding(d_model, max_len=50)
pe = pos_enc.pe[0].detach().numpy()   # (50, 64)

plt.figure(figsize=(12, 4))
plt.imshow(pe.T, aspect="auto", cmap="RdBu")
plt.colorbar()
plt.xlabel("Position")
plt.ylabel("Dimension")
plt.title("Sinusoidal Positional Encoding (d=64)")
plt.tight_layout()
plt.savefig("data/ch04_positional_encoding.png", dpi=100)
print("Saved → data/ch04_positional_encoding.png")

8. Summary

Concept Key Detail
Scaled dot-product Divide by √dₖ to prevent softmax saturation
Causal mask Future positions → −∞ score → 0 attention weight
Multi-head h parallel heads each of size d/h; results concatenated
Sinusoidal PE Fixed, deterministic; works for sequences longer than training
Learned PE Usually slightly better; limited to max training length

Chapter 5 will assemble these pieces (attention, LayerNorm, residual connections) into a complete GPT-2 style transformer and train it on TinyStories.