Implementation:Bitsandbytes foundation Bitsandbytes Quantize Columnwise Transpose
| Knowledge Sources | |
|---|---|
| Domains | Quantization, Triton, INT8 |
| Last Updated | 2026-02-07 13:31 GMT |
Overview
Triton kernel that performs fused columnwise INT8 quantization and matrix transpose in a single GPU pass.
Description
This module provides a Triton JIT-compiled kernel that combines two operations into one GPU pass: (1) quantizing each column of a matrix to INT8 by computing per-column absolute maximum and scaling to [-127, 127], and (2) transposing the result. The fusion avoids an intermediate materialization of the quantized matrix before transposing, saving memory bandwidth. The kernel is autotuned across pipeline stages and warp configurations. It is used by the SwitchBack linear layer for the backward pass weight gradient computation.
Usage
Used internally by the SwitchBack vector-wise backward pass to quantize and transpose the weight matrix in a single kernel call.
Code Reference
Source Location
- Repository: bitsandbytes
- File: bitsandbytes/triton/quantize_columnwise_and_transpose.py
- Lines: 1-75
Signature
def quantize_columnwise_and_transpose(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""
Quantize each column of x to INT8 and transpose.
Args:
x: Input tensor of shape (M, N), float16/float32
Returns:
output: INT8 tensor of shape (N, M) (transposed)
output_maxs: Per-column absolute max of shape (N,), float16
"""
Import
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| x | torch.Tensor | Yes | Input matrix of shape (M, N) on CUDA |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor (int8) | Quantized and transposed matrix of shape (N, M) |
| output_maxs | torch.Tensor (float16) | Per-column absolute maximum values of shape (N,) |
Usage Examples
Fused Columnwise Quantize + Transpose
import torch
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose
W = torch.randn(768, 3072, device="cuda", dtype=torch.float16)
W_int8_transposed, col_maxs = quantize_columnwise_and_transpose(W)
# W_int8_transposed.shape = (3072, 768), dtype=int8
# col_maxs.shape = (3072,), dtype=float16