Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Bitsandbytes foundation Bitsandbytes Linear4bit Forward

From Leeroopedia


Metadata

Field Value
Page Type Implementation (API Doc)
Knowledge Sources Repo (bitsandbytes), Paper (QLoRA)
Domains Quantization, Linear_Algebra
Last Updated 2026-02-07 14:00 GMT

Overview

Concrete tool for performing forward computation through 4-bit quantized linear layers provided by the bitsandbytes library.

Description

The forward computation pipeline for 4-bit quantized linear layers involves two key components: Linear4bit.forward() and the matmul_4bit() dispatch function.

Linear4bit.forward()

The forward() method of Linear4bit orchestrates the forward pass:

  1. Recover quant state: Calls fix_4bit_weight_quant_state_from_module() to ensure the weight's quant_state is available (handles cases where it was lost during serialization or FSDP).
  2. CPU optimization: On CPU with AVX-512 BF16 support, converts the weight packing format for optimized CPU inference.
  3. Bias dtype alignment: Casts the bias to match the input activation dtype if necessary.
  4. Compute dtype resolution: If the compute dtype was not explicitly set, infers it from the input dtype on first call. Warns if float16 input is used with float32 compute (the default), as this leads to slow inference.
  5. Dtype casting: Casts input activations and bias to the compute dtype.
  6. Dispatch: Calls bnb.matmul_4bit(x, weight.t(), bias=bias, quant_state=quant_state).
  7. Output casting: Casts the result back to the original input dtype.

matmul_4bit()

The matmul_4bit() function dispatches the actual computation:

  • CPU path: On CPU, if the weight is in the optimized CPU packing format, calls gemv_4bit directly. Otherwise, falls back to MatMul4Bit.apply().
  • GEMV path (single token): When A.numel() == A.shape[-1] (single token) and requires_grad == False and the hidden dimension is aligned to the block size, dispatches to gemv_4bit(A, B.t(), out, state=quant_state). This is the fast path for autoregressive generation.
  • MatMul4Bit path (batched): For all other cases (batched inputs, training, misaligned dimensions), uses MatMul4Bit.apply(A, B, out, bias, quant_state).

MatMul4Bit Autograd Function

MatMul4Bit is a torch.autograd.Function that provides differentiable 4-bit matmul:

  • Forward: Dequantizes the weight via dequantize_4bit(B, quant_state), casts to input dtype, transposes, and calls torch.nn.functional.linear(A, dequantized_weight.t(), bias).
  • Backward: For grad_A, dequantizes the weight again and computes grad_output @ dequantized_weight. Gradient with respect to the quantized weight B is not computed (frozen weights). Bias gradient is computed via grad_output.sum(0) if needed.

Code Reference

Source Location

bitsandbytes repo:

  • bitsandbytes/nn/modules.py: Linear4bit.forward (lines L529-557)
  • bitsandbytes/autograd/_functions.py: matmul_4bit (lines L369-401), MatMul4Bit (lines L294-348)

Signature: Linear4bit.forward

def forward(self, x: torch.Tensor) -> torch.Tensor:

Signature: matmul_4bit

def matmul_4bit(
    A: torch.Tensor,
    B: torch.Tensor,
    quant_state: QuantState,
    out: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:

Import

from bitsandbytes.autograd._functions import matmul_4bit
# or accessed via:
import bitsandbytes as bnb
bnb.matmul_4bit(...)

I/O Contract

Inputs (Linear4bit.forward)

Parameter Type Required Description
x torch.Tensor Yes Input activations. Shape: (batch, ..., input_features). Any floating-point dtype.

Internal State

Component Type Description
weight Params4bit Packed 4-bit weights (after quantization). Accessed as self.weight.
quant_state QuantState Quantization metadata (absmax, shape, codebook, blocksize, quant_type, dtype, optional nested state).
bias torch.Tensor Optional bias term, cast to compute dtype during forward.

Inputs (matmul_4bit)

Parameter Type Required Description
A torch.Tensor Yes Input activations. Shape: (batch, ..., input_features).
B torch.Tensor Yes Packed 4-bit weight tensor (transposed from Linear4bit).
quant_state QuantState Yes Quantization state for dequantizing B.
out torch.Tensor No Pre-allocated output tensor.
bias torch.Tensor No Bias term to add to the output.

Outputs

Output Type Description
result torch.Tensor Output activations. Shape: (batch, ..., output_features). Dtype is cast back to the original input activation dtype.

Usage Examples

Typical Inference Through a Quantized Model

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig

# Load model in 4-bit
quantization_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-hf",
    quantization_config=quantization_config,
    device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-hf")

# Each forward pass through Linear4bit layers automatically:
# 1. Casts input to bfloat16 (compute dtype)
# 2. Dispatches to gemv_4bit (single token) or MatMul4Bit (batched)
# 3. Casts output back to input dtype
inputs = tokenizer("The capital of France is", return_tensors="pt").to("cuda")
with torch.no_grad():
    outputs = model.generate(**inputs, max_new_tokens=20)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

Direct Use of matmul_4bit

import torch
import bitsandbytes as bnb
from bitsandbytes.functional import quantize_4bit

# Create and quantize a weight matrix
weight = torch.randn(4096, 4096, dtype=torch.float16, device="cuda")
packed_weight, quant_state = quantize_4bit(
    weight,
    quant_type="nf4",
    compress_statistics=True,
)

# Perform 4-bit matrix multiplication
input_tensor = torch.randn(1, 128, 4096, dtype=torch.float16, device="cuda")
output = bnb.matmul_4bit(
    input_tensor,
    packed_weight.t(),
    quant_state=quant_state,
)
# output shape: (1, 128, 4096), dtype: float16

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment