Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Bitsandbytes foundation Bitsandbytes Linear4bit FSDP

From Leeroopedia


Metadata

Field Value
Sources Repo: bitsandbytes, Blog: FSDP QLoRA
Domains Quantization, Distributed_Training
Last updated 2026-02-07 14:00 GMT

Overview

Concrete tool for FSDP-compatible 4-bit quantized linear layers provided by the bitsandbytes library.

Description

This documents the Linear4bit class from the FSDP perspective. The key difference from standard Linear4bit usage is the quant_storage parameter set to bfloat16 (instead of the default uint8). This enables FSDP to shard the quantized weights as if they were regular bfloat16 parameters.

The forward() method calls fix_4bit_weight_quant_state_from_module() to recover QuantState lost during FSDP operations. The module stores quant_storage (L490) so state can be recovered even when weight.quant_state is None.

Code Reference

Field Value
Source bitsandbytes repo
File bitsandbytes/nn/modules.py
Lines L422-557 (Linear4bit), L405-420 (fix_4bit_weight_quant_state_from_module)
Key FSDP-specific lines L463 (quant_storage param), L490 (self.quant_storage stored), L530 (fix call in forward)

Signature

class Linear4bit(nn.Linear):
    def __init__(
        self,
        input_features,
        output_features,
        bias=True,
        compute_dtype=None,
        compress_statistics=True,
        quant_type="fp4",
        quant_storage=torch.uint8,  # Set to torch.bfloat16 for FSDP
        device=None,
    ):
def fix_4bit_weight_quant_state_from_module(
    module: Union["Embedding4bit", "Linear4bit"]
):

Import

from bitsandbytes.nn import Linear4bit

I/O Contract

Inputs

Parameter Type Required Default Description
input_features int Yes -- Number of input features
output_features int Yes -- Number of output features
bias bool No True Whether to include bias term
compute_dtype torch.dtype No None Dtype for computation during forward pass
compress_statistics bool No True Use double quantization for quantization constants
quant_type str No "fp4" Quantization type: "fp4" or "nf4"
quant_storage torch.dtype No torch.uint8 MUST be set to torch.bfloat16 (or match torch_dtype) for FSDP
device torch.device No None Device for parameter placement

Forward

Same as standard Linear4bit, with added fix_4bit_weight_quant_state_from_module call at the top of forward().

Outputs

torch.Tensor -- the result of the 4-bit quantized linear transformation.

Usage Examples

Loading a model for FSDP QLoRA training with bfloat16 quant_storage:

from transformers import AutoModelForCausalLM, BitsAndBytesConfig
import torch

# Configure 4-bit quantization with bfloat16 storage for FSDP
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_storage=torch.bfloat16,  # Critical for FSDP
)

model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-70b-hf",
    quantization_config=bnb_config,
    torch_dtype=torch.bfloat16,  # Must match quant_storage
)

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment