Implementation:Kornia Kornia Total Variation Loss
| Knowledge Sources | |
|---|---|
| Domains | Vision, Loss_Functions |
| Last Updated | 2026-02-09 15:00 GMT |
Overview
Total Variation Loss computes the total variation of an image, measuring the amount of pixel-to-pixel variation and commonly used as a regularization term to encourage smoothness.
Description
Total Variation (TV) is a measure of the complexity or noise in an image, computed as the sum of absolute differences between neighboring pixels in both horizontal and vertical directions. It is widely used as a regularization term in image processing to encourage piecewise-smooth solutions.
The total variation is computed as:
The implementation supports two reduction modes:
- sum: The default, computes the sum of absolute differences along both axes.
- mean: Computes the mean of absolute differences, providing resolution-invariant behavior.
The function accepts tensors with arbitrary leading batch dimensions and expects the last two dimensions to be the spatial (H, W) dimensions.
Usage
Import this loss as a regularization term for image generation, style transfer, image denoising, and neural radiance field tasks. Adding total variation regularization encourages the network to produce smoother outputs and reduces noise artifacts.
Code Reference
Source Location
- Repository: Kornia
- File: kornia/losses/total_variation.py
- Lines: 1-104
Signature
def total_variation(
img: torch.Tensor,
reduction: str = "sum",
) -> torch.Tensor: ...
class TotalVariation(nn.Module):
def forward(self, img: torch.Tensor) -> torch.Tensor: ...
Import
from kornia.losses import TotalVariation
from kornia.losses import total_variation
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| img | torch.Tensor | Yes | Input image tensor with shape (*, H, W) where * represents any batch dimensions |
| reduction | str | No | Reduction mode: 'sum' (default) or 'mean' |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | torch.Tensor | Total variation per batch element with shape (*,), where * matches the leading dimensions of the input |
Usage Examples
import torch
from kornia.losses import TotalVariation, total_variation
# Constant image has zero total variation
result = total_variation(torch.ones(4, 4))
# tensor(0.)
# Batched usage
img = torch.ones(2, 3, 4, 4, requires_grad=True)
tv = TotalVariation()
output = tv(img)
# output.shape = torch.Size([2, 3])
output.sum().backward()
# Using mean reduction for resolution-invariant TV
img = torch.rand(1, 3, 64, 64)
tv_loss = total_variation(img, reduction="mean")