Implementation:Pyro ppl Pyro TorchDistributionMixin
| Knowledge Sources | |
|---|---|
| Domains | Probability_Distributions, Deep_Learning |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
Core mixin and base classes that bridge PyTorch distributions with Pyro's probabilistic programming interface, plus supporting distribution wrappers for masking and batch expansion.
Description
This module defines the fundamental building blocks for all distributions in Pyro:
TorchDistributionMixin is the primary mixin class that provides Pyro compatibility for PyTorch distributions. It inherits from both Pyro's Distribution and Python's Callable. Key features include:
- Callable interface -- Calling a distribution instance invokes
rsamplefor reparameterized distributions orsamplefor non-reparameterized ones, enabling thedist(sample_shape)syntax. - Shape management -- Provides
shape(),batch_shape,event_shape, andevent_dimproperties following PyTorch's shape convention:sample_shape + batch_shape + event_shape. - Shape inference -- The class method
infer_shapes(**arg_shapes)statically infersbatch_shapeandevent_shapefrom argument shapes without constructing actual tensors. - expand/expand_by --
expand(batch_shape)returns anExpandedDistribution;expand_by(sample_shape)prepends dimensions to the batch shape. - to_event(n) -- Reinterprets
nrightmost batch dimensions as event dimensions, wrapping inIndependent. Supports negative values to remove event dimensions. - mask(mask) -- Returns a
MaskedDistributionthat zeros out log-prob contributions where mask is False.
TorchDistribution is the recommended base class for new Pyro distributions. It inherits from both torch.distributions.Distribution and TorchDistributionMixin, providing a default expand method.
MaskedDistribution wraps a base distribution with a boolean mask. When mask is False, log_prob, score_parts, and KL divergence computations are skipped entirely, returning zeros. This is essential for handling missing data and conditional observations in Pyro models.
ExpandedDistribution wraps a base distribution with an expanded batch shape, handling both expanded dimensions (new leading dimensions) and interstitial dimensions (broadcasting size-1 dimensions to larger sizes). It correctly transposes sample dimensions during sampling to match the expanded shape.
The module also registers a KL divergence handler for MaskedDistribution pairs that properly handles mask intersection logic.
Usage
Use TorchDistribution as the base class when implementing new distributions for Pyro. Use TorchDistributionMixin only when wrapping existing PyTorch distributions. Use the mask and to_event methods in Pyro models to control observation masking and shape semantics.
Code Reference
Source Location
- Repository: Pyro
- File: pyro/distributions/torch_distribution.py
Signature
class TorchDistributionMixin(Distribution, Callable):
def __call__(self, sample_shape=torch.Size()) -> torch.Tensor
def shape(self, sample_shape=torch.Size()) -> torch.Size
@classmethod
def infer_shapes(cls, **arg_shapes) -> tuple
def expand(self, batch_shape, _instance=None) -> "ExpandedDistribution"
def expand_by(self, sample_shape) -> "ExpandedDistribution"
def to_event(self, reinterpreted_batch_ndims=None)
def mask(self, mask) -> "MaskedDistribution"
class TorchDistribution(torch.distributions.Distribution, TorchDistributionMixin):
pass
class MaskedDistribution(TorchDistribution):
def __init__(self, base_dist, mask)
class ExpandedDistribution(TorchDistribution):
def __init__(self, base_dist, batch_shape=torch.Size())
Import
from pyro.distributions.torch_distribution import TorchDistributionMixin
from pyro.distributions.torch_distribution import TorchDistribution
from pyro.distributions.torch_distribution import MaskedDistribution
from pyro.distributions.torch_distribution import ExpandedDistribution
I/O Contract
TorchDistributionMixin.__call__
| Name | Type | Required | Description |
|---|---|---|---|
| sample_shape | torch.Size | No | Shape of the iid batch to draw. Defaults to empty. |
| Name | Type | Description |
|---|---|---|
| return | torch.Tensor | A reparameterized sample (via rsample) if available, otherwise a non-reparameterized sample. |
TorchDistributionMixin.infer_shapes
| Name | Type | Required | Description |
|---|---|---|---|
| **arg_shapes | dict of torch.Size | Yes | Keyword arguments mapping parameter names to their shapes. |
| Name | Type | Description |
|---|---|---|
| return | tuple | A pair (batch_shape, event_shape) of torch.Size.
|
MaskedDistribution Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| base_dist | Distribution | Yes | The underlying distribution to mask. |
| mask | bool or torch.BoolTensor | Yes | Boolean mask broadcastable to base_dist.batch_shape. When False, log_prob returns zero.
|
MaskedDistribution Outputs
| Name | Type | Description |
|---|---|---|
| sample | torch.Tensor | Samples from the base distribution (unaffected by mask). |
| log_prob | torch.Tensor | Masked log probability; zero where mask is False. |
| score_parts | ScoreParts | Masked score parts for gradient estimation. |
ExpandedDistribution Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| base_dist | Distribution | Yes | The underlying distribution to expand. |
| batch_shape | torch.Size | No | Target batch shape. Must be broadcastable from the base distribution's batch shape. |
ExpandedDistribution Outputs
| Name | Type | Description |
|---|---|---|
| sample | torch.Tensor | Samples with shape sample_shape + expanded_batch_shape + event_shape.
|
| log_prob | torch.Tensor | Log probability expanded to the new batch shape. |
Usage Examples
Creating a New Distribution
import torch
from pyro.distributions.torch_distribution import TorchDistribution
from torch.distributions import constraints
class MyDistribution(TorchDistribution):
arg_constraints = {"loc": constraints.real, "scale": constraints.positive}
support = constraints.real
has_rsample = True
def __init__(self, loc, scale, validate_args=None):
self.loc, self.scale = torch.broadcast_tensors(loc, scale)
super().__init__(self.loc.shape, validate_args=validate_args)
def rsample(self, sample_shape=torch.Size()):
shape = self._extended_shape(sample_shape)
eps = torch.randn(shape, device=self.loc.device)
return self.loc + self.scale * eps
def log_prob(self, value):
return -((value - self.loc) / self.scale).pow(2) / 2 - self.scale.log()
Using to_event for Multivariate Modeling
import pyro.distributions as dist
import torch
# Independent Normal over 5 dimensions
d = dist.Normal(torch.zeros(5), torch.ones(5)).to_event(1)
print(d.batch_shape) # torch.Size([])
print(d.event_shape) # torch.Size([5])
x = d.sample()
print(d.log_prob(x).shape) # torch.Size([])
Masking Observations
import torch
import pyro
import pyro.distributions as dist
data = torch.randn(100)
mask = torch.rand(100) > 0.3 # 70% observed
with pyro.plate("data", 100):
pyro.sample("obs", dist.Normal(0.0, 1.0).mask(mask), obs=data)
Callable Interface
import torch
import pyro.distributions as dist
d = dist.Normal(0.0, 1.0)
# Calling the distribution samples with reparameterization
x = d(torch.Size([10]))
print(x.shape) # torch.Size([10])
print(x.requires_grad) # True (reparameterized)