Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Alibaba ROLL McaTrainer

From Leeroopedia


Knowledge Sources
Domains Training, Distributed_Computing
Last Updated 2026-02-07 20:00 GMT

Overview

Advanced PyTorch trainer for Megatron-Core distributed training supporting model/pipeline/expert parallelism, sequence packing, per-token loss, and distributed checkpoint management.

Description

trainer.py implements McaTrainer, a comprehensive training loop that extends HuggingFace's Trainer class to work with Megatron-Core's distributed primitives. It replaces the standard single-GPU training step with Megatron-Core's forward_backward_func pipeline scheduler, which handles micro-batching, pipeline-parallel communication, and gradient accumulation across distributed ranks.

Core architecture:

  • Distributed Data Parallel (DDP): Wraps each model chunk in Megatron-Core's DistributedDataParallel with configurable grad reduction, distributed optimizer support, and parameter gather overlap.
  • Pipeline scheduling: Uses get_forward_backward_func() to dispatch forward/backward passes across pipeline stages with proper micro-batch interleaving.
  • Sequence packing: Supports packing multiple sequences into a single batch using PackedSeqParams with THD format for efficient attention computation.
  • Context parallelism: Integrates with Megatron-Core's context parallel infrastructure for sequence splitting across CP ranks.
  • Per-token loss: Optionally computes loss normalized by actual non-padded token count rather than batch size, with proper all-reduce across context parallel groups.

Training loop (_inner_training_loop, lines 789-925):

  1. Calls _prepare_train_loop to set up dataloaders, optimizer, scheduler, and state
  2. Iterates over epochs with cyclic data iteration
  3. For each step, collects micro-batches via _get_step_iterator_and_seq_length
  4. Calls training_step which runs the forward-backward pipeline schedule
  5. Handles gradient norm tracking, skipped iterations, MoE auxiliary losses, and MTP losses
  6. Manages checkpointing, evaluation, and logging through callback handlers

Optimizer and scheduler:

  • Uses Megatron-Core's get_megatron_optimizer which supports distributed optimizer (ZeRO-like sharding), CPU offloading, and proper gradient clipping
  • Custom LR scheduler via get_megatron_lr_scheduler

Checkpoint management:

  • Model checkpoints use Megatron-Core's distributed checkpointing with fully parallel save/load strategies
  • Optimizer state saved with FullyParallelSaveStrategyWrapper for efficient distributed I/O
  • RNG state saved per-process for reproducible training resume
  • PEFT adapter checkpoints handled separately with adapter config detection

Usage

Use McaTrainer as the primary trainer for any Megatron-Core based training job. It is instantiated in the training entry points (pt_mca_train, sft_mca_train) with a VirtualModels instance, training arguments, tokenizer, and data collator.

Code Reference

Source Location

Signature

class McaTrainer(Trainer):
    metrics_keys = ["loss"]
    _language_input_names = ["input_ids", "attention_mask", "labels", "position_ids"]
    ckpt_sharding_type = "fully_sharded_model_space"

Key Methods

__init__

def __init__(
    self,
    model: "VirtualModels" = None,
    args: TrainingArguments = None,
    **kwargs,
)  # lines 85-117

Initializes the trainer by patching PyTorch shard operations, calling initialize_megatron, creating the HF Trainer superclass (using model[0] as a hack), then replacing self.model with the full VirtualModels. Wraps models in DDP, sets up forward_backward_func, and optionally configures distributed optimizer save strategies.

_prepare_model

def _prepare_model(self, models: "VirtualModels") -> List["DistributedDataParallel"]  # lines 119-140

Wraps each model chunk in Megatron-Core DistributedDataParallel with configurable grad reduce settings. Disables bucketing for model chunks beyond the first (communication is already overlapped with compute).

get_train_dataloader

def get_train_dataloader(self) -> DataLoader  # lines 167-198

Creates a training DataLoader distributed across data-parallel ranks using prepare_data_loader from Accelerate. Always drops the last incomplete batch. Uses Megatron-Core's mpu for world size and rank.

_prepare_train_inputs

def _prepare_train_inputs(self, data_iterator: Iterator) -> Dict[str, Tensor | Any]  # lines 231-254

Prepares a single micro-batch: applies sequence packing or constructs left-to-right attention masks and position IDs, then slices the batch for context parallelism.

_packing_sequence

def _packing_sequence(self, inputs: Dict[str, Tensor | Any])  # lines 289-324

Packs multiple sequences into a single batch using THD format. Computes cumulative sequence lengths, creates PackedSeqParams, and reshapes inputs to [1, total_tokens, ...]. Requires Transformer Engine implementation.

training_step

def training_step(
    self,
    models: List[DistributedDataParallel],
    data_iterator,
    seq_length,
)  # lines 417-454

Executes a complete training step including forward-backward pass via forward_backward_func, optimizer step, and LR scheduler step. Returns (loss, metrics_tensors, skipped_iter, grad_norm, num_zeros_in_grad).

_inner_training_loop

def _inner_training_loop(
    self,
    batch_size=None,
    args=None,
    resume_from_checkpoint=None,
    trial=None,
    ignore_keys_for_eval=None,
)  # lines 789-925

Main training loop. Manages epoch iteration, data skipping for resume, cyclic data iteration, gradient accumulation, loss tracking (including per-token loss), and final metrics reporting.

create_optimizer

def create_optimizer(self)  # lines 482-501

Creates a Megatron-Core optimizer with OptimizerConfig from training arguments. Supports Adam/SGD, distributed optimizer, CPU offloading, and gradient clipping.

create_scheduler

def create_scheduler(self, num_training_steps: int, optimizer=None)  # lines 503-506

Creates a Megatron-compatible LR scheduler via get_megatron_lr_scheduler.

_load_from_checkpoint

def _load_from_checkpoint(self, resume_from_checkpoint, model=None)  # lines 508-533

Loads model weights from checkpoint. Handles both PEFT adapter checkpoints (detecting adapter subdirectories with adapter_config.json) and standard MCA checkpoints.

evaluation_loop

@torch.no_grad()
def evaluation_loop(
    self,
    dataloader: DataLoader,
    description: str,
    prediction_loss_only: Optional[bool] = None,
    ignore_keys: Optional[List[str]] = None,
    metric_key_prefix: str = "eval",
) -> EvalLoopOutput  # lines 1010-1046

Forward-only evaluation loop using the pipeline scheduler. Streams evaluation inputs, runs forward passes, gathers metrics across distributed ranks, and returns EvalLoopOutput.

save_model

def save_model(self, output_dir: str = None, _internal_call: bool = False)  # lines 1082-1091

Saves model in MCA format (unless save_only_model and save_hf_model are set), optionally exports to HF format, and saves tokenizer and training arguments.

Import

import torch
from megatron.core import dist_checkpointing, mpu, tensor_parallel
from megatron.core.distributed import DistributedDataParallel, finalize_model_grads
from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer
from megatron.core.pipeline_parallel import get_forward_backward_func
from transformers.trainer import Trainer

from mcore_adapter.trainer.trainer import McaTrainer

I/O Contract

Inputs

Name Type Required Description
model VirtualModels Yes Model container with virtual pipeline model chunks
args TrainingArguments Yes Training arguments with parallelism, optimizer, and training configs
tokenizer PreTrainedTokenizerBase No Tokenizer for padding and data collation
data_collator callable No Function to collate batch samples
train_dataset Dataset No Training dataset
eval_dataset Dataset No Evaluation dataset
resume_from_checkpoint str No Path to checkpoint directory for resuming training

Outputs

Name Type Description
TrainOutput TrainOutput Contains global_step, training_loss, and metrics dict
EvalLoopOutput EvalLoopOutput Contains metrics dict, num_samples (predictions are None)

Usage Examples

from mcore_adapter.models import AutoModel
from mcore_adapter.trainer import McaTrainer
from mcore_adapter.training_args import Seq2SeqTrainingArguments

# Create training arguments
args = Seq2SeqTrainingArguments(
    output_dir="/path/to/output",
    tensor_model_parallel_size=2,
    pipeline_model_parallel_size=2,
    use_distributed_optimizer=True,
    overlap_grad_reduce=True,
    bf16=True,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=2,
    num_train_epochs=3,
    sequence_packing=True,
)

# Load model
model = AutoModel.from_pretrained("Qwen/Qwen2.5-7B", args)

# Create trainer
trainer = McaTrainer(
    model=model,
    args=args,
    tokenizer=tokenizer,
    data_collator=data_collator,
    train_dataset=train_dataset,
)

# Train
trainer.train()

# Save model
trainer.save_model("/path/to/output")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment