Implementation:Predibase Lorax EETQ Linear
| Knowledge Sources | |
|---|---|
| Domains | Quantization, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements a W8A16 (8-bit weight, 16-bit activation) quantized linear layer using the EETQ library for efficient int8 weight-only quantization with high-performance GEMM.
Description
This module provides a single class that wraps the EETQ (Easy and Efficient Quantization for Transformers) library:
EETQLinear: An nn.Module that quantizes weights to int8 at initialization time and performs W8A16 GEMM during inference. During __init__:
- Converts the weight to float16 if not already.
- Transposes and makes the weight contiguous on CPU.
- Calls quant_weights from the EETQ library to produce an int8 weight tensor and a per-channel float scale tensor.
- Moves both the quantized weight and scale back to CUDA.
The forward method calls w8_a16_gemm which performs matrix multiplication between the float16 input and the int8 weight using the scale for dequantization, then optionally adds bias.
Usage
This layer is instantiated by the get_linear factory when quantize == "eetq". It is recommended as a drop-in replacement for bitsandbytes 8-bit quantization with significantly better inference performance.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File: server/lorax_server/layers/eetq.py
- Lines: 1-25
Signature
class EETQLinear(torch.nn.Module):
def __init__(self, weight, bias) -> None:
Import
from lorax_server.layers.eetq import EETQLinear
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| weight | torch.Tensor | Yes | Full-precision weight tensor to be quantized to int8 |
| bias | torch.Tensor or None | No | Optional bias tensor |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | Result of the W8A16 GEMM, with optional bias added |
Usage Examples
# Used internally via the linear factory
from lorax_server.layers.eetq import EETQLinear
layer = EETQLinear(weight, bias=None)
output = layer(input_tensor)