Implementation:Vllm project Vllm DNNL Helper
| Knowledge Sources | |
|---|---|
| Domains | CPU_Inference, DNNL |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Implements an LRU cache for OneDNN (DNNL) matrix multiplication primitives, avoiding expensive primitive recompilation during CPU inference.
Description
This file provides the DNNLPrimitiveCache template class, which maintains an LRU-ordered collection of compiled OneDNN primitives keyed by their configuration parameters. It also implements DNNLMatMulPrimitiveHandler for weight prepacking and runtime memory management, along with derived handlers W8A8MatMulPrimitiveHandler and MatMulPrimitiveHandler for INT8 quantized and floating-point matmul operations respectively.
The cache uses a std::list and std::unordered_map combination to provide O(1) lookup and eviction, dramatically improving OneDNN operator performance by caching compiled primitives essential for low-latency serving.
Usage
This code is compiled as part of the vLLM CPU backend extension. It is used internally by the DNNL kernel layer (dnnl_kernels.cpp) whenever OneDNN matmul primitives are needed, such as during onednn_mm and onednn_qmatmul calls for CPU-based inference.
Code Reference
Source Location
- Repository: vllm
- File: csrc/cpu/dnnl_helper.cpp
- Lines: 1-569
Signature
void release_dnnl_matmul_handler(int64_t handler);
template <typename KT, typename VT>
class DNNLPrimitiveCache {
DNNLPrimitiveCache(size_t capacity);
template <typename F>
result_value_t get_or_create(const KT& key, F&& creator);
};
DNNLMatMulPrimitiveHandler::DNNLMatMulPrimitiveHandler(
const Args& args, dnnl::memory::data_type b_type);
void DNNLMatMulPrimitiveHandler::prepack_weight(
void* original_b_ptr, dnnl::memory::desc original_b_md,
dnnl::memory::desc b_target_mem_desc);
void DNNLMatMulPrimitiveHandler::set_runtime_memory_ptr(
size_t index, dnnl_memory* memory_ptr);
W8A8MatMulPrimitiveHandler::W8A8MatMulPrimitiveHandler(const Args& args);
Import
#include "cpu/dnnl_helper.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| handler | int64_t | Yes | Opaque pointer to a DNNLMatMulPrimitiveHandler, cast to int64_t for release |
| args | DNNLMatMulPrimitiveHandler::Args | Yes | Configuration struct containing weight dimensions (b_n_size, b_k_size), strides, data types, and primitive cache size |
| b_type | dnnl::memory::data_type | Yes | OneDNN data type for the weight matrix (e.g., s8 for INT8) |
| original_b_ptr | void* | Yes | Pointer to the original weight data for prepacking into OneDNN-optimal layout |
| key | KT (template) | Yes | Cache key identifying the specific primitive configuration |
| creator | F (callable) | Yes | Factory function invoked on cache miss to create a new primitive |
Outputs
| Name | Type | Description |
|---|---|---|
| cached_primitive | VT (shared_ptr) | Cached or newly created OneDNN primitive handler, ready for execution |
| packed_weight | dnnl::memory | Weight memory repacked into OneDNN-preferred blocked layout for efficient matmul |
Usage Examples
// Create a W8A8 matmul handler with INT8 weights
W8A8MatMulPrimitiveHandler::Args args;
args.b_k_size = weight.size(0);
args.b_n_size = weight.size(1);
args.b_k_stride = weight.stride(0);
args.b_n_stride = weight.stride(1);
args.b_ptr = weight.data_ptr();
args.primitive_cache_size = 16;
args.a_quantization_strategy = QuantizationStrategy::PER_TENSOR;
args.b_quantization_strategy = QuantizationStrategy::PER_OUTPUT_CHANNEL;
auto handler = new W8A8MatMulPrimitiveHandler(args);
// Release when done
release_dnnl_matmul_handler(reinterpret_cast<int64_t>(handler));