Implementation:Pyro ppl Pyro AIR Training
Appearance
| Property | Value |
|---|---|
| Implementation Type | Pattern Doc |
| Source File | examples/air/main.py
|
| Module | examples.air |
| Pyro Features | SVI, TraceGraph_ELBO, JitTraceGraph_ELBO, pyro.optim.Adam, poutine.trace, poutine.replay
|
| Dataset | Multi-MNIST |
Overview
This file is the main training script for the Attend, Infer, Repeat (AIR) model applied to the multi-MNIST dataset. It orchestrates data loading, model construction, SVI-based optimization, evaluation, and visualization.
Key features of the training pipeline:
- Prior annealing: The z_pres prior probability is decayed over training (linearly or exponentially) to encourage the model to use fewer object slots as training progresses.
- Per-parameter learning rates: Baseline network parameters use a separate (typically higher) learning rate from inference network parameters.
- Count accuracy evaluation: The script measures how accurately the model infers the number of digits in each image.
- Visdom visualization: Optional real-time visualization of inferred object locations and reconstructions.
The training loop uses TraceGraph_ELBO (or its JIT variant) because the model contains discrete latent variables (z_pres) that require REINFORCE-style gradient estimation with the Rao-Blackwellized estimator.
Code Reference
def main(**kwargs):
args = argparse.Namespace(**kwargs)
X, true_counts = load_data()
air = AIR(
num_steps=args.model_steps,
x_size=50,
z_what_size=args.encoder_latent_size,
use_masking=not args.no_masking,
use_baselines=not args.no_baselines,
**model_args
)
adam = optim.Adam(per_param_optim_args)
elbo = JitTraceGraph_ELBO() if args.jit else TraceGraph_ELBO()
svi = SVI(air.model, air.guide, adam, loss=elbo)
for i in range(1, args.num_steps + 1):
loss = svi.step(X, batch_size=args.batch_size,
z_pres_prior_p=partial(z_pres_prior_p, i))
I/O Contract
| Parameter | Type | Description |
|---|---|---|
--num-steps |
int |
Number of optimization steps (default: 1e8) |
--batch-size |
int |
Batch size (default: 64) |
--learning-rate |
float |
Learning rate (default: 1e-4) |
--baseline-learning-rate |
float |
Baseline LR (default: 1e-3) |
--model-steps |
int |
Number of AIR time steps (default: 3) |
--z-pres-prior |
float |
Prior success probability for z_pres (default: 0.5) |
--anneal-prior |
str |
Prior annealing strategy: none, lin, exp |
Output:
- Trained AIR model parameters (optionally saved to file)
- Count accuracy metrics printed to stdout
- Visdom visualizations (if
--vizflag is set)
Usage Examples
# Run training from command line
# python main.py -n 200000 -b 64 -lr 1e-4 --anneal-prior exp --viz
# Or call main() programmatically
from main import main
main(
num_steps=200000,
batch_size=64,
learning_rate=1e-4,
model_steps=3,
encoder_latent_size=50,
window_size=28,
rnn_hidden_size=256,
anneal_prior="exp",
anneal_prior_to=1e-7,
anneal_prior_begin=1000,
anneal_prior_duration=100000,
seed=42,
)
Related Pages
- Pyro_ppl_Pyro_AIR_Model - The AIR model and guide implementation
- Pyro_ppl_Pyro_AIR_Modules - Neural network modules used by AIR
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment