Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:Pyro ppl Pyro SMC Filtering Pattern

From Leeroopedia


Knowledge Sources
Domains Sequential Monte Carlo, Particle Filtering, Design Patterns
Last Updated 2026-02-09 09:00 GMT

Overview

The SMC filtering pattern provides an application-level design pattern for implementing particle filters in Pyro, structuring the model into initialization, transition, and observation components that integrate with Pyro's inference machinery.

Description

While the theoretical foundations of Sequential Monte Carlo are covered in the SMC principle, the SMC filtering pattern describes how to structure and implement particle filters within the Pyro probabilistic programming framework.

The pattern prescribes a specific decomposition of the state-space model into three methods that a user must implement:

init(): Defines the prior distribution over the initial state z_0. This method uses Pyro's sample statement to draw the initial state from a prior distribution. It sets up the state that will be propagated forward in time.

step(state, t, observations): Defines both the state transition and the observation model for a single time step. This method:

  1. Samples the next state z_t from the transition distribution p(z_t | z_{t-1}).
  2. Conditions on the observation x_t via the observation model p(x_t | z_t).
  3. Returns the updated state.

The framework handles the particle-level bookkeeping: maintaining multiple particles, computing importance weights, and performing resampling. The user only needs to define the model structure for a single particle.

Key design considerations:

  • State management: The state object passed between steps can be any Python object (tensor, tuple, dictionary), allowing flexible state representations.
  • Observation handling: Observations are passed as arguments, and the step function conditions on them using Pyro's observe mechanism.
  • Resampling strategy: The framework provides configurable resampling strategies (systematic, multinomial) triggered by effective sample size thresholds.
  • Logging and diagnostics: Particle weights, ESS, and resampled indices are tracked for diagnostics.

Usage

Use the SMC filtering pattern when:

  • Implementing a custom particle filter for a specific state-space model.
  • The state transition is nonlinear or non-Gaussian (ruling out Kalman filters).
  • You need online posterior estimation as observations arrive sequentially.
  • Estimating marginal likelihoods for state-space models.
  • Building custom proposal distributions that leverage observation information.

Theoretical Basis

Pattern structure:

# User implements three methods:

class MyFilter(SMCFilter):
    def init(self, state):
        # Prior over initial state
        z_0 = pyro.sample("z_0", initial_distribution)
        return z_0

    def step(self, state, t, observations):
        # Transition: z_t ~ p(z_t | z_{t-1})
        z_t = pyro.sample(f"z_{t}",
            transition_distribution(state))

        # Observation: x_t ~ p(x_t | z_t)
        pyro.sample(f"x_{t}",
            observation_distribution(z_t),
            obs=observations[t])

        return z_t

Framework execution flow:

# For N particles:

# Initialization:
# for i = 1, ..., N:
#     state[i] = self.init()
#     weights[i] = 1/N

# Time loop:
# for t = 1, ..., T:
#     # Check if resampling needed:
#     if ESS(weights) < threshold:
#         ancestors = resample(weights)
#         state = state[ancestors]
#         weights = 1/N
#
#     # Propagate each particle:
#     for i = 1, ..., N:
#         state[i] = self.step(state[i], t, observations)
#         weights[i] *= p(x_t | state[i])  # from observe statement
#
#     # Normalize weights:
#     weights = weights / sum(weights)

Custom proposal distributions:

# Default proposal: transition prior p(z_t | z_{t-1})
# Better proposal: incorporate current observation

class MyFilter(SMCFilter):
    def step(self, state, t, observations):
        # Proposal that uses observation information:
        # q(z_t | z_{t-1}, x_t) -- typically better than p(z_t | z_{t-1})

        # For linear-Gaussian models, the optimal proposal is:
        # q(z_t | z_{t-1}, x_t) = N(mu_opt, Sigma_opt)
        # mu_opt = Sigma_opt * (Sigma_trans^{-1} A z_{t-1} + H^T R^{-1} x_t)
        # Sigma_opt = (Sigma_trans^{-1} + H^T R^{-1} H)^{-1}

        z_t = pyro.sample(f"z_{t}",
            proposal_distribution(state, observations[t]))

        # Importance weight correction is handled automatically
        # by Pyro's importance sampling machinery
        pyro.sample(f"x_{t}",
            observation_distribution(z_t),
            obs=observations[t])

        return z_t

Marginal likelihood estimation:

# The SMC filter provides an unbiased estimate of the marginal likelihood:
# p_hat(x_{1:T}) = product_{t=1}^{T} (1/N * sum_i w_t^{(i)})

# This can be used for:
# 1. Model comparison: compare log p_hat for different models
# 2. Parameter estimation: maximize p_hat(x | theta) over theta
# 3. Pseudo-marginal MCMC: use p_hat as a drop-in replacement
#    for the true marginal likelihood in an MCMC sampler

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment