Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Alibaba ROLL TrainingArguments

From Leeroopedia


Knowledge Sources
Domains Configuration, Training, Distributed_Computing
Last Updated 2026-02-07 20:00 GMT

Overview

Training argument dataclasses combining Megatron-Core parallelism configuration with HuggingFace training arguments, providing a unified interface for distributed LLM training.

Description

training_args.py defines a hierarchy of dataclass-based argument classes that merge Megatron-Core's distributed computing configuration with HuggingFace's training hyperparameters:

DistributingParallelArguments (lines 16-278) is the base class containing all parallelism and model architecture configuration fields. It serves dual purpose:

  • As command-line arguments parsed by HfArgumentParser
  • As configuration that overrides checkpoint-stored model config values (high priority)

Key design principle: Most arguments default to None to avoid accidentally overwriting checkpoint configurations. Only training-only parameters (not affecting model checkpoints) have non-None defaults.

The class includes configuration for:

  • Parallelism: tensor, pipeline, sequence, context, and expert model parallelism
  • Virtual pipeline: Virtual pipeline model parallel size and pipeline layout
  • Recomputation: Granularity (full/selective), method (uniform/recompute), specific modules, and number of layers
  • Fusion: Bias activation fusion and RoPE fusion
  • MoE configuration: Token dispatcher type, aux loss coefficient, grouped GEMM, capacity factor, token drop policy, shared expert overlap, and router dtype
  • MTP: Multi-Token Prediction number of layers
  • FP8: Recipe, parameter format, and format selection
  • Additional configs: A catch-all dictionary or JSON file path for minor configurations

MegatronArguments (lines 281-353) extends DistributingParallelArguments with Megatron-specific training options:

  • Distributed optimizer settings (use, overlap grad reduce, overlap param gather, delay grad reduce)
  • DDP configuration (bucket size, average in collective, NaN checking)
  • Optimizer selection (Adam/SGD) with CPU offloading
  • HuggingFace model export toggle
  • Sequence packing enable

TrainingArguments (lines 356-366) is the final class combining MegatronArguments with HuggingFace TrainingArguments. Its __post_init__ enables FP32 gradient accumulation for bf16 training and disables DeepSpeed and wandb reporting.

Seq2SeqTrainingArguments (lines 369-379) is the Seq2Seq variant combining MegatronArguments with HuggingFace Seq2SeqTrainingArguments.

Usage

Use TrainingArguments or Seq2SeqTrainingArguments as the primary training configuration class. Parse from command-line via HfArgumentParser, or instantiate directly with the desired parallelism and training settings.

Code Reference

Source Location

Key Classes

DistributingParallelArguments

@dataclass
class DistributingParallelArguments:
    tensor_model_parallel_size: Optional[int] = None
    pipeline_model_parallel_size: Optional[int] = None
    sequence_parallel: bool = False
    virtual_pipeline_model_parallel_size: Optional[int] = None
    context_parallel_size: Optional[int] = None
    expert_model_parallel_size: Optional[int] = None
    recompute_granularity: Optional[Literal["full", "selective"]] = None
    moe_token_dispatcher_type: Literal["allgather", "alltoall"] = "allgather"
    mtp_num_layers: Optional[int] = None
    calculate_per_token_loss: bool = False
    fp8_recipe: Optional[str] = None
    additional_configs: Optional[Union[dict, str]] = field(default_factory=dict)
    # ... (30+ fields total)

Key methods:

  • __post_init__() (lines 241-278): Validates argument consistency. Loads additional_configs from JSON if provided as a string path. Converts recompute_modules from comma-separated string to list. Raises ValueError for incompatible variable_seq_lengths with allgather dispatcher. Auto-computes virtual_pipeline_model_parallel_size from pipeline_model_parallel_layout.
  • get_config_dict() (lines 274-278): Returns non-None fields as a dictionary, merging additional_configs into the top level.

MegatronArguments

@dataclass
class MegatronArguments(DistributingParallelArguments):
    accumulate_allreduce_grads_in_fp32: bool = False
    use_distributed_optimizer: bool = False
    overlap_grad_reduce: bool = False
    delay_grad_reduce: bool = True
    overlap_param_gather: bool = False
    optimizer: str = "adam"
    optimizer_cpu_offload: bool = False
    save_hf_model: bool = False
    sequence_packing: bool = False
    # ... additional fields

Key methods:

  • __post_init__() (lines 338-344): Asserts that overlap_param_gather requires use_distributed_optimizer and overlap_grad_reduce.
  • from_json_file(json_file_path) (lines 347-350): Class method to load arguments from a JSON file.
  • allow_variable_seq_lengths() (lines 352-353): Returns True if variable_seq_lengths is enabled or pipeline parallelism is not used.

TrainingArguments

@dataclass
class TrainingArguments(MegatronArguments, HFTrainingArguments):
    def __post_init__(self):
        # Enables FP32 grad accumulation for bf16
        # Disables DeepSpeed
        # Calls both parent __post_init__
        # Removes wandb from report_to

Seq2SeqTrainingArguments

@dataclass
class Seq2SeqTrainingArguments(MegatronArguments, HFSeq2SeqTrainingArguments):
    def __post_init__(self):
        # Same initialization as TrainingArguments

Import

from dataclasses import dataclass, field
from typing import Literal, Optional, Union
from megatron.core.transformer.pipeline_parallel_layer_layout import PipelineParallelLayerLayout
from transformers import TrainingArguments as HFTrainingArguments
from transformers import Seq2SeqTrainingArguments as HFSeq2SeqTrainingArguments

from mcore_adapter.training_args import (
    DistributingParallelArguments,
    MegatronArguments,
    TrainingArguments,
    Seq2SeqTrainingArguments,
)

I/O Contract

Inputs

Name Type Required Description
tensor_model_parallel_size Optional[int] No Degree of tensor model parallelism (default: None, uses checkpoint value)
pipeline_model_parallel_size Optional[int] No Degree of pipeline model parallelism (default: None)
sequence_parallel bool No Enable sequence parallelism (default: False)
context_parallel_size Optional[int] No Degree of context parallelism (default: None)
expert_model_parallel_size Optional[int] No Degree of expert model parallelism (default: None)
use_distributed_optimizer bool No Enable ZeRO-like distributed optimizer (default: False)
overlap_grad_reduce bool No Overlap gradient reduction with backward pass (default: False)
recompute_granularity Optional[str] No Activation recomputation: full or selective (default: None)
sequence_packing bool No Enable sequence packing for efficient batching (default: False)
fp8_recipe Optional[str] No FP8 training recipe: mxfp8 or blockwise (default: None)
additional_configs Optional[dict or str] No Extra config dict or JSON file path (default: empty dict)

Outputs

Name Type Description
args TrainingArguments Fully validated training arguments ready for McaTrainer

Usage Examples

from mcore_adapter.training_args import Seq2SeqTrainingArguments
from transformers import HfArgumentParser

# Parse from command line
parser = HfArgumentParser(Seq2SeqTrainingArguments)
args = parser.parse_args_into_dataclasses()[0]

# Or instantiate directly
args = Seq2SeqTrainingArguments(
    output_dir="/path/to/output",
    tensor_model_parallel_size=4,
    pipeline_model_parallel_size=2,
    virtual_pipeline_model_parallel_size=2,
    sequence_parallel=True,
    context_parallel_size=2,
    expert_model_parallel_size=2,
    use_distributed_optimizer=True,
    overlap_grad_reduce=True,
    overlap_param_gather=True,
    recompute_granularity="selective",
    bf16=True,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    learning_rate=1e-5,
    num_train_epochs=3,
    sequence_packing=True,
    save_hf_model=True,
)

# Load additional configs from JSON
args_with_extras = Seq2SeqTrainingArguments(
    output_dir="/path/to/output",
    additional_configs="/path/to/extra_config.json",
)

# Get parallelism config as dict
config_dict = args.get_config_dict()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment