Principle:Huggingface Transformers Distributed Checkpointing
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Computing, Training, Checkpointing |
| Last Updated | 2026-02-13 00:00 GMT |
Overview
Distributed checkpointing saves and restores distributed training state across multiple processes, preserving the sharded representation of model parameters and optimizer states without requiring global gathering.
Description
In distributed training with tensor parallelism, data parallelism, and context parallelism, the model parameters and optimizer states are distributed across multiple devices in various sharded configurations. Distributed Checkpointing (DCP) saves this distributed state directly -- each rank writes its local shard to storage -- rather than first gathering all shards to a single rank and writing a monolithic checkpoint.
This approach offers several advantages:
- Memory efficiency: No single rank needs enough memory to hold the full model and optimizer state.
- Parallelized I/O: All ranks write simultaneously, reducing checkpoint time proportional to the number of ranks.
- Reshardability: DCP checkpoints can be loaded with a different number of ranks or a different parallelism configuration, as the framework handles redistribution automatically.
The DCP API in PyTorch (torch.distributed.checkpoint) works with a Stateful protocol: any object that implements state_dict() and load_state_dict() can be saved and restored. The typical pattern wraps both the model and optimizer in an application state object (AppState) that uses get_state_dict and set_state_dict from PyTorch's distributed checkpoint utilities to properly handle DTensor, FSDP, and TP state.
Key components:
- AppState: A
Statefulwrapper that manages model and optimizer state dicts together. - get_state_dict / set_state_dict: Utilities that extract/restore state dicts in a distributed-aware manner, handling DTensor placement, FSDP flat parameters, and other distributed tensor formats.
- dcp.save / dcp.load: The top-level APIs that coordinate multi-rank checkpoint I/O.
Usage
Use distributed checkpointing when:
- Training with any form of model parallelism (TP, FSDP with sharding) where the model state is distributed.
- You want to save and resume training without requiring all state to be gathered to a single rank.
- You need the flexibility to resume training with a different parallelism topology.
- Checkpoint size is large enough that single-rank gathering would cause OOM or unacceptable latency.
DCP is typically called at the end of training or at periodic intervals. It requires the distributed process group to be active.
Theoretical Basis
Distributed checkpointing addresses the checkpointing scalability problem in large-scale distributed training. In a system with N ranks, each holding M/N bytes of model and optimizer state (where M is the total state size), the naive approach of gathering to rank 0 requires:
- Memory:
O(M)on rank 0, which may exceed GPU or host memory for large models. - Communication:
O(M)bytes transferred to rank 0. - I/O time: Sequential write of
Mbytes.
With DCP:
- Memory:
O(M/N)per rank (no additional memory needed). - Communication: Only local metadata coordination.
- I/O time: Parallel writes of
M/Nbytes per rank, reducing wall time by up toNx.
The Stateful protocol provides a clean abstraction that decouples the checkpoint format from the parallelism strategy. By using get_state_dict, the DTensor metadata (device mesh, placements) is preserved alongside the tensor data, enabling the set_state_dict / dcp.load path to redistribute tensors to a potentially different mesh topology at load time.