Principle:Pyro ppl Pyro SMC Filtering Pattern
| Knowledge Sources | |
|---|---|
| Domains | Sequential Monte Carlo, Particle Filtering, Design Patterns |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
The SMC filtering pattern provides an application-level design pattern for implementing particle filters in Pyro, structuring the model into initialization, transition, and observation components that integrate with Pyro's inference machinery.
Description
While the theoretical foundations of Sequential Monte Carlo are covered in the SMC principle, the SMC filtering pattern describes how to structure and implement particle filters within the Pyro probabilistic programming framework.
The pattern prescribes a specific decomposition of the state-space model into three methods that a user must implement:
init(): Defines the prior distribution over the initial state z_0. This method uses Pyro's sample statement to draw the initial state from a prior distribution. It sets up the state that will be propagated forward in time.
step(state, t, observations): Defines both the state transition and the observation model for a single time step. This method:
- Samples the next state z_t from the transition distribution p(z_t | z_{t-1}).
- Conditions on the observation x_t via the observation model p(x_t | z_t).
- Returns the updated state.
The framework handles the particle-level bookkeeping: maintaining multiple particles, computing importance weights, and performing resampling. The user only needs to define the model structure for a single particle.
Key design considerations:
- State management: The state object passed between steps can be any Python object (tensor, tuple, dictionary), allowing flexible state representations.
- Observation handling: Observations are passed as arguments, and the step function conditions on them using Pyro's observe mechanism.
- Resampling strategy: The framework provides configurable resampling strategies (systematic, multinomial) triggered by effective sample size thresholds.
- Logging and diagnostics: Particle weights, ESS, and resampled indices are tracked for diagnostics.
Usage
Use the SMC filtering pattern when:
- Implementing a custom particle filter for a specific state-space model.
- The state transition is nonlinear or non-Gaussian (ruling out Kalman filters).
- You need online posterior estimation as observations arrive sequentially.
- Estimating marginal likelihoods for state-space models.
- Building custom proposal distributions that leverage observation information.
Theoretical Basis
Pattern structure:
# User implements three methods:
class MyFilter(SMCFilter):
def init(self, state):
# Prior over initial state
z_0 = pyro.sample("z_0", initial_distribution)
return z_0
def step(self, state, t, observations):
# Transition: z_t ~ p(z_t | z_{t-1})
z_t = pyro.sample(f"z_{t}",
transition_distribution(state))
# Observation: x_t ~ p(x_t | z_t)
pyro.sample(f"x_{t}",
observation_distribution(z_t),
obs=observations[t])
return z_t
Framework execution flow:
# For N particles:
# Initialization:
# for i = 1, ..., N:
# state[i] = self.init()
# weights[i] = 1/N
# Time loop:
# for t = 1, ..., T:
# # Check if resampling needed:
# if ESS(weights) < threshold:
# ancestors = resample(weights)
# state = state[ancestors]
# weights = 1/N
#
# # Propagate each particle:
# for i = 1, ..., N:
# state[i] = self.step(state[i], t, observations)
# weights[i] *= p(x_t | state[i]) # from observe statement
#
# # Normalize weights:
# weights = weights / sum(weights)
Custom proposal distributions:
# Default proposal: transition prior p(z_t | z_{t-1})
# Better proposal: incorporate current observation
class MyFilter(SMCFilter):
def step(self, state, t, observations):
# Proposal that uses observation information:
# q(z_t | z_{t-1}, x_t) -- typically better than p(z_t | z_{t-1})
# For linear-Gaussian models, the optimal proposal is:
# q(z_t | z_{t-1}, x_t) = N(mu_opt, Sigma_opt)
# mu_opt = Sigma_opt * (Sigma_trans^{-1} A z_{t-1} + H^T R^{-1} x_t)
# Sigma_opt = (Sigma_trans^{-1} + H^T R^{-1} H)^{-1}
z_t = pyro.sample(f"z_{t}",
proposal_distribution(state, observations[t]))
# Importance weight correction is handled automatically
# by Pyro's importance sampling machinery
pyro.sample(f"x_{t}",
observation_distribution(z_t),
obs=observations[t])
return z_t
Marginal likelihood estimation:
# The SMC filter provides an unbiased estimate of the marginal likelihood:
# p_hat(x_{1:T}) = product_{t=1}^{T} (1/N * sum_i w_t^{(i)})
# This can be used for:
# 1. Model comparison: compare log p_hat for different models
# 2. Parameter estimation: maximize p_hat(x | theta) over theta
# 3. Pseudo-marginal MCMC: use p_hat as a drop-in replacement
# for the true marginal likelihood in an MCMC sampler