Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Facebookresearch Audiocraft CompressionSolver run step

From Leeroopedia
Metadata
Knowledge Sources
Domains
Last Updated 2026-02-13 00:00 GMT

Overview

Concrete implementation of the EnCodec compression training loop within Audiocraft. The CompressionSolver class (extending base.StandardSolver) orchestrates model construction, loss computation, gradient balancing, discriminator training, and optimizer steps. The central method run_step() executes one training or validation iteration.

Description

The CompressionSolver initializes adversarial losses, auxiliary losses, info-only losses, and a gradient Balancer from the Hydra config. During each training step, run_step() performs the forward pass through the EnCodec model, computes all losses, updates the discriminator, applies the balancer to compute the generator gradient, clips gradients, and steps the optimizer.

Key loss components wired into the solver:

  • MRSTFTLoss (audiocraft/losses/stftloss.py) -- multi-resolution STFT spectral convergence and magnitude loss
  • MultiScaleMelSpectrogramLoss (audiocraft/losses/specloss.py) -- multi-scale mel spectrogram L1 and L2 loss
  • AdversarialLoss (audiocraft/adversarial/losses.py) -- wraps the discriminator and computes generator/discriminator losses
  • MultiScaleSTFTDiscriminator (audiocraft/adversarial/discriminators/msstftd.py) -- the adversary network
  • Balancer (audiocraft/losses/balancer.py) -- gradient-based loss balancing

Usage

Import when extending or debugging the compression solver:

from audiocraft.solvers.compression import CompressionSolver

The solver is typically instantiated by the Audiocraft training framework through Dora and Hydra configuration, not by direct instantiation.

Code Reference

Source Location

  • Repository: facebookresearch/audiocraft
  • File: audiocraft/solvers/compression.py (lines 27--329)
  • Key method: run_step() at lines 83--174
  • Model construction: build_model() at lines 59--66
  • Loss setup: __init__() at lines 34--52

Supporting Source Files

Loss Component Sources
Component File Lines
MRSTFTLoss audiocraft/losses/stftloss.py 164--207
MultiScaleMelSpectrogramLoss audiocraft/losses/specloss.py 96--149
AdversarialLoss audiocraft/adversarial/losses.py 26--135
FeatureMatchingLoss audiocraft/adversarial/losses.py 201--228
MultiScaleSTFTDiscriminator audiocraft/adversarial/discriminators/msstftd.py 94--134
Balancer audiocraft/losses/balancer.py 14--136

Signature

class CompressionSolver(base.StandardSolver):
    def __init__(self, cfg: omegaconf.DictConfig):
        ...

    def build_model(self):
        """Instantiate EnCodec model and optimizer from config."""
        self.model = models.builders.get_compression_model(self.cfg).to(self.device)
        self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
        ...

    def run_step(self, idx: int, batch: torch.Tensor, metrics: dict) -> dict:
        """Perform one training or validation step on a given batch."""
        ...

    def evaluate(self):
        """Run audio reconstruction evaluation with ViSQOL and SI-SNR."""
        ...

    @staticmethod
    def model_from_checkpoint(
        checkpoint_path: Union[Path, str],
        device: Union[torch.device, str] = 'cpu',
    ) -> models.CompressionModel:
        """Load a trained CompressionModel from a checkpoint path or Dora sig."""
        ...

Import

from audiocraft.solvers.compression import CompressionSolver

Dependencies

  • flashy -- distributed training utilities, EMA averaging, model synchronization
  • torch -- core tensor operations, autograd, optimizer
  • omegaconf -- Hydra configuration management

I/O Contract

Inputs

Input Contract
Name Type Description
idx int Current step index within the epoch.
batch torch.Tensor [B, C, T] Raw audio tensors from the dataset loader. Moved to device inside the method.
metrics dict Mutable dictionary for accumulating metrics. Passed in by the training loop.
self.cfg omegaconf.DictConfig Hydra configuration specifying loss weights, optimizer params, adversarial settings, and model architecture.

Outputs

Output Contract
Name Type Description
metrics dict Updated metrics dictionary containing: bandwidth (kbps), d_loss (discriminator loss), g_loss (generator loss from balancer), adv (aggregated adversarial loss), feat (aggregated feature matching loss), penalty (commitment loss), ratio1/ratio2 (gradient norm diagnostics), and individual loss values for each auxiliary and info loss.
Trained model EncodecModel After full training, self.model contains the trained encoder, decoder, and quantizer. Best state is tracked via register_best_state('model').

Usage Examples

Example 1: Training Step Internals

The core training logic within run_step(), showing how losses are computed and balanced.

# Inside CompressionSolver.run_step():

# 1. Forward pass through EnCodec
x = batch.to(self.device)
y = x.clone()
qres = self.model(x)       # returns QuantizedResult
y_pred = qres.x            # reconstructed audio

# 2. Discriminator update (stochastic, every N steps)
if self.is_training and len(self.adv_losses) > 0:
    if torch.rand(1, generator=self.rng).item() <= 1 / self.cfg.adversarial.every:
        for adv_name, adversary in self.adv_losses.items():
            disc_loss = adversary.train_adv(y_pred, y)

# 3. Compute balanced losses
balanced_losses = {}
for adv_name, adversary in self.adv_losses.items():
    adv_loss, feat_loss = adversary(y_pred, y)
    balanced_losses[f'adv_{adv_name}'] = adv_loss
    balanced_losses[f'feat_{adv_name}'] = feat_loss

for loss_name, criterion in self.aux_losses.items():
    balanced_losses[loss_name] = criterion(y_pred, y)

# 4. Gradient balancing and backward
metrics['g_loss'] = self.balancer.backward(balanced_losses, y_pred)

# 5. Optimizer step
flashy.distrib.sync_model(self.model)
self.optimizer.step()
self.optimizer.zero_grad()

Example 2: Loading a Trained Compression Model

Using the solver's static method to load a trained model from a checkpoint.

from audiocraft.solvers.compression import CompressionSolver

# Load from a Dora experiment signature
model = CompressionSolver.model_from_checkpoint(
    '//sig/my_encodec_experiment',
    device='cuda',
)

# Or load from a pretrained model
model = CompressionSolver.model_from_checkpoint(
    '//pretrained/facebook/encodec_32khz',
    device='cuda',
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment