Implementation:Pyro ppl Pyro AIR Model
| Property | Value |
|---|---|
| Implementation Type | Pattern Doc |
| Source File | examples/air/air.py
|
| Module | examples.air |
| Pyro Features | pyro.sample, pyro.module, pyro.param, pyro.plate, Bernoulli/Normal distributions, spatial transformers
|
| Paper | Eslami et al., "Attend, Infer, Repeat: Fast scene understanding with generative models" (NeurIPS 2016) |
Overview
This file implements the Attend, Infer, Repeat (AIR) model, a structured deep generative model for scene understanding. AIR decomposes an image into a variable number of objects by iteratively attending to parts of the image, inferring their latent descriptions, and rendering them back.
The AIR class extends nn.Module and defines both the model (generative process) and the guide (inference network). The model iterates for a configurable number of steps, sampling at each step:
- z_pres (Bernoulli): Whether an object is present
- z_where (Normal): The position and scale of the attention window (3D: scale, x, y)
- z_what (Normal): The latent code describing the object appearance
The guide uses an LSTM-based recurrent network that processes the image and produces variational parameters for each latent variable. A separate baseline network estimates the REINFORCE baseline for the discrete z_pres variable.
Spatial transformer functions (window_to_image, image_to_window) convert between attention windows and the full image space using affine grid sampling.
Code Reference
class AIR(nn.Module):
def __init__(self, num_steps, x_size, window_size, z_what_size,
rnn_hidden_size, ...):
# Configures prior parameters, encoder/decoder networks, RNN, baseline networks
def model(self, data, batch_size, **kwargs):
pyro.module("decode", self.decode)
with pyro.plate("data", data.size(0), device=data.device) as ix:
batch = data[ix]
n = batch.size(0)
(z_where, z_pres), x = self.prior(n, **kwargs)
pyro.sample("obs", dist.Normal(x.view(n, -1),
self.likelihood_sd * torch.ones(n, self.x_size**2, **self.options)
).to_event(1), obs=batch.view(n, -1))
def guide(self, data, batch_size, **kwargs):
pyro.module("rnn", self.rnn)
pyro.module("predict", self.predict)
pyro.module("encode", self.encode)
# Registers all neural network components with Pyro
# Iteratively processes image through LSTM, sampling z_pres, z_where, z_what
I/O Contract
| Parameter | Type | Description |
|---|---|---|
data |
torch.Tensor |
Batch of images, shape [N, x_size, x_size]
|
batch_size |
int |
Subsample size for the data plate |
z_pres_prior_p |
callable |
Function t -> float returning prior probability for z_pres at step t
|
Outputs (from guide):
z_where: List of tensors[N, 3](scale, x, y) per time stepz_pres: List of tensors[N, 1](presence indicators) per time step
Named sample sites:
z_pres_{t}: Bernoulli presence variable at step tz_where_{t}: Normal attention window parameters at step tz_what_{t}: Normal latent appearance code at step tobs: Normal observation likelihood
Usage Examples
from air import AIR
# Create the AIR model
air = AIR(
num_steps=3,
x_size=50,
window_size=28,
z_what_size=50,
rnn_hidden_size=256,
encoder_net=[200],
decoder_net=[200],
use_masking=True,
use_baselines=True,
likelihood_sd=0.3,
)
# Use with Pyro SVI
svi = SVI(air.model, air.guide, adam, loss=TraceGraph_ELBO())
loss = svi.step(X, batch_size=64)
# Sample from the prior
z, x = air.prior(5)
# Run the guide to get inferred latents
z_where, z_pres = air.guide(X_batch, batch_size=64)
Related Pages
- Pyro_ppl_Pyro_AIR_Training - Training entry point for the AIR model
- Pyro_ppl_Pyro_AIR_Modules - Neural network modules used by AIR