Implementation:Microsoft Onnxruntime CPU AdamW
| Knowledge Sources | |
|---|---|
| Domains | Training, CPU_Kernels |
| Last Updated | 2026-02-10 04:00 GMT |
Overview
Concrete tool for the AdamW optimizer kernel on CPU in the ONNX Runtime training framework.
Description
This file implements the AdamWOptimizer kernel, which performs the AdamW (Adam with decoupled weight decay) optimization step. It supports two computation modes:
Mode 0 (PyTorch-style): Weight decay is applied before the weight update, and bias correction is applied individually to the first and second moments. The update is: w = w - lr * w * weight_decay, then w = w - lr * m1 / (alpha_correction * (sqrt(m2/beta_correction) + epsilon)).
Mode 1 (HuggingFace-style): Bias correction is applied to the learning rate, and weight decay is applied after the weight update. The update is: w = w - lr_corrected * m1 / (sqrt(m2) + epsilon) - lr * weight_decay * w.
The kernel processes multiple weight tensors (via TensorSeq) in parallel using thread pools. It supports an optional update signal that can skip the update step. Weights, gradients, and momentum tensors are updated in-place through aliasing.
Usage
This kernel is invoked at each training step when the AdamW optimizer is used. It updates all model weights and their momentum states in a single kernel call.
Code Reference
Source Location
- Repository: Microsoft_Onnxruntime
- File: orttraining/orttraining/training_ops/cpu/optimizer/adamw/adamw.cc
- Lines: 1-204
Signature
Status AdamWOptimizerBase::PrepareForCompute(OpKernelContext* ctx,
AdamWOptimizerBase::Prepare& prepare) const;
template <typename T>
Status AdamWOptimizer<T>::AdamWComputeMode0(Tensor& weight, Tensor& gradient,
Tensor& momentums_1, Tensor& momentums_2,
float lr, float alpha_correction, float beta_correction) const;
template <typename T>
Status AdamWOptimizer<T>::AdamWComputeMode1(Tensor& weight, Tensor& gradient,
Tensor& momentums_1, Tensor& momentums_2,
float lr, float lr_corrected) const;
template <typename T>
Status AdamWOptimizer<T>::Compute(OpKernelContext* ctx) const;
Import
#include "orttraining/orttraining/training_ops/cpu/optimizer/adamw/adamw.h"
I/O Contract
Inputs
| Name | Type | Required | Description |
|---|---|---|---|
| learning_rate | Tensor(float) | Yes | Scalar learning rate |
| step | Tensor(int64) | Yes | Current training step (for bias correction) |
| weights | TensorSeq(float) | Yes | Sequence of weight tensors |
| gradients | TensorSeq(float) | Yes | Sequence of gradient tensors |
| momentums_1 | TensorSeq(float) | Yes | First moment estimates (mean of gradients) |
| momentums_2 | TensorSeq(float) | Yes | Second moment estimates (mean of squared gradients) |
| update_signal | Tensor(bool) | No | Optional signal to skip update |
Outputs
| Name | Type | Description |
|---|---|---|
| updated_flag | Tensor(bool) | Whether the update was performed |
| updated_weights | TensorSeq(float) | Updated weight tensors (in-place alias) |
| updated_momentums_1 | TensorSeq(float) | Updated first moments (in-place alias) |
| updated_momentums_2 | TensorSeq(float) | Updated second moments (in-place alias) |
Usage Examples
ONNX_OPERATOR_KERNEL_EX(
AdamWOptimizer, kMSDomain, 1, kCpuExecutionProvider,
(*KernelDefBuilder::Create())
.Alias(2, 1) /* Return updated weights in-place */
.Alias(4, 2) /* Return updated moment-1 in-place */
.Alias(5, 3) /* Return updated moment-2 in-place */
.TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
.TypeConstraint("T2", DataTypeImpl::GetTensorType<int64_t>())
.TypeConstraint("S_WEIGHT", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
.TypeConstraint("S_GRAD", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
.TypeConstraint("S_MOMENT", DataTypeImpl::AllFixedSizeSequenceTensorTypes()),
AdamWOptimizer<float>);