Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Heuristic:Haotian liu LLaVA Quantization MM Projector Exclusion

From Leeroopedia
Knowledge Sources
Domains Optimization, Deep_Learning, Computer_Vision
Last Updated 2026-02-13 23:00 GMT

Overview

The multimodal projector (`mm_projector`) must be excluded from int8 quantization and kept at full or half precision to preserve vision-language alignment quality.

Description

When using 4-bit or 8-bit quantization for LLaVA training or inference, the multimodal projector — the critical bridge between the CLIP vision encoder and the LLaMA language model — is explicitly excluded from quantization. The `BitsAndBytesConfig` sets `llm_int8_skip_modules=["mm_projector"]` to prevent the projector from being quantized. Additionally, during quantized training the projector is cast to the compute dtype and placed on the correct device explicitly. During LoRA training, the `find_all_linear_names` function also excludes `mm_projector`, `vision_tower`, and `vision_resampler` from LoRA target modules.

Usage

Apply this heuristic when configuring quantized training (QLoRA with `--bits 4` or `--bits 8`) or quantized inference (`--load-4bit` or `--load-8bit`). The mm_projector must remain in higher precision to maintain the quality of vision-to-language feature mapping.

The Insight (Rule of Thumb)

  • Action: Always set `llm_int8_skip_modules=["mm_projector"]` in `BitsAndBytesConfig`. Exclude `mm_projector`, `vision_tower`, and `vision_resampler` from LoRA target modules.
  • Value: Preserves vision-language alignment while still quantizing the bulk of the language model.
  • Trade-off: Slightly higher memory usage than full quantization, but critical for maintaining model quality.
  • Additional: During quantized training, norm layers must be kept in float32 for stability, and `lm_head`/`embed_tokens` should match the training precision (bf16/fp16).

Reasoning

The multimodal projector is a small MLP (2-layer with GELU activation in V1.5) that projects CLIP vision features into the language model's embedding space. Quantizing this component would degrade the quality of the vision-language mapping, as the projector's weights need higher precision to accurately transform high-dimensional visual features. The vision tower (CLIP) is similarly excluded because it produces the features that the projector transforms. The `lm_head` is excluded from LoRA targets because it is shared with the embedding layer in 16-bit mode and modifying it with LoRA could destabilize output generation.

Code Evidence

Quantization config excluding mm_projector from `train.py:798-813`:

if training_args.bits in [4, 8]:
    from transformers import BitsAndBytesConfig
    bnb_model_from_pretrained_args.update(dict(
        device_map={"": training_args.device},
        load_in_4bit=training_args.bits == 4,
        load_in_8bit=training_args.bits == 8,
        quantization_config=BitsAndBytesConfig(
            load_in_4bit=training_args.bits == 4,
            load_in_8bit=training_args.bits == 8,
            llm_int8_skip_modules=["mm_projector"],
            llm_int8_threshold=6.0,
            llm_int8_has_fp16_weight=False,
            bnb_4bit_compute_dtype=compute_dtype,
            bnb_4bit_use_double_quant=training_args.double_quant,
            bnb_4bit_quant_type=training_args.quant_type
        )
    ))

LoRA target exclusion from `train.py:169-182`:

def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
    for name, module in model.named_modules():
        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
            continue
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)

Norm layers kept in float32 during quantized training from `train.py:946-957`:

if training_args.bits in [4, 8]:
    from peft.tuners.lora import LoraLayer
    for name, module in model.named_modules():
        if isinstance(module, LoraLayer):
            if training_args.bf16:
                module = module.to(torch.bfloat16)
        if 'norm' in name:
            module = module.to(torch.float32)
        if 'lm_head' in name or 'embed_tokens' in name:
            if hasattr(module, 'weight'):
                if training_args.bf16 and module.weight.dtype == torch.float32:
                    module = module.to(torch.bfloat16)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment