Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Hpcaitech ColossalAI Gradient Checkpointing Memory Tip

From Leeroopedia



Knowledge Sources
Domains Optimization, Memory_Management
Last Updated 2026-02-09 03:00 GMT

Overview

Enable gradient checkpointing to reduce VRAM usage during training, but always call `model.train()` first and use `use_reentrant=False`.

Description

Gradient checkpointing reduces peak GPU memory by recomputing activations during the backward pass instead of storing them. In ColossalAI, enabling gradient checkpointing requires a specific call order: `model.train()` must be called before `model.gradient_checkpointing_enable()`, otherwise checkpointing silently fails. The `use_reentrant=False` flag is recommended for compatibility with modern PyTorch autograd.

Usage

Use this when training large models (7B+ parameters) on GPUs with limited VRAM, or when encountering CUDA OOM errors during the backward pass. Standard practice for all ColossalAI training workflows.

The Insight (Rule of Thumb)

  • Action: Call `model.train()` before `model.gradient_checkpointing_enable()`.
  • Value: Pass `gradient_checkpointing_kwargs={"use_reentrant": False}` for modern PyTorch.
  • Trade-off: ~20-30% slower training for ~50-60% VRAM reduction.
  • Warning: LoRA and gradient checkpointing are incompatible for some models (e.g., ChatGLM). When using LoRA, verify compatibility first.

Reasoning

The `model.train()` call sets the model to training mode, which is a prerequisite for gradient checkpointing hooks to be registered. Without it, the checkpointing hooks are silently skipped, resulting in no memory savings and confusing OOM errors. The `use_reentrant=False` flag avoids known issues with PyTorch's reentrant autograd and is the recommended setting for PyTorch >= 2.0.

Code Evidence

From `applications/Colossal-LLaMA/train.py:193-198`:

# this is essential, otherwise the grad checkpoint will not work.
model.train()

if args.use_grad_checkpoint:
    model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
    coordinator.print_on_master(msg="Gradient checkpointing enabled successfully")

From `applications/ColossalChat/coati/distributed/grpo_consumer.py:74-75`:

self.policy_model.train()
self.policy_model.gradient_checkpointing_enable()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment