Implementation:Vllm project Vllm Broadcast Load Epilogue C3X
| Knowledge Sources | |
|---|---|
| Domains | CUTLASS, Epilogue, Quantization, GEMM |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements CUTLASS 3.x epilogue visitors for broadcasting quantization scales and biases from device pointers during GEMM operations on Hopper (SM90+) GPUs.
Description
This file is a modified excerpt of CUTLASS v3.5.0 sm90_visitor_load_tma_warpspecialized.hpp, adapted to support row, column, or scalar broadcasting where the tensor is always passed via a device pointer. It leverages shared memory and TMA-based data movement on Hopper GPUs for high-performance broadcast loading. Like its C2X counterpart, this design avoids torch.compile graph breaks caused by CPU-resident scalars by keeping all scale tensors on the device.
Usage
This header is included by scaled_mm_epilogues_c3x.hpp and compiled as part of the CUTLASS 3.x quantized GEMM kernels targeting NVIDIA Hopper (SM90a) and later architectures. It is used whenever vLLM performs scaled matrix multiplication with INT8 or FP8 operands on Hopper+ GPUs.
Code Reference
Source Location
- Repository: vllm
- File: csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp
- Lines: 1-447
Signature
template<
int Stages,
class CtaTileShapeMNK,
class Element,
class StrideMNL = Stride<_0,_1,_0>,
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90RowOrScalarBroadcast {
struct SharedStorage {
array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
};
struct Arguments {
Element const* ptr_row = nullptr;
bool row_broadcast = true;
StrideMNL dRow = {};
};
using Params = Arguments;
};
template<
int Stages,
class CtaTileShapeMNK,
class Element,
class StrideMNL = Stride<_1,_0,_0>,
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90ColOrScalarBroadcast {
struct SharedStorage {
array_aligned<Element, size<0>(CtaTileShapeMNK{})> smem;
};
struct Arguments {
Element const* ptr_col = nullptr;
bool col_broadcast = true;
StrideMNL dCol = {};
};
using Params = Arguments;
};
Import
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| ptr_row / ptr_col | Element const* | Yes | Device pointer to the broadcast tensor (scale, bias, or zero-point) |
| row_broadcast / col_broadcast | bool | No | When true, loads from a vector; when false, broadcasts a scalar (default: true) |
| dRow / dCol | StrideMNL | No | Stride descriptor for the broadcast tensor layout |
| Stages | int | Yes | Number of pipeline stages (must be 0 for row/col broadcast) |
| CtaTileShapeMNK | template param | Yes | CTA tile shape defining M, N, K dimensions |
| Element | template param | Yes | Data type of the broadcast tensor (e.g., float, half) |
| Alignment | int | No | Memory alignment for vector loads (default: 128 / sizeof_bits) |
Outputs
| Name | Type | Description |
|---|---|---|
| visit() return | Array<Element, FragmentSize> | Fragment of broadcast values loaded via shared memory into registers for epilogue computation |
Usage Examples
// Using Sm90RowOrScalarBroadcast in a CUTLASS 3.x epilogue
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
0 /*Stages*/, TileShape, float, Stride<Int<0>, Int<1>, Int<0>>>;
// Construct arguments for per-channel scale broadcast
typename ScaleB::Arguments scale_args{
scale_data_ptr, // device pointer to scale tensor
true, // row_broadcast = true for per-channel
{} // default stride
};
// Construct arguments for per-tensor scalar broadcast
typename ScaleB::Arguments scalar_args{
scalar_data_ptr, // device pointer to single scalar
false, // row_broadcast = false for per-tensor
{} // default stride
};