| Implementation Details
|
| Name |
CriticServerTrainer_Run
|
| Type |
API Doc
|
| Implements Principle |
Critic_Server_Deployment
|
| Module |
nemo_aligner.algorithms
|
| Repository |
NeMo Aligner
|
| Last Updated |
2026-02-07 00:00 GMT
|
Overview
Concrete tool for serving a combined critic and reward model as a trainable PyTriton HTTP service for PPO training provided by the NeMo Aligner algorithms module.
Description
The CriticServerTrainer class implements a PyTriton server that exposes three endpoints: (1) infer - returns critic values and optionally reward model scores, (2) train - accepts returns and masks to update critic weights, (3) save - persists checkpoint. Rank 0 runs the Triton HTTP server while other ranks run a subscriber loop for synchronized distributed operations. The class supports combined RM+critic inference for efficiency when both models share the same GPU allocation.
Usage
Used exclusively in PPO training via serve_ppo_critic.py. The actor process communicates with this server via RemoteGPTRMCriticClient.
Code Reference
Source Location
- Repository: NeMo Aligner
- File:
nemo_aligner/algorithms/critic_server_trainer.py
- Lines: L49-381
Signature
class CriticServerTrainer:
def __init__(
self,
cfg: DictConfig,
model, # MegatronGPTCriticModel
optimizer,
scheduler,
logger,
ckpt_callback,
tokenize_func: Callable,
gbs: int, # Global batch size
model_forward_micro_batch_size: int,
):
...
def run_server(self) -> None:
"""Start PyTriton server with infer/train/save endpoints."""
def run_inference(self, **inputs) -> Dict[str, np.ndarray]:
"""Return values (and optionally rewards)."""
def run_training(self, **inputs) -> Dict[str, np.ndarray]:
"""Update critic weights using returns and advantages."""
Import
from nemo_aligner.algorithms.critic_server_trainer import CriticServerTrainer
I/O Contract
Inputs (run_inference)
| Name |
Type |
Required |
Description
|
| tokens |
np.int64 |
Yes |
Token sequences (prompt + response)
|
| sequence_lengths |
np.int64 |
Yes |
Sequence lengths
|
Outputs (run_inference)
| Name |
Type |
Description
|
| values |
np.float32 |
Critic value estimates per token position
|
| rewards |
np.float32 |
(Optional) Reward model scores if combine_rm_and_critic_server=True
|
Inputs (run_training)
| Name |
Type |
Required |
Description
|
| tokens |
np.int64 |
Yes |
Token sequences
|
| returns |
np.float32 |
Yes |
Computed returns for critic targets
|
| prev_values |
np.float32 |
Yes |
Previous value estimates
|
| mask |
np.float32 |
Yes |
Response token mask
|
Outputs (run_training)
| Name |
Type |
Description
|
| loss_mean |
np.float32 |
Critic training loss
|
Usage Examples
from nemo_aligner.algorithms.critic_server_trainer import CriticServerTrainer
critic_trainer = CriticServerTrainer(
cfg=cfg.trainer.ppo,
model=critic_model,
optimizer=optimizer,
scheduler=scheduler,
logger=logger,
ckpt_callback=ckpt_callback,
tokenize_func=tokenize_func,
gbs=cfg.model.global_batch_size,
model_forward_micro_batch_size=cfg.model.forward_micro_batch_size,
)
critic_trainer.run_server() # Blocks: serves until training completes
Related Pages
Knowledge Sources
Reinforcement_Learning, Distributed_Systems
Page Connections
Double-click a node to navigate. Hold to expand connections.