Heuristic:Huggingface Transformers FSDP Activation Checkpointing Tip
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Training, Optimization, FSDP |
| Last Updated | 2026-02-13 20:00 GMT |
Overview
When using FSDP full shard, use activation_checkpointing in fsdp_config instead of gradient_checkpointing in TrainingArguments to avoid redundant AllGather operations.
Description
There is a subtle but important difference between gradient_checkpointing=True in TrainingArguments and activation_checkpointing in fsdp_config when using FSDP with full sharding. The former (TrainingArguments) uses PyTorch's built-in gradient checkpointing which is unaware of FSDP sharding, causing redundant AllGather collective operations during the backward pass. The latter (fsdp_config) uses FSDP-native activation checkpointing that properly coordinates with the sharding strategy, avoiding unnecessary communication overhead.
Usage
Apply this when using FSDP with full_shard or hybrid_shard sharding strategies and you want to enable activation checkpointing. The Transformers library emits a warning when it detects this misconfiguration.
The Insight (Rule of Thumb)
- Action: Remove
gradient_checkpointing=Truefrom TrainingArguments. Instead, setactivation_checkpointing=Trueinsidefsdp_config. - Value: Avoids one redundant AllGather per checkpointed layer per training step.
- Trade-off: No trade-off; this is strictly better when using FSDP full shard.
- Scope: Only applies when using FSDP with
full_shardorhybrid_shard. For non-FSDP orshard_grad_op, the TrainingArgumentsgradient_checkpointingflag is fine.
Reasoning
With FSDP full sharding, model parameters are sharded across GPUs. During the forward pass, FSDP performs AllGather to reconstruct full parameters. With standard gradient checkpointing (from TrainingArguments), the recomputation during backward pass triggers an additional AllGather for each checkpointed segment, because the standard checkpointing mechanism does not coordinate with FSDP's parameter lifecycle. FSDP-native activation checkpointing avoids this by integrating the checkpoint/recompute schedule with FSDP's parameter gather schedule.
Code Evidence
Warning from src/transformers/training_args.py:2665-2672:
if self.gradient_checkpointing and (
FSDPOption.FULL_SHARD in self.fsdp or FSDPOption.HYBRID_SHARD in self.fsdp
):
logger.warning(
"When using FSDP full shard, instead of using `gradient_checkpointing` in TrainingArguments, please"
" use `activation_checkpointing` in `fsdp_config`. The former introduces a redundant AllGather"
" operation in backward pass. Reference: https://github.com/huggingface/transformers/issues/30404"
)
Conflict check in Trainer from src/transformers/trainer.py:805-807:
if fsdp_plugin.activation_checkpointing and self.args.gradient_checkpointing:
"The activation_checkpointing in FSDP config and the gradient_checkpointing in training arg "