Implementation:Pyro ppl Pyro SIR Models
Appearance
| Property | Value |
|---|---|
| Implementation Type | Pattern Doc |
| Source File | examples/contrib/epidemiology/sir.py
|
| Module | pyro.contrib.epidemiology |
| Pyro Features | pyro.contrib.epidemiology.models, MCMC (NUTS/HMC), SVI, SMC heuristic initialization, Haar wavelet reparameterization
|
| Model Classes | SimpleSIRModel, SimpleSEIRModel, OverdispersedSIRModel, OverdispersedSEIRModel, SuperspreadingSIRModel, SuperspreadingSEIRModel, HeterogeneousSIRModel |
Overview
This file demonstrates Pyro's high-level compartmental epidemiology modeling framework for fitting SIR and SEIR models to infection count data. It dispatches between seven different model classes based on command-line arguments:
- SimpleSIRModel / SimpleSEIRModel: Basic compartmental models with Poisson-distributed transitions.
- OverdispersedSIRModel / OverdispersedSEIRModel: Models with overdispersed (negative binomial-like) transitions.
- SuperspreadingSIRModel / SuperspreadingSEIRModel: Models with a concentration parameter for superspreading events.
- HeterogeneousSIRModel: SIR model with individual-level heterogeneity in transmission.
The script provides a complete workflow: data generation from a known model, inference via MCMC or SVI, evaluation of parameter estimates, and forecasting of future infections.
Code Reference
def Model(args, data):
"""Dispatch between different model classes."""
if args.heterogeneous:
return HeterogeneousSIRModel(args.population, args.recovery_time, data)
elif args.incubation_time > 0:
if args.concentration < math.inf:
return SuperspreadingSEIRModel(
args.population, args.incubation_time, args.recovery_time, data)
else:
return SimpleSEIRModel(
args.population, args.incubation_time, args.recovery_time, data)
else:
return SimpleSIRModel(args.population, args.recovery_time, data)
def main(args):
dataset = generate_data(args)
model = Model(args, dataset["obs"])
infer = {"mcmc": infer_mcmc, "svi": infer_svi}[args.infer]
samples = infer(args, model)
evaluate(args, model, samples)
if args.forecast:
predict(args, model, truth=dataset["new_I"])
I/O Contract
| Parameter | Type | Description |
|---|---|---|
--population |
int |
Total population size (default: 1000) |
--duration |
int |
Number of observed days (default: 20) |
--forecast |
int |
Number of days to forecast (default: 10) |
--basic-reproduction-number |
float |
True R0 for data generation (default: 1.5) |
--recovery-time |
float |
Mean recovery time in days (default: 7.0) |
--incubation-time |
float |
Incubation time; 0 for SIR, >1 for SEIR |
--infer |
str |
Inference method: "mcmc" or "svi" |
--haar |
flag | Use Haar wavelet reparameterization |
Output:
- Estimated R0, response rate (rho), and other parameters with uncertainties
- Forecasted infection time series with confidence intervals
- Optional matplotlib plots of posteriors and predictions
Usage Examples
# Simple SIR model with MCMC inference
# python sir.py -p 1000 -d 20 -R0 1.5 --mcmc -n 200
# SEIR model with superspreading
# python sir.py -p 10000 -e 3 -k 0.5 --mcmc --haar
# SVI inference with plotting
# python sir.py --svi --plot -n 200 -ss 5000
Related Pages
- Pyro_ppl_Pyro_Regional_SIR - Regional variant with coupled populations
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment