Principle:Microsoft DeepSpeedExamples Vision Language Projection
- Principle: Vision_Language_Projection
Metadata
| Field | Value |
|---|---|
| Page Type | Principle |
| Title | Vision_Language_Projection |
| Sources | Paper: LLaVA (https://arxiv.org/abs/2304.08485), Paper: Perceiver (https://arxiv.org/abs/2103.03206) |
| Domains | Multimodal, Model_Architecture |
| Repository | Microsoft/DeepSpeedExamples |
| Application | DeepSpeed-VisualChat |
| Status | Active |
Overview
A technique for bridging vision and language modalities by projecting visual features into the language model's embedding space.
Description
Vision encoders and language models operate in fundamentally different embedding spaces with different hidden dimensions. A projection layer is required to map the output of the vision encoder into a representation that the language decoder can process alongside text token embeddings.
DeepSpeed-VisualChat supports three projection approaches:
1. Baseline Linear Projection
The simplest approach: a single linear transformation followed by layer normalization.
proj(F) = LayerNorm(Linear(F))
where F in R^(num_patches x vis_dim) -> R^(num_patches x lang_dim)
- Pros -- Minimal parameters, fast computation, straightforward gradient flow
- Cons -- No additional feature refinement; the number of visual tokens equals the number of patches (can be large)
2. ViT Linear Projection
Adds a CLIPEncoderLayer before the linear projection, providing one additional round of self-attention over the visual features:
proj(F) = LayerNorm(Linear(CLIPEncoderLayer(F)))
- Pros -- The extra transformer layer allows cross-patch feature refinement before projection
- Cons -- Slightly more parameters and computation than baseline
3. Perceiver Cross-Attention Projection
Uses a set of learned queries (256 tokens arranged in a 16x16 grid) that attend to the visual features via cross-attention. This is inspired by the Perceiver architecture (Jaegle et al., 2021):
proj(F) = Linear(LayerNorm(MultiheadAttn(Q=learned_queries + pos, K=F' + pos, V=F')))
where:
F' = LayerNorm(Linear_kv(F[:, 1:, :])) # project and remove CLS token
learned_queries in R^(256 x lang_dim)
pos = sincos_2d_positional_embeddings
- Pros -- Compresses variable-length visual features into a fixed number of tokens (256), regardless of input resolution; enables cross-modal information bottleneck
- Cons -- More parameters; the learned queries need sufficient training to capture relevant visual information
Theoretical Basis
Dimension Mismatch Problem
Consider a CLIP ViT-L/14 encoder producing features of dimension 1024 and a LLaMA-2-7B decoder with embedding dimension 4096. Without projection:
vis_features in R^(num_patches x 1024)
text_embeddings in R^(seq_len x 4096)
# Cannot concatenate: dimension mismatch!
The projection layer resolves this:
proj_features = projection(vis_features) in R^(num_visual_tokens x 4096)
combined = concat(proj_features, text_embeddings) in R^((num_visual_tokens + seq_len) x 4096)
Linear Projection (LLaVA-style)
The LLaVA approach (Liu et al., 2023) demonstrates that a simple linear projection can be highly effective:
proj(F) = LayerNorm(W * F + b)
where W in R^(lang_dim x vis_dim), b in R^(lang_dim)
The key insight is that pre-trained vision and language features already share significant semantic structure from their respective contrastive/generative pre-training, so a linear mapping suffices to align them.
Perceiver Cross-Attention (Resampler)
The Perceiver-style approach uses multi-head cross-attention with learned queries:
Q = learned_queries in R^(256 x lang_dim)
K = V = projected_visual_features in R^((num_patches - 1) x lang_dim)
Attention(Q, K, V) = softmax(Q * K^T / sqrt(d_k)) * V
Output in R^(256 x lang_dim)
Key design choices in DeepSpeed-VisualChat's implementation:
- 256 learned queries from a 16x16 grid (
grid_size = 16,num_queries = 16^2 = 256) - Sinusoidal-cosine 2D positional embeddings -- Fixed (non-learned) positional encodings added to both queries and keys
- CLS token removal -- The CLS token (
x[:, 1:, :]) is stripped before cross-attention since the queries learn their own global representations - Number of attention heads computed as
lang_dim // 128(e.g., 32 heads for LLaMA-2-7B withlang_dim=4096)
Positional Embedding Interpolation
When the input image resolution changes, the number of patches changes. The positional embeddings are interpolated using bicubic interpolation:
pos_embed_new = bicubic_interpolate(
pos_embed.reshape(src_size, src_size, C),
target_size=(tgt_size, tgt_size)
)
This enables resolution flexibility without retraining the positional embeddings.
Key Considerations
- Projection type selection -- When using Qwen-VL's modified CLIP (which has an internal Perceiver-like attention pooling), only the
baselineprojection is supported to avoid redundancy. For standard CLIP models, all three projection types (baseline,vit,perceiver) are available. - Number of visual tokens -- Baseline and ViT projections preserve the number of patches (e.g., 257 tokens for CLIP ViT-L/14 at 224x224). The Perceiver projection reduces this to a fixed 256 tokens.
- Training strategy -- The projection layer is always trained (frozen encoder, frozen language decoder, trainable projection).
- Weight initialization -- The Perceiver projection uses truncated normal initialization (
std=0.02) for the learned queries and linear layers, and constant initialization for LayerNorm parameters.
Related Pages
- Implementation:Microsoft_DeepSpeedExamples_VisProjection -- The concrete projection module implementations
- Principle:Microsoft_DeepSpeedExamples_Vision_Encoder_Extraction -- The vision encoder that produces features for projection
- Principle:Microsoft_DeepSpeedExamples_Multimodal_Model_Composition -- The full model that uses the projection layer