Implementation:Pyro ppl Pyro MixedHMM Model
Appearance
| Property | Value |
|---|---|
| Implementation Type | Pattern Doc |
| Source File | examples/mixed_hmm/model.py
|
| Module | examples.mixed_hmm |
| Pyro Features | config_enumerate, pyro.plate, pyro.markov, poutine.mask, Vindex, MaskedMixture, discrete/continuous random effects, zero-inflated distributions
|
| Pattern | Hierarchical mixed-effects HMM for animal movement data |
Overview
This file implements a hierarchical mixed-effects Hidden Markov Model designed for analyzing animal movement data (step sizes, turning angles, dive activity). The model supports both discrete and continuous random effects at two levels of hierarchy:
- Group level: Random effects shared across individuals within a group (e.g., species, colony)
- Individual level: Random effects specific to each individual animal
The HMM transition matrix is composed additively from a baseline (gamma) plus group-level and individual-level random effects. The random effects can be either:
- Discrete: Categorical latent types with learned embeddings (enabling mixture models)
- Continuous: Normal random effects with mean-field variational inference
Three observation types are modeled at each time step:
- Step size: Gamma distribution with zero-inflation via
MaskedMixture - Turning angle: Von Mises distribution
- Dive activity: Beta distribution with zero-inflation via
MaskedMixture
Code Reference
@config_enumerate
def model_generic(config):
"""Hierarchical mixed-effects hidden markov model"""
N_state = config["sizes"]["state"]
gamma = torch.zeros((N_state ** 2,))
with pyro.plate("group", N_c, dim=-1):
if config["group"]["random"] == "discrete":
e_g = pyro.sample("e_g", dist.Categorical(probs_e_g))
eps_g = Vindex(theta_g)[..., e_g, :]
elif config["group"]["random"] == "continuous":
eps_g = pyro.sample("eps_g", dist.Normal(loc_g, scale_g).to_event(1))
gamma = gamma + eps_g
with pyro.plate("individual", N_s, dim=-2):
# Individual-level random effects...
gamma = gamma + eps_i
y = torch.tensor(0).long()
for t in pyro.markov(range(N_t)):
with poutine.mask(mask=config["timestep"]["mask"][..., t]):
gamma_t = gamma.reshape(..., N_state, N_state)
gamma_y = Vindex(gamma_t)[..., y, :]
y = pyro.sample("y_{}".format(t), dist.Categorical(logits=gamma_y))
# Observe step, angle, omega...
I/O Contract
| Parameter | Type | Description |
|---|---|---|
config["sizes"]["state"] |
int |
Number of HMM hidden states |
config["sizes"]["group"] |
int |
Number of groups |
config["sizes"]["individual"] |
int |
Number of individuals per group |
config["sizes"]["timesteps"] |
int |
Number of time steps |
config["group"]["random"] |
str |
Group effect type: "discrete", "continuous", or "none" |
config["individual"]["random"] |
str |
Individual effect type: "discrete", "continuous", or "none" |
config["observations"] |
dict |
Observed step sizes, angles, and omega values per timestep |
Key sample sites:
e_g/eps_g: Group-level discrete/continuous random effectse_i/eps_i: Individual-level discrete/continuous random effectsy_{t}: Hidden state at time t (enumerated)step_{t},angle_{t},omega_{t}: Observations
Usage Examples
from model import model_generic, guide_generic
config = {
"sizes": {"state": 3, "group": 5, "individual": 10, "timesteps": 100, "random": 4},
"group": {"random": "continuous"},
"individual": {"random": "continuous", "mask": torch.ones(10, 5).bool()},
"timestep": {"mask": torch.ones(10, 5, 100).bool()},
"observations": {"step": step_data, "angle": angle_data, "omega": omega_data},
"MISSING": -1.0,
}
# Use with SVI and TraceEnum_ELBO for discrete state marginalization
svi = SVI(model_generic, guide_generic, Adam({"lr": 0.01}),
TraceEnum_ELBO(max_plate_nesting=2))
loss = svi.step(config)
Related Pages
- Pyro_ppl_Pyro_CJS_Models - Simpler capture-recapture models with discrete state enumeration
- Pyro_ppl_Pyro_Funsor_HMM - HMM examples for music data
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment