Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:NVIDIA NeMo Aligner SupervisedTrainer Fit

From Leeroopedia


Implementation Metadata
Name SupervisedTrainer_Fit
Type API Doc
Implements Principle Supervised_Training_Loop
Repository NeMo Aligner
File nemo_aligner/algorithms/supervised.py
Lines L34-322
Domains Deep_Learning, Training
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for executing supervised training loops for SFT and reward model training provided by the NeMo Aligner algorithms module.

Description

The SupervisedTrainer class implements a generic supervised training loop that iterates over epochs and batches, computes loss via the model's get_loss_and_metrics interface, runs validation at configurable intervals, saves checkpoints, and logs metrics. It handles distributed training concerns (batch samplers, gradient accumulation) and supports both single and multi-epoch training. The fit() method is the main entry point that orchestrates the full training run.

Usage

Import when executing SFT or reward model training. The model must implement SupervisedInterface (get_loss_and_metrics, prepare_for_training_step, finish_training_step). Used in train_gpt_sft.py and train_reward_model.py.

Code Reference

Source Location

  • Repository: NeMo Aligner
  • File: nemo_aligner/algorithms/supervised.py
  • Lines: L34-322

Signature

class SupervisedTrainer:
    def __init__(
        self,
        cfg: DictConfig,
        model,              # must implement SupervisedInterface
        optimizer,
        scheduler,
        train_dataloader,
        val_dataloader,
        test_dataloader,
        logger,
        ckpt_callback,
        run_timer,
        run_init_validation=False,
    ):
        ...

    def fit(self) -> None:
        """Main training loop: iterates epochs, trains, validates, saves."""

    def run_validation(self) -> Tuple[float, Dict]:
        """Run validation across all validation dataloaders."""

Import

from nemo_aligner.algorithms.supervised import SupervisedTrainer

I/O Contract

Inputs

Name Type Required Description
cfg DictConfig Yes Training config with max_epochs, val_check_interval, save_interval, limit_val_batches
model SupervisedInterface Yes Model implementing get_loss_and_metrics
optimizer Optimizer Yes PyTorch optimizer (extracted from PTL model)
scheduler LRScheduler Yes Learning rate scheduler
train_dataloader DataLoader Yes Training data loader with MegatronPretrainingBatchSampler
val_dataloader DataLoader Yes Validation data loader(s)
logger Logger Yes Metric logging (WandB, TensorBoard)
ckpt_callback NeMoModelCheckpoint Yes Checkpoint save callback

Outputs

Name Type Description
(side effect) None Trained model weights (in-place), saved checkpoints, logged metrics

Usage Examples

Running SFT Training

from nemo_aligner.algorithms.supervised import SupervisedTrainer

trainer = SupervisedTrainer(
    cfg=cfg.trainer.sft,
    model=model,
    optimizer=optimizer,
    scheduler=scheduler,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    test_dataloader=test_dataloader,
    logger=logger,
    ckpt_callback=ckpt_callback,
    run_timer=timer,
)
trainer.fit()

Related Pages

Knowledge Sources

Deep_Learning | Training

2026-02-07 00:00 GMT

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment