Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:ARISE Initiative Robomimic TrainUtils run epoch

From Leeroopedia
Knowledge Sources
Domains Robotics, Training, Optimization
Last Updated 2026-02-15 08:00 GMT

Overview

Concrete tool for running a single epoch of training or validation over batched demonstration data provided by the robomimic training utilities module.

Description

The run_epoch function iterates over a DataLoader for a specified number of steps (or the full dataset), calling the algorithm's processing and training methods on each batch. It handles data iterator reset at epoch boundaries, collects per-step metrics, and returns averaged statistics. When validate=True, the model is set to eval mode and no gradient updates are performed.

Usage

Call this function in the per-epoch training loop, once for training and optionally once for validation with validate=True. Requires a fully instantiated Algo and a DataLoader wrapping a SequenceDataset.

Code Reference

Source Location

  • Repository: robomimic
  • File: robomimic/utils/train_utils.py
  • Lines: L637-721

Signature

def run_epoch(model, data_loader, epoch, validate=False, num_steps=None, obs_normalization_stats=None):
    """
    Run an epoch of training or validation.

    Args:
        model (Algo instance): model to train
        data_loader (DataLoader instance): data loader that will be used to serve batches of data
        epoch (int): epoch number
        validate (bool): whether this is a training epoch or validation epoch
        num_steps (int): if provided, this epoch lasts for a fixed number of batches
        obs_normalization_stats (dict or None): if provided, maps observation keys to dicts
            with a "mean" and "std" of shape (1, ...)

    Returns:
        step_log_all (dict): dictionary of logged training metrics averaged across all batches
    """

Import

import robomimic.utils.train_utils as TrainUtils

# Call as:
train_log = TrainUtils.run_epoch(model=model, data_loader=train_loader, epoch=epoch)

I/O Contract

Inputs

Name Type Required Description
model Algo Yes Algorithm instance to train or validate
data_loader DataLoader Yes PyTorch DataLoader wrapping a SequenceDataset
epoch int Yes Current epoch number (passed to model.train_on_batch)
validate bool No If True, runs validation (no gradient updates). Default: False
num_steps int No Fixed number of batches per epoch. Default: len(data_loader)
obs_normalization_stats dict No Observation normalization statistics

Outputs

Name Type Description
step_log_all dict Averaged training metrics across all batches; includes Time_Data_Loading, Time_Process_Batch, Time_Train_Batch, Time_Log_Info (in minutes), Time_Epoch (total epoch time in minutes), plus algorithm-specific loss metrics

Usage Examples

Training and Validation Epochs

import robomimic.utils.train_utils as TrainUtils

for epoch in range(1, num_epochs + 1):
    # Training epoch
    train_log = TrainUtils.run_epoch(
        model=model,
        data_loader=train_loader,
        epoch=epoch,
        validate=False,
        num_steps=config.experiment.epoch_every_n_steps,
        obs_normalization_stats=obs_normalization_stats,
    )
    print(f"Train Loss: {train_log.get('Loss', 'N/A')}")

    # Validation epoch
    if valid_loader is not None:
        with torch.no_grad():
            valid_log = TrainUtils.run_epoch(
                model=model,
                data_loader=valid_loader,
                epoch=epoch,
                validate=True,
                num_steps=config.experiment.validation_epoch_every_n_steps,
                obs_normalization_stats=obs_normalization_stats,
            )
        print(f"Valid Loss: {valid_log.get('Loss', 'N/A')}")

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment