Implementation:Pyro ppl Pyro AutoNormal Guide
Metadata
| Field | Value |
|---|---|
| Sources | Repo: Pyro |
| Domains | Bayesian_Inference, Variational_Inference |
| Updated | 2026-02-09 |
| Type | Class |
Overview
AutoNormal is an automatic guide that constructs an independent (diagonal) Normal variational distribution for each latent site in a Pyro model, implementing mean-field variational inference. It automatically discovers latent variables from the model structure and creates separate learnable location and scale parameters per site.
Description
The AutoNormal class inherits from AutoGuide and provides a mean-field Normal variational family. When first called, it traces the model to discover all stochastic latent sites, then creates two PyroParam parameters for each site: a location (loc) in unconstrained space and a scale (scale) constrained to be positive via softplus_positive. During each forward pass, it samples from independent Normal distributions in unconstrained space and transforms the samples back to constrained space using the appropriate bijective transform.
Unlike AutoDiagonalNormal, which flattens all latent variables into a single vector, AutoNormal maintains separate named Normal distributions per latent site. This provides more convenient site-level access to parameters and better support for TraceMeanField_ELBO.
The guide handles plates (batched latent variables) automatically by repeating initial values to full plate sizes when subsampling is used. Scale parameters are initialized to a configurable init_scale value (default 0.1) and constrained to remain positive throughout training.
Usage
Import this class when you need a simple, scalable variational approximation for a Pyro model. It is the recommended default guide for most models where full posterior correlations are not critical.
Code Reference
Source Location
- Repository
pyro-ppl/pyro- File
pyro/infer/autoguide/guides.py- Lines
- L415--553
Signature
class AutoNormal(AutoGuide):
scale_constraint = constraints.softplus_positive
def __init__(
self, model, *, init_loc_fn=init_to_feasible, init_scale=0.1, create_plates=None
):
Import
from pyro.infer.autoguide import AutoNormal
I/O Contract
Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
model |
callable | Yes | A Pyro model function containing pyro.sample statements for latent variables
|
init_loc_fn |
callable | No | Per-site initialization function for location parameters (default: init_to_feasible)
|
init_scale |
float | No | Initial scale for the standard deviation of each latent variable in unconstrained space (default: 0.1, must be > 0) |
create_plates |
callable or None | No | Optional function returning pyro.plate contexts for data subsampling
|
Outputs
| Name | Type | Description |
|---|---|---|
| return value | dict | Dictionary mapping sample site names (str) to sampled latent values (Tensor), transformed to the constrained space of each site's distribution support |
Internal Parameters Created
| Attribute | Type | Description |
|---|---|---|
self.locs |
PyroModule | Container of per-site location parameters (PyroParam with constraints.real)
|
self.scales |
PyroModule | Container of per-site scale parameters (PyroParam with softplus_positive constraint)
|
self._event_dims |
dict | Dictionary mapping site names to unconstrained event dimensions |
Usage Examples
Basic Usage with SVI
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, Trace_ELBO
from pyro.infer.autoguide import AutoNormal
# Define a simple Bayesian model
def model(data):
loc = pyro.sample("loc", dist.Normal(0.0, 10.0))
scale = pyro.sample("scale", dist.LogNormal(0.0, 1.0))
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Normal(loc, scale), obs=data)
# Construct the mean-field guide automatically
guide = AutoNormal(model)
# Set up stochastic variational inference
optimizer = pyro.optim.Adam({"lr": 0.01})
svi = SVI(model, guide, optimizer, loss=Trace_ELBO())
# Training loop
data = torch.randn(100) + 3.0
for step in range(1000):
loss = svi.step(data)
Accessing Learned Parameters
# After training, access the variational parameters
for name, site in guide.prototype_trace.iter_stochastic_nodes():
loc, scale = guide._get_loc_and_scale(name)
print(f"{name}: loc={loc.item():.3f}, scale={scale.item():.3f}")
# Get the posterior median
median = guide.median(data)
print(median) # {'loc': tensor(3.01), 'scale': tensor(1.02)}
Custom Initialization
from pyro.infer.autoguide import AutoNormal
from pyro.infer.autoguide.initialization import init_to_sample
# Use sample-based initialization with larger initial scale
guide = AutoNormal(model, init_loc_fn=init_to_sample, init_scale=0.5)