Implementation:Pyro ppl Pyro Pyro Module Primitive
| Knowledge Sources | |
|---|---|
| Domains | Deep_Learning, Probabilistic_Programming |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete tool for registering PyTorch neural network modules within Pyro probabilistic programs provided by the Pyro library.
Description
pyro.module registers a torch.nn.Module instance with Pyro's parameter store, making its parameters visible to the SVI optimizer. This is essential for deep generative models (VAEs) where encoder and decoder networks must be trained jointly during variational inference.
When called, it registers all of the module's parameters in Pyro's global parameter store under a namespace prefix. Subsequent calls with the same name return the same module, with parameters loaded from the store if available.
Usage
Use pyro.module inside model and guide functions when a neural network's parameters need to be optimized by Pyro's SVI. Register the decoder network in the model and the encoder network in the guide. For more modern usage, consider PyroModule which provides a more seamless integration with PyTorch.
Code Reference
Source Location
- Repository: pyro
- File: pyro/primitives.py
- Lines: L429-503
Signature
def module(
name: str,
nn_module: torch.nn.Module,
update_module_params: bool = False,
) -> torch.nn.Module:
"""
Register a torch.nn.Module with Pyro's parameter store.
Args:
name: name of the module in the parameter store
nn_module: the torch.nn.Module to register
update_module_params: whether to overwrite module params
with values from the param store
Returns:
The same nn_module, registered with Pyro
"""
Import
import pyro
# Used as: pyro.module("decoder", decoder_net)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| name | str | Yes | Name prefix for parameter store registration |
| nn_module | torch.nn.Module | Yes | PyTorch neural network module |
| update_module_params | bool | No | Whether to overwrite with stored params (default: False) |
Outputs
| Name | Type | Description |
|---|---|---|
| return | torch.nn.Module | The same module, now registered with Pyro's param store |
Usage Examples
VAE Model with Decoder
import pyro
import pyro.distributions as dist
import torch.nn as nn
class Decoder(nn.Module):
def __init__(self, z_dim, hidden_dim, x_dim):
super().__init__()
self.fc1 = nn.Linear(z_dim, hidden_dim)
self.fc2 = nn.Linear(hidden_dim, x_dim)
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, z):
h = self.relu(self.fc1(z))
return self.sigmoid(self.fc2(h))
def model(x):
decoder = pyro.module("decoder", Decoder(50, 400, 784))
with pyro.plate("data", x.shape[0]):
z = pyro.sample("z", dist.Normal(torch.zeros(50), torch.ones(50)).to_event(1))
loc_img = decoder(z)
pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x)