Implementation:CarperAI Trlx NeMo SFT Model
| Knowledge Sources | |
|---|---|
| Domains | Supervised_Learning, NLP, Megatron |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Concrete tool for supervised fine-tuning (SFT) of language models on NVIDIA NeMo's Megatron-GPT framework with pipeline-parallel and tensor-parallel support.
Description
The SFTGPT class extends NeMo's MegatronGPTModel to add custom data loading with distributed batch sampling, causal language model loss computation with loss masking (to ignore padding tokens), checkpoint loading with pipeline-parallel resharding, and configurable text generation. Uses vocab_parallel_cross_entropy for efficient distributed loss computation across tensor-parallel ranks.
Usage
Use this model class when performing supervised fine-tuning on large-scale models (1B+ parameters) that require NeMo's Megatron distributed training. For smaller models using HuggingFace Accelerate, use the standard SFT trainer (accelerate_sft_trainer.py) instead.
Code Reference
Source Location
- Repository: CarperAI_Trlx
- File: trlx/models/modeling_nemo_sft.py
- Lines: 1-523
Signature
class SFTGPT(MegatronGPTModel):
def __init__(
self,
sft_config: SFTConfig,
metric_fn: Optional[Callable] = None,
**kwargs,
):
"""
Args:
sft_config: SFT configuration with gen_kwargs.
metric_fn: Optional evaluation metric function.
**kwargs: Passed to MegatronGPTModel.
"""
def load_from_pretrained(self, checkpoint_dir: str):
"""Load pretrained weights with pipeline-parallel resharding."""
def training_step(self, batch, batch_idx) -> torch.Tensor:
"""Compute causal LM loss with loss masking."""
def generate(
self,
inputs: dict,
length_params: LengthParam,
sampling_params: Optional[SamplingParam] = None,
) -> list:
"""Generate text with configurable sampling parameters."""
Import
from trlx.models.modeling_nemo_sft import SFTGPT
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| sft_config | SFTConfig | Yes | SFT configuration with generation kwargs |
| metric_fn | Callable | No | Evaluation metric function |
| batch | dict | Yes | Training batch with tokens, loss_mask, attention_mask, position_ids |
| checkpoint_dir | str | No | Path to pretrained checkpoint |
Outputs
| Name | Type | Description |
|---|---|---|
| training_step returns | torch.Tensor | Cross-entropy loss (masked, averaged across DP group) |
| generate returns | list | Generated text sequences |
| validation_step returns | dict | Validation loss and generated samples |
Usage Examples
Use SFTGPT with NeMo Megatron
from omegaconf import OmegaConf
from trlx.models.modeling_nemo_sft import SFTGPT
from trlx.trainer.accelerate_sft_trainer import SFTConfig
# 1. Load NeMo config
megatron_cfg = OmegaConf.load("configs/nemo_configs/megatron_1.3b.yaml")
# 2. Create SFT config
sft_config = SFTConfig(
gen_kwargs={"temperature": 0.7, "max_new_tokens": 128},
)
# 3. Instantiate model (typically done by NeMoSFTTrainer)
model = SFTGPT(sft_config=sft_config, cfg=megatron_cfg.model)
# 4. Load pretrained weights
model.load_from_pretrained("/path/to/checkpoint")