Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hiyouga LLaMA Factory V1 Batching

From Leeroopedia


Knowledge Sources
Domains Machine Learning, Data Processing, Distributed Training
Last Updated 2026-02-06 19:00 GMT

Overview

BatchGenerator is a stateful batch generation iterator that manages the complete data-to-batch pipeline including distributed sampling, tokenization, buffering, and micro-batch generation with state persistence for training resumption.

Description

The batching module provides BatchGenerator (implementing Python's Iterator protocol) and the default_collate_fn helper function. BatchGenerator wraps a StatefulDataLoader with StatefulDistributedSampler to handle distributed data sharding across processes. It uses a StatefulBuffer to accumulate tokenized samples produced by the Renderer, then generates micro-batches via default_collate_fn which pads and truncates sequences to a fixed cutoff length. The generator supports gradient accumulation through configurable num_micro_batch and global_batch_size parameters, and provides state_dict / load_state_dict methods for checkpointing and training resumption. The batching strategy is extensible through the BatchingPlugin system.

Usage

BatchGenerator is typically instantiated by BaseTrainer during initialization. It should not usually be created directly unless building a custom training loop. Configure it through TrainingArguments fields such as micro_batch_size, global_batch_size, cutoff_len, batching_workers, and batching_strategy.

Code Reference

Source Location

Signature

def default_collate_fn(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None: ...

class BatchGenerator(Iterator):
    def __init__(
        self,
        dataset: TorchDataset,
        renderer: Renderer,
        micro_batch_size: int = 1,
        global_batch_size: int | None = None,
        cutoff_len: int = 2048,
        batching_workers: int = 0,
        batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
        pin_memory: bool = True,
        drop_last: bool = True,
    ) -> None: ...

    def __len__(self) -> int: ...
    def __iter__(self): ...
    def __next__(self): ...

    def state_dict(self) -> dict[str, Any]: ...
    def load_state_dict(self, state: dict[str, Any]) -> None: ...
    def set_epoch(self, epoch: int) -> None: ...

Import

from llamafactory.v1.core.utils.batching import BatchGenerator, default_collate_fn

I/O Contract

Inputs

Name Type Required Description
dataset TorchDataset Yes The training dataset (typically a DataEngine instance).
renderer Renderer Yes The renderer for processing raw samples into tokenized model inputs.
micro_batch_size int No Number of samples per micro-batch (default: 1).
global_batch_size int or None No Total batch size across all DP ranks. Must be divisible by (dp_size * micro_batch_size). Default: dp_size * micro_batch_size.
cutoff_len int No Maximum sequence length for padding/truncation (default: 2048).
batching_workers int No Number of dataloader workers (default: 0).
batching_strategy BatchingStrategy No Strategy for batch generation (default: NORMAL).
pin_memory bool No Whether to pin memory for faster GPU transfer (default: True).
drop_last bool No Whether to drop the last incomplete batch (default: True, must be True).

Outputs

Name Type Description
__next__ return list[BatchInput] A list of micro-batches (one per gradient accumulation step), each containing padded/truncated input_ids, labels, attention_mask, and loss_weights tensors.
state_dict return dict[str, Any] Serializable state for training resumption, including buffer contents and dataloader state.

Usage Examples

from llamafactory.v1.core.utils.batching import BatchGenerator
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.model_engine import ModelEngine

data_engine = DataEngine("data/v1_sft_demo.yaml")
model_engine = ModelEngine(model_args=model_args)

batch_generator = BatchGenerator(
    dataset=data_engine,
    renderer=model_engine.renderer,
    micro_batch_size=2,
    global_batch_size=4,
    cutoff_len=2048,
    batching_workers=0,
)

for micro_batches in batch_generator:
    for micro_batch in micro_batches:
        print(micro_batch["input_ids"].shape)
    break

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment