Principle:NVIDIA TransformerEngine Activation Checkpointing
Overview
Trading compute for memory by discarding intermediate activations during forward pass and recomputing them during backward pass.
Description
Activation checkpointing (also known as gradient checkpointing) reduces peak memory usage by not storing intermediate activations from the forward pass. During the backward pass, the forward pass is re-run to recompute needed activations on the fly. TransformerEngine's version adds tensor-parallel awareness and RNG state management for reproducible recomputation.
In standard backpropagation, all intermediate activations from the forward pass are stored in memory so they can be used during the backward pass to compute gradients. For deep models with many layers, this memory requirement can become prohibitive. Activation checkpointing addresses this by:
- Discarding intermediate activations after the forward pass completes.
- Saving only the inputs at checkpoint boundaries (typically layer boundaries).
- Recomputing discarded activations during the backward pass by re-executing the forward pass from the nearest checkpoint.
TransformerEngine's implementation extends standard PyTorch checkpointing with:
- Tensor-parallel activation distribution: The first saved activation can be scattered across TP ranks, reducing per-GPU memory by
1/tp_size. - RNG state tracking: Captures and restores CUDA RNG states so that stochastic operations (dropout) produce identical results during recomputation.
- Reentrant and non-reentrant modes: Supports both PyTorch checkpointing paradigms.
Theoretical Basis
For a model with N layers, standard training stores O(N) activations. With checkpointing:
- Only checkpoint boundaries store activations; intermediate ones are recomputed during backward.
- This trades approximately 33% extra compute (one additional forward pass per checkpointed segment) for significantly reduced memory.
- With optimal checkpoint placement, memory usage reduces to
O(sqrt(N)).
The key constraint is reproducibility: the recomputed forward pass must produce bit-identical results to the original. This requires:
- Saving and restoring the CUDA RNG state at each checkpoint boundary.
- Ensuring deterministic execution of all operations within the checkpointed region.
- Properly handling tensor-parallel communication during recomputation.
Usage
Use when GPU memory is insufficient for the desired batch size or model size. Common scenarios include:
- FSDP training where memory is already sharded but still insufficient for large models.
- Large model training where per-layer activations consume significant memory.
- Increasing effective batch size by freeing activation memory for more data.
Activation checkpointing is typically applied at the TransformerLayer granularity: each layer's forward pass is wrapped in a checkpoint, so only the layer inputs (not intermediate attention/MLP activations) are stored.