Implementation:Huggingface Transformers FSDP Wrapping
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, Training |
| Last Updated | 2026-02-13 00:00 GMT |
Overview
Concrete tool for wrapping a tensor-parallel model with PyTorch FSDP for data-parallel gradient synchronization provided by PyTorch.
Description
This wrapper applies FullyShardedDataParallel (FSDP) to a model that has already been loaded with tensor parallelism. In the 3D parallel example, FSDP is configured with ShardingStrategy.NO_SHARD, which means parameters are fully replicated across data-parallel ranks (equivalent to standard DDP behavior). The FSDP wrapper handles gradient all-reduce during the backward pass.
The wrapping is conditional: it is only applied when the distributed environment is initialized and the data-parallel mesh size is greater than 1. A use_ddp flag is set to True when FSDP is applied, which is later used by the gradient synchronization logic to determine whether DDP/FSDP already handles gradient sync for the DP dimension.
Usage
Apply this wrapper after loading the model with TP and before starting the training loop. It is needed when dp_size > 1 to ensure gradients are synchronized across data-parallel ranks. When dp_size == 1, FSDP is not needed and is skipped.
Code Reference
Source Location
- Repository: transformers
- File:
examples/3D_parallel.py - Lines: 150-153
Signature
FSDP(module, device_mesh=dp_mesh, sharding_strategy=ShardingStrategy.NO_SHARD)
Import
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| module | nn.Module | Yes | The model to wrap. In the 3D parallel case, this is already tensor-parallel sharded. |
| device_mesh | DeviceMesh | Yes | The DP sub-mesh extracted from the world mesh via world_mesh["dp"].
|
| sharding_strategy | ShardingStrategy | Yes | The sharding strategy. NO_SHARD replicates parameters (DDP-like). Other options: FULL_SHARD, SHARD_GRAD_OP.
|
Outputs
| Name | Type | Description |
|---|---|---|
| model | FSDP | The FSDP-wrapped model with automatic gradient synchronization across the DP mesh. |
| use_ddp | bool | Flag set to True indicating that FSDP handles DP gradient sync (used by downstream gradient all-reduce logic).
|
Usage Examples
Basic Usage
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy
# model already loaded with TP
use_ddp = False
if dist.is_initialized() and dp_mesh.size() > 1:
model = FSDP(model, device_mesh=dp_mesh, sharding_strategy=ShardingStrategy.NO_SHARD)
use_ddp = True
model.train()
With Full Sharding for Memory Savings
# Use FULL_SHARD for maximum memory savings (ZeRO Stage 3)
model = FSDP(
model,
device_mesh=dp_mesh,
sharding_strategy=ShardingStrategy.FULL_SHARD,
)