Implementation:Vllm project Vllm Broadcast Load Epilogue C2X
| Knowledge Sources | |
|---|---|
| Domains | CUTLASS, Epilogue, Quantization, GEMM |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements CUTLASS 2.x epilogue visitors for broadcasting quantization scales and biases from device pointers during GEMM operations on Ampere (SM80) GPUs.
Description
This file is a modified excerpt of CUTLASS v3.5.0 visitor_load.hpp, adapted to support row, column, or scalar broadcasting where the tensor being loaded is always passed via a device pointer. This design allows a single compiled kernel to handle all cases of per-tensor, per-channel, and per-token quantization. The interface keeps scales as device-resident tensors, avoiding torch.compile graph breaks that occurred when scalars needed to reside on the CPU.
Usage
This header is included by scaled_mm_epilogues_c2x.hpp and compiled as part of the CUTLASS-based quantized GEMM kernels targeting NVIDIA Ampere (SM80) architecture. It is used whenever vLLM performs scaled matrix multiplication with INT8 or FP8 operands on Ampere GPUs.
Code Reference
Source Location
- Repository: vllm
- File: csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp
- Lines: 1-497
Signature
template<class ThreadMap, class Element, class StrideMNL>
struct VisitorRowOrScalarBroadcast {
struct Arguments {
Element const* ptr_row = nullptr;
bool row_broadcast = true;
StrideMNL dRow = {};
};
using Params = Arguments;
};
template<class ThreadMap, class Element, class StrideMNL>
struct VisitorRowOrZeroBroadcast {
struct Arguments {
Element const* ptr_row = nullptr;
StrideMNL dRow = {};
};
using Params = Arguments;
};
template<class ThreadMap, class Element, class StrideMNL>
struct VisitorColOrScalarBroadcast {
struct Arguments {
Element const* ptr_col = nullptr;
bool col_broadcast = true;
StrideMNL dCol = {};
};
using Params = Arguments;
};
Import
#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| ptr_row / ptr_col | Element const* | Yes | Device pointer to the broadcast tensor (scale, bias, or zero-point) |
| row_broadcast / col_broadcast | bool | No | When true, loads from a vector; when false, broadcasts a scalar (default: true) |
| dRow / dCol | StrideMNL | No | Stride descriptor for the broadcast tensor layout |
| ThreadMap | template param | Yes | CUTLASS thread map defining output tile partitioning |
| Element | template param | Yes | Data type of the broadcast tensor (e.g., float, half) |
| StrideMNL | template param | Yes | Stride type parameterizing M, N, L dimensions |
Outputs
| Name | Type | Description |
|---|---|---|
| visit() return | Array<Element, FragmentSize> | Fragment of broadcast values loaded into registers for epilogue computation |
Usage Examples
// Using VisitorRowOrScalarBroadcast in a CUTLASS 2.x epilogue
using ScaleB = cutlass::epilogue::threadblock::VisitorRowOrScalarBroadcast<
OutputTileThreadMap, float, Stride<Int<0>, Int<1>, Int<0>>>;
// Construct arguments: vector broadcast from device pointer
typename ScaleB::Arguments scale_args{
scale_data_ptr, // device pointer to scale tensor
true, // row_broadcast = true for per-channel scales
{} // default stride
};
// Construct arguments: scalar broadcast from device pointer
typename ScaleB::Arguments scalar_args{
scalar_data_ptr, // device pointer to single scalar
false, // row_broadcast = false for per-tensor scale
{} // default stride
};