Implementation:Vllm project Vllm AIter Ops
| Knowledge Sources | |
|---|---|
| Domains | ROCm, GPU_Acceleration, Inference |
| Last Updated | 2026-02-08 00:00 GMT |
Overview
Provides ROCm AITER (AMD Inference Technology for Enhanced Runtime) operations that wrap optimized AMD GPU kernels for attention, matrix multiplication, quantization, and Mixture-of-Experts computations.
Description
_aiter_ops.py is the central integration point for AMD's AITER library within vLLM. It registers custom torch operations (via direct_register_custom_op) that delegate to hardware-optimized kernels for ROCm platforms with gfx9 architecture (e.g., MI300X). The file provides implementations for fused MoE (Mixture-of-Experts), assembly-optimized MoE, top-k softmax, paged attention (both CK and ASM variants), RMSNorm, FP8 quantization, batched matrix multiplication, and MLA (Multi-head Latent Attention) sparse attention operations.
Each operation has both a real implementation (calling the aiter library) and a "fake" implementation that returns correctly-shaped empty tensors for torch.compile graph tracing. Platform availability is checked via is_aiter_found_and_supported() which validates ROCm platform, gfx9 architecture, and library presence.
Usage
This module is used internally by vLLM's model execution layers when running on AMD GPUs. Operations are registered as custom torch ops and invoked through vLLM's attention backends and MoE layers. The VLLM_ROCM_USE_AITER and related environment variables (e.g., VLLM_ROCM_USE_AITER_LINEAR, VLLM_ROCM_USE_AITER_MOE) control which optimizations are enabled.
Code Reference
Source Location
- Repository: vllm
- File: vllm/_aiter_ops.py
- Lines: 1-1738
Signature
def is_aiter_found() -> bool: ...
def is_aiter_found_and_supported() -> bool: ...
def if_aiter_supported(func: Callable) -> Callable: ...
class rocm_aiter_ops:
@staticmethod
def is_enabled() -> bool: ...
@staticmethod
def fused_moe(hidden_states, w1, w2, topk_weight, topk_ids, ...) -> torch.Tensor: ...
@staticmethod
def asm_moe_tkw1(hidden_states, w1, w2, topk_weights, topk_ids, ...) -> torch.Tensor: ...
@staticmethod
def topk_softmax(topk_weights, topk_indices, token_expert_indices, gating_output, renormalize) -> None: ...
@staticmethod
def paged_attention_rocm(...) -> None: ...
@staticmethod
def rmsnorm(output, input, weight, epsilon) -> None: ...
@staticmethod
def batched_gemm_fp8(...) -> torch.Tensor: ...
Import
from vllm._aiter_ops import rocm_aiter_ops, is_aiter_found_and_supported
# Check if AITER is available and supported
if is_aiter_found_and_supported():
result = rocm_aiter_ops.fused_moe(hidden_states, w1, w2, topk_weight, topk_ids)
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| hidden_states | torch.Tensor | Yes | Input activations tensor for MoE or attention operations |
| w1, w2 | torch.Tensor | Yes | Expert weight matrices for MoE gate and down projections |
| topk_weight | torch.Tensor | Yes | Router weights for top-k selected experts |
| topk_ids | torch.Tensor | Yes | Indices of top-k selected experts per token |
| w1_scale, w2_scale | torch.Tensor | No | Quantization scales for FP8 weight dequantization |
| a1_scale, a2_scale | torch.Tensor | No | Quantization scales for FP8 activation dequantization |
| VLLM_ROCM_USE_AITER | env var | No | Enable/disable AITER operations globally ("0" or "1") |
Outputs
| Name | Type | Description |
|---|---|---|
| result | torch.Tensor | Output tensor from the fused MoE, attention, or GEMM operation |
| (in-place) | None | Some operations (rmsnorm, paged_attention) modify output tensors in-place |
Usage Examples
from vllm._aiter_ops import rocm_aiter_ops, is_aiter_found_and_supported
# Check platform support before using AITER ops
if is_aiter_found_and_supported() and rocm_aiter_ops.is_enabled():
# Fused MoE with FP8 quantization on AMD MI300X
output = torch.ops.vllm.rocm_aiter_fused_moe(
hidden_states,
w1, w2,
topk_weight, topk_ids,
expert_mask=None,
activation_method=0, # SiLU
quant_method=1, # FP8
w1_scale=w1_scale,
w2_scale=w2_scale,
)
# Using the decorator to guard functions
from vllm._aiter_ops import if_aiter_supported
@if_aiter_supported
def my_rocm_optimized_op():
# Only executes on supported ROCm platforms
pass