Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro SIR Models

From Leeroopedia


Property Value
Implementation Type Pattern Doc
Source File examples/contrib/epidemiology/sir.py
Module pyro.contrib.epidemiology
Pyro Features pyro.contrib.epidemiology.models, MCMC (NUTS/HMC), SVI, SMC heuristic initialization, Haar wavelet reparameterization
Model Classes SimpleSIRModel, SimpleSEIRModel, OverdispersedSIRModel, OverdispersedSEIRModel, SuperspreadingSIRModel, SuperspreadingSEIRModel, HeterogeneousSIRModel

Overview

This file demonstrates Pyro's high-level compartmental epidemiology modeling framework for fitting SIR and SEIR models to infection count data. It dispatches between seven different model classes based on command-line arguments:

  • SimpleSIRModel / SimpleSEIRModel: Basic compartmental models with Poisson-distributed transitions.
  • OverdispersedSIRModel / OverdispersedSEIRModel: Models with overdispersed (negative binomial-like) transitions.
  • SuperspreadingSIRModel / SuperspreadingSEIRModel: Models with a concentration parameter for superspreading events.
  • HeterogeneousSIRModel: SIR model with individual-level heterogeneity in transmission.

The script provides a complete workflow: data generation from a known model, inference via MCMC or SVI, evaluation of parameter estimates, and forecasting of future infections.

Code Reference

def Model(args, data):
    """Dispatch between different model classes."""
    if args.heterogeneous:
        return HeterogeneousSIRModel(args.population, args.recovery_time, data)
    elif args.incubation_time > 0:
        if args.concentration < math.inf:
            return SuperspreadingSEIRModel(
                args.population, args.incubation_time, args.recovery_time, data)
        else:
            return SimpleSEIRModel(
                args.population, args.incubation_time, args.recovery_time, data)
    else:
        return SimpleSIRModel(args.population, args.recovery_time, data)

def main(args):
    dataset = generate_data(args)
    model = Model(args, dataset["obs"])
    infer = {"mcmc": infer_mcmc, "svi": infer_svi}[args.infer]
    samples = infer(args, model)
    evaluate(args, model, samples)
    if args.forecast:
        predict(args, model, truth=dataset["new_I"])

I/O Contract

Parameter Type Description
--population int Total population size (default: 1000)
--duration int Number of observed days (default: 20)
--forecast int Number of days to forecast (default: 10)
--basic-reproduction-number float True R0 for data generation (default: 1.5)
--recovery-time float Mean recovery time in days (default: 7.0)
--incubation-time float Incubation time; 0 for SIR, >1 for SEIR
--infer str Inference method: "mcmc" or "svi"
--haar flag Use Haar wavelet reparameterization

Output:

  • Estimated R0, response rate (rho), and other parameters with uncertainties
  • Forecasted infection time series with confidence intervals
  • Optional matplotlib plots of posteriors and predictions

Usage Examples

# Simple SIR model with MCMC inference
# python sir.py -p 1000 -d 20 -R0 1.5 --mcmc -n 200

# SEIR model with superspreading
# python sir.py -p 10000 -e 3 -k 0.5 --mcmc --haar

# SVI inference with plotting
# python sir.py --svi --plot -n 200 -ss 5000

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment