Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pyro ppl Pyro PyroModule Class

From Leeroopedia


Knowledge Sources
Domains Deep_Learning, Probabilistic_Programming, Bayesian_Inference
Last Updated 2026-02-09 00:00 GMT

Overview

Concrete class that extends torch.nn.Module with Pyro probabilistic semantics, allowing module attributes to be learnable constrained parameters or random variables with prior distributions.

Description

PyroModule is a subclass of nn.Module that uses a custom metaclass (_PyroModuleMeta) to intercept attribute access via __getattr__, __setattr__, and __delattr__. This interception enables two special attribute types:

  • PyroParam(init_value, constraint, event_dim): A learnable parameter with optional constraint enforcement. Internally managed via Pyro's parameter store, the unconstrained representation is stored and the constraint transform is applied on access.
  • PyroSample(prior): A random variable. Each time the attribute is accessed during a forward pass, a new sample is drawn from the specified prior distribution. The prior can be a Distribution object or a callable that takes the module as input and returns a distribution (enabling hierarchical priors where the prior depends on other module attributes).

When PyroSample attributes are accessed, the class internally calls pyro.sample with a site name derived from the module's name prefix and the attribute name. This means all sample sites are automatically named and traced by Pyro's effect handler system.

PyroModule supports module_local_params mode (via pyro.settings.set(module_local_params=True)), which stores parameters locally within the module rather than in Pyro's global parameter store. This is recommended for most new code as it avoids global state and enables multiple independent model instances.

Usage

Use PyroModule as the base class for any neural network component that needs Bayesian treatment of its parameters. Subclass PyroModule instead of nn.Module, then set weight and bias attributes to PyroSample objects to create Bayesian layers. Use PyroParam for constrained parameters that should be optimized (not sampled). Combine with AutoNormal or other auto-guides for variational inference over the sampled parameters.

Code Reference

Source Location

  • Repository: pyro
  • File: pyro/nn/module.py
  • Lines: L339-473 (main class definition)

Signature

class PyroModule(torch.nn.Module, metaclass=_PyroModuleMeta):
    """
    Subclass of torch.nn.Module whose attributes can be modified by
    Pyro effects. Attributes can be set to PyroParam or PyroSample
    objects to create learnable parameters or random variables.

    Args:
        name: Optional name prefix for Pyro parameters and sample sites.
    """
    def __init__(self, name: str = ""):
        ...

Import

from pyro.nn import PyroModule, PyroParam, PyroSample

Helper Classes

class PyroParam(tuple):
    """
    Declares a learnable constrained parameter.

    Args:
        init_value: Initial parameter value (Tensor)
        constraint: torch.distributions.constraints object (default: real)
        event_dim: Number of rightmost dimensions as event dims (default: None)
    """

class PyroSample(tuple):
    """
    Declares a random variable with a prior distribution.

    Args:
        prior: A Distribution object or callable(module) -> Distribution
    """

I/O Contract

Inputs

Name Type Required Description
name str No Name prefix for Pyro parameter and sample site names (default: "")
attributes set as PyroParam PyroParam No Learnable constrained parameters: PyroParam(init_value, constraint, event_dim)
attributes set as PyroSample PyroSample No Random variables with priors: PyroSample(prior_distribution)

Outputs

Name Type Description
instance PyroModule (nn.Module subclass) A module with Pyro probabilistic semantics; attribute access triggers pyro.sample or pyro.param calls as appropriate
PyroParam attribute access torch.Tensor Returns the constrained parameter value
PyroSample attribute access torch.Tensor Returns a fresh sample from the prior distribution

Usage Examples

Bayesian Linear Layer

import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample

class BayesianLinear(PyroModule):
    def __init__(self, in_features, out_features):
        super().__init__()
        # Weight prior: independent Normal for each element
        self.weight = PyroSample(
            dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2)
        )
        # Bias prior: independent Normal for each output
        self.bias = PyroSample(
            dist.Normal(0., 10.).expand([out_features]).to_event(1)
        )

    def forward(self, x):
        # self.weight and self.bias are freshly sampled each forward pass
        return x @ self.weight.T + self.bias

Bayesian Regression Model

import torch
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample

class BayesianRegression(PyroModule):
    def __init__(self, in_features):
        super().__init__()
        self.linear = BayesianLinear(in_features, 1)
        # Learnable noise scale with positivity constraint
        self.sigma = PyroSample(dist.HalfNormal(1.))

    def forward(self, x, y=None):
        mu = self.linear(x).squeeze(-1)
        with pyro.plate("data", x.shape[0]):
            obs = pyro.sample("obs", dist.Normal(mu, self.sigma), obs=y)
        return mu

Training with AutoNormal Guide

import torch
import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal

# Create model and guide
model = BayesianRegression(in_features=3)
guide = AutoNormal(model)

# Set up SVI
optimizer = pyro.optim.Adam({"lr": 0.01})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())

# Training loop
x_train = torch.randn(100, 3)
y_train = x_train @ torch.tensor([1.5, -0.3, 0.7]) + 0.1 * torch.randn(100)

for step in range(1000):
    loss = svi.step(x_train, y_train)
    if step % 200 == 0:
        print(f"Step {step}: loss = {loss:.4f}")

Using PyroParam for Constrained Parameters

import torch
import pyro.distributions.constraints as constraints
from pyro.nn import PyroModule, PyroParam

class ConstrainedModule(PyroModule):
    def __init__(self):
        super().__init__()
        # Positive parameter (e.g., length scale for a kernel)
        self.length_scale = PyroParam(
            torch.tensor(1.0),
            constraint=constraints.positive
        )
        # Parameter on the unit interval
        self.mixing_weight = PyroParam(
            torch.tensor(0.5),
            constraint=constraints.unit_interval
        )

Converting Existing nn.Module to Bayesian

import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample

# Start with a standard PyTorch model
class DeterministicNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(10, 50)
        self.fc2 = nn.Linear(50, 1)

# Convert to Bayesian by wrapping with PyroModule
model = PyroModule[DeterministicNet]()

# Set priors on specific layers
model.fc1.weight = PyroSample(
    dist.Normal(0., 1.).expand([50, 10]).to_event(2)
)
model.fc1.bias = PyroSample(
    dist.Normal(0., 10.).expand([50]).to_event(1)
)

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment