Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Deepspeedai DeepSpeed PipelineEngine Train Batch

From Leeroopedia


Overview

Concrete tool for executing one complete pipeline-parallel training iteration provided by the DeepSpeed library. PipelineEngine.train_batch() executes a complete 1F1B training schedule across all micro-batches.

Description

PipelineEngine.train_batch() executes a complete 1F1B training schedule across all micro-batches. It loads micro-batches from the data iterator, sends activations between stages, executes forward/backward passes according to the TrainSchedule, accumulates gradients, and performs the optimizer step.

The method performs the following sequence:

  1. Verifies that gradients are enabled (raises RuntimeError if not).
  2. Handles curriculum learning activation shape resets if applicable.
  3. Sets the data iterator (if provided externally).
  4. Puts the model in training mode.
  5. Creates a TrainSchedule with micro_batches = gradient_accumulation_steps.
  6. Executes the schedule via _exec_schedule(), which dispatches each PipeInstruction to its handler.
  7. Aggregates the total loss across micro-batches and data-parallel ranks.
  8. Logs throughput and loss information at configured intervals.
  9. Returns the aggregated loss tensor.

The TrainSchedule generates 2 * (M + S - 1) total steps, where each step produces a list of PipeInstruction objects. The engine's _INSTRUCTION_MAP maps each instruction type to the corresponding execution method (e.g., ForwardPass maps to _exec_forward_pass).

Code Reference

train_batch signature:

def train_batch(self, data_iter=None) -> torch.Tensor

TrainSchedule signature:

class TrainSchedule(PipeSchedule):
    def __init__(self, micro_batches, stages, stage_id)

Import:

# Accessed via PipelineEngine returned by deepspeed.initialize()
engine, _, _, _ = deepspeed.initialize(model=pipeline_model, ...)
loss = engine.train_batch(data_iter=train_iter)

I/O Contract

Inputs

Parameter Type Required Default Description
data_iter Iterator No None Training data iterator. If not provided, uses the internal iterator set via set_dataiterator() or from deepspeed.initialize() training data.

Data Iterator Requirements

The data iterator must yield tuples of (inputs, labels):

  • inputs: A tensor or tuple of tensors consumed by the first stage.
  • labels: A tensor or tuple of tensors consumed by the last stage (passed to loss_fn).
  • A total of gradient_accumulation_steps entries will be pulled per train_batch() call.

Outputs

Output Type Description
loss torch.Tensor Aggregated loss: averaged across micro-batches, then averaged across data-parallel ranks, then broadcast to all pipeline stages. Model parameters are updated as a side effect.

Usage Example

import deepspeed
from deepspeed.pipe import PipelineModule, LayerSpec
from deepspeed.utils import RepeatingLoader
import torch
import torch.nn as nn

# Setup model and engine
layers = [LayerSpec(nn.Linear, 1024, 1024) for _ in range(24)]
model = PipelineModule(layers=layers, num_stages=4,
                       loss_fn=nn.CrossEntropyLoss())

engine, _, _, _ = deepspeed.initialize(
    model=model,
    config={"train_batch_size": 32, "train_micro_batch_size_per_gpu": 4}
)

# Create a repeating data loader to avoid StopIteration
train_loader = RepeatingLoader(my_dataloader)
train_iter = iter(train_loader)

# Training loop
for step in range(total_steps):
    loss = engine.train_batch(data_iter=train_iter)
    if step % log_interval == 0:
        print(f"Step {step}, Loss: {loss.item()}")

Related Pages

Knowledge Sources

Last updated: 2026-02-09 00:00 GMT

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment