Implementation:NVIDIA NeMo Aligner MegatronGPT RS Actor
Appearance
| Knowledge Sources | |
|---|---|
| Domains | NLP, Alignment |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
MegatronGPTRSModel is the actor model class for Rejection Sampling training, extending MegatronGPTModel with the AlignableGenerativeInterface to support both text generation (inference) and supervised fine-tuning on selected responses.
Description
The MegatronGPTRSModel class combines three base classes:
- NLPAdapterModelMixin -- enables PEFT (Parameter-Efficient Fine-Tuning) adapter support.
- MegatronGPTModel -- provides the core Megatron-LM GPT model with tensor/pipeline parallelism.
- AlignableGenerativeInterface -- defines the interface contract for alignable models that support generation and training lifecycle methods.
Key capabilities include:
- Inference: The
infer()method generates text given prompt tokens using theTrackLengthGPTModelTextGenerationStrategy, which tracks response lengths during generation. Generation parameters (length and sampling) are configured viacfg.rs.length_paramsandcfg.rs.sampling_params. - Training: The
get_loss_and_metrics()method computes the SFT loss as the negative masked mean of log-probabilities over response tokens. It uses Megatron's pipeline-parallel forward-backward function. - Adam state offloading: When
cfg.rs.offload_adam_statesis enabled, the model offloads distributed Adam optimizer states to CPU during inference to free GPU memory, and reloads them before training. - Training lifecycle: Implements
prepare_for_training(),prepare_for_inference(),finish_inference(), and related methods that handle batch size configuration, activation checkpointing, sequence parallelism, and eval/train mode toggling.
Usage
Import MegatronGPTRSModel when loading the pretrained model for RS training. It is loaded from a NeMo checkpoint via load_from_nemo() in the training entry script.
Code Reference
Source Location
- Repository: NVIDIA_NeMo_Aligner
- File: nemo_aligner/models/nlp/gpt/megatron_gpt_rs_actor.py
- Lines: 53-262
Signature
class MegatronGPTRSModel(NLPAdapterModelMixin, MegatronGPTModel, AlignableGenerativeInterface):
def __init__(self, cfg: DictConfig, trainer: Trainer):
Import
from nemo_aligner.models.nlp.gpt.megatron_gpt_rs_actor import MegatronGPTRSModel
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cfg | DictConfig | Yes | Model configuration including rs.length_params, rs.sampling_params, rs.offload_adam_states, rs.forward_micro_batch_size, micro_batch_size, global_batch_size
|
| trainer | pytorch_lightning.Trainer | Yes | PyTorch Lightning trainer instance |
Key Method: infer()
| Name | Type | Required | Description |
|---|---|---|---|
| inference_batch | dict | Yes | Dictionary with text (prompt token IDs, LongTensor) and length (prompt lengths, LongTensor)
|
Outputs
| Name | Type | Description |
|---|---|---|
| rollout_batch | dict | Dictionary containing response_tokens (LongTensor), response_lengths (LongTensor), and prompt_lengths (LongTensor), all on GPU
|
Key Method: get_loss_and_metrics()
| Name | Type | Description |
|---|---|---|
| loss_mean | float | Mean SFT loss (negative mean log-prob on response tokens) |
| metrics | dict | Dictionary with key loss containing the reduced loss value
|
Usage Examples
from nemo_aligner.models.nlp.gpt.megatron_gpt_rs_actor import MegatronGPTRSModel
from nemo_aligner.utils.utils import load_from_nemo
ptl_model = load_from_nemo(
MegatronGPTRSModel,
cfg.model,
trainer,
strict=True,
restore_path=cfg.pretrained_checkpoint.restore_from_path,
)
# Inference
rollout_batch = ptl_model.infer(inference_batch)
# Training
loss_mean, metrics = ptl_model.get_loss_and_metrics(batch=batch, forward_only=False)
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment