Implementation:Speechbrain Speechbrain S2SWhisperBeamSearcher
| Field | Value |
|---|---|
| API | S2SWhisperGreedySearcher(model, temperature=0.0, use_kv_cache=True, suppress_blank=True, suppress_tokens="-1", sample_len=None, prefix=None, prompt=None, **kwargs) and S2SWhisperBeamSearcher(module, temperature=1.0, use_kv_cache=True, suppress_blank=True, suppress_tokens="-1", sample_len=None, prefix=None, prompt=None, **kwargs). Both: forward(enc_states, wav_len) -> (hyps, lengths, scores, log_probs) |
| Source | speechbrain/decoders/seq2seq.py:L352-565 (S2SWhisperGreedySearcher), L1855-2118 (S2SWhisperBeamSearcher) |
| Import | from speechbrain.decoders.seq2seq import S2SWhisperGreedySearcher, S2SWhisperBeamSearcher |
| Type | API Doc |
| Inputs | enc_states (encoder output tensor from Whisper encoder), wav_len (relative length tensor) |
| Outputs | hyps (List[List[int]] - decoded token IDs), lengths, scores, log_probs |
| Related Principle | Principle:Speechbrain_Speechbrain_Beam_Search_Decoding |
Purpose
Provides Whisper-specific greedy and beam search decoders that handle Whisper's unique token structure (language, task, timestamp tokens), KV caching for efficient autoregressive decoding, and token suppression for clean ASR outputs.
S2SWhisperGreedySearcher
Used during validation for fast hypothesis generation.
Constructor
from speechbrain.decoders.seq2seq import S2SWhisperGreedySearcher
greedy_searcher = S2SWhisperGreedySearcher(
model=whisper_model,
temperature=0.0,
use_kv_cache=True,
suppress_blank=True,
suppress_tokens="-1",
min_decode_ratio=0.0,
max_decode_ratio=1.0,
)
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
| model | Whisper | required | The SpeechBrain Whisper model instance |
| temperature | float | 0.0 | Decoding temperature. 0.0 = argmax (deterministic greedy) |
| use_kv_cache | bool | True | Enable key-value caching for efficient autoregressive decoding |
| suppress_blank | bool | True | Suppress blank tokens and EOS at the first decoding step |
| suppress_tokens | str or list | "-1" | Token IDs to suppress. "-1" suppresses non-speech tokens defined in model.non_speech_tokens |
| sample_len | int | None | Maximum number of tokens to sample. Defaults to max_attn_tokens // 2 |
| prefix | str or list | None | Optional prefix to prepend to the decoder input |
| prompt | str or list | None | Optional prompt for conditional generation |
S2SWhisperBeamSearcher
Used during testing for higher-quality hypothesis generation.
Constructor
from speechbrain.decoders.seq2seq import S2SWhisperBeamSearcher
beam_searcher = S2SWhisperBeamSearcher(
module=[whisper_model],
temperature=1.0,
use_kv_cache=True,
suppress_blank=True,
suppress_tokens="-1",
min_decode_ratio=0.0,
max_decode_ratio=1.0,
beam_size=8,
)
Note: The module parameter takes a list containing the Whisper model (not the model directly). This follows the S2SBeamSearcher interface convention where module[0] is accessed internally.
Additional Parameter
| Parameter | Type | Default | Description |
|---|---|---|---|
| beam_size | int | required (via kwargs) | Number of beams to maintain during search. Typical value: 8 |
| module | list | required | List containing the Whisper model as its first element |
Usage in Training Recipe
# In compute_forward:
enc_out, logits, _ = self.modules.whisper(wavs, bos_tokens)
# Validation: greedy search for speed
if stage == sb.Stage.VALID:
hyps, _, _, _ = self.hparams.valid_search(
enc_out.detach(), wav_lens
)
# Testing: beam search for quality
elif stage == sb.Stage.TEST:
hyps, _, _, _ = self.hparams.test_search(
enc_out.detach(), wav_lens
)
# Decode hypotheses to text
predicted_words = [
tokenizer.decode(t, skip_special_tokens=True).strip()
for t in hyps
]
YAML Configuration
min_decode_ratio: 0.0
max_decode_ratio: 1.0
test_beam_size: 8
valid_search: !new:speechbrain.decoders.seq2seq.S2SWhisperGreedySearcher
model: !ref <whisper>
min_decode_ratio: !ref <min_decode_ratio>
max_decode_ratio: !ref <max_decode_ratio>
test_search: !new:speechbrain.decoders.seq2seq.S2SWhisperBeamSearcher
module: [!ref <whisper>]
min_decode_ratio: !ref <min_decode_ratio>
max_decode_ratio: !ref <max_decode_ratio>
beam_size: !ref <test_beam_size>
Key Internal Methods
reset_mem(batch_size, device)
Initializes the decoder memory with Whisper's initial token sequence:
# Initial tokens are derived from the tokenizer's prefix_tokens:
# [<|startoftranscript|>, <|language|>, <|task|>, <|notimestamps|>]
# The last token becomes the first decoder input.
# The remaining tokens are stored as memory.
memory_tokens = self.initial_tokens[:-1]
mem = torch.tensor([memory_tokens] * batch_size).to(device)
Also resets the KV cache to None at the start of each new utterance.
forward_step(inp_tokens, memory, enc_states, enc_lens)
Performs a single decoding step:
- Appends the new input token to the memory sequence.
- Runs model.forward_decoder with optional KV caching.
- Computes no-speech probabilities at the BOS position.
- Extracts logits for the last position.
- Updates the KV cache.
- Applies blank suppression and token suppression.
- Returns logits (greedy) or log probabilities (beam search).
_check_end_condition
Checks whether the generated sequence has reached the maximum attention window length (max_attn_tokens - sample_begin).
permute_mem (Beam Search only)
Reorders the decoder memory and KV cache when beams are pruned. Uses _reorder_cache to select the correct past key-value pairs for surviving beams.
See Also
- Principle:Speechbrain_Speechbrain_Beam_Search_Decoding
- Implementation:Speechbrain_Speechbrain_Whisper_HFTransformersInterface
- Implementation:Speechbrain_Speechbrain_Whisper_ASR_Compute_Forward