Principle:Facebookresearch Audiocraft Finetuning Checkpoint Preparation
Overview
Finetuning Checkpoint Preparation is the process of manually manipulating model checkpoints to prepare them for fine-tuning on new data or tasks. Unlike the automated export pipeline, this involves direct manipulation of PyTorch state dictionaries and Hydra configurations to adapt a pretrained model's checkpoint for use as a starting point in a new training run. This is a pattern-level principle -- there is no dedicated API; instead, users work directly with torch.load(), torch.save(), and dictionary operations.
Theoretical Background
Fine-tuning a pretrained model often requires modifications to the checkpoint that go beyond simple loading:
- Configuration adjustment: The original training configuration may need changes for the new task (e.g., different learning rate, dataset paths, or model architecture modifications such as adding new conditioning heads).
- State dict surgery: When the model architecture changes between pretraining and fine-tuning (e.g., adding a new layer), the state dictionary must be manually adjusted to include new randomly-initialized parameters while preserving pretrained weights.
- Key remapping: Different training frameworks or model versions may use different naming conventions for the same parameters, requiring key renaming in the state dictionary.
- Selective weight loading: Only certain parts of the model (e.g., the transformer backbone but not the conditioning modules) may need to be loaded from the pretrained checkpoint.
Checkpoint Dictionary Structure
AudioCraft training checkpoints contain a structured dictionary with the following important keys:
| Key | Type | Description |
|---|---|---|
best_state |
dict |
The best model state dictionary, nested as {'model': state_dict}
|
fsdp_best_state |
dict |
FSDP-specific best state, nested as {'model': state_dict}; present when training with FSDP
|
xp.cfg |
DictConfig |
The full Hydra/OmegaConf experiment configuration |
| Optimizer states | dict |
Optimizer parameter groups, momentum buffers, etc. |
| EMA states | dict |
Exponential moving average of model parameters |
Common Manipulation Patterns
Pattern 1: Config Override
Modify the experiment configuration before saving:
import torch
from omegaconf import OmegaConf
pkg = torch.load('checkpoint_best.th', 'cpu')
cfg = pkg['xp.cfg']
cfg.optim.lr = 1e-5 # lower learning rate for fine-tuning
cfg.dataset.train.path = '/new/dataset/path'
pkg['xp.cfg'] = cfg
torch.save(pkg, 'checkpoint_finetuning.th')
Pattern 2: State Dict Surgery
Add, remove, or rename keys in the model state dictionary:
pkg = torch.load('checkpoint_best.th', 'cpu')
state = pkg['best_state']['model']
# Remove keys that no longer exist in the new architecture
keys_to_remove = [k for k in state if 'old_module' in k]
for k in keys_to_remove:
del state[k]
# Rename keys for architectural changes
state['new_module.weight'] = state.pop('old_module.weight')
pkg['best_state']['model'] = state
torch.save(pkg, 'checkpoint_modified.th')
Design Rationale
- No dedicated API by design: Checkpoint manipulation for fine-tuning is inherently task-specific. A general-purpose API would either be too restrictive or too complex. Direct dictionary manipulation gives maximum flexibility.
- Standard PyTorch patterns: Using
torch.load()/torch.save()and dictionary operations is idiomatic PyTorch, requiring no AudioCraft-specific knowledge beyond the checkpoint key structure. - Separation from export: This is intentionally separate from the export pipeline. Export produces minimal inference-only packages; fine-tuning preparation produces modified training checkpoints that retain the full training state structure.