Implementation:Bitsandbytes foundation Bitsandbytes SwitchBackLinear
| Knowledge Sources | |
|---|---|
| Domains | Quantization, Neural_Network_Modules, INT8 |
| Last Updated | 2026-02-07 13:31 GMT |
Overview
Triton-based INT8 quantized linear layer that quantizes activations and weights to INT8 during forward and backward passes using configurable quantization strategies.
Description
SwitchBackLinear is an nn.Linear replacement that performs INT8 quantized matrix multiplication using Triton JIT-compiled kernels. It supports three quantization strategies: (1) global quantization — activations are quantized rowwise, weights are quantized with a single global scaling factor; (2) vector-wise quantization — both activations and weights are quantized rowwise; (3) memory-efficient global — same as global but saves quantized activations in the autograd graph instead of full-precision tensors, reducing memory at the cost of a dequantization step in backward. The module also supports pre-quantizing weights for inference via prepare_for_eval().
Usage
Use SwitchBackLinear as a drop-in replacement for nn.Linear when training with INT8 quantized matmul to reduce memory usage. Requires Triton to be installed.
Code Reference
Source Location
- Repository: bitsandbytes
- File: bitsandbytes/nn/triton_based_modules.py
- Lines: 1-264
Signature
class SwitchBackLinear(nn.Linear):
def __init__(
self,
in_features: int,
out_features: int,
bias: bool = True,
device=None,
dtype=None,
vector_wise_quantization: bool = False,
mem_efficient: bool = False,
): ...
def prepare_for_eval(self) -> None:
"""Pre-quantize weights for inference."""
def forward(self, x: torch.Tensor) -> torch.Tensor: ...
# Convenience aliases:
SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False)
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True)
SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True)
Import
from bitsandbytes.nn import SwitchBackLinear
# Or specific variants:
from bitsandbytes.nn.triton_based_modules import (
SwitchBackLinearGlobal,
SwitchBackLinearGlobalMemEfficient,
SwitchBackLinearVectorwise,
)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| in_features | int | Yes | Input feature dimension |
| out_features | int | Yes | Output feature dimension |
| bias | bool | No | Whether to include bias (default True) |
| vector_wise_quantization | bool | No | Use rowwise quantization for weights (default False = global) |
| mem_efficient | bool | No | Save quantized activations to reduce memory (default False) |
| x | torch.Tensor | Yes | Input tensor of shape (*, in_features) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | Linear transformation result of shape (*, out_features) |
Usage Examples
Basic SwitchBack Linear
import torch
from bitsandbytes.nn import SwitchBackLinear
# Replace nn.Linear with SwitchBackLinear
layer = SwitchBackLinear(768, 3072, bias=True).cuda()
x = torch.randn(32, 128, 768, device="cuda", dtype=torch.float16)
output = layer(x) # INT8 quantized matmul via Triton
output.sum().backward()
Pre-quantized Inference
# Pre-quantize weights for faster inference
layer.eval()
layer.prepare_for_eval()
# Now forward uses pre-computed INT8 weights
with torch.no_grad():
output = layer(x)