Implementation:Huggingface Transformers All Reduce Grads
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, Training, Optimization |
| Last Updated | 2026-02-13 00:00 GMT |
Overview
Concrete API for all-reducing gradients across data-parallel and context-parallel ranks, with special handling for DTensor gradients, as defined in the Hugging Face Transformers 3D parallel training example.
Description
The all_reduce_grads function synchronizes gradients across the appropriate mesh after the backward pass. Its behavior adapts based on whether DDP/FSDP is active:
- With DDP (
use_ddp=True): DDP already handles gradient sync across the DP mesh, soall_reduce_gradsonly synchronizes across the CP mesh. - Without DDP (
use_ddp=False): The function synchronizes across a flattened DP+CP mesh that combines both data-parallel and context-parallel dimensions.
For each parameter with a gradient, the function handles two cases:
- DTensor gradients (from tensor-parallel parameters): The local tensor is extracted from the DTensor, all-reduced with
ReduceOp.SUM, manually divided by the mesh size, and then reconstructed as a DTensor with the original device mesh and placements. - Regular tensor gradients: A simple all-reduce with
ReduceOp.AVGis performed.
After gradient synchronization, the training loop applies gradient clipping via either FSDP's clip_grad_norm_ (when FSDP is active) or a custom clip_grad_norm_ function that handles DTensor parameters correctly.
Usage
Call this function after loss.backward() and after exiting the context-parallel context, but before optimizer.step(). It must be called on every training step when cp_size > 1 or when dp_size > 1 without DDP.
Code Reference
Source Location
- Repository: transformers
- File:
examples/3D_parallel.py - Lines: 355-430 (all_reduce_grads and clip_grad_norm_)
Signature
def all_reduce_grads(model, world_mesh, use_ddp):
"""All reduce gradients across dp_cp if applicable."""
...
def clip_grad_norm_(
parameters: Iterable[torch.Tensor],
max_norm: float,
norm_type: float = 2.0,
error_if_nonfinite: bool = False,
foreach: bool | None = None,
) -> torch.Tensor:
"""Clip the gradient norm of an iterable of parameters."""
...
Import
import torch.distributed as dist
from torch.distributed.tensor import DTensor
I/O Contract
Inputs (all_reduce_grads)
| Name | Type | Required | Description |
|---|---|---|---|
| model | nn.Module | Yes | The model whose parameter gradients will be all-reduced. |
| world_mesh | DeviceMesh | Yes | The full 3D world mesh with "dp", "tp", and "cp" dimensions.
|
| use_ddp | bool | Yes | Whether DDP/FSDP is already handling DP gradient sync. If True, only sync across CP. If False, sync across flattened DP+CP. |
Inputs (clip_grad_norm_)
| Name | Type | Required | Description |
|---|---|---|---|
| parameters | Iterable[torch.Tensor] | Yes | Model parameters whose gradients will be clipped. |
| max_norm | float | Yes | Maximum gradient norm threshold (e.g. 1.0). |
| norm_type | float | No | Type of norm to use (default 2.0 for L2 norm). |
Outputs
| Name | Type | Description |
|---|---|---|
| (side effect) | None | all_reduce_grads modifies parameter gradients in-place.
|
| total_norm | torch.Tensor | clip_grad_norm_ returns the total gradient norm before clipping.
|
Usage Examples
Basic Usage
# After loss.backward() and exiting CP context:
all_reduce_grads(model, world_mesh, use_ddp=use_ddp)
# Gradient clipping
if hasattr(model, "clip_grad_norm_"):
# FSDP-wrapped model has its own clip_grad_norm_
gradnorm = model.clip_grad_norm_(max_norm=1.0, norm_type=2.0)
else:
# Custom clip_grad_norm_ that handles DTensors
gradnorm = clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=2.0, foreach=True)
optimizer.step()
Internal DTensor Handling
def all_reduce_grads(model, world_mesh, use_ddp):
cp_mesh = world_mesh["cp"]
if use_ddp:
mesh = cp_mesh # DDP handles DP sync
else:
mesh = world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")
if dist.is_initialized() and mesh.size() > 1:
for name, param in model.named_parameters():
if param.grad is not None:
if isinstance(param.grad, DTensor):
local_grad = param.grad.to_local()
dist.all_reduce(local_grad, op=dist.ReduceOp.SUM, group=mesh.get_group())
local_grad = local_grad / mesh.size()
param.grad = DTensor.from_local(
local_grad,
device_mesh=param.grad.device_mesh,
placements=param.grad.placements,
)
else:
dist.all_reduce(param.grad, op=dist.ReduceOp.AVG, group=mesh.get_group())