Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Pyro ppl Pyro NUTS Kernel

From Leeroopedia


Metadata

Field Value
Page Type Implementation (API Doc)
Knowledge Sources Repo (Pyro), Paper (The No-U-Turn Sampler), Paper (A Conceptual Introduction to Hamiltonian Monte Carlo)
Domains MCMC, Bayesian_Inference
Last Updated 2026-02-09 12:00 GMT

Overview

Concrete MCMC kernel implementing the No-U-Turn Sampler (NUTS) algorithm in Pyro, providing automatic trajectory length tuning on top of the HMC kernel.

Description

NUTS is a subclass of HMC that implements the No-U-Turn Sampler algorithm. It extends HMC by replacing the fixed-length leapfrog trajectory with an adaptive tree-doubling scheme. At each MCMC iteration, NUTS builds a balanced binary tree of leapfrog states, doubling the trajectory length until a U-turn condition is detected or the maximum tree depth is reached.

The kernel inherits all adaptation machinery from HMC, including:

  • Step size adaptation via dual averaging, targeting a specified acceptance probability.
  • Mass matrix adaptation using the Welford algorithm to estimate the diagonal or full covariance of the posterior.
  • Automatic transforms that reparameterize constrained parameters to unconstrained space for more efficient sampling.

NUTS adds the following behavior on top of HMC:

  • Tree doubling: Recursively builds a binary tree of leapfrog states by doubling in a random direction until a U-turn is detected.
  • U-turn detection: Checks whether the trajectory endpoints begin to converge, indicating the path is curving back.
  • Multinomial sampling: When use_multinomial_sampling=True (default), selects the return state from the trajectory using multinomial weighting by energy, which yields lower-variance estimates than the original slice sampling approach.
  • Divergence tracking: Records the number of divergent transitions (numerical failures during leapfrog integration) for convergence diagnostics.

Code Reference

Source Location

Pyro repo, file: pyro/infer/mcmc/nuts.py, lines L55-523.

Class Hierarchy

NUTS inherits from HMC, which inherits from MCMCKernel.

Signature

class NUTS(HMC):
    def __init__(
        self,
        model=None,
        potential_fn=None,
        step_size=1,
        adapt_step_size=True,
        adapt_mass_matrix=True,
        full_mass=False,
        use_multinomial_sampling=True,
        transforms=None,
        max_plate_nesting=None,
        jit_compile=False,
        jit_options=None,
        ignore_jit_warnings=False,
        target_accept_prob=0.8,
        max_tree_depth=10,
        init_strategy=init_to_uniform,
    ):

Import

from pyro.infer.mcmc import NUTS

I/O Contract

Constructor Inputs

Parameter Type Required Description
model callable No* A Pyro model (stochastic function). Either model or potential_fn must be provided.
potential_fn callable No* A Python callable that computes the potential energy given unconstrained parameters. Alternative to model.
step_size float No Initial leapfrog step size. Defaults to 1. Adapted during warmup if adapt_step_size=True.
adapt_step_size bool No Whether to adapt step size during warmup using dual averaging. Defaults to True.
adapt_mass_matrix bool No Whether to adapt the mass matrix during warmup using the Welford algorithm. Defaults to True.
full_mass bool No If True, uses a dense mass matrix (full covariance). If False, uses a diagonal mass matrix. Defaults to False.
use_multinomial_sampling bool No If True, uses multinomial sampling from the trajectory. If False, uses slice sampling. Defaults to True.
transforms dict No Dictionary mapping sample site names to transforms for unconstrained parameterization. If None, inferred automatically from the model's constraint registry.
max_plate_nesting int No Maximum depth of nested pyro.plate contexts in the model. Used to determine which dimensions are independent.
jit_compile bool No Whether to JIT-compile the potential function using torch.jit.trace. Defaults to False.
jit_options dict No Options passed to torch.jit.trace when JIT compilation is enabled.
ignore_jit_warnings bool No Whether to suppress JIT compilation warnings. Defaults to False.
target_accept_prob float No Target Metropolis acceptance probability for dual averaging step size adaptation. Defaults to 0.8.
max_tree_depth int No Maximum depth of the NUTS tree (trajectory length up to 2^max_tree_depth leapfrog steps). Defaults to 10.
init_strategy callable No Strategy for initializing parameter values. Defaults to init_to_uniform.

Outputs

Output Type Description
NUTS kernel instance NUTS A kernel object to be passed to the MCMC class for sampling. Provides sample(), setup(), and logging() methods.

Usage Examples

Basic NUTS Sampling

import pyro
import pyro.distributions as dist
from pyro.infer.mcmc import NUTS, MCMC

def model(data):
    mu = pyro.sample("mu", dist.Normal(0, 10))
    sigma = pyro.sample("sigma", dist.HalfNormal(10))
    with pyro.plate("data", len(data)):
        pyro.sample("obs", dist.Normal(mu, sigma), obs=data)

data = torch.randn(100) * 2 + 5

nuts_kernel = NUTS(model)
mcmc = MCMC(nuts_kernel, num_samples=1000, warmup_steps=500)
mcmc.run(data)

samples = mcmc.get_samples()
print(samples["mu"].mean())   # approximately 5.0
print(samples["sigma"].mean())  # approximately 2.0

NUTS with Custom Configuration

from pyro.infer.mcmc import NUTS, MCMC
from pyro.infer.autoguide.initialization import init_to_median

nuts_kernel = NUTS(
    model,
    step_size=0.1,
    adapt_step_size=True,
    adapt_mass_matrix=True,
    full_mass=True,
    target_accept_prob=0.9,
    max_tree_depth=12,
    jit_compile=True,
    init_strategy=init_to_median,
)

mcmc = MCMC(nuts_kernel, num_samples=2000, warmup_steps=1000, num_chains=4)
mcmc.run(data)
mcmc.summary()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment