Implementation:Speechbrain Speechbrain SEBrain Compute Forward
Appearance
| Property | Value |
|---|---|
| Implementation Name | SEBrain_Compute_Forward |
| API | SEBrain.compute_forward(self, batch, stage) and SEBrain.compute_objectives(self, predictions, batch, stage)
|
| Source File | recipes/Voicebank/enhance/spectral_mask/train.py -- Class: L26, compute_forward: L27-43, compute_objectives: L52-92
|
| Import | Recipe-specific Brain subclass (not importable as library) |
| Type | API Doc |
| Workflow | Speech_Enhancement_Training |
| Domains | Speech_Enhancement, Deep_Learning |
| Related Principle | Principle:Speechbrain_Speechbrain_Conventional_Enhancement_Training |
Purpose
SEBrain (Speech Enhancement Brain) is a custom sb.Brain subclass that implements conventional supervised training for speech enhancement using spectral masking or waveform mapping. It provides the compute_forward() and compute_objectives() methods that plug into SpeechBrain's standard training loop to perform STFT-based mask prediction, spectral reconstruction, and MSE loss computation.
Class Definition
class SEBrain(sb.Brain):
"""Brain class for speech enhancement training.
Supports both spectral mask and waveform mapping approaches,
selectable via the 'waveform_target' hyperparameter.
"""
compute_forward Method
def compute_forward(self, batch, stage):
"""Forward computations from the waveform batches to the enhanced output.
Arguments
---------
batch : PaddedBatch
Contains 'noisy_sig' (noisy waveform) and 'clean_sig' (clean waveform).
stage : sb.Stage
One of TRAIN, VALID, or TEST.
Returns
-------
predict_spec : torch.Tensor
Enhanced spectral features [batch, time, freq].
predict_wav : torch.Tensor
Reconstructed enhanced waveform [batch, samples].
"""
batch = batch.to(self.device)
noisy_wavs, lens = batch.noisy_sig
noisy_feats = self.compute_feats(noisy_wavs)
# Predict spectral mask using the model
mask = self.modules.model(noisy_feats)
mask = torch.squeeze(mask, 2)
# Apply mask via signal approximation (SA)
predict_spec = torch.mul(mask, noisy_feats)
# Reconstruct waveform via ISTFT using original noisy phase
predict_wav = self.hparams.resynth(
torch.expm1(predict_spec), noisy_wavs
)
return predict_spec, predict_wav
compute_feats Method
def compute_feats(self, wavs):
"""Feature computation pipeline.
Arguments
---------
wavs : torch.Tensor
Raw waveform tensor [batch, samples].
Returns
-------
feats : torch.Tensor
Log-compressed spectral magnitude [batch, time, freq].
"""
feats = self.hparams.compute_STFT(wavs)
feats = spectral_magnitude(feats, power=0.5)
feats = torch.log1p(feats)
return feats
compute_objectives Method
def compute_objectives(self, predictions, batch, stage):
"""Computes the loss given the predicted and targeted outputs.
Arguments
---------
predictions : tuple
(predict_spec, predict_wav) from compute_forward.
batch : PaddedBatch
Contains 'clean_sig' (target waveform).
stage : sb.Stage
Current training stage.
Returns
-------
loss : torch.Tensor
Scalar loss value.
"""
predict_spec, predict_wav = predictions
clean_wavs, lens = batch.clean_sig
if getattr(self.hparams, "waveform_target", False):
# Waveform-domain loss
loss = self.hparams.compute_cost(predict_wav, clean_wavs, lens)
self.loss_metric.append(
batch.id, predict_wav, clean_wavs, lens, reduction="batch"
)
else:
# Spectral-domain loss (default)
clean_spec = self.compute_feats(clean_wavs)
loss = self.hparams.compute_cost(predict_spec, clean_spec, lens)
self.loss_metric.append(
batch.id, predict_spec, clean_spec, lens, reduction="batch"
)
if stage != sb.Stage.TRAIN:
# Compute perceptual metrics during validation/test
self.stoi_metric.append(
batch.id, predict_wav, clean_wavs, lens, reduction="batch"
)
self.pesq_metric.append(
batch.id, predict=predict_wav, target=clean_wavs, lengths=lens
)
# Write enhanced wavs to file during test
if stage == sb.Stage.TEST:
lens = lens * clean_wavs.shape[1]
for name, pred_wav, length in zip(batch.id, predict_wav, lens):
name += ".wav"
enhance_path = os.path.join(
self.hparams.enhanced_folder, name
)
torchaudio.save(
enhance_path,
torch.unsqueeze(pred_wav[: int(length)].cpu(), 0),
16000,
)
return loss
Stage Callbacks
on_stage_start
def on_stage_start(self, stage, epoch=None):
"""Gets called at the beginning of each epoch."""
self.loss_metric = MetricStats(metric=self.hparams.compute_cost)
self.stoi_metric = MetricStats(metric=stoi_loss)
def pesq_eval(pred_wav, target_wav):
return pesq(
fs=16000, ref=target_wav.numpy(),
deg=pred_wav.numpy(), mode="wb",
)
if stage != sb.Stage.TRAIN:
self.pesq_metric = MetricStats(
metric=pesq_eval, n_jobs=1, batch_eval=False
)
on_stage_end
def on_stage_end(self, stage, stage_loss, epoch=None):
"""Gets called at the end of an epoch."""
if stage == sb.Stage.TRAIN:
self.train_loss = stage_loss
else:
stats = {
"loss": stage_loss,
"pesq": self.pesq_metric.summarize("average"),
"stoi": -self.stoi_metric.summarize("average"),
}
if stage == sb.Stage.VALID:
self.hparams.train_logger.log_stats(
{"Epoch": epoch},
train_stats={"loss": self.train_loss},
valid_stats=stats,
)
# Save checkpoint based on best PESQ
self.checkpointer.save_and_keep_only(
meta=stats, max_keys=["pesq"]
)
Usage Examples
Full Training Pipeline
import sys
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
# Load hyperparameters
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
with open(hparams_file, encoding="utf-8") as fin:
hparams = load_hyperpyyaml(fin, overrides)
# Prepare data
from voicebank_prepare import prepare_voicebank
from speechbrain.utils.distributed import run_on_main
run_on_main(
prepare_voicebank,
kwargs={
"data_folder": hparams["data_folder"],
"save_folder": hparams["output_folder"],
"skip_prep": hparams["skip_prep"],
},
)
# Create datasets with audio pipelines
datasets = dataio_prep(hparams)
# Initialize SEBrain
se_brain = SEBrain(
modules=hparams["modules"],
opt_class=hparams["opt_class"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)
# Train with standard Brain.fit() loop
se_brain.fit(
epoch_counter=se_brain.hparams.epoch_counter,
train_set=datasets["train"],
valid_set=datasets["valid"],
train_loader_kwargs=hparams["dataloader_options"],
valid_loader_kwargs=hparams["dataloader_options"],
)
# Evaluate on test set using best PESQ checkpoint
test_stats = se_brain.evaluate(
test_set=datasets["test"],
max_key="pesq",
test_loader_kwargs=hparams["dataloader_options"],
)
Running from Command Line
# Train with default BLSTM model
python train.py hparams/train.yaml --data_folder /data/noisy-vctk-16k
# Train with 2D-FCN model
python train.py hparams/train.yaml --data_folder /data/noisy-vctk-16k \
--models '!include:models/2DFCN.yaml'
# Train with waveform-domain loss
python train.py hparams/train.yaml --data_folder /data/noisy-vctk-16k \
--waveform_target True
Data Flow Diagram
noisy_wav ──> STFT ──> spectral_magnitude(power=0.5) ──> log1p ──> model ──> mask
|
v
noisy_wav ──> STFT ──> spectral_magnitude(power=0.5) ──> log1p ──> (*mask) = predict_spec
|
v
expm1 ──> ISTFT ──> predict_wav
|
clean_wav ──> STFT ──> spectral_magnitude(power=0.5) ──> log1p = clean_spec |
| |
v v
MSE(predict_spec, clean_spec) = loss
Inputs and Outputs
Inputs (per batch):
batch.noisy_sig: Tuple of (noisy waveform tensor [B, T], relative lengths [B])batch.clean_sig: Tuple of (clean waveform tensor [B, T], relative lengths [B])
Outputs:
- Training loss: Spectral MSE (or waveform MSE if
waveform_target=True) - Validation metrics: PESQ (1-4.5), STOI (0-1), computed on enhanced waveforms
- Enhanced wavs: Written to
enhanced_folderduring test stage
Key Configuration Parameters
| Parameter | Default | Description |
|---|---|---|
waveform_target |
False |
If True, compute loss in waveform domain instead of spectral domain |
number_of_epochs |
50 | Total training epochs |
N_batch |
8 | Batch size |
lr |
0.0001 | Learning rate for Adam optimizer |
sorting |
"ascending" |
Sort training data by length for efficient batching |
N_fft |
512 | FFT size (32 ms at 16 kHz) |
Win_length |
32 | Window length in milliseconds |
Hop_length |
16 | Hop length in milliseconds |
See Also
- Principle:Speechbrain_Speechbrain_Conventional_Enhancement_Training -- The theoretical foundation for conventional training
- Implementation:Speechbrain_Speechbrain_Load_Hyperpyyaml_Enhancement -- How architecture and hyperparameters are configured
- Implementation:Speechbrain_Speechbrain_Prepare_Voicebank -- How training data is prepared
- Implementation:Speechbrain_Speechbrain_Composite_Eval_Metrics -- Detailed evaluation metrics
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment