Implementation:Pyro ppl Pyro SMCFilter Example
Appearance
| Property | Value |
|---|---|
| Implementation Type | Pattern Doc |
| Source File | examples/smcfilter.py
|
| Module | examples |
| Pyro Features | pyro.infer.SMCFilter, pyro.sample, dist.Delta, dist.Normal, state-space model, online filtering, get_empirical()
|
| Pattern | Sequential Monte Carlo filtering for state-space models |
Overview
This file demonstrates how to use the SMCFilter algorithm with a simple noisy harmonic oscillator state-space model. The model has the form:
- State transition:
z[t] ~ Normal(A * z[t-1], B * sigma_z) - Observation:
y[t] ~ Normal(z[t][0], sigma_y)
where A is a rotation matrix, B is a scaling vector, and the state z is 2-dimensional (position and velocity).
The example shows the complete SMC filtering pattern:
- Define a model class with
init()andstep()methods - Define a guide class with matching
init()andstep()methods (the proposal distribution) - Create an
SMCFilterinstance with a specified number of particles - Run filtering: call
smc.init()followed bysmc.step(y)for each observation - Extract posterior estimates via
smc.get_empirical()
Code Reference
class SimpleHarmonicModel:
def __init__(self, process_noise, measurement_noise):
self.A = torch.tensor([[0.0, 1.0], [-1.0, 0.0]])
self.B = torch.tensor([3.0, 3.0])
self.sigma_z = torch.tensor(process_noise)
self.sigma_y = torch.tensor(measurement_noise)
def init(self, state, initial):
self.t = 0
state["z"] = pyro.sample("z_init", dist.Delta(initial, event_dim=1))
def step(self, state, y=None):
self.t += 1
state["z"] = pyro.sample("z_{}".format(self.t),
dist.Normal(state["z"].matmul(self.A), self.B * self.sigma_z).to_event(1))
y = pyro.sample("y_{}".format(self.t),
dist.Normal(state["z"][..., 0], self.sigma_y), obs=y)
return state["z"], y
class SimpleHarmonicModel_Guide:
def step(self, state, y=None):
self.t += 1
pyro.sample("z_{}".format(self.t),
dist.Normal(state["z"].matmul(self.model.A),
torch.tensor([1.0, 1.0])).to_event(1))
def main(args):
model = SimpleHarmonicModel(args.process_noise, args.measurement_noise)
guide = SimpleHarmonicModel_Guide(model)
smc = SMCFilter(model, guide, num_particles=args.num_particles, max_plate_nesting=0)
smc.init(initial=torch.tensor([1.0, 0.0]))
for y in ys[1:]:
smc.step(y)
z = smc.get_empirical()["z"]
I/O Contract
| Parameter | Type | Description |
|---|---|---|
-n / --num-timesteps |
int |
Number of time steps (default: 500) |
-p / --num-particles |
int |
Number of SMC particles (default: 100) |
--process-noise |
float |
Process noise standard deviation (default: 1.0) |
--measurement-noise |
float |
Measurement noise standard deviation (default: 1.0) |
--seed |
int |
Random seed (default: 0) |
Output:
- True final state z[-1]
- Posterior mean and standard deviation of z at the final time step
Usage Examples
from pyro.infer import SMCFilter
model = SimpleHarmonicModel(process_noise=1.0, measurement_noise=1.0)
guide = SimpleHarmonicModel_Guide(model)
smc = SMCFilter(model, guide, num_particles=100, max_plate_nesting=0)
smc.init(initial=torch.tensor([1.0, 0.0]))
for observation in observations:
smc.step(observation)
# Get particle approximation to posterior
empirical = smc.get_empirical()
posterior_mean = empirical["z"].mean
posterior_std = empirical["z"].variance ** 0.5
Related Pages
- Pyro_ppl_Pyro_Funsor_HMM - HMM models (batch inference rather than online filtering)
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment