Implementation:Pyro ppl Pyro ScaleMessenger
| Attribute | Value |
|---|---|
| File | pyro/poutine/scale_messenger.py
|
| Module | pyro.poutine.scale_messenger
|
| Lines | 54 |
| Parent Class | Messenger
|
| Purpose | Multiplicatively scale the log probabilities of all sample and observe sites |
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
ScaleMessenger multiplicatively scales the log probabilities of all sample and observe sites within its scope. Given a positive scale factor, it multiplies the msg["scale"] field of every message, which is later used when computing log_prob_sum in the Trace.
This handler operates on all message types (not just sample sites), as it overrides _process_message rather than _pyro_sample.
The scale must be strictly positive. The constructor validates this and raises ValueError for non-positive values. For zero/boolean masking, MaskMessenger should be used instead.
Code Reference
class ScaleMessenger(Messenger):
def __init__(self, scale: Union[float, torch.Tensor]) -> None:
if isinstance(scale, torch.Tensor):
if is_validation_enabled() and not (scale > 0).all():
raise ValueError(
"Expected scale > 0 but got {}. "
"Consider using poutine.mask() instead of poutine.scale().".format(scale)
)
elif not (scale > 0):
raise ValueError("Expected scale > 0 but got {}".format(scale))
super().__init__()
self.scale = scale
def _process_message(self, msg: "Message") -> None:
msg["scale"] = self.scale * msg["scale"]
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| scale | float or torch.Tensor
|
A positive scaling factor for log probabilities |
| Message Effect | Description |
|---|---|
| msg["scale"] | Multiplied by the scale factor (composable with other scale effects)
|
Validation: Raises ValueError if scale is not positive.
Usage Examples
Scaling Log Probabilities
def model(x):
s = pyro.param("s", torch.tensor(0.5))
pyro.sample("z", dist.Normal(x, s), obs=torch.tensor(1.0))
scaled_model = pyro.poutine.scale(model, scale=0.5)
scaled_tr = pyro.poutine.trace(scaled_model).get_trace(0.0)
unscaled_tr = pyro.poutine.trace(model).get_trace(0.0)
assert (scaled_tr.log_prob_sum() == 0.5 * unscaled_tr.log_prob_sum()).all()
KL Annealing
# Use scale for KL annealing during training
beta = min(1.0, epoch / warmup_epochs)
with pyro.poutine.scale(scale=beta):
z = pyro.sample("z", dist.Normal(0, 1))
Tensor-Valued Scale
# Per-element scaling
importance_weights = torch.tensor([0.5, 1.0, 2.0])
with pyro.poutine.scale(scale=importance_weights):
pyro.sample("obs", dist.Normal(0, 1).expand([3]), obs=data)
Related Pages
- Pyro_ppl_Pyro_Messenger_Base -- Parent class providing the handler protocol
- Pyro_ppl_Pyro_Poutine_Handlers -- The
poutine.scale()factory function - Pyro_ppl_Pyro_MaskMessenger -- Related: binary masking instead of multiplicative scaling
- Pyro_ppl_Pyro_SubsampleMessenger -- Uses scaling for subsampling correction (
size / subsample_size) - Pyro_ppl_Pyro_Trace_Struct --
log_prob_sumuses the scale field for computation