Implementation:Bitsandbytes foundation Bitsandbytes Linear4bit
Metadata
| Field | Value |
|---|---|
| Page Type | Implementation (API Doc) |
| Knowledge Sources | Repo (bitsandbytes), Paper (QLoRA) |
| Domains | Quantization, Model_Architecture |
| Last Updated | 2026-02-07 14:00 GMT |
Overview
Concrete tool for 4-bit quantized linear computation provided by the bitsandbytes library.
Description
Linear4bit extends torch.nn.Linear to provide a 4-bit quantized linear layer. Upon construction, it replaces the standard weight parameter with a Params4bit object. The Params4bit parameter stores the weight data and carries quantization configuration (block size, quantization type, compression settings).
The quantization lifecycle is as follows:
- Construction: Weights are stored in their original dtype (e.g., float16). A
Params4bitwrapper is created withbnb_quantized=False. - Device transfer (
.to(device)): When moved to a CUDA or other compute device,Params4bit._quantize()is triggered. This callsbitsandbytes.functional.quantize_4bit(), which blockwise-quantizes the weight tensor and packs two 4-bit values per byte. TheQuantStatemetadata is stored alongside the packed data. - Forward pass:
Linear4bit.forward()callsbnb.matmul_4bit(), which dequantizes the packed weights and performs the matrix multiplication. The output is cast back to the input activation dtype.
The layer also maintains a quant_state attribute on the module itself (in addition to the one on the weight parameter) to support recovery after serialization or FSDP parameter flattening.
Two convenience subclasses are provided: LinearFP4 (fixes quant_type="fp4") and LinearNF4 (fixes quant_type="nf4").
Code Reference
Source Location
bitsandbytes repo, file: bitsandbytes/nn/modules.py, lines L422-557.
Signature
class Linear4bit(nn.Linear):
def __init__(
self,
input_features,
output_features,
bias=True,
compute_dtype=None,
compress_statistics=True,
quant_type="fp4",
quant_storage=torch.uint8,
device=None,
):
Import
from bitsandbytes.nn import Linear4bit
I/O Contract
Constructor Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
input_features |
int | Yes | Number of input features (columns of the weight matrix). |
output_features |
int | Yes | Number of output features (rows of the weight matrix). |
bias |
bool | No | Whether to include a bias term. Defaults to True.
|
compute_dtype |
torch.dtype | No | The dtype for computation during the forward pass. If None, inferred from input dtype at first forward call.
|
compress_statistics |
bool | No | Whether to apply double quantization to the absmax scaling factors. Defaults to True.
|
quant_type |
str | No | The quantization data type: "fp4" or "nf4". Defaults to "fp4".
|
quant_storage |
torch.dtype | No | The dtype used to physically store the packed 4-bit values. Defaults to torch.uint8.
|
device |
torch.device | No | Initial device for weight allocation. Quantization is deferred until a non-CPU device transfer. |
Forward Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
x |
torch.Tensor | Yes | Input activations. Shape: (batch, ..., input_features).
|
Forward Outputs
| Output | Type | Description |
|---|---|---|
| result | torch.Tensor | Output activations. Shape: (batch, ..., output_features). Dtype matches the input activation dtype.
|
Usage Examples
Creating and Quantizing a Linear4bit Layer
import torch
import torch.nn as nn
from bitsandbytes.nn import Linear4bit
# Create a 4-bit quantized linear layer
layer = Linear4bit(
input_features=4096,
output_features=4096,
bias=False,
compute_dtype=torch.bfloat16,
compress_statistics=True,
quant_type="nf4",
)
# Load pretrained float16 weights
pretrained_linear = nn.Linear(4096, 4096, bias=False)
layer.load_state_dict(pretrained_linear.state_dict(), strict=False)
# Quantization occurs on device transfer
layer = layer.to("cuda") # Weights are now packed 4-bit
# Forward pass
x = torch.randn(1, 128, 4096, dtype=torch.bfloat16, device="cuda")
output = layer(x) # Shape: (1, 128, 4096), dtype: torch.bfloat16
Building a Quantized Model
import torch
import torch.nn as nn
from bitsandbytes.nn import Linear4bit
# Build a simple model with 4-bit layers
fp16_model = nn.Sequential(
nn.Linear(64, 64),
nn.Linear(64, 64),
)
quantized_model = nn.Sequential(
Linear4bit(64, 64, quant_type="nf4"),
Linear4bit(64, 64, quant_type="nf4"),
)
# Copy weights from fp16 model, then quantize
quantized_model.load_state_dict(fp16_model.state_dict())
quantized_model = quantized_model.to("cuda") # Quantization happens here