Principle:Speechbrain Speechbrain Custom Batch Training For Separation
| Field | Value |
|---|---|
| Principle Name | Custom_Batch_Training_For_Separation |
| Domain(s) | Training_Strategy, Speech_Separation |
| Description | Overriding the default training loop to handle non-standard training procedures like gradient clipping and sample filtering |
| Related Implementation | Implementation:Speechbrain_Speechbrain_Separation_Fit_Batch |
Overview
Speech separation models require custom training logic that goes beyond what a standard training loop provides. The Separation class in SpeechBrain overrides the default Brain.fit_batch() and Brain.evaluate_batch() methods to incorporate nonfinite loss handling, gradient clipping, loss-based sample filtering, and optional audio saving during evaluation.
Theoretical Foundation
Why Custom Training Logic Is Needed
Standard supervised training loops assume that every batch produces a valid, finite loss and that all gradients are well-behaved. Speech separation models violate these assumptions in several ways:
- Nonfinite losses: Corrupted audio, silence, or numerical instability in SI-SNR computation can produce NaN or infinite losses. These must be detected and the offending batch must be skipped.
- Gradient explosion: The self-attention mechanism in transformer-based models like SepFormer is prone to gradient explosion, especially early in training. Gradient clipping is essential for stable convergence.
- Easy sample domination: In later stages of training, many examples become "too easy" (very high SI-SNR) and contribute little to learning. Filtering these out focuses training on harder, more informative examples.
Nonfinite Loss Handling
When a batch produces a loss that is nonfinite (NaN or infinity) or exceeds a configured upper limit (loss_upper_lim), the training step is skipped entirely:
- The gradient is not computed
- The optimizer is not stepped
- A counter (
nonfinite_count) is incremented for monitoring - The loss is set to 0.0 to avoid corrupting running statistics
This approach is more robust than simply clamping the loss, because a clamped but still anomalous gradient could destabilize subsequent updates.
Gradient Clipping
After computing the backward pass but before the optimizer step, gradients are clipped to a maximum norm:
if self.hparams.clip_grad_norm >= 0:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.modules.parameters(),
self.hparams.clip_grad_norm,
)
The default clipping value is 5.0. The unscale_ call is necessary when using mixed-precision training (via GradScaler) because gradients must be in their true scale before clipping.
Threshold-by-Loss Filtering
When threshold_byloss is enabled, the per-example losses within a batch are filtered to only include those exceeding a threshold value (default: -30 dB SI-SNR):
if self.hparams.threshold_byloss:
th = self.hparams.threshold
loss = loss[loss > th]
if loss.nelement() > 0:
loss = loss.mean()
This mechanism effectively implements curriculum-like training by focusing the model on examples it has not yet mastered. If all examples in a batch are below the threshold (i.e., all are "easy"), the batch is treated similarly to a nonfinite loss and skipped.
Custom Evaluation Logic
The evaluate_batch() override adds:
- Conditional audio saving: During the test stage, separated audio can be saved to disk for listening tests and qualitative evaluation
- Configurable save count: The
n_audio_to_saveparameter limits how many examples are saved, preventing disk space issues on large test sets - Gradient-free computation: All evaluation is wrapped in
torch.no_grad()for memory efficiency
Training Flow
The complete custom training flow for a single batch:
- Unpack batch: extract mixture, source signals, and optional noise
- Forward pass: encoder, mask net, decoder (within mixed-precision context)
- Compute SI-SNR loss with PIT
- Apply threshold-by-loss filtering (if enabled)
- Check for nonfinite or upper-limit-exceeding loss
- If loss is valid: backward pass, gradient clipping, optimizer step
- If loss is invalid: increment nonfinite counter, set loss to zero
- Zero gradients
- Return detached loss for logging
Key Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
threshold_byloss |
bool | True | Enable loss-based sample filtering |
threshold |
float | -30 | SI-SNR threshold for sample filtering (dB) |
clip_grad_norm |
float | 5 | Maximum gradient norm for clipping |
loss_upper_lim |
float | 999999 | Upper limit for acceptable loss values |
save_audio |
bool | False | Whether to save separated audio during test |
n_audio_to_save |
int | (optional) | Maximum number of audio examples to save |
See Also
- Implementation:Speechbrain_Speechbrain_Separation_Fit_Batch
- Heuristic:Speechbrain_Speechbrain_Gradient_Clipping_Strategy
- Heuristic:Speechbrain_Speechbrain_Separation_Loss_Filtering
- Heuristic:Speechbrain_Speechbrain_Nonfinite_Loss_Handling
- Principle:Speechbrain_Speechbrain_Permutation_Invariant_Training
- Principle:Speechbrain_Speechbrain_SepFormer_Model_Configuration