Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro ScaleMessenger

From Leeroopedia


Attribute Value
File pyro/poutine/scale_messenger.py
Module pyro.poutine.scale_messenger
Lines 54
Parent Class Messenger
Purpose Multiplicatively scale the log probabilities of all sample and observe sites
License Apache-2.0 (Uber Technologies, Inc.)

Overview

ScaleMessenger multiplicatively scales the log probabilities of all sample and observe sites within its scope. Given a positive scale factor, it multiplies the msg["scale"] field of every message, which is later used when computing log_prob_sum in the Trace.

This handler operates on all message types (not just sample sites), as it overrides _process_message rather than _pyro_sample.

The scale must be strictly positive. The constructor validates this and raises ValueError for non-positive values. For zero/boolean masking, MaskMessenger should be used instead.

Code Reference

class ScaleMessenger(Messenger):
    def __init__(self, scale: Union[float, torch.Tensor]) -> None:
        if isinstance(scale, torch.Tensor):
            if is_validation_enabled() and not (scale > 0).all():
                raise ValueError(
                    "Expected scale > 0 but got {}. "
                    "Consider using poutine.mask() instead of poutine.scale().".format(scale)
                )
        elif not (scale > 0):
            raise ValueError("Expected scale > 0 but got {}".format(scale))
        super().__init__()
        self.scale = scale

    def _process_message(self, msg: "Message") -> None:
        msg["scale"] = self.scale * msg["scale"]

I/O Contract

Parameter Type Description
scale float or torch.Tensor A positive scaling factor for log probabilities
Message Effect Description
msg["scale"] Multiplied by the scale factor (composable with other scale effects)

Validation: Raises ValueError if scale is not positive.

Usage Examples

Scaling Log Probabilities

def model(x):
    s = pyro.param("s", torch.tensor(0.5))
    pyro.sample("z", dist.Normal(x, s), obs=torch.tensor(1.0))

scaled_model = pyro.poutine.scale(model, scale=0.5)
scaled_tr = pyro.poutine.trace(scaled_model).get_trace(0.0)
unscaled_tr = pyro.poutine.trace(model).get_trace(0.0)

assert (scaled_tr.log_prob_sum() == 0.5 * unscaled_tr.log_prob_sum()).all()

KL Annealing

# Use scale for KL annealing during training
beta = min(1.0, epoch / warmup_epochs)

with pyro.poutine.scale(scale=beta):
    z = pyro.sample("z", dist.Normal(0, 1))

Tensor-Valued Scale

# Per-element scaling
importance_weights = torch.tensor([0.5, 1.0, 2.0])
with pyro.poutine.scale(scale=importance_weights):
    pyro.sample("obs", dist.Normal(0, 1).expand([3]), obs=data)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment