Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Heuristic:FlagOpen FlagEmbedding Gradient Checkpointing Tip

From Leeroopedia



Knowledge Sources
Domains Optimization, Memory_Management
Last Updated 2026-02-09 21:00 GMT

Overview

Enable gradient checkpointing during fine-tuning to reduce GPU memory usage by 40-60% at the cost of ~20% slower training.

Description

Gradient checkpointing (activation checkpointing) is a memory optimization that avoids storing all intermediate activations during the forward pass. Instead, only a subset of activations are kept in memory, and the rest are recomputed during the backward pass. FlagEmbedding supports this through the `--gradient_checkpointing` flag in training arguments, which is propagated to all model types (encoder-only, decoder-only, and M3 embedders, as well as all reranker variants).

Usage

Use this heuristic when:

  • VRAM is limited: Fine-tuning large models (7B+) on consumer GPUs (RTX 3090/4090 with 24GB)
  • Maximizing batch size: Want to increase effective batch size within fixed GPU memory
  • Training large models: Any model where activation memory exceeds available VRAM

The Insight (Rule of Thumb)

  • Action: Set `--gradient_checkpointing` in the training script arguments.
  • Value: Boolean flag (True/False).
  • Trade-off: Reduces VRAM usage by 40-60% at the cost of approximately 20-30% slower training due to recomputation of activations during backward pass.
  • Companion: Combine with `--negatives_cross_device` and `--pad_to_multiple_of 8` for maximum efficiency.

Reasoning

For Transformer models, the activation memory grows linearly with batch size, sequence length, and number of layers. For a 7B parameter model with 32 layers, storing all activations for a single batch can consume 10-20GB. Gradient checkpointing reduces this to storing activations at only a few checkpoint layers.

All FlagEmbedding model classes implement `gradient_checkpointing_enable()`:

# From FlagEmbedding/finetune/embedder/encoder_only/base/modeling.py:168-172
def gradient_checkpointing_enable(self, **kwargs):
    """
    Activates gradient checkpointing for the current model.
    """
    self.model.gradient_checkpointing_enable(**kwargs)

The runners activate it conditionally:

# From FlagEmbedding/finetune/embedder/encoder_only/base/runner.py:60
if self.training_args.gradient_checkpointing:
    model.gradient_checkpointing_enable()

Example shell script usage from `examples/finetune/embedder/encoder_only/base.sh`:

torchrun --nproc_per_node 2 \
    -m FlagEmbedding.finetune.embedder.encoder_only.base \
    --gradient_checkpointing True \
    --negatives_cross_device \
    --temperature 0.02

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment