Implementation:Bitsandbytes foundation Bitsandbytes Linear4bit Forward
Appearance
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation (API Doc) |
| Knowledge Sources | Repo (bitsandbytes), Paper (QLoRA) |
| Domains | Quantization, Linear_Algebra |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Concrete tool for performing forward computation through 4-bit quantized linear layers provided by the bitsandbytes library.
Description
The forward computation pipeline for 4-bit quantized linear layers involves two key components: Linear4bit.forward() and the matmul_4bit() dispatch function.
Linear4bit.forward()
The forward() method of Linear4bit orchestrates the forward pass:
- Recover quant state: Calls
fix_4bit_weight_quant_state_from_module()to ensure the weight'squant_stateis available (handles cases where it was lost during serialization or FSDP). - CPU optimization: On CPU with AVX-512 BF16 support, converts the weight packing format for optimized CPU inference.
- Bias dtype alignment: Casts the bias to match the input activation dtype if necessary.
- Compute dtype resolution: If the compute dtype was not explicitly set, infers it from the input dtype on first call. Warns if float16 input is used with float32 compute (the default), as this leads to slow inference.
- Dtype casting: Casts input activations and bias to the compute dtype.
- Dispatch: Calls
bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=quant_state). - Output casting: Casts the result back to the original input dtype.
matmul_4bit()
The matmul_4bit() function dispatches the actual computation:
- CPU path: On CPU, if the weight is in the optimized CPU packing format, calls
gemv_4bitdirectly. Otherwise, falls back toMatMul4Bit.apply(). - GEMV path (single token): When
A.numel() == A.shape[-1](single token) andrequires_grad == Falseand the hidden dimension is aligned to the block size, dispatches togemv_4bit(A, B.t(), out, state=quant_state). This is the fast path for autoregressive generation. - MatMul4Bit path (batched): For all other cases (batched inputs, training, misaligned dimensions), uses
MatMul4Bit.apply(A, B, out, bias, quant_state).
MatMul4Bit Autograd Function
MatMul4Bit is a torch.autograd.Function that provides differentiable 4-bit matmul:
- Forward: Dequantizes the weight via
dequantize_4bit(B, quant_state), casts to input dtype, transposes, and callstorch.nn.functional.linear(A, dequantized_weight.t(), bias). - Backward: For
grad_A, dequantizes the weight again and computesgrad_output @ dequantized_weight. Gradient with respect to the quantized weight B is not computed (frozen weights). Bias gradient is computed viagrad_output.sum(0)if needed.
Code Reference
Source Location
bitsandbytes repo:
bitsandbytes/nn/modules.py:Linear4bit.forward(lines L529-557)bitsandbytes/autograd/_functions.py:matmul_4bit(lines L369-401),MatMul4Bit(lines L294-348)
Signature: Linear4bit.forward
def forward(self, x: torch.Tensor) -> torch.Tensor:
Signature: matmul_4bit
def matmul_4bit(
A: torch.Tensor,
B: torch.Tensor,
quant_state: QuantState,
out: Optional[torch.Tensor] = None,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
Import
from bitsandbytes.autograd._functions import matmul_4bit
# or accessed via:
import bitsandbytes as bnb
bnb.matmul_4bit(...)
I/O Contract
Inputs (Linear4bit.forward)
| Parameter | Type | Required | Description |
|---|---|---|---|
x |
torch.Tensor | Yes | Input activations. Shape: (batch, ..., input_features). Any floating-point dtype.
|
Internal State
| Component | Type | Description |
|---|---|---|
weight |
Params4bit | Packed 4-bit weights (after quantization). Accessed as self.weight.
|
quant_state |
QuantState | Quantization metadata (absmax, shape, codebook, blocksize, quant_type, dtype, optional nested state). |
bias |
torch.Tensor | Optional bias term, cast to compute dtype during forward. |
Inputs (matmul_4bit)
| Parameter | Type | Required | Description |
|---|---|---|---|
A |
torch.Tensor | Yes | Input activations. Shape: (batch, ..., input_features).
|
B |
torch.Tensor | Yes | Packed 4-bit weight tensor (transposed from Linear4bit). |
quant_state |
QuantState | Yes | Quantization state for dequantizing B. |
out |
torch.Tensor | No | Pre-allocated output tensor. |
bias |
torch.Tensor | No | Bias term to add to the output. |
Outputs
| Output | Type | Description |
|---|---|---|
| result | torch.Tensor | Output activations. Shape: (batch, ..., output_features). Dtype is cast back to the original input activation dtype.
|
Usage Examples
Typical Inference Through a Quantized Model
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
# Load model in 4-bit
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-7b-hf",
quantization_config=quantization_config,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")
# Each forward pass through Linear4bit layers automatically:
# 1. Casts input to bfloat16 (compute dtype)
# 2. Dispatches to gemv_4bit (single token) or MatMul4Bit (batched)
# 3. Casts output back to input dtype
inputs = tokenizer("The capital of France is", return_tensors="pt").to("cuda")
with torch.no_grad():
outputs = model.generate(**inputs, max_new_tokens=20)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))
Direct Use of matmul_4bit
import torch
import bitsandbytes as bnb
from bitsandbytes.functional import quantize_4bit
# Create and quantize a weight matrix
weight = torch.randn(4096, 4096, dtype=torch.float16, device="cuda")
packed_weight, quant_state = quantize_4bit(
weight,
quant_type="nf4",
compress_statistics=True,
)
# Perform 4-bit matrix multiplication
input_tensor = torch.randn(1, 128, 4096, dtype=torch.float16, device="cuda")
output = bnb.matmul_4bit(
input_tensor,
packed_weight.t(),
quant_state=quant_state,
)
# output shape: (1, 128, 4096), dtype: float16
Related Pages
Page Connections
Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment