Implementation:Facebookresearch Audiocraft CompressionSolver run step
| Knowledge Sources | |
|---|---|
| Domains | |
| Last Updated | 2026-02-13 00:00 GMT |
Overview
Concrete implementation of the EnCodec compression training loop within Audiocraft. The CompressionSolver class (extending base.StandardSolver) orchestrates model construction, loss computation, gradient balancing, discriminator training, and optimizer steps. The central method run_step() executes one training or validation iteration.
Description
The CompressionSolver initializes adversarial losses, auxiliary losses, info-only losses, and a gradient Balancer from the Hydra config. During each training step, run_step() performs the forward pass through the EnCodec model, computes all losses, updates the discriminator, applies the balancer to compute the generator gradient, clips gradients, and steps the optimizer.
Key loss components wired into the solver:
- MRSTFTLoss (
audiocraft/losses/stftloss.py) -- multi-resolution STFT spectral convergence and magnitude loss - MultiScaleMelSpectrogramLoss (
audiocraft/losses/specloss.py) -- multi-scale mel spectrogram L1 and L2 loss - AdversarialLoss (
audiocraft/adversarial/losses.py) -- wraps the discriminator and computes generator/discriminator losses - MultiScaleSTFTDiscriminator (
audiocraft/adversarial/discriminators/msstftd.py) -- the adversary network - Balancer (
audiocraft/losses/balancer.py) -- gradient-based loss balancing
Usage
Import when extending or debugging the compression solver:
from audiocraft.solvers.compression import CompressionSolver
The solver is typically instantiated by the Audiocraft training framework through Dora and Hydra configuration, not by direct instantiation.
Code Reference
Source Location
- Repository:
facebookresearch/audiocraft - File:
audiocraft/solvers/compression.py(lines 27--329) - Key method:
run_step()at lines 83--174 - Model construction:
build_model()at lines 59--66 - Loss setup:
__init__()at lines 34--52
Supporting Source Files
| Component | File | Lines |
|---|---|---|
MRSTFTLoss |
audiocraft/losses/stftloss.py |
164--207 |
MultiScaleMelSpectrogramLoss |
audiocraft/losses/specloss.py |
96--149 |
AdversarialLoss |
audiocraft/adversarial/losses.py |
26--135 |
FeatureMatchingLoss |
audiocraft/adversarial/losses.py |
201--228 |
MultiScaleSTFTDiscriminator |
audiocraft/adversarial/discriminators/msstftd.py |
94--134 |
Balancer |
audiocraft/losses/balancer.py |
14--136 |
Signature
class CompressionSolver(base.StandardSolver):
def __init__(self, cfg: omegaconf.DictConfig):
...
def build_model(self):
"""Instantiate EnCodec model and optimizer from config."""
self.model = models.builders.get_compression_model(self.cfg).to(self.device)
self.optimizer = builders.get_optimizer(self.model.parameters(), self.cfg.optim)
...
def run_step(self, idx: int, batch: torch.Tensor, metrics: dict) -> dict:
"""Perform one training or validation step on a given batch."""
...
def evaluate(self):
"""Run audio reconstruction evaluation with ViSQOL and SI-SNR."""
...
@staticmethod
def model_from_checkpoint(
checkpoint_path: Union[Path, str],
device: Union[torch.device, str] = 'cpu',
) -> models.CompressionModel:
"""Load a trained CompressionModel from a checkpoint path or Dora sig."""
...
Import
from audiocraft.solvers.compression import CompressionSolver
Dependencies
flashy-- distributed training utilities, EMA averaging, model synchronizationtorch-- core tensor operations, autograd, optimizeromegaconf-- Hydra configuration management
I/O Contract
Inputs
| Name | Type | Description |
|---|---|---|
idx |
int |
Current step index within the epoch. |
batch |
torch.Tensor [B, C, T] |
Raw audio tensors from the dataset loader. Moved to device inside the method. |
metrics |
dict |
Mutable dictionary for accumulating metrics. Passed in by the training loop. |
self.cfg |
omegaconf.DictConfig |
Hydra configuration specifying loss weights, optimizer params, adversarial settings, and model architecture. |
Outputs
| Name | Type | Description |
|---|---|---|
metrics |
dict |
Updated metrics dictionary containing: bandwidth (kbps), d_loss (discriminator loss), g_loss (generator loss from balancer), adv (aggregated adversarial loss), feat (aggregated feature matching loss), penalty (commitment loss), ratio1/ratio2 (gradient norm diagnostics), and individual loss values for each auxiliary and info loss.
|
| Trained model | EncodecModel |
After full training, self.model contains the trained encoder, decoder, and quantizer. Best state is tracked via register_best_state('model').
|
Usage Examples
Example 1: Training Step Internals
The core training logic within run_step(), showing how losses are computed and balanced.
# Inside CompressionSolver.run_step():
# 1. Forward pass through EnCodec
x = batch.to(self.device)
y = x.clone()
qres = self.model(x) # returns QuantizedResult
y_pred = qres.x # reconstructed audio
# 2. Discriminator update (stochastic, every N steps)
if self.is_training and len(self.adv_losses) > 0:
if torch.rand(1, generator=self.rng).item() <= 1 / self.cfg.adversarial.every:
for adv_name, adversary in self.adv_losses.items():
disc_loss = adversary.train_adv(y_pred, y)
# 3. Compute balanced losses
balanced_losses = {}
for adv_name, adversary in self.adv_losses.items():
adv_loss, feat_loss = adversary(y_pred, y)
balanced_losses[f'adv_{adv_name}'] = adv_loss
balanced_losses[f'feat_{adv_name}'] = feat_loss
for loss_name, criterion in self.aux_losses.items():
balanced_losses[loss_name] = criterion(y_pred, y)
# 4. Gradient balancing and backward
metrics['g_loss'] = self.balancer.backward(balanced_losses, y_pred)
# 5. Optimizer step
flashy.distrib.sync_model(self.model)
self.optimizer.step()
self.optimizer.zero_grad()
Example 2: Loading a Trained Compression Model
Using the solver's static method to load a trained model from a checkpoint.
from audiocraft.solvers.compression import CompressionSolver
# Load from a Dora experiment signature
model = CompressionSolver.model_from_checkpoint(
'//sig/my_encodec_experiment',
device='cuda',
)
# Or load from a pretrained model
model = CompressionSolver.model_from_checkpoint(
'//pretrained/facebook/encodec_32khz',
device='cuda',
)