Implementation:Pyro ppl Pyro ReweightedWakeSleep
Overview
The rws module (Template:Code) implements Reweighted Wake-Sleep (RWS), a variational inference algorithm that separately optimizes model parameters (theta) and guide parameters (phi) using distinct objectives. Unlike standard ELBO which uses a single loss for both model and guide, RWS provides two losses:
- Wake-theta loss: The IWAE (Importance Weighted Autoencoder) objective, which provides a tighter bound on the log marginal likelihood than the standard ELBO. This is used to train the model parameters.
- Wake-phi loss: A self-normalized importance-weighted version of the CSIS objective, which trains the guide by reweighting samples to better match the posterior.
- Sleep-phi loss: The standard CSIS (inference compilation) loss, computed by sampling from the unconditioned model and evaluating the guide's log-probability. This helps the guide learn to cover the posterior, especially for models with stochastic branching.
The insomnia parameter interpolates between wake-phi (insomnia=1) and sleep-phi (insomnia=0) for the guide loss. Setting insomnia=0 recovers pure sleep-phase training (CSIS). Setting insomnia=1 uses only the wake-phi estimator.
RWS is particularly useful for models with stochastic control flow (branching), where standard ELBO-based methods may struggle.
Code Reference
File: Template:Code
Key Classes
| Class | Parent | Description |
|---|---|---|
| Template:Code | Template:Code | Reweighted Wake-Sleep inference with separate model/guide losses. |
ReweightedWakeSleep Methods
| Method | Description |
|---|---|
| Template:Code | Initialize RWS with configuration parameters. |
| Template:Code | Compute wake-theta and phi losses without gradients. Returns a tuple (wake_theta_loss, phi_loss). |
| Template:Code | Compute losses and perform backward on their sum. Returns (wake_theta_loss, phi_loss) as floats. |
| Template:Code | Internal method that computes all three loss components (wake-theta, wake-phi, sleep-phi) and combines the phi losses. |
| Template:Code | Returns paired model/guide traces with detached log-probabilities for importance weight computation. |
| Template:Code | Static method that replays a model trace through the guide, passing observations as keyword arguments. |
| Template:Code | Wraps a function with a plate for vectorized sleep-phase sampling. |
I/O Contract
Constructor
Inputs:
- Template:Code -- Number of particles for wake-theta and wake-phi (must be > 1). Default is 2.
- Template:Code -- Interpolation between wake-phi (1.0) and sleep-phi (0.0). Default is 1.0. Must be in [0, 1].
- Template:Code -- Whether the model has learnable parameters. If False and insomnia=0, skips wake computation. Default is True.
- Template:Code -- Number of particles for sleep-phi. Defaults to Template:Code.
- Template:Code -- Whether to vectorize particle computation. Default is True.
- Template:Code -- Max nested plate depth. Default is infinity.
- Template:Code -- Enumeration warning toggle. Default is True.
loss
Inputs:
- Template:Code -- A Pyro model callable.
- Template:Code -- A Pyro guide callable. Must accept Template:Code keyword argument for sleep-phi.
- Template:Code -- Passed to model and guide.
Output:
- Template:Code -- Tuple of (wake_theta_loss, phi_loss).
loss_and_grads
Inputs:
- Same as Template:Code.
Output:
- Template:Code -- Tuple of (wake_theta_loss, phi_loss). Gradients are populated on all parameters.
Guide Requirements
For sleep-phi computation (insomnia < 1), the guide must accept an Template:Code keyword argument, where observations is a dict mapping site names to observed tensor values. This is needed because sleep-phase samples are drawn from the unconditioned model.
Usage Examples
Standard RWS
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SVI, ReweightedWakeSleep
from pyro.optim import Adam
def model(observations=None):
z = pyro.sample("z", dist.Normal(0, 1))
obs = observations.get("x") if observations else None
pyro.sample("x", dist.Normal(z, 0.5), obs=obs)
def guide(observations=None):
obs = observations.get("x") if observations else torch.tensor(0.0)
loc = pyro.param("loc", torch.tensor(0.0))
pyro.sample("z", dist.Normal(loc + obs, 1.0))
rws = ReweightedWakeSleep(
num_particles=10,
insomnia=0.5, # 50% wake-phi, 50% sleep-phi
vectorize_particles=True
)
svi = SVI(model, guide, Adam({"lr": 0.01}), loss=rws)
data = torch.tensor(2.0)
for step in range(1000):
theta_loss, phi_loss = svi.step(observations={"x": data})
Pure Sleep Mode (CSIS)
rws = ReweightedWakeSleep(
num_particles=10,
insomnia=0.0, # Pure sleep-phi
model_has_params=False # Skip wake computation
)
Pure Wake Mode
rws = ReweightedWakeSleep(
num_particles=10,
insomnia=1.0 # Pure wake-phi
)
Related Pages
- Pyro_ppl_Pyro_CSIS -- Compiled Sequential Importance Sampling (equivalent to RWS with insomnia=0)
- Pyro_ppl_Pyro_RenyiELBO -- IWAE objective (related to wake-theta with alpha=0)
- Pyro_ppl_Pyro_Importance -- Basic importance sampling
- Pyro_ppl_Pyro_Infer_Utilities -- Utility functions used during trace computation