Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Sktime Pytorch forecasting NBeats From Dataset

From Leeroopedia


Knowledge Sources
Domains Time_Series, Deep_Learning, Signal_Decomposition
Last Updated 2026-02-08 07:00 GMT

Overview

Concrete tool for instantiating an N-BEATS forecasting model from a TimeSeriesDataSet provided by the pytorch-forecasting library.

Description

The NBeats.from_dataset factory method (inherited from NBeatsAdapter) creates an N-BEATS model with prediction_length and context_length inferred from the dataset's max_prediction_length and max_encoder_length. It validates strict requirements: single target (univariate), continuous target (not categorical), fixed encoder/decoder lengths, no length randomization, no relative time index, and only the target variable as input. The model is configured with trend and seasonality stacks by default.

Usage

Call this method after constructing a TimeSeriesDataSet configured for N-BEATS: fixed-length windows, single univariate target, no covariates. The dataset must have min_encoder_length == max_encoder_length and min_prediction_length == max_prediction_length.

Code Reference

Source Location

  • Repository: pytorch-forecasting
  • File: pytorch_forecasting/models/nbeats/_nbeats_adapter.py (from_dataset: L117-175), pytorch_forecasting/models/nbeats/_nbeats.py (__init__: L98-174)

Signature

class NBeats(NBeatsAdapter):
    @classmethod
    def from_dataset(cls, dataset: TimeSeriesDataSet, **kwargs):
        """
        Convenience function to create network from TimeSeriesDataSet.

        Parameters
        ----------
        dataset : TimeSeriesDataSet
            dataset where sole predictor is the target.
        **kwargs
            additional arguments to be passed to __init__ method.

        Returns
        -------
        NBeats
        """

    def __init__(
        self,
        stack_types: list[str] | None = None,  # default: ["trend", "seasonality"]
        num_blocks: list[int] | None = None,  # default: [3, 3]
        num_block_layers: list[int] | None = None,  # default: [3, 3]
        widths: list[int] | None = None,  # default: [32, 512]
        sharing: list[bool] | None = None,  # default: [True, True]
        expansion_coefficient_lengths: list[int] | None = None,  # default: [3, 7]
        prediction_length: int = 1,
        context_length: int = 1,
        dropout: float = 0.1,
        learning_rate: float = 1e-2,
        weight_decay: float = 1e-3,
        loss: MultiHorizonMetric = None,  # default: MASE()
        backcast_loss_ratio: float = 0.0,
        **kwargs,
    ):

Import

from pytorch_forecasting import NBeats

I/O Contract

Inputs

Name Type Required Description
dataset TimeSeriesDataSet Yes Univariate dataset with fixed-length windows
learning_rate float No Learning rate (default: 1e-2)
weight_decay float No Weight decay for regularization (default: 1e-3)
stack_types list[str] No Block types: "trend", "seasonality", "generic" (default: ["trend", "seasonality"])
num_blocks list[int] No Number of blocks per stack (default: [3, 3])
widths list[int] No Hidden widths per stack (default: [32, 512])
dropout float No Dropout rate (default: 0.1)
backcast_loss_ratio float No Weight of backcast loss term (default: 0.0)

Outputs

Name Type Description
return NBeats Configured N-BEATS model with trend/seasonal stacks

Usage Examples

Standard N-BEATS Instantiation

from pytorch_forecasting import NBeats

model = NBeats.from_dataset(
    training,
    learning_rate=3e-2,
    weight_decay=1e-2,
    widths=[32, 512],
    backcast_loss_ratio=1.0,
)

print(f"Number of parameters: {model.size() / 1e3:.1f}k")

N-BEATS with Generic Stack

model = NBeats.from_dataset(
    training,
    stack_types=["generic", "generic", "generic"],
    num_blocks=[1, 1, 1],
    widths=[512, 512, 512],
    learning_rate=1e-3,
)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment