Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft Onnxruntime CPU FakeQuant

From Leeroopedia


Knowledge Sources
Domains Training, CPU_Kernels
Last Updated 2026-02-10 04:00 GMT

Overview

Concrete tool for fake quantization forward and gradient computation on CPU in the ONNX Runtime training framework.

Description

This file implements two kernels for quantization-aware training:

FakeQuant: Simulates quantization during training by quantizing the input to integer range [quant_min, quant_max] and immediately dequantizing back to floating point: output = (clamp(round(input/scale) + zero_point, quant_min, quant_max) - zero_point) * scale. It also outputs a boolean mask indicating which elements fell within the quantization range (not clipped). The operation runs in parallel using the thread pool with a per-element cost model.

FakeQuantGrad: Computes the straight-through estimator gradient: dX = dY * mask, where the mask is true if the element was within quantization range during the forward pass, and false (gradient is zeroed) if it was clipped. This implements the standard straight-through estimator used in quantization-aware training.

Both kernels support per-tensor quantization. Per-channel quantization is noted as a TODO.

Usage

These kernels are used in quantization-aware training (QAT) to simulate the effects of quantization during training while maintaining differentiability through the straight-through estimator.

Code Reference

Source Location

Signature

template <typename T>
void FakeQuantPerTensor(OpKernelContext* ctx, const int64_t num_elements,
    const T* input_data, T quant_scale, T quant_zero_point,
    int64_t quant_min, int64_t quant_max,
    T* fake_quantized_data, bool* quantization_mask_data);

template <typename T>
void FakeQuantGradImpl(const Tensor& dY, const Tensor& gradient_mask, Tensor& dX);

template <typename T>
Status FakeQuant<T>::Compute(OpKernelContext* ctx) const;

template <typename T>
Status FakeQuantGrad<T>::Compute(OpKernelContext* ctx) const;

Import

#include "orttraining/orttraining/training_ops/cpu/quantization/fake_quant.h"

I/O Contract

Inputs (FakeQuant)

Name Type Required Description
input Tensor(float) Yes Input tensor to fake-quantize
scale Tensor(float) Yes Quantization scale (scalar)
zero_point Tensor(float) Yes Quantization zero point (scalar)

Outputs (FakeQuant)

Name Type Description
fake_quantized Tensor(float) Fake-quantized output (same shape as input)
quantization_mask Tensor(bool) Mask indicating in-range elements

Inputs (FakeQuantGrad)

Name Type Required Description
dY Tensor(float) Yes Upstream gradient
gradient_mask Tensor(bool) Yes Quantization mask from forward pass

Outputs (FakeQuantGrad)

Name Type Description
dX Tensor(float) Gradient w.r.t. input (masked by straight-through estimator)

Usage Examples

ONNX_OPERATOR_TYPED_KERNEL_EX(
    FakeQuant, kMSDomain, 1, float, kCpuExecutionProvider,
    (*KernelDefBuilder::Create())
        .TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
    FakeQuant<float>);

ONNX_OPERATOR_TYPED_KERNEL_EX(
    FakeQuantGrad, kMSDomain, 1, float, kCpuExecutionProvider,
    (*KernelDefBuilder::Create())
        .TypeConstraint("T", DataTypeImpl::GetTensorType<float>()),
    FakeQuantGrad<float>);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment