Principle:Speechbrain Speechbrain Whisper Finetuning With LR Scheduling
| Field | Value |
|---|---|
| Concept | Fine-tuning pretrained encoder-decoder models with learning rate warmup and decay schedules |
| Domains | Transfer_Learning, ASR, Optimization |
| Knowledge Sources | Vaswani et al. 2017 "Attention is All You Need" (Noam scheduler); Radford et al. 2023 "Robust Speech Recognition via Large-Scale Weak Supervision" |
| Related Implementation | Implementation:Speechbrain_Speechbrain_Whisper_ASR_Compute_Forward |
Overview
Fine-tuning large pretrained models such as Whisper requires careful learning rate management to balance preserving pretrained knowledge with adapting to the target domain. The Noam learning rate scheduler, combined with AdamW optimization and encoder freezing, provides an effective training regime for Whisper fine-tuning within SpeechBrain.
Learning Rate Scheduling
The Noam scheduler (from "Attention is All You Need") provides a warmup phase followed by inverse square root decay:
LR(step) = d_model^(-0.5) * min(step^(-0.5), step * warmup_steps^(-1.5))
This schedule has two phases:
- Warmup phase (step < warmup_steps): The learning rate increases linearly from near-zero to the peak value. This prevents large, destructive gradient updates in the early steps when the model has not yet adapted to the new data distribution.
- Decay phase (step >= warmup_steps): The learning rate decays proportionally to the inverse square root of the step number, gradually reducing the update magnitude as training converges.
For Whisper fine-tuning, typical hyperparameters are:
- lr_initial: 1e-5 (a conservative starting point for fine-tuning)
- warmup_steps: 500 (sufficient for the model to warm up on the target data)
Optimization Strategy
The fine-tuning recipe uses AdamW (Adam with decoupled weight decay):
- Weight decay: 0.01 (provides L2 regularization to prevent overfitting on small target datasets)
- Gradient clipping: max_grad_norm=2.0 (prevents gradient explosion during fine-tuning)
- Gradient accumulation: grad_accumulation_factor=2 (effective batch size = batch_size * 2)
Encoder Freezing
A key strategy for parameter-efficient fine-tuning is freezing the encoder while training only the decoder:
- The encoder has already learned robust acoustic representations from 680K hours of diverse audio data. These representations generalize well across languages and domains.
- The decoder contains the language model component, which benefits most from adaptation to a specific language or domain.
- Freezing the encoder reduces trainable parameters significantly (roughly by half), reducing GPU memory requirements and training time.
- This approach is particularly effective for language adaptation, where the acoustic features are similar but the target language differs.
Loss Function
The training uses Negative Log-Likelihood (NLL) loss computed over the decoder's output distribution:
- The decoder produces logits over the vocabulary for each position.
- Log-softmax is applied to obtain log probabilities.
- NLL loss is computed between the log probabilities and the target token sequence (tokens_eos).
- The loss is masked to ignore padding tokens using the provided sequence lengths.
Optional label smoothing can be applied to prevent the model from becoming overconfident, which is especially useful when fine-tuning on small datasets.
Forward Pass
The compute_forward method implements the full forward pass:
- Audio augmentation (optional): Time-domain augmentations (speed perturbation, frequency dropping, chunk dropping) can be applied during training.
- Padding mask computation: Decoder input sequences are padded with the pad_token_id expected by Whisper's decoder.
- Encoder-decoder forward: Waveforms and BOS tokens are passed through the Whisper model to obtain logits.
- Log-softmax: Logits are converted to log probabilities.
- Decoding (validation/test only): Greedy search or beam search is applied to generate hypotheses.
Checkpointing
The training recipe saves checkpoints based on the best validation WER:
checkpointer.save_and_keep_only(
meta={"WER": stage_stats["WER"]},
min_keys=["WER"],
)
This ensures that only the model with the lowest validation Word Error Rate is kept, preventing storage waste and enabling reliable model selection.