Implementation:Microsoft Onnxruntime CPU Scale
| Knowledge Sources | |
|---|---|
| Domains | Training, CPU_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for element-wise tensor scaling on CPU in the ONNX Runtime training framework.
Description
This file implements the Scale kernel, which multiplies a tensor by a scalar value. The scale factor is read from a single-element tensor input. An optional scale_down attribute (default 0) inverts the scale value, enabling division. The actual computation uses Eigen maps: output = scale_value * input. The kernel supports multiple type combinations for the data tensor (float, double) and scale tensor (float, double, int64_t, int32_t), yielding 8 registered kernel variants.
Usage
This kernel is commonly used in training for gradient scaling operations, such as dividing gradients by the number of micro-batches in gradient accumulation, or scaling loss values.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cpu/math/scale.cc
- Lines: 1-57
Signature
template <typename T, typename ScaleT>
Scale<T, ScaleT>::Scale(const OpKernelInfo& info);
template <typename T, typename ScaleT>
Status Scale<T, ScaleT>::Compute(OpKernelContext* context) const;
Import
#include "orttraining/orttraining/training_ops/cpu/math/scale.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| input | Tensor(T) | Yes | Input tensor to scale |
| scale | Tensor(ScaleT) | Yes | Scalar scale factor (single element) |
Outputs
| Name | Type | Description |
|---|---|---|
| output | Tensor(T) | Scaled output tensor |
Usage Examples
// Registered variants include:
REGISTER_SCALE_KERNEL_TYPED(float, float)
REGISTER_SCALE_KERNEL_TYPED(float, double)
REGISTER_SCALE_KERNEL_TYPED(float, int64_t)
REGISTER_SCALE_KERNEL_TYPED(float, int32_t)
REGISTER_SCALE_KERNEL_TYPED(double, float)
REGISTER_SCALE_KERNEL_TYPED(double, double)
REGISTER_SCALE_KERNEL_TYPED(double, int64_t)
REGISTER_SCALE_KERNEL_TYPED(double, int32_t)