Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:NVIDIA NeMo Aligner SPINTrainer Fit

From Leeroopedia


Knowledge Sources
Domains NLP, Alignment
Last Updated 2026-02-08 00:00 GMT

Overview

SPINTrainer is the trainer class that coordinates SPIN (Self-Play Fine-Tuning) training, managing the nested loop of iterations and epochs, generation of self-play responses, reference policy log-probability computation, training, validation, and checkpointing.

Description

The SPINTrainer class orchestrates the full SPIN training workflow. Its key responsibilities include:

  • Dataloader augmentation: The augment_dataloader() method wraps the training dataloader as a generator. For each batch, it generates responses using the reference policy weights (via cpu_weight_swap()), constructs paired actual/generated batches with masks, computes reference policy log-probabilities for both actual and generated responses, and yields augmented batches containing all data needed for training.
  • Training steps: The train_single_step() method performs a single optimization step: zero gradients, compute loss and metrics via the model's get_loss_and_metrics(), clip gradients, and step the optimizer and scheduler.
  • Validation: The run_validation() method evaluates using vanilla SFT loss (not SPIN loss) for efficiency, avoiding costly generation during validation.
  • Nested loop structure: The fit() method iterates over iterations (outer loop) and epochs (inner loop). After each complete iteration, the reference policy weights are updated to match the current model weights via retrieve_model_state_dict_in_cpu(). The KL penalty can be scheduled per iteration via model.set_KL_penalty_by_iteration().
  • State management: Tracks step, consumed_samples, and derives epoch and iteration as computed properties. Supports save/restore for continuation training.

The spin_custom_collate function is also defined in the same module. It collates batches from GPTSFTChatDataset into tensors with keys: prompts_and_answers, masks, prompts_only, answers_only, prompt_lengths, and combined_lengths.

Usage

Import and instantiate SPINTrainer when setting up a SPIN training run. It is typically created in the train_gpt_spin.py entry point script after the model, optimizer, scheduler, dataloaders, and logger have been initialized.

Code Reference

Source Location

Signature

class SPINTrainer:
    """Trainer to coordinate SPIN SFT training
    """

    def __init__(
        self,
        cfg: DictConfig,
        model,
        optimizer,
        scheduler,
        train_dataloader,
        val_dataloader,
        test_dataloader,
        logger,
        ckpt_callback,
        run_timer,
    ):

Key Method: fit()

def fit(self):
    ...
    self.run_timer.start_time()
    iterations_iter = range(self.iteration, self.cfg.max_iterations)
    ...
    for _ in iterations_iter:
        epoch_iter = range(self.epoch, self.cfg.max_epochs)
        ...
        self.model.set_KL_penalty_by_iteration(self.iteration)
        for _ in epoch_iter:
            ...
            global_pbar = tqdm(
                self.augment_dataloader(self.train_dataloader),
                initial=self.step,
                total=self.max_steps,
                leave=True,
                desc="Training steps",
            )
            for _, global_batch in zip(loop_iter, global_pbar):
                ...
        # update the reference policy weights
        self.model.ref_policy_state_dict = retrieve_model_state_dict_in_cpu(
            self.model, megatron_amp_O2=self.model.cfg.get("megatron_amp_O2", False)
        )

Key Function: spin_custom_collate()

def spin_custom_collate(batch, eos_id, reset_position_ids=False, reset_attention_mask=False, eod_mask_loss=False):

Import

from nemo_aligner.algorithms.spin import SPINTrainer, spin_custom_collate

I/O Contract

Inputs

Name Type Required Description
cfg DictConfig Yes Trainer configuration (from cfg.trainer.spin), includes max_epochs, max_iterations, max_steps, gradient_clip_val, val_check_interval, save_interval, limit_val_batches, limit_train_batches
model MegatronGPTSPINModel Yes The SPIN model implementing SupervisedInterface, with get_loss_and_metrics(), get_ref_policy_logprobs(), and generation capabilities
optimizer Optimizer Yes PyTorch optimizer extracted from the PTL model
scheduler LRScheduler Yes Learning rate scheduler extracted from the PTL model
train_dataloader DataLoader Yes Training dataloader providing SFT chat data with prompts and ground-truth responses
val_dataloader DataLoader Yes Validation dataloader
test_dataloader DataLoader No Test dataloader (can be None)
logger CustomLoggerWrapper Yes Logger for metrics and table logging
ckpt_callback Callback Yes Custom checkpoint callback for saving model states
run_timer Timer Yes Timer that enforces maximum training time limits

Outputs

Name Type Description
None (side effects) N/A The fit() method trains the model in-place, logs metrics, saves checkpoints, and updates the reference policy. No return value.

Usage Examples

from nemo_aligner.algorithms.spin import SPINTrainer

spin_trainer = SPINTrainer(
    cfg=cfg.trainer.spin,
    model=ptl_model,
    optimizer=optimizer,
    scheduler=scheduler,
    train_dataloader=train_dataloader,
    val_dataloader=val_dataloader,
    test_dataloader=None,
    logger=logger,
    ckpt_callback=ckpt_callback,
    run_timer=timer,
)

if custom_trainer_state_dict is not None:
    spin_trainer.load_state_dict(custom_trainer_state_dict)

spin_trainer.fit()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment