Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Huggingface Transformers Distributed Gradient Synchronization

From Leeroopedia
Knowledge Sources
Domains Distributed_Computing, Training, Optimization
Last Updated 2026-02-13 00:00 GMT

Overview

Distributed gradient synchronization ensures that all processes computing on different data or sequence shards produce consistent parameter updates by averaging their gradients before the optimizer step.

Description

In distributed training, each process computes gradients on its own local data shard (data parallelism) or sequence shard (context parallelism). For the model to converge correctly, these gradients must be averaged across all relevant processes before the optimizer updates the parameters. Without gradient synchronization, each process would diverge to a different set of weights.

In a 3D parallel training setup, gradient synchronization is nuanced because different parallelism axes require different treatment:

  • Tensor Parallelism (TP): No explicit gradient sync needed for TP-specific communication because the TP forward hooks (all-reduce after rowwise layers, all-reduce-backward before colwise layers) already ensure correct gradient flow through the autograd graph.
  • Data Parallelism (DP): When FSDP/DDP is used, it automatically handles gradient all-reduce across the DP mesh during the backward pass.
  • Context Parallelism (CP): Gradients from different CP ranks need to be combined because each rank only saw a portion of the sequence.
  • Combined DP+CP: When both DP and CP are active, gradients must be synchronized across the joint DP+CP mesh to account for both data and sequence sharding.

The all-reduce operation for gradient synchronization has two main modes:

  • SUM followed by division: Each gradient is summed across ranks and then divided by the mesh size. This is used for DTensor gradients where in-place averaging is not possible.
  • AVG (average): The all-reduce directly computes the average. This is used for regular (non-DTensor) gradients.

After gradient synchronization, gradient clipping is applied to prevent gradient explosion, followed by the optimizer step.

Usage

Gradient synchronization must be performed after every backward pass and before the optimizer step. The specific synchronization scope depends on the parallelism configuration:

  • If only FSDP/DDP is used (no CP), FSDP handles gradient sync automatically.
  • If CP is active (with or without DDP), explicit all-reduce across the CP (or combined DP+CP) mesh is required.
  • The use_ddp flag determines whether DDP already handles the DP dimension, allowing the explicit all-reduce to focus on just the CP dimension.

Theoretical Basis

Gradient synchronization is rooted in the mathematical equivalence between:

  1. Computing the gradient on the full (global) batch: g = (1/N) * sum_i grad(loss_i)
  2. Computing local gradients on shards and averaging: g = (1/K) * sum_k g_k where g_k = (1/n_k) * sum_i grad(loss_i) for samples in shard k.

The all-reduce collective operation computes the sum (or average) across all ranks in a single communication step, with a communication cost of O(2 * (K-1)/K * |g|) using the ring all-reduce algorithm, where K is the number of ranks and |g| is the gradient size.

When using DTensor (Distributed Tensor) gradients from tensor parallelism, the gradients are already distributed across the TP mesh. All-reducing a DTensor gradient requires extracting the local tensor, performing the all-reduce on the CP/DP mesh (which is a different mesh from the TP mesh), and then reconstructing the DTensor. This cross-mesh communication is a key complexity in 3D parallel training.

Gradient clipping after synchronization ensures that the global gradient norm is bounded, preventing training instability. The norm is computed across all parameters (which may be DTensors requiring full_tensor() for correct norm computation) and then a scaling factor is applied if the norm exceeds the threshold.

Related Pages

Implemented By

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment