Heuristic:Alibaba ROLL Gradient Checkpointing Recomputation
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Memory_Management |
| Last Updated | 2026-02-07 19:00 GMT |
Overview
Activation recomputation that trades 20-30% compute overhead for 50-60% VRAM reduction, with auto-detection of reentrant mode based on model architecture.
Description
ROLL supports two forms of activation memory reduction: HuggingFace gradient checkpointing (controlled by `disable_gradient_checkpointing` flag) and Megatron-Core recomputation (controlled by `recompute_granularity`). HuggingFace gradient checkpointing works by discarding intermediate activations during the forward pass and recomputing them during backward. The framework auto-detects whether to use reentrant or non-reentrant mode based on the model architecture. Megatron-Core provides finer control with `full` recomputation (recompute all layers) and `selective` recomputation (recompute only attention, keep FFN). MoE layers have dedicated recomputation via `moe_layer_recompute`.
Usage
Enable gradient checkpointing when GPU VRAM is insufficient for the model (e.g., 7B+ models on consumer GPUs). It is standard practice for any model larger than the GPU can fit without it. For Megatron backend, use `recompute_granularity: full` for maximum memory savings or `selective` for a balance between speed and memory.
The Insight (Rule of Thumb)
- Action: Keep `disable_gradient_checkpointing=False` (the default). For Megatron: set `recompute_granularity: full` or `selective`.
- Value: ~50-60% VRAM reduction at ~20-30% compute overhead. MoE layers: use `moe_layer_recompute: True` for additional savings.
- Trade-off: Full recomputation saves the most memory but is slowest. Selective recomputation balances speed/memory by keeping FFN activations.
- V100 GPUs: Always enable recomputation with `recompute_granularity: full` and `use_distributed_optimizer: true`.
Reasoning
Transformer activations scale as O(batch_size * seq_len * hidden_dim * num_layers). For a 7B model with 2048 sequence length, this can exceed 20GB of activation memory alone. Recomputation reduces this to O(sqrt(num_layers)) memory at the cost of recomputing activations during backward pass. The compute overhead is modest because modern GPUs are often memory-bandwidth-bound rather than compute-bound during training.
Model provider gradient checkpointing from `roll/models/model_providers.py:249-259`:
if not model_args.disable_gradient_checkpointing:
if model_args.gradient_checkpointing_use_reentrant is None:
use_reentrant = bool(model_args.gradient_checkpointing_use_reentrant)
model.gradient_checkpointing_enable(
gradient_checkpointing_kwargs={"use_reentrant": use_reentrant}
)
use_cache must be disabled during training from `roll/distributed/strategy/deepspeed_strategy.py:184`:
# set use_cache=False manually for the same reason as HfInferStrategy
use_cache=False,
Qwen-VL special case from `roll/distributed/strategy/hf_strategy.py:88-90`:
# in Qwen2-vl/Qwen2.5-vl, use_cache=False should be set manually to
# avoid unexpected behavior since model.generation_config.cache_implementation
# is not None