Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro Vindex

From Leeroopedia


Property Value
Module pyro.ops.indexing
Source pyro/ops/indexing.py
Lines 218
Classes Index, Vindex
Functions index, vindex
Dependencies torch

Overview

This module provides vectorized advanced indexing with broadcasting semantics for PyTorch tensors. The central utility is Vindex, which enables indexing that is compatible with batching and enumeration -- a critical capability for selecting mixture components with discrete random variables in Pyro models.

Standard PyTorch advanced indexing does not broadcast tensor indices in the way needed by probabilistic programming. For example, when indexing x[i, :, j] where i and j are tensors with different batch shapes, standard indexing fails or produces unexpected results. Vindex handles this by broadcasting tensor indices and reshaping slices to ensure correct alignment.

The module also provides Index for handling nested tuple indexing, which is useful when indexing code must be compatible with multiple interpretations (scalar, vectorized, reshaping).

Code Reference

Function: vindex(tensor, args)

Performs vectorized advanced indexing with broadcasting semantics. Key conventions:

  • Ellipsis as the first argument denotes batch dimensions (can only appear on the left).
  • slice(None) (i.e., :) preserves event dimensions.
  • Integer and tensor arguments are broadcast together.
  • Non-leading Ellipsis raises NotImplementedError.
  • Nontrivial slices (other than :) raise NotImplementedError.

If no tensor argument has dim() > 0, falls back to standard indexing.

Class: Vindex

Convenience wrapper providing __getitem__ syntax:

Vindex(x)[..., i, j, :]
# is equivalent to
vindex(x, (Ellipsis, i, j, slice(None)))

Function: index(tensor, args)

Handles indexing with nested tuples by flattening the tuple and combining consecutive Ellipsis entries. This is useful when index expressions are constructed programmatically and may contain nested tuples.

Class: Index

Convenience wrapper for index:

Index(x)[..., i, j, :]
# is equivalent to
index(x, (Ellipsis, i, j, slice(None)))

I/O Contract

Function Input Output
vindex(tensor, args) Tensor, tuple of (Ellipsis, slice, int, LongTensor) Tensor with broadcast batch + event shape
index(tensor, args) Tensor, tuple (possibly nested) Tensor

Usage Examples

import torch
from pyro.ops.indexing import Vindex

# Basic vectorized indexing
x = torch.randn(5, 4, 3)

# Scalar indexing (same as standard)
assert Vindex(x)[2, :, 1].shape == (4,)

# Vectorized indexing with tensor indices
i = torch.tensor([0, 1, 2])  # shape (3,)
j = torch.tensor([[0], [1]])  # shape (2, 1)

# Result broadcasts i and j, preserving the slice dimension
result = Vindex(x)[i, :, j]
assert result.shape == (2, 3, 4)  # broadcast(i.shape, j.shape) + (4,)

# With batch dimensions (Ellipsis convention)
x_batched = torch.randn(7, 5, 4, 3)  # batch_shape=(7,), event_shape=(5,4,3)
i = torch.randint(5, (7,))  # batch indices
result = Vindex(x_batched)[..., i, :, 0]
assert result.shape == (7, 4)  # (7,) batch + (4,) event

# Nested tuple indexing
from pyro.ops.indexing import Index
x = torch.randn(3, 4)
t = (Ellipsis, None)  # reshaping operation
result = Index(x)[..., t]

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment