Implementation:NVIDIA TransformerEngine TE Distributed Checkpoint
Appearance
Overview
Concrete tool for TP-aware activation checkpointing provided by TransformerEngine.
Description
te.distributed.checkpoint wraps a module's forward pass with activation checkpointing, adding TE-specific features: distributed activation saving across the tensor-parallel group, RNG state tracking for reproducible recomputation, and support for reentrant/non-reentrant modes.
The function intercepts the forward pass of the given callable and:
- Saves only the input tensors (not intermediate activations) during the forward pass.
- Captures CUDA RNG states (both default and tracked model-parallel states) at the checkpoint boundary.
- Optionally distributes the first saved activation across the tensor-parallel group using reduce-scatter, reducing per-GPU memory by
1/tp_size. - Recomputes the forward pass during backward, restoring RNG states for bit-identical results.
Two modes are supported:
- Reentrant mode (
use_reentrant=True): Uses PyTorch's legacy reentrant checkpointing. The recomputed forward pass runs insidetorch.no_grad(). - Non-reentrant mode (
use_reentrant=False): Uses PyTorch's newer non-reentrant checkpointing, which supports more complex autograd graphs.
Source
transformer_engine/pytorch/distributed.py, function checkpoint at L644-794
Import
from transformer_engine.pytorch.distributed import checkpoint
Signature
def checkpoint(
function: Callable,
*args: Tuple[torch.Tensor, ...],
distribute_saved_activations: bool = False,
get_rng_state_tracker: Callable = None,
tp_group: ProcessGroup = None,
use_reentrant: bool = True,
**kwargs,
) -> Tuple[torch.Tensor, ...]:
I/O
| Direction | Description |
|---|---|
| Input | function (Callable): The module or callable whose forward pass should be checkpointed. *args: Input tensors to the function.
|
| Output | Tuple[torch.Tensor, ...]: Same as the function's output, but activations are recomputed during backward instead of being stored.
|
Key Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
function |
Callable |
required | The callable (module or function) whose forward pass is wrapped with activation checkpointing. |
distribute_saved_activations |
bool |
False |
If True, distributes the first saved activation tensor across the tensor-parallel group via reduce-scatter, reducing per-GPU memory.
|
get_rng_state_tracker |
Callable |
None |
A callable that returns the CudaRNGStatesTracker instance. Used to save/restore model-parallel RNG states for reproducible recomputation.
|
tp_group |
ProcessGroup |
None |
The tensor-parallel process group. Required when distribute_saved_activations=True.
|
use_reentrant |
bool |
True |
Whether to use reentrant (True) or non-reentrant (False) checkpointing mode.
|
Example Usage
import transformer_engine.pytorch as te
from transformer_engine.pytorch.distributed import checkpoint
# Define a TransformerLayer
layer = te.TransformerLayer(
hidden_size=4096,
ffn_hidden_size=11008,
num_attention_heads=32,
tp_group=tp_group,
tp_size=tp_size,
)
# Wrap the layer's forward with activation checkpointing
output = checkpoint(
layer,
hidden_states,
attention_mask,
distribute_saved_activations=True,
get_rng_state_tracker=get_rng_tracker,
tp_group=tp_group,
use_reentrant=False,
)
# During backward, the forward pass of `layer` is automatically
# re-executed to recompute intermediate activations.
output.sum().backward()
Related
Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment