Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro AIR Modules

From Leeroopedia
Revision as of 16:22, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pyro_ppl_Pyro_AIR_Modules.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Property Value
Implementation Type Pattern Doc
Source File examples/air/modules.py
Module examples.air
Pyro Features nn.Module subclasses used with pyro.module
Dependencies PyTorch nn, softplus

Overview

This file defines the neural network building blocks used by the AIR (Attend, Infer, Repeat) model. These modules are registered with Pyro via pyro.module() so their parameters are tracked during optimization.

The file provides five module classes:

  • Encoder: Maps attention window pixel intensities to parameters (mean, std) of the variational distribution over z_what. Uses softplus to ensure positive standard deviations.
  • Decoder: Maps latent code z_what back to pixel intensities in the attention window. Optionally applies a bias and sigmoid nonlinearity.
  • MLP: A general-purpose multi-layer perceptron used as a building block by other modules. Supports configurable hidden layer sizes and optional output nonlinearity.
  • Predict: Maps the guide RNN hidden state to parameters of the guide distributions over z_pres (via sigmoid) and z_where (mean and softplus scale).
  • Identity: A pass-through module used when no embedding network is configured.

Code Reference

class Encoder(nn.Module):
    def __init__(self, x_size, h_sizes, z_size, non_linear_layer):
        super().__init__()
        self.z_size = z_size
        output_size = 2 * z_size
        self.mlp = MLP(x_size, h_sizes + [output_size], non_linear_layer)

    def forward(self, x):
        a = self.mlp(x)
        return a[:, 0:self.z_size], softplus(a[:, self.z_size:])

class Decoder(nn.Module):
    def __init__(self, x_size, h_sizes, z_size, bias, use_sigmoid, non_linear_layer):
        super().__init__()
        self.mlp = MLP(z_size, h_sizes + [x_size], non_linear_layer)

    def forward(self, z):
        a = self.mlp(z)
        if self.bias is not None:
            a = a + self.bias
        return torch.sigmoid(a) if self.use_sigmoid else a

class Predict(nn.Module):
    def forward(self, h):
        out = self.mlp(h)
        z_pres_p = torch.sigmoid(out[:, 0:self.z_pres_size])
        z_where_loc = out[:, self.z_pres_size:self.z_pres_size + self.z_where_size]
        z_where_scale = softplus(out[:, (self.z_pres_size + self.z_where_size):])
        return z_pres_p, z_where_loc, z_where_scale

I/O Contract

Module Input Output
Encoder Flattened attention window [batch, window_size^2] (z_what_loc, z_what_scale) each [batch, z_size]
Decoder Latent code [batch, z_size] Pixel intensities [batch, window_size^2]
Predict RNN hidden state [batch, rnn_hidden_size] (z_pres_p, z_where_loc, z_where_scale)
MLP Arbitrary input [batch, in_size] Output [batch, out_sizes[-1]]
Identity Any tensor Same tensor (pass-through)

Usage Examples

from modules import Encoder, Decoder, Predict, MLP, Identity

# Create encoder: window_size^2 -> [200] -> 2*z_what_size
encoder = Encoder(28*28, [200], 50, nn.ReLU)
z_what_loc, z_what_scale = encoder(x_att)

# Create decoder: z_what_size -> [200] -> window_size^2
decoder = Decoder(28*28, [200], 50, bias=None, use_sigmoid=False, non_linear_layer=nn.ReLU)
y_att = decoder(z_what)

# Create predict network: rnn_hidden -> z_pres_p, z_where_loc, z_where_scale
predict = Predict(256, [], 1, 3, nn.ReLU)
z_pres_p, z_where_loc, z_where_scale = predict(h)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment