Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Facebookresearch Habitat lab TensorDict

From Leeroopedia
Knowledge Sources
Domains Embodied_AI, Data_Structures, Reinforcement_Learning
Last Updated 2026-02-15 00:00 GMT

Overview

TensorDict and its family of classes (NDArrayDict, TensorOrNDArrayDict) provide a dictionary-tree data structure where leaves are tensors or arrays, supporting both string-key dictionary access and tensor-style indexing in a single unified interface.

Description

The module implements a hierarchy of dictionary-tree (DictTree) classes built on top of Python's dict:

_DictTreeBase[T] is the abstract base class providing:

  • Dual indexing: String keys access sub-trees or leaves (t["key"]), while tensor indices (int, slice, Tensor, ndarray) broadcast across all leaves (t[0:5] returns a new tree with each leaf sliced).
  • from_tree() / to_tree(): Convert between typed DictTree instances and plain nested dicts.
  • flatten() / from_flattened(): Flatten the tree into a list of (key-tuple, leaf) pairs and reconstruct from that representation. Keys are tuples of strings representing the path through the tree.
  • map(): Apply a function to all leaves, returning a new tree with transformed leaves.
  • map_in_place(): Same as map but modifies the current tree.
  • apply(): Apply a function to all leaves without returning values (side-effect only).
  • set(): Set values with support for both string indexing and tensor indexing with strict key checking.
  • slice_keys(): Create a shallow copy with only the specified keys.
  • __deepcopy__(): Deep copy via tree serialization/deserialization.

TensorDict specializes _DictTreeBase[torch.Tensor]:

  • Automatically converts numpy arrays and scalars to torch.Tensor via _to_instance().
  • numpy(): Converts to NDArrayDict.

NDArrayDict specializes _DictTreeBase[np.ndarray]:

  • Automatically converts torch tensors and scalars to np.ndarray.
  • as_tensor(): Converts to TensorDict.

TensorOrNDArrayDict accepts either torch tensors or numpy arrays as leaves.

The module also provides two utility functions:

  • iterate_dicts_recursively(): Iterates multiple DictTrees in lockstep, yielding tuples of corresponding leaves.
  • transpose_list_of_dicts(): Converts a list of dicts into a dict of lists.

Usage

TensorDict is the primary data structure used throughout habitat-baselines for storing batched observations, rollout buffers, and action data. It allows treating nested observation dictionaries as if they were single tensors, enabling batch slicing, device transfer, and functional transformations.

Code Reference

Source Location

Signature

class _DictTreeBase(Dict[str, Union["_DictTreeBase[T]", T]]):
    @classmethod
    def from_tree(cls, tree: Dict[str, Any]) -> _DictTreeInst: ...
    def to_tree(self) -> DictTree: ...
    @classmethod
    def from_flattened(cls, spec: List[Tuple[str, ...]], leaves: List[TensorLike]) -> _DictTreeInst: ...
    def flatten(self) -> Tuple[List[Tuple[str, ...]], List[T]]: ...
    def __getitem__(self, index: Union[str, TensorIndexType]) -> Union[_DictTreeInst, T]: ...
    def set(self, index, value, strict=True) -> None: ...
    def map(self, func: _MapFuncType) -> _DictTreeInst: ...
    def map_in_place(self, func: _MapFuncType) -> _DictTreeInst: ...
    def apply(self, func: _ApplyFuncType) -> None: ...
    def slice_keys(self, *keys) -> _DictTreeInst: ...


class TensorDict(_DictTreeBase[torch.Tensor]):
    def numpy(self) -> NDArrayDict: ...


class NDArrayDict(_DictTreeBase[np.ndarray]):
    def as_tensor(self) -> TensorDict: ...


class TensorOrNDArrayDict(_DictTreeBase[Union[torch.Tensor, np.ndarray]]):
    ...

def iterate_dicts_recursively(*dicts_i: _DictTreeBase[T]) -> Iterable[Tuple[T, ...]]: ...
def transpose_list_of_dicts(*dicts_i: Dict[Any, Any]) -> Dict[Any, List[Any]]: ...

Import

from habitat_baselines.common.tensor_dict import (
    TensorDict,
    NDArrayDict,
    TensorOrNDArrayDict,
    iterate_dicts_recursively,
    transpose_list_of_dicts,
)

I/O Contract

Inputs

Name Type Required Description
tree Dict[str, Any] No Nested dictionary to create a DictTree from (via from_tree())
spec List[Tuple[str, ...]] No Key paths for constructing from flattened representation
leaves List[TensorLike] No Leaf values for constructing from flattened representation

Outputs

Name Type Description
TensorDict instance TensorDict A dictionary tree where all leaves are torch.Tensor
NDArrayDict instance NDArrayDict A dictionary tree where all leaves are np.ndarray
flatten() Tuple[List[Tuple[str,...]], List[T]] Flattened key paths and leaf values

Usage Examples

Basic Usage

import torch
from habitat_baselines.common.tensor_dict import TensorDict

# Create a TensorDict from keyword arguments
t = TensorDict(
    a=torch.randn(4, 3),
    b=TensorDict(
        c=torch.randn(4, 5),
        d=torch.randn(4, 2),
    ),
)

# String indexing to access leaves or sub-trees
print(t["a"].shape)       # torch.Size([4, 3])
print(t["b"]["c"].shape)  # torch.Size([4, 5])

# Tensor indexing broadcasts across all leaves
batch_0 = t[0]
print(batch_0["a"].shape)       # torch.Size([3])
print(batch_0["b"]["c"].shape)  # torch.Size([5])

# Slice indexing
first_two = t[0:2]
print(first_two["a"].shape)  # torch.Size([2, 3])

# Map a function over all leaves
t_gpu = t.map(lambda x: x.cuda())

# Map in place
t.map_in_place(lambda x: x * 2.0)

# Flatten and reconstruct
spec, leaves = t.flatten()
t_reconstructed = TensorDict.from_flattened(spec, leaves)

# Convert between TensorDict and NDArrayDict
nd = t.numpy()
t_back = nd.as_tensor()

# Create from a plain dict
plain_dict = {"obs": torch.zeros(4, 3), "action": torch.ones(4)}
td = TensorDict.from_tree(plain_dict)

# Slice to specific keys
subset = t.slice_keys("a")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment