Implementation:Predibase Lorax FP8 Linear
| Knowledge Sources | |
|---|---|
| Domains | Quantization, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements an FP8 (8-bit floating point) quantized linear layer using vLLM's CUTLASS-based scaled GEMM kernels for high-throughput inference with dynamic per-token input quantization.
Description
This module provides FP8 linear layer support adapted from vLLM:
apply_fp8_linear: A standalone function that performs the full FP8 GEMM pipeline:
- Quantizes the input tensor to FP8 using ops.scaled_fp8_quant with dynamic per-token scaling (or a static input scale if provided). An optional input_scale_ub upper bound can constrain the scale.
- Performs the scaled matrix multiplication via ops.cutlass_scaled_mm, which takes the FP8 quantized input and weight, applies separate scale factors for input (scale_a) and weight (scale_b), casts the output back to the input dtype, and optionally adds bias.
Fp8Linear: An nn.Module that stores pre-quantized FP8 weights. During __init__:
- Transposes the weight matrix for the CUTLASS kernel layout.
- Reshapes weight_scale to (1, -1) and casts to float32 for proper broadcasting.
- Stores the optional bias as qbias and input scale as float32.
The forward method delegates to apply_fp8_linear. A weight property returns the transposed qweight for compatibility with the layer interface.
Usage
This layer is instantiated by the get_linear factory when quantize is fp8 or fp8-kv. It requires models that have been pre-quantized to FP8 format with stored weight scales.
Code Reference
Source Location
- Repository: Predibase_Lorax
- File: server/lorax_server/layers/fp8.py
- Lines: 1-52
Signature
def apply_fp8_linear(
input: torch.Tensor,
qweight: torch.Tensor,
weight_scale: torch.Tensor,
input_scale: Optional[torch.Tensor] = None,
input_scale_ub: Optional[torch.Tensor] = None,
qbias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
class Fp8Linear(torch.nn.Module):
def __init__(self, weight, bias, weight_scale, input_scale) -> None:
Import
from lorax_server.layers.fp8 import Fp8Linear, apply_fp8_linear
I/O Contract
Inputs (Fp8Linear.__init__)
| Name | Type | Required | Description |
|---|---|---|---|
| weight | torch.Tensor (FP8) | Yes | Pre-quantized FP8 weight tensor |
| bias | torch.Tensor or None | No | Optional bias tensor |
| weight_scale | torch.Tensor | Yes | Per-channel or per-tensor weight scale factor |
| input_scale | torch.Tensor or None | No | Optional static input scale; if None, dynamic per-token scaling is used |
Outputs
| Name | Type | Description |
|---|---|---|
| output | torch.Tensor | Result of the FP8 GEMM, cast back to the input dtype |
Usage Examples
# Used internally via the linear factory
from lorax_server.layers.fp8 import Fp8Linear
layer = Fp8Linear(weight, bias=None, weight_scale=weight_scale, input_scale=input_scale)
output = layer(input_tensor)