Heuristic:FlagOpen FlagEmbedding Gradient Checkpointing Tip
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Memory_Management |
| Last Updated | 2026-02-09 21:00 GMT |
Overview
Enable gradient checkpointing during fine-tuning to reduce GPU memory usage by 40-60% at the cost of ~20% slower training.
Description
Gradient checkpointing (activation checkpointing) is a memory optimization that avoids storing all intermediate activations during the forward pass. Instead, only a subset of activations are kept in memory, and the rest are recomputed during the backward pass. FlagEmbedding supports this through the `--gradient_checkpointing` flag in training arguments, which is propagated to all model types (encoder-only, decoder-only, and M3 embedders, as well as all reranker variants).
Usage
Use this heuristic when:
- VRAM is limited: Fine-tuning large models (7B+) on consumer GPUs (RTX 3090/4090 with 24GB)
- Maximizing batch size: Want to increase effective batch size within fixed GPU memory
- Training large models: Any model where activation memory exceeds available VRAM
The Insight (Rule of Thumb)
- Action: Set `--gradient_checkpointing` in the training script arguments.
- Value: Boolean flag (True/False).
- Trade-off: Reduces VRAM usage by 40-60% at the cost of approximately 20-30% slower training due to recomputation of activations during backward pass.
- Companion: Combine with `--negatives_cross_device` and `--pad_to_multiple_of 8` for maximum efficiency.
Reasoning
For Transformer models, the activation memory grows linearly with batch size, sequence length, and number of layers. For a 7B parameter model with 32 layers, storing all activations for a single batch can consume 10-20GB. Gradient checkpointing reduces this to storing activations at only a few checkpoint layers.
All FlagEmbedding model classes implement `gradient_checkpointing_enable()`:
# From FlagEmbedding/finetune/embedder/encoder_only/base/modeling.py:168-172
def gradient_checkpointing_enable(self, **kwargs):
"""
Activates gradient checkpointing for the current model.
"""
self.model.gradient_checkpointing_enable(**kwargs)
The runners activate it conditionally:
# From FlagEmbedding/finetune/embedder/encoder_only/base/runner.py:60
if self.training_args.gradient_checkpointing:
model.gradient_checkpointing_enable()
Example shell script usage from `examples/finetune/embedder/encoder_only/base.sh`:
torchrun --nproc_per_node 2 \
-m FlagEmbedding.finetune.embedder.encoder_only.base \
--gradient_checkpointing True \
--negatives_cross_device \
--temperature 0.02