Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Bitsandbytes foundation Bitsandbytes Quantize Columnwise Transpose

From Leeroopedia


Knowledge Sources
Domains Quantization, Triton, INT8
Last Updated 2026-02-07 13:31 GMT

Overview

Triton kernel that performs fused columnwise INT8 quantization and matrix transpose in a single GPU pass.

Description

This module provides a Triton JIT-compiled kernel that combines two operations into one GPU pass: (1) quantizing each column of a matrix to INT8 by computing per-column absolute maximum and scaling to [-127, 127], and (2) transposing the result. The fusion avoids an intermediate materialization of the quantized matrix before transposing, saving memory bandwidth. The kernel is autotuned across pipeline stages and warp configurations. It is used by the SwitchBack linear layer for the backward pass weight gradient computation.

Usage

Used internally by the SwitchBack vector-wise backward pass to quantize and transpose the weight matrix in a single kernel call.

Code Reference

Source Location

Signature

def quantize_columnwise_and_transpose(x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Quantize each column of x to INT8 and transpose.

    Args:
        x: Input tensor of shape (M, N), float16/float32
    Returns:
        output: INT8 tensor of shape (N, M) (transposed)
        output_maxs: Per-column absolute max of shape (N,), float16
    """

Import

from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose

I/O Contract

Inputs

Name Type Required Description
x torch.Tensor Yes Input matrix of shape (M, N) on CUDA

Outputs

Name Type Description
output torch.Tensor (int8) Quantized and transposed matrix of shape (N, M)
output_maxs torch.Tensor (float16) Per-column absolute maximum values of shape (N,)

Usage Examples

Fused Columnwise Quantize + Transpose

import torch
from bitsandbytes.triton.quantize_columnwise_and_transpose import quantize_columnwise_and_transpose

W = torch.randn(768, 3072, device="cuda", dtype=torch.float16)
W_int8_transposed, col_maxs = quantize_columnwise_and_transpose(W)
# W_int8_transposed.shape = (3072, 768), dtype=int8
# col_maxs.shape = (3072,), dtype=float16

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment