Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Speechbrain Speechbrain GAN Dual Optimizer Pattern

From Leeroopedia




Knowledge Sources
Domains Speech_Enhancement, GAN_Training
Last Updated 2026-02-09 20:00 GMT

Overview

Dual optimizer management pattern for GAN-based speech enhancement, with discriminator-first training order and separate optimizer/scheduler management.

Description

SpeechBrain's GAN-based enhancement recipes (MetricGAN+, HiFi-GAN, SEGAN) implement a consistent dual-optimizer pattern where the discriminator and generator have completely separate optimizers. The key insight is that the discriminator must be trained first in each iteration to provide useful gradient signal to the generator. MetricGAN+ uses a triple-pass discriminator training (current data, historical data, current again) with historical replay to prevent catastrophic forgetting. HiFi-GAN re-evaluates the discriminator after its update to give the generator the most current feedback. SEGAN uses `retain_graph=True` to share the computation graph between discriminator and generator backward passes.

Usage

Apply when implementing GAN-based enhancement or vocoder training in SpeechBrain. Override `init_optimizers()` in your Brain subclass to create separate generator and discriminator optimizers. Always train discriminator before generator in each batch.

The Insight (Rule of Thumb)

  • Action: Override `init_optimizers()` to create `self.optimizer_g` and `self.optimizer_d`. Train D first, then G, in each `fit_batch()` call.
  • Value: MetricGAN: 3 D passes per G pass. HiFi-GAN: 1 D pass + re-evaluate before G pass. SEGAN: `retain_graph=True` for shared computation.
  • Trade-off: More D passes make D stronger but slow training. Too few D passes mean G gets poor gradient signal. MetricGAN's historical replay adds memory overhead.
  • Critical detail: MetricGAN clamps Learnable_sigmoid parameters to max 3.5 to prevent gradient infinity, replacing NaN with 3.5 via `param[param != param] = 3.5`.

Reasoning

In adversarial training, the discriminator must be sufficiently strong to provide meaningful gradient information to the generator. If D is too weak, G receives near-zero gradients and stops improving. The historical replay in MetricGAN prevents the discriminator from forgetting what older enhancement quality levels look like, which stabilizes training. The Learnable_sigmoid clamping at 3.5 prevents a specific failure mode where the sigmoid slope grows unbounded, turning it into a step function with vanishing gradients everywhere except at zero.

Code from `recipes/Voicebank/enhance/MetricGAN/train.py:327-332`:

for name, param in self.modules.generator.named_parameters():
    if "Learnable_sigmoid" in name:
        param.data = torch.clamp(param, max=3.5)
        param.data[param != param] = 3.5  # set 'nan' to 3.5

Code from `recipes/LJSpeech/TTS/vocoder/hifigan/train.py:82-111`:

# First train the discriminator
self.optimizer_d.zero_grad()
loss_d.backward()
self.optimizer_d.step()

# Re-evaluate discriminator with updated weights before training generator
scores_fake, feats_fake = self.modules.discriminator(y_g_hat)
scores_real, feats_real = self.modules.discriminator(y)

# Then train the generator
self.optimizer_g.zero_grad()
loss_g.backward()
self.optimizer_g.step()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment