Implementation:Huggingface Peft Prepare Model For Kbit Training
Metadata
| Field | Value |
|---|---|
| Source | PEFT | https://github.com/huggingface/peft |
| Domains | Quantization, Training |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
prepare_model_for_kbit_training is a utility function that prepares a quantized transformer model for parameter-efficient fine-tuning. It implements the complete protocol described in the Quantized Model Preparation principle: freezing base model parameters, casting non-integer parameters to fp32 for numerical stability, and enabling gradient checkpointing for memory efficiency. This function is the standard entry point for any QLoRA or quantized adapter training workflow.
Source
File: src/peft/utils/other.py, lines 130-215
Repository: huggingface/peft
Signature
def prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=True,
gradient_checkpointing_kwargs=None
):
Import
from peft import prepare_model_for_kbit_training
Parameters
| Parameter | Type | Default | Description |
|---|---|---|---|
model |
transformers.PreTrainedModel |
required | The loaded model from transformers. Must be a model loaded with quantization enabled (e.g., load_in_4bit=True or load_in_8bit=True).
|
use_gradient_checkpointing |
bool |
True |
If True, enables gradient checkpointing to save memory at the expense of slower backward pass. Only takes effect when the model is detected as quantized. |
gradient_checkpointing_kwargs |
dict or None |
None |
Keyword arguments passed to torch.utils.checkpoint.checkpoint. Requires transformers > 4.34.1. Common usage: {"use_reentrant": False}.
|
Return Value
Returns the modified model (same object, mutated in-place) with frozen parameters, fp32 casts applied, and gradient checkpointing enabled.
Behavior
The function performs the following operations in sequence:
1. Detect quantization method: The function checks which quantization backend was used by inspecting model attributes:
loaded_in_kbit = getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False)
is_gptq_quantized = getattr(model, "quantization_method", None) == "gptq"
is_aqlm_quantized = getattr(model, "quantization_method", None) == "aqlm"
is_eetq_quantized = getattr(model, "quantization_method", None) == "eetq"
is_torchao_quantized = getattr(model, "quantization_method", None) == "torchao"
is_hqq_quantized = getattr(model, "quantization_method", None) == "hqq" or getattr(model, "hqq_quantized", False)
2. Freeze all base parameters: Every parameter in the model has requires_grad set to False:
for name, param in model.named_parameters():
param.requires_grad = False
3. Cast to fp32 (conditional): For bitsandbytes-quantized models (not GPTQ, AQLM, EETQ, HQQ, or TorchAO), all float16/bfloat16 parameters are cast to float32, except Params4bit instances which remain in their quantized form:
for param in model.parameters():
if (
(param.dtype == torch.float16) or (param.dtype == torch.bfloat16)
) and param.__class__.__name__ != "Params4bit":
param.data = param.data.to(torch.float32)
4. Enable gradient checkpointing (conditional): If the model is quantized and use_gradient_checkpointing=True, gradient checkpointing is enabled. For reentrant checkpointing, a forward hook is registered on the input embeddings to ensure output tensors require gradients:
if "use_reentrant" not in gradient_checkpointing_kwargs or gradient_checkpointing_kwargs["use_reentrant"]:
if hasattr(model, "enable_input_require_grads"):
model.enable_input_require_grads()
else:
def make_inputs_require_grad(module, input, output):
output.requires_grad_(True)
model.get_input_embeddings().register_forward_hook(make_inputs_require_grad)
Usage Example
import torch
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import prepare_model_for_kbit_training, LoraConfig, get_peft_model
# Load a model in 4-bit quantization
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.float16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
quantization_config=bnb_config,
device_map="auto",
)
# Prepare for k-bit training
model = prepare_model_for_kbit_training(model)
# Attach LoRA adapter
lora_config = LoraConfig(
r=16,
lora_alpha=32,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
lora_dropout=0.05,
task_type="CAUSAL_LM",
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# Output: trainable params: 6,553,600 || all params: 3,506,341,888 || trainable%: 0.1869
Using non-reentrant gradient checkpointing:
# Non-reentrant checkpointing avoids the need for the input_require_grads hook
model = prepare_model_for_kbit_training(
model,
use_gradient_checkpointing=True,
gradient_checkpointing_kwargs={"use_reentrant": False},
)
Disabling gradient checkpointing:
# When memory is not a concern, skip gradient checkpointing for faster training
model = prepare_model_for_kbit_training(model, use_gradient_checkpointing=False)
Edge Cases and Notes
- Non-quantized models: The function can be called on non-quantized models. It will still freeze parameters and cast to fp32, but gradient checkpointing will not be enabled because the quantization detection flags will be False. For non-quantized models, it is generally preferable to let the PEFT method handle freezing directly.
- GPTQ/AQLM/EETQ/HQQ/TorchAO models: The fp32 upcasting step is skipped for these quantization methods. They manage their own precision internally, and upcasting could interfere with their custom kernels.
- Older transformers versions: If the installed transformers version does not support
gradient_checkpointing_kwargs(versions <= 4.34.1), aFutureWarningis emitted and the kwargs are silently ignored. - Params4bit exclusion: The
Params4bitclass from bitsandbytes stores quantized weights and should not be cast to fp32. The function explicitly checksparam.__class__.__name__ != "Params4bit"to skip these parameters during upcasting.
Related Pages
- Principle:Huggingface_Peft_Quantized_Model_Preparation
- Environment:Huggingface_Peft_Python_Core_Dependencies
- Environment:Huggingface_Peft_BitsAndBytes_Quantization
- Environment:Huggingface_Peft_Optional_Quantization_Backends
- Environment:Huggingface_Peft_GPU_Hardware_Detection
- Heuristic:Huggingface_Peft_Gradient_Checkpointing_With_Quantization