Train Your LLM from Scratch
A comprehensive guide to training a language model from scratch � from data preparation and tokenization through pretraining, instruction tuning, and reasoning with RLHF/DPO.
Prerequisites
- Python proficiency
- PyTorch basics
- Understanding of transformer architecture
Stage 0: Preparation
Before writing a single training loop, you need three things: an environment that won't crash mid-run, a tokenizer that can represent your data, and a data pipeline that feeds batches efficiently. Skip any of these and you'll waste GPU hours debugging.
Environment Setup
1# Create a clean environment
2conda create -n llm-train python=3.11 -y
3conda activate llm-train
4
5# Core dependencies
6pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
7pip install transformers datasets tokenizers wandb accelerate
8pip install flash-attn --no-build-isolation # FlashAttention-2
9
10# Verify GPU
11python -c "import torch; print(f'CUDA: {torch.cuda.is_available()}, GPUs: {torch.cuda.device_count()}')"Building a Tokenizer
We'll train a BPE tokenizer from scratch using HuggingFace `tokenizers`. This is the same approach used by Llama, GPT, and Mistral.
1from tokenizers import Tokenizer, models, trainers, pre_tokenizers, decoders
2
3def train_tokenizer(
4 data_files: list[str],
5 vocab_size: int = 32_000,
6 save_path: str = "tokenizer.json"
7):
8 tokenizer = Tokenizer(models.BPE())
9 tokenizer.pre_tokenizer = pre_tokenizers.ByteLevel(add_prefix_space=False)
10 tokenizer.decoder = decoders.ByteLevel()
11
12 trainer = trainers.BpeTrainer(
13 vocab_size=vocab_size,
14 special_tokens=["<|pad|>", "<|eos|>", "<|bos|>", "<|unk|>"],
15 min_frequency=2,
16 show_progress=True,
17 )
18
19 tokenizer.train(data_files, trainer)
20 tokenizer.save(save_path)
21 print(f"Trained tokenizer: {tokenizer.get_vocab_size()} tokens")
22 return tokenizer
23
24# Train on your corpus
25tokenizer = train_tokenizer(["corpus_part1.txt", "corpus_part2.txt"])
26
27# Test it
28encoded = tokenizer.encode("The transformer architecture revolutionized NLP.")
29print(f"Tokens: {encoded.tokens}")
30print(f"IDs: {encoded.ids}")Key decisions:
- Vocab size: 32K is a good default. Larger (64K) improves multilingual; smaller (16K) reduces embedding params
- Special tokens: At minimum you need pad, eos, bos, unk. Add `<|im_start|>` and `<|im_end|>` if you plan to instruction-tune later
Data Pipeline
Efficient data loading is critical. We pack multiple documents into fixed-length sequences to maximize GPU utilization:
1import torch
2from torch.utils.data import Dataset, DataLoader
3from pathlib import Path
4import numpy as np
5
6class PretrainingDataset(Dataset):
7 """Concatenates all documents and chunks into fixed-length sequences."""
8
9 def __init__(
10 self,
11 data_dir: str,
12 tokenizer,
13 seq_length: int = 2048,
14 stride: int = 2048, # No overlap by default
15 ):
16 self.seq_length = seq_length
17 self.stride = stride
18
19 # Tokenize and concatenate all files
20 all_ids = []
21 for file in sorted(Path(data_dir).glob("*.txt")):
22 text = file.read_text(encoding="utf-8")
23 encoded = tokenizer.encode(text)
24 all_ids.extend(encoded.ids)
25 all_ids.append(tokenizer.token_to_id("<|eos|>"))
26
27 self.data = np.array(all_ids, dtype=np.uint16)
28 self.n_chunks = max(1, (len(self.data) - seq_length) // stride)
29 print(f"Dataset: {len(self.data):,} tokens → {self.n_chunks:,} chunks")
30
31 def __len__(self):
32 return self.n_chunks
33
34 def __getitem__(self, idx):
35 start = idx * self.stride
36 end = start + self.seq_length + 1 # +1 for target shift
37
38 chunk = torch.tensor(self.data[start:end], dtype=torch.long)
39 x = chunk[:-1] # Input
40 y = chunk[1:] # Target (shifted by 1)
41 return x, y
42
43def create_dataloaders(
44 train_dir: str,
45 val_dir: str,
46 tokenizer,
47 batch_size: int = 8,
48 seq_length: int = 2048,
49):
50 train_ds = PretrainingDataset(train_dir, tokenizer, seq_length)
51 val_ds = PretrainingDataset(val_dir, tokenizer, seq_length)
52
53 train_loader = DataLoader(
54 train_ds, batch_size=batch_size, shuffle=True,
55 num_workers=4, pin_memory=True, drop_last=True,
56 )
57 val_loader = DataLoader(
58 val_ds, batch_size=batch_size, shuffle=False,
59 num_workers=2, pin_memory=True,
60 )
61 return train_loader, val_loaderData Preparation Checklist
| Step | Action | Why |
|---|---|---|
| 1 | Deduplicate documents | Prevents memorization of repeated text |
| 2 | Filter low-quality text | Removes boilerplate, ads, HTML artifacts |
| 3 | Shuffle at document level | Prevents domain clustering in batches |
| 4 | Split train/val (99/1) | Val set should be representative but small |
| 5 | Tokenize and save as binary | Avoids re-tokenizing every training run |
Hyperparameter Cheat Sheet
For a ~125M parameter model (good for learning):
1config = {
2 "vocab_size": 32_000,
3 "d_model": 768,
4 "n_heads": 12,
5 "n_layers": 12,
6 "d_ff": 3072, # 4 * d_model
7 "seq_length": 2048,
8 "dropout": 0.1,
9 "batch_size": 8, # Per GPU
10 "gradient_accumulation": 4,
11 "lr": 3e-4,
12 "warmup_steps": 1000,
13 "total_steps": 100_000,
14 "weight_decay": 0.1,
15}Next: Stage 1 — Pretraining, where we build the model and write the training loop.
Stage 1: Pretraining
Pretraining is where a language model learns to predict the next token. This is the most compute-intensive phase — you're teaching the model the statistical patterns of language from raw text.
Model Architecture
We'll build a decoder-only transformer (GPT-style). This is the same architecture used by GPT, Llama, and Mistral:
1import torch
2import torch.nn as nn
3import torch.nn.functional as F
4import math
5
6class RMSNorm(nn.Module):
7 """RMSNorm (used by Llama, Mistral instead of LayerNorm)."""
8 def __init__(self, dim: int, eps: float = 1e-6):
9 super().__init__()
10 self.eps = eps
11 self.weight = nn.Parameter(torch.ones(dim))
12
13 def forward(self, x):
14 norm = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
15 return x * norm * self.weight
16
17
18class CausalSelfAttention(nn.Module):
19 def __init__(self, d_model: int, n_heads: int, max_seq_len: int = 2048):
20 super().__init__()
21 assert d_model % n_heads == 0
22 self.n_heads = n_heads
23 self.d_k = d_model // n_heads
24
25 self.qkv = nn.Linear(d_model, 3 * d_model, bias=False)
26 self.proj = nn.Linear(d_model, d_model, bias=False)
27
28 # Causal mask
29 mask = torch.triu(torch.ones(max_seq_len, max_seq_len), diagonal=1).bool()
30 self.register_buffer("mask", mask)
31
32 def forward(self, x):
33 B, T, C = x.shape
34
35 # Project Q, K, V in one shot
36 qkv = self.qkv(x).reshape(B, T, 3, self.n_heads, self.d_k)
37 q, k, v = qkv.unbind(dim=2)
38 q, k, v = [t.transpose(1, 2) for t in (q, k, v)] # [B, heads, T, d_k]
39
40 # Scaled dot-product attention with causal mask
41 scores = (q @ k.transpose(-2, -1)) / math.sqrt(self.d_k)
42 scores = scores.masked_fill(self.mask[:T, :T], float("-inf"))
43 attn = F.softmax(scores, dim=-1)
44
45 out = (attn @ v).transpose(1, 2).reshape(B, T, C)
46 return self.proj(out)
47
48
49class FeedForward(nn.Module):
50 """SwiGLU feedforward (used by Llama, Mistral)."""
51 def __init__(self, d_model: int, d_ff: int):
52 super().__init__()
53 self.w1 = nn.Linear(d_model, d_ff, bias=False)
54 self.w2 = nn.Linear(d_ff, d_model, bias=False)
55 self.w3 = nn.Linear(d_model, d_ff, bias=False)
56
57 def forward(self, x):
58 return self.w2(F.silu(self.w1(x)) * self.w3(x))
59
60
61class TransformerBlock(nn.Module):
62 def __init__(self, d_model, n_heads, d_ff):
63 super().__init__()
64 self.norm1 = RMSNorm(d_model)
65 self.attn = CausalSelfAttention(d_model, n_heads)
66 self.norm2 = RMSNorm(d_model)
67 self.ff = FeedForward(d_model, d_ff)
68
69 def forward(self, x):
70 x = x + self.attn(self.norm1(x))
71 x = x + self.ff(self.norm2(x))
72 return x
73
74
75class GPT(nn.Module):
76 def __init__(self, vocab_size, d_model, n_heads, n_layers, d_ff, max_seq_len):
77 super().__init__()
78 self.tok_emb = nn.Embedding(vocab_size, d_model)
79 self.pos_emb = nn.Embedding(max_seq_len, d_model)
80 self.blocks = nn.ModuleList([
81 TransformerBlock(d_model, n_heads, d_ff) for _ in range(n_layers)
82 ])
83 self.norm = RMSNorm(d_model)
84 self.head = nn.Linear(d_model, vocab_size, bias=False)
85
86 # Weight tying
87 self.head.weight = self.tok_emb.weight
88
89 n_params = sum(p.numel() for p in self.parameters())
90 print(f"Model parameters: {n_params / 1e6:.1f}M")
91
92 def forward(self, idx, targets=None):
93 B, T = idx.shape
94 pos = torch.arange(T, device=idx.device)
95
96 x = self.tok_emb(idx) + self.pos_emb(pos)
97 for block in self.blocks:
98 x = block(x)
99 x = self.norm(x)
100 logits = self.head(x)
101
102 loss = None
103 if targets is not None:
104 loss = F.cross_entropy(
105 logits.view(-1, logits.size(-1)),
106 targets.view(-1),
107 ignore_index=-1,
108 )
109 return logits, lossTraining Loop
Here's the complete training loop with mixed precision, gradient accumulation, and cosine LR schedule:
1from torch.cuda.amp import autocast, GradScaler
2import wandb
3
4def train(model, train_loader, val_loader, config):
5 device = torch.device("cuda")
6 model = model.to(device)
7
8 optimizer = torch.optim.AdamW(
9 model.parameters(),
10 lr=config["lr"],
11 betas=(0.9, 0.95),
12 weight_decay=config["weight_decay"],
13 )
14
15 # Cosine LR with warmup
16 def lr_schedule(step):
17 if step < config["warmup_steps"]:
18 return step / config["warmup_steps"]
19 progress = (step - config["warmup_steps"]) / (config["total_steps"] - config["warmup_steps"])
20 return 0.1 + 0.9 * 0.5 * (1 + math.cos(math.pi * progress))
21
22 scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_schedule)
23 scaler = GradScaler() # For mixed precision
24
25 wandb.init(project="llm-from-scratch", config=config)
26
27 global_step = 0
28 for epoch in range(config.get("epochs", 1)):
29 model.train()
30 for batch_idx, (x, y) in enumerate(train_loader):
31 x, y = x.to(device), y.to(device)
32
33 with autocast(dtype=torch.bfloat16):
34 _, loss = model(x, y)
35 loss = loss / config["gradient_accumulation"]
36
37 scaler.scale(loss).backward()
38
39 if (batch_idx + 1) % config["gradient_accumulation"] == 0:
40 scaler.unscale_(optimizer)
41 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
42 scaler.step(optimizer)
43 scaler.update()
44 optimizer.zero_grad()
45 scheduler.step()
46 global_step += 1
47
48 if global_step % 100 == 0:
49 wandb.log({
50 "train/loss": loss.item() * config["gradient_accumulation"],
51 "train/lr": scheduler.get_last_lr()[0],
52 "train/step": global_step,
53 })
54
55 if global_step % 1000 == 0:
56 val_loss = validate(model, val_loader, device)
57 wandb.log({"val/loss": val_loss, "val/perplexity": math.exp(val_loss)})
58 print(f"Step {global_step} | Val loss: {val_loss:.4f} | PPL: {math.exp(val_loss):.2f}")
59 save_checkpoint(model, optimizer, global_step)
60
61 if global_step >= config["total_steps"]:
62 returnValidation
Perplexity is the standard metric for pretraining — lower is better. A perplexity of 20 means the model is "20-way confused" on average.
1@torch.no_grad()
2def validate(model, val_loader, device, max_batches=50):
3 model.eval()
4 total_loss = 0
5 n_batches = 0
6
7 for x, y in val_loader:
8 x, y = x.to(device), y.to(device)
9 with autocast(dtype=torch.bfloat16):
10 _, loss = model(x, y)
11 total_loss += loss.item()
12 n_batches += 1
13 if n_batches >= max_batches:
14 break
15
16 model.train()
17 return total_loss / n_batches
18
19def save_checkpoint(model, optimizer, step, path="checkpoints"):
20 Path(path).mkdir(exist_ok=True)
21 torch.save({
22 "model": model.state_dict(),
23 "optimizer": optimizer.state_dict(),
24 "step": step,
25 }, f"{path}/step_{step}.pt")What to Watch For
| Metric | Healthy | Unhealthy |
|---|---|---|
| Training loss | Smooth downward curve | Spikes, plateaus early |
| Val loss | Tracks train loss closely | Diverges from train loss |
| Gradient norm | Stable around 0.1–1.0 | Exploding (>10) or vanishing |
| Learning rate | Smooth warmup → cosine decay | — |
| Perplexity | Steadily decreasing | Stuck above 100 after 10K steps |
Generating Text (Sanity Check)
1@torch.no_grad()
2def generate(model, tokenizer, prompt, max_new_tokens=100, temperature=0.8):
3 model.eval()
4 device = next(model.parameters()).device
5 ids = tokenizer.encode(prompt).ids
6 x = torch.tensor([ids], device=device)
7
8 for _ in range(max_new_tokens):
9 logits, _ = model(x[:, -2048:]) # Truncate to max seq len
10 logits = logits[:, -1, :] / temperature
11 probs = F.softmax(logits, dim=-1)
12 next_id = torch.multinomial(probs, 1)
13 x = torch.cat([x, next_id], dim=1)
14
15 if next_id.item() == tokenizer.token_to_id("<|eos|>"):
16 break
17
18 return tokenizer.decode(x[0].tolist())Next: Stage 2 — Instruction Tuning, where we teach the model to follow instructions.
Stage 2: Instruction Tuning
A pretrained model can complete text, but it can't follow instructions. Instruction tuning (SFT — Supervised Fine-Tuning) teaches the model to respond helpfully to user requests. This is the step that turns a "text completer" into a "chatbot."
Chat Template
First, define a chat format your model will learn:
1CHAT_TEMPLATE = """<|im_start|>system
2{system}<|im_end|>
3<|im_start|>user
4{user}<|im_end|>
5<|im_start|>assistant
6{assistant}<|im_end|>"""
7
8def format_conversation(system: str, user: str, assistant: str) -> str:
9 return CHAT_TEMPLATE.format(
10 system=system, user=user, assistant=assistant
11 )
12
13# Example
14formatted = format_conversation(
15 system="You are a helpful assistant.",
16 user="What is gradient descent?",
17 assistant="Gradient descent is an optimization algorithm..."
18)Instruction Dataset
1from datasets import load_dataset
2from torch.utils.data import Dataset
3
4class InstructionDataset(Dataset):
5 def __init__(self, tokenizer, max_length=2048, split="train"):
6 # Use a public instruction dataset
7 raw = load_dataset("tatsu-lab/alpaca", split=split)
8 self.samples = []
9 self.tokenizer = tokenizer
10 self.max_length = max_length
11
12 for item in raw:
13 instruction = item["instruction"]
14 if item.get("input"):
15 instruction += f"\n\nInput: {item['input']}"
16
17 text = format_conversation(
18 system="You are a helpful assistant.",
19 user=instruction,
20 assistant=item["output"],
21 )
22 ids = tokenizer.encode(text).ids
23 if len(ids) <= max_length:
24 self.samples.append(ids)
25
26 print(f"Instruction dataset: {len(self.samples)} samples")
27
28 def __len__(self):
29 return len(self.samples)
30
31 def __getitem__(self, idx):
32 ids = self.samples[idx]
33
34 # Pad to max_length
35 padded = ids + [self.tokenizer.token_to_id("<|pad|>")] * (self.max_length - len(ids))
36 x = torch.tensor(padded[:-1], dtype=torch.long)
37 y = torch.tensor(padded[1:], dtype=torch.long)
38
39 # Mask: only compute loss on the assistant's response
40 # Find where assistant response starts
41 y[:len(ids) // 2] = -1 # Simplified — mask system + user tokens
42 return x, yLoRA: Parameter-Efficient Fine-Tuning
Full fine-tuning updates all parameters. LoRA freezes the base model and adds small trainable matrices, reducing memory by ~10x:
1class LoRALinear(nn.Module):
2 """Low-Rank Adaptation for efficient fine-tuning."""
3 def __init__(self, base_layer: nn.Linear, rank: int = 16, alpha: float = 32):
4 super().__init__()
5 self.base = base_layer
6 self.base.weight.requires_grad = False # Freeze base
7
8 d_in, d_out = base_layer.in_features, base_layer.out_features
9 self.lora_A = nn.Parameter(torch.randn(d_in, rank) * 0.01)
10 self.lora_B = nn.Parameter(torch.zeros(rank, d_out))
11 self.scale = alpha / rank
12
13 def forward(self, x):
14 base_out = self.base(x)
15 lora_out = (x @ self.lora_A @ self.lora_B) * self.scale
16 return base_out + lora_out
17
18def apply_lora(model, rank=16, alpha=32):
19 """Apply LoRA to all attention projection layers."""
20 for name, module in model.named_modules():
21 if isinstance(module, CausalSelfAttention):
22 module.qkv = LoRALinear(module.qkv, rank, alpha)
23 module.proj = LoRALinear(module.proj, rank, alpha)
24
25 trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
26 total = sum(p.numel() for p in model.parameters())
27 print(f"LoRA: {trainable/1e6:.1f}M trainable / {total/1e6:.1f}M total ({100*trainable/total:.1f}%)")SFT Training Loop
1def instruction_tune(model, train_ds, val_ds, config):
2 device = torch.device("cuda")
3 model = model.to(device)
4
5 # Only optimize LoRA parameters
6 optimizer = torch.optim.AdamW(
7 [p for p in model.parameters() if p.requires_grad],
8 lr=2e-5, # Much lower LR than pretraining
9 weight_decay=0.01,
10 )
11
12 train_loader = DataLoader(train_ds, batch_size=4, shuffle=True)
13 val_loader = DataLoader(val_ds, batch_size=4)
14
15 for epoch in range(3): # SFT typically needs only 1-3 epochs
16 model.train()
17 epoch_loss = 0
18 for x, y in train_loader:
19 x, y = x.to(device), y.to(device)
20
21 with autocast(dtype=torch.bfloat16):
22 _, loss = model(x, y)
23
24 loss.backward()
25 torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
26 optimizer.step()
27 optimizer.zero_grad()
28 epoch_loss += loss.item()
29
30 val_loss = validate(model, val_loader, device)
31 print(f"Epoch {epoch+1} | Train: {epoch_loss/len(train_loader):.4f} | Val: {val_loss:.4f}")Evaluation
After instruction tuning, evaluate on standard benchmarks:
1def evaluate_instruction_following(model, tokenizer, test_prompts):
2 """Manual evaluation of instruction following quality."""
3 results = []
4 for prompt in test_prompts:
5 text = format_conversation(
6 system="You are a helpful assistant.",
7 user=prompt,
8 assistant="",
9 )
10 # Remove the final <|im_end|> so the model generates the response
11 text = text.rsplit("<|im_end|>", 1)[0]
12
13 response = generate(model, tokenizer, text, max_new_tokens=256)
14 results.append({"prompt": prompt, "response": response})
15 print(f"Q: {prompt}")
16 print(f"A: {response}\n")
17 return results
18
19test_prompts = [
20 "Explain quantum computing in simple terms.",
21 "Write a Python function to find the nth Fibonacci number.",
22 "What are the pros and cons of remote work?",
23 "Summarize the key ideas of the attention mechanism.",
24]SFT Pitfalls
| Issue | Symptom | Fix |
|---|---|---|
| Catastrophic forgetting | Model loses general knowledge | Lower LR, use LoRA, fewer epochs |
| Overfitting | Val loss increases after epoch 1 | More data, higher dropout, early stopping |
| Template leakage | Model outputs `< | im_start |
| Repetition | Model loops the same phrase | Add repetition penalty during generation |
Next: Stage 3 — Reasoning, where we teach the model to think step by step.
Stage 3: Reasoning & Alignment
An instruction-tuned model follows commands, but it doesn't think. This stage teaches the model to reason step-by-step and aligns its behavior with human preferences using RLHF or DPO.
Chain-of-Thought Training Data
The key insight: if you train on data that contains explicit reasoning steps, the model learns to reason. We create "thinking" traces:
1COT_TEMPLATE = """<|im_start|>system
2You are a helpful assistant. Think step by step before answering.<|im_end|>
3<|im_start|>user
4{question}<|im_end|>
5<|im_start|>assistant
6<think>
7{reasoning}
8</think>
9
10{answer}<|im_end|>"""
11
12# Example training data
13cot_examples = [
14 {
15 "question": "If a train travels 120 km in 2 hours, what is its speed in m/s?",
16 "reasoning": """Step 1: Find speed in km/h.
17Speed = distance / time = 120 km / 2 h = 60 km/h.
18
19Step 2: Convert km/h to m/s.
201 km = 1000 m, 1 h = 3600 s.
2160 km/h = 60 × 1000 / 3600 = 16.67 m/s.""",
22 "answer": "The train's speed is approximately 16.67 m/s."
23 },
24 {
25 "question": "A store has a 25% off sale. If an item costs $80, and there's an additional 10% member discount applied after, what's the final price?",
26 "reasoning": """Step 1: Apply the 25% sale discount.
2725% of $80 = $20. Price after sale: $80 - $20 = $60.
28
29Step 2: Apply the 10% member discount on the sale price.
3010% of $60 = $6. Price after member discount: $60 - $6 = $54.
31
32Step 3: The discounts are applied sequentially, not combined.
33Total discount is not 35% — it's 25% then 10% of the reduced price.""",
34 "answer": "The final price is $54.00."
35 },
36]Rejection Sampling for Reasoning Data
Generate multiple attempts and keep only the ones that arrive at the correct answer:
1def generate_cot_data(model, tokenizer, problems, n_samples=8, temperature=0.7):
2 """Generate chain-of-thought data via rejection sampling."""
3 good_samples = []
4
5 for problem in problems:
6 prompt = f"""<|im_start|>system
7Think step by step.<|im_end|>
8<|im_start|>user
9{problem['question']}<|im_end|>
10<|im_start|>assistant
11<think>
12"""
13 candidates = []
14 for _ in range(n_samples):
15 response = generate(model, tokenizer, prompt,
16 max_new_tokens=512, temperature=temperature)
17 # Check if the final answer matches
18 if problem["answer"] in response:
19 candidates.append(response)
20
21 if candidates:
22 # Keep the shortest correct reasoning (Occam's razor)
23 best = min(candidates, key=len)
24 good_samples.append({
25 "question": problem["question"],
26 "response": best,
27 })
28
29 print(f"Generated {len(good_samples)}/{len(problems)} valid CoT samples")
30 return good_samplesDPO: Direct Preference Optimization
DPO is a simpler alternative to RLHF. Instead of training a reward model, you directly optimize on preference pairs (chosen vs rejected):
1class DPOTrainer:
2 def __init__(self, model, ref_model, tokenizer, beta=0.1, lr=5e-7):
3 self.model = model
4 self.ref_model = ref_model # Frozen copy of the SFT model
5 self.tokenizer = tokenizer
6 self.beta = beta
7
8 # Freeze reference model
9 for p in self.ref_model.parameters():
10 p.requires_grad = False
11
12 self.optimizer = torch.optim.AdamW(
13 model.parameters(), lr=lr, weight_decay=0.01
14 )
15
16 def compute_log_probs(self, model, input_ids, labels):
17 """Compute log probabilities of the target tokens."""
18 logits, _ = model(input_ids)
19 log_probs = F.log_softmax(logits, dim=-1)
20
21 # Gather log probs for actual tokens
22 token_log_probs = log_probs.gather(-1, labels.unsqueeze(-1)).squeeze(-1)
23
24 # Mask padding
25 mask = (labels != -1).float()
26 return (token_log_probs * mask).sum(-1) / mask.sum(-1)
27
28 def dpo_loss(self, chosen_ids, chosen_labels, rejected_ids, rejected_labels):
29 """DPO loss: maximize margin between chosen and rejected."""
30 # Policy log probs
31 pi_chosen = self.compute_log_probs(self.model, chosen_ids, chosen_labels)
32 pi_rejected = self.compute_log_probs(self.model, rejected_ids, rejected_labels)
33
34 # Reference log probs (no gradient)
35 with torch.no_grad():
36 ref_chosen = self.compute_log_probs(self.ref_model, chosen_ids, chosen_labels)
37 ref_rejected = self.compute_log_probs(self.ref_model, rejected_ids, rejected_labels)
38
39 # DPO objective
40 pi_logratios = pi_chosen - pi_rejected
41 ref_logratios = ref_chosen - ref_rejected
42 logits = self.beta * (pi_logratios - ref_logratios)
43
44 loss = -F.logsigmoid(logits).mean()
45 return loss
46
47 def train_step(self, batch):
48 loss = self.dpo_loss(
49 batch["chosen_ids"], batch["chosen_labels"],
50 batch["rejected_ids"], batch["rejected_labels"],
51 )
52 loss.backward()
53 torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
54 self.optimizer.step()
55 self.optimizer.zero_grad()
56 return loss.item()Preference Data Format
1preference_pairs = [
2 {
3 "prompt": "Explain recursion.",
4 "chosen": "<think>\nRecursion is when a function calls itself...\n</think>\n\nRecursion is a programming concept where a function calls itself to solve smaller instances of the same problem...",
5 "rejected": "Recursion is when something is recursive."
6 },
7 {
8 "prompt": "Is water wet?",
9 "chosen": "<think>\nThis is a nuanced question. 'Wet' typically means...\n</think>\n\nThis depends on how you define 'wet.' If wet means 'covered in water,' then water itself isn't wet — it *makes* things wet...",
10 "rejected": "Yes, water is wet because it's a liquid."
11 },
12]Evaluation: Reasoning Benchmarks
1def evaluate_reasoning(model, tokenizer, benchmark="gsm8k"):
2 """Evaluate on GSM8K (grade school math)."""
3 dataset = load_dataset("gsm8k", "main", split="test")
4 correct = 0
5 total = 0
6
7 for item in dataset:
8 prompt = f"""<|im_start|>system
9Solve step by step. Put your final numerical answer after "#### ".<|im_end|>
10<|im_start|>user
11{item['question']}<|im_end|>
12<|im_start|>assistant
13<think>
14"""
15 response = generate(model, tokenizer, prompt, max_new_tokens=512, temperature=0.0)
16
17 # Extract the final answer
18 predicted = extract_answer(response)
19 expected = item["answer"].split("####")[-1].strip()
20
21 if predicted == expected:
22 correct += 1
23 total += 1
24
25 accuracy = correct / total
26 print(f"{benchmark} accuracy: {accuracy:.1%} ({correct}/{total})")
27 return accuracyInference & Efficiency Metrics
These metrics measure how well an AI model runs on hardware and its responsiveness in production.
Throughput (Tokens Per Second)
Throughput measures the total volume of output tokens a model generates every second. High TPS is critical for high-traffic applications and batch processing.
For a given model, throughput depends on:
- Hardware: GPU type (A100, H100), number of GPUs, interconnect bandwidth
- Batch size: Larger batches improve throughput but increase latency
- Quantization: INT8/INT4 quantization reduces memory and increases speed at the cost of some quality
- Serving framework: vLLM, TensorRT-LLM, and SGLang provide optimized inference kernels
| Model Size | Typical TPS (A100) | Typical TPS (H100) |
|---|---|---|
| 7B | 40-80 | 80-150 |
| 13B | 25-50 | 50-100 |
| 70B | 8-15 | 20-40 |
Time to First Token (TTFT)
TTFT is the delay between a user sending a prompt and seeing the very first character of the response. Sub-200ms is the standard for a "snappy" user experience.
TTFT is dominated by the prefill phase � where the model processes all input tokens in parallel through KV-cache computation. Techniques to reduce TTFT:
- Speculative decoding: Use a small draft model to propose tokens, verified by the large model
- Prefix caching: Cache the KV states of common system prompts
- Chunked prefill: Break long prompts into chunks to overlap with decode
Context Window
The context window is the "short-term memory" of the model, measured in tokens. A larger window allows the AI to process entire books or massive codebases in a single pass.
| Model | Context Length | ~Pages of Text |
|---|---|---|
| GPT-4o | 128K | ~300 pages |
| Claude 3.5 | 200K | ~500 pages |
| Gemini 1.5 Pro | 2M | ~5,000 pages |
Key techniques for extending context:
- RoPE scaling: Rotary Position Embeddings with frequency scaling
- Ring Attention: Distributes long sequences across GPUs
- Sliding Window Attention: Used by Mistral to limit attention to local windows
GPU & Memory Utilization
Tracks how much hardware resources (VRAM) the model consumes. Lower utilization per query allows for more simultaneous users.
Key optimization techniques:
- FlashAttention-2: Reduces memory from O(n^2) to O(n) for attention computation
- PagedAttention: Used by vLLM, manages KV-cache memory like OS virtual memory pages
- Continuous batching: Dynamically adds/removes requests from running batches
- Model parallelism: Tensor, pipeline, and expert parallelism for large models
Quality & Intelligence Metrics
These quantify how "smart" or accurate a model's outputs are.
Perplexity
Perplexity is a mathematical measure of how "surprised" a model is by new data. Lower is better, indicating the model has a stronger internal grasp of language patterns.
Perplexity = exp(average cross-entropy loss). A perplexity of 10 means the model is, on average, "10-way uncertain" about each next token.
| Stage | Typical Perplexity |
|---|---|
| Early pretraining | 100-1000+ |
| Converged pretraining | 5-15 |
| Domain-specific fine-tune | 3-8 |
Important caveat: Perplexity only measures next-token prediction quality on a held-out set. A model with great perplexity can still produce bad instruction-following results.
Accuracy & F1 Score
Standard metrics for classification and extraction tasks:
- Accuracy: Percentage of correct predictions overall
- Precision: Of items flagged as positive, how many actually are? (Reduces false positives)
- Recall: Of all actual positives, how many did we find? (Reduces false negatives)
- F1 Score: The harmonic mean of precision and recall � the "gold standard" for balancing both
For LLM benchmarks, the most commonly referenced evaluations include:
- MMLU: 57 subjects ranging from STEM to humanities
- HumanEval: Code generation benchmark
- GSM8K: Grade school math reasoning
- HellaSwag: Commonsense reasoning
Elo Rating (Human Preference)
Elo rating, borrowed from chess, is used by the LMSYS Chatbot Arena to rank models based on blind side-by-side human testing. Users see two anonymous model outputs and pick the better one.
This is arguably the most reliable quality signal because:
- It captures holistic quality (helpfulness, safety, style, accuracy)
- It's resistant to benchmark gaming � models can't overfit to specific test sets
- It reflects real user preferences, not proxy metrics
Hallucination Rate
The hallucination rate measures how frequently a model generates factually incorrect or unsupported information. This is one of the biggest challenges in deploying LLMs.
Two types of hallucination:
- Intrinsic: Contradicts the provided source material
- Extrinsic: Generates information not supported by any source
Mitigation strategies:
- Retrieval-Augmented Generation (RAG): Ground responses in retrieved documents
- Chain-of-thought prompting: Force step-by-step reasoning
- Citation training: Train models to cite sources (as done in this tutorial's markdown!)
- Confidence calibration: Train models to say "I don't know" when uncertain
Scaling Laws
Both efficiency and quality metrics improve with scale, but in predictable ways described by scaling laws:
- Compute-optimal training (Chinchilla scaling): The optimal model size and data size grow proportionally with compute budget
- Inference scaling: Techniques like test-time compute allow models to "think longer" on harder problems, trading latency for quality
- Data scaling: Textbooks Are All You Need showed that high-quality data can substitute for raw scale
The Full Training Pipeline
| Stage | What | Data | Epochs | LR | Result |
|---|---|---|---|---|---|
| 0 | Preparation | Raw text corpus | — | — | Tokenizer + data pipeline |
| 1 | Pretraining | Raw text | 1 | 3e-4 | Next-token predictor |
| 2 | SFT | Instruction pairs | 1–3 | 2e-5 | Instruction follower |
| 3a | CoT | Reasoning traces | 1–2 | 1e-5 | Step-by-step thinker |
| 3b | DPO | Preference pairs | 1 | 5e-7 | Aligned reasoner |
What You've Built
By completing all four stages, you've built a model that:
- Understands language (pretraining)
- Follows instructions (SFT)
- Thinks before answering (chain-of-thought)
- Prefers good answers over bad ones (DPO alignment)
This is the same pipeline used by frontier models — just at a smaller scale. The architecture, loss functions, and training stages are identical.
This completes the "Train Your LLM from Scratch" series. For production-scale training, explore DeepSpeed, FSDP, and multi-node distributed training.