Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Bitsandbytes foundation Bitsandbytes Linear4bit

From Leeroopedia


Metadata

Field Value
Page Type Implementation (API Doc)
Knowledge Sources Repo (bitsandbytes), Paper (QLoRA)
Domains Quantization, Model_Architecture
Last Updated 2026-02-07 14:00 GMT

Overview

Concrete tool for 4-bit quantized linear computation provided by the bitsandbytes library.

Description

Linear4bit extends torch.nn.Linear to provide a 4-bit quantized linear layer. Upon construction, it replaces the standard weight parameter with a Params4bit object. The Params4bit parameter stores the weight data and carries quantization configuration (block size, quantization type, compression settings).

The quantization lifecycle is as follows:

  1. Construction: Weights are stored in their original dtype (e.g., float16). A Params4bit wrapper is created with bnb_quantized=False.
  2. Device transfer (.to(device)): When moved to a CUDA or other compute device, Params4bit._quantize() is triggered. This calls bitsandbytes.functional.quantize_4bit(), which blockwise-quantizes the weight tensor and packs two 4-bit values per byte. The QuantState metadata is stored alongside the packed data.
  3. Forward pass: Linear4bit.forward() calls bnb.matmul_4bit(), which dequantizes the packed weights and performs the matrix multiplication. The output is cast back to the input activation dtype.

The layer also maintains a quant_state attribute on the module itself (in addition to the one on the weight parameter) to support recovery after serialization or FSDP parameter flattening.

Two convenience subclasses are provided: LinearFP4 (fixes quant_type="fp4") and LinearNF4 (fixes quant_type="nf4").

Code Reference

Source Location

bitsandbytes repo, file: bitsandbytes/nn/modules.py, lines L422-557.

Signature

class Linear4bit(nn.Linear):
    def __init__(
        self,
        input_features,
        output_features,
        bias=True,
        compute_dtype=None,
        compress_statistics=True,
        quant_type="fp4",
        quant_storage=torch.uint8,
        device=None,
    ):

Import

from bitsandbytes.nn import Linear4bit

I/O Contract

Constructor Inputs

Parameter Type Required Description
input_features int Yes Number of input features (columns of the weight matrix).
output_features int Yes Number of output features (rows of the weight matrix).
bias bool No Whether to include a bias term. Defaults to True.
compute_dtype torch.dtype No The dtype for computation during the forward pass. If None, inferred from input dtype at first forward call.
compress_statistics bool No Whether to apply double quantization to the absmax scaling factors. Defaults to True.
quant_type str No The quantization data type: "fp4" or "nf4". Defaults to "fp4".
quant_storage torch.dtype No The dtype used to physically store the packed 4-bit values. Defaults to torch.uint8.
device torch.device No Initial device for weight allocation. Quantization is deferred until a non-CPU device transfer.

Forward Inputs

Parameter Type Required Description
x torch.Tensor Yes Input activations. Shape: (batch, ..., input_features).

Forward Outputs

Output Type Description
result torch.Tensor Output activations. Shape: (batch, ..., output_features). Dtype matches the input activation dtype.

Usage Examples

Creating and Quantizing a Linear4bit Layer

import torch
import torch.nn as nn
from bitsandbytes.nn import Linear4bit

# Create a 4-bit quantized linear layer
layer = Linear4bit(
    input_features=4096,
    output_features=4096,
    bias=False,
    compute_dtype=torch.bfloat16,
    compress_statistics=True,
    quant_type="nf4",
)

# Load pretrained float16 weights
pretrained_linear = nn.Linear(4096, 4096, bias=False)
layer.load_state_dict(pretrained_linear.state_dict(), strict=False)

# Quantization occurs on device transfer
layer = layer.to("cuda")  # Weights are now packed 4-bit

# Forward pass
x = torch.randn(1, 128, 4096, dtype=torch.bfloat16, device="cuda")
output = layer(x)  # Shape: (1, 128, 4096), dtype: torch.bfloat16

Building a Quantized Model

import torch
import torch.nn as nn
from bitsandbytes.nn import Linear4bit

# Build a simple model with 4-bit layers
fp16_model = nn.Sequential(
    nn.Linear(64, 64),
    nn.Linear(64, 64),
)

quantized_model = nn.Sequential(
    Linear4bit(64, 64, quant_type="nf4"),
    Linear4bit(64, 64, quant_type="nf4"),
)

# Copy weights from fp16 model, then quantize
quantized_model.load_state_dict(fp16_model.state_dict())
quantized_model = quantized_model.to("cuda")  # Quantization happens here

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment