Implementation:NVIDIA TransformerEngine Prepare TE Modules For FSDP
Overview
Utility function to prepare TransformerEngine modules for FSDP compatibility.
Description
prepare_te_modules_for_fsdp(fsdp_root) traverses the FSDP-wrapped module tree and injects FSDP process group references into all TE modules. This enables proper FP8 weight scatter/gather during FSDP's parameter lifecycle.
The function performs the following operations:
- Traverses the entire module tree rooted at
fsdp_root, identifying all TransformerEngine modules. - Extracts the FSDP process group from the FSDP wrapper enclosing each TE module.
- Injects the process group reference into each TE module so that FP8 metadata (amax histories, scaling factors) can be properly synchronized during FSDP's all-gather and reduce-scatter operations.
This is an in-place operation: it modifies the TE modules directly and does not return a new model. It must be called after FSDP wrapping because it depends on the FSDP wrapper's process group information.
Source
transformer_engine/pytorch/distributed.py, function prepare_te_modules_for_fsdp at L2003-2036
Import
from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp
Signature
def prepare_te_modules_for_fsdp(fsdp_root: torch.nn.Module) -> None:
I/O
| Direction | Description |
|---|---|
| Input | fsdp_root (torch.nn.Module): The FSDP-wrapped root module containing TE modules in its subtree.
|
| Output | None. The function modifies TE modules in-place by injecting FSDP process group references.
|
Example Usage
import torch
import transformer_engine.pytorch as te
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from transformer_engine.pytorch.distributed import prepare_te_modules_for_fsdp
# Step 1: Build the model with TE modules
model = te.TransformerLayer(
hidden_size=4096,
ffn_hidden_size=11008,
num_attention_heads=32,
)
# Step 2: Wrap with FSDP
model = FSDP(model, device_id=torch.cuda.current_device())
# Step 3: Prepare TE modules for FSDP compatibility
prepare_te_modules_for_fsdp(model)
# Step 4: Proceed with training
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
for batch in dataloader:
output = model(batch)
loss = output.sum()
loss.backward()
optimizer.step()
optimizer.zero_grad()
Related
Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements