Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Special

From Leeroopedia
Revision as of 16:25, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pyro_ppl_Pyro_Special.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Property Value
Module pyro.ops.special
Source pyro/ops/special.py
Lines 219
Functions safe_log, log_beta, log_binomial, log_I1, get_quad_rule, sparse_multinomial_likelihood
Dependencies torch, numpy (for Gauss-Hermite quadrature)

Overview

This module provides special mathematical functions needed by Pyro's distributions and inference algorithms. The functions are designed to be numerically stable and differentiable, with several offering tunable accuracy-speed tradeoffs through tolerance parameters.

Key capabilities include:

  • Safe logarithm: Avoids infinite gradients at log(0).
  • Log Beta function: With a Stirling's approximation option for faster computation.
  • Log binomial coefficient: Using the log Beta function.
  • Modified Bessel function of the first kind: For von Mises and related distributions.
  • Gauss-Hermite quadrature: For numerical integration against Gaussian measures.
  • Sparse multinomial likelihood: Efficient computation for sparse count data.

Code Reference

Function: safe_log(x)

Computes torch.log(x) but clamps gradients to avoid the infinite gradient at log(0). Uses a custom torch.autograd.Function that clamps the backward pass gradient denominator at finfo.eps.

Function: log_beta(x, y, tol=0.0)

Computes the log Beta function log(B(x, y)) = lgamma(x) + lgamma(y) - lgamma(x+y).

When tol >= 0.02, uses a shifted Stirling's approximation for speed. The number of shift iterations is ceil(0.082 / tol), bounding absolute error below tol. For small tolerance, defers to torch.lgamma.

Function: log_binomial(n, k, tol=0.0)

Computes the log binomial coefficient using log_beta. Supports the same Stirling's approximation via the tol parameter.

Function: log_I1(orders, value, terms=250)

Computes the first orders+1 log modified Bessel functions of the first kind using the series expansion. Truncates at terms terms. Used internally by the von Mises distribution.

Function: get_quad_rule(num_quad, prototype_tensor)

Returns Gauss-Hermite quadrature points and log weights for numerical integration. Points are scaled by sqrt(2) to match the standard normal measure. Uses numpy.polynomial.hermite.hermgauss and converts to PyTorch tensors matching the dtype/device of prototype_tensor.

Function: sparse_multinomial_likelihood(total_count, nonzero_logits, nonzero_value)

Computes multinomial log probability efficiently using only the nonzero entries. Equivalent to Multinomial(logits=logits).log_prob(value).sum() but avoids materializing the full probability vector. Uses a weakref-based cache for log-factorial computations.

I/O Contract

Function Input Output
safe_log(x) Tensor (non-negative) Tensor (log values with clamped gradients)
log_beta(x, y, tol) Positive Tensors, tol: float Tensor
log_binomial(n, k, tol) Non-negative integer Tensors Tensor
log_I1(orders, value, terms) orders: int, value: Tensor Tensor(orders+1, *value.shape)
get_quad_rule(num_quad, prototype_tensor) int, Tensor Tuple (quad_points: Tensor, log_weights: Tensor)

Usage Examples

import torch
from pyro.ops.special import safe_log, log_beta, get_quad_rule

# Safe log avoids infinite gradients
x = torch.tensor([0.0, 0.5, 1.0], requires_grad=True)
y = safe_log(x)
y.sum().backward()
print(x.grad)  # finite even at x=0

# Log beta function with Stirling approximation
a = torch.tensor(10.0)
b = torch.tensor(5.0)
exact = log_beta(a, b, tol=0.0)
approx = log_beta(a, b, tol=0.1)
print(f"Exact: {exact:.6f}, Approx: {approx:.6f}")

# Gauss-Hermite quadrature for integration
quad_points, log_weights = get_quad_rule(32, torch.tensor(0.0))
# Compute E[x^2] under N(0, 1) -- should be 1.0
variance = torch.logsumexp(
    quad_points.pow(2).log() + log_weights, dim=0
).exp()
print(f"Variance: {variance:.6f}")  # ~1.0

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment