Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:Pyro ppl Pyro Funsor Backend HMM

From Leeroopedia
Revision as of 17:51, 16 February 2026 by Admin (talk | contribs) (Auto-imported from principles/Pyro_ppl_Pyro_Funsor_Backend_HMM.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Hidden Markov Models, Tensor Variable Elimination, Exact Inference
Last Updated 2026-02-09 09:00 GMT

Overview

The Funsor backend enables exact inference in Hidden Markov Models by performing tensor variable elimination, automatically marginalizing out discrete latent states using the sum-product algorithm over plated factor graphs.

Description

Hidden Markov Models (HMMs) are a class of structured probabilistic models with:

  • A sequence of discrete latent states z_1, z_2, ..., z_T.
  • Observed emissions x_1, x_2, ..., x_T.
  • Markov property: z_t depends only on z_{t-1}.

Standard inference in HMMs uses the forward-backward algorithm to exactly marginalize out the discrete states in O(T * K^2) time, where K is the number of states and T is the sequence length.

The Funsor backend generalizes this to a broader class of structured models by implementing tensor variable elimination (TVE). Funsor (functional tensor) represents probability distributions as lazy tensor expressions that can be automatically simplified and marginalized.

Key advantages of the Funsor backend for HMMs:

  • Automatic algorithm selection: The user writes the model as a standard Pyro program with discrete sample sites inside plate contexts. Funsor automatically identifies the HMM structure and applies the optimal algorithm (forward-backward), without the user needing to implement it manually.
  • Handling plates (batches): Multiple independent HMM sequences (e.g., different time series) can be processed in parallel using Pyro's plate construct. Funsor handles the tensor bookkeeping automatically.
  • Composability: HMM components can be composed with other model components (e.g., neural network emission models, hierarchical priors over transition parameters). Funsor handles the variable elimination for the discrete parts while allowing gradient-based inference for continuous parts.
  • Exact discrete marginalization: By exactly summing out discrete states, the resulting objective (marginal log-likelihood or ELBO for continuous variables) has zero variance from the discrete variables, dramatically improving gradient estimates.

Usage

Use the Funsor backend for HMMs when:

  • Your model contains discrete latent sequences with Markov structure.
  • You want exact marginalization of discrete states without hand-coding forward-backward.
  • Building models that combine HMMs with neural network components.
  • Working with multiple independent sequences (batched HMMs).
  • You want to avoid the high variance of Monte Carlo estimation for discrete variables.

Theoretical Basis

HMM specification:

# Initial distribution: z_0 ~ Categorical(pi_0)
# Transition model:     z_t | z_{t-1} ~ Categorical(A[z_{t-1}])
# Emission model:       x_t | z_t ~ f(x | theta[z_t])

# where:
# pi_0: initial state probabilities (K-vector)
# A: transition matrix (K x K)
# theta[k]: emission parameters for state k

Forward algorithm (message passing):

# Forward messages alpha_t(k) = p(x_1, ..., x_t, z_t = k)

# Initialization:
# alpha_1(k) = pi_0(k) * f(x_1 | theta_k)

# Recursion:
# alpha_t(k) = f(x_t | theta_k) * sum_{j=1}^{K} alpha_{t-1}(j) * A(j, k)

# In matrix form:
# alpha_t = (A^T @ alpha_{t-1}) * emission_t

# Log marginal likelihood:
# log p(x_{1:T}) = log sum_k alpha_T(k)

# Cost: O(T * K^2) time, O(K) memory

Tensor variable elimination (TVE):

# Factor graph for HMM:
# Factors: f_init(z_1), f_trans(z_{t-1}, z_t) for t=2..T, f_emit(z_t, x_t) for t=1..T

# Variable elimination order: z_1, z_2, ..., z_T
# At each step: sum out z_t, combining adjacent factors

# TVE generalizes this to plated factor graphs:
# With batch plate of size B:
#   B independent sequences, each of length T
#   Factors are shared across the batch (same A, theta)

# Funsor represents each factor as a lazy tensor:
# f_trans: Tensor(A, dims=("z_prev", "z_curr"))
# f_emit:  Tensor(emissions, dims=("z_curr", "batch", "time"))

# Elimination contracts dims automatically:
# result = f_trans.sum("z_prev") * f_emit
# Funsor determines optimal contraction order

Integration with variational inference:

# For HMMs with continuous parameters theta:
# 1. Funsor exactly marginalizes discrete states z_{1:T}
# 2. Returns log p(x_{1:T} | theta) (exact marginal likelihood)
# 3. SVI optimizes theta (or a variational distribution over theta)

# The gradient grad_theta log p(x | theta) is exact w.r.t. discrete states
# Only continuous parameters need variational approximation
# This eliminates all variance from discrete state enumeration

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment