Implementation:Axolotl ai cloud Axolotl MultipackBatchSampler
| Knowledge Sources | |
|---|---|
| Domains | Training_Efficiency, Data_Loading |
| Last Updated | 2026-02-06 23:00 GMT |
Overview
Concrete tool for constructing optimally packed batches of variable-length sequences provided by the Axolotl framework.
Description
The MultipackBatchSampler class extends PyTorch's BatchSampler to produce batches where multiple sequences are packed into fixed-capacity bins. It supports both parallel packing (using FFD algorithm with multiprocessing) and sequential packing (preserving original sequence order). The sampler estimates batch counts for distributed training synchronization and supports safe mode for conservative packing.
Key features:
- FFD bin packing for optimal sequence grouping
- Multiprocessing support for parallel packing of large datasets
- Sequential mode for order-preserving packing
- Safe mode for conservative packing that prevents training instability
- Numba JIT optional acceleration for packing algorithm
Usage
Used automatically when cfg.sample_packing is enabled. The HFCausalTrainerBuilder creates this sampler and passes it to the data loader.
Code Reference
Source Location
- Repository: axolotl
- File: src/axolotl/utils/samplers/multipack.py
- Lines: L244-473
Signature
class MultipackBatchSampler(BatchSampler):
"""Batch sampler for efficient packing of variable-length sequences."""
def __init__(
self,
sampler: Union[Sampler[int], Iterable[int]],
batch_size: int, # Number of bins per batch
batch_max_len: int, # Maximum sequence length (bin capacity)
lengths: np.ndarray, # Sequence lengths array
bin_size: int, # Max samples per bin
packing_efficiency_estimate: float = 1.0,
drop_last: bool = True,
num_count_samples: int = 4,
sequential: bool = False,
group_size: int = 100_000,
num_processes: int | None = None,
safe_mode: bool = True,
mp_start_method: str = "fork",
**kwargs,
):
Import
from axolotl.utils.samplers.multipack import MultipackBatchSampler
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| sampler | Sampler or Iterable[int] | Yes | Base sampler providing dataset indices |
| batch_size | int | Yes | Number of bins (packed sequences) per batch |
| batch_max_len | int | Yes | Maximum total tokens per bin (sequence length capacity) |
| lengths | np.ndarray | Yes | Array of sequence lengths for all dataset samples |
| bin_size | int | Yes | Maximum number of individual sequences per bin |
| sequential | bool | No (default: False) | Preserve original sequence ordering |
| safe_mode | bool | No (default: True) | Conservative packing for stability |
Outputs
| Name | Type | Description |
|---|---|---|
| batches (iterator) | list[list[list[int]]] | Batches of bins, each bin containing indices of packed sequences |
Usage Examples
Basic Usage with DataLoader
from axolotl.utils.samplers.multipack import MultipackBatchSampler
from torch.utils.data import SequentialSampler
import numpy as np
# Get sequence lengths from dataset
lengths = np.array([len(sample["input_ids"]) for sample in dataset])
sampler = SequentialSampler(dataset)
batch_sampler = MultipackBatchSampler(
sampler=sampler,
batch_size=4, # 4 packed sequences per batch
batch_max_len=2048, # Max 2048 tokens per packed sequence
lengths=lengths,
bin_size=8, # Max 8 original sequences per pack
)
# Use with DataLoader
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)