Implementation:Bitsandbytes foundation Bitsandbytes LinearFP8
| Knowledge Sources | |
|---|---|
| Domains | Research, Quantization, Neural_Network_Modules |
| Last Updated | 2026-02-07 13:31 GMT |
Overview
Research nn.Module linear layers that perform FP8-simulated forward and backward passes using mixed or global quantization strategies.
Description
LinearFP8Mixed and LinearFP8Global are drop-in nn.Linear replacements for FP8-simulated training research. Both lazily initialize FP8 codebooks on first forward call (E4M3 for forward, E5M2 for backward) and auto-select block sizes based on feature dimensions. LinearFP8Mixed uses blockwise quantization for activations and global quantization for weights, while LinearFP8Global uses global quantization for both operands. These modules delegate to the matmul_fp8_mixed and matmul_fp8_global autograd functions.
Usage
Use as nn.Linear replacements when experimenting with FP8 training. These are research modules and not intended for production deployment.
Code Reference
Source Location
- Repository: bitsandbytes
- File: bitsandbytes/research/nn/modules.py
- Lines: 1-76
Signature
class LinearFP8Mixed(nn.Linear):
def __init__(self, input_features: int, output_features: int, bias: bool = True):
...
def forward(self, x: torch.Tensor) -> torch.Tensor: ...
class LinearFP8Global(nn.Linear):
def __init__(self, input_features: int, output_features: int, bias: bool = True):
...
def forward(self, x: torch.Tensor) -> torch.Tensor: ...
Import
from bitsandbytes.research.nn.modules import LinearFP8Mixed, LinearFP8Global
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input_features | int | Yes | Input feature dimension |
| output_features | int | Yes | Output feature dimension |
| bias | bool | No | Include bias term (default True) |
| x | torch.Tensor | Yes | Input tensor of shape (*, input_features) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | FP8-simulated linear output of shape (*, output_features) |
Usage Examples
FP8 Mixed Linear Layer
import torch
from bitsandbytes.research.nn.modules import LinearFP8Mixed
layer = LinearFP8Mixed(768, 3072).cuda()
x = torch.randn(32, 128, 768, device="cuda", dtype=torch.float16)
output = layer(x) # FP8-simulated matmul
output.sum().backward()