Implementation:Microsoft Onnxruntime CUDA Adam
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for the Adam optimizer in the ONNX Runtime CUDA training framework.
Description
Implements the AdamOptimizer operator for CUDA with extensive mixed-precision type support. The kernel registration uses seven type parameters (T1 for learning rate, T2 for step count, T3 for moments, T4 for weights, T_GRAD for gradients, T_GRAD_NORM for gradient norm, T_MIXED_PRECISION_FP for mixed-precision weights). The implementation supports in-place updates via aliasing (step count, weights, gradients, moment-1, moment-2, and mixed-precision weights are all updated in-place). Step count is maintained on CPU memory while the optimization step executes on GPU. The do_update flag (CPU input) controls whether the update is applied, supporting gradient accumulation patterns. Nine type combinations are registered covering float and MLFloat16 configurations.
Usage
Invoked during the optimization step of training when the Adam optimizer is selected. Supports both full-precision and mixed-precision training configurations.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/optimizer/adam.cc
- Lines: 1-167
Signature
template <typename T1, typename T2, typename T3, typename T4,
typename T_GRAD, typename T_GRAD_NORM, typename T_MIXED_PRECISION_FP>
class AdamOptimizer : public CudaKernel {
Status ComputeInternal(OpKernelContext* ctx) const;
};
Import
#include "orttraining/training_ops/cuda/optimizer/adam.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| eta | Tensor(T1) | Yes | Learning rate |
| step | Tensor(T2) | Yes | Step count (CPU memory) |
| weights | Tensor(T4) | Yes | Current parameter weights |
| gradients | Tensor(T_GRAD) | Yes | Current gradients |
| moment_1 | Tensor(T3) | Yes | First moment estimate |
| moment_2 | Tensor(T3) | Yes | Second moment estimate |
| mixed_precision_weights | Tensor(T_MIXED_PRECISION_FP) | No | Mixed-precision weight copy |
| loss_scale | Tensor(T3) | No | Loss scale factor |
| grad_norm | Tensor(T_GRAD_NORM) | No | Global gradient norm |
| do_update | Tensor(bool) | No | Whether to perform update (CPU memory) |
Outputs
| Name | Type | Description |
|---|---|---|
| step_new | Tensor(T2) | Updated step count (CPU memory) |
| moment_1_new | Tensor(T3) | Updated first moment (in-place) |
| moment_2_new | Tensor(T3) | Updated second moment (in-place) |
| weights_new | Tensor(T4) | Updated weights (in-place) |
| gradients_new | Tensor(T_GRAD) | Updated gradients (in-place) |
| mixed_precision_weights_new | Tensor(T_MIXED_PRECISION_FP) | Updated mixed-precision weights (in-place) |
Usage Examples
REGISTER_ADAM_KERNEL_TYPED(float, int64_t, float, float, float, float, MLFloat16)
REGISTER_ADAM_KERNEL_TYPED(MLFloat16, int64_t, float, MLFloat16, float, float, MLFloat16)