Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Speechbrain Speechbrain MetricGanBrain Fit Batch

From Leeroopedia


Property Value
Implementation Name MetricGanBrain_Fit_Batch
API MetricGanBrain.fit_batch(self, batch)
Source File recipes/Voicebank/enhance/MetricGAN/train.py -- Class: L48, fit_batch: L299-346, compute_objectives: L83-163
Import Recipe-specific Brain subclass (not importable as library)
Type API Doc
Workflow Speech_Enhancement_Training
Domains GAN_Training, Speech_Enhancement
Related Principle Principle:Speechbrain_Speechbrain_GAN_Based_Enhancement_Training

Purpose

MetricGanBrain is a custom sb.Brain subclass that implements the MetricGAN+ training procedure for speech enhancement. The core method fit_batch() manages the alternating generator/discriminator optimization, while compute_objectives() implements the sub-stage-specific loss computation that uses actual PESQ scores as discriminator training targets.

Class Definition

class SubStage(Enum):
    """For keeping track of training stage progress"""
    GENERATOR = auto()
    CURRENT = auto()
    HISTORICAL = auto()

class MetricGanBrain(sb.Brain):
    """Brain class for MetricGAN+ speech enhancement training.

    Manages dual optimizers, sub-stage training, and historical
    sample replay for adversarial perceptual metric optimization.
    """

fit_batch Method

def fit_batch(self, batch):
    "Compute gradients and update either D or G based on sub-stage."
    predictions = self.compute_forward(batch, sb.Stage.TRAIN)
    loss_tracker = 0

    if self.sub_stage == SubStage.CURRENT:
        # Discriminator training on current data: clean, enhanced, noisy
        for mode in ["clean", "enh", "noisy"]:
            loss = self.compute_objectives(
                predictions, batch, sb.Stage.TRAIN, f"D_{mode}"
            )
            self.d_optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(
                self.modules.parameters(), self.max_grad_norm
            )
            self.d_optimizer.step()
            loss_tracker += loss.detach() / 3

    elif self.sub_stage == SubStage.HISTORICAL:
        # Discriminator training on historical enhanced samples
        loss = self.compute_objectives(
            predictions, batch, sb.Stage.TRAIN, "D_enh"
        )
        self.d_optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            self.modules.parameters(), self.max_grad_norm
        )
        self.d_optimizer.step()
        loss_tracker += loss.detach()

    elif self.sub_stage == SubStage.GENERATOR:
        # Clamp learnable sigmoid to prevent gradient explosion
        for name, param in self.modules.generator.named_parameters():
            if "Learnable_sigmoid" in name:
                param.data = torch.clamp(param, max=3.5)
                param.data[param != param] = 3.5  # set NaN to 3.5

        loss = self.compute_objectives(
            predictions, batch, sb.Stage.TRAIN, "generator"
        )
        self.g_optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(
            self.modules.parameters(), self.max_grad_norm
        )
        self.g_optimizer.step()
        loss_tracker += loss.detach()

    return loss_tracker

compute_objectives Method

The compute_objectives method dispatches on the optim_name parameter to compute the appropriate loss for each sub-stage:

def compute_objectives(self, predictions, batch, stage, optim_name=""):
    "Given the network predictions and targets compute the total loss"
    predict_wav = predictions
    predict_spec = self.compute_feats(predict_wav)
    clean_wav, lens = batch.clean_sig
    clean_spec = self.compute_feats(clean_wav)

    if optim_name == "generator":
        # Generator aims for score of 1.0 (perfect quality)
        target_score = torch.ones(self.batch_size, 1, device=self.device)
        est_score = self.est_score(predict_spec, clean_spec)
        mse_cost = self.hparams.compute_cost(predict_spec, clean_spec, lens)
        # cost = MSE(est_score, 1.0) + mse_weight * MSE(spec, clean_spec)

    elif optim_name == "D_clean":
        # Discriminator learns: clean speech -> score 1.0
        target_score = torch.ones(self.batch_size, 1, device=self.device)
        est_score = self.est_score(clean_spec, clean_spec)

    elif optim_name == "D_enh" and self.sub_stage == SubStage.CURRENT:
        # Discriminator learns: enhanced speech -> actual PESQ score
        target_score = self.score(ids, predict_wav, clean_wav, lens)
        est_score = self.est_score(predict_spec, clean_spec)

    elif optim_name == "D_enh" and self.sub_stage == SubStage.HISTORICAL:
        # Discriminator relearns: historical enhanced speech -> saved score
        target_score = batch.score.unsqueeze(1).float()
        est_score = self.est_score(predict_spec, clean_spec)

    elif optim_name == "D_noisy":
        # Discriminator learns: noisy speech -> actual PESQ score
        noisy_wav, _ = batch.noisy_sig
        noisy_spec = self.compute_feats(noisy_wav)
        target_score = self.score(ids, noisy_wav, clean_wav, lens)
        est_score = self.est_score(noisy_spec, clean_spec)

    cost = self.hparams.compute_cost(est_score, target_score)
    if optim_name == "generator":
        cost += self.hparams.mse_weight * mse_cost

    return cost

compute_forward Method

def compute_forward(self, batch, stage):
    "Given an input batch computes the enhanced signal"
    batch = batch.to(self.device)

    if self.sub_stage == SubStage.HISTORICAL:
        # Historical data already has pre-computed enhanced wavs
        predict_wav, lens = batch.enh_sig
    else:
        noisy_wav, lens = batch.noisy_sig
        noisy_spec = self.compute_feats(noisy_wav)

        # Predict spectral mask via generator
        mask = self.modules.generator(noisy_spec, lengths=lens)
        mask = mask.clamp(min=self.hparams.min_mask).squeeze(2)
        predict_spec = torch.mul(mask, noisy_spec)

        # Reconstruct waveform via ISTFT
        predict_wav = self.hparams.resynth(
            torch.expm1(predict_spec), noisy_wav
        )

    return predict_wav

Training Orchestration

The epoch-level training is managed by on_stage_start and train_discriminator:

def train_discriminator(self):
    """A total of 3 data passes to update discriminator."""
    # Pass 1: Current data (clean, enhanced, noisy)
    self.sub_stage = SubStage.CURRENT
    self.fit(range(1), self.train_set,
             train_loader_kwargs=self.hparams.dataloader_options)

    # Pass 2: Historical enhanced data
    if self.historical_set:
        self.sub_stage = SubStage.HISTORICAL
        self.fit(range(1), self.historical_set,
                 train_loader_kwargs=self.hparams.dataloader_options)

    # Pass 3: Current data again
    self.sub_stage = SubStage.CURRENT
    self.fit(range(1), self.train_set,
             train_loader_kwargs=self.hparams.dataloader_options)

Dual Optimizer Initialization

def init_optimizers(self):
    "Initializes the generator and discriminator optimizers"
    self.g_optimizer = self.hparams.g_opt_class(
        self.modules.generator.parameters()
    )
    self.d_optimizer = self.hparams.d_opt_class(
        self.modules.discriminator.parameters()
    )

    if self.checkpointer is not None:
        self.checkpointer.add_recoverable("g_opt", self.g_optimizer)
        self.checkpointer.add_recoverable("d_opt", self.d_optimizer)

Usage Example

Full Training Pipeline

import sys
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
from train import MetricGanBrain, SubStage

# Load configuration
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
with open(hparams_file, encoding="utf-8") as fin:
    hparams = load_hyperpyyaml(fin, overrides)

# Prepare data
from voicebank_prepare import prepare_voicebank
prepare_voicebank(
    data_folder=hparams["data_folder"],
    save_folder=hparams["data_folder"],
)

# Create datasets
datasets = dataio_prep(hparams)

# Initialize Brain
se_brain = MetricGanBrain(
    modules=hparams["modules"],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams["checkpointer"],
)
se_brain.train_set = datasets["train"]
se_brain.historical_set = {}
se_brain.noisy_scores = {}
se_brain.batch_size = hparams["dataloader_options"]["batch_size"]
se_brain.sub_stage = SubStage.GENERATOR

# Train
se_brain.fit(
    epoch_counter=hparams["epoch_counter"],
    train_set=datasets["train"],
    valid_set=datasets["valid"],
    train_loader_kwargs=hparams["dataloader_options"],
    valid_loader_kwargs=hparams["valid_dataloader_options"],
)

# Evaluate
test_stats = se_brain.evaluate(
    test_set=datasets["test"],
    max_key=hparams["target_metric"],
    test_loader_kwargs=hparams["dataloader_options"],
)

Key Configuration Parameters

Parameter Default Description
target_metric "pesq" Metric used as discriminator target ("pesq" or "stoi")
G_lr 0.0005 Generator learning rate
D_lr 0.0005 Discriminator learning rate
mse_weight 0 Weight for spectral MSE reconstruction loss in generator objective
min_mask 0.05 Minimum mask value (prevents complete signal suppression)
number_of_epochs 750 Total training epochs
number_of_samples 100 Samples per epoch for generator training
history_portion 0.2 Fraction of historical set used per epoch
train_N_batch 1 Batch size for training
valid_N_batch 20 Batch size for validation

Inputs and Outputs

Inputs (per batch):

  • batch.noisy_sig: Noisy speech waveform tensor and lengths
  • batch.clean_sig: Clean speech reference waveform tensor and lengths
  • batch.noisy_wav: Path to noisy audio file (for identification)

Outputs:

  • Generator loss: MSE(D(enhanced, clean), 1.0) + mse_weight * spectral_MSE
  • Discriminator loss: MSE(D(input, clean), actual_score) for each input type
  • Enhanced wavs: Written to enhanced_folder during validation/test

See Also

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment