Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Speechbrain Speechbrain GAN Based Enhancement Training

From Leeroopedia


Property Value
Principle Name GAN_Based_Enhancement_Training
Workflow Speech_Enhancement_Training
Domains GAN_Training, Speech_Enhancement
Source Repository speechbrain/speechbrain
Knowledge Sources Fu et al. 2021 "MetricGAN+: An Improved Version of MetricGAN for Speech Enhancement"
Related Implementation Implementation:Speechbrain_Speechbrain_MetricGanBrain_Fit_Batch

Overview

GAN-Based Enhancement Training applies the Generative Adversarial Network (GAN) framework to speech enhancement, using a generator to enhance noisy speech and a discriminator to evaluate enhancement quality. The key innovation of MetricGAN+ over standard GANs is that the discriminator is trained to predict actual perceptual quality scores (such as PESQ) rather than simple real/fake binary labels. This directly optimizes for human-perceived speech quality.

Theoretical Background

Standard GAN vs. MetricGAN

In a standard GAN for speech enhancement:

  • The generator G transforms noisy speech into enhanced speech
  • The discriminator D classifies speech as "real" (clean) or "fake" (enhanced)
  • The training objective is a minimax game: G tries to fool D, and D tries to distinguish real from fake

The fundamental limitation is that the binary real/fake signal provides coarse gradient information -- the discriminator cannot communicate how much better or worse the enhanced speech is compared to the clean reference.

MetricGAN addresses this by redefining the discriminator's role:

Standard GAN discriminator: D(x) -> {0, 1}       (real or fake)
MetricGAN discriminator:    D(x, ref) -> [0, 1]   (predicted quality score)

The discriminator takes both the evaluated speech and the clean reference as input, and outputs a continuous quality score. This score is trained to match the actual PESQ score of the evaluated speech, providing a differentiable approximation of the non-differentiable PESQ metric.

Training Procedure

MetricGAN+ training alternates between two phases in each epoch:

Phase 1: Discriminator Training

The discriminator is trained to accurately predict quality scores for three types of inputs:

  1. Clean speech (target score: 1.0, since clean-vs-clean PESQ is perfect)
  2. Enhanced speech (target score: actual normalized PESQ of the enhanced signal)
  3. Noisy speech (target score: actual normalized PESQ of the noisy signal)

The discriminator loss is MSE between predicted and actual scores:

L_D = MSE(D(clean, clean), 1.0)
    + MSE(D(enhanced, clean), PESQ(enhanced, clean))
    + MSE(D(noisy, clean), PESQ(noisy, clean))

Phase 2: Generator Training

The generator is trained to maximize the discriminator's predicted quality score:

L_G = MSE(D(G(noisy), clean), 1.0) + lambda * MSE(G(noisy)_spec, clean_spec)

The generator's goal is to produce enhanced speech that the (now-trained) discriminator rates as score 1.0 (perfect quality). An optional MSE reconstruction term provides additional gradient stability.

Historical Training

A distinctive feature of MetricGAN+ is historical training. Enhanced utterances from previous epochs are saved along with their actual quality scores. During discriminator training, these historical samples are replayed to:

  • Prevent catastrophic forgetting of previously seen enhancement quality levels
  • Provide a richer training distribution spanning multiple quality levels
  • Enable the discriminator to learn a smooth mapping from spectral features to quality scores

The historical set grows over epochs, with a configurable portion (default: 20%) sampled for replay in each epoch.

Sub-Stage Training Architecture

The training loop is organized into three sub-stages within each epoch:

Sub-Stage Description Optimizer Updated
CURRENT Discriminator trained on current epoch's clean, enhanced, and noisy speech Discriminator
HISTORICAL Discriminator retrained on enhanced speech from previous epochs Discriminator
GENERATOR Generator trained to maximize discriminator's predicted score Generator

This sub-stage architecture requires a custom fit_batch() that:

  • Maintains separate optimizers for generator and discriminator
  • Routes gradient updates to the correct optimizer based on the current sub-stage
  • Computes different loss objectives depending on the sub-stage

PESQ as Training Target

PESQ (Perceptual Evaluation of Speech Quality) scores are normalized to [0, 1] for use as training targets:

normalized_pesq = (raw_pesq + 0.5) / 5.0

where raw PESQ ranges from -0.5 to 4.5. This normalization ensures compatibility with the discriminator's Sigmoid output range.

Dual Optimizer Management

Unlike standard Brain training which uses a single optimizer, MetricGAN+ requires two independent optimizers:

  • Generator optimizer (g_opt_class): Adam with lr=0.0005, updates only generator parameters
  • Discriminator optimizer (d_opt_class): Adam with lr=0.0005, updates only discriminator parameters

The custom init_optimizers() method creates both optimizers and registers them with the checkpointer for resumable training.

Key Design Decisions

  • Metric-driven optimization: By using actual PESQ scores as discriminator targets, the model directly optimizes for human-perceived quality rather than proxy losses
  • Batch size of 1 for generator training: The default configuration uses batch_size=1 during generator training because PESQ computation is expensive and done per-utterance
  • Learnable Sigmoid clamping: The generator's learnable Sigmoid parameters are clamped to a maximum of 3.5 to prevent gradient explosion
  • Three-pass discriminator training: Each epoch trains the discriminator with three data passes (current, historical, current again) before a single generator pass, ensuring the discriminator is well-calibrated

Comparison with Standard Enhancement Training

Aspect GAN-Based (MetricGAN+) Conventional (MSE)
Objective Adversarial + perceptual metric Spectral MSE
Optimizers Two (generator + discriminator) One
Training complexity High (multi-stage per epoch) Low (single forward-backward)
PESQ score Directly optimized Only monitored
Typical epochs 750 50
Typical batch size 1 8

See Also

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment