Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Facebookresearch Audiocraft Training Checkpoint Management

From Leeroopedia

Overview

Training Checkpoint Management is the system for persisting and restoring training state -- model weights, optimizer state, learning rate scheduler state, exponential moving average (EMA) weights, best model state, and epoch counters -- to enable experiment resumption, failure recovery, and model deployment. In the MusicGen training pipeline, checkpoint management must handle the complexities of distributed training (including FSDP sharded checkpoints), safe atomic writes to prevent corruption, and integration with Dora's experiment tracking system.

Theoretical Foundations

Training State Persistence

Deep learning training is a stateful process. At any point, the training state consists of:

  • Model weights -- The current parameters of the neural network.
  • Optimizer state -- Momentum buffers, adaptive learning rate statistics (e.g., Adam's first and second moment estimates).
  • LR scheduler state -- The current step count and any scheduler-specific state.
  • EMA state -- Exponential moving average of model weights, used for stable evaluation.
  • Best state -- A copy of the model weights that achieved the best validation metric.
  • Epoch counter -- Current epoch and training history.

All of these must be saved together to enable exact resumption of training. Missing or inconsistent state can lead to training divergence or wasted compute.

Distributed Checkpoint Challenges

When training with Fully Sharded Data Parallel (FSDP), model parameters are sharded across GPUs. This creates two checkpoint formats:

  • Rank-0 consolidated checkpoint -- A single file containing the full model state, used for evaluation and deployment. Only rank 0 writes this.
  • Sharded per-rank checkpoints -- Each rank saves its own shard. These are used for FSDP training resumption where each rank needs only its shard.

The checkpoint system must handle:

  • Atomic writes -- Using write-and-rename patterns to prevent corruption if a process dies mid-write.
  • Barrier synchronization -- Ensuring all ranks have finished writing before any rank starts reading.
  • Cross-format loading -- Loading rank-0 checkpoints into FSDP training (and vice versa).

Best State Management

During training, the best state tracks the model weights that achieved the lowest validation metric (typically cross-entropy). This is maintained separately from the latest state because:

  • Evaluation and generation use the best state, not the latest state.
  • If training diverges late, the best state preserves the best-performing weights.
  • When loading a checkpoint from a different experiment (via continue_from), only the best state is transferred.

The best state is updated after each validation epoch by comparing the current validation metric to historical values.

Experiment Resumption

The checkpoint system supports three resumption scenarios:

  1. Resume current experiment -- Load the latest checkpoint from the current Dora experiment folder. Restores full training state including optimizer.
  2. Continue from another experiment -- Load only the best state from another experiment's checkpoint (identified by Dora sig or path). Optimizer and scheduler are re-initialized.
  3. Continue from pretrained -- Load a pretrained model (e.g., from HuggingFace) and use only its weights as initialization. Triggered by the //pretrained/ prefix in continue_from.

Key Principles

  • Checkpoint-per-epoch -- A checkpoint is saved at the end of every epoch. Additionally, periodic full checkpoints (e.g., every 50 epochs) are saved for long-term preservation.
  • Atomic writes -- Checkpoints use a write-to-temporary-then-rename pattern to prevent corruption during power failures or OOM kills.
  • Consistent distributed state -- After loading, all ranks verify they have the same epoch number to detect out-of-sync checkpoint corruption.
  • Stale checkpoint flushing -- Old per-epoch checkpoints are automatically deleted to manage disk usage, keeping only the most recent N.
  • State source registration -- Components register themselves as stateful via register_stateful(), and the solver's state_dict() and load_state_dict() handle serialization of all registered components.

Role in the MusicGen Training Pipeline

Checkpoint management is the final stage of each training epoch and the first operation when resuming:

  1. On resume -- StandardSolver.restore() is called at the beginning of run(), loading the latest checkpoint and replaying metric history.
  2. On epoch end -- StandardSolver.commit() calls save_checkpoints() to persist the current state.
  3. On evaluation -- The best state is swapped in via swap_best_state() context manager.
  4. On deployment -- The best state checkpoint is extracted for model export.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment