Implementation:NVIDIA NeMo Aligner Train SteerLM2
| Knowledge Sources | |
|---|---|
| Domains | Natural Language Processing, Controllable Generation, Alignment |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
train_steerlm2.py is the entry-point training script for SteerLM v2 attribute-conditioned SFT training, orchestrating dataset construction, model loading, and supervised training with importance-weighted losses.
Description
This script sets up the complete SteerLM v2 training pipeline:
- Configuration -- Uses Hydra with the config path "conf" and config name "gpt_sft". Configuration is modified via _modify_config to inject SteerLM-specific settings including steerlm2 micro-batch sizes, PEFT configuration, and chat prompt templates.
- Model loading -- Loads a pretrained GPTSteerLMModel from a NeMo checkpoint, applying configuration modifications and optionally initializing PEFT adapters.
- Dataset construction -- Uses a custom build_sft_dataset function (defined locally, not the one from nemo_aligner.data.nlp.builders) that instantiates SteerLM2Dataset objects. These datasets contain responses annotated with importance sampling weights and proposal distribution log-probabilities.
- Dataloader creation -- Builds train and validation dataloaders using build_dataloader with the dataset's own collate_fn.
- Training execution -- Creates a SupervisedTrainer with the SFT configuration section and calls fit().
The _modify_config function is responsible for transferring fine-tuning configuration parameters (micro-batch size, global-batch size, PEFT settings, dropout, precision, SteerLM v2 settings, chat prompt templates) from the training config to the model config loaded from the checkpoint.
The script supports optional data sampling via cfg.model.data.sample, which controls whether the dataset size is limited based on max_steps.
Usage
Run this script to perform SteerLM v2 training. Requires a dataset in the SteerLM v2 format containing responses with attribute labels, importance weights, and proposal distribution log-probabilities.
Code Reference
Source Location
- Repository: NVIDIA_NeMo_Aligner
- File: examples/nlp/gpt/train_steerlm2.py
- Lines: 1-285
Signature
def build_sft_dataset(data_cfg, tokenizer, num_samples, answer_only_loss=True, is_chat=True, special_tokens=None):
def _modify_config(gpt_cfg, cfg, add_cfg_to_tree=False):
@hydra_runner(config_path="conf", config_name="gpt_sft")
def main(cfg) -> None:
Import
from nemo_aligner.models.nlp.gpt.gpt_steerlm_model import GPTSteerLMModel
from nemo_aligner.data.nlp.datasets import SteerLM2Dataset
from nemo_aligner.algorithms.supervised import SupervisedTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cfg.model.restore_from_path | str | Yes | Path to the pretrained NeMo checkpoint |
| cfg.model.data.train_ds | DictConfig | Yes | Training dataset configuration with file_path, max_seq_length, micro_batch_size, global_batch_size |
| cfg.model.data.validation_ds | DictConfig | Yes | Validation dataset configuration |
| cfg.model.data.chat | bool | Yes | Whether the data is in chat format |
| cfg.model.data.chat_prompt_tokens | DictConfig | No | Special tokens for chat formatting |
| cfg.model.steerlm2 | DictConfig | Yes | SteerLM v2 specific configuration including forward_micro_batch_size and micro_batch_size |
Outputs
| Name | Type | Description |
|---|---|---|
| Trained model checkpoint | File | Saved NeMo checkpoint of the SteerLM v2 trained model |
| Training logs | Logs | Loss and distance metrics logged via the experiment logger |
Usage Examples
# Command-line invocation:
# python examples/nlp/gpt/train_steerlm2.py \
# model.restore_from_path=/path/to/model.nemo \
# model.data.train_ds.file_path=/path/to/steerlm2_train.jsonl \
# model.data.validation_ds.file_path=/path/to/steerlm2_val.jsonl \
# model.data.chat=True \
# model.steerlm2.forward_micro_batch_size=4 \
# model.steerlm2.micro_batch_size=2