Heuristic:Facebookresearch Audiocraft Gradient Balancing For Multi Loss
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Deep_Learning, Audio_Generation |
| Last Updated | 2026-02-13 23:00 GMT |
Overview
Gradient-based loss balancing technique that normalizes partial gradients using EMA-tracked norms, making loss weights represent gradient fractions rather than arbitrary multipliers.
Description
AudioCraft's Balancer class implements gradient-level loss balancing from the EnCodec paper. Instead of simply weighting losses by scalar multipliers (which are scale-dependent and hard to tune), the Balancer:
- Computes partial gradients of each loss with respect to the model output
- Tracks the average gradient norm of each loss using exponential moving average (EMA)
- Rescales each gradient so that its contribution to the total update is proportional to its configured weight
This means a weight of 0.5 guarantees that loss contributes 50% of the total gradient magnitude, regardless of the absolute loss scale. The formula is: G = sum_i(total_norm * g_i / avg(||g_i||) * w_i / sum(w_i)).
Usage
Use gradient balancing when training with multiple losses that operate at different scales (e.g., reconstruction loss + adversarial loss + perceptual loss in EnCodec compression training). Without balancing, one loss can dominate training simply because its gradients are larger in magnitude, not because it matters more.
The Insight (Rule of Thumb)
- Action: Enable
Balancerwithbalance_grads=Trueand configureema_decay=0.999for stable gradient norm tracking. - Value: EMA decay 0.999 (slow adaptation, stable training). Per-batch normalization (
per_batch_item=True) prevents batch outliers from distorting the running average. - Trade-off: Gradient balancing adds a backward pass per loss to compute partial gradients, roughly doubling the backward computation cost. The gradient rescaling also means the effective learning rate for each loss changes dynamically.
Reasoning
In multi-loss training (e.g., EnCodec with reconstruction, adversarial, spectrogram, and perceptual losses), raw loss magnitudes can differ by orders of magnitude. A reconstruction L1 loss might produce gradients 100x larger than a GAN discriminator loss. Simple scalar weighting requires careful manual tuning per experiment.
The Balancer solves this by making weights relative: a weight ratio of 2:1 guarantees the first loss contributes twice the gradient magnitude of the second, regardless of their absolute scales. The EMA tracking adapts to changing gradient statistics during training without introducing instability.
Code Evidence
Balancer formula from audiocraft/losses/balancer.py:14-60:
# The loss balancer combines losses together to compute gradients for the backward.
# Given y = f(...), and a number of losses l1(y, ...), l2(y, ...), the balancer can
# efficiently normalize the partial gradients d l1 / d y, d l2 / dy before summing them
# to achieve a desired ratio between the losses.
# Balanced gradient G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i)
EMA-based gradient norm tracking from audiocraft/losses/balancer.py:86-136:
class Balancer:
def __init__(self, weights: tp.Dict[str, float], rescale_grads: bool = True,
total_norm: float = 1., ema_decay: float = 0.999,
per_batch_item: bool = True):
# ...
self.ema_decay = ema_decay
self.per_batch_item = per_batch_item