Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:CarperAI Trlx NeMo ILQL Trainer

From Leeroopedia


Knowledge Sources
Domains Reinforcement_Learning, NLP, Megatron
Last Updated 2026-02-07 16:00 GMT

Overview

Concrete tool for orchestrating ILQL (Implicit Language Q-Learning) training using the NeMo Megatron framework with distributed data loading and precision management.

Description

The NeMoILQLTrainer extends BaseRLTrainer to run ILQL training on NeMo's Megatron-GPT backend. It sets up the Megatron trainer with precision plugins (FP16/BF16 via MegatronHalfPrecisionPlugin), the NLPDDPStrategy for distributed data parallelism, and experiment management. Uses ShuffledCyclicSequence to create a virtual dataset that cycles through the real training data with per-epoch shuffling. The megatron_trainer helper function configures the PyTorch Lightning Trainer with all NeMo-specific settings. Delegates the actual training loop to PyTorch Lightning's trainer.fit() with the ILQLGPT model.

Usage

Use this trainer when running ILQL on large-scale models (1B+ parameters) using NeMo's Megatron distributed backend. Registered as "NeMoILQLTrainer" and automatically selected when using NeMo configs with ILQL method.

Code Reference

Source Location

Signature

def megatron_trainer(cfg) -> pytorch_lightning.Trainer:
    """
    Configure a PyTorch Lightning Trainer with NeMo Megatron settings.

    Args:
        cfg: OmegaConf config with trainer, model, and exp_manager sections.

    Returns:
        Configured PyTorch Lightning Trainer.
    """


class ShuffledCyclicSequence:
    def __init__(self, new_length: int, data: Sequence, seed: int):
        """
        Virtual dataset that cycles through data with per-epoch shuffling.

        Args:
            new_length: Virtual dataset length (can be larger than real data).
            data: Underlying data sequence.
            seed: Random seed for reproducible shuffling.
        """


@register_trainer
class NeMoILQLTrainer(BaseRLTrainer):
    def __init__(
        self,
        config: TRLConfig,
        reward_fn: Optional[Callable] = None,
        logit_mask: Optional[torch.Tensor] = None,
        metric_fn: Optional[Callable] = None,
        stop_sequences: Optional[List[str]] = None,
        train_mode: bool = True,
        megatron_cfg: Optional[str] = None,
        pretrained_model: Optional[str] = None,
    ):
        """
        Args:
            config: TRLConfig with ILQL method config.
            reward_fn: Reward function for experience generation.
            logit_mask: Optional logit mask for constrained decoding.
            metric_fn: Optional evaluation metric function.
            stop_sequences: Optional stop sequences for generation.
            train_mode: Whether to train or only evaluate.
            megatron_cfg: Path to NeMo Megatron YAML config.
            pretrained_model: Path to pretrained model weights.
        """

    def learn(self) -> None:
        """Run ILQL training via PyTorch Lightning trainer.fit()."""

Import

from trlx.trainer.nemo_ilql_trainer import NeMoILQLTrainer, megatron_trainer

I/O Contract

Inputs

Name Type Required Description
config TRLConfig Yes Full trlx configuration with ILQL method config
reward_fn Callable No Reward function for labeling samples
metric_fn Callable No Evaluation metric function
megatron_cfg str No Path to NeMo Megatron YAML config
pretrained_model str No Path to pretrained model checkpoint
logit_mask torch.Tensor No Mask for constrained decoding

Outputs

Name Type Description
learn None Trains the ILQL model in-place via Lightning
Trained model ILQLGPT ILQL-trained NeMo model with Q/V heads

Usage Examples

Train ILQL with NeMo

import trlx
from trlx.data.default_configs import TRLConfig

# 1. Prepare labeled data (prompt, completion, reward)
samples = [("What is 2+2?", "4", 1.0), ("What is 2+2?", "5", 0.0)]

# 2. Configure ILQL with NeMo
config = TRLConfig.load_yaml("configs/nemo_ilql_config.yml")

# 3. Train
trainer = trlx.train(
    samples=samples,
    config=config,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment