Implementation:Facebookresearch Audiocraft MusicGenSolver run step
Appearance
Overview
MusicGenSolver.run_step is the core training step method that performs a single forward-backward pass on a batch of audio data. It tokenizes audio, prepares conditioning attributes, computes cross-entropy loss across codebooks, and performs optimizer updates with gradient clipping, mixed precision scaling, and distributed synchronization.
Source Location
| Property | Value |
|---|---|
| Source file | audiocraft/solvers/musicgen.py lines 363-442
|
| Import | from audiocraft.solvers.musicgen import MusicGenSolver
|
| Class | MusicGenSolver(base.StandardSolver)
|
| build_model() | audiocraft/solvers/musicgen.py lines 140-169
|
| StandardSolver.run() | audiocraft/solvers/base.py lines 489-499
|
| Solver instantiation | audiocraft/solvers/builders.py lines 44-65 (get_solver)
|
API
MusicGenSolver.run_step(
idx: int,
batch: Tuple[torch.Tensor, List[SegmentWithAttributes]],
metrics: dict
) -> dict
Parameters
| Parameter | Type | Description |
|---|---|---|
idx |
int |
Step index within the current epoch |
batch |
Tuple[torch.Tensor, List[SegmentWithAttributes]] |
Audio tensor [B, C, T] and metadata list
|
metrics |
dict |
Metrics dictionary to populate |
Return Value
dict containing:
| Key | Description |
|---|---|
ce |
Average cross-entropy loss across codebooks |
ppl |
Perplexity (exp(ce))
|
ce_q1, ce_q2, ... |
Per-codebook cross-entropy |
ppl_q1, ppl_q2, ... |
Per-codebook perplexity |
lr |
Current learning rate (training only) |
grad_norm |
Gradient norm after clipping (training only) |
grad_scale |
GradScaler scale (when using float16) |
Inputs and Outputs
Inputs:
- Batch of audio tensors and metadata from the dataloader
- Hydra configuration (via
self.cfg) - Frozen compression model (via
self.compression_model) - Trainable LM model (via
self.model)
Outputs:
- Metrics dictionary with loss values and training statistics
- Updated model parameters (during training)
Internal Execution Flow
The run_step method performs these operations in sequence:
1. Prepare Tokens and Attributes
Calls _prepare_tokens_and_attributes(batch) which:
# Extract audio and metadata
audio, infos = batch
audio = audio.to(self.device)
# Prepare conditioning attributes with CFG and attribute dropout
attributes = [info.to_condition_attributes() for info in infos]
attributes = self.model.cfg_dropout(attributes)
attributes = self.model.att_dropout(attributes)
tokenized = self.model.condition_provider.tokenize(attributes)
# Encode audio to discrete tokens
audio_tokens, scale = self.compression_model.encode(audio)
# Compute condition tensors
condition_tensors = self.model.condition_provider(tokenized)
# Build padding mask
padding_mask = torch.ones_like(audio_tokens, dtype=torch.bool)
2. Forward Pass
with self.autocast:
model_output = self.model.compute_predictions(audio_tokens, [], condition_tensors)
logits = model_output.logits
mask = padding_mask & model_output.mask # combine padding and pattern mask
ce, ce_per_codebook = self._compute_cross_entropy(logits, audio_tokens, mask)
loss = ce
3. Backward Pass and Optimization (training only)
if self.is_training:
# Scale loss for mixed precision
if self.scaler is not None:
loss = self.scaler.scale(loss)
# Backward with distributed sync
if self.cfg.fsdp.use:
loss.backward()
flashy.distrib.average_tensors(self.model.buffers())
elif self.cfg.optim.eager_sync:
with flashy.distrib.eager_sync_model(self.model):
loss.backward()
else:
loss.backward()
flashy.distrib.sync_model(self.model)
# Gradient clipping
if self.cfg.optim.max_norm:
metrics['grad_norm'] = torch.nn.utils.clip_grad_norm_(
self.model.parameters(), self.cfg.optim.max_norm)
# Optimizer step
self.optimizer.step()
self.lr_scheduler.step()
self.optimizer.zero_grad()
4. Metrics Collection
metrics['ce'] = ce
metrics['ppl'] = torch.exp(ce)
for k, ce_q in enumerate(ce_per_codebook):
metrics[f'ce_q{k + 1}'] = ce_q
metrics[f'ppl_q{k + 1}'] = torch.exp(ce_q)
Cross-Entropy Computation
The _compute_cross_entropy method at lines 219-251:
def _compute_cross_entropy(self, logits, targets, mask):
B, K, T = targets.shape
ce = torch.zeros([], device=targets.device)
ce_per_codebook = []
for k in range(K):
logits_k = logits[:, k, ...].contiguous().view(-1, logits.size(-1))
targets_k = targets[:, k, ...].contiguous().view(-1)
mask_k = mask[:, k, ...].contiguous().view(-1)
ce_targets = targets_k[mask_k]
ce_logits = logits_k[mask_k]
q_ce = F.cross_entropy(ce_logits, ce_targets)
ce += q_ce
ce_per_codebook.append(q_ce.detach())
ce = ce / K
return ce, ce_per_codebook
Model Building
The build_model() method at lines 140-169 initializes:
# Load frozen compression model
self.compression_model = CompressionSolver.wrapped_model_from_checkpoint(
self.cfg, self.cfg.compression_model_checkpoint, device=self.device)
# Instantiate trainable LM
self.model = models.builders.get_lm_model(self.cfg).to(self.device)
# Setup optimization
self.optimizer = builders.get_optimizer(
builders.get_optim_parameter_groups(self.model), self.cfg.optim)
self.lr_scheduler = builders.get_lr_scheduler(
self.optimizer, self.cfg.schedule, self.total_updates)
Related Builder Functions
| Function | Source | Purpose |
|---|---|---|
get_solver(cfg) |
solvers/builders.py:L44-65 |
Instantiates MusicGenSolver from config |
get_optimizer(params, cfg) |
solvers/builders.py:L95-121 |
Creates Adam/AdamW/DAdaptAdam optimizer |
get_lr_scheduler(optimizer, cfg, total_updates) |
solvers/builders.py:L124-165 |
Creates LR scheduler |
get_lm_model(cfg) |
models/builders.py:L136+ |
Instantiates transformer LM with conditioning |
Dependencies
torch,torch.nn.functional-- core PyTorchflashy-- distributed training utilities, metric averagingxformers(optional) -- memory-efficient attentionaudiocraft.solvers.base.StandardSolver-- base solver class
Related Pages
- Principle:Facebookresearch_Audiocraft_MusicGen_Training_Execution
- Environment:Facebookresearch_Audiocraft_Python_PyTorch_CUDA_Environment
- Environment:Facebookresearch_Audiocraft_XFormers_Memory_Efficient_Attention
- Heuristic:Facebookresearch_Audiocraft_FSDP_Distributed_Training_Tips
- Heuristic:Facebookresearch_Audiocraft_Chroma_Conditioning_Cache_Requirement
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment