Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Deepspeedai DeepSpeed DeepSpeedEngine Backward Step

From Leeroopedia


Knowledge Sources
Domains Distributed_Training, Gradient_Optimization, Memory_Optimization
Last Updated 2026-02-09 00:00 GMT

Overview

Concrete tool for executing distributed backward propagation and optimizer stepping with ZeRO optimization provided by the DeepSpeed library.

Description

DeepSpeedEngine.backward() executes backward propagation with ZeRO-aware gradient handling. It performs:

  • Loss scaling: Applies fp16 dynamic loss scaling via the ZeRO optimizer's scale_if_loss() or torch autocast GradScaler
  • Gradient computation: Calls loss.backward() with support for retain_graph and create_graph (eigenvalue computation)
  • Gradient accumulation scaling: Optionally scales the loss by gradient_accumulation_steps (controlled by scale_wrt_gas)
  • AMP support: For NVIDIA Apex AMP, wraps backward in amp.scale_loss() with delayed unscaling during accumulation
  • Compiled autograd: Supports torch.compile integration via compiled_autograd context

DeepSpeedEngine.step() performs the optimizer update at gradient accumulation boundaries. It:

  • Checks accumulation boundary: Only performs synchronization and update at the boundary
  • Calls _take_model_step(): Which triggers gradient allreduce/reduce-scatter, optimizer.step(), LR scheduler step, and gradient zeroing
  • Handles progressive layer drop: Updates PLD state if enabled
  • Handles eigenvalue computation: Computes block eigenvalues for quantization if enabled
  • Manages flops profiling: Triggers profiler at the configured step
  • Increments global step counter: Tracks total training steps

Usage

Use engine.backward(loss) and engine.step() in the training loop instead of loss.backward() and optimizer.step(). These are accessed via the engine object returned by deepspeed.initialize().

Code Reference

Source Location

  • Repository: DeepSpeed
  • File: deepspeed/runtime/engine.py
  • Lines: 2547-2589 (backward), 2722-2770 (step)

Signature

@instrument_w_nvtx
def backward(self, loss, retain_graph=False, scale_wrt_gas=True):
    r"""Execute backward pass on the loss

    Arguments:
        loss: Torch tensor on which to execute backward propagation
        retain_graph: bool, default: false
            forward on user defined choice of retain_graph
        scale_wrt_gas: bool, default: true
            whether to scale gradients and return value by gradient accumulation steps
    """
def step(self, lr_kwargs=None):
    r"""Execute the weight update step after forward and backward propagation
    on effective_train_batch.
    """

Import

# Accessed via the engine returned by deepspeed.initialize()
import deepspeed

engine, optimizer, _, _ = deepspeed.initialize(model=model, config=config)

# Then use:
engine.backward(loss)
engine.step()

I/O Contract

Inputs (backward)

Name Type Required Description
loss torch.Tensor Yes Scalar loss tensor on which to execute backward propagation
retain_graph bool No Whether to retain the computation graph after backward (default: False)
scale_wrt_gas bool No Whether to scale gradients by gradient accumulation steps (default: True)

Inputs (step)

Name Type Required Description
lr_kwargs dict No Extra keyword arguments passed to the learning rate scheduler step function

Outputs

Name Type Description
(side effect) None Updated model parameters synchronized across all ranks; global step counter incremented at accumulation boundary

Usage Examples

import deepspeed
import torch
import torch.nn as nn

model = nn.Linear(1024, 10)
config = {
    "train_batch_size": 32,
    "gradient_accumulation_steps": 4,
    "zero_optimization": {"stage": 2},
    "fp16": {"enabled": True},
    "optimizer": {"type": "Adam", "params": {"lr": 1e-4}},
}

engine, optimizer, _, _ = deepspeed.initialize(
    model=model, config=config, model_parameters=model.parameters()
)

criterion = nn.CrossEntropyLoss()

# Training loop
for batch_idx, (inputs, labels) in enumerate(dataloader):
    outputs = engine(inputs)
    loss = criterion(outputs, labels)

    # Replaces loss.backward() - handles scaling and accumulation
    engine.backward(loss)

    # Replaces optimizer.step() + optimizer.zero_grad()
    # Only synchronizes and updates at accumulation boundary
    engine.step()

# With retain_graph for multiple backward passes
loss1 = criterion1(outputs, labels)
loss2 = criterion2(outputs, labels)
engine.backward(loss1, retain_graph=True)
engine.backward(loss2)
engine.step()

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment