Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Speechbrain Speechbrain Get Si Snr With Pitwrapper

From Leeroopedia


Field Value
Implementation Name Get_Si_Snr_With_Pitwrapper
API get_si_snr_with_pitwrapper(source, estimate_source) and PitWrapper(base_loss) and cal_si_snr(source, estimate_source)
Source speechbrain/nnet/losses.py:L988-1018 (get_si_snr_with_pitwrapper), L95-135 (PitWrapper), L1045-1093 (cal_si_snr)
Import from speechbrain.nnet.losses import get_si_snr_with_pitwrapper
Type API Doc
Related Principle Principle:Speechbrain_Speechbrain_Permutation_Invariant_Training

Purpose

The get_si_snr_with_pitwrapper function computes the Scale-Invariant Signal-to-Noise Ratio (SI-SNR) loss with Permutation Invariant Training (PIT). It automatically finds the optimal assignment of predicted sources to ground-truth sources by evaluating all permutations. This is the primary loss function used for training SepFormer and other speech separation models in SpeechBrain.

Function Signatures

get_si_snr_with_pitwrapper

def get_si_snr_with_pitwrapper(source, estimate_source):

PitWrapper

class PitWrapper(nn.Module):
    def __init__(self, base_loss):
        ...

    def forward(self, preds, targets):
        ...

cal_si_snr

def cal_si_snr(source, estimate_source):

Parameters

get_si_snr_with_pitwrapper

Parameter Type Shape Description
source torch.Tensor [B, T, C] Ground truth source signals. B=batch, T=time, C=number of sources.
estimate_source torch.Tensor [B, T, C] Estimated (predicted) source signals, same shape as source.

PitWrapper.__init__

Parameter Type Description
base_loss callable A loss function that takes predictions and targets with no reduction. For speech separation, this is cal_si_snr.

cal_si_snr

Parameter Type Shape Description
source torch.Tensor [T, B, C] Ground truth source signals (note: transposed from outer API)
estimate_source torch.Tensor [T, B, C] Estimated source signals

Outputs

get_si_snr_with_pitwrapper

Output Type Shape Description
loss torch.Tensor [B] Negative SI-SNR loss for each example in the batch, using the optimal source permutation

PitWrapper.forward

Output Type Shape Description
loss torch.Tensor [B] Optimal permutation-invariant loss per batch element
perms list[tuple] B tuples Optimal permutation indices for each batch element

cal_si_snr

Output Type Shape Description
si_snr torch.Tensor [1, B, C] Negative SI-SNR values (negated so that minimization maximizes SI-SNR)

Implementation Details

get_si_snr_with_pitwrapper

This is a convenience wrapper that instantiates PitWrapper with cal_si_snr and returns only the loss:

def get_si_snr_with_pitwrapper(source, estimate_source):
    pit_si_snr = PitWrapper(cal_si_snr)
    loss, perms = pit_si_snr(source, estimate_source)
    return loss

PitWrapper Permutation Search

The PitWrapper.forward() method computes the loss for each batch element independently:

def forward(self, preds, targets):
    losses = []
    perms = []
    for pred, label in zip(preds, targets):
        loss, perm = self._opt_perm_loss(pred, label)
        losses.append(loss)
        perms.append(perm)
    loss = torch.stack(losses)
    return loss, perms

The _opt_perm_loss method builds a loss matrix by expanding predictions and targets along orthogonal dimensions, applies the base loss function, and then _fast_pit iterates over all permutations using itertools.permutations:

def _fast_pit(self, loss_mat):
    loss = None
    assigned_perm = None
    for p in permutations(range(loss_mat.shape[0])):
        c_loss = loss_mat[range(loss_mat.shape[0]), p].mean()
        if loss is None or loss > c_loss:
            loss = c_loss
            assigned_perm = p
    return loss, assigned_perm

cal_si_snr Computation

The SI-SNR calculation follows the standard formula:

def cal_si_snr(source, estimate_source):
    EPS = 1e-8

    # Zero-mean normalization
    mean_target = torch.sum(source, dim=0, keepdim=True) / num_samples
    mean_estimate = torch.sum(estimate_source, dim=0, keepdim=True) / num_samples
    zero_mean_target = source - mean_target
    zero_mean_estimate = estimate_source - mean_estimate

    # Projection: s_target = (<s_hat, s> / ||s||^2) * s
    dot = torch.sum(s_estimate * s_target, dim=0, keepdim=True)
    s_target_energy = torch.sum(s_target ** 2, dim=0, keepdim=True) + EPS
    proj = dot * s_target / s_target_energy

    # Noise: e_noise = s_hat - s_target
    e_noise = s_estimate - proj

    # SI-SNR = 10 * log10(||s_target||^2 / ||e_noise||^2)
    si_snr_beforelog = torch.sum(proj ** 2, dim=0) / (
        torch.sum(e_noise ** 2, dim=0) + EPS
    )
    si_snr = 10 * torch.log10(si_snr_beforelog + EPS)

    return -si_snr.unsqueeze(0)

Key details:

  • Both source and estimate are zero-mean normalized before computation
  • A small epsilon (1e-8) prevents division by zero
  • The result is negated so that minimizing the loss maximizes SI-SNR
  • Masking is applied to handle variable-length sequences within a batch

Usage Example

import torch
from speechbrain.nnet.losses import get_si_snr_with_pitwrapper

# Simulated ground truth: 2 sources, batch of 3, 100 time steps
source = torch.randn(3, 100, 2)

# Simulated predictions (sources in reversed order)
estimate_source = source[:, :, (1, 0)]

# Compute PIT-SI-SNR loss
loss = get_si_snr_with_pitwrapper(source, estimate_source)
print(loss)  # tensor([very_negative, very_negative, very_negative])
# Negative because sources match perfectly (just permuted)

# Negate for reporting
si_snr_db = -loss
print(si_snr_db)  # High positive values indicating perfect separation

Using PitWrapper Directly

from speechbrain.nnet.losses import PitWrapper, cal_si_snr

# Create PIT wrapper with SI-SNR as base loss
pit_loss = PitWrapper(cal_si_snr)

# Compute loss and get optimal permutations
loss, perms = pit_loss(source, estimate_source)
print(f"Loss: {loss}")
print(f"Optimal permutations: {perms}")

# Reorder predictions to match the optimal permutation
reordered = pit_loss.reorder_tensor(estimate_source, perms)

Key Implementation Details

  • PitWrapper is a general-purpose module that works with any unreduced loss function, not just SI-SNR
  • The permutation search complexity is O(C!) where C is the number of sources; this is efficient for C=2 (2 permutations) or C=3 (6 permutations)
  • cal_si_snr expects tensors in [T, B, C] format (time-first), while the outer API get_si_snr_with_pitwrapper accepts [B, T, C] format (batch-first); the PitWrapper handles the transposition internally
  • The reorder_tensor method on PitWrapper allows reordering predictions to match the optimal permutation, useful for downstream processing

Source File

speechbrain/nnet/losses.py

See Also

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment