Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Speechbrain Speechbrain SEBrain Compute Forward

From Leeroopedia


Property Value
Implementation Name SEBrain_Compute_Forward
API SEBrain.compute_forward(self, batch, stage) and SEBrain.compute_objectives(self, predictions, batch, stage)
Source File recipes/Voicebank/enhance/spectral_mask/train.py -- Class: L26, compute_forward: L27-43, compute_objectives: L52-92
Import Recipe-specific Brain subclass (not importable as library)
Type API Doc
Workflow Speech_Enhancement_Training
Domains Speech_Enhancement, Deep_Learning
Related Principle Principle:Speechbrain_Speechbrain_Conventional_Enhancement_Training

Purpose

SEBrain (Speech Enhancement Brain) is a custom sb.Brain subclass that implements conventional supervised training for speech enhancement using spectral masking or waveform mapping. It provides the compute_forward() and compute_objectives() methods that plug into SpeechBrain's standard training loop to perform STFT-based mask prediction, spectral reconstruction, and MSE loss computation.

Class Definition

class SEBrain(sb.Brain):
    """Brain class for speech enhancement training.

    Supports both spectral mask and waveform mapping approaches,
    selectable via the 'waveform_target' hyperparameter.
    """

compute_forward Method

def compute_forward(self, batch, stage):
    """Forward computations from the waveform batches to the enhanced output.

    Arguments
    ---------
    batch : PaddedBatch
        Contains 'noisy_sig' (noisy waveform) and 'clean_sig' (clean waveform).
    stage : sb.Stage
        One of TRAIN, VALID, or TEST.

    Returns
    -------
    predict_spec : torch.Tensor
        Enhanced spectral features [batch, time, freq].
    predict_wav : torch.Tensor
        Reconstructed enhanced waveform [batch, samples].
    """
    batch = batch.to(self.device)
    noisy_wavs, lens = batch.noisy_sig
    noisy_feats = self.compute_feats(noisy_wavs)

    # Predict spectral mask using the model
    mask = self.modules.model(noisy_feats)
    mask = torch.squeeze(mask, 2)

    # Apply mask via signal approximation (SA)
    predict_spec = torch.mul(mask, noisy_feats)

    # Reconstruct waveform via ISTFT using original noisy phase
    predict_wav = self.hparams.resynth(
        torch.expm1(predict_spec), noisy_wavs
    )

    return predict_spec, predict_wav

compute_feats Method

def compute_feats(self, wavs):
    """Feature computation pipeline.

    Arguments
    ---------
    wavs : torch.Tensor
        Raw waveform tensor [batch, samples].

    Returns
    -------
    feats : torch.Tensor
        Log-compressed spectral magnitude [batch, time, freq].
    """
    feats = self.hparams.compute_STFT(wavs)
    feats = spectral_magnitude(feats, power=0.5)
    feats = torch.log1p(feats)
    return feats

compute_objectives Method

def compute_objectives(self, predictions, batch, stage):
    """Computes the loss given the predicted and targeted outputs.

    Arguments
    ---------
    predictions : tuple
        (predict_spec, predict_wav) from compute_forward.
    batch : PaddedBatch
        Contains 'clean_sig' (target waveform).
    stage : sb.Stage
        Current training stage.

    Returns
    -------
    loss : torch.Tensor
        Scalar loss value.
    """
    predict_spec, predict_wav = predictions
    clean_wavs, lens = batch.clean_sig

    if getattr(self.hparams, "waveform_target", False):
        # Waveform-domain loss
        loss = self.hparams.compute_cost(predict_wav, clean_wavs, lens)
        self.loss_metric.append(
            batch.id, predict_wav, clean_wavs, lens, reduction="batch"
        )
    else:
        # Spectral-domain loss (default)
        clean_spec = self.compute_feats(clean_wavs)
        loss = self.hparams.compute_cost(predict_spec, clean_spec, lens)
        self.loss_metric.append(
            batch.id, predict_spec, clean_spec, lens, reduction="batch"
        )

    if stage != sb.Stage.TRAIN:
        # Compute perceptual metrics during validation/test
        self.stoi_metric.append(
            batch.id, predict_wav, clean_wavs, lens, reduction="batch"
        )
        self.pesq_metric.append(
            batch.id, predict=predict_wav, target=clean_wavs, lengths=lens
        )

        # Write enhanced wavs to file during test
        if stage == sb.Stage.TEST:
            lens = lens * clean_wavs.shape[1]
            for name, pred_wav, length in zip(batch.id, predict_wav, lens):
                name += ".wav"
                enhance_path = os.path.join(
                    self.hparams.enhanced_folder, name
                )
                torchaudio.save(
                    enhance_path,
                    torch.unsqueeze(pred_wav[: int(length)].cpu(), 0),
                    16000,
                )

    return loss

Stage Callbacks

on_stage_start

def on_stage_start(self, stage, epoch=None):
    """Gets called at the beginning of each epoch."""
    self.loss_metric = MetricStats(metric=self.hparams.compute_cost)
    self.stoi_metric = MetricStats(metric=stoi_loss)

    def pesq_eval(pred_wav, target_wav):
        return pesq(
            fs=16000, ref=target_wav.numpy(),
            deg=pred_wav.numpy(), mode="wb",
        )

    if stage != sb.Stage.TRAIN:
        self.pesq_metric = MetricStats(
            metric=pesq_eval, n_jobs=1, batch_eval=False
        )

on_stage_end

def on_stage_end(self, stage, stage_loss, epoch=None):
    """Gets called at the end of an epoch."""
    if stage == sb.Stage.TRAIN:
        self.train_loss = stage_loss
    else:
        stats = {
            "loss": stage_loss,
            "pesq": self.pesq_metric.summarize("average"),
            "stoi": -self.stoi_metric.summarize("average"),
        }

    if stage == sb.Stage.VALID:
        self.hparams.train_logger.log_stats(
            {"Epoch": epoch},
            train_stats={"loss": self.train_loss},
            valid_stats=stats,
        )
        # Save checkpoint based on best PESQ
        self.checkpointer.save_and_keep_only(
            meta=stats, max_keys=["pesq"]
        )

Usage Examples

Full Training Pipeline

import sys
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml

# Load hyperparameters
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
with open(hparams_file, encoding="utf-8") as fin:
    hparams = load_hyperpyyaml(fin, overrides)

# Prepare data
from voicebank_prepare import prepare_voicebank
from speechbrain.utils.distributed import run_on_main

run_on_main(
    prepare_voicebank,
    kwargs={
        "data_folder": hparams["data_folder"],
        "save_folder": hparams["output_folder"],
        "skip_prep": hparams["skip_prep"],
    },
)

# Create datasets with audio pipelines
datasets = dataio_prep(hparams)

# Initialize SEBrain
se_brain = SEBrain(
    modules=hparams["modules"],
    opt_class=hparams["opt_class"],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams["checkpointer"],
)

# Train with standard Brain.fit() loop
se_brain.fit(
    epoch_counter=se_brain.hparams.epoch_counter,
    train_set=datasets["train"],
    valid_set=datasets["valid"],
    train_loader_kwargs=hparams["dataloader_options"],
    valid_loader_kwargs=hparams["dataloader_options"],
)

# Evaluate on test set using best PESQ checkpoint
test_stats = se_brain.evaluate(
    test_set=datasets["test"],
    max_key="pesq",
    test_loader_kwargs=hparams["dataloader_options"],
)

Running from Command Line

# Train with default BLSTM model
python train.py hparams/train.yaml --data_folder /data/noisy-vctk-16k

# Train with 2D-FCN model
python train.py hparams/train.yaml --data_folder /data/noisy-vctk-16k \
    --models '!include:models/2DFCN.yaml'

# Train with waveform-domain loss
python train.py hparams/train.yaml --data_folder /data/noisy-vctk-16k \
    --waveform_target True

Data Flow Diagram

noisy_wav ──> STFT ──> spectral_magnitude(power=0.5) ──> log1p ──> model ──> mask
                                                                              |
                                                                              v
noisy_wav ──> STFT ──> spectral_magnitude(power=0.5) ──> log1p ──> (*mask) = predict_spec
                                                                              |
                                                                              v
                                                          expm1 ──> ISTFT ──> predict_wav
                                                                              |
clean_wav ──> STFT ──> spectral_magnitude(power=0.5) ──> log1p = clean_spec  |
                                                              |               |
                                                              v               v
                                                         MSE(predict_spec, clean_spec) = loss

Inputs and Outputs

Inputs (per batch):

  • batch.noisy_sig: Tuple of (noisy waveform tensor [B, T], relative lengths [B])
  • batch.clean_sig: Tuple of (clean waveform tensor [B, T], relative lengths [B])

Outputs:

  • Training loss: Spectral MSE (or waveform MSE if waveform_target=True)
  • Validation metrics: PESQ (1-4.5), STOI (0-1), computed on enhanced waveforms
  • Enhanced wavs: Written to enhanced_folder during test stage

Key Configuration Parameters

Parameter Default Description
waveform_target False If True, compute loss in waveform domain instead of spectral domain
number_of_epochs 50 Total training epochs
N_batch 8 Batch size
lr 0.0001 Learning rate for Adam optimizer
sorting "ascending" Sort training data by length for efficient batching
N_fft 512 FFT size (32 ms at 16 kHz)
Win_length 32 Window length in milliseconds
Hop_length 16 Hop length in milliseconds

See Also

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment