Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Facebookresearch Audiocraft MusicGen Training Execution

From Leeroopedia

Overview

MusicGen Training Execution describes the core training loop that trains an autoregressive transformer language model on discrete audio tokens. The model learns to predict the next token in a sequence of codebook codes, conditioned on text descriptions (and optionally melody/style embeddings). The training procedure employs cross-entropy loss computed per-codebook, codebook pattern masking for multi-stream token generation, and classifier-free guidance dropout for improved inference quality.

Theoretical Foundations

Autoregressive Language Modeling on Audio Tokens

MusicGen (Copet et al., 2023, arXiv:2306.05284) reformulates music generation as a language modeling problem over discrete audio tokens. Given a sequence of tokens produced by a neural audio codec (EnCodec), the transformer learns:

P(c_t | c_{<t}, text)

where c_t represents the token at position t and text is the conditioning information. The model is trained with standard next-token prediction using cross-entropy loss.

Multi-Codebook Token Generation with Codebook Patterns

Unlike text language models that produce a single token per timestep, audio codecs produce K parallel codebook streams (e.g., K=4 for MusicGen). The key innovation of MusicGen is the codebook pattern system that defines how these parallel streams are serialized for autoregressive modeling:

  • Delay pattern (default) -- Each codebook stream is offset by one timestep. Codebook 0 generates at time t, codebook 1 at time t+1, etc. This introduces a small generation delay but allows efficient parallel prediction.
  • Parallel pattern -- All codebooks are generated simultaneously (no delay).
  • Unrolled pattern -- Codebooks are fully serialized into a single stream.

The pattern determines a mask over the [B, K, T] token tensor, indicating which positions are valid prediction targets at each step. The training loss is only computed on valid (unmasked) positions.

Cross-Entropy Loss per Codebook

The model outputs logits of shape [B, K, T, card] where card is the codebook cardinality (e.g., 2048). Cross-entropy is computed independently for each codebook:

CE_k = CrossEntropy(logits[:, k, :, :], targets[:, k, :]) (masked)

The final loss is the average across codebooks: CE = (1/K) * sum(CE_k). Per-codebook metrics (ce_q1, ce_q2, ...) are tracked for monitoring, as earlier codebooks typically carry more perceptually important information.

Classifier-Free Guidance (CFG) Dropout

During training, conditioning information (text descriptions, metadata) is randomly dropped with probability cfg_prob (typically 0.1-0.3). This trains the model to generate both conditionally and unconditionally. At inference time, the conditional and unconditional predictions are combined:

logits_guided = logits_uncond + cfg_coef * (logits_cond - logits_uncond)

This dramatically improves the alignment between generated audio and the text prompt.

Attribute Dropout

Beyond full CFG dropout, individual conditioning attributes (text, melody, style) can be independently dropped with configurable probabilities. This is handled by the model's att_dropout method before the conditioning provider processes the attributes.

Key Principles

  • Frozen tokenizer, trainable LM -- The compression model (EnCodec) is frozen during training. Only the transformer LM and its conditioning components receive gradients.
  • Distributed training support -- The training loop supports FSDP (Fully Sharded Data Parallel), DDP with eager sync, and gradient synchronization across processes.
  • Mixed precision -- Autocast (float16 or bfloat16) is used for forward/backward passes with optional gradient scaling.
  • Gradient clipping -- Gradient norms are clipped to a configurable maximum (typically 1.0) to prevent training instability.
  • EMA tracking -- An exponential moving average of model weights is maintained and used for validation, evaluation, and generation stages.

Training Loop Structure

The full training loop orchestrated by StandardSolver.run():

  1. Restore -- Load checkpoint and replay metrics history.
  2. For each epoch:
    1. Train -- Run run_step() on each batch from the training dataloader.
    2. Valid -- Run run_step() on validation data (with EMA weights swapped in).
    3. Update best state -- Compare validation metric to historical best.
    4. Evaluate (periodic) -- Run generation + metric computation with best state.
    5. Generate (periodic) -- Generate audio samples for qualitative inspection.
    6. Commit -- Save checkpoints and log metrics.

Role in the MusicGen Training Pipeline

Training execution is the central stage of the pipeline. It consumes:

  • Prepared audio data (from the dataset)
  • Discrete tokens (from the frozen tokenizer)
  • A fully composed configuration (from the Hydra system)

And produces:

  • Trained model weights
  • Per-epoch metrics (cross-entropy, perplexity, per-codebook metrics)
  • Checkpoints for resumption and deployment

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment