Implementation:Speechbrain Speechbrain Train PIQ
| Knowledge Sources | |
|---|---|
| Domains | Interpretability, Sound_Classification |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete tool for training PIQ (Posthoc Interpretation via Quantization) to interpret an audio classifier provided by the SpeechBrain library.
Description
This recipe trains a PIQ interpreter model that provides posthoc interpretability for an audio classifier on the ESC-50 dataset. The PIQ class extends InterpreterBrain and implements a pipeline that: (1) preprocesses waveforms into STFT log-power spectrograms, (2) runs them through a frozen pretrained classifier to obtain hidden representations and class predictions, (3) passes the classifier's intermediate features through a learned decoder (psi) that produces an interpretation spectrogram highlighting which time-frequency regions are important for the classification, and (4) optionally applies vector quantization (VQ) for discrete interpretation. The interpretation can operate in mask mode (sigmoid-gated multiplicative mask on the input spectrogram) or softplus mode (thresholded activation). The interpret_computation_steps method provides the full interpretation pipeline including spectral phase recovery for potential audio reconstruction.
Usage
Run as a training recipe with a YAML hyperparameter file. The frozen classifier must be pretrained separately. Uses the ESC-50 dataset with data preparation from esc50_prepare.
Code Reference
Source Location
- Repository: SpeechBrain
- File: recipes/ESC50/interpret/train_piq.py
Signature
class PIQ(InterpreterBrain):
"""Class for interpreter training."""
def interpret_computation_steps(self, wavs, print_probability=False):
"""Computation steps to get the interpretation spectrogram."""
...
def compute_forward(self, batch, stage):
"""Computation pipeline based on an encoder + sound classifier."""
...
def compute_objectives(self, predictions, batch, stage):
"""Computes the objectives for PIQ training."""
...
Import
python train_piq.py hparams/piq.yaml --data_folder /path/to/ESC-50-master
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| hparams_file | str | Yes | Path to YAML hyperparameter file (e.g., hparams/piq.yaml) |
| --data_folder | str | Yes | Path to the ESC-50 dataset root directory |
| batch.sig | tuple(torch.Tensor, torch.Tensor) | Yes | Waveform tensor and lengths |
| wavs | torch.Tensor | Yes (interpret) | Waveforms to interpret |
| use_vq | bool | No | Whether to use vector quantization in the decoder (from hparams) |
| use_mask_output | bool | No | Whether to use sigmoid mask mode (from hparams) |
Outputs
| Name | Type | Description |
|---|---|---|
| X_int | torch.Tensor | Interpretation spectrogram highlighting important time-frequency regions |
| xhat | torch.Tensor | Raw decoder output (mask or activation map) |
| X_stft_phase | torch.Tensor | Spectral phase for potential audio reconstruction |
| class_pred | torch.Tensor | Predicted class from the frozen classifier |
Usage Examples
# Train PIQ interpreter on ESC-50
python train_piq.py hparams/piq.yaml --data_folder /data/ESC-50-master
# With vector quantization enabled
python train_piq.py hparams/piq_vq.yaml --data_folder /data/ESC-50-master