Implementation:Speechbrain Speechbrain Whisper HFTransformersInterface
| Field | Value |
|---|---|
| API | Whisper(source, save_path, sampling_rate=16000, encoder_only=False, freeze=False, freeze_encoder=False, output_attentions=False, output_all_hiddens=False, language=None, task="transcribe") |
| Source | speechbrain/lobes/models/huggingface_transformers/whisper.py:L33-636 (class), L90-166 (__init__) |
| Import | from speechbrain.lobes.models.huggingface_transformers.whisper import Whisper |
| Type | Wrapper Doc |
| Inputs | HuggingFace model identifier (e.g., "openai/whisper-medium"), optional local save path |
| Outputs | Initialized Whisper model with encoder, decoder, tokenizer. Key methods: forward(wav, decoder_input_ids), forward_encoder(mel), forward_decoder(enc_states, dec_ids) |
| Related Principle | Principle:Speechbrain_Speechbrain_Whisper_Model_Loading |
Purpose
Wraps OpenAI's Whisper model from HuggingFace Transformers for use in SpeechBrain's training and inference pipelines. Handles model downloading, mel spectrogram computation, encoder/decoder forwarding, tokenizer configuration, and freezing strategies.
Constructor
from speechbrain.lobes.models.huggingface_transformers.whisper import Whisper
whisper_model = Whisper(
source="openai/whisper-medium",
save_path="results/whisper_fr/save/whisper_checkpoint",
sampling_rate=16000,
encoder_only=False,
freeze=False,
freeze_encoder=True,
output_attentions=False,
output_all_hiddens=False,
language="fr",
task="transcribe",
)
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
| source | str | required | HuggingFace hub identifier (e.g., "openai/whisper-tiny", "openai/whisper-medium") |
| save_path | str | required | Local directory where the downloaded model weights are cached |
| sampling_rate | int | 16000 | Expected audio sampling rate in Hz |
| encoder_only | bool | False | If True, deletes the decoder and returns only encoder states. Reduces memory usage for feature extraction tasks |
| freeze | bool | False | If True, freezes the entire model (encoder + decoder). Parameters do not receive gradients |
| freeze_encoder | bool | False | If True, freezes only the encoder while keeping the decoder trainable. Recommended for fine-tuning |
| output_attentions | bool | False | If True, returns decoder attention weights. Disables flash attention, increasing memory usage |
| output_all_hiddens | bool | False | If True, returns hidden states from all encoder layers (stacked). E.g., whisper-base returns shape (7, B, T, C) |
| language | str | None | Target language code (e.g., "fr", "de", "en"). Sets the language prefix token for multilingual models |
| task | str | "transcribe" | Either "transcribe" (output in source language) or "translate" (output in English) |
Key Methods
forward(wav, decoder_input_ids)
Full encoder-decoder forward pass.
import torch
# Example forward pass
wavs = torch.randn(2, 48000) # batch of 2, 3 seconds at 16kHz
bos_tokens = torch.tensor([[50258, 50259, 50360, 50364],
[50258, 50259, 50360, 50364]]) # prefix tokens
enc_out, logits, attn = whisper_model(wavs, bos_tokens)
# enc_out: (2, 1500, 1024) - encoder hidden states
# logits: (2, 4, 51865) - decoder output logits
# attn: None (unless output_attentions=True)
Returns: Tuple of (encoder_output, decoder_logits, decoder_attention).
forward_encoder(mel)
Encoder-only forward pass on mel spectrogram features.
mel = whisper_model._get_mel(wavs) # compute mel internally
enc_states = whisper_model.forward_encoder(mel)
# enc_states: (batch, time, hidden_dim)
Returns: Last hidden state of the encoder (or all hidden states if output_all_hiddens=True).
forward_decoder(encoder_states, decoder_input_ids, use_cache=True, past_key_values=None)
Single decoder step with optional KV caching.
logits, attn, past_kv = whisper_model.forward_decoder(
enc_states, bos_tokens, use_cache=True, past_key_values=None
)
# logits: (batch, seq_len, vocab_size)
# attn: None or attention tensor
# past_kv: cached key-value pairs for subsequent steps
Returns: Tuple of (logits, attention, past_key_values).
Tokenizer Access
The Whisper tokenizer is accessible via whisper_model.tokenizer:
tokenizer = whisper_model.tokenizer
# Encode text
tokens = tokenizer.encode("Bonjour le monde", add_special_tokens=False)
# [35309, 531, 32828]
# Wrap with special tokens (BOS, language, task, notimestamps, ..., EOS)
full_tokens = tokenizer.build_inputs_with_special_tokens(tokens)
# [50258, 50265, 50360, 50364, 35309, 531, 32828, 50257]
# Decode back to text
text = tokenizer.decode(full_tokens, skip_special_tokens=True)
# "Bonjour le monde"
# Normalize for evaluation
normalized = tokenizer.normalize("Bonjour, le monde!")
# "bonjour le monde"
Special Token Properties
The wrapper exposes cached properties for commonly used special tokens:
whisper_model.bos # <|startoftranscript|> token ID
whisper_model.eos # <|endoftext|> token ID
whisper_model.transcribe # <|transcribe|> token ID
whisper_model.translate # <|translate|> token ID
whisper_model.no_timestamps # <|notimestamps|> token ID
whisper_model.no_speech # no_speech token ID
whisper_model.language_token # language token ID for configured language
whisper_model.is_multilingual # True if vocab_size >= 51865
YAML Configuration
whisper_hub: openai/whisper-medium
freeze_whisper: False
freeze_encoder: True
language: fr
whisper: !new:speechbrain.lobes.models.huggingface_transformers.whisper.Whisper
source: !ref <whisper_hub>
freeze: !ref <freeze_whisper>
freeze_encoder: !ref <freeze_encoder>
save_path: !ref <save_folder>/whisper_checkpoint
language: !ref <language>
task: "transcribe"
modules:
whisper: !ref <whisper>
Mel Spectrogram Computation
The _get_mel method handles internal mel spectrogram computation:
- Pad or trim audio to exactly 480,000 samples (30 seconds at 16kHz).
- STFT with n_fft=400, hop_length=160 using a Hann window.
- Mel projection using 80 mel filter banks loaded from the HuggingFace feature extractor.
- Log scaling with clamping at 1e-10, max normalization (-8.0 dB floor), and offset normalization ((log + 4) / 4).
See Also
- Principle:Speechbrain_Speechbrain_Whisper_Model_Loading
- Implementation:Speechbrain_Speechbrain_Whisper_ASR_Compute_Forward
- Implementation:Speechbrain_Speechbrain_S2SWhisperBeamSearcher