Implementation:Vllm project Vllm Marlin MMA
| Knowledge Sources | |
|---|---|
| Domains | Quantization, CUDA_Kernels, Tensor_Core |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Provides CUDA tensor core matrix multiply-accumulate (MMA) inline PTX assembly instructions for the Marlin quantized kernel framework.
Description
This header defines two templated device functions, mma and mma_trans, that wrap NVIDIA PTX inline assembly for warp-level matrix multiply-accumulate operations. The functions support multiple scalar types (FP16, BF16, FP8 E4M3, INT8) with configurable accumulator precision (FP32 or FP16) and K-dimension sizes (16 or 32). Architecture-specific code paths handle SM75 (Turing) with m16n8k8 instructions and SM80+ (Ampere/Hopper) with m16n8k16 or m16n8k32 instructions.
Usage
Use these functions as the low-level matrix multiplication building blocks inside Marlin quantized GEMM kernels. They are called from warp-level loops in Marlin's tiled GEMM implementation to perform the actual tensor core computations on quantized weight fragments.
Code Reference
Source Location
- Repository: vllm
- File: csrc/quantization/marlin/marlin_mma.h
- Lines: 1-269
Signature
namespace MARLIN_NAMESPACE_NAME {
// Standard MMA: m16n8k{16,32} tensor core instruction
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
__device__ inline void mma(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
typename MarlinScalarType<type_id>::FragC& frag_c, int idx = 0);
// Transposed MMA: computes with transposed B operand (two B fragments)
template <vllm::ScalarTypeId type_id, bool use_fp16_accum, int k_size = 16>
__device__ inline void mma_trans(
const typename MarlinScalarType<type_id>::FragA& a_frag,
const typename MarlinScalarType<type_id>::FragB& frag_b,
const typename MarlinScalarType<type_id>::FragB& frag_b2,
typename MarlinScalarType<type_id>::FragC& frag_c);
} // namespace MARLIN_NAMESPACE_NAME
Import
#include "marlin_dtypes.cuh"
// Then include this header in Marlin kernel files:
#include "marlin_mma.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| a_frag | MarlinScalarType<type_id>::FragA | Yes | Fragment A register tile (activations) |
| frag_b | MarlinScalarType<type_id>::FragB | Yes | Fragment B register tile (weights) |
| frag_b2 | MarlinScalarType<type_id>::FragB | mma_trans only | Second fragment B for transposed MMA |
| frag_c | MarlinScalarType<type_id>::FragC | Yes | Accumulator fragment (read-modify-write) |
| idx | int | No | Sub-tile index for FP8/INT8 k_size=16 variants (default 0) |
Outputs
| Name | Type | Description |
|---|---|---|
| frag_c | MarlinScalarType<type_id>::FragC | Updated accumulator fragment with MMA result added in-place |
Usage Examples
// Inside a Marlin kernel warp-level loop (FP16 with FP32 accumulation):
using ScalarType = MarlinScalarType<vllm::kFp16>;
ScalarType::FragA a_frag;
ScalarType::FragB b_frag;
ScalarType::FragC c_frag = {}; // zero-initialized accumulator
// Load fragments from shared memory...
// Then perform tensor core MMA:
mma<vllm::kFp16, /*use_fp16_accum=*/false>(a_frag, b_frag, c_frag);
// For INT8 with k_size=32:
mma<vllm::kInt8, /*use_fp16_accum=*/false, /*k_size=*/32>(
a_frag_i8, b_frag_i8, c_frag_i32);