Heuristic:Speechbrain Speechbrain GAN Dual Optimizer Pattern
| Knowledge Sources | |
|---|---|
| Domains | Speech_Enhancement, GAN_Training |
| Last Updated | 2026-02-09 20:00 GMT |
Overview
Dual optimizer management pattern for GAN-based speech enhancement, with discriminator-first training order and separate optimizer/scheduler management.
Description
SpeechBrain's GAN-based enhancement recipes (MetricGAN+, HiFi-GAN, SEGAN) implement a consistent dual-optimizer pattern where the discriminator and generator have completely separate optimizers. The key insight is that the discriminator must be trained first in each iteration to provide useful gradient signal to the generator. MetricGAN+ uses a triple-pass discriminator training (current data, historical data, current again) with historical replay to prevent catastrophic forgetting. HiFi-GAN re-evaluates the discriminator after its update to give the generator the most current feedback. SEGAN uses `retain_graph=True` to share the computation graph between discriminator and generator backward passes.
Usage
Apply when implementing GAN-based enhancement or vocoder training in SpeechBrain. Override `init_optimizers()` in your Brain subclass to create separate generator and discriminator optimizers. Always train discriminator before generator in each batch.
The Insight (Rule of Thumb)
- Action: Override `init_optimizers()` to create `self.optimizer_g` and `self.optimizer_d`. Train D first, then G, in each `fit_batch()` call.
- Value: MetricGAN: 3 D passes per G pass. HiFi-GAN: 1 D pass + re-evaluate before G pass. SEGAN: `retain_graph=True` for shared computation.
- Trade-off: More D passes make D stronger but slow training. Too few D passes mean G gets poor gradient signal. MetricGAN's historical replay adds memory overhead.
- Critical detail: MetricGAN clamps Learnable_sigmoid parameters to max 3.5 to prevent gradient infinity, replacing NaN with 3.5 via `param[param != param] = 3.5`.
Reasoning
In adversarial training, the discriminator must be sufficiently strong to provide meaningful gradient information to the generator. If D is too weak, G receives near-zero gradients and stops improving. The historical replay in MetricGAN prevents the discriminator from forgetting what older enhancement quality levels look like, which stabilizes training. The Learnable_sigmoid clamping at 3.5 prevents a specific failure mode where the sigmoid slope grows unbounded, turning it into a step function with vanishing gradients everywhere except at zero.
Code from `recipes/Voicebank/enhance/MetricGAN/train.py:327-332`:
for name, param in self.modules.generator.named_parameters():
if "Learnable_sigmoid" in name:
param.data = torch.clamp(param, max=3.5)
param.data[param != param] = 3.5 # set 'nan' to 3.5
Code from `recipes/LJSpeech/TTS/vocoder/hifigan/train.py:82-111`:
# First train the discriminator
self.optimizer_d.zero_grad()
loss_d.backward()
self.optimizer_d.step()
# Re-evaluate discriminator with updated weights before training generator
scores_fake, feats_fake = self.modules.discriminator(y_g_hat)
scores_real, feats_real = self.modules.discriminator(y)
# Then train the generator
self.optimizer_g.zero_grad()
loss_g.backward()
self.optimizer_g.step()