Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Alibaba ROLL Checkpointing

From Leeroopedia


Knowledge Sources
Domains Checkpointing, Distributed_Computing
Last Updated 2026-02-07 20:00 GMT

Overview

Checkpoint management module for distributed Megatron-Core models, handling save and load operations with tensor, pipeline, and expert parallelism support.

Description

This module provides a comprehensive set of functions for managing checkpoints in distributed Megatron-Core training environments. It is modified from Megatron-LM training/checkpointing.py and extends it with support for multiple parallelism dimensions (tensor, pipeline, expert). The module uses a tracker file mechanism (via TRACKER_FILENAME constant) to record the latest checkpoint iteration, enabling seamless training resumption. Key capabilities include:

  • Tracker-based metadata management: Reading and writing iteration metadata via a tracker file at the checkpoint root
  • Parallelism-aware directory naming: Constructing checkpoint paths that encode tensor rank, pipeline rank, and expert rank (e.g., mp_rank_00, mp_rank_00_000_001)
  • Rank-0 checkpoint discovery: Searching through all possible parallelism configurations to locate the rank-0 checkpoint for inspection or conversion
  • Distributed checkpoint detection: Integration with megatron.core.dist_checkpointing to handle both legacy single-file and modern distributed checkpoint formats
  • State dict save/load: Convenience functions for saving model configuration alongside state dictionaries and loading them back

Usage

Use this module when implementing checkpoint save/load logic for Megatron-Core distributed training. Call save_config_and_state_dict to persist a model's configuration and weights. Use load_state_dict_from_checkpoint for quick checkpoint loading. The find_checkpoint_rank_0 function is particularly useful for checkpoint inspection or conversion tools that need to locate the rank-0 checkpoint without knowing the parallelism configuration.

Code Reference

Source Location

Signature

def get_checkpoint_tracker_filename(checkpoints_path: str) -> str: ...

def ensure_directory_exists(filename: str, check_parent: bool = True) -> None: ...

def read_metadata(tracker_filename: str) -> tuple[int, bool]: ...

def get_checkpoint_dir(
    checkpoints_path: str,
    iteration: int = 1,
    release: bool = False,
    pipeline_parallel: bool | None = None,
    tensor_rank: int | None = None,
    pipeline_rank: int | None = None,
    expert_parallel: bool | None = None,
    expert_rank: int | None = None,
    return_base_dir: bool = False,
) -> str: ...

def get_checkpoint_name(
    checkpoints_path: str,
    iteration: int = 1,
    release: bool = False,
    pipeline_parallel: bool | None = None,
    tensor_rank: int | None = None,
    pipeline_rank: int | None = None,
    expert_parallel: bool | None = None,
    expert_rank: int | None = None,
    return_base_dir: bool = False,
) -> str: ...

def find_checkpoint_rank_0(
    checkpoints_path: str, iteration: int, release: bool = False
) -> str | None: ...

def _load_base_checkpoint(
    load_dir: str,
    rank0: bool = False,
    sharded_state_dict: dict | None = None,
    exit_on_missing_checkpoint: bool = True,
    checkpoint_step: int | None = None,
) -> tuple[dict | None, str, bool]: ...

def load_state_dict_from_checkpoint(checkpoint_dir: str) -> dict | None: ...

def save_config_and_state_dict(
    save_directory: str, config, state_dict: dict
) -> None: ...

Import

from mcore_adapter.checkpointing import (
    get_checkpoint_tracker_filename,
    ensure_directory_exists,
    read_metadata,
    get_checkpoint_dir,
    get_checkpoint_name,
    find_checkpoint_rank_0,
    load_state_dict_from_checkpoint,
    save_config_and_state_dict,
)

I/O Contract

Inputs

Name Type Required Description
checkpoints_path str Yes Root directory containing checkpoint subdirectories
iteration int No Training iteration number (default 1)
release bool No Whether this is a release checkpoint (default False)
pipeline_parallel bool or None No Whether pipeline parallelism is used; auto-detected if None
tensor_rank int or None No Tensor model parallel rank; auto-detected if None
pipeline_rank int or None No Pipeline model parallel rank; auto-detected if None
expert_parallel bool or None No Whether expert parallelism is used; auto-detected if None
expert_rank int or None No Expert model parallel rank; auto-detected if None
return_base_dir bool No If True, return just the iteration directory without rank suffix
load_dir str Yes Directory from which to load checkpoint
rank0 bool No If True, load only rank 0 checkpoint
exit_on_missing_checkpoint bool No If True, raise error when metadata file is missing
checkpoint_step int or None No Override iteration number read from tracker file
save_directory str Yes Directory to save checkpoint into
config object Yes Model configuration with save_pretrained method
state_dict dict Yes Model state dictionary to save

Outputs

Name Type Description
tracker_filename str Full path to the tracker file
checkpoint_dir str Directory path for the checkpoint at given parallelism config
checkpoint_name str Full path to the checkpoint file (model_optim_rng.pt)
state_dict dict or None Loaded state dictionary, or None if checkpoint not found
(read_metadata) tuple[int, bool] Iteration number and whether it is a release checkpoint
(find_checkpoint_rank_0) str or None Path to rank 0 checkpoint, or None if not found

Usage Examples

from mcore_adapter.checkpointing import (
    save_config_and_state_dict,
    load_state_dict_from_checkpoint,
    find_checkpoint_rank_0,
    read_metadata,
    get_checkpoint_tracker_filename,
)

# Save a checkpoint
save_config_and_state_dict("/checkpoints/my_model", config, model.state_dict())

# Load a checkpoint
state_dict = load_state_dict_from_checkpoint("/checkpoints/my_model")
if state_dict is not None:
    model.load_state_dict(state_dict)

# Find rank 0 checkpoint for inspection
tracker = get_checkpoint_tracker_filename("/checkpoints/my_model")
iteration, release = read_metadata(tracker)
rank0_path = find_checkpoint_rank_0("/checkpoints/my_model", iteration, release)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment