Chapter 12: Inference I — KV Cache
During autoregressive generation, the model predicts one token at a time and appends it to the sequence. Naively, each new token requires recomputing the entire attention matrix over the full sequence — including all the tokens we already processed. This is quadratic in total tokens generated, making long sequences impossibly slow.
The Key-Value (KV) cache solves this: we cache the K and V projections for every past token. When we generate the next token, we only compute Q, K, V for the new token, then concatenate its K and V with the cached versions to compute attention. The computational cost per step stays roughly constant instead of growing with sequence length.
The memory cost is O(L × H × d), where L is sequence length, H is number of heads, and d is head dimension. For a 7B parameter model generating 1 024 tokens, the KV cache requires around 1-2 GB of VRAM — a significant but manageable overhead.
1. Attention Without KV Cache (baseline)
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import time
class AttentionNoCache(nn.Module):
"""Standard causal attention — recomputes everything every step."""
def __init__(self, d_model: int, n_heads: int):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.out = nn.Linear(d_model, d_model, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""x: (B, T, d_model) — processes the full sequence each call."""
B, T, C = x.shape
H, hd = self.n_heads, self.head_dim
qkv = self.qkv(x).view(B, T, 3, H, hd).permute(2, 0, 3, 1, 4)
Q, K, V = qkv[0], qkv[1], qkv[2] # each (B, H, T, hd)
scale = math.sqrt(hd)
scores = Q @ K.transpose(-2, -1) / scale # (B, H, T, T)
mask = torch.triu(torch.ones(T, T, device=x.device), diagonal=1).bool()
scores = scores.masked_fill(mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
out = (attn @ V).transpose(1, 2).reshape(B, T, C)
return self.out(out)
2. Attention With KV Cache
class AttentionWithCache(nn.Module):
"""
Causal attention with KV cache for efficient autoregressive generation.
During generation, only the NEW token's Q/K/V are computed; cached
K and V from all previous steps are concatenated and reused.
"""
def __init__(self, d_model: int, n_heads: int):
super().__init__()
self.n_heads = n_heads
self.head_dim = d_model // n_heads
self.d_model = d_model
self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
self.out = nn.Linear(d_model, d_model, bias=False)
def forward(
self,
x: torch.Tensor,
kv_cache: dict = None,
layer_idx: int = 0,
) -> tuple[torch.Tensor, dict]:
"""
x: (B, T_new, d_model) — T_new=1 during generation, full T during prefill
kv_cache: dict keyed by layer_idx containing accumulated K and V tensors
Returns: (output, updated_kv_cache)
"""
B, T_new, C = x.shape
H, hd = self.n_heads, self.head_dim
# Compute Q, K, V for the new token(s)
qkv = self.qkv(x).view(B, T_new, 3, H, hd).permute(2, 0, 3, 1, 4)
Q, K_new, V_new = qkv[0], qkv[1], qkv[2]
# Retrieve and extend cached K, V
if kv_cache is not None and layer_idx in kv_cache:
K_cached, V_cached = kv_cache[layer_idx]
K = torch.cat([K_cached, K_new], dim=2) # (B, H, T_past+T_new, hd)
V = torch.cat([V_cached, V_new], dim=2)
else:
K, V = K_new, V_new
# Update cache
if kv_cache is not None:
kv_cache[layer_idx] = (K.detach(), V.detach())
T_total = K.shape[2]
# Scaled dot-product attention
scale = math.sqrt(hd)
scores = Q @ K.transpose(-2, -1) / scale # (B, H, T_new, T_total)
# Causal mask: new tokens can attend to all past + current tokens
# but not future ones (only relevant during prefill with T_new > 1)
if T_new > 1:
mask = torch.triu(
torch.ones(T_new, T_total, device=x.device), diagonal=T_total - T_new + 1
).bool()
scores = scores.masked_fill(mask, float("-inf"))
attn = F.softmax(scores, dim=-1)
out = (attn @ V).transpose(1, 2).reshape(B, T_new, C)
return self.out(out), kv_cache
3. GPT with KV Cache
class GPTWithCache(nn.Module):
def __init__(self, vocab_size: int, d_model: int, n_heads: int,
n_layers: int, max_len: int):
super().__init__()
self.token_emb = nn.Embedding(vocab_size, d_model)
self.pos_emb = nn.Embedding(max_len, d_model)
self.layers = nn.ModuleList([
AttentionWithCache(d_model, n_heads) for _ in range(n_layers)
])
self.ln_f = nn.LayerNorm(d_model)
self.lm_head = nn.Linear(d_model, vocab_size, bias=False)
self.lm_head.weight = self.token_emb.weight
def forward(self, idx: torch.Tensor,
kv_cache: dict = None,
start_pos: int = 0) -> tuple[torch.Tensor, dict]:
"""
idx: (B, T) — during generation T=1
kv_cache: accumulated keys and values across all layers
start_pos: position of the first token in idx (for correct position embedding)
"""
B, T = idx.shape
positions = torch.arange(start_pos, start_pos + T, device=idx.device)
x = self.token_emb(idx) + self.pos_emb(positions)
if kv_cache is None:
kv_cache = {}
for i, layer in enumerate(self.layers):
x, kv_cache = layer(x, kv_cache=kv_cache, layer_idx=i)
x = self.ln_f(x)
logits = self.lm_head(x) # (B, T, vocab_size)
return logits, kv_cache
@torch.no_grad()
def generate_with_cache(self, prompt_ids: torch.Tensor,
max_new_tokens: int,
temperature: float = 1.0,
itos: dict = None) -> list[int]:
"""Generate tokens one at a time, using the KV cache."""
self.eval()
# PREFILL: process the full prompt at once
logits, kv_cache = self(prompt_ids, kv_cache={})
next_logit = logits[:, -1, :] / temperature
next_id = torch.multinomial(F.softmax(next_logit, dim=-1), 1)
generated = prompt_ids[0].tolist() + [next_id.item()]
pos = prompt_ids.shape[1]
# DECODE: generate one token at a time
for step in range(max_new_tokens - 1):
x = next_id # (1, 1) — only new token
logits, kv_cache = self(x, kv_cache=kv_cache, start_pos=pos + step)
next_logit = logits[:, -1, :] / temperature
next_id = torch.multinomial(F.softmax(next_logit, dim=-1), 1)
generated.append(next_id.item())
if itos and itos.get(next_id.item()) == '.':
break
return generated
4. Benchmark: With vs Without KV Cache
import time
device = "cuda" if torch.cuda.is_available() else "cpu"
VOCAB = 128
D_MODEL = 128
N_HEADS = 4
N_LAYERS = 4
MAX_LEN = 256
model_cache = GPTWithCache(VOCAB, D_MODEL, N_HEADS, N_LAYERS, MAX_LEN).to(device)
model_nocache = nn.Transformer( # proxy for no-cache model
d_model=D_MODEL, nhead=N_HEADS, num_encoder_layers=0,
num_decoder_layers=N_LAYERS, batch_first=True
).to(device) if False else model_cache # reuse same model for comparison
@torch.no_grad()
def generate_no_cache(model, prompt: torch.Tensor, n_new: int) -> list[int]:
"""Naive generation: pass the full growing sequence each step."""
ids = prompt.clone()
for _ in range(n_new):
logits, _ = model(ids, kv_cache=None)
next_id = torch.multinomial(F.softmax(logits[:, -1], dim=-1), 1)
ids = torch.cat([ids, next_id], dim=1)
return ids[0].tolist()
prompt = torch.randint(0, VOCAB, (1, 16), device=device)
N_NEW = 64
# Benchmark no cache
t0 = time.perf_counter()
for _ in range(5):
_ = generate_no_cache(model_cache, prompt, N_NEW)
if device == "cuda": torch.cuda.synchronize()
no_cache_ms = (time.perf_counter() - t0) / 5 * 1000
# Benchmark with cache
t0 = time.perf_counter()
for _ in range(5):
_ = model_cache.generate_with_cache(prompt, N_NEW)
if device == "cuda": torch.cuda.synchronize()
cache_ms = (time.perf_counter() - t0) / 5 * 1000
print(f"No cache : {no_cache_ms:.1f} ms (generates {N_NEW} tokens)")
print(f"KV cache : {cache_ms:.1f} ms (generates {N_NEW} tokens)")
print(f"Speedup : {no_cache_ms / cache_ms:.2f}×")
5. KV Cache Memory Layout
def kv_cache_memory_mb(
batch_size: int,
seq_len: int,
n_layers: int,
n_heads: int,
head_dim: int,
dtype: torch.dtype = torch.float16,
) -> float:
"""Estimate KV cache memory in MB."""
bytes_per_elem = torch.finfo(dtype).bits // 8
# K tensor + V tensor per layer
total_elements = 2 * batch_size * n_layers * n_heads * seq_len * head_dim
return total_elements * bytes_per_elem / 1e6
print("KV Cache Memory Estimates (FP16):")
configs = [
("GPT-2 Small", 12, 12, 64),
("GPT-2 Large", 36, 20, 64),
("LLaMA-7B", 32, 32, 128),
]
for name, n_layers, n_heads, head_dim in configs:
for seq_len in [512, 2048, 8192]:
mb = kv_cache_memory_mb(
batch_size=1, seq_len=seq_len,
n_layers=n_layers, n_heads=n_heads, head_dim=head_dim
)
print(f" {name:15s} seq={seq_len:5d}: {mb:6.1f} MB")
6. Summary
| Concept | Detail |
|---|---|
| Prefill phase | Full prompt processed in one forward pass |
| Decode phase | One token at a time; K and V concatenated to cache |
| Memory cost | 2 × B × L × H × d × dtype_bytes per layer |
| Speedup | Quadratic → nearly linear in sequence length |
Chapter 13 covers quantisation: compressing model weights to 8-bit or 4-bit integers to drastically reduce memory without sacrificing much accuracy.