Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:CarperAI Trlx NeMo PPO Model

From Leeroopedia


Knowledge Sources
Domains Reinforcement_Learning, NLP, Megatron
Last Updated 2026-02-07 16:00 GMT

Overview

Concrete tool for running Proximal Policy Optimization (PPO) on NVIDIA NeMo's Megatron-GPT framework with value head, reference model, and pipeline-parallel support.

Description

The PPOGPT class extends NeMo's MegatronGPTModel to add a ValueHead for advantage estimation, a RefLMHeads module that maintains both the policy model and a frozen reference model (with CPU offloading for memory efficiency), and PPO-specific training logic including clipped surrogate loss, value loss, and KL penalty computation. Supports tensor parallelism, pipeline parallelism, selective layer freezing (num_layers_unfrozen), stop-sequence handling during generation, and distributed fused Adam optimization.

Usage

Use this model class when training PPO on large-scale models (1B+ parameters) that require NeMo's Megatron distributed training. The reference model is used for KL divergence computation and can be offloaded to CPU to reduce GPU memory. For smaller models using HuggingFace Accelerate, use the standard PPO model in modeling_ppo.py instead.

Code Reference

Source Location

Signature

class PPOGPT(MegatronGPTModel):
    def __init__(
        self,
        ppo_config: PPOConfig,
        metric_fn: Optional[Callable] = None,
        stop_sequences: Sequence[str] = (),
        num_layers_unfrozen: Optional[int] = None,
        build_reference_model: bool = True,
        **kwargs,
    ):
        """
        Args:
            ppo_config: PPO hyperparameters (gen_kwargs, etc.).
            metric_fn: Optional evaluation metric function.
            stop_sequences: Sequences that terminate generation.
            num_layers_unfrozen: Number of transformer layers to train (-1 = all).
            build_reference_model: Whether to create a frozen reference model for KL.
            **kwargs: Passed to MegatronGPTModel.
        """

    def training_step(self, batch: PPORLBatch, batch_idx: int) -> torch.Tensor:
        """Execute one PPO training step with clipped surrogate and value losses."""

    def infer_logprobs_and_values(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Compute policy logprobs, reference logprobs, and values for a batch."""

    def generate(
        self,
        inputs: dict,
        length_params: LengthParam,
        sampling_params: Optional[SamplingParam] = None,
    ) -> list:
        """Generate text with optional stop-sequence handling."""

Import

from trlx.models.modeling_nemo_ppo import PPOGPT

I/O Contract

Inputs

Name Type Required Description
ppo_config PPOConfig Yes PPO hyperparameters
metric_fn Callable No Evaluation metric function
stop_sequences Sequence[str] No Stop sequences for generation
num_layers_unfrozen int No Layers to keep trainable (-1 = all)
build_reference_model bool No Whether to build frozen reference copy (default True)
batch PPORLBatch Yes PPO rollout batch for training
input_ids torch.Tensor Yes Token IDs for logprob/value inference
attention_mask torch.Tensor Yes Attention mask for inference

Outputs

Name Type Description
training_step returns torch.Tensor Combined PPO loss (policy + value)
infer_logprobs_and_values returns Tuple[Tensor, Tensor, Tensor] (policy_logprobs, ref_logprobs, values)
generate returns list Generated text sequences

Usage Examples

Create PPOGPT for NeMo Training

from omegaconf import OmegaConf
from trlx.models.modeling_ppo import PPOConfig
from trlx.models.modeling_nemo_ppo import PPOGPT

# 1. Load NeMo config
megatron_cfg = OmegaConf.load("configs/nemo_configs/megatron_20b.yaml")

# 2. Create PPO config
ppo_config = PPOConfig(
    num_rollouts=128,
    chunk_size=16,
    ppo_epochs=4,
    init_kl_coef=0.1,
    target=6.0,
    gen_kwargs={"temperature": 1.0, "max_new_tokens": 128},
)

# 3. Instantiate model (typically done by NeMoPPOTrainer)
model = PPOGPT(
    ppo_config=ppo_config,
    num_layers_unfrozen=2,
    build_reference_model=True,
    cfg=megatron_cfg.model,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment