Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Vllm project Vllm Marlin MMA

From Leeroopedia


Knowledge Sources
Domains Quantization, CUDA_Kernels, Tensor_Core
Last Updated 2026-02-08 00:00 GMT

Overview

Provides CUDA tensor core matrix multiply-accumulate (MMA) inline PTX assembly instructions for the Marlin quantized kernel framework.

Description

This header defines two templated device functions, mma and mma_trans, that wrap NVIDIA PTX inline assembly for warp-level matrix multiply-accumulate operations. The functions support multiple scalar types (FP16, BF16, FP8 E4M3, INT8) with configurable accumulator precision (FP32 or FP16) and K-dimension sizes (16 or 32). Architecture-specific code paths handle SM75 (Turing) with m16n8k8 instructions and SM80+ (Ampere/Hopper) with m16n8k16 or m16n8k32 instructions.

Usage

Use these functions as the low-level matrix multiplication building blocks inside Marlin quantized GEMM kernels. They are called from warp-level loops in Marlin's tiled GEMM implementation to perform the actual tensor core computations on quantized weight fragments.

Code Reference

Source Location

Signature

namespace MARLIN_NAMESPACE_NAME {

// Standard MMA: m16n8k{16,32} tensor core instruction
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
__device__ inline void mma(
    const typename MarlinScalarType<type_id>::FragA& a_frag,
    const typename MarlinScalarType<type_id>::FragB& frag_b,
    typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0);

// Transposed MMA: computes with transposed B operand (two B fragments)
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
__device__ inline void mma_trans(
    const typename MarlinScalarType<type_id>::FragA& a_frag,
    const typename MarlinScalarType<type_id>::FragB& frag_b,
    const typename MarlinScalarType<type_id>::FragB& frag_b2,
    typename MarlinScalarType<type_id>::FragC& frag_c);

}  // namespace MARLIN_NAMESPACE_NAME

Import

#include "marlin_dtypes.cuh"
// Then include this header in Marlin kernel files:
#include "marlin_mma.h"

I/O Contract

Inputs

Name Type Required Description
a_frag MarlinScalarType<type_id>::FragA Yes Fragment A register tile (activations)
frag_b MarlinScalarType<type_id>::FragB Yes Fragment B register tile (weights)
frag_b2 MarlinScalarType<type_id>::FragB mma_trans only Second fragment B for transposed MMA
frag_c MarlinScalarType<type_id>::FragC Yes Accumulator fragment (read-modify-write)
idx int No Sub-tile index for FP8/INT8 k_size=16 variants (default 0)

Outputs

Name Type Description
frag_c MarlinScalarType<type_id>::FragC Updated accumulator fragment with MMA result added in-place

Usage Examples

// Inside a Marlin kernel warp-level loop (FP16 with FP32 accumulation):
using ScalarType = MarlinScalarType<vllm::kFp16>;
ScalarType::FragA a_frag;
ScalarType::FragB b_frag;
ScalarType::FragC c_frag = {};  // zero-initialized accumulator

// Load fragments from shared memory...
// Then perform tensor core MMA:
mma<vllm::kFp16, /*use_fp16_accum=*/false>(a_frag, b_frag, c_frag);

// For INT8 with k_size=32:
mma<vllm::kInt8, /*use_fp16_accum=*/false, /*k_size=*/32>(
    a_frag_i8, b_frag_i8, c_frag_i32);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment