Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Hiyouga LLaMA Factory Gradient Checkpointing

From Leeroopedia


Knowledge Sources
Domains Memory Optimization, Training Infrastructure
Last Updated 2026-02-06 19:00 GMT

Overview

Configures gradient checkpointing strategies and prepares models for training, including Unsloth-style CPU offloading, selective layer checkpointing, layernorm upcasting, and output head precision management.

Description

This module provides multiple gradient checkpointing implementations and model preparation utilities. get_unsloth_gradient_checkpointing_func returns a custom torch.autograd.Function that offloads hidden states to CPU during the forward pass and reloads them during backward, saving significant VRAM. get_custom_gradient_checkpointing_func wraps any gradient checkpointing function to only apply it to layers with trainable parameters, skipping frozen layers for efficiency. _gradient_checkpointing_enable is a monkey-patched replacement for the standard HuggingFace method, supporting both reentrant and non-reentrant modes. prepare_model_for_training orchestrates the full training preparation pipeline: upcasting layernorm weights to float32 for numerical stability, enabling gradient checkpointing with the selected strategy, disabling KV cache, and optionally upcasting the language model head output to float32.

Usage

Use prepare_model_for_training after loading a model to configure it for training. The function is called automatically by the model patcher. Use the individual gradient checkpointing functions when implementing custom training loops or extending the framework.

Code Reference

Source Location

Signature

def get_unsloth_gradient_checkpointing_func() -> Callable:
    ...

def get_custom_gradient_checkpointing_func(
    gradient_checkpointing_func: Callable,
) -> Callable:
    ...

def _gradient_checkpointing_enable(
    self: "PreTrainedModel",
    gradient_checkpointing_kwargs: Optional[dict[str, Any]] = None,
    use_unsloth_gc: bool = False,
) -> None:
    ...

def prepare_model_for_training(
    model: "PreTrainedModel",
    model_args: "ModelArguments",
) -> None:
    ...

Import

from llamafactory.model.model_utils.checkpointing import prepare_model_for_training

I/O Contract

Inputs

Name Type Required Description
model PreTrainedModel Yes The pretrained model to prepare for training
model_args ModelArguments Yes Model arguments controlling checkpointing behavior
model_args.upcast_layernorm bool No Whether to cast layernorm weights to float32
model_args.disable_gradient_checkpointing bool No Whether to skip gradient checkpointing entirely
model_args.use_unsloth_gc bool No Whether to use Unsloth-style CPU offloading for gradient checkpointing
model_args.use_reentrant_gc bool No Whether to use reentrant gradient checkpointing (auto-disabled for FSDP2)
model_args.upcast_lmhead_output bool No Whether to cast language model head output to float32

Outputs

Name Type Description
(side effect) None Modifies the model in-place: enables gradient checkpointing, casts layernorm/lmhead, disables KV cache

Usage Examples

from llamafactory.model.model_utils.checkpointing import prepare_model_for_training

# Standard preparation with gradient checkpointing
prepare_model_for_training(model, model_args)

# The Unsloth gradient checkpointing function can be used standalone
from llamafactory.model.model_utils.checkpointing import get_unsloth_gradient_checkpointing_func

gc_func = get_unsloth_gradient_checkpointing_func()
# gc_func offloads hidden states to CPU during forward, reloads during backward

# Custom gradient checkpointing that skips frozen layers
from llamafactory.model.model_utils.checkpointing import get_custom_gradient_checkpointing_func
from functools import partial
from torch.utils.checkpoint import checkpoint

base_func = partial(checkpoint, use_reentrant=False)
custom_func = get_custom_gradient_checkpointing_func(base_func)
# custom_func only checkpoints layers that have trainable parameters

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment