Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Transformers FSDP Wrapping

From Leeroopedia
Knowledge Sources
Domains Distributed_Computing, Training
Last Updated 2026-02-13 00:00 GMT

Overview

Concrete tool for wrapping a tensor-parallel model with PyTorch FSDP for data-parallel gradient synchronization provided by PyTorch.

Description

This wrapper applies FullyShardedDataParallel (FSDP) to a model that has already been loaded with tensor parallelism. In the 3D parallel example, FSDP is configured with ShardingStrategy.NO_SHARD, which means parameters are fully replicated across data-parallel ranks (equivalent to standard DDP behavior). The FSDP wrapper handles gradient all-reduce during the backward pass.

The wrapping is conditional: it is only applied when the distributed environment is initialized and the data-parallel mesh size is greater than 1. A use_ddp flag is set to True when FSDP is applied, which is later used by the gradient synchronization logic to determine whether DDP/FSDP already handles gradient sync for the DP dimension.

Usage

Apply this wrapper after loading the model with TP and before starting the training loop. It is needed when dp_size > 1 to ensure gradients are synchronized across data-parallel ranks. When dp_size == 1, FSDP is not needed and is skipped.

Code Reference

Source Location

  • Repository: transformers
  • File: examples/3D_parallel.py
  • Lines: 150-153

Signature

FSDP(module, device_mesh=dp_mesh, sharding_strategy=ShardingStrategy.NO_SHARD)

Import

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy

I/O Contract

Inputs

Name Type Required Description
module nn.Module Yes The model to wrap. In the 3D parallel case, this is already tensor-parallel sharded.
device_mesh DeviceMesh Yes The DP sub-mesh extracted from the world mesh via world_mesh["dp"].
sharding_strategy ShardingStrategy Yes The sharding strategy. NO_SHARD replicates parameters (DDP-like). Other options: FULL_SHARD, SHARD_GRAD_OP.

Outputs

Name Type Description
model FSDP The FSDP-wrapped model with automatic gradient synchronization across the DP mesh.
use_ddp bool Flag set to True indicating that FSDP handles DP gradient sync (used by downstream gradient all-reduce logic).

Usage Examples

Basic Usage

from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import ShardingStrategy

# model already loaded with TP
use_ddp = False
if dist.is_initialized() and dp_mesh.size() > 1:
    model = FSDP(model, device_mesh=dp_mesh, sharding_strategy=ShardingStrategy.NO_SHARD)
    use_ddp = True

model.train()

With Full Sharding for Memory Savings

# Use FULL_SHARD for maximum memory savings (ZeRO Stage 3)
model = FSDP(
    model,
    device_mesh=dp_mesh,
    sharding_strategy=ShardingStrategy.FULL_SHARD,
)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment