Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Predibase Lorax FP8 Linear

From Leeroopedia


Knowledge Sources
Domains Quantization, Inference
Last Updated 2026-02-08 00:00 GMT

Overview

Implements an FP8 (8-bit floating point) quantized linear layer using vLLM's CUTLASS-based scaled GEMM kernels for high-throughput inference with dynamic per-token input quantization.

Description

This module provides FP8 linear layer support adapted from vLLM:

apply_fp8_linear: A standalone function that performs the full FP8 GEMM pipeline:

  1. Quantizes the input tensor to FP8 using ops.scaled_fp8_quant with dynamic per-token scaling (or a static input scale if provided). An optional input_scale_ub upper bound can constrain the scale.
  2. Performs the scaled matrix multiplication via ops.cutlass_scaled_mm, which takes the FP8 quantized input and weight, applies separate scale factors for input (scale_a) and weight (scale_b), casts the output back to the input dtype, and optionally adds bias.

Fp8Linear: An nn.Module that stores pre-quantized FP8 weights. During __init__:

  • Transposes the weight matrix for the CUTLASS kernel layout.
  • Reshapes weight_scale to (1, -1) and casts to float32 for proper broadcasting.
  • Stores the optional bias as qbias and input scale as float32.

The forward method delegates to apply_fp8_linear. A weight property returns the transposed qweight for compatibility with the layer interface.

Usage

This layer is instantiated by the get_linear factory when quantize is fp8 or fp8-kv. It requires models that have been pre-quantized to FP8 format with stored weight scales.

Code Reference

Source Location

  • Repository: Predibase_Lorax
  • File: server/lorax_server/layers/fp8.py
  • Lines: 1-52

Signature

def apply_fp8_linear(
    input: torch.Tensor,
    qweight: torch.Tensor,
    weight_scale: torch.Tensor,
    input_scale: Optional[torch.Tensor] = None,
    input_scale_ub: Optional[torch.Tensor] = None,
    qbias: Optional[torch.Tensor] = None,
) -> torch.Tensor:

class Fp8Linear(torch.nn.Module):
    def __init__(self, weight, bias, weight_scale, input_scale) -> None:

Import

from lorax_server.layers.fp8 import Fp8Linear, apply_fp8_linear

I/O Contract

Inputs (Fp8Linear.__init__)

Name Type Required Description
weight torch.Tensor (FP8) Yes Pre-quantized FP8 weight tensor
bias torch.Tensor or None No Optional bias tensor
weight_scale torch.Tensor Yes Per-channel or per-tensor weight scale factor
input_scale torch.Tensor or None No Optional static input scale; if None, dynamic per-token scaling is used

Outputs

Name Type Description
output torch.Tensor Result of the FP8 GEMM, cast back to the input dtype

Usage Examples

# Used internally via the linear factory
from lorax_server.layers.fp8 import Fp8Linear

layer = Fp8Linear(weight, bias=None, weight_scale=weight_scale, input_scale=input_scale)
output = layer(input_tensor)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment