Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA TransformerEngine Prepare TE Modules For FSDP

From Leeroopedia
Revision as of 15:59, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/NVIDIA_TransformerEngine_Prepare_TE_Modules_For_FSDP.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Overview

Utility function to prepare TransformerEngine modules for FSDP compatibility.

Description

prepare_te_modules_for_fsdp(fsdp_root) traverses the FSDP-wrapped module tree and injects FSDP process group references into all TE modules. This enables proper FP8 weight scatter/gather during FSDP's parameter lifecycle.

The function performs the following operations:

  • Traverses the entire module tree rooted at fsdp_root, identifying all TransformerEngine modules.
  • Extracts the FSDP process group from the FSDP wrapper enclosing each TE module.
  • Injects the process group reference into each TE module so that FP8 metadata (amax histories, scaling factors) can be properly synchronized during FSDP's all-gather and reduce-scatter operations.

This is an in-place operation: it modifies the TE modules directly and does not return a new model. It must be called after FSDP wrapping because it depends on the FSDP wrapper's process group information.

Source

transformer_engine/pytorch/distributed.py, function prepare_te_modules_for_fsdp at L2003-2036

Import

from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp

Signature

def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:

I/O

Direction Description
Input fsdp_root (torch.nn.Module): The FSDP-wrapped root module containing TE modules in its subtree.
Output None. The function modifies TE modules in-place by injecting FSDP process group references.

Example Usage

import torch
import transformer_engine.pytorch as te
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp

# Step 1: Build the model with TE modules
model = te.TransformerLayer(
    hidden_size=4096,
    ffn_hidden_size=11008,
    num_attention_heads=32,
)

# Step 2: Wrap with FSDP
model = FSDP(model, device_id=torch.cuda.current_device())

# Step 3: Prepare TE modules for FSDP compatibility
prepare_te_modules_for_fsdp(model)

# Step 4: Proceed with training
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for batch in dataloader:
    output = model(batch)
    loss = output.sum()
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

Related

Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment