Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Transformers All Reduce Grads

From Leeroopedia
Knowledge Sources
Domains Distributed_Computing, Training, Optimization
Last Updated 2026-02-13 00:00 GMT

Overview

Concrete API for all-reducing gradients across data-parallel and context-parallel ranks, with special handling for DTensor gradients, as defined in the Hugging Face Transformers 3D parallel training example.

Description

The all_reduce_grads function synchronizes gradients across the appropriate mesh after the backward pass. Its behavior adapts based on whether DDP/FSDP is active:

  • With DDP (use_ddp=True): DDP already handles gradient sync across the DP mesh, so all_reduce_grads only synchronizes across the CP mesh.
  • Without DDP (use_ddp=False): The function synchronizes across a flattened DP+CP mesh that combines both data-parallel and context-parallel dimensions.

For each parameter with a gradient, the function handles two cases:

  • DTensor gradients (from tensor-parallel parameters): The local tensor is extracted from the DTensor, all-reduced with ReduceOp.SUM, manually divided by the mesh size, and then reconstructed as a DTensor with the original device mesh and placements.
  • Regular tensor gradients: A simple all-reduce with ReduceOp.AVG is performed.

After gradient synchronization, the training loop applies gradient clipping via either FSDP's clip_grad_norm_ (when FSDP is active) or a custom clip_grad_norm_ function that handles DTensor parameters correctly.

Usage

Call this function after loss.backward() and after exiting the context-parallel context, but before optimizer.step(). It must be called on every training step when cp_size > 1 or when dp_size > 1 without DDP.

Code Reference

Source Location

  • Repository: transformers
  • File: examples/3D_parallel.py
  • Lines: 355-430 (all_reduce_grads and clip_grad_norm_)

Signature

def all_reduce_grads(model, world_mesh, use_ddp):
    """All reduce gradients across dp_cp if applicable."""
    ...
def clip_grad_norm_(
    parameters: Iterable[torch.Tensor],
    max_norm: float,
    norm_type: float = 2.0,
    error_if_nonfinite: bool = False,
    foreach: bool | None = None,
) -> torch.Tensor:
    """Clip the gradient norm of an iterable of parameters."""
    ...

Import

import torch.distributed as dist
from torch.distributed.tensor import DTensor

I/O Contract

Inputs (all_reduce_grads)

Name Type Required Description
model nn.Module Yes The model whose parameter gradients will be all-reduced.
world_mesh DeviceMesh Yes The full 3D world mesh with "dp", "tp", and "cp" dimensions.
use_ddp bool Yes Whether DDP/FSDP is already handling DP gradient sync. If True, only sync across CP. If False, sync across flattened DP+CP.

Inputs (clip_grad_norm_)

Name Type Required Description
parameters Iterable[torch.Tensor] Yes Model parameters whose gradients will be clipped.
max_norm float Yes Maximum gradient norm threshold (e.g. 1.0).
norm_type float No Type of norm to use (default 2.0 for L2 norm).

Outputs

Name Type Description
(side effect) None all_reduce_grads modifies parameter gradients in-place.
total_norm torch.Tensor clip_grad_norm_ returns the total gradient norm before clipping.

Usage Examples

Basic Usage

# After loss.backward() and exiting CP context:
all_reduce_grads(model, world_mesh, use_ddp=use_ddp)

# Gradient clipping
if hasattr(model, "clip_grad_norm_"):
    # FSDP-wrapped model has its own clip_grad_norm_
    gradnorm = model.clip_grad_norm_(max_norm=1.0, norm_type=2.0)
else:
    # Custom clip_grad_norm_ that handles DTensors
    gradnorm = clip_grad_norm_(model.parameters(), max_norm=1.0, norm_type=2.0, foreach=True)

optimizer.step()

Internal DTensor Handling

def all_reduce_grads(model, world_mesh, use_ddp):
    cp_mesh = world_mesh["cp"]
    if use_ddp:
        mesh = cp_mesh  # DDP handles DP sync
    else:
        mesh = world_mesh["dp", "cp"]._flatten(mesh_dim_name="dp_cp")

    if dist.is_initialized() and mesh.size() > 1:
        for name, param in model.named_parameters():
            if param.grad is not None:
                if isinstance(param.grad, DTensor):
                    local_grad = param.grad.to_local()
                    dist.all_reduce(local_grad, op=dist.ReduceOp.SUM, group=mesh.get_group())
                    local_grad = local_grad / mesh.size()
                    param.grad = DTensor.from_local(
                        local_grad,
                        device_mesh=param.grad.device_mesh,
                        placements=param.grad.placements,
                    )
                else:
                    dist.all_reduce(param.grad, op=dist.ReduceOp.AVG, group=mesh.get_group())

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment