Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro MixedHMM Model

From Leeroopedia
Revision as of 16:24, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pyro_ppl_Pyro_MixedHMM_Model.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Property Value
Implementation Type Pattern Doc
Source File examples/mixed_hmm/model.py
Module examples.mixed_hmm
Pyro Features config_enumerate, pyro.plate, pyro.markov, poutine.mask, Vindex, MaskedMixture, discrete/continuous random effects, zero-inflated distributions
Pattern Hierarchical mixed-effects HMM for animal movement data

Overview

This file implements a hierarchical mixed-effects Hidden Markov Model designed for analyzing animal movement data (step sizes, turning angles, dive activity). The model supports both discrete and continuous random effects at two levels of hierarchy:

  • Group level: Random effects shared across individuals within a group (e.g., species, colony)
  • Individual level: Random effects specific to each individual animal

The HMM transition matrix is composed additively from a baseline (gamma) plus group-level and individual-level random effects. The random effects can be either:

  • Discrete: Categorical latent types with learned embeddings (enabling mixture models)
  • Continuous: Normal random effects with mean-field variational inference

Three observation types are modeled at each time step:

  • Step size: Gamma distribution with zero-inflation via MaskedMixture
  • Turning angle: Von Mises distribution
  • Dive activity: Beta distribution with zero-inflation via MaskedMixture

Code Reference

@config_enumerate
def model_generic(config):
    """Hierarchical mixed-effects hidden markov model"""
    N_state = config["sizes"]["state"]
    gamma = torch.zeros((N_state ** 2,))

    with pyro.plate("group", N_c, dim=-1):
        if config["group"]["random"] == "discrete":
            e_g = pyro.sample("e_g", dist.Categorical(probs_e_g))
            eps_g = Vindex(theta_g)[..., e_g, :]
        elif config["group"]["random"] == "continuous":
            eps_g = pyro.sample("eps_g", dist.Normal(loc_g, scale_g).to_event(1))
        gamma = gamma + eps_g

        with pyro.plate("individual", N_s, dim=-2):
            # Individual-level random effects...
            gamma = gamma + eps_i

            y = torch.tensor(0).long()
            for t in pyro.markov(range(N_t)):
                with poutine.mask(mask=config["timestep"]["mask"][..., t]):
                    gamma_t = gamma.reshape(..., N_state, N_state)
                    gamma_y = Vindex(gamma_t)[..., y, :]
                    y = pyro.sample("y_{}".format(t), dist.Categorical(logits=gamma_y))
                    # Observe step, angle, omega...

I/O Contract

Parameter Type Description
config["sizes"]["state"] int Number of HMM hidden states
config["sizes"]["group"] int Number of groups
config["sizes"]["individual"] int Number of individuals per group
config["sizes"]["timesteps"] int Number of time steps
config["group"]["random"] str Group effect type: "discrete", "continuous", or "none"
config["individual"]["random"] str Individual effect type: "discrete", "continuous", or "none"
config["observations"] dict Observed step sizes, angles, and omega values per timestep

Key sample sites:

  • e_g / eps_g: Group-level discrete/continuous random effects
  • e_i / eps_i: Individual-level discrete/continuous random effects
  • y_{t}: Hidden state at time t (enumerated)
  • step_{t}, angle_{t}, omega_{t}: Observations

Usage Examples

from model import model_generic, guide_generic

config = {
    "sizes": {"state": 3, "group": 5, "individual": 10, "timesteps": 100, "random": 4},
    "group": {"random": "continuous"},
    "individual": {"random": "continuous", "mask": torch.ones(10, 5).bool()},
    "timestep": {"mask": torch.ones(10, 5, 100).bool()},
    "observations": {"step": step_data, "angle": angle_data, "omega": omega_data},
    "MISSING": -1.0,
}

# Use with SVI and TraceEnum_ELBO for discrete state marginalization
svi = SVI(model_generic, guide_generic, Adam({"lr": 0.01}),
          TraceEnum_ELBO(max_plate_nesting=2))
loss = svi.step(config)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment