blogs
AI & ML

Understanding Continual Learning in Neural Networks

Zizhao Hu
Zizhao Hu
January 10, 2024
12 min
Understanding Continual Learning in Neural Networks

Understanding Continual Learning in Neural Networks

One of the most significant challenges in modern AI is building systems that can learn continuously without forgetting previously acquired knowledge. This problem, known as catastrophic forgetting, has been a fundamental limitation of neural networks since their inception.

Continual Learning Visualization

The Problem: Catastrophic Forgetting

When a neural network is trained on a new task, the weight updates that optimize performance on the new task often degrade performance on previously learned tasks. This happens because:

  1. Shared representations: Neural networks use distributed representations where multiple tasks share the same parameters
  2. Gradient interference: Updates for new tasks can directly conflict with the optimal parameters for old tasks
  3. No explicit memory: Standard networks have no mechanism to protect important learned information

Consider this scenario: You train a model to classify cats vs. dogs with 95% accuracy. Then you train the same model on birds vs. fish. After the second training phase, your cat/dog accuracy might drop to 50%—essentially random chance.

Python
1# Demonstrating catastrophic forgetting
2import torch
3import torch.nn as nn
4
5class SimpleClassifier(nn.Module):
6    def __init__(self, input_dim, hidden_dim, output_dim):
7        super().__init__()
8        self.fc1 = nn.Linear(input_dim, hidden_dim)
9        self.fc2 = nn.Linear(hidden_dim, output_dim)
10
11    def forward(self, x):
12        x = torch.relu(self.fc1(x))
13        return self.fc2(x)
14
15# After training on Task A, then Task B:
16# Task A accuracy: 95% -> 52%  (catastrophic forgetting!)
17# Task B accuracy: 0% -> 94%

Modern Approaches to Continual Learning

1. Replay-Based Methods

The most intuitive approach is to store examples from previous tasks and replay them during training on new tasks. This is analogous to how humans consolidate memories during sleep.

Experience Replay maintains a buffer of past examples:

Python
1class ExperienceReplayBuffer:
2    def __init__(self, max_size=1000):
3        self.buffer = []
4        self.max_size = max_size
5
6    def add(self, sample):
7        if len(self.buffer) >= self.max_size:
8            # Replace random sample (reservoir sampling)
9            idx = random.randint(0, len(self.buffer) - 1)
10            self.buffer[idx] = sample
11        else:
12            self.buffer.append(sample)
13
14    def sample(self, batch_size):
15        return random.sample(self.buffer, min(batch_size, len(self.buffer)))

The key insight is that by mixing old and new data during training, we can maintain performance on all tasks. However, this requires storing raw data, which raises privacy concerns and storage limitations.

2. Regularization-Based Methods

Instead of storing data, we can add constraints to the optimization process that prevent important weights from changing too much.

Elastic Weight Consolidation (EWC) uses the Fisher information matrix to identify important parameters:

Python
1def ewc_loss(model, fisher_matrix, old_params, lambda_ewc=1000):
2    """
3    EWC regularization loss
4    """
5    loss = 0
6    for name, param in model.named_parameters():
7        if name in fisher_matrix:
8            # Penalize changes to important parameters
9            loss += (fisher_matrix[name] * (param - old_params[name]) ** 2).sum()
10    return lambda_ewc * loss

The intuition is that parameters with high Fisher information are crucial for previous tasks, so we should constrain their updates.

3. Architecture-Based Methods

Another approach is to dynamically modify the network architecture for each new task.

Progressive Neural Networks add new columns for each task while freezing previous columns:

Each new task gets its own column (set of layers), while all previous columns are frozen. Task 2 can read from Column 1 via lateral connections, and Task 3 can read from both Column 1 and Column 2.

This completely eliminates forgetting but at the cost of growing model size.

DREAM: Difficulty-Aware Replay

In my research, I've been working on a method called DREAM (Difficulty-REplay-Augmented Memory) that combines the benefits of replay with intelligent sample selection.

The key insight is that not all samples are equally important. Difficult samples—those the model struggles with—often lie near decision boundaries and are more informative for maintaining performance.

Python
1def compute_difficulty(model, sample, label):
2    """
3    Compute sample difficulty based on prediction confidence
4    """
5    with torch.no_grad():
6        output = model(sample)
7        probs = torch.softmax(output, dim=-1)
8        confidence = probs[label]
9        # Lower confidence = higher difficulty
10        return 1.0 - confidence.item()

By prioritizing difficult samples in the replay buffer, DREAM achieves better performance with smaller memory footprints compared to random replay.

Benchmarking Continual Learning

The field has developed several standard benchmarks:

BenchmarkTasksDescription
Split MNIST5Digit pairs: 0-1, 2-3, 4-5, 6-7, 8-9
Split CIFAR-105Image class pairs
Permuted MNIST10+Same task with permuted pixels
CORe5050Object recognition from video

Key Metrics

  1. Average Accuracy: Mean accuracy across all tasks after training
  2. Forgetting Measure: How much accuracy drops on old tasks
  3. Forward Transfer: Does learning help future tasks?
  4. Backward Transfer: Does new learning improve old tasks?

The Road Ahead

Continual learning remains an open challenge. Current methods still fall short of human-like lifelong learning capabilities. Key research directions include:

  • Meta-learning for continual learning: Learning how to learn without forgetting
  • Neuro-inspired approaches: Drawing from how the brain consolidates memories
  • Curriculum learning: Ordering tasks to maximize positive transfer
  • Sparse representations: Using only a subset of parameters per task

As AI systems become more prevalent in real-world applications, the ability to learn continuously will become essential. A self-driving car, for instance, must adapt to new road conditions without forgetting how to handle familiar ones.

Conclusion

Catastrophic forgetting is a fundamental challenge that highlights the gap between current AI systems and biological intelligence. While significant progress has been made with replay, regularization, and architectural methods, we're still far from achieving true lifelong learning.

The most promising approaches combine multiple strategies: replay for memory consolidation, regularization for parameter protection, and architectural innovations for scalability. As we continue to push the boundaries, the goal is clear—AI systems that learn like we do: continuously, efficiently, and without forgetting.


This post is based on my ongoing research in continual learning at USC. For more details, check out my publications on Google Scholar.

Related content