Implementation:Speechbrain Speechbrain Whisper ASR Compute Forward
Appearance
| Field | Value |
|---|---|
| API | ASR.compute_forward(self, batch, stage) and ASR.compute_objectives(self, predictions, batch, stage) |
| Source | recipes/CommonVoice/ASR/transformer/train_with_whisper.py:L29 (class ASR), L30-65 (compute_forward), L67-118 (compute_objectives) |
| Import | Recipe-specific Brain subclass. Uses speechbrain.nnet.losses.nll_loss and speechbrain.nnet.schedulers.NoamScheduler |
| Type | API Doc |
| Inputs | PaddedBatch with audio signals (batch.sig) and tokenized text (batch.tokens_bos, batch.tokens_eos) |
| Outputs | Log probabilities over vocabulary (log_probs), decoded hypotheses (hyps), NLL loss |
| Related Principle | Principle:Speechbrain_Speechbrain_Whisper_Finetuning_With_LR_Scheduling |
Purpose
Implements the core training loop for Whisper fine-tuning as a SpeechBrain Brain subclass. Handles the forward pass through the Whisper encoder-decoder, NLL loss computation, learning rate scheduling, and WER/CER evaluation during validation and testing.
Class Definition
import speechbrain as sb
class ASR(sb.Brain):
def compute_forward(self, batch, stage):
"""Forward computations from the waveform batches
to the output probabilities."""
...
def compute_objectives(self, predictions, batch, stage):
"""Computes the loss NLL given predictions and targets."""
...
def on_stage_start(self, stage, epoch):
"""Gets called at the beginning of each epoch."""
...
def on_stage_end(self, stage, stage_loss, epoch):
"""Gets called at the end of an epoch."""
...
compute_forward(self, batch, stage)
Performs the forward pass through the Whisper model.
def compute_forward(self, batch, stage):
batch = batch.to(self.device)
wavs, wav_lens = batch.sig
bos_tokens, bos_tokens_lens = batch.tokens_bos
# Optional waveform augmentation during training
if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
wavs, wav_lens = self.hparams.wav_augment(wavs, wav_lens)
bos_tokens = self.hparams.wav_augment.replicate_labels(bos_tokens)
bos_tokens_lens = self.hparams.wav_augment.replicate_labels(
bos_tokens_lens
)
# Compute padding mask for Whisper decoder
abs_tokens_lens = (bos_tokens_lens * bos_tokens.shape[1]).long()
pad_mask = (
torch.arange(abs_tokens_lens.max(), device=self.device)[None, :]
< abs_tokens_lens[:, None]
)
bos_tokens[~pad_mask] = self.tokenizer.pad_token_id
# Forward through Whisper encoder + decoder
enc_out, logits, _ = self.modules.whisper(wavs, bos_tokens)
log_probs = self.hparams.log_softmax(logits)
# Decoding for validation/test
hyps = None
if stage == sb.Stage.VALID:
hyps, _, _, _ = self.hparams.valid_search(
enc_out.detach(), wav_lens
)
elif stage == sb.Stage.TEST:
hyps, _, _, _ = self.hparams.test_search(
enc_out.detach(), wav_lens
)
return log_probs, hyps, wav_lens
Key steps:
- Move batch to device (GPU).
- Extract audio waveforms and BOS token sequences.
- Apply optional data augmentation (speed perturbation, frequency/chunk dropping) during training.
- Compute a padding mask and replace padding positions with pad_token_id.
- Run Whisper's full encoder-decoder forward pass.
- Apply log-softmax to decoder logits.
- During validation: run greedy search (S2SWhisperGreedySearcher) for hypothesis generation.
- During testing: run beam search (S2SWhisperBeamSearcher) for hypothesis generation.
compute_objectives(self, predictions, batch, stage)
Computes the NLL loss and evaluation metrics.
def compute_objectives(self, predictions, batch, stage):
(log_probs, hyps, wav_lens) = predictions
batch = batch.to(self.device)
ids = batch.id
tokens_eos, tokens_eos_lens = batch.tokens_eos
# Replicate labels if augmentation was applied
if stage == sb.Stage.TRAIN and hasattr(self.hparams, "wav_augment"):
tokens_eos = self.hparams.wav_augment.replicate_labels(tokens_eos)
tokens_eos_lens = self.hparams.wav_augment.replicate_labels(
tokens_eos_lens
)
# Compute NLL loss
loss = self.hparams.nll_loss(
log_probs, tokens_eos, length=tokens_eos_lens
)
# Evaluation metrics (validation/test only)
if stage != sb.Stage.TRAIN:
tokens, tokens_lens = batch.tokens
# Decode hypothesis tokens to text
predicted_words = [
self.tokenizer.decode(t, skip_special_tokens=True).strip()
for t in hyps
]
# Decode target tokens to text
target_words = undo_padding(tokens, tokens_lens)
target_words = self.tokenizer.batch_decode(
target_words, skip_special_tokens=True
)
# Apply Whisper normalization if configured
if hasattr(self.hparams, "normalized_transcripts"):
predicted_words = [
self.tokenizer.normalize(text).split(" ")
for text in predicted_words
]
target_words = [
self.tokenizer.normalize(text).split(" ")
for text in target_words
]
else:
predicted_words = [text.split(" ") for text in predicted_words]
target_words = [text.split(" ") for text in target_words]
self.wer_metric.append(ids, predicted_words, target_words)
self.cer_metric.append(ids, predicted_words, target_words)
return loss
YAML Configuration for Training
# Optimizer
lr_whisper: 1e-5
weight_decay: 0.01
warmup_steps: 500
max_grad_norm: 2.0
whisper_opt_class: !name:torch.optim.AdamW
lr: !ref <lr_whisper>
weight_decay: !ref <weight_decay>
# Learning rate scheduler
lr_annealing_whisper: !new:speechbrain.nnet.schedulers.NoamScheduler
lr_initial: !ref <lr_whisper>
n_warmup_steps: !ref <warmup_steps>
# Loss
nll_loss: !name:speechbrain.nnet.losses.nll_loss
log_softmax: !new:speechbrain.nnet.activations.Softmax
apply_log: True
# Decoding
valid_search: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearcher
model: !ref <whisper>
min_decode_ratio: !ref <min_decode_ratio>
max_decode_ratio: !ref <max_decode_ratio>
test_search: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearcher
module: [!ref <whisper>]
min_decode_ratio: !ref <min_decode_ratio>
max_decode_ratio: !ref <max_decode_ratio>
beam_size: !ref <test_beam_size>
Full Training Script Usage
# Initialize the ASR Brain
asr_brain = ASR(
modules=hparams["modules"],
hparams=hparams,
run_opts=run_opts,
checkpointer=hparams["checkpointer"],
opt_class=hparams["whisper_opt_class"],
)
# Load pretrained weights if available
if "pretrainer" in hparams:
hparams["pretrainer"].collect_files()
hparams["pretrainer"].load_collected(asr_brain.device)
# Attach tokenizer
asr_brain.tokenizer = hparams["whisper"].tokenizer
# Train
asr_brain.fit(
asr_brain.hparams.epoch_counter,
train_data,
valid_data,
train_loader_kwargs=hparams["train_loader_kwargs"],
valid_loader_kwargs=hparams["valid_loader_kwargs"],
)
# Test (loads best checkpoint by min WER)
asr_brain.evaluate(
test_data,
min_key="WER",
test_loader_kwargs=hparams["test_loader_kwargs"],
)
Epoch Lifecycle
| Method | When Called | Action |
|---|---|---|
| on_stage_start | Beginning of each stage | Initializes WER and CER metric computers for validation/test |
| compute_forward | Each batch | Runs Whisper encoder-decoder forward pass and optional decoding |
| compute_objectives | Each batch | Computes NLL loss and appends WER/CER metrics |
| on_stage_end | End of each stage | Logs stats, updates LR scheduler, saves checkpoint (validation) or writes WER file (test) |
See Also
- Principle:Speechbrain_Speechbrain_Whisper_Finetuning_With_LR_Scheduling
- Implementation:Speechbrain_Speechbrain_Whisper_HFTransformersInterface
- Implementation:Speechbrain_Speechbrain_ErrorRateStats_For_Whisper
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment