Implementation:Pyro ppl Pyro DenseNN
| Property | Value |
|---|---|
| Module | pyro.nn.dense_nn
|
| Source | pyro/nn/dense_nn.py |
| Lines | 140 |
| Classes | ConditionalDenseNN, DenseNN
|
| Dependencies | torch
|
Overview
This module provides simple dense (fully connected) feedforward neural networks for use with normalizing flows and other Pyro components that require parameterized functions but do not need autoregressive masking. This contrasts with the MADE-based networks in pyro.nn.auto_reg_nn.
Two classes are provided:
ConditionalDenseNN: Takes both an input and a context variable, concatenates them, and produces multiple output parameter groups.DenseNN: A subclass with no context variable (context_dim=0).
Both classes support flexible output parameterization via param_dims, allowing the network to produce multiple parameter tensors of different sizes from a single forward pass.
Code Reference
Class: ConditionalDenseNN
Constructor:
input_dim(int): Input dimensionality.context_dim(int): Context dimensionality.hidden_dims(list of int): Hidden layer sizes.param_dims(list of int, default[1, 1]): Output is split into tensors of these sizes.nonlinearity(nn.Module, default ReLU): Activation function (not applied to final layer).
Forward pass:
- Broadcasts and concatenates context with input.
- Passes through hidden layers with nonlinearity.
- Passes through output layer (no nonlinearity).
- Reshapes output into parameter groups based on
param_dims.
Output shapes:
- If
output_multiplier == 1: Returns raw output tensor. - If single param with dim > 1: Returns reshaped tensor.
- If all param_dims are 1: Returns tuple of squeezed tensors.
- Otherwise: Returns tuple of sliced tensors.
Class: DenseNN
Subclass of ConditionalDenseNN with context_dim=0. Forward method takes only x.
I/O Contract
| Method | Input | Output |
|---|---|---|
ConditionalDenseNN.forward |
x: Tensor(..., input_dim), context: Tensor(..., context_dim) |
Single Tensor or Tuple of Tensors depending on param_dims
|
DenseNN.forward |
x: Tensor(..., input_dim) |
Single Tensor or Tuple of Tensors depending on param_dims
|
Usage Examples
import torch
from pyro.nn.dense_nn import DenseNN, ConditionalDenseNN
# Simple dense network
input_dim = 10
nn = DenseNN(input_dim, [50], param_dims=[1, input_dim, input_dim])
z = torch.rand(100, input_dim)
a, b, c = nn(z)
print(a.shape) # (100, 1)
print(b.shape) # (100, 10)
print(c.shape) # (100, 10)
# Conditional dense network (e.g., for conditional flows)
context_dim = 5
cnn = ConditionalDenseNN(input_dim, context_dim, [50], param_dims=[1, 1])
x = torch.rand(100, input_dim)
ctx = torch.rand(100, context_dim)
m, s = cnn(x, context=ctx)
print(m.shape) # (100,)
print(s.shape) # (100,)
# Network with default param_dims for coupling flows
nn2 = DenseNN(5, [32, 32], param_dims=[1, 1])
m, s = nn2(torch.randn(64, 5))
Related Pages
- Pyro_ppl_Pyro_AutoRegressiveNN -- Autoregressive (MADE) version with masked weights