Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Speechbrain Speechbrain Whisper ASR Compute Forward

From Leeroopedia


Field Value
API ASR.compute_forward(self, batch, stage) and ASR.compute_objectives(self, predictions, batch, stage)
Source recipes/CommonVoice/ASR/transformer/train_with_whisper.py:L29 (class ASR), L30-65 (compute_forward), L67-118 (compute_objectives)
Import Recipe-specific Brain subclass. Uses speechbrain.nnet.losses.nll_loss and speechbrain.nnet.schedulers.NoamScheduler
Type API Doc
Inputs PaddedBatch with audio signals (batch.sig) and tokenized text (batch.tokens_bos, batch.tokens_eos)
Outputs Log probabilities over vocabulary (log_probs), decoded hypotheses (hyps), NLL loss
Related Principle Principle:Speechbrain_Speechbrain_Whisper_Finetuning_With_LR_Scheduling

Purpose

Implements the core training loop for Whisper fine-tuning as a SpeechBrain Brain subclass. Handles the forward pass through the Whisper encoder-decoder, NLL loss computation, learning rate scheduling, and WER/CER evaluation during validation and testing.

Class Definition

import speechbrain as sb

class ASR(sb.Brain):
    def compute_forward(self, batch, stage):
        """Forward computations from the waveform batches
        to the output probabilities."""
        ...

    def compute_objectives(self, predictions, batch, stage):
        """Computes the loss NLL given predictions and targets."""
        ...

    def on_stage_start(self, stage, epoch):
        """Gets called at the beginning of each epoch."""
        ...

    def on_stage_end(self, stage, stage_loss, epoch):
        """Gets called at the end of an epoch."""
        ...

compute_forward(self, batch, stage)

Performs the forward pass through the Whisper model.

def compute_forward(self, batch, stage):
    batch = batch.to(self.device)
    wavs, wav_lens = batch.sig
    bos_tokens, bos_tokens_lens = batch.tokens_bos

    # Optional waveform augmentation during training
    if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
        wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
        bos_tokens = self.hparams.wav_augment.replicate_labels(bos_tokens)
        bos_tokens_lens = self.hparams.wav_augment.replicate_labels(
            bos_tokens_lens
        )

    # Compute padding mask for Whisper decoder
    abs_tokens_lens = (bos_tokens_lens * bos_tokens.shape[1]).long()
    pad_mask = (
        torch.arange(abs_tokens_lens.max(), device=self.device)[None, :]
        < abs_tokens_lens[:, None]
    )
    bos_tokens[~pad_mask] = self.tokenizer.pad_token_id

    # Forward through Whisper encoder + decoder
    enc_out, logits, _ = self.modules.whisper(wavs, bos_tokens)
    log_probs = self.hparams.log_softmax(logits)

    # Decoding for validation/test
    hyps = None
    if stage == sb.Stage.VALID:
        hyps, _, _, _ = self.hparams.valid_search(
            enc_out.detach(), wav_lens
        )
    elif stage == sb.Stage.TEST:
        hyps, _, _, _ = self.hparams.test_search(
            enc_out.detach(), wav_lens
        )

    return log_probs, hyps, wav_lens

Key steps:

  1. Move batch to device (GPU).
  2. Extract audio waveforms and BOS token sequences.
  3. Apply optional data augmentation (speed perturbation, frequency/chunk dropping) during training.
  4. Compute a padding mask and replace padding positions with pad_token_id.
  5. Run Whisper's full encoder-decoder forward pass.
  6. Apply log-softmax to decoder logits.
  7. During validation: run greedy search (S2SWhisperGreedySearcher) for hypothesis generation.
  8. During testing: run beam search (S2SWhisperBeamSearcher) for hypothesis generation.

compute_objectives(self, predictions, batch, stage)

Computes the NLL loss and evaluation metrics.

def compute_objectives(self, predictions, batch, stage):
    (log_probs, hyps, wav_lens) = predictions
    batch = batch.to(self.device)
    ids = batch.id
    tokens_eos, tokens_eos_lens = batch.tokens_eos

    # Replicate labels if augmentation was applied
    if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
        tokens_eos = self.hparams.wav_augment.replicate_labels(tokens_eos)
        tokens_eos_lens = self.hparams.wav_augment.replicate_labels(
            tokens_eos_lens
        )

    # Compute NLL loss
    loss = self.hparams.nll_loss(
        log_probs, tokens_eos, length=tokens_eos_lens
    )

    # Evaluation metrics (validation/test only)
    if stage != sb.Stage.TRAIN:
        tokens, tokens_lens = batch.tokens

        # Decode hypothesis tokens to text
        predicted_words = [
            self.tokenizer.decode(t, skip_special_tokens=True).strip()
            for t in hyps
        ]

        # Decode target tokens to text
        target_words = undo_padding(tokens, tokens_lens)
        target_words = self.tokenizer.batch_decode(
            target_words, skip_special_tokens=True
        )

        # Apply Whisper normalization if configured
        if hasattr(self.hparams, "normalized_transcripts"):
            predicted_words = [
                self.tokenizer.normalize(text).split(" ")
                for text in predicted_words
            ]
            target_words = [
                self.tokenizer.normalize(text).split(" ")
                for text in target_words
            ]
        else:
            predicted_words = [text.split(" ") for text in predicted_words]
            target_words = [text.split(" ") for text in target_words]

        self.wer_metric.append(ids, predicted_words, target_words)
        self.cer_metric.append(ids, predicted_words, target_words)

    return loss

YAML Configuration for Training

# Optimizer
lr_whisper: 1e-5
weight_decay: 0.01
warmup_steps: 500
max_grad_norm: 2.0

whisper_opt_class: !name:torch.optim.AdamW
    lr: !ref <lr_whisper>
    weight_decay: !ref <weight_decay>

# Learning rate scheduler
lr_annealing_whisper: !new:speechbrain.nnet.schedulers.NoamScheduler
    lr_initial: !ref <lr_whisper>
    n_warmup_steps: !ref <warmup_steps>

# Loss
nll_loss: !name:speechbrain.nnet.losses.nll_loss

log_softmax: !new:speechbrain.nnet.activations.Softmax
    apply_log: True

# Decoding
valid_search: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearcher
    model: !ref <whisper>
    min_decode_ratio: !ref <min_decode_ratio>
    max_decode_ratio: !ref <max_decode_ratio>

test_search: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearcher
    module: [!ref <whisper>]
    min_decode_ratio: !ref <min_decode_ratio>
    max_decode_ratio: !ref <max_decode_ratio>
    beam_size: !ref <test_beam_size>

Full Training Script Usage

# Initialize the ASR Brain
asr_brain = ASR(
    modules=hparams["modules"],
    hparams=hparams,
    run_opts=run_opts,
    checkpointer=hparams["checkpointer"],
    opt_class=hparams["whisper_opt_class"],
)

# Load pretrained weights if available
if "pretrainer" in hparams:
    hparams["pretrainer"].collect_files()
    hparams["pretrainer"].load_collected(asr_brain.device)

# Attach tokenizer
asr_brain.tokenizer = hparams["whisper"].tokenizer

# Train
asr_brain.fit(
    asr_brain.hparams.epoch_counter,
    train_data,
    valid_data,
    train_loader_kwargs=hparams["train_loader_kwargs"],
    valid_loader_kwargs=hparams["valid_loader_kwargs"],
)

# Test (loads best checkpoint by min WER)
asr_brain.evaluate(
    test_data,
    min_key="WER",
    test_loader_kwargs=hparams["test_loader_kwargs"],
)

Epoch Lifecycle

Method When Called Action
on_stage_start Beginning of each stage Initializes WER and CER metric computers for validation/test
compute_forward Each batch Runs Whisper encoder-decoder forward pass and optional decoding
compute_objectives Each batch Computes NLL loss and appends WER/CER metrics
on_stage_end End of each stage Logs stats, updates LR scheduler, saves checkpoint (validation) or writes WER file (test)

See Also

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment