Implementation:Speechbrain Speechbrain Train NMF
| Knowledge Sources | |
|---|---|
| Domains | Interpretability, Sound_Classification |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete tool for training a Non-Negative Matrix Factorization (NMF) model with amortized inference on ESC50 data provided by the SpeechBrain library.
Description
This recipe trains an NMF model with amortized inference for interpretable sound classification on the ESC-50 environmental sound dataset. The NMFBrain class extends sb.core.Brain and implements a pipeline that: (1) computes the STFT of input waveforms, (2) takes the log-magnitude spectrogram, (3) encodes it via an NMF encoder into latent activations, and (4) reconstructs the spectrogram via an NMF decoder. Training minimizes the L2 reconstruction error between original and reconstructed log-magnitude spectrograms. During validation, the model periodically saves visual comparisons of original vs. reconstructed spectrograms as PNG images for inspection. The NMF factorization learns a dictionary of spectral basis components that can be inspected to understand which frequency patterns characterize different sound classes.
Usage
Run as a training recipe with a YAML hyperparameter file. Requires the ESC-50 dataset to be downloaded. Uses the dataio_prep function from train_l2i for data loading.
Code Reference
Source Location
- Repository: SpeechBrain
- File: recipes/ESC50/interpret/train_nmf.py
Signature
class NMFBrain(sb.core.Brain):
"""The SpeechBrain class to train Non-Negative Factorization with Amortized Inference."""
def compute_forward(self, batch, stage=sb.Stage.TRAIN):
"""Calculates the forward pass for NMF."""
...
def compute_objectives(self, predictions, batch, stage=sb.Stage.TRAIN):
"""Computes the L2-error to train the NMF model."""
...
def on_stage_end(self, stage, stage_loss, epoch=None):
"""Gets called at the end of an epoch."""
...
Import
python train_nmf.py hparams/nmf.yaml --data_folder /path/to/ESC-50-master
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| hparams_file | str | Yes | Path to YAML hyperparameter file (e.g., hparams/nmf.yaml) |
| --data_folder | str | Yes | Path to the ESC-50 dataset root directory |
| batch.sig | tuple(torch.Tensor, torch.Tensor) | Yes | Waveform tensor and lengths from the dataloader |
Outputs
| Name | Type | Description |
|---|---|---|
| Xhat | torch.Tensor | Reconstructed log-magnitude spectrogram from NMF decoder |
| loss | float | L2 reconstruction error between target and predicted spectrograms |
| nmf_rec/*.png | image files | Visual comparisons of original vs. reconstructed spectrograms |
Usage Examples
# Train NMF model on ESC-50
python train_nmf.py hparams/nmf.yaml --data_folder /data/ESC-50-master
# The model learns spectral basis components (dictionary)
# that can be inspected for interpretability