Implementation:Microsoft DeepSpeedExamples ImageNet Training DeepSpeed
| Knowledge Sources | |
|---|---|
| Domains | Computer Vision, Distributed Training, Deep Learning |
| Last Updated | 2026-02-07 12:00 GMT |
Overview
A complete ImageNet image classification training script using DeepSpeed for distributed training with torchvision models.
Description
This script implements a standard ImageNet training pipeline integrated with DeepSpeed for efficient distributed training. It supports the full range of torchvision model architectures (ResNet, VGG, AlexNet, etc.) and handles the complete training lifecycle: model creation, distributed initialization via DeepSpeed, data loading with standard ImageNet preprocessing (random resized crop, horizontal flip, normalize), training with cross-entropy loss, and validation with top-1 and top-5 accuracy metrics.
The training loop uses DeepSpeed's engine for gradient computation and parameter updates, calling model.backward(loss) and model.step() instead of standard PyTorch optimizer calls. The script gathers loss and accuracy metrics across all GPUs after training and writes them to an Excel spreadsheet (via openpyxl) for analysis. It supports checkpoint saving and resuming, with best-model tracking based on validation accuracy.
The script provides comprehensive distributed training support with configurable world size, local rank, and multiprocessing distributed mode. It includes utility classes AverageMeter for tracking running statistics, ProgressMeter for formatted console output, and Summary enum for metric display modes. The validation function supports both standard and dummy (synthetic) data for benchmarking.
Usage
Use this script for training standard torchvision image classification models on ImageNet with DeepSpeed. It serves as a reference example for integrating DeepSpeed into a conventional PyTorch ImageNet training workflow. Launch via deepspeed training/imagenet/main.py --arch resnet50 /path/to/imagenet --deepspeed_config ds_config.json.
Code Reference
Source Location
- Repository: Microsoft_DeepSpeedExamples
- File: training/imagenet/main.py
- Lines: 1-508
Signature
def main():
"""Entry point: parse args, setup distributed environment, launch main_worker."""
...
def main_worker(gpu, ngpus_per_node, args):
"""Main training worker: create model, DeepSpeed init, data loading, train/validate loop."""
...
def train(train_loader, model, criterion, optimizer, epoch, device, args):
"""Run one training epoch using DeepSpeed engine."""
...
def validate(val_loader, model, criterion, args):
"""Evaluate model on the validation set, computing top-1 and top-5 accuracy."""
...
class AverageMeter(object):
"""Computes and stores the average and current value."""
...
class ProgressMeter(object):
"""Displays training progress with formatted metrics."""
...
Import
# This is a standalone training script, not typically imported.
# Run via: deepspeed training/imagenet/main.py [args]
import deepspeed
import torchvision.models as models
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| data | str | No | Path to ImageNet dataset directory (default: 'imagenet') |
| --arch | str | No | Model architecture from torchvision (default: 'resnet18') |
| --epochs | int | No | Number of total epochs to run (default: 90) |
| --batch-size | int | No | Mini-batch size across all GPUs (default: 256) |
| --lr | float | No | Initial learning rate (default: 0.1) |
| --momentum | float | No | SGD momentum (default: 0.9) |
| --weight-decay | float | No | Weight decay (default: 1e-4) |
| --pretrained | flag | No | Use pretrained model from torchvision |
| --local_rank | int | No | Local rank for distributed training (default: -1) |
| --deepspeed_config | str | No | Path to DeepSpeed JSON configuration file |
Outputs
| Name | Type | Description |
|---|---|---|
| Trained model | nn.Module | Trained image classification model |
| checkpoint.pth.tar | file | Training checkpoint with model state and optimizer |
| model_best.pth.tar | file | Best model based on validation top-1 accuracy |
| Acc_loss_log.xlsx | file | Excel spreadsheet with per-GPU loss and accuracy metrics |
Usage Examples
# Train ResNet-50 on ImageNet with DeepSpeed
# deepspeed training/imagenet/main.py \
# --arch resnet50 \
# /path/to/imagenet \
# --epochs 90 \
# --batch-size 256 \
# --lr 0.1 \
# --deepspeed_config ds_config.json
# Evaluate a pretrained model
# deepspeed training/imagenet/main.py \
# --arch resnet50 \
# --pretrained \
# --evaluate \
# /path/to/imagenet \
# --deepspeed_config ds_config.json