Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Facebookresearch Habitat lab BaseTrainer

From Leeroopedia
Knowledge Sources
Domains Embodied_AI, Reinforcement_Learning, Training_Infrastructure
Last Updated 2026-02-15 00:00 GMT

Overview

BaseTrainer and BaseRLTrainer are abstract base classes that define the training and evaluation interfaces for all Habitat baselines trainers, providing checkpoint management, preemption handling, and progress tracking.

Description

The module defines two classes forming a trainer hierarchy:

BaseTrainer is the most generic trainer template, providing:

  • train(): Abstract method for the training loop (must be overridden).
  • eval(): Complete evaluation pipeline that iterates over checkpoints in a directory (or evaluates a single checkpoint), polling for new checkpoints with a 2-second sleep interval. Supports resume from preemption by saving/loading resume state with save_resume_state()/load_resume_state(). Configures video output to tensorboard or disk.
  • _eval_checkpoint(): Abstract method for evaluating a single checkpoint.
  • save_checkpoint() / load_checkpoint(): Abstract methods for checkpoint persistence.
  • _get_resume_state_config_or_new_config(): Handles configuration reconciliation when resuming training, optionally using the original training config.
  • _add_preemption_signal_handlers(): Sets up SLURM signal handlers for graceful preemption.

BaseRLTrainer extends BaseTrainer for reinforcement learning with:

  • Progress tracking via num_updates_done and num_steps_done.
  • percent_done(): Computes training progress as a fraction (0.0 to 1.0) based on either num_updates or total_num_steps.
  • is_done(): Returns True when training is complete.
  • should_checkpoint(): Determines if a checkpoint should be saved based on either num_checkpoints (evenly spaced) or checkpoint_interval (fixed interval).
  • _should_save_resume_state(): Checks whether preemption resume state should be saved based on SLURM batch job status and configurable intervals.
  • Constructor validation ensuring exactly one of num_updates/total_num_steps is set, and exactly one of num_checkpoints/checkpoint_interval is set.

Usage

These classes are never instantiated directly. Concrete trainers like PPOTrainer or VERTrainer inherit from BaseRLTrainer and implement the abstract methods. The baseline_registry is used to look up and instantiate the appropriate trainer class based on configuration.

Code Reference

Source Location

Signature

class BaseTrainer:
    config: "DictConfig"
    flush_secs: float
    supported_tasks: ClassVar[List[str]]

    def train(self) -> None: ...
    def eval(self) -> None: ...
    def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        checkpoint_index: int = 0,
    ) -> None: ...
    def save_checkpoint(self, file_name) -> None: ...
    def load_checkpoint(self, checkpoint_path, *args, **kwargs) -> Dict: ...


class BaseRLTrainer(BaseTrainer):
    device: torch.device
    config: "DictConfig"
    video_option: List[str]
    num_updates_done: int
    num_steps_done: int

    def __init__(self, config: "DictConfig") -> None: ...
    def percent_done(self) -> float: ...
    def is_done(self) -> bool: ...
    def should_checkpoint(self) -> bool: ...

Import

from habitat_baselines.common.base_trainer import BaseTrainer, BaseRLTrainer

I/O Contract

Inputs

Name Type Required Description
config DictConfig Yes Habitat baselines configuration containing training parameters

Outputs

Name Type Description
percent_done() float Training progress as a fraction between 0.0 and 1.0
is_done() bool Whether training has completed
should_checkpoint() bool Whether a checkpoint should be saved at the current step

Usage Examples

Basic Usage

from habitat_baselines.common.base_trainer import BaseRLTrainer, BaseTrainer
from habitat_baselines.common.tensorboard_utils import TensorboardWriter
from typing import Dict

# Create a custom RL trainer by inheriting BaseRLTrainer
class MyCustomTrainer(BaseRLTrainer):
    supported_tasks = ["Nav-v0"]

    def __init__(self, config):
        super().__init__(config)

    def train(self) -> None:
        # Custom training loop
        while not self.is_done():
            # ... perform training step ...
            self.num_updates_done += 1
            self.num_steps_done += batch_size

            if self.should_checkpoint():
                self.save_checkpoint(f"ckpt.{self.num_updates_done}.pth")

    def _eval_checkpoint(
        self,
        checkpoint_path: str,
        writer: TensorboardWriter,
        checkpoint_index: int = 0,
    ) -> None:
        state_dict = self.load_checkpoint(checkpoint_path)
        # ... run evaluation ...

    def save_checkpoint(self, file_name) -> None:
        # ... save model state ...
        pass

    def load_checkpoint(self, checkpoint_path, *args, **kwargs) -> Dict:
        # ... load model state ...
        return {}

# Using the trainer
trainer = MyCustomTrainer(config)
trainer.train()  # Training loop
trainer.eval()   # Evaluation loop over checkpoints

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment