Implementation:Openai Whisper DecodingTask Run
Overview
DecodingTask is the core class that orchestrates single-segment decoding in Whisper. Its __init__ method assembles the full decoding pipeline (tokenizer, decoder strategy, logit filters, inference engine), and its run() method executes the pipeline on a batched mel spectrogram tensor.
Source
- File:
whisper/decoding.py:L508-570(__init__),whisper/decoding.py:L712-789(run) - Import:
from whisper.decoding import DecodingTask - Repository: https://github.com/openai/whisper
Signatures
def __init__(self, model: "Whisper", options: DecodingOptions):
@torch.no_grad()
def run(self, mel: Tensor) -> List[DecodingResult]:
Initialization (__init__)
The constructor sets up all components needed for decoding:
Tokenizer
A Tokenizer is created with the specified language and task. This provides the vocabulary, special token IDs, and encoding/decoding methods.
Group Size
The n_group parameter determines how many parallel sequences are maintained:
- If beam_size is set,
n_group = beam_size - If best_of is set,
n_group = best_of - Otherwise,
n_group = 1
Initial Token Sequence
The start-of-transcript (SOT) sequence is constructed:
- SOT token + language token + task token (+ optional timestamp token)
- Optional prompt tokens prepended before SOT
- Optional prefix tokens appended after the SOT sequence
Inference Engine
A PyTorchInference instance is created, which wraps the model's decoder with key-value caching for efficient autoregressive generation.
Sequence Ranker
A MaximumLikelihoodRanker is instantiated with the configured length_penalty. This ranks candidate sequences by their average log probability.
Decoder Strategy
Based on the options:
- temperature == 0 and no beam search:
GreedyDecoderwithtemperature=0(argmax) - temperature > 0 and no beam search:
GreedyDecoderwith the specified temperature (sampling) - beam_size specified:
BeamSearchDecoderwith the given beam size and patience
Logit Filters
A list of logit filters is assembled:
- SuppressBlank — if
suppress_blank=True, prevents blank/EOT tokens at the start - SuppressTokens — suppresses the configured non-speech token IDs
- ApplyTimestampRules — if timestamps are enabled, enforces monotonic timestamp ordering and proper timestamp-text alternation
Execution (run())
The run() method processes a batch of mel spectrograms through the full decoding pipeline:
Step-by-Step Flow
- Reset decoder state (clear any cached sequences)
- Encode audio features — pass the mel tensor through the model's encoder to produce audio feature representations
- Detect language (if needed) — if no language is specified, perform a forward pass to identify the most probable language from the encoder output
- Repeat tokens for beam/sampling groups — duplicate the initial token sequences
n_grouptimes per batch element to support parallel beam or sampling candidates - Main autoregressive loop (
_main_loop) — iteratively generate tokens:- Forward pass through decoder with cached key-value pairs
- Apply all logit filters
- Select next tokens via the decoder strategy (greedy/beam)
- Check for completion (EOT token or max length)
- Finalize sequences — collect completed token sequences
- Rank sequences — use
MaximumLikelihoodRankerto select the best sequence per batch element - Build
DecodingResultobjects — decode tokens to text, attach metadata (language, probabilities, compression ratio)
Inputs and Outputs
- Inputs:
meltensor of shape (batch, 80, 3000) - Outputs:
List[DecodingResult]with one result per batch element
Usage Example
from whisper.decoding import DecodingTask, DecodingOptions
options = DecodingOptions(beam_size=5, language="en")
task = DecodingTask(model, options)
results = task.run(mel_segment)
for r in results:
print(r.text, r.avg_logprob)
Key Notes
DecodingTaskis typically not called directly by users. Thedecode()function andtranscribe()function create and invoke it internally.- The
run()method is decorated with@torch.no_grad()for inference efficiency. - Key-value caching in
PyTorchInferenceis essential for performance. Without it, the decoder would need to reprocess all previous tokens at each step. - The class handles both single-segment decoding (called by
decode()) and repeated-segment decoding (called bytranscribe()in a loop).