Implementation:Bitsandbytes foundation Bitsandbytes Linear4bit FSDP
Metadata
| Field | Value |
|---|---|
| Sources | Repo: bitsandbytes, Blog: FSDP QLoRA |
| Domains | Quantization, Distributed_Training |
| Last updated | 2026-02-07 14:00 GMT |
Overview
Concrete tool for FSDP-compatible 4-bit quantized linear layers provided by the bitsandbytes library.
Description
This documents the Linear4bit class from the FSDP perspective. The key difference from standard Linear4bit usage is the quant_storage parameter set to bfloat16 (instead of the default uint8). This enables FSDP to shard the quantized weights as if they were regular bfloat16 parameters.
The forward() method calls fix_4bit_weight_quant_state_from_module() to recover QuantState lost during FSDP operations. The module stores quant_storage (L490) so state can be recovered even when weight.quant_state is None.
Code Reference
| Field | Value |
|---|---|
| Source | bitsandbytes repo |
| File | bitsandbytes/nn/modules.py
|
| Lines | L422-557 (Linear4bit), L405-420 (fix_4bit_weight_quant_state_from_module) |
| Key FSDP-specific lines | L463 (quant_storage param), L490 (self.quant_storage stored), L530 (fix call in forward) |
Signature
class Linear4bit(nn.Linear):
def __init__(
self,
input_features,
output_features,
bias=True,
compute_dtype=None,
compress_statistics=True,
quant_type="fp4",
quant_storage=torch.uint8, # Set to torch.bfloat16 for FSDP
device=None,
):
def fix_4bit_weight_quant_state_from_module(
module: Union["Embedding4bit", "Linear4bit"]
):
Import
from bitsandbytes.nn import Linear4bit
I/O Contract
Inputs
| Parameter | Type | Required | Default | Description |
|---|---|---|---|---|
| input_features | int | Yes | -- | Number of input features |
| output_features | int | Yes | -- | Number of output features |
| bias | bool | No | True | Whether to include bias term |
| compute_dtype | torch.dtype | No | None | Dtype for computation during forward pass |
| compress_statistics | bool | No | True | Use double quantization for quantization constants |
| quant_type | str | No | "fp4" | Quantization type: "fp4" or "nf4" |
| quant_storage | torch.dtype | No | torch.uint8 | MUST be set to torch.bfloat16 (or match torch_dtype) for FSDP |
| device | torch.device | No | None | Device for parameter placement |
Forward
Same as standard Linear4bit, with added fix_4bit_weight_quant_state_from_module call at the top of forward().
Outputs
torch.Tensor -- the result of the 4-bit quantized linear transformation.
Usage Examples
Loading a model for FSDP QLoRA training with bfloat16 quant_storage:
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch
# Configure 4-bit quantization with bfloat16 storage for FSDP
bnb_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_compute_dtype=torch.bfloat16,
bnb_4bit_use_double_quant=True,
bnb_4bit_quant_storage=torch.bfloat16, # Critical for FSDP
)
model = AutoModelForCausalLM.from_pretrained(
"meta-llama/Llama-2-70b-hf",
quantization_config=bnb_config,
torch_dtype=torch.bfloat16, # Must match quant_storage
)