Implementation:Deepspeedai DeepSpeed PipelineEngine Train Batch
Overview
Concrete tool for executing one complete pipeline-parallel training iteration provided by the DeepSpeed library. PipelineEngine.train_batch() executes a complete 1F1B training schedule across all micro-batches.
Description
PipelineEngine.train_batch() executes a complete 1F1B training schedule across all micro-batches. It loads micro-batches from the data iterator, sends activations between stages, executes forward/backward passes according to the TrainSchedule, accumulates gradients, and performs the optimizer step.
The method performs the following sequence:
- Verifies that gradients are enabled (raises
RuntimeErrorif not). - Handles curriculum learning activation shape resets if applicable.
- Sets the data iterator (if provided externally).
- Puts the model in training mode.
- Creates a
TrainSchedulewithmicro_batches = gradient_accumulation_steps. - Executes the schedule via
_exec_schedule(), which dispatches eachPipeInstructionto its handler. - Aggregates the total loss across micro-batches and data-parallel ranks.
- Logs throughput and loss information at configured intervals.
- Returns the aggregated loss tensor.
The TrainSchedule generates 2 * (M + S - 1) total steps, where each step produces a list of PipeInstruction objects. The engine's _INSTRUCTION_MAP maps each instruction type to the corresponding execution method (e.g., ForwardPass maps to _exec_forward_pass).
Code Reference
- Repository: https://github.com/deepspeedai/DeepSpeed
- File (train_batch):
deepspeed/runtime/pipe/engine.py, Lines L337-425 - File (TrainSchedule):
deepspeed/runtime/pipe/schedule.py, Lines L189-299
train_batch signature:
def train_batch(self, data_iter=None) -> torch.Tensor
TrainSchedule signature:
class TrainSchedule(PipeSchedule):
def __init__(self, micro_batches, stages, stage_id)
Import:
# Accessed via PipelineEngine returned by deepspeed.initialize()
engine, _, _, _ = deepspeed.initialize(model=pipeline_model, ...)
loss = engine.train_batch(data_iter=train_iter)
I/O Contract
Inputs
| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
| data_iter | Iterator | No | None | Training data iterator. If not provided, uses the internal iterator set via set_dataiterator() or from deepspeed.initialize() training data.
|
Data Iterator Requirements
The data iterator must yield tuples of (inputs, labels):
- inputs: A tensor or tuple of tensors consumed by the first stage.
- labels: A tensor or tuple of tensors consumed by the last stage (passed to
loss_fn). - A total of
gradient_accumulation_stepsentries will be pulled pertrain_batch()call.
Outputs
| Output | Type | Description |
|---|---|---|
| loss | torch.Tensor | Aggregated loss: averaged across micro-batches, then averaged across data-parallel ranks, then broadcast to all pipeline stages. Model parameters are updated as a side effect. |
Usage Example
import deepspeed
from deepspeed.pipe import PipelineModule, LayerSpec
from deepspeed.utils import RepeatingLoader
import torch
import torch.nn as nn
# Setup model and engine
layers = [LayerSpec(nn.Linear, 1024, 1024) for _ in range(24)]
model = PipelineModule(layers=layers, num_stages=4,
loss_fn=nn.CrossEntropyLoss())
engine, _, _, _ = deepspeed.initialize(
model=model,
config={"train_batch_size": 32, "train_micro_batch_size_per_gpu": 4}
)
# Create a repeating data loader to avoid StopIteration
train_loader = RepeatingLoader(my_dataloader)
train_iter = iter(train_loader)
# Training loop
for step in range(total_steps):
loss = engine.train_batch(data_iter=train_iter)
if step % log_interval == 0:
print(f"Step {step}, Loss: {loss.item()}")
Related Pages
- Principle:Deepspeedai_DeepSpeed_Pipeline_Training_Schedule
- Implementation:Deepspeedai_DeepSpeed_PipelineEngine_Init
- Implementation:Deepspeedai_DeepSpeed_PipelineEngine_Eval_Batch
Knowledge Sources
Last updated: 2026-02-09 00:00 GMT