Implementation:Huggingface Alignment handbook Get Model Quantized
| Knowledge Sources | |
|---|---|
| Domains | NLP, Deep_Learning, Optimization |
| Last Updated | 2026-02-07 00:00 GMT |
Overview
Concrete tool for loading pretrained models with 4-bit BitsAndBytes quantization for QLoRA training, provided by the alignment-handbook and TRL libraries.
Description
The get_model function, when model_args.load_in_4bit is True, activates the QLoRA model loading path. Internally, it calls TRL's get_quantization_config to create a BitsAndBytesConfig and get_kbit_device_map to set up appropriate GPU device mapping. The model is then loaded with 4-bit quantized weights via AutoModelForCausalLM.from_pretrained.
This is the same get_model function used for standard loading, but the quantization branch is activated by the ModelConfig flags set in the QLoRA YAML recipe configs.
Usage
Use this when loading models for QLoRA fine-tuning. The quantization is controlled entirely by YAML config (load_in_4bit: true in the recipe file), requiring no code changes from the standard loading path.
Code Reference
Source Location
- Repository: alignment-handbook
- File: src/alignment/model_utils.py (lines 37-57)
- Config: recipes/zephyr-7b-beta/sft/config_qlora.yaml (lines 1-71)
Signature
def get_model(model_args: ModelConfig, training_args: SFTConfig) -> AutoModelForCausalLM:
"""Get the model.
When model_args.load_in_4bit is True, internally calls:
- get_quantization_config(model_args) -> BitsAndBytesConfig
- get_kbit_device_map() -> dict (device mapping for quantized model)
"""
# TRL utility functions used internally
from trl import get_quantization_config, get_kbit_device_map
def get_quantization_config(model_args: ModelConfig) -> Optional[BitsAndBytesConfig]:
"""Returns BitsAndBytesConfig if load_in_4bit or load_in_8bit is True."""
def get_kbit_device_map() -> dict:
"""Returns device map for quantized model loading."""
Import
from alignment import get_model
from trl import ModelConfig
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_args | ModelConfig | Yes | Model config with quantization flags |
| model_args.model_name_or_path | str | Yes | HuggingFace model ID (e.g., "mistralai/Mistral-7B-v0.1") |
| model_args.load_in_4bit | bool | Yes | Must be True for QLoRA quantization |
| model_args.torch_dtype | str | No | Compute dtype (e.g., "bfloat16") |
| model_args.attn_implementation | str | No | Attention backend (e.g., "flash_attention_2") |
| training_args | SFTConfig | Yes | Training config (gradient_checkpointing toggles use_cache) |
Outputs
| Name | Type | Description |
|---|---|---|
| return | AutoModelForCausalLM | 4-bit quantized model on GPU, ready for LoRA adapter injection by SFTTrainer/DPOTrainer |
Usage Examples
QLoRA YAML Config
# From recipes/zephyr-7b-beta/sft/config_qlora.yaml
model_name_or_path: mistralai/Mistral-7B-v0.1
torch_dtype: bfloat16
attn_implementation: flash_attention_2
load_in_4bit: true
use_peft: true
lora_r: 16
lora_alpha: 16
lora_dropout: 0.05
lora_target_modules:
- q_proj
- k_proj
- v_proj
- o_proj
- gate_proj
- up_proj
- down_proj
Programmatic Usage
from alignment import get_model
# model_args populated from QLoRA YAML config
# model_args.load_in_4bit = True triggers quantization path
model = get_model(model_args, training_args)
# Model is now 4-bit quantized
# LoRA adapters are injected by the trainer via peft_config