Implementation:Microsoft DeepSpeedExamples Vision Transformer Model
| Knowledge Sources | |
|---|---|
| Domains | Computer Vision, Deep Learning, Transformer Architecture |
| Last Updated | 2026-02-07 12:00 GMT |
Overview
A PyTorch implementation of the Vision Transformer (ViT) model for image recognition, based on the paper "An Image Is Worth 16x16 Words" with support for multiple variants and pretrained weights.
Description
This module provides a comprehensive implementation of the Vision Transformer (ViT) architecture as described in the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" (Dosovitskiy et al., 2020). The implementation includes the core ViT model along with all necessary building blocks: Attention, Block, ResPostBlock, ParallelBlock, LayerScale, and the main VisionTransformer class.
The VisionTransformer class implements the full ViT pipeline: patch embedding (splitting an image into fixed-size patches and projecting them), positional embedding, a sequence of transformer encoder blocks with multi-head self-attention and MLP layers, and a classification head. It supports multiple pooling strategies (class token or average pooling), configurable depth and width, stochastic depth (drop path), and various weight initialization schemes (JAX, MoCo).
The module defines default configurations for a wide range of ViT variants including tiny, small, base, large, and huge models with different patch sizes (8, 16, 32) and input resolutions (224, 384). Pretrained weights are available from Google's official JAX checkpoints, ImageNet-21k, SAM (Sharpness-Aware Minimization), and DINO self-supervised pretraining. The implementation is built on top of the timm library and registered via register_model for seamless integration.
Usage
Use this module when you need a Vision Transformer model for image classification tasks, especially within the DeepSpeed data efficiency framework. The Block class is specifically imported by the ViT fine-tuning scripts to enable Random-LTD token dropping. Import individual classes for custom architectures or use the timm model registry to create standard configurations.
Code Reference
Source Location
- Repository: Microsoft_DeepSpeedExamples
- File: training/data_efficiency/vit_finetuning/models/vit.py
- Lines: 1-1120
Signature
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
...
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
...
class Block(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0.,
attn_drop=0., init_values=None, drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
...
class ResPostBlock(nn.Module):
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0.,
attn_drop=0., init_values=None, drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
...
class ParallelBlock(nn.Module):
def __init__(self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False,
init_values=None, drop=0., attn_drop=0., drop_path=0.,
act_layer=nn.GELU, norm_layer=nn.LayerNorm):
...
class VisionTransformer(nn.Module):
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000,
global_pool='token', embed_dim=768, depth=12, num_heads=12,
mlp_ratio=4., qkv_bias=True, init_values=None, class_token=True,
no_embed_class=False, fc_norm=None, drop_rate=0.,
attn_drop_rate=0., drop_path_rate=0., weight_init='',
embed_layer=PatchEmbed, norm_layer=None, act_layer=None,
block_fn=Block):
...
Import
from models.vit import VisionTransformer, Block, Attention
from timm import create_model
# Or create via timm registry:
model = create_model('vit_base_patch16_224', pretrained=True)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| img_size | int | No | Input image size (default: 224) |
| patch_size | int | No | Size of each image patch (default: 16) |
| in_chans | int | No | Number of input channels (default: 3) |
| num_classes | int | No | Number of classification classes (default: 1000) |
| embed_dim | int | No | Transformer embedding dimension (default: 768) |
| depth | int | No | Number of transformer blocks (default: 12) |
| num_heads | int | No | Number of attention heads (default: 12) |
| mlp_ratio | float | No | Ratio of MLP hidden dim to embedding dim (default: 4.0) |
| drop_rate | float | No | Dropout rate (default: 0.0) |
| drop_path_rate | float | No | Stochastic depth rate (default: 0.0) |
Outputs
| Name | Type | Description |
|---|---|---|
| logits | torch.Tensor | Classification logits of shape (batch_size, num_classes) |
| features | torch.Tensor | Feature embeddings of shape (batch_size, embed_dim) when used as feature extractor |
Usage Examples
import torch
from models.vit import VisionTransformer
# Create a ViT-Base model
model = VisionTransformer(
img_size=224,
patch_size=16,
embed_dim=768,
depth=12,
num_heads=12,
num_classes=1000
)
# Forward pass
images = torch.randn(4, 3, 224, 224)
logits = model(images) # shape: (4, 1000)
# Create via timm registry
from timm import create_model
model = create_model('vit_base_patch16_224', pretrained=True, num_classes=1000)