Principle:Mlfoundations Open flamingo Distributed Checkpointing
Overview
Fault tolerance strategy that saves model, optimizer, and scheduler state to disk at regular intervals during distributed training, with special handling for FSDP-sharded parameters.
Description
Checkpointing during distributed training requires gathering sharded parameters to a single rank for saving. With FSDP, the model state dict must be gathered using FULL_STATE_DICT type before saving. Only trainable parameters are saved (plus embeddings), reducing checkpoint size significantly. The checkpoint includes model_state_dict, optimizer_state_dict, lr_scheduler_state_dict, and epoch number. Optional wandb upload and previous checkpoint deletion are supported.
Usage
At the end of each training epoch or at regular step intervals to enable training resumption after failures.
Theoretical Basis
In distributed training, model parameters are sharded across GPUs. Checkpointing requires reconstructing the full parameter tensors on a single rank (typically rank 0) before serialization. FSDP provides set_state_dict_type() to configure how state dicts are gathered. Saving only trainable parameters (Perceiver + cross-attention + optional embeddings) reduces checkpoint size from the full model size to just the trainable fraction.