Implementation:NVIDIA NeMo Aligner Train GPT RS Actor
Appearance
| Knowledge Sources | |
|---|---|
| Domains | NLP, Alignment |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
train_gpt_rs_actor.py is the entry point script for launching Rejection Sampling (RS) training of a GPT model using NeMo Aligner.
Description
This script wires together all components required for RS training:
- Configuration loading: Uses Hydra (
@hydra_runner) with config pathconfand config namegpt_rs_actor. Loads and overrides the model config from the pretrained checkpoint. - Trainer and experiment setup: Creates a PyTorch Lightning trainer via
resolve_and_create_trainer(cfg, "rs")and initializes experiment management. - Model loading: Loads a pretrained
MegatronGPTRSModelfrom a NeMo checkpoint, then optionally initializes PEFT adapters viainit_peft(). - Data preparation: Builds RLHF train/validation datasets and dataloaders using
build_train_valid_test_rlhf_datasets(). The collate function pads sequences to the maximum generation length. - Optimizer and scheduler: Extracts the optimizer and scheduler from the PTL model. A dummy dataloader is used to configure NeMo's internal max-steps calculation.
- Reward model client: Instantiates
RemoteGPTRMClientto communicate with an external reward model service. - RSTrainer instantiation: Creates the
RSTrainerwith all dependencies, optionally restores trainer state from a checkpoint, and callsrs_trainer.fit().
Usage
Run this script via the command line with Hydra configuration overrides to launch RS training. It requires a pretrained NeMo GPT checkpoint and a running remote reward model service.
Code Reference
Source Location
- Repository: NVIDIA_NeMo_Aligner
- File: examples/nlp/gpt/train_gpt_rs_actor.py
- Lines: 1-171
Signature
@hydra_runner(config_path="conf", config_name="gpt_rs_actor")
def main(cfg) -> None:
Import
from nemo_aligner.algorithms.rs import RSTrainer
from nemo_aligner.models.nlp.gpt.megatron_gpt_rs_actor import MegatronGPTRSModel
from nemo_aligner.models.nlp.gpt.reward_critic_clients import RemoteGPTRMClient
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cfg | DictConfig | Yes | Hydra configuration object containing pretrained_checkpoint.restore_from_path, model, trainer, exp_manager, remote_rm, and data configuration
|
| cfg.pretrained_checkpoint.restore_from_path | str | Yes | Path to the pretrained NeMo GPT model checkpoint |
| cfg.remote_rm | DictConfig | Yes | Configuration for the remote reward model client (host, port, etc.) |
| cfg.model.rs.num_rollouts_per_prompt | int | Yes | Number of candidate responses to generate per prompt |
| cfg.model.rs.top_n_rollouts | int | Yes | Number of top-scoring responses to keep for training |
| cfg.model.rs.rollout_micro_batch_size | int | Yes | Micro batch size for rollout generation |
| cfg.model.rs.num_rollout_samples | int | Yes | Global batch size for rollout generation |
Outputs
| Name | Type | Description |
|---|---|---|
| None (side effects) | N/A | Trains the model in-place, saves checkpoints, and logs metrics. No return value. |
Usage Examples
# Command-line invocation:
# python examples/nlp/gpt/train_gpt_rs_actor.py \
# pretrained_checkpoint.restore_from_path=/path/to/model.nemo \
# model.rs.num_rollouts_per_prompt=4 \
# model.rs.top_n_rollouts=1 \
# remote_rm.host=localhost \
# remote_rm.port=5555
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment