Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Microsoft DeepSpeedExamples Gradient Checkpointing Tradeoff

From Leeroopedia



Knowledge Sources
Domains Optimization, Deep_Learning, LLMs
Last Updated 2026-02-07 13:00 GMT

Overview

Gradient checkpointing increases the FLOPs computation factor from 3x to 4x of the forward pass but reduces VRAM usage by 30-50%, making it essential for memory-constrained large model training.

Description

Gradient checkpointing (also called activation checkpointing) saves memory by not storing all intermediate activations during the forward pass. Instead, only a subset of activations are kept, and the rest are recomputed during the backward pass. In the DeepSpeed-Chat codebase, the performance calculation explicitly uses a factor of 4 (with checkpointing) versus 3 (without checkpointing) when computing TFLOPS. This factor represents the total computation relative to a single forward pass: 1x forward + 1x backward + 1x recomputation = 4x (with checkpointing) versus 1x forward + 2x backward = 3x (without). LoRA further reduces this factor by proportionally decreasing the number of parameters that need gradient computation.

Usage

Enable gradient checkpointing when GPU VRAM is insufficient for your model and batch size. Disable it when training is compute-bound and memory is not a constraint. This tradeoff is most impactful when fine-tuning 7B+ parameter models on consumer GPUs (A6000, RTX 3090/4090).

The Insight (Rule of Thumb)

  • Action: Set `--gradient_checkpointing` flag in training arguments, or `model.gradient_checkpointing_enable()` in code.
  • Value: VRAM reduction of ~30-50%; compute overhead of ~33% (factor 3 to 4).
  • Trade-off: Slower training speed for lower memory usage.
  • Interaction with LoRA: LoRA reduces the checkpointing overhead proportionally. The effective factor is `4 - (1 - k)` where `k = lora_dim * 2 / hidden_size`. For a 128-dim LoRA on a 2048-hidden model, k = 0.125, so the factor drops from 4.0 to ~3.125.

Reasoning

The memory savings come from not storing O(layers x batch x seq_len x hidden) intermediate activations. For a 7B model with batch size 4 and sequence length 2048, this can save 10-20GB of VRAM. The compute cost is predictable: exactly one additional forward pass through the model. This tradeoff is almost always worthwhile for memory-constrained setups because the forward pass is faster than the backward pass (fewer operations), making the overhead relatively small compared to the memory savings.

Code Evidence:

Factor calculation from `applications/DeepSpeed-Chat/dschat/utils/perf.py:19`:

checkpoint_activations_factor = 4 if args.gradient_checkpointing else 3

LoRA interaction from `dschat/utils/perf.py:67-74`:

if args.lora_dim > 0:
    k = args.lora_dim * 2 / config.hidden_size
    checkpoint_activations_factor -= (1 - k)

Gradient clipping configuration from `dschat/utils/ds_utils.py:59-60`:

"gradient_clipping": 1.0,
"prescale_gradients": False,

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment