Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Vllm project Vllm Broadcast Load Epilogue Array C3X

From Leeroopedia
Revision as of 17:05, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Vllm_project_Vllm_Broadcast_Load_Epilogue_Array_C3X.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains CUTLASS Epilogue, Quantized GEMM
Last Updated 2026-02-08 00:00 GMT

Overview

Implements CUTLASS 3.x epilogue visitors for broadcasting scales and biases from device pointers during quantized GEMM operations on Hopper (SM90+) GPUs.

Description

This header extends the CUTLASS Sm90 epilogue fusion framework to support flexible broadcast loading of quantization scales and biases. Sm90RowOrScalarBroadcastArray handles row-vector or scalar broadcasting along the N dimension, while Sm90ColOrScalarBroadcastArray handles column-vector or scalar broadcasting along the M dimension. Both structs accept arrays of device pointers, allowing batched/grouped GEMM operations with per-group scales. A key design decision is using a row_broadcast boolean flag to switch between vector and scalar modes, avoiding the need for CPU-resident scalar values that would break torch.compile graph compilation.

Usage

This header is included by CUTLASS-based quantized GEMM kernels in the vLLM CUDA backend. It is used when running quantized inference on NVIDIA Hopper (SM90+) GPUs, supporting per-tensor, per-channel, and per-token quantization schemes through the same compiled kernel.

Code Reference

Source Location

Signature

namespace cutlass::epilogue::fusion {

// Row or scalar broadcast for quantization scales along N dimension
template<
  int Stages,
  class CtaTileShapeMNK,
  class Element,
  class StrideMNL = Stride<_0,_1,_0>,
  int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90RowOrScalarBroadcastArray {
  struct Arguments {
    const Element* const* ptr_row_array = nullptr;
    bool row_broadcast = true;
    StrideMNL dRow = {};
  };

  using Params = Arguments;

  template <class ProblemShape>
  static constexpr Params to_underlying_arguments(
      ProblemShape const& problem_shape, Arguments const& args, void* workspace);

  template <class ProblemShape>
  static bool can_implement(ProblemShape const& problem_shape, Arguments const& args);

  template <class... Args>
  struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks {
    CUTLASS_DEVICE void begin();
    CUTLASS_DEVICE void begin_loop(int epi_m, int epi_n);
    template <typename ElementAccumulator, int FragmentSize>
    CUTLASS_DEVICE Array<Element, FragmentSize>
    visit(Array<ElementAccumulator, FragmentSize> const& frg_acc,
          int epi_v, int epi_m, int epi_n);
  };
};

// Column or scalar broadcast for quantization scales along M dimension
template<
  int Stages,
  class CtaTileShapeMNK,
  class Element,
  class StrideMNL = Stride<_1,_0,_0>,
  int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90ColOrScalarBroadcastArray;

} // namespace cutlass::epilogue::fusion

Import

#include "cutlass/cutlass.h"
#include "cutlass/arch/barrier.h"
#include "cute/tensor.hpp"
#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp"

I/O Contract

Inputs

Name Type Required Description
ptr_row_array const Element* const* Yes Array of device pointers to scale/bias vectors (one per batch/group)
row_broadcast bool Yes If true, load a row vector; if false, broadcast a scalar from ptr_row_array
dRow StrideMNL No Stride descriptor for the broadcast tensor layout
frg_acc Array<ElementAccumulator, FragmentSize> Yes Accumulator fragment from the GEMM mainloop

Outputs

Name Type Description
(return) Array<Element, FragmentSize> Broadcast scale/bias fragment to be applied to GEMM output

Usage Examples

// Define epilogue with row-broadcast scales for quantized GEMM
using EpilogueScale = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcastArray<
    0,                    // Stages (0 for row broadcast)
    CtaTileShape,         // CTA tile shape MNK
    float,                // Element type for scales
    Stride<_0, _1, _0>>; // Row stride

// Configure arguments
EpilogueScale::Arguments scale_args;
scale_args.ptr_row_array = device_scale_ptrs;  // array of per-group scale pointers
scale_args.row_broadcast = true;                // vector mode (false for scalar)

// Pass to CUTLASS kernel launch as part of epilogue fusion

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment