Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:CarperAI Trlx NeMo PPO Trainer

From Leeroopedia


Knowledge Sources
Domains Reinforcement_Learning, NLP, Megatron
Last Updated 2026-02-07 16:00 GMT

Overview

Concrete tool for orchestrating PPO reinforcement learning training using the NeMo Megatron framework, handling rollout generation, reward computation, and the PPO training loop.

Description

The NeMoPPOTrainer class extends BaseRLTrainer to implement the full PPO training loop on NeMo's Megatron-GPT backend. It handles experience collection (make_experience) by generating text completions, computing rewards, calculating KL penalties against a reference model, inferring log-probabilities and values, and constructing PPORLBatch objects. The learn method orchestrates the outer training loop with rollout generation, PPO optimization epochs, validation, checkpointing, and W&B logging. Supports reward scaling via whitening, reference-based subtraction, or reward clipping.

Usage

Use this trainer when running PPO training on large-scale models (1B+ parameters) using NeMo's Megatron distributed backend. It is registered as the "NeMoPPOTrainer" trainer type and is automatically selected when using NeMo configs with PPO method.

Code Reference

Source Location

Signature

@register_trainer
class NeMoPPOTrainer(BaseRLTrainer):
    def __init__(
        self,
        config: TRLConfig,
        metric_fn: Optional[Callable] = None,
        megatron_cfg: Optional[str] = None,
        pretrained_model: Optional[str] = None,
        **kwargs,
    ):
        """
        Args:
            config: TRLConfig with PPO method config.
            metric_fn: Optional evaluation metric function.
            megatron_cfg: Path to NeMo Megatron YAML config.
            pretrained_model: Path to pretrained model weights.
        """

    def make_experience(
        self,
        prompt_iterator: Iterator,
        num_rollouts: int = 1024,
        dp_world: int = 1,
    ) -> List[PPORLElement]:
        """
        Generate rollouts: sample completions, compute rewards, infer logprobs/values.
        Returns list of PPORLElement for training.
        """

    def learn(self) -> None:
        """
        Main training loop: generate experience, run PPO epochs,
        validate, checkpoint, and log metrics.
        """

Import

from trlx.trainer.nemo_ppo_trainer import NeMoPPOTrainer

I/O Contract

Inputs

Name Type Required Description
config TRLConfig Yes Full trlx configuration with PPO method config
metric_fn Callable No Evaluation metric function
megatron_cfg str No Path to NeMo Megatron YAML config
pretrained_model str No Path to pretrained model checkpoint
prompt_iterator Iterator Yes Iterator yielding prompt batches for rollout generation

Outputs

Name Type Description
make_experience returns List[PPORLElement] PPO rollout elements with tokens, logprobs, values, rewards
learn None Trains the model in-place, logs to W&B, saves checkpoints

Usage Examples

Train with NeMoPPOTrainer

import trlx
from trlx.data.default_configs import TRLConfig, default_ppo_config

# 1. Define reward function
def reward_fn(samples, **kwargs):
    return [0.5] * len(samples)  # Dummy reward

# 2. Configure and train
config = default_ppo_config()
config.train.trainer = "NeMoPPOTrainer"

trainer = trlx.train(
    reward_fn=reward_fn,
    prompts=["Hello, how are you?"] * 100,
    config=config,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment