Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Pyro ppl Pyro PackedTensor

From Leeroopedia


Property Value
Module pyro.ops.packed
Source pyro/ops/packed.py
Lines 201
Functions pack, unpack, broadcast_all, gather, mul, scale_and_mask, neg, exp, rename_equation
Dependencies torch, pyro.distributions.util, pyro.util

Overview

This module provides utilities for working with "packed" tensors -- tensors annotated with a ._pyro_dims string attribute that names each dimension. This dimension-naming convention enables Pyro's variable elimination and einsum-based inference algorithms to manipulate tensors without tracking explicit dimension positions.

In the packed representation, tensors are squeezed (all size-1 dimensions removed) and each remaining dimension is given a character name. Operations like broadcasting, multiplication, and gathering are performed by matching dimension names rather than positions.

Code Reference

Function: pack(value, dim_to_symbol)

Converts a standard tensor to a packed tensor by squeezing out size-1 dimensions and assigning character names from dim_to_symbol. Raises ValueError if tensor shape does not match the expected dimension mapping.

Function: unpack(value, symbol_to_dim)

Converts a packed tensor back to an unpacked tensor by permuting dimensions according to the symbol-to-dim mapping and reshaping to restore size-1 dimensions.

Function: broadcast_all(*values, dims)

Broadcasts multiple packed tensors to a common set of dimensions. Permutes each tensor's dimensions to match the target ordering and expands to the broadcast shape.

Function: gather(value, index, dim)

Performs a broadcasted gather along a named dimension dim. The named dimension must be present in value but absent in index.

Function: mul(lhs, rhs)

Packed broadcasted multiplication using torch.einsum. Computes the union of dimension names and constructs an einsum equation.

Function: scale_and_mask(tensor, scale, mask)

Scales and masks a packed tensor, broadcasting and avoiding unnecessary operations. Handles None, True, and False masks efficiently.

Function: neg(value) / exp(value)

Pointwise negation and exponentiation that preserve _pyro_dims.

Function: rename_equation(equation, *operands)

Renames symbols in an einsum/ubersum equation to match the ._pyro_dims attributes of the provided operands.

I/O Contract

Function Input Output
pack(value, dim_to_symbol) Tensor, dict mapping negative int to char Packed tensor with ._pyro_dims
unpack(value, symbol_to_dim) Packed tensor, dict mapping char to negative int Standard tensor
broadcast_all(*values) Packed tensors Tuple of packed tensors with common dims
mul(lhs, rhs) Two packed tensors Packed tensor with union of dims

Usage Examples

import torch
from pyro.ops.packed import pack, unpack, mul, broadcast_all

# Define dimension mapping
dim_to_symbol = {-2: "a", -1: "b"}
symbol_to_dim = {"a": -2, "b": -1}

# Pack a tensor
x = torch.randn(1, 3, 4)
x_packed = pack(x, dim_to_symbol)
print(x_packed._pyro_dims)  # "ab"
print(x_packed.shape)        # torch.Size([3, 4])

# Unpack back
x_unpacked = unpack(x_packed, symbol_to_dim)
print(x_unpacked.shape)      # torch.Size([3, 4])

# Packed multiplication
y = torch.randn(4, 5)
y._pyro_dims = "bc"
result = mul(x_packed, y)
print(result._pyro_dims)  # "abc"

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment