Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Allenai Open instruct TensorDataCollatorWithFlattening

From Leeroopedia


Knowledge Sources
Domains Machine Learning, Deep Learning, Systems Optimization
Last Updated 2026-02-07 00:00 GMT

Overview

Concrete tool for padding-free data collation that concatenates multiple training examples into a single sequence for efficient GPU utilization, provided by the Open Instruct library.

Description

The TensorDataCollatorWithFlattening class inherits from HuggingFace's DefaultDataCollator and implements a custom __call__ method that packs multiple examples into a single batch-of-one tensor. For each batch of features, it:

  1. Concatenates all input_ids into a single 1D tensor (unsqueezed to batch dimension 1).
  2. Constructs labels by inserting a separator token (-100) at each example boundary, then concatenating the labels (skipping the first token of each example to align with the shifted loss).
  3. Computes cumulative sequence lengths (cu_seq_lens_q/cu_seq_lens_k) and maximum sequence length for Flash Attention's variable-length attention kernel.
  4. Generates position IDs that reset to 0 at each example boundary for correct positional encoding.
  5. Assigns sequence indices (seq_idx) that map each token to its source example.

The output batch has a batch dimension of 1 (since all examples are packed into a single sequence), and the model uses the Flash Attention metadata to compute attention independently within each packed example.

Usage

Use this collator when packing=True is set in the training configuration. It replaces the default data collator and requires a model that supports Flash Attention 2 with variable-length inputs. It is enabled in finetune.py via the --packing flag.

Code Reference

Source Location

  • Repository: Open Instruct
  • File: open_instruct/padding_free_collator.py
  • Lines: L8-73

Signature

@dataclass
class TensorDataCollatorWithFlattening(DefaultDataCollator):
    """
    Data collator for padding-free training along the lines of
    https://huggingface.co/blog/packing-with-FA2
    """

    return_flash_attn_kwargs: bool = True
    return_position_ids: bool = True
    return_seq_idx: bool = True
    separator_id: int = -100

    def __call__(self, features, return_tensors=None, separator_id=None) -> dict:
        ...

Import

from open_instruct.padding_free_collator import TensorDataCollatorWithFlattening

I/O Contract

Inputs

Constructor:

Name Type Required Description
return_flash_attn_kwargs bool No Whether to include cumulative sequence lengths and max length in the output. Defaults to True.
return_position_ids bool No Whether to include position IDs in the output. Defaults to True.
return_seq_idx bool No Whether to include sequence indices in the output. Defaults to True.
separator_id int No Token ID used as separator between packed examples in labels. Defaults to -100 (PyTorch ignore index).

__call__():

Name Type Required Description
features list[dict] Yes A list of example dicts, each containing "input_ids" (torch.Tensor) and optionally "labels" (torch.Tensor).
return_tensors str or None No Tensor return format (inherited from parent). Defaults to class attribute.
separator_id int or None No Override for the separator ID. If None, uses the class default.

Outputs

Name Type Description
input_ids torch.Tensor [1, total_len] Concatenated input token IDs from all examples, with batch dimension 1.
labels torch.Tensor [1, total_len] Concatenated labels with -100 separator at example boundaries.
cu_seq_lens_q torch.Tensor [B+1] Cumulative sequence lengths for Flash Attention queries (if return_flash_attn_kwargs=True).
cu_seq_lens_k torch.Tensor [B+1] Cumulative sequence lengths for Flash Attention keys (same as cu_seq_lens_q).
max_length_q int Maximum individual sequence length in the batch (if return_flash_attn_kwargs=True).
max_length_k int Same as max_length_q.
position_ids torch.Tensor [1, total_len] Per-token position IDs, reset to 0 at each example boundary (if return_position_ids=True).
seq_idx torch.Tensor [1, total_len] Per-token example index (if return_seq_idx=True).

Usage Examples

Basic Usage

from open_instruct.padding_free_collator import TensorDataCollatorWithFlattening
import torch

collator = TensorDataCollatorWithFlattening()

# Two examples of different lengths
features = [
    {"input_ids": torch.tensor([1, 2, 3, 4, 5]), "labels": torch.tensor([1, 2, 3, 4, 5])},
    {"input_ids": torch.tensor([10, 20, 30]), "labels": torch.tensor([-100, -100, 30])},
]

batch = collator(features)

# batch["input_ids"].shape = [1, 8]  (5 + 3 tokens)
# batch["labels"].shape = [1, 8]
# batch["cu_seq_lens_q"] = tensor([0, 5, 8])
# batch["position_ids"] = tensor([[0, 1, 2, 3, 4, 0, 1, 2]])

Integration with DataLoader

from torch.utils.data import DataLoader
from open_instruct.padding_free_collator import TensorDataCollatorWithFlattening

collator = TensorDataCollatorWithFlattening()
dataloader = DataLoader(
    train_dataset,
    batch_size=8,  # 8 examples packed into 1 sequence
    collate_fn=collator,
    shuffle=True,
)

for batch in dataloader:
    outputs = model(**batch)
    loss = outputs.loss
    ...

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment