Chapter 01: Bigram Language Model

A language model assigns probabilities to sequences of tokens. The simplest possible language model is the bigram model: it predicts the next token based solely on the current token. Despite its simplicity, the bigram model captures real statistical patterns in text and serves as the perfect starting point for understanding everything that follows in this course.

In this chapter we work at the character level — each token is a single character. This keeps the vocabulary tiny (around 65 characters for English text) so the entire probability table fits in memory and we can reason about it directly. Later chapters will move to subword tokenisation (BPE) and much larger models, but the core ideas remain identical.

We will use the TinyStories dataset — a collection of short children’s stories generated by GPT-4 and GPT-3.5. Every chapter in this course uses TinyStories as its training corpus. We download it once here and save it locally so subsequent chapters can load it without hitting the network again.

By the end of this chapter you will be able to count bigram frequencies, convert counts to probabilities, sample novel text from the distribution, and measure model quality with negative log-likelihood (NLL) loss — the same loss function used to train billion-parameter LLMs.

1. Download TinyStories

import os
from datasets import load_dataset

DATA_DIR = "data"
TRAIN_FILE = os.path.join(DATA_DIR, "tinystories_train.txt")
VAL_FILE   = os.path.join(DATA_DIR, "tinystories_val.txt")

os.makedirs(DATA_DIR, exist_ok=True)

if not os.path.exists(TRAIN_FILE):
    print("Downloading TinyStories …")
    # streaming=True avoids loading all 2M stories into RAM at once
    ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True)

    # We take a manageable slice (50 000 stories) for local experiments.
    # Remove the slice to use the full dataset (≈ 2 GB).
    with open(TRAIN_FILE, "w", encoding="utf-8") as f:
        for i, example in enumerate(ds):
            if i >= 50_000:
                break
            f.write(example["text"].strip() + "\n")
    print(f"Saved training slice → {TRAIN_FILE}")
else:
    print(f"Training file already exists: {TRAIN_FILE}")

if not os.path.exists(VAL_FILE):
    ds_val = load_dataset("roneneldan/TinyStories", split="validation", streaming=True)
    with open(VAL_FILE, "w", encoding="utf-8") as f:
        for i, example in enumerate(ds_val):
            if i >= 5_000:
                break
            f.write(example["text"].strip() + "\n")
    print(f"Saved validation slice → {VAL_FILE}")
else:
    print(f"Validation file already exists: {VAL_FILE}")

2. Build the Character Vocabulary

# Read the training corpus
with open(TRAIN_FILE, "r", encoding="utf-8") as f:
    text = f.read()

print(f"Corpus length: {len(text):,} characters")

# Collect every unique character
chars = sorted(set(text))
vocab_size = len(chars)
print(f"Vocabulary size: {vocab_size}")
print(f"Characters: {''.join(chars[:40])} …")

# Build integer ↔ character mappings
stoi = {ch: i for i, ch in enumerate(chars)}  # string → index
itos = {i: ch for ch, i in stoi.items()}       # index → string

def encode(s: str) -> list[int]:
    return [stoi[c] for c in s]

def decode(ids: list[int]) -> str:
    return "".join(itos[i] for i in ids)

# Sanity check
sample = "Once upon a time"
assert decode(encode(sample)) == sample
print("Encode/decode round-trip: OK")

3. Count Bigram Frequencies

import torch

# N[i, j] = number of times character i is followed by character j
N = torch.zeros((vocab_size, vocab_size), dtype=torch.int32)

# We add special start/end token '.' at index 0
# Rebuild stoi/itos with '.' prepended
chars_with_special = ['.'] + [c for c in chars if c != '.']
stoi = {ch: i for i, ch in enumerate(chars_with_special)}
itos = {i: ch for ch, i in stoi.items()}
vocab_size = len(chars_with_special)

N = torch.zeros((vocab_size, vocab_size), dtype=torch.int32)

# Iterate over every story in the training file
with open(TRAIN_FILE, "r", encoding="utf-8") as f:
    for line in f:
        story = line.strip()
        if not story:
            continue
        # Wrap each story with start/end sentinel '.'
        tokens = [stoi['.']] + [stoi[c] for c in story] + [stoi['.']]
        for ch1, ch2 in zip(tokens, tokens[1:]):
            N[ch1, ch2] += 1

print(f"Bigram count matrix shape: {N.shape}")
print(f"Total bigrams counted: {N.sum().item():,}")

4. Compute Bigram Probabilities

# Convert counts to probabilities using Laplace (add-one) smoothing
# Smoothing ensures no probability is exactly zero (avoids -inf log-prob)
P = (N + 1).float()
P = P / P.sum(dim=1, keepdim=True)  # row-normalise

print("Row sums (should all be 1.0):", P.sum(dim=1)[:5])

5. Sample from the Model

import random

def generate(P, itos, stoi, max_chars: int = 200, seed: int = 42) -> str:
    """Sample a story character-by-character from the bigram distribution."""
    torch.manual_seed(seed)
    idx = stoi['.']          # start token
    result = []

    for _ in range(max_chars):
        # P[idx] is the probability distribution over next characters
        probs = P[idx]
        # torch.multinomial draws a sample index according to `probs`
        idx = torch.multinomial(probs, num_samples=1).item()
        if itos[idx] == '.':  # end token reached
            break
        result.append(itos[idx])

    return "".join(result)

for seed in range(3):
    print(f"--- Sample {seed + 1} ---")
    print(generate(P, itos, stoi, seed=seed))
    print()

6. Compute Negative Log-Likelihood Loss

def compute_nll(P, stoi, filepath: str, max_chars: int = 100_000) -> float:
    """
    Negative log-likelihood: average -log P(next_char | current_char).
    Lower is better. A uniform model over V chars gives NLL = log(V).
    """
    total_log_prob = 0.0
    n = 0

    with open(filepath, "r", encoding="utf-8") as f:
        chars_seen = 0
        for line in f:
            story = line.strip()
            if not story:
                continue
            tokens = [stoi['.']] + [stoi.get(c, stoi['.']) for c in story] + [stoi['.']]
            for ch1, ch2 in zip(tokens, tokens[1:]):
                log_prob = torch.log(P[ch1, ch2]).item()
                total_log_prob += log_prob
                n += 1
                chars_seen += 1
            if chars_seen >= max_chars:
                break

    nll = -total_log_prob / n
    return nll

train_nll = compute_nll(P, stoi, TRAIN_FILE)
val_nll   = compute_nll(P, stoi, VAL_FILE)

print(f"Train NLL: {train_nll:.4f}")
print(f"Val   NLL: {val_nll:.4f}")
print(f"Baseline (uniform): {torch.log(torch.tensor(float(vocab_size))).item():.4f}")

7. Visualise the Most Common Bigrams

import matplotlib
matplotlib.use("Agg")          # no display required
import matplotlib.pyplot as plt

# Show top-20 most frequent bigrams
counts_flat = N.view(-1)
top_indices = counts_flat.argsort(descending=True)[:20]
labels = []
values = []
for idx in top_indices:
    i, j = divmod(idx.item(), vocab_size)
    labels.append(f"{repr(itos[i])}{repr(itos[j])}")
    values.append(counts_flat[idx].item())

fig, ax = plt.subplots(figsize=(12, 5))
ax.bar(range(len(labels)), values)
ax.set_xticks(range(len(labels)))
ax.set_xticklabels(labels, rotation=45, ha="right")
ax.set_title("Top-20 Most Frequent Bigrams (TinyStories)")
ax.set_ylabel("Count")
plt.tight_layout()
plt.savefig("data/bigram_frequencies.png", dpi=100)
print("Saved plot → data/bigram_frequencies.png")

8. Summary

Concept Takeaway
Bigram model P(next | current) — a simple but real language model
Laplace smoothing Prevents zero probabilities; essential for log-likelihood
NLL loss $-\frac{1}{N}\sum_t \log P(x_t \mid x_{t-1})$ — the universal LM objective
Sampling torch.multinomial draws tokens from a learned distribution

In Chapter 2 we will build micrograd, a tiny autodiff engine, which will let us move from hand-computed count tables to gradient-based learning.