Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Huggingface Transformers FSDP Activation Checkpointing Tip

From Leeroopedia
Knowledge Sources
Domains Distributed_Training, Optimization, FSDP
Last Updated 2026-02-13 20:00 GMT

Overview

When using FSDP full shard, use activation_checkpointing in fsdp_config instead of gradient_checkpointing in TrainingArguments to avoid redundant AllGather operations.

Description

There is a subtle but important difference between gradient_checkpointing=True in TrainingArguments and activation_checkpointing in fsdp_config when using FSDP with full sharding. The former (TrainingArguments) uses PyTorch's built-in gradient checkpointing which is unaware of FSDP sharding, causing redundant AllGather collective operations during the backward pass. The latter (fsdp_config) uses FSDP-native activation checkpointing that properly coordinates with the sharding strategy, avoiding unnecessary communication overhead.

Usage

Apply this when using FSDP with full_shard or hybrid_shard sharding strategies and you want to enable activation checkpointing. The Transformers library emits a warning when it detects this misconfiguration.

The Insight (Rule of Thumb)

  • Action: Remove gradient_checkpointing=True from TrainingArguments. Instead, set activation_checkpointing=True inside fsdp_config.
  • Value: Avoids one redundant AllGather per checkpointed layer per training step.
  • Trade-off: No trade-off; this is strictly better when using FSDP full shard.
  • Scope: Only applies when using FSDP with full_shard or hybrid_shard. For non-FSDP or shard_grad_op, the TrainingArguments gradient_checkpointing flag is fine.

Reasoning

With FSDP full sharding, model parameters are sharded across GPUs. During the forward pass, FSDP performs AllGather to reconstruct full parameters. With standard gradient checkpointing (from TrainingArguments), the recomputation during backward pass triggers an additional AllGather for each checkpointed segment, because the standard checkpointing mechanism does not coordinate with FSDP's parameter lifecycle. FSDP-native activation checkpointing avoids this by integrating the checkpoint/recompute schedule with FSDP's parameter gather schedule.

Code Evidence

Warning from src/transformers/training_args.py:2665-2672:

if self.gradient_checkpointing and (
    FSDPOption.FULL_SHARD in self.fsdp or FSDPOption.HYBRID_SHARD in self.fsdp
):
    logger.warning(
        "When using FSDP full shard, instead of using `gradient_checkpointing` in TrainingArguments, please"
        " use `activation_checkpointing` in `fsdp_config`. The former introduces a redundant AllGather"
        " operation in backward pass. Reference: https://github.com/huggingface/transformers/issues/30404"
    )

Conflict check in Trainer from src/transformers/trainer.py:805-807:

if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
    "The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment