Implementation:Pyro ppl Pyro BART Forecast
Appearance
| Property | Value |
|---|---|
| Implementation Type | Pattern Doc |
| Source File | examples/contrib/forecast/bart.py
|
| Module | pyro.contrib.forecast |
| Pyro Features | pyro.contrib.forecast.ForecastingModel, backtest, GaussianHMM, LKJCholesky, seasonal modeling, periodic_cumsum, periodic_repeat
|
| Dataset | BART (Bay Area Rapid Transit) origin-destination ridership data |
Overview
This file demonstrates Pyro's forecasting framework by modeling hourly arrivals and departures at the Embarcadero BART station. The model subclasses ForecastingModel and implements a .model() method that defines a bivariate time series model with:
- Seasonal component: Weekly seasonality (period = 24 hours x 7 days = 168) with an initial seasonal pattern and slow drift via
periodic_cumsum. - Noise model: A
GaussianHMMwith correlated bivariate transitions, modeling autocorrelated noise with LKJ-distributed correlation matrices. - Global parameters: Noise scale, transition timescale, transition location, transition scale and correlation, observation scale and correlation.
The model is evaluated using backtesting via the backtest() function, which automatically trains on sliding windows and evaluates on held-out test windows.
Code Reference
class Model(ForecastingModel):
def model(self, zero_data, covariates):
period = 24 * 7
duration, dim = zero_data.shape[-2:]
noise_scale = pyro.sample("noise_scale",
dist.LogNormal(torch.full((dim,), -3.0), 1.0).to_event(1))
trans_timescale = pyro.sample("trans_timescale",
dist.LogNormal(torch.zeros(dim), 1).to_event(1))
# ... more global parameters ...
with pyro.plate("season_plate", period, dim=-1):
season_init = pyro.sample("season_init",
dist.Normal(torch.zeros(dim), 1).to_event(1))
with self.time_plate:
season_noise = pyro.sample("season_noise",
dist.Normal(0, noise_scale).to_event(1))
prediction = periodic_repeat(season_init, duration, dim=-2) + \
periodic_cumsum(season_noise, period, dim=-2)
noise_model = dist.GaussianHMM(init_dist, trans_mat, trans_dist,
obs_mat, obs_dist, duration=duration)
self.predict(noise_model, prediction)
I/O Contract
| Parameter | Type | Description |
|---|---|---|
--train-window |
int |
Training window size in hours (default: 2160) |
--test-window |
int |
Test window size in hours (default: 336) |
--stride |
int |
Stride between backtest windows (default: 168) |
-n / --num-steps |
int |
SVI optimization steps (default: 501) |
--num-samples |
int |
Forecast samples (default: 100) |
--dct |
flag | Use DCT gradients |
Output:
- Backtest metrics: MAE, RMSE, CRPS (mean +/- std across windows)
- Data shape: bivariate (arrivals, departures) hourly counts
Usage Examples
# Run backtesting with default parameters
# python bart.py -n 501 -lr 0.01 --num-samples 100
# With DCT gradients for faster training
# python bart.py --dct -n 501
Related Pages
- Pyro_ppl_Pyro_GP_TimeSeries - GP-based time series models
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment