Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Heuristic:Hiyouga LLaMA Factory Gradient Checkpointing Memory Optimization

From Leeroopedia



Knowledge Sources
Domains Memory_Optimization, LLMs
Last Updated 2026-02-06 20:00 GMT

Overview

Memory optimization technique using gradient checkpointing with optional Unsloth CPU offloading to drastically reduce VRAM usage during training.

Description

LLaMA Factory implements three gradient checkpointing strategies: standard reentrant (default), non-reentrant, and Unsloth-style CPU offloading. Standard gradient checkpointing trades compute for memory by recomputing activations during the backward pass. The Unsloth variant goes further by offloading saved hidden states to CPU RAM during the forward pass and moving them back to GPU during backward, achieving even greater VRAM savings. Additionally, a custom wrapper ensures gradient checkpointing only applies to layers with trainable parameters, which is critical for efficient LoRA and freeze fine-tuning.

Usage

Use this heuristic when you are VRAM constrained during training. Gradient checkpointing is enabled by default in LLaMA Factory. Consider the Unsloth variant (use_unsloth_gc=True) when fine-tuning 7B+ parameter models on consumer GPUs (e.g., RTX 3090/4090 with 24GB VRAM).

The Insight (Rule of Thumb)

  • Action: Keep gradient checkpointing enabled (default). For maximum VRAM savings, set use_unsloth_gc=True.
  • Value: Standard GC reduces VRAM by ~40-50%. Unsloth GC can save an additional 10-20% by offloading to CPU.
  • Trade-off: Standard GC adds ~20-30% training time overhead. Unsloth GC adds more overhead due to CPU-GPU data transfers. Non-reentrant GC (use_reentrant_gc=False) may increase VRAM (not yet empirically verified per code comments).
  • Critical: When gradient checkpointing is enabled, use_cache is automatically set to False (KV cache is incompatible with GC).
  • FSDP2 Note: FSDP2 automatically forces use_reentrant_gc=False.

Reasoning

Deep transformer models store massive intermediate activations (Batch x SeqLen x Hidden) for backpropagation. Gradient checkpointing discards these activations and recomputes them on-the-fly during the backward pass. The custom wrapper in LLaMA Factory adds an optimization: it only checkpoints layers that have trainable parameters. For LoRA fine-tuning where most layers are frozen, this means frozen layers skip checkpointing entirely, reducing unnecessary recomputation.

The Unsloth variant further optimizes by moving hidden states to CPU memory using non-blocking transfers, exploiting the CPU-GPU memory bandwidth while the GPU computes the forward pass.

Code evidence from src/llamafactory/model/model_utils/checkpointing.py:43-77:

class UnslothGradientCheckpointing(torch.autograd.Function):
    r"""Saves VRAM by smartly offloading to RAM."""
    @staticmethod
    @torch.cuda.amp.custom_fwd
    def forward(ctx, forward_function, hidden_states, *args):
        saved_hidden_states = hidden_states.to("cpu", non_blocking=True)
        with torch.no_grad():
            outputs = forward_function(hidden_states, *args)
        ctx.save_for_backward(saved_hidden_states)
        ctx.forward_function = forward_function
        ctx.args = args
        return outputs

Custom wrapper skipping frozen layers from src/llamafactory/model/model_utils/checkpointing.py:80-103:

def get_custom_gradient_checkpointing_func(gradient_checkpointing_func):
    r"""Only applies gradient checkpointing to trainable layers."""
    def custom_gradient_checkpointing_func(func, *args, **kwargs):
        module = func.__self__
        has_grad = False
        if any(param.requires_grad for param in module.parameters()):
            has_grad = True
            for arg in args:
                if torch.is_tensor(arg) and torch.is_floating_point(arg):
                    arg.requires_grad_(True)
                    break
        if has_grad:
            return gradient_checkpointing_func(func, *args, **kwargs)
        else:
            return func(*args, **kwargs)
    return custom_gradient_checkpointing_func

FSDP2 force non-reentrant from src/llamafactory/model/model_utils/checkpointing.py:157-162:

if (
    os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true"
    and int(os.environ.get("FSDP_VERSION", "1")) == 2
):
    model_args.use_reentrant_gc = False
    logger.warning_rank0("You are using fsdp2, `use_reentrant_gc` has been set to False.")

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment