Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Importance

From Leeroopedia


Overview

The importance module (Template:Code) implements Importance Sampling for posterior inference in Pyro. It provides the Importance class for basic sequential importance sampling, the LogWeightsMixin for computing importance weight analytics, a vectorized_importance_weights function for efficient vectorized computation of importance weights, and a psis_diagnostic function for evaluating guide quality via Pareto Smoothed Importance Sampling.

Importance extends TracePosterior and performs posterior inference by drawing samples from the guide (proposal distribution) and weighting them by the ratio of model to guide log-probabilities. If no guide is provided, it defaults to proposing from the model's prior.

LogWeightsMixin adds convenience methods for computing the log normalizing constant, normalized importance weights, and Effective Sample Size (ESS) from a list of log-weights.

vectorized_importance_weights provides an efficient vectorized alternative for models with static structure, computing importance weights across many particles in a single forward pass using Pyro plates.

psis_diagnostic computes the Pareto tail index k, which assesses how well the guide approximates the posterior: k < 0.5 indicates a good fit, 0.5 <= k <= 0.7 is acceptable, and k > 0.7 suggests a poor fit.

Code Reference

File: Template:Code

Key Classes

Class Parent(s) Description
Template:Code Template:Code, Template:Code Sequential importance sampling using the guide as proposal distribution.
Template:Code -- Mixin providing analytics from a Template:Code attribute.

Importance Methods

Method Description
Template:Code Initialize with model, optional guide (defaults to prior), and number of samples (default 10).
Template:Code Generator yielding Template:Code pairs for each sample.

LogWeightsMixin Methods

Method Description
Template:Code Returns the log of the mean unnormalized importance weight, estimating log Z.
Template:Code Returns normalized importance weights. If Template:Code, returns log-normalized weights.
Template:Code Returns the Effective Sample Size computed from normalized weights.

Standalone Functions

Function Description
Template:Code Vectorized computation of importance weights for static-structure models. Keyword args: Template:Code, Template:Code, Template:Code.
Template:Code Computes the Pareto tail index k diagnostic for a model/guide pair. Keyword args: Template:Code, Template:Code, Template:Code.

I/O Contract

Importance Constructor

Inputs:

Importance._traces

Inputs:

Yields:

  • Template:Code -- Each model trace paired with its log importance weight (log p - log q).

vectorized_importance_weights

Required Keyword Args:

Optional Keyword Args:

Output:

psis_diagnostic

Optional Keyword Args:

Output:

Usage Examples

Basic Importance Sampling

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import Importance, EmpiricalMarginal

def model():
    z = pyro.sample("z", dist.Normal(0, 1))
    pyro.sample("x", dist.Normal(z, 0.5), obs=torch.tensor(2.0))

def guide():
    loc = pyro.param("loc", torch.tensor(0.0))
    pyro.sample("z", dist.Normal(loc, 1.0))

posterior = Importance(model, guide, num_samples=1000)
posterior.run()

marginal = EmpiricalMarginal(posterior, "z")
print("Mean:", marginal.mean)
print("ESS:", posterior.get_ESS())
print("Log normalizer:", posterior.get_log_normalizer())

Prior-Based Importance Sampling

# When no guide is provided, samples from the prior
posterior = Importance(model, num_samples=5000)
posterior.run()

Vectorized Importance Weights

from pyro.infer.importance import vectorized_importance_weights

log_weights, model_trace, guide_trace = vectorized_importance_weights(
    model, guide,
    num_samples=1000,
    max_plate_nesting=4,
    normalized=False
)
print("Log weights shape:", log_weights.shape)  # (1000,)

PSIS Diagnostic

from pyro.infer.importance import psis_diagnostic

k = psis_diagnostic(model, guide,
                    num_particles=5000,
                    max_plate_nesting=4)
if k < 0.5:
    print(f"Good guide fit (k={k:.3f})")
elif k < 0.7:
    print(f"Acceptable guide fit (k={k:.3f})")
else:
    print(f"Poor guide fit (k={k:.3f})")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment