Implementation:Kornia Kornia Dice Loss
| Knowledge Sources | |
|---|---|
| Domains | Vision, Loss_Functions |
| Last Updated | 2026-02-09 15:00 GMT |
Overview
Dice Loss computes the Sorensen-Dice Coefficient loss, a region-based loss function commonly used for image segmentation tasks.
Description
The Dice loss is based on the Sorensen-Dice coefficient, which measures the overlap between two sets. It is widely used in medical image segmentation and other pixel-level classification tasks. The Dice coefficient is computed as:
The loss is then:
The implementation supports two averaging modes:
- micro: Calculates the loss across all classes simultaneously.
- macro: Calculates the loss for each class separately and averages the metrics across classes.
It applies softmax to the predictions internally and converts targets to one-hot encoding.
Usage
Import this loss for semantic segmentation tasks, especially when dealing with class imbalance. Dice loss directly optimizes the overlap metric and is less sensitive to class imbalance compared to cross-entropy loss. It is particularly popular in medical image segmentation.
Code Reference
Source Location
- Repository: Kornia
- File: kornia/losses/dice.py
- Lines: 1-214
Signature
def dice_loss(
pred: torch.Tensor,
target: torch.Tensor,
average: str = "micro",
eps: float = 1e-8,
weight: Optional[torch.Tensor] = None,
ignore_index: Optional[int] = -100,
) -> torch.Tensor: ...
class DiceLoss(nn.Module):
def __init__(
self,
average: str = "micro",
eps: float = 1e-8,
weight: Optional[torch.Tensor] = None,
ignore_index: Optional[int] = -100,
) -> None: ...
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor: ...
Import
from kornia.losses import DiceLoss
from kornia.losses import dice_loss
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| pred | torch.Tensor | Yes | Logits tensor with shape (N, C, H, W) where C = number of classes |
| target | torch.Tensor | Yes | Labels tensor with shape (N, H, W) where each value is in [0, C-1] |
| average | str | No | Averaging mode: 'micro' (default) or 'macro' |
| eps | float | No | Scalar for numerical stability (default: 1e-8) |
| weight | Optional[torch.Tensor] | No | Per-class weights with shape (num_classes,) |
| ignore_index | Optional[int] | No | Label value to ignore (default: -100) |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | torch.Tensor | Scalar Dice loss value (1 - Dice coefficient) |
Usage Examples
import torch
from kornia.losses import DiceLoss
# Setup
N = 5 # num_classes
pred = torch.randn(1, N, 3, 5, requires_grad=True)
target = torch.empty(1, 3, 5, dtype=torch.long).random_(N)
# Using the module API
criterion = DiceLoss()
output = criterion(pred, target)
output.backward()
# With macro averaging and class weights
weights = torch.ones(N)
criterion_macro = DiceLoss(average="macro", weight=weights)
output = criterion_macro(pred, target)