Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Rejector

From Leeroopedia


Knowledge Sources
Domains Probability_Distributions
Last Updated 2026-02-09 09:00 GMT

Overview

Description

Rejector is a distribution class in Pyro that implements rejection sampling with a differentiable acceptance rate function. It extends TorchDistribution and supports reparameterized sampling (has_rsample = True).

The distribution is defined by three components:

  • propose - A proposal distribution from which candidate samples are drawn.
  • log_prob_accept - A callable that takes a batch of proposed samples and returns a batch of log acceptance probabilities.
  • log_scale - The total log probability of acceptance (the log of the normalizing constant for the acceptance function).

The rsample method implements parallel batched accept-reject sampling: it draws proposals from the proposal distribution, accepts each with probability exp(log_prob_accept(x)), and repeatedly draws new proposals for rejected samples until all positions are accepted. This is done efficiently using boolean masking to avoid resampling already-accepted positions.

The log_prob method computes the log probability as the sum of the proposal's log probability and the normalized log acceptance probability:

log_prob(x) = propose.log_prob(x) + log_prob_accept(x) - log_scale

The score_parts method returns a ScoreParts named tuple that separates the score function (the normalized acceptance log probability) from the full log probability, which is useful for computing gradient estimators in variational inference.

The class includes LRU(1) caches for both _log_prob_accept and _propose_log_prob to avoid redundant computation when the same sample is passed to multiple methods.

Usage

Rejector is useful for constructing custom distributions where the target density can be expressed as a reweighted proposal distribution. It is particularly valuable when the acceptance probability is differentiable, enabling gradient-based inference. Common use cases include truncated distributions, conditional distributions, and distributions defined by density ratios.

Code Reference

Source Location

Signature

class Rejector(TorchDistribution):
    def __init__(self, propose, log_prob_accept, log_scale, *,
                 batch_shape=None, event_shape=None)

Import

from pyro.distributions import Rejector

I/O Contract

Inputs

Parameter Type Description
propose Distribution A proposal distribution. Must support log_prob and be callable (for sampling). If sample_shape is used with rsample, the proposal must support a sample_shape argument.
log_prob_accept callable A function that takes a batch of proposed samples and returns a batch of log acceptance probabilities. Should be differentiable for reparameterized gradients.
log_scale torch.Tensor or float The total log probability of acceptance, used to normalize the acceptance function.
batch_shape torch.Size or None Optional batch shape. Defaults to the proposal's batch shape.
event_shape torch.Size or None Optional event shape. Defaults to the proposal's event shape.

Outputs

Method Return Type Description
rsample(sample_shape) torch.Tensor Returns a reparameterized sample obtained via rejection sampling. Shape is sample_shape + batch_shape + event_shape.
log_prob(x) torch.Tensor Returns the log probability, computed as proposal log_prob plus normalized acceptance log probability.
score_parts(x) ScoreParts Returns a ScoreParts named tuple with log_prob, score_function (the normalized acceptance log probability), and entropy_term.

Usage Examples

import torch
import pyro.distributions as dist

# Create a truncated normal distribution using Rejector
# Target: Normal(0, 1) truncated to [0, inf)
proposal = dist.Normal(0.0, 1.0)

def log_prob_accept(x):
    # Accept if x >= 0
    return torch.where(x >= 0, torch.zeros_like(x), torch.tensor(float('-inf')))

# log_scale = log(0.5) since half of the normal is in [0, inf)
import math
log_scale = torch.tensor(math.log(0.5))

truncated_normal = dist.Rejector(proposal, log_prob_accept, log_scale)

# Sample from the truncated distribution
sample = truncated_normal.rsample((1000,))
print(sample.min())  # Should be >= 0

# Compute log probability
log_p = truncated_normal.log_prob(torch.tensor(0.5))
print(log_p)
import torch
import pyro.distributions as dist

# Soft acceptance function (differentiable)
proposal = dist.Normal(0.0, 2.0)

def soft_log_accept(x):
    # Smoothly upweight samples near x=1
    return -0.5 * (x - 1.0) ** 2

# Approximate log normalizing constant
log_scale = torch.tensor(0.0)  # Approximate

rejector = dist.Rejector(proposal, soft_log_accept, log_scale)

# Use score_parts for gradient estimation
x = rejector.rsample()
parts = rejector.score_parts(x)
print(parts.log_prob)
print(parts.score_function)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment