Implementation:Pyro ppl Pyro AutoRegressiveNN
| Property | Value |
|---|---|
| Module | pyro.nn.auto_reg_nn
|
| Source | pyro/nn/auto_reg_nn.py |
| Lines | 360 |
| Classes | MaskedLinear, ConditionalAutoRegressiveNN, AutoRegressiveNN
|
| Functions | sample_mask_indices, create_mask
|
| Dependencies | torch, torch.nn
|
Overview
This module implements MADE (Masked Autoencoder for Distribution Estimation) networks for use in autoregressive normalizing flows. The key idea is that by applying binary masks to the weight matrices of a feedforward neural network, the output at position i depends only on inputs at positions < i (or a permutation thereof), enforcing an autoregressive property.
Two main classes are provided:
ConditionalAutoRegressiveNN: An autoregressive network that can accept an additional context variable, following the conditional MADE architecture of Paige & Wood (2016).AutoRegressiveNN: A simpler subclass without conditioning, following the original MADE architecture of Germain et al. (2015).
The module also provides MaskedLinear, a custom linear layer that multiplies weights by a binary mask before the forward pass, and utility functions for constructing MADE masks.
Code Reference
Assigns integer indices to hidden units for MADE mask construction. When simple=True, spaces indices evenly by rounding to nearest integer. When simple=False, rounds randomly using Bernoulli sampling.
Constructs the binary masks for all layers of a conditional MADE network. Returns a tuple (masks, mask_skip) where masks is a list of layer masks and mask_skip is the skip-connection mask.
For conditional MADE (context_dim > 0), all context variables are assigned index 0, ensuring they can influence all outputs.
Class: MaskedLinear
A torch.nn.Linear subclass that registers a mask buffer and applies it during forward pass:
def forward(self, _input):
masked_weight = self.weight * self.mask
return F.linear(_input, masked_weight, self.bias)
Class: ConditionalAutoRegressiveNN
Constructor:
input_dim: Dimensionality of the input variable.context_dim: Dimensionality of the context variable.hidden_dims: List of hidden layer sizes (each must be >= input_dim).param_dims: Output parameter dimensions (default[1, 1]for mean and scale).permutation: Optional input ordering (random by default).skip_connections: Whether to add input-to-output skip connections.nonlinearity: Activation function (default ReLU).
Methods:
forward(x, context=None): Runs the MADE forward pass. Returns either a single tensor or a tuple of tensors depending onparam_dims.get_permutation(): Returns the permutation applied to inputs.
Class: AutoRegressiveNN
Subclass of ConditionalAutoRegressiveNN with context_dim=0. The forward method takes only x (no context).
I/O Contract
| Method | Input | Output |
|---|---|---|
AutoRegressiveNN.forward |
x: Tensor(..., input_dim) |
Tensor(..., input_dim) if param_dims=[1]; Tuple of tensors otherwise
|
ConditionalAutoRegressiveNN.forward |
x: Tensor(..., input_dim), context: Tensor(..., context_dim) |
Same as above |
create_mask |
Dimensions and permutation | Tuple (masks: List[Tensor], mask_skip: Tensor)
|
Usage Examples
import torch
from pyro.nn.auto_reg_nn import AutoRegressiveNN, ConditionalAutoRegressiveNN
# Unconditional autoregressive network
x = torch.randn(100, 10)
arn = AutoRegressiveNN(10, [50], param_dims=[1])
p = arn(x) # 1 parameter of size (100, 10)
print(p.shape)
# Two output parameters (e.g., mean and log scale for IAF)
arn2 = AutoRegressiveNN(10, [50], param_dims=[1, 1])
m, s = arn2(x) # each of size (100, 10)
# Multiple parameters of different sizes
arn3 = AutoRegressiveNN(10, [50], param_dims=[1, 5, 3])
a, b, c = arn3(x)
print(a.shape, b.shape, c.shape)
# (100, 1, 10), (100, 5, 10), (100, 3, 10)
# Conditional autoregressive network with context
y = torch.randn(100, 5)
carn = ConditionalAutoRegressiveNN(10, 5, [50], param_dims=[1, 1])
m, s = carn(x, context=y)
print(m.shape, s.shape) # (100, 10), (100, 10)
Related Pages
- Pyro_ppl_Pyro_DenseNN -- Non-autoregressive feedforward network