Implementation:NVIDIA NeMo Aligner RSTrainer Fit
| Knowledge Sources | |
|---|---|
| Domains | NLP, Alignment |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
RSTrainer is the trainer class that coordinates the full Rejection Sampling training loop, including rollout generation, reward-based selection, model training, validation, checkpointing, and logging.
Description
The RSTrainer class orchestrates the Rejection Sampling alignment workflow. It manages the alternating cycle of inference (generation + reward scoring) and optimization (SFT on selected responses). Key responsibilities include:
- Rollout generation: For each batch of prompts, the trainer calls
model.infer()multiple times (controlled bynum_rollouts_per_prompt), scores each response using the reward model (rm.infer_rm()), and selects the top-K responses viaselect_topk(). - Data preparation: The
generate_rs_data()method constructs training-ready tensors by creating response masks (masking out prompt tokens from the loss), and padding all tensors to a uniform global sequence length across data-parallel workers. - Training: The
run_training()method iterates over micro-batches of selected responses, computing the SFT loss and performing gradient updates with gradient clipping. - Validation: The
run_validation()method generates single responses (no top-K selection) and computes reward metrics. - Checkpointing and state: The trainer maintains
step,consumed_samples, andrs_optimization_stepcounters, and supports saving/loading viastate_dict()andload_state_dict().
The fit() method is the main entry point. It iterates over epochs and global steps, calling generate_rollouts() and run_training() at each step, with periodic validation and checkpointing controlled by val_check_interval and save_interval.
Usage
Import and instantiate RSTrainer when setting up a Rejection Sampling training run. It is typically created in the train_gpt_rs_actor.py entry point script after the model, optimizer, scheduler, dataloaders, reward model client, and logger have been initialized.
Code Reference
Source Location
- Repository: NVIDIA_NeMo_Aligner
- File: nemo_aligner/algorithms/rs.py
- Lines: 41-478
Signature
class RSTrainer:
"""Trainer to coordinate RS training
"""
def __init__(
self,
cfg: DictConfig,
model,
optimizer,
scheduler,
train_dataloader,
val_dataloader,
logger,
ckpt_callback,
run_timer,
num_rollouts_per_prompt,
top_n_rollouts,
rm,
):
Key Method: fit()
def fit(self):
epoch_iter = range(self.epoch, self.cfg.max_epochs)
if len(epoch_iter) <= 0:
return
for _ in epoch_iter:
num_steps_in_epoch = min(
self.max_steps - self.step, self.num_steps_per_epoch - self.step % self.num_steps_per_epoch
)
loop_iter = range(num_steps_in_epoch)
if not loop_iter:
return
dataloader_iter = iter(self.train_dataloader)
global_pbar = tqdm(loop_iter, initial=self.step, total=self.max_steps, leave=True, desc="RS Global Step")
...
Import
from nemo_aligner.algorithms.rs import RSTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cfg | DictConfig | Yes | Trainer configuration (from cfg.trainer.rs), includes max_epochs, model_gbs, gradient_clip_val, val_check_interval, save_interval, rollout_batch_seq_length
|
| model | MegatronGPTRSModel | Yes | The actor model implementing AlignableGenerativeInterface, with infer(), get_loss_and_metrics(), and training lifecycle methods
|
| optimizer | Optimizer | Yes | PyTorch optimizer extracted from the PTL model |
| scheduler | LRScheduler | Yes | Learning rate scheduler extracted from the PTL model |
| train_dataloader | DataLoader | Yes | Training dataloader providing prompt batches with keys text and length
|
| val_dataloader | DataLoader | Yes | Validation dataloader providing prompt batches |
| logger | CustomLoggerWrapper | Yes | Logger for metrics and table logging (e.g., WandB) |
| ckpt_callback | Callback | Yes | Custom checkpoint callback for saving model states |
| run_timer | Timer | Yes | Timer that enforces maximum training time limits |
| num_rollouts_per_prompt | int | Yes | Number of candidate responses to generate per prompt |
| top_n_rollouts | int | Yes | Number of top-scoring responses to select per prompt for training |
| rm | RemoteGPTRMClient | Yes | Remote reward model client for scoring generated responses |
Outputs
| Name | Type | Description |
|---|---|---|
| None (side effects) | N/A | The fit() method trains the model in-place, logs metrics, and saves checkpoints. It does not return a value.
|
Usage Examples
from nemo_aligner.algorithms.rs import RSTrainer
rs_trainer = RSTrainer(
cfg=cfg.trainer.rs,
model=ptl_model,
optimizer=optimizer,
scheduler=scheduler,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
rm=rm,
logger=logger,
ckpt_callback=ckpt_callback,
run_timer=timer,
num_rollouts_per_prompt=cfg.model.rs.num_rollouts_per_prompt,
top_n_rollouts=cfg.model.rs.top_n_rollouts,
)
if custom_trainer_state_dict is not None:
rs_trainer.load_state_dict(custom_trainer_state_dict)
rs_trainer.fit()