Chapter 14: Finetuning I — SFT, LoRA, and Chat Templates

A pretrained language model learns to predict the next token from vast amounts of text, but it doesn’t inherently know how to follow instructions or answer questions. Supervised Fine-Tuning (SFT) on a curated instruction dataset teaches the model to be a useful assistant.

Full fine-tuning updates every parameter — expensive for large models. Parameter-Efficient Fine-Tuning (PEFT), particularly LoRA (Low-Rank Adaptation), offers a better trade-off: we freeze the pretrained weights and inject small trainable low-rank matrices at each attention layer. A 7B model has ~28GB of weights; LoRA typically adds only ~50MB of trainable parameters while achieving 90%+ of full fine-tune quality.

LoRA decomposes the weight update ΔW as a product of two low-rank matrices: ΔW = B × A, where A ∈ ℝ^{d×r} and B ∈ ℝ^{r×k} with rank r ≪ min(d,k). During the forward pass, W_new x = (W_frozen + B @ A) x. Only A and B are trained; W_frozen is never touched.

In this chapter we download the Alpaca instruction dataset from HuggingFace, implement LoRA from scratch, and show how to use the PEFT library for production-quality fine-tuning.

1. Download the Alpaca Instruction Dataset

from datasets import load_dataset
import os

DATA_DIR = "data"
os.makedirs(DATA_DIR, exist_ok=True)

print("Loading Alpaca instruction dataset …")
# tatsu-lab/alpaca: 52K instruction-following examples
# generated by OpenAI's text-davinci-003
alpaca = load_dataset("tatsu-lab/alpaca", split="train")

print(f"Dataset size: {len(alpaca):,} examples")
print(f"Columns: {alpaca.column_names}")
print(f"\nSample:")
ex = alpaca[0]
for k, v in ex.items():
    print(f"  {k}: {repr(v[:120])}")

2. Format Chat Templates

def format_alpaca(example: dict) -> str:
    """
    Format an Alpaca example into a prompt-response string.
    The model learns to predict the response given the instruction.
    """
    if example.get("input", "").strip():
        prompt = (
            "Below is an instruction that describes a task, paired with an input "
            "that provides further context. Write a response that appropriately "
            "completes the request.\n\n"
            f"### Instruction:\n{example['instruction']}\n\n"
            f"### Input:\n{example['input']}\n\n"
            "### Response:\n"
        )
    else:
        prompt = (
            "Below is an instruction that describes a task. Write a response "
            "that appropriately completes the request.\n\n"
            f"### Instruction:\n{example['instruction']}\n\n"
            "### Response:\n"
        )
    return prompt + example["output"] + "\n\n### End"

# Format first few examples
for i in range(2):
    print(f"--- Example {i+1} ---")
    print(format_alpaca(alpaca[i])[:400])
    print()

3. LoRA Implementation from Scratch

import torch
import torch.nn as nn
import math

class LoRALinear(nn.Module):
    """
    Drop-in replacement for nn.Linear with LoRA weights.

    The effective weight is: W_effective = W_frozen + (B @ A) * (alpha / r)
    where alpha/r is a scaling factor that controls LoRA's contribution.
    """

    def __init__(
        self,
        in_features:  int,
        out_features: int,
        rank:         int   = 8,
        alpha:        float = 16.0,
        dropout:      float = 0.05,
    ):
        super().__init__()
        self.in_features  = in_features
        self.out_features = out_features
        self.rank         = rank
        self.scaling      = alpha / rank

        # Frozen pretrained weight (won't receive gradients)
        self.weight = nn.Parameter(
            torch.empty(out_features, in_features), requires_grad=False
        )
        self.bias = nn.Parameter(torch.zeros(out_features), requires_grad=False)

        # LoRA A: initialised from N(0, 1/√d) — standard kaiming
        self.lora_A = nn.Parameter(torch.empty(rank, in_features))
        # LoRA B: initialised to zero so ΔW=0 at the start of training
        self.lora_B = nn.Parameter(torch.zeros(out_features, rank))

        self.lora_dropout = nn.Dropout(p=dropout)

        nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
        nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        # Base forward (frozen weights)
        base = nn.functional.linear(x, self.weight, self.bias)
        # LoRA update: x → dropout → A → B → scale
        lora = self.lora_dropout(x) @ self.lora_A.T @ self.lora_B.T
        return base + lora * self.scaling

    @classmethod
    def from_linear(cls, linear: nn.Linear, rank: int = 8, alpha: float = 16.0):
        """Wrap an existing nn.Linear with LoRA, copying pretrained weights."""
        lora = cls(linear.in_features, linear.out_features, rank=rank, alpha=alpha)
        lora.weight.data.copy_(linear.weight.data)
        if linear.bias is not None:
            lora.bias.data.copy_(linear.bias.data)
        return lora

    def trainable_params(self) -> list[nn.Parameter]:
        return [self.lora_A, self.lora_B]

    def extra_repr(self) -> str:
        return f"rank={self.rank}, scaling={self.scaling:.3f}"

4. Inject LoRA into a Model

def inject_lora(model: nn.Module, target_modules: set[str],
                rank: int = 8, alpha: float = 16.0) -> nn.Module:
    """
    Replace all nn.Linear layers whose names match `target_modules` with LoRALinear.
    Freeze all other parameters.
    """
    # First, freeze everything
    for param in model.parameters():
        param.requires_grad = False

    # Replace target modules
    for name, module in list(model.named_modules()):
        for target in target_modules:
            if name.endswith(target) and isinstance(module, nn.Linear):
                parent_name, child_name = name.rsplit(".", 1) if "." in name else ("", name)
                parent = model if not parent_name else dict(model.named_modules())[parent_name]
                lora_layer = LoRALinear.from_linear(module, rank=rank, alpha=alpha)
                setattr(parent, child_name, lora_layer)
                # Only LoRA parameters are trainable
                lora_layer.lora_A.requires_grad = True
                lora_layer.lora_B.requires_grad = True

    total  = sum(p.numel() for p in model.parameters())
    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total params:     {total:,}")
    print(f"Trainable params: {trainable:,}  ({100*trainable/total:.2f}%)")
    return model


# Demo on a tiny GPT-like model
class TinyGPT(nn.Module):
    def __init__(self, d=128, n_heads=4, vocab=1000, n_layers=2):
        super().__init__()
        self.emb = nn.Embedding(vocab, d)
        self.layers = nn.ModuleList([
            nn.ModuleDict({
                "q_proj": nn.Linear(d, d, bias=False),
                "k_proj": nn.Linear(d, d, bias=False),
                "v_proj": nn.Linear(d, d, bias=False),
                "out_proj": nn.Linear(d, d, bias=False),
                "ffn": nn.Linear(d, 4*d, bias=False),
                "ffn2": nn.Linear(4*d, d, bias=False),
            }) for _ in range(n_layers)
        ])
        self.head = nn.Linear(d, vocab, bias=False)

    def forward(self, x):
        h = self.emb(x)
        for layer in self.layers:
            # Simplified (no attention masking for brevity)
            q = layer["q_proj"](h)
            k = layer["k_proj"](h)
            v = layer["v_proj"](h)
            h = h + layer["out_proj"](v)
            h = h + layer["ffn2"](torch.relu(layer["ffn"](h)))
        return self.head(h)

model = TinyGPT()
print("Before LoRA injection:")
print(f"  All params: {sum(p.numel() for p in model.parameters()):,}")

inject_lora(model, target_modules={"q_proj", "v_proj"}, rank=4, alpha=8)

5. SFT Training Loop

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import get_peft_model, LoraConfig, TaskType

print("\nLoading GPT-2 with PEFT LoRA …")
model_name = "gpt2"
tokenizer  = AutoTokenizer.from_pretrained(model_name)
tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(model_name)

# PEFT LoRA config
lora_config = LoraConfig(
    task_type     = TaskType.CAUSAL_LM,
    r             = 8,                          # rank
    lora_alpha    = 16,
    lora_dropout  = 0.05,
    target_modules= ["c_attn", "c_proj"],       # GPT-2 attention projections
    bias          = "none",
)

peft_model = get_peft_model(base_model, lora_config)
peft_model.print_trainable_parameters()

# Prepare a small batch from Alpaca
device = "cuda" if torch.cuda.is_available() else "cpu"
peft_model = peft_model.to(device)
optim = torch.optim.AdamW(peft_model.parameters(), lr=2e-4, weight_decay=0.0)

# Format and tokenise a small batch of examples
texts = [format_alpaca(alpaca[i]) for i in range(8)]
batch = tokenizer(
    texts, return_tensors="pt", padding=True,
    truncation=True, max_length=256
)
input_ids      = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)

# SFT step
peft_model.train()
outputs = peft_model(input_ids=input_ids, attention_mask=attention_mask,
                     labels=input_ids)
loss = outputs.loss
print(f"\nSFT loss (first batch): {loss.item():.4f}")

optim.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(peft_model.parameters(), 1.0)
optim.step()

print("LoRA fine-tuning step complete.")

6. Save and Load LoRA Weights

# Only LoRA adapter weights are saved — full model not needed
peft_model.save_pretrained("data/lora_adapter")
print("LoRA adapter saved → data/lora_adapter/")

# Reload
from peft import PeftModel

base_model_reload = AutoModelForCausalLM.from_pretrained("gpt2")
loaded_peft = PeftModel.from_pretrained(base_model_reload, "data/lora_adapter")
print("LoRA adapter loaded successfully.")

# Merge LoRA weights back into base model (optional — for deployment)
merged = loaded_peft.merge_and_unload()
print(f"Merged model type: {type(merged).__name__}")

7. Chat Template for Instruction Following

# Modern chat models use structured templates with role markers
SYSTEM_PROMPT = "You are a helpful assistant that tells creative children's stories."

def make_chat_prompt(user_message: str) -> str:
    """
    Format a user message using a simple ChatML-style template.
    GPT-4, Mistral, LLaMA-3 all use variants of this structure.
    """
    return (
        f"<|system|>\n{SYSTEM_PROMPT}\n"
        f"<|user|>\n{user_message}\n"
        f"<|assistant|>\n"
    )

prompt = make_chat_prompt("Tell me a short story about a dragon who is afraid of fire.")
print(prompt)
print(f"Prompt length: {len(tokenizer.encode(prompt))} tokens")

8. Summary

Technique Trainable params Accuracy vs full FT
Full fine-tune 100% 100%
LoRA (r=8) ~0.1–1% ~95%
LoRA (r=64) ~1–5% ~98%
Prefix tuning ~0.1% ~90%
Prompt tuning <0.01% ~85%

Chapter 15 moves beyond supervised learning to reinforcement learning from human feedback (RLHF) and the simpler DPO alternative.