Jump to content

Connect Leeroopedia MCP: Equip your AI agents to search best practices, build plans, verify code, diagnose failures, and look up hyperparameter defaults.

Implementation:Microsoft Onnxruntime CPU AdamW

From Leeroopedia


Knowledge Sources
Domains Training, CPU_Kernels
Last Updated 2026-02-10 04:00 GMT

Overview

Concrete tool for the AdamW optimizer kernel on CPU in the ONNX Runtime training framework.

Description

This file implements the AdamWOptimizer kernel, which performs the AdamW (Adam with decoupled weight decay) optimization step. It supports two computation modes:

Mode 0 (PyTorch-style): Weight decay is applied before the weight update, and bias correction is applied individually to the first and second moments. The update is: w = w - lr * w * weight_decay, then w = w - lr * m1 / (alpha_correction * (sqrt(m2/beta_correction) + epsilon)).

Mode 1 (HuggingFace-style): Bias correction is applied to the learning rate, and weight decay is applied after the weight update. The update is: w = w - lr_corrected * m1 / (sqrt(m2) + epsilon) - lr * weight_decay * w.

The kernel processes multiple weight tensors (via TensorSeq) in parallel using thread pools. It supports an optional update signal that can skip the update step. Weights, gradients, and momentum tensors are updated in-place through aliasing.

Usage

This kernel is invoked at each training step when the AdamW optimizer is used. It updates all model weights and their momentum states in a single kernel call.

Code Reference

Source Location

Signature

Status AdamWOptimizerBase::PrepareForCompute(OpKernelContext* ctx,
                                             AdamWOptimizerBase::Prepare& prepare) const;

template <typename T>
Status AdamWOptimizer<T>::AdamWComputeMode0(Tensor& weight, Tensor& gradient,
    Tensor& momentums_1, Tensor& momentums_2,
    float lr, float alpha_correction, float beta_correction) const;

template <typename T>
Status AdamWOptimizer<T>::AdamWComputeMode1(Tensor& weight, Tensor& gradient,
    Tensor& momentums_1, Tensor& momentums_2,
    float lr, float lr_corrected) const;

template <typename T>
Status AdamWOptimizer<T>::Compute(OpKernelContext* ctx) const;

Import

#include "orttraining/orttraining/training_ops/cpu/optimizer/adamw/adamw.h"

I/O Contract

Inputs

Name Type Required Description
learning_rate Tensor(float) Yes Scalar learning rate
step Tensor(int64) Yes Current training step (for bias correction)
weights TensorSeq(float) Yes Sequence of weight tensors
gradients TensorSeq(float) Yes Sequence of gradient tensors
momentums_1 TensorSeq(float) Yes First moment estimates (mean of gradients)
momentums_2 TensorSeq(float) Yes Second moment estimates (mean of squared gradients)
update_signal Tensor(bool) No Optional signal to skip update

Outputs

Name Type Description
updated_flag Tensor(bool) Whether the update was performed
updated_weights TensorSeq(float) Updated weight tensors (in-place alias)
updated_momentums_1 TensorSeq(float) Updated first moments (in-place alias)
updated_momentums_2 TensorSeq(float) Updated second moments (in-place alias)

Usage Examples

ONNX_OPERATOR_KERNEL_EX(
    AdamWOptimizer, kMSDomain, 1, kCpuExecutionProvider,
    (*KernelDefBuilder::Create())
        .Alias(2, 1)  /* Return updated weights in-place */
        .Alias(4, 2)  /* Return updated moment-1 in-place */
        .Alias(5, 3)  /* Return updated moment-2 in-place */
        .TypeConstraint("T1", DataTypeImpl::GetTensorType<float>())
        .TypeConstraint("T2", DataTypeImpl::GetTensorType<int64_t>())
        .TypeConstraint("S_WEIGHT", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
        .TypeConstraint("S_GRAD", DataTypeImpl::AllFixedSizeSequenceTensorTypes())
        .TypeConstraint("S_MOMENT", DataTypeImpl::AllFixedSizeSequenceTensorTypes()),
    AdamWOptimizer<float>);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment