Implementation:Pyro ppl Pyro AIR Modules
Appearance
| Property | Value |
|---|---|
| Implementation Type | Pattern Doc |
| Source File | examples/air/modules.py
|
| Module | examples.air |
| Pyro Features | nn.Module subclasses used with pyro.module
|
| Dependencies | PyTorch nn, softplus
|
Overview
This file defines the neural network building blocks used by the AIR (Attend, Infer, Repeat) model. These modules are registered with Pyro via pyro.module() so their parameters are tracked during optimization.
The file provides five module classes:
- Encoder: Maps attention window pixel intensities to parameters (mean, std) of the variational distribution over
z_what. Uses softplus to ensure positive standard deviations. - Decoder: Maps latent code
z_whatback to pixel intensities in the attention window. Optionally applies a bias and sigmoid nonlinearity. - MLP: A general-purpose multi-layer perceptron used as a building block by other modules. Supports configurable hidden layer sizes and optional output nonlinearity.
- Predict: Maps the guide RNN hidden state to parameters of the guide distributions over
z_pres(via sigmoid) andz_where(mean and softplus scale). - Identity: A pass-through module used when no embedding network is configured.
Code Reference
class Encoder(nn.Module):
def __init__(self, x_size, h_sizes, z_size, non_linear_layer):
super().__init__()
self.z_size = z_size
output_size = 2 * z_size
self.mlp = MLP(x_size, h_sizes + [output_size], non_linear_layer)
def forward(self, x):
a = self.mlp(x)
return a[:, 0:self.z_size], softplus(a[:, self.z_size:])
class Decoder(nn.Module):
def __init__(self, x_size, h_sizes, z_size, bias, use_sigmoid, non_linear_layer):
super().__init__()
self.mlp = MLP(z_size, h_sizes + [x_size], non_linear_layer)
def forward(self, z):
a = self.mlp(z)
if self.bias is not None:
a = a + self.bias
return torch.sigmoid(a) if self.use_sigmoid else a
class Predict(nn.Module):
def forward(self, h):
out = self.mlp(h)
z_pres_p = torch.sigmoid(out[:, 0:self.z_pres_size])
z_where_loc = out[:, self.z_pres_size:self.z_pres_size + self.z_where_size]
z_where_scale = softplus(out[:, (self.z_pres_size + self.z_where_size):])
return z_pres_p, z_where_loc, z_where_scale
I/O Contract
| Module | Input | Output |
|---|---|---|
| Encoder | Flattened attention window [batch, window_size^2] |
(z_what_loc, z_what_scale) each [batch, z_size]
|
| Decoder | Latent code [batch, z_size] |
Pixel intensities [batch, window_size^2]
|
| Predict | RNN hidden state [batch, rnn_hidden_size] |
(z_pres_p, z_where_loc, z_where_scale)
|
| MLP | Arbitrary input [batch, in_size] |
Output [batch, out_sizes[-1]]
|
| Identity | Any tensor | Same tensor (pass-through) |
Usage Examples
from modules import Encoder, Decoder, Predict, MLP, Identity
# Create encoder: window_size^2 -> [200] -> 2*z_what_size
encoder = Encoder(28*28, [200], 50, nn.ReLU)
z_what_loc, z_what_scale = encoder(x_att)
# Create decoder: z_what_size -> [200] -> window_size^2
decoder = Decoder(28*28, [200], 50, bias=None, use_sigmoid=False, non_linear_layer=nn.ReLU)
y_att = decoder(z_what)
# Create predict network: rnn_hidden -> z_pres_p, z_where_loc, z_where_scale
predict = Predict(256, [], 1, 3, nn.ReLU)
z_pres_p, z_where_loc, z_where_scale = predict(h)
Related Pages
- Pyro_ppl_Pyro_AIR_Model - The AIR model that uses these modules
- Pyro_ppl_Pyro_AIR_Training - Training entry point for the AIR model
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment