| Implementation Details
|
| Name |
MegatronGPT_Actor_And_Critic_Client
|
| Type |
API Doc
|
| Implements Principle |
PPO_Actor_Critic_Setup
|
| Module |
nemo_aligner.models.nlp.gpt
|
| Repository |
NeMo Aligner
|
| Last Updated |
2026-02-07 00:00 GMT
|
Overview
Concrete tools for PPO actor model initialization and remote critic/reward model HTTP client communication provided by the NeMo Aligner models module.
Description
MegatronGPTActorModel extends MegatronGPTModel with PPO-specific capabilities: text generation (infer), log-probability computation, entropy calculation, and the PPO clipped ratio loss function. It implements AlignableGenerativeInterface for integration with the PPO training loop. RemoteGPTRMCriticClient provides HTTP communication with the critic server via PyTriton FuturesModelClient, supporting inference (get values + rewards), training (send returns for critic update), and checkpoint saving.
Usage
Used in PPO training scripts. The actor model is loaded from a pretrained checkpoint. The critic client connects to the critic server running on separate GPU allocation.
Code Reference
Source Location
- Repository: NeMo Aligner
- File:
nemo_aligner/models/nlp/gpt/megatron_gpt_ppo_actor.py (L64-417), nemo_aligner/models/nlp/gpt/reward_critic_clients.py (L100-182)
Signature
class MegatronGPTActorModel(NLPAdapterModelMixin, MegatronGPTModel, AlignableGenerativeInterface):
def __init__(self, cfg: DictConfig, trainer: Trainer):
...
def infer(self, inference_batch: dict) -> dict:
"""Generate responses. Returns response_tokens, response_lengths, prompt_lengths, is_end."""
def get_inference_log_probs(self, response_tokens, forward_micro_batch_size) -> Tensor:
"""Compute log probabilities for generated tokens."""
class RemoteGPTRMCriticClient:
def __init__(self, cfg: DictConfig):
...
def infer_rm_critic(self, rollout_batch: dict) -> RMCriticFutureResult:
"""Get rewards and values from remote server."""
def train(self, rollout_batch: dict) -> None:
"""Send training data to critic server for weight update."""
def save(self) -> None:
"""Trigger checkpoint save on critic server."""
Import
from nemo_aligner.models.nlp.gpt.megatron_gpt_ppo_actor import MegatronGPTActorModel
from nemo_aligner.models.nlp.gpt.reward_critic_clients import RemoteGPTRMCriticClient
I/O Contract
Inputs (MegatronGPTActorModel.infer)
| Name |
Type |
Required |
Description
|
| inference_batch |
dict |
Yes |
Dict with prompt token tensors and generation config
|
Outputs (MegatronGPTActorModel.infer)
| Name |
Type |
Description
|
| response_tokens |
Tensor |
Generated full sequences (prompt + response)
|
| response_lengths |
Tensor |
Total sequence lengths
|
| prompt_lengths |
Tensor |
Prompt-only lengths
|
| is_end |
Tensor |
Whether generation hit EOS
|
Inputs (RemoteGPTRMCriticClient)
| Name |
Type |
Required |
Description
|
| cfg |
DictConfig |
Yes |
Server connection config with IP, port
|
Outputs (RemoteGPTRMCriticClient.infer_rm_critic)
| Name |
Type |
Description
|
| rewards |
np.ndarray |
Reward model scores
|
| values |
np.ndarray |
Critic value estimates
|
Usage Examples
from nemo_aligner.models.nlp.gpt.megatron_gpt_ppo_actor import MegatronGPTActorModel
from nemo_aligner.models.nlp.gpt.reward_critic_clients import RemoteGPTRMCriticClient
# Load actor model
actor_model = load_from_nemo(MegatronGPTActorModel, model_cfg, trainer, restore_path=path)
# Create critic client
rm_critic = RemoteGPTRMCriticClient(cfg.remote_critic_rm)
# Generate responses
rollout = actor_model.infer(prompt_batch)
# Get rewards and values
result = rm_critic.infer_rm_critic(rollout)
Related Pages
Knowledge Sources
Reinforcement_Learning, NLP
Page Connections
Double-click a node to navigate. Hold to expand connections.