Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Workflow:NVIDIA TransformerEngine FSDP Distributed Training

From Leeroopedia


Knowledge Sources
Domains LLMs, FP8_Training, Distributed_Training
Last Updated 2026-02-07 21:00 GMT

Overview

End-to-end process for training Transformer Engine models using PyTorch's Fully Sharded Data Parallel (FSDP) with FP8 precision and optional deferred initialization for memory-efficient multi-GPU training.

Description

This workflow demonstrates how to combine NVIDIA Transformer Engine with PyTorch's FSDP strategy for distributed training across multiple GPUs. FSDP shards model parameters, gradients, and optimizer states across data-parallel workers to reduce per-GPU memory consumption. When combined with TE's FP8 precision, this enables training of large Transformer models on clusters of Hopper/Ada/Blackwell GPUs with significant memory and compute savings. The workflow covers both standard initialization and deferred initialization (meta device) for models that exceed single-GPU memory.

Key outputs:

  • A multi-GPU training setup with FSDP-sharded TE Transformer layers
  • FP8 mixed precision training with activation checkpointing
  • Support for models too large to fit on a single GPU via deferred initialization

Usage

Execute this workflow when you need to train a Transformer model across multiple GPUs using data parallelism with full sharding, and want to leverage FP8 precision for additional performance gains. This is appropriate for training runs where the model or its optimizer states are too large for a single GPU, or when you want to maximize throughput across a multi-GPU node.

Execution Steps

Step 1: Initialize Distributed Environment

Set up the distributed process group and CUDA RNG state tracker. Initialize torch.distributed with NCCL backend, assign each process to its local GPU, and create a CudaRNGStatesTracker for deterministic dropout behavior during activation checkpointing. This ensures that dropout masks are correctly regenerated during the recomputation phase.

Key considerations:

  • Use torch.distributed.init_process_group with NCCL backend
  • Create CudaRNGStatesTracker and add a "model-parallel-rng" state
  • Set the device for each rank using torch.cuda.set_device(local_rank)

Step 2: Build Transformer Model With TE Modules

Construct the model using TE modules such as TransformerLayer, MultiheadAttention, and LayerNormMLP. Stack multiple transformer layers to form the full model. Optionally initialize on the meta device for deferred initialization, which avoids materializing the full model on any single GPU before FSDP sharding.

Key considerations:

  • Use te.TransformerLayer for complete transformer blocks
  • For deferred init: use device="meta" during model construction
  • Configure attention, normalization, and activation parameters per model architecture

Step 3: Apply Activation Checkpointing

Wrap transformer layers with TE's activation checkpointing function to trade compute for memory. During the forward pass, intermediate activations are discarded and recomputed during the backward pass. Use te.distributed.checkpoint with the RNG state tracker to ensure correct dropout behavior during recomputation.

What happens:

  • Forward activations are freed after the forward pass
  • During backward, activations are recomputed from saved inputs
  • The RNG tracker ensures dropout produces identical masks during recomputation

Step 4: Configure and Apply FSDP Wrapping

Wrap the model with PyTorch's FullyShardedDataParallel. Use te.distributed.prepare_te_modules_for_fsdp to configure TE modules for FSDP compatibility. Set the sharding strategy (full shard, shard grad op, or no shard), mixed precision policy, and auto-wrap policy to control the granularity of sharding.

Key considerations:

  • Call prepare_te_modules_for_fsdp before wrapping with FSDP
  • Set appropriate auto_wrap_policy to shard at transformer layer boundaries
  • Configure mixed_precision policy for BF16 parameter and gradient communication
  • For deferred init: use FSDP's param_init_fn to materialize parameters during wrapping

Step 5: Execute FP8 Training Loop

Run the training loop with te.autocast context manager for FP8 precision. Each iteration performs a forward pass inside the autocast context, computes the loss, runs the backward pass outside autocast, and updates parameters with the optimizer. FSDP handles parameter gathering and gradient synchronization automatically.

Key considerations:

  • Wrap only the forward pass with te.autocast
  • Configure the FP8 recipe (DelayedScaling or Float8CurrentScaling)
  • The optimizer step runs on unsharded parameters (FSDP gathers them automatically)
  • Gradient clipping and accumulation work as with standard PyTorch training

Execution Diagram

GitHub URL

Workflow Repository