Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft Onnxruntime CUDA AdamW

From Leeroopedia
Revision as of 15:45, 16 February 2026 by Admin (talk | contribs) (Auto-imported from implementations/Microsoft_Onnxruntime_CUDA_AdamW.md)
(diff) ← Older revision | Latest revision (diff) | Newer revision → (diff)


Knowledge Sources
Domains Training, CUDA_Kernels
Last Updated 2026-02-10 04:00 GMT

Overview

Concrete tool for the AdamW optimizer (Adam with decoupled weight decay) in the ONNX Runtime CUDA training framework.

Description

Implements the AdamWOptimizer operator for CUDA that performs AdamW optimization using multi-tensor functors for efficient batched GPU computation. The implementation uses the AdamWMTAFunctor (Multi-Tensor Apply) to process all parameter groups in a single fused kernel launch with configurable chunk size (MTA_ADAMW_CHUNK_SIZE) and group size (MTA_ADAMW_GROUP_SIZE). Parameters include alpha (beta1), beta (beta2), epsilon, learning rate, weight decay, adam mode, correct_bias flag, and current step count. The kernel supports in-place updates via aliases for weights, moment-1, and moment-2 (sequence tensor types). An update signal (CPU input) controls whether the optimization step is performed, and results are optionally copied to output buffers when not aliased in-place.

Usage

Invoked during the optimization step of training when AdamW optimizer is selected, commonly used in modern transformer training.

Code Reference

Source Location

Signature

class AdamWOptimizer : public CudaKernel, public AdamWOptimizerBase {
  Status ComputeInternal(OpKernelContext* ctx) const;
};

Import

#include "orttraining/training_ops/cuda/optimizer/adamw/adamw.h"

I/O Contract

Inputs

Name Type Required Description
learning_rate Tensor(float) Yes Learning rate (CPU memory)
step Tensor(int64_t) Yes Step count (CPU memory)
weights Sequence(Tensor) Yes Parameter weights
gradients Sequence(Tensor) Yes Gradients
momentums_1 Sequence(Tensor) Yes First moments
momentums_2 Sequence(Tensor) Yes Second moments
update_signal Tensor(bool) No Whether to perform update (CPU memory)

Outputs

Name Type Description
updated_flag Tensor(bool) Whether the update was performed (CPU memory)
updated_weights Sequence(Tensor) Updated weights (in-place)
updated_momentums_1 Sequence(Tensor) Updated first moments (in-place)
updated_momentums_2 Sequence(Tensor) Updated second moments (in-place)

Usage Examples

ONNX_OPERATOR_KERNEL_EX(
    AdamWOptimizer, kMSDomain, 1, kCudaExecutionProvider,
    (*KernelDefBuilder::Create())
        .InputMemoryType(OrtMemTypeCPUInput, 0)   // learning_rate
        .InputMemoryType(OrtMemTypeCPUInput, 1)   // step
        .InputMemoryType(OrtMemTypeCPUInput, 6)   // update_signal
        .OutputMemoryType(OrtMemTypeCPUOutput, 0)  // updated_flag
        .Alias(2, 1).Alias(4, 2).Alias(5, 3),
    AdamWOptimizer);

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment