Implementation:CarperAI Trlx NeMo SFT Trainer
| Knowledge Sources | |
|---|---|
| Domains | Supervised_Learning, NLP, Megatron |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Concrete tool for orchestrating supervised fine-tuning (SFT) using the NeMo Megatron framework with distributed data loading and loss masking.
Description
The NeMoSFTTrainer extends BaseRLTrainer to run supervised fine-tuning on NeMo's Megatron-GPT backend. It handles tokenization of training samples, constructs loss masks that exclude padding tokens and optionally end-of-document tokens, pads batches to uniform length, and delegates training to PyTorch Lightning via the SFTGPT model. Uses ShuffledCyclicSequence from the NeMo ILQL trainer for virtual dataset cycling and the megatron_trainer helper for configuring the Lightning Trainer with NeMo precision plugins and DDP strategy.
Usage
Use this trainer when performing supervised fine-tuning on large-scale models (1B+ parameters) using NeMo's Megatron distributed backend. Registered as "NeMoSFTTrainer" and automatically selected when using NeMo configs with SFT method.
Code Reference
Source Location
- Repository: CarperAI_Trlx
- File: trlx/trainer/nemo_sft_trainer.py
- Lines: 1-140
Signature
@register_trainer
class NeMoSFTTrainer(BaseRLTrainer):
def __init__(
self,
config: TRLConfig,
metric_fn: Optional[Callable] = None,
megatron_cfg: Optional[str] = None,
pretrained_model: Optional[str] = None,
**kwargs,
):
"""
Args:
config: TRLConfig with SFT method config.
metric_fn: Optional evaluation metric function.
megatron_cfg: Path to NeMo Megatron YAML config.
pretrained_model: Path to pretrained model weights.
"""
def learn(self) -> None:
"""
Tokenize samples, build loss masks, configure datasets,
and run training via PyTorch Lightning trainer.fit().
"""
def make_experience(self, samples, seq_length: int):
"""
Tokenize raw text samples into BatchEncoding objects.
Args:
samples: List of text strings to tokenize.
seq_length: Maximum sequence length for truncation.
"""
Import
from trlx.trainer.nemo_sft_trainer import NeMoSFTTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| config | TRLConfig | Yes | Full trlx configuration with SFT method config |
| metric_fn | Callable | No | Evaluation metric function |
| megatron_cfg | str | No | Path to NeMo Megatron YAML config |
| pretrained_model | str | No | Path to pretrained model checkpoint |
| samples | List[str] | Yes | Training text samples for tokenization |
Outputs
| Name | Type | Description |
|---|---|---|
| learn | None | Trains the SFT model in-place via Lightning |
| Trained model | SFTGPT | Fine-tuned NeMo model |
Usage Examples
SFT with NeMo
import trlx
from trlx.data.default_configs import TRLConfig
# 1. Prepare training samples
samples = [
"Human: What is AI?\nAssistant: AI stands for Artificial Intelligence.",
"Human: Hello\nAssistant: Hi there!",
]
# 2. Configure SFT with NeMo
config = TRLConfig.load_yaml("configs/nemo_sft_config.yml")
# 3. Train
trainer = trlx.train(
samples=samples,
config=config,
)