Implementation:Pyro ppl Pyro PackedTensor
| Property | Value |
|---|---|
| Module | pyro.ops.packed
|
| Source | pyro/ops/packed.py |
| Lines | 201 |
| Functions | pack, unpack, broadcast_all, gather, mul, scale_and_mask, neg, exp, rename_equation
|
| Dependencies | torch, pyro.distributions.util, pyro.util
|
Overview
This module provides utilities for working with "packed" tensors -- tensors annotated with a ._pyro_dims string attribute that names each dimension. This dimension-naming convention enables Pyro's variable elimination and einsum-based inference algorithms to manipulate tensors without tracking explicit dimension positions.
In the packed representation, tensors are squeezed (all size-1 dimensions removed) and each remaining dimension is given a character name. Operations like broadcasting, multiplication, and gathering are performed by matching dimension names rather than positions.
Code Reference
Function: pack(value, dim_to_symbol)
Converts a standard tensor to a packed tensor by squeezing out size-1 dimensions and assigning character names from dim_to_symbol. Raises ValueError if tensor shape does not match the expected dimension mapping.
Function: unpack(value, symbol_to_dim)
Converts a packed tensor back to an unpacked tensor by permuting dimensions according to the symbol-to-dim mapping and reshaping to restore size-1 dimensions.
Function: broadcast_all(*values, dims)
Broadcasts multiple packed tensors to a common set of dimensions. Permutes each tensor's dimensions to match the target ordering and expands to the broadcast shape.
Function: gather(value, index, dim)
Performs a broadcasted gather along a named dimension dim. The named dimension must be present in value but absent in index.
Function: mul(lhs, rhs)
Packed broadcasted multiplication using torch.einsum. Computes the union of dimension names and constructs an einsum equation.
Function: scale_and_mask(tensor, scale, mask)
Scales and masks a packed tensor, broadcasting and avoiding unnecessary operations. Handles None, True, and False masks efficiently.
Function: neg(value) / exp(value)
Pointwise negation and exponentiation that preserve _pyro_dims.
Function: rename_equation(equation, *operands)
Renames symbols in an einsum/ubersum equation to match the ._pyro_dims attributes of the provided operands.
I/O Contract
| Function | Input | Output |
|---|---|---|
pack(value, dim_to_symbol) |
Tensor, dict mapping negative int to char | Packed tensor with ._pyro_dims
|
unpack(value, symbol_to_dim) |
Packed tensor, dict mapping char to negative int | Standard tensor |
broadcast_all(*values) |
Packed tensors | Tuple of packed tensors with common dims |
mul(lhs, rhs) |
Two packed tensors | Packed tensor with union of dims |
Usage Examples
import torch
from pyro.ops.packed import pack, unpack, mul, broadcast_all
# Define dimension mapping
dim_to_symbol = {-2: "a", -1: "b"}
symbol_to_dim = {"a": -2, "b": -1}
# Pack a tensor
x = torch.randn(1, 3, 4)
x_packed = pack(x, dim_to_symbol)
print(x_packed._pyro_dims) # "ab"
print(x_packed.shape) # torch.Size([3, 4])
# Unpack back
x_unpacked = unpack(x_packed, symbol_to_dim)
print(x_unpacked.shape) # torch.Size([3, 4])
# Packed multiplication
y = torch.randn(4, 5)
y._pyro_dims = "bc"
result = mul(x_packed, y)
print(result._pyro_dims) # "abc"
Related Pages
- Pyro_ppl_Pyro_Rings -- Ring classes that operate on packed tensors