Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:CarperAI Trlx NeMo SFT Model

From Leeroopedia


Knowledge Sources
Domains Supervised_Learning, NLP, Megatron
Last Updated 2026-02-07 16:00 GMT

Overview

Concrete tool for supervised fine-tuning (SFT) of language models on NVIDIA NeMo's Megatron-GPT framework with pipeline-parallel and tensor-parallel support.

Description

The SFTGPT class extends NeMo's MegatronGPTModel to add custom data loading with distributed batch sampling, causal language model loss computation with loss masking (to ignore padding tokens), checkpoint loading with pipeline-parallel resharding, and configurable text generation. Uses vocab_parallel_cross_entropy for efficient distributed loss computation across tensor-parallel ranks.

Usage

Use this model class when performing supervised fine-tuning on large-scale models (1B+ parameters) that require NeMo's Megatron distributed training. For smaller models using HuggingFace Accelerate, use the standard SFT trainer (accelerate_sft_trainer.py) instead.

Code Reference

Source Location

Signature

class SFTGPT(MegatronGPTModel):
    def __init__(
        self,
        sft_config: SFTConfig,
        metric_fn: Optional[Callable] = None,
        **kwargs,
    ):
        """
        Args:
            sft_config: SFT configuration with gen_kwargs.
            metric_fn: Optional evaluation metric function.
            **kwargs: Passed to MegatronGPTModel.
        """

    def load_from_pretrained(self, checkpoint_dir: str):
        """Load pretrained weights with pipeline-parallel resharding."""

    def training_step(self, batch, batch_idx) -> torch.Tensor:
        """Compute causal LM loss with loss masking."""

    def generate(
        self,
        inputs: dict,
        length_params: LengthParam,
        sampling_params: Optional[SamplingParam] = None,
    ) -> list:
        """Generate text with configurable sampling parameters."""

Import

from trlx.models.modeling_nemo_sft import SFTGPT

I/O Contract

Inputs

Name Type Required Description
sft_config SFTConfig Yes SFT configuration with generation kwargs
metric_fn Callable No Evaluation metric function
batch dict Yes Training batch with tokens, loss_mask, attention_mask, position_ids
checkpoint_dir str No Path to pretrained checkpoint

Outputs

Name Type Description
training_step returns torch.Tensor Cross-entropy loss (masked, averaged across DP group)
generate returns list Generated text sequences
validation_step returns dict Validation loss and generated samples

Usage Examples

Use SFTGPT with NeMo Megatron

from omegaconf import OmegaConf
from trlx.models.modeling_nemo_sft import SFTGPT
from trlx.trainer.accelerate_sft_trainer import SFTConfig

# 1. Load NeMo config
megatron_cfg = OmegaConf.load("configs/nemo_configs/megatron_1.3b.yaml")

# 2. Create SFT config
sft_config = SFTConfig(
    gen_kwargs={"temperature": 0.7, "max_new_tokens": 128},
)

# 3. Instantiate model (typically done by NeMoSFTTrainer)
model = SFTGPT(sft_config=sft_config, cfg=megatron_cfg.model)

# 4. Load pretrained weights
model.load_from_pretrained("/path/to/checkpoint")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment