Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Microsoft DeepSpeedExamples Multimodal Model Composition

From Leeroopedia


  1. Principle: Multimodal_Model_Composition

Metadata

Field Value
Page Type Principle
Title Multimodal_Model_Composition
Sources Paper: LLaVA (https://arxiv.org/abs/2304.08485), Paper: DeepSpeed-VisualChat (https://arxiv.org/abs/2309.14327)
Domains Multimodal, Model_Architecture, NLP
Repository Microsoft/DeepSpeedExamples
Application DeepSpeed-VisualChat
Status Active

Overview

An architecture pattern that composes a vision encoder, projection layer, and language decoder into a unified multimodal model for visual question answering and chat.

Description

The DeepSpeed-VisualChat model is a three-stage composed architecture that processes interleaved image and text inputs to produce natural language responses. Rather than training a monolithic multimodal model from scratch, this approach composes three pre-trained or purpose-built components:

Stage 1: Vision Encoder

The vision encoder processes raw images into dense feature sequences:

  • Accepts images of fixed resolution (e.g., 224x224 for CLIP or 448x448 for Qwen-VL)
  • Outputs a tensor of shape [num_images, num_patches, vis_dim]
  • The encoder is frozen during training (no gradient computation)

Supported encoders:

  • Standard CLIP models (e.g., openai/clip-vit-large-patch14) loaded via CLIPVisionModel
  • Qwen-VL's modified CLIP (a ViT-bigG variant with output dimension 4096) loaded as a standalone VisionTransformer

Stage 2: Projection Layer

Maps visual features from the vision encoder's dimension to the language decoder's embedding dimension:

  • Three options: baseline (Linear + LayerNorm), vit (CLIPEncoderLayer + Linear + LayerNorm), perceiver (cross-attention with learned queries)
  • The projection layer is always trainable

Stage 3: Language Decoder

A causal language model that processes the concatenated visual and text token embeddings:

  • Currently supports LLaMA-2 family models
  • Produces autoregressive text generation conditioned on visual context
  • Can be fine-tuned with LoRA adapters for parameter efficiency

Concatenation and Interleaving

The core innovation is the interleaving of visual and text tokens within a single sequence:

Input sequence: [system_prompt] [### Image 1:] [vis_tokens_1] [### Question:] [text_tokens] [### Answer:]

For multi-image inputs:

Input sequence: [### Image 1:] [vis_tokens_1] [### Image 2:] [vis_tokens_2] [### Question:] [text] [### Answer:]

The <image> placeholder tokens in the text are replaced with the actual projected visual feature tensors at runtime. The concatenation process:

  1. Text tokens are embedded via the language model's embedding layer
  2. <image> token positions are identified in the input IDs
  3. Projected visual features are inserted at those positions, replacing the placeholder
  4. The resulting mixed embedding sequence is padded to a uniform length

Theoretical Basis

Visual Token Injection

The key theoretical insight is that projected visual features can be treated as additional "tokens" in the language model's input:

hidden_states = concat(
    text_embed(tokens_before_image),
    projection(vis_encoder(image)),
    text_embed(tokens_after_image)
)

This works because:

  • The projection layer maps visual features to the same dimensional space as text embeddings
  • The language model's self-attention mechanism can attend to both visual and text tokens
  • Causal masking ensures proper autoregressive generation

Multi-Modal Causal Attention (MMCA)

DeepSpeed-VisualChat introduces MMCA, a modified attention mechanism that distinguishes between visual and text tokens in the attention computation:

attention_mask values:
    0 = padding (ignored)
    1 = text token (standard causal attention)
    2 = image token (visual attention pattern)

When enable_mmca_attention is set, the attention mechanism applies different masking patterns for image-to-text and text-to-image attention, similar to cross-attention but within a unified self-attention framework.

Loss Computation

The model computes cross-entropy loss only on the answer tokens, not on the instruction or image tokens:

labels = [-100, -100, ..., -100, answer_token_1, answer_token_2, ..., eos]
         |--- instruction ---|  |---------- answer region ---------|

loss = CrossEntropyLoss(logits[labels != -100], labels[labels != -100])

The -100 label value (matching PyTorch's CrossEntropyLoss ignore index) is used to mask out instruction tokens, image tokens, and padding from the loss computation.

Trainable vs. Frozen Components

Component Trainable? Rationale
Vision Encoder No (frozen) Pre-trained features are sufficient; large parameter count
Projection Layer Yes Must learn to bridge the specific encoder-decoder pair
Language Embedding Yes Extended vocabulary for special tokens (<image>, etc.)
Language Decoder (base) No (frozen) Pre-trained language capabilities preserved
Language Decoder (LoRA) Yes (if enabled) Small adapter weights for task-specific fine-tuning

Key Considerations

  • Memory efficiency -- The vision encoder runs in torch.no_grad() mode when frozen, saving significant GPU memory by not storing activations for backpropagation.
  • Token limit -- The maximum sequence length (max_seq_len, default 4096) must accommodate both visual tokens and text tokens. With CLIP ViT-L/14, each image contributes ~257 tokens; multiple images can quickly exhaust the context window.
  • Vocabulary extension -- Special tokens (<image>, <im_patch>, <im_start>, <im_end>) are added to the tokenizer and the language model's embedding layer is resized accordingly.
  • Padding strategy -- Variable-length sequences (from different numbers of images) are padded using the padding token embedding, with padding on the right side and divisible-by-8 alignment for hardware efficiency.
  • Gradient checkpointing -- Both the vision encoder and language decoder support gradient checkpointing to reduce memory usage during training at the cost of recomputation.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment