Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Bitsandbytes foundation Bitsandbytes Research FP8 Matmul

From Leeroopedia


Knowledge Sources
Domains Quantization, Research, Training
Last Updated 2026-02-07 13:31 GMT

Overview

Research autograd functions implementing FP8-simulated matrix multiplication with mixed-precision and global quantization strategies for forward and backward passes.

Description

This module provides experimental PyTorch autograd functions that simulate FP8 (8-bit floating point) training by quantizing activations and weights to FP8 format before performing standard floating-point matrix multiplications. Two quantization strategies are offered: MatMulFP8Mixed uses blockwise quantization for activations and global quantization for weights, while MatMulFP8Global uses global quantization for both. A third function SwitchBackBnb implements INT8 matmul with outlier decomposition in the research context. The module also provides convenience wrappers (matmul_fp8_global, matmul_fp8_mixed, switchback_bnb) and a block size selection helper.

Usage

Use these functions when experimenting with FP8-simulated training to evaluate quantization-aware matmul accuracy without requiring hardware FP8 support. They serve as research tools for studying the impact of different quantization granularities (blockwise vs global) on training dynamics.

Code Reference

Source Location

Signature

class MatMulFP8Mixed(torch.autograd.Function):
    @staticmethod
    def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
        ...
    @staticmethod
    def backward(ctx, grad_output):
        ...

class MatMulFP8Global(torch.autograd.Function):
    @staticmethod
    def forward(ctx, A, B, out=None, fw_code=None, bw_code=None, bsz=1024, bsz2=1024):
        ...
    @staticmethod
    def backward(ctx, grad_output):
        ...

class SwitchBackBnb(torch.autograd.Function):
    @staticmethod
    def forward(ctx, A, B, out=None, bias=None, state=None):
        ...
    @staticmethod
    def backward(ctx, grad_output):
        ...

def matmul_fp8_global(A: torch.Tensor, B: torch.Tensor, fw_code: torch.Tensor,
                       bw_code: torch.Tensor, out=None, bsz=-1, bsz2=-1):
    ...

def matmul_fp8_mixed(A: torch.Tensor, B: torch.Tensor, fw_code: torch.Tensor,
                      bw_code: torch.Tensor, out=None, bsz=-1, bsz2=-1):
    ...

def switchback_bnb(A: torch.Tensor, B: torch.Tensor, out=None, state=None,
                    threshold=0.0, bias=None):
    ...

Import

from bitsandbytes.research.autograd._functions import (
    MatMulFP8Mixed, MatMulFP8Global, SwitchBackBnb,
    matmul_fp8_global, matmul_fp8_mixed, switchback_bnb,
)
# Or via the research namespace:
import bitsandbytes as bnb
bnb.research.matmul_fp8_global(...)
bnb.research.matmul_fp8_mixed(...)

I/O Contract

Inputs

Name Type Required Description
A torch.Tensor Yes Input activation tensor (2D or 3D)
B torch.Tensor Yes Weight tensor
fw_code torch.Tensor Yes FP8 quantization codebook for forward pass (e.g., E4M3 map)
bw_code torch.Tensor Yes FP8 quantization codebook for backward pass (e.g., E5M2 map)
bsz int No Block size for A quantization (auto-detected if -1)
bsz2 int No Block size for output dim quantization (auto-detected if -1)

Outputs

Name Type Description
output torch.Tensor Result of quantize-dequantize-matmul (same shape as standard matmul)
grad_A torch.Tensor Gradient w.r.t. A (backward pass, also FP8-quantized)
grad_B torch.Tensor Gradient w.r.t. B (backward pass)

Usage Examples

FP8 Mixed Quantization Matmul

import torch
import bitsandbytes as bnb

# Create FP8 codebooks: E4M3 for forward, E5M2 for backward
fw_code = bnb.functional.create_fp8_map(True, 4, 3, 8).to("cuda")
bw_code = bnb.functional.create_fp8_map(True, 5, 2, 8).to("cuda")

A = torch.randn(32, 512, device="cuda", dtype=torch.float16, requires_grad=True)
B = torch.randn(256, 512, device="cuda", dtype=torch.float16, requires_grad=True)

# Mixed: A quantized blockwise, B quantized globally
output = bnb.research.matmul_fp8_mixed(A, B.t(), fw_code=fw_code, bw_code=bw_code)
output.sum().backward()

FP8 Global Quantization Matmul

# Global: both A and B quantized with single global absmax
output = bnb.research.matmul_fp8_global(A, B.t(), fw_code=fw_code, bw_code=bw_code)
output.sum().backward()

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment