Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Axolotl ai cloud Axolotl MultipackBatchSampler

From Leeroopedia


Knowledge Sources
Domains Training_Efficiency, Data_Loading
Last Updated 2026-02-06 23:00 GMT

Overview

Concrete tool for constructing optimally packed batches of variable-length sequences provided by the Axolotl framework.

Description

The MultipackBatchSampler class extends PyTorch's BatchSampler to produce batches where multiple sequences are packed into fixed-capacity bins. It supports both parallel packing (using FFD algorithm with multiprocessing) and sequential packing (preserving original sequence order). The sampler estimates batch counts for distributed training synchronization and supports safe mode for conservative packing.

Key features:

  • FFD bin packing for optimal sequence grouping
  • Multiprocessing support for parallel packing of large datasets
  • Sequential mode for order-preserving packing
  • Safe mode for conservative packing that prevents training instability
  • Numba JIT optional acceleration for packing algorithm

Usage

Used automatically when cfg.sample_packing is enabled. The HFCausalTrainerBuilder creates this sampler and passes it to the data loader.

Code Reference

Source Location

  • Repository: axolotl
  • File: src/axolotl/utils/samplers/multipack.py
  • Lines: L244-473

Signature

class MultipackBatchSampler(BatchSampler):
    """Batch sampler for efficient packing of variable-length sequences."""

    def __init__(
        self,
        sampler: Union[Sampler[int], Iterable[int]],
        batch_size: int,              # Number of bins per batch
        batch_max_len: int,           # Maximum sequence length (bin capacity)
        lengths: np.ndarray,          # Sequence lengths array
        bin_size: int,                # Max samples per bin
        packing_efficiency_estimate: float = 1.0,
        drop_last: bool = True,
        num_count_samples: int = 4,
        sequential: bool = False,
        group_size: int = 100_000,
        num_processes: int | None = None,
        safe_mode: bool = True,
        mp_start_method: str = "fork",
        **kwargs,
    ):

Import

from axolotl.utils.samplers.multipack import MultipackBatchSampler

I/O Contract

Inputs

Name Type Required Description
sampler Sampler or Iterable[int] Yes Base sampler providing dataset indices
batch_size int Yes Number of bins (packed sequences) per batch
batch_max_len int Yes Maximum total tokens per bin (sequence length capacity)
lengths np.ndarray Yes Array of sequence lengths for all dataset samples
bin_size int Yes Maximum number of individual sequences per bin
sequential bool No (default: False) Preserve original sequence ordering
safe_mode bool No (default: True) Conservative packing for stability

Outputs

Name Type Description
batches (iterator) list[list[list[int]]] Batches of bins, each bin containing indices of packed sequences

Usage Examples

Basic Usage with DataLoader

from axolotl.utils.samplers.multipack import MultipackBatchSampler
from torch.utils.data import SequentialSampler
import numpy as np

# Get sequence lengths from dataset
lengths = np.array([len(sample["input_ids"]) for sample in dataset])

sampler = SequentialSampler(dataset)
batch_sampler = MultipackBatchSampler(
    sampler=sampler,
    batch_size=4,           # 4 packed sequences per batch
    batch_max_len=2048,     # Max 2048 tokens per packed sequence
    lengths=lengths,
    bin_size=8,             # Max 8 original sequences per pack
)

# Use with DataLoader
dataloader = DataLoader(dataset, batch_sampler=batch_sampler)

Related Pages

Implements Principle

Requires Environment

Uses Heuristic

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment