Implementation:Pyro ppl Pyro NUTS Kernel
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation (API Doc) |
| Knowledge Sources | Repo (Pyro), Paper (The No-U-Turn Sampler), Paper (A Conceptual Introduction to Hamiltonian Monte Carlo) |
| Domains | MCMC, Bayesian_Inference |
| Last Updated | 2026-02-09 12:00 GMT |
Overview
Concrete MCMC kernel implementing the No-U-Turn Sampler (NUTS) algorithm in Pyro, providing automatic trajectory length tuning on top of the HMC kernel.
Description
NUTS is a subclass of HMC that implements the No-U-Turn Sampler algorithm. It extends HMC by replacing the fixed-length leapfrog trajectory with an adaptive tree-doubling scheme. At each MCMC iteration, NUTS builds a balanced binary tree of leapfrog states, doubling the trajectory length until a U-turn condition is detected or the maximum tree depth is reached.
The kernel inherits all adaptation machinery from HMC, including:
- Step size adaptation via dual averaging, targeting a specified acceptance probability.
- Mass matrix adaptation using the Welford algorithm to estimate the diagonal or full covariance of the posterior.
- Automatic transforms that reparameterize constrained parameters to unconstrained space for more efficient sampling.
NUTS adds the following behavior on top of HMC:
- Tree doubling: Recursively builds a binary tree of leapfrog states by doubling in a random direction until a U-turn is detected.
- U-turn detection: Checks whether the trajectory endpoints begin to converge, indicating the path is curving back.
- Multinomial sampling: When
use_multinomial_sampling=True(default), selects the return state from the trajectory using multinomial weighting by energy, which yields lower-variance estimates than the original slice sampling approach. - Divergence tracking: Records the number of divergent transitions (numerical failures during leapfrog integration) for convergence diagnostics.
Code Reference
Source Location
Pyro repo, file: pyro/infer/mcmc/nuts.py, lines L55-523.
Class Hierarchy
NUTS inherits from HMC, which inherits from MCMCKernel.
Signature
class NUTS(HMC):
def __init__(
self,
model=None,
potential_fn=None,
step_size=1,
adapt_step_size=True,
adapt_mass_matrix=True,
full_mass=False,
use_multinomial_sampling=True,
transforms=None,
max_plate_nesting=None,
jit_compile=False,
jit_options=None,
ignore_jit_warnings=False,
target_accept_prob=0.8,
max_tree_depth=10,
init_strategy=init_to_uniform,
):
Import
from pyro.infer.mcmc import NUTS
I/O Contract
Constructor Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
model |
callable | No* | A Pyro model (stochastic function). Either model or potential_fn must be provided.
|
potential_fn |
callable | No* | A Python callable that computes the potential energy given unconstrained parameters. Alternative to model.
|
step_size |
float | No | Initial leapfrog step size. Defaults to 1. Adapted during warmup if adapt_step_size=True.
|
adapt_step_size |
bool | No | Whether to adapt step size during warmup using dual averaging. Defaults to True.
|
adapt_mass_matrix |
bool | No | Whether to adapt the mass matrix during warmup using the Welford algorithm. Defaults to True.
|
full_mass |
bool | No | If True, uses a dense mass matrix (full covariance). If False, uses a diagonal mass matrix. Defaults to False.
|
use_multinomial_sampling |
bool | No | If True, uses multinomial sampling from the trajectory. If False, uses slice sampling. Defaults to True.
|
transforms |
dict | No | Dictionary mapping sample site names to transforms for unconstrained parameterization. If None, inferred automatically from the model's constraint registry.
|
max_plate_nesting |
int | No | Maximum depth of nested pyro.plate contexts in the model. Used to determine which dimensions are independent.
|
jit_compile |
bool | No | Whether to JIT-compile the potential function using torch.jit.trace. Defaults to False.
|
jit_options |
dict | No | Options passed to torch.jit.trace when JIT compilation is enabled.
|
ignore_jit_warnings |
bool | No | Whether to suppress JIT compilation warnings. Defaults to False.
|
target_accept_prob |
float | No | Target Metropolis acceptance probability for dual averaging step size adaptation. Defaults to 0.8.
|
max_tree_depth |
int | No | Maximum depth of the NUTS tree (trajectory length up to 2^max_tree_depth leapfrog steps). Defaults to 10.
|
init_strategy |
callable | No | Strategy for initializing parameter values. Defaults to init_to_uniform.
|
Outputs
| Output | Type | Description |
|---|---|---|
| NUTS kernel instance | NUTS |
A kernel object to be passed to the MCMC class for sampling. Provides sample(), setup(), and logging() methods.
|
Usage Examples
Basic NUTS Sampling
import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import NUTS, MCMC
def model(data):
mu = pyro.sample("mu", dist.Normal(0, 10))
sigma = pyro.sample("sigma", dist.HalfNormal(10))
with pyro.plate("data", len(data)):
pyro.sample("obs", dist.Normal(mu, sigma), obs=data)
data = torch.randn(100) * 2 + 5
nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)
mcmc.run(data)
samples = mcmc.get_samples()
print(samples["mu"].mean()) # approximately 5.0
print(samples["sigma"].mean()) # approximately 2.0
NUTS with Custom Configuration
from pyro.infer.mcmc import NUTS, MCMC
from pyro.infer.autoguide.initialization import init_to_median
nuts_kernel = NUTS(
model,
step_size=0.1,
adapt_step_size=True,
adapt_mass_matrix=True,
full_mass=True,
target_accept_prob=0.9,
max_tree_depth=12,
jit_compile=True,
init_strategy=init_to_median,
)
mcmc = MCMC(nuts_kernel, num_samples=2000, warmup_steps=1000, num_chains=4)
mcmc.run(data)
mcmc.summary()