Implementation:Hiyouga LLaMA Factory V1 Batching
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, Data Processing, Distributed Training |
| Last Updated | 2026-02-06 19:00 GMT |
Overview
BatchGenerator is a stateful batch generation iterator that manages the complete data-to-batch pipeline including distributed sampling, tokenization, buffering, and micro-batch generation with state persistence for training resumption.
Description
The batching module provides BatchGenerator (implementing Python's Iterator protocol) and the default_collate_fn helper function. BatchGenerator wraps a StatefulDataLoader with StatefulDistributedSampler to handle distributed data sharding across processes. It uses a StatefulBuffer to accumulate tokenized samples produced by the Renderer, then generates micro-batches via default_collate_fn which pads and truncates sequences to a fixed cutoff length. The generator supports gradient accumulation through configurable num_micro_batch and global_batch_size parameters, and provides state_dict / load_state_dict methods for checkpointing and training resumption. The batching strategy is extensible through the BatchingPlugin system.
Usage
BatchGenerator is typically instantiated by BaseTrainer during initialization. It should not usually be created directly unless building a custom training loop. Configure it through TrainingArguments fields such as micro_batch_size, global_batch_size, cutoff_len, batching_workers, and batching_strategy.
Code Reference
Source Location
- Repository: Hiyouga_LLaMA_Factory
- File: src/llamafactory/v1/core/utils/batching.py
- Lines: 1-244
Signature
def default_collate_fn(buffer: StatefulBuffer, batch_info: BatchInfo) -> list[BatchInput] | None: ...
class BatchGenerator(Iterator):
def __init__(
self,
dataset: TorchDataset,
renderer: Renderer,
micro_batch_size: int = 1,
global_batch_size: int | None = None,
cutoff_len: int = 2048,
batching_workers: int = 0,
batching_strategy: BatchingStrategy = BatchingStrategy.NORMAL,
pin_memory: bool = True,
drop_last: bool = True,
) -> None: ...
def __len__(self) -> int: ...
def __iter__(self): ...
def __next__(self): ...
def state_dict(self) -> dict[str, Any]: ...
def load_state_dict(self, state: dict[str, Any]) -> None: ...
def set_epoch(self, epoch: int) -> None: ...
Import
from llamafactory.v1.core.utils.batching import BatchGenerator, default_collate_fn
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| dataset | TorchDataset | Yes | The training dataset (typically a DataEngine instance). |
| renderer | Renderer | Yes | The renderer for processing raw samples into tokenized model inputs. |
| micro_batch_size | int | No | Number of samples per micro-batch (default: 1). |
| global_batch_size | int or None | No | Total batch size across all DP ranks. Must be divisible by (dp_size * micro_batch_size). Default: dp_size * micro_batch_size. |
| cutoff_len | int | No | Maximum sequence length for padding/truncation (default: 2048). |
| batching_workers | int | No | Number of dataloader workers (default: 0). |
| batching_strategy | BatchingStrategy | No | Strategy for batch generation (default: NORMAL). |
| pin_memory | bool | No | Whether to pin memory for faster GPU transfer (default: True). |
| drop_last | bool | No | Whether to drop the last incomplete batch (default: True, must be True). |
Outputs
| Name | Type | Description |
|---|---|---|
| __next__ return | list[BatchInput] | A list of micro-batches (one per gradient accumulation step), each containing padded/truncated input_ids, labels, attention_mask, and loss_weights tensors. |
| state_dict return | dict[str, Any] | Serializable state for training resumption, including buffer contents and dataloader state. |
Usage Examples
from llamafactory.v1.core.utils.batching import BatchGenerator
from llamafactory.v1.core.data_engine import DataEngine
from llamafactory.v1.core.model_engine import ModelEngine
data_engine = DataEngine("data/v1_sft_demo.yaml")
model_engine = ModelEngine(model_args=model_args)
batch_generator = BatchGenerator(
dataset=data_engine,
renderer=model_engine.renderer,
micro_batch_size=2,
global_batch_size=4,
cutoff_len=2048,
batching_workers=0,
)
for micro_batches in batch_generator:
for micro_batch in micro_batches:
print(micro_batch["input_ids"].shape)
break
Related Pages
- Hiyouga_LLaMA_Factory_V1_Base_Trainer - Creates and consumes BatchGenerator in the training loop.
- Hiyouga_LLaMA_Factory_V1_Rendering - The Renderer used as the collate function for tokenization.
- Hiyouga_LLaMA_Factory_V1_Data_Engine - The dataset that BatchGenerator wraps.