Implementation:Openai Whisper Find Alignment
Overview
find_alignment() is a function in Whisper's timing module that extracts word-level timestamps by capturing cross-attention weights during a forward pass, normalizing and filtering them, applying Dynamic Time Warping (DTW), and mapping the alignment path to word boundaries. It is the core implementation of cross-attention alignment in the Whisper speech recognition system.
Source
- File:
whisper/timing.py, lines 163-242 - Repository: https://github.com/openai/whisper
- Import:
from whisper.timing import find_alignment
Signature
def find_alignment(
model: "Whisper",
tokenizer: Tokenizer,
text_tokens: List[int],
mel: torch.Tensor,
num_frames: int,
*,
medfilt_width: int = 7,
qk_scale: float = 1.0,
) -> List[WordTiming]:
Parameters
| Parameter | Type | Description |
|---|---|---|
| model | Whisper | Whisper model instance with alignment_heads attribute set
|
| tokenizer | Tokenizer | Tokenizer providing sot_sequence, eot, and split_to_word_tokens()
|
| text_tokens | List[int] | Token IDs for the segment text (from decoding result) |
| mel | torch.Tensor | Mel spectrogram tensor for the audio segment |
| num_frames | int | Number of audio frames in this segment |
| medfilt_width | int | Median filter kernel width (default 7, must be odd) |
| qk_scale | float | Scaling factor for QK attention weights (default 1.0) |
Return Value
Returns a List[WordTiming]. The WordTiming dataclass is defined at whisper/timing.py:L154-160:
@dataclass
class WordTiming:
word: str
tokens: List[int]
start: float
end: float
probability: float
Each WordTiming contains:
- word: The decoded word text
- tokens: List of token IDs composing this word
- start: Start time in seconds
- end: End time in seconds
- probability: Average token probability for this word
Behavior
The function performs the following steps in sequence:
- Token Preparation: Prepends
sot_sequence+no_timestampstoken, appendseotto the providedtext_tokens. - Hook Installation: Installs forward hooks on all cross-attention layers to capture QK attention weight matrices during the forward pass.
- Forward Pass: Runs the model forward pass with SDPA (Scaled Dot-Product Attention) disabled to obtain raw attention scores rather than flash attention outputs.
- Alignment Head Extraction: Extracts attention weights only from the heads identified as alignment heads in
model.alignment_heads. - Z-Score Normalization: Normalizes attention weights across the time dimension using z-score normalization (subtract mean, divide by standard deviation).
- Median Filtering: Applies a median filter with kernel width
medfilt_widthto smooth the attention weights along the time axis. - Head Averaging: Averages the normalized, filtered weights across all selected alignment heads.
- DTW Alignment: Runs Dynamic Time Warping on the negative attention matrix to find the optimal monotonic alignment path.
- Word Splitting: Uses
tokenizer.split_to_word_tokens()to group subword tokens into words. - Boundary Mapping: Maps the DTW path to word boundaries, computing start/end times and token probabilities for each word.
Example Usage
from whisper.timing import find_alignment
from whisper.tokenizer import get_tokenizer
tokenizer = get_tokenizer(
model.is_multilingual,
num_languages=model.num_languages,
language="en"
)
text_tokens = [...] # from decoding result
word_timings = find_alignment(model, tokenizer, text_tokens, mel_segment, num_frames)
for wt in word_timings:
print(f"[{wt.start:.2f}-{wt.end:.2f}] {wt.word} (p={wt.probability:.2f})")
Sample output:
[0.00-0.48] Hello (p=0.95)
[0.48-0.92] world (p=0.91)
[0.92-1.20] how (p=0.88)
[1.20-1.56] are (p=0.93)
[1.56-2.00] you (p=0.90)
Key Dependencies
- dtw(): Dynamic Time Warping for optimal alignment path computation
- median_filter(): Median filtering for attention weight smoothing
- tokenizer.split_to_word_tokens(): Word boundary detection from subword tokens
- Forward hooks on
MultiHeadAttentioncross-attention layers
Links
Principle:Openai_Whisper_Cross_Attention_Alignment
Environment:Openai_Whisper_PyTorch_CUDA Heuristic:Openai_Whisper_SDPA_Disabling_For_Attention_Extraction