Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro AutoRegressiveNN

From Leeroopedia


Property Value
Module pyro.nn.auto_reg_nn
Source pyro/nn/auto_reg_nn.py
Lines 360
Classes MaskedLinear, ConditionalAutoRegressiveNN, AutoRegressiveNN
Functions sample_mask_indices, create_mask
Dependencies torch, torch.nn

Overview

This module implements MADE (Masked Autoencoder for Distribution Estimation) networks for use in autoregressive normalizing flows. The key idea is that by applying binary masks to the weight matrices of a feedforward neural network, the output at position i depends only on inputs at positions < i (or a permutation thereof), enforcing an autoregressive property.

Two main classes are provided:

  • ConditionalAutoRegressiveNN: An autoregressive network that can accept an additional context variable, following the conditional MADE architecture of Paige & Wood (2016).
  • AutoRegressiveNN: A simpler subclass without conditioning, following the original MADE architecture of Germain et al. (2015).

The module also provides MaskedLinear, a custom linear layer that multiplies weights by a binary mask before the forward pass, and utility functions for constructing MADE masks.

Code Reference

Function: sample_mask_indices(input_dim, hidden_dim, simple=True)

Assigns integer indices to hidden units for MADE mask construction. When simple=True, spaces indices evenly by rounding to nearest integer. When simple=False, rounds randomly using Bernoulli sampling.

Function: create_mask(input_dim, context_dim, hidden_dims, permutation, output_dim_multiplier)

Constructs the binary masks for all layers of a conditional MADE network. Returns a tuple (masks, mask_skip) where masks is a list of layer masks and mask_skip is the skip-connection mask.

For conditional MADE (context_dim > 0), all context variables are assigned index 0, ensuring they can influence all outputs.

Class: MaskedLinear

A torch.nn.Linear subclass that registers a mask buffer and applies it during forward pass:

def forward(self, _input):
    masked_weight = self.weight * self.mask
    return F.linear(_input, masked_weight, self.bias)

Class: ConditionalAutoRegressiveNN

Constructor:

  • input_dim: Dimensionality of the input variable.
  • context_dim: Dimensionality of the context variable.
  • hidden_dims: List of hidden layer sizes (each must be >= input_dim).
  • param_dims: Output parameter dimensions (default [1, 1] for mean and scale).
  • permutation: Optional input ordering (random by default).
  • skip_connections: Whether to add input-to-output skip connections.
  • nonlinearity: Activation function (default ReLU).

Methods:

  • forward(x, context=None): Runs the MADE forward pass. Returns either a single tensor or a tuple of tensors depending on param_dims.
  • get_permutation(): Returns the permutation applied to inputs.

Class: AutoRegressiveNN

Subclass of ConditionalAutoRegressiveNN with context_dim=0. The forward method takes only x (no context).

I/O Contract

Method Input Output
AutoRegressiveNN.forward x: Tensor(..., input_dim) Tensor(..., input_dim) if param_dims=[1]; Tuple of tensors otherwise
ConditionalAutoRegressiveNN.forward x: Tensor(..., input_dim), context: Tensor(..., context_dim) Same as above
create_mask Dimensions and permutation Tuple (masks: List[Tensor], mask_skip: Tensor)

Usage Examples

import torch
from pyro.nn.auto_reg_nn import AutoRegressiveNN, ConditionalAutoRegressiveNN

# Unconditional autoregressive network
x = torch.randn(100, 10)
arn = AutoRegressiveNN(10, [50], param_dims=[1])
p = arn(x)  # 1 parameter of size (100, 10)
print(p.shape)

# Two output parameters (e.g., mean and log scale for IAF)
arn2 = AutoRegressiveNN(10, [50], param_dims=[1, 1])
m, s = arn2(x)  # each of size (100, 10)

# Multiple parameters of different sizes
arn3 = AutoRegressiveNN(10, [50], param_dims=[1, 5, 3])
a, b, c = arn3(x)
print(a.shape, b.shape, c.shape)
# (100, 1, 10), (100, 5, 10), (100, 3, 10)

# Conditional autoregressive network with context
y = torch.randn(100, 5)
carn = ConditionalAutoRegressiveNN(10, 5, [50], param_dims=[1, 1])
m, s = carn(x, context=y)
print(m.shape, s.shape)  # (100, 10), (100, 10)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment