Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Facebookresearch Audiocraft MusicGenSolver run step

From Leeroopedia

Overview

MusicGenSolver.run_step is the core training step method that performs a single forward-backward pass on a batch of audio data. It tokenizes audio, prepares conditioning attributes, computes cross-entropy loss across codebooks, and performs optimizer updates with gradient clipping, mixed precision scaling, and distributed synchronization.

Source Location

Property Value
Source file audiocraft/solvers/musicgen.py lines 363-442
Import from audiocraft.solvers.musicgen import MusicGenSolver
Class MusicGenSolver(base.StandardSolver)
build_model() audiocraft/solvers/musicgen.py lines 140-169
StandardSolver.run() audiocraft/solvers/base.py lines 489-499
Solver instantiation audiocraft/solvers/builders.py lines 44-65 (get_solver)

API

MusicGenSolver.run_step(
    idx: int,
    batch: Tuple[torch.Tensor, List[SegmentWithAttributes]],
    metrics: dict
) -> dict

Parameters

Parameter Type Description
idx int Step index within the current epoch
batch Tuple[torch.Tensor, List[SegmentWithAttributes]] Audio tensor [B, C, T] and metadata list
metrics dict Metrics dictionary to populate

Return Value

dict containing:

Key Description
ce Average cross-entropy loss across codebooks
ppl Perplexity (exp(ce))
ce_q1, ce_q2, ... Per-codebook cross-entropy
ppl_q1, ppl_q2, ... Per-codebook perplexity
lr Current learning rate (training only)
grad_norm Gradient norm after clipping (training only)
grad_scale GradScaler scale (when using float16)

Inputs and Outputs

Inputs:

  • Batch of audio tensors and metadata from the dataloader
  • Hydra configuration (via self.cfg)
  • Frozen compression model (via self.compression_model)
  • Trainable LM model (via self.model)

Outputs:

  • Metrics dictionary with loss values and training statistics
  • Updated model parameters (during training)

Internal Execution Flow

The run_step method performs these operations in sequence:

1. Prepare Tokens and Attributes

Calls _prepare_tokens_and_attributes(batch) which:

# Extract audio and metadata
audio, infos = batch
audio = audio.to(self.device)

# Prepare conditioning attributes with CFG and attribute dropout
attributes = [info.to_condition_attributes() for info in infos]
attributes = self.model.cfg_dropout(attributes)
attributes = self.model.att_dropout(attributes)
tokenized = self.model.condition_provider.tokenize(attributes)

# Encode audio to discrete tokens
audio_tokens, scale = self.compression_model.encode(audio)

# Compute condition tensors
condition_tensors = self.model.condition_provider(tokenized)

# Build padding mask
padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool)

2. Forward Pass

with self.autocast:
    model_output = self.model.compute_predictions(audio_tokens, [], condition_tensors)
    logits = model_output.logits
    mask = padding_mask & model_output.mask  # combine padding and pattern mask
    ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
    loss = ce

3. Backward Pass and Optimization (training only)

if self.is_training:
    # Scale loss for mixed precision
    if self.scaler is not None:
        loss = self.scaler.scale(loss)

    # Backward with distributed sync
    if self.cfg.fsdp.use:
        loss.backward()
        flashy.distrib.average_tensors(self.model.buffers())
    elif self.cfg.optim.eager_sync:
        with flashy.distrib.eager_sync_model(self.model):
            loss.backward()
    else:
        loss.backward()
        flashy.distrib.sync_model(self.model)

    # Gradient clipping
    if self.cfg.optim.max_norm:
        metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_(
            self.model.parameters(), self.cfg.optim.max_norm)

    # Optimizer step
    self.optimizer.step()
    self.lr_scheduler.step()
    self.optimizer.zero_grad()

4. Metrics Collection

metrics['ce'] = ce
metrics['ppl'] = torch.exp(ce)
for k, ce_q in enumerate(ce_per_codebook):
    metrics[f'ce_q{k + 1}'] = ce_q
    metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q)

Cross-Entropy Computation

The _compute_cross_entropy method at lines 219-251:

def _compute_cross_entropy(self, logits, targets, mask):
    B, K, T = targets.shape
    ce = torch.zeros([], device=targets.device)
    ce_per_codebook = []
    for k in range(K):
        logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1))
        targets_k = targets[:, k, ...].contiguous().view(-1)
        mask_k = mask[:, k, ...].contiguous().view(-1)
        ce_targets = targets_k[mask_k]
        ce_logits = logits_k[mask_k]
        q_ce = F.cross_entropy(ce_logits, ce_targets)
        ce += q_ce
        ce_per_codebook.append(q_ce.detach())
    ce = ce / K
    return ce, ce_per_codebook

Model Building

The build_model() method at lines 140-169 initializes:

# Load frozen compression model
self.compression_model = CompressionSolver.wrapped_model_from_checkpoint(
    self.cfg, self.cfg.compression_model_checkpoint, device=self.device)

# Instantiate trainable LM
self.model = models.builders.get_lm_model(self.cfg).to(self.device)

# Setup optimization
self.optimizer = builders.get_optimizer(
    builders.get_optim_parameter_groups(self.model), self.cfg.optim)
self.lr_scheduler = builders.get_lr_scheduler(
    self.optimizer, self.cfg.schedule, self.total_updates)

Related Builder Functions

Function Source Purpose
get_solver(cfg) solvers/builders.py:L44-65 Instantiates MusicGenSolver from config
get_optimizer(params, cfg) solvers/builders.py:L95-121 Creates Adam/AdamW/DAdaptAdam optimizer
get_lr_scheduler(optimizer, cfg, total_updates) solvers/builders.py:L124-165 Creates LR scheduler
get_lm_model(cfg) models/builders.py:L136+ Instantiates transformer LM with conditioning

Dependencies

  • torch, torch.nn.functional -- core PyTorch
  • flashy -- distributed training utilities, metric averaging
  • xformers (optional) -- memory-efficient attention
  • audiocraft.solvers.base.StandardSolver -- base solver class

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment