Heuristic:Haotian liu LLaVA Quantization MM Projector Exclusion
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Deep_Learning, Computer_Vision |
| Last Updated | 2026-02-13 23:00 GMT |
Overview
The multimodal projector (`mm_projector`) must be excluded from int8 quantization and kept at full or half precision to preserve vision-language alignment quality.
Description
When using 4-bit or 8-bit quantization for LLaVA training or inference, the multimodal projector — the critical bridge between the CLIP vision encoder and the LLaMA language model — is explicitly excluded from quantization. The `BitsAndBytesConfig` sets `llm_int8_skip_modules=["mm_projector"]` to prevent the projector from being quantized. Additionally, during quantized training the projector is cast to the compute dtype and placed on the correct device explicitly. During LoRA training, the `find_all_linear_names` function also excludes `mm_projector`, `vision_tower`, and `vision_resampler` from LoRA target modules.
Usage
Apply this heuristic when configuring quantized training (QLoRA with `--bits 4` or `--bits 8`) or quantized inference (`--load-4bit` or `--load-8bit`). The mm_projector must remain in higher precision to maintain the quality of vision-to-language feature mapping.
The Insight (Rule of Thumb)
- Action: Always set `llm_int8_skip_modules=["mm_projector"]` in `BitsAndBytesConfig`. Exclude `mm_projector`, `vision_tower`, and `vision_resampler` from LoRA target modules.
- Value: Preserves vision-language alignment while still quantizing the bulk of the language model.
- Trade-off: Slightly higher memory usage than full quantization, but critical for maintaining model quality.
- Additional: During quantized training, norm layers must be kept in float32 for stability, and `lm_head`/`embed_tokens` should match the training precision (bf16/fp16).
Reasoning
The multimodal projector is a small MLP (2-layer with GELU activation in V1.5) that projects CLIP vision features into the language model's embedding space. Quantizing this component would degrade the quality of the vision-language mapping, as the projector's weights need higher precision to accurately transform high-dimensional visual features. The vision tower (CLIP) is similarly excluded because it produces the features that the projector transforms. The `lm_head` is excluded from LoRA targets because it is shared with the embedding layer in 16-bit mode and modifying it with LoRA could destabilize output generation.
Code Evidence
Quantization config excluding mm_projector from `train.py:798-813`:
if training_args.bits in [4, 8]:
from transformers import BitsAndBytesConfig
bnb_model_from_pretrained_args.update(dict(
device_map={"": training_args.device},
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
quantization_config=BitsAndBytesConfig(
load_in_4bit=training_args.bits == 4,
load_in_8bit=training_args.bits == 8,
llm_int8_skip_modules=["mm_projector"],
llm_int8_threshold=6.0,
llm_int8_has_fp16_weight=False,
bnb_4bit_compute_dtype=compute_dtype,
bnb_4bit_use_double_quant=training_args.double_quant,
bnb_4bit_quant_type=training_args.quant_type
)
))
LoRA target exclusion from `train.py:169-182`:
def find_all_linear_names(model):
cls = torch.nn.Linear
lora_module_names = set()
multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
for name, module in model.named_modules():
if any(mm_keyword in name for mm_keyword in multimodal_keywords):
continue
if isinstance(module, cls):
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
if 'lm_head' in lora_module_names: # needed for 16-bit
lora_module_names.remove('lm_head')
return list(lora_module_names)
Norm layers kept in float32 during quantized training from `train.py:946-957`:
if training_args.bits in [4, 8]:
from peft.tuners.lora import LoraLayer
for name, module in model.named_modules():
if isinstance(module, LoraLayer):
if training_args.bf16:
module = module.to(torch.bfloat16)
if 'norm' in name:
module = module.to(torch.float32)
if 'lm_head' in name or 'embed_tokens' in name:
if hasattr(module, 'weight'):
if training_args.bf16 and module.weight.dtype == torch.float32:
module = module.to(torch.bfloat16)