Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro RelaxedStraightThrough

From Leeroopedia


Attribute Value
Sources pyro/distributions/relaxed_straight_through.py
Domains Probabilistic Programming, Discrete Variational Inference, Gradient Estimation
Last Updated 2026-02-09

Overview

Description

The RelaxedStraightThrough module provides straight-through gradient estimator variants of PyTorch's relaxed discrete distributions: RelaxedOneHotCategoricalStraightThrough and RelaxedBernoulliStraightThrough. These distributions combine the Gumbel-Softmax relaxation technique with the straight-through estimator to enable gradient-based optimization through discrete stochastic nodes in neural networks.

The key insight is that during the forward pass, the distributions produce discrete/quantized samples (hard one-hot vectors or binary values), but during the backward pass, gradients flow through as if using the relaxed/unquantized continuous samples. This is achieved through custom PyTorch autograd functions (QuantizeCategorical and QuantizeBernoulli) that implement the identity function in the backward pass while performing quantization in the forward pass.

The module implements the techniques described in two foundational papers:

  • "The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables" by Maddison, Mnih, and Teh
  • "Categorical Reparameterization with Gumbel-Softmax" by Jang, Gu, and Poole

Usage

These distributions are designed for use in variational autoencoders (VAEs) and other models that require differentiable sampling from discrete distributions. The straight-through variant is preferred when the downstream computation requires actual discrete samples (e.g., selecting items, indexing) while still needing to backpropagate gradients for training.

Code Reference

Source Location

Property Value
File pyro/distributions/relaxed_straight_through.py
Module pyro.distributions.relaxed_straight_through
Repository pyro-ppl/pyro

Signature

class RelaxedOneHotCategoricalStraightThrough(RelaxedOneHotCategorical):
    def rsample(self, sample_shape=torch.Size()):
        ...
    def log_prob(self, value):
        ...

class RelaxedBernoulliStraightThrough(RelaxedBernoulli):
    def rsample(self, sample_shape=torch.Size()):
        ...
    def log_prob(self, value):
        ...

class QuantizeCategorical(torch.autograd.Function):
    @staticmethod
    def forward(ctx, soft_value):
        ...
    @staticmethod
    def backward(ctx, grad):
        ...

class QuantizeBernoulli(torch.autograd.Function):
    @staticmethod
    def forward(ctx, soft_value):
        ...
    @staticmethod
    def backward(ctx, grad):
        ...

Import

from pyro.distributions import RelaxedBernoulliStraightThrough
from pyro.distributions import RelaxedOneHotCategoricalStraightThrough

# Or from the module directly:
from pyro.distributions.relaxed_straight_through import (
    RelaxedBernoulliStraightThrough,
    RelaxedOneHotCategoricalStraightThrough,
)

I/O Contract

RelaxedOneHotCategoricalStraightThrough

Parameter Type Description
temperature torch.Tensor Temperature parameter controlling relaxation (lower = more discrete)
probs / logits torch.Tensor Category probabilities or log-odds (inherited from RelaxedOneHotCategorical)
Method Return Type Description
rsample(sample_shape) torch.Tensor Returns discrete one-hot samples; gradients pass through relaxed samples
log_prob(value) torch.Tensor Returns log probability of the relaxed (unquantized) sample

RelaxedBernoulliStraightThrough

Parameter Type Description
temperature torch.Tensor Temperature parameter controlling relaxation
probs / logits torch.Tensor Bernoulli probability or log-odds (inherited from RelaxedBernoulli)
Method Return Type Description
rsample(sample_shape) torch.Tensor Returns discrete binary (0/1) samples; gradients pass through relaxed samples
log_prob(value) torch.Tensor Returns log probability of the relaxed (unquantized) sample

QuantizeCategorical (Autograd Function)

Direction Input Output Description
Forward Soft sample (continuous one-hot) Hard one-hot vector Applies argmax and scatter to produce a one-hot encoding; stores soft value as _unquantize attribute
Backward Upstream gradient Same gradient (identity) Straight-through: passes gradient unchanged

QuantizeBernoulli (Autograd Function)

Direction Input Output Description
Forward Soft sample (continuous [0,1]) Binary value (0 or 1) Applies round() to produce discrete value; stores soft value as _unquantize attribute
Backward Upstream gradient Same gradient (identity) Straight-through: passes gradient unchanged

Usage Examples

Straight-Through Gumbel-Softmax for Categorical Variables

import torch
import pyro
import pyro.distributions as dist

# Define a straight-through categorical distribution
temperature = torch.tensor(0.5)
probs = torch.tensor([0.2, 0.3, 0.5])
st_cat = dist.RelaxedOneHotCategoricalStraightThrough(temperature, probs=probs)

# Sample returns discrete one-hot vectors
sample = st_cat.rsample()  # e.g., tensor([0., 0., 1.])
print(sample)  # Hard one-hot sample

# Log prob uses the underlying relaxed distribution
log_p = st_cat.log_prob(sample)
print(log_p)

Straight-Through Bernoulli in a VAE Encoder

import torch
import pyro
import pyro.distributions as dist

# Binary latent variable with straight-through gradient
temperature = torch.tensor(0.67)
logits = torch.randn(10)
st_bern = dist.RelaxedBernoulliStraightThrough(temperature, logits=logits)

# Forward pass gives discrete 0/1 samples
z = st_bern.rsample()  # Binary tensor of shape (10,)
print(z)  # e.g., tensor([1., 0., 1., 1., 0., 0., 1., 0., 1., 0.])

# Gradients flow through the relaxed sample in the backward pass
loss = z.sum()
loss.backward()

Using in a Pyro Model

import torch
import pyro
import pyro.distributions as dist

def model(data):
    probs = pyro.param("probs", torch.ones(3) / 3)
    temperature = torch.tensor(0.5)
    with pyro.plate("data", len(data)):
        z = pyro.sample(
            "z",
            dist.RelaxedOneHotCategoricalStraightThrough(temperature, probs=probs),
        )
        # z is a hard one-hot vector usable for discrete operations
        pyro.sample("obs", dist.Normal(z @ torch.randn(3), 1.0), obs=data)

Related Pages

  • RelaxedOneHotCategorical -- Base PyTorch distribution providing the Gumbel-Softmax relaxation for categorical variables
  • RelaxedBernoulli -- Base PyTorch distribution providing the relaxed Bernoulli (logistic sigmoid)
  • OneHotCategoricalStraightThrough -- PyTorch's built-in straight-through variant for OneHotCategorical
  • Delta -- Distribution for deterministic values, another approach to discrete representations
  • Distributions_Init -- Central registry of all Pyro distributions including these straight-through variants

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment