Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft DeepSpeedExamples ViT ImageNet Finetuning

From Leeroopedia


Knowledge Sources
Domains Computer Vision, Data Efficiency, Fine-Tuning
Last Updated 2026-02-07 12:00 GMT

Overview

A training script for fine-tuning Vision Transformer (ViT) and other image classification models on ImageNet with DeepSpeed data efficiency features including Random-LTD.

Description

This script provides a full ImageNet training pipeline for Vision Transformer models with DeepSpeed integration. It supports multiple model architectures through a registry-based system, where models are loaded from a local models module that includes custom ViT implementations. The script handles distributed training setup, data loading with ImageNet-specific augmentations from the timm library, and the complete train/validate/checkpoint cycle.

The key differentiator of this script is its integration with DeepSpeed's Random-LTD (Layer Token Dropping) data efficiency technique via the convert_to_random_ltd helper. When enabled with the --random_ltd flag, the ViT model is modified to drop tokens at intermediate transformer layers, reducing computational cost while maintaining accuracy. The script imports the Block class from the custom ViT implementation to support this token dropping mechanism.

Training features include configurable learning rate scheduling (cosine or step), optional warmup, distributed data parallel training via DeepSpeed, and comprehensive metric tracking for loss, top-1, and top-5 accuracy. Results are stored in a history dictionary and checkpointed for later analysis.

Usage

Use this script when training or fine-tuning ViT models on ImageNet-scale datasets with DeepSpeed, particularly when leveraging data efficiency techniques like Random-LTD. Launch via deepspeed main_imagenet.py --arch vit_base_patch16_224 --data /path/to/imagenet --deepspeed_config ds_config.json.

Code Reference

Source Location

Signature

def _get_model(args):
    """Create a model from the architecture registry."""
    ...

def _get_dist_model(gpu, args):
    """Create a distributed model with DeepSpeed or DataParallel."""
    ...

def main():
    """Entry point: parse args, setup distributed, launch main_worker."""
    ...

def main_worker(gpu, ngpus_per_node, args):
    """Main training worker: setup model, data, optimizer, run train/validate loop."""
    ...

def train(scheduler, train_loader, model, criterion, optimizer, epoch, args):
    """Run one training epoch."""
    ...

def validate(val_loader, model, criterion, args):
    """Evaluate model on validation set."""
    ...

Import

# This is a standalone training script, not typically imported.
# Run via: deepspeed main_imagenet.py [args]
import deepspeed
from deepspeed.runtime.data_pipeline.data_routing.helper import convert_to_random_ltd, save_without_random_ltd
from models.vit import Block

I/O Contract

Inputs

Name Type Required Description
data str No Path to ImageNet dataset directory (default: 'imagenet')
--arch str No Model architecture name (default: 'resnet18')
--epochs int No Number of total epochs to run (default: 90)
--batch-size int No Mini-batch size across all GPUs (default: 256)
--lr float No Initial learning rate (default: 0.1)
--img_size int No Input image size (default: 224)
--scheduler str No Learning rate scheduler type: 'cosine' or 'step' (default: 'cosine')
--random_ltd flag No Enable Random-LTD token dropping for data efficiency
--deepspeed_config str No Path to DeepSpeed JSON configuration file

Outputs

Name Type Description
Trained model nn.Module Fine-tuned image classification model
checkpoint.pth.tar file Training checkpoint with model state, optimizer state, and epoch
model_best.pth.tar file Best model checkpoint based on top-1 validation accuracy
Training metrics dict History of train/val loss, top-1 accuracy, and top-5 accuracy per epoch

Usage Examples

# Fine-tune ViT-Base on ImageNet with DeepSpeed
# deepspeed training/data_efficiency/vit_finetuning/main_imagenet.py \
#     --arch vit_base_patch16_224 \
#     --data /path/to/imagenet \
#     --epochs 90 \
#     --batch-size 256 \
#     --lr 0.1 \
#     --scheduler cosine \
#     --deepspeed_config ds_config.json

# Fine-tune with Random-LTD for data efficiency
# deepspeed main_imagenet.py \
#     --arch vit_base_patch16_224 \
#     --data /path/to/imagenet \
#     --random_ltd \
#     --deepspeed_config ds_config_random_ltd.json

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment