Implementation:Facebookresearch Audiocraft Checkpoint Manipulation
Overview
Checkpoint_Manipulation documents the user-space patterns for preparing AudioCraft model checkpoints for fine-tuning. There is no dedicated API for this -- it relies on direct torch.load()/torch.save() operations and dictionary manipulation following the checkpoint structure conventions.
Implements
Principle:Facebookresearch_Audiocraft_Finetuning_Checkpoint_Preparation
Pattern Documentation
This is a pattern doc describing user-space checkpoint manipulation techniques. The primary operations are:
import torch
# Load
pkg = torch.load(path, 'cpu')
# Manipulate
# ... modify pkg dict ...
# Save
torch.save(pkg, new_path)
Checkpoint Key Reference
| Key | Type | Description | Used In |
|---|---|---|---|
best_state |
dict |
Best model state, structured as {'model': OrderedDict(...)} |
Standard (non-FSDP) training |
fsdp_best_state |
dict |
Gathered FSDP best state, structured as {'model': OrderedDict(...)} |
FSDP distributed training |
xp.cfg |
DictConfig |
Full Hydra experiment configuration (model architecture, dataset, optimizer, etc.) | All training modes |
exported |
bool |
Whether the checkpoint is an export (inference-only) vs. full training checkpoint | Export detection |
version |
str |
AudioCraft version string | Version tracking |
Complete Fine-tuning Preparation Example
import torch
from omegaconf import OmegaConf
# 1. Load the pretrained checkpoint
pkg = torch.load('/path/to/pretrained/checkpoint_best.th', 'cpu')
# 2. Access and modify the configuration
cfg = pkg['xp.cfg']
# Adjust training hyperparameters
cfg.optim.lr = 1e-5
cfg.optim.epochs = 50
# Point to new dataset
cfg.dataset.train.path = '/path/to/finetune/dataset'
cfg.dataset.valid.path = '/path/to/finetune/valid'
# 3. Optionally modify the state dict
if pkg['fsdp_best_state']:
state = pkg['fsdp_best_state']['model']
else:
state = pkg['best_state']['model']
# Example: inspect available keys
print("Model keys:", list(state.keys())[:10])
# Example: remove a conditioning module for re-initialization
keys_to_remove = [k for k in state if 'condition_provider.conditioners.description' in k]
for k in keys_to_remove:
del state[k]
# 4. Update the checkpoint dict
if pkg['fsdp_best_state']:
pkg['fsdp_best_state']['model'] = state
else:
pkg['best_state']['model'] = state
pkg['xp.cfg'] = cfg
# 5. Save the modified checkpoint
torch.save(pkg, '/path/to/finetune/checkpoint_init.th')
Working with Exported Checkpoints
Exported checkpoints (from export_lm()) have a flattened structure compared to training checkpoints. If starting from an export rather than a training checkpoint:
# Exported checkpoint structure (simplified)
exported_pkg = torch.load('state_dict.bin', 'cpu')
# Keys: best_state (flat state dict), xp.cfg (YAML string), version, exported
# To use as a training checkpoint, wrap it back
training_pkg = {
'best_state': {'model': exported_pkg['best_state']},
'fsdp_best_state': {},
'xp.cfg': OmegaConf.create(exported_pkg['xp.cfg']),
}
torch.save(training_pkg, 'checkpoint_for_finetuning.th')
Config Serialization Note
In training checkpoints, xp.cfg is a live DictConfig object. In exported checkpoints, it is a YAML string (produced by OmegaConf.to_yaml()). The loaders in audiocraft/models/loaders.py handle both formats by calling OmegaConf.create() on the value, which is a no-op for DictConfig and parses YAML strings.