Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Allenai Open instruct Checkpoint Saving

From Leeroopedia


Knowledge Sources
Domains Machine Learning, Distributed Systems, MLOps
Last Updated 2026-02-07 00:00 GMT

Overview

Checkpoint saving is the process of persisting model weights, tokenizer files, and generation configuration to disk during and after training, with special handling for distributed training and parameter-efficient fine-tuning.

Description

Training large language models can take hours to days. Without checkpointing, a hardware failure, preemption, or bug would require restarting from scratch. Checkpoint saving provides fault tolerance and also produces intermediate model snapshots that can be evaluated or used for downstream tasks.

Checkpoint saving in distributed training is more complex than single-GPU training because:

State dict gathering: In distributed training (e.g., with DeepSpeed ZeRO), model parameters are sharded across multiple GPUs. Saving requires gathering the full state dict to a single process. HuggingFace Accelerate's get_state_dict() handles this, but the wrapped model (not the unwrapped one) must be used to get the complete state dict.

Main process saving: Only the main process (rank 0) should write to disk to avoid file conflicts. Other processes may have state_dict = None when using DeepSpeed ZeRO.

LoRA adapter saving: When using LoRA, only the adapter weights (not the full model) need to be saved. The PEFT library's save_pretrained() handles this, saving only the small LoRA matrices. This dramatically reduces checkpoint size.

Model attribute saving: For composite models (e.g., PPO models with separate policy and value heads), a specific attribute of the model can be saved by filtering the state dict to only include keys starting with the attribute name.

Generation config: The model's generation configuration must be set correctly before saving, as it affects inference behavior. For OLMo models, a specific generation config with two EOS tokens is used.

Tokenizer saving: The tokenizer is saved alongside the model to ensure inference uses the exact same tokenization.

Usage

Use checkpoint saving at regular intervals during training and always at the end of training. Configure the frequency based on the trade-off between disk space and recovery time.

Theoretical Basis

Checkpoint frequency trade-off:

Cost of checkpointing every N steps:
  Disk usage = (total_steps / N) * checkpoint_size
  Time overhead = (total_steps / N) * save_time

Cost of NOT checkpointing (failure at step T):
  Wasted compute = T * cost_per_step

Optimal N minimizes:
  E[total_cost] = training_cost + P(failure) * (N/2) * cost_per_step + (total_steps / N) * save_cost

Where P(failure) is the probability of failure during training.

State dict gathering in ZeRO-3:

ZeRO-3 shards parameters across D devices:
  device_i holds: params[i * P/D : (i+1) * P/D]

Gathering for checkpoint:
  all_gather(params) -> full_params on rank 0
  rank 0: save(full_params)
  all other ranks: no-op

LoRA checkpoint size:

Full model checkpoint: ~2 * P bytes (bfloat16)
LoRA checkpoint: ~2 * P_lora bytes

For 7B model with rank=64:
  Full: ~14 GB
  LoRA: ~58 MB (0.4% of full)

Related Pages

Implemented By

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment