Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Alibaba ROLL ModelFactory

From Leeroopedia


Knowledge Sources
Domains Model_Architecture, Distributed_Computing
Last Updated 2026-02-07 20:00 GMT

Overview

Factory module for creating Megatron-based GPT models with virtual pipeline parallelism, distributed weight loading, and HuggingFace export support.

Description

model_factory.py provides the core model creation infrastructure for MCoreAdapter. It contains three primary components:

VirtualModels (lines 47-180) is a wrapper around a list of model chunks that supports virtual pipeline model parallelism (interleaved scheduling). Each chunk handles a subset of transformer layers, and VirtualModels provides a unified interface for:

  • Saving and loading state dictionaries (including PEFT adapter support)
  • Aggregating parameters and FLOPs across all chunks
  • Exporting to HuggingFace format via ModelConverter
  • Distributed checkpointing with per-virtual-stage sharded state dicts

PretrainedModel (lines 183-295) is the base model class inheriting from MegatronModule and ModuleUtilsMixin. Its from_pretrained class method orchestrates the full model loading pipeline:

  1. Loads or converts configuration via config_class.from_pretrained
  2. Creates virtual model instances
  3. Detects checkpoint format (MCA vs HuggingFace)
  4. Checks distributed config compatibility for MCA checkpoints
  5. Converts HuggingFace weights to MCA format if needed via ModelConverter
  6. Loads state dict with strict assertion checking

McaGPTModel (lines 298-370) extends Megatron-Core's GPTModel with MCoreAdapter-specific initialization, including transformer layer spec construction that handles MoE, Transformer Engine, local implementations, RMSNorm, and Multi-Token Prediction (MTP) block specs.

Usage

Use McaGPTModel.from_pretrained() (via AutoModel) to load a pretrained model from any supported checkpoint format with the desired parallelism configuration.

Code Reference

Source Location

Key Classes

VirtualModels

class VirtualModels:
    def __init__(self, cls, config: "McaModelConfig", *args, **kwargs)

Wrapper for virtual pipeline model parallel. Creates N model chunks where N = virtual_pipeline_model_parallel_size (default 1). Each chunk is initialized with its vp_stage index.

Key methods:

  • save_pretrained(save_directory) (lines 58-73): Saves model with PEFT adapter support. For PEFT models, saves adapter config and state dict separately per adapter name.
  • load_state_dict(state_dict, strict) (lines 75-100): Loads state dict with PEFT support. For multi-VP models, loads per-chunk state dicts keyed as model0, model1, etc.
  • save_pretrained_as_hf(save_directory, save_safetensors, max_shard_size) (lines 160-167): Exports model to HuggingFace format using ModelConverter with in-flight conversion to minimize memory.
  • sharded_state_dict(prefix) (lines 172-180): Returns distributed checkpoint-compatible sharded state dict, setting virtual PP rank for each model chunk.
  • get_batch_on_this_cp_rank(*args, **kwargs) (lines 169-170): Delegates to the first model chunk for context parallel batch slicing.

PretrainedModel

class PretrainedModel(MegatronModule, ModuleUtilsMixin):
    config_class = McaModelConfig

Key methods:

  • from_pretrained(model_name_or_path, args, use_cpu_initialization, tokenizer) (lines 186-241): Full model loading pipeline. Returns VirtualModels instance. Handles tokenizer-driven vocab resizing, MCA/HF checkpoint detection, and weight conversion.
  • get_batch_on_this_cp_rank(batch, dim3_keys) (lines 248-283): Implements context parallelism sequence splitting with load balancing. Splits sequences into 2 * CP chunks, then assigns chunks to CP ranks in a zigzag pattern for balanced computation.
  • save_pretrained(save_directory, state_dict) (lines 243-246): Saves model config and sharded state dict to directory.

McaGPTModel

class McaGPTModel(GPTModel, PretrainedModel):
    main_input_name: str = "input_ids"
    config_class = McaModelConfig

Key methods:

  • __init__(config, **kwargs) (lines 302-331): Determines pre_process and post_process flags from pipeline stage, constructs layer specs, initializes GPTModel superclass, sets tensor parallel attributes, and registers a logits postprocessing hook.
  • _get_transformer_layer_spec(config) (lines 333-360): Returns the appropriate Megatron-Core layer spec based on configuration: decoder block spec for MoE, TE spec for Transformer Engine, or local spec otherwise. Patches RMSNorm and shared expert gate settings.
  • _get_mtp_block_spec(config, vp_stage) (lines 362-370): Returns Multi-Token Prediction block spec when mtp_num_layers > 0.

Import

import torch
from megatron.core import mpu, tensor_parallel
from megatron.core.models.gpt import GPTModel
from megatron.core.models.gpt.gpt_layer_specs import (
    get_gpt_decoder_block_spec, get_gpt_layer_local_spec,
    get_gpt_layer_with_transformer_engine_spec, get_gpt_mtp_block_spec,
)
from megatron.core.transformer.module import MegatronModule
from mcore_adapter.models.model_factory import McaGPTModel, PretrainedModel, VirtualModels

I/O Contract

Inputs

Name Type Required Description
model_name_or_path str Yes Path to pretrained model checkpoint (MCA or HuggingFace format)
args TrainingArguments No Training arguments with parallelism and precision settings
use_cpu_initialization bool No Initialize model on CPU instead of GPU (default: False)
tokenizer PreTrainedTokenizer No Tokenizer for optional vocab resizing
config McaModelConfig Yes (for __init__) Model configuration with architecture and parallelism settings

Outputs

Name Type Description
models VirtualModels Container of model chunks with loaded weights, ready for training

Usage Examples

from mcore_adapter.models import AutoModel
from mcore_adapter.training_args import TrainingArguments

# Load model with tensor and pipeline parallelism
args = TrainingArguments(
    tensor_model_parallel_size=2,
    pipeline_model_parallel_size=2,
    virtual_pipeline_model_parallel_size=2,
    bf16=True,
    output_dir="/tmp/output",
)
models = AutoModel.from_pretrained("Qwen/Qwen2.5-7B", args)

# Access individual virtual pipeline chunks
for i, model_chunk in enumerate(models):
    print(f"Chunk {i}: {model_chunk.num_parameters()} parameters")

# Save in MCA format
models.save_pretrained("/path/to/output")

# Export to HuggingFace format
models.save_pretrained_as_hf("/path/to/hf_output")

# Context parallel batch slicing
batch = {"input_ids": input_ids, "attention_mask": attn_mask, "labels": labels}
cp_batch = models.get_batch_on_this_cp_rank(batch)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment