Implementation:NVIDIA NeMo Aligner SPINTrainer Fit
| Knowledge Sources | |
|---|---|
| Domains | NLP, Alignment |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
SPINTrainer is the trainer class that coordinates SPIN (Self-Play Fine-Tuning) training, managing the nested loop of iterations and epochs, generation of self-play responses, reference policy log-probability computation, training, validation, and checkpointing.
Description
The SPINTrainer class orchestrates the full SPIN training workflow. Its key responsibilities include:
- Dataloader augmentation: The
augment_dataloader()method wraps the training dataloader as a generator. For each batch, it generates responses using the reference policy weights (viacpu_weight_swap()), constructs paired actual/generated batches with masks, computes reference policy log-probabilities for both actual and generated responses, and yields augmented batches containing all data needed for training. - Training steps: The
train_single_step()method performs a single optimization step: zero gradients, compute loss and metrics via the model'sget_loss_and_metrics(), clip gradients, and step the optimizer and scheduler. - Validation: The
run_validation()method evaluates using vanilla SFT loss (not SPIN loss) for efficiency, avoiding costly generation during validation. - Nested loop structure: The
fit()method iterates over iterations (outer loop) and epochs (inner loop). After each complete iteration, the reference policy weights are updated to match the current model weights viaretrieve_model_state_dict_in_cpu(). The KL penalty can be scheduled per iteration viamodel.set_KL_penalty_by_iteration(). - State management: Tracks
step,consumed_samples, and derivesepochanditerationas computed properties. Supports save/restore for continuation training.
The spin_custom_collate function is also defined in the same module. It collates batches from GPTSFTChatDataset into tensors with keys: prompts_and_answers, masks, prompts_only, answers_only, prompt_lengths, and combined_lengths.
Usage
Import and instantiate SPINTrainer when setting up a SPIN training run. It is typically created in the train_gpt_spin.py entry point script after the model, optimizer, scheduler, dataloaders, and logger have been initialized.
Code Reference
Source Location
- Repository: NVIDIA_NeMo_Aligner
- File: nemo_aligner/algorithms/spin.py
- Lines: 78-541
Signature
class SPINTrainer:
"""Trainer to coordinate SPIN SFT training
"""
def __init__(
self,
cfg: DictConfig,
model,
optimizer,
scheduler,
train_dataloader,
val_dataloader,
test_dataloader,
logger,
ckpt_callback,
run_timer,
):
Key Method: fit()
def fit(self):
...
self.run_timer.start_time()
iterations_iter = range(self.iteration, self.cfg.max_iterations)
...
for _ in iterations_iter:
epoch_iter = range(self.epoch, self.cfg.max_epochs)
...
self.model.set_KL_penalty_by_iteration(self.iteration)
for _ in epoch_iter:
...
global_pbar = tqdm(
self.augment_dataloader(self.train_dataloader),
initial=self.step,
total=self.max_steps,
leave=True,
desc="Training steps",
)
for _, global_batch in zip(loop_iter, global_pbar):
...
# update the reference policy weights
self.model.ref_policy_state_dict = retrieve_model_state_dict_in_cpu(
self.model, megatron_amp_O2=self.model.cfg.get("megatron_amp_O2", False)
)
Key Function: spin_custom_collate()
def spin_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False):
Import
from nemo_aligner.algorithms.spin import SPINTrainer, spin_custom_collate
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cfg | DictConfig | Yes | Trainer configuration (from cfg.trainer.spin), includes max_epochs, max_iterations, max_steps, gradient_clip_val, val_check_interval, save_interval, limit_val_batches, limit_train_batches
|
| model | MegatronGPTSPINModel | Yes | The SPIN model implementing SupervisedInterface, with get_loss_and_metrics(), get_ref_policy_logprobs(), and generation capabilities
|
| optimizer | Optimizer | Yes | PyTorch optimizer extracted from the PTL model |
| scheduler | LRScheduler | Yes | Learning rate scheduler extracted from the PTL model |
| train_dataloader | DataLoader | Yes | Training dataloader providing SFT chat data with prompts and ground-truth responses |
| val_dataloader | DataLoader | Yes | Validation dataloader |
| test_dataloader | DataLoader | No | Test dataloader (can be None) |
| logger | CustomLoggerWrapper | Yes | Logger for metrics and table logging |
| ckpt_callback | Callback | Yes | Custom checkpoint callback for saving model states |
| run_timer | Timer | Yes | Timer that enforces maximum training time limits |
Outputs
| Name | Type | Description |
|---|---|---|
| None (side effects) | N/A | The fit() method trains the model in-place, logs metrics, saves checkpoints, and updates the reference policy. No return value.
|
Usage Examples
from nemo_aligner.algorithms.spin import SPINTrainer
spin_trainer = SPINTrainer(
cfg=cfg.trainer.spin,
model=ptl_model,
optimizer=optimizer,
scheduler=scheduler,
train_dataloader=train_dataloader,
val_dataloader=val_dataloader,
test_dataloader=None,
logger=logger,
ckpt_callback=ckpt_callback,
run_timer=timer,
)
if custom_trainer_state_dict is not None:
spin_trainer.load_state_dict(custom_trainer_state_dict)
spin_trainer.fit()