Principle:Openai Whisper Cross Attention Alignment
Overview
Cross-Attention Alignment is a technique for extracting word-level timing information from the cross-attention weights in encoder-decoder transformer models. In OpenAI's Whisper model, specific cross-attention heads in the decoder correlate strongly with the alignment between audio frames and text tokens, enabling precise word-level timestamps without a separate forced alignment model.
Domain
- Speech Recognition
- Attention Mechanisms
- Forced Alignment
Theoretical Background
In encoder-decoder transformer architectures such as Whisper, the decoder attends to the encoder output via cross-attention layers. Each cross-attention head computes attention weights of shape (num_text_tokens, num_audio_frames), representing how much each decoded token "attends to" each audio frame.
The Whisper paper (Radford et al., 2022) identifies that certain cross-attention heads exhibit strong diagonal alignment patterns, where the attention weight for a given text token peaks at the audio frame corresponding to when that word was spoken. These heads are called alignment heads, and their indices are stored in the model as the alignment_heads attribute.
Extraction Pipeline
The process for extracting word-level timestamps from cross-attention weights involves several steps:
- Forward Pass with Raw Attention: Run the decoder with Scaled Dot-Product Attention (SDPA) disabled so that raw attention weight matrices are computed rather than optimized flash attention outputs.
- Hook-Based Weight Capture: Install forward hooks on cross-attention layers to capture the QK (query-key) attention weight matrices during the forward pass.
- Alignment Head Selection: Extract attention weights only from the identified alignment heads, discarding the rest.
- Z-Score Normalization: Normalize attention weights across the time dimension using z-score normalization to standardize the distribution across different heads.
- Median Filtering: Apply a median filter along the time axis to smooth out impulse noise and frame-level fluctuations while preserving transitions.
- Head Averaging: Average the normalized, filtered weights across the selected alignment heads to produce a single attention matrix.
- Dynamic Time Warping (DTW): Apply DTW on the negative attention matrix to find the optimal monotonic alignment path from text tokens to audio frames.
- Word Grouping: Split tokens into word-level groups and map DTW path indices to word-level start and end times.
Mathematical Formulation
Given an attention weight matrix A of shape (T, F) where T is the number of text tokens and F is the number of audio frames:
- Z-score normalization: For each head, normalize across the time dimension so that each frame column has zero mean and unit variance.
- Median filter: Replace each value with the median of its neighbors within a sliding window of width k (typically 7).
- DTW alignment: Find path P through cost matrix C = -A that minimizes total cost while maintaining monotonicity.
Significance
This approach is a form of attention-based forced alignment that leverages the learned internal representations of the model, rather than requiring a separate alignment model or phoneme-level annotations. It enables:
- Word-level timestamps for transcriptions
- Karaoke-style subtitle generation with per-word highlighting
- Quality assessment of transcription segments based on alignment confidence
References
- Radford, A., Kim, J.W., Xu, T., Brockman, G., McLeavey, C., & Sutskever, I. (2022). Robust Speech Recognition via Large-Scale Weak Supervision. OpenAI.
Implementation
Implementation:Openai_Whisper_Find_Alignment Heuristic:Openai_Whisper_SDPA_Disabling_For_Attention_Extraction