Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Deepspeedai DeepSpeed PipelineEngine Eval Batch

From Leeroopedia


Overview

Concrete tool for evaluating pipeline-parallel models provided by the DeepSpeed library. PipelineEngine.eval_batch() executes a forward-only pipeline schedule for evaluation, with support for returning logits, configurable loss reduction, and broadcasting results to all ranks.

Description

PipelineEngine.eval_batch() executes a forward-only pipeline schedule for evaluation. It supports returning logits from the last stage, computing loss, reducing outputs across micro-batches, and broadcasting results to all ranks.

The method performs the following sequence:

  1. Sets eval_return_logits flag if logits are requested.
  2. Puts the model in eval mode via self.module.eval().
  3. Handles curriculum learning activation shape resets if applicable.
  4. Saves the current training iterator and sets the evaluation data iterator.
  5. Creates an InferenceSchedule with the specified or default number of micro-batches.
  6. Inserts a dist.barrier() to prevent deadlocks with sequential eval calls.
  7. Executes the schedule under torch.no_grad() context.
  8. On the last stage, reduces forward outputs across micro-batches using the configured method.
  9. Optionally broadcasts the loss from the last stage to all pipeline stages.
  10. Restores the training data iterator.
  11. Returns the evaluation output (and optionally logits).

The method also documents the pipeline checkpointing approach: module_state_dict() saves per-layer state dicts via PipelineModule.save_state_dict(), and load_module_state_dict() loads per-layer checkpoints via PipelineModule.load_state_dir().

Code Reference

  • Repository: https://github.com/deepspeedai/DeepSpeed
  • File (eval_batch): deepspeed/runtime/pipe/engine.py, Lines L427-514
  • File (module_state_dict/load_module_state_dict): deepspeed/runtime/pipe/engine.py, Lines L1308-1346
  • File (InferenceSchedule): deepspeed/runtime/pipe/schedule.py, Lines L135-186

eval_batch signature:

def eval_batch(self,
               data_iter,
               return_logits=False,
               compute_loss=True,
               reduce_output='avg',
               bcast_loss=True,
               num_micro_batches=None) -> torch.Tensor

Import:

# Accessed via PipelineEngine returned by deepspeed.initialize()
engine, _, _, _ = deepspeed.initialize(model=pipeline_model, ...)
loss = engine.eval_batch(data_iter=eval_iter)

I/O Contract

Inputs

Parameter Type Required Default Description
data_iter Iterator Yes Evaluation data iterator yielding (inputs, labels) tuples
return_logits bool No False If True, return model outputs (logits) alongside loss
compute_loss bool No True If True, compute loss using the model's loss_fn
reduce_output str or None No 'avg' Reduction method: 'avg' to average across micro-batches, or None for no reduction
bcast_loss bool No True If True, broadcast the loss from the last stage to all pipeline stages
num_micro_batches int No None Override the number of micro-batches; defaults to gradient_accumulation_steps

Outputs

Output Type Description
eval_output torch.Tensor Evaluation loss tensor (reduced and optionally broadcast)
logits (optional) tuple If return_logits=True, returns (eval_output, logits) where logits are the raw model outputs from the last stage

Usage Example

import deepspeed
from deepspeed.pipe import PipelineModule

# Assume engine is already initialized
engine, _, _, _ = deepspeed.initialize(model=pipeline_model, config=config)

# Basic evaluation
engine.eval()
total_loss = 0
for step in range(eval_steps):
    loss = engine.eval_batch(data_iter=eval_iter)
    total_loss += loss.item()
avg_loss = total_loss / eval_steps
print(f"Average eval loss: {avg_loss:.4f}")

# Evaluation with logits
loss, logits = engine.eval_batch(
    data_iter=eval_iter,
    return_logits=True
)

# Evaluation with custom micro-batch count
loss = engine.eval_batch(
    data_iter=eval_iter,
    num_micro_batches=8,
    reduce_output='avg',
    bcast_loss=True
)

# Saving and loading pipeline checkpoints
engine.save_checkpoint("checkpoints/", tag="step_1000")
engine.load_checkpoint("checkpoints/", tag="step_1000")

Related Pages

Knowledge Sources

Last updated: 2026-02-09 00:00 GMT

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment