Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro TorchDistributionMixin

From Leeroopedia


Knowledge Sources
Domains Probability_Distributions, Deep_Learning
Last Updated 2026-02-09 09:00 GMT

Overview

Core mixin and base classes that bridge PyTorch distributions with Pyro's probabilistic programming interface, plus supporting distribution wrappers for masking and batch expansion.

Description

This module defines the fundamental building blocks for all distributions in Pyro:

TorchDistributionMixin is the primary mixin class that provides Pyro compatibility for PyTorch distributions. It inherits from both Pyro's Distribution and Python's Callable. Key features include:

  • Callable interface -- Calling a distribution instance invokes rsample for reparameterized distributions or sample for non-reparameterized ones, enabling the dist(sample_shape) syntax.
  • Shape management -- Provides shape(), batch_shape, event_shape, and event_dim properties following PyTorch's shape convention: sample_shape + batch_shape + event_shape.
  • Shape inference -- The class method infer_shapes(**arg_shapes) statically infers batch_shape and event_shape from argument shapes without constructing actual tensors.
  • expand/expand_by -- expand(batch_shape) returns an ExpandedDistribution; expand_by(sample_shape) prepends dimensions to the batch shape.
  • to_event(n) -- Reinterprets n rightmost batch dimensions as event dimensions, wrapping in Independent. Supports negative values to remove event dimensions.
  • mask(mask) -- Returns a MaskedDistribution that zeros out log-prob contributions where mask is False.

TorchDistribution is the recommended base class for new Pyro distributions. It inherits from both torch.distributions.Distribution and TorchDistributionMixin, providing a default expand method.

MaskedDistribution wraps a base distribution with a boolean mask. When mask is False, log_prob, score_parts, and KL divergence computations are skipped entirely, returning zeros. This is essential for handling missing data and conditional observations in Pyro models.

ExpandedDistribution wraps a base distribution with an expanded batch shape, handling both expanded dimensions (new leading dimensions) and interstitial dimensions (broadcasting size-1 dimensions to larger sizes). It correctly transposes sample dimensions during sampling to match the expanded shape.

The module also registers a KL divergence handler for MaskedDistribution pairs that properly handles mask intersection logic.

Usage

Use TorchDistribution as the base class when implementing new distributions for Pyro. Use TorchDistributionMixin only when wrapping existing PyTorch distributions. Use the mask and to_event methods in Pyro models to control observation masking and shape semantics.

Code Reference

Source Location

Signature

class TorchDistributionMixin(Distribution, Callable):
    def __call__(self, sample_shape=torch.Size()) -> torch.Tensor
    def shape(self, sample_shape=torch.Size()) -> torch.Size
    @classmethod
    def infer_shapes(cls, **arg_shapes) -> tuple
    def expand(self, batch_shape, _instance=None) -> "ExpandedDistribution"
    def expand_by(self, sample_shape) -> "ExpandedDistribution"
    def to_event(self, reinterpreted_batch_ndims=None)
    def mask(self, mask) -> "MaskedDistribution"

class TorchDistribution(torch.distributions.Distribution, TorchDistributionMixin):
    pass

class MaskedDistribution(TorchDistribution):
    def __init__(self, base_dist, mask)

class ExpandedDistribution(TorchDistribution):
    def __init__(self, base_dist, batch_shape=torch.Size())

Import

from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.torch_distribution import MaskedDistribution
from pyro.distributions.torch_distribution import ExpandedDistribution

I/O Contract

TorchDistributionMixin.__call__

Name Type Required Description
sample_shape torch.Size No Shape of the iid batch to draw. Defaults to empty.
Name Type Description
return torch.Tensor A reparameterized sample (via rsample) if available, otherwise a non-reparameterized sample.

TorchDistributionMixin.infer_shapes

Name Type Required Description
**arg_shapes dict of torch.Size Yes Keyword arguments mapping parameter names to their shapes.
Name Type Description
return tuple A pair (batch_shape, event_shape) of torch.Size.

MaskedDistribution Inputs

Name Type Required Description
base_dist Distribution Yes The underlying distribution to mask.
mask bool or torch.BoolTensor Yes Boolean mask broadcastable to base_dist.batch_shape. When False, log_prob returns zero.

MaskedDistribution Outputs

Name Type Description
sample torch.Tensor Samples from the base distribution (unaffected by mask).
log_prob torch.Tensor Masked log probability; zero where mask is False.
score_parts ScoreParts Masked score parts for gradient estimation.

ExpandedDistribution Inputs

Name Type Required Description
base_dist Distribution Yes The underlying distribution to expand.
batch_shape torch.Size No Target batch shape. Must be broadcastable from the base distribution's batch shape.

ExpandedDistribution Outputs

Name Type Description
sample torch.Tensor Samples with shape sample_shape + expanded_batch_shape + event_shape.
log_prob torch.Tensor Log probability expanded to the new batch shape.

Usage Examples

Creating a New Distribution

import torch
from pyro.distributions.torch_distribution import TorchDistribution
from torch.distributions import constraints

class MyDistribution(TorchDistribution):
    arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
    support = constraints.real
    has_rsample = True

    def __init__(self, loc, scale, validate_args=None):
        self.loc, self.scale = torch.broadcast_tensors(loc, scale)
        super().__init__(self.loc.shape, validate_args=validate_args)

    def rsample(self, sample_shape=torch.Size()):
        shape = self._extended_shape(sample_shape)
        eps = torch.randn(shape, device=self.loc.device)
        return self.loc + self.scale * eps

    def log_prob(self, value):
        return -((value - self.loc) / self.scale).pow(2) / 2 - self.scale.log()

Using to_event for Multivariate Modeling

import pyro.distributions as dist
import torch

# Independent Normal over 5 dimensions
d = dist.Normal(torch.zeros(5), torch.ones(5)).to_event(1)
print(d.batch_shape)  # torch.Size([])
print(d.event_shape)  # torch.Size([5])

x = d.sample()
print(d.log_prob(x).shape)  # torch.Size([])

Masking Observations

import torch
import pyro
import pyro.distributions as dist

data = torch.randn(100)
mask = torch.rand(100) > 0.3  # 70% observed

with pyro.plate("data", 100):
    pyro.sample("obs", dist.Normal(0.0, 1.0).mask(mask), obs=data)

Callable Interface

import torch
import pyro.distributions as dist

d = dist.Normal(0.0, 1.0)
# Calling the distribution samples with reparameterization
x = d(torch.Size([10]))
print(x.shape)  # torch.Size([10])
print(x.requires_grad)  # True (reparameterized)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment