Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Speechbrain Speechbrain Whisper HFTransformersInterface

From Leeroopedia


Field Value
API Whisper(source, save_path, sampling_rate=16000, encoder_only=False, freeze=False, freeze_encoder=False, output_attentions=False, output_all_hiddens=False, language=None, task="transcribe")
Source speechbrain/lobes/models/huggingface_transformers/whisper.py:L33-636 (class), L90-166 (__init__)
Import from speechbrain.lobes.models.huggingface_transformers.whisper import Whisper
Type Wrapper Doc
Inputs HuggingFace model identifier (e.g., "openai/whisper-medium"), optional local save path
Outputs Initialized Whisper model with encoder, decoder, tokenizer. Key methods: forward(wav, decoder_input_ids), forward_encoder(mel), forward_decoder(enc_states, dec_ids)
Related Principle Principle:Speechbrain_Speechbrain_Whisper_Model_Loading

Purpose

Wraps OpenAI's Whisper model from HuggingFace Transformers for use in SpeechBrain's training and inference pipelines. Handles model downloading, mel spectrogram computation, encoder/decoder forwarding, tokenizer configuration, and freezing strategies.

Constructor

from speechbrain.lobes.models.huggingface_transformers.whisper import Whisper

whisper_model = Whisper(
    source="openai/whisper-medium",
    save_path="results/whisper_fr/save/whisper_checkpoint",
    sampling_rate=16000,
    encoder_only=False,
    freeze=False,
    freeze_encoder=True,
    output_attentions=False,
    output_all_hiddens=False,
    language="fr",
    task="transcribe",
)

Parameters

Parameter Type Default Description
source str required HuggingFace hub identifier (e.g., "openai/whisper-tiny", "openai/whisper-medium")
save_path str required Local directory where the downloaded model weights are cached
sampling_rate int 16000 Expected audio sampling rate in Hz
encoder_only bool False If True, deletes the decoder and returns only encoder states. Reduces memory usage for feature extraction tasks
freeze bool False If True, freezes the entire model (encoder + decoder). Parameters do not receive gradients
freeze_encoder bool False If True, freezes only the encoder while keeping the decoder trainable. Recommended for fine-tuning
output_attentions bool False If True, returns decoder attention weights. Disables flash attention, increasing memory usage
output_all_hiddens bool False If True, returns hidden states from all encoder layers (stacked). E.g., whisper-base returns shape (7, B, T, C)
language str None Target language code (e.g., "fr", "de", "en"). Sets the language prefix token for multilingual models
task str "transcribe" Either "transcribe" (output in source language) or "translate" (output in English)

Key Methods

forward(wav, decoder_input_ids)

Full encoder-decoder forward pass.

import torch

# Example forward pass
wavs = torch.randn(2, 48000)  # batch of 2, 3 seconds at 16kHz
bos_tokens = torch.tensor([[50258, 50259, 50360, 50364],
                            [50258, 50259, 50360, 50364]])  # prefix tokens

enc_out, logits, attn = whisper_model(wavs, bos_tokens)
# enc_out: (2, 1500, 1024) - encoder hidden states
# logits: (2, 4, 51865) - decoder output logits
# attn: None (unless output_attentions=True)

Returns: Tuple of (encoder_output, decoder_logits, decoder_attention).

forward_encoder(mel)

Encoder-only forward pass on mel spectrogram features.

mel = whisper_model._get_mel(wavs)  # compute mel internally
enc_states = whisper_model.forward_encoder(mel)
# enc_states: (batch, time, hidden_dim)

Returns: Last hidden state of the encoder (or all hidden states if output_all_hiddens=True).

forward_decoder(encoder_states, decoder_input_ids, use_cache=True, past_key_values=None)

Single decoder step with optional KV caching.

logits, attn, past_kv = whisper_model.forward_decoder(
    enc_states, bos_tokens, use_cache=True, past_key_values=None
)
# logits: (batch, seq_len, vocab_size)
# attn: None or attention tensor
# past_kv: cached key-value pairs for subsequent steps

Returns: Tuple of (logits, attention, past_key_values).

Tokenizer Access

The Whisper tokenizer is accessible via whisper_model.tokenizer:

tokenizer = whisper_model.tokenizer

# Encode text
tokens = tokenizer.encode("Bonjour le monde", add_special_tokens=False)
# [35309, 531, 32828]

# Wrap with special tokens (BOS, language, task, notimestamps, ..., EOS)
full_tokens = tokenizer.build_inputs_with_special_tokens(tokens)
# [50258, 50265, 50360, 50364, 35309, 531, 32828, 50257]

# Decode back to text
text = tokenizer.decode(full_tokens, skip_special_tokens=True)
# "Bonjour le monde"

# Normalize for evaluation
normalized = tokenizer.normalize("Bonjour, le monde!")
# "bonjour le monde"

Special Token Properties

The wrapper exposes cached properties for commonly used special tokens:

whisper_model.bos           # <|startoftranscript|> token ID
whisper_model.eos           # <|endoftext|> token ID
whisper_model.transcribe    # <|transcribe|> token ID
whisper_model.translate     # <|translate|> token ID
whisper_model.no_timestamps # <|notimestamps|> token ID
whisper_model.no_speech     # no_speech token ID
whisper_model.language_token  # language token ID for configured language
whisper_model.is_multilingual # True if vocab_size >= 51865

YAML Configuration

whisper_hub: openai/whisper-medium
freeze_whisper: False
freeze_encoder: True
language: fr

whisper: !new:speechbrain.lobes.models.huggingface_transformers.whisper.Whisper
    source: !ref <whisper_hub>
    freeze: !ref <freeze_whisper>
    freeze_encoder: !ref <freeze_encoder>
    save_path: !ref <save_folder>/whisper_checkpoint
    language: !ref <language>
    task: "transcribe"

modules:
    whisper: !ref <whisper>

Mel Spectrogram Computation

The _get_mel method handles internal mel spectrogram computation:

  1. Pad or trim audio to exactly 480,000 samples (30 seconds at 16kHz).
  2. STFT with n_fft=400, hop_length=160 using a Hann window.
  3. Mel projection using 80 mel filter banks loaded from the HuggingFace feature extractor.
  4. Log scaling with clamping at 1e-10, max normalization (-8.0 dB floor), and offset normalization ((log + 4) / 4).

See Also

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment