IntermediateML Fundamentals

Understanding Transformer Architecture

A visual guide to the transformer architecture that powers modern LLMs like GPT and LLaMA.

25 min
Deep Learning, NLP, Attention

Prerequisites

  • Linear algebra basics
  • Neural network fundamentals
  • Python/PyTorch experience

Understanding Transformer Architecture

The Transformer architecture, introduced in the landmark paper "Attention Is All You Need" (2017), revolutionized natural language processing and laid the foundation for modern LLMs like GPT, LLaMA, and Claude. In this tutorial, we'll break down the architecture piece by piece.

The Big Picture

At its core, a Transformer processes sequences by:

  1. Converting tokens to embeddings
  2. Adding positional information
  3. Processing through attention and feedforward layers
  4. Producing output predictions
Input Tokens → Embeddings → [N × Transformer Blocks] → Output ↓ Each block contains: - Multi-Head Attention - Feed-Forward Network - Layer Normalization - Residual Connections

Step 1: Token Embeddings

First, we convert discrete tokens (words, subwords) into continuous vectors:

Python
1import torch
2import torch.nn as nn
3
4class TokenEmbedding(nn.Module):
5    def __init__(self, vocab_size: int, d_model: int):
6        super().__init__()
7        self.embedding = nn.Embedding(vocab_size, d_model)
8        self.d_model = d_model
9
10    def forward(self, x: torch.Tensor) -> torch.Tensor:
11        # x: [batch_size, seq_len] → [batch_size, seq_len, d_model]
12        # Scale by sqrt(d_model) as per original paper
13        return self.embedding(x) * (self.d_model ** 0.5)
14
15# Example
16vocab_size = 50000
17d_model = 512
18embedding = TokenEmbedding(vocab_size, d_model)
19
20tokens = torch.tensor([[1, 42, 156, 7]])  # [1, 4]
21embedded = embedding(tokens)  # [1, 4, 512]

Step 2: Positional Encoding

Unlike RNNs, Transformers process all tokens in parallel. To capture sequence order, we add positional information:

Python
1import math
2
3class PositionalEncoding(nn.Module):
4    def __init__(self, d_model: int, max_seq_len: int = 5000, dropout: float = 0.1):
5        super().__init__()
6        self.dropout = nn.Dropout(dropout)
7
8        # Create position encodings
9        position = torch.arange(max_seq_len).unsqueeze(1)  # [max_seq_len, 1]
10        div_term = torch.exp(
11            torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
12        )
13
14        pe = torch.zeros(max_seq_len, d_model)
15        pe[:, 0::2] = torch.sin(position * div_term)  # Even indices
16        pe[:, 1::2] = torch.cos(position * div_term)  # Odd indices
17
18        # Register as buffer (not a parameter)
19        self.register_buffer('pe', pe.unsqueeze(0))  # [1, max_seq_len, d_model]
20
21    def forward(self, x: torch.Tensor) -> torch.Tensor:
22        # x: [batch_size, seq_len, d_model]
23        seq_len = x.size(1)
24        x = x + self.pe[:, :seq_len, :]
25        return self.dropout(x)

Why sinusoidal? The sine/cosine functions allow the model to:

  • Learn relative positions (PE[pos+k] can be represented as a function of PE[pos])
  • Generalize to longer sequences than seen during training

Step 3: Self-Attention (The Core Innovation)

Self-attention computes relationships between all pairs of tokens:

Python
1class ScaledDotProductAttention(nn.Module):
2    def __init__(self, dropout: float = 0.1):
3        super().__init__()
4        self.dropout = nn.Dropout(dropout)
5
6    def forward(
7        self,
8        query: torch.Tensor,    # [batch, heads, seq_len, d_k]
9        key: torch.Tensor,      # [batch, heads, seq_len, d_k]
10        value: torch.Tensor,    # [batch, heads, seq_len, d_v]
11        mask: torch.Tensor = None
12    ) -> tuple[torch.Tensor, torch.Tensor]:
13        d_k = query.size(-1)
14
15        # Compute attention scores
16        # [batch, heads, seq_len, seq_len]
17        scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
18
19        # Apply mask (for causal attention in decoders)
20        if mask is not None:
21            scores = scores.masked_fill(mask == 0, float('-inf'))
22
23        # Softmax to get attention weights
24        attention_weights = torch.softmax(scores, dim=-1)
25        attention_weights = self.dropout(attention_weights)
26
27        # Apply attention to values
28        output = torch.matmul(attention_weights, value)
29
30        return output, attention_weights

Intuition: Each token "queries" for relevant information from all other tokens. The dot product between query and key determines relevance, and values carry the actual information.

Step 4: Multi-Head Attention

Instead of single attention, we use multiple "heads" to capture different types of relationships:

Python
1class MultiHeadAttention(nn.Module):
2    def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
3        super().__init__()
4        assert d_model % num_heads == 0
5
6        self.d_model = d_model
7        self.num_heads = num_heads
8        self.d_k = d_model // num_heads
9
10        # Linear projections
11        self.W_q = nn.Linear(d_model, d_model)
12        self.W_k = nn.Linear(d_model, d_model)
13        self.W_v = nn.Linear(d_model, d_model)
14        self.W_o = nn.Linear(d_model, d_model)
15
16        self.attention = ScaledDotProductAttention(dropout)
17
18    def forward(
19        self,
20        query: torch.Tensor,
21        key: torch.Tensor,
22        value: torch.Tensor,
23        mask: torch.Tensor = None
24    ) -> torch.Tensor:
25        batch_size = query.size(0)
26
27        # Project and reshape for multi-head: [batch, seq, d_model] → [batch, heads, seq, d_k]
28        Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
29        K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
30        V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
31
32        # Apply attention
33        attn_output, _ = self.attention(Q, K, V, mask)
34
35        # Concatenate heads: [batch, heads, seq, d_k] → [batch, seq, d_model]
36        attn_output = attn_output.transpose(1, 2).contiguous().view(
37            batch_size, -1, self.d_model
38        )
39
40        # Final projection
41        return self.W_o(attn_output)

Why multiple heads? Different heads can focus on:

  • Syntactic relationships (subject-verb agreement)
  • Semantic relationships (word meanings)
  • Positional patterns (nearby words)

Step 5: Feed-Forward Network

After attention, each position passes through a feedforward network:

Python
1class FeedForward(nn.Module):
2    def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
3        super().__init__()
4        self.linear1 = nn.Linear(d_model, d_ff)
5        self.linear2 = nn.Linear(d_ff, d_model)
6        self.dropout = nn.Dropout(dropout)
7        self.activation = nn.GELU()  # Modern transformers use GELU
8
9    def forward(self, x: torch.Tensor) -> torch.Tensor:
10        # x: [batch, seq, d_model]
11        x = self.linear1(x)      # [batch, seq, d_ff]
12        x = self.activation(x)
13        x = self.dropout(x)
14        x = self.linear2(x)      # [batch, seq, d_model]
15        return x

Typically, d_ff = 4 * d_model. This expansion allows the model to process information in a higher-dimensional space.

Step 6: Transformer Block

Combining everything with residual connections and layer normalization:

Python
1class TransformerBlock(nn.Module):
2    def __init__(
3        self,
4        d_model: int,
5        num_heads: int,
6        d_ff: int,
7        dropout: float = 0.1,
8        pre_norm: bool = True  # Modern transformers use pre-norm
9    ):
10        super().__init__()
11        self.pre_norm = pre_norm
12
13        self.attention = MultiHeadAttention(d_model, num_heads, dropout)
14        self.ff = FeedForward(d_model, d_ff, dropout)
15
16        self.norm1 = nn.LayerNorm(d_model)
17        self.norm2 = nn.LayerNorm(d_model)
18
19        self.dropout = nn.Dropout(dropout)
20
21    def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
22        if self.pre_norm:
23            # Pre-norm (GPT-style)
24            attn_out = self.attention(
25                self.norm1(x), self.norm1(x), self.norm1(x), mask
26            )
27            x = x + self.dropout(attn_out)
28
29            ff_out = self.ff(self.norm2(x))
30            x = x + self.dropout(ff_out)
31        else:
32            # Post-norm (original transformer)
33            attn_out = self.attention(x, x, x, mask)
34            x = self.norm1(x + self.dropout(attn_out))
35
36            ff_out = self.ff(x)
37            x = self.norm2(x + self.dropout(ff_out))
38
39        return x

Step 7: Complete Decoder (GPT-style)

Putting it all together for a decoder-only model (like GPT):

Python
1class GPTModel(nn.Module):
2    def __init__(
3        self,
4        vocab_size: int,
5        d_model: int = 512,
6        num_heads: int = 8,
7        num_layers: int = 6,
8        d_ff: int = 2048,
9        max_seq_len: int = 1024,
10        dropout: float = 0.1,
11    ):
12        super().__init__()
13
14        self.token_embedding = TokenEmbedding(vocab_size, d_model)
15        self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
16
17        self.layers = nn.ModuleList([
18            TransformerBlock(d_model, num_heads, d_ff, dropout)
19            for _ in range(num_layers)
20        ])
21
22        self.norm = nn.LayerNorm(d_model)
23        self.output = nn.Linear(d_model, vocab_size)
24
25    def forward(self, x: torch.Tensor) -> torch.Tensor:
26        # x: [batch, seq_len] token indices
27        seq_len = x.size(1)
28
29        # Create causal mask
30        mask = torch.triu(
31            torch.ones(seq_len, seq_len, device=x.device), diagonal=1
32        ).bool()
33        mask = ~mask  # Invert: True = attend, False = mask
34
35        # Embedding + positional encoding
36        x = self.token_embedding(x)
37        x = self.pos_encoding(x)
38
39        # Transformer blocks
40        for layer in self.layers:
41            x = layer(x, mask)
42
43        # Output projection
44        x = self.norm(x)
45        logits = self.output(x)  # [batch, seq, vocab_size]
46
47        return logits

Key Concepts Summary

ComponentPurpose
Token EmbeddingConvert discrete tokens to vectors
Positional EncodingAdd sequence order information
Self-AttentionModel relationships between all tokens
Multi-HeadCapture different types of relationships
Feed-ForwardProcess each position independently
Layer NormStabilize training
Residual ConnectionsEnable gradient flow in deep networks

Modern Improvements

Since the original paper, several improvements have been made:

  1. Pre-normalization: Apply LayerNorm before (not after) attention and FFN
  2. Rotary Position Embeddings (RoPE): Better handling of relative positions
  3. Grouped Query Attention (GQA): More efficient multi-head attention
  4. SwiGLU Activation: Improved feedforward networks

Next Steps

Now that you understand the architecture:

  • Implement a small GPT from scratch
  • Experiment with different hyperparameters
  • Explore pre-trained models on Hugging Face
  • Study specific improvements like FlashAttention

This tutorial is part of the ML Fundamentals series. Understanding transformers is essential for working with modern LLMs.

Part of ML Fundamentals