Implementation:Pyro ppl Pyro Special
| Property | Value |
|---|---|
| Module | pyro.ops.special
|
| Source | pyro/ops/special.py |
| Lines | 219 |
| Functions | safe_log, log_beta, log_binomial, log_I1, get_quad_rule, sparse_multinomial_likelihood
|
| Dependencies | torch, numpy (for Gauss-Hermite quadrature)
|
Overview
This module provides special mathematical functions needed by Pyro's distributions and inference algorithms. The functions are designed to be numerically stable and differentiable, with several offering tunable accuracy-speed tradeoffs through tolerance parameters.
Key capabilities include:
- Safe logarithm: Avoids infinite gradients at
log(0). - Log Beta function: With a Stirling's approximation option for faster computation.
- Log binomial coefficient: Using the log Beta function.
- Modified Bessel function of the first kind: For von Mises and related distributions.
- Gauss-Hermite quadrature: For numerical integration against Gaussian measures.
- Sparse multinomial likelihood: Efficient computation for sparse count data.
Code Reference
Function: safe_log(x)
Computes torch.log(x) but clamps gradients to avoid the infinite gradient at log(0). Uses a custom torch.autograd.Function that clamps the backward pass gradient denominator at finfo.eps.
Function: log_beta(x, y, tol=0.0)
Computes the log Beta function log(B(x, y)) = lgamma(x) + lgamma(y) - lgamma(x+y).
When tol >= 0.02, uses a shifted Stirling's approximation for speed. The number of shift iterations is ceil(0.082 / tol), bounding absolute error below tol. For small tolerance, defers to torch.lgamma.
Function: log_binomial(n, k, tol=0.0)
Computes the log binomial coefficient using log_beta. Supports the same Stirling's approximation via the tol parameter.
Function: log_I1(orders, value, terms=250)
Computes the first orders+1 log modified Bessel functions of the first kind using the series expansion. Truncates at terms terms. Used internally by the von Mises distribution.
Function: get_quad_rule(num_quad, prototype_tensor)
Returns Gauss-Hermite quadrature points and log weights for numerical integration. Points are scaled by sqrt(2) to match the standard normal measure. Uses numpy.polynomial.hermite.hermgauss and converts to PyTorch tensors matching the dtype/device of prototype_tensor.
Function: sparse_multinomial_likelihood(total_count, nonzero_logits, nonzero_value)
Computes multinomial log probability efficiently using only the nonzero entries. Equivalent to Multinomial(logits=logits).log_prob(value).sum() but avoids materializing the full probability vector. Uses a weakref-based cache for log-factorial computations.
I/O Contract
| Function | Input | Output |
|---|---|---|
safe_log(x) |
Tensor (non-negative) |
Tensor (log values with clamped gradients)
|
log_beta(x, y, tol) |
Positive Tensors, tol: float |
Tensor
|
log_binomial(n, k, tol) |
Non-negative integer Tensors |
Tensor
|
log_I1(orders, value, terms) |
orders: int, value: Tensor |
Tensor(orders+1, *value.shape)
|
get_quad_rule(num_quad, prototype_tensor) |
int, Tensor |
Tuple (quad_points: Tensor, log_weights: Tensor)
|
Usage Examples
import torch
from pyro.ops.special import safe_log, log_beta, get_quad_rule
# Safe log avoids infinite gradients
x = torch.tensor([0.0, 0.5, 1.0], requires_grad=True)
y = safe_log(x)
y.sum().backward()
print(x.grad) # finite even at x=0
# Log beta function with Stirling approximation
a = torch.tensor(10.0)
b = torch.tensor(5.0)
exact = log_beta(a, b, tol=0.0)
approx = log_beta(a, b, tol=0.1)
print(f"Exact: {exact:.6f}, Approx: {approx:.6f}")
# Gauss-Hermite quadrature for integration
quad_points, log_weights = get_quad_rule(32, torch.tensor(0.0))
# Compute E[x^2] under N(0, 1) -- should be 1.0
variance = torch.logsumexp(
quad_points.pow(2).log() + log_weights, dim=0
).exp()
print(f"Variance: {variance:.6f}") # ~1.0
Related Pages
- Pyro_ppl_Pyro_TensorUtils -- Related numerical utilities
- Pyro_ppl_Pyro_Stats -- Statistical functions that may use quadrature