Implementation:Microsoft DeepSpeedExamples ViT ImageNet Finetuning
| Knowledge Sources | |
|---|---|
| Domains | Computer Vision, Data Efficiency, Fine-Tuning |
| Last Updated | 2026-02-07 12:00 GMT |
Overview
A training script for fine-tuning Vision Transformer (ViT) and other image classification models on ImageNet with DeepSpeed data efficiency features including Random-LTD.
Description
This script provides a full ImageNet training pipeline for Vision Transformer models with DeepSpeed integration. It supports multiple model architectures through a registry-based system, where models are loaded from a local models module that includes custom ViT implementations. The script handles distributed training setup, data loading with ImageNet-specific augmentations from the timm library, and the complete train/validate/checkpoint cycle.
The key differentiator of this script is its integration with DeepSpeed's Random-LTD (Layer Token Dropping) data efficiency technique via the convert_to_random_ltd helper. When enabled with the --random_ltd flag, the ViT model is modified to drop tokens at intermediate transformer layers, reducing computational cost while maintaining accuracy. The script imports the Block class from the custom ViT implementation to support this token dropping mechanism.
Training features include configurable learning rate scheduling (cosine or step), optional warmup, distributed data parallel training via DeepSpeed, and comprehensive metric tracking for loss, top-1, and top-5 accuracy. Results are stored in a history dictionary and checkpointed for later analysis.
Usage
Use this script when training or fine-tuning ViT models on ImageNet-scale datasets with DeepSpeed, particularly when leveraging data efficiency techniques like Random-LTD. Launch via deepspeed main_imagenet.py --arch vit_base_patch16_224 --data /path/to/imagenet --deepspeed_config ds_config.json.
Code Reference
Source Location
- Repository: Microsoft_DeepSpeedExamples
- File: training/data_efficiency/vit_finetuning/main_imagenet.py
- Lines: 1-568
Signature
def _get_model(args):
"""Create a model from the architecture registry."""
...
def _get_dist_model(gpu, args):
"""Create a distributed model with DeepSpeed or DataParallel."""
...
def main():
"""Entry point: parse args, setup distributed, launch main_worker."""
...
def main_worker(gpu, ngpus_per_node, args):
"""Main training worker: setup model, data, optimizer, run train/validate loop."""
...
def train(scheduler, train_loader, model, criterion, optimizer, epoch, args):
"""Run one training epoch."""
...
def validate(val_loader, model, criterion, args):
"""Evaluate model on validation set."""
...
Import
# This is a standalone training script, not typically imported.
# Run via: deepspeed main_imagenet.py [args]
import deepspeed
from deepspeed.runtime.data_pipeline.data_routing.helper import convert_to_random_ltd, save_without_random_ltd
from models.vit import Block
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| data | str | No | Path to ImageNet dataset directory (default: 'imagenet') |
| --arch | str | No | Model architecture name (default: 'resnet18') |
| --epochs | int | No | Number of total epochs to run (default: 90) |
| --batch-size | int | No | Mini-batch size across all GPUs (default: 256) |
| --lr | float | No | Initial learning rate (default: 0.1) |
| --img_size | int | No | Input image size (default: 224) |
| --scheduler | str | No | Learning rate scheduler type: 'cosine' or 'step' (default: 'cosine') |
| --random_ltd | flag | No | Enable Random-LTD token dropping for data efficiency |
| --deepspeed_config | str | No | Path to DeepSpeed JSON configuration file |
Outputs
| Name | Type | Description |
|---|---|---|
| Trained model | nn.Module | Fine-tuned image classification model |
| checkpoint.pth.tar | file | Training checkpoint with model state, optimizer state, and epoch |
| model_best.pth.tar | file | Best model checkpoint based on top-1 validation accuracy |
| Training metrics | dict | History of train/val loss, top-1 accuracy, and top-5 accuracy per epoch |
Usage Examples
# Fine-tune ViT-Base on ImageNet with DeepSpeed
# deepspeed training/data_efficiency/vit_finetuning/main_imagenet.py \
# --arch vit_base_patch16_224 \
# --data /path/to/imagenet \
# --epochs 90 \
# --batch-size 256 \
# --lr 0.1 \
# --scheduler cosine \
# --deepspeed_config ds_config.json
# Fine-tune with Random-LTD for data efficiency
# deepspeed main_imagenet.py \
# --arch vit_base_patch16_224 \
# --data /path/to/imagenet \
# --random_ltd \
# --deepspeed_config ds_config_random_ltd.json