Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Speechbrain Speechbrain SpeakerBrain Compute Forward

From Leeroopedia


Property Value
Implementation Name SpeakerBrain Compute Forward
Type API Doc
Repository speechbrain/speechbrain
Source File recipes/VoxCeleb/SpeakerRec/train_speaker_embeddings.py:L30-111
Import Recipe-specific Brain subclass (defined in the training recipe)
Related Principle Principle:Speechbrain_Speechbrain_Speaker_Embedding_Model_Training

API Signature

class SpeakerBrain(sb.core.Brain):
    """Class for speaker embedding training"""

    def compute_forward(self, batch, stage):
        """Computation pipeline based on an encoder + speaker classifier."""

    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss using speaker-id as label."""

Description

SpeakerBrain is a subclass of speechbrain.core.Brain that implements the training loop for speaker embedding models (ECAPA-TDNN, x-vectors). It defines the forward computation pipeline (feature extraction, normalization, embedding, classification) and the loss computation (cross-entropy with optional augmentation label replication).

compute_forward

Parameters

Parameter Type Description
batch PaddedBatch A batch from the DataLoader containing batch.sig (waveforms, lengths) and batch.spk_id_encoded (speaker labels).
stage sb.Stage Current stage: sb.Stage.TRAIN, sb.Stage.VALID, or sb.Stage.TEST.

Returns

Output Type Description
outputs torch.Tensor Classification logits of shape (batch_size, num_speakers).
lens torch.Tensor Relative lengths of each waveform in the batch.

Processing Pipeline

def compute_forward(self, batch, stage):
    batch = batch.to(self.device)
    wavs, lens = batch.sig

    # 1. Waveform augmentation (training only)
    if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
        wavs, lens = self.hparams.wav_augment(wavs, lens)

    # 2. Feature extraction
    feats = self.modules.compute_features(wavs)

    # 3. Mean-variance normalization
    feats = self.modules.mean_var_norm(feats, lens)

    # 4. Embedding model (e.g., ECAPA-TDNN)
    embeddings = self.modules.embedding_model(feats)

    # 5. Speaker classifier
    outputs = self.modules.classifier(embeddings)

    return outputs, lens

Step-by-step:

  1. Move to device: The batch is transferred to the appropriate device (GPU/CPU).
  2. Waveform augmentation: During training only, wav_augment applies noise addition, reverberation, and/or speed perturbation. Augmented copies are appended to the batch, increasing the effective batch size.
  3. Feature extraction: compute_features computes acoustic features (Fbank, MFCC, or Tacotron2 mel spectrogram) from the raw waveform.
  4. Normalization: mean_var_norm applies instance-level mean-variance normalization to the features.
  5. Embedding model: The ECAPA-TDNN (or x-vector) network maps normalized features to fixed-dimensional embeddings.
  6. Classifier: A linear layer maps embeddings to speaker class logits.

compute_objectives

Parameters

Parameter Type Description
predictions tuple Output of compute_forward: (logits, lens).
batch PaddedBatch The same batch passed to compute_forward.
stage sb.Stage Current stage.

Returns

Output Type Description
loss torch.Tensor Scalar loss value (cross-entropy).

Implementation

def compute_objectives(self, predictions, batch, stage):
    predictions, lens = predictions
    uttid = batch.id
    spkid, _ = batch.spk_id_encoded

    # Replicate labels to match augmented batch
    if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
        spkid = self.hparams.wav_augment.replicate_labels(spkid)

    loss = self.hparams.compute_cost(predictions, spkid, lens)

    # Per-batch LR update (if configured)
    if stage == sb.Stage.TRAIN and hasattr(
        self.hparams.lr_annealing, "on_batch_end"
    ):
        self.hparams.lr_annealing.on_batch_end(self.optimizer)

    # Track error metrics during validation
    if stage != sb.Stage.TRAIN:
        self.error_metrics.append(uttid, predictions, spkid, lens)

    return loss

Key behaviors:

  • Label replication: When augmentation creates additional copies of each sample, the speaker labels must be replicated accordingly via wav_augment.replicate_labels().
  • Loss function: hparams.compute_cost is typically speechbrain.nnet.losses.nll_loss (negative log-likelihood / cross-entropy).
  • Per-batch LR scheduling: Some schedulers (e.g., cyclic) update the learning rate after each batch.
  • Error tracking: During validation, classification error rate is tracked for checkpoint selection.

Epoch Lifecycle Methods

on_stage_start

def on_stage_start(self, stage, epoch=None):
    if stage != sb.Stage.TRAIN:
        self.error_metrics = self.hparams.error_stats()

Initializes error metric tracking at the start of each validation epoch.

on_stage_end

def on_stage_end(self, stage, stage_loss, epoch=None):
    stage_stats = {"loss": stage_loss}
    if stage == sb.Stage.TRAIN:
        self.train_stats = stage_stats
    else:
        stage_stats["ErrorRate"] = self.error_metrics.summarize("average")

    if stage == sb.Stage.VALID:
        old_lr, new_lr = self.hparams.lr_annealing(epoch)
        sb.nnet.schedulers.update_learning_rate(self.optimizer, new_lr)
        self.hparams.train_logger.log_stats(
            stats_meta={"epoch": epoch, "lr": old_lr},
            train_stats=self.train_stats,
            valid_stats=stage_stats,
        )
        self.checkpointer.save_and_keep_only(
            meta={"ErrorRate": stage_stats["ErrorRate"]},
            min_keys=["ErrorRate"],
        )

At the end of each validation epoch:

  • Applies learning rate annealing based on epoch number.
  • Logs training and validation statistics.
  • Saves checkpoint and keeps only the best model (minimum error rate).

Usage Example

import sys
import torch
import speechbrain as sb
from hyperpyyaml import load_hyperpyyaml

# Load hyperparameters
hparams_file, run_opts, overrides = sb.parse_arguments(sys.argv[1:])
with open(hparams_file) as fin:
    hparams = load_hyperpyyaml(fin, overrides)

# Create datasets (see dataio_prep)
train_data, valid_data, label_encoder = dataio_prep(hparams)

# Initialize the Brain
speaker_brain = SpeakerBrain(
    modules=hparams["modules"],
    opt_class=hparams["opt_class"],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams["checkpointer"],
)

# Train the model
speaker_brain.fit(
    speaker_brain.hparams.epoch_counter,
    train_data,
    valid_data,
    train_loader_kwargs=hparams["dataloader_options"],
    valid_loader_kwargs=hparams["dataloader_options"],
)

Required Module Configuration

The following modules must be defined in the YAML hyperparameters file:

Module Description
compute_features Feature extractor (e.g., speechbrain.lobes.features.Fbank)
mean_var_norm Instance normalization (e.g., speechbrain.processing.features.InputNormalization)
embedding_model Speaker encoder (e.g., speechbrain.lobes.models.ECAPA_TDNN.ECAPA_TDNN)
classifier Linear classification head (e.g., speechbrain.lobes.models.ECAPA_TDNN.Classifier)
wav_augment (optional) Waveform augmentation module

See Also

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment