Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Abstract Infer

From Leeroopedia


Overview

The abstract_infer module (Template:Code) defines the foundational base classes for posterior inference in Pyro. It provides the abstract TracePosterior class from which posterior inference algorithms inherit, the deprecated TracePredictive class for posterior predictive sampling, and two marginal distribution containers: EmpiricalMarginal and Marginals.

TracePosterior is the core abstraction. When run, it collects a bag of execution traces from an approximate posterior distribution. Subclasses must implement the Template:Code generator method, which yields Template:Code or Template:Code tuples. After running, the collected traces and weights can be used by utility classes like Template:Code to construct marginal distributions over individual latent variables.

TracePredictive (deprecated in favor of Template:Code) generates traces from the posterior predictive distribution by replaying sampled parameter values from the approximate posterior through the model.

The module also provides WAIC (Widely Applicable Information Criterion) computation through the Template:Code method on TracePosterior.

Code Reference

File: Template:Code

Key Classes

Class Parent Description
Template:Code Template:Code (ABCMeta) Abstract base class for posterior inference. Collects execution traces from an approximate posterior.
Template:Code Template:Code Deprecated. Generates traces from the posterior predictive distribution. Use Template:Code instead.
Template:Code Template:Code Marginal distribution over a single site (or multiple sites with the same shape) from a TracePosterior.
Template:Code Template:Code Container holding marginal distributions over one or more sites from a TracePosterior.

TracePosterior Methods

Method Description
Template:Code Initializes with optional number of chains. Calls Template:Code to prepare internal state.
Template:Code Abstract. Must be implemented by subclasses. Generator yielding Template:Code or Template:Code.
Template:Code Calls Template:Code to populate execution traces, log_weights, and chain_ids. Builds a Categorical distribution for sampling.
Template:Code Draws a random trace from the collected traces using importance sampling. Removes observation nodes from the returned trace.
Template:Code Returns a Template:Code instance for the specified sites.
Template:Code Computes WAIC and effective number of parameters. Only supports models with a single observation node.

I/O Contract

TracePosterior

Inputs (via Template:Code and Template:Code):

  • Template:Code -- Arbitrary arguments passed through to the model/guide execution.

Internal State:

Outputs:

EmpiricalMarginal

Inputs:

Outputs:

  • An Template:Code distribution whose samples and weights are extracted from the trace posterior.

TracePredictive (Deprecated)

Inputs:

Outputs:

  • Generates Template:Code traces by replaying model with posterior samples.

Usage Examples

Basic Importance Sampling with TracePosterior

import pyro
import pyro.distributions as dist
from pyro.infer import Importance, EmpiricalMarginal

def model():
    mu = pyro.sample("mu", dist.Normal(0, 1))
    pyro.sample("obs", dist.Normal(mu, 0.5), obs=torch.tensor(2.0))

def guide():
    loc = pyro.param("loc", torch.tensor(0.0))
    pyro.sample("mu", dist.Normal(loc, 1))

# Importance inherits from TracePosterior
posterior = Importance(model, guide, num_samples=1000)
posterior.run()

# Extract marginal distribution
marginal = EmpiricalMarginal(posterior, sites="mu")
print("Posterior mean:", marginal.mean)
print("Posterior variance:", marginal.variance)

Using Marginals Container

from pyro.infer.abstract_infer import Marginals

# Get marginals for specific sites
marginals = posterior.marginal(sites=["mu"])
support = marginals.support()
empirical_dists = marginals.empirical

Computing WAIC

posterior = Importance(model, guide, num_samples=500)
posterior.run()

ic = posterior.information_criterion(pointwise=False)
print("WAIC:", ic["waic"])
print("Effective parameters:", ic["p_waic"])

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment