Heuristic:Bitsandbytes foundation Bitsandbytes Compute Dtype Mismatch Warning
| Knowledge Sources | |
|---|---|
| Domains | Optimization, Quantization |
| Last Updated | 2026-02-07 13:00 GMT |
Overview
Performance warning: passing float16 input to Linear4bit with the default compute_dtype=float32 causes slow inference; set `bnb_4bit_compute_dtype=torch.float16` or `torch.bfloat16` for speed.
Description
Linear4bit performs dequantization and matmul in a compute dtype that defaults to float32. When the input tensor is already float16, this default causes an unnecessary fp16→fp32 promotion that significantly slows computation (2-3x slower). Bitsandbytes detects this mismatch and emits a warning at runtime. The fix is to explicitly set the compute dtype to match the input dtype, or use bfloat16 which is both fast and numerically stable. For 8-bit (Linear8bitLt), a separate warning is emitted when inputs are not float16, as the INT8 quantization pathway requires fp16 casting.
Usage
Apply this heuristic when using 4-bit quantized models (BitsAndBytesConfig with `load_in_4bit=True`) and observing unexpectedly slow inference or training. Check if your input dtype matches the compute dtype setting.
The Insight (Rule of Thumb)
- Action: Set `bnb_4bit_compute_dtype=torch.bfloat16` in BitsAndBytesConfig, or `torch.float16` if bf16 is unavailable.
- Value: bfloat16 is preferred (fast + stable). float16 is acceptable. float32 is slow but most precise.
- Trade-off: Lower precision compute dtypes (fp16/bf16) are 2-3x faster but may introduce minor numerical differences. float32 is the safest but slowest option.
- Single-batch detection: The warning differentiates single-batch inference (`x.numel() == x.shape[-1]`) from multi-batch, as the performance impact is more pronounced for single-batch.
Reasoning
The performance gap occurs because GPU tensor cores operate most efficiently in fp16/bf16. When compute_dtype=float32, the dequantized 4-bit weights and inputs must be promoted to fp32 before the matmul, bypassing tensor core acceleration. The warning is designed to catch the most common misconfiguration: loading a model in 4-bit without explicitly setting the compute dtype.
Compute dtype detection from `bitsandbytes/nn/modules.py:493-511`:
def set_compute_type(self, x):
if x.dtype in [torch.float32, torch.bfloat16]:
# the input is in a dtype that is safe to compute in, we switch
# to this type for speed and stability
self.compute_dtype = x.dtype
elif x.dtype == torch.float16:
if self.compute_dtype in [None, torch.float32] and (x.numel() == x.shape[-1]):
warnings.warn(
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 "
"(default). This will lead to slow inference.",
)
if self.compute_dtype in [None, torch.float32] and (x.numel() != x.shape[-1]):
warnings.warn(
"Input type into Linear4bit is torch.float16, but bnb_4bit_compute_dtype=torch.float32 "
"(default). This will lead to slow inference or training speed.",
)
INT8 input casting warning from `bitsandbytes/autograd/_functions.py:122-123`:
if A.dtype != torch.float16 and not _is_compiling():
warnings.warn(f"MatMul8bitLt: inputs will be cast from {A.dtype} to float16 during quantization")