Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Pyro Module Primitive

From Leeroopedia


Knowledge Sources
Domains Deep_Learning, Probabilistic_Programming
Last Updated 2026-02-09 00:00 GMT

Overview

Concrete tool for registering PyTorch neural network modules within Pyro probabilistic programs provided by the Pyro library.

Description

pyro.module registers a torch.nn.Module instance with Pyro's parameter store, making its parameters visible to the SVI optimizer. This is essential for deep generative models (VAEs) where encoder and decoder networks must be trained jointly during variational inference.

When called, it registers all of the module's parameters in Pyro's global parameter store under a namespace prefix. Subsequent calls with the same name return the same module, with parameters loaded from the store if available.

Usage

Use pyro.module inside model and guide functions when a neural network's parameters need to be optimized by Pyro's SVI. Register the decoder network in the model and the encoder network in the guide. For more modern usage, consider PyroModule which provides a more seamless integration with PyTorch.

Code Reference

Source Location

  • Repository: pyro
  • File: pyro/primitives.py
  • Lines: L429-503

Signature

def module(
    name: str,
    nn_module: torch.nn.Module,
    update_module_params: bool = False,
) -> torch.nn.Module:
    """
    Register a torch.nn.Module with Pyro's parameter store.

    Args:
        name: name of the module in the parameter store
        nn_module: the torch.nn.Module to register
        update_module_params: whether to overwrite module params
            with values from the param store
    Returns:
        The same nn_module, registered with Pyro
    """

Import

import pyro
# Used as: pyro.module("decoder", decoder_net)

I/O Contract

Inputs

Name Type Required Description
name str Yes Name prefix for parameter store registration
nn_module torch.nn.Module Yes PyTorch neural network module
update_module_params bool No Whether to overwrite with stored params (default: False)

Outputs

Name Type Description
return torch.nn.Module The same module, now registered with Pyro's param store

Usage Examples

VAE Model with Decoder

import pyro
import pyro.distributions as dist
import torch.nn as nn

class Decoder(nn.Module):
    def __init__(self, z_dim, hidden_dim, x_dim):
        super().__init__()
        self.fc1 = nn.Linear(z_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, x_dim)
        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, z):
        h = self.relu(self.fc1(z))
        return self.sigmoid(self.fc2(h))

def model(x):
    decoder = pyro.module("decoder", Decoder(50, 400, 784))
    with pyro.plate("data", x.shape[0]):
        z = pyro.sample("z", dist.Normal(torch.zeros(50), torch.ones(50)).to_event(1))
        loc_img = decoder(z)
        pyro.sample("obs", dist.Bernoulli(loc_img).to_event(1), obs=x)

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment