Heuristic:Pyro ppl Pyro Numerical Stability Patterns
| Knowledge Sources | |
|---|---|
| Domains | Numerical_Computing, Debugging |
| Last Updated | 2026-02-09 09:00 GMT |
Overview
Numerical stability techniques used throughout Pyro including log-domain arithmetic, custom log-addexp, clamping guards, and Cholesky-based precision handling.
Description
Pyro extensively uses log-domain computation to prevent underflow and overflow in probability calculations. Key patterns include: log-sum-exp for normalizing log-probabilities, a custom log-addexp implementation more stable than naive `log(exp(x) + exp(y))`, clamping arguments before singular operations, and Cholesky decomposition of precision matrices instead of direct inversion. These patterns are critical for the correctness of MCMC (potential energy computation), HMM forward algorithms, and ELBO estimation.
Usage
Apply these patterns when implementing custom distributions, writing models with extreme probability values, or debugging NaN/Inf values in ELBO losses or MCMC potential energy. Understanding these patterns helps diagnose numerical issues in inference.
The Insight (Rule of Thumb)
- Action 1: Always work in log-space for probability computations. Use `logsumexp` instead of `sum(exp(...))`.
- Action 2: Use `log1p(exp(x))` (via `_logaddexp`) instead of `log(exp(x) + exp(y))` for adding two log-probabilities.
- Action 3: Clamp arguments before `log()`, `exp()`, or `expm1()` to prevent `-inf` or `NaN` results.
- Action 4: Use Cholesky factorization of precision matrices instead of inverting covariance matrices directly.
- Value: Common clamp thresholds: `min=1e-4` for log, `min=-1, max=1` for bounded operations.
- Trade-off: Log-domain arithmetic adds complexity but prevents catastrophic floating-point errors. Clamping introduces small biases but prevents NaN propagation.
Reasoning
Probabilities in deep models can be astronomically small (e.g., 1e-300), which underflows to zero in linear space but is perfectly representable as -690 in log space. The log-sum-exp trick subtracts the maximum before exponentiation, keeping all values in a representable range. Pyro's custom `_logaddexp` uses `log1p` which is more accurate than `log(1+x)` for small `x`.
Cholesky-based precision handling avoids the numerical amplification of errors that occurs with direct matrix inversion. When computing `Sigma^{-1} * x`, using `cholesky_solve(x, chol(Lambda))` is more stable than `inv(Sigma) @ x`.
Code evidence for custom log-addexp from `pyro/infer/mcmc/nuts.py:15-17`:
def _logaddexp(x, y):
minval, maxval = (x, y) if x < y else (y, x)
return (minval - maxval).exp().log1p() + maxval
Log-space normalization from `pyro/distributions/hmm.py:323-324`:
self.initial_logits = initial_logits - initial_logits.logsumexp(-1, True)
Clamping before expm1 from `examples/contrib/forecast/bart.py:137`:
pred.clamp(min=1e-4).expm1()
Cholesky-based solve from `pyro/distributions/hmm.py:931-932`:
scale_tril = safe_cholesky(logp.precision)
loc = cholesky_solve(logp.info_vec.unsqueeze(-1), scale_tril).squeeze(-1)