Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro ReweightedWakeSleep

From Leeroopedia


Overview

The rws module (Template:Code) implements Reweighted Wake-Sleep (RWS), a variational inference algorithm that separately optimizes model parameters (theta) and guide parameters (phi) using distinct objectives. Unlike standard ELBO which uses a single loss for both model and guide, RWS provides two losses:

  • Wake-theta loss: The IWAE (Importance Weighted Autoencoder) objective, which provides a tighter bound on the log marginal likelihood than the standard ELBO. This is used to train the model parameters.
  • Wake-phi loss: A self-normalized importance-weighted version of the CSIS objective, which trains the guide by reweighting samples to better match the posterior.
  • Sleep-phi loss: The standard CSIS (inference compilation) loss, computed by sampling from the unconditioned model and evaluating the guide's log-probability. This helps the guide learn to cover the posterior, especially for models with stochastic branching.

The insomnia parameter interpolates between wake-phi (insomnia=1) and sleep-phi (insomnia=0) for the guide loss. Setting insomnia=0 recovers pure sleep-phase training (CSIS). Setting insomnia=1 uses only the wake-phi estimator.

RWS is particularly useful for models with stochastic control flow (branching), where standard ELBO-based methods may struggle.

Code Reference

File: Template:Code

Key Classes

Class Parent Description
Template:Code Template:Code Reweighted Wake-Sleep inference with separate model/guide losses.

ReweightedWakeSleep Methods

Method Description
Template:Code Initialize RWS with configuration parameters.
Template:Code Compute wake-theta and phi losses without gradients. Returns a tuple (wake_theta_loss, phi_loss).
Template:Code Compute losses and perform backward on their sum. Returns (wake_theta_loss, phi_loss) as floats.
Template:Code Internal method that computes all three loss components (wake-theta, wake-phi, sleep-phi) and combines the phi losses.
Template:Code Returns paired model/guide traces with detached log-probabilities for importance weight computation.
Template:Code Static method that replays a model trace through the guide, passing observations as keyword arguments.
Template:Code Wraps a function with a plate for vectorized sleep-phase sampling.

I/O Contract

Constructor

Inputs:

  • Template:Code -- Number of particles for wake-theta and wake-phi (must be > 1). Default is 2.
  • Template:Code -- Interpolation between wake-phi (1.0) and sleep-phi (0.0). Default is 1.0. Must be in [0, 1].
  • Template:Code -- Whether the model has learnable parameters. If False and insomnia=0, skips wake computation. Default is True.
  • Template:Code -- Number of particles for sleep-phi. Defaults to Template:Code.
  • Template:Code -- Whether to vectorize particle computation. Default is True.
  • Template:Code -- Max nested plate depth. Default is infinity.
  • Template:Code -- Enumeration warning toggle. Default is True.

loss

Inputs:

Output:

loss_and_grads

Inputs:

Output:

  • Template:Code -- Tuple of (wake_theta_loss, phi_loss). Gradients are populated on all parameters.

Guide Requirements

For sleep-phi computation (insomnia < 1), the guide must accept an Template:Code keyword argument, where observations is a dict mapping site names to observed tensor values. This is needed because sleep-phase samples are drawn from the unconditioned model.

Usage Examples

Standard RWS

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, ReweightedWakeSleep
from pyro.optim import Adam

def model(observations=None):
    z = pyro.sample("z", dist.Normal(0, 1))
    obs = observations.get("x") if observations else None
    pyro.sample("x", dist.Normal(z, 0.5), obs=obs)

def guide(observations=None):
    obs = observations.get("x") if observations else torch.tensor(0.0)
    loc = pyro.param("loc", torch.tensor(0.0))
    pyro.sample("z", dist.Normal(loc + obs, 1.0))

rws = ReweightedWakeSleep(
    num_particles=10,
    insomnia=0.5,  # 50% wake-phi, 50% sleep-phi
    vectorize_particles=True
)
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=rws)

data = torch.tensor(2.0)
for step in range(1000):
    theta_loss, phi_loss = svi.step(observations={"x": data})

Pure Sleep Mode (CSIS)

rws = ReweightedWakeSleep(
    num_particles=10,
    insomnia=0.0,          # Pure sleep-phi
    model_has_params=False  # Skip wake computation
)

Pure Wake Mode

rws = ReweightedWakeSleep(
    num_particles=10,
    insomnia=1.0   # Pure wake-phi
)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment