Understanding Transformer Architecture
A visual guide to the transformer architecture that powers modern LLMs like GPT and LLaMA.
Prerequisites
- Linear algebra basics
- Neural network fundamentals
- Python/PyTorch experience
Understanding Transformer Architecture
The Transformer architecture, introduced in the landmark paper "Attention Is All You Need" (2017), revolutionized natural language processing and laid the foundation for modern LLMs like GPT, LLaMA, and Claude. In this tutorial, we'll break down the architecture piece by piece.
The Big Picture
At its core, a Transformer processes sequences by:
- Converting tokens to embeddings
- Adding positional information
- Processing through attention and feedforward layers
- Producing output predictions
Input Tokens → Embeddings → [N × Transformer Blocks] → Output
↓
Each block contains:
- Multi-Head Attention
- Feed-Forward Network
- Layer Normalization
- Residual Connections
Step 1: Token Embeddings
First, we convert discrete tokens (words, subwords) into continuous vectors:
1import torch
2import torch.nn as nn
3
4class TokenEmbedding(nn.Module):
5 def __init__(self, vocab_size: int, d_model: int):
6 super().__init__()
7 self.embedding = nn.Embedding(vocab_size, d_model)
8 self.d_model = d_model
9
10 def forward(self, x: torch.Tensor) -> torch.Tensor:
11 # x: [batch_size, seq_len] → [batch_size, seq_len, d_model]
12 # Scale by sqrt(d_model) as per original paper
13 return self.embedding(x) * (self.d_model ** 0.5)
14
15# Example
16vocab_size = 50000
17d_model = 512
18embedding = TokenEmbedding(vocab_size, d_model)
19
20tokens = torch.tensor([[1, 42, 156, 7]]) # [1, 4]
21embedded = embedding(tokens) # [1, 4, 512]Step 2: Positional Encoding
Unlike RNNs, Transformers process all tokens in parallel. To capture sequence order, we add positional information:
1import math
2
3class PositionalEncoding(nn.Module):
4 def __init__(self, d_model: int, max_seq_len: int = 5000, dropout: float = 0.1):
5 super().__init__()
6 self.dropout = nn.Dropout(dropout)
7
8 # Create position encodings
9 position = torch.arange(max_seq_len).unsqueeze(1) # [max_seq_len, 1]
10 div_term = torch.exp(
11 torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model)
12 )
13
14 pe = torch.zeros(max_seq_len, d_model)
15 pe[:, 0::2] = torch.sin(position * div_term) # Even indices
16 pe[:, 1::2] = torch.cos(position * div_term) # Odd indices
17
18 # Register as buffer (not a parameter)
19 self.register_buffer('pe', pe.unsqueeze(0)) # [1, max_seq_len, d_model]
20
21 def forward(self, x: torch.Tensor) -> torch.Tensor:
22 # x: [batch_size, seq_len, d_model]
23 seq_len = x.size(1)
24 x = x + self.pe[:, :seq_len, :]
25 return self.dropout(x)Why sinusoidal? The sine/cosine functions allow the model to:
- Learn relative positions (PE[pos+k] can be represented as a function of PE[pos])
- Generalize to longer sequences than seen during training
Step 3: Self-Attention (The Core Innovation)
Self-attention computes relationships between all pairs of tokens:
1class ScaledDotProductAttention(nn.Module):
2 def __init__(self, dropout: float = 0.1):
3 super().__init__()
4 self.dropout = nn.Dropout(dropout)
5
6 def forward(
7 self,
8 query: torch.Tensor, # [batch, heads, seq_len, d_k]
9 key: torch.Tensor, # [batch, heads, seq_len, d_k]
10 value: torch.Tensor, # [batch, heads, seq_len, d_v]
11 mask: torch.Tensor = None
12 ) -> tuple[torch.Tensor, torch.Tensor]:
13 d_k = query.size(-1)
14
15 # Compute attention scores
16 # [batch, heads, seq_len, seq_len]
17 scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k)
18
19 # Apply mask (for causal attention in decoders)
20 if mask is not None:
21 scores = scores.masked_fill(mask == 0, float('-inf'))
22
23 # Softmax to get attention weights
24 attention_weights = torch.softmax(scores, dim=-1)
25 attention_weights = self.dropout(attention_weights)
26
27 # Apply attention to values
28 output = torch.matmul(attention_weights, value)
29
30 return output, attention_weightsIntuition: Each token "queries" for relevant information from all other tokens. The dot product between query and key determines relevance, and values carry the actual information.
Step 4: Multi-Head Attention
Instead of single attention, we use multiple "heads" to capture different types of relationships:
1class MultiHeadAttention(nn.Module):
2 def __init__(self, d_model: int, num_heads: int, dropout: float = 0.1):
3 super().__init__()
4 assert d_model % num_heads == 0
5
6 self.d_model = d_model
7 self.num_heads = num_heads
8 self.d_k = d_model // num_heads
9
10 # Linear projections
11 self.W_q = nn.Linear(d_model, d_model)
12 self.W_k = nn.Linear(d_model, d_model)
13 self.W_v = nn.Linear(d_model, d_model)
14 self.W_o = nn.Linear(d_model, d_model)
15
16 self.attention = ScaledDotProductAttention(dropout)
17
18 def forward(
19 self,
20 query: torch.Tensor,
21 key: torch.Tensor,
22 value: torch.Tensor,
23 mask: torch.Tensor = None
24 ) -> torch.Tensor:
25 batch_size = query.size(0)
26
27 # Project and reshape for multi-head: [batch, seq, d_model] → [batch, heads, seq, d_k]
28 Q = self.W_q(query).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
29 K = self.W_k(key).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
30 V = self.W_v(value).view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)
31
32 # Apply attention
33 attn_output, _ = self.attention(Q, K, V, mask)
34
35 # Concatenate heads: [batch, heads, seq, d_k] → [batch, seq, d_model]
36 attn_output = attn_output.transpose(1, 2).contiguous().view(
37 batch_size, -1, self.d_model
38 )
39
40 # Final projection
41 return self.W_o(attn_output)Why multiple heads? Different heads can focus on:
- Syntactic relationships (subject-verb agreement)
- Semantic relationships (word meanings)
- Positional patterns (nearby words)
Step 5: Feed-Forward Network
After attention, each position passes through a feedforward network:
1class FeedForward(nn.Module):
2 def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1):
3 super().__init__()
4 self.linear1 = nn.Linear(d_model, d_ff)
5 self.linear2 = nn.Linear(d_ff, d_model)
6 self.dropout = nn.Dropout(dropout)
7 self.activation = nn.GELU() # Modern transformers use GELU
8
9 def forward(self, x: torch.Tensor) -> torch.Tensor:
10 # x: [batch, seq, d_model]
11 x = self.linear1(x) # [batch, seq, d_ff]
12 x = self.activation(x)
13 x = self.dropout(x)
14 x = self.linear2(x) # [batch, seq, d_model]
15 return xTypically, d_ff = 4 * d_model. This expansion allows the model to process information in a higher-dimensional space.
Step 6: Transformer Block
Combining everything with residual connections and layer normalization:
1class TransformerBlock(nn.Module):
2 def __init__(
3 self,
4 d_model: int,
5 num_heads: int,
6 d_ff: int,
7 dropout: float = 0.1,
8 pre_norm: bool = True # Modern transformers use pre-norm
9 ):
10 super().__init__()
11 self.pre_norm = pre_norm
12
13 self.attention = MultiHeadAttention(d_model, num_heads, dropout)
14 self.ff = FeedForward(d_model, d_ff, dropout)
15
16 self.norm1 = nn.LayerNorm(d_model)
17 self.norm2 = nn.LayerNorm(d_model)
18
19 self.dropout = nn.Dropout(dropout)
20
21 def forward(self, x: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor:
22 if self.pre_norm:
23 # Pre-norm (GPT-style)
24 attn_out = self.attention(
25 self.norm1(x), self.norm1(x), self.norm1(x), mask
26 )
27 x = x + self.dropout(attn_out)
28
29 ff_out = self.ff(self.norm2(x))
30 x = x + self.dropout(ff_out)
31 else:
32 # Post-norm (original transformer)
33 attn_out = self.attention(x, x, x, mask)
34 x = self.norm1(x + self.dropout(attn_out))
35
36 ff_out = self.ff(x)
37 x = self.norm2(x + self.dropout(ff_out))
38
39 return xStep 7: Complete Decoder (GPT-style)
Putting it all together for a decoder-only model (like GPT):
1class GPTModel(nn.Module):
2 def __init__(
3 self,
4 vocab_size: int,
5 d_model: int = 512,
6 num_heads: int = 8,
7 num_layers: int = 6,
8 d_ff: int = 2048,
9 max_seq_len: int = 1024,
10 dropout: float = 0.1,
11 ):
12 super().__init__()
13
14 self.token_embedding = TokenEmbedding(vocab_size, d_model)
15 self.pos_encoding = PositionalEncoding(d_model, max_seq_len, dropout)
16
17 self.layers = nn.ModuleList([
18 TransformerBlock(d_model, num_heads, d_ff, dropout)
19 for _ in range(num_layers)
20 ])
21
22 self.norm = nn.LayerNorm(d_model)
23 self.output = nn.Linear(d_model, vocab_size)
24
25 def forward(self, x: torch.Tensor) -> torch.Tensor:
26 # x: [batch, seq_len] token indices
27 seq_len = x.size(1)
28
29 # Create causal mask
30 mask = torch.triu(
31 torch.ones(seq_len, seq_len, device=x.device), diagonal=1
32 ).bool()
33 mask = ~mask # Invert: True = attend, False = mask
34
35 # Embedding + positional encoding
36 x = self.token_embedding(x)
37 x = self.pos_encoding(x)
38
39 # Transformer blocks
40 for layer in self.layers:
41 x = layer(x, mask)
42
43 # Output projection
44 x = self.norm(x)
45 logits = self.output(x) # [batch, seq, vocab_size]
46
47 return logitsKey Concepts Summary
| Component | Purpose |
|---|---|
| Token Embedding | Convert discrete tokens to vectors |
| Positional Encoding | Add sequence order information |
| Self-Attention | Model relationships between all tokens |
| Multi-Head | Capture different types of relationships |
| Feed-Forward | Process each position independently |
| Layer Norm | Stabilize training |
| Residual Connections | Enable gradient flow in deep networks |
Modern Improvements
Since the original paper, several improvements have been made:
- Pre-normalization: Apply LayerNorm before (not after) attention and FFN
- Rotary Position Embeddings (RoPE): Better handling of relative positions
- Grouped Query Attention (GQA): More efficient multi-head attention
- SwiGLU Activation: Improved feedforward networks
Next Steps
Now that you understand the architecture:
- Implement a small GPT from scratch
- Experiment with different hyperparameters
- Explore pre-trained models on Hugging Face
- Study specific improvements like FlashAttention
This tutorial is part of the ML Fundamentals series. Understanding transformers is essential for working with modern LLMs.