Implementation:NVIDIA NeMo Aligner GPTSteerLMModel
| Knowledge Sources | |
|---|---|
| Domains | Natural Language Processing, Controllable Generation, Alignment |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
GPTSteerLMModel extends GPTSFTModel to implement SteerLM v2 training with importance-weighted baseline corrections for attribute-conditioned controllable text generation.
Description
The GPTSteerLMModel class inherits from GPTSFTModel and overrides the loss computation and training loop to support the SteerLM v2 iterative training methodology.
The class provides the following key capabilities:
- loss_func -- Computes the per-token negative log-likelihood for each sequence in the batch (no weighting), used during baseline weight computation.
- weight_loss_func -- Applies per-response importance weights to the NLL loss, normalizes by the average number of valid tokens, and sums across the micro-batch.
- compute_baselilne_weights -- (note: the typo "baselilne" is preserved from the source code) Performs a forward-only pass to compute NLL for all responses, then computes the baseline softmax probabilities and the importance sampling weights. Also computes the KL distance metric between target and baseline distributions.
- get_forward_output_only_for_weight_func -- Returns a closure that runs a no-gradient forward pass to compute NLL loss for baseline estimation.
- get_forward_output_and_loss_func -- Returns a closure that computes the weighted training loss using the pre-computed importance weights.
- get_loss_and_metrics -- Orchestrates the two-phase training step: first computing baseline weights, then performing the weighted forward-backward pass.
The model uses two different micro-batch sizes: steerlm2.forward_micro_batch_size for the baseline computation pass and steerlm2.micro_batch_size for the actual training pass.
Usage
Import and use this class when implementing SteerLM v2 training. It is designed to be used with the SupervisedTrainer and requires a dataset that provides num_responses, log(Q(y|a,x)) (log-probabilities under the proposal distribution), and ws (target importance sampling weights).
Code Reference
Source Location
- Repository: NVIDIA_NeMo_Aligner
- File: nemo_aligner/models/nlp/gpt/gpt_steerlm_model.py
- Lines: 53-239
Signature
class GPTSteerLMModel(GPTSFTModel):
def loss_func(self, loss_mask, output_tensor):
def weight_loss_func(self, loss_mask, avg_num_valid_tokens_in_ub, output_tensor, weight):
def compute_baselilne_weights(self, batch):
def get_forward_output_only_for_weight_func(self):
def get_forward_output_and_loss_func(self, validation_step=False):
def get_loss_and_metrics(self, batch, forward_only):
Import
from nemo_aligner.models.nlp.gpt.gpt_steerlm_model import GPTSteerLMModel
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| batch["tokens"] | Tensor | Yes | Input token IDs of shape [B, seq_length] |
| batch["labels"] | Tensor | Yes | Target label IDs of shape [B, seq_length] |
| batch["loss_mask"] | Tensor | Yes | Loss mask of shape [B, seq_length] |
| batch["attention_mask"] | Tensor | Yes | Attention mask tensor |
| batch["position_ids"] | Tensor | Yes | Position ID tensor |
| batch["num_responses"] | Tensor | Yes | Number of responses per prompt group |
| a,x))"] | Tensor | Yes | Log-probability of each response under proposal distribution |
| batch["ws"] | Tensor | Yes | Target importance sampling weights |
Outputs
| Name | Type | Description |
|---|---|---|
| loss_value | float | Importance-weighted training loss |
| metrics | dict | Dictionary with "loss" (weighted loss) and "distance" (KL divergence between target and baseline distributions) |
Usage Examples
from nemo_aligner.models.nlp.gpt.gpt_steerlm_model import GPTSteerLMModel
# Model is typically loaded via load_from_nemo
ptl_model, updated_cfg = load_from_nemo(
GPTSteerLMModel,
cfg,
trainer,
strict=True,
modify_config_fn=_modify_config,
restore_path=cfg.model.restore_from_path,
return_updated_cfg=True,
)
# Training step returns loss and metrics including KL distance
loss_value, metrics = ptl_model.get_loss_and_metrics(batch, forward_only=False)
# metrics["distance"] tracks convergence of the model to the target distribution