Implementation:Alibaba ROLL McaTrainer
| Knowledge Sources | |
|---|---|
| Domains | Training, Distributed_Computing |
| Last Updated | 2026-02-07 20:00 GMT |
Overview
Advanced PyTorch trainer for Megatron-Core distributed training supporting model/pipeline/expert parallelism, sequence packing, per-token loss, and distributed checkpoint management.
Description
trainer.py implements McaTrainer, a comprehensive training loop that extends HuggingFace's Trainer class to work with Megatron-Core's distributed primitives. It replaces the standard single-GPU training step with Megatron-Core's forward_backward_func pipeline scheduler, which handles micro-batching, pipeline-parallel communication, and gradient accumulation across distributed ranks.
Core architecture:
- Distributed Data Parallel (DDP): Wraps each model chunk in Megatron-Core's DistributedDataParallel with configurable grad reduction, distributed optimizer support, and parameter gather overlap.
- Pipeline scheduling: Uses get_forward_backward_func() to dispatch forward/backward passes across pipeline stages with proper micro-batch interleaving.
- Sequence packing: Supports packing multiple sequences into a single batch using PackedSeqParams with THD format for efficient attention computation.
- Context parallelism: Integrates with Megatron-Core's context parallel infrastructure for sequence splitting across CP ranks.
- Per-token loss: Optionally computes loss normalized by actual non-padded token count rather than batch size, with proper all-reduce across context parallel groups.
Training loop (_inner_training_loop, lines 789-925):
- Calls _prepare_train_loop to set up dataloaders, optimizer, scheduler, and state
- Iterates over epochs with cyclic data iteration
- For each step, collects micro-batches via _get_step_iterator_and_seq_length
- Calls training_step which runs the forward-backward pipeline schedule
- Handles gradient norm tracking, skipped iterations, MoE auxiliary losses, and MTP losses
- Manages checkpointing, evaluation, and logging through callback handlers
Optimizer and scheduler:
- Uses Megatron-Core's get_megatron_optimizer which supports distributed optimizer (ZeRO-like sharding), CPU offloading, and proper gradient clipping
- Custom LR scheduler via get_megatron_lr_scheduler
Checkpoint management:
- Model checkpoints use Megatron-Core's distributed checkpointing with fully parallel save/load strategies
- Optimizer state saved with FullyParallelSaveStrategyWrapper for efficient distributed I/O
- RNG state saved per-process for reproducible training resume
- PEFT adapter checkpoints handled separately with adapter config detection
Usage
Use McaTrainer as the primary trainer for any Megatron-Core based training job. It is instantiated in the training entry points (pt_mca_train, sft_mca_train) with a VirtualModels instance, training arguments, tokenizer, and data collator.
Code Reference
Source Location
- Repository: Alibaba_ROLL
- File: mcore_adapter/src/mcore_adapter/trainer/trainer.py
- Lines: 1-1097
Signature
class McaTrainer(Trainer):
metrics_keys = ["loss"]
_language_input_names = ["input_ids", "attention_mask", "labels", "position_ids"]
ckpt_sharding_type = "fully_sharded_model_space"
Key Methods
__init__
def __init__(
self,
model: "VirtualModels" = None,
args: TrainingArguments = None,
**kwargs,
) # lines 85-117
Initializes the trainer by patching PyTorch shard operations, calling initialize_megatron, creating the HF Trainer superclass (using model[0] as a hack), then replacing self.model with the full VirtualModels. Wraps models in DDP, sets up forward_backward_func, and optionally configures distributed optimizer save strategies.
_prepare_model
def _prepare_model(self, models: "VirtualModels") -> List["DistributedDataParallel"] # lines 119-140
Wraps each model chunk in Megatron-Core DistributedDataParallel with configurable grad reduce settings. Disables bucketing for model chunks beyond the first (communication is already overlapped with compute).
get_train_dataloader
def get_train_dataloader(self) -> DataLoader # lines 167-198
Creates a training DataLoader distributed across data-parallel ranks using prepare_data_loader from Accelerate. Always drops the last incomplete batch. Uses Megatron-Core's mpu for world size and rank.
_prepare_train_inputs
def _prepare_train_inputs(self, data_iterator: Iterator) -> Dict[str, Tensor | Any] # lines 231-254
Prepares a single micro-batch: applies sequence packing or constructs left-to-right attention masks and position IDs, then slices the batch for context parallelism.
_packing_sequence
def _packing_sequence(self, inputs: Dict[str, Tensor | Any]) # lines 289-324
Packs multiple sequences into a single batch using THD format. Computes cumulative sequence lengths, creates PackedSeqParams, and reshapes inputs to [1, total_tokens, ...]. Requires Transformer Engine implementation.
training_step
def training_step(
self,
models: List[DistributedDataParallel],
data_iterator,
seq_length,
) # lines 417-454
Executes a complete training step including forward-backward pass via forward_backward_func, optimizer step, and LR scheduler step. Returns (loss, metrics_tensors, skipped_iter, grad_norm, num_zeros_in_grad).
_inner_training_loop
def _inner_training_loop(
self,
batch_size=None,
args=None,
resume_from_checkpoint=None,
trial=None,
ignore_keys_for_eval=None,
) # lines 789-925
Main training loop. Manages epoch iteration, data skipping for resume, cyclic data iteration, gradient accumulation, loss tracking (including per-token loss), and final metrics reporting.
create_optimizer
def create_optimizer(self) # lines 482-501
Creates a Megatron-Core optimizer with OptimizerConfig from training arguments. Supports Adam/SGD, distributed optimizer, CPU offloading, and gradient clipping.
create_scheduler
def create_scheduler(self, num_training_steps: int, optimizer=None) # lines 503-506
Creates a Megatron-compatible LR scheduler via get_megatron_lr_scheduler.
_load_from_checkpoint
def _load_from_checkpoint(self, resume_from_checkpoint, model=None) # lines 508-533
Loads model weights from checkpoint. Handles both PEFT adapter checkpoints (detecting adapter subdirectories with adapter_config.json) and standard MCA checkpoints.
evaluation_loop
@torch.no_grad()
def evaluation_loop(
self,
dataloader: DataLoader,
description: str,
prediction_loss_only: Optional[bool] = None,
ignore_keys: Optional[List[str]] = None,
metric_key_prefix: str = "eval",
) -> EvalLoopOutput # lines 1010-1046
Forward-only evaluation loop using the pipeline scheduler. Streams evaluation inputs, runs forward passes, gathers metrics across distributed ranks, and returns EvalLoopOutput.
save_model
def save_model(self, output_dir: str = None, _internal_call: bool = False) # lines 1082-1091
Saves model in MCA format (unless save_only_model and save_hf_model are set), optionally exports to HF format, and saves tokenizer and training arguments.
Import
import torch
from megatron.core import dist_checkpointing, mpu, tensor_parallel
from megatron.core.distributed import DistributedDataParallel, finalize_model_grads
from megatron.core.optimizer import OptimizerConfig, get_megatron_optimizer
from megatron.core.pipeline_parallel import get_forward_backward_func
from transformers.trainer import Trainer
from mcore_adapter.trainer.trainer import McaTrainer
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | VirtualModels | Yes | Model container with virtual pipeline model chunks |
| args | TrainingArguments | Yes | Training arguments with parallelism, optimizer, and training configs |
| tokenizer | PreTrainedTokenizerBase | No | Tokenizer for padding and data collation |
| data_collator | callable | No | Function to collate batch samples |
| train_dataset | Dataset | No | Training dataset |
| eval_dataset | Dataset | No | Evaluation dataset |
| resume_from_checkpoint | str | No | Path to checkpoint directory for resuming training |
Outputs
| Name | Type | Description |
|---|---|---|
| TrainOutput | TrainOutput | Contains global_step, training_loss, and metrics dict |
| EvalLoopOutput | EvalLoopOutput | Contains metrics dict, num_samples (predictions are None) |
Usage Examples
from mcore_adapter.models import AutoModel
from mcore_adapter.trainer import McaTrainer
from mcore_adapter.training_args import Seq2SeqTrainingArguments
# Create training arguments
args = Seq2SeqTrainingArguments(
output_dir="/path/to/output",
tensor_model_parallel_size=2,
pipeline_model_parallel_size=2,
use_distributed_optimizer=True,
overlap_grad_reduce=True,
bf16=True,
per_device_train_batch_size=4,
gradient_accumulation_steps=2,
num_train_epochs=3,
sequence_packing=True,
)
# Load model
model = AutoModel.from_pretrained("Qwen/Qwen2.5-7B", args)
# Create trainer
trainer = McaTrainer(
model=model,
args=args,
tokenizer=tokenizer,
data_collator=data_collator,
train_dataset=train_dataset,
)
# Train
trainer.train()
# Save model
trainer.save_model("/path/to/output")