Implementation:Speechbrain Speechbrain SpeakerBrain Compute Forward
Appearance
| Property | Value |
|---|---|
| Implementation Name | SpeakerBrain Compute Forward |
| Type | API Doc |
| Repository | speechbrain/speechbrain |
| Source File | recipes/VoxCeleb/SpeakerRec/train_speaker_embeddings.py:L30-111
|
| Import | Recipe-specific Brain subclass (defined in the training recipe) |
| Related Principle | Principle:Speechbrain_Speechbrain_Speaker_Embedding_Model_Training |
API Signature
class SpeakerBrain(sb.core.Brain):
"""Class for speaker embedding training"""
def compute_forward(self, batch, stage):
"""Computation pipeline based on an encoder + speaker classifier."""
def compute_objectives(self, predictions, batch, stage):
"""Computes the loss using speaker-id as label."""
Description
SpeakerBrain is a subclass of speechbrain.core.Brain that implements the training loop for speaker embedding models (ECAPA-TDNN, x-vectors). It defines the forward computation pipeline (feature extraction, normalization, embedding, classification) and the loss computation (cross-entropy with optional augmentation label replication).
compute_forward
Parameters
| Parameter | Type | Description |
|---|---|---|
| batch | PaddedBatch |
A batch from the DataLoader containing batch.sig (waveforms, lengths) and batch.spk_id_encoded (speaker labels).
|
| stage | sb.Stage |
Current stage: sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST.
|
Returns
| Output | Type | Description |
|---|---|---|
| outputs | torch.Tensor |
Classification logits of shape (batch_size, num_speakers).
|
| lens | torch.Tensor |
Relative lengths of each waveform in the batch. |
Processing Pipeline
def compute_forward(self, batch, stage):
batch = batch.to(self.device)
wavs, lens = batch.sig
# 1. Waveform augmentation (training only)
if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
wavs, lens = self.hparams.wav_augment(wavs, lens)
# 2. Feature extraction
feats = self.modules.compute_features(wavs)
# 3. Mean-variance normalization
feats = self.modules.mean_var_norm(feats, lens)
# 4. Embedding model (e.g., ECAPA-TDNN)
embeddings = self.modules.embedding_model(feats)
# 5. Speaker classifier
outputs = self.modules.classifier(embeddings)
return outputs, lens
Step-by-step:
- Move to device: The batch is transferred to the appropriate device (GPU/CPU).
- Waveform augmentation: During training only,
wav_augmentapplies noise addition, reverberation, and/or speed perturbation. Augmented copies are appended to the batch, increasing the effective batch size. - Feature extraction:
compute_featurescomputes acoustic features (Fbank, MFCC, or Tacotron2 mel spectrogram) from the raw waveform. - Normalization:
mean_var_normapplies instance-level mean-variance normalization to the features. - Embedding model: The ECAPA-TDNN (or x-vector) network maps normalized features to fixed-dimensional embeddings.
- Classifier: A linear layer maps embeddings to speaker class logits.
compute_objectives
Parameters
| Parameter | Type | Description |
|---|---|---|
| predictions | tuple | Output of compute_forward: (logits, lens).
|
| batch | PaddedBatch |
The same batch passed to compute_forward.
|
| stage | sb.Stage |
Current stage. |
Returns
| Output | Type | Description |
|---|---|---|
| loss | torch.Tensor |
Scalar loss value (cross-entropy). |
Implementation
def compute_objectives(self, predictions, batch, stage):
predictions, lens = predictions
uttid = batch.id
spkid, _ = batch.spk_id_encoded
# Replicate labels to match augmented batch
if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
spkid = self.hparams.wav_augment.replicate_labels(spkid)
loss = self.hparams.compute_cost(predictions, spkid, lens)
# Per-batch LR update (if configured)
if stage == sb.Stage.TRAIN and hasattr(
self.hparams.lr_annealing, "on_batch_end"
):
self.hparams.lr_annealing.on_batch_end(self.optimizer)
# Track error metrics during validation
if stage != sb.Stage.TRAIN:
self.error_metrics.append(uttid, predictions, spkid, lens)
return loss
Key behaviors:
- Label replication: When augmentation creates additional copies of each sample, the speaker labels must be replicated accordingly via
wav_augment.replicate_labels(). - Loss function:
hparams.compute_costis typicallyspeechbrain.nnet.losses.nll_loss(negative log-likelihood / cross-entropy). - Per-batch LR scheduling: Some schedulers (e.g., cyclic) update the learning rate after each batch.
- Error tracking: During validation, classification error rate is tracked for checkpoint selection.
Epoch Lifecycle Methods
on_stage_start
def on_stage_start(self, stage, epoch=None):
if stage != sb.Stage.TRAIN:
self.error_metrics = self.hparams.error_stats()
Initializes error metric tracking at the start of each validation epoch.
on_stage_end
def on_stage_end(self, stage, stage_loss, epoch=None):
stage_stats = {"loss": stage_loss}
if stage == sb.Stage.TRAIN:
self.train_stats = stage_stats
else:
stage_stats["ErrorRate"] = self.error_metrics.summarize("average")
if stage == sb.Stage.VALID:
old_lr, new_lr = self.hparams.lr_annealing(epoch)
sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)
self.hparams.train_logger.log_stats(
stats_meta={"epoch": epoch, "lr": old_lr},
train_stats=self.train_stats,
valid_stats=stage_stats,
)
self.checkpointer.save_and_keep_only(
meta={"ErrorRate": stage_stats["ErrorRate"]},
min_keys=["ErrorRate"],
)
At the end of each validation epoch:
- Applies learning rate annealing based on epoch number.
- Logs training and validation statistics.
- Saves checkpoint and keeps only the best model (minimum error rate).
Usage Example
import sys
import torch
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml
# Load hyperparameters
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
with open(hparams_file) as fin:
hparams = load_hyperpyyaml(fin, overrides)
# Create datasets (see dataio_prep)
train_data, valid_data, label_encoder = dataio_prep(hparams)
# Initialize the Brain
speaker_brain = SpeakerBrain(
modules=hparams["modules"],
opt_class=hparams["opt_class"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
)
# Train the model
speaker_brain.fit(
speaker_brain.hparams.epoch_counter,
train_data,
valid_data,
train_loader_kwargs=hparams["dataloader_options"],
valid_loader_kwargs=hparams["dataloader_options"],
)
Required Module Configuration
The following modules must be defined in the YAML hyperparameters file:
| Module | Description |
|---|---|
| compute_features | Feature extractor (e.g., speechbrain.lobes.features.Fbank)
|
| mean_var_norm | Instance normalization (e.g., speechbrain.processing.features.InputNormalization)
|
| embedding_model | Speaker encoder (e.g., speechbrain.lobes.models.ECAPA_TDNN.ECAPA_TDNN)
|
| classifier | Linear classification head (e.g., speechbrain.lobes.models.ECAPA_TDNN.Classifier)
|
| wav_augment | (optional) Waveform augmentation module |
See Also
- Principle:Speechbrain_Speechbrain_Speaker_Embedding_Model_Training
- Implementation:Speechbrain_Speechbrain_Speaker_Dataio_Prep
- Implementation:Speechbrain_Speechbrain_Compute_Embeddings
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment