Heuristic:Speechbrain Speechbrain Separation Loss Filtering
| Knowledge Sources | |
|---|---|
| Domains | Speech_Separation, Optimization |
| Last Updated | 2026-02-09 20:00 GMT |
Overview
Hard-example mining technique that filters out "easy" samples with SI-SNR loss better than -30 dB and skips batches with catastrophic losses exceeding 999999.
Description
SpeechBrain separation recipes implement two complementary loss filtering mechanisms. First, threshold-by-loss removes samples whose SI-SNR loss is already very good (below -30 dB), focusing gradient budget on harder examples. Second, loss upper limit skips entire batches where the loss exceeds 999999 (indicating catastrophic model output), preventing destructive gradient updates. These two mechanisms work together to stabilize separation training and accelerate convergence.
Usage
Apply when training speech separation models (SepFormer, ConvTasNet, DPTNet) on WSJ0Mix, LibriMix, WHAMandWHAMR, Aishell1Mix, or DNS datasets. The threshold-by-loss is enabled by default in all separation YAML configs with `threshold_byloss: True` and `threshold: -30`.
The Insight (Rule of Thumb)
- Action: Set `threshold_byloss: True` and `threshold: -30` in separation YAML config. Set `loss_upper_lim: 999999`.
- Value: threshold = -30 dB SI-SNR; loss_upper_lim = 999999
- Trade-off: More aggressive thresholds (e.g., -10) may discard too many samples and cause empty batches. The -30 dB value is conservative enough to only filter already-excellent separations.
Reasoning
In separation training, many samples in a batch may already be well-separated (especially in later epochs). Computing gradients for these "easy" samples wastes compute and dilutes the gradient signal from difficult examples. The -30 dB threshold was chosen empirically: SI-SNR of -30 dB or better indicates near-perfect separation that provides negligible learning signal. The loss upper limit of 999999 catches rare but destructive cases where the model produces catastrophic output (e.g., all-zero or all-noise predictions), which would produce enormous loss values that would corrupt model weights if backpropagated.
Code from `recipes/WSJ0Mix/separation/train.py:116-142`:
# hard threshold the easy dataitems
if self.hparams.threshold_byloss:
th = self.hparams.threshold
loss = loss[loss > th]
if loss.nelement() > 0:
loss = loss.mean()
else:
loss = loss.mean()
if loss.nelement() > 0 and loss < self.hparams.loss_upper_lim:
self.scaler.scale(loss).backward()
else:
self.nonfinite_count += 1
logger.info(
"infinite loss or empty loss! it happened %d times so far"
" - skipping this batch", self.nonfinite_count
)
loss.data = torch.tensor(0.0).to(self.device)