Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro SMCFilter Example

From Leeroopedia


Property Value
Implementation Type Pattern Doc
Source File examples/smcfilter.py
Module examples
Pyro Features pyro.infer.SMCFilter, pyro.sample, dist.Delta, dist.Normal, state-space model, online filtering, get_empirical()
Pattern Sequential Monte Carlo filtering for state-space models

Overview

This file demonstrates how to use the SMCFilter algorithm with a simple noisy harmonic oscillator state-space model. The model has the form:

  • State transition: z[t] ~ Normal(A * z[t-1], B * sigma_z)
  • Observation: y[t] ~ Normal(z[t][0], sigma_y)

where A is a rotation matrix, B is a scaling vector, and the state z is 2-dimensional (position and velocity).

The example shows the complete SMC filtering pattern:

  1. Define a model class with init() and step() methods
  2. Define a guide class with matching init() and step() methods (the proposal distribution)
  3. Create an SMCFilter instance with a specified number of particles
  4. Run filtering: call smc.init() followed by smc.step(y) for each observation
  5. Extract posterior estimates via smc.get_empirical()

Code Reference

class SimpleHarmonicModel:
    def __init__(self, process_noise, measurement_noise):
        self.A = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
        self.B = torch.tensor([3.0, 3.0])
        self.sigma_z = torch.tensor(process_noise)
        self.sigma_y = torch.tensor(measurement_noise)

    def init(self, state, initial):
        self.t = 0
        state["z"] = pyro.sample("z_init", dist.Delta(initial, event_dim=1))

    def step(self, state, y=None):
        self.t += 1
        state["z"] = pyro.sample("z_{}".format(self.t),
            dist.Normal(state["z"].matmul(self.A), self.B * self.sigma_z).to_event(1))
        y = pyro.sample("y_{}".format(self.t),
            dist.Normal(state["z"][..., 0], self.sigma_y), obs=y)
        return state["z"], y

class SimpleHarmonicModel_Guide:
    def step(self, state, y=None):
        self.t += 1
        pyro.sample("z_{}".format(self.t),
            dist.Normal(state["z"].matmul(self.model.A),
                        torch.tensor([1.0, 1.0])).to_event(1))

def main(args):
    model = SimpleHarmonicModel(args.process_noise, args.measurement_noise)
    guide = SimpleHarmonicModel_Guide(model)
    smc = SMCFilter(model, guide, num_particles=args.num_particles, max_plate_nesting=0)
    smc.init(initial=torch.tensor([1.0, 0.0]))
    for y in ys[1:]:
        smc.step(y)
    z = smc.get_empirical()["z"]

I/O Contract

Parameter Type Description
-n / --num-timesteps int Number of time steps (default: 500)
-p / --num-particles int Number of SMC particles (default: 100)
--process-noise float Process noise standard deviation (default: 1.0)
--measurement-noise float Measurement noise standard deviation (default: 1.0)
--seed int Random seed (default: 0)

Output:

  • True final state z[-1]
  • Posterior mean and standard deviation of z at the final time step

Usage Examples

from pyro.infer import SMCFilter

model = SimpleHarmonicModel(process_noise=1.0, measurement_noise=1.0)
guide = SimpleHarmonicModel_Guide(model)

smc = SMCFilter(model, guide, num_particles=100, max_plate_nesting=0)
smc.init(initial=torch.tensor([1.0, 0.0]))

for observation in observations:
    smc.step(observation)

# Get particle approximation to posterior
empirical = smc.get_empirical()
posterior_mean = empirical["z"].mean
posterior_std = empirical["z"].variance ** 0.5

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment