Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:Pyro ppl Pyro Plated Einsum

From Leeroopedia


Knowledge Sources
Domains Tensor Algebra, Graphical Models, Exact Inference
Last Updated 2026-02-09 09:00 GMT

Overview

Plated einsum extends the Einstein summation convention to factor graphs with plate (conditional independence) structure, enabling efficient tensor contraction that exploits repeated structure to reduce computational cost.

Description

Einstein summation (einsum) is a compact notation for tensor contractions:

C_ij = sum_k A_ik * B_kj is written as einsum("ik,kj->ij", A, B)

In probabilistic models, plates (conditional independence structures) introduce repeated dimensions that are semantically different from dimensions being summed over. Standard einsum does not distinguish between these two types of dimensions.

Plated einsum extends einsum with awareness of plates. It distinguishes:

  • Eliminated dimensions: Dimensions that are summed over (marginalized out).
  • Plate dimensions: Dimensions representing independent repetitions that are preserved.

The key insight is that when a factor graph has plate structure, the order of tensor contractions can exploit this structure to dramatically reduce cost. For example, if a factor f(x, y) is inside a plate of size N:

  • Naive computation: Form the full N x |x| x |y| tensor, then contract.
  • Plated computation: Contract f(x, y) once (|x| x |y|), then broadcast across the plate dimension.

This distinction can reduce exponential cost to polynomial cost in many structured models, including HMMs, topic models, and other models with conditional independence structure.

Usage

Use plated einsum when:

  • Performing exact marginalization in factor graphs with plate structure.
  • Implementing variable elimination for discrete latent variables in plated models.
  • Optimizing the contraction order of tensor networks with repeated structure.
  • Building efficient inference backends for probabilistic programming languages.

Theoretical Basis

Standard einsum:

# einsum("ij,jk->ik", A, B) = matrix multiply
# einsum("ii->", A) = trace
# einsum("ij->i", A) = row sums

# General: einsum(subscripts, *operands)
# subscripts: comma-separated index labels for each operand, '->' output indices
# Indices in operands but not in output are summed over

Plated einsum extension:

# Notation: plated_einsum(equation, *operands, plates=set_of_plate_dims)

# Example: HMM with batch plate 'b' and time plate 't'
# Factors:
#   transition(prev, curr): shape (K, K)
#   emission(state, obs):   shape (K, V)
#   data(batch, time):      shape (B, T)  -- observed

# Standard einsum would require O(K^T * B) computation
# Plated einsum with plates={'b', 't'}:
#   - Contract transition and emission first: O(K^2 * V)
#   - Then multiply across plate: O(B * T * K)
#   Total: O(B * T * K^2) instead of O(K^T * B)

Optimal contraction order:

# Given factors f_1, f_2, ..., f_n with shared indices:
# The cost of computing the result depends on contraction order

# Example: three factors
# f1(a, b), f2(b, c), f3(c, d)
# To compute sum_{b,c} f1(a,b) * f2(b,c) * f3(c,d):

# Order 1: contract f1*f2 first, then with f3
# Cost: O(|a|*|b|*|c|) + O(|a|*|c|*|d|) = O(K^3)

# Order 2: contract f2*f3 first, then with f1
# Cost: O(|b|*|c|*|d|) + O(|a|*|b|*|d|) = O(K^3)

# With plates, the cost can differ dramatically:
# If 'a' is a plate of size N:
# Order 1: O(K^3) + O(N*K^2) -- much better if N >> K
# Order 2: O(K^3) + O(N*K^2) -- same in this case

# Finding optimal order is NP-hard in general
# Heuristics: greedy contraction, dynamic programming for small graphs

Connection to variable elimination:

# Variable elimination in a factor graph is equivalent to
# a sequence of einsum operations:

# Eliminate variable x from factors involving x:
# new_factor = sum_x product of factors containing x
# = einsum over x: f1 * f2 * ... (all factors containing x)

# Plated einsum handles the case where some factors are
# "inside" plates (replicated) and some are "outside" (shared)

# The plate structure determines which contractions can be
# factored into smaller, independent computations

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment