Implementation:Pyro ppl Pyro SubsampleMessenger
Appearance
| Attribute | Value |
|---|---|
| File | pyro/poutine/subsample_messenger.py
|
| Module | pyro.poutine.subsample_messenger
|
| Lines | 218 |
| Parent Class | IndepMessenger
|
| Purpose | Extend IndepMessenger with data subsampling and automatic scaling |
| Architecture Role | Intermediate class between IndepMessenger and PlateMessenger; handles subsampling logic |
| License | Apache-2.0 (Uber Technologies, Inc.) |
Overview
SubsampleMessenger extends IndepMessenger with data subsampling capabilities. It randomly selects a subsample of indices from a range and automatically scales log probabilities by size / subsample_size to provide unbiased gradient estimates.
The module defines two classes:
- _Subsample -- An internal
Distributionthat generates random subsample indices. It returns a random permutation of indices (truncated tosubsample_size) and has zero log probability so that the plate can provide the unbiased estimate. - SubsampleMessenger -- Extends
IndepMessengerwith:- Subsample generation via
_Subsampledistribution (processed through the stack viaapply_stack). - Scale correction:
msg["scale"] *= size / subsample_size. - Post-processing of
paramandsubsamplemessages to index-select parameters along the plate dimension.
- Subsample generation via
The post-processing logic handles parameter subsampling by calling value.index_select(dim, indices) and tracking the subsample indices on the unconstrained parameter via param._pyro_subsample.
Code Reference
_Subsample Distribution
class _Subsample(Distribution):
"""Internal distribution for random subsample generation."""
def __init__(self, size, subsample_size, use_cuda=None, device=None):
self.size = size
self.subsample_size = subsample_size
self.device = device or torch.tensor(tuple()).device
def sample(self, sample_shape=torch.Size()):
if subsample_size is None or subsample_size >= self.size:
return torch.arange(self.size, device=self.device)
return torch.randperm(self.size, device=self.device)[:subsample_size].clone()
def log_prob(self, x):
return torch.tensor(0.0, device=self.device) # zero for unbiased estimation
SubsampleMessenger
class SubsampleMessenger(IndepMessenger):
def __init__(self, name, size=None, subsample_size=None,
subsample=None, dim=None, use_cuda=None, device=None):
full_size, self.subsample_size, subsample = self._subsample(
name, size, subsample_size, subsample, use_cuda, device)
super().__init__(name, full_size, dim, device)
self._indices = subsample
def _process_message(self, msg):
frame = CondIndepStackFrame(
name=self.name, dim=self.dim,
size=self.subsample_size, counter=self.counter,
full_size=self.size)
msg["cond_indep_stack"] = (frame,) + msg["cond_indep_stack"]
msg["scale"] = msg["scale"] * self.size / self.subsample_size
def _postprocess_message(self, msg):
if msg["type"] in ("param", "subsample") and self.dim is not None:
event_dim = msg["kwargs"].get("event_dim")
if event_dim is not None:
dim = self.dim - event_dim
shape = msg["value"].shape
if len(shape) >= -dim and shape[dim] != 1:
if self.subsample_size < self.size:
# Subsample parameters with known batch semantics
msg["value"] = msg["value"].index_select(dim, self._indices)
I/O Contract
| Parameter | Type | Description |
|---|---|---|
| name | str
|
Name of the plate context |
| size | Optional[int]
|
Total size of the data |
| subsample_size | Optional[int]
|
Subsample size (if None or >= size, no subsampling)
|
| subsample | Optional[torch.Tensor]
|
Explicit subsample indices |
| dim | Optional[int]
|
Batch dimension (negative integer) |
| device | Optional[str]
|
Device for index tensors |
| Message Effect | Description |
|---|---|
| msg["cond_indep_stack"] | Prepends CondIndepStackFrame with subsample_size and full_size
|
| msg["scale"] | Multiplied by size / subsample_size for unbiased estimation
|
| msg["value"] (param/subsample) | Index-selected along plate dimension for subsampled parameters |
Usage Examples
Plate with Subsampling
data = torch.randn(10000)
def model(data):
loc = pyro.sample("loc", dist.Normal(0, 1))
# Subsample 100 from 10000 items; log prob scaled by 100x
with pyro.plate("data", len(data), subsample_size=100) as idx:
pyro.sample("obs", dist.Normal(loc, 1), obs=data[idx])
Explicit Subsample Indices
indices = torch.tensor([0, 5, 10, 15, 20])
with pyro.plate("data", 100, subsample=indices) as idx:
pyro.sample("obs", dist.Normal(0, 1), obs=data[idx])
Related Pages
- Pyro_ppl_Pyro_IndepMessenger -- Parent class providing independence annotation and dimension allocation
- Pyro_ppl_Pyro_PlateMessenger -- Child class that adds broadcasting on top of subsampling
- Pyro_ppl_Pyro_ScaleMessenger -- Related: multiplicative scaling of log probabilities
- Pyro_ppl_Pyro_Poutine_Runtime --
apply_stackis called to process subsample messages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment