Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft Onnxruntime CUDA Adam

From Leeroopedia


Knowledge Sources
Domains Training, CUDA_Kernels
Last Updated 2026-02-10 04:00 GMT

Overview

Concrete tool for the Adam optimizer in the ONNX Runtime CUDA training framework.

Description

Implements the AdamOptimizer operator for CUDA with extensive mixed-precision type support. The kernel registration uses seven type parameters (T1 for learning rate, T2 for step count, T3 for moments, T4 for weights, T_GRAD for gradients, T_GRAD_NORM for gradient norm, T_MIXED_PRECISION_FP for mixed-precision weights). The implementation supports in-place updates via aliasing (step count, weights, gradients, moment-1, moment-2, and mixed-precision weights are all updated in-place). Step count is maintained on CPU memory while the optimization step executes on GPU. The do_update flag (CPU input) controls whether the update is applied, supporting gradient accumulation patterns. Nine type combinations are registered covering float and MLFloat16 configurations.

Usage

Invoked during the optimization step of training when the Adam optimizer is selected. Supports both full-precision and mixed-precision training configurations.

Code Reference

Source Location

Signature

template <typename T1, typename T2, typename T3, typename T4,
          typename T_GRAD, typename T_GRAD_NORM, typename T_MIXED_PRECISION_FP>
class AdamOptimizer : public CudaKernel {
  Status ComputeInternal(OpKernelContext* ctx) const;
};

Import

#include "orttraining/training_ops/cuda/optimizer/adam.h"

I/O Contract

Inputs

Name Type Required Description
eta Tensor(T1) Yes Learning rate
step Tensor(T2) Yes Step count (CPU memory)
weights Tensor(T4) Yes Current parameter weights
gradients Tensor(T_GRAD) Yes Current gradients
moment_1 Tensor(T3) Yes First moment estimate
moment_2 Tensor(T3) Yes Second moment estimate
mixed_precision_weights Tensor(T_MIXED_PRECISION_FP) No Mixed-precision weight copy
loss_scale Tensor(T3) No Loss scale factor
grad_norm Tensor(T_GRAD_NORM) No Global gradient norm
do_update Tensor(bool) No Whether to perform update (CPU memory)

Outputs

Name Type Description
step_new Tensor(T2) Updated step count (CPU memory)
moment_1_new Tensor(T3) Updated first moment (in-place)
moment_2_new Tensor(T3) Updated second moment (in-place)
weights_new Tensor(T4) Updated weights (in-place)
gradients_new Tensor(T_GRAD) Updated gradients (in-place)
mixed_precision_weights_new Tensor(T_MIXED_PRECISION_FP) Updated mixed-precision weights (in-place)

Usage Examples

REGISTER_ADAM_KERNEL_TYPED(float, int64_t, float, float, float, float, MLFloat16)
REGISTER_ADAM_KERNEL_TYPED(MLFloat16, int64_t, float, MLFloat16, float, float, MLFloat16)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment