Heuristic:Openai Whisper SDPA Disabling For Attention Extraction
| Knowledge Sources | |
|---|---|
| Domains | Timestamps, Optimization |
| Last Updated | 2025-06-25 00:00 GMT |
Overview
Temporarily disabling PyTorch's Scaled Dot-Product Attention (SDPA) to extract raw cross-attention weights needed for word-level timestamp alignment via DTW.
Description
PyTorch 2.0+ includes an optimized `scaled_dot_product_attention` (SDPA) implementation that fuses the attention computation but does not return the raw attention weight matrix (QK^T). Since Whisper's word-level timestamp pipeline requires the raw cross-attention weights to align text tokens with audio frames, SDPA must be temporarily disabled during the alignment forward pass. The `disable_sdpa()` context manager sets `MultiHeadAttention.use_sdpa = False`, forcing the fallback code path that computes and returns attention weights explicitly.
Usage
Use this heuristic when extracting cross-attention weights for word-level timestamp computation. The `find_alignment()` function in `whisper/timing.py` wraps the model forward pass in `disable_sdpa()` to ensure attention weights are available. Normal decoding (without word timestamps) uses SDPA for faster inference.
The Insight (Rule of Thumb)
- Action: Wrap the forward pass in `disable_sdpa()` context manager when attention weights are needed.
- Value: Enables extraction of per-head attention weight matrices from cross-attention layers.
- Trade-off: Slower attention computation (no SDPA fusion) during the alignment pass. This is acceptable because alignment is done once per segment, not per decoding step.
- Key detail: SDPA returns `qk = None` when active, so any code that needs `qk` must disable SDPA first.
Reasoning
SDPA combines Q, K, V projections and softmax into a single fused kernel for performance, but discards intermediate attention weights. Word-level timestamps require these weights to build an alignment matrix between text tokens and audio frames. The class-level `use_sdpa` flag provides a clean toggle mechanism: normal decoding benefits from SDPA speed, while alignment computation temporarily falls back to the explicit path.
Code evidence from `whisper/model.py:71-78`:
@contextmanager
def disable_sdpa():
prev_state = MultiHeadAttention.use_sdpa
try:
MultiHeadAttention.use_sdpa = False
yield
finally:
MultiHeadAttention.use_sdpa = prev_state
SDPA conditional path from `whisper/model.py:123-137`:
if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
a = scaled_dot_product_attention(
q, k, v, is_causal=mask is not None and n_ctx > 1
)
out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
qk = None
else:
qk = (q * scale) @ (k * scale).transpose(-1, -2)
if mask is not None:
qk = qk + mask[:n_ctx, :n_ctx]
qk = qk.float()
w = F.softmax(qk, dim=-1).to(q.dtype)
out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
qk = qk.detach()
Usage in alignment from `whisper/timing.py:196-197`:
with torch.no_grad(), disable_sdpa():
logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]