Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:NVIDIA NeMo Aligner ReinforceTrainer Fit

From Leeroopedia


Implementation Details
Name ReinforceTrainer_Fit
Type API Doc
Implements Principle REINFORCE_Training
Module nemo_aligner.algorithms
Repository NeMo Aligner
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for executing the REINFORCE training loop for language model alignment without a critic network provided by the NeMo Aligner algorithms module.

Description

The ReinforceTrainer class coordinates REINFORCE-based alignment: (1) generates responses via the actor, (2) scores them via the remote reward model, (3) computes RLOO baseline from batch rewards, (4) applies KL penalty between current and reference policy, (5) updates actor using REINFORCE policy gradient. The fit() method manages the main loop with epoch iteration, metric logging, validation, and checkpointing. Simpler than PPOTrainer as it has no critic training phase.

Usage

Used in train_gpt_reinforce_actor.py. Requires a running reward model server. Does not require a critic server.

Code Reference

Source Location

  • Repository: NeMo Aligner
  • File: nemo_aligner/algorithms/reinforce.py
  • Lines: L126-599

Signature

class ReinforceTrainer:
    def __init__(
        self,
        cfg: DictConfig,
        model: MegatronGPTReinforceActorModel,
        optimizer,
        scheduler,
        train_dataloader_builder: Callable,
        val_dataloader_builder: Callable,
        collate_fn,
        rm: RemoteGPTRMClient,
        batch_iterator_cls,
        logger,
        ckpt_callback,
        run_timer,
    ):
        ...

    def fit(self) -> None:
        """Main REINFORCE training loop."""

    def generate_reinforce_data(self, rollout_batch) -> Tuple[dict, dict]:
        """Generate rollout, get rewards, compute baseline."""

    def generate_rollouts(self) -> Tuple[dict, dict, dict]:
        """Full rollout generation with metrics."""

Import

from nemo_aligner.algorithms.reinforce import ReinforceTrainer

I/O Contract

Inputs

Name Type Required Description
cfg DictConfig Yes REINFORCE config: initial_policy_kl_penalty, discount_factor, reinforce_epochs
model MegatronGPTReinforceActorModel Yes REINFORCE actor model
rm RemoteGPTRMClient Yes Remote reward model client
train_dataloader_builder Callable Yes Factory for prompt dataloaders
batch_iterator_cls type Yes Batch iteration helper

Outputs

Name Type Description
(side effect) None Updated actor weights, checkpoints
metrics Dict Per-step: rewards_with_kl, init_policy_kl, baseline, actor_loss

Usage Examples

from nemo_aligner.algorithms.reinforce import ReinforceTrainer

reinforce_trainer = ReinforceTrainer(
    cfg=cfg.trainer.reinforce,
    model=actor_model,
    optimizer=optimizer,
    scheduler=scheduler,
    train_dataloader_builder=train_dataloader_builder,
    val_dataloader_builder=val_dataloader_builder,
    collate_fn=collate_fn,
    rm=rm_client,
    batch_iterator_cls=batch_iterator_cls,
    logger=logger,
    ckpt_callback=ckpt_callback,
    run_timer=timer,
)
reinforce_trainer.fit()

Related Pages

Knowledge Sources

Reinforcement_Learning, NLP

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment