Implementation:Speechbrain Speechbrain Separation Fit Batch
Appearance
| Field | Value |
|---|---|
| Implementation Name | Separation_Fit_Batch |
| API | Separation.fit_batch(self, batch) and Separation.evaluate_batch(self, batch, stage)
|
| Source | recipes/LibriMix/separation/train.py:L44 (class), L114-163 (fit_batch), L165-186 (evaluate_batch)
|
| Import | Recipe-specific, part of recipes/LibriMix/separation/train.py
|
| Type | API Doc |
| Related Principle | Principle:Speechbrain_Speechbrain_Custom_Batch_Training_For_Separation |
Purpose
The Separation class extends speechbrain.Brain to implement custom training and evaluation logic for speech separation. The fit_batch() method handles gradient clipping, nonfinite loss detection, and loss-based sample filtering. The evaluate_batch() method adds optional audio saving during the test stage.
Class Definition
class Separation(sb.Brain):
def compute_forward(self, mix, targets, stage, noise=None):
"""Forward computations from the mixture to the separated signals."""
...
def compute_objectives(self, predictions, targets):
"""Computes the si-snr loss"""
return self.hparams.loss(targets, predictions)
def fit_batch(self, batch):
"""Trains one batch"""
...
def evaluate_batch(self, batch, stage):
"""Computations needed for validation/test batches"""
...
fit_batch Method
Signature
def fit_batch(self, batch):
Parameters
| Parameter | Type | Description |
|---|---|---|
batch |
PaddedBatch | A batch object with attributes: mix_sig, s1_sig, s2_sig, optionally s3_sig and noise_sig
|
Inputs
| Attribute | Type | Description |
|---|---|---|
batch.mix_sig |
(Tensor, Tensor) | Mixture waveform tensor [B, T] and relative lengths [B] |
batch.s1_sig |
(Tensor, Tensor) | First speaker source and lengths |
batch.s2_sig |
(Tensor, Tensor) | Second speaker source and lengths |
batch.s3_sig |
(Tensor, Tensor) | (optional) Third speaker source and lengths |
batch.noise_sig |
(Tensor, Tensor) | (optional) Noise signal and lengths |
Output
Returns a scalar torch.Tensor (detached, on CPU) representing the mean loss for the batch.
Implementation
def fit_batch(self, batch):
# Unpacking batch list
mixture = batch.mix_sig
targets = [batch.s1_sig, batch.s2_sig]
if self.hparams.use_wham_noise:
noise = batch.noise_sig[0]
else:
noise = None
if self.hparams.num_spks == 3:
targets.append(batch.s3_sig)
with self.training_ctx:
predictions, targets = self.compute_forward(
mixture, targets, sb.Stage.TRAIN, noise
)
loss = self.compute_objectives(predictions, targets)
# Hard threshold the easy data items
if self.hparams.threshold_byloss:
th = self.hparams.threshold
loss = loss[loss > th]
if loss.nelement() > 0:
loss = loss.mean()
else:
loss = loss.mean()
if loss.nelement() > 0 and loss < self.hparams.loss_upper_lim:
self.scaler.scale(loss).backward()
if self.hparams.clip_grad_norm >= 0:
self.scaler.unscale_(self.optimizer)
torch.nn.utils.clip_grad_norm_(
self.modules.parameters(),
self.hparams.clip_grad_norm,
)
self.scaler.step(self.optimizer)
self.scaler.update()
else:
self.nonfinite_count += 1
logger.info(
"infinite loss or empty loss! it happened {} times so far "
"- skipping this batch".format(self.nonfinite_count)
)
loss.data = torch.tensor(0.0).to(self.device)
self.optimizer.zero_grad()
return loss.detach().cpu()
Processing Flow
- Unpack batch: Extract mixture signal, source targets, and optional noise
- Forward pass: Run encoder, mask net, decoder within mixed-precision context
- Compute loss: SI-SNR with PIT wrapper, returning per-example losses [B]
- Threshold filtering: If enabled, keep only losses above the threshold
- Loss validation: Check that loss has elements and is below the upper limit
- Backward pass: Scale loss for mixed precision, compute gradients
- Gradient clipping: Unscale gradients, clip to max norm
- Optimizer step: Update model parameters
- Nonfinite handling: If loss is invalid, skip update, increment counter
- Zero gradients: Clear accumulated gradients for next batch
evaluate_batch Method
Signature
def evaluate_batch(self, batch, stage):
Parameters
| Parameter | Type | Description |
|---|---|---|
batch |
PaddedBatch | Same structure as fit_batch input |
stage |
sb.Stage | Either sb.Stage.VALID or sb.Stage.TEST
|
Output
Returns a scalar torch.Tensor (detached) representing the mean loss for the batch.
Implementation
def evaluate_batch(self, batch, stage):
snt_id = batch.id
mixture = batch.mix_sig
targets = [batch.s1_sig, batch.s2_sig]
if self.hparams.num_spks == 3:
targets.append(batch.s3_sig)
with torch.no_grad():
predictions, targets = self.compute_forward(mixture, targets, stage)
loss = self.compute_objectives(predictions, targets)
# Manage audio file saving
if stage == sb.Stage.TEST and self.hparams.save_audio:
if hasattr(self.hparams, "n_audio_to_save"):
if self.hparams.n_audio_to_save > 0:
self.save_audio(snt_id[0], mixture, targets, predictions)
self.hparams.n_audio_to_save += -1
else:
self.save_audio(snt_id[0], mixture, targets, predictions)
return loss.mean().detach()
Usage Example
import sys
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
# Load configuration
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
with open(hparams_file, encoding="utf-8") as fin:
hparams = load_hyperpyyaml(fin, overrides)
# Initialize the Separation Brain
separator = Separation(
modules=hparams["modules"],
opt_class=hparams["optimizer"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)
# Re-initialize parameters (if not using pretrained model)
for module in separator.modules.values():
separator.reset_layer_recursively(module)
# Training
separator.fit(
separator.hparams.epoch_counter,
train_data,
valid_data,
train_loader_kwargs=hparams["dataloader_opts"],
valid_loader_kwargs=hparams["dataloader_opts"],
)
# Evaluation
separator.evaluate(test_data, min_key="si-snr")
Key Implementation Details
- Mixed precision: The training context (
self.training_ctx) andself.scaler(GradScaler) handle fp16/bf16 mixed-precision training - Gradient clipping order: Gradients must be unscaled before clipping when using mixed precision, hence
self.scaler.unscale_(self.optimizer)is called beforeclip_grad_norm_ - Nonfinite counter: The
self.nonfinite_countattribute tracks how many batches have been skipped, useful for diagnosing data quality issues - Loss zeroing: When a batch is skipped, the loss tensor data is set to 0.0 to avoid corrupting logging statistics
- Audio saving control: The
n_audio_to_savecounter decrements with each saved example, providing precise control over disk usage during evaluation
Source File
recipes/LibriMix/separation/train.py
See Also
- Principle:Speechbrain_Speechbrain_Custom_Batch_Training_For_Separation
- Implementation:Speechbrain_Speechbrain_Load_Hyperpyyaml_SepFormer
- Implementation:Speechbrain_Speechbrain_Get_Si_Snr_With_Pitwrapper
Related Pages
- Principle:Speechbrain_Speechbrain_Custom_Batch_Training_For_Separation
- Environment:Speechbrain_Speechbrain_PyTorch_CUDA_Runtime
- Environment:Speechbrain_Speechbrain_Multi_GPU_DDP
- Heuristic:Speechbrain_Speechbrain_Gradient_Clipping_Strategy
- Heuristic:Speechbrain_Speechbrain_Separation_Loss_Filtering
- Heuristic:Speechbrain_Speechbrain_Nonfinite_Loss_Handling
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment