Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Facebookresearch Audiocraft Gradient Balancing For Multi Loss

From Leeroopedia
Knowledge Sources
Domains Optimization, Deep_Learning, Audio_Generation
Last Updated 2026-02-13 23:00 GMT

Overview

Gradient-based loss balancing technique that normalizes partial gradients using EMA-tracked norms, making loss weights represent gradient fractions rather than arbitrary multipliers.

Description

AudioCraft's Balancer class implements gradient-level loss balancing from the EnCodec paper. Instead of simply weighting losses by scalar multipliers (which are scale-dependent and hard to tune), the Balancer:

  1. Computes partial gradients of each loss with respect to the model output
  2. Tracks the average gradient norm of each loss using exponential moving average (EMA)
  3. Rescales each gradient so that its contribution to the total update is proportional to its configured weight

This means a weight of 0.5 guarantees that loss contributes 50% of the total gradient magnitude, regardless of the absolute loss scale. The formula is: G = sum_i(total_norm * g_i / avg(||g_i||) * w_i / sum(w_i)).

Usage

Use gradient balancing when training with multiple losses that operate at different scales (e.g., reconstruction loss + adversarial loss + perceptual loss in EnCodec compression training). Without balancing, one loss can dominate training simply because its gradients are larger in magnitude, not because it matters more.

The Insight (Rule of Thumb)

  • Action: Enable Balancer with balance_grads=True and configure ema_decay=0.999 for stable gradient norm tracking.
  • Value: EMA decay 0.999 (slow adaptation, stable training). Per-batch normalization (per_batch_item=True) prevents batch outliers from distorting the running average.
  • Trade-off: Gradient balancing adds a backward pass per loss to compute partial gradients, roughly doubling the backward computation cost. The gradient rescaling also means the effective learning rate for each loss changes dynamically.

Reasoning

In multi-loss training (e.g., EnCodec with reconstruction, adversarial, spectrogram, and perceptual losses), raw loss magnitudes can differ by orders of magnitude. A reconstruction L1 loss might produce gradients 100x larger than a GAN discriminator loss. Simple scalar weighting requires careful manual tuning per experiment.

The Balancer solves this by making weights relative: a weight ratio of 2:1 guarantees the first loss contributes twice the gradient magnitude of the second, regardless of their absolute scales. The EMA tracking adapts to changing gradient statistics during training without introducing instability.

Code Evidence

Balancer formula from audiocraft/losses/balancer.py:14-60:

# The loss balancer combines losses together to compute gradients for the backward.
# Given y = f(...), and a number of losses l1(y, ...), l2(y, ...), the balancer can
# efficiently normalize the partial gradients d l1 / d y, d l2 / dy before summing them
# to achieve a desired ratio between the losses.

# Balanced gradient G = sum_i total_norm * g_i / avg(||g_i||) * w_i / sum(w_i)

EMA-based gradient norm tracking from audiocraft/losses/balancer.py:86-136:

class Balancer:
    def __init__(self, weights: tp.Dict[str, float], rescale_grads: bool = True,
                 total_norm: float = 1., ema_decay: float = 0.999,
                 per_batch_item: bool = True):
        # ...
        self.ema_decay = ema_decay
        self.per_batch_item = per_batch_item

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment