Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:NVIDIA TransformerEngine TE Distributed Checkpoint

From Leeroopedia


Overview

Concrete tool for TP-aware activation checkpointing provided by TransformerEngine.

Description

te.distributed.checkpoint wraps a module's forward pass with activation checkpointing, adding TE-specific features: distributed activation saving across the tensor-parallel group, RNG state tracking for reproducible recomputation, and support for reentrant/non-reentrant modes.

The function intercepts the forward pass of the given callable and:

  • Saves only the input tensors (not intermediate activations) during the forward pass.
  • Captures CUDA RNG states (both default and tracked model-parallel states) at the checkpoint boundary.
  • Optionally distributes the first saved activation across the tensor-parallel group using reduce-scatter, reducing per-GPU memory by 1/tp_size.
  • Recomputes the forward pass during backward, restoring RNG states for bit-identical results.

Two modes are supported:

  • Reentrant mode (use_reentrant=True): Uses PyTorch's legacy reentrant checkpointing. The recomputed forward pass runs inside torch.no_grad().
  • Non-reentrant mode (use_reentrant=False): Uses PyTorch's newer non-reentrant checkpointing, which supports more complex autograd graphs.

Source

transformer_engine/pytorch/distributed.py, function checkpoint at L644-794

Import

from transformer_engine.pytorch.distributed import checkpoint

Signature

def checkpoint(
    function: Callable,
    *args: Tuple[torch.Tensor, ...],
    distribute_saved_activations: bool = False,
    get_rng_state_tracker: Callable = None,
    tp_group: ProcessGroup = None,
    use_reentrant: bool = True,
    **kwargs,
) -> Tuple[torch.Tensor, ...]:

I/O

Direction Description
Input function (Callable): The module or callable whose forward pass should be checkpointed. *args: Input tensors to the function.
Output Tuple[torch.Tensor, ...]: Same as the function's output, but activations are recomputed during backward instead of being stored.

Key Parameters

Parameter Type Default Description
function Callable required The callable (module or function) whose forward pass is wrapped with activation checkpointing.
distribute_saved_activations bool False If True, distributes the first saved activation tensor across the tensor-parallel group via reduce-scatter, reducing per-GPU memory.
get_rng_state_tracker Callable None A callable that returns the CudaRNGStatesTracker instance. Used to save/restore model-parallel RNG states for reproducible recomputation.
tp_group ProcessGroup None The tensor-parallel process group. Required when distribute_saved_activations=True.
use_reentrant bool True Whether to use reentrant (True) or non-reentrant (False) checkpointing mode.

Example Usage

import transformer_engine.pytorch as te
from transformer_engine.pytorch.distributed import checkpoint

# Define a TransformerLayer
layer = te.TransformerLayer(
    hidden_size=4096,
    ffn_hidden_size=11008,
    num_attention_heads=32,
    tp_group=tp_group,
    tp_size=tp_size,
)

# Wrap the layer's forward with activation checkpointing
output = checkpoint(
    layer,
    hidden_states,
    attention_mask,
    distribute_saved_activations=True,
    get_rng_state_tracker=get_rng_tracker,
    tp_group=tp_group,
    use_reentrant=False,
)

# During backward, the forward pass of `layer` is automatically
# re-executed to recompute intermediate activations.
output.sum().backward()

Related

Environment:NVIDIA_TransformerEngine_CUDA_Toolkit_Requirements

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment