Implementation:Alibaba ROLL Patcher
| Knowledge Sources | |
|---|---|
| Domains | Checkpointing, Bug_Fixes |
| Last Updated | 2026-02-07 20:00 GMT |
Overview
Monkey-patches PyTorch distributed checkpoint functions to fix performance and correctness bugs in shard overlap detection and global plan validation.
Description
This module applies targeted monkey-patches to PyTorch's distributed checkpointing subsystem, addressing two specific upstream bugs that affect checkpoint save/load performance and correctness at scale.
patch_torch_find_nd_overlapping_shards (lines 27-79): Fixes PyTorch issue #166941 by replacing the default N-dimensional shard overlap detection with an efficient sweep-line algorithm. The original implementation had quadratic complexity when checking for overlapping shard metadata, which became a bottleneck with many shards. The patched version:
- Selects the sweep dimension with the largest range for optimal pruning
- Sorts shard indices by their offset along the sweep dimension
- Maintains an active set of shards sorted by their end position
- Uses bisect_right to efficiently prune shards that can no longer overlap
- Only checks _check_shard_metadata_pair_overlap for shards still in the active window
This reduces the complexity from O(N^2) to approximately O(N log N) in typical cases.
patch_torch_validate_global_plan (lines 82-160): Fixes PyTorch issue #163548 in the global checkpoint plan validation. The replacement function:
- Skips BytesStorageMetadata entries and zero-dimensional tensors
- Validates chunk bounds and computes chunk volumes
- Uses the same sweep-line algorithm (on dimension 0) to detect overlapping chunks efficiently
- Verifies that the combined chunk volume equals the tensor volume for multi-rank plans
Both patches are applied by directly assigning the new function implementations to the corresponding module-level names in PyTorch's internal modules.
Usage
Call these patch functions at application startup, before any distributed checkpoint operations. They are typically invoked once during the mcore_adapter initialization phase. Both patches are safe to apply unconditionally as they implement the same contract as the originals with improved performance.
Code Reference
Source Location
- Repository: Alibaba_ROLL
- File: mcore_adapter/src/mcore_adapter/patcher.py
- Lines: 1-160
Signature
def patch_torch_find_nd_overlapping_shards() -> None: ...
def patch_torch_validate_global_plan() -> None: ...
Import
from mcore_adapter.patcher import (
patch_torch_find_nd_overlapping_shards,
patch_torch_validate_global_plan,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| (none) | N/A | N/A | Both functions take no arguments |
Outputs
| Name | Type | Description |
|---|---|---|
| (patch_torch_find_nd_overlapping_shards) | None | Side effect: replaces torch.distributed._shard.sharding_spec._internals._find_nd_overlapping_shards with sweep-line implementation |
| (patch_torch_validate_global_plan) | None | Side effect: replaces torch.distributed.checkpoint.default_planner._validate_global_plan with optimized validation |
Usage Examples
from mcore_adapter.patcher import (
patch_torch_find_nd_overlapping_shards,
patch_torch_validate_global_plan,
)
# Apply patches at startup before any checkpoint operations
patch_torch_find_nd_overlapping_shards()
patch_torch_validate_global_plan()
# Now distributed checkpoint save/load will use the patched implementations
# Example: saving a distributed checkpoint
import torch.distributed.checkpoint as dist_cp
dist_cp.save_state_dict(state_dict, storage_writer)