Implementation:Pyro ppl Pyro Vindex
| Property | Value |
|---|---|
| Module | pyro.ops.indexing
|
| Source | pyro/ops/indexing.py |
| Lines | 218 |
| Classes | Index, Vindex
|
| Functions | index, vindex
|
| Dependencies | torch
|
Overview
This module provides vectorized advanced indexing with broadcasting semantics for PyTorch tensors. The central utility is Vindex, which enables indexing that is compatible with batching and enumeration -- a critical capability for selecting mixture components with discrete random variables in Pyro models.
Standard PyTorch advanced indexing does not broadcast tensor indices in the way needed by probabilistic programming. For example, when indexing x[i, :, j] where i and j are tensors with different batch shapes, standard indexing fails or produces unexpected results. Vindex handles this by broadcasting tensor indices and reshaping slices to ensure correct alignment.
The module also provides Index for handling nested tuple indexing, which is useful when indexing code must be compatible with multiple interpretations (scalar, vectorized, reshaping).
Code Reference
Function: vindex(tensor, args)
Performs vectorized advanced indexing with broadcasting semantics. Key conventions:
Ellipsisas the first argument denotes batch dimensions (can only appear on the left).slice(None)(i.e.,:) preserves event dimensions.- Integer and tensor arguments are broadcast together.
- Non-leading Ellipsis raises
NotImplementedError. - Nontrivial slices (other than
:) raiseNotImplementedError.
If no tensor argument has dim() > 0, falls back to standard indexing.
Class: Vindex
Convenience wrapper providing __getitem__ syntax:
Vindex(x)[..., i, j, :]
# is equivalent to
vindex(x, (Ellipsis, i, j, slice(None)))
Function: index(tensor, args)
Handles indexing with nested tuples by flattening the tuple and combining consecutive Ellipsis entries. This is useful when index expressions are constructed programmatically and may contain nested tuples.
Class: Index
Convenience wrapper for index:
Index(x)[..., i, j, :]
# is equivalent to
index(x, (Ellipsis, i, j, slice(None)))
I/O Contract
| Function | Input | Output |
|---|---|---|
vindex(tensor, args) |
Tensor, tuple of (Ellipsis, slice, int, LongTensor) |
Tensor with broadcast batch + event shape
|
index(tensor, args) |
Tensor, tuple (possibly nested) |
Tensor
|
Usage Examples
import torch
from pyro.ops.indexing import Vindex
# Basic vectorized indexing
x = torch.randn(5, 4, 3)
# Scalar indexing (same as standard)
assert Vindex(x)[2, :, 1].shape == (4,)
# Vectorized indexing with tensor indices
i = torch.tensor([0, 1, 2]) # shape (3,)
j = torch.tensor([[0], [1]]) # shape (2, 1)
# Result broadcasts i and j, preserving the slice dimension
result = Vindex(x)[i, :, j]
assert result.shape == (2, 3, 4) # broadcast(i.shape, j.shape) + (4,)
# With batch dimensions (Ellipsis convention)
x_batched = torch.randn(7, 5, 4, 3) # batch_shape=(7,), event_shape=(5,4,3)
i = torch.randint(5, (7,)) # batch indices
result = Vindex(x_batched)[..., i, :, 0]
assert result.shape == (7, 4) # (7,) batch + (4,) event
# Nested tuple indexing
from pyro.ops.indexing import Index
x = torch.randn(3, 4)
t = (Ellipsis, None) # reshaping operation
result = Index(x)[..., t]
Related Pages
- Pyro_ppl_Pyro_PackedTensor -- Related packed tensor indexing utilities