Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Speechbrain Speechbrain S2SWhisperBeamSearcher

From Leeroopedia


Field Value
API S2SWhisperGreedySearcher(model, temperature=0.0, use_kv_cache=True, suppress_blank=True, suppress_tokens="-1", sample_len=None, prefix=None, prompt=None, **kwargs) and S2SWhisperBeamSearcher(module, temperature=1.0, use_kv_cache=True, suppress_blank=True, suppress_tokens="-1", sample_len=None, prefix=None, prompt=None, **kwargs). Both: forward(enc_states, wav_len) -> (hyps, lengths, scores, log_probs)
Source speechbrain/decoders/seq2seq.py:L352-565 (S2SWhisperGreedySearcher), L1855-2118 (S2SWhisperBeamSearcher)
Import from speechbrain.decoders.seq2seq import S2SWhisperGreedySearcher, S2SWhisperBeamSearcher
Type API Doc
Inputs enc_states (encoder output tensor from Whisper encoder), wav_len (relative length tensor)
Outputs hyps (List[List[int]] - decoded token IDs), lengths, scores, log_probs
Related Principle Principle:Speechbrain_Speechbrain_Beam_Search_Decoding

Purpose

Provides Whisper-specific greedy and beam search decoders that handle Whisper's unique token structure (language, task, timestamp tokens), KV caching for efficient autoregressive decoding, and token suppression for clean ASR outputs.

S2SWhisperGreedySearcher

Used during validation for fast hypothesis generation.

Constructor

from speechbrain.decoders.seq2seq import S2SWhisperGreedySearcher

greedy_searcher = S2SWhisperGreedySearcher(
    model=whisper_model,
    temperature=0.0,
    use_kv_cache=True,
    suppress_blank=True,
    suppress_tokens="-1",
    min_decode_ratio=0.0,
    max_decode_ratio=1.0,
)

Parameters

Parameter Type Default Description
model Whisper required The SpeechBrain Whisper model instance
temperature float 0.0 Decoding temperature. 0.0 = argmax (deterministic greedy)
use_kv_cache bool True Enable key-value caching for efficient autoregressive decoding
suppress_blank bool True Suppress blank tokens and EOS at the first decoding step
suppress_tokens str or list "-1" Token IDs to suppress. "-1" suppresses non-speech tokens defined in model.non_speech_tokens
sample_len int None Maximum number of tokens to sample. Defaults to max_attn_tokens // 2
prefix str or list None Optional prefix to prepend to the decoder input
prompt str or list None Optional prompt for conditional generation

S2SWhisperBeamSearcher

Used during testing for higher-quality hypothesis generation.

Constructor

from speechbrain.decoders.seq2seq import S2SWhisperBeamSearcher

beam_searcher = S2SWhisperBeamSearcher(
    module=[whisper_model],
    temperature=1.0,
    use_kv_cache=True,
    suppress_blank=True,
    suppress_tokens="-1",
    min_decode_ratio=0.0,
    max_decode_ratio=1.0,
    beam_size=8,
)

Note: The module parameter takes a list containing the Whisper model (not the model directly). This follows the S2SBeamSearcher interface convention where module[0] is accessed internally.

Additional Parameter

Parameter Type Default Description
beam_size int required (via kwargs) Number of beams to maintain during search. Typical value: 8
module list required List containing the Whisper model as its first element

Usage in Training Recipe

# In compute_forward:
enc_out, logits, _ = self.modules.whisper(wavs, bos_tokens)

# Validation: greedy search for speed
if stage == sb.Stage.VALID:
    hyps, _, _, _ = self.hparams.valid_search(
        enc_out.detach(), wav_lens
    )

# Testing: beam search for quality
elif stage == sb.Stage.TEST:
    hyps, _, _, _ = self.hparams.test_search(
        enc_out.detach(), wav_lens
    )

# Decode hypotheses to text
predicted_words = [
    tokenizer.decode(t, skip_special_tokens=True).strip()
    for t in hyps
]

YAML Configuration

min_decode_ratio: 0.0
max_decode_ratio: 1.0
test_beam_size: 8

valid_search: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearcher
    model: !ref <whisper>
    min_decode_ratio: !ref <min_decode_ratio>
    max_decode_ratio: !ref <max_decode_ratio>

test_search: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearcher
    module: [!ref <whisper>]
    min_decode_ratio: !ref <min_decode_ratio>
    max_decode_ratio: !ref <max_decode_ratio>
    beam_size: !ref <test_beam_size>

Key Internal Methods

reset_mem(batch_size, device)

Initializes the decoder memory with Whisper's initial token sequence:

# Initial tokens are derived from the tokenizer's prefix_tokens:
# [<|startoftranscript|>, <|language|>, <|task|>, <|notimestamps|>]
# The last token becomes the first decoder input.
# The remaining tokens are stored as memory.
memory_tokens = self.initial_tokens[:-1]
mem = torch.tensor([memory_tokens] * batch_size).to(device)

Also resets the KV cache to None at the start of each new utterance.

forward_step(inp_tokens, memory, enc_states, enc_lens)

Performs a single decoding step:

  1. Appends the new input token to the memory sequence.
  2. Runs model.forward_decoder with optional KV caching.
  3. Computes no-speech probabilities at the BOS position.
  4. Extracts logits for the last position.
  5. Updates the KV cache.
  6. Applies blank suppression and token suppression.
  7. Returns logits (greedy) or log probabilities (beam search).

_check_end_condition

Checks whether the generated sequence has reached the maximum attention window length (max_attn_tokens - sample_begin).

permute_mem (Beam Search only)

Reorders the decoder memory and KV cache when beams are pruned. Uses _reorder_cache to select the correct past key-value pairs for surviving beams.

See Also

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment