Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA NeMo Aligner Train SteerLM2

From Leeroopedia


Knowledge Sources
Domains Natural Language Processing, Controllable Generation, Alignment
Last Updated 2026-02-08 00:00 GMT

Overview

train_steerlm2.py is the entry-point training script for SteerLM v2 attribute-conditioned SFT training, orchestrating dataset construction, model loading, and supervised training with importance-weighted losses.

Description

This script sets up the complete SteerLM v2 training pipeline:

  1. Configuration -- Uses Hydra with the config path "conf" and config name "gpt_sft". Configuration is modified via _modify_config to inject SteerLM-specific settings including steerlm2 micro-batch sizes, PEFT configuration, and chat prompt templates.
  2. Model loading -- Loads a pretrained GPTSteerLMModel from a NeMo checkpoint, applying configuration modifications and optionally initializing PEFT adapters.
  3. Dataset construction -- Uses a custom build_sft_dataset function (defined locally, not the one from nemo_aligner.data.nlp.builders) that instantiates SteerLM2Dataset objects. These datasets contain responses annotated with importance sampling weights and proposal distribution log-probabilities.
  4. Dataloader creation -- Builds train and validation dataloaders using build_dataloader with the dataset's own collate_fn.
  5. Training execution -- Creates a SupervisedTrainer with the SFT configuration section and calls fit().

The _modify_config function is responsible for transferring fine-tuning configuration parameters (micro-batch size, global-batch size, PEFT settings, dropout, precision, SteerLM v2 settings, chat prompt templates) from the training config to the model config loaded from the checkpoint.

The script supports optional data sampling via cfg.model.data.sample, which controls whether the dataset size is limited based on max_steps.

Usage

Run this script to perform SteerLM v2 training. Requires a dataset in the SteerLM v2 format containing responses with attribute labels, importance weights, and proposal distribution log-probabilities.

Code Reference

Source Location

Signature

def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, is_chat=True, special_tokens=None):

def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):

@hydra_runner(config_path="conf", config_name="gpt_sft")
def main(cfg) -> None:

Import

from nemo_aligner.models.nlp.gpt.gpt_steerlm_model import GPTSteerLMModel
from nemo_aligner.data.nlp.datasets import SteerLM2Dataset
from nemo_aligner.algorithms.supervised import SupervisedTrainer

I/O Contract

Inputs

Name Type Required Description
cfg.model.restore_from_path str Yes Path to the pretrained NeMo checkpoint
cfg.model.data.train_ds DictConfig Yes Training dataset configuration with file_path, max_seq_length, micro_batch_size, global_batch_size
cfg.model.data.validation_ds DictConfig Yes Validation dataset configuration
cfg.model.data.chat bool Yes Whether the data is in chat format
cfg.model.data.chat_prompt_tokens DictConfig No Special tokens for chat formatting
cfg.model.steerlm2 DictConfig Yes SteerLM v2 specific configuration including forward_micro_batch_size and micro_batch_size

Outputs

Name Type Description
Trained model checkpoint File Saved NeMo checkpoint of the SteerLM v2 trained model
Training logs Logs Loss and distance metrics logged via the experiment logger

Usage Examples

# Command-line invocation:
# python examples/nlp/gpt/train_steerlm2.py \
#     model.restore_from_path=/path/to/model.nemo \
#     model.data.train_ds.file_path=/path/to/steerlm2_train.jsonl \
#     model.data.validation_ds.file_path=/path/to/steerlm2_val.jsonl \
#     model.data.chat=True \
#     model.steerlm2.forward_micro_batch_size=4 \
#     model.steerlm2.micro_batch_size=2

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment