Heuristic:Hiyouga LLaMA Factory Gradient Checkpointing Memory Optimization
| Knowledge Sources | |
|---|---|
| Domains | Memory_Optimization, LLMs |
| Last Updated | 2026-02-06 20:00 GMT |
Overview
Memory optimization technique using gradient checkpointing with optional Unsloth CPU offloading to drastically reduce VRAM usage during training.
Description
LLaMA Factory implements three gradient checkpointing strategies: standard reentrant (default), non-reentrant, and Unsloth-style CPU offloading. Standard gradient checkpointing trades compute for memory by recomputing activations during the backward pass. The Unsloth variant goes further by offloading saved hidden states to CPU RAM during the forward pass and moving them back to GPU during backward, achieving even greater VRAM savings. Additionally, a custom wrapper ensures gradient checkpointing only applies to layers with trainable parameters, which is critical for efficient LoRA and freeze fine-tuning.
Usage
Use this heuristic when you are VRAM constrained during training. Gradient checkpointing is enabled by default in LLaMA Factory. Consider the Unsloth variant (use_unsloth_gc=True) when fine-tuning 7B+ parameter models on consumer GPUs (e.g., RTX 3090/4090 with 24GB VRAM).
The Insight (Rule of Thumb)
- Action: Keep gradient checkpointing enabled (default). For maximum VRAM savings, set
use_unsloth_gc=True. - Value: Standard GC reduces VRAM by ~40-50%. Unsloth GC can save an additional 10-20% by offloading to CPU.
- Trade-off: Standard GC adds ~20-30% training time overhead. Unsloth GC adds more overhead due to CPU-GPU data transfers. Non-reentrant GC (
use_reentrant_gc=False) may increase VRAM (not yet empirically verified per code comments). - Critical: When gradient checkpointing is enabled,
use_cacheis automatically set toFalse(KV cache is incompatible with GC). - FSDP2 Note: FSDP2 automatically forces
use_reentrant_gc=False.
Reasoning
Deep transformer models store massive intermediate activations (Batch x SeqLen x Hidden) for backpropagation. Gradient checkpointing discards these activations and recomputes them on-the-fly during the backward pass. The custom wrapper in LLaMA Factory adds an optimization: it only checkpoints layers that have trainable parameters. For LoRA fine-tuning where most layers are frozen, this means frozen layers skip checkpointing entirely, reducing unnecessary recomputation.
The Unsloth variant further optimizes by moving hidden states to CPU memory using non-blocking transfers, exploiting the CPU-GPU memory bandwidth while the GPU computes the forward pass.
Code evidence from src/llamafactory/model/model_utils/checkpointing.py:43-77:
class UnslothGradientCheckpointing(torch.autograd.Function):
r"""Saves VRAM by smartly offloading to RAM."""
@staticmethod
@torch.cuda.amp.custom_fwd
def forward(ctx, forward_function, hidden_states, *args):
saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
with torch.no_grad():
outputs = forward_function(hidden_states, *args)
ctx.save_for_backward(saved_hidden_states)
ctx.forward_function = forward_function
ctx.args = args
return outputs
Custom wrapper skipping frozen layers from src/llamafactory/model/model_utils/checkpointing.py:80-103:
def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
r"""Only applies gradient checkpointing to trainable layers."""
def custom_gradient_checkpointing_func(func, *args, **kwargs):
module = func.__self__
has_grad = False
if any(param.requires_grad for param in module.parameters()):
has_grad = True
for arg in args:
if torch.is_tensor(arg) and torch.is_floating_point(arg):
arg.requires_grad_(True)
break
if has_grad:
return gradient_checkpointing_func(func, *args, **kwargs)
else:
return func(*args, **kwargs)
return custom_gradient_checkpointing_func
FSDP2 force non-reentrant from src/llamafactory/model/model_utils/checkpointing.py:157-162:
if (
os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
and int(os.environ.get("FSDP_VERSION", "1")) == 2
):
model_args.use_reentrant_gc = False
logger.warning_rank0("You are using fsdp2, `use_reentrant_gc` has been set to False.")