AdvancedTrain Your LLM from Scratch

Train Your LLM from Scratch

A comprehensive guide to training a language model from scratch � from data preparation and tokenization through pretraining, instruction tuning, and reasoning with RLHF/DPO.

2 hours
LLM Training, PyTorch, Pretraining, SFT, LoRA, RLHF, DPO

Prerequisites

  • Python proficiency
  • PyTorch basics
  • Understanding of transformer architecture

Stage 0: Preparation

Before writing a single training loop, you need three things: an environment that won't crash mid-run, a tokenizer that can represent your data, and a data pipeline that feeds batches efficiently. Skip any of these and you'll waste GPU hours debugging.

Environment Setup

Bash
1# Create a clean environment
2conda create -n llm-train python=3.11 -y
3conda activate llm-train
4
5# Core dependencies
6pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
7pip install transformers datasets tokenizers wandb accelerate
8pip install flash-attn --no-build-isolation  # FlashAttention-2
9
10# Verify GPU
11python -c "import torch; print(f'CUDA: {torch.cuda.is_available()}, GPUs: {torch.cuda.device_count()}')"

Building a Tokenizer

We'll train a BPE tokenizer from scratch using HuggingFace `tokenizers`. This is the same approach used by Llama, GPT, and Mistral.

Python
1from tokenizers import Tokenizer, models, trainers, pre_tokenizers, decoders
2
3def train_tokenizer(
4    data_files: list[str],
5    vocab_size: int = 32_000,
6    save_path: str = "tokenizer.json"
7):
8    tokenizer = Tokenizer(models.BPE())
9    tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
10    tokenizer.decoder = decoders.ByteLevel()
11
12    trainer = trainers.BpeTrainer(
13        vocab_size=vocab_size,
14        special_tokens=["<|pad|>", "<|eos|>", "<|bos|>", "<|unk|>"],
15        min_frequency=2,
16        show_progress=True,
17    )
18
19    tokenizer.train(data_files, trainer)
20    tokenizer.save(save_path)
21    print(f"Trained tokenizer: {tokenizer.get_vocab_size()} tokens")
22    return tokenizer
23
24# Train on your corpus
25tokenizer = train_tokenizer(["corpus_part1.txt", "corpus_part2.txt"])
26
27# Test it
28encoded = tokenizer.encode("The transformer architecture revolutionized NLP.")
29print(f"Tokens: {encoded.tokens}")
30print(f"IDs:    {encoded.ids}")

Key decisions:

  • Vocab size: 32K is a good default. Larger (64K) improves multilingual; smaller (16K) reduces embedding params
  • Special tokens: At minimum you need pad, eos, bos, unk. Add `<|im_start|>` and `<|im_end|>` if you plan to instruction-tune later

Data Pipeline

Efficient data loading is critical. We pack multiple documents into fixed-length sequences to maximize GPU utilization:

Python
1import torch
2from torch.utils.data import Dataset, DataLoader
3from pathlib import Path
4import numpy as np
5
6class PretrainingDataset(Dataset):
7    """Concatenates all documents and chunks into fixed-length sequences."""
8
9    def __init__(
10        self,
11        data_dir: str,
12        tokenizer,
13        seq_length: int = 2048,
14        stride: int = 2048,  # No overlap by default
15    ):
16        self.seq_length = seq_length
17        self.stride = stride
18
19        # Tokenize and concatenate all files
20        all_ids = []
21        for file in sorted(Path(data_dir).glob("*.txt")):
22            text = file.read_text(encoding="utf-8")
23            encoded = tokenizer.encode(text)
24            all_ids.extend(encoded.ids)
25            all_ids.append(tokenizer.token_to_id("<|eos|>"))
26
27        self.data = np.array(all_ids, dtype=np.uint16)
28        self.n_chunks = max(1, (len(self.data) - seq_length) // stride)
29        print(f"Dataset: {len(self.data):,} tokens → {self.n_chunks:,} chunks")
30
31    def __len__(self):
32        return self.n_chunks
33
34    def __getitem__(self, idx):
35        start = idx * self.stride
36        end = start + self.seq_length + 1  # +1 for target shift
37
38        chunk = torch.tensor(self.data[start:end], dtype=torch.long)
39        x = chunk[:-1]   # Input
40        y = chunk[1:]    # Target (shifted by 1)
41        return x, y
42
43def create_dataloaders(
44    train_dir: str,
45    val_dir: str,
46    tokenizer,
47    batch_size: int = 8,
48    seq_length: int = 2048,
49):
50    train_ds = PretrainingDataset(train_dir, tokenizer, seq_length)
51    val_ds = PretrainingDataset(val_dir, tokenizer, seq_length)
52
53    train_loader = DataLoader(
54        train_ds, batch_size=batch_size, shuffle=True,
55        num_workers=4, pin_memory=True, drop_last=True,
56    )
57    val_loader = DataLoader(
58        val_ds, batch_size=batch_size, shuffle=False,
59        num_workers=2, pin_memory=True,
60    )
61    return train_loader, val_loader

Data Preparation Checklist

StepActionWhy
1Deduplicate documentsPrevents memorization of repeated text
2Filter low-quality textRemoves boilerplate, ads, HTML artifacts
3Shuffle at document levelPrevents domain clustering in batches
4Split train/val (99/1)Val set should be representative but small
5Tokenize and save as binaryAvoids re-tokenizing every training run

Hyperparameter Cheat Sheet

For a ~125M parameter model (good for learning):

Python
1config = {
2    "vocab_size": 32_000,
3    "d_model": 768,
4    "n_heads": 12,
5    "n_layers": 12,
6    "d_ff": 3072,          # 4 * d_model
7    "seq_length": 2048,
8    "dropout": 0.1,
9    "batch_size": 8,       # Per GPU
10    "gradient_accumulation": 4,
11    "lr": 3e-4,
12    "warmup_steps": 1000,
13    "total_steps": 100_000,
14    "weight_decay": 0.1,
15}

Next: Stage 1 — Pretraining, where we build the model and write the training loop.


Stage 1: Pretraining

Pretraining is where a language model learns to predict the next token. This is the most compute-intensive phase — you're teaching the model the statistical patterns of language from raw text.

Model Architecture

We'll build a decoder-only transformer (GPT-style). This is the same architecture used by GPT, Llama, and Mistral:

Python
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5
6class RMSNorm(nn.Module):
7    """RMSNorm (used by Llama, Mistral instead of LayerNorm)."""
8    def __init__(self, dim: int, eps: float = 1e-6):
9        super().__init__()
10        self.eps = eps
11        self.weight = nn.Parameter(torch.ones(dim))
12
13    def forward(self, x):
14        norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
15        return x * norm * self.weight
16
17
18class CausalSelfAttention(nn.Module):
19    def __init__(self, d_model: int, n_heads: int, max_seq_len: int = 2048):
20        super().__init__()
21        assert d_model % n_heads == 0
22        self.n_heads = n_heads
23        self.d_k = d_model // n_heads
24
25        self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
26        self.proj = nn.Linear(d_model, d_model, bias=False)
27
28        # Causal mask
29        mask = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1).bool()
30        self.register_buffer("mask", mask)
31
32    def forward(self, x):
33        B, T, C = x.shape
34
35        # Project Q, K, V in one shot
36        qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.d_k)
37        q, k, v = qkv.unbind(dim=2)
38        q, k, v = [t.transpose(1, 2) for t in (q, k, v)]  # [B, heads, T, d_k]
39
40        # Scaled dot-product attention with causal mask
41        scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
42        scores = scores.masked_fill(self.mask[:T, :T], float("-inf"))
43        attn = F.softmax(scores, dim=-1)
44
45        out = (attn @ v).transpose(1, 2).reshape(B, T, C)
46        return self.proj(out)
47
48
49class FeedForward(nn.Module):
50    """SwiGLU feedforward (used by Llama, Mistral)."""
51    def __init__(self, d_model: int, d_ff: int):
52        super().__init__()
53        self.w1 = nn.Linear(d_model, d_ff, bias=False)
54        self.w2 = nn.Linear(d_ff, d_model, bias=False)
55        self.w3 = nn.Linear(d_model, d_ff, bias=False)
56
57    def forward(self, x):
58        return self.w2(F.silu(self.w1(x)) * self.w3(x))
59
60
61class TransformerBlock(nn.Module):
62    def __init__(self, d_model, n_heads, d_ff):
63        super().__init__()
64        self.norm1 = RMSNorm(d_model)
65        self.attn = CausalSelfAttention(d_model, n_heads)
66        self.norm2 = RMSNorm(d_model)
67        self.ff = FeedForward(d_model, d_ff)
68
69    def forward(self, x):
70        x = x + self.attn(self.norm1(x))
71        x = x + self.ff(self.norm2(x))
72        return x
73
74
75class GPT(nn.Module):
76    def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff, max_seq_len):
77        super().__init__()
78        self.tok_emb = nn.Embedding(vocab_size, d_model)
79        self.pos_emb = nn.Embedding(max_seq_len, d_model)
80        self.blocks = nn.ModuleList([
81            TransformerBlock(d_model, n_heads, d_ff) for _ in range(n_layers)
82        ])
83        self.norm = RMSNorm(d_model)
84        self.head = nn.Linear(d_model, vocab_size, bias=False)
85
86        # Weight tying
87        self.head.weight = self.tok_emb.weight
88
89        n_params = sum(p.numel() for p in self.parameters())
90        print(f"Model parameters: {n_params / 1e6:.1f}M")
91
92    def forward(self, idx, targets=None):
93        B, T = idx.shape
94        pos = torch.arange(T, device=idx.device)
95
96        x = self.tok_emb(idx) + self.pos_emb(pos)
97        for block in self.blocks:
98            x = block(x)
99        x = self.norm(x)
100        logits = self.head(x)
101
102        loss = None
103        if targets is not None:
104            loss = F.cross_entropy(
105                logits.view(-1, logits.size(-1)),
106                targets.view(-1),
107                ignore_index=-1,
108            )
109        return logits, loss

Training Loop

Here's the complete training loop with mixed precision, gradient accumulation, and cosine LR schedule:

Python
1from torch.cuda.amp import autocast, GradScaler
2import wandb
3
4def train(model, train_loader, val_loader, config):
5    device = torch.device("cuda")
6    model = model.to(device)
7
8    optimizer = torch.optim.AdamW(
9        model.parameters(),
10        lr=config["lr"],
11        betas=(0.9, 0.95),
12        weight_decay=config["weight_decay"],
13    )
14
15    # Cosine LR with warmup
16    def lr_schedule(step):
17        if step < config["warmup_steps"]:
18            return step / config["warmup_steps"]
19        progress = (step - config["warmup_steps"]) / (config["total_steps"] - config["warmup_steps"])
20        return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress))
21
22    scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
23    scaler = GradScaler()  # For mixed precision
24
25    wandb.init(project="llm-from-scratch", config=config)
26
27    global_step = 0
28    for epoch in range(config.get("epochs", 1)):
29        model.train()
30        for batch_idx, (x, y) in enumerate(train_loader):
31            x, y = x.to(device), y.to(device)
32
33            with autocast(dtype=torch.bfloat16):
34                _, loss = model(x, y)
35                loss = loss / config["gradient_accumulation"]
36
37            scaler.scale(loss).backward()
38
39            if (batch_idx + 1) % config["gradient_accumulation"] == 0:
40                scaler.unscale_(optimizer)
41                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
42                scaler.step(optimizer)
43                scaler.update()
44                optimizer.zero_grad()
45                scheduler.step()
46                global_step += 1
47
48                if global_step % 100 == 0:
49                    wandb.log({
50                        "train/loss": loss.item() * config["gradient_accumulation"],
51                        "train/lr": scheduler.get_last_lr()[0],
52                        "train/step": global_step,
53                    })
54
55                if global_step % 1000 == 0:
56                    val_loss = validate(model, val_loader, device)
57                    wandb.log({"val/loss": val_loss, "val/perplexity": math.exp(val_loss)})
58                    print(f"Step {global_step} | Val loss: {val_loss:.4f} | PPL: {math.exp(val_loss):.2f}")
59                    save_checkpoint(model, optimizer, global_step)
60
61                if global_step >= config["total_steps"]:
62                    return

Validation

Perplexity is the standard metric for pretraining — lower is better. A perplexity of 20 means the model is "20-way confused" on average.

Python
1@torch.no_grad()
2def validate(model, val_loader, device, max_batches=50):
3    model.eval()
4    total_loss = 0
5    n_batches = 0
6
7    for x, y in val_loader:
8        x, y = x.to(device), y.to(device)
9        with autocast(dtype=torch.bfloat16):
10            _, loss = model(x, y)
11        total_loss += loss.item()
12        n_batches += 1
13        if n_batches >= max_batches:
14            break
15
16    model.train()
17    return total_loss / n_batches
18
19def save_checkpoint(model, optimizer, step, path="checkpoints"):
20    Path(path).mkdir(exist_ok=True)
21    torch.save({
22        "model": model.state_dict(),
23        "optimizer": optimizer.state_dict(),
24        "step": step,
25    }, f"{path}/step_{step}.pt")

What to Watch For

MetricHealthyUnhealthy
Training lossSmooth downward curveSpikes, plateaus early
Val lossTracks train loss closelyDiverges from train loss
Gradient normStable around 0.1–1.0Exploding (>10) or vanishing
Learning rateSmooth warmup → cosine decay
PerplexitySteadily decreasingStuck above 100 after 10K steps

Generating Text (Sanity Check)

Python
1@torch.no_grad()
2def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.8):
3    model.eval()
4    device = next(model.parameters()).device
5    ids = tokenizer.encode(prompt).ids
6    x = torch.tensor([ids], device=device)
7
8    for _ in range(max_new_tokens):
9        logits, _ = model(x[:, -2048:])  # Truncate to max seq len
10        logits = logits[:, -1, :] / temperature
11        probs = F.softmax(logits, dim=-1)
12        next_id = torch.multinomial(probs, 1)
13        x = torch.cat([x, next_id], dim=1)
14
15        if next_id.item() == tokenizer.token_to_id("<|eos|>"):
16            break
17
18    return tokenizer.decode(x[0].tolist())

Next: Stage 2 — Instruction Tuning, where we teach the model to follow instructions.


Stage 2: Instruction Tuning

A pretrained model can complete text, but it can't follow instructions. Instruction tuning (SFT — Supervised Fine-Tuning) teaches the model to respond helpfully to user requests. This is the step that turns a "text completer" into a "chatbot."

Chat Template

First, define a chat format your model will learn:

Python
1CHAT_TEMPLATE = """<|im_start|>system
2{system}<|im_end|>
3<|im_start|>user
4{user}<|im_end|>
5<|im_start|>assistant
6{assistant}<|im_end|>"""
7
8def format_conversation(system: str, user: str, assistant: str) -> str:
9    return CHAT_TEMPLATE.format(
10        system=system, user=user, assistant=assistant
11    )
12
13# Example
14formatted = format_conversation(
15    system="You are a helpful assistant.",
16    user="What is gradient descent?",
17    assistant="Gradient descent is an optimization algorithm..."
18)

Instruction Dataset

Python
1from datasets import load_dataset
2from torch.utils.data import Dataset
3
4class InstructionDataset(Dataset):
5    def __init__(self, tokenizer, max_length=2048, split="train"):
6        # Use a public instruction dataset
7        raw = load_dataset("tatsu-lab/alpaca", split=split)
8        self.samples = []
9        self.tokenizer = tokenizer
10        self.max_length = max_length
11
12        for item in raw:
13            instruction = item["instruction"]
14            if item.get("input"):
15                instruction += f"\n\nInput: {item['input']}"
16
17            text = format_conversation(
18                system="You are a helpful assistant.",
19                user=instruction,
20                assistant=item["output"],
21            )
22            ids = tokenizer.encode(text).ids
23            if len(ids) <= max_length:
24                self.samples.append(ids)
25
26        print(f"Instruction dataset: {len(self.samples)} samples")
27
28    def __len__(self):
29        return len(self.samples)
30
31    def __getitem__(self, idx):
32        ids = self.samples[idx]
33
34        # Pad to max_length
35        padded = ids + [self.tokenizer.token_to_id("<|pad|>")] * (self.max_length - len(ids))
36        x = torch.tensor(padded[:-1], dtype=torch.long)
37        y = torch.tensor(padded[1:], dtype=torch.long)
38
39        # Mask: only compute loss on the assistant's response
40        # Find where assistant response starts
41        y[:len(ids) // 2] = -1  # Simplified — mask system + user tokens
42        return x, y

LoRA: Parameter-Efficient Fine-Tuning

Full fine-tuning updates all parameters. LoRA freezes the base model and adds small trainable matrices, reducing memory by ~10x:

Python
1class LoRALinear(nn.Module):
2    """Low-Rank Adaptation for efficient fine-tuning."""
3    def __init__(self, base_layer: nn.Linear, rank: int = 16, alpha: float = 32):
4        super().__init__()
5        self.base = base_layer
6        self.base.weight.requires_grad = False  # Freeze base
7
8        d_in, d_out = base_layer.in_features, base_layer.out_features
9        self.lora_A = nn.Parameter(torch.randn(d_in, rank) * 0.01)
10        self.lora_B = nn.Parameter(torch.zeros(rank, d_out))
11        self.scale = alpha / rank
12
13    def forward(self, x):
14        base_out = self.base(x)
15        lora_out = (x @ self.lora_A @ self.lora_B) * self.scale
16        return base_out + lora_out
17
18def apply_lora(model, rank=16, alpha=32):
19    """Apply LoRA to all attention projection layers."""
20    for name, module in model.named_modules():
21        if isinstance(module, CausalSelfAttention):
22            module.qkv = LoRALinear(module.qkv, rank, alpha)
23            module.proj = LoRALinear(module.proj, rank, alpha)
24
25    trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
26    total = sum(p.numel() for p in model.parameters())
27    print(f"LoRA: {trainable/1e6:.1f}M trainable / {total/1e6:.1f}M total ({100*trainable/total:.1f}%)")

SFT Training Loop

Python
1def instruction_tune(model, train_ds, val_ds, config):
2    device = torch.device("cuda")
3    model = model.to(device)
4
5    # Only optimize LoRA parameters
6    optimizer = torch.optim.AdamW(
7        [p for p in model.parameters() if p.requires_grad],
8        lr=2e-5,            # Much lower LR than pretraining
9        weight_decay=0.01,
10    )
11
12    train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
13    val_loader = DataLoader(val_ds, batch_size=4)
14
15    for epoch in range(3):  # SFT typically needs only 1-3 epochs
16        model.train()
17        epoch_loss = 0
18        for x, y in train_loader:
19            x, y = x.to(device), y.to(device)
20
21            with autocast(dtype=torch.bfloat16):
22                _, loss = model(x, y)
23
24            loss.backward()
25            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
26            optimizer.step()
27            optimizer.zero_grad()
28            epoch_loss += loss.item()
29
30        val_loss = validate(model, val_loader, device)
31        print(f"Epoch {epoch+1} | Train: {epoch_loss/len(train_loader):.4f} | Val: {val_loss:.4f}")

Evaluation

After instruction tuning, evaluate on standard benchmarks:

Python
1def evaluate_instruction_following(model, tokenizer, test_prompts):
2    """Manual evaluation of instruction following quality."""
3    results = []
4    for prompt in test_prompts:
5        text = format_conversation(
6            system="You are a helpful assistant.",
7            user=prompt,
8            assistant="",
9        )
10        # Remove the final <|im_end|> so the model generates the response
11        text = text.rsplit("<|im_end|>", 1)[0]
12
13        response = generate(model, tokenizer, text, max_new_tokens=256)
14        results.append({"prompt": prompt, "response": response})
15        print(f"Q: {prompt}")
16        print(f"A: {response}\n")
17    return results
18
19test_prompts = [
20    "Explain quantum computing in simple terms.",
21    "Write a Python function to find the nth Fibonacci number.",
22    "What are the pros and cons of remote work?",
23    "Summarize the key ideas of the attention mechanism.",
24]

SFT Pitfalls

IssueSymptomFix
Catastrophic forgettingModel loses general knowledgeLower LR, use LoRA, fewer epochs
OverfittingVal loss increases after epoch 1More data, higher dropout, early stopping
Template leakageModel outputs `<im_start
RepetitionModel loops the same phraseAdd repetition penalty during generation

Next: Stage 3 — Reasoning, where we teach the model to think step by step.


Stage 3: Reasoning & Alignment

An instruction-tuned model follows commands, but it doesn't think. This stage teaches the model to reason step-by-step and aligns its behavior with human preferences using RLHF or DPO.

Chain-of-Thought Training Data

The key insight: if you train on data that contains explicit reasoning steps, the model learns to reason. We create "thinking" traces:

Python
1COT_TEMPLATE = """<|im_start|>system
2You are a helpful assistant. Think step by step before answering.<|im_end|>
3<|im_start|>user
4{question}<|im_end|>
5<|im_start|>assistant
6<think>
7{reasoning}
8</think>
9
10{answer}<|im_end|>"""
11
12# Example training data
13cot_examples = [
14    {
15        "question": "If a train travels 120 km in 2 hours, what is its speed in m/s?",
16        "reasoning": """Step 1: Find speed in km/h.
17Speed = distance / time = 120 km / 2 h = 60 km/h.
18
19Step 2: Convert km/h to m/s.
201 km = 1000 m, 1 h = 3600 s.
2160 km/h = 60 × 1000 / 3600 = 16.67 m/s.""",
22        "answer": "The train's speed is approximately 16.67 m/s."
23    },
24    {
25        "question": "A store has a 25% off sale. If an item costs $80, and there's an additional 10% member discount applied after, what's the final price?",
26        "reasoning": """Step 1: Apply the 25% sale discount.
2725% of $80 = $20. Price after sale: $80 - $20 = $60.
28
29Step 2: Apply the 10% member discount on the sale price.
3010% of $60 = $6. Price after member discount: $60 - $6 = $54.
31
32Step 3: The discounts are applied sequentially, not combined.
33Total discount is not 35% — it's 25% then 10% of the reduced price.""",
34        "answer": "The final price is $54.00."
35    },
36]

Rejection Sampling for Reasoning Data

Generate multiple attempts and keep only the ones that arrive at the correct answer:

Python
1def generate_cot_data(model, tokenizer, problems, n_samples=8, temperature=0.7):
2    """Generate chain-of-thought data via rejection sampling."""
3    good_samples = []
4
5    for problem in problems:
6        prompt = f"""<|im_start|>system
7Think step by step.<|im_end|>
8<|im_start|>user
9{problem['question']}<|im_end|>
10<|im_start|>assistant
11<think>
12"""
13        candidates = []
14        for _ in range(n_samples):
15            response = generate(model, tokenizer, prompt,
16                              max_new_tokens=512, temperature=temperature)
17            # Check if the final answer matches
18            if problem["answer"] in response:
19                candidates.append(response)
20
21        if candidates:
22            # Keep the shortest correct reasoning (Occam's razor)
23            best = min(candidates, key=len)
24            good_samples.append({
25                "question": problem["question"],
26                "response": best,
27            })
28
29    print(f"Generated {len(good_samples)}/{len(problems)} valid CoT samples")
30    return good_samples

DPO: Direct Preference Optimization

DPO is a simpler alternative to RLHF. Instead of training a reward model, you directly optimize on preference pairs (chosen vs rejected):

Python
1class DPOTrainer:
2    def __init__(self, model, ref_model, tokenizer, beta=0.1, lr=5e-7):
3        self.model = model
4        self.ref_model = ref_model  # Frozen copy of the SFT model
5        self.tokenizer = tokenizer
6        self.beta = beta
7
8        # Freeze reference model
9        for p in self.ref_model.parameters():
10            p.requires_grad = False
11
12        self.optimizer = torch.optim.AdamW(
13            model.parameters(), lr=lr, weight_decay=0.01
14        )
15
16    def compute_log_probs(self, model, input_ids, labels):
17        """Compute log probabilities of the target tokens."""
18        logits, _ = model(input_ids)
19        log_probs = F.log_softmax(logits, dim=-1)
20
21        # Gather log probs for actual tokens
22        token_log_probs = log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
23
24        # Mask padding
25        mask = (labels != -1).float()
26        return (token_log_probs * mask).sum(-1) / mask.sum(-1)
27
28    def dpo_loss(self, chosen_ids, chosen_labels, rejected_ids, rejected_labels):
29        """DPO loss: maximize margin between chosen and rejected."""
30        # Policy log probs
31        pi_chosen = self.compute_log_probs(self.model, chosen_ids, chosen_labels)
32        pi_rejected = self.compute_log_probs(self.model, rejected_ids, rejected_labels)
33
34        # Reference log probs (no gradient)
35        with torch.no_grad():
36            ref_chosen = self.compute_log_probs(self.ref_model, chosen_ids, chosen_labels)
37            ref_rejected = self.compute_log_probs(self.ref_model, rejected_ids, rejected_labels)
38
39        # DPO objective
40        pi_logratios = pi_chosen - pi_rejected
41        ref_logratios = ref_chosen - ref_rejected
42        logits = self.beta * (pi_logratios - ref_logratios)
43
44        loss = -F.logsigmoid(logits).mean()
45        return loss
46
47    def train_step(self, batch):
48        loss = self.dpo_loss(
49            batch["chosen_ids"], batch["chosen_labels"],
50            batch["rejected_ids"], batch["rejected_labels"],
51        )
52        loss.backward()
53        torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
54        self.optimizer.step()
55        self.optimizer.zero_grad()
56        return loss.item()

Preference Data Format

Python
1preference_pairs = [
2    {
3        "prompt": "Explain recursion.",
4        "chosen": "<think>\nRecursion is when a function calls itself...\n</think>\n\nRecursion is a programming concept where a function calls itself to solve smaller instances of the same problem...",
5        "rejected": "Recursion is when something is recursive."
6    },
7    {
8        "prompt": "Is water wet?",
9        "chosen": "<think>\nThis is a nuanced question. 'Wet' typically means...\n</think>\n\nThis depends on how you define 'wet.' If wet means 'covered in water,' then water itself isn't wet — it *makes* things wet...",
10        "rejected": "Yes, water is wet because it's a liquid."
11    },
12]

Evaluation: Reasoning Benchmarks

Python
1def evaluate_reasoning(model, tokenizer, benchmark="gsm8k"):
2    """Evaluate on GSM8K (grade school math)."""
3    dataset = load_dataset("gsm8k", "main", split="test")
4    correct = 0
5    total = 0
6
7    for item in dataset:
8        prompt = f"""<|im_start|>system
9Solve step by step. Put your final numerical answer after "#### ".<|im_end|>
10<|im_start|>user
11{item['question']}<|im_end|>
12<|im_start|>assistant
13<think>
14"""
15        response = generate(model, tokenizer, prompt, max_new_tokens=512, temperature=0.0)
16
17        # Extract the final answer
18        predicted = extract_answer(response)
19        expected = item["answer"].split("####")[-1].strip()
20
21        if predicted == expected:
22            correct += 1
23        total += 1
24
25    accuracy = correct / total
26    print(f"{benchmark} accuracy: {accuracy:.1%} ({correct}/{total})")
27    return accuracy

Inference & Efficiency Metrics

These metrics measure how well an AI model runs on hardware and its responsiveness in production.

Throughput (Tokens Per Second)

Throughput measures the total volume of output tokens a model generates every second. High TPS is critical for high-traffic applications and batch processing.

For a given model, throughput depends on:

  • Hardware: GPU type (A100, H100), number of GPUs, interconnect bandwidth
  • Batch size: Larger batches improve throughput but increase latency
  • Quantization: INT8/INT4 quantization reduces memory and increases speed at the cost of some quality
  • Serving framework: vLLM, TensorRT-LLM, and SGLang provide optimized inference kernels
Model SizeTypical TPS (A100)Typical TPS (H100)
7B40-8080-150
13B25-5050-100
70B8-1520-40

Time to First Token (TTFT)

TTFT is the delay between a user sending a prompt and seeing the very first character of the response. Sub-200ms is the standard for a "snappy" user experience.

TTFT is dominated by the prefill phase � where the model processes all input tokens in parallel through KV-cache computation. Techniques to reduce TTFT:

  • Speculative decoding: Use a small draft model to propose tokens, verified by the large model
  • Prefix caching: Cache the KV states of common system prompts
  • Chunked prefill: Break long prompts into chunks to overlap with decode

Context Window

The context window is the "short-term memory" of the model, measured in tokens. A larger window allows the AI to process entire books or massive codebases in a single pass.

ModelContext Length~Pages of Text
GPT-4o128K~300 pages
Claude 3.5200K~500 pages
Gemini 1.5 Pro2M~5,000 pages

Key techniques for extending context:

GPU & Memory Utilization

Tracks how much hardware resources (VRAM) the model consumes. Lower utilization per query allows for more simultaneous users.

Key optimization techniques:


Quality & Intelligence Metrics

These quantify how "smart" or accurate a model's outputs are.

Perplexity

Perplexity is a mathematical measure of how "surprised" a model is by new data. Lower is better, indicating the model has a stronger internal grasp of language patterns.

Perplexity = exp(average cross-entropy loss). A perplexity of 10 means the model is, on average, "10-way uncertain" about each next token.

StageTypical Perplexity
Early pretraining100-1000+
Converged pretraining5-15
Domain-specific fine-tune3-8

Important caveat: Perplexity only measures next-token prediction quality on a held-out set. A model with great perplexity can still produce bad instruction-following results.

Accuracy & F1 Score

Standard metrics for classification and extraction tasks:

  • Accuracy: Percentage of correct predictions overall
  • Precision: Of items flagged as positive, how many actually are? (Reduces false positives)
  • Recall: Of all actual positives, how many did we find? (Reduces false negatives)
  • F1 Score: The harmonic mean of precision and recall � the "gold standard" for balancing both

For LLM benchmarks, the most commonly referenced evaluations include:

  • MMLU: 57 subjects ranging from STEM to humanities
  • HumanEval: Code generation benchmark
  • GSM8K: Grade school math reasoning
  • HellaSwag: Commonsense reasoning

Elo Rating (Human Preference)

Elo rating, borrowed from chess, is used by the LMSYS Chatbot Arena to rank models based on blind side-by-side human testing. Users see two anonymous model outputs and pick the better one.

This is arguably the most reliable quality signal because:

  • It captures holistic quality (helpfulness, safety, style, accuracy)
  • It's resistant to benchmark gaming � models can't overfit to specific test sets
  • It reflects real user preferences, not proxy metrics

Hallucination Rate

The hallucination rate measures how frequently a model generates factually incorrect or unsupported information. This is one of the biggest challenges in deploying LLMs.

Two types of hallucination:

  • Intrinsic: Contradicts the provided source material
  • Extrinsic: Generates information not supported by any source

Mitigation strategies:

Scaling Laws

Both efficiency and quality metrics improve with scale, but in predictable ways described by scaling laws:

  • Compute-optimal training (Chinchilla scaling): The optimal model size and data size grow proportionally with compute budget
  • Inference scaling: Techniques like test-time compute allow models to "think longer" on harder problems, trading latency for quality
  • Data scaling: Textbooks Are All You Need showed that high-quality data can substitute for raw scale

The Full Training Pipeline

StageWhatDataEpochsLRResult
0PreparationRaw text corpusTokenizer + data pipeline
1PretrainingRaw text13e-4Next-token predictor
2SFTInstruction pairs1–32e-5Instruction follower
3aCoTReasoning traces1–21e-5Step-by-step thinker
3bDPOPreference pairs15e-7Aligned reasoner

What You've Built

By completing all four stages, you've built a model that:

  1. Understands language (pretraining)
  2. Follows instructions (SFT)
  3. Thinks before answering (chain-of-thought)
  4. Prefers good answers over bad ones (DPO alignment)

This is the same pipeline used by frontier models — just at a smaller scale. The architecture, loss functions, and training stages are identical.


This completes the "Train Your LLM from Scratch" series. For production-scale training, explore DeepSpeed, FSDP, and multi-node distributed training.

Part of Train Your LLM from Scratch