Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Openai Whisper Median Filter

From Leeroopedia

Overview

median_filter() is a function that applies a median filter to a tensor along its last dimension. It supports both CPU and GPU execution, with a Triton-based CUDA kernel for GPU acceleration. This function is used in Whisper's word-level timestamp pipeline to smooth cross-attention weights before DTW alignment.

Source

  • File: whisper/timing.py, lines 19-54
  • CUDA kernel: whisper/triton_ops.py, lines 106-117 (median_filter_cuda)
  • Repository: https://github.com/openai/whisper
  • Import: from whisper.timing import median_filter

Signature

def median_filter(x: torch.Tensor, filter_width: int) -> torch.Tensor:

Parameters

Parameter Type Description
x torch.Tensor Input tensor to filter (filtering applied on last dimension)
filter_width int Kernel width for the median filter (must be odd; default 7 in find_alignment)

Return Value

Returns a torch.Tensor of the same shape as the input, with median filtering applied along the last dimension.

Behavior

The function performs the following steps:

  1. Validation: Asserts that filter_width is odd (required for symmetric padding).
  2. Padding: Pads the input with reflect padding of size filter_width // 2 along the last dimension to handle boundary effects.
  3. CUDA Path: If the tensor is on a CUDA device, attempts to use the Triton-based median_filter_cuda kernel for GPU-accelerated computation.
  4. CPU Fallback: Uses an unfold + sort approach:
    • torch.Tensor.unfold() creates sliding windows of size filter_width along the last dimension.
    • torch.sort() sorts each window (this is faster than torch.median() for this use case).
    • The median value is extracted as the middle element of the sorted window.
  5. Dimension Handling: For 1D and 2D inputs, temporarily adds batch dimensions to ensure consistent processing, then removes them before returning.

Why Sort Instead of torch.median

The implementation uses torch.sort() followed by indexing the middle element rather than calling torch.median() directly. This is because PyTorch's torch.median() on higher-dimensional tensors can be slower due to its implementation details, whereas sorting small windows and extracting the middle element is more efficient for the typical window sizes used (e.g., 7).

Example Usage

from whisper.timing import median_filter
import torch

# Filter attention weights
weights = torch.randn(8, 50, 100)  # 8 heads, 50 tokens, 100 frames
smoothed = median_filter(weights, filter_width=7)
print(smoothed.shape)  # torch.Size([8, 50, 100])

Implementation Details

Reflect Padding

Reflect padding mirrors the signal at the boundaries. For a signal [a, b, c, d, e] with pad width 3:

[d, c, b, a, b, c, d, e, d, c, b]

This avoids discontinuities at the edges that could distort the filtered output near segment boundaries.

GPU Acceleration

The Triton-based median_filter_cuda kernel (in whisper/triton_ops.py) provides GPU-accelerated median filtering. If Triton is not available or the kernel fails, the function gracefully falls back to the CPU-based unfold + sort approach.

Links

Principle:Openai_Whisper_Median_Filtering Environment:Openai_Whisper_Triton

Metadata

2025-06-25 00:00 GMT

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment