Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:NVIDIA NeMo Aligner PPOTrainer Fit

From Leeroopedia


Implementation Details
Name PPOTrainer_Fit
Type API Doc
Implements Principle PPO_Training
Module nemo_aligner.algorithms
Repository NeMo Aligner
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for executing the PPO training loop with distributed actor-critic architecture provided by the NeMo Aligner algorithms module.

Description

The PPOTrainer class coordinates the full PPO training cycle: (1) rollout generation via the actor model, (2) reward and value estimation via the remote critic server, (3) advantage computation using GAE, (4) actor policy update using clipped PPO objective, (5) critic training via the remote server. The fit() method orchestrates the main training loop with epoch management, metric logging, validation, and checkpointing. It manages distributed coordination across data-parallel ranks and handles optional TRT-LLM resharding for accelerated generation.

Usage

Used in train_gpt_ppo_actor.py. Requires a running critic server process (serve_ppo_critic.py) and a connected RemoteGPTRMCriticClient.

Code Reference

Source Location

  • Repository: NeMo Aligner
  • File: nemo_aligner/algorithms/ppo.py
  • Lines: L130-641

Signature

class PPOTrainer:
    def __init__(
        self,
        cfg: DictConfig,
        model: MegatronGPTActorModel,
        optimizer,
        scheduler,
        train_dataloader_builder: Callable,
        val_dataloader_builder: Callable,
        collate_fn,
        rm_critic: RemoteGPTRMCriticClient,
        batch_iterator_cls,
        logger,
        ckpt_callback,
        run_timer,
    ):
        ...

    def fit(self) -> None:
        """Main PPO training loop."""

    def generate_ppo_data(self, rollout_batch) -> Tuple[dict, dict]:
        """Generate rollout, get rewards/values, compute advantages."""

    def generate_rollouts(self) -> Tuple[dict, dict, dict]:
        """Full rollout generation with metrics."""

Import

from nemo_aligner.algorithms.ppo import PPOTrainer

I/O Contract

Inputs

Name Type Required Description
cfg DictConfig Yes PPO config: discount_factor, gae_lambda, initial_policy_kl_penalty, ppo_epochs, entropy_bonus, ratio_eps
model MegatronGPTActorModel Yes PPO actor model
rm_critic RemoteGPTRMCriticClient Yes Remote critic/RM client
train_dataloader_builder Callable Yes Factory for prompt dataloaders
batch_iterator_cls type Yes Batch iteration helper class

Outputs

Name Type Description
(side effect) None Updated actor weights, saved checkpoints
metrics Dict Per-step: rewards, init_policy_kl, advantages, returns, actor_loss, critic_loss

Usage Examples

from nemo_aligner.algorithms.ppo import PPOTrainer

ppo_trainer = PPOTrainer(
    cfg=cfg.trainer.ppo,
    model=actor_model,
    optimizer=optimizer,
    scheduler=scheduler,
    train_dataloader_builder=train_dataloader_builder,
    val_dataloader_builder=val_dataloader_builder,
    collate_fn=collate_fn,
    rm_critic=rm_critic,
    batch_iterator_cls=batch_iterator_cls,
    logger=logger,
    ckpt_callback=ckpt_callback,
    run_timer=timer,
)
ppo_trainer.fit()

Related Pages

Knowledge Sources

Reinforcement_Learning, NLP

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment