Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Hpcaitech ColossalAI DataCollatorForPreferenceDataset

From Leeroopedia


Knowledge Sources
Domains NLP, Data_Engineering
Last Updated 2026-02-09 00:00 GMT

Overview

Concrete tool for collating preference pair batches for DPO training, provided by ColossalChat.

Description

DataCollatorForPreferenceDataset pads chosen and rejected sequences to the maximum length in each batch, handling attention masks and loss masks for both sequences. StatefulDistributedSampler extends PyTorch's DistributedSampler with state tracking for training resumption.

Usage

Use as the collate_fn in a DataLoader for DPO training data.

Code Reference

Source Location

  • Repository: ColossalAI
  • File: applications/ColossalChat/coati/dataset/loader.py
  • Lines: 234-349

Signature

class DataCollatorForPreferenceDataset:
    """Collator for preference datasets (DPO, SimPO, ORPO)."""
    # Pads chosen_input_ids, rejected_input_ids, masks to max_length

class StatefulDistributedSampler(DistributedSampler):
    """Sampler that maintains state for training resumption."""
    def __init__(
        self,
        dataset: Dataset,
        num_replicas: int = None,
        rank: int = None,
        shuffle: bool = True,
    ):
        """Extends DistributedSampler with state tracking."""

Import

from coati.dataset import (
    DataCollatorForPreferenceDataset,
    StatefulDistributedSampler,
    load_tokenized_dataset,
)

I/O Contract

Inputs

Name Type Required Description
dataset_paths List[str] Yes Paths to tokenized Arrow datasets
tokenizer PreTrainedTokenizer Yes For pad_token_id
max_length int No Maximum sequence length

Outputs

Name Type Description
DataLoader batches Dict[str, Tensor] Batches with chosen_input_ids, chosen_attention_mask, chosen_loss_mask, rejected_input_ids, rejected_attention_mask, rejected_loss_mask

Usage Examples

from coati.dataset import (
    DataCollatorForPreferenceDataset,
    StatefulDistributedSampler,
    load_tokenized_dataset,
)
from torch.utils.data import DataLoader

# Load tokenized preference data
dataset = load_tokenized_dataset(dataset_paths=["/data/tokenized/preference"], mode="train")

# Create dataloader
data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=4096)
sampler = StatefulDistributedSampler(dataset, shuffle=True)
dataloader = DataLoader(
    dataset,
    batch_size=4,
    sampler=sampler,
    collate_fn=data_collator,
    drop_last=True,
)

Related Pages

Implements Principle

Environment and Heuristic Links

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment