Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro OneTwoMatching

From Leeroopedia
Revision as of 16:24, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pyro_ppl_Pyro_OneTwoMatching.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Probability_Distributions
Last Updated 2026-02-09 09:00 GMT

Overview

Description

OneTwoMatching is a discrete probability distribution over matchings from 2*N sources to N destinations, where each source matches exactly one destination and each destination matches exactly two sources. It extends TorchDistribution and is parameterized by a matrix of edge logits of shape (2*N, N).

The log probability of a matching v is defined as the sum of edge logits along the matching minus the log partition function:

log p(v) = sum_s logits[s, v[s]] - log Z

Samples are represented as long tensors of shape (2*N,) with values in {0, ..., N-1}.

The module provides two computational modes:

  • Exact computation (when bp_iters=None): Uses brute-force enumeration of all valid matchings. The enumerate_support method recursively generates all valid one-two matchings. This is only tractable for small N.
  • Approximate computation (when bp_iters is a positive integer): Uses Sinkhorn iteration to compute approximate mean-field beliefs, then evaluates the Bethe free energy to approximate the log partition function. This is based on belief propagation methods from the references by Chertkov, Huang, Vontobel, and others.

The module also includes:

  • OneTwoMatchingConstraint - a custom constraint class that validates samples have the correct one-two matching structure.
  • enumerate_one_two_matchings - a recursive function that generates all valid matchings for a given number of destinations.
  • maximum_weight_matching - computes the maximum probability matching using SciPy's linear_sum_assignment by duplicating destination columns and solving a linear assignment problem.

Usage

OneTwoMatching is useful in combinatorial optimization and probabilistic inference problems involving assignment or matching constraints. Applications include multi-object tracking (where each track must be assigned to exactly two detections), bipartite matching problems, and structured prediction tasks. The belief propagation approximation makes it tractable for moderate-sized problems.

Code Reference

Source Location

  • File: pyro/distributions/one_two_matching.py
  • Repository: pyro-ppl/pyro

Signature

class OneTwoMatching(TorchDistribution):
    def __init__(self, logits, *, bp_iters=None, validate_args=None)

Import

from pyro.distributions import OneTwoMatching

I/O Contract

Inputs

Parameter Type Description
logits torch.Tensor A 2-dimensional tensor of shape (2*N, N) containing real-valued edge logits. Each entry logits[s, d] specifies the log-weight of matching source s to destination d.
bp_iters int or None Number of belief propagation (Sinkhorn) iterations for approximate inference. If None, exact brute-force computation is used. Must be a positive integer if specified.
validate_args bool or None Whether to validate input arguments. Defaults to None.

Outputs

Method Return Type Description
sample(sample_shape) torch.Tensor Returns a matching tensor of shape sample_shape + (2*N,). Uses brute-force categorical sampling when bp_iters=None; raises NotImplementedError for approximate mode with non-empty sample_shape.
log_prob(value) torch.Tensor Returns the log probability of a matching value.
log_partition_function torch.Tensor Lazily computed log partition function (exact or Bethe approximation).
enumerate_support(expand) torch.Tensor Returns all valid one-two matchings as a tensor of shape (num_matchings, 2*N).
mode() torch.Tensor Returns the maximum probability matching of shape (2*N,) using SciPy's linear sum assignment.

Usage Examples

import torch
from pyro.distributions import OneTwoMatching

# Create a distribution over matchings from 4 sources to 2 destinations
N = 2
logits = torch.randn(2 * N, N)
dist = OneTwoMatching(logits)

# Enumerate all valid matchings
all_matchings = dist.enumerate_support()
print(all_matchings.shape)  # torch.Size([3, 4]) for N=2

# Sample a matching (exact mode)
sample = dist.sample()
print(sample.shape)  # torch.Size([4])
print(sample)         # e.g., tensor([0, 1, 1, 0])

# Compute log probability
log_p = dist.log_prob(sample)
print(log_p.shape)  # torch.Size([])
import torch
from pyro.distributions import OneTwoMatching

# Approximate mode using belief propagation (for larger problems)
N = 10
logits = torch.randn(2 * N, N)
dist = OneTwoMatching(logits, bp_iters=20)

# Compute the approximate log partition function
log_Z = dist.log_partition_function
print(log_Z)

# Compute the maximum probability matching
best_matching = dist.mode()
print(best_matching.shape)  # torch.Size([20])

# Compute log probability of the mode
log_p = dist.log_prob(best_matching)
print(log_p)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment