Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Alibaba ROLL Patcher

From Leeroopedia


Knowledge Sources
Domains Checkpointing, Bug_Fixes
Last Updated 2026-02-07 20:00 GMT

Overview

Monkey-patches PyTorch distributed checkpoint functions to fix performance and correctness bugs in shard overlap detection and global plan validation.

Description

This module applies targeted monkey-patches to PyTorch's distributed checkpointing subsystem, addressing two specific upstream bugs that affect checkpoint save/load performance and correctness at scale.

patch_torch_find_nd_overlapping_shards (lines 27-79): Fixes PyTorch issue #166941 by replacing the default N-dimensional shard overlap detection with an efficient sweep-line algorithm. The original implementation had quadratic complexity when checking for overlapping shard metadata, which became a bottleneck with many shards. The patched version:

  1. Selects the sweep dimension with the largest range for optimal pruning
  2. Sorts shard indices by their offset along the sweep dimension
  3. Maintains an active set of shards sorted by their end position
  4. Uses bisect_right to efficiently prune shards that can no longer overlap
  5. Only checks _check_shard_metadata_pair_overlap for shards still in the active window

This reduces the complexity from O(N^2) to approximately O(N log N) in typical cases.

patch_torch_validate_global_plan (lines 82-160): Fixes PyTorch issue #163548 in the global checkpoint plan validation. The replacement function:

  • Skips BytesStorageMetadata entries and zero-dimensional tensors
  • Validates chunk bounds and computes chunk volumes
  • Uses the same sweep-line algorithm (on dimension 0) to detect overlapping chunks efficiently
  • Verifies that the combined chunk volume equals the tensor volume for multi-rank plans

Both patches are applied by directly assigning the new function implementations to the corresponding module-level names in PyTorch's internal modules.

Usage

Call these patch functions at application startup, before any distributed checkpoint operations. They are typically invoked once during the mcore_adapter initialization phase. Both patches are safe to apply unconditionally as they implement the same contract as the originals with improved performance.

Code Reference

Source Location

Signature

def patch_torch_find_nd_overlapping_shards() -> None: ...

def patch_torch_validate_global_plan() -> None: ...

Import

from mcore_adapter.patcher import (
    patch_torch_find_nd_overlapping_shards,
    patch_torch_validate_global_plan,
)

I/O Contract

Inputs

Name Type Required Description
(none) N/A N/A Both functions take no arguments

Outputs

Name Type Description
(patch_torch_find_nd_overlapping_shards) None Side effect: replaces torch.distributed._shard.sharding_spec._internals._find_nd_overlapping_shards with sweep-line implementation
(patch_torch_validate_global_plan) None Side effect: replaces torch.distributed.checkpoint.default_planner._validate_global_plan with optimized validation

Usage Examples

from mcore_adapter.patcher import (
    patch_torch_find_nd_overlapping_shards,
    patch_torch_validate_global_plan,
)

# Apply patches at startup before any checkpoint operations
patch_torch_find_nd_overlapping_shards()
patch_torch_validate_global_plan()

# Now distributed checkpoint save/load will use the patched implementations
# Example: saving a distributed checkpoint
import torch.distributed.checkpoint as dist_cp
dist_cp.save_state_dict(state_dict, storage_writer)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment