Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Vllm project Vllm Broadcast Load Epilogue C3X

From Leeroopedia


Knowledge Sources
Domains CUTLASS, Epilogue, Quantization, GEMM
Last Updated 2026-02-08 00:00 GMT

Overview

Implements CUTLASS 3.x epilogue visitors for broadcasting quantization scales and biases from device pointers during GEMM operations on Hopper (SM90+) GPUs.

Description

This file is a modified excerpt of CUTLASS v3.5.0 sm90_visitor_load_tma_warpspecialized.hpp, adapted to support row, column, or scalar broadcasting where the tensor is always passed via a device pointer. It leverages shared memory and TMA-based data movement on Hopper GPUs for high-performance broadcast loading. Like its C2X counterpart, this design avoids torch.compile graph breaks caused by CPU-resident scalars by keeping all scale tensors on the device.

Usage

This header is included by scaled_mm_epilogues_c3x.hpp and compiled as part of the CUTLASS 3.x quantized GEMM kernels targeting NVIDIA Hopper (SM90a) and later architectures. It is used whenever vLLM performs scaled matrix multiplication with INT8 or FP8 operands on Hopper+ GPUs.

Code Reference

Source Location

Signature

template<
  int Stages,
  class CtaTileShapeMNK,
  class Element,
  class StrideMNL = Stride<_0,_1,_0>,
  int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90RowOrScalarBroadcast {
  struct SharedStorage {
    array_aligned<Element, size<1>(CtaTileShapeMNK{})> smem;
  };
  struct Arguments {
    Element const* ptr_row = nullptr;
    bool row_broadcast = true;
    StrideMNL dRow = {};
  };
  using Params = Arguments;
};

template<
  int Stages,
  class CtaTileShapeMNK,
  class Element,
  class StrideMNL = Stride<_1,_0,_0>,
  int Alignment = 128 / sizeof_bits_v<Element>
>
struct Sm90ColOrScalarBroadcast {
  struct SharedStorage {
    array_aligned<Element, size<0>(CtaTileShapeMNK{})> smem;
  };
  struct Arguments {
    Element const* ptr_col = nullptr;
    bool col_broadcast = true;
    StrideMNL dCol = {};
  };
  using Params = Arguments;
};

Import

#include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp"

I/O Contract

Inputs

Name Type Required Description
ptr_row / ptr_col Element const* Yes Device pointer to the broadcast tensor (scale, bias, or zero-point)
row_broadcast / col_broadcast bool No When true, loads from a vector; when false, broadcasts a scalar (default: true)
dRow / dCol StrideMNL No Stride descriptor for the broadcast tensor layout
Stages int Yes Number of pipeline stages (must be 0 for row/col broadcast)
CtaTileShapeMNK template param Yes CTA tile shape defining M, N, K dimensions
Element template param Yes Data type of the broadcast tensor (e.g., float, half)
Alignment int No Memory alignment for vector loads (default: 128 / sizeof_bits)

Outputs

Name Type Description
visit() return Array<Element, FragmentSize> Fragment of broadcast values loaded via shared memory into registers for epilogue computation

Usage Examples

// Using Sm90RowOrScalarBroadcast in a CUTLASS 3.x epilogue
using ScaleB = cutlass::epilogue::fusion::Sm90RowOrScalarBroadcast<
    0 /*Stages*/, TileShape, float, Stride<Int<0>, Int<1>, Int<0>>>;

// Construct arguments for per-channel scale broadcast
typename ScaleB::Arguments scale_args{
    scale_data_ptr,    // device pointer to scale tensor
    true,              // row_broadcast = true for per-channel
    {}                 // default stride
};

// Construct arguments for per-tensor scalar broadcast
typename ScaleB::Arguments scalar_args{
    scalar_data_ptr,   // device pointer to single scalar
    false,             // row_broadcast = false for per-tensor
    {}                 // default stride
};

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment