Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro CJS Models

From Leeroopedia
Revision as of 16:23, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pyro_ppl_Pyro_CJS_Models.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Property Value
Implementation Type Pattern Doc
Source File examples/capture_recapture/cjs.py
Module examples.capture_recapture
Pyro Features pyro.plate, pyro.markov, poutine.mask, TraceEnum_ELBO, TraceTMC_ELBO, AutoDiagonalNormal, parallel enumeration of discrete variables
Datasets European Dipper, Meadow Voles

Overview

This file implements five variants of the Cormack-Jolly-Seber (CJS) model, a foundational model in ecological statistics for analyzing animal capture-recapture data. The models estimate survival probability (phi) and recapture probability (rho) from binary capture histories.

The five model variants demonstrate increasing complexity:

  • model_1: Fixed-effect survival and recapture probabilities (phi, rho are scalars).
  • model_2: Time-varying survival probability (phi_t for each time period) as fixed effects.
  • model_3: Time-varying survival probability as random effects with a hierarchical Normal prior in logit space.
  • model_4: Group-level fixed effects for sex (separate phi for males/females).
  • model_5: Both fixed group effects and fixed time effects: logit(phi_t) = beta_group + gamma_t.

All models use parallel enumeration to exactly marginalize out the discrete alive/dead state variable z_t, and poutine.mask to handle the fact that different individuals are first captured at different times.

Code Reference

def model_1(capture_history, sex):
    N, T = capture_history.shape
    phi = pyro.sample("phi", dist.Uniform(0.0, 1.0))  # survival probability
    rho = pyro.sample("rho", dist.Uniform(0.0, 1.0))  # recapture probability

    with pyro.plate("animals", N, dim=-1):
        z = torch.ones(N)
        first_capture_mask = torch.zeros(N).bool()
        for t in pyro.markov(range(T)):
            with poutine.mask(mask=first_capture_mask):
                mu_z_t = first_capture_mask.float() * phi * z + (1 - first_capture_mask.float())
                z = pyro.sample("z_{}".format(t), dist.Bernoulli(mu_z_t),
                                infer={"enumerate": "parallel"})
                mu_y_t = rho * z
                pyro.sample("y_{}".format(t), dist.Bernoulli(mu_y_t),
                            obs=capture_history[:, t])
            first_capture_mask |= capture_history[:, t].bool()

I/O Contract

Parameter Type Description
capture_history torch.Tensor Binary matrix [N, T] where 1 = captured, 0 = not captured
sex torch.Tensor or None Vector of sex indicators [N] (0=female, 1=male), needed for models 4 and 5
--model str Model variant: "1", "2", "3", "4", or "5"
--dataset str Dataset: "dipper" or "vole"
--tmc flag Use Tensor Monte Carlo instead of exact enumeration

Named sample sites:

  • phi / phi_t / phi_male / phi_female: Survival probabilities
  • rho: Recapture probability
  • z_{t}: Discrete alive/dead state (enumerated)
  • y_{t}: Observed capture event

Usage Examples

import pyro
from pyro.infer import SVI, TraceEnum_ELBO
from pyro.infer.autoguide import AutoDiagonalNormal

# Load capture-recapture data
capture_history = torch.tensor(...)  # shape [N, T]

# Create guide (only exposes continuous variables to AutoDiagonalNormal)
def expose_fn(msg):
    return msg["name"][0:3] in ["phi", "rho"]
guide = AutoDiagonalNormal(poutine.block(model_1, expose_fn=expose_fn))

# Use TraceEnum_ELBO for exact enumeration of discrete states
elbo = TraceEnum_ELBO(max_plate_nesting=1, num_particles=20, vectorize_particles=True)
svi = SVI(model_1, guide, Adam({"lr": 0.002}), elbo)

for step in range(400):
    loss = svi.step(capture_history, sex=None)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment