Implementation:Predibase Lorax BitsAndBytes Layers
| Knowledge Sources | |
|---|---|
| Domains | Quantization, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
⚠️ DEPRECATION WARNING: The 8-bit quantization path (Linear8bitLt) is deprecated. Use EETQ as a drop-in replacement. See Heuristic:Predibase_Lorax_Warning_Deprecated_BitsAndBytes_8bit.
Provides 8-bit and 4-bit quantized linear layers using the bitsandbytes library for memory-efficient inference with LLM.int8() and NF4/FP4 quantization.
Description
This module wraps the bitsandbytes library to provide quantized linear layer implementations:
warn_deprecate_bnb: A cached helper function that logs a deprecation warning recommending users switch from bitsandbytes 8-bit to EETQ for better performance.
Linear8bitLt: An nn.Module implementing the LLM.int8() mixed-precision decomposition. During initialization, it creates a bnb.MatmulLtState to manage quantization state, converts weights to Int8Params using bitsandbytes, and moves them to CUDA. Key features:
- Configurable threshold (default 6.0) for outlier feature detection in the mixed-precision scheme.
- init_8bit_state transfers quantized weight data (CB, SCB) from parameters to the matmul state.
- forward calls bnb.matmul with the quantization state, automatically casting bias to match input dtype. After the first pass, it cleans up the row-major weight copy (CB) since the Turing/Ampere-formatted weight (CxB) is retained.
Linear4bit: An nn.Module implementing 4-bit quantization using bitsandbytes Params4bit. Supports both NF4 (NormalFloat4) and FP4 (Float4) quantization types via the quant_type parameter. Uses compressed statistics for efficient storage. The forward method calls bnb.matmul_4bit with the quantization state, supporting optional compute dtype override and automatic bias dtype casting.
Usage
These layers are instantiated by the get_linear factory when the quantization method is bitsandbytes (8-bit), bitsandbytes-nf4 (NF4 4-bit), or bitsandbytes-fp4 (FP4 4-bit). The 8-bit variant is considered deprecated in favor of EETQ.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File: server/lorax_server/layers/bnb.py
- Lines: 1-103
Signature
class Linear8bitLt(torch.nn.Module):
def __init__(self, weight, bias, has_fp16_weights=True,
memory_efficient_backward=False, threshold=0.0, index=None):
class Linear4bit(torch.nn.Module):
def __init__(self, weight, bias, quant_type):
Import
from lorax_server.layers.bnb import Linear8bitLt, Linear4bit
I/O Contract
Inputs (Linear8bitLt)
| Name | Type | Required | Description |
|---|---|---|---|
| weight | torch.Tensor | Yes | Full-precision weight tensor to be quantized to int8 |
| bias | torch.Tensor or None | No | Optional bias tensor |
| has_fp16_weights | bool | No | Whether to keep fp16 weights (default True, set to False for inference) |
| threshold | float | No | Outlier threshold for LLM.int8() decomposition (default 0.0, typically 6.0) |
| index | int or None | No | Optional index for the layer |
Inputs (Linear4bit)
| Name | Type | Required | Description |
|---|---|---|---|
| weight | torch.Tensor | Yes | Full-precision weight tensor to be quantized to 4-bit |
| bias | torch.Tensor or None | No | Optional bias tensor |
| quant_type | str | Yes | Quantization type: "nf4" or "fp4" |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | Result of the quantized linear transformation, dtype matches input |
Usage Examples
# Used internally via the linear factory
from lorax_server.layers.bnb import Linear8bitLt, Linear4bit
# 8-bit quantized layer
layer_8bit = Linear8bitLt(weight, bias=None, has_fp16_weights=False, threshold=6.0)
output = layer_8bit(input_tensor)
# 4-bit NF4 quantized layer
layer_4bit = Linear4bit(weight, bias=None, quant_type="nf4")
output = layer_4bit(input_tensor)