Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Speechbrain Speechbrain Whisper Finetuning With LR Scheduling

From Leeroopedia


Field Value
Concept Fine-tuning pretrained encoder-decoder models with learning rate warmup and decay schedules
Domains Transfer_Learning, ASR, Optimization
Knowledge Sources Vaswani et al. 2017 "Attention is All You Need" (Noam scheduler); Radford et al. 2023 "Robust Speech Recognition via Large-Scale Weak Supervision"
Related Implementation Implementation:Speechbrain_Speechbrain_Whisper_ASR_Compute_Forward

Overview

Fine-tuning large pretrained models such as Whisper requires careful learning rate management to balance preserving pretrained knowledge with adapting to the target domain. The Noam learning rate scheduler, combined with AdamW optimization and encoder freezing, provides an effective training regime for Whisper fine-tuning within SpeechBrain.

Learning Rate Scheduling

The Noam scheduler (from "Attention is All You Need") provides a warmup phase followed by inverse square root decay:

LR(step) = d_model^(-0.5) * min(step^(-0.5), step * warmup_steps^(-1.5))

This schedule has two phases:

  • Warmup phase (step < warmup_steps): The learning rate increases linearly from near-zero to the peak value. This prevents large, destructive gradient updates in the early steps when the model has not yet adapted to the new data distribution.
  • Decay phase (step >= warmup_steps): The learning rate decays proportionally to the inverse square root of the step number, gradually reducing the update magnitude as training converges.

For Whisper fine-tuning, typical hyperparameters are:

  • lr_initial: 1e-5 (a conservative starting point for fine-tuning)
  • warmup_steps: 500 (sufficient for the model to warm up on the target data)

Optimization Strategy

The fine-tuning recipe uses AdamW (Adam with decoupled weight decay):

  • Weight decay: 0.01 (provides L2 regularization to prevent overfitting on small target datasets)
  • Gradient clipping: max_grad_norm=2.0 (prevents gradient explosion during fine-tuning)
  • Gradient accumulation: grad_accumulation_factor=2 (effective batch size = batch_size * 2)

Encoder Freezing

A key strategy for parameter-efficient fine-tuning is freezing the encoder while training only the decoder:

  • The encoder has already learned robust acoustic representations from 680K hours of diverse audio data. These representations generalize well across languages and domains.
  • The decoder contains the language model component, which benefits most from adaptation to a specific language or domain.
  • Freezing the encoder reduces trainable parameters significantly (roughly by half), reducing GPU memory requirements and training time.
  • This approach is particularly effective for language adaptation, where the acoustic features are similar but the target language differs.

Loss Function

The training uses Negative Log-Likelihood (NLL) loss computed over the decoder's output distribution:

  1. The decoder produces logits over the vocabulary for each position.
  2. Log-softmax is applied to obtain log probabilities.
  3. NLL loss is computed between the log probabilities and the target token sequence (tokens_eos).
  4. The loss is masked to ignore padding tokens using the provided sequence lengths.

Optional label smoothing can be applied to prevent the model from becoming overconfident, which is especially useful when fine-tuning on small datasets.

Forward Pass

The compute_forward method implements the full forward pass:

  1. Audio augmentation (optional): Time-domain augmentations (speed perturbation, frequency dropping, chunk dropping) can be applied during training.
  2. Padding mask computation: Decoder input sequences are padded with the pad_token_id expected by Whisper's decoder.
  3. Encoder-decoder forward: Waveforms and BOS tokens are passed through the Whisper model to obtain logits.
  4. Log-softmax: Logits are converted to log probabilities.
  5. Decoding (validation/test only): Greedy search or beam search is applied to generate hypotheses.

Checkpointing

The training recipe saves checkpoints based on the best validation WER:

checkpointer.save_and_keep_only(
    meta={"WER": stage_stats["WER"]},
    min_keys=["WER"],
)

This ensures that only the model with the lowest validation Word Error Rate is kept, preventing storage waste and enabling reliable model selection.

See Also

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment