Implementation:Microsoft DeepSpeedExamples VisProjection
Appearance
- Implementation: VisProjection
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation (Pattern Doc) |
| Title | VisProjection |
| Repository | Microsoft/DeepSpeedExamples |
| Application | DeepSpeed-VisualChat |
| File | applications/DeepSpeed-VisualChat/utils/model/vis_proj.py
|
| Lines | 15-153 |
| Language | Python |
| Status | Active |
Overview
Concrete tool for vision-to-language projection in DeepSpeed-VisualChat supporting ViT linear and Perceiver cross-attention modes.
Code Reference
VisProjection_vit (Lines 15-25)
class VisProjection_vit(nn.Module):
def __init__(self, vis_config, lang_dim):
super().__init__()
self.vis_layer = CLIPEncoderLayer(vis_config)
self.projection = nn.Sequential(
nn.Linear(vis_config.hidden_size, lang_dim),
nn.LayerNorm(lang_dim, eps=1e-12))
def forward(self, vis_input):
vis_feature = self.vis_layer(vis_input, None, None)[0]
return self.projection(vis_feature)
This module chains a single CLIPEncoderLayer (self-attention + FFN) with a linear projection and layer normalization. The CLIPEncoderLayer provides one additional round of cross-patch feature refinement before the dimension mapping.
VisProjection_perceiver (Lines 97-153)
class VisProjection_perceiver(nn.Module):
def __init__(self, vis_config, lang_dim):
super().__init__()
grid_size = 16
self.num_queries = grid_size ** 2 # 256 learned queries
self.embed_dim = lang_dim
self.num_heads = lang_dim // 128 # e.g., 32 for 4096-dim
self.pos_embed = nn.Parameter(
torch.from_numpy(
get_2d_sincos_pos_embed(lang_dim, grid_size)
).float()
).requires_grad_(False) # fixed positional embeddings
self.query = nn.Parameter(torch.zeros(self.num_queries, lang_dim))
trunc_normal_(self.query, std=.02) # truncated normal init
self.kv_proj = nn.Linear(vis_config.hidden_size, lang_dim)
self.attn = nn.MultiheadAttention(lang_dim, self.num_heads)
self.ln_q = nn.LayerNorm(lang_dim, eps=1e-12)
self.ln_kv = nn.LayerNorm(lang_dim, eps=1e-12)
self.projection = nn.Sequential(
nn.LayerNorm(lang_dim, eps=1e-12),
nn.Linear(lang_dim, lang_dim))
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def forward(self, x, attn_mask=None):
pos_embed = get_abs_pos(self.pos_embed, x.size(1))
x = x[:, 1:, :] # remove CLS token
x = self.kv_proj(x) # project to lang_dim
x = self.ln_kv(x).permute(1, 0, 2) # (seq, batch, dim) for attn
N = x.shape[1]
q = self.ln_q(self.query)
out = self.attn(
self._repeat(q, N) + self.pos_embed.unsqueeze(1),
x + pos_embed.unsqueeze(1),
x,
attn_mask=attn_mask)[0]
return self.projection(out.permute(1, 0, 2)) # back to (batch, seq, dim)
def _repeat(self, query, N: int):
return query.unsqueeze(1).repeat(1, N, 1)
Supporting Functions
Positional Embedding Interpolation (Lines 29-45)
def get_abs_pos(abs_pos, tgt_size):
"""Interpolate positional embeddings when resolution changes."""
src_size = int(math.sqrt(abs_pos.size(0)))
tgt_size = int(math.sqrt(tgt_size))
dtype = abs_pos.dtype
if src_size != tgt_size:
return F.interpolate(
abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2),
size=(tgt_size, tgt_size),
mode="bicubic",
align_corners=False,
).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype)
else:
return abs_pos
2D Sinusoidal-Cosine Positional Embeddings (Lines 48-94)
def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False):
"""Generate 2D sincos positional embeddings.
Returns: pos_embed of shape [grid_size*grid_size, embed_dim]
"""
grid_h = np.arange(grid_size, dtype=np.float32)
grid_w = np.arange(grid_size, dtype=np.float32)
grid = np.meshgrid(grid_w, grid_h)
grid = np.stack(grid, axis=0)
grid = grid.reshape([2, 1, grid_size, grid_size])
pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
if cls_token:
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
return pos_embed
I/O Contract
VisProjection_vit
| Direction | Parameter | Type | Shape | Description |
|---|---|---|---|---|
| Init Input | vis_config |
CLIPVisionConfig |
-- | Vision encoder configuration (hidden_size, etc.) |
| Init Input | lang_dim |
int |
-- | Language model hidden dimension (e.g., 4096) |
| Forward Input | vis_input |
torch.Tensor |
[batch, num_patches, vis_dim] |
Visual features from vision encoder |
| Forward Output | (return) | torch.Tensor |
[batch, num_patches, lang_dim] |
Projected features in language space |
VisProjection_perceiver
| Direction | Parameter | Type | Shape | Description |
|---|---|---|---|---|
| Init Input | vis_config |
CLIPVisionConfig |
-- | Vision encoder configuration |
| Init Input | lang_dim |
int |
-- | Language model hidden dimension |
| Forward Input | x |
torch.Tensor |
[batch, num_patches+1, vis_dim] |
Visual features (including CLS token) |
| Forward Input | attn_mask |
torch.Tensor or None |
optional | Attention mask for cross-attention |
| Forward Output | (return) | torch.Tensor |
[batch, 256, lang_dim] |
Fixed 256 projected query tokens |
Usage Example
Selecting the Projection Type
The projection type is selected via the --vis_proj command-line argument and instantiated in DeepSpeedViLModel.build_projection():
def build_projection(self, vis_config, lang_dim):
if self.args.vis_proj == 'vit':
return VisProjection_vit(vis_config, lang_dim=lang_dim)
elif self.args.vis_proj == 'baseline':
return nn.Sequential(
nn.Linear(vis_config.hidden_size, lang_dim),
nn.LayerNorm(lang_dim, eps=1e-12))
elif self.args.vis_proj == 'perceiver':
return VisProjection_perceiver(vis_config, lang_dim=lang_dim)
Command-Line Selection
# Use ViT linear projection
deepspeed training/main.py --vis_proj vit ...
# Use Perceiver cross-attention projection
deepspeed training/main.py --vis_proj perceiver ...
# Use simple baseline projection (required for Qwen-VL)
deepspeed training/main.py --vis_proj baseline ...
Architecture Details
| Projection Type | Trainable Parameters | Output Tokens | Key Components |
|---|---|---|---|
baseline |
vis_dim * lang_dim + lang_dim (Linear + LN) |
Same as input patches | Linear + LayerNorm |
vit |
CLIPEncoderLayer params + Linear + LN | Same as input patches | CLIPEncoderLayer + Linear + LayerNorm |
perceiver |
Queries + kv_proj + attn + LN + projection | Fixed 256 | Learned queries, MultiheadAttention, sincos pos embed |
Dependencies
torch.nn-- PyTorch neural network modulestorch.nn.functional-- Functional operations (interpolation)transformers.models.clip.modeling_clip.CLIPEncoderLayer-- CLIP transformer layertorch.nn.init.trunc_normal_-- Truncated normal initializationnumpy-- Positional embedding computation
Related Pages
- Principle:Microsoft_DeepSpeedExamples_Vision_Language_Projection -- The theoretical basis for vision-language projection
- Implementation:Microsoft_DeepSpeedExamples_Extract_Qwen_VL -- The vision encoder that produces features for projection
- Implementation:Microsoft_DeepSpeedExamples_Create_DSVL_Model -- The model that uses the projection layer
- Environment:Microsoft_DeepSpeedExamples_VisualChat_Training_Environment
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment