Implementation:Vllm project Vllm Broadcast Load Epilogue Array C3X
| Knowledge Sources | |
|---|---|
| Domains | CUTLASS Epilogue, Quantized GEMM |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements CUTLASS 3.x epilogue visitors for broadcasting scales and biases from device pointers during quantized GEMM operations on Hopper (SM90+) GPUs.
Description
This header extends the CUTLASS Sm90 epilogue fusion framework to support flexible broadcast loading of quantization scales and biases. Sm90RowOrScalarBroadcastArray handles row-vector or scalar broadcasting along the N dimension, while Sm90ColOrScalarBroadcastArray handles column-vector or scalar broadcasting along the M dimension. Both structs accept arrays of device pointers, allowing batched/grouped GEMM operations with per-group scales. A key design decision is using a row_broadcast boolean flag to switch between vector and scalar modes, avoiding the need for CPU-resident scalar values that would break torch.compile graph compilation.
Usage
This header is included by CUTLASS-based quantized GEMM kernels in the vLLM CUDA backend. It is used when running quantized inference on NVIDIA Hopper (SM90+) GPUs, supporting per-tensor, per-channel, and per-token quantization schemes through the same compiled kernel.
Code Reference
Source Location
- Repository: vllm
- File: csrc/cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp
- Lines: 1-457
Signature
namespace cutlass::epilogue::fusion {
// Row or scalar broadcast for quantization scales along N dimension
template<
int Stages,
class CtaTileShapeMNK,
class Element,
class StrideMNL = Stride<_0,_1,_0>,
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90RowOrScalarBroadcastArray {
struct Arguments {
const Element* const* ptr_row_array = nullptr;
bool row_broadcast = true;
StrideMNL dRow = {};
};
using Params = Arguments;
template <class ProblemShape>
static constexpr Params to_underlying_arguments(
ProblemShape const& problem_shape, Arguments const& args, void* workspace);
template <class ProblemShape>
static bool can_implement(ProblemShape const& problem_shape, Arguments const& args);
template <class... Args>
struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
CUTLASS_DEVICE void begin();
CUTLASS_DEVICE void begin_loop(int epi_m, int epi_n);
template <typename ElementAccumulator, int FragmentSize>
CUTLASS_DEVICE Array<Element, FragmentSize>
visit(Array<ElementAccumulator, FragmentSize> const& frg_acc,
int epi_v, int epi_m, int epi_n);
};
};
// Column or scalar broadcast for quantization scales along M dimension
template<
int Stages,
class CtaTileShapeMNK,
class Element,
class StrideMNL = Stride<_1,_0,_0>,
int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90ColOrScalarBroadcastArray;
} // namespace cutlass::epilogue::fusion
Import
#include "cutlass/cutlass.h"
#include "cutlass/arch/barrier.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| ptr_row_array | const Element* const* | Yes | Array of device pointers to scale/bias vectors (one per batch/group) |
| row_broadcast | bool | Yes | If true, load a row vector; if false, broadcast a scalar from ptr_row_array |
| dRow | StrideMNL | No | Stride descriptor for the broadcast tensor layout |
| frg_acc | Array<ElementAccumulator, FragmentSize> | Yes | Accumulator fragment from the GEMM mainloop |
Outputs
| Name | Type | Description |
|---|---|---|
| (return) | Array<Element, FragmentSize> | Broadcast scale/bias fragment to be applied to GEMM output |
Usage Examples
// Define epilogue with row-broadcast scales for quantized GEMM
using EpilogueScale = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray<
0, // Stages (0 for row broadcast)
CtaTileShape, // CTA tile shape MNK
float, // Element type for scales
Stride<_0, _1, _0>>; // Row stride
// Configure arguments
EpilogueScale::Arguments scale_args;
scale_args.ptr_row_array = device_scale_ptrs; // array of per-group scale pointers
scale_args.row_broadcast = true; // vector mode (false for scalar)
// Pass to CUTLASS kernel launch as part of epilogue fusion