Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Speechbrain Speechbrain Separation Fit Batch

From Leeroopedia


Field Value
Implementation Name Separation_Fit_Batch
API Separation.fit_batch(self, batch) and Separation.evaluate_batch(self, batch, stage)
Source recipes/LibriMix/separation/train.py:L44 (class), L114-163 (fit_batch), L165-186 (evaluate_batch)
Import Recipe-specific, part of recipes/LibriMix/separation/train.py
Type API Doc
Related Principle Principle:Speechbrain_Speechbrain_Custom_Batch_Training_For_Separation

Purpose

The Separation class extends speechbrain.Brain to implement custom training and evaluation logic for speech separation. The fit_batch() method handles gradient clipping, nonfinite loss detection, and loss-based sample filtering. The evaluate_batch() method adds optional audio saving during the test stage.

Class Definition

class Separation(sb.Brain):
    def compute_forward(self, mix, targets, stage, noise=None):
        """Forward computations from the mixture to the separated signals."""
        ...

    def compute_objectives(self, predictions, targets):
        """Computes the si-snr loss"""
        return self.hparams.loss(targets, predictions)

    def fit_batch(self, batch):
        """Trains one batch"""
        ...

    def evaluate_batch(self, batch, stage):
        """Computations needed for validation/test batches"""
        ...

fit_batch Method

Signature

def fit_batch(self, batch):

Parameters

Parameter Type Description
batch PaddedBatch A batch object with attributes: mix_sig, s1_sig, s2_sig, optionally s3_sig and noise_sig

Inputs

Attribute Type Description
batch.mix_sig (Tensor, Tensor) Mixture waveform tensor [B, T] and relative lengths [B]
batch.s1_sig (Tensor, Tensor) First speaker source and lengths
batch.s2_sig (Tensor, Tensor) Second speaker source and lengths
batch.s3_sig (Tensor, Tensor) (optional) Third speaker source and lengths
batch.noise_sig (Tensor, Tensor) (optional) Noise signal and lengths

Output

Returns a scalar torch.Tensor (detached, on CPU) representing the mean loss for the batch.

Implementation

def fit_batch(self, batch):
    # Unpacking batch list
    mixture = batch.mix_sig
    targets = [batch.s1_sig, batch.s2_sig]
    if self.hparams.use_wham_noise:
        noise = batch.noise_sig[0]
    else:
        noise = None

    if self.hparams.num_spks == 3:
        targets.append(batch.s3_sig)

    with self.training_ctx:
        predictions, targets = self.compute_forward(
            mixture, targets, sb.Stage.TRAIN, noise
        )
        loss = self.compute_objectives(predictions, targets)

        # Hard threshold the easy data items
        if self.hparams.threshold_byloss:
            th = self.hparams.threshold
            loss = loss[loss > th]
            if loss.nelement() > 0:
                loss = loss.mean()
        else:
            loss = loss.mean()

    if loss.nelement() > 0 and loss < self.hparams.loss_upper_lim:
        self.scaler.scale(loss).backward()
        if self.hparams.clip_grad_norm >= 0:
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(
                self.modules.parameters(),
                self.hparams.clip_grad_norm,
            )
        self.scaler.step(self.optimizer)
        self.scaler.update()
    else:
        self.nonfinite_count += 1
        logger.info(
            "infinite loss or empty loss! it happened {} times so far "
            "- skipping this batch".format(self.nonfinite_count)
        )
        loss.data = torch.tensor(0.0).to(self.device)
    self.optimizer.zero_grad()

    return loss.detach().cpu()

Processing Flow

  1. Unpack batch: Extract mixture signal, source targets, and optional noise
  2. Forward pass: Run encoder, mask net, decoder within mixed-precision context
  3. Compute loss: SI-SNR with PIT wrapper, returning per-example losses [B]
  4. Threshold filtering: If enabled, keep only losses above the threshold
  5. Loss validation: Check that loss has elements and is below the upper limit
  6. Backward pass: Scale loss for mixed precision, compute gradients
  7. Gradient clipping: Unscale gradients, clip to max norm
  8. Optimizer step: Update model parameters
  9. Nonfinite handling: If loss is invalid, skip update, increment counter
  10. Zero gradients: Clear accumulated gradients for next batch

evaluate_batch Method

Signature

def evaluate_batch(self, batch, stage):

Parameters

Parameter Type Description
batch PaddedBatch Same structure as fit_batch input
stage sb.Stage Either sb.Stage.VALID or sb.Stage.TEST

Output

Returns a scalar torch.Tensor (detached) representing the mean loss for the batch.

Implementation

def evaluate_batch(self, batch, stage):
    snt_id = batch.id
    mixture = batch.mix_sig
    targets = [batch.s1_sig, batch.s2_sig]
    if self.hparams.num_spks == 3:
        targets.append(batch.s3_sig)

    with torch.no_grad():
        predictions, targets = self.compute_forward(mixture, targets, stage)
        loss = self.compute_objectives(predictions, targets)

    # Manage audio file saving
    if stage == sb.Stage.TEST and self.hparams.save_audio:
        if hasattr(self.hparams, "n_audio_to_save"):
            if self.hparams.n_audio_to_save > 0:
                self.save_audio(snt_id[0], mixture, targets, predictions)
                self.hparams.n_audio_to_save += -1
        else:
            self.save_audio(snt_id[0], mixture, targets, predictions)

    return loss.mean().detach()

Usage Example

import sys
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml

# Load configuration
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
with open(hparams_file, encoding="utf-8") as fin:
    hparams = load_hyperpyyaml(fin, overrides)

# Initialize the Separation Brain
separator = Separation(
    modules=hparams["modules"],
    opt_class=hparams["optimizer"],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams["checkpointer"],
)

# Re-initialize parameters (if not using pretrained model)
for module in separator.modules.values():
    separator.reset_layer_recursively(module)

# Training
separator.fit(
    separator.hparams.epoch_counter,
    train_data,
    valid_data,
    train_loader_kwargs=hparams["dataloader_opts"],
    valid_loader_kwargs=hparams["dataloader_opts"],
)

# Evaluation
separator.evaluate(test_data, min_key="si-snr")

Key Implementation Details

  • Mixed precision: The training context (self.training_ctx) and self.scaler (GradScaler) handle fp16/bf16 mixed-precision training
  • Gradient clipping order: Gradients must be unscaled before clipping when using mixed precision, hence self.scaler.unscale_(self.optimizer) is called before clip_grad_norm_
  • Nonfinite counter: The self.nonfinite_count attribute tracks how many batches have been skipped, useful for diagnosing data quality issues
  • Loss zeroing: When a batch is skipped, the loss tensor data is set to 0.0 to avoid corrupting logging statistics
  • Audio saving control: The n_audio_to_save counter decrements with each saved example, providing precise control over disk usage during evaluation

Source File

recipes/LibriMix/separation/train.py

See Also

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment