Implementation:NVIDIA NeMo Aligner PPOTrainer Fit
| Implementation Details | |
|---|---|
| Name | PPOTrainer_Fit |
| Type | API Doc |
| Implements Principle | PPO_Training |
| Module | nemo_aligner.algorithms |
| Repository | NeMo Aligner |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete tool for executing the PPO training loop with distributed actor-critic architecture provided by the NeMo Aligner algorithms module.
Description
The PPOTrainer class coordinates the full PPO training cycle: (1) rollout generation via the actor model, (2) reward and value estimation via the remote critic server, (3) advantage computation using GAE, (4) actor policy update using clipped PPO objective, (5) critic training via the remote server. The fit() method orchestrates the main training loop with epoch management, metric logging, validation, and checkpointing. It manages distributed coordination across data-parallel ranks and handles optional TRT-LLM resharding for accelerated generation.
Usage
Used in train_gpt_ppo_actor.py. Requires a running critic server process (serve_ppo_critic.py) and a connected RemoteGPTRMCriticClient.
Code Reference
Source Location
- Repository: NeMo Aligner
- File:
nemo_aligner/algorithms/ppo.py - Lines: L130-641
Signature
class PPOTrainer:
def __init__(
self,
cfg: DictConfig,
model: MegatronGPTActorModel,
optimizer,
scheduler,
train_dataloader_builder: Callable,
val_dataloader_builder: Callable,
collate_fn,
rm_critic: RemoteGPTRMCriticClient,
batch_iterator_cls,
logger,
ckpt_callback,
run_timer,
):
...
def fit(self) -> None:
"""Main PPO training loop."""
def generate_ppo_data(self, rollout_batch) -> Tuple[dict, dict]:
"""Generate rollout, get rewards/values, compute advantages."""
def generate_rollouts(self) -> Tuple[dict, dict, dict]:
"""Full rollout generation with metrics."""
Import
from nemo_aligner.algorithms.ppo import PPOTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| cfg | DictConfig | Yes | PPO config: discount_factor, gae_lambda, initial_policy_kl_penalty, ppo_epochs, entropy_bonus, ratio_eps |
| model | MegatronGPTActorModel | Yes | PPO actor model |
| rm_critic | RemoteGPTRMCriticClient | Yes | Remote critic/RM client |
| train_dataloader_builder | Callable | Yes | Factory for prompt dataloaders |
| batch_iterator_cls | type | Yes | Batch iteration helper class |
Outputs
| Name | Type | Description |
|---|---|---|
| (side effect) | None | Updated actor weights, saved checkpoints |
| metrics | Dict | Per-step: rewards, init_policy_kl, advantages, returns, actor_loss, critic_loss |
Usage Examples
from nemo_aligner.algorithms.ppo import PPOTrainer
ppo_trainer = PPOTrainer(
cfg=cfg.trainer.ppo,
model=actor_model,
optimizer=optimizer,
scheduler=scheduler,
train_dataloader_builder=train_dataloader_builder,
val_dataloader_builder=val_dataloader_builder,
collate_fn=collate_fn,
rm_critic=rm_critic,
batch_iterator_cls=batch_iterator_cls,
logger=logger,
ckpt_callback=ckpt_callback,
run_timer=timer,
)
ppo_trainer.fit()
Related Pages
- Principle:NVIDIA_NeMo_Aligner_PPO_Training
- Environment:NVIDIA_NeMo_Aligner_NeMo_Framework_GPU_Environment
- Environment:NVIDIA_NeMo_Aligner_PyTriton_Serving_Environment
- Environment:NVIDIA_NeMo_Aligner_TensorRT_LLM_Acceleration_Environment
- Heuristic:NVIDIA_NeMo_Aligner_Higher_Stability_Log_Probs
- Heuristic:NVIDIA_NeMo_Aligner_Adam_State_Offloading_Tip
- Heuristic:NVIDIA_NeMo_Aligner_PPO_NCCL_Algorithm_Setting
- Heuristic:NVIDIA_NeMo_Aligner_PPO_Critic_Warmup_Tip