Heuristic:Eric mitchell Direct preference optimization Activation Checkpointing Memory
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Deep_Learning |
| Last Updated | 2026-02-08 02:00 GMT |
Overview
Enable FSDP activation checkpointing to reduce VRAM usage by trading compute for memory, allowing larger batch sizes on memory-constrained GPUs.
Description
Activation checkpointing (also called gradient checkpointing) reduces peak memory usage during training by not storing intermediate activations from the forward pass. Instead, activations are recomputed during the backward pass. The DPO codebase implements this for FSDPTrainer using PyTorch's `apply_activation_checkpointing` API with non-reentrant checkpointing at the transformer block level. It is controlled by the `activation_checkpointing` config flag.
Usage
Use this heuristic when VRAM is the bottleneck and you cannot fit batch size >= 2 per GPU even with mixed precision. The README advises: "Activation checkpointing doesn't always increase throughput, but if you're stuck at batch size per GPU of 1, it's worth a try." Only available for FSDPTrainer.
The Insight (Rule of Thumb)
- Action: Pass `activation_checkpointing=true` on the command line.
- Value: Boolean flag, default `false`.
- Trade-off: Significant VRAM reduction (typically 40-60%) at the cost of increased computation time (~20-30% slower) due to recomputing activations during the backward pass. Only beneficial if the freed memory allows increasing batch size.
- Compatibility: Only implemented for FSDPTrainer. Uses non-reentrant checkpoint wrapper. Requires `torch.distributed.algorithms._checkpoint.checkpoint_wrapper` module.
Reasoning
Transformer models store large activation tensors (batch_size x seq_length x hidden_dim) at each layer for the backward pass. For large models (6.9B+ parameters), this becomes the primary VRAM bottleneck. By checkpointing at the transformer block level, only the inputs to each block are saved, and internal activations are recomputed.
Code evidence from `trainers.py:461-484`:
if config.activation_checkpointing:
rank0_print('Attempting to enable activation checkpointing...')
try:
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
checkpoint_wrapper,
apply_activation_checkpointing,
CheckpointImpl,
)
non_reentrant_wrapper = functools.partial(
checkpoint_wrapper,
offload_to_cpu=False,
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
)
except Exception as e:
rank0_print('FSDP activation checkpointing not available:', e)
else:
check_fn = lambda submodule: isinstance(submodule, wrap_class)
rank0_print('Applying activation checkpointing wrapper to policy...')
apply_activation_checkpointing(self.policy, checkpoint_wrapper_fn=non_reentrant_wrapper, check_fn=check_fn)
rank0_print('FSDP activation checkpointing enabled!')
Key implementation details:
- Uses `CheckpointImpl.NO_REENTRANT` (the recommended mode for FSDP).
- `offload_to_cpu=False` — activations are recomputed on GPU, not offloaded to CPU.
- The `check_fn` applies checkpointing only to transformer block instances (e.g., `GPT2Block`, `GPTNeoXLayer`).
- Gracefully degrades: if the checkpoint API is unavailable, it prints a warning and continues without checkpointing.