Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Facebookresearch Audiocraft StandardSolver Checkpoints

From Leeroopedia

Overview

The checkpoint management system in AudioCraft is implemented across StandardSolver (the base solver class) and the audiocraft.utils.checkpoint utility module. Together, they provide saving, loading, and restoration of full training state with support for FSDP sharded checkpoints, best state tracking, and safe atomic writes.

Source Locations

Component Source File Lines
StandardSolver.save_checkpoints() audiocraft/solvers/base.py 282-309
StandardSolver.load_checkpoints() audiocraft/solvers/base.py 314-430
StandardSolver.restore() audiocraft/solvers/base.py 432-454
checkpoint_name() audiocraft/utils/checkpoint.py 28-48
resolve_checkpoint_path() audiocraft/utils/checkpoint.py 56-84
load_checkpoint() audiocraft/utils/checkpoint.py 87-95
save_checkpoint() audiocraft/utils/checkpoint.py 98-101
flush_stale_checkpoints() audiocraft/utils/checkpoint.py 104-122

APIs

StandardSolver.save_checkpoints

def save_checkpoints(self) -> None

Saves the current training state as a checkpoint. Behavior:

  • Only rank 0 saves non-sharded checkpoints; all ranks save for FSDP.
  • Saves a periodic epoch checkpoint (every cfg.checkpoint.save_every epochs) with optionally reduced state.
  • Saves the latest checkpoint (if cfg.checkpoint.save_last is True).
  • Flushes stale epoch checkpoints to manage disk usage.

StandardSolver.load_checkpoints

def load_checkpoints(
    self,
    load_best: bool = False,
    ignore_state_keys: List[str] = []
) -> Optional[dict]

Loads the most recent checkpoint or the one specified by cfg.continue_from. Handles three source types:

Source Trigger Behavior
Current XP Checkpoint exists in current experiment folder Full state restore (model, optimizer, scheduler, EMA, history)
Other XP continue_from is set (not //pretrained/) Load best state only, re-initialize optimizer
Pretrained continue_from starts with //pretrained/ Load pretrained weights as best state

After loading, verifies epoch consistency across all distributed ranks.

StandardSolver.restore

def restore(
    self,
    load_best: bool = False,
    replay_metrics: bool = False,
    ignore_state_keys: List[str] = []
) -> bool

High-level restoration method that:

  1. Calls load_checkpoints() to load state
  2. Optionally replays past metrics to result loggers
  3. Returns whether a checkpoint was successfully loaded

Checkpoint Utility Functions

checkpoint_name

def checkpoint_name(
    name: Optional[str] = None,
    rank: Optional[int] = None,
    use_fsdp: bool = False
) -> str

Generates checkpoint filenames following the convention: checkpoint_{name}.th{.rank}

Examples:

  • Latest: checkpoint.th
  • Epoch 50: checkpoint_50.th
  • FSDP rank 1: checkpoint.th.1

resolve_checkpoint_path

def resolve_checkpoint_path(
    sig_or_path: Union[Path, str],
    name: Optional[str] = None,
    use_fsdp: bool = False
) -> Optional[Path]

Resolves a checkpoint path from either:

  • A Dora signature (prefixed with //sig/) -- resolves to {dora_dir}/xps/{sig}/checkpoint.th
  • A filesystem path -- optionally resolved via AudioCraftEnvironment.resolve_reference_path()

Returns None if the resolved path does not exist.

load_checkpoint / save_checkpoint

def load_checkpoint(checkpoint_path: Path, is_sharded: bool = False) -> Any
def save_checkpoint(state: Any, checkpoint_path: Path, is_sharded: bool = False) -> None

Low-level load/save using torch.load/torch.save with sharded checkpoint safety checks.

flush_stale_checkpoints

def flush_stale_checkpoints(checkpoint_path: Path, keep_last: Optional[int] = None) -> None

Removes old epoch checkpoints, keeping only the most recent keep_last.

Inputs and Outputs

Inputs (save):

  • Training state dictionary containing all registered stateful objects:
    • model -- model weights
    • optimizer -- optimizer state
    • lr_scheduler -- scheduler state
    • ema -- EMA module state
    • best_state -- best model weights (by validation metric)
    • fsdp_best_state -- full best state for FSDP (rank 0 only)
    • Epoch counter and training history

Outputs (save):

  • Checkpoint files on disk:
    • checkpoint.th -- latest checkpoint
    • checkpoint_{epoch}.th -- periodic epoch checkpoints
    • checkpoint.th.{rank} -- per-rank sharded checkpoints (FSDP)

Inputs (load):

  • Checkpoint file path (auto-resolved or specified via continue_from)

Outputs (load):

  • Restored state dictionary loaded into all registered stateful objects

Save Flow

def save_checkpoints(self):
    is_sharded = self.cfg.fsdp.use
    if not flashy.distrib.is_rank_zero() and not is_sharded:
        return
    state = self.state_dict()
    epoch = self.epoch - 1

    # Periodic epoch checkpoint
    if self.cfg.checkpoint.save_every:
        if epoch % self.cfg.checkpoint.save_every == 0:
            minimal_state = state
            if self.cfg.checkpoint.keep_every_states is not None:
                minimal_state = {name: source for name, source in state.items()
                                 if name in self.cfg.checkpoint.keep_every_states}
            checkpoint.save_checkpoint(minimal_state,
                self.epoch_checkpoint_path(epoch), is_sharded)

    # Latest checkpoint
    if self.cfg.checkpoint.save_last:
        checkpoint.save_checkpoint(state,
            self.checkpoint_path(), is_sharded)

    # Cleanup old checkpoints
    checkpoint.flush_stale_checkpoints(self.checkpoint_path())

Load Flow

def load_checkpoints(self, load_best=False, ignore_state_keys=[]):
    # Determine source
    if rank0_checkpoint_path.exists():
        # Resume current experiment
        load_from_path = current_checkpoint_path
        checkpoint_source = CheckpointSource.CURRENT_XP
    elif self.cfg.continue_from:
        # Continue from other experiment
        load_from_path = checkpoint.resolve_checkpoint_path(self.cfg.continue_from)
        checkpoint_source = CheckpointSource.OTHER

    # Load state
    state = checkpoint.load_checkpoint(load_from_path, is_sharded)

    # For non-current-XP sources, only keep best state
    if checkpoint_source != CheckpointSource.CURRENT_XP:
        load_best = True
        state = {key: state[key] for key in self._continue_best_source_keys
                 if key in state}

    # Apply state
    self.load_state_dict(state)

    # Verify distributed consistency
    avg_epoch = flashy.distrib.average_metrics({'epoch': float(self.epoch)})['epoch']
    if avg_epoch != float(self.epoch):
        raise RuntimeError("Inconsistent checkpoint loading across ranks")

    # Load best state if requested
    if load_best:
        self._load_new_state_dict(self.best_state.state_dict())
        if self.ema is not None:
            self.initialize_ema()  # Re-init EMA from best state

Safe Write Mechanism

The _safe_save_checkpoint function in checkpoint.py uses:

  1. Write to a temporary file via flashy.utils.write_and_rename
  2. Barrier synchronization between FSDP ranks
  3. A .tmp.done token file to coordinate completion across ranks
  4. Atomic rename from temporary to final path

Configuration Parameters

Parameter Default Description
checkpoint.save_last true Save latest checkpoint every epoch
checkpoint.save_every 50 Save periodic epoch checkpoints
checkpoint.keep_last 10 Number of epoch checkpoints to retain
checkpoint.keep_every_states null If set, only save these state keys in periodic checkpoints

Dependencies

  • torch -- torch.save, torch.load for serialization
  • flashy -- distributed utilities (rank detection, barriers, write_and_rename)
  • dora -- experiment folder management
  • audiocraft.utils.checkpoint -- low-level checkpoint utilities
  • audiocraft.environment.AudioCraftEnvironment -- reference path resolution

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment