Implementation:Pyro ppl Pyro OneTwoMatching
| Knowledge Sources | |
|---|---|
| Domains | Probability_Distributions |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
Description
OneTwoMatching is a discrete probability distribution over matchings from 2*N sources to N destinations, where each source matches exactly one destination and each destination matches exactly two sources. It extends TorchDistribution and is parameterized by a matrix of edge logits of shape (2*N, N).
The log probability of a matching v is defined as the sum of edge logits along the matching minus the log partition function:
log p(v) = sum_s logits[s, v[s]] - log Z
Samples are represented as long tensors of shape (2*N,) with values in {0, ..., N-1}.
The module provides two computational modes:
- Exact computation (when
bp_iters=None): Uses brute-force enumeration of all valid matchings. Theenumerate_supportmethod recursively generates all valid one-two matchings. This is only tractable for small N. - Approximate computation (when
bp_itersis a positive integer): Uses Sinkhorn iteration to compute approximate mean-field beliefs, then evaluates the Bethe free energy to approximate the log partition function. This is based on belief propagation methods from the references by Chertkov, Huang, Vontobel, and others.
The module also includes:
- OneTwoMatchingConstraint - a custom constraint class that validates samples have the correct one-two matching structure.
- enumerate_one_two_matchings - a recursive function that generates all valid matchings for a given number of destinations.
- maximum_weight_matching - computes the maximum probability matching using SciPy's
linear_sum_assignmentby duplicating destination columns and solving a linear assignment problem.
Usage
OneTwoMatching is useful in combinatorial optimization and probabilistic inference problems involving assignment or matching constraints. Applications include multi-object tracking (where each track must be assigned to exactly two detections), bipartite matching problems, and structured prediction tasks. The belief propagation approximation makes it tractable for moderate-sized problems.
Code Reference
Source Location
- File:
pyro/distributions/one_two_matching.py - Repository: pyro-ppl/pyro
Signature
class OneTwoMatching(TorchDistribution):
def __init__(self, logits, *, bp_iters=None, validate_args=None)
Import
from pyro.distributions import OneTwoMatching
I/O Contract
Inputs
| Parameter | Type | Description |
|---|---|---|
logits |
torch.Tensor |
A 2-dimensional tensor of shape (2*N, N) containing real-valued edge logits. Each entry logits[s, d] specifies the log-weight of matching source s to destination d.
|
bp_iters |
int or None |
Number of belief propagation (Sinkhorn) iterations for approximate inference. If None, exact brute-force computation is used. Must be a positive integer if specified.
|
validate_args |
bool or None |
Whether to validate input arguments. Defaults to None.
|
Outputs
| Method | Return Type | Description |
|---|---|---|
sample(sample_shape) |
torch.Tensor |
Returns a matching tensor of shape sample_shape + (2*N,). Uses brute-force categorical sampling when bp_iters=None; raises NotImplementedError for approximate mode with non-empty sample_shape.
|
log_prob(value) |
torch.Tensor |
Returns the log probability of a matching value.
|
log_partition_function |
torch.Tensor |
Lazily computed log partition function (exact or Bethe approximation). |
enumerate_support(expand) |
torch.Tensor |
Returns all valid one-two matchings as a tensor of shape (num_matchings, 2*N).
|
mode() |
torch.Tensor |
Returns the maximum probability matching of shape (2*N,) using SciPy's linear sum assignment.
|
Usage Examples
import torch
from pyro.distributions import OneTwoMatching
# Create a distribution over matchings from 4 sources to 2 destinations
N = 2
logits = torch.randn(2 * N, N)
dist = OneTwoMatching(logits)
# Enumerate all valid matchings
all_matchings = dist.enumerate_support()
print(all_matchings.shape) # torch.Size([3, 4]) for N=2
# Sample a matching (exact mode)
sample = dist.sample()
print(sample.shape) # torch.Size([4])
print(sample) # e.g., tensor([0, 1, 1, 0])
# Compute log probability
log_p = dist.log_prob(sample)
print(log_p.shape) # torch.Size([])
import torch
from pyro.distributions import OneTwoMatching
# Approximate mode using belief propagation (for larger problems)
N = 10
logits = torch.randn(2 * N, N)
dist = OneTwoMatching(logits, bp_iters=20)
# Compute the approximate log partition function
log_Z = dist.log_partition_function
print(log_Z)
# Compute the maximum probability matching
best_matching = dist.mode()
print(best_matching.shape) # torch.Size([20])
# Compute log probability of the mode
log_p = dist.log_prob(best_matching)
print(log_p)
Related Pages
- Pyro_ppl_Pyro_OrderedLogistic - Another discrete distribution in Pyro with structured parameterization
- Pyro_ppl_Pyro_Rejector - Distribution using rejection sampling, another custom sampling strategy