Principle:Huggingface Transformers Fully Sharded Data Parallelism
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, Training |
| Last Updated | 2026-02-13 00:00 GMT |
Overview
Fully Sharded Data Parallelism replicates or shards a model across data-parallel ranks to enable synchronized training on different data subsets while reducing per-GPU memory usage.
Description
Data Parallelism (DP) is the most fundamental distributed training strategy: each GPU holds a copy of the model and processes a different mini-batch of data. After the backward pass, gradients are averaged across all data-parallel ranks using an all-reduce operation, ensuring that all replicas converge to the same weights.
Fully Sharded Data Parallelism (FSDP) extends this concept by optionally sharding model parameters, gradients, and optimizer states across data-parallel ranks, reducing per-GPU memory consumption. PyTorch's FullyShardedDataParallel wrapper provides configurable sharding strategies:
- FULL_SHARD: Shards parameters, gradients, and optimizer states (equivalent to ZeRO Stage 3). Minimizes memory but requires all-gather before each forward/backward and reduce-scatter after.
- SHARD_GRAD_OP: Shards gradients and optimizer states only (equivalent to ZeRO Stage 2).
- NO_SHARD: No parameter sharding; behaves like standard DDP (Distributed Data Parallel) with replicated parameters and all-reduce for gradient synchronization.
In the 3D parallel training setup, FSDP operates along the DP sub-mesh -- the set of ranks that share the same TP and CP coordinates. When combined with tensor parallelism (which already reduces per-GPU parameter memory), NO_SHARD is often sufficient because the TP-sharded model is small enough to replicate across DP ranks. The FSDP wrapper still provides the essential gradient synchronization infrastructure.
Usage
Use FSDP wrapping when:
- Training with multiple data-parallel replicas that need gradient synchronization.
- You want to leverage PyTorch's built-in DDP/FSDP infrastructure rather than manual all-reduce calls.
- Memory is constrained and parameter/gradient sharding is needed (use FULL_SHARD or SHARD_GRAD_OP).
- In a 3D parallel setup with TP already reducing model size, NO_SHARD provides DDP-like behavior with simpler semantics.
The FSDP wrapper is applied after model loading (including TP sharding) and before the training loop begins.
Theoretical Basis
FSDP is based on the ZeRO (Zero Redundancy Optimizer) family of techniques from DeepSpeed (Rajbhandari et al., 2020):
- ZeRO Stage 1: Partitions optimizer states across ranks.
- ZeRO Stage 2: Additionally partitions gradients.
- ZeRO Stage 3: Additionally partitions parameters.
The fundamental insight is that in standard data parallelism, every rank stores a complete copy of all parameters, gradients, and optimizer states, creating N-fold redundancy for N ranks. ZeRO eliminates this redundancy by distributing these components and reconstructing them on demand through communication.
With ShardingStrategy.NO_SHARD, FSDP degenerates to standard DDP: all parameters are replicated, and an all-reduce (sum followed by division by world_size) synchronizes gradients after the backward pass. This is the simplest correct approach when combined with TP, because TP already partitions the large weight matrices across devices.
The interaction between DP and TP is managed through the device mesh: FSDP only communicates within the DP sub-mesh, ensuring that gradient all-reduce does not interfere with TP communication.