Implementation:Microsoft DeepSpeedExamples BingBert Turing FocalLoss
| Knowledge Sources | |
|---|---|
| Domains | Loss Functions, Classification |
| Last Updated | 2026-02-07 12:00 GMT |
Overview
A PyTorch implementation of Focal Loss that down-weights well-classified examples to focus training on hard, misclassified samples, with configurable alpha weighting and gamma focusing parameters.
Description
FocalLoss is a PyTorch nn.Module implementing the focal loss function from the paper "Focal Loss for Dense Object Detection." The loss is defined as Loss(x, class) = -alpha * (1 - softmax(x)[class])^gamma * log(softmax(x)[class]), where the (1 - p)^gamma modulating factor reduces the contribution of easy examples and increases focus on hard negatives.
The module accepts a class_num parameter to define the number of classes, an optional alpha tensor for per-class weighting (defaulting to uniform weights), a gamma focusing parameter (defaulting to 2), and a size_average flag controlling whether the loss is averaged or summed over the minibatch. The alpha tensor is automatically moved to CUDA when the input tensors are on GPU.
In the forward pass, the module computes softmax probabilities, constructs a one-hot class mask using scatter_, extracts the predicted probability for the correct class, and applies the focal modulation. The implementation uses standard PyTorch tensor operations and supports both CPU and GPU execution.
Usage
Use FocalLoss as a drop-in replacement for CrossEntropyLoss when training on imbalanced datasets or when you want to focus optimization on hard examples. It is particularly useful in the Bing BERT Turing pipeline for tasks with class imbalance.
Code Reference
Source Location
- Repository: Microsoft_DeepSpeedExamples
- File: training/bing_bert/turing/loss.py
- Lines: 1-60
Signature
class FocalLoss(nn.Module):
def __init__(self, class_num, alpha=None, gamma=2, size_average=True):
def forward(self, inputs, targets):
Import
from turing.loss import FocalLoss
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| class_num | int | Yes | Number of target classes |
| alpha | Tensor or None | No | Per-class weighting factors; defaults to uniform ones(class_num, 1) |
| gamma | float | No | Focusing parameter that controls down-weighting of easy examples. Default: 2 |
| size_average | bool | No | If True, average loss over minibatch; if False, sum. Default: True |
| inputs | Tensor | Yes | Raw logits of shape (N, C) where N is batch size and C is class count |
| targets | Tensor | Yes | Ground truth class indices of shape (N,) |
Outputs
| Name | Type | Description |
|---|---|---|
| loss | Tensor | Scalar focal loss value, either averaged or summed over the minibatch |
Usage Examples
import torch
from turing.loss import FocalLoss
# Initialize with 10 classes and default gamma=2
criterion = FocalLoss(class_num=10)
# Forward pass
logits = torch.randn(32, 10) # batch_size=32, num_classes=10
targets = torch.randint(0, 10, (32,))
loss = criterion(logits, targets)
loss.backward()
# With custom alpha weights for imbalanced classes
alpha = torch.tensor([1.0, 2.0, 1.5, 1.0, 3.0, 1.0, 1.0, 2.0, 1.0, 1.5])
criterion = FocalLoss(class_num=10, alpha=alpha, gamma=3, size_average=False)
loss = criterion(logits, targets)