Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft Onnxruntime CUDA FakeQuant

From Leeroopedia


Knowledge Sources
Domains Training, CUDA_Kernels
Last Updated 2026-02-10 04:00 GMT

Overview

Concrete tool for fake quantization and its gradient in the ONNX Runtime CUDA training framework.

Description

Implements FakeQuant and FakeQuantGrad operators for CUDA that simulate quantization effects during training (quantization-aware training). The FakeQuant forward pass takes input, scale, and zero_point (both CPU scalars), then calls FakeQuantPerTensor to quantize and immediately dequantize the input, producing a fake-quantized output and a boolean gradient mask. The mask indicates which elements fell within the [quant_min, quant_max] range. The FakeQuantGrad backward pass applies the gradient mask (straight-through estimator), passing gradients only for elements that were within the quantization range. Scale and zero_point inputs are validated to be scalars or 1-element vectors. Currently registered only for float type.

Usage

Used during quantization-aware training (QAT) to simulate the effects of integer quantization in the forward pass while maintaining differentiability for backpropagation.

Code Reference

Source Location

Signature

template <typename T>
class FakeQuant : public CudaKernel {
  Status ComputeInternal(OpKernelContext* ctx) const;
};

template <typename T>
class FakeQuantGrad : public CudaKernel {
  Status ComputeInternal(OpKernelContext* ctx) const;
};

Import

#include "orttraining/training_ops/cuda/quantization/fake_quant.h"

I/O Contract

Inputs (FakeQuant)

Name Type Required Description
input Tensor(T) Yes Input tensor to fake-quantize
scale Tensor(T) Yes Quantization scale (scalar, CPU memory)
zero_point Tensor(T) Yes Quantization zero point (scalar, CPU memory)

Outputs (FakeQuant)

Name Type Description
fake_quantized Tensor(T) Fake-quantized output (same shape as input)
quantization_mask Tensor(bool) Boolean mask for gradient computation

Usage Examples

REGISTER_FAKEQUANT_KERNEL_TYPED(float)
// Forward: FakeQuantPerTensor(stream, size, input, scale, zero_point, quant_min, quant_max, output, mask)
// Backward: FakeQuantGradImpl(stream, size, dY, mask, dX)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment