Principle:Speechbrain Speechbrain GAN Based Enhancement Training
| Property | Value |
|---|---|
| Principle Name | GAN_Based_Enhancement_Training |
| Workflow | Speech_Enhancement_Training |
| Domains | GAN_Training, Speech_Enhancement |
| Source Repository | speechbrain/speechbrain |
| Knowledge Sources | Fu et al. 2021 "MetricGAN+: An Improved Version of MetricGAN for Speech Enhancement" |
| Related Implementation | Implementation:Speechbrain_Speechbrain_MetricGanBrain_Fit_Batch |
Overview
GAN-Based Enhancement Training applies the Generative Adversarial Network (GAN) framework to speech enhancement, using a generator to enhance noisy speech and a discriminator to evaluate enhancement quality. The key innovation of MetricGAN+ over standard GANs is that the discriminator is trained to predict actual perceptual quality scores (such as PESQ) rather than simple real/fake binary labels. This directly optimizes for human-perceived speech quality.
Theoretical Background
Standard GAN vs. MetricGAN
In a standard GAN for speech enhancement:
- The generator G transforms noisy speech into enhanced speech
- The discriminator D classifies speech as "real" (clean) or "fake" (enhanced)
- The training objective is a minimax game: G tries to fool D, and D tries to distinguish real from fake
The fundamental limitation is that the binary real/fake signal provides coarse gradient information -- the discriminator cannot communicate how much better or worse the enhanced speech is compared to the clean reference.
MetricGAN addresses this by redefining the discriminator's role:
Standard GAN discriminator: D(x) -> {0, 1} (real or fake)
MetricGAN discriminator: D(x, ref) -> [0, 1] (predicted quality score)
The discriminator takes both the evaluated speech and the clean reference as input, and outputs a continuous quality score. This score is trained to match the actual PESQ score of the evaluated speech, providing a differentiable approximation of the non-differentiable PESQ metric.
Training Procedure
MetricGAN+ training alternates between two phases in each epoch:
Phase 1: Discriminator Training
The discriminator is trained to accurately predict quality scores for three types of inputs:
- Clean speech (target score: 1.0, since clean-vs-clean PESQ is perfect)
- Enhanced speech (target score: actual normalized PESQ of the enhanced signal)
- Noisy speech (target score: actual normalized PESQ of the noisy signal)
The discriminator loss is MSE between predicted and actual scores:
L_D = MSE(D(clean, clean), 1.0)
+ MSE(D(enhanced, clean), PESQ(enhanced, clean))
+ MSE(D(noisy, clean), PESQ(noisy, clean))
Phase 2: Generator Training
The generator is trained to maximize the discriminator's predicted quality score:
L_G = MSE(D(G(noisy), clean), 1.0) + lambda * MSE(G(noisy)_spec, clean_spec)
The generator's goal is to produce enhanced speech that the (now-trained) discriminator rates as score 1.0 (perfect quality). An optional MSE reconstruction term provides additional gradient stability.
Historical Training
A distinctive feature of MetricGAN+ is historical training. Enhanced utterances from previous epochs are saved along with their actual quality scores. During discriminator training, these historical samples are replayed to:
- Prevent catastrophic forgetting of previously seen enhancement quality levels
- Provide a richer training distribution spanning multiple quality levels
- Enable the discriminator to learn a smooth mapping from spectral features to quality scores
The historical set grows over epochs, with a configurable portion (default: 20%) sampled for replay in each epoch.
Sub-Stage Training Architecture
The training loop is organized into three sub-stages within each epoch:
| Sub-Stage | Description | Optimizer Updated |
|---|---|---|
| CURRENT | Discriminator trained on current epoch's clean, enhanced, and noisy speech | Discriminator |
| HISTORICAL | Discriminator retrained on enhanced speech from previous epochs | Discriminator |
| GENERATOR | Generator trained to maximize discriminator's predicted score | Generator |
This sub-stage architecture requires a custom fit_batch() that:
- Maintains separate optimizers for generator and discriminator
- Routes gradient updates to the correct optimizer based on the current sub-stage
- Computes different loss objectives depending on the sub-stage
PESQ as Training Target
PESQ (Perceptual Evaluation of Speech Quality) scores are normalized to [0, 1] for use as training targets:
normalized_pesq = (raw_pesq + 0.5) / 5.0
where raw PESQ ranges from -0.5 to 4.5. This normalization ensures compatibility with the discriminator's Sigmoid output range.
Dual Optimizer Management
Unlike standard Brain training which uses a single optimizer, MetricGAN+ requires two independent optimizers:
- Generator optimizer (
g_opt_class): Adam with lr=0.0005, updates only generator parameters - Discriminator optimizer (
d_opt_class): Adam with lr=0.0005, updates only discriminator parameters
The custom init_optimizers() method creates both optimizers and registers them with the checkpointer for resumable training.
Key Design Decisions
- Metric-driven optimization: By using actual PESQ scores as discriminator targets, the model directly optimizes for human-perceived quality rather than proxy losses
- Batch size of 1 for generator training: The default configuration uses batch_size=1 during generator training because PESQ computation is expensive and done per-utterance
- Learnable Sigmoid clamping: The generator's learnable Sigmoid parameters are clamped to a maximum of 3.5 to prevent gradient explosion
- Three-pass discriminator training: Each epoch trains the discriminator with three data passes (current, historical, current again) before a single generator pass, ensuring the discriminator is well-calibrated
Comparison with Standard Enhancement Training
| Aspect | GAN-Based (MetricGAN+) | Conventional (MSE) |
|---|---|---|
| Objective | Adversarial + perceptual metric | Spectral MSE |
| Optimizers | Two (generator + discriminator) | One |
| Training complexity | High (multi-stage per epoch) | Low (single forward-backward) |
| PESQ score | Directly optimized | Only monitored |
| Typical epochs | 750 | 50 |
| Typical batch size | 1 | 8 |
See Also
- Implementation:Speechbrain_Speechbrain_MetricGanBrain_Fit_Batch -- The concrete implementation of the GAN training loop
- Heuristic:Speechbrain_Speechbrain_GAN_Dual_Optimizer_Pattern
- Principle:Speechbrain_Speechbrain_Conventional_Enhancement_Training -- The simpler alternative training approach
- Principle:Speechbrain_Speechbrain_Perceptual_Quality_Evaluation -- The metrics that MetricGAN+ optimizes for