Implementation:Alibaba ROLL Checkpointing
| Knowledge Sources | |
|---|---|
| Domains | Checkpointing, Distributed_Computing |
| Last Updated | 2026-02-07 20:00 GMT |
Overview
Checkpoint management module for distributed Megatron-Core models, handling save and load operations with tensor, pipeline, and expert parallelism support.
Description
This module provides a comprehensive set of functions for managing checkpoints in distributed Megatron-Core training environments. It is modified from Megatron-LM training/checkpointing.py and extends it with support for multiple parallelism dimensions (tensor, pipeline, expert). The module uses a tracker file mechanism (via TRACKER_FILENAME constant) to record the latest checkpoint iteration, enabling seamless training resumption. Key capabilities include:
- Tracker-based metadata management: Reading and writing iteration metadata via a tracker file at the checkpoint root
- Parallelism-aware directory naming: Constructing checkpoint paths that encode tensor rank, pipeline rank, and expert rank (e.g., mp_rank_00, mp_rank_00_000_001)
- Rank-0 checkpoint discovery: Searching through all possible parallelism configurations to locate the rank-0 checkpoint for inspection or conversion
- Distributed checkpoint detection: Integration with megatron.core.dist_checkpointing to handle both legacy single-file and modern distributed checkpoint formats
- State dict save/load: Convenience functions for saving model configuration alongside state dictionaries and loading them back
Usage
Use this module when implementing checkpoint save/load logic for Megatron-Core distributed training. Call save_config_and_state_dict to persist a model's configuration and weights. Use load_state_dict_from_checkpoint for quick checkpoint loading. The find_checkpoint_rank_0 function is particularly useful for checkpoint inspection or conversion tools that need to locate the rank-0 checkpoint without knowing the parallelism configuration.
Code Reference
Source Location
- Repository: Alibaba_ROLL
- File: mcore_adapter/src/mcore_adapter/checkpointing.py
- Lines: 1-279
Signature
def get_checkpoint_tracker_filename(checkpoints_path: str) -> str: ...
def ensure_directory_exists(filename: str, check_parent: bool = True) -> None: ...
def read_metadata(tracker_filename: str) -> tuple[int, bool]: ...
def get_checkpoint_dir(
checkpoints_path: str,
iteration: int = 1,
release: bool = False,
pipeline_parallel: bool | None = None,
tensor_rank: int | None = None,
pipeline_rank: int | None = None,
expert_parallel: bool | None = None,
expert_rank: int | None = None,
return_base_dir: bool = False,
) -> str: ...
def get_checkpoint_name(
checkpoints_path: str,
iteration: int = 1,
release: bool = False,
pipeline_parallel: bool | None = None,
tensor_rank: int | None = None,
pipeline_rank: int | None = None,
expert_parallel: bool | None = None,
expert_rank: int | None = None,
return_base_dir: bool = False,
) -> str: ...
def find_checkpoint_rank_0(
checkpoints_path: str, iteration: int, release: bool = False
) -> str | None: ...
def _load_base_checkpoint(
load_dir: str,
rank0: bool = False,
sharded_state_dict: dict | None = None,
exit_on_missing_checkpoint: bool = True,
checkpoint_step: int | None = None,
) -> tuple[dict | None, str, bool]: ...
def load_state_dict_from_checkpoint(checkpoint_dir: str) -> dict | None: ...
def save_config_and_state_dict(
save_directory: str, config, state_dict: dict
) -> None: ...
Import
from mcore_adapter.checkpointing import (
get_checkpoint_tracker_filename,
ensure_directory_exists,
read_metadata,
get_checkpoint_dir,
get_checkpoint_name,
find_checkpoint_rank_0,
load_state_dict_from_checkpoint,
save_config_and_state_dict,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| checkpoints_path | str | Yes | Root directory containing checkpoint subdirectories |
| iteration | int | No | Training iteration number (default 1) |
| release | bool | No | Whether this is a release checkpoint (default False) |
| pipeline_parallel | bool or None | No | Whether pipeline parallelism is used; auto-detected if None |
| tensor_rank | int or None | No | Tensor model parallel rank; auto-detected if None |
| pipeline_rank | int or None | No | Pipeline model parallel rank; auto-detected if None |
| expert_parallel | bool or None | No | Whether expert parallelism is used; auto-detected if None |
| expert_rank | int or None | No | Expert model parallel rank; auto-detected if None |
| return_base_dir | bool | No | If True, return just the iteration directory without rank suffix |
| load_dir | str | Yes | Directory from which to load checkpoint |
| rank0 | bool | No | If True, load only rank 0 checkpoint |
| exit_on_missing_checkpoint | bool | No | If True, raise error when metadata file is missing |
| checkpoint_step | int or None | No | Override iteration number read from tracker file |
| save_directory | str | Yes | Directory to save checkpoint into |
| config | object | Yes | Model configuration with save_pretrained method |
| state_dict | dict | Yes | Model state dictionary to save |
Outputs
| Name | Type | Description |
|---|---|---|
| tracker_filename | str | Full path to the tracker file |
| checkpoint_dir | str | Directory path for the checkpoint at given parallelism config |
| checkpoint_name | str | Full path to the checkpoint file (model_optim_rng.pt) |
| state_dict | dict or None | Loaded state dictionary, or None if checkpoint not found |
| (read_metadata) | tuple[int, bool] | Iteration number and whether it is a release checkpoint |
| (find_checkpoint_rank_0) | str or None | Path to rank 0 checkpoint, or None if not found |
Usage Examples
from mcore_adapter.checkpointing import (
save_config_and_state_dict,
load_state_dict_from_checkpoint,
find_checkpoint_rank_0,
read_metadata,
get_checkpoint_tracker_filename,
)
# Save a checkpoint
save_config_and_state_dict("/checkpoints/my_model", config, model.state_dict())
# Load a checkpoint
state_dict = load_state_dict_from_checkpoint("/checkpoints/my_model")
if state_dict is not None:
model.load_state_dict(state_dict)
# Find rank 0 checkpoint for inspection
tracker = get_checkpoint_tracker_filename("/checkpoints/my_model")
iteration, release = read_metadata(tracker)
rank0_path = find_checkpoint_rank_0("/checkpoints/my_model", iteration, release)