Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Openai Whisper Find Alignment

From Leeroopedia

Overview

find_alignment() is a function in Whisper's timing module that extracts word-level timestamps by capturing cross-attention weights during a forward pass, normalizing and filtering them, applying Dynamic Time Warping (DTW), and mapping the alignment path to word boundaries. It is the core implementation of cross-attention alignment in the Whisper speech recognition system.

Source

Signature

def find_alignment(
    model: "Whisper",
    tokenizer: Tokenizer,
    text_tokens: List[int],
    mel: torch.Tensor,
    num_frames: int,
    *,
    medfilt_width: int = 7,
    qk_scale: float = 1.0,
) -> List[WordTiming]:

Parameters

Parameter Type Description
model Whisper Whisper model instance with alignment_heads attribute set
tokenizer Tokenizer Tokenizer providing sot_sequence, eot, and split_to_word_tokens()
text_tokens List[int] Token IDs for the segment text (from decoding result)
mel torch.Tensor Mel spectrogram tensor for the audio segment
num_frames int Number of audio frames in this segment
medfilt_width int Median filter kernel width (default 7, must be odd)
qk_scale float Scaling factor for QK attention weights (default 1.0)

Return Value

Returns a List[WordTiming]. The WordTiming dataclass is defined at whisper/timing.py:L154-160:

@dataclass
class WordTiming:
    word: str
    tokens: List[int]
    start: float
    end: float
    probability: float

Each WordTiming contains:

  • word: The decoded word text
  • tokens: List of token IDs composing this word
  • start: Start time in seconds
  • end: End time in seconds
  • probability: Average token probability for this word

Behavior

The function performs the following steps in sequence:

  1. Token Preparation: Prepends sot_sequence + no_timestamps token, appends eot to the provided text_tokens.
  2. Hook Installation: Installs forward hooks on all cross-attention layers to capture QK attention weight matrices during the forward pass.
  3. Forward Pass: Runs the model forward pass with SDPA (Scaled Dot-Product Attention) disabled to obtain raw attention scores rather than flash attention outputs.
  4. Alignment Head Extraction: Extracts attention weights only from the heads identified as alignment heads in model.alignment_heads.
  5. Z-Score Normalization: Normalizes attention weights across the time dimension using z-score normalization (subtract mean, divide by standard deviation).
  6. Median Filtering: Applies a median filter with kernel width medfilt_width to smooth the attention weights along the time axis.
  7. Head Averaging: Averages the normalized, filtered weights across all selected alignment heads.
  8. DTW Alignment: Runs Dynamic Time Warping on the negative attention matrix to find the optimal monotonic alignment path.
  9. Word Splitting: Uses tokenizer.split_to_word_tokens() to group subword tokens into words.
  10. Boundary Mapping: Maps the DTW path to word boundaries, computing start/end times and token probabilities for each word.

Example Usage

from whisper.timing import find_alignment
from whisper.tokenizer import get_tokenizer

tokenizer = get_tokenizer(
    model.is_multilingual,
    num_languages=model.num_languages,
    language="en"
)
text_tokens = [...]  # from decoding result
word_timings = find_alignment(model, tokenizer, text_tokens, mel_segment, num_frames)
for wt in word_timings:
    print(f"[{wt.start:.2f}-{wt.end:.2f}] {wt.word} (p={wt.probability:.2f})")

Sample output:

[0.00-0.48]  Hello (p=0.95)
[0.48-0.92]  world (p=0.91)
[0.92-1.20]  how (p=0.88)
[1.20-1.56]  are (p=0.93)
[1.56-2.00]  you (p=0.90)

Key Dependencies

  • dtw(): Dynamic Time Warping for optimal alignment path computation
  • median_filter(): Median filtering for attention weight smoothing
  • tokenizer.split_to_word_tokens(): Word boundary detection from subword tokens
  • Forward hooks on MultiHeadAttention cross-attention layers

Links

Principle:Openai_Whisper_Cross_Attention_Alignment

Environment:Openai_Whisper_PyTorch_CUDA Heuristic:Openai_Whisper_SDPA_Disabling_For_Attention_Extraction

Metadata

2025-06-25 00:00 GMT

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment