Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA NeMo Aligner GPTSteerLMModel

From Leeroopedia


Knowledge Sources
Domains Natural Language Processing, Controllable Generation, Alignment
Last Updated 2026-02-08 00:00 GMT

Overview

GPTSteerLMModel extends GPTSFTModel to implement SteerLM v2 training with importance-weighted baseline corrections for attribute-conditioned controllable text generation.

Description

The GPTSteerLMModel class inherits from GPTSFTModel and overrides the loss computation and training loop to support the SteerLM v2 iterative training methodology.

The class provides the following key capabilities:

  • loss_func -- Computes the per-token negative log-likelihood for each sequence in the batch (no weighting), used during baseline weight computation.
  • weight_loss_func -- Applies per-response importance weights to the NLL loss, normalizes by the average number of valid tokens, and sums across the micro-batch.
  • compute_baselilne_weights -- (note: the typo "baselilne" is preserved from the source code) Performs a forward-only pass to compute NLL for all responses, then computes the baseline softmax probabilities and the importance sampling weights. Also computes the KL distance metric between target and baseline distributions.
  • get_forward_output_only_for_weight_func -- Returns a closure that runs a no-gradient forward pass to compute NLL loss for baseline estimation.
  • get_forward_output_and_loss_func -- Returns a closure that computes the weighted training loss using the pre-computed importance weights.
  • get_loss_and_metrics -- Orchestrates the two-phase training step: first computing baseline weights, then performing the weighted forward-backward pass.

The model uses two different micro-batch sizes: steerlm2.forward_micro_batch_size for the baseline computation pass and steerlm2.micro_batch_size for the actual training pass.

Usage

Import and use this class when implementing SteerLM v2 training. It is designed to be used with the SupervisedTrainer and requires a dataset that provides num_responses, log(Q(y|a,x)) (log-probabilities under the proposal distribution), and ws (target importance sampling weights).

Code Reference

Source Location

  • Repository: NVIDIA_NeMo_Aligner
  • File: nemo_aligner/models/nlp/gpt/gpt_steerlm_model.py
  • Lines: 53-239

Signature

class GPTSteerLMModel(GPTSFTModel):
    def loss_func(self, loss_mask, output_tensor):
    def weight_loss_func(self, loss_mask, avg_num_valid_tokens_in_ub, output_tensor, weight):
    def compute_baselilne_weights(self, batch):
    def get_forward_output_only_for_weight_func(self):
    def get_forward_output_and_loss_func(self, validation_step=False):
    def get_loss_and_metrics(self, batch, forward_only):

Import

from nemo_aligner.models.nlp.gpt.gpt_steerlm_model import GPTSteerLMModel

I/O Contract

Inputs

Name Type Required Description
batch["tokens"] Tensor Yes Input token IDs of shape [B, seq_length]
batch["labels"] Tensor Yes Target label IDs of shape [B, seq_length]
batch["loss_mask"] Tensor Yes Loss mask of shape [B, seq_length]
batch["attention_mask"] Tensor Yes Attention mask tensor
batch["position_ids"] Tensor Yes Position ID tensor
batch["num_responses"] Tensor Yes Number of responses per prompt group
a,x))"] Tensor Yes Log-probability of each response under proposal distribution
batch["ws"] Tensor Yes Target importance sampling weights

Outputs

Name Type Description
loss_value float Importance-weighted training loss
metrics dict Dictionary with "loss" (weighted loss) and "distance" (KL divergence between target and baseline distributions)

Usage Examples

from nemo_aligner.models.nlp.gpt.gpt_steerlm_model import GPTSteerLMModel

# Model is typically loaded via load_from_nemo
ptl_model, updated_cfg = load_from_nemo(
    GPTSteerLMModel,
    cfg,
    trainer,
    strict=True,
    modify_config_fn=_modify_config,
    restore_path=cfg.model.restore_from_path,
    return_updated_cfg=True,
)

# Training step returns loss and metrics including KL distance
loss_value, metrics = ptl_model.get_loss_and_metrics(batch, forward_only=False)
# metrics["distance"] tracks convergence of the model to the target distribution

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment