Implementation:NVIDIA DALI Train Function PyTorch
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, GPU_Computing, Model_Training |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
The train() function pattern from the NVIDIA DALI ResNet50 example that implements a single-epoch PyTorch training loop consuming data from a DALI iterator, with support for mixed-precision training via GradScaler, distributed gradient synchronization, learning rate warmup, and periodic metric logging.
Description
The train() function implements the integration pattern between a DALI data pipeline and a standard PyTorch training loop. It handles the specific data format produced by the DALIClassificationIterator and orchestrates the complete forward-backward-update cycle for one training epoch.
Key aspects of this implementation:
DALI data unpacking: The function detects whether the data comes from a DALI iterator or a PyTorch DataLoader and unpacks accordingly. For DALI, images are accessed via data[0]["data"] and labels via data[0]["label"].squeeze(-1).long(). The squeeze removes the trailing dimension [B, 1] -> [B] and the long conversion ensures compatibility with CrossEntropyLoss.
Mixed-precision training: The forward pass runs under torch.cuda.amp.autocast for automatic float16 computation on Tensor Cores. The GradScaler scales the loss before the backward pass to prevent gradient underflow in float16, then unscales gradients before the optimizer step. This is orthogonal to DALI's preprocessing, which outputs float32 tensors.
Learning rate schedule: The adjust_learning_rate function implements a step-decay schedule (divide by 10 every 30 epochs) with a 5-epoch warmup period. The warmup linearly increases the learning rate from 0 to the target LR over the first 5 epochs, which is important for large-batch training stability.
Periodic metric computation: Top-1 and top-5 accuracy are computed every print_freq iterations (default 10) rather than every iteration, since accuracy computation requires a host-device synchronization. In distributed training, metrics are reduced across all workers via all_reduce for accurate global statistics.
Profiling hooks: Optional CUDA profiling integration via nvtx range markers and cudaProfiler API calls, enabling detailed performance analysis of the training loop.
Usage
Call this function once per epoch, passing the DALI iterator as the train_loader. The function returns the average batch processing time for the epoch.
Code Reference
Source Location
- Repository: NVIDIA DALI
- File: docs/examples/use_cases/pytorch/resnet50/main.py (lines 563-668)
Signature
def train(train_loader, model, criterion, scaler, optimizer, epoch):
batch_time = AverageMeter()
losses = AverageMeter()
top1 = AverageMeter()
top5 = AverageMeter()
model.train()
end = time.time()
is_pytorch_loader = args.data_loader == "pytorch" or args.data_loader == "dali_proxy"
if is_pytorch_loader:
do_normalize = args.data_loader == "pytorch"
data_iterator = data_prefetcher(train_loader, do_normalize=do_normalize)
data_iterator = iter(data_iterator)
else:
data_iterator = train_loader
for i, data in enumerate(data_iterator):
if is_pytorch_loader:
input, target = data
train_loader_len = len(train_loader)
else:
input = data[0]["data"]
target = data[0]["label"].squeeze(-1).long()
train_loader_len = int(math.ceil(data_iterator._size / args.batch_size))
adjust_learning_rate(optimizer, epoch, i, train_loader_len)
with torch.cuda.amp.autocast(enabled=args.fp16_mode):
output = model(input)
loss = criterion(output, target)
optimizer.zero_grad()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
if i % args.print_freq == 0:
prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
if args.distributed:
reduced_loss = reduce_tensor(loss.data)
prec1 = reduce_tensor(prec1)
prec5 = reduce_tensor(prec5)
else:
reduced_loss = loss.data
losses.update(to_python_float(reduced_loss), input.size(0))
top1.update(to_python_float(prec1), input.size(0))
top5.update(to_python_float(prec5), input.size(0))
return batch_time.avg
Import
import torch
import torch.nn as nn
import torch.optim
import torch.cuda.amp
from torch.nn.parallel import DistributedDataParallel as DDP
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| train_loader | DALIClassificationIterator | Yes | DALI iterator yielding [{"data": Tensor[B,3,H,W], "label": Tensor[B,1]}] per iteration |
| model | nn.Module | Yes | PyTorch model (optionally wrapped in DistributedDataParallel) in .train() mode |
| criterion | nn.Module | Yes | Loss function, typically nn.CrossEntropyLoss().cuda() |
| scaler | torch.cuda.amp.GradScaler | Yes | Gradient scaler for mixed-precision training; pass enabled=False to disable FP16 |
| optimizer | torch.optim.Optimizer | Yes | Optimizer instance, typically SGD with momentum and weight decay |
| epoch | int | Yes | Current epoch number (0-indexed), used for learning rate scheduling and logging |
Outputs
| Name | Type | Description |
|---|---|---|
| avg_batch_time | float | Average wall-clock time per batch across the epoch, used for throughput reporting |
Usage Examples
Main Training Loop
# From main.py (lines 480-514):
scaler = torch.cuda.amp.GradScaler(
init_scale=args.loss_scale,
growth_factor=2,
backoff_factor=0.5,
growth_interval=100,
enabled=args.fp16_mode,
)
for epoch in range(args.start_epoch, args.epochs):
avg_train_time = train(train_loader, model, criterion, scaler, optimizer, epoch)
[prec1, prec5] = validate(val_loader, model, criterion)
if args.local_rank == 0:
is_best = prec1 > best_prec1
best_prec1 = max(prec1, best_prec1)
save_checkpoint({
'epoch': epoch + 1,
'arch': args.arch,
'state_dict': model.state_dict(),
'best_prec1': best_prec1,
'optimizer': optimizer.state_dict(),
}, is_best)
DALI Data Unpacking Pattern
# Inside the training loop, DALI iterator produces different format than PyTorch:
for i, data in enumerate(train_loader):
# DALI format:
input = data[0]["data"] # [B, 3, 224, 224] float32 GPU
target = data[0]["label"].squeeze(-1).long() # [B] int64 GPU
# Equivalent PyTorch DataLoader format would be:
# input, target = data
Learning Rate Schedule
def adjust_learning_rate(optimizer, epoch, step, len_epoch):
"""Step decay with warmup: divide LR by 10 every 30 epochs, warmup for 5 epochs."""
factor = epoch // 30
if epoch >= 80:
factor = factor + 1
lr = args.lr * (0.1 ** factor)
# Linear warmup for first 5 epochs
if epoch < 5:
lr = lr * float(1 + step + epoch * len_epoch) / (5. * len_epoch)
for param_group in optimizer.param_groups:
param_group['lr'] = lr