Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Heuristic:Huggingface Datasets Flatten Indices Performance

From Leeroopedia

Overview

Performance optimization technique using flatten_indices() to eliminate random memory access patterns caused by dataset operations like select(), shuffle(), and sort().

Description

Many Dataset operations do not physically reorder the underlying Arrow data. Instead, they create an indices mapping -- a lightweight table of row indices that defines a virtual ordering over the original data. For example, calling shuffle() generates a permuted list of indices [0:len(my_dataset)] and stores it as the dataset's _indices attribute. The underlying Arrow table remains unchanged on disk; only the mapping is new.

When the dataset is later iterated, each row access must first resolve the logical index through the mapping to find the physical row. Because the mapping is typically non-sequential (e.g., row 0 maps to physical row 47392, row 1 maps to physical row 8101, etc.), the access pattern becomes random rather than sequential. Arrow tables are stored in contiguous columnar chunks optimized for sequential reads, so random access defeats prefetching, caching, and memory locality.

The flatten_indices() method resolves this by performing a full rewrite of the dataset: it reads every row in the mapped order and writes them out contiguously. After flattening, the _indices mapping is removed, and the physical order matches the logical order once again.

Usage

Call flatten_indices() in these scenarios:

  • After select(), shuffle(), sort(), or train_test_split() -- any operation that creates an indices mapping rather than physically rewriting data.
  • Before converting to an iterable dataset via to_iterable_dataset() -- the library explicitly warns about this (see Code Evidence below).
  • Before training loops -- repeated iteration over a dataset with an indices mapping incurs the random access penalty on every epoch.
  • Before save_to_disk() or push_to_hub() -- flattening ensures the saved files have an optimal contiguous layout for downstream consumers.
  • Before add_column() -- the library internally calls flatten_indices() when indices exist, since horizontal table concatenation requires aligned physical rows.

The Insight (Rule of Thumb)

  • Action: Call ds = ds.flatten_indices() after operations that create indices mappings (select, shuffle, sort, train_test_split).
  • Value: Converts random access patterns to sequential access for optimal read performance. The official documentation states that an indices mapping can make access up to 10x slower.
  • Trade-off: flatten_indices() requires a full rewrite of the dataset (internally it calls self.map(batched=True)), which takes time and temporary disk space proportional to the dataset size. This cost is only justified if you plan to iterate over the dataset multiple times (e.g., multi-epoch training). For one-shot iteration, the rewrite overhead may exceed the savings.

Reasoning

Apache Arrow tables store columns in contiguous memory buffers. Sequential iteration walks through these buffers linearly, benefiting from CPU cache lines, memory prefetching, and OS page cache read-ahead. When an indices mapping introduces a non-sequential access pattern, each row fetch potentially touches a different memory page or disk block, destroying locality.

The flatten_indices() method works by calling self.map(batched=True) with no transformation function -- it simply reads batches of rows in the mapped order and writes them sequentially to a new cache file. The result is a fresh Arrow table where physical row order matches the desired logical order, and the _indices attribute is set to None.

Operations that create indices mappings include:

  • select() -- stores the selected indices as a pa.Table of uint64 values.
  • shuffle() -- permutes [0:len(dataset)] and stores the permutation as an indices mapping.
  • sort() -- computes a sorted order and stores it as an indices mapping.
  • train_test_split() -- internally calls select() for train and test splits.
  • concatenate_datasets() with axis=1 -- forces flatten_indices() on all input datasets since column-wise concatenation requires aligned physical rows.

Code Evidence

Warning on Iterable Conversion (arrow_dataset.py:5494-5498)

if self._indices is not None:
    logger.info(
        "Converting an Arrow dataset to iterable but it has an indices mapping that can make it slower. "
        "You can use `ds = ds.flatten_indices()` to write your dataset in contiguous chunks of data and have optimal speed."
    )

This warning is emitted inside to_iterable_dataset() when a dataset with an active indices mapping is converted, alerting users to the performance penalty.

Shuffle Docstring Performance Warning (arrow_dataset.py:4522-4534)

Shuffling takes the list of indices [0:len(my_dataset)] and shuffles it to create an indices mapping.
However as soon as your Dataset has an indices mapping, the speed can become 10x slower.
This is because there is an extra step to get the row index to read using the indices mapping,
and most importantly, you aren't reading contiguous chunks of data anymore.
To restore the speed, you'd need to rewrite the entire dataset on your disk again using
Dataset.flatten_indices, which removes the indices mapping.

Usage pattern from the docstring:

my_dataset[0]  # fast
my_dataset = my_dataset.shuffle(seed=42)
my_dataset[0]  # up to 10x slower
my_dataset = my_dataset.flatten_indices()  # rewrite the shuffled dataset on disk as contiguous chunks of data
my_dataset[0]  # fast again

flatten_indices() Implementation (arrow_dataset.py:3960-4004)

def flatten_indices(
    self,
    keep_in_memory: bool = False,
    cache_file_name: Optional[str] = None,
    writer_batch_size: Optional[int] = 1000,
    features: Optional[Features] = None,
    disable_nullable: bool = False,
    num_proc: Optional[int] = None,
    new_fingerprint: Optional[str] = None,
) -> "Dataset":
    """Create and cache a new Dataset by flattening the indices mapping."""
    return self.map(
        batched=True,  # for speed
        keep_in_memory=keep_in_memory,
        cache_file_name=cache_file_name,
        writer_batch_size=writer_batch_size,
        features=features,
        disable_nullable=disable_nullable,
        new_fingerprint=new_fingerprint,
        desc="Flattening the indices",
        num_proc=num_proc,
    )

The implementation reveals that flatten_indices() is a specialized map() call with no transform function, using batched=True for throughput. It supports num_proc for parallel flattening and keep_in_memory to avoid disk writes for smaller datasets.

Indices Composition in select() (arrow_dataset.py:4257-4262)

indices_array = pa.array(indices, type=pa.uint64())
# Check if we need to convert indices
if self._indices is not None:
    indices_array = self._indices.column(0).take(indices_array)

indices_table = pa.Table.from_arrays([indices_array], names=["indices"])

When select() is called on a dataset that already has an indices mapping, it composes the mappings rather than flattening first. This means indices mappings can chain, compounding the indirection overhead.

Forced Flattening in add_column() (arrow_dataset.py:6119)

dataset = self.flatten_indices() if self._indices is not None else self

The library itself calls flatten_indices() internally when it needs physical alignment, confirming that indices mappings are incompatible with certain operations.

Forced Flattening in concatenate_datasets() (arrow_dataset.py:6550)

for i in range(len(dsets)):
    dsets[i] = dsets[i].flatten_indices()

When concatenating datasets along axis=1 (column-wise), all datasets are forcibly flattened because the rows must be physically aligned for horizontal concatenation.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment