Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Deepspeedai DeepSpeed ZeRO Training Loop

From Leeroopedia


Knowledge Sources
Domains Distributed_Training, Gradient_Optimization, Memory_Optimization
Last Updated 2026-02-09 00:00 GMT

Overview

The distributed training loop pattern using DeepSpeed's engine abstraction for backward propagation with ZeRO-optimized gradient handling and optimizer stepping.

Description

The ZeRO Training Loop replaces standard PyTorch loss.backward() and optimizer.step() with DeepSpeed engine methods that handle gradient accumulation, mixed-precision scaling, ZeRO-partitioned gradient allreduce, and optimizer state management automatically.

The loop consists of two critical engine methods:

  • engine.backward(loss): Handles loss scaling (fp16 dynamic scaling or static scaling), gradient computation via autograd, and gradient accumulation across micro-batches. For ZeRO stages 1-3, gradients are computed but not yet synchronized -- synchronization is deferred to the accumulation boundary.
  • engine.step(): Handles the gradient synchronization and optimizer update at the gradient accumulation boundary:
    • ZeRO Stage 1: AllReduce gradients, then each rank updates its 1/N partition of optimizer states and parameters
    • ZeRO Stage 2: ReduceScatter gradients (each rank receives only its 1/N partition), then each rank updates its local partition
    • ZeRO Stage 3: ReduceScatter gradients, each rank updates its parameter shard, and parameters remain partitioned for the next forward pass

The engine automatically tracks gradient accumulation boundaries and only performs synchronization and optimizer updates at the boundary.

Usage

Replace loss.backward() with engine.backward(loss) and optimizer.step() / optimizer.zero_grad() with engine.step() in the training loop. The engine handles gradient zeroing, accumulation, synchronization, and optimizer updates automatically.

Theoretical Basis

ZeRO gradient handling varies by stage:

  • Stage 1: Partitions optimizer states. Each rank maintains a full copy of gradients (AllReduce), but only updates 1/N of the optimizer states and parameters. After update, an AllGather broadcasts updated parameters.
  • Stage 2: Additionally partitions gradients. Uses ReduceScatter instead of AllReduce so each rank receives only its 1/N gradient partition. Each rank updates its local optimizer states and parameters. AllGather broadcasts updated parameters.
  • Stage 3: Additionally partitions parameters. Same as Stage 2 for gradients, but parameters remain partitioned between forward/backward passes, gathered on-demand via AllGather.

Gradient accumulation: Handled by deferring allreduce/reduce-scatter until the accumulation boundary. During non-boundary micro-batch steps, gradients are accumulated locally. At the boundary, accumulated gradients are synchronized across ranks.

Mixed precision interaction:

  • fp16: Dynamic loss scaling with automatic scale adjustment on overflow detection
  • bf16: No loss scaling needed (larger dynamic range)
  • AMP: NVIDIA Apex AMP integration with delayed unscaling during gradient accumulation

Paper: ZeRO: Memory Optimizations Toward Training Trillion Parameter Models

Pseudo-code:

# Abstract ZeRO training loop
for micro_batch in dataloader:
    outputs = engine.forward(micro_batch)
    loss = criterion(outputs, labels)

    # Handles loss scaling + gradient computation + accumulation
    engine.backward(loss)

    # At accumulation boundary: sync gradients, update optimizer, zero grads
    engine.step()
    # Internally:
    #   if is_gradient_accumulation_boundary():
    #       reduce_scatter(gradients)   # ZeRO-2/3
    #       optimizer.step()            # update local partition
    #       all_gather(parameters)      # ZeRO-1/2 (ZeRO-3 defers)
    #       zero_grad()

Related Pages

Implemented By

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment