Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Bitsandbytes foundation Bitsandbytes SwitchBackLinear

From Leeroopedia
Revision as of 14:34, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Bitsandbytes_foundation_Bitsandbytes_SwitchBackLinear.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Quantization, Neural_Network_Modules, INT8
Last Updated 2026-02-07 13:31 GMT

Overview

Triton-based INT8 quantized linear layer that quantizes activations and weights to INT8 during forward and backward passes using configurable quantization strategies.

Description

SwitchBackLinear is an nn.Linear replacement that performs INT8 quantized matrix multiplication using Triton JIT-compiled kernels. It supports three quantization strategies: (1) global quantization — activations are quantized rowwise, weights are quantized with a single global scaling factor; (2) vector-wise quantization — both activations and weights are quantized rowwise; (3) memory-efficient global — same as global but saves quantized activations in the autograd graph instead of full-precision tensors, reducing memory at the cost of a dequantization step in backward. The module also supports pre-quantizing weights for inference via prepare_for_eval().

Usage

Use SwitchBackLinear as a drop-in replacement for nn.Linear when training with INT8 quantized matmul to reduce memory usage. Requires Triton to be installed.

Code Reference

Source Location

Signature

class SwitchBackLinear(nn.Linear):
    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = True,
        device=None,
        dtype=None,
        vector_wise_quantization: bool = False,
        mem_efficient: bool = False,
    ): ...

    def prepare_for_eval(self) -> None:
        """Pre-quantize weights for inference."""

    def forward(self, x: torch.Tensor) -> torch.Tensor: ...

# Convenience aliases:
SwitchBackLinearGlobal = partial(SwitchBackLinear, vector_wise_quantization=False)
SwitchBackLinearGlobalMemEfficient = partial(SwitchBackLinear, vector_wise_quantization=False, mem_efficient=True)
SwitchBackLinearVectorwise = partial(SwitchBackLinear, vector_wise_quantization=True)

Import

from bitsandbytes.nn import SwitchBackLinear
# Or specific variants:
from bitsandbytes.nn.triton_based_modules import (
    SwitchBackLinearGlobal,
    SwitchBackLinearGlobalMemEfficient,
    SwitchBackLinearVectorwise,
)

I/O Contract

Inputs

Name Type Required Description
in_features int Yes Input feature dimension
out_features int Yes Output feature dimension
bias bool No Whether to include bias (default True)
vector_wise_quantization bool No Use rowwise quantization for weights (default False = global)
mem_efficient bool No Save quantized activations to reduce memory (default False)
x torch.Tensor Yes Input tensor of shape (*, in_features)

Outputs

Name Type Description
output torch.Tensor Linear transformation result of shape (*, out_features)

Usage Examples

Basic SwitchBack Linear

import torch
from bitsandbytes.nn import SwitchBackLinear

# Replace nn.Linear with SwitchBackLinear
layer = SwitchBackLinear(768, 3072, bias=True).cuda()

x = torch.randn(32, 128, 768, device="cuda", dtype=torch.float16)
output = layer(x)  # INT8 quantized matmul via Triton
output.sum().backward()

Pre-quantized Inference

# Pre-quantize weights for faster inference
layer.eval()
layer.prepare_for_eval()

# Now forward uses pre-computed INT8 weights
with torch.no_grad():
    output = layer(x)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment