Heuristic:Microsoft DeepSpeedExamples Gradient Checkpointing Tradeoff
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Deep_Learning, LLMs |
| Last Updated | 2026-02-07 13:00 GMT |
Overview
Gradient checkpointing increases the FLOPs computation factor from 3x to 4x of the forward pass but reduces VRAM usage by 30-50%, making it essential for memory-constrained large model training.
Description
Gradient checkpointing (also called activation checkpointing) saves memory by not storing all intermediate activations during the forward pass. Instead, only a subset of activations are kept, and the rest are recomputed during the backward pass. In the DeepSpeed-Chat codebase, the performance calculation explicitly uses a factor of 4 (with checkpointing) versus 3 (without checkpointing) when computing TFLOPS. This factor represents the total computation relative to a single forward pass: 1x forward + 1x backward + 1x recomputation = 4x (with checkpointing) versus 1x forward + 2x backward = 3x (without). LoRA further reduces this factor by proportionally decreasing the number of parameters that need gradient computation.
Usage
Enable gradient checkpointing when GPU VRAM is insufficient for your model and batch size. Disable it when training is compute-bound and memory is not a constraint. This tradeoff is most impactful when fine-tuning 7B+ parameter models on consumer GPUs (A6000, RTX 3090/4090).
The Insight (Rule of Thumb)
- Action: Set `--gradient_checkpointing` flag in training arguments, or `model.gradient_checkpointing_enable()` in code.
- Value: VRAM reduction of ~30-50%; compute overhead of ~33% (factor 3 to 4).
- Trade-off: Slower training speed for lower memory usage.
- Interaction with LoRA: LoRA reduces the checkpointing overhead proportionally. The effective factor is `4 - (1 - k)` where `k = lora_dim * 2 / hidden_size`. For a 128-dim LoRA on a 2048-hidden model, k = 0.125, so the factor drops from 4.0 to ~3.125.
Reasoning
The memory savings come from not storing O(layers x batch x seq_len x hidden) intermediate activations. For a 7B model with batch size 4 and sequence length 2048, this can save 10-20GB of VRAM. The compute cost is predictable: exactly one additional forward pass through the model. This tradeoff is almost always worthwhile for memory-constrained setups because the forward pass is faster than the backward pass (fewer operations), making the overhead relatively small compared to the memory savings.
Code Evidence:
Factor calculation from `applications/DeepSpeed-Chat/dschat/utils/perf.py:19`:
checkpoint_activations_factor = 4 if args.gradient_checkpointing else 3
LoRA interaction from `dschat/utils/perf.py:67-74`:
if args.lora_dim > 0:
k = args.lora_dim * 2 / config.hidden_size
checkpoint_activations_factor -= (1 - k)
Gradient clipping configuration from `dschat/utils/ds_utils.py:59-60`:
"gradient_clipping": 1.0,
"prescale_gradients": False,