Implementation:Deepspeedai DeepSpeed DeepSpeedEngine Backward Step
Appearance
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Training, Gradient_Optimization, Memory_Optimization |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete tool for executing distributed backward propagation and optimizer stepping with ZeRO optimization provided by the DeepSpeed library.
Description
DeepSpeedEngine.backward() executes backward propagation with ZeRO-aware gradient handling. It performs:
- Loss scaling: Applies fp16 dynamic loss scaling via the ZeRO optimizer's scale_if_loss() or torch autocast GradScaler
- Gradient computation: Calls loss.backward() with support for retain_graph and create_graph (eigenvalue computation)
- Gradient accumulation scaling: Optionally scales the loss by gradient_accumulation_steps (controlled by scale_wrt_gas)
- AMP support: For NVIDIA Apex AMP, wraps backward in amp.scale_loss() with delayed unscaling during accumulation
- Compiled autograd: Supports torch.compile integration via compiled_autograd context
DeepSpeedEngine.step() performs the optimizer update at gradient accumulation boundaries. It:
- Checks accumulation boundary: Only performs synchronization and update at the boundary
- Calls _take_model_step(): Which triggers gradient allreduce/reduce-scatter, optimizer.step(), LR scheduler step, and gradient zeroing
- Handles progressive layer drop: Updates PLD state if enabled
- Handles eigenvalue computation: Computes block eigenvalues for quantization if enabled
- Manages flops profiling: Triggers profiler at the configured step
- Increments global step counter: Tracks total training steps
Usage
Use engine.backward(loss) and engine.step() in the training loop instead of loss.backward() and optimizer.step(). These are accessed via the engine object returned by deepspeed.initialize().
Code Reference
Source Location
- Repository: DeepSpeed
- File: deepspeed/runtime/engine.py
- Lines: 2547-2589 (backward), 2722-2770 (step)
Signature
@instrument_w_nvtx
def backward(self, loss, retain_graph=False, scale_wrt_gas=True):
r"""Execute backward pass on the loss
Arguments:
loss: Torch tensor on which to execute backward propagation
retain_graph: bool, default: false
forward on user defined choice of retain_graph
scale_wrt_gas: bool, default: true
whether to scale gradients and return value by gradient accumulation steps
"""
def step(self, lr_kwargs=None):
r"""Execute the weight update step after forward and backward propagation
on effective_train_batch.
"""
Import
# Accessed via the engine returned by deepspeed.initialize()
import deepspeed
engine, optimizer, _, _ = deepspeed.initialize(model=model, config=config)
# Then use:
engine.backward(loss)
engine.step()
I/O Contract
Inputs (backward)
| Name | Type | Required | Description |
|---|---|---|---|
| loss | torch.Tensor | Yes | Scalar loss tensor on which to execute backward propagation |
| retain_graph | bool | No | Whether to retain the computation graph after backward (default: False) |
| scale_wrt_gas | bool | No | Whether to scale gradients by gradient accumulation steps (default: True) |
Inputs (step)
| Name | Type | Required | Description |
|---|---|---|---|
| lr_kwargs | dict | No | Extra keyword arguments passed to the learning rate scheduler step function |
Outputs
| Name | Type | Description |
|---|---|---|
| (side effect) | None | Updated model parameters synchronized across all ranks; global step counter incremented at accumulation boundary |
Usage Examples
import deepspeed
import torch
import torch.nn as nn
model = nn.Linear(1024, 10)
config = {
"train_batch_size": 32,
"gradient_accumulation_steps": 4,
"zero_optimization": {"stage": 2},
"fp16": {"enabled": True},
"optimizer": {"type": "Adam", "params": {"lr": 1e-4}},
}
engine, optimizer, _, _ = deepspeed.initialize(
model=model, config=config, model_parameters=model.parameters()
)
criterion = nn.CrossEntropyLoss()
# Training loop
for batch_idx, (inputs, labels) in enumerate(dataloader):
outputs = engine(inputs)
loss = criterion(outputs, labels)
# Replaces loss.backward() - handles scaling and accumulation
engine.backward(loss)
# Replaces optimizer.step() + optimizer.zero_grad()
# Only synchronizes and updates at accumulation boundary
engine.step()
# With retain_graph for multiple backward passes
loss1 = criterion1(outputs, labels)
loss2 = criterion2(outputs, labels)
engine.backward(loss1, retain_graph=True)
engine.backward(loss2)
engine.step()
Related Pages
Implements Principle
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment