Chapter 12: Inference I — KV Cache

During autoregressive generation, the model predicts one token at a time and appends it to the sequence. Naively, each new token requires recomputing the entire attention matrix over the full sequence — including all the tokens we already processed. This is quadratic in total tokens generated, making long sequences impossibly slow.

The Key-Value (KV) cache solves this: we cache the K and V projections for every past token. When we generate the next token, we only compute Q, K, V for the new token, then concatenate its K and V with the cached versions to compute attention. The computational cost per step stays roughly constant instead of growing with sequence length.

The memory cost is O(L × H × d), where L is sequence length, H is number of heads, and d is head dimension. For a 7B parameter model generating 1 024 tokens, the KV cache requires around 1-2 GB of VRAM — a significant but manageable overhead.

1. Attention Without KV Cache (baseline)

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time

class AttentionNoCache(nn.Module):
    """Standard causal attention — recomputes everything every step."""

    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        self.n_heads  = n_heads
        self.head_dim = d_model // n_heads
        self.qkv      = nn.Linear(d_model, 3 * d_model, bias=False)
        self.out      = nn.Linear(d_model, d_model, bias=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """x: (B, T, d_model) — processes the full sequence each call."""
        B, T, C = x.shape
        H, hd   = self.n_heads, self.head_dim

        qkv = self.qkv(x).view(B, T, 3, H, hd).permute(2, 0, 3, 1, 4)
        Q, K, V = qkv[0], qkv[1], qkv[2]   # each (B, H, T, hd)

        scale  = math.sqrt(hd)
        scores = Q @ K.transpose(-2, -1) / scale       # (B, H, T, T)
        mask   = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
        scores = scores.masked_fill(mask, float("-inf"))
        attn   = F.softmax(scores, dim=-1)

        out = (attn @ V).transpose(1, 2).reshape(B, T, C)
        return self.out(out)

2. Attention With KV Cache

class AttentionWithCache(nn.Module):
    """
    Causal attention with KV cache for efficient autoregressive generation.
    During generation, only the NEW token's Q/K/V are computed; cached
    K and V from all previous steps are concatenated and reused.
    """

    def __init__(self, d_model: int, n_heads: int):
        super().__init__()
        self.n_heads  = n_heads
        self.head_dim = d_model // n_heads
        self.d_model  = d_model
        self.qkv      = nn.Linear(d_model, 3 * d_model, bias=False)
        self.out      = nn.Linear(d_model, d_model, bias=False)

    def forward(
        self,
        x: torch.Tensor,
        kv_cache: dict = None,
        layer_idx: int = 0,
    ) -> tuple[torch.Tensor, dict]:
        """
        x:        (B, T_new, d_model)  — T_new=1 during generation, full T during prefill
        kv_cache: dict keyed by layer_idx containing accumulated K and V tensors
        Returns:  (output, updated_kv_cache)
        """
        B, T_new, C = x.shape
        H, hd       = self.n_heads, self.head_dim

        # Compute Q, K, V for the new token(s)
        qkv = self.qkv(x).view(B, T_new, 3, H, hd).permute(2, 0, 3, 1, 4)
        Q, K_new, V_new = qkv[0], qkv[1], qkv[2]

        # Retrieve and extend cached K, V
        if kv_cache is not None and layer_idx in kv_cache:
            K_cached, V_cached = kv_cache[layer_idx]
            K = torch.cat([K_cached, K_new], dim=2)   # (B, H, T_past+T_new, hd)
            V = torch.cat([V_cached, V_new], dim=2)
        else:
            K, V = K_new, V_new

        # Update cache
        if kv_cache is not None:
            kv_cache[layer_idx] = (K.detach(), V.detach())

        T_total = K.shape[2]

        # Scaled dot-product attention
        scale  = math.sqrt(hd)
        scores = Q @ K.transpose(-2, -1) / scale      # (B, H, T_new, T_total)

        # Causal mask: new tokens can attend to all past + current tokens
        # but not future ones (only relevant during prefill with T_new > 1)
        if T_new > 1:
            mask = torch.triu(
                torch.ones(T_new, T_total, device=x.device), diagonal=T_total - T_new + 1
            ).bool()
            scores = scores.masked_fill(mask, float("-inf"))

        attn = F.softmax(scores, dim=-1)
        out  = (attn @ V).transpose(1, 2).reshape(B, T_new, C)
        return self.out(out), kv_cache

3. GPT with KV Cache

class GPTWithCache(nn.Module):
    def __init__(self, vocab_size: int, d_model: int, n_heads: int,
                 n_layers: int, max_len: int):
        super().__init__()
        self.token_emb = nn.Embedding(vocab_size, d_model)
        self.pos_emb   = nn.Embedding(max_len, d_model)
        self.layers    = nn.ModuleList([
            AttentionWithCache(d_model, n_heads) for _ in range(n_layers)
        ])
        self.ln_f    = nn.LayerNorm(d_model)
        self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
        self.lm_head.weight = self.token_emb.weight

    def forward(self, idx: torch.Tensor,
                kv_cache: dict = None,
                start_pos: int = 0) -> tuple[torch.Tensor, dict]:
        """
        idx:      (B, T) — during generation T=1
        kv_cache: accumulated keys and values across all layers
        start_pos: position of the first token in idx (for correct position embedding)
        """
        B, T = idx.shape
        positions = torch.arange(start_pos, start_pos + T, device=idx.device)

        x = self.token_emb(idx) + self.pos_emb(positions)

        if kv_cache is None:
            kv_cache = {}

        for i, layer in enumerate(self.layers):
            x, kv_cache = layer(x, kv_cache=kv_cache, layer_idx=i)

        x      = self.ln_f(x)
        logits = self.lm_head(x)     # (B, T, vocab_size)
        return logits, kv_cache

    @torch.no_grad()
    def generate_with_cache(self, prompt_ids: torch.Tensor,
                            max_new_tokens: int,
                            temperature: float = 1.0,
                            itos: dict = None) -> list[int]:
        """Generate tokens one at a time, using the KV cache."""
        self.eval()

        # PREFILL: process the full prompt at once
        logits, kv_cache = self(prompt_ids, kv_cache={})
        next_logit = logits[:, -1, :] / temperature
        next_id    = torch.multinomial(F.softmax(next_logit, dim=-1), 1)

        generated = prompt_ids[0].tolist() + [next_id.item()]
        pos       = prompt_ids.shape[1]

        # DECODE: generate one token at a time
        for step in range(max_new_tokens - 1):
            x              = next_id                    # (1, 1) — only new token
            logits, kv_cache = self(x, kv_cache=kv_cache, start_pos=pos + step)
            next_logit     = logits[:, -1, :] / temperature
            next_id        = torch.multinomial(F.softmax(next_logit, dim=-1), 1)
            generated.append(next_id.item())
            if itos and itos.get(next_id.item()) == '.':
                break

        return generated

4. Benchmark: With vs Without KV Cache

import time

device = "cuda" if torch.cuda.is_available() else "cpu"

VOCAB  = 128
D_MODEL = 128
N_HEADS = 4
N_LAYERS = 4
MAX_LEN  = 256

model_cache = GPTWithCache(VOCAB, D_MODEL, N_HEADS, N_LAYERS, MAX_LEN).to(device)
model_nocache = nn.Transformer(  # proxy for no-cache model
    d_model=D_MODEL, nhead=N_HEADS, num_encoder_layers=0,
    num_decoder_layers=N_LAYERS, batch_first=True
).to(device) if False else model_cache   # reuse same model for comparison

@torch.no_grad()
def generate_no_cache(model, prompt: torch.Tensor, n_new: int) -> list[int]:
    """Naive generation: pass the full growing sequence each step."""
    ids = prompt.clone()
    for _ in range(n_new):
        logits, _ = model(ids, kv_cache=None)
        next_id   = torch.multinomial(F.softmax(logits[:, -1], dim=-1), 1)
        ids = torch.cat([ids, next_id], dim=1)
    return ids[0].tolist()

prompt = torch.randint(0, VOCAB, (1, 16), device=device)
N_NEW  = 64

# Benchmark no cache
t0 = time.perf_counter()
for _ in range(5):
    _ = generate_no_cache(model_cache, prompt, N_NEW)
if device == "cuda": torch.cuda.synchronize()
no_cache_ms = (time.perf_counter() - t0) / 5 * 1000

# Benchmark with cache
t0 = time.perf_counter()
for _ in range(5):
    _ = model_cache.generate_with_cache(prompt, N_NEW)
if device == "cuda": torch.cuda.synchronize()
cache_ms = (time.perf_counter() - t0) / 5 * 1000

print(f"No cache : {no_cache_ms:.1f} ms  (generates {N_NEW} tokens)")
print(f"KV cache : {cache_ms:.1f} ms  (generates {N_NEW} tokens)")
print(f"Speedup  : {no_cache_ms / cache_ms:.2f}×")

5. KV Cache Memory Layout

def kv_cache_memory_mb(
    batch_size: int,
    seq_len: int,
    n_layers: int,
    n_heads: int,
    head_dim: int,
    dtype: torch.dtype = torch.float16,
) -> float:
    """Estimate KV cache memory in MB."""
    bytes_per_elem = torch.finfo(dtype).bits // 8
    # K tensor + V tensor per layer
    total_elements = 2 * batch_size * n_layers * n_heads * seq_len * head_dim
    return total_elements * bytes_per_elem / 1e6

print("KV Cache Memory Estimates (FP16):")
configs = [
    ("GPT-2 Small",  12,  12,   64),
    ("GPT-2 Large",  36,  20,   64),
    ("LLaMA-7B",     32,  32,  128),
]

for name, n_layers, n_heads, head_dim in configs:
    for seq_len in [512, 2048, 8192]:
        mb = kv_cache_memory_mb(
            batch_size=1, seq_len=seq_len,
            n_layers=n_layers, n_heads=n_heads, head_dim=head_dim
        )
        print(f"  {name:15s} seq={seq_len:5d}: {mb:6.1f} MB")

6. Summary

Concept Detail
Prefill phase Full prompt processed in one forward pass
Decode phase One token at a time; K and V concatenated to cache
Memory cost 2 × B × L × H × d × dtype_bytes per layer
Speedup Quadratic → nearly linear in sequence length

Chapter 13 covers quantisation: compressing model weights to 8-bit or 4-bit integers to drastically reduce memory without sacrificing much accuracy.