Implementation:CarperAI Trlx NeMo PPO Trainer
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, NLP, Megatron |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Concrete tool for orchestrating PPO reinforcement learning training using the NeMo Megatron framework, handling rollout generation, reward computation, and the PPO training loop.
Description
The NeMoPPOTrainer class extends BaseRLTrainer to implement the full PPO training loop on NeMo's Megatron-GPT backend. It handles experience collection (make_experience) by generating text completions, computing rewards, calculating KL penalties against a reference model, inferring log-probabilities and values, and constructing PPORLBatch objects. The learn method orchestrates the outer training loop with rollout generation, PPO optimization epochs, validation, checkpointing, and W&B logging. Supports reward scaling via whitening, reference-based subtraction, or reward clipping.
Usage
Use this trainer when running PPO training on large-scale models (1B+ parameters) using NeMo's Megatron distributed backend. It is registered as the "NeMoPPOTrainer" trainer type and is automatically selected when using NeMo configs with PPO method.
Code Reference
Source Location
- Repository: CarperAI_Trlx
- File: trlx/trainer/nemo_ppo_trainer.py
- Lines: 1-441
Signature
@register_trainer
class NeMoPPOTrainer(BaseRLTrainer):
def __init__(
self,
config: TRLConfig,
metric_fn: Optional[Callable] = None,
megatron_cfg: Optional[str] = None,
pretrained_model: Optional[str] = None,
**kwargs,
):
"""
Args:
config: TRLConfig with PPO method config.
metric_fn: Optional evaluation metric function.
megatron_cfg: Path to NeMo Megatron YAML config.
pretrained_model: Path to pretrained model weights.
"""
def make_experience(
self,
prompt_iterator: Iterator,
num_rollouts: int = 1024,
dp_world: int = 1,
) -> List[PPORLElement]:
"""
Generate rollouts: sample completions, compute rewards, infer logprobs/values.
Returns list of PPORLElement for training.
"""
def learn(self) -> None:
"""
Main training loop: generate experience, run PPO epochs,
validate, checkpoint, and log metrics.
"""
Import
from trlx.trainer.nemo_ppo_trainer import NeMoPPOTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| config | TRLConfig | Yes | Full trlx configuration with PPO method config |
| metric_fn | Callable | No | Evaluation metric function |
| megatron_cfg | str | No | Path to NeMo Megatron YAML config |
| pretrained_model | str | No | Path to pretrained model checkpoint |
| prompt_iterator | Iterator | Yes | Iterator yielding prompt batches for rollout generation |
Outputs
| Name | Type | Description |
|---|---|---|
| make_experience returns | List[PPORLElement] | PPO rollout elements with tokens, logprobs, values, rewards |
| learn | None | Trains the model in-place, logs to W&B, saves checkpoints |
Usage Examples
Train with NeMoPPOTrainer
import trlx
from trlx.data.default_configs import TRLConfig, default_ppo_config
# 1. Define reward function
def reward_fn(samples, **kwargs):
return [0.5] * len(samples) # Dummy reward
# 2. Configure and train
config = default_ppo_config()
config.train.trainer = "NeMoPPOTrainer"
trainer = trlx.train(
reward_fn=reward_fn,
prompts=["Hello, how are you?"] * 100,
config=config,
)