Implementation:Pyro ppl Pyro SMCFilter
Overview
The smcfilter module (Template:Code) implements Sequential Monte Carlo (SMC) filtering, also known as particle filtering, for online Bayesian inference in time-series and state-space models. The SMCFilter class maintains a weighted set of particles that are propagated through time steps, with optional importance resampling to prevent particle degeneracy.
The design uses a two-phase protocol:
- init -- Initialize the particle set using model and guide Template:Code methods.
- step -- Advance the filter one time step using model and guide Template:Code methods.
Both model and guide must be objects with Template:Code and Template:Code methods. The SMCState (a dict-like container) stores all tensors that depend on sampled variables. The model can read and write to the state, but the guide can only read (enforced by a lock mechanism).
Key features:
- Adaptive resampling based on Effective Sample Size (ESS) threshold.
- Systematic resampling which preserves particle diversity better than multinomial resampling.
- Inference complexity is O(len(state) * num_time_steps), so fixed-size state avoids quadratic complexity.
- The Template:Code exception is raised when all particles have zero probability.
Code Reference
File: Template:Code
Key Classes
| Class | Parent | Description |
|---|---|---|
| Template:Code | -- | Top-level interface for sequential Monte Carlo filtering. |
| Template:Code | Template:Code | Dictionary-like container for vectorized particle tensors with resampling support. |
| Template:Code | Template:Code | Exception raised when no hypothesis has nonzero probability. |
SMCFilter Methods
| Method | Description |
|---|---|
| Template:Code | Initialize with model, guide, particle count, plate nesting, and ESS threshold. |
| Template:Code | Initialize the particle set via model.init and guide.init. Performs initial weight update and resampling. |
| Template:Code | Advance one time step via model.step and guide.step. Updates weights and optionally resamples. |
| Template:Code | Returns a dict mapping state keys to Template:Code distributions weighted by particle log-weights. |
| Template:Code | z_t) * p(z_t|z_{t-1}) / q(z_t). |
| Template:Code | Resamples if ESS < ess_threshold * num_particles. |
| Template:Code | Performs systematic resampling using the given probability weights. |
SMCState Methods
| Method | Description |
|---|---|
| Template:Code | Initialize with particle count and zero log-weights. |
| Template:Code | Store a tensor. Validates that values are tensors with correct leading dimension. Raises RuntimeError if guide tries to write (locked). |
| Template:Code | Resample all stored tensors using the given index tensor. Resets log-weights to zero. |
| Template:Code | Context manager that prevents writes (used during guide execution). |
Helper Functions
| Function | Description |
|---|---|
| Template:Code | Systematic resampling implementation that preserves diversity better than multinomial sampling. |
I/O Contract
SMCFilter Constructor
Inputs:
- Template:Code -- Must have Template:Code and Template:Code methods.
- Template:Code -- Must have Template:Code and Template:Code methods.
- Template:Code -- Number of particles.
- Template:Code -- Bound on max nested plate contexts.
- Template:Code -- ESS threshold for resampling (default 0.5). Must be in (0, 1].
Model/Guide Protocol
- Both must have Template:Code and Template:Code methods.
- The Template:Code argument is an Template:Code instance.
- The model may read and write to Template:Code.
- The guide may only read from Template:Code (writes raise Template:Code).
- State values must be Template:Code with leading dimension equal to Template:Code.
get_empirical
Output:
- Template:Code -- Maps state keys to weighted Empirical distributions.
Usage Examples
Basic Particle Filtering
import torch
import pyro
import pyro.distributions as dist
from pyro.infer import SMCFilter
class Model:
def init(self, state):
state["z"] = pyro.sample("z_0", dist.Normal(0, 1).expand([1]))
def step(self, state, y=None):
z_prev = state["z"]
z = pyro.sample("z", dist.Normal(z_prev, 0.1))
state["z"] = z
pyro.sample("y", dist.Normal(z, 0.5), obs=y)
class Guide:
def init(self, state):
pyro.sample("z_0", dist.Normal(0, 1).expand([1]))
def step(self, state, y=None):
z_prev = state["z"]
pyro.sample("z", dist.Normal(z_prev, 0.1))
model = Model()
guide = Guide()
smc = SMCFilter(model, guide, num_particles=100, max_plate_nesting=0)
# Initialize
smc.init()
# Step through observations
observations = torch.randn(50)
for t, y in enumerate(observations):
smc.step(y=y.unsqueeze(0))
# Get filtered state
empirical = smc.get_empirical()
z_dist = empirical["z"]
print("Filtered z mean:", z_dist.mean)
Handling SMCFailed
from pyro.infer.smcfilter import SMCFailed
try:
smc.step(y=extreme_observation)
except SMCFailed as e:
print("All particles collapsed:", e)
# Consider increasing num_particles or improving the guide
Related Pages
- Pyro_ppl_Pyro_Importance -- Importance sampling, the foundation of SMC
- Pyro_ppl_Pyro_Resampler -- Importance resampling for prior predictive checks
- Pyro_ppl_Pyro_Abstract_Infer -- Base inference classes
- Pyro_ppl_Pyro_Infer_Utilities -- Validation utilities used by SMCState