Implementation:Vllm project Vllm Scaled MM Epilogues C3X
| Knowledge Sources | |
|---|---|
| Domains | CUTLASS, Epilogue, Quantization, GEMM |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Defines complete CUTLASS 3.x epilogue configurations for scaled and quantized matrix multiplications on Hopper (SM90+) and later GPUs, supporting symmetric and asymmetric quantization with bias, zero-points, and array-based batching.
Description
This header composes Sm90 epilogue fusion visitors from broadcast_load_epilogue_c3x.hpp and broadcast_load_epilogue_array_c3x.hpp with CUTLASS Sm90EVT compute nodes to implement dequantization formulas directly in the GEMM epilogue. It provides several epilogue variants: TrivialEpilogue (identity passthrough), ScaledEpilogue, ScaledEpilogueBias (row bias), ScaledEpilogueColumnBias (column bias for transposed GEMMs), ScaledEpilogueBiasAzp (per-tensor activation zero-point), ScaledEpilogueBiasAzpToken (per-token activation zero-point), and ScaledEpilogueArray (batched pointer arrays for MoE workloads). Each class exposes EVTCompute and prepare_args for easy integration.
Usage
This header is included during compilation of CUTLASS 3.x quantized GEMM kernels for Hopper+ GPUs. It is used by vLLM cutlass_scaled_mm, cutlass_scaled_mm_azp, and cutlass_moe_mm operations when running on SM90a or later hardware.
Code Reference
Source Location
- Repository: vllm
- File: csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp
- Lines: 1-450
Signature
namespace vllm::c3x {
template <typename ElementAcc, typename ElementD, typename TileShape>
struct TrivialEpilogue {
using EVTCompute = Sm90EVT<Compute, Accum>;
static ArgumentType prepare_args(Args... args);
};
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBase { /* common load descriptors */ };
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogue : private ScaledEpilogueBase<...> {
using EVTCompute = Sm90EVT<Compute1, ScaleA, EVTCompute0>;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales);
};
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBias : private ScaledEpilogueBase<...> {
using EVTCompute = Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias);
};
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueColumnBias : private ScaledEpilogueBase<...> {
using EVTCompute = Sm90EVT<Compute1, ScaleA, EVTCompute0, Bias>;
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& bias);
};
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBiasAzp : private ScaledEpilogueBase<...> {
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
std::optional<torch::Tensor> const& bias);
};
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueBiasAzpToken : private ScaledEpilogueBase<...> {
static ArgumentType prepare_args(torch::Tensor const& a_scales,
torch::Tensor const& b_scales,
torch::Tensor const& azp_adj,
torch::Tensor const& azp,
std::optional<torch::Tensor> const& bias);
};
template <typename ElementAcc, typename ElementD, typename TileShape>
struct ScaledEpilogueArray : private ScaledEpilogueBase<...> {
static ArgumentType prepare_args(/* array-based pointer arguments */);
};
} // namespace vllm::c3x
Import
#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| a_scales | torch::Tensor | Yes | Quantization scales for operand A; scalar (per-tensor) or column vector (per-row) |
| b_scales | torch::Tensor | Yes | Quantization scales for operand B; scalar (per-tensor) or row vector (per-column) |
| bias | torch::Tensor / optional | No | Per-output-channel bias tensor; row vector (1xN) or column vector (Mx1) depending on variant |
| azp_adj | torch::Tensor | No | Activation zero-point adjustment of shape (1,N), computed as azp * J @ B or J @ B |
| azp | torch::Tensor | No | Per-token activation zero-points of shape (M,1) for AzpToken variant |
| ElementAcc | template param | Yes | Accumulator element type (e.g., float, int32_t) |
| ElementD | template param | Yes | Output element type (e.g., half, bfloat16) |
| TileShape | template param | Yes | CTA tile shape for the GEMM kernel |
Outputs
| Name | Type | Description |
|---|---|---|
| EVTCompute::Arguments | ArgumentType | Fully constructed epilogue arguments struct ready for kernel launch |
Usage Examples
// ScaledEpilogue on Hopper: D = (a_scales * A) @ (b_scales * B)
using Epilogue = vllm::c3x::ScaledEpilogue<float, cutlass::half_t, TileShape>;
auto args = Epilogue::prepare_args(a_scales_tensor, b_scales_tensor);
// ScaledEpilogueBias: D = (a_scales * A) @ (b_scales * B) + bias
using EpilogueBias = vllm::c3x::ScaledEpilogueBias<float, cutlass::half_t, TileShape>;
auto bias_args = EpilogueBias::prepare_args(a_scales, b_scales, bias);
// ScaledEpilogueColumnBias: for transposed GEMM (e.g., 2:4 sparse)
using EpilogueColBias = vllm::c3x::ScaledEpilogueColumnBias<float, cutlass::half_t, TileShape>;
auto col_args = EpilogueColBias::prepare_args(a_scales, b_scales, col_bias);
// TrivialEpilogue: identity passthrough (no scaling)
using TrivialEpi = vllm::c3x::TrivialEpilogue<float, cutlass::half_t, TileShape>;
auto trivial_args = TrivialEpi::prepare_args();