Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA NeMo Aligner MegatronGPT RS Actor

From Leeroopedia


Knowledge Sources
Domains NLP, Alignment
Last Updated 2026-02-08 00:00 GMT

Overview

MegatronGPTRSModel is the actor model class for Rejection Sampling training, extending MegatronGPTModel with the AlignableGenerativeInterface to support both text generation (inference) and supervised fine-tuning on selected responses.

Description

The MegatronGPTRSModel class combines three base classes:

  • NLPAdapterModelMixin -- enables PEFT (Parameter-Efficient Fine-Tuning) adapter support.
  • MegatronGPTModel -- provides the core Megatron-LM GPT model with tensor/pipeline parallelism.
  • AlignableGenerativeInterface -- defines the interface contract for alignable models that support generation and training lifecycle methods.

Key capabilities include:

  • Inference: The infer() method generates text given prompt tokens using the TrackLengthGPTModelTextGenerationStrategy, which tracks response lengths during generation. Generation parameters (length and sampling) are configured via cfg.rs.length_params and cfg.rs.sampling_params.
  • Training: The get_loss_and_metrics() method computes the SFT loss as the negative masked mean of log-probabilities over response tokens. It uses Megatron's pipeline-parallel forward-backward function.
  • Adam state offloading: When cfg.rs.offload_adam_states is enabled, the model offloads distributed Adam optimizer states to CPU during inference to free GPU memory, and reloads them before training.
  • Training lifecycle: Implements prepare_for_training(), prepare_for_inference(), finish_inference(), and related methods that handle batch size configuration, activation checkpointing, sequence parallelism, and eval/train mode toggling.

Usage

Import MegatronGPTRSModel when loading the pretrained model for RS training. It is loaded from a NeMo checkpoint via load_from_nemo() in the training entry script.

Code Reference

Source Location

  • Repository: NVIDIA_NeMo_Aligner
  • File: nemo_aligner/models/nlp/gpt/megatron_gpt_rs_actor.py
  • Lines: 53-262

Signature

class MegatronGPTRSModel(NLPAdapterModelMixin, MegatronGPTModel, AlignableGenerativeInterface):
    def __init__(self, cfg: DictConfig, trainer: Trainer):

Import

from nemo_aligner.models.nlp.gpt.megatron_gpt_rs_actor import MegatronGPTRSModel

I/O Contract

Inputs

Name Type Required Description
cfg DictConfig Yes Model configuration including rs.length_params, rs.sampling_params, rs.offload_adam_states, rs.forward_micro_batch_size, micro_batch_size, global_batch_size
trainer pytorch_lightning.Trainer Yes PyTorch Lightning trainer instance

Key Method: infer()

Name Type Required Description
inference_batch dict Yes Dictionary with text (prompt token IDs, LongTensor) and length (prompt lengths, LongTensor)

Outputs

Name Type Description
rollout_batch dict Dictionary containing response_tokens (LongTensor), response_lengths (LongTensor), and prompt_lengths (LongTensor), all on GPU

Key Method: get_loss_and_metrics()

Name Type Description
loss_mean float Mean SFT loss (negative mean log-prob on response tokens)
metrics dict Dictionary with key loss containing the reduced loss value

Usage Examples

from nemo_aligner.models.nlp.gpt.megatron_gpt_rs_actor import MegatronGPTRSModel
from nemo_aligner.utils.utils import load_from_nemo

ptl_model = load_from_nemo(
    MegatronGPTRSModel,
    cfg.model,
    trainer,
    strict=True,
    restore_path=cfg.pretrained_checkpoint.restore_from_path,
)

# Inference
rollout_batch = ptl_model.infer(inference_batch)

# Training
loss_mean, metrics = ptl_model.get_loss_and_metrics(batch=batch, forward_only=False)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment