Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro AIR Training

From Leeroopedia


Property Value
Implementation Type Pattern Doc
Source File examples/air/main.py
Module examples.air
Pyro Features SVI, TraceGraph_ELBO, JitTraceGraph_ELBO, pyro.optim.Adam, poutine.trace, poutine.replay
Dataset Multi-MNIST

Overview

This file is the main training script for the Attend, Infer, Repeat (AIR) model applied to the multi-MNIST dataset. It orchestrates data loading, model construction, SVI-based optimization, evaluation, and visualization.

Key features of the training pipeline:

  • Prior annealing: The z_pres prior probability is decayed over training (linearly or exponentially) to encourage the model to use fewer object slots as training progresses.
  • Per-parameter learning rates: Baseline network parameters use a separate (typically higher) learning rate from inference network parameters.
  • Count accuracy evaluation: The script measures how accurately the model infers the number of digits in each image.
  • Visdom visualization: Optional real-time visualization of inferred object locations and reconstructions.

The training loop uses TraceGraph_ELBO (or its JIT variant) because the model contains discrete latent variables (z_pres) that require REINFORCE-style gradient estimation with the Rao-Blackwellized estimator.

Code Reference

def main(**kwargs):
    args = argparse.Namespace(**kwargs)
    X, true_counts = load_data()

    air = AIR(
        num_steps=args.model_steps,
        x_size=50,
        z_what_size=args.encoder_latent_size,
        use_masking=not args.no_masking,
        use_baselines=not args.no_baselines,
        **model_args
    )

    adam = optim.Adam(per_param_optim_args)
    elbo = JitTraceGraph_ELBO() if args.jit else TraceGraph_ELBO()
    svi = SVI(air.model, air.guide, adam, loss=elbo)

    for i in range(1, args.num_steps + 1):
        loss = svi.step(X, batch_size=args.batch_size,
                        z_pres_prior_p=partial(z_pres_prior_p, i))

I/O Contract

Parameter Type Description
--num-steps int Number of optimization steps (default: 1e8)
--batch-size int Batch size (default: 64)
--learning-rate float Learning rate (default: 1e-4)
--baseline-learning-rate float Baseline LR (default: 1e-3)
--model-steps int Number of AIR time steps (default: 3)
--z-pres-prior float Prior success probability for z_pres (default: 0.5)
--anneal-prior str Prior annealing strategy: none, lin, exp

Output:

  • Trained AIR model parameters (optionally saved to file)
  • Count accuracy metrics printed to stdout
  • Visdom visualizations (if --viz flag is set)

Usage Examples

# Run training from command line
# python main.py -n 200000 -b 64 -lr 1e-4 --anneal-prior exp --viz

# Or call main() programmatically
from main import main
main(
    num_steps=200000,
    batch_size=64,
    learning_rate=1e-4,
    model_steps=3,
    encoder_latent_size=50,
    window_size=28,
    rnn_hidden_size=256,
    anneal_prior="exp",
    anneal_prior_to=1e-7,
    anneal_prior_begin=1000,
    anneal_prior_duration=100000,
    seed=42,
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment