Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro AIR Model

From Leeroopedia
Revision as of 16:22, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pyro_ppl_Pyro_AIR_Model.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Property Value
Implementation Type Pattern Doc
Source File examples/air/air.py
Module examples.air
Pyro Features pyro.sample, pyro.module, pyro.param, pyro.plate, Bernoulli/Normal distributions, spatial transformers
Paper Eslami et al., "Attend, Infer, Repeat: Fast scene understanding with generative models" (NeurIPS 2016)

Overview

This file implements the Attend, Infer, Repeat (AIR) model, a structured deep generative model for scene understanding. AIR decomposes an image into a variable number of objects by iteratively attending to parts of the image, inferring their latent descriptions, and rendering them back.

The AIR class extends nn.Module and defines both the model (generative process) and the guide (inference network). The model iterates for a configurable number of steps, sampling at each step:

  • z_pres (Bernoulli): Whether an object is present
  • z_where (Normal): The position and scale of the attention window (3D: scale, x, y)
  • z_what (Normal): The latent code describing the object appearance

The guide uses an LSTM-based recurrent network that processes the image and produces variational parameters for each latent variable. A separate baseline network estimates the REINFORCE baseline for the discrete z_pres variable.

Spatial transformer functions (window_to_image, image_to_window) convert between attention windows and the full image space using affine grid sampling.

Code Reference

class AIR(nn.Module):
    def __init__(self, num_steps, x_size, window_size, z_what_size,
                 rnn_hidden_size, ...):
        # Configures prior parameters, encoder/decoder networks, RNN, baseline networks

    def model(self, data, batch_size, **kwargs):
        pyro.module("decode", self.decode)
        with pyro.plate("data", data.size(0), device=data.device) as ix:
            batch = data[ix]
            n = batch.size(0)
            (z_where, z_pres), x = self.prior(n, **kwargs)
            pyro.sample("obs", dist.Normal(x.view(n, -1),
                self.likelihood_sd * torch.ones(n, self.x_size**2, **self.options)
            ).to_event(1), obs=batch.view(n, -1))

    def guide(self, data, batch_size, **kwargs):
        pyro.module("rnn", self.rnn)
        pyro.module("predict", self.predict)
        pyro.module("encode", self.encode)
        # Registers all neural network components with Pyro
        # Iteratively processes image through LSTM, sampling z_pres, z_where, z_what

I/O Contract

Parameter Type Description
data torch.Tensor Batch of images, shape [N, x_size, x_size]
batch_size int Subsample size for the data plate
z_pres_prior_p callable Function t -> float returning prior probability for z_pres at step t

Outputs (from guide):

  • z_where: List of tensors [N, 3] (scale, x, y) per time step
  • z_pres: List of tensors [N, 1] (presence indicators) per time step

Named sample sites:

  • z_pres_{t}: Bernoulli presence variable at step t
  • z_where_{t}: Normal attention window parameters at step t
  • z_what_{t}: Normal latent appearance code at step t
  • obs: Normal observation likelihood

Usage Examples

from air import AIR

# Create the AIR model
air = AIR(
    num_steps=3,
    x_size=50,
    window_size=28,
    z_what_size=50,
    rnn_hidden_size=256,
    encoder_net=[200],
    decoder_net=[200],
    use_masking=True,
    use_baselines=True,
    likelihood_sd=0.3,
)

# Use with Pyro SVI
svi = SVI(air.model, air.guide, adam, loss=TraceGraph_ELBO())
loss = svi.step(X, batch_size=64)

# Sample from the prior
z, x = air.prior(5)

# Run the guide to get inferred latents
z_where, z_pres = air.guide(X_batch, batch_size=64)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment