Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Speechbrain Speechbrain Custom Batch Training For Separation

From Leeroopedia


Field Value
Principle Name Custom_Batch_Training_For_Separation
Domain(s) Training_Strategy, Speech_Separation
Description Overriding the default training loop to handle non-standard training procedures like gradient clipping and sample filtering
Related Implementation Implementation:Speechbrain_Speechbrain_Separation_Fit_Batch

Overview

Speech separation models require custom training logic that goes beyond what a standard training loop provides. The Separation class in SpeechBrain overrides the default Brain.fit_batch() and Brain.evaluate_batch() methods to incorporate nonfinite loss handling, gradient clipping, loss-based sample filtering, and optional audio saving during evaluation.

Theoretical Foundation

Why Custom Training Logic Is Needed

Standard supervised training loops assume that every batch produces a valid, finite loss and that all gradients are well-behaved. Speech separation models violate these assumptions in several ways:

  1. Nonfinite losses: Corrupted audio, silence, or numerical instability in SI-SNR computation can produce NaN or infinite losses. These must be detected and the offending batch must be skipped.
  2. Gradient explosion: The self-attention mechanism in transformer-based models like SepFormer is prone to gradient explosion, especially early in training. Gradient clipping is essential for stable convergence.
  3. Easy sample domination: In later stages of training, many examples become "too easy" (very high SI-SNR) and contribute little to learning. Filtering these out focuses training on harder, more informative examples.

Nonfinite Loss Handling

When a batch produces a loss that is nonfinite (NaN or infinity) or exceeds a configured upper limit (loss_upper_lim), the training step is skipped entirely:

  • The gradient is not computed
  • The optimizer is not stepped
  • A counter (nonfinite_count) is incremented for monitoring
  • The loss is set to 0.0 to avoid corrupting running statistics

This approach is more robust than simply clamping the loss, because a clamped but still anomalous gradient could destabilize subsequent updates.

Gradient Clipping

After computing the backward pass but before the optimizer step, gradients are clipped to a maximum norm:

if self.hparams.clip_grad_norm >= 0:
    self.scaler.unscale_(self.optimizer)
    torch.nn.utils.clip_grad_norm_(
        self.modules.parameters(),
        self.hparams.clip_grad_norm,
    )

The default clipping value is 5.0. The unscale_ call is necessary when using mixed-precision training (via GradScaler) because gradients must be in their true scale before clipping.

Threshold-by-Loss Filtering

When threshold_byloss is enabled, the per-example losses within a batch are filtered to only include those exceeding a threshold value (default: -30 dB SI-SNR):

if self.hparams.threshold_byloss:
    th = self.hparams.threshold
    loss = loss[loss > th]
    if loss.nelement() > 0:
        loss = loss.mean()

This mechanism effectively implements curriculum-like training by focusing the model on examples it has not yet mastered. If all examples in a batch are below the threshold (i.e., all are "easy"), the batch is treated similarly to a nonfinite loss and skipped.

Custom Evaluation Logic

The evaluate_batch() override adds:

  • Conditional audio saving: During the test stage, separated audio can be saved to disk for listening tests and qualitative evaluation
  • Configurable save count: The n_audio_to_save parameter limits how many examples are saved, preventing disk space issues on large test sets
  • Gradient-free computation: All evaluation is wrapped in torch.no_grad() for memory efficiency

Training Flow

The complete custom training flow for a single batch:

  1. Unpack batch: extract mixture, source signals, and optional noise
  2. Forward pass: encoder, mask net, decoder (within mixed-precision context)
  3. Compute SI-SNR loss with PIT
  4. Apply threshold-by-loss filtering (if enabled)
  5. Check for nonfinite or upper-limit-exceeding loss
  6. If loss is valid: backward pass, gradient clipping, optimizer step
  7. If loss is invalid: increment nonfinite counter, set loss to zero
  8. Zero gradients
  9. Return detached loss for logging

Key Parameters

Parameter Type Default Description
threshold_byloss bool True Enable loss-based sample filtering
threshold float -30 SI-SNR threshold for sample filtering (dB)
clip_grad_norm float 5 Maximum gradient norm for clipping
loss_upper_lim float 999999 Upper limit for acceptable loss values
save_audio bool False Whether to save separated audio during test
n_audio_to_save int (optional) Maximum number of audio examples to save

See Also

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment