Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro FoldedDistribution

From Leeroopedia
Revision as of 16:23, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pyro_ppl_Pyro_FoldedDistribution.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Probability_Distributions
Last Updated 2026-02-09 09:00 GMT

Overview

Description

FoldedDistribution implements a "folded" version of any univariate base distribution by applying the absolute value transform. It is equivalent to TransformedDistribution(base_dist, AbsTransform()) but additionally provides a correct log_prob implementation that accounts for the folding.

The folding operation takes a distribution supported on the real line and maps it to the positive reals by reflecting negative values through zero. The resulting distribution has support constraints.positive. The log_prob method computes the log probability by summing the densities at both +value and -value under the base distribution using logsumexp:

log_prob(value) = logsumexp(base_dist.log_prob(+value), base_dist.log_prob(-value))

This is mathematically correct because the folded density at a point x > 0 is f(x) + f(-x), where f is the base density. The class restricts the base distribution to be univariate (empty event_shape) and raises a ValueError otherwise.

The class extends TransformedDistribution and inherits sampling behavior from the absolute value transform while overriding log_prob and expand for correctness.

Usage

FoldedDistribution is commonly used to construct folded normal distributions (also known as half-normal when the mean is zero) and other folded variants. It is useful in Bayesian modeling when a parameter is known to be positive but the natural model involves a symmetric distribution. For example, a folded normal prior on a scale parameter arises naturally when the scale is the absolute value of a normally distributed latent variable.

Code Reference

Source Location

pyro/distributions/folded.py

Signature

class FoldedDistribution(TransformedDistribution):
    support = constraints.positive

    def __init__(self, base_dist, validate_args=None):
        ...

Import

from pyro.distributions import FoldedDistribution

I/O Contract

Inputs

Parameter Type Description
base_dist torch.distributions.Distribution A univariate distribution to fold. Must have empty event_shape (i.e., scalar-valued).
validate_args bool or None Whether to validate input arguments. Defaults to None (inherits from base distribution).

Outputs

Method Return Type Description
sample(sample_shape) torch.Tensor Draws samples from the base distribution and applies the absolute value transform. Returns positive values.
log_prob(value) torch.Tensor Computes the folded log probability as logsumexp(base.log_prob(+value), base.log_prob(-value)).
expand(batch_shape) FoldedDistribution Returns a new FoldedDistribution with expanded batch dimensions.

Usage Examples

import torch
from pyro.distributions import FoldedDistribution
from torch.distributions import Normal

# Create a folded normal distribution (fold a Normal(2, 1) distribution)
base = Normal(torch.tensor(2.0), torch.tensor(1.0))
folded = FoldedDistribution(base)

# Draw samples (all positive)
samples = folded.sample((1000,))
print("Min sample:", samples.min().item())  # >= 0
print("Mean sample:", samples.mean().item())

# Compute log probability
value = torch.tensor(1.5)
log_p = folded.log_prob(value)
print("log_prob(1.5):", log_p.item())

# Half-normal distribution (folded standard normal)
half_normal = FoldedDistribution(Normal(0.0, 1.0))
samples = half_normal.sample((1000,))
print("Half-normal mean:", samples.mean().item())  # approximately sqrt(2/pi) ~ 0.798

# Batched usage
base_batch = Normal(torch.zeros(5), torch.ones(5))
folded_batch = FoldedDistribution(base_batch)
samples = folded_batch.sample((100,))
print("Batched sample shape:", samples.shape)  # (100, 5)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment