Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft DeepSpeedExamples VisProjection

From Leeroopedia


  1. Implementation: VisProjection

Metadata

Field Value
Page Type Implementation (Pattern Doc)
Title VisProjection
Repository Microsoft/DeepSpeedExamples
Application DeepSpeed-VisualChat
File applications/DeepSpeed-VisualChat/utils/model/vis_proj.py
Lines 15-153
Language Python
Status Active

Overview

Concrete tool for vision-to-language projection in DeepSpeed-VisualChat supporting ViT linear and Perceiver cross-attention modes.

Code Reference

VisProjection_vit (Lines 15-25)

class VisProjection_vit(nn.Module):
    def __init__(self, vis_config, lang_dim):
        super().__init__()
        self.vis_layer = CLIPEncoderLayer(vis_config)
        self.projection = nn.Sequential(
            nn.Linear(vis_config.hidden_size, lang_dim),
            nn.LayerNorm(lang_dim, eps=1e-12))

    def forward(self, vis_input):
        vis_feature = self.vis_layer(vis_input, None, None)[0]
        return self.projection(vis_feature)

This module chains a single CLIPEncoderLayer (self-attention + FFN) with a linear projection and layer normalization. The CLIPEncoderLayer provides one additional round of cross-patch feature refinement before the dimension mapping.

VisProjection_perceiver (Lines 97-153)

class VisProjection_perceiver(nn.Module):
    def __init__(self, vis_config, lang_dim):
        super().__init__()
        grid_size = 16
        self.num_queries = grid_size ** 2     # 256 learned queries
        self.embed_dim = lang_dim
        self.num_heads = lang_dim // 128      # e.g., 32 for 4096-dim

        self.pos_embed = nn.Parameter(
            torch.from_numpy(
                get_2d_sincos_pos_embed(lang_dim, grid_size)
            ).float()
        ).requires_grad_(False)               # fixed positional embeddings

        self.query = nn.Parameter(torch.zeros(self.num_queries, lang_dim))
        trunc_normal_(self.query, std=.02)    # truncated normal init

        self.kv_proj = nn.Linear(vis_config.hidden_size, lang_dim)
        self.attn = nn.MultiheadAttention(lang_dim, self.num_heads)
        self.ln_q = nn.LayerNorm(lang_dim, eps=1e-12)
        self.ln_kv = nn.LayerNorm(lang_dim, eps=1e-12)
        self.projection = nn.Sequential(
            nn.LayerNorm(lang_dim, eps=1e-12),
            nn.Linear(lang_dim, lang_dim))

        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    def forward(self, x, attn_mask=None):
        pos_embed = get_abs_pos(self.pos_embed, x.size(1))

        x = x[:, 1:, :]                     # remove CLS token
        x = self.kv_proj(x)                  # project to lang_dim
        x = self.ln_kv(x).permute(1, 0, 2)  # (seq, batch, dim) for attn

        N = x.shape[1]
        q = self.ln_q(self.query)
        out = self.attn(
            self._repeat(q, N) + self.pos_embed.unsqueeze(1),
            x + pos_embed.unsqueeze(1),
            x,
            attn_mask=attn_mask)[0]
        return self.projection(out.permute(1, 0, 2))  # back to (batch, seq, dim)

    def _repeat(self, query, N: int):
        return query.unsqueeze(1).repeat(1, N, 1)

Supporting Functions

Positional Embedding Interpolation (Lines 29-45)

def get_abs_pos(abs_pos, tgt_size):
    """Interpolate positional embeddings when resolution changes."""
    src_size = int(math.sqrt(abs_pos.size(0)))
    tgt_size = int(math.sqrt(tgt_size))
    dtype = abs_pos.dtype

    if src_size != tgt_size:
        return F.interpolate(
            abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
            size=(tgt_size, tgt_size),
            mode="bicubic",
            align_corners=False,
        ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
    else:
        return abs_pos

2D Sinusoidal-Cosine Positional Embeddings (Lines 48-94)

def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
    """Generate 2D sincos positional embeddings.
    Returns: pos_embed of shape [grid_size*grid_size, embed_dim]
    """
    grid_h = np.arange(grid_size, dtype=np.float32)
    grid_w = np.arange(grid_size, dtype=np.float32)
    grid = np.meshgrid(grid_w, grid_h)
    grid = np.stack(grid, axis=0)
    grid = grid.reshape([2, 1, grid_size, grid_size])
    pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
    if cls_token:
        pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
    return pos_embed

I/O Contract

VisProjection_vit

Direction Parameter Type Shape Description
Init Input vis_config CLIPVisionConfig -- Vision encoder configuration (hidden_size, etc.)
Init Input lang_dim int -- Language model hidden dimension (e.g., 4096)
Forward Input vis_input torch.Tensor [batch, num_patches, vis_dim] Visual features from vision encoder
Forward Output (return) torch.Tensor [batch, num_patches, lang_dim] Projected features in language space

VisProjection_perceiver

Direction Parameter Type Shape Description
Init Input vis_config CLIPVisionConfig -- Vision encoder configuration
Init Input lang_dim int -- Language model hidden dimension
Forward Input x torch.Tensor [batch, num_patches+1, vis_dim] Visual features (including CLS token)
Forward Input attn_mask torch.Tensor or None optional Attention mask for cross-attention
Forward Output (return) torch.Tensor [batch, 256, lang_dim] Fixed 256 projected query tokens

Usage Example

Selecting the Projection Type

The projection type is selected via the --vis_proj command-line argument and instantiated in DeepSpeedViLModel.build_projection():

def build_projection(self, vis_config, lang_dim):
    if self.args.vis_proj == 'vit':
        return VisProjection_vit(vis_config, lang_dim=lang_dim)
    elif self.args.vis_proj == 'baseline':
        return nn.Sequential(
            nn.Linear(vis_config.hidden_size, lang_dim),
            nn.LayerNorm(lang_dim, eps=1e-12))
    elif self.args.vis_proj == 'perceiver':
        return VisProjection_perceiver(vis_config, lang_dim=lang_dim)

Command-Line Selection

# Use ViT linear projection
deepspeed training/main.py --vis_proj vit ...

# Use Perceiver cross-attention projection
deepspeed training/main.py --vis_proj perceiver ...

# Use simple baseline projection (required for Qwen-VL)
deepspeed training/main.py --vis_proj baseline ...

Architecture Details

Projection Type Trainable Parameters Output Tokens Key Components
baseline vis_dim * lang_dim + lang_dim (Linear + LN) Same as input patches Linear + LayerNorm
vit CLIPEncoderLayer params + Linear + LN Same as input patches CLIPEncoderLayer + Linear + LayerNorm
perceiver Queries + kv_proj + attn + LN + projection Fixed 256 Learned queries, MultiheadAttention, sincos pos embed

Dependencies

  • torch.nn -- PyTorch neural network modules
  • torch.nn.functional -- Functional operations (interpolation)
  • transformers.models.clip.modeling_clip.CLIPEncoderLayer -- CLIP transformer layer
  • torch.nn.init.trunc_normal_ -- Truncated normal initialization
  • numpy -- Positional embedding computation

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment