Implementation:NVIDIA DALI TF Runner CTL
| Knowledge Sources | |
|---|---|
| Domains | Image_Classification, TensorFlow |
| Last Updated | 2026-02-08 16:00 GMT |
Overview
Implements a custom training loop (CTL) runner for distributed ResNet training with TensorFlow, supporting mixed precision (FP16), Horovod multi-GPU scaling, DALI data loading, and XLA JIT compilation.
Description
The `train_ctl()` function is the main entry point for training ResNet models using a TensorFlow custom training loop (as opposed to Keras `model.fit()`). It orchestrates the complete training workflow: GPU configuration with Horovod for multi-GPU distributed training, mixed precision policy setup (FP16 with configurable loss scaling), learning rate scheduling with warmup and piecewise constant decay, dataset creation via the `image_processing` module (supporting both DALI and native TF pipelines), model construction, and the training/validation loop.
The training step is decorated with `@tf.function` for graph compilation and handles: forward pass, loss computation (sparse categorical cross-entropy plus model regularization losses), manual loss scaling for FP16 precision, gradient computation via `hvd.DistributedGradientTape` for all-reduce across workers, gradient unscaling, and optimizer application. On the first batch, Horovod broadcasts model variables and optimizer state from rank 0 to synchronize all workers.
The validation step computes loss and top-1/top-5 accuracy metrics. The training loop supports epoch-based or batch-based iteration, periodic metric logging, optional TensorBoard summary writing (rank 0 only), checkpoint saving and restoration, and model export in HDF5 format. Each epoch includes both training and validation phases with metric reset between epochs.
The runner supports configurable parameters including image dimensions, batch size, data directory, precision mode, DALI mode, XLA compilation, and various logging/export options.
Usage
Use this function as the training driver for ResNet-N models in the DALI examples. It is called from the main training script with a model constructor function and a parameters dictionary.
Code Reference
Source Location
- Repository: NVIDIA_DALI
- File: docs/examples/use_cases/tensorflow/resnet-n/nvutils/runner_ctl.py
- Lines: 1-351
Signature
def train_ctl(model_func, params):
"""Custom training loop for distributed ResNet training.
Args:
model_func: Callable that returns a Keras model, accepting
num_classes and batch_size keyword arguments.
params: Dictionary with training configuration parameters.
"""
...
Import
from nvutils.runner_ctl import train_ctl
train_ctl(model_func=resnet50, params=training_params)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_func | callable | Yes | Function returning a Keras model (signature: num_classes, batch_size) |
| params['image_width'] | int | Yes | Input image width (e.g., 224) |
| params['image_height'] | int | Yes | Input image height (e.g., 224) |
| params['batch_size'] | int | Yes | Per-GPU batch size |
| params['precision'] | str | Yes | 'fp16' or 'fp32' |
| params['data_dir'] | str | No | Directory with TFRecord data (None for synthetic data) |
| params['data_idx_dir'] | str | No | Directory with DALI index files |
| params['dali_mode'] | str | No | DALI mode: 'CPU', 'GPU', or None |
| params['num_iter'] | int | Yes | Number of epochs or batches to train |
| params['iter_unit'] | str | Yes | 'epoch' or 'batch' |
| params['momentum'] | float | Yes | SGD momentum value |
| params['loss_scale'] | float | Yes | Loss scale for FP16 training |
| params['use_xla'] | bool | Yes | Whether to enable XLA JIT compilation |
| params['log_dir'] | str | No | Directory for checkpoints |
| params['export_dir'] | str | No | Directory for model export |
| params['tensorboard_dir'] | str | No | Directory for TensorBoard logs |
| params['display_every'] | int | Yes | Steps between metric display |
Outputs
| Name | Type | Description |
|---|---|---|
| Trained model | tf.keras.Model | ResNet model with updated weights (saved to log_dir/export_dir if specified) |
| Training logs | console/TensorBoard | Loss, top-1/top-5 accuracy, learning rate, throughput metrics |
Usage Examples
Train ResNet-50 with DALI and FP16
from nvutils.runner_ctl import train_ctl
from resnet_model import resnet50
params = {
'image_width': 224,
'image_height': 224,
'image_format': 'channels_last',
'distort_color': False,
'momentum': 0.9,
'loss_scale': 128.0,
'data_dir': '/data/imagenet/tfrecords',
'data_idx_dir': '/data/imagenet/dali_idx',
'batch_size': 256,
'num_iter': 90,
'iter_unit': 'epoch',
'log_dir': '/tmp/resnet_ckpt',
'export_dir': '/tmp/resnet_export',
'tensorboard_dir': '/tmp/resnet_tb',
'display_every': 100,
'precision': 'fp16',
'dali_mode': 'GPU',
'use_xla': True,
}
train_ctl(model_func=resnet50, params=params)