Principle:Allenai Open instruct Checkpoint Saving
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, Distributed Systems, MLOps |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Checkpoint saving is the process of persisting model weights, tokenizer files, and generation configuration to disk during and after training, with special handling for distributed training and parameter-efficient fine-tuning.
Description
Training large language models can take hours to days. Without checkpointing, a hardware failure, preemption, or bug would require restarting from scratch. Checkpoint saving provides fault tolerance and also produces intermediate model snapshots that can be evaluated or used for downstream tasks.
Checkpoint saving in distributed training is more complex than single-GPU training because:
State dict gathering: In distributed training (e.g., with DeepSpeed ZeRO), model parameters are sharded across multiple GPUs. Saving requires gathering the full state dict to a single process. HuggingFace Accelerate's get_state_dict() handles this, but the wrapped model (not the unwrapped one) must be used to get the complete state dict.
Main process saving: Only the main process (rank 0) should write to disk to avoid file conflicts. Other processes may have state_dict = None when using DeepSpeed ZeRO.
LoRA adapter saving: When using LoRA, only the adapter weights (not the full model) need to be saved. The PEFT library's save_pretrained() handles this, saving only the small LoRA matrices. This dramatically reduces checkpoint size.
Model attribute saving: For composite models (e.g., PPO models with separate policy and value heads), a specific attribute of the model can be saved by filtering the state dict to only include keys starting with the attribute name.
Generation config: The model's generation configuration must be set correctly before saving, as it affects inference behavior. For OLMo models, a specific generation config with two EOS tokens is used.
Tokenizer saving: The tokenizer is saved alongside the model to ensure inference uses the exact same tokenization.
Usage
Use checkpoint saving at regular intervals during training and always at the end of training. Configure the frequency based on the trade-off between disk space and recovery time.
Theoretical Basis
Checkpoint frequency trade-off:
Cost of checkpointing every N steps:
Disk usage = (total_steps / N) * checkpoint_size
Time overhead = (total_steps / N) * save_time
Cost of NOT checkpointing (failure at step T):
Wasted compute = T * cost_per_step
Optimal N minimizes:
E[total_cost] = training_cost + P(failure) * (N/2) * cost_per_step + (total_steps / N) * save_cost
Where P(failure) is the probability of failure during training.
State dict gathering in ZeRO-3:
ZeRO-3 shards parameters across D devices:
device_i holds: params[i * P/D : (i+1) * P/D]
Gathering for checkpoint:
all_gather(params) -> full_params on rank 0
rank 0: save(full_params)
all other ranks: no-op
LoRA checkpoint size:
Full model checkpoint: ~2 * P bytes (bfloat16)
LoRA checkpoint: ~2 * P_lora bytes
For 7B model with rank=64:
Full: ~14 GB
LoRA: ~58 MB (0.4% of full)