Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Speechbrain Speechbrain Train NMF

From Leeroopedia


Knowledge Sources
Domains Interpretability, Sound_Classification
Last Updated 2026-02-09 00:00 GMT

Overview

Concrete tool for training a Non-Negative Matrix Factorization (NMF) model with amortized inference on ESC50 data provided by the SpeechBrain library.

Description

This recipe trains an NMF model with amortized inference for interpretable sound classification on the ESC-50 environmental sound dataset. The NMFBrain class extends sb.core.Brain and implements a pipeline that: (1) computes the STFT of input waveforms, (2) takes the log-magnitude spectrogram, (3) encodes it via an NMF encoder into latent activations, and (4) reconstructs the spectrogram via an NMF decoder. Training minimizes the L2 reconstruction error between original and reconstructed log-magnitude spectrograms. During validation, the model periodically saves visual comparisons of original vs. reconstructed spectrograms as PNG images for inspection. The NMF factorization learns a dictionary of spectral basis components that can be inspected to understand which frequency patterns characterize different sound classes.

Usage

Run as a training recipe with a YAML hyperparameter file. Requires the ESC-50 dataset to be downloaded. Uses the dataio_prep function from train_l2i for data loading.

Code Reference

Source Location

Signature

class NMFBrain(sb.core.Brain):
    """The SpeechBrain class to train Non-Negative Factorization with Amortized Inference."""

    def compute_forward(self, batch, stage=sb.Stage.TRAIN):
        """Calculates the forward pass for NMF."""
        ...

    def compute_objectives(self, predictions, batch, stage=sb.Stage.TRAIN):
        """Computes the L2-error to train the NMF model."""
        ...

    def on_stage_end(self, stage, stage_loss, epoch=None):
        """Gets called at the end of an epoch."""
        ...

Import

python train_nmf.py hparams/nmf.yaml --data_folder /path/to/ESC-50-master

I/O Contract

Inputs

Name Type Required Description
hparams_file str Yes Path to YAML hyperparameter file (e.g., hparams/nmf.yaml)
--data_folder str Yes Path to the ESC-50 dataset root directory
batch.sig tuple(torch.Tensor, torch.Tensor) Yes Waveform tensor and lengths from the dataloader

Outputs

Name Type Description
Xhat torch.Tensor Reconstructed log-magnitude spectrogram from NMF decoder
loss float L2 reconstruction error between target and predicted spectrograms
nmf_rec/*.png image files Visual comparisons of original vs. reconstructed spectrograms

Usage Examples

# Train NMF model on ESC-50
python train_nmf.py hparams/nmf.yaml --data_folder /data/ESC-50-master

# The model learns spectral basis components (dictionary)
# that can be inspected for interpretability

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment