Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro DenseNN

From Leeroopedia


Property Value
Module pyro.nn.dense_nn
Source pyro/nn/dense_nn.py
Lines 140
Classes ConditionalDenseNN, DenseNN
Dependencies torch

Overview

This module provides simple dense (fully connected) feedforward neural networks for use with normalizing flows and other Pyro components that require parameterized functions but do not need autoregressive masking. This contrasts with the MADE-based networks in pyro.nn.auto_reg_nn.

Two classes are provided:

  • ConditionalDenseNN: Takes both an input and a context variable, concatenates them, and produces multiple output parameter groups.
  • DenseNN: A subclass with no context variable (context_dim=0).

Both classes support flexible output parameterization via param_dims, allowing the network to produce multiple parameter tensors of different sizes from a single forward pass.

Code Reference

Class: ConditionalDenseNN

Constructor:

  • input_dim (int): Input dimensionality.
  • context_dim (int): Context dimensionality.
  • hidden_dims (list of int): Hidden layer sizes.
  • param_dims (list of int, default [1, 1]): Output is split into tensors of these sizes.
  • nonlinearity (nn.Module, default ReLU): Activation function (not applied to final layer).

Forward pass:

  1. Broadcasts and concatenates context with input.
  2. Passes through hidden layers with nonlinearity.
  3. Passes through output layer (no nonlinearity).
  4. Reshapes output into parameter groups based on param_dims.

Output shapes:

  • If output_multiplier == 1: Returns raw output tensor.
  • If single param with dim > 1: Returns reshaped tensor.
  • If all param_dims are 1: Returns tuple of squeezed tensors.
  • Otherwise: Returns tuple of sliced tensors.

Class: DenseNN

Subclass of ConditionalDenseNN with context_dim=0. Forward method takes only x.

I/O Contract

Method Input Output
ConditionalDenseNN.forward x: Tensor(..., input_dim), context: Tensor(..., context_dim) Single Tensor or Tuple of Tensors depending on param_dims
DenseNN.forward x: Tensor(..., input_dim) Single Tensor or Tuple of Tensors depending on param_dims

Usage Examples

import torch
from pyro.nn.dense_nn import DenseNN, ConditionalDenseNN

# Simple dense network
input_dim = 10
nn = DenseNN(input_dim, [50], param_dims=[1, input_dim, input_dim])
z = torch.rand(100, input_dim)
a, b, c = nn(z)
print(a.shape)  # (100, 1)
print(b.shape)  # (100, 10)
print(c.shape)  # (100, 10)

# Conditional dense network (e.g., for conditional flows)
context_dim = 5
cnn = ConditionalDenseNN(input_dim, context_dim, [50], param_dims=[1, 1])
x = torch.rand(100, input_dim)
ctx = torch.rand(100, context_dim)
m, s = cnn(x, context=ctx)
print(m.shape)  # (100,)
print(s.shape)  # (100,)

# Network with default param_dims for coupling flows
nn2 = DenseNN(5, [32, 32], param_dims=[1, 1])
m, s = nn2(torch.randn(64, 5))

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment