Implementation:Pyro ppl Pyro Regional SIR
Appearance
| Property | Value |
|---|---|
| Implementation Type | Pattern Doc |
| Source File | examples/contrib/epidemiology/regional.py
|
| Module | pyro.contrib.epidemiology |
| Pyro Features | pyro.contrib.epidemiology.models.RegionalSIRModel, MCMC, SVI, Haar wavelet reparameterization, multi-region coupling
|
| Pattern | Coupled multi-region compartmental epidemic model |
Overview
This file demonstrates a regional SIR (Susceptible-Infected-Recovered) model that extends the basic SIR framework to multiple coupled populations. Each region has its own infection dynamics, but regions interact through a coupling matrix that models cross-regional transmission.
The model uses RegionalSIRModel from pyro.contrib.epidemiology which takes:
- A population vector specifying the size of each region
- A coupling matrix defining inter-region transmission rates (diagonal = within-region, off-diagonal = between-region)
- A recovery time parameter
The coupling matrix is constructed as an identity matrix clamped at a minimum coupling value, representing symmetric coupling between all regions.
Code Reference
def Model(args, data):
assert 0 <= args.coupling <= 1, args.coupling
population = torch.full((args.num_regions,), float(args.population))
coupling = torch.eye(args.num_regions).clamp(min=args.coupling)
return RegionalSIRModel(population, coupling, args.recovery_time, data)
def main(args):
pyro.set_rng_seed(args.rng_seed)
dataset = generate_data(args)
model = Model(args, dataset["obs"])
infer = {"mcmc": infer_mcmc, "svi": infer_svi}[args.infer]
infer(args, model)
predict(args, model, truth=dataset["S2I"])
I/O Contract
| Parameter | Type | Description |
|---|---|---|
-p / --population |
int |
Population per region (default: 1000) |
-r / --num-regions |
int |
Number of coupled regions (default: 2) |
-c / --coupling |
float |
Inter-region coupling strength 0-1 (default: 0.1) |
-d / --duration |
int |
Observed days (default: 20) |
-f / --forecast |
int |
Forecast days (default: 10) |
-R0 |
float |
Basic reproduction number (default: 1.5) |
--infer |
str |
Inference method: "mcmc" or "svi" |
--haar |
flag | Use Haar wavelet reparameterization |
Output:
- Median predicted new infections per region
- Per-region time series with 90% confidence intervals
- MCMC energy trace or SVI convergence plot
Usage Examples
# Run with 3 regions and MCMC
# python regional.py -r 3 -p 1000 -c 0.1 --mcmc -n 200 --plot
# Run with SVI and Haar wavelets
# python regional.py -r 5 --svi --haar -ss 5000 --plot
Related Pages
- Pyro_ppl_Pyro_SIR_Models - Single-population SIR/SEIR models
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment