Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro NanMaskedDistributions

From Leeroopedia
Revision as of 16:24, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Pyro_ppl_Pyro_NanMaskedDistributions.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Probability_Distributions
Last Updated 2026-02-09 09:00 GMT

Overview

Description

The NanMasked module provides two distribution wrappers that handle partially observed data containing NaN values: NanMaskedNormal and NanMaskedMultivariateNormal. Both classes override the log_prob method to treat NaN entries as missing data, assigning them zero log probability (effectively marginalizing them out).

NanMaskedNormal extends Normal and handles elementwise missing data. When log_prob is called, it identifies finite (non-NaN) entries, computes the normal log probability only for those entries, and returns zeros for NaN positions. Broadcasting is handled by explicitly broadcasting value, loc, and scale before filtering.

NanMaskedMultivariateNormal extends MultivariateNormal and handles missing dimensions within multivariate observations. It implements a more sophisticated marginalization strategy: it groups observations by their missingness pattern (which dimensions are observed vs. missing), then for each unique pattern, constructs a marginal multivariate normal over only the observed dimensions by sub-selecting the corresponding rows and columns of the covariance matrix. This is exact marginalization, though the implementation note acknowledges that eager broadcasting may introduce some computational overhead.

Both classes are designed as drop-in replacements for their base distributions, requiring no changes to model code beyond swapping the distribution class.

Usage

These distributions are useful in any Bayesian model where the observed data has missing values represented as NaN. Common scenarios include time series with irregular observations, survey data with unanswered questions, and sensor data with dropout. By using NanMasked wrappers, the model can process incomplete data without requiring manual imputation or data preprocessing.

Code Reference

Source Location

Signature

class NanMaskedNormal(Normal):
    def log_prob(self, value: torch.Tensor) -> torch.Tensor

class NanMaskedMultivariateNormal(MultivariateNormal):
    def log_prob(self, value: torch.Tensor) -> torch.Tensor

Import

from pyro.distributions import NanMaskedNormal, NanMaskedMultivariateNormal

I/O Contract

Inputs (NanMaskedNormal)

Parameter Type Description
loc torch.Tensor Mean of the normal distribution (inherited from Normal).
scale torch.Tensor Standard deviation of the normal distribution (inherited from Normal).

Inputs (NanMaskedMultivariateNormal)

Parameter Type Description
loc torch.Tensor Mean vector of the multivariate normal (inherited from MultivariateNormal).
covariance_matrix / precision_matrix / scale_tril torch.Tensor Covariance parameterization (inherited from MultivariateNormal).

Outputs

Method Return Type Description
NanMaskedNormal.log_prob(value) torch.Tensor Returns log probability for finite entries; zero for NaN entries. Shape matches the broadcast of value, loc, and scale.
NanMaskedMultivariateNormal.log_prob(value) torch.Tensor Returns the marginal log probability computed over observed (non-NaN) dimensions only. Shape is value.shape[:-1].

Usage Examples

import torch
import pyro
from pyro.distributions import NanMaskedNormal
from math import nan

# Univariate case: partially observed 1-D data
data = torch.tensor([0.5, 0.1, nan, 0.9])

with pyro.plate("data", len(data)):
    pyro.sample("obs", NanMaskedNormal(0, 1), obs=data)
import torch
import pyro
from pyro.distributions import NanMaskedMultivariateNormal
from math import nan

# Multivariate case: some dimensions missing per observation
data = torch.tensor([
    [0.1, 0.2, 3.4],
    [0.5, 0.1, nan],
    [0.6, nan, nan],
    [nan, 0.5, nan],
    [nan, nan, nan],
])

with pyro.plate("data", len(data)):
    pyro.sample(
        "obs",
        NanMaskedMultivariateNormal(torch.zeros(3), torch.eye(3)),
        obs=data,
    )
import torch
from pyro.distributions import NanMaskedNormal
from math import nan

# Direct log_prob computation
dist = NanMaskedNormal(torch.zeros(4), torch.ones(4))
value = torch.tensor([1.0, nan, -0.5, nan])
log_p = dist.log_prob(value)
print(log_p)  # Non-NaN entries get normal log_prob; NaN entries get 0.0

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment