Heuristic:Kornia Kornia Gradient Detach Stability
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Deep_Learning |
| Last Updated | 2026-02-09 15:00 GMT |
Overview
Gradient stability technique using `.detach()` on normalization statistics to prevent noisy gradients from `F.grid_sample` during descriptor training.
Description
When training local feature detectors and descriptors, patch extraction via `F.grid_sample` produces gradients that flow back through the sampling grid. If these gradients also flow through the batch normalization statistics (mean and standard deviation), the combined gradient signal becomes extremely noisy, destabilizing detector training entirely. The solution is to `.detach()` the mean and std tensors before using them for normalization, so gradients only flow through the normalized output, not through the statistics computation.
Usage
Apply this heuristic when implementing patch-based descriptor networks that use `F.grid_sample` for patch extraction upstream. Specifically, if your pipeline has: (1) a detector that produces keypoint locations, (2) a grid_sample-based patch extractor, and (3) a descriptor CNN with input normalization, you must detach the normalization statistics.
The Insight (Rule of Thumb)
- Action: Call `.detach()` on mean and standard deviation before using them to normalize input patches in descriptor networks.
- Value: `return (x - mp.detach()) / (sp.detach() + eps)` where `mp` is mean and `sp` is standard deviation.
- Trade-off: Prevents gradients from flowing through the normalization statistics, which means the statistics themselves are not optimized. This is acceptable because the normalization is just a preprocessing step, not a learned operation.
Reasoning
`F.grid_sample` backpropagation produces gradients with respect to both the input image and the sampling grid coordinates. When these gradients combine with gradients from the statistics computation (which also depends on the same input), the resulting gradient signal has high variance and can cause training divergence. By detaching, the backward pass only needs to differentiate through `(x - constant) / constant`, which is a simple affine transformation with stable gradients.
Code Evidence
Detached normalization in HardNet from `kornia/feature/hardnet.py:92-103`:
@staticmethod
def _normalize_input(x: torch.Tensor, eps: float = 1e-6) -> torch.Tensor:
"""Normalize the input by batch."""
if not is_mps_tensor_safe(x):
sp, mp = torch.std_mean(x, dim=(-3, -2, -1), keepdim=True)
else:
mp = torch.mean(x, dim=(-3, -2, -1), keepdim=True)
sp = torch.std(x, dim=(-3, -2, -1), keepdim=True)
# WARNING: we need to .detach() input, otherwise the gradients produced by
# the patches extractor with F.grid_sample are very noisy, making the detector
# training totally unstable.
return (x - mp.detach()) / (sp.detach() + eps)
Related ZCA whitening detach pattern from `kornia/enhance/zca.py`:
# This implementation uses torch.svd which yields NaNs in the backwards step
if self.detach_transforms:
self.mean_vector = self.mean_vector.detach()
MPS-specific fallback in the same function from `kornia/feature/hardnet.py:95-99`:
if not is_mps_tensor_safe(x):
sp, mp = torch.std_mean(x, dim=(-3, -2, -1), keepdim=True)
else:
mp = torch.mean(x, dim=(-3, -2, -1), keepdim=True)
sp = torch.std(x, dim=(-3, -2, -1), keepdim=True)