Heuristic:Hpcaitech ColossalAI Gradient Checkpointing Memory Tip
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Memory_Management |
| Last Updated | 2026-02-09 03:00 GMT |
Overview
Enable gradient checkpointing to reduce VRAM usage during training, but always call `model.train()` first and use `use_reentrant=False`.
Description
Gradient checkpointing reduces peak GPU memory by recomputing activations during the backward pass instead of storing them. In ColossalAI, enabling gradient checkpointing requires a specific call order: `model.train()` must be called before `model.gradient_checkpointing_enable()`, otherwise checkpointing silently fails. The `use_reentrant=False` flag is recommended for compatibility with modern PyTorch autograd.
Usage
Use this when training large models (7B+ parameters) on GPUs with limited VRAM, or when encountering CUDA OOM errors during the backward pass. Standard practice for all ColossalAI training workflows.
The Insight (Rule of Thumb)
- Action: Call `model.train()` before `model.gradient_checkpointing_enable()`.
- Value: Pass `gradient_checkpointing_kwargs={"use_reentrant": False}` for modern PyTorch.
- Trade-off: ~20-30% slower training for ~50-60% VRAM reduction.
- Warning: LoRA and gradient checkpointing are incompatible for some models (e.g., ChatGLM). When using LoRA, verify compatibility first.
Reasoning
The `model.train()` call sets the model to training mode, which is a prerequisite for gradient checkpointing hooks to be registered. Without it, the checkpointing hooks are silently skipped, resulting in no memory savings and confusing OOM errors. The `use_reentrant=False` flag avoids known issues with PyTorch's reentrant autograd and is the recommended setting for PyTorch >= 2.0.
Code Evidence
From `applications/Colossal-LLaMA/train.py:193-198`:
# this is essential, otherwise the grad checkpoint will not work.
model.train()
if args.use_grad_checkpoint:
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")
From `applications/ColossalChat/coati/distributed/grpo_consumer.py:74-75`:
self.policy_model.train()
self.policy_model.gradient_checkpointing_enable()