Implementation:Sgl project Sglang Scalar Type Python
| Knowledge Sources | |
|---|---|
| Domains | Quantization, Type System, LLM Inference |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
Python implementation of the ScalarType system for representing arbitrary numeric types including sub-byte quantized formats, mirroring the C++ ScalarType class for cross-language type negotiation.
Description
scalar_type.py implements a comprehensive type system that can represent a wide range of floating point and integer types, including sub-byte formats that torch.dtype does not natively support. This is essential for quantized model inference where weights may be stored in 2-bit, 3-bit, or 4-bit formats.
NanRepr Enum: Defines how NaN values are represented in a scalar type:
- NONE -- NaNs are not supported (common for finite-only types)
- IEEE_754 -- Standard IEEE 754 NaN encoding (exponent all 1s, mantissa not all 0s)
- EXTD_RANGE_MAX_MIN -- Extended range encoding where NaN is exponent all 1s with mantissa all 1s (frees one mantissa pattern for additional values)
ScalarType Dataclass: A frozen (immutable) dataclass with the following fields:
- exponent (int) -- Number of exponent bits (0 for integer types)
- mantissa (int) -- Number of mantissa bits (or value bits excluding sign for integers)
- signed (bool) -- Whether the type has a sign bit
- bias (int) -- Encoding bias where stored_value = value + bias (e.g., GPTQ 4-bit uses bias=8)
- _finite_values_only (bool) -- Whether infinities are excluded
- nan_repr (NanRepr) -- NaN representation scheme
Key properties and methods:
- id -- Encodes the type into a compact int64 by bit-packing all fields (exponent:8, mantissa:8, signed:1, bias:32, finite:1, nan_repr:8). This ID is passed through PyTorch custom ops to the C++ side. A global _SCALAR_TYPES_ID_MAP maintains the reverse mapping.
- size_bits -- Total bit width: exponent + mantissa + sign
- min() / max() -- Computes representable value range accounting for bias, NaN encoding, and exponent bias. Uses IEEE 754 double representation internally for floating point bounds.
- is_floating_point() / is_integer() / is_signed() / has_bias() / has_infs() / has_nans() / is_ieee_754() -- Type classification predicates.
Factory Methods:
- ScalarType.int_(size_bits, bias) -- Creates a signed integer type (size_bits includes sign bit)
- ScalarType.uint(size_bits, bias) -- Creates an unsigned integer type
- ScalarType.float_IEEE754(exponent, mantissa) -- Creates a standard IEEE 754 float type
- ScalarType.float_(exponent, mantissa, finite_values_only, nan_repr) -- Creates a non-standard float type
- ScalarType.from_id(scalar_type_id) -- Looks up a previously created type by its packed ID
Predefined Types (scalar_types class):
| Name | Description | Bits |
|---|---|---|
| int4 | Signed 4-bit integer | 4 |
| uint4 | Unsigned 4-bit integer | 4 |
| int8 | Signed 8-bit integer | 8 |
| uint8 | Unsigned 8-bit integer | 8 |
| float8_e4m3fn | FP8 E4M3 (finite, extended range NaN) | 8 |
| float8_e5m2 | FP8 E5M2 (IEEE 754) | 8 |
| float16_e8m7 | BFloat16 (IEEE 754) | 16 |
| float16_e5m10 | Float16 (IEEE 754) | 16 |
| float6_e3m2f | FP6 from fp6_llm (finite only, no NaN) | 6 |
| float4_e2m1f | FP4 from OCP MX format (finite only, no NaN) | 4 |
| uint2b2 | GPTQ 2-bit with bias 2 | 2 |
| uint3b4 | GPTQ 3-bit with bias 4 | 3 |
| uint4b8 | GPTQ 4-bit with bias 8 | 4 |
| uint8b128 | GPTQ 8-bit with bias 128 | 8 |
| bfloat16 | Alias for float16_e8m7 | 16 |
| float16 | Alias for float16_e5m10 | 16 |
The string representation follows the ml_dtypes naming convention: float<size>_e<exp>m<mant>[f][n] for floats and [u]int<size>[b<bias>] for integers.
Usage
Use ScalarType and scalar_types when calling kernel operations that require quantization format specification, such as gptq_marlin_gemm which takes a b_q_type parameter. The id property is used to pass the type through PyTorch custom ops to the C++ kernel layer.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/python/sgl_kernel/scalar_type.py
- Lines: 1-352
Signature
class NanRepr(Enum):
NONE = 0
IEEE_754 = 1
EXTD_RANGE_MAX_MIN = 2
@dataclass(frozen=True)
class ScalarType:
exponent: int
mantissa: int
signed: bool
bias: int
_finite_values_only: bool = False
nan_repr: NanRepr = NanRepr.IEEE_754
@functools.cached_property
def id(self) -> int: ...
@property
def size_bits(self) -> int: ...
def min(self) -> Union[int, float]: ...
def max(self) -> Union[int, float]: ...
def is_signed(self) -> bool: ...
def is_floating_point(self) -> bool: ...
def is_integer(self) -> bool: ...
def has_bias(self) -> bool: ...
def has_infs(self) -> bool: ...
def has_nans(self) -> bool: ...
def is_ieee_754(self) -> bool: ...
@classmethod
def int_(cls, size_bits: int, bias: Optional[int]) -> "ScalarType": ...
@classmethod
def uint(cls, size_bits: int, bias: Optional[int]) -> "ScalarType": ...
@classmethod
def float_IEEE754(cls, exponent: int, mantissa: int) -> "ScalarType": ...
@classmethod
def float_(cls, exponent: int, mantissa: int, finite_values_only: bool, nan_repr: NanRepr) -> "ScalarType": ...
@classmethod
def from_id(cls, scalar_type_id: int) -> "ScalarType": ...
class scalar_types:
int4 = ScalarType.int_(4, None)
uint4 = ScalarType.uint(4, None)
int8 = ScalarType.int_(8, None)
uint8 = ScalarType.uint(8, None)
float8_e4m3fn = ScalarType.float_(4, 3, True, NanRepr.EXTD_RANGE_MAX_MIN)
float8_e5m2 = ScalarType.float_IEEE754(5, 2)
float16_e8m7 = ScalarType.float_IEEE754(8, 7)
float16_e5m10 = ScalarType.float_IEEE754(5, 10)
float6_e3m2f = ScalarType.float_(3, 2, True, NanRepr.NONE)
float4_e2m1f = ScalarType.float_(2, 1, True, NanRepr.NONE)
uint2b2 = ScalarType.uint(2, 2)
uint3b4 = ScalarType.uint(3, 4)
uint4b8 = ScalarType.uint(4, 8)
uint8b128 = ScalarType.uint(8, 128)
bfloat16 = float16_e8m7
float16 = float16_e5m10
Import
from sgl_kernel.scalar_type import ScalarType, scalar_types, NanRepr
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| exponent | int | Yes | Number of exponent bits (0 for integer types) |
| mantissa | int | Yes | Number of mantissa/value bits (excluding sign) |
| signed | bool | Yes | Whether the type has a sign bit |
| bias | int | Yes | Encoding bias (0 for standard types) |
| size_bits | int | Yes (factory) | Total bit width for integer factory methods |
| scalar_type_id | int | Yes (from_id) | Packed int64 ID from a previous ScalarType.id call |
Outputs
| Name | Type | Description |
|---|---|---|
| ScalarType instance | ScalarType | Immutable type descriptor |
| id | int | Packed int64 representation for passing through custom ops |
| min/max | Union[int, float] | Representable value range |
| size_bits | int | Total bit width of the type |
Usage Examples
from sgl_kernel.scalar_type import ScalarType, scalar_types, NanRepr
# Using predefined types for GPTQ Marlin GEMM
from sgl_kernel import gptq_marlin_gemm
result = gptq_marlin_gemm(
a=activations,
c=None,
b_q_weight=quantized_weights,
b_scales=scales,
global_scale=None,
b_zeros=zeros,
g_idx=None,
perm=None,
workspace=workspace,
b_q_type=scalar_types.uint4b8, # GPTQ 4-bit with bias 8
size_m=M, size_n=N, size_k=K,
)
# Querying type properties
t = scalar_types.float8_e4m3fn
print(t.size_bits) # 8
print(t.is_floating_point()) # True
print(t.max()) # 448.0
print(t.has_nans()) # True (EXTD_RANGE_MAX_MIN)
# Creating custom types
my_type = ScalarType.uint(6, bias=32) # 6-bit unsigned with bias 32
print(my_type) # "uint6b32"
print(my_type.id) # packed int64 for C++ ops
# Looking up by ID
recovered = ScalarType.from_id(my_type.id)
assert recovered == my_type