Principle:ContextualAI HALOs Model Checkpointing
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Infrastructure |
| Last Updated | 2026-02-08 03:00 GMT |
Overview
A persistence strategy that saves model weights, optimizer state, and training metrics to disk, with automatic LoRA-to-base-model merging under distributed training.
Description
Model checkpointing in LLM training must handle several complexities beyond naive torch.save: distributed model sharding (FSDP splits weights across GPUs), parameter-efficient adapters (LoRA weights must be merged back into the base model for inference), and training resumability (optimizer and scheduler state must be preserved).
The HALOs checkpointing principle addresses these by:
- Using Accelerate's
get_state_dictto properly gather sharded weights from FSDP - Manually merging LoRA weights into the base model using the formula
W_merged = W_base + (alpha/r) * B @ Awhenfinal_save=True - Saving tokenizer, optimizer state, scheduler state, and metrics JSON alongside model weights
- Supporting both intermediate checkpoints (during training) and final checkpoints (with LoRA merge)
Usage
Model checkpointing is automatically invoked at the end of training and optionally at regular intervals. It is used in every training workflow: SFT, preference alignment, reward model training, and online iterative alignment (where each round saves a checkpoint that the next round loads).
Theoretical Basis
LoRA Merge Formula
When using Low-Rank Adaptation (LoRA), the model has additional low-rank matrices A and B. The merged weight is computed as:
Where:
- is the original pre-trained weight
- and are the LoRA matrices
- is the LoRA scaling factor
- is the LoRA rank
This merge produces a single set of weights that can be loaded without the PEFT library.