Chapter 11: Datasets — Loading, Preprocessing, and Synthetic Data

A model is only as good as its training data. In this chapter we take a deep dive into the data pipeline: how to download, inspect, preprocess, and efficiently feed the TinyStories dataset into our training loop. We also explore synthetic data generation — using an existing language model to create new training examples — which is how the TinyStories dataset itself was created.

PyTorch separates data handling into two abstractions: Dataset (maps indices to samples) and DataLoader (iterates over a Dataset in batches, with optional shuffling and parallel loading). This separation makes it easy to swap in different datasets without changing the training code.

For large datasets that don’t fit in RAM, we use the HuggingFace datasets library, which provides memory-mapped Arrow files, streaming mode, and a convenient preprocessing pipeline with .map() and .filter().

1. Explore TinyStories with the datasets Library

from datasets import load_dataset, DatasetDict
import os

DATA_DIR = "data"
os.makedirs(DATA_DIR, exist_ok=True)

# Load the full TinyStories dataset (downloads once, caches locally)
print("Loading TinyStories …")
ds = load_dataset("roneneldan/TinyStories")

print(ds)
print(f"\nTrain examples : {len(ds['train']):,}")
print(f"Val   examples : {len(ds['validation']):,}")
print(f"\nFirst story:\n{ds['train'][0]['text'][:300]}")

2. Dataset Statistics

import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np

# Sample story lengths from the training set
print("Computing story length statistics …")
sample_size = 10_000
lengths = [len(ds['train'][i]['text']) for i in range(sample_size)]

print(f"Mean   length : {np.mean(lengths):,.0f} chars")
print(f"Median length : {np.median(lengths):,.0f} chars")
print(f"Max    length : {max(lengths):,} chars")
print(f"Min    length : {min(lengths):,} chars")

plt.figure(figsize=(9, 4))
plt.hist(lengths, bins=60, color="steelblue", edgecolor="white")
plt.xlabel("Story length (characters)")
plt.ylabel("Count")
plt.title(f"TinyStories length distribution (n={sample_size:,})")
plt.tight_layout()
plt.savefig("data/ch11_story_lengths.png", dpi=100)
print("Saved → data/ch11_story_lengths.png")

3. Preprocessing Pipeline

from transformers import AutoTokenizer

# Use the GPT-2 tokeniser (50 257 vocab BPE)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token   # GPT-2 has no pad token by default

BLOCK_SIZE = 512   # number of tokens per training example

def tokenize_and_chunk(examples: dict) -> dict:
    """
    Tokenise a batch of stories and concatenate them into fixed-length blocks.
    Adds EOS token between stories.
    """
    # Tokenise all texts in the batch
    tokens = tokenizer(
        examples["text"],
        add_special_tokens=True,
        truncation=False,
    )["input_ids"]

    # Concatenate all token lists separated by EOS
    flat = []
    for ids in tokens:
        flat.extend(ids + [tokenizer.eos_token_id])

    # Split into fixed-length chunks
    input_ids = [flat[i : i + BLOCK_SIZE]
                 for i in range(0, len(flat) - BLOCK_SIZE, BLOCK_SIZE)]

    return {"input_ids": input_ids}

# Apply preprocessing (use num_proc for parallelism on larger datasets)
print("Tokenising and chunking …")
small_ds = ds["train"].select(range(5_000))
tokenised = small_ds.map(
    tokenize_and_chunk,
    batched=True,
    batch_size=500,
    remove_columns=["text"],
    desc="Tokenising",
)
print(f"Tokenised blocks: {len(tokenised):,}")
print(f"Block shape example: {len(tokenised[0]['input_ids'])} tokens")

4. Custom PyTorch Dataset Class

import torch
from torch.utils.data import Dataset, DataLoader

class TinyStoriesDataset(Dataset):
    """
    PyTorch Dataset that returns (input, target) pairs for
    causal language modelling. Target is input shifted by one.
    """

    def __init__(self, hf_dataset, block_size: int = 512):
        self.data       = hf_dataset
        self.block_size = block_size

    def __len__(self) -> int:
        return len(self.data)

    def __getitem__(self, idx: int):
        ids = self.data[idx]["input_ids"]
        x   = torch.tensor(ids[:-1], dtype=torch.long)   # context
        y   = torch.tensor(ids[1:],  dtype=torch.long)   # targets (shifted by 1)
        return x, y


# Instantiate train and validation datasets
train_dataset = TinyStoriesDataset(tokenised)

# Split 90/10 for demo purposes
n_val  = len(train_dataset) // 10
n_train = len(train_dataset) - n_val
train_set, val_set = torch.utils.data.random_split(train_dataset, [n_train, n_val])

print(f"Train blocks: {len(train_set):,}  |  Val blocks: {len(val_set):,}")

5. Efficient DataLoader

def collate_fn(batch):
    """Pad variable-length sequences in a batch to the same length."""
    xs, ys = zip(*batch)
    max_len = max(x.size(0) for x in xs)
    pad_id  = tokenizer.pad_token_id

    xs_padded = torch.stack([
        torch.cat([x, torch.full((max_len - x.size(0),), pad_id)]) for x in xs
    ])
    ys_padded = torch.stack([
        torch.cat([y, torch.full((max_len - y.size(0),), -100)])   for y in ys
    ])
    return xs_padded, ys_padded   # -100 is ignored by cross-entropy

train_loader = DataLoader(
    train_set,
    batch_size  = 8,
    shuffle     = True,
    num_workers = 2,             # parallel data loading
    pin_memory  = True,          # faster CPU→GPU transfer
    drop_last   = True,          # ensure fixed batch sizes
)

val_loader = DataLoader(val_set, batch_size=8, shuffle=False, num_workers=2)

# Inspect a batch
xb, yb = next(iter(train_loader))
print(f"Batch x shape: {xb.shape}  dtype: {xb.dtype}")
print(f"Batch y shape: {yb.shape}  dtype: {yb.dtype}")
print(f"Sample tokens: {tokenizer.decode(xb[0][:30].tolist())!r}")

6. Streaming Dataset (for Very Large Corpora)

from datasets import load_dataset

# Streaming mode: data is downloaded and processed on-the-fly
# Useful when the full dataset doesn't fit on disk or in RAM
stream_ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
stream_ds = stream_ds.shuffle(seed=42, buffer_size=10_000)

# Wrap in a PyTorch IterableDataset for use with DataLoader
from torch.utils.data import IterableDataset

class StreamingTinyStories(IterableDataset):
    def __init__(self, hf_iterable_dataset, tokenizer, block_size=512):
        self.ds         = hf_iterable_dataset
        self.tokenizer  = tokenizer
        self.block_size = block_size
        self.buffer     = []

    def __iter__(self):
        for example in self.ds:
            ids = self.tokenizer.encode(example["text"]) + [self.tokenizer.eos_token_id]
            self.buffer.extend(ids)
            while len(self.buffer) >= self.block_size + 1:
                chunk = self.buffer[:self.block_size + 1]
                self.buffer = self.buffer[self.block_size:]
                x = torch.tensor(chunk[:-1], dtype=torch.long)
                y = torch.tensor(chunk[1:],  dtype=torch.long)
                yield x, y

stream_dataset = StreamingTinyStories(stream_ds, tokenizer, block_size=256)
stream_loader  = DataLoader(stream_dataset, batch_size=4)

# Get one batch from the stream
xb_stream, yb_stream = next(iter(stream_loader))
print(f"Streaming batch shape: {xb_stream.shape}")

7. Synthetic Story Generation

from transformers import pipeline

print("Loading GPT-2 for synthetic story generation …")
generator = pipeline("text-generation", model="gpt2", device=-1)   # -1 = CPU

# Prompts for generating synthetic children's stories
prompts = [
    "Once upon a time, a little rabbit",
    "There was a small village where",
    "The friendly dragon wanted to",
]

synthetic_stories = []
for prompt in prompts:
    output = generator(
        prompt,
        max_new_tokens = 100,
        do_sample      = True,
        temperature    = 0.8,
        top_k          = 50,
        num_return_sequences = 1,
    )
    story = output[0]["generated_text"]
    synthetic_stories.append(story)
    print(f"\nPrompt: {repr(prompt)}")
    print(f"Story : {story[:200]}")

# Save synthetic stories for fine-tuning augmentation
synth_file = os.path.join(DATA_DIR, "synthetic_stories.txt")
with open(synth_file, "w") as f:
    for story in synthetic_stories:
        f.write(story.strip() + "\n")
print(f"\nSaved {len(synthetic_stories)} synthetic stories → {synth_file}")

8. Data Quality Filtering

import re

def is_valid_story(text: str) -> bool:
    """
    Basic quality filter for TinyStories examples.
    Removes very short, very long, or garbled stories.
    """
    if len(text) < 100 or len(text) > 5000:
        return False
    # Must contain some sentence-ending punctuation
    if not re.search(r"[.!?]", text):
        return False
    # Reject if more than 20% non-ASCII
    n_ascii = sum(c.isascii() for c in text)
    if n_ascii / len(text) < 0.8:
        return False
    return True

# Apply filter to the small dataset
before = len(small_ds)
filtered_ds = small_ds.filter(is_valid_story, input_columns=["text"])
after  = len(filtered_ds)
print(f"Before filtering: {before:,}  |  After: {after:,}  |  Removed: {before-after:,}")

9. Summary

Tool Purpose
load_dataset() Download and cache HuggingFace datasets
.map() Vectorised, parallel preprocessing
.filter() Remove low-quality examples
Dataset + DataLoader Batched, shuffled, parallel PyTorch data loading
IterableDataset Streaming from disk/network — no full RAM needed
pipeline("text-generation") Quick synthetic data generation

Chapter 12 introduces the KV cache — the inference-time optimisation that makes autoregressive generation practical.