Implementation:Pyro ppl Pyro Funsor HMM
Appearance
| Property | Value |
|---|---|
| Implementation Type | Pattern Doc |
| Source File | examples/contrib/funsor/hmm.py
|
| Module | examples.contrib.funsor |
| Pyro Features | pyro.contrib.funsor backend, pyroapi, TraceEnum_ELBO, TraceMarkovEnum_ELBO, pyro.markov, pyro.vectorized_markov, AutoDelta, parallel enumeration, JIT compilation
|
| Dataset | JSB Chorales (polyphonic music) |
Overview
This file demonstrates the Funsor backend for Pyro through a comprehensive collection of Hidden Markov Model (HMM) variants for polyphonic music modeling. It shows how Funsor provides an intermediate representation for probabilistic programs, enabling efficient variable elimination algorithms.
Eight model variants are implemented with increasing complexity:
- model_0: Simple HMM with sequential iteration over sequences
- model_1: Vectorized HMM with batched sequences and JIT support
- model_2: Autoregressive HMM (arHMM) with y[t] depending on y[t-1]
- model_3: Factorial HMM (FHMM) with two independent hidden chains
- model_4: Paired Factorial HMM (PFHMM) with dependent hidden chains
- model_5: Neural HMM (nnHMM) with a CNN-based emission model
- model_6: Second-order HMM (2HMM) with
pyro.markov(history=2) - model_7: Vectorized time-dimension HMM using
pyro.vectorized_markovwith parallel scan (Funsor-only)
All models use MAP Baum-Welch estimation (marginalizing discrete states, point-estimating parameters) via AutoDelta guide and TraceEnum_ELBO.
Code Reference
# model_1: Batched HMM with JIT support
def model_1(sequences, lengths, args, batch_size=None, include_prior=True):
with handlers.mask(mask=include_prior):
probs_x = pyro.sample("probs_x",
dist.Dirichlet(0.9 * torch.eye(args.hidden_dim) + 0.1).to_event(1))
probs_y = pyro.sample("probs_y",
dist.Beta(0.1, 0.9).expand([args.hidden_dim, data_dim]).to_event(2))
tones_plate = pyro.plate("tones", data_dim, dim=-1)
with pyro.plate("sequences", num_sequences, batch_size, dim=-2) as batch:
lengths = lengths[batch]
x = 0
for t in pyro.markov(range(max_length if args.jit else lengths.max())):
with handlers.mask(mask=(t < lengths).unsqueeze(-1)):
x = pyro.sample("x_{}".format(t), dist.Categorical(probs_x[x]),
infer={"enumerate": "parallel"})
with tones_plate:
pyro.sample("y_{}".format(t),
dist.Bernoulli(probs_y[x.squeeze(-1)]),
obs=sequences[batch, t])
I/O Contract
| Parameter | Type | Description |
|---|---|---|
sequences |
torch.Tensor |
Polyphonic music data [num_sequences, max_length, data_dim]
|
lengths |
torch.Tensor |
Sequence lengths [num_sequences]
|
--model |
str |
Model variant: "0" through "7" |
--hidden-dim |
int |
Hidden state dimension (default: 16) |
--funsor |
flag | Use Funsor backend (required for model_7) |
--jit |
flag | Enable JIT compilation |
Output:
- Training loss per step
- Final training and test loss (per observation)
- Model capacity (number of parameters)
Usage Examples
# Standard HMM with Pyro backend
# python hmm.py -m 1 -n 50 -b 8 -d 16 -lr 0.05
# Factorial HMM
# python hmm.py -m 3 -n 50 --jit
# Vectorized HMM with Funsor backend (parallel scan)
# python hmm.py -m 7 --funsor -n 50
# Second-order HMM with Raftery parameterization
# python hmm.py -m 6 -rp -n 50
Related Pages
- Pyro_ppl_Pyro_CJS_Models - Another example using parallel enumeration of discrete states
- Pyro_ppl_Pyro_MixedHMM_Model - Mixed-effects HMM for animal movement
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment