Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:NVIDIA DALI Train Function PyTorch

From Leeroopedia


Knowledge Sources
Domains Deep_Learning, GPU_Computing, Model_Training
Last Updated 2026-02-08 00:00 GMT

Overview

The train() function pattern from the NVIDIA DALI ResNet50 example that implements a single-epoch PyTorch training loop consuming data from a DALI iterator, with support for mixed-precision training via GradScaler, distributed gradient synchronization, learning rate warmup, and periodic metric logging.

Description

The train() function implements the integration pattern between a DALI data pipeline and a standard PyTorch training loop. It handles the specific data format produced by the DALIClassificationIterator and orchestrates the complete forward-backward-update cycle for one training epoch.

Key aspects of this implementation:

DALI data unpacking: The function detects whether the data comes from a DALI iterator or a PyTorch DataLoader and unpacks accordingly. For DALI, images are accessed via data[0]["data"] and labels via data[0]["label"].squeeze(-1).long(). The squeeze removes the trailing dimension [B, 1] -> [B] and the long conversion ensures compatibility with CrossEntropyLoss.

Mixed-precision training: The forward pass runs under torch.cuda.amp.autocast for automatic float16 computation on Tensor Cores. The GradScaler scales the loss before the backward pass to prevent gradient underflow in float16, then unscales gradients before the optimizer step. This is orthogonal to DALI's preprocessing, which outputs float32 tensors.

Learning rate schedule: The adjust_learning_rate function implements a step-decay schedule (divide by 10 every 30 epochs) with a 5-epoch warmup period. The warmup linearly increases the learning rate from 0 to the target LR over the first 5 epochs, which is important for large-batch training stability.

Periodic metric computation: Top-1 and top-5 accuracy are computed every print_freq iterations (default 10) rather than every iteration, since accuracy computation requires a host-device synchronization. In distributed training, metrics are reduced across all workers via all_reduce for accurate global statistics.

Profiling hooks: Optional CUDA profiling integration via nvtx range markers and cudaProfiler API calls, enabling detailed performance analysis of the training loop.

Usage

Call this function once per epoch, passing the DALI iterator as the train_loader. The function returns the average batch processing time for the epoch.

Code Reference

Source Location

  • Repository: NVIDIA DALI
  • File: docs/examples/use_cases/pytorch/resnet50/main.py (lines 563-668)

Signature

def train(train_loader, model, criterion, scaler, optimizer, epoch):
    batch_time = AverageMeter()
    losses = AverageMeter()
    top1 = AverageMeter()
    top5 = AverageMeter()

    model.train()
    end = time.time()

    is_pytorch_loader = args.data_loader == "pytorch" or args.data_loader == "dali_proxy"
    if is_pytorch_loader:
        do_normalize = args.data_loader == "pytorch"
        data_iterator = data_prefetcher(train_loader, do_normalize=do_normalize)
        data_iterator = iter(data_iterator)
    else:
        data_iterator = train_loader

    for i, data in enumerate(data_iterator):
        if is_pytorch_loader:
            input, target = data
            train_loader_len = len(train_loader)
        else:
            input = data[0]["data"]
            target = data[0]["label"].squeeze(-1).long()
            train_loader_len = int(math.ceil(data_iterator._size / args.batch_size))

        adjust_learning_rate(optimizer, epoch, i, train_loader_len)

        with torch.cuda.amp.autocast(enabled=args.fp16_mode):
            output = model(input)
            loss = criterion(output, target)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        if i % args.print_freq == 0:
            prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
            if args.distributed:
                reduced_loss = reduce_tensor(loss.data)
                prec1 = reduce_tensor(prec1)
                prec5 = reduce_tensor(prec5)
            else:
                reduced_loss = loss.data

            losses.update(to_python_float(reduced_loss), input.size(0))
            top1.update(to_python_float(prec1), input.size(0))
            top5.update(to_python_float(prec5), input.size(0))

    return batch_time.avg

Import

import torch
import torch.nn as nn
import torch.optim
import torch.cuda.amp
from torch.nn.parallel import DistributedDataParallel as DDP

I/O Contract

Inputs

Name Type Required Description
train_loader DALIClassificationIterator Yes DALI iterator yielding [{"data": Tensor[B,3,H,W], "label": Tensor[B,1]}] per iteration
model nn.Module Yes PyTorch model (optionally wrapped in DistributedDataParallel) in .train() mode
criterion nn.Module Yes Loss function, typically nn.CrossEntropyLoss().cuda()
scaler torch.cuda.amp.GradScaler Yes Gradient scaler for mixed-precision training; pass enabled=False to disable FP16
optimizer torch.optim.Optimizer Yes Optimizer instance, typically SGD with momentum and weight decay
epoch int Yes Current epoch number (0-indexed), used for learning rate scheduling and logging

Outputs

Name Type Description
avg_batch_time float Average wall-clock time per batch across the epoch, used for throughput reporting

Usage Examples

Main Training Loop

# From main.py (lines 480-514):
scaler = torch.cuda.amp.GradScaler(
    init_scale=args.loss_scale,
    growth_factor=2,
    backoff_factor=0.5,
    growth_interval=100,
    enabled=args.fp16_mode,
)

for epoch in range(args.start_epoch, args.epochs):
    avg_train_time = train(train_loader, model, criterion, scaler, optimizer, epoch)
    [prec1, prec5] = validate(val_loader, model, criterion)

    if args.local_rank == 0:
        is_best = prec1 > best_prec1
        best_prec1 = max(prec1, best_prec1)
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': args.arch,
            'state_dict': model.state_dict(),
            'best_prec1': best_prec1,
            'optimizer': optimizer.state_dict(),
        }, is_best)

DALI Data Unpacking Pattern

# Inside the training loop, DALI iterator produces different format than PyTorch:
for i, data in enumerate(train_loader):
    # DALI format:
    input = data[0]["data"]                        # [B, 3, 224, 224] float32 GPU
    target = data[0]["label"].squeeze(-1).long()   # [B] int64 GPU

    # Equivalent PyTorch DataLoader format would be:
    # input, target = data

Learning Rate Schedule

def adjust_learning_rate(optimizer, epoch, step, len_epoch):
    """Step decay with warmup: divide LR by 10 every 30 epochs, warmup for 5 epochs."""
    factor = epoch // 30
    if epoch >= 80:
        factor = factor + 1
    lr = args.lr * (0.1 ** factor)

    # Linear warmup for first 5 epochs
    if epoch < 5:
        lr = lr * float(1 + step + epoch * len_epoch) / (5. * len_epoch)

    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment