Implementation:Pyro ppl Pyro PyroModule Class
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Probabilistic_Programming, Bayesian_Inference |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete class that extends torch.nn.Module with Pyro probabilistic semantics, allowing module attributes to be learnable constrained parameters or random variables with prior distributions.
Description
PyroModule is a subclass of nn.Module that uses a custom metaclass (_PyroModuleMeta) to intercept attribute access via __getattr__, __setattr__, and __delattr__. This interception enables two special attribute types:
- PyroParam(init_value, constraint, event_dim): A learnable parameter with optional constraint enforcement. Internally managed via Pyro's parameter store, the unconstrained representation is stored and the constraint transform is applied on access.
- PyroSample(prior): A random variable. Each time the attribute is accessed during a forward pass, a new sample is drawn from the specified prior distribution. The prior can be a Distribution object or a callable that takes the module as input and returns a distribution (enabling hierarchical priors where the prior depends on other module attributes).
When PyroSample attributes are accessed, the class internally calls pyro.sample with a site name derived from the module's name prefix and the attribute name. This means all sample sites are automatically named and traced by Pyro's effect handler system.
PyroModule supports module_local_params mode (via pyro.settings.set(module_local_params=True)), which stores parameters locally within the module rather than in Pyro's global parameter store. This is recommended for most new code as it avoids global state and enables multiple independent model instances.
Usage
Use PyroModule as the base class for any neural network component that needs Bayesian treatment of its parameters. Subclass PyroModule instead of nn.Module, then set weight and bias attributes to PyroSample objects to create Bayesian layers. Use PyroParam for constrained parameters that should be optimized (not sampled). Combine with AutoNormal or other auto-guides for variational inference over the sampled parameters.
Code Reference
Source Location
- Repository: pyro
- File: pyro/nn/module.py
- Lines: L339-473 (main class definition)
Signature
class PyroModule(torch.nn.Module, metaclass=_PyroModuleMeta):
"""
Subclass of torch.nn.Module whose attributes can be modified by
Pyro effects. Attributes can be set to PyroParam or PyroSample
objects to create learnable parameters or random variables.
Args:
name: Optional name prefix for Pyro parameters and sample sites.
"""
def __init__(self, name: str = ""):
...
Import
from pyro.nn import PyroModule, PyroParam, PyroSample
Helper Classes
class PyroParam(tuple):
"""
Declares a learnable constrained parameter.
Args:
init_value: Initial parameter value (Tensor)
constraint: torch.distributions.constraints object (default: real)
event_dim: Number of rightmost dimensions as event dims (default: None)
"""
class PyroSample(tuple):
"""
Declares a random variable with a prior distribution.
Args:
prior: A Distribution object or callable(module) -> Distribution
"""
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| name | str | No | Name prefix for Pyro parameter and sample site names (default: "") |
| attributes set as PyroParam | PyroParam | No | Learnable constrained parameters: PyroParam(init_value, constraint, event_dim) |
| attributes set as PyroSample | PyroSample | No | Random variables with priors: PyroSample(prior_distribution) |
Outputs
| Name | Type | Description |
|---|---|---|
| instance | PyroModule (nn.Module subclass) | A module with Pyro probabilistic semantics; attribute access triggers pyro.sample or pyro.param calls as appropriate |
| PyroParam attribute access | torch.Tensor | Returns the constrained parameter value |
| PyroSample attribute access | torch.Tensor | Returns a fresh sample from the prior distribution |
Usage Examples
Bayesian Linear Layer
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
class BayesianLinear(PyroModule):
def __init__(self, in_features, out_features):
super().__init__()
# Weight prior: independent Normal for each element
self.weight = PyroSample(
dist.Normal(0., 1.).expand([out_features, in_features]).to_event(2)
)
# Bias prior: independent Normal for each output
self.bias = PyroSample(
dist.Normal(0., 10.).expand([out_features]).to_event(1)
)
def forward(self, x):
# self.weight and self.bias are freshly sampled each forward pass
return x @ self.weight.T + self.bias
Bayesian Regression Model
import torch
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
class BayesianRegression(PyroModule):
def __init__(self, in_features):
super().__init__()
self.linear = BayesianLinear(in_features, 1)
# Learnable noise scale with positivity constraint
self.sigma = PyroSample(dist.HalfNormal(1.))
def forward(self, x, y=None):
mu = self.linear(x).squeeze(-1)
with pyro.plate("data", x.shape[0]):
obs = pyro.sample("obs", dist.Normal(mu, self.sigma), obs=y)
return mu
Training with AutoNormal Guide
import torch
import pyro
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
# Create model and guide
model = BayesianRegression(in_features=3)
guide = AutoNormal(model)
# Set up SVI
optimizer = pyro.optim.Adam({"lr": 0.01})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
# Training loop
x_train = torch.randn(100, 3)
y_train = x_train @ torch.tensor([1.5, -0.3, 0.7]) + 0.1 * torch.randn(100)
for step in range(1000):
loss = svi.step(x_train, y_train)
if step % 200 == 0:
print(f"Step {step}: loss = {loss:.4f}")
Using PyroParam for Constrained Parameters
import torch
import pyro.distributions.constraints as constraints
from pyro.nn import PyroModule, PyroParam
class ConstrainedModule(PyroModule):
def __init__(self):
super().__init__()
# Positive parameter (e.g., length scale for a kernel)
self.length_scale = PyroParam(
torch.tensor(1.0),
constraint=constraints.positive
)
# Parameter on the unit interval
self.mixing_weight = PyroParam(
torch.tensor(0.5),
constraint=constraints.unit_interval
)
Converting Existing nn.Module to Bayesian
import torch
import torch.nn as nn
import pyro
import pyro.distributions as dist
from pyro.nn import PyroModule, PyroSample
# Start with a standard PyTorch model
class DeterministicNet(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(10, 50)
self.fc2 = nn.Linear(50, 1)
# Convert to Bayesian by wrapping with PyroModule
model = PyroModule[DeterministicNet]()
# Set priors on specific layers
model.fc1.weight = PyroSample(
dist.Normal(0., 1.).expand([50, 10]).to_event(2)
)
model.fc1.bias = PyroSample(
dist.Normal(0., 10.).expand([50]).to_event(1)
)