Implementation:Speechbrain Speechbrain Get Si Snr With Pitwrapper
| Field | Value |
|---|---|
| Implementation Name | Get_Si_Snr_With_Pitwrapper |
| API | get_si_snr_with_pitwrapper(source, estimate_source) and PitWrapper(base_loss) and cal_si_snr(source, estimate_source)
|
| Source | speechbrain/nnet/losses.py:L988-1018 (get_si_snr_with_pitwrapper), L95-135 (PitWrapper), L1045-1093 (cal_si_snr)
|
| Import | from speechbrain.nnet.losses import get_si_snr_with_pitwrapper
|
| Type | API Doc |
| Related Principle | Principle:Speechbrain_Speechbrain_Permutation_Invariant_Training |
Purpose
The get_si_snr_with_pitwrapper function computes the Scale-Invariant Signal-to-Noise Ratio (SI-SNR) loss with Permutation Invariant Training (PIT). It automatically finds the optimal assignment of predicted sources to ground-truth sources by evaluating all permutations. This is the primary loss function used for training SepFormer and other speech separation models in SpeechBrain.
Function Signatures
get_si_snr_with_pitwrapper
def get_si_snr_with_pitwrapper(source, estimate_source):
PitWrapper
class PitWrapper(nn.Module):
def __init__(self, base_loss):
...
def forward(self, preds, targets):
...
cal_si_snr
def cal_si_snr(source, estimate_source):
Parameters
get_si_snr_with_pitwrapper
| Parameter | Type | Shape | Description |
|---|---|---|---|
source |
torch.Tensor | [B, T, C] | Ground truth source signals. B=batch, T=time, C=number of sources. |
estimate_source |
torch.Tensor | [B, T, C] | Estimated (predicted) source signals, same shape as source. |
PitWrapper.__init__
| Parameter | Type | Description |
|---|---|---|
base_loss |
callable | A loss function that takes predictions and targets with no reduction. For speech separation, this is cal_si_snr.
|
cal_si_snr
| Parameter | Type | Shape | Description |
|---|---|---|---|
source |
torch.Tensor | [T, B, C] | Ground truth source signals (note: transposed from outer API) |
estimate_source |
torch.Tensor | [T, B, C] | Estimated source signals |
Outputs
get_si_snr_with_pitwrapper
| Output | Type | Shape | Description |
|---|---|---|---|
loss |
torch.Tensor | [B] | Negative SI-SNR loss for each example in the batch, using the optimal source permutation |
PitWrapper.forward
| Output | Type | Shape | Description |
|---|---|---|---|
loss |
torch.Tensor | [B] | Optimal permutation-invariant loss per batch element |
perms |
list[tuple] | B tuples | Optimal permutation indices for each batch element |
cal_si_snr
| Output | Type | Shape | Description |
|---|---|---|---|
si_snr |
torch.Tensor | [1, B, C] | Negative SI-SNR values (negated so that minimization maximizes SI-SNR) |
Implementation Details
get_si_snr_with_pitwrapper
This is a convenience wrapper that instantiates PitWrapper with cal_si_snr and returns only the loss:
def get_si_snr_with_pitwrapper(source, estimate_source):
pit_si_snr = PitWrapper(cal_si_snr)
loss, perms = pit_si_snr(source, estimate_source)
return loss
PitWrapper Permutation Search
The PitWrapper.forward() method computes the loss for each batch element independently:
def forward(self, preds, targets):
losses = []
perms = []
for pred, label in zip(preds, targets):
loss, perm = self._opt_perm_loss(pred, label)
losses.append(loss)
perms.append(perm)
loss = torch.stack(losses)
return loss, perms
The _opt_perm_loss method builds a loss matrix by expanding predictions and targets along orthogonal dimensions, applies the base loss function, and then _fast_pit iterates over all permutations using itertools.permutations:
def _fast_pit(self, loss_mat):
loss = None
assigned_perm = None
for p in permutations(range(loss_mat.shape[0])):
c_loss = loss_mat[range(loss_mat.shape[0]), p].mean()
if loss is None or loss > c_loss:
loss = c_loss
assigned_perm = p
return loss, assigned_perm
cal_si_snr Computation
The SI-SNR calculation follows the standard formula:
def cal_si_snr(source, estimate_source):
EPS = 1e-8
# Zero-mean normalization
mean_target = torch.sum(source, dim=0, keepdim=True) / num_samples
mean_estimate = torch.sum(estimate_source, dim=0, keepdim=True) / num_samples
zero_mean_target = source - mean_target
zero_mean_estimate = estimate_source - mean_estimate
# Projection: s_target = (<s_hat, s> / ||s||^2) * s
dot = torch.sum(s_estimate * s_target, dim=0, keepdim=True)
s_target_energy = torch.sum(s_target ** 2, dim=0, keepdim=True) + EPS
proj = dot * s_target / s_target_energy
# Noise: e_noise = s_hat - s_target
e_noise = s_estimate - proj
# SI-SNR = 10 * log10(||s_target||^2 / ||e_noise||^2)
si_snr_beforelog = torch.sum(proj ** 2, dim=0) / (
torch.sum(e_noise ** 2, dim=0) + EPS
)
si_snr = 10 * torch.log10(si_snr_beforelog + EPS)
return -si_snr.unsqueeze(0)
Key details:
- Both source and estimate are zero-mean normalized before computation
- A small epsilon (1e-8) prevents division by zero
- The result is negated so that minimizing the loss maximizes SI-SNR
- Masking is applied to handle variable-length sequences within a batch
Usage Example
import torch
from speechbrain.nnet.losses import get_si_snr_with_pitwrapper
# Simulated ground truth: 2 sources, batch of 3, 100 time steps
source = torch.randn(3, 100, 2)
# Simulated predictions (sources in reversed order)
estimate_source = source[:, :, (1, 0)]
# Compute PIT-SI-SNR loss
loss = get_si_snr_with_pitwrapper(source, estimate_source)
print(loss) # tensor([very_negative, very_negative, very_negative])
# Negative because sources match perfectly (just permuted)
# Negate for reporting
si_snr_db = -loss
print(si_snr_db) # High positive values indicating perfect separation
Using PitWrapper Directly
from speechbrain.nnet.losses import PitWrapper, cal_si_snr
# Create PIT wrapper with SI-SNR as base loss
pit_loss = PitWrapper(cal_si_snr)
# Compute loss and get optimal permutations
loss, perms = pit_loss(source, estimate_source)
print(f"Loss: {loss}")
print(f"Optimal permutations: {perms}")
# Reorder predictions to match the optimal permutation
reordered = pit_loss.reorder_tensor(estimate_source, perms)
Key Implementation Details
PitWrapperis a general-purpose module that works with any unreduced loss function, not just SI-SNR- The permutation search complexity is O(C!) where C is the number of sources; this is efficient for C=2 (2 permutations) or C=3 (6 permutations)
cal_si_snrexpects tensors in [T, B, C] format (time-first), while the outer APIget_si_snr_with_pitwrapperaccepts [B, T, C] format (batch-first); the PitWrapper handles the transposition internally- The
reorder_tensormethod on PitWrapper allows reordering predictions to match the optimal permutation, useful for downstream processing
Source File
speechbrain/nnet/losses.py