Implementation:Microsoft Onnxruntime CPU FakeQuant
| Knowledge Sources | |
|---|---|
| Domains | Training, CPU_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for fake quantization forward and gradient computation on CPU in the ONNX Runtime training framework.
Description
This file implements two kernels for quantization-aware training:
FakeQuant: Simulates quantization during training by quantizing the input to integer range [quant_min, quant_max] and immediately dequantizing back to floating point: output = (clamp(round(input/scale) + zero_point, quant_min, quant_max) - zero_point) * scale. It also outputs a boolean mask indicating which elements fell within the quantization range (not clipped). The operation runs in parallel using the thread pool with a per-element cost model.
FakeQuantGrad: Computes the straight-through estimator gradient: dX = dY * mask, where the mask is true if the element was within quantization range during the forward pass, and false (gradient is zeroed) if it was clipped. This implements the standard straight-through estimator used in quantization-aware training.
Both kernels support per-tensor quantization. Per-channel quantization is noted as a TODO.
Usage
These kernels are used in quantization-aware training (QAT) to simulate the effects of quantization during training while maintaining differentiability through the straight-through estimator.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cpu/quantization/fake_quant.cc
- Lines: 1-118
Signature
template <typename T>
void FakeQuantPerTensor(OpKernelContext* ctx, const int64_t num_elements,
const T* input_data, T quant_scale, T quant_zero_point,
int64_t quant_min, int64_t quant_max,
T* fake_quantized_data, bool* quantization_mask_data);
template <typename T>
void FakeQuantGradImpl(const Tensor& dY, const Tensor& gradient_mask, Tensor& dX);
template <typename T>
Status FakeQuant<T>::Compute(OpKernelContext* ctx) const;
template <typename T>
Status FakeQuantGrad<T>::Compute(OpKernelContext* ctx) const;
Import
#include "orttraining/orttraining/training_ops/cpu/quantization/fake_quant.h"
I/O Contract
Inputs (FakeQuant)
| Name | Type | Required | Description |
|---|---|---|---|
| input | Tensor(float) | Yes | Input tensor to fake-quantize |
| scale | Tensor(float) | Yes | Quantization scale (scalar) |
| zero_point | Tensor(float) | Yes | Quantization zero point (scalar) |
Outputs (FakeQuant)
| Name | Type | Description |
|---|---|---|
| fake_quantized | Tensor(float) | Fake-quantized output (same shape as input) |
| quantization_mask | Tensor(bool) | Mask indicating in-range elements |
Inputs (FakeQuantGrad)
| Name | Type | Required | Description |
|---|---|---|---|
| dY | Tensor(float) | Yes | Upstream gradient |
| gradient_mask | Tensor(bool) | Yes | Quantization mask from forward pass |
Outputs (FakeQuantGrad)
| Name | Type | Description |
|---|---|---|
| dX | Tensor(float) | Gradient w.r.t. input (masked by straight-through estimator) |
Usage Examples
ONNX_OPERATOR_TYPED_KERNEL_EX(
FakeQuant, kMSDomain, 1, float, kCpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
FakeQuant<float>);
ONNX_OPERATOR_TYPED_KERNEL_EX(
FakeQuantGrad, kMSDomain, 1, float, kCpuExecutionProvider,
(*KernelDefBuilder::Create())
.TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
FakeQuantGrad<float>);