Chapter 04: Attention — Softmax, Scaled Dot-Product, Positional Encoding
The fundamental limitation of the MLP language model from Chapter 3 is its fixed context window: it can see at most N previous characters and treats them all equally regardless of their relevance. The attention mechanism removes both constraints. It dynamically computes a weighted average of past token representations, where the weights are derived from the content of the tokens themselves.
Attention was introduced in “Attention is All You Need” (Vaswani et al., 2017) and is the core operation in every modern LLM. The key formula is:
\[\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{QK^\top}{\sqrt{d_k}}\right)V\]Q (queries), K (keys), and V (values) are linear projections of the input. The dot products Q Kᵀ measure pairwise compatibility; dividing by √dₖ keeps the values in a reasonable range so softmax doesn’t saturate; and the resulting weights select which value vectors to mix.
We also implement causal masking (each position can only attend to itself and earlier positions) and sinusoidal positional encoding (so the model knows where in the sequence each token appears). Together these two components are what turn a set-based operation into an ordered sequence model.
1. Scaled Dot-Product Attention
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
def scaled_dot_product_attention(
Q: torch.Tensor,
K: torch.Tensor,
V: torch.Tensor,
mask: torch.Tensor = None,
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Q: (B, heads, T, head_dim)
K: (B, heads, T, head_dim)
V: (B, heads, T, head_dim)
mask: (T, T) or (B, 1, T, T) — True/1 means MASKED OUT (ignored)
Returns:
out: (B, heads, T, head_dim) — attention-weighted values
weights:(B, heads, T, T) — attention probabilities (for viz)
"""
d_k = Q.shape[-1]
# Raw attention scores
scores = Q @ K.transpose(-2, -1) / math.sqrt(d_k) # (B, H, T, T)
if mask is not None:
# Replace masked positions with -inf so softmax gives 0 weight
scores = scores.masked_fill(mask, float("-inf"))
weights = F.softmax(scores, dim=-1) # (B, H, T, T)
out = weights @ V # (B, H, T, head_dim)
return out, weights
2. Causal (Autoregressive) Mask
def causal_mask(T: int, device="cpu") -> torch.Tensor:
"""
Returns a (T, T) boolean mask where mask[i, j] = True means
position i cannot attend to position j (i.e., j > i — the future).
"""
# torch.triu gives the upper triangle; we mask everything above the diagonal
mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
return mask
# Visualise the mask for T=6
T = 6
mask = causal_mask(T)
print("Causal mask (True = masked out):")
print(mask.int())
3. Sinusoidal Positional Encoding
class SinusoidalPositionalEncoding(nn.Module):
"""
Fixed (non-learned) positional encoding from "Attention is All You Need".
PE(pos, 2i) = sin(pos / 10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos / 10000^(2i/d_model))
"""
def __init__(self, d_model: int, max_len: int = 2048):
super().__init__()
pe = torch.zeros(max_len, d_model)
pos = torch.arange(max_len).unsqueeze(1) # (max_len, 1)
div = torch.exp(
torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
)
pe[:, 0::2] = torch.sin(pos * div) # even dimensions
pe[:, 1::2] = torch.cos(pos * div) # odd dimensions
# Register as buffer so it moves to GPU with .to(device)
self.register_buffer("pe", pe.unsqueeze(0)) # (1, max_len, d_model)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x: (B, T, d_model)"""
return x + self.pe[:, : x.size(1)]
4. Multi-Head Attention
class MultiHeadAttention(nn.Module):
"""
Runs `n_heads` attention heads in parallel, each with dimension d_model // n_heads.
Multiple heads let the model attend to different aspects of the sequence simultaneously.
"""
def __init__(self, d_model: int, n_heads: int, dropout: float = 0.1):
super().__init__()
assert d_model % n_heads == 0
self.n_heads = n_heads
self.head_dim = d_model // n_heads
# Single fused projection for Q, K, V — more efficient than three separate ones
self.qkv_proj = nn.Linear(d_model, 3 * d_model, bias=False)
self.out_proj = nn.Linear(d_model, d_model, bias=False)
self.dropout = nn.Dropout(dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor = None):
"""x: (B, T, d_model)"""
B, T, C = x.shape
# Project to Q, K, V and reshape into heads
qkv = self.qkv_proj(x) # (B, T, 3*C)
Q, K, V = qkv.split(C, dim=-1) # each (B, T, C)
def split_heads(t):
# (B, T, C) → (B, n_heads, T, head_dim)
return t.view(B, T, self.n_heads, self.head_dim).transpose(1, 2)
Q, K, V = split_heads(Q), split_heads(K), split_heads(V)
# Expand mask to match (B, n_heads, T, T)
if mask is not None:
mask = mask.unsqueeze(0).unsqueeze(0) # (1, 1, T, T)
out, attn_weights = scaled_dot_product_attention(Q, K, V, mask)
# Re-merge heads: (B, n_heads, T, head_dim) → (B, T, d_model)
out = out.transpose(1, 2).contiguous().view(B, T, C)
out = self.dropout(self.out_proj(out))
return out, attn_weights
5. Demo — Attention on a Short Sequence
torch.manual_seed(0)
d_model = 64
n_heads = 4
T = 10
B = 1
mha = MultiHeadAttention(d_model, n_heads)
pos_enc = SinusoidalPositionalEncoding(d_model)
# Random input (in practice this would be token embeddings)
x = torch.randn(B, T, d_model)
x = pos_enc(x) # add positional information
mask = causal_mask(T)
out, weights = mha(x, mask)
print(f"Input shape: {x.shape}")
print(f"Output shape: {out.shape}")
print(f"Attention weight shape: {weights.shape}")
# Show the average attention pattern across heads for position 5
avg_weights = weights[0].mean(dim=0) # (T, T)
print(f"\nAttention from position 5 (avg across heads):")
print(avg_weights[5].detach().numpy().round(3))
print("(positions 6-9 are zero due to causal mask)")
6. Visualise Attention Patterns
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
fig, axes = plt.subplots(1, n_heads, figsize=(14, 3))
for h in range(n_heads):
ax = axes[h]
im = ax.imshow(weights[0, h].detach().numpy(), vmin=0, vmax=1,
cmap="Blues")
ax.set_title(f"Head {h}")
ax.set_xlabel("Key position")
ax.set_ylabel("Query position")
plt.colorbar(im, ax=ax)
plt.suptitle("Multi-Head Attention Patterns (causal)")
plt.tight_layout()
plt.savefig("data/ch04_attention.png", dpi=100)
print("Saved → data/ch04_attention.png")
7. Compare Sinusoidal vs Learned Positional Encoding
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
d_model = 64
pos_enc = SinusoidalPositionalEncoding(d_model, max_len=50)
pe = pos_enc.pe[0].detach().numpy() # (50, 64)
plt.figure(figsize=(12, 4))
plt.imshow(pe.T, aspect="auto", cmap="RdBu")
plt.colorbar()
plt.xlabel("Position")
plt.ylabel("Dimension")
plt.title("Sinusoidal Positional Encoding (d=64)")
plt.tight_layout()
plt.savefig("data/ch04_positional_encoding.png", dpi=100)
print("Saved → data/ch04_positional_encoding.png")
8. Summary
| Concept | Key Detail |
|---|---|
| Scaled dot-product | Divide by √dₖ to prevent softmax saturation |
| Causal mask | Future positions → −∞ score → 0 attention weight |
| Multi-head | h parallel heads each of size d/h; results concatenated |
| Sinusoidal PE | Fixed, deterministic; works for sequences longer than training |
| Learned PE | Usually slightly better; limited to max training length |
Chapter 5 will assemble these pieces (attention, LayerNorm, residual connections) into a complete GPT-2 style transformer and train it on TinyStories.