Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Bitsandbytes foundation Bitsandbytes MatMul8bitLt

From Leeroopedia


Metadata

Field Value
Sources Repo: bitsandbytes, Paper: LLM.int8()
Domains Quantization, Linear_Algebra
Type API Doc
Last updated 2026-02-07 14:00 GMT

Overview

Concrete tool for mixed-precision INT8/FP16 matrix multiplication with outlier decomposition provided by the bitsandbytes library.

Description

MatMul8bitLt is a torch.autograd.Function that implements the LLM.int8() mixed-precision matrix multiplication with full autograd support for both forward and backward passes.

Forward pass:

  1. Quantize activations (A): Calls int8_vectorwise_quant (fast path, no gradient needed for A) or int8_double_quant (slow path, gradient needed for B) to quantize the input activations to INT8. Outlier columns are identified as a side effect.
  2. Quantize weights (B): If weights are not already quantized (first pass or has_fp16_weights=True), calls int8_vectorwise_quant on the weights. Quantized weights (CB) and scales (SCB) are stored in the MatmulLtState.
  3. Dispatch matmul:
    • With outliers (threshold > 0): Dispatches to torch.ops.bitsandbytes.int8_mixed_scaled_mm, which performs the split INT8 + FP16 computation.
    • Without outliers (threshold = 0): Dispatches to torch.ops.bitsandbytes.int8_scaled_mm, which performs pure INT8 scaled matmul.

Backward pass:

  1. Gradient w.r.t. input (grad_A): Dequantizes weights via CB * SCB / 127 and performs FP16/BF16 matmul with grad_output.
  2. Gradient w.r.t. weights (grad_B): Uses int8_double_quant on the gradient and performs INT8 scaled matmul with the transposed quantized activations. If outliers were present, adds the FP16 gradient contribution for outlier columns.

The matmul dispatch function:

The top-level matmul() function selects the appropriate implementation:

  • GPU (CUDA): Always uses MatMul8bitLt.
  • CPU/XPU (training): Uses MatMul8bitFp, a fallback that dequantizes weights before standard matmul (faster on CPU/XPU due to lack of fast INT8 quant/dequant kernels).

Supporting classes:

  • MatmulLtState (dataclass, L55-93): Tracks per-layer state including quantized weights (CB), scaling factors (SCB), outlier threshold, training mode, and FP16 weight retention flag.
  • GlobalOutlierPooler (singleton, L22-49): Pools outlier column indices across layers. Particularly useful for small models where outlier features are less systematic and occur with low frequency.

Code Reference

  • Source: bitsandbytes repo
  • Files:
    • bitsandbytes/autograd/_functions.py: MatMul8bitLt (L95-237), matmul (L351-366)
    • bitsandbytes/autograd/_functions.py: MatmulLtState (L55-93), GlobalOutlierPooler (L22-49)
    • bitsandbytes/nn/modules.py: Linear8bitLt.forward (L1087-1101)
  • Import:
import bitsandbytes as bnb
# Use via: bnb.matmul(A, B, ...)

# Or import directly:
from bitsandbytes.autograd._functions import MatMul8bitLt
  • Signature for matmul:
def matmul(
    A: torch.Tensor,
    B: torch.Tensor,
    out: Optional[torch.Tensor] = None,
    state: Optional[MatmulLtState] = None,
    threshold: float = 0.0,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
  • Signature for MatMul8bitLt.forward:
@staticmethod
def forward(
    ctx: torch.autograd.function.FunctionCtx,
    A: torch.Tensor,
    B: torch.Tensor,
    out: Optional[torch.Tensor] = None,
    bias: Optional[torch.Tensor] = None,
    state: Optional[MatmulLtState] = None,
) -> torch.Tensor:

I/O Contract

Inputs (matmul)

Parameter Type Required Default Description
A torch.Tensor Yes -- Input activations. Will be cast to float16 if not already.
B torch.Tensor / Int8Params Yes -- Weight matrix. Can be FP16 (will be quantized) or pre-quantized Int8Params.
out torch.Tensor or None No None Optional pre-allocated output tensor.
state MatmulLtState or None No None Per-layer state object tracking quantized weights and configuration. A new one is created if None.
threshold float No 0.0 Outlier detection threshold. Overrides state.threshold if > 0.
bias torch.Tensor or None No None Optional bias vector added to the output.

Outputs

Output Type Description
result torch.Tensor Output activations with shape (*A.shape[:-1], B.shape[0]).

Usage Examples

Using bnb.matmul directly:

import torch
import bitsandbytes as bnb
from bitsandbytes.autograd._functions import MatmulLtState

# Create input tensors
A = torch.randn(2, 64, dtype=torch.float16, device="cuda")
B = torch.randn(128, 64, dtype=torch.float16, device="cuda")

# Create state for tracking quantized weights across calls
state = MatmulLtState()
state.threshold = 6.0
state.has_fp16_weights = True
state.is_training = False

# Perform mixed-precision INT8 matmul
output = bnb.matmul(A, B, state=state)
print(output.shape)  # torch.Size([2, 128])

Using bnb.matmul with threshold (outlier decomposition):

import torch
import bitsandbytes as bnb

A = torch.randn(4, 256, dtype=torch.float16, device="cuda")
B = torch.randn(512, 256, dtype=torch.float16, device="cuda")

# threshold > 0 enables LLM.int8() outlier decomposition
output = bnb.matmul(A, B, threshold=6.0)
print(output.shape)  # torch.Size([4, 512])

Related

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment