Implementation:Hiyouga LLaMA Factory Gradient Checkpointing
| Knowledge Sources | |
|---|---|
| Domains | Memory Optimization, Training Infrastructure |
| Last Updated | 2026-02-06 19:00 GMT |
Overview
Configures gradient checkpointing strategies and prepares models for training, including Unsloth-style CPU offloading, selective layer checkpointing, layernorm upcasting, and output head precision management.
Description
This module provides multiple gradient checkpointing implementations and model preparation utilities. get_unsloth_gradient_checkpointing_func returns a custom torch.autograd.Function that offloads hidden states to CPU during the forward pass and reloads them during backward, saving significant VRAM. get_custom_gradient_checkpointing_func wraps any gradient checkpointing function to only apply it to layers with trainable parameters, skipping frozen layers for efficiency. _gradient_checkpointing_enable is a monkey-patched replacement for the standard HuggingFace method, supporting both reentrant and non-reentrant modes. prepare_model_for_training orchestrates the full training preparation pipeline: upcasting layernorm weights to float32 for numerical stability, enabling gradient checkpointing with the selected strategy, disabling KV cache, and optionally upcasting the language model head output to float32.
Usage
Use prepare_model_for_training after loading a model to configure it for training. The function is called automatically by the model patcher. Use the individual gradient checkpointing functions when implementing custom training loops or extending the framework.
Code Reference
Source Location
- Repository: Hiyouga_LLaMA_Factory
- File: src/llamafactory/model/model_utils/checkpointing.py
- Lines: 1-184
Signature
def get_unsloth_gradient_checkpointing_func() -> Callable:
...
def get_custom_gradient_checkpointing_func(
gradient_checkpointing_func: Callable,
) -> Callable:
...
def _gradient_checkpointing_enable(
self: "PreTrainedModel",
gradient_checkpointing_kwargs: Optional[dict[str, Any]] = None,
use_unsloth_gc: bool = False,
) -> None:
...
def prepare_model_for_training(
model: "PreTrainedModel",
model_args: "ModelArguments",
) -> None:
...
Import
from llamafactory.model.model_utils.checkpointing import prepare_model_for_training
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model | PreTrainedModel | Yes | The pretrained model to prepare for training |
| model_args | ModelArguments | Yes | Model arguments controlling checkpointing behavior |
| model_args.upcast_layernorm | bool | No | Whether to cast layernorm weights to float32 |
| model_args.disable_gradient_checkpointing | bool | No | Whether to skip gradient checkpointing entirely |
| model_args.use_unsloth_gc | bool | No | Whether to use Unsloth-style CPU offloading for gradient checkpointing |
| model_args.use_reentrant_gc | bool | No | Whether to use reentrant gradient checkpointing (auto-disabled for FSDP2) |
| model_args.upcast_lmhead_output | bool | No | Whether to cast language model head output to float32 |
Outputs
| Name | Type | Description |
|---|---|---|
| (side effect) | None | Modifies the model in-place: enables gradient checkpointing, casts layernorm/lmhead, disables KV cache |
Usage Examples
from llamafactory.model.model_utils.checkpointing import prepare_model_for_training
# Standard preparation with gradient checkpointing
prepare_model_for_training(model, model_args)
# The Unsloth gradient checkpointing function can be used standalone
from llamafactory.model.model_utils.checkpointing import get_unsloth_gradient_checkpointing_func
gc_func = get_unsloth_gradient_checkpointing_func()
# gc_func offloads hidden states to CPU during forward, reloads during backward
# Custom gradient checkpointing that skips frozen layers
from llamafactory.model.model_utils.checkpointing import get_custom_gradient_checkpointing_func
from functools import partial
from torch.utils.checkpoint import checkpoint
base_func = partial(checkpoint, use_reentrant=False)
custom_func = get_custom_gradient_checkpointing_func(base_func)
# custom_func only checkpoints layers that have trainable parameters
Related Pages
- Hiyouga_LLaMA_Factory_Model_Loader - Model loader that triggers training preparation
- Hiyouga_LLaMA_Factory_Training_Args - Training arguments that control checkpointing behavior