Implementation:Facebookresearch Audiocraft StandardSolver Checkpoints
Overview
The checkpoint management system in AudioCraft is implemented across StandardSolver (the base solver class) and the audiocraft.utils.checkpoint utility module. Together, they provide saving, loading, and restoration of full training state with support for FSDP sharded checkpoints, best state tracking, and safe atomic writes.
Source Locations
| Component | Source File | Lines |
|---|---|---|
StandardSolver.save_checkpoints() |
audiocraft/solvers/base.py |
282-309 |
StandardSolver.load_checkpoints() |
audiocraft/solvers/base.py |
314-430 |
StandardSolver.restore() |
audiocraft/solvers/base.py |
432-454 |
checkpoint_name() |
audiocraft/utils/checkpoint.py |
28-48 |
resolve_checkpoint_path() |
audiocraft/utils/checkpoint.py |
56-84 |
load_checkpoint() |
audiocraft/utils/checkpoint.py |
87-95 |
save_checkpoint() |
audiocraft/utils/checkpoint.py |
98-101 |
flush_stale_checkpoints() |
audiocraft/utils/checkpoint.py |
104-122 |
APIs
StandardSolver.save_checkpoints
def save_checkpoints(self) -> None
Saves the current training state as a checkpoint. Behavior:
- Only rank 0 saves non-sharded checkpoints; all ranks save for FSDP.
- Saves a periodic epoch checkpoint (every
cfg.checkpoint.save_everyepochs) with optionally reduced state. - Saves the latest checkpoint (if
cfg.checkpoint.save_lastis True). - Flushes stale epoch checkpoints to manage disk usage.
StandardSolver.load_checkpoints
def load_checkpoints(
self,
load_best: bool = False,
ignore_state_keys: List[str] = []
) -> Optional[dict]
Loads the most recent checkpoint or the one specified by cfg.continue_from. Handles three source types:
| Source | Trigger | Behavior |
|---|---|---|
| Current XP | Checkpoint exists in current experiment folder | Full state restore (model, optimizer, scheduler, EMA, history) |
| Other XP | continue_from is set (not //pretrained/) |
Load best state only, re-initialize optimizer |
| Pretrained | continue_from starts with //pretrained/ |
Load pretrained weights as best state |
After loading, verifies epoch consistency across all distributed ranks.
StandardSolver.restore
def restore(
self,
load_best: bool = False,
replay_metrics: bool = False,
ignore_state_keys: List[str] = []
) -> bool
High-level restoration method that:
- Calls
load_checkpoints()to load state - Optionally replays past metrics to result loggers
- Returns whether a checkpoint was successfully loaded
Checkpoint Utility Functions
checkpoint_name
def checkpoint_name(
name: Optional[str] = None,
rank: Optional[int] = None,
use_fsdp: bool = False
) -> str
Generates checkpoint filenames following the convention: checkpoint_{name}.th{.rank}
Examples:
- Latest:
checkpoint.th - Epoch 50:
checkpoint_50.th - FSDP rank 1:
checkpoint.th.1
resolve_checkpoint_path
def resolve_checkpoint_path(
sig_or_path: Union[Path, str],
name: Optional[str] = None,
use_fsdp: bool = False
) -> Optional[Path]
Resolves a checkpoint path from either:
- A Dora signature (prefixed with
//sig/) -- resolves to{dora_dir}/xps/{sig}/checkpoint.th - A filesystem path -- optionally resolved via
AudioCraftEnvironment.resolve_reference_path()
Returns None if the resolved path does not exist.
load_checkpoint / save_checkpoint
def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> Any
def save_checkpoint(state: Any, checkpoint_path: Path, is_sharded: bool = False) -> None
Low-level load/save using torch.load/torch.save with sharded checkpoint safety checks.
flush_stale_checkpoints
def flush_stale_checkpoints(checkpoint_path: Path, keep_last: Optional[int] = None) -> None
Removes old epoch checkpoints, keeping only the most recent keep_last.
Inputs and Outputs
Inputs (save):
- Training state dictionary containing all registered stateful objects:
model-- model weightsoptimizer-- optimizer statelr_scheduler-- scheduler stateema-- EMA module statebest_state-- best model weights (by validation metric)fsdp_best_state-- full best state for FSDP (rank 0 only)- Epoch counter and training history
Outputs (save):
- Checkpoint files on disk:
checkpoint.th-- latest checkpointcheckpoint_{epoch}.th-- periodic epoch checkpointscheckpoint.th.{rank}-- per-rank sharded checkpoints (FSDP)
Inputs (load):
- Checkpoint file path (auto-resolved or specified via
continue_from)
Outputs (load):
- Restored state dictionary loaded into all registered stateful objects
Save Flow
def save_checkpoints(self):
is_sharded = self.cfg.fsdp.use
if not flashy.distrib.is_rank_zero() and not is_sharded:
return
state = self.state_dict()
epoch = self.epoch - 1
# Periodic epoch checkpoint
if self.cfg.checkpoint.save_every:
if epoch % self.cfg.checkpoint.save_every == 0:
minimal_state = state
if self.cfg.checkpoint.keep_every_states is not None:
minimal_state = {name: source for name, source in state.items()
if name in self.cfg.checkpoint.keep_every_states}
checkpoint.save_checkpoint(minimal_state,
self.epoch_checkpoint_path(epoch), is_sharded)
# Latest checkpoint
if self.cfg.checkpoint.save_last:
checkpoint.save_checkpoint(state,
self.checkpoint_path(), is_sharded)
# Cleanup old checkpoints
checkpoint.flush_stale_checkpoints(self.checkpoint_path())
Load Flow
def load_checkpoints(self, load_best=False, ignore_state_keys=[]):
# Determine source
if rank0_checkpoint_path.exists():
# Resume current experiment
load_from_path = current_checkpoint_path
checkpoint_source = CheckpointSource.CURRENT_XP
elif self.cfg.continue_from:
# Continue from other experiment
load_from_path = checkpoint.resolve_checkpoint_path(self.cfg.continue_from)
checkpoint_source = CheckpointSource.OTHER
# Load state
state = checkpoint.load_checkpoint(load_from_path, is_sharded)
# For non-current-XP sources, only keep best state
if checkpoint_source != CheckpointSource.CURRENT_XP:
load_best = True
state = {key: state[key] for key in self._continue_best_source_keys
if key in state}
# Apply state
self.load_state_dict(state)
# Verify distributed consistency
avg_epoch = flashy.distrib.average_metrics({'epoch': float(self.epoch)})['epoch']
if avg_epoch != float(self.epoch):
raise RuntimeError("Inconsistent checkpoint loading across ranks")
# Load best state if requested
if load_best:
self._load_new_state_dict(self.best_state.state_dict())
if self.ema is not None:
self.initialize_ema() # Re-init EMA from best state
Safe Write Mechanism
The _safe_save_checkpoint function in checkpoint.py uses:
- Write to a temporary file via
flashy.utils.write_and_rename - Barrier synchronization between FSDP ranks
- A
.tmp.donetoken file to coordinate completion across ranks - Atomic rename from temporary to final path
Configuration Parameters
| Parameter | Default | Description |
|---|---|---|
checkpoint.save_last |
true |
Save latest checkpoint every epoch |
checkpoint.save_every |
50 |
Save periodic epoch checkpoints |
checkpoint.keep_last |
10 |
Number of epoch checkpoints to retain |
checkpoint.keep_every_states |
null |
If set, only save these state keys in periodic checkpoints |
Dependencies
torch--torch.save,torch.loadfor serializationflashy-- distributed utilities (rank detection, barriers, write_and_rename)dora-- experiment folder managementaudiocraft.utils.checkpoint-- low-level checkpoint utilitiesaudiocraft.environment.AudioCraftEnvironment-- reference path resolution