Implementation:ARISE Initiative Robomimic TrainUtils run epoch
| Knowledge Sources | |
|---|---|
| Domains | Robotics, Training, Optimization |
| Last Updated | 2026-02-15 08:00 GMT |
Overview
Concrete tool for running a single epoch of training or validation over batched demonstration data provided by the robomimic training utilities module.
Description
The run_epoch function iterates over a DataLoader for a specified number of steps (or the full dataset), calling the algorithm's processing and training methods on each batch. It handles data iterator reset at epoch boundaries, collects per-step metrics, and returns averaged statistics. When validate=True, the model is set to eval mode and no gradient updates are performed.
Usage
Call this function in the per-epoch training loop, once for training and optionally once for validation with validate=True. Requires a fully instantiated Algo and a DataLoader wrapping a SequenceDataset.
Code Reference
Source Location
- Repository: robomimic
- File: robomimic/utils/train_utils.py
- Lines: L637-721
Signature
def run_epoch(model, data_loader, epoch, validate=False, num_steps=None, obs_normalization_stats=None):
"""
Run an epoch of training or validation.
Args:
model (Algo instance): model to train
data_loader (DataLoader instance): data loader that will be used to serve batches of data
epoch (int): epoch number
validate (bool): whether this is a training epoch or validation epoch
num_steps (int): if provided, this epoch lasts for a fixed number of batches
obs_normalization_stats (dict or None): if provided, maps observation keys to dicts
with a "mean" and "std" of shape (1, ...)
Returns:
step_log_all (dict): dictionary of logged training metrics averaged across all batches
"""
Import
import robomimic.utils.train_utils as TrainUtils
# Call as:
train_log = TrainUtils.run_epoch(model=model, data_loader=train_loader, epoch=epoch)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | Algo | Yes | Algorithm instance to train or validate |
| data_loader | DataLoader | Yes | PyTorch DataLoader wrapping a SequenceDataset |
| epoch | int | Yes | Current epoch number (passed to model.train_on_batch) |
| validate | bool | No | If True, runs validation (no gradient updates). Default: False |
| num_steps | int | No | Fixed number of batches per epoch. Default: len(data_loader) |
| obs_normalization_stats | dict | No | Observation normalization statistics |
Outputs
| Name | Type | Description |
|---|---|---|
| step_log_all | dict | Averaged training metrics across all batches; includes Time_Data_Loading, Time_Process_Batch, Time_Train_Batch, Time_Log_Info (in minutes), Time_Epoch (total epoch time in minutes), plus algorithm-specific loss metrics |
Usage Examples
Training and Validation Epochs
import robomimic.utils.train_utils as TrainUtils
for epoch in range(1, num_epochs + 1):
# Training epoch
train_log = TrainUtils.run_epoch(
model=model,
data_loader=train_loader,
epoch=epoch,
validate=False,
num_steps=config.experiment.epoch_every_n_steps,
obs_normalization_stats=obs_normalization_stats,
)
print(f"Train Loss: {train_log.get('Loss', 'N/A')}")
# Validation epoch
if valid_loader is not None:
with torch.no_grad():
valid_log = TrainUtils.run_epoch(
model=model,
data_loader=valid_loader,
epoch=epoch,
validate=True,
num_steps=config.experiment.validation_epoch_every_n_steps,
obs_normalization_stats=obs_normalization_stats,
)
print(f"Valid Loss: {valid_log.get('Loss', 'N/A')}")