Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Principle:Pyro ppl Pyro Bayesian Module Integration

From Leeroopedia


Knowledge Sources
Domains Deep_Learning, Probabilistic_Programming, Bayesian_Inference
Last Updated 2026-02-09 00:00 GMT

Overview

A principle for seamlessly integrating probabilistic programming with PyTorch's nn.Module system, enabling neural network modules whose parameters can be treated as random variables with prior distributions.

Description

PyTorch's nn.Module is the standard abstraction for building neural networks, providing parameter management, serialization, device placement, and composability. However, standard nn.Module parameters are deterministic point estimates. Bayesian module integration bridges the gap between deterministic deep learning and probabilistic programming by allowing module attributes to be either learnable constrained parameters or random variables sampled from prior distributions.

This integration is achieved through PyroModule, a subclass of nn.Module that intercepts attribute access via a custom metaclass. When a module attribute is set to a PyroParam, it behaves like a standard nn.Parameter but with constraint support managed through Pyro's parameter store. When set to a PyroSample, the attribute becomes a random variable: each time the attribute is accessed during a forward pass, a fresh sample is drawn from the specified prior distribution.

This approach enables several powerful modeling patterns:

  • Bayesian neural networks: Place priors on all weights and biases to get full posterior uncertainty over network parameters. Each forward pass samples a different set of weights, naturally producing predictive uncertainty.
  • Bayesian layers mixed with point-estimated layers: Some layers can use PyroSample (Bayesian) while others use standard nn.Parameter (point-estimated), allowing flexible regularization.
  • PyTorch ecosystem compatibility: Since PyroModule inherits from nn.Module, standard PyTorch tools for serialization (state_dict), device management (.to(device)), and module composition (nested modules) work as expected.

PyroModule is the modern alternative to the older pyro.module() primitive. While pyro.module() requires explicit registration calls inside model functions, PyroModule provides a more natural, object-oriented interface where probabilistic semantics are defined at module construction time.

Two key attribute types enable this integration:

  • PyroParam(init_value, constraint, event_dim): Declares a learnable parameter with an optional constraint (e.g., positivity). Managed by Pyro's parameter store.
  • PyroSample(prior): Declares a random variable. The prior can be a distribution or a callable that returns a distribution (enabling hierarchical priors).

Usage

Use this principle when building models that combine deep learning architectures with Bayesian uncertainty quantification. It is particularly valuable for Bayesian neural networks, where you want posterior distributions over weights rather than point estimates. Also use it when you need better integration between Pyro models and the PyTorch ecosystem (e.g., using PyTorch Lightning, model serialization, or GPU management). Prefer PyroModule over pyro.module() for new code.

Theoretical Basis

A Bayesian neural network places prior distributions over the network's weights 𝐖:

p(𝐖)=lp(𝐖l)

where each layer's weights 𝐖l have independent priors (commonly 𝒩(0,σ2)). Given data 𝒟={(𝐱i,yi)}, the posterior over weights is:

p(𝐖|𝒟)p(𝒟|𝐖)p(𝐖)

This posterior is intractable for neural networks, so we approximate it using variational inference. The variational distribution qϕ(𝐖) is optimized to minimize:

KL(qϕ(𝐖)p(𝐖|𝒟))

Pseudo-code:

# Bayesian module integration pattern
class BayesianLayer(PyroModule):
    def __init__(self):
        # Weight is a random variable, not a fixed parameter
        self.weight = PyroSample(prior=Normal(0, 1))
        # Bias is also a random variable
        self.bias = PyroSample(prior=Normal(0, 10))

    def forward(self, x):
        # Each call to self.weight samples a new value
        return x @ self.weight.T + self.bias

Predictive uncertainty is obtained by averaging predictions over multiple posterior weight samples:

p(y*|𝐱*,𝒟)=p(y*|𝐱*,𝐖)qϕ(𝐖)d𝐖1Ss=1Sp(y*|𝐱*,𝐖(s))

Related Pages

Implemented By

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment