Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Lm sys FastChat AutoModelForCausalLM From Pretrained QLoRA

From Leeroopedia


Knowledge Sources
Domains NLP, Training, Quantization
Last Updated 2026-02-07 14:00 GMT

Overview

Wrapper around AutoModelForCausalLM.from_pretrained() with BitsAndBytesConfig for loading models in 4-bit NF4 quantization, as used in FastChat's LoRA/QLoRA training script.

Description

In fastchat/train/train_lora.py, the model loading logic conditionally applies 4-bit quantization based on the q_lora flag from LoraArguments. When q_lora=True, a BitsAndBytesConfig is constructed with NF4 quantization, double quantization enabled, and the compute dtype derived from the training precision flags (--fp16 or --bf16). The device map is set to route the model to the local GPU rank for DDP compatibility. When q_lora=False, the quantization_config parameter is None and the model loads in full precision.

The code also includes a guard that logs a warning if FSDP or DeepSpeed ZeRO Stage 3 is enabled alongside QLoRA, as these distributed strategies are incompatible with quantized weight storage.

Usage

Use this pattern when loading a base model for QLoRA fine-tuning in FastChat, particularly when training large models on memory-constrained hardware.

Code Reference

Source Location

  • Repository: FastChat
  • File: fastchat/train/train_lora.py (lines 118-146, model loading with quantization)
  • File: fastchat/train/train_lora.py (lines 55-65, LoraArguments dataclass)

Signature

# From fastchat/train/train_lora.py:train(), lines 118-146

device_map = None
world_size = int(os.environ.get("WORLD_SIZE", 1))
ddp = world_size != 1
if lora_args.q_lora:
    device_map = {"": int(os.environ.get("LOCAL_RANK") or 0)} if ddp else None
    if len(training_args.fsdp) > 0 or deepspeed.is_deepspeed_zero3_enabled():
        logging.warning(
            "FSDP and ZeRO3 are both currently incompatible with QLoRA."
        )

compute_dtype = (
    torch.float16
    if training_args.fp16
    else (torch.bfloat16 if training_args.bf16 else torch.float32)
)

model = transformers.AutoModelForCausalLM.from_pretrained(
    model_args.model_name_or_path,
    cache_dir=training_args.cache_dir,
    device_map=device_map,
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=compute_dtype,
    )
    if lora_args.q_lora
    else None,
)

Import

from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import transformers
import torch

I/O Contract

Inputs

Name Type Required Description
model_args.model_name_or_path str Yes HuggingFace Hub model ID or local path to pretrained model directory
training_args.cache_dir str or None No Directory for caching downloaded model files; default: None
lora_args.q_lora bool No Whether to enable 4-bit QLoRA quantization; default: False
training_args.fp16 bool No Use FP16 mixed precision (sets compute_dtype to torch.float16); default: False
training_args.bf16 bool No Use BF16 mixed precision (sets compute_dtype to torch.bfloat16); default: False
device_map dict or None No Device placement mapping; set to {"": LOCAL_RANK} for QLoRA+DDP, None otherwise

Outputs

Name Type Description
model PreTrainedModel Loaded causal language model, optionally quantized to 4-bit NF4, ready for LoRA adapter wrapping

Usage Examples

QLoRA Loading (4-bit NF4)

import torch
import transformers
from transformers import BitsAndBytesConfig

compute_dtype = torch.bfloat16

model = transformers.AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    cache_dir="/data/model_cache",
    device_map={"": 0},
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_dtype=compute_dtype,
    ),
)

Standard LoRA Loading (No Quantization)

import transformers

model = transformers.AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    cache_dir="/data/model_cache",
    device_map=None,
    quantization_config=None,
)

Full Training Command (QLoRA with DeepSpeed)

deepspeed fastchat/train/train_lora.py \
    --model_name_or_path meta-llama/Llama-2-7b-hf \
    --data_path data/dummy_conversation.json \
    --bf16 True \
    --output_dir output_qlora \
    --q_lora True \
    --deepspeed playground/deepspeed_config_s2.json

Related Pages

Implements Principle

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment