Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA NeMo Aligner RSTrainer Fit

From Leeroopedia


Knowledge Sources
Domains NLP, Alignment
Last Updated 2026-02-08 00:00 GMT

Overview

RSTrainer is the trainer class that coordinates the full Rejection Sampling training loop, including rollout generation, reward-based selection, model training, validation, checkpointing, and logging.

Description

The RSTrainer class orchestrates the Rejection Sampling alignment workflow. It manages the alternating cycle of inference (generation + reward scoring) and optimization (SFT on selected responses). Key responsibilities include:

  • Rollout generation: For each batch of prompts, the trainer calls model.infer() multiple times (controlled by num_rollouts_per_prompt), scores each response using the reward model (rm.infer_rm()), and selects the top-K responses via select_topk().
  • Data preparation: The generate_rs_data() method constructs training-ready tensors by creating response masks (masking out prompt tokens from the loss), and padding all tensors to a uniform global sequence length across data-parallel workers.
  • Training: The run_training() method iterates over micro-batches of selected responses, computing the SFT loss and performing gradient updates with gradient clipping.
  • Validation: The run_validation() method generates single responses (no top-K selection) and computes reward metrics.
  • Checkpointing and state: The trainer maintains step, consumed_samples, and rs_optimization_step counters, and supports saving/loading via state_dict() and load_state_dict().

The fit() method is the main entry point. It iterates over epochs and global steps, calling generate_rollouts() and run_training() at each step, with periodic validation and checkpointing controlled by val_check_interval and save_interval.

Usage

Import and instantiate RSTrainer when setting up a Rejection Sampling training run. It is typically created in the train_gpt_rs_actor.py entry point script after the model, optimizer, scheduler, dataloaders, reward model client, and logger have been initialized.

Code Reference

Source Location

Signature

class RSTrainer:
    """Trainer to coordinate RS training
    """

    def __init__(
        self,
        cfg: DictConfig,
        model,
        optimizer,
        scheduler,
        train_dataloader,
        val_dataloader,
        logger,
        ckpt_callback,
        run_timer,
        num_rollouts_per_prompt,
        top_n_rollouts,
        rm,
    ):

Key Method: fit()

def fit(self):
    epoch_iter = range(self.epoch, self.cfg.max_epochs)
    if len(epoch_iter) <= 0:
        return
    for _ in epoch_iter:
        num_steps_in_epoch = min(
            self.max_steps - self.step, self.num_steps_per_epoch - self.step % self.num_steps_per_epoch
        )
        loop_iter = range(num_steps_in_epoch)
        if not loop_iter:
            return
        dataloader_iter = iter(self.train_dataloader)
        global_pbar = tqdm(loop_iter, initial=self.step, total=self.max_steps, leave=True, desc="RS Global Step")
        ...

Import

from nemo_aligner.algorithms.rs import RSTrainer

I/O Contract

Inputs

Name Type Required Description
cfg DictConfig Yes Trainer configuration (from cfg.trainer.rs), includes max_epochs, model_gbs, gradient_clip_val, val_check_interval, save_interval, rollout_batch_seq_length
model MegatronGPTRSModel Yes The actor model implementing AlignableGenerativeInterface, with infer(), get_loss_and_metrics(), and training lifecycle methods
optimizer Optimizer Yes PyTorch optimizer extracted from the PTL model
scheduler LRScheduler Yes Learning rate scheduler extracted from the PTL model
train_dataloader DataLoader Yes Training dataloader providing prompt batches with keys text and length
val_dataloader DataLoader Yes Validation dataloader providing prompt batches
logger CustomLoggerWrapper Yes Logger for metrics and table logging (e.g., WandB)
ckpt_callback Callback Yes Custom checkpoint callback for saving model states
run_timer Timer Yes Timer that enforces maximum training time limits
num_rollouts_per_prompt int Yes Number of candidate responses to generate per prompt
top_n_rollouts int Yes Number of top-scoring responses to select per prompt for training
rm RemoteGPTRMClient Yes Remote reward model client for scoring generated responses

Outputs

Name Type Description
None (side effects) N/A The fit() method trains the model in-place, logs metrics, and saves checkpoints. It does not return a value.

Usage Examples

from nemo_aligner.algorithms.rs import RSTrainer

rs_trainer = RSTrainer(
    cfg=cfg.trainer.rs,
    model=ptl_model,
    optimizer=optimizer,
    scheduler=scheduler,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    rm=rm,
    logger=logger,
    ckpt_callback=ckpt_callback,
    run_timer=timer,
    num_rollouts_per_prompt=cfg.model.rs.num_rollouts_per_prompt,
    top_n_rollouts=cfg.model.rs.top_n_rollouts,
)

if custom_trainer_state_dict is not None:
    rs_trainer.load_state_dict(custom_trainer_state_dict)

rs_trainer.fit()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment