Chapter 11: Datasets — Loading, Preprocessing, and Synthetic Data
A model is only as good as its training data. In this chapter we take a deep dive into the data pipeline: how to download, inspect, preprocess, and efficiently feed the TinyStories dataset into our training loop. We also explore synthetic data generation — using an existing language model to create new training examples — which is how the TinyStories dataset itself was created.
PyTorch separates data handling into two abstractions: Dataset (maps indices to samples) and DataLoader (iterates over a Dataset in batches, with optional shuffling and parallel loading). This separation makes it easy to swap in different datasets without changing the training code.
For large datasets that don’t fit in RAM, we use the HuggingFace datasets library, which provides memory-mapped Arrow files, streaming mode, and a convenient preprocessing pipeline with .map() and .filter().
1. Explore TinyStories with the datasets Library
from datasets import load_dataset, DatasetDict
import os
DATA_DIR = "data"
os.makedirs(DATA_DIR, exist_ok=True)
# Load the full TinyStories dataset (downloads once, caches locally)
print("Loading TinyStories …")
ds = load_dataset("roneneldan/TinyStories")
print(ds)
print(f"\nTrain examples : {len(ds['train']):,}")
print(f"Val examples : {len(ds['validation']):,}")
print(f"\nFirst story:\n{ds['train'][0]['text'][:300]}")
2. Dataset Statistics
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
import numpy as np
# Sample story lengths from the training set
print("Computing story length statistics …")
sample_size = 10_000
lengths = [len(ds['train'][i]['text']) for i in range(sample_size)]
print(f"Mean length : {np.mean(lengths):,.0f} chars")
print(f"Median length : {np.median(lengths):,.0f} chars")
print(f"Max length : {max(lengths):,} chars")
print(f"Min length : {min(lengths):,} chars")
plt.figure(figsize=(9, 4))
plt.hist(lengths, bins=60, color="steelblue", edgecolor="white")
plt.xlabel("Story length (characters)")
plt.ylabel("Count")
plt.title(f"TinyStories length distribution (n={sample_size:,})")
plt.tight_layout()
plt.savefig("data/ch11_story_lengths.png", dpi=100)
print("Saved → data/ch11_story_lengths.png")
3. Preprocessing Pipeline
from transformers import AutoTokenizer
# Use the GPT-2 tokeniser (50 257 vocab BPE)
tokenizer = AutoTokenizer.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token # GPT-2 has no pad token by default
BLOCK_SIZE = 512 # number of tokens per training example
def tokenize_and_chunk(examples: dict) -> dict:
"""
Tokenise a batch of stories and concatenate them into fixed-length blocks.
Adds EOS token between stories.
"""
# Tokenise all texts in the batch
tokens = tokenizer(
examples["text"],
add_special_tokens=True,
truncation=False,
)["input_ids"]
# Concatenate all token lists separated by EOS
flat = []
for ids in tokens:
flat.extend(ids + [tokenizer.eos_token_id])
# Split into fixed-length chunks
input_ids = [flat[i : i + BLOCK_SIZE]
for i in range(0, len(flat) - BLOCK_SIZE, BLOCK_SIZE)]
return {"input_ids": input_ids}
# Apply preprocessing (use num_proc for parallelism on larger datasets)
print("Tokenising and chunking …")
small_ds = ds["train"].select(range(5_000))
tokenised = small_ds.map(
tokenize_and_chunk,
batched=True,
batch_size=500,
remove_columns=["text"],
desc="Tokenising",
)
print(f"Tokenised blocks: {len(tokenised):,}")
print(f"Block shape example: {len(tokenised[0]['input_ids'])} tokens")
4. Custom PyTorch Dataset Class
import torch
from torch.utils.data import Dataset, DataLoader
class TinyStoriesDataset(Dataset):
"""
PyTorch Dataset that returns (input, target) pairs for
causal language modelling. Target is input shifted by one.
"""
def __init__(self, hf_dataset, block_size: int = 512):
self.data = hf_dataset
self.block_size = block_size
def __len__(self) -> int:
return len(self.data)
def __getitem__(self, idx: int):
ids = self.data[idx]["input_ids"]
x = torch.tensor(ids[:-1], dtype=torch.long) # context
y = torch.tensor(ids[1:], dtype=torch.long) # targets (shifted by 1)
return x, y
# Instantiate train and validation datasets
train_dataset = TinyStoriesDataset(tokenised)
# Split 90/10 for demo purposes
n_val = len(train_dataset) // 10
n_train = len(train_dataset) - n_val
train_set, val_set = torch.utils.data.random_split(train_dataset, [n_train, n_val])
print(f"Train blocks: {len(train_set):,} | Val blocks: {len(val_set):,}")
5. Efficient DataLoader
def collate_fn(batch):
"""Pad variable-length sequences in a batch to the same length."""
xs, ys = zip(*batch)
max_len = max(x.size(0) for x in xs)
pad_id = tokenizer.pad_token_id
xs_padded = torch.stack([
torch.cat([x, torch.full((max_len - x.size(0),), pad_id)]) for x in xs
])
ys_padded = torch.stack([
torch.cat([y, torch.full((max_len - y.size(0),), -100)]) for y in ys
])
return xs_padded, ys_padded # -100 is ignored by cross-entropy
train_loader = DataLoader(
train_set,
batch_size = 8,
shuffle = True,
num_workers = 2, # parallel data loading
pin_memory = True, # faster CPU→GPU transfer
drop_last = True, # ensure fixed batch sizes
)
val_loader = DataLoader(val_set, batch_size=8, shuffle=False, num_workers=2)
# Inspect a batch
xb, yb = next(iter(train_loader))
print(f"Batch x shape: {xb.shape} dtype: {xb.dtype}")
print(f"Batch y shape: {yb.shape} dtype: {yb.dtype}")
print(f"Sample tokens: {tokenizer.decode(xb[0][:30].tolist())!r}")
6. Streaming Dataset (for Very Large Corpora)
from datasets import load_dataset
# Streaming mode: data is downloaded and processed on-the-fly
# Useful when the full dataset doesn't fit on disk or in RAM
stream_ds = load_dataset("roneneldan/TinyStories", split="train", streaming=True)
stream_ds = stream_ds.shuffle(seed=42, buffer_size=10_000)
# Wrap in a PyTorch IterableDataset for use with DataLoader
from torch.utils.data import IterableDataset
class StreamingTinyStories(IterableDataset):
def __init__(self, hf_iterable_dataset, tokenizer, block_size=512):
self.ds = hf_iterable_dataset
self.tokenizer = tokenizer
self.block_size = block_size
self.buffer = []
def __iter__(self):
for example in self.ds:
ids = self.tokenizer.encode(example["text"]) + [self.tokenizer.eos_token_id]
self.buffer.extend(ids)
while len(self.buffer) >= self.block_size + 1:
chunk = self.buffer[:self.block_size + 1]
self.buffer = self.buffer[self.block_size:]
x = torch.tensor(chunk[:-1], dtype=torch.long)
y = torch.tensor(chunk[1:], dtype=torch.long)
yield x, y
stream_dataset = StreamingTinyStories(stream_ds, tokenizer, block_size=256)
stream_loader = DataLoader(stream_dataset, batch_size=4)
# Get one batch from the stream
xb_stream, yb_stream = next(iter(stream_loader))
print(f"Streaming batch shape: {xb_stream.shape}")
7. Synthetic Story Generation
from transformers import pipeline
print("Loading GPT-2 for synthetic story generation …")
generator = pipeline("text-generation", model="gpt2", device=-1) # -1 = CPU
# Prompts for generating synthetic children's stories
prompts = [
"Once upon a time, a little rabbit",
"There was a small village where",
"The friendly dragon wanted to",
]
synthetic_stories = []
for prompt in prompts:
output = generator(
prompt,
max_new_tokens = 100,
do_sample = True,
temperature = 0.8,
top_k = 50,
num_return_sequences = 1,
)
story = output[0]["generated_text"]
synthetic_stories.append(story)
print(f"\nPrompt: {repr(prompt)}")
print(f"Story : {story[:200]}")
# Save synthetic stories for fine-tuning augmentation
synth_file = os.path.join(DATA_DIR, "synthetic_stories.txt")
with open(synth_file, "w") as f:
for story in synthetic_stories:
f.write(story.strip() + "\n")
print(f"\nSaved {len(synthetic_stories)} synthetic stories → {synth_file}")
8. Data Quality Filtering
import re
def is_valid_story(text: str) -> bool:
"""
Basic quality filter for TinyStories examples.
Removes very short, very long, or garbled stories.
"""
if len(text) < 100 or len(text) > 5000:
return False
# Must contain some sentence-ending punctuation
if not re.search(r"[.!?]", text):
return False
# Reject if more than 20% non-ASCII
n_ascii = sum(c.isascii() for c in text)
if n_ascii / len(text) < 0.8:
return False
return True
# Apply filter to the small dataset
before = len(small_ds)
filtered_ds = small_ds.filter(is_valid_story, input_columns=["text"])
after = len(filtered_ds)
print(f"Before filtering: {before:,} | After: {after:,} | Removed: {before-after:,}")
9. Summary
| Tool | Purpose |
|---|---|
load_dataset() |
Download and cache HuggingFace datasets |
.map() |
Vectorised, parallel preprocessing |
.filter() |
Remove low-quality examples |
Dataset + DataLoader |
Batched, shuffled, parallel PyTorch data loading |
IterableDataset |
Streaming from disk/network — no full RAM needed |
pipeline("text-generation") |
Quick synthetic data generation |
Chapter 12 introduces the KV cache — the inference-time optimisation that makes autoregressive generation practical.