Principle:Allenai Open instruct GRPO Checkpointing
| Knowledge Sources | |
|---|---|
| Domains | Training Infrastructure Reinforcement Learning |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
GRPO checkpointing is the process of periodically saving model weights, optimizer states, and training metadata during RL training to enable resumption from failures and export of trained models.
Description
RL training runs are typically long (hours to days) and expensive. Checkpointing serves multiple purposes:
- Fault tolerance: If training crashes (GPU failure, OOM, node preemption), the run can resume from the last checkpoint rather than starting over.
- Model selection: Intermediate checkpoints allow selecting the best model based on evaluation metrics, rather than being forced to use the final model.
- Model distribution: The final trained model is saved in HuggingFace format for sharing and deployment.
The GRPO checkpointing system has two distinct modes:
- Lightweight model checkpoints (controlled by
save_freq): Save only the model weights and tokenizer at regular intervals. These are fast to save and can be used for evaluation or fine-tuning. Old checkpoints are pruned to keep only the last N (keep_last_n_checkpoints).
- Full training state checkpoints (controlled by
checkpoint_state_freq): Save the complete DeepSpeed state including model weights, optimizer states, learning rate scheduler state, RNG states for all devices, data loader state (epoch, batch position, excluded indices), and the training step counter. These enable exact resumption of training from the checkpoint.
The final model save additionally:
- Converts from DeepSpeed format to standard HuggingFace format.
- Optionally pushes to the HuggingFace Hub.
- Optionally launches downstream evaluation jobs on Beaker.
- Optionally uploads to Google Cloud Storage for archival.
Usage
Checkpointing is configured at experiment setup time via save_freq, checkpoint_state_freq, checkpoint_state_dir, output_dir, and push_to_hub. The main training loop calls checkpointing functions at the appropriate intervals.
Theoretical Basis
Checkpoint Consistency
For distributed training, checkpoint consistency requires that all ranks save their state at the same logical point:
For each checkpoint:
barrier() -- all ranks reach the same training step
rank_0: save model weights (gathered from all ranks if ZeRO stage 3)
all_ranks: save DeepSpeed state (optimizer, scheduler)
all_ranks: save RNG states (CPU, CUDA per device, numpy, python)
rank_0: save data loader state (epoch, batch position, excluded indices)
barrier() -- all ranks confirm save completion
Resumption Protocol
When resuming from a checkpoint:
1. Load DeepSpeed state (model weights, optimizer, scheduler)
2. Restore RNG states on each device
3. Restore data loader state:
a. Set epoch to checkpoint epoch
b. Reshuffle with the same seed + epoch
c. Skip to the checkpointed batch position
4. Restore data preparation actor state:
a. Set training step to checkpoint step
b. Restore iterator dataloader state
5. Resume training loop from the next step
This protocol ensures exact reproducibility: given the same hardware configuration and random seeds, a resumed run will produce identical results to an uninterrupted run.
Reference Policy Checkpointing
When using a reference policy with Polyak updates, the reference policy weights must also be saved and restored. This is stored as a separate pytorch_model.bin in a ref_policy/ subdirectory of the checkpoint.