Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft DeepSpeedExamples Vision Transformer Model

From Leeroopedia
Revision as of 15:42, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Microsoft_DeepSpeedExamples_Vision_Transformer_Model.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Computer Vision, Deep Learning, Transformer Architecture
Last Updated 2026-02-07 12:00 GMT

Overview

A PyTorch implementation of the Vision Transformer (ViT) model for image recognition, based on the paper "An Image Is Worth 16x16 Words" with support for multiple variants and pretrained weights.

Description

This module provides a comprehensive implementation of the Vision Transformer (ViT) architecture as described in the paper "An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale" (Dosovitskiy et al., 2020). The implementation includes the core ViT model along with all necessary building blocks: Attention, Block, ResPostBlock, ParallelBlock, LayerScale, and the main VisionTransformer class.

The VisionTransformer class implements the full ViT pipeline: patch embedding (splitting an image into fixed-size patches and projecting them), positional embedding, a sequence of transformer encoder blocks with multi-head self-attention and MLP layers, and a classification head. It supports multiple pooling strategies (class token or average pooling), configurable depth and width, stochastic depth (drop path), and various weight initialization schemes (JAX, MoCo).

The module defines default configurations for a wide range of ViT variants including tiny, small, base, large, and huge models with different patch sizes (8, 16, 32) and input resolutions (224, 384). Pretrained weights are available from Google's official JAX checkpoints, ImageNet-21k, SAM (Sharpness-Aware Minimization), and DINO self-supervised pretraining. The implementation is built on top of the timm library and registered via register_model for seamless integration.

Usage

Use this module when you need a Vision Transformer model for image classification tasks, especially within the DeepSpeed data efficiency framework. The Block class is specifically imported by the ViT fine-tuning scripts to enable Random-LTD token dropping. Import individual classes for custom architectures or use the timm model registry to create standard configurations.

Code Reference

Source Location

Signature

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
        ...

class LayerScale(nn.Module):
    def __init__(self, dim, init_values=1e-5, inplace=False):
        ...

class Block(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0.,
                 attn_drop=0., init_values=None, drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        ...

class ResPostBlock(nn.Module):
    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0.,
                 attn_drop=0., init_values=None, drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        ...

class ParallelBlock(nn.Module):
    def __init__(self, dim, num_heads, num_parallel=2, mlp_ratio=4., qkv_bias=False,
                 init_values=None, drop=0., attn_drop=0., drop_path=0.,
                 act_layer=nn.GELU, norm_layer=nn.LayerNorm):
        ...

class VisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000,
                 global_pool='token', embed_dim=768, depth=12, num_heads=12,
                 mlp_ratio=4., qkv_bias=True, init_values=None, class_token=True,
                 no_embed_class=False, fc_norm=None, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0., weight_init='',
                 embed_layer=PatchEmbed, norm_layer=None, act_layer=None,
                 block_fn=Block):
        ...

Import

from models.vit import VisionTransformer, Block, Attention
from timm import create_model

# Or create via timm registry:
model = create_model('vit_base_patch16_224', pretrained=True)

I/O Contract

Inputs

Name Type Required Description
img_size int No Input image size (default: 224)
patch_size int No Size of each image patch (default: 16)
in_chans int No Number of input channels (default: 3)
num_classes int No Number of classification classes (default: 1000)
embed_dim int No Transformer embedding dimension (default: 768)
depth int No Number of transformer blocks (default: 12)
num_heads int No Number of attention heads (default: 12)
mlp_ratio float No Ratio of MLP hidden dim to embedding dim (default: 4.0)
drop_rate float No Dropout rate (default: 0.0)
drop_path_rate float No Stochastic depth rate (default: 0.0)

Outputs

Name Type Description
logits torch.Tensor Classification logits of shape (batch_size, num_classes)
features torch.Tensor Feature embeddings of shape (batch_size, embed_dim) when used as feature extractor

Usage Examples

import torch
from models.vit import VisionTransformer

# Create a ViT-Base model
model = VisionTransformer(
    img_size=224,
    patch_size=16,
    embed_dim=768,
    depth=12,
    num_heads=12,
    num_classes=1000
)

# Forward pass
images = torch.randn(4, 3, 224, 224)
logits = model(images)  # shape: (4, 1000)

# Create via timm registry
from timm import create_model
model = create_model('vit_base_patch16_224', pretrained=True, num_classes=1000)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment