Overview
Description
The RelaxedStraightThrough module provides straight-through gradient estimator variants of PyTorch's relaxed discrete distributions: RelaxedOneHotCategoricalStraightThrough and RelaxedBernoulliStraightThrough. These distributions combine the Gumbel-Softmax relaxation technique with the straight-through estimator to enable gradient-based optimization through discrete stochastic nodes in neural networks.
The key insight is that during the forward pass, the distributions produce discrete/quantized samples (hard one-hot vectors or binary values), but during the backward pass, gradients flow through as if using the relaxed/unquantized continuous samples. This is achieved through custom PyTorch autograd functions (QuantizeCategorical and QuantizeBernoulli) that implement the identity function in the backward pass while performing quantization in the forward pass.
The module implements the techniques described in two foundational papers:
- "The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables" by Maddison, Mnih, and Teh
- "Categorical Reparameterization with Gumbel-Softmax" by Jang, Gu, and Poole
Usage
These distributions are designed for use in variational autoencoders (VAEs) and other models that require differentiable sampling from discrete distributions. The straight-through variant is preferred when the downstream computation requires actual discrete samples (e.g., selecting items, indexing) while still needing to backpropagate gradients for training.
Code Reference
Source Location
| Property |
Value
|
| File |
pyro/distributions/relaxed_straight_through.py
|
| Module |
pyro.distributions.relaxed_straight_through
|
| Repository |
pyro-ppl/pyro
|
Signature
class RelaxedOneHotCategoricalStraightThrough(RelaxedOneHotCategorical):
def rsample(self, sample_shape=torch.Size()):
...
def log_prob(self, value):
...
class RelaxedBernoulliStraightThrough(RelaxedBernoulli):
def rsample(self, sample_shape=torch.Size()):
...
def log_prob(self, value):
...
class QuantizeCategorical(torch.autograd.Function):
@staticmethod
def forward(ctx, soft_value):
...
@staticmethod
def backward(ctx, grad):
...
class QuantizeBernoulli(torch.autograd.Function):
@staticmethod
def forward(ctx, soft_value):
...
@staticmethod
def backward(ctx, grad):
...
Import
from pyro.distributions import RelaxedBernoulliStraightThrough
from pyro.distributions import RelaxedOneHotCategoricalStraightThrough
# Or from the module directly:
from pyro.distributions.relaxed_straight_through import (
RelaxedBernoulliStraightThrough,
RelaxedOneHotCategoricalStraightThrough,
)
I/O Contract
RelaxedOneHotCategoricalStraightThrough
| Parameter |
Type |
Description
|
temperature |
torch.Tensor |
Temperature parameter controlling relaxation (lower = more discrete)
|
probs / logits |
torch.Tensor |
Category probabilities or log-odds (inherited from RelaxedOneHotCategorical)
|
| Method |
Return Type |
Description
|
rsample(sample_shape) |
torch.Tensor |
Returns discrete one-hot samples; gradients pass through relaxed samples
|
log_prob(value) |
torch.Tensor |
Returns log probability of the relaxed (unquantized) sample
|
RelaxedBernoulliStraightThrough
| Parameter |
Type |
Description
|
temperature |
torch.Tensor |
Temperature parameter controlling relaxation
|
probs / logits |
torch.Tensor |
Bernoulli probability or log-odds (inherited from RelaxedBernoulli)
|
| Method |
Return Type |
Description
|
rsample(sample_shape) |
torch.Tensor |
Returns discrete binary (0/1) samples; gradients pass through relaxed samples
|
log_prob(value) |
torch.Tensor |
Returns log probability of the relaxed (unquantized) sample
|
QuantizeCategorical (Autograd Function)
| Direction |
Input |
Output |
Description
|
| Forward |
Soft sample (continuous one-hot) |
Hard one-hot vector |
Applies argmax and scatter to produce a one-hot encoding; stores soft value as _unquantize attribute
|
| Backward |
Upstream gradient |
Same gradient (identity) |
Straight-through: passes gradient unchanged
|
QuantizeBernoulli (Autograd Function)
| Direction |
Input |
Output |
Description
|
| Forward |
Soft sample (continuous [0,1]) |
Binary value (0 or 1) |
Applies round() to produce discrete value; stores soft value as _unquantize attribute
|
| Backward |
Upstream gradient |
Same gradient (identity) |
Straight-through: passes gradient unchanged
|
Usage Examples
Straight-Through Gumbel-Softmax for Categorical Variables
import torch
import pyro
import pyro.distributions as dist
# Define a straight-through categorical distribution
temperature = torch.tensor(0.5)
probs = torch.tensor([0.2, 0.3, 0.5])
st_cat = dist.RelaxedOneHotCategoricalStraightThrough(temperature, probs=probs)
# Sample returns discrete one-hot vectors
sample = st_cat.rsample() # e.g., tensor([0., 0., 1.])
print(sample) # Hard one-hot sample
# Log prob uses the underlying relaxed distribution
log_p = st_cat.log_prob(sample)
print(log_p)
Straight-Through Bernoulli in a VAE Encoder
import torch
import pyro
import pyro.distributions as dist
# Binary latent variable with straight-through gradient
temperature = torch.tensor(0.67)
logits = torch.randn(10)
st_bern = dist.RelaxedBernoulliStraightThrough(temperature, logits=logits)
# Forward pass gives discrete 0/1 samples
z = st_bern.rsample() # Binary tensor of shape (10,)
print(z) # e.g., tensor([1., 0., 1., 1., 0., 0., 1., 0., 1., 0.])
# Gradients flow through the relaxed sample in the backward pass
loss = z.sum()
loss.backward()
Using in a Pyro Model
import torch
import pyro
import pyro.distributions as dist
def model(data):
probs = pyro.param("probs", torch.ones(3) / 3)
temperature = torch.tensor(0.5)
with pyro.plate("data", len(data)):
z = pyro.sample(
"z",
dist.RelaxedOneHotCategoricalStraightThrough(temperature, probs=probs),
)
# z is a hard one-hot vector usable for discrete operations
pyro.sample("obs", dist.Normal(z @ torch.randn(3), 1.0), obs=data)
Related Pages
- RelaxedOneHotCategorical -- Base PyTorch distribution providing the Gumbel-Softmax relaxation for categorical variables
- RelaxedBernoulli -- Base PyTorch distribution providing the relaxed Bernoulli (logistic sigmoid)
- OneHotCategoricalStraightThrough -- PyTorch's built-in straight-through variant for OneHotCategorical
- Delta -- Distribution for deterministic values, another approach to discrete representations
- Distributions_Init -- Central registry of all Pyro distributions including these straight-through variants