Implementation:Hiyouga LLaMA Factory V1 NPU RMSNorm
| Knowledge Sources | |
|---|---|
| Domains | Machine Learning, Hardware Optimization |
| Last Updated | 2026-02-06 19:00 GMT |
Overview
NpuRMSNormKernel is a concrete kernel implementation that replaces standard RMSNorm forward methods with NPU-optimized versions using torch_npu.npu_rms_norm.
Description
This module implements the NpuRMSNormKernel class which extends BaseKernel to provide NPU-native RMS normalization. It defines a standalone npu_rms_norm_forward function that calls torch_npu.npu_rms_norm directly, and the kernel's apply method iterates over all model modules, matching any whose class name contains "RMSNorm" (case-insensitive) via regex. Matched modules have their forward method replaced using types.MethodType to bind the NPU-optimized function as an instance method. The kernel is automatically registered via the @register_kernel decorator.
Usage
This kernel is applied automatically when running on NPU hardware and the kernel system is configured to use npu_fused_rmsnorm. It requires torch_npu to be installed and an NPU accelerator to be active. No manual invocation is needed in typical workflows.
Code Reference
Source Location
- Repository: Hiyouga_LLaMA_Factory
- File: src/llamafactory/v1/plugins/model_plugins/kernels/ops/rms_norm/npu_rms_norm.py
- Lines: 1-91
Signature
def npu_rms_norm_forward(self, hidden_states) -> Tensor: ...
@register_kernel
class NpuRMSNormKernel(BaseKernel):
_kernel_id = "npu_fused_rmsnorm"
_device = DeviceType.NPU
@classmethod
def apply(cls, **kwargs) -> HFModel: ...
Import
from llamafactory.v1.plugins.model_plugins.kernels.ops.rms_norm.npu_rms_norm import NpuRMSNormKernel
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| hidden_states (npu_rms_norm_forward) | Tensor | Yes | Input hidden states tensor to be normalized |
| self (npu_rms_norm_forward) | Module | Yes | RMSNorm module instance with weight and variance_epsilon attributes |
| model (apply kwarg) | HFModel | Yes | HuggingFace model instance whose RMSNorm modules will be patched |
Outputs
| Name | Type | Description |
|---|---|---|
| npu_rms_norm_forward | Tensor | Normalized tensor consistent with baseline RMSNorm behavior |
| apply | HFModel | The model with all RMSNorm modules patched to use NPU-optimized forward |
Usage Examples
# Automatic registration and application via the kernel system
from llamafactory.v1.plugins.model_plugins.kernels.ops.rms_norm.npu_rms_norm import NpuRMSNormKernel
# The kernel is auto-registered via @register_kernel on import.
# Typical usage is through the kernel interface:
from llamafactory.v1.plugins.model_plugins.kernels.interface import apply_kernel
model = apply_kernel(model=model, kernel_id="npu_fused_rmsnorm")
# Direct application (less common):
if NpuRMSNormKernel.check_deps():
model = NpuRMSNormKernel.apply(model=model)
Related Pages
- Hiyouga_LLaMA_Factory_V1_Kernel_Base - Abstract base class that NpuRMSNormKernel extends
- Hiyouga_LLaMA_Factory_V1_Kernel_Registry - Registry where this kernel is registered via decorator