Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Heuristic:Openai Whisper SDPA Disabling For Attention Extraction

From Leeroopedia
Knowledge Sources
Domains Timestamps, Optimization
Last Updated 2025-06-25 00:00 GMT

Overview

Temporarily disabling PyTorch's Scaled Dot-Product Attention (SDPA) to extract raw cross-attention weights needed for word-level timestamp alignment via DTW.

Description

PyTorch 2.0+ includes an optimized `scaled_dot_product_attention` (SDPA) implementation that fuses the attention computation but does not return the raw attention weight matrix (QK^T). Since Whisper's word-level timestamp pipeline requires the raw cross-attention weights to align text tokens with audio frames, SDPA must be temporarily disabled during the alignment forward pass. The `disable_sdpa()` context manager sets `MultiHeadAttention.use_sdpa = False`, forcing the fallback code path that computes and returns attention weights explicitly.

Usage

Use this heuristic when extracting cross-attention weights for word-level timestamp computation. The `find_alignment()` function in `whisper/timing.py` wraps the model forward pass in `disable_sdpa()` to ensure attention weights are available. Normal decoding (without word timestamps) uses SDPA for faster inference.

The Insight (Rule of Thumb)

  • Action: Wrap the forward pass in `disable_sdpa()` context manager when attention weights are needed.
  • Value: Enables extraction of per-head attention weight matrices from cross-attention layers.
  • Trade-off: Slower attention computation (no SDPA fusion) during the alignment pass. This is acceptable because alignment is done once per segment, not per decoding step.
  • Key detail: SDPA returns `qk = None` when active, so any code that needs `qk` must disable SDPA first.

Reasoning

SDPA combines Q, K, V projections and softmax into a single fused kernel for performance, but discards intermediate attention weights. Word-level timestamps require these weights to build an alignment matrix between text tokens and audio frames. The class-level `use_sdpa` flag provides a clean toggle mechanism: normal decoding benefits from SDPA speed, while alignment computation temporarily falls back to the explicit path.

Code evidence from `whisper/model.py:71-78`:

@contextmanager
def disable_sdpa():
    prev_state = MultiHeadAttention.use_sdpa
    try:
        MultiHeadAttention.use_sdpa = False
        yield
    finally:
        MultiHeadAttention.use_sdpa = prev_state

SDPA conditional path from `whisper/model.py:123-137`:

if SDPA_AVAILABLE and MultiHeadAttention.use_sdpa:
    a = scaled_dot_product_attention(
        q, k, v, is_causal=mask is not None and n_ctx > 1
    )
    out = a.permute(0, 2, 1, 3).flatten(start_dim=2)
    qk = None
else:
    qk = (q * scale) @ (k * scale).transpose(-1, -2)
    if mask is not None:
        qk = qk + mask[:n_ctx, :n_ctx]
    qk = qk.float()
    w = F.softmax(qk, dim=-1).to(q.dtype)
    out = (w @ v).permute(0, 2, 1, 3).flatten(start_dim=2)
    qk = qk.detach()

Usage in alignment from `whisper/timing.py:196-197`:

with torch.no_grad(), disable_sdpa():
    logits = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0]

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment