Implementation:Sktime Pytorch forecasting QuantileLoss
| Knowledge Sources | |
|---|---|
| Domains | Time_Series, Loss_Functions, Probabilistic_Forecasting |
| Last Updated | 2026-02-08 07:00 GMT |
Overview
Concrete tool for computing multi-quantile pinball loss for probabilistic time series forecasting provided by the pytorch-forecasting library.
Description
The QuantileLoss class extends MultiHorizonMetric to compute pinball loss across multiple quantiles simultaneously. It computes the asymmetric loss for each quantile, multiplies by 2 (for MAE normalization), and concatenates results. It provides to_prediction() (extracts the median quantile as point forecast) and to_quantiles() (returns all quantile predictions). Default quantiles are [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98].
Usage
Use as the loss function for TemporalFusionTransformer. It is the default loss when calling TemporalFusionTransformer.from_dataset(). The model's output_size should match the number of quantiles.
Code Reference
Source Location
- Repository: pytorch-forecasting
- File: pytorch_forecasting/metrics/quantile.py
- Lines: L10-68
Signature
class QuantileLoss(MultiHorizonMetric):
def __init__(
self,
quantiles: list[float] | None = None,
**kwargs,
):
"""
Quantile loss.
Args:
quantiles: quantiles for metric.
Default: [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98]
"""
def loss(self, y_pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
"""Compute quantile loss per quantile."""
def to_prediction(self, y_pred: torch.Tensor) -> torch.Tensor:
"""Extract median (q=0.5) as point prediction."""
def to_quantiles(self, y_pred: torch.Tensor) -> torch.Tensor:
"""Return all quantile predictions."""
Import
from pytorch_forecasting.metrics import QuantileLoss
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| quantiles | list[float] | No | Quantile levels (default: [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98]) |
| y_pred | torch.Tensor | Yes (to loss) | Model predictions (batch, horizon, n_quantiles) |
| target | torch.Tensor | Yes (to loss) | Actual target values (batch, horizon) |
Outputs
| Name | Type | Description |
|---|---|---|
| loss() | torch.Tensor | Per-sample quantile losses (batch, horizon, n_quantiles) |
| to_prediction() | torch.Tensor | Point forecast — median quantile (batch, horizon) |
| to_quantiles() | torch.Tensor | All quantile predictions (batch, horizon, n_quantiles) |
Usage Examples
Default QuantileLoss for TFT
from pytorch_forecasting.metrics import QuantileLoss
# Default: 7 quantiles
loss = QuantileLoss()
print(f"Quantiles: {loss.quantiles}")
# [0.02, 0.1, 0.25, 0.5, 0.75, 0.9, 0.98]
# Use with TFT
tft = TemporalFusionTransformer.from_dataset(
training,
loss=loss,
output_size=7, # matches number of quantiles
)
Custom Quantiles
# Narrower prediction interval
loss = QuantileLoss(quantiles=[0.1, 0.5, 0.9])
tft = TemporalFusionTransformer.from_dataset(
training,
loss=loss,
output_size=3,
)