Implementation:CarperAI Trlx NeMo PPO Model
| Knowledge Sources | |
|---|---|
| Domains | Reinforcement_Learning, NLP, Megatron |
| Last Updated | 2026-02-07 16:00 GMT |
Overview
Concrete tool for running Proximal Policy Optimization (PPO) on NVIDIA NeMo's Megatron-GPT framework with value head, reference model, and pipeline-parallel support.
Description
The PPOGPT class extends NeMo's MegatronGPTModel to add a ValueHead for advantage estimation, a RefLMHeads module that maintains both the policy model and a frozen reference model (with CPU offloading for memory efficiency), and PPO-specific training logic including clipped surrogate loss, value loss, and KL penalty computation. Supports tensor parallelism, pipeline parallelism, selective layer freezing (num_layers_unfrozen), stop-sequence handling during generation, and distributed fused Adam optimization.
Usage
Use this model class when training PPO on large-scale models (1B+ parameters) that require NeMo's Megatron distributed training. The reference model is used for KL divergence computation and can be offloaded to CPU to reduce GPU memory. For smaller models using HuggingFace Accelerate, use the standard PPO model in modeling_ppo.py instead.
Code Reference
Source Location
- Repository: CarperAI_Trlx
- File: trlx/models/modeling_nemo_ppo.py
- Lines: 1-1222
Signature
class PPOGPT(MegatronGPTModel):
def __init__(
self,
ppo_config: PPOConfig,
metric_fn: Optional[Callable] = None,
stop_sequences: Sequence[str] = (),
num_layers_unfrozen: Optional[int] = None,
build_reference_model: bool = True,
**kwargs,
):
"""
Args:
ppo_config: PPO hyperparameters (gen_kwargs, etc.).
metric_fn: Optional evaluation metric function.
stop_sequences: Sequences that terminate generation.
num_layers_unfrozen: Number of transformer layers to train (-1 = all).
build_reference_model: Whether to create a frozen reference model for KL.
**kwargs: Passed to MegatronGPTModel.
"""
def training_step(self, batch: PPORLBatch, batch_idx: int) -> torch.Tensor:
"""Execute one PPO training step with clipped surrogate and value losses."""
def infer_logprobs_and_values(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute policy logprobs, reference logprobs, and values for a batch."""
def generate(
self,
inputs: dict,
length_params: LengthParam,
sampling_params: Optional[SamplingParam] = None,
) -> list:
"""Generate text with optional stop-sequence handling."""
Import
from trlx.models.modeling_nemo_ppo import PPOGPT
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| ppo_config | PPOConfig | Yes | PPO hyperparameters |
| metric_fn | Callable | No | Evaluation metric function |
| stop_sequences | Sequence[str] | No | Stop sequences for generation |
| num_layers_unfrozen | int | No | Layers to keep trainable (-1 = all) |
| build_reference_model | bool | No | Whether to build frozen reference copy (default True) |
| batch | PPORLBatch | Yes | PPO rollout batch for training |
| input_ids | torch.Tensor | Yes | Token IDs for logprob/value inference |
| attention_mask | torch.Tensor | Yes | Attention mask for inference |
Outputs
| Name | Type | Description |
|---|---|---|
| training_step returns | torch.Tensor | Combined PPO loss (policy + value) |
| infer_logprobs_and_values returns | Tuple[Tensor, Tensor, Tensor] | (policy_logprobs, ref_logprobs, values) |
| generate returns | list | Generated text sequences |
Usage Examples
Create PPOGPT for NeMo Training
from omegaconf import OmegaConf
from trlx.models.modeling_ppo import PPOConfig
from trlx.models.modeling_nemo_ppo import PPOGPT
# 1. Load NeMo config
megatron_cfg = OmegaConf.load("configs/nemo_configs/megatron_20b.yaml")
# 2. Create PPO config
ppo_config = PPOConfig(
num_rollouts=128,
chunk_size=16,
ppo_epochs=4,
init_kl_coef=0.1,
target=6.0,
gen_kwargs={"temperature": 1.0, "max_new_tokens": 128},
)
# 3. Instantiate model (typically done by NeMoPPOTrainer)
model = PPOGPT(
ppo_config=ppo_config,
num_layers_unfrozen=2,
build_reference_model=True,
cfg=megatron_cfg.model,
)