Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Principle:Hiyouga LLaMA Factory Gradient Checkpointing Theory

From Leeroopedia


Knowledge Sources
Domains Deep Learning, Memory Optimization
Last Updated 2026-02-06 19:00 GMT

Overview

Gradient checkpointing is a memory-efficient technique that trades additional computation for reduced GPU VRAM consumption during backpropagation by selectively discarding and recomputing intermediate activations.

Description

During the forward pass of a deep neural network, each layer produces intermediate activations that must be stored for the backward pass. For large transformer models with dozens or hundreds of layers, these stored activations dominate GPU memory consumption. Gradient checkpointing addresses this by not retaining activations for every layer during the forward pass. Instead, only activations at designated checkpoint boundaries are kept. During the backward pass, any discarded activations are recomputed on the fly from the nearest stored checkpoint.

The original paper by Chen et al. (2016) demonstrated that for a network with n layers, standard backpropagation requires O(n) memory for activations, whereas gradient checkpointing reduces this to O(n) memory at the cost of one additional forward pass through each segment.

In LLaMA-Factory, gradient checkpointing is implemented with two strategies:

  • Standard checkpointing via PyTorch's torch.utils.checkpoint.checkpoint, which recomputes activations within torch.no_grad() during the backward pass.
  • Unsloth gradient checkpointing, which additionally offloads hidden states to CPU RAM during the forward pass and retrieves them during backward, further reducing GPU VRAM at the cost of CPU-GPU data transfer.

A key enhancement in LLaMA-Factory is the custom gradient checkpointing function that only applies checkpointing to layers containing trainable parameters. This is critical for parameter-efficient fine-tuning (e.g., LoRA), where most layers are frozen; skipping checkpoint overhead on frozen layers avoids unnecessary recomputation and maintains efficiency.

Usage

Gradient checkpointing should be enabled when training large transformer models that exceed available GPU memory with full activation storage. It is particularly effective when:

  • Fine-tuning models with billions of parameters on consumer or mid-range GPUs.
  • Using long context lengths where activation memory scales linearly with sequence length.
  • Combining with LoRA or other PEFT methods, where the custom gradient checkpointing function avoids wasted recomputation on frozen layers.
  • When combined with FSDP2, use_reentrant is automatically set to False for compatibility with the distributed framework.

Theoretical Basis

The core insight from the sublinear memory cost paper is that memory and computation can be traded off systematically. For a sequential network of n layers:

Standard backpropagation stores all n activations:

Memory=O(n),Computation=O(n)

Gradient checkpointing divides the network into k segments, storing only k checkpoints:

Memory=O(k+n/k)

Choosing k=n yields the optimal trade-off:

Memory=O(n),Computation=O(n) (with one extra forward pass per segment)

The total computation cost increases by at most one full forward pass, making the overhead modest relative to the memory savings.

The Unsloth variant extends this by offloading checkpoint tensors to CPU memory via non-blocking transfers:

# Forward: save hidden_states to CPU
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
    outputs = forward_function(hidden_states, *args)

# Backward: reload from CPU to GPU and recompute
hidden_states = saved_hidden_states.to("cuda", non_blocking=True).detach()
hidden_states.requires_grad_(True)
with torch.enable_grad():
    outputs = forward_function(hidden_states, *args)

This approach exploits the CPU-GPU memory hierarchy, allowing GPU VRAM to hold only the currently active layer's tensors while idle checkpoints reside in CPU RAM. The non_blocking=True flag enables asynchronous transfers that overlap with computation on the GPU stream.

The selective checkpointing applied in LLaMA-Factory further optimizes this by checking param.requires_grad for each module:

has_grad = any(param.requires_grad for param in module.parameters())
if has_grad:
    return gradient_checkpointing_func(func, *args, **kwargs)
else:
    return func(*args, **kwargs)  # skip checkpointing for frozen layers

This ensures that layers without any trainable parameters (common during LoRA fine-tuning) execute their forward pass normally without the overhead of checkpoint wrapping, while layers with trainable parameters still benefit from memory savings.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment