Implementation:Openai Whisper Median Filter
Overview
median_filter() is a function that applies a median filter to a tensor along its last dimension. It supports both CPU and GPU execution, with a Triton-based CUDA kernel for GPU acceleration. This function is used in Whisper's word-level timestamp pipeline to smooth cross-attention weights before DTW alignment.
Source
- File:
whisper/timing.py, lines 19-54 - CUDA kernel:
whisper/triton_ops.py, lines 106-117 (median_filter_cuda) - Repository: https://github.com/openai/whisper
- Import:
from whisper.timing import median_filter
Signature
def median_filter(x: torch.Tensor, filter_width: int) -> torch.Tensor:
Parameters
| Parameter | Type | Description |
|---|---|---|
| x | torch.Tensor | Input tensor to filter (filtering applied on last dimension) |
| filter_width | int | Kernel width for the median filter (must be odd; default 7 in find_alignment)
|
Return Value
Returns a torch.Tensor of the same shape as the input, with median filtering applied along the last dimension.
Behavior
The function performs the following steps:
- Validation: Asserts that
filter_widthis odd (required for symmetric padding). - Padding: Pads the input with reflect padding of size
filter_width // 2along the last dimension to handle boundary effects. - CUDA Path: If the tensor is on a CUDA device, attempts to use the Triton-based
median_filter_cudakernel for GPU-accelerated computation. - CPU Fallback: Uses an unfold + sort approach:
torch.Tensor.unfold()creates sliding windows of sizefilter_widthalong the last dimension.torch.sort()sorts each window (this is faster thantorch.median()for this use case).- The median value is extracted as the middle element of the sorted window.
- Dimension Handling: For 1D and 2D inputs, temporarily adds batch dimensions to ensure consistent processing, then removes them before returning.
Why Sort Instead of torch.median
The implementation uses torch.sort() followed by indexing the middle element rather than calling torch.median() directly. This is because PyTorch's torch.median() on higher-dimensional tensors can be slower due to its implementation details, whereas sorting small windows and extracting the middle element is more efficient for the typical window sizes used (e.g., 7).
Example Usage
from whisper.timing import median_filter
import torch
# Filter attention weights
weights = torch.randn(8, 50, 100) # 8 heads, 50 tokens, 100 frames
smoothed = median_filter(weights, filter_width=7)
print(smoothed.shape) # torch.Size([8, 50, 100])
Implementation Details
Reflect Padding
Reflect padding mirrors the signal at the boundaries. For a signal [a, b, c, d, e] with pad width 3:
[d, c, b, a, b, c, d, e, d, c, b]
This avoids discontinuities at the edges that could distort the filtered output near segment boundaries.
GPU Acceleration
The Triton-based median_filter_cuda kernel (in whisper/triton_ops.py) provides GPU-accelerated median filtering. If Triton is not available or the kernel fails, the function gracefully falls back to the CPU-based unfold + sort approach.
Links
Principle:Openai_Whisper_Median_Filtering Environment:Openai_Whisper_Triton