Implementation:Hpcaitech ColossalAI DataCollatorForPreferenceDataset
Appearance
| Knowledge Sources | |
|---|---|
| Domains | NLP, Data_Engineering |
| Last Updated | 2026-02-09 00:00 GMT |
Overview
Concrete tool for collating preference pair batches for DPO training, provided by ColossalChat.
Description
DataCollatorForPreferenceDataset pads chosen and rejected sequences to the maximum length in each batch, handling attention masks and loss masks for both sequences. StatefulDistributedSampler extends PyTorch's DistributedSampler with state tracking for training resumption.
Usage
Use as the collate_fn in a DataLoader for DPO training data.
Code Reference
Source Location
- Repository: ColossalAI
- File: applications/ColossalChat/coati/dataset/loader.py
- Lines: 234-349
Signature
class DataCollatorForPreferenceDataset:
"""Collator for preference datasets (DPO, SimPO, ORPO)."""
# Pads chosen_input_ids, rejected_input_ids, masks to max_length
class StatefulDistributedSampler(DistributedSampler):
"""Sampler that maintains state for training resumption."""
def __init__(
self,
dataset: Dataset,
num_replicas: int = None,
rank: int = None,
shuffle: bool = True,
):
"""Extends DistributedSampler with state tracking."""
Import
from coati.dataset import (
DataCollatorForPreferenceDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| dataset_paths | List[str] | Yes | Paths to tokenized Arrow datasets |
| tokenizer | PreTrainedTokenizer | Yes | For pad_token_id |
| max_length | int | No | Maximum sequence length |
Outputs
| Name | Type | Description |
|---|---|---|
| DataLoader batches | Dict[str, Tensor] | Batches with chosen_input_ids, chosen_attention_mask, chosen_loss_mask, rejected_input_ids, rejected_attention_mask, rejected_loss_mask |
Usage Examples
from coati.dataset import (
DataCollatorForPreferenceDataset,
StatefulDistributedSampler,
load_tokenized_dataset,
)
from torch.utils.data import DataLoader
# Load tokenized preference data
dataset = load_tokenized_dataset(dataset_paths=["/data/tokenized/preference"], mode="train")
# Create dataloader
data_collator = DataCollatorForPreferenceDataset(tokenizer=tokenizer, max_length=4096)
sampler = StatefulDistributedSampler(dataset, shuffle=True)
dataloader = DataLoader(
dataset,
batch_size=4,
sampler=sampler,
collate_fn=data_collator,
drop_last=True,
)
Related Pages
Implements Principle
Environment and Heuristic Links
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment