Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro TruncatedPolyaGamma

From Leeroopedia


Knowledge Sources
Domains Probability_Distributions
Last Updated 2026-02-09 09:00 GMT

Overview

Description

TruncatedPolyaGamma is a distribution class in Pyro that implements the Polya-Gamma(1, 0) distribution truncated to the finite interval (0, 2.5). This truncation enables the use of the distribution in gradient-based inference methods like HMC. The class extends TorchDistribution.

Key implementation details:

  • Truncation point: The support is constrained to (0, 2.5) via the class constant truncation_point = 2.5.
  • Log probability: The log_prob method uses an alternating series approximation with num_log_prob_terms = 7 terms. The computation evaluates even and odd indexed terms separately using logsumexp for numerical stability, then takes the difference. This approach yields accuracy to approximately six decimal places.
  • Sampling: The sample method (not reparameterized, has_rsample = False) uses a rough approximation based on a sum of num_gamma_variates = 8 scaled exponential random variables, clamped to the truncation point. This sampler is intended only for initialization contexts where sample accuracy is not critical.
  • Prototype tensor: Instead of explicit distribution parameters, the constructor takes a prototype tensor that determines the dtype and device of returned values.

The distribution has no learnable parameters (arg_constraints = {}) and scalar batch and event shapes.

Usage

TruncatedPolyaGamma is designed for use in Bayesian logistic regression and related models that employ the Polya-Gamma data augmentation scheme introduced by Polson, Scott, and Windle (2013). In this scheme, Polya-Gamma latent variables are introduced to turn the logistic likelihood into a conditionally Gaussian form, enabling efficient Gibbs sampling or HMC-based inference. The truncation and approximate log_prob make it suitable for HMC where accurate log probability gradients are needed but exact sampling is not.

Code Reference

Source Location

  • File: pyro/distributions/polya_gamma.py
  • Repository: pyro-ppl/pyro

Signature

class TruncatedPolyaGamma(TorchDistribution):
    def __init__(self, prototype, validate_args=None)

Import

from pyro.distributions import TruncatedPolyaGamma

I/O Contract

Inputs

Parameter Type Description
prototype torch.Tensor A prototype tensor used to determine the dtype and device for samples and log_prob computations. Its values are not used.
validate_args bool or None Whether to validate input arguments. Defaults to None.

Outputs

Method Return Type Description
sample(sample_shape) torch.Tensor Returns an approximate sample of shape batch_shape + sample_shape from the truncated Polya-Gamma distribution. Values are clamped to [0, 2.5].
log_prob(value) torch.Tensor Returns the approximate log probability density of value. Accurate to approximately 6 decimal places.
expand(batch_shape) TruncatedPolyaGamma Returns a new instance with the specified batch shape.

Usage Examples

import torch
from pyro.distributions import TruncatedPolyaGamma

# Create a TruncatedPolyaGamma distribution
prototype = torch.tensor(0.0)
dist = TruncatedPolyaGamma(prototype)

# Draw an approximate sample
sample = dist.sample()
print(sample.shape)  # torch.Size([])
print(sample)         # A scalar in (0, 2.5)

# Draw multiple samples
samples = dist.sample((1000,))
print(samples.shape)  # torch.Size([1000])
print(samples.min(), samples.max())  # within (0, 2.5)

# Compute log probability
log_p = dist.log_prob(torch.tensor(0.5))
print(log_p)
import pyro
import pyro.distributions as dist
import torch

# Using TruncatedPolyaGamma in a Bayesian logistic regression model
# with Polya-Gamma data augmentation
def model(X, y):
    p = X.shape[-1]
    beta = pyro.sample("beta", dist.Normal(torch.zeros(p), torch.ones(p)).to_event(1))

    with pyro.plate("data", len(X)):
        eta = X @ beta
        # Polya-Gamma auxiliary variable
        omega = pyro.sample("omega", dist.TruncatedPolyaGamma(X.new_tensor(0.0)))
        # Likelihood using the augmented form
        pyro.sample("y", dist.Normal(eta / (2 * omega), 1.0 / omega.sqrt()), obs=y.float())

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment