blogs
AI & ML

Multi-Modal Learning: Bridging Vision and Language

Zizhao Hu
Zizhao Hu
December 20, 2023
15 min
Multi-Modal Learning: Bridging Vision and Language

Multi-Modal Learning: Bridging Vision and Language

Humans effortlessly combine information from multiple senses—we see an apple and simultaneously recall its taste, texture, and the word "apple." For decades, AI systems processed each modality in isolation. But recent breakthroughs in multi-modal learning are finally enabling machines to connect vision and language in powerful ways.

Multi-Modal AI Visualization

Why Multi-Modal Matters

Consider the limitations of single-modality AI:

  • Vision-only: A model can identify "a person running" but can't describe why they're running or respond to questions
  • Language-only: A model can discuss "red sports cars" but has no grounding in what "red" actually looks like

Multi-modal learning addresses these limitations by creating shared representations that bridge different modalities.

The Evolution of Vision-Language Models

Early Approaches: Separate Encoders

The first generation of vision-language models used separate, pre-trained encoders for each modality:

Python
1class EarlyVLM(nn.Module):
2    def __init__(self):
3        self.image_encoder = ResNet50(pretrained=True)  # Vision
4        self.text_encoder = LSTM(vocab_size=30000)      # Language
5        self.fusion = nn.Linear(2048 + 512, 1024)       # Late fusion
6
7    def forward(self, image, text):
8        img_features = self.image_encoder(image)
9        txt_features = self.text_encoder(text)
10        combined = torch.cat([img_features, txt_features], dim=-1)
11        return self.fusion(combined)

The problem? The representations were learned independently and didn't share a common semantic space.

The Transformer Revolution

The transformer architecture changed everything. Its attention mechanism naturally handles sequences of any kind—whether image patches or word tokens.

Vision Transformer (ViT) showed that images could be processed as sequences:

Python
1def image_to_patches(image, patch_size=16):
2    """Convert image to sequence of patches"""
3    # image: [B, C, H, W]
4    B, C, H, W = image.shape
5    patches = image.unfold(2, patch_size, patch_size)
6                   .unfold(3, patch_size, patch_size)
7    # patches: [B, C, H//P, W//P, P, P]
8    patches = patches.permute(0, 2, 3, 1, 4, 5)
9    patches = patches.reshape(B, -1, C * patch_size * patch_size)
10    return patches  # [B, num_patches, patch_dim]

This allowed the same transformer architecture to process both images and text!

CLIP: Contrastive Language-Image Pre-training

OpenAI's CLIP was a watershed moment. It learned to align image and text representations through contrastive learning on 400 million image-text pairs.

Python
1def clip_loss(image_embeddings, text_embeddings, temperature=0.07):
2    """
3    Contrastive loss for CLIP
4    """
5    # Normalize embeddings
6    image_embeddings = F.normalize(image_embeddings, dim=-1)
7    text_embeddings = F.normalize(text_embeddings, dim=-1)
8
9    # Compute similarity matrix
10    logits = image_embeddings @ text_embeddings.T / temperature
11
12    # Labels: diagonal elements should be highest (matching pairs)
13    labels = torch.arange(len(logits), device=logits.device)
14
15    # Symmetric loss
16    loss_i2t = F.cross_entropy(logits, labels)
17    loss_t2i = F.cross_entropy(logits.T, labels)
18
19    return (loss_i2t + loss_t2i) / 2

The key insight: by training on millions of naturally occurring image-caption pairs from the internet, CLIP learned rich, transferable representations without explicit task-specific labels.

Key Architectures

1. Dual Encoder Models

Models like CLIP use separate encoders for each modality, projecting into a shared embedding space:

Images and text are processed by separate encoders and projected into a shared embedding space where similar concepts (regardless of modality) end up close together.

Pros: Efficient retrieval (pre-compute embeddings) Cons: Limited cross-modal interaction

2. Fusion Models

Models like BLIP and Flamingo deeply integrate vision and language through cross-attention:

Python
1class CrossAttention(nn.Module):
2    def __init__(self, dim, num_heads):
3        super().__init__()
4        self.attention = nn.MultiheadAttention(dim, num_heads)
5
6    def forward(self, query, key_value):
7        """
8        query: text features
9        key_value: image features
10        """
11        output, weights = self.attention(query, key_value, key_value)
12        return output  # Text enriched with visual information

Pros: Rich cross-modal reasoning Cons: Computationally expensive

3. Unified Models

The latest generation uses a single transformer for both modalities:

Image patches and word tokens are concatenated into a single sequence and fed through one unified transformer, producing a joint representation that captures cross-modal relationships.

Applications

Visual Question Answering (VQA)

Given an image and a question, produce an answer:

Input: [Image of a kitchen] + "What color is the refrigerator?" Output: "The refrigerator is silver/stainless steel."

Image Captioning

Generate natural language descriptions of images:

Python
1def generate_caption(model, image, max_length=50):
2    """Autoregressive caption generation"""
3    tokens = [BOS_TOKEN]
4    image_features = model.encode_image(image)
5
6    for _ in range(max_length):
7        text_features = model.encode_text(tokens)
8        combined = model.fuse(image_features, text_features)
9        next_token = model.predict_next(combined)
10
11        if next_token == EOS_TOKEN:
12            break
13        tokens.append(next_token)
14
15    return decode(tokens)

Visual Grounding

Locate objects in an image based on natural language descriptions:

Input: "Find the red car in the parking lot" Output: Bounding box coordinates [x, y, width, height]

Text-to-Image Generation

Models like DALL-E and Stable Diffusion generate images from text:

Input: "A cyberpunk city at sunset, neon lights reflecting on wet streets" Output: [Generated Image]

Challenges and Research Directions

1. Hallucination

Vision-language models sometimes generate plausible-sounding but incorrect descriptions. A model might describe "a cat sitting on a couch" when the image shows a dog.

Python
1# Detecting potential hallucinations
2def check_consistency(model, image, caption):
3    """
4    Cross-check caption against image features
5    """
6    # Generate multiple captions
7    captions = [generate_caption(model, image) for _ in range(5)]
8
9    # Check semantic consistency
10    embeddings = [model.encode_text(c) for c in captions]
11    similarity_matrix = compute_pairwise_similarity(embeddings)
12
13    # Low similarity suggests uncertainty/potential hallucination
14    return similarity_matrix.mean()

2. Compositional Understanding

Models struggle with compositional concepts:

  • "A red cube on a blue sphere" vs. "A blue cube on a red sphere"
  • Understanding spatial relationships
  • Counting objects accurately

3. Bias and Fairness

Training data biases propagate to models. CLIP, for instance, has been shown to exhibit demographic biases in its associations.

4. Efficiency

Large vision-language models require significant compute. Research focuses on:

  • Knowledge distillation
  • Efficient attention mechanisms
  • Model pruning and quantization

My Research: Static Key Attention

In my recent work, I've been exploring ways to improve the efficiency of attention mechanisms in vision transformers. The key insight is that not all attention patterns need to be dynamically computed.

Static Key Attention pre-computes certain attention patterns, reducing computational cost while maintaining performance:

Python
1class StaticKeyAttention(nn.Module):
2    def __init__(self, dim, num_static_keys):
3        super().__init__()
4        # Static keys learned during training
5        self.static_keys = nn.Parameter(torch.randn(num_static_keys, dim))
6        self.query_proj = nn.Linear(dim, dim)
7        self.value_proj = nn.Linear(dim, dim)
8
9    def forward(self, x):
10        queries = self.query_proj(x)
11        values = self.value_proj(x)
12
13        # Attention with static keys
14        attention = queries @ self.static_keys.T
15        attention = F.softmax(attention, dim=-1)
16
17        return attention @ values

The Future

Vision-language AI is rapidly evolving. Key trends include:

  1. Unified models: Single architectures that handle any modality combination
  2. World models: Learning physical intuition from video
  3. Embodied AI: Robots that understand language commands and visual scenes
  4. Multimodal reasoning: Combining vision, language, and symbolic reasoning

The goal is AI systems with human-like multimodal understanding—systems that don't just process images and text separately but truly comprehend the world through multiple complementary channels.


For more on my vision-language research, see my publications on Google Scholar.

Related content