Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Speechbrain Speechbrain Separation Save Results

From Leeroopedia


Field Value
Implementation Name Separation_Save_Results
API Separation.save_results(self, test_data) and Separation.save_audio(self, snt_id, mixture, targets, predictions)
Source recipes/LibriMix/separation/train.py:L296-390 (save_results), L392-427 (save_audio)
Import Recipe-specific. Uses mir_eval.separation.bss_eval_sources
Type API Doc
Related Principle Principle:Speechbrain_Speechbrain_Source_Separation_Evaluation

Purpose

The save_results method computes comprehensive evaluation metrics (SDR, SDRi, SI-SNR, SI-SNRi) for every test utterance and writes them to a CSV file. The save_audio method saves the separated audio, ground-truth sources, and mixture waveforms to disk for qualitative evaluation through listening tests.

Method Signatures

save_results

def save_results(self, test_data):

save_audio

def save_audio(self, snt_id, mixture, targets, predictions):

Parameters

save_results

Parameter Type Description
test_data DynamicItemDataset The test dataset, created by dataio_prep()

save_audio

Parameter Type Description
snt_id str Unique identifier for the utterance
mixture (Tensor, Tensor) Mixture signal and lengths from batch
targets Tensor Ground truth source signals [B, T, C]
predictions Tensor Model predictions [B, T, C]

Outputs

save_results

Produces a CSV file at {output_folder}/test_results.csv with the following columns:

Column Type Description
snt_id str Utterance identifier (final row is "avg")
sdr float Signal-to-Distortion Ratio (dB)
sdr_i float SDR improvement over mixture baseline (dB)
si-snr float Scale-Invariant Signal-to-Noise Ratio (dB)
si-snr_i float SI-SNR improvement over mixture baseline (dB)

The final row of the CSV contains the mean values across all test utterances.

save_audio

Produces WAV files in {save_folder}/audio_results/:

Filename Pattern Description
item{snt_id}_source{N}hat.wav Estimated (separated) source N
item{snt_id}_source{N}.wav Ground truth source N
item{snt_id}_mix.wav Original mixture

All audio files are peak-normalized (divided by absolute maximum) before saving.

Implementation Details

save_results

def save_results(self, test_data):
    from mir_eval.separation import bss_eval_sources

    save_file = os.path.join(self.hparams.output_folder, "test_results.csv")

    all_sdrs = []
    all_sdrs_i = []
    all_sisnrs = []
    all_sisnrs_i = []
    csv_columns = ["snt_id", "sdr", "sdr_i", "si-snr", "si-snr_i"]

    test_loader = sb.dataio.dataloader.make_dataloader(
        test_data, **self.hparams.dataloader_opts
    )

    with open(save_file, "w", newline="", encoding="utf-8") as results_csv:
        writer = csv.DictWriter(results_csv, fieldnames=csv_columns)
        writer.writeheader()

        for i, batch in enumerate(tqdm(test_loader)):
            mixture, mix_len = batch.mix_sig
            snt_id = batch.id
            targets = [batch.s1_sig, batch.s2_sig]
            if self.hparams.num_spks == 3:
                targets.append(batch.s3_sig)

            with torch.no_grad():
                predictions, targets = self.compute_forward(
                    batch.mix_sig, targets, sb.Stage.TEST
                )

            # Compute SI-SNR
            sisnr = self.compute_objectives(predictions, targets)

            # Compute SI-SNR improvement
            mixture_signal = torch.stack(
                [mixture] * self.hparams.num_spks, dim=-1
            )
            sisnr_baseline = self.compute_objectives(mixture_signal, targets)
            sisnr_i = sisnr - sisnr_baseline

            # Compute SDR using mir_eval
            sdr, _, _, _ = bss_eval_sources(
                targets[0].t().cpu().numpy(),
                predictions[0].t().detach().cpu().numpy(),
            )

            sdr_baseline, _, _, _ = bss_eval_sources(
                targets[0].t().cpu().numpy(),
                mixture_signal[0].t().detach().cpu().numpy(),
            )

            sdr_i = sdr.mean() - sdr_baseline.mean()

            # Write per-utterance results
            row = {
                "snt_id": snt_id[0],
                "sdr": sdr.mean(),
                "sdr_i": sdr_i,
                "si-snr": -sisnr.item(),
                "si-snr_i": -sisnr_i.item(),
            }
            writer.writerow(row)

            all_sdrs.append(sdr.mean())
            all_sdrs_i.append(sdr_i.mean())
            all_sisnrs.append(-sisnr.item())
            all_sisnrs_i.append(-sisnr_i.item())

        # Write average row
        row = {
            "snt_id": "avg",
            "sdr": np.array(all_sdrs).mean(),
            "sdr_i": np.array(all_sdrs_i).mean(),
            "si-snr": np.array(all_sisnrs).mean(),
            "si-snr_i": np.array(all_sisnrs_i).mean(),
        }
        writer.writerow(row)

save_audio

def save_audio(self, snt_id, mixture, targets, predictions):
    save_path = os.path.join(self.hparams.save_folder, "audio_results")
    if not os.path.exists(save_path):
        os.mkdir(save_path)

    for ns in range(self.hparams.num_spks):
        # Estimated source
        signal = predictions[0, :, ns]
        signal = signal / signal.abs().max()
        save_file = os.path.join(
            save_path, "item{}_source{}hat.wav".format(snt_id, ns + 1)
        )
        torchaudio.save(
            save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
        )

        # Original source
        signal = targets[0, :, ns]
        signal = signal / signal.abs().max()
        save_file = os.path.join(
            save_path, "item{}_source{}.wav".format(snt_id, ns + 1)
        )
        torchaudio.save(
            save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
        )

    # Mixture
    signal = mixture[0][0, :]
    signal = signal / signal.abs().max()
    save_file = os.path.join(save_path, "item{}_mix.wav".format(snt_id))
    torchaudio.save(
        save_file, signal.unsqueeze(0).cpu(), self.hparams.sample_rate
    )

Metric Computation Details

SI-SNR and SI-SNRi

SI-SNR is computed using the model's own loss function (get_si_snr_with_pitwrapper), which handles permutation alignment. The improvement metric is:

SI-SNRi = SI-SNR(target, estimate) - SI-SNR(target, mixture)

The baseline is computed by treating the unprocessed mixture as the "estimate" for each source. Note that the SI-SNR values are negated when written to CSV (since the loss function returns negative SI-SNR for optimization).

SDR and SDRi

SDR is computed using mir_eval.separation.bss_eval_sources, which implements the BSS_EVAL framework:

SDRi = SDR(target, estimate) - SDR(target, mixture)

The SDR computation is performed on CPU using numpy arrays (first batch element only, as batch size is typically 1 during evaluation).

Usage Example

import sys
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml

hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
with open(hparams_file, encoding="utf-8") as fin:
    hparams = load_hyperpyyaml(fin, overrides)

# Prepare data
train_data, valid_data, test_data = dataio_prep(hparams)

# Initialize and load trained model
separator = Separation(
    modules=hparams["modules"],
    opt_class=hparams["optimizer"],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams["checkpointer"],
)

# Run evaluation (computes loss on test set)
separator.evaluate(test_data, min_key="si-snr")

# Save detailed per-utterance metrics to CSV
separator.save_results(test_data)

# The output CSV will be at: {output_folder}/test_results.csv

Example Output CSV

snt_id,sdr,sdr_i,si-snr,si-snr_i
0,15.234,12.891,14.982,12.643
1,13.876,11.532,13.654,11.315
2,16.012,13.669,15.789,13.450
avg,15.041,12.697,14.808,12.469

Key Implementation Details

  • Lazy import: mir_eval is imported inside save_results() rather than at module level, making it an optional dependency
  • Negation convention: The SI-SNR loss is negated for training (minimization). When writing to CSV, it is re-negated to report the conventional positive-is-better metric: "si-snr": -sisnr.item()
  • Batch size assumption: SDR computation uses targets[0] and predictions[0], processing one example at a time (batch_size=1 during evaluation)
  • Peak normalization: All saved audio is divided by its absolute maximum to prevent clipping and ensure consistent playback levels
  • Progress tracking: The test loop uses tqdm for progress bar display

Dependencies

  • mir_eval: BSS_EVAL SDR computation (mir_eval.separation.bss_eval_sources)
  • torchaudio: Audio file saving
  • numpy: Metric aggregation
  • csv: CSV file writing
  • tqdm: Progress bar

Source File

recipes/LibriMix/separation/train.py

See Also

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment