Implementation:Microsoft Onnxruntime CUDA FakeQuant
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for fake quantization and its gradient in the ONNX Runtime CUDA training framework.
Description
Implements FakeQuant and FakeQuantGrad operators for CUDA that simulate quantization effects during training (quantization-aware training). The FakeQuant forward pass takes input, scale, and zero_point (both CPU scalars), then calls FakeQuantPerTensor to quantize and immediately dequantize the input, producing a fake-quantized output and a boolean gradient mask. The mask indicates which elements fell within the [quant_min, quant_max] range. The FakeQuantGrad backward pass applies the gradient mask (straight-through estimator), passing gradients only for elements that were within the quantization range. Scale and zero_point inputs are validated to be scalars or 1-element vectors. Currently registered only for float type.
Usage
Used during quantization-aware training (QAT) to simulate the effects of integer quantization in the forward pass while maintaining differentiability for backpropagation.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/quantization/fake_quant.cc
- Lines: 1-91
Signature
template <typename T>
class FakeQuant : public CudaKernel {
Status ComputeInternal(OpKernelContext* ctx) const;
};
template <typename T>
class FakeQuantGrad : public CudaKernel {
Status ComputeInternal(OpKernelContext* ctx) const;
};
Import
#include "orttraining/training_ops/cuda/quantization/fake_quant.h"
I/O Contract
Inputs (FakeQuant)
| Name | Type | Required | Description |
|---|---|---|---|
| input | Tensor(T) | Yes | Input tensor to fake-quantize |
| scale | Tensor(T) | Yes | Quantization scale (scalar, CPU memory) |
| zero_point | Tensor(T) | Yes | Quantization zero point (scalar, CPU memory) |
Outputs (FakeQuant)
| Name | Type | Description |
|---|---|---|
| fake_quantized | Tensor(T) | Fake-quantized output (same shape as input) |
| quantization_mask | Tensor(bool) | Boolean mask for gradient computation |
Usage Examples
REGISTER_FAKEQUANT_KERNEL_TYPED(float)
// Forward: FakeQuantPerTensor(stream, size, input, scale, zero_point, quant_min, quant_max, output, mask)
// Backward: FakeQuantGradImpl(stream, size, dY, mask, dX)