Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:CarperAI Trlx NeMo SFT Trainer

From Leeroopedia


Knowledge Sources
Domains Supervised_Learning, NLP, Megatron
Last Updated 2026-02-07 16:00 GMT

Overview

Concrete tool for orchestrating supervised fine-tuning (SFT) using the NeMo Megatron framework with distributed data loading and loss masking.

Description

The NeMoSFTTrainer extends BaseRLTrainer to run supervised fine-tuning on NeMo's Megatron-GPT backend. It handles tokenization of training samples, constructs loss masks that exclude padding tokens and optionally end-of-document tokens, pads batches to uniform length, and delegates training to PyTorch Lightning via the SFTGPT model. Uses ShuffledCyclicSequence from the NeMo ILQL trainer for virtual dataset cycling and the megatron_trainer helper for configuring the Lightning Trainer with NeMo precision plugins and DDP strategy.

Usage

Use this trainer when performing supervised fine-tuning on large-scale models (1B+ parameters) using NeMo's Megatron distributed backend. Registered as "NeMoSFTTrainer" and automatically selected when using NeMo configs with SFT method.

Code Reference

Source Location

Signature

@register_trainer
class NeMoSFTTrainer(BaseRLTrainer):
    def __init__(
        self,
        config: TRLConfig,
        metric_fn: Optional[Callable] = None,
        megatron_cfg: Optional[str] = None,
        pretrained_model: Optional[str] = None,
        **kwargs,
    ):
        """
        Args:
            config: TRLConfig with SFT method config.
            metric_fn: Optional evaluation metric function.
            megatron_cfg: Path to NeMo Megatron YAML config.
            pretrained_model: Path to pretrained model weights.
        """

    def learn(self) -> None:
        """
        Tokenize samples, build loss masks, configure datasets,
        and run training via PyTorch Lightning trainer.fit().
        """

    def make_experience(self, samples, seq_length: int):
        """
        Tokenize raw text samples into BatchEncoding objects.

        Args:
            samples: List of text strings to tokenize.
            seq_length: Maximum sequence length for truncation.
        """

Import

from trlx.trainer.nemo_sft_trainer import NeMoSFTTrainer

I/O Contract

Inputs

Name Type Required Description
config TRLConfig Yes Full trlx configuration with SFT method config
metric_fn Callable No Evaluation metric function
megatron_cfg str No Path to NeMo Megatron YAML config
pretrained_model str No Path to pretrained model checkpoint
samples List[str] Yes Training text samples for tokenization

Outputs

Name Type Description
learn None Trains the SFT model in-place via Lightning
Trained model SFTGPT Fine-tuned NeMo model

Usage Examples

SFT with NeMo

import trlx
from trlx.data.default_configs import TRLConfig

# 1. Prepare training samples
samples = [
    "Human: What is AI?\nAssistant: AI stands for Artificial Intelligence.",
    "Human: Hello\nAssistant: Hi there!",
]

# 2. Configure SFT with NeMo
config = TRLConfig.load_yaml("configs/nemo_sft_config.yml")

# 3. Train
trainer = trlx.train(
    samples=samples,
    config=config,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment