Jump to content

Connect SuperML | Leeroopedia MCP: Equip your AI agents with best practices, code verification, and debugging knowledge. Powered by Leeroo — building Organizational Superintelligence. Contact us at founders@leeroo.com.

Implementation:Microsoft Onnxruntime CUDA LAMB

From Leeroopedia


Knowledge Sources
Domains Training, CUDA_Kernels
Last Updated 2026-02-10 04:00 GMT

Overview

Concrete tool for the LAMB (Layer-wise Adaptive Moments) optimizer in the ONNX Runtime CUDA training framework.

Description

Implements the LambOptimizer operator for CUDA that performs LAMB optimization across multiple weight groups in a single kernel launch. The optimizer supports up to 1024 parameter groups with in-place updates via alias mappings generated by GenerateLambExtraAliasMapping. The ComputeInternal method processes weights, gradients, momentum-1, momentum-2, and optional mixed-precision weights per group. The multi-step computation includes: (1) compute update direction using first and second moments with bias correction and optional loss scaling/gradient norm clipping; (2) compute per-group weight and update L2 norms via reduce_square_sum; (3) apply the LAMB trust ratio (clamped between ratio_min and ratio_max) to scale the update; (4) update weights and optionally copy to mixed-precision format. The implementation uses multi-tensor functors for efficient batched GPU computation. Registered for multiple mixed-precision type combinations including float and MLFloat16.

Usage

Invoked during the optimization step of distributed training when LAMB optimizer is selected, commonly used for large-batch training of BERT and similar models.

Code Reference

Source Location

Signature

std::vector<std::pair<int, int>> GenerateLambExtraAliasMapping();

template <typename T1, typename T2, typename T3, typename T4, typename T_GRAD, typename T_MIXED_PRECISION_FP>
class LambOptimizer : public CudaKernel {
  Status ComputeInternal(OpKernelContext* ctx) const;
};

Import

#include "orttraining/training_ops/cuda/optimizer/lamb.h"

I/O Contract

Inputs

Name Type Required Description
eta Tensor(T1) Yes Learning rate
loss_scale Tensor(T3) No Loss scale factor for mixed precision
gradient_norm Tensor(T_GRAD_NORM) No Global gradient norm for clipping
do_update Tensor(bool) No Whether to perform the update (CPU input)
update_count Tensor(T4) Yes Iteration count for bias correction
weights_i Tensor(T2) Yes Parameter weights per group (up to 1024 groups)
gradients_i Tensor(T_GRAD) Yes Gradients per group
moment_1_i Tensor(T3) Yes First moment per group
moment_2_i Tensor(T3) Yes Second moment per group
mixed_precision_weights_i Tensor(T_MIXED_PRECISION_FP) No Mixed-precision weight copy per group

Outputs

Name Type Description
update_count_new Tensor(T4) Updated iteration count
weights_new_i Tensor(T2) Updated weights per group (in-place)
gradients_new_i Tensor(T_GRAD) Updated gradients per group (in-place)
moment_1_new_i Tensor(T3) Updated first moment per group (in-place)
moment_2_new_i Tensor(T3) Updated second moment per group (in-place)
mixed_precision_weights_new_i Tensor(T_MIXED_PRECISION_FP) Updated mixed-precision weights per group (in-place)

Usage Examples

// Registration for float weights, float gradients, MLFloat16 mixed precision
REGISTER_LAMB_KERNEL_TYPED(float, float, float, float, float, MLFloat16)
// Multi-tensor LAMB update is launched with:
// Stage 1: LambComputeDirection (compute d, update m1, m2)
// Stage 2: reduce_square_sum for w_norm and d_norm
// Stage 3: LambUpdate (apply trust ratio and update weights)

Related Pages

Page Connections

Double-click a node to navigate. Hold to expand connections.
Principle
Implementation
Heuristic
Environment