Implementation:Sgl project Sglang CUTLASS Epilogue Scale
| Knowledge Sources | |
|---|---|
| Domains | CUDA_Kernels, CUTLASS_Extensions, Quantized_Inference |
| Last Updated | 2026-02-10 00:00 GMT |
Overview
CUTLASS epilogue visitor that applies per-row and per-column scaling factors to GEMM accumulator output, enabling fused dequantization for SmoothQuant-style quantized inference.
Description
The EpilogueVisitorPerRowPerCol class template in the cutlass::epilogue::threadblock namespace implements a CUTLASS epilogue callback pattern for dual-axis scaling. It is adapted from NVIDIA TensorRT-LLM and parameterized on:
- ThreadblockShape_ -- the tile dimensions of the GEMM threadblock
- ThreadCount -- number of threads in the threadblock
- ScaleTileIterator_ -- tile iterator for loading per-row and per-column scale factors
- OutputTileIterator_ -- tile iterator for writing the scaled output
- ElementAccumulator_ -- accumulator element type (typically float)
- ElementCompute_ -- computation element type for scaling
- ElementwiseFunctor_ -- functor for the elementwise scaling operation
- UseMasking_ -- optional masking support
The class defines nested Arguments and Params structs with batch stride support for batched GEMM. SharedStorage manages shared memory for scale factor caching. During the epilogue phase, per-row alpha scales and per-column alpha scales are loaded via tile iterators and multiplied with the accumulated GEMM results, writing the dequantized output in a single fused pass.
Usage
Use this epilogue visitor when implementing quantized GEMM kernels that require per-row/per-column scaling (e.g., SmoothQuant W8A8 quantization). It is integrated into the CUTLASS GEMM kernel pipeline via the GemmWithEpilogueVisitor kernel.
Code Reference
Source Location
- Repository: Sgl_project_Sglang
- File: sgl-kernel/csrc/cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h
- Lines: 1-309
Signature
namespace cutlass::epilogue::threadblock {
template <
typename ThreadblockShape_,
int ThreadCount,
typename ScaleTileIterator_,
typename OutputTileIterator_,
typename ElementAccumulator_,
typename ElementCompute_,
typename ElementwiseFunctor_,
bool UseMasking_ = false>
class EpilogueVisitorPerRowPerCol {
public:
using ThreadblockShape = ThreadblockShape_;
static int const kThreadCount = ThreadCount;
using ScaleTileIterator = ScaleTileIterator_;
using OutputTileIterator = OutputTileIterator_;
static int const kIterations = OutputTileIterator::kIterations;
static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess;
using AccumulatorFragment = Array<ElementAccumulator_, kElementsPerAccess>;
using ComputeFragment = Array<ElementCompute_, kElementsPerAccess>;
struct Arguments {
typename ElementwiseFunctor_::Params elementwise;
int64_t batch_stride_alpha;
int64_t batch_stride_C;
int64_t batch_stride_D;
};
struct Params { ... };
struct SharedStorage { ... };
};
} // namespace cutlass::epilogue::threadblock
Import
#include "cutlass_extensions/epilogue/epilogue_per_row_per_col_scale.h"
// Underlying dependencies:
#include <cutlass/arch/memory.h>
#include <cutlass/numeric_conversion.h>
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| accumulator | AccumulatorFragment | Yes | GEMM accumulator values from the MMA stage |
| alpha_row | ScaleTileIterator | Yes | Per-row scale factors (activation scales) |
| alpha_col | ScaleTileIterator | Yes | Per-column scale factors (weight scales) |
| batch_stride_alpha | int64_t | No | Stride between batches for scale factors |
| batch_stride_C | int64_t | No | Stride between batches for input C |
| batch_stride_D | int64_t | No | Stride between batches for output D |
Outputs
| Name | Type | Description |
|---|---|---|
| output | OutputVector (Array<ElementOutput, kElementsPerAccess>) | Scaled output elements written to global memory |
Usage Examples
// Define the epilogue visitor type
using EpilogueVisitor = cutlass::epilogue::threadblock::EpilogueVisitorPerRowPerCol<
ThreadblockShape,
kThreadCount,
ScaleTileIterator,
OutputTileIterator,
float, // ElementAccumulator
float, // ElementCompute
LinearCombination,
false>; // UseMasking
// Set up arguments with batch strides
typename EpilogueVisitor::Arguments epilogue_args{
{alpha, beta}, // elementwise params
batch_stride_alpha,
batch_stride_C,
batch_stride_D
};