Implementation:Microsoft Onnxruntime CUDA AdamW
| Knowledge Sources | |
|---|---|
| Domains | Training, CUDA_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for the AdamW optimizer (Adam with decoupled weight decay) in the ONNX Runtime CUDA training framework.
Description
Implements the AdamWOptimizer operator for CUDA that performs AdamW optimization using multi-tensor functors for efficient batched GPU computation. The implementation uses the AdamWMTAFunctor (Multi-Tensor Apply) to process all parameter groups in a single fused kernel launch with configurable chunk size (MTA_ADAMW_CHUNK_SIZE) and group size (MTA_ADAMW_GROUP_SIZE). Parameters include alpha (beta1), beta (beta2), epsilon, learning rate, weight decay, adam mode, correct_bias flag, and current step count. The kernel supports in-place updates via aliases for weights, moment-1, and moment-2 (sequence tensor types). An update signal (CPU input) controls whether the optimization step is performed, and results are optionally copied to output buffers when not aliased in-place.
Usage
Invoked during the optimization step of training when AdamW optimizer is selected, commonly used in modern transformer training.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cuda/optimizer/adamw/adamw.cc
- Lines: 1-73
Signature
class AdamWOptimizer : public CudaKernel, public AdamWOptimizerBase {
Status ComputeInternal(OpKernelContext* ctx) const;
};
Import
#include "orttraining/training_ops/cuda/optimizer/adamw/adamw.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| learning_rate | Tensor(float) | Yes | Learning rate (CPU memory) |
| step | Tensor(int64_t) | Yes | Step count (CPU memory) |
| weights | Sequence(Tensor) | Yes | Parameter weights |
| gradients | Sequence(Tensor) | Yes | Gradients |
| momentums_1 | Sequence(Tensor) | Yes | First moments |
| momentums_2 | Sequence(Tensor) | Yes | Second moments |
| update_signal | Tensor(bool) | No | Whether to perform update (CPU memory) |
Outputs
| Name | Type | Description |
|---|---|---|
| updated_flag | Tensor(bool) | Whether the update was performed (CPU memory) |
| updated_weights | Sequence(Tensor) | Updated weights (in-place) |
| updated_momentums_1 | Sequence(Tensor) | Updated first moments (in-place) |
| updated_momentums_2 | Sequence(Tensor) | Updated second moments (in-place) |
Usage Examples
ONNX_OPERATOR_KERNEL_EX(
AdamWOptimizer, kMSDomain, 1, kCudaExecutionProvider,
(*KernelDefBuilder::Create())
.InputMemoryType(OrtMemTypeCPUInput, 0) // learning_rate
.InputMemoryType(OrtMemTypeCPUInput, 1) // step
.InputMemoryType(OrtMemTypeCPUInput, 6) // update_signal
.OutputMemoryType(OrtMemTypeCPUOutput, 0) // updated_flag
.Alias(2, 1).Alias(4, 2).Alias(5, 3),
AdamWOptimizer);