Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Bitsandbytes foundation Bitsandbytes Fix 4bit Weight Quant State

From Leeroopedia


Metadata

Field Value
Sources Repo: bitsandbytes
Domains Distributed_Training, Quantization
Last updated 2026-02-07 14:00 GMT

Overview

Concrete tool for recovering 4-bit quantization state from module after FSDP operations provided by the bitsandbytes library.

Description

fix_4bit_weight_quant_state_from_module checks if module.weight.quant_state is None. If so, it copies module.quant_state to module.weight.quant_state. This is needed because FSDP unshard operations may create new weight tensors without preserving custom attributes.

The function performs the following steps:

  1. Early return if module.weight.quant_state is already set (L406-407)
  2. Warning if module.quant_state is also None -- indicates the layer has not been quantized yet (L409-412)
  3. Shape assertion -- verifies weight.shape[1] == 1 to confirm packed quantized format (L416)
  4. Weight wrapping -- if the weight is not a Params4bit instance, wraps it with quant_storage from the module (L417-418)
  5. State recovery -- assigns module.quant_state to module.weight.quant_state (L419)

Called at the top of Linear4bit.forward() (L530).

Code Reference

Field Value
Source bitsandbytes repo
File bitsandbytes/nn/modules.py
Lines L405-420

Signature

def fix_4bit_weight_quant_state_from_module(
    module: Union["Embedding4bit", "Linear4bit"]
) -> None:

Import

from bitsandbytes.nn.modules import fix_4bit_weight_quant_state_from_module

Source Code (L405-420)

def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]):
    if getattr(module.weight, "quant_state", None) is not None:
        return

    if getattr(module, "quant_state", None) is None:
        warnings.warn(
            "FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
        )

    # the quant state got lost when the parameter got converted. This happens for example for fsdp
    # since we registered the module, we can recover the state here
    assert module.weight.shape[1] == 1
    if not isinstance(module.weight, Params4bit):
        module.weight = Params4bit(module.weight, quant_storage=module.quant_storage, bnb_quantized=True)
    module.weight.quant_state = module.quant_state

I/O Contract

Inputs

Parameter Type Required Description
module Linear4bit or Embedding4bit Yes The module whose weight's quant_state may need recovery

Outputs

None -- mutates module.weight.quant_state in-place.

Side Effects

  • Sets module.weight.quant_state = module.quant_state if the weight's quant_state was None
  • May wrap module.weight in a new Params4bit instance if it was converted to a plain tensor by FSDP

Usage Examples

Called automatically within Linear4bit.forward(); shown here in context:

class Linear4bit(nn.Linear):
    def forward(self, x: torch.Tensor):
        # Recovery call -- ensures quant_state is present after FSDP unshard
        fix_4bit_weight_quant_state_from_module(self)

        quant_state = self.weight.quant_state
        # ... rest of forward pass uses quant_state for dequantization
        return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=quant_state).to(inp_dtype)

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment