Implementation:Bitsandbytes foundation Bitsandbytes Research FP8 Matmul
| Knowledge Sources | |
|---|---|
| Domains | Quantization, Research, Training |
| Last Updated | 2026-02-07 13:31 GMT |
Overview
Research autograd functions implementing FP8-simulated matrix multiplication with mixed-precision and global quantization strategies for forward and backward passes.
Description
This module provides experimental PyTorch autograd functions that simulate FP8 (8-bit floating point) training by quantizing activations and weights to FP8 format before performing standard floating-point matrix multiplications. Two quantization strategies are offered: MatMulFP8Mixed uses blockwise quantization for activations and global quantization for weights, while MatMulFP8Global uses global quantization for both. A third function SwitchBackBnb implements INT8 matmul with outlier decomposition in the research context. The module also provides convenience wrappers (matmul_fp8_global, matmul_fp8_mixed, switchback_bnb) and a block size selection helper.
Usage
Use these functions when experimenting with FP8-simulated training to evaluate quantization-aware matmul accuracy without requiring hardware FP8 support. They serve as research tools for studying the impact of different quantization granularities (blockwise vs global) on training dynamics.
Code Reference
Source Location
- Repository: bitsandbytes
- File: bitsandbytes/research/autograd/_functions.py
- Lines: 1-396
Signature
class MatMulFP8Mixed(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
...
@staticmethod
def backward(ctx, grad_output):
...
class MatMulFP8Global(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
...
@staticmethod
def backward(ctx, grad_output):
...
class SwitchBackBnb(torch.autograd.Function):
@staticmethod
def forward(ctx, A, B, out=None, bias=None, state=None):
...
@staticmethod
def backward(ctx, grad_output):
...
def matmul_fp8_global(A: torch.Tensor, B: torch.Tensor, fw_code: torch.Tensor,
bw_code: torch.Tensor, out=None, bsz=-1, bsz2=-1):
...
def matmul_fp8_mixed(A: torch.Tensor, B: torch.Tensor, fw_code: torch.Tensor,
bw_code: torch.Tensor, out=None, bsz=-1, bsz2=-1):
...
def switchback_bnb(A: torch.Tensor, B: torch.Tensor, out=None, state=None,
threshold=0.0, bias=None):
...
Import
from bitsandbytes.research.autograd._functions import (
MatMulFP8Mixed, MatMulFP8Global, SwitchBackBnb,
matmul_fp8_global, matmul_fp8_mixed, switchback_bnb,
)
# Or via the research namespace:
import bitsandbytes as bnb
bnb.research.matmul_fp8_global(...)
bnb.research.matmul_fp8_mixed(...)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| A | torch.Tensor | Yes | Input activation tensor (2D or 3D) |
| B | torch.Tensor | Yes | Weight tensor |
| fw_code | torch.Tensor | Yes | FP8 quantization codebook for forward pass (e.g., E4M3 map) |
| bw_code | torch.Tensor | Yes | FP8 quantization codebook for backward pass (e.g., E5M2 map) |
| bsz | int | No | Block size for A quantization (auto-detected if -1) |
| bsz2 | int | No | Block size for output dim quantization (auto-detected if -1) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | Result of quantize-dequantize-matmul (same shape as standard matmul) |
| grad_A | torch.Tensor | Gradient w.r.t. A (backward pass, also FP8-quantized) |
| grad_B | torch.Tensor | Gradient w.r.t. B (backward pass) |
Usage Examples
FP8 Mixed Quantization Matmul
import torch
import bitsandbytes as bnb
# Create FP8 codebooks: E4M3 for forward, E5M2 for backward
fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to("cuda")
bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to("cuda")
A = torch.randn(32, 512, device="cuda", dtype=torch.float16, requires_grad=True)
B = torch.randn(256, 512, device="cuda", dtype=torch.float16, requires_grad=True)
# Mixed: A quantized blockwise, B quantized globally
output = bnb.research.matmul_fp8_mixed(A, B.t(), fw_code=fw_code, bw_code=bw_code)
output.sum().backward()
FP8 Global Quantization Matmul
# Global: both A and B quantized with single global absmax
output = bnb.research.matmul_fp8_global(A, B.t(), fw_code=fw_code, bw_code=bw_code)
output.sum().backward()