Heuristic:Facebookresearch Audiocraft FSDP Distributed Training Tips
| Knowledge Sources | |
|---|---|
| Domains | Distributed_Training, Optimization, Infrastructure |
| Last Updated | 2026-02-13 23:00 GMT |
Overview
FSDP workarounds for AudioCraft: FULL_SHARD is unsupported (breaks LM generation), state dict switching requires manual handling due to PyTorch bugs, and gradient sync must precede gradient clipping.
Description
AudioCraft's FSDP (Fully Sharded Data Parallel) integration has several critical workarounds based on PyTorch FSDP limitations discovered during development. These are essential tribal knowledge for anyone training AudioCraft models at scale across multiple GPUs.
The three key insights are: (1) FULL_SHARD strategy is explicitly unsupported because it would flush model weights on every forward pass, making autoregressive generation impossible; (2) FSDP state dict type switching must be done manually due to a PyTorch API bug; (3) the FSDP wrapper reference is stored in __dict__ (not as an attribute) to avoid serialization into checkpoints.
Usage
Apply these tips when setting up multi-GPU training with FSDP. These constraints affect training configuration, checkpoint saving/loading, and inference with FSDP-wrapped models.
The Insight (Rule of Thumb)
- Action 1: Never use
ShardingStrategy.FULL_SHARDwith AudioCraft LM models. UseSHARD_GRAD_OPorNO_SHARDinstead. - Action 2: Use manual state dict type switching via
switch_to_full_state_dict()context manager instead of the FSDPstate_dict_typeAPI. - Action 3: Always call
flashy.distrib.sync_model()beforeclip_grad_norm_()to ensure gradient synchronization precedes clipping. - Action 4: Store FSDP wrapper reference in
module.__dict__['_fsdp']to avoid it appearing in state dict. - Trade-off:
SHARD_GRAD_OPuses more memory thanFULL_SHARD(keeps full parameters during forward) but enables generation. Manual state dict handling adds code complexity but avoids PyTorch bugs.
Reasoning
FULL_SHARD incompatibility: In autoregressive generation, the model performs sequential forward passes (one per token). FULL_SHARD discards parameters after each forward pass and re-gathers them for the next call. This causes massive communication overhead during generation where thousands of sequential forward passes are needed.
State dict bug: PyTorch's FSDP.state_dict_type() context manager has a known bug that prevents clean switching between sharded and full state dicts. The manual approach uses FullStateDictConfig(offload_to_cpu=True, rank0_only=True) to gather full weights on rank 0 CPU only.
Gradient sync before clipping: In distributed training, each GPU has partial gradients. Clipping before synchronization would clip based on incomplete gradient information, leading to inconsistent updates across devices.
Code Evidence
FULL_SHARD assertion from audiocraft/optim/fsdp.py:84-89:
assert sharding_strategy_config != ShardingStrategy.FULL_SHARD, \
"Not supported at the moment, requires a bit more work."
# Would require being smart about LM generation, would flush weights every step
Manual state dict switching from audiocraft/optim/fsdp.py:36-48:
@contextmanager
def switch_to_full_state_dict(models: tp.List[FSDP]):
# Another bug in FSDP makes it that we cannot use the `state_dict_type` API,
# so let's do thing manually.
FSDP wrapper stored in __dict__ from audiocraft/optim/fsdp.py:110-116:
# Let the wrapped model know about the wrapping!
# We use __dict__ to avoid it going into the state dict.
for module in FSDP.fsdp_modules(wrapped):
original = module._fsdp_wrapped_module
original.__dict__['_fsdp'] = module
Gradient sync before clipping from audiocraft/solvers/musicgen.py:409:
flashy.distrib.sync_model(self.model)
if self.cfg.optim.max_norm:
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.optim.max_norm)
CUDA sync debug check from audiocraft/solvers/musicgen.py:365-373:
check_synchronization_points = idx == 1 and self.device == 'cuda'
if check_synchronization_points:
torch.cuda.set_sync_debug_mode('warn')