Implementation:CarperAI Trlx NeMo ILQL Trainer
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, NLP, Megatron |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Concrete tool for orchestrating ILQL (Implicit Language Q-Learning) training using the NeMo Megatron framework with distributed data loading and precision management.
Description
The NeMoILQLTrainer extends BaseRLTrainer to run ILQL training on NeMo's Megatron-GPT backend. It sets up the Megatron trainer with precision plugins (FP16/BF16 via MegatronHalfPrecisionPlugin), the NLPDDPStrategy for distributed data parallelism, and experiment management. Uses ShuffledCyclicSequence to create a virtual dataset that cycles through the real training data with per-epoch shuffling. The megatron_trainer helper function configures the PyTorch Lightning Trainer with all NeMo-specific settings. Delegates the actual training loop to PyTorch Lightning's trainer.fit() with the ILQLGPT model.
Usage
Use this trainer when running ILQL on large-scale models (1B+ parameters) using NeMo's Megatron distributed backend. Registered as "NeMoILQLTrainer" and automatically selected when using NeMo configs with ILQL method.
Code Reference
Source Location
- Repository: CarperAI_Trlx
- File: trlx/trainer/nemo_ilql_trainer.py
- Lines: 1-204
Signature
def megatron_trainer(cfg) -> pytorch_lightning.Trainer:
"""
Configure a PyTorch Lightning Trainer with NeMo Megatron settings.
Args:
cfg: OmegaConf config with trainer, model, and exp_manager sections.
Returns:
Configured PyTorch Lightning Trainer.
"""
class ShuffledCyclicSequence:
def __init__(self, new_length: int, data: Sequence, seed: int):
"""
Virtual dataset that cycles through data with per-epoch shuffling.
Args:
new_length: Virtual dataset length (can be larger than real data).
data: Underlying data sequence.
seed: Random seed for reproducible shuffling.
"""
@register_trainer
class NeMoILQLTrainer(BaseRLTrainer):
def __init__(
self,
config: TRLConfig,
reward_fn: Optional[Callable] = None,
logit_mask: Optional[torch.Tensor] = None,
metric_fn: Optional[Callable] = None,
stop_sequences: Optional[List[str]] = None,
train_mode: bool = True,
megatron_cfg: Optional[str] = None,
pretrained_model: Optional[str] = None,
):
"""
Args:
config: TRLConfig with ILQL method config.
reward_fn: Reward function for experience generation.
logit_mask: Optional logit mask for constrained decoding.
metric_fn: Optional evaluation metric function.
stop_sequences: Optional stop sequences for generation.
train_mode: Whether to train or only evaluate.
megatron_cfg: Path to NeMo Megatron YAML config.
pretrained_model: Path to pretrained model weights.
"""
def learn(self) -> None:
"""Run ILQL training via PyTorch Lightning trainer.fit()."""
Import
from trlx.trainer.nemo_ilql_trainer import NeMoILQLTrainer, megatron_trainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| config | TRLConfig | Yes | Full trlx configuration with ILQL method config |
| reward_fn | Callable | No | Reward function for labeling samples |
| metric_fn | Callable | No | Evaluation metric function |
| megatron_cfg | str | No | Path to NeMo Megatron YAML config |
| pretrained_model | str | No | Path to pretrained model checkpoint |
| logit_mask | torch.Tensor | No | Mask for constrained decoding |
Outputs
| Name | Type | Description |
|---|---|---|
| learn | None | Trains the ILQL model in-place via Lightning |
| Trained model | ILQLGPT | ILQL-trained NeMo model with Q/V heads |
Usage Examples
Train ILQL with NeMo
import trlx
from trlx.data.default_configs import TRLConfig
# 1. Prepare labeled data (prompt, completion, reward)
samples = [("What is 2+2?", "4", 1.0), ("What is 2+2?", "5", 0.0)]
# 2. Configure ILQL with NeMo
config = TRLConfig.load_yaml("configs/nemo_ilql_config.yml")
# 3. Train
trainer = trlx.train(
samples=samples,
config=config,
)