Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Funsor HMM

From Leeroopedia
Revision as of 16:23, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pyro_ppl_Pyro_Funsor_HMM.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Property Value
Implementation Type Pattern Doc
Source File examples/contrib/funsor/hmm.py
Module examples.contrib.funsor
Pyro Features pyro.contrib.funsor backend, pyroapi, TraceEnum_ELBO, TraceMarkovEnum_ELBO, pyro.markov, pyro.vectorized_markov, AutoDelta, parallel enumeration, JIT compilation
Dataset JSB Chorales (polyphonic music)

Overview

This file demonstrates the Funsor backend for Pyro through a comprehensive collection of Hidden Markov Model (HMM) variants for polyphonic music modeling. It shows how Funsor provides an intermediate representation for probabilistic programs, enabling efficient variable elimination algorithms.

Eight model variants are implemented with increasing complexity:

  • model_0: Simple HMM with sequential iteration over sequences
  • model_1: Vectorized HMM with batched sequences and JIT support
  • model_2: Autoregressive HMM (arHMM) with y[t] depending on y[t-1]
  • model_3: Factorial HMM (FHMM) with two independent hidden chains
  • model_4: Paired Factorial HMM (PFHMM) with dependent hidden chains
  • model_5: Neural HMM (nnHMM) with a CNN-based emission model
  • model_6: Second-order HMM (2HMM) with pyro.markov(history=2)
  • model_7: Vectorized time-dimension HMM using pyro.vectorized_markov with parallel scan (Funsor-only)

All models use MAP Baum-Welch estimation (marginalizing discrete states, point-estimating parameters) via AutoDelta guide and TraceEnum_ELBO.

Code Reference

# model_1: Batched HMM with JIT support
def model_1(sequences, lengths, args, batch_size=None, include_prior=True):
    with handlers.mask(mask=include_prior):
        probs_x = pyro.sample("probs_x",
            dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
        probs_y = pyro.sample("probs_y",
            dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2))

    tones_plate = pyro.plate("tones", data_dim, dim=-1)
    with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
        lengths = lengths[batch]
        x = 0
        for t in pyro.markov(range(max_length if args.jit else lengths.max())):
            with handlers.mask(mask=(t < lengths).unsqueeze(-1)):
                x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]),
                                infer={"enumerate": "parallel"})
                with tones_plate:
                    pyro.sample("y_{}".format(t),
                                dist.Bernoulli(probs_y[x.squeeze(-1)]),
                                obs=sequences[batch, t])

I/O Contract

Parameter Type Description
sequences torch.Tensor Polyphonic music data [num_sequences, max_length, data_dim]
lengths torch.Tensor Sequence lengths [num_sequences]
--model str Model variant: "0" through "7"
--hidden-dim int Hidden state dimension (default: 16)
--funsor flag Use Funsor backend (required for model_7)
--jit flag Enable JIT compilation

Output:

  • Training loss per step
  • Final training and test loss (per observation)
  • Model capacity (number of parameters)

Usage Examples

# Standard HMM with Pyro backend
# python hmm.py -m 1 -n 50 -b 8 -d 16 -lr 0.05

# Factorial HMM
# python hmm.py -m 3 -n 50 --jit

# Vectorized HMM with Funsor backend (parallel scan)
# python hmm.py -m 7 --funsor -n 50

# Second-order HMM with Raftery parameterization
# python hmm.py -m 6 -rp -n 50

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment