Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA DALI TF Runner CTL

From Leeroopedia


Knowledge Sources
Domains Image_Classification, TensorFlow
Last Updated 2026-02-08 16:00 GMT

Overview

Implements a custom training loop (CTL) runner for distributed ResNet training with TensorFlow, supporting mixed precision (FP16), Horovod multi-GPU scaling, DALI data loading, and XLA JIT compilation.

Description

The `train_ctl()` function is the main entry point for training ResNet models using a TensorFlow custom training loop (as opposed to Keras `model.fit()`). It orchestrates the complete training workflow: GPU configuration with Horovod for multi-GPU distributed training, mixed precision policy setup (FP16 with configurable loss scaling), learning rate scheduling with warmup and piecewise constant decay, dataset creation via the `image_processing` module (supporting both DALI and native TF pipelines), model construction, and the training/validation loop.

The training step is decorated with `@tf.function` for graph compilation and handles: forward pass, loss computation (sparse categorical cross-entropy plus model regularization losses), manual loss scaling for FP16 precision, gradient computation via `hvd.DistributedGradientTape` for all-reduce across workers, gradient unscaling, and optimizer application. On the first batch, Horovod broadcasts model variables and optimizer state from rank 0 to synchronize all workers.

The validation step computes loss and top-1/top-5 accuracy metrics. The training loop supports epoch-based or batch-based iteration, periodic metric logging, optional TensorBoard summary writing (rank 0 only), checkpoint saving and restoration, and model export in HDF5 format. Each epoch includes both training and validation phases with metric reset between epochs.

The runner supports configurable parameters including image dimensions, batch size, data directory, precision mode, DALI mode, XLA compilation, and various logging/export options.

Usage

Use this function as the training driver for ResNet-N models in the DALI examples. It is called from the main training script with a model constructor function and a parameters dictionary.

Code Reference

Source Location

Signature

def train_ctl(model_func, params):
    """Custom training loop for distributed ResNet training.

    Args:
        model_func: Callable that returns a Keras model, accepting
                     num_classes and batch_size keyword arguments.
        params: Dictionary with training configuration parameters.
    """
    ...

Import

from nvutils.runner_ctl import train_ctl

train_ctl(model_func=resnet50, params=training_params)

I/O Contract

Inputs

Name Type Required Description
model_func callable Yes Function returning a Keras model (signature: num_classes, batch_size)
params['image_width'] int Yes Input image width (e.g., 224)
params['image_height'] int Yes Input image height (e.g., 224)
params['batch_size'] int Yes Per-GPU batch size
params['precision'] str Yes 'fp16' or 'fp32'
params['data_dir'] str No Directory with TFRecord data (None for synthetic data)
params['data_idx_dir'] str No Directory with DALI index files
params['dali_mode'] str No DALI mode: 'CPU', 'GPU', or None
params['num_iter'] int Yes Number of epochs or batches to train
params['iter_unit'] str Yes 'epoch' or 'batch'
params['momentum'] float Yes SGD momentum value
params['loss_scale'] float Yes Loss scale for FP16 training
params['use_xla'] bool Yes Whether to enable XLA JIT compilation
params['log_dir'] str No Directory for checkpoints
params['export_dir'] str No Directory for model export
params['tensorboard_dir'] str No Directory for TensorBoard logs
params['display_every'] int Yes Steps between metric display

Outputs

Name Type Description
Trained model tf.keras.Model ResNet model with updated weights (saved to log_dir/export_dir if specified)
Training logs console/TensorBoard Loss, top-1/top-5 accuracy, learning rate, throughput metrics

Usage Examples

Train ResNet-50 with DALI and FP16

from nvutils.runner_ctl import train_ctl
from resnet_model import resnet50

params = {
    'image_width': 224,
    'image_height': 224,
    'image_format': 'channels_last',
    'distort_color': False,
    'momentum': 0.9,
    'loss_scale': 128.0,
    'data_dir': '/data/imagenet/tfrecords',
    'data_idx_dir': '/data/imagenet/dali_idx',
    'batch_size': 256,
    'num_iter': 90,
    'iter_unit': 'epoch',
    'log_dir': '/tmp/resnet_ckpt',
    'export_dir': '/tmp/resnet_export',
    'tensorboard_dir': '/tmp/resnet_tb',
    'display_every': 100,
    'precision': 'fp16',
    'dali_mode': 'GPU',
    'use_xla': True,
}

train_ctl(model_func=resnet50, params=params)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment