Implementation:Pyro ppl Pyro Rejector
| Knowledge Sources | |
|---|---|
| Domains | Probability_Distributions |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
Description
Rejector is a distribution class in Pyro that implements rejection sampling with a differentiable acceptance rate function. It extends TorchDistribution and supports reparameterized sampling (has_rsample = True).
The distribution is defined by three components:
- propose - A proposal distribution from which candidate samples are drawn.
- log_prob_accept - A callable that takes a batch of proposed samples and returns a batch of log acceptance probabilities.
- log_scale - The total log probability of acceptance (the log of the normalizing constant for the acceptance function).
The rsample method implements parallel batched accept-reject sampling: it draws proposals from the proposal distribution, accepts each with probability exp(log_prob_accept(x)), and repeatedly draws new proposals for rejected samples until all positions are accepted. This is done efficiently using boolean masking to avoid resampling already-accepted positions.
The log_prob method computes the log probability as the sum of the proposal's log probability and the normalized log acceptance probability:
log_prob(x) = propose.log_prob(x) + log_prob_accept(x) - log_scale
The score_parts method returns a ScoreParts named tuple that separates the score function (the normalized acceptance log probability) from the full log probability, which is useful for computing gradient estimators in variational inference.
The class includes LRU(1) caches for both _log_prob_accept and _propose_log_prob to avoid redundant computation when the same sample is passed to multiple methods.
Usage
Rejector is useful for constructing custom distributions where the target density can be expressed as a reweighted proposal distribution. It is particularly valuable when the acceptance probability is differentiable, enabling gradient-based inference. Common use cases include truncated distributions, conditional distributions, and distributions defined by density ratios.
Code Reference
Source Location
- File:
pyro/distributions/rejector.py - Repository: pyro-ppl/pyro
Signature
class Rejector(TorchDistribution):
def __init__(self, propose, log_prob_accept, log_scale, *,
batch_shape=None, event_shape=None)
Import
from pyro.distributions import Rejector
I/O Contract
Inputs
| Parameter | Type | Description |
|---|---|---|
propose |
Distribution |
A proposal distribution. Must support log_prob and be callable (for sampling). If sample_shape is used with rsample, the proposal must support a sample_shape argument.
|
log_prob_accept |
callable |
A function that takes a batch of proposed samples and returns a batch of log acceptance probabilities. Should be differentiable for reparameterized gradients. |
log_scale |
torch.Tensor or float |
The total log probability of acceptance, used to normalize the acceptance function. |
batch_shape |
torch.Size or None |
Optional batch shape. Defaults to the proposal's batch shape. |
event_shape |
torch.Size or None |
Optional event shape. Defaults to the proposal's event shape. |
Outputs
| Method | Return Type | Description |
|---|---|---|
rsample(sample_shape) |
torch.Tensor |
Returns a reparameterized sample obtained via rejection sampling. Shape is sample_shape + batch_shape + event_shape.
|
log_prob(x) |
torch.Tensor |
Returns the log probability, computed as proposal log_prob plus normalized acceptance log probability. |
score_parts(x) |
ScoreParts |
Returns a ScoreParts named tuple with log_prob, score_function (the normalized acceptance log probability), and entropy_term.
|
Usage Examples
import torch
import pyro.distributions as dist
# Create a truncated normal distribution using Rejector
# Target: Normal(0, 1) truncated to [0, inf)
proposal = dist.Normal(0.0, 1.0)
def log_prob_accept(x):
# Accept if x >= 0
return torch.where(x >= 0, torch.zeros_like(x), torch.tensor(float('-inf')))
# log_scale = log(0.5) since half of the normal is in [0, inf)
import math
log_scale = torch.tensor(math.log(0.5))
truncated_normal = dist.Rejector(proposal, log_prob_accept, log_scale)
# Sample from the truncated distribution
sample = truncated_normal.rsample((1000,))
print(sample.min()) # Should be >= 0
# Compute log probability
log_p = truncated_normal.log_prob(torch.tensor(0.5))
print(log_p)
import torch
import pyro.distributions as dist
# Soft acceptance function (differentiable)
proposal = dist.Normal(0.0, 2.0)
def soft_log_accept(x):
# Smoothly upweight samples near x=1
return -0.5 * (x - 1.0) ** 2
# Approximate log normalizing constant
log_scale = torch.tensor(0.0) # Approximate
rejector = dist.Rejector(proposal, soft_log_accept, log_scale)
# Use score_parts for gradient estimation
x = rejector.rsample()
parts = rejector.score_parts(x)
print(parts.log_prob)
print(parts.score_function)
Related Pages
- Pyro_ppl_Pyro_GaussianScaleMixture - Distribution with custom pathwise derivatives, another distribution with specialized gradient handling
- Pyro_ppl_Pyro_TruncatedPolyaGamma - A truncated distribution that could alternatively be implemented via rejection sampling
- Pyro_ppl_Pyro_ImproperUniform - A non-standard distribution illustrating Pyro's flexible distribution system