Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Huggingface Peft Prepare Model For Kbit Training

From Leeroopedia


Metadata

Field Value
Source PEFT | https://github.com/huggingface/peft
Domains Quantization, Training
Last Updated 2026-02-07 00:00 GMT

Overview

prepare_model_for_kbit_training is a utility function that prepares a quantized transformer model for parameter-efficient fine-tuning. It implements the complete protocol described in the Quantized Model Preparation principle: freezing base model parameters, casting non-integer parameters to fp32 for numerical stability, and enabling gradient checkpointing for memory efficiency. This function is the standard entry point for any QLoRA or quantized adapter training workflow.

Source

File: src/peft/utils/other.py, lines 130-215

Repository: huggingface/peft

Signature

def prepare_model_for_kbit_training(
    model,
    use_gradient_checkpointing=True,
    gradient_checkpointing_kwargs=None
):

Import

from peft import prepare_model_for_kbit_training

Parameters

Parameter Type Default Description
model transformers.PreTrainedModel required The loaded model from transformers. Must be a model loaded with quantization enabled (e.g., load_in_4bit=True or load_in_8bit=True).
use_gradient_checkpointing bool True If True, enables gradient checkpointing to save memory at the expense of slower backward pass. Only takes effect when the model is detected as quantized.
gradient_checkpointing_kwargs dict or None None Keyword arguments passed to torch.utils.checkpoint.checkpoint. Requires transformers > 4.34.1. Common usage: {"use_reentrant": False}.

Return Value

Returns the modified model (same object, mutated in-place) with frozen parameters, fp32 casts applied, and gradient checkpointing enabled.

Behavior

The function performs the following operations in sequence:

1. Detect quantization method: The function checks which quantization backend was used by inspecting model attributes:

loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)
is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq"
is_aqlm_quantized = getattr(model, "quantization_method", None) == "aqlm"
is_eetq_quantized = getattr(model, "quantization_method", None) == "eetq"
is_torchao_quantized = getattr(model, "quantization_method", None) == "torchao"
is_hqq_quantized = getattr(model, "quantization_method", None) == "hqq" or getattr(model, "hqq_quantized", False)

2. Freeze all base parameters: Every parameter in the model has requires_grad set to False:

for name, param in model.named_parameters():
    param.requires_grad = False

3. Cast to fp32 (conditional): For bitsandbytes-quantized models (not GPTQ, AQLM, EETQ, HQQ, or TorchAO), all float16/bfloat16 parameters are cast to float32, except Params4bit instances which remain in their quantized form:

for param in model.parameters():
    if (
        (param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
    ) and param.__class__.__name__ != "Params4bit":
        param.data = param.data.to(torch.float32)

4. Enable gradient checkpointing (conditional): If the model is quantized and use_gradient_checkpointing=True, gradient checkpointing is enabled. For reentrant checkpointing, a forward hook is registered on the input embeddings to ensure output tensors require gradients:

if "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]:
    if hasattr(model, "enable_input_require_grads"):
        model.enable_input_require_grads()
    else:
        def make_inputs_require_grad(module, input, output):
            output.requires_grad_(True)
        model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)

Usage Example

import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model

# Load a model in 4-bit quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    quantization_config=bnb_config,
    device_map="auto",
)

# Prepare for k-bit training
model = prepare_model_for_kbit_training(model)

# Attach LoRA adapter
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
    lora_dropout=0.05,
    task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Output: trainable params: 6,553,600 || all params: 3,506,341,888 || trainable%: 0.1869

Using non-reentrant gradient checkpointing:

# Non-reentrant checkpointing avoids the need for the input_require_grads hook
model = prepare_model_for_kbit_training(
    model,
    use_gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
)

Disabling gradient checkpointing:

# When memory is not a concern, skip gradient checkpointing for faster training
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)

Edge Cases and Notes

  • Non-quantized models: The function can be called on non-quantized models. It will still freeze parameters and cast to fp32, but gradient checkpointing will not be enabled because the quantization detection flags will be False. For non-quantized models, it is generally preferable to let the PEFT method handle freezing directly.
  • GPTQ/AQLM/EETQ/HQQ/TorchAO models: The fp32 upcasting step is skipped for these quantization methods. They manage their own precision internally, and upcasting could interfere with their custom kernels.
  • Older transformers versions: If the installed transformers version does not support gradient_checkpointing_kwargs (versions <= 4.34.1), a FutureWarning is emitted and the kwargs are silently ignored.
  • Params4bit exclusion: The Params4bit class from bitsandbytes stores quantized weights and should not be cast to fp32. The function explicitly checks param.__class__.__name__ != "Params4bit" to skip these parameters during upcasting.

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment