Principle:Deepspeedai DeepSpeed ZeRO Training Loop
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Training, Gradient_Optimization, Memory_Optimization |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
The distributed training loop pattern using DeepSpeed's engine abstraction for backward propagation with ZeRO-optimized gradient handling and optimizer stepping.
Description
The ZeRO Training Loop replaces standard PyTorch loss.backward() and optimizer.step() with DeepSpeed engine methods that handle gradient accumulation, mixed-precision scaling, ZeRO-partitioned gradient allreduce, and optimizer state management automatically.
The loop consists of two critical engine methods:
- engine.backward(loss): Handles loss scaling (fp16 dynamic scaling or static scaling), gradient computation via autograd, and gradient accumulation across micro-batches. For ZeRO stages 1-3, gradients are computed but not yet synchronized -- synchronization is deferred to the accumulation boundary.
- engine.step(): Handles the gradient synchronization and optimizer update at the gradient accumulation boundary:
- ZeRO Stage 1: AllReduce gradients, then each rank updates its 1/N partition of optimizer states and parameters
- ZeRO Stage 2: ReduceScatter gradients (each rank receives only its 1/N partition), then each rank updates its local partition
- ZeRO Stage 3: ReduceScatter gradients, each rank updates its parameter shard, and parameters remain partitioned for the next forward pass
The engine automatically tracks gradient accumulation boundaries and only performs synchronization and optimizer updates at the boundary.
Usage
Replace loss.backward() with engine.backward(loss) and optimizer.step() / optimizer.zero_grad() with engine.step() in the training loop. The engine handles gradient zeroing, accumulation, synchronization, and optimizer updates automatically.
Theoretical Basis
ZeRO gradient handling varies by stage:
- Stage 1: Partitions optimizer states. Each rank maintains a full copy of gradients (AllReduce), but only updates 1/N of the optimizer states and parameters. After update, an AllGather broadcasts updated parameters.
- Stage 2: Additionally partitions gradients. Uses ReduceScatter instead of AllReduce so each rank receives only its 1/N gradient partition. Each rank updates its local optimizer states and parameters. AllGather broadcasts updated parameters.
- Stage 3: Additionally partitions parameters. Same as Stage 2 for gradients, but parameters remain partitioned between forward/backward passes, gathered on-demand via AllGather.
Gradient accumulation: Handled by deferring allreduce/reduce-scatter until the accumulation boundary. During non-boundary micro-batch steps, gradients are accumulated locally. At the boundary, accumulated gradients are synchronized across ranks.
Mixed precision interaction:
- fp16: Dynamic loss scaling with automatic scale adjustment on overflow detection
- bf16: No loss scaling needed (larger dynamic range)
- AMP: NVIDIA Apex AMP integration with delayed unscaling during gradient accumulation
Paper: ZeRO: Memory Optimizations Toward Training Trillion Parameter Models
Pseudo-code:
# Abstract ZeRO training loop
for micro_batch in dataloader:
outputs = engine.forward(micro_batch)
loss = criterion(outputs, labels)
# Handles loss scaling + gradient computation + accumulation
engine.backward(loss)
# At accumulation boundary: sync gradients, update optimizer, zero grads
engine.step()
# Internally:
# if is_gradient_accumulation_boundary():
# reduce_scatter(gradients) # ZeRO-2/3
# optimizer.step() # update local partition
# all_gather(parameters) # ZeRO-1/2 (ZeRO-3 defers)
# zero_grad()