Implementation:Microsoft Onnxruntime CUDA MixedPrecisionScale
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for scaling tensors with mixed-precision type conversion in the ONNX Runtime CUDA training framework.
Description
Implements the MixedPrecisionScale operator for CUDA that scales one or more input tensors by a float scale factor while converting to a target data type. The target type is specified via the to attribute (TensorProto_DataType). Supported target types include float16, bfloat16, float, and double. When fuse_outputs is enabled, all output tensors are fused into a single contiguous buffer with computed byte offsets, reducing memory fragmentation. Otherwise, each input produces a separate output tensor. The implementation calls Impl_MixedPrecisionScale for each input tensor to perform the fused scale-and-cast operation on GPU. Registered for MLFloat16, float, and BFloat16 source types.
Usage
Used during mixed-precision training to scale gradients or activations while converting between precision formats, commonly as part of loss scaling.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/math/mixed_precision_scale.cc
- Lines: 1-125
Signature
template <typename SrcT>
class MixedPrecisionScale : public CudaKernel {
MixedPrecisionScale(const OpKernelInfo& info);
Status ComputeInternal(OpKernelContext* context) const;
};
Status BytesPerElement(ONNX_NAMESPACE::TensorProto_DataType to, size_t& bytes_per_elem);
Import
#include "orttraining/training_ops/cuda/math/mixed_precision_scale.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| scale | Tensor(float) | Yes | Scalar scale factor |
| inputs | Tensor(SrcT)... | Yes | One or more tensors to scale and convert |
Outputs
| Name | Type | Description |
|---|---|---|
| outputs | Tensor(DstT)... | Scaled and type-converted tensors (fused or separate based on fuse_outputs) |
Usage Examples
REGISTER_MIXEDPRECISIONSCALE_KERNEL_TYPED(MLFloat16)
REGISTER_MIXEDPRECISIONSCALE_KERNEL_TYPED(float)
REGISTER_MIXEDPRECISIONSCALE_KERNEL_TYPED(BFloat16)
// Converts and scales: input(MLFloat16) * scale -> output(float)
// Attribute "to" controls the target type
// Attribute "fuse_outputs" controls output buffer fusion