Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Alibaba ROLL MegatronTrainStrategy Train Step

From Leeroopedia


Knowledge Sources
Domains Distributed_Training, Reinforcement_Learning
Last Updated 2026-02-07 20:00 GMT

Overview

Concrete Megatron-Core based distributed training step for policy optimization provided by the Alibaba ROLL library.

Description

The MegatronTrainStrategy.train_step method executes a single training step using Megatron-Core's distributed training infrastructure. It handles micro-batch splitting, forward-backward pipeline scheduling, gradient accumulation, optimizer stepping with memory offloading, and metric aggregation. This is a wrapper around Megatron-Core's native training loop adapted for RL loss functions.

Usage

This strategy is the default training backend for large-scale RLVR training with tensor and pipeline parallelism. It is used when the worker configuration specifies Megatron-Core as the training strategy.

Code Reference

Source Location

  • Repository: Alibaba ROLL
  • File: roll/distributed/strategy/megatron_strategy.py
  • Lines: L1078-1220

Signature

def train_step(
    self,
    batch: DataProto,
    loss_func: Callable
) -> Dict[str, float]:
    """
    Execute single training step with Megatron-LM distributed training.

    Args:
        batch: DataProto containing model inputs (input_ids, attention_mask,
               advantages, old_log_probs, response_mask)
        loss_func: Loss computation function (PPO/GRPO/Reinforce++ loss)

    Returns:
        Dict[str, float]: Training metrics including grad_norm,
                          MOE losses (if applicable), MTP losses (if applicable)
    """

Import

from roll.distributed.strategy.megatron_strategy import MegatronTrainStrategy

I/O Contract

Inputs

Name Type Required Description
batch DataProto Yes Training batch with advantages, old_log_probs, response_mask, input_ids
loss_func Callable Yes PPO/GRPO/Reinforce++ loss function

Outputs

Name Type Description
metrics Dict[str, float] Training metrics: grad_norm, actor/pg_loss, actor/clipfrac, actor/approxkl

Usage Examples

Training Step in Pipeline

# Called internally by the RLVR pipeline via cluster dispatch:
# actor_train.execute_all_sync("train_step", batch)

# The train_step dispatches to MegatronTrainStrategy internally:
metrics = megatron_strategy.train_step(
    batch=data_with_advantages,
    loss_func=actor_worker.loss_func
)

# Returned metrics
print(metrics)
# {'grad_norm': 1.23, 'actor/pg_loss': 0.05, 'actor/clipfrac': 0.12}

Related Pages

Implements Principle

Requires Environment

Environment Dependencies

This implementation requires the following environment constraints:

Heuristics Applied

This implementation uses the following heuristics:

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment