Implementation:Allenai Open instruct TensorDataCollatorWithFlattening
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, Deep Learning, Systems Optimization |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete tool for padding-free data collation that concatenates multiple training examples into a single sequence for efficient GPU utilization, provided by the Open Instruct library.
Description
The TensorDataCollatorWithFlattening class inherits from HuggingFace's DefaultDataCollator and implements a custom __call__ method that packs multiple examples into a single batch-of-one tensor. For each batch of features, it:
- Concatenates all
input_idsinto a single 1D tensor (unsqueezed to batch dimension 1). - Constructs
labelsby inserting a separator token (-100) at each example boundary, then concatenating the labels (skipping the first token of each example to align with the shifted loss). - Computes cumulative sequence lengths (
cu_seq_lens_q/cu_seq_lens_k) and maximum sequence length for Flash Attention's variable-length attention kernel. - Generates position IDs that reset to 0 at each example boundary for correct positional encoding.
- Assigns sequence indices (
seq_idx) that map each token to its source example.
The output batch has a batch dimension of 1 (since all examples are packed into a single sequence), and the model uses the Flash Attention metadata to compute attention independently within each packed example.
Usage
Use this collator when packing=True is set in the training configuration. It replaces the default data collator and requires a model that supports Flash Attention 2 with variable-length inputs. It is enabled in finetune.py via the --packing flag.
Code Reference
Source Location
- Repository: Open Instruct
- File:
open_instruct/padding_free_collator.py - Lines: L8-73
Signature
@dataclass
class TensorDataCollatorWithFlattening(DefaultDataCollator):
"""
Data collator for padding-free training along the lines of
https://huggingface.co/blog/packing-with-FA2
"""
return_flash_attn_kwargs: bool = True
return_position_ids: bool = True
return_seq_idx: bool = True
separator_id: int = -100
def __call__(self, features, return_tensors=None, separator_id=None) -> dict:
...
Import
from open_instruct.padding_free_collator import TensorDataCollatorWithFlattening
I/O Contract
Inputs
Constructor:
| Name | Type | Required | Description |
|---|---|---|---|
| return_flash_attn_kwargs | bool | No | Whether to include cumulative sequence lengths and max length in the output. Defaults to True. |
| return_position_ids | bool | No | Whether to include position IDs in the output. Defaults to True. |
| return_seq_idx | bool | No | Whether to include sequence indices in the output. Defaults to True. |
| separator_id | int | No | Token ID used as separator between packed examples in labels. Defaults to -100 (PyTorch ignore index). |
__call__():
| Name | Type | Required | Description |
|---|---|---|---|
| features | list[dict] | Yes | A list of example dicts, each containing "input_ids" (torch.Tensor) and optionally "labels" (torch.Tensor).
|
| return_tensors | str or None | No | Tensor return format (inherited from parent). Defaults to class attribute. |
| separator_id | int or None | No | Override for the separator ID. If None, uses the class default. |
Outputs
| Name | Type | Description |
|---|---|---|
| input_ids | torch.Tensor [1, total_len] | Concatenated input token IDs from all examples, with batch dimension 1. |
| labels | torch.Tensor [1, total_len] | Concatenated labels with -100 separator at example boundaries. |
| cu_seq_lens_q | torch.Tensor [B+1] | Cumulative sequence lengths for Flash Attention queries (if return_flash_attn_kwargs=True).
|
| cu_seq_lens_k | torch.Tensor [B+1] | Cumulative sequence lengths for Flash Attention keys (same as cu_seq_lens_q). |
| max_length_q | int | Maximum individual sequence length in the batch (if return_flash_attn_kwargs=True).
|
| max_length_k | int | Same as max_length_q. |
| position_ids | torch.Tensor [1, total_len] | Per-token position IDs, reset to 0 at each example boundary (if return_position_ids=True).
|
| seq_idx | torch.Tensor [1, total_len] | Per-token example index (if return_seq_idx=True).
|
Usage Examples
Basic Usage
from open_instruct.padding_free_collator import TensorDataCollatorWithFlattening
import torch
collator = TensorDataCollatorWithFlattening()
# Two examples of different lengths
features = [
{"input_ids": torch.tensor([1, 2, 3, 4, 5]), "labels": torch.tensor([1, 2, 3, 4, 5])},
{"input_ids": torch.tensor([10, 20, 30]), "labels": torch.tensor([-100, -100, 30])},
]
batch = collator(features)
# batch["input_ids"].shape = [1, 8] (5 + 3 tokens)
# batch["labels"].shape = [1, 8]
# batch["cu_seq_lens_q"] = tensor([0, 5, 8])
# batch["position_ids"] = tensor([[0, 1, 2, 3, 4, 0, 1, 2]])
Integration with DataLoader
from torch.utils.data import DataLoader
from open_instruct.padding_free_collator import TensorDataCollatorWithFlattening
collator = TensorDataCollatorWithFlattening()
dataloader = DataLoader(
train_dataset,
batch_size=8, # 8 examples packed into 1 sequence
collate_fn=collator,
shuffle=True,
)
for batch in dataloader:
outputs = model(**batch)
loss = outputs.loss
...