Implementation:Pyro ppl Pyro TruncatedPolyaGamma
| Knowledge Sources | |
|---|---|
| Domains | Probability_Distributions |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
Description
TruncatedPolyaGamma is a distribution class in Pyro that implements the Polya-Gamma(1, 0) distribution truncated to the finite interval (0, 2.5). This truncation enables the use of the distribution in gradient-based inference methods like HMC. The class extends TorchDistribution.
Key implementation details:
- Truncation point: The support is constrained to
(0, 2.5)via the class constanttruncation_point = 2.5. - Log probability: The
log_probmethod uses an alternating series approximation withnum_log_prob_terms = 7terms. The computation evaluates even and odd indexed terms separately usinglogsumexpfor numerical stability, then takes the difference. This approach yields accuracy to approximately six decimal places. - Sampling: The
samplemethod (not reparameterized,has_rsample = False) uses a rough approximation based on a sum ofnum_gamma_variates = 8scaled exponential random variables, clamped to the truncation point. This sampler is intended only for initialization contexts where sample accuracy is not critical. - Prototype tensor: Instead of explicit distribution parameters, the constructor takes a
prototypetensor that determines thedtypeanddeviceof returned values.
The distribution has no learnable parameters (arg_constraints = {}) and scalar batch and event shapes.
Usage
TruncatedPolyaGamma is designed for use in Bayesian logistic regression and related models that employ the Polya-Gamma data augmentation scheme introduced by Polson, Scott, and Windle (2013). In this scheme, Polya-Gamma latent variables are introduced to turn the logistic likelihood into a conditionally Gaussian form, enabling efficient Gibbs sampling or HMC-based inference. The truncation and approximate log_prob make it suitable for HMC where accurate log probability gradients are needed but exact sampling is not.
Code Reference
Source Location
- File:
pyro/distributions/polya_gamma.py - Repository: pyro-ppl/pyro
Signature
class TruncatedPolyaGamma(TorchDistribution):
def __init__(self, prototype, validate_args=None)
Import
from pyro.distributions import TruncatedPolyaGamma
I/O Contract
Inputs
| Parameter | Type | Description |
|---|---|---|
prototype |
torch.Tensor |
A prototype tensor used to determine the dtype and device for samples and log_prob computations. Its values are not used.
|
validate_args |
bool or None |
Whether to validate input arguments. Defaults to None.
|
Outputs
| Method | Return Type | Description |
|---|---|---|
sample(sample_shape) |
torch.Tensor |
Returns an approximate sample of shape batch_shape + sample_shape from the truncated Polya-Gamma distribution. Values are clamped to [0, 2.5].
|
log_prob(value) |
torch.Tensor |
Returns the approximate log probability density of value. Accurate to approximately 6 decimal places.
|
expand(batch_shape) |
TruncatedPolyaGamma |
Returns a new instance with the specified batch shape. |
Usage Examples
import torch
from pyro.distributions import TruncatedPolyaGamma
# Create a TruncatedPolyaGamma distribution
prototype = torch.tensor(0.0)
dist = TruncatedPolyaGamma(prototype)
# Draw an approximate sample
sample = dist.sample()
print(sample.shape) # torch.Size([])
print(sample) # A scalar in (0, 2.5)
# Draw multiple samples
samples = dist.sample((1000,))
print(samples.shape) # torch.Size([1000])
print(samples.min(), samples.max()) # within (0, 2.5)
# Compute log probability
log_p = dist.log_prob(torch.tensor(0.5))
print(log_p)
import pyro
import pyro.distributions as dist
import torch
# Using TruncatedPolyaGamma in a Bayesian logistic regression model
# with Polya-Gamma data augmentation
def model(X, y):
p = X.shape[-1]
beta = pyro.sample("beta", dist.Normal(torch.zeros(p), torch.ones(p)).to_event(1))
with pyro.plate("data", len(X)):
eta = X @ beta
# Polya-Gamma auxiliary variable
omega = pyro.sample("omega", dist.TruncatedPolyaGamma(X.new_tensor(0.0)))
# Likelihood using the augmented form
pyro.sample("y", dist.Normal(eta / (2 * omega), 1.0 / omega.sqrt()), obs=y.float())
Related Pages
- Pyro_ppl_Pyro_ImproperUniform - Another distribution with specialized support constraints
- Pyro_ppl_Pyro_Rejector - Rejection sampling distribution, another distribution with approximate sampling