Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Sktime Pytorch forecasting DeepAR Predict

From Leeroopedia


Knowledge Sources
Domains Time_Series, Probabilistic_Forecasting, Model_Evaluation
Last Updated 2026-02-08 07:00 GMT

Overview

Concrete tool for generating probabilistic predictions from a trained DeepAR model provided by the pytorch-forecasting library.

Description

The DeepAR.predict method overrides BaseModel.predict to add probabilistic sampling support. It accepts a DataLoader, DataFrame, or TimeSeriesDataSet and returns predictions in multiple modes: point predictions (prediction), quantile forecasts (quantiles), raw Monte Carlo samples (samples), or distribution parameters (raw). The n_samples parameter (default: 100) controls the number of samples drawn from the learned distribution for approximating prediction intervals. The method uses a PredictCallback internally and supports batch-wise or epoch-wise output writing.

Usage

Call on a trained DeepAR model with validation or test data. Use mode="prediction" for point forecasts, mode="quantiles" for prediction intervals, or mode="samples" for full sample paths. Set return_x=True and return_y=True to get inputs and actuals alongside predictions for evaluation.

Code Reference

Source Location

  • Repository: pytorch-forecasting
  • File: pytorch_forecasting/models/deepar/_deepar.py
  • Lines: L437-508

Signature

class DeepAR:
    def predict(
        self,
        data: DataLoader | pd.DataFrame | TimeSeriesDataSet,
        mode: str | tuple[str, str] = "prediction",
        return_index: bool = False,
        return_decoder_lengths: bool = False,
        batch_size: int = 64,
        num_workers: int = 0,
        fast_dev_run: bool = False,
        return_x: bool = False,
        return_y: bool = False,
        mode_kwargs: dict[str, Any] = None,
        trainer_kwargs: dict[str, Any] | None = None,
        write_interval: str = "batch",
        output_dir: str | None = None,
        n_samples: int = 100,
        **kwargs,
    ) -> Prediction:
        """
        Predict dataloader.

        Args:
            data: dataloader, dataframe or dataset
            mode: "prediction", "quantiles", "samples", or "raw"
            return_index: if to return prediction index
            return_x: if to return network inputs
            return_y: if to return network targets
            n_samples: number of samples to draw (default: 100)

        Returns:
            Prediction namedtuple
        """

Import

from pytorch_forecasting import DeepAR
# model.predict(val_dataloader, mode="prediction", n_samples=100)

I/O Contract

Inputs

Name Type Required Description
data DataLoader or DataFrame or TimeSeriesDataSet Yes Data to predict on
mode str No Output mode: "prediction", "quantiles", "samples", "raw" (default: "prediction")
n_samples int No Number of Monte Carlo samples (default: 100)
return_index bool No Return prediction time index (default: False)
return_x bool No Return network inputs (default: False)
return_y bool No Return actual targets (default: False)
batch_size int No Batch size for prediction (default: 64)

Outputs

Name Type Description
return Prediction Namedtuple with fields: output (predictions tensor), x (inputs), y (actuals), index, decoder_lengths

Usage Examples

Point Predictions

# Get point predictions (median of distribution)
predictions = model.predict(
    val_dataloader,
    mode="prediction",
    return_x=True,
    return_y=True,
    n_samples=100,
)

print(f"Predictions shape: {predictions.output.shape}")

Sample-Based Prediction Intervals

# Get raw samples for custom interval computation
raw_predictions = model.predict(
    val_dataloader,
    mode="samples",
    n_samples=200,
)

# raw_predictions contains (n_samples, batch, prediction_length) tensor
import torch
lower = torch.quantile(raw_predictions.output, 0.1, dim=0)
upper = torch.quantile(raw_predictions.output, 0.9, dim=0)

Related Pages

Implements Principle

Requires Environment

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment