Workflow:NVIDIA TransformerEngine FSDP Distributed Training
| Knowledge Sources | |
|---|---|
| Domains | LLMs, FP8_Training, Distributed_Training |
| Last Updated | 2026-02-07 21:00 GMT |
Overview
End-to-end process for training Transformer Engine models using PyTorch's Fully Sharded Data Parallel (FSDP) with FP8 precision and optional deferred initialization for memory-efficient multi-GPU training.
Description
This workflow demonstrates how to combine NVIDIA Transformer Engine with PyTorch's FSDP strategy for distributed training across multiple GPUs. FSDP shards model parameters, gradients, and optimizer states across data-parallel workers to reduce per-GPU memory consumption. When combined with TE's FP8 precision, this enables training of large Transformer models on clusters of Hopper/Ada/Blackwell GPUs with significant memory and compute savings. The workflow covers both standard initialization and deferred initialization (meta device) for models that exceed single-GPU memory.
Key outputs:
- A multi-GPU training setup with FSDP-sharded TE Transformer layers
- FP8 mixed precision training with activation checkpointing
- Support for models too large to fit on a single GPU via deferred initialization
Usage
Execute this workflow when you need to train a Transformer model across multiple GPUs using data parallelism with full sharding, and want to leverage FP8 precision for additional performance gains. This is appropriate for training runs where the model or its optimizer states are too large for a single GPU, or when you want to maximize throughput across a multi-GPU node.
Execution Steps
Step 1: Initialize Distributed Environment
Set up the distributed process group and CUDA RNG state tracker. Initialize torch.distributed with NCCL backend, assign each process to its local GPU, and create a CudaRNGStatesTracker for deterministic dropout behavior during activation checkpointing. This ensures that dropout masks are correctly regenerated during the recomputation phase.
Key considerations:
- Use torch.distributed.init_process_group with NCCL backend
- Create CudaRNGStatesTracker and add a "model-parallel-rng" state
- Set the device for each rank using torch.cuda.set_device(local_rank)
Step 2: Build Transformer Model With TE Modules
Construct the model using TE modules such as TransformerLayer, MultiheadAttention, and LayerNormMLP. Stack multiple transformer layers to form the full model. Optionally initialize on the meta device for deferred initialization, which avoids materializing the full model on any single GPU before FSDP sharding.
Key considerations:
- Use te.TransformerLayer for complete transformer blocks
- For deferred init: use device="meta" during model construction
- Configure attention, normalization, and activation parameters per model architecture
Step 3: Apply Activation Checkpointing
Wrap transformer layers with TE's activation checkpointing function to trade compute for memory. During the forward pass, intermediate activations are discarded and recomputed during the backward pass. Use te.distributed.checkpoint with the RNG state tracker to ensure correct dropout behavior during recomputation.
What happens:
- Forward activations are freed after the forward pass
- During backward, activations are recomputed from saved inputs
- The RNG tracker ensures dropout produces identical masks during recomputation
Step 4: Configure and Apply FSDP Wrapping
Wrap the model with PyTorch's FullyShardedDataParallel. Use te.distributed.prepare_te_modules_for_fsdp to configure TE modules for FSDP compatibility. Set the sharding strategy (full shard, shard grad op, or no shard), mixed precision policy, and auto-wrap policy to control the granularity of sharding.
Key considerations:
- Call prepare_te_modules_for_fsdp before wrapping with FSDP
- Set appropriate auto_wrap_policy to shard at transformer layer boundaries
- Configure mixed_precision policy for BF16 parameter and gradient communication
- For deferred init: use FSDP's param_init_fn to materialize parameters during wrapping
Step 5: Execute FP8 Training Loop
Run the training loop with te.autocast context manager for FP8 precision. Each iteration performs a forward pass inside the autocast context, computes the loss, runs the backward pass outside autocast, and updates parameters with the optimizer. FSDP handles parameter gathering and gradient synchronization automatically.
Key considerations:
- Wrap only the forward pass with te.autocast
- Configure the FP8 recipe (DelayedScaling or Float8CurrentScaling)
- The optimizer step runs on unsharded parameters (FSDP gathers them automatically)
- Gradient clipping and accumulation work as with standard PyTorch training