Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro SMCFilter

From Leeroopedia


Overview

The smcfilter module (Template:Code) implements Sequential Monte Carlo (SMC) filtering, also known as particle filtering, for online Bayesian inference in time-series and state-space models. The SMCFilter class maintains a weighted set of particles that are propagated through time steps, with optional importance resampling to prevent particle degeneracy.

The design uses a two-phase protocol:

  1. init -- Initialize the particle set using model and guide Template:Code methods.
  2. step -- Advance the filter one time step using model and guide Template:Code methods.

Both model and guide must be objects with Template:Code and Template:Code methods. The SMCState (a dict-like container) stores all tensors that depend on sampled variables. The model can read and write to the state, but the guide can only read (enforced by a lock mechanism).

Key features:

  • Adaptive resampling based on Effective Sample Size (ESS) threshold.
  • Systematic resampling which preserves particle diversity better than multinomial resampling.
  • Inference complexity is O(len(state) * num_time_steps), so fixed-size state avoids quadratic complexity.
  • The Template:Code exception is raised when all particles have zero probability.

Code Reference

File: Template:Code

Key Classes

Class Parent Description
Template:Code -- Top-level interface for sequential Monte Carlo filtering.
Template:Code Template:Code Dictionary-like container for vectorized particle tensors with resampling support.
Template:Code Template:Code Exception raised when no hypothesis has nonzero probability.

SMCFilter Methods

Method Description
Template:Code Initialize with model, guide, particle count, plate nesting, and ESS threshold.
Template:Code Initialize the particle set via model.init and guide.init. Performs initial weight update and resampling.
Template:Code Advance one time step via model.step and guide.step. Updates weights and optionally resamples.
Template:Code Returns a dict mapping state keys to Template:Code distributions weighted by particle log-weights.
Template:Code z_t) * p(z_t|z_{t-1}) / q(z_t).
Template:Code Resamples if ESS < ess_threshold * num_particles.
Template:Code Performs systematic resampling using the given probability weights.

SMCState Methods

Method Description
Template:Code Initialize with particle count and zero log-weights.
Template:Code Store a tensor. Validates that values are tensors with correct leading dimension. Raises RuntimeError if guide tries to write (locked).
Template:Code Resample all stored tensors using the given index tensor. Resets log-weights to zero.
Template:Code Context manager that prevents writes (used during guide execution).

Helper Functions

Function Description
Template:Code Systematic resampling implementation that preserves diversity better than multinomial sampling.

I/O Contract

SMCFilter Constructor

Inputs:

Model/Guide Protocol

get_empirical

Output:

  • Template:Code -- Maps state keys to weighted Empirical distributions.

Usage Examples

Basic Particle Filtering

import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SMCFilter

class Model:
    def init(self, state):
        state["z"] = pyro.sample("z_0", dist.Normal(0, 1).expand([1]))

    def step(self, state, y=None):
        z_prev = state["z"]
        z = pyro.sample("z", dist.Normal(z_prev, 0.1))
        state["z"] = z
        pyro.sample("y", dist.Normal(z, 0.5), obs=y)

class Guide:
    def init(self, state):
        pyro.sample("z_0", dist.Normal(0, 1).expand([1]))

    def step(self, state, y=None):
        z_prev = state["z"]
        pyro.sample("z", dist.Normal(z_prev, 0.1))

model = Model()
guide = Guide()
smc = SMCFilter(model, guide, num_particles=100, max_plate_nesting=0)

# Initialize
smc.init()

# Step through observations
observations = torch.randn(50)
for t, y in enumerate(observations):
    smc.step(y=y.unsqueeze(0))

# Get filtered state
empirical = smc.get_empirical()
z_dist = empirical["z"]
print("Filtered z mean:", z_dist.mean)

Handling SMCFailed

from pyro.infer.smcfilter import SMCFailed

try:
    smc.step(y=extreme_observation)
except SMCFailed as e:
    print("All particles collapsed:", e)
    # Consider increasing num_particles or improving the guide

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment