Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Facebookresearch Audiocraft FSDP Distributed Training Tips

From Leeroopedia
Knowledge Sources
Domains Distributed_Training, Optimization, Infrastructure
Last Updated 2026-02-13 23:00 GMT

Overview

FSDP workarounds for AudioCraft: FULL_SHARD is unsupported (breaks LM generation), state dict switching requires manual handling due to PyTorch bugs, and gradient sync must precede gradient clipping.

Description

AudioCraft's FSDP (Fully Sharded Data Parallel) integration has several critical workarounds based on PyTorch FSDP limitations discovered during development. These are essential tribal knowledge for anyone training AudioCraft models at scale across multiple GPUs.

The three key insights are: (1) FULL_SHARD strategy is explicitly unsupported because it would flush model weights on every forward pass, making autoregressive generation impossible; (2) FSDP state dict type switching must be done manually due to a PyTorch API bug; (3) the FSDP wrapper reference is stored in __dict__ (not as an attribute) to avoid serialization into checkpoints.

Usage

Apply these tips when setting up multi-GPU training with FSDP. These constraints affect training configuration, checkpoint saving/loading, and inference with FSDP-wrapped models.

The Insight (Rule of Thumb)

  • Action 1: Never use ShardingStrategy.FULL_SHARD with AudioCraft LM models. Use SHARD_GRAD_OP or NO_SHARD instead.
  • Action 2: Use manual state dict type switching via switch_to_full_state_dict() context manager instead of the FSDP state_dict_type API.
  • Action 3: Always call flashy.distrib.sync_model() before clip_grad_norm_() to ensure gradient synchronization precedes clipping.
  • Action 4: Store FSDP wrapper reference in module.__dict__['_fsdp'] to avoid it appearing in state dict.
  • Trade-off: SHARD_GRAD_OP uses more memory than FULL_SHARD (keeps full parameters during forward) but enables generation. Manual state dict handling adds code complexity but avoids PyTorch bugs.

Reasoning

FULL_SHARD incompatibility: In autoregressive generation, the model performs sequential forward passes (one per token). FULL_SHARD discards parameters after each forward pass and re-gathers them for the next call. This causes massive communication overhead during generation where thousands of sequential forward passes are needed.

State dict bug: PyTorch's FSDP.state_dict_type() context manager has a known bug that prevents clean switching between sharded and full state dicts. The manual approach uses FullStateDictConfig(offload_to_cpu=True, rank0_only=True) to gather full weights on rank 0 CPU only.

Gradient sync before clipping: In distributed training, each GPU has partial gradients. Clipping before synchronization would clip based on incomplete gradient information, leading to inconsistent updates across devices.

Code Evidence

FULL_SHARD assertion from audiocraft/optim/fsdp.py:84-89:

assert sharding_strategy_config != ShardingStrategy.FULL_SHARD, \
    "Not supported at the moment, requires a bit more work."
# Would require being smart about LM generation, would flush weights every step

Manual state dict switching from audiocraft/optim/fsdp.py:36-48:

@contextmanager
def switch_to_full_state_dict(models: tp.List[FSDP]):
    # Another bug in FSDP makes it that we cannot use the `state_dict_type` API,
    # so let's do thing manually.

FSDP wrapper stored in __dict__ from audiocraft/optim/fsdp.py:110-116:

# Let the wrapped model know about the wrapping!
# We use __dict__ to avoid it going into the state dict.
for module in FSDP.fsdp_modules(wrapped):
    original = module._fsdp_wrapped_module
    original.__dict__['_fsdp'] = module

Gradient sync before clipping from audiocraft/solvers/musicgen.py:409:

flashy.distrib.sync_model(self.model)
if self.cfg.optim.max_norm:
    torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.cfg.optim.max_norm)

CUDA sync debug check from audiocraft/solvers/musicgen.py:365-373:

check_synchronization_points = idx == 1 and self.device == 'cuda'
if check_synchronization_points:
    torch.cuda.set_sync_debug_mode('warn')

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment