Implementation:Lm sys FastChat AutoModelForCausalLM From Pretrained QLoRA
| Knowledge Sources | |
|---|---|
| Domains | NLP, Training, Quantization |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Wrapper around AutoModelForCausalLM.from_pretrained() with BitsAndBytesConfig for loading models in 4-bit NF4 quantization, as used in FastChat's LoRA/QLoRA training script.
Description
In fastchat/train/train_lora.py, the model loading logic conditionally applies 4-bit quantization based on the q_lora flag from LoraArguments. When q_lora=True, a BitsAndBytesConfig is constructed with NF4 quantization, double quantization enabled, and the compute dtype derived from the training precision flags (--fp16 or --bf16). The device map is set to route the model to the local GPU rank for DDP compatibility. When q_lora=False, the quantization_config parameter is None and the model loads in full precision.
The code also includes a guard that logs a warning if FSDP or DeepSpeed ZeRO Stage 3 is enabled alongside QLoRA, as these distributed strategies are incompatible with quantized weight storage.
Usage
Use this pattern when loading a base model for QLoRA fine-tuning in FastChat, particularly when training large models on memory-constrained hardware.
Code Reference
Source Location
- Repository: FastChat
- File:
fastchat/train/train_lora.py(lines 118-146, model loading with quantization) - File:
fastchat/train/train_lora.py(lines 55-65,LoraArgumentsdataclass)
Signature
# From fastchat/train/train_lora.py:train(), lines 118-146
device_map = None
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if lora_args.q_lora:
device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
logging.warning(
"FSDP and ZeRO3 are both currently incompatible with QLoRA."
)
compute_dtype = (
torch.float16
if training_args.fp16
else (torch.bfloat16 if training_args.bf16 else torch.float32)
)
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
device_map=device_map,
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
)
if lora_args.q_lora
else None,
)
Import
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import transformers
import torch
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| model_args.model_name_or_path | str |
Yes | HuggingFace Hub model ID or local path to pretrained model directory |
| training_args.cache_dir | str or None |
No | Directory for caching downloaded model files; default: None
|
| lora_args.q_lora | bool |
No | Whether to enable 4-bit QLoRA quantization; default: False
|
| training_args.fp16 | bool |
No | Use FP16 mixed precision (sets compute_dtype to torch.float16); default: False
|
| training_args.bf16 | bool |
No | Use BF16 mixed precision (sets compute_dtype to torch.bfloat16); default: False
|
| device_map | dict or None |
No | Device placement mapping; set to {"": LOCAL_RANK} for QLoRA+DDP, None otherwise
|
Outputs
| Name | Type | Description |
|---|---|---|
| model | PreTrainedModel |
Loaded causal language model, optionally quantized to 4-bit NF4, ready for LoRA adapter wrapping |
Usage Examples
QLoRA Loading (4-bit NF4)
import torch
import transformers
from transformers import BitsAndBytesConfig
compute_dtype = torch.bfloat16
model = transformers.AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
cache_dir="/data/model_cache",
device_map={"": 0},
quantization_config=BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=compute_dtype,
),
)
Standard LoRA Loading (No Quantization)
import transformers
model = transformers.AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
cache_dir="/data/model_cache",
device_map=None,
quantization_config=None,
)
Full Training Command (QLoRA with DeepSpeed)
deepspeed fastchat/train/train_lora.py \
--model_name_or_path meta-llama/Llama-2-7b-hf \
--data_path data/dummy_conversation.json \
--bf16 True \
--output_dir output_qlora \
--q_lora True \
--deepspeed playground/deepspeed_config_s2.json