Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro ImproperUniform

From Leeroopedia


Knowledge Sources
Domains Probability_Distributions
Last Updated 2026-02-09 09:00 GMT

Overview

Description

ImproperUniform is a distribution class in Pyro that represents an improper (non-normalizable) uniform distribution. It has a zero log probability everywhere on its support and an undefined sample method. Calling sample() raises a NotImplementedError.

This distribution extends TorchDistribution and is defined by three parameters: a support constraint, a batch shape, and an event shape. The log_prob method returns a tensor of zeros with the appropriate broadcast batch shape, regardless of the input value. The distribution has no learnable parameters (its arg_constraints dict is empty).

The class also supports the expand method for adjusting the batch shape, which creates a new instance with the given batch shape while preserving the support and event shape.

Usage

ImproperUniform is primarily useful for transforming generative directed acyclic graph (DAG) models into factor graph form, which is required by certain inference algorithms such as Hamiltonian Monte Carlo (HMC). In a generative DAG model, variables are sampled sequentially from conditional distributions. In factor graph form, all variables are sampled simultaneously from an improper prior, and the conditional relationships are expressed as observed factors.

An alternative approach for creating a similar distribution that does support sampling is to use .mask(False) on a proper distribution, which zeros out the log probability while retaining the ability to draw samples.

Code Reference

Source Location

  • File: pyro/distributions/improper_uniform.py
  • Repository: pyro-ppl/pyro

Signature

class ImproperUniform(TorchDistribution):
    def __init__(self, support, batch_shape, event_shape)

Import

from pyro.distributions import ImproperUniform

I/O Contract

Inputs

Parameter Type Description
support torch.distributions.constraints.Constraint The support of the distribution, defining the domain over which the distribution is defined (e.g., constraints.real, constraints.positive).
batch_shape torch.Size The batch shape of the distribution, defining the number of independent distributions.
event_shape torch.Size The event shape of the distribution, defining the shape of a single sample event.

Outputs

Method Return Type Description
log_prob(value) torch.Tensor Returns a tensor of zeros with shape broadcast from the value's batch dimensions and the distribution's batch shape.
sample(sample_shape) N/A Not supported. Raises NotImplementedError.
expand(batch_shape) ImproperUniform Returns a new ImproperUniform instance with the specified batch shape.

Usage Examples

import torch
import pyro
import pyro.distributions as dist
from torch.distributions import constraints

# Converting a generative DAG model to factor graph form

# Version 1: Generative DAG
def model_dag():
    x = pyro.sample("x", dist.Normal(0, 1))
    y = pyro.sample("y", dist.Normal(x, 1))
    z = pyro.sample("z", dist.Normal(y, 1))

# Version 2: Factor graph using ImproperUniform
def model_factor_graph():
    xyz = pyro.sample("xyz", dist.ImproperUniform(constraints.real, (), (3,)))
    x, y, z = xyz.unbind(-1)
    pyro.sample("x", dist.Normal(0, 1), obs=x)
    pyro.sample("y", dist.Normal(x, 1), obs=y)
    pyro.sample("z", dist.Normal(y, 1), obs=z)
import torch
import pyro.distributions as dist
from torch.distributions import constraints

# Using ImproperUniform with various supports
uniform_real = dist.ImproperUniform(constraints.real, (), (5,))
uniform_positive = dist.ImproperUniform(constraints.positive, (), (3,))
uniform_batched = dist.ImproperUniform(constraints.real, (10,), (5,))

# log_prob is always zero
value = torch.randn(5)
print(uniform_real.log_prob(value))  # tensor(0.)

# Expanding batch shape
expanded = uniform_real.expand((4,))
print(expanded.batch_shape)  # torch.Size([4])

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment