Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:NVIDIA NeMo Aligner CriticServerTrainer Run

From Leeroopedia


Implementation Details
Name CriticServerTrainer_Run
Type API Doc
Implements Principle Critic_Server_Deployment
Module nemo_aligner.algorithms
Repository NeMo Aligner
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for serving a combined critic and reward model as a trainable PyTriton HTTP service for PPO training provided by the NeMo Aligner algorithms module.

Description

The CriticServerTrainer class implements a PyTriton server that exposes three endpoints: (1) infer - returns critic values and optionally reward model scores, (2) train - accepts returns and masks to update critic weights, (3) save - persists checkpoint. Rank 0 runs the Triton HTTP server while other ranks run a subscriber loop for synchronized distributed operations. The class supports combined RM+critic inference for efficiency when both models share the same GPU allocation.

Usage

Used exclusively in PPO training via serve_ppo_critic.py. The actor process communicates with this server via RemoteGPTRMCriticClient.

Code Reference

Source Location

  • Repository: NeMo Aligner
  • File: nemo_aligner/algorithms/critic_server_trainer.py
  • Lines: L49-381

Signature

class CriticServerTrainer:
    def __init__(
        self,
        cfg: DictConfig,
        model,                           # MegatronGPTCriticModel
        optimizer,
        scheduler,
        logger,
        ckpt_callback,
        tokenize_func: Callable,
        gbs: int,                        # Global batch size
        model_forward_micro_batch_size: int,
    ):
        ...

    def run_server(self) -> None:
        """Start PyTriton server with infer/train/save endpoints."""

    def run_inference(self, **inputs) -> Dict[str, np.ndarray]:
        """Return values (and optionally rewards)."""

    def run_training(self, **inputs) -> Dict[str, np.ndarray]:
        """Update critic weights using returns and advantages."""

Import

from nemo_aligner.algorithms.critic_server_trainer import CriticServerTrainer

I/O Contract

Inputs (run_inference)

Name Type Required Description
tokens np.int64 Yes Token sequences (prompt + response)
sequence_lengths np.int64 Yes Sequence lengths

Outputs (run_inference)

Name Type Description
values np.float32 Critic value estimates per token position
rewards np.float32 (Optional) Reward model scores if combine_rm_and_critic_server=True

Inputs (run_training)

Name Type Required Description
tokens np.int64 Yes Token sequences
returns np.float32 Yes Computed returns for critic targets
prev_values np.float32 Yes Previous value estimates
mask np.float32 Yes Response token mask

Outputs (run_training)

Name Type Description
loss_mean np.float32 Critic training loss

Usage Examples

from nemo_aligner.algorithms.critic_server_trainer import CriticServerTrainer

critic_trainer = CriticServerTrainer(
    cfg=cfg.trainer.ppo,
    model=critic_model,
    optimizer=optimizer,
    scheduler=scheduler,
    logger=logger,
    ckpt_callback=ckpt_callback,
    tokenize_func=tokenize_func,
    gbs=cfg.model.global_batch_size,
    model_forward_micro_batch_size=cfg.model.forward_micro_batch_size,
)
critic_trainer.run_server()  # Blocks: serves until training completes

Related Pages

Knowledge Sources

Reinforcement_Learning, Distributed_Systems

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment