Implementation:LLMBook zh LLMBook zh github io DataCollatorForSupervisedDataset
Appearance
| Knowledge Sources | |
|---|---|
| Domains | NLP, Data_Engineering |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Concrete tool for padding and batching SFT training data with IGNORE_INDEX label masking provided by the LLMBook repository.
Description
The DataCollatorForSupervisedDataset is a dataclass-based collator that takes a list of (input_ids, labels) instances, pads them using torch.nn.utils.rnn.pad_sequence, and returns a batched dictionary. Input sequences are padded with the tokenizer's pad_token_id, while labels are padded with IGNORE_INDEX (-100).
Usage
Pass an instance of this class as data_collator to HuggingFace Trainer when doing supervised fine-tuning with variable-length sequences.
Code Reference
Source Location
- Repository: LLMBook-zh
- File: code/7.1 SFT实践.py
- Lines: 48-61
Signature
@dataclass
class DataCollatorForSupervisedDataset:
tokenizer: PreTrainedTokenizer
def __call__(self, instances: list[dict]) -> dict:
"""
Pads and batches a list of (input_ids, labels) instances.
Args:
instances: List of dicts with 'input_ids' and 'labels' keys (Tensors).
Returns:
dict(input_ids=Tensor[batch, max_len], labels=Tensor[batch, max_len])
"""
Import
from sft_training import DataCollatorForSupervisedDataset
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| instances | list[dict] | Yes | List of dicts with 'input_ids' and 'labels' Tensor keys |
| tokenizer | PreTrainedTokenizer | Yes | Tokenizer (provides pad_token_id) |
Outputs
| Name | Type | Description |
|---|---|---|
| input_ids | Tensor | Padded batch [batch_size, max_seq_len], padded with pad_token_id |
| labels | Tensor | Padded batch [batch_size, max_seq_len], padded with -100 |
Usage Examples
from transformers import AutoTokenizer
import torch
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
# Simulate two variable-length examples
instances = [
{"input_ids": torch.tensor([1, 2, 3]), "labels": torch.tensor([-100, 2, 3])},
{"input_ids": torch.tensor([4, 5, 6, 7, 8]), "labels": torch.tensor([-100, -100, 6, 7, 8])},
]
batch = collator(instances)
print(batch["input_ids"].shape) # [2, 5]
print(batch["labels"].shape) # [2, 5]
Related Pages
Implements Principle
Requires Environment
Uses Heuristic
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment