Chapter 03: N-gram Neural Language Model (MLP, MatMul, GELU)
The bigram model from Chapter 1 only looked one character back. We can do much better by conditioning on several past characters — an N-gram context. Rather than storing a table of counts (which would grow exponentially with context length), we learn a neural network that maps a context window of characters to a probability distribution over the next character.
This architecture, popularised by Bengio et al. (2003), introduced two ideas that remain central to modern LLMs: (1) representing discrete tokens as dense embeddings and (2) learning through backpropagation via gradient descent. The embedding lookup + linear projection + nonlinearity pattern is exactly what every transformer uses internally.
We also switch from ReLU to GELU (Gaussian Error Linear Unit), the activation function used by GPT-2 and most modern transformers. GELU is smoother than ReLU near zero, which helps gradients flow more easily through deep networks.
We use the TinyStories data saved in Chapter 1. If you haven’t run Chapter 1 yet, the download block below will fetch it for you.
1. Load TinyStories
import os
from datasets import load_dataset
DATA_DIR = "data"
TRAIN_FILE = os.path.join(DATA_DIR, "tinystories_train.txt")
VAL_FILE = os.path.join(DATA_DIR, "tinystories_val.txt")
os.makedirs(DATA_DIR, exist_ok=True)
def download_tinystories():
for split, path, n in [("train", TRAIN_FILE, 50_000),
("validation", VAL_FILE, 5_000)]:
if os.path.exists(path):
continue
print(f"Downloading TinyStories ({split}) …")
ds = load_dataset("roneneldan/TinyStories", split=split, streaming=True)
with open(path, "w", encoding="utf-8") as f:
for i, ex in enumerate(ds):
if i >= n:
break
f.write(ex["text"].strip() + "\n")
print(f" Saved → {path}")
download_tinystories()
with open(TRAIN_FILE, "r", encoding="utf-8") as f:
train_text = f.read()
with open(VAL_FILE, "r", encoding="utf-8") as f:
val_text = f.read()
print(f"Train chars: {len(train_text):,}")
print(f"Val chars: {len(val_text):,}")
2. Build Vocabulary & Encode Data
import torch
# Build vocabulary from training text only
chars = ['.'] + sorted(set(train_text) - {'.'})
stoi = {c: i for i, c in enumerate(chars)}
itos = {i: c for c, i in stoi.items()}
vocab_size = len(chars)
print(f"Vocabulary size: {vocab_size}")
def encode(s: str) -> list[int]:
return [stoi.get(c, stoi['.']) for c in s]
def decode(ids) -> str:
return "".join(itos[i] for i in ids)
3. Build N-gram Training Examples
CONTEXT_LEN = 8 # how many characters we condition on
def build_dataset(text: str):
"""
Slide a window of CONTEXT_LEN over the text.
X[i] = context characters (integers)
Y[i] = next character (integer)
"""
X, Y = [], []
ids = encode(text)
for i in range(len(ids) - CONTEXT_LEN):
X.append(ids[i : i + CONTEXT_LEN])
Y.append(ids[i + CONTEXT_LEN])
X = torch.tensor(X, dtype=torch.long)
Y = torch.tensor(Y, dtype=torch.long)
return X, Y
Xtr, Ytr = build_dataset(train_text)
Xva, Yva = build_dataset(val_text)
print(f"Train samples: {Xtr.shape} Val samples: {Xva.shape}")
4. GELU Activation
import torch.nn as nn
import math
class GELU(nn.Module):
"""
Gaussian Error Linear Unit.
GELU(x) = x * Φ(x) where Φ is the standard normal CDF.
Approximated as: 0.5 * x * (1 + tanh(√(2/π) * (x + 0.044715 * x³)))
"""
def forward(self, x: torch.Tensor) -> torch.Tensor:
return 0.5 * x * (
1.0 + torch.tanh(
math.sqrt(2.0 / math.pi) * (x + 0.044715 * x ** 3)
)
)
# Alternatively: nn.GELU() — built-in PyTorch implementation
5. MLP Language Model
class NGramMLP(nn.Module):
def __init__(self, vocab_size: int, emb_dim: int, context_len: int,
hidden_dim: int):
super().__init__()
# Embedding table: each character gets a dense vector of size emb_dim
self.embedding = nn.Embedding(vocab_size, emb_dim)
# After embedding lookup and flattening, input dim = context_len * emb_dim
in_dim = context_len * emb_dim
self.net = nn.Sequential(
nn.Linear(in_dim, hidden_dim),
GELU(),
nn.Linear(hidden_dim, hidden_dim),
GELU(),
nn.Linear(hidden_dim, vocab_size),
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
x: (B, context_len) integer token indices
returns: (B, vocab_size) logits
"""
# Embed each token in the context window
emb = self.embedding(x) # (B, context_len, emb_dim)
emb = emb.view(emb.size(0), -1) # (B, context_len * emb_dim) — flatten
return self.net(emb)
@torch.no_grad()
def generate(self, stoi, itos, max_chars=200, temperature=1.0, seed=0):
"""Autoregressively sample one story."""
torch.manual_seed(seed)
self.eval()
context = [stoi['.']] * CONTEXT_LEN # start with padding tokens
result = []
for _ in range(max_chars):
x = torch.tensor([context], dtype=torch.long)
logits = self(x)[0] # (vocab_size,)
probs = torch.softmax(logits / temperature, dim=-1)
idx = torch.multinomial(probs, 1).item()
if itos[idx] == '.':
break
result.append(itos[idx])
context = context[1:] + [idx] # slide window
return "".join(result)
6. Training Loop
torch.manual_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
model = NGramMLP(
vocab_size = vocab_size,
emb_dim = 32,
context_len = CONTEXT_LEN,
hidden_dim = 256,
).to(device)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")
optimiser = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss()
BATCH_SIZE = 256
STEPS = 5_000
Xtr, Ytr = Xtr.to(device), Ytr.to(device)
Xva, Yva = Xva.to(device), Yva.to(device)
train_losses, val_losses = [], []
for step in range(STEPS):
model.train()
# Sample a random mini-batch
idx = torch.randint(len(Xtr), (BATCH_SIZE,))
xb, yb = Xtr[idx], Ytr[idx]
logits = model(xb)
loss = loss_fn(logits, yb)
optimiser.zero_grad()
loss.backward()
optimiser.step()
train_losses.append(loss.item())
if step % 500 == 0:
model.eval()
with torch.no_grad():
val_idx = torch.randint(len(Xva), (2048,))
val_logits = model(Xva[val_idx])
val_loss = loss_fn(val_logits, Yva[val_idx]).item()
val_losses.append(val_loss)
print(f"Step {step:5d} train loss: {loss.item():.4f} val loss: {val_loss:.4f}")
7. Generate Stories
model.eval()
for seed in range(3):
print(f"\n--- Generated story {seed + 1} ---")
print(model.generate(stoi, itos, max_chars=300, temperature=0.8, seed=seed))
8. Plot Embeddings with PCA
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
# Project character embeddings to 2-D for visualisation
emb_weights = model.embedding.weight.detach().cpu().numpy()
pca = PCA(n_components=2)
coords = pca.fit_transform(emb_weights)
fig, ax = plt.subplots(figsize=(10, 8))
for i, ch in itos.items():
ax.scatter(coords[i, 0], coords[i, 1], s=20, color="steelblue")
ax.annotate(repr(ch), (coords[i, 0], coords[i, 1]), fontsize=7)
ax.set_title("Character Embeddings (PCA)")
plt.tight_layout()
plt.savefig("data/ch03_embeddings.png", dpi=100)
print("Saved → data/ch03_embeddings.png")
9. Summary
| Component | Purpose |
|---|---|
nn.Embedding |
Learns dense vector representations for each token |
| Flatten + Linear | Projects concatenated context embeddings to hidden space |
| GELU | Smooth non-linearity; better gradient flow than ReLU |
| Cross-entropy | Equivalent to NLL for one-hot targets |
| AdamW | Adaptive learning rates + weight decay (see Chapter 7) |
Chapter 4 introduces the attention mechanism, which replaces the fixed-size context window with a dynamic, content-based weighting of all past tokens.