Implementation:Pyro ppl Pyro Abstract Infer
Overview
The abstract_infer module (Template:Code) defines the foundational base classes for posterior inference in Pyro. It provides the abstract TracePosterior class from which posterior inference algorithms inherit, the deprecated TracePredictive class for posterior predictive sampling, and two marginal distribution containers: EmpiricalMarginal and Marginals.
TracePosterior is the core abstraction. When run, it collects a bag of execution traces from an approximate posterior distribution. Subclasses must implement the Template:Code generator method, which yields Template:Code or Template:Code tuples. After running, the collected traces and weights can be used by utility classes like Template:Code to construct marginal distributions over individual latent variables.
TracePredictive (deprecated in favor of Template:Code) generates traces from the posterior predictive distribution by replaying sampled parameter values from the approximate posterior through the model.
The module also provides WAIC (Widely Applicable Information Criterion) computation through the Template:Code method on TracePosterior.
Code Reference
File: Template:Code
Key Classes
| Class | Parent | Description |
|---|---|---|
| Template:Code | Template:Code (ABCMeta) | Abstract base class for posterior inference. Collects execution traces from an approximate posterior. |
| Template:Code | Template:Code | Deprecated. Generates traces from the posterior predictive distribution. Use Template:Code instead. |
| Template:Code | Template:Code | Marginal distribution over a single site (or multiple sites with the same shape) from a TracePosterior. |
| Template:Code | Template:Code | Container holding marginal distributions over one or more sites from a TracePosterior. |
TracePosterior Methods
| Method | Description |
|---|---|
| Template:Code | Initializes with optional number of chains. Calls Template:Code to prepare internal state. |
| Template:Code | Abstract. Must be implemented by subclasses. Generator yielding Template:Code or Template:Code. |
| Template:Code | Calls Template:Code to populate execution traces, log_weights, and chain_ids. Builds a Categorical distribution for sampling. |
| Template:Code | Draws a random trace from the collected traces using importance sampling. Removes observation nodes from the returned trace. |
| Template:Code | Returns a Template:Code instance for the specified sites. |
| Template:Code | Computes WAIC and effective number of parameters. Only supports models with a single observation node. |
I/O Contract
TracePosterior
Inputs (via Template:Code and Template:Code):
- Template:Code -- Arbitrary arguments passed through to the model/guide execution.
Internal State:
- Template:Code -- List of Template:Code objects collected from inference.
- Template:Code -- List of float log-importance-weights, one per trace.
- Template:Code -- List of integer chain identifiers.
- Template:Code -- Integer specifying the number of parallel chains.
Outputs:
- Template:Code returns Template:Code for chaining.
- Template:Code returns a single Template:Code sampled proportional to importance weights, with observation nodes removed.
- Template:Code returns a Template:Code object.
- Template:Code returns an Template:Code with keys Template:Code and Template:Code.
EmpiricalMarginal
Inputs:
- Template:Code -- A Template:Code instance.
- Template:Code -- Optional list of site names or a single site name string (defaults to Template:Code).
Outputs:
- An Template:Code distribution whose samples and weights are extracted from the trace posterior.
TracePredictive (Deprecated)
Inputs:
- Template:Code -- A callable Pyro model.
- Template:Code -- A Template:Code instance.
- Template:Code -- Number of predictive samples to draw.
- Template:Code -- Optional list of sites to retain from the posterior.
Outputs:
- Generates Template:Code traces by replaying model with posterior samples.
Usage Examples
Basic Importance Sampling with TracePosterior
import pyro
import pyro.distributions as dist
from pyro.infer import Importance, EmpiricalMarginal
def model():
mu = pyro.sample("mu", dist.Normal(0, 1))
pyro.sample("obs", dist.Normal(mu, 0.5), obs=torch.tensor(2.0))
def guide():
loc = pyro.param("loc", torch.tensor(0.0))
pyro.sample("mu", dist.Normal(loc, 1))
# Importance inherits from TracePosterior
posterior = Importance(model, guide, num_samples=1000)
posterior.run()
# Extract marginal distribution
marginal = EmpiricalMarginal(posterior, sites="mu")
print("Posterior mean:", marginal.mean)
print("Posterior variance:", marginal.variance)
Using Marginals Container
from pyro.infer.abstract_infer import Marginals
# Get marginals for specific sites
marginals = posterior.marginal(sites=["mu"])
support = marginals.support()
empirical_dists = marginals.empirical
Computing WAIC
posterior = Importance(model, guide, num_samples=500)
posterior.run()
ic = posterior.information_criterion(pointwise=False)
print("WAIC:", ic["waic"])
print("Effective parameters:", ic["p_waic"])
Related Pages
- Pyro_ppl_Pyro_Importance -- Importance sampling implementation that inherits from TracePosterior
- Pyro_ppl_Pyro_CSIS -- Compiled Sequential Importance Sampling, also inherits from Importance/TracePosterior
- Pyro_ppl_Pyro_Infer_Utilities -- Utility functions used throughout inference
- Pyro_ppl_Pyro_SMCFilter -- Sequential Monte Carlo filter for time-series models