Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:NVIDIA TransformerEngine FSDP Integration

From Leeroopedia
Revision as of 17:09, 16 February 2026 by Admin (talk | contribs) (Auto-imported from principles/NVIDIA_TransformerEngine_FSDP_Integration.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Overview

Integrating TransformerEngine modules with PyTorch's Fully Sharded Data Parallel for distributed training.

Description

FSDP shards model parameters across GPUs and gathers them on-demand for forward/backward passes. TransformerEngine modules need special preparation because their FP8 state (amax histories, scaling factors) must be aware of FSDP's parameter lifecycle (scatter/gather). The prepare_te_modules_for_fsdp utility injects FSDP process group references into TE modules.

The integration challenge arises from the interaction between two systems:

  • FSDP manages parameter sharding: it scatters parameters across GPUs after initialization and gathers them before each forward/backward pass, then re-scatters them afterward.
  • TransformerEngine maintains FP8 quantization metadata (amax histories, scaling factors) alongside parameters. This metadata must be synchronized correctly when FSDP gathers and scatters parameters.

Without proper preparation, TE modules are unaware of FSDP's parameter lifecycle, leading to:

  • Stale FP8 scales that do not reflect the current gathered parameters.
  • Incorrect amax tracking when parameters are in their sharded state.
  • Communication errors when FP8 metadata is not properly distributed across FSDP ranks.

The preparation step bridges these two systems by injecting FSDP process group references into every TE module in the model tree, enabling TE's internal FP8 management to coordinate with FSDP's parameter lifecycle.

Theoretical Basis

FSDP wraps each module and manages parameter sharding/unsharding through a defined lifecycle:

  1. Pre-forward: FSDP all-gathers sharded parameters to reconstruct full parameters.
  2. Forward: The module executes its forward pass with full parameters.
  3. Post-forward: FSDP re-scatters parameters back to their sharded state.
  4. Pre-backward: FSDP all-gathers parameters again for gradient computation.
  5. Post-backward: Gradients are reduce-scattered, and parameters return to sharded state.

TE's FP8 quantizers maintain scaling metadata (amax, scale, scale_inverse) that must be consistent with the parameter state at each stage. The preparation step ensures that when FSDP gathers parameters for the forward pass, TE's FP8 quantizers also properly synchronize their state across the FSDP group.

Usage

Use after wrapping a TE model with FSDP to ensure proper FP8 state management. The typical workflow is:

  1. Build the model using TE modules (te.TransformerLayer, te.Linear, etc.).
  2. Wrap the model with torch.distributed.fsdp.FullyShardedDataParallel.
  3. Call prepare_te_modules_for_fsdp(fsdp_model) on the wrapped model.
  4. Proceed with training as normal.

This is required whenever:

  • Using TE modules with FSDP (PyTorch native or any FSDP wrapper).
  • FP8 training is enabled with FSDP.
  • The model contains any te.TransformerLayer, te.Linear, or te.LayerNormLinear modules under an FSDP wrapper.

Related

Sources

Domains

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment