Implementation:Bitsandbytes foundation Bitsandbytes Fix 4bit Weight Quant State
Metadata
| Field | Value |
|---|---|
| Sources | Repo: bitsandbytes |
| Domains | Distributed_Training, Quantization |
| Last updated | 2026-02-07 14:00 GMT |
Overview
Concrete tool for recovering 4-bit quantization state from module after FSDP operations provided by the bitsandbytes library.
Description
fix_4bit_weight_quant_state_from_module checks if module.weight.quant_state is None. If so, it copies module.quant_state to module.weight.quant_state. This is needed because FSDP unshard operations may create new weight tensors without preserving custom attributes.
The function performs the following steps:
- Early return if
module.weight.quant_stateis already set (L406-407) - Warning if
module.quant_stateis alsoNone-- indicates the layer has not been quantized yet (L409-412) - Shape assertion -- verifies
weight.shape[1] == 1to confirm packed quantized format (L416) - Weight wrapping -- if the weight is not a Params4bit instance, wraps it with
quant_storagefrom the module (L417-418) - State recovery -- assigns
module.quant_statetomodule.weight.quant_state(L419)
Called at the top of Linear4bit.forward() (L530).
Code Reference
| Field | Value |
|---|---|
| Source | bitsandbytes repo |
| File | bitsandbytes/nn/modules.py
|
| Lines | L405-420 |
Signature
def fix_4bit_weight_quant_state_from_module(
module: Union["Embedding4bit", "Linear4bit"]
) -> None:
Import
from bitsandbytes.nn.modules import fix_4bit_weight_quant_state_from_module
Source Code (L405-420)
def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]):
if getattr(module.weight, "quant_state", None) is not None:
return
if getattr(module, "quant_state", None) is None:
warnings.warn(
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first.",
)
# the quant state got lost when the parameter got converted. This happens for example for fsdp
# since we registered the module, we can recover the state here
assert module.weight.shape[1] == 1
if not isinstance(module.weight, Params4bit):
module.weight = Params4bit(module.weight, quant_storage=module.quant_storage, bnb_quantized=True)
module.weight.quant_state = module.quant_state
I/O Contract
Inputs
| Parameter | Type | Required | Description |
|---|---|---|---|
| module | Linear4bit or Embedding4bit | Yes | The module whose weight's quant_state may need recovery |
Outputs
None -- mutates module.weight.quant_state in-place.
Side Effects
- Sets
module.weight.quant_state = module.quant_stateif the weight'squant_statewasNone - May wrap
module.weightin a new Params4bit instance if it was converted to a plain tensor by FSDP
Usage Examples
Called automatically within Linear4bit.forward(); shown here in context:
class Linear4bit(nn.Linear):
def forward(self, x: torch.Tensor):
# Recovery call -- ensures quant_state is present after FSDP unshard
fix_4bit_weight_quant_state_from_module(self)
quant_state = self.weight.quant_state
# ... rest of forward pass uses quant_state for dequantization
return bnb.matmul_4bit(x, self.weight.t(), bias=bias, quant_state=quant_state).to(inp_dtype)