Implementation:Bitsandbytes foundation Bitsandbytes Quantize Global
| Knowledge Sources | |
|---|---|
| Domains | Quantization, Triton, INT8 |
| Last Updated | 2026-02-07 13:31 GMT |
Overview
Triton kernels for global INT8 quantization of tensors using a single absolute maximum scaling factor, with an optional fused transpose variant.
Description
This module provides two Triton-based global quantization operations: (1) quantize_global computes the absolute maximum of the entire tensor, then scales all elements to [-127, 127] in a single kernel pass, producing an INT8 tensor and the absmax scalar. (2) quantize_global_transpose performs the same global quantization but writes the output in transposed layout using a tiled grouped kernel. Both are autotuned across block sizes and warp configurations.
Usage
Used by the SwitchBack global quantization linear layer for quantizing weight matrices. The transpose variant is used in backward passes where the weight matrix needs to be quantized and transposed simultaneously.
Code Reference
Source Location
- Repository: bitsandbytes
- File: bitsandbytes/triton/quantize_global.py
- Lines: 1-124
Signature
def quantize_global(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Global INT8 quantization: single scaling factor for entire tensor.
Returns: (output_int8, absmax)
"""
def quantize_global_transpose(input: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Global INT8 quantization with fused transpose.
Returns: (output_int8_transposed, absmax)
"""
Import
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| x / input | torch.Tensor | Yes | Input tensor on CUDA (any shape for global, 2D for transpose) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor (int8) | Quantized tensor (same shape, or transposed for _transpose variant) |
| absmax | torch.Tensor (float) | Single-element tensor with absolute maximum |
Usage Examples
Global Quantization
import torch
from bitsandbytes.triton.quantize_global import quantize_global, quantize_global_transpose
W = torch.randn(768, 3072, device="cuda", dtype=torch.float16)
# Global quantize (preserves shape)
W_int8, absmax = quantize_global(W)
# W_int8.shape = (768, 3072), dtype=int8
# Global quantize + transpose
W_int8_t, absmax = quantize_global_transpose(W)
# W_int8_t.shape = (3072, 768), dtype=int8